diff --git a/.github/workflows/_pull_requests.yml b/.github/workflows/_pull_requests.yml index d73a91f9..54fdb78a 100644 --- a/.github/workflows/_pull_requests.yml +++ b/.github/workflows/_pull_requests.yml @@ -16,10 +16,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v7.0.0 - name: Set up Miniconda - uses: conda-incubator/setup-miniconda@v3 + uses: conda-incubator/setup-miniconda@v4.0.1 with: activate-environment: graphrag-eval environment-file: environment.yml diff --git a/.github/workflows/_release.yml b/.github/workflows/_release.yml index e491a656..874a9515 100644 --- a/.github/workflows/_release.yml +++ b/.github/workflows/_release.yml @@ -9,10 +9,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v7.0.0 - name: Set up Miniconda - uses: conda-incubator/setup-miniconda@v3 + uses: conda-incubator/setup-miniconda@v4.0.1 with: activate-environment: graphrag-eval environment-file: environment.yml diff --git a/docs/config.md b/docs/config.md index 5f4cb724..3b5ddb85 100644 --- a/docs/config.md +++ b/docs/config.md @@ -5,6 +5,9 @@ `run_evaluation()` and `compute_aggregates()` are configured using a YAML file whose path is passed as the optional parameter `config_file_path`. Example call: ```python +from graphrag_eval import run_evaluation + + evaluation_results = await run_evaluation( reference_dataset, target_dataset, @@ -43,6 +46,7 @@ The configuration has the following structure: - `instructions`: (`str`) instructions for the evaluation - `outputs`: (`map[str,str]`) output variable names and descriptions - `answer_correctness` + - `enabled` (`bool`, default: `True`) - if `False`, then [answer correctness metrics](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md) won't be calculated. - `prompt`: (`str`, default [here](https://github.com/Ontotext-AD/graphrag-eval/blob/main/graphrag_eval/prompts/template.md)) Template for instructions to the LLM on how to compute the answer correctness. Must contain placeholders `{question}`, `{reference_answer}` and `{actual_answer}` and no others. @@ -52,6 +56,8 @@ The configuration has the following structure: # Output format ``` +- `answer_relevance` + - `enabled` (`bool`, default: `True`) - if `False`, then [answer relevance metric](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md) won't be calculated. ## Example configuration file with LLM configuration diff --git a/docs/metrics.md b/docs/metrics.md index 091c736a..6fe7be87 100644 --- a/docs/metrics.md +++ b/docs/metrics.md @@ -50,6 +50,9 @@ See [ยง Example configuration file with custom evaluation](https://github.com/On ### Example call to evaluate using custom metrics ```python +from graphrag_eval import run_evaluation, compute_aggregates + + evaluation_results = await run_evaluation( reference_qa_dataset, chat_responses, diff --git a/docs/output.md b/docs/output.md index 9c3cce9e..6a0a154c 100644 --- a/docs/output.md +++ b/docs/output.md @@ -18,7 +18,7 @@ The output is a list of objects, one for each reference item. Each output object - `answer_recall`: (optional) `answer_matching_claims_count / answer_reference_claims_count` - `answer_precision`: (optional) `answer_matching_claims_count / answer_actual_claims_count` - `answer_correctness_reason`: (optional) LLM reasoning in extracting and matching claims from `reference_answer` and `actual_answer` -- `answer_eval_error`: (optional) error message if answer evaluation failed +- `answer_correctness_error`: (optional) error message if answer correctness evaluation failed - `answer_f1`: (optional) Harmonic mean of `answer_recall` and `answer_precision` - `answer_relevance`: (optional `float` in [0, 1]) answer relevance score - `answer_relevance_error`: (optional) error message if answer relevance evaluation failed diff --git a/graphrag_eval/aggregation.py b/graphrag_eval/aggregation.py index acedd665..461fed4e 100644 --- a/graphrag_eval/aggregation.py +++ b/graphrag_eval/aggregation.py @@ -1,13 +1,13 @@ import json -import yaml from collections import defaultdict from collections.abc import Sequence from pathlib import Path from statistics import mean, median from typing import Any, Collection, Iterable -from . import evaluation +import yaml +from . import evaluation METRICS = [ "answer_recall", @@ -155,7 +155,7 @@ def compute_micro_stats( ) -> dict: if custom_metrics is None: custom_metrics = [] - + values = number_of_samples_per_template_by_status.values() micro_summary = defaultdict(dict, { "number_of_error_samples": sum(v["error"] for v in values), @@ -197,7 +197,7 @@ def compute_macro_stats( ) -> dict: if custom_metrics is None: custom_metrics = [] - + macro_summary = defaultdict(dict) for metric in METRICS + custom_metrics: means = [ diff --git a/graphrag_eval/answer_correctness.py b/graphrag_eval/answer_correctness.py index b2bca48d..100a095d 100644 --- a/graphrag_eval/answer_correctness.py +++ b/graphrag_eval/answer_correctness.py @@ -1,21 +1,36 @@ +from __future__ import annotations + from pathlib import Path +from typing import Any, Self, TYPE_CHECKING from pydantic import BaseModel, Field from graphrag_eval.util import compute_f1 +from .evaluator import Evaluator + +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM def load_default_prompt() -> str: - with open(Path(__file__).parent / "prompts" / "template.md", "r", encoding="utf-8") as f: + with open( + Path(__file__).parent / "prompts" / "template.md", + encoding="utf-8" + ) as f: return f.read() class AnswerCorrectnessConfig(BaseModel): + enabled: bool = Field(default=True) prompt: str = Field(default_factory=load_default_prompt) class InvalidPromptException(Exception): - def __init__(self, message="The prompt template is invalid and cannot be formatted."): + def __init__( + self, + message="The prompt template is invalid and cannot be " + "formatted." + ): self.message = message super().__init__(self.message) @@ -23,13 +38,25 @@ def __init__(self, message="The prompt template is invalid and cannot be formatt class AnswerCorrectnessEvaluator: def __init__( self, - llm: "InstructorBaseRagasLLM", + ragas_llm: InstructorBaseRagasLLM, config: AnswerCorrectnessConfig | None = None, ): self.config = config or AnswerCorrectnessConfig() self.__validate_prompt_template(self.config.prompt) self.prompt_template = self.config.prompt - self.llm = llm + self.ragas_llm = ragas_llm + + @classmethod + def from_config( + cls, + ragas_llm: InstructorBaseRagasLLM | None, + config: AnswerCorrectnessConfig | None + ) -> Self | None: + if ragas_llm is None: + return None + if config is None or not config.enabled: + return None + return cls(ragas_llm=ragas_llm, config=config) @staticmethod def __validate_prompt_template(prompt_template: str): @@ -48,7 +75,7 @@ def __validate_prompt_template(prompt_template: str): async def _agenerate(self, prompt): """Wrapper method for easier testing""" - return (await self.llm.agenerate(prompt, None)).choices[0].message.content + return (await self.ragas_llm.agenerate(prompt, None)).choices[0].message.content async def evaluate_answer( self, @@ -56,9 +83,13 @@ async def evaluate_answer( reference_answer: str, actual_answer: str ) -> tuple[int, int, int, str]: - if any(not s.strip() for s in [question, reference_answer, actual_answer]): - raise ValueError("The question of the reference or the actual answer is a blank " - "string!") + if any( + not s.strip() for s in [question, reference_answer, actual_answer] + ): + raise ValueError( + "The question of the reference or the actual answer is a blank " + "string!" + ) prompt = self.prompt_template.format( question=question, reference_answer=reference_answer, @@ -67,12 +98,14 @@ async def evaluate_answer( response_str = await self._agenerate(prompt) return self.extract_response_values(response_str) - async def get_correctness_dict( + async def evaluate( self, - reference: dict, - actual: dict, - ): - result = {"reference_answer": reference["reference_answer"]} + reference: dict[str, Any], + actual: dict[str, Any] + ) -> dict[str, Any]: + if "actual_answer" not in actual or "reference_answer" not in reference: + return {} + result = {} try: num_ref_claims, num_actual_claims, num_matching_claims, reason = \ await self.evaluate_answer( @@ -96,7 +129,7 @@ async def get_correctness_dict( if f1 is not None: result["answer_f1"] = f1 except Exception as exc: - result["answer_eval_error"] = str(exc) + result["answer_correctness_error"] = str(exc) return result @staticmethod @@ -134,6 +167,10 @@ def extract_response_values( n_matching > n_actual ]): raise ValueError( - f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}" + "Invalid claims counts combination: " + f"{n_ref}\t{n_actual}\t{n_matching}" ) return n_ref, n_actual, n_matching, vals[3] + + +_: Evaluator = AnswerCorrectnessEvaluator diff --git a/graphrag_eval/answer_relevance.py b/graphrag_eval/answer_relevance.py index 2e4c367a..c1876d8c 100644 --- a/graphrag_eval/answer_relevance.py +++ b/graphrag_eval/answer_relevance.py @@ -1,24 +1,53 @@ -from ragas.embeddings.base import BaseRagasEmbedding -from ragas.llms.base import InstructorBaseRagasLLM -from ragas.metrics.collections import AnswerRelevancy +from __future__ import annotations -from graphrag_eval.util import singleton +from typing import Any, Self, TYPE_CHECKING +from pydantic import BaseModel, Field -@singleton -class Evaluator: - def __init__(self, ragas_llm: InstructorBaseRagasLLM, ragas_embedder: BaseRagasEmbedding): +from .evaluator import Evaluator + +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM + from ragas.embeddings.base import BaseRagasEmbeddings, BaseRagasEmbedding + + +class AnswerRelevanceConfig(BaseModel): + enabled: bool = Field(default=True) + + +class AnswerRelevanceEvaluator: + def __init__( + self, + ragas_llm: InstructorBaseRagasLLM, + ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding + ): + from ragas.metrics.collections import AnswerRelevancy self.scorer = AnswerRelevancy(llm=ragas_llm, embeddings=ragas_embedder) - async def get_relevance_dict( + @classmethod + def from_config( + cls, + ragas_llm: InstructorBaseRagasLLM | None, + ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding | None, + config: AnswerRelevanceConfig | None + ) -> Self | None: + if ragas_llm is None or ragas_embedder is None: + return None + if config is None or not config.enabled: + return None + return cls(ragas_llm=ragas_llm, ragas_embedder=ragas_embedder) + + async def evaluate( self, - question_text: str, - actual_answer: str, - ) -> dict: + reference: dict[str, Any], + actual: dict[str, Any] + ) -> dict[str, Any]: + if "actual_answer" not in actual: + return {} try: result = await self.scorer.ascore( - user_input=question_text, - response=actual_answer + user_input=reference["question_text"], + response=actual["actual_answer"] ) return { "answer_relevance": result.value @@ -27,3 +56,6 @@ async def get_relevance_dict( return { "answer_relevance_error": str(e) } + + +_: Evaluator = AnswerRelevanceEvaluator diff --git a/graphrag_eval/cli/answer_correctness.py b/graphrag_eval/cli/answer_correctness.py index 32b62b26..38122df0 100644 --- a/graphrag_eval/cli/answer_correctness.py +++ b/graphrag_eval/cli/answer_correctness.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import argparse import asyncio import csv from argparse import ArgumentParser from pathlib import Path +from typing import TYPE_CHECKING from tqdm import tqdm @@ -10,34 +13,39 @@ from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator from graphrag_eval.evaluation import Config +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM + def parse_args() -> argparse.Namespace: parser = ArgumentParser( - description="Calculates answer correctness over the entries from the input tsv file and " - "stores the output in the output tsv file.", + description="Calculates answer correctness over the entries from the " + "input tsv file and stores the output in the output tsv " + "file.", ) parser.add_argument( "-i", "--input-tsv-file-path", type=Path, required=True, - help="Input tsv file path with columns `Question`, `Reference answer` and `Actual answer`", + help="Input tsv file path with columns `Question`, `Reference answer` " + "and `Actual answer`", ) parser.add_argument( "-o", "--output-tsv-file-path", type=Path, required=True, - help="Output tsv file path with columns `#Reference`, `#PTarget`, `#Matching`, " - "`Reasoning`, `Error`", + help="Output tsv file path with columns `#Reference`, `#PTarget`, " + "`#Matching`, `Reasoning`, `Error`", ) parser.add_argument( "-c", "--config-yaml-file-path", type=Path, required=True, - help="Config yaml file path with definition of the LLM to use and optionally a custom " - "prompt.", + help="Config yaml file path with definition of the LLM to use and " + "optionally a custom prompt.", ) return parser.parse_args() @@ -54,7 +62,9 @@ async def evaluate_and_write( output_tsv_file_path.parent.mkdir(parents=True, exist_ok=True) with open(output_tsv_file_path, "w", encoding="utf-8") as f: writer = csv.writer(f, delimiter="\t") - writer.writerow(["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"]) + writer.writerow( + ["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"] + ) for row in tqdm(rows): if "Question" not in row or \ @@ -81,19 +91,26 @@ def run( output_tsv_file_path: Path, ): config = Config.parse(config_yaml_file_path) - ragas_llm = llm_factory.create_llm(config) + ragas_llm: InstructorBaseRagasLLM | None = llm_factory.create_llm( + config.llm + ) if ragas_llm is None: - raise ValueError("LLM must be configured to calculate the answer correctness!") - else: - evaluator = AnswerCorrectnessEvaluator( - llm=ragas_llm, - config=config.answer_correctness, + raise ValueError( + "LLM must be configured to calculate the answer correctness!" + ) + if config.answer_correctness and not config.answer_correctness.enabled: + raise ValueError( + "Can't disable answer correctness, when running this script!" ) - asyncio.run(evaluate_and_write( - input_tsv_file_path, - output_tsv_file_path, - evaluator, - )) + evaluator = AnswerCorrectnessEvaluator( + ragas_llm=ragas_llm, + config=config.answer_correctness, + ) + asyncio.run(evaluate_and_write( + input_tsv_file_path, + output_tsv_file_path, + evaluator, + )) def main(): diff --git a/graphrag_eval/custom_evaluation.py b/graphrag_eval/custom_evaluation.py index 7e43fcc9..9a55e856 100644 --- a/graphrag_eval/custom_evaluation.py +++ b/graphrag_eval/custom_evaluation.py @@ -1,9 +1,14 @@ +from __future__ import annotations + import json -from typing import Literal +from typing import Literal, Self, TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field, model_validator -from graphrag_eval.llm_factory import create_llm +from .evaluator import Evaluator + +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM RESERVED_KEYS = { "template_id", @@ -43,7 +48,7 @@ StepsKey = Literal["args", "output"] -class Config(BaseModel): +class EvaluatorConfig(BaseModel): model_config = ConfigDict(extra='forbid') name: str inputs: list[Inputs] = Field(..., min_length=1) @@ -53,7 +58,7 @@ class Config(BaseModel): steps_keys: set[StepsKey] | None = Field(default=None, min_length=1) @model_validator(mode='after') - def validate_step_dependencies(self) -> 'Config': + def validate_step_dependencies(self) -> Self: if set(self.inputs) & {"reference_steps", "actual_steps"}: suffix = "is required when steps are in inputs" for var_name in ["steps_name", "steps_keys"]: @@ -62,7 +67,7 @@ def validate_step_dependencies(self) -> 'Config': return self @model_validator(mode='after') - def validate_name_and_outputs(self) -> 'Config': + def validate_name_and_outputs(self) -> Self: if self.name + "_error" in RESERVED_KEYS: raise ValueError(f"Name {self.name} is reserved") conflicting_keys = set(self.outputs.keys()) & RESERVED_KEYS @@ -76,7 +81,7 @@ def create_input_template(input_key: str) -> str: return f"# {header}\n{{{input_key}}}" -def create_prompt_template(config: Config, output_variables: list[str]) -> str: +def create_prompt_template(config: EvaluatorConfig, output_variables: list[str]) -> str: """ Return a template for the LLM prompt, with placeholders for the inputs, instructions, outputs etc. We use this template at evaluation time to @@ -99,8 +104,8 @@ def create_prompt_template(config: Config, output_variables: list[str]) -> str: class CustomEvaluator: def __init__( self, - config: Config, - eval_config: "evaluation.Config", + ragas_llm: InstructorBaseRagasLLM, + config: EvaluatorConfig, ): self.name = config.name self.input_variables = config.inputs @@ -111,11 +116,24 @@ def __init__( config, self.output_variables ) - self.llm = create_llm(eval_config) + self.ragas_llm = ragas_llm + + @classmethod + def from_config( + cls, + ragas_llm: InstructorBaseRagasLLM | None, + evaluation_configs: list[EvaluatorConfig] | None + ) -> list[Self]: + if ragas_llm and evaluation_configs: + return [ + cls(ragas_llm, evaluation_config) + for evaluation_config in evaluation_configs + ] + return [] async def _agenerate(self, prompt: str) -> str: """Wrapper method for easier testing""" - return (await self.llm.agenerate(prompt, None)).choices[0].message.content + return (await self.ragas_llm.agenerate(prompt, None)).choices[0].message.content def format_steps(self, steps: list) -> str: steps_formatted = [] @@ -157,7 +175,11 @@ def parse_outputs(self, response: str) -> dict[str, str | None]: return result return self.error(f"Expected {n_exp} tab-separated values, got: {response}") - async def evaluate(self, reference: dict, actual: dict) -> dict[str, str | None]: + async def evaluate( + self, + reference: dict[str, Any], + actual: dict[str, Any] + ) -> dict[str, Any]: inputs = {} if "question" in self.input_variables: if "question_text" not in reference: @@ -195,10 +217,4 @@ async def evaluate(self, reference: dict, actual: dict) -> dict[str, str | None] return self.parse_outputs(response) -def create_evaluators(config: "evaluation.Config") -> list[CustomEvaluator]: - if config.custom_evaluations and config.llm: - return [ - CustomEvaluator(custom_evaluation_config, config) - for custom_evaluation_config in config.custom_evaluations - ] - return [] +_: Evaluator = CustomEvaluator diff --git a/graphrag_eval/evaluation.py b/graphrag_eval/evaluation.py index d3dfe80c..d5e6315e 100644 --- a/graphrag_eval/evaluation.py +++ b/graphrag_eval/evaluation.py @@ -1,29 +1,71 @@ +from __future__ import annotations + from pathlib import Path +from typing import Self, TYPE_CHECKING import yaml from pydantic import BaseModel, Field, model_validator -from . import custom_evaluation -from .answer_correctness import AnswerCorrectnessConfig -from .llm_factory import Config as LLMConfig, create_llm, create_embedder +from .answer_correctness import ( + AnswerCorrectnessConfig, + AnswerCorrectnessEvaluator, +) +from .answer_relevance import AnswerRelevanceConfig, AnswerRelevanceEvaluator +from .custom_evaluation import EvaluatorConfig, CustomEvaluator +from .evaluator import Evaluator +from .llm_factory import LLMConfig, create_llm, create_embedder from .steps.evaluation import evaluate_steps +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM + from ragas.embeddings.base import BaseRagasEmbeddings, BaseRagasEmbedding + class Config(BaseModel): llm: LLMConfig | None = None - custom_evaluations: list[custom_evaluation.Config] | None \ - = Field(default=None, min_length=1) + custom_evaluations: list[EvaluatorConfig] | None = Field( + default=None, + min_length=1 + ) answer_correctness: AnswerCorrectnessConfig | None = None + answer_relevance: AnswerRelevanceConfig | None = None @model_validator(mode="after") - def validate_config(self) -> "Config": - if self.custom_evaluations and not self.llm: - msg = "llm config is required if custom_evaluations are provided" - raise ValueError(msg) + def validate_config_and_set_defaults(self) -> Self: + has_llm = self.llm is not None + has_embedding = has_llm and self.llm.embedding is not None + + if self.answer_correctness is None and has_llm: + self.answer_correctness = AnswerCorrectnessConfig() + + if self.answer_relevance is None and has_embedding: + self.answer_relevance = AnswerRelevanceConfig() + + if self.custom_evaluations and not has_llm: + raise ValueError( + "llm config is required if custom_evaluations are provided" + ) + if ( + self.answer_correctness + and self.answer_correctness.enabled + and not has_llm + ): + raise ValueError( + "llm config is required if answer correctness is enabled" + ) + if ( + self.answer_relevance + and self.answer_relevance.enabled + and not has_embedding + ): + raise ValueError( + "llm config including embedding is required if answer " + "relevance is enabled" + ) return self @classmethod - def parse(cls, config_file_path: str | Path | None) -> "Config": + def parse(cls, config_file_path: str | Path | None) -> Self: if config_file_path: with open(config_file_path, encoding="utf-8") as f: config_dict = yaml.safe_load(f) @@ -36,12 +78,10 @@ async def run_evaluation( responses_dict: dict, config_file_path: str | Path | None = None, ) -> list[dict]: + evaluators, ragas_llm = parse_config_and_init_evaluators(config_file_path) + # Output metrics are not nested, for simpler aggregation evaluation_results = [] - config = Config.parse(config_file_path) - ragas_llm = create_llm(config) - ragas_embedder = create_embedder(config) - custom_evaluators = custom_evaluation.create_evaluators(config) for template in qa_dataset: template_id = template["template_id"] for question in template["questions"]: @@ -51,6 +91,12 @@ async def run_evaluation( "question_id": actual_result["question_id"], "question_text": question["question_text"] } + for key in ("input_tokens", "output_tokens", "total_tokens", + "elapsed_sec"): + if key in actual_result: + eval_result[key] = actual_result[key] + if "actual_answer" in actual_result: + eval_result["actual_answer"] = actual_result["actual_answer"] if "reference_answer" in question: eval_result["reference_answer"] = question["reference_answer"] if "reference_steps" in question: @@ -63,42 +109,46 @@ async def run_evaluation( else: eval_result["status"] = "success" - if "actual_answer" in actual_result: - eval_result["actual_answer"] = actual_result["actual_answer"] - if ragas_llm: - from graphrag_eval.answer_relevance import Evaluator - relevance_evaluator = Evaluator(ragas_llm, ragas_embedder) - eval_result.update( - await relevance_evaluator.get_relevance_dict( - question["question_text"], - actual_result["actual_answer"], - ) - ) - if "reference_answer" in question and ragas_llm: - from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator - answer_correctness_evaluator = AnswerCorrectnessEvaluator( - llm=ragas_llm, - config=config.answer_correctness, - ) - eval_result.update( - await answer_correctness_evaluator.get_correctness_dict( - question, - actual_result, - ) - ) eval_result.update( - await evaluate_steps( - question, - actual_result, - ragas_llm, - ) + await evaluate_steps(question, actual_result, ragas_llm) ) - for custom_evaluator in custom_evaluators: - custom_metrics = await custom_evaluator.evaluate(question, actual_result) - eval_result.update(**custom_metrics) - for key in "input_tokens", "output_tokens", "total_tokens", "elapsed_sec": - if key in actual_result: - eval_result[key] = actual_result[key] + for evaluator in evaluators: + eval_result.update( + await evaluator.evaluate(question, actual_result) + ) evaluation_results.append(eval_result) return evaluation_results + + +def parse_config_and_init_evaluators( + config_file_path: str | Path | None +) -> tuple[ + list[Evaluator], + InstructorBaseRagasLLM | None, +]: + config = Config.parse(config_file_path) + ragas_llm: InstructorBaseRagasLLM | None = create_llm(config.llm) + ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding | None = ( + create_embedder(config.llm) + ) + + evaluators: list[Evaluator] = [] + + answer_relevance_evaluator = AnswerRelevanceEvaluator.from_config( + ragas_llm, ragas_embedder, config.answer_relevance + ) + if answer_relevance_evaluator: + evaluators.append(answer_relevance_evaluator) + + answer_correctness_evaluator = AnswerCorrectnessEvaluator.from_config( + ragas_llm, config.answer_correctness + ) + if answer_correctness_evaluator: + evaluators.append(answer_correctness_evaluator) + + evaluators.extend( + CustomEvaluator.from_config(ragas_llm, config.custom_evaluations) + ) + + return evaluators, ragas_llm diff --git a/graphrag_eval/evaluator.py b/graphrag_eval/evaluator.py new file mode 100644 index 00000000..fcf2720d --- /dev/null +++ b/graphrag_eval/evaluator.py @@ -0,0 +1,14 @@ +from typing import Protocol, Any + + +class Evaluator(Protocol): + async def evaluate( + self, + reference: dict[str, Any], + actual: dict[str, Any] + ) -> dict[str, Any]: + """ + Evaluate the actual output against the reference. + Returns a flat dictionary containing scores or error tracking logs. + """ + ... diff --git a/graphrag_eval/llm_factory.py b/graphrag_eval/llm_factory.py index 87b0d803..345791f7 100644 --- a/graphrag_eval/llm_factory.py +++ b/graphrag_eval/llm_factory.py @@ -1,7 +1,13 @@ -from typing import Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM + from ragas.embeddings.base import BaseRagasEmbeddings, BaseRagasEmbedding + class GenerationConfig(BaseModel): provider: str @@ -17,18 +23,20 @@ class EmbeddingConfig(BaseModel): model_config = ConfigDict(extra='allow') -class Config(BaseModel): +class LLMConfig(BaseModel): generation: GenerationConfig embedding: EmbeddingConfig | None = None -def create_llm(config: "evaluation.Config") -> Optional["InstructorBaseRagasLLM"]: - if config.llm: +def create_llm( + config: LLMConfig | None +) -> InstructorBaseRagasLLM | None: + if config: import litellm from ragas.llms import llm_factory litellm.drop_params = True # Remove unsupported params from requests - params = config.llm.generation.model_dump() + params = config.generation.model_dump() ragas_llm = llm_factory( provider="litellm", model=f"{params.pop('provider')}/{params.pop('model')}", @@ -40,13 +48,15 @@ def create_llm(config: "evaluation.Config") -> Optional["InstructorBaseRagasLLM" return None -def create_embedder(config: "evaluation.Config") -> Optional["BaseRagasEmbedding"]: - if config.llm and config.llm.embedding: +def create_embedder( + config: LLMConfig | None +) -> BaseRagasEmbeddings | BaseRagasEmbedding | None: + if config and config.embedding: import litellm from ragas.embeddings.base import embedding_factory litellm.drop_params = True # Remove unsupported params from requests - params = config.llm.embedding.model_dump() + params = config.embedding.model_dump() ragas_embedder = embedding_factory( provider="litellm", model=f"{params.pop('provider')}/{params.pop('model')}", diff --git a/graphrag_eval/steps/evaluation.py b/graphrag_eval/steps/evaluation.py index 53f0f54c..243b2413 100644 --- a/graphrag_eval/steps/evaluation.py +++ b/graphrag_eval/steps/evaluation.py @@ -1,13 +1,21 @@ +from __future__ import annotations + import json import logging from collections import defaultdict from collections.abc import Sequence -from typing import Any +from typing import Any, TYPE_CHECKING from .iri_discovery import do_iri_discovery_steps_equal from .retrieval_context_ids import recall_at_k from .sparql import compare_sparql_results -from .timeseries import do_retrieve_time_series_steps_equal, do_retrieve_data_points_steps_equal +from .timeseries import ( + do_retrieve_time_series_steps_equal, + do_retrieve_data_points_steps_equal, +) + +if TYPE_CHECKING: + from ragas.llms.base import InstructorBaseRagasLLM logger = logging.getLogger(__name__) @@ -140,7 +148,7 @@ def calculate_steps_score( async def evaluate_steps( reference: dict, actual: dict, - ragas_llm: "InstructorBaseRagasLLM", + ragas_llm: InstructorBaseRagasLLM | None, ) -> dict: eval_result = {} actual_steps = actual.get("actual_steps", []) diff --git a/poetry.lock b/poetry.lock index 18ce6a79..a1fd329e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1679,14 +1679,14 @@ xxhash = ">=3.5.0" [[package]] name = "langgraph-checkpoint" -version = "4.1.0" +version = "4.1.1" description = "Library with base interfaces for LangGraph checkpoint savers." optional = false python-versions = ">=3.10" groups = ["main", "llm"] files = [ - {file = "langgraph_checkpoint-4.1.0-py3-none-any.whl", hash = "sha256:8bc2a0466a20c38b865ce6671b42093fd5c041133f32351cae4222e0eeaf7fb5"}, - {file = "langgraph_checkpoint-4.1.0.tar.gz", hash = "sha256:e5bb304e30fc1363ac8fcb5f7dee5ca2185d77fe475b0d01de2c5f91324c2c21"}, + {file = "langgraph_checkpoint-4.1.1-py3-none-any.whl", hash = "sha256:25d29144b082827218e7bc3f1e9b0566a4bb007895cd6cc26f66a8428739f56e"}, + {file = "langgraph_checkpoint-4.1.1.tar.gz", hash = "sha256:6c2bdb530c91f91d7d9c1bd100925d0fc4f498d418c17f3587d1526279482a25"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index c54a9ac1..ce15992f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "graphrag-eval" -version = "6.3.0" +version = "6.4.0" description = "For assessing question answering systems' final answers and intermediate steps, against a given set of questions, reference answers and steps." authors = [ { name = "Philip Ganchev", email = "philip.ganchev@graphwise.ai" }, diff --git a/system-tests/test_answer_correctness_azure.py b/system-tests/test_answer_correctness_azure.py index 855650ba..6913d6bc 100644 --- a/system-tests/test_answer_correctness_azure.py +++ b/system-tests/test_answer_correctness_azure.py @@ -12,7 +12,7 @@ async def test_answer_correctness(): with open(config_path, encoding="utf-8") as f: config_dict = yaml.safe_load(f) config = Config(**config_dict) - ragas_llm = llm_factory.create_llm(config) + ragas_llm = llm_factory.create_llm(config.llm) reference = { "template_id": "geography", @@ -25,9 +25,9 @@ async def test_answer_correctness(): "actual_answer": "The capital of Bulgaria is Sofia" } - evaluator = AnswerCorrectnessEvaluator(llm=ragas_llm) - result = await evaluator.get_correctness_dict(reference, actual) - assert "answer_eval_error" not in result + evaluator = AnswerCorrectnessEvaluator(ragas_llm=ragas_llm) + result = await evaluator.evaluate(reference, actual) + assert "answer_correctness_error" not in result assert isinstance(result["answer_recall"], float) assert isinstance(result["answer_precision"], float) assert isinstance(result["answer_f1"], float) diff --git a/system-tests/test_answer_correctness_openai.py b/system-tests/test_answer_correctness_openai.py index 2999cf72..05b18193 100644 --- a/system-tests/test_answer_correctness_openai.py +++ b/system-tests/test_answer_correctness_openai.py @@ -12,7 +12,7 @@ async def test_answer_correctness(): with open(config_path, encoding="utf-8") as f: config_dict = yaml.safe_load(f) config = Config(**config_dict) - ragas_llm = llm_factory.create_llm(config) + ragas_llm = llm_factory.create_llm(config.llm) reference = { "template_id": "geography", @@ -25,8 +25,8 @@ async def test_answer_correctness(): "actual_answer": "The capital of Bulgaria is Sofia" } - evaluator = AnswerCorrectnessEvaluator(llm=ragas_llm) - result = await evaluator.get_correctness_dict(reference, actual) + evaluator = AnswerCorrectnessEvaluator(ragas_llm=ragas_llm) + result = await evaluator.evaluate(reference, actual) assert isinstance(result["answer_recall"], float) assert isinstance(result["answer_precision"], float) assert isinstance(result["answer_f1"], float) diff --git a/system-tests/test_answer_relevance_azure.py b/system-tests/test_answer_relevance_azure.py index 53a5ba19..67f36a2e 100644 --- a/system-tests/test_answer_relevance_azure.py +++ b/system-tests/test_answer_relevance_azure.py @@ -4,7 +4,7 @@ import yaml from graphrag_eval import llm_factory -from graphrag_eval.answer_relevance import Evaluator +from graphrag_eval.answer_relevance import AnswerRelevanceEvaluator from graphrag_eval.evaluation import Config @@ -17,13 +17,13 @@ async def test_answer_relevance(): config_dict["llm"]["generation"]["api_key"] = os.getenv("AZURE_OPENAI_GENERATION_KEY") config_dict["llm"]["embedding"]["api_key"] = os.getenv("AZURE_OPENAI_EMBEDDING_KEY") config = Config(**config_dict) - ragas_llm = llm_factory.create_llm(config) - ragas_embedder = llm_factory.create_embedder(config) + ragas_llm = llm_factory.create_llm(config.llm) + ragas_embedder = llm_factory.create_embedder(config.llm) - evaluator = Evaluator(ragas_llm, ragas_embedder) - result = await evaluator.get_relevance_dict( - question_text="Why is the sky blue?", - actual_answer="Oxygen makes it blue", + evaluator = AnswerRelevanceEvaluator(ragas_llm, ragas_embedder) + result = await evaluator.evaluate( + {"question_text": "Why is the sky blue?"}, + {"actual_answer": "Oxygen makes it blue"}, ) assert "answer_relevance_error" not in result assert "answer_relevance" in result diff --git a/system-tests/test_answer_relevance_openai.py b/system-tests/test_answer_relevance_openai.py index be002903..54ffbf7c 100644 --- a/system-tests/test_answer_relevance_openai.py +++ b/system-tests/test_answer_relevance_openai.py @@ -2,7 +2,7 @@ import yaml from graphrag_eval import llm_factory -from graphrag_eval.answer_relevance import Evaluator +from graphrag_eval.answer_relevance import AnswerRelevanceEvaluator from graphrag_eval.evaluation import Config @@ -12,13 +12,13 @@ async def test_answer_relevance(): with open(path, encoding="utf-8") as f: config_dict = yaml.safe_load(f) config = Config(**config_dict) - ragas_llm = llm_factory.create_llm(config) - ragas_embedder = llm_factory.create_embedder(config) + ragas_llm = llm_factory.create_llm(config.llm) + ragas_embedder = llm_factory.create_embedder(config.llm) - evaluator = Evaluator(ragas_llm, ragas_embedder) - result = await evaluator.get_relevance_dict( - question_text="Why is the sky blue?", - actual_answer="Oxygen makes it blue", + evaluator = AnswerRelevanceEvaluator(ragas_llm, ragas_embedder) + result = await evaluator.evaluate( + {"question_text": "Why is the sky blue?"}, + {"actual_answer": "Oxygen makes it blue"}, ) assert isinstance(result["answer_relevance"], float) assert 0 <= result["answer_relevance"] <= 1 diff --git a/system-tests/test_custom_evaluation.py b/system-tests/test_custom_evaluation.py index 2dc61f7b..ef6f0892 100644 --- a/system-tests/test_custom_evaluation.py +++ b/system-tests/test_custom_evaluation.py @@ -1,5 +1,3 @@ -from pathlib import Path - import pytest from graphrag_eval.evaluation import run_evaluation @@ -26,7 +24,7 @@ async def test_custom_evaluation(): "actual_answer": "Nein" } }, - Path("system-tests/config/legal-config.yaml"), + "system-tests/config/legal-config.yaml", ) assert len(evaluation_results) == 1 assert "legal_recall" in evaluation_results[0] diff --git a/system-tests/test_retrieval_answer_openai.py b/system-tests/test_retrieval_answer_openai.py index 186e79ce..906de928 100644 --- a/system-tests/test_retrieval_answer_openai.py +++ b/system-tests/test_retrieval_answer_openai.py @@ -12,7 +12,7 @@ async def test_retrieval_answer(): with open(path, encoding="utf-8") as f: config_dict = yaml.safe_load(f) config = Config(**config_dict) - ragas_llm = llm_factory.create_llm(config) + ragas_llm = llm_factory.create_llm(config.llm) evaluator = Evaluator(ragas_llm) result = await evaluator.get_retrieval_evaluation_dict( diff --git a/system-tests/test_retrieval_context_texts_openai.py b/system-tests/test_retrieval_context_texts_openai.py index ccbbdec5..6cd3c4ab 100644 --- a/system-tests/test_retrieval_context_texts_openai.py +++ b/system-tests/test_retrieval_context_texts_openai.py @@ -12,7 +12,7 @@ async def test_retrieval_contexts(): with open(path, encoding="utf-8") as f: config_dict = yaml.safe_load(f) config = Config(**config_dict) - ragas_llm = llm_factory.create_llm(config) + ragas_llm = llm_factory.create_llm(config.llm) evaluator = Evaluator(ragas_llm) result = await evaluator.get_retrieval_evaluation_dict( diff --git a/tests-with-llm/__init__.py b/tests-with-llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests-with-llm/cli/test_answer_correctness.py b/tests-with-llm/cli/test_answer_correctness.py index 0cd88c0a..d19bf423 100644 --- a/tests-with-llm/cli/test_answer_correctness.py +++ b/tests-with-llm/cli/test_answer_correctness.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, AsyncMock, patch import pytest +import yaml from graphrag_eval.answer_correctness import InvalidPromptException from graphrag_eval.cli.answer_correctness import evaluate_and_write, run @@ -79,7 +80,7 @@ async def test_evaluate_and_write_success(tmp_path): "Gases scatter sunlight" ) - with open(output_file, "r", encoding="utf-8") as f: + with open(output_file, encoding="utf-8") as f: reader = csv.reader(f, delimiter="\t") output_data = list(reader) @@ -120,7 +121,7 @@ async def test_evaluate_and_write_wrong_input_format(tmp_path): mock_evaluator.evaluate_answer.assert_not_called() - with open(output_file, "r", encoding="utf-8") as f: + with open(output_file, encoding="utf-8") as f: reader = csv.reader(f, delimiter="\t") output_data = list(reader) @@ -139,9 +140,13 @@ def test_run_with_custom_prompt(mock_create_llm, tmp_path): {reference_answer} {actual_answer} """ + config_data = { + "llm": {"generation": {"provider": "openai", "model": "gpt-4o-mini"}}, + "answer_correctness": {"prompt": custom_prompt} + } + with open(config_path, "w", encoding="utf-8") as f: - f.write(f"answer_correctness:\n") - f.write(f" prompt: \"{custom_prompt}\"\n") + yaml.dump(config_data, f) with open(input_path, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f, delimiter="\t") @@ -165,11 +170,14 @@ def test_run_with_custom_prompt(mock_create_llm, tmp_path): mock_create_llm.assert_called_once() called_prompt = mock_llm.agenerate.call_args[0][0] - assert ("""You are an expert evaluator assessing factual criteria... What is 1+1? 2 3 """ == - called_prompt) + assert ("""You are an expert evaluator assessing factual criteria... +What is 1+1? +2 +3 +""" == called_prompt) assert output_path.exists() - with open(output_path, "r", encoding="utf-8") as f: + with open(output_path, encoding="utf-8") as f: reader = csv.reader(f, delimiter="\t") output_data = list(reader) @@ -184,10 +192,13 @@ def test_run_with_invalid_prompt(mock_create_llm, tmp_path): output_path = tmp_path / "output.tsv" custom_prompt = "You are an expert evaluator assessing factual criteria...\\n{unexpected}\\n" + config_data = { + "llm": {"generation": {"provider": "openai", "model": "gpt-4o-mini"}}, + "answer_correctness": {"prompt": custom_prompt} + } with open(config_path, "w", encoding="utf-8") as f: - f.write("answer_correctness:\n") - f.write(f" prompt: \"{custom_prompt}\"\n") + yaml.dump(config_data, f) with open(input_path, "w", encoding="utf-8", newline="") as f: writer = csv.writer(f, delimiter="\t") diff --git a/tests-with-llm/test_answer_correctness.py b/tests-with-llm/test_answer_correctness.py index d13df9a5..3c1d72ee 100644 --- a/tests-with-llm/test_answer_correctness.py +++ b/tests-with-llm/test_answer_correctness.py @@ -10,7 +10,7 @@ def test_extract_response_values_expected_case(): response = "2\t3\t1\treason" - result = AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + result = AnswerCorrectnessEvaluator(ragas_llm=MagicMock()).extract_response_values(response) assert result == (2, 3, 1, "reason") @@ -37,7 +37,7 @@ async def test_evaluate_answer_empty_strings( ): with raises(ValueError, match="The question of the reference or the actual answer is a blank " "string!"): - await AnswerCorrectnessEvaluator(llm=MagicMock()).evaluate_answer( + await AnswerCorrectnessEvaluator(ragas_llm=MagicMock()).evaluate_answer( question, reference_answer, actual_answer ) @@ -57,7 +57,7 @@ def test_extract_response_values_invalid_values(n_ref: int, n_actual: int, n_mat response = f"{n_ref}\t{n_actual}\t{n_matching}\treason" with raises(ValueError, match=f"Invalid claims counts combination: {n_ref}\t{n_actual}\t{n_matching}"): - AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + AnswerCorrectnessEvaluator(ragas_llm=MagicMock()).extract_response_values(response) @pytest.mark.parametrize( @@ -79,16 +79,16 @@ def test_extract_response_values_non_int(n_ref: Any, n_actual: Any, n_matching: f"Claims counts should be ints: ['{n_ref}', '{n_actual}', '{n_matching}', " f"'reason']" )): - AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + AnswerCorrectnessEvaluator(ragas_llm=MagicMock()).extract_response_values(response) def test_extract_response_values_too_few_values(): response = "2\t2\treason" with raises(ValueError, match=f"Expected 4 tab-separated values: {response}"): - AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + AnswerCorrectnessEvaluator(ragas_llm=MagicMock()).extract_response_values(response) def test_extract_response_values_too_many_values(): response = "2\t2\t2\treason\textra" - result = AnswerCorrectnessEvaluator(llm=MagicMock()).extract_response_values(response) + result = AnswerCorrectnessEvaluator(ragas_llm=MagicMock()).extract_response_values(response) assert result == (2, 2, 2, "reason") diff --git a/tests-with-llm/test_answer_relevance.py b/tests-with-llm/test_answer_relevance.py index 4b748d78..299acd7f 100644 --- a/tests-with-llm/test_answer_relevance.py +++ b/tests-with-llm/test_answer_relevance.py @@ -1,9 +1,10 @@ import os from unittest.mock import AsyncMock, MagicMock +import pytest from ragas.llms.base import InstructorBaseRagasLLM -import pytest +from graphrag_eval.answer_relevance import AnswerRelevanceEvaluator def get_ragas_llm() -> InstructorBaseRagasLLM: @@ -15,8 +16,8 @@ def get_ragas_llm() -> InstructorBaseRagasLLM: def get_ragas_embedder(): from openai import AsyncOpenAI - from ragas.embeddings.base import embedding_factory - + from ragas.embeddings.base import embedding_factory + return embedding_factory("openai", client=AsyncOpenAI()) @@ -26,15 +27,15 @@ def set_env(): @pytest.mark.asyncio -async def test_get_relevance_dict_eval_success(monkeypatch): - from graphrag_eval.answer_relevance import AnswerRelevancy, Evaluator - - relevance_mock = AsyncMock(return_value=MagicMock(value=0.9)) - monkeypatch.setattr(AnswerRelevancy, 'ascore', relevance_mock) - evaluator = Evaluator(get_ragas_llm(), get_ragas_embedder()) - eval_result_dict = await evaluator.get_relevance_dict( - "Why is the sky blue?", - "Because of the oxygen in the air", +async def test_evaluate_answer_relevance_success(monkeypatch): + async_mock = AsyncMock(return_value=MagicMock(value=0.9)) + from ragas.metrics.collections import AnswerRelevancy + monkeypatch.setattr(AnswerRelevancy, "ascore", async_mock) + + evaluator = AnswerRelevanceEvaluator(get_ragas_llm(), get_ragas_embedder()) + eval_result_dict = await evaluator.evaluate( + {"question_text": "Why is the sky blue?"}, + {"actual_answer": "Because of the oxygen in the air"}, ) assert eval_result_dict == { "answer_relevance": 0.9 @@ -42,14 +43,15 @@ async def test_get_relevance_dict_eval_success(monkeypatch): @pytest.mark.asyncio -async def test_get_relevance_dict_eval_error(monkeypatch): - from graphrag_eval.answer_relevance import AnswerRelevancy, Evaluator - relevance_mock = AsyncMock(side_effect=Exception("some error")) - monkeypatch.setattr(AnswerRelevancy, 'ascore', relevance_mock) - evaluator = Evaluator(get_ragas_llm(), get_ragas_embedder()) - eval_result_dict = await evaluator.get_relevance_dict( - "Why is the sky blue?", - "Because of the oxygen in the air", +async def test_evaluate_answer_relevance_error(monkeypatch): + async_mock = AsyncMock(side_effect=Exception("some error")) + from ragas.metrics.collections import AnswerRelevancy + monkeypatch.setattr(AnswerRelevancy, 'ascore', async_mock) + + evaluator = AnswerRelevanceEvaluator(get_ragas_llm(), get_ragas_embedder()) + eval_result_dict = await evaluator.evaluate( + {"question_text": "Why is the sky blue?"}, + {"actual_answer": "Because of the oxygen in the air"}, ) assert eval_result_dict == { "answer_relevance_error": "some error" diff --git a/tests-with-llm/test_custom_evaluation.py b/tests-with-llm/test_custom_evaluation.py index c349c890..f1e57778 100644 --- a/tests-with-llm/test_custom_evaluation.py +++ b/tests-with-llm/test_custom_evaluation.py @@ -1,16 +1,18 @@ -import os from copy import deepcopy from pathlib import Path from unittest.mock import AsyncMock, MagicMock import pytest import yaml +from ragas.embeddings.base import BaseRagasEmbedding +from ragas.llms.base import InstructorBaseRagasLLM from graphrag_eval import ( compute_aggregates, run_evaluation, ) from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator +from graphrag_eval.answer_relevance import AnswerRelevanceEvaluator from graphrag_eval.custom_evaluation import CustomEvaluator from tests.util import read_responses @@ -29,24 +31,47 @@ async def mock_agenerate_correctness(self, prompt): ) -@pytest.fixture(scope="session", autouse=True) -def set_env(): - os.environ["OPENAI_API_KEY"] = "fake-key" +def _mock_common_calls(monkeypatch): + async_mock = AsyncMock(return_value=MagicMock(value=0.9)) + from ragas.metrics.collections import AnswerRelevancy + monkeypatch.setattr(AnswerRelevancy, "ascore", async_mock) -def _mock_common_calls(monkeypatch): from graphrag_eval.steps.retrieval_answer import ( ContextRecall, ContextPrecision ) - from graphrag_eval.answer_relevance import AnswerRelevancy - - mock = AsyncMock(return_value=MagicMock(value=0.9)) - monkeypatch.setattr(AnswerRelevancy, "ascore", mock) - monkeypatch.setattr(ContextRecall, "ascore", mock) - monkeypatch.setattr(ContextPrecision, "ascore", mock) + monkeypatch.setattr(ContextRecall, "ascore", async_mock) + monkeypatch.setattr(ContextPrecision, "ascore", async_mock) mock_answer_correctness_evaluator(monkeypatch) + def mock_init_evaluators(config_path): + mock_llm = MagicMock(spec=InstructorBaseRagasLLM) + evaluators = [] + answer_relevance_evaluator = AnswerRelevanceEvaluator( + mock_llm, + MagicMock(spec=BaseRagasEmbedding) + ) + evaluators.append(answer_relevance_evaluator) + answer_correctness_evaluator = AnswerCorrectnessEvaluator( + ragas_llm=mock_llm + ) + evaluators.append(answer_correctness_evaluator) + config = evaluation.Config.parse(config_path) + custom_evaluators = CustomEvaluator.from_config( + mock_llm, config.custom_evaluations + ) + evaluators.extend(custom_evaluators) + + return evaluators, mock_llm + + from graphrag_eval import evaluation + monkeypatch.setattr( + evaluation, + "parse_config_and_init_evaluators", + mock_init_evaluators + ) + @pytest.mark.asyncio async def test_run_custom_evaluation_ok(monkeypatch): @@ -69,8 +94,8 @@ async def mock_agenerate(self, prompt): return "0.5\t0.67\tThere are 4 reference claims and 3 actual " \ "claims; 2 claims match" if i == 3: - return "0.75\t0.6\tThe reference answer has 4 claims; there are 5 " \ - "SPARQL results; 3 claims match" + return "0.75\t0.6\tThe reference answer has 4 claims; " \ + "there are 5 SPARQL results; 3 claims match" return "0.0\tDefault mock fallback response" monkeypatch.setattr(CustomEvaluator, "_agenerate", mock_agenerate) @@ -145,8 +170,8 @@ async def test_run_custom_evaluation_config_error(monkeypatch): file_path = DATA_DIR / f"evaluation_{i}.yaml" with open(file_path) as f: eval_dicts = yaml.safe_load(f) - for eval in eval_dicts: - reserved_keys |= eval.keys() + for eval_dict in eval_dicts: + reserved_keys |= eval_dict.keys() for key in reserved_keys: error_config = deepcopy(correct_config) error_config["custom_evaluations"][0]["outputs"][key] = "invalid" diff --git a/tests-with-llm/test_evaluation.py b/tests-with-llm/test_evaluation.py index 06f74d20..763c0c6e 100644 --- a/tests-with-llm/test_evaluation.py +++ b/tests-with-llm/test_evaluation.py @@ -1,14 +1,17 @@ -import os from pathlib import Path from unittest.mock import AsyncMock, MagicMock import pytest import yaml +from ragas.embeddings.base import BaseRagasEmbedding +from ragas.llms.base import InstructorBaseRagasLLM from graphrag_eval import ( compute_aggregates, run_evaluation, ) +from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator +from graphrag_eval.answer_relevance import AnswerRelevanceEvaluator from tests.util import read_responses DATA_DIR = Path(__file__).parent / "test_data" @@ -16,8 +19,6 @@ def mock_answer_correctness_evaluator(monkeypatch): - from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator - async def mock_agenerate_correctness(self, prompt): return "2\t2\t2\tanswer correctness reason" @@ -28,26 +29,45 @@ async def mock_agenerate_correctness(self, prompt): ) -@pytest.fixture(scope="session", autouse=True) -def set_env(): - os.environ["OPENAI_API_KEY"] = "fake-key" - - @pytest.mark.asyncio async def test_run_evaluation_and_compute_aggregates(monkeypatch): - reference_data = yaml.safe_load( - (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") - ) async_mock = AsyncMock(return_value=MagicMock(value=0.9)) - from graphrag_eval.answer_relevance import AnswerRelevancy + + from ragas.metrics.collections import AnswerRelevancy + monkeypatch.setattr(AnswerRelevancy, "ascore", async_mock) + from graphrag_eval.steps.retrieval_answer import ( ContextRecall, ContextPrecision ) - monkeypatch.setattr(AnswerRelevancy, "ascore", async_mock) monkeypatch.setattr(ContextRecall, "ascore", async_mock) monkeypatch.setattr(ContextPrecision, "ascore", async_mock) mock_answer_correctness_evaluator(monkeypatch) + + def mock_init_evaluators(_): + mock_llm = MagicMock(spec=InstructorBaseRagasLLM) + evaluators = [] + answer_relevance_evaluator = AnswerRelevanceEvaluator( + mock_llm, + MagicMock(spec=BaseRagasEmbedding) + ) + evaluators.append(answer_relevance_evaluator) + answer_correctness_evaluator = AnswerCorrectnessEvaluator( + ragas_llm=mock_llm + ) + evaluators.append(answer_correctness_evaluator) + return evaluators, mock_llm + + from graphrag_eval import evaluation + monkeypatch.setattr( + evaluation, + "parse_config_and_init_evaluators", + mock_init_evaluators + ) + + reference_data = yaml.safe_load( + (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") + ) actual_responses = read_responses(DATA_DIR / "actual_responses_1.jsonl") evaluation_results = await run_evaluation( reference_data, @@ -70,13 +90,37 @@ async def test_run_evaluation_and_compute_aggregates(monkeypatch): async def test_run_evaluation_and_compute_aggregates_no_actual_steps( monkeypatch ): - reference_data = yaml.safe_load( - (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") - ) async_mock = AsyncMock(return_value=MagicMock(value=0.9)) - from graphrag_eval.answer_relevance import AnswerRelevancy + + from ragas.metrics.collections import AnswerRelevancy monkeypatch.setattr(AnswerRelevancy, "ascore", async_mock) mock_answer_correctness_evaluator(monkeypatch) + + def mock_init_evaluators(_): + mock_llm = MagicMock(spec=InstructorBaseRagasLLM) + evaluators = [] + answer_relevance_evaluator = AnswerRelevanceEvaluator( + mock_llm, + MagicMock(spec=BaseRagasEmbedding) + ) + evaluators.append(answer_relevance_evaluator) + answer_correctness_evaluator = AnswerCorrectnessEvaluator( + ragas_llm=mock_llm + ) + evaluators.append(answer_correctness_evaluator) + + return evaluators, mock_llm + + from graphrag_eval import evaluation + monkeypatch.setattr( + evaluation, + "parse_config_and_init_evaluators", + mock_init_evaluators + ) + + reference_data = yaml.safe_load( + (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") + ) actual_responses = read_responses(DATA_DIR / "actual_responses_3.jsonl") evaluation_results = await run_evaluation( reference_data, @@ -114,3 +158,100 @@ async def test_run_evaluation_and_compute_aggregates_all_errors(): ) aggregates = compute_aggregates(evaluation_results, CONFIG_FILE_PATH) assert expected_aggregates == aggregates + + +@pytest.mark.asyncio +async def test_answer_correctness_disabled(monkeypatch): + async_mock = AsyncMock(return_value=MagicMock(value=0.9)) + + from ragas.metrics.collections import AnswerRelevancy + monkeypatch.setattr(AnswerRelevancy, "ascore", async_mock) + + def mock_init_evaluators(_): + mock_llm = MagicMock(spec=InstructorBaseRagasLLM) + answer_relevance_evaluator = AnswerRelevanceEvaluator( + mock_llm, + MagicMock(spec=BaseRagasEmbedding) + ) + return [answer_relevance_evaluator], mock_llm + + from graphrag_eval import evaluation + monkeypatch.setattr( + evaluation, + "parse_config_and_init_evaluators", + mock_init_evaluators + ) + + reference_data = yaml.safe_load( + (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") + ) + actual_responses = read_responses(DATA_DIR / "actual_responses_3.jsonl") + evaluation_results = await run_evaluation( + reference_data, + actual_responses, + CONFIG_FILE_PATH + ) + assert len(evaluation_results) > 0 + for res in evaluation_results: + assert "answer_relevance" in res + assert "answer_f1" not in res + + +@pytest.mark.asyncio +async def test_answer_relevance_disabled(monkeypatch): + mock_answer_correctness_evaluator(monkeypatch) + + def mock_init_evaluators(_): + mock_llm = MagicMock(spec=InstructorBaseRagasLLM) + answer_correctness_evaluator = AnswerCorrectnessEvaluator( + ragas_llm=mock_llm + ) + return [answer_correctness_evaluator], mock_llm + + from graphrag_eval import evaluation + monkeypatch.setattr( + evaluation, + "parse_config_and_init_evaluators", + mock_init_evaluators + ) + + reference_data = yaml.safe_load( + (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") + ) + actual_responses = read_responses(DATA_DIR / "actual_responses_3.jsonl") + evaluation_results = await run_evaluation( + reference_data, + actual_responses, + CONFIG_FILE_PATH + ) + assert len(evaluation_results) > 0 + for res in evaluation_results: + assert "answer_f1" in res + assert "answer_relevance" not in res + + +@pytest.mark.asyncio +async def test_answer_correctness_and_answer_relevance_disabled(monkeypatch): + def mock_init_evaluators(_): + return [], MagicMock(spec=InstructorBaseRagasLLM) + + from graphrag_eval import evaluation + monkeypatch.setattr( + evaluation, + "parse_config_and_init_evaluators", + mock_init_evaluators + ) + + reference_data = yaml.safe_load( + (DATA_DIR / "reference_1.yaml").read_text(encoding="utf-8") + ) + actual_responses = read_responses(DATA_DIR / "actual_responses_3.jsonl") + evaluation_results = await run_evaluation( + reference_data, + actual_responses, + CONFIG_FILE_PATH + ) + assert len(evaluation_results) > 0 + for res in evaluation_results: + assert "answer_f1" not in res + assert "answer_relevance" not in res diff --git a/tests-with-llm/test_llm_factory.py b/tests-with-llm/test_llm_factory.py index 5675e864..74d812c0 100644 --- a/tests-with-llm/test_llm_factory.py +++ b/tests-with-llm/test_llm_factory.py @@ -1,6 +1,6 @@ from graphrag_eval import evaluation from graphrag_eval.llm_factory import ( - Config, + LLMConfig, GenerationConfig, EmbeddingConfig, create_llm, @@ -9,15 +9,15 @@ def test_create_llm_and_embeddings_no_llm_config(): - llm = create_llm(evaluation.Config()) - embedder = create_embedder(evaluation.Config()) + llm = create_llm(evaluation.Config().llm) + embedder = create_embedder(evaluation.Config().llm) assert llm is None assert embedder is None def test_create_llm_and_embeddings_llm_config_no_embedding_config(): config = evaluation.Config( - llm=Config( + llm=LLMConfig( generation=GenerationConfig( provider="openai", model="gpt-3.5-turbo", @@ -26,8 +26,8 @@ def test_create_llm_and_embeddings_llm_config_no_embedding_config(): ) ) ) - llm = create_llm(config) - embedder = create_embedder(config) + llm = create_llm(config.llm) + embedder = create_embedder(config.llm) assert llm is not None assert llm.model == "openai/gpt-3.5-turbo" assert embedder is None @@ -35,7 +35,7 @@ def test_create_llm_and_embeddings_llm_config_no_embedding_config(): def test_create_llm_and_embeddings_llm_config_embedding_config(): config = evaluation.Config( - llm=Config( + llm=LLMConfig( generation=GenerationConfig( provider="openai", model="gpt-3.5-turbo", @@ -48,8 +48,8 @@ def test_create_llm_and_embeddings_llm_config_embedding_config(): ), ), ) - llm = create_llm(config) - embedder = create_embedder(config) + llm = create_llm(config.llm) + embedder = create_embedder(config.llm) assert llm is not None assert llm.model == "openai/gpt-3.5-turbo" assert embedder is not None diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 79299ff7..59ec8c14 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -13,6 +13,6 @@ def test_compute_aggregates_doesnt_throw_exception(): is a DESCRIBE query containing the string "results" in the text. """ evaluation_results_file = DATA_DIR / f"evaluation_3.yaml" - with open(evaluation_results_file, "r", encoding="utf-8") as yaml_file: + with open(evaluation_results_file, encoding="utf-8") as yaml_file: per_question_eval = yaml.safe_load(yaml_file) compute_aggregates(per_question_eval) diff --git a/tests/test_llm_factory.py b/tests/test_llm_factory.py index 4e7f05f3..de04c07a 100644 --- a/tests/test_llm_factory.py +++ b/tests/test_llm_factory.py @@ -1,10 +1,14 @@ from pytest import raises -from graphrag_eval.llm_factory import Config, GenerationConfig, EmbeddingConfig +from graphrag_eval.llm_factory import ( + LLMConfig, + GenerationConfig, + EmbeddingConfig, +) def test_config_ok(): - c = Config( + c = LLMConfig( generation=GenerationConfig( provider="generation_provider", model="generation_model", @@ -25,7 +29,7 @@ def test_config_ok(): def test_config_optional(): - c = Config( + c = LLMConfig( generation=GenerationConfig( provider="generation_provider", model="generation_model", @@ -51,7 +55,7 @@ def test_config_optional(): def test_config_invalid_temperature(): with raises(ValueError): - Config( + LLMConfig( generation=GenerationConfig( provider="generation_provider", model="generation_model", @@ -64,7 +68,7 @@ def test_config_invalid_temperature(): ) ) with raises(ValueError): - Config( + LLMConfig( generation=GenerationConfig( provider="generation_provider", model="generation_model", @@ -80,7 +84,7 @@ def test_config_invalid_temperature(): def test_config_invalid_max_tokens(): with raises(ValueError): - Config( + LLMConfig( generation=GenerationConfig( provider="generation_provider", model="generation_model", @@ -94,7 +98,7 @@ def test_config_invalid_max_tokens(): ) with raises(ValueError): - Config( + LLMConfig( generation=GenerationConfig( provider="generation_provider", model="generation_model",