From cfc641692568534b3da7623928638ebb54efad69 Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 18 Feb 2025 23:32:43 +0900 Subject: [PATCH 1/5] generic label type --- collatable/fields/label_field.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/collatable/fields/label_field.py b/collatable/fields/label_field.py index 390bc59..6b972f5 100644 --- a/collatable/fields/label_field.py +++ b/collatable/fields/label_field.py @@ -1,4 +1,4 @@ -from typing import Callable, Mapping, Optional, TypeVar, Union +from typing import Callable, Generic, Hashable, Mapping, Optional, TypeVar import numpy @@ -6,17 +6,18 @@ from collatable.typing import IntTensor Self = TypeVar("Self", bound="LabelField") +LabelT = TypeVar("LabelT", bound=Hashable) -class LabelField(Field[IntTensor]): +class LabelField(Generic[LabelT], Field[IntTensor]): __slots__ = ["_label", "_label_index"] def __init__( self, - label: Union[int, str], + label: LabelT, *, - vocab: Optional[Mapping[str, int]] = None, - indexer: Optional[Callable[[str], int]] = None, + vocab: Optional[Mapping[LabelT, int]] = None, + indexer: Optional[Callable[[LabelT], int]] = None, ) -> None: if isinstance(label, str) and vocab is None is indexer: raise ValueError("LabelField with string labels requires vocab or indexer") @@ -42,15 +43,15 @@ def __repr__(self) -> str: return f"LabelField(label={self._label})" @property - def label(self) -> Union[int, str]: + def label(self) -> LabelT: return self._label def as_array(self) -> IntTensor: return numpy.array(self._label_index, dtype=numpy.int32) @staticmethod - def _make_indexer(vocab: Mapping[str, int]) -> Callable[[str], int]: - def indexer(label: str) -> int: + def _make_indexer(vocab: Mapping[LabelT, int]) -> Callable[[LabelT], int]: + def indexer(label: LabelT) -> int: return vocab[label] return indexer From b3e062828f6e3d8fdd74435899aa2c12d0e541fa Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 18 Feb 2025 23:41:09 +0900 Subject: [PATCH 2/5] add datamodule --- README.md | 72 ++++++++++++++ collatable/extras/__init__.py | 7 ++ collatable/extras/datamodule.py | 161 ++++++++++++++++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 collatable/extras/datamodule.py diff --git a/README.md b/README.md index 8822420..cd1b02c 100644 --- a/README.md +++ b/README.md @@ -181,3 +181,75 @@ Execution result: [-1, -1, -1], [-1, -1, -1]]], dtype=int32)} ``` + + +### DataModule + +```python +from dataclasses import dataclass +from typing import Sequence, Union + +from collatable import LabelField, TextField +from collatable.extras import DataLoader, LabelIndexer, TokenIndexer +from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform + + +@dataclass +class Text2TextExample: + source: Union[str, Sequence[str]] + target: Union[str, Sequence[str]] + + +text2text_dataset = [ + Text2TextExample(source="how are you?", target="I am fine."), + Text2TextExample(source="what is your name?", target="My name is John."), + Text2TextExample(source="where are you?", target="I am in New York."), + Text2TextExample(source="what is the time?", target="It is 10:00 AM."), +] + +shared_token_indexer = TokenIndexer( + default="", + specials=["", ""], +) + +text2text_datamodule = DataModule[Text2TextExample]( + fields={ + "source": TextFieldTransform(indexer=shared_token_indexer, pad_token=""), + "target": TextFieldTransform(indexer=shared_token_indexer, pad_token=""), + } +) + +with shared_token_indexer.context(train=True): + text2text_datamodule.build(text2text_dataset) + +text2text_instances = list(text2text_datamodule(text2text_dataset)) + + +dataloader = DataLoader(batch_size=2) +for batch in dataloader(list(text2text_datamodule(text2text_dataset))): + print(batch) +``` + +Execution result: + +``` +{'target': { + 'token_ids': array([[12, 13, 0, 0], + [14, 8, 6, 15]]), + 'mask': array([[ True, True, False, False], + [ True, True, True, True]])}, + 'source': { + 'token_ids': array([[2, 3, 4, 0], + [5, 6, 7, 8]]), + 'mask': array([[ True, True, True, False], + [ True, True, True, True]])}} +{'target': { + 'token_ids': array([[12, 16, 17, 18, 0], + g[19, 6, 20, 21, 22]]), + 'mask': array([[ True, True, True, True, False], + g [ True, True, True, True, True]])}, + 'source': {'token_ids': array([[ 9, 3, 4, 0], + g [ 5, 6, 10, 11]]), + 'mask': array([[ True, True, True, False], + [ True, True, True, True]])}} +``` diff --git a/collatable/extras/__init__.py b/collatable/extras/__init__.py index 866a29c..5a04d2a 100644 --- a/collatable/extras/__init__.py +++ b/collatable/extras/__init__.py @@ -1,3 +1,10 @@ from collatable.extras.dataloader import DataLoader # noqa: F401 +from collatable.extras.datamodule import ( # noqa: F401 + DataModule, + FieldConfig, + FieldTransform, + LabelFieldTransform, + TextFieldTransform, +) from collatable.extras.dataset import Dataset # noqa: F401 from collatable.extras.indexer import Indexer, LabelIndexer, TokenIndexer # noqa: F401 diff --git a/collatable/extras/datamodule.py b/collatable/extras/datamodule.py new file mode 100644 index 0000000..e32001b --- /dev/null +++ b/collatable/extras/datamodule.py @@ -0,0 +1,161 @@ +import re +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + Generic, + Hashable, + Iterable, + Mapping, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + runtime_checkable, +) + +from collatable import Field, LabelField, TextField +from collatable.typing import DataArray, Scalar, Tensor + +S = TypeVar("S") +T = TypeVar("T") +U = TypeVar("U") +HashableT = TypeVar("HashableT", bound=Hashable) +HashableT_contra = TypeVar("HashableT_contra", bound=Hashable, contravariant=True) +IndexT_co = TypeVar("IndexT_co", bound=Union[Scalar, DataArray], covariant=True) + + +@runtime_checkable +class IIndexer(Protocol[HashableT_contra, IndexT_co]): + def __len__(self) -> int: ... + + def __getitem__(self, value: HashableT_contra, /) -> int: ... + + def __call__(self, value: HashableT_contra, /) -> IndexT_co: ... + + +@runtime_checkable +class ISequenceIndexer(Protocol[HashableT_contra, IndexT_co]): + def __len__(self) -> int: ... + + def __getitem__(self, value: HashableT_contra, /) -> int: ... + + def __call__(self, values: Sequence[HashableT_contra], /) -> IndexT_co: ... + + +class FieldAccessor: + def __init__(self, name: str) -> None: + self._name = name + + def __call__(self, obj: Any) -> Any: + return getattr(obj, self._name) + + +class FieldTransform(Generic[S]): + def __call__(self, obj: S) -> Field: + raise NotImplementedError + + def build(self, dataset: Iterable[S]) -> None: + pass + + def indexers(self) -> Mapping[str, IIndexer]: + return dict((attribute, value) for attribute, value in self.__dict__.items() if isinstance(value, IIndexer)) + + +class TextFieldTransform(Generic[HashableT], FieldTransform[Union[str, Sequence[HashableT]]]): + _DEFAULT_TOKENIZER_PATTERN = re.compile(r"(?u)\b\w\w+\b") + + def __init__( + self, + tokenizer: Optional[Callable[[str], Sequence[HashableT]]] = None, + pad_token: Optional[HashableT] = None, + unk_token: Optional[HashableT] = None, + special_tokens: Optional[Sequence[HashableT]] = None, + indexer: Optional[ISequenceIndexer[HashableT, Mapping[str, Tensor]]] = None, + ) -> None: + from .indexer import TokenIndexer + + self._tokenizer = tokenizer or (lambda text: self._DEFAULT_TOKENIZER_PATTERN.findall(text)) + self._indexer: ISequenceIndexer[HashableT, Mapping[str, Tensor]] = ( + indexer if indexer is not None else TokenIndexer[HashableT]() + ) + self._pad_token = pad_token + self._special_tokens = special_tokens or [] + if unk_token is not None and unk_token not in self._special_tokens: + self._special_tokens = [unk_token, *self._special_tokens] + if pad_token is not None and pad_token not in self._special_tokens: + self._special_tokens = [pad_token, *self._special_tokens] + + def __call__(self, obj: Union[str, Sequence[HashableT]]) -> TextField: + if isinstance(obj, str): + obj = self._tokenizer(obj) + return TextField( + obj, + indexer=self._indexer, + padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0, + ) + + def build(self, dataset: Iterable[Union[str, Sequence[HashableT]]]) -> None: + for special_token in self._special_tokens: + self._indexer[special_token] + for text in dataset: + if isinstance(text, str): + text = self._tokenizer(text) + self._indexer(text) + + def indexers(self) -> Mapping[str, IIndexer]: + return {"tokens": self._indexer} + + +class LabelFieldTransform(FieldTransform[HashableT]): + def __init__( + self, + indexer: Optional[IIndexer[HashableT, int]] = None, + ) -> None: + from .indexer import LabelIndexer + + self._indexer: IIndexer[HashableT, int] = indexer or LabelIndexer[HashableT]() + + def __call__(self, obj: HashableT) -> LabelField: + return LabelField(obj, indexer=self._indexer) + + def build(self, dataset: Iterable[HashableT]) -> None: + for label in dataset: + self._indexer[label] + + def indexers(self) -> Mapping[str, IIndexer[HashableT, int]]: + return {"labels": self._indexer} + + +@dataclass +class FieldConfig(Generic[S, T]): + accessor: Callable[[S], T] + transform: FieldTransform[T] + + +class DataModule(Generic[T]): + def __init__( + self, + fields: Mapping[str, Union[FieldTransform, FieldConfig[T, Any]]], + ) -> None: + self._fields: Mapping[str, FieldConfig[T, Any]] = { + name: ( + FieldConfig(accessor=FieldAccessor(name), transform=transform) + if isinstance(transform, FieldTransform) + else transform + ) + for name, transform in fields.items() + } + + def build(self, dataset: Iterable[T]) -> None: + for field in self._fields.values(): + field.transform.build(field.accessor(obj) for obj in dataset) + + def __call__(self, dataset: Iterable[T]) -> Iterable[Dict[str, Field]]: + for obj in dataset: + yield {name: field.transform(field.accessor(obj)) for name, field in self._fields.items()} + + def indexer(self, field: str, name: str) -> IIndexer: + return self._fields[field].transform.indexers()[name] From 2fce7efa741d9bc2e5569a75f680523c18ee73fe Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 18 Feb 2025 23:42:14 +0900 Subject: [PATCH 3/5] fix indent --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index cd1b02c..32bc86d 100644 --- a/README.md +++ b/README.md @@ -237,19 +237,19 @@ Execution result: 'token_ids': array([[12, 13, 0, 0], [14, 8, 6, 15]]), 'mask': array([[ True, True, False, False], - [ True, True, True, True]])}, + [ True, True, True, True]])}, 'source': { - 'token_ids': array([[2, 3, 4, 0], + 'token_ids': array([[2, 3, 4, 0], [5, 6, 7, 8]]), - 'mask': array([[ True, True, True, False], + 'mask': array([[ True, True, True, False], [ True, True, True, True]])}} {'target': { - 'token_ids': array([[12, 16, 17, 18, 0], - g[19, 6, 20, 21, 22]]), + 'token_ids': array([[12, 16, 17, 18, 0], + [19, 6, 20, 21, 22]]), 'mask': array([[ True, True, True, True, False], - g [ True, True, True, True, True]])}, + [ True, True, True, True, True]])}, 'source': {'token_ids': array([[ 9, 3, 4, 0], - g [ 5, 6, 10, 11]]), + [ 5, 6, 10, 11]]), 'mask': array([[ True, True, True, False], [ True, True, True, True]])}} ``` From 8d28803720fcd2daffc5601abdaaeb795080099d Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 18 Feb 2025 23:48:04 +0900 Subject: [PATCH 4/5] modify example --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 32bc86d..7cdbb7c 100644 --- a/README.md +++ b/README.md @@ -224,9 +224,8 @@ with shared_token_indexer.context(train=True): text2text_instances = list(text2text_datamodule(text2text_dataset)) - dataloader = DataLoader(batch_size=2) -for batch in dataloader(list(text2text_datamodule(text2text_dataset))): +for batch in dataloader(text2text_instances): print(batch) ``` From c1c3ce22681b25e13d1f7fe62ac243043dfe1f26 Mon Sep 17 00:00:00 2001 From: altescy Date: Tue, 18 Feb 2025 23:53:38 +0900 Subject: [PATCH 5/5] allow accessing to nested field --- collatable/extras/datamodule.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/collatable/extras/datamodule.py b/collatable/extras/datamodule.py index e32001b..60f3b27 100644 --- a/collatable/extras/datamodule.py +++ b/collatable/extras/datamodule.py @@ -46,11 +46,13 @@ def __call__(self, values: Sequence[HashableT_contra], /) -> IndexT_co: ... class FieldAccessor: - def __init__(self, name: str) -> None: - self._name = name + def __init__(self, field: str) -> None: + self._field = field.split(".") def __call__(self, obj: Any) -> Any: - return getattr(obj, self._name) + for part in self._field: + obj = obj[part] if isinstance(obj, Mapping) else getattr(obj, part) + return obj class FieldTransform(Generic[S]):