Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/devinterp/backends/default/slt/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def sample_single_chain(
optimize_over_per_model_param: Optional[dict] = None,
callbacks: List[SamplerCallback] = [],
use_amp: bool = False,
copy_full_model: bool=True,
**kwargs,
):
if grad_accum_steps > 1:
Expand All @@ -53,8 +54,16 @@ def sample_single_chain(
"You are taking more sample batches than there are dataloader batches available, this removes some randomness from sampling but is probably fine. (All sample batches beyond the number dataloader batches are cycled from the start, f.e. 9 samples from [A, B, C] would be [B, A, C, B, A, C, B, A, C].)"
)

# Initialize new model and optimizer for this chain
model = deepcopy(ref_model).to(device)
if copy_full_model:
model = deepcopy(ref_model).to(device)
else:
model = ref_model.to(device)
if 'named_params' in optimizer_kwargs:
named_params = optimizer_kwargs.pop('named_params')
else:
named_params = model.named_parameters()
if not copy_full_model:
ref_named_params = deepcopy(named_params)

if "temperature" in optimizer_kwargs:
assert (
Expand All @@ -78,7 +87,7 @@ def sample_single_chain(

if optimize_over_per_model_param:
param_groups = []
for name, parameter in model.named_parameters():
for name, parameter in named_params:
param_groups.append(
{
"params": parameter,
Expand All @@ -90,7 +99,7 @@ def sample_single_chain(
**optimizer_kwargs,
)
else:
optimizer = sampling_method(model.parameters(), **optimizer_kwargs)
optimizer = sampling_method(named_params.values(), **optimizer_kwargs)

if seed is not None:
torch.manual_seed(seed)
Expand Down Expand Up @@ -138,6 +147,10 @@ def sample_single_chain(
cumulative_loss = 0
pbar.update(1)

# Restore orginal params
if not copy_full_model:
model.load_state_dict(ref_named_params, strict=False)


def _sample_single_chain(kwargs):
pickled_args = ["evaluate", "loader"]
Expand Down Expand Up @@ -342,6 +355,7 @@ def evaluate(model, data):
verbose=verbose,
optimize_over_per_model_param=optimize_over_per_model_param,
use_amp=use_amp,
copy_full_model=cores > 1,
)

if cores > 1:
Expand Down