Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 0 additions & 90 deletions utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,96 +133,6 @@ def get_predicted_noise(
return pred_epsilon


# From LatentConsistencyModel.get_guidance_scale_embedding
def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings

Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0

half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb


def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]


# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
scaled_timestep = timestep_scaling * timestep
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_skip, c_out


# Compare LCMScheduler.step, Step 4
def get_predicted_original_sample(
model_output, timesteps, sample, prediction_type, alphas, sigmas
):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)

return pred_x_0


# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(
model_output, timesteps, sample, prediction_type, alphas, sigmas
):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)

return pred_epsilon


def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
extra_params = extra_params if len(extra_params.keys()) > 0 else None
return {
Expand Down
34 changes: 34 additions & 0 deletions wan2_consistency/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Wan2.2 Consistency Distillation

Lightweight project scaffold for training classic consistency-distilled LoRA experts on **Wan2.2** without motion guidance or reward models. The goal is to produce two LoRAs from a single Wan2.2 base checkpoint:
- **High-noise expert**: distills the teacher at the upper timesteps.
- **Low-noise expert**: distills the teacher at the lower timesteps.

## Layout
- `train_lcd_wan22.py`: entrypoint for distillation.
- `config/lcd_wan22.yaml`: editable defaults (paths, noise ranges, LoRA ranks, etc.).
- `data/local_video_dataset.py`: simple local CSV-driven video + conditioning image loader (no S3 deps).
- `modeling/wan22_adapter.py`: helper to load Wan2.2 via `DiffusionPipeline` and to encode prompts/videos.

## Quickstart (skeleton)
```
pip install -r requirements.txt # make sure diffusers/accelerate/decord/omegaconf/torch are installed

python wan2_consistency/train_lcd_wan22.py \
--config wan2_consistency/config/lcd_wan22.yaml \
--base_model Wan-2.2/YourModelIdOrPath \
--train_csv /path/to/train.csv \
--output_dir /path/to/output
```

`train.csv` format (minimal):
```
video_path,prompt,image_path
/abs/path/to/video.mp4,"An astronaut riding a horse","/abs/path/to/cond_image.png"
```
`image_path` is optional; if omitted, the first sampled frame from the video is used as the conditioning image.

## Notes
- The current scaffold focuses purely on classic consistency distillation; motion guidance, reward models, and multi-reward mixing are intentionally absent.
- Wan2.2 specifics can vary. If the tokenizer/text encoder or VAE layout differs from a standard Diffusers video pipeline, adjust `Wan22Adapter.encode_prompt`, `encode_video_latents`, and `encode_image_latents` accordingly.
- Two separate LoRA checkpoints are written: `lora_high_noise.pt` and `lora_low_noise.pt`.
1 change: 1 addition & 0 deletions wan2_consistency/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

44 changes: 44 additions & 0 deletions wan2_consistency/config/lcd_wan22.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Base model and IO
base_model: "Wan2.2/your-model-id" # HF repo ID or local path for Wan2.2
revision: null # Optional model revision
output_dir: "output/wan2.2-lcd" # Where LoRA checkpoints will be saved

# Data
train_csv: "data/train.csv" # CSV with columns: video_path,prompt
sample_fps: 8
sample_frames: 16
sample_size: [320, 512] # H, W

# Distillation / noise
num_train_timesteps: 1000 # Teacher diffusion timesteps (matches Wan2.2 scheduler)
num_ddim_timesteps: 50 # DDIM substeps used by the teacher
distill_topk: 4 # Steps to roll back per LCD target (k)
high_noise_range: [600, 999] # Inclusive range for the high-noise expert
low_noise_range: [0, 399] # Inclusive range for the low-noise expert
w_min: 5.0 # Guidance sampling lower bound
w_max: 15.0 # Guidance sampling upper bound
timestep_scaling_factor: 10.0 # LCD boundary scaling (see paper)
prediction_type: "epsilon" # Teacher prediction type

# LoRA
lora_rank: 32
lora_alpha: 32
lora_dropout: 0.0
unet_replace_modules: ["UNet3DConditionModel"] # Change if Wan2.2 uses a different UNet class name

# Training
train_batch_size: 1
learning_rate: 1.0e-4
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 0.0
adam_epsilon: 1.0e-8
max_train_steps: 10000
gradient_accumulation_steps: 1
mixed_precision: "bf16" # ["no", "fp16", "bf16"]
seed: 42

# Logging / checkpoints
checkpoint_every: 1000
save_dtype: "float32" # LoRA save dtype
log_every: 10
1 change: 1 addition & 0 deletions wan2_consistency/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

90 changes: 90 additions & 0 deletions wan2_consistency/data/local_video_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pandas as pd
import torch
import torchvision
from torch.utils.data import Dataset
import torch.nn.functional as F

from utils.common_utils import read_video_to_tensor


class LocalVideoDataset(Dataset):
"""
Minimal CSV-driven video dataset for local training.

CSV columns:
- video_path: absolute or relative path to an .mp4 (or decord-supported) file
- prompt: text prompt for the clip
- image_path: optional path to a conditioning image. If missing, the first sampled frame is used.
"""

def __init__(
self,
csv_path: str,
sample_fps: int = 8,
sample_frames: int = 16,
sample_size=(320, 512),
):
self.df = pd.read_csv(csv_path)
if "video_path" not in self.df or "prompt" not in self.df:
raise ValueError("CSV must contain 'video_path' and 'prompt' columns.")
self.sample_fps = sample_fps
self.sample_frames = sample_frames
if isinstance(sample_size, int):
sample_size = (sample_size, sample_size)
self.sample_size = sample_size

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
row = self.df.iloc[idx]
video_path = row["video_path"]
prompt = row["prompt"]

pixel_values = read_video_to_tensor(
video_path,
self.sample_fps,
self.sample_frames,
uniform_sampling=False,
)

# Resize all frames jointly using trilinear interpolate (treat frames as depth)
pixel_values = pixel_values.unsqueeze(0).permute(0, 2, 1, 3, 4) # 1, C, T, H, W
pixel_values = F.interpolate(
pixel_values,
size=(self.sample_frames, self.sample_size[0], self.sample_size[1]),
mode="trilinear",
align_corners=False,
)
pixel_values = pixel_values.squeeze(0).permute(1, 0, 2, 3) # T, C, H, W

# Normalize to [-1, 1]
pixel_values = pixel_values * 2.0 - 1.0

# Conditioning image: use provided image_path if available, else first frame
if "image_path" in row and isinstance(row["image_path"], str) and row["image_path"]:
image_path = row["image_path"]
try:
img = torch.clamp(
torch.as_tensor(
F.interpolate(
torchvision.io.read_image(image_path).unsqueeze(0).float() / 255.0,
size=self.sample_size,
mode="bilinear",
align_corners=False,
)
),
0.0,
1.0,
).squeeze(0)
cond_image = img * 2.0 - 1.0
except Exception:
cond_image = pixel_values[0]
else:
cond_image = pixel_values[0]

return {
"pixel_values": pixel_values, # (T, C, H, W) in [-1, 1]
"prompt": prompt,
"cond_image": cond_image, # (C, H, W) in [-1, 1]
}
1 change: 1 addition & 0 deletions wan2_consistency/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

120 changes: 120 additions & 0 deletions wan2_consistency/modeling/wan22_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
from diffusers import DiffusionPipeline


@dataclass
class Wan22Components:
unet: torch.nn.Module
vae: torch.nn.Module
text_encoder: torch.nn.Module
tokenizer: object
vae_scale_factor: float
scheduler: object


class Wan22Adapter:
"""
Thin wrapper around a Wan2.2 DiffusionPipeline.
If your Wan2.2 checkpoint uses a custom pipeline or tokenizer layout, adjust the
loader/encoder methods here.
"""

def __init__(
self,
model_id_or_path: str,
revision: Optional[str] = None,
torch_dtype=torch.float16,
):
self.pipe = DiffusionPipeline.from_pretrained(
model_id_or_path,
revision=revision,
torch_dtype=torch_dtype,
)
self.pipe.set_progress_bar_config(disable=True)

def to(self, device: torch.device, dtype: torch.dtype):
self.pipe.to(device=device, dtype=dtype)
return self

@property
def components(self) -> Wan22Components:
vae_scale = getattr(self.pipe.vae, "scaling_factor", 0.18215)
return Wan22Components(
unet=self.pipe.unet,
vae=self.pipe.vae,
text_encoder=self.pipe.text_encoder,
tokenizer=getattr(self.pipe, "tokenizer", None),
vae_scale_factor=vae_scale,
scheduler=getattr(self.pipe, "scheduler", None),
)

def encode_prompt(
self,
prompts: List[str],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if self.components.tokenizer is None or self.components.text_encoder is None:
raise ValueError(
"Tokenizer or text encoder missing. Please adjust Wan22Adapter for your model."
)

tokenizer = self.components.tokenizer
text_encoder = self.components.text_encoder

tokens = tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokens.input_ids.to(device)
attention_mask = tokens.attention_mask.to(device)
with torch.no_grad():
enc_out = text_encoder(input_ids, attention_mask=attention_mask)
# Handle encoders that return tuples vs. BaseModelOutput
hidden_states = enc_out[0] if isinstance(enc_out, (tuple, list)) else enc_out.last_hidden_state
return hidden_states.to(dtype=dtype)

def encode_video_latents(
self,
videos: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""
videos: (B, T, C, H, W) in [-1, 1]
returns latents shaped for a 3D UNet: (B, C, T, H', W')
"""
vae = self.components.vae
vae_scale = self.components.vae_scale_factor
videos = videos.to(device=device, dtype=dtype)
b, t, c, h, w = videos.shape
flat = videos.reshape(b * t, c, h, w)
with torch.no_grad():
latents = vae.encode(flat).latent_dist.sample() * vae_scale
c_latent, h_latent, w_latent = latents.shape[1:]
latents = latents.view(b, t, c_latent, h_latent, w_latent)
latents = latents.permute(0, 2, 1, 3, 4).contiguous()
return latents

def encode_image_latents(
self,
images: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""
images: (B, C, H, W) in [-1, 1]
returns latents: (B, C, H', W')
"""
vae = self.components.vae
vae_scale = self.components.vae_scale_factor
images = images.to(device=device, dtype=dtype)
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample() * vae_scale
return latents
Loading