API Reference: Training & Configuration
This section covers the Pydantic configuration classes used by the main training script, pretrain.py.
PretrainConfig
This is the top-level configuration model that orchestrates the entire training run.
Source: pretrain.py
class PretrainConfig(pydantic.BaseModel):
# Config
arch: ArchConfig
# Data
data_path: str
# Hyperparams
global_batch_size: int
epochs: int
lr: float
lr_min_ratio: float
lr_warmup_steps: int
weight_decay: float
beta1: float
beta2: float
# Puzzle embedding
puzzle_emb_lr: float
puzzle_emb_weight_decay: float
# Names
project_name: Optional[str] = None
run_name: Optional[str] = None
checkpoint_path: Optional[str] = None
# Extras
seed: int = 0
checkpoint_every_eval: bool = False
eval_interval: Optional[int] = None
eval_save_outputs: List[str] = []
Parameters:
arch: A nestedArchConfigobject specifying the model architecture.data_path: The path to the processed dataset.global_batch_size: The total batch size across all devices.epochs: Total number of training epochs.lr,lr_min_ratio,lr_warmup_steps: Parameters for the cosine learning rate scheduler.weight_decay,beta1,beta2: Adam optimizer parameters.puzzle_emb_lr,puzzle_emb_weight_decay: Specific learning rate and weight decay for the sparse puzzle embeddings.project_name,run_name: Names for the Weights & Biases project and run.checkpoint_path: Directory to save model checkpoints.eval_interval: Frequency (in epochs) for running evaluation.eval_save_outputs: List of keys to save from the model output during evaluation.
ArchConfig
Defines the model architecture and its associated loss function.
Source: pretrain.py
class ArchConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
loss: LossConfig
Parameters:
name: The identifier for the model class (e.g.,hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1).loss: A nestedLossConfigobject.extra='allow': Allows passing additional, unspecified parameters that will be forwarded to the model's constructor.
LossConfig
Defines the loss head and its parameters.
Source: pretrain.py
class LossConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra='allow')
name: str
Parameters:
name: The identifier for the loss head class (e.g.,losses@ACTLossHead).