A high-fidelity drone simulation framework integrating MuJoCo physics with ArduPilot SITL for training neural network controllers using reinforcement learning.
This project provides:
- MuJoCo-based quadrotor simulation with realistic physics
- ArduPilot SITL integration via JSON/MAVLink protocols
- Gymnasium environments for reinforcement learning
- Neural network controllers trained with PPO/SAC
- Yaw tracking task as a demonstration of NN control
# Clone repository
git clone https://github.com/brainnotincluded/NNPID.git
cd NNPID
# Install dependencies (using uv - recommended)
# Inference-only (load trained models):
uv sync
# Full install (training + simulation + visualization):
uv sync --all-extras
# Or using pip
pip install -e .
pip install -e ".[full]"
# Train yaw tracking model
python scripts/train_yaw_tracker.py --timesteps 500000
# Evaluate trained model
python scripts/evaluate_yaw_tracker.py --model runs/<run_name>/best_model
# Visualize (interactive or video)
python scripts/visualize_mujoco.py --mode interactive --model runs/<run_name>/best_model
python scripts/visualize_mujoco.py --mode video --model runs/<run_name>/best_model --output runs/visualizations/demo.mp4
# List available runs
ls runs/For full setup details, see docs/GETTING_STARTED.md.
NNPID/
├── src/ # Main source code
│ ├── core/ # MuJoCo simulation core
│ │ ├── mujoco_sim.py # Main simulator class
│ │ ├── quadrotor.py # Quadrotor dynamics
│ │ └── sensors.py # Sensor models
│ ├── environments/ # Gymnasium environments
│ │ ├── base_drone_env.py # Base environment
│ │ ├── yaw_tracking_env.py # Yaw tracking task
│ │ ├── hover_env.py # Hover task
│ │ └── waypoint_env.py # Waypoint navigation
│ ├── controllers/ # Control algorithms
│ │ ├── base_controller.py # PID controller
│ │ ├── nn_controller.py # Neural network controller
│ │ └── yaw_rate_controller.py # Yaw rate controller
│ ├── communication/ # SITL communication
│ │ ├── mavlink_bridge.py # MAVLink protocol
│ │ └── messages.py # Message definitions
│ ├── deployment/ # Model deployment
│ │ ├── yaw_tracker_sitl.py # Deploy to ArduPilot SITL
│ │ └── model_export.py # Export models
│ ├── utils/ # Utilities
│ │ ├── rotations.py # Quaternion math
│ │ └── transforms.py # Coordinate transforms
│ ├── visualization/ # Visualization tools
│ │ ├── viewer.py # MuJoCo viewer
│ │ ├── scene_objects.py # 3D scene objects
│ │ ├── nn_visualizer.py # Neural network visualizer
│ │ ├── telemetry_hud.py # Real-time HUD
│ │ └── mujoco_overlay.py # Combined overlay system
│ └── perturbations/ # Realistic perturbations
│ ├── wind.py # Wind effects
│ ├── delays.py # Sensor/actuator delays
│ └── ... # Other perturbations
├── scripts/ # Executable scripts
│ ├── train_yaw_tracker.py # Train yaw tracking
│ ├── evaluate_yaw_tracker.py # Evaluate models
│ ├── visualize_mujoco.py # Unified MuJoCo visualization
│ ├── run_mega_viz.py # Full visualization
│ ├── model_inspector.py # CLI model analysis
│ ├── run_ardupilot_sim.py # Run with ArduPilot SITL
│ └── run_yaw_tracker_sitl.py # Deploy NN to SITL
├── models/ # MuJoCo XML models
│ └── quadrotor_x500.xml # X500 quadrotor model
├── config/ # Configuration files
│ ├── yaw_tracking.yaml # Yaw tracking config
│ └── simulation.yaml # Simulation settings
├── tests/ # Unit tests
├── docs/ # Documentation
└── runs/ # Training runs & checkpoints
High-fidelity physics simulation with:
- 500Hz physics timestep
- Accurate motor dynamics
- IMU, gyroscope, accelerometer sensors
- Ground contact detection
from src.core.mujoco_sim import MuJoCoSimulator
sim = MuJoCoSimulator("models/quadrotor_x500.xml")
sim.reset(position=[0, 0, 1])
sim.step(motor_commands=[0.5, 0.5, 0.5, 0.5])
state = sim.get_state()Standard RL interface for training:
import gymnasium as gym
from src.environments import YawTrackingEnv
env = YawTrackingEnv(render_mode="rgb_array")
obs, info = env.reset()
for _ in range(1000):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)Both classical and neural network controllers:
from src.controllers import PIDController, NNController
# Classical PID
pid = PIDController(kp=1.0, kd=0.1)
# Neural network (trained)
nn = NNController(model_path="runs/<run_name>/best_model")
action = nn.compute_action(observation)Connect MuJoCo simulation to ArduPilot:
# Terminal 1: Start ArduPilot SITL
sim_vehicle.py -v ArduCopter -f JSON --console
# Terminal 2: Run MuJoCo bridge
python scripts/run_ardupilot_sim.pyTrain a neural network to keep the drone facing a moving target:
# Basic training
python scripts/train_yaw_tracker.py
# With custom settings
python scripts/train_yaw_tracker.py \
--config config/yaw_tracking.yaml \
--timesteps 1000000 \
--n-envs 8Edit config/yaw_tracking.yaml:
environment:
hover_height: 1.0
max_episode_steps: 1000
target_patterns: ["circular", "random"]
target_speed_min: 0.1
target_speed_max: 0.3
max_yaw_rate: 1.0
action_dead_zone: 0.08
training:
algorithm: PPO
learning_rate: 0.0003
n_steps: 2048
batch_size: 64
policy_kwargs:
net_arch: [64, 64]Use the unified CLI for interactive or video visualization:
# Interactive viewer
python scripts/visualize_mujoco.py --mode interactive --model runs/<run_name>/best_model
# Record video
python scripts/visualize_mujoco.py \
--mode video \
--model runs/<run_name>/best_model \
--patterns circular sinusoidal \
--target-speed-min 0.05 \
--target-speed-max 0.1 \
--output runs/visualizations/demo.mp4
# Debug raw observations (skip VecNormalize)
python scripts/visualize_mujoco.py \
--mode interactive \
--model runs/<run_name>/best_model \
--no-normalizeRun full visualization with all effects:
# Run with trained model
python scripts/run_mega_viz.py --model runs/<run_name>/best_model
# With perturbations and video recording
python scripts/run_mega_viz.py \
--model runs/<run_name>/best_model \
--perturbations config/perturbations.yaml \
--record output.mp4Features:
- 3D Scene Objects: Wind arrows, force vectors, trajectory trails, VRS danger zones
- Neural Network Visualizer: Real-time activation display
- Telemetry HUD: Roll/pitch/yaw graphs, motor indicators, attitude display
- Perturbation Effects: Visual wind, ground effects, sensor delays
Analyze trained models without running simulation:
# Show architecture with ASCII diagram
python scripts/model_inspector.py arch runs/model.zip --diagram
# Visualize weights as heatmap
python scripts/model_inspector.py weights runs/model.zip --heatmap
# Analyze activations over episodes
python scripts/model_inspector.py activations runs/model.zip --episodes 5
# Export statistics to JSON
python scripts/model_inspector.py stats runs/model.zip --output stats.jsonAdd realistic disturbances for robust training:
# Train with perturbations
python scripts/train_yaw_tracker.py --perturbations config/perturbations.yamlAvailable perturbations:
- Wind: Steady wind, gusts, turbulence (Dryden model)
- Delays: Sensor latency, actuator delays, jitter
- Sensor Noise: Gaussian noise, drift, outliers, GPS loss
- Physics: CoM offset, motor variations, ground effect
- Aerodynamics: Air drag, blade flapping, VRS
- External Forces: Impulses, vibrations, EMI
Deploy trained model to ArduPilot SITL:
python scripts/run_yaw_tracker_sitl.py \
--model runs/<run_name>/best_model \
--connection udp:127.0.0.1:14550Note: For correct inference, vec_normalize.pkl must exist in the run directory.
- Python 3.10+
- Inference-only: Stable-Baselines3 + PyTorch
- Full simulation/training: MuJoCo 3.0+
- ArduPilot SITL (optional, for deployment)
See docs/SITL_INTEGRATION.md for full setup instructions on macOS and Linux.
- Getting Started - Setup and first steps
- Architecture - Code structure for developers
- Training Guide - How to train models
- Using Trained Models - Load and use trained models
- Webots Human Tracking - Track pedestrians in Webots
- Webots Quickstart - Fast setup for Webots tracking
- SITL Integration - ArduPilot connection
- Fork the repository
- Create a feature branch
- Make changes with tests
- Submit a pull request
MIT License - see LICENSE file for details.
- MuJoCo - Physics simulation
- Gymnasium - RL interface
- Stable-Baselines3 - RL algorithms
- ArduPilot - Flight controller