From 871d87df14bc6d1897b3d0b50c2634d04f3dd685 Mon Sep 17 00:00:00 2001 From: Daniel Gomm Date: Tue, 21 Jan 2025 10:48:25 +0100 Subject: [PATCH] initial reranking implementation --- pyproject.toml | 2 +- src/pytei/client.py | 50 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ab6f9d5..9e205fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pytei-client" -version = "0.1.0" +version = "0.1.1" authors = [ { name="Daniel Gomm", email="daniel.gomm@cwi.nl" }, ] diff --git a/src/pytei/client.py b/src/pytei/client.py index 0679f7a..c4b6322 100644 --- a/src/pytei/client.py +++ b/src/pytei/client.py @@ -137,12 +137,56 @@ def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_na else: raise AttributeError("text_input must be either a string or a list of strings.") + @staticmethod + def _build_rerank_request_body(query: str, texts: List[str], raw_scores: bool = False, return_text: bool = False, + truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None): + body = { + "query": query, + "texts": texts, + "raw_scores": raw_scores, + "return_text": return_text + } + if truncate is True: + body["truncate"] = True + if truncation_direction is None: + truncation_direction = "right" + body["truncation_direction"] = truncation_direction + return body + + def _fetch_reranking_result(self, request_body: Dict) -> List[Rank]: + try: + response = requests.post(f"{self._endpoint}/rerank", json=request_body, headers={"Content-Type": "application/json"}, + timeout=self._timeout) + response.raise_for_status() + raw_reranking_results = json.loads(response.text) + return [Rank(raw_result["index"], raw_result["score"], raw_result.get("text", None)) for raw_result in raw_reranking_results] + except (requests.RequestException, json.JSONDecodeError, IndexError, ValueError) as e: + raise RuntimeError(f"Failed to rerank texts: {e}") - def rerank(self, query: str, texts: List[str], raw_score: bool = False, return_text: bool = False, - truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None) -> List[Rank]: def rerank(self, query: str, texts: List[str], raw_scores: bool = False, return_text: bool = False, truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> List[Rank]: - raise NotImplementedError("Reranking is not yet implemented.") + """ + Re-rank retrieval results. + + :param query: Query text against which to re-rank the texts. + :type query: str + :param texts: Texts that should be re-ranked. + :type texts: List[str] + :param raw_scores: Whether to return raw scores. Default is False. + :type raw_scores: bool, optional + :param return_text: Whether to include the raw text in the ranking results. Default is False. + :type return_text: bool, optional + :param truncate: Whether to apply truncation. Default is False. + :type truncate: bool, optional + :param truncation_direction: Direction in which truncation is applied. Default is 'right'. + :type truncation_direction: str, optional + :return: A re-ranked list of the supplied texts. + :rtype: List[`Rank`] + """ + body = self._build_rerank_request_body(query=query, texts=texts, raw_scores=raw_scores, return_text=return_text, + truncate=truncate, + truncation_direction=truncation_direction) + return self._fetch_reranking_result(body) def predict(self, inputs: Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]], raw_scores: bool = False, truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> Union[PredictionResult, List[PredictionResult]]: