Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 138 additions & 27 deletions dbt_coverage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
class CoverageType(Enum):
DOC = "doc"
TEST = "test"
UNIT_TEST = "unit_test"


class CoverageFormat(str, Enum):
Expand Down Expand Up @@ -64,6 +65,7 @@ class Table:
name: str
original_file_path: str
columns: Dict[str, Column]
unit_tests: List[Dict] = field(default_factory=list)

@staticmethod
def from_node(node, manifest: Manifest) -> Table:
Expand Down Expand Up @@ -159,6 +161,7 @@ class Manifest:
seeds: Dict[str, Dict[str, Dict[str, Dict]]]
snapshots: Dict[str, Dict[str, Dict[str, Dict]]]
tests: Dict[str, Dict[str, List[Dict]]]
unit_tests: Dict[str, List[Dict]]

@classmethod
def from_nodes(cls, manifest_nodes: Dict[str, Dict]) -> Manifest:
Expand Down Expand Up @@ -209,8 +212,9 @@ def from_nodes(cls, manifest_nodes: Dict[str, Dict]) -> Manifest:
}

tests = cls._parse_tests(manifest_nodes)
unit_tests = cls._parse_unit_tests(manifest_nodes)

return Manifest(sources, models, seeds, snapshots, tests)
return Manifest(sources, models, seeds, snapshots, tests, unit_tests)

def get_table(self, table_id):
source_candidate = self.sources.get(table_id)
Expand Down Expand Up @@ -270,6 +274,27 @@ def _parse_tests(cls, manifest_nodes: Dict[str, Dict]) -> Dict[str, Dict[str, Li

return tests

@classmethod
def _parse_unit_tests(cls, manifest_nodes: Dict[str, Dict]) -> Dict[str, List[Dict]]:
"""Parses unit tests from manifest.json nodes."""

unit_tests = {}
for node in manifest_nodes.values():
if node["resource_type"] != "unit_test":
continue

depends_on = node["depends_on"]["nodes"]
if not depends_on:
continue

table_id = depends_on[0]

# Create a model unique_id pattern to match against models
# Unit tests are associated with models, not individual columns
unit_tests.setdefault(table_id, []).append(node)

return unit_tests

@staticmethod
def _full_table_name(table):
return f"{table['schema']}.{table['name']}".lower()
Expand Down Expand Up @@ -323,12 +348,45 @@ class ColumnRef:
subentities: Dict[str, CoverageReport]

def __post_init__(self):
if self.covered is not None and self.total is not None and self.total != 0:
self.misses = self.total - self.covered
self.coverage = len(self.covered) / len(self.total)
if self.cov_type == CoverageType.UNIT_TEST:
# For unit tests, coverage is based on tables that have at least one unit test
# We need to count unique tables that have unit tests vs total tables
covered_tables = set()
total_tables = set()

if self.entity_type == CoverageReport.EntityType.CATALOG:
# For catalog, aggregate from subentities
for table_report in self.subentities.values():
if table_report.covered: # Table has unit tests
covered_tables.add(table_report.entity_name)
total_tables.add(table_report.entity_name)
elif self.entity_type == CoverageReport.EntityType.TABLE:
# For table, check if it has unit tests
if self.covered: # Has unit tests
covered_tables.add(self.entity_name)
total_tables.add(self.entity_name)

if total_tables:
self.coverage = len(covered_tables) / len(total_tables)
else:
self.coverage = 0.0

# Set misses based on tables without unit tests
if self.entity_type == CoverageReport.EntityType.CATALOG:
self.misses = {CoverageReport.ColumnRef(table, None) for table in total_tables - covered_tables}
else:
self.misses = set() if self.covered else {CoverageReport.ColumnRef(self.entity_name, None)}
else:
self.misses = None
self.coverage = None
# Original logic for doc and test coverage
if self.covered and self.total:
self.misses = self.total - self.covered
self.coverage = len(self.covered) / len(self.total)
elif self.covered:
self.misses = None
self.coverage = 1.0
else:
self.misses = None
self.coverage = 0.0

@classmethod
def from_catalog(cls, catalog: Catalog, cov_type: CoverageType):
Expand All @@ -343,6 +401,21 @@ def from_catalog(cls, catalog: Catalog, cov_type: CoverageType):

@classmethod
def from_table(cls, table: Table, cov_type: CoverageType):
if cov_type == CoverageType.UNIT_TEST:
# For unit tests:
# - covered contains one entry per unit test (for count display)
# - total contains one entry per table (for coverage calculation)
covered = {
CoverageReport.ColumnRef(table.name, unit_test.get("name", f"unit_test_{i}"))
for i, unit_test in enumerate(table.unit_tests)
}
total = {CoverageReport.ColumnRef(table.name, None)} # One entry per table

return CoverageReport(
cls.EntityType.TABLE, cov_type, table.name, covered, total, {}
)


subentities = {
col.name: CoverageReport.from_column(col, cov_type) for col in table.columns.values()
}
Expand All @@ -367,6 +440,9 @@ def from_column(cls, column: Column, cov_type: CoverageType):
covered = column.doc
elif cov_type == CoverageType.TEST:
covered = column.test
elif cov_type == CoverageType.UNIT_TEST:
# Unit tests don't apply at column level
raise ValueError("Unit test coverage is not supported at column level")
else:
raise ValueError(f"Unsupported cov_type {cov_type}")

Expand All @@ -377,22 +453,41 @@ def from_column(cls, column: Column, cov_type: CoverageType):

def to_markdown_table(self):
if self.entity_type == CoverageReport.EntityType.TABLE:
return (
f"| {self.entity_name:70} | {len(self.covered):5}/{len(self.total):<5} | "
f"{self.coverage * 100:5.1f}% |"
)
if self.cov_type == CoverageType.UNIT_TEST:
return (
f"| {self.entity_name:70} | {len(self.covered):5} tests | "
f"{self.coverage * 100:5.1f}% |"
)
else:
return (
f"| {self.entity_name:70} | {len(self.covered):5}/{len(self.total):<5} | "
f"{self.coverage * 100:5.1f}% |"
)
elif self.entity_type == CoverageReport.EntityType.CATALOG:
buf = io.StringIO()

buf.write(f"# Coverage report ({self.cov_type.value})\n")
buf.write("| Model | Columns Covered | % |\n")
buf.write("|:------|----------------:|:-:|\n")
if self.cov_type == CoverageType.UNIT_TEST:
buf.write("| Model | Unit Tests | % |\n")
buf.write("|:------|----------------:|:-:|\n")
else:
buf.write("| Model | Covered | % |\n")
buf.write("|:------|----------------:|:-:|\n")
for _, table_cov in sorted(self.subentities.items()):
buf.write(table_cov.to_markdown_table() + "\n")
buf.write(
f"| {'Total':70} | {len(self.covered):5}/{len(self.total):<5} | "
f"{self.coverage * 100:5.1f}% |\n"
)
if self.cov_type == CoverageType.UNIT_TEST:
buf.write(
f"| {'Total':70} | {len(self.covered):5} tests | "
f"{self.coverage * 100:5.1f}% |\n"
)
else:
buf.write(
f"| {'Total':70} | {len(self.covered):5}/{len(self.total):<5} | "
f"{self.coverage * 100:5.1f}% |\n"
) if self.total != 0 else buf.write(
f"| {'Total':70} | {len(self.covered):10} | "
f"{self.coverage * 100:5.1f}% |\n"
)

return buf.getvalue()
else:
Expand All @@ -403,22 +498,34 @@ def to_markdown_table(self):

def to_formatted_string(self):
if self.entity_type == CoverageReport.EntityType.TABLE:
return (
f"{self.entity_name:50} {len(self.covered):5}/{len(self.total):<5} "
f"{self.coverage * 100:5.1f}%"
)
if self.cov_type == CoverageType.UNIT_TEST:
return (
f"{self.entity_name:50} {len(self.covered):5} tests "
f"{self.coverage * 100:5.1f}%"
)
else:
return (
f"{self.entity_name:50} {len(self.covered):5}/{len(self.total):<5} "
f"{self.coverage * 100:5.1f}%"
)
elif self.entity_type == CoverageReport.EntityType.CATALOG:
buf = io.StringIO()

buf.write(f"Coverage report ({self.cov_type.value})\n")
buf.write("=" * 69 + "\n")
buf.write("=" * 70 + "\n")
for _, table_cov in sorted(self.subentities.items()):
buf.write(table_cov.to_formatted_string() + "\n")
buf.write("=" * 69 + "\n")
buf.write(
f"{'Total':50} {len(self.covered):5}/{len(self.total):<5} "
f"{self.coverage * 100:5.1f}%\n"
)
buf.write("=" * 70 + "\n")
if self.cov_type == CoverageType.UNIT_TEST:
buf.write(
f"{'Total':50} {len(self.covered):5} tests "
f"{self.coverage * 100:5.1f}%\n"
)
else:
buf.write(
f"{'Total':50} {len(self.covered):5}/{len(self.total):<5} "
f"{self.coverage * 100:5.1f}%\n"
)

return buf.getvalue()
else:
Expand Down Expand Up @@ -731,7 +838,7 @@ def load_manifest(project_dir: Path, run_artifacts_dir: Path) -> Manifest:

check_manifest_version(manifest_json)

manifest_nodes = {**manifest_json["sources"], **manifest_json["nodes"]}
manifest_nodes = {**manifest_json["sources"], **manifest_json["nodes"], **manifest_json["unit_tests"]}
manifest = Manifest.from_nodes(manifest_nodes)

return manifest
Expand Down Expand Up @@ -772,6 +879,10 @@ def load_files(project_dir: Path, run_artifacts_dir: Path) -> Catalog:
catalog_column.doc = Column.is_valid_doc(doc)
catalog_column.test = Column.is_valid_test(manifest_column_tests)

# Set unit tests for the table
table_unit_tests = manifest.unit_tests.get(table_id, [])
catalog_table.unit_tests = table_unit_tests

return catalog


Expand Down
Loading