From b4cb75fa6d1a1ac8b1556f8a0e65c96623ee651e Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Mon, 10 Jan 2022 04:52:37 +0000 Subject: [PATCH 1/8] initial reinforce algorithm --- poetry.lock | 123 ++++++++++++++++++++++++- pyproject.toml | 2 + src/functionrl/algorithms/reinforce.py | 89 ++++++++++++++++++ src/functionrl/experiences.py | 16 +++- 4 files changed, 220 insertions(+), 10 deletions(-) create mode 100644 src/functionrl/algorithms/reinforce.py diff --git a/poetry.lock b/poetry.lock index dd5b9d9..af6d224 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,6 +20,32 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] +[[package]] +name = "black" +version = "21.12b0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +click = ">=7.1.2" +mypy-extensions = ">=0.4.3" +pathspec = ">=0.9.0,<1" +platformdirs = ">=2" +tomli = ">=0.2.6,<2.0.0" +typing-extensions = [ + {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, + {version = "!=3.10.0.1", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +python2 = ["typed-ast (>=1.4.3)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "certifi" version = "2021.10.8" @@ -120,6 +146,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "numpy" version = "1.22.0" @@ -160,6 +194,26 @@ pytz = ">=2017.3" [package.extras] test = ["hypothesis (>=3.58)", "pytest (>=6.0)", "pytest-xdist"] +[[package]] +name = "pathspec" +version = "0.9.0" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[[package]] +name = "platformdirs" +version = "2.4.1" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"] +test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] + [[package]] name = "pluggy" version = "1.0.0" @@ -296,11 +350,30 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] name = "tomli" -version = "2.0.0" +version = "1.2.3" description = "A lil' TOML parser" category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.6" + +[[package]] +name = "torch" +version = "1.10.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +typing-extensions = "*" + +[[package]] +name = "typing-extensions" +version = "4.0.1" +description = "Backported and Experimental Type Hints for Python 3.6+" +category = "main" +optional = false +python-versions = ">=3.6" [[package]] name = "urllib3" @@ -318,7 +391,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "3690820d8b14aff176d2c7a2785ed19b97f36801fa9912e86aced8c0aeb5f3c0" +content-hash = "692fd4035dc8cb9de040ee0d1bafcaf5676d8c04d6fa3d94d6703b098fdd3517" [metadata.files] atomicwrites = [ @@ -329,6 +402,10 @@ attrs = [ {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, ] +black = [ + {file = "black-21.12b0-py3-none-any.whl", hash = "sha256:a615e69ae185e08fdd73e4715e260e2479c861b5740057fde6e8b4e3b7dd589f"}, + {file = "black-21.12b0.tar.gz", hash = "sha256:77b80f693a569e2e527958459634f18df9b0ba2625ba4e0c2d5da5be42e6f2b3"}, +] certifi = [ {file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"}, {file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"}, @@ -409,6 +486,10 @@ iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] +mypy-extensions = [ + {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, + {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, +] numpy = [ {file = "numpy-1.22.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d22662b4b10112c545c91a0741f2436f8ca979ab3d69d03d19322aa970f9695"}, {file = "numpy-1.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:11a1f3816ea82eed4178102c56281782690ab5993251fdfd75039aad4d20385f"}, @@ -464,6 +545,14 @@ pandas = [ {file = "pandas-1.3.5-cp39-cp39-win_amd64.whl", hash = "sha256:32e1a26d5ade11b547721a72f9bfc4bd113396947606e00d5b4a5b79b3dcb006"}, {file = "pandas-1.3.5.tar.gz", hash = "sha256:1e4285f5de1012de20ca46b188ccf33521bff61ba5c5ebd78b4fb28e5416a9f1"}, ] +pathspec = [ + {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, + {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, +] +platformdirs = [ + {file = "platformdirs-2.4.1-py3-none-any.whl", hash = "sha256:1d7385c7db91728b83efd0ca99a5afb296cab9d0ed8313a45ed8ba17967ecfca"}, + {file = "platformdirs-2.4.1.tar.gz", hash = "sha256:440633ddfebcc36264232365d7840a970e75e1018d15b4327d11f91909045fda"}, +] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, @@ -509,8 +598,32 @@ toml = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] tomli = [ - {file = "tomli-2.0.0-py3-none-any.whl", hash = "sha256:b5bde28da1fed24b9bd1d4d2b8cba62300bfb4ec9a6187a957e8ddb9434c5224"}, - {file = "tomli-2.0.0.tar.gz", hash = "sha256:c292c34f58502a1eb2bbb9f5bbc9a5ebc37bee10ffb8c2d6bbdfa8eb13cc14e1"}, + {file = "tomli-1.2.3-py3-none-any.whl", hash = "sha256:e3069e4be3ead9668e21cb9b074cd948f7b3113fd9c8bba083f48247aab8b11c"}, + {file = "tomli-1.2.3.tar.gz", hash = "sha256:05b6166bff487dc068d322585c7ea4ef78deed501cc124060e0f238e89a9231f"}, +] +torch = [ + {file = "torch-1.10.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:adbb5f292e260e39715d67478823e03e3001db1af5b02c18caa34549dccb421e"}, + {file = "torch-1.10.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:ac8cae04458cc47555fa07a760496c2fdf687223bcc13df5fed56ea3aead37f5"}, + {file = "torch-1.10.1-cp36-cp36m-win_amd64.whl", hash = "sha256:40508d67288c46ff1fad301fa6e996e0e936a733f2401475fc92c21dc3ef702d"}, + {file = "torch-1.10.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:8b47bd113c6cbd9a49669aaaa233ad5f25852d6ca3e640f9c71c808e65a1fdf4"}, + {file = "torch-1.10.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:50360868ad3f039cf99f0250300dbec51bf686a7b84dc6bbdb8dff4b1171c0f0"}, + {file = "torch-1.10.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:e3d2154722189ed74747a494dce9588978dd55e43ca24c5bd307fb52620b232b"}, + {file = "torch-1.10.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d9c495bcd5f00becff5b051b5e4be86b7eaa0433cd0fe57f77c02bc1b93ab5b1"}, + {file = "torch-1.10.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:6b327d7b4eb2461b16d46763d46df71e597235ccc428650538a2735a0898270d"}, + {file = "torch-1.10.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1c6c56178e5dacf7602ad00dc79c263d6c41c0f76261e9641e6bd2679678ceb3"}, + {file = "torch-1.10.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2ffa2db4ccb6466c59e3f95b7a582d47ae721e476468f4ffbcaa2832e0b92b9b"}, + {file = "torch-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:af577602e884c5e40fbd29ec978f052202355da93cd31e0a23251bd7aaff5a99"}, + {file = "torch-1.10.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:725d86e9809073eef868a3ddf4189878ee7af46fac71403834dd0925b3db9b82"}, + {file = "torch-1.10.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:fa197cfe047d0515bef238f42472721406609ebaceff2fd4e17f2ad4692ee51c"}, + {file = "torch-1.10.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cca660b27a90dbbc0af06c859260f6b875aef37c0897bd353e5deed085d2c877"}, + {file = "torch-1.10.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:01f4ffdafbfbd7d106fb4e487feee2cf29cced9903df8cb0444b0e308f9c5e92"}, + {file = "torch-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:607eccb7d539a11877cd02d95f4b164b7941fcf538ac7ff087bfed19e3644283"}, + {file = "torch-1.10.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:26b6dfbe21e247e67c615bfab0017ec391ed1517f88bbeea6228a49edd24cd88"}, + {file = "torch-1.10.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:5644280d88c5b6de27eacc0d911f968aad41a4bab297af4df5e571bc0927d3e4"}, +] +typing-extensions = [ + {file = "typing_extensions-4.0.1-py3-none-any.whl", hash = "sha256:7f001e5ac290a0c0401508864c7ec868be4e701886d5b573a9528ed3973d9d3b"}, + {file = "typing_extensions-4.0.1.tar.gz", hash = "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e"}, ] urllib3 = [ {file = "urllib3-1.26.8-py2.py3-none-any.whl", hash = "sha256:000ca7f471a233c2251c6c7023ee85305721bfdf18621ebff4fd17a8653427ed"}, diff --git a/pyproject.toml b/pyproject.toml index 7620489..53338a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,12 +16,14 @@ requests = "^2.27.1" numpy = "^1.22.0" gym = "^0.21.0" pandas = "^1.3.5" +torch = "^1.10.1" [tool.poetry.dev-dependencies] coverage = {extras = ["toml"], version = "^6.2"} pytest = "^6.2.5" pytest-cov = "^3.0.0" pytest-mock = "^3.6.1" +black = "^21.12b0" [tool.poetry.scripts] diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py new file mode 100644 index 0000000..c35fe62 --- /dev/null +++ b/src/functionrl/algorithms/reinforce.py @@ -0,0 +1,89 @@ +import numpy as np +import torch +from torch import nn, optim +from torch.nn import functional as F +from torch.distributions import Categorical +from ..utils import linear_decay +from ..policies import make_epsilon_greedy_policy, make_greedy_policy +from ..experiences import generate_episodes +from ..envs import make_frozen_lake +from ..display import print_pi, print_v + + +class Pi(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.model = nn.Sequential( + nn.Linear(in_dim, 64), + nn.ReLU(), + nn.Linear(64, out_dim) + ) + + def act(self, state): + state_t = torch.tensor([state]) + state_oh = F.one_hot(state_t, num_classes=self.in_dim).float() + logits = self.forward(state_oh) + pd = Categorical(logits=logits) + action = pd.sample() + log_prob = pd.log_prob(action) + return action.item(), log_prob # .squeeze() + + def forward(self, x): + return self.model(x) + + +def reinforce( + make_env, + gamma: float = 1.0, + learning_rate: float = 1e-3, + n_episodes: int = 10000, + log_interval: int = 100, + eval_episodes: int = 100, +): + env_train = make_env() + env_eval = make_env() + + n_states = env_train.observation_space.n + n_actions = env_train.action_space.n + + pi = Pi(n_states, n_actions) + optimizer = optim.Adam(pi.parameters(), lr=learning_rate) + + def policy(state): + action, log_prob = pi.act(state) + return action, {"log_prob": log_prob} + + for i, episode in enumerate( + generate_episodes(env_train, policy, n=n_episodes), start=1 + ): + T = len(episode) + rewards = [exp.reward for exp in episode] + log_probs = [exp.policy_info["log_prob"] for exp in episode] + rets = np.empty(T, dtype=np.float32) + future_ret = 0.0 + for t in reversed(range(T)): + future_ret = rewards[t] + gamma * future_ret + rets[t] = future_ret + rets = torch.tensor(rets) + log_probs = torch.stack(log_probs) + loss = (-log_probs * rets).sum() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % log_interval == 0: + with torch.no_grad(): + episodes = list(generate_episodes(env_eval, policy, n=eval_episodes)) + returns = [sum(e.reward for e in episode) for episode in episodes] + mean_return = np.mean(returns) + print(f"{i:5d}: {mean_return:.3f}") + + +if __name__ == "__main__": + reinforce( + make_frozen_lake, + gamma=0.99, + n_episodes=100000, + ) diff --git a/src/functionrl/experiences.py b/src/functionrl/experiences.py index 206f740..7d83b46 100644 --- a/src/functionrl/experiences.py +++ b/src/functionrl/experiences.py @@ -2,16 +2,22 @@ from collections import namedtuple from .utils import limitable -Experience = namedtuple("Experience", "state action reward next_state is_done") +Experience = namedtuple( + "Experience", + ["state", "action", "reward", "next_state", "is_done", "policy_info"], +) def generate_episode(env, policy): next_state = env.reset() - for step in count(): + while True: state = next_state - action = policy(state, step) - next_state, reward, is_done, _ = env.step(action) - yield Experience(state, action, reward, next_state, is_done) + action = policy(state) + policy_info = None + if isinstance(action, tuple): + action, policy_info = action + next_state, reward, is_done, env_info = env.step(action) + yield Experience(state, action, reward, next_state, is_done, policy_info) if is_done: break From b5fe0bcf143ed9d651592ec6ae8204985984c729 Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Mon, 10 Jan 2022 05:23:09 +0000 Subject: [PATCH 2/8] solve frozenlake with reinforce --- src/functionrl/algorithms/reinforce.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py index c35fe62..243a39d 100644 --- a/src/functionrl/algorithms/reinforce.py +++ b/src/functionrl/algorithms/reinforce.py @@ -15,14 +15,15 @@ def __init__(self, in_dim, out_dim): super().__init__() self.in_dim = in_dim self.out_dim = out_dim - self.model = nn.Sequential( - nn.Linear(in_dim, 64), - nn.ReLU(), - nn.Linear(64, out_dim) - ) + # self.model = nn.Sequential( + # nn.Linear(in_dim, 64), + # nn.ReLU(), + # nn.Linear(64, out_dim) + # ) + self.model = nn.Linear(in_dim, out_dim) def act(self, state): - state_t = torch.tensor([state]) + state_t = torch.tensor([state]) # .to("cuda") state_oh = F.one_hot(state_t, num_classes=self.in_dim).float() logits = self.forward(state_oh) pd = Categorical(logits=logits) @@ -40,7 +41,7 @@ def reinforce( learning_rate: float = 1e-3, n_episodes: int = 10000, log_interval: int = 100, - eval_episodes: int = 100, + eval_episodes: int = 1000, ): env_train = make_env() env_eval = make_env() @@ -48,7 +49,7 @@ def reinforce( n_states = env_train.observation_space.n n_actions = env_train.action_space.n - pi = Pi(n_states, n_actions) + pi = Pi(n_states, n_actions) # .to("cuda") optimizer = optim.Adam(pi.parameters(), lr=learning_rate) def policy(state): @@ -66,7 +67,7 @@ def policy(state): for t in reversed(range(T)): future_ret = rewards[t] + gamma * future_ret rets[t] = future_ret - rets = torch.tensor(rets) + rets = torch.tensor(rets) # .to("cuda") log_probs = torch.stack(log_probs) loss = (-log_probs * rets).sum() optimizer.zero_grad() @@ -85,5 +86,6 @@ def policy(state): reinforce( make_frozen_lake, gamma=0.99, - n_episodes=100000, + learning_rate=0.01, + n_episodes=10000, ) From 8257a0be64295bdb938a11ad07141dc0660359ff Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Tue, 11 Jan 2022 05:04:36 +0000 Subject: [PATCH 3/8] refactor policies, make deterministic, fix tests --- .devcontainer/Dockerfile | 2 +- .devcontainer/devcontainer.json | 2 +- src/functionrl/algorithms/reinforce.py | 23 +++++------- src/functionrl/algorithms/tabular_q.py | 40 ++++++++++----------- src/functionrl/experiences.py | 12 +++---- src/functionrl/policies.py | 49 ++++++++++---------------- src/functionrl/utils.py | 7 +++- tests/test_experiences.py | 8 ++--- tests/test_policies.py | 2 +- 9 files changed, 64 insertions(+), 81 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 822dec6..3169257 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -16,7 +16,7 @@ USER $USERNAME # install poetry for package management RUN curl -sSL https://install.python-poetry.org | python3 - -ENV PATH="~/.local/bin:$PATH" +ENV PATH="/home/$USERNAME/.local/bin:$PATH" WORKDIR $DIR diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 00bf2dd..ca3e97d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,5 @@ { "build": { "dockerfile": "Dockerfile", "context": ".." }, "runArgs": ["--gpus=all"], - "extensions": ["ms-python.python", "tamasfe.even-better-toml"], + "extensions": ["ms-python.python", "tamasfe.even-better-toml"] } diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py index 243a39d..5c7090b 100644 --- a/src/functionrl/algorithms/reinforce.py +++ b/src/functionrl/algorithms/reinforce.py @@ -5,7 +5,7 @@ from torch.distributions import Categorical from ..utils import linear_decay from ..policies import make_epsilon_greedy_policy, make_greedy_policy -from ..experiences import generate_episodes +from ..experiences import gen_episodes from ..envs import make_frozen_lake from ..display import print_pi, print_v @@ -15,21 +15,16 @@ def __init__(self, in_dim, out_dim): super().__init__() self.in_dim = in_dim self.out_dim = out_dim - # self.model = nn.Sequential( - # nn.Linear(in_dim, 64), - # nn.ReLU(), - # nn.Linear(64, out_dim) - # ) - self.model = nn.Linear(in_dim, out_dim) + self.model = nn.Linear(in_dim, out_dim, bias=False) def act(self, state): - state_t = torch.tensor([state]) # .to("cuda") + state_t = torch.tensor([state]) state_oh = F.one_hot(state_t, num_classes=self.in_dim).float() logits = self.forward(state_oh) pd = Categorical(logits=logits) action = pd.sample() log_prob = pd.log_prob(action) - return action.item(), log_prob # .squeeze() + return action.item(), log_prob def forward(self, x): return self.model(x) @@ -49,16 +44,14 @@ def reinforce( n_states = env_train.observation_space.n n_actions = env_train.action_space.n - pi = Pi(n_states, n_actions) # .to("cuda") + pi = Pi(n_states, n_actions) optimizer = optim.Adam(pi.parameters(), lr=learning_rate) def policy(state): action, log_prob = pi.act(state) return action, {"log_prob": log_prob} - for i, episode in enumerate( - generate_episodes(env_train, policy, n=n_episodes), start=1 - ): + for i, episode in enumerate(gen_episodes(env_train, policy, n=n_episodes), start=1): T = len(episode) rewards = [exp.reward for exp in episode] log_probs = [exp.policy_info["log_prob"] for exp in episode] @@ -67,7 +60,7 @@ def policy(state): for t in reversed(range(T)): future_ret = rewards[t] + gamma * future_ret rets[t] = future_ret - rets = torch.tensor(rets) # .to("cuda") + rets = torch.tensor(rets) log_probs = torch.stack(log_probs) loss = (-log_probs * rets).sum() optimizer.zero_grad() @@ -76,7 +69,7 @@ def policy(state): if i % log_interval == 0: with torch.no_grad(): - episodes = list(generate_episodes(env_eval, policy, n=eval_episodes)) + episodes = list(gen_episodes(env_eval, policy, n=eval_episodes)) returns = [sum(e.reward for e in episode) for episode in episodes] mean_return = np.mean(returns) print(f"{i:5d}: {mean_return:.3f}") diff --git a/src/functionrl/algorithms/tabular_q.py b/src/functionrl/algorithms/tabular_q.py index a9177b4..ece5fdc 100644 --- a/src/functionrl/algorithms/tabular_q.py +++ b/src/functionrl/algorithms/tabular_q.py @@ -1,7 +1,9 @@ +from itertools import count +from typing import Optional import numpy as np from ..utils import linear_decay from ..policies import make_epsilon_greedy_policy, make_greedy_policy -from ..experiences import generate_experiences, generate_episodes +from ..experiences import gen_experiences, gen_episodes from ..envs import make_frozen_lake from ..display import print_pi, print_v @@ -18,6 +20,7 @@ def tabular_q( n_steps: int = 5000, log_interval: int = 1000, eval_episodes: int = 1000, + seed: Optional[int] = None, ): env_train = make_env() env_eval = make_env() @@ -25,36 +28,30 @@ def tabular_q( n_states = env_train.observation_space.n n_actions = env_train.action_space.n + q = np.zeros((n_states, n_actions)) + alpha_decay = linear_decay(alpha_max, alpha_min, alpha_decay_steps) epsilon_decay = linear_decay(epsilon_max, epsilon_min, epsilon_decay_steps) - q = np.zeros((n_states, n_actions)) - - # TODO: pass decay into make_eps - policy_train = make_epsilon_greedy_policy( - q, epsilon_max, epsilon_min, epsilon_decay_steps - ) + policy_train = make_epsilon_greedy_policy(q, epsilon_decay, seed=seed) policy_eval = make_greedy_policy(q) - for step, exp in enumerate( - generate_experiences(env_train, policy_train, n=n_steps) + for i, exp in enumerate( + gen_experiences(env_train, policy_train, n=n_steps), start=1 ): + state, action, reward, next_state, is_done, policy_info = exp + td_target = reward + gamma * float(not is_done) * q[next_state].max() + td_error = td_target - q[state, action] - td_target = ( - exp.reward + gamma * float(not exp.is_done) * q[exp.next_state].max() - ) - td_error = td_target - q[exp.state, exp.action] - - alpha = alpha_decay(step) - q[exp.state, exp.action] += alpha * td_error + alpha = alpha_decay(i) + q[state, action] += alpha * td_error - if (step + 1) % log_interval == 0: - episodes = list(generate_episodes(env_eval, policy_eval, n=eval_episodes)) + if i % log_interval == 0: + epsilon = policy_info["epsilon"] + episodes = list(gen_episodes(env_eval, policy_eval, n=eval_episodes)) returns = [sum(e.reward for e in episode) for episode in episodes] mean_return = np.mean(returns) - print( - f"{step+1:5d}: {mean_return:.3f}, eps: {epsilon_decay(step):.3f}, alpha: {alpha:.6f}" - ) + print(f"{i:5d}: {mean_return:.3f}, eps: {epsilon:.3f}, alpha: {alpha:.6f}") return q @@ -71,6 +68,7 @@ def tabular_q( epsilon_decay_steps=100_000, n_steps=100_000, log_interval=10_000, + seed=0, ) pi = np.argmax(q, axis=1) print_pi(pi) diff --git a/src/functionrl/experiences.py b/src/functionrl/experiences.py index 7d83b46..5428aba 100644 --- a/src/functionrl/experiences.py +++ b/src/functionrl/experiences.py @@ -8,7 +8,7 @@ ) -def generate_episode(env, policy): +def gen_episode(env, policy): next_state = env.reset() while True: state = next_state @@ -16,19 +16,19 @@ def generate_episode(env, policy): policy_info = None if isinstance(action, tuple): action, policy_info = action - next_state, reward, is_done, env_info = env.step(action) + next_state, reward, is_done, _ = env.step(action) yield Experience(state, action, reward, next_state, is_done, policy_info) if is_done: break @limitable -def generate_episodes(env, policy): +def gen_episodes(env, policy): while True: - yield list(generate_episode(env, policy)) + yield list(gen_episode(env, policy)) @limitable -def generate_experiences(env, policy): +def gen_experiences(env, policy): while True: - yield from generate_episode(env, policy) + yield from gen_episode(env, policy) diff --git a/src/functionrl/policies.py b/src/functionrl/policies.py index 07cc4d6..ee0d8d9 100644 --- a/src/functionrl/policies.py +++ b/src/functionrl/policies.py @@ -1,49 +1,36 @@ import numpy as np -from .utils import linear_decay +from .utils import decay_generator def make_random_policy(n_actions, seed=None): rng = np.random.default_rng(seed=seed) - def _policy(state, step): + def _policy(state): return rng.integers(n_actions) return _policy -def make_greedy_action_selector(q): - return lambda state: q[state].argmax() - - -def make_random_action_selector(n_actions): - return lambda state: np.random.randint(n_actions) - - -def make_epsilon_greedy_policy(q, eps_start, eps_end, eps_decay_steps): - n_states, n_actions = q.shape - greedy_selector = make_greedy_action_selector(q) - random_selector = make_random_action_selector(n_actions) - eps_decay = linear_decay(eps_start, eps_end, eps_decay_steps) - - step = 0 - - def _policy(state, _): - nonlocal step - epsilon = eps_decay(step) - step += 1 - is_exploring = np.random.random() < epsilon - selector = random_selector if is_exploring else greedy_selector - action = selector(state) - return action +def make_greedy_policy(q): + def _policy(state): + return q[state].argmax() return _policy -def make_greedy_policy(q): - greedy_selector = make_greedy_action_selector(q) +def make_epsilon_greedy_policy(q, epsilon_decay_fn, seed=None): + rng = np.random.default_rng(seed=seed) - def _policy(state, step): - action = greedy_selector(state) - return action + n_actions = q.shape[1] + greedy_policy = make_greedy_policy(q) + random_policy = make_random_policy(n_actions, seed=seed) + epsilon_generator = iter(decay_generator(epsilon_decay_fn)) + + def _policy(state): + epsilon = next(epsilon_generator) + is_exploring = rng.random() < epsilon + policy = random_policy if is_exploring else greedy_policy + action = policy(state) + return action, {"epsilon": epsilon} return _policy diff --git a/src/functionrl/utils.py b/src/functionrl/utils.py index ebe0cd2..c887aae 100644 --- a/src/functionrl/utils.py +++ b/src/functionrl/utils.py @@ -1,5 +1,5 @@ from functools import wraps -from itertools import islice +from itertools import count, islice def linear_decay(start, end, decay_steps): @@ -7,6 +7,11 @@ def linear_decay(start, end, decay_steps): return lambda step: start + delta * step if step < decay_steps else end +def decay_generator(decay_fn): + for i in count(): + yield decay_fn(i) + + def limitable(func): @wraps(func) def wrapper(*args, n=None, **kwargs): diff --git a/tests/test_experiences.py b/tests/test_experiences.py index ee71e56..45bc726 100644 --- a/tests/test_experiences.py +++ b/tests/test_experiences.py @@ -1,7 +1,7 @@ import pytest from functionrl.envs import make_frozen_lake -from functionrl.experiences import generate_episode, generate_episodes, generate_experiences +from functionrl.experiences import gen_episode, gen_episodes, gen_experiences from functionrl.policies import make_random_policy @@ -16,17 +16,17 @@ def policy(): def test_generate_episode(env, policy): - episode = list(generate_episode(env, policy)) + episode = list(gen_episode(env, policy)) assert len(episode) == 5 assert episode[-1].is_done def test_episode_gen(env, policy): - episodes = list(generate_episodes(env, policy, n=4)) + episodes = list(gen_episodes(env, policy, n=4)) assert [len(episode) for episode in episodes] == [5, 2, 2, 3] def test_generate_experiences(env, policy): - experiences = list(generate_experiences(env, policy, n=10)) + experiences = list(gen_experiences(env, policy, n=10)) assert len(experiences) == 10 assert len([e for e in experiences if e.is_done]) == 3 diff --git a/tests/test_policies.py b/tests/test_policies.py index 33e1803..ec1aba7 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -3,5 +3,5 @@ def test_make_random_policy(): policy = make_random_policy(4, seed=1) - actions = [policy(None, step) for step in range(8)] + actions = [policy(None) for _ in range(8)] assert actions == [1, 2, 3, 3, 0, 0, 3, 3] From 7b65c287d174153cb62f489de3ebb153972c1840 Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Tue, 11 Jan 2022 05:12:00 +0000 Subject: [PATCH 4/8] speed up reinforce evaluation --- src/functionrl/algorithms/reinforce.py | 30 +++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py index 5c7090b..df873e2 100644 --- a/src/functionrl/algorithms/reinforce.py +++ b/src/functionrl/algorithms/reinforce.py @@ -1,13 +1,11 @@ import numpy as np import torch from torch import nn, optim -from torch.nn import functional as F from torch.distributions import Categorical -from ..utils import linear_decay -from ..policies import make_epsilon_greedy_policy, make_greedy_policy -from ..experiences import gen_episodes +from torch.nn import functional as F + from ..envs import make_frozen_lake -from ..display import print_pi, print_v +from ..experiences import gen_episodes class Pi(nn.Module): @@ -47,10 +45,20 @@ def reinforce( pi = Pi(n_states, n_actions) optimizer = optim.Adam(pi.parameters(), lr=learning_rate) + print(pi) + def policy(state): action, log_prob = pi.act(state) return action, {"log_prob": log_prob} + def make_eval_policy(): + states = torch.arange(n_states) + states_oh = F.one_hot(states, num_classes=n_states).float() + logits = pi(states_oh) + pds = [Categorical(logits=sl) for sl in logits] + return lambda state: pds[state].sample().item() + + losses = [] for i, episode in enumerate(gen_episodes(env_train, policy, n=n_episodes), start=1): T = len(episode) rewards = [exp.reward for exp in episode] @@ -60,19 +68,25 @@ def policy(state): for t in reversed(range(T)): future_ret = rewards[t] + gamma * future_ret rets[t] = future_ret - rets = torch.tensor(rets) + rets = torch.tensor(rets) # .to("cuda") + # rets.sub_(rets.mean()) log_probs = torch.stack(log_probs) loss = (-log_probs * rets).sum() optimizer.zero_grad() loss.backward() optimizer.step() + losses.append(loss.item()) + if i % log_interval == 0: with torch.no_grad(): - episodes = list(gen_episodes(env_eval, policy, n=eval_episodes)) + episodes = list( + gen_episodes(env_eval, make_eval_policy(), n=eval_episodes) + ) returns = [sum(e.reward for e in episode) for episode in episodes] mean_return = np.mean(returns) - print(f"{i:5d}: {mean_return:.3f}") + mean_loss = np.array(losses[-log_interval:]).mean() + print(f"{i:5d}: {mean_return:.3f} - loss: {mean_loss:8.4f}") if __name__ == "__main__": From de2f7264244e1a396bff06f5ce7209078b5a49d0 Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Tue, 11 Jan 2022 17:49:32 +0000 Subject: [PATCH 5/8] add *.code-workspace to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index d9005f2..91668e1 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# VSCode +*.code-workspace From d6e8b36fbe210b8ca0477af8b7dc739814c56827 Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Tue, 11 Jan 2022 18:10:40 +0000 Subject: [PATCH 6/8] add evaluate_policy helper and add seeding to reinforce --- src/functionrl/algorithms/reinforce.py | 28 +++++++++++++++----------- src/functionrl/algorithms/tabular_q.py | 23 +++++++++------------ src/functionrl/experiences.py | 4 ++++ src/functionrl/policies.py | 10 +++++++++ 4 files changed, 39 insertions(+), 26 deletions(-) diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py index df873e2..f3a08ed 100644 --- a/src/functionrl/algorithms/reinforce.py +++ b/src/functionrl/algorithms/reinforce.py @@ -1,9 +1,12 @@ +from typing import Optional import numpy as np import torch from torch import nn, optim from torch.distributions import Categorical from torch.nn import functional as F +from functionrl.policies import evaluate_policy + from ..envs import make_frozen_lake from ..experiences import gen_episodes @@ -35,12 +38,15 @@ def reinforce( n_episodes: int = 10000, log_interval: int = 100, eval_episodes: int = 1000, + seed: Optional[int] = None, ): - env_train = make_env() - env_eval = make_env() + if seed is not None: + torch.manual_seed(seed) + + env = make_env() - n_states = env_train.observation_space.n - n_actions = env_train.action_space.n + n_states = env.observation_space.n + n_actions = env.action_space.n pi = Pi(n_states, n_actions) optimizer = optim.Adam(pi.parameters(), lr=learning_rate) @@ -54,12 +60,13 @@ def policy(state): def make_eval_policy(): states = torch.arange(n_states) states_oh = F.one_hot(states, num_classes=n_states).float() - logits = pi(states_oh) + with torch.no_grad(): + logits = pi(states_oh) pds = [Categorical(logits=sl) for sl in logits] return lambda state: pds[state].sample().item() losses = [] - for i, episode in enumerate(gen_episodes(env_train, policy, n=n_episodes), start=1): + for i, episode in enumerate(gen_episodes(env, policy, n=n_episodes), start=1): T = len(episode) rewards = [exp.reward for exp in episode] log_probs = [exp.policy_info["log_prob"] for exp in episode] @@ -79,12 +86,8 @@ def make_eval_policy(): losses.append(loss.item()) if i % log_interval == 0: - with torch.no_grad(): - episodes = list( - gen_episodes(env_eval, make_eval_policy(), n=eval_episodes) - ) - returns = [sum(e.reward for e in episode) for episode in episodes] - mean_return = np.mean(returns) + policy = make_eval_policy() + mean_return = evaluate_policy(make_env, policy, eval_episodes) mean_loss = np.array(losses[-log_interval:]).mean() print(f"{i:5d}: {mean_return:.3f} - loss: {mean_loss:8.4f}") @@ -95,4 +98,5 @@ def make_eval_policy(): gamma=0.99, learning_rate=0.01, n_episodes=10000, + seed=0, ) diff --git a/src/functionrl/algorithms/tabular_q.py b/src/functionrl/algorithms/tabular_q.py index ece5fdc..53e4202 100644 --- a/src/functionrl/algorithms/tabular_q.py +++ b/src/functionrl/algorithms/tabular_q.py @@ -2,8 +2,8 @@ from typing import Optional import numpy as np from ..utils import linear_decay -from ..policies import make_epsilon_greedy_policy, make_greedy_policy -from ..experiences import gen_experiences, gen_episodes +from ..policies import evaluate_policy, make_epsilon_greedy_policy, make_greedy_policy +from ..experiences import gen_experiences from ..envs import make_frozen_lake from ..display import print_pi, print_v @@ -22,11 +22,10 @@ def tabular_q( eval_episodes: int = 1000, seed: Optional[int] = None, ): - env_train = make_env() - env_eval = make_env() + env = make_env() - n_states = env_train.observation_space.n - n_actions = env_train.action_space.n + n_states = env.observation_space.n + n_actions = env.action_space.n q = np.zeros((n_states, n_actions)) @@ -36,9 +35,7 @@ def tabular_q( policy_train = make_epsilon_greedy_policy(q, epsilon_decay, seed=seed) policy_eval = make_greedy_policy(q) - for i, exp in enumerate( - gen_experiences(env_train, policy_train, n=n_steps), start=1 - ): + for i, exp in enumerate(gen_experiences(env, policy_train, n=n_steps), start=1): state, action, reward, next_state, is_done, policy_info = exp td_target = reward + gamma * float(not is_done) * q[next_state].max() td_error = td_target - q[state, action] @@ -48,10 +45,10 @@ def tabular_q( if i % log_interval == 0: epsilon = policy_info["epsilon"] - episodes = list(gen_episodes(env_eval, policy_eval, n=eval_episodes)) - returns = [sum(e.reward for e in episode) for episode in episodes] - mean_return = np.mean(returns) + mean_return = evaluate_policy(make_env, policy_eval, eval_episodes) print(f"{i:5d}: {mean_return:.3f}, eps: {epsilon:.3f}, alpha: {alpha:.6f}") + pi = np.argmax(q, axis=1) + print_pi(pi) return q @@ -70,7 +67,5 @@ def tabular_q( log_interval=10_000, seed=0, ) - pi = np.argmax(q, axis=1) - print_pi(pi) v = np.max(q, axis=1) print_v(v) diff --git a/src/functionrl/experiences.py b/src/functionrl/experiences.py index 5428aba..3116270 100644 --- a/src/functionrl/experiences.py +++ b/src/functionrl/experiences.py @@ -32,3 +32,7 @@ def gen_episodes(env, policy): def gen_experiences(env, policy): while True: yield from gen_episode(env, policy) + + +def calc_episode_return(episode): + return sum(experience.reward for experience in episode) diff --git a/src/functionrl/policies.py b/src/functionrl/policies.py index ee0d8d9..4a29fbc 100644 --- a/src/functionrl/policies.py +++ b/src/functionrl/policies.py @@ -1,4 +1,6 @@ import numpy as np + +from functionrl.experiences import calc_episode_return, gen_episodes from .utils import decay_generator @@ -34,3 +36,11 @@ def _policy(state): return action, {"epsilon": epsilon} return _policy + + +def evaluate_policy(make_env, policy, n_episodes): + env = make_env() + episodes = gen_episodes(env, policy, n=n_episodes) + returns = [calc_episode_return(episode) for episode in episodes] + mean_return = np.mean(returns) + return mean_return From f3261a8b321418f36ba070267d8615723a10daa3 Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Thu, 13 Jan 2022 05:35:33 +0000 Subject: [PATCH 7/8] refactor model policy code --- src/functionrl/algorithms/reinforce.py | 60 +++++++------------------- src/functionrl/algorithms/tabular_q.py | 1 - src/functionrl/models.py | 22 ++++++++++ src/functionrl/policies.py | 22 +++++++++- 4 files changed, 58 insertions(+), 47 deletions(-) create mode 100644 src/functionrl/models.py diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py index f3a08ed..bde6d4d 100644 --- a/src/functionrl/algorithms/reinforce.py +++ b/src/functionrl/algorithms/reinforce.py @@ -1,36 +1,19 @@ from typing import Optional + import numpy as np import torch -from torch import nn, optim -from torch.distributions import Categorical -from torch.nn import functional as F - -from functionrl.policies import evaluate_policy +from functionrl.models import LinearNet +from functionrl.policies import ( + evaluate_policy, + make_categorical_policy_from_model, + make_greedy_policy_from_model, +) +from torch import optim from ..envs import make_frozen_lake from ..experiences import gen_episodes -class Pi(nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - self.model = nn.Linear(in_dim, out_dim, bias=False) - - def act(self, state): - state_t = torch.tensor([state]) - state_oh = F.one_hot(state_t, num_classes=self.in_dim).float() - logits = self.forward(state_oh) - pd = Categorical(logits=logits) - action = pd.sample() - log_prob = pd.log_prob(action) - return action.item(), log_prob - - def forward(self, x): - return self.model(x) - - def reinforce( make_env, gamma: float = 1.0, @@ -44,26 +27,14 @@ def reinforce( torch.manual_seed(seed) env = make_env() - n_states = env.observation_space.n n_actions = env.action_space.n - pi = Pi(n_states, n_actions) - optimizer = optim.Adam(pi.parameters(), lr=learning_rate) - + pi = LinearNet(n_states, n_actions) print(pi) - def policy(state): - action, log_prob = pi.act(state) - return action, {"log_prob": log_prob} - - def make_eval_policy(): - states = torch.arange(n_states) - states_oh = F.one_hot(states, num_classes=n_states).float() - with torch.no_grad(): - logits = pi(states_oh) - pds = [Categorical(logits=sl) for sl in logits] - return lambda state: pds[state].sample().item() + optimizer = optim.Adam(pi.parameters(), lr=learning_rate) + policy = make_categorical_policy_from_model(pi) losses = [] for i, episode in enumerate(gen_episodes(env, policy, n=n_episodes), start=1): @@ -75,7 +46,7 @@ def make_eval_policy(): for t in reversed(range(T)): future_ret = rewards[t] + gamma * future_ret rets[t] = future_ret - rets = torch.tensor(rets) # .to("cuda") + rets = torch.tensor(rets) # rets.sub_(rets.mean()) log_probs = torch.stack(log_probs) loss = (-log_probs * rets).sum() @@ -86,10 +57,10 @@ def make_eval_policy(): losses.append(loss.item()) if i % log_interval == 0: - policy = make_eval_policy() + policy = make_greedy_policy_from_model(pi, n_states) mean_return = evaluate_policy(make_env, policy, eval_episodes) mean_loss = np.array(losses[-log_interval:]).mean() - print(f"{i:5d}: {mean_return:.3f} - loss: {mean_loss:8.4f}") + print(f"{i:5d} mean_return: {mean_return:.3f} - loss: {mean_loss:8.4f}") if __name__ == "__main__": @@ -98,5 +69,6 @@ def make_eval_policy(): gamma=0.99, learning_rate=0.01, n_episodes=10000, - seed=0, + seed=1, + eval_episodes=1000, ) diff --git a/src/functionrl/algorithms/tabular_q.py b/src/functionrl/algorithms/tabular_q.py index 53e4202..aa4cc30 100644 --- a/src/functionrl/algorithms/tabular_q.py +++ b/src/functionrl/algorithms/tabular_q.py @@ -1,4 +1,3 @@ -from itertools import count from typing import Optional import numpy as np from ..utils import linear_decay diff --git a/src/functionrl/models.py b/src/functionrl/models.py new file mode 100644 index 0000000..89388a2 --- /dev/null +++ b/src/functionrl/models.py @@ -0,0 +1,22 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +def to_tensor(x): + return torch.as_tensor(x) + + +class LinearNet(nn.Module): + def __init__(self, in_dim, out_dim, one_hot=True): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.one_hot = one_hot + self.layer = nn.Linear(in_dim, out_dim, bias=False) + + def forward(self, state): + state = to_tensor(state) + if self.one_hot: + state = F.one_hot(state, num_classes=self.in_dim).float() + return self.layer(state) diff --git a/src/functionrl/policies.py b/src/functionrl/policies.py index 4a29fbc..3bc01d1 100644 --- a/src/functionrl/policies.py +++ b/src/functionrl/policies.py @@ -1,7 +1,8 @@ import numpy as np - +import torch from functionrl.experiences import calc_episode_return, gen_episodes from .utils import decay_generator +from torch.distributions import Categorical def make_random_policy(n_actions, seed=None): @@ -20,6 +21,12 @@ def _policy(state): return _policy +def make_greedy_policy_from_model(model, n_states): + with torch.no_grad(): + q = model(torch.arange(n_states)).numpy() + return make_greedy_policy(q) + + def make_epsilon_greedy_policy(q, epsilon_decay_fn, seed=None): rng = np.random.default_rng(seed=seed) @@ -33,7 +40,18 @@ def _policy(state): is_exploring = rng.random() < epsilon policy = random_policy if is_exploring else greedy_policy action = policy(state) - return action, {"epsilon": epsilon} + return action, {"epsilon": epsilon, "is_exploring": is_exploring} + + return _policy + + +def make_categorical_policy_from_model(model): + def _policy(state): + logits = model(state) + prob_dist = Categorical(logits=logits) + action = prob_dist.sample() + log_prob = prob_dist.log_prob(action) + return action.item(), {"log_prob": log_prob} return _policy From 479d88cdb5943730b226251a1114aa0ca8775bd5 Mon Sep 17 00:00:00 2001 From: John Hartquist Date: Tue, 18 Jan 2022 05:22:29 +0000 Subject: [PATCH 8/8] 100% test coverage --- src/functionrl/algorithms/reinforce.py | 8 +- src/functionrl/algorithms/tabular_q.py | 2 +- src/functionrl/algorithms/value_iteration.py | 4 +- src/functionrl/envs.py | 8 +- src/functionrl/train.py | 118 +++++++++---------- src/functionrl/utils.py | 2 +- tests/test_display.py | 21 ++++ tests/test_experiences.py | 23 +++- tests/test_models.py | 12 ++ tests/test_policies.py | 10 +- tests/test_reinforce.py | 30 +++++ tests/test_tabular_q.py | 18 +++ tests/test_utils.py | 9 +- tests/test_value_iteration.py | 9 ++ 14 files changed, 201 insertions(+), 73 deletions(-) create mode 100644 tests/test_display.py create mode 100644 tests/test_models.py create mode 100644 tests/test_reinforce.py create mode 100644 tests/test_tabular_q.py create mode 100644 tests/test_value_iteration.py diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py index bde6d4d..2caeeb2 100644 --- a/src/functionrl/algorithms/reinforce.py +++ b/src/functionrl/algorithms/reinforce.py @@ -57,13 +57,15 @@ def reinforce( losses.append(loss.item()) if i % log_interval == 0: - policy = make_greedy_policy_from_model(pi, n_states) - mean_return = evaluate_policy(make_env, policy, eval_episodes) + eval_policy = make_greedy_policy_from_model(pi, n_states) + mean_return = evaluate_policy(make_env, eval_policy, eval_episodes) mean_loss = np.array(losses[-log_interval:]).mean() print(f"{i:5d} mean_return: {mean_return:.3f} - loss: {mean_loss:8.4f}") + return policy -if __name__ == "__main__": + +if __name__ == "__main__": # pragma: no cover reinforce( make_frozen_lake, gamma=0.99, diff --git a/src/functionrl/algorithms/tabular_q.py b/src/functionrl/algorithms/tabular_q.py index aa4cc30..6e33d90 100644 --- a/src/functionrl/algorithms/tabular_q.py +++ b/src/functionrl/algorithms/tabular_q.py @@ -52,7 +52,7 @@ def tabular_q( return q -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover q = tabular_q( make_frozen_lake, gamma=1, diff --git a/src/functionrl/algorithms/value_iteration.py b/src/functionrl/algorithms/value_iteration.py index ce30811..cd4dac9 100644 --- a/src/functionrl/algorithms/value_iteration.py +++ b/src/functionrl/algorithms/value_iteration.py @@ -12,7 +12,7 @@ def value_iteration(env, gamma=0.99, theta=1e-10): n_actions = env.action_space.n transitions = env.P v = [0.0 for _ in range(n_states)] - for step in count(1): + for step in count(1): # pragma: no branch q = [ [ sum( @@ -34,7 +34,7 @@ def value_iteration(env, gamma=0.99, theta=1e-10): return pi, {"steps": step, "q": q, "v": v} -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover env = gym.make("FrozenLake-v1") pi, info = value_iteration(env) print_pi(pi) diff --git a/src/functionrl/envs.py b/src/functionrl/envs.py index b093cd8..9126de0 100644 --- a/src/functionrl/envs.py +++ b/src/functionrl/envs.py @@ -7,7 +7,7 @@ def make_frozen_lake(): return env -def make_frozen_lake_not_slippery(): - env = gym.make("FrozenLake-v1", is_slippery=False) - env.seed(0) - return env +# def make_frozen_lake_not_slippery(): +# env = gym.make("FrozenLake-v1", is_slippery=False) +# env.seed(0) +# return env diff --git a/src/functionrl/train.py b/src/functionrl/train.py index ced86ab..fd70212 100644 --- a/src/functionrl/train.py +++ b/src/functionrl/train.py @@ -1,78 +1,78 @@ -import gym -import numpy as np -import torch -from torch import nn -from algorithms.value_iteration import value_iteration -from display import print_v, print_pi +# import gym +# import numpy as np +# import torch +# from torch import nn +# from algorithms.value_iteration import value_iteration +# from display import print_v, print_pi -np.set_printoptions(suppress=True, precision=4) +# np.set_printoptions(suppress=True, precision=4) -torch.manual_seed(0) +# torch.manual_seed(0) -NUM_EPOCHS = 10000 -LEARNING_RATE = 0.1 +# NUM_EPOCHS = 10000 +# LEARNING_RATE = 0.1 -# HIDDEN_DIM = 8 -# HIDDEN_DIM = 8 -HIDDEN_DIM = 16 +# # HIDDEN_DIM = 8 +# # HIDDEN_DIM = 8 +# HIDDEN_DIM = 16 -class QNetwork(nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.layers = nn.Sequential( - nn.Linear(in_dim, out_dim, bias=False), - # nn.Linear(in_dim, HIDDEN_DIM), - # nn.ReLU(), - # nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - # nn.ReLU(), - # nn.Linear(HIDDEN_DIM, out_dim), - ) +# class QNetwork(nn.Module): +# def __init__(self, in_dim, out_dim): +# super().__init__() +# self.layers = nn.Sequential( +# nn.Linear(in_dim, out_dim, bias=False), +# # nn.Linear(in_dim, HIDDEN_DIM), +# # nn.ReLU(), +# # nn.Linear(HIDDEN_DIM, HIDDEN_DIM), +# # nn.ReLU(), +# # nn.Linear(HIDDEN_DIM, out_dim), +# ) - def forward(self, x): - return self.layers(x) +# def forward(self, x): +# return self.layers(x) -if __name__ == "__main__": - env = gym.make("FrozenLake-v1") +# if __name__ == "__main__": +# env = gym.make("FrozenLake-v1") - pi, info = value_iteration(env, gamma=1) - q = np.array(info["q"]) - n_states, n_actions = q.shape - print(q) +# pi, info = value_iteration(env, gamma=1) +# q = np.array(info["q"]) +# n_states, n_actions = q.shape +# print(q) - x = nn.functional.one_hot(torch.arange(n_states)).float() - y = torch.Tensor(q) - ds = torch.utils.data.TensorDataset(x, y) - dl = torch.utils.data.DataLoader(ds, batch_size=n_states) +# x = nn.functional.one_hot(torch.arange(n_states)).float() +# y = torch.Tensor(q) +# ds = torch.utils.data.TensorDataset(x, y) +# dl = torch.utils.data.DataLoader(ds, batch_size=n_states) - net = QNetwork(16, 4) +# net = QNetwork(16, 4) - print(f"{sum(p.numel() for p in net.parameters())} params") +# print(f"{sum(p.numel() for p in net.parameters())} params") - opt = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE) - loss_fn = nn.MSELoss() +# opt = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE) +# loss_fn = nn.MSELoss() - for epoch in range(1, NUM_EPOCHS + 1): - for batch in dl: - xb, yb = batch - opt.zero_grad() - yh = net(x) - loss = loss_fn(yh, y) - loss.backward() - opt.step() +# for epoch in range(1, NUM_EPOCHS + 1): +# for batch in dl: +# xb, yb = batch +# opt.zero_grad() +# yh = net(x) +# loss = loss_fn(yh, y) +# loss.backward() +# opt.step() - if epoch % 100 == 0: - print(f"{epoch:4d}: {loss.item():.8f}") +# if epoch % 100 == 0: +# print(f"{epoch:4d}: {loss.item():.8f}") - if loss.item() < 1e-7: - print(f"Stopping early after {epoch} epochs") - break +# if loss.item() < 1e-7: +# print(f"Stopping early after {epoch} epochs") +# break - # yh = net(x) - # q = yh.detach().numpy() - # v = np.max(q, axis=1).reshape(4, 4) - # print(v) +# # yh = net(x) +# # q = yh.detach().numpy() +# # v = np.max(q, axis=1).reshape(4, 4) +# # print(v) - print_v(info["v"]) - print_pi(pi) +# print_v(info["v"]) +# print_pi(pi) diff --git a/src/functionrl/utils.py b/src/functionrl/utils.py index c887aae..2aa8be2 100644 --- a/src/functionrl/utils.py +++ b/src/functionrl/utils.py @@ -8,7 +8,7 @@ def linear_decay(start, end, decay_steps): def decay_generator(decay_fn): - for i in count(): + for i in count(): # pragma: no branch yield decay_fn(i) diff --git a/tests/test_display.py b/tests/test_display.py new file mode 100644 index 0000000..8c2a547 --- /dev/null +++ b/tests/test_display.py @@ -0,0 +1,21 @@ +from functionrl.display import print_grid, print_pi, print_v + + +def test_print_grid(capfd): + print_grid([1, 2, 3, 4]) + out = capfd.readouterr()[0] + assert out == "1 2\n3 4\n\n" + + +def test_v(capfd): + v = [1, 2, 3, 4] + print_v(v) + out = capfd.readouterr()[0] + assert out == "1.0000 2.0000\n3.0000 4.0000\n\n" + + +def test_print_pi(capfd): + pi = [0, 1, 3, 2] + print_pi(pi) + out = capfd.readouterr()[0] + assert out == "← ↓\n↑ →\n\n" diff --git a/tests/test_experiences.py b/tests/test_experiences.py index 45bc726..7c444a8 100644 --- a/tests/test_experiences.py +++ b/tests/test_experiences.py @@ -1,7 +1,13 @@ import pytest from functionrl.envs import make_frozen_lake -from functionrl.experiences import gen_episode, gen_episodes, gen_experiences +from functionrl.experiences import ( + Experience, + calc_episode_return, + gen_episode, + gen_episodes, + gen_experiences, +) from functionrl.policies import make_random_policy @@ -25,8 +31,23 @@ def test_episode_gen(env, policy): episodes = list(gen_episodes(env, policy, n=4)) assert [len(episode) for episode in episodes] == [5, 2, 2, 3] + def info_policy(state): + return 0, {} + + episodes = list(gen_episodes(env, info_policy, n=4)) + assert [len(episode) for episode in episodes] == [8, 24, 8, 10] + def test_generate_experiences(env, policy): experiences = list(gen_experiences(env, policy, n=10)) assert len(experiences) == 10 assert len([e for e in experiences if e.is_done]) == 3 + + +def test_calc_episode_return(): + episode = [ + Experience(0, 0, 1, 0, False, {}), + Experience(0, 0, 2, 0, False, {}), + Experience(0, 0, 3, 0, False, {}), + ] + assert calc_episode_return(episode) == 6 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..a991b07 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,12 @@ +import torch +from functionrl.models import LinearNet + + +def test_linear_net(): + net = LinearNet(2, 4) + assert net(0).shape == (4,) + assert net([0]).shape == (1, 4) + + net = LinearNet(2, 4, one_hot=False) + assert net(torch.tensor([0, 1], dtype=torch.float)).shape == (4,) + assert net(torch.tensor([[0, 1]], dtype=torch.float)).shape == (1, 4) diff --git a/tests/test_policies.py b/tests/test_policies.py index ec1aba7..cc040e1 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,7 +1,15 @@ -from functionrl.policies import make_random_policy +import numpy as np +from functionrl.policies import make_random_policy, make_greedy_policy def test_make_random_policy(): policy = make_random_policy(4, seed=1) actions = [policy(None) for _ in range(8)] assert actions == [1, 2, 3, 3, 0, 0, 3, 3] + + +def test_make_greedy_policy(): + q = np.array([[0, 2, 1], [2, 1, 0]]) + policy = make_greedy_policy(q) + assert policy(0) == 1 + assert policy(1) == 0 diff --git a/tests/test_reinforce.py b/tests/test_reinforce.py new file mode 100644 index 0000000..a75016c --- /dev/null +++ b/tests/test_reinforce.py @@ -0,0 +1,30 @@ +from functionrl.algorithms.reinforce import reinforce +from functionrl.envs import make_frozen_lake + + +def test_reinforce(): + policy = reinforce( + make_frozen_lake, + gamma=0.99, + learning_rate=0.01, + n_episodes=2, + log_interval=2, + eval_episodes=2, + seed=1, + ) + + assert policy(0)[0] == 3 + assert policy(1)[0] == 3 + assert policy(2)[0] == 1 + assert policy(3)[0] == 3 + + +def test_reinforce_no_seed(): + reinforce( + make_frozen_lake, + gamma=0.99, + learning_rate=0.01, + n_episodes=2, + log_interval=1, + eval_episodes=2, + ) diff --git a/tests/test_tabular_q.py b/tests/test_tabular_q.py new file mode 100644 index 0000000..0f951f1 --- /dev/null +++ b/tests/test_tabular_q.py @@ -0,0 +1,18 @@ +from functionrl.algorithms.tabular_q import tabular_q +from functionrl.envs import make_frozen_lake + + +def test_tabular_q(): + q = tabular_q( + make_frozen_lake, + gamma=1, + alpha_max=1e-1, + alpha_min=1e-3, + alpha_decay_steps=100_000, + epsilon_max=1.0, + epsilon_min=0.1, + epsilon_decay_steps=100_000, + n_steps=10, + log_interval=5, + seed=0, + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d44558..fe06c1d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ -from functionrl.utils import linear_decay, limitable +from itertools import islice +from functionrl.utils import linear_decay, decay_generator, limitable def test_linear_decay(): @@ -11,6 +12,12 @@ def test_linear_decay(): assert decay(100000) == 0.5 +def test_decay_generator(): + decay = linear_decay(1, 0.5, 2) + generator = decay_generator(decay) + assert list(islice(generator, 5)) == [1.0, 0.75, 0.5, 0.5, 0.5] + + def test_limitable(): @limitable def generator(): diff --git a/tests/test_value_iteration.py b/tests/test_value_iteration.py new file mode 100644 index 0000000..fb21114 --- /dev/null +++ b/tests/test_value_iteration.py @@ -0,0 +1,9 @@ +from functionrl.algorithms.value_iteration import value_iteration +from functionrl.envs import make_frozen_lake + + +def test_value_iteration(): + env = make_frozen_lake() + pi, info = value_iteration(env) + assert pi == [0, 3, 3, 3, 0, 0, 0, 0, 3, 1, 0, 0, 0, 2, 1, 0] + assert info["steps"] == 571