Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class OfflineTrainConfig:
initial_rtg: list[float] = (0.0, 1.0)
eval_max_time_steps: int = 100
eval_num_envs: int = 8
num_checkpoints: int = 10

def __post_init__(self):
assert self.model_type in ["decision_transformer", "clone_transformer"]
Expand Down
24 changes: 3 additions & 21 deletions src/decision_transformer/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import time
import warnings
Expand All @@ -8,7 +7,6 @@

import wandb
from src.config import (
ConfigJsonEncoder,
EnvironmentConfig,
OfflineTrainConfig,
RunConfig,
Expand All @@ -27,7 +25,7 @@
one_hot_encode_observation,
)
from .train import train
from .utils import get_max_len_from_model_type
from .utils import get_max_len_from_model_type, store_transformer_model


def run_decision_transformer(
Expand Down Expand Up @@ -154,6 +152,8 @@ def run_decision_transformer(
initial_rtg=offline_config.initial_rtg,
eval_max_time_steps=offline_config.eval_max_time_steps,
eval_num_envs=offline_config.eval_num_envs,
exp_name=run_config.exp_name,
offline_config=offline_config,
)

if run_config.track:
Expand All @@ -177,21 +177,3 @@ def run_decision_transformer(
os.remove(model_path)

wandb.finish()


def store_transformer_model(path, model, offline_config):
t.save(
{
"model_state_dict": model.state_dict(),
"offline_config": json.dumps(
offline_config, cls=ConfigJsonEncoder
),
"environment_config": json.dumps(
model.environment_config, cls=ConfigJsonEncoder
),
"model_config": json.dumps(
model.transformer_config, cls=ConfigJsonEncoder
),
},
path,
)
38 changes: 37 additions & 1 deletion src/decision_transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import tqdm

import wandb
from src.config import EnvironmentConfig
from src.config import EnvironmentConfig, OfflineTrainConfig
from src.models.trajectory_transformer import (
CloneTransformer,
DecisionTransformer,
Expand All @@ -16,11 +16,13 @@

from .offline_dataset import TrajectoryDataset
from .eval import evaluate_dt_agent
from .utils import store_model_checkpoint


def train(
model: TrajectoryTransformer,
trajectory_data_set: TrajectoryDataset,
offline_config: OfflineTrainConfig,
env,
make_env,
batch_size=128,
Expand All @@ -36,6 +38,7 @@ def train(
initial_rtg=[0.0, 1.0],
eval_max_time_steps=100,
eval_num_envs=8,
exp_name=""
):
loss_fn = nn.CrossEntropyLoss()
model = model.to(device)
Expand Down Expand Up @@ -71,6 +74,20 @@ def train(
test_dataset, batch_size=batch_size, sampler=test_sampler
)

if track:
checkpoint_artifact = wandb.Artifact(
f"{exp_name}_checkpoints", type="model"
)
checkpoint_num = 1
checkpoint_interval = train_epochs // offline_config.num_checkpoints + 1
checkpoint_num = store_model_checkpoint(
model,
exp_name,
offline_config,
checkpoint_num,
checkpoint_artifact
)

train_batches_per_epoch = len(train_dataloader)
pbar = tqdm(range(train_epochs))
for epoch in pbar:
Expand Down Expand Up @@ -175,6 +192,25 @@ def train(
num_envs=eval_num_envs,
)

if track and (epoch + 1) % checkpoint_interval == 0:
checkpoint_num = store_model_checkpoint(
model,
exp_name,
offline_config,
checkpoint_num,
checkpoint_artifact
)

if track:
store_model_checkpoint(
model,
exp_name,
offline_config,
checkpoint_num,
checkpoint_artifact
)
wandb.log_artifact(checkpoint_artifact)

return model


Expand Down
41 changes: 41 additions & 0 deletions src/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from src.config import (
EnvironmentConfig,
TransformerModelConfig,
ConfigJsonEncoder,
)
from src.models.trajectory_transformer import (
DecisionTransformer,
Expand Down Expand Up @@ -75,6 +76,13 @@ def parse_args():
default=False,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
"--num_checkpoints",
type=int,
default=10,
help="how many checkpoints are stored and uploaded to wandb during training",
)

args = parser.parse_args()
return args

Expand Down Expand Up @@ -198,3 +206,36 @@ def initialize_padding_inputs(
).to(device)

return obs, actions, reward, rtg, timesteps, mask


def store_model_checkpoint(
model, exp_name, offline_config, checkpoint_num, checkpoint_artifact
) -> int:
checkpoint_name = f"{exp_name}_{checkpoint_num:0>2}"
checkpoint_path = f"models/{checkpoint_name}.pt"

store_transformer_model(checkpoint_path, model, offline_config)

checkpoint_artifact.add_file(
local_path=checkpoint_path, name=f"{checkpoint_name}.pt"
)

return checkpoint_num + 1


def store_transformer_model(path, model, offline_config):
t.save(
{
"model_state_dict": model.state_dict(),
"offline_config": json.dumps(
offline_config, cls=ConfigJsonEncoder
),
"environment_config": json.dumps(
model.environment_config, cls=ConfigJsonEncoder
),
"model_config": json.dumps(
model.transformer_config, cls=ConfigJsonEncoder
),
},
path,
)
3 changes: 2 additions & 1 deletion src/run_decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
eval_max_time_steps=args.eval_max_time_steps,
track=args.track,
convert_to_one_hot=args.convert_to_one_hot,
device=run_config.device
device=run_config.device,
num_checkpoints=args.num_checkpoints,
)

run_decision_transformer(
Expand Down
35 changes: 33 additions & 2 deletions tests/acceptance/test_model_saving_and_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
import wandb

from src.config import (
EnvironmentConfig,
Expand All @@ -12,9 +13,8 @@
TransformerModelConfig,
)

from src.decision_transformer.runner import store_transformer_model
from src.decision_transformer.offline_dataset import TrajectoryDataset
from src.decision_transformer.utils import load_decision_transformer
from src.decision_transformer.utils import load_decision_transformer, store_model_checkpoint, store_transformer_model
from src.models.trajectory_transformer import DecisionTransformer


Expand Down Expand Up @@ -145,6 +145,37 @@ def test_load_decision_transformer(
assert new_model.environment_config == environment_config


def test_decision_transformer_checkpoint_saving_and_loading(
transformer_config, environment_config, offline_config, run_config
):
wandb.init(mode="offline")
model = DecisionTransformer(
environment_config=environment_config,
transformer_config=transformer_config
)
checkpoint_artifact = wandb.Artifact(
f"{run_config.exp_name}_checkpoints", type="model"
)
checkpoint_num = 1

checkpoint_num = store_model_checkpoint(
model=model,
exp_name=run_config.exp_name,
offline_config=offline_config,
checkpoint_num=checkpoint_num,
checkpoint_artifact=checkpoint_artifact
)

assert checkpoint_num == 2

loaded_model = load_decision_transformer(f"models/{run_config.exp_name}_01.pt")

assert_state_dicts_are_equal(loaded_model.state_dict(), model.state_dict())

assert loaded_model.transformer_config == transformer_config
assert loaded_model.environment_config == environment_config


def assert_state_dicts_are_equal(dict1, dict2):
keys1 = sorted(dict1.keys())
keys2 = sorted(dict2.keys())
Expand Down
2 changes: 1 addition & 1 deletion tests/end_end/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)

from src.models.trajectory_transformer import DecisionTransformer
from src.decision_transformer.runner import store_transformer_model
from src.decision_transformer.utils import store_transformer_model


@pytest.fixture()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_streamlit_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from src.models.trajectory_transformer import DecisionTransformer

from src.decision_transformer.runner import store_transformer_model
from src.decision_transformer.utils import store_transformer_model
from src.streamlit_app.environment import get_env_and_dt
from src.environments.registration import register_envs

Expand Down