diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 822dec6..3169257 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -16,7 +16,7 @@ USER $USERNAME # install poetry for package management RUN curl -sSL https://install.python-poetry.org | python3 - -ENV PATH="~/.local/bin:$PATH" +ENV PATH="/home/$USERNAME/.local/bin:$PATH" WORKDIR $DIR diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 00bf2dd..ca3e97d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,5 @@ { "build": { "dockerfile": "Dockerfile", "context": ".." }, "runArgs": ["--gpus=all"], - "extensions": ["ms-python.python", "tamasfe.even-better-toml"], + "extensions": ["ms-python.python", "tamasfe.even-better-toml"] } diff --git a/.gitignore b/.gitignore index d9005f2..91668e1 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# VSCode +*.code-workspace diff --git a/poetry.lock b/poetry.lock index dd5b9d9..af6d224 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,6 +20,32 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"] tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"] tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"] +[[package]] +name = "black" +version = "21.12b0" +description = "The uncompromising code formatter." +category = "dev" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +click = ">=7.1.2" +mypy-extensions = ">=0.4.3" +pathspec = ">=0.9.0,<1" +platformdirs = ">=2" +tomli = ">=0.2.6,<2.0.0" +typing-extensions = [ + {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}, + {version = "!=3.10.0.1", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +python2 = ["typed-ast (>=1.4.3)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "certifi" version = "2021.10.8" @@ -120,6 +146,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "dev" +optional = false +python-versions = "*" + [[package]] name = "numpy" version = "1.22.0" @@ -160,6 +194,26 @@ pytz = ">=2017.3" [package.extras] test = ["hypothesis (>=3.58)", "pytest (>=6.0)", "pytest-xdist"] +[[package]] +name = "pathspec" +version = "0.9.0" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" + +[[package]] +name = "platformdirs" +version = "2.4.1" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["Sphinx (>=4)", "furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)"] +test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] + [[package]] name = "pluggy" version = "1.0.0" @@ -296,11 +350,30 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] name = "tomli" -version = "2.0.0" +version = "1.2.3" description = "A lil' TOML parser" category = "dev" optional = false -python-versions = ">=3.7" +python-versions = ">=3.6" + +[[package]] +name = "torch" +version = "1.10.1" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +category = "main" +optional = false +python-versions = ">=3.6.2" + +[package.dependencies] +typing-extensions = "*" + +[[package]] +name = "typing-extensions" +version = "4.0.1" +description = "Backported and Experimental Type Hints for Python 3.6+" +category = "main" +optional = false +python-versions = ">=3.6" [[package]] name = "urllib3" @@ -318,7 +391,7 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "3690820d8b14aff176d2c7a2785ed19b97f36801fa9912e86aced8c0aeb5f3c0" +content-hash = "692fd4035dc8cb9de040ee0d1bafcaf5676d8c04d6fa3d94d6703b098fdd3517" [metadata.files] atomicwrites = [ @@ -329,6 +402,10 @@ attrs = [ {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"}, {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"}, ] +black = [ + {file = "black-21.12b0-py3-none-any.whl", hash = "sha256:a615e69ae185e08fdd73e4715e260e2479c861b5740057fde6e8b4e3b7dd589f"}, + {file = "black-21.12b0.tar.gz", hash = "sha256:77b80f693a569e2e527958459634f18df9b0ba2625ba4e0c2d5da5be42e6f2b3"}, +] certifi = [ {file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"}, {file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"}, @@ -409,6 +486,10 @@ iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] +mypy-extensions = [ + {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, + {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, +] numpy = [ {file = "numpy-1.22.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3d22662b4b10112c545c91a0741f2436f8ca979ab3d69d03d19322aa970f9695"}, {file = "numpy-1.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:11a1f3816ea82eed4178102c56281782690ab5993251fdfd75039aad4d20385f"}, @@ -464,6 +545,14 @@ pandas = [ {file = "pandas-1.3.5-cp39-cp39-win_amd64.whl", hash = "sha256:32e1a26d5ade11b547721a72f9bfc4bd113396947606e00d5b4a5b79b3dcb006"}, {file = "pandas-1.3.5.tar.gz", hash = "sha256:1e4285f5de1012de20ca46b188ccf33521bff61ba5c5ebd78b4fb28e5416a9f1"}, ] +pathspec = [ + {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, + {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, +] +platformdirs = [ + {file = "platformdirs-2.4.1-py3-none-any.whl", hash = "sha256:1d7385c7db91728b83efd0ca99a5afb296cab9d0ed8313a45ed8ba17967ecfca"}, + {file = "platformdirs-2.4.1.tar.gz", hash = "sha256:440633ddfebcc36264232365d7840a970e75e1018d15b4327d11f91909045fda"}, +] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, @@ -509,8 +598,32 @@ toml = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] tomli = [ - {file = "tomli-2.0.0-py3-none-any.whl", hash = "sha256:b5bde28da1fed24b9bd1d4d2b8cba62300bfb4ec9a6187a957e8ddb9434c5224"}, - {file = "tomli-2.0.0.tar.gz", hash = "sha256:c292c34f58502a1eb2bbb9f5bbc9a5ebc37bee10ffb8c2d6bbdfa8eb13cc14e1"}, + {file = "tomli-1.2.3-py3-none-any.whl", hash = "sha256:e3069e4be3ead9668e21cb9b074cd948f7b3113fd9c8bba083f48247aab8b11c"}, + {file = "tomli-1.2.3.tar.gz", hash = "sha256:05b6166bff487dc068d322585c7ea4ef78deed501cc124060e0f238e89a9231f"}, +] +torch = [ + {file = "torch-1.10.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:adbb5f292e260e39715d67478823e03e3001db1af5b02c18caa34549dccb421e"}, + {file = "torch-1.10.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:ac8cae04458cc47555fa07a760496c2fdf687223bcc13df5fed56ea3aead37f5"}, + {file = "torch-1.10.1-cp36-cp36m-win_amd64.whl", hash = "sha256:40508d67288c46ff1fad301fa6e996e0e936a733f2401475fc92c21dc3ef702d"}, + {file = "torch-1.10.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:8b47bd113c6cbd9a49669aaaa233ad5f25852d6ca3e640f9c71c808e65a1fdf4"}, + {file = "torch-1.10.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:50360868ad3f039cf99f0250300dbec51bf686a7b84dc6bbdb8dff4b1171c0f0"}, + {file = "torch-1.10.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:e3d2154722189ed74747a494dce9588978dd55e43ca24c5bd307fb52620b232b"}, + {file = "torch-1.10.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d9c495bcd5f00becff5b051b5e4be86b7eaa0433cd0fe57f77c02bc1b93ab5b1"}, + {file = "torch-1.10.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:6b327d7b4eb2461b16d46763d46df71e597235ccc428650538a2735a0898270d"}, + {file = "torch-1.10.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:1c6c56178e5dacf7602ad00dc79c263d6c41c0f76261e9641e6bd2679678ceb3"}, + {file = "torch-1.10.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2ffa2db4ccb6466c59e3f95b7a582d47ae721e476468f4ffbcaa2832e0b92b9b"}, + {file = "torch-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:af577602e884c5e40fbd29ec978f052202355da93cd31e0a23251bd7aaff5a99"}, + {file = "torch-1.10.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:725d86e9809073eef868a3ddf4189878ee7af46fac71403834dd0925b3db9b82"}, + {file = "torch-1.10.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:fa197cfe047d0515bef238f42472721406609ebaceff2fd4e17f2ad4692ee51c"}, + {file = "torch-1.10.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:cca660b27a90dbbc0af06c859260f6b875aef37c0897bd353e5deed085d2c877"}, + {file = "torch-1.10.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:01f4ffdafbfbd7d106fb4e487feee2cf29cced9903df8cb0444b0e308f9c5e92"}, + {file = "torch-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:607eccb7d539a11877cd02d95f4b164b7941fcf538ac7ff087bfed19e3644283"}, + {file = "torch-1.10.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:26b6dfbe21e247e67c615bfab0017ec391ed1517f88bbeea6228a49edd24cd88"}, + {file = "torch-1.10.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:5644280d88c5b6de27eacc0d911f968aad41a4bab297af4df5e571bc0927d3e4"}, +] +typing-extensions = [ + {file = "typing_extensions-4.0.1-py3-none-any.whl", hash = "sha256:7f001e5ac290a0c0401508864c7ec868be4e701886d5b573a9528ed3973d9d3b"}, + {file = "typing_extensions-4.0.1.tar.gz", hash = "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e"}, ] urllib3 = [ {file = "urllib3-1.26.8-py2.py3-none-any.whl", hash = "sha256:000ca7f471a233c2251c6c7023ee85305721bfdf18621ebff4fd17a8653427ed"}, diff --git a/pyproject.toml b/pyproject.toml index 7620489..53338a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,12 +16,14 @@ requests = "^2.27.1" numpy = "^1.22.0" gym = "^0.21.0" pandas = "^1.3.5" +torch = "^1.10.1" [tool.poetry.dev-dependencies] coverage = {extras = ["toml"], version = "^6.2"} pytest = "^6.2.5" pytest-cov = "^3.0.0" pytest-mock = "^3.6.1" +black = "^21.12b0" [tool.poetry.scripts] diff --git a/src/functionrl/algorithms/reinforce.py b/src/functionrl/algorithms/reinforce.py new file mode 100644 index 0000000..2caeeb2 --- /dev/null +++ b/src/functionrl/algorithms/reinforce.py @@ -0,0 +1,76 @@ +from typing import Optional + +import numpy as np +import torch +from functionrl.models import LinearNet +from functionrl.policies import ( + evaluate_policy, + make_categorical_policy_from_model, + make_greedy_policy_from_model, +) +from torch import optim + +from ..envs import make_frozen_lake +from ..experiences import gen_episodes + + +def reinforce( + make_env, + gamma: float = 1.0, + learning_rate: float = 1e-3, + n_episodes: int = 10000, + log_interval: int = 100, + eval_episodes: int = 1000, + seed: Optional[int] = None, +): + if seed is not None: + torch.manual_seed(seed) + + env = make_env() + n_states = env.observation_space.n + n_actions = env.action_space.n + + pi = LinearNet(n_states, n_actions) + print(pi) + + optimizer = optim.Adam(pi.parameters(), lr=learning_rate) + policy = make_categorical_policy_from_model(pi) + + losses = [] + for i, episode in enumerate(gen_episodes(env, policy, n=n_episodes), start=1): + T = len(episode) + rewards = [exp.reward for exp in episode] + log_probs = [exp.policy_info["log_prob"] for exp in episode] + rets = np.empty(T, dtype=np.float32) + future_ret = 0.0 + for t in reversed(range(T)): + future_ret = rewards[t] + gamma * future_ret + rets[t] = future_ret + rets = torch.tensor(rets) + # rets.sub_(rets.mean()) + log_probs = torch.stack(log_probs) + loss = (-log_probs * rets).sum() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + losses.append(loss.item()) + + if i % log_interval == 0: + eval_policy = make_greedy_policy_from_model(pi, n_states) + mean_return = evaluate_policy(make_env, eval_policy, eval_episodes) + mean_loss = np.array(losses[-log_interval:]).mean() + print(f"{i:5d} mean_return: {mean_return:.3f} - loss: {mean_loss:8.4f}") + + return policy + + +if __name__ == "__main__": # pragma: no cover + reinforce( + make_frozen_lake, + gamma=0.99, + learning_rate=0.01, + n_episodes=10000, + seed=1, + eval_episodes=1000, + ) diff --git a/src/functionrl/algorithms/tabular_q.py b/src/functionrl/algorithms/tabular_q.py index a9177b4..6e33d90 100644 --- a/src/functionrl/algorithms/tabular_q.py +++ b/src/functionrl/algorithms/tabular_q.py @@ -1,7 +1,8 @@ +from typing import Optional import numpy as np from ..utils import linear_decay -from ..policies import make_epsilon_greedy_policy, make_greedy_policy -from ..experiences import generate_experiences, generate_episodes +from ..policies import evaluate_policy, make_epsilon_greedy_policy, make_greedy_policy +from ..experiences import gen_experiences from ..envs import make_frozen_lake from ..display import print_pi, print_v @@ -18,48 +19,40 @@ def tabular_q( n_steps: int = 5000, log_interval: int = 1000, eval_episodes: int = 1000, + seed: Optional[int] = None, ): - env_train = make_env() - env_eval = make_env() + env = make_env() - n_states = env_train.observation_space.n - n_actions = env_train.action_space.n + n_states = env.observation_space.n + n_actions = env.action_space.n + + q = np.zeros((n_states, n_actions)) alpha_decay = linear_decay(alpha_max, alpha_min, alpha_decay_steps) epsilon_decay = linear_decay(epsilon_max, epsilon_min, epsilon_decay_steps) - q = np.zeros((n_states, n_actions)) - - # TODO: pass decay into make_eps - policy_train = make_epsilon_greedy_policy( - q, epsilon_max, epsilon_min, epsilon_decay_steps - ) + policy_train = make_epsilon_greedy_policy(q, epsilon_decay, seed=seed) policy_eval = make_greedy_policy(q) - for step, exp in enumerate( - generate_experiences(env_train, policy_train, n=n_steps) - ): - - td_target = ( - exp.reward + gamma * float(not exp.is_done) * q[exp.next_state].max() - ) - td_error = td_target - q[exp.state, exp.action] + for i, exp in enumerate(gen_experiences(env, policy_train, n=n_steps), start=1): + state, action, reward, next_state, is_done, policy_info = exp + td_target = reward + gamma * float(not is_done) * q[next_state].max() + td_error = td_target - q[state, action] - alpha = alpha_decay(step) - q[exp.state, exp.action] += alpha * td_error + alpha = alpha_decay(i) + q[state, action] += alpha * td_error - if (step + 1) % log_interval == 0: - episodes = list(generate_episodes(env_eval, policy_eval, n=eval_episodes)) - returns = [sum(e.reward for e in episode) for episode in episodes] - mean_return = np.mean(returns) - print( - f"{step+1:5d}: {mean_return:.3f}, eps: {epsilon_decay(step):.3f}, alpha: {alpha:.6f}" - ) + if i % log_interval == 0: + epsilon = policy_info["epsilon"] + mean_return = evaluate_policy(make_env, policy_eval, eval_episodes) + print(f"{i:5d}: {mean_return:.3f}, eps: {epsilon:.3f}, alpha: {alpha:.6f}") + pi = np.argmax(q, axis=1) + print_pi(pi) return q -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover q = tabular_q( make_frozen_lake, gamma=1, @@ -71,8 +64,7 @@ def tabular_q( epsilon_decay_steps=100_000, n_steps=100_000, log_interval=10_000, + seed=0, ) - pi = np.argmax(q, axis=1) - print_pi(pi) v = np.max(q, axis=1) print_v(v) diff --git a/src/functionrl/algorithms/value_iteration.py b/src/functionrl/algorithms/value_iteration.py index ce30811..cd4dac9 100644 --- a/src/functionrl/algorithms/value_iteration.py +++ b/src/functionrl/algorithms/value_iteration.py @@ -12,7 +12,7 @@ def value_iteration(env, gamma=0.99, theta=1e-10): n_actions = env.action_space.n transitions = env.P v = [0.0 for _ in range(n_states)] - for step in count(1): + for step in count(1): # pragma: no branch q = [ [ sum( @@ -34,7 +34,7 @@ def value_iteration(env, gamma=0.99, theta=1e-10): return pi, {"steps": step, "q": q, "v": v} -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover env = gym.make("FrozenLake-v1") pi, info = value_iteration(env) print_pi(pi) diff --git a/src/functionrl/envs.py b/src/functionrl/envs.py index b093cd8..9126de0 100644 --- a/src/functionrl/envs.py +++ b/src/functionrl/envs.py @@ -7,7 +7,7 @@ def make_frozen_lake(): return env -def make_frozen_lake_not_slippery(): - env = gym.make("FrozenLake-v1", is_slippery=False) - env.seed(0) - return env +# def make_frozen_lake_not_slippery(): +# env = gym.make("FrozenLake-v1", is_slippery=False) +# env.seed(0) +# return env diff --git a/src/functionrl/experiences.py b/src/functionrl/experiences.py index 206f740..3116270 100644 --- a/src/functionrl/experiences.py +++ b/src/functionrl/experiences.py @@ -2,27 +2,37 @@ from collections import namedtuple from .utils import limitable -Experience = namedtuple("Experience", "state action reward next_state is_done") +Experience = namedtuple( + "Experience", + ["state", "action", "reward", "next_state", "is_done", "policy_info"], +) -def generate_episode(env, policy): +def gen_episode(env, policy): next_state = env.reset() - for step in count(): + while True: state = next_state - action = policy(state, step) + action = policy(state) + policy_info = None + if isinstance(action, tuple): + action, policy_info = action next_state, reward, is_done, _ = env.step(action) - yield Experience(state, action, reward, next_state, is_done) + yield Experience(state, action, reward, next_state, is_done, policy_info) if is_done: break @limitable -def generate_episodes(env, policy): +def gen_episodes(env, policy): while True: - yield list(generate_episode(env, policy)) + yield list(gen_episode(env, policy)) @limitable -def generate_experiences(env, policy): +def gen_experiences(env, policy): while True: - yield from generate_episode(env, policy) + yield from gen_episode(env, policy) + + +def calc_episode_return(episode): + return sum(experience.reward for experience in episode) diff --git a/src/functionrl/models.py b/src/functionrl/models.py new file mode 100644 index 0000000..89388a2 --- /dev/null +++ b/src/functionrl/models.py @@ -0,0 +1,22 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +def to_tensor(x): + return torch.as_tensor(x) + + +class LinearNet(nn.Module): + def __init__(self, in_dim, out_dim, one_hot=True): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.one_hot = one_hot + self.layer = nn.Linear(in_dim, out_dim, bias=False) + + def forward(self, state): + state = to_tensor(state) + if self.one_hot: + state = F.one_hot(state, num_classes=self.in_dim).float() + return self.layer(state) diff --git a/src/functionrl/policies.py b/src/functionrl/policies.py index 07cc4d6..3bc01d1 100644 --- a/src/functionrl/policies.py +++ b/src/functionrl/policies.py @@ -1,49 +1,64 @@ import numpy as np -from .utils import linear_decay +import torch +from functionrl.experiences import calc_episode_return, gen_episodes +from .utils import decay_generator +from torch.distributions import Categorical def make_random_policy(n_actions, seed=None): rng = np.random.default_rng(seed=seed) - def _policy(state, step): + def _policy(state): return rng.integers(n_actions) return _policy -def make_greedy_action_selector(q): - return lambda state: q[state].argmax() +def make_greedy_policy(q): + def _policy(state): + return q[state].argmax() + return _policy -def make_random_action_selector(n_actions): - return lambda state: np.random.randint(n_actions) +def make_greedy_policy_from_model(model, n_states): + with torch.no_grad(): + q = model(torch.arange(n_states)).numpy() + return make_greedy_policy(q) -def make_epsilon_greedy_policy(q, eps_start, eps_end, eps_decay_steps): - n_states, n_actions = q.shape - greedy_selector = make_greedy_action_selector(q) - random_selector = make_random_action_selector(n_actions) - eps_decay = linear_decay(eps_start, eps_end, eps_decay_steps) - step = 0 +def make_epsilon_greedy_policy(q, epsilon_decay_fn, seed=None): + rng = np.random.default_rng(seed=seed) - def _policy(state, _): - nonlocal step - epsilon = eps_decay(step) - step += 1 - is_exploring = np.random.random() < epsilon - selector = random_selector if is_exploring else greedy_selector - action = selector(state) - return action + n_actions = q.shape[1] + greedy_policy = make_greedy_policy(q) + random_policy = make_random_policy(n_actions, seed=seed) + epsilon_generator = iter(decay_generator(epsilon_decay_fn)) - return _policy + def _policy(state): + epsilon = next(epsilon_generator) + is_exploring = rng.random() < epsilon + policy = random_policy if is_exploring else greedy_policy + action = policy(state) + return action, {"epsilon": epsilon, "is_exploring": is_exploring} + return _policy -def make_greedy_policy(q): - greedy_selector = make_greedy_action_selector(q) - def _policy(state, step): - action = greedy_selector(state) - return action +def make_categorical_policy_from_model(model): + def _policy(state): + logits = model(state) + prob_dist = Categorical(logits=logits) + action = prob_dist.sample() + log_prob = prob_dist.log_prob(action) + return action.item(), {"log_prob": log_prob} return _policy + + +def evaluate_policy(make_env, policy, n_episodes): + env = make_env() + episodes = gen_episodes(env, policy, n=n_episodes) + returns = [calc_episode_return(episode) for episode in episodes] + mean_return = np.mean(returns) + return mean_return diff --git a/src/functionrl/train.py b/src/functionrl/train.py index ced86ab..fd70212 100644 --- a/src/functionrl/train.py +++ b/src/functionrl/train.py @@ -1,78 +1,78 @@ -import gym -import numpy as np -import torch -from torch import nn -from algorithms.value_iteration import value_iteration -from display import print_v, print_pi +# import gym +# import numpy as np +# import torch +# from torch import nn +# from algorithms.value_iteration import value_iteration +# from display import print_v, print_pi -np.set_printoptions(suppress=True, precision=4) +# np.set_printoptions(suppress=True, precision=4) -torch.manual_seed(0) +# torch.manual_seed(0) -NUM_EPOCHS = 10000 -LEARNING_RATE = 0.1 +# NUM_EPOCHS = 10000 +# LEARNING_RATE = 0.1 -# HIDDEN_DIM = 8 -# HIDDEN_DIM = 8 -HIDDEN_DIM = 16 +# # HIDDEN_DIM = 8 +# # HIDDEN_DIM = 8 +# HIDDEN_DIM = 16 -class QNetwork(nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.layers = nn.Sequential( - nn.Linear(in_dim, out_dim, bias=False), - # nn.Linear(in_dim, HIDDEN_DIM), - # nn.ReLU(), - # nn.Linear(HIDDEN_DIM, HIDDEN_DIM), - # nn.ReLU(), - # nn.Linear(HIDDEN_DIM, out_dim), - ) +# class QNetwork(nn.Module): +# def __init__(self, in_dim, out_dim): +# super().__init__() +# self.layers = nn.Sequential( +# nn.Linear(in_dim, out_dim, bias=False), +# # nn.Linear(in_dim, HIDDEN_DIM), +# # nn.ReLU(), +# # nn.Linear(HIDDEN_DIM, HIDDEN_DIM), +# # nn.ReLU(), +# # nn.Linear(HIDDEN_DIM, out_dim), +# ) - def forward(self, x): - return self.layers(x) +# def forward(self, x): +# return self.layers(x) -if __name__ == "__main__": - env = gym.make("FrozenLake-v1") +# if __name__ == "__main__": +# env = gym.make("FrozenLake-v1") - pi, info = value_iteration(env, gamma=1) - q = np.array(info["q"]) - n_states, n_actions = q.shape - print(q) +# pi, info = value_iteration(env, gamma=1) +# q = np.array(info["q"]) +# n_states, n_actions = q.shape +# print(q) - x = nn.functional.one_hot(torch.arange(n_states)).float() - y = torch.Tensor(q) - ds = torch.utils.data.TensorDataset(x, y) - dl = torch.utils.data.DataLoader(ds, batch_size=n_states) +# x = nn.functional.one_hot(torch.arange(n_states)).float() +# y = torch.Tensor(q) +# ds = torch.utils.data.TensorDataset(x, y) +# dl = torch.utils.data.DataLoader(ds, batch_size=n_states) - net = QNetwork(16, 4) +# net = QNetwork(16, 4) - print(f"{sum(p.numel() for p in net.parameters())} params") +# print(f"{sum(p.numel() for p in net.parameters())} params") - opt = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE) - loss_fn = nn.MSELoss() +# opt = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE) +# loss_fn = nn.MSELoss() - for epoch in range(1, NUM_EPOCHS + 1): - for batch in dl: - xb, yb = batch - opt.zero_grad() - yh = net(x) - loss = loss_fn(yh, y) - loss.backward() - opt.step() +# for epoch in range(1, NUM_EPOCHS + 1): +# for batch in dl: +# xb, yb = batch +# opt.zero_grad() +# yh = net(x) +# loss = loss_fn(yh, y) +# loss.backward() +# opt.step() - if epoch % 100 == 0: - print(f"{epoch:4d}: {loss.item():.8f}") +# if epoch % 100 == 0: +# print(f"{epoch:4d}: {loss.item():.8f}") - if loss.item() < 1e-7: - print(f"Stopping early after {epoch} epochs") - break +# if loss.item() < 1e-7: +# print(f"Stopping early after {epoch} epochs") +# break - # yh = net(x) - # q = yh.detach().numpy() - # v = np.max(q, axis=1).reshape(4, 4) - # print(v) +# # yh = net(x) +# # q = yh.detach().numpy() +# # v = np.max(q, axis=1).reshape(4, 4) +# # print(v) - print_v(info["v"]) - print_pi(pi) +# print_v(info["v"]) +# print_pi(pi) diff --git a/src/functionrl/utils.py b/src/functionrl/utils.py index ebe0cd2..2aa8be2 100644 --- a/src/functionrl/utils.py +++ b/src/functionrl/utils.py @@ -1,5 +1,5 @@ from functools import wraps -from itertools import islice +from itertools import count, islice def linear_decay(start, end, decay_steps): @@ -7,6 +7,11 @@ def linear_decay(start, end, decay_steps): return lambda step: start + delta * step if step < decay_steps else end +def decay_generator(decay_fn): + for i in count(): # pragma: no branch + yield decay_fn(i) + + def limitable(func): @wraps(func) def wrapper(*args, n=None, **kwargs): diff --git a/tests/test_display.py b/tests/test_display.py new file mode 100644 index 0000000..8c2a547 --- /dev/null +++ b/tests/test_display.py @@ -0,0 +1,21 @@ +from functionrl.display import print_grid, print_pi, print_v + + +def test_print_grid(capfd): + print_grid([1, 2, 3, 4]) + out = capfd.readouterr()[0] + assert out == "1 2\n3 4\n\n" + + +def test_v(capfd): + v = [1, 2, 3, 4] + print_v(v) + out = capfd.readouterr()[0] + assert out == "1.0000 2.0000\n3.0000 4.0000\n\n" + + +def test_print_pi(capfd): + pi = [0, 1, 3, 2] + print_pi(pi) + out = capfd.readouterr()[0] + assert out == "← ↓\n↑ →\n\n" diff --git a/tests/test_experiences.py b/tests/test_experiences.py index ee71e56..7c444a8 100644 --- a/tests/test_experiences.py +++ b/tests/test_experiences.py @@ -1,7 +1,13 @@ import pytest from functionrl.envs import make_frozen_lake -from functionrl.experiences import generate_episode, generate_episodes, generate_experiences +from functionrl.experiences import ( + Experience, + calc_episode_return, + gen_episode, + gen_episodes, + gen_experiences, +) from functionrl.policies import make_random_policy @@ -16,17 +22,32 @@ def policy(): def test_generate_episode(env, policy): - episode = list(generate_episode(env, policy)) + episode = list(gen_episode(env, policy)) assert len(episode) == 5 assert episode[-1].is_done def test_episode_gen(env, policy): - episodes = list(generate_episodes(env, policy, n=4)) + episodes = list(gen_episodes(env, policy, n=4)) assert [len(episode) for episode in episodes] == [5, 2, 2, 3] + def info_policy(state): + return 0, {} + + episodes = list(gen_episodes(env, info_policy, n=4)) + assert [len(episode) for episode in episodes] == [8, 24, 8, 10] + def test_generate_experiences(env, policy): - experiences = list(generate_experiences(env, policy, n=10)) + experiences = list(gen_experiences(env, policy, n=10)) assert len(experiences) == 10 assert len([e for e in experiences if e.is_done]) == 3 + + +def test_calc_episode_return(): + episode = [ + Experience(0, 0, 1, 0, False, {}), + Experience(0, 0, 2, 0, False, {}), + Experience(0, 0, 3, 0, False, {}), + ] + assert calc_episode_return(episode) == 6 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..a991b07 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,12 @@ +import torch +from functionrl.models import LinearNet + + +def test_linear_net(): + net = LinearNet(2, 4) + assert net(0).shape == (4,) + assert net([0]).shape == (1, 4) + + net = LinearNet(2, 4, one_hot=False) + assert net(torch.tensor([0, 1], dtype=torch.float)).shape == (4,) + assert net(torch.tensor([[0, 1]], dtype=torch.float)).shape == (1, 4) diff --git a/tests/test_policies.py b/tests/test_policies.py index 33e1803..cc040e1 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,7 +1,15 @@ -from functionrl.policies import make_random_policy +import numpy as np +from functionrl.policies import make_random_policy, make_greedy_policy def test_make_random_policy(): policy = make_random_policy(4, seed=1) - actions = [policy(None, step) for step in range(8)] + actions = [policy(None) for _ in range(8)] assert actions == [1, 2, 3, 3, 0, 0, 3, 3] + + +def test_make_greedy_policy(): + q = np.array([[0, 2, 1], [2, 1, 0]]) + policy = make_greedy_policy(q) + assert policy(0) == 1 + assert policy(1) == 0 diff --git a/tests/test_reinforce.py b/tests/test_reinforce.py new file mode 100644 index 0000000..a75016c --- /dev/null +++ b/tests/test_reinforce.py @@ -0,0 +1,30 @@ +from functionrl.algorithms.reinforce import reinforce +from functionrl.envs import make_frozen_lake + + +def test_reinforce(): + policy = reinforce( + make_frozen_lake, + gamma=0.99, + learning_rate=0.01, + n_episodes=2, + log_interval=2, + eval_episodes=2, + seed=1, + ) + + assert policy(0)[0] == 3 + assert policy(1)[0] == 3 + assert policy(2)[0] == 1 + assert policy(3)[0] == 3 + + +def test_reinforce_no_seed(): + reinforce( + make_frozen_lake, + gamma=0.99, + learning_rate=0.01, + n_episodes=2, + log_interval=1, + eval_episodes=2, + ) diff --git a/tests/test_tabular_q.py b/tests/test_tabular_q.py new file mode 100644 index 0000000..0f951f1 --- /dev/null +++ b/tests/test_tabular_q.py @@ -0,0 +1,18 @@ +from functionrl.algorithms.tabular_q import tabular_q +from functionrl.envs import make_frozen_lake + + +def test_tabular_q(): + q = tabular_q( + make_frozen_lake, + gamma=1, + alpha_max=1e-1, + alpha_min=1e-3, + alpha_decay_steps=100_000, + epsilon_max=1.0, + epsilon_min=0.1, + epsilon_decay_steps=100_000, + n_steps=10, + log_interval=5, + seed=0, + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d44558..fe06c1d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ -from functionrl.utils import linear_decay, limitable +from itertools import islice +from functionrl.utils import linear_decay, decay_generator, limitable def test_linear_decay(): @@ -11,6 +12,12 @@ def test_linear_decay(): assert decay(100000) == 0.5 +def test_decay_generator(): + decay = linear_decay(1, 0.5, 2) + generator = decay_generator(decay) + assert list(islice(generator, 5)) == [1.0, 0.75, 0.5, 0.5, 0.5] + + def test_limitable(): @limitable def generator(): diff --git a/tests/test_value_iteration.py b/tests/test_value_iteration.py new file mode 100644 index 0000000..fb21114 --- /dev/null +++ b/tests/test_value_iteration.py @@ -0,0 +1,9 @@ +from functionrl.algorithms.value_iteration import value_iteration +from functionrl.envs import make_frozen_lake + + +def test_value_iteration(): + env = make_frozen_lake() + pi, info = value_iteration(env) + assert pi == [0, 3, 3, 3, 0, 0, 0, 0, 3, 1, 0, 0, 0, 2, 1, 0] + assert info["steps"] == 571