diff --git a/src/devinterp/backends/default/slt/sampler.py b/src/devinterp/backends/default/slt/sampler.py index 2cd50538..bb6a5623 100644 --- a/src/devinterp/backends/default/slt/sampler.py +++ b/src/devinterp/backends/default/slt/sampler.py @@ -41,6 +41,7 @@ def sample_single_chain( optimize_over_per_model_param: Optional[dict] = None, callbacks: List[SamplerCallback] = [], use_amp: bool = False, + copy_full_model: bool=True, **kwargs, ): if grad_accum_steps > 1: @@ -53,8 +54,16 @@ def sample_single_chain( "You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)" ) - # Initialize new model and optimizer for this chain - model = deepcopy(ref_model).to(device) + if copy_full_model: + model = deepcopy(ref_model).to(device) + else: + model = ref_model.to(device) + if 'named_params' in optimizer_kwargs: + named_params = optimizer_kwargs.pop('named_params') + else: + named_params = model.named_parameters() + if not copy_full_model: + ref_named_params = deepcopy(named_params) if "temperature" in optimizer_kwargs: assert ( @@ -78,7 +87,7 @@ def sample_single_chain( if optimize_over_per_model_param: param_groups = [] - for name, parameter in model.named_parameters(): + for name, parameter in named_params: param_groups.append( { "params": parameter, @@ -90,7 +99,7 @@ def sample_single_chain( **optimizer_kwargs, ) else: - optimizer = sampling_method(model.parameters(), **optimizer_kwargs) + optimizer = sampling_method(named_params.values(), **optimizer_kwargs) if seed is not None: torch.manual_seed(seed) @@ -138,6 +147,10 @@ def sample_single_chain( cumulative_loss = 0 pbar.update(1) + # Restore orginal params + if not copy_full_model: + model.load_state_dict(ref_named_params, strict=False) + def _sample_single_chain(kwargs): pickled_args = ["evaluate", "loader"] @@ -342,6 +355,7 @@ def evaluate(model, data): verbose=verbose, optimize_over_per_model_param=optimize_over_per_model_param, use_amp=use_amp, + copy_full_model=cores > 1, ) if cores > 1: