Embedding outputs & multiple duration preprocessing functionality#25
Embedding outputs & multiple duration preprocessing functionality#25payalchandak merged 8 commits intomainfrom
Conversation
📝 WalkthroughWalkthroughThis PR refactors the evaluation and task generation pipeline to support multiple duration-based configurations. It consolidates Changes
Sequence DiagramsequenceDiagram
participant User
participant Hydra as Hydra Config
participant Setup as _setup_eval
participant Trainer as Lightning Trainer
participant Model as Model
participant Test as _run_test
participant Predict as _run_predict
participant Output as Output Handler
User->>Hydra: Load config + mode
Hydra->>Setup: Pass cfg
Setup->>Setup: Load train config, seed, precision
Setup->>Model: Instantiate model
Setup->>Trainer: Instantiate trainer
Setup->>Setup: Validate task_set_dir
Setup-->>Hydra: Return (train_cfg, M, trainer, task_set_dir)
alt mode == 'test'
Hydra->>Test: Call _run_test
Test->>Trainer: trainer.validate/test
Test->>Test: Collect metrics + eval_time
Test->>Output: Write all_code_aucs_<date>.csv
else mode == 'predict'
Hydra->>Predict: Call _run_predict
Predict->>Trainer: trainer.predict (per query_code)
Predict->>Predict: Aggregate subject_id, prediction_time, query_embed
Predict->>Output: Write all_preds_<eval_codes>_<timestamp>.csv
Predict->>Output: Write query_embeds_<eval_codes>_<timestamp>.parquet
end
Output-->>User: Evaluation complete
Estimated Code Review Effort🎯 4 (Complex) | ⏱️ ~65 minutes Possibly Related PRs
Suggested Reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/every_query/tasks.py (1)
74-74:⚠️ Potential issue | 🔴 CriticalCritical bug:
not pl.col("censored")does not negate the column.Using Python's
notoperator on a Polars expression returnsFalse(since the expression object is truthy), rather than negating the column values. This meanscensor_falsewill always be an empty DataFrame.Use
~or.not_()to negate a Polars expression:🐛 Proposed fix
- censor_false = censor_df.filter(not pl.col("censored")) + censor_false = censor_df.filter(~pl.col("censored"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/tasks.py` at line 74, The filter currently uses Python's `not` on a Polars expression which doesn't negate the column; update the `censor_df.filter(...)` call (the `censor_false` assignment) to use a proper Polars negation such as the bitwise negation operator or the expression method (i.e., replace the `not pl.col("censored")` operand with `~pl.col("censored")` or `pl.col("censored").not_()` so the filter selects non-censored rows correctly).src/every_query/eval.py (1)
84-84:⚠️ Potential issue | 🟠 Major
TypeErrorifcfg.ood_codesisNone.When
cfg.ood_codesisNone, the expressioncode in cfg.ood_codesraisesTypeError: argument of type 'NoneType' is not iterable.🐛 Proposed fix
- "bucket": "ood" if code in cfg.ood_codes else "id", + "bucket": "ood" if cfg.ood_codes and code in cfg.ood_codes else "id",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/eval.py` at line 84, The code building the dict uses "code in cfg.ood_codes" which raises TypeError when cfg.ood_codes is None; update that expression to safely handle None (e.g., replace with "code in (cfg.ood_codes or [])" or use "cfg.ood_codes is not None and code in cfg.ood_codes") so the "bucket" value becomes "ood" only when cfg.ood_codes is set and contains code; locate the occurrence that constructs the dict (the line with "bucket": "ood" if code in cfg.ood_codes else "id") and change it to one of the safe forms to avoid iterating None.
🧹 Nitpick comments (7)
src/every_query/eval_suite/gen_task.py (1)
31-34: UseTypeErrorfor invalid type validation.Static analysis (TRY004) suggests
TypeErroris more appropriate thanValueErrorwhen validating types.Proposed fix
else: - raise ValueError( + raise TypeError( f"eval_codes must be a list or dict with id/ood subfields, got {type(eval_codes_obj)}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/eval_suite/gen_task.py` around lines 31 - 34, The raise currently uses ValueError for a wrong-type check; change it to raise TypeError instead so type validation uses the correct exception class. Update the raise in the branch that checks eval_codes_obj (the f"eval_codes must be a list or dict..." message) to raise TypeError with the same message, ensuring any callers expecting type-related errors receive a TypeError from the function handling eval_codes_obj.src/every_query/sample_codes/sample_eval_codes.py (1)
7-8: Stale comment and inconsistent sample sizes.The comment says "20 random codes from each ID and OOD" but the code now samples
NUM_CODES(5000) for ID and a hardcoded 500 for OOD. Consider:
- Updating the comment to reflect actual behavior
- Using a constant for OOD sample size for consistency
Proposed fix
-# goal is to sample 20 random codes from each ID and OOD pair to get a -# total of 40 codes to calculate auroc for +# Sample NUM_CODES codes from ID set and 500 from OOD set for AUROC evaluation PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet" SEED = 42 NUM_CODES = 5000 +NUM_OOD_CODES = 500Then update line 33:
- ood_sampled = random.sample(ood_codes, 500) + ood_sampled = random.sample(ood_codes, NUM_OOD_CODES)Also applies to: 32-33
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/sample_codes/sample_eval_codes.py` around lines 7 - 8, Update the stale top-of-file comment to accurately describe current behavior (ID samples use NUM_CODES and OOD samples use a separate constant) and introduce a named constant (e.g., OOD_NUM_CODES) in sample_eval_codes.py to replace the hardcoded 500; then change the code that currently uses the magic number 500 to use OOD_NUM_CODES and ensure any calculations or comments that mention "20" or "40" are revised to reflect the actual sample sizes and totals while keeping NUM_CODES referenced for the ID samples.src/every_query/sample_codes/sample_embedding_codes.py (2)
47-47: Avoid shadowing built-inhash.Renaming the variable prevents confusion and accidental misuse of the built-in.
♻️ Suggested change
-hash = stable_hash_list(sampled_embed_queries) +codes_hash = stable_hash_list(sampled_embed_queries) # ---- write file ---- -out_path = f"{OUT_DIR}/embed_{N_SAMPLES}_{hash}.yaml" +out_path = f"{OUT_DIR}/embed_{N_SAMPLES}_{codes_hash}.yaml"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/sample_codes/sample_embedding_codes.py` at line 47, The variable named "hash" shadows Python's built-in; rename it to a clear, non-conflicting identifier (e.g., sampled_queries_hash or queries_hash) where it's assigned via stable_hash_list(sampled_embed_queries) and update any subsequent references in this module (look for usages of "hash" near stable_hash_list and sampled_embed_queries). Ensure the new name is used consistently to avoid relying on the built-in hash function.
12-12: Unused variableN_REPEATS.
N_REPEATS = 1is defined but never referenced. Remove it or implement the intended repeat logic.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/sample_codes/sample_embedding_codes.py` at line 12, N_REPEATS is defined but never used; either remove the unused constant or implement the intended repeat behavior by wrapping the embedding-generation calls in a loop using N_REPEATS (e.g., for _ in range(N_REPEATS): ...) so the sample_embedding_codes module actually repeats the embed logic, or simply delete N_REPEATS if repetition is not needed.src/every_query/tasks.py (1)
120-121: Hoistread_query_codesoutside the file loop.
query_codesdepends only onread_codes_dir, not onfile_name. Moving it before the file loop avoids redundant reads ofcodes.parquetfor each shard.♻️ Suggested change
+ query_codes = read_query_codes(read_codes_dir) + print("Completed read_query_codes") + for split in [train_split, tuning_split, held_out_split]: shard_directory = f"{read_dir}/data/{split}" for file_name in os.listdir(shard_directory): if not file_name.endswith(".parquet"): continue # Read once per shard, shared across all durations events_df = read_event_shard(f"{shard_directory}/{file_name}") print("Completed read_event_shard") - query_codes = read_query_codes(read_codes_dir) - print("Completed read_query_codes")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/tasks.py` around lines 120 - 121, The call to read_query_codes(read_codes_dir) is inside the file loop causing repeated reads; move the query_codes = read_query_codes(read_codes_dir) invocation so it runs once before iterating over files (before the loop that uses file_name/shard processing). Update code that currently references query_codes inside the loop to use the hoisted variable; ensure read_codes_dir and query_codes remain in scope for functions like the shard-processing block that use them.src/every_query/train.py (1)
89-89: MD5 flagged by static analysis (S324).MD5 is used here for content-based directory naming, not cryptographic security. While the risk is low, switching to SHA-256 satisfies the linter and future-proofs the code.
♻️ Suggested change
- hash_hex = hashlib.md5(task_str.encode()).hexdigest() + hash_hex = hashlib.sha256(task_str.encode()).hexdigest()[:32]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/train.py` at line 89, Replace the MD5 usage for content-based directory naming with SHA-256 to satisfy the linter: instead of calling hashlib.md5(task_str.encode()).hexdigest() (producing hash_hex), compute a SHA-256 digest via hashlib.sha256(task_str.encode()).hexdigest() and use that value for directory naming; update any references to hash_hex if you rename the variable to reflect SHA256 (e.g., sha256_hex) and ensure encoding and usage remain identical otherwise.src/every_query/eval.py (1)
127-127: Trailing whitespace detected by CI.The pipeline reports trailing whitespace on these lines. Run your formatter (
ruff format) to fix.Also applies to: 139-139, 164-164
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/every_query/eval.py` at line 127, Remove the trailing whitespace characters reported by CI in the every_query.eval module: run your formatter (ruff format) or manually delete trailing spaces on the affected lines, then re-run tests; ensure the file every_query.eval (src/every_query/eval.py) has no trailing whitespace and commit the formatted file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/every_query/eval_suite/conf/eval_config.yaml`:
- Around line 12-13: The config uses eval_codes sometimes as a DictConfig with
id/ood fields and sometimes as a flat list, causing _run_predict
(eval.py:_run_predict) to receive cfg.query_codes in the wrong shape; update
either the YAML resolution or the predictor to normalize before
use—specifically, modify _run_predict to detect DictConfig vs list for
cfg.query_codes (or resolve query_codes to a flat list in the config) and
produce a single list[str] (e.g., combine eval_codes.id and eval_codes.ood when
present, otherwise convert the list-like value to a list) so that codes:
list[str] = cfg.query_codes always succeeds.
In `@src/every_query/eval.py`:
- Around line 165-166: The code reads hc = HydraConfig.get() and directly
indexes hc.runtime.choices["eval_codes"], which can raise a KeyError if the
choice is missing; change this to safely fetch the value (e.g., use
hc.runtime.choices.get("eval_codes", <fallback>) or check "eval_codes" in
hc.runtime.choices before indexing) and handle the missing case by setting a
sensible default or raising a clear error; update the usage around
eval_codes_choice_str to account for the fallback/validated value.
In `@src/every_query/process_composite/process_composite_config.yaml`:
- Line 6: predictions_df_path in process_composite_config.yaml is hardcoded to a
timestamped file name and will miss newer eval outputs; update the config to a
non-timestamped or glob path (e.g., use all_preds_hosp_admin_codes_*.csv) or
remove the timestamped filename so Hydra can override at runtime, and (if
choosing glob) add logic in eval.py to resolve the latest matching file (e.g.,
resolve predictions_df_path glob to the newest file before reading); ensure you
modify the predictions_df_path entry and the file-resolution in eval.py
accordingly.
In `@src/every_query/process_composite/process_composite.py`:
- Line 6: The import of sklearn.metrics.roc_auc_score in process_composite.py
will fail unless scikit-learn is declared as a dependency; add "scikit-learn"
(appropriate version constraint) to the project's dependencies in pyproject.toml
so the package is installed and the import of roc_auc_score succeeds when
running the code.
- Line 92: Replace the naive datetime.now() call used to build the timestamp
variable with a timezone-aware UTC call (use datetime.now(tz=timezone.utc) or
datetime.now(tz=datetime.timezone.utc)) so the timestamp is unambiguous; update
the imports in process_composite.py to include timezone (or reference
datetime.timezone) and keep the same strftime("%Y%m%d_%H%M%S") formatting when
constructing timestamp.
In `@src/every_query/sample_codes/sample_embedding_codes.py`:
- Line 10: The hardcoded PARQUET_PATH constant in sample_embedding_codes.py
causes portability/CI failures; replace the fixed value by reading from an
environment variable or config (e.g., os.environ.get('PARQUET_PATH') or use a
config loader) and assign that to PARQUET_PATH (or a new variable) so callers
use the configurable path; add a short existence check (os.path.exists) after
resolving the path and raise a clear error if missing to fail fast and guide
users to set the env var or config.
In `@src/every_query/sample_codes/sample_eval_codes.py`:
- Line 43: The print statement in sample_eval_codes.py contains a typo: update
the string in the print call (the f-string "Writtten to {out_fp}") to "Written
to {out_fp}" so the output reads correctly; locate the print(...) invocation in
the module (the f-string print) and correct the misspelling only.
- Line 9: The constant PARQUET_PATH is hardcoded to a user-specific absolute
path which breaks CI; change sample_eval_codes.py to read the path from
configuration or environment and fall back to a portable relative/default
location (e.g., use os.environ.get('PARQUET_PATH') or a config loader to set
PARQUET_PATH, with a sensible repo-relative default like
"data/.../codes.parquet"); update the PARQUET_PATH symbol accordingly and ensure
any code using PARQUET_PATH continues to work with the environment/config-driven
value.
In `@src/every_query/train.py`:
- Around line 86-92: The code builds durations and then reads durations[0] which
will IndexError when cfg.query.duration_min >= cfg.query.duration_max; before
using durations (and before computing task_str/hash_hex/write_dir), validate the
range and either (a) raise a clear ValueError if cfg.query.duration_min >=
cfg.query.duration_max with a descriptive message referencing
cfg.query.duration_min and cfg.query.duration_max, or (b) adjust the range
construction if the intent was inclusive (e.g., use
range(cfg.query.duration_min, cfg.query.duration_max + 1)); ensure the check/fix
is applied near the durations variable so that first_duration and downstream
uses are safe.
---
Outside diff comments:
In `@src/every_query/eval.py`:
- Line 84: The code building the dict uses "code in cfg.ood_codes" which raises
TypeError when cfg.ood_codes is None; update that expression to safely handle
None (e.g., replace with "code in (cfg.ood_codes or [])" or use "cfg.ood_codes
is not None and code in cfg.ood_codes") so the "bucket" value becomes "ood" only
when cfg.ood_codes is set and contains code; locate the occurrence that
constructs the dict (the line with "bucket": "ood" if code in cfg.ood_codes else
"id") and change it to one of the safe forms to avoid iterating None.
In `@src/every_query/tasks.py`:
- Line 74: The filter currently uses Python's `not` on a Polars expression which
doesn't negate the column; update the `censor_df.filter(...)` call (the
`censor_false` assignment) to use a proper Polars negation such as the bitwise
negation operator or the expression method (i.e., replace the `not
pl.col("censored")` operand with `~pl.col("censored")` or
`pl.col("censored").not_()` so the filter selects non-censored rows correctly).
---
Nitpick comments:
In `@src/every_query/eval_suite/gen_task.py`:
- Around line 31-34: The raise currently uses ValueError for a wrong-type check;
change it to raise TypeError instead so type validation uses the correct
exception class. Update the raise in the branch that checks eval_codes_obj (the
f"eval_codes must be a list or dict..." message) to raise TypeError with the
same message, ensuring any callers expecting type-related errors receive a
TypeError from the function handling eval_codes_obj.
In `@src/every_query/eval.py`:
- Line 127: Remove the trailing whitespace characters reported by CI in the
every_query.eval module: run your formatter (ruff format) or manually delete
trailing spaces on the affected lines, then re-run tests; ensure the file
every_query.eval (src/every_query/eval.py) has no trailing whitespace and commit
the formatted file.
In `@src/every_query/sample_codes/sample_embedding_codes.py`:
- Line 47: The variable named "hash" shadows Python's built-in; rename it to a
clear, non-conflicting identifier (e.g., sampled_queries_hash or queries_hash)
where it's assigned via stable_hash_list(sampled_embed_queries) and update any
subsequent references in this module (look for usages of "hash" near
stable_hash_list and sampled_embed_queries). Ensure the new name is used
consistently to avoid relying on the built-in hash function.
- Line 12: N_REPEATS is defined but never used; either remove the unused
constant or implement the intended repeat behavior by wrapping the
embedding-generation calls in a loop using N_REPEATS (e.g., for _ in
range(N_REPEATS): ...) so the sample_embedding_codes module actually repeats the
embed logic, or simply delete N_REPEATS if repetition is not needed.
In `@src/every_query/sample_codes/sample_eval_codes.py`:
- Around line 7-8: Update the stale top-of-file comment to accurately describe
current behavior (ID samples use NUM_CODES and OOD samples use a separate
constant) and introduce a named constant (e.g., OOD_NUM_CODES) in
sample_eval_codes.py to replace the hardcoded 500; then change the code that
currently uses the magic number 500 to use OOD_NUM_CODES and ensure any
calculations or comments that mention "20" or "40" are revised to reflect the
actual sample sizes and totals while keeping NUM_CODES referenced for the ID
samples.
In `@src/every_query/tasks.py`:
- Around line 120-121: The call to read_query_codes(read_codes_dir) is inside
the file loop causing repeated reads; move the query_codes =
read_query_codes(read_codes_dir) invocation so it runs once before iterating
over files (before the loop that uses file_name/shard processing). Update code
that currently references query_codes inside the loop to use the hoisted
variable; ensure read_codes_dir and query_codes remain in scope for functions
like the shard-processing block that use them.
In `@src/every_query/train.py`:
- Line 89: Replace the MD5 usage for content-based directory naming with SHA-256
to satisfy the linter: instead of calling
hashlib.md5(task_str.encode()).hexdigest() (producing hash_hex), compute a
SHA-256 digest via hashlib.sha256(task_str.encode()).hexdigest() and use that
value for directory naming; update any references to hash_hex if you rename the
variable to reflect SHA256 (e.g., sha256_hex) and ensure encoding and usage
remain identical otherwise.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c1cd04ef-734b-4357-bca6-8333789f5d3f
📒 Files selected for processing (14)
src/every_query/config.yamlsrc/every_query/eval.pysrc/every_query/eval_composite.pysrc/every_query/eval_suite/conf/eval_composite_config.yamlsrc/every_query/eval_suite/conf/eval_config.yamlsrc/every_query/eval_suite/conf/gen_tasks_config.yamlsrc/every_query/eval_suite/gen_task.pysrc/every_query/lightning_module.pysrc/every_query/process_composite/process_composite.pysrc/every_query/process_composite/process_composite_config.yamlsrc/every_query/sample_codes/sample_embedding_codes.pysrc/every_query/sample_codes/sample_eval_codes.pysrc/every_query/tasks.pysrc/every_query/train.py
💤 Files with no reviewable changes (2)
- src/every_query/eval_suite/conf/eval_composite_config.yaml
- src/every_query/eval_composite.py
| id_codes: ${eval_codes.id} | ||
| ood_codes: ${eval_codes.ood} |
There was a problem hiding this comment.
Incompatible eval_codes structure between test and predict modes.
There's a structural mismatch:
- Lines 12-13:
id_codes: ${eval_codes.id}andood_codes: ${eval_codes.ood}assumeeval_codesis a DictConfig (fromsample_eval_codes.py) - Line 20:
query_codes: ${eval_codes}assigns the entireeval_codesobject - Line 26:
mode: predict
In predict mode, _run_predict (eval.py:110) expects cfg.query_codes to be a list: codes: list[str] = cfg.query_codes. When eval_codes is a DictConfig with id/ood subfields, this assignment fails.
Conversely, when eval_codes is a flat ListConfig (from sample_embedding_codes.py), the ${eval_codes.id} interpolations on lines 12-13 will fail.
Suggested approach
Consider unifying query_codes resolution in the config or Python code:
-query_codes: ${eval_codes}
+# For predict mode with flat list eval_codes:
+# query_codes: ${eval_codes}
+# For predict mode with dict eval_codes:
+# query_codes: ${eval_codes.id} + ${eval_codes.ood} # Not valid YAML - handle in PythonOr handle the structure in _run_predict:
if isinstance(cfg.query_codes, DictConfig):
codes = list(cfg.query_codes.get("id", [])) + list(cfg.query_codes.get("ood", []))
else:
codes = list(cfg.query_codes)Also applies to: 20-20, 26-26
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/eval_suite/conf/eval_config.yaml` around lines 12 - 13, The
config uses eval_codes sometimes as a DictConfig with id/ood fields and
sometimes as a flat list, causing _run_predict (eval.py:_run_predict) to receive
cfg.query_codes in the wrong shape; update either the YAML resolution or the
predictor to normalize before use—specifically, modify _run_predict to detect
DictConfig vs list for cfg.query_codes (or resolve query_codes to a flat list in
the config) and produce a single list[str] (e.g., combine eval_codes.id and
eval_codes.ood when present, otherwise convert the list-like value to a list) so
that codes: list[str] = cfg.query_codes always succeeds.
| hc = HydraConfig.get() | ||
| eval_codes_choice_str = hc.runtime.choices["eval_codes"] |
There was a problem hiding this comment.
Potential KeyError if eval_codes choice is missing.
If the Hydra config doesn't define eval_codes in runtime choices, this will raise a KeyError. Consider using .get() with a fallback or validating the key exists.
🛡️ Proposed fix
hc = HydraConfig.get()
- eval_codes_choice_str = hc.runtime.choices["eval_codes"]
+ eval_codes_choice_str = hc.runtime.choices.get("eval_codes", "unknown")🧰 Tools
🪛 GitHub Actions: Code Quality PR
[error] Trailing whitespace detected and fixed by pre-commit hooks.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/eval.py` around lines 165 - 166, The code reads hc =
HydraConfig.get() and directly indexes hc.runtime.choices["eval_codes"], which
can raise a KeyError if the choice is missing; change this to safely fetch the
value (e.g., use hc.runtime.choices.get("eval_codes", <fallback>) or check
"eval_codes" in hc.runtime.choices before indexing) and handle the missing case
by setting a sensible default or raising a clear error; update the usage around
eval_codes_choice_str to account for the fallback/validated value.
| task_labels_df_path: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/${task_name}/held_out | ||
|
|
||
| output_path: ${model_path}/eval/readmiss_30d/ | ||
| predictions_df_path: ${model_path}/eval/${task_name}/${split_name}/all_preds_hosp_admin_codes_20260305_014527.csv |
There was a problem hiding this comment.
Hardcoded timestamp in path will break with new prediction runs.
The path all_preds_hosp_admin_codes_20260305_014527.csv contains a fixed timestamp. Since eval.py generates timestamped filenames dynamically, this config will fail to find newly generated prediction files.
Consider using a glob pattern, latest-file resolution in Python, or removing the timestamp to let Hydra override at runtime.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/process_composite/process_composite_config.yaml` at line 6,
predictions_df_path in process_composite_config.yaml is hardcoded to a
timestamped file name and will miss newer eval outputs; update the config to a
non-timestamped or glob path (e.g., use all_preds_hosp_admin_codes_*.csv) or
remove the timestamped filename so Hydra can override at runtime, and (if
choosing glob) add logic in eval.py to resolve the latest matching file (e.g.,
resolve predictions_df_path glob to the newest file before reading); ensure you
modify the predictions_df_path entry and the file-resolution in eval.py
accordingly.
| @@ -4,6 +4,7 @@ | |||
| import polars as pl | |||
| from omegaconf import DictConfig | |||
| from sklearn.metrics import roc_auc_score | |||
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if scikit-learn is in dependencies
cat pyproject.toml 2>/dev/null | grep -i "scikit\|sklearn" || cat setup.py 2>/dev/null | grep -i "scikit\|sklearn" || cat requirements*.txt 2>/dev/null | grep -i "scikit\|sklearn"Repository: payalchandak/EveryQuery
Length of output: 49
🏁 Script executed:
# List potential dependency files
git ls-files | grep -E "(requirements|setup|pyproject|poetry|tox|environment)" | head -20Repository: payalchandak/EveryQuery
Length of output: 115
🏁 Script executed:
# Check what's imported and used in process_composite.py
cat -n src/every_query/process_composite/process_composite.py | head -20Repository: payalchandak/EveryQuery
Length of output: 773
🏁 Script executed:
# Search for any sklearn references in the entire repo
rg -i "sklearn|scikit" --type py | head -30Repository: payalchandak/EveryQuery
Length of output: 270
🏁 Script executed:
cat pyproject.tomlRepository: payalchandak/EveryQuery
Length of output: 3025
🏁 Script executed:
cat .github/actions/setup/action.yamlRepository: payalchandak/EveryQuery
Length of output: 803
Add scikit-learn to project dependencies.
The code imports sklearn.metrics.roc_auc_score at line 6, but scikit-learn is not listed in pyproject.toml. Add scikit-learn to the dependencies section to resolve the import failure.
🧰 Tools
🪛 GitHub Actions: Tests
[error] 6-6: ModuleNotFoundError: No module named 'sklearn' while importing roc_auc_score
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/process_composite/process_composite.py` at line 6, The import
of sklearn.metrics.roc_auc_score in process_composite.py will fail unless
scikit-learn is declared as a dependency; add "scikit-learn" (appropriate
version constraint) to the project's dependencies in pyproject.toml so the
package is installed and the import of roc_auc_score succeeds when running the
code.
|
|
||
| out_dir = Path(cfg.output_path) | ||
| out_fp = out_dir / f"{cfg.task_name}_all_preds.csv" | ||
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
There was a problem hiding this comment.
Use timezone-aware datetime.now() to avoid ambiguity.
Static analysis (DTZ005) flags datetime.now() called without a timezone argument. Use datetime.now(tz=UTC) for consistency with other files in this PR (e.g., eval.py).
Proposed fix
-from datetime import datetime
+from datetime import UTC, datetime- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| from datetime import UTC, datetime | |
| timestamp = datetime.now(tz=UTC).strftime("%Y%m%d_%H%M%S") |
🧰 Tools
🪛 GitHub Actions: Code Quality PR
[error] 92-92: DTZ005: datetime.now() called without a tz argument. Pass a timezone to datetime.now().
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/process_composite/process_composite.py` at line 92, Replace
the naive datetime.now() call used to build the timestamp variable with a
timezone-aware UTC call (use datetime.now(tz=timezone.utc) or
datetime.now(tz=datetime.timezone.utc)) so the timestamp is unambiguous; update
the imports in process_composite.py to include timezone (or reference
datetime.timezone) and keep the same strftime("%Y%m%d_%H%M%S") formatting when
constructing timestamp.
| # ------------------- | ||
| # Config | ||
| # ------------------- | ||
| PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet" |
There was a problem hiding this comment.
Hardcoded absolute path breaks CI and portability.
The pipeline fails with FileNotFoundError because /users/gbk2114/data/... is a user-specific path. Use an environment variable or configuration parameter to make this script portable.
🐛 Proposed fix
-PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+PARQUET_PATH = os.environ.get("PROCESSED", "") + "/metadata/codes.parquet"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/sample_codes/sample_embedding_codes.py` at line 10, The
hardcoded PARQUET_PATH constant in sample_embedding_codes.py causes
portability/CI failures; replace the fixed value by reading from an environment
variable or config (e.g., os.environ.get('PARQUET_PATH') or use a config loader)
and assign that to PARQUET_PATH (or a new variable) so callers use the
configurable path; add a short existence check (os.path.exists) after resolving
the path and raise a clear error if missing to fail fast and guide users to set
the env var or config.
| @@ -8,6 +8,7 @@ | |||
| # total of 40 codes to calculate auroc for | |||
| PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet" | |||
There was a problem hiding this comment.
Hardcoded absolute path breaks CI and portability.
The pipeline fails with FileNotFoundError because /users/gbk2114/data/MIMIC_MEDS/... is a user-specific path that doesn't exist in CI. Consider using environment variables or relative paths.
Proposed fix
-PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+PARQUET_PATH = os.environ.get("CODES_PARQUET_PATH", "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/sample_codes/sample_eval_codes.py` at line 9, The constant
PARQUET_PATH is hardcoded to a user-specific absolute path which breaks CI;
change sample_eval_codes.py to read the path from configuration or environment
and fall back to a portable relative/default location (e.g., use
os.environ.get('PARQUET_PATH') or a config loader to set PARQUET_PATH, with a
sensible repo-relative default like "data/.../codes.parquet"); update the
PARQUET_PATH symbol accordingly and ensure any code using PARQUET_PATH continues
to work with the environment/config-driven value.
| with open(out_fp, "w") as f: | ||
| yaml.safe_dump(out_codes, f) | ||
|
|
||
| print(f"Writtten to {out_fp}") |
There was a problem hiding this comment.
Fix typo: "Writtten" → "Written".
Flagged by codespell in pipeline.
Proposed fix
- print(f"Writtten to {out_fp}")
+ print(f"Written to {out_fp}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print(f"Writtten to {out_fp}") | |
| print(f"Written to {out_fp}") |
🧰 Tools
🪛 GitHub Actions: Code Quality PR
[error] 43-43: Writtten (typo) found by codespell. Consider correcting to 'Written'.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/sample_codes/sample_eval_codes.py` at line 43, The print
statement in sample_eval_codes.py contains a typo: update the string in the
print call (the f-string "Writtten to {out_fp}") to "Written to {out_fp}" so the
output reads correctly; locate the print(...) invocation in the module (the
f-string print) and correct the misspelling only.
| durations = list(range(cfg.query.duration_min, cfg.query.duration_max)) | ||
|
|
||
| task_str = f"{'|'.join(sorted(cfg.query.codes))}" | ||
| task_str = f"{'|'.join(sorted(cfg.query.codes))}_{'|'.join(str(d) for d in sorted(durations))}" | ||
| hash_hex = hashlib.md5(task_str.encode()).hexdigest() | ||
| write_dir = f"{cfg.query.task_dir}/collated/{hash_hex}" | ||
| write_dir = f"{task_dir}/collated/{hash_hex}" | ||
|
|
||
| first_duration = durations[0] |
There was a problem hiding this comment.
Handle empty durations range to prevent IndexError.
If cfg.query.duration_min >= cfg.query.duration_max, durations will be empty, and durations[0] on line 92 will raise an IndexError.
🛡️ Proposed fix
task_dir = cfg.query.task_dir
durations = list(range(cfg.query.duration_min, cfg.query.duration_max))
+ if not durations:
+ raise ValueError(
+ f"Invalid duration range: duration_min ({cfg.query.duration_min}) "
+ f"must be less than duration_max ({cfg.query.duration_max})"
+ )
task_str = f"{'|'.join(sorted(cfg.query.codes))}_{'|'.join(str(d) for d in sorted(durations))}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| durations = list(range(cfg.query.duration_min, cfg.query.duration_max)) | |
| task_str = f"{'|'.join(sorted(cfg.query.codes))}" | |
| task_str = f"{'|'.join(sorted(cfg.query.codes))}_{'|'.join(str(d) for d in sorted(durations))}" | |
| hash_hex = hashlib.md5(task_str.encode()).hexdigest() | |
| write_dir = f"{cfg.query.task_dir}/collated/{hash_hex}" | |
| write_dir = f"{task_dir}/collated/{hash_hex}" | |
| first_duration = durations[0] | |
| durations = list(range(cfg.query.duration_min, cfg.query.duration_max)) | |
| if not durations: | |
| raise ValueError( | |
| f"Invalid duration range: duration_min ({cfg.query.duration_min}) " | |
| f"must be less than duration_max ({cfg.query.duration_max})" | |
| ) | |
| task_str = f"{'|'.join(sorted(cfg.query.codes))}_{'|'.join(str(d) for d in sorted(durations))}" | |
| hash_hex = hashlib.md5(task_str.encode()).hexdigest() | |
| write_dir = f"{task_dir}/collated/{hash_hex}" | |
| first_duration = durations[0] |
🧰 Tools
🪛 Ruff (0.15.4)
[error] 89-89: Probable use of insecure hash functions in hashlib: md5
(S324)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/every_query/train.py` around lines 86 - 92, The code builds durations and
then reads durations[0] which will IndexError when cfg.query.duration_min >=
cfg.query.duration_max; before using durations (and before computing
task_str/hash_hex/write_dir), validate the range and either (a) raise a clear
ValueError if cfg.query.duration_min >= cfg.query.duration_max with a
descriptive message referencing cfg.query.duration_min and
cfg.query.duration_max, or (b) adjust the range construction if the intent was
inclusive (e.g., use range(cfg.query.duration_min, cfg.query.duration_max + 1));
ensure the check/fix is applied near the durations variable so that
first_duration and downstream uses are safe.
Summary by CodeRabbit
New Features
Refactor