Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions fgmetric/_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def is_optional(annotation: TypeAnnotation | None) -> bool:
Returns:
True if the type is a union type containing `None`.
False otherwise.

Examples:
>>> is_optional(int | None)
True
>>> is_optional(int | str | None)
True
>>> is_optional(int)
False
>>> is_optional(list[int])
False
>>> is_optional(None)
False
"""
if annotation is None:
return False
Expand Down Expand Up @@ -61,6 +73,16 @@ def unpack_optional(annotation: TypeAnnotation) -> TypeAnnotation:

Raises:
ValueError: If the input is not an optional type.

Examples:
>>> unpack_optional(int | None)
<class 'int'>
>>> unpack_optional(int | str | None)
int | str
>>> unpack_optional(list[int] | None)
list[int]
>>> unpack_optional(int) # not optional
ValueError: Type is not Optional: <class 'int'>
"""
if not is_optional(annotation):
raise ValueError(f"Type is not Optional: {annotation}")
Expand All @@ -80,7 +102,19 @@ def unpack_optional(annotation: TypeAnnotation) -> TypeAnnotation:


def has_optional_elements(annotation: TypeAnnotation | None) -> bool:
"""True if annotation is a list with optional element type (e.g., list[int | None])."""
"""
True if annotation is a list with optional element type (e.g., list[int | None]).

Examples:
>>> has_optional_elements(list[int | None])
True
>>> has_optional_elements(list[int | None] | None)
True
>>> has_optional_elements(list[int])
False
>>> has_optional_elements(list[int] | None)
False
"""
if annotation is None:
return False

Expand All @@ -106,6 +140,14 @@ def has_origin(annotation: TypeAnnotation | None, origin: type) -> bool:
Returns:
True if the annotation is a parameterized instance of `origin`.
False otherwise.

Examples:
>>> has_origin(list[int], list)
True
>>> has_origin(list[int] | None, list)
True
>>> has_origin(set[int], list)
False
"""
if annotation is None:
return False
Expand All @@ -122,10 +164,32 @@ def is_list(annotation: TypeAnnotation | None) -> bool:
Check if a type annotation is a list type.

Matches `list[T]`, `Optional[list[T]]`, and `list[T] | None`.

Examples:
>>> is_list(list[int])
True
>>> is_list(list[int] | None)
True
>>> is_list(set[int])
False
>>> is_list(list) # bare list, no type parameter
False
"""
return has_origin(annotation, list)


def is_counter(annotation: TypeAnnotation | None) -> bool:
"""True if the type annotation is a Counter."""
"""
True if the type annotation is a Counter.

Examples:
>>> is_counter(Counter[str])
True
>>> is_counter(Counter[str] | None)
True
>>> is_counter(dict[str, int])
False
>>> is_counter(Counter) # bare Counter, no type parameter
False
"""
return has_origin(annotation, Counter)
75 changes: 74 additions & 1 deletion fgmetric/collections/_counter_pivot_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,43 @@ class CounterPivotTable(BaseModel):

**IMPORTANT:** As with all Python mixins, this class must precede `Metric` when declaring a
*metric's parent classes, in order for its methods to take precedence over `Metric`'s defaults.

Examples:
Defining a metric with a pivot-table counter:

```python
class Color(StrEnum):
RED = "red"
GREEN = "green"
BLUE = "blue"

class MyMetric(CounterPivotTable, Metric):
name: str
counts: Counter[Color]
```

Deserialization — wide columns are folded into the Counter:

```python
row = {"name": "foo", "red": 10, "green": 20, "blue": 30}
m = MyMetric.model_validate(row)
m.counts # Counter({Color.RED: 10, Color.GREEN: 20, Color.BLUE: 30})
```

Serialization — the Counter is pivoted back to wide columns:

```python
m.model_dump()
# {"name": "foo", "red": 10, "green": 20, "blue": 30}
```

Missing enum members default to zero:

```python
row = {"name": "foo", "red": 5}
m = MyMetric.model_validate(row)
m.counts # Counter({Color.RED: 5, Color.GREEN: 0, Color.BLUE: 0})
```
"""

_counter_fieldname: ClassVar[str | None]
Expand Down Expand Up @@ -62,6 +99,25 @@ def _get_counter_fieldname(cls) -> str | None:
Raises:
TypeError: If the user-specified model includes more than one field annotated as
`Counter[T]`.

Examples:
>>> # One counter field -> returns its name
>>> class M(CounterPivotTable, Metric):
... counts: Counter[Color]
>>> M._counter_fieldname
'counts'

>>> # No counter field -> returns None
>>> class M(CounterPivotTable, Metric):
... name: str
>>> M._counter_fieldname
None

>>> # Two counter fields -> raises TypeError
>>> class M(CounterPivotTable, Metric):
... counts_a: Counter[Color]
... counts_b: Counter[Color]
TypeError: Only one Counter per model is currently supported. ...
"""
counter_fieldnames = [
name for name, info in cls.model_fields.items() if is_counter(info.annotation)
Expand Down Expand Up @@ -97,6 +153,16 @@ def _get_counter_enum(cls) -> type[StrEnum] | None: # noqa: C901
Raises:
TypeError: If the user-specified model includes a Counter field with a type parameter
that is not a subclass of `StrEnum`.

Examples:
>>> class M(CounterPivotTable, Metric):
... counts: Counter[Color] # Color is a StrEnum
>>> M._counter_enum
<enum 'Color'>

>>> class M(CounterPivotTable, Metric):
... counts: Counter[int] # int is not a StrEnum
TypeError: Counter fields must have a StrEnum type parameter: ...
"""
if cls._counter_fieldname is None:
# No counter fields -> short-circuit
Expand Down Expand Up @@ -171,7 +237,14 @@ def _pivot_counter_values(
nxt: SerializerFunctionWrapHandler, # noqa: ARG002
info: SerializationInfo, # noqa: ARG002
) -> dict[str, Any]:
"""Pivot the Counter values out wide."""
"""
Pivot the Counter values out wide.

Example:
Given ``counts = Counter({Color.RED: 10, Color.GREEN: 20, Color.BLUE: 30})``,
the output dict will contain ``{"red": 10, "green": 20, "blue": 30}`` in place
of ``{"counts": Counter(...)}``.
"""
# Call the default serializer
data: dict[str, Any] = nxt(self)

Expand Down
18 changes: 15 additions & 3 deletions fgmetric/collections/_delimited_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class DelimitedList(BaseModel):
`["a", "b", "c"]`. Avoid using delimiters that may appear in element values.

Examples:
Basic usage with comma delimiter (default):
Basic usage comma delimiter (default):

```python
class MyMetric(Metric):
tags: list[int] # "1,2,3" becomes [1, 2, 3]

MyMetric.model_validate({"tags": "1,2,3"}).tags # -> [1, 2, 3]
MyMetric(tags=[1, 2, 3]).model_dump() # -> {"tags": "1,2,3"}
```

Custom delimiter:
Expand All @@ -54,20 +57,29 @@ class MyMetric(Metric):
class MyMetric(Metric):
collection_delimiter = ";"
tags: list[int] # "1;2;3" becomes [1, 2, 3]

MyMetric.model_validate({"tags": "1;2;3"}).tags # -> [1, 2, 3]
MyMetric(tags=[1, 2, 3]).model_dump() # -> {"tags": "1;2;3"}
```

Optional list field:
Optional list field — the whole field may be absent:

```python
class MyMetric(Metric):
tags: list[int] | None # "" becomes None

MyMetric.model_validate({"tags": ""}).tags # -> None
MyMetric(tags=None).model_dump() # -> {"tags": None}
```

List field with Optional elements:
List with optional elements — individual elements may be absent:

```python
class MyMetric(Metric):
tags: list[int | None] # "1,,3" becomes [1, None, 3]

MyMetric.model_validate({"tags": "1,,3"}).tags # -> [1, None, 3]
MyMetric(tags=[1, None, 3]).model_dump() # -> {"tags": "1,,3"}
```
"""

Expand Down
19 changes: 18 additions & 1 deletion fgmetric/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ class AlignmentMetric(Metric):

@classmethod
def read(cls, path: Path, delimiter: str = "\t") -> Iterator[Self]:
"""Read Metric instances from file."""
"""
Read Metric instances from file.

Example:
```python
for m in AlignmentMetric.read(Path("out.tsv")):
print(m.read_name, m.mapping_quality)
```
"""
# NOTE: the utf-8-sig encoding is required to auto-remove BOM from input file headers
with path.open(encoding="utf-8-sig") as fin:
for record in DictReader(fin, delimiter=delimiter):
Expand Down Expand Up @@ -108,6 +116,15 @@ def _header_fieldnames(cls) -> list[str]:

Returns:
The list of fieldnames to use as the header row.

Example:
Given a model with ``name: str`` and ``counts: Counter[Color]`` where
``Color`` has members ``RED``, ``GREEN``, ``BLUE``:

```python
cls._header_fieldnames()
# -> ["name", "red", "green", "blue"]
```
"""
# TODO: support returning the set of fields that would be constructed if the class has a
# custom model serializer
Expand Down
3 changes: 1 addition & 2 deletions fgmetric/metric_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ class AlignmentMetric(Metric):
mapping_quality: int
is_duplicate: bool = False

# Write metrics to a TSV file
metrics = [
AlignmentMetric(read_name="read1", mapping_quality=60, is_duplicate=False),
AlignmentMetric(read_name="read2", mapping_quality=30, is_duplicate=True),
]

with MetricWriter(AlignmentMetric, "output.txt") as writer:
with MetricWriter(AlignmentMetric, Path("output.tsv")) as writer:
writer.writeall(metrics)
```

Expand Down