diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..ab6ce16 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,11 @@ +[run] +source = server + +[report] +omit = + */__init__.py + server/test/* + server/utilities/constants/* + +exclude_lines = + ^\s*(import|from)\b.* #match import statements \ No newline at end of file diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml new file mode 100644 index 0000000..7f45ff0 --- /dev/null +++ b/.github/workflows/coverage.yaml @@ -0,0 +1,39 @@ +name: Test and Coverage + +on: + pull_request: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r server/requirements.txt + + - name: Run tests with coverage + run: | + python -m pytest --cov=server server/test + env: + OPENAI_API_KEY: "DUMMY KEY" + ANTHROPIC_API_KEY: "DUMMY KEY" + GOOGLE_AI_API_KEY: "DUMMY KEY" + DEEPSEEK_API_KEY: "DUMMY KEY" + DASHSCOPE_API_KEY: "DUMMY KEY" + ALL_GOOGLE_API_KEYS: "DUMMY KEY" + + - name: Upload coverage to Coveralls + env: + COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} + run: | + coveralls \ No newline at end of file diff --git a/server/app/main.py b/server/app/main.py index ddb5f7d..7b8e7e2 100644 --- a/server/app/main.py +++ b/server/app/main.py @@ -5,9 +5,6 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from services.client_factory import ClientFactory -from utilities.batch_job import (create_batch_input_file, - download_batch_job_output_file, - upload_and_run_batch_job) from utilities.config import PATH_CONFIG from utilities.constants.LLM_enums import LLMType, ModelType from utilities.constants.prompts_enums import PromptType @@ -81,35 +78,6 @@ async def test_cost_estimations(): raise HTTPException(status_code=500, detail=str(e)) -@app.post("/process_batch_job") -async def process_batch_job(request: BatchJobRequest): - try: - messages = [] - - create_msg = create_batch_input_file( - prompt_type_with_shots = { - request.prompt_type: request.shots, - PromptType.DAIL_SQL: 5 - }, - model=request.model, - temperature=request.temperature, - max_tokens=request.max_tokens, - database_name=request.database_name - ) - if create_msg.get('success'): - messages.append(create_msg['success']) - - input_file_id, batch_job_id = upload_and_run_batch_job(request.database_name) - - download_msg = download_batch_job_output_file(batch_job_id, request.database_name) - if download_msg.get('success'): - messages.append(download_msg['success']) - - return {'success': messages} - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - @app.post("/queries/generate-and-execute/") async def generate_and_execute_sql_query(body: QueryGenerationRequest): question = body.question @@ -224,18 +192,6 @@ def mask_single_question_and_query(request: MaskRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@app.post("/masking/file/") -def mask_question_and_answer_file_by_filename(request: MaskFileRequest): - try: - table_and_column_names = get_array_of_table_and_column_name(PATH_CONFIG.sqlite_path()) - masked_file_name = mask_question_and_answer_files(database_name=request.database_name, table_and_column_names=table_and_column_names) - - return { - "masked_file_name": masked_file_name - } - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) @app.post("/prompts/generate/") async def generate_prompt(request: PromptGenerationRequest): diff --git a/server/requirements.txt b/server/requirements.txt index 5beabcb..39436ba 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -22,4 +22,7 @@ tiktoken==0.8.0 tqdm==4.67.0 uvicorn==0.34.0 onnxruntime==1.16.3 -pre-commit==3.7.0 \ No newline at end of file +pre-commit==3.7.0 +pytest==8.3.5 +pytest-cov==6.1.1 +coveralls==4.0.1 \ No newline at end of file diff --git a/server/scripts/download_batch_output.py b/server/scripts/download_batch_output.py deleted file mode 100644 index 15fc119..0000000 --- a/server/scripts/download_batch_output.py +++ /dev/null @@ -1,78 +0,0 @@ -import json -import time - -from tqdm import tqdm -from utilities.batch_job import download_batch_job_output_file -from utilities.config import PATH_CONFIG -from utilities.constants.script_constants import BatchFileStatus - - -def download_batch_output_files(metadata_path: str): - """Download all batch jobs output files from OpenAI corresponding to the given meta data file.""" - - with open(metadata_path, "r") as file: - metadata = json.load(file) - - count = 0 - batch_jobs = metadata.get("databases", {}) - - # Retry until all batch jobs are downloaded - with tqdm(total=len(batch_jobs), desc="Downloading batch jobs") as progress_bar: - progress_bar.n = sum( - 1 - for batch_job_data in batch_jobs.values() - if batch_job_data["state"] == BatchFileStatus.DOWNLOADED.value - ) - progress_bar.refresh() - - while progress_bar.n < len(batch_jobs): - tqdm.write(f"Try: {count}") - for database, batch_job_data in batch_jobs.items(): - # Skip already downloaded jobs - if batch_job_data["state"] == BatchFileStatus.DOWNLOADED.value: - continue - - try: - tqdm.write(f"Downloading: {batch_job_data['batch_job_id']}") - download_batch_job_output_file( - batch_job_id=batch_job_data["batch_job_id"], - download_file_path=PATH_CONFIG.batch_output_path(database_name=database), - ) - - # Update the state to downloaded - batch_job_data["state"] = BatchFileStatus.DOWNLOADED.value - metadata["databases"][database] = batch_job_data - with open(metadata_path, "w") as file: - json.dump(metadata, file, indent=4) - - progress_bar.update(1) - - except Exception as e: - tqdm.write(str(e)) - - if progress_bar.n < len(batch_jobs): - time.sleep(10) # retry after 10 secs - - count += 1 - - -if __name__ == "__main__": - """ - To run this script: - - 1. Ensure the correct metadata file for batch jobs is available: - - The metadata file should be in the directory specified by `PATH_CONFIG.batch_job_metadata_dir()` with the time stamp format `YYYY-MM-DD_HH-MM-SS.json`. - - 2. Run the script: - - In the terminal, run `python3 -m scripts.download_batch_jobs`. - - Expected Output: - - Batch job output files will be downloaded to the appropriate directories, and the metadata file will be updated to reflect the download status. - - If any jobs fail to download, the script will retry the download every 10 seconds. - """ - - # Inputs - time_stamp = "2024-12-11_17:01:02.json" - metadata_path = f"{PATH_CONFIG.batch_job_metadata_dir}/{time_stamp}" - - download_batch_output_files(metadata_path) diff --git a/server/scripts/generate_batch_file.py b/server/scripts/generate_batch_file.py deleted file mode 100644 index 1ecb0af..0000000 --- a/server/scripts/generate_batch_file.py +++ /dev/null @@ -1,132 +0,0 @@ -import json -import os -from datetime import datetime - -from tqdm import tqdm -from utilities.batch_job import (create_batch_input_file, - upload_and_run_batch_job) -from utilities.config import PATH_CONFIG -from utilities.constants.LLM_enums import ModelType -from utilities.constants.prompts_enums import PromptType -from utilities.constants.script_constants import BatchFileStatus -from utilities.cost_estimation import calculate_cost_and_tokens_for_file - - -def generate_and_run_batch_input_files( - prompt_types_with_shots: dict, - model: ModelType, - temperature: float, - max_tokens: int, - dataset_dir: str, -): - databases = [d for d in os.listdir(dataset_dir)] - - total_estimated_cost = 0.0 - - metadata = { - "batch_info": { - "dataset_path": dataset_dir, - "candidates": { - str(key): value for key, value in prompt_types_with_shots.items() - }, - "model": model.value, - "temperature": temperature, - "max_tokens": max_tokens, - "total_estimated_cost": 0.0, - }, - "databases": {} - } - - - for database in tqdm(databases[:2], desc="Creating and running batch input files"): - db_name = os.path.splitext(database)[0] # remove .db in case we are working in synthetic data dir - - try: - create_batch_input_file( - prompt_types_with_shots=prompt_types_with_shots, - model=model, - temperature=temperature, - max_tokens=max_tokens, - database_name=db_name, - ) - - _, estimated_cost, _ = calculate_cost_and_tokens_for_file( - file_path=PATH_CONFIG.batch_input_path(database_name=db_name), - model=model, - is_batched=True, - ) - - total_estimated_cost += estimated_cost - - _, batch_job_id = upload_and_run_batch_job( - database_name=db_name - ) - - # Store database-specific metadata - metadata["databases"][db_name] = { - "batch_job_id": batch_job_id, - "state": BatchFileStatus.UPLOADED.value, - "estimated_cost": estimated_cost - } - - except Exception as e: - tqdm.write(str(e)) - - metadata["batch_info"]["total_estimated_cost"] = total_estimated_cost - - # Storing batch job metadata with the corresponding DB directory - now = datetime.now() - time_stamp = now.strftime("%Y-%m-%d_%H:%M:%S") - metadata_file_path = os.path.join(PATH_CONFIG.batch_job_metadata_dir(), f"{time_stamp}.json") - - os.makedirs(PATH_CONFIG.batch_job_metadata_dir(), exist_ok=True) - - with open(metadata_file_path, "w") as file: - json.dump(metadata, file) - - return metadata_file_path - - -if __name__ == "__main__": - """ - This script automates the process of creating batch input files for multiple databases, uploading the files to an API client, and running batch jobs for LLM evaluation. - To run this script: - - 1. Ensure you have set the correct `PATH_CONFIG.dataset_type` and `PATH_CONFIG.sample_dataset_type` in `utilities.config`: - - Set `PATH_CONFIG.dataset_type` to DatasetType.BIRD_TRAIN for training data. - - Set `PATH_CONFIG.dataset_type` to DatasetType.BIRD_DEV for development data. - - Set `PATH_CONFIG.sample_dataset_type` to DatasetType.BIRD_DEV or DatasetType.BIRD_TRAIN. - - 2. Run the Script: - - Navigate to the project directory. - - Execute the script using the following command: - python3 -m scripts.generate_batch_files - - 3. Expected Functionality: - - The script performs the following steps: - - Creates batch input files for candidate prompt types and shots for each database. - - Uploads the batch input files to the LLM client (e.g., OpenAI) and initiates batch jobs. - - Stores batch job metadata, linking batch jobs to their respective databases. - - Batch input files are saved to paths defined in `PATH_CONFIG.batch_output_file()`. - - Batch job metadata is saved as a timestamped `.json` file in the `PATH_CONFIG.batch_job_metadata_dir()`. - - Extra Note: - - Make sure that the dataset is split in the correct ratio. e.g. 50% for test data and 50% for examples data. - - If you want to generate a batch input file for only one candidate, just set the prompt_type_with_shots accordingly. - """ - - # Inputs, update these accordingly - prompt_type_with_shots = {PromptType.CODE_REPRESENTATION: 0, PromptType.DAIL_SQL: 5,PromptType.OPENAI_DEMO: 0} - temperature = 0.5 - model = ModelType.OPENAI_GPT4_O_MINI - max_tokens = 1000 - - metadata_file_path = generate_and_run_batch_input_files( - prompt_types_with_shots=prompt_type_with_shots, - model=model, - temperature=temperature, - max_tokens=max_tokens, - dataset_dir=PATH_CONFIG.dataset_dir(), - ) - - print("Batch Metadata File saved at: ", metadata_file_path) diff --git a/server/scripts/process_batch_output.py b/server/scripts/process_batch_output.py deleted file mode 100644 index a9356e2..0000000 --- a/server/scripts/process_batch_output.py +++ /dev/null @@ -1,114 +0,0 @@ -import json - -from tqdm import tqdm -from utilities.config import PATH_CONFIG -from utilities.constants.script_constants import BatchFileStatus - - -def format_batch_output(database, batch_output_data, test_data): - """ Formats the batch output for BIRD evaluation. """ - - predicted_scripts = {} - gold_items = [] - grouped_candidates = {} - - # Group candidates by question_id - for prediction in batch_output_data: - custom_id = prediction['custom_id'][8:] - custom_id_number = int(custom_id.split('-')[-1]) - - if custom_id_number not in grouped_candidates: - grouped_candidates[custom_id_number] = [] - - grouped_candidates[custom_id_number].append(prediction) - - for question_id, candidates in grouped_candidates.items(): - # TODO: Implement the judge LLM logic to select the best candidate based on some criteria. - chosen_candidate = candidates[0] - - # Get the SQL prediction from the chosen candidate - pred_sql = chosen_candidate['response']['body']['choices'][0]['message']['content'] - - # Clean up SQL predictions - if pred_sql.startswith('```sql') and pred_sql.endswith('```'): - pred_sql = pred_sql.strip('```sql\n').strip('```') - - predicted_scripts[question_id] = f'{pred_sql}\t----- bird -----\t{database}' - - # Match gold query for the prediction - for item in test_data: - if question_id == item['question_id']: - gt_sql = item['SQL'] - gold_items.append(f'{gt_sql}\t{database}') - - formatted_pred_path = PATH_CONFIG.formatted_predictions_path(database_name=database) - with open(formatted_pred_path, 'w') as file: - json.dump(predicted_scripts, file) - - gold_sql_path = PATH_CONFIG.test_gold_path(database_name=database) - with open(gold_sql_path, 'w') as file: - for item in gold_items: - file.write(f'{item}\n') - - -def format_batch_output_files(metadata_path: str): - - # Load batch job metadata - with open(metadata_path, "r") as file: - metadata = json.load(file) - - batch_jobs = metadata.get("databases", {}) - - for database, batch_job_data in tqdm(batch_jobs.items(), desc="Formatting batch output files"): - if batch_job_data["state"] == BatchFileStatus.FORMATTED_PRED_FILE.value: - continue - - if batch_job_data["state"] != BatchFileStatus.DOWNLOADED.value: - tqdm.write(f"Skipping {batch_job_data['database']} because is it not downloaded") - continue - - batch_output_path = PATH_CONFIG.batch_output_path(database_name=database) - - # Load batch output - with open(batch_output_path, 'r') as file: - batch_output_data = [json.loads(line) for line in file] - - # Load test file - test_file_path = PATH_CONFIG.processed_test_path(database_name=database) - with open(test_file_path, 'r') as file: - test_data = json.loads(file.read()) - - # Format batch output and generate gold queries - format_batch_output(database, batch_output_data, test_data) - - # Update the state to formatted - batch_job_data["state"] = BatchFileStatus.FORMATTED_PRED_FILE.value - metadata["databases"][database] = batch_job_data - with open(metadata_path, "w") as file: - json.dump(metadata, file, indent=4) - - -if __name__ == "__main__": - """ - To run this script: - - 1. Ensure that batch output files and test files are available: - - The batch output files should be in the directory specified by `GENERATE_BATCH_SCRIPT_PATH` for each database. - - Test files for each database should be located in the same directory with the filename format `test_{database}.json`. - - The batch job metadata file should be present in the directory specified by `BATCH_JOB_METADATA_DIR` with the time stamp format `YYYY-MM-DD_HH-MM-SS.json`. - - 2. Run the script: - - In the terminal, run `python3 -m scripts.format_batch_output`. - - The script will create two files for each database: - - A JSON file with the formatted predictions (`FORMATTED_PRED_FILE_{database}.json`). - - A SQL file with the gold queries (`gold_{database}.sql`). - - Expected Output: - - The metadata file will be updated with the new state for each batch job, marking it as formatted. - - The script will skip any jobs that are already formatted or not yet downloaded. - """ - # Inputs - time_stamp = "2024-12-11_17:01:02.json" - metadata_path = f"{PATH_CONFIG.batch_job_metadata_dir()}/{time_stamp}" - - format_batch_output_files(metadata_path) diff --git a/server/test/test_main.py b/server/test/test_main.py index b1c7301..69f8389 100644 --- a/server/test/test_main.py +++ b/server/test/test_main.py @@ -251,56 +251,6 @@ def test_masking_failure_due_to_exception( assert response.json()["detail"] == "Failed to retrieve table and column names" -class TestMaskQuestionAndAnswerFile: - - @patch("app.main.get_array_of_table_and_column_name") - @patch("app.main.mask_question_and_answer_files") - def test_successful_file_masking( - self, - mock_mask_question_and_answer_files, - mock_get_array_of_table_and_column_names, - ): - mock_get_array_of_table_and_column_names.return_value = [ - {"table": "test_table", "columns": ["id", "name"]} - ] - mock_mask_question_and_answer_files.return_value = "masked_file_name.txt" - - request_data = {"database_name": "test"} - response = client.post("/masking/file/", json=request_data) - - assert response.status_code == 200 - response_json = response.json() - assert "masked_file_name" in response_json - assert response_json["masked_file_name"] == "masked_file_name.txt" - - mock_get_array_of_table_and_column_names.assert_called_once_with( - PATH_CONFIG.sqlite_path() - ) - mock_mask_question_and_answer_files.assert_called_once_with( - database_name=request_data["database_name"], - table_and_column_names=mock_get_array_of_table_and_column_names.return_value, - ) - - @patch("app.main.get_array_of_table_and_column_name") - def test_file_masking_failure_due_to_exception( - self, mock_get_array_of_table_and_column_names - ): - mock_get_array_of_table_and_column_names.side_effect = Exception( - "Failed to retrieve table and column names" - ) - - request_data = {"database_name": "test"} - - response = client.post("/masking/file/", json=request_data) - - assert response.status_code == 500 - assert response.json()["detail"] == "Failed to retrieve table and column names" - - mock_get_array_of_table_and_column_names.assert_called_once_with( - PATH_CONFIG.sqlite_path() - ) - - class TestGeneratePrompt: @patch("app.main.PromptFactory.get_prompt_class") @@ -371,7 +321,7 @@ def test_successful_database_change( mock_format_schema.return_value = "Database Schema" mock_active_database.value = "hotel" - request_data = {"database_type": "hotel"} + request_data = {"database_name": "hotel"} response = client.post("/database/change/", json=request_data) @@ -379,13 +329,12 @@ def test_successful_database_change( response_json = response.json() assert "database_type" in response_json - assert response_json["database_type"] == "hotel" assert "schema" in response_json assert response_json["schema"] == "Database Schema" - mock_set_database.assert_called_once_with(DatabaseType.HOTEL) + mock_set_database.assert_called_once_with("hotel") mock_format_schema.assert_called_once_with( - FormatType.CODE, PATH_CONFIG.sqlite_path() + FormatType.CODE, database_name="hotel" ) def test_change_database_failure(self): @@ -398,12 +347,12 @@ def test_change_database_failure(self): class TestGetDatabaseSchema: @patch("app.main.format_schema") - @patch("app.main.PATH_CONFIG.database_name", new_callable=MagicMock) + @patch("app.main.PATH_CONFIG.database_name", "hotel") def test_successful_get_database_schema( - self, mock_active_database, mock_format_schema + self, mock_format_schema ): mock_format_schema.return_value = "Database Schema" - mock_active_database.value = "hotel" + # mock_active_database.value = "hotel" response = client.get("/database/schema/") @@ -415,7 +364,7 @@ def test_successful_get_database_schema( assert response_json["schema"] == "Database Schema" mock_format_schema.assert_called_once_with( - FormatType.CODE, PATH_CONFIG.sqlite_path() + FormatType.CODE, "hotel" ) @patch("app.main.format_schema") diff --git a/server/utilities/batch_job.py b/server/utilities/batch_job.py deleted file mode 100644 index 7f57345..0000000 --- a/server/utilities/batch_job.py +++ /dev/null @@ -1,126 +0,0 @@ -import json -import os -from typing import Dict - -from app import db -from services.client_factory import ClientFactory -from tqdm import tqdm -from utilities.config import PATH_CONFIG, ChromadbClient -from utilities.constants.LLM_enums import LLMType, ModelType -from utilities.constants.prompts_enums import PromptType -from utilities.constants.response_messages import ( - ERROR_BATCH_INPUT_FILE_CREATION, ERROR_BATCH_JOB_CREATION, - ERROR_BATCH_JOB_STATUS_NOT_COMPLETED, ERROR_DOWNLOAD_BATCH_FILE, - ERROR_SCHEMA_FILE_NOT_FOUND, SUCCESS_BATCH_INPUT_FILE_CREATED, - SUCCESS_BATCH_OUTPUT_FILE_DOWNLOADED) -from utilities.constants.script_constants import BatchJobStatus -from utilities.prompts.prompt_factory import PromptFactory -from utilities.vectorize import vectorize_data_samples - - -def create_batch_input_file( - prompt_types_with_shots: Dict[PromptType, int], - model: ModelType, - temperature: float, - max_tokens: int, - database_name: str, -): - # Change Database - db.set_database(database_name) - ChromadbClient.reset_chroma( - PATH_CONFIG.processed_train_path(database_name=database_name) - ) - vectorize_data_samples() - - # Set file path - batch_input_file_path = PATH_CONFIG.batch_input_path(database_name=database_name) - - # Load test data - test_data = None - try: - with open(PATH_CONFIG.processed_test_path(database_name=database_name), "r") as f: - test_data = json.load(f) - except FileNotFoundError as e: - raise RuntimeError(ERROR_SCHEMA_FILE_NOT_FOUND.format(error=(str(e)))) - - try: - # Create batch input items - batch_input = [] - for item in tqdm(test_data, desc="Processing queries", unit="item"): - for prompt_type, shots in prompt_types_with_shots.items(): - prompt = PromptFactory.get_prompt_class( - prompt_type=prompt_type, - target_question=item["question"], - shots=shots, - ) - batch_input.append( - { - "custom_id": f"{prompt_type.value}-{item['question_id']}", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": model.value, - "messages": [ - {"role": "system", "content": prompt}, - ], - "max_tokens": max_tokens, - "temperature": temperature, - }, - } - ) - - # Create and save batch input file - batch_input_dir = os.path.dirname(batch_input_file_path) - os.makedirs(batch_input_dir, exist_ok=True) - - with open(batch_input_file_path, "w") as f: - for item in batch_input: - json.dump(item, f) - f.write("\n") - - return {"success": SUCCESS_BATCH_INPUT_FILE_CREATED} - - except Exception as e: - raise RuntimeError( - ERROR_BATCH_INPUT_FILE_CREATION.format( - file_path=batch_input_file_path, error=str(e) - ) - ) - - -def upload_and_run_batch_job(database_name: str): - try: - client = ClientFactory.get_client(LLMType.OPENAI) - - batch_input_file = client.upload_batch_input_file(database_name) - batch_job = client.create_batch_job(batch_input_file.id) - - return batch_input_file.id, batch_job.id - - except Exception as e: - raise RuntimeError(ERROR_BATCH_JOB_CREATION.format(error=str(e))) - - -def download_batch_job_output_file(batch_job_id: str, download_file_path: str): - try: - client = ClientFactory.get_client(LLMType.OPENAI) - - batch_job = client.client.batches.retrieve(batch_job_id) - - if batch_job.status != BatchJobStatus.COMPLETED.value: - raise RuntimeError( - ERROR_BATCH_JOB_STATUS_NOT_COMPLETED.format(status=batch_job.status) - ) - - batch_output_file = client.download_file( - batch_job.output_file_id, download_file_path - ) - - return { - "success": SUCCESS_BATCH_OUTPUT_FILE_DOWNLOADED.format( - output_file_path=batch_output_file - ) - } - - except Exception as e: - raise RuntimeError(ERROR_DOWNLOAD_BATCH_FILE.format(error=str(e))) diff --git a/server/utilities/path_config.py b/server/utilities/path_config.py index 6d6f7b7..aac7468 100644 --- a/server/utilities/path_config.py +++ b/server/utilities/path_config.py @@ -19,7 +19,7 @@ class PathConfig: def __post_init__(self): self.database_name = next( - (p.name for p in self.dataset_dir().iterdir() if p.is_dir()), None + (p.name for p in self.dataset_dir().iterdir() if p.is_dir()), "" ) def set_database(self, database_name: str): @@ -172,43 +172,6 @@ def minhashes_path(self, database_name: Optional[str] = None) -> Path: return None - def batch_input_path(self, database_name: Optional[str] = None) -> Path: - database_name = database_name if database_name is not None else self.database_name - - if self.dataset_type in (DatasetType.BIRD_TRAIN, DatasetType.BIRD_DEV, DatasetType.BIRD_TEST): - return ( - self.database_dir(database_name=database_name) - / "batch_jobs" - / f"batch_job_input_{database_name}.jsonl" - ) - - elif self.dataset_type == DatasetType.SYNTHETIC: - return ( - self.repo_root - / "data" - / "batch_jobs" - / "batch_input_files" - / f"{database_name}_batch_job_input.jsonl" - ) - - def batch_output_path(self, database_name: Optional[str] = None) -> Path: - database_name = database_name if database_name is not None else self.database_name - - if self.dataset_type in (DatasetType.BIRD_TRAIN, DatasetType.BIRD_DEV, DatasetType.BIRD_TEST): - return ( - self.database_dir(database_name=database_name) - / "batch_jobs" - / f"batch_job_output_{database_name}.jsonl" - ) - - elif self.dataset_type == DatasetType.SYNTHETIC: - return ( - self.repo_root - / "data" - / "batch_jobs" - / "batch_input_files" - / f"{database_name}_batch_job_output.jsonl" - ) def description_dir(self, database_name: Optional[str] = None, dataset_type: Optional[DatasetType] = None ) -> Path: database_name = database_name if database_name is not None else self.database_name @@ -239,8 +202,6 @@ def column_meaning_path(self, dataset_type: Optional[DatasetType] = None) -> Pat def bird_results_dir(self) -> Path: return self.repo_root / "bird_results" - def batch_job_metadata_dir(self) -> Path: - return self.repo_root / "batch_job_metadata" def bird_file_path(self, dataset_type: Optional[DatasetType] = None) -> Path: dataset_type = dataset_type if dataset_type is not None else self.dataset_type