diff --git a/README.md b/README.md index 7a57848..1fb9d53 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,21 @@ Currently the following ways of fetching simulation results are supported. - Record and collect time series data of values sampled from components. Call `builder.record_time_series::()` on the builder to set up the recording, and then `simulation.get_time_series::()` to collect the results. +## Examples + +The crate includes several comprehensive examples demonstrating different simulation patterns: + +### Basic Simulations +- **[Counter](examples/counter.rs)** - Simple entity counting and state management +- **[Pandemic](examples/pandemic.rs)** - Basic disease spread simulation with random movement +- **[Step Counter](examples/step_counter.rs)** - Demonstrates the built-in step counter plugin +- **[Traders](examples/traders.rs)** - Time series collection with multiple trader agents + +### Spatial Simulations +- **[Forest Fire](examples/forest_fire.rs)** - Cellular automaton fire spread with spatial grid optimization +- **[Pandemic Spatial](examples/pandemic_spatial.rs)** - Advanced epidemic modeling with infection radius, social distancing, contact tracing, and quarantine zones + + ## Performance When it comes to experiments like Monte Carlo, performance is typically of paramount importance since it defines their limits in terms of scope, size, length and granularity. Hence why I made the decision build this crate on top of bevy. The ECS architecture on offer here is likely the most memory-efficient and parallelizable way one can build such simulations, while still maintaining some agency of high-level programming. diff --git a/examples/README.md b/examples/README.md index 59d5684..f2ecf7a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,7 +7,14 @@ - Shows how entities can interact with each other. - Samples the simulation by counting entities. - **[traders.rs](traders.rs)** - - The most complicated example. - Two different kinds of entities interacting with each other. - Data is sampled as a time series and plotted. +- **[forest_fire.rs](forest_fire.rs)** + - Spatial cellular automaton simulation. + - Demonstrates grid-based entities and neighborhood interactions. + - Shows probabilistic fire spread with time series analysis. +- **[pandemic_spatial.rs](pandemic_spatial.rs)** + - Advanced epidemic simulation with spatial features. + - Infection radius, social distancing, contact tracing, and quarantine zones. + - Demonstrates realistic epidemic modeling with spatial grid optimization. diff --git a/examples/air_traffic_3d.rs b/examples/air_traffic_3d.rs new file mode 100644 index 0000000..0a6f8cf --- /dev/null +++ b/examples/air_traffic_3d.rs @@ -0,0 +1,244 @@ +//! # Simple 3D Air Traffic Control Simulation +//! +//! This example demonstrates 3D spatial grid functionality with aircraft moving +//! through different altitude levels in a simple airspace. + +#![allow(clippy::cast_possible_wrap)] +#![allow(clippy::cast_sign_loss)] +#![allow(clippy::expect_used)] + +use bevy::prelude::*; +use incerto::prelude::*; +use rand::prelude::*; + +// Simulation parameters +const SIMULATION_STEPS: usize = 50; +const AIRSPACE_SIZE: i32 = 20; // 20x20x10 airspace +const AIRSPACE_HEIGHT: i32 = 10; +const NUM_AIRCRAFT: usize = 15; + +/// Aircraft component with basic properties +#[derive(Component, Debug)] +pub struct Aircraft +{ + pub id: usize, + pub aircraft_type: AircraftType, + pub target_altitude: i32, + pub speed: i32, // cells per step +} + +#[derive(Debug, Clone, Copy)] +pub enum AircraftType +{ + Commercial, + PrivateJet, + Cargo, +} + +impl AircraftType +{ + const fn symbol(self) -> &'static str + { + match self + { + Self::Commercial => "āœˆļø", + Self::PrivateJet => "šŸ›©ļø", + Self::Cargo => "šŸ›«", + } + } +} + +/// Sample trait for counting aircraft +impl Sample for Aircraft +{ + fn sample(components: &[&Self]) -> usize + { + components.len() + } +} + +/// Spawn aircraft at random positions and altitudes +fn spawn_aircraft(spawner: &mut Spawner) +{ + let mut rng = rand::rng(); + + for i in 0..NUM_AIRCRAFT + { + let x = rng.random_range(0..AIRSPACE_SIZE); + let y = rng.random_range(0..AIRSPACE_SIZE); + let z = rng.random_range(0..AIRSPACE_HEIGHT); + + let aircraft_type = match rng.random_range(0..3) + { + 0 => AircraftType::Commercial, + 1 => AircraftType::PrivateJet, + _ => AircraftType::Cargo, + }; + + let target_altitude = rng.random_range(0..AIRSPACE_HEIGHT); + + let aircraft = Aircraft { + id: i, + aircraft_type, + target_altitude, + speed: 1, + }; + + let position = GridPosition3D::new(x, y, z); + spawner.spawn((aircraft, position)); + } +} + +/// Move aircraft towards their target altitudes and in random horizontal directions +fn move_aircraft(mut query: Query<(&mut GridPosition3D, &Aircraft)>) +{ + let mut rng = rand::rng(); + + for (mut position, aircraft) in &mut query + { + // Move towards target altitude + if position.z() < aircraft.target_altitude + { + *position = GridPosition3D::new(position.x(), position.y(), position.z() + 1); + } + else if position.z() > aircraft.target_altitude + { + *position = GridPosition3D::new(position.x(), position.y(), position.z() - 1); + } + + // Random horizontal movement + if rng.random_bool(0.7) + { + let dx = rng.random_range(-1..=1); + let dy = rng.random_range(-1..=1); + + let new_x = (position.x() + dx).clamp(0, AIRSPACE_SIZE - 1); + let new_y = (position.y() + dy).clamp(0, AIRSPACE_SIZE - 1); + + *position = GridPosition3D::new(new_x, new_y, position.z()); + } + } +} + +/// Check for aircraft conflicts (too close in 3D space) +fn check_conflicts( + spatial_grid: Res>, + query: Query<(Entity, &GridPosition3D, &Aircraft)>, +) +{ + let mut conflicts = 0; + + for (entity, position, aircraft) in &query + { + // Check for nearby aircraft (within 1 cell in any direction) + let nearby_aircraft = spatial_grid.neighbors_of(position); + + for nearby_entity in nearby_aircraft + { + if nearby_entity != entity + && let Ok((_, nearby_pos, nearby_aircraft)) = query.get(nearby_entity) + { + let distance = (position.0 - nearby_pos.0).length_squared(); + if distance <= 1 + { + conflicts += 1; + println!( + "āš ļø CONFLICT: Aircraft {} {} and {} {} too close at distance {}", + aircraft.id, + aircraft.aircraft_type.symbol(), + nearby_aircraft.id, + nearby_aircraft.aircraft_type.symbol(), + distance + ); + } + } + } + } + + if conflicts > 0 + { + println!(" Total conflicts detected: {}", conflicts / 2); // Divide by 2 since each conflict is counted twice + } +} + +/// Display airspace status +fn display_airspace(query: Query<(&GridPosition3D, &Aircraft)>) +{ + println!("\nšŸ“” Airspace Status:"); + + // Count aircraft by altitude + let mut altitude_counts = vec![0; AIRSPACE_HEIGHT as usize]; + let mut aircraft_positions = Vec::new(); + + for (position, aircraft) in &query + { + altitude_counts[position.z() as usize] += 1; + aircraft_positions.push((position, aircraft)); + } + + for altitude in 0..AIRSPACE_HEIGHT + { + let count = altitude_counts[altitude as usize]; + if count > 0 + { + print!(" FL{altitude:02}: {count} aircraft "); + + // Show aircraft at this altitude + for (pos, aircraft) in &aircraft_positions + { + if pos.z() == altitude + { + print!("{} ", aircraft.aircraft_type.symbol()); + } + } + println!(); + } + } + + println!(" Airspace: {AIRSPACE_SIZE}x{AIRSPACE_SIZE}x{AIRSPACE_HEIGHT} cells"); +} + +fn main() +{ + println!("āœˆļø 3D Air Traffic Control Simulation"); + println!("Airspace: {AIRSPACE_SIZE}x{AIRSPACE_SIZE}x{AIRSPACE_HEIGHT} cells"); + println!("Aircraft: {NUM_AIRCRAFT}"); + println!("Duration: {SIMULATION_STEPS} steps\n"); + + // Create 3D airspace bounds + let bounds = GridBounds3D { + min: IVec3::ZERO, + max: IVec3::new(AIRSPACE_SIZE, AIRSPACE_SIZE, AIRSPACE_HEIGHT), + }; + + let mut simulation = SimulationBuilder::new() + .add_spatial_grid::(bounds) + .add_entity_spawner(spawn_aircraft) + .add_systems((move_aircraft, check_conflicts, display_airspace)) + .build(); + + // Run simulation + for step in 1..=SIMULATION_STEPS + { + println!("šŸ• Step {step}/{SIMULATION_STEPS}"); + simulation.run(1); + + if step.is_multiple_of(10) + { + let aircraft_count = simulation + .sample::() + .expect("Failed to sample aircraft count"); + println!(" šŸ“Š Total aircraft tracked: {aircraft_count}"); + } + + // Add small delay for readability + std::thread::sleep(std::time::Duration::from_millis(200)); + } + + println!("\nāœ… Air Traffic Control simulation completed!"); + + let final_count = simulation + .sample::() + .expect("Failed to sample final aircraft count"); + println!("šŸ“ˆ Final aircraft count: {final_count}"); +} diff --git a/examples/forest_fire.rs b/examples/forest_fire.rs new file mode 100644 index 0000000..2a2d0c3 --- /dev/null +++ b/examples/forest_fire.rs @@ -0,0 +1,362 @@ +//! # Monte Carlo simulation of forest fire spread. +//! +//! This example showcases a spatial cellular automaton simulation where fire spreads +//! through a forest based on probabilistic rules. The simulation demonstrates: +//! +//! * Spatial grid-based entities using the `SpatialGridPlugin` +//! * Entity state transitions (Healthy → Burning → Burned → Empty) +//! * Neighborhood interactions for fire spreading +//! * Time series collection of fire statistics +//! * Configurable fire spread parameters for Monte Carlo analysis +//! +//! Each cell in the forest can be in one of four states: +//! * **Healthy**: Can catch fire from burning neighbors +//! * **Burning**: Spreads fire to healthy neighbors, burns for a duration +//! * **Burned**: No longer spreads fire, can't burn again +//! * **Empty**: Vacant land that can regrow over time +//! +//! The simulation allows for studying fire spread patterns, firebreak effectiveness, +//! and forest management strategies under different conditions. + +#![allow(clippy::unwrap_used)] +#![allow(clippy::expect_used)] +#![allow(clippy::cast_precision_loss)] + +use std::collections::HashSet; + +use bevy::prelude::IVec2; +use incerto::prelude::*; +use rand::prelude::*; + +// Simulation parameters +const SIMULATION_STEPS: usize = 500; +const GRID_WIDTH: i32 = 50; +const GRID_HEIGHT: i32 = 50; + +// Fire parameters +const INITIAL_FOREST_DENSITY: f64 = 0.7; // Probability a cell starts as forest +const FIRE_SPREAD_PROBABILITY: f64 = 0.6; // Probability fire spreads to neighbor +const BURN_DURATION: usize = 3; // Steps a cell burns before becoming burned +const REGROWTH_PROBABILITY: f64 = 0.001; // Probability empty cell becomes forest +const INITIAL_FIRE_COUNT: usize = 3; // Number of initial fire sources + +// Time series sampling +const SAMPLE_INTERVAL: usize = 1; + +/// Represents the state of a forest cell. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CellState +{ + /// Healthy forest that can catch fire + Healthy, + /// Currently burning, will spread fire + Burning + { + remaining_burn_time: usize + }, + /// Already burned, cannot burn again + Burned, + /// Empty land that can regrow + #[default] + Empty, +} + +/// Component representing a single cell in the forest grid. +#[derive(Component, Debug, Default)] +pub struct ForestCell +{ + pub state: CellState, +} + +/// Fire statistics collected during simulation. +#[derive(Debug, Clone, Copy)] +pub struct FireStats +{ + pub healthy_count: usize, + pub burning_count: usize, + pub burned_count: usize, + pub empty_count: usize, + pub total_cells: usize, +} + +impl FireStats +{ + #[must_use] + pub fn fire_activity(&self) -> f64 + { + self.burning_count as f64 / self.total_cells as f64 + } + + #[must_use] + pub fn forest_coverage(&self) -> f64 + { + (self.healthy_count + self.burning_count) as f64 / self.total_cells as f64 + } + + #[must_use] + pub fn burned_percentage(&self) -> f64 + { + self.burned_count as f64 / self.total_cells as f64 + } +} + +/// Implement sampling to collect fire statistics. +impl Sample for ForestCell +{ + fn sample(components: &[&Self]) -> FireStats + { + assert!(!components.is_empty()); + + let mut healthy_count = 0; + let mut burning_count = 0; + let mut burned_count = 0; + let mut empty_count = 0; + + for cell in components + { + match cell.state + { + CellState::Healthy => healthy_count += 1, + CellState::Burning { .. } => burning_count += 1, + CellState::Burned => burned_count += 1, + CellState::Empty => empty_count += 1, + } + } + + FireStats { + healthy_count, + burning_count, + burned_count, + empty_count, + total_cells: components.len(), + } + } +} + +fn main() +{ + println!("šŸ”„ Starting Forest Fire Simulation"); + println!("Grid size: {GRID_WIDTH}x{GRID_HEIGHT}"); + println!( + "Initial forest density: {:.1}%", + INITIAL_FOREST_DENSITY * 100.0 + ); + println!( + "Fire spread probability: {:.1}%", + FIRE_SPREAD_PROBABILITY * 100.0 + ); + println!("Simulation steps: {SIMULATION_STEPS}"); + println!(); + + // Build the simulation + let bounds = GridBounds2D { + min: IVec2::new(0, 0), + max: IVec2::new(GRID_WIDTH - 1, GRID_HEIGHT - 1), + }; + let mut simulation = SimulationBuilder::new() + // Add spatial grid support + .add_spatial_grid::(bounds) + // Spawn the forest grid + .add_entity_spawner(spawn_forest_grid) + // Add fire spread system + .add_systems(fire_spread_system) + // Add burn progression system + .add_systems(burn_progression_system) + // Add regrowth system + .add_systems(regrowth_system) + // Record time series of fire statistics + .record_time_series::(SAMPLE_INTERVAL) + .expect("Failed to set up time series recording") + .build(); + + // Run the simulation + println!("Running simulation..."); + simulation.run(SIMULATION_STEPS); + + // Collect and display results + let final_stats = simulation + .sample::() + .expect("Failed to sample fire statistics"); + + println!("šŸ“Š Final Statistics:"); + println!( + " Healthy forest: {} cells ({:.1}%)", + final_stats.healthy_count, + final_stats.healthy_count as f64 / final_stats.total_cells as f64 * 100.0 + ); + println!( + " Currently burning: {} cells ({:.1}%)", + final_stats.burning_count, + final_stats.fire_activity() * 100.0 + ); + println!( + " Burned areas: {} cells ({:.1}%)", + final_stats.burned_count, + final_stats.burned_percentage() * 100.0 + ); + println!( + " Empty land: {} cells ({:.1}%)", + final_stats.empty_count, + final_stats.empty_count as f64 / final_stats.total_cells as f64 * 100.0 + ); + + // Display time series summary + let time_series = simulation + .get_time_series::() + .expect("Failed to get time series data"); + + println!("\nšŸ“ˆ Time Series Summary:"); + println!(" Data points collected: {}", time_series.len()); + + if let Some(peak_fire) = time_series + .iter() + .max_by(|a, b| a.burning_count.cmp(&b.burning_count)) + { + println!( + " Peak fire activity: {} burning cells ({:.1}%)", + peak_fire.burning_count, + peak_fire.fire_activity() * 100.0 + ); + } + + let total_burned = time_series.last().map_or(0, |stats| stats.burned_count); + println!( + " Total area burned: {} cells ({:.1}%)", + total_burned, + total_burned as f64 / final_stats.total_cells as f64 * 100.0 + ); +} + +/// Spawn the initial forest grid with random forest coverage. +fn spawn_forest_grid(spawner: &mut Spawner) +{ + let mut rng = rand::rng(); + + // Spawn all grid cells + for x in 0..GRID_WIDTH + { + for y in 0..GRID_HEIGHT + { + let position = GridPosition2D::new(x, y); + + // Determine initial state + let state = if rng.random_bool(INITIAL_FOREST_DENSITY) + { + CellState::Healthy + } + else + { + CellState::Empty + }; + + let cell = ForestCell { state }; + spawner.spawn((position, cell)); + } + } + + // Start some initial fires at random locations + let healthy_positions: Vec = (0..GRID_WIDTH) + .flat_map(|x| (0..GRID_HEIGHT).map(move |y| GridPosition2D::new(x, y))) + .collect(); + + // This is a simplified approach - in a real implementation you'd query existing entities + // For this example, we'll start fires by spawning burning cells at random positions + for _ in 0..INITIAL_FIRE_COUNT + { + if let Some(&pos) = healthy_positions.choose(&mut rng) + { + let burning_cell = ForestCell { + state: CellState::Burning { + remaining_burn_time: BURN_DURATION, + }, + }; + spawner.spawn((pos, burning_cell)); + } + } +} + +/// System that handles fire spreading to neighboring cells using the spatial grid. +fn fire_spread_system( + spatial_grid: Res>, + query_burning: Query<(Entity, &GridPosition2D), With>, + mut query_cells: Query<(&GridPosition2D, &mut ForestCell)>, +) +{ + let mut rng = rand::rng(); + let mut spread_positions = HashSet::new(); + + // Find all burning cells + for (burning_entity, burning_pos) in &query_burning + { + // Check if this cell is actually burning + if let Ok((_, cell)) = query_cells.get(burning_entity) + && matches!(cell.state, CellState::Burning { .. }) + { + // Get orthogonal neighbors using the spatial grid + let neighbors = spatial_grid.orthogonal_neighbors_of(burning_pos); + + for neighbor_entity in neighbors + { + if let Ok((neighbor_pos, neighbor_cell)) = query_cells.get(neighbor_entity) + { + // Check if neighbor is healthy and can catch fire + if matches!(neighbor_cell.state, CellState::Healthy) + { + // Fire spreads with probability + if rng.random_bool(FIRE_SPREAD_PROBABILITY) + { + spread_positions.insert(*neighbor_pos); + } + } + } + } + } + } + + // Apply fire spread + for (position, mut cell) in &mut query_cells + { + if spread_positions.contains(position) + { + cell.state = CellState::Burning { + remaining_burn_time: BURN_DURATION, + }; + } + } +} + +/// System that progresses burning cells through their burn cycle. +fn burn_progression_system(mut query: Query<&mut ForestCell>) +{ + for mut cell in &mut query + { + if let CellState::Burning { + remaining_burn_time, + } = &mut cell.state + { + if *remaining_burn_time > 1 + { + *remaining_burn_time -= 1; + } + else + { + // Fire burns out, cell becomes burned + cell.state = CellState::Burned; + } + } + } +} + +/// System that handles forest regrowth on empty land. +fn regrowth_system(mut query: Query<&mut ForestCell>) +{ + let mut rng = rand::rng(); + + for mut cell in &mut query + { + if matches!(cell.state, CellState::Empty) && rng.random_bool(REGROWTH_PROBABILITY) + { + cell.state = CellState::Healthy; + } + } +} diff --git a/examples/pandemic_spatial.rs b/examples/pandemic_spatial.rs new file mode 100644 index 0000000..f8f76da --- /dev/null +++ b/examples/pandemic_spatial.rs @@ -0,0 +1,619 @@ +//! # Enhanced Monte Carlo simulation of pandemic spread with spatial features. +//! +//! This example demonstrates advanced spatial epidemic modeling using the `SpatialGrid` plugin. +//! It showcases realistic pandemic dynamics including: +//! +//! * **Infection radius**: Disease spreads within a configurable distance, not just same-cell +//! * **Contact tracing**: Track and quarantine people who were near infected individuals +//! * **Social distancing**: People avoid crowded areas and maintain distance +//! * **Quarantine zones**: Restricted movement areas to contain outbreaks +//! * **Superspreader events**: Detection of high-transmission locations +//! * **Population density tracking**: Monitor crowding and movement patterns +//! +//! The simulation models a more realistic epidemic than simple grid-cell transmission, +//! allowing for analysis of various intervention strategies and their effectiveness. + +#![allow(clippy::unwrap_used)] +#![allow(clippy::expect_used)] +#![allow(clippy::cast_precision_loss)] +#![allow(clippy::too_many_lines)] +#![allow(clippy::type_complexity)] + +use std::collections::HashSet; + +use bevy::prelude::IVec2; +use incerto::prelude::*; +use rand::prelude::*; + +// Simulation parameters +const SIMULATION_STEPS: usize = 500; +const INITIAL_POPULATION: usize = 2000; +const GRID_SIZE: i32 = 40; + +// Disease parameters +const CHANCE_START_INFECTED: f64 = 0.02; +const INFECTION_RADIUS: i32 = 2; // Can infect within 2 cells distance +const CHANCE_INFECT_AT_DISTANCE_1: f64 = 0.15; // High chance at close distance +const CHANCE_INFECT_AT_DISTANCE_2: f64 = 0.05; // Lower chance at farther distance +const CHANCE_RECOVER: f64 = 0.03; +const CHANCE_DIE: f64 = 0.001; +const INCUBATION_PERIOD: usize = 5; // Steps before becoming infectious + +// Social distancing parameters +const SOCIAL_DISTANCING_ENABLED: bool = true; +const CROWDING_THRESHOLD: usize = 8; // Avoid areas with more than 8 people +const SOCIAL_DISTANCE_COMPLIANCE: f64 = 0.7; // 70% of people practice social distancing + +// Contact tracing parameters +const CONTACT_TRACING_ENABLED: bool = true; +const CONTACT_QUARANTINE_DURATION: usize = 14; + +// Quarantine zone (center area) +const QUARANTINE_ZONE_ENABLED: bool = true; +const QUARANTINE_CENTER_X: i32 = GRID_SIZE / 2; +const QUARANTINE_CENTER_Y: i32 = GRID_SIZE / 2; +const QUARANTINE_RADIUS: i32 = 8; + +// Time series sampling +const SAMPLE_INTERVAL: usize = 1; + +/// Disease states for people in the simulation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DiseaseState +{ + Healthy, + Exposed + { + incubation_remaining: usize, + }, // Infected but not yet infectious + Infectious, + Recovered, +} + +/// Component representing a person in the simulation +#[derive(Component, Debug)] +pub struct Person +{ + pub disease_state: DiseaseState, + pub social_distancing: bool, // Whether this person practices social distancing +} + +/// Component for people under quarantine (contact tracing) +#[derive(Component, Debug)] +pub struct Quarantined +{ + pub remaining_duration: usize, +} + +/// Component to track contact history for contact tracing +#[derive(Component, Debug, Default)] +pub struct ContactHistory +{ + pub recent_contacts: HashSet, // Positions visited recently +} + +/// Pandemic statistics collected during simulation +#[derive(Debug, Clone, Copy)] +pub struct PandemicStats +{ + pub healthy_count: usize, + pub exposed_count: usize, + pub infectious_count: usize, + pub recovered_count: usize, + pub dead_count: usize, + pub total_population: usize, +} + +impl Sample for Person +{ + fn sample(components: &[&Self]) -> PandemicStats + { + assert!(!components.is_empty()); + + let mut healthy_count = 0; + let mut exposed_count = 0; + let mut infectious_count = 0; + let mut recovered_count = 0; + + for person in components + { + match person.disease_state + { + DiseaseState::Healthy => healthy_count += 1, + DiseaseState::Exposed { .. } => exposed_count += 1, + DiseaseState::Infectious => infectious_count += 1, + DiseaseState::Recovered => recovered_count += 1, + } + } + + PandemicStats { + healthy_count, + exposed_count, + infectious_count, + recovered_count, + dead_count: INITIAL_POPULATION - components.len(), + total_population: components.len(), + } + } +} + +fn main() +{ + println!("🦠 Starting Enhanced Pandemic Simulation"); + println!("Population: {INITIAL_POPULATION}"); + println!("Grid size: {GRID_SIZE}x{GRID_SIZE}"); + println!("Infection radius: {INFECTION_RADIUS} cells"); + println!( + "Social distancing: {}", + if SOCIAL_DISTANCING_ENABLED + { + "ON" + } + else + { + "OFF" + } + ); + println!( + "Contact tracing: {}", + if CONTACT_TRACING_ENABLED { "ON" } else { "OFF" } + ); + println!( + "Quarantine zone: {}", + if QUARANTINE_ZONE_ENABLED { "ON" } else { "OFF" } + ); + println!(); + + let bounds = GridBounds2D { + min: IVec2::new(0, 0), + max: IVec2::new(GRID_SIZE - 1, GRID_SIZE - 1), + }; + let mut simulation = SimulationBuilder::new() + // Add spatial grid support + .add_spatial_grid::(bounds) + // Spawn initial population + .add_entity_spawner(spawn_population) + // Movement and social distancing + .add_systems(people_move_with_social_distancing) + // Disease progression and transmission + .add_systems(( + disease_incubation_progression, + spatial_disease_transmission, + disease_recovery_and_death, + )) + // Contact tracing and quarantine + .add_systems(( + update_contact_history, + process_contact_tracing, + update_quarantine_status, + )) + // Record pandemic statistics + .record_time_series::(SAMPLE_INTERVAL) + .expect("Failed to set up time series recording") + .build(); + + println!("Running simulation..."); + simulation.run(SIMULATION_STEPS); + + // Collect and display results + let final_stats = simulation + .sample::() + .expect("Failed to sample pandemic statistics"); + + println!("šŸ“Š Final Statistics:"); + println!( + " Population: {} ({}% survived)", + final_stats.total_population, + final_stats.total_population as f64 / INITIAL_POPULATION as f64 * 100.0 + ); + println!( + " Healthy: {} ({:.1}%)", + final_stats.healthy_count, + final_stats.healthy_count as f64 / final_stats.total_population as f64 * 100.0 + ); + println!( + " Exposed: {} ({:.1}%)", + final_stats.exposed_count, + final_stats.exposed_count as f64 / final_stats.total_population as f64 * 100.0 + ); + println!( + " Infectious: {} ({:.1}%)", + final_stats.infectious_count, + final_stats.infectious_count as f64 / final_stats.total_population as f64 * 100.0 + ); + println!( + " Recovered: {} ({:.1}%)", + final_stats.recovered_count, + final_stats.recovered_count as f64 / final_stats.total_population as f64 * 100.0 + ); + println!( + " Deaths: {} ({:.1}%)", + final_stats.dead_count, + final_stats.dead_count as f64 / INITIAL_POPULATION as f64 * 100.0 + ); + + // Display time series summary + let time_series = simulation + .get_time_series::() + .expect("Failed to get time series data"); + + println!("\nšŸ“ˆ Pandemic Timeline:"); + println!(" Data points collected: {}", time_series.len()); + + if let Some(peak_infections) = time_series + .iter() + .max_by_key(|stats| stats.infectious_count) + { + println!( + " Peak infectious: {} people ({:.1}%)", + peak_infections.infectious_count, + peak_infections.infectious_count as f64 / INITIAL_POPULATION as f64 * 100.0 + ); + } + + let total_recovered = time_series.last().map_or(0, |stats| stats.recovered_count); + let total_deaths = time_series.last().map_or(0, |stats| stats.dead_count); + println!( + " Total recovered: {} ({:.1}%)", + total_recovered, + total_recovered as f64 / INITIAL_POPULATION as f64 * 100.0 + ); + println!( + " Total deaths: {} ({:.1}%)", + total_deaths, + total_deaths as f64 / INITIAL_POPULATION as f64 * 100.0 + ); + + // Calculate attack rate (percentage who got infected) + let attack_rate = (total_recovered + total_deaths) as f64 / INITIAL_POPULATION as f64 * 100.0; + println!(" Attack rate: {attack_rate:.1}% (total who got infected)"); +} + +/// Spawn the initial population with random positions and infection states +fn spawn_population(spawner: &mut Spawner) +{ + let mut rng = rand::rng(); + + for _ in 0..INITIAL_POPULATION + { + // Random position on the grid + let position = GridPosition2D::new( + rng.random_range(0..GRID_SIZE), + rng.random_range(0..GRID_SIZE), + ); + + // Determine if person practices social distancing + let social_distancing = rng.random_bool(SOCIAL_DISTANCE_COMPLIANCE); + + // Initial disease state + let disease_state = if rng.random_bool(CHANCE_START_INFECTED) + { + DiseaseState::Exposed { + incubation_remaining: INCUBATION_PERIOD, + } + } + else + { + DiseaseState::Healthy + }; + + let person = Person { + disease_state, + social_distancing, + }; + + spawner.spawn((position, person, ContactHistory::default())); + } +} + +/// Enhanced movement system with social distancing behavior +fn people_move_with_social_distancing( + mut query: Query<(&mut GridPosition2D, &Person, Option<&Quarantined>)>, + spatial_grid: Res>, +) +{ + let mut rng = rand::rng(); + + for (mut position, person, quarantined) in &mut query + { + // Quarantined people don't move + if quarantined.is_some() + { + continue; + } + + // 50% chance to try to move + if !rng.random_bool(0.5) + { + continue; + } + + // Get potential movement directions + let directions = [ + GridPosition2D::new(position.x(), position.y() - 1), // up + GridPosition2D::new(position.x() - 1, position.y()), // left + GridPosition2D::new(position.x() + 1, position.y()), // right + GridPosition2D::new(position.x(), position.y() + 1), // down + ]; + + let mut best_moves = Vec::new(); + let mut min_crowding = usize::MAX; + + for new_pos in directions + { + // Check bounds + if new_pos.x() < 0 + || new_pos.x() >= GRID_SIZE + || new_pos.y() < 0 + || new_pos.y() >= GRID_SIZE + { + continue; + } + + // Check if in quarantine zone + if QUARANTINE_ZONE_ENABLED + { + let quarantine_center = + GridPosition2D::new(QUARANTINE_CENTER_X, QUARANTINE_CENTER_Y); + if (new_pos.0 - quarantine_center.0).abs().element_sum() <= QUARANTINE_RADIUS + { + // Only enter quarantine zone if not practicing social distancing + if person.social_distancing + { + continue; + } + } + } + + // Count people at potential destination for social distancing + let people_at_destination = spatial_grid.entities_at(&new_pos).count(); + + if person.social_distancing && SOCIAL_DISTANCING_ENABLED + { + // Social distancing: prefer less crowded areas + if people_at_destination < min_crowding + { + min_crowding = people_at_destination; + best_moves.clear(); + best_moves.push(new_pos); + } + else if people_at_destination == min_crowding + { + best_moves.push(new_pos); + } + } + else + { + // No social distancing: any valid move is fine + best_moves.push(new_pos); + } + } + + // Move to a randomly selected best position + if !best_moves.is_empty() + { + let chosen_move = best_moves.choose(&mut rng).copied().unwrap(); + + // Only move if it's not too crowded (even for non-social-distancing people) + let people_at_destination = spatial_grid.entities_at(&chosen_move).count(); + if people_at_destination < CROWDING_THRESHOLD + { + *position = chosen_move; + } + } + } +} + +/// Progress disease through incubation period +fn disease_incubation_progression(mut query: Query<&mut Person>) +{ + for mut person in &mut query + { + if let DiseaseState::Exposed { + incubation_remaining, + } = &mut person.disease_state + { + if *incubation_remaining > 1 + { + *incubation_remaining -= 1; + } + else + { + // Become infectious + person.disease_state = DiseaseState::Infectious; + } + } + } +} + +/// Advanced spatial disease transmission system using infection radius +fn spatial_disease_transmission( + spatial_grid: Res>, + mut query: Query<(Entity, &GridPosition2D, &mut Person), Without>, +) +{ + let mut rng = rand::rng(); + let mut new_exposures = Vec::new(); + + // Collect infectious people first to avoid borrowing conflicts + let infectious_people: Vec<(Entity, GridPosition2D)> = query + .iter() + .filter_map(|(entity, pos, person)| { + if matches!(person.disease_state, DiseaseState::Infectious) + { + Some((entity, *pos)) + } + else + { + None + } + }) + .collect(); + + for (infectious_entity, infectious_pos) in infectious_people + { + // Get all people within infection radius using iterative approach + let mut nearby_entities = Vec::new(); + let infectious_coord = infectious_pos.0; + + // Check all positions within Manhattan distance of INFECTION_RADIUS + for dx in -INFECTION_RADIUS..=INFECTION_RADIUS + { + for dy in -INFECTION_RADIUS..=INFECTION_RADIUS + { + let manhattan_distance = dx.abs() + dy.abs(); + if manhattan_distance <= INFECTION_RADIUS + { + let check_pos = + GridPosition2D::new(infectious_coord.x + dx, infectious_coord.y + dy); + nearby_entities.extend(spatial_grid.entities_at(&check_pos)); + } + } + } + + for nearby_entity in nearby_entities + { + if nearby_entity == infectious_entity + { + continue; // Don't infect self + } + + if let Ok((entity, susceptible_pos, person)) = query.get(nearby_entity) + { + // Only infect healthy people + if matches!(person.disease_state, DiseaseState::Healthy) + { + // Calculate infection probability based on distance + let distance = (infectious_pos.0 - susceptible_pos.0).abs().element_sum(); + let infection_chance = match distance + { + 0 | 1 => CHANCE_INFECT_AT_DISTANCE_1, // Same cell or adjacent + 2 => CHANCE_INFECT_AT_DISTANCE_2, // 2 cells away + _ => 0.0, // Too far + }; + + if rng.random_bool(infection_chance) + { + new_exposures.push(entity); + } + } + } + } + } + + // Apply new exposures + for entity in new_exposures + { + let (_, _, mut person) = query.get_mut(entity).expect("Entity should exist"); + person.disease_state = DiseaseState::Exposed { + incubation_remaining: INCUBATION_PERIOD, + }; + } +} + +/// Handle disease recovery and death +fn disease_recovery_and_death(mut commands: Commands, mut query: Query<(Entity, &mut Person)>) +{ + let mut rng = rand::rng(); + + for (entity, mut person) in &mut query + { + if matches!(person.disease_state, DiseaseState::Infectious) + { + if rng.random_bool(CHANCE_DIE) + { + // Person dies + commands.entity(entity).despawn(); + } + else if rng.random_bool(CHANCE_RECOVER) + { + // Person recovers and gains immunity + person.disease_state = DiseaseState::Recovered; + } + } + } +} + +/// Update contact history for contact tracing +fn update_contact_history(mut query: Query<(&GridPosition2D, &mut ContactHistory)>) +{ + for (position, mut contact_history) in &mut query + { + // Add current position to recent contacts + contact_history.recent_contacts.insert(*position); + + // Limit history size (keep last 14 positions) + if contact_history.recent_contacts.len() > 14 + { + // In a real implementation, you'd track timestamps and remove old ones + // For simplicity, we'll just clear periodically + if contact_history.recent_contacts.len() > 20 + { + contact_history.recent_contacts.clear(); + } + } + } +} + +/// Process contact tracing when someone becomes infectious +fn process_contact_tracing( + mut commands: Commands, + spatial_grid: Res>, + query_newly_infectious: Query< + (Entity, &GridPosition2D, &ContactHistory), + (With, Without), + >, + query_potential_contacts: Query, Without)>, +) +{ + if !CONTACT_TRACING_ENABLED + { + return; + } + + for (infectious_entity, _infectious_pos, contact_history) in &query_newly_infectious + { + // Check if this person just became infectious (simplified check) + // In a real implementation, you'd track state changes + + // Quarantine people who were in recent contact locations + for &contact_location in &contact_history.recent_contacts + { + let people_at_location = spatial_grid.entities_at(&contact_location); + + for potential_contact in people_at_location + { + if potential_contact == infectious_entity + { + continue; + } + + // Quarantine this person if they're not already quarantined + if query_potential_contacts.get(potential_contact).is_ok() + { + // Use try_insert to handle entities that may have been despawned + commands.entity(potential_contact).try_insert(Quarantined { + remaining_duration: CONTACT_QUARANTINE_DURATION, + }); + } + } + } + } +} + +/// Update quarantine status +fn update_quarantine_status(mut commands: Commands, mut query: Query<(Entity, &mut Quarantined)>) +{ + for (entity, mut quarantined) in &mut query + { + if quarantined.remaining_duration > 1 + { + quarantined.remaining_duration -= 1; + } + else + { + // End quarantine + commands.entity(entity).remove::(); + } + } +} diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..5d56faf --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" diff --git a/src/lib.rs b/src/lib.rs index 3b80c6d..d92f9fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,10 @@ mod spawner; mod traits; pub use error::*; +pub use plugins::{ + GridBounds, GridBounds2D, GridBounds3D, GridCoordinates, GridPosition, GridPosition2D, + GridPosition3D, SpatialGrid, SpatialGrid2D, SpatialGrid3D, TimeSeries, +}; pub use simulation::Simulation; pub use simulation_builder::SimulationBuilder; pub use spawner::Spawner; diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index e7353ce..74421dd 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -3,3 +3,9 @@ pub use step_counter::StepCounterPlugin; mod time_series; pub use time_series::{TimeSeries, TimeSeriesPlugin}; + +mod spatial_grid; +pub use spatial_grid::{ + GridBounds, GridBounds2D, GridBounds3D, GridCoordinates, GridPosition, GridPosition2D, + GridPosition3D, SpatialGrid, SpatialGrid2D, SpatialGrid3D, SpatialGridPlugin, +}; diff --git a/src/plugins/spatial_grid.rs b/src/plugins/spatial_grid.rs new file mode 100644 index 0000000..c381a65 --- /dev/null +++ b/src/plugins/spatial_grid.rs @@ -0,0 +1,481 @@ +use std::hash::Hash; + +use bevy::{ + ecs::entity::EntityHashMap, + platform::collections::{HashMap, HashSet}, + prelude::*, +}; + +use crate::plugins::step_counter::StepCounter; + +// Direction constants for 2D grid movement +const NORTH: IVec2 = IVec2::new(0, -1); +const SOUTH: IVec2 = IVec2::new(0, 1); +const EAST: IVec2 = IVec2::new(1, 0); +const WEST: IVec2 = IVec2::new(-1, 0); +const NORTH_EAST: IVec2 = IVec2::new(1, -1); +const NORTH_WEST: IVec2 = IVec2::new(-1, -1); +const SOUTH_EAST: IVec2 = IVec2::new(1, 1); +const SOUTH_WEST: IVec2 = IVec2::new(-1, 1); + +// Direction constants for 3D grid movement (orthogonal only) +const UP: IVec3 = IVec3::new(0, 0, 1); +const DOWN: IVec3 = IVec3::new(0, 0, -1); +const NORTH_3D: IVec3 = IVec3::new(0, -1, 0); +const SOUTH_3D: IVec3 = IVec3::new(0, 1, 0); +const EAST_3D: IVec3 = IVec3::new(1, 0, 0); +const WEST_3D: IVec3 = IVec3::new(-1, 0, 0); + +/// A sealed trait for coordinates on a grid. +/// Will be implemented for [`IVec2`] and [`IVec3`]. +pub trait GridCoordinates: + private::Sealed + Clone + Copy + Hash + PartialEq + Eq + Send + Sync + 'static +{ + fn neighbors(&self) -> impl Iterator; + + fn neighbors_orthogonal(&self) -> impl Iterator; + + fn in_bounds(&self, bounds: &GridBounds) -> bool; +} + +/// Describes the bounds of a grid. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GridBounds +{ + pub min: T, + pub max: T, +} + +impl GridCoordinates for IVec2 +{ + fn neighbors(&self) -> impl Iterator + { + const DIRECTIONS: [IVec2; 8] = [ + NORTH_WEST, NORTH, NORTH_EAST, WEST, EAST, SOUTH_WEST, SOUTH, SOUTH_EAST, + ]; + DIRECTIONS.into_iter().map(move |dir| self + dir) + } + + fn neighbors_orthogonal(&self) -> impl Iterator + { + const DIRECTIONS: [IVec2; 4] = [NORTH, WEST, EAST, SOUTH]; + DIRECTIONS.into_iter().map(move |dir| self + dir) + } + + fn in_bounds(&self, bounds: &GridBounds) -> bool + { + bounds.contains(self) + } +} + +impl GridCoordinates for IVec3 +{ + fn neighbors(&self) -> impl Iterator + { + // 26 neighbors in 3D (3x3x3 cube minus center) + (-1..=1).flat_map(move |dx| { + (-1..=1).flat_map(move |dy| { + (-1..=1).filter_map(move |dz| { + if dx == 0 && dy == 0 && dz == 0 + { + None // Skip center + } + else + { + Some(self + Self::new(dx, dy, dz)) + } + }) + }) + }) + } + + fn neighbors_orthogonal(&self) -> impl Iterator + { + // 6 orthogonal neighbors in 3D + const DIRECTIONS: [IVec3; 6] = [WEST_3D, EAST_3D, NORTH_3D, SOUTH_3D, DOWN, UP]; + DIRECTIONS.into_iter().map(move |dir| self + dir) + } + + fn in_bounds(&self, bounds: &GridBounds) -> bool + { + bounds.contains(self) + } +} + +/// Component representing a position in the spatial grid. +/// Generic over coordinate types that implement the `GridCoordinate` trait. +#[derive(Component, Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct GridPosition(pub T); + +// Convenience methods for 2D positions +impl GridPosition +{ + /// Create a new 2D grid position from x, y coordinates. + #[must_use] + pub const fn new(x: i32, y: i32) -> Self + { + Self(IVec2::new(x, y)) + } + + pub fn neighbors(&self) -> impl Iterator + { + self.0.neighbors().map(Self) + } + + pub fn neighbors_orthogonal(&self) -> impl Iterator + { + self.0.neighbors_orthogonal().map(Self) + } + + #[must_use] + pub const fn x(&self) -> i32 + { + self.0.x + } + + #[must_use] + pub const fn y(&self) -> i32 + { + self.0.y + } +} + +// Convenience methods for 3D positions +impl GridPosition +{ + /// Create a new 3D grid position from x, y, z coordinates. + #[must_use] + pub const fn new(x: i32, y: i32, z: i32) -> Self + { + Self(IVec3::new(x, y, z)) + } + + pub fn neighbors(&self) -> impl Iterator + { + self.0.neighbors().map(Self) + } + + pub fn neighbors_orthogonal(&self) -> impl Iterator + { + self.0.neighbors_orthogonal().map(Self) + } + + #[must_use] + pub const fn x(&self) -> i32 + { + self.0.x + } + + #[must_use] + pub const fn y(&self) -> i32 + { + self.0.y + } + + #[must_use] + pub const fn z(&self) -> i32 + { + self.0.z + } +} + +/// Component that maintains a spatial index for efficient neighbor queries. +/// Generic over coordinate types that implement the `GridCoordinate` trait and component types. +#[derive(Resource)] +pub struct SpatialGrid +{ + /// Maps grid positions to entities at those positions. + position_to_entities: HashMap, HashSet>, + /// Maps entities to their grid positions for fast lookups (optimized for Entity keys). + entity_to_position: EntityHashMap>, + /// Grid bounds for validation and iteration. + bounds: Option>, + /// Phantom data to maintain type association with component C. + _phantom: std::marker::PhantomData, +} + +/// Specific implementations for 2D bounds +impl GridBounds +{ + /// Check if a position is within these bounds. + /// + /// # Panics + /// + /// If [`Self::min`] is larger than [`Self::max`] along any axis. + #[must_use] + pub fn contains(&self, pos: &IVec2) -> bool + { + assert!(self.min.x <= self.max.x); + assert!(self.min.y <= self.max.y); + (pos.x >= self.min.x && pos.x <= self.max.x) && (pos.y >= self.min.y && pos.y <= self.max.y) + } +} + +/// Specific implementations for 3D bounds +impl GridBounds +{ + /// Check if a position is within these bounds. + /// + /// # Panics + /// + /// If [`Self::min`] is larger than [`Self::max`] along any axis. + #[must_use] + pub fn contains(&self, pos: &IVec3) -> bool + { + assert!(self.min.x <= self.max.x); + assert!(self.min.y <= self.max.y); + assert!(self.min.z <= self.max.z); + (pos.x >= self.min.x && pos.x <= self.max.x) + && (pos.y >= self.min.y && pos.y <= self.max.y) + && (pos.z >= self.min.z && pos.z <= self.max.z) + } +} + +impl SpatialGrid +{ + #[must_use] + pub fn new(bounds: Option>) -> Self + { + Self { + position_to_entities: HashMap::default(), + entity_to_position: EntityHashMap::default(), + bounds, + _phantom: std::marker::PhantomData, + } + } + + #[must_use] + pub const fn bounds(&self) -> Option> + { + self.bounds + } + + /// Add an entity at a specific grid position. + fn insert(&mut self, entity: Entity, position: GridPosition) + { + // Remove entity from old position if it exists + self.remove(entity); + + // Insert at new position + self.position_to_entities + .entry(position) + .or_default() + .insert(entity); + self.entity_to_position.insert(entity, position); + } + + /// Remove an entity from the spatial index. + /// + /// Returns the position where the entity was located, if it was found. + fn remove(&mut self, entity: Entity) -> Option> + { + let position = self.entity_to_position.remove(&entity)?; + + let Some(entities_at_position) = self.position_to_entities.get_mut(&position) + else + { + panic!("entity found in one hashmap but not the other?"); + }; + + entities_at_position.remove(&entity); + if entities_at_position.is_empty() + { + self.position_to_entities.remove(&position); + } + Some(position) + } + + /// Get all entities at a specific position. + pub fn entities_at(&self, position: &GridPosition) -> impl Iterator + '_ + { + self.position_to_entities + .get(position) + .into_iter() + .flat_map(|set| set.iter().copied()) + } + + /// Get the position of an entity. + #[must_use] + pub fn position_of(&self, entity: Entity) -> Option> + { + self.entity_to_position.get(&entity).copied() + } + + /// Get all entities in the neighborhood of a position (Moore neighborhood). + pub fn neighbors_of(&self, position: &GridPosition) -> impl Iterator + { + position + .0 + .neighbors() + .filter(|neighbor_pos| { + self.bounds + .is_none_or(|bounds| neighbor_pos.in_bounds(&bounds)) + }) + .map(|p| GridPosition(p)) + .flat_map(|neighbor_pos| { + self.position_to_entities + .get(&neighbor_pos) + .into_iter() + .flat_map(|set| set.iter().copied()) + }) + } + + /// Get all entities in the orthogonal neighborhood of a position (Von Neumann neighborhood). + pub fn orthogonal_neighbors_of( + &self, + position: &GridPosition, + ) -> impl Iterator + { + position + .0 + .neighbors_orthogonal() + .filter(|neighbor_pos| { + self.bounds + .is_none_or(|bounds| neighbor_pos.in_bounds(&bounds)) + }) + .map(|p| GridPosition(p)) + .flat_map(|neighbor_pos| { + self.position_to_entities + .get(&neighbor_pos) + .into_iter() + .flat_map(|set| set.iter().copied()) + }) + } + + /// Clear all entities from the spatial index. + fn clear(&mut self) + { + self.position_to_entities.clear(); + self.entity_to_position.clear(); + } + + /// Check if a position is empty (has no entities). + #[must_use] + pub fn is_empty(&self, position: &GridPosition) -> bool + { + self.position_to_entities + .get(position) + .is_none_or(HashSet::is_empty) + } + + /// Get total number of entities in the grid. + #[must_use] + pub fn num_entities(&self) -> usize + { + self.entity_to_position.len() + } +} + +/// Plugin that maintains a spatial index for entities with `GridPosition` components. +/// Generic over coordinate types that implement the `GridCoordinate` trait and component types. +pub struct SpatialGridPlugin +{ + bounds: Option>, + _phantom: std::marker::PhantomData<(T, C)>, +} + +impl SpatialGridPlugin +{ + pub const fn new(bounds: Option>) -> Self + { + Self { + bounds, + _phantom: std::marker::PhantomData, + } + } + + pub fn init(app: &mut App, bounds: Option>) + { + // Spawn the spatial grid entity directly + let spatial_grid = SpatialGrid::::new(bounds); + app.world_mut().insert_resource(spatial_grid); + } +} + +impl Plugin for SpatialGridPlugin +{ + fn build(&self, app: &mut App) + { + Self::init(app, self.bounds); + + // System to maintain the spatial index + app.add_systems( + PreUpdate, + ( + spatial_grid_reset_system::, + spatial_grid_update_system::, + spatial_grid_cleanup_system::, + ) + .chain(), + ); + } +} + +/// System that resets the spatial grid at the beginning of each simulation. +fn spatial_grid_reset_system( + mut spatial_grid: ResMut>, + step_counter: Res, +) +{ + // Reset the spatial grid whenever the step counter is 0 + // This should occur on the first step of every simulation + if **step_counter == 0 + { + spatial_grid.clear(); + } +} + +/// Query for entities with `GridPosition` components that have been added or changed. +type GridPositionQuery<'world, 'state, T, C> = + Query<'world, 'state, (Entity, &'static GridPosition), (Changed>, With)>; + +/// System that updates the spatial grid when entities with `GridPosition` are added or moved. +fn spatial_grid_update_system( + mut spatial_grid: ResMut>, + query: GridPositionQuery, +) +{ + for (entity, position) in &query + { + spatial_grid.insert(entity, *position); + } +} + +/// System that removes entities from the spatial grid when they no longer have `GridPosition`. +fn spatial_grid_cleanup_system( + mut spatial_grid: ResMut>, + mut removed: RemovedComponents>, +) +{ + for entity in removed.read() + { + spatial_grid.remove(entity); + } +} + +// Type aliases for convenience +/// 2D spatial grid using `IVec2` coordinates. +pub type SpatialGrid2D = SpatialGrid; + +/// 3D spatial grid using `IVec3` coordinates. +pub type SpatialGrid3D = SpatialGrid; + +/// 2D grid position using `IVec2` coordinates. +pub type GridPosition2D = GridPosition; + +/// 3D grid position using `IVec3` coordinates. +pub type GridPosition3D = GridPosition; + +/// 2D grid bounds using `IRect`. +pub type GridBounds2D = GridBounds; + +/// 3D grid bounds using custom `Bounds3D`. +pub type GridBounds3D = GridBounds; + +/// Private module to enforce the sealed trait pattern. +mod private +{ + pub trait Sealed {} + impl Sealed for bevy::prelude::IVec2 {} + impl Sealed for bevy::prelude::IVec3 {} +} diff --git a/src/plugins/time_series.rs b/src/plugins/time_series.rs index 6b2b441..0c6e24b 100644 --- a/src/plugins/time_series.rs +++ b/src/plugins/time_series.rs @@ -32,6 +32,7 @@ where O: Send + Sync + 'static, F: QueryFilter + Send + Sync + 'static, { + #[must_use] pub const fn new(sample_interval: usize) -> Self { Self { diff --git a/src/prelude.rs b/src/prelude.rs index 3207220..235b3f2 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -4,6 +4,7 @@ pub use bevy::prelude::{ }; pub use super::{ - error::*, simulation::Simulation, simulation_builder::SimulationBuilder, spawner::Spawner, - traits::*, + GridBounds, GridBounds2D, GridBounds3D, GridCoordinates, GridPosition, GridPosition2D, + GridPosition3D, SpatialGrid, SpatialGrid2D, SpatialGrid3D, TimeSeries, error::*, + simulation::Simulation, simulation_builder::SimulationBuilder, spawner::Spawner, traits::*, }; diff --git a/src/simulation_builder.rs b/src/simulation_builder.rs index c67205c..3fec802 100644 --- a/src/simulation_builder.rs +++ b/src/simulation_builder.rs @@ -6,7 +6,10 @@ use bevy::{ use crate::{ Sample, SimulationBuildError, - plugins::{StepCounterPlugin, TimeSeries, TimeSeriesPlugin}, + plugins::{ + GridBounds, GridCoordinates, SpatialGridPlugin, StepCounterPlugin, TimeSeries, + TimeSeriesPlugin, + }, simulation::Simulation, spawner::Spawner, }; @@ -76,6 +79,44 @@ impl SimulationBuilder self } + /// Add a spatial grid for a specific component type to the simulation. + /// + /// + /// This creates a spatial index for entities that have both `GridPosition` and the specified component `C`. + /// Multiple spatial grids can coexist for different component types. + /// The spatial grid will be spawned as an entity during simulation startup. + /// + /// Example: + /// ``` + /// # use bevy::prelude::IVec2; + /// # use incerto::prelude::*; + /// #[derive(Component)] + /// struct Person; + /// + /// #[derive(Component)] + /// struct Vehicle; + /// + /// let bounds = GridBounds2D { + /// min: IVec2::new(0, 0), + /// max: IVec2::new(99, 99), + /// }; + /// let simulation = SimulationBuilder::new() + /// .add_spatial_grid::(bounds) + /// .add_spatial_grid::(bounds) + /// .build(); + /// ``` + #[must_use] + pub fn add_spatial_grid( + mut self, + bounds: GridBounds, + ) -> Self + { + self.sim + .app + .add_plugins(SpatialGridPlugin::::new(Some(bounds))); + self + } + /// Add an entity spawner function to the simulation. /// /// In the beginning of ever simulation, each of the spawner functions added here diff --git a/tests/mod.rs b/tests/mod.rs index 5fbf105..b9cf797 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -1,2 +1,3 @@ mod test_builder; mod test_counter; +mod test_spatial_grid; diff --git a/tests/test_spatial_grid.rs b/tests/test_spatial_grid.rs new file mode 100644 index 0000000..7fbfab7 --- /dev/null +++ b/tests/test_spatial_grid.rs @@ -0,0 +1,309 @@ +#![allow(clippy::expect_used)] +#![allow(clippy::uninlined_format_args)] +#![allow(clippy::cast_possible_truncation)] + +use bevy::prelude::{IVec2, IVec3}; +use incerto::prelude::*; + +#[test] +fn test_grid_position_neighbors() +{ + let pos = GridPosition2D::new(1, 1); + + let neighbors: Vec = pos.neighbors().collect(); + assert_eq!(neighbors.len(), 8); + + // Check all 8 neighbors are present + let expected_neighbors = [ + GridPosition2D::new(0, 0), + GridPosition2D::new(1, 0), + GridPosition2D::new(2, 0), + GridPosition2D::new(0, 1), + GridPosition2D::new(2, 1), + GridPosition2D::new(0, 2), + GridPosition2D::new(1, 2), + GridPosition2D::new(2, 2), + ]; + + for expected in expected_neighbors + { + assert!( + neighbors.contains(&expected), + "Missing neighbor: {:?}", + expected + ); + } +} + +#[test] +fn test_grid_position_orthogonal_neighbors() +{ + let pos = GridPosition2D::new(1, 1); + + let neighbors: Vec = pos.neighbors_orthogonal().collect(); + assert_eq!(neighbors.len(), 4); + + let expected_neighbors = [ + GridPosition2D::new(1, 0), // top + GridPosition2D::new(0, 1), // left + GridPosition2D::new(2, 1), // right + GridPosition2D::new(1, 2), // bottom + ]; + + for expected in expected_neighbors + { + assert!( + neighbors.contains(&expected), + "Missing orthogonal neighbor: {:?}", + expected + ); + } +} + +#[test] +fn test_grid_position_distances() +{ + let pos1 = GridPosition2D::new(0, 0); + let pos2 = GridPosition2D::new(3, 4); + + // Test Manhattan distance + let diff = pos2.0 - pos1.0; + assert_eq!(diff.abs().element_sum(), 7); + + let pos3 = GridPosition2D::new(1, 1); + let diff2 = pos3.0 - pos1.0; + assert_eq!(diff2.abs().element_sum(), 2); +} + +#[test] +fn test_grid_bounds() +{ + let bounds = GridBounds2D { + min: IVec2::new(0, 0), + max: IVec2::new(9, 9), + }; + + assert!(bounds.contains(&GridPosition2D::new(0, 0).0)); + assert!(bounds.contains(&GridPosition2D::new(9, 9).0)); + assert!(bounds.contains(&GridPosition2D::new(5, 5).0)); + + assert!(!bounds.contains(&GridPosition2D::new(-1, 0).0)); + assert!(!bounds.contains(&GridPosition2D::new(0, -1).0)); + assert!(!bounds.contains(&GridPosition2D::new(10, 5).0)); + assert!(!bounds.contains(&GridPosition2D::new(5, 10).0)); +} + +#[test] +fn test_3d_grid_position_neighbors() +{ + let pos = GridPosition3D::new(1, 1, 1); + + let neighbors: Vec = pos.neighbors().collect(); + assert_eq!(neighbors.len(), 26); // 3x3x3 cube minus center = 26 neighbors + + // Check that center position is not included + assert!(!neighbors.contains(&pos)); + + // Check some specific 3D neighbors + assert!(neighbors.contains(&GridPosition3D::new(0, 0, 0))); // corner + assert!(neighbors.contains(&GridPosition3D::new(2, 2, 2))); // opposite corner + assert!(neighbors.contains(&GridPosition3D::new(1, 1, 0))); // directly below + assert!(neighbors.contains(&GridPosition3D::new(1, 1, 2))); // directly above +} + +#[test] +fn test_3d_grid_position_orthogonal_neighbors() +{ + let pos = GridPosition3D::new(1, 1, 1); + + let neighbors: Vec = pos.neighbors_orthogonal().collect(); + assert_eq!(neighbors.len(), 6); // 6 orthogonal directions in 3D + + let expected_neighbors = [ + GridPosition3D::new(0, 1, 1), // -x + GridPosition3D::new(2, 1, 1), // +x + GridPosition3D::new(1, 0, 1), // -y + GridPosition3D::new(1, 2, 1), // +y + GridPosition3D::new(1, 1, 0), // -z + GridPosition3D::new(1, 1, 2), // +z + ]; + + for expected in expected_neighbors + { + assert!( + neighbors.contains(&expected), + "Missing 3D orthogonal neighbor: {:?}", + expected + ); + } +} + +#[test] +fn test_3d_grid_position_distances() +{ + let pos1 = GridPosition3D::new(0, 0, 0); + let pos2 = GridPosition3D::new(3, 4, 5); + + // Test 3D Manhattan distance + let diff = pos2.0 - pos1.0; + assert_eq!(diff.abs().element_sum(), 12); // 3 + 4 + 5 = 12 + + let pos3 = GridPosition3D::new(1, 1, 1); + let diff2 = pos3.0 - pos1.0; + assert_eq!(diff2.abs().element_sum(), 3); // 1 + 1 + 1 = 3 +} + +#[test] +fn test_3d_grid_bounds() +{ + let bounds = GridBounds3D { + min: IVec3::new(0, 0, 0), + max: IVec3::new(9, 9, 9), + }; + + assert!(bounds.contains(&GridPosition3D::new(0, 0, 0).0)); + assert!(bounds.contains(&GridPosition3D::new(9, 9, 9).0)); + assert!(bounds.contains(&GridPosition3D::new(5, 5, 5).0)); + + assert!(!bounds.contains(&GridPosition3D::new(-1, 0, 0).0)); + assert!(!bounds.contains(&GridPosition3D::new(0, -1, 0).0)); + assert!(!bounds.contains(&GridPosition3D::new(0, 0, -1).0)); + assert!(!bounds.contains(&GridPosition3D::new(10, 5, 5).0)); + assert!(!bounds.contains(&GridPosition3D::new(5, 10, 5).0)); + assert!(!bounds.contains(&GridPosition3D::new(5, 5, 10).0)); +} + +#[test] +fn test_3d_spatial_grid_integration() +{ + #[derive(Component)] + struct TestEntity3D(i32); + + impl Sample for TestEntity3D + { + fn sample(components: &[&Self]) -> usize + { + components.len() + } + } + + let bounds = GridBounds3D { + min: IVec3::new(0, 0, 0), + max: IVec3::new(4, 4, 4), + }; + + let builder = SimulationBuilder::new() + .add_spatial_grid::(bounds) + .add_entity_spawner(|spawner| { + // Spawn entities at different 3D positions + spawner.spawn((GridPosition3D::new(0, 0, 0), TestEntity3D(1))); + spawner.spawn((GridPosition3D::new(2, 2, 2), TestEntity3D(2))); + spawner.spawn((GridPosition3D::new(4, 4, 4), TestEntity3D(3))); + spawner.spawn((GridPosition3D::new(1, 2, 3), TestEntity3D(4))); + }) + .add_systems( + |spatial_grid: Res>, + query: Query<(Entity, &GridPosition3D, &TestEntity3D)>| { + // Test 3D spatial queries + let center_pos = GridPosition3D::new(2, 2, 2); + + // Find entities within distance 2 in 3D space using neighbor-based approach + let mut nearby_entities = Vec::new(); + let center_coord = center_pos.0; + + // Check all positions within Manhattan distance of 2 + for dx in -2i32..=2i32 + { + for dy in -2i32..=2i32 + { + for dz in -2i32..=2i32 + { + let manhattan_distance = dx.abs() + dy.abs() + dz.abs(); + if manhattan_distance <= 2 + { + let check_pos = GridPosition3D::new( + center_coord.x + dx, + center_coord.y + dy, + center_coord.z + dz, + ); + nearby_entities.extend(spatial_grid.entities_at(&check_pos)); + } + } + } + } + + assert!( + !nearby_entities.is_empty(), + "Should find nearby entities in 3D" + ); + + // Verify all positions are valid 3D coordinates + for (_, position, test_entity) in &query + { + assert!(test_entity.0 > 0); + assert!(position.x() >= 0 && position.x() <= 4); + assert!(position.y() >= 0 && position.y() <= 4); + assert!(position.z() >= 0 && position.z() <= 4); + } + }, + ); + + let mut simulation = builder.build(); + simulation.run(1); + + // Verify all entities were created + let entity_count = simulation + .sample::() + .expect("Failed to sample TestEntity3D count"); + assert_eq!(entity_count, 4); +} + +#[test] +fn test_spatial_grid_reset_functionality() +{ + #[derive(Component)] + #[allow(dead_code)] + struct TestResetEntity(i32); + + impl Sample for TestResetEntity + { + fn sample(components: &[&Self]) -> usize + { + components.len() + } + } + + let bounds = GridBounds2D { + min: IVec2::new(0, 0), + max: IVec2::new(4, 4), + }; + + let mut simulation = SimulationBuilder::new() + .add_spatial_grid::(bounds) + .add_entity_spawner(|spawner| { + // Spawn entities at different positions + spawner.spawn((GridPosition2D::new(0, 0), TestResetEntity(1))); + spawner.spawn((GridPosition2D::new(2, 2), TestResetEntity(2))); + spawner.spawn((GridPosition2D::new(4, 4), TestResetEntity(3))); + }) + .build(); + + // Run first simulation + simulation.run(2); + + // Verify entities are tracked + let entity_count = simulation + .sample::() + .expect("Failed to sample TestResetEntity count"); + assert_eq!(entity_count, 3); + + // Reset simulation (this should trigger spatial grid reset on step 0) + simulation.reset(); + simulation.run(1); + + // Verify entities are still tracked after reset + let entity_count_after_reset = simulation + .sample::() + .expect("Failed to sample TestResetEntity count after reset"); + assert_eq!(entity_count_after_reset, 3); +}