Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ jobs:
fail-fast: false
matrix:
version:
- "1.11" # Latest
- "1.11"
- "1.12"
os:
- windows-latest
- ubuntu-latest
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Crux"
uuid = "e51cc422-768a-4345-bb8e-2246287ae729"
authors = ["Anthony Corso <anthonycorso92@gmail.com>"]
version = "0.1.3"
version = "0.1.4"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -11,6 +11,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca"
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Expand All @@ -36,6 +37,7 @@ ColorSchemes = "3"
Distributions = "0.25"
Flux = "0.14"
Images = "0.25, 0.26"
POMDPModels = "0.4"
POMDPTools = "0.1"
POMDPs = "0.9, 1.0"
Parameters = "0.12"
Expand Down
56 changes: 37 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Crux.jl

[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://sisl.github.io/Crux.jl/dev/)
[![Build Status](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml)
[![Code Coverage](https://codecov.io/gh/sisl/Crux.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/sisl/Crux.jl)
[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://sisl.github.io/Crux.jl/dev/) [![Build Status](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/sisl/Crux.jl/actions/workflows/CI.yml) [![Code Coverage](https://codecov.io/gh/sisl/Crux.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/sisl/Crux.jl)

Deep RL library with concise implementations of popular algorithms. Implemented using [Flux.jl](https://github.com/FluxML/Flux.jl) and fits into the [POMDPs.jl](https://github.com/JuliaPOMDP/POMDPs.jl) interface.

Expand Down Expand Up @@ -40,34 +38,54 @@ Supports CPU and GPU computation and implements deep reinforcement learning, imi
* [Experience Replay](https://github.com/sisl/Crux.jl/blob/master/src/model_free/cl/experience_replay.jl)


## Installation

To install the package, run:
```julia
] add Crux
```

## Usage

An example usage of the `REINFORCE` algorithm with a simple Flux network for the Cart Pole problem is shown here:
### Basic Example (Pure Julia)

A minimal example using DQN to solve a GridWorld problem:

```julia
using Crux, POMDPGym
using Crux

# Problem setup
mdp = GymPOMDP(:CartPole)
as = actions(mdp)
mdp = SimpleGridWorld()
S = state_space(mdp)

# Flux network: Map states to actions
A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as)
A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp))

# Setup REINFORCE solver
solver_reinforce = REINFORCE(S=S, π=A())

# Solve the `mdp` to get the `policy`
policy_reinforce = solve(solver_reinforce, mdp)
solver = DQN(π=A(), S=S, N=100_000)
policy = solve(solver, mdp)
```

See the [documentation](https://sisl.github.io/Crux.jl/dev/) for more examples and details.

## Installation
### Gym Environments

To install the package, run:
For OpenAI Gym environments like CartPole, install [POMDPGym.jl](https://github.com/ancorso/POMDPGym.jl):

```julia
] add https://github.com/ancorso/POMDPGym.jl
```
] add Crux

> **Note:** POMDPGym requires Python with [Gymnasium](https://gymnasium.farama.org/) installed (`pip install gymnasium`).

```julia
using Crux, POMDPGym

mdp = GymPOMDP(:CartPole)
as = actions(mdp)
S = state_space(mdp)

A() = DiscreteNetwork(Chain(Dense(dim(S)..., 64, relu), Dense(64, length(as))), as)

solver = REINFORCE(S=S, π=A())
policy = solve(solver, mdp)
```

See the [installation documentation](https://github.com/ancorso/POMDPGym.jl?tab=readme-ov-file#installation) for more details on how to install POMDPGym for more environment.
See the [installation documentation](https://github.com/ancorso/POMDPGym.jl?tab=readme-ov-file#installation) for more details on how to install POMDPGym for more environment.
28 changes: 26 additions & 2 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,34 @@ For a full set of examples, please see the [`examples/`](https://github.com/sisl
- [Hopper Medium (MuJoCo)](https://github.com/sisl/Crux.jl/blob/master/examples/offline%20rl/hopper_medium.jl)


## Minimal RL Example
## Minimal RL Example (Pure Julia)

As a minimal example, we'll show how to set up a cart-pole problem and solve it with a simple Flux network using the REINFORCE algorithm.
As a minimal self-contained example, we'll show how to solve a GridWorld problem using DQN with no external dependencies:

```julia
using Crux

mdp = SimpleGridWorld()
S = state_space(mdp)

A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp))

solver = DQN(π=A(), S=S, N=100_000)
policy = solve(solver, mdp)
```

## Gym Environments

For OpenAI Gym environments like CartPole, you need to install [POMDPGym.jl](https://github.com/ancorso/POMDPGym.jl):

```julia
] add https://github.com/ancorso/POMDPGym.jl
```

!!! note
POMDPGym requires Python with [Gymnasium](https://gymnasium.farama.org/) installed (`pip install gymnasium`).

Here's an example using REINFORCE to solve CartPole:

```julia
using Crux, POMDPGym
Expand Down
1 change: 1 addition & 0 deletions src/Crux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Crux
using Random
using Distributions
@reexport using POMDPs
@reexport using POMDPModels
using POMDPTools:render
using Parameters
using TensorBoardLogger
Expand Down
9 changes: 9 additions & 0 deletions test/readme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using Crux

mdp = SimpleGridWorld()
S = state_space(mdp)

A() = DiscreteNetwork(Chain(Dense(2, 8, relu), Dense(8, 4)), actions(mdp))

solver = DQN(π=A(), S=S, N=100_000)
policy = solve(solver, mdp)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ try
catch end

## Run functionality tests
@testset "README example" begin
include("readme.jl")
end
@testset "spaces" begin
include("spaces_tests.jl")
end
Expand Down
Loading