Skip to content

[REFACTOR] Refactor training setup to use TrainingDataAdapter #200

@jeipollack

Description

@jeipollack

Description of Refactoring/Improvement
Update get_loss_metrics_monitor_and_outputs() and related training setup code to use the new TrainingDataAdapter interface (from Issue #199) instead of directly accessing DataConfigHandler internals. This completes the decoupling of training logic from dataset implementation details.

Goals and Objectives

  • Eliminate hardcoded dataset field access from training logic
  • Simplify get_loss_metrics_monitor_and_outputs() signature and implementation
  • Enable clean support for both simulation and real data workflows
  • Improve code readability and maintainability

Current Code Behavior
get_loss_metrics_monitor_and_outputs() currently:

  • Accepts data_conf (entire DataConfigHandler object)
  • Contains branching logic based on data_conf.training_data.dataset structure
  • Hardcodes field names like "noisy_stars", "stars", "masks"
  • Cannot easily support alternative data sources
def get_loss_metrics_monitor_and_outputs(training_handler, data_conf):
    # ... complex conditional logic ...
    outputs = data_conf.training_data.dataset["noisy_stars"]  # Brittle

Proposed Changes

  1. Update
def get_loss_metrics_monitor_and_outputs(
    training_handler,
    data_adapter: TrainingDataAdapter  # Clean interface
):
  1. Simplify implementation
def get_loss_metrics_monitor_and_outputs(training_handler, data_adapter):
    use_masks = training_handler.training_hparams.loss == "mask_mse"
    
    # Clean adapter calls (no dataset structure knowledge needed)
    outputs = data_adapter.get_training_outputs(use_masks=use_masks)
    output_val = data_adapter.get_validation_outputs(use_masks=use_masks)
    
    # Loss/metrics configuration unchanged
    if use_masks:
        loss = train_utils.MaskedMeanSquaredError()
        # ...
  1. Update calling code in training scripts:
# Before:
data_config = DataConfigHandler(conf_path, model_params, batch_size=32)
loss, metrics, ... = get_loss_metrics_monitor_and_outputs(
    training_handler, 
    data_config  # Heavy object
)

# After:
data_config = DataConfigHandler(conf_path, model_params)
data_adapter = DataAdapterFactory.create_from_config(data_config)
loss, metrics, ... = get_loss_metrics_monitor_and_outputs(
    training_handler,
    data_adapter  # Focused interface
)

Expected Benefits

  • Simplification: Remove ~15 lines of conditional dataset access logic
  • Clarity: Function dependencies explicit via type hints
  • Testability: Can test with mock adapters (no need for full DataConfigHandler setup)
  • Flexibility: Switching data sources now requires only changing adapter, not training code
  • Type Safety: IDE autocomplete and type checking work properly

Dependencies
Requires: Issue #199 (TrainingDataAdapter implementation) to be completed first
Affects:

  • Main training script(s) that call get_loss_metrics_monitor_and_outputs()
  • Any test code that uses this function
  • Documentation/examples showing training setup

Testing Plan

  1. Unit tests for refactored function:
  • Test with mock TrainingDataAdapter (verify it calls correct methods)
  • Test both masked and unmasked loss configurations
  • Verify correct loss/metrics objects returned
  1. Integration tests:
  • Test with real PreSplitDataAdapter (simulation workflow)
  • Test with real SplitOnLoadDataAdapter (real data workflow)
  • Verify training loop works end-to-end with both adapter types
  1. Regression tests:
  • Ensure existing simulation training produces identical results

Additional Context
This refactoring is the second step in modernizing the data handling architecture. Issue #199 introduces the adapter abstraction, while this issue migrates the training code to actually use it.
Breaking change consideration: The function signature changes, so this is technically a breaking change.

Impact Assessment
Medium impact, high value:

  • Breaking change to get_loss_metrics_monitor_and_outputs() signature
  • All code calling this function needs updating (estimate: 3-5 call sites)
  • Significantly improves code quality and maintainability
  • Unblocks future work on real data training pipeline
  • Sets precedent for adapter pattern usage elsewhere in codebase

Migration effort: ~2-4 hours for implementation + testing

Next Steps

  1. Verify Issue [REFACTOR] Add TrainingDataAdapter abstraction for simulation data #199 is merged and adapters are available
  2. Update get_loss_metrics_monitor_and_outputs() signature and implementation
  3. Update all calling code in training scripts
  4. Update tests to use new signature (or test both if using shim)
  5. Update documentation and examples
  6. Optional: Add deprecation shim if backward compatibility needed
  7. Run full regression test suite
  8. Update CHANGELOG with migration guide if breaking change
  9. Optional follow-up: Clean up DataConfigHandler to remove now-unused attributes (separate issue)

Thank you for starting this request to refactor or improve the code. We will review it and collaborate to enhance the codebase together! 🛠️

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

No status

Relationships

None yet

Development

No branches or pull requests

Issue actions