Skip to content

[REFACTOR] Add TrainingDataAdapter abstraction for simulation data #199

@jeipollack

Description

@jeipollack

Description of Refactoring/Improvement
Introduce an adapter pattern to decouple training logic from dataset implementation details. Currently, get_loss_metrics_monitor_and_outputs() directly accesses dataset structures (e.g., data_conf.training_data.dataset["noisy_stars"]), making it difficult to support different data sources (simulations vs. real data) without conditional logic

Goals and Objectives

  • Create a clean abstraction layer between data loading and training code
  • Enable support for multiple data sources (pre-split simulations, real data requiring splitting) without modifying training logic
  • Improve testability by allowing mock adapters
  • Establish extensible pattern for future data sources

Current Code Behavior
The get_loss_metrics_monitor_and_outputs() function currently:

  • Contains hardcoded dataset field access (dataset["noisy_stars"], dataset["stars"])
  • Assumes pre-split train/test data
  • Tightly couples training configuration with dataset structure
  • Cannot handle real SHE data that requires train/val splitting

Proposed Changes
This PR establishes the adapter pattern for simulation data only. Support for real PSF datasets will come in a follow-up PR.

  1. Add new file training_data_adapter.py with:
class TrainingDataAdapter(ABC):
    @abstractmethod
    def get_training_outputs(self, use_masks: bool = False) -> tf.Tensor:
        pass
    
    @abstractmethod
    def get_validation_outputs(self, use_masks: bool = False) -> tf.Tensor:
        pass
    
    @abstractmethod
    def get_training_inputs(self) -> Tuple[tf.Tensor, ...]:
        pass
    
    @abstractmethod
    def get_validation_inputs(self) -> Tuple[tf.Tensor, ...]:
        pass
  1. Concrete Implementation for simulation workflow
    PreSplitDataAdapter: Wraps existing pre-split simulation data
class PreSplitSimulationAdapter(TrainingDataAdapter):
    """Adapter for pre-split simulation data loaded from .npy files."""
    def __init__(self, train_loader: SimulationDataLoader, test_loader: SimulationDataLoader):
        # Wraps two SimulationDataLoader instances
  1. Factory for adapter creation:
class DataAdapterFactory:
    @staticmethod
    def create_from_config(data_config_handler, load_data=True) -> TrainingDataAdapter:
        # Returns appropriate adapter based on data_type configuration

Expected Benefits

  • Maintainability: Dataset access logic isolated in adapters, not scattered across training code
  • Extensibility: New data sources added by creating new adapter classes
  • Testability: Mock adapters can be injected for unit testing training logic
  • Flexibility: Support for both simulation (pre-split) and later real data (requires splitting) workflows
  • Type Safety: Clear interface contracts via abstract base class

Dependencies

Testing Plan

  1. Unit tests for each adapter:
  • Test PreSplitDataAdapter with mock DataHandler instances
  • Verify correct tensor shapes and data types
  • Test masking behaviour (with/without masks)
  1. Integration tests:
  • Test DataAdapterFactory with different config types
  • Verify adapters work with real DataHandler instances
  • Test edge cases (empty datasets, single sample, all train/no validation)
  1. Property-based tests:
  • Verify train/val splits sum to total dataset size
  • Ensure no data leakage between splits

Additional Context
This refactoring is motivated by the need to support real SHE PSF data (represented as SHEPSFDataset dataclass) which arrives as a single dataset requiring train/test splitting, unlike simulations which provide pre-split datasets. The adapter pattern allows both workflows to coexist cleanly.

Impact Assessment
Low risk, high value addition:

  • Pure addition with no modifications to existing code
  • Establishes architectural pattern for future data handling improvements
  • Enables follow-up work to simplify training logic
  • No performance impact (thin wrapper around existing operations)

Next Steps

  1. Implement abstract TrainingDataAdapter base class
  2. Implement PreSplitDataAdapter for existing simulation workflow
  3. Implement DataAdapterFactory with configuration-based creation
  4. Write comprehensive unit tests
  5. Document adapter pattern in developer guide
  6. Create second issue to migrate training code to use adapters

Thank you for starting this request to refactor or improve the code. We will review it and collaborate to enhance the codebase together! 🛠️

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Projects

Status

No status

Relationships

None yet

Development

No branches or pull requests

Issue actions