Skip to content

WangYixuan12/interactive_world_sim

Repository files navigation

Interactive World Simulator for Robot Policy Training and Evaluation

Yixuan Wang1, Rhythm Syed1, Fangyu Wu1, Mengchao Zhang2, Aykut Onol2, Jose Barreiros2, Hooshang Nayyeri3, Tony Dear1, Huan Zhang4, Yunzhu Li1

1Columbia University   2Toyota Research Institute   3Amazon   4University of Illinois Urbana-Champaign

Paper | Project Page | Video | Code

github_teaser.mp4

Table of Contents

πŸ”¨ Installation

Step 1: Create and activate the conda environment.

mamba env create -f conda_env.yaml
conda activate iws

Step 2: Install Python dependencies.

uv pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu126/

Step 3: Install the package in editable mode.

pip install -e .

Step 4: Configure your Weights & Biases entity in configurations/config.yaml:

wandb:
  entity: YOUR_WANDB_ENTITY

πŸ€– Inference

Download Checkpoints

Download all pretrained checkpoints with one command (requires gdown, installed automatically):

bash scripts/download_checkpoints.sh

This downloads 7 checkpoints to outputs/, each with its Hydra config:

Directory Task Camera(s)
outputs/pusht_cam1/ PushT cam1
outputs/single_grasp_cam0/ Single Grasp cam0
outputs/single_grasp_cam1/ Single Grasp cam1
outputs/bimanual_sweep_cam0/ Bimanual Sweep cam0
outputs/bimanual_sweep_cam1/ Bimanual Sweep cam1
outputs/bimanual_rope_cam0/ Bimanual Rope cam0
outputs/bimanual_rope_cam1/ Bimanual Rope cam1

Each directory contains checkpoints/best.ckpt and .hydra/config.yaml.

Alternatively, download from Hugging Face manually.

Download Data

Mini dataset (for inference and debugging β€” a few episodes per task):

bash scripts/download_mini_data.sh

Downloads to data/mini/{pusht,single_grasp,bimanual_sweep,bimanual_rope}/. (Hugging Face)

Full training dataset (real ALOHA):

bash scripts/download_full_data.sh

Downloads to data/full/{pusht,single_grasp,bimanual_sweep,bimanual_rope,bimanual_box,single_chain_in_box}/. (Hugging Face)

MuJoCo simulation dataset (PushT task, scripted policy):

python scripts/download_data_hf.py \
    --repo yixuan1999/interactive-world-sim-mujoco-data \
    --local_dir data/mujoco

(Hugging Face)

Teleoperate from Keyboard

Requirements: finish Download Checkpoints and Download Data (mini dataset). Hardware: the minimum requirement is a 2080 GPU for inference.

Use the keyboard to teleoperate the robot through the world model (no physical robot required). Example for the PushT task:

python scripts/inference/teleoperate_keyboard.py \
  +output_dir='data/wm_demo' \
  +use_joystick=false \
  +use_dataset=false \
  +act_horizon=1 \
  +scene=real \
  "+ckpt_paths=['outputs/pusht_cam1/checkpoints/best.ckpt']" \
  dataset=real_aloha_dataset \
  dataset.dataset_dir=data/mini/pusht/val \
  "dataset.obs_keys=['camera_1_color']"

Controls: WASD to move left end-effector; IKJL to move right end-effector; Press "c" to start recording into HDF5 file; Press "s" to stop recording and save into HDF5 file; Press "q" to abandon HDF5 file recording.

Inference scripts for all camera views and tasks are in scripts/inference/keyboard/. Controls vary by task:

Task Script(s) Keys Action
PushT pusht_kybd.sh WASD Move left arm XY
IJKL Move right arm XY
Single Grasp single_grasp_cam{0,1}_kybd.sh WASD Move arm XY
IK Move arm Z
JL Open / close gripper
Bimanual Sweep bimanual_sweep_cam{0,1}_kybd.sh WASD Move left arm XY
IJKL Move right arm XY
Bimanual Rope bimanual_rope_cam{0,1}_kybd.sh WASD Move left arm XY
IJKL Move right arm XY
QE Move left arm Z
UO Move right arm Z

Teleoperate from ALOHA Robot

First, follow ALOHA setup process here to set up real robots. Example for Single Grasp with both cameras:

python scripts/inference/teleoperate_aloha.py \
  +output_dir='data/wm_demo' \
  +act_horizon=1 \
  +scene=single_grasp_cam_0 \
  "+ckpt_paths=['outputs/single_grasp_cam0/checkpoints/best.ckpt', 'outputs/single_grasp_cam1/checkpoints/best.ckpt']" \
  dataset=real_aloha_dataset \
  dataset.dataset_dir=data/real_aloha/single_grasp/val \
  "dataset.obs_keys=['camera_0_color', 'camera_1_color']"

All scripts are in scripts/inference/aloha/. Each task has three variants:

Task Script(s) Cameras
Single Grasp single_grasp_cam0_aloha.sh cam0 only
single_grasp_cam1_aloha.sh cam1 only
single_grasp_cam0_and_cam1_aloha.sh both
Bimanual Sweep bimanual_sweep_cam0_aloha.sh cam0 only
bimanual_sweep_cam1_aloha.sh cam1 only
bimanual_sweep_cam0_and_cam1_aloha.sh both
Bimanual Rope bimanual_rope_cam0_aloha.sh cam0 only
bimanual_rope_cam1_aloha.sh cam1 only
bimanual_rope_cam0_and_cam1_aloha.sh both

πŸ–₯️ Local Interactive Demo

Interact with the world model live in your browser!

  1. Start the server: bash deploy/start_demo.sh

  2. Open https://www.yixuanwang.me/interactive_world_sim/ in your browser and click Connect Locally

Requirements: finish Download Checkpoints and Download Data (mini dataset).

πŸ‹οΈ Training

Training uses Weights & Biases for logging. Make sure your entity is configured in configurations/config.yaml. Here we show example scripts to train the world model for T Pushing task.

Stage 1: Autoencoder Training

Train the encoder and diffusion decoder to compress RGB observations into a compact latent space.

python main.py +name=pusht_stage_1 algorithm=latent_world_model \
  experiment=exp_latent_dyn dataset=real_aloha_dataset \
  dataset.dataset_dir=data/mini/pusht \
  dataset.horizon=1 dataset.val_horizon=1 \
  dataset.obs_keys=[camera_1_color] \
  dataset.action_mode=bimanual_push \
  experiment.training.batch_size=1 \
  experiment.training.max_steps=1000005 \
  experiment.training.log_every_n_steps=100 \
  experiment.validation.limit_batch=1.0 \
  experiment.validation.batch_size=10 \
  experiment.validation.val_every_n_step=6000 \
  algorithm.latent_dim=512 algorithm.action_dim=4 \
  algorithm.training_stage=1

My stage 1 training report is attached here for reference.

Stage 2: Dynamics Training

Train the latent dynamics model to predict future latent states from past observations and actions. Requires a Stage 1 checkpoint.

python main.py +name=pusht_stage_2 algorithm=latent_world_model \
  experiment=exp_latent_dyn dataset=real_aloha_dataset \
  dataset.dataset_dir=data/mini/pusht \
  dataset.horizon=10 dataset.val_horizon=200 \
  dataset.obs_keys=[camera_1_color] \
  dataset.action_mode=bimanual_push \
  experiment.training.batch_size=4 \
  experiment.training.max_steps=1000005 \
  experiment.training.log_every_n_steps=100 \
  experiment.validation.limit_batch=1.0 \
  experiment.validation.batch_size=2 \
  experiment.validation.val_every_n_step=30000 \
  experiment.training.checkpointing.every_n_train_steps=10000 \
  experiment.training.data.num_workers=4 \
  experiment.validation.data.num_workers=4 \
  algorithm.latent_dim=512 algorithm.action_dim=4 \
  algorithm.noise_scheduler.loss_weighting=uniform \
  algorithm.sampling_strategy=terminal_only \
  algorithm.load_ae="path_to_stage_1.ckpt" \
  algorithm.training_stage=2

My stage 2 training report is attached here for reference.

Stage 3: Autoencoder Finetuning

Finetune the decoder to make it robust to latent noises.

python main.py +name=pusht_stage_3 algorithm=latent_world_model \
  experiment=exp_latent_dyn dataset=real_aloha_dataset \
  dataset.dataset_dir=data/mini/pusht \
  dataset.horizon=1 dataset.val_horizon=200 \
  dataset.obs_keys=[camera_1_color] \
  dataset.action_mode=bimanual_push \
  experiment.training.batch_size=16 \
  experiment.training.max_steps=1000005 \
  experiment.training.log_every_n_steps=100 \
  experiment.validation.limit_batch=1.0 \
  experiment.validation.batch_size=2 \
  experiment.validation.val_every_n_step=30000 \
  experiment.training.checkpointing.every_n_train_steps=10000 \
  experiment.training.data.num_workers=4 \
  experiment.validation.data.num_workers=4 \
  algorithm.latent_dim=512 algorithm.action_dim=4 \
  algorithm.noise_scheduler.loss_weighting=uniform \
  algorithm.sampling_strategy=terminal_only \
  algorithm.load_ae="path_to_stage_2.ckpt" \
  algorithm.training_stage=3

My stage 3 training report is attached here for reference.

Empirical tips for training:

  • For stage 1 training, the reconstruction should be almost perfect before proceeding to stage 2.
  • For new tasks, you just need to change action_mode and action_dim accordingly.
  • To train a good model, the play dataset is suggested to have large action data coverage (different speed, contact modes, and positions) and minimal occlusion.
  • Order of stage 2 and stage 3 can be swapped.
  • Even after the validation metrics converge, you are suggested to wait for longer to achieve the best result.
  • Stage 2 training is most time-consuming. Stage 1 takes less time than stage 2 but more time than stage 3.
  • 6-hour data (~600 episodes with 200 steps each) is typically enough for world model training

πŸ“¦ Real-World Data Collection on ALOHA

First, follow ALOHA setup process here to set up real robots. Example commands of recording HDF5 episodes are shown below.

python scripts/data_collection/collect_real_aloha.py \
  --output_dir data/bimanual_push \
  --robot_sides right \
  --robot_sides left \
  --frequency 10 \
  --ctrl_mode bimanual_push \
  --total_steps 200

You need to change ctrl_mode for different tasks. After the data collection, you could run the following command to sleep robots safely:

python -m interactive_world_sim.real_world.robot_sleep --left --right

Requirements:

Data is saved in HDF5 format and cached as a zarr dataset for fast loading during training.

πŸ€– Sim Data Collection (MuJoCo)

Collect scripted demonstration data in MuJoCo simulation for the PushT task. A scripted policy automatically generates diverse motions (linear pushes, rotations, random contact, random exploration) and saves successful episodes.

Install mujoco environment by running

git submodule update --init --recursive
uv pip install -e external/gym-aloha/

Then generate data with specific motion type (linear, rotating, random_contact, random_no_contact)

python scripts/data_collection/sim_aloha_dataset_collection_scripted.py \
    --output_dir data/mujoco/pusht/train \
    --motion_type random_no_contact

Use --headless to run without visualization. Episodes are auto-saved; only successful ones are kept based on a task-specific success function.

The collected data is also available on Hugging Face.

🌎 WM Data Collection on ALOHA

You could reuse commands from Teleoperate from ALOHA Robot to collect data. Press "c" to start recording into HDF5 file; Press "s" to stop recording and save into HDF5 file; Press "q" to abandon HDF5 file recording.

Acknowledgements

This repo is forked from Boyuan Chen's research template repo. By its MIT license, you must keep the above sentence in README.md and the LICENSE file to credit the author.

About

[RSS 2026] Interactive World Simulator for Robot Policy Training and Evaluation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors