Add an optional Wan2.2-TI2V-5B video backbone to OmniWM as a new model arch
family (Omni-Wan-5B) + training module (omni_wan), reusing the
one-directional Omni coupling: the action expert cross-attends f0 video tokens, video never
reads the action, and f0-only inference runs via a prefill cache. Selected purely by config;
the NanoWM / Omni path never changes behavior. HTML rendering of
docs/wan_backbone_design.md.
Every additive branch is gated so the SD/NanoWM/Omni path is never entered for Wan and vice-versa. The seams below were confirmed by code investigation.
The registry (src/models/__init__.py:11) merges NanoWM_models +
Omni_models; dispatch guards on the substring 'NanoWM'/'Omni'
in args.model.arch. New arch Omni-Wan-5B (registered in
new src/models/omni_wan.py → OmniWan_models) deliberately keeps the
Omni- prefix so the joint-knob plumbing still passes
action_expert_depth/action_loss_dim.
# src/models/__init__.py — additive Wan-only knobs
if "Omni-Wan" in arch:
common["video_backbone"] = "wan2.2-ti2v-5b"
common["wan_repo"] = args.model.get("wan_repo", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
common["peft"] = OmegaConf.to_container(args.model.get("peft", {}), resolve=True)
common["latent_channels"] = args.model.get("latent_channels", 48)
New module kind omni_wan: an Omni-Wan arch with any
other module raises. OmniWanTrainingModule lives in a new
src/experiments/train_wan_experiment.py, imported lazily inside the
omni_wan branch so diffusers' Wan classes / peft are not required for
NanoWM/Omni runs.
OmniWan model classComposition over the frozen diffusers WanTransformer3DModel plus a
from-scratch action expert — it does not subclass NanoWM/Omni (Wan's 3D patchify,
per-token 2D timestep, RoPE, and 30 full-attention blocks have no 1:1 mapping; the Omni assert
action_expert_depth == len(blocks)//2 is structurally invalid for Wan). Public joint
API is preserved: forward_joint, prefill_f0, action_from_cache.
ActionInEmbedder, ActionBlock, ActionHead,
TimestepEmbedder, CrossAttention are reused by import from
omni.py/nanowm.py — backbone-agnostic; only the cross-attn K/V width
changes. The empty-prompt umT5 embedding [1, 512, 4096] is precomputed once and
registered as a buffer (world model is null-text).
Wan has no output_hidden_states; we fork the block loop (copying
WanTransformer3DModel.forward) and after each of the 30 blocks slice the
first-frame latent tokens. Token order is frame-major
(patch_embedding(Conv3d).flatten(2).transpose(1,2)), so f0 = the first
n_f0 = (H_lat/2)·(W_lat/2) tokens.
Wan TI2V conditions on the initial image by latent inpainting + per-token 2D timestep
masking (expand_timesteps), not a separate cross-attn stream. We adopt the
same mechanism for f0 = current obs: place the clean current-obs latent at latent-frame 0,
set its per-token timestep to 0, and run
latent_model_input = (1-mask)·cond + mask·noisy. Action conditioning of the video
goes additively into Wan's timestep_proj/AdaLN temb per latent-frame via a zero-init
action_to_temb hook, time-shifted so latent-frame 0 receives zero action —
f0 stays a leak-free, policy-readable observation.
The core break: the repo's contract is per-frame 2D SD latents
[T, V, 4, 32, 32] sliced by frame index. Wan's VAE is 3D-causal:
T_lat = 1 + (T-1)//4, z_dim=48, spatial 16×. On-disk latents become
[T_lat, V, 48, 16, 16]; the naive per-frame slice is wrong for Wan.
| latent channels | 48 (SD: 4) |
|---|---|
| spatial compression | 16× (SD: 8×) — 256px → 16×16 |
| temporal compression | 4× causal — T_lat = 1 + (T-1)//4 |
| normalization | per-channel (z − latents_mean) / latents_std — no scalar scaling_factor |
Precompute over on-the-fly: new tool
src/tools/precompute_latents_wan.py (fork; SD tool untouched) encodes each view
independently in 3D over the whole episode and stores normalized fp16
[T_lat, V, 48, 16, 16]. meta.json records z_dim, scale
factors, latents_mean/std, and the frame_to_latent map.
Parallel cache key + slice logic: a new mixture knob
latents_backbone (default sdvae) switches the latent dir suffix
(<name>_wanvae_<size>) and dispatches precompute. Data sources gain a
temporal_compression field (default 1) mapping a raw-frame window to its
latent-frame window; with 1 the behavior is bit-identical to today.
Wan is a single video stream; the repo carries V views (e.g. 2 wrist cams). Encode each
view independently through the Wan VAE, then token-concat the views'
patch tokens along the sequence with a learned per-view embedding — matching the existing
view_embed + (v p) contract. f0 tokens become V·n_f0 (both
views' frame-0 tokens). Do not concat views along time (mixes views temporally,
breaks f0). First deliverable runs V=1 to de-risk; lift to V=2 once V=1 trains.
Wan is flow-matching (UniPC flow_prediction, flow_shift=5.0); the
Gaussian compute_joint_loss stays untouched. A parallel
compute_joint_loss_flow lands in train_wan_experiment.py:
| Component | Rule |
|---|---|
| Forward | x_t = σ·noise + (1−σ)·x0; target = noise − x0 (velocity). Video and action get independently sampled σ (two schedulers, per FastWAM). |
| Video | latent-frame 0 kept clean (per-token timestep 0) and dropped from the loss — only future latent frames supervised, via Wan's per-token 2D timestep path. |
| Action | full chunk noised at its own σ; flow target noise_a − a0; MSE masked by action_mask > 0.5. |
| Total | λ_video·L_video + λ_action·L_action (defaults 1.0 / 1.0). |
A new wrapper src/diffusion/flow_match.py holds add_noise,
training_target, Euler step, and σ shifting —
gaussian_diffusion.py untouched.
The existing dfot_action_sample is coupled to Gaussian DDIM
(alphas_cumprod); a parallel dfot_action_sample_flow (Euler ODE over
sigmas) is added — df_sample.py untouched. This is what makes the 5B policy
tractable, the FastWAM prefill_video_cache analogue and identical in spirit to the
current Omni prefill. OmniWanTrainingModule.policy_sample mirrors the Omni version
and reuses _denorm_action unchanged.
| Component | Treatment |
|---|---|
| Wan transformer | LoRA r=16–32, α=2r on attn q/k/v/out + ffn projections; base frozen; gradient checkpointing. Tens of M trainable. Add peft to pyproject deps. |
| Action expert + v2a + action_to_temb | Fully trained (small, from scratch) |
| Wan VAE | Frozen, fp32, not resident during training (precomputed latents) |
| umT5 text encoder (~11GB) | Frozen, not resident (precomputed null-prompt buffer) |
| Precision | bf16-mixed (repo default) |
| Parallelism | Auto-DDP as today; LoRA fits on 2–4× H200 (143GB), 8 for throughput. Full 5B fine-tune would need FSDP FULL_SHARD (~40GB AdamW states) — LoRA-first avoids this. |
Two new config files, siblings of model/omni_s2.yaml /
experiment/train_omni.yaml (which stay untouched):
# src/configs/model/omni_wan_5b.yaml
arch: Omni-Wan-5B
image_size: 256
latent_channels: 48
num_frames: 33 # -> T_lat = 9
action_expert_hidden: 1024
peft: {type: lora, r: 32, alpha: 64}
wan_repo: Wan-AI/Wan2.2-TI2V-5B-Diffusers
# src/configs/experiment/train_omni_wan.yaml
module: omni_wan
mixture: {enabled: true, latents: auto, latents_backbone: wanvae}
diffusion: {objective: flow_match, flow_shift: 5.0, action_loss_weight: 1.0}
infra: {gradient_checkpointing: true, vae_precision: fp32}
| # | Smoke test (scaffold-first) | Asserts |
|---|---|---|
| 1 | Import/registry | get_models returns OmniWan for Omni-Wan-5B; NanoWM/Omni still resolve without peft import |
| 2 | Load | from_pretrained under HF_HOME redirect; config = in/out 48, 30 layers, 3072 hidden |
| 3 | Forward (tiny) | v_pred shape == input; a_pred [B, T_action, A] and == 0 at init; f0_list len 30, each [B, V·64, 3072] |
| 4 | VAE round-trip | wan_encode/wan_decode finite + [B, 3, T, H, W] |
| 5 | Prefill equivalence | forward_joint(...)[1] == action_from_cache(a_t, t_a, prefill_f0(...)) — the load-bearing inference correctness check |
| 6 | Loss | compute_joint_loss_flow finite; action loss 0 when action_mask==0 |
| File | Change |
|---|---|
src/models/nanowm.py · omni.py · gaussian_diffusion.py · df_sample.py · respace.py · vae_ops.py · compute_joint_loss | untouched |
src/models/__init__.py | additive: merge OmniWan_models, specialize guard, Wan-only knobs |
train_experiment.py module dispatch | additive: omni_wan branch + de-substring guard |
mixture_dataset.py · world_model_dataset.py · lerobot_data_source.py | additive: latents_backbone key + compression-aware latent slice (default no-op) |
src/models/omni_wan.py · src/experiments/train_wan_experiment.py · src/diffusion/flow_match.py · src/utils/wan_vae_ops.py · src/tools/precompute_latents_wan.py | NEW |
src/configs/model/omni_wan_5b.yaml · src/configs/experiment/train_omni_wan.yaml | NEW |
pyproject.toml | add peft, set HF_HOME activation env |