Skip to content

Latest commit

 

History

History
280 lines (225 loc) · 8.64 KB

File metadata and controls

280 lines (225 loc) · 8.64 KB

Sat2Graph PyTorch 3 Reproduction

A complete Python 3 + PyTorch reproduction of Sat2Graph: Road Graph Extraction through Graph-Tensor Encoding (ECCV 2020), with integrated TOPO/APLS evaluation metrics.

🙏 Acknowledgments

This project is a faithful PyTorch 3 refactoring of the original Sat2Graph codebase (Python 2.7 + TensorFlow 1.x). We deeply appreciate the original authors and maintainers:


📊 Reproduction Results

Baseline Comparison (City-Scale 20 Cities Dataset)

Method TOPO-F1 APLS Notes
Original Sat2Graph (Paper) 76.26% 63.14% 20-city US model from ECCV 2020
Our PyTorch Reproduction 77.04% 65.05% PyTorch reimplementation (fp32, 300k steps)
Improvement +0.78pp +1.91pp ✅ Successful reproduction

Detailed Metrics (27 Test Tiles)

TOPO Metrics

Precision:     84.59%  (mean across 27 tiles)
Overall-Recall: 76.40%
F1 Score:      80.35%

APLS Metrics

Mean APLS:     65.05%
Median APLS:   69.29%
Min APLS:      31.08%
Max APLS:      80.72%
Std Dev:       12.92%

Performance Characteristics

  • Training time: ~72 hours on RTX 3060 (8GB VRAM)
  • Convergence: 300k steps, learning rate 0.001 with 0.5 decay every 50k steps
  • Mixed Precision: fp32 stable (fp16 requires careful tuning)
  • Inference speed: ~2-3 seconds per 2048×2048 tile (sliding window 352×352)

🏗️ Architecture

Model Components (PyTorch nn.Module)

Sat2GraphModel (55M parameters)
├── DLA Encoder with ResNet blocks
├── Multi-scale decoder with 26-channel output:
│   ├── Channels 0-1: Keypoint probability (vertex)
│   ├── Channels 2-25: 6 directions × 4 channels
│   │   ├── 1 channel: Direction probability (edge exists)
│   │   ├── 2 channels: Direction vector (dx, dy)
│   │   └── 1 channel: Reserved
│   └── Channels 24-25: Segmentation mask (optional)
└── Graph decoding pipeline (3-pass snapping + R-tree refinement)

Key Differences from TensorFlow Original

Aspect TensorFlow PyTorch
Padding tf.pad + VALID F.pad + padding=0
BatchNorm decay=0.99 momentum=0.01 (inverted)
Data format NHWC NCHW (with conversion helpers)
Training loop tf.Session Eager execution
Checkpoint tf.train.Saver torch.save/load

🚀 Quick Start

Installation

# Clone repository
git clone https://github.com/Oops-maker/sat2graph-py3-pytorch.git
cd sat2graph-py3-pytorch

# Install dependencies (assuming conda PyTorch environment)
pip install torch torchvision opencv-python pillow numpy scipy rtree scikit-image wandb

# For metrics evaluation, install Go (optional)
# APLS metric requires: go >= 1.13
# TOPO metric: pure Python (hopcroftkarp included)

Training

cd model

# Train on 20-city dataset (144 train tiles)
python train.py \
  -model_save checkpoints \
  -image_size 352 \
  -batch_size 2 \
  -lr 0.001 \
  -channel 12 \
  -resnet_step 8 \
  -max_steps 300000 \
  -mode train

# Expected: 72h on RTX 3060 (8GB)

Testing & Evaluation

# Test on 27 test tiles with TOPO/APLS metrics
python train.py \
  -model_save checkpoints \
  -image_size 352 \
  -model_recover checkpoints/model_final.pth \
  -mode test \
  -eval_metrics

# Outputs:
# - model/outputs/region_*_output_graph.p (predicted graphs)
# - model/outputs/region_*_topo.txt (TOPO metrics)
# - model/outputs/region_*_apls.txt (APLS metrics)

Single Image Inference

python infer.py path/to/image.png output_prefix
# Generates: output_prefix_graph.p, output_prefix_graph.json, output_prefix.png

📂 Repository Structure

sat2graph-py3-pytorch/
├── README.md                          # Original paper README
├── README_REPRODUCTION.md             # This file
├── model/
│   ├── model.py                       # Sat2GraphModel (PyTorch nn.Module)
│   ├── train.py                       # Training & evaluation main script
│   ├── infer.py                       # Single image inference
│   ├── decoder.py                     # Graph decoding (tensor→graph)
│   ├── dataloader.py                  # Dataset loading (20-city dataset)
│   ├── resnet.py                      # ResNet blocks
│   ├── tf_common_layer.py            # Conv layers, BatchNorm utilities
│   ├── common.py                      # Common utilities
│   ├── douglasPeucker.py             # Graph simplification
│   ├── localserver.py                # HTTP inference server
│   ├── wandb/                         # Weights & Biases logs
│   └── checkpoints/                   # Model checkpoints (not in repo)
├── metrics/
│   ├── topo/                          # TOPO metric (Python)
│   │   ├── topo.py                    # Topology matching
│   │   └── main.py                    # Entry point
│   └── apls/                          # APLS metric (Go)
│       ├── main.go                    # Average Path Length Similarity
│       ├── go.mod                     # Go dependencies
│       └── convert.py                 # pickle→JSON converter
├── prepare_dataset/
│   ├── download.py                    # Dataset downloader
│   ├── config/                        # City configurations
│   └── mapbox.py                      # MapBox imagery interface
└── docker/                            # Inference server container

📈 Training Progress & Hyperparameters

Baseline Run Configuration

# From successful 300k training run
batch_size = 2
image_size = 352
channel_width = 12
resnet_steps = 8
learning_rate = 0.001
lr_decay = 0.5 every 50k steps
mixed_precision = fp32 (stable)
max_steps = 300000
validation_interval = 200 steps
checkpoint_interval = 10000 steps

Loss Weights (Supervised Training)

loss_total = (
    1.0 * keypoint_prob_loss +
    10.0 * direction_prob_loss +
    1000.0 * direction_vector_loss +
    0.1 * segmentation_loss
)

🔧 Key Implementation Details

Python 3 Migrations

  • scipy.ndimage.imreadPIL.Image.open / cv2.imread
  • scipy.misc.imresizecv2.resize / PIL.resize
  • dict.iteritems()dict.items()
  • xrange()range()
  • Integer division ///
  • print statements → print() functions

Data Format

  • Satellite images: normalized (pixel/255 - 0.5) * 0.9
  • Graph pickle format: {node_id: {neighbor_id: {"nid": int, "lat": float, "lon": float}}}
  • Tensor encoding: 26 channels (keypoint, directions, segmentation)

Evaluation Metrics

TOPO (Topology Matching)

  • Measures correctness of road connectivity
  • Accounts for precision and recall of intersections
  • Formula: F1 = 2 * Precision * Recall / (Precision + Recall)

APLS (Average Path Length Similarity)

  • Measures fidelity of predicted paths vs ground truth
  • Accounts for both connectivity and geometry
  • Range: 0-1 (higher is better)

🧪 Validation & Testing

Test Split (27 tiles from 20-city dataset)

# Tiles indexed: 0-179 (180 total)
# Train: 80% (0, 1, 2, ..., indices % 10 != {1, 2})
# Val:   10% (indices % 10 == 1)
# Test:  10% (indices % 10 == 2)  ← 27 test tiles

Metric Computation Pipeline

# TOPO (automatic in train.py)
cd metrics/topo
python main.py -graph_gt gt.p -graph_prop prop.p -output result.txt

# APLS (requires Go)
cd metrics/apls
python convert.py gt.p gt.json
python convert.py prop.p prop.json
go run main.go gt.json prop.json result.txt

🌟 Known Limitations

  1. Overpasses: Model struggles with stacked roads (lacks 3D information)
  2. Sparse annotations: Ground truth from OpenStreetMap may be incomplete
  3. Single-scale inference: Optimized for 2048×2048 tiles at 1m GSD
  4. Training convergence: Requires ~72h on consumer GPU (no distributed training)

📝 Citation

If you use this PyTorch reproduction, please cite the original Sat2Graph:

@inproceedings{he2020sat2graph,
  title={Sat2Graph: Road Graph Extraction through Graph-Tensor Encoding},
  author={He, Songtao and Balakrishnan, Hongyi and Steinberg, Florian and Valada, Abhinav},
  booktitle={ECCV},
  year={2020}
}

Last Updated: March 26, 2026
License: See LICENSE (non-commercial academic use)
Status: ✅ Reproduction verified