From f0640fdfd38481c19cadb74a8362e7be946e1bb7 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Fri, 9 Jan 2026 14:19:18 +0100 Subject: [PATCH 1/8] Implement DLT creator --- CLAUDE.md | 82 ++++ .../databricks_dlt_creator.py | 118 ++++++ .../dag_creator/airflow/operator_factory.py | 1 + dagger/pipeline/task_factory.py | 3 +- dagger/pipeline/tasks/databricks_dlt_task.py | 164 ++++++++ dagger/plugins/__init__.py | 1 + dagger/plugins/dlt_task_generator/__init__.py | 6 + .../dlt_task_generator/bundle_parser.py | 309 +++++++++++++++ .../dlt_task_generator/dlt_task_generator.py | 374 ++++++++++++++++++ dagger/utilities/dbt_config_parser.py | 180 ++++++++- dagger/utilities/module.py | 23 +- 11 files changed, 1234 insertions(+), 27 deletions(-) create mode 100644 CLAUDE.md create mode 100644 dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py create mode 100644 dagger/pipeline/tasks/databricks_dlt_task.py create mode 100644 dagger/plugins/__init__.py create mode 100644 dagger/plugins/dlt_task_generator/__init__.py create mode 100644 dagger/plugins/dlt_task_generator/bundle_parser.py create mode 100644 dagger/plugins/dlt_task_generator/dlt_task_generator.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ffd1ebd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,82 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Dagger is a configuration-driven framework that transforms YAML definitions into Apache Airflow DAGs. It uses dataset lineage (matching inputs/outputs) to automatically build dependency graphs across workflows. + +## Common Commands + +### Development Setup +```bash +make install-dev # Create venv, install package in editable mode with dev/test deps +source venv/bin/activate +``` + +### Testing +```bash +make test # Run all tests with coverage (sets AIRFLOW_HOME automatically) + +# Run a single test file +AIRFLOW_HOME=$(pwd)/tests/fixtures/config_finder/root/ ENV=local pytest -s tests/path/to/test_file.py + +# Run a specific test +AIRFLOW_HOME=$(pwd)/tests/fixtures/config_finder/root/ ENV=local pytest -s tests/path/to/test_file.py::test_function_name +``` + +### Linting +```bash +make lint # Run flake8 on dagger and tests directories +black dagger tests # Format code +``` + +### Local Airflow Testing +```bash +make test-airflow # Build and start Airflow in Docker (localhost:8080, user: dev_user, pass: dev_user) +make stop-airflow # Stop Airflow containers +``` + +### CLI +```bash +dagger --help +dagger list-tasks # Show available task types +dagger list-ios # Show available IO types +dagger init-pipeline # Create a new pipeline.yaml +dagger init-task --type= # Add a task configuration +dagger init-io --type= # Add an IO definition +dagger print-graph # Visualize dependency graph +``` + +## Architecture + +### Core Flow +1. **ConfigFinder** discovers pipeline directories (each with `pipeline.yaml` + task YAML files) +2. **ConfigProcessor** loads YAML configs with environment variable support +3. **TaskFactory/IOFactory** use reflection to instantiate task/IO objects from YAML +4. **TaskGraph** builds a 3-layer graph: Pipeline → Task → Dataset nodes +5. **DagCreator** traverses the graph and generates Airflow DAGs using **OperatorFactory** + +### Key Directories +- `dagger/pipeline/tasks/` - Task type definitions (DbtTask, SparkTask, AthenaTransformTask, etc.) +- `dagger/pipeline/ios/` - IO type definitions (S3, Redshift, Athena, Databricks, etc.) +- `dagger/dag_creator/airflow/operator_creators/` - One creator per task type, translates tasks to Airflow operators +- `dagger/graph/` - Graph construction from task inputs/outputs +- `dagger/config_finder/` - YAML discovery and loading +- `tests/fixtures/config_finder/root/dags/` - Example DAG configurations for testing + +### Adding a New Task Type +1. Create task definition in `dagger/pipeline/tasks/` (subclass of Task) +2. Create any needed IOs in `dagger/pipeline/ios/` (if new data sources) +3. Create operator creator in `dagger/dag_creator/airflow/operator_creators/` +4. Register in `dagger/dag_creator/airflow/operator_factory.py` + +### Configuration Files +- `pipeline.yaml` - Pipeline metadata (owner, schedule, alerts, airflow_parameters) +- `[taskname].yaml` - Task configs (type, inputs, outputs, task-specific params) +- `dagger_config.yaml` - System config (Neo4j, Elasticsearch, Spark settings) + +### Key Patterns +- **Factory Pattern**: TaskFactory/IOFactory auto-discover types via reflection +- **Strategy Pattern**: OperatorCreator subclasses handle task-specific operator creation +- **Dataset Aliasing**: IO `alias()` method enables automatic dependency detection across pipelines diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py new file mode 100644 index 0000000..66034a6 --- /dev/null +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -0,0 +1,118 @@ +"""Operator creator for Databricks DLT (Delta Live Tables) pipelines.""" + +import logging +from typing import Any + +from airflow.models import BaseOperator, DAG + +from dagger.dag_creator.airflow.operator_creator import OperatorCreator +from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask + +_logger = logging.getLogger(__name__) + + +def _cancel_databricks_run(context: dict[str, Any]) -> None: + """Cancel a Databricks job run when task fails or is cleared. + + This callback retrieves the run_id from XCom and cancels the corresponding + Databricks job run. Used as on_failure_callback to ensure jobs are cancelled + when tasks are marked as failed. + + Args: + context: Airflow context dictionary containing task instance and other metadata. + """ + from airflow.providers.databricks.hooks.databricks import DatabricksHook + + ti = context.get("task_instance") + if not ti: + _logger.warning("No task instance in context, cannot cancel Databricks run") + return + + # Get run_id from XCom (pushed by DatabricksRunNowOperator) + run_id = ti.xcom_pull(task_ids=ti.task_id, key="run_id") + if not run_id: + _logger.warning(f"No run_id found in XCom for task {ti.task_id}") + return + + # Get the databricks_conn_id from the operator + databricks_conn_id = getattr(ti.task, "databricks_conn_id", "databricks_default") + + try: + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) + hook.cancel_run(run_id) + _logger.info(f"Cancelled Databricks run {run_id} for task {ti.task_id}") + except Exception as e: + _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") + + +class DatabricksDLTCreator(OperatorCreator): + """Creates operators for triggering Databricks DLT pipelines via Jobs. + + This creator uses DatabricksRunNowOperator to trigger a Databricks Job + that wraps the DLT pipeline. The job is identified by name and must be + defined in the Databricks Asset Bundle. + + Attributes: + ref_name: Reference name used by OperatorFactory to match this creator + with DatabricksDLTTask instances. + """ + + ref_name: str = "databricks_dlt" + + def __init__(self, task: DatabricksDLTTask, dag: DAG) -> None: + """Initialize the DatabricksDLTCreator. + + Args: + task: The DatabricksDLTTask containing pipeline configuration. + dag: The Airflow DAG this operator will belong to. + """ + super().__init__(task, dag) + + def _create_operator(self, **kwargs: Any) -> BaseOperator: + """Create a DatabricksRunNowOperator for the DLT pipeline. + + Creates an Airflow operator that triggers an existing Databricks Job + by name. The job must have a pipeline_task that references the DLT + pipeline. + + Args: + **kwargs: Additional keyword arguments passed to the operator. + + Returns: + A configured DatabricksRunNowOperator instance. + """ + # Import here to avoid import errors if databricks provider not installed + from datetime import timedelta + + from airflow.providers.databricks.operators.databricks import ( + DatabricksRunNowOperator, + ) + + # Get task parameters + job_name: str = self._task.job_name + databricks_conn_id: str = getattr( + self._task, "databricks_conn_id", "databricks_default" + ) + wait_for_completion: bool = getattr(self._task, "wait_for_completion", True) + poll_interval_seconds: int = getattr(self._task, "poll_interval_seconds", 30) + timeout_seconds: int = getattr(self._task, "timeout_seconds", 3600) + + # DatabricksRunNowOperator triggers an existing Databricks Job by name + # The job must have a pipeline_task that references the DLT pipeline + # Note: timeout is handled via Airflow's execution_timeout, not a direct parameter + # Note: on_kill() is already implemented in DatabricksRunNowOperator to cancel runs + # We add on_failure_callback to also cancel when task is marked as failed + operator: BaseOperator = DatabricksRunNowOperator( + dag=self._dag, + task_id=self._task.name, + databricks_conn_id=databricks_conn_id, + job_name=job_name, + wait_for_termination=wait_for_completion, + polling_period_seconds=poll_interval_seconds, + execution_timeout=timedelta(seconds=timeout_seconds), + do_xcom_push=True, # Required to store run_id for cancellation callback + on_failure_callback=_cancel_databricks_run, + **kwargs, + ) + + return operator diff --git a/dagger/dag_creator/airflow/operator_factory.py b/dagger/dag_creator/airflow/operator_factory.py index 2a1654a..dd7344e 100644 --- a/dagger/dag_creator/airflow/operator_factory.py +++ b/dagger/dag_creator/airflow/operator_factory.py @@ -4,6 +4,7 @@ airflow_op_creator, athena_transform_creator, batch_creator, + databricks_dlt_creator, dbt_creator, dummy_creator, python_creator, diff --git a/dagger/pipeline/task_factory.py b/dagger/pipeline/task_factory.py index 9ed79e7..f5f80bb 100644 --- a/dagger/pipeline/task_factory.py +++ b/dagger/pipeline/task_factory.py @@ -3,6 +3,7 @@ airflow_op_task, athena_transform_task, batch_task, + databricks_dlt_task, dbt_task, dummy_task, python_task, @@ -12,7 +13,7 @@ reverse_etl_task, spark_task, sqoop_task, - soda_task + soda_task, ) from dagger.utilities.classes import get_deep_obj_subclasses diff --git a/dagger/pipeline/tasks/databricks_dlt_task.py b/dagger/pipeline/tasks/databricks_dlt_task.py new file mode 100644 index 0000000..4f0b113 --- /dev/null +++ b/dagger/pipeline/tasks/databricks_dlt_task.py @@ -0,0 +1,164 @@ +"""Task configuration for Databricks DLT (Delta Live Tables) pipelines.""" + +from typing import Any, Optional + +from dagger.pipeline.task import Task +from dagger.utilities.config_validator import Attribute + + +class DatabricksDLTTask(Task): + """Task configuration for triggering Databricks DLT pipelines via Jobs. + + This task type uses DatabricksRunNowOperator to trigger a Databricks Job + that wraps the DLT pipeline. The job is identified by name and must be + defined in the Databricks Asset Bundle. + + Attributes: + ref_name: Reference name used by TaskFactory to instantiate this task type. + job_name: Databricks Job name that triggers the DLT pipeline. + databricks_conn_id: Airflow connection ID for Databricks. + wait_for_completion: Whether to wait for job completion. + poll_interval_seconds: Polling interval in seconds. + timeout_seconds: Timeout in seconds. + cancel_on_kill: Whether to cancel Databricks job if Airflow task is killed. + + Example YAML configuration: + type: databricks_dlt + description: Run DLT pipeline users + inputs: + - type: athena + schema: ddb_changelogs + table: order_preference + follow_external_dependency: true + outputs: + - type: databricks + catalog: ${ENV_MARTS} + schema: dlt_users + table: silver_order_preference + task_parameters: + job_name: dlt-users + databricks_conn_id: databricks_default + wait_for_completion: true + poll_interval_seconds: 30 + timeout_seconds: 3600 + """ + + ref_name: str = "databricks_dlt" + + @classmethod + def init_attributes(cls, orig_cls: type) -> None: + """Initialize configuration attributes for YAML parsing. + + Registers all task_parameters attributes that can be specified in the + YAML configuration file. Called by the Task metaclass during class creation. + + Args: + orig_cls: The original class being initialized (used for attribute registration). + """ + cls.add_config_attributes( + [ + Attribute( + attribute_name="job_name", + parent_fields=["task_parameters"], + comment="Databricks Job name that triggers the DLT pipeline", + ), + Attribute( + attribute_name="databricks_conn_id", + parent_fields=["task_parameters"], + required=False, + comment="Airflow connection ID for Databricks (default: databricks_default)", + ), + Attribute( + attribute_name="wait_for_completion", + parent_fields=["task_parameters"], + required=False, + validator=bool, + comment="Wait for job to complete (default: true)", + ), + Attribute( + attribute_name="poll_interval_seconds", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="Polling interval in seconds (default: 30)", + ), + Attribute( + attribute_name="timeout_seconds", + parent_fields=["task_parameters"], + required=False, + validator=int, + comment="Timeout in seconds (default: 3600)", + ), + Attribute( + attribute_name="cancel_on_kill", + parent_fields=["task_parameters"], + required=False, + validator=bool, + comment="Cancel Databricks job if Airflow task is killed (default: true)", + ), + ] + ) + + def __init__( + self, + name: str, + pipeline_name: str, + pipeline: Any, + job_config: dict[str, Any], + ) -> None: + """Initialize a DatabricksDLTTask instance. + + Args: + name: The task name (used as task_id in Airflow). + pipeline_name: Name of the Dagger pipeline this task belongs to. + pipeline: The parent Pipeline object. + job_config: Dictionary containing the task configuration from YAML. + """ + super().__init__(name, pipeline_name, pipeline, job_config) + + self._job_name: str = self.parse_attribute("job_name") + self._databricks_conn_id: str = ( + self.parse_attribute("databricks_conn_id") or "databricks_default" + ) + wait_for_completion: Optional[bool] = self.parse_attribute("wait_for_completion") + self._wait_for_completion: bool = ( + wait_for_completion if wait_for_completion is not None else True + ) + self._poll_interval_seconds: int = ( + self.parse_attribute("poll_interval_seconds") or 30 + ) + self._timeout_seconds: int = self.parse_attribute("timeout_seconds") or 3600 + cancel_on_kill: Optional[bool] = self.parse_attribute("cancel_on_kill") + self._cancel_on_kill: bool = ( + cancel_on_kill if cancel_on_kill is not None else True + ) + + @property + def job_name(self) -> str: + """Databricks Job name that triggers the DLT pipeline.""" + return self._job_name + + @property + def databricks_conn_id(self) -> str: + """Airflow connection ID for Databricks.""" + return self._databricks_conn_id + + @property + def wait_for_completion(self) -> bool: + """Whether to wait for job completion.""" + return self._wait_for_completion + + @property + def poll_interval_seconds(self) -> int: + """Polling interval in seconds.""" + return self._poll_interval_seconds + + @property + def timeout_seconds(self) -> int: + """Timeout in seconds.""" + return self._timeout_seconds + + @property + def cancel_on_kill(self) -> bool: + """Whether to cancel Databricks job if Airflow task is killed.""" + return self._cancel_on_kill diff --git a/dagger/plugins/__init__.py b/dagger/plugins/__init__.py new file mode 100644 index 0000000..26acb8c --- /dev/null +++ b/dagger/plugins/__init__.py @@ -0,0 +1 @@ +"""Dagger plugins for task generation.""" diff --git a/dagger/plugins/dlt_task_generator/__init__.py b/dagger/plugins/dlt_task_generator/__init__.py new file mode 100644 index 0000000..49e4a17 --- /dev/null +++ b/dagger/plugins/dlt_task_generator/__init__.py @@ -0,0 +1,6 @@ +"""DLT Task Generator plugin for generating Dagger configs from Databricks Asset Bundles.""" + +from dagger.plugins.dlt_task_generator.bundle_parser import DatabricksBundleParser +from dagger.plugins.dlt_task_generator.dlt_task_generator import DLTTaskGenerator + +__all__ = ["DatabricksBundleParser", "DLTTaskGenerator"] diff --git a/dagger/plugins/dlt_task_generator/bundle_parser.py b/dagger/plugins/dlt_task_generator/bundle_parser.py new file mode 100644 index 0000000..5a15ed7 --- /dev/null +++ b/dagger/plugins/dlt_task_generator/bundle_parser.py @@ -0,0 +1,309 @@ +"""Parse Databricks Asset Bundle YAML files for DLT pipeline configuration.""" + +import logging +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional + +import yaml + +_logger = logging.getLogger(__name__) + + +@dataclass +class TableConfig: + """Configuration for a single table in a DLT pipeline. + + Attributes: + database: Source database name. + table: Source table name. + changelog_type: Type of changelog source ('dynamodb' or 'postgres'). + unique_keys: List of columns that uniquely identify a row. + scd_type: Slowly changing dimension type (1 or 2). + """ + + database: str + table: str + changelog_type: str # 'dynamodb' or 'postgres' + unique_keys: list[str] = field(default_factory=list) + scd_type: int = 1 + + @property + def source_schema(self) -> str: + """Get the source schema name for Athena/Glue catalog. + + For DynamoDB: ddb_changelogs + For PostgreSQL: pg_changelogs_kafka_{database_normalized} + """ + if self.changelog_type == "dynamodb": + return self.database + elif self.changelog_type == "postgres": + # Normalize database name (replace hyphens with underscores) + db_normalized = self.database.replace("-", "_") + return f"pg_changelogs_kafka_{db_normalized}" + else: + return self.database + + @property + def silver_table_name(self) -> str: + """Get the silver table name produced by DLT.""" + return f"silver_{self.table}" + + @property + def bronze_table_name(self) -> str: + """Get the bronze table name produced by DLT.""" + return f"bronze_{self.table}" + + +@dataclass +class PipelineConfig: + """Configuration for a DLT pipeline parsed from Databricks Asset Bundle. + + Attributes: + name: Pipeline/bundle name. + catalog: Target Unity Catalog name. + schema: Target schema name. + tables: List of table configurations for the pipeline. + targets: Target environment configurations (dev/prod). + variables: Variable definitions from databricks.yml. + tags: Pipeline tags. + """ + + name: str + catalog: str + schema: str + tables: list[TableConfig] = field(default_factory=list) + targets: dict[str, Any] = field(default_factory=dict) + variables: dict[str, Any] = field(default_factory=dict) + tags: dict[str, str] = field(default_factory=dict) + + +class DatabricksBundleParser: + """Parse Databricks Asset Bundle YAML files (databricks.yml and tables.yml). + + This parser extracts pipeline configuration from Databricks Asset Bundles, + including the target catalog, schema, table definitions, and job configuration. + It resolves variable references to Dagger environment variable format. + + Attributes: + _databricks_yml_path: Path to the databricks.yml file. + _tables_yml_path: Path to the tables.yml file. + _databricks_config: Parsed databricks.yml content. + _tables_config: Parsed tables.yml content. + _pipeline_config: Cached PipelineConfig instance. + """ + + def __init__( + self, + databricks_yml_path: Path, + tables_yml_path: Optional[Path] = None, + ) -> None: + """Initialize the parser with paths to bundle YAML files. + + Args: + databricks_yml_path: Path to the databricks.yml file. + tables_yml_path: Optional path to the tables.yml file. If not provided, + will look for tables.yml in the same directory. + """ + self._databricks_yml_path = Path(databricks_yml_path) + self._tables_yml_path = ( + Path(tables_yml_path) + if tables_yml_path + else self._databricks_yml_path.parent / "tables.yml" + ) + + self._databricks_config = self._load_yaml(self._databricks_yml_path) + self._tables_config = ( + self._load_yaml(self._tables_yml_path) + if self._tables_yml_path.exists() + else {} + ) + + self._pipeline_config: Optional[PipelineConfig] = None + + @staticmethod + def _load_yaml(path: Path) -> dict[str, Any]: + """Load and parse a YAML file. + + Args: + path: Path to the YAML file. + + Returns: + Parsed YAML content as a dictionary. + + Raises: + yaml.YAMLError: If the YAML file is malformed. + """ + try: + with open(path, "r") as f: + return yaml.safe_load(f) or {} + except FileNotFoundError: + _logger.warning(f"YAML file not found: {path}") + return {} + except yaml.YAMLError as e: + _logger.error(f"Error parsing YAML file {path}: {e}") + raise + + def _resolve_variable(self, value: str) -> str: + """Resolve Databricks bundle variable references like ${var.catalog}. + + Args: + value: String that may contain variable references + + Returns: + Resolved string with environment variable format for Dagger + """ + if not isinstance(value, str): + return value + + # Match ${var.variable_name} pattern + var_pattern = re.compile(r"\$\{var\.(\w+)\}") + + def replace_var(match): + var_name = match.group(1) + # Get the default value from variables section + var_config = self._databricks_config.get("variables", {}).get(var_name, {}) + default_value = var_config.get("default", "") + + # Map to Dagger environment variables + if var_name == "catalog": + # Map catalog to Dagger's ${ENV_MARTS} pattern + return "${ENV_MARTS}" + return default_value + + return var_pattern.sub(replace_var, value) + + def _parse_tables(self) -> list[TableConfig]: + """Parse table configurations from tables.yml. + + Returns: + List of TableConfig instances for each table defined in the bundle. + """ + tables = [] + defaults = self._tables_config.get("defaults", {}) + default_scd_type = defaults.get("scd_type", 1) + + for table_config in self._tables_config.get("tables", []): + tables.append( + TableConfig( + database=table_config.get("database", ""), + table=table_config.get("table", ""), + changelog_type=table_config.get("changelog_type", "dynamodb"), + unique_keys=table_config.get("unique_keys", []), + scd_type=table_config.get("scd_type", default_scd_type), + ) + ) + + return tables + + def _parse_pipeline(self) -> PipelineConfig: + """Parse pipeline configuration from databricks.yml. + + Extracts bundle name, variables, targets, and pipeline-specific settings + from the Databricks Asset Bundle configuration. + + Returns: + PipelineConfig instance with all parsed configuration. + """ + bundle_name = self._databricks_config.get("bundle", {}).get("name", "") + variables = self._databricks_config.get("variables", {}) + targets = self._databricks_config.get("targets", {}) + + # Get pipeline configuration from resources + resources = self._databricks_config.get("resources", {}) + pipelines = resources.get("pipelines", {}) + + # Get the first pipeline (usually matches bundle name) + pipeline_key = bundle_name or next(iter(pipelines.keys()), "") + pipeline_config = pipelines.get(pipeline_key, {}) + + catalog = self._resolve_variable(pipeline_config.get("catalog", "")) + schema = pipeline_config.get("schema", "") + tags = pipeline_config.get("tags", {}) + + return PipelineConfig( + name=bundle_name, + catalog=catalog, + schema=schema, + tables=self._parse_tables(), + targets=targets, + variables=variables, + tags=tags, + ) + + def parse(self) -> PipelineConfig: + """Parse the Databricks Asset Bundle and return pipeline configuration. + + Returns: + PipelineConfig with all parsed configuration + """ + if self._pipeline_config is None: + self._pipeline_config = self._parse_pipeline() + return self._pipeline_config + + def get_bundle_name(self) -> str: + """Return the bundle/pipeline name. + + Returns: + The bundle name from databricks.yml. + """ + return self.parse().name + + def get_catalog(self) -> str: + """Return the target catalog with Dagger environment variable format. + + Returns: + Catalog name with variables resolved to Dagger format (e.g., ${ENV_MARTS}). + """ + return self.parse().catalog + + def get_schema(self) -> str: + """Return the target schema. + + Returns: + Target schema name for the DLT pipeline. + """ + return self.parse().schema + + def get_tables(self) -> list[TableConfig]: + """Return the list of table configurations. + + Returns: + List of TableConfig instances for all tables in the pipeline. + """ + return self.parse().tables + + def get_targets(self) -> dict[str, Any]: + """Return target configurations (dev/prod). + + Returns: + Dictionary of target environment configurations. + """ + return self.parse().targets + + def get_variables(self) -> dict[str, Any]: + """Return variable definitions. + + Returns: + Dictionary of variable definitions from databricks.yml. + """ + return self.parse().variables + + def get_job_name(self) -> str: + """Get the Databricks Job name that triggers this pipeline. + + Looks for a job defined in resources.jobs that wraps the pipeline. + Falls back to a default naming convention if no job is defined. + + Returns: + The job name to use with DatabricksRunNowOperator + """ + resources = self._databricks_config.get("resources", {}) + jobs = resources.get("jobs", {}) + + # Return the first job's name, or use default naming convention + for job_config in jobs.values(): + return job_config.get("name", f"dlt-{self.get_bundle_name()}") + + return f"dlt-{self.get_bundle_name()}" diff --git a/dagger/plugins/dlt_task_generator/dlt_task_generator.py b/dagger/plugins/dlt_task_generator/dlt_task_generator.py new file mode 100644 index 0000000..6136227 --- /dev/null +++ b/dagger/plugins/dlt_task_generator/dlt_task_generator.py @@ -0,0 +1,374 @@ +"""Generate Dagger task configurations from Databricks DLT bundle definitions.""" + +import logging +import os +from pathlib import Path +from typing import Any, Optional + +import yaml + +from dagger.plugins.dlt_task_generator.bundle_parser import ( + DatabricksBundleParser, + PipelineConfig, + TableConfig, +) + +_logger = logging.getLogger(__name__) + +# Default path to the DLT pipelines repository (can be overridden via env var) +DEFAULT_DLT_PIPELINES_REPO = os.getenv( + "DLT_PIPELINES_REPO", + str(Path(__file__).parent.parent.parent.parent.parent / "dataeng-databricks-dlt-pipelines"), +) + + +class DLTTaskGenerator: + """Generate Dagger task configurations from Databricks DLT bundle definitions. + + This generator reads Databricks Asset Bundle configurations and produces + Dagger-compatible YAML task configurations for DLT pipelines. + + Attributes: + ATHENA_TASK_BASE: Base configuration for Athena input tasks. + DATABRICKS_TASK_BASE: Base configuration for Databricks output tasks. + DUMMY_TASK_BASE: Base configuration for dummy tasks. + """ + + ATHENA_TASK_BASE: dict[str, str] = {"type": "athena"} + DATABRICKS_TASK_BASE: dict[str, str] = {"type": "databricks"} + DUMMY_TASK_BASE: dict[str, str] = {"type": "dummy"} + + def __init__(self, dlt_repo_path: Optional[str] = None) -> None: + """Initialize the generator with path to DLT pipelines repository. + + Args: + dlt_repo_path: Path to the dataeng-databricks-dlt-pipelines repository. + Defaults to DLT_PIPELINES_REPO env var or sibling directory. + """ + self._dlt_repo_path = Path(dlt_repo_path or DEFAULT_DLT_PIPELINES_REPO) + self._pipelines: dict[str, DatabricksBundleParser] = {} + self._load_all_pipelines() + + def _load_all_pipelines(self) -> None: + """Load all pipeline bundles from the DLT repository. + + Scans the pipelines directory and loads each valid Databricks Asset Bundle + found. Bundles are identified by the presence of a databricks.yml file. + """ + pipelines_dir = self._dlt_repo_path / "pipelines" + + if not pipelines_dir.exists(): + _logger.warning(f"DLT pipelines directory not found: {pipelines_dir}") + return + + for pipeline_dir in pipelines_dir.iterdir(): + if not pipeline_dir.is_dir(): + continue + + databricks_yml = pipeline_dir / "databricks.yml" + if not databricks_yml.exists(): + continue + + tables_yml = pipeline_dir / "tables.yml" + try: + parser = DatabricksBundleParser(databricks_yml, tables_yml) + pipeline_name = parser.get_bundle_name() or pipeline_dir.name + self._pipelines[pipeline_name] = parser + _logger.info(f"Loaded DLT pipeline: {pipeline_name}") + except Exception as e: + _logger.error(f"Error loading pipeline from {pipeline_dir}: {e}") + + def get_pipeline_names(self) -> list[str]: + """Return list of available DLT pipeline names.""" + return list(self._pipelines.keys()) + + def get_pipeline_config(self, pipeline_name: str) -> PipelineConfig: + """Get the parsed pipeline configuration. + + Args: + pipeline_name: Name of the pipeline + + Returns: + PipelineConfig object + + Raises: + ValueError: If pipeline not found + """ + if pipeline_name not in self._pipelines: + raise ValueError( + f"Unknown pipeline: {pipeline_name}. " + f"Available pipelines: {self.get_pipeline_names()}" + ) + return self._pipelines[pipeline_name].parse() + + def _get_athena_input( + self, table: TableConfig, follow_external_dependency: bool = True + ) -> dict[str, Any]: + """Generate an Athena input task for a source changelog table. + + Args: + table: Table configuration from the DLT bundle. + follow_external_dependency: Whether to create an ExternalTaskSensor + for cross-pipeline dependency tracking. + + Returns: + Dagger Athena task configuration dict. + """ + task = self.ATHENA_TASK_BASE.copy() + task.update( + { + "schema": table.source_schema, + "table": table.table, + "name": f"{table.source_schema}__{table.table}_athena", + } + ) + if follow_external_dependency: + task["follow_external_dependency"] = True + return task + + def _get_databricks_output( + self, table: TableConfig, catalog: str, schema: str + ) -> dict[str, Any]: + """Generate a Databricks output task for a silver table. + + Args: + table: Table configuration from the DLT bundle. + catalog: Target Unity Catalog name (e.g., ${ENV_MARTS}). + schema: Target schema name. + + Returns: + Dagger Databricks task configuration dict. + """ + task = self.DATABRICKS_TASK_BASE.copy() + # Normalize catalog name for task naming + catalog_name = catalog.replace("${", "").replace("}", "").lower() + task.update( + { + "catalog": catalog, + "schema": schema, + "table": table.silver_table_name, + "name": f"{catalog_name}__{schema}__{table.silver_table_name}_databricks", + } + ) + return task + + def get_inputs( + self, pipeline_name: str, follow_external_dependency: bool = True + ) -> list[dict[str, Any]]: + """Generate input dependencies for a DLT pipeline task. + + These are the source changelog tables that the DLT pipeline reads from. + + Args: + pipeline_name: Name of the DLT pipeline. + follow_external_dependency: Whether to create ExternalTaskSensors + for cross-pipeline dependency tracking. + + Returns: + List of Dagger input task configurations. + """ + config = self.get_pipeline_config(pipeline_name) + inputs = [] + + for table in config.tables: + input_task = self._get_athena_input(table, follow_external_dependency) + inputs.append(input_task) + + return inputs + + def get_outputs(self, pipeline_name: str) -> list[dict[str, Any]]: + """Generate output declarations for a DLT pipeline task. + + These are the silver tables produced by the DLT pipeline. + + Args: + pipeline_name: Name of the DLT pipeline. + + Returns: + List of Dagger output task configurations. + """ + config = self.get_pipeline_config(pipeline_name) + outputs = [] + + for table in config.tables: + output_task = self._get_databricks_output( + table, config.catalog, config.schema + ) + outputs.append(output_task) + + return outputs + + def get_task_parameters(self, pipeline_name: str) -> dict[str, Any]: + """Generate task parameters for triggering a DLT pipeline via Databricks Job. + + Args: + pipeline_name: Name of the DLT pipeline. + + Returns: + Dict of task parameters for the DatabricksRunNowOperator. + """ + parser = self._pipelines[pipeline_name] + return { + "job_name": parser.get_job_name(), + "databricks_conn_id": "${DATABRICKS_CONN_ID}", + "wait_for_completion": True, + "poll_interval_seconds": 30, + "timeout_seconds": 3600, + } + + def generate_task_config( + self, + pipeline_name: str, + description: Optional[str] = None, + follow_external_dependency: bool = True, + ) -> dict[str, Any]: + """Generate a complete Dagger task configuration for a DLT pipeline. + + Args: + pipeline_name: Name of the DLT pipeline. + description: Optional task description. Defaults to auto-generated. + follow_external_dependency: Whether to create ExternalTaskSensors for inputs. + + Returns: + Complete Dagger task configuration dict ready for YAML serialization. + """ + config = self.get_pipeline_config(pipeline_name) + + task_config = { + "type": "databricks_dlt", + "description": description or f"Run DLT pipeline {pipeline_name}", + "inputs": self.get_inputs(pipeline_name, follow_external_dependency), + "outputs": self.get_outputs(pipeline_name), + "airflow_task_parameters": { + "retries": 2, + "retry_delay": 300, + }, + "template_parameters": {}, + "task_parameters": self.get_task_parameters(pipeline_name), + } + + return task_config + + def generate_pipeline_config( + self, + pipeline_name: str, + schedule: str = "0 * * * *", + owner: str = "dataeng@choco.com", + ) -> dict[str, Any]: + """Generate a Dagger pipeline.yaml configuration for a DLT pipeline DAG. + + Args: + pipeline_name: Name of the DLT pipeline. + schedule: Cron schedule expression. Defaults to hourly. + owner: Pipeline owner email address. + + Returns: + Dagger pipeline.yaml configuration dict ready for YAML serialization. + """ + config = self.get_pipeline_config(pipeline_name) + + return { + "owner": owner, + "description": f"DLT Pipeline - {pipeline_name}", + "schedule": schedule, + "start_date": "2024-01-01T00:00", + "airflow_parameters": { + "default_args": { + "retries": 2, + "retry_delay": 180, + "depends_on_past": False, + }, + "dag_parameters": { + "catchup": False, + "max_active_runs": 1, + "tags": ["dlt", "databricks", pipeline_name], + }, + }, + "alerts": [ + { + "type": "slack", + "channel": "#${ENV}-airflow-alerts", + "mentions": ["@dataeng-oncall"], + } + ], + } + + def write_task_config( + self, pipeline_name: str, output_path: Path, **kwargs: Any + ) -> Path: + """Write a task configuration to a YAML file. + + Args: + pipeline_name: Name of the DLT pipeline. + output_path: Directory to write the file to. + **kwargs: Additional arguments passed to generate_task_config. + + Returns: + Path to the written YAML file. + """ + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + task_config = self.generate_task_config(pipeline_name, **kwargs) + file_path = output_path / f"{pipeline_name}_dlt.yaml" + + with open(file_path, "w") as f: + # Add autogenerated marker + task_config["autogenerated_by_dagger"] = f"dlt_task_generator:{pipeline_name}" + yaml.dump(task_config, f, default_flow_style=False, sort_keys=False) + + _logger.info(f"Generated task config: {file_path}") + return file_path + + def write_pipeline_config( + self, pipeline_name: str, output_path: Path, **kwargs: Any + ) -> Path: + """Write a pipeline configuration to a YAML file. + + Args: + pipeline_name: Name of the DLT pipeline. + output_path: Directory to write the file to. + **kwargs: Additional arguments passed to generate_pipeline_config. + + Returns: + Path to the written YAML file. + """ + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + pipeline_config = self.generate_pipeline_config(pipeline_name, **kwargs) + file_path = output_path / "pipeline.yaml" + + with open(file_path, "w") as f: + yaml.dump(pipeline_config, f, default_flow_style=False, sort_keys=False) + + _logger.info(f"Generated pipeline config: {file_path}") + return file_path + + def generate_all(self, output_base_path: Path) -> list[Path]: + """Generate all DLT pipeline configurations. + + Creates pipeline.yaml and task configuration files for each loaded + DLT pipeline in the repository. + + Args: + output_base_path: Base directory for output (e.g., dags/dlt/). + + Returns: + List of paths to all generated files. + """ + output_base_path = Path(output_base_path) + generated_files = [] + + for pipeline_name in self.get_pipeline_names(): + pipeline_output_path = output_base_path / pipeline_name + + # Generate pipeline.yaml + pipeline_file = self.write_pipeline_config(pipeline_name, pipeline_output_path) + generated_files.append(pipeline_file) + + # Generate task config + task_file = self.write_task_config(pipeline_name, pipeline_output_path) + generated_files.append(task_file) + + return generated_files diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 9a341f6..9b86d5d 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -296,7 +296,20 @@ def _generate_dagger_output(self, node: dict): class DatabricksDBTConfigParser(DBTConfigParser): - """Implementation for Databricks configurations.""" + """DBT config parser implementation for Databricks Unity Catalog. + + Parses dbt manifest.json files for projects using the databricks-dbt adapter + and generates Dagger task configurations. Handles both Unity Catalog sources + (accessed via Databricks) and legacy Hive metastore sources (accessed via Athena). + + Attributes: + LEGACY_HIVE_DATABASES: Set of database names that indicate legacy Hive + metastore tables accessed via Athena rather than Unity Catalog. + """ + + # Schemas that indicate sources are in legacy Hive metastore (accessed via Athena) + # rather than Unity Catalog (accessed via Databricks) + LEGACY_HIVE_DATABASES: set[str] = {"hive_metastore"} def __init__(self, default_config_parameters: dict): super().__init__(default_config_parameters) @@ -306,17 +319,132 @@ def __init__(self, default_config_parameters: dict): "create_external_athena_table", False ) - def _is_node_preparation_model(self, node: dict): + def _is_databricks_source(self, node: dict) -> bool: + """Check if a source is a Unity Catalog table (accessed via Databricks). + + Sources with database 'hive_metastore' are legacy tables accessed via Athena. + Sources with other databases (e.g., Unity Catalog like ${ENV_MARTS}) are + Databricks tables that should create databricks input tasks. + + Args: + node: The source node from dbt manifest + + Returns: + True if the source is a Unity Catalog table, False otherwise """ - Define whether it is a preparation model. + database = node.get("database", "") + return database not in self.LEGACY_HIVE_DATABASES + + def _is_node_preparation_model(self, node: dict) -> bool: + """Determine whether a node is a preparation model. + + Preparation models are intermediate models in the transformation pipeline + that should not create external dependencies. + + Args: + node: The dbt node from manifest.json. + + Returns: + True if the node's schema contains 'preparation', False otherwise. """ return "preparation" in node.get("schema", "") - def _get_table_task( + def _get_databricks_source_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: + """Generate a databricks input task for a Unity Catalog source. + + This is used for sources that point to Unity Catalog tables (e.g., DLT outputs) + rather than legacy Hive metastore tables. + + Args: + node: The source node from dbt manifest + follow_external_dependency: Whether to create an ExternalTaskSensor + + Returns: + Dagger databricks task configuration dict """ - Generates the dagger databricks task for the DBT model node + task = DATABRICKS_TASK_BASE.copy() + if follow_external_dependency: + task["follow_external_dependency"] = True + + task["catalog"] = node.get("database", self._default_catalog) + task["schema"] = node.get("schema", self._default_schema) + task["table"] = node.get("name", "") + task["name"] = f"{task['catalog']}__{task['schema']}__{task['table']}_databricks" + + return task + + def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: + """Generate dagger tasks, with special handling for Databricks Unity Catalog sources. + + Overrides the base class method to handle sources that are in Unity Catalog + (e.g., DLT output tables) by creating databricks input tasks instead of athena tasks. + + Args: + node_name: The name of the DBT model node + + Returns: + List[Dict]: The respective dagger tasks for the DBT model node + """ + dagger_tasks = [] + + if node_name.startswith("source"): + node = self._sources_in_manifest[node_name] + else: + node = self._nodes_in_manifest[node_name] + + resource_type = node.get("resource_type") + materialized_type = node.get("config", {}).get("materialized") + + follow_external_dependency = True + if resource_type == "seed" or (self._is_node_preparation_model(node) and materialized_type != "table"): + follow_external_dependency = False + + if resource_type == "source": + # Check if this source is a Unity Catalog table (e.g., DLT outputs) + if self._is_databricks_source(node): + table_task = self._get_databricks_source_task( + node, follow_external_dependency=follow_external_dependency + ) + else: + # Legacy Hive metastore sources use Athena + table_task = self._get_athena_table_task( + node, follow_external_dependency=follow_external_dependency + ) + dagger_tasks.append(table_task) + + elif materialized_type == "ephemeral": + task = self._get_dummy_task(node) + dagger_tasks.append(task) + for dependent_node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks += self._generate_dagger_tasks(dependent_node_name) + + else: + table_task = self._get_table_task(node, follow_external_dependency=follow_external_dependency) + dagger_tasks.append(table_task) + + if materialized_type in ("table", "incremental"): + dagger_tasks.append(self._get_s3_task(node)) + elif self._is_node_preparation_model(node): + for dependent_node_name in node.get("depends_on", {}).get("nodes", []): + dagger_tasks.extend( + self._generate_dagger_tasks(dependent_node_name) + ) + + return dagger_tasks + + def _get_table_task( + self, node: dict, follow_external_dependency: bool = False + ) -> dict: + """Generate a Databricks table task for a dbt model node. + + Args: + node: The dbt model node from manifest.json. + follow_external_dependency: Whether to create an ExternalTaskSensor. + + Returns: + Dagger databricks task configuration dict. """ task = DATABRICKS_TASK_BASE.copy() if follow_external_dependency: @@ -334,8 +462,15 @@ def _get_table_task( def _get_model_data_location( self, node: dict, schema: str, model_name: str ) -> Tuple[str, str]: - """ - Gets the S3 path of the dbt model relative to the data bucket. + """Get the S3 path of a dbt model relative to the data bucket. + + Args: + node: The dbt model node from manifest.json. + schema: The schema name (unused for Databricks, kept for interface compatibility). + model_name: The model name. + + Returns: + Tuple of (bucket_name, data_path). """ location_root = node.get("config", {}).get("location_root") location = join(location_root, model_name) @@ -345,32 +480,39 @@ def _get_model_data_location( return bucket_name, data_path def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: - """ - Generates the dagger s3 task for the databricks-dbt model node + """Generate an S3 task for a databricks-dbt model node. + + Args: + node: The dbt model node from manifest.json. + is_output: If True, names the task 'output_s3_path' for output declarations. + + Returns: + Dagger S3 task configuration dict. """ task = S3_TASK_BASE.copy() schema = node.get("schema", self._default_schema) table = node.get("name", "") - task["name"] = f"output_s3_path" if is_output else f"s3_{table}" + task["name"] = "output_s3_path" if is_output else f"s3_{table}" task["bucket"], task["path"] = self._get_model_data_location( node, schema, table ) return task - def _generate_dagger_output(self, node: dict): - """ - Generates the dagger output for the DBT model node with the databricks-dbt adapter. - If the model is materialized as a view or ephemeral, then a dummy task is created. - Otherwise, and databricks and s3 task is created for the DBT model node. - And if create_external_athena_table is True te an extra athena task is created. + def _generate_dagger_output(self, node: dict) -> List[Dict]: + """Generate dagger output tasks for a databricks-dbt model node. + + Creates output task configurations based on the model's materialization type: + - Ephemeral models produce a dummy task + - Table/incremental models produce databricks + S3 tasks + - Optionally adds an Athena task if create_external_athena_table is True + Args: - node: The extracted node from the manifest.json file + node: The dbt model node from manifest.json. Returns: - dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node - + List of dagger output task configuration dicts. """ materialized_type = node.get("config", {}).get("materialized") if materialized_type == "ephemeral": diff --git a/dagger/utilities/module.py b/dagger/utilities/module.py index 7f33690..a12c25f 100644 --- a/dagger/utilities/module.py +++ b/dagger/utilities/module.py @@ -51,7 +51,9 @@ def read_task_config(self, task): return content @staticmethod - def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2.Environment: + def load_plugins_to_jinja_environment( + environment: jinja2.Environment, + ) -> jinja2.Environment: """ Dynamically load all classes(plugins) from the folders defined in the conf.PLUGIN_DIRS variable. The folder contains all plugins that are part of the project. @@ -60,12 +62,20 @@ def load_plugins_to_jinja_environment(environment: jinja2.Environment) -> jinja2 """ for plugin_path in conf.PLUGIN_DIRS: for root, dirs, files in os.walk(plugin_path): - dirs[:] = [directory for directory in dirs if not directory.lower().startswith("test")] + dirs[:] = [ + directory + for directory in dirs + if not directory.lower().startswith("test") + ] for plugin_file in files: - if plugin_file.endswith(".py") and not (plugin_file.startswith("__") or plugin_file.startswith("test")): + if plugin_file.endswith(".py") and not ( + plugin_file.startswith("__") or plugin_file.startswith("test") + ): module_name = plugin_file.replace(".py", "") module_path = os.path.join(root, plugin_file) - spec = importlib.util.spec_from_file_location(module_name, module_path) + spec = importlib.util.spec_from_file_location( + module_name, module_path + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -84,8 +94,7 @@ def replace_template_parameters(_task_str, _template_parameters): return ( rendered_task # TODO Remove this hack and use Jinja escaping instead of special expression in template files - .replace("__CBS__", "{") - .replace("__CBE__", "}") + .replace("__CBS__", "{").replace("__CBE__", "}") ) @staticmethod @@ -102,7 +111,7 @@ def generate_task_configs(self): template_parameters = {} template_parameters.update(self._default_parameters or {}) template_parameters.update(attrs) - template_parameters['branch_name'] = branch_name + template_parameters["branch_name"] = branch_name template_parameters.update(self._jinja_parameters) for task, task_yaml in self._tasks.items(): From 1a7ff6ecdf0f79f2a6dd434e91ab60f956d77e85 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:07:40 +0100 Subject: [PATCH 2/8] Add databricks provider to Airflow dependencies Required for the new DLT creator which uses DatabricksRunNowOperator and DatabricksHook from apache-airflow-providers-databricks. --- reqs/dev.txt | 2 +- reqs/test.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/reqs/dev.txt b/reqs/dev.txt index b39d00a..cb96f9e 100644 --- a/reqs/dev.txt +++ b/reqs/dev.txt @@ -1,5 +1,5 @@ pip==24.0 -apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" +apache-airflow[amazon,databricks,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" black==22.10.0 bumpversion==0.6.0 coverage==7.4.4 diff --git a/reqs/test.txt b/reqs/test.txt index 6bb6c2e..3b97347 100644 --- a/reqs/test.txt +++ b/reqs/test.txt @@ -1,4 +1,4 @@ -apache-airflow[amazon,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" +apache-airflow[amazon,databricks,postgres,s3,statsd]==2.11.0 --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.11.0/constraints-3.12.txt" pytest-cov==4.0.0 pytest==7.2.0 graphviz From 435beedec013ba2e77072d31f09bd24648b2ccb1 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:10:30 +0100 Subject: [PATCH 3/8] Add databricks provider to production Airflow image Required for DLT creator to work in production. --- dockers/airflow/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockers/airflow/Dockerfile b/dockers/airflow/Dockerfile index 2bd40d5..71e73d7 100644 --- a/dockers/airflow/Dockerfile +++ b/dockers/airflow/Dockerfile @@ -52,7 +52,7 @@ RUN curl -Ls "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awsc RUN pip install -U --progress-bar off --no-cache-dir pip setuptools wheel COPY requirements.txt requirements.txt -RUN pip install --progress-bar off --no-cache-dir apache-airflow[amazon,postgres,s3,statsd]==$AIRFLOW_VERSION --constraint $AIRFLOW_CONSTRAINTS && \ +RUN pip install --progress-bar off --no-cache-dir apache-airflow[amazon,databricks,postgres,s3,statsd]==$AIRFLOW_VERSION --constraint $AIRFLOW_CONSTRAINTS && \ pip install --progress-bar off --no-cache-dir -r requirements.txt && \ apt-get purge --auto-remove -yq $BUILD_DEPS && \ apt-get autoremove --purge -yq && \ From 1a82c9d58cbd6c95904cf56ad18faa8fa9fbef4f Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Mon, 12 Jan 2026 16:39:30 +0100 Subject: [PATCH 4/8] Remove plugin --- .../databricks_dlt_creator.py | 16 +- dagger/plugins/dlt_task_generator/__init__.py | 6 - .../dlt_task_generator/bundle_parser.py | 309 --------------- .../dlt_task_generator/dlt_task_generator.py | 374 ------------------ 4 files changed, 7 insertions(+), 698 deletions(-) delete mode 100644 dagger/plugins/dlt_task_generator/__init__.py delete mode 100644 dagger/plugins/dlt_task_generator/bundle_parser.py delete mode 100644 dagger/plugins/dlt_task_generator/dlt_task_generator.py diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py index 66034a6..d3f538c 100644 --- a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -34,8 +34,8 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: _logger.warning(f"No run_id found in XCom for task {ti.task_id}") return - # Get the databricks_conn_id from the operator - databricks_conn_id = getattr(ti.task, "databricks_conn_id", "databricks_default") + # Get the databricks_conn_id from the operator (set during operator creation) + databricks_conn_id = ti.task.databricks_conn_id try: hook = DatabricksHook(databricks_conn_id=databricks_conn_id) @@ -88,14 +88,12 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: DatabricksRunNowOperator, ) - # Get task parameters + # Get task parameters - defaults are handled in DatabricksDLTTask job_name: str = self._task.job_name - databricks_conn_id: str = getattr( - self._task, "databricks_conn_id", "databricks_default" - ) - wait_for_completion: bool = getattr(self._task, "wait_for_completion", True) - poll_interval_seconds: int = getattr(self._task, "poll_interval_seconds", 30) - timeout_seconds: int = getattr(self._task, "timeout_seconds", 3600) + databricks_conn_id: str = self._task.databricks_conn_id + wait_for_completion: bool = self._task.wait_for_completion + poll_interval_seconds: int = self._task.poll_interval_seconds + timeout_seconds: int = self._task.timeout_seconds # DatabricksRunNowOperator triggers an existing Databricks Job by name # The job must have a pipeline_task that references the DLT pipeline diff --git a/dagger/plugins/dlt_task_generator/__init__.py b/dagger/plugins/dlt_task_generator/__init__.py deleted file mode 100644 index 49e4a17..0000000 --- a/dagger/plugins/dlt_task_generator/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""DLT Task Generator plugin for generating Dagger configs from Databricks Asset Bundles.""" - -from dagger.plugins.dlt_task_generator.bundle_parser import DatabricksBundleParser -from dagger.plugins.dlt_task_generator.dlt_task_generator import DLTTaskGenerator - -__all__ = ["DatabricksBundleParser", "DLTTaskGenerator"] diff --git a/dagger/plugins/dlt_task_generator/bundle_parser.py b/dagger/plugins/dlt_task_generator/bundle_parser.py deleted file mode 100644 index 5a15ed7..0000000 --- a/dagger/plugins/dlt_task_generator/bundle_parser.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Parse Databricks Asset Bundle YAML files for DLT pipeline configuration.""" - -import logging -import re -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Optional - -import yaml - -_logger = logging.getLogger(__name__) - - -@dataclass -class TableConfig: - """Configuration for a single table in a DLT pipeline. - - Attributes: - database: Source database name. - table: Source table name. - changelog_type: Type of changelog source ('dynamodb' or 'postgres'). - unique_keys: List of columns that uniquely identify a row. - scd_type: Slowly changing dimension type (1 or 2). - """ - - database: str - table: str - changelog_type: str # 'dynamodb' or 'postgres' - unique_keys: list[str] = field(default_factory=list) - scd_type: int = 1 - - @property - def source_schema(self) -> str: - """Get the source schema name for Athena/Glue catalog. - - For DynamoDB: ddb_changelogs - For PostgreSQL: pg_changelogs_kafka_{database_normalized} - """ - if self.changelog_type == "dynamodb": - return self.database - elif self.changelog_type == "postgres": - # Normalize database name (replace hyphens with underscores) - db_normalized = self.database.replace("-", "_") - return f"pg_changelogs_kafka_{db_normalized}" - else: - return self.database - - @property - def silver_table_name(self) -> str: - """Get the silver table name produced by DLT.""" - return f"silver_{self.table}" - - @property - def bronze_table_name(self) -> str: - """Get the bronze table name produced by DLT.""" - return f"bronze_{self.table}" - - -@dataclass -class PipelineConfig: - """Configuration for a DLT pipeline parsed from Databricks Asset Bundle. - - Attributes: - name: Pipeline/bundle name. - catalog: Target Unity Catalog name. - schema: Target schema name. - tables: List of table configurations for the pipeline. - targets: Target environment configurations (dev/prod). - variables: Variable definitions from databricks.yml. - tags: Pipeline tags. - """ - - name: str - catalog: str - schema: str - tables: list[TableConfig] = field(default_factory=list) - targets: dict[str, Any] = field(default_factory=dict) - variables: dict[str, Any] = field(default_factory=dict) - tags: dict[str, str] = field(default_factory=dict) - - -class DatabricksBundleParser: - """Parse Databricks Asset Bundle YAML files (databricks.yml and tables.yml). - - This parser extracts pipeline configuration from Databricks Asset Bundles, - including the target catalog, schema, table definitions, and job configuration. - It resolves variable references to Dagger environment variable format. - - Attributes: - _databricks_yml_path: Path to the databricks.yml file. - _tables_yml_path: Path to the tables.yml file. - _databricks_config: Parsed databricks.yml content. - _tables_config: Parsed tables.yml content. - _pipeline_config: Cached PipelineConfig instance. - """ - - def __init__( - self, - databricks_yml_path: Path, - tables_yml_path: Optional[Path] = None, - ) -> None: - """Initialize the parser with paths to bundle YAML files. - - Args: - databricks_yml_path: Path to the databricks.yml file. - tables_yml_path: Optional path to the tables.yml file. If not provided, - will look for tables.yml in the same directory. - """ - self._databricks_yml_path = Path(databricks_yml_path) - self._tables_yml_path = ( - Path(tables_yml_path) - if tables_yml_path - else self._databricks_yml_path.parent / "tables.yml" - ) - - self._databricks_config = self._load_yaml(self._databricks_yml_path) - self._tables_config = ( - self._load_yaml(self._tables_yml_path) - if self._tables_yml_path.exists() - else {} - ) - - self._pipeline_config: Optional[PipelineConfig] = None - - @staticmethod - def _load_yaml(path: Path) -> dict[str, Any]: - """Load and parse a YAML file. - - Args: - path: Path to the YAML file. - - Returns: - Parsed YAML content as a dictionary. - - Raises: - yaml.YAMLError: If the YAML file is malformed. - """ - try: - with open(path, "r") as f: - return yaml.safe_load(f) or {} - except FileNotFoundError: - _logger.warning(f"YAML file not found: {path}") - return {} - except yaml.YAMLError as e: - _logger.error(f"Error parsing YAML file {path}: {e}") - raise - - def _resolve_variable(self, value: str) -> str: - """Resolve Databricks bundle variable references like ${var.catalog}. - - Args: - value: String that may contain variable references - - Returns: - Resolved string with environment variable format for Dagger - """ - if not isinstance(value, str): - return value - - # Match ${var.variable_name} pattern - var_pattern = re.compile(r"\$\{var\.(\w+)\}") - - def replace_var(match): - var_name = match.group(1) - # Get the default value from variables section - var_config = self._databricks_config.get("variables", {}).get(var_name, {}) - default_value = var_config.get("default", "") - - # Map to Dagger environment variables - if var_name == "catalog": - # Map catalog to Dagger's ${ENV_MARTS} pattern - return "${ENV_MARTS}" - return default_value - - return var_pattern.sub(replace_var, value) - - def _parse_tables(self) -> list[TableConfig]: - """Parse table configurations from tables.yml. - - Returns: - List of TableConfig instances for each table defined in the bundle. - """ - tables = [] - defaults = self._tables_config.get("defaults", {}) - default_scd_type = defaults.get("scd_type", 1) - - for table_config in self._tables_config.get("tables", []): - tables.append( - TableConfig( - database=table_config.get("database", ""), - table=table_config.get("table", ""), - changelog_type=table_config.get("changelog_type", "dynamodb"), - unique_keys=table_config.get("unique_keys", []), - scd_type=table_config.get("scd_type", default_scd_type), - ) - ) - - return tables - - def _parse_pipeline(self) -> PipelineConfig: - """Parse pipeline configuration from databricks.yml. - - Extracts bundle name, variables, targets, and pipeline-specific settings - from the Databricks Asset Bundle configuration. - - Returns: - PipelineConfig instance with all parsed configuration. - """ - bundle_name = self._databricks_config.get("bundle", {}).get("name", "") - variables = self._databricks_config.get("variables", {}) - targets = self._databricks_config.get("targets", {}) - - # Get pipeline configuration from resources - resources = self._databricks_config.get("resources", {}) - pipelines = resources.get("pipelines", {}) - - # Get the first pipeline (usually matches bundle name) - pipeline_key = bundle_name or next(iter(pipelines.keys()), "") - pipeline_config = pipelines.get(pipeline_key, {}) - - catalog = self._resolve_variable(pipeline_config.get("catalog", "")) - schema = pipeline_config.get("schema", "") - tags = pipeline_config.get("tags", {}) - - return PipelineConfig( - name=bundle_name, - catalog=catalog, - schema=schema, - tables=self._parse_tables(), - targets=targets, - variables=variables, - tags=tags, - ) - - def parse(self) -> PipelineConfig: - """Parse the Databricks Asset Bundle and return pipeline configuration. - - Returns: - PipelineConfig with all parsed configuration - """ - if self._pipeline_config is None: - self._pipeline_config = self._parse_pipeline() - return self._pipeline_config - - def get_bundle_name(self) -> str: - """Return the bundle/pipeline name. - - Returns: - The bundle name from databricks.yml. - """ - return self.parse().name - - def get_catalog(self) -> str: - """Return the target catalog with Dagger environment variable format. - - Returns: - Catalog name with variables resolved to Dagger format (e.g., ${ENV_MARTS}). - """ - return self.parse().catalog - - def get_schema(self) -> str: - """Return the target schema. - - Returns: - Target schema name for the DLT pipeline. - """ - return self.parse().schema - - def get_tables(self) -> list[TableConfig]: - """Return the list of table configurations. - - Returns: - List of TableConfig instances for all tables in the pipeline. - """ - return self.parse().tables - - def get_targets(self) -> dict[str, Any]: - """Return target configurations (dev/prod). - - Returns: - Dictionary of target environment configurations. - """ - return self.parse().targets - - def get_variables(self) -> dict[str, Any]: - """Return variable definitions. - - Returns: - Dictionary of variable definitions from databricks.yml. - """ - return self.parse().variables - - def get_job_name(self) -> str: - """Get the Databricks Job name that triggers this pipeline. - - Looks for a job defined in resources.jobs that wraps the pipeline. - Falls back to a default naming convention if no job is defined. - - Returns: - The job name to use with DatabricksRunNowOperator - """ - resources = self._databricks_config.get("resources", {}) - jobs = resources.get("jobs", {}) - - # Return the first job's name, or use default naming convention - for job_config in jobs.values(): - return job_config.get("name", f"dlt-{self.get_bundle_name()}") - - return f"dlt-{self.get_bundle_name()}" diff --git a/dagger/plugins/dlt_task_generator/dlt_task_generator.py b/dagger/plugins/dlt_task_generator/dlt_task_generator.py deleted file mode 100644 index 6136227..0000000 --- a/dagger/plugins/dlt_task_generator/dlt_task_generator.py +++ /dev/null @@ -1,374 +0,0 @@ -"""Generate Dagger task configurations from Databricks DLT bundle definitions.""" - -import logging -import os -from pathlib import Path -from typing import Any, Optional - -import yaml - -from dagger.plugins.dlt_task_generator.bundle_parser import ( - DatabricksBundleParser, - PipelineConfig, - TableConfig, -) - -_logger = logging.getLogger(__name__) - -# Default path to the DLT pipelines repository (can be overridden via env var) -DEFAULT_DLT_PIPELINES_REPO = os.getenv( - "DLT_PIPELINES_REPO", - str(Path(__file__).parent.parent.parent.parent.parent / "dataeng-databricks-dlt-pipelines"), -) - - -class DLTTaskGenerator: - """Generate Dagger task configurations from Databricks DLT bundle definitions. - - This generator reads Databricks Asset Bundle configurations and produces - Dagger-compatible YAML task configurations for DLT pipelines. - - Attributes: - ATHENA_TASK_BASE: Base configuration for Athena input tasks. - DATABRICKS_TASK_BASE: Base configuration for Databricks output tasks. - DUMMY_TASK_BASE: Base configuration for dummy tasks. - """ - - ATHENA_TASK_BASE: dict[str, str] = {"type": "athena"} - DATABRICKS_TASK_BASE: dict[str, str] = {"type": "databricks"} - DUMMY_TASK_BASE: dict[str, str] = {"type": "dummy"} - - def __init__(self, dlt_repo_path: Optional[str] = None) -> None: - """Initialize the generator with path to DLT pipelines repository. - - Args: - dlt_repo_path: Path to the dataeng-databricks-dlt-pipelines repository. - Defaults to DLT_PIPELINES_REPO env var or sibling directory. - """ - self._dlt_repo_path = Path(dlt_repo_path or DEFAULT_DLT_PIPELINES_REPO) - self._pipelines: dict[str, DatabricksBundleParser] = {} - self._load_all_pipelines() - - def _load_all_pipelines(self) -> None: - """Load all pipeline bundles from the DLT repository. - - Scans the pipelines directory and loads each valid Databricks Asset Bundle - found. Bundles are identified by the presence of a databricks.yml file. - """ - pipelines_dir = self._dlt_repo_path / "pipelines" - - if not pipelines_dir.exists(): - _logger.warning(f"DLT pipelines directory not found: {pipelines_dir}") - return - - for pipeline_dir in pipelines_dir.iterdir(): - if not pipeline_dir.is_dir(): - continue - - databricks_yml = pipeline_dir / "databricks.yml" - if not databricks_yml.exists(): - continue - - tables_yml = pipeline_dir / "tables.yml" - try: - parser = DatabricksBundleParser(databricks_yml, tables_yml) - pipeline_name = parser.get_bundle_name() or pipeline_dir.name - self._pipelines[pipeline_name] = parser - _logger.info(f"Loaded DLT pipeline: {pipeline_name}") - except Exception as e: - _logger.error(f"Error loading pipeline from {pipeline_dir}: {e}") - - def get_pipeline_names(self) -> list[str]: - """Return list of available DLT pipeline names.""" - return list(self._pipelines.keys()) - - def get_pipeline_config(self, pipeline_name: str) -> PipelineConfig: - """Get the parsed pipeline configuration. - - Args: - pipeline_name: Name of the pipeline - - Returns: - PipelineConfig object - - Raises: - ValueError: If pipeline not found - """ - if pipeline_name not in self._pipelines: - raise ValueError( - f"Unknown pipeline: {pipeline_name}. " - f"Available pipelines: {self.get_pipeline_names()}" - ) - return self._pipelines[pipeline_name].parse() - - def _get_athena_input( - self, table: TableConfig, follow_external_dependency: bool = True - ) -> dict[str, Any]: - """Generate an Athena input task for a source changelog table. - - Args: - table: Table configuration from the DLT bundle. - follow_external_dependency: Whether to create an ExternalTaskSensor - for cross-pipeline dependency tracking. - - Returns: - Dagger Athena task configuration dict. - """ - task = self.ATHENA_TASK_BASE.copy() - task.update( - { - "schema": table.source_schema, - "table": table.table, - "name": f"{table.source_schema}__{table.table}_athena", - } - ) - if follow_external_dependency: - task["follow_external_dependency"] = True - return task - - def _get_databricks_output( - self, table: TableConfig, catalog: str, schema: str - ) -> dict[str, Any]: - """Generate a Databricks output task for a silver table. - - Args: - table: Table configuration from the DLT bundle. - catalog: Target Unity Catalog name (e.g., ${ENV_MARTS}). - schema: Target schema name. - - Returns: - Dagger Databricks task configuration dict. - """ - task = self.DATABRICKS_TASK_BASE.copy() - # Normalize catalog name for task naming - catalog_name = catalog.replace("${", "").replace("}", "").lower() - task.update( - { - "catalog": catalog, - "schema": schema, - "table": table.silver_table_name, - "name": f"{catalog_name}__{schema}__{table.silver_table_name}_databricks", - } - ) - return task - - def get_inputs( - self, pipeline_name: str, follow_external_dependency: bool = True - ) -> list[dict[str, Any]]: - """Generate input dependencies for a DLT pipeline task. - - These are the source changelog tables that the DLT pipeline reads from. - - Args: - pipeline_name: Name of the DLT pipeline. - follow_external_dependency: Whether to create ExternalTaskSensors - for cross-pipeline dependency tracking. - - Returns: - List of Dagger input task configurations. - """ - config = self.get_pipeline_config(pipeline_name) - inputs = [] - - for table in config.tables: - input_task = self._get_athena_input(table, follow_external_dependency) - inputs.append(input_task) - - return inputs - - def get_outputs(self, pipeline_name: str) -> list[dict[str, Any]]: - """Generate output declarations for a DLT pipeline task. - - These are the silver tables produced by the DLT pipeline. - - Args: - pipeline_name: Name of the DLT pipeline. - - Returns: - List of Dagger output task configurations. - """ - config = self.get_pipeline_config(pipeline_name) - outputs = [] - - for table in config.tables: - output_task = self._get_databricks_output( - table, config.catalog, config.schema - ) - outputs.append(output_task) - - return outputs - - def get_task_parameters(self, pipeline_name: str) -> dict[str, Any]: - """Generate task parameters for triggering a DLT pipeline via Databricks Job. - - Args: - pipeline_name: Name of the DLT pipeline. - - Returns: - Dict of task parameters for the DatabricksRunNowOperator. - """ - parser = self._pipelines[pipeline_name] - return { - "job_name": parser.get_job_name(), - "databricks_conn_id": "${DATABRICKS_CONN_ID}", - "wait_for_completion": True, - "poll_interval_seconds": 30, - "timeout_seconds": 3600, - } - - def generate_task_config( - self, - pipeline_name: str, - description: Optional[str] = None, - follow_external_dependency: bool = True, - ) -> dict[str, Any]: - """Generate a complete Dagger task configuration for a DLT pipeline. - - Args: - pipeline_name: Name of the DLT pipeline. - description: Optional task description. Defaults to auto-generated. - follow_external_dependency: Whether to create ExternalTaskSensors for inputs. - - Returns: - Complete Dagger task configuration dict ready for YAML serialization. - """ - config = self.get_pipeline_config(pipeline_name) - - task_config = { - "type": "databricks_dlt", - "description": description or f"Run DLT pipeline {pipeline_name}", - "inputs": self.get_inputs(pipeline_name, follow_external_dependency), - "outputs": self.get_outputs(pipeline_name), - "airflow_task_parameters": { - "retries": 2, - "retry_delay": 300, - }, - "template_parameters": {}, - "task_parameters": self.get_task_parameters(pipeline_name), - } - - return task_config - - def generate_pipeline_config( - self, - pipeline_name: str, - schedule: str = "0 * * * *", - owner: str = "dataeng@choco.com", - ) -> dict[str, Any]: - """Generate a Dagger pipeline.yaml configuration for a DLT pipeline DAG. - - Args: - pipeline_name: Name of the DLT pipeline. - schedule: Cron schedule expression. Defaults to hourly. - owner: Pipeline owner email address. - - Returns: - Dagger pipeline.yaml configuration dict ready for YAML serialization. - """ - config = self.get_pipeline_config(pipeline_name) - - return { - "owner": owner, - "description": f"DLT Pipeline - {pipeline_name}", - "schedule": schedule, - "start_date": "2024-01-01T00:00", - "airflow_parameters": { - "default_args": { - "retries": 2, - "retry_delay": 180, - "depends_on_past": False, - }, - "dag_parameters": { - "catchup": False, - "max_active_runs": 1, - "tags": ["dlt", "databricks", pipeline_name], - }, - }, - "alerts": [ - { - "type": "slack", - "channel": "#${ENV}-airflow-alerts", - "mentions": ["@dataeng-oncall"], - } - ], - } - - def write_task_config( - self, pipeline_name: str, output_path: Path, **kwargs: Any - ) -> Path: - """Write a task configuration to a YAML file. - - Args: - pipeline_name: Name of the DLT pipeline. - output_path: Directory to write the file to. - **kwargs: Additional arguments passed to generate_task_config. - - Returns: - Path to the written YAML file. - """ - output_path = Path(output_path) - output_path.mkdir(parents=True, exist_ok=True) - - task_config = self.generate_task_config(pipeline_name, **kwargs) - file_path = output_path / f"{pipeline_name}_dlt.yaml" - - with open(file_path, "w") as f: - # Add autogenerated marker - task_config["autogenerated_by_dagger"] = f"dlt_task_generator:{pipeline_name}" - yaml.dump(task_config, f, default_flow_style=False, sort_keys=False) - - _logger.info(f"Generated task config: {file_path}") - return file_path - - def write_pipeline_config( - self, pipeline_name: str, output_path: Path, **kwargs: Any - ) -> Path: - """Write a pipeline configuration to a YAML file. - - Args: - pipeline_name: Name of the DLT pipeline. - output_path: Directory to write the file to. - **kwargs: Additional arguments passed to generate_pipeline_config. - - Returns: - Path to the written YAML file. - """ - output_path = Path(output_path) - output_path.mkdir(parents=True, exist_ok=True) - - pipeline_config = self.generate_pipeline_config(pipeline_name, **kwargs) - file_path = output_path / "pipeline.yaml" - - with open(file_path, "w") as f: - yaml.dump(pipeline_config, f, default_flow_style=False, sort_keys=False) - - _logger.info(f"Generated pipeline config: {file_path}") - return file_path - - def generate_all(self, output_base_path: Path) -> list[Path]: - """Generate all DLT pipeline configurations. - - Creates pipeline.yaml and task configuration files for each loaded - DLT pipeline in the repository. - - Args: - output_base_path: Base directory for output (e.g., dags/dlt/). - - Returns: - List of paths to all generated files. - """ - output_base_path = Path(output_base_path) - generated_files = [] - - for pipeline_name in self.get_pipeline_names(): - pipeline_output_path = output_base_path / pipeline_name - - # Generate pipeline.yaml - pipeline_file = self.write_pipeline_config(pipeline_name, pipeline_output_path) - generated_files.append(pipeline_file) - - # Generate task config - task_file = self.write_task_config(pipeline_name, pipeline_output_path) - generated_files.append(task_file) - - return generated_files From bfeb772df1acbe648205231bac4bc745b0485f75 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:34:49 +0100 Subject: [PATCH 5/8] Update the module to support yaml --- dagger/cli/module.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/dagger/cli/module.py b/dagger/cli/module.py index 931e809..d77897d 100644 --- a/dagger/cli/module.py +++ b/dagger/cli/module.py @@ -1,19 +1,34 @@ +import json + import click +import yaml + from dagger.utilities.module import Module from dagger.utils import Printer -import json def parse_key_value(ctx, param, value): - #print('YYY', value) + """Parse key=value pairs where value is a path to JSON or YAML file. + + Args: + ctx: Click context. + param: Click parameter. + value: List of key=value pairs. + + Returns: + Dictionary mapping variable names to parsed file contents. + """ if not value: return {} key_value_dict = {} for pair in value: try: key, val_file_path = pair.split('=', 1) - #print('YYY', key, val_file_path, pair) - val = json.load(open(val_file_path)) + with open(val_file_path, 'r') as f: + if val_file_path.endswith(('.yaml', '.yml')): + val = yaml.safe_load(f) + else: + val = json.load(f) key_value_dict[key] = val except ValueError: raise click.BadParameter(f"Key-value pair '{pair}' is not in the format key=value") @@ -22,7 +37,7 @@ def parse_key_value(ctx, param, value): @click.command() @click.option("--config_file", "-c", help="Path to module config file") @click.option("--target_dir", "-t", help="Path to directory to generate the task configs to") -@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Path to jinja parameters json file in the format: =") +@click.option("--jinja_parameters", "-j", callback=parse_key_value, multiple=True, default=None, help="Jinja parameters file in the format: =") def generate_tasks(config_file: str, target_dir: str, jinja_parameters: dict) -> None: """ Generating tasks for a module based on config From 63d03452c5f78e299712d512740ede9479bb18cf Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:21:39 +0100 Subject: [PATCH 6/8] Add comprehensive tests and improve Databricks DLT components - Add type hints and docstrings to DatabricksIO - Improve error handling in DatabricksDLTCreator with ImportError support - Add validation for empty job_name in DatabricksDLTCreator - Add comprehensive test coverage for DatabricksIO, DatabricksDLTTask, and DatabricksDLTCreator - All Databricks components now have 100% test coverage --- .../databricks_dlt_creator.py | 17 +- dagger/pipeline/ios/databricks_io.py | 106 ++++++- .../airflow/operator_creators/__init__.py | 0 .../test_databricks_dlt_creator.py | 276 ++++++++++++++++++ .../pipeline/tasks/databricks_dlt_task.yaml | 22 ++ tests/pipeline/ios/test_databricks_io.py | 222 +++++++++++++- .../tasks/test_databricks_dlt_task.py | 176 +++++++++++ 7 files changed, 798 insertions(+), 21 deletions(-) create mode 100644 tests/dag_creator/airflow/operator_creators/__init__.py create mode 100644 tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py create mode 100644 tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml create mode 100644 tests/pipeline/tasks/test_databricks_dlt_task.py diff --git a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py index d3f538c..87a11ac 100644 --- a/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py +++ b/dagger/dag_creator/airflow/operator_creators/databricks_dlt_creator.py @@ -21,8 +21,6 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: Args: context: Airflow context dictionary containing task instance and other metadata. """ - from airflow.providers.databricks.hooks.databricks import DatabricksHook - ti = context.get("task_instance") if not ti: _logger.warning("No task instance in context, cannot cancel Databricks run") @@ -37,10 +35,18 @@ def _cancel_databricks_run(context: dict[str, Any]) -> None: # Get the databricks_conn_id from the operator (set during operator creation) databricks_conn_id = ti.task.databricks_conn_id + # Import here to avoid import errors if databricks provider not installed + # and to only import when actually needed (after early returns) try: + from airflow.providers.databricks.hooks.databricks import DatabricksHook + hook = DatabricksHook(databricks_conn_id=databricks_conn_id) hook.cancel_run(run_id) _logger.info(f"Cancelled Databricks run {run_id} for task {ti.task_id}") + except ImportError: + _logger.error( + "airflow-providers-databricks is not installed, cannot cancel run" + ) except Exception as e: _logger.error(f"Failed to cancel Databricks run {run_id}: {e}") @@ -80,6 +86,9 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: Returns: A configured DatabricksRunNowOperator instance. + + Raises: + ValueError: If job_name is empty or not provided. """ # Import here to avoid import errors if databricks provider not installed from datetime import timedelta @@ -90,6 +99,10 @@ def _create_operator(self, **kwargs: Any) -> BaseOperator: # Get task parameters - defaults are handled in DatabricksDLTTask job_name: str = self._task.job_name + if not job_name: + raise ValueError( + f"job_name is required for DatabricksDLTTask '{self._task.name}'" + ) databricks_conn_id: str = self._task.databricks_conn_id wait_for_completion: bool = self._task.wait_for_completion poll_interval_seconds: int = self._task.poll_interval_seconds diff --git a/dagger/pipeline/ios/databricks_io.py b/dagger/pipeline/ios/databricks_io.py index 15be2c1..7c7b4d2 100644 --- a/dagger/pipeline/ios/databricks_io.py +++ b/dagger/pipeline/ios/databricks_io.py @@ -1,12 +1,45 @@ +"""IO representation for Databricks Unity Catalog tables.""" + +from typing import Any + from dagger.pipeline.io import IO from dagger.utilities.config_validator import Attribute class DatabricksIO(IO): - ref_name = "databricks" + """IO representation for Databricks Unity Catalog tables. + + Represents a table in Databricks Unity Catalog with catalog.schema.table naming. + Used to define inputs and outputs for tasks that read from or write to + Databricks tables. + + Attributes: + ref_name: Reference name used by IOFactory to instantiate this IO type. + catalog: Databricks Unity Catalog name. + schema: Schema/database name within the catalog. + table: Table name. + + Example YAML configuration: + type: databricks + name: my_output_table + catalog: prod_catalog + schema: analytics + table: user_metrics + """ + + ref_name: str = "databricks" @classmethod - def init_attributes(cls, orig_cls): + def init_attributes(cls, orig_cls: type) -> None: + """Initialize configuration attributes for YAML parsing. + + Registers all attributes that can be specified in the YAML configuration. + Called by the IO metaclass during class creation. + + Args: + orig_cls: The original class being initialized (used for attribute + registration). + """ cls.add_config_attributes( [ Attribute(attribute_name="catalog"), @@ -15,32 +48,81 @@ def init_attributes(cls, orig_cls): ] ) - def __init__(self, io_config, config_location): + def __init__(self, io_config: dict[str, Any], config_location: str) -> None: + """Initialize a DatabricksIO instance. + + Args: + io_config: Dictionary containing the IO configuration from YAML. + config_location: Path to the configuration file for error reporting. + + Raises: + DaggerMissingFieldException: If required fields (catalog, schema, table) + are missing from the configuration. + """ super().__init__(io_config, config_location) - self._catalog = self.parse_attribute("catalog") - self._schema = self.parse_attribute("schema") - self._table = self.parse_attribute("table") + self._catalog: str = self.parse_attribute("catalog") + self._schema: str = self.parse_attribute("schema") + self._table: str = self.parse_attribute("table") - def alias(self): + def alias(self) -> str: + """Return the unique alias for this IO in databricks:// URI format. + + The alias is used for dataset lineage tracking and dependency resolution + across pipelines. + + Returns: + A unique identifier string in the format + 'databricks://{catalog}/{schema}/{table}'. + """ return f"databricks://{self._catalog}/{self._schema}/{self._table}" @property - def rendered_name(self): + def rendered_name(self) -> str: + """Return the fully qualified table name in dot notation. + + This format is used in SQL queries and Databricks API calls. + + Returns: + The table name in '{catalog}.{schema}.{table}' format. + """ return f"{self._catalog}.{self._schema}.{self._table}" @property - def airflow_name(self): + def airflow_name(self) -> str: + """Return an Airflow-safe identifier for this table. + + Airflow task/dataset IDs cannot contain dots, so this returns a + hyphen-separated format suitable for use in Airflow contexts. + + Returns: + The table name in 'databricks-{catalog}-{schema}-{table}' format. + """ return f"databricks-{self._catalog}-{self._schema}-{self._table}" @property - def catalog(self): + def catalog(self) -> str: + """Return the Databricks Unity Catalog name. + + Returns: + The catalog name. + """ return self._catalog @property - def schema(self): + def schema(self) -> str: + """Return the schema/database name within the catalog. + + Returns: + The schema name. + """ return self._schema @property - def table(self): + def table(self) -> str: + """Return the table name. + + Returns: + The table name. + """ return self._table diff --git a/tests/dag_creator/airflow/operator_creators/__init__.py b/tests/dag_creator/airflow/operator_creators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py b/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py new file mode 100644 index 0000000..39de91b --- /dev/null +++ b/tests/dag_creator/airflow/operator_creators/test_databricks_dlt_creator.py @@ -0,0 +1,276 @@ +"""Unit tests for DatabricksDLTCreator.""" + +import sys +import unittest +from datetime import timedelta +from unittest.mock import MagicMock, patch + +from dagger.dag_creator.airflow.operator_creators.databricks_dlt_creator import ( + DatabricksDLTCreator, + _cancel_databricks_run, +) + + +class TestDatabricksDLTCreator(unittest.TestCase): + """Test cases for DatabricksDLTCreator.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.mock_task = MagicMock() + self.mock_task.name = "test_dlt_task" + self.mock_task.job_name = "test-dlt-job" + self.mock_task.databricks_conn_id = "databricks_default" + self.mock_task.wait_for_completion = True + self.mock_task.poll_interval_seconds = 30 + self.mock_task.timeout_seconds = 3600 + self.mock_task.cancel_on_kill = True + + self.mock_dag = MagicMock() + + # Set up mock for DatabricksRunNowOperator + self.mock_operator = MagicMock() + self.mock_operator_class = MagicMock(return_value=self.mock_operator) + self.mock_databricks_module = MagicMock() + self.mock_databricks_module.DatabricksRunNowOperator = self.mock_operator_class + + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(DatabricksDLTCreator.ref_name, "databricks_dlt") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator(self) -> None: + """Test operator creation returns an operator instance.""" + mock_operator = MagicMock() + mock_operator_class = MagicMock(return_value=mock_operator) + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + operator = creator._create_operator() + + mock_operator_class.assert_called_once() + self.assertEqual(operator, mock_operator) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_maps_task_properties(self) -> None: + """Test that task properties are correctly mapped to operator.""" + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["dag"], self.mock_dag) + self.assertEqual(call_kwargs["task_id"], "test_dlt_task") + self.assertEqual(call_kwargs["databricks_conn_id"], "databricks_default") + self.assertEqual(call_kwargs["job_name"], "test-dlt-job") + self.assertEqual(call_kwargs["wait_for_termination"], True) + self.assertEqual(call_kwargs["polling_period_seconds"], 30) + self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=3600)) + self.assertTrue(call_kwargs["do_xcom_push"]) + self.assertEqual(call_kwargs["on_failure_callback"], _cancel_databricks_run) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_with_custom_values(self) -> None: + """Test operator creation with non-default values.""" + self.mock_task.databricks_conn_id = "custom_conn" + self.mock_task.wait_for_completion = False + self.mock_task.poll_interval_seconds = 60 + self.mock_task.timeout_seconds = 7200 + + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator() + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["databricks_conn_id"], "custom_conn") + self.assertEqual(call_kwargs["wait_for_termination"], False) + self.assertEqual(call_kwargs["polling_period_seconds"], 60) + self.assertEqual(call_kwargs["execution_timeout"], timedelta(seconds=7200)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_empty_job_name_raises_error(self) -> None: + """Test that empty job_name raises ValueError.""" + self.mock_task.job_name = "" + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + + with self.assertRaises(ValueError) as context: + creator._create_operator() + + self.assertIn("job_name is required", str(context.exception)) + self.assertIn("test_dlt_task", str(context.exception)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_none_job_name_raises_error(self) -> None: + """Test that None job_name raises ValueError.""" + self.mock_task.job_name = None + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + + with self.assertRaises(ValueError) as context: + creator._create_operator() + + self.assertIn("job_name is required", str(context.exception)) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.operators.databricks": MagicMock()}, + ) + def test_create_operator_passes_kwargs(self) -> None: + """Test that additional kwargs are passed to operator.""" + mock_operator_class = MagicMock() + sys.modules[ + "airflow.providers.databricks.operators.databricks" + ].DatabricksRunNowOperator = mock_operator_class + + creator = DatabricksDLTCreator(self.mock_task, self.mock_dag) + creator._create_operator(retries=3, retry_delay=60) + + call_kwargs = mock_operator_class.call_args[1] + + self.assertEqual(call_kwargs["retries"], 3) + self.assertEqual(call_kwargs["retry_delay"], 60) + + +class TestCancelDatabricksRun(unittest.TestCase): + """Test cases for _cancel_databricks_run callback.""" + + def test_cancel_run_no_task_instance(self) -> None: + """Test callback handles missing task instance gracefully.""" + context: dict = {} + + # Should not raise, just log warning + _cancel_databricks_run(context) + + def test_cancel_run_no_run_id(self) -> None: + """Test callback handles missing run_id gracefully.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = None + + context = {"task_instance": mock_ti} + + # Should not raise, just log warning + _cancel_databricks_run(context) + + mock_ti.xcom_pull.assert_called_once_with(task_ids="test_task", key="run_id") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_success(self) -> None: + """Test successful cancellation of Databricks run.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + _cancel_databricks_run(context) + + mock_hook_class.assert_called_once_with(databricks_conn_id="databricks_default") + mock_hook.cancel_run.assert_called_once_with("run_12345") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_handles_exception(self) -> None: + """Test callback handles cancellation errors gracefully.""" + mock_hook = MagicMock() + mock_hook.cancel_run.side_effect = Exception("API Error") + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + # Should not raise, just log error + _cancel_databricks_run(context) + + mock_hook.cancel_run.assert_called_once_with("run_12345") + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": MagicMock()}, + ) + def test_cancel_run_with_custom_conn_id(self) -> None: + """Test cancellation uses correct connection ID.""" + mock_hook = MagicMock() + mock_hook_class = MagicMock(return_value=mock_hook) + sys.modules[ + "airflow.providers.databricks.hooks.databricks" + ].DatabricksHook = mock_hook_class + + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_67890" + mock_ti.task.databricks_conn_id = "custom_databricks_conn" + + context = {"task_instance": mock_ti} + + _cancel_databricks_run(context) + + mock_hook_class.assert_called_once_with( + databricks_conn_id="custom_databricks_conn" + ) + + @patch.dict( + sys.modules, + {"airflow.providers.databricks.hooks.databricks": None}, + ) + def test_cancel_run_handles_import_error(self) -> None: + """Test callback handles missing databricks provider gracefully.""" + mock_ti = MagicMock() + mock_ti.task_id = "test_task" + mock_ti.xcom_pull.return_value = "run_12345" + mock_ti.task.databricks_conn_id = "databricks_default" + + context = {"task_instance": mock_ti} + + # Should not raise, just log error + _cancel_databricks_run(context) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml b/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml new file mode 100644 index 0000000..1902cf0 --- /dev/null +++ b/tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml @@ -0,0 +1,22 @@ +type: databricks_dlt +description: Test DLT pipeline task +inputs: + - type: athena + name: input_table + schema: test_schema + table: input_table +outputs: + - type: databricks + name: output_table + catalog: test_catalog + schema: test_schema + table: output_table +airflow_task_parameters: +template_parameters: +task_parameters: + job_name: test-dlt-job + databricks_conn_id: databricks_test + wait_for_completion: true + poll_interval_seconds: 60 + timeout_seconds: 7200 + cancel_on_kill: true diff --git a/tests/pipeline/ios/test_databricks_io.py b/tests/pipeline/ios/test_databricks_io.py index b1d0c45..e4a9456 100644 --- a/tests/pipeline/ios/test_databricks_io.py +++ b/tests/pipeline/ios/test_databricks_io.py @@ -1,17 +1,225 @@ +"""Unit tests for DatabricksIO.""" + import unittest -from dagger.pipeline.io_factory import databricks_io import yaml +from dagger.pipeline.ios import databricks_io +from dagger.utilities.exceptions import DaggerMissingFieldException + + +class TestDatabricksIO(unittest.TestCase): + """Test cases for DatabricksIO.""" -class DbIOTest(unittest.TestCase): def setUp(self) -> None: - with open('tests/fixtures/pipeline/ios/databricks_io.yaml', "r") as stream: + """Set up test fixtures.""" + with open("tests/fixtures/pipeline/ios/databricks_io.yaml", "r") as stream: config = yaml.safe_load(stream) self.db_io = databricks_io.DatabricksIO(config, "/") - def test_properties(self): - self.assertEqual(self.db_io.alias(), "databricks://test_catalog/test_schema/test_table") - self.assertEqual(self.db_io.rendered_name, "test_catalog.test_schema.test_table") - self.assertEqual(self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table") + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(databricks_io.DatabricksIO.ref_name, "databricks") + + def test_catalog(self) -> None: + """Test catalog property.""" + self.assertEqual(self.db_io.catalog, "test_catalog") + + def test_schema(self) -> None: + """Test schema property.""" + self.assertEqual(self.db_io.schema, "test_schema") + + def test_table(self) -> None: + """Test table property.""" + self.assertEqual(self.db_io.table, "test_table") + + def test_alias(self) -> None: + """Test alias method returns databricks:// URI format.""" + self.assertEqual( + self.db_io.alias(), "databricks://test_catalog/test_schema/test_table" + ) + + def test_rendered_name(self) -> None: + """Test rendered_name returns dot-separated format.""" + self.assertEqual( + self.db_io.rendered_name, "test_catalog.test_schema.test_table" + ) + + def test_airflow_name(self) -> None: + """Test airflow_name returns hyphen-separated format.""" + self.assertEqual( + self.db_io.airflow_name, "databricks-test_catalog-test_schema-test_table" + ) + + def test_name(self) -> None: + """Test name property from base IO class.""" + self.assertEqual(self.db_io.name, "test") + + def test_has_dependency_default(self) -> None: + """Test that has_dependency defaults to True.""" + self.assertTrue(self.db_io.has_dependency) + + +class TestDatabricksIOInlineConfig(unittest.TestCase): + """Test cases for DatabricksIO with inline configuration.""" + + def test_with_minimal_config(self) -> None: + """Test DatabricksIO with minimal required configuration.""" + config = { + "type": "databricks", + "name": "minimal_table", + "catalog": "my_catalog", + "schema": "my_schema", + "table": "my_table", + } + + db_io = databricks_io.DatabricksIO(config, "/test/path") + + self.assertEqual(db_io.catalog, "my_catalog") + self.assertEqual(db_io.schema, "my_schema") + self.assertEqual(db_io.table, "my_table") + self.assertEqual(db_io.name, "minimal_table") + + def test_alias_format_with_special_characters(self) -> None: + """Test alias format with underscores and numbers.""" + config = { + "type": "databricks", + "name": "output_123", + "catalog": "prod_catalog_v2", + "schema": "analytics_schema", + "table": "user_events_2024", + } + + db_io = databricks_io.DatabricksIO(config, "/") + + self.assertEqual( + db_io.alias(), + "databricks://prod_catalog_v2/analytics_schema/user_events_2024", + ) + self.assertEqual( + db_io.rendered_name, "prod_catalog_v2.analytics_schema.user_events_2024" + ) + self.assertEqual( + db_io.airflow_name, + "databricks-prod_catalog_v2-analytics_schema-user_events_2024", + ) + + def test_has_dependency_false(self) -> None: + """Test that has_dependency can be set to False.""" + config = { + "type": "databricks", + "name": "no_dep_table", + "catalog": "cat", + "schema": "sch", + "table": "tbl", + "has_dependency": False, + } + + db_io = databricks_io.DatabricksIO(config, "/") + + self.assertFalse(db_io.has_dependency) + + +class TestDatabricksIOMissingFields(unittest.TestCase): + """Test cases for DatabricksIO error handling.""" + + def test_missing_catalog_raises_exception(self) -> None: + """Test that missing catalog raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "schema": "test_schema", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_schema_raises_exception(self) -> None: + """Test that missing schema raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "catalog": "test_catalog", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_table_raises_exception(self) -> None: + """Test that missing table raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "name": "test_table", + "catalog": "test_catalog", + "schema": "test_schema", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + def test_missing_name_raises_exception(self) -> None: + """Test that missing name raises DaggerMissingFieldException.""" + config = { + "type": "databricks", + "catalog": "test_catalog", + "schema": "test_schema", + "table": "test_table", + } + + with self.assertRaises(DaggerMissingFieldException): + databricks_io.DatabricksIO(config, "/test/config.yaml") + + +class TestDatabricksIOEquality(unittest.TestCase): + """Test cases for DatabricksIO equality comparison.""" + + def test_equal_ios_are_equal(self) -> None: + """Test that two IOs with same alias are equal.""" + config1 = { + "type": "databricks", + "name": "table1", + "catalog": "cat", + "schema": "sch", + "table": "tbl", + } + config2 = { + "type": "databricks", + "name": "table2", # Different name, same catalog.schema.table + "catalog": "cat", + "schema": "sch", + "table": "tbl", + } + + io1 = databricks_io.DatabricksIO(config1, "/") + io2 = databricks_io.DatabricksIO(config2, "/") + + self.assertEqual(io1, io2) + + def test_different_ios_are_not_equal(self) -> None: + """Test that two IOs with different aliases are not equal.""" + config1 = { + "type": "databricks", + "name": "table1", + "catalog": "cat1", + "schema": "sch", + "table": "tbl", + } + config2 = { + "type": "databricks", + "name": "table2", + "catalog": "cat2", + "schema": "sch", + "table": "tbl", + } + + io1 = databricks_io.DatabricksIO(config1, "/") + io2 = databricks_io.DatabricksIO(config2, "/") + + self.assertNotEqual(io1, io2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipeline/tasks/test_databricks_dlt_task.py b/tests/pipeline/tasks/test_databricks_dlt_task.py new file mode 100644 index 0000000..a222148 --- /dev/null +++ b/tests/pipeline/tasks/test_databricks_dlt_task.py @@ -0,0 +1,176 @@ +"""Unit tests for DatabricksDLTTask.""" + +import unittest +from unittest.mock import MagicMock + +import yaml + +from dagger.pipeline.tasks.databricks_dlt_task import DatabricksDLTTask + + +class TestDatabricksDLTTask(unittest.TestCase): + """Test cases for DatabricksDLTTask.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + with open( + "tests/fixtures/pipeline/tasks/databricks_dlt_task.yaml", "r" + ) as stream: + self.config = yaml.safe_load(stream) + + # Create a mock pipeline object + self.mock_pipeline = MagicMock() + self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + self.task = DatabricksDLTTask( + name="test_dlt_task", + pipeline_name="test_pipeline", + pipeline=self.mock_pipeline, + job_config=self.config, + ) + + def test_ref_name(self) -> None: + """Test that ref_name is correctly set.""" + self.assertEqual(DatabricksDLTTask.ref_name, "databricks_dlt") + + def test_job_name(self) -> None: + """Test job_name property.""" + self.assertEqual(self.task.job_name, "test-dlt-job") + + def test_databricks_conn_id(self) -> None: + """Test databricks_conn_id property.""" + self.assertEqual(self.task.databricks_conn_id, "databricks_test") + + def test_wait_for_completion(self) -> None: + """Test wait_for_completion property.""" + self.assertTrue(self.task.wait_for_completion) + + def test_poll_interval_seconds(self) -> None: + """Test poll_interval_seconds property.""" + self.assertEqual(self.task.poll_interval_seconds, 60) + + def test_timeout_seconds(self) -> None: + """Test timeout_seconds property.""" + self.assertEqual(self.task.timeout_seconds, 7200) + + def test_cancel_on_kill(self) -> None: + """Test cancel_on_kill property.""" + self.assertTrue(self.task.cancel_on_kill) + + def test_task_name(self) -> None: + """Test that task name is correctly set.""" + self.assertEqual(self.task.name, "test_dlt_task") + + def test_pipeline_name(self) -> None: + """Test that pipeline_name is correctly set.""" + self.assertEqual(self.task.pipeline_name, "test_pipeline") + + +class TestDatabricksDLTTaskDefaults(unittest.TestCase): + """Test cases for DatabricksDLTTask default values.""" + + def setUp(self) -> None: + """Set up test fixtures with minimal config.""" + self.config = { + "type": "databricks_dlt", + "description": "Test DLT task with defaults", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "minimal-dlt-job", + }, + } + + self.mock_pipeline = MagicMock() + self.mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + self.task = DatabricksDLTTask( + name="minimal_dlt_task", + pipeline_name="test_pipeline", + pipeline=self.mock_pipeline, + job_config=self.config, + ) + + def test_default_databricks_conn_id(self) -> None: + """Test default databricks_conn_id value.""" + self.assertEqual(self.task.databricks_conn_id, "databricks_default") + + def test_default_wait_for_completion(self) -> None: + """Test default wait_for_completion value.""" + self.assertTrue(self.task.wait_for_completion) + + def test_default_poll_interval_seconds(self) -> None: + """Test default poll_interval_seconds value.""" + self.assertEqual(self.task.poll_interval_seconds, 30) + + def test_default_timeout_seconds(self) -> None: + """Test default timeout_seconds value.""" + self.assertEqual(self.task.timeout_seconds, 3600) + + def test_default_cancel_on_kill(self) -> None: + """Test default cancel_on_kill value.""" + self.assertTrue(self.task.cancel_on_kill) + + +class TestDatabricksDLTTaskBooleanHandling(unittest.TestCase): + """Test cases for boolean parameter handling edge cases.""" + + def test_wait_for_completion_false(self) -> None: + """Test that wait_for_completion=false is correctly handled.""" + config = { + "type": "databricks_dlt", + "description": "Test", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "test-job", + "wait_for_completion": False, + }, + } + + mock_pipeline = MagicMock() + mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + task = DatabricksDLTTask( + name="test_task", + pipeline_name="test_pipeline", + pipeline=mock_pipeline, + job_config=config, + ) + + self.assertFalse(task.wait_for_completion) + + def test_cancel_on_kill_false(self) -> None: + """Test that cancel_on_kill=false is correctly handled.""" + config = { + "type": "databricks_dlt", + "description": "Test", + "inputs": [], + "outputs": [], + "airflow_task_parameters": None, + "template_parameters": None, + "task_parameters": { + "job_name": "test-job", + "cancel_on_kill": False, + }, + } + + mock_pipeline = MagicMock() + mock_pipeline.directory = "tests/fixtures/pipeline/tasks" + + task = DatabricksDLTTask( + name="test_task", + pipeline_name="test_pipeline", + pipeline=mock_pipeline, + job_config=config, + ) + + self.assertFalse(task.cancel_on_kill) + + +if __name__ == "__main__": + unittest.main() From 40bdea4b037182e0107184bb9461a3d8cbb8805e Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:22:40 +0100 Subject: [PATCH 7/8] Add coding standard to avoid getattr in CLAUDE.md Prefer explicit properties over getattr for type safety and better IDE support. --- CLAUDE.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index ffd1ebd..9cf1b75 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,3 +80,19 @@ dagger print-graph # Visualize dependency graph - **Factory Pattern**: TaskFactory/IOFactory auto-discover types via reflection - **Strategy Pattern**: OperatorCreator subclasses handle task-specific operator creation - **Dataset Aliasing**: IO `alias()` method enables automatic dependency detection across pipelines + +## Coding Standards + +### Avoid getattr +Do not use `getattr` for accessing task or IO properties. Instead, define explicit properties on the class. This ensures: +- Type safety and IDE autocompletion +- Clear interface contracts +- Easier debugging and testing + +```python +# Bad - avoid this pattern +value = getattr(self._task, 'some_property', default) + +# Good - use explicit properties +value = self._task.some_property # Property defined on task class +``` From 99898b47c1a87f7cde3c25bb659b8ea0a3563aa3 Mon Sep 17 00:00:00 2001 From: Mohamad Hallak <16711801+mrhallak@users.noreply.github.com> Date: Tue, 13 Jan 2026 17:26:54 +0100 Subject: [PATCH 8/8] Revert dbt_config_parser.py changes --- dagger/utilities/dbt_config_parser.py | 180 +++----------------------- 1 file changed, 19 insertions(+), 161 deletions(-) diff --git a/dagger/utilities/dbt_config_parser.py b/dagger/utilities/dbt_config_parser.py index 9b86d5d..9a341f6 100644 --- a/dagger/utilities/dbt_config_parser.py +++ b/dagger/utilities/dbt_config_parser.py @@ -296,20 +296,7 @@ def _generate_dagger_output(self, node: dict): class DatabricksDBTConfigParser(DBTConfigParser): - """DBT config parser implementation for Databricks Unity Catalog. - - Parses dbt manifest.json files for projects using the databricks-dbt adapter - and generates Dagger task configurations. Handles both Unity Catalog sources - (accessed via Databricks) and legacy Hive metastore sources (accessed via Athena). - - Attributes: - LEGACY_HIVE_DATABASES: Set of database names that indicate legacy Hive - metastore tables accessed via Athena rather than Unity Catalog. - """ - - # Schemas that indicate sources are in legacy Hive metastore (accessed via Athena) - # rather than Unity Catalog (accessed via Databricks) - LEGACY_HIVE_DATABASES: set[str] = {"hive_metastore"} + """Implementation for Databricks configurations.""" def __init__(self, default_config_parameters: dict): super().__init__(default_config_parameters) @@ -319,132 +306,17 @@ def __init__(self, default_config_parameters: dict): "create_external_athena_table", False ) - def _is_databricks_source(self, node: dict) -> bool: - """Check if a source is a Unity Catalog table (accessed via Databricks). - - Sources with database 'hive_metastore' are legacy tables accessed via Athena. - Sources with other databases (e.g., Unity Catalog like ${ENV_MARTS}) are - Databricks tables that should create databricks input tasks. - - Args: - node: The source node from dbt manifest - - Returns: - True if the source is a Unity Catalog table, False otherwise + def _is_node_preparation_model(self, node: dict): """ - database = node.get("database", "") - return database not in self.LEGACY_HIVE_DATABASES - - def _is_node_preparation_model(self, node: dict) -> bool: - """Determine whether a node is a preparation model. - - Preparation models are intermediate models in the transformation pipeline - that should not create external dependencies. - - Args: - node: The dbt node from manifest.json. - - Returns: - True if the node's schema contains 'preparation', False otherwise. + Define whether it is a preparation model. """ return "preparation" in node.get("schema", "") - def _get_databricks_source_task( - self, node: dict, follow_external_dependency: bool = False - ) -> dict: - """Generate a databricks input task for a Unity Catalog source. - - This is used for sources that point to Unity Catalog tables (e.g., DLT outputs) - rather than legacy Hive metastore tables. - - Args: - node: The source node from dbt manifest - follow_external_dependency: Whether to create an ExternalTaskSensor - - Returns: - Dagger databricks task configuration dict - """ - task = DATABRICKS_TASK_BASE.copy() - if follow_external_dependency: - task["follow_external_dependency"] = True - - task["catalog"] = node.get("database", self._default_catalog) - task["schema"] = node.get("schema", self._default_schema) - task["table"] = node.get("name", "") - task["name"] = f"{task['catalog']}__{task['schema']}__{task['table']}_databricks" - - return task - - def _generate_dagger_tasks(self, node_name: str) -> List[Dict]: - """Generate dagger tasks, with special handling for Databricks Unity Catalog sources. - - Overrides the base class method to handle sources that are in Unity Catalog - (e.g., DLT output tables) by creating databricks input tasks instead of athena tasks. - - Args: - node_name: The name of the DBT model node - - Returns: - List[Dict]: The respective dagger tasks for the DBT model node - """ - dagger_tasks = [] - - if node_name.startswith("source"): - node = self._sources_in_manifest[node_name] - else: - node = self._nodes_in_manifest[node_name] - - resource_type = node.get("resource_type") - materialized_type = node.get("config", {}).get("materialized") - - follow_external_dependency = True - if resource_type == "seed" or (self._is_node_preparation_model(node) and materialized_type != "table"): - follow_external_dependency = False - - if resource_type == "source": - # Check if this source is a Unity Catalog table (e.g., DLT outputs) - if self._is_databricks_source(node): - table_task = self._get_databricks_source_task( - node, follow_external_dependency=follow_external_dependency - ) - else: - # Legacy Hive metastore sources use Athena - table_task = self._get_athena_table_task( - node, follow_external_dependency=follow_external_dependency - ) - dagger_tasks.append(table_task) - - elif materialized_type == "ephemeral": - task = self._get_dummy_task(node) - dagger_tasks.append(task) - for dependent_node_name in node.get("depends_on", {}).get("nodes", []): - dagger_tasks += self._generate_dagger_tasks(dependent_node_name) - - else: - table_task = self._get_table_task(node, follow_external_dependency=follow_external_dependency) - dagger_tasks.append(table_task) - - if materialized_type in ("table", "incremental"): - dagger_tasks.append(self._get_s3_task(node)) - elif self._is_node_preparation_model(node): - for dependent_node_name in node.get("depends_on", {}).get("nodes", []): - dagger_tasks.extend( - self._generate_dagger_tasks(dependent_node_name) - ) - - return dagger_tasks - def _get_table_task( self, node: dict, follow_external_dependency: bool = False ) -> dict: - """Generate a Databricks table task for a dbt model node. - - Args: - node: The dbt model node from manifest.json. - follow_external_dependency: Whether to create an ExternalTaskSensor. - - Returns: - Dagger databricks task configuration dict. + """ + Generates the dagger databricks task for the DBT model node """ task = DATABRICKS_TASK_BASE.copy() if follow_external_dependency: @@ -462,15 +334,8 @@ def _get_table_task( def _get_model_data_location( self, node: dict, schema: str, model_name: str ) -> Tuple[str, str]: - """Get the S3 path of a dbt model relative to the data bucket. - - Args: - node: The dbt model node from manifest.json. - schema: The schema name (unused for Databricks, kept for interface compatibility). - model_name: The model name. - - Returns: - Tuple of (bucket_name, data_path). + """ + Gets the S3 path of the dbt model relative to the data bucket. """ location_root = node.get("config", {}).get("location_root") location = join(location_root, model_name) @@ -480,39 +345,32 @@ def _get_model_data_location( return bucket_name, data_path def _get_s3_task(self, node: dict, is_output: bool = False) -> dict: - """Generate an S3 task for a databricks-dbt model node. - - Args: - node: The dbt model node from manifest.json. - is_output: If True, names the task 'output_s3_path' for output declarations. - - Returns: - Dagger S3 task configuration dict. + """ + Generates the dagger s3 task for the databricks-dbt model node """ task = S3_TASK_BASE.copy() schema = node.get("schema", self._default_schema) table = node.get("name", "") - task["name"] = "output_s3_path" if is_output else f"s3_{table}" + task["name"] = f"output_s3_path" if is_output else f"s3_{table}" task["bucket"], task["path"] = self._get_model_data_location( node, schema, table ) return task - def _generate_dagger_output(self, node: dict) -> List[Dict]: - """Generate dagger output tasks for a databricks-dbt model node. - - Creates output task configurations based on the model's materialization type: - - Ephemeral models produce a dummy task - - Table/incremental models produce databricks + S3 tasks - - Optionally adds an Athena task if create_external_athena_table is True - + def _generate_dagger_output(self, node: dict): + """ + Generates the dagger output for the DBT model node with the databricks-dbt adapter. + If the model is materialized as a view or ephemeral, then a dummy task is created. + Otherwise, and databricks and s3 task is created for the DBT model node. + And if create_external_athena_table is True te an extra athena task is created. Args: - node: The dbt model node from manifest.json. + node: The extracted node from the manifest.json file Returns: - List of dagger output task configuration dicts. + dict: The dagger output, which is a combination of an athena and s3 task for the DBT model node + """ materialized_type = node.get("config", {}).get("materialized") if materialized_type == "ephemeral":