Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions collatable/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collatable.fields.index_field import IndexField # noqa: F401
from collatable.fields.label_field import LabelField # noqa: F401
from collatable.fields.list_field import ListField # noqa: F401
from collatable.fields.mapping_field import MappingField # noqa: F401
from collatable.fields.metadata_field import MetadataField # noqa: F401
from collatable.fields.scalar_field import ScalarField # noqa: F401
from collatable.fields.sequence_field import SequenceField # noqa: F401
Expand Down
49 changes: 49 additions & 0 deletions collatable/fields/mapping_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Dict, Iterator, Mapping, Sequence, Tuple, Type

from collatable.fields.field import Field
from collatable.types import DataArray


class MappingField(Field[Dict[str, Any]], Mapping[str, Field]):
def __init__(self, mapping: Mapping[str, Field]) -> None:
super().__init__()
self._mapping = mapping

def __len__(self) -> int:
return len(self._mapping)

def __iter__(self) -> Iterator[str]:
return iter(self._mapping)

def __getitem__(self, key: str) -> Field:
return self._mapping[key]

def __contains__(self, key: object) -> bool:
return key in self._mapping

def __str__(self) -> str:
return str(self._mapping)

def __repr__(self) -> str:
return f"MappingField({self._mapping})"

def as_array(self) -> Dict[str, Any]:
return {key: field.as_array() for key, field in self._mapping.items()}

@classmethod
def from_array( # type: ignore[override]
cls,
array: Mapping[str, DataArray],
*,
fields: Mapping[str, Tuple[Type[Field], Mapping[str, Any]]],
) -> "MappingField":
return cls({key: field.from_array(array[key], **params) for key, (field, params) in fields.items()})

def collate( # type: ignore[override]
self,
arrays: Sequence,
) -> Dict[str, Any]:
if not isinstance(arrays[0], MappingField):
return super().collate(arrays) # type: ignore[no-any-return]
arrays = [x.as_array() for x in arrays]
return {key: field.collate([x[key] for x in arrays]) for key, field in self._mapping.items()}
102 changes: 102 additions & 0 deletions tests/fields/test_mapping_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Mapping, Sequence

import numpy

from collatable.fields import LabelField, MappingField, TextField
from collatable.fields.text_field import PaddingValue
from collatable.utils import debatched


def test_mapping_field_can_be_converted_to_array() -> None:
tokens = ["this", "is", "a", "test"]
vocab = {"a": 0, "is": 1, "test": 2, "this": 3}
text_field = TextField(tokens, vocab=vocab)

labels = {"negative": 0, "positive": 1}
label_field = LabelField("positive", vocab=labels)

field = MappingField({"text": text_field, "label": label_field})
array = field.as_array()

assert isinstance(array, dict)
assert array.keys() == {"text", "label"}
assert array["text"]["token_ids"].tolist() == [3, 1, 0, 2]
assert array["text"]["mask"].tolist() == [True, True, True, True]
assert array["label"] == 1


def test_mapping_field_can_be_collated() -> None:
class TokenIndexer:
VOCAB = {"<pad>": -1, "a": 0, "first": 1, "is": 2, "this": 3, "second": 4, "sentence": 5, "!": 6}
INV_VOCAB = {index: token for token, index in VOCAB.items()}

def __call__(self, tokens: Sequence[str], /) -> Mapping[str, numpy.ndarray]:
return {
"token_ids": numpy.array([self.VOCAB[token] for token in tokens], dtype=numpy.int64),
"mask": numpy.array([True] * len(tokens), dtype=numpy.bool_),
}

def decode(self, index: Mapping[str, numpy.ndarray], /) -> Sequence[str]:
return [self.INV_VOCAB[index] for index in index["token_ids"] if index != -1]

class LabelIndexer:
LABELS = {"negative": 0, "positive": 1}
INV_LABELS = {index: label for label, index in LABELS.items()}

def __call__(self, label: str, /) -> int:
return self.LABELS[label]

def decode(self, index: int, /) -> str:
return self.INV_LABELS[index]

token_indexer = TokenIndexer()
label_indexer = LabelIndexer()
padding_value: PaddingValue = {"token_ids": -1}
fields = [
MappingField(
{
"text": TextField(
["this", "is", "a", "first", "sentence"], indexer=token_indexer, padding_value=padding_value
),
"label": LabelField("positive", indexer=label_indexer),
}
),
MappingField(
{
"text": TextField(
["this", "is", "a", "second", "sentence", "!"], indexer=token_indexer, padding_value=padding_value
),
"label": LabelField("negative", indexer=label_indexer),
}
),
]

output = fields[0].collate(fields)

assert isinstance(output, dict)
assert output.keys() == {"text", "label"}
assert output["text"]["token_ids"].tolist() == [[3, 2, 0, 1, 5, -1], [3, 2, 0, 4, 5, 6]]
assert output["text"]["mask"].sum(1).tolist() == [5, 6]
assert output["label"].tolist() == [1, 0]

reconstruction = [
MappingField.from_array(
array, # type: ignore[arg-type]
fields={
"text": (TextField, {"indexer": token_indexer}),
"label": (LabelField, {"indexer": label_indexer}),
},
)
for array in debatched(output)
]

assert len(reconstruction) == len(fields)
assert all(isinstance(field, MappingField) for field in reconstruction)
assert isinstance(reconstruction[0]["text"], TextField)
assert isinstance(reconstruction[1]["text"], TextField)
assert reconstruction[0]["text"].tokens == ["this", "is", "a", "first", "sentence"]
assert reconstruction[1]["text"].tokens == ["this", "is", "a", "second", "sentence", "!"]
assert isinstance(reconstruction[0]["label"], LabelField)
assert isinstance(reconstruction[1]["label"], LabelField)
assert reconstruction[0]["label"].label == "positive"
assert reconstruction[1]["label"].label == "negative"