Training Guide
The primary script for training the Hierarchical Reasoning Model is pretrain.py
. It leverages Hydra for configuration and can be run in both single-GPU and distributed multi-GPU settings.
Running the Training Script
The script is designed to be launched with torchrun
for distributed training, which is the recommended approach for larger-scale experiments on multiple GPUs.
General Command Structure (Multi-GPU):
OMP_NUM_THREADS=8 torchrun --nproc-per-node <NUM_GPUS> pretrain.py [HYDRA_OPTIONS]
<NUM_GPUS>
: The number of GPUs to use for training (e.g., 8).[HYDRA_OPTIONS]
: Configuration overrides for the experiment, such asdata_path
,lr
, etc. See the Configuration Guide for more details.
Dataset Preparation
Before starting any training run, ensure that the corresponding dataset has been built using the scripts in the dataset/
directory. For more details on this process, see the Data Pipeline Concepts page.
Example Training Commands
Below are the commands for the full-scale experiments mentioned in the README, assuming an 8-GPU setup.
ARC-1 (1,000 samples)
This command trains the model on the combined ARC-AGI and ConceptARC datasets.
# First, build the dataset
python dataset/build_arc_dataset.py
# Launch training
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py
Runtime: ~24 hours
ARC-2 (1,000 samples)
This command trains on the ARC-AGI-2 dataset.
# First, build the dataset
python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000
# Launch training
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py data_path=data/arc-2-aug-1000
Runtime: ~24 hours (a checkpoint after 8 hours is often sufficient)
Maze 30x30 Hard (1,000 samples)
# First, build the dataset
python dataset/build_maze_dataset.py
# Launch training
OMP_NUM_THREADS=8 torchrun --nproc-per-node 8 pretrain.py \
data_path=data/maze-30x30-hard-1k \
epochs=20000 \
eval_interval=2000 \
lr=1e-4 \
puzzle_emb_lr=1e-4 \
weight_decay=1.0 \
puzzle_emb_weight_decay=1.0
Runtime: ~1 hour
Checkpoints
By default, model checkpoints are saved to the checkpoints/
directory. The path is determined by the project_name
and run_name
configuration values. Checkpoints are saved periodically based on the eval_interval
and checkpoint_every_eval
settings in the configuration.