OmniWM Spec · Omni-Wan backbone
Design spec — optional Wan2.2-TI2V-5B video backbone

OmniWM + Wan2.2-TI2V-5B backbone

Add an optional Wan2.2-TI2V-5B video backbone to OmniWM as a new model arch family (Omni-Wan-5B) + training module (omni_wan), reusing the one-directional Omni coupling: the action expert cross-attends f0 video tokens, video never reads the action, and f0-only inference runs via a prefill cache. Selected purely by config; the NanoWM / Omni path never changes behavior. HTML rendering of docs/wan_backbone_design.md.

00

Goal & non-goals

Every additive branch is gated so the SD/NanoWM/Omni path is never entered for Wan and vice-versa. The seams below were confirmed by code investigation.

01

Arch family & config gate

The registry (src/models/__init__.py:11) merges NanoWM_models + Omni_models; dispatch guards on the substring 'NanoWM'/'Omni' in args.model.arch. New arch Omni-Wan-5B (registered in new src/models/omni_wan.pyOmniWan_models) deliberately keeps the Omni- prefix so the joint-knob plumbing still passes action_expert_depth/action_loss_dim.

# src/models/__init__.py — additive Wan-only knobs
if "Omni-Wan" in arch:
    common["video_backbone"] = "wan2.2-ti2v-5b"
    common["wan_repo"] = args.model.get("wan_repo", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
    common["peft"] = OmegaConf.to_container(args.model.get("peft", {}), resolve=True)
    common["latent_channels"] = args.model.get("latent_channels", 48)

New module kind omni_wan: an Omni-Wan arch with any other module raises. OmniWanTrainingModule lives in a new src/experiments/train_wan_experiment.py, imported lazily inside the omni_wan branch so diffusers' Wan classes / peft are not required for NanoWM/Omni runs.

02

The OmniWan model class

Composition over the frozen diffusers WanTransformer3DModel plus a from-scratch action expert — it does not subclass NanoWM/Omni (Wan's 3D patchify, per-token 2D timestep, RoPE, and 30 full-attention blocks have no 1:1 mapping; the Omni assert action_expert_depth == len(blocks)//2 is structurally invalid for Wan). Public joint API is preserved: forward_joint, prefill_f0, action_from_cache.

OmniWan — frozen Wan tower + trained action expert
Video tower (per layer j = 0..29)
Frozen + LoRA WanTransformer3DModel
30 blocks · 24h × 128 = 3072 · ffn 14336 · in/out 48ch
forked forward loop
Cache f0_list[j]
first-frame latent tokens · [B, V·n_f0, 3072] · n_f0 = 8×8 = 64 @ 256px
Action expert (trained from scratch, hidden Da = 1024)
Embed action_in + t-embedder
Linear(action_dim, Da) · 1 token per action step
Block ×30 ActionBlock + gated v2a cross-attn
Q: Da · K/V: f0_list[j] (3072 → Da) · zero-init gate
Head ActionHead (zero-init)
a_pred == 0 at step 0 → clean warm-start
frozen (+LoRA) cached f0 trained head/loss

ActionInEmbedder, ActionBlock, ActionHead, TimestepEmbedder, CrossAttention are reused by import from omni.py/nanowm.py — backbone-agnostic; only the cross-attn K/V width changes. The empty-prompt umT5 embedding [1, 512, 4096] is precomputed once and registered as a buffer (world model is null-text).

f0 token surgery

Wan has no output_hidden_states; we fork the block loop (copying WanTransformer3DModel.forward) and after each of the 30 blocks slice the first-frame latent tokens. Token order is frame-major (patch_embedding(Conv3d).flatten(2).transpose(1,2)), so f0 = the first n_f0 = (H_lat/2)·(W_lat/2) tokens.

Cross-attn realization

Clean-action conditioning into Wan

Wan TI2V conditions on the initial image by latent inpainting + per-token 2D timestep masking (expand_timesteps), not a separate cross-attn stream. We adopt the same mechanism for f0 = current obs: place the clean current-obs latent at latent-frame 0, set its per-token timestep to 0, and run latent_model_input = (1-mask)·cond + mask·noisy. Action conditioning of the video goes additively into Wan's timestep_proj/AdaLN temb per latent-frame via a zero-init action_to_temb hook, time-shifted so latent-frame 0 receives zero action — f0 stays a leak-free, policy-readable observation.

03

Wan VAE / latent pipeline

The core break: the repo's contract is per-frame 2D SD latents [T, V, 4, 32, 32] sliced by frame index. Wan's VAE is 3D-causal: T_lat = 1 + (T-1)//4, z_dim=48, spatial 16×. On-disk latents become [T_lat, V, 48, 16, 16]; the naive per-frame slice is wrong for Wan.

latent channels48 (SD: 4)
spatial compression16× (SD: 8×) — 256px → 16×16
temporal compression4× causal — T_lat = 1 + (T-1)//4
normalizationper-channel (z − latents_mean) / latents_std — no scalar scaling_factor

Precompute over on-the-fly: new tool src/tools/precompute_latents_wan.py (fork; SD tool untouched) encodes each view independently in 3D over the whole episode and stores normalized fp16 [T_lat, V, 48, 16, 16]. meta.json records z_dim, scale factors, latents_mean/std, and the frame_to_latent map.

Parallel cache key + slice logic: a new mixture knob latents_backbone (default sdvae) switches the latent dir suffix (<name>_wanvae_<size>) and dispatches precompute. Data sources gain a temporal_compression field (default 1) mapping a raw-frame window to its latent-frame window; with 1 the behavior is bit-identical to today.

04

Multiview

Wan is a single video stream; the repo carries V views (e.g. 2 wrist cams). Encode each view independently through the Wan VAE, then token-concat the views' patch tokens along the sequence with a learned per-view embedding — matching the existing view_embed + (v p) contract. f0 tokens become V·n_f0 (both views' frame-0 tokens). Do not concat views along time (mixes views temporally, breaks f0). First deliverable runs V=1 to de-risk; lift to V=2 once V=1 trains.

05

Flow-matching joint loss

Wan is flow-matching (UniPC flow_prediction, flow_shift=5.0); the Gaussian compute_joint_loss stays untouched. A parallel compute_joint_loss_flow lands in train_wan_experiment.py:

ComponentRule
Forwardx_t = σ·noise + (1−σ)·x0; target = noise − x0 (velocity). Video and action get independently sampled σ (two schedulers, per FastWAM).
Videolatent-frame 0 kept clean (per-token timestep 0) and dropped from the loss — only future latent frames supervised, via Wan's per-token 2D timestep path.
Actionfull chunk noised at its own σ; flow target noise_a − a0; MSE masked by action_mask > 0.5.
Totalλ_video·L_video + λ_action·L_action (defaults 1.0 / 1.0).

A new wrapper src/diffusion/flow_match.py holds add_noise, training_target, Euler step, and σ shifting — gaussian_diffusion.py untouched.

06

f0-only policy sampling — prefill cache

One big forward + N tiny action steps
Step 1 · once prefill_f0
one 5B Wan forward over the clean-f0 latent → cache f0_list (30 layers)
Step 2 · ×N action_from_cache
cheap action-expert step · Euler ODE: a ← a + (σ_next − σ)·a_out

The existing dfot_action_sample is coupled to Gaussian DDIM (alphas_cumprod); a parallel dfot_action_sample_flow (Euler ODE over sigmas) is added — df_sample.py untouched. This is what makes the 5B policy tractable, the FastWAM prefill_video_cache analogue and identical in spirit to the current Omni prefill. OmniWanTrainingModule.policy_sample mirrors the Omni version and reuses _denorm_action unchanged.

07

PEFT & compute plan

ComponentTreatment
Wan transformerLoRA r=16–32, α=2r on attn q/k/v/out + ffn projections; base frozen; gradient checkpointing. Tens of M trainable. Add peft to pyproject deps.
Action expert + v2a + action_to_tembFully trained (small, from scratch)
Wan VAEFrozen, fp32, not resident during training (precomputed latents)
umT5 text encoder (~11GB)Frozen, not resident (precomputed null-prompt buffer)
Precisionbf16-mixed (repo default)
ParallelismAuto-DDP as today; LoRA fits on 2–4× H200 (143GB), 8 for throughput. Full 5B fine-tune would need FSDP FULL_SHARD (~40GB AdamW states) — LoRA-first avoids this.
08

Configs & smoke tests

Two new config files, siblings of model/omni_s2.yaml / experiment/train_omni.yaml (which stay untouched):

# src/configs/model/omni_wan_5b.yaml
arch: Omni-Wan-5B
image_size: 256
latent_channels: 48
num_frames: 33          # -> T_lat = 9
action_expert_hidden: 1024
peft: {type: lora, r: 32, alpha: 64}
wan_repo: Wan-AI/Wan2.2-TI2V-5B-Diffusers

# src/configs/experiment/train_omni_wan.yaml
module: omni_wan
mixture: {enabled: true, latents: auto, latents_backbone: wanvae}
diffusion: {objective: flow_match, flow_shift: 5.0, action_loss_weight: 1.0}
infra: {gradient_checkpointing: true, vae_precision: fp32}
#Smoke test (scaffold-first)Asserts
1Import/registryget_models returns OmniWan for Omni-Wan-5B; NanoWM/Omni still resolve without peft import
2Loadfrom_pretrained under HF_HOME redirect; config = in/out 48, 30 layers, 3072 hidden
3Forward (tiny)v_pred shape == input; a_pred [B, T_action, A] and == 0 at init; f0_list len 30, each [B, V·64, 3072]
4VAE round-tripwan_encode/wan_decode finite + [B, 3, T, H, W]
5Prefill equivalenceforward_joint(...)[1] == action_from_cache(a_t, t_a, prefill_f0(...)) — the load-bearing inference correctness check
6Losscompute_joint_loss_flow finite; action loss 0 when action_mask==0
09

Untouched vs additive

FileChange
src/models/nanowm.py · omni.py · gaussian_diffusion.py · df_sample.py · respace.py · vae_ops.py · compute_joint_lossuntouched
src/models/__init__.pyadditive: merge OmniWan_models, specialize guard, Wan-only knobs
train_experiment.py module dispatchadditive: omni_wan branch + de-substring guard
mixture_dataset.py · world_model_dataset.py · lerobot_data_source.pyadditive: latents_backbone key + compression-aware latent slice (default no-op)
src/models/omni_wan.py · src/experiments/train_wan_experiment.py · src/diffusion/flow_match.py · src/utils/wan_vae_ops.py · src/tools/precompute_latents_wan.pyNEW
src/configs/model/omni_wan_5b.yaml · src/configs/experiment/train_omni_wan.yamlNEW
pyproject.tomladd peft, set HF_HOME activation env