diff --git a/src/gimbench/arguments.py b/src/gimbench/arguments.py index 45b2823..40b3611 100644 --- a/src/gimbench/arguments.py +++ b/src/gimbench/arguments.py @@ -202,6 +202,7 @@ def _add_scierc_eval_args(parser): def _add_cv_eval_args(parser): parser.add_argument("--use_outlines", action="store_true", help="Whether to use outlines in CV evaluation") + parser.add_argument("--use_gliner2", action="store_true", help="Whether to use GLiNER2 in CV evaluation") parser.add_argument( "--judge_model_name", type=str, diff --git a/src/gimbench/cv/evaluators.py b/src/gimbench/cv/evaluators.py index b169c93..2f2d76d 100644 --- a/src/gimbench/cv/evaluators.py +++ b/src/gimbench/cv/evaluators.py @@ -17,7 +17,14 @@ from gimbench.log import get_logger from gimbench.models import SimpleGIM -from .schema import CV_FIELDS, GIMKIT_TEMPLATE, OUTLINES_JSON_SCHEMA, OUTLINES_TEMPLATE, SHARED_PROMPT_PREFIX +from .schema import ( + CV_FIELDS, + GIMKIT_TEMPLATE, + GLINER_SCHEMA, + OUTLINES_JSON_SCHEMA, + OUTLINES_TEMPLATE, + SHARED_PROMPT_PREFIX, +) logger = get_logger(__name__) @@ -236,7 +243,39 @@ def _extract_fields(self, cv_content: str) -> dict[str, str]: raise ValueError(f"Expected dict but got {type(extraction).__name__}: {extraction}") +class GLiNEREvaluator(CVEvaluator): + def __init__(self, args: Namespace, dataset: Dataset): + super().__init__(args, dataset) + try: + from gliner2 import GLiNER2 + except ImportError: + raise ImportError( + "The 'gliner2' package is required but not installed. " + "Please install it manually using `pip install gliner2` or `uv add gliner2` " + "to evaluate using this model." + ) + self.model = GLiNER2.from_pretrained(args.model_name) + + def _extract_fields(self, cv_content: str) -> dict[str, str]: + # GLiNER2 has a length limit, let's truncate just in case, or pass directly + result = self.model.extract_json(cv_content, GLINER_SCHEMA) + + extraction = {} + if "cv" in result and isinstance(result["cv"], list) and len(result["cv"]) > 0: + extracted_item = result["cv"][0] + if isinstance(extracted_item, dict): + for field in CV_FIELDS: + val = extracted_item.get(field, "") + extraction[field] = str(val) if val is not None else "" + return extraction + + def conduct_eval(args: Namespace, ds: Dataset): - evaluator = OutlinesEvaluator(args, ds) if args.use_outlines else GIMEvaluator(args, ds) + if args.use_outlines: + evaluator = OutlinesEvaluator(args, ds) + elif getattr(args, "use_gliner2", False): + evaluator = GLiNEREvaluator(args, ds) + else: + evaluator = GIMEvaluator(args, ds) result = evaluator.evaluate() result.dump() diff --git a/src/gimbench/cv/schema.py b/src/gimbench/cv/schema.py index 353172f..40d7d36 100644 --- a/src/gimbench/cv/schema.py +++ b/src/gimbench/cv/schema.py @@ -90,3 +90,21 @@ class CVData(BaseModel): OUTLINES_JSON_SCHEMA = CVData.model_json_schema() + +GLINER_SCHEMA = { + "cv": [ + "name::str::Full name of the person", + "country::str::Country, nationality, or country of residence", + "birthday::str::Date of birth", + "phone_number::str::Phone number", + "email::str::Email address", + "highest_level_degree::[Bachelor|Master|PhD]::str::Highest educational degree", + "university::str::University name", + "department::str::Department or school", + "major::str::Major or field of study", + "start_date::str::Start date of education", + "end_date::str::End date of education", + "homepage_url::str::Personal homepage URL", + "github_url::str::GitHub profile URL", + ] +}