Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pytei-client"
version = "0.1.0"
version = "0.1.1"
authors = [
{ name="Daniel Gomm", email="daniel.gomm@cwi.nl" },
]
Expand Down
50 changes: 47 additions & 3 deletions src/pytei/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,56 @@ def embed(self, inputs: Union[str, List[str]], normalize: bool = True, prompt_na
else:
raise AttributeError("text_input must be either a string or a list of strings.")

@staticmethod
def _build_rerank_request_body(query: str, texts: List[str], raw_scores: bool = False, return_text: bool = False,
truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None):
body = {
"query": query,
"texts": texts,
"raw_scores": raw_scores,
"return_text": return_text
}
if truncate is True:
body["truncate"] = True
if truncation_direction is None:
truncation_direction = "right"
body["truncation_direction"] = truncation_direction
return body

def _fetch_reranking_result(self, request_body: Dict) -> List[Rank]:
try:
response = requests.post(f"{self._endpoint}/rerank", json=request_body, headers={"Content-Type": "application/json"},
timeout=self._timeout)
response.raise_for_status()
raw_reranking_results = json.loads(response.text)
return [Rank(raw_result["index"], raw_result["score"], raw_result.get("text", None)) for raw_result in raw_reranking_results]
except (requests.RequestException, json.JSONDecodeError, IndexError, ValueError) as e:
raise RuntimeError(f"Failed to rerank texts: {e}")

def rerank(self, query: str, texts: List[str], raw_score: bool = False, return_text: bool = False,
truncate: bool = False, truncation_direction: Union[Literal['left', 'right'], None] = None) -> List[Rank]:
def rerank(self, query: str, texts: List[str], raw_scores: bool = False, return_text: bool = False,
truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> List[Rank]:
raise NotImplementedError("Reranking is not yet implemented.")
"""
Re-rank retrieval results.

:param query: Query text against which to re-rank the texts.
:type query: str
:param texts: Texts that should be re-ranked.
:type texts: List[str]
:param raw_scores: Whether to return raw scores. Default is False.
:type raw_scores: bool, optional
:param return_text: Whether to include the raw text in the ranking results. Default is False.
:type return_text: bool, optional
:param truncate: Whether to apply truncation. Default is False.
:type truncate: bool, optional
:param truncation_direction: Direction in which truncation is applied. Default is 'right'.
:type truncation_direction: str, optional
:return: A re-ranked list of the supplied texts.
:rtype: List[`Rank`]
"""
body = self._build_rerank_request_body(query=query, texts=texts, raw_scores=raw_scores, return_text=return_text,
truncate=truncate,
truncation_direction=truncation_direction)
return self._fetch_reranking_result(body)

def predict(self, inputs: Union[str, Tuple[str, str], List[Union[str, Tuple[str, str]]]], raw_scores: bool = False,
truncate: bool = False, truncation_direction: Optional[Literal['left', 'right']] = None) -> Union[PredictionResult, List[PredictionResult]]:
Expand Down
Loading