OmniWM Spec · training
Guide — workflow + design choices

Training the Nano World Model

End-to-end training spec: the PyTorch Lightning loop, the Hydra config system, and the three design axes we ablated head-to-head on RT-1 (prediction target, action injection, model scale). The winning settings are baked into experiment=default: pred-v · additive injection · cosine + ZTSNR. HTML rendering of docs/training.md + docs/config_system.md.

00

Setup & quick start

conda env create -f environment.yml && conda activate nanowm

# data + output paths (or use src/configs/local/paths.yaml)
export DATASET_DIR=/path/to/dino_wm_data    # DINO-WM envs
export CSGO_DATA_DIR=/path/to/csgo          # CSGO
export RT1_DATA_ROOT=/path/to/rt1_fractal   # RT-1 (LeRobot fractal)
export RESULTS_DIR=/path/to/results         # checkpoints + logs

# i3d torchscript for periodic FID/FVD during training (one-time)
mkdir -p pretrained_models/i3d
curl -L "https://www.dropbox.com/scl/fi/c5nfs6c422nlpj880jbmh/i3d_torchscript.pt?rlkey=x5xcjsrz0818i4qxyoglp5bb8&dl=1" \
    -o pretrained_models/i3d/i3d_torchscript.pt

All runs go through src/main.py with an experiment / dataset / model triple:

# CSGO — 4x GPU, 50k steps
python src/main.py experiment=csgo dataset=game/csgo model=nanowm_l2_csgo

# RT-1 (fractal) — main run, NanoWM-B/2
python src/main.py experiment=rt1 dataset=rt1/rt1 model=nanowm_b2

# DINO-WM PushT
python src/main.py experiment=dino_wm_pusht dataset=dino_wm/pusht model=nanowm_b2

# Resume
python src/main.py experiment=csgo dataset=game/csgo model=nanowm_l2_csgo \
    resume_from_checkpoint=<path/to/ckpt>

Outputs land under ${RESULTS_DIR}/<run_dir>/: .hydra/ (composed config snapshot), checkpoints/latest/ (overwritten every 1k steps), checkpoints/across_timesteps/ (every 10k steps), and tb/ for tensorboard. Set wandb.enabled=true (+ WANDB_ENTITY / WANDB_PROJECT) for W&B logging.

01

Training loop, in one pass

One optimizer step
Input Video clip
[B, T, 3, H, W] sampled from the dataset
Encode Per-frame latents
frozen VAE · per-frame diffusion timesteps (logit-normal, SD3-style)
Model NanoWM transformer
factorized spatial/temporal blocks · action-conditioned
Loss Prediction-target loss
v / x / ε · AdamW lr=1e-4, warmup 1000, cosine decay
frozen encoder trainable loss

PyTorch Lightning drives the loop. Validation runs every val_every_n_steps (default 1k); FID/FVD every metrics.log_every_n_train_steps (default 5k). Key knobs: experiment.training.{batch_size, max_steps, gradient_clip_norm}, experiment.diffusion.{pred_name, noise_schedule, zero_terminal_snr, snr_gamma, timestep_sampling}, experiment.infra.{mixed_precision, num_workers, compile}.

02

Config system (Hydra)

src/configs/config.yaml composes one option from each group; CLI overrides beat everything.

defaults:
  - model: nanowm_b2
  - dataset: dino_wm/point_maze
  - experiment: default
  - planning: base

dataset_dir:    ${oc.env:DATASET_DIR,./data}
csgo_data_dir:  ${oc.env:CSGO_DATA_DIR,./data/csgo}
vae_model_path: ${oc.env:VAE_MODEL_PATH,stabilityai/sd-vae-ft-mse}
results_dir:    ${oc.env:RESULTS_DIR,./results}
ProfileWhat it sets
defaultBase training (1M steps, lr=1e-4, bs=8, pred-v + cosine + ZTSNR)
csgolr=1e-5, bs=6, max_steps=50k (CSGO-specific)
rt1RT-1 main training defaults
ablation_rt1RT-1 ablation arms (50k steps)
dino_wm_{env}DINO-WM per-env overrides (point_maze, pusht, wall, rope, granular)
evaluate_onlytasks=[evaluate], full validation set
planningtasks=[planning], requires ckpt_path=...
Model configArchitectureParamsFramesImage size
nanowm_s2NanoWM-S/2~40M4256
nanowm_b2NanoWM-B/2 (default)~160M4256
nanowm_l2NanoWM-L/2~460M4256
nanowm_s2_csgoNanoWM-S/2 (CSGO)~40M4320×512
nanowm_l2_csgoNanoWM-L/2 (CSGO)~460M4320×512

Debug composition without running: python src/main.py ... --cfg job (add --package experiment.training for one section). Common errors: ConfigCompositionException (typo in group name), MissingMandatoryValue: ckpt_path (planning needs a checkpoint), unresolvable ${oc.env:DATASET_DIR} (export the env var or pass dataset_dir= on the CLI).

03

Axis 1 — prediction target

ε / v / x prediction, each in its native schedule (cosine + ZTSNR for v / x; linear for ε — cosine + ε is numerically degenerate at t=T). RT-1, NanoWM-B/2, 50k steps.

TargetPSNR ↑SSIM ↑LPIPS ↓FID ↓Schedule
v23.070.7600.20742.27cosine + ZTSNR
x23.370.7830.18442.99cosine + ZTSNR
ε21.890.7390.22548.86linear
python src/main.py experiment=ablation_rt1 dataset=rt1/rt1 model=nanowm_b2 \
    experiment.diffusion.pred_name=x
04

Axis 2 — action injection

Five conditioning mechanisms, shared action-embedding MLP, everything else fixed. Only the way actions enter the transformer differs.

RT-1 (7D EE actions)PSNR ↑SSIM ↑LPIPS ↓FID ↓Params
additive23.070.7600.20742.27158.6M
adaLN23.190.7620.20643.62158.6M
adaLN-fuse23.100.7620.20643.03158.6M
FiLM23.200.7630.20340.62172.8M
cross-attention20.820.7210.24251.12187.0M
PushT (2D actions, 30k steps)PSNR ↑SSIM ↑LPIPS ↓FID ↓Extra params
additive26.200.9620.05323.890
adaLN-fuse26.170.9610.05130.280
adaLN26.090.9600.05326.32~42.5M
cross-attention25.950.9590.05528.64~28.3M
FiLM25.880.9600.05625.45~14.4M
05

Axis 3 — model scale

Width × depth × patch-size sweep on RT-1, 50k steps. B/2 is the reference; S/2 is ~4× smaller, L/2 is ~3× larger.

ArchitectureParamsPSNR ↑SSIM ↑LPIPS ↓FID ↓
NanoWM-S/239.8M22.300.7390.23054.95
NanoWM-B/2158.6M23.070.7600.20742.27
NanoWM-L/2~460M23.620.7770.18636.31
06

Pretrained checkpoints

Best-config runs (pred-v, additive, cosine + ZTSNR) on NanoWM-B/2 unless noted. The 11 ablation arms above each ship a checkpoint too (HF knightnemo/nanowm-*).

DomainHF checkpointSteps
DINO-WM Point Mazenanowm-b2-dino-wm-point-maze-30k30k
DINO-WM Wallnanowm-b2-dino-wm-wall-15k15k
DINO-WM Ropenanowm-b2-dino-wm-rope-15k15k
DINO-WM Granularnanowm-b2-dino-wm-granular-15k15k
DINO-WM PushTnanowm-b2-dino-wm-pusht-100k100k
RT-1 (fractal)nanowm-b2-rt1-300k300k
CSGOnanowm-l2-csgo-50k / nanowm-l2-csgo-100k50k / 100k

Reproduce any ablation arm with experiment=ablation_rt1 and the matching experiment.diffusion.* / model.action_injection.type / model= overrides. Eval numbers for these checkpoints: evaluation spec.