diff --git a/pyproject.toml b/pyproject.toml index 3d425a7..1a083ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,6 @@ dependencies = [ [project.entry-points.pytest11] databricks-labs-pytester = "databricks.labs.pytester.fixtures.plugin" - [project.urls] Issues = "https://github.com/databrickslabs/pytester/issues" Source = "https://github.com/databrickslabs/pytester" diff --git a/src/databricks/labs/pytester/fixtures/compute.py b/src/databricks/labs/pytester/fixtures/compute.py index a4ea066..bd888c5 100644 --- a/src/databricks/labs/pytester/fixtures/compute.py +++ b/src/databricks/labs/pytester/fixtures/compute.py @@ -339,6 +339,7 @@ def create( cluster_size: str | None = None, max_num_clusters: int = 1, enable_serverless_compute: bool = False, + tags: list[EndpointTagPair] = [], **kwargs, ) -> Wait[GetWarehouseResponse]: if warehouse_name is None: @@ -348,7 +349,9 @@ def create( if cluster_size is None: cluster_size = "2X-Small" - remove_after_tags = EndpointTags(custom_tags=[EndpointTagPair(key="RemoveAfter", value=watchdog_remove_after)]) + remove_after_tags = EndpointTags( + custom_tags=[EndpointTagPair(key="RemoveAfter", value=watchdog_remove_after)] + tags + ) return ws.warehouses.create( name=warehouse_name, cluster_size=cluster_size, diff --git a/tests/integration/fixtures/test_compute.py b/tests/integration/fixtures/test_compute.py index d7ca52a..5d9cb05 100644 --- a/tests/integration/fixtures/test_compute.py +++ b/tests/integration/fixtures/test_compute.py @@ -4,6 +4,7 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service.iam import PermissionLevel from databricks.sdk.service.jobs import RunResultState, SparkPythonTask +from databricks.sdk.service.sql import EndpointTagPair from databricks.labs.pytester.fixtures.watchdog import TEST_RESOURCE_PURGE_TIMEOUT @@ -55,6 +56,13 @@ def test_warehouse_has_remove_after_tag(ws, make_warehouse): assert warehouse_tags["custom_tags"][0]["key"] == "RemoveAfter" +def test_warehouse_has_custom_tag(ws, make_warehouse): + new_warehouse = make_warehouse(tags=[EndpointTagPair(key="my-custom-tag", value="my-custom-tag-value")]) + created_warehouse = ws.warehouses.get(new_warehouse.response.id) + warehouse_tags = created_warehouse.tags.as_dict() + assert warehouse_tags["custom_tags"][1]["key"] == "my-custom-tag" + + def test_remove_after_tag_jobs(ws, env_or_skip, make_job): new_job = make_job() created_job = ws.jobs.get(new_job.job_id)