Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions dreamer/extraction/samplers/raycast_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,30 @@ def harvest(
Logger.Levels.debug
).log()

def finalize_rays(raw_rays, target_rays):
lengths = np.linalg.norm(raw_rays, axis=1)
sorted_indices = np.argsort(lengths)
best_rays = raw_rays[sorted_indices][:target_rays]
np.random.shuffle(best_rays)
final_rays = best_rays
return final_rays

max_radius = 100
raw_rays = np.array([])

raddai = []
expansions = []

while len(final_rays) < target_rays:
Logger(f"Sweeping lattice up to R_max = {current_R_max:.2f}...", Logger.Levels.debug).log()
if current_R_max >= max_radius:
Logger(
f"[Pipeline] Could not achieve quota, found {len(final_rays)}/{target_rays}",
Logger.Levels.debug
).log()
final_rays = finalize_rays(raw_rays, target_rays)
break

# Logger(f"Sweeping lattice up to R_max = {current_R_max:.2f}...", Logger.Levels.debug).log()

# Enforce max_per_ray=1 for the "Fair Slice"
raw_rays = sampler.harvest(
Expand All @@ -200,15 +222,8 @@ def harvest(
)

if len(raw_rays) >= target_rays:
Logger(
f"Quota exceeded ({len(raw_rays)}). Engaging Expanding Ball (Length Sort)...",
Logger.Levels.debug
).log()
lengths = np.linalg.norm(raw_rays, axis=1)
sorted_indices = np.argsort(lengths)
best_rays = raw_rays[sorted_indices][:target_rays]
np.random.shuffle(best_rays)
final_rays = best_rays
Logger(f"[Pipeline] Quota exceeded ({len(raw_rays)})!", Logger.Levels.debug).log()
final_rays = finalize_rays(raw_rays, target_rays)
break
else:
if len(raw_rays) == 0:
Expand All @@ -220,15 +235,16 @@ def harvest(

# Cap the multiplier between 1.10 (minimum safety step) and 3.0 (max jump)
multiplier = np.clip(momentum_multiplier, 1.10, 3.0)
Logger(
f"Discretization Gap hit: Yielded {len(raw_rays)} bounded rays. Target: {target_rays}.",
Logger.Levels.debug
).log()
Logger(
f"\n Momentum Expansion: Multiplying R_max by {multiplier:.3f}",
Logger.Levels.debug
).log()
expansions.append(multiplier)
raddai.append(current_R_max)
current_R_max *= multiplier

multipliers = [f'{radius:.2f} by {multiplier:.3f}' for multiplier, radius in zip(expansions, raddai)]
multipliers = ', '.join(multipliers)
if multipliers:
Logger(f"\tMomentum Expansion up to {current_R_max:.2f}: {multipliers}", Logger.Levels.debug).log()
else:
Logger(f"\tSearch radius used: {current_R_max:.2f}", Logger.Levels.debug)

self._verify_uniformity(final_rays, fraction, d_flat)
return final_rays
2 changes: 1 addition & 1 deletion examples/main_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ def trajectory_compute_func_analysis(d):
if_srcs=[pFq(log(2), 2, 1, -1)],
extractor=extraction.extractor.ShardExtractorMod,
analyzers=[analysis.AnalyzerModV1],
searcher=search.GeneticSearchMod
searcher=search.SearcherModV1
).run(constants=[log(2)])
Loading