diff --git a/src/pytorch_icem/icem.py b/src/pytorch_icem/icem.py index d18a69b..a0dbf10 100644 --- a/src/pytorch_icem/icem.py +++ b/src/pytorch_icem/icem.py @@ -1,9 +1,11 @@ +import scipy.stats import torch import colorednoise from arm_pytorch_utilities import handle_batch_input - +import scipy import logging - +import numpy as np +from scipy.stats import truncnorm logger = logging.getLogger(__name__) def accumulate_running_cost(running_cost, terminal_state_weight=10.0): @@ -26,11 +28,18 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10 warmup_iters=100, online_iters=100, includes_x0=False, fixed_H=True, - device="cpu"): + device="cpu", + low_bound_action = None, + high_bound_action = None): self.dynamics = dynamics self.trajectory_cost = trajectory_cost + self.low_bound_action = low_bound_action + self.high_bound_action = high_bound_action + + + self.nx = nx self.nu = nu self.H = horizon @@ -42,7 +51,7 @@ def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=10 sigma = torch.ones(self.nu, device=self.device).float() elif isinstance(sigma, float): sigma = torch.ones(self.nu, device=self.device).float() * sigma - if len(sigma.shape) != nu: + if sigma.shape[0] != nu: raise ValueError(f"Sigma must be either a scalar or a vector of length nu {nu}") self.sigma = sigma self.dtype = self.sigma.dtype @@ -81,8 +90,20 @@ def sample_action_sequences(self, state, N): samples = torch.from_numpy(samples).to(device=self.device, dtype=self.dtype) else: samples = torch.randn(N, self.H, self.nu, device=self.device, dtype=self.dtype) - - U = self.mean + self.std * samples + + if self.low_bound_action is not None and self.high_bound_action is not None: + a_trunc = np.array(self.low_bound_action) + b_trunc = np.array(self.high_bound_action) + scale = np.array(self.std) + loc = np.array(self.mean) + a, b = (a_trunc - loc) / scale, (b_trunc - loc) / scale + a[np.isinf(a)] = -1.0 + b[np.isinf(b)] = 1.0 + U_np = truncnorm.rvs(a = a, b = b, size = [N, self.H, self.nu], scale = scale ,loc=loc) + U = torch.Tensor(U_np) + + else: + U = self.mean + self.std * samples return U def update_distribution(self, elites):