Training Guide

The primary script for training the Hierarchical Reasoning Model is pretrain.py. It leverages Hydra for configuration and can be run in both single-GPU and distributed multi-GPU settings.

Running the Training Script

The script is designed to be launched with torchrun for distributed training, which is the recommended approach for larger-scale experiments on multiple GPUs.

General Command Structure (Multi-GPU):

OMP_NUM_THREADS=8 torchrun --nproc-per-node <NUM_GPUS> pretrain.py [HYDRA_OPTIONS]
  • <NUM_GPUS>: The number of GPUs to use for training (e.g., 8).
  • [HYDRA_OPTIONS]: Configuration overrides for the experiment, such as data_path, lr, etc. See the Configuration Guide for more details.

Dataset Preparation

Before starting any training run, ensure that the corresponding dataset has been built using the scripts in the dataset/ directory. For more details on this process, see the Data Pipeline Concepts page.

Example Training Commands

Below are the commands for the full-scale experiments mentioned in the README, assuming an 8-GPU setup.

ARC-1 (1,000 samples)

This command trains the model on the combined ARC-AGI and ConceptARC datasets.

# First, build the dataset
python dataset/build_arc_dataset.py

# Launch training
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
Runtime: ~24 hours

ARC-2 (1,000 samples)

This command trains on the ARC-AGI-2 dataset.

# First, build the dataset
python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000

# Launch training
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
Runtime: ~24 hours (a checkpoint after 8 hours is often sufficient)

Maze 30x30 Hard (1,000 samples)

# First, build the dataset
python dataset/build_maze_dataset.py

# Launch training
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py \
  data_path=data/maze-30x30-hard-1k \
  epochs=20000 \
  eval_interval=2000 \
  lr=1e-4 \
  puzzle_emb_lr=1e-4 \
  weight_decay=1.0 \
  puzzle_emb_weight_decay=1.0
Runtime: ~1 hour

Checkpoints

By default, model checkpoints are saved to the checkpoints/ directory. The path is determined by the project_name and run_name configuration values. Checkpoints are saved periodically based on the eval_interval and checkpoint_every_eval settings in the configuration.