Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion cosmotech/coal/cosmotech_api/apis/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,21 @@ def upload_dataset(
dataset_name: str,
as_files: Optional[list[Union[Path, str]]] = (),
as_db: Optional[list[Union[Path, str]]] = (),
tags: Optional[list[str]] = None,
additional_data: Optional[dict] = None,
) -> Dataset:
"""Upload a new dataset with optional tags and additional data.

Args:
dataset_name: The name of the dataset to create
as_files: List of file paths to upload as FILE type parts
as_db: List of file paths to upload as DB type parts
tags: Optional list of tags to associate with the dataset
additional_data: Optional dictionary of additional metadata

Returns:
The created Dataset object
"""
_parts = list()

for _f in as_files:
Expand All @@ -81,6 +95,8 @@ def upload_dataset(

d_request = DatasetCreateRequest(
name=dataset_name,
tags=tags,
additional_data=additional_data,
parts=list(
DatasetPartCreateRequest(
name=_p_name,
Expand All @@ -92,12 +108,96 @@ def upload_dataset(
),
)

_files = []
for _p in _parts:
with _p[1].open("rb") as _p_file:
_files.append((_p[0], _p_file.read()))

d_ret = self.create_dataset(
self.configuration.cosmotech.organization_id,
self.configuration.cosmotech.workspace_id,
d_request,
files=list((_p[0], _p[1].open("rb").read()) for _p in _parts),
files=_files,
)

LOGGER.info(T("coal.services.dataset.dataset_created").format(dataset_id=d_ret.id))
return d_ret

def upload_dataset_parts(
self,
dataset_id: str,
as_files: Optional[list[Union[Path, str]]] = (),
as_db: Optional[list[Union[Path, str]]] = (),
replace_existing: bool = False,
) -> Dataset:
"""Upload parts to an existing dataset.

Args:
dataset_id: The ID of the existing dataset
as_files: List of file paths to upload as FILE type parts
as_db: List of file paths to upload as DB type parts
replace_existing: If True, replace existing parts with same name

Returns:
The updated Dataset object
"""
# Get current dataset to check existing parts
current_dataset = self.get_dataset(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
)

# Build set of existing part names and their IDs for quick lookup
existing_parts = {part.source_name: part.id for part in (current_dataset.parts or [])}

# Collect parts to upload
_parts = list()
for _f in as_files:
_parts.extend(self.path_to_parts(_f, DatasetPartTypeEnum.FILE))
for _db in as_db:
_parts.extend(self.path_to_parts(_db, DatasetPartTypeEnum.DB))

# Process each part
for _p_name, _p_path, _type in _parts:
if _p_name in existing_parts:
if replace_existing:
# Delete existing part before creating new one
self.delete_dataset_part(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
dataset_part_id=existing_parts[_p_name],
)
LOGGER.info(T("coal.services.dataset.part_replaced").format(part_name=_p_name))
else:
LOGGER.warning(T("coal.services.dataset.part_skipped").format(part_name=_p_name))
continue

# Create new part
part_request = DatasetPartCreateRequest(
name=_p_name,
description=_p_name,
sourceName=_p_name,
type=_type,
)

with _p_path.open("rb") as _p_file:
self.create_dataset_part(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
dataset_part_create_request=part_request,
file=(_p_name, _p_file.read()),
)
LOGGER.debug(T("coal.services.dataset.part_uploaded").format(part_name=_p_name))

# Return updated dataset
updated_dataset = self.get_dataset(
organization_id=self.configuration.cosmotech.organization_id,
workspace_id=self.configuration.cosmotech.workspace_id,
dataset_id=dataset_id,
)

LOGGER.info(T("coal.services.dataset.parts_uploaded").format(dataset_id=dataset_id))
return updated_dataset
6 changes: 6 additions & 0 deletions cosmotech/translation/coal/en-US/coal/services/dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ text_processed: "Processed text file {file_name} with {lines} lines"
# Dataset API operations
part_downloaded: "Downloaded part {part_name} to {file_path}"
dataset_created: "Created dataset {dataset_id}"

# Dataset parts operations
part_uploaded: "Uploaded part {part_name}"
part_replaced: "Replaced existing part {part_name}"
part_skipped: "Skipped existing part {part_name} (use replace_existing=True to overwrite)"
parts_uploaded: "Successfully uploaded parts to dataset {dataset_id}"
2 changes: 1 addition & 1 deletion tests/unit/coal/test_azure/test_azure_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from azure.identity import ClientSecretCredential
from azure.storage.blob import BlobServiceClient, ContainerClient

from cosmotech.coal.azure.blob import dump_store_to_azure
from cosmotech.coal.azure.blob import VALID_TYPES, dump_store_to_azure
from cosmotech.coal.store.store import Store
from cosmotech.coal.utils.configuration import Configuration

Expand Down
Loading
Loading