From 0b78aee10ba239ae1d1608282e3d1b8c93c5eb69 Mon Sep 17 00:00:00 2001 From: wilson Date: Sun, 9 Mar 2025 02:13:21 +0000 Subject: [PATCH 1/2] sample.py only copy params that are optimized over --- src/devinterp/backends/default/slt/sampler.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/devinterp/backends/default/slt/sampler.py b/src/devinterp/backends/default/slt/sampler.py index 2cd50538..d09f2db3 100644 --- a/src/devinterp/backends/default/slt/sampler.py +++ b/src/devinterp/backends/default/slt/sampler.py @@ -53,8 +53,15 @@ def sample_single_chain( "You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)" ) - # Initialize new model and optimizer for this chain - model = deepcopy(ref_model).to(device) + # Smuggling in named_params through optimizer_kwargs is a little hacky + # Maybe better to just pass named_params as a separate argument + # Also this approach probably doesn't work with multiple GPUs + model = ref_model.to(device) + if 'named_params' in optimizer_kwargs: + named_params = optimizer_kwargs.pop('named_params') + else: + named_params = model.named_parameters() + ref_named_params = deepcopy(named_params) if "temperature" in optimizer_kwargs: assert ( @@ -78,7 +85,7 @@ def sample_single_chain( if optimize_over_per_model_param: param_groups = [] - for name, parameter in model.named_parameters(): + for name, parameter in named_params: param_groups.append( { "params": parameter, @@ -90,7 +97,7 @@ def sample_single_chain( **optimizer_kwargs, ) else: - optimizer = sampling_method(model.parameters(), **optimizer_kwargs) + optimizer = sampling_method(named_params.values(), **optimizer_kwargs) if seed is not None: torch.manual_seed(seed) @@ -138,6 +145,9 @@ def sample_single_chain( cumulative_loss = 0 pbar.update(1) + # Restore orginal params + model.load_state_dict(ref_named_params, strict=False) + def _sample_single_chain(kwargs): pickled_args = ["evaluate", "loader"] From 91aba28c41c6e5e3ce5dc0c060b28d44525c142f Mon Sep 17 00:00:00 2001 From: wilson Date: Sun, 9 Mar 2025 02:25:40 +0000 Subject: [PATCH 2/2] Copy full model when cores > 1 --- src/devinterp/backends/default/slt/sampler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/devinterp/backends/default/slt/sampler.py b/src/devinterp/backends/default/slt/sampler.py index d09f2db3..bb6a5623 100644 --- a/src/devinterp/backends/default/slt/sampler.py +++ b/src/devinterp/backends/default/slt/sampler.py @@ -41,6 +41,7 @@ def sample_single_chain( optimize_over_per_model_param: Optional[dict] = None, callbacks: List[SamplerCallback] = [], use_amp: bool = False, + copy_full_model: bool=True, **kwargs, ): if grad_accum_steps > 1: @@ -53,15 +54,16 @@ def sample_single_chain( "You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)" ) - # Smuggling in named_params through optimizer_kwargs is a little hacky - # Maybe better to just pass named_params as a separate argument - # Also this approach probably doesn't work with multiple GPUs - model = ref_model.to(device) + if copy_full_model: + model = deepcopy(ref_model).to(device) + else: + model = ref_model.to(device) if 'named_params' in optimizer_kwargs: named_params = optimizer_kwargs.pop('named_params') else: named_params = model.named_parameters() - ref_named_params = deepcopy(named_params) + if not copy_full_model: + ref_named_params = deepcopy(named_params) if "temperature" in optimizer_kwargs: assert ( @@ -146,7 +148,8 @@ def sample_single_chain( pbar.update(1) # Restore orginal params - model.load_state_dict(ref_named_params, strict=False) + if not copy_full_model: + model.load_state_dict(ref_named_params, strict=False) def _sample_single_chain(kwargs): @@ -352,6 +355,7 @@ def evaluate(model, data): verbose=verbose, optimize_over_per_model_param=optimize_over_per_model_param, use_amp=use_amp, + copy_full_model=cores > 1, ) if cores > 1: