diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 564a70f..37d4d57 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,8 @@ jobs: fail-fast: false matrix: version: - - "1.11" # Latest + - "1.11" + - "1.12" os: - windows-latest - ubuntu-latest diff --git a/Project.toml b/Project.toml index 5f68695..42e66be 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Crux" uuid = "e51cc422-768a-4345-bb8e-2246287ae729" authors = ["Anthony Corso "] -version = "0.1.3" +version = "0.1.4" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -11,6 +11,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" @@ -36,6 +37,7 @@ ColorSchemes = "3" Distributions = "0.25" Flux = "0.14" Images = "0.25, 0.26" +POMDPModels = "0.4" POMDPTools = "0.1" POMDPs = "0.9, 1.0" Parameters = "0.12" diff --git a/README.md b/README.md index 788b717..908e87a 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # Crux.jl -[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://sisl.github.io/Crux.jl/dev/) -[![Build Status](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml) -[![Code Coverage](https://codecov.io/gh/sisl/Crux.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/sisl/Crux.jl) +[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://sisl.github.io/Crux.jl/dev/) [![Build Status](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml) [![Code Coverage](https://codecov.io/gh/sisl/Crux.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/sisl/Crux.jl) Deep RL library with concise implementations of popular algorithms. Implemented using [Flux.jl](https://github.com/FluxML/Flux.jl) and fits into the [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl) interface. @@ -40,34 +38,54 @@ Supports CPU and GPU computation and implements deep reinforcement learning, imi * [Experience Replay](https://github.com/sisl/Crux.jl/blob/master/src/model_free/cl/experience_replay.jl) +## Installation + +To install the package, run: +```julia +] add Crux +``` + ## Usage -An example usage of the `REINFORCE` algorithm with a simple Flux network for the Cart Pole problem is shown here: +### Basic Example (Pure Julia) + +A minimal example using DQN to solve a GridWorld problem: ```julia -using Crux, POMDPGym +using Crux -# Problem setup -mdp = GymPOMDP(:CartPole) -as = actions(mdp) +mdp = SimpleGridWorld() S = state_space(mdp) -# Flux network: Map states to actions -A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as) +A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp)) -# Setup REINFORCE solver -solver_reinforce = REINFORCE(S=S, π=A()) - -# Solve the `mdp` to get the `policy` -policy_reinforce = solve(solver_reinforce, mdp) +solver = DQN(π=A(), S=S, N=100_000) +policy = solve(solver, mdp) ``` +See the [documentation](https://sisl.github.io/Crux.jl/dev/) for more examples and details. -## Installation +### Gym Environments -To install the package, run: +For OpenAI Gym environments like CartPole, install [POMDPGym.jl](https://github.com/ancorso/POMDPGym.jl): + +```julia +] add https://github.com/ancorso/POMDPGym.jl ``` -] add Crux + +> **Note:** POMDPGym requires Python with [Gymnasium](https://gymnasium.farama.org/) installed (`pip install gymnasium`). + +```julia +using Crux, POMDPGym + +mdp = GymPOMDP(:CartPole) +as = actions(mdp) +S = state_space(mdp) + +A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as) + +solver = REINFORCE(S=S, π=A()) +policy = solve(solver, mdp) ``` -See the [installation documentation](https://github.com/ancorso/POMDPGym.jl?tab=readme-ov-file#installation) for more details on how to install POMDPGym for more environment. +See the [installation documentation](https://github.com/ancorso/POMDPGym.jl?tab=readme-ov-file#installation) for more details on how to install POMDPGym for more environment. \ No newline at end of file diff --git a/docs/src/examples.md b/docs/src/examples.md index 6b08d7a..0902997 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -25,10 +25,34 @@ For a full set of examples, please see the [`examples/`](https://github.com/sisl - [Hopper Medium (MuJoCo)](https://github.com/sisl/Crux.jl/blob/master/examples/offline%20rl/hopper_medium.jl) -## Minimal RL Example +## Minimal RL Example (Pure Julia) -As a minimal example, we'll show how to set up a cart-pole problem and solve it with a simple Flux network using the REINFORCE algorithm. +As a minimal self-contained example, we'll show how to solve a GridWorld problem using DQN with no external dependencies: +```julia +using Crux + +mdp = SimpleGridWorld() +S = state_space(mdp) + +A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp)) + +solver = DQN(π=A(), S=S, N=100_000) +policy = solve(solver, mdp) +``` + +## Gym Environments + +For OpenAI Gym environments like CartPole, you need to install [POMDPGym.jl](https://github.com/ancorso/POMDPGym.jl): + +```julia +] add https://github.com/ancorso/POMDPGym.jl +``` + +!!! note + POMDPGym requires Python with [Gymnasium](https://gymnasium.farama.org/) installed (`pip install gymnasium`). + +Here's an example using REINFORCE to solve CartPole: ```julia using Crux, POMDPGym diff --git a/src/Crux.jl b/src/Crux.jl index 113e48b..ac46c17 100644 --- a/src/Crux.jl +++ b/src/Crux.jl @@ -7,6 +7,7 @@ module Crux using Random using Distributions @reexport using POMDPs + @reexport using POMDPModels using POMDPTools:render using Parameters using TensorBoardLogger diff --git a/test/readme.jl b/test/readme.jl new file mode 100644 index 0000000..3c75356 --- /dev/null +++ b/test/readme.jl @@ -0,0 +1,9 @@ +using Crux + +mdp = SimpleGridWorld() +S = state_space(mdp) + +A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp)) + +solver = DQN(π=A(), S=S, N=100_000) +policy = solve(solver, mdp) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 860bd7a..fbee056 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,9 @@ try catch end ## Run functionality tests +@testset "README example" begin + include("readme.jl") +end @testset "spaces" begin include("spaces_tests.jl") end