Module loda.llm.trainer

Training script for the LODA LLM (Large Language Model).

This script handles the complete training pipeline: 1. Load and preprocess training data 2. Set up the model and training loop 3. Train the model with proper validation 4. Save the trained model

Functions

def train_loda_llm(programs_dir: str,
output_dir: str = 'loda_llm_model',
model_name: str = 't5-small',
max_examples: int = -1,
val_split: float = 0.1,
batch_size: int = 8,
learning_rate: float = 5e-05,
num_epochs: int = 3)
Expand source code
def train_loda_llm(programs_dir: str,
                   output_dir: str = "loda_llm_model",
                   model_name: str = "t5-small",
                   max_examples: int = -1,
                   val_split: float = 0.1,
                   batch_size: int = 8,
                   learning_rate: float = 5e-5,
                   num_epochs: int = 3):
    """
    Main training function.
    
    Args:
        programs_dir: Directory containing OEIS programs
        output_dir: Directory to save the trained model
        model_name: Base T5 model to use
        max_examples: Maximum number of training examples (-1 for all)
        val_split: Fraction of data to use for validation
        batch_size: Training batch size
        learning_rate: Learning rate
        num_epochs: Number of training epochs
    """
    print("Preparing training data...")
    
    # Create training examples
    preprocessor = DataPreprocessor(programs_dir)
    examples = preprocessor.create_training_examples(max_examples)
    
    if len(examples) == 0:
        print("No training examples found!")
        return None
    
    # Augment examples
    print("Augmenting training examples...")
    examples = preprocessor.augment_descriptions(examples)
    
    # Split into train/validation
    if val_split > 0:
        split_idx = int(len(examples) * (1 - val_split))
        train_examples = examples[:split_idx]
        val_examples = examples[split_idx:]
    else:
        train_examples = examples
        val_examples = None
    
    print(f"Training examples: {len(train_examples)}")
    if val_examples:
        print(f"Validation examples: {len(val_examples)}")
    
    # Create model
    print(f"Creating model based on {model_name}...")
    model = LodaT5Model(model_name)
    
    # Create datasets
    train_dataset = LodaDataset(train_examples, model)
    val_dataset = LodaDataset(val_examples, model) if val_examples else None
    
    # Create trainer
    trainer = LodaTrainer(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        learning_rate=learning_rate,
        batch_size=batch_size,
        num_epochs=num_epochs,
        save_dir=output_dir
    )
    
    # Train the model
    trained_model = trainer.train()
    
    # Save final model
    trained_model.save_model(output_dir)
    print(f"Final model saved to {output_dir}")
    
    return trained_model

Main training function.

Args

programs_dir
Directory containing OEIS programs
output_dir
Directory to save the trained model
model_name
Base T5 model to use
max_examples
Maximum number of training examples (-1 for all)
val_split
Fraction of data to use for validation
batch_size
Training batch size
learning_rate
Learning rate
num_epochs
Number of training epochs

Classes

class LodaDataset (examples: List[TrainingExample],
model: LodaT5Model,
max_length: int = 512)
Expand source code
class LodaDataset(Dataset):
    """PyTorch dataset for LODA training examples."""
    
    def __init__(self, examples: List[TrainingExample], model: LodaT5Model, max_length: int = 512):
        """
        Initialize the dataset.
        
        Args:
            examples: List of training examples
            model: LodaT5Model instance for tokenization
            max_length: Maximum sequence length
        """
        self.examples = examples
        self.model = model
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Prepare input (description)
        input_encoding = self.model.prepare_input([example.description])
        
        # Prepare target (LODA code)
        target_encoding = self.model.prepare_target([example.loda_code])
        
        return {
            'input_ids': input_encoding['input_ids'].squeeze(),
            'attention_mask': input_encoding['attention_mask'].squeeze(),
            'labels': target_encoding['input_ids'].squeeze(),
            'decoder_attention_mask': target_encoding['attention_mask'].squeeze()
        }

PyTorch dataset for LODA training examples.

Initialize the dataset.

Args

examples
List of training examples
model
LodaT5Model instance for tokenization
max_length
Maximum sequence length

Ancestors

  • torch.utils.data.dataset.Dataset
  • typing.Generic
class LodaTrainer (model: LodaT5Model,
train_dataset: LodaDataset,
val_dataset: LodaDataset | None = None,
learning_rate: float = 5e-05,
batch_size: int = 8,
num_epochs: int = 3,
warmup_steps: int = 500,
save_dir: str = 'loda_llm_model')
Expand source code
class LodaTrainer:
    """Trainer class for LODA LLM."""
    
    def __init__(self, 
                 model: LodaT5Model,
                 train_dataset: LodaDataset,
                 val_dataset: Optional[LodaDataset] = None,
                 learning_rate: float = 5e-5,
                 batch_size: int = 8,
                 num_epochs: int = 3,
                 warmup_steps: int = 500,
                 save_dir: str = "loda_llm_model"):
        """
        Initialize the trainer.
        
        Args:
            model: LodaT5Model to train
            train_dataset: Training dataset
            val_dataset: Validation dataset (optional)
            learning_rate: Learning rate
            batch_size: Batch size
            num_epochs: Number of training epochs
            warmup_steps: Number of warmup steps for learning rate schedule
            save_dir: Directory to save the model
        """
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.num_epochs = num_epochs
        self.warmup_steps = warmup_steps
        self.save_dir = save_dir
        
        # Set up device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.model.to(self.device)
        
        # Set up data loaders
        self.train_loader = DataLoader(
            train_dataset, 
            batch_size=batch_size, 
            shuffle=True,
            collate_fn=self._collate_fn
        )
        
        if val_dataset:
            self.val_loader = DataLoader(
                val_dataset, 
                batch_size=batch_size, 
                shuffle=False,
                collate_fn=self._collate_fn
            )
        
        # Set up optimizer
        self.optimizer = AdamW(
            self.model.model.parameters(),
            lr=learning_rate,
            weight_decay=0.01
        )
        
        # Set up learning rate scheduler
        total_steps = len(self.train_loader) * num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )
    
    def _collate_fn(self, batch):
        """Collate function for DataLoader."""
        # Pad sequences to the same length
        input_ids = [item['input_ids'] for item in batch]
        attention_masks = [item['attention_mask'] for item in batch]
        labels = [item['labels'] for item in batch]
        decoder_attention_masks = [item['decoder_attention_mask'] for item in batch]
        
        # Pad input sequences
        max_input_len = max(len(seq) for seq in input_ids)
        padded_input_ids = []
        padded_attention_masks = []
        
        for i in range(len(input_ids)):
            seq_len = len(input_ids[i])
            pad_len = max_input_len - seq_len
            
            padded_input_ids.append(
                torch.cat([input_ids[i], torch.zeros(pad_len, dtype=torch.long)])
            )
            padded_attention_masks.append(
                torch.cat([attention_masks[i], torch.zeros(pad_len, dtype=torch.long)])
            )
        
        # Pad target sequences
        max_target_len = max(len(seq) for seq in labels)
        padded_labels = []
        padded_decoder_masks = []
        
        for i in range(len(labels)):
            seq_len = len(labels[i])
            pad_len = max_target_len - seq_len
            
            # For labels, use -100 for padding (ignored in loss calculation)
            padded_labels.append(
                torch.cat([labels[i], torch.full((pad_len,), -100, dtype=torch.long)])
            )
            padded_decoder_masks.append(
                torch.cat([decoder_attention_masks[i], torch.zeros(pad_len, dtype=torch.long)])
            )
        
        return {
            'input_ids': torch.stack(padded_input_ids),
            'attention_mask': torch.stack(padded_attention_masks),
            'labels': torch.stack(padded_labels),
            'decoder_attention_mask': torch.stack(padded_decoder_masks)
        }
    
    def train_epoch(self):
        """Train for one epoch."""
        self.model.model.train()
        total_loss = 0
        
        progress_bar = tqdm(self.train_loader, desc="Training")
        
        for batch in progress_bar:
            # Move to device
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            # Forward pass
            outputs = self.model.forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            # Backward pass
            loss.backward()
            
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.model.parameters(), 1.0)
            
            # Update parameters
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()
            
            # Update progress bar
            progress_bar.set_postfix({'loss': loss.item()})
        
        return total_loss / len(self.train_loader)
    
    def validate(self):
        """Validate the model."""
        if not self.val_dataset:
            return None
        
        self.model.model.eval()
        total_loss = 0
        
        with torch.no_grad():
            progress_bar = tqdm(self.val_loader, desc="Validation")
            
            for batch in progress_bar:
                # Move to device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Forward pass
                outputs = self.model.forward(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels']
                )
                
                loss = outputs.loss
                total_loss += loss.item()
                
                progress_bar.set_postfix({'val_loss': loss.item()})
        
        return total_loss / len(self.val_loader)
    
    def train(self):
        """Train the model."""
        print(f"Training on device: {self.device}")
        print(f"Training examples: {len(self.train_dataset)}")
        if self.val_dataset:
            print(f"Validation examples: {len(self.val_dataset)}")
        
        best_val_loss = float('inf')
        
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.num_epochs}")
            
            # Train
            train_loss = self.train_epoch()
            print(f"Training loss: {train_loss:.4f}")
            
            # Validate
            val_loss = self.validate()
            if val_loss is not None:
                print(f"Validation loss: {val_loss:.4f}")
                
                # Save best model
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    self.save_model(f"{self.save_dir}_best")
                    print("Saved best model")
            
            # Save checkpoint
            self.save_model(f"{self.save_dir}_epoch_{epoch + 1}")
        
        print("\nTraining completed!")
        return self.model
    
    def save_model(self, path: str):
        """Save the model."""
        self.model.save_model(path)

Trainer class for LODA LLM.

Initialize the trainer.

Args

model
LodaT5Model to train
train_dataset
Training dataset
val_dataset
Validation dataset (optional)
learning_rate
Learning rate
batch_size
Batch size
num_epochs
Number of training epochs
warmup_steps
Number of warmup steps for learning rate schedule
save_dir
Directory to save the model

Methods

def save_model(self, path: str)
Expand source code
def save_model(self, path: str):
    """Save the model."""
    self.model.save_model(path)

Save the model.

def train(self)
Expand source code
def train(self):
    """Train the model."""
    print(f"Training on device: {self.device}")
    print(f"Training examples: {len(self.train_dataset)}")
    if self.val_dataset:
        print(f"Validation examples: {len(self.val_dataset)}")
    
    best_val_loss = float('inf')
    
    for epoch in range(self.num_epochs):
        print(f"\nEpoch {epoch + 1}/{self.num_epochs}")
        
        # Train
        train_loss = self.train_epoch()
        print(f"Training loss: {train_loss:.4f}")
        
        # Validate
        val_loss = self.validate()
        if val_loss is not None:
            print(f"Validation loss: {val_loss:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self.save_model(f"{self.save_dir}_best")
                print("Saved best model")
        
        # Save checkpoint
        self.save_model(f"{self.save_dir}_epoch_{epoch + 1}")
    
    print("\nTraining completed!")
    return self.model

Train the model.

def train_epoch(self)
Expand source code
def train_epoch(self):
    """Train for one epoch."""
    self.model.model.train()
    total_loss = 0
    
    progress_bar = tqdm(self.train_loader, desc="Training")
    
    for batch in progress_bar:
        # Move to device
        batch = {k: v.to(self.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = self.model.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        
        # Backward pass
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(self.model.model.parameters(), 1.0)
        
        # Update parameters
        self.optimizer.step()
        self.scheduler.step()
        self.optimizer.zero_grad()
        
        # Update progress bar
        progress_bar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(self.train_loader)

Train for one epoch.

def validate(self)
Expand source code
def validate(self):
    """Validate the model."""
    if not self.val_dataset:
        return None
    
    self.model.model.eval()
    total_loss = 0
    
    with torch.no_grad():
        progress_bar = tqdm(self.val_loader, desc="Validation")
        
        for batch in progress_bar:
            # Move to device
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            # Forward pass
            outputs = self.model.forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            progress_bar.set_postfix({'val_loss': loss.item()})
    
    return total_loss / len(self.val_loader)

Validate the model.