Skip to main content

Overview

The LlaSMol training script uses LoRA (Low-Rank Adaptation) to efficiently fine-tune large language models on chemistry tasks. It supports distributed training, Weights & Biases integration, and custom data collation.

Running Training

Basic Command

python finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./checkpoints/llasmol" \
  --batch_size=512 \
  --micro_batch_size=4 \
  --num_epochs=3 \
  --learning_rate=1e-4

Distributed Training

torchrun --nproc_per_node=4 finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./checkpoints/llasmol" \
  --batch_size=512 \
  --micro_batch_size=4

Training Parameters

Model & Data

base_model
string
required
The base model to fine-tune. Supported models:
  • mistralai/Mistral-7B-v0.1
  • facebook/galactica-6.7b
  • meta-llama/Llama-2-7b-hf
  • codellama/CodeLlama-7b-hf
data_path
string
required
Path to the training dataset (Hugging Face dataset path or local path).
output_dir
string
default:"checkpoint"
Directory where model checkpoints will be saved.
train_split
string
default:"train"
Name of the training split in the dataset.
dev_split
string
default:"validation"
Name of the validation split in the dataset.
tasks
list[string]
default:"None"
Specific tasks to train on. If None, trains on all tasks in the dataset.

Training Hyperparameters

batch_size
int
default:"512"
Total batch size across all devices (effective batch size).
micro_batch_size
int
default:"4"
Batch size per device. Gradient accumulation steps are calculated as batch_size / micro_batch_size.
num_epochs
int
default:"3"
Number of training epochs.
learning_rate
float
default:"1e-4"
Learning rate for the optimizer.
cutoff_len
int
default:"512"
Maximum sequence length. Longer sequences will be truncated.
optim
string
default:"adamw_bnb_8bit"
Optimizer to use. Options include:
  • adamw_bnb_8bit: 8-bit AdamW (memory efficient)
  • adamw_torch: Standard PyTorch AdamW
  • Other HuggingFace-supported optimizers
lr_scheduler
string
default:"cosine"
Learning rate scheduler type. Options: cosine, linear, constant, etc.
warmup_steps
int
default:"1000"
Number of warmup steps for the learning rate scheduler.
use_val_set
bool
default:"True"
Whether to use a validation set for evaluation during training.
train_on_inputs
bool
default:"True"
If False, masks out inputs in the loss calculation (only compute loss on outputs).
group_by_length
bool
default:"False"
Whether to group training samples by length. Can speed up training but may produce odd loss curves.

LoRA Hyperparameters

lora_r
int
default:"16"
LoRA rank (dimensionality of the low-rank matrices).
lora_alpha
int
default:"16"
LoRA scaling parameter.
lora_dropout
float
default:"0.05"
Dropout probability for LoRA layers.
lora_target_modules
list[string]
Model modules to apply LoRA to. Common options:
  • Query/Key/Value projections: q_proj, k_proj, v_proj, o_proj
  • MLP layers: gate_proj, down_proj, up_proj
  • Embeddings: wte
modules_to_save
list[string]
default:"[]"
Additional modules to save (not using LoRA). Useful for saving embeddings or classification heads.

Weights & Biases Integration

wandb_project
string
default:""
Weights & Biases project name for logging.
wandb_run_name
string
default:""
Name for this specific W&B run.
wandb_watch
string
default:""
W&B watch mode. Options: false, gradients, all.
wandb_log_model
string
default:""
Whether to log model checkpoints to W&B. Options: false, true.

Checkpointing

resume_from_checkpoint
string
default:"None"
Path to a checkpoint to resume training from.
logging_steps
int
default:"10"
Log metrics every N steps.
save_steps
int
default:"200"
Save checkpoint every N steps.
eval_steps
int
default:"200"
Run evaluation every N steps (if validation set is enabled).
save_total_limit
int
default:"None"
Maximum number of checkpoints to keep. Older checkpoints are deleted.

Precision & Device

precision
string
default:"bf16"
Training precision. Currently only bf16 (bfloat16) is tested and supported.
use_int8
bool
default:"False"
Whether to load the model in 8-bit mode (requires bitsandbytes).

Training Function

from finetune import train

train(
    base_model="mistralai/Mistral-7B-v0.1",
    data_path="osunlp/SMolInstruct",
    output_dir="./checkpoints/llasmol",
    batch_size=512,
    micro_batch_size=4,
    num_epochs=3,
    learning_rate=1e-4,
    lora_r=16,
    lora_alpha=16,
    wandb_project="llasmol-training",
    wandb_run_name="mistral-7b-chemistry"
)

LoRA Fine-Tuning

LlaSMol uses Parameter-Efficient Fine-Tuning (PEFT) with LoRA to reduce memory requirements:
  • Low-Rank Adaptation: Instead of fine-tuning all model parameters, LoRA adds trainable low-rank matrices to attention layers
  • Memory Efficient: Typically requires 3-4x less GPU memory than full fine-tuning
  • Fast Training: Fewer parameters to update means faster training iterations
  • Modular: LoRA adapters can be easily swapped or combined

LoRA Configuration

The LoRA configuration is defined using PEFT’s LoraConfig:
from peft import LoraConfig

config = LoraConfig(
    r=16,                    # Rank of LoRA matrices
    lora_alpha=16,           # Scaling factor
    target_modules=[         # Modules to apply LoRA to
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "down_proj", "up_proj"
    ],
    lora_dropout=0.05,       # Dropout for regularization
    bias="none",             # Don't train bias parameters
    task_type="CAUSAL_LM"    # Causal language modeling
)

Custom Training Components

CustomTrainer

The training script uses a CustomTrainer class (defined in trainer.py) that extends HuggingFace’s Trainer with:
  • Core Loss Tracking: Tracks loss on “core” tokens (important chemical entities)
  • Core Mask Support: Uses core_mask to weight loss on specific tokens
  • Custom Data Collation: Handles chemistry-specific data formatting

Data Processing

The generate_and_tokenize_prompt() function:
  1. Generates chat-formatted prompts from input/output pairs
  2. Tokenizes with truncation to cutoff_len
  3. Optionally masks input tokens (when train_on_inputs=False)
  4. Generates core masks for important chemical tokens
def generate_and_tokenize_prompt(data_point, add_core_mask=True):
    input_text = data_point['input']
    output_text = data_point['output']
    
    # Create chat format
    chat = generate_chat(input_text, output_text)
    full_prompt = prompter.generate_prompt(chat)
    
    # Tokenize
    tokenized = tokenize(full_prompt)
    
    # Mask inputs if specified
    if not train_on_inputs:
        user_prompt_len = len(tokenize(input_only_prompt))
        tokenized["labels"] = [-100] * user_prompt_len + tokenized["labels"][user_prompt_len:]
    
    return tokenized

Usage Examples

Basic Training

python finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./output/mistral-chemistry" \
  --num_epochs=3

Training with W&B Logging

python finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./output/mistral-chemistry" \
  --wandb_project="chemistry-llm" \
  --wandb_run_name="mistral-llasmol-v1" \
  --wandb_watch="gradients"

Custom LoRA Configuration

python finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./output/mistral-chemistry" \
  --lora_r=32 \
  --lora_alpha=64 \
  --lora_dropout=0.1 \
  --lora_target_modules q_proj k_proj v_proj o_proj

Resume from Checkpoint

python finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./output/mistral-chemistry" \
  --resume_from_checkpoint="./output/mistral-chemistry/checkpoint-1000"

Task-Specific Training

python finetune.py \
  --base_model="mistralai/Mistral-7B-v0.1" \
  --data_path="osunlp/SMolInstruct" \
  --output_dir="./output/name-conversion" \
  --tasks name_conversion-i2s name_conversion-s2i

File Location

LLM4Chem/finetune.py
  • trainer.py: CustomTrainer and CustomDataCollator classes
  • model.py: Model loading utilities
  • config.py: Task configurations and base model mappings
  • utils/general_prompter.py: Prompt formatting
  • utils/core_tagger.py: Core token masking for chemistry entities
The training script uses torch.distributed and requires proper initialization. Make sure to use torchrun for multi-GPU training.
For best results with chemistry tasks, use the default LoRA configuration which has been optimized for molecular understanding and generation.