Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dreamer/configs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class SearchConfig(Configurable):
GA_MAX_RETRIES: int = 3 # Retry rounds for invalid/failed trajectory evaluations.
GA_REFINE_PROB: float = 0.5 # Probability of entering refine mutation mode.
GA_REFINE_COORD_PROB: float = 0.5 # Per-coordinate refine perturbation probability.
GA_MAX_NO_IMPROVEMENT_COUNT_RETRY: int = 5 # Max retries when no improvement is observed before giving up.

MAX_TRAJECTORY_COORD: int = 50 # Max coordinate value for a trajectory.


search_config: SearchConfig = SearchConfig()
8 changes: 4 additions & 4 deletions dreamer/extraction/samplers/raycast_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from dreamer.extraction.samplers.conditioner import HyperSpaceConditioner
from dreamer.extraction.samplers.raycaster import RayCastingSamplingMethod
from dreamer.utils.logger import Logger
from dreamer.configs.search import search_config
from .sampler import Sampler
from typing import Callable, cast
import math



class RaycastPipelineSampler(Sampler):
Expand Down Expand Up @@ -197,7 +200,7 @@ def finalize_rays(raw_rays, target_rays):
final_rays = best_rays
return final_rays

max_radius = 100
max_radius = math.sqrt(pow(search_config.MAX_TRAJECTORY_COORD, 2) * d_flat) + 1
raw_rays = np.array([])

raddai = []
Expand All @@ -212,9 +215,6 @@ def finalize_rays(raw_rays, target_rays):
).log()
break

# Logger(f"Sweeping lattice up to R_max = {current_R_max:.2f}...", Logger.Levels.debug).log()

# Enforce max_per_ray=1 for the "Fair Slice"
raw_rays = sampler.harvest(
target_rays=guide_rays_to_shoot,
R_max=current_R_max,
Expand Down
30 changes: 27 additions & 3 deletions dreamer/search/methods/genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def _sample_valid_trajectories(self, *, count: int, template_pos: Position) -> L
break

if len(sampled) < count:
raise ValueError("Genetic search could not sample enough valid trajectories satisfying A v <= 0")
Logger(
f"Genetic search could not sample enough valid trajectories. Sampled {len(sampled)}/{count}",
Logger.Levels.warning
).log()
return sampled[:count]

def _get_valid_repair_trajectory(self, template_pos: Position) -> Position:
Expand All @@ -293,7 +296,8 @@ def _get_valid_repair_trajectory(self, template_pos: Position) -> Position:
count=self._buffer_chunk_size,
template_pos=template_pos
)
random.shuffle(self._valid_trajectory_buffer)
if self._valid_trajectory_buffer:
random.shuffle(self._valid_trajectory_buffer)
return self._valid_trajectory_buffer.pop()

def _repair_trajectory(self, trajectory: Position, template_pos: Position) -> Position:
Expand Down Expand Up @@ -406,11 +410,28 @@ def _evaluate_population(
population[i]["sd"] = sd

unresolved = invalid_indices
unchanged_count = 0
last_found_amount = -1

for _ in range(self.max_retries):
if not unresolved:
break

if unchanged_count >= search_config.GA_MAX_NO_IMPROVEMENT_COUNT_RETRY:
Logger(
"Genetic algorithm solving unresolved trajectories - giving up resampling.", Logger.Levels.debug
).log()
break

retry_trajectories = self._sample_valid_trajectories(count=len(unresolved), template_pos=template_pos)
if len(retry_trajectories) == 0:
Logger("No valid trajectories could be sampled")

if last_found_amount == -1:
last_found_amount = len(retry_trajectories)
elif last_found_amount == len(retry_trajectories):
unchanged_count += 1

retry_pairs = [(traj, start) for traj in retry_trajectories]
self._compute_missing_search_data(retry_pairs)

Expand All @@ -431,7 +452,6 @@ def _evaluate_population(
next_unresolved.append(i)

unresolved = next_unresolved

return population

def search(
Expand All @@ -450,6 +470,10 @@ def search(
template = self._resolve_template(template_trajectory)

initial_trajectories = self._sample_valid_trajectories(count=self.pop_size, template_pos=template)
if len(initial_trajectories) == 0:
Logger("No valid trajectories could be sampled. Continue...", Logger.Levels.warning).log()
return DataManager(self.use_LIReC)

population: List[Dict[str, Any]] = [
{"trajectory": traj, "delta": None, "sd": None} for traj in initial_trajectories
]
Expand Down
18 changes: 0 additions & 18 deletions tests/test_search_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,24 +375,6 @@ def _always_invalid_mutate(_pos, **_kwargs):
assert all(space.is_valid_trajectory(sd.sv.trajectory) for sd in result.values())


def test_genetic_search_raises_when_constraints_have_no_valid_trajectories(monkeypatch):
"""Ensure search fails loudly when constraints admit no valid trajectories.
Assumption: sampler yields only invalid trajectories for constrained space.
Failure mode caught: infinite retries or misleading exception messages.
"""
space = ImpossibleConstrainedSpace()
x, y = space.cmf.symbols
_patch_static_sampler(monkeypatch, [Position({x: 1, y: 1}), Position({x: 2, y: 2})])
_configure_ga(monkeypatch, generations=1, pop_size=2, max_retries=1, parallel_search=False)
method = GeneticSearchMethod(
cast(Searchable, cast(object, space)),
constant=None,
)

with pytest.raises(ValueError, match="could not sample enough valid trajectories"):
method.search(template_trajectory=Position({x: 1, y: 1}))


def test_evaluate_population_resamples_invalid_trajectories_in_batch(monkeypatch):
"""Verify invalid individuals are resampled in one batch per retry round.
Assumption: initial cache lookup returns no SearchData for all individuals.
Expand Down
Loading