Skip to content

Greg testing#24

Merged
payalchandak merged 24 commits intomainfrom
greg_testing
Feb 10, 2026
Merged

Greg testing#24
payalchandak merged 24 commits intomainfrom
greg_testing

Conversation

@gkondas
Copy link
Collaborator

@gkondas gkondas commented Nov 25, 2025

#22 #21 #23

Summary by CodeRabbit

Release Notes

  • New Features

    • Added evaluation suite for generating prediction indices, task datasets, and model evaluation workflows
    • Introduced composite query evaluation with per-code ROC AUC metrics and probability aggregation
    • Added code sampling utilities for training and evaluation set generation
  • Configuration & Improvements

    • Updated training defaults with optimized learning rates, improved early stopping, and reduced max steps
    • Enhanced model architecture with configurable dropout regularization
    • Extended dataset support for evaluation metrics including subject IDs and prediction timestamps
  • Documentation

    • Added evaluation suite instructions

@coderabbitai
Copy link

coderabbitai bot commented Nov 25, 2025

📝 Walkthrough

Walkthrough

This PR introduces a comprehensive evaluation pipeline for EveryQuery. It adds scripts for sampling training/OOD codes, generating evaluation indices and per-code task data, performing model evaluation, and aggregating prediction results. Configuration files for training and evaluation are updated with revised hyperparameters and new workflows. The model gains dropout configuration, the datamodule supports held-out splits with subject/time tracking, and new utilities enable code slugging and data aggregation.

Changes

Cohort / File(s) Summary
Training & Evaluation Configuration
src/every_query/config.yaml, src/every_query/aces_to_eq/config.yaml, src/every_query/eval_suite/conf/*config.yaml
Training hyperparameters updated (learning rate: 5.0e-4→1e-5, batch_size: 160→128, max_steps: 1M→40K); new early_stopping callback; evaluation configs added for index-time generation, task generation, per-code and composite evaluation.
Code Sampling & Training Sets
src/every_query/sample_codes/sample_train_codes.py, src/every_query/sample_codes/sample_eval_codes.py
New scripts generate balanced train/OOD code samples with stable SHA-1 hashing; outputs YAML files with training set IDs and OOD codes for each configuration.
Evaluation Data Preparation
src/every_query/eval_suite/gen_index_times.py, src/every_query/eval_suite/gen_task.py, src/every_query/eval_suite/conf/gen_index_times_config.yaml, src/every_query/eval_suite/conf/gen_tasks_config.yaml
New workflows: gen_index_times samples prediction times per subject with deterministic hashing; gen_task.py joins index data with per-code features and generates per-code task Parquets.
Model & Output Enhancements
src/every_query/model.py
Adds mlp_dropout parameter (default 0.1) to EveryQueryModel; introduces logits_to_probs static method and occurs_probs/censor_probs properties on EveryQueryOutput for probability computation.
Lightning Module & Data Handling
src/every_query/lightning_module.py, src/every_query/dataset.py, src/every_query/lit_datamodule.py
predict_step now returns dict[str, torch.Tensor] instead of EveryQueryOutput; EveryQueryBatch gains subject_id and prediction_time fields for held-out split tracking; new Datamodule class delegates predict_dataloader to test_dataloader.
Training Orchestration
src/every_query/train.py
Adds OmegaConf resolvers (list_len, int_prod); collate_tasks returns write_dir path; new only_preprocess flag enables early exit after preprocessing; held_out_split removed from collation loop.
Task Code Management
src/every_query/tasks.py
Updates query code reading source from read_dir to read_codes_dir; adjusts output path formatting; fixes skip message to use f-string.
ACES Integration
src/every_query/aces_to_eq/aces_to_eq.py
New module joins EQ and ACES Parquet shards, generates per-code task dataframes with subject_id, prediction_time, task_label, occurs, and censoring information.
Per-Code Evaluation
src/every_query/eval.py, src/every_query/eval_suite/README.md
New eval.py orchestrates per-code model testing, collects occurs_auc and censor_auc metrics, and writes meds_death_from_discharge.csv; README documents evaluation workflow instructions.
Composite Prediction Aggregation
src/every_query/eval_composite.py, src/every_query/process_composite/process_composite.py, src/every_query/process_composite/get_per_code_from_composite.py, src/every_query/process_composite/process_composite_config.yaml
New workflows for aggregating predictions across codes (max/or/sum aggregation), computing ROC AUC metrics, and extracting per-code performance from composite predictions.
Utility Functions
src/every_query/utils/codes.py
New utilities: code_slug generates deterministic sanitized hash-prefixed codes; values_as_list extracts kwargs as list.
Project Maintenance
.gitignore, env.yaml, README.md, src/every_query/README.md
Deleted env.yaml and src/every_query/README.md; added gitignore entries for eval_codes/, query_codes/, train_codes/; minor README formatting fixes.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant EvalScript as eval.py
    participant Model as LightningModule
    participant Trainer
    participant DataModule
    participant Metrics as Results

    User->>EvalScript: Execute with config
    EvalScript->>EvalScript: Load model from checkpoint
    EvalScript->>EvalScript: Iterate over code list
    
    loop For each code
        EvalScript->>EvalScript: Derive code_slug
        EvalScript->>DataModule: Configure with per-code<br/>task_labels_dir
        EvalScript->>DataModule: Instantiate
        EvalScript->>Trainer: Run test(model, datamodule)
        Trainer->>Model: Execute forward pass
        Model->>Trainer: Return predictions
        Trainer->>Metrics: Extract occurs_auc,<br/>censor_auc
        Metrics->>EvalScript: Return per-code metrics
        EvalScript->>EvalScript: Append row to results
    end
    
    EvalScript->>Metrics: Aggregate all codes
    Metrics->>Metrics: Write to CSV
    EvalScript->>User: Output: meds_death_from_discharge.csv
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Poem

🐰 In burrows deep where queries bloom,
New codes and slugs light up the room,
Per-subject times and tasks aligned,
A pipeline's art, so well-designed,
Predictions aggregate with care—
Evaluation magic everywhere!

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The pull request title 'Greg testing' is vague and generic, using a non-descriptive term that does not convey meaningful information about the changeset. Replace the generic title with a descriptive summary of the main change, such as 'Refactor composite evaluation to use predict step' or similar, to clearly indicate the primary objective.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch greg_testing

Comment @coderabbitai help to get the list of available commands and usage tips.

"""
Use best_model.ckpt if present; otherwise fall back to last.ckpt.
"""
best_model_ckpt = run_dir / "best_model.ckpt"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading best from run_dir / "best_model.ckpt" but last from run_dir / "checkpoints" / "last.ckpt"

cfg.output_dir = str(run_dir)

# 2) Make eval reproducible: same seed & matmul precision as training
seed = cfg.get("seed", None)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different runs/models may be trained on different seeds. I think we want to use a stable test seed to make sure we sample the same patients and histories in the test set, rather than using the seed set during training?


# 4) Instantiate datamodule, lightning_module, trainer from saved cfg
logger.info("Instantiating datamodule...")
datamodule = instantiate(cfg.datamodule)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are you setting the inference query? Would this not load the training queries into the datamodule?

sample_times_per_subject: 20 # epochs worth of data
codes:
- MEDS_DEATH
- LAB//50976//g/dL//value_[7.2,7.4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't add concrete code lists into the github as they are dataset specific.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
src/every_query/eval_config.yaml (1)

5-7: Consider using environment variables for portability.

The hardcoded absolute paths contain a specific username and directory structure. For better portability across different environments and users, consider using environment variables (e.g., ${oc.env:MODEL_RUN_DIR}) similar to the pattern used in config.yaml.

src/every_query/train.py (1)

193-193: Consider removing debug print statement.

The print("fitting model") statement appears to be debug code. Consider replacing it with a proper logger call or removing it entirely.

-    print("fitting model")
+    logger.info("Starting model training...")
src/every_query/config.yaml (1)

67-67: Consider making the W&B entity configurable.

The W&B entity is hardcoded to a specific user. For better portability and multi-user collaboration, consider using an environment variable or a separate config field that can be easily overridden.

-    entity: "gregkondas9-columbia-university"
+    entity: ${oc.env:WANDB_ENTITY,"gregkondas9-columbia-university"}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7ae4205 and 1d763f9.

📒 Files selected for processing (7)
  • env.yaml (0 hunks)
  • src/every_query/config.yaml (4 hunks)
  • src/every_query/eval.py (1 hunks)
  • src/every_query/eval_config.yaml (1 hunks)
  • src/every_query/query_codes (1 hunks)
  • src/every_query/tasks.py (3 hunks)
  • src/every_query/train.py (6 hunks)
💤 Files with no reviewable changes (1)
  • env.yaml
🔇 Additional comments (17)
src/every_query/eval_config.yaml (1)

13-23: LGTM!

The derived path configurations and control flags are well-structured with clear variable interpolation and sensible defaults.

src/every_query/tasks.py (3)

74-74: LGTM!

The explicit comparison == False is appropriate for Polars column expressions and improves clarity.


120-120: Good fix!

Converting to an f-string properly formats the skip message with the filename.


126-126: LGTM!

Using read_codes_dir for query codes correctly separates the PROCESSED data source from INTERMEDIATE, aligning with the updated data source separation.

src/every_query/eval.py (5)

1-17: LGTM!

Imports are appropriate and the values_as_list helper is straightforward.


32-35: Verify seed source for evaluation reproducibility.

The code uses the training seed (train_cfg.get("seed", 42)) rather than the evaluation seed (cfg.seed). As noted in a previous comment, different training runs may use different seeds, but evaluation should use a stable test seed to ensure consistent sampling of test patients and histories across different model evaluations.

Consider whether the evaluation should use:

seed = cfg.get("seed", 42)  # from eval_config.yaml

instead of:

seed = train_cfg.get("seed", 42)  # from training config

46-57: LGTM!

Manifest loading and filtering logic is clear and correct.


80-91: LGTM!

Output handling correctly respects the do_overwrite flag and safely creates parent directories.


65-66: Verify that evaluation queries are properly configured.

As noted in a previous comment, the datamodule is instantiated from train_cfg.datamodule, which may contain training query codes rather than the per-task evaluation queries. The code updates task_labels_dir (line 65) but doesn't explicitly set which query codes should be evaluated.

Please verify:

  1. Does the datamodule automatically infer query codes from the task_labels_dir?
  2. Or should the query codes be explicitly set for each evaluation iteration?

Generate a script to check how query codes are determined:

#!/bin/bash
# Check how query codes are loaded in the datamodule and dataset classes
ast-grep --pattern $'class $CLASS {
  $$$
  def __init__($$$) {
    $$$
    query$$$
    $$$
  }
  $$$
}'
src/every_query/train.py (3)

22-39: LGTM!

The OmegaConf resolver functions are well-documented with clear examples and appropriate rounding behavior for int_prod.


84-132: LGTM with a minor observation.

The per-split handling and sample sizing logic is well-structured. The early skip for existing shards (lines 110-112) is a good optimization.

Minor note: Line 108 logs every file being collated, which may be verbose for large datasets. Consider using a progress bar or reducing log frequency if this becomes noisy.


141-143: LGTM!

The only_preprocess flag provides useful control for running only the data collation step without training.

src/every_query/config.yaml (5)

1-10: LGTM!

The defaults and control flags are well-structured. The query_codes reference pattern avoids hardcoding dataset-specific code lists in the main config.


12-16: LGTM!

The per-split sample sizing configuration is clear and the codes reference maintains flexibility.


18-30: LGTM with a performance note.

The datamodule configuration is well-structured. Note that num_workers: 1 may limit data loading throughput. Consider increasing this value if data loading becomes a bottleneck during training.


32-53: LGTM - significant learning rate change.

The updated targets and dynamic warmup/training step calculation using resolvers is well-designed. Note that the learning rate was decreased from 5e-4 to 1e-4 (5x reduction), which will affect training dynamics and convergence speed.


71-112: LGTM - comprehensive trainer configuration.

The trainer configuration is thorough with appropriate settings for monitoring, checkpointing, early stopping, and validation. The early stopping configuration with 20-step patience and tuning loss monitoring should help prevent overfitting.

check_on_train_epoch_end: false # set to false so that early stopping is checked after each validation loop
default_root_dir: ${output_dir}
min_epochs: 1 # prevents early stopping
min_epochs: 0 # prevents early stopping
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Correct the misleading comment.

The comment states "prevents early stopping," but min_epochs: 0 actually allows training to stop immediately if early stopping conditions are met. This contradicts the intent suggested by the comment.

If the intent is to prevent very early stopping, consider:

-  min_epochs: 0 # prevents early stopping
+  min_epochs: 0 # allows early stopping from the first epoch

Or set to a higher value if you want a minimum training duration:

-  min_epochs: 0 # prevents early stopping
+  min_epochs: 10 # ensures at least 10 epochs before early stopping
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
min_epochs: 0 # prevents early stopping
min_epochs: 0 # allows early stopping from the first epoch
🤖 Prompt for AI Agents
In src/every_query/config.yaml around line 98, the comment "prevents early
stopping" is misleading because min_epochs: 0 allows training to stop
immediately when early-stopping criteria are met; update the comment to
accurately state that 0 permits immediate early stopping (e.g., "allows
immediate early stopping — training can stop at epoch 0 if criteria met"), or if
the intent is to avoid very early termination, set min_epochs to a positive
integer (e.g., 1 or higher) and update the comment to reflect the chosen minimum
epoch value.

@@ -1,13 +1,19 @@
defaults:
- _self_
- query_codes: "2_10000_ID"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the notation mean here? Just curious

min_epochs: 1 # prevents early stopping
min_epochs: 0 # prevents early stopping
max_epochs: null # We don't control the max epochs, we control the max steps.
max_steps: 1000000
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably want to bring this down to 100k?

logging.basicConfig(level=logging.INFO)


def values_as_list(**kwargs) -> list[Any]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to utils?

train_cfg.trainer.logger = ""

seed = train_cfg.get("seed", 42)
if seed is not None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are setting a default seed 42, how would seed ever be None

.select(["bucket", "code", "code_slug"])
.unique()
.sort(["bucket", "code"])
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i assume this is reading a list of codes that we want to evaluate on, is this right? what is a bucket or code_slug?

code = row["code"]
code_slug = row["code_slug"]

task_labels_dir = str(task_set_dir / bucket / code_slug)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is setting code_slug somehow equivalent to specifying a eval code?

# inputs
model_run_dir: /users/gbk2114/EveryQuery/results/outputs/2025-12-10/17-58-26
task_parent_dir: /users/gbk2114/eq_stuff/tasks/manual
task_set_name: cde742daabc4ddc5e79f29102ea14107
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does hash correspond to?

model_run_dir: /users/gbk2114/EveryQuery/results/outputs/2025-12-10/17-58-26
task_parent_dir: /users/gbk2114/eq_stuff/tasks/manual
task_set_name: cde742daabc4ddc5e79f29102ea14107
eval_code_idx: 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this indexing?

7
"""
return round(x * y)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to utils?

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (4)
src/every_query/train.py (1)

22-40: Consider moving resolver utilities to a shared utils module.

As noted in previous reviews, these OmegaConf resolver utilities might be better placed in a shared utils module for reusability across the codebase.

src/every_query/eval.py (2)

16-17: Consolidate duplicate utility function.

This function is identical to the one in train.py (lines 42-43). As noted in previous reviews, consider moving it to a shared utils module to eliminate duplication.


32-35: Remove unreachable condition check.

The seed check at line 33 is unreachable because train_cfg.get("seed", 42) returns 42 as the default, so seed will never be None. This was flagged in previous reviews.

🔎 Suggested fix
-    seed = train_cfg.get("seed", 42)
-    if seed is not None:
-        logger.info(f"Seeding with seed={seed}")
-        seed_everything(seed, workers=True)
+    seed = train_cfg.get("seed", 42)
+    logger.info(f"Seeding with seed={seed}")
+    seed_everything(seed, workers=True)
src/every_query/config.yaml (1)

97-97: Correct the misleading comment.

The comment states "prevents early stopping," but min_epochs: 0 actually allows training to stop immediately if early stopping conditions are met. This was flagged in previous reviews.

🔎 Suggested fix
-  min_epochs: 0 # prevents early stopping
+  min_epochs: 0 # allows early stopping from the first epoch

Or if you want to prevent very early stopping:

-  min_epochs: 0 # prevents early stopping
+  min_epochs: 1 # ensures at least one complete epoch before early stopping
🧹 Nitpick comments (9)
src/every_query/sample_codes/sample_eq_train_eval_codes.py (1)

9-9: Consider making the path configurable.

The hardcoded absolute path makes this script less portable and tied to a specific user environment. Consider using a command-line argument, environment variable, or config file to make the path configurable.

🔎 Example refactor using environment variable
-PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+PARQUET_PATH = os.getenv("CODES_PARQUET_PATH", "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet")

Or use argparse for a command-line option.

src/every_query/train.py (1)

182-182: Use logger instead of print for consistency.

The codebase uses the logging module throughout. Replace this print statement with a logger call for consistency.

🔎 Suggested change
-    print("fitting model")
+    logger.info("Fitting model")
src/every_query/config.yaml (1)

66-66: Consider parameterizing the wandb entity name.

The hardcoded entity name ties this config to a specific user. Consider using an environment variable or OmegaConf resolver to make it configurable across different users/teams.

🔎 Example using environment variable
-    entity: "gregkondas9-columbia-university"
+    entity: ${oc.env:WANDB_ENTITY,"gregkondas9-columbia-university"}
src/every_query/eval_suite/conf/gen_tasks_config.yaml (1)

7-9: Document or parameterize user-specific paths.

The hardcoded absolute paths are tied to a specific user environment (/users/gbk2114/), making this config non-portable. Consider either:

  1. Documenting that this is a sample config to be copied and customized per user, or
  2. Using environment variables for the base paths
🔎 Example using environment variables
paths:
  index_times_dir: ${oc.env:EQ_TASKS_ROOT}/eval/index_times/${index_hash}/held_out
  all_tasks_dir: ${oc.env:EQ_TASKS_ROOT}/all/held_out
  out_root_dir: ${oc.env:EQ_TASKS_ROOT}/eval/tasks_from_index
src/every_query/eval_suite/conf/gen_index_times_config.yaml (1)

6-9: Document or parameterize user-specific paths.

The hardcoded absolute paths (/users/gbk2114/) limit portability. Consider documenting this as a sample config or using environment variables for base paths.

🔎 Example using environment variables
io:
  task_dir: ${oc.env:EQ_TASKS_ROOT,/users/gbk2114/eq_stuff/tasks}
  out_root: ${oc.env:EQ_EVAL_ROOT,/users/gbk2114/eq_stuff/tasks/eval/}
src/every_query/eval_suite/gen_index_times.py (2)

11-12: Consolidate duplicate list_parquets implementation.

This function is duplicated in src/every_query/eval_suite/gen_task.py (lines 16-17) with a slight variation. Consider moving it to a shared utils module to eliminate duplication.

🔎 Create a shared utility

Create src/every_query/eval_suite/utils.py:

from pathlib import Path

def list_parquets(d: Path) -> list[Path]:
    """Return sorted list of Parquet files in directory."""
    return sorted(p for p in d.iterdir() if p.is_file() and p.suffix == ".parquet")

Then import in both files:

from every_query.eval_suite.utils import list_parquets

35-48: Remove redundant mkdir call.

write_root.mkdir(...) is called twice (lines 35 and 46). The second call is redundant since the directory is already created on line 35.

🔎 Remove the duplicate
     write_root = out_root / "index_times" / time_hash
     write_root.mkdir(parents=True, exist_ok=True)
     
     ...
     
     in_dir = read_dir / held_out_split
-    write_root.mkdir(parents=True, exist_ok=True)
     write_shards = write_root / split
src/every_query/eval_suite/gen_task.py (2)

16-17: Consolidate duplicate list_parquets implementation.

This function is duplicated in src/every_query/eval_suite/gen_index_times.py (lines 11-12). Consider moving it to a shared utils module to eliminate duplication.


37-37: Add explicit strict parameter to zip().

For safety, explicitly specify strict=True to ensure both iterables have the same length. While you have an assertion on line 35 that checks filename equality, using strict=True provides an additional runtime safeguard.

🔎 Suggested change
-    for shard_idx, (idx_fp, all_fp) in enumerate(zip(index_shards, all_shards)):
+    for shard_idx, (idx_fp, all_fp) in enumerate(zip(index_shards, all_shards, strict=True)):

Note: strict=True requires Python 3.10+. If targeting earlier versions, the current assertion on line 35 is sufficient.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1d763f9 and aad00f5.

📒 Files selected for processing (10)
  • src/every_query/config.yaml (4 hunks)
  • src/every_query/eval.py (1 hunks)
  • src/every_query/eval_config.yaml (1 hunks)
  • src/every_query/eval_suite/README.md (1 hunks)
  • src/every_query/eval_suite/conf/gen_index_times_config.yaml (1 hunks)
  • src/every_query/eval_suite/conf/gen_tasks_config.yaml (1 hunks)
  • src/every_query/eval_suite/gen_index_times.py (1 hunks)
  • src/every_query/eval_suite/gen_task.py (1 hunks)
  • src/every_query/sample_codes/sample_eq_train_eval_codes.py (1 hunks)
  • src/every_query/train.py (5 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/every_query/eval_suite/README.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/every_query/eval_config.yaml
🧰 Additional context used
🧬 Code graph analysis (3)
src/every_query/eval_suite/gen_index_times.py (1)
src/every_query/eval_suite/gen_task.py (2)
  • list_parquets (16-17)
  • main (21-74)
src/every_query/eval_suite/gen_task.py (1)
src/every_query/eval_suite/gen_index_times.py (2)
  • list_parquets (11-12)
  • main (24-72)
src/every_query/eval.py (1)
src/every_query/train.py (1)
  • values_as_list (42-43)
🪛 Ruff (0.14.8)
src/every_query/sample_codes/sample_eq_train_eval_codes.py

22-22: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


23-23: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)

src/every_query/eval_suite/gen_index_times.py

19-19: Probable use of insecure hash functions in hashlib: md5

(S324)


31-31: Local variable method is assigned to but never used

Remove assignment to unused variable method

(F841)

src/every_query/eval_suite/gen_task.py

11-11: Probable use of insecure hash functions in hashlib: sha1

(S324)


37-37: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

src/every_query/eval.py

24-24: Avoid specifying long messages outside the exception class

(TRY003)


48-48: Avoid specifying long messages outside the exception class

(TRY003)


52-52: Avoid specifying long messages outside the exception class

(TRY003)


57-57: Undefined name code_slug

(F821)


64-64: Comment contains ambiguous (RIGHT SINGLE QUOTATION MARK). Did you mean ``` (GRAVE ACCENT)?

(RUF003)


76-76: Undefined name split

(F821)


76-76: Undefined name split

(F821)


77-77: Undefined name split

(F821)


77-77: Undefined name split

(F821)

src/every_query/train.py

87-87: Undefined name train_val_n

(F821)


87-87: Undefined name held_out_n

(F821)


88-88: Probable use of insecure hash functions in hashlib: md5

(S324)

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
src/every_query/eval.py (2)

22-25: Move code_slug to shared utils module.

This function is duplicated from src/every_query/eval_suite/gen_task.py (lines 10-13). Since it's used in multiple files across the evaluation suite, it should be moved to a shared utils module for reuse.

🔎 Proposed refactor

Create a shared utils module (e.g., src/every_query/utils/code_utils.py):

import hashlib
import re

def code_slug(code: str, n_hash: int = 10, prefix_len: int = 24) -> str:
    h = hashlib.sha1(code.encode("utf-8")).hexdigest()[:n_hash]
    prefix = re.sub(r"[^A-Za-z0-9._-]+", "_", code).strip("_")[:prefix_len]
    return f"{prefix}__{h}" if prefix else h

Then import it in both files:

from every_query.utils.code_utils import code_slug

85-86: Critical: Undefined variable 'split'.

The variable split is used in the f-string metric keys but is never defined. Based on the trainer.test() call on Line 77, this should likely be "test".

🔎 Proposed fix
+    split = "test"
     rows.append(
         {
             "code": code,
             "code_slug": slug,
             "bucket": str(cfg.bucket) if cfg.get("bucket") is not None else None,
             "occurs_auc": float(m.get(f"{split}/occurs_auc")) if m.get(f"{split}/occurs_auc") is not None else None,
             "censor_auc": float(m.get(f"{split}/censor_auc")) if m.get(f"{split}/censor_auc") is not None else None,
         }
     )
src/every_query/eval_suite/gen_task.py (1)

10-13: Move code_slug to shared utils module.

This function is duplicated in src/every_query/eval.py (lines 22-25). As noted in the review of that file, consider moving this to a shared utils module.

src/every_query/config.yaml (1)

98-98: Correct the misleading comment.

The comment states "prevents early stopping," but min_epochs: 0 actually allows training to stop immediately if early stopping conditions are met (including at epoch 0). This contradicts the comment.

🔎 Proposed fix
-  min_epochs: 0 # prevents early stopping
+  min_epochs: 0 # allows immediate early stopping

Or, if you want to prevent very early stopping:

-  min_epochs: 0 # prevents early stopping
+  min_epochs: 1 # ensures at least 1 full epoch before early stopping
🧹 Nitpick comments (7)
src/every_query/sample_codes/sample_eval_codes.py (1)

21-21: Consider discovering training sets dynamically.

The hardcoded list of training sets reduces flexibility. Consider discovering training set files from the ../train_codes/ directory using glob patterns, which would automatically accommodate new training sets without code changes.

🔎 Proposed refactor
-training_sets = ["10000_ID__8db2be6fadf8","10000_ID__fef969fa50be","10000_ID__e267abdbc547"]
+# Discover all training set YAML files
+training_set_files = glob("../train_codes/10000_ID__*.yaml")
+training_sets = [Path(f).stem for f in training_set_files]
src/every_query/train.py (1)

182-182: Consider using logger instead of print statement.

The debug print statement would be better as a logger call for consistency with the rest of the codebase.

🔎 Proposed refactor
-    print("fitting model")
+    logger.info("Starting model training.")
src/every_query/eval.py (2)

18-19: Move duplicated utility to shared module.

This function is duplicated from src/every_query/train.py. Consider moving it to a shared utils module to eliminate duplication.


41-44: Simplify seed handling.

Since seed is retrieved with a default value of 42, the None check on Line 42 will always pass. Consider simplifying this logic.

🔎 Proposed simplification
-    seed = train_cfg.get("seed", 42)
-    if seed is not None:
-        logger.info(f"Seeding with seed={seed}")
-        seed_everything(seed, workers=True)
+    seed = train_cfg.get("seed", 42)
+    logger.info(f"Seeding with seed={seed}")
+    seed_everything(seed, workers=True)
src/every_query/eval_suite/gen_task.py (3)

16-17: Move list_parquets to shared module.

This utility function is also defined in src/every_query/eval_suite/gen_index_times.py. Consider moving it to a shared utils module to eliminate duplication across the evaluation suite.


38-38: Provide descriptive assertion message.

The assertion validates shard alignment but lacks a descriptive message. This makes debugging harder if the assertion fails.

🔎 Proposed improvement
-    assert [p.name for p in index_shards] == [p.name for p in all_shards]
+    assert [p.name for p in index_shards] == [p.name for p in all_shards], (
+        f"Shard filename mismatch between index_dir and all_dir. "
+        f"Index: {[p.name for p in index_shards]}, All: {[p.name for p in all_shards]}"
+    )

40-40: Use strict=True in zip() for clarity and safety.

The zip() call lacks an explicit strict parameter. While Line 38 validates that index_shards and all_shards have the same length via assertion, using strict=True (available in Python 3.10+, which your project exceeds with its Python 3.12+ requirement) would make this safety check more explicit and eliminate the need for the separate assertion.

🔎 Proposed improvement
-    assert [p.name for p in index_shards] == [p.name for p in all_shards]
-
-    for shard_idx, (idx_fp, all_fp) in tqdm(enumerate(zip(index_shards, all_shards)),total=len(index_shards)):
+    for shard_idx, (idx_fp, all_fp) in tqdm(enumerate(zip(index_shards, all_shards, strict=True)),total=len(index_shards)):

This would raise a ValueError automatically if the lists have different lengths, providing a clearer error message than the assertion.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aad00f5 and dab6a3b.

📒 Files selected for processing (15)
  • .gitignore
  • src/every_query/config.yaml
  • src/every_query/eval.py
  • src/every_query/eval_config.yaml
  • src/every_query/eval_suite/conf/gen_tasks_config.yaml
  • src/every_query/eval_suite/gen_task.py
  • src/every_query/model.py
  • src/every_query/sample_codes/sample_eval_codes.py
  • src/every_query/sample_codes/sample_train_codes.py
  • src/every_query/train.py
  • src/every_query/train_codes/10000_ID__03b39cc59fb1.yaml
  • src/every_query/train_codes/10000_ID__8db2be6fadf8.yaml
  • src/every_query/train_codes/10000_ID__cee85062a596.yaml
  • src/every_query/train_codes/10000_ID__e267abdbc547.yaml
  • src/every_query/train_codes/10000_ID__fef969fa50be.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/every_query/eval_suite/conf/gen_tasks_config.yaml
  • src/every_query/eval_config.yaml
🧰 Additional context used
🧬 Code graph analysis (2)
src/every_query/eval.py (2)
src/every_query/train.py (2)
  • values_as_list (42-43)
  • main (125-197)
src/every_query/eval_suite/gen_task.py (2)
  • code_slug (10-13)
  • main (21-76)
src/every_query/eval_suite/gen_task.py (1)
src/every_query/eval_suite/gen_index_times.py (2)
  • list_parquets (11-12)
  • main (24-72)
🪛 GitHub Actions: Tests
src/every_query/sample_codes/sample_eval_codes.py

[error] 18-18: FileNotFoundError: No such file or directory (os error 2): /users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet

src/every_query/sample_codes/sample_train_codes.py

[error] 25-25: IndentationError: unindent does not match any outer indentation level

🪛 Ruff (0.14.10)
src/every_query/eval.py

23-23: Probable use of insecure hash functions in hashlib: sha1

(S324)


33-33: Avoid specifying long messages outside the exception class

(TRY003)


57-57: Avoid specifying long messages outside the exception class

(TRY003)


61-61: Avoid specifying long messages outside the exception class

(TRY003)


73-73: Comment contains ambiguous (RIGHT SINGLE QUOTATION MARK). Did you mean ``` (GRAVE ACCENT)?

(RUF003)


85-85: Undefined name split

(F821)


85-85: Undefined name split

(F821)


86-86: Undefined name split

(F821)


86-86: Undefined name split

(F821)

src/every_query/train.py

88-88: Probable use of insecure hash functions in hashlib: md5

(S324)

src/every_query/sample_codes/sample_train_codes.py

25-25: unindent does not match any outer indentation level

(invalid-syntax)

src/every_query/eval_suite/gen_task.py

11-11: Probable use of insecure hash functions in hashlib: sha1

(S324)


40-40: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: code-quality
🔇 Additional comments (13)
.gitignore (1)

179-180: LGTM! Appropriate ignore rule for generated evaluation configurations.

The new ignore rule correctly excludes the generated eval_codes directory, which aligns with the script in sample_eval_codes.py that writes YAML files to this location.

src/every_query/model.py (2)

211-215: LGTM! Clean addition of mlp_dropout parameter.

The new mlp_dropout parameter is properly typed, has a sensible default value (0.1), and follows the existing parameter pattern.


242-242: LGTM! Complete and consistent mlp_dropout integration.

The dropout parameter is properly:

  1. Stored in the HF model config (line 242)
  2. Referenced by both MLP instances from the config (lines 252, 255)
  3. Exposed in hparams for tracking (line 268)

This ensures consistency across the model and proper hyperparameter logging.

Also applies to: 252-256, 268-268

src/every_query/train.py (3)

22-40: LGTM! Well-implemented OmegaConf resolvers.

Both resolver functions are correctly implemented:

  • list_len provides a clean wrapper for length operations in configs
  • int_prod properly rounds the product and includes helpful docstring examples

The @OmegaConfResolver decorator properly registers these for use in configuration files.


84-121: LGTM! Collate function correctly refactored.

The function now properly:

  1. Returns the collated task directory path (addressing the signature change)
  2. Generates task hash from only the sorted codes (the undefined train_val_n and held_out_n variables from the previous review have been correctly removed)
  3. Processes only train and tuning splits with appropriate logging

Note: MD5 usage here is for directory naming/identification, not cryptographic security, so the static analysis warning can be safely ignored.


130-133: LGTM! Useful preprocessing-only mode.

The early exit when cfg.only_preprocess is true allows running the collation step independently, which is valuable for testing, debugging, and pipeline workflows.

src/every_query/eval.py (1)

100-102: LGTM!

The overwrite protection is well-implemented with clear logging.

src/every_query/eval_suite/gen_task.py (1)

57-76: LGTM!

The DataFrame construction logic is correct. The left join, column renaming, and null filtering appropriately handle per-code task generation.

src/every_query/config.yaml (5)

42-42: LGTM!

The learning rate reduction to 2e-5 aligns with common transformer training practices.


51-52: LGTM!

Making warmup proportional to training steps (5%) is a good improvement over the fixed 100 steps.


85-96: LGTM!

The early stopping configuration is well-designed with appropriate thresholds and monitoring strategy.


100-100: LGTM!

The max_steps reduction from 1,000,000 to 40,000 is noted. This significantly shortens training duration - ensure this aligns with your experimental objectives.


38-38: LGTM!

The MLP dropout rate of 0.1 is reasonable. EveryQueryModel accepts this parameter with a matching default value and uses it for dropout in the MLP layer construction.


# goal is to sample 20 random codes from each ID and OOD pair to get a
# total of 40 codes to calculate auroc for
PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded absolute path causes portability issues and pipeline failure.

The hardcoded path /users/gbk2114/data/... fails on other systems, as evidenced by the pipeline failure. Consider using a configurable path via environment variable, command-line argument, or configuration file.

🔎 Suggested approach
-PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+PARQUET_PATH = os.environ.get(
+    "CODES_PARQUET_PATH",
+    "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+)
🤖 Prompt for AI Agents
In src/every_query/sample_codes/sample_eval_codes.py around line 10, the
PARQUET_PATH is a hardcoded absolute path which breaks portability; change it to
read the path from a configurable source (preferably an environment variable
like PARQUET_PATH with a sensible relative-path default or a CLI/config file
override), validate that the resolved path exists and raise a clear error if
not, and update any docs or tests to rely on the new env/arg/config mechanism.

# -------------------
# Config
# -------------------
PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded absolute path reduces portability.

The hardcoded path /users/gbk2114/data/... will fail on other systems or deployments. Consider making this configurable via command-line argument, environment variable, or a configuration file.

🔎 Suggested approach

Option 1 - Environment variable:

-PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+PARQUET_PATH = os.environ.get(
+    "CODES_PARQUET_PATH",
+    "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
+)

Option 2 - Command-line argument (requires argparse setup):

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--parquet-path", required=True)
args = parser.parse_args()
PARQUET_PATH = args.parquet_path
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
PARQUET_PATH = "/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
PARQUET_PATH = os.environ.get(
"CODES_PARQUET_PATH",
"/users/gbk2114/data/MIMIC_MEDS/MEDS_cohort/processed/metadata/codes.parquet"
)
🤖 Prompt for AI Agents
In src/every_query/sample_codes/sample_train_codes.py around line 9, the
PARQUET_PATH is a hardcoded absolute path which breaks portability; change it to
be configurable by reading from an environment variable (e.g., read PARQUET_PATH
from os.environ with a sensible default) or by accepting a command-line argument
(add argparse to parse --parquet-path and assign it to PARQUET_PATH), and update
any usage or docstring to reflect the new configuration method.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

🤖 Fix all issues with AI agents
In @src/every_query/aces_to_eq/aces_to_eq.py:
- Line 1: Remove the trailing whitespace at the end of the "import logging" line
in aces_to_eq.py so the import statement has no trailing spaces; open the top of
the module where the "import logging" statement exists and delete any trailing
space or tabs so the pre-commit hook no longer flags the line.
- Around line 19-22: Extract the duplicated function code_slug into a single
shared utility module (create every_query.utils with the code_slug function
signature and body exactly as in the diff), then remove the local definitions of
code_slug in this file and the other two duplicates and replace them with from
every_query.utils import code_slug; ensure the function name and parameters
(code_slug(code: str, n_hash: int = 10, prefix_len: int = 24)) are unchanged so
callers in this module (aces_to_eq.py) and in process_composite.py and
eval_composite.py continue to work; you can ignore the static analysis SHA1
warning since it is used only for slug generation.

In @src/every_query/aces_to_eq/config.yaml:
- Around line 9-11: The three hardcoded config keys eq_tasks_all_dir,
aces_shards_dir, and output_dir should be made configurable: replace the
absolute paths with environment-variable backed values (e.g., use
${oc.env:EQ_DATA_ROOT:-./data}/tasks/all/held_out for eq_tasks_all_dir,
${oc.env:EIC_CONFIG_ROOT:-./configs}/make_index_dfs/task_configs/${task_name}/held_out
for aces_shards_dir, and
${oc.env:EQ_OUTPUT_ROOT:-./output}/tasks/eval/aces/${task_name}/held_out for
output_dir) or use relative paths from the project root and document an optional
git-ignored local override file; ensure ${task_name} placeholder remains intact
and provide sensible defaults so the config works for other developers and CI
without user-specific absolute paths.

In @src/every_query/eval_composite.py:
- Around line 22-25: The function code_slug is duplicated in eval_composite.py
and process_composite.py; extract it into a shared utility module (e.g.,
every_query.utils) and replace the local definitions with a single import from
that module; ensure the new utils module exports code_slug with the same
signature (code: str, n_hash: int = 10, prefix_len: int = 24) so both
eval_composite.py and process_composite.py simply do from every_query.utils
import code_slug and remove the duplicate function bodies.
- Around line 91-102: Before calling pl.concat(rows, how="vertical"), check
whether rows is empty; if it is, avoid concatenation and either log and return
(using logger) or create and write an empty DataFrame so downstream code gets a
valid CSV. Concretely, guard the existing logic around final_df =
pl.concat(rows, how="vertical") with an if not rows: branch that logs something
like "no rows produced, skipping output" and returns (honoring cfg.do_overwrite
behavior), or construct an empty pl.DataFrame() and continue to the
out_dir/out_fp write path so out_fp is created with headers before calling
final_df.write_csv(out_fp). Ensure you reference the same variables: rows,
final_df, pl.concat, out_dir, out_fp, cfg.output_root, cfg.task_name,
cfg.do_overwrite, and logger.
- Line 76: The assignment "m = out[0] if out else {}" creates an unused variable
`m`; remove this unused assignment and instead directly use `out[0]` (with the
same fallback {}) at the call site or simply omit the assignment if nothing is
needed. Update the expression around `out` handling so there are no unused
variables (replace `m` usage with `out[0] if out else {}` or remove the line
entirely) to satisfy static analysis; refer to the variable `m` and the `out`
list in the existing code.
- Around line 78-86: Check and validate M.test_predictions before using it:
replace direct access to M.test_predictions with a guarded retrieval (e.g., pred
= getattr(M, "test_predictions", None)) and if pred is None raise/log a clear
error mentioning M and the test stage. Also validate that pred is a mapping
containing the keys "subject_id", "prediction_time", and "occurs_probs" (and
that each value is an iterable/Series); if any key is missing, raise/log a
descriptive error or provide a safe fallback, and only then construct the
pl.DataFrame with those validated values.

In @src/every_query/eval_suite/conf/eval_composite_config.yaml:
- Around line 6-8: The config currently hard-codes user-specific absolute paths
for model_run_dir and task_set_dir; change these keys to use
environment-variable-based roots or configurable placeholders (e.g.
model_run_dir: ${MODEL_RUN_ROOT}/outputs/${run_timestamp} and task_set_dir:
${TASK_SET_ROOT}/tasks/eval/aces/${task_name}/held_out) so task_name remains a
relative token, and update any code that loads eval_composite_config.yaml to
expand those env vars (or read a configured root) before using the paths.

In @src/every_query/eval_suite/conf/eval_config.yaml:
- Around line 6-8: Replace the hard-coded user-specific paths in the YAML keys
model_run_dir and task_set_dir (and any others using /users/gbk2114/) with
configurable, portable values: read them from environment variables (e.g.,
MODEL_RUN_DIR and TASK_SET_DIR) or construct them via a Hydra resolver or a
project-root relative path that uses index_time_hash dynamically; update the
YAML to reference the env/resolver placeholders (keeping index_time_hash as the
variable used in task_set_dir) and ensure calling code that reads
model_run_dir/task_set_dir falls back to sensible defaults if the env vars are
unset.

In @src/every_query/process_composite_config.yaml:
- Around line 1-3: The YAML contains hard-coded user-specific paths in
predictions_df_path and task_labels_df_path which break portability; change
these to use configurable roots (environment variables or a top-level config
key) or relative paths (e.g., use ${PREDICTIONS_ROOT}/... and
${TASK_LABELS_ROOT}/... or a base_path variable) and update any code that reads
this YAML to expand those variables; also verify and correct the inconsistent
directory name (eic_stuff vs eq_stuff) in task_labels_df_path to the intended
name so all configs use the same canonical directory.

In @src/every_query/process_composite.py:
- Line 9: The import "from sklearn.metrics import roc_auc_score" is failing
because scikit-learn is not declared as a project dependency; add
"scikit-learn>=1.0.0" to your dependency manifest (e.g., requirements.txt,
setup.py, or pyproject.toml/Poetry under [tool.poetry.dependencies]) and run
your install step so the import in process_composite.py succeeds; update
CI/pipeline dependency cache if necessary.
- Around line 25-35: Wrap the CSV and Parquet reads with defensive I/O checks:
before calling pl.read_csv(cfg.predictions_df_path) and
pl.read_parquet(cfg.task_labels_df_path) verify the paths exist
(cfg.predictions_df_path and cfg.task_labels_df_path), and then wrap the
pl.read_* calls that populate all_preds_df and all_aces_labels_df in try/except
to catch and log IOErrors/Exceptions (including the exception message), failing
fast with a clear error via processLogger.error or re-raising a descriptive
exception; keep subsequent logic that computes probs_df (group_by/agg on
all_preds_df) unchanged but ensure it only runs after successful reads.
- Around line 43-51: Check that the inner join result (joined_df produced from
probs_df.join(all_aces_labels_df,on=["subject_id","prediction_time"],how="inner"))
is non-empty and that joined_df["boolean_value"] contains at least two distinct
classes before calling roc_auc_score; if joined_df.empty or
joined_df["boolean_value"].nunique() < 2, short-circuit (return/raise/log a
clear message or set auc to None/np.nan) instead of calling roc_auc_score to
avoid ValueError, and optionally wrap the roc_auc_score call in a try/except to
log unexpected errors while preserving context (include references to joined_df,
probs_df, all_aces_labels_df, and roc_auc_score).
🧹 Nitpick comments (10)
src/every_query/aces_to_eq/aces_to_eq.py (3)

27-27: Remove or use the unused target_rows parameter.

The target_rows parameter is passed from the configuration but never used in the function body. Either:

  1. Remove it if it's not needed, or
  2. Implement row limiting logic if it was intended but not yet implemented

50-50: Use spread operator for list concatenation.

The current concatenation style works but using the spread operator is more idiomatic.

♻️ Proposed refactor
-            .select(base_cols + [code])
+            .select([*base_cols, code])

88-88: Add strict=True to zip() for safer iteration.

While the assertion on line 85 validates that shard IDs match, adding strict=True makes the intent explicit and provides an additional safety check.

♻️ Proposed refactor
-    for eq_shard_fp, aces_shard_fp in tqdm(zip(all_eq_shards,all_aces_shards),total=len(all_eq_shards)):
+    for eq_shard_fp, aces_shard_fp in tqdm(zip(all_eq_shards, all_aces_shards, strict=True), total=len(all_eq_shards)):
src/every_query/eval_suite/conf/eval_config.yaml (1)

12-13: Clarify commented configuration intent.

The id_codes and ood_codes are commented out. If these are placeholders for future work, consider adding a TODO comment. If they're no longer needed, they should be removed for clarity.

src/every_query/eval_composite.py (3)

18-19: Consider removing trivial helper to reduce duplication.

values_as_list is duplicated across eval_composite.py and process_composite.py but only converts kwargs.values() to a list—a one-liner that could be used inline. Since it's not currently called anywhere in this file, consider removing it or moving it to a shared utilities module if it's genuinely needed elsewhere.


38-39: Replace logger nullification with proper configuration.

Setting train_cfg.trainer.logger = "" to "nuke the logger" is a workaround. Consider disabling the logger via proper Lightning configuration to avoid potential issues.

♻️ Cleaner approach
-    # Nuke the logger so wandb dashboard is clean
-    train_cfg.trainer.logger = ""
+    # Disable logger for evaluation runs
+    train_cfg.trainer.logger = False

63-89: Add error handling for per-code evaluation loop.

The loop lacks exception handling. If any code's evaluation fails (e.g., during instantiate, trainer.test, or DataFrame construction), the entire run aborts. Consider wrapping the loop body in a try-except to log failures and continue processing remaining codes.

♻️ Add graceful error handling
     for code in codes:
+        try:
             slug = code_slug(code)
             task_labels_dir = str(task_set_dir / slug)

             if not Path(task_labels_dir).is_dir():
                 logger.warning(f"Missing task_labels_dir for code={code}: {task_labels_dir} (skipping)")
                 continue

             # Point datamodule at this code's task dfs
             train_cfg.datamodule.config.task_labels_dir = task_labels_dir
             D = instantiate(train_cfg.datamodule)

             trainer.test(model=M, datamodule=D, ckpt_path=cfg.ckpt_path)
             
             if not hasattr(M, 'test_predictions'):
                 logger.warning(f"Model has no test_predictions for code={code} (skipping)")
                 continue

             pred = M.test_predictions

             df = pl.DataFrame({
                 "subject_id": pred["subject_id"],
                 "prediction_time": pred["prediction_time"],
                 "occurs_probs": pred["occurs_probs"],
             }).with_columns(
                 pl.lit(code).alias("code")
             )

             rows.append(df)
+        except Exception as e:
+            logger.error(f"Failed to evaluate code={code}: {e}", exc_info=True)
+            continue
src/every_query/process_composite.py (3)

25-25: Add spacing around assignment operator.

Minor style issue: missing space after = in the assignment.

-    all_preds_df=pl.read_csv(cfg.predictions_df_path)
+    all_preds_df = pl.read_csv(cfg.predictions_df_path)

51-51: Use logger instead of print for output.

Using print for the AUC result makes it difficult to capture, filter, or redirect output in production environments. Consider using a logger or writing to a file.

♻️ Better output handling
+    import logging
+    logger = logging.getLogger(__name__)
+    
     auc = roc_auc_score(
         joined_df["boolean_value"].to_numpy(),
         joined_df["max_prob"].to_numpy(),
     )

-    print(auc)
+    logger.info(f"ROC-AUC: {auc:.4f}")
+    
+    # Optionally save to file
+    # output_path = Path(cfg.output_root) / "metrics.txt"
+    # output_path.parent.mkdir(parents=True, exist_ok=True)
+    # output_path.write_text(f"ROC-AUC: {auc:.4f}\n")

53-59: Remove excessive blank lines.

Lines 53-59 contain 7 consecutive blank lines. Python PEP 8 recommends at most 2 blank lines between top-level definitions.

     print(auc)


-
-
-
-
-
     
 if __name__ == "__main__":
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dab6a3b and c5f3c7e.

📒 Files selected for processing (8)
  • src/every_query/aces_to_eq/aces_to_eq.py
  • src/every_query/aces_to_eq/config.yaml
  • src/every_query/eval_composite.py
  • src/every_query/eval_suite/conf/eval_composite_config.yaml
  • src/every_query/eval_suite/conf/eval_config.yaml
  • src/every_query/eval_suite/process_composite_eval.py
  • src/every_query/process_composite.py
  • src/every_query/process_composite_config.yaml
🧰 Additional context used
🧬 Code graph analysis (1)
src/every_query/aces_to_eq/aces_to_eq.py (2)
src/every_query/eval_composite.py (2)
  • code_slug (22-25)
  • main (30-102)
src/every_query/process_composite.py (2)
  • code_slug (16-19)
  • main (24-51)
🪛 GitHub Actions: Code Quality PR
src/every_query/eval_suite/conf/eval_config.yaml

[warning] 1-1: yaml formatting adjusted by prettier/mdformat.

src/every_query/aces_to_eq/aces_to_eq.py

[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook (trace: trailing-whitespace).

src/every_query/eval_composite.py

[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook (trace: trailing-whitespace).


[error] 74-74: Local variable 'm' assigned to but never used (F841) in eval_composite.py.

src/every_query/eval_suite/conf/eval_composite_config.yaml

[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook (trace: trailing-whitespace).

src/every_query/process_composite.py

[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook (trace: trailing-whitespace).

🪛 GitHub Actions: Tests
src/every_query/process_composite.py

[error] 9-9: ImportError: No module named 'sklearn'. ModuleNotFoundError occurred while importing test module during collection.

🪛 Ruff (0.14.10)
src/every_query/aces_to_eq/aces_to_eq.py

20-20: Probable use of insecure hash functions in hashlib: sha1

(S324)


27-27: Unused function argument: target_rows

(ARG001)


50-50: Consider [*base_cols, code] instead of concatenation

Replace with [*base_cols, code]

(RUF005)


88-88: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

src/every_query/eval_composite.py

23-23: Probable use of insecure hash functions in hashlib: sha1

(S324)


33-33: Avoid specifying long messages outside the exception class

(TRY003)


57-57: Avoid specifying long messages outside the exception class

(TRY003)


71-71: Comment contains ambiguous (RIGHT SINGLE QUOTATION MARK). Did you mean ``` (GRAVE ACCENT)?

(RUF003)


76-76: Local variable m is assigned to but never used

Remove assignment to unused variable m

(F841)

src/every_query/process_composite.py

17-17: Probable use of insecure hash functions in hashlib: sha1

(S324)

🔇 Additional comments (1)
src/every_query/aces_to_eq/aces_to_eq.py (1)

23-24: LGTM!

Clean helper function for extracting shard IDs from filenames.

@@ -0,0 +1,99 @@
import logging
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix trailing whitespace.

The pre-commit hook detected trailing whitespace on this line, which caused a pipeline failure.

🧰 Tools
🪛 GitHub Actions: Code Quality PR

[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook (trace: trailing-whitespace).

🤖 Prompt for AI Agents
In @src/every_query/aces_to_eq/aces_to_eq.py at line 1, Remove the trailing
whitespace at the end of the "import logging" line in aces_to_eq.py so the
import statement has no trailing spaces; open the top of the module where the
"import logging" statement exists and delete any trailing space or tabs so the
pre-commit hook no longer flags the line.

Comment on lines +9 to +11
eq_tasks_all_dir: /users/gbk2114/eq_stuff/tasks/all/held_out
aces_shards_dir: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/${task_name}/held_out
output_dir: /users/gbk2114/eq_stuff/tasks/eval/aces/${task_name}/held_out
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace hardcoded absolute paths with configurable alternatives.

These absolute paths are specific to user gbk2114 and will break for other developers or deployment environments. This creates a major portability and reproducibility issue.

Consider using:

  • Relative paths from the project root
  • Environment variables (e.g., ${oc.env:EQ_DATA_ROOT})
  • A separate config override file for local paths that is git-ignored
♻️ Example using environment variables
-eq_tasks_all_dir: /users/gbk2114/eq_stuff/tasks/all/held_out
-aces_shards_dir: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/${task_name}/held_out
-output_dir: /users/gbk2114/eq_stuff/tasks/eval/aces/${task_name}/held_out
+eq_tasks_all_dir: ${oc.env:EQ_DATA_ROOT}/tasks/all/held_out
+aces_shards_dir: ${oc.env:ACES_DATA_ROOT}/make_index_dfs/task_configs/${task_name}/held_out
+output_dir: ${oc.env:EQ_DATA_ROOT}/tasks/eval/aces/${task_name}/held_out
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
eq_tasks_all_dir: /users/gbk2114/eq_stuff/tasks/all/held_out
aces_shards_dir: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/${task_name}/held_out
output_dir: /users/gbk2114/eq_stuff/tasks/eval/aces/${task_name}/held_out
eq_tasks_all_dir: ${oc.env:EQ_DATA_ROOT}/tasks/all/held_out
aces_shards_dir: ${oc.env:ACES_DATA_ROOT}/make_index_dfs/task_configs/${task_name}/held_out
output_dir: ${oc.env:EQ_DATA_ROOT}/tasks/eval/aces/${task_name}/held_out
🤖 Prompt for AI Agents
In @src/every_query/aces_to_eq/config.yaml around lines 9 - 11, The three
hardcoded config keys eq_tasks_all_dir, aces_shards_dir, and output_dir should
be made configurable: replace the absolute paths with environment-variable
backed values (e.g., use ${oc.env:EQ_DATA_ROOT:-./data}/tasks/all/held_out for
eq_tasks_all_dir,
${oc.env:EIC_CONFIG_ROOT:-./configs}/make_index_dfs/task_configs/${task_name}/held_out
for aces_shards_dir, and
${oc.env:EQ_OUTPUT_ROOT:-./output}/tasks/eval/aces/${task_name}/held_out for
output_dir) or use relative paths from the project root and document an optional
git-ignored local override file; ensure ${task_name} placeholder remains intact
and provide sensible defaults so the config works for other developers and CI
without user-specific absolute paths.

D = instantiate(train_cfg.datamodule)

out = trainer.test(model=M, datamodule=D, ckpt_path=cfg.ckpt_path)
m = out[0] if out else {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused variable m.

The variable m is assigned the test results but never used. This is flagged by both pipeline checks and static analysis.

🧹 Proposed fix
-        out = trainer.test(model=M, datamodule=D, ckpt_path=cfg.ckpt_path)
-        m = out[0] if out else {}
+        trainer.test(model=M, datamodule=D, ckpt_path=cfg.ckpt_path)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.10)

76-76: Local variable m is assigned to but never used

Remove assignment to unused variable m

(F841)

🤖 Prompt for AI Agents
In @src/every_query/eval_composite.py at line 76, The assignment "m = out[0] if
out else {}" creates an unused variable `m`; remove this unused assignment and
instead directly use `out[0]` (with the same fallback {}) at the call site or
simply omit the assignment if nothing is needed. Update the expression around
`out` handling so there are no unused variables (replace `m` usage with `out[0]
if out else {}` or remove the line entirely) to satisfy static analysis; refer
to the variable `m` and the `out` list in the existing code.

Comment on lines +6 to +8
model_run_dir: /users/gbk2114/EveryQuery/results/outputs/2026-01-03/15-43-49/
index_time_hash: 7573f855c4b050a9d79d57fefd8a139c
task_set_dir: /users/gbk2114/eq_stuff/tasks/eval/aces/${index_time_hash}/
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hard-coded user-specific paths block portability.

Lines 6-8 contain absolute paths tied to a specific user directory (/users/gbk2114/). This prevents the code from running in other environments, CI/CD pipelines, or by other developers.

💡 Recommended approach

Consider using:

  • Environment variables (e.g., ${oc.env:MODEL_RUN_DIR})
  • Relative paths from a configurable root
  • Hydra resolvers for user-agnostic path construction

Example using environment variables:

-model_run_dir: /users/gbk2114/EveryQuery/results/outputs/2026-01-03/15-43-49/
-task_set_dir: /users/gbk2114/eq_stuff/tasks/eval/aces/${index_time_hash}/
+model_run_dir: ${oc.env:EVERYQUERY_RESULTS_DIR}/outputs/2026-01-03/15-43-49/
+task_set_dir: ${oc.env:EVERYQUERY_TASKS_DIR}/eval/aces/${index_time_hash}/
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model_run_dir: /users/gbk2114/EveryQuery/results/outputs/2026-01-03/15-43-49/
index_time_hash: 7573f855c4b050a9d79d57fefd8a139c
task_set_dir: /users/gbk2114/eq_stuff/tasks/eval/aces/${index_time_hash}/
model_run_dir: ${oc.env:EVERYQUERY_RESULTS_DIR}/outputs/2026-01-03/15-43-49/
index_time_hash: 7573f855c4b050a9d79d57fefd8a139c
task_set_dir: ${oc.env:EVERYQUERY_TASKS_DIR}/eval/aces/${index_time_hash}/
🤖 Prompt for AI Agents
In @src/every_query/eval_suite/conf/eval_config.yaml around lines 6 - 8, Replace
the hard-coded user-specific paths in the YAML keys model_run_dir and
task_set_dir (and any others using /users/gbk2114/) with configurable, portable
values: read them from environment variables (e.g., MODEL_RUN_DIR and
TASK_SET_DIR) or construct them via a Hydra resolver or a project-root relative
path that uses index_time_hash dynamically; update the YAML to reference the
env/resolver placeholders (keeping index_time_hash as the variable used in
task_set_dir) and ensure calling code that reads model_run_dir/task_set_dir
falls back to sensible defaults if the env vars are unset.

Comment on lines +1 to +3
predictions_df_path: /users/gbk2114/EveryQuery/results/outputs/2026-01-03/15-43-49/eval/readmiss_30d/readmiss_30d.csv

task_labels_df_path: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/readmiss_30d/held_out
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hard-coded user-specific paths block portability.

Both configuration paths are tied to specific user directories (/users/gbk2114/), preventing use by other developers or in different environments.

Additionally, note the inconsistent directory naming: eic_stuff (line 3) vs. eq_stuff (used in eval_config.yaml line 8). Verify whether this is intentional.

💡 Recommended approach

Use environment variables or configurable roots:

-predictions_df_path: /users/gbk2114/EveryQuery/results/outputs/2026-01-03/15-43-49/eval/readmiss_30d/readmiss_30d.csv
+predictions_df_path: ${oc.env:EVERYQUERY_RESULTS_DIR}/outputs/2026-01-03/15-43-49/eval/readmiss_30d/readmiss_30d.csv

-task_labels_df_path: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/readmiss_30d/held_out
+task_labels_df_path: ${oc.env:EVERYQUERY_DATA_DIR}/make_index_dfs/task_configs/readmiss_30d/held_out
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
predictions_df_path: /users/gbk2114/EveryQuery/results/outputs/2026-01-03/15-43-49/eval/readmiss_30d/readmiss_30d.csv
task_labels_df_path: /users/gbk2114/eic_stuff/make_index_dfs/task_configs/readmiss_30d/held_out
predictions_df_path: ${oc.env:EVERYQUERY_RESULTS_DIR}/outputs/2026-01-03/15-43-49/eval/readmiss_30d/readmiss_30d.csv
task_labels_df_path: ${oc.env:EVERYQUERY_DATA_DIR}/make_index_dfs/task_configs/readmiss_30d/held_out
🤖 Prompt for AI Agents
In @src/every_query/process_composite_config.yaml around lines 1 - 3, The YAML
contains hard-coded user-specific paths in predictions_df_path and
task_labels_df_path which break portability; change these to use configurable
roots (environment variables or a top-level config key) or relative paths (e.g.,
use ${PREDICTIONS_ROOT}/... and ${TASK_LABELS_ROOT}/... or a base_path variable)
and update any code that reads this YAML to expand those variables; also verify
and correct the inconsistent directory name (eic_stuff vs eq_stuff) in
task_labels_df_path to the intended name so all configs use the same canonical
directory.

import hydra
import polars as pl
from omegaconf import DictConfig, OmegaConf
from sklearn.metrics import roc_auc_score
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add missing scikit-learn dependency.

The import of sklearn.metrics fails because the package is not installed, as shown in the pipeline failure. This blocks the script from running.

📦 Add to project dependencies

Add scikit-learn to your project's dependency file (e.g., requirements.txt, setup.py, pyproject.toml):

scikit-learn>=1.0.0

Or if using Poetry:

[tool.poetry.dependencies]
scikit-learn = "^1.0.0"
🧰 Tools
🪛 GitHub Actions: Tests

[error] 9-9: ImportError: No module named 'sklearn'. ModuleNotFoundError occurred while importing test module during collection.

🤖 Prompt for AI Agents
In @src/every_query/process_composite.py at line 9, The import "from
sklearn.metrics import roc_auc_score" is failing because scikit-learn is not
declared as a project dependency; add "scikit-learn>=1.0.0" to your dependency
manifest (e.g., requirements.txt, setup.py, or pyproject.toml/Poetry under
[tool.poetry.dependencies]) and run your install step so the import in
process_composite.py succeeds; update CI/pipeline dependency cache if necessary.

Comment on lines +25 to +35
all_preds_df=pl.read_csv(cfg.predictions_df_path)

probs_df = (
all_preds_df
.group_by(["subject_id","prediction_time"])
.agg(
pl.col("occurs_probs").max().alias("max_prob")
)
)

all_aces_labels_df=pl.read_parquet(cfg.task_labels_df_path,columns=["subject_id","prediction_time","boolean_value"])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add validation and error handling for file I/O.

Lines 25 and 35 read files without checking if they exist or handling potential I/O errors. If the configured paths are incorrect or files are missing, the script will crash with an unclear error.

🛡️ Add defensive checks
+    from pathlib import Path
+    
+    predictions_path = Path(cfg.predictions_df_path)
+    if not predictions_path.exists():
+        raise FileNotFoundError(f"Predictions file not found: {predictions_path}")
+    
-    all_preds_df=pl.read_csv(cfg.predictions_df_path)
+    all_preds_df = pl.read_csv(predictions_path)
     
     probs_df = (
         all_preds_df
         .group_by(["subject_id","prediction_time"])
         .agg(
             pl.col("occurs_probs").max().alias("max_prob")
         )
     )
     
+    labels_path = Path(cfg.task_labels_df_path)
+    if not labels_path.exists():
+        raise FileNotFoundError(f"Task labels file not found: {labels_path}")
+    
-    all_aces_labels_df=pl.read_parquet(cfg.task_labels_df_path,columns=["subject_id","prediction_time","boolean_value"])
+    all_aces_labels_df = pl.read_parquet(labels_path, columns=["subject_id","prediction_time","boolean_value"])
🤖 Prompt for AI Agents
In @src/every_query/process_composite.py around lines 25 - 35, Wrap the CSV and
Parquet reads with defensive I/O checks: before calling
pl.read_csv(cfg.predictions_df_path) and
pl.read_parquet(cfg.task_labels_df_path) verify the paths exist
(cfg.predictions_df_path and cfg.task_labels_df_path), and then wrap the
pl.read_* calls that populate all_preds_df and all_aces_labels_df in try/except
to catch and log IOErrors/Exceptions (including the exception message), failing
fast with a clear error via processLogger.error or re-raising a descriptive
exception; keep subsequent logic that computes probs_df (group_by/agg on
all_preds_df) unchanged but ensure it only runs after successful reads.

Comment on lines +43 to +51
joined_df = probs_df.join(all_aces_labels_df,on=["subject_id","prediction_time"],how="inner")


auc = roc_auc_score(
joined_df["boolean_value"].to_numpy(),
joined_df["max_prob"].to_numpy(),
)

print(auc)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate join results and ROC-AUC requirements.

The inner join at line 43 may produce an empty DataFrame if there are no matching records. Additionally, roc_auc_score requires at least two classes in the labels; if all labels are identical, it will raise a ValueError.

🛡️ Add validation
     joined_df = probs_df.join(all_aces_labels_df,on=["subject_id","prediction_time"],how="inner")

+    if joined_df.is_empty():
+        raise ValueError("No matching records found between predictions and labels")
+    
+    unique_labels = joined_df["boolean_value"].n_unique()
+    if unique_labels < 2:
+        raise ValueError(f"ROC-AUC requires at least 2 classes, found {unique_labels}")
     
     auc = roc_auc_score(
         joined_df["boolean_value"].to_numpy(),
         joined_df["max_prob"].to_numpy(),
     )

     print(auc)
🤖 Prompt for AI Agents
In @src/every_query/process_composite.py around lines 43 - 51, Check that the
inner join result (joined_df produced from
probs_df.join(all_aces_labels_df,on=["subject_id","prediction_time"],how="inner"))
is non-empty and that joined_df["boolean_value"] contains at least two distinct
classes before calling roc_auc_score; if joined_df.empty or
joined_df["boolean_value"].nunique() < 2, short-circuit (return/raise/log a
clear message or set auc to None/np.nan) instead of calling roc_auc_score to
avoid ValueError, and optionally wrap the roc_auc_score call in a try/except to
log unexpected errors while preserving context (include references to joined_df,
probs_df, all_aces_labels_df, and roc_auc_score).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Fix all issues with AI agents
In `@src/every_query/config.yaml`:
- Around line 51-52: The YAML comment for num_warmup_steps is inconsistent: it
sets num_warmup_steps to 0.05 (5%) but the comment says "~10% of training
steps"; update either the value or the comment so they match. Locate
num_warmup_steps and num_training_steps in config.yaml and change the numeric
value to 0.10 if you want 10% warmup, or edit the trailing comment to "~5% of
training steps" to reflect the current 0.05 value; ensure the placeholder
expression ${int_prod:0.05,${.num_training_steps}} remains syntactically
unchanged except for the numeric literal if you adjust it.

In `@src/every_query/eval.py`:
- Around line 84-89: The bucket assignment uses "code in cfg.ood_codes" which
throws when cfg.ood_codes is None; to fix, introduce a safe local default (e.g.,
ood_codes = cfg.ood_codes or set()) near where codes/cfg are prepared and
replace the direct membership check in the rows.append call with "code in
ood_codes" (update the bucket expression inside the rows.append block in eval.py
so it uses the local ood_codes variable instead of cfg.ood_codes).

In `@src/every_query/lightning_module.py`:
- Around line 204-215: The on_test_epoch_end method must early-return when
self.cache is empty to avoid torch.cat on an empty sequence; add a guard at the
top of on_test_epoch_end that checks if not self.cache and if so sets
self.test_predictions to a dict with "subject_id", "prediction_time", and
"occurs_probs" as appropriately-typed empty numpy arrays (e.g., empty int array
for subject_id and empty float arrays for prediction_time and occurs_probs) and
then returns. Ensure you update the same symbols: on_test_epoch_end, self.cache,
and self.test_predictions so downstream code sees the correctly-typed empty
arrays.
- Around line 299-307: In test_step, guard access to outputs.occurs_logits the
same way _log_metrics does: use getattr(outputs, "occurs_logits", None) and only
compute occurs_probs = torch.sigmoid(...) and detach/CPU it if the logits are
not None; otherwise set occurs_probs to None (and ensure the "occurs_probs"
entry appended to self.cache uses that value). Update references to batch.occurs
handling remains the same and keep using outputs/occurs_logits symbol names so
the change is localized to test_step.

In `@src/every_query/process_composite/get_per_code_from_composite.py`:
- Around line 64-66: The code builds the output filename via string
concatenation (out_fp = cfg.output_root + "readmiss_30d_per_code.csv"), which
can create invalid paths; change this to use Path joining by importing
pathlib.Path (or using cfg.output_root as a Path) and construct out_fp with
Path(cfg.output_root) / "readmiss_30d_per_code.csv" before calling
out_df.write_csv(out_fp), ensuring out_fp is converted to a string if write_csv
requires it.
- Around line 44-62: The loop that builds joined_df and computes AUC needs
guards: before accessing joined_df["code"][0] or computing auc with
roc_auc_score, check that joined_df is not empty (e.g., len(joined_df) > 0) and
skip/log if empty; then validate y_true and y_score extracted from joined_df (no
NaNs, matching lengths, and y_true contains at least two classes) and skip/log
when any check fails; only compute auc and set results[code] when all
validations pass. Use the existing variables/code symbols (code_task_path,
slugged_code, code_task_df, code_preds_df, joined_df, y_true, y_score, results,
roc_auc_score) to add these guards and avoid indexing an empty DataFrame or
calling roc_auc_score on invalid data.
- Around line 7-11: The module-level import of roc_auc_score from sklearn causes
test collection failures when scikit-learn isn't installed; either add
scikit-learn to project dependencies in pyproject.toml or move the import into
the runtime path (e.g., inside the main() function) and raise a clear error if
it's missing; specifically, remove the top-level "from sklearn.metrics import
roc_auc_score" and instead import roc_auc_score within main() (or the function
that uses it) and catch ImportError to log a helpful message about installing
scikit-learn.

In `@src/every_query/process_composite/process_composite.py`:
- Around line 5-9: The module-level sklearn import causes test failures when
scikit-learn isn't installed and the unused parameter clip_sum_to_1 should be
removed: move the import "from sklearn.metrics import roc_auc_score" into the
agg_probs function (so it is lazily imported only when agg_probs is called) and
remove the unused parameter clip_sum_to_1 from agg_probs's signature and any
callers; update any internal references to roc_auc_score to use the local import
within agg_probs and run tests to ensure no other call sites expect
clip_sum_to_1.
🧹 Nitpick comments (2)
src/every_query/dataset.py (1)

31-51: Add shape checks for new held‑out annotations.

subject_id and prediction_time are new batch fields but aren’t validated in __post_init__, so shape mismatches can slip through silently. Consider validating them like censor/occurs/query.

♻️ Suggested validation
     def __post_init__(self):
         # Run base validations
         super().__post_init__()
 
+        if self.subject_id is not None:
+            self._MEDSTorchBatch__check_shape("subject_id", (self.batch_size,))
+        if self.prediction_time is not None:
+            self._MEDSTorchBatch__check_shape("prediction_time", (self.batch_size,))
+
         # Validate optional per-sample annotation shapes, if provided
         if self.censor is not None:
             self._MEDSTorchBatch__check_shape("censor", (self.batch_size,))
src/every_query/process_composite/process_composite.py (1)

21-64: Wire up clip_sum_to_1 or remove it.

The argument is unused; if it’s intended, implement clipping on prob_sum.

♻️ Possible implementation
     if "sum" in agg_type_set:
-        expr = pl.col(pred_prob_col).sum()
+        expr = pl.col(pred_prob_col).sum()
+        if clip_sum_to_1:
+            expr = expr.clip(0.0, 1.0)
 
         aggs.append(expr.alias("prob_sum"))

Comment on lines +51 to +52
num_warmup_steps: ${int_prod:0.05,${.num_training_steps}} # ~10% of training steps
num_training_steps: ${trainer.max_steps}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Warmup percentage comment doesn’t match value.

num_warmup_steps uses 0.05 (=5%), but the comment says ~10%. Align the comment or value.

📝 Comment fix
-    num_warmup_steps: ${int_prod:0.05,${.num_training_steps}} # ~10% of training steps
+    num_warmup_steps: ${int_prod:0.05,${.num_training_steps}} # ~5% of training steps
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
num_warmup_steps: ${int_prod:0.05,${.num_training_steps}} # ~10% of training steps
num_training_steps: ${trainer.max_steps}
num_warmup_steps: ${int_prod:0.05,${.num_training_steps}} # ~5% of training steps
num_training_steps: ${trainer.max_steps}
🤖 Prompt for AI Agents
In `@src/every_query/config.yaml` around lines 51 - 52, The YAML comment for
num_warmup_steps is inconsistent: it sets num_warmup_steps to 0.05 (5%) but the
comment says "~10% of training steps"; update either the value or the comment so
they match. Locate num_warmup_steps and num_training_steps in config.yaml and
change the numeric value to 0.10 if you want 10% warmup, or edit the trailing
comment to "~5% of training steps" to reflect the current 0.05 value; ensure the
placeholder expression ${int_prod:0.05,${.num_training_steps}} remains
syntactically unchanged except for the numeric literal if you adjust it.

Comment on lines +84 to +89
rows.append(
{
"code": code,
"code_slug": slug,
"bucket": "ood" if code in cfg.ood_codes else "id",
"occurs_auc": float(m.get("held_out/occurs_auc"))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n src/every_query/eval.py | sed -n '70,100p'

Repository: payalchandak/EveryQuery

Length of output: 1358


🏁 Script executed:

# Find the config definition to check ood_codes type
rg "ood_codes" --type py -B 2 -A 2

Repository: payalchandak/EveryQuery

Length of output: 3179


🏁 Script executed:

# Look for cfg/config type definition
rg "class.*Config|@dataclass" --type py -A 10 | head -100

Repository: payalchandak/EveryQuery

Length of output: 1635


🏁 Script executed:

cat -n src/every_query/eval.py | sed -n '55,95p'

Repository: payalchandak/EveryQuery

Length of output: 1766


🏁 Script executed:

# Also check if there's any config class definition with type hints for ood_codes
rg "ood_codes.*:" --type py -B 2 -A 2

Repository: payalchandak/EveryQuery

Length of output: 756


Handle cfg.ood_codes being None in bucket assignment.

code in cfg.ood_codes raises TypeError when cfg.ood_codes is None. Although lines 62–63 check for None before populating the codes list, the membership test at line 88 uses cfg.ood_codes directly without protection. Use a local variable with a safe default.

🛠️ Safe default
     if cfg.manual_codes is not None:
         codes += cfg.manual_codes
 
+    ood_codes = cfg.ood_codes or []
+
     rows: list[dict[str, Any]] = []
@@
                 "code": code,
                 "code_slug": slug,
-                "bucket": "ood" if code in cfg.ood_codes else "id",
+                "bucket": "ood" if code in ood_codes else "id",
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
rows.append(
{
"code": code,
"code_slug": slug,
"bucket": "ood" if code in cfg.ood_codes else "id",
"occurs_auc": float(m.get("held_out/occurs_auc"))
rows.append(
{
"code": code,
"code_slug": slug,
"bucket": "ood" if code in ood_codes else "id",
"occurs_auc": float(m.get("held_out/occurs_auc"))
🤖 Prompt for AI Agents
In `@src/every_query/eval.py` around lines 84 - 89, The bucket assignment uses
"code in cfg.ood_codes" which throws when cfg.ood_codes is None; to fix,
introduce a safe local default (e.g., ood_codes = cfg.ood_codes or set()) near
where codes/cfg are prepared and replace the direct membership check in the
rows.append call with "code in ood_codes" (update the bucket expression inside
the rows.append block in eval.py so it uses the local ood_codes variable instead
of cfg.ood_codes).

Comment on lines +204 to +215
def on_test_epoch_end(self):
self._on_epoch_end(held_out_split)

subject_ids = torch.cat([batch_info["subject_id"] for batch_info in self.cache]).numpy()
prediction_times = torch.cat([batch_info["prediction_time"] for batch_info in self.cache]).numpy()
occurs_probs = torch.cat([batch_info["occurs_probs"] for batch_info in self.cache]).numpy()

self.test_predictions = {
"subject_id": subject_ids,
"prediction_time": prediction_times,
"occurs_probs": occurs_probs,
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n src/every_query/lightning_module.py | head -250 | tail -100

Repository: payalchandak/EveryQuery

Length of output: 4236


🏁 Script executed:

# Search for test_step to understand cache population
rg -n "def test_step" src/every_query/lightning_module.py -A 20

Repository: payalchandak/EveryQuery

Length of output: 1071


🏁 Script executed:

# Search for cache initialization
rg -n "self.cache" src/every_query/lightning_module.py -B 2 -A 2

Repository: payalchandak/EveryQuery

Length of output: 814


🏁 Script executed:

# Search for held_out_split
rg -n "held_out_split" src/every_query/lightning_module.py -B 3 -A 3

Repository: payalchandak/EveryQuery

Length of output: 1847


🏁 Script executed:

# Check if there are any test skips or conditions that might prevent test batches
rg -n "test_dataloader\|test_loader" src/every_query/lightning_module.py -B 3 -A 5

Repository: payalchandak/EveryQuery

Length of output: 49


🏁 Script executed:

rg -n "if.*cache\|if.*self\.cache" src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 49


🏁 Script executed:

# Check for any guards or conditions in on_test_epoch_end
sed -n '204,216p' src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 623


🏁 Script executed:

# Run Python to verify torch.cat behavior with empty lists
python3 << 'EOF'
import torch

print("=== Test 1: torch.cat([]) ===")
try:
    result = torch.cat([])
    print("Result:", result, "shape:", result.shape)
except Exception as e:
    print("ERROR:", type(e).__name__, "-", str(e))

print("\n=== Test 2: torch.cat on empty list from comprehension ===")
try:
    cache = []
    result = torch.cat([item for item in cache])
    print("Result:", result)
except Exception as e:
    print("ERROR:", type(e).__name__, "-", str(e))

print("\n=== Test 3: Realistic test_step scenario (empty cache) ===")
try:
    cache = []
    subject_ids = torch.cat([batch_info["subject_id"] for batch_info in cache])
    print("Result:", subject_ids)
except Exception as e:
    print("ERROR:", type(e).__name__, "-", str(e))

print("\n=== Test 4: occurs_probs generation (from test_step) ===")
try:
    occurs_logits = torch.randn(5, 1)
    occurs_probs = torch.sigmoid(occurs_logits).squeeze(-1)
    print("occurs_probs shape:", occurs_probs.shape, "type:", type(occurs_probs))
    cache = [{"occurs_probs": occurs_probs}]
    result = torch.cat([batch_info["occurs_probs"] for batch_info in cache])
    print("Concatenation result shape:", result.shape)
except Exception as e:
    print("ERROR:", type(e).__name__, "-", str(e))
EOF

Repository: payalchandak/EveryQuery

Length of output: 184


🌐 Web query:

PyTorch torch.cat empty list behavior

💡 Result:

torch.cat does not accept an empty sequence. If you call:

torch.cat([])

(or torch.cat([], dim=...)) you’ll get a runtime error like RuntimeError: torch.cat(): expected a non-empty list of Tensors. This is consistent with the docs’ requirement that tensors be a (conceptually) non-empty sequence. [1]

What is allowed is concatenating sequences that include empty tensors, as long as the sequence itself is non-empty and shapes are compatible (PyTorch also special-cases allowing a 1-D empty tensor of shape (0,)). [1]

Typical pattern to handle the empty-list case:

if xs:
    y = torch.cat(xs, dim=0)
else:
    y = torch.empty((0, D), device=device, dtype=dtype)  # choose the shape you expect

Sources: PyTorch torch.cat documentation [1]; examples of the empty-list runtime error in real codebases/issues [2].

References: [1] [2]


🏁 Script executed:

# Check if there are any conditions that might guarantee test_step is called
rg -n "def test_dataloader\|test_loader" src/every_query/ -A 10

Repository: payalchandak/EveryQuery

Length of output: 49


🏁 Script executed:

# Look for any setup.py or requirements.txt to understand dependencies
fd -e "requirements.txt" -o -e "setup.py" -o -e "pyproject.toml" | head -5

Repository: payalchandak/EveryQuery

Length of output: 183


Guard on_test_epoch_end against empty test cache.

When the test loader yields zero batches, test_step is never called, leaving self.cache empty. Calling torch.cat([...]) on an empty sequence raises RuntimeError: torch.cat(): expected a non-empty list of Tensors. Add an early return with correctly-typed empty arrays.

Guard implementation
     def on_test_epoch_end(self):
         self._on_epoch_end(held_out_split)
 
+        if not self.cache:
+            self.test_predictions = {
+                "subject_id": torch.empty(0, dtype=torch.long).numpy(),
+                "prediction_time": torch.empty(0, dtype=torch.long).numpy(),
+                "occurs_probs": torch.empty(0, dtype=torch.float).numpy(),
+            }
+            return
+
         subject_ids = torch.cat([batch_info["subject_id"] for batch_info in self.cache]).numpy()
         prediction_times = torch.cat([batch_info["prediction_time"] for batch_info in self.cache]).numpy()
         occurs_probs = torch.cat([batch_info["occurs_probs"] for batch_info in self.cache]).numpy()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def on_test_epoch_end(self):
self._on_epoch_end(held_out_split)
subject_ids = torch.cat([batch_info["subject_id"] for batch_info in self.cache]).numpy()
prediction_times = torch.cat([batch_info["prediction_time"] for batch_info in self.cache]).numpy()
occurs_probs = torch.cat([batch_info["occurs_probs"] for batch_info in self.cache]).numpy()
self.test_predictions = {
"subject_id": subject_ids,
"prediction_time": prediction_times,
"occurs_probs": occurs_probs,
}
def on_test_epoch_end(self):
self._on_epoch_end(held_out_split)
if not self.cache:
self.test_predictions = {
"subject_id": torch.empty(0, dtype=torch.long).numpy(),
"prediction_time": torch.empty(0, dtype=torch.long).numpy(),
"occurs_probs": torch.empty(0, dtype=torch.float).numpy(),
}
return
subject_ids = torch.cat([batch_info["subject_id"] for batch_info in self.cache]).numpy()
prediction_times = torch.cat([batch_info["prediction_time"] for batch_info in self.cache]).numpy()
occurs_probs = torch.cat([batch_info["occurs_probs"] for batch_info in self.cache]).numpy()
self.test_predictions = {
"subject_id": subject_ids,
"prediction_time": prediction_times,
"occurs_probs": occurs_probs,
}
🤖 Prompt for AI Agents
In `@src/every_query/lightning_module.py` around lines 204 - 215, The
on_test_epoch_end method must early-return when self.cache is empty to avoid
torch.cat on an empty sequence; add a guard at the top of on_test_epoch_end that
checks if not self.cache and if so sets self.test_predictions to a dict with
"subject_id", "prediction_time", and "occurs_probs" as appropriately-typed empty
numpy arrays (e.g., empty int array for subject_id and empty float arrays for
prediction_time and occurs_probs) and then returns. Ensure you update the same
symbols: on_test_epoch_end, self.cache, and self.test_predictions so downstream
code sees the correctly-typed empty arrays.

Comment on lines +299 to +307
# per-row probability of occurs(query)
occurs_probs = torch.sigmoid(outputs.occurs_logits).squeeze(-1) # [B]

self.cache.append(
{
"subject_id": batch.subject_id.detach().cpu(), # [B]
"prediction_time": batch.prediction_time.detach().cpu(), # [B]
"occurs_probs": occurs_probs.detach().cpu(), # [B]
"occurs": batch.occurs.detach().cpu() if batch.occurs is not None else None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cd /tmp && find . -name "lightning_module.py" -o -name "*config*" -o -name "*model*" 2>/dev/null | head -20

Repository: payalchandak/EveryQuery

Length of output: 49


🏁 Script executed:

git ls-files src/every_query/ | head -30

Repository: payalchandak/EveryQuery

Length of output: 1423


🏁 Script executed:

wc -l src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 106


🏁 Script executed:

sed -n '290,320p' src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 1307


🏁 Script executed:

sed -n '1,50p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 1391


🏁 Script executed:

grep -n "occurs_logits" src/every_query/model.py | head -20

Repository: payalchandak/EveryQuery

Length of output: 416


🏁 Script executed:

grep -n "class.*Output" src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 110


🏁 Script executed:

grep -n "occurs" src/every_query/model.py | head -30

Repository: payalchandak/EveryQuery

Length of output: 764


🏁 Script executed:

sed -n '200,280p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 3369


🏁 Script executed:

grep -n "self.occurs_mlp" src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 158


🏁 Script executed:

sed -n '250,260p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 478


🏁 Script executed:

sed -n '430,460p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 1197


🏁 Script executed:

sed -n '57,80p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 1078


🏁 Script executed:

sed -n '280,330p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 2156


🏁 Script executed:

grep -n "_forward_demo\|do_demo" src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 624


🏁 Script executed:

sed -n '411,480p' src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 1954


🏁 Script executed:

grep -n "EveryQueryOutput(" src/every_query/model.py

Repository: payalchandak/EveryQuery

Length of output: 150


🏁 Script executed:

grep -rn "occurs_mlp\|censor_mlp" src/every_query/ --include="*.py" | grep -v ".pyc"

Repository: payalchandak/EveryQuery

Length of output: 350


🏁 Script executed:

grep -rn "occurs_logits\|censor_logits" src/every_query/ --include="*.py" | grep -v ".pyc"

Repository: payalchandak/EveryQuery

Length of output: 1655


🏁 Script executed:

find src/every_query -name "*test*.py" -o -name "*_test.py" | head -10

Repository: payalchandak/EveryQuery

Length of output: 49


🏁 Script executed:

sed -n '255,280p' src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 1230


🏁 Script executed:

sed -n '295,310p' src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 834


🏁 Script executed:

grep -rn "occurs_logits.*None\|None.*occurs_logits" src/every_query/ --include="*.py"

Repository: payalchandak/EveryQuery

Length of output: 253


🏁 Script executed:

cat src/every_query/config.yaml

Repository: payalchandak/EveryQuery

Length of output: 4589


🏁 Script executed:

grep -rn "test_step\|predict_step" src/every_query/ --include="*.py" | head -20

Repository: payalchandak/EveryQuery

Length of output: 283


🏁 Script executed:

sed -n '260,280p' src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 1114


🏁 Script executed:

sed -n '295,320p' src/every_query/lightning_module.py

Repository: payalchandak/EveryQuery

Length of output: 1161


Guard occurs_logits before sigmoid.

The test_step method accesses outputs.occurs_logits directly without checking if it exists, while the _log_metrics method (line 271) already uses a defensive pattern with getattr(outputs, "occurs_logits", None) for the same field. This inconsistency leaves test_step vulnerable if the model behavior changes or the occurs head becomes optional. Add a defensive check to match the pattern established elsewhere in the module:

🛠️ Defensive check
-        # per-row probability of occurs(query)
-        occurs_probs = torch.sigmoid(outputs.occurs_logits).squeeze(-1)  # [B]
+        # per-row probability of occurs(query)
+        if getattr(outputs, "occurs_logits", None) is None:
+            raise ValueError("occurs_logits is required to compute occurs probabilities")
+        occurs_probs = torch.sigmoid(outputs.occurs_logits).squeeze(-1)  # [B]
🤖 Prompt for AI Agents
In `@src/every_query/lightning_module.py` around lines 299 - 307, In test_step,
guard access to outputs.occurs_logits the same way _log_metrics does: use
getattr(outputs, "occurs_logits", None) and only compute occurs_probs =
torch.sigmoid(...) and detach/CPU it if the logits are not None; otherwise set
occurs_probs to None (and ensure the "occurs_probs" entry appended to self.cache
uses that value). Update references to batch.occurs handling remains the same
and keep using outputs/occurs_logits symbol names so the change is localized to
test_step.

Comment on lines +7 to +11
import hydra
import polars as pl
from omegaconf import DictConfig
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Find the file and check its existence
find . -name "get_per_code_from_composite.py" -type f

Repository: payalchandak/EveryQuery

Length of output: 133


🏁 Script executed:

# List imports and structure of the target file
head -30 src/every_query/process_composite/get_per_code_from_composite.py

Repository: payalchandak/EveryQuery

Length of output: 864


🏁 Script executed:

# Check for dependency declarations (requirements.txt, setup.py, pyproject.toml, etc.)
find . -maxdepth 2 -type f \( -name "requirements*.txt" -o -name "setup.py" -o -name "pyproject.toml" -o -name "setup.cfg" \) | head -10

Repository: payalchandak/EveryQuery

Length of output: 83


🏁 Script executed:

# Search for sklearn usage in the file
rg "sklearn" src/every_query/process_composite/get_per_code_from_composite.py -A 3 -B 3

Repository: payalchandak/EveryQuery

Length of output: 234


🏁 Script executed:

# Check pyproject.toml for sklearn dependency
cat pyproject.toml | head -100

Repository: payalchandak/EveryQuery

Length of output: 2810


🏁 Script executed:

# Check for CI configuration files
find . -maxdepth 3 -type f \( -name ".github" -o -name "*.yml" -o -name "*.yaml" \) | grep -E "(workflow|ci|pytest)" | head -10

Repository: payalchandak/EveryQuery

Length of output: 219


🏁 Script executed:

# Find where roc_auc_score is actually used in the file
rg "roc_auc_score" src/every_query/process_composite/get_per_code_from_composite.py -B 5 -A 5

Repository: payalchandak/EveryQuery

Length of output: 624


🏁 Script executed:

# Get full file to understand context better
wc -l src/every_query/process_composite/get_per_code_from_composite.py

Repository: payalchandak/EveryQuery

Length of output: 134


🏁 Script executed:

# Check the full file to see how it's structured
cat -n src/every_query/process_composite/get_per_code_from_composite.py

Repository: payalchandak/EveryQuery

Length of output: 2764


🏁 Script executed:

# Check CI workflow configurations
cat .github/workflows/tests.yaml

Repository: payalchandak/EveryQuery

Length of output: 1058


🏁 Script executed:

# Check if sklearn is imported anywhere else in the project
rg "from sklearn|import sklearn" --type py | head -20

Repository: payalchandak/EveryQuery

Length of output: 270


🏁 Script executed:

# Check pytest configuration to see what gets collected
cat pyproject.toml | grep -A 10 "pytest"

Repository: payalchandak/EveryQuery

Length of output: 1474


Add scikit-learn as a dependency or move the import inside main().

scikit-learn is not listed in pyproject.toml dependencies, but line 10 imports it at module level. Since pytest is configured with --doctest-modules, pytest collection will fail with ModuleNotFoundError: No module named 'sklearn' before any tests run. Either declare scikit-learn as a dependency or defer the import.

🛠️ Lazy import with clear error handling
 import hydra
 import polars as pl
 from omegaconf import DictConfig
-from sklearn.metrics import roc_auc_score
 from tqdm import tqdm
 
 def main(cfg: DictConfig) -> None:
+    try:
+        from sklearn.metrics import roc_auc_score
+    except ModuleNotFoundError as e:
+        raise ModuleNotFoundError(
+            "scikit-learn is required for ROC AUC computation. Install it or "
+            "provide an alternative implementation."
+        ) from e
     print("starting")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import hydra
import polars as pl
from omegaconf import DictConfig
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
import hydra
import polars as pl
from omegaconf import DictConfig
from tqdm import tqdm
def main(cfg: DictConfig) -> None:
try:
from sklearn.metrics import roc_auc_score
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"scikit-learn is required for ROC AUC computation. Install it or "
"provide an alternative implementation."
) from e
print("starting")
🧰 Tools
🪛 GitHub Actions: Tests

[error] 10-10: ModuleNotFoundError: No module named 'sklearn' while importing get_per_code_from_composite.py during pytest collection.

🤖 Prompt for AI Agents
In `@src/every_query/process_composite/get_per_code_from_composite.py` around
lines 7 - 11, The module-level import of roc_auc_score from sklearn causes test
collection failures when scikit-learn isn't installed; either add scikit-learn
to project dependencies in pyproject.toml or move the import into the runtime
path (e.g., inside the main() function) and raise a clear error if it's missing;
specifically, remove the top-level "from sklearn.metrics import roc_auc_score"
and instead import roc_auc_score within main() (or the function that uses it)
and catch ImportError to log a helpful message about installing scikit-learn.

Comment on lines +44 to +62
for code_task_path in tqdm(code_task_df_paths, total=len(code_task_df_paths)):
slugged_code = Path(code_task_path).stem
code_task_df = pl.read_parquet(code_task_path, columns=["subject_id", "prediction_time", "occurs"])

code_preds_df = all_preds_df.filter(pl.col("code_slugged") == slugged_code)
# print(code_preds_df.head())
# print(slugged_code)

joined_df = code_preds_df.join(code_task_df, on=["subject_id", "prediction_time"], how="left")
print(joined_df.head())
print(joined_df.shape)
code = joined_df["code"][0]
print(code)

y_true = joined_df["occurs"].to_numpy()
y_score = joined_df["occurs_probs"].to_numpy()

auc = roc_auc_score(y_true, y_score)
results[code] = auc
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n src/every_query/process_composite/get_per_code_from_composite.py | head -80

Repository: payalchandak/EveryQuery

Length of output: 2764


🏁 Script executed:

# Let's check the exact lines mentioned (44-62)
sed -n '44,62p' src/every_query/process_composite/get_per_code_from_composite.py

Repository: payalchandak/EveryQuery

Length of output: 860


🏁 Script executed:

# Let's also get more context around those lines to understand the full function
sed -n '35,75p' src/every_query/process_composite/get_per_code_from_composite.py

Repository: payalchandak/EveryQuery

Length of output: 1341


🏁 Script executed:

python3 << 'EOF'
from sklearn.metrics import roc_auc_score
import numpy as np

# Test 1: Empty array
print("Test 1: Empty array")
try:
    result = roc_auc_score(np.array([]), np.array([]))
    print(f"  Result: {result}")
except Exception as e:
    print(f"  Error: {type(e).__name__}: {e}")

# Test 2: Single class (all zeros)
print("\nTest 2: Single class (all zeros)")
try:
    result = roc_auc_score(np.array([0, 0, 0]), np.array([0.1, 0.2, 0.3]))
    print(f"  Result: {result}")
except Exception as e:
    print(f"  Error: {type(e).__name__}: {e}")

# Test 3: Single class (all ones)
print("\nTest 3: Single class (all ones)")
try:
    result = roc_auc_score(np.array([1, 1, 1]), np.array([0.1, 0.2, 0.3]))
    print(f"  Result: {result}")
except Exception as e:
    print(f"  Error: {type(e).__name__}: {e}")

# Test 4: With NaNs
print("\nTest 4: With NaNs in y_score")
try:
    result = roc_auc_score(np.array([0, 1, 0]), np.array([0.1, np.nan, 0.3]))
    print(f"  Result: {result}")
except Exception as e:
    print(f"  Error: {type(e).__name__}: {e}")

# Test 5: With NaNs in y_true
print("\nTest 5: With NaNs in y_true")
try:
    result = roc_auc_score(np.array([0, np.nan, 0]), np.array([0.1, 0.5, 0.3]))
    print(f"  Result: {result}")
except Exception as e:
    print(f"  Error: {type(e).__name__}: {e}")

EOF

Repository: payalchandak/EveryQuery

Length of output: 186


🏁 Script executed:

# Check if the join can actually be empty - let's see what all_preds_df and code_task_df schemas are
# Look for where all_preds_df is created
grep -n "all_preds_df" src/every_query/process_composite/get_per_code_from_composite.py | head -20

Repository: payalchandak/EveryQuery

Length of output: 363


🏁 Script executed:

# Let's look at the full function to understand the context better
sed -n '1,100p' src/every_query/process_composite/get_per_code_from_composite.py

Repository: payalchandak/EveryQuery

Length of output: 2274


Add guards for empty joins and invalid labels before computing AUC.

Line 55 (joined_df["code"][0]) will crash if the join is empty. Additionally, roc_auc_score will fail if the data contains NaNs or has only a single class. Add guards to skip processing when these conditions occur:

Suggested guards
         code_preds_df = all_preds_df.filter(pl.col("code_slugged") == slugged_code)
 
         joined_df = code_preds_df.join(code_task_df, on=["subject_id", "prediction_time"], how="left")
+        if joined_df.is_empty():
+            logger.warning(f"No predictions for code slug {slugged_code} (skipping)")
+            continue
+        joined_df = joined_df.drop_nulls(["occurs", "occurs_probs"])
+        if joined_df.is_empty() or joined_df["occurs"].n_unique() < 2:
+            logger.warning(f"Insufficient label variation for {slugged_code} (skipping)")
+            continue
         print(joined_df.head())
         print(joined_df.shape)
         code = joined_df["code"][0]
🤖 Prompt for AI Agents
In `@src/every_query/process_composite/get_per_code_from_composite.py` around
lines 44 - 62, The loop that builds joined_df and computes AUC needs guards:
before accessing joined_df["code"][0] or computing auc with roc_auc_score, check
that joined_df is not empty (e.g., len(joined_df) > 0) and skip/log if empty;
then validate y_true and y_score extracted from joined_df (no NaNs, matching
lengths, and y_true contains at least two classes) and skip/log when any check
fails; only compute auc and set results[code] when all validations pass. Use the
existing variables/code symbols (code_task_path, slugged_code, code_task_df,
code_preds_df, joined_df, y_true, y_score, results, roc_auc_score) to add these
guards and avoid indexing an empty DataFrame or calling roc_auc_score on invalid
data.

Comment on lines +64 to +66
out_fp = cfg.output_root + "readmiss_30d_per_code.csv"
out_df = pl.DataFrame(results)
out_df.write_csv(out_fp)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use Path join for output filename.

String concatenation can yield incorrect paths (missing separator). Prefer Path joining.

🛠️ Path join
-    out_fp = cfg.output_root + "readmiss_30d_per_code.csv"
+    out_fp = Path(cfg.output_root) / "readmiss_30d_per_code.csv"
     out_df = pl.DataFrame(results)
     out_df.write_csv(out_fp)
🤖 Prompt for AI Agents
In `@src/every_query/process_composite/get_per_code_from_composite.py` around
lines 64 - 66, The code builds the output filename via string concatenation
(out_fp = cfg.output_root + "readmiss_30d_per_code.csv"), which can create
invalid paths; change this to use Path joining by importing pathlib.Path (or
using cfg.output_root as a Path) and construct out_fp with Path(cfg.output_root)
/ "readmiss_30d_per_code.csv" before calling out_df.write_csv(out_fp), ensuring
out_fp is converted to a string if write_csv requires it.

Comment on lines +5 to +9
import hydra
import polars as pl
from omegaconf import DictConfig
from sklearn.metrics import roc_auc_score

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's check if the file exists and read its current contents
if [ -f "src/every_query/process_composite/process_composite.py" ]; then
  echo "=== File exists ===" 
  wc -l src/every_query/process_composite/process_composite.py
  echo "=== File contents (first 100 lines) ==="
  head -100 src/every_query/process_composite/process_composite.py
else
  echo "File not found at: src/every_query/process_composite/process_composite.py"
  # Try to find similar files
  find . -name "*process_composite*" -type f 2>/dev/null | head -20
fi

Repository: payalchandak/EveryQuery

Length of output: 3240


🏁 Script executed:

# Check if sklearn is declared in dependencies (setup.py, pyproject.toml, requirements.txt, etc.)
echo "=== Looking for dependency declarations ==="
for file in setup.py setup.cfg pyproject.toml requirements.txt requirements-dev.txt Pipfile poetry.lock; do
  if [ -f "$file" ]; then
    echo "=== $file ==="
    grep -i "scikit-learn\|sklearn" "$file" || echo "sklearn not found in $file"
  fi
done

Repository: payalchandak/EveryQuery

Length of output: 169


🏁 Script executed:

# Check for any test failures or CI configuration
echo "=== CI configuration files ==="
find . -name "*.yml" -o -name "*.yaml" -o -name ".github" -type d | head -10

Repository: payalchandak/EveryQuery

Length of output: 454


🏁 Script executed:

# Check CI workflows to see if sklearn causes import/collection errors
echo "=== Checking tests.yaml ==="
cat .github/workflows/tests.yaml

echo -e "\n\n=== Checking python-build.yaml ==="
cat .github/workflows/python-build.yaml

Repository: payalchandak/EveryQuery

Length of output: 3847


🏁 Script executed:

# Check pyproject.toml more thoroughly to understand dependencies structure
echo "=== Full pyproject.toml ==="
cat pyproject.toml

Repository: payalchandak/EveryQuery

Length of output: 2838


Move sklearn import into agg_probs function or add scikit-learn as a dependency.

The module-level import at line 8 will cause pytest collection to fail if sklearn is not installed, since it's not declared in pyproject.toml dependencies. Recommended: lazy import within agg_probs where it's actually used (line 72).

Additionally, the clip_sum_to_1 parameter at line 28 is unused in the function body and should be removed.

🛠️ Lazy import
 import hydra
 import polars as pl
 from omegaconf import DictConfig
-from sklearn.metrics import roc_auc_score


 def agg_probs(
     all_preds_df_fp: str,
     all_aces_labels_df_fp: str,
     agg_type: list[str],
     *,
     pred_prob_col: str = "occurs_probs",
     label_col: str = "boolean_value",
-    clip_sum_to_1: bool = False,
     return_auc: bool = True,
 ) -> pl.DataFrame | tuple[pl.DataFrame, dict[str, float]]:
+    from sklearn.metrics import roc_auc_score
+
🧰 Tools
🪛 GitHub Actions: Tests

[error] 8-8: ModuleNotFoundError: No module named 'sklearn' while importing process_composite.py during pytest collection.

🤖 Prompt for AI Agents
In `@src/every_query/process_composite/process_composite.py` around lines 5 - 9,
The module-level sklearn import causes test failures when scikit-learn isn't
installed and the unused parameter clip_sum_to_1 should be removed: move the
import "from sklearn.metrics import roc_auc_score" into the agg_probs function
(so it is lazily imported only when agg_probs is called) and remove the unused
parameter clip_sum_to_1 from agg_probs's signature and any callers; update any
internal references to roc_auc_score to use the local import within agg_probs
and run tests to ensure no other call sites expect clip_sum_to_1.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Fix all issues with AI agents
In `@src/every_query/aces_to_eq/aces_to_eq.py`:
- Around line 26-28: The parameter target_rows on create_eq_task_df is unused
and should be either applied or removed to avoid dead configuration; either
implement sampling/truncation logic inside create_eq_task_df (e.g., after
loading the DataFrame apply .sample(n=target_rows) or .head(target_rows) /
otherwise enforce a row limit) or delete the target_rows parameter from the
create_eq_task_df signature and update all callers to stop passing it (also
remove the same unused parameter in the related functions referenced around
lines 79-85); ensure you update tests and call sites to match the chosen
approach.

In `@src/every_query/eval_suite/gen_task.py`:
- Around line 11-14: Replace the SHA-1 usage in code slug generation by
switching hashlib.sha1(...) to hashlib.sha256(...) and keep the existing
hexdigest slicing logic (still use n_hash) in the code_slug function; update
every other implementation of code_slug in the listed files
(src/every_query/eval.py,
src/every_query/process_composite/process_composite.py,
src/every_query/process_composite/get_per_code_from_composite.py,
src/every_query/process_composite/eval_composite.py,
src/every_query/aces_to_eq/aces_to_eq.py) so they use hashlib.sha256(...) and
produce the same formatted prefix/h suffix behavior (prefix generation via
re.sub and prefix_len unchanged); be aware this will change existing task
directory names and run any relevant tests or downstream references after making
the change.
- Line 39: Replace the assert [p.name for p in index_shards] == [p.name for p in
all_shards] with explicit validation that compares the two lists and raises a
clear exception (e.g., ValueError) when they differ; use the variables
index_shards and all_shards and include the expected and actual shard name lists
(or a concise diff) in the error message so the mismatch cannot be skipped under
python -O and is actionable in production.

In `@src/every_query/sample_codes/sample_train_codes.py`:
- Around line 31-42: The top-level script logic (reading PARQUET_PATH, computing
codes/time_codes/filtered_codes, prints and os.makedirs(OUT_DIR)) currently runs
on import; wrap this work in a guarded main routine so imports don’t trigger
file I/O/prints/dir creation: create a main() function (or reuse an existing
stable_hash_list() wrapper if intended) that contains the df =
pl.read_parquet(...), code filtering, prints and os.makedirs(OUT_DIR,
exist_ok=True), then call that function only inside if __name__ == "__main__":
so unit tests and imports can use functions like stable_hash_list() without side
effects.
- Around line 56-57: The ID hash is order-sensitive because id_codes is unsorted
while ood_codes is sorted; to make identity deterministic sort id_codes before
hashing. Locate the code that calls stable_hash_list on id_codes and ood_codes
(symbols: id_codes, ood_codes, stable_hash_list, and filtered_codes) and ensure
you sort id_codes (e.g., stable sort of elements or canonicalize by content)
prior to calling stable_hash_list so both ID and OOD hashes are derived from
content-only ordering.

In `@src/every_query/train.py`:
- Around line 87-90: The current task hashing uses only the code set (task_str)
so collated directories can be reused even when sampling parameters change;
update the string used to compute hash_hex to also include sampling-related
config values (e.g., cfg.query.sample_times_per_subject and cfg.query.seed) and
ensure any collection (like sample_times_per_subject if it's a list) is
deterministically ordered/serialized before joining—then recompute hash_hex from
that extended task string so write_dir reflects sampling config changes.
🧹 Nitpick comments (4)
src/every_query/sample_codes/sample_train_codes.py (2)

62-64: Manual YAML writing may not handle special characters correctly.

If any code string contains YAML special characters (e.g., :, #, {, }, or leading */&), the output may produce invalid YAML. Consider using yaml.safe_dump() from PyYAML for proper escaping.

🔧 Suggested improvement
+import yaml
+
     # ---- write ID file ----
     id_path = f"{OUT_DIR}/{N_SAMPLES}_ID__{id_hash}.yaml"
     with open(id_path, "x") as f:
-        f.write("codes:\n")
-        for code in id_codes:
-            f.write(f'  - "{code}"\n')
+        yaml.safe_dump({"codes": id_codes}, f, default_flow_style=False)

66-78: Consider removing or enabling the commented-out OOD logic.

Large blocks of commented-out code can clutter the codebase. If OOD file generation is needed, enable it; otherwise, consider removing it and restoring via version control if needed later.

src/every_query/eval_suite/gen_index_times.py (2)

58-62: Preserve shuffle order before per‑subject head.

Polars group_by does not preserve input order by default, so .head() may not reflect the seeded shuffle. Consider maintain_order=True (or a per‑group sample) to keep deterministic sampling.

🔁 Keep shuffled order within groups
-            .group_by("subject_id")
+            .group_by("subject_id", maintain_order=True)

15-20: Use SHA‑256 instead of MD5 if organizational policy requires avoiding MD5.

The function uses MD5 purely for generating a deterministic directory identifier for caching output—not for cryptographic purposes. While MD5 is acceptable for non-crypto uses like config hashing and file naming, some organizations restrict MD5 use for compliance reasons. If such a policy applies, switch to SHA‑256; otherwise, this is fine. Note that src/every_query/train.py uses MD5 similarly for task identification, so any change should be applied consistently across the codebase.

Comment on lines +26 to +28
def create_eq_task_df(
eq_shard_fp: str, aces_shard_fp: str, codes: list[str], output_root: str, target_rows: int
) -> int:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove or use target_rows to avoid dead configuration.

target_rows is passed but never used, which is misleading and may mask missing sampling logic. Either apply it (e.g., limit rows) or remove it from the signature and call sites.

🔧 Suggested fix (remove if unused)
-def create_eq_task_df(
-    eq_shard_fp: str, aces_shard_fp: str, codes: list[str], output_root: str, target_rows: int
-) -> int:
+def create_eq_task_df(
+    eq_shard_fp: str, aces_shard_fp: str, codes: list[str], output_root: str
+) -> int:
         create_eq_task_df(
             eq_shard_fp=str(eq_shard_fp),
             aces_shard_fp=str(aces_shard_fp),
             codes=cfg.queries,
             output_root=cfg.output_dir,
-            target_rows=cfg.target_rows_per_shard,
         )

Also applies to: 79-85

🧰 Tools
🪛 Ruff (0.14.14)

[warning] 27-27: Unused function argument: target_rows

(ARG001)

🤖 Prompt for AI Agents
In `@src/every_query/aces_to_eq/aces_to_eq.py` around lines 26 - 28, The parameter
target_rows on create_eq_task_df is unused and should be either applied or
removed to avoid dead configuration; either implement sampling/truncation logic
inside create_eq_task_df (e.g., after loading the DataFrame apply
.sample(n=target_rows) or .head(target_rows) / otherwise enforce a row limit) or
delete the target_rows parameter from the create_eq_task_df signature and update
all callers to stop passing it (also remove the same unused parameter in the
related functions referenced around lines 79-85); ensure you update tests and
call sites to match the chosen approach.

index_shards = list_parquets(index_dir)
all_shards = list_parquets(all_dir)

assert [p.name for p in index_shards] == [p.name for p in all_shards]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n src/every_query/eval_suite/gen_task.py | head -60

Repository: payalchandak/EveryQuery

Length of output: 2396


🏁 Script executed:

grep -n "assert" src/every_query/eval_suite/gen_task.py

Repository: payalchandak/EveryQuery

Length of output: 143


🏁 Script executed:

# Check if there are other assert statements in the file
wc -l src/every_query/eval_suite/gen_task.py

Repository: payalchandak/EveryQuery

Length of output: 108


🏁 Script executed:

cat -n src/every_query/eval_suite/gen_task.py

Repository: payalchandak/EveryQuery

Length of output: 3266


Replace assert with explicit validation for shard alignment.

Line 39 uses assert for a critical data integrity check, which can be stripped with python -O, risking silent misalignment between index and all shards. Raise an explicit error instead to ensure this precondition is always enforced in production.

✅ Use an explicit validation error
-    assert [p.name for p in index_shards] == [p.name for p in all_shards]
+    if [p.name for p in index_shards] != [p.name for p in all_shards]:
+        raise ValueError("Index and all-task shard filenames are misaligned")
🤖 Prompt for AI Agents
In `@src/every_query/eval_suite/gen_task.py` at line 39, Replace the assert
[p.name for p in index_shards] == [p.name for p in all_shards] with explicit
validation that compares the two lists and raises a clear exception (e.g.,
ValueError) when they differ; use the variables index_shards and all_shards and
include the expected and actual shard name lists (or a concise diff) in the
error message so the mismatch cannot be skipped under python -O and is
actionable in production.

Comment on lines +31 to +42
df = pl.read_parquet(PARQUET_PATH)
codes = df["code"].unique().sort().to_list()
print(f"num all codes {len(codes)}")

time_codes = [code for code in codes if "TIME" in code]
print(f"{len(time_codes)} TIME Codes removed:")
print(time_codes)

filtered_codes = [code for code in codes if "TIME" not in code]
print(f"num codes after filtering: {len(filtered_codes)}")

os.makedirs(OUT_DIR, exist_ok=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Wrap script logic in if __name__ == "__main__": guard.

All code from line 31 onward executes immediately when the module is imported. This:

  • Prevents reusing stable_hash_list() without triggering file I/O and directory creation
  • Makes unit testing difficult
  • Causes side effects on import (file reads, directory creation, prints)
🛠️ Suggested refactor
+def main():
+    # -------------------
+    # Load + filter codes
+    # -------------------
     df = pl.read_parquet(PARQUET_PATH)
     codes = df["code"].unique().sort().to_list()
     # ... rest of script logic ...
+
+if __name__ == "__main__":
+    main()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
df = pl.read_parquet(PARQUET_PATH)
codes = df["code"].unique().sort().to_list()
print(f"num all codes {len(codes)}")
time_codes = [code for code in codes if "TIME" in code]
print(f"{len(time_codes)} TIME Codes removed:")
print(time_codes)
filtered_codes = [code for code in codes if "TIME" not in code]
print(f"num codes after filtering: {len(filtered_codes)}")
os.makedirs(OUT_DIR, exist_ok=True)
def main():
# -------------------
# Load + filter codes
# -------------------
df = pl.read_parquet(PARQUET_PATH)
codes = df["code"].unique().sort().to_list()
print(f"num all codes {len(codes)}")
time_codes = [code for code in codes if "TIME" in code]
print(f"{len(time_codes)} TIME Codes removed:")
print(time_codes)
filtered_codes = [code for code in codes if "TIME" not in code]
print(f"num codes after filtering: {len(filtered_codes)}")
os.makedirs(OUT_DIR, exist_ok=True)
if __name__ == "__main__":
main()
🤖 Prompt for AI Agents
In `@src/every_query/sample_codes/sample_train_codes.py` around lines 31 - 42, The
top-level script logic (reading PARQUET_PATH, computing
codes/time_codes/filtered_codes, prints and os.makedirs(OUT_DIR)) currently runs
on import; wrap this work in a guarded main routine so imports don’t trigger
file I/O/prints/dir creation: create a main() function (or reuse an existing
stable_hash_list() wrapper if intended) that contains the df =
pl.read_parquet(...), code filtering, prints and os.makedirs(OUT_DIR,
exist_ok=True), then call that function only inside if __name__ == "__main__":
so unit tests and imports can use functions like stable_hash_list() without side
effects.

Comment on lines +56 to +57
id_hash = stable_hash_list(id_codes)
ood_hash = stable_hash_list(ood_codes)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hash ordering inconsistency between ID and OOD sets.

id_codes is in random sample order, while ood_codes preserves the sorted order from filtered_codes. Since stable_hash_list is order-sensitive, two runs that sample the same codes in a different order will produce different ID hashes. This breaks the "identity" guarantee mentioned in the comment on line 55.

For deterministic, content-based identity, sort id_codes before hashing:

🔧 Proposed fix
     # Hashes define identity of the code universes
-    id_hash = stable_hash_list(id_codes)
+    id_hash = stable_hash_list(sorted(id_codes))
     ood_hash = stable_hash_list(ood_codes)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
id_hash = stable_hash_list(id_codes)
ood_hash = stable_hash_list(ood_codes)
id_hash = stable_hash_list(sorted(id_codes))
ood_hash = stable_hash_list(ood_codes)
🤖 Prompt for AI Agents
In `@src/every_query/sample_codes/sample_train_codes.py` around lines 56 - 57, The
ID hash is order-sensitive because id_codes is unsorted while ood_codes is
sorted; to make identity deterministic sort id_codes before hashing. Locate the
code that calls stable_hash_list on id_codes and ood_codes (symbols: id_codes,
ood_codes, stable_hash_list, and filtered_codes) and ensure you sort id_codes
(e.g., stable sort of elements or canonicalize by content) prior to calling
stable_hash_list so both ID and OOD hashes are derived from content-only
ordering.

Comment on lines +87 to 90
task_str = f"{'|'.join(sorted(cfg.query.codes))}"
hash_hex = hashlib.md5(task_str.encode()).hexdigest()
write_dir = f"{cfg.query.task_dir}/collated/{hash_hex}"

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Include sampling config in the task hash to avoid stale collations.

task_str only hashes the code set, but sample_times_per_subject (and seed) change the generated shards. That can reuse an old collated directory with different sampling settings. Consider including these in the hash key to prevent incorrect reuse.

🔧 Suggested fix
-    task_str = f"{'|'.join(sorted(cfg.query.codes))}"
+    task_str = "|".join(
+        [
+            ",".join(sorted(cfg.query.codes)),
+            f"sps={cfg.query.sample_times_per_subject}",
+            f"seed={cfg.get('seed', 1)}",
+        ]
+    )
🧰 Tools
🪛 Ruff (0.14.14)

[error] 88-88: Probable use of insecure hash functions in hashlib: md5

(S324)

🤖 Prompt for AI Agents
In `@src/every_query/train.py` around lines 87 - 90, The current task hashing uses
only the code set (task_str) so collated directories can be reused even when
sampling parameters change; update the string used to compute hash_hex to also
include sampling-related config values (e.g., cfg.query.sample_times_per_subject
and cfg.query.seed) and ensure any collection (like sample_times_per_subject if
it's a list) is deterministically ordered/serialized before joining—then
recompute hash_hex from that extended task string so write_dir reflects sampling
config changes.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/every_query/process_composite/process_composite.py`:
- Around line 102-107: The code may fail when cfg.output_path's directory does
not exist before calling out_df.write_csv; update the write path handling around
out_fp in process_composite.py to ensure the parent directory exists by creating
it (use out_fp.parent.mkdir(parents=True, exist_ok=True)) before calling
out_df.write_csv, while preserving the existing overwrite guard that checks
out_fp.exists() and cfg.do_overwrite.
🧹 Nitpick comments (1)
src/every_query/process_composite/process_composite.py (1)

12-19: Consider extracting shared utilities to a common module.

values_as_list and code_slug are duplicated in at least three files:

  • src/every_query/process_composite/get_per_code_from_composite.py
  • src/every_query/eval.py
  • This file

Additionally, values_as_list appears unused in this file.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/every_query/eval_composite.py`:
- Around line 52-53: cfg.query_codes may be None or not a list so iterating into
"codes" can raise; before using cfg.query_codes (assigned to the local variable
codes) validate and coerce it to a safe default (e.g., if not
isinstance(cfg.query_codes, list) or cfg.query_codes is None then set codes = []
or try to cast iterable to list) and log or raise a clear error if the value is
invalid; update the assignment around the "codes: list[str] = cfg.query_codes"
line to perform this check so subsequent iteration over codes and the creation
of rows is safe.
- Around line 67-71: pred_batches (the result of trainer.predict) may be None,
empty, or contain batch dicts missing keys which will make torch.cat(...) fail;
add defensive validation after calling trainer.predict(M, D,
ckpt_path=cfg.ckpt_path) to (1) check pred_batches is not None and has length>0,
(2) verify each batch is a dict and contains the keys "subject_id",
"prediction_time", and "occurs_probs" (or skip/log/raise a clear error if any
are missing), and (3) optionally ensure the values are tensors before calling
torch.cat; update the code around pred_batches, and the subsequent concatenation
that produces subject_id, prediction_time, and occurs_probs to use these
validated/normalized batches or bail with a helpful error message.
🧹 Nitpick comments (4)
src/every_query/process_composite/process_composite.py (1)

44-46: Comment mentions clipping but no clipping is implemented.

The comment states "Clip inside to avoid negatives from float error" but no clipping is applied. Floating-point errors could cause 1 - p to go slightly negative or the result to exceed 1.

🔧 Consider adding explicit clipping
     if "or" in agg_type_set:
-        # 1 - prod(1 - p). Clip inside to avoid negatives from float error.
-        aggs.append((1.0 - (1.0 - pl.col(pred_prob_col)).product()).alias("prob_or"))
+        # 1 - prod(1 - p). Clip to avoid negatives from float error.
+        aggs.append(
+            (1.0 - (1.0 - pl.col(pred_prob_col).clip(0, 1)).product())
+            .clip(0, 1)
+            .alias("prob_or")
+        )
src/every_query/eval.py (1)

18-19: Unused helper function values_as_list.

This function is defined but never used within this file. Consider removing it or moving it to a shared utilities module if needed elsewhere.

src/every_query/model.py (1)

110-112: Consider specifying dimension in squeeze() to avoid unexpected behavior.

Using squeeze() without a dimension argument removes all dimensions of size 1. With batch_size=1, this could unexpectedly reduce a (1,) tensor to a scalar (), which may cause shape mismatches downstream.

🔧 Proposed fix
     `@staticmethod`
     def logits_to_probs(logits: torch.Tensor) -> torch.Tensor:
-        return torch.sigmoid(logits).squeeze()
+        return torch.sigmoid(logits).squeeze(-1)
src/every_query/eval_composite.py (1)

18-19: Remove unused duplicate helper function.

values_as_list is defined identically in both eval.py and eval_composite.py but is not used in either file. Consider removing it from both files.

Comment on lines +52 to +53
codes: list[str] = cfg.query_codes
rows = []
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate cfg.query_codes before iteration.

If cfg.query_codes is None or not a list, line 55 will raise a TypeError when iterating. Add validation or a safe default.

🛠️ Proposed fix
-    codes: list[str] = cfg.query_codes
+    codes: list[str] = cfg.query_codes or []
+    if not codes:
+        logger.warning("No query_codes provided. Exiting.")
+        return
+
     rows = []
🤖 Prompt for AI Agents
In `@src/every_query/eval_composite.py` around lines 52 - 53, cfg.query_codes may
be None or not a list so iterating into "codes" can raise; before using
cfg.query_codes (assigned to the local variable codes) validate and coerce it to
a safe default (e.g., if not isinstance(cfg.query_codes, list) or
cfg.query_codes is None then set codes = [] or try to cast iterable to list) and
log or raise a clear error if the value is invalid; update the assignment around
the "codes: list[str] = cfg.query_codes" line to perform this check so
subsequent iteration over codes and the creation of rows is safe.

Comment on lines +67 to +71
pred_batches = trainer.predict(model=M, datamodule=D, ckpt_path=cfg.ckpt_path)

subject_id = torch.cat([b["subject_id"] for b in pred_batches]).numpy()
prediction_time = torch.cat([b["prediction_time"] for b in pred_batches]).numpy()
occurs_probs = torch.cat([b["occurs_probs"] for b in pred_batches]).numpy()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add validation for prediction batches.

If trainer.predict returns an empty list or None, or if the batch dictionaries don't contain the expected keys (subject_id, prediction_time, occurs_probs), this code will raise an error.

🛡️ Proposed defensive check
         pred_batches = trainer.predict(model=M, datamodule=D, ckpt_path=cfg.ckpt_path)
 
+        if not pred_batches:
+            logger.warning(f"No predictions returned for code={code} (skipping)")
+            continue
+
         subject_id = torch.cat([b["subject_id"] for b in pred_batches]).numpy()
         prediction_time = torch.cat([b["prediction_time"] for b in pred_batches]).numpy()
         occurs_probs = torch.cat([b["occurs_probs"] for b in pred_batches]).numpy()
🤖 Prompt for AI Agents
In `@src/every_query/eval_composite.py` around lines 67 - 71, pred_batches (the
result of trainer.predict) may be None, empty, or contain batch dicts missing
keys which will make torch.cat(...) fail; add defensive validation after calling
trainer.predict(M, D, ckpt_path=cfg.ckpt_path) to (1) check pred_batches is not
None and has length>0, (2) verify each batch is a dict and contains the keys
"subject_id", "prediction_time", and "occurs_probs" (or skip/log/raise a clear
error if any are missing), and (3) optionally ensure the values are tensors
before calling torch.cat; update the code around pred_batches, and the
subsequent concatenation that produces subject_id, prediction_time, and
occurs_probs to use these validated/normalized batches or bail with a helpful
error message.

@payalchandak payalchandak self-requested a review February 10, 2026 20:38
@payalchandak payalchandak merged commit bf9b074 into main Feb 10, 2026
5 of 6 checks passed
@gkondas gkondas deleted the greg_testing branch March 11, 2026 01:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants