Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/alora.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Use the `m alora train` command to fine-tune a LoRA or aLoRA adapter requirement

```bash
m alora train path/to/data.jsonl \
--basemodel ibm-granite/granite-3.2-8b-instruct \
--basemodel ibm-granite/granite-4.0-micro \
--outfile ./checkpoints/alora_adapter \
--adapter alora \
--epochs 6 \
Expand All @@ -47,6 +47,10 @@ m alora train path/to/data.jsonl \
--grad-accum 4
```

> **Note on Model Selection**: Use non-hybrid models (e.g., `granite-4.0-micro`) for aLoRA training.
> Hybrid models (`granite-4.0-h-micro`) are recommended for general inference but adapters should be
> trained on non-hybrid base models for compatibility with the `ibm-granite/rag-intrinsics-lib` repository.

### 📌 Parameters

| Flag | Type | Default | Description |
Expand Down
4 changes: 2 additions & 2 deletions docs/dev/requirement_aLoRA_rerouting.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ from mellea.core import Requirement
from mellea.backends.adapters import GraniteCommonAdapter

m = start_session(
"huggingface.LocalHFBackend:ibm-granite/granite-3.2-8b-instruct")
"huggingface.LocalHFBackend:ibm-granite/granite-4.0-micro")

# By default, the AloraRequirement uses a GraniteCommonAdapter with "requirement_check".
m.backend.add_adapter(GraniteCommonAdapter("ibm-granite/rag-intrinsics-lib", "requirement_check", base_model_name="granite-3.2-8b-instruct"))
m.backend.add_adapter(GraniteCommonAdapter("ibm-granite/rag-intrinsics-lib", "requirement_check", base_model_name="granite-4.0-micro"))

m.instruct(
"Corporate wants you to find the difference between these two strings:\n\naaa\naba")
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/aLora/101_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# Define a backend and add the constraint aLora
backend = LocalHFBackend(
model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5)
model_id="ibm-granite/granite-4.0-h-micro", cache=SimpleLRUCache(5)
)

custom_stembolt_failure_constraint = HFConstraintAlora(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/generative_slots/generative_gsm8k.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pytest: ollama, qualitative, llm, slow

"""Example of chain-of-thought reasoning on a mathematical question from the GSM8K dataset, structured as code for improved performance with the granite-3.3-8B model. The original accuracy in standard "thinking" mode is approximately 80%, while this implementation achieves 85-89% accuracy—up to 9 points higher.
"""Example of chain-of-thought reasoning on a mathematical question from the GSM8K dataset, structured as code for improved performance with Granite 4 models. The original accuracy in standard "thinking" mode is approximately 80%, while this implementation achieves 85-89% accuracy—up to 9 points higher.

This demonstrates that generative decorators are sufficient for complex reasoning tasks: not only do they maintain or improve performance, but they also significantly enhance observability and control. For instance, the structured Thought titles can be easily surfaced in a UI, providing instant insight into the model's reasoning process.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from docs.examples.helper import req_print, w
from mellea import start_session
from mellea.backends import ModelOption
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO
from mellea.stdlib.sampling import RejectionSamplingStrategy

# create a session using Granite 4 Micro (3B) on Ollama and a simple context [see below]
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/intrinsics/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Create the backend. Example for a VLLM Server. Commented out in favor of the hugging face code for now.
# # Assumes a locally running VLLM server.
# backend = OpenAIBackend(
# model_id="ibm-granite/granite-3.3-8b-instruct",
# model_id="ibm-granite/granite-4.0-micro",
# base_url="http://0.0.0.0:8000/v1",
# api_key="EMPTY",
# )
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/m_serve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

response = client.chat.completions.create(
messages=[{"role": "user", "content": "Find all the real roots of x^3 + 1."}],
model="granite3.3:8b",
model="granite4:micro-h",
)

print(response.choices[0])
2 changes: 1 addition & 1 deletion docs/examples/tutorial/document_mobject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pytest: ollama, qualitative, llm, requires_heavy_ram

from mellea.backends import model_ids
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO
from mellea.stdlib.components.docs.richdocument import RichDocument

rd = RichDocument.from_document_file("https://arxiv.org/pdf/1906.04043")
Expand Down
4 changes: 2 additions & 2 deletions docs/kv_smash/hf_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mellea.backends import ModelOption
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO
from mellea.core import CBlock
from mellea.stdlib.components import Message
from mellea.stdlib.context import ChatContext
Expand Down Expand Up @@ -30,7 +30,7 @@ async def example():
role="user",
content="What is the likely ZIP code of Nathan Fulton's work address?",
)
backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)
backend = LocalHFBackend(model_id=IBM_GRANITE_4_HYBRID_MICRO)
mot = await backend._generate_from_context_with_kv_cache(
action=msg, ctx=ctx, model_options={ModelOption.MAX_NEW_TOKENS: 64}
)
Expand Down
4 changes: 2 additions & 2 deletions docs/kv_smash/kv_with_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO

backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B)
backend = LocalHFBackend(model_id=IBM_GRANITE_4_HYBRID_MICRO)

model = backend._model
tokenizer = backend._tokenizer
Expand Down
2 changes: 1 addition & 1 deletion docs/kv_smash/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches

model_id = "ibm-granite/granite-3.3-8b-instruct"
model_id = "ibm-granite/granite-4.0-tiny-preview"
device = torch.device("mps")
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ We will train a lightweight adapter with the `m alora train` command on this sma
```bash
m alora train /to/stembolts_data.jsonl \
--promptfile ./prompt_config.json \
--basemodel ibm-granite/granite-3.2-8b-instruct \
--basemodel ibm-granite/granite-4.0-h-micro \
--outfile ./checkpoints/alora_adapter \
--adapter alora \
--epochs 6 \
Expand Down Expand Up @@ -1321,9 +1321,9 @@ Assuming a component's TemplateRepresentation contains a `template_order` field,

If the default formatter searches the template path or the package, it uses the following logic:
- look in the `.../templates/prompts/...` directory
- traverse sub-directories in that path that match the formatter's model id (ie `ibm-granite/granite-3.2-8b-instruct` will match `.../templates/prompts/granite/granite-3-2/instruct`) or default (ie `.../templates/prompts/default`)
- traverse sub-directories in that path that match the formatter's model id (ie `ibm-granite/granite-4.0-h-micro` will match `.../templates/prompts/granite/granite-4-0-h/micro`) or default (ie `.../templates/prompts/default`)
- return the template at the deepest directory path
- the default template formatter assumes that a model will only have one match in any given directory; in other words, traversing a `templates` directory with both `prompts/granite/...` and `prompts/ibm/...` for `ibm-granite/granite-3.2-8b-instruct` should not happen
- the default template formatter assumes that a model will only have one match in any given directory; in other words, traversing a `templates` directory with both `prompts/granite/...` and `prompts/ibm/...` for `ibm-granite/granite-4.0-h-micro` should not happen

#### Editing an Existing Class
To customize the template and template representation of an existing class, simply create a new class that inherits from the class you want to edit. Then, override the format_for_llm function and create a new template.
Expand Down
39 changes: 35 additions & 4 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,28 @@ class ModelIdentifier:
#### IBM models ####
####################

IBM_GRANITE_4_MICRO_3B = ModelIdentifier(
hf_model_name="ibm-granite/granite-4.0-micro",
ollama_name="granite4:micro",
# Granite 4 Hybrid Models (Recommended for general use)
IBM_GRANITE_4_HYBRID_MICRO = ModelIdentifier(
hf_model_name="ibm-granite/granite-4.0-h-micro",
ollama_name="granite4:micro-h",
watsonx_name=None, # Only h-small available on Watsonx
)

IBM_GRANITE_4_HYBRID_TINY = ModelIdentifier(
hf_model_name="ibm-granite/granite-4.0-h-tiny",
ollama_name="granite4:tiny-h",
watsonx_name=None, # Only h-small available on Watsonx
)

IBM_GRANITE_4_HYBRID_SMALL = ModelIdentifier(
hf_model_name="ibm-granite/granite-4.0-h-small",
ollama_name="granite4:small-h",
watsonx_name="ibm/granite-4-h-small",
)
# todo: watsonx model is different from ollama model - should be same.


# Deprecated Granite 3 models - kept for backward compatibility
# These maintain their original model references (not upgraded to Granite 4)
IBM_GRANITE_3_2_8B = ModelIdentifier(
hf_model_name="ibm-granite/granite-3.2-8b-instruct",
ollama_name="granite3.2:8b",
Expand All @@ -45,6 +59,23 @@ class ModelIdentifier:
watsonx_name="ibm/granite-3-3-8b-instruct",
)

# Deprecated: Use IBM_GRANITE_4_HYBRID_MICRO or IBM_GRANITE_4_HYBRID_SMALL instead
# Kept for backward compatibility with per-backend model selection:
# - Ollama/HF: Uses MICRO (fits in CI memory constraints)
# - Watsonx: Uses SMALL (required for watsonx support)
IBM_GRANITE_4_MICRO_3B = ModelIdentifier(
hf_model_name="ibm-granite/granite-4.0-h-micro",
ollama_name="granite4:micro-h",
watsonx_name="ibm/granite-4-h-small",
)

# Granite 3.3 Vision Model (2B)
IBM_GRANITE_3_3_VISION_2B = ModelIdentifier(
hf_model_name="ibm-granite/granite-vision-3.3-2b",
ollama_name="ibm/granite3.3-vision:2b",
watsonx_name=None,
)

IBM_GRANITE_GUARDIAN_3_0_2B = ModelIdentifier(
hf_model_name="ibm-granite/granite-guardian-3.0-2b",
ollama_name="granite3-guardian:2b",
Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.

Args:
model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_4_HYBRID_MICRO.
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
base_url : Base url for LLM API. Defaults to None.
model_options : Generation options to pass to the LLM. Defaults to None.
Expand Down
4 changes: 2 additions & 2 deletions mellea/backends/watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class WatsonxAIBackend(FormatterBackend):

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_3_3_8B,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_HYBRID_SMALL,
formatter: ChatFormatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
Expand All @@ -66,7 +66,7 @@ def __init__(
"""A generic watsonx backend that wraps around the ibm_watsonx_ai sdk.

Args:
model_id : Model id. Defaults to model_ids.IBM_GRANITE_3_3_8B.
model_id : Model id. Defaults to model_ids.IBM_GRANITE_4_HYBRID_SMALL.
formatter : input formatter. Defaults to TemplateFormatter in __init__.
base_url : url for watson ML deployment. Defaults to env(WATSONX_URL).
model_options : Global model options to pass to the model. Defaults to None.
Expand Down
9 changes: 7 additions & 2 deletions test/backends/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
]

from mellea import MelleaSession
from mellea.backends import ModelOption
from mellea.backends import ModelOption, model_ids
from mellea.backends.adapters import GraniteCommonAdapter
from mellea.backends.cache import SimpleLRUCache
from mellea.backends.huggingface import LocalHFBackend, _assert_correct_adapters
Expand All @@ -46,7 +46,12 @@

@pytest.fixture(scope="module")
def backend():
"""Shared HuggingFace backend for all tests in this module."""
"""Shared HuggingFace backend for all tests in this module.

Uses Granite 3.3-8b for aLoRA adapter compatibility.
The ibm-granite/rag-intrinsics-lib repository only has adapters for
Granite 3.3 models. Granite 4 adapters are not yet available.
"""
backend = LocalHFBackend(
model_id="ibm-granite/granite-3.3-8b-instruct",
formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"),
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_litellm_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mellea.stdlib.context import SimpleContext
from mellea.stdlib.sampling import RejectionSamplingStrategy

_MODEL_ID = f"ollama_chat/{model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name}"
_MODEL_ID = f"ollama_chat/{model_ids.IBM_GRANITE_4_HYBRID_MICRO.ollama_name}"


@pytest.fixture(scope="function")
Expand Down
6 changes: 3 additions & 3 deletions test/backends/test_openai_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mellea import MelleaSession
from mellea.backends import ModelOption
from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO
from mellea.backends.openai import OpenAIBackend
from mellea.core import CBlock, ModelOutputThunk
from mellea.formatters import TemplateFormatter
Expand All @@ -23,8 +23,8 @@
def backend(gh_run: int):
"""Shared OpenAI backend configured for Ollama."""
return OpenAIBackend(
model_id=IBM_GRANITE_4_MICRO_3B.ollama_name, # type: ignore
formatter=TemplateFormatter(model_id=IBM_GRANITE_4_MICRO_3B.hf_model_name), # type: ignore
model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name, # type: ignore
formatter=TemplateFormatter(model_id=IBM_GRANITE_4_HYBRID_MICRO.hf_model_name), # type: ignore
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
api_key="ollama",
)
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_openai_vllm/test_openai_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Answer(pydantic.BaseModel):
class TestOpenAIALoraStuff:
backend = OpenAIBackend(
model_id="ibm-granite/granite-3.3-8b-instruct",
formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"),
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.3-8b-instruct"),
base_url="http://localhost:8000/v1",
api_key="EMPTY",
)
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_vision_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
pytestmark = [pytest.mark.ollama, pytest.mark.llm]

from mellea import MelleaSession, start_session
from mellea.backends import ModelOption
from mellea.backends import ModelOption, model_ids
from mellea.core import ImageBlock, ModelOutputThunk
from mellea.stdlib.components import Instruction, Message

Expand Down
4 changes: 2 additions & 2 deletions test/backends/test_vision_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mellea import MelleaSession, start_session
from mellea.backends import ModelOption
from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO
from mellea.core import ImageBlock, ModelOutputThunk
from mellea.stdlib.components import Instruction, Message

Expand All @@ -21,7 +21,7 @@ def m_session(gh_run):
if gh_run == 1:
m = start_session(
"openai",
model_id=IBM_GRANITE_4_MICRO_3B.ollama_name, # type: ignore
model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name, # type: ignore
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
api_key="ollama",
model_options={ModelOption.MAX_NEW_TOKENS: 5},
Expand Down
21 changes: 18 additions & 3 deletions test/backends/test_watsonx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
]

from mellea import MelleaSession
from mellea.backends import ModelOption
from mellea.backends import ModelOption, model_ids
from mellea.backends.watsonx import WatsonxAIBackend
from mellea.core import CBlock, ModelOutputThunk
from mellea.formatters import TemplateFormatter
Expand All @@ -32,8 +32,8 @@ def backend():
pytest.skip("Skipping watsonx tests.")
else:
return WatsonxAIBackend(
model_id="ibm/granite-3-3-8b-instruct",
formatter=TemplateFormatter(model_id="ibm-granite/granite-3.3-8b-instruct"),
model_id=model_ids.IBM_GRANITE_4_HYBRID_SMALL,
formatter=TemplateFormatter(model_id=model_ids.IBM_GRANITE_4_HYBRID_SMALL),
)


Expand Down Expand Up @@ -223,6 +223,21 @@ async def get_client_async():
assert len(backend._client_cache.cache.values()) == 2


def test_default_model():
"""Verify WatsonxAIBackend uses correct default model."""
if int(os.environ.get("CICD", 0)) == 1:
pytest.skip("Skipping watsonx tests.")

# Create backend without specifying model_id
default_backend = WatsonxAIBackend()

# Verify it uses IBM_GRANITE_4_HYBRID_SMALL as default
assert default_backend._model_id == model_ids.IBM_GRANITE_4_HYBRID_SMALL, (
f"Expected default model to be IBM_GRANITE_4_HYBRID_SMALL, "
f"but got {default_backend._model_id}"
)


if __name__ == "__main__":
import pytest

Expand Down
9 changes: 4 additions & 5 deletions test/core/test_component_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import mellea.stdlib.functional as mfuncs
from mellea import MelleaSession, start_session
from mellea.backends.model_ids import IBM_GRANITE_4_MICRO_3B
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO
from mellea.backends.ollama import OllamaModelBackend
from mellea.core import (
CBlock,
Expand Down Expand Up @@ -64,10 +64,10 @@ def backend(gh_run: int):
"""Shared backend."""
if gh_run == 1:
return OllamaModelBackend(
model_id=IBM_GRANITE_4_MICRO_3B.ollama_name # type: ignore
model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name # type: ignore
)
else:
return OllamaModelBackend(model_id="granite3.3:8b")
return OllamaModelBackend(model_id=IBM_GRANITE_4_HYBRID_MICRO.ollama_name) # type: ignore


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -117,11 +117,10 @@ def test_incorrect_type_override():


# Marking as qualitative for now since there's so much generation required for this.
# Uses granite3.3:8b (8B, heavy) in local mode
# Uses granite4:micro-h (3B hybrid, lightweight) in local mode
@pytest.mark.qualitative
@pytest.mark.ollama
@pytest.mark.requires_gpu
@pytest.mark.requires_heavy_ram
@pytest.mark.llm
async def test_generating(session):
m = session
Expand Down
Loading
Loading