Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fgmetric/collections/_counter_pivot_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _pivot_counter_values(
# Replace the counter field with keys for each of its enum's members
counts = data.pop(self._counter_fieldname)
for key, count in counts.items():
data[key.value] = count
column_name = key if isinstance(key, str) else key.value
data[column_name] = count

return data
6 changes: 4 additions & 2 deletions fgmetric/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def write(self, metric: T) -> None:
"""
Write a single Metric instance to file.

The Metric is converted to a dictionary and then written using the underlying `DictWriter`.
The Metric is serialized using ``model_dump(mode="json")`` and then written using the
underlying `DictWriter`. JSON mode ensures that all field values (e.g., enums) are
converted to JSON-compatible types before writing.

Args:
metric: An instance of the specified Metric.
Expand All @@ -106,7 +108,7 @@ def write(self, metric: T) -> None:
TypeError: If the provided `metric` is not an instance of the Metric class used to
parametrize the writer.
"""
self._writer.writerow(metric.model_dump())
self._writer.writerow(metric.model_dump(mode="json"))

def writeall(self, metrics: Iterable[T]) -> None:
"""
Expand Down
18 changes: 18 additions & 0 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,24 @@ class FakeMetric(Metric):
next(f)


def test_counter_pivot_table_model_dump_json_mode() -> None:
"""Test that model_dump(mode='json') works with Counter pivot tables."""

@unique
class FakeEnum(StrEnum):
FOO = "foo"
BAR = "bar"

class FakeMetric(Metric):
name: str
counts: Counter[FakeEnum]

metric = FakeMetric(name="test", counts=Counter({FakeEnum.FOO: 1, FakeEnum.BAR: 2}))
result = metric.model_dump(mode="json")

assert result == {"name": "test", "foo": 1, "bar": 2}


def test_counter_pivot_table_missing_enum_members_default_to_zero(tmp_path: Path) -> None:
"""Test that missing enum members in input default to 0."""

Expand Down
27 changes: 27 additions & 0 deletions tests/test_metric_writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections import Counter
from enum import StrEnum
from enum import unique
from pathlib import Path
from typing import assert_type
from unittest import mock
Expand Down Expand Up @@ -47,3 +50,27 @@ def test_init_closes_file_on_failure(tmp_path: Path) -> None:
MetricWriter(FakeMetric, fpath)

assert real_fout.closed


def test_writer_with_counter_metric(tmp_path: Path) -> None:
"""Test we can write a Counter metric through MetricWriter."""

@unique
class FakeEnum(StrEnum):
FOO = "foo"
BAR = "bar"

class CounterMetric(Metric):
name: str
counts: Counter[FakeEnum]

fpath = tmp_path / "test.txt"

with MetricWriter(CounterMetric, fpath) as writer:
writer.write(CounterMetric(name="test", counts=Counter({FakeEnum.FOO: 3, FakeEnum.BAR: 4})))

with fpath.open("r") as f:
assert next(f) == "name\tfoo\tbar\n"
assert next(f) == "test\t3\t4\n"
with pytest.raises(StopIteration):
next(f)