End-to-end training spec: the PyTorch Lightning loop, the Hydra config system, and the three
design axes we ablated head-to-head on RT-1 (prediction target, action injection, model
scale). The winning settings are baked into experiment=default:
pred-v · additive injection · cosine + ZTSNR. HTML rendering of
docs/training.md + docs/config_system.md.
conda env create -f environment.yml && conda activate nanowm
# data + output paths (or use src/configs/local/paths.yaml)
export DATASET_DIR=/path/to/dino_wm_data # DINO-WM envs
export CSGO_DATA_DIR=/path/to/csgo # CSGO
export RT1_DATA_ROOT=/path/to/rt1_fractal # RT-1 (LeRobot fractal)
export RESULTS_DIR=/path/to/results # checkpoints + logs
# i3d torchscript for periodic FID/FVD during training (one-time)
mkdir -p pretrained_models/i3d
curl -L "https://www.dropbox.com/scl/fi/c5nfs6c422nlpj880jbmh/i3d_torchscript.pt?rlkey=x5xcjsrz0818i4qxyoglp5bb8&dl=1" \
-o pretrained_models/i3d/i3d_torchscript.pt
All runs go through src/main.py with an experiment / dataset / model triple:
# CSGO — 4x GPU, 50k steps
python src/main.py experiment=csgo dataset=game/csgo model=nanowm_l2_csgo
# RT-1 (fractal) — main run, NanoWM-B/2
python src/main.py experiment=rt1 dataset=rt1/rt1 model=nanowm_b2
# DINO-WM PushT
python src/main.py experiment=dino_wm_pusht dataset=dino_wm/pusht model=nanowm_b2
# Resume
python src/main.py experiment=csgo dataset=game/csgo model=nanowm_l2_csgo \
resume_from_checkpoint=<path/to/ckpt>
Outputs land under ${RESULTS_DIR}/<run_dir>/: .hydra/ (composed
config snapshot), checkpoints/latest/ (overwritten every 1k steps),
checkpoints/across_timesteps/ (every 10k steps), and tb/ for
tensorboard. Set wandb.enabled=true (+ WANDB_ENTITY /
WANDB_PROJECT) for W&B logging.
PyTorch Lightning drives the loop. Validation runs every val_every_n_steps
(default 1k); FID/FVD every metrics.log_every_n_train_steps (default 5k).
Key knobs: experiment.training.{batch_size, max_steps, gradient_clip_norm},
experiment.diffusion.{pred_name, noise_schedule, zero_terminal_snr, snr_gamma,
timestep_sampling}, experiment.infra.{mixed_precision, num_workers, compile}.
src/configs/config.yaml composes one option from each group;
CLI overrides beat everything.
defaults:
- model: nanowm_b2
- dataset: dino_wm/point_maze
- experiment: default
- planning: base
dataset_dir: ${oc.env:DATASET_DIR,./data}
csgo_data_dir: ${oc.env:CSGO_DATA_DIR,./data/csgo}
vae_model_path: ${oc.env:VAE_MODEL_PATH,stabilityai/sd-vae-ft-mse}
results_dir: ${oc.env:RESULTS_DIR,./results}
| Profile | What it sets |
|---|---|
default | Base training (1M steps, lr=1e-4, bs=8, pred-v + cosine + ZTSNR) |
csgo | lr=1e-5, bs=6, max_steps=50k (CSGO-specific) |
rt1 | RT-1 main training defaults |
ablation_rt1 | RT-1 ablation arms (50k steps) |
dino_wm_{env} | DINO-WM per-env overrides (point_maze, pusht, wall, rope, granular) |
evaluate_only | tasks=[evaluate], full validation set |
planning | tasks=[planning], requires ckpt_path=... |
| Model config | Architecture | Params | Frames | Image size |
|---|---|---|---|---|
nanowm_s2 | NanoWM-S/2 | ~40M | 4 | 256 |
nanowm_b2 | NanoWM-B/2 (default) | ~160M | 4 | 256 |
nanowm_l2 | NanoWM-L/2 | ~460M | 4 | 256 |
nanowm_s2_csgo | NanoWM-S/2 (CSGO) | ~40M | 4 | 320×512 |
nanowm_l2_csgo | NanoWM-L/2 (CSGO) | ~460M | 4 | 320×512 |
Debug composition without running: python src/main.py ... --cfg job
(add --package experiment.training for one section). Common errors:
ConfigCompositionException (typo in group name),
MissingMandatoryValue: ckpt_path (planning needs a checkpoint),
unresolvable ${oc.env:DATASET_DIR} (export the env var or pass
dataset_dir= on the CLI).
ε / v / x prediction, each in its native schedule (cosine + ZTSNR for v / x; linear for ε — cosine + ε is numerically degenerate at t=T). RT-1, NanoWM-B/2, 50k steps.
| Target | PSNR ↑ | SSIM ↑ | LPIPS ↓ | FID ↓ | Schedule |
|---|---|---|---|---|---|
| v | 23.07 | 0.760 | 0.207 | 42.27 | cosine + ZTSNR |
| x | 23.37 | 0.783 | 0.184 | 42.99 | cosine + ZTSNR |
| ε | 21.89 | 0.739 | 0.225 | 48.86 | linear |
python src/main.py experiment=ablation_rt1 dataset=rt1/rt1 model=nanowm_b2 \
experiment.diffusion.pred_name=x
Five conditioning mechanisms, shared action-embedding MLP, everything else fixed. Only the way actions enter the transformer differs.
| RT-1 (7D EE actions) | PSNR ↑ | SSIM ↑ | LPIPS ↓ | FID ↓ | Params |
|---|---|---|---|---|---|
| additive | 23.07 | 0.760 | 0.207 | 42.27 | 158.6M |
| adaLN | 23.19 | 0.762 | 0.206 | 43.62 | 158.6M |
| adaLN-fuse | 23.10 | 0.762 | 0.206 | 43.03 | 158.6M |
| FiLM | 23.20 | 0.763 | 0.203 | 40.62 | 172.8M |
| cross-attention | 20.82 | 0.721 | 0.242 | 51.12 | 187.0M |
| PushT (2D actions, 30k steps) | PSNR ↑ | SSIM ↑ | LPIPS ↓ | FID ↓ | Extra params |
|---|---|---|---|---|---|
| additive | 26.20 | 0.962 | 0.053 | 23.89 | 0 |
| adaLN-fuse | 26.17 | 0.961 | 0.051 | 30.28 | 0 |
| adaLN | 26.09 | 0.960 | 0.053 | 26.32 | ~42.5M |
| cross-attention | 25.95 | 0.959 | 0.055 | 28.64 | ~28.3M |
| FiLM | 25.88 | 0.960 | 0.056 | 25.45 | ~14.4M |
Width × depth × patch-size sweep on RT-1, 50k steps. B/2 is the reference; S/2 is ~4× smaller, L/2 is ~3× larger.
| Architecture | Params | PSNR ↑ | SSIM ↑ | LPIPS ↓ | FID ↓ |
|---|---|---|---|---|---|
| NanoWM-S/2 | 39.8M | 22.30 | 0.739 | 0.230 | 54.95 |
| NanoWM-B/2 | 158.6M | 23.07 | 0.760 | 0.207 | 42.27 |
| NanoWM-L/2 | ~460M | 23.62 | 0.777 | 0.186 | 36.31 |
Best-config runs (pred-v, additive, cosine + ZTSNR) on NanoWM-B/2 unless
noted. The 11 ablation arms above each ship a checkpoint too (HF
knightnemo/nanowm-*).
| Domain | HF checkpoint | Steps |
|---|---|---|
| DINO-WM Point Maze | nanowm-b2-dino-wm-point-maze-30k | 30k |
| DINO-WM Wall | nanowm-b2-dino-wm-wall-15k | 15k |
| DINO-WM Rope | nanowm-b2-dino-wm-rope-15k | 15k |
| DINO-WM Granular | nanowm-b2-dino-wm-granular-15k | 15k |
| DINO-WM PushT | nanowm-b2-dino-wm-pusht-100k | 100k |
| RT-1 (fractal) | nanowm-b2-rt1-300k | 300k |
| CSGO | nanowm-l2-csgo-50k / nanowm-l2-csgo-100k | 50k / 100k |
Reproduce any ablation arm with experiment=ablation_rt1 and the matching
experiment.diffusion.* / model.action_injection.type /
model= overrides. Eval numbers for these checkpoints:
evaluation spec.