Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/pytorch_icem/icem.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import scipy.stats
import torch
import colorednoise
from arm_pytorch_utilities import handle_batch_input

import scipy
import logging

import numpy as np
from scipy.stats import truncnorm
logger = logging.getLogger(__name__)

def accumulate_running_cost(running_cost, terminal_state_weight=10.0):
Expand All @@ -26,11 +28,18 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10
warmup_iters=100, online_iters=100,
includes_x0=False,
fixed_H=True,
device="cpu"):
device="cpu",
low_bound_action = None,
high_bound_action = None):

self.dynamics = dynamics
self.trajectory_cost = trajectory_cost

self.low_bound_action = low_bound_action
self.high_bound_action = high_bound_action



self.nx = nx
self.nu = nu
self.H = horizon
Expand All @@ -42,7 +51,7 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10
sigma = torch.ones(self.nu, device=self.device).float()
elif isinstance(sigma, float):
sigma = torch.ones(self.nu, device=self.device).float() * sigma
if len(sigma.shape) != nu:
if sigma.shape[0] != nu:
raise ValueError(f"Sigma must be either a scalar or a vector of length nu {nu}")
self.sigma = sigma
self.dtype = self.sigma.dtype
Expand Down Expand Up @@ -81,8 +90,20 @@ def sample_action_sequences(self, state, N):
samples = torch.from_numpy(samples).to(device=self.device, dtype=self.dtype)
else:
samples = torch.randn(N, self.H, self.nu, device=self.device, dtype=self.dtype)

U = self.mean + self.std * samples

if self.low_bound_action is not None and self.high_bound_action is not None:
a_trunc = np.array(self.low_bound_action)
b_trunc = np.array(self.high_bound_action)
scale = np.array(self.std)
loc = np.array(self.mean)
a, b = (a_trunc - loc) / scale, (b_trunc - loc) / scale
a[np.isinf(a)] = -1.0
b[np.isinf(b)] = 1.0
U_np = truncnorm.rvs(a = a, b = b, size = [N, self.H, self.nu], scale = scale ,loc=loc)
U = torch.Tensor(U_np)

else:
U = self.mean + self.std * samples
return U

def update_distribution(self, elites):
Expand Down