Skip to content

[BUG?] Missing accumulate_iterations call in pairwise score computation when compute_per_module_scores=True #45

@MyDum-bsu

Description

@MyDum-bsu

Is there a critical bug in score/dot_product.py where accumulate_iterations is only called when compute_per_module_scores=False, causing incorrect module state management when computing per-module pairwise scores?

  • File: score/dot_product.py
  • Function: compute_dot_products_with_loader

Problem: accumulate_iterations is only called when compute_per_module_scores=False, but it should be called in both cases to properly clear module state after each iteration:

with torch.no_grad():
    if score_args.compute_per_module_scores:
        for module in cached_module_lst:
            score_chunks[module.name].append(
                module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to(device="cpu", copy=True)
            )
    else:
        # ... code for aggregated scores ...
        score_chunks[ALL_MODULE_NAME].append(pairwise_scores)
        accumulate_iterations(model=model, tracked_module_names=tracked_module_names)  # BUG

Move accumulate_iterations outside the conditional block:

with torch.no_grad():
    if score_args.compute_per_module_scores:
        for module in cached_module_lst:
            score_chunks[module.name].append(
                module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).to(device="cpu", copy=True)
            )
    else:
        # ... code for aggregated scores ...
        score_chunks[ALL_MODULE_NAME].append(pairwise_scores)
    
    accumulate_iterations(model=model, tracked_module_names=tracked_module_names)

Correct me if I'm wrong

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions