Spaces:
Runtime error
Runtime error
| """ | |
| Test script to verify the NLβSQL Leaderboard system works correctly. | |
| """ | |
| import os | |
| import sys | |
| import time | |
| # Add src to path for imports | |
| sys.path.append('src') | |
| from evaluator import evaluator, DatasetManager | |
| from models_registry import models_registry | |
| from scoring import scoring_engine | |
| def test_dataset_discovery(): | |
| """Test that datasets are discovered correctly.""" | |
| print("Testing dataset discovery...") | |
| dataset_manager = DatasetManager() | |
| datasets = dataset_manager.get_datasets() | |
| print(f"Found datasets: {list(datasets.keys())}") | |
| if "nyc_taxi_small" in datasets: | |
| print("β NYC Taxi dataset found") | |
| return True | |
| else: | |
| print("β NYC Taxi dataset not found") | |
| return False | |
| def test_models_loading(): | |
| """Test that models are loaded correctly.""" | |
| print("\nTesting models loading...") | |
| models = models_registry.get_models() | |
| print(f"Found models: {[model.name for model in models]}") | |
| if len(models) > 0: | |
| print("β Models loaded successfully") | |
| return True | |
| else: | |
| print("β No models found") | |
| return False | |
| def test_database_creation(): | |
| """Test database creation for NYC Taxi dataset.""" | |
| print("\nTesting database creation...") | |
| try: | |
| dataset_manager = DatasetManager() | |
| db_path = dataset_manager.create_database("nyc_taxi_small") | |
| if os.path.exists(db_path): | |
| print("β Database created successfully") | |
| # Clean up | |
| os.remove(db_path) | |
| return True | |
| else: | |
| print("β Database file not created") | |
| return False | |
| except Exception as e: | |
| print(f"β Database creation failed: {e}") | |
| return False | |
| def test_cases_loading(): | |
| """Test loading test cases.""" | |
| print("\nTesting cases loading...") | |
| try: | |
| dataset_manager = DatasetManager() | |
| cases = dataset_manager.load_cases("nyc_taxi_small") | |
| print(f"Found {len(cases)} test cases") | |
| if len(cases) > 0: | |
| print("β Test cases loaded successfully") | |
| return True | |
| else: | |
| print("β No test cases found") | |
| return False | |
| except Exception as e: | |
| print(f"β Cases loading failed: {e}") | |
| return False | |
| def test_prompt_templates(): | |
| """Test that prompt templates exist.""" | |
| print("\nTesting prompt templates...") | |
| dialects = ["presto", "bigquery", "snowflake"] | |
| all_exist = True | |
| for dialect in dialects: | |
| template_path = f"prompts/template_{dialect}.txt" | |
| if os.path.exists(template_path): | |
| print(f"β {dialect} template found") | |
| else: | |
| print(f"β {dialect} template not found") | |
| all_exist = False | |
| return all_exist | |
| def test_scoring_engine(): | |
| """Test the scoring engine.""" | |
| print("\nTesting scoring engine...") | |
| try: | |
| from scoring import Metrics | |
| # Test with sample metrics | |
| metrics = Metrics( | |
| correctness_exact=1.0, | |
| result_match_f1=0.8, | |
| exec_success=1.0, | |
| latency_ms=100.0, | |
| readability=0.9, | |
| dialect_ok=1.0 | |
| ) | |
| score = scoring_engine.compute_composite_score(metrics) | |
| print(f"β Composite score computed: {score}") | |
| if 0.0 <= score <= 1.0: | |
| print("β Score is in valid range") | |
| return True | |
| else: | |
| print("β Score is out of valid range") | |
| return False | |
| except Exception as e: | |
| print(f"β Scoring engine test failed: {e}") | |
| return False | |
| def test_sql_execution(): | |
| """Test SQL execution with DuckDB.""" | |
| print("\nTesting SQL execution...") | |
| try: | |
| import duckdb | |
| # Create a simple test database | |
| conn = duckdb.connect(":memory:") | |
| conn.execute("CREATE TABLE test (id INTEGER, name VARCHAR(10))") | |
| conn.execute("INSERT INTO test VALUES (1, 'Alice'), (2, 'Bob')") | |
| # Test query | |
| result = conn.execute("SELECT COUNT(*) FROM test").fetchdf() | |
| print(f"β SQL execution successful: {result.iloc[0, 0]} rows") | |
| conn.close() | |
| return True | |
| except Exception as e: | |
| print(f"β SQL execution failed: {e}") | |
| return False | |
| def test_sqlglot_transpilation(): | |
| """Test SQL transpilation with sqlglot.""" | |
| print("\nTesting SQL transpilation...") | |
| try: | |
| import sqlglot | |
| # Test simple query | |
| sql = "SELECT COUNT(*) FROM trips" | |
| parsed = sqlglot.parse_one(sql) | |
| # Transpile to different dialects | |
| dialects = ["presto", "bigquery", "snowflake"] | |
| for dialect in dialects: | |
| transpiled = parsed.sql(dialect=dialect) | |
| print(f"β {dialect} transpilation: {transpiled}") | |
| return True | |
| except Exception as e: | |
| print(f"β SQL transpilation failed: {e}") | |
| return False | |
| def main(): | |
| """Run all tests.""" | |
| print("NLβSQL Leaderboard System Test") | |
| print("=" * 40) | |
| tests = [ | |
| test_dataset_discovery, | |
| test_models_loading, | |
| test_database_creation, | |
| test_cases_loading, | |
| test_prompt_templates, | |
| test_scoring_engine, | |
| test_sql_execution, | |
| test_sqlglot_transpilation | |
| ] | |
| passed = 0 | |
| total = len(tests) | |
| for test in tests: | |
| try: | |
| if test(): | |
| passed += 1 | |
| except Exception as e: | |
| print(f"β Test {test.__name__} failed with exception: {e}") | |
| print("\n" + "=" * 40) | |
| print(f"Test Results: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("π All tests passed! The system is ready to use.") | |
| return True | |
| else: | |
| print("β Some tests failed. Please check the issues above.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) | |