-
Notifications
You must be signed in to change notification settings - Fork 9
Description
Description of Refactoring/Improvement
Update get_loss_metrics_monitor_and_outputs() and related training setup code to use the new TrainingDataAdapter interface (from Issue #199) instead of directly accessing DataConfigHandler internals. This completes the decoupling of training logic from dataset implementation details.
Goals and Objectives
- Eliminate hardcoded dataset field access from training logic
- Simplify
get_loss_metrics_monitor_and_outputs()signature and implementation - Enable clean support for both simulation and real data workflows
- Improve code readability and maintainability
Current Code Behavior
get_loss_metrics_monitor_and_outputs() currently:
- Accepts
data_conf(entire DataConfigHandler object) - Contains branching logic based on
data_conf.training_data.datasetstructure - Hardcodes field names like
"noisy_stars","stars","masks" - Cannot easily support alternative data sources
def get_loss_metrics_monitor_and_outputs(training_handler, data_conf):
# ... complex conditional logic ...
outputs = data_conf.training_data.dataset["noisy_stars"] # BrittleProposed Changes
- Update
def get_loss_metrics_monitor_and_outputs(
training_handler,
data_adapter: TrainingDataAdapter # Clean interface
):
- Simplify implementation
def get_loss_metrics_monitor_and_outputs(training_handler, data_adapter):
use_masks = training_handler.training_hparams.loss == "mask_mse"
# Clean adapter calls (no dataset structure knowledge needed)
outputs = data_adapter.get_training_outputs(use_masks=use_masks)
output_val = data_adapter.get_validation_outputs(use_masks=use_masks)
# Loss/metrics configuration unchanged
if use_masks:
loss = train_utils.MaskedMeanSquaredError()
# ...
- Update calling code in training scripts:
# Before:
data_config = DataConfigHandler(conf_path, model_params, batch_size=32)
loss, metrics, ... = get_loss_metrics_monitor_and_outputs(
training_handler,
data_config # Heavy object
)
# After:
data_config = DataConfigHandler(conf_path, model_params)
data_adapter = DataAdapterFactory.create_from_config(data_config)
loss, metrics, ... = get_loss_metrics_monitor_and_outputs(
training_handler,
data_adapter # Focused interface
)
Expected Benefits
- Simplification: Remove ~15 lines of conditional dataset access logic
- Clarity: Function dependencies explicit via type hints
- Testability: Can test with mock adapters (no need for full DataConfigHandler setup)
- Flexibility: Switching data sources now requires only changing adapter, not training code
- Type Safety: IDE autocomplete and type checking work properly
Dependencies
Requires: Issue #199 (TrainingDataAdapter implementation) to be completed first
Affects:
- Main training script(s) that call get_loss_metrics_monitor_and_outputs()
- Any test code that uses this function
- Documentation/examples showing training setup
Testing Plan
- Unit tests for refactored function:
- Test with mock
TrainingDataAdapter(verify it calls correct methods) - Test both masked and unmasked loss configurations
- Verify correct loss/metrics objects returned
- Integration tests:
- Test with real
PreSplitDataAdapter(simulation workflow) - Test with real
SplitOnLoadDataAdapter(real data workflow) - Verify training loop works end-to-end with both adapter types
- Regression tests:
- Ensure existing simulation training produces identical results
Additional Context
This refactoring is the second step in modernizing the data handling architecture. Issue #199 introduces the adapter abstraction, while this issue migrates the training code to actually use it.
Breaking change consideration: The function signature changes, so this is technically a breaking change.
Impact Assessment
Medium impact, high value:
- Breaking change to
get_loss_metrics_monitor_and_outputs()signature - All code calling this function needs updating (estimate: 3-5 call sites)
- Significantly improves code quality and maintainability
- Unblocks future work on real data training pipeline
- Sets precedent for adapter pattern usage elsewhere in codebase
Migration effort: ~2-4 hours for implementation + testing
Next Steps
- Verify Issue [REFACTOR] Add
TrainingDataAdapterabstraction for simulation data #199 is merged and adapters are available - Update
get_loss_metrics_monitor_and_outputs()signature and implementation - Update all calling code in training scripts
- Update tests to use new signature (or test both if using shim)
- Update documentation and examples
- Optional: Add deprecation shim if backward compatibility needed
- Run full regression test suite
- Update
CHANGELOGwith migration guide if breaking change - Optional follow-up: Clean up
DataConfigHandlerto remove now-unused attributes (separate issue)
Thank you for starting this request to refactor or improve the code. We will review it and collaborate to enhance the codebase together! 🛠️
Metadata
Metadata
Assignees
Labels
Type
Projects
Status