Configuration Management

The project uses Hydra to manage all configurations for training and evaluation. This allows for a clean separation of settings in YAML files and easy overriding of parameters from the command line.

Configuration File Structure

The main configuration files are located in the config/ directory:

  • config/cfg_pretrain.yaml: The main configuration file that sets default values for training.
  • config/arch/hrm_v1.yaml: The configuration file specific to the HRM model architecture.

Main Configuration (cfg_pretrain.yaml)

This file contains hyperparameters and settings related to the training process, data, and logging.

# config/cfg_pretrain.yaml

# Data path
data_path: data/arc-aug-1000

# Hyperparams - Training
global_batch_size: 768
epochs: 100000
eval_interval: 10000
checkpoint_every_eval: True

lr: 1e-4
lr_min_ratio: 1.0
lr_warmup_steps: 2000

# Standard hyperparameter settings for LM, as used in Llama
beta1: 0.9
beta2: 0.95
weight_decay: 0.1
puzzle_emb_weight_decay: 0.1

# Hyperparams - Puzzle embeddings training
puzzle_emb_lr: 1e-2

Key Parameters:

  • data_path: Path to the processed dataset directory.
  • global_batch_size: Total batch size across all GPUs.
  • epochs: Total number of training epochs.
  • eval_interval: Run evaluation every N epochs.
  • lr: Peak learning rate for the main model parameters.
  • puzzle_emb_lr: Peak learning rate for the sparse puzzle embeddings.
  • weight_decay: Weight decay for the main model and puzzle embeddings.

Architecture Configuration (arch/hrm_v1.yaml)

This file defines the structure and hyperparameters of the HierarchicalReasoningModel_ACTV1.

# config/arch/hrm_v1.yaml

name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1
loss:
  name: losses@ACTLossHead
  loss_type: stablemax_cross_entropy

halt_exploration_prob: 0.1
halt_max_steps: 16

H_cycles: 2
L_cycles: 2

H_layers: 4
L_layers: 4

hidden_size: 512
num_heads: 8
expansion: 4

puzzle_emb_ndim: ${.hidden_size}

pos_encodings: rope

Key Parameters:

  • name: The model class to instantiate.
  • loss: Configuration for the loss function, including the ACTLossHead.
  • halt_max_steps: Maximum number of recurrent steps for the ACT mechanism.
  • H_cycles, L_cycles: Number of update cycles for the high-level and low-level modules.
  • H_layers, L_layers: Number of Transformer blocks in each module.
  • hidden_size, num_heads: Standard Transformer dimensions.
  • puzzle_emb_ndim: Dimensionality of the per-puzzle embeddings.

Overriding Configuration via Command Line

Hydra's primary benefit is the ability to easily override any configuration value from the command line when launching a script.

Syntax: python <script>.py path.to.key=value

Example: To run the Sudoku experiment, you might override the data path, batch size, and learning rate:

python pretrain.py \
  data_path=data/sudoku-extreme-1k-aug-1000 \
  global_batch_size=384 \
  lr=7e-5

This command tells Hydra to use data/sudoku-extreme-1k-aug-1000 for data_path instead of the default value in cfg_pretrain.yaml. This makes it easy to manage and run multiple experiments without editing the YAML files directly.