From a7b612665a6eb903f40fab5c6f65e253ffe803d5 Mon Sep 17 00:00:00 2001 From: etnikatcosmotech Date: Thu, 18 Dec 2025 09:49:38 +0100 Subject: [PATCH 1/2] feat(dataset-api): extend upload_dataset and add upload_dataset_parts --- cosmotech/coal/cosmotech_api/apis/dataset.py | 94 ++++++ .../coal/en-US/coal/services/dataset.yml | 6 + tests/unit/coal/test_azure/test_azure_blob.py | 2 +- .../test_apis/test_dataset.py | 273 ++++++++++++++++++ 4 files changed, 374 insertions(+), 1 deletion(-) diff --git a/cosmotech/coal/cosmotech_api/apis/dataset.py b/cosmotech/coal/cosmotech_api/apis/dataset.py index aa1bb927..42d75d8c 100644 --- a/cosmotech/coal/cosmotech_api/apis/dataset.py +++ b/cosmotech/coal/cosmotech_api/apis/dataset.py @@ -70,7 +70,21 @@ def upload_dataset( dataset_name: str, as_files: Optional[list[Union[Path, str]]] = (), as_db: Optional[list[Union[Path, str]]] = (), + tags: Optional[list[str]] = None, + additional_data: Optional[dict] = None, ) -> Dataset: + """Upload a new dataset with optional tags and additional data. + + Args: + dataset_name: The name of the dataset to create + as_files: List of file paths to upload as FILE type parts + as_db: List of file paths to upload as DB type parts + tags: Optional list of tags to associate with the dataset + additional_data: Optional dictionary of additional metadata + + Returns: + The created Dataset object + """ _parts = list() for _f in as_files: @@ -81,6 +95,8 @@ def upload_dataset( d_request = DatasetCreateRequest( name=dataset_name, + tags=tags, + additional_data=additional_data, parts=list( DatasetPartCreateRequest( name=_p_name, @@ -101,3 +117,81 @@ def upload_dataset( LOGGER.info(T("coal.services.dataset.dataset_created").format(dataset_id=d_ret.id)) return d_ret + + def upload_dataset_parts( + self, + dataset_id: str, + as_files: Optional[list[Union[Path, str]]] = (), + as_db: Optional[list[Union[Path, str]]] = (), + replace_existing: bool = False, + ) -> Dataset: + """Upload parts to an existing dataset. + + Args: + dataset_id: The ID of the existing dataset + as_files: List of file paths to upload as FILE type parts + as_db: List of file paths to upload as DB type parts + replace_existing: If True, replace existing parts with same name + + Returns: + The updated Dataset object + """ + # Get current dataset to check existing parts + current_dataset = self.get_dataset( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + ) + + # Build set of existing part names and their IDs for quick lookup + existing_parts = {part.source_name: part.id for part in (current_dataset.parts or [])} + + # Collect parts to upload + _parts = list() + for _f in as_files: + _parts.extend(self.path_to_parts(_f, DatasetPartTypeEnum.FILE)) + for _db in as_db: + _parts.extend(self.path_to_parts(_db, DatasetPartTypeEnum.DB)) + + # Process each part + for _p_name, _p_path, _type in _parts: + if _p_name in existing_parts: + if replace_existing: + # Delete existing part before creating new one + self.delete_dataset_part( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + dataset_part_id=existing_parts[_p_name], + ) + LOGGER.info(T("coal.services.dataset.part_replaced").format(part_name=_p_name)) + else: + LOGGER.warning(T("coal.services.dataset.part_skipped").format(part_name=_p_name)) + continue + + # Create new part + part_request = DatasetPartCreateRequest( + name=_p_name, + description=_p_name, + sourceName=_p_name, + type=_type, + ) + + self.create_dataset_part( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + dataset_part_create_request=part_request, + file=(_p_name, _p_path.open("rb").read()), + ) + LOGGER.debug(T("coal.services.dataset.part_uploaded").format(part_name=_p_name)) + + # Return updated dataset + updated_dataset = self.get_dataset( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + ) + + LOGGER.info(T("coal.services.dataset.parts_uploaded").format(dataset_id=dataset_id)) + return updated_dataset diff --git a/cosmotech/translation/coal/en-US/coal/services/dataset.yml b/cosmotech/translation/coal/en-US/coal/services/dataset.yml index 4ab9955d..801a22c9 100644 --- a/cosmotech/translation/coal/en-US/coal/services/dataset.yml +++ b/cosmotech/translation/coal/en-US/coal/services/dataset.yml @@ -56,3 +56,9 @@ text_processed: "Processed text file {file_name} with {lines} lines" # Dataset API operations part_downloaded: "Downloaded part {part_name} to {file_path}" dataset_created: "Created dataset {dataset_id}" + +# Dataset parts operations +part_uploaded: "Uploaded part {part_name}" +part_replaced: "Replaced existing part {part_name}" +part_skipped: "Skipped existing part {part_name} (use replace_existing=True to overwrite)" +parts_uploaded: "Successfully uploaded parts to dataset {dataset_id}" diff --git a/tests/unit/coal/test_azure/test_azure_blob.py b/tests/unit/coal/test_azure/test_azure_blob.py index 0c2a09e2..3e07d0f0 100644 --- a/tests/unit/coal/test_azure/test_azure_blob.py +++ b/tests/unit/coal/test_azure/test_azure_blob.py @@ -13,7 +13,7 @@ from azure.identity import ClientSecretCredential from azure.storage.blob import BlobServiceClient, ContainerClient -from cosmotech.coal.azure.blob import dump_store_to_azure +from cosmotech.coal.azure.blob import VALID_TYPES, dump_store_to_azure from cosmotech.coal.store.store import Store from cosmotech.coal.utils.configuration import Configuration diff --git a/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py b/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py index a8a5a1a4..191adf0d 100644 --- a/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py +++ b/tests/unit/coal/test_cosmotech_api/test_apis/test_dataset.py @@ -236,3 +236,276 @@ def test_upload_dataset_empty(self, mock_cosmotech_config, mock_api_client): call_args = api.create_dataset.call_args # Verify the request has an empty parts list assert len(call_args[0][2].parts) == 0 + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_with_tags(self, mock_cosmotech_config, mock_api_client): + """Test uploading a dataset with tags.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "new-dataset-123" + + api = DatasetApi(configuration=mock_config) + api.create_dataset = MagicMock(return_value=mock_dataset) + + result = api.upload_dataset("Test Dataset", tags=["tag1", "tag2"]) + + assert result == mock_dataset + api.create_dataset.assert_called_once() + call_args = api.create_dataset.call_args + request = call_args[0][2] + assert request.tags == ["tag1", "tag2"] + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_with_additional_data(self, mock_cosmotech_config, mock_api_client): + """Test uploading a dataset with additional_data.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "new-dataset-123" + + api = DatasetApi(configuration=mock_config) + api.create_dataset = MagicMock(return_value=mock_dataset) + + result = api.upload_dataset("Test Dataset", additional_data={"key": "value", "nested": {"a": 1}}) + + assert result == mock_dataset + api.create_dataset.assert_called_once() + call_args = api.create_dataset.call_args + request = call_args[0][2] + assert request.additional_data == {"key": "value", "nested": {"a": 1}} + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_with_tags_and_additional_data(self, mock_cosmotech_config, mock_api_client): + """Test uploading a dataset with both tags and additional_data.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "new-dataset-123" + + api = DatasetApi(configuration=mock_config) + api.create_dataset = MagicMock(return_value=mock_dataset) + + with tempfile.TemporaryDirectory() as tmpdir: + file1 = Path(tmpdir) / "file1.csv" + file1.write_text("data1") + + result = api.upload_dataset( + "Test Dataset", + as_files=[str(file1)], + tags=["tag1", "tag2"], + additional_data={"key": "value"}, + ) + + assert result == mock_dataset + api.create_dataset.assert_called_once() + call_args = api.create_dataset.call_args + request = call_args[0][2] + assert request.tags == ["tag1", "tag2"] + assert request.additional_data == {"key": "value"} + assert len(request.parts) == 1 + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_parts_new_parts(self, mock_cosmotech_config, mock_api_client): + """Test uploading new parts to an existing dataset.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + # Mock existing dataset with no parts + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "existing-dataset-123" + mock_dataset.parts = [] + + api = DatasetApi(configuration=mock_config) + api.get_dataset = MagicMock(return_value=mock_dataset) + api.create_dataset_part = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + file1 = Path(tmpdir) / "file1.csv" + file1.write_text("data1") + + result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)]) + + assert api.create_dataset_part.called + assert api.get_dataset.call_count == 2 # Called at start and end + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_parts_skip_existing(self, mock_cosmotech_config, mock_api_client): + """Test skipping existing parts when replace_existing=False.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + # Mock existing dataset with one existing part + mock_existing_part = MagicMock() + mock_existing_part.source_name = "file1.csv" + mock_existing_part.id = "part-1" + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "existing-dataset-123" + mock_dataset.parts = [mock_existing_part] + + api = DatasetApi(configuration=mock_config) + api.get_dataset = MagicMock(return_value=mock_dataset) + api.create_dataset_part = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + file1 = Path(tmpdir) / "file1.csv" + file1.write_text("data1") + + result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)]) + + # Part should be skipped, not created + api.create_dataset_part.assert_not_called() + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_parts_replace_existing(self, mock_cosmotech_config, mock_api_client): + """Test replacing existing parts when replace_existing=True.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + # Mock existing dataset with one existing part + mock_existing_part = MagicMock() + mock_existing_part.source_name = "file1.csv" + mock_existing_part.id = "part-1" + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "existing-dataset-123" + mock_dataset.parts = [mock_existing_part] + + api = DatasetApi(configuration=mock_config) + api.get_dataset = MagicMock(return_value=mock_dataset) + api.create_dataset_part = MagicMock() + api.delete_dataset_part = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + file1 = Path(tmpdir) / "file1.csv" + file1.write_text("data1") + + result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1)], replace_existing=True) + + # Part should be deleted and then created + api.delete_dataset_part.assert_called_once() + api.create_dataset_part.assert_called_once() + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_parts_mixed(self, mock_cosmotech_config, mock_api_client): + """Test uploading parts with some existing and some new.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + # Mock existing dataset with one existing part + mock_existing_part = MagicMock() + mock_existing_part.source_name = "file1.csv" + mock_existing_part.id = "part-1" + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "existing-dataset-123" + mock_dataset.parts = [mock_existing_part] + + api = DatasetApi(configuration=mock_config) + api.get_dataset = MagicMock(return_value=mock_dataset) + api.create_dataset_part = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + # Existing file (should be skipped) + file1 = Path(tmpdir) / "file1.csv" + file1.write_text("data1") + # New file (should be created) + file2 = Path(tmpdir) / "file2.csv" + file2.write_text("data2") + + result = api.upload_dataset_parts("existing-dataset-123", as_files=[str(file1), str(file2)]) + + # Only the new file should be created + assert api.create_dataset_part.call_count == 1 + + @patch.dict(os.environ, {"CSM_API_KEY": "test-api-key", "CSM_API_URL": "https://api.example.com"}, clear=True) + @patch("cosmotech_api.ApiClient") + @patch("cosmotech_api.Configuration") + def test_upload_dataset_parts_with_db_type(self, mock_cosmotech_config, mock_api_client): + """Test uploading parts with DB type.""" + mock_config = MagicMock() + mock_config.cosmotech.organization_id = "org-123" + mock_config.cosmotech.workspace_id = "ws-456" + + mock_client_instance = MagicMock() + mock_api_client.return_value = mock_client_instance + mock_configuration_instance = MagicMock() + mock_cosmotech_config.return_value = mock_configuration_instance + + # Mock existing dataset with no parts + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "existing-dataset-123" + mock_dataset.parts = [] + + api = DatasetApi(configuration=mock_config) + api.get_dataset = MagicMock(return_value=mock_dataset) + api.create_dataset_part = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + db_file = Path(tmpdir) / "data.db" + db_file.write_text("database content") + + result = api.upload_dataset_parts("existing-dataset-123", as_db=[str(db_file)]) + + assert api.create_dataset_part.called + # Verify the part request has DB type + call_args = api.create_dataset_part.call_args + part_request = call_args.kwargs.get("dataset_part_create_request") + assert part_request.type == DatasetPartTypeEnum.DB From 469116383ba583346e0d1a93ef1e56e149d24000 Mon Sep 17 00:00:00 2001 From: Alexis Fossart Date: Mon, 19 Jan 2026 09:39:28 +0100 Subject: [PATCH 2/2] fix: use with around open files to avoid possible issues --- cosmotech/coal/cosmotech_api/apis/dataset.py | 24 ++++++++++++-------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/cosmotech/coal/cosmotech_api/apis/dataset.py b/cosmotech/coal/cosmotech_api/apis/dataset.py index 42d75d8c..e0cade46 100644 --- a/cosmotech/coal/cosmotech_api/apis/dataset.py +++ b/cosmotech/coal/cosmotech_api/apis/dataset.py @@ -108,11 +108,16 @@ def upload_dataset( ), ) + _files = [] + for _p in _parts: + with _p[1].open("rb") as _p_file: + _files.append((_p[0], _p_file.read())) + d_ret = self.create_dataset( self.configuration.cosmotech.organization_id, self.configuration.cosmotech.workspace_id, d_request, - files=list((_p[0], _p[1].open("rb").read()) for _p in _parts), + files=_files, ) LOGGER.info(T("coal.services.dataset.dataset_created").format(dataset_id=d_ret.id)) @@ -177,14 +182,15 @@ def upload_dataset_parts( type=_type, ) - self.create_dataset_part( - organization_id=self.configuration.cosmotech.organization_id, - workspace_id=self.configuration.cosmotech.workspace_id, - dataset_id=dataset_id, - dataset_part_create_request=part_request, - file=(_p_name, _p_path.open("rb").read()), - ) - LOGGER.debug(T("coal.services.dataset.part_uploaded").format(part_name=_p_name)) + with _p_path.open("rb") as _p_file: + self.create_dataset_part( + organization_id=self.configuration.cosmotech.organization_id, + workspace_id=self.configuration.cosmotech.workspace_id, + dataset_id=dataset_id, + dataset_part_create_request=part_request, + file=(_p_name, _p_file.read()), + ) + LOGGER.debug(T("coal.services.dataset.part_uploaded").format(part_name=_p_name)) # Return updated dataset updated_dataset = self.get_dataset(