From b496e71c4048e1ae73ce066ce0d2a817a4c19aa0 Mon Sep 17 00:00:00 2001 From: Madhup Sukoon <29144316+vagrantism@users.noreply.github.com> Date: Mon, 27 May 2024 15:33:02 +0530 Subject: [PATCH 1/5] Added branch automation and CODEOWNERS (#19) --- .github/CODEOWNERS | 9 +++++++++ .github/issue-branch.yml | 10 ++++++++++ .github/workflows/issue-branch.yml | 15 +++++++++++++++ 3 files changed, 34 insertions(+) create mode 100644 .github/CODEOWNERS create mode 100644 .github/issue-branch.yml create mode 100644 .github/workflows/issue-branch.yml diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..ab5bb69 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,9 @@ +# This is a comment. +# Each line is a file pattern followed by one or more owners. + +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +# @global-owner1 and @global-owner2 will be requested for +# review when someone opens a pull request. +* @GoogleCloudPlatform/nl2sql-maintainers +/.github/ @GoogleCloudPlatform/nl2sql-maintainers \ No newline at end of file diff --git a/.github/issue-branch.yml b/.github/issue-branch.yml new file mode 100644 index 0000000..6f7dd0c --- /dev/null +++ b/.github/issue-branch.yml @@ -0,0 +1,10 @@ +autoLinkIssue: true +autoCloseIssue: true +defaultBranch: 'dev' +openDraftPR: true +copyIssueDescriptionToPR: true +copyIssueLabelsToPR: true +copyIssueAssigneeToPR: true +copyIssueProjectsToPR: true +copyIssueMilestoneToPR: true +conventionalPrTitles: true diff --git a/.github/workflows/issue-branch.yml b/.github/workflows/issue-branch.yml new file mode 100644 index 0000000..e966ceb --- /dev/null +++ b/.github/workflows/issue-branch.yml @@ -0,0 +1,15 @@ +on: + issues: + types: [ assigned ] + # The pull_request events below are only needed for pull-request related features. + pull_request: + types: [ opened, closed ] + +jobs: + create_issue_branch_job: + runs-on: ubuntu-latest + steps: + - name: Create Issue Branch + uses: robvanderleek/create-issue-branch@main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file From a7121ec5e740036ee69ac83afaa8e109f892e73b Mon Sep 17 00:00:00 2001 From: Dhruv Ahuja Date: Mon, 8 Jul 2024 14:38:16 +0530 Subject: [PATCH 2/5] added sql_generation_test.py with initial testing --- tests/tasks/sql_generation_test.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/tasks/sql_generation_test.py diff --git a/tests/tasks/sql_generation_test.py b/tests/tasks/sql_generation_test.py new file mode 100644 index 0000000..2919f3b --- /dev/null +++ b/tests/tasks/sql_generation_test.py @@ -0,0 +1,38 @@ +import unittest +from unittest.mock import MagicMock +from nl2sql.tasks.sql_generation.core import CoreSqlGenerator, CoreSqlGenratorResult + +class TestCoreSqlGenerator(unittest.TestCase): + def test_core_sql_generator_with_valid_response(self): + mock_llm = MagicMock() + mock_llm.generate.return_value = MagicMock( + generations=[ + [ + MagicMock(text="SELECT AVG(price) FROM products WHERE category = 'Electronics';") + ] + ] + ) + + mock_db = MagicMock() + mock_db.db.dialect = "sqlite" + mock_db.db.table_info = { + "products": {"product_id": "INT PRIMARY KEY", "name": "TEXT", "price": "REAL", "category": "TEXT"} + } + mock_db.db._usable_tables = ["products"] + mock_db.name = "test_db" + mock_db.descriptor = "A test database" + + # Initialize with the mock LLM + generator = CoreSqlGenerator(llm=mock_llm) + + # Run the generator + result = generator(mock_db, "What is the average price of products in the 'Electronics' category?") + + # Assertions + self.assertEqual(result.generated_query, "SELECT AVG(price) FROM products WHERE category = 'Electronics';") + self.assertEqual(result.db_name, "test_db") + self.assertEqual(result.question, "What is the average price of products in the 'Electronics' category?") + self.assertEqual(len(result.intermediate_steps), 1) + + # Verify LLM call + mock_llm.generate.assert_called_once() \ No newline at end of file From 19a1b1f63f3f85cabd122ffc72aa4ffbd8004f90 Mon Sep 17 00:00:00 2001 From: Dhruv Ahuja Date: Mon, 8 Jul 2024 14:46:15 +0530 Subject: [PATCH 3/5] Added optional instead of using | operator --- nl2sql/tasks/sql_generation/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nl2sql/tasks/sql_generation/core.py b/nl2sql/tasks/sql_generation/core.py index d43a0cc..db51a7a 100644 --- a/nl2sql/tasks/sql_generation/core.py +++ b/nl2sql/tasks/sql_generation/core.py @@ -26,7 +26,7 @@ from loguru import logger from pydantic import BaseModel, SkipValidation from typing_extensions import Literal - +from typing import Optional from nl2sql.assets.prompts import FewShot as FewShotPrompts from nl2sql.assets.prompts import ZeroShot as ZeroShotPrompts from nl2sql.datasets.base import Database @@ -115,9 +115,9 @@ def LANGCHAIN_ZERO_SHOT_PROMPT(self) -> _CoreSqlGeneratorPrompt: def custom_prompt( cls, prompt_template: BasePromptTemplate, - parser: StructuredOutputParser | None = None, + parser: Optional[StructuredOutputParser] = None, post_processor: Callable = lambda x: x, - prompt_template_id: str | None = None, + prompt_template_id: Optional[str] = None, ) -> _CoreSqlGeneratorPrompt: """ Use a custom PromptTemplate for SQL Generation. From 2c5a5c45bfea6eeec685eb8062e45a10b8b37125 Mon Sep 17 00:00:00 2001 From: Dhruv Ahuja Date: Wed, 10 Jul 2024 14:28:55 +0530 Subject: [PATCH 4/5] added table_selection_test --- tests/tasks/table_selection_test.py | 98 +++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/tasks/table_selection_test.py diff --git a/tests/tasks/table_selection_test.py b/tests/tasks/table_selection_test.py new file mode 100644 index 0000000..39e8b58 --- /dev/null +++ b/tests/tasks/table_selection_test.py @@ -0,0 +1,98 @@ +import unittest +from unittest.mock import MagicMock +from loguru import logger +from nl2sql.datasets.base import Database +from nl2sql.tasks.table_selection.core import ( + CoreTableSelector, + _TableSelectorPrompts, +) + +class TestCoreTableSelector(unittest.TestCase): + def test_call_with_langchain_decider_prompt(self): + mock_llm = MagicMock() + mock_llm.generate.return_value = MagicMock( + generations=[ + [MagicMock(text="TableA, TableB")] + ] + ) + selector = CoreTableSelector( + llm=mock_llm, prompt=_TableSelectorPrompts().LANGCHAIN_DECIDER_PROMPT + ) + mock_db = MagicMock() + mock_db.name = "test_db" + mock_db.db._usable_tables = {"TableA", "TableC"} + mock_db.descriptor = { + "TableA": "Description A", + "TableB": "Description B", + "TableC": "Description C", + } + result = selector(mock_db, "test question") + self.assertEqual(result.selected_tables, {"TableA"}) + self.assertEqual(result.db_name, "test_db") + self.assertEqual(result.question, "test question") + self.assertEqual(result.available_tables, {"TableA", "TableC"}) + + def test_call_with_curated_few_shot_cot_prompt(self): + mock_llm = MagicMock() + mock_llm.generate.side_effect = [ + MagicMock( + generations=[ + [MagicMock(text="Yes. TableA is relevant")] + ] + ), + MagicMock( + generations=[ + [MagicMock(text="No. TableB is not relevant")] + ] + ), + ] + selector = CoreTableSelector( + llm=mock_llm, prompt=_TableSelectorPrompts().CURATED_FEW_SHOT_COT_PROMPT + ) + mock_db = MagicMock() + mock_db.name = "test_db" + mock_db.db._usable_tables = {"TableA", "TableB"} + mock_db.descriptor = { + "TableA": { + "col_descriptor": { + "column1": "data_type", + "column2": "data_type" + } + }, + "TableB": { + "col_descriptor": { + "column1": "data_type", + "column2": "data_type" + } + } + } + result = selector(mock_db, "test question") + self.assertEqual(result.selected_tables, {"TableA"}) + self.assertEqual(result.db_name, "test_db") + self.assertEqual(result.question, "test question") + self.assertEqual(result.available_tables, {"TableA", "TableB"}) + + def test_call_with_empty_response(self): + mock_llm = MagicMock() + mock_llm.generate.return_value = MagicMock( + generations=[ + [MagicMock(text=" ")] # Set text to an empty string + ] + ) + selector = CoreTableSelector( + llm=mock_llm, prompt=_TableSelectorPrompts().LANGCHAIN_DECIDER_PROMPT + ) + mock_db = MagicMock() + mock_db.name = "test_db" + mock_db.db._usable_tables = {"TableA", "TableC"} + mock_db.descriptor = { + "TableA": "Description A", + "TableB": "Description B", + "TableC": "Description C", + } + # with self.assertLogs("nl2sql.tasks.table_selection.core", level="CRITICAL"): + result = selector(mock_db, "test question") + self.assertEqual(result.selected_tables, set()) + self.assertEqual(result.db_name, "test_db") + self.assertEqual(result.question, "test question") + self.assertEqual(result.available_tables, {"TableA", "TableC"}) \ No newline at end of file From 9e0cbbe872a644a5e4f33f6fbb56c1401f54e595 Mon Sep 17 00:00:00 2001 From: Dhruv Ahuja Date: Wed, 17 Jul 2024 11:19:00 +0530 Subject: [PATCH 5/5] added Apache licence --- tests/tasks/sql_generation_test.py | 14 ++++++++++++++ tests/tasks/table_selection_test.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/tests/tasks/sql_generation_test.py b/tests/tasks/sql_generation_test.py index 2919f3b..4b3eafd 100644 --- a/tests/tasks/sql_generation_test.py +++ b/tests/tasks/sql_generation_test.py @@ -1,3 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock from nl2sql.tasks.sql_generation.core import CoreSqlGenerator, CoreSqlGenratorResult diff --git a/tests/tasks/table_selection_test.py b/tests/tasks/table_selection_test.py index 39e8b58..7f8f328 100644 --- a/tests/tasks/table_selection_test.py +++ b/tests/tasks/table_selection_test.py @@ -1,3 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock from loguru import logger