API Reference: Training & Configuration

This section covers the Pydantic configuration classes used by the main training script, pretrain.py.

PretrainConfig

This is the top-level configuration model that orchestrates the entire training run.

Source: pretrain.py

class PretrainConfig(pydantic.BaseModel):
    # Config
    arch: ArchConfig
    # Data
    data_path: str

    # Hyperparams
    global_batch_size: int
    epochs: int

    lr: float
    lr_min_ratio: float
    lr_warmup_steps: int

    weight_decay: float
    beta1: float
    beta2: float

    # Puzzle embedding
    puzzle_emb_lr: float
    puzzle_emb_weight_decay: float

    # Names
    project_name: Optional[str] = None
    run_name: Optional[str] = None
    checkpoint_path: Optional[str] = None

    # Extras
    seed: int = 0
    checkpoint_every_eval: bool = False
    eval_interval: Optional[int] = None
    eval_save_outputs: List[str] = []

Parameters:

  • arch: A nested ArchConfig object specifying the model architecture.
  • data_path: The path to the processed dataset.
  • global_batch_size: The total batch size across all devices.
  • epochs: Total number of training epochs.
  • lr, lr_min_ratio, lr_warmup_steps: Parameters for the cosine learning rate scheduler.
  • weight_decay, beta1, beta2: Adam optimizer parameters.
  • puzzle_emb_lr, puzzle_emb_weight_decay: Specific learning rate and weight decay for the sparse puzzle embeddings.
  • project_name, run_name: Names for the Weights & Biases project and run.
  • checkpoint_path: Directory to save model checkpoints.
  • eval_interval: Frequency (in epochs) for running evaluation.
  • eval_save_outputs: List of keys to save from the model output during evaluation.

ArchConfig

Defines the model architecture and its associated loss function.

Source: pretrain.py

class ArchConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')

    name: str
    loss: LossConfig
Parameters:

  • name: The identifier for the model class (e.g., hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1).
  • loss: A nested LossConfig object.
  • extra='allow': Allows passing additional, unspecified parameters that will be forwarded to the model's constructor.

LossConfig

Defines the loss head and its parameters.

Source: pretrain.py

class LossConfig(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra='allow')

    name: str
Parameters:

  • name: The identifier for the loss head class (e.g., losses@ACTLossHead).