Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/modelplane/evaluator/prompt_enricher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import string
from typing import Optional, Sequence

from modelgauge.config import load_secrets_from_config
from modelgauge.model_options import ModelOptions
from modelgauge.secret_values import RawSecrets
from modelgauge.sut import PromptResponseSUT, TextPrompt
from modelgauge.sut_factory import SUT_FACTORY

from modelplane.evaluator.context import EvalContext, NodeOutput
from modelplane.evaluator.cost import RealizedCost
from modelplane.evaluator.nodes import Enricher


class PromptEngineeredNode(Enricher):
"""Node that enriches the context by making an LLM call with a prompt template."""

def __init__(
self,
name: str,
routes: Sequence[str],
prompt_template: string.Template,
sut_id: str,
model_options=None,
sut_secrets: Optional[RawSecrets] = None,
**sut_kwargs,
) -> None:
super().__init__(name=name, routes=routes)

subs = prompt_template.get_identifiers()
if not set(subs).issubset({"prompt", "response"}):
raise ValueError(
"Prompt template may only have 'prompt' and 'response' placeholders."
)
self.prompt_template = prompt_template

if model_options is None:
model_options = ModelOptions()
self.model_options = model_options

if sut_secrets is None:
sut_secrets = load_secrets_from_config()
sut = SUT_FACTORY.make_instance(uid=sut_id, secrets=sut_secrets, **sut_kwargs)
if not isinstance(sut, PromptResponseSUT):
raise ValueError(
f"PromptEngineeredAnnotator only works with PromptResponseSUTs. SUT {sut_id} is of type {type(sut)}"
)
self.sut: PromptResponseSUT = sut

def _build_prompt(self, ctx: EvalContext) -> TextPrompt:
return TextPrompt(
text=self.prompt_template.safe_substitute(
prompt=ctx.prompt, response=ctx.response
)
)

def _count_tokens(self, text: str) -> int:
# Simple tokenizer.
return len(text.split())

def run(self, ctx: EvalContext) -> NodeOutput:
prompt = self._build_prompt(ctx)
sut_request = self.sut.translate_text_prompt(prompt, options=self.model_options)
resp = self.sut.evaluate(sut_request)
sut_response = self.sut.translate_response(sut_request, resp)
return NodeOutput(
value=sut_response.text,
realized_cost=RealizedCost(
input_token_cost=self._count_tokens(prompt.text),
output_token_cost=self._count_tokens(sut_response.text),
),
original_ctx=ctx,
updated_ctx=ctx.with_response(sut_response.text),
)
14 changes: 13 additions & 1 deletion tests/unit/evaluator/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Shared mock node implementations and helpers for evaluator tests."""

import os
import string

import pytest

from modelplane.evaluator.context import EvalContext
from modelplane.evaluator.dag import Composer
from modelplane.evaluator.verdict import Verdict
from modelplane.evaluator.prompt_enricher import PromptEngineeredNode
from modelplane.evaluator.safety import Safety
from modelplane.evaluator.verdict import Verdict

from .mocks import (
AlwaysFalse,
Expand Down Expand Up @@ -196,3 +198,13 @@ def bad_one_step_dag():
)
.add_node(AlwaysUnsafe(name="always_unsafe"))
)


@pytest.fixture
def prompt_enricher() -> PromptEngineeredNode:
return PromptEngineeredNode(
name="prompt_enricher",
routes=["next_node"],
prompt_template=string.Template("$prompt\n$response"),
sut_id="demo_yes_no",
)
33 changes: 33 additions & 0 deletions tests/unit/evaluator/test_prompt_enricher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import string

import pytest

from modelplane.evaluator.context import EvalContext
from modelplane.evaluator.prompt_enricher import PromptEngineeredNode


def test_prompt_enricher_run(prompt_enricher, sample_ctx):
even_ctx = EvalContext(prompt="this is even", response="with the response")
output = prompt_enricher.run(even_ctx)
assert isinstance(output.value, str)
assert output.value == "Yes"

odd_ctx = EvalContext(prompt="this is not even", response="with the response")
output = prompt_enricher.run(odd_ctx)
assert isinstance(output.value, str)
assert output.value == "No"


def test_prompt_enricher_bad_template():
with pytest.raises(
ValueError,
match="Prompt template may only have 'prompt' and 'response' placeholders.",
):
PromptEngineeredNode(
name="bad_enricher",
routes=["next_node"],
prompt_template=string.Template(
"This template has an invalid placeholder: $invalid"
),
sut_id="demo_yes_no",
)
Loading
Loading