Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions pyproject.toml

This file was deleted.

2 changes: 1 addition & 1 deletion table_rl/learners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from table_rl.learners.q_learning import QLearning
from table_rl.learners.double_q_learning import DoubleQLearning
from table_rl.learners.sarsa import SARSA
from table_rl.learners.sarsa import Sarsa, SarsaLambda
from table_rl.learners.qv import QVLearning
77 changes: 75 additions & 2 deletions table_rl/learners/sarsa.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,82 @@
from table_rl import learner
import numpy as np

class SarsaLambda(learner.Learner):
"""Class that implements SARSA(λ)."""

class SARSA(learner.Learner):
"""Class that implements SARSA."""
def __init__(self,
num_states,
num_actions,
step_size_schedule,
explorer,
discount=0.99,
trace_lambda=0.9,
initial_val=0.):
self.explorer = explorer
self.step_size_schedule = step_size_schedule
self.q = np.full((num_states, num_actions), initial_val, dtype=float)
self.e = np.zeros((num_states, num_actions), dtype=float) # Eligibility traces
self.discount = discount
self.trace_lambda = trace_lambda
self.next_obs = None
self.next_action = None


def update_q(self, obs, action, reward, terminated, next_obs, next_action):
if terminated:
target = reward
else:
target = reward + self.discount * self.q[next_obs, next_action]
estimate = self.q[obs, action]
delta = target - estimate
step_size = self.step_size_schedule.step_size(obs, action)
self.e *= self.discount * self.trace_lambda # Decay all eligibility traces
self.e[obs, action] += 1 # Increment eligibility trace for the current state-action pair
self.q += step_size * delta * self.e # Update all Q-values based on eligibility traces
self.e *= self.discount * self.trace_lambda # Decay eligibility traces


def act(self, obs: int, train: bool) -> int:
"""Returns an integer
"""
if not train:
return np.argmax(self.q[obs])
if self.next_obs is not None:
assert obs == self.next_obs
q_values = self.q[obs]
if self.next_action is None:
action = self.explorer.select_action(obs, q_values) if train else np.argmax(q_values)
else:
action = self.next_action
self.current_obs = obs
self.action = action
return action

def observe(self, obs: int, reward: float, terminated: bool, truncated: bool, training_mode: bool) -> None:
"""Observe consequences of the last action and update estimates accordingly.

Returns:
None
"""
self.next_obs = obs
next_obs_q_values = self.q[self.next_obs]
# obs, action, reward, terminated, next_obs, next_action
if terminated:
self.next_action = None
else:
self.next_action = self.explorer.select_action(self.next_obs, next_obs_q_values) if training_mode else np.argmax(next_obs_q_values)
self.update_q(self.current_obs, self.action, reward, terminated, obs, self.next_action)
self.explorer.observe(obs, reward, terminated, truncated, training_mode)
self.step_size_schedule.observe(obs, reward, terminated, truncated, training_mode)
if terminated or truncated:
self.current_obs = None
self.next_obs = None
self.next_action = None
self.action = None


class Sarsa(learner.Learner):
"""Class that implements Sarsa."""

def __init__(self,
num_states,
Expand Down
12 changes: 6 additions & 6 deletions tests/learner_tests/test_sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import numpy as np
import table_rl
import table_rl.dp.dp as dp
from table_rl.learners import SARSA
from table_rl.learners import Sarsa

class TestSARSA:
class TestSarsa:
@pytest.fixture(autouse=True)
def setUp(self):
self.env = table_rl.envs.BasicEnv(discount=0.9)
Expand All @@ -21,7 +21,7 @@ def setUp(self):
def test_sarsa_loop(self):
explorer = table_rl.explorers.PolicyExecutor(self.policy)

agent = SARSA(self.T.shape[0],
agent = Sarsa(self.T.shape[0],
self.T.shape[1],
table_rl.step_size_schedulers.ConstantStepSize(0.015),
explorer,
Expand All @@ -43,7 +43,7 @@ def test_sarsa_loop(self):

def test_sarsa_update(self):
explorer = table_rl.explorers.GreedyExplorer(self.T.shape[1])
agent = SARSA(self.T.shape[0],
agent = Sarsa(self.T.shape[0],
self.T.shape[1],
table_rl.step_size_schedulers.ConstantStepSize(0.1),
explorer,
Expand All @@ -60,7 +60,7 @@ def test_sarsa_update(self):

def test_sarsa_update_termination(self):
explorer = table_rl.explorers.GreedyExplorer(self.T.shape[1])
agent = SARSA(self.T.shape[0],
agent = Sarsa(self.T.shape[0],
self.T.shape[1],
table_rl.step_size_schedulers.ConstantStepSize(0.1),
explorer,
Expand All @@ -78,7 +78,7 @@ def test_sarsa_update_termination(self):

def test_sarsa_truncation(self):
explorer = table_rl.explorers.GreedyExplorer(self.T.shape[1])
agent = SARSA(self.T.shape[0],
agent = Sarsa(self.T.shape[0],
self.T.shape[1],
table_rl.step_size_schedulers.ConstantStepSize(0.1),
explorer,
Expand Down