diff --git a/botorch_community/acquisition/discretized.py b/botorch_community/acquisition/discretized.py new file mode 100644 index 0000000..b80eea2 --- /dev/null +++ b/botorch_community/acquisition/discretized.py @@ -0,0 +1,306 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r"""Discretized acquisition functions for Riemann-distributed posteriors. + +NOTE: This module should eventually be moved to: + botorch_community/acquisition/discretized.py +in the meta-pytorch/botorch repository. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import torch +from botorch.acquisition import AcquisitionFunction +from botorch.acquisition.objective import ( + PosteriorTransform, + ScalarizedPosteriorTransform, +) +from botorch.exceptions.errors import UnsupportedError +from botorch.models.model import Model +from botorch.utils.transforms import ( + average_over_ensemble_models, + t_batch_mode_transform, +) +from torch import Tensor + + +class DiscretizedAcquisitionFunction(AcquisitionFunction, ABC): + r"""DiscretizedAcquisitionFunction is an abstract base class for acquisition + functions that are defined on discrete distributions. It wraps a model and + implements a forward method that computes the acquisition function value at + a given set of points. + This class can be subclassed to define acquisition functions for Riemann- + distributed posteriors. + The acquisition function must have the form $$acq(x) = \int p(y|x) ag(x)$$, + where $$ag$$ is defined differently for each acquisition function. + The ag_integrate method, which computes the integral of ag between two points, must + be implemented by subclasses to define the specific acquisition functions. + """ + + def __init__( + self, + model: Model, + posterior_transform: PosteriorTransform, + assume_symmetric_posterior: bool = True, + ) -> None: + r""" + Initialize the DiscretizedAcquisitionFunction + + Args: + model: A fitted model that is used to compute the posterior + distribution over the outcomes of interest. + The model should be a ``PFNModel``. + posterior_transform: A ScalarizedPosteriorTransform that can only + indicate minimization or maximization of the objective. + assume_symmetric_posterior: If True, we simply negate train y, if + the task is to minimize the objective. Else, we use a proper + posterior transform. We cannot do this generally, as some + models only support maximization. This does not mean that + the posterior distribution for a particular set is symmetric + but that one can negate the y's of the context and get out + negated ys. + """ + super().__init__(model=model) + self.set_X_pending(None) + self.assume_symmetric_posterior = assume_symmetric_posterior + self.maximize = True + if posterior_transform is not None: + unsupported_error_message = ( + "Only scalarized posterior transforms with a" + "single objective and 0.0 offset are supported." + ) + if ( + not isinstance(posterior_transform, ScalarizedPosteriorTransform) + or (posterior_transform.offset != 0.0) + or len(posterior_transform.weights) != 1 + or posterior_transform.weights[0] not in [-1.0, 1.0] + ): + raise UnsupportedError(unsupported_error_message) + + self.maximize = posterior_transform.weights[0] == 1.0 + + @t_batch_mode_transform() + @average_over_ensemble_models + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate the acquisition function on the candidate set X. + + Args: + X: A ``(b) x q x d``-dim Tensor of ``(b)`` t-batches with ``q`` ``d``-dim + design points each. + + Returns: + A ``(b)``-dim Tensor of the acquisition function at the given + design points ``X``. + """ + # Note: pending_X is not supported by PFNModel.posterior() + # If X_pending is set, it would need to be handled differently + discrete_posterior = self.model.posterior( + X, + negate_train_ys=(not self.maximize) and self.assume_symmetric_posterior, + ) + if not self.maximize and not self.assume_symmetric_posterior: + discrete_posterior.borders = -torch.flip(discrete_posterior.borders, [0]) + discrete_posterior.probabilities = torch.flip( + discrete_posterior.probabilities, [-1] + ) + + result = discrete_posterior.integrate(self.ag_integrate) + # result has shape (b, q) - sum over q dimension for batch acquisition + # For q=1, this is equivalent to squeeze(-1) + return result.sum(dim=-1) + + @abstractmethod + def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor: + r""" + This function calculates the integral that computes the acquisition function + without the posterior factor from lower_bound to upper_bound. + That is, our acquisition function is assumed to have the form + \int ag(x) * p(x) dx, + and this function calculates \int_{lower_bound}^{upper_bound} ag(x) dx. + The ``integrate`` method of the posterior (``BoundedRiemannPosterior``) + then computes the final acquisition value. + + Args: + lower_bound: lower bound of integral + upper_bound: upper bound of integral + + Returns: + A ``(b)``-dim Tensor of acquisition function derivatives at the given + design points ``X``. + """ + pass # pragma: no cover + + +class DiscretizedExpectedImprovement(DiscretizedAcquisitionFunction): + r"""DiscretizedExpectedImprovement is an acquisition function that + computes the expected improvement over the current best observed value + for a Riemann distribution. + """ + + def __init__( + self, + model: Model, + best_f: Tensor, + posterior_transform: PosteriorTransform | None = None, + assume_symmetric_posterior: bool = True, + ) -> None: + r""" + Initialize the DiscretizedExpectedImprovement + + Args: + model: A fitted model that is used to compute the posterior + distribution over the outcomes of interest. + The model should be a ``PFNModel``. + best_f: A tensor representing the current best observed value. + """ + super().__init__( + model=model, + posterior_transform=posterior_transform, + assume_symmetric_posterior=assume_symmetric_posterior, + ) + self.register_buffer("best_f", torch.as_tensor(best_f)) + + def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor: + r""" + As Expected improvement has ag(y) = (y - best_f).clamp(min=0), and + is defined as \int ag(y) * p(y) dy, we can calculate the integral + of ag(y) like so: + We just calculate ag(y) at beginning and end, and since the function has + a gradient of 1 or 0, we can just take the average of the two. + + Args: + lower_bound: lower bound of integral + upper_bound: upper bound of integral + + Returns: + A ``(b)``-dim Tensor of acquisition function derivatives at the given + design points ``X``. + """ + best_f = self.best_f.to(lower_bound) + + # Case 1: best_f >= upper_bound, entire interval gives 0 improvement + case1_mask = best_f >= upper_bound + + # Case 2: best_f <= lower_bound, entire interval gives improvement + case2_mask = best_f <= lower_bound + + # Case 3: lower_bound < best_f < upper_bound, partial improvement + case3_mask = ~(case1_mask | case2_mask) + + # Initialize result tensor + result = torch.zeros_like(lower_bound) + + # Case 1: result is already 0 + + # Case 2: integral = ( + # ((upper_bound + lower_bound)/2 - best_f) + # * (upper_bound - lower_bound) + # ) + if case2_mask.any(): + bucket_width = upper_bound - lower_bound + bucket_center = (upper_bound + lower_bound) / 2 + result = torch.where( + case2_mask, (bucket_center - best_f) * bucket_width, result + ) + + # Case 3: integral = (upper_bound - best_f)²/2 + if case3_mask.any(): + result = torch.where(case3_mask, (upper_bound - best_f).pow(2) / 2, result) + + return result.clamp_min(0) + + +class DiscretizedNoisyExpectedImprovement(DiscretizedExpectedImprovement): + def __init__( + self, + model: Model, + posterior_transform: PosteriorTransform | None = None, + X_pending: Tensor | None = None, + ) -> None: + r""" + Only works with models trained specifically for this. + + Args: + model: A fitted model that is used to compute the posterior + distribution over the outcomes of interest. + The model should be a ``PFNModelWithPendingPoints``. + X_pending: Optional pending points to include in the model. + """ + super().__init__( + model=model, + posterior_transform=posterior_transform, + best_f=0.0, + ) + # Set pending points on the model if it supports them + if X_pending is not None: + if hasattr(model, "pending_X"): + model.pending_X = X_pending + else: + raise UnsupportedError( + f"Model {type(model).__name__} does not support pending points. " + "Use PFNModelWithPendingPoints for NEI with pending evaluations." + ) + self.set_X_pending(X_pending) + + +class DiscretizedProbabilityOfImprovement(DiscretizedAcquisitionFunction): + r"""DiscretizedProbabilityOfImprovement is an acquisition function that + computes the probability of improvement over the current best observed value + for a Riemann distribution. + """ + + def __init__( + self, + model: Model, + best_f: Tensor, + posterior_transform: PosteriorTransform | None = None, + assume_symmetric_posterior: bool = True, + ) -> None: + r""" + Initialize the DiscretizedProbabilityOfImprovement + + Args: + model: A fitted model that is used to compute the posterior + distribution over the outcomes of interest. + The model should be a ``PFNModel``. + best_f: A tensor representing the current best observed value. + """ + + super().__init__( + model, + posterior_transform, + assume_symmetric_posterior=assume_symmetric_posterior, + ) + self.register_buffer("best_f", torch.as_tensor(best_f)) + + def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor: + r""" + PI is defined as \int ag(y) * p(y) dy, where ag(y) = I(y - best_f) + and I being the indicator function. + + So all we need to do is calculate the portion between the bounds + that is larger than best_f. + We do this by comparing how much higher the upper bound is than best_f, + compared to the size of the bucket. + Then we clamp at one if best_f is below lower_bound and at zero if + best_f is above upper_bound. + + Args: + lower_bound: lower bound of integral + upper_bound: upper bound of integral + + Returns: + A ``(b)``-dim Tensor of acquisition function derivatives at the given + design points ``X``. + """ + best_f = self.best_f.to(lower_bound) + # two separate clamps needed below, as one is a tensor and one a scalar + return ( + (upper_bound - best_f).clamp(min=0.0).clamp(max=upper_bound - lower_bound) + ) diff --git a/botorch_community/models/prior_fitted_network.py b/botorch_community/models/prior_fitted_network.py new file mode 100644 index 0000000..435d4c1 --- /dev/null +++ b/botorch_community/models/prior_fitted_network.py @@ -0,0 +1,966 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +This module defines the botorch model for PFNs (``PFNModel``) and it +provides handy helpers to download pretrained, public PFNs +with ``download_model`` and model paths with ``ModelPaths``. +For the latter to work ``pfns4bo`` must be installed. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import Any, Iterator, Optional, Union + +import torch +from botorch.acquisition.objective import PosteriorTransform +from botorch.exceptions.errors import UnsupportedError +from botorch.logging import logger +from botorch.models.model import Model +from botorch.models.transforms.input import InputTransform +from botorch.utils.transforms import match_batch_shape +from botorch_community.models.utils.prior_fitted_network import ( + download_model, + ModelPaths, +) +from botorch_community.posteriors.riemann import ( + BoundedRiemannPosterior, + MultivariateRiemannPosterior, +) +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood +from pfns.train import MainConfig # @manual=//pytorch/PFNs:PFNs +from torch import Tensor +from torch.nn import Module + + +def get_styles( + model: Module, hps: dict | None, batch_size: int, device: str +) -> dict[str, Tensor]: + if hps is None or (model.style_encoder is None and model.y_style_encoder is None): + return {} + style_kwargs = {} + if model.style_encoder is not None: + hps_subset = { + k: v for k, v in hps.items() if k in model.style_encoder[0].hyperparameters + } + style = ( + model.style_encoder[0] + .hyperparameters_dict_to_tensor(hps_subset) + .repeat(batch_size, 1) + .to(device) + .float() + ) # shape (batch_size, num_styles) + style_kwargs["style"] = style + + if model.y_style_encoder is not None: + hps_subset = { + k: v + for k, v in hps.items() + if k in model.y_style_encoder[0].hyperparameters + } + y_style = ( + model.y_style_encoder[0] + .hyperparameters_dict_to_tensor(hps_subset) + .repeat(batch_size, 1) + .to(device) + .float() + ) # shape (batch_size, num_styles) + style_kwargs["y_style"] = y_style + return style_kwargs + + +class PFNModel(Model): + """Prior-data Fitted Network""" + + def __init__( + self, + train_X: Tensor, + train_Y: Tensor, + model: Module | None = None, + checkpoint_url: str = ModelPaths.pfns4bo_hebo, + train_Yvar: Tensor | None = None, + batch_first: bool = False, + constant_model_kwargs: dict[str, Any] | None = None, + input_transform: InputTransform | None = None, + load_training_checkpoint: bool = False, + style_hyperparameters: dict[str, Any] | None = None, + style: Tensor + | None = None, # should have shape (num_styles,) or (num_features, num_styles) + ) -> None: + """Initialize a PFNModel. + + Either a pre-trained PFN model can be provided via the model kwarg, + or a checkpoint_url can be provided from which the model will be + downloaded. This defaults to the pfns4bo_hebo model. + + Loading the model does an unsafe "weights_only=False" load, so + it is essential that checkpoint_url be a trusted source. + + Args: + train_X: A ``n x d`` tensor of training features. + train_Y: A ``n x 1`` tensor of training observations. + model: A pre-trained PFN model with the following + forward(train_X, train_Y, X) -> logit predictions of shape + ``n x b x c`` where c is the number of discrete buckets + borders: A ``c+1``-dim tensor of bucket borders. + checkpoint_url: The string URL of the PFN model to download and load. + Will be ignored if model is provided. + train_Yvar: Observed variance of train_Y. Currently ignored. + batch_first: Whether the batch dimension is the first dimension of + the input tensors. This is needed to support different PFN + models. For batch-first x has shape ``batch x seq_len x features`` + and for non-batch-first it has shape ``seq_len x batch x features``. + constant_model_kwargs: A dictionary of model kwargs that + will be passed to the model in each forward pass. + input_transform: A Botorch input transform. + load_training_checkpoint: Whether to load a training checkpoint as + produced by the PFNs training code, see github.com/automl/PFNs. + style_hyperparameters: A dictionary of hyperparameters to be passed + to the style and the y-style encoders. It is useful when training + models with ``hyperparameter_sampling`` prior and its style + encoder. One simply supplies the dict with the unnormalized + hyperparameters, e.g., {"noise_std": 0.1}. Omitted values are + treated as unknown and the value will build a Bayesian average + for these, if ``hyperparameter_sampling_skip_style_prob`` > 0 + during pre-training. + style: A tensor of style values to be passed to the model. These + are raw style values of shape (num_styles,), which will then + be extended as needed. + + """ + super().__init__() + if model is None: + model = download_model( + model_path=checkpoint_url, + ) + + if load_training_checkpoint: + # the model is not an actual model, but a training checkpoint + # make a model out of it + checkpoint = model + config = MainConfig.from_dict(checkpoint["config"]) + model = config.model.create_model() + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + + if train_Yvar is not None: + logger.debug("train_Yvar provided but ignored for PFNModel.") + + if train_Y.dim() != 2: + raise UnsupportedError("train_Y must be 2-dimensional.") + + if train_X.dim() != 2: + raise UnsupportedError("train_X must be 2-dimensional.") + + if train_Y.shape[-1] > 1: + raise UnsupportedError("Only 1 target allowed for PFNModel.") + + if train_X.shape[0] != train_Y.shape[0]: + raise UnsupportedError( + "train_X and train_Y must have the same number of rows." + ) + + with torch.no_grad(): + self.transformed_X = self.transform_inputs( + X=train_X, input_transform=input_transform + ) + + self.train_X = train_X # shape: (n, d) + self.train_Y = train_Y # shape: (n, 1) + # Downstream botorch tooling expects a likelihood to be specified, + # so here we use a FixedNoiseGaussianLikelihood that is unused. + if train_Yvar is None: + train_Yvar = torch.zeros_like(train_Y) + self.train_Yvar = train_Yvar # shape: (n, 1) + self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar) + self.pfn = model.to(device=train_X.device) + self.batch_first = batch_first + self.constant_model_kwargs = constant_model_kwargs or {} + self.style_hyperparameters = style_hyperparameters + self.style = style + if input_transform is not None: + self.input_transform = input_transform + self._compute_styles() + + # Cache state initialization + self._training_cache_ready = False + self._cached_negate_train_ys = False + self._cached_context_hash: int | None = None + self._cached_y_encoder_state: dict | None = None + + def _compute_styles(self): + """ + Can be used to compute styles to be used for PFN prediction based on + training data. + + When implemented, will directly modify self.style_hyperparameters or + self.style. + """ + pass + + def _compute_context_hash( + self, train_X: Tensor, train_Y: Tensor, negate_train_ys: bool + ) -> int: + """Compute hash of training context to detect changes. + + Returns a hash based on training data and negation setting. + Used to validate cache consistency, especially with pending points. + """ + return hash(( + id(train_X), + train_X.shape, + tuple(train_X.stride()), + id(train_Y), + train_Y.shape, + tuple(train_Y.stride()), + negate_train_ys, + )) + + @contextmanager + def cache_training_context( + self, negate_train_ys: bool = False + ) -> Iterator[None]: + """Context manager to cache training data for efficient batch evaluation. + + When used with BoTorch's ``optimize_acqf()``, this caches the training + data representations once, allowing many test point evaluations without + recomputing the training context. Gradients flow through test_x only. + + Memory savings: O(n) instead of O(b × n) for training data, where + b is the batch size (num_restarts) and n is the training set size. + + Usage:: + + with model.cache_training_context(): + acqf = ExpectedImprovement(model, best_f=train_Y.max()) + candidates, value = optimize_acqf( + acqf, bounds=bounds, + num_restarts=64, + raw_samples=512, + options={"batch_limit": 64} + ) + + Args: + negate_train_ys: Whether to negate training Y values (for minimization). + """ + # Store original state + original_cache_trainset_representation = getattr( + self.pfn, "cache_trainset_representation", False + ) + + # Validate transformer supports caching + if not hasattr(self.pfn, "cache_trainset_representation"): + raise TypeError( + f"Model {type(self.pfn).__name__} does not support caching. " + "Requires 'cache_trainset_representation' attribute." + ) + if not callable(getattr(self.pfn, "empty_trainset_representation_cache", None)): + raise TypeError( + f"Model {type(self.pfn).__name__} does not support caching. " + "Requires 'empty_trainset_representation_cache()' method." + ) + + try: + # Enable caching on the transformer + self.pfn.cache_trainset_representation = True + + # Prepare training data (single batch) + train_X_bf = self.transformed_X.unsqueeze(0) # (1, n, d) + train_Y_bf = self.train_Y.unsqueeze(0) # (1, n, 1) + if negate_train_ys: + assert self.train_Y.mean().abs() < 1e-4, ( + "train_Y must be zero-centered for negation." + ) + train_Y_bf = -train_Y_bf + + # Hook for subclasses to augment training data (e.g., pending points) + dummy_X = torch.zeros( + 1, 1, train_X_bf.shape[-1], + device=train_X_bf.device, + dtype=train_X_bf.dtype, + ) + train_X_bf, train_Y_bf = self._augment_training_data( + train_X_bf, train_Y_bf, dummy_X, + negate_train_ys=negate_train_ys, use_cache=False + ) + + # Get styles for batch_size=1 (will be broadcast during inference) + styles = self._get_styles(batch_size=1) + + # Populate cache with training data (no gradients on train) + with torch.no_grad(): + if not self.batch_first: + train_X_bf = train_X_bf.transpose(0, 1) # (n, 1, d) + train_Y_bf = train_Y_bf.transpose(0, 1) # (n, 1, 1) + # Create dummy test point to trigger full forward + dummy_test = torch.zeros( + 1, 1, train_X_bf.shape[-1], + device=train_X_bf.device, + dtype=train_X_bf.dtype, + ) # (1, 1, d) -> seq-first + else: + dummy_test = torch.zeros( + 1, 1, train_X_bf.shape[-1], + device=train_X_bf.device, + dtype=train_X_bf.dtype, + ) # (1, 1, d) -> batch-first + + # Forward pass to populate cache + self.pfn( + x=train_X_bf.float(), + y=train_Y_bf.float(), + test_x=dummy_test.float(), + **self.constant_model_kwargs, + **styles, + ) + + # Save y_encoder state after cache population. + # This is critical for MultivariatePFNModel where _compute_conditional_means() + # temporarily disables caching and calls the encoder with augmented data, + # which corrupts the encoder state. We save the state here so it can be + # restored after each correlation estimation call. + if hasattr(self.pfn, "y_encoder") and hasattr( + self.pfn.y_encoder, "save_fitted_state" + ): + self._cached_y_encoder_state = self.pfn.y_encoder.save_fitted_state() + else: + self._cached_y_encoder_state = None + + # Mark cache as ready and store context hash + self._training_cache_ready = True + self._cached_negate_train_ys = negate_train_ys + self._cached_context_hash = self._compute_context_hash( + train_X_bf, train_Y_bf, negate_train_ys + ) + + yield + + finally: + # Clear cache and restore state + self._training_cache_ready = False + self._cached_negate_train_ys = False + self._cached_context_hash = None + self._cached_y_encoder_state = None + + self.pfn.cache_trainset_representation = original_cache_trainset_representation + if hasattr(self.pfn, "empty_trainset_representation_cache"): + self.pfn.empty_trainset_representation_cache() + + def posterior( + self, + X: Tensor, + output_indices: Optional[list[int]] = None, + observation_noise: Union[bool, Tensor] = False, + posterior_transform: Optional[PosteriorTransform] = None, + negate_train_ys: bool = False, + ) -> BoundedRiemannPosterior: + r"""Computes the posterior over model outputs at the provided points. + + Subclasses should override hooks (_augment_training_data, _build_posterior) + rather than this method directly. + + Args: + X: A ``b? x q? x d``-dim Tensor. + output_indices: **Not supported.** + observation_noise: **Not supported.** + posterior_transform: **Not supported.** + negate_train_ys: Whether to negate training Ys (for minimization). + + Returns: + A ``BoundedRiemannPosterior``. + """ + self.pfn.eval() + self._validate_posterior_args(output_indices, observation_noise, posterior_transform) + + # Check cache state + use_cache = self._check_cache_compatibility(negate_train_ys) + + # Prepare base training data + X, train_X, train_Y, orig_X_shape, styles = self._prepare_data( + X, negate_train_ys=negate_train_ys, skip_train_replication=use_cache + ) + + # Hook: Allow subclasses to augment training data (e.g., pending points) + train_X, train_Y = self._augment_training_data( + train_X, train_Y, X, negate_train_ys=negate_train_ys, use_cache=use_cache + ) + + # Core prediction + probabilities = self.pfn_predict( + X=X, + train_X=train_X, + train_Y=train_Y, + use_cache=use_cache, + **self.constant_model_kwargs, + **styles, + ) + probabilities = probabilities.view(*orig_X_shape[:-1], -1) + + # Hook: Allow subclasses to build custom posteriors + return self._build_posterior( + X=X, + probabilities=probabilities, + orig_X_shape=orig_X_shape, + train_X=train_X, + train_Y=train_Y, + styles=styles, + use_cache=use_cache, + ) + + def _validate_posterior_args( + self, + output_indices: Optional[list[int]], + observation_noise: Union[bool, Tensor], + posterior_transform: Optional[PosteriorTransform], + ) -> None: + """Validate unsupported posterior arguments.""" + if output_indices is not None: + raise UnsupportedError( + "output_indices is not None. PFNModel should not be a multi-output model." + ) + if observation_noise: + logger.warning( + "observation_noise is not supported for PFNModel and is being ignored." + ) + if posterior_transform is not None: + raise UnsupportedError("posterior_transform is not supported for PFNModel.") + + def _check_cache_compatibility(self, negate_train_ys: bool) -> bool: + """Check if cache can be used and validate settings.""" + if not self._training_cache_ready: + return False + + if negate_train_ys != self._cached_negate_train_ys: + raise ValueError( + f"negate_train_ys={negate_train_ys} does not match cached " + f"negate_train_ys={self._cached_negate_train_ys}. " + "The cache was populated with a different negation setting." + ) + return True + + def _augment_training_data( + self, + train_X: Tensor | None, + train_Y: Tensor | None, + X: Tensor, + *, + negate_train_ys: bool, + use_cache: bool, + ) -> tuple[Tensor | None, Tensor | None]: + """Hook for subclasses to augment training data. + + Override this instead of posterior() to add pending points, etc. + + Args: + train_X: Training features, shape (b, n, d) or None if use_cache + train_Y: Training targets, shape (b, n, 1) or None if use_cache + X: Test points, shape (b, q, d) + negate_train_ys: Whether train_Y was negated + use_cache: Whether caching is active + + Returns: + Potentially augmented (train_X, train_Y) + """ + return train_X, train_Y + + def _build_posterior( + self, + X: Tensor, + probabilities: Tensor, + orig_X_shape: torch.Size, + train_X: Tensor | None, + train_Y: Tensor | None, + styles: dict[str, Tensor], + use_cache: bool, + ) -> BoundedRiemannPosterior: + """Hook for subclasses to build custom posteriors. + + Override this instead of posterior() for multivariate posteriors, etc. + """ + return BoundedRiemannPosterior( + borders=self.borders, + probabilities=probabilities, + ) + + def _prepare_data( + self, + X: Tensor, + negate_train_ys: bool = False, + skip_train_replication: bool = False, + ) -> tuple[Tensor, Tensor | None, Tensor | None, torch.Size, dict[str, Tensor]]: + """Prepare data for posterior computation. + + Returns: + X: Transformed test points (b, q, d) + train_X: Training features (b, n, d) or None if skip_train_replication + train_Y: Training targets (b, n, 1) or None if skip_train_replication + orig_X_shape: Original shape of X before unsqueezing + styles: Style tensors for the model + """ + orig_X_shape = X.shape # X has shape b? x q? x d + if len(X.shape) > 3: + raise UnsupportedError(f"X must be at most 3-d, got {X.shape}.") + while len(X.shape) < 3: + X = X.unsqueeze(0) + + X = self.transform_inputs(X) # shape (b , q, d) + + if skip_train_replication: + # When using cached training context, we don't need to replicate + # training data since it's already cached in the transformer. + train_X = None + train_Y = None + else: + train_X = match_batch_shape(self.transformed_X, X) # shape (b, n, d) + train_Y = match_batch_shape(self.train_Y, X) # shape (b, n, 1) + if negate_train_ys: + assert self.train_Y.mean().abs() < 1e-4, "train_Y must be zero-centered." + train_Y = -train_Y + + styles = self._get_styles( + batch_size=X.shape[0], + ) # shape (b, num_styles) + return X, train_X, train_Y, orig_X_shape, styles + + def _get_styles(self, batch_size) -> dict[str, Tensor]: + style_kwargs = get_styles( + model=self.pfn, + hps=self.style_hyperparameters, + batch_size=batch_size, + device=self.train_X.device, + ) + if self.style is not None: + assert style_kwargs == {}, ( + "Cannot provide both style and style_hyperparameters." + ) + style_kwargs["style"] = ( + self.style[None] + .repeat(batch_size, 1, 1) + .to(self.train_X.device) + .float() + ) + return style_kwargs + + def pfn_predict( + self, + X: Tensor, + train_X: Tensor | None, + train_Y: Tensor | None, + use_cache: bool = False, + **forward_kwargs, + ) -> Tensor: + """Make a prediction using the PFN model. + + Args: + X: Test points, shape (b, q, d) + train_X: Training features, shape (b, n, d) or None if use_cache + train_Y: Training targets, shape (b, n, 1) or None if use_cache + use_cache: Whether to use cached training representations + **forward_kwargs: Additional kwargs for the PFN model + + Returns: + Probabilities (b, q, num_buckets) + """ + if use_cache: + assert train_X is None and train_Y is None, ( + "Bug: use_cache=True but train_X/train_Y provided" + ) + assert self._training_cache_ready, ( + "Cache requested but not populated. Call cache_training_context() first." + ) + if not self.batch_first: + X = X.transpose(0, 1) + + logits = self.pfn( + x=None, + y=None, + test_x=X.float(), + **forward_kwargs, + ) + + if not self.batch_first: + logits = logits.transpose(0, 1) + else: + assert train_X is not None and train_Y is not None, ( + "Bug: use_cache=False but train_X/train_Y is None" + ) + if not self.batch_first: + X = X.transpose(0, 1) + train_X = train_X.transpose(0, 1) + train_Y = train_Y.transpose(0, 1) + + logits = self.pfn( + x=train_X.float(), + y=train_Y.float(), + test_x=X.float(), + **forward_kwargs, + ) + + if not self.batch_first: + logits = logits.transpose(0, 1) + + logits = logits.to(X.dtype) + probabilities = logits.softmax(dim=-1) + return probabilities + + @property + def borders(self): + return self.pfn.criterion.borders.to(self.train_X.dtype) + + @property + def num_outputs(self) -> int: + """Number of outputs of the model (always 1 for PFNModel).""" + return 1 + + +class PFNModelWithPendingPoints(PFNModel): + """PFNModel that supports pending points (unobserved evaluations). + + Pending points are added to the training context with NaN labels, + allowing the model to account for in-flight evaluations during + Bayesian optimization. + """ + + def __init__( + self, + train_X: Tensor, + train_Y: Tensor, + pending_X: Tensor | None = None, + **kwargs, + ) -> None: + """Initialize with optional pending points. + + Args: + train_X: Training features (n, d) + train_Y: Training targets (n, 1) + pending_X: Optional pending point locations (n', d) + **kwargs: Additional arguments for PFNModel + """ + super().__init__(train_X=train_X, train_Y=train_Y, **kwargs) + self._pending_X = pending_X + + @property + def pending_X(self) -> Tensor | None: + """Current pending points.""" + return self._pending_X + + @pending_X.setter + def pending_X(self, value: Tensor | None) -> None: + """Set pending points. Cannot modify while cache is active.""" + if self._training_cache_ready: + raise RuntimeError( + "Cannot modify pending_X while cache_training_context is active. " + "Exit the context manager first." + ) + self._pending_X = value + + def _augment_training_data( + self, + train_X: Tensor | None, + train_Y: Tensor | None, + X: Tensor, + *, + negate_train_ys: bool, + use_cache: bool, + ) -> tuple[Tensor | None, Tensor | None]: + """Add pending points to training data. + + Args: + train_X: Training features (b, n, d) or None if using cache + train_Y: Training targets (b, n, 1) or None if using cache + X: Test points (b, q, d) + negate_train_ys: Whether Y values are negated + use_cache: Whether using cached training context + + Returns: + Augmented (train_X, train_Y) or (None, None) if using cache + """ + if self._pending_X is None: + return train_X, train_Y + + if use_cache: + # When using cache, pending points were included at cache creation + return None, None + + # Non-cached path: augment training data with pending points + assert train_X is not None and train_Y is not None + + pending_X = self._pending_X[None].expand(X.shape[0], -1, -1) # (b, n', d) + train_X = torch.cat([train_X, pending_X], dim=1) # (b, n+n', d) + train_Y = torch.cat([ + train_Y, + torch.full( + (train_Y.shape[0], pending_X.shape[1], 1), + torch.nan, + device=train_Y.device, + dtype=train_Y.dtype, + ), + ], dim=1) # (b, n+n', 1) + + return train_X, train_Y + + +class MultivariatePFNModel(PFNModel): + """A multivariate PFN model that returns a joint posterior over q batch inputs. + + For this to work correctly it is necessary that the underlying model return a + posterior for the latent f, not the noisy observed y. + """ + + def _build_posterior( + self, + X: Tensor, + probabilities: Tensor, + orig_X_shape: torch.Size, + train_X: Tensor | None, + train_Y: Tensor | None, + styles: dict[str, Tensor], + use_cache: bool, + ) -> Union[BoundedRiemannPosterior, MultivariateRiemannPosterior]: + """Build multivariate posterior with correlation estimation. + + Args: + X: Prepared test points (b, q, d) + probabilities: Predicted probabilities (b?, q?, num_buckets) + orig_X_shape: Original shape of test points + train_X: Training features or None if using cache + train_Y: Training targets or None if using cache + styles: Style tensors for forward pass + use_cache: Whether using cached training context + + Returns: + MultivariateRiemannPosterior if q > 1, else BoundedRiemannPosterior + """ + marginals = BoundedRiemannPosterior( + borders=self.borders, + probabilities=probabilities, + ) + + # If no q dimension or q=1, return marginals + if len(orig_X_shape) == 1 or orig_X_shape[-2] == 1: + return marginals + + # For correlation estimation, we need full training data + if train_X is None or train_Y is None: + # Re-fetch training data without cache for correlation estimation + _, train_X, train_Y, _, _ = self._prepare_data( + X.view(*orig_X_shape), + skip_train_replication=False + ) + # Apply any augmentation (e.g., pending points from subclasses) + train_X, train_Y = self._augment_training_data( + train_X, train_Y, X, + negate_train_ys=self._cached_negate_train_ys if use_cache else False, + use_cache=False, # Force non-cached for correlation + ) + + # Estimate correlation structure with additional forward pass + R = self.estimate_correlations( + X=X, + train_X=train_X, + train_Y=train_Y, + styles=styles, + marginals=marginals, + ) # (b, q, q) + R = R.view(*orig_X_shape[:-2], X.shape[-2], X.shape[-2]) # (b?, q, q) + + return MultivariateRiemannPosterior( + borders=self.borders, + probabilities=marginals.probabilities, + correlation_matrix=R, + ) + + def estimate_correlations( + self, + X: Tensor, + train_X: Tensor, + train_Y: Tensor, + styles: dict[str, Tensor], + marginals: BoundedRiemannPosterior, + ) -> Tensor: + """ + Estimate a correlation matrix R across the q batch of points in X. + Will do a forward pass through the PFN model with batch size O(q^2). + + For every x_q in [x_1, ..., x_Q]: + 1. Add x_q to train_X, with y_q the 90th percentile value for f(x_q) + 2. Evaluate p(f(x_i)) for all points. + + Uses bivariate normal conditioning formulae, and so will be approximate. + + Args: + X: evaluation point, shape (b, q, d) + train_X: Training X, shape (b, n, d) + train_Y: Training Y, shape (b, n, 1) + styles: dict from name to tensor shaped (b, ns) for any styles. + marginals: A posterior object with marginal posteriors for f(X), but no + correlation structure yet added. posterior.probabilities has + shape (b?, q, num_buckets). + + Returns: A (b, q, q) correlation matrix + """ + # Compute conditional distributions with a forward pass + cond_mean, cond_val = self._compute_conditional_means( + X=X, + train_X=train_X, + train_Y=train_Y, + styles=styles, + marginals=marginals, + ) + # Get marginal moments + var = marginals.variance.squeeze(-1) # (b?, q) + mean = marginals.mean.squeeze(-1) # (b?, q) + if len(var.shape) == 1: + var = var.unsqueeze(0) # (b, q) + mean = mean.unsqueeze(0) # (b, q) + # Estimate covariances from conditional distributions + cov = self._estimate_covariances( + cond_mean=cond_mean, + cond_val=cond_val, + mean=mean, + var=var, + ) + # Convert to correlation matrix + S = 1 / torch.sqrt(torch.diagonal(cov, dim1=-2, dim2=-1)) # (b, q) + S = S.unsqueeze(-1).expand(cov.shape) # (b, q, q) + R = S * cov * S.transpose(-1, -2) # (b, q, q) + return R + + def _compute_conditional_means( + self, + X: Tensor, + train_X: Tensor, + train_Y: Tensor, + styles: dict[str, Tensor], + marginals: BoundedRiemannPosterior, + ) -> tuple[Tensor, Tensor]: + """ + Compute conditional means between pairs of points in X. + + Conditioning is done with an additional forward pass through the model. The + returned conditional mean will be of shape (b, q, q), with entry [b, i, j] the + conditional mean of j given i set to the conditioning value. + + Args: + X: evaluation point, shape (b, q, d) + train_X: Training X, shape (b, n, d) + train_Y: Training Y, shape (b, n, 1) + styles: dict from name to tensor shaped (b, ns) for any styles. + marginals: A posterior object with marginal posteriors for f(X), but no + correlation structure yet added. posterior.probabilities has + shape (b?, q, num_buckets). + + Returns: conditional means (b, q, q), and values used for conditioning (b, q). + """ + b, q, d = X.shape + n = train_X.shape[-2] + post_shape = marginals.probabilities.shape[:-1] + # Find the 90th percentile of each eval point. + cond_val = marginals.icdf( + torch.full(post_shape, 0.9, device=X.device, dtype=X.dtype).unsqueeze(0) + ) # (1, b?, q, 1) + cond_val = cond_val.view(b, q) # (b, q) + # Construct conditional training data. + # train_X will have shape (b, q, n+1, d), to have a conditional observation + # for each point. train_Y will have shape (b, q, n+1, 1). + train_X = train_X.unsqueeze(1).expand(b, q, n, d) + cond_X = X.unsqueeze(-2) # (b, q, 1, d) + train_X = torch.cat((train_X, cond_X), dim=-2) # (b, q, n+1, d) + train_Y = train_Y.unsqueeze(1).expand(b, q, n, 1) + cond_Y = cond_val.unsqueeze(-1).unsqueeze(-1) # (b, q, 1, 1) + train_Y = torch.cat((train_Y, cond_Y), dim=-2) # (b, q, n+1, 1) + cond_styles = {} + for name, style in styles.items(): + ns = style.shape[-1] + cond_styles[name] = style.unsqueeze(-2).expand(b, q, ns).reshape(b * q, ns) + # Construct eval points + eval_X = X.unsqueeze(1).expand(b, q, q, d) + # Squeeze everything into necessary 2 batch dims, and do PFN forward pass + # Temporarily disable caching for correlation estimation since we use + # a different training set (with conditioning point) than the cached one + cache_was_enabled = self.pfn.cache_trainset_representation + try: + self.pfn.cache_trainset_representation = False + cond_probabilities = self.pfn_predict( + X=eval_X.reshape(b * q, q, d), + train_X=train_X.reshape(b * q, n + 1, d), + train_Y=train_Y.reshape(b * q, n + 1, 1), + **cond_styles, + ) # (b * q, q, num_buckets) + finally: + self.pfn.cache_trainset_representation = cache_was_enabled + # Restore y_encoder state if it was saved during cache population. + # The pfn_predict call above fitted the encoder on augmented data (n+1 points), + # which corrupts the state that was set when the cache was populated. + if ( + hasattr(self, "_cached_y_encoder_state") + and self._cached_y_encoder_state is not None + and hasattr(self.pfn, "y_encoder") + and hasattr(self.pfn.y_encoder, "restore_fitted_state") + ): + self.pfn.y_encoder.restore_fitted_state(self._cached_y_encoder_state) + # Object for conditional posteriors + cond_posterior = BoundedRiemannPosterior( + borders=self.borders, + probabilities=cond_probabilities, + ) + # Get conditional means + cond_mean = cond_posterior.mean.squeeze(-1) # (b * q, q) + cond_mean = cond_mean.unsqueeze(0).view(b, q, q) + return cond_mean, cond_val + + def _estimate_covariances( + self, + cond_mean: Tensor, + cond_val: Tensor, + mean: Tensor, + var: Tensor, + ) -> Tensor: + """ + Estimate covariances from conditional distributions. + + Part one: Compute noise variance implied by conditional distributions + E[f_j | y_j=y] = E[f_j] + var[f_j]/(var[f_j] + noise_var) * (y - E[f_j]) + Let Z_jj = (E[f_j | y_j=y] - E[f_j]) / (y - E[f_j]). + Note that Z is in (0, 1]. + Then, noise_var_j = var[f_j](1/Z_jj - 1). + + Part two: Compute covariances for all pairs + E[f_j|y_i=y] = E[f_j]+cov[f_j, f_i]/(var[f_i] + noise_var_i) * (y - E[f_i]) + Let Z_ij = (E[f_j | y_i=y] - E[f_j]) / (y - E[f_i]). + Then, cov[f_j, f_i] = Z * (var[f_i] + noise_var) + + Args: + cond_mean: (b, q, q) means of dim -1 conditioned on dim -2 + cond_val: (b, q) conditioned y value. + var: (b, q) marginal variances + mean: (b, q) marginal means + + Returns: Covariance matrix + """ + Z = (cond_mean - mean.unsqueeze(-2).expand(cond_mean.shape)) / ( + cond_val - mean + ).unsqueeze(-1) # (b, q, q) + # Z[i, j] is for j cond. on i + noise_var = torch.clamp( + var * (1 / torch.diagonal(Z, dim1=-2, dim2=-1) - 1), min=1e-8 + ) # (b, q) + cov = Z * (var + noise_var).unsqueeze(-1) # (b, q, q) + # Symmetrize + cov = 0.5 * (cov + cov.transpose(-1, -2)) + cov = self._map_psd(cov) + return cov + + def _map_psd(self, A): + """ + Map A (assumed symmetric) to the nearest PSD matrix. + """ + if torch.linalg.eigvals(A).real.min() < 0: + L, Q = torch.linalg.eigh(A) + L = torch.clamp(L, min=1e-6) + A = Q @ torch.diag_embed(L) @ Q.transpose(-1, -2) + return A diff --git a/botorch_community/tests/test_botorch_pfn_caching.py b/botorch_community/tests/test_botorch_pfn_caching.py new file mode 100644 index 0000000..f2dd7ce --- /dev/null +++ b/botorch_community/tests/test_botorch_pfn_caching.py @@ -0,0 +1,545 @@ +"""Tests for PFNModel caching functionality. + +These tests verify the refactored caching implementation including: +- Cache state initialization +- Cache capability validation +- Context manager behavior +- Cached prediction correctness +- Gradient flow through cached predictions +- Negation handling +- Subclass behavior (PFNModelWithPendingPoints, MultivariatePFNModel) +""" + +import pytest +import torch +from unittest.mock import MagicMock + +from botorch_community.models.prior_fitted_network import ( + PFNModel, + PFNModelWithPendingPoints, + MultivariatePFNModel, +) +from pfns.model.bar_distribution import BarDistribution +from pfns.model.transformer import TableTransformer + + +@pytest.fixture +def device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + +@pytest.fixture +def dtype(): + return torch.float32 + + +@pytest.fixture +def simple_transformer(device, dtype): + """Create a simple transformer for testing.""" + torch.manual_seed(42) + num_buckets = 10 + borders = torch.linspace(-3, 3, num_buckets + 1) + + model = TableTransformer( + ninp=32, + nhead=2, + nhid=64, + nlayers=2, + batch_first=True, + cache_trainset_representation=False, + decoder_dict={"standard": (None, num_buckets)}, + ) + model.criterion = BarDistribution(borders=borders) + model = model.to(device=device, dtype=dtype) + + # Add small perturbation to weights to ensure gradient flow + with torch.no_grad(): + for p in model.parameters(): + p.add_(0.01 * torch.randn_like(p)) + + return model.eval() + + +@pytest.fixture +def train_data(device, dtype): + """Create training data for tests.""" + torch.manual_seed(42) + n_train, n_dims = 20, 5 + train_X = torch.rand(n_train, n_dims, device=device, dtype=dtype) + train_Y = torch.sin(train_X.sum(dim=-1, keepdim=True)) + train_Y = (train_Y - train_Y.mean()) / (train_Y.std() + 1e-8) + return train_X, train_Y + + +class TestPFNModelCacheInitialization: + """Test cache state initialization.""" + + def test_cache_state_initialized_in_init(self, simple_transformer, train_data): + """Cache state should be explicitly initialized.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + assert hasattr(model, "_training_cache_ready") + assert model._training_cache_ready is False + assert hasattr(model, "_cached_negate_train_ys") + assert model._cached_negate_train_ys is False + assert hasattr(model, "_cached_context_hash") + assert model._cached_context_hash is None + + +class TestPFNModelCacheCapabilityCheck: + """Test validation of transformer caching capability.""" + + def test_raises_on_missing_cache_attribute(self, train_data, device, dtype): + """Should raise if transformer doesn't support caching.""" + train_X, train_Y = train_data + + # Create mock model without caching support + mock_model = MagicMock() + mock_model.to = MagicMock(return_value=mock_model) + del mock_model.cache_trainset_representation # Ensure missing + mock_model.criterion = MagicMock() + mock_model.criterion.borders = torch.linspace(-3, 3, 11, device=device, dtype=dtype) + + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=mock_model, + batch_first=True, + ) + + with pytest.raises(TypeError, match="does not support caching"): + with model.cache_training_context(): + pass + + def test_raises_on_missing_cache_clear_method(self, train_data, device, dtype): + """Should raise if transformer lacks cache clearing method.""" + train_X, train_Y = train_data + + mock_model = MagicMock() + mock_model.to = MagicMock(return_value=mock_model) + mock_model.cache_trainset_representation = False + del mock_model.empty_trainset_representation_cache + mock_model.criterion = MagicMock() + mock_model.criterion.borders = torch.linspace(-3, 3, 11, device=device, dtype=dtype) + + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=mock_model, + batch_first=True, + ) + + with pytest.raises(TypeError, match="does not support caching"): + with model.cache_training_context(): + pass + + +class TestPFNModelCacheContextManager: + """Test cache_training_context behavior.""" + + def test_cache_enabled_within_context(self, simple_transformer, train_data, device): + """Cache should be enabled within context.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + assert not model._training_cache_ready + + with model.cache_training_context(): + assert model._training_cache_ready + assert model.pfn.cache_trainset_representation + + assert not model._training_cache_ready + + def test_cache_cleaned_on_exit(self, simple_transformer, train_data, device): + """Cache should be cleared when exiting context.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + with model.cache_training_context(): + pass + + assert not model._training_cache_ready + assert model._cached_context_hash is None + + def test_cache_cleaned_on_exception(self, simple_transformer, train_data, device): + """Cache should be cleared even if exception occurs.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + with pytest.raises(RuntimeError): + with model.cache_training_context(): + assert model._training_cache_ready + raise RuntimeError("Test exception") + + assert not model._training_cache_ready + + +class TestPFNModelCachedPredictions: + """Test that cached predictions match non-cached.""" + + def test_predictions_match(self, simple_transformer, train_data, device, dtype): + """Cached and non-cached predictions should be identical.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + torch.manual_seed(123) + test_X = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + + # Non-cached + with torch.no_grad(): + posterior_std = model.posterior(test_X) + mean_std = posterior_std.mean.clone() + var_std = posterior_std.variance.clone() + + # Cached + with model.cache_training_context(): + with torch.no_grad(): + posterior_cached = model.posterior(test_X) + mean_cached = posterior_cached.mean.clone() + var_cached = posterior_cached.variance.clone() + + torch.testing.assert_close(mean_std, mean_cached, atol=1e-7, rtol=1e-7) + torch.testing.assert_close(var_std, var_cached, atol=1e-7, rtol=1e-7) + + def test_predictions_match_batched(self, simple_transformer, train_data, device, dtype): + """Cached predictions should work with batched inputs.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + torch.manual_seed(456) + test_X = torch.rand(8, 3, train_X.shape[-1], device=device, dtype=dtype) # (b, q, d) + + with torch.no_grad(): + posterior_std = model.posterior(test_X) + mean_std = posterior_std.mean.clone() + + with model.cache_training_context(): + with torch.no_grad(): + posterior_cached = model.posterior(test_X) + mean_cached = posterior_cached.mean.clone() + + torch.testing.assert_close(mean_std, mean_cached, atol=1e-7, rtol=1e-7) + + def test_multiple_calls_in_context(self, simple_transformer, train_data, device, dtype): + """Multiple posterior calls within cache context should all be consistent.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + torch.manual_seed(789) + test_X1 = torch.rand(3, train_X.shape[-1], device=device, dtype=dtype) + test_X2 = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + + # Get non-cached baselines + with torch.no_grad(): + mean_std1 = model.posterior(test_X1).mean.clone() + mean_std2 = model.posterior(test_X2).mean.clone() + + # Get cached results + with model.cache_training_context(): + with torch.no_grad(): + mean_cached1 = model.posterior(test_X1).mean.clone() + mean_cached2 = model.posterior(test_X2).mean.clone() + + torch.testing.assert_close(mean_std1, mean_cached1, atol=1e-7, rtol=1e-7) + torch.testing.assert_close(mean_std2, mean_cached2, atol=1e-7, rtol=1e-7) + + +class TestPFNModelCacheGradients: + """Test gradient flow through cached predictions.""" + + def test_gradients_flow_through_test_x(self, simple_transformer, train_data, device, dtype): + """Gradients should flow through test_x when caching.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + torch.manual_seed(111) + test_X = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype, requires_grad=True) + + with model.cache_training_context(): + posterior = model.posterior(test_X) + loss = posterior.mean.sum() + loss.backward() + + assert test_X.grad is not None + assert not torch.all(test_X.grad == 0) + + def test_gradients_match_non_cached(self, simple_transformer, train_data, device, dtype): + """Gradients should match between cached and non-cached.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + torch.manual_seed(222) + base_test_X = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + + # Non-cached gradient + test_X_std = base_test_X.clone().requires_grad_(True) + posterior_std = model.posterior(test_X_std) + posterior_std.mean.sum().backward() + grad_std = test_X_std.grad.clone() + + # Cached gradient + test_X_cached = base_test_X.clone().requires_grad_(True) + with model.cache_training_context(): + posterior_cached = model.posterior(test_X_cached) + posterior_cached.mean.sum().backward() + grad_cached = test_X_cached.grad.clone() + + torch.testing.assert_close(grad_std, grad_cached, atol=1e-7, rtol=1e-7) + + def test_gradient_shape_consistency(self, simple_transformer, train_data, device, dtype): + """Gradients should have correct shape for various input shapes.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + test_shapes = [ + (5, train_X.shape[-1]), # (q, d) + (4, 1, train_X.shape[-1]), # (b, q=1, d) + (3, 2, train_X.shape[-1]), # (b, q, d) + ] + + for shape in test_shapes: + test_X = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True) + + with model.cache_training_context(): + posterior = model.posterior(test_X) + posterior.mean.sum().backward() + + assert test_X.grad is not None, f"No gradient for shape {shape}" + assert test_X.grad.shape == shape, ( + f"Wrong gradient shape for input {shape}: got {test_X.grad.shape}" + ) + + +class TestPFNModelCacheNegation: + """Test cache behavior with negate_train_ys.""" + + def test_negate_mismatch_raises(self, simple_transformer, train_data, device, dtype): + """Should raise if negate_train_ys doesn't match cache.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + test_X = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + + with model.cache_training_context(negate_train_ys=False): + with pytest.raises(ValueError, match="negate_train_ys"): + model.posterior(test_X, negate_train_ys=True) + + def test_negation_cached_correctly(self, simple_transformer, train_data, device, dtype): + """Negated predictions should match between cached and non-cached.""" + train_X, train_Y = train_data + model = PFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + torch.manual_seed(333) + test_X = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + + # Non-cached with negation + with torch.no_grad(): + mean_std = model.posterior(test_X, negate_train_ys=True).mean.clone() + + # Cached with negation + with model.cache_training_context(negate_train_ys=True): + with torch.no_grad(): + mean_cached = model.posterior(test_X, negate_train_ys=True).mean.clone() + + torch.testing.assert_close(mean_std, mean_cached, atol=1e-7, rtol=1e-7) + + +class TestPFNModelWithPendingPointsCaching: + """Test caching with pending points.""" + + def test_cache_includes_pending_points(self, simple_transformer, train_data, device, dtype): + """Cache should include pending points.""" + train_X, train_Y = train_data + pending_X = torch.rand(3, train_X.shape[-1], device=device, dtype=dtype) + + model = PFNModelWithPendingPoints( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + pending_X=pending_X, + batch_first=True, + ) + + torch.manual_seed(444) + test_X = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + + # Non-cached + with torch.no_grad(): + posterior_std = model.posterior(test_X) + mean_std = posterior_std.mean.clone() + + # Cached (should include pending points) + with model.cache_training_context(): + with torch.no_grad(): + posterior_cached = model.posterior(test_X) + mean_cached = posterior_cached.mean.clone() + + torch.testing.assert_close(mean_std, mean_cached, atol=1e-7, rtol=1e-7) + + def test_pending_modification_blocked_during_cache( + self, simple_transformer, train_data, device, dtype + ): + """Cannot modify pending_X while cache is active.""" + train_X, train_Y = train_data + pending_X = torch.rand(3, train_X.shape[-1], device=device, dtype=dtype) + + model = PFNModelWithPendingPoints( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + pending_X=pending_X, + batch_first=True, + ) + + with model.cache_training_context(): + with pytest.raises(RuntimeError, match="Cannot modify pending_X"): + model.pending_X = torch.rand(2, train_X.shape[-1], device=device, dtype=dtype) + + def test_pending_points_property(self, simple_transformer, train_data, device, dtype): + """Test pending_X property getter and setter.""" + train_X, train_Y = train_data + pending_X = torch.rand(3, train_X.shape[-1], device=device, dtype=dtype) + + model = PFNModelWithPendingPoints( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + pending_X=pending_X, + batch_first=True, + ) + + # Test getter + assert model.pending_X is not None + assert model.pending_X.shape == pending_X.shape + + # Test setter (outside cache context) + new_pending = torch.rand(5, train_X.shape[-1], device=device, dtype=dtype) + model.pending_X = new_pending + assert model.pending_X.shape == new_pending.shape + + # Test setter to None + model.pending_X = None + assert model.pending_X is None + + +class TestMultivariatePFNModelCaching: + """Test caching with multivariate posteriors.""" + + def test_marginals_use_cache(self, simple_transformer, train_data, device, dtype): + """Marginal computation should use cache.""" + train_X, train_Y = train_data + model = MultivariatePFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + # Single point - should return BoundedRiemannPosterior + torch.manual_seed(555) + test_X = torch.rand(1, train_X.shape[-1], device=device, dtype=dtype) + + with torch.no_grad(): + posterior_std = model.posterior(test_X) + mean_std = posterior_std.mean.clone() + + with model.cache_training_context(): + with torch.no_grad(): + posterior_cached = model.posterior(test_X) + mean_cached = posterior_cached.mean.clone() + + torch.testing.assert_close(mean_std, mean_cached, atol=1e-7, rtol=1e-7) + + def test_multivariate_with_cache(self, simple_transformer, train_data, device, dtype): + """Multivariate posterior should work with caching.""" + train_X, train_Y = train_data + model = MultivariatePFNModel( + train_X=train_X, + train_Y=train_Y, + model=simple_transformer, + batch_first=True, + ) + + # Multiple points - should return MultivariateRiemannPosterior + torch.manual_seed(666) + test_X = torch.rand(3, train_X.shape[-1], device=device, dtype=dtype) + + with torch.no_grad(): + posterior_std = model.posterior(test_X) + mean_std = posterior_std.mean.clone() + + with model.cache_training_context(): + with torch.no_grad(): + posterior_cached = model.posterior(test_X) + mean_cached = posterior_cached.mean.clone() + + # Marginals should match + torch.testing.assert_close(mean_std, mean_cached, atol=1e-7, rtol=1e-7) + + # Correlation should be computed + assert hasattr(posterior_cached, "correlation_matrix") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pfns/batch_shape_sampler.py b/pfns/batch_shape_sampler.py index 8369a6b..ad7af5d 100644 --- a/pfns/batch_shape_sampler.py +++ b/pfns/batch_shape_sampler.py @@ -49,7 +49,6 @@ def sample_batch_shape(self, epoch: int, step: int) -> BatchShape: single_eval_pos = rng.randint( self.min_single_eval_pos, self.max_seq_len - - 1 - ( self.fixed_num_test_instances if self.fixed_num_test_instances is not None diff --git a/pfns/model/bar_distribution.py b/pfns/model/bar_distribution.py index 3f223f6..81b0f32 100644 --- a/pfns/model/bar_distribution.py +++ b/pfns/model/bar_distribution.py @@ -13,7 +13,6 @@ from typing import Any, List, TYPE_CHECKING import torch - from pfns import base_config from torch import nn from typing_extensions import override @@ -38,9 +37,21 @@ def get_criterion(self): else: return BarDistribution(**kwargs) + @classmethod + def _loading_kwarg_transform(cls, kwargs): + if kwargs.pop("sobolev_multiplier", None) is not None: + print("WARNING: sobolev_multiplier is not running no more.") + return kwargs + class BarDistribution(nn.Module): - def __init__(self, borders: torch.Tensor, *, ignore_nan_targets: bool = True): + def __init__( + self, + borders: torch.Tensor, + *, + ignore_nan_targets: bool = True, + probabilities: torch.Tensor | None = None, + ): """Loss for a distribution over bars. The bars are defined by the borders. The loss is the negative log density of the distribution. The density is defined as a softmax over the logits, where the softmax is scaled by the width of the @@ -59,6 +70,7 @@ def __init__(self, borders: torch.Tensor, *, ignore_nan_targets: bool = True): assert len(borders.shape) == 1 borders = borders.contiguous() self.register_buffer("borders", borders) + self.probabilities = probabilities full_width = self.bucket_widths.sum() assert (1 - (full_width / (self.borders[-1] - self.borders[0]))).abs() < 1e-2, ( @@ -223,14 +235,13 @@ def forward( """Returns the negative log density (the _loss_). Args: - logits: The logits of the model. - y: The ys to compute the loss for. + logits: The logits of the model, shape (*batch_shape, num_bars) + y: The ys to compute the loss for, shape (*batch_shape,) Returns: The negative log density. """ # gives the negative log density (the _loss_), - # y: T x B, logits: T x B x self.num_bars y = y.clone().view(*logits.shape[:-1]) # no trailing one dimension ignore_loss_mask = self.ignore_init(y) target_sample = self.map_to_bucket_idx(y) @@ -244,7 +255,6 @@ def forward( scaled_bucket_log_probs = self.compute_scaled_log_probs(logits) - # T x B nll_loss = -scaled_bucket_log_probs.gather( -1, target_sample[..., None], @@ -266,27 +276,30 @@ def mean(self, logits: torch.Tensor) -> torch.Tensor: return p @ bucket_means def median(self, logits: torch.Tensor) -> torch.Tensor: - return self.icdf(logits, 0.5) + return self.icdf(logits=logits, left_prob=0.5) - def icdf(self, logits: torch.Tensor, left_prob: float) -> torch.Tensor: + def icdf(self, left_prob: float, logits: torch.Tensor | None) -> torch.Tensor: """Implementation of the quantile function :param logits: Tensor of any shape, with the last dimension being logits :param left_prob: float: The probability mass to the left of the result. :return: Position with `left_prob` probability weight to the left. """ - probs = logits.softmax(-1) + if logits is None and self.probabilities is not None: + probs = self.probabilities + else: + probs = logits.softmax(-1) cumprobs = torch.cumsum(probs, -1) idx = ( torch.searchsorted( cumprobs, - left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=logits.device), + left_prob * torch.ones(*cumprobs.shape[:-1], 1, device=probs.device), ) .squeeze(-1) .clamp(0, cumprobs.shape[-1] - 1) ) # this might not do the right for outliers cumprobs = torch.cat( [ - torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), + torch.zeros(*cumprobs.shape[:-1], 1, device=probs.device), cumprobs, ], -1, @@ -300,6 +313,19 @@ def icdf(self, logits: torch.Tensor, left_prob: float) -> torch.Tensor: idx[..., None], ).squeeze(-1) + def sample(self, logits: torch.Tensor, t: float = 1.0) -> torch.Tensor: + """Samples values from the distribution. + + Temperature t. + """ + p_cdf = torch.rand(*logits.shape[:-1]) + return torch.tensor( + [ + self.icdf(logits=logits[i, :] / t, left_prob=p) + for i, p in enumerate(p_cdf.tolist()) + ], + ) + def quantile( self, logits: torch.Tensor, @@ -308,8 +334,8 @@ def quantile( side_probs = (1.0 - center_prob) / 2 return torch.stack( ( - self.icdf(logits, side_probs), - self.icdf(logits, 1.0 - side_probs), + self.icdf(logits=logits, left_prob=side_probs), + self.icdf(logits=logits, left_prob=1.0 - side_probs), ), -1, ) @@ -344,7 +370,7 @@ def ucb( """ if maximize: rest_prob = 1 - rest_prob - return self.icdf(logits, rest_prob) + return self.icdf(logits=logits, left_prob=rest_prob) def mode(self, logits: torch.Tensor) -> torch.Tensor: density = logits.softmax(-1) / self.bucket_widths @@ -586,16 +612,6 @@ def pdf(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Probability density function at y.""" return torch.exp(self.forward(logits, y)) - def sample(self, logits: torch.Tensor, t: float = 1.0) -> torch.Tensor: - """Samples values from the distribution. - - Temperature t. - """ - p_cdf = torch.rand(*logits.shape[:-1]) - return torch.tensor( - [self.icdf(logits[i, :] / t, p) for i, p in enumerate(p_cdf.tolist())], - ) - @override def mean(self, logits: torch.Tensor) -> torch.Tensor: bucket_means = self.borders[:-1] + self.bucket_widths / 2 @@ -834,6 +850,8 @@ def get_bucket_borders( if ys is not None: ys = ys.flatten() ys = ys[~torch.isnan(ys)] + ys = ys.sort()[0] + ys = torch.unique_consecutive(ys) assert ( len(ys) > num_outputs ), f"Number of ys :{len(ys)} must be larger than num_outputs: {num_outputs}" @@ -869,9 +887,9 @@ def get_bucket_borders( 0, ) - assert len(borders) - 1 == num_outputs, ( - f"len(borders) - 1 == {len(borders) - 1}" f" != {num_outputs} == num_outputs" - ) + assert ( + len(borders) - 1 == num_outputs + ), f"len(borders) - 1 == {len(borders) - 1} != {num_outputs} == num_outputs" if not widen_borders_factor or widen_borders_factor == 1.0: assert ( @@ -881,9 +899,4 @@ def get_bucket_borders( full_range[-1] == borders[-1] # type: ignore ), f"{full_range[-1]} != {borders[-1]}" # type: ignore - unique_borders = torch.unique_consecutive(borders) - - if (unique_borders != borders).any(): - print("Borders were not unique, removed duplicates.") - - return unique_borders + return borders diff --git a/pfns/model/encoders.py b/pfns/model/encoders.py index 171955b..28bb5c7 100644 --- a/pfns/model/encoders.py +++ b/pfns/model/encoders.py @@ -3,12 +3,10 @@ from __future__ import annotations from dataclasses import dataclass - from typing import Any import numpy as np import torch - from pfns import base_config from pfns.model import encoders from pfns.priors.hyperparameter_sampling import ( @@ -121,6 +119,26 @@ def create_encoder(self, features, emsize): ### Style Encoders +class ConstantStyleNormalization(nn.Module): + def __init__(self, mean, std): + super().__init__() + self.mean = mean + self.std = std + + def forward(self, x): + return (x - self.mean) / self.std + + +class NanHandlingStyleEncoder(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + nan_indicators = torch.isnan(x).float() * 2.0 - 1 + x = torch.nan_to_num(x, nan=0.0) + return torch.cat([x, nan_indicators], dim=-1) + + def linear_style_encoder(num_styles, emsize): return nn.Linear(num_styles, emsize) @@ -132,6 +150,16 @@ class StyleEncoderConfig(base_config.BaseConfig): dict[str, base_config.BaseTypes | DistributionConfig] | None ) = None encoder_type: str = "linear" + constant_normalization_mean: float = 0.0 + constant_normalization_std: float = 1.0 + nan_handling: bool | None = None + + def __post_init__(self): + if self.normalize_to_hyperparameters is not None: + assert self.constant_normalization_mean == 0.0 + assert self.constant_normalization_std == 1.0 + assert self.nan_handling is None + assert self.num_styles is None def create_encoder(self, emsize): num_features = self.num_styles @@ -145,9 +173,21 @@ def create_encoder(self, emsize): hpn = HyperparameterNormalizer(self.normalize_to_hyperparameters) num_features = hpn.num_hps * 2 modules.append(hpn) + else: + assert self.num_styles is not None + normalizer = ConstantStyleNormalization( + mean=self.constant_normalization_mean, + std=self.constant_normalization_std, + ) + modules.append(normalizer) + + if self.nan_handling: + modules.append(NanHandlingStyleEncoder()) + num_features *= 2 if self.encoder_type == "linear": modules.append(encoders.linear_style_encoder(num_features, emsize)) + modules.append(nn.Flatten()) return nn.Sequential(*modules) else: raise ValueError( @@ -206,6 +246,63 @@ def forward( return input[self.output_key] if self.output_key is not None else input + def freeze_fitted_state(self) -> None: + """Freeze all encoder steps to prevent _fit() from being called. + + When frozen, encoder steps will reuse their existing fitted state + instead of refitting on new data. This is used during cached inference + to prevent correlation estimation from corrupting the encoder state. + """ + for module in self: + if hasattr(module, "_fitted_state_frozen"): + module._fitted_state_frozen = True + + def unfreeze_fitted_state(self) -> None: + """Unfreeze all encoder steps to allow _fit() to be called. + + This restores normal behavior where encoder steps fit on training data. + """ + for module in self: + if hasattr(module, "_fitted_state_frozen"): + module._fitted_state_frozen = False + + def save_fitted_state(self) -> dict: + """Save the fitted state of all encoder steps. + + Returns a dictionary mapping module indices to their fitted state. + This can be restored later with `restore_fitted_state`. + """ + saved_state = {} + for idx, module in enumerate(self): + module_state = {} + # Save any attribute that starts with underscore and ends with underscore + # (fitted state convention: _attribute_) + for attr_name in dir(module): + if ( + attr_name.endswith("_") + and not attr_name.startswith("__") + and not callable(getattr(module, attr_name, None)) + ): + attr_value = getattr(module, attr_name, None) + if isinstance(attr_value, torch.Tensor): + module_state[attr_name] = attr_value.clone() + elif attr_value is not None: + module_state[attr_name] = attr_value + if module_state: + saved_state[idx] = module_state + return saved_state + + def restore_fitted_state(self, saved_state: dict) -> None: + """Restore the fitted state of all encoder steps. + + Args: + saved_state: Dictionary returned by `save_fitted_state`. + """ + for idx, module_state in saved_state.items(): + module = self[idx] + for attr_name, attr_value in module_state.items(): + setattr(module, attr_name, attr_value) + class SeqEncStep(nn.Module): """Abstract base class for sequential encoder steps. @@ -235,6 +332,10 @@ def __init__( super().__init__() self.in_keys = in_keys self.out_keys = out_keys + # When frozen, _fit() is skipped and the existing fitted state is used. + # This is used during cached inference to prevent correlation estimation + # from corrupting the encoder state. + self._fitted_state_frozen = False # Either implement _forward: @@ -307,7 +408,14 @@ def forward( """ args = [state[in_key] for in_key in self.in_keys] if hasattr(self, "_fit"): - if kwargs["single_eval_pos"] or not cache_trainset_representation: + # Skip _fit if: + # 1. State is frozen (during cached inference to prevent corruption) + # 2. OR cache is enabled and we're in inference mode (no training data) + should_fit = ( + not self._fitted_state_frozen + and (kwargs["single_eval_pos"] or not cache_trainset_representation) + ) + if should_fit: self._fit(*args, **kwargs) out = self._transform(*args, **kwargs) else: @@ -444,11 +552,21 @@ def _fit(self, x: torch.Tensor, single_eval_pos: int, **kwargs: Any) -> None: """Compute the feature means on the training set for replacing NaNs. Args: - x: The input tensor. + x: The input tensor of shape (seq_len, batch_size, num_features). single_eval_pos: The position to use for single evaluation. **kwargs: Additional keyword arguments (unused). """ - self.feature_means_ = torch_nanmean(x[:single_eval_pos], axis=0) + train_data = x[:single_eval_pos] # (train_len, batch, num_features) + # Compute mean across sequence dimension (axis=0) → (batch, num_features) + seq_mean = torch_nanmean(train_data, axis=0) + # Average across batch dimension to get batch-independent means → (num_features,) + # This enables caching to work with different batch sizes between + # cache population and cache usage + if seq_mean.dim() > 1: + self.feature_means_ = torch_nanmean(seq_mean, axis=0) + else: + # Already 1D (single batch element case) + self.feature_means_ = seq_mean def _transform( self, @@ -478,7 +596,9 @@ def _transform( nan_mask = torch.logical_or(torch.isnan(x), torch.isinf(x)) # replace nans with the mean of the corresponding feature x = x.clone() # clone to avoid inplace operations - x[nan_mask] = self.feature_means_.unsqueeze(0).expand_as(x)[nan_mask] + # feature_means_ is (num_features,), reshape to broadcast with x (seq, batch, features) + means_expanded = self.feature_means_.view(1, 1, -1).expand_as(x) + x[nan_mask] = means_expanded[nan_mask] return x, nans_indicator @@ -516,12 +636,18 @@ def _fit(self, x: torch.Tensor, **kwargs: Any) -> None: """Compute the number of used features on the training set. Args: - x: The input tensor. + x: The input tensor of shape (seq_len, batch_size, num_features). **kwargs: Additional keyword arguments (unused). """ + # sel: (batch, num_features) - True where feature varies across sequence sel = (x[1:] == x[0]).sum(0) != (x.shape[0] - 1) + # per_batch_used: (batch,) - count of used features per batch element + per_batch_used = sel.sum(-1) + # Average across batch to get batch-independent count → scalar + # This enables caching to work with different batch sizes + avg_used = per_batch_used.float().mean() self.number_of_used_features_ = torch.clip( - sel.sum(-1).unsqueeze(-1), + avg_used.unsqueeze(-1), # (1,) for broadcasting min=1, ).cpu() @@ -544,13 +670,14 @@ def _transform(self, x: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor]: dtype=x.dtype, ) if self.normalize_by_used_features: + # number_of_used_features_ is a scalar (1,) that broadcasts to any batch + n_used = self.number_of_used_features_.to(x.device) + if self.normalize_by_sqrt: # Verified that this gives indeed unit variance with appended zeros - x = x * torch.sqrt( - self.num_features / self.number_of_used_features_.to(x.device), - ) + x = x * torch.sqrt(self.num_features / n_used) else: - x = x * (self.num_features / self.number_of_used_features_.to(x.device)) + x = x * (self.num_features / n_used) zeros_appended = torch.zeros( *x.shape[:-1], diff --git a/pfns/model/layer.py b/pfns/model/layer.py index 3ba1e8d..f2496ff 100644 --- a/pfns/model/layer.py +++ b/pfns/model/layer.py @@ -50,6 +50,10 @@ def __init__( # noqa: PLR0913 d_k: int | None = None, d_v: int | None = None, precomputed_kv: None | torch.Tensor | tuple[torch.Tensor, torch.Tensor] = None, + attention_across_items_first: bool = False, + positions_num_measures: int = 0, + positions_base: float = 0.02, + dont_look_at_yourself: bool = False, ) -> None: """ Args: @@ -132,7 +136,11 @@ def __init__( # noqa: PLR0913 precomputed_v=precomputed_v, precomputed_kv=precomputed_kv, init_gain=attention_init_gain, + positions_base=positions_base, + positions_num_measures=positions_num_measures, + dont_look_at_yourself=dont_look_at_yourself, ) + self.positions_num_measures = positions_num_measures if dim_feedforward is None: dim_feedforward = 2 * d_model @@ -179,6 +187,7 @@ def __init__( # noqa: PLR0913 self.multiquery_item_attention_for_test_set = ( multiquery_item_attention_for_test_set ) + self.attention_across_items_first = attention_across_items_first def __setstate__(self, state: dict[str, Any]) -> None: state.setdefault("save_peak_mem_factor", None) @@ -191,6 +200,9 @@ def forward( # noqa: C901 *, cache_trainset_representation: bool = False, att_src: Tensor | None = None, + rope_vals: torch.Tensor | None = None, + positions: torch.Tensor + | None = None, # shape: [batch, seqlen_q, num_feature_blocks, 1] ) -> Tensor: """Pass the input through the encoder layer. @@ -254,9 +266,15 @@ def attn_between_features(x: torch.Tensor) -> torch.Tensor: ) def attn_between_items(x: torch.Tensor) -> torch.Tensor: + transposed_rope_vals = ( + rope_vals.transpose(1, 2) if rope_vals is not None else None + ) # we need to transpose as self attention always treats # dim -2 as the sequence dimension if self.multiquery_item_attention_for_test_set: + assert rope_vals is None + assert self.positions_num_measures == 0 + if single_eval_pos < x.shape[1]: new_x_test = self.self_attn_between_items( x[:, single_eval_pos:].transpose(1, 2), @@ -295,10 +313,15 @@ def attn_between_items(x: torch.Tensor) -> torch.Tensor: ) attention_src_x = None + positions_kv = None if att_src is not None: attention_src_x = att_src.transpose(1, 2) elif single_eval_pos: attention_src_x = x[:, :single_eval_pos].transpose(1, 2) + if positions is not None: + positions_kv = positions[:, :single_eval_pos].transpose(1, 2) + else: + assert positions is None return self.self_attn_between_items( x.transpose(1, 2), @@ -308,6 +331,9 @@ def attn_between_items(x: torch.Tensor) -> torch.Tensor: add_input=True, allow_inplace=True, use_cached_kv=cache_trainset_representation and not single_eval_pos, + rope_vals=transposed_rope_vals, + positions=positions.transpose(1, 2) if positions is not None else None, + positions_kv=positions_kv, ).transpose(1, 2) # the mlp tends to require 8 times more memory at its peak, that is why we use 8 here @@ -325,8 +351,12 @@ def attn_between_items(x: torch.Tensor) -> torch.Tensor: " blocks must be 1." ) + if self.attention_across_items_first: + sublayers.insert(0, attn_between_items) + else: + sublayers.append(attn_between_items) + sublayers += [ - attn_between_items, partial( self.mlp.__call__, save_peak_mem_factor=( diff --git a/pfns/model/multi_head_attention.py b/pfns/model/multi_head_attention.py index 8eaee3b..36e15f6 100644 --- a/pfns/model/multi_head_attention.py +++ b/pfns/model/multi_head_attention.py @@ -21,6 +21,26 @@ HAVE_FLASH_ATTN = False +def apply_rope(x, rope_vals): + # x has shape [b,s,h,d] + # rope_vals has shape [b,s,d] + b, s, h, dim = x.shape + assert rope_vals.shape == (b, s, dim), f"{rope_vals.shape=} != {(b, s, dim)}" + assert (dim // 2) * 2 == dim, f"{dim} is not even" + + x = x.reshape(b, s, h, -1, 2) + rope_vals = rope_vals.reshape(b, s, 1, -1, 2) + + out = torch.stack( + [ + x[..., 0] * rope_vals[..., 0] - x[..., 1] * rope_vals[..., 1], + x[..., 1] * rope_vals[..., 0] + x[..., 0] * rope_vals[..., 1], + ], + -1, + ).view(b, s, h, dim) + return out + + class MultiHeadAttention(torch.nn.Module): """ An implementation of multi-head attention, heavily relying on the pytorch @@ -186,6 +206,9 @@ def __init__( # noqa: PLR0913 precomputed_kv: torch.Tensor | None = None, recompute: bool = False, init_gain: float = 1.0, + positions_base: float = 0.02, + positions_num_measures: int = 0, + dont_look_at_yourself: bool = False, ): super().__init__() assert nhead % share_kv_across_n_heads == 0 @@ -202,6 +225,14 @@ def __init__( # noqa: PLR0913 self.recompute = recompute self.init_gain = init_gain + self.positions_base = positions_base + assert (positions_num_measures + 1) <= self._d_k == self._d_v, ( + f"positions_num_measures {positions_num_measures} must be less than " + f"or equal to d_k {self._d_k} and d_v {self._d_v}." + ) + self.positions_num_measures = positions_num_measures + self.dont_look_at_yourself = dont_look_at_yourself + w_out = torch.nn.Parameter( torch.empty(nhead, d_v, output_size, device=device, dtype=dtype), ) @@ -287,6 +318,9 @@ def forward( reuse_first_head_kv: bool = False, only_cache_first_head_kv: bool = False, use_cached_kv: bool = False, + rope_vals: torch.Tensor | None = None, # shape: [..., batch, seq_len, head dim] + positions: torch.Tensor | None = None, # shape: [..., batch, seqlen_q, 1] + positions_kv: torch.Tensor | None = None, # shape: [..., batch, seqlen_kv, 1] ): """X is the current hidden and has a shape of [batch, ..., seq_len, input_size]. If keys and values are present in the cache and 'freeze_kv' is not set, they @@ -299,10 +333,32 @@ def forward( assert not ( cache_kv and use_cached_kv ), "Cannot cache and use cached keys and values at the same time." - assert not x.requires_grad or ( - not self.has_cached_kv and not cache_kv - ), "Saving keys and values is only supported during inference." + assert ( + not x.requires_grad or not cache_kv + ), "Saving keys and values will stop gradients to flow into trainset." + assert not (use_cached_kv and rope_vals is not None), ( + "Cannot use rope_vals with cached KV - ROPE is already applied to cached values" + ) + + # Flatten the batch dimensions x, x_kv, x_shape_after_transpose = self._rearrange_inputs_to_flat_batch(x, x_kv) + if rope_vals is not None: + rope_vals = rope_vals.reshape(-1, *rope_vals.shape[-2:]) + + # flatten positions' batch dims, just like x and x_kv are flattened + positions_flat = None + positions_kv_flat = None + if positions is not None: + assert ( + positions.shape[-1] == 1 + ), f"positions must have trailing dim 1, got {positions.shape}" + # Squeeze the trailing dim and flatten batch dims + positions_flat = positions.reshape(-1, positions.shape[-2]) + if positions_kv is not None: + assert ( + positions_kv.shape[-1] == 1 + ), f"positions_kv must have trailing dim 1, got {positions_kv.shape}" + positions_kv_flat = positions_kv.reshape(-1, positions_kv.shape[-2]) nhead_kv = 1 if reuse_first_head_kv else self._nhead_kv @@ -315,7 +371,6 @@ def forward( else: batch_size, seqlen_kv = x.shape[:2] - # TODO: handling of device and dtype. if self._w_kv is not None or self._w_qkv is not None: self._kv_cache = torch.empty( batch_size, @@ -356,6 +411,9 @@ def forward( allow_inplace=allow_inplace, save_peak_mem_factor=save_peak_mem_factor, reuse_first_head_kv=reuse_first_head_kv, + rope_vals=rope_vals, + positions=positions_flat, + positions_kv=positions_kv_flat, ) return output.reshape(x_shape_after_transpose[:-1] + output.shape[-1:]) @@ -397,6 +455,46 @@ def compute_qkv( # noqa: PLR0912, C901 v = v_cache kv = kv_cache + # Handle batch dimension broadcast when cached KV has different batch + # size than the query. This is common when caching training context + # with batch_size=1 and evaluating multiple test batches. + # The batch dimension includes feature groups, so we need to repeat + # (not just expand) when batch sizes differ. + query_batch_size = x.shape[0] + if k is not None and k.shape[0] != query_batch_size: + if k.shape[0] == 1: + k = k.expand(query_batch_size, *k.shape[1:]) + elif query_batch_size % k.shape[0] == 0: + n_repeats = query_batch_size // k.shape[0] + k = k.repeat(n_repeats, *([1] * (k.dim() - 1))) + else: + raise RuntimeError( + f"Cannot broadcast cached k with shape {k.shape} to query " + f"batch size {query_batch_size}." + ) + if v is not None and v.shape[0] != query_batch_size: + if v.shape[0] == 1: + v = v.expand(query_batch_size, *v.shape[1:]) + elif query_batch_size % v.shape[0] == 0: + n_repeats = query_batch_size // v.shape[0] + v = v.repeat(n_repeats, *([1] * (v.dim() - 1))) + else: + raise RuntimeError( + f"Cannot broadcast cached v with shape {v.shape} to query " + f"batch size {query_batch_size}." + ) + if kv is not None and kv.shape[0] != query_batch_size: + if kv.shape[0] == 1: + kv = kv.expand(query_batch_size, *kv.shape[1:]) + elif query_batch_size % kv.shape[0] == 0: + n_repeats = query_batch_size // kv.shape[0] + kv = kv.repeat(n_repeats, *([1] * (kv.dim() - 1))) + else: + raise RuntimeError( + f"Cannot broadcast cached kv with shape {kv.shape} to query " + f"batch size {query_batch_size}." + ) + assert (k is None) == (v is None) if self._w_qkv is None: @@ -469,6 +567,9 @@ def _compute( cache_kv: bool, use_cached_kv: bool, reuse_first_head_kv: bool, + rope_vals: torch.Tensor | None = None, # shape: [batch, seq_len, head dim] + positions: torch.Tensor | None = None, # shape: [batch, seqlen_q] + positions_kv: torch.Tensor | None = None, # shape: [batch, seqlen_kv] ) -> torch.Tensor: """Attention computation. Called by 'forward', potentially on shards, once shapes have been normalized. @@ -483,7 +584,21 @@ def _compute( use_cached_kv=use_cached_kv, reuse_first_head_kv=reuse_first_head_kv, ) - attention_head_outputs = MultiHeadAttention.compute_attention_heads( + + if rope_vals is not None: + if q is not None: + q = apply_rope(q, rope_vals) + if k is not None: + k = apply_rope(k, rope_vals[:, : k.shape[1]]) + if kv is not None: + kv[..., 0, :, :] = apply_rope( + kv[..., 0, :, :], rope_vals[:, : kv.shape[1]] + ) + if qkv is not None: + qkv[..., 0, :, :] = apply_rope(qkv[..., 0, :, :], rope_vals) + qkv[..., 1, :, :] = apply_rope(qkv[..., 1, :, :], rope_vals) + + attention_head_outputs = self.compute_attention_heads( q, k, v, @@ -491,6 +606,11 @@ def _compute( qkv, self.dropout_p, self.softmax_scale, + positions=positions, + positions_kv=positions_kv, + positions_base=self.positions_base, + positions_num_measures=self.positions_num_measures, + dont_look_at_yourself=self.dont_look_at_yourself, ) return torch.einsum( "... h d, h d s -> ... s", @@ -524,7 +644,7 @@ def broadcast_kv_across_heads( share_kv_across_n_heads, -1, ) - return kv.reshape(*kv.shape[:-3], nhead * share_kv_across_n_heads, d) + return kv.contiguous().reshape(*kv.shape[:-3], nhead * share_kv_across_n_heads, d) @staticmethod def compute_attention_heads( # noqa: C901, PLR0912 @@ -535,6 +655,12 @@ def compute_attention_heads( # noqa: C901, PLR0912 qkv: torch.Tensor | None, dropout_p: float | None = None, softmax_scale: float | None = None, + *, + positions: torch.Tensor | None = None, # shape: [batch, seqlen_q] + positions_kv: torch.Tensor | None = None, # shape: [batch, seqlen_kv] + positions_base: float = 0.02, + positions_num_measures: int = 8, + dont_look_at_yourself: bool = False, # only with positions is not None ) -> torch.Tensor: assert (k is None) == (v is None) assert sum([qkv is None, kv is None, k is None and v is None]) == 2 @@ -549,6 +675,14 @@ def compute_attention_heads( # noqa: C901, PLR0912 assert k is not None assert v is not None + if positions is not None: + assert ( + positions_kv is not None + ), "positions is not None but positions_kv is None" + + if dont_look_at_yourself: + assert positions_num_measures > 0 + batch_size, seqlen_q, nhead, d_k = q.shape _, seqlen_kv, nhead_kv, d_v = v.shape share_kv_across_n_heads = nhead // nhead_kv @@ -565,6 +699,10 @@ def compute_attention_heads( # noqa: C901, PLR0912 TORCH_2_ATTENTION_POSSIBLE = ( torch.__version__ >= "2" and torch.cuda.is_available() ) + if positions_num_measures > 0: + TORCH_2_ATTENTION_POSSIBLE = False + use_flash_attention = False + USE_TORCH_2_GQA = False if TORCH_2_ATTENTION_POSSIBLE: # check whether torch.nn.functional.scaled_dot_product_attention has a @@ -618,7 +756,7 @@ def get_seqlen_cumsums( if qkv is not None: attention_head_outputs = flash_attn_unpadded_qkvpacked_func( # type: ignore - qkv.reshape(batch_size * seqlen_q, 3, nhead, d_k), + qkv.contiguous().reshape(batch_size * seqlen_q, 3, nhead, d_k), get_seqlen_cumsums(batch_size, seqlen_q, qkv.device), seqlen_q, dropout_p=dropout_p, @@ -633,8 +771,8 @@ def get_seqlen_cumsums( share_kv_across_n_heads, ) attention_head_outputs = flash_attn_unpadded_kvpacked_func( # type: ignore - q.reshape(batch_size * seqlen_q, nhead, d_k), - kv.reshape(batch_size * seqlen_kv, 2, nhead, d_k), + q.contiguous().reshape(batch_size * seqlen_q, nhead, d_k), + kv.contiguous().reshape(batch_size * seqlen_kv, 2, nhead, d_k), get_seqlen_cumsums(batch_size, seqlen_q, q.device), get_seqlen_cumsums(batch_size, seqlen_kv, kv.device), seqlen_q, @@ -662,9 +800,9 @@ def get_seqlen_cumsums( share_kv_across_n_heads, ) attention_head_outputs = flash_attn_unpadded_func( # type: ignore - q.reshape(batch_size * seqlen_q, nhead, d_k_), # type: ignore - k.reshape(batch_size * seqlen_kv, nhead, d_k_), # type: ignore - v.reshape(batch_size * seqlen_kv, nhead, d_v), + q.contiguous().reshape(batch_size * seqlen_q, nhead, d_k_), # type: ignore + k.contiguous().reshape(batch_size * seqlen_kv, nhead, d_k_), # type: ignore + v.contiguous().reshape(batch_size * seqlen_kv, nhead, d_v), get_seqlen_cumsums(batch_size, seqlen_q, q.device), get_seqlen_cumsums(batch_size, seqlen_kv, k.device), seqlen_q, @@ -704,9 +842,58 @@ def get_seqlen_cumsums( ) attention_head_outputs = attention_head_outputs.transpose(1, 2) else: - k = MultiHeadAttention.broadcast_kv_across_heads(k, share_kv_across_n_heads) - v = MultiHeadAttention.broadcast_kv_across_heads(v, share_kv_across_n_heads) - logits = torch.einsum("b q h d, b k h d -> b q k h", q, k) + k = MultiHeadAttention.broadcast_kv_across_heads( + k, share_kv_across_n_heads + ) # [b,k,h,d] + v = MultiHeadAttention.broadcast_kv_across_heads( + v, share_kv_across_n_heads + ) # [b,k,h,d] + # Prototype: compute pairwise distance metrics without modifying attention + if positions_num_measures > 0: + assert ( + positions.shape[0] == batch_size and positions.shape[1] == seqlen_q + ), f"positions (q) shape {positions.shape} incompatible with {(batch_size, seqlen_q)}" + assert ( + positions_kv.shape[0] == batch_size + and positions_kv.shape[1] == seqlen_kv + ), f"positions_kv (kv) shape {positions_kv.shape} incompatible with {(batch_size, seqlen_kv)}" + pos_q = positions + pos_k = positions_kv + # Distance per (q,k) + dist = (pos_q[:, :, None] - pos_k[:, None, :]).abs() # [b,q,k] + scales = torch.tensor( + [2**i for i in range(positions_num_measures)], + device=positions.device, + dtype=positions.dtype, + ) + thresholds = (positions_base * scales)[None, None, None, :] + # Vector of within-threshold interpolations in [0,1] + g = 1.0 - 2.0 * (dist[..., None] / thresholds).clamp( + 0.0, 1.0 + ) # [b,q,k,T] + # Direction indicator + diff = pos_k[:, None, :] - pos_q[:, :, None] + is_right = 2 * (diff >= 0.0).to(q.dtype) - 1 # [b,q,k] + is_right[diff == 0.0] = 0.0 + g = torch.cat([g, is_right[..., None]], dim=-1) # [b,q,k,T+1] + t = g.shape[-1] + + logits = torch.einsum( + "b q h d, b k h d -> b q k h", q[..., :-t], k[..., :-t] + ) + position_available_mask = ~positions.isnan().any(1) # [b] + logits[position_available_mask] += torch.einsum( + "b q h t, b q k t -> b q k h", + q[position_available_mask, ..., -t:], + g[position_available_mask], + ) + + if dont_look_at_yourself: + logits[:, torch.arange(seqlen_kv), torch.arange(seqlen_kv), :] = ( + float("-inf") + ) + else: + logits = torch.einsum("b q h d, b k h d -> b q k h", q, k) logits *= ( torch.sqrt(torch.tensor(1.0 / d_k)).to(k.device) if softmax_scale is None @@ -714,7 +901,21 @@ def get_seqlen_cumsums( ) ps = torch.softmax(logits, dim=2) ps = torch.dropout(ps, dropout_p, train=True) - attention_head_outputs = torch.einsum("b q k h, b k h d -> b q h d", ps, v) + if positions_num_measures > 0: + attention_head_outputs = torch.einsum( + "b q k h, b k h d -> b q h d", ps, v + ) + attention_head_outputs[position_available_mask, ..., -t:] = ( + torch.einsum( + "b q k h, b q k t -> b q h t", + ps[position_available_mask], + g[position_available_mask], + ) + ) + else: + attention_head_outputs = torch.einsum( + "b q k h, b k h d -> b q h d", ps, v + ) return attention_head_outputs.reshape( batch_size, diff --git a/pfns/model/transformer.py b/pfns/model/transformer.py index e667fd8..cd4e287 100644 --- a/pfns/model/transformer.py +++ b/pfns/model/transformer.py @@ -69,6 +69,10 @@ def __init__( # noqa: C901, D417, PLR0913 y_style_encoder: nn.Module | None = None, attention_between_features: bool = True, batch_first: bool = True, + use_rope: bool = False, + rope_multiplier: float = 1, + positions_num_measures: int = 0, + x_only_mode: bool = False, **layer_kwargs: Any, ): """Initializes the PerFeatureTransformer module. @@ -146,7 +150,7 @@ def __init__( # noqa: C901, D417, PLR0913 print("Using linear x encoder, as no encoder was provided.") encoder = get_linear_x_encoder(ninp, features_per_group) - if y_encoder is None: + if y_encoder is None and not x_only_mode: print("Using linear y encoder, as no y_encoder was provided.") y_encoder = get_linear_y_encoder(ninp) @@ -160,6 +164,9 @@ def __init__( # noqa: C901, D417, PLR0913 self.cached_embeddings: torch.Tensor | None = None self.attention_between_features = attention_between_features self.batch_first = batch_first + self.use_rope = use_rope + self.rope_multiplier = rope_multiplier + self.positions_num_measures = positions_num_measures def layer_creator(): return PerFeatureLayer( @@ -172,6 +179,7 @@ def layer_creator(): precomputed_kv.pop(0) if precomputed_kv is not None else None ), attention_between_features=attention_between_features, + positions_num_measures=positions_num_measures, **layer_kwargs, ) @@ -215,6 +223,16 @@ def layer_creator(): assert attention_between_features, "Attention between features must be True when using a y_style_encoder, otherwise only use a style_encoder." self.y_style_encoder = y_style_encoder + if x_only_mode: + assert ( + attention_between_features + ), "attention_between_features must be True when x_only_mode is True" + assert ( + features_per_group == 1 + ), "features_per_group must be 1 when x_only_mode is True" + + self.x_only_mode = x_only_mode + def forward( self, x: torch.Tensor | None, @@ -249,7 +267,6 @@ def forward( these are shared between the datasets within a batch. - `half_layers`: Whether to use the first half of the layers. """ - # Prepare batch-first versions of x, y, test_x for _forward # and clone all to be sure not to change outside data x_bf = x.clone() if x is not None else None @@ -271,11 +288,16 @@ def forward( # Determine single_eval_pos based on the original y shape if y_bf is not None: single_eval_pos = y_bf.shape[1] + elif self.x_only_mode: + single_eval_pos = x_bf.shape[1] else: single_eval_pos = None # Handle cache_trainset_representation and combining x, test_x if self.cache_trainset_representation and y is None: + assert ( + not self.x_only_mode + ), "x_only_mode is not supported when cache_trainset_representation is True" assert ( (test_x is None) != (x is None) ), "Provide the test inputs only via test_x or x, not both, when cache_trainset_representation is True" @@ -286,7 +308,7 @@ def forward( x_bf is not None ), "x must be provided when not predicting from cached trainset representations" assert ( - y is not None + y is not None or self.x_only_mode ), "y must be provided when not predicting from cached trainset representations" if test_x_bf is not None: @@ -335,13 +357,20 @@ def _forward( # noqa: PLR0912, C901 y is None ), "_forward expects y=None if single_eval_pos is 0/None and caching" else: - assert ( - y is not None - ), "_forward expects y if not caching for pure inference or during training" + if not self.x_only_mode: + assert ( + y is not None + ), "_forward expects y if not caching for pure inference or during training" + assert ( single_eval_pos is not None ), "_forward expects single_eval_pos if not caching for pure inference or during training" + if self.use_rope or (self.positions_num_measures > 0): + assert ( + not self.x_only_mode + ), "Rope/Positional Embs only supported for x_only_mode=False" + # single_eval_pos is the length of the training sequence part. # If None (e.g. pure inference from cache), treat as 0. current_context_len = single_eval_pos or 0 @@ -355,7 +384,7 @@ def _forward( # noqa: PLR0912, C901 _batch_size, _seq_len, _num_features_orig_main = x["main"].shape if ( - y is None + y is None and not self.x_only_mode ): # Should only happen if self.cache_trainset_representation and not single_eval_pos y_main_ref = x["main"] y = { @@ -451,60 +480,96 @@ def _forward( # noqa: PLR0912, C901 categorical_inds_to_use = new_categorical_inds - for k in y: - # y[k] is (batch_size, current_seq_len_y, num_targets_y) - if y[k].ndim == 2: # (B,S) or (B,T) - y[k] = y[k].unsqueeze(-1) # B S -> B S 1 + rope_vals = None + positions = None + if self.x_only_mode: + embedded_y = None + else: + for k in y: + # y[k] is (batch_size, current_seq_len_y, num_targets_y) + if y[k].ndim == 2: # (B,S) or (B,T) + y[k] = y[k].unsqueeze(-1) # B S -> B S 1 + + # Pad y sequence length if shorter than x's sequence length (_seq_len) + if y[k].shape[1] < _seq_len: # _seq_len is full sequence length from x + # current_context_len is the length of the training part of y + assert ( + y[k].shape[1] + == current_context_len # y should only contain train part if shorter + or y[k].shape[1] + == _seq_len # Should not happen if already shorter + ), f"y[{k}] seq len {y[k].shape[1]} vs train_seq_len {current_context_len} vs x_seq_len {_seq_len}" + + # Only pad if y is for training part or not main y (auxiliary targets might be full length) + if k != "main" or y[k].shape[1] == current_context_len: + y[k] = torch.cat( + ( + y[k], + torch.nan + * torch.zeros( + y[k].shape[0], # batch_size + _seq_len - y[k].shape[1], # seq_len difference + y[k].shape[2], # num_targets_y + device=y[k].device, + dtype=y[k].dtype, + ), + ), + dim=1, # Pad along sequence dimension (dim 1 for batch-first) + ) + # Now y[k] is (batch_size, _seq_len, num_targets_y) + + # Making sure no label leakage ever happens for y["main"] (batch-first indexing) + # current_context_len is the length of the training data part + if "main" in y and y["main"].shape[1] > current_context_len: + y["main"][:, current_context_len:] = torch.nan + + # Prepare y for y_encoder (transpose to sequence-first if y_encoder expects it) + y_for_y_encoder = {} + for k_enc, v_enc in y.items(): + y_for_y_encoder[k_enc] = v_enc.transpose(0, 1) # B S T -> S B T + + embedded_y = self.y_encoder( + y_for_y_encoder, + single_eval_pos=current_context_len, # Length of training part for y_encoder + cache_trainset_representation=self.cache_trainset_representation, + ).transpose(0, 1) - # Pad y sequence length if shorter than x's sequence length (_seq_len) - if y[k].shape[1] < _seq_len: # _seq_len is full sequence length from x - # current_context_len is the length of the training part of y + if self.use_rope or self.positions_num_measures > 0: assert ( - y[k].shape[1] - == current_context_len # y should only contain train part if shorter - or y[k].shape[1] == _seq_len # Should not happen if already shorter - ), f"y[{k}] seq len {y[k].shape[1]} vs train_seq_len {current_context_len} vs x_seq_len {_seq_len}" - - # Only pad if y is for training part or not main y (auxiliary targets might be full length) - if k != "main" or y[k].shape[1] == current_context_len: - y[k] = torch.cat( - ( - y[k], - torch.nan - * torch.zeros( - y[k].shape[0], # batch_size - _seq_len - y[k].shape[1], # seq_len difference - y[k].shape[2], # num_targets_y - device=y[k].device, - dtype=y[k].dtype, - ), - ), - dim=1, # Pad along sequence dimension (dim 1 for batch-first) + self.attention_between_features + ), "Rope only supported for attention_between_features=True" + assert ( + self.features_per_group == 1 + ), "Rope only supported for features_per_group=1" + assert not ( + self.use_rope and (self.positions_num_measures > 0) + ), "Rope and positions_num_measures > 0 not supported at the same time" + if self.use_rope: + head_dim = self.ninp // self.nhead + rope_vals_x = get_rope_vals( + x["main"].flatten(), head_dim, multiplier=self.rope_multiplier + ).view(_batch_size, _seq_len, num_groups_main, head_dim) + rope_vals_y = torch.ones_like(rope_vals_x[:, :, :1, :]).view( + _batch_size, _seq_len, 1, -1, 2 ) - # Now y[k] is (batch_size, _seq_len, num_targets_y) - - # Making sure no label leakage ever happens for y["main"] (batch-first indexing) - # current_context_len is the length of the training data part - if "main" in y and y["main"].shape[1] > current_context_len: - y["main"][:, current_context_len:] = torch.nan - - # Prepare y for y_encoder (transpose to sequence-first if y_encoder expects it) - y_for_y_encoder = {} - for k_enc, v_enc in y.items(): - y_for_y_encoder[k_enc] = v_enc.transpose(0, 1) # B S T -> S B T - - embedded_y = self.y_encoder( - y_for_y_encoder, - single_eval_pos=current_context_len, # Length of training part for y_encoder - cache_trainset_representation=self.cache_trainset_representation, - ).transpose(0, 1) - - del y, y_for_y_encoder - if torch.isnan(embedded_y).any(): - raise ValueError( - f"{torch.isnan(embedded_y).any()=}, make sure to add nan handlers" - " to the ys that are not fully provided (test set missing)", - ) + rope_vals_y[..., 1] = 0.0 + rope_vals_y = rope_vals_y.view(_batch_size, _seq_len, 1, head_dim) + rope_vals = torch.cat((rope_vals_x, rope_vals_y), dim=2) + else: + assert ( + "main" in y and len(y) == 1 + ), "Positions in attention only supported for simple y" + # [batch, seqlen_q, num_feature_blocks, 1] + positions = x["main"] # [_batch_size, _seq_len, num_groups_main, 1] + # add positions for y [_batch_size, _seq_len, 1] + positions = torch.cat((positions, y["main"].unsqueeze(2)), dim=2) + + del y, y_for_y_encoder + if torch.isnan(embedded_y).any(): + raise ValueError( + f"{torch.isnan(embedded_y).any()=}, make sure to add nan handlers" + " to the ys that are not fully provided (test set missing)", + ) extra_encoders_args = {} if categorical_inds_to_use is not None and isinstance( @@ -524,13 +589,13 @@ def _forward( # noqa: PLR0912, C901 **extra_encoders_args, ), "s (b f) e -> b s f e", - b=embedded_y.shape[0], + b=_batch_size, ) # b s f 1 -> b s f e del x embedded_x, embedded_y = self.add_embeddings( embedded_x, # (b s num_groups e) - embedded_y, # (b s e) + embedded_y, # (b s e) | None num_features=_num_features_orig_main, seq_len=_seq_len, cache_embeddings=( @@ -542,9 +607,16 @@ def _forward( # noqa: PLR0912, C901 ) if self.attention_between_features: - # b s f e + b s 1 e -> b s f+1 e - embedded_input = torch.cat((embedded_x, embedded_y.unsqueeze(2)), dim=2) + if self.x_only_mode: + embedded_input = embedded_x + else: + # b s f e + b s 1 e -> b s f+1 e + embedded_input = torch.cat((embedded_x, embedded_y.unsqueeze(2)), dim=2) + else: + assert ( + not self.x_only_mode + ), "x_only_mode is not supported when attention_between_features is False" # add them together in this case, like for the original PFNs assert ( embedded_x.shape[2] == 1 @@ -592,8 +664,17 @@ def _forward( # noqa: PLR0912, C901 device=embedded_input.device, dtype=embedded_input.dtype, ) + else: + assert ( + not self.x_only_mode + ), "x_only_mode is not supported when embedded_y_style is not None" - full_embedded_style = torch.cat((embedded_style, embedded_y_style), dim=2) + if self.x_only_mode: + full_embedded_style = embedded_style + else: + full_embedded_style = torch.cat( + (embedded_style, embedded_y_style), dim=2 + ) embedded_input = torch.cat( (full_embedded_style, embedded_input), @@ -606,12 +687,14 @@ def _forward( # noqa: PLR0912, C901 f"There should be no NaNs in the encoded x and y." "Check that you do not feed NaNs or use a NaN-handling enocder." "Your embedded x and y returned the following:" - f"{torch.isnan(embedded_x).any()=} | {torch.isnan(embedded_y).any()=}", + f"{torch.isnan(embedded_x).any()=} | {(embedded_y is not None and torch.isnan(embedded_y).any())=}", ) del embedded_y, embedded_x encoder_out = self.transformer_layers( embedded_input, # (b s_effective (num_groups+1_for_y) e) + rope_vals=rope_vals, + positions=positions, single_eval_pos=current_context_len, # Pass the context length including style half_layers=half_layers, cache_trainset_representation=self.cache_trainset_representation, @@ -625,11 +708,19 @@ def _forward( # noqa: PLR0912, C901 # for the test sequence part (after current_context_len). test_encoder_out = encoder_out[ - :, current_context_len:, -1 - ] # (batch, seq_test, embed_dim) + :, current_context_len:, : + ] # (batch, seq_test, num_groups[+1_for_y], embed_dim) train_encoder_out = encoder_out[ - :, :current_context_len, -1 - ] # (batch, seq_train_and_style, embed_dim) + :, :current_context_len, : + ] # (batch, seq_train_and_style, num_groups[+1_for_y], embed_dim) + + if not self.x_only_mode: + test_encoder_out = test_encoder_out[ + :, :, -1, : + ] # (batch, seq_test, embed_dim) + train_encoder_out = train_encoder_out[ + :, :, -1, : + ] # (batch, seq_train_and_style, embed_dim) # No transposition needed here as _forward returns batch-first @@ -647,7 +738,7 @@ def _forward( # noqa: PLR0912, C901 def add_embeddings( # noqa: C901, PLR0912 self, x: torch.Tensor, # (b s num_groups e) - y: torch.Tensor, # (b s e) + y: torch.Tensor | None, # (b s e) *, num_features: int, # Original number of features (before grouping) seq_len: int, # Sequence length @@ -826,3 +917,29 @@ def isolate_torch_rng(seed: int, device: torch.device) -> Generator[None, None, torch.set_rng_state(torch_rng_state) if torch.cuda.is_available(): torch.cuda.set_rng_state(torch_cuda_rng_state, device=device) + + +def get_rope_vals( + inputs: torch.Tensor, dim: int, base: int = 10_000, multiplier: float = 1.0 +): + # inputs has to have shape [b] + + assert (dim // 2) * 2 == dim, f"{dim=} not divisible by 2" + + theta = ( + multiplier + * 1000.0 + / ( + base + ** ( + torch.arange(0, dim, 2, device=inputs.device)[: (dim // 2)].float() + / dim + ) + ) + ) + + deg = torch.einsum("b,d->bd", inputs, theta) + + rope_vals = torch.stack([torch.cos(deg), torch.sin(deg)], dim=-1) # b d/2 2 + + return rope_vals.view(-1, dim) diff --git a/pfns/model/transformer_config.py b/pfns/model/transformer_config.py index e3c2610..fbabdf4 100644 --- a/pfns/model/transformer_config.py +++ b/pfns/model/transformer_config.py @@ -1,5 +1,6 @@ import typing as tp from dataclasses import dataclass +from typing import Literal from pfns import base_config from pfns.model import encoders, transformer @@ -29,6 +30,13 @@ class TransformerConfig(base_config.BaseConfig): features_per_group: int = 1 attention_between_features: bool = True model_extra_args: tp.Dict[str, base_config.BaseTypes] | None = None + multiquery_item_attention_for_test_set: bool = False + activation: Literal["gelu", "relu"] = "relu" + recompute_layer: bool = False + use_rope: bool = False + rope_multiplier: float = 1 + positions_num_measures: int = 0 + x_only_mode: bool = False def create_model(self) -> transformer.TableTransformer: # Resolve criterion @@ -81,6 +89,13 @@ def create_model(self) -> transformer.TableTransformer: style_encoder=style_encoder, y_style_encoder=y_style_encoder, batch_first=True, # model is batch_first by default now + multiquery_item_attention_for_test_set=self.multiquery_item_attention_for_test_set, + activation=self.activation, + recompute_layer=self.recompute_layer, + use_rope=self.use_rope, + rope_multiplier=self.rope_multiplier, + positions_num_measures=self.positions_num_measures, + x_only_mode=self.x_only_mode, **(self.model_extra_args or {}), ) model.criterion = criterion diff --git a/pfns/priors/condition_on_area_of_opt_continuous.py b/pfns/priors/condition_on_area_of_opt_continuous.py index 638cc1b..fd56e71 100644 --- a/pfns/priors/condition_on_area_of_opt_continuous.py +++ b/pfns/priors/condition_on_area_of_opt_continuous.py @@ -1,7 +1,5 @@ import torch -from ..utils import default_device - from .prior import Batch @@ -11,8 +9,7 @@ def get_batch( seq_len, num_features, get_batch, - epoch, - device=default_device, + device="cpu", hyperparameters=None, **kwargs, ): @@ -31,7 +28,6 @@ def get_batch( :param seq_len: :param num_features: :param get_batch: - :param epoch: :param device: :param hyperparameters: :param kwargs: @@ -42,65 +38,69 @@ def get_batch( hyperparameters = {} maximize = hyperparameters.get("condition_on_area_maximize", True) - size_range = hyperparameters.get("condition_on_area_size_range", (0.1, 0.5)) + size_range = hyperparameters.get("condition_on_area_size_range", (0.2, 0.99)) distribution = hyperparameters.get("condition_on_area_distribution", "uniform") assert distribution in ["uniform"] + extra_samples = hyperparameters.get("condition_on_area_extra_samples", 0) batch: Batch = get_batch( batch_size=batch_size, - seq_len=seq_len, + seq_len=seq_len + extra_samples, num_features=num_features, device=device, hyperparameters=hyperparameters, - epoch=epoch, **kwargs, ) assert batch.style is None d = batch.x.shape[2] - prob_correct = torch.rand(batch_size, d, device=device) - correct_opt = torch.rand(batch_size, d, device=device) < prob_correct division_size = ( torch.rand(batch_size, d, device=device) * (size_range[1] - size_range[0]) + size_range[0] ) + division_start = torch.rand(batch_size, d, device=device) * (1 - division_size) + + assert batch.target_y.shape[2] == 1, "Only support single objective." - optima = ( - batch.target_y.argmax(0).squeeze() + optima_inds = ( + batch.target_y.argmax(1).squeeze(-1) if maximize - else batch.target_y.argmin(0).squeeze() + else batch.target_y.argmin(0).squeeze(-1) ) # batch_size, d - optima_hints = ( - batch.x[optima, torch.arange(batch_size, device=device)] - - division_size / 2 - + torch.rand(batch_size, d, device=device) * division_size - ) # shape: (batch_size, d) - optima_hints = optima_hints.clamp(0, 1) - - optima_division_lower_bound = (optima_hints - division_size / 2).clamp(0, 1) - optima_division_upper_bound = (optima_hints + division_size / 2).clamp(0, 1) + optima = batch.x[torch.arange(batch_size), optima_inds] # batch_size, d - random_hints = ( - torch.rand(batch_size, d, device=device) - - division_size / 2 - + torch.rand(batch_size, d, device=device) * division_size - ) # shape: (batch_size, d) - random_hints = random_hints.clamp(0, 1) + is_inside = (division_start <= optima) & ( + optima <= division_start + division_size + ) # batch_size, d - random_division_lower_bound = (random_hints - division_size / 2).clamp(0, 1) - random_division_upper_bound = (random_hints + division_size / 2).clamp(0, 1) + # hint_probs = torch.rand(batch_size, d, device=device) # probs are chosen randomly + # hint probs need to be drawn dependent on whether it is inside or not + # what we want is R ~ Uniform(0, 1), and we now sample p(R=r|Ber(R)=1) and p(R=r|Ber(R)=0) + # that is: p(R=r|Ber(R)=1) = r / 0.5 = 2.0 * r, and p(R=r|Ber(R)=0) = (1-r) / 0.5 = 2.0 * (1-r) + # we can compute the icdfs as icdf(|Ber(R)=1) = np.sqrt(u), icdf(|Ber(R)=0)= 1 - np.sqrt(1 - u) - lower_bounds = torch.where( - correct_opt, optima_division_lower_bound, random_division_lower_bound - ) - upper_bounds = torch.where( - correct_opt, optima_division_upper_bound, random_division_upper_bound + hint_probs = torch.where( + is_inside, + torch.sqrt(torch.rand(batch_size, d, device=device)), + 1 - torch.sqrt(1 - torch.rand(batch_size, d, device=device)), ) - batch.style = torch.stack([prob_correct, lower_bounds, upper_bounds], 2).view( - batch_size, -1 - ) # shape: (batch_size, 3*d) + batch.style = torch.stack( + [hint_probs, division_start, division_start + division_size], 2 + ) # batch_size, d, 3 + + skip_style_prob = hyperparameters.get("condition_on_opt_area_skip_style_prob", 0.0) + + skip_style_mask = torch.rand(batch_size, device=device) < skip_style_prob + + # set to nan for the encoder to figure this out + batch.style[skip_style_mask, :, :] = torch.nan + + if extra_samples: + batch.x = batch.x[:, :-extra_samples] + batch.y = batch.y[:, :-extra_samples] + batch.target_y = batch.target_y[:, :-extra_samples] return batch diff --git a/pfns/priors/convert_prior_to_x_only_format.py b/pfns/priors/convert_prior_to_x_only_format.py new file mode 100644 index 0000000..6476070 --- /dev/null +++ b/pfns/priors/convert_prior_to_x_only_format.py @@ -0,0 +1,120 @@ +# pyre-strict + +from dataclasses import fields + +import torch +from pfns.priors.prior import Batch + + +def get_batch( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + get_batch: callable, + hyperparameters: dict | None = None, + n_targets_per_input: int = 1, + **kwargs, +) -> Batch: + """ + Wrapper function that converts traditional batch format to x-only format. + + This function takes a traditional get_batch function and converts its output + from the format with separate x, y, target_y to the x-only format with + x, test_x, target (and y=None, target_y=None). + + Args: + batch_size: Number of sequences in the batch + seq_len: Total sequence length (train + test) + num_features: Number of input features + single_eval_pos: Position where training ends and testing begins + hyperparameters: Hyperparameter dictionary + get_batch: The traditional get_batch function to wrap + n_targets_per_input: Number of targets per input + **kwargs: Additional arguments to pass to the wrapped get_batch function + + Returns: + Batch in x-only format with x, test_x, target fields + """ + assert n_targets_per_input == 1, "Only single target per input supported" + # Call the traditional get_batch function + traditional_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + hyperparameters=hyperparameters, + n_targets_per_input=n_targets_per_input, + **kwargs, + ) + + # Extract traditional format components + x_traditional = traditional_batch.x # shape: (batch_size, seq_len, num_features) + y_traditional = ( + traditional_batch.y + ) # shape: (batch_size, seq_len, 1) or (batch_size, seq_len,) + if len(y_traditional.shape) == 2: + y_traditional = y_traditional.unsqueeze(-1) + target_y_traditional = ( + traditional_batch.target_y + ) # shape: (batch_size, seq_len, n_targets_per_input) or (batch_size, seq_len,) + if len(target_y_traditional.shape) == 2: + target_y_traditional = target_y_traditional.unsqueeze(-1) + + # Split into train and test portions + x_train = x_traditional[ + :, :single_eval_pos, : + ] # shape: (batch_size, single_eval_pos, num_features) + x_test = x_traditional[ + :, single_eval_pos:, : + ] # shape: (batch_size, seq_len - single_eval_pos, num_features) + y_train = y_traditional[ + :, :single_eval_pos, : + ] # shape: (batch_size, single_eval_pos, 1) + y_test_targets = target_y_traditional[ + :, single_eval_pos:, : + ] # shape: (batch_size, seq_len - single_eval_pos, n_targets_per_input) + + # Convert to x-only format + # x: concatenate training inputs with training outputs + x_with_y = torch.cat( + [x_train, y_train], dim=2 + ) # shape: (batch_size, single_eval_pos, num_features + 1) + + # test_x: test inputs with NaN for y values (to be predicted) + test_y_nan = torch.full( + (batch_size, seq_len - single_eval_pos, 1), + torch.nan, + dtype=x_test.dtype, + device=x_test.device, + ) + test_x_with_nan_y = torch.cat( + [x_test, test_y_nan], dim=2 + ) # shape: (batch_size, seq_len - single_eval_pos, num_features + 1) + + # target: test inputs with target y values + target_with_y = torch.cat( + [torch.full_like(x_test, torch.nan), y_test_targets], dim=2 + ) # shape: (batch_size, seq_len - single_eval_pos, num_features + 1) + + # Create the x-only format batch, taking over all entries from the original batch + # except for the ones we need to change + batch_dict = { + field.name: getattr(traditional_batch, field.name) + for field in fields(traditional_batch) + } + + # Override the fields that need to change for x-only format + batch_dict.update( + { + "x": x_with_y, + "test_x": test_x_with_nan_y, + "target": target_with_y, + "y": None, + "target_y": None, + } + ) + + x_only_batch = Batch(**batch_dict) + + return x_only_batch diff --git a/pfns/priors/data_loading.py b/pfns/priors/data_loading.py index 9869104..c7c698a 100644 --- a/pfns/priors/data_loading.py +++ b/pfns/priors/data_loading.py @@ -1,9 +1,7 @@ import math import os import random - from copy import deepcopy - from functools import partial from typing import Callable, Iterator @@ -86,9 +84,10 @@ def __iter__(self) -> Iterator[Batch]: b = self.get_batch_method(**kwargs) - assert ( - len(b.x) == len(b.y) == len(b.target_y) == batch_shape.batch_size - ), "Our code was updated to use the more intuitive batch first format, please make sure your get_batch function returns data with shapes (batch_size, seq_len, ...)" + if b.y is not None: + assert ( + len(b.x) == len(b.y) == len(b.target_y) == batch_shape.batch_size + ), "Our code was updated to use the more intuitive batch first format, please make sure your get_batch function returns data with shapes (batch_size, seq_len, ...)" # Ensure single_eval_pos is set on the batch object if get_batch_method doesn't handle it if b.single_eval_pos is None: diff --git a/pfns/priors/formula/get_batch.py b/pfns/priors/formula/get_batch.py index 28f5853..f79db5f 100644 --- a/pfns/priors/formula/get_batch.py +++ b/pfns/priors/formula/get_batch.py @@ -2,11 +2,9 @@ from typing import Literal import numpy as np - import torch from .. import Batch - from .ops import binary_ops, unary_ops from .trees import evaluate_tree, sample_tree @@ -80,6 +78,8 @@ def get_batch( n_targets_per_input=1, single_eval_pos=None, # not using this device="cpu", # ignoring this + return_trees=False, + batch_size_per_gp_sample=None, ): assert ( n_targets_per_input == 1 @@ -106,6 +106,9 @@ def get_batch( # longterm todo: add styles based on the tree + if return_trees: + return batch, [tree for _, _, tree in batch_as_list] + return batch diff --git a/pfns/priors/formula/ops.py b/pfns/priors/formula/ops.py index f204b74..41c6323 100644 --- a/pfns/priors/formula/ops.py +++ b/pfns/priors/formula/ops.py @@ -1,7 +1,9 @@ import math +from functools import partial import torch +MAX_OP_OR_INPUT_KEY_LENGTH = 12 # all ops are implemented in torch mapping vectors of length n to vectors of length n ### binary operations @@ -23,8 +25,12 @@ def power(x, y): return x.abs() ** y -def rr_ad_gate(x, y): +def rr_ad_gate(x, y, zero_at_bound=False): bound_index = torch.randint(len(x), (1,)).squeeze() + + if zero_at_bound: + y = y - y[bound_index] + return (x > x[bound_index]) * y @@ -34,13 +40,14 @@ def rr_ad_gate(x, y): "gate": gate, "power": power, "rr_ad_gate": rr_ad_gate, + "rr_aad_gate": partial(rr_ad_gate, zero_at_bound=False), } -### unary operations - +assert ( + MAX_OP_OR_INPUT_KEY_LENGTH >= len(max(binary_ops.keys(), key=len)) +), "MAX_OP_OR_INPUT_KEY_LENGTH must be greater than the length of the longest binary operation" -def ident(x): - return x +### unary operations def absolute(x): @@ -79,11 +86,37 @@ def cubic(x): return x**3 +def relu(x): + return torch.relu(x) + + +def rr_repeat(x, mirror=False): + min_index = torch.randint(len(x), (1,)).squeeze() + max_index = torch.randint(len(x), (1,)).squeeze() + + mini = x[min_index] + maxi = x[max_index] + lo = torch.minimum(mini, maxi) + hi = torch.maximum(mini, maxi) + width = hi - lo + + if width.item() == 0: + return torch.full_like(x, lo) + + offset = torch.remainder(x - lo, width) + + if not mirror: + return lo + offset + + q = torch.floor((x - lo) / width).to(torch.int64) + flip = torch.remainder(q, 2) == 1 + return torch.where(flip, hi - offset, lo + offset) + + # def relative_noise__random(x): # return x * torch.randn_like(x) unary_ops = { - "ident": ident, "abs": absolute, "inv": inv, "sin": sin, @@ -93,5 +126,12 @@ def cubic(x): "square": square, "sigmoid": sigmoid, "cubic": cubic, + "relu": relu, + "rr_repeat": rr_repeat, + "rr_repeat_m": partial(rr_repeat, mirror=True), # 'rr_rel_noise': relative_noise__random, } + +assert ( + MAX_OP_OR_INPUT_KEY_LENGTH >= len(max(unary_ops.keys(), key=len)) +), "MAX_OP_OR_INPUT_KEY_LENGTH must be greater than the length of the longest unary operation" diff --git a/pfns/priors/formula/trees.py b/pfns/priors/formula/trees.py index 48bc9d8..3cc8f88 100644 --- a/pfns/priors/formula/trees.py +++ b/pfns/priors/formula/trees.py @@ -4,7 +4,7 @@ import numpy as np import torch -from .ops import binary_ops, unary_ops +from .ops import binary_ops, MAX_OP_OR_INPUT_KEY_LENGTH, unary_ops node_dtype = np.dtype( [ @@ -14,7 +14,7 @@ ), # The type of the node can be 'binary', 'unary', or 'leaf' ( "op_or_input", - "U12", + f"U{MAX_OP_OR_INPUT_KEY_LENGTH}", ), # The operation (e.g. 'add') or input index, if it is a leaf. ( "left", diff --git a/pfns/priors/heteroskedastic_prior.py b/pfns/priors/heteroskedastic_prior.py new file mode 100644 index 0000000..d9d711a --- /dev/null +++ b/pfns/priors/heteroskedastic_prior.py @@ -0,0 +1,239 @@ +# This is a wrapper prior that adds heteroscedastic noise to datasets. +# The noise variance varies spatially based on a random GP function, and can be +# either normally distributed or long-tailed (with outliers). + +from copy import deepcopy + +import torch +from gpytorch.priors import LogNormalPrior +from pfns.priors import Batch + +from .path_stgp import sample_paths + + +@torch.no_grad() +def get_batch( + batch_size: int, + seq_len: int, + num_features: int, + *args, + hyperparameters: dict | None = None, + get_batch=None, + **kwargs, +) -> Batch: + """Generate a batch with heteroscedastic noise added to the base prior. + + This wrapper prior adds spatially-varying noise to datasets. The noise + variance at each point is determined by a random GP function sampled + using sample_paths from path_stgp.py. The noise can be either normal + or long-tailed (Student-t distribution with outliers). + + Hyperparameters: + hetero_noise_prob: Probability of making the noise heteroscedastic + (spatially varying) vs homoscedastic (constant) (default: 0.5) + hetero_noise_long_tailed_prob: Probability of using long-tailed noise + vs normal noise (default: 0.5) + hetero_noise_df_min: Minimum degrees of freedom for Student-t when + long-tailed (default: 2.0) + hetero_noise_df_max: Maximum degrees of freedom for Student-t when + long-tailed (default: 5.0) + hetero_noise_var_loc: Location parameter for log-normal base noise + variance distribution (default: -4.0, same as path_stgp) + hetero_noise_var_scale: Scale parameter for log-normal base noise + variance distribution (default: 1.0, same as path_stgp) + hetero_noise_range_scale: Scale factor for heteroscedastic variation + on top of base noise (default: 1.0). The heteroscedastic + component adds variation in [0, base_std * range_scale]. + + Additional hyperparameters are passed to sample_paths for the + variance function GP: + - use_rbf_kernel, lengthscale_loc_constant_add, lengthscale_loc_feature_mul, + lengthscale_scale, mean_width, additive_cosine_per_dim_prob + + Args: + batch_size: Number of samples in the batch + seq_len: Sequence length (number of points per sample) + num_features: Number of input features + hyperparameters: Dictionary of hyperparameters + get_batch: The underlying prior's get_batch function to wrap + **kwargs: Additional arguments passed to the underlying prior + + Returns: + Batch with heteroscedastic noise added + """ + if hyperparameters is None: + hyperparameters = {} + + hyperparameters = deepcopy(hyperparameters) + + # Extract heteroscedastic noise hyperparameters + hetero_prob: float = hyperparameters.pop("hetero_noise_prob", 0.5) + long_tailed_prob: float = hyperparameters.pop("hetero_noise_long_tailed_prob", 0.5) + df_min: float = hyperparameters.pop("hetero_noise_df_min", 2.0) + df_max: float = hyperparameters.pop("hetero_noise_df_max", 5.0) + noise_var_loc: float = hyperparameters.pop("hetero_noise_var_loc", -4.0) + noise_var_scale: float = hyperparameters.pop("hetero_noise_var_scale", 1.0) + + # Extract hyperparameters for the variance GP (with prefixed names) + variance_gp_hyperparameters = {} + variance_gp_keys = [ + "hetero_noise_use_rbf_kernel", + "hetero_noise_lengthscale_loc_constant_add", + "hetero_noise_lengthscale_loc_feature_mul", + "hetero_noise_lengthscale_scale", + "hetero_noise_mean_width", + "hetero_noise_additive_cosine_per_dim_prob", + ] + + for key in variance_gp_keys: + if key in hyperparameters: + # Remove prefix and add to variance GP hyperparameters + gp_key = key.replace("hetero_noise_", "") + variance_gp_hyperparameters[gp_key] = hyperparameters.pop(key) + + # Set defaults for variance GP if not provided + variance_gp_hyperparameters.setdefault("use_rbf_kernel", True) + variance_gp_hyperparameters.setdefault("mean_width", 1.0) + + if get_batch is None: + raise ValueError( + "heteroscedastic_noise_prior requires a base get_batch function to wrap" + ) + + # Get batch from base prior + base_batch = get_batch( + batch_size, + seq_len, + num_features, + *args, + hyperparameters=hyperparameters, + **kwargs, + ) + + device = base_batch.x.device + dtype = base_batch.x.dtype + x = base_batch.x # (batch_size, seq_len, num_features) + + # Decide per batch element whether to use heteroscedastic or homoscedastic noise + use_hetero = torch.rand(batch_size, device=device) < hetero_prob + + # Sample base noise variance from log-normal distribution (like path_stgp.py) + base_noise_variance: torch.Tensor = LogNormalPrior( + loc=noise_var_loc, + scale=noise_var_scale, + ).sample((batch_size,)) + base_noise_std = base_noise_variance.sqrt().to(device=device, dtype=dtype) + + # Sample the variance function using sample_paths from path_stgp + # This returns a function that maps x -> y where y varies smoothly + variance_paths = sample_paths(batch_size, num_features, variance_gp_hyperparameters) + + # Evaluate the variance function at all x points + # variance_paths expects (batch_size, n, num_features) and returns (1, batch_size, n) + variance_func_values = variance_paths(x).squeeze(0) # (batch_size, seq_len) + + # Normalize variance function values to [0, 1] per batch + var_min = variance_func_values.min(dim=1, keepdim=True).values + var_max = variance_func_values.max(dim=1, keepdim=True).values + var_range = (var_max - var_min).clamp(min=1e-8) + normalized_variance = (variance_func_values - var_min) / var_range # [0, 1] + + # For heteroscedastic noise: + # std = base_std + normalized_variance * (base_std * range_scale) + # This adds variation in [0, base_std * range_scale] on top of base_std + base_std_expanded = base_noise_std.unsqueeze(1) # (batch_size, 1) + range_size = torch.rand(batch_size, device=device) * 2 # [0, 2] + hetero_std = base_std_expanded * ( + 1 - range_size.clamp(max=1.0).unsqueeze(1) + ) + normalized_variance * ( + base_std_expanded * range_size.unsqueeze(1) + ) # (batch_size, seq_len) + + # For homoscedastic noise: just use the base std + homo_std = base_std_expanded.expand(-1, seq_len) # (batch_size, seq_len) + + # Select heteroscedastic or homoscedastic std per batch element + point_std = torch.where( + use_hetero.unsqueeze(1).expand(-1, seq_len), + hetero_std, + homo_std, + ) # (batch_size, seq_len) + + # Decide per batch element whether to use long-tailed or normal noise + use_long_tailed = torch.rand(batch_size, device=device) < long_tailed_prob + + # Sample noise + # For normal noise: N(0, std^2) + # For long-tailed noise: Student-t with df degrees of freedom, scaled by std + normal_noise = torch.randn(batch_size, seq_len, device=device, dtype=dtype) + + # Sample degrees of freedom for Student-t (per batch) + df = torch.rand(batch_size, device=device) * (df_max - df_min) + df_min + + # Sample Student-t noise using the relationship: t = Z / sqrt(V/df) + # where Z ~ N(0,1) and V ~ Chi-squared(df) + chi2_samples = torch.distributions.Chi2( + df.unsqueeze(1).expand(-1, seq_len) + ).sample() + student_t_noise = normal_noise / torch.sqrt( + chi2_samples / df.unsqueeze(1) + ) # (batch_size, seq_len) + + # Select noise type based on use_long_tailed + noise = torch.where( + use_long_tailed.unsqueeze(1).expand(-1, seq_len), + student_t_noise, + normal_noise, + ) + + # Scale noise by point-wise std + scaled_noise = noise * point_std # (batch_size, seq_len) + + # Add noise to y + y_with_noise = base_batch.y + scaled_noise.unsqueeze(2) + target_y = base_batch.target_y + if hyperparameters.get("noisy_predictions", False): + target_y += scaled_noise.unsqueeze(2) + + # # Plot noises on linspace [0,1] if num_features == 1 (for prototyping) + # import matplotlib.pyplot as plt + # if num_features == 1: + # print("hiii") + # x_plot = torch.linspace(0, 1, 100).unsqueeze(0).unsqueeze(-1) # (1, 100, 1) + # x_plot = x_plot.expand(batch_size, -1, -1).to(device=device, dtype=dtype) + + # # Evaluate variance function on linspace + # variance_func_plot = variance_paths(x_plot).squeeze(0) # (batch_size, 100) + + # # Normalize to [0, 1] + # var_min_plot = variance_func_plot.min(dim=1, keepdim=True).values + # var_max_plot = variance_func_plot.max(dim=1, keepdim=True).values + # var_range_plot = (var_max_plot - var_min_plot).clamp(min=1e-8) + # normalized_variance_plot = (variance_func_plot - var_min_plot) / var_range_plot + + # # Compute std on linspace + # hetero_std_plot = base_std_expanded * ( + # 1 - range_size.clamp(max=1.0).unsqueeze(1) + # ) + normalized_variance_plot * (base_std_expanded * range_size.unsqueeze(1)) + + # # Plot a few samples + # fig, axes = plt.subplots(2, 2, figsize=(10, 8)) + # _ = fig # suppress unused variable warning + # x_np = torch.linspace(0, 1, 100).numpy() + # for i, ax in enumerate(axes.flat): + # if i < batch_size: + # ax.plot(x_np, hetero_std_plot[i].cpu().numpy(), label="std(x)") + # ax.set_xlabel("x") + # ax.set_ylabel("noise std") + # ax.set_title(f"Sample {i} (hetero={use_hetero[i].item()})") + # ax.legend() + # plt.suptitle("Heteroscedastic Noise Std on [0,1]") + # plt.tight_layout() + # plt.show() + + return Batch( + x=base_batch.x, + y=y_with_noise, + target_y=target_y, + style=base_batch.style, + ) diff --git a/pfns/priors/hyperparameter_sampling.py b/pfns/priors/hyperparameter_sampling.py index 3e4632e..87cdac1 100644 --- a/pfns/priors/hyperparameter_sampling.py +++ b/pfns/priors/hyperparameter_sampling.py @@ -44,11 +44,19 @@ def sample(self): def normalize(self, value: torch.Tensor) -> torch.Tensor: if self.log: - return (torch.log(value) - math.log(self.lower)) / ( - math.log(self.upper) - math.log(self.lower) + return (torch.log(value) - torch.log(self.lower)) / ( + torch.log(self.upper) - torch.log(self.lower) ) return (value - self.lower) / (self.upper - self.lower) + def unnormalize(self, encoded_value: torch.Tensor) -> torch.Tensor: + if self.log: + return torch.exp( + encoded_value * (torch.log(self.upper) - torch.log(self.lower)) + + torch.log(self.lower) + ) + return encoded_value * (self.upper - self.lower) + self.lower + def encode_to_torch(self, value): assert (value >= self.lower) and ( value <= self.upper @@ -95,6 +103,12 @@ def normalize(self, value: torch.Tensor) -> torch.Tensor: transformed_upper - transformed_lower ) + def unnormalize(self, encoded_value: torch.Tensor) -> torch.Tensor: + transformed_value = encoded_value * ( + self.upper ** (1 / self.power) - self.lower ** (1 / self.power) + ) + self.lower ** (1 / self.power) + return torch.pow(transformed_value, self.power) + def encode_to_torch(self, value): assert (value >= self.lower) and ( value <= self.upper @@ -127,6 +141,14 @@ def normalize(self, value): ) return (value - self.lower) / (self.upper - self.lower) + def unnormalize(self, encoded_value): + if self.log: + return torch.exp( + encoded_value * (math.log(self.upper) - math.log(self.lower)) + + math.log(self.lower) + ) + return encoded_value * (self.upper - self.lower) + self.lower + def encode_to_torch(self, value): assert (value >= self.lower) and ( value <= self.upper @@ -149,6 +171,10 @@ def normalize(self, value: torch.Tensor): # Return one-hot encoding of the choice return value / (len(self.choices) - 1) + def unnormalize(self, encoded_value: torch.Tensor): + # Return the index of the one-hot encoded value + return encoded_value * (len(self.choices) - 1) + def sample_hyperparameters(config): """Sample values for all hyperparameters in the config""" diff --git a/pfns/priors/path_stgp.py b/pfns/priors/path_stgp.py new file mode 100644 index 0000000..5179e7f --- /dev/null +++ b/pfns/priors/path_stgp.py @@ -0,0 +1,477 @@ +from functools import partial +from math import log, sqrt + +import gpytorch +import torch +from botorch.models.gp_regression import SingleTaskGP +from botorch.sampling.pathwise.prior_samplers import draw_kernel_feature_paths +from gpytorch.constraints.constraints import GreaterThan +from gpytorch.kernels import MaternKernel, RBFKernel +from gpytorch.module import _pyro_sample_from_prior +from gpytorch.priors import LogNormalPrior, NormalPrior +from torch import Size + +from .path_trace_sampling import generate_trace +from .prior import Batch +from .utils import sample_x_around_points + + +def to_random_module_no_copy(module) -> gpytorch.Module: + random_module_cls = type( + "_Random" + module.__class__.__name__, + (gpytorch.module.RandomModuleMixin, module.__class__), + {}, + ) + module.__class__ = random_module_cls # hack + + for mname, child in module.named_children(): + if isinstance(child, gpytorch.Module): + setattr(module, mname, to_random_module_no_copy(child)) + return module + + +def _sample_paths_inner(batch_size, num_features, hyperparameters=None): + """Internal function that samples paths for the given number of non-dummy features. + + Returns a function that takes [batch_size, n, num_features] and returns [1, batch_size, n]. + """ + if hyperparameters is None: + hyperparameters = {} + base_class = ( + RBFKernel if hyperparameters.get("use_rbf_kernel", True) else MaternKernel + ) + lengthscale_prior = LogNormalPrior( + loc=hyperparameters.get("lengthscale_loc_constant_add", sqrt(2)) + + log(num_features) * hyperparameters.get("lengthscale_loc_feature_mul", 0.5), + scale=hyperparameters.get("lengthscale_scale", sqrt(3)), + ) + base_kernel = base_class( + ard_num_dims=num_features, + batch_shape=torch.Size([batch_size]), + lengthscale_prior=lengthscale_prior, + lengthscale_constraint=GreaterThan( + 2.5e-2, transform=None, initial_value=lengthscale_prior.mode + ), + # pyre-ignore[6] GPyTorch type is unnecessarily restrictive. + active_dims=None, + ) + + model = SingleTaskGP( + torch.zeros(batch_size, 1, num_features, dtype=torch.double), + torch.zeros(batch_size, 1, 1, dtype=torch.double), + covar_module=base_kernel, + mean_module=gpytorch.means.ConstantMean( + constant_prior=NormalPrior(loc=0.0, scale=hyperparameters["mean_width"] / 2) + ), + ).to(dtype=torch.double) + model = to_random_module_no_copy(model) + _pyro_sample_from_prior(module=model, memo=None, prefix="") + + init_paths = draw_kernel_feature_paths(model=model, sample_shape=Size((1,))).to( + dtype=torch.double + ) + + if ( + additive_cosine_per_dim_prob := hyperparameters.get( + "additive_cosine_per_dim_prob", 0.0 + ) + ) > 0.0: + additive_cosine_per_dim = ( + torch.rand(batch_size, num_features) < additive_cosine_per_dim_prob + ) + lengthscale = ( + torch.rand(batch_size, num_features, dtype=torch.double) * 0.2 + 0.08 + ) + gp_lengthscale = model.covar_module.lengthscale.view(batch_size, num_features) + magnitude = ( + torch.randn(batch_size, num_features, dtype=torch.double) + / 10 + / gp_lengthscale + ) # very rough... + offset = torch.rand(batch_size, num_features, dtype=torch.double) + + def paths(x): # [batch_size, n, num_features] + y = init_paths(x) + mask = additive_cosine_per_dim.unsqueeze(1).expand(-1, x.shape[1], -1) + return y + torch.where( + mask, + ( + magnitude.unsqueeze(1) + * torch.cos( + 2 + * torch.pi + * (x / lengthscale.unsqueeze(1) + offset.unsqueeze(1)), + dtype=torch.double, + ) + ), + 0.0, + ).sum(-1).unsqueeze(0) / sqrt(num_features) + else: + paths = init_paths + + def paths_with_double(x): + dtype = x.dtype + y = paths(x.double()) + return y.to(dtype) + + # paths takes a tensor of shape [batch_size, n, num_features] and returns a tensor of shape [1, batch_size, n] + return paths_with_double + + +def _get_dummy_dims(num_features, hyperparameters): + """Determine which dimensions are dummy. Returns (dummy_dims mask, num_non_dummy). + + Hyperparameters: + dummy_dim_sample_non_dummy_range: tuple (min, max) - Sample number of non-dummy + dimensions uniformly from [min, max]. Remaining dimensions are dummy. + dummy_dim_sample_non_dummy_range_prob: float - Probability of applying the + dummy_dim_sample_non_dummy_range logic. If not applied, all dimensions + are non-dummy. Default: 1.0 (always apply when range is set). + dummy_dim_prob: float - Each dimension has this probability of being a dummy. + At least one dimension will always be non-dummy. + """ + if ( + non_dummy_range := hyperparameters.get("dummy_dim_sample_non_dummy_range", None) + ) is not None: + # Check if we should apply the non-dummy range logic for this dataset + range_prob = hyperparameters.get("dummy_dim_sample_non_dummy_range_prob", 1.0) + if torch.rand(()).item() >= range_prob: + # Don't apply dummy dims - all dimensions are non-dummy + return None, num_features + + min_non_dummy, max_non_dummy = non_dummy_range + num_non_dummy = min( + torch.randint(min_non_dummy, max_non_dummy + 1, ()).item(), + num_features, + ) + perm = torch.randperm(num_features) + dummy_dims = torch.ones(num_features, dtype=torch.bool) + dummy_dims[perm[:num_non_dummy]] = False + return dummy_dims, num_non_dummy + + elif (dummy_dim_prob := hyperparameters.get("dummy_dim_prob", 0.0)) > 0.0: + dummy_dims = torch.rand(num_features) < dummy_dim_prob + if (~dummy_dims).sum().item() == 0: + dummy_dims[torch.randint(0, num_features, ())] = False + return dummy_dims, (~dummy_dims).int().sum().item() + + return None, num_features + + +def sample_paths(batch_size, num_features, hyperparameters=None): + """Sample GP paths with optional dummy dimension handling and gap discontinuities. + + This function handles dummy dimensions by sampling a GP on only the non-dummy + dimensions and wrapping it in a function that filters out dummy dimensions. + When gaps are enabled, all region functions share the same dummy dimensions. + + Hyperparameters: + dummy_dim_sample_non_dummy_range: tuple (min, max) - Sample number of non-dummy + dimensions uniformly from [min, max]. Remaining dimensions are dummy. + dummy_dim_prob: float - Each dimension has this probability of being a dummy. + At least one dimension will always be non-dummy. + gap_max_splits: int - Max axis-aligned splits for gaps (default: 0) + gap_prob: float - Probability of applying gaps (default: 1.0) + gap_lengthscale_add: float - Add to lengthscale_loc_constant_add when + gaps are applied (default: 0.0, no change) + gap_lengthscale_add_prob: float - Probability of applying the + lengthscale adjustment when gaps are applied (default: 1.0) + + Returns a function that takes [batch_size, n, num_features] and returns [1, batch_size, n]. + """ + if hyperparameters is None: + hyperparameters = {} + + # Determine if gaps will be applied + max_splits = hyperparameters.get("gap_max_splits", 0) + gap_prob = hyperparameters.get("gap_prob", 1.0) + apply_gaps = max_splits > 0 and torch.rand(()).item() < gap_prob + num_splits = torch.randint(1, max_splits + 1, (1,)).item() if apply_gaps else 0 + + # Optionally adjust lengthscale when gaps are applied + if num_splits > 0: + lf = hyperparameters.get("gap_lengthscale_add", 0.0) + lf_prob = hyperparameters.get("gap_lengthscale_add_prob", 1.0) + if lf != 0.0 and torch.rand(()).item() < lf_prob: + hyperparameters = hyperparameters.copy() + loc = hyperparameters.get("lengthscale_loc_constant_add", sqrt(2)) + hyperparameters["lengthscale_loc_constant_add"] = loc + lf + + # Determine dummy dimensions once (shared across all region functions) + dummy_dims, num_non_dummy = _get_dummy_dims(num_features, hyperparameters) + + if num_splits == 0: + inner = _sample_paths_inner(batch_size, num_non_dummy, hyperparameters) + if dummy_dims is None: + return inner + return lambda x: inner(x[:, :, ~dummy_dims]) + + # Sample independent GP paths for each region (2^num_splits regions) + num_regions = 2**num_splits + region_paths = [ + _sample_paths_inner(batch_size, num_non_dummy, hyperparameters) + for _ in range(num_regions) + ] + + # Sample split parameters (on original feature space, not reduced) + feature_indices = torch.randint(0, num_features, (batch_size, num_splits)) + thresholds = torch.rand(batch_size, num_splits) + + def paths_with_gaps(x): + # Compute region index for each point (using full x with all features) + fi_exp = feature_indices.unsqueeze(1).expand(-1, x.shape[1], -1) + x_at_splits = torch.gather(x, dim=2, index=fi_exp) + above = (x_at_splits > thresholds.unsqueeze(1)).long() + powers = (2 ** torch.arange(num_splits)).view(1, 1, -1) + region_idx = (above * powers).sum(dim=2) # [batch_size, n] + + # Filter to non-dummy dims for GP evaluation + x_filtered = x if dummy_dims is None else x[:, :, ~dummy_dims] + + # Evaluate all region paths and select based on region_idx + all_ys = torch.stack( + [p(x_filtered).squeeze(0) for p in region_paths], dim=2 + ) # [batch_size, n, num_regions] + y = torch.gather(all_ys, dim=2, index=region_idx.unsqueeze(2)).squeeze(2) + return y.unsqueeze(0) # [1, batch_size, n] + + return paths_with_gaps + + +# paths = draw_matheron_paths(model=model, sample_shape=Size((128,))) + + +def sample_clustered_x( + batch_size, + seq_len, + num_features, + pad_factor: int = 10, + num_cluster_max: int = 1, + max_std: float = 0.25, +): + """ + This function samples a batch of inputs from normal distributions. + Its outputs are all in [0,1], which is ensured by over-sampling (pad_factor) + and then rejecting outside samples. In addition, we clamp the values to [0,1]. + """ + num_clusters = torch.randint(1, num_cluster_max + 1, tuple()).item() + + mean = torch.rand(batch_size, num_clusters, num_features) + std = torch.rand(batch_size, num_clusters, num_features) * max_std + + # define mean and std for each position + # to do that we randomly pick values from the num_clusters dimension + mean = ( + mean[:, torch.randint(0, num_clusters, (seq_len,)), :] + .transpose(1, 2) + .repeat(1, 1, pad_factor) + ) + std = ( + std[:, torch.randint(0, num_clusters, (seq_len,)), :] + .transpose(1, 2) + .repeat(1, 1, pad_factor) + ) + + x = torch.randn(batch_size, num_features, seq_len * pad_factor) * std + mean + x = x.transpose(1, 2) + sorting_x = ((x >= 0.0) & (x <= 1.0)).sum(dim=-1) + order = torch.argsort(sorting_x, dim=-1, stable=True, descending=True) + print(f"{order.shape=}") + x = x.gather( + dim=1, index=order[:, :seq_len].unsqueeze(-1).expand(-1, -1, num_features) + ) + x = x.clamp(0, 1) + return x + + +def sample_around_train_point( + batch_size, + seq_len, + num_features, + single_eval_pos, + surrounding_std: float = 0.01, + surrounding_share: float = 0.5, + binary_feature_likelihood: float = 0.0, +): + binary_features = ( + (torch.rand(batch_size, num_features) < binary_feature_likelihood) + .unsqueeze(1) + .expand(-1, single_eval_pos, -1) + ) + train_x = torch.rand(batch_size, single_eval_pos, num_features) + train_x_cutoffs = torch.rand(batch_size, single_eval_pos, num_features) + train_x[binary_features] = ( + train_x[binary_features] > train_x_cutoffs[binary_features] + ).float() + + num_test_points = seq_len - single_eval_pos + num_surrounding = int(num_test_points * surrounding_share) + + normal_test_x = torch.rand( + batch_size, num_test_points - num_surrounding, num_features + ) + + # Use shared utility for sampling around training points + surrounding_test_x = sample_x_around_points( + batch_size=batch_size, + num_samples=num_surrounding, + num_features=num_features, + centers=train_x, + std=surrounding_std, + device=train_x.device, + ) + + x = torch.cat([train_x, normal_test_x, surrounding_test_x], dim=1) + return x + + +def add_noise(y, hyperparameters, no_noise=None): + batch_size = y.shape[0] + noise_var_dist = hyperparameters.get("noise_var_dist", "lognormal") + if noise_var_dist == "lognormal": + noise_variance: torch.Tensor = LogNormalPrior( + loc=hyperparameters["noise_var_loc"], + scale=hyperparameters["noise_var_scale"], + ).sample((batch_size,)) + elif noise_var_dist == "gamma": + noise_variance: torch.Tensor = ( + torch.distributions.Gamma( + hyperparameters["noise_var_concentration"], + hyperparameters["noise_var_rate"], + ).sample((batch_size,)) + + 1e-4 + ) + else: + raise ValueError(f"Unknown noise variance distribution {noise_var_dist}") + + if no_noise is not None: + noise_variance[no_noise] = 0.0 + + noisy_y = y + torch.randn_like(y) * noise_variance[:, None, None] ** (1 / 2) + return noisy_y + + +@torch.no_grad() +def get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters=None, + n_targets_per_input=1, + **kwargs, +): + if hyperparameters is None: + hyperparameters = { + "noise_var_loc": -4.0, + "noise_var_scale": 1.0, + } + + no_noise_prob = hyperparameters.get("no_noise_prob", 0.0) + no_noise = torch.rand(batch_size) < no_noise_prob + + assert hyperparameters.get("mean_dist", "normal") == "normal" + + paths = sample_paths( + batch_size, num_features, hyperparameters + ) # paths maps [batch_size, n, num_features] to [batch_size, n, 1] + + sample_clustered_x_hp = hyperparameters.get("sample_clustered_x", None) + + if sample_clustered_x_hp == "trace": + assert no_noise_prob == 0.0 + x, y = generate_trace( + batch_size, + paths, + partial(add_noise, hyperparameters=hyperparameters), + seq_len, + single_eval_pos, + bounds=[(0, 1)] * num_features, + ) + + y = y.view(batch_size, seq_len, 1) + noisy_y = y + + else: + if sample_clustered_x_hp is True or sample_clustered_x_hp == "clustered": + x = sample_clustered_x( + batch_size, + seq_len, + num_features, + num_cluster_max=hyperparameters.get("num_cluster_max", 1), + max_std=hyperparameters.get("max_std", 0.25), + ) + elif sample_clustered_x_hp.startswith("around_train_point_binp_"): + binary_prob = float(sample_clustered_x_hp.split("_")[-1]) + x = sample_around_train_point( + batch_size, + seq_len, + num_features, + single_eval_pos, + binary_feature_likelihood=binary_prob, + ) + elif sample_clustered_x_hp == "around_train_point": + x = sample_around_train_point( + batch_size, + seq_len, + num_features, + single_eval_pos, + ) + else: + assert (sample_clustered_x_hp is None) or ( + sample_clustered_x_hp == "none" + ), sample_clustered_x_hp + x = torch.rand(batch_size, seq_len, num_features) + + y = paths(x).squeeze(0) # shape: (batch_size, seq_len) + + y = y.view(batch_size, seq_len, 1) + + noisy_y = add_noise(y, hyperparameters, no_noise) + + if hyperparameters["noisy_predictions"]: + target_y = add_noise( + y.expand(-1, -1, n_targets_per_input), hyperparameters, no_noise + ) + else: + target_y = y.expand(-1, -1, n_targets_per_input) + + if hyperparameters.get("train_normalized_y", False): + if single_eval_pos <= 1: + raise ValueError("train_normalized_y requires single_eval_pos > 1") + + train_mean = noisy_y[:, :single_eval_pos].mean(dim=1, keepdim=True) + train_std = noisy_y[:, :single_eval_pos].std(dim=1, keepdim=True) + noisy_y = (noisy_y - train_mean) / train_std + target_y = (target_y - train_mean) / train_std + + predict_advantage = hyperparameters.get("predict_advantage", False) + if predict_advantage is True: + target_y = ( + target_y[:, :, :] + - target_y[:, :single_eval_pos, :] + .max(dim=-1, keepdim=True) + .values.max(dim=-2, keepdim=True) + .values + ) + elif predict_advantage == "y": + target_y = ( + target_y[:, :, :] + - noisy_y[:, :single_eval_pos, :] + .max(dim=-1, keepdim=True) + .values.max(dim=-2, keepdim=True) + .values + ) + else: + assert predict_advantage is False + + if hyperparameters.get("relu_target", False): + target_y = target_y.clamp(min=0.0) + + # set ys to nan in training set + number_of_y_hidden = torch.randint( + 0, hyperparameters.get("max_num_hidden_y", 0) + 1, tuple() + ) + noisy_y[:, single_eval_pos - number_of_y_hidden : single_eval_pos] = torch.nan + + return Batch(x=x, y=noisy_y, target_y=target_y) diff --git a/pfns/priors/path_trace_sampling.py b/pfns/priors/path_trace_sampling.py new file mode 100644 index 0000000..3e0115e --- /dev/null +++ b/pfns/priors/path_trace_sampling.py @@ -0,0 +1,183 @@ +import numpy as np +import torch + + +# Vectorized corner check for batch of points +def corner_check(x, corners): + # x can be a single point or batch of points + batching = True + if x.ndim == 1: + batching = False + x = x[np.newaxis, :] + + # Check which points are corners (all coords <= 0 or >= 1) + is_corner = np.all((x <= 0) | (x >= 1), axis=1) + new_corners = np.zeros(len(x), dtype=bool) + + # For non-corner points, return True + results = np.ones(len(x), dtype=bool) + + if np.any(is_corner): + # For corner points, compute their IDs + corner_powers = np.array([2**i for i in range(x.shape[1])]) + corner_ids = (x[is_corner] @ corner_powers).round().astype(int) + corner_ids = np.array( + [cid * x.shape[0] + i for i, cid in enumerate(corner_ids)] + ) + + # Check which corner IDs are new + new_corners = np.array([cid not in corners for cid in corner_ids]) + + # Add new corner IDs to the set + corners.update(corner_ids[new_corners]) + + # For corner points, return True if new corner, False if already seen + results[is_corner] = new_corners + + return results if batching else results[0], corners + + +def sample_until_all_success(sampling_function, corners): + sample = sampling_function() # shape: [batch_size, d] + success, corners = corner_check(sample, corners) + while not np.all(success): + # Keep successful samples, only resample failed ones + failed_mask = ~success + new_sample = sampling_function() + new_success, corners = corner_check(new_sample, corners) + # Update only failed positions with new successful samples + sample[failed_mask] = new_sample[failed_mask] + success[failed_mask] = new_success[failed_mask] + return sample + + +@torch.no_grad() +def generate_trace( + batch_size, + paths, + add_noise, + L, + cutoff, + bounds, + best=None, + never_local=False, + dtype=torch.float, +): + """ + Generate optimization traces blending exploration and exploitation for batched Gaussian Processes. + + Parameters: + - L: int, length of the trace. + - cutoff: int, position in the trace after which we may sample around the global optimum. + - bounds: list of tuples [(min1, max1), (min2, max2), ...], search space bounds. + - t_random_weights: torch.Tensor, random weights for GP Fourier features [batch_size, d, num_features]. + - t_random_offset: torch.Tensor, random offsets for GP Fourier features [batch_size, num_features]. + - t_W_GP: torch.Tensor, weights defining the GP in RFF space [batch_size, num_features, 1]. + - sigma_output: float, output scale of the GP. + - sigma_noise: float, observation noise level. + - mean_function: float, mean function of the GP. + - best: ndarray, optional, the location of the global optimum [batch_size, d]. + + Returns: + - trace: ndarray of shape [batch_size, L, d], the optimization traces. + - y: ndarray of shape [batch_size, L], the function values at each point in the traces. + """ + d = len(bounds) + trace = np.zeros((batch_size, L, d)) + y = np.zeros((batch_size, L)) + + corners = set() + + # Initialize + eps = (1 - np.random.rand(batch_size) ** (d / 6)) / 2 + sigma = np.exp(np.random.normal(-3, 0.5, size=(batch_size,))) + u = np.random.uniform(size=(batch_size, 3)) + initial_alpha = u.min(axis=1) + final_alpha = u.max(axis=1) + trace[:, 0] = np.clip( + np.random.uniform(-eps[:, None], 1 + eps[:, None], (batch_size, d)), 0, 1 + ) + best_point = trace[:, 0].copy() + + # Get initial values using vectorized GP evaluation + y[:, :1] = add_noise(paths(torch.tensor(trace[:, :1], dtype=dtype))).numpy() + y_best = y[:, 0].copy() + + for i in range(1, L): + alpha = initial_alpha + (final_alpha - initial_alpha) * (i / L) + local = np.random.rand(batch_size) < alpha + if never_local: + local[:] = False + + # could speed up by factor of 2 + + def sample_local(): + if i < cutoff: # noqa: B023 + inc = best_point + else: + inc = np.zeros_like(best_point) + if best is not None: + use_best = (L - cutoff) * np.random.rand(batch_size) < 5 * d + inc[use_best] = best[use_best] + + if cutoff > 0: + use_cutoff = ( + ~use_best + if best is not None + else np.ones(batch_size, dtype=bool) + ) + random_cutoff_indices = np.random.choice(cutoff, size=batch_size) + inc[use_cutoff] = trace[ + np.arange(batch_size)[use_cutoff], + random_cutoff_indices[use_cutoff], + ] + else: + # No point to sample locally around, just sample globally + use_global = ( + ~use_best + if best is not None + else np.ones(batch_size, dtype=bool) + ) + inc[use_global] = np.clip( + np.random.uniform( + -eps[use_global, None], + 1 + eps[use_global, None], + (np.sum(use_global), d), + ), + 0, + 1, + ) + + ret = np.random.normal(inc, sigma[:, None], size=(batch_size, d)) + return np.clip( + ret, [low for low, _ in bounds], [high for _, high in bounds] + ) + + def sample_global(): + return np.clip( + np.random.uniform(-eps[:, None], 1 + eps[:, None], (batch_size, d)), + 0, + 1, + ) + + # Sample points based on local/global strategy + trace[:, i] = np.where( + local[:, None], + sample_until_all_success(sample_local, corners), + sample_until_all_success(sample_global, corners), + ) + + # Update the current best if before cutoff + if i < cutoff: + y[:, i : i + 1] = add_noise( + paths(torch.tensor(trace[:, i : i + 1], dtype=dtype)) + ).numpy() + better_mask = y[:, i] > y_best + best_point[better_mask] = trace[better_mask, i] + y_best[better_mask] = y[:, i][better_mask] + + # Get noiseless values after cutoff using vectorized evaluation + if cutoff < L: + y[:, cutoff:] = paths(torch.tensor(trace[:, cutoff:], dtype=dtype)).numpy() + + return torch.tensor(trace, dtype=dtype), torch.tensor(y, dtype=dtype) diff --git a/pfns/priors/prior.py b/pfns/priors/prior.py index 8120045..c31729d 100644 --- a/pfns/priors/prior.py +++ b/pfns/priors/prior.py @@ -20,9 +20,11 @@ def create_get_batch_method(self) -> Callable: @dataclass(frozen=True) class AdhocPriorConfig(PriorConfig): # Set as a class variable instead of being set at init - prior_names: str | Sequence[str] | None = None + prior_names: list[str] | None = None get_batch_methods: Callable | Sequence[Callable] | None = None prior_kwargs: dict | None = None + prior_dirs: str | list[str] = "pfns.priors" + get_batch_names: str | list[str] = "get_batch" strict_field_types: ClassVar[bool] = False @@ -36,13 +38,21 @@ def create_get_batch_method(self) -> Callable: if self.prior_names is not None: get_batch_methods = [] - for prior_name in ( - self.prior_names - if isinstance(self.prior_names, Sequence) - else [self.prior_names] - ): - prior_module = importlib.import_module(f"pfns.priors.{prior_name}") - get_batch_methods.append(prior_module.get_batch) + get_batch_names = ( + self.get_batch_names + if isinstance(self.get_batch_names, list) + else [self.get_batch_names] * len(self.prior_names) + ) + prior_dirs = ( + self.prior_dirs + if isinstance(self.prior_dirs, list) + else [self.prior_dirs] * len(self.prior_names) + ) + assert len(self.prior_names) == len(get_batch_names) == len(prior_dirs) + + for i, prior_name in enumerate(self.prior_names): + prior_module = importlib.import_module(f"{prior_dirs[i]}.{prior_name}") + get_batch_methods.append(getattr(prior_module, get_batch_names[i])) else: get_batch_methods = ( self.get_batch_methods @@ -68,8 +78,14 @@ class Batch: # Required entries x: torch.Tensor - y: torch.Tensor - target_y: torch.Tensor + + # Entries when using sep. y + y: torch.Tensor | None + target_y: torch.Tensor | None + + # Entries for x_only_mode + target: torch.Tensor | None = None + test_x: torch.Tensor | None = None # Optional Batch Entries style: Optional[torch.Tensor] = None diff --git a/pfns/priors/saas.py b/pfns/priors/saas.py new file mode 100644 index 0000000..55482a3 --- /dev/null +++ b/pfns/priors/saas.py @@ -0,0 +1,105 @@ +import torch +from botorch.models.fully_bayesian import SaasPyroModel + +from .path_stgp import sample_around_train_point, sample_clustered_x +from .prior import Batch + + +@torch.no_grad() +def get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters=None, + n_targets_per_input=1, + **kwargs, +): + if hyperparameters is None: + hyperparameters = {} + + sample_clustered_x_hp = hyperparameters.get("sample_clustered_x", None) + + # Sample x based on the specified method (same logic as path_stgp.py) + if sample_clustered_x_hp is True or sample_clustered_x_hp == "clustered": + x = sample_clustered_x( + batch_size, + seq_len, + num_features, + num_cluster_max=hyperparameters.get("num_cluster_max", 1), + max_std=hyperparameters.get("max_std", 0.25), + ) + elif sample_clustered_x_hp and sample_clustered_x_hp.startswith( + "around_train_point_binp_" + ): + binary_prob = float(sample_clustered_x_hp.split("_")[-1]) + x = sample_around_train_point( + batch_size, + seq_len, + num_features, + single_eval_pos, + binary_feature_likelihood=binary_prob, + ) + elif sample_clustered_x_hp == "around_train_point": + x = sample_around_train_point( + batch_size, + seq_len, + num_features, + single_eval_pos, + ) + else: + assert sample_clustered_x_hp is None, sample_clustered_x_hp + x = torch.rand(batch_size, seq_len, num_features) + + no_noise_prob = hyperparameters.get("no_noise_prob", 0.0) + no_noise = torch.rand(batch_size) < no_noise_prob + + # Sample each batch item separately using SAAS prior + y_list = [] + noisy_y_list = [] + + for i in range(batch_size): + m = SaasPyroModel() + m._prior_mode = True + m.set_inputs( + x[i].to(dtype=torch.float64), + torch.zeros(seq_len, 1, dtype=torch.float64), + ) + m.sample() + + # Get noiseless and noisy predictions + noiseless_y = m.f_prior_sample.to(dtype=torch.float32) + if no_noise[i]: + noisy_y = noiseless_y.clone() + else: + noisy_y = m.Y_prior_sample.to(dtype=torch.float32) + + y_list.append(noiseless_y) + noisy_y_list.append(noisy_y) + + y = torch.stack(y_list, dim=0).view(batch_size, seq_len, 1) + noisy_y = torch.stack(noisy_y_list, dim=0).view(batch_size, seq_len, 1) + + # Handle n_targets_per_input + if hyperparameters.get("noisy_predictions", False): + target_y = noisy_y.expand(-1, -1, n_targets_per_input) + else: + target_y = y.expand(-1, -1, n_targets_per_input) + + # Handle predict_advantage + if hyperparameters.get("predict_advantage", False): + target_y = ( + target_y[:, :, :] + - target_y[:, :single_eval_pos, :] + .max(dim=-1, keepdim=True) + .values.max(dim=-2, keepdim=True) + .values + ) + + # Set ys to nan in training set + number_of_y_hidden = torch.randint( + 0, hyperparameters.get("max_num_hidden_y", 0) + 1, tuple() + ) + noisy_y[:, single_eval_pos - number_of_y_hidden : single_eval_pos] = torch.nan + + return Batch(x=x, y=noisy_y, target_y=target_y) diff --git a/pfns/priors/singletaskgp.py b/pfns/priors/singletaskgp.py new file mode 100644 index 0000000..634d3e0 --- /dev/null +++ b/pfns/priors/singletaskgp.py @@ -0,0 +1,441 @@ +from math import log, sqrt + +import torch + +from gpytorch.distributions.multivariate_normal import MultivariateNormal +from gpytorch.kernels import LinearKernel, MaternKernel, RBFKernel +from gpytorch.priors import LogNormalPrior + +from pfns.priors.prior import Batch +from torch import Tensor + + +def sample_clustered_x( + batch_size, + seq_len, + num_features, + pad_factor: int = 10, + num_cluster_max: int = 1, + max_std: float = 0.25, +): + """ + This function samples a batch of inputs from normal distributions. + Its outputs are all in [0,1], which is ensured by over-sampling (pad_factor) + and then rejecting outside samples. In addition, we clamp the values to [0,1]. + """ + num_clusters = torch.randint(1, num_cluster_max + 1, tuple()).item() + + mean = torch.rand(batch_size, num_clusters, num_features) + std = torch.rand(batch_size, num_clusters, num_features) * max_std + + # define mean and std for each position + # to do that we randomly pick values from the num_clusters dimension + mean = ( + mean[:, torch.randint(0, num_clusters, (seq_len,)), :] + .transpose(1, 2) + .repeat(1, 1, pad_factor) + ) + std = ( + std[:, torch.randint(0, num_clusters, (seq_len,)), :] + .transpose(1, 2) + .repeat(1, 1, pad_factor) + ) + + x = torch.randn(batch_size, num_features, seq_len * pad_factor) * std + mean + x = x.transpose(1, 2) + sorting_x = ((x >= 0.0) & (x <= 1.0)).sum(dim=-1) + order = torch.argsort(sorting_x, dim=-1, stable=True, descending=True) + x = x.gather( + dim=1, index=order[:, :seq_len].unsqueeze(-1).expand(-1, -1, num_features) + ) + x = x.clamp(0, 1) + return x + + +def sample_around_train_point( + batch_size, + seq_len, + num_features, + single_eval_pos, + surrounding_std: float = 0.01, + surrounding_share: float = 0.5, +): + train_x = torch.rand(batch_size, single_eval_pos, num_features) + + num_test_points = seq_len - single_eval_pos + num_surrounding = int(num_test_points * surrounding_share) + + normal_test_x = torch.rand( + batch_size, num_test_points - num_surrounding, num_features + ) + if single_eval_pos > 0: + centers = torch.multinomial( + torch.ones(single_eval_pos), num_surrounding, replacement=True + ) + surrounding_test_x = ( + torch.randn(batch_size, num_surrounding, num_features) * surrounding_std + + train_x[:, centers] + ) + else: + surrounding_test_x = torch.rand(batch_size, num_surrounding, num_features) + x = torch.cat([train_x, normal_test_x, surrounding_test_x], dim=1) + return x + + +# adapted from botorch to support batching +def inv_kumaraswamy_warp( + X: Tensor, c0: Tensor, c1: Tensor, eps: float = 1e-8 +) -> Tensor: + """Map warped inputs through an inverse Kumaraswamy CDF. + + This takes warped inputs (X) and transforms those via an inverse + Kumaraswamy CDF. This then unnormalizes the inputs using bounds of + [eps, 1-eps]^d and ensures that the values are within [0, 1]^d. + + Args: + X: A `b x n x d`-dim tensor of inputs. + c0: A `b x d`-dim tensor of the concentration0 parameter for the + Kumaraswamy distribution. + c1: A `b x d`-dim tensor of the concentration1 parameter for the + Kumaraswamy distribution. + eps: A small value that is used to ensure inputs are not 0 or 1. + + Returns: + A `batch_shape x n x d`-dim tensor of untransformed inputs. + """ + X_range = 1 - 2 * eps + # unnormalize from [eps, 1-eps] to [0,1] + untf_X = (1 - (1 - X).pow(1 / c0.unsqueeze(1))).pow(1 / c1.unsqueeze(1)) + return ((untf_X - eps) / X_range).clamp(0.0, 1.0) + + +@torch.no_grad() +def get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters=None, + n_targets_per_input=1, + print_infos=False, + **kwargs, +): + if hyperparameters is None: + hyperparameters = { + "lengthscale_loc_constant_add": sqrt(2), # same as in fully + "lengthscale_loc_feature_mul": 0.5, # same as in fully + "lengthscale_scale": sqrt(3), # same as in fully + "noise_var_loc": -4.0, # different in fully bayesian + "noise_var_scale": 1.0, # different in fully bayesian + "noise_var_dist": "lognormal", # different in fully bayesian, where it is "gamma" + "noise_var_concentration": 0.9, + "noise_var_rate": 10.0, + "mean_width": 2.0, + "mean_dist": "uniform", + "attsink_tokens": 0, + "noisy_predictions": False, + "sample_strategy": "uniform", + "train_normalized_y": False, + "style_for_max_on_border_likelihood": False, + "dummy_dim_prob": 0.0, + # Oversampling factor: build a super dataset of this factor times seq_len + "oversample_factor": 1.0, + # Proportion of final dataset sampled from the top-share (by non-noisy y) + "top_sampling_share": 0.0, + # The fraction of the super dataset considered as the top-share + "top_share_of_super": 0.1, + # input warping + "input_warping_prob": 0.0, + "input_warping_c0_std": 0.75**0.5, + "input_warping_c1_std": 0.75**0.5, + } + + # Build dataset, possibly oversampled if using top sampling + top_sampling_share = hyperparameters.get("top_sampling_share", 0.0) + + if top_sampling_share > 0.0: + oversample_factor = hyperparameters.get("oversample_factor", 1.0) + assert ( + oversample_factor > 1.0 + ), "oversample_factor must be > 1.0 when top_sampling_share > 0" + super_seq_len = round(seq_len * oversample_factor) + else: + assert ( + hyperparameters.get("oversample_factor", 1.0) == 1.0 + ), "oversample_factor must be 1.0 when top_sampling_share is 0" + super_seq_len = seq_len + + sample_clustered_x_hp = hyperparameters.get("sample_clustered_x", False) + + if sample_clustered_x_hp is True or sample_clustered_x_hp == "clustered": + x_super = sample_clustered_x( + batch_size, + super_seq_len, + num_features, + num_cluster_max=hyperparameters.get("num_cluster_max", 1), + max_std=hyperparameters.get("max_std", 0.25), + ) + elif sample_clustered_x_hp == "around_train_point": + x_super = sample_around_train_point( + batch_size, + super_seq_len, + num_features, + single_eval_pos, + ) + else: + x_super = torch.rand(batch_size, super_seq_len, num_features) + + mean_width = hyperparameters["mean_width"] + if mean_width == 0: + mean = torch.zeros(batch_size) + else: + mean_dist = hyperparameters.get("mean_dist", "uniform") + if mean_dist == "uniform": + min_mean, max_mean = -mean_width / 2, mean_width / 2 + mean = torch.rand(batch_size) * (max_mean - min_mean) + min_mean + elif mean_dist == "normal": + mean = torch.randn(batch_size) * mean_width / 2 + else: + raise ValueError(f"Unknown mean distribution {mean_dist}") + + if (dummy_dim_prob := hyperparameters.get("dummy_dim_prob", 0.0)) > 0.0: + num_important_features = 0 + while num_important_features == 0: + dummy_dims_mask = torch.bernoulli( + torch.full((num_features,), dummy_dim_prob) + ).bool() + used_dims_mask = ~dummy_dims_mask + num_important_features = used_dims_mask.sum() + else: + used_dims_mask = torch.ones(num_features, dtype=torch.bool) + num_important_features = num_features + + length_scales = LogNormalPrior( + loc=hyperparameters["lengthscale_loc_constant_add"] + + log(num_important_features) * hyperparameters["lengthscale_loc_feature_mul"], + scale=hyperparameters["lengthscale_scale"], + ).sample((batch_size, num_important_features)) + + kernel_name = hyperparameters.get("kernel", "rbf") + + def get_covar(length_scales, x): + if kernel_name == "rbf": + covar_module = RBFKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + ) + elif kernel_name == "matern_1.5": + covar_module = MaternKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + nu=1.5, + ) + elif kernel_name == "matern_2.5": + covar_module = MaternKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + nu=2.5, + ) + elif kernel_name == "linear": + covar_module = LinearKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + ) + else: + raise ValueError(f"Unknown kernel {kernel_name}") + if covar_module.has_lengthscale: + covar_module._set_lengthscale(length_scales) + covar = covar_module(x[..., used_dims_mask], x[..., used_dims_mask]) + return covar + + style = None + + if hyperparameters.get("additive", False): + num_features_in_group1 = torch.randint(0, num_features, tuple()).item() + perm = torch.randperm(num_features) + features_in_group1 = perm[:num_features_in_group1] + features_in_group0 = perm[num_features_in_group1:] + + covar0 = get_covar( + length_scales[:, features_in_group0], x_super[:, :, features_in_group0] + ) + covar1 = get_covar( + length_scales[:, features_in_group1], x_super[:, :, features_in_group1] + ) + + d0 = MultivariateNormal( + torch.ones_like(x_super[:, :, 0]) * mean[:, None], covar0 + ) + d1 = MultivariateNormal(torch.zeros_like(x_super[:, :, 0]), covar1) + y_super: torch.Tensor = d0.sample() + d1.sample() + style = torch.zeros(batch_size, num_features, 1) + style[:, features_in_group1, :] = 1.0 + style = (style * 2.0) - 1.0 + else: + covar = get_covar(length_scales, x_super) + d = MultivariateNormal(torch.ones_like(x_super[:, :, 0]) * mean[:, None], covar) + y_super: torch.Tensor = d.sample() + + # Select final dataset of length seq_len, possibly from top share of larger super dataset + device = x_super.device + x = torch.empty(batch_size, seq_len, num_features, device=device) + y = torch.empty(batch_size, seq_len, device=device) + + if top_sampling_share > 0.0: + top_share_of_super = hyperparameters.get("top_share_of_super", 0.1) + # Calculate sizes with bounds checking using max/min + top_k_count = min( + max(0, round(top_share_of_super * super_seq_len)), super_seq_len + ) + n_top = min( + max(0, round(top_sampling_share * seq_len)), min(top_k_count, seq_len) + ) + n_rest = seq_len - n_top + + if top_k_count > 0: + top_inds = torch.topk(y_super, k=top_k_count, largest=True).indices + else: + top_inds = torch.empty(batch_size, 0, dtype=torch.long, device=device) + + for b in range(batch_size): + # Get top indices by non-noisy y value + top_idx_b = top_inds[b] + + chosen_indices = [] + if n_top > 0 and top_idx_b.numel() > 0: + perm_top = torch.randperm(top_idx_b.numel(), device=device) + chosen_top = top_idx_b[perm_top[:n_top]] + chosen_indices.append(chosen_top) + + if n_rest > 0: + all_idx = torch.arange(super_seq_len, device=device) + if chosen_indices: + chosen_cat = torch.cat(chosen_indices) + mask = torch.ones(super_seq_len, dtype=torch.bool, device=device) + mask[chosen_cat] = False + remaining_idx = all_idx[mask] + else: + remaining_idx = all_idx + perm_rem = torch.randperm(remaining_idx.numel(), device=device) + chosen_rest = remaining_idx[perm_rem[:n_rest]] + chosen_indices.append(chosen_rest) + + # Combine and shuffle indices + final_idx = torch.cat(chosen_indices) + final_idx = final_idx[torch.randperm(final_idx.numel(), device=device)] + + x[b] = x_super[b, final_idx] + y[b] = y_super[b, final_idx] + else: + # No oversampling, just use the dataset as is + x = x_super + y = y_super + + if hyperparameters.get("style_for_max_on_border_likelihood", False): + max_i = y.max(1).indices # (B,) + max_x = x[torch.arange(len(max_i)), max_i] # (B,F) + mins = x.min(1).values # (B,F) + maxs = x.max(1).values # (B,F) + is_max_or_min_on_border = (max_x == mins) | (max_x == maxs) # (B,F) + + sureness = torch.rand(batch_size, num_features) # (B,F) + correct_hint = torch.bernoulli(sureness).bool() # (B,F) + border_style = ( + correct_hint * is_max_or_min_on_border * sureness + + ~correct_hint * is_max_or_min_on_border * (1 - sureness) + + correct_hint * ~is_max_or_min_on_border * (1 - sureness) + + ~correct_hint * ~is_max_or_min_on_border * sureness + )[:, :, None] # (B,F,1) + # for b in [0.1 * i for i in range(1, 10)]: + # mask = (border_style.flatten() < b) & (border_style.flatten() >= b - 0.1) + # print(b, is_max_or_min_on_border.flatten()[mask].float().mean(), mask.sum()) + # border_style = torch.zeros(batch_size, num_features, 1) + if style is None: + style = border_style + else: + style = torch.cat([style, border_style], dim=-1) + + noise_var_dist = hyperparameters.get("noise_var_dist", "lognormal") + if noise_var_dist == "lognormal": + noise_variance: torch.Tensor = LogNormalPrior( + loc=hyperparameters["noise_var_loc"], + scale=hyperparameters["noise_var_scale"], + ).sample((batch_size,)) + elif noise_var_dist == "gamma": + noise_variance: torch.Tensor = ( + torch.distributions.Gamma( + hyperparameters["noise_var_concentration"], + hyperparameters["noise_var_rate"], + ).sample((batch_size,)) + + 1e-4 + ) + else: + raise ValueError(f"Unknown noise variance distribution {noise_var_dist}") + + noisy_y = y + torch.randn_like(y) * noise_variance[:, None] ** (1 / 2) + + if hyperparameters["noisy_predictions"]: + target_y = y.view(batch_size, seq_len, 1) + torch.randn( + batch_size, seq_len, n_targets_per_input + ) * torch.sqrt(noise_variance[:, None, None]) + else: + target_y = y.view(batch_size, seq_len, 1).repeat(1, 1, n_targets_per_input) + + train_normalized_y = hyperparameters.get("train_normalized_y", False) + + if hyperparameters.get("predict_advantage", False): + assert not train_normalized_y + target_y = ( + target_y[:, :, :] + - target_y[:, :single_eval_pos, :] + .max(dim=-1, keepdim=True) + .values.max(dim=-2, keepdim=True) + .values + ) + + if train_normalized_y: + if single_eval_pos <= 1: + raise ValueError("train_normalized_y requires single_eval_pos > 1") + + train_mean = noisy_y[:, :single_eval_pos].mean(dim=1, keepdim=True) + train_std = noisy_y[:, :single_eval_pos].std(dim=1, keepdim=True) + noisy_y = (noisy_y - train_mean) / train_std + target_y = (target_y - train_mean[..., None]) / train_std[..., None] + + # Apply input warping if specified + input_warping_prob = hyperparameters.get("input_warping_prob", 0.0) + if input_warping_prob > 0.0: + c0_std = hyperparameters.get("input_warping_c0_std", 0.75**0.5) + c1_std = hyperparameters.get("input_warping_c1_std", 0.75**0.5) + + # Sample c0 and c1 parameters from LogNormal distributions for each batch + c0 = LogNormalPrior(loc=0.0, scale=c0_std).sample((batch_size, num_features)) + c1 = LogNormalPrior(loc=0.0, scale=c1_std).sample((batch_size, num_features)) + + no_warping_mask = torch.rand(batch_size, num_features) > input_warping_prob + c0[no_warping_mask] = 1.0 + c1[no_warping_mask] = 1.0 + + # Apply inverse Kumaraswamy warping to inputs with per-batch parameters + x = inv_kumaraswamy_warp(x, c0, c1) + + # set ys to nan in training set + number_of_y_hidden = torch.randint( + 0, hyperparameters.get("max_num_hidden_y", 0) + 1, tuple() + ) + noisy_y[:, single_eval_pos - number_of_y_hidden : single_eval_pos] = torch.nan + + if print_infos: + import pprint + + infos = { + "lengthscales": length_scales, + "noise_variances": noise_variance, + "means": mean, + "num_important_features": num_important_features, + "kernel": kernel_name, + } + pprint.pprint(infos) + + b = Batch(x=x, y=noisy_y, target_y=target_y, style=style) + return b diff --git a/pfns/priors/singletaskgp_x_only.py b/pfns/priors/singletaskgp_x_only.py new file mode 100644 index 0000000..ccaf8f0 --- /dev/null +++ b/pfns/priors/singletaskgp_x_only.py @@ -0,0 +1,281 @@ +from math import log, sqrt + +import torch + +from gpytorch.distributions.multivariate_normal import MultivariateNormal +from gpytorch.kernels import LinearKernel, MaternKernel, RBFKernel +from gpytorch.priors import LogNormalPrior + +from pfns.priors.prior import Batch + + +def sample_clustered_x(batch_size, seq_len, num_features, pad_factor: int = 10): + """ + This function samples a batch of inputs from normal distributions. + Its outputs are all in [0,1], which is ensured by over-sampling (pad_factor) + and then rejecting outside samples. In addition, we clamp the values to [0,1]. + """ + mean = torch.rand(batch_size, num_features) + std = torch.rand(batch_size, num_features) / 4.0 + + x = ( + torch.randn(batch_size, num_features, seq_len * pad_factor) * std[:, :, None] + + mean[:, :, None] + ) + sorting_x = (x >= 0.0) & (x <= 1.0) + order = torch.argsort(sorting_x, dim=-1, stable=True, descending=True) + x = x.gather(dim=-1, index=order[:, :, :seq_len]) + x = x.transpose(1, 2) + x = x.clamp(0, 1) + return x + + +@torch.no_grad() +def get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters=None, + n_targets_per_input=1, + **kwargs, +): + if hyperparameters is None: + hyperparameters = { + "lengthscale_loc_constant_add": sqrt(2), # same as in fully + "lengthscale_loc_feature_mul": 0.5, # same as in fully + "lengthscale_scale": sqrt(3), # same as in fully + "noise_var_loc": -4.0, # different in fully bayesian + "noise_var_scale": 1.0, # different in fully bayesian + "noise_var_dist": "lognormal", # different in fully bayesian, where it is "gamma" + "noise_var_concentration": 0.9, + "noise_var_rate": 10.0, + "mean_width": 2.0, + "mean_dist": "uniform", + "attsink_tokens": 0, + "noisy_predictions": False, + "sample_clustered_x": False, + "train_normalized_y": False, + "style_for_max_on_border_likelihood": False, + "dummy_dim_prob": 0.0, + } + + assert not hyperparameters.get( + "noisy_predictions", False + ), "noisy_predictions is not supported for x_only mode" + + if hyperparameters["sample_clustered_x"]: + x_train = sample_clustered_x(batch_size, single_eval_pos, num_features) + x_test = torch.rand(batch_size, seq_len - single_eval_pos, num_features) + x = torch.cat([x_train, x_test], dim=1) + else: + x = torch.rand(batch_size, seq_len, num_features) + + mean_width = hyperparameters["mean_width"] + if mean_width == 0: + mean = torch.zeros(batch_size) + else: + mean_dist = hyperparameters.get("mean_dist", "uniform") + if mean_dist == "uniform": + min_mean, max_mean = -mean_width / 2, mean_width / 2 + mean = torch.rand(batch_size) * (max_mean - min_mean) + min_mean + elif mean_dist == "normal": + mean = torch.randn(batch_size) * mean_width / 2 + else: + raise ValueError(f"Unknown mean distribution {mean_dist}") + + if (dummy_dim_prob := hyperparameters.get("dummy_dim_prob", 0.0)) > 0.0: + num_important_features = 0 + while num_important_features == 0: + dummy_dims_mask = torch.bernoulli( + torch.full((num_features,), dummy_dim_prob) + ).bool() + used_dims_mask = ~dummy_dims_mask + num_important_features = used_dims_mask.sum() + else: + used_dims_mask = torch.ones(num_features, dtype=torch.bool) + num_important_features = num_features + + length_scales = LogNormalPrior( + loc=hyperparameters["lengthscale_loc_constant_add"] + + log(num_important_features) * hyperparameters["lengthscale_loc_feature_mul"], + scale=hyperparameters["lengthscale_scale"], + ).sample((batch_size, num_important_features)) + + kernel_name = hyperparameters.get("kernel", "rbf") + + def get_covar(length_scales, x): + if kernel_name == "rbf": + covar_module = RBFKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + ) + elif kernel_name == "matern_1.5": + covar_module = MaternKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + nu=1.5, + ) + elif kernel_name == "matern_2.5": + covar_module = MaternKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + nu=2.5, + ) + elif kernel_name == "linear": + covar_module = LinearKernel( + batch_shape=torch.Size([batch_size]), + ard_num_dims=length_scales.shape[1], + ) + else: + raise ValueError(f"Unknown kernel {kernel_name}") + if covar_module.has_lengthscale: + covar_module._set_lengthscale(length_scales) + covar = covar_module(x[..., used_dims_mask], x[..., used_dims_mask]) + return covar + + covar = get_covar(length_scales, x) + d = MultivariateNormal(torch.ones_like(x[:, :, 0]) * mean[:, None], covar) + y: torch.Tensor = d.sample() + + noise_var_dist = hyperparameters.get("noise_var_dist", "lognormal") + if noise_var_dist == "lognormal": + noise_variance: torch.Tensor = LogNormalPrior( + loc=hyperparameters["noise_var_loc"], + scale=hyperparameters["noise_var_scale"], + ).sample((batch_size,)) + elif noise_var_dist == "gamma": + noise_variance: torch.Tensor = ( + torch.distributions.Gamma( + hyperparameters["noise_var_concentration"], + hyperparameters["noise_var_rate"], + ).sample((batch_size,)) + + 1e-4 + ) + else: + raise ValueError(f"Unknown noise variance distribution {noise_var_dist}") + + noisy_y = y + torch.randn_like(y) * noise_variance[:, None] ** (1 / 2) + noisy_y = noisy_y.view(batch_size, seq_len, 1) + + assert n_targets_per_input == 1, "n_targets_per_input must be 1 for x_only mode" + target_y = y.view(batch_size, seq_len, 1).repeat(1, 1, n_targets_per_input) + + if hyperparameters.get("train_normalized_y", False): + if single_eval_pos <= 1: + raise ValueError("train_normalized_y requires single_eval_pos > 1") + + train_mean = noisy_y[:, :single_eval_pos].mean(dim=1, keepdim=True) + train_std = noisy_y[:, :single_eval_pos].std(dim=1, keepdim=True) + noisy_y = (noisy_y - train_mean) / train_std + target_y = (target_y - train_mean[..., None]) / train_std[..., None] + + # we should hide the y from time to time in training but still incorporate it to compute EI + # that is exactly what we need for batch EI, I believe + # then we simply do EI and then condition on the point without passing y again + number_of_y_hidden = torch.randint( + 0, hyperparameters.get("max_num_hidden_y", 0) + 1, tuple() + ) + noisy_y[:, single_eval_pos - number_of_y_hidden : single_eval_pos, :] = torch.nan + full_train_x = torch.cat([x, noisy_y], dim=2)[:, :single_eval_pos, :] + + # LETS GET TO THE TEST PART + + # ei values + # target_y shape: batch_size, seq_len, 1 + if hyperparameters.get("predict_ei", True): + target_y = ( + target_y[:, single_eval_pos:].squeeze(-1) + - target_y[:, :single_eval_pos, :] + .squeeze(-1) + .max(dim=-1, keepdim=True) + .values + ) + else: + target_y = target_y[:, single_eval_pos:].squeeze(-1) + + # ei shape: batch_size, test size + + full_test_x = torch.cat([x, noisy_y], dim=2)[:, single_eval_pos:, :] + full_target = torch.cat([x[:, single_eval_pos:], target_y.unsqueeze(-1)], dim=2) + + # we need three partitions of test + # 1. y shown, predict subset of features (rest shown) + # 2. y hidden, all features shown, predict y + # 3. y hidden, predict subset of features that maximize EI + + batch_size, test_size, num_features_plus_1 = full_target.shape + # let's do case 3 first + if hyperparameters.get("predict_maximizer", False): + case_3_test_size = num_features # all empty to full - 1 + num_1_and_2_test_points = test_size - case_3_test_size + + max_target_index = target_y.argmax(dim=-1) + max_ei_features = x[ + torch.arange(batch_size), max_target_index + single_eval_pos, : + ] + max_ei_features_and_nan_for_y = torch.cat( + [max_ei_features, torch.full((batch_size, 1), torch.nan)], dim=1 + ) + + case_3_test_x = max_ei_features_and_nan_for_y.unsqueeze(1).repeat( + 1, case_3_test_size, 1 + ) # shape: batch_size, case_3_test_size, num_features + 1 + case_3_target = case_3_test_x.clone() + for i in range(case_3_test_size): + case_3_test_x[:, i, : i + 1] = torch.nan + case_3_target[:, i, i + 1 :] = torch.nan + + # add case 3 to the tensors + full_test_x[:, num_1_and_2_test_points:, :] = case_3_test_x + full_target[:, num_1_and_2_test_points:, :] = case_3_target + else: + num_1_and_2_test_points = test_size + + num_1_test_points = round( + hyperparameters.get("y_conditioned_share", 0.5) * num_1_and_2_test_points + ) + num_2_test_points = num_1_and_2_test_points - num_1_test_points + + if num_1_test_points > 0: + target_mask = torch.ones_like( + full_target[:, :num_1_test_points, :], dtype=torch.bool + ) + # show targets for case 1 + target_mask[:, :, -1] = False # not target but shown + + # for features, sample uniformly 0 to all + assert full_target.shape[2] == num_features + 1 + num_shown_features = torch.randint( + 0, num_features, (batch_size, num_1_test_points) + ) + + for i in range(batch_size): + for j in range(num_1_test_points): + shown_features_for_example = num_shown_features[i, j] + shown_features = torch.randperm(num_features + 1)[ + :shown_features_for_example + ] + target_mask[i, j, shown_features] = ( + False # not in the target_mask anymore + ) + + assert not target_mask[:, :, -1].any() # always predict the y target + + # set inputs that are hidden to nan + full_test_x[:, :num_1_test_points, :][target_mask] = torch.nan + + # copy over the targets, such that we don't condition on noisy values + full_test_x[:, :num_1_test_points, -1] = full_target[:, :num_1_test_points, -1] + + full_target[:, :num_1_test_points, :][~target_mask] = torch.nan + + # finally let's do case 2 + if num_2_test_points > 0: + # y hidden, all features shown, predict ei + full_test_x[:, num_1_test_points:num_1_and_2_test_points, -1] = torch.nan + full_target[:, num_1_test_points:num_1_and_2_test_points, :-1] = torch.nan + + return Batch( + x=full_train_x, test_x=full_test_x, target=full_target, y=None, target_y=None + ) diff --git a/pfns/priors/small_peaks.py b/pfns/priors/small_peaks.py new file mode 100644 index 0000000..c9cc545 --- /dev/null +++ b/pfns/priors/small_peaks.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +""" +Example configuration file for PFN training. +This is a Hebo+ prior configuration, as found in the PFNs4BO paper. +This file demonstrates how to configure the MainConfig for training using Python. +""" + +import math + +import torch +from pfns.model import bar_distribution +from pfns.model.encoders import EncoderConfig +from pfns.priors.prior import AdhocPriorConfig +from pfns.train import ( + BatchShapeSamplerConfig, + MainConfig, + OptimizerConfig, + TransformerConfig, +) +from pfns.utils import product_dict + +from tqdm import tqdm + +config_dicts = product_dict( + { + "emsize": [256], + "nlayers": [12], + "epochs": [200], + "lr": [2e-4], + "batch_size": [256], + "batch_size_per_gp_sample": [8], + "num_workers": [6], # while more workers would be good they lead to segfaults + "max_seq_len": [60, 120], + "num_buckets": [1000, 5000], + "encoder_hidden_size": [1024, None], + } +) + + +def get_config(config_index: int): + config_dict = list(config_dicts)[config_index] + + emsize = config_dict["emsize"] + epochs = config_dict["epochs"] + lr = config_dict["lr"] + nlayers = config_dict["nlayers"] + batch_size = config_dict["batch_size"] + num_workers = config_dict["num_workers"] + max_seq_len = config_dict["max_seq_len"] + + steps_per_epoch = 1000 + num_features = 2 + hyperparameters = {} + + def get_prior_config(plotting=False): + prior_config = AdhocPriorConfig( + prior_names=["small_peaks"], + prior_kwargs={ + "num_features": 1 if plotting else num_features, + "hyperparameters": {**hyperparameters}, + "batch_size_per_gp_sample": config_dict["batch_size_per_gp_sample"], + }, + ) + return prior_config, hyperparameters + + prior_config, hps = get_prior_config() + + gb = prior_config.create_get_batch_method() + + ys = [] + for nf in tqdm(list(range(1, num_features)) * 200): + ys.append(gb(batch_size=128, seq_len=1000, num_features=nf).target_y.flatten()) + + ys = torch.cat(ys) + print(f"{len(ys)=} for {config_dict['num_buckets']=}") + + borders = bar_distribution.get_bucket_borders(config_dict["num_buckets"], ys=ys) + + print(f"{borders=}") + + return MainConfig( + priors=[prior_config], + optimizer=OptimizerConfig("adamw", lr=lr, weight_decay=0.0), + scheduler="cosine_decay", + model=TransformerConfig( + criterion=bar_distribution.BarDistributionConfig( + borders.tolist(), full_support=True + ), + emsize=emsize, + nhead=emsize // 32, + nhid=emsize * 4, + nlayers=nlayers, + encoder=EncoderConfig( + variable_num_features_normalization=True, + constant_normalization_mean=0.5, + constant_normalization_std=1 / math.sqrt(12), + hidden_size=config_dict["encoder_hidden_size"], + ), + y_encoder=EncoderConfig( + nan_handling=True, + constant_normalization_mean=0.5, + constant_normalization_std=1 / math.sqrt(12), + hidden_size=config_dict["encoder_hidden_size"], + ), + attention_between_features=True, + ), + batch_shape_sampler=BatchShapeSamplerConfig( + batch_size=batch_size, + max_seq_len=max_seq_len, + fixed_num_test_instances=10, + max_num_features=num_features, + ), + epochs=epochs, + warmup_epochs=epochs // 10, + steps_per_epoch=steps_per_epoch, + num_workers=num_workers, + train_mixed_precision=True, + verbose=True, + ) + + +# View with: tensorboard --logdir=runs diff --git a/pfns/run_training_cli.py b/pfns/run_training_cli.py index b681a0a..a6547ed 100644 --- a/pfns/run_training_cli.py +++ b/pfns/run_training_cli.py @@ -113,12 +113,6 @@ def load_config_from_python( ), "config_index is not 0 but get_config is not defined" config = config_module.config - # Validate that it is a MainConfig instance - if not isinstance(config, pfns.train.MainConfig): - raise TypeError( - f"'config' variable must be a MainConfig instance, got {config.__class__.__name__}" - ) - print(f"Successfully loaded config from {config_file}") return config @@ -138,6 +132,12 @@ def main(): # Load configuration from Python file config = load_config_from_python(args.config_file, args.config_index) + # Validate that it is a MainConfig instance + if not isinstance(config, pfns.train.MainConfig): + raise TypeError( + f"'config' variable must be a MainConfig instance, got {config.__class__.__name__}" + ) + def get_filename(config_file): return f"{config_file.split('/')[-1].split('.')[0]}" diff --git a/pfns/train.py b/pfns/train.py index d9926bb..283e8ca 100644 --- a/pfns/train.py +++ b/pfns/train.py @@ -1,14 +1,12 @@ from __future__ import annotations import importlib - import os import time import typing as tp from contextlib import nullcontext from dataclasses import dataclass -import einops import torch from torch import nn from torch.amp import autocast, GradScaler @@ -19,9 +17,7 @@ from .batch_shape_sampler import BatchShapeSamplerConfig from .model.transformer_config import TransformerConfig from .optimizer import OptimizerConfig - from .priors import data_loading, prior, utils as priors_utils - from .training_utils import ( Metrics, move_style_and_check_shape, @@ -383,6 +379,10 @@ def create_get_batch_method(priors: tp.List[prior.PriorConfig] | None): data_loader = None if writer: writer.close() + # Clean up distributed training + if using_dist: + torch.distributed.destroy_process_group() + return { "total_loss": total_loss, "model": model.to("cpu"), @@ -417,6 +417,10 @@ def train_or_evaluate_epoch( metrics = Metrics(steps_per_epoch=len(dl)) + # Whether the prior does not return y, but but instead uses + # a separate x_test and target. + x_only_mode = c.model.x_only_mode + importance_sampling_infos = [] before_get_batch = time.time() @@ -433,11 +437,11 @@ def train_or_evaluate_epoch( for batch_index, batch in enumerate(dl): batch: prior.Batch = batch # for IDE support # batch.x.shape == (batch_size, seq_len, num_features) - if not c.model.attention_between_features: - assert ( - c.model.features_per_group == batch.x.shape[2] - ), "features_per_group must match the number of features in the input, if attention_between_features is False" - targets = batch.target_y.to(device) + + if x_only_mode: + targets = batch.target.to(device) + else: + targets = batch.target_y.to(device) single_eval_pos = batch.single_eval_pos if tqdm_iter is not None: @@ -460,33 +464,49 @@ def train_or_evaluate_epoch( before_forward = time.time() try: with autocast(device.split(":")[0], enabled=scaler is not None): - output = model( - x=batch.x.to(device), - y=batch.y[:, :single_eval_pos].to(device), - style=move_style_and_check_shape(batch.style, batch.x, device), - y_style=move_y_style_and_check_shape( - batch.y_style, batch.y, device - ), - only_return_standard_out=True, - ) # shape: (batch_size, test_len) + if x_only_mode: + assert ( + batch.target_y is None + and batch.y is None + and batch.y_style is None + ), "model.x_only_mode is not supported when y, target_y, or y_style are not None" + output = model( + x=batch.x.to( + device + ), # shape: (batch_size, train_len, num_features) + test_x=batch.test_x.to( + device + ), # shape: (batch_size, test_len, num_features) + y=None, + style=move_style_and_check_shape( + batch.style, batch.x, device + ), + only_return_standard_out=True, + ) # shape: (batch_size, test_len, num_groups) + else: + output = model( + x=batch.x.to(device), + y=batch.y[:, :single_eval_pos].to(device), + style=move_style_and_check_shape( + batch.style, batch.x, device + ), + y_style=move_y_style_and_check_shape( + batch.y_style, batch.y, device + ), + only_return_standard_out=True, + ) # shape: (batch_size, test_len) forward_time = time.time() - before_forward - if single_eval_pos is not None: + if single_eval_pos is not None and not x_only_mode: targets = targets[ :, single_eval_pos: ] # shape: (batch_size, test_len) - losses = compute_losses( - output, targets, criterion, c.n_targets_per_input - ) # shape: (batch_size, test_len) + loss, nan_share = compute_loss( + output, targets, criterion, c.n_targets_per_input, x_only_mode + ) # shape: (batch_size, test_len) | (batch_size, test_len, n_features) - loss, nan_share = utils.torch_nanmean( - losses.mean( - 1 - ), # loss per sequence without nanmean, if any loss in a sequence is nan, the whole sequence is ignored - return_nanshare=True, - ) # loss and nan_share are both scalar tensors loss_scaled = loss / c.aggregate_k_gradients if scaler: @@ -576,34 +596,38 @@ def train_or_evaluate_epoch( return metrics.get_epoch_result(importance_sampling_infos) -def compute_losses( +def compute_loss( output: torch.Tensor, targets: torch.Tensor, criterion: torch.nn.Module, n_targets_per_input: int, + x_only_mode: bool, ): """ Compute the losses for the given output and targets. Args: - output: The output of the model, shape (batch_size, num_eval_positions, n_out) - targets: The targets, shape (batch_size, num_eval_positions[, n_targets_per_input]) + output: The output of the model, shape (batch_size, num_eval_positions, n_out) | (batch_size, num_eval_positions, n_features, n_out) + targets: The targets, shape (batch_size, num_eval_positions[, n_targets_per_input]) | (batch_size, num_eval_positions, n_features, n_targets_per_input) criterion: The criterion to use. n_targets_per_input: The number of targets per input. Returns: The losses, shape (batch_size, num_eval_positions) """ - # Repeat output in the semi-last dimension n_targets_per_input times - output = output.unsqueeze(2).expand( - *output.shape[:2], - n_targets_per_input, - output.shape[-1], - ) + if ( + len(output.shape) == 3 + ): # else it is (batch_size, num_eval_positions, n_features, n_out) + # Repeat output in the semi-last dimension n_targets_per_input times + output = output.unsqueeze(2).expand( + *output.shape[:2], + n_targets_per_input, + output.shape[-1], + ) - if len(targets.shape) == 2: - # This implies we only have a single target per input - targets = targets.unsqueeze(2) + if len(targets.shape) == 2: + # This implies we only have a single target per input + targets = targets.unsqueeze(2) assert targets.shape == output.shape[:-1], ( f"Target shape {targets.shape} " @@ -612,9 +636,6 @@ def compute_losses( "1 dimension in the target." ) - output = einops.rearrange(output, "b s t l -> (b t) s l") - targets = einops.rearrange(targets, "b s t -> (b t) s") - if isinstance(criterion, nn.GaussianNLLLoss): assert ( output.shape[-1] == 2 @@ -639,9 +660,21 @@ def compute_losses( ) else: losses = criterion(output, targets.unsqueeze(-1)) - losses = einops.rearrange(losses, "(b t) s -> b s t", t=n_targets_per_input) - losses = losses.mean(-1) - return losses + # mean over the last dimension (either features or target repetitions) + if x_only_mode: + loss = losses[~torch.isnan(targets)].mean() + nan_share = torch.tensor(0.0) + else: + losses = losses.mean(-1) + + loss, nan_share = utils.torch_nanmean( + losses.mean( + 1 + ), # loss per sequence without nanmean, if any loss in a sequence is nan, the whole sequence is ignored + return_nanshare=True, + ) # loss and nan_share are both scalar tensors + + return loss, nan_share def should_load_checkpoint( diff --git a/pfns/utils.py b/pfns/utils.py index 549d1fd..233d172 100644 --- a/pfns/utils.py +++ b/pfns/utils.py @@ -7,9 +7,8 @@ import re import numpy as np - import torch -from torch import nn +from torch import distributed as dist, nn from torch.optim.lr_scheduler import LambdaLR @@ -302,20 +301,31 @@ def print(*args, **kwargs): def init_dist(device): print("init dist") - if "LOCAL_RANK" in os.environ: - # launched with torch.distributed.launch + if "WORLD_SIZE" in os.environ: + dist.init_process_group(backend="nccl") + distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if distributed else 0 + + # Get local rank from environment variable (set by torchrun) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) if distributed else 0 + torch.cuda.set_device(local_rank) + print("Initialized rank", rank) + print_on_master_only(rank == 0) + return distributed, rank, f"cuda:{local_rank}" + elif "LOCAL_RANK" in os.environ: + # launched with dist.launch rank = int(os.environ["LOCAL_RANK"]) - print("torch.distributed.launch and my rank is", rank) + print("dist.launch and my rank is", rank) torch.cuda.set_device(rank) os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) - torch.distributed.init_process_group( + dist.init_process_group( backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20), world_size=torch.cuda.device_count(), rank=rank, ) - torch.distributed.barrier() + dist.barrier() print_on_master_only(rank == 0) print( f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, " @@ -333,14 +343,14 @@ def init_dist(device): torch.cuda.set_device(rank) # os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) print("distributed submitit launch and my rank is", rank) - torch.distributed.init_process_group( + dist.init_process_group( backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20), world_size=torch.cuda.device_count(), rank=rank, ) - torch.distributed.barrier() + dist.barrier() print_on_master_only(rank == 0) print( f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, " diff --git a/rl/function_sampler.py b/rl/function_sampler.py new file mode 100644 index 0000000..88b0bad --- /dev/null +++ b/rl/function_sampler.py @@ -0,0 +1,24 @@ +from abc import ABCMeta, abstractmethod +from typing import Callable + +from pfns.base_config import BaseConfig +from torch import Tensor + + +class FunctionSamplerConfig(BaseConfig, metaclass=ABCMeta): + @property + def restricts_sampling_points(self) -> bool: + """Indicates whether this sampler restricts which points can be sampled. + + If True, the sampler's callable will have a `get_candidate_points` method + that returns the available candidate points for each batch element. + """ + return False + + @abstractmethod + def function_sampler( + self, batch_size: int, num_features: int = 1, device: str = "cpu" + ) -> Callable[ + [Tensor], Tensor + ]: # going from tensor of shape (batch_size, n) to (batch_size, n) + pass diff --git a/rl/function_samplers/singletaskgp.py b/rl/function_samplers/singletaskgp.py new file mode 100644 index 0000000..72e030e --- /dev/null +++ b/rl/function_samplers/singletaskgp.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +from math import sqrt +from typing import Literal, Optional + +import torch +from gpytorch.priors import LogNormalPrior +from pfns.priors.path_stgp import sample_paths + +from .function_sampler import FunctionSamplerConfig + + +@dataclass(frozen=True) +class Config(FunctionSamplerConfig): + # Noise variance distribution type: "gamma" or "lognormal" + noise_var_dist: Literal["gamma", "lognormal"] = "lognormal" + + # Gamma distribution parameters (used when noise_var_dist="gamma") + noise_variance_gamma_concentration: float = 0.9 + noise_variance_gamma_rate: float = 10.0 + + # LogNormal distribution parameters (used when noise_var_dist="lognormal") + noise_var_loc: Optional[float] = -4.0 + noise_var_scale: Optional[float] = 1.0 + + # sample_paths hyperparameters + use_rbf_kernel: bool = True + lengthscale_loc_constant_add: float = sqrt(2) + lengthscale_loc_feature_mul: float = 0.5 + lengthscale_scale: float = sqrt(3) + mean_width: float = 2.0 + + # Dummy dimension configuration + # If set, sample number of non-dummy dimensions from [min, max] range + # E.g., (1, 3) means 1-3 dimensions are non-dummy, rest are ignored + dummy_dim_sample_non_dummy_range: Optional[tuple[int, int]] = None + # Probability of applying dummy_dim_sample_non_dummy_range logic. + # If not applied, all dimensions are non-dummy. + dummy_dim_sample_non_dummy_range_prob: float = 1.0 + + # Gap discontinuity configuration + gap_max_splits: int = 0 # Max axis-aligned splits (0 = disabled) + gap_prob: float = 1.0 # Probability of applying gaps + # Add to lengthscale_loc_constant_add when gaps applied + gap_lengthscale_add: float = 0.0 + # Probability of applying lengthscale adjustment + gap_lengthscale_add_prob: float = 1.0 + + def _sample_noise_variance(self, batch_size: int) -> torch.Tensor: + if self.noise_var_dist == "lognormal": + return LogNormalPrior( + loc=self.noise_var_loc, + scale=self.noise_var_scale, + ).sample((batch_size,)) + elif self.noise_var_dist == "gamma": + return ( + torch.distributions.Gamma( + self.noise_variance_gamma_concentration, + self.noise_variance_gamma_rate, + ).sample((batch_size,)) + + 1e-4 + ) + else: + raise ValueError( + f"Unknown noise variance distribution {self.noise_var_dist}" + ) + + @torch.no_grad() + def function_sampler(self, batch_size, num_features=1, device="cpu", seed=None): + hyperparameters = { + "use_rbf_kernel": self.use_rbf_kernel, + "lengthscale_loc_constant_add": self.lengthscale_loc_constant_add, + "lengthscale_loc_feature_mul": self.lengthscale_loc_feature_mul, + "lengthscale_scale": self.lengthscale_scale, + "mean_width": self.mean_width, + "dummy_dim_sample_non_dummy_range": self.dummy_dim_sample_non_dummy_range, + "dummy_dim_sample_non_dummy_range_prob": self.dummy_dim_sample_non_dummy_range_prob, + "gap_max_splits": self.gap_max_splits, + "gap_prob": self.gap_prob, + "gap_lengthscale_add": self.gap_lengthscale_add, + "gap_lengthscale_add_prob": self.gap_lengthscale_add_prob, + } + paths = sample_paths(batch_size, num_features, hyperparameters) + + noise_variance: torch.Tensor = self._sample_noise_variance(batch_size) + + @torch.no_grad() + def noisy_eval(batch_inputs, independent_noise=False): + # Calculate a spike function: resembles a triangle with peak at 1 + noiseless_outputs = paths(batch_inputs.cpu())[0] + if independent_noise: + noise = ( + torch.randn(batch_inputs.shape[:-1]) + * noise_variance[:, None] ** (1 / 2) + ).squeeze(-1) + else: + noise = (torch.randn(batch_size) * noise_variance ** (1 / 2))[:, None] + outputs = noiseless_outputs + noise + + return noiseless_outputs.to(device), outputs.to(device) + + return noisy_eval diff --git a/rl/run_training_cli.py b/rl/run_training_cli.py new file mode 100644 index 0000000..b80bade --- /dev/null +++ b/rl/run_training_cli.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +Command-line interface for training PFNs models. +""" + +import argparse +import io +import os +import sys +from functools import partial +from pathlib import Path + +import pfns.run_training_cli as original_cli +import pfns.train +import torch + +from manifold.clients.python import ManifoldClient + +from .discrete_eval import evaluate_bo_on_hpob +from .discrete_pfns_bayesopt import get_acquisition_values_pfn + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Train a PFNs model using configuration from a Python file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "config_file", + type=str, + help="Path to the Python configuration file that defines a 'config' variable or `get_config` function. Path is relative to PFNs/fb.", + ) + + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to use for training (e.g., 'cuda:0', 'cpu', 'mps'). If not specified, will auto-detect cuda, but not mps.", + ) + + parser.add_argument( + "--checkpoint-save-load-prefix", + type=str, + default=None, + help="Path to save/load checkpoint and for tensorboard.", + ) + + parser.add_argument( + "--checkpoint-save-load-suffix", + type=str, + default="", + help="Suffix to add to the checkpoint save/load path. This can e.g. be the seed.", + ) + + parser.add_argument( + "--tensorboard-path", + type=str, + default=None, + help=( + "Path to save tensorboard. If not provided, will use the " + "checkpoint save/load prefix or the path in the config file." + ), + ) + + parser.add_argument( + "--config-index", + type=int, + default=0, + help="Index of the config to use. This is used to select a config from the config file.", + ) + + return parser.parse_args() + + +def manifold_load(path: str, map_location: str | None = None) -> object: + """ + A wrapper around torch.load with the same API. + Loads from manifold instead of local disk, though. + + Args: + path: The path to the file to load. The path has the format: manifold:///. + map_location: The device to load the tensors to. Same as torch.load, e.g. "cpu" or "cuda:0". + + Returns: + The loaded object. + """ + + with ManifoldClient.get_client("ae_generic") as client: + stream = io.BytesIO() + client.sync_get(path, stream) + stream.seek(0) + return torch.load(stream, map_location=map_location, weights_only=True) + + +def manifold_exists(path: str) -> bool: + """ + A replacement for os.path.exists that works for manifold paths. + + Args: + path: The path to check. The path has the format: manifold:///. + + Returns: + True if the path exists, False otherwise. + """ + + with ManifoldClient.get_client("ae_generic") as client: + return client.sync_exists(path) + + +def manifold_save(obj, path: str): + """ + A wrapper around torch.save with the same API that saves to manifold. + + Args: + obj: The object to save. + path: The path to save the object to. + The path has the format: manifold:///. + + Returns: + None + """ + dir_path = os.path.dirname(path) + + assert dir_path != "", "dir_path must not be empty" + + with ManifoldClient.get_client("ae_generic") as client: + if not client.sync_exists(dir_path): + client.sync_mkdirs(dir_path) + print("made path") + + stream = io.BytesIO() + torch.save(obj, stream) + stream.seek(0) + client.sync_put( + path, stream, predicate=ManifoldClient.Predicates.AllowOverwrite + ) + + +def main(): + """Main CLI entry point.""" + args = parse_args() + + config_file = args.config_file + config_file = config_file[3:] if config_file.startswith("fb/") else config_file + + # Load configuration from Python file + config = original_cli.load_config_from_python( + config_file, args.config_index, config_base_path=Path(__file__).parent + ) + + def get_filename(config_file): + return Path(config_file).stem + + if args.checkpoint_save_load_suffix: + assert ( + args.checkpoint_save_load_prefix is not None + ), "checkpoint_save_load_prefix is required when checkpoint_save_load_suffix is provided" + + config_tensorboard_path_is_none = config.tensorboard_path is None + + # Override checkpoint paths if specified via CLI + if args.checkpoint_save_load_prefix is not None: + assert ( + config.train_state_dict_save_path is None + ), "train_state_dict_save_path is already set" + assert ( + config.train_state_dict_load_path is None + ), "train_state_dict_load_path is already set" + assert config_tensorboard_path_is_none, "tensorboard_path is already set" + + # Add suffix if it exists + suffix = f"_{args.config_index}" + if args.checkpoint_save_load_suffix: + suffix += f"_{args.checkpoint_save_load_suffix}" + + path = f"{args.checkpoint_save_load_prefix}/{get_filename(config_file)}{suffix}" + + config = config.__class__( + **{ + **config.__dict__, + "train_state_dict_save_path": path + "/checkpoint.pt", + "train_state_dict_load_path": path + "/checkpoint.pt", + "tensorboard_path": "manifold://ae_generic/" + path + "/tensorboard", + } + ) + + if args.tensorboard_path is not None: + assert config_tensorboard_path_is_none, "tensorboard_path is already set" + config = config.__class__( + **{ + **config.__dict__, + "tensorboard_path": args.tensorboard_path, + } + ) + + # We overwrite the config with the one from the checkpoint if it exists + # as there is some randomness in the config and we want to use the exact + # same config again. + if pfns.train.should_load_checkpoint( + config, check_path_exists_function=manifold_exists + ): + config = pfns.train.load_config( + config.train_state_dict_load_path, load_function=manifold_load + ) + + print("Starting training with configuration:") + print(f" Epochs: {config.epochs}") + print(f" Steps per epoch: {config.steps_per_epoch}") + print(f" Device: {args.device or 'auto-detect'}") + print(f" Mixed precision: {config.train_mixed_precision}") + + try: + result = pfns.train.train( + c=config, + device=args.device, + # overrides for filesystem things + save_object_function=manifold_save, + load_object_function=manifold_load, + check_path_exists_function=manifold_exists, + ) + except KeyboardInterrupt: + print("\nTraining interrupted by user.") + sys.exit(1) + + print("\nTraining completed successfully!") + print(f"Total training time: {result['total_time']:.2f} seconds") + print(f"Final loss: {result['total_loss']:.6f}") + + if config.train_state_dict_save_path is not None: + print(f"Model saved to: {config.train_state_dict_save_path}") + # run eval + # todo use manifold_save and manifold_load instead of torch + device = args.device + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = result["model"].to(device) + acq_function = partial(get_acquisition_values_pfn, model=model, device=device) + results = evaluate_bo_on_hpob(acq_function, verbose=True) + torch.save( + results, + "manifold://ae_generic/" + + str(Path(config.train_state_dict_save_path).parent) + + "/hpob_results.pt", + ) + + +if __name__ == "__main__": + main() diff --git a/rl/train_rl.py b/rl/train_rl.py new file mode 100644 index 0000000..3b6ffcf --- /dev/null +++ b/rl/train_rl.py @@ -0,0 +1,2488 @@ +#!/usr/bin/env python3 +"""Command-line interface for running RL fine-tuning of PFN models.""" + +import copy +import math +import os +import random +import time +import typing as tp +from contextlib import nullcontext +from dataclasses import dataclass, fields, replace +from functools import partial + +import numpy as np +import torch +import torch.distributed as dist +from pfns import base_config +from pfns.model import transformer_config +from pfns.model.encoders import StyleEncoderConfig +from pfns.model.transformer import TableTransformer +from pfns.priors.utils import sample_x_around_points +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from .function_samplers import function_sampler # noqa: F401 +from .utils import load_config_and_model + + +def local_load(path: str, map_location: str | None = None) -> object: + """ + Load a checkpoint from the local filesystem. + + Args: + path: The path to the file to load. + map_location: The device to load the tensors to. Same as torch.load, e.g. "cpu" or "cuda:0". + + Returns: + The loaded object. + """ + return torch.load(path, map_location=map_location, weights_only=True) + + +def local_exists(path: str) -> bool: + """ + Check if a path exists on the local filesystem. + + Args: + path: The path to check. + + Returns: + True if the path exists, False otherwise. + """ + return os.path.exists(path) + + +def local_save(obj, path: str): + """ + Save an object to the local filesystem. + + Args: + obj: The object to save. + path: The path to save the object to. + + Returns: + None + """ + dir_path = os.path.dirname(path) + + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path, exist_ok=True) + print(f"Created directory: {dir_path}") + + torch.save(obj, path) + + +@dataclass +class PathGenerationResult: + """Result of path generation for RL training. + + Attributes: + ys: Tensor of y values, shape [batch_size * sub_batch_size, seq_len] + target_ys: Tensor of target y values, shape [batch_size * sub_batch_size, seq_len] + predictions: Tensor of predictions, shape [batch_size * sub_batch_size, seq_len, num_features] + options: List of option tensors (if choose_next_in_set=True) + chosen_options: List of chosen option indices (if choose_next_in_set=True) + choice_probs: List of choice probability tensors (if choose_next_in_set=True) + current_num_features: Number of features used for this generation + draw: Random draw tensor for computing regret, shape [batch_size, draw_size] + y_quantiles: Quantiles of y values, shape [batch_size * sub_batch_size, seq_len] + draw_size: Size of the random draw + joint_steps: Number of initial joint steps used during generation + step_entropies: List of entropy values per step (for tensorboard logging) + step_max_probs: List of max probability values per step (for tensorboard logging) + step_sampled_probs: List of sampled probability values per step (for tensorboard logging) + """ + + ys: torch.Tensor + target_ys: torch.Tensor + predictions: torch.Tensor + options: list[torch.Tensor] + chosen_options: list[torch.Tensor] + choice_probs: list[torch.Tensor] + current_num_features: int + joint_steps: int + basemodel_ei_values: list[ + torch.Tensor + ] # EI values from unfinetuned basemodel for each step + step_entropies: list[torch.Tensor] # Entropy of distribution at each step + step_max_probs: list[torch.Tensor] # Max probability at each step + step_sampled_probs: list[torch.Tensor] # Probability of sampled action at each step + binary_features_mask: ( + torch.Tensor | None + ) # Mask indicating which features are binary, shape [batch_size, num_features] + bo_batch_size: int # The bo_batch_size used for this generation + seq_len: int # The seq_len used for this generation (for random horizon training) + y_style: torch.Tensor | None # The y_style tensor used for this generation + + # placeholders to be filled by subsequent computations + normalized_avg_rewards: torch.Tensor | None = None + unnormalized_avg_rewards: torch.Tensor | None = None + draw: torch.Tensor | None = None + y_quantiles: torch.Tensor | None = None + target_y_quantiles: torch.Tensor | None = None + standardized_ys: torch.Tensor | None = None + standardized_target_ys: torch.Tensor | None = None + draw_size: int | None = None + + def to_device(self, device): + for field in fields(self): + value = getattr(self, field.name) + if isinstance(value, torch.Tensor): + setattr(self, field.name, value.to(device)) + elif isinstance(value, list): + new_list = [ + item.to(device) if isinstance(item, torch.Tensor) else item + for item in value + ] + setattr(self, field.name, new_list) + return self + + +@dataclass(frozen=True) +class RewardConfig(base_config.BaseConfig): + reward_type: tp.Literal[ + "raw", "quantile", "standardized", "log_quantile", "rs_equivalent" + ] = "raw" + standardization_source: tp.Literal["batch", "draw"] = "draw" + only_future: bool = False + aggregation: str = ( + "max" # we can sum instead of mean, because everything is standardized anyways + # Options: "sum", "max", "avgmax", "max_imp", "max_sparse", "myopic_X" (where X is 0, 1, 2, ...) + # myopic_0: only current position, myopic_1: current + next, etc. + ) + standardization: tp.Literal[ + "none", + "per_step_and_function", + "divide_per_step_and_function", + "mean_divide_per_step_and_function", + "mean_sub_per_step_and_function", + "mean_divide_per_function", + "top_0.1_per_function", + "top_0.2_per_function", + ] = "per_step_and_function" + standardization_eps: float = 1e-8 + reward_on_targets: bool = False + no_reward_after_peak: bool | tp.Literal["global"] = False + + @classmethod + def _loading_kwarg_transform(cls, kwargs): + if "quantile_reward" in kwargs: + qr = kwargs.pop("quantile_reward") + kwargs["reward_type"] = "quantile" if qr else "raw" + return kwargs + + def __post_init__(self): + if self.aggregation in ("max_imp", "max_sparse"): + assert ( + self.only_future + ), f"{self.aggregation} does only make sense with future rewards." + if self.aggregation.startswith("myopic_"): + assert ( + self.only_future + ), f"{self.aggregation} does only make sense with future rewards." + try: + window = int(self.aggregation.split("_")[1]) + assert window >= 0, f"myopic window must be non-negative, got {window}" + except (IndexError, ValueError): + raise ValueError( + f"Invalid myopic aggregation format: {self.aggregation}. Expected 'myopic_X' where X is a non-negative integer." + ) + + @torch.no_grad() + def compute_reward( + self, + ys: torch.Tensor, + target_ys: torch.Tensor, + quantile_ys: torch.Tensor, + quantile_target_ys: torch.Tensor, + standardized_ys: torch.Tensor, + standardized_target_ys: torch.Tensor, + ): + batch_size, sub_batch_size, seq_len = ys.shape + + if self.reward_on_targets: + ys = target_ys + quantile_ys = quantile_target_ys + standardized_ys = standardized_target_ys + + if self.reward_type == "quantile": + rewards_curr_pos = quantile_ys + elif self.reward_type == "standardized": + rewards_curr_pos = standardized_ys + elif self.reward_type == "log_quantile": + quantile_regret = (1 - quantile_ys).clamp( + min=1 / 10_000 + ) # clamp s.t. we don't get log(0.) errors + log_quantile_regret = torch.log(quantile_regret) + rewards_curr_pos = -log_quantile_regret + elif self.reward_type == "rs_equivalent": + # Transform quantile to equivalent random search size + # 1/(1-q) represents how many random samples would be needed on average + # to find a value at least this good (e.g., q=0.99 -> 100 samples) + quantile_regret = (1 - quantile_ys).clamp( + min=1 / 10_000 + ) # clamp s.t. we don't get division by 0 + rewards_curr_pos = 1.0 / quantile_regret + else: + assert self.reward_type == "raw", f"Unknown reward_type: {self.reward_type}" + rewards_curr_pos = ys + + if self.only_future: + if self.aggregation == "sum": + timestep_rewards = torch.cumsum(rewards_curr_pos.flip(-1), dim=-1).flip( + -1 + ) + elif self.aggregation == "avgmax": + max_so_far = torch.cummax(rewards_curr_pos, dim=-1)[0] + timestep_rewards = torch.cumsum(max_so_far.flip(-1), dim=-1).flip(-1) + # normalize sums to be averages + num_remaining = torch.arange( + seq_len, 0, -1, device=timestep_rewards.device + ) + timestep_rewards = timestep_rewards / num_remaining.view(1, 1, -1) + else: + timestep_rewards = torch.cummax(rewards_curr_pos.flip(-1), dim=-1)[ + 0 + ].flip(-1) + if self.aggregation == "max_imp": # try this with future rewards + # reward for the first one is a little random + # as it is against a baseline that is 0 starting + # for the quantiles it makes sense, but without not really + max_so_far = torch.cummax(rewards_curr_pos, dim=-1)[0] + timestep_improvement = timestep_rewards + timestep_improvement[..., 1:] -= max_so_far[..., :-1] + average_y_first_guess = rewards_curr_pos[:, :, 0].mean( + 1, keepdim=True + ) + timestep_improvement[..., 0] -= average_y_first_guess + timestep_rewards = timestep_improvement.clamp(min=0.0) + elif self.aggregation == "max_sparse": + # todo: think about this more + max_so_far = torch.cummax(rewards_curr_pos, dim=-1)[0] + timestep_rewards[..., 1:] = torch.where( + max_so_far[..., :-1] > 0.0, 0.0, timestep_rewards[..., 1:] + ) + elif self.aggregation.startswith("myopic_"): + # myopic_X: only look at current position and the next X positions + window = int(self.aggregation.split("_")[1]) + # For each position i, compute max over positions i to min(i+window, seq_len-1) + timestep_rewards = torch.zeros_like(rewards_curr_pos) + for i in range(seq_len): + end_idx = min(i + window + 1, seq_len) + timestep_rewards[..., i] = rewards_curr_pos[..., i:end_idx].max( + dim=-1 + )[0] + else: + assert self.aggregation == "max", self.aggregation + else: + if self.aggregation == "sum": + timestep_rewards = rewards_curr_pos.sum(-1, keepdim=True).repeat( + 1, 1, seq_len + ) + elif self.aggregation == "avgmax": + # Average of cumulative max across the sequence + max_so_far = torch.cummax(rewards_curr_pos, dim=-1)[0] + avg_max = max_so_far.mean(-1, keepdim=True) + timestep_rewards = avg_max.repeat(1, 1, seq_len) + else: + assert self.aggregation == "max", self.aggregation + timestep_rewards = rewards_curr_pos.max(-1, keepdim=True)[0].repeat( + 1, 1, seq_len + ) + + if self.standardization == "per_step_and_function": + normalized_avg_rewards_future = ( + timestep_rewards - timestep_rewards.mean(1, keepdim=True) + ) / (timestep_rewards.std(1, keepdim=True) + self.standardization_eps) + elif self.standardization == "mean_sub_per_step_and_function": + normalized_avg_rewards_future = timestep_rewards - timestep_rewards.mean( + 1, keepdim=True + ) + elif self.standardization == "divide_per_step_and_function": # try this + normalized_avg_rewards_future = timestep_rewards / ( + timestep_rewards.std(1, keepdim=True) + self.standardization_eps + ) + elif self.standardization == "mean_divide_per_step_and_function": # try this + normalized_avg_rewards_future = timestep_rewards / ( + timestep_rewards.mean(1, keepdim=True) + self.standardization_eps + ) + elif self.standardization == "mean_divide_per_function": + normalized_avg_rewards_future = timestep_rewards / ( + timestep_rewards.mean(1, keepdim=True).mean(2, keepdim=True) + + self.standardization_eps + ) + elif self.standardization.startswith("top_") and self.standardization.endswith( + "per_function" + ): + # generalize to any quantile, e.g. "top_0.1_per_function" + try: + quantile = float(self.standardization.split("_")[1]) + except (IndexError, ValueError): + raise ValueError( + f"Invalid standardization format: {self.standardization}" + ) + # cutoff top quantile + quantile_cutoffs = timestep_rewards.view(batch_size, -1).sort(-1)[0][ + :, -round(quantile * sub_batch_size * seq_len) + ] + normalized_avg_rewards_future = torch.where( + timestep_rewards > quantile_cutoffs[:, None, None], + timestep_rewards, + 0.0, + ) + else: + assert self.standardization == "none", self.standardization + normalized_avg_rewards_future = timestep_rewards + # normalized_avg_rewards_future = normalized_avg_rewards_future.view( + # batch_size * sub_batch_size, -1 + # ) + + if self.no_reward_after_peak: + # Set rewards to 0 for all timesteps after the peak evaluation + position_indices = torch.arange(seq_len, device=rewards_curr_pos.device) + + if self.no_reward_after_peak == "global": + # Find the global peak across all trajectories in each sub-batch + # and use it to cut off rewards for all items + # Flatten sub_batch and seq_len to find global max per batch + flat_rewards = rewards_curr_pos.view( + batch_size, sub_batch_size * seq_len + ) + global_peak_flat_indices = flat_rewards.argmax(dim=-1) # [batch_size] + # Convert flat index to seq_len index (the timestep of the global peak) + global_peak_timestep = ( + global_peak_flat_indices % seq_len + ) # [batch_size] + # Create mask: positions after global peak timestep are True for all trajectories + after_peak_mask = position_indices.view( + 1, 1, -1 + ) > global_peak_timestep.view( + batch_size, 1, 1 + ) # [batch_size, sub_batch_size, seq_len] + else: + assert self.no_reward_after_peak is True + # Original behavior: find peak per trajectory + peak_indices = rewards_curr_pos.argmax( + dim=-1 + ) # [batch_size, sub_batch_size] + # Create a mask where positions after the peak are True + after_peak_mask = position_indices.view( + 1, 1, -1 + ) > peak_indices.unsqueeze(-1) # [batch_size, sub_batch_size, seq_len] + + # Zero out rewards after the peak + normalized_avg_rewards_future = torch.where( + after_peak_mask, + torch.zeros_like(normalized_avg_rewards_future), + normalized_avg_rewards_future, + ) + + return ( + normalized_avg_rewards_future, + timestep_rewards, + ) # both shape: [batch_size, sub_batch_size, seq_len] + + +@dataclass(frozen=True) +class RLConfig(base_config.BaseConfig): + model_path: str + function_sampler: function_sampler.FunctionSamplerConfig + reward: RewardConfig = RewardConfig() + batch_size: int = 16 + sub_batch_size: int = 32 # in the scaling RL paper, they use 16 + seq_len: int = 5 + algorithm: tp.Literal["grpo", "cispo"] = "grpo" + eps: float = 0.2 + eps_low: float | None = None # If None, use eps for lower bound (symmetric) + filter_rewards_up_to_magnitude: float | None = ( + None # Zero out rewards for functions with avg magnitude <= this threshold + ) + num_batches: int = 100 + experience_repetitions: int = 10 + learning_rate: float = 1e-5 + min_learning_rate: float | None = 0.0 # No LR schedule, when None + lr_schedule: tp.Literal["cosine", "linear"] = "cosine" # LR decay schedule type + warmup_batches: int | None = None # Optional linear LR warmup for this many batches + opt_beta2: float = 0.99 + opt_eps: float = 1e-8 + weight_decay: float = 0.0 + grad_clip_norm: float = 1.0 + device: str | None = None + standardize_y: None | str = None + seed: int = 21415 + independent_noise_in_function_draws: bool = False + + # For choice based RL + choose_next_in_set: bool = False + choice_set_size: int = 100 + super_choice_set_factor: float = 1.0 + choice_set_top_share: float = ( + 0.5 # only relevant when super_choice_set_factor > 1.0 + ) + ei_selector: bool = False + keep_head: bool = False + num_features: int = 1 # Number of input features for choose_next_in_set mode + mix_k_features_in_opt: int = 1 + basemodel_ei_input: bool = ( + False # Pass EI from unfinetuned basemodel as input feature + ) + binary_feature_likelihood: float = ( + 0.0 # Probability that each feature is binary (0 or 1) in the choice set + ) + + # Around train point sampling: sample some options near previous training points + around_train_point_share: float = ( + 0.0 # Fraction of options to sample around training points (0.0 = disabled) + ) + around_train_point_std: float = 0.01 # Standard deviation for Gaussian noise when sampling around training points + + # Joint rollout training: keep trajectories identical for a random number of initial steps + # None: disabled (default), train on all positions independently + # "single": train only on the first position after joint steps (the split point) + # "remaining": train on all positions from the split point onwards + joint_rollout_training: tp.Literal["single", "remaining"] | None = None + + # Batched BO: select multiple points per batch with NaN y values for pending evaluations + # When bo_batch_size > 1, points within a batch don't see each other's y values + # and rewards are copied from the last position in each batch to all positions + bo_batch_size: int = 1 + randomize_bo_batch_size: bool = False # If True, sample bo_batch_size uniformly from 1 to bo_batch_size for each rollout + + # Random horizon: sample seq_len uniformly at random up to the specified value for each rollout + randomize_seq_len: bool = False + # When randomize_seq_len is True, add a y_style_encoder that encodes the current seq_len + # The seq_len is normalized as (curr_seq_len/seq_len)*2-1 to be in range [-1, 1] + seq_len_y_style_encoder: bool = False + + # Mixed precision for path generation + mixed_precision_path_generation: bool = False + + # Checkpointing + checkpoint_save_path: str | None = None # Path to save checkpoints after each batch + checkpoint_load_path: str | None = ( + None # Path to load checkpoint from to resume training + ) + + # Rollback on high loss + rollback_loss_threshold: float | None = ( + None # Roll back optimizer step if loss exceeds this threshold + ) + + # Sequence length curriculum: scales seq_len with factors of 2 until reaching target + # Options: + # - None: No curriculum, use seq_len for all batches + # - "equal": Train equal number of batches at each curriculum stage + # - "exponential": Train exponentially fewer batches at longer sequence lengths + # (double seq_len, half the batches at each stage) + seq_len_curriculum: tp.Literal["equal", "exponential"] | None = None + seq_len_curriculum_min: int = 8 # Minimum sequence length for curriculum stages + + # Filled in automatically + tensorboard_path: str | None = None + model: transformer_config.TransformerConfig | None = None + + # make it backwards compatible + @classmethod + def _loading_kwarg_transform(cls, kwargs): + kwargs.pop("output_checkpoint_path", None) + + if kwargs.pop("single_position_training", None): + kwargs["joint_rollout_training"] = "single" + + return kwargs + + def __post_init__(self): + if self.bo_batch_size > 1 and not self.choose_next_in_set: + raise ValueError( + "bo_batch_size > 1 is only supported when choose_next_in_set=True" + ) + if self.binary_feature_likelihood > 0.0 and not self.choose_next_in_set: + raise ValueError( + "binary_feature_likelihood > 0 is only supported when choose_next_in_set=True" + ) + if self.around_train_point_share > 0.0 and not self.choose_next_in_set: + raise ValueError( + "around_train_point_share > 0 is only supported when choose_next_in_set=True" + ) + + +def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: + """Get the underlying module from a DDP-wrapped model or return the model itself.""" + if isinstance(model, DDP): + return model.module + return model + + +def transform_logits(model: TableTransformer, logits: torch.Tensor, ei_selector: bool): + if not ei_selector: + return logits.squeeze(-1) + else: + # Handle DDP-wrapped models + model_module = unwrap_model(model) + return model_module.criterion.ei(logits, best_f=0.0) * 1_000_000_000 + + +def preprocess_train_x_and_y( + train_x: torch.Tensor, + train_y: torch.Tensor, + standardize_y: str | None, + bo_batch_size: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: # both input tensors are [b,seq,features / 1] + """Preprocess training data for model input. + + Args: + train_x: Training x values, shape [batch, seq_len, num_features] + train_y: Training y values, shape [batch, seq_len, 1] + standardize_y: Standardization method ("m0s1" or "m0.5s0.3" or None) + bo_batch_size: Batch size for batched BO. When > 1, y values for + previous elements in the current batch are set to NaN (as they + haven't been evaluated yet in batched BO). + + Returns: + Tuple of (train_x, train_y) with appropriate preprocessing applied. + """ + assert train_x.numel() > 0 + current_seq_len = train_y.shape[1] + + # For batched BO, determine which positions have been "evaluated" + # Positions in previous batches are evaluated, positions in current batch are pending + if bo_batch_size > 1 and current_seq_len > 0: + batch_start = (current_seq_len // bo_batch_size) * bo_batch_size + # Positions 0 to batch_start-1 are evaluated (from previous batches) + # Positions batch_start to current_seq_len-1 are in current batch (pending) + evaluated_y = train_y[:, :batch_start, :] if batch_start > 0 else None + else: + evaluated_y = train_y + batch_start = current_seq_len + + if standardize_y: + # Compute mean/std only over evaluated positions (previous batches) + if evaluated_y is not None and evaluated_y.shape[1] > 0: + mean = evaluated_y.mean(1, keepdim=True) + if evaluated_y.shape[1] == 1: + std = 1.0 + else: + std = evaluated_y.std(1, keepdim=True) + std[std < 1e-8] = 1.0 + else: + # No evaluated positions yet, use defaults + mean = 0.0 + std = 1.0 + + # Apply standardization (NaNs remain NaN) + train_y = (train_y - mean) / std + + if standardize_y == "m0.5s0.3": + train_y = train_y * 0.3 + 0.5 + else: + assert standardize_y == "m0s1" + else: + assert standardize_y is None + + # Still need to NaN out batch positions even without standardization + if bo_batch_size > 1 and batch_start < current_seq_len: + train_y = train_y.clone() + train_y[:, batch_start:, :] = torch.nan + + return train_x, train_y + + +@torch.no_grad() +def generate_paths( + model: TableTransformer, + sampler: tp.Callable, + batch_size: int, + seq_len: int, + sub_batch_size: int, + choose_next_in_set: bool = False, + choice_set_size: int = 100, + super_choice_set_factor: float = 1.0, + top_share: float = 0.5, # only relevant when super_choice_set_factor > 1.0 + ei_selector: bool = False, + standardize_y: str | None = None, + argmax_selection: bool = False, + current_num_features: int = 1, + device: str = "cuda:0", + joint_steps: int = 0, + basemodel_ei_input: bool = False, + basemodel_for_ei: TableTransformer | None = None, + mixed_precision: bool = False, + bo_batch_size: int = 1, + binary_feature_likelihood: float = 0.0, + around_train_point_share: float = 0.0, + around_train_point_std: float = 0.01, + y_style: torch.Tensor | None = None, +) -> PathGenerationResult: + """Generate paths for a sampler. + + Args: + model: The model to use for generating predictions + sampler: Sampler function taking inputs (batch_size, sub_batch_size) returning tuple (batch_size, sub_batch_size) + batch_size: Number of batches + seq_len: Number of sequential predictions to make + sub_batch_size: Number of sub-batches per sampler + choose_next_in_set: Whether to choose next point from a set of options + choice_set_size: Size of the choice set when choose_next_in_set=True + ei_selector: Whether to use EI selector + standardize_y: Standardization method for y values + argmax_selection: Whether to use argmax for selection + current_num_features: Number of features to use for this batch (constant across all steps) + device: Device to run on + joint_steps: Number of initial steps where all sub_batches share the same trajectory. + During these steps, only one sample is taken per batch (not per sub_batch). + At step joint_steps, trajectories are expanded to sub_batch_size copies each. + + Returns: + Dictionary with keys: + - ys: Tensor of shape [batch_size * sub_batch_size, seq_len] + - predictions: Tensor of predictions + - options: List of option tensors (if choose_next_in_set=True) + - chosen_options: List of chosen option indices (if choose_next_in_set=True) + """ + if current_num_features > 1: + assert choose_next_in_set + + if argmax_selection: + assert choose_next_in_set, "We still need to implement the argmax selection." + + if basemodel_ei_input: + assert ( + choose_next_in_set + ), "basemodel_ei_input only works with choose_next_in_set" + assert ( + basemodel_for_ei is not None + ), "basemodel_for_ei must be provided when basemodel_ei_input=True" + + # Sample which features are binary for this batch (consistent across all steps) + if binary_feature_likelihood > 0.0: + binary_features_mask = ( + torch.rand(batch_size, current_num_features, device=device) + < binary_feature_likelihood + ) + else: + binary_features_mask = None + + super_batch_size = batch_size * sub_batch_size + + ys = None + target_ys = None + predictions = None + options = [] + chosen_options = [] + choice_probs = [] + basemodel_ei_values = [] + step_entropies = [] + step_max_probs = [] + step_sampled_probs = [] + + for step_idx in range(seq_len): + # Determine effective batch size for this step + # During joint_steps, we only sample batch_size trajectories (one per batch) + # After joint_steps, we sample super_batch_size trajectories (one per sub_batch) + is_joint_step = step_idx < joint_steps + effective_batch_size = batch_size if is_joint_step else super_batch_size + + # At the transition from joint to split steps, expand trajectories + if step_idx == joint_steps and joint_steps > 0: + # Expand from [batch_size, seq] to [super_batch_size, seq] + # by repeating each trajectory sub_batch_size times + ys = ( + ys.unsqueeze(1) + .expand(-1, sub_batch_size, -1) + .reshape(super_batch_size, joint_steps) + ) + target_ys = ( + target_ys.unsqueeze(1) + .expand(-1, sub_batch_size, -1) + .reshape(super_batch_size, joint_steps) + ) + predictions = ( + predictions.unsqueeze(1) + .expand(-1, sub_batch_size, -1, -1) + .reshape(super_batch_size, joint_steps, current_num_features) + ) + # Expand options and chosen_options for PPO training + for i in range(len(options)): + options[i] = ( + options[i] + .unsqueeze(1) + .expand(-1, sub_batch_size, -1, -1) + .reshape(super_batch_size, -1, current_num_features) + ) + chosen_options[i] = ( + chosen_options[i] + .unsqueeze(1) + .expand(-1, sub_batch_size) + .reshape(super_batch_size) + ) + choice_probs[i] = ( + choice_probs[i] + .unsqueeze(1) + .expand(-1, sub_batch_size, -1) + .reshape(super_batch_size, -1) + ) + + if choose_next_in_set: + if ys is None: + x = torch.zeros( + effective_batch_size, 0, current_num_features, device=device + ) + y = torch.zeros(effective_batch_size, 0, 1, device=device) + else: + x, y = preprocess_train_x_and_y( + predictions, + ys[:, :, None], + standardize_y, + bo_batch_size=bo_batch_size, + ) + + # we make sure the rollouts within each sub-batch get the same choices in each step + total_opts = round(choice_set_size * super_choice_set_factor) + + # Check if sampler restricts sampling points by checking for get_candidate_points + if hasattr(sampler, "get_candidate_points"): + # Validate incompatible options + if around_train_point_share > 0.0: + raise ValueError( + "around_train_point_share > 0 is not supported with samplers that provide candidate points" + ) + if binary_features_mask is not None: + raise ValueError( + "binary_feature_likelihood > 0 is not supported with samplers that provide candidate points" + ) + + # Use candidate points from the sampler + candidate_points = sampler.get_candidate_points() + # Build options from the candidate points for each batch element + opts_list = [] + for batch_idx in range(batch_size): + candidates = candidate_points[batch_idx].to(device) + n_candidates = candidates.shape[0] + if n_candidates >= total_opts: + # Sample without replacement + perm = torch.randperm(n_candidates, device=device)[:total_opts] + opts_list.append(candidates[perm]) + else: + # Sample with replacement if we don't have enough candidates + indices = torch.randint( + 0, n_candidates, (total_opts,), device=device + ) + opts_list.append(candidates[indices]) + opts = torch.stack( + opts_list, dim=0 + ) # [batch_size, total_opts, features] + else: + # Calculate how many options should be sampled around training points + num_around_train = ( + round(total_opts * around_train_point_share) + if around_train_point_share > 0.0 and predictions is not None + else 0 + ) + num_uniform = total_opts - num_around_train + + # Sample uniform options + opts = torch.rand( + batch_size, + num_uniform, + current_num_features, + device=device, + ) + + # Sample options around training points if applicable + if num_around_train > 0: + # Get unique train points per batch + if is_joint_step: + # During joint steps, predictions already has shape [batch_size, step_idx, num_features] + train_points = predictions + else: + # After joint steps, take every sub_batch_size-th trajectory + train_points = predictions[ + ::sub_batch_size + ] # [batch_size, step_idx, num_features] + + # Use shared utility for sampling around training points + around_train_opts = sample_x_around_points( + batch_size=batch_size, + num_samples=num_around_train, + num_features=current_num_features, + centers=train_points, + std=around_train_point_std, + device=device, + ) + # Concatenate with uniform options + opts = torch.cat([opts, around_train_opts], dim=1) + + # Apply binary feature restriction to options + if binary_features_mask is not None: + binary_mask_expanded = binary_features_mask.unsqueeze(1).expand( + -1, total_opts, -1 + ) + opts = torch.where(binary_mask_expanded, (opts > 0.5).float(), opts) + + if super_choice_set_factor < 1.0: + raise NotImplementedError("Please use super_choice_set_factor >= 1.") + if not is_joint_step: + opts = ( + opts.unsqueeze(1) + .repeat( + 1, sub_batch_size, 1, 1 + ) # using repeat instead of expand as we have to copy for the next line either way + .view(super_batch_size, total_opts, current_num_features) + ) + + autocast_ctx = ( + torch.amp.autocast(dtype=torch.float16, device_type="cuda") + if mixed_precision + else nullcontext() + ) + # Compute base model EI and add as additional feature to options + if basemodel_ei_input: + with torch.no_grad(): + # Get logits from the base model (unfinetuned) + basemodel_logits = basemodel_for_ei(x=x, y=y, test_x=opts) + # Compute EI from base model - using best_f=0.0 as in ei_selector + basemodel_ei = ( + basemodel_for_ei.criterion.ei(basemodel_logits, best_f=0.0) + .detach() + .view(super_batch_size, total_opts, 1) + ) # shape: [batch, num_opts, 1] + # Create augmented options with EI as additional feature for model input + opts_with_ei = torch.cat([opts, basemodel_ei], dim=-1) + # Add 0s to train_x for the EI column + x_with_ei = torch.cat( + [ + x, + 0.5 + torch.zeros(x.shape[0], x.shape[1], 1, device=device), + ], + dim=-1, + ) + + # Subselect y_style for joint steps (every sub_batch_size-th element) + # y_style is already in expanded shape [super_batch_size, 1] + current_y_style = None + if y_style is not None: + if is_joint_step: + # Subselect from [super_batch_size, 1] to [batch_size, 1] + current_y_style = y_style[::sub_batch_size] + else: + current_y_style = y_style + + with autocast_ctx: + logits = transform_logits( + model, + model( + x=x_with_ei, + y=y, + test_x=opts_with_ei, + y_style=current_y_style, + ), + ei_selector, + ) + else: + # Subselect y_style for joint steps (every sub_batch_size-th element) + # y_style is already in expanded shape [super_batch_size, 1] + current_y_style = None + if y_style is not None: + if is_joint_step: + # Subselect from [super_batch_size, 1] to [batch_size, 1] + current_y_style = y_style[::sub_batch_size] + else: + current_y_style = y_style + + with autocast_ctx: + logits = transform_logits( + model, + model(x=x, y=y, test_x=opts, y_style=current_y_style), + ei_selector, + ) + basemodel_ei = None + + if super_choice_set_factor > 1.0: + # Select a super set of options consisting of: + # - the best top_share * total options by model score + # - plus the first (1 - top_share) * total options among the remaining non-best ones + k_best = max(1, round(top_share * choice_set_size)) + k_first_non_best = choice_set_size - k_best + + # Get indices of top-k according to logits + topk_vals, topk_inds = torch.topk(logits.squeeze(-1), k=k_best, dim=1) + + # Build a mask for selected top-k to find non-best candidates + top_mask = torch.zeros( + effective_batch_size, total_opts, dtype=torch.bool, device=device + ) + top_mask[torch.arange(effective_batch_size).unsqueeze(1), topk_inds] = ( + True + ) + + # Indices of non-best (complement of top-k) + all_inds = ( + torch.arange(total_opts, device=device) + .unsqueeze(0) + .expand(effective_batch_size, -1) + ) + non_best_inds = all_inds[~top_mask].view(effective_batch_size, -1) + + # Take the first k_first_non_best from non-best in their original order + if k_first_non_best > 0: + first_non_best_inds = non_best_inds[:, :k_first_non_best] + combined_inds = torch.cat([topk_inds, first_non_best_inds], dim=1) + else: + combined_inds = topk_inds + + # Subselect options and logits to the combined set + opts = opts[ + torch.arange(effective_batch_size).unsqueeze(1), combined_inds + ] + logits = logits[ + torch.arange(effective_batch_size).unsqueeze(1), combined_inds + ] + # Also subselect EI values if basemodel_ei_input is enabled + if basemodel_ei_input: + basemodel_ei = basemodel_ei[ + torch.arange(effective_batch_size).unsqueeze(1), combined_inds + ] + + logits = logits.squeeze(-1) + if argmax_selection: + sampled_inds = logits.argmax(1) + else: + sampled_inds = torch.distributions.Categorical(logits=logits).sample() + + # Compute entropy, max prob, and sampled prob for tensorboard logging + # Entropy: -sum(p * log(p)), using clamp to avoid log(0) + log_probs = logits.log_softmax( + dim=-1 + ) # shape: [effective_batch_size, num_opts] + probs = log_probs.exp() # shape: [effective_batch_size, num_opts] + entropy = -(probs * log_probs).sum(dim=-1) # shape: [effective_batch_size] + max_prob = probs.max(dim=-1).values # shape: [effective_batch_size] + sampled_prob = probs[ + torch.arange(effective_batch_size, device=device), sampled_inds + ] # shape: [effective_batch_size] + + step_entropies.append(entropy) + step_max_probs.append(max_prob) + step_sampled_probs.append(sampled_prob) + + pred = opts[torch.arange(effective_batch_size), sampled_inds] + + options.append(opts) + chosen_options.append(sampled_inds) + choice_probs.append(probs) + # Store the EI values for reuse in training loop + basemodel_ei_values.append(basemodel_ei) + else: + if ys is None: + full_train_x = torch.zeros(effective_batch_size, 0, 2, device=device) + else: + train_x, train_y = preprocess_train_x_and_y( + predictions[:, :], + ys[:, :, None], + standardize_y, + bo_batch_size=bo_batch_size, + ) + full_train_x = torch.cat((train_x, train_y), -1) + + full_test_x = torch.full( + (effective_batch_size, 1, 2), + torch.nan, + device=device, + ) + logits = model(x=full_train_x, y=None, test_x=full_test_x)[:, :, 0].squeeze( + 1 + ) + + p_cdf = torch.rand(*logits.shape[:-1], device=device) + pred = torch.stack( + [ + unwrap_model(model).criterion.icdf(logits[i, :], p) + for i, p in enumerate(p_cdf.tolist()) + ], + ).clamp(0, 1) + + target_y, y = sampler( + pred.view( + batch_size, 1 if is_joint_step else sub_batch_size, current_num_features + ) + ) + # Flatten the result back to effective_batch_size + y = y.view(effective_batch_size) + target_y = target_y.view(effective_batch_size) + + if ys is None: + ys = y.view(effective_batch_size, 1) + target_ys = target_y.view(effective_batch_size, 1) + predictions = pred.view(effective_batch_size, 1, current_num_features) + else: + ys = torch.cat((ys, y.view(effective_batch_size, 1)), 1) + target_ys = torch.cat( + (target_ys, target_y.view(effective_batch_size, 1)), 1 + ) + predictions = torch.cat( + (predictions, pred.view(effective_batch_size, 1, current_num_features)), + 1, + ) + + return PathGenerationResult( + ys=ys, + target_ys=target_ys, + predictions=predictions, + options=options, + chosen_options=chosen_options, + choice_probs=choice_probs, + current_num_features=current_num_features, + joint_steps=joint_steps, + basemodel_ei_values=basemodel_ei_values, + step_entropies=step_entropies, + step_max_probs=step_max_probs, + step_sampled_probs=step_sampled_probs, + binary_features_mask=binary_features_mask, + bo_batch_size=bo_batch_size, + seq_len=seq_len, + y_style=y_style, + ) + + +def compute_curriculum_schedule( + num_batches: int, + target_seq_len: int, + mode: tp.Literal["equal", "exponential"], + min_seq_len: int = 8, +) -> list[int]: + """Compute the sequence length for each batch based on curriculum. + + The curriculum scales the rollout length with factors of 2 until reaching + the target sequence length. + + Args: + num_batches: Total number of batches to train + target_seq_len: Final sequence length to reach + mode: Curriculum mode: + - "equal": Train equal number of batches at each curriculum stage + - "exponential": Train exponentially fewer batches at longer sequence + lengths (double seq_len, half the batches at each stage) + min_seq_len: Minimum sequence length for curriculum stages (default: 8) + + Returns: + List of sequence lengths, one per batch + """ + # Clamp min_seq_len to not exceed target_seq_len + min_seq_len = min(min_seq_len, target_seq_len) + + # Generate stages: min_seq_len, min_seq_len*2, min_seq_len*4, ..., up to target_seq_len + curriculum_stages: list[int] = [] + current = min_seq_len + while current < target_seq_len: + curriculum_stages.append(current) + current *= 2 + curriculum_stages.append(target_seq_len) # Always end with target + + n_stages = len(curriculum_stages) + + # If we have fewer batches than stages, trim stages to only include + # the last num_batches stages (prioritize longer sequences) + if num_batches < n_stages: + curriculum_stages = curriculum_stages[-num_batches:] + n_stages = num_batches + + if mode == "equal": + # Equal batches per stage + base_batches = num_batches // n_stages + remainder = num_batches % n_stages + batches_per_stage = [base_batches] * n_stages + # Distribute remainder to later stages (longer sequences get slightly more) + for i in range(remainder): + batches_per_stage[-(i + 1)] += 1 + elif mode == "exponential": + # Exponentially decreasing batches for longer sequences + # First stage (shortest seq) gets the most batches, halving each time + # Sum of geometric series: 1 + 1/2 + 1/4 + ... + 1/2^(n-1) = 2 - 2^(1-n) + geometric_sum = sum(1 / (2**i) for i in range(n_stages)) + first_stage_batches = num_batches / geometric_sum + batches_per_stage = [ + max(1, int(round(first_stage_batches / (2**i)))) for i in range(n_stages) + ] + # Adjust total to exactly match num_batches + total = sum(batches_per_stage) + diff = num_batches - total + if diff != 0: + # Distribute the difference across stages, starting from first stage + # to maintain the exponential property as much as possible + for i in range(abs(diff)): + idx = i % n_stages + if diff > 0: + batches_per_stage[idx] += 1 + else: + # Only subtract if stage has more than 1 batch + if batches_per_stage[idx] > 1: + batches_per_stage[idx] -= 1 + else: + # Find another stage to subtract from + for j in range(n_stages): + if batches_per_stage[j] > 1: + batches_per_stage[j] -= 1 + break + else: + raise ValueError(f"Unknown curriculum mode: {mode}") + + # Build the schedule: list of seq_len for each batch + schedule: list[int] = [] + for stage_idx, stage_len in enumerate(curriculum_stages): + schedule.extend([stage_len] * batches_per_stage[stage_idx]) + + return schedule + + +def run_rl_training( + rl_config: RLConfig, + device_override: str | None = None, +) -> dict: + # Detect if we're in distributed mode + if "WORLD_SIZE" in os.environ: + dist.init_process_group(backend="nccl") + + distributed = dist.is_available() and dist.is_initialized() + rank = dist.get_rank() if distributed else 0 + world_size = dist.get_world_size() if distributed else 1 + + # Get local rank from environment variable (set by torchrun) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) if distributed else 0 + torch.cuda.set_device(local_rank) + + print("Training distributed? ", distributed) + print("Rank?", rank, "Local?", local_rank) + print("World?", world_size) + + is_main = rank == 0 + + if is_main and distributed: + print( + f"Running in distributed mode: rank {rank}/{world_size}, local_rank {local_rank}" + ) + + # Set seed (different per rank for diversity in trajectory generation) + if rl_config.seed is not None: + seed = rl_config.seed + (rank if distributed else 0) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + writer_path = rl_config.tensorboard_path + + # Only create writer on main process + writer: SummaryWriter | None = None + if writer_path is not None and is_main: + writer = SummaryWriter(log_dir=writer_path) + print(f"Tensorboard logging to: {writer_path}") + + if is_main: + print(f"Loading base model from {rl_config.model_path}") + base_train_config, model = load_config_and_model( + rl_config.model_path, map_location="cpu" + ) + + if ( + rl_config.choose_next_in_set + and not rl_config.ei_selector + and not rl_config.keep_head + ): + # edit model head to be a simple 1 size prediction + + base_train_config = replace( + base_train_config, + model=replace( + base_train_config.model, decoder_dict={"standard": (None, 1)} + ), + ) + model_with_single_output = base_train_config.model.create_model() + og_statedict = model.state_dict() + filtered_og_statedict = {} + for n in og_statedict: + if "decoder_dict" in n: + print("removing", n) + else: + filtered_og_statedict[n] = og_statedict[n] + model_with_single_output.load_state_dict(filtered_og_statedict, strict=False) + model = model_with_single_output + + # Add y_style_encoder for seq_len encoding if enabled + if rl_config.seq_len_y_style_encoder: + if not rl_config.randomize_seq_len: + print( + "Warning: seq_len_y_style_encoder is True but randomize_seq_len is False. " + "The y_style_encoder will always receive the same normalized seq_len value." + ) + # Update the model config to include y_style_encoder + # This ensures the encoder is saved with the model and can be loaded later + base_train_config = replace( + base_train_config, + model=replace( + base_train_config.model, + y_style_encoder=StyleEncoderConfig(num_styles=1), + ), + ) + # Recreate model with the y_style_encoder + model_with_y_style_encoder = base_train_config.model.create_model() + og_statedict = model.state_dict() + model_with_y_style_encoder.load_state_dict(og_statedict, strict=False) + model = model_with_y_style_encoder + print(f"Added y_style_encoder: Linear(1, {model.ninp}) for seq_len encoding") + + # We do this s.t. we can load the model w/o base_train_config + rl_config = replace(rl_config, model=base_train_config.model) + + sampler_factory = rl_config.function_sampler.function_sampler + + try: + # Setup device based on distributed training config + # TODO test which of this is right for distributed, might want to use "cuda" on all. + if device_override: + device = device_override + elif distributed: + device = f"cuda:{local_rank}" + elif rl_config.device is not None: + device = rl_config.device + else: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + if is_main: + print(f"Using device {device}") + + model.to(device) + model.train() + + # Create old_model for PPO-style training + # For DDP, we need to access the underlying module + old_model = copy.deepcopy(model) + old_model.requires_grad_(False) + old_model.to(device) + old_model.eval() + + # Store a separate copy of the basemodel for EI computation (never updated during training) + basemodel_for_ei = None + if rl_config.basemodel_ei_input: + basemodel_for_ei = copy.deepcopy(model) + basemodel_for_ei.requires_grad_(False) + basemodel_for_ei.to(device) + basemodel_for_ei.eval() + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=rl_config.learning_rate, + eps=rl_config.opt_eps, + betas=(0.9, rl_config.opt_beta2), + weight_decay=rl_config.weight_decay, + ) + # Set up learning rate scheduler with optional warmup + warmup_batches = rl_config.warmup_batches or 0 + main_batches = max(1, rl_config.num_batches - warmup_batches) + eta_min = ( + rl_config.learning_rate + if rl_config.min_learning_rate is None + else rl_config.min_learning_rate + ) + + # Create main scheduler based on lr_schedule config + if rl_config.lr_schedule == "linear": + # Linear decay from learning_rate to eta_min + main_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1.0, + end_factor=eta_min / rl_config.learning_rate + if rl_config.learning_rate > 0 + else 1.0, + total_iters=main_batches, + ) + else: + # Default: Cosine annealing + main_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=main_batches, + eta_min=eta_min, + ) + + if warmup_batches > 0: + warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=1e-8 / rl_config.learning_rate, # Start from near-zero + end_factor=1.0, + total_iters=warmup_batches, + ) + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, main_scheduler], + milestones=[warmup_batches], + ) + if is_main: + print(f"Using linear warmup for {warmup_batches} batches") + else: + scheduler = main_scheduler + + if is_main: + print(f"Using {rl_config.lr_schedule} LR schedule") + scaler = torch.amp.GradScaler( + enabled=True, + ) + + # Load checkpoint if available + start_batch = 0 + if should_load_rl_checkpoint(rl_config): + start_batch = load_rl_checkpoint( + model=model, + optimizer=optimizer, + scheduler=scheduler, + scaler=scaler, + checkpoint_load_path=rl_config.checkpoint_load_path, + device=device, + ) + if is_main: + print(f"Resuming training from batch {start_batch}") + else: + if rl_config.checkpoint_load_path is not None and is_main: + print( + f"Checkpoint file {rl_config.checkpoint_load_path} not found or load/save paths are identical and file doesn't exist. Starting from scratch." + ) + + # Wrap model with DDP if using distributed training + if distributed: + model = DDP( + model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=False, + broadcast_buffers=False, + ) + if is_main: + print("Wrapped model with DistributedDataParallel") + if writer is not None: + writer.add_scalar("config/num_gpus", world_size) + + eps = rl_config.eps + target_seq_len = rl_config.seq_len + batch_size = rl_config.batch_size + sub_batch_size = rl_config.sub_batch_size + super_batch_size = batch_size * sub_batch_size + + # Compute curriculum schedule if enabled + if rl_config.seq_len_curriculum is not None: + curriculum_schedule = compute_curriculum_schedule( + rl_config.num_batches, + target_seq_len, + rl_config.seq_len_curriculum, + min_seq_len=rl_config.seq_len_curriculum_min, + ) + if is_main: + # Log curriculum stages + stages = [] + current_len = curriculum_schedule[0] + stage_start = 0 + for i, length in enumerate(curriculum_schedule): + if length != current_len: + stages.append( + (current_len, stage_start, i - 1, i - stage_start) + ) + current_len = length + stage_start = i + stages.append( + ( + current_len, + stage_start, + len(curriculum_schedule) - 1, + len(curriculum_schedule) - stage_start, + ) + ) + print( + f"Curriculum schedule ({rl_config.seq_len_curriculum} mode): " + f"{len(stages)} stages" + ) + for seq_len_stage, start, end, num_batches_stage in stages: + print( + f" seq_len={seq_len_stage}: batches {start}-{end} ({num_batches_stage} batches)" + ) + else: + curriculum_schedule = None + + batch_losses: list[float] = [] + + for batch_i in range(start_batch, rl_config.num_batches): + # Determine seq_len for this batch (curriculum or fixed) + if curriculum_schedule is not None: + seq_len = curriculum_schedule[batch_i] + else: + seq_len = target_seq_len + if rl_config.seed is not None: + # We re-initialize the seed for every batch + # As randomness might be different in the fitting + # step. + seed = rl_config.seed + batch_i + (rank if distributed else 0) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + # Save model/optimizer state before this batch for potential rollback + batch_rolled_back = False + if rl_config.rollback_loss_threshold is not None: + model_state_before_batch = { + k: v.clone() for k, v in unwrap_model(model).state_dict().items() + } + optimizer_state_before_batch = copy.deepcopy(optimizer.state_dict()) + + # Calculate current number of features for this batch: (batch_i % num_features) + 1 + if rl_config.mix_k_features_in_opt > 1: + all_features = list(range(1, rl_config.num_features + 1)) + num_full_copies = rl_config.mix_k_features_in_opt // len(all_features) + remainder = rl_config.mix_k_features_in_opt % len(all_features) + # Take all features num_full_copies times + current_num_features_list = all_features * num_full_copies + # Sample the remainder + if remainder > 0: + current_num_features_list += random.sample(all_features, remainder) + + else: + current_num_features_list = [ + ((batch_i + 1) % rl_config.num_features) + 1 + ] + + start_generation_time = time.time() + + gen_results: list[PathGenerationResult] = [] + for feat_idx, current_num_features in enumerate(current_num_features_list): + # Sample seq_len for this rollout if randomization is enabled + # Use synchronized RNG to ensure same seq_len across distributed workers + if rl_config.randomize_seq_len and seq_len > 3: + seq_len_rng = random.Random( + (rl_config.seed or 0) + batch_i * 10_000 + feat_idx + ) + current_seq_len = seq_len_rng.randint(3, seq_len) + else: + current_seq_len = seq_len + + # Compute normalized seq_len for y_style_encoder: (curr/max)*2-1 to be in [-1, 1] + # Create y_style in expanded shape [super_batch_size, 1] upfront + # and subselect for joint steps in generate_paths + if rl_config.seq_len_y_style_encoder: + normalized_seq_len = (current_seq_len / seq_len) * 2 - 1 + y_style = torch.full( + (batch_size * sub_batch_size, 1), + normalized_seq_len, + device=device, + ) # [super_batch_size, 1] + else: + y_style = None + + # Sample bo_batch_size for this rollout if randomization is enabled + if rl_config.randomize_bo_batch_size and rl_config.bo_batch_size > 1: + current_bo_batch_size = random.randint(1, rl_config.bo_batch_size) + else: + current_bo_batch_size = rl_config.bo_batch_size + + # Compute a synchronized seed for sampler feature selection + # This ensures all distributed workers use the same dimensionality + sampler_seed = ( + (rl_config.seed or 0) + batch_i * 1_000 + current_num_features + ) + print( + f"{sampler_seed=}, {current_num_features=}, {current_bo_batch_size=}, {current_seq_len=}" + ) + + sampler = sampler_factory( + batch_size, + num_features=current_num_features, + device=device, + seed=sampler_seed, + ) + if rl_config.independent_noise_in_function_draws: + sampler = partial(sampler, independent_noise=True) + + # Check if the sampler provides its own num_features + if hasattr(sampler, "num_features"): + current_num_features = sampler.num_features + + # Sample joint_steps for joint rollout training + # Use current_seq_len (not seq_len) to ensure joint_steps < current_seq_len + joint_steps = 0 + if rl_config.joint_rollout_training is not None: + # Sample a random position to train on (0 to current_seq_len-1) + # joint_steps = training position, so trajectories are identical until then + if rl_config.joint_rollout_training == "remaining" and distributed: + # For "remaining" mode in DDP, use a synchronized seed across all workers + # to ensure the same joint_steps value, enabling proper gradient synchronization + sync_rng = random.Random( + (rl_config.seed or 0) + batch_i * 10_000 + feat_idx + 1412 + ) + joint_steps = sync_rng.randint(0, current_seq_len - 1) + else: + joint_steps = random.randint(0, current_seq_len - 1) + + gen_res = generate_paths( + model=model, + sampler=sampler, + batch_size=batch_size, + seq_len=current_seq_len, + sub_batch_size=sub_batch_size, + choose_next_in_set=rl_config.choose_next_in_set, + choice_set_size=rl_config.choice_set_size, + ei_selector=rl_config.ei_selector, + standardize_y=rl_config.standardize_y, + current_num_features=current_num_features, + device=device, + super_choice_set_factor=rl_config.super_choice_set_factor, + top_share=rl_config.choice_set_top_share, + joint_steps=joint_steps, + basemodel_ei_input=rl_config.basemodel_ei_input, + basemodel_for_ei=basemodel_for_ei, + mixed_precision=rl_config.mixed_precision_path_generation, + bo_batch_size=current_bo_batch_size, + binary_feature_likelihood=rl_config.binary_feature_likelihood, + around_train_point_share=rl_config.around_train_point_share, + around_train_point_std=rl_config.around_train_point_std, + y_style=y_style, + ) + + generation_time = time.time() - start_generation_time + + start_draw_time = time.time() + + draw_size = 100_000 + draw_x = torch.rand( + batch_size, draw_size, current_num_features, device=device + ) + # Apply binary feature constraint to draw samples if applicable + if gen_res.binary_features_mask is not None: + binary_mask_expanded = gen_res.binary_features_mask.unsqueeze( + 1 + ).expand(-1, draw_size, -1) + draw_x = torch.where( + binary_mask_expanded, (draw_x > 0.5).float(), draw_x + ) + draw, _ = sampler( + draw_x, + independent_noise=True, + ) + + def compute_quantiles( + y_version, use_draw: bool = True, draw=draw, draw_size=draw_size + ): + y_view = y_version.view(batch_size, -1) + if use_draw: + # Use draw samples to compute quantile positions + sorted_ref = draw.view(batch_size, draw_size).sort(1).values + ref_size = draw_size + else: + # Use batch data itself to compute quantile positions + sorted_ref = y_view.sort(1).values + ref_size = y_view.shape[1] + + quantiles = ( + torch.searchsorted(sorted_ref, y_view).float() / ref_size + ) + + return quantiles.view(batch_size * sub_batch_size, -1) + + def compute_standardized_ys( + y_version, use_draw: bool = False, draw=draw, draw_size=draw_size + ): + # Standardize per function (per batch) - z-score normalization + if use_draw: + # Use draw samples to compute mean and std + draw_view = draw.view(batch_size, draw_size) + mean = draw_view.mean(dim=1, keepdim=True) + std = draw_view.std(dim=1, keepdim=True) + else: + # Use batch data to compute mean and std + y_view = y_version.view(batch_size, -1) + mean = y_view.mean(dim=1, keepdim=True) + std = y_view.std(dim=1, keepdim=True) + std = std.clamp(min=1e-8) + y_view = y_version.view(batch_size, -1) + standardized = (y_view - mean) / std + return standardized.view(batch_size * sub_batch_size, -1) + + use_draw_for_reward = rl_config.reward.standardization_source == "draw" + + # Compute quantiles and standardized values for REWARDS + # (respects standardization_source setting) + reward_y_quantiles = compute_quantiles( + gen_res.ys, use_draw=use_draw_for_reward + ) + reward_target_y_quantiles = compute_quantiles( + gen_res.target_ys, use_draw=use_draw_for_reward + ) + reward_standardized_ys = compute_standardized_ys( + gen_res.ys, use_draw=use_draw_for_reward + ) + reward_standardized_target_ys = compute_standardized_ys( + gen_res.target_ys, use_draw=use_draw_for_reward + ) + + # Compute quantiles and standardized values for PLOTS + # (always uses draw for quantiles, batch for standardized) + plot_y_quantiles = compute_quantiles(gen_res.ys, use_draw=True) + plot_target_y_quantiles = compute_quantiles( + gen_res.target_ys, use_draw=True + ) + plot_standardized_ys = compute_standardized_ys( + gen_res.ys, use_draw=True + ) + plot_standardized_target_ys = compute_standardized_ys( + gen_res.target_ys, use_draw=True + ) + + draw_time = time.time() - start_draw_time + + # These rewards make sense, but they do have the problem that randomly roll-out might just be disadvantaged by earlier mishaps in it + # This could be fixed by keeping all trajectories the same until we break out into sub trajectories and then only train at the break-out point + + gen_res.normalized_avg_rewards, gen_res.unnormalized_avg_rewards = ( + rl_config.reward.compute_reward( + gen_res.ys.view(batch_size, sub_batch_size, -1), + gen_res.target_ys.view(batch_size, sub_batch_size, -1), + reward_y_quantiles.view(batch_size, sub_batch_size, -1), + reward_target_y_quantiles.view(batch_size, sub_batch_size, -1), + reward_standardized_ys.view(batch_size, sub_batch_size, -1), + reward_standardized_target_ys.view( + batch_size, sub_batch_size, -1 + ), + ) + ) + + # For batched BO, copy reward from last position in each batch to all positions + if gen_res.bo_batch_size > 1: + reward_seq_len = gen_res.normalized_avg_rewards.shape[-1] + normalized = gen_res.normalized_avg_rewards.clone() + unnormalized = gen_res.unnormalized_avg_rewards.clone() + + for batch_start in range(0, reward_seq_len, gen_res.bo_batch_size): + batch_end = min( + batch_start + gen_res.bo_batch_size, reward_seq_len + ) + # Copy reward from last position in batch to all positions in batch + normalized[..., batch_start:batch_end] = normalized[ + ..., batch_end - 1 : batch_end + ] + unnormalized[..., batch_start:batch_end] = unnormalized[ + ..., batch_end - 1 : batch_end + ] + + gen_res.normalized_avg_rewards = normalized + gen_res.unnormalized_avg_rewards = unnormalized + gen_res.draw = draw + # Store plot versions (not reward versions) for tensorboard logging + gen_res.y_quantiles = plot_y_quantiles + gen_res.target_y_quantiles = plot_target_y_quantiles + gen_res.standardized_ys = plot_standardized_ys + gen_res.standardized_target_ys = plot_standardized_target_ys + gen_res.draw_size = draw_size + + gen_results.append(gen_res) + + # TODO put all the relevant var's, that is for stuff without plotting just normalized_avg_rewards, I believe, into the gen_res and then generate multiple gen_res and cycle through them during training + # when we do the features, we should enable it to just be a subset of the features + # and when we run the trainings with this setting we should maybe think about reducing the batch size!? + + if writer is not None: + # Log current curriculum seq_len + writer.add_scalar("curriculum/seq_len", seq_len, batch_i) + + def avg( + values: list[float], + ) -> float: + return sum(values) / len(values) + + # Aggregated metrics (average across all gen_results) + avg_reward = avg( + [gr.unnormalized_avg_rewards.mean().item() for gr in gen_results] + ) + writer.add_scalar( + "avg_reward", + avg_reward, + batch_i, + ) + + writer.add_scalar( + "reward_metrics/std_mean", + avg( + [ + gr.unnormalized_avg_rewards.std(1, keepdim=True) + .mean() + .item() + for gr in gen_results + ] + ), + batch_i, + ) + writer.add_scalar( + "reward_metrics/std_min", + min( + gr.unnormalized_avg_rewards.std(1, keepdim=True).min().item() + for gr in gen_results + ), + batch_i, + ) + writer.add_scalar( + "retrieved_y/mean", + avg([gr.ys.mean().item() for gr in gen_results]), + batch_i, + ) + writer.add_scalar( + "retrieved_y/last", + avg([gr.ys[:, -1].mean().item() for gr in gen_results]), + batch_i, + ) + writer.add_scalar( + "retrieved_y/max", + avg([gr.ys.max(1).values.mean().item() for gr in gen_results]), + batch_i, + ) + + # Regret metrics + for name in ["noisy", "noiseless"]: + regrets = [] + for gr in gen_results: + max_draw_per_function = gr.draw.max(1).values + used_ys = gr.ys if name == "noisy" else gr.target_ys + regret = ( + ( + max_draw_per_function[:, None] + - used_ys.max(1).values.view(batch_size, sub_batch_size) + ) + .mean() + .item() + ) + regrets.append(regret) + writer.add_scalar( + f"retrieved_y/final_{name}_regret_v_{draw_size}rs", + avg(regrets), + batch_i, + ) + + # Noiseless regret at step increments of 5 + step_increments = list(range(5, seq_len + 1, 5)) + for step_cutoff in step_increments: + regrets_at_step = [] + for gr in gen_results: + max_draw_per_function = gr.draw.max(1).values + used_ys = gr.target_ys[:, :step_cutoff] + max_y_up_to_step = used_ys.max(1).values.view( + batch_size, sub_batch_size + ) + regret = ( + (max_draw_per_function[:, None] - max_y_up_to_step) + .mean() + .item() + ) + regrets_at_step.append(regret) + writer.add_scalar( + f"retrieved_y/noiseless_regret_at_step_{step_cutoff}", + avg(regrets_at_step), + batch_i, + ) + + # Distance to incumbent (Euclidean distance to best point seen so far) + # Averaged over groups of 5 steps + # Skip this logging when using randomize_seq_len as sequence lengths vary + if not rl_config.randomize_seq_len: + for group_start in range(0, seq_len, 5): + group_end = min(group_start + 5, seq_len) + all_distances = [] + for gr in gen_results: + for step_i in range(group_start, group_end): + if step_i == 0: + # At step 0, there's no previous incumbent + continue + # Incumbent is the argmax of target_ys among steps 0 to step_i-1 + incumbent_indices = gr.target_ys[:, :step_i].argmax( + dim=1 + ) + batch_indices = torch.arange( + gr.predictions.shape[0], + device=gr.predictions.device, + ) + incumbent_x = gr.predictions[ + batch_indices, incumbent_indices + ] + current_x = gr.predictions[:, step_i] + # Euclidean distance + distance = ( + (current_x - incumbent_x).pow(2).sum(dim=-1).sqrt() + ) + all_distances.append(distance.mean().item()) + + if all_distances: + avg_distance = sum(all_distances) / len(all_distances) + writer.add_scalar( + f"incumbent_distance/steps_{group_start + 1}_to_{group_end}", + avg_distance, + batch_i, + ) + + # Quantile metrics + writer.add_scalar( + "retrieved_y/mean_quantile", + avg([gr.target_y_quantiles.mean().item() for gr in gen_results]), + batch_i, + ) + writer.add_scalar( + "retrieved_y/last_quantile", + avg( + [ + gr.target_y_quantiles[:, -1].mean().item() + for gr in gen_results + ] + ), + batch_i, + ) + max_quantile = avg( + [ + gr.target_y_quantiles.max(1).values.mean().item() + for gr in gen_results + ] + ) + writer.add_scalar( + "retrieved_y/max_quantile", + max_quantile, + batch_i, + ) + + print( + "average max quantile", + max_quantile, + ) + print("avereage reward", avg_reward) + print("generation took ", generation_time, "seconds") + + writer.add_scalar( + "retrieved_y/max_standardized", + avg( + [ + gr.standardized_target_ys.max(1).values.mean().item() + for gr in gen_results + ] + ), + batch_i, + ) + + # Per-feature metrics + for gr in gen_results: + num_feats = gr.current_num_features + local_draw_size = gr.draw_size + max_draw_per_function = gr.draw.max(1).values + + writer.add_scalar( + f"{num_feats}_features/avg_reward", + gr.unnormalized_avg_rewards.mean().item(), + batch_i, + ) + for name, used_ys in [ + ("noisy", gr.ys), + ("noiseless", gr.target_ys), + ]: + writer.add_scalar( + f"{num_feats}_features/retrieved_y/final_{name}_regret_v_{local_draw_size}rs", + ( + max_draw_per_function[:, None] + - used_ys.max(1).values.view(batch_size, sub_batch_size) + ) + .mean() + .item(), + batch_i, + ) + writer.add_scalar( + f"{num_feats}_features/retrieved_y/max_quantile", + gr.target_y_quantiles.max(1).values.mean().item(), + batch_i, + ) + + # Rollout distribution metrics (entropy, max_prob, sampled_prob) + # Averaged across all steps and all samples in the batch + # Note: tensors may have different shapes (batch_size during joint steps, + # super_batch_size after), so we concatenate and compute mean safely + all_entropies = [] + all_max_probs = [] + all_sampled_probs = [] + for gr in gen_results: + if gr.step_entropies: + # Concatenate all step tensors (handles different shapes) + entropies = torch.cat([e.flatten() for e in gr.step_entropies]) + max_probs = torch.cat([m.flatten() for m in gr.step_max_probs]) + sampled_probs = torch.cat( + [s.flatten() for s in gr.step_sampled_probs] + ) + # Mean across all samples + all_entropies.append(entropies.mean().item()) + all_max_probs.append(max_probs.mean().item()) + all_sampled_probs.append(sampled_probs.mean().item()) + + if all_entropies: + writer.add_scalar( + "rollout/entropy", + sum(all_entropies) / len(all_entropies), + batch_i, + ) + writer.add_scalar( + "rollout/max_prob", + sum(all_max_probs) / len(all_max_probs), + batch_i, + ) + writer.add_scalar( + "rollout/sampled_prob", + sum(all_sampled_probs) / len(all_sampled_probs), + batch_i, + ) + + # training actually starts here + start_training_loop_time = time.time() + + old_model.load_state_dict(copy.deepcopy(unwrap_model(model).state_dict())) + + losses: list[float] = [] + nan_encountered = False + + # Tracking for eps clamping analysis - per repetition + per_rep_stats: list[dict[str, int]] = [] + + # Tracking for zero variance filtering + zero_variance_filtered_samples = 0 + total_filtering_samples = 0 + + # Build list of (gen_res_idx, seq_idx) pairs to iterate over + if rl_config.joint_rollout_training == "single": + # Only train on the first non-joint position (the split point) + training_steps = [ + (gr_idx, gr.joint_steps) for gr_idx, gr in enumerate(gen_results) + ] + elif rl_config.joint_rollout_training == "remaining": + # Train on all positions from the split point onwards + training_steps = [ + (gr_idx, sep) + for gr_idx, gr in enumerate(gen_results) + for sep in range(gr.joint_steps, gr.predictions.shape[1]) + ] + else: + # No joint rollout training: train on all positions + training_steps = [ + (gr_idx, sep) + for gr_idx, gr in enumerate(gen_results) + for sep in range(gr.predictions.shape[1]) + ] + + for _rep_idx in range(rl_config.experience_repetitions): + # Per-repetition tracking + rep_total_samples = 0 + rep_ratio_above_eps = 0 + rep_ratio_below_eps = 0 + rep_clamp_active = 0 + random.shuffle(training_steps) + + for gr_idx, i in training_steps: + gen_res = gen_results[gr_idx] + current_num_features = gen_res.current_num_features + + with torch.amp.autocast(dtype=torch.float16, device_type="cuda"): + optimizer.zero_grad(set_to_none=True) + if i == 0: + train_x = gen_res.predictions[:, :i] + train_y = gen_res.ys[:, :i, None] + full_train_x = torch.cat((train_x, train_y), -1) + else: + train_x, train_y = preprocess_train_x_and_y( + gen_res.predictions[:, :i], + gen_res.ys[:, :i, None], + rl_config.standardize_y, + bo_batch_size=gen_res.bo_batch_size, + ) + full_train_x = torch.cat((train_x, train_y), -1) + full_test_x = torch.full( + (batch_size * sub_batch_size, 1, current_num_features + 1), + torch.nan, + device=device, + ) + if rl_config.choose_next_in_set: + # Prepare options with base model EI if enabled + if rl_config.basemodel_ei_input: + # Reuse the precomputed EI values from generation + basemodel_ei = gen_res.basemodel_ei_values[i] + # Create augmented options with EI + opts = torch.cat( + [gen_res.options[i], basemodel_ei], dim=-1 + ) + # Add 0s to train_x for the EI column + train_x = torch.cat( + [ + train_x, + torch.zeros( + train_x.shape[0], + train_x.shape[1], + 1, + device=device, + ) + + 0.5, + ], + dim=-1, + ) + else: + opts = gen_res.options[i] + + # Use stored y_style from path generation (already in super_batch_size shape) + with torch.no_grad(): + logits_old = transform_logits( + old_model, + old_model( + x=train_x, + y=train_y, + test_x=opts, + y_style=gen_res.y_style, + ), + rl_config.ei_selector, + ) # shape: superb x num options + + logits_new = transform_logits( + model, + model( + x=train_x, + y=train_y, + test_x=opts, + y_style=gen_res.y_style, + ), + rl_config.ei_selector, + ) # shape: superb x num options + + log_p_new = logits_new.log_softmax(1)[ + torch.arange(super_batch_size), + gen_res.chosen_options[i], + ] + log_p_old = logits_old.log_softmax(1)[ + torch.arange(super_batch_size), + gen_res.chosen_options[i], + ] + pred_ratio = (log_p_new - log_p_old).exp() + + else: + logits_new = model( + x=full_train_x, y=None, test_x=full_test_x + )[:, :, 0].squeeze(1) + with torch.no_grad(): + logits_old = ( + old_model( + x=full_train_x, y=None, test_x=full_test_x + )[:, :, 0] + .squeeze(1) + .detach() + ) + # criterion computes the neg log likelihood, that is why the ratio is inverted + # Note: old_model is always unwrapped, but model might be DDP-wrapped + nll_old = old_model.criterion( + logits_old, gen_res.predictions[:, i] + ) + nll_new = unwrap_model(model).criterion( + logits_new, gen_res.predictions[:, i] + ) + log_p_new = -nll_new + pred_ratio = (nll_old - nll_new).exp() + normalized_avg_rewards = gen_res.normalized_avg_rewards.view( + batch_size * sub_batch_size, -1 + ) + + # Filter low-magnitude functions per time step + # Compute average magnitude (L1 / sub_batch_size) per function + rewards_at_step = normalized_avg_rewards[:, i].view( + batch_size, sub_batch_size + ) + avg_magnitude_per_func = rewards_at_step.abs().mean( + dim=1 + ) # [batch_size] + + # Zero out rewards for functions with low average magnitude + # and scale up remaining rewards to maintain gradient magnitude + if rl_config.filter_rewards_up_to_magnitude is not None: + low_magnitude_mask = ( + avg_magnitude_per_func + <= rl_config.filter_rewards_up_to_magnitude + ) # [batch_size] + effective_batch_size = ( + batch_size - low_magnitude_mask.sum().item() + ) + + # Track zero variance filtered samples + zero_variance_filtered_samples += ( + low_magnitude_mask.sum().item() + ) + total_filtering_samples += batch_size + + # Expand mask to super_batch_size and zero out rewards + full_zero_mask = ( + low_magnitude_mask.unsqueeze(1) + .expand(-1, sub_batch_size) + .reshape(-1) + ) # [batch_size * sub_batch_size] + # Create a copy of rewards and zero out low-magnitude functions + rewards_for_loss = normalized_avg_rewards[:, i].clone() + rewards_for_loss[full_zero_mask] = 0.0 + # Scale up by batch_size / effective_batch_size to maintain gradient magnitude + if effective_batch_size > 0: + rewards_for_loss = ( + rewards_for_loss * batch_size / effective_batch_size + ) + else: + rewards_for_loss = normalized_avg_rewards[:, i] + + # Compute eps bounds (asymmetric if eps_low is specified) + eps_lower = ( + rl_config.eps_low if rl_config.eps_low is not None else eps + ) + eps_upper = eps + + # Track eps clamping statistics for this repetition + with torch.no_grad(): + n_samples = pred_ratio.numel() + rep_total_samples += n_samples + rep_ratio_above_eps += ( + (pred_ratio > 1 + eps_upper).sum().item() + ) + rep_ratio_below_eps += ( + (pred_ratio < 1 - eps_lower).sum().item() + ) + # Clamp is active when pred_ratio is outside bounds + rep_clamp_active += ( + ( + (pred_ratio > 1 + eps_upper) + | (pred_ratio < 1 - eps_lower) + ) + .sum() + .item() + ) + + if rl_config.algorithm == "grpo": + # GRPO: min(ratio * advantage, clamp(ratio) * advantage) + goal = pred_ratio * rewards_for_loss + clamped_goal = ( + pred_ratio.clamp(1 - eps_lower, 1 + eps_upper) + * rewards_for_loss + ) + # GRPO formulation is a maximization and we minimize + loss = -torch.min(goal, clamped_goal) + else: + # CISPO: stop_grad(clamp(ratio)) * advantage * log_p_new + assert rl_config.algorithm == "cispo" + clamped_weight = pred_ratio.detach().clamp( + 1 - eps_lower, 1 + eps_upper + ) + # Maximize weighted log probability, so negate for minimization + loss = -clamped_weight * rewards_for_loss * log_p_new + + loss = loss.mean() + + scaler.scale(loss).backward() + + losses.append(loss.detach().item()) + + if torch.isnan(loss): + if is_main: + print("nan loss") + nan_encountered = True + break + + if ( + rl_config.grad_clip_norm is not None + and rl_config.grad_clip_norm > 0 + ): + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + model.parameters(), rl_config.grad_clip_norm + ) + + scaler.step(optimizer) + scaler.update() + + if nan_encountered: + break + + # Store this repetition's stats + per_rep_stats.append( + { + "total_samples": rep_total_samples, + "ratio_above_eps": rep_ratio_above_eps, + "ratio_below_eps": rep_ratio_below_eps, + "clamp_active": rep_clamp_active, + } + ) + + training_loop_time = time.time() - start_training_loop_time + mean_loss = sum(losses) / len(losses) if losses else float("nan") + + # Synchronize mean_loss across all workers for coordinated rollback decision + if distributed: + mean_loss_tensor = torch.tensor(mean_loss, device=device) + dist.all_reduce(mean_loss_tensor, op=dist.ReduceOp.AVG) + mean_loss = mean_loss_tensor.item() + + # Check if batch should be rolled back due to high loss + if ( + rl_config.rollback_loss_threshold is not None + and mean_loss > rl_config.rollback_loss_threshold + ): + # Rollback model and optimizer state to before this batch + unwrap_model(model).load_state_dict(model_state_before_batch) + optimizer.load_state_dict(optimizer_state_before_batch) + batch_rolled_back = True + if is_main: + print( + f"Batch {batch_i} rolled back: mean loss {mean_loss:.4f} > threshold {rl_config.rollback_loss_threshold}" + ) + + batch_losses.append(mean_loss) + if is_main: + print("mean loss", mean_loss, "after batch", batch_i) + + if writer is not None and not math.isnan(mean_loss): + print(f"rank {rank} is pushing times for batch {batch_i}") + writer.add_scalar("loss/mean_batch_loss", mean_loss, batch_i) + writer.add_scalar( + "optimizer/lr", optimizer.param_groups[0]["lr"], batch_i + ) + writer.add_scalar("time/generation_time", generation_time, batch_i) + writer.add_scalar("time/draw_time", draw_time, batch_i) + writer.add_scalar( + "time/training_loop_time", training_loop_time, batch_i + ) + + # Log eps clamping statistics - per repetition + for rep_idx, stats in enumerate(per_rep_stats): + if stats["total_samples"] > 0: + ratio_above_eps_frac = ( + stats["ratio_above_eps"] / stats["total_samples"] + ) + ratio_below_eps_frac = ( + stats["ratio_below_eps"] / stats["total_samples"] + ) + clamp_active_frac = ( + stats["clamp_active"] / stats["total_samples"] + ) + + writer.add_scalar( + f"ppo_clipping/rep_{rep_idx}/ratio_above_eps_frac", + ratio_above_eps_frac, + batch_i, + ) + writer.add_scalar( + f"ppo_clipping/rep_{rep_idx}/ratio_below_eps_frac", + ratio_below_eps_frac, + batch_i, + ) + writer.add_scalar( + f"ppo_clipping/rep_{rep_idx}/clamp_active_frac", + clamp_active_frac, + batch_i, + ) + writer.add_scalar( + f"ppo_clipping/rep_{rep_idx}/ratio_within_bounds_frac", + 1.0 - clamp_active_frac, + batch_i, + ) + + # Also log aggregated stats across all repetitions + total_samples_all = sum(s["total_samples"] for s in per_rep_stats) + if total_samples_all > 0: + total_above = sum(s["ratio_above_eps"] for s in per_rep_stats) + total_below = sum(s["ratio_below_eps"] for s in per_rep_stats) + total_clamp = sum(s["clamp_active"] for s in per_rep_stats) + + writer.add_scalar( + "ppo_clipping/total/ratio_above_eps_frac", + total_above / total_samples_all, + batch_i, + ) + writer.add_scalar( + "ppo_clipping/total/ratio_below_eps_frac", + total_below / total_samples_all, + batch_i, + ) + writer.add_scalar( + "ppo_clipping/total/clamp_active_frac", + total_clamp / total_samples_all, + batch_i, + ) + + if per_rep_stats: + print(" PPO clipping stats per repetition:") + for rep_idx, stats in enumerate(per_rep_stats): + if stats["total_samples"] > 0: + print( + f" rep_{rep_idx}: " + f"above_eps={stats['ratio_above_eps'] / stats['total_samples']:.4f}, " + f"below_eps={stats['ratio_below_eps'] / stats['total_samples']:.4f}, " + f"clamp_active={stats['clamp_active'] / stats['total_samples']:.4f}" + ) + + # Log whether this batch was rolled back (1) or not (0) + writer.add_scalar( + "training/batch_rolled_back", + 1 if batch_rolled_back else 0, + batch_i, + ) + + if is_main and total_filtering_samples > 0: + zero_variance_filtered_share = ( + zero_variance_filtered_samples / total_filtering_samples + ) + writer.add_scalar( + "reward_metrics/zero_variance_filtered_share", + zero_variance_filtered_share, + batch_i, + ) + + scheduler.step() + + # Save checkpoint after each batch (only on main process) + if rl_config.checkpoint_save_path is not None and is_main: + _save_checkpoint( + model=unwrap_model(model), + base_train_config=base_train_config, + rl_config=rl_config, + output_path=rl_config.checkpoint_save_path, + batch_i=batch_i, + optimizer=optimizer, + scaler=scaler, + scheduler=scheduler, + ) + + # Save final checkpoint only on main process (legacy path for backwards compatibility) + + return { + "losses": batch_losses, + "model": unwrap_model(model), + "base_train_config": base_train_config, + } + + finally: + if writer is not None: + writer.close() + + +def _save_checkpoint( + model: torch.nn.Module, + base_train_config: base_config.BaseConfig, + rl_config: RLConfig, + output_path: str, + batch_i: int | None = None, + optimizer: torch.optim.Optimizer | None = None, + scaler: torch.amp.GradScaler | None = None, + scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, +): + """Save a checkpoint to the specified path. + + Args: + model: The model to save + base_train_config: The base training config + rl_config: The RL config + output_path: Path to save the checkpoint + batch_i: Current batch index (optional, for resumable checkpoints) + optimizer: Optimizer state (optional, for resumable checkpoints) + scaler: GradScaler state (optional, for resumable checkpoints) + scheduler: LR scheduler state (optional, for resumable checkpoints) + """ + checkpoint = { + "model_state_dict": { + k: v.detach().cpu() for k, v in model.state_dict().items() + }, + "base_train_config": base_train_config.to_dict(), + "config": rl_config.to_dict(), + } + + if batch_i is not None: + checkpoint["batch_i"] = batch_i + if optimizer is not None: + checkpoint["optimizer_state_dict"] = optimizer.state_dict() + if scaler is not None: + checkpoint["scaler_state_dict"] = scaler.state_dict() + if scheduler is not None: + checkpoint["scheduler_state_dict"] = scheduler.state_dict() + + local_save(checkpoint, output_path) + + +def should_load_rl_checkpoint( + rl_config: RLConfig, + check_path_exists_function: tp.Callable[[str], bool] | None = None, +) -> bool: + """Check if we should load a checkpoint. + + Returns True if: + - checkpoint_load_path is set AND + - Either load_path != save_path, OR the file exists + """ + if rl_config.checkpoint_load_path is None: + return False + + if check_path_exists_function is None: + check_path_exists_function = local_exists + + return (rl_config.checkpoint_save_path != rl_config.checkpoint_load_path) or ( + (rl_config.checkpoint_save_path == rl_config.checkpoint_load_path) + and check_path_exists_function(rl_config.checkpoint_load_path) + ) + + +def load_rl_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler.LRScheduler, + scaler: torch.amp.GradScaler, + checkpoint_load_path: str, + device: str, + load_function: tp.Callable | None = None, +) -> int: + """Load a checkpoint and restore training state. + + Args: + model: The model to load state into + optimizer: The optimizer to load state into + scheduler: The scheduler to load state into (or fast-forward) + scaler: The GradScaler to load state into + checkpoint_load_path: Path to load checkpoint from + device: Device to map tensors to + load_function: Custom load function (defaults to local_load) + + Returns: + The batch index to resume from (checkpoint batch_i + 1) + """ + print(f"Loading checkpoint from {checkpoint_load_path}") + + if load_function is None: + load_function = local_load + + try: + checkpoint = load_function(checkpoint_load_path, map_location=device) + + if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"], strict=True) + + if "optimizer_state_dict" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + if "scaler_state_dict" in checkpoint: + scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + if "batch_i" in checkpoint: + start_batch = checkpoint["batch_i"] + 1 + print(f"Resuming from batch {start_batch}") + + if "scheduler_state_dict" in checkpoint: + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + else: + for _ in range(start_batch): + scheduler.step() + + return start_batch + else: + print("Checkpoint does not contain batch index, starting from 0") + return 0 + else: + raise ValueError( + f"Checkpoint does not contain 'model_state_dict'. Keys: {checkpoint.keys() if isinstance(checkpoint, dict) else type(checkpoint)}" + ) + except Exception as e: + print(f"Error loading checkpoint: {e}") + raise e diff --git a/rl/utils.py b/rl/utils.py new file mode 100644 index 0000000..2208bc4 --- /dev/null +++ b/rl/utils.py @@ -0,0 +1,33 @@ +from typing import Tuple + +import torch +from pfns.model import transformer +from pfns.train import MainConfig + + +def load_config_and_model( + path: str = "tree/pfns/runs/singletaskgp2_clusteredx_0/checkpoint.pt", + map_location: str | None = None, +) -> Tuple[MainConfig, transformer.TableTransformer]: + """ + Load a config and model from a checkpoint file on the local filesystem. + + Args: + path: The path to the checkpoint file. + map_location: The device to load the tensors to. Same as torch.load, e.g. "cpu" or "cuda:0". + + Returns: + A tuple of (config, model). + """ + if map_location is None: + map_location = "cuda" if torch.cuda.is_available() else "cpu" + + checkpoint = torch.load(path, map_location=map_location) + + config_dict = checkpoint["config"] + + c = MainConfig.from_dict(config_dict) + model = c.model.create_model() + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + return c, model.to(map_location) diff --git a/test_convert_prior_to_x_only_format.py b/test_convert_prior_to_x_only_format.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/model/test_attention.py b/tests/model/test_attention.py index 8e6d8db..b5ea993 100644 --- a/tests/model/test_attention.py +++ b/tests/model/test_attention.py @@ -2,31 +2,31 @@ from pfns.model.multi_head_attention import MultiHeadAttention +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Testing attention on {device=}.") +n_batch = 7 +nhead = 4 +n_seq_q = 534 +n_seq_kv = 316 +embed_dim = 128 -def test_attention(): - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Testing attention on {device=}.") - n_batch = 7 - nhead = 4 - n_seq_q = 534 - n_seq_kv = 316 - embed_dim = 128 +dtype = torch.float16 if device == "cuda" else torch.float32 - dtype = torch.float16 if device == "cuda" else torch.float32 +x_q = torch.normal( + torch.tensor(0.0), + torch.tensor(1.0), + size=(n_batch, n_seq_q, embed_dim), +) +x_kv = torch.normal( + torch.tensor(0.0), + torch.tensor(1.0), + size=(n_batch, n_seq_kv, embed_dim), +) +x_q = x_q.to(device, dtype) +x_kv = x_kv.to(device, dtype) - x_q = torch.normal( - torch.tensor(0.0), - torch.tensor(1.0), - size=(n_batch, n_seq_q, embed_dim), - ) - x_kv = torch.normal( - torch.tensor(0.0), - torch.tensor(1.0), - size=(n_batch, n_seq_kv, embed_dim), - ) - x_q = x_q.to(device, dtype) - x_kv = x_kv.to(device, dtype) +def test_attention(): att_ref = torch.nn.MultiheadAttention( embed_dim, nhead, @@ -105,20 +105,6 @@ def test_attention(): y_ = att_multi_test(x_q, x_kv) assert torch.sqrt(torch.nn.functional.mse_loss(y, y_)) < 5e-5 - # Caching. - att_test = MultiHeadAttention( - input_size=embed_dim, - output_size=embed_dim, - d_k=embed_dim // nhead, - d_v=embed_dim // nhead, - nhead=nhead, - device=device, - dtype=dtype, - ) - y = att_test(x_q, x_kv, cache_kv=True) - y_ = att_test(x_q, use_cached_kv=True) - assert torch.sqrt(torch.nn.functional.mse_loss(y, y_)) < 5e-5 - if __name__ == "__main__": test_attention() diff --git a/tests/model/test_pfn_caching.py b/tests/model/test_pfn_caching.py new file mode 100644 index 0000000..1f4ec8a --- /dev/null +++ b/tests/model/test_pfn_caching.py @@ -0,0 +1,477 @@ +"""Tests for PFN caching functionality. + +Tests verify: +1. Gradient flow through test_x when using cached training context +2. Output consistency between cached and non-cached modes +3. Batch broadcast behavior with cached KV +4. Sequential vs batch evaluation equivalence +""" + +import pytest +import torch + +from pfns.model.transformer import TableTransformer + + +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Testing PFN caching on {device=}.") + +dtype = torch.float32 # Use float32 for gradient tests + + +@pytest.fixture +def minimal_transformer(): + """Create a minimal TableTransformer for testing.""" + model = TableTransformer( + ninp=32, + nhead=2, + nhid=64, + nlayers=2, + batch_first=True, + cache_trainset_representation=False, + ) + model = model.to(device=device, dtype=dtype) + # Add small non-zero values to weights to make outputs more interesting + with torch.no_grad(): + for p in model.parameters(): + p.add_(0.01 * torch.randn_like(p)) + return model + + +@pytest.fixture +def sample_data(): + """Generate sample training and test data.""" + n_train = 10 + n_test = 5 + n_features = 4 + batch_size = 1 # Base batch size for training data + + train_x = torch.randn(batch_size, n_train, n_features, device=device, dtype=dtype) + train_y = torch.randn(batch_size, n_train, 1, device=device, dtype=dtype) + test_x = torch.randn(batch_size, n_test, n_features, device=device, dtype=dtype) + + return { + "train_x": train_x, + "train_y": train_y, + "test_x": test_x, + "n_train": n_train, + "n_test": n_test, + "n_features": n_features, + } + + +def test_cache_output_consistency(minimal_transformer, sample_data): + """Test that cached and non-cached outputs are identical.""" + model = minimal_transformer + train_x = sample_data["train_x"] + train_y = sample_data["train_y"] + test_x = sample_data["test_x"] + + # Non-cached forward pass + model.cache_trainset_representation = False + with torch.no_grad(): + output_no_cache = model(x=train_x, y=train_y, test_x=test_x) + + # Cached forward pass: first populate cache, then use it + model.cache_trainset_representation = True + model.empty_trainset_representation_cache() + + with torch.no_grad(): + # Populate cache with training data + _ = model(x=train_x, y=train_y, test_x=test_x[:, :1, :]) # dummy test + + # Use cache for actual test + output_with_cache = model(x=None, y=None, test_x=test_x) + + # Clear cache + model.empty_trainset_representation_cache() + model.cache_trainset_representation = False + + assert torch.allclose(output_no_cache, output_with_cache, atol=1e-7), ( + f"Cached and non-cached outputs differ. " + f"Max diff: {(output_no_cache - output_with_cache).abs().max()}" + ) + + +def test_cache_gradient_flow(minimal_transformer, sample_data): + """Test that gradients flow through test_x when using cached training context. + + This is critical for acquisition function optimization in BoTorch. + """ + model = minimal_transformer + train_x = sample_data["train_x"] + train_y = sample_data["train_y"] + test_x = sample_data["test_x"].clone().requires_grad_(True) + + # Enable caching + model.cache_trainset_representation = True + model.empty_trainset_representation_cache() + + # Populate cache with no gradients on training data + with torch.no_grad(): + _ = model(x=train_x, y=train_y, test_x=torch.zeros_like(test_x[:, :1, :])) + + # Forward pass using cache - should allow gradients on test_x + output = model(x=None, y=None, test_x=test_x) + + # Backward pass + loss = output.sum() + loss.backward() + + # Verify gradients exist on test_x + assert test_x.grad is not None, "test_x.grad should not be None" + assert not torch.all(test_x.grad == 0), "test_x.grad should not be all zeros" + + # Cleanup + model.empty_trainset_representation_cache() + model.cache_trainset_representation = False + + +def test_cache_gradient_correctness(minimal_transformer, sample_data): + """Test that gradients are identical with and without caching.""" + model = minimal_transformer + train_x = sample_data["train_x"] + train_y = sample_data["train_y"] + + # Test 1: Compute gradients without caching + test_x_no_cache = sample_data["test_x"].clone().requires_grad_(True) + model.cache_trainset_representation = False + output_no_cache = model(x=train_x, y=train_y, test_x=test_x_no_cache) + output_no_cache.sum().backward() + grad_no_cache = test_x_no_cache.grad.clone() + + # Test 2: Compute gradients with caching + test_x_with_cache = sample_data["test_x"].clone().requires_grad_(True) + model.cache_trainset_representation = True + model.empty_trainset_representation_cache() + + # Populate cache + with torch.no_grad(): + _ = model( + x=train_x, y=train_y, test_x=torch.zeros_like(test_x_with_cache[:, :1, :]) + ) + + # Forward with cache + output_with_cache = model(x=None, y=None, test_x=test_x_with_cache) + output_with_cache.sum().backward() + grad_with_cache = test_x_with_cache.grad.clone() + + # Cleanup + model.empty_trainset_representation_cache() + model.cache_trainset_representation = False + + # Compare gradients + assert torch.allclose(grad_no_cache, grad_with_cache, atol=1e-7), ( + f"Gradients differ. Max diff: {(grad_no_cache - grad_with_cache).abs().max()}" + ) + + +def test_cache_batch_broadcast(minimal_transformer, sample_data): + """Test that cached KV with batch_size=1 can be used with larger query batches. + + This is essential for efficient acquisition function optimization where + we cache training context once and evaluate many test points in parallel. + """ + model = minimal_transformer + train_x = sample_data["train_x"] # (1, n, d) + train_y = sample_data["train_y"] # (1, n, 1) + n_features = sample_data["n_features"] + + # Create a larger batch of test points + batch_size = 4 + n_test = 3 + test_x_batch = torch.randn( + batch_size, n_test, n_features, device=device, dtype=dtype + ) + + # Compute reference: sequential evaluation (one at a time) + model.cache_trainset_representation = False + reference_outputs = [] + for i in range(batch_size): + with torch.no_grad(): + out = model(x=train_x, y=train_y, test_x=test_x_batch[i : i + 1]) + reference_outputs.append(out) + reference_output = torch.cat(reference_outputs, dim=0) + + # Enable caching with batch_size=1 + model.cache_trainset_representation = True + model.empty_trainset_representation_cache() + + # Populate cache with training data (batch_size=1) + with torch.no_grad(): + _ = model( + x=train_x, + y=train_y, + test_x=torch.zeros(1, 1, n_features, device=device, dtype=dtype), + ) + + # Evaluate all test points in parallel using cache + # The test_x has batch_size=4, but cache has batch_size=1 + # This requires broadcasting in the attention layer + with torch.no_grad(): + # Expand train context for batch + batch_output = model(x=None, y=None, test_x=test_x_batch) + + # Cleanup + model.empty_trainset_representation_cache() + model.cache_trainset_representation = False + + # Compare outputs + assert batch_output.shape == reference_output.shape, ( + f"Shape mismatch: batch={batch_output.shape}, ref={reference_output.shape}" + ) + assert torch.allclose(reference_output, batch_output, atol=1e-7), ( + f"Batch output differs from sequential. " + f"Max diff: {(reference_output - batch_output).abs().max()}" + ) + + +def test_attention_caching(): + """Test MultiHeadAttention caching produces identical outputs and correct gradients.""" + from pfns.model.multi_head_attention import MultiHeadAttention + + embed_dim = 128 + nhead = 4 + n_seq_kv = 316 + n_seq_q = 534 + n_batch = 7 + + att = MultiHeadAttention( + input_size=embed_dim, + output_size=embed_dim, + d_k=embed_dim // nhead, + d_v=embed_dim // nhead, + nhead=nhead, + device=device, + dtype=dtype, + ) + + x_kv = torch.randn(n_batch, n_seq_kv, embed_dim, device=device, dtype=dtype) + x_q = torch.randn(n_batch, n_seq_q, embed_dim, device=device, dtype=dtype) + + # Test output consistency between cached and non-cached + y = att(x_q, x_kv, cache_kv=True) + y_ = att(x_q, use_cached_kv=True) + assert torch.allclose(y, y_, atol=1e-7), ( + f"Cached and non-cached attention outputs differ. " + f"Max diff: {(y - y_).abs().max()}" + ) + + # Gradients should fail for train part (caching requires no grad) + x_q_grad = x_q.clone().requires_grad_(True) + with pytest.raises(AssertionError): + att(x_q_grad, x_kv, cache_kv=True) + + # Gradients should flow through test part when using cache + x_q_grad = x_q.clone().requires_grad_(True) + y_ = att(x_q_grad, x_kv, use_cached_kv=True) + y_.mean().backward() + assert x_q_grad.grad is not None + grads_with_cache = x_q_grad.grad.clone() + + # Compute gradients without caching - should match + x_q_grad = x_q.clone().requires_grad_(True) + y_ = att(x_q_grad, x_kv) + y_.mean().backward() + assert x_q_grad.grad is not None + grads_without_cache = x_q_grad.grad.clone() + + assert torch.allclose(grads_with_cache, grads_without_cache, atol=1e-7), ( + f"Attention gradients differ. " + f"Max diff: {(grads_with_cache - grads_without_cache).abs().max()}" + ) + + # Cleanup + att.empty_kv_cache() + + +def test_attention_cache_batch_broadcast(): + """Test MultiHeadAttention layer directly for batch broadcast with cached KV.""" + from pfns.model.multi_head_attention import MultiHeadAttention + + embed_dim = 32 + nhead = 2 + n_seq_kv = 10 + n_seq_q = 5 + + att = MultiHeadAttention( + input_size=embed_dim, + output_size=embed_dim, + d_k=embed_dim // nhead, + d_v=embed_dim // nhead, + nhead=nhead, + device=device, + dtype=dtype, + ) + + # KV with batch_size=1 + x_kv = torch.randn(1, n_seq_kv, embed_dim, device=device, dtype=dtype) + + # Query with batch_size=4 + batch_size_q = 4 + x_q = torch.randn(batch_size_q, n_seq_q, embed_dim, device=device, dtype=dtype) + + # Cache KV with batch_size=1 + with torch.no_grad(): + att(x_kv, x_kv, cache_kv=True) + + # Reference: compute each batch element separately by expanding KV + x_kv_expanded = x_kv.expand(batch_size_q, -1, -1) + with torch.no_grad(): + ref_output = att(x_q, x_kv_expanded) + + # Test using cache with batch_size=4 query + # Batch broadcast is supported - verify it works correctly + with torch.no_grad(): + cached_output = att(x_q, use_cached_kv=True) + + assert cached_output.shape == ref_output.shape + assert torch.allclose(ref_output, cached_output, atol=1e-7), ( + f"Cached attention output differs. " + f"Max diff: {(ref_output - cached_output).abs().max()}" + ) + + # Cleanup + att.empty_kv_cache() + + +def test_deterministic_seeded_outputs(minimal_transformer, sample_data): + """Test that with same seed, outputs are deterministic.""" + model = minimal_transformer + train_x = sample_data["train_x"] + train_y = sample_data["train_y"] + test_x = sample_data["test_x"] + + torch.manual_seed(42) + model.cache_trainset_representation = False + with torch.no_grad(): + output1 = model(x=train_x, y=train_y, test_x=test_x) + + torch.manual_seed(42) + with torch.no_grad(): + output2 = model(x=train_x, y=train_y, test_x=test_x) + + assert torch.allclose(output1, output2, atol=1e-7), "Outputs should be deterministic" + + +def test_cache_invalidation_on_data_change(minimal_transformer, sample_data): + """Test that cache is properly invalidated when training data changes. + + In a real BO loop, training data grows as new observations are added. + The cache must be invalidated when this happens to avoid stale predictions. + """ + model = minimal_transformer + train_x = sample_data["train_x"].clone() + train_y = sample_data["train_y"].clone() + test_x = sample_data["test_x"] + n_features = sample_data["n_features"] + + # Step 1: Populate cache with initial training data + model.cache_trainset_representation = True + model.empty_trainset_representation_cache() + + with torch.no_grad(): + _ = model( + x=train_x, + y=train_y, + test_x=torch.zeros(1, 1, n_features, device=device, dtype=dtype), + ) + + # Get prediction with original cache + with torch.no_grad(): + output_original = model(x=None, y=None, test_x=test_x) + + # Step 2: Clear cache and repopulate with DIFFERENT training data + model.empty_trainset_representation_cache() + + # Create modified training data (add noise) + train_x_modified = train_x + 0.5 * torch.randn_like(train_x) + train_y_modified = train_y + 0.5 * torch.randn_like(train_y) + + with torch.no_grad(): + _ = model( + x=train_x_modified, + y=train_y_modified, + test_x=torch.zeros(1, 1, n_features, device=device, dtype=dtype), + ) + + # Get prediction with new cache + with torch.no_grad(): + output_modified = model(x=None, y=None, test_x=test_x) + + # Step 3: Verify outputs are DIFFERENT (cache was properly updated) + assert not torch.allclose(output_original, output_modified, atol=1e-3), ( + "Outputs should differ after training data change. " + "Cache may not have been properly invalidated/repopulated." + ) + + # Step 4: Verify the new output matches a fresh non-cached computation + model.cache_trainset_representation = False + model.empty_trainset_representation_cache() + + with torch.no_grad(): + output_fresh = model(x=train_x_modified, y=train_y_modified, test_x=test_x) + + assert torch.allclose(output_modified, output_fresh, atol=1e-7), ( + f"Cached output with new data should match fresh computation. " + f"Max diff: {(output_modified - output_fresh).abs().max()}" + ) + + +def test_cache_with_batch_first_false(): + """Test caching works correctly with batch_first=False configuration.""" + model = TableTransformer( + ninp=32, + nhead=2, + nhid=64, + nlayers=2, + batch_first=False, # Key difference from other tests + cache_trainset_representation=False, + ) + model = model.to(device=device, dtype=dtype) + + # Add small non-zero values to weights + with torch.no_grad(): + for p in model.parameters(): + p.add_(0.01 * torch.randn_like(p)) + + n_train, n_test, n_features = 10, 5, 4 + batch_size = 1 + + # For batch_first=False, shape is (seq, batch, features) + train_x = torch.randn(n_train, batch_size, n_features, device=device, dtype=dtype) + train_y = torch.randn(n_train, batch_size, 1, device=device, dtype=dtype) + test_x = torch.randn(n_test, batch_size, n_features, device=device, dtype=dtype) + + # Non-cached forward pass + model.cache_trainset_representation = False + with torch.no_grad(): + output_no_cache = model(x=train_x, y=train_y, test_x=test_x) + + # Cached forward pass + model.cache_trainset_representation = True + model.empty_trainset_representation_cache() + + with torch.no_grad(): + # Populate cache + _ = model( + x=train_x, + y=train_y, + test_x=torch.zeros(1, batch_size, n_features, device=device, dtype=dtype), + ) + + # Use cache + output_with_cache = model(x=None, y=None, test_x=test_x) + + # Cleanup + model.empty_trainset_representation_cache() + model.cache_trainset_representation = False + + assert torch.allclose(output_no_cache, output_with_cache, atol=1e-7), ( + f"batch_first=False: Cached and non-cached outputs differ. " + f"Max diff: {(output_no_cache - output_with_cache).abs().max()}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/priors/test_convert_prior_to_x_only_format.py b/tests/priors/test_convert_prior_to_x_only_format.py new file mode 100644 index 0000000..8188d40 --- /dev/null +++ b/tests/priors/test_convert_prior_to_x_only_format.py @@ -0,0 +1,424 @@ +# pyre-strict + +import torch +from pfns.priors.convert_prior_to_x_only_format import get_batch +from pfns.priors.prior import Batch + + +def create_simple_traditional_get_batch( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + hyperparameters: dict | None = None, + n_targets_per_input: int = 1, + **kwargs, +) -> Batch: + """ + Simple traditional get_batch function for testing. + Creates linear functions: y = sum(x) + noise + """ + # Generate random input features + x = torch.rand(batch_size, seq_len, num_features) + + # Create simple linear relationship: y = sum(x_features) + small noise + y = x.sum(dim=2, keepdim=True) + torch.randn(batch_size, seq_len, 1) * 0.1 + + # For traditional format, target_y is the same as y but potentially repeated + target_y = y.repeat(1, 1, n_targets_per_input) + + return Batch( + x=x, + y=y, + target_y=target_y, + single_eval_pos=single_eval_pos, + ) + + +def create_complex_traditional_get_batch( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + hyperparameters: dict | None = None, + n_targets_per_input: int = 1, + **kwargs, +) -> Batch: + """ + More complex traditional get_batch function with optional attributes for testing. + """ + x = torch.rand(batch_size, seq_len, num_features) + y = x.mean(dim=2, keepdim=True) + torch.randn(batch_size, seq_len, 1) * 0.2 + target_y = y.repeat(1, 1, n_targets_per_input) + + # Add some optional attributes + style = torch.randn(batch_size, 3) # 3 style dimensions + y_style = torch.randn(batch_size, 2) # 2 y_style dimensions + + return Batch( + x=x, + y=y, + target_y=target_y, + single_eval_pos=single_eval_pos, + style=style, + y_style=y_style, + ) + + +class TestConvertPriorToXOnlyFormat: + """Test suite for the convert_prior_to_x_only_format wrapper function.""" + + def test_basic_conversion_functionality(self) -> None: + """Test that the basic conversion from traditional to x-only format works correctly.""" + batch_size = 4 + seq_len = 10 + num_features = 3 + single_eval_pos = 6 + + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + get_batch=create_simple_traditional_get_batch, + ) + + # Check that x-only format fields are properly set + assert x_only_batch.y is None + assert x_only_batch.target_y is None + assert x_only_batch.x is not None + assert x_only_batch.test_x is not None + assert x_only_batch.target is not None + + # Check shapes + assert x_only_batch.x.shape == (batch_size, single_eval_pos, num_features + 1) + assert x_only_batch.test_x.shape == ( + batch_size, + seq_len - single_eval_pos, + num_features + 1, + ) + assert x_only_batch.target.shape == ( + batch_size, + seq_len - single_eval_pos, + num_features + 1, + ) + + def test_x_format_contains_concatenated_features_and_y(self) -> None: + """Test that x contains the concatenation of training features and training y values.""" + batch_size = 2 + seq_len = 8 + num_features = 2 + single_eval_pos = 5 + + # Create traditional batch first to compare + traditional_batch = create_simple_traditional_get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + ) + + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + get_batch=create_simple_traditional_get_batch, + ) + + # Extract training portions + x_train_traditional = traditional_batch.x[:, :single_eval_pos, :] + y_train_traditional = traditional_batch.y[:, :single_eval_pos, :] + + # Check that x in x-only format is [x_features, y_values] + expected_x = torch.cat([x_train_traditional, y_train_traditional], dim=2) + torch.testing.assert_close(x_only_batch.x, expected_x) + + def test_test_x_has_nan_for_y_values(self) -> None: + """Test that test_x contains NaN for y values (to be predicted).""" + batch_size = 3 + seq_len = 7 + num_features = 2 + single_eval_pos = 4 + + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + get_batch=create_simple_traditional_get_batch, + ) + + # Check that the last column (y values) in test_x contains only NaN + y_column_in_test_x = x_only_batch.test_x[:, :, -1] + assert torch.all(torch.isnan(y_column_in_test_x)) + + # Check that the feature columns don't contain NaN + feature_columns_in_test_x = x_only_batch.test_x[:, :, :-1] + assert not torch.any(torch.isnan(feature_columns_in_test_x)) + + def test_target_contains_features_and_target_y(self) -> None: + """Test that target contains test features concatenated with target y values.""" + batch_size = 2 + seq_len = 9 + num_features = 3 + single_eval_pos = 5 + + # Create traditional batch to compare + traditional_batch = create_simple_traditional_get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + ) + + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + get_batch=create_simple_traditional_get_batch, + ) + + # Extract test portions from traditional format + x_test_traditional = traditional_batch.x[:, single_eval_pos:, :] + target_y_test_traditional = traditional_batch.target_y[:, single_eval_pos:, :] + + # Check that target in x-only format is [x_test_features, target_y_values] + expected_target = torch.cat( + [x_test_traditional, target_y_test_traditional], dim=2 + ) + torch.testing.assert_close(x_only_batch.target, expected_target) + + def test_multiple_targets_per_input(self) -> None: + """Test conversion with n_targets_per_input > 1.""" + batch_size = 2 + seq_len = 8 + num_features = 2 + single_eval_pos = 5 + n_targets_per_input = 3 + + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + get_batch=create_simple_traditional_get_batch, + n_targets_per_input=n_targets_per_input, + ) + + # With multiple targets, the target shape should be expanded + expected_target_shape = ( + batch_size, + seq_len - single_eval_pos, + num_features + 1, + n_targets_per_input, + ) + assert x_only_batch.target.shape == expected_target_shape + + # x and test_x should still have the same shapes as single target case + assert x_only_batch.x.shape == (batch_size, single_eval_pos, num_features + 1) + assert x_only_batch.test_x.shape == ( + batch_size, + seq_len - single_eval_pos, + num_features + 1, + ) + + def test_preserves_optional_attributes(self) -> None: + """Test that optional attributes from the traditional batch are preserved.""" + batch_size = 2 + seq_len = 6 + num_features = 2 + single_eval_pos = 3 + + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=single_eval_pos, + get_batch=create_complex_traditional_get_batch, + ) + + # Check that optional attributes are preserved + assert x_only_batch.style is not None + assert x_only_batch.y_style is not None + assert x_only_batch.style.shape == (batch_size, 3) + assert x_only_batch.y_style.shape == (batch_size, 2) + assert x_only_batch.single_eval_pos == single_eval_pos + + def test_hyperparameters_passed_through(self) -> None: + """Test that hyperparameters are correctly passed to the wrapped function.""" + + def test_get_batch_with_hyperparams( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + hyperparameters: dict | None = None, + **kwargs, + ) -> Batch: + # Check that hyperparameters are passed correctly + assert hyperparameters is not None + assert hyperparameters["test_param"] == "test_value" + + return create_simple_traditional_get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters, + **kwargs, + ) + + test_hyperparams = {"test_param": "test_value"} + + x_only_batch = get_batch( + batch_size=2, + seq_len=6, + num_features=2, + single_eval_pos=3, + get_batch=test_get_batch_with_hyperparams, + hyperparameters=test_hyperparams, + ) + + # If we get here without assertion error, hyperparameters were passed correctly + assert x_only_batch is not None + + def test_kwargs_passed_through(self) -> None: + """Test that additional kwargs are correctly passed to the wrapped function.""" + + def test_get_batch_with_kwargs( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + hyperparameters: dict | None = None, + extra_param: str = "default", + **kwargs, + ) -> Batch: + # Check that kwargs are passed correctly + assert extra_param == "test_extra" + + return create_simple_traditional_get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters, + **kwargs, + ) + + x_only_batch = get_batch( + batch_size=2, + seq_len=6, + num_features=2, + single_eval_pos=3, + get_batch=test_get_batch_with_kwargs, + extra_param="test_extra", + ) + + # If we get here without assertion error, kwargs were passed correctly + assert x_only_batch is not None + + def test_empty_hyperparameters_default(self) -> None: + """Test that empty hyperparameters dict is used when None is passed.""" + + def test_get_batch_checks_empty_hyperparams( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + hyperparameters: dict | None = None, + **kwargs, + ) -> Batch: + # When hyperparameters=None is passed to wrapper, it should become {} + assert hyperparameters == {} + + return create_simple_traditional_get_batch( + batch_size, + seq_len, + num_features, + single_eval_pos, + hyperparameters, + **kwargs, + ) + + x_only_batch = get_batch( + batch_size=2, + seq_len=6, + num_features=2, + single_eval_pos=3, + get_batch=test_get_batch_checks_empty_hyperparams, + hyperparameters=None, + ) + + assert x_only_batch is not None + + def test_single_eval_pos_boundary_conditions(self) -> None: + """Test behavior at boundary conditions for single_eval_pos.""" + batch_size = 2 + seq_len = 6 + num_features = 2 + + # Test with single_eval_pos = 1 (minimal training data) + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=1, + get_batch=create_simple_traditional_get_batch, + ) + + assert x_only_batch.x.shape == (batch_size, 1, num_features + 1) + assert x_only_batch.test_x.shape == (batch_size, seq_len - 1, num_features + 1) + assert x_only_batch.target.shape == (batch_size, seq_len - 1, num_features + 1) + + # Test with single_eval_pos = seq_len - 1 (minimal test data) + x_only_batch = get_batch( + batch_size=batch_size, + seq_len=seq_len, + num_features=num_features, + single_eval_pos=seq_len - 1, + get_batch=create_simple_traditional_get_batch, + ) + + assert x_only_batch.x.shape == (batch_size, seq_len - 1, num_features + 1) + assert x_only_batch.test_x.shape == (batch_size, 1, num_features + 1) + assert x_only_batch.target.shape == (batch_size, 1, num_features + 1) + + def test_device_consistency(self) -> None: + """Test that tensors maintain device consistency.""" + + def get_batch_with_specific_device( + batch_size: int, + seq_len: int, + num_features: int, + single_eval_pos: int, + hyperparameters: dict | None = None, + **kwargs, + ) -> Batch: + device = torch.device("cpu") # Force CPU for testing + x = torch.rand(batch_size, seq_len, num_features, device=device) + y = torch.rand(batch_size, seq_len, 1, device=device) + target_y = y.clone() + + return Batch(x=x, y=y, target_y=target_y, single_eval_pos=single_eval_pos) + + x_only_batch = get_batch( + batch_size=2, + seq_len=6, + num_features=2, + single_eval_pos=3, + get_batch=get_batch_with_specific_device, + ) + + # All tensors should be on the same device + assert x_only_batch.x.device == x_only_batch.test_x.device + assert x_only_batch.x.device == x_only_batch.target.device + + +if __name__ == "__main__": + # Run a simple test if executed directly + test_instance = TestConvertPriorToXOnlyFormat() + test_instance.test_basic_conversion_functionality() + print("Basic test passed!")