Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 200 additions & 33 deletions src/pref_bo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import argparse
import time
from dataclasses import dataclass, field
from typing import Optional, Sequence

import botorch
import gurobipy as gp
Expand Down Expand Up @@ -36,6 +39,41 @@
torch.set_default_dtype(torch.float64)


@dataclass
class Config:
n_items: int = 50
n_objs: int = 3
n_initial_samples: int = 10
n_iterations: int = 20
mc_samples: int = 128
batch_size_q: int = 2
num_restarts: int = 10
raw_samples: int = 512
acqf_maxiter: int = 200
should_maximize: bool = True
sequential: bool = True
seed: int = 123
density: float = 0.5
rho: float = 1e-4
ref_point_values: Optional[Sequence[float]] = None
ref_point: torch.Tensor = field(init=False)

def __post_init__(self):
if self.ref_point_values is not None:
if len(self.ref_point_values) != self.n_objs:
raise ValueError(
"Length of ref_point_values must match the number of objectives"
)
ref_point_tensor = torch.tensor(
self.ref_point_values, dtype=torch.get_default_dtype()
)
else:
ref_point_tensor = torch.zeros(
self.n_objs, dtype=torch.get_default_dtype()
)
self.ref_point = ref_point_tensor


class MOKPInstance:
"""
A class to define and solve the Multi-Objective Knapsack Problem (MOKP).
Expand Down Expand Up @@ -245,90 +283,219 @@ def initialize_model(train_x, train_obj):
return mll, model


def optimize_qehvi_and_get_observation(model, train_x, sampler):
def optimize_qehvi_and_get_observation(model, train_x, sampler, config: Config):
"""Optimizes the qEHVI acquisition function, and returns a new candidate and observation."""
# partition non-dominated space into disjoint rectangles
with torch.no_grad():
pred = model.posterior(train_x).mean
ref_point = config.ref_point.to(pred.device)
partitioning = FastNondominatedPartitioning(
ref_point=REF_POINT,
ref_point=ref_point,
Y=pred,
)
acq_func = qLogExpectedHypervolumeImprovement(
model=model,
ref_point=REF_POINT,
ref_point=ref_point,
partitioning=partitioning,
sampler=sampler,
)
bounds = torch.stack([torch.zeros(N_OBJS), torch.ones(N_OBJS)])
bounds = torch.stack(
[
torch.zeros(config.n_objs, dtype=train_x.dtype, device=train_x.device),
torch.ones(config.n_objs, dtype=train_x.dtype, device=train_x.device),
]
)
# (indices, coefficients, rhs)
equality_constraints = [(torch.arange(N_OBJS), torch.ones(N_OBJS), 1.0)]
equality_constraints = [
(
torch.arange(config.n_objs, device=train_x.device),
torch.ones(config.n_objs, dtype=train_x.dtype, device=train_x.device),
1.0,
)
]
# optimize
candidates, _ = optimize_acqf(
acq_function=acq_func,
bounds=bounds,
q=BATCH_SIZE_Q,
num_restarts=NUM_RESTARTS,
raw_samples=RAW_SAMPLES, # used for intialization heuristic
options={"maxiter": 200},
sequential=SEQUENTIAL,
q=config.batch_size_q,
num_restarts=config.num_restarts,
raw_samples=config.raw_samples, # used for intialization heuristic
options={"maxiter": config.acqf_maxiter},
sequential=config.sequential,
equality_constraints=equality_constraints,
)

return candidates


def main():
mokp_problem = MOKPInstance(n_items=N_ITEMS, n_objs=N_OBJS, seed=123)
def main(config: Optional[Config] = None):
if config is None:
config = Config()
mokp_problem = MOKPInstance(
n_items=config.n_items,
n_objs=config.n_objs,
density=config.density,
seed=config.seed,
rho=config.rho,
)
train_lambda, train_obj = generate_random_samples(
mokp_problem, n=N_INITIAL_SAMPLES, maximize=SHOULD_MAXIMIZE
mokp_problem, n=config.n_initial_samples, maximize=config.should_maximize
)

print(f"Starting BO loop for {N_ITERATIONS} iterations...")
print(f"Starting BO loop for {config.n_iterations} iterations...")
start_time = time.time()
for i in range(N_ITERATIONS):
for i in range(config.n_iterations):
# --- a. Fit the GP Surrogate Models ---
mll, model = initialize_model(train_lambda, train_obj)
fit_gpytorch_mll(mll)

# Sampler for qEHVI
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([MC_SAMPLES]))
new_lambda = optimize_qehvi_and_get_observation(model, train_lambda, sampler)
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([config.mc_samples]))
new_lambda = optimize_qehvi_and_get_observation(
model, train_lambda, sampler, config
)

# --- d. Evaluate the Black Box ---
new_obj = mokp_problem(new_lambda, maximize=SHOULD_MAXIMIZE)
new_obj = mokp_problem(new_lambda, maximize=config.should_maximize)

# --- e. Update the Dataset ---
train_lambda = torch.cat([train_lambda, new_lambda])
train_obj = torch.cat([train_obj, new_obj])

pareto_mask = is_non_dominated(train_obj)
bd = FastNondominatedPartitioning(ref_point=REF_POINT, Y=train_obj)
bd = FastNondominatedPartitioning(
ref_point=config.ref_point.to(train_obj.device), Y=train_obj
)
volume = bd.compute_hypervolume().item() / torch.abs(
torch.tensor(mokp_problem.ideal_point.prod())
)
print(
f"Iter {i+1}/{N_ITERATIONS} | ND: {pareto_mask.sum()} | Hypervolume: {volume:.4f}"
f"Iter {i+1}/{config.n_iterations} | ND: {pareto_mask.sum()} | Hypervolume: {volume:.4f}"
)

end_time = time.time()
print(f"\nBO loop finished in {end_time - start_time:.2f} seconds.")


# Problem & BO Parameters
N_ITEMS = 50
N_OBJS = 3
N_INITIAL_SAMPLES = 10 # Warm-up points
N_ITERATIONS = 20 # BO loop iterations
MC_SAMPLES = 128 # Samples for qEHVI
BATCH_SIZE_Q = 2 # q=1 for sequential optimization
NUM_RESTARTS = 10
RAW_SAMPLES = 512
REF_POINT = torch.zeros(N_OBJS) # Reference point for hypervolume
SHOULD_MAXIMIZE = True # Whether to maximize the objectives
SEQUENTIAL = True
def _add_bool_argument(
parser: argparse.ArgumentParser, name: str, default: bool, help_text: str
) -> str:
"""Register a pair of --foo/--no-foo flags that toggle a boolean option."""

dest = name.replace("-", "_")
default_state = "enabled" if default else "disabled"
group = parser.add_mutually_exclusive_group()
group.add_argument(
f"--{name}",
dest=dest,
action="store_true",
help=f"{help_text} (default: {default_state})",
)
group.add_argument(
f"--no-{name}",
dest=dest,
action="store_false",
help=f"{help_text} (default: {default_state})",
)
parser.set_defaults(**{dest: default})
return dest


def parse_args(argv: Optional[Sequence[str]] = None) -> Config:
parser = argparse.ArgumentParser(description="Preference-based BO configuration")
parser.add_argument("--n-items", type=int, default=50, help="Number of items")
parser.add_argument(
"--n-objs", type=int, default=3, help="Number of objectives in the MOKP"
)
parser.add_argument(
"--n-initial-samples",
type=int,
default=10,
help="Number of initial samples for warm-up",
)
parser.add_argument(
"--n-iterations", type=int, default=20, help="Number of BO iterations"
)
parser.add_argument(
"--mc-samples",
type=int,
default=128,
help="Number of Monte Carlo samples for qEHVI",
)
parser.add_argument(
"--batch-size-q", type=int, default=2, help="Batch size q for qEHVI"
)
parser.add_argument(
"--num-restarts",
type=int,
default=10,
help="Number of restarts for acquisition optimization",
)
parser.add_argument(
"--raw-samples",
type=int,
default=512,
help="Number of raw samples for acquisition initialization",
)
parser.add_argument(
"--acqf-maxiter",
type=int,
default=200,
help="Maximum iterations when optimizing the acquisition function",
)
should_maximize_dest = _add_bool_argument(
parser,
name="should-maximize",
default=True,
help_text="Whether to maximize the objectives",
)
sequential_dest = _add_bool_argument(
parser,
name="sequential",
default=True,
help_text="Whether to use sequential acquisition optimization",
)
parser.add_argument("--seed", type=int, default=123, help="Random seed")
parser.add_argument(
"--density",
type=float,
default=0.5,
help="Knapsack capacity density (percentage of total weight)",
)
parser.add_argument(
"--rho",
type=float,
default=1e-4,
help="Augmentation parameter for Tchebycheff scalarization",
)
parser.add_argument(
"--ref-point",
type=float,
nargs="*",
help="Reference point for hypervolume calculation (length must equal n_objs)",
)
args = parser.parse_args(argv)
return Config(
n_items=args.n_items,
n_objs=args.n_objs,
n_initial_samples=args.n_initial_samples,
n_iterations=args.n_iterations,
mc_samples=args.mc_samples,
batch_size_q=args.batch_size_q,
num_restarts=args.num_restarts,
raw_samples=args.raw_samples,
acqf_maxiter=args.acqf_maxiter,
should_maximize=getattr(args, should_maximize_dest),
sequential=getattr(args, sequential_dest),
seed=args.seed,
density=args.density,
rho=args.rho,
ref_point_values=args.ref_point,
)


if __name__ == "__main__":
main()
main(parse_args())

# from botorch.test_functions.multi_objective import BraninCurrin
# problem = BraninCurrin(negate=True)
Expand Down