Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/test_shakespeare.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
{
"batch_size": 16,
"batch_size": 64,
"dataset": "shakespeare",
"dataset_kwargs": {},
"max_epoch": 10,
"model": "llama",
"model_kwargs": {"vocab_size": 92, "dim": 384, "expand": 4, "n_layers": 3, "n_heads": 2, "mlp": "mlp", "seq_len": 512},
"model_kwargs": {"vocab_size": 92, "dim": 384, "expand": 4, "n_layers": 6, "n_heads": 6, "mlp": "mlp", "seq_len": 256},
"opt": [{"name": "adam", "lr": [1e-3], "lr_schedule": "constant", "warmup_steps": 100, "stepwise_schedule": true}],
"loss_func": "sequence_cross_entropy",
"score_func": "sequence_cross_entropy_accuracy",
Expand Down
164 changes: 82 additions & 82 deletions output/test_shakespeare.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"config": {
"batch_size": 16,
"batch_size": 64,
"dataset": "shakespeare",
"dataset_kwargs": {},
"loss_func": "sequence_cross_entropy",
Expand All @@ -11,9 +11,9 @@
"dim": 384,
"expand": 4,
"mlp": "mlp",
"n_heads": 2,
"n_layers": 3,
"seq_len": 512,
"n_heads": 6,
"n_layers": 6,
"seq_len": 256,
"vocab_size": 92
},
"opt": {
Expand All @@ -29,129 +29,129 @@
"history": [
{
"epoch": 0,
"grad_norm": 1.992753267288208,
"grad_norm": 1.121427059173584,
"learning_rate": 1e-13,
"model_norm": 64.88945770263672,
"train_epoch_time": 20.36537790298462,
"train_loss": 2.165691924560494,
"train_score": 0.36806988216568765,
"val_loss": 2.218102411840154,
"val_score": 0.3520025144363272
"model_norm": 87.65153503417969,
"train_epoch_time": 36.848896980285645,
"train_loss": 2.461630094821005,
"train_score": 0.29407953720865837,
"val_loss": 2.4997918203421765,
"val_score": 0.28291834143218164
},
{
"epoch": 1,
"grad_norm": 1.334716558456421,
"learning_rate": 0.001,
"model_norm": 66.11741638183594,
"train_epoch_time": 20.214939832687378,
"train_loss": 1.8825722357983157,
"train_score": 0.4417044885807278,
"val_loss": 2.0125111694993643,
"val_score": 0.4103133981255279
"grad_norm": 1.3179094791412354,
"learning_rate": 0.000540000000046,
"model_norm": 88.65203094482422,
"train_epoch_time": 36.27439522743225,
"train_loss": 2.0504614697296275,
"train_score": 0.40136858857687097,
"val_loss": 2.11395750226438,
"val_score": 0.3813056111609484
},
{
"epoch": 2,
"grad_norm": 1.0500348806381226,
"grad_norm": 1.0813289880752563,
"learning_rate": 0.001,
"model_norm": 67.5,
"train_epoch_time": 19.981853008270264,
"train_loss": 1.7142505491094666,
"train_score": 0.48880484908921845,
"val_loss": 1.9234138680600572,
"val_score": 0.4437095905172414
"model_norm": 89.66089630126953,
"train_epoch_time": 36.44406986236572,
"train_loss": 1.7927071080488317,
"train_score": 0.46938217360969936,
"val_loss": 1.929213602014579,
"val_score": 0.43094324097293935
},
{
"epoch": 3,
"grad_norm": 1.0010490417480469,
"grad_norm": 1.2653794288635254,
"learning_rate": 0.001,
"model_norm": 68.94657897949219,
"train_epoch_time": 19.904484033584595,
"train_loss": 1.6100540534773868,
"train_score": 0.5147796176741782,
"val_loss": 1.8390625857758796,
"val_score": 0.4670977011494253
"model_norm": 90.63284301757812,
"train_epoch_time": 36.52592372894287,
"train_loss": 1.6544971603232101,
"train_score": 0.5043736549497848,
"val_loss": 1.8547480991869378,
"val_score": 0.4552821832806317
},
{
"epoch": 4,
"grad_norm": 0.867782711982727,
"grad_norm": 0.9828493595123291,
"learning_rate": 0.001,
"model_norm": 70.42530822753906,
"train_epoch_time": 19.90021586418152,
"train_loss": 1.5380675991148134,
"train_score": 0.5328275243806236,
"val_loss": 1.792324341576675,
"val_score": 0.4788793104818498
"model_norm": 91.56520080566406,
"train_epoch_time": 36.71347689628601,
"train_loss": 1.5387006788034863,
"train_score": 0.5344926917056956,
"val_loss": 1.7642201408316156,
"val_score": 0.4774101251644327
},
{
"epoch": 5,
"grad_norm": 0.7821542620658875,
"grad_norm": 1.0504655838012695,
"learning_rate": 0.001,
"model_norm": 71.892822265625,
"train_epoch_time": 19.903041124343872,
"train_loss": 1.4823639410645215,
"train_score": 0.54531541538567,
"val_loss": 1.7534021857141078,
"val_score": 0.49129849144782145
"model_norm": 92.474609375,
"train_epoch_time": 36.35963201522827,
"train_loss": 1.4742244492986452,
"train_score": 0.5499394725350772,
"val_loss": 1.7320461796017108,
"val_score": 0.49073442892009983
},
{
"epoch": 6,
"grad_norm": 0.888006865978241,
"grad_norm": 0.8896018862724304,
"learning_rate": 0.001,
"model_norm": 73.36498260498047,
"train_epoch_time": 19.90959620475769,
"train_loss": 1.4459222527787552,
"train_score": 0.5567617144705097,
"val_loss": 1.7188306109658602,
"val_score": 0.5006106322524191
"model_norm": 93.3919906616211,
"train_epoch_time": 36.567251205444336,
"train_loss": 1.4220848283945573,
"train_score": 0.5644335094791232,
"val_loss": 1.6994645117892189,
"val_score": 0.5016414314405789
},
{
"epoch": 7,
"grad_norm": 0.7781890034675598,
"grad_norm": 0.7851909399032593,
"learning_rate": 0.001,
"model_norm": 74.81156921386719,
"train_epoch_time": 20.277348041534424,
"train_loss": 1.4094120469350628,
"train_score": 0.564420610495194,
"val_loss": 1.7047440726181557,
"val_score": 0.5073141165163325
"model_norm": 94.29660034179688,
"train_epoch_time": 36.6595242023468,
"train_loss": 1.3814418076784745,
"train_score": 0.5747007264639418,
"val_loss": 1.6752228601381234,
"val_score": 0.5100907716904387
},
{
"epoch": 8,
"grad_norm": 0.7673189043998718,
"grad_norm": 0.7439262866973877,
"learning_rate": 0.001,
"model_norm": 76.27680969238281,
"train_epoch_time": 20.30754780769348,
"train_loss": 1.3734883747473103,
"train_score": 0.5729114344828986,
"val_loss": 1.693546010159898,
"val_score": 0.5119118175287356
"model_norm": 95.21795654296875,
"train_epoch_time": 36.596500873565674,
"train_loss": 1.359556995163347,
"train_score": 0.5797659610573155,
"val_loss": 1.6677938104360166,
"val_score": 0.511516934831709
},
{
"epoch": 9,
"grad_norm": 0.6948201060295105,
"grad_norm": 0.7252312898635864,
"learning_rate": 0.001,
"model_norm": 77.71929931640625,
"train_epoch_time": 20.247668027877808,
"train_loss": 1.3485437734255699,
"train_score": 0.5800568226694924,
"val_loss": 1.672089501358997,
"val_score": 0.5159931754243785
"model_norm": 96.16193389892578,
"train_epoch_time": 36.32176995277405,
"train_loss": 1.3168328549290662,
"train_score": 0.5910767124578292,
"val_loss": 1.6540300926850118,
"val_score": 0.5194550074474409
}
],
"summary": {
"data_parallel": "false",
"end_time": "2025-07-23 14:30:37.932485",
"final_model_norm": 77.71929931640625,
"init_model_norm": 64.04080963134766,
"end_time": "2025-12-01 15:12:04.720226",
"final_model_norm": 96.16193389892578,
"init_model_norm": 87.41546630859375,
"input_dim": [
512
256
],
"num_batches_per_epoch": 108,
"num_batches_per_epoch": 54,
"num_workers": 0,
"output_dim": [
512
256
],
"start_time": "2025-07-23 14:25:44.459311",
"start_time": "2025-12-01 15:03:08.470563",
"step_scheduler_on_epoch": false
}
}
Expand Down
63 changes: 34 additions & 29 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@
parser.add_argument('--verbose', action="store_true", help="Verbose mode.")
parser.add_argument('--force-deterministic', action="store_true", help="Use deterministic mode in Pytorch. Might require setting environment variables.")

def run_one(exp_id: str,
config_dir: str=DEFAULTS.config_dir,
output_dir: str=DEFAULTS.output_dir,
data_dir: str=DEFAULTS.data_dir,
device: str=DEFAULTS.device,
num_workers: int=DEFAULTS.num_workers,
data_parallel: Union[list, None]=DEFAULTS.data_parallel,
log_every_k_steps: Union[int, None]=DEFAULTS.log_every_k_steps,
verbose: bool=DEFAULTS.verbose,
force_deterministic: bool=DEFAULTS.force_deterministic
):
def run_one(
exp_id: str,
config_dir: str=DEFAULTS.config_dir,
output_dir: str=DEFAULTS.output_dir,
data_dir: str=DEFAULTS.data_dir,
device: str=DEFAULTS.device,
num_workers: int=DEFAULTS.num_workers,
data_parallel: Union[list, None]=DEFAULTS.data_parallel,
log_every_k_steps: Union[int, None]=DEFAULTS.log_every_k_steps,
verbose: bool=DEFAULTS.verbose,
force_deterministic: bool=DEFAULTS.force_deterministic
):
"""Function for running all runs from one config file.
Default values for all arguments can be found in ``stepback/defaults.py``.

Expand Down Expand Up @@ -83,14 +84,16 @@ def run_one(exp_id: str,

for j, config in enumerate(exp_list):
# each run gets id, by position in the list
B = Base(name=exp_id + f'_{j}',
config=config,
device=device,
data_dir=data_dir,
num_workers=num_workers,
data_parallel=data_parallel,
log_every_k_steps=log_every_k_steps,
verbose=verbose)
B = Base(
name=exp_id + f'_{j}',
config=config,
device=device,
data_dir=data_dir,
num_workers=num_workers,
data_parallel=data_parallel,
log_every_k_steps=log_every_k_steps,
verbose=verbose
)

B.setup()
B.run() # train and validate
Expand All @@ -106,15 +109,17 @@ def run_one(exp_id: str,

print(args)

run_one(args.id,
config_dir=args.config_dir,
output_dir=args.output_dir,
data_dir=args.data_dir,
device=args.device,
num_workers=args.num_workers,
data_parallel=args.data_parallel,
log_every_k_steps=args.log_every_k_steps,
verbose=args.verbose,
force_deterministic=args.force_deterministic)
run_one(
args.id,
config_dir=args.config_dir,
output_dir=args.output_dir,
data_dir=args.data_dir,
device=args.device,
num_workers=args.num_workers,
data_parallel=args.data_parallel,
log_every_k_steps=args.log_every_k_steps,
verbose=args.verbose,
force_deterministic=args.force_deterministic
)


Loading