Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .train_hypergrid_simple_ls import main as train_hypergrid_simple_ls_main
from .train_ising import main as train_ising_main
from .train_line import main as train_line_main
from .train_rng_gfn import main as train_rng_gfn_main


@dataclass
Expand Down Expand Up @@ -170,6 +171,25 @@ class BayesianStructureArgs(CommonArgs):
use_cuda: bool = False


@dataclass
class RNGGFNArgs(CommonArgs):
batch_size: int = 8
n_trajectories: int = 8
n_iterations: int = 10
lr: float = 1e-4
max_length: int = 5
prompt: str = "The following is a random integer drawn uniformly between 0 and 100: "


@pytest.mark.parametrize("n_iterations", [10])
def test_rng_gfn_smoke(n_iterations: int):
"""Smoke test for the RNG GFN training script."""
args = RNGGFNArgs(n_iterations=n_iterations)
args_dict = asdict(args)
namespace_args = Namespace(**args_dict)
train_rng_gfn_main(namespace_args) # Just ensure it runs without errors.


@pytest.mark.parametrize("ndim", [2, 4])
@pytest.mark.parametrize("height", [8, 16])
@pytest.mark.parametrize("replay_buffer_size", [0, 1000])
Expand Down
Loading