Yixuan Wang1, Rhythm Syed1, Fangyu Wu1, Mengchao Zhang2, Aykut Onol2, Jose Barreiros2, Hooshang Nayyeri3, Tony Dear1, Huan Zhang4, Yunzhu Li1
1Columbia University β 2Toyota Research Institute β 3Amazon β 4University of Illinois Urbana-Champaign
Paper | Project Page | Video | Code
github_teaser.mp4
- Interactive World Simulator for Robot Policy Training and Evaluation
Step 1: Create and activate the conda environment.
mamba env create -f conda_env.yaml
conda activate iwsStep 2: Install Python dependencies.
uv pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu126/Step 3: Install the package in editable mode.
pip install -e .Step 4: Configure your Weights & Biases entity in configurations/config.yaml:
wandb:
entity: YOUR_WANDB_ENTITYDownload all pretrained checkpoints with one command (requires gdown, installed automatically):
bash scripts/download_checkpoints.shThis downloads 7 checkpoints to outputs/, each with its Hydra config:
| Directory | Task | Camera(s) |
|---|---|---|
outputs/pusht_cam1/ |
PushT | cam1 |
outputs/single_grasp_cam0/ |
Single Grasp | cam0 |
outputs/single_grasp_cam1/ |
Single Grasp | cam1 |
outputs/bimanual_sweep_cam0/ |
Bimanual Sweep | cam0 |
outputs/bimanual_sweep_cam1/ |
Bimanual Sweep | cam1 |
outputs/bimanual_rope_cam0/ |
Bimanual Rope | cam0 |
outputs/bimanual_rope_cam1/ |
Bimanual Rope | cam1 |
Each directory contains checkpoints/best.ckpt and .hydra/config.yaml.
Alternatively, download from Hugging Face manually.
Mini dataset (for inference and debugging β a few episodes per task):
bash scripts/download_mini_data.shDownloads to data/mini/{pusht,single_grasp,bimanual_sweep,bimanual_rope}/. (Hugging Face)
Full training dataset (real ALOHA):
bash scripts/download_full_data.shDownloads to data/full/{pusht,single_grasp,bimanual_sweep,bimanual_rope,bimanual_box,single_chain_in_box}/. (Hugging Face)
MuJoCo simulation dataset (PushT task, scripted policy):
python scripts/download_data_hf.py \
--repo yixuan1999/interactive-world-sim-mujoco-data \
--local_dir data/mujocoRequirements: finish Download Checkpoints and Download Data (mini dataset). Hardware: the minimum requirement is a 2080 GPU for inference.
Use the keyboard to teleoperate the robot through the world model (no physical robot required). Example for the PushT task:
python scripts/inference/teleoperate_keyboard.py \
+output_dir='data/wm_demo' \
+use_joystick=false \
+use_dataset=false \
+act_horizon=1 \
+scene=real \
"+ckpt_paths=['outputs/pusht_cam1/checkpoints/best.ckpt']" \
dataset=real_aloha_dataset \
dataset.dataset_dir=data/mini/pusht/val \
"dataset.obs_keys=['camera_1_color']"Controls: WASD to move left end-effector; IKJL to move right end-effector; Press "c" to start recording into HDF5 file; Press "s" to stop recording and save into HDF5 file; Press "q" to abandon HDF5 file recording.
Inference scripts for all camera views and tasks are in scripts/inference/keyboard/. Controls vary by task:
| Task | Script(s) | Keys | Action |
|---|---|---|---|
| PushT | pusht_kybd.sh |
WASD | Move left arm XY |
| IJKL | Move right arm XY | ||
| Single Grasp | single_grasp_cam{0,1}_kybd.sh |
WASD | Move arm XY |
| IK | Move arm Z | ||
| JL | Open / close gripper | ||
| Bimanual Sweep | bimanual_sweep_cam{0,1}_kybd.sh |
WASD | Move left arm XY |
| IJKL | Move right arm XY | ||
| Bimanual Rope | bimanual_rope_cam{0,1}_kybd.sh |
WASD | Move left arm XY |
| IJKL | Move right arm XY | ||
| QE | Move left arm Z | ||
| UO | Move right arm Z |
First, follow ALOHA setup process here to set up real robots. Example for Single Grasp with both cameras:
python scripts/inference/teleoperate_aloha.py \
+output_dir='data/wm_demo' \
+act_horizon=1 \
+scene=single_grasp_cam_0 \
"+ckpt_paths=['outputs/single_grasp_cam0/checkpoints/best.ckpt', 'outputs/single_grasp_cam1/checkpoints/best.ckpt']" \
dataset=real_aloha_dataset \
dataset.dataset_dir=data/real_aloha/single_grasp/val \
"dataset.obs_keys=['camera_0_color', 'camera_1_color']"All scripts are in scripts/inference/aloha/. Each task has three variants:
| Task | Script(s) | Cameras |
|---|---|---|
| Single Grasp | single_grasp_cam0_aloha.sh |
cam0 only |
single_grasp_cam1_aloha.sh |
cam1 only | |
single_grasp_cam0_and_cam1_aloha.sh |
both | |
| Bimanual Sweep | bimanual_sweep_cam0_aloha.sh |
cam0 only |
bimanual_sweep_cam1_aloha.sh |
cam1 only | |
bimanual_sweep_cam0_and_cam1_aloha.sh |
both | |
| Bimanual Rope | bimanual_rope_cam0_aloha.sh |
cam0 only |
bimanual_rope_cam1_aloha.sh |
cam1 only | |
bimanual_rope_cam0_and_cam1_aloha.sh |
both |
Interact with the world model live in your browser!
-
Start the server:
bash deploy/start_demo.sh -
Open https://www.yixuanwang.me/interactive_world_sim/ in your browser and click Connect Locally
Requirements: finish Download Checkpoints and Download Data (mini dataset).
Training uses Weights & Biases for logging. Make sure your entity is configured in configurations/config.yaml. Here we show example scripts to train the world model for T Pushing task.
Train the encoder and diffusion decoder to compress RGB observations into a compact latent space.
python main.py +name=pusht_stage_1 algorithm=latent_world_model \
experiment=exp_latent_dyn dataset=real_aloha_dataset \
dataset.dataset_dir=data/mini/pusht \
dataset.horizon=1 dataset.val_horizon=1 \
dataset.obs_keys=[camera_1_color] \
dataset.action_mode=bimanual_push \
experiment.training.batch_size=1 \
experiment.training.max_steps=1000005 \
experiment.training.log_every_n_steps=100 \
experiment.validation.limit_batch=1.0 \
experiment.validation.batch_size=10 \
experiment.validation.val_every_n_step=6000 \
algorithm.latent_dim=512 algorithm.action_dim=4 \
algorithm.training_stage=1My stage 1 training report is attached here for reference.
Train the latent dynamics model to predict future latent states from past observations and actions. Requires a Stage 1 checkpoint.
python main.py +name=pusht_stage_2 algorithm=latent_world_model \
experiment=exp_latent_dyn dataset=real_aloha_dataset \
dataset.dataset_dir=data/mini/pusht \
dataset.horizon=10 dataset.val_horizon=200 \
dataset.obs_keys=[camera_1_color] \
dataset.action_mode=bimanual_push \
experiment.training.batch_size=4 \
experiment.training.max_steps=1000005 \
experiment.training.log_every_n_steps=100 \
experiment.validation.limit_batch=1.0 \
experiment.validation.batch_size=2 \
experiment.validation.val_every_n_step=30000 \
experiment.training.checkpointing.every_n_train_steps=10000 \
experiment.training.data.num_workers=4 \
experiment.validation.data.num_workers=4 \
algorithm.latent_dim=512 algorithm.action_dim=4 \
algorithm.noise_scheduler.loss_weighting=uniform \
algorithm.sampling_strategy=terminal_only \
algorithm.load_ae="path_to_stage_1.ckpt" \
algorithm.training_stage=2My stage 2 training report is attached here for reference.
Finetune the decoder to make it robust to latent noises.
python main.py +name=pusht_stage_3 algorithm=latent_world_model \
experiment=exp_latent_dyn dataset=real_aloha_dataset \
dataset.dataset_dir=data/mini/pusht \
dataset.horizon=1 dataset.val_horizon=200 \
dataset.obs_keys=[camera_1_color] \
dataset.action_mode=bimanual_push \
experiment.training.batch_size=16 \
experiment.training.max_steps=1000005 \
experiment.training.log_every_n_steps=100 \
experiment.validation.limit_batch=1.0 \
experiment.validation.batch_size=2 \
experiment.validation.val_every_n_step=30000 \
experiment.training.checkpointing.every_n_train_steps=10000 \
experiment.training.data.num_workers=4 \
experiment.validation.data.num_workers=4 \
algorithm.latent_dim=512 algorithm.action_dim=4 \
algorithm.noise_scheduler.loss_weighting=uniform \
algorithm.sampling_strategy=terminal_only \
algorithm.load_ae="path_to_stage_2.ckpt" \
algorithm.training_stage=3My stage 3 training report is attached here for reference.
- For stage 1 training, the reconstruction should be almost perfect before proceeding to stage 2.
- For new tasks, you just need to change
action_modeandaction_dimaccordingly. - To train a good model, the play dataset is suggested to have large action data coverage (different speed, contact modes, and positions) and minimal occlusion.
- Order of stage 2 and stage 3 can be swapped.
- Even after the validation metrics converge, you are suggested to wait for longer to achieve the best result.
- Stage 2 training is most time-consuming. Stage 1 takes less time than stage 2 but more time than stage 3.
- 6-hour data (~600 episodes with 200 steps each) is typically enough for world model training
First, follow ALOHA setup process here to set up real robots. Example commands of recording HDF5 episodes are shown below.
python scripts/data_collection/collect_real_aloha.py \
--output_dir data/bimanual_push \
--robot_sides right \
--robot_sides left \
--frequency 10 \
--ctrl_mode bimanual_push \
--total_steps 200You need to change ctrl_mode for different tasks. After the data collection, you could run the following command to sleep robots safely:
python -m interactive_world_sim.real_world.robot_sleep --left --rightRequirements:
- ALOHA robot hardware (see
real_world/) - Intel RealSense cameras (configured in
real_world/aloha_extrinsics/)
Data is saved in HDF5 format and cached as a zarr dataset for fast loading during training.
Collect scripted demonstration data in MuJoCo simulation for the PushT task. A scripted policy automatically generates diverse motions (linear pushes, rotations, random contact, random exploration) and saves successful episodes.
Install mujoco environment by running
git submodule update --init --recursive
uv pip install -e external/gym-aloha/Then generate data with specific motion type (linear, rotating, random_contact, random_no_contact)
python scripts/data_collection/sim_aloha_dataset_collection_scripted.py \
--output_dir data/mujoco/pusht/train \
--motion_type random_no_contactUse --headless to run without visualization. Episodes are auto-saved; only successful ones are kept based on a task-specific success function.
The collected data is also available on Hugging Face.
You could reuse commands from Teleoperate from ALOHA Robot to collect data. Press "c" to start recording into HDF5 file; Press "s" to stop recording and save into HDF5 file; Press "q" to abandon HDF5 file recording.
- Built on Diffusion Forcing.
This repo is forked from Boyuan Chen's research template repo. By its MIT license, you must keep the above sentence in
README.mdand theLICENSEfile to credit the author.