diff --git a/README.md b/README.md index 8822420..7cdbb7c 100644 --- a/README.md +++ b/README.md @@ -181,3 +181,74 @@ Execution result: [-1, -1, -1], [-1, -1, -1]]], dtype=int32)} ``` + + +### DataModule + +```python +from dataclasses import dataclass +from typing import Sequence, Union + +from collatable import LabelField, TextField +from collatable.extras import DataLoader, LabelIndexer, TokenIndexer +from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform + + +@dataclass +class Text2TextExample: + source: Union[str, Sequence[str]] + target: Union[str, Sequence[str]] + + +text2text_dataset = [ + Text2TextExample(source="how are you?", target="I am fine."), + Text2TextExample(source="what is your name?", target="My name is John."), + Text2TextExample(source="where are you?", target="I am in New York."), + Text2TextExample(source="what is the time?", target="It is 10:00 AM."), +] + +shared_token_indexer = TokenIndexer( + default="", + specials=["", ""], +) + +text2text_datamodule = DataModule[Text2TextExample]( + fields={ + "source": TextFieldTransform(indexer=shared_token_indexer, pad_token=""), + "target": TextFieldTransform(indexer=shared_token_indexer, pad_token=""), + } +) + +with shared_token_indexer.context(train=True): + text2text_datamodule.build(text2text_dataset) + +text2text_instances = list(text2text_datamodule(text2text_dataset)) + +dataloader = DataLoader(batch_size=2) +for batch in dataloader(text2text_instances): + print(batch) +``` + +Execution result: + +``` +{'target': { + 'token_ids': array([[12, 13, 0, 0], + [14, 8, 6, 15]]), + 'mask': array([[ True, True, False, False], + [ True, True, True, True]])}, + 'source': { + 'token_ids': array([[2, 3, 4, 0], + [5, 6, 7, 8]]), + 'mask': array([[ True, True, True, False], + [ True, True, True, True]])}} +{'target': { + 'token_ids': array([[12, 16, 17, 18, 0], + [19, 6, 20, 21, 22]]), + 'mask': array([[ True, True, True, True, False], + [ True, True, True, True, True]])}, + 'source': {'token_ids': array([[ 9, 3, 4, 0], + [ 5, 6, 10, 11]]), + 'mask': array([[ True, True, True, False], + [ True, True, True, True]])}} +``` diff --git a/collatable/extras/__init__.py b/collatable/extras/__init__.py index 866a29c..5a04d2a 100644 --- a/collatable/extras/__init__.py +++ b/collatable/extras/__init__.py @@ -1,3 +1,10 @@ from collatable.extras.dataloader import DataLoader # noqa: F401 +from collatable.extras.datamodule import ( # noqa: F401 + DataModule, + FieldConfig, + FieldTransform, + LabelFieldTransform, + TextFieldTransform, +) from collatable.extras.dataset import Dataset # noqa: F401 from collatable.extras.indexer import Indexer, LabelIndexer, TokenIndexer # noqa: F401 diff --git a/collatable/extras/datamodule.py b/collatable/extras/datamodule.py new file mode 100644 index 0000000..60f3b27 --- /dev/null +++ b/collatable/extras/datamodule.py @@ -0,0 +1,163 @@ +import re +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + Mapping, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + runtime_checkable, +) + +from collatable import Field, LabelField, TextField +from collatable.typing import DataArray, Scalar, Tensor + +S = TypeVar("S") +T = TypeVar("T") +U = TypeVar("U") +HashableT = TypeVar("HashableT", bound=Hashable) +HashableT_contra = TypeVar("HashableT_contra", bound=Hashable, contravariant=True) +IndexT_co = TypeVar("IndexT_co", bound=Union[Scalar, DataArray], covariant=True) + + +@runtime_checkable +class IIndexer(Protocol[HashableT_contra, IndexT_co]): + def __len__(self) -> int: ... + + def __getitem__(self, value: HashableT_contra, /) -> int: ... + + def __call__(self, value: HashableT_contra, /) -> IndexT_co: ... + + +@runtime_checkable +class ISequenceIndexer(Protocol[HashableT_contra, IndexT_co]): + def __len__(self) -> int: ... + + def __getitem__(self, value: HashableT_contra, /) -> int: ... + + def __call__(self, values: Sequence[HashableT_contra], /) -> IndexT_co: ... + + +class FieldAccessor: + def __init__(self, field: str) -> None: + self._field = field.split(".") + + def __call__(self, obj: Any) -> Any: + for part in self._field: + obj = obj[part] if isinstance(obj, Mapping) else getattr(obj, part) + return obj + + +class FieldTransform(Generic[S]): + def __call__(self, obj: S) -> Field: + raise NotImplementedError + + def build(self, dataset: Iterable[S]) -> None: + pass + + def indexers(self) -> Mapping[str, IIndexer]: + return dict((attribute, value) for attribute, value in self.__dict__.items() if isinstance(value, IIndexer)) + + +class TextFieldTransform(Generic[HashableT], FieldTransform[Union[str, Sequence[HashableT]]]): + _DEFAULT_TOKENIZER_PATTERN = re.compile(r"(?u)\b\w\w+\b") + + def __init__( + self, + tokenizer: Optional[Callable[[str], Sequence[HashableT]]] = None, + pad_token: Optional[HashableT] = None, + unk_token: Optional[HashableT] = None, + special_tokens: Optional[Sequence[HashableT]] = None, + indexer: Optional[ISequenceIndexer[HashableT, Mapping[str, Tensor]]] = None, + ) -> None: + from .indexer import TokenIndexer + + self._tokenizer = tokenizer or (lambda text: self._DEFAULT_TOKENIZER_PATTERN.findall(text)) + self._indexer: ISequenceIndexer[HashableT, Mapping[str, Tensor]] = ( + indexer if indexer is not None else TokenIndexer[HashableT]() + ) + self._pad_token = pad_token + self._special_tokens = special_tokens or [] + if unk_token is not None and unk_token not in self._special_tokens: + self._special_tokens = [unk_token, *self._special_tokens] + if pad_token is not None and pad_token not in self._special_tokens: + self._special_tokens = [pad_token, *self._special_tokens] + + def __call__(self, obj: Union[str, Sequence[HashableT]]) -> TextField: + if isinstance(obj, str): + obj = self._tokenizer(obj) + return TextField( + obj, + indexer=self._indexer, + padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0, + ) + + def build(self, dataset: Iterable[Union[str, Sequence[HashableT]]]) -> None: + for special_token in self._special_tokens: + self._indexer[special_token] + for text in dataset: + if isinstance(text, str): + text = self._tokenizer(text) + self._indexer(text) + + def indexers(self) -> Mapping[str, IIndexer]: + return {"tokens": self._indexer} + + +class LabelFieldTransform(FieldTransform[HashableT]): + def __init__( + self, + indexer: Optional[IIndexer[HashableT, int]] = None, + ) -> None: + from .indexer import LabelIndexer + + self._indexer: IIndexer[HashableT, int] = indexer or LabelIndexer[HashableT]() + + def __call__(self, obj: HashableT) -> LabelField: + return LabelField(obj, indexer=self._indexer) + + def build(self, dataset: Iterable[HashableT]) -> None: + for label in dataset: + self._indexer[label] + + def indexers(self) -> Mapping[str, IIndexer[HashableT, int]]: + return {"labels": self._indexer} + + +@dataclass +class FieldConfig(Generic[S, T]): + accessor: Callable[[S], T] + transform: FieldTransform[T] + + +class DataModule(Generic[T]): + def __init__( + self, + fields: Mapping[str, Union[FieldTransform, FieldConfig[T, Any]]], + ) -> None: + self._fields: Mapping[str, FieldConfig[T, Any]] = { + name: ( + FieldConfig(accessor=FieldAccessor(name), transform=transform) + if isinstance(transform, FieldTransform) + else transform + ) + for name, transform in fields.items() + } + + def build(self, dataset: Iterable[T]) -> None: + for field in self._fields.values(): + field.transform.build(field.accessor(obj) for obj in dataset) + + def __call__(self, dataset: Iterable[T]) -> Iterable[Dict[str, Field]]: + for obj in dataset: + yield {name: field.transform(field.accessor(obj)) for name, field in self._fields.items()} + + def indexer(self, field: str, name: str) -> IIndexer: + return self._fields[field].transform.indexers()[name] diff --git a/collatable/fields/label_field.py b/collatable/fields/label_field.py index 390bc59..6b972f5 100644 --- a/collatable/fields/label_field.py +++ b/collatable/fields/label_field.py @@ -1,4 +1,4 @@ -from typing import Callable, Mapping, Optional, TypeVar, Union +from typing import Callable, Generic, Hashable, Mapping, Optional, TypeVar import numpy @@ -6,17 +6,18 @@ from collatable.typing import IntTensor Self = TypeVar("Self", bound="LabelField") +LabelT = TypeVar("LabelT", bound=Hashable) -class LabelField(Field[IntTensor]): +class LabelField(Generic[LabelT], Field[IntTensor]): __slots__ = ["_label", "_label_index"] def __init__( self, - label: Union[int, str], + label: LabelT, *, - vocab: Optional[Mapping[str, int]] = None, - indexer: Optional[Callable[[str], int]] = None, + vocab: Optional[Mapping[LabelT, int]] = None, + indexer: Optional[Callable[[LabelT], int]] = None, ) -> None: if isinstance(label, str) and vocab is None is indexer: raise ValueError("LabelField with string labels requires vocab or indexer") @@ -42,15 +43,15 @@ def __repr__(self) -> str: return f"LabelField(label={self._label})" @property - def label(self) -> Union[int, str]: + def label(self) -> LabelT: return self._label def as_array(self) -> IntTensor: return numpy.array(self._label_index, dtype=numpy.int32) @staticmethod - def _make_indexer(vocab: Mapping[str, int]) -> Callable[[str], int]: - def indexer(label: str) -> int: + def _make_indexer(vocab: Mapping[LabelT, int]) -> Callable[[LabelT], int]: + def indexer(label: LabelT) -> int: return vocab[label] return indexer