Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
from datasets.fingerprint import generate_fingerprint
from huggingface_hub import HfApi
from jaxtyping import Float
from tqdm import tqdm

from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
from sae_lens.load_model import load_model
from sae_lens.training.activations_store import ActivationsStore
from tqdm import tqdm


class CacheActivationsRunner:
Expand Down Expand Up @@ -273,10 +272,20 @@ def run(self) -> Dataset:
meta_io.seek(0)

api = HfApi()

# Tacks on username to the repo id
user_repo_id = api.create_repo(
self.cfg.hf_repo_id,
private=self.cfg.hf_is_private_repo,
repo_type="dataset",
exist_ok=True, # should exist already
).repo_id

api.upload_file(
path_or_fileobj=meta_io,
path_in_repo="cache_activations_runner_cfg.json",
repo_id=self.cfg.hf_repo_id,
repo_id=user_repo_id,
revision=self.cfg.hf_revision,
repo_type="dataset",
commit_message="Add cache_activations_runner metadata",
)
Expand Down