Skip to content

facebookresearch/pixio

Pixio

A capable vision encoder dedicated to dense prediction, simply by pixel reconstruction

Paper Hugging Model Card


Official implementation of Pixio from the paper In Pursuit of Pixel Supervision for Visual Pre-training.

Lihe Yang, Shang-Wen Li, Yang Li, Xinjie Lei, Dong Wang, Abdelrahman Mohamed, Hengshuang Zhao, Hu Xu

[BibTeX]

Pixio is largely built on MAE, with three minimal yet critical algorithm updates:

  • deeper decoder
  • larger masking granularity
  • more class tokens

Pixio also updates MAE's pre-training data from ImageNet-1K to MetaCLIP-2B with a simple self-curation strategy.

Performance

Monocular depth estimation ($\delta_1 \uparrow$, frozen encoder)

Method ViT #Params NYUv2 (DPT head) KITTI (DPT head) NYUv2 (linear head) KITTI (linear head)
MAE H/14 631M 80.8 90.9 70.3 79.4
DINOv2 g/14 1137M 90.1 94.6 75.3 78.1
DINOv3 H+/16 841M 93.2 95.6 76.3 73.2
Pixio H/16 631M 95.5 96.7 90.8 90.3

Feed-forward 3D reconstruction (MapAnything, ScanNet++ v2)

Method ViT #Params Scale (rel $\downarrow$) Points (rel $\downarrow$) Points ($\tau \uparrow$) Pose (auc5 $\uparrow$) Depth (rel $\downarrow$) Depth ($\tau \uparrow$)
MAE H/14 631M 0.050 0.057 63.3 65.6 0.058 55.4
DINOv2 L/14 304M 0.041 0.052 67.6 73.2 0.052 60.6
DINOv3 H+/16 841M 0.035 0.051 69.0 68.5 0.051 61.2
Pixio H/16 631M 0.029 0.041 78.8 80.5 0.042 72.4

Semantic segmentation (mIoU $\uparrow$, frozen encoder)

Method ViT #Params ADE20K (DPT) VOC (DPT) LoveDA (DPT) ADE20K (linear) VOC (linear) LoveDA (linear)
MAE H/14 631M 37.6 76.0 50.2 35.2 70.8 47.6
DINOv2 g/14 1137M 51.5 85.2 55.0 49.0 81.8 51.9
DINOv3 H+/16 841M 52.3 85.6 55.3 50.3 82.1 52.7
Pixio H/16 631M 53.6 85.9 54.7 50.2 82.2 53.9

Installation

This codebase is developed with PyTorch 2.8.0 + CUDA 12.8.

conda create -n pixio python=3.10.18
conda activate pixio
pip install -r requirements.txt

Inference (may need Huggingface login)

You can either use source code from this repo or call Transformers APIs.

Source Code

Pixio ViT models pre-trained on web-scale dataset (MetaCLIP-2B):

Model Parameters Pre-training Dataset Download
Pixio-B/16 86M MetaCLIP-2B [link]
Pixio-L/16 303M MetaCLIP-2B [link]
Pixio-H/16 631M MetaCLIP-2B [link]
Pixio-1B/16 1362M MetaCLIP-2B [link]
Pixio-5B/16 5441M MetaCLIP-2B [link]
cd pixio

Then testing as follows:

from PIL import Image
from torchvision import transforms

from pixio import pixio_vith16

model = pixio_vith16(pretrained="your/checkpoint/path")

# you can try larger resolution, but ensure both sides are divisible by 16
transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=3), # 3 is bicubic
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

img = Image.open("your/image/path").convert("RGB")
img = transform(img)

# block-wise features containing class tokens and patch tokens
features = model(img.unsqueeze(0))

Transformers (may need Huggingface login)

You can find all HuggingFace paths under this collection.

from transformers import AutoImageProcessor, AutoModel
from PIL import Image

img = Image.open("your/image/path")

processor = AutoImageProcessor.from_pretrained("facebook/pixio-vith16")
model = AutoModel.from_pretrained("facebook/pixio-vith16")

inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)
features_norm = outputs.last_hidden_state # 8 class tokens + patch tokens after last LayerNorm
features = outputs.hidden_states[-1] # 8 class tokens + patch tokens before last LayerNorm

Pre-training

Data Preparation

We provide examples using ImageNet-1K and ImageNet-21K. We use ImageNet datasets organized as tar files from HuggingFace:

Launch Pre-training

cd pretraining

# specify your data path in the script
bash scripts/pretrain_pixio_vith16_imagenet.sh

Evaluation

We provide the evaluation code for monocular depth estimation (NYUv2, KITTI), semantic segmentation (ADE20K, Pascal VOC, LoveDA), and k-NN classification (ImageNet-1K).

Data Preparation

Click here for details

Monocular Depth Estimation

We follow ZoeDepth and BTS, preparing the data as follows:

Please organize the data as follows:

├── [Your NYUv2 Path]
    ├── sync
    │   ├── basement_0001a
    │   ├── bathroom_0001
    │   └── ...    
    └── official_splits
        └── test
            ├── bathroom
            ├── bedroom
            └── ...

├── [Your KITTI Path]
    ├── images
    │   ├── 2011_09_26
    │   ├── 2011_09_28
    │   └── ...    
    └── annotations # extracted from data_depth_annotated.zip
        ├── 2011_09_26_drive_0001_sync
        ├── 2011_09_26_drive_0002_sync
        └── ...

Semantic Segmentation

We mainly follow UniMatch V2, preparing the data as follows:

Please organize the data as follows:

├── [Your ADE20K Path]
    ├── images
    │   ├── training
    │   └── validation
    └── annotations
        ├── training
        └── validation

├── [Your Pascal Path]
    ├── JPEGImages
    └── SegmentationClass

├── [Your LoveDA Path]
    ├── Train/Train
    └── Val/Val

k-NN Classification

Following this script to prepare ImageNet-1K.

Launch Evaluation

cd evaluation

model="pixio_vith16"
pretrained="your/checkpoint/path"

# specify the data path in config files or script
sbatch launch_monodepth.sh monodepth/configs/nyuv2_dpt.yaml $model $pretrained
sbatch launch_semseg.sh semseg/configs/ade20k_linear.yaml $model $pretrained
sbatch launch_knn.sh $model $pretrained

# or run all evaluations together
bash run_all.sh $model $pretrained

Distillation

Launch Distillation

cd distillation

# specify your data path and teacher checkpoint path in the scripts
bash scripts/distill_pixio_vit5b16_to_vit1b16+vith16_imagenet.sh

License

Pixio is licensed under Facebook license.

Acknowledgement

We sincerely thank the authors of MAE, DINO, DINOv2, and DINOv3 for open-sourcing their code and models.

Citation

@article{pixio,
  title={In Pursuit of Pixel Supervision for Visual Pre-training},
  author={Yang, Lihe and Li, Shang-Wen and Li, Yang and Lei, Xinjie and Wang, Dong and Mohamed, Abdelrahman and Zhao, Hengshuang and Xu, Hu},
  journal={arXiv:2512.15715},
  year={2025}
}

About

Pixio: a capable vision encoder dedicated to dense prediction, simply by pixel reconstruction

Topics

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published