From 5fccad7c5f8f7baf852b016e4a161917c225b12d Mon Sep 17 00:00:00 2001 From: altescy Date: Mon, 24 Feb 2025 23:30:57 +0900 Subject: [PATCH 1/2] add mapping field --- collatable/fields/__init__.py | 1 + collatable/fields/mapping_field.py | 50 ++++++++++++++ tests/fields/test_mapping_field.py | 102 +++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 collatable/fields/mapping_field.py create mode 100644 tests/fields/test_mapping_field.py diff --git a/collatable/fields/__init__.py b/collatable/fields/__init__.py index ae73a7f..3f1e717 100644 --- a/collatable/fields/__init__.py +++ b/collatable/fields/__init__.py @@ -3,6 +3,7 @@ from collatable.fields.index_field import IndexField # noqa: F401 from collatable.fields.label_field import LabelField # noqa: F401 from collatable.fields.list_field import ListField # noqa: F401 +from collatable.fields.mapping_field import MappingField # noqa: F401 from collatable.fields.metadata_field import MetadataField # noqa: F401 from collatable.fields.scalar_field import ScalarField # noqa: F401 from collatable.fields.sequence_field import SequenceField # noqa: F401 diff --git a/collatable/fields/mapping_field.py b/collatable/fields/mapping_field.py new file mode 100644 index 0000000..72db184 --- /dev/null +++ b/collatable/fields/mapping_field.py @@ -0,0 +1,50 @@ +from collections import abc +from typing import Any, Dict, Iterator, Mapping, Sequence, Tuple, Type + +from collatable.fields.field import Field +from collatable.types import DataArray + + +class MappingField(Field[Dict[str, Any]], abc.Mapping[str, Field]): + def __init__(self, mapping: Mapping[str, Field]) -> None: + super().__init__() + self._mapping = mapping + + def __len__(self) -> int: + return len(self._mapping) + + def __iter__(self) -> Iterator[str]: + return iter(self._mapping) + + def __getitem__(self, key: str) -> Field: + return self._mapping[key] + + def __contains__(self, key: object) -> bool: + return key in self._mapping + + def __str__(self) -> str: + return str(self._mapping) + + def __repr__(self) -> str: + return f"MappingField({self._mapping})" + + def as_array(self) -> Dict[str, Any]: + return {key: field.as_array() for key, field in self._mapping.items()} + + @classmethod + def from_array( # type: ignore[override] + cls, + array: Mapping[str, DataArray], + *, + fields: Mapping[str, Tuple[Type[Field], Mapping[str, Any]]], + ) -> "MappingField": + return cls({key: field.from_array(array[key], **params) for key, (field, params) in fields.items()}) + + def collate( # type: ignore[override] + self, + arrays: Sequence, + ) -> Dict[str, Any]: + if not isinstance(arrays[0], MappingField): + return super().collate(arrays) # type: ignore[no-any-return] + arrays = [x.as_array() for x in arrays] + return {key: field.collate([x[key] for x in arrays]) for key, field in self._mapping.items()} diff --git a/tests/fields/test_mapping_field.py b/tests/fields/test_mapping_field.py new file mode 100644 index 0000000..2794ef0 --- /dev/null +++ b/tests/fields/test_mapping_field.py @@ -0,0 +1,102 @@ +from typing import Mapping, Sequence + +import numpy + +from collatable.fields import LabelField, MappingField, TextField +from collatable.fields.text_field import PaddingValue +from collatable.utils import debatched + + +def test_mapping_field_can_be_converted_to_array() -> None: + tokens = ["this", "is", "a", "test"] + vocab = {"a": 0, "is": 1, "test": 2, "this": 3} + text_field = TextField(tokens, vocab=vocab) + + labels = {"negative": 0, "positive": 1} + label_field = LabelField("positive", vocab=labels) + + field = MappingField({"text": text_field, "label": label_field}) + array = field.as_array() + + assert isinstance(array, dict) + assert array.keys() == {"text", "label"} + assert array["text"]["token_ids"].tolist() == [3, 1, 0, 2] + assert array["text"]["mask"].tolist() == [True, True, True, True] + assert array["label"] == 1 + + +def test_mapping_field_can_be_collated() -> None: + class TokenIndexer: + VOCAB = {"": -1, "a": 0, "first": 1, "is": 2, "this": 3, "second": 4, "sentence": 5, "!": 6} + INV_VOCAB = {index: token for token, index in VOCAB.items()} + + def __call__(self, tokens: Sequence[str], /) -> Mapping[str, numpy.ndarray]: + return { + "token_ids": numpy.array([self.VOCAB[token] for token in tokens], dtype=numpy.int64), + "mask": numpy.array([True] * len(tokens), dtype=numpy.bool_), + } + + def decode(self, index: Mapping[str, numpy.ndarray], /) -> Sequence[str]: + return [self.INV_VOCAB[index] for index in index["token_ids"] if index != -1] + + class LabelIndexer: + LABELS = {"negative": 0, "positive": 1} + INV_LABELS = {index: label for label, index in LABELS.items()} + + def __call__(self, label: str, /) -> int: + return self.LABELS[label] + + def decode(self, index: int, /) -> str: + return self.INV_LABELS[index] + + token_indexer = TokenIndexer() + label_indexer = LabelIndexer() + padding_value: PaddingValue = {"token_ids": -1} + fields = [ + MappingField( + { + "text": TextField( + ["this", "is", "a", "first", "sentence"], indexer=token_indexer, padding_value=padding_value + ), + "label": LabelField("positive", indexer=label_indexer), + } + ), + MappingField( + { + "text": TextField( + ["this", "is", "a", "second", "sentence", "!"], indexer=token_indexer, padding_value=padding_value + ), + "label": LabelField("negative", indexer=label_indexer), + } + ), + ] + + output = fields[0].collate(fields) + + assert isinstance(output, dict) + assert output.keys() == {"text", "label"} + assert output["text"]["token_ids"].tolist() == [[3, 2, 0, 1, 5, -1], [3, 2, 0, 4, 5, 6]] + assert output["text"]["mask"].sum(1).tolist() == [5, 6] + assert output["label"].tolist() == [1, 0] + + reconstruction = [ + MappingField.from_array( + array, # type: ignore[arg-type] + fields={ + "text": (TextField, {"indexer": token_indexer}), + "label": (LabelField, {"indexer": label_indexer}), + }, + ) + for array in debatched(output) + ] + + assert len(reconstruction) == len(fields) + assert all(isinstance(field, MappingField) for field in reconstruction) + assert isinstance(reconstruction[0]["text"], TextField) + assert isinstance(reconstruction[1]["text"], TextField) + assert reconstruction[0]["text"].tokens == ["this", "is", "a", "first", "sentence"] + assert reconstruction[1]["text"].tokens == ["this", "is", "a", "second", "sentence", "!"] + assert isinstance(reconstruction[0]["label"], LabelField) + assert isinstance(reconstruction[1]["label"], LabelField) + assert reconstruction[0]["label"].label == "positive" + assert reconstruction[1]["label"].label == "negative" From f2cafe06a4a26b4df8894f35a3023861b2089ff6 Mon Sep 17 00:00:00 2001 From: altescy Date: Mon, 24 Feb 2025 23:33:39 +0900 Subject: [PATCH 2/2] use typing.Mapping --- collatable/fields/mapping_field.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/collatable/fields/mapping_field.py b/collatable/fields/mapping_field.py index 72db184..f87d306 100644 --- a/collatable/fields/mapping_field.py +++ b/collatable/fields/mapping_field.py @@ -1,11 +1,10 @@ -from collections import abc from typing import Any, Dict, Iterator, Mapping, Sequence, Tuple, Type from collatable.fields.field import Field from collatable.types import DataArray -class MappingField(Field[Dict[str, Any]], abc.Mapping[str, Field]): +class MappingField(Field[Dict[str, Any]], Mapping[str, Field]): def __init__(self, mapping: Mapping[str, Field]) -> None: super().__init__() self._mapping = mapping