A complete Python 3 + PyTorch reproduction of Sat2Graph: Road Graph Extraction through Graph-Tensor Encoding (ECCV 2020), with integrated TOPO/APLS evaluation metrics.
This project is a faithful PyTorch 3 refactoring of the original Sat2Graph codebase (Python 2.7 + TensorFlow 1.x). We deeply appreciate the original authors and maintainers:
- He et al. (ECCV 2020) for the core Sat2Graph method (arxiv.org/pdf/2007.09547.pdf)
- Original repository: songtaohe/Sat2Graph
| Method | TOPO-F1 | APLS | Notes |
|---|---|---|---|
| Original Sat2Graph (Paper) | 76.26% | 63.14% | 20-city US model from ECCV 2020 |
| Our PyTorch Reproduction | 77.04% | 65.05% | PyTorch reimplementation (fp32, 300k steps) |
| Improvement | +0.78pp | +1.91pp | ✅ Successful reproduction |
Precision: 84.59% (mean across 27 tiles)
Overall-Recall: 76.40%
F1 Score: 80.35%
Mean APLS: 65.05%
Median APLS: 69.29%
Min APLS: 31.08%
Max APLS: 80.72%
Std Dev: 12.92%
- Training time: ~72 hours on RTX 3060 (8GB VRAM)
- Convergence: 300k steps, learning rate 0.001 with 0.5 decay every 50k steps
- Mixed Precision: fp32 stable (fp16 requires careful tuning)
- Inference speed: ~2-3 seconds per 2048×2048 tile (sliding window 352×352)
Sat2GraphModel (55M parameters)
├── DLA Encoder with ResNet blocks
├── Multi-scale decoder with 26-channel output:
│ ├── Channels 0-1: Keypoint probability (vertex)
│ ├── Channels 2-25: 6 directions × 4 channels
│ │ ├── 1 channel: Direction probability (edge exists)
│ │ ├── 2 channels: Direction vector (dx, dy)
│ │ └── 1 channel: Reserved
│ └── Channels 24-25: Segmentation mask (optional)
└── Graph decoding pipeline (3-pass snapping + R-tree refinement)
| Aspect | TensorFlow | PyTorch |
|---|---|---|
| Padding | tf.pad + VALID |
F.pad + padding=0 |
| BatchNorm | decay=0.99 | momentum=0.01 (inverted) |
| Data format | NHWC | NCHW (with conversion helpers) |
| Training loop | tf.Session |
Eager execution |
| Checkpoint | tf.train.Saver |
torch.save/load |
# Clone repository
git clone https://github.com/Oops-maker/sat2graph-py3-pytorch.git
cd sat2graph-py3-pytorch
# Install dependencies (assuming conda PyTorch environment)
pip install torch torchvision opencv-python pillow numpy scipy rtree scikit-image wandb
# For metrics evaluation, install Go (optional)
# APLS metric requires: go >= 1.13
# TOPO metric: pure Python (hopcroftkarp included)cd model
# Train on 20-city dataset (144 train tiles)
python train.py \
-model_save checkpoints \
-image_size 352 \
-batch_size 2 \
-lr 0.001 \
-channel 12 \
-resnet_step 8 \
-max_steps 300000 \
-mode train
# Expected: 72h on RTX 3060 (8GB)# Test on 27 test tiles with TOPO/APLS metrics
python train.py \
-model_save checkpoints \
-image_size 352 \
-model_recover checkpoints/model_final.pth \
-mode test \
-eval_metrics
# Outputs:
# - model/outputs/region_*_output_graph.p (predicted graphs)
# - model/outputs/region_*_topo.txt (TOPO metrics)
# - model/outputs/region_*_apls.txt (APLS metrics)python infer.py path/to/image.png output_prefix
# Generates: output_prefix_graph.p, output_prefix_graph.json, output_prefix.pngsat2graph-py3-pytorch/
├── README.md # Original paper README
├── README_REPRODUCTION.md # This file
├── model/
│ ├── model.py # Sat2GraphModel (PyTorch nn.Module)
│ ├── train.py # Training & evaluation main script
│ ├── infer.py # Single image inference
│ ├── decoder.py # Graph decoding (tensor→graph)
│ ├── dataloader.py # Dataset loading (20-city dataset)
│ ├── resnet.py # ResNet blocks
│ ├── tf_common_layer.py # Conv layers, BatchNorm utilities
│ ├── common.py # Common utilities
│ ├── douglasPeucker.py # Graph simplification
│ ├── localserver.py # HTTP inference server
│ ├── wandb/ # Weights & Biases logs
│ └── checkpoints/ # Model checkpoints (not in repo)
├── metrics/
│ ├── topo/ # TOPO metric (Python)
│ │ ├── topo.py # Topology matching
│ │ └── main.py # Entry point
│ └── apls/ # APLS metric (Go)
│ ├── main.go # Average Path Length Similarity
│ ├── go.mod # Go dependencies
│ └── convert.py # pickle→JSON converter
├── prepare_dataset/
│ ├── download.py # Dataset downloader
│ ├── config/ # City configurations
│ └── mapbox.py # MapBox imagery interface
└── docker/ # Inference server container
# From successful 300k training run
batch_size = 2
image_size = 352
channel_width = 12
resnet_steps = 8
learning_rate = 0.001
lr_decay = 0.5 every 50k steps
mixed_precision = fp32 (stable)
max_steps = 300000
validation_interval = 200 steps
checkpoint_interval = 10000 stepsloss_total = (
1.0 * keypoint_prob_loss +
10.0 * direction_prob_loss +
1000.0 * direction_vector_loss +
0.1 * segmentation_loss
)scipy.ndimage.imread→PIL.Image.open/cv2.imreadscipy.misc.imresize→cv2.resize/PIL.resizedict.iteritems()→dict.items()xrange()→range()- Integer division
/→// printstatements →print()functions
- Satellite images: normalized
(pixel/255 - 0.5) * 0.9 - Graph pickle format:
{node_id: {neighbor_id: {"nid": int, "lat": float, "lon": float}}} - Tensor encoding: 26 channels (keypoint, directions, segmentation)
- Measures correctness of road connectivity
- Accounts for precision and recall of intersections
- Formula:
F1 = 2 * Precision * Recall / (Precision + Recall)
- Measures fidelity of predicted paths vs ground truth
- Accounts for both connectivity and geometry
- Range: 0-1 (higher is better)
# Tiles indexed: 0-179 (180 total)
# Train: 80% (0, 1, 2, ..., indices % 10 != {1, 2})
# Val: 10% (indices % 10 == 1)
# Test: 10% (indices % 10 == 2) ← 27 test tiles# TOPO (automatic in train.py)
cd metrics/topo
python main.py -graph_gt gt.p -graph_prop prop.p -output result.txt
# APLS (requires Go)
cd metrics/apls
python convert.py gt.p gt.json
python convert.py prop.p prop.json
go run main.go gt.json prop.json result.txt- Overpasses: Model struggles with stacked roads (lacks 3D information)
- Sparse annotations: Ground truth from OpenStreetMap may be incomplete
- Single-scale inference: Optimized for 2048×2048 tiles at 1m GSD
- Training convergence: Requires ~72h on consumer GPU (no distributed training)
If you use this PyTorch reproduction, please cite the original Sat2Graph:
@inproceedings{he2020sat2graph,
title={Sat2Graph: Road Graph Extraction through Graph-Tensor Encoding},
author={He, Songtao and Balakrishnan, Hongyi and Steinberg, Florian and Valada, Abhinav},
booktitle={ECCV},
year={2020}
}Last Updated: March 26, 2026
License: See LICENSE (non-commercial academic use)
Status: ✅ Reproduction verified