Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions auto_circuit/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def init_task(self):
bs_2 = bs[1] if isinstance(bs, tuple) else bs
count_2 = b_count[1] if isinstance(b_count, tuple) else b_count
has_tokenizer = hasattr(model, "tokenizer") and model.tokenizer is not None
if has_tokenizer and hasattr(model.tokenizer, "add_bos_token"): # ensures bos token not added to answers/wrong answers for pythia models
model.tokenizer.add_bos_token = False
train_loader, test_loader = load_datasets_from_json(
model=model if has_tokenizer else None,
path=repo_path_to_abs_path(f"datasets/{self._dataset_name}.json"),
Expand Down Expand Up @@ -296,7 +298,7 @@ def init_task(self):
key="Indirect Object Identification Component Circuit",
name="Indirect Object Identification",
_model_def="gpt2-small",
_dataset_name="ioi_prompts",
_dataset_name="ioi/ioi_prompts",
batch_size=64,
batch_count=2,
_true_edge_func=ioi_true_edges,
Expand All @@ -306,7 +308,7 @@ def init_task(self):
key="Indirect Object Identification GPT2 Autoencoder Component Circuit",
name="Indirect Object Identification",
_model_def="gpt2-small",
_dataset_name="ioi_prompts",
_dataset_name="ioi/ioi_prompts",
batch_size=1,
batch_count=2,
_true_edge_func=None,
Expand Down