From c802d98757ffca6dc3c06192b8bfe146b90a5fa6 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Tue, 6 May 2025 12:04:29 -0400 Subject: [PATCH] Upgraded to MEDS v0.4 --- pyproject.toml | 2 +- src/aces/run.py | 45 ++++++++++++++-------------------- tests/test_meds.py | 60 +++++++++++++++++++--------------------------- 3 files changed, 44 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ef4ea4..bffaca9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "pytimeparse == 1.1.*", "networkx == 3.3.*", "pyarrow == 17.*", - "meds == 0.3.3", + "meds ~= 0.4.0", ] [tool.setuptools] diff --git a/src/aces/run.py b/src/aces/run.py index be0c926..7efc41e 100644 --- a/src/aces/run.py +++ b/src/aces/run.py @@ -11,7 +11,7 @@ import polars as pl import pyarrow as pa import pyarrow.parquet as pq -from meds import label_schema, prediction_time_field, subject_id_field +from meds import LabelSchema from omegaconf import DictConfig, OmegaConf from . import config, predicates, query @@ -20,15 +20,15 @@ config_yaml = files("aces").joinpath("configs/_aces.yaml") MEDS_LABEL_MANDATORY_TYPES = { - subject_id_field: pl.Int64, + LabelSchema.subject_id_name: pl.Int64, } MEDS_LABEL_OPTIONAL_TYPES = { - "boolean_value": pl.Boolean, - "integer_value": pl.Int64, - "float_value": pl.Float64, - "categorical_value": pl.String, - prediction_time_field: pl.Datetime("us"), + LabelSchema.prediction_time_name: pl.Datetime("us"), + LabelSchema.boolean_value_name: pl.Boolean, + LabelSchema.integer_value_name: pl.Int64, + LabelSchema.float_value_name: pl.Float64, + LabelSchema.categorical_value_name: pl.String, } @@ -56,9 +56,9 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: >>> get_and_validate_label_schema(df) Traceback (most recent call last): ... - ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64. + ValueError: MEDS Label DataFrame must have a 'subject_id' column of type Int64. >>> df = pl.DataFrame({ - ... subject_id_field: pl.Series([1, 3, 2], dtype=pl.UInt32), + ... "subject_id": pl.Series([1, 3, 2], dtype=pl.UInt32), ... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)], ... "boolean_value": [1, 0, 100], ... }) @@ -68,7 +68,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: prediction_time: timestamp[us] boolean_value: bool integer_value: int64 - float_value: double + float_value: float categorical_value: string ---- subject_id: [[1,3,2]] @@ -80,7 +80,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: """ schema = df.schema - if "prediction_time" not in schema: + if LabelSchema.prediction_time_name not in schema: logger.warning( "Output DataFrame is missing a 'prediction_time' column. If this is not intentional, add a " "'index_timestamp' (yes, it should be different) key to the task configuration identifying " @@ -92,7 +92,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: if col in schema and schema[col] != dtype: df = df.with_columns(pl.col(col).cast(dtype, strict=False)) elif col not in schema: - errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.") + errors.append(f"MEDS Label DataFrame must have a '{col}' column of type {dtype}.") if errors: raise ValueError("\n".join(errors)) @@ -115,16 +115,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: ) df = df.drop(extra_cols) - df = df.select( - subject_id_field, - "prediction_time", - "boolean_value", - "integer_value", - "float_value", - "categorical_value", - ) - - return df.to_arrow().cast(label_schema) + return LabelSchema.align(df.to_arrow()) @hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem) @@ -154,18 +145,18 @@ def main(cfg: DictConfig) -> None: # pragma: no cover if cfg.data.standard.lower() == "meds": for in_col, out_col in [ - ("subject_id", subject_id_field), - ("index_timestamp", "prediction_time"), - ("label", "boolean_value"), + ("subject_id", LabelSchema.subject_id_name), + ("index_timestamp", LabelSchema.prediction_time_name), + ("label", LabelSchema.boolean_value_name), ]: if in_col in result.columns: result = result.rename({in_col: out_col}) - if subject_id_field not in result.columns: + if LabelSchema.subject_id_name not in result.columns: if not result_is_empty: raise ValueError("Output dataframe is missing a 'subject_id' column.") else: logger.warning("Output dataframe is empty; adding an empty patient ID column.") - result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias(subject_id_field)) + result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias(LabelSchema.subject_id_name)) result = result.head(0) if cfg.window_stats_dir: Path(cfg.window_stats_filepath).parent.mkdir(exist_ok=True, parents=True) diff --git a/tests/test_meds.py b/tests/test_meds.py index e7b2caf..4382cf8 100644 --- a/tests/test_meds.py +++ b/tests/test_meds.py @@ -7,7 +7,7 @@ import polars as pl import pyarrow as pa -from meds import label_schema, subject_id_field +from meds import DataSchema, LabelSchema from yaml import load as load_yaml from .utils import ( @@ -36,24 +36,23 @@ # TODO: Make use meds library MEDS_PL_SCHEMA = { - subject_id_field: pl.Int64, - "time": pl.Datetime("us"), - "code": pl.Utf8, - "numeric_value": pl.Float32, - "numeric_value/is_inlier": pl.Boolean, + DataSchema.subject_id_name: pl.Int64, + DataSchema.time_name: pl.Datetime("us"), + DataSchema.code_name: pl.Utf8, + DataSchema.numeric_value_name: pl.Float32, } MEDS_LABEL_MANDATORY_TYPES = { - subject_id_field: pl.Int64, + LabelSchema.subject_id_name: pl.Int64, } MEDS_LABEL_OPTIONAL_TYPES = { - "boolean_value": pl.Boolean, - "integer_value": pl.Int64, - "float_value": pl.Float64, - "categorical_value": pl.String, - "prediction_time": pl.Datetime("us"), + LabelSchema.boolean_value_name: pl.Boolean, + LabelSchema.integer_value_name: pl.Int64, + LabelSchema.float_value_name: pl.Float64, + LabelSchema.categorical_value_name: pl.String, + LabelSchema.prediction_time_name: pl.Datetime("us"), } @@ -113,16 +112,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table: ) df = df.drop(extra_cols) - df = df.select( - subject_id_field, - "prediction_time", - "boolean_value", - "integer_value", - "float_value", - "categorical_value", - ) - - return df.to_arrow().cast(label_schema) + return LabelSchema.align(df.to_arrow()) def parse_meds_csvs( @@ -140,7 +130,7 @@ def reader(csv_str: str) -> pl.DataFrame: cols = csv_str.strip().split("\n")[0].split(",") read_schema = {k: v for k, v in default_read_schema.items() if k in cols} return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns( - pl.col("time").str.strptime(MEDS_PL_SCHEMA["time"], DEFAULT_CSV_TS_FORMAT) + pl.col("time").str.strptime(MEDS_PL_SCHEMA[DataSchema.time_name], DEFAULT_CSV_TS_FORMAT) ) if isinstance(csvs, str): @@ -169,9 +159,9 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]: # Data (input) MEDS_SHARDS = parse_shards_yaml( - f""" + """ "train/0": |- - {subject_id_field},time,code,numeric_value + subject_id,time,code,numeric_value 2,,SNP//rs234567, 2,,SNP//rs345678, 2,,GENDER//FEMALE, @@ -196,7 +186,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]: 2,6/8/1996 3:00,DEATH, "train/1": |-2 - {subject_id_field},time,code,numeric_value + subject_id,time,code,numeric_value 4,,GENDER//MALE, 4,,SNP//rs123456, 4,12/1/1989 12:03,ADMISSION//CARDIAC, @@ -246,7 +236,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]: 6,3/12/1996 0:00,DEATH, "held_out/0/0": |-2 - {subject_id_field},time,code,numeric_value + subject_id,time,code,numeric_value 3,,GENDER//FEMALE, 3,,SNP//rs234567, 3,,SNP//rs345678, @@ -261,10 +251,10 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]: 3,3/12/1996 0:00,DEATH, "empty_shard": |-2 - {subject_id_field},time,code,numeric_value + subject_id,time,code,numeric_value "held_out": |-2 - {subject_id_field},time,code,numeric_value + subject_id,time,code,numeric_value 1,,GENDER//MALE, 1,,SNP//rs123456, 1,12/1/1989 12:03,ADMISSION//CARDIAC, @@ -349,22 +339,22 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]: """ WANT_SHARDS = parse_labels_yaml( - f""" + """ "train/0": |-2 - {subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value + subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value "train/1": |-2 - {subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value + subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value 4,1/28/1991 23:32,False,,,, "held_out/0/0": |-2 - {subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value + subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value "empty_shard": |-2 - {subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value + subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value "held_out": |-2 - {subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value + subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value 1,1/28/1991 23:32,False,,,, """ )