-
Notifications
You must be signed in to change notification settings - Fork 0
Prepare sample database script test cases #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AwaisKamran
wants to merge
38
commits into
main
Choose a base branch
from
prepare-sample-database-script-test-cases
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
e6c212e
refactor: prepare_sample_dataset updated
AwaisKamran 7307e36
refactor: misleading parameters removed from function
AwaisKamran 05994f9
Merge branch 'main' of github.com:Conrad-X/text2SQL into refactor-pre…
AwaisKamran 9091924
fix: replace add_question_id_for_bird_train with add_sequential_ids_t…
AwaisKamran 479da2c
feat: Replace alive bar with tqdm
AwaisKamran 5390337
refactor: replace write_train_data_to_file with save_json_to_file
AwaisKamran b206ccc
refactor: update docstrings
AwaisKamran 760e0f9
refactor: rename error_messages.py to response_messages.py
AwaisKamran 1bba42a
refactor: reused bird utils constants
AwaisKamran 75be5ec
refactor: response messages updated for consistency
AwaisKamran bc96600
refactor: update import paths
AwaisKamran daced9a
refactor: use with statement instead of conventional connection creat…
AwaisKamran 8afeb3d
refactor: create sql connection explicitly
AwaisKamran af5dd1c
refactor: update import statements
AwaisKamran f74ab77
docs: add docstring to create_database_connection function
AwaisKamran 0a26d3a
refactor: fixed docstrings
AwaisKamran 37363b2
refactor: copilot suggestions accommodated
AwaisKamran 51a1c32
refactor: segregated current_db and train_data fetching
AwaisKamran 7dd755e
refactor: typehints added
AwaisKamran b6e5a09
docs: add module docstring to indexing_constants
AwaisKamran 3c5a6a7
Merge branch 'main' of github.com:Conrad-X/text2SQL into refactor-pre…
AwaisKamran b0cfc7d
feat: updated TODO comment for add_database_descriptions
AwaisKamran cf1b0ea
fix: import statement changed
AwaisKamran b503583
chore: merge conflicts resolved
AwaisKamran 68fa986
refactor: error message name issue
AwaisKamran 7d659c8
refactor: add_schema_used updated
AwaisKamran 8cd8470
fix: update parameters for add_schema_used method
AwaisKamran b8f2238
chore: LLMConfig added to add_database_descriptions & typehints added…
AwaisKamran 85d1a07
chore: connection removed from add_schema_used
AwaisKamran bb9e8bc
fix: fixed docstrings
AwaisKamran 49b44fc
chore: added try/catch around updating schema_used
AwaisKamran db53d4b
chore: item[SCHEMA_USED] updated
AwaisKamran 0a5c46b
fix: close_connection method removed
AwaisKamran c52ebaf
Merge branch 'main' of github.com:Conrad-X/text2SQL into refactor-pre…
AwaisKamran e526088
Merge branch 'main' of github.com:Conrad-X/text2SQL into refactor-pre…
AwaisKamran eeabf76
chore: add test cases for prepare_sample_dataset.py
AwaisKamran 320d9ab
chore: update test cases
AwaisKamran 21dbc1e
fix: update test cases
AwaisKamran File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """ | ||
| text2SQL Server Package. | ||
|
|
||
| This package handles the server-side operations for converting natural language text to SQL queries. | ||
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |
|
|
||
| import os | ||
| import shutil | ||
| from pathlib import Path | ||
|
|
||
| from preprocess.add_descriptions_bird_dataset import add_database_descriptions | ||
| from tqdm import tqdm | ||
|
|
@@ -85,10 +86,16 @@ def create_train_file(train_file: str) -> None: | |
| if PATH_CONFIG.sample_dataset_type == DatasetType.BIRD_TRAIN: | ||
| add_sequential_ids_to_questions(train_file) | ||
| except FileNotFoundError as e: | ||
| if isinstance(e, FileNotFoundError): | ||
| raise | ||
| logger.error(ERROR_FILE_NOT_FOUND.format(error=str(e))) | ||
| except IOError as e: | ||
| if isinstance(e, IOError): | ||
| raise | ||
| logger.error(ERROR_IO.format(error=str(e))) | ||
| except Exception as e: | ||
| if isinstance(e, PermissionError): | ||
| raise | ||
| logger.error(UNEXPECTED_ERROR.format(error=str(e))) | ||
|
|
||
| def copy_bird_train_file(train_file: str) -> None: | ||
|
|
@@ -131,7 +138,7 @@ def get_train_data(train_file: str) -> list: | |
| logger.error(ERROR_INVALID_TRAIN_FILE) | ||
| return None | ||
|
|
||
| def add_schema_used(train_data: list, dataset_type: str) -> None: | ||
| def add_schema_used(train_data: list, dataset_type: str, train_file: Path) -> None: | ||
| """ | ||
| Add schema_used field to each item in the train data. | ||
|
|
||
|
|
@@ -162,7 +169,9 @@ def add_schema_used(train_data: list, dataset_type: str) -> None: | |
| except Exception as e: | ||
| logger.warning(WARNING_FAILED_TO_ADD_SCHEMA_USED.format(question_id=item[QUESTION_ID_KEY])) | ||
| item[SCHEMA_USED] = None | ||
| except KeyboardInterrupt: | ||
| except KeyboardInterrupt as e: | ||
| if isinstance(e, KeyboardInterrupt): | ||
| raise | ||
| logger.error(ERROR_USER_KEYBOARD_INTERRUPTION) | ||
|
|
||
| finally: | ||
|
|
@@ -198,10 +207,10 @@ def add_schema_used(train_data: list, dataset_type: str) -> None: | |
| """ | ||
| train_file = get_train_file_path() | ||
| dataset_type = PATH_CONFIG.sample_dataset_type | ||
| train_data = get_train_data(train_file) | ||
| train_data = get_train_data(Path(train_file)) | ||
|
|
||
| if train_data: | ||
| add_schema_used(train_data, dataset_type) | ||
| add_schema_used(train_data, dataset_type, Path(train_file)) | ||
|
||
| # Need to Updated `add_database_descriptions` to work for sample datasets | ||
| add_database_descriptions( | ||
| dataset_type=dataset_type, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| import os | ||
| import sys | ||
| import unittest | ||
| from unittest.mock import MagicMock, mock_open, patch | ||
|
|
||
| # Append the project root directory to sys.path | ||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | ||
| from pathlib import Path | ||
|
|
||
| from utilities.config import PATH_CONFIG | ||
| from utilities.constants.bird_utils.indexing_constants import (DB_ID_KEY, | ||
| QUESTION_ID_KEY, | ||
| SCHEMA_USED, | ||
| SQL) | ||
| from utilities.constants.database_enums import DatasetType | ||
|
|
||
| from server.preprocess.prepare_sample_dataset import (add_schema_used, | ||
| copy_bird_train_file, | ||
| create_train_file, | ||
| get_train_data, | ||
| get_train_file_path) | ||
|
|
||
|
|
||
| class TestGetTrainFilePathExisting(unittest.TestCase): | ||
| @patch('os.path.exists', return_value=True) | ||
| @patch('os.makedirs') | ||
| @patch('shutil.copyfile') | ||
| @patch('server.preprocess.prepare_sample_dataset.add_sequential_ids_to_questions') | ||
| def test_get_train_file_path_existing(self, mock_add_ids, mock_copyfile, mock_makedirs, mock_exists): | ||
| PATH_CONFIG.processed_train_path = MagicMock(return_value='/path/to/train_file.json') | ||
| result = get_train_file_path() | ||
| self.assertEqual(result, '/path/to/train_file.json') | ||
| mock_add_ids.assert_not_called() | ||
|
|
||
| class TestGetTrainFilePathNonExisting(unittest.TestCase): | ||
| @patch('os.path.exists', return_value=False) | ||
| @patch('os.makedirs') | ||
| @patch('shutil.copyfile') | ||
| @patch('server.preprocess.prepare_sample_dataset.add_sequential_ids_to_questions') | ||
| def test_get_train_file_path_non_existing(self, mock_add_ids, mock_copyfile, mock_makedirs, mock_exists): | ||
| PATH_CONFIG.processed_train_path = MagicMock(return_value='/path/to/train_file.json') | ||
| PATH_CONFIG.sample_dataset_type = DatasetType.BIRD_TRAIN | ||
| result = get_train_file_path() | ||
| self.assertEqual(result, '/path/to/train_file.json') | ||
| mock_add_ids.assert_called_once_with('/path/to/train_file.json') | ||
|
|
||
| class TestCreateTrainFile(unittest.TestCase): | ||
| @patch('os.makedirs') | ||
| @patch('shutil.copyfile') | ||
| @patch('server.preprocess.prepare_sample_dataset.add_sequential_ids_to_questions') | ||
| def test_create_train_file(self, mock_add_ids, mock_copyfile, mock_makedirs): | ||
| PATH_CONFIG.sample_dataset_type = DatasetType.BIRD_TRAIN | ||
|
|
||
| # Test normal behavior | ||
| create_train_file('/path/to/train_file.json') | ||
| mock_makedirs.assert_called_once_with('/path/to', exist_ok=True) | ||
| mock_copyfile.assert_called_once() | ||
|
|
||
| @patch('os.makedirs') | ||
| @patch('shutil.copyfile') | ||
| @patch('server.preprocess.prepare_sample_dataset.add_sequential_ids_to_questions') | ||
| def test_file_not_found_error(self, mock_add_ids, mock_copyfile, mock_makedirs): | ||
| """Test FileNotFoundError""" | ||
| mock_copyfile.side_effect = FileNotFoundError | ||
| with self.assertRaises(FileNotFoundError): | ||
| create_train_file('/path/to/train_file.json') | ||
|
|
||
| @patch('os.makedirs') | ||
| @patch('shutil.copyfile') | ||
| @patch('server.preprocess.prepare_sample_dataset.add_sequential_ids_to_questions') | ||
| def test_io_error(self, mock_add_ids, mock_copyfile, mock_makedirs): | ||
| """Test IOError""" | ||
| mock_copyfile.side_effect = IOError | ||
| with self.assertRaises(IOError): | ||
| create_train_file('/path/to/train_file.json') | ||
|
|
||
| @patch('os.makedirs') | ||
| @patch('shutil.copyfile') | ||
| @patch('server.preprocess.prepare_sample_dataset.add_sequential_ids_to_questions') | ||
| def test_permission_error(self, mock_add_ids, mock_copyfile, mock_makedirs): | ||
| """Test PermissionError""" | ||
| mock_copyfile.side_effect = PermissionError | ||
| with self.assertRaises(PermissionError): | ||
| create_train_file('/path/to/train_file.json') | ||
|
|
||
| class TestCreateTrainFilePermissionError(unittest.TestCase): | ||
| @patch('shutil.copyfile') | ||
| @patch('os.makedirs', side_effect=PermissionError("Permission denied")) | ||
| def test_create_train_file_permission_error(self, mock_makedirs, mock_copyfile): | ||
| PATH_CONFIG.sample_dataset_type = DatasetType.BIRD_TRAIN | ||
| with self.assertRaises(PermissionError): | ||
| create_train_file('/path/to/train_file.json') | ||
|
|
||
| class TestCopyBirdTrainFile(unittest.TestCase): | ||
| @patch('shutil.copyfile') | ||
| def test_copy_bird_train_file(self, mock_copyfile): | ||
| PATH_CONFIG.bird_file_path = MagicMock(return_value='/path/to/source_file.json') | ||
| copy_bird_train_file('/path/to/train_file.json') | ||
| mock_copyfile.assert_called_once_with('/path/to/source_file.json', '/path/to/train_file.json') | ||
|
|
||
| class TestGetTrainDataValid(unittest.TestCase): | ||
| @patch('builtins.open', new_callable=mock_open, read_data='{"db_id": "db1", "question_id": "q1", "sql": "SELECT * FROM table"}') | ||
| @patch('os.path.exists', return_value=True) | ||
| @patch('server.preprocess.prepare_sample_dataset.load_json_from_file', return_value=[{'db_id': 'db1', 'question_id': 'q1', 'sql': 'SELECT * FROM table'}]) | ||
| def test_get_train_data_valid(self, mock_open_file, mock_exists, mock_load_json): | ||
| result = get_train_data('/path/to/train_file.json') | ||
| self.assertEqual(result, [{'db_id': 'db1', 'question_id': 'q1', 'sql': 'SELECT * FROM table'}]) | ||
|
|
||
| class TestGetTrainDataInvalidFile(unittest.TestCase): | ||
| @patch('os.path.exists', return_value=False) | ||
| def test_get_train_data_invalid_file(self, mock_exists): | ||
| train_file = "/path/to/non_existent_file.json" | ||
| result = get_train_data(train_file) | ||
| self.assertIsNone(result) | ||
|
|
||
| class TestAddSchemaUsed(unittest.TestCase): | ||
| @patch('server.preprocess.prepare_sample_dataset.save_json_to_file') | ||
| @patch('server.preprocess.prepare_sample_dataset.get_sql_columns_dict', return_value={'columns': ['col1', 'col2']}) | ||
| @patch('builtins.open', new_callable=mock_open, read_data='{"db_id": "db1", "question_id": "q1", "sql": "SELECT * FROM table"}') | ||
| def test_add_schema_used(self, mock_open_file, mock_get_sql_columns_dict, mock_save_json): | ||
| train_data = [{DB_ID_KEY: 'db1', QUESTION_ID_KEY: 'q1', SQL: 'SELECT * FROM table'}] | ||
| train_file = Path('/path/to/train_file.json') | ||
|
|
||
| add_schema_used(train_data, DatasetType.BIRD_TRAIN, train_file) | ||
| self.assertEqual(train_data[0][SCHEMA_USED], {'columns': ['col1', 'col2']}) | ||
| mock_save_json.assert_called_once() | ||
|
|
||
| @patch('server.preprocess.prepare_sample_dataset.save_json_to_file') | ||
| @patch('server.preprocess.prepare_sample_dataset.get_sql_columns_dict', side_effect=KeyboardInterrupt) | ||
| @patch('builtins.open', new_callable=mock_open, read_data='{"db_id": "db1", "question_id": "q1", "sql": "SELECT * FROM table"}') | ||
| def test_add_schema_used_keyboard_interrupt(self, mock_open_file, mock_get_sql_columns_dict, mock_save_json): | ||
| train_data = [{DB_ID_KEY: 'db1', QUESTION_ID_KEY: 'q1', SQL: 'SELECT * FROM table'}] | ||
| train_file = Path('/path/to/train_file.json') | ||
|
|
||
| with self.assertRaises(KeyboardInterrupt): | ||
| add_schema_used(train_data, DatasetType.BIRD_TRAIN, train_file) | ||
|
|
||
| self.assertNotIn(SCHEMA_USED, train_data[0]) | ||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider using consistent types for file paths across your functions; either update get_train_data to accept a Path object or convert the Path to a string before passing it.