Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions cli/alora/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
import json
import os
import sys
import warnings

import torch
import typer
from alora.config import aLoraConfig
from alora.peft_model_alora import aLoRAPeftModelForCausalLM
from datasets import Dataset
from peft import LoraConfig, PeftModelForCausalLM
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer

# Handle MPS with old PyTorch versions on macOS only
# Accelerate's GradScaler requires PyTorch >= 2.8.0 for MPS
if sys.platform == "darwin" and hasattr(torch.backends, "mps"):
if torch.backends.mps.is_available():
pytorch_version = tuple(int(x) for x in torch.__version__.split(".")[:2])
if pytorch_version < (2, 8):
# Disable MPS detection to force CPU usage on macOS
# This must be done before any models or tensors are initialized
torch.backends.mps.is_available = lambda: False # type: ignore[assignment]
torch.backends.mps.is_built = lambda: False # type: ignore[assignment]
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0"
warnings.warn(
"MPS is available but PyTorch < 2.8.0. Disabling MPS to avoid "
"gradient scaling issues. Training will run on CPU. "
"To use MPS, upgrade to PyTorch >= 2.8.0.",
UserWarning,
stacklevel=2,
)


def load_dataset_from_json(json_path, tokenizer, invocation_prompt):
data = []
Expand Down Expand Up @@ -90,8 +110,12 @@ def train_model(
train_dataset = dataset.select(range(split_idx))
val_dataset = dataset.select(range(split_idx, len(dataset)))

# Use device_map="auto" only when CUDA is available
# In CPU-only environments (like CI), device_map="auto" creates meta tensors
# which cause "Cannot copy out of meta tensor" errors
device_map = "auto" if torch.cuda.is_available() else None
model_base = AutoModelForCausalLM.from_pretrained(
base_model, device_map="auto", use_cache=False
base_model, device_map=device_map, use_cache=False
)

collator = DataCollatorForCompletionOnlyLM(invocation_prompt, tokenizer=tokenizer)
Expand All @@ -100,21 +124,21 @@ def train_model(
os.makedirs(output_dir, exist_ok=True)

if adapter == "alora":
peft_config = aLoraConfig(
invocation_string=invocation_prompt,
# Tokenize the invocation string for PEFT 0.18.0 native aLoRA
invocation_token_ids = tokenizer.encode(
invocation_prompt, add_special_tokens=False
)

peft_config = LoraConfig(
r=32,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj"],
alora_invocation_tokens=invocation_token_ids, # Enable aLoRA
)
response_token_ids = tokenizer(
invocation_prompt, return_tensors="pt", add_special_tokens=False
)["input_ids"]
model = aLoRAPeftModelForCausalLM(
model_base, peft_config, response_token_ids=response_token_ids
)
model = get_peft_model(model_base, peft_config)

sft_args = SFTConfig(
output_dir=output_dir,
Expand Down Expand Up @@ -148,7 +172,7 @@ def train_model(
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj"],
)
model = PeftModelForCausalLM(model_base, peft_config)
model = get_peft_model(model_base, peft_config)

sft_args = SFTConfig(
output_dir=output_dir,
Expand Down
7 changes: 3 additions & 4 deletions docs/alora.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Mellea CLI — Train & Upload LoRA/aLoRA Adapters

Mellea provides a command-line interface for training and uploading [LoRA](https://arxiv.org/abs/2106.09685) or [aLoRA](https://github.com/IBM/alora) adapters for causal language models. This tool is useful for adapting base models like IBM Granite to custom tasks using prompt-based classification. The major goal is to help customer train a requirement validator.
Mellea provides a command-line interface for training and uploading [LoRA](https://arxiv.org/abs/2106.09685) or [aLoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora#alora) adapters for causal language models. This tool is useful for adapting base models like IBM Granite to custom tasks using prompt-based classification. The major goal is to help customer train a requirement validator.

---

Expand Down Expand Up @@ -82,13 +82,12 @@ This will:
## 🛠 Requirements

- Python 3.8+
- Install the following dependencies manually or via `pip install mellea`:
- Install the following dependencies manually or via `pip install mellea[hf]`:
- `transformers`
- `trl`
- `peft`
- `peft>=0.18.1` (native aLoRA support)
- `datasets`
- `huggingface_hub`
- `alora`


---
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ m = "cli.m:cli"

hf = [
"accelerate>=1.9.0",
"alora==0.2.0",
"datasets>=4.0.0",
"outlines-core==0.1.26",
"outlines", # intentionally un-versioned, expecting a minor update. coutlines-core version should be enough to specify it
"peft>=0.18.0", # aLoRA support was added in Peft 0.18.0
"peft>=0.18.1", # Native aLoRA support added in PEFT 0.18.0
"transformers>=4.53.2,<5",
"trl==0.19.1",
"granite-common[transformers]",
Expand Down
221 changes: 221 additions & 0 deletions test/cli/test_alora_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""Unit tests for aLoRA/LoRA training configuration."""

from unittest.mock import MagicMock, Mock, patch

import pytest
from peft import LoraConfig


@pytest.mark.huggingface
def test_alora_config_creation():
"""Test that aLoRA config is created correctly with PEFT 0.18+."""
from cli.alora.train import train_model

# Mock all the heavy dependencies
with (
patch("cli.alora.train.AutoTokenizer") as mock_tokenizer_class,
patch("cli.alora.train.AutoModelForCausalLM") as mock_model_class,
patch("cli.alora.train.Dataset"),
patch("cli.alora.train.SafeSaveTrainer") as mock_trainer,
patch("cli.alora.train.get_peft_model") as mock_get_peft_model,
patch("cli.alora.train.load_dataset_from_json") as mock_load_dataset,
patch("cli.alora.train.DataCollatorForCompletionOnlyLM"),
):
# Setup mocks
mock_tokenizer = Mock()
mock_tokenizer.encode.return_value = [123, 456, 789] # Mock token IDs
mock_tokenizer.eos_token = "<eos>"
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer

mock_model = Mock()
mock_model_class.from_pretrained.return_value = mock_model

mock_peft_model = Mock()
mock_get_peft_model.return_value = mock_peft_model

# Mock dataset
mock_ds = MagicMock()
mock_ds.shuffle.return_value = mock_ds
mock_ds.select.return_value = mock_ds
mock_ds.__len__ = Mock(return_value=10)
mock_load_dataset.return_value = mock_ds

# Mock trainer
mock_trainer_instance = Mock()
mock_trainer.return_value = mock_trainer_instance

# Call train_model with aLoRA adapter
train_model(
dataset_path="test.jsonl",
base_model="test-model",
output_file="./test_output/adapter",
adapter="alora",
epochs=1,
)

# Verify get_peft_model was called
assert mock_get_peft_model.called, "get_peft_model should be called"

# Get the LoraConfig that was passed to get_peft_model
call_args = mock_get_peft_model.call_args
assert call_args is not None, (
"get_peft_model should have been called with arguments"
)

peft_config = call_args[0][1] # Second argument is the config

# Verify it's a LoraConfig
assert isinstance(peft_config, LoraConfig), "Should use LoraConfig"

# Verify aLoRA-specific parameter is set
assert hasattr(peft_config, "alora_invocation_tokens"), (
"Config should have alora_invocation_tokens attribute"
)
assert peft_config.alora_invocation_tokens == [123, 456, 789], (
"alora_invocation_tokens should match tokenized invocation prompt"
)

# Verify other LoRA parameters
assert peft_config.r == 32, "Rank should be 32 for aLoRA"
assert peft_config.lora_alpha == 32, "Alpha should be 32"
assert peft_config.task_type == "CAUSAL_LM", "Task type should be CAUSAL_LM"


@pytest.mark.huggingface
def test_lora_config_creation():
"""Test that standard LoRA config is created correctly."""
from cli.alora.train import train_model

# Mock all the heavy dependencies
with (
patch("cli.alora.train.AutoTokenizer") as mock_tokenizer_class,
patch("cli.alora.train.AutoModelForCausalLM") as mock_model_class,
patch("cli.alora.train.Dataset"),
patch("cli.alora.train.SafeSaveTrainer") as mock_trainer,
patch("cli.alora.train.get_peft_model") as mock_get_peft_model,
patch("cli.alora.train.load_dataset_from_json") as mock_load_dataset,
patch("cli.alora.train.DataCollatorForCompletionOnlyLM"),
):
# Setup mocks
mock_tokenizer = Mock()
mock_tokenizer.eos_token = "<eos>"
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer

mock_model = Mock()
mock_model_class.from_pretrained.return_value = mock_model

mock_peft_model = Mock()
mock_get_peft_model.return_value = mock_peft_model

# Mock dataset
mock_ds = MagicMock()
mock_ds.shuffle.return_value = mock_ds
mock_ds.select.return_value = mock_ds
mock_ds.__len__ = Mock(return_value=10)
mock_load_dataset.return_value = mock_ds

# Mock trainer
mock_trainer_instance = Mock()
mock_trainer.return_value = mock_trainer_instance

# Call train_model with standard LoRA adapter
train_model(
dataset_path="test.jsonl",
base_model="test-model",
output_file="./test_output/adapter",
adapter="lora", # Standard LoRA, not aLoRA
epochs=1,
)

# Verify get_peft_model was called
assert mock_get_peft_model.called, "get_peft_model should be called"

# Get the LoraConfig that was passed to get_peft_model
call_args = mock_get_peft_model.call_args
assert call_args is not None, (
"get_peft_model should have been called with arguments"
)

peft_config = call_args[0][1] # Second argument is the config

# Verify it's a LoraConfig
assert isinstance(peft_config, LoraConfig), "Should use LoraConfig"

# Verify aLoRA-specific parameter is NOT set for standard LoRA
assert (
not hasattr(peft_config, "alora_invocation_tokens")
or peft_config.alora_invocation_tokens is None
), "Standard LoRA should not have alora_invocation_tokens"

# Verify other LoRA parameters
assert peft_config.r == 6, "Rank should be 6 for standard LoRA"
assert peft_config.lora_alpha == 32, "Alpha should be 32"
assert peft_config.task_type == "CAUSAL_LM", "Task type should be CAUSAL_LM"


@pytest.mark.huggingface
def test_invocation_prompt_tokenization():
"""Test that invocation prompt is correctly tokenized for aLoRA."""
from cli.alora.train import train_model

with (
patch("cli.alora.train.AutoTokenizer") as mock_tokenizer_class,
patch("cli.alora.train.AutoModelForCausalLM") as mock_model_class,
patch("cli.alora.train.get_peft_model") as mock_get_peft_model,
patch("cli.alora.train.load_dataset_from_json") as mock_load_dataset,
patch("cli.alora.train.SafeSaveTrainer"),
patch("cli.alora.train.DataCollatorForCompletionOnlyLM"),
patch("cli.alora.train.os.makedirs"),
):
# Setup tokenizer mock
mock_tokenizer = Mock()
custom_tokens = [111, 222, 333, 444]
mock_tokenizer.encode.return_value = custom_tokens
mock_tokenizer.eos_token = "<eos>"
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer

# Setup other mocks
mock_model_class.from_pretrained.return_value = Mock()
mock_get_peft_model.return_value = Mock()

mock_ds = MagicMock()
mock_ds.shuffle.return_value = mock_ds
mock_ds.select.return_value = mock_ds
mock_ds.__len__ = Mock(return_value=10)
mock_load_dataset.return_value = mock_ds

# Call with custom invocation prompt
train_model(
dataset_path="test.jsonl",
base_model="test-model",
output_file="./test_output/adapter",
adapter="alora",
epochs=1,
)

# Verify tokenizer.encode was called with the invocation prompt
assert mock_tokenizer.encode.called, "Tokenizer encode should be called"

# Verify the config has the correct tokens
peft_config = mock_get_peft_model.call_args[0][1]
assert peft_config.alora_invocation_tokens == custom_tokens, (
"Config should have the tokenized invocation prompt"
)


def test_imports_work():
"""Test that PEFT imports work correctly (no IBM alora dependency)."""
# This test verifies the migration was successful
from peft import LoraConfig, get_peft_model

# Verify we can create a LoraConfig with alora_invocation_tokens
config = LoraConfig(
r=32, lora_alpha=32, task_type="CAUSAL_LM", alora_invocation_tokens=[1, 2, 3]
)

assert config.alora_invocation_tokens == [1, 2, 3], (
"LoraConfig should support alora_invocation_tokens parameter"
)

# Verify get_peft_model is available
assert callable(get_peft_model), "get_peft_model should be callable"
Loading