diff --git a/pose_evaluation/metrics/base.py b/pose_evaluation/metrics/base.py index 75f8803..692f04b 100644 --- a/pose_evaluation/metrics/base.py +++ b/pose_evaluation/metrics/base.py @@ -1,11 +1,70 @@ # pylint: disable=undefined-variable +from typing import Any, Callable from tqdm import tqdm +class Signature: + """Represents reproducibility signatures for metrics. Inspired by sacreBLEU + """ + def __init__(self, name:str, args: dict): + + self._abbreviated = { + "name":"n", + "higher_is_better":"hb" + } + + self.signature_info = {"name": name, **args} + + def update(self, key: str, value: Any): + self.signature_info[key] = value + + def update_signature_and_abbr(self, key:str, abbr:str, args:dict): + self._abbreviated.update({ + key: abbr + }) + + self.signature_info.update({ + key: args.get(key, None) + }) + + def format(self, short: bool = False) -> str: + pairs = [] + keys = list(self.signature_info.keys()) + for name in keys: + value = self.signature_info[name] + if value is not None: + # Check for nested signature objects + if hasattr(value, "get_signature"): + # Wrap nested signatures in brackets + nested_signature = value.get_signature() + if isinstance(nested_signature, Signature): + nested_signature = nested_signature.format(short=short) + value = f"{{{nested_signature}}}" + if isinstance(value, bool): + # Replace True/False with yes/no + value = "yes" if value else "no" + if isinstance(value, Callable): + value = value.__name__ + + # if the abbreviation is not defined, use the full name as a fallback. + abbreviated_name = self._abbreviated.get(name, name) + final_name = abbreviated_name if short else name + pairs.append(f"{final_name}:{value}") + + return "|".join(pairs) + + def __str__(self): + return self.format() + + def __repr__(self): + return self.format() + class BaseMetric[T]: """Base class for all metrics.""" + # Each metric should define its Signature class' name here + _SIGNATURE_TYPE = Signature - def __init__(self, name: str, higher_is_better: bool = True): + def __init__(self, name: str, higher_is_better: bool = False): self.name = name self.higher_is_better = higher_is_better @@ -38,3 +97,6 @@ def score_all(self, hypotheses: list[T], references: list[T], progress_bar=True) def __str__(self): return self.name + + def get_signature(self) -> Signature: + return self._SIGNATURE_TYPE(self.name, self.__dict__)