diff --git a/dreamer/extraction/samplers/raycast_sampler.py b/dreamer/extraction/samplers/raycast_sampler.py index d6a0f64..cce5318 100644 --- a/dreamer/extraction/samplers/raycast_sampler.py +++ b/dreamer/extraction/samplers/raycast_sampler.py @@ -189,8 +189,30 @@ def harvest( Logger.Levels.debug ).log() + def finalize_rays(raw_rays, target_rays): + lengths = np.linalg.norm(raw_rays, axis=1) + sorted_indices = np.argsort(lengths) + best_rays = raw_rays[sorted_indices][:target_rays] + np.random.shuffle(best_rays) + final_rays = best_rays + return final_rays + + max_radius = 100 + raw_rays = np.array([]) + + raddai = [] + expansions = [] + while len(final_rays) < target_rays: - Logger(f"Sweeping lattice up to R_max = {current_R_max:.2f}...", Logger.Levels.debug).log() + if current_R_max >= max_radius: + Logger( + f"[Pipeline] Could not achieve quota, found {len(final_rays)}/{target_rays}", + Logger.Levels.debug + ).log() + final_rays = finalize_rays(raw_rays, target_rays) + break + + # Logger(f"Sweeping lattice up to R_max = {current_R_max:.2f}...", Logger.Levels.debug).log() # Enforce max_per_ray=1 for the "Fair Slice" raw_rays = sampler.harvest( @@ -200,15 +222,8 @@ def harvest( ) if len(raw_rays) >= target_rays: - Logger( - f"Quota exceeded ({len(raw_rays)}). Engaging Expanding Ball (Length Sort)...", - Logger.Levels.debug - ).log() - lengths = np.linalg.norm(raw_rays, axis=1) - sorted_indices = np.argsort(lengths) - best_rays = raw_rays[sorted_indices][:target_rays] - np.random.shuffle(best_rays) - final_rays = best_rays + Logger(f"[Pipeline] Quota exceeded ({len(raw_rays)})!", Logger.Levels.debug).log() + final_rays = finalize_rays(raw_rays, target_rays) break else: if len(raw_rays) == 0: @@ -220,15 +235,16 @@ def harvest( # Cap the multiplier between 1.10 (minimum safety step) and 3.0 (max jump) multiplier = np.clip(momentum_multiplier, 1.10, 3.0) - Logger( - f"Discretization Gap hit: Yielded {len(raw_rays)} bounded rays. Target: {target_rays}.", - Logger.Levels.debug - ).log() - Logger( - f"\n Momentum Expansion: Multiplying R_max by {multiplier:.3f}", - Logger.Levels.debug - ).log() + expansions.append(multiplier) + raddai.append(current_R_max) current_R_max *= multiplier + multipliers = [f'{radius:.2f} by {multiplier:.3f}' for multiplier, radius in zip(expansions, raddai)] + multipliers = ', '.join(multipliers) + if multipliers: + Logger(f"\tMomentum Expansion up to {current_R_max:.2f}: {multipliers}", Logger.Levels.debug).log() + else: + Logger(f"\tSearch radius used: {current_R_max:.2f}", Logger.Levels.debug) + self._verify_uniformity(final_rays, fraction, d_flat) return final_rays diff --git a/examples/main_example.py b/examples/main_example.py index dea5bea..769a511 100644 --- a/examples/main_example.py +++ b/examples/main_example.py @@ -46,5 +46,5 @@ def trajectory_compute_func_analysis(d): if_srcs=[pFq(log(2), 2, 1, -1)], extractor=extraction.extractor.ShardExtractorMod, analyzers=[analysis.AnalyzerModV1], - searcher=search.GeneticSearchMod + searcher=search.SearcherModV1 ).run(constants=[log(2)])