From c0b4b1fca85ed516403b87a28f97658b074f66ae Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 19 Feb 2025 15:57:23 +0900 Subject: [PATCH 1/6] add batch sampler --- collatable/extras/dataloader.py | 91 ++++++++++++++++++++------------- tests/extras/test_dataloader.py | 8 +-- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/collatable/extras/dataloader.py b/collatable/extras/dataloader.py index ee04737..e440b27 100644 --- a/collatable/extras/dataloader.py +++ b/collatable/extras/dataloader.py @@ -1,66 +1,87 @@ import math import random -from typing import Dict, Iterator, Mapping, Optional, Sequence +from typing import Dict, Iterable, Iterator, Mapping, Optional, Protocol, Sequence, TypeVar from collatable.collator import Collator from collatable.fields import Field from collatable.typing import DataArray +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) -class BatchIterator: + +class SizedIterator(Protocol[T_co]): + def __len__(self) -> int: ... + + def __next__(self) -> T_co: ... + + def __iter__(self) -> Iterator[T_co]: ... + + +class BatchIterator(SizedIterator[Dict[str, DataArray]]): def __init__( self, dataset: Sequence[Mapping[str, Field]], - batch_size: int = 1, - shuffle: bool = False, - drop_last: bool = False, + indices: Iterable[Sequence[int]], + num_batches: int, collator: Optional[Collator] = None, ) -> None: self._dataset = dataset - self._batch_size = batch_size - self._shuffle = shuffle - self._drop_last = drop_last - self._offset = 0 + self._indices = iter(indices) + self._num_batches = num_batches self._collator = collator or Collator() - self._indices = list(range(len(self._dataset))) - if self._shuffle: - random.shuffle(self._indices) def __len__(self) -> int: - if self._drop_last: - return len(self._dataset) // self._batch_size - return math.ceil(len(self._dataset) / self._batch_size) + return self._num_batches def __next__(self) -> Dict[str, DataArray]: - if self._offset >= len(self._dataset): - raise StopIteration - if self._offset + self._batch_size > len(self._dataset): - if self._drop_last: - raise StopIteration - batch_indices = self._indices[self._offset :] - else: - batch_indices = self._indices[self._offset : self._offset + self._batch_size] - self._offset += self._batch_size - return self._collator([self._dataset[i] for i in batch_indices]) + indices = next(self._indices) + return self._collator([self._dataset[i] for i in indices]) def __iter__(self) -> Iterator[Dict[str, DataArray]]: return self -class DataLoader: +class IBatchSampler(Protocol): + def __call__(self, dataset: Sequence) -> SizedIterator[Mapping[str, DataArray]]: ... + + +class DefaultBatchSampler: def __init__( - self, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, collator: Optional[Collator] = None + self, + batch_size: int = 1, + shuffle: bool = False, + drop_last: bool = False, ) -> None: self._batch_size = batch_size self._shuffle = shuffle self._drop_last = drop_last - self._collator = collator or Collator() - def __call__(self, dataset: Sequence[Mapping[str, Field]]) -> BatchIterator: - return BatchIterator( - dataset, - batch_size=self._batch_size, - shuffle=self._shuffle, - drop_last=self._drop_last, - collator=self._collator, + def __call__(self, dataset: Sequence) -> BatchIterator: + num_batches = ( + len(dataset) // self._batch_size if self._drop_last else math.ceil(len(dataset) / self._batch_size) ) + indices = list(range(len(dataset))) + if self._shuffle: + random.shuffle(indices) + + def iter_batches() -> Iterator[Sequence[int]]: + for batch_index in range(num_batches): + start_index = batch_index * self._batch_size + end_index = start_index + self._batch_size + yield indices[start_index:end_index] + + return BatchIterator(dataset, iter_batches(), num_batches) + + +class DataLoader: + def __init__( + self, + sampler: Optional[IBatchSampler] = None, + collator: Optional[Collator] = None, + ) -> None: + self._sampler = sampler or DefaultBatchSampler() + self._collator = collator or Collator() + + def __call__(self, dataset: Sequence[Mapping[str, Field]]) -> SizedIterator[Mapping[str, DataArray]]: + return self._sampler(dataset) diff --git a/tests/extras/test_dataloader.py b/tests/extras/test_dataloader.py index 03a7ba9..152da58 100644 --- a/tests/extras/test_dataloader.py +++ b/tests/extras/test_dataloader.py @@ -1,7 +1,7 @@ from typing import Iterator from collatable import LabelField, MetadataField, TextField -from collatable.extras.dataloader import DataLoader +from collatable.extras.dataloader import DataLoader, DefaultBatchSampler from collatable.extras.dataset import Dataset from collatable.extras.indexer import LabelIndexer, TokenIndexer @@ -43,14 +43,14 @@ def read_dataset() -> Iterator[dict]: dataset = Dataset.from_iterable(read_dataset()) - dataloader = DataLoader(batch_size=2) + dataloader = DataLoader(DefaultBatchSampler(batch_size=2)) batch_iterator = dataloader(dataset) assert len(batch_iterator) == 2 - dataloader = DataLoader(batch_size=3, drop_last=True) + dataloader = DataLoader(DefaultBatchSampler(batch_size=3, drop_last=True)) batch_iterator = dataloader(dataset) assert len(batch_iterator) == 1 - dataloader = DataLoader(batch_size=2, shuffle=True) + dataloader = DataLoader(DefaultBatchSampler(batch_size=2, shuffle=True)) batch_iterator = dataloader(dataset) assert all(len(batch["label"]) == 2 for batch in batch_iterator) From 18c90911e824dd21b03e271a8c6569d1c8666386 Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 19 Feb 2025 16:45:03 +0900 Subject: [PATCH 2/6] make fields reconstructable --- collatable/fields/adjacency_field.py | 45 +++++++++++++++++++---- collatable/fields/field.py | 21 +++++++---- collatable/fields/index_field.py | 10 +++++ collatable/fields/label_field.py | 22 ++++++++++- collatable/fields/list_field.py | 12 +++++- collatable/fields/metadata_field.py | 4 ++ collatable/fields/scalar_field.py | 4 ++ collatable/fields/sequence_label_field.py | 45 +++++++++++++++++------ collatable/fields/span_field.py | 10 +++++ collatable/fields/tensor_field.py | 4 ++ collatable/fields/text_field.py | 19 +++++++++- 11 files changed, 166 insertions(+), 30 deletions(-) diff --git a/collatable/fields/adjacency_field.py b/collatable/fields/adjacency_field.py index 51cb102..cdd5b24 100644 --- a/collatable/fields/adjacency_field.py +++ b/collatable/fields/adjacency_field.py @@ -1,4 +1,4 @@ -from typing import Callable, Mapping, Optional, Sequence, Tuple, TypeVar, Union, cast +from typing import Callable, Generic, Hashable, List, Mapping, Optional, Protocol, Sequence, Tuple, TypeVar, Union, cast import numpy @@ -7,9 +7,16 @@ from collatable.typing import IntTensor Self = TypeVar("Self", bound="AdjacencyField") +LabelT = TypeVar("LabelT", bound=Hashable) -class AdjacencyField(Field[IntTensor]): +class IDecotableIndexer(Protocol[LabelT]): + def __call__(self, label: LabelT) -> int: ... + + def decode(self, index: int) -> LabelT: ... + + +class AdjacencyField(Generic[LabelT], Field[IntTensor]): __slots__ = ["_indices", "_labels", "_indexed_labels", "_sequence_length", "_padding_value"] def __init__( @@ -17,9 +24,9 @@ def __init__( indices: Sequence[Tuple[int, int]], sequence_field: SequenceField, *, - labels: Optional[Union[Sequence[int], Sequence[str]]] = None, - vocab: Optional[Mapping[str, int]] = None, - indexer: Optional[Callable[[str], int]] = None, + labels: Optional[Union[Sequence[LabelT]]] = None, + vocab: Optional[Mapping[LabelT, int]] = None, + indexer: Optional[Callable[[LabelT], int]] = None, padding_value: PaddingValue = -1, ) -> None: if len(indices) == 0: @@ -46,7 +53,7 @@ def __init__( self._indexed_labels = cast(Sequence[int], self._labels) else: assert indexer is not None - self._indexed_labels = [indexer(label) for label in cast(Sequence[str], self._labels)] + self._indexed_labels = [indexer(label) for label in self._labels] def __str__(self) -> str: return f"[{', '.join(str(index) for index in self._indices)}]" @@ -55,8 +62,8 @@ def __repr__(self) -> str: return f"AdjacencyField(indices={self._indices}, padding_value={self._padding_value})" @staticmethod - def _make_indexer(vocab: Mapping[str, int]) -> Callable[[str], int]: - def indexer(label: str) -> int: + def _make_indexer(vocab: Mapping[LabelT, int]) -> Callable[[LabelT], int]: + def indexer(label: LabelT) -> int: return vocab[label] return indexer @@ -72,3 +79,25 @@ def as_array(self) -> IntTensor: for (i, j), label in zip(self._indices, labels): array[i, j] = label return array + + @classmethod + def from_array( # type: ignore[override] + cls, + array: IntTensor, + *, + sequence_field: SequenceField, + indexer: Optional[IDecotableIndexer[LabelT]] = None, + padding_value: PaddingValue = -1, + ) -> "AdjacencyField[LabelT]": + if array.ndim != 2: + raise ValueError(f"AdjacencyField expects a 2-dimensional array, but got shape {array.shape}") + indices = [] + indexed_labels: List[LabelT] = [] + for i, row in enumerate(array): + for j, label in enumerate(row): + if label != padding_value: + indices.append((i, j)) + if indexer is not None: + indexed_labels.append(indexer.decode(label)) + + return cls(indices, sequence_field, labels=indexed_labels, padding_value=padding_value) diff --git a/collatable/fields/field.py b/collatable/fields/field.py index d1c59ef..1128e5f 100644 --- a/collatable/fields/field.py +++ b/collatable/fields/field.py @@ -1,17 +1,17 @@ import abc import copy -from typing import Dict, Generic, List, Sequence, TypeVar, Union, cast +from typing import Any, Dict, Generic, List, Sequence, Type, TypeVar, Union, cast import numpy -from collatable.typing import ArrayLike, DataArrayT_co +from collatable.typing import ArrayLike, DataArrayT from collatable.util import stack_with_padding Self = TypeVar("Self", bound="Field") PaddingValue = Union[Dict[str, ArrayLike], ArrayLike] -class Field(abc.ABC, Generic[DataArrayT_co]): +class Field(abc.ABC, Generic[DataArrayT]): __slots__: List[str] def __init__(self, padding_value: PaddingValue = 0) -> None: @@ -35,17 +35,17 @@ def __eq__(self: Self, other: object) -> bool: def padding_value(self) -> Dict[str, ArrayLike]: return self._padding_value - def collate(self: Self, arrays: Union[Sequence[DataArrayT_co], Sequence[Self]]) -> DataArrayT_co: + def collate(self: Self, arrays: Union[Sequence[DataArrayT], Sequence[Self]]) -> DataArrayT: if isinstance(arrays[0], Field): arrays = [cast(Self, array).as_array() for array in arrays] - arrays = cast(Sequence[DataArrayT_co], arrays) + arrays = cast(Sequence[DataArrayT], arrays) if isinstance(arrays[0], numpy.ndarray): return cast( - DataArrayT_co, + DataArrayT, stack_with_padding(cast(Sequence[numpy.ndarray], arrays), padding_value=self.padding_value[""]), ) if isinstance(arrays[0], list): - return cast(DataArrayT_co, list(arrays)) + return cast(DataArrayT, list(arrays)) if isinstance(arrays[0], dict): return { key: stack_with_padding( @@ -60,5 +60,10 @@ def copy(self: Self) -> Self: return copy.deepcopy(self) @abc.abstractmethod - def as_array(self) -> DataArrayT_co: + def as_array(self) -> DataArrayT: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def from_array(cls: Type[Self], array: DataArrayT, **kwargs: Any) -> Self: raise NotImplementedError diff --git a/collatable/fields/index_field.py b/collatable/fields/index_field.py index 80ec112..29d7cb0 100644 --- a/collatable/fields/index_field.py +++ b/collatable/fields/index_field.py @@ -27,3 +27,13 @@ def index(self) -> int: def as_array(self) -> IntTensor: return numpy.array(self.index) + + @classmethod + def from_array( # type: ignore[override] + cls, + array: IntTensor, + sequence: SequenceField, + ) -> "IndexField": + if array.ndim != 0: + raise ValueError(f"IndexField expects a 0-dimensional array, but got shape {array.shape}") + return cls(array.item(), sequence) diff --git a/collatable/fields/label_field.py b/collatable/fields/label_field.py index 6b972f5..003b438 100644 --- a/collatable/fields/label_field.py +++ b/collatable/fields/label_field.py @@ -1,4 +1,4 @@ -from typing import Callable, Generic, Hashable, Mapping, Optional, TypeVar +from typing import Callable, Generic, Hashable, Mapping, Optional, Protocol, TypeVar, cast import numpy @@ -9,6 +9,12 @@ LabelT = TypeVar("LabelT", bound=Hashable) +class IDecotableIndexer(Protocol[LabelT]): + def __call__(self, label: LabelT) -> int: ... + + def decode(self, index: int) -> LabelT: ... + + class LabelField(Generic[LabelT], Field[IntTensor]): __slots__ = ["_label", "_label_index"] @@ -49,6 +55,20 @@ def label(self) -> LabelT: def as_array(self) -> IntTensor: return numpy.array(self._label_index, dtype=numpy.int32) + @classmethod + def from_array( # type: ignore[override] + cls, + array: IntTensor, + *, + indexer: Optional[IDecotableIndexer[LabelT]] = None, + ) -> "LabelField": + if array.ndim != 0: + raise ValueError(f"LabelField expects a 0-dimensional array, but got shape {array.shape}") + label: LabelT = cast(LabelT, array.item()) + if indexer is not None: + label = indexer.decode(array.item()) + return cls(label, indexer=indexer) + @staticmethod def _make_indexer(vocab: Mapping[LabelT, int]) -> Callable[[LabelT], int]: def indexer(label: LabelT) -> int: diff --git a/collatable/fields/list_field.py b/collatable/fields/list_field.py index 022aad1..112c836 100644 --- a/collatable/fields/list_field.py +++ b/collatable/fields/list_field.py @@ -1,4 +1,4 @@ -from typing import Generic, Iterator, Optional, Sequence +from typing import Generic, Iterator, Optional, Sequence, Type from collatable.fields.field import Field, PaddingValue from collatable.fields.sequence_field import SequenceField @@ -37,3 +37,13 @@ def fields(self) -> Sequence[Field[DataArrayT]]: def as_array(self) -> DataArrayT: return self.fields[0].collate(self.fields) + + @classmethod + def from_array( # type: ignore[override] + cls, + array: DataArrayT, + *, + item_type: Type[Field], + padding_value: Optional[PaddingValue] = None, + ) -> "ListField": + return cls([item_type.from_array(item) for item in array], padding_value=padding_value) diff --git a/collatable/fields/metadata_field.py b/collatable/fields/metadata_field.py index b69b683..da26f25 100644 --- a/collatable/fields/metadata_field.py +++ b/collatable/fields/metadata_field.py @@ -23,6 +23,10 @@ def metadata(self) -> Any: def as_array(self) -> Any: return self._metadata + @classmethod + def from_array(cls, array: Any) -> "MetadataField": # type: ignore[override] + return cls(array) + def collate(self, arrays: Sequence[Any]) -> List[Any]: if isinstance(arrays[0], Field): arrays = [array.as_array() for array in arrays] diff --git a/collatable/fields/scalar_field.py b/collatable/fields/scalar_field.py index 25cefb7..a36088e 100644 --- a/collatable/fields/scalar_field.py +++ b/collatable/fields/scalar_field.py @@ -28,3 +28,7 @@ def __repr__(self) -> str: def as_array(self) -> Tensor: return numpy.array(self._value) + + @classmethod + def from_array(cls, array: Tensor) -> "ScalarField": # type: ignore[override] + return cls(array.item()) diff --git a/collatable/fields/sequence_label_field.py b/collatable/fields/sequence_label_field.py index 493a999..13e063d 100644 --- a/collatable/fields/sequence_label_field.py +++ b/collatable/fields/sequence_label_field.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, Mapping, Optional, Sequence, Union, cast +from typing import Callable, Generic, Hashable, Iterator, Mapping, Optional, Protocol, Sequence, TypeVar, Union, cast import numpy @@ -6,17 +6,25 @@ from collatable.fields.sequence_field import SequenceField from collatable.typing import IntTensor +LabelT = TypeVar("LabelT", bound=Hashable) -class SequenceLabelField(SequenceField[IntTensor]): + +class IDecotableIndexer(Protocol[LabelT]): + def __call__(self, label: LabelT) -> int: ... + + def decode(self, index: int) -> LabelT: ... + + +class SequenceLabelField(Generic[LabelT], SequenceField[IntTensor]): __slots__ = ["_labels", "_indexed_labels"] def __init__( self, - labels: Union[Sequence[int], Sequence[str]], + labels: Union[Sequence[LabelT]], sequence_field: SequenceField, *, - vocab: Optional[Mapping[str, int]] = None, - indexer: Optional[Callable[[str], int]] = None, + vocab: Optional[Mapping[LabelT, int]] = None, + indexer: Optional[Callable[[LabelT], int]] = None, padding_value: PaddingValue = 0, ) -> None: if len(labels) != len(sequence_field): @@ -36,15 +44,15 @@ def __init__( if indexer is None: raise ValueError("Indexer must be specified if labels are strings.") self._labels = self._labels - self._indexed_labels = [indexer(label) for label in cast(Sequence[str], self._labels)] + self._indexed_labels = [indexer(label) for label in self._labels] def __len__(self) -> int: return len(self._labels) - def __iter__(self) -> Iterator[Union[int, str]]: + def __iter__(self) -> Iterator[LabelT]: return iter(self._labels) - def __getitem__(self, index: int) -> Union[int, str]: + def __getitem__(self, index: int) -> LabelT: return self._labels[index] def __str__(self) -> str: @@ -54,15 +62,30 @@ def __repr__(self) -> str: return f"SequenceLabelField(labels={self._labels}, padding_value={self._padding_value})" @property - def labels(self) -> Union[Sequence[int], Sequence[str]]: + def labels(self) -> Sequence[LabelT]: return self._labels def as_array(self) -> IntTensor: return numpy.array(self._indexed_labels) + @classmethod + def from_array( # type: ignore[override] + cls, + array: IntTensor, + *, + sequence_field: SequenceField, + indexer: Optional[IDecotableIndexer[LabelT]] = None, + ) -> "SequenceLabelField[LabelT]": + if array.ndim != 1: + raise ValueError(f"SequenceLabelField expects a 1-dimensional array, but got shape {array.shape}") + labels: Sequence[LabelT] = cast(Sequence[LabelT], array.tolist()) + if indexer is not None: + labels = [indexer.decode(index) for index in array] + return cls(labels, sequence_field, indexer=indexer) + @staticmethod - def _make_indexer(vocab: Mapping[str, int]) -> Callable[[str], int]: - def indexer(label: str) -> int: + def _make_indexer(vocab: Mapping[LabelT, int]) -> Callable[[LabelT], int]: + def indexer(label: LabelT) -> int: return vocab[label] return indexer diff --git a/collatable/fields/span_field.py b/collatable/fields/span_field.py index 4fe1056..dbbf076 100644 --- a/collatable/fields/span_field.py +++ b/collatable/fields/span_field.py @@ -45,3 +45,13 @@ def span_end(self) -> int: def as_array(self) -> IntTensor: return numpy.array([self.span_start, self.span_end]) + + @classmethod + def from_array( # type: ignore[override] + cls, + array: IntTensor, + sequence_field: SequenceField, + ) -> "SpanField": + if array.ndim != 1 or array.shape[0] != 2: + raise ValueError(f"SpanField expects a 1-dimensional array of length 2, but got shape {array.shape}") + return cls(array[0], array[1], sequence_field) diff --git a/collatable/fields/tensor_field.py b/collatable/fields/tensor_field.py index fd918bd..a16cd3f 100644 --- a/collatable/fields/tensor_field.py +++ b/collatable/fields/tensor_field.py @@ -35,3 +35,7 @@ def __repr__(self) -> str: def as_array(self) -> TensorT: return self._tensor + + @classmethod + def from_array(cls, array: TensorT) -> "TensorField": # type: ignore[override] + return cls(array) diff --git a/collatable/fields/text_field.py b/collatable/fields/text_field.py index 14dbd09..b77add9 100644 --- a/collatable/fields/text_field.py +++ b/collatable/fields/text_field.py @@ -1,4 +1,4 @@ -from typing import Callable, Generic, Hashable, Iterator, Mapping, Optional, Sequence, TypeVar +from typing import Callable, Generic, Hashable, Iterator, Mapping, Optional, Protocol, Sequence, TypeVar import numpy @@ -9,6 +9,12 @@ TokenT = TypeVar("TokenT", bound=Hashable) +class IDecotableIndexer(Protocol[TokenT]): + def __call__(self, tokens: Sequence[TokenT]) -> Mapping[str, numpy.ndarray]: ... + + def decode(self, index: Mapping[str, numpy.ndarray]) -> Sequence[TokenT]: ... + + class TextField(Generic[TokenT], SequenceField[Mapping[str, numpy.ndarray]]): __slots__ = ["_tokens", "_padding_value", "_indexed_tokens"] @@ -54,6 +60,17 @@ def tokens(self) -> Sequence[TokenT]: def as_array(self) -> Mapping[str, numpy.ndarray]: return self._indexed_tokens + @classmethod + def from_array( # type: ignore[override] + cls, + array: Mapping[str, numpy.ndarray], + *, + indexer: IDecotableIndexer[TokenT], + padding_value: PaddingValue = 0, + ) -> "TextField": + tokens = indexer.decode(array) + return cls(tokens, indexer=indexer, padding_value=padding_value) + @staticmethod def _make_indexer(vocab: Mapping[TokenT, int]) -> Callable[[Sequence[TokenT]], Mapping[str, numpy.ndarray]]: def indexer(tokens: Sequence[TokenT]) -> Mapping[str, numpy.ndarray]: From 61bd2db6c990f4b378ed63c548d1ed7df1b8521c Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 19 Feb 2025 18:02:29 +0900 Subject: [PATCH 3/6] fix types --- collatable/fields/label_field.py | 6 +++--- collatable/fields/text_field.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/collatable/fields/label_field.py b/collatable/fields/label_field.py index 003b438..20d578b 100644 --- a/collatable/fields/label_field.py +++ b/collatable/fields/label_field.py @@ -10,9 +10,9 @@ class IDecotableIndexer(Protocol[LabelT]): - def __call__(self, label: LabelT) -> int: ... + def __call__(self, label: LabelT, /) -> int: ... - def decode(self, index: int) -> LabelT: ... + def decode(self, index: int, /) -> LabelT: ... class LabelField(Generic[LabelT], Field[IntTensor]): @@ -61,7 +61,7 @@ def from_array( # type: ignore[override] array: IntTensor, *, indexer: Optional[IDecotableIndexer[LabelT]] = None, - ) -> "LabelField": + ) -> "LabelField[LabelT]": if array.ndim != 0: raise ValueError(f"LabelField expects a 0-dimensional array, but got shape {array.shape}") label: LabelT = cast(LabelT, array.item()) diff --git a/collatable/fields/text_field.py b/collatable/fields/text_field.py index b77add9..66ae1e5 100644 --- a/collatable/fields/text_field.py +++ b/collatable/fields/text_field.py @@ -10,9 +10,9 @@ class IDecotableIndexer(Protocol[TokenT]): - def __call__(self, tokens: Sequence[TokenT]) -> Mapping[str, numpy.ndarray]: ... + def __call__(self, tokens: Sequence[TokenT], /) -> Mapping[str, numpy.ndarray]: ... - def decode(self, index: Mapping[str, numpy.ndarray]) -> Sequence[TokenT]: ... + def decode(self, index: Mapping[str, numpy.ndarray], /) -> Sequence[TokenT]: ... class TextField(Generic[TokenT], SequenceField[Mapping[str, numpy.ndarray]]): From ee508ba7b7bd4b5edbb46e53de4cd417895098e0 Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 19 Feb 2025 18:02:49 +0900 Subject: [PATCH 4/6] add debatched function --- collatable/util.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/collatable/util.py b/collatable/util.py index dc58fa5..45d5c78 100644 --- a/collatable/util.py +++ b/collatable/util.py @@ -1,8 +1,8 @@ -from typing import Sequence, Type, cast +from typing import List, Mapping, Sequence, Type, cast import numpy -from collatable.typing import ArrayLike, ScalarT, TensorT +from collatable.typing import ArrayLike, DataArray, ScalarT, TensorT def stack_with_padding( @@ -34,3 +34,14 @@ def get_scalar_default_value(cls: Type[ScalarT]) -> ScalarT: if issubclass(cls, complex): return cast(ScalarT, 0.0 + 0.0j) raise TypeError(f"Unsupported type: {cls}") + + +def debatched(array: DataArray) -> List[DataArray]: + if isinstance(array, (Sequence, numpy.ndarray)): + return list(array) + if isinstance(array, Mapping): + keys = set(array) + debatched_values = {key: debatched(value) for key, value in array.items()} + batch_size = len(debatched_values[next(iter(keys))]) + return [{key: debatched_values[key][index] for key in keys} for index in range(batch_size)] + raise TypeError(f"Unsupported type: {type(array)}") From 5a6ee582559db439860219d44e8ac09d7cf22082 Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 19 Feb 2025 18:03:19 +0900 Subject: [PATCH 5/6] add reconstruct method --- collatable/extras/__init__.py | 2 +- collatable/extras/datamodule.py | 64 +++++++++++++++++++++++++-------- collatable/extras/indexer.py | 18 ++++++++-- 3 files changed, 67 insertions(+), 17 deletions(-) diff --git a/collatable/extras/__init__.py b/collatable/extras/__init__.py index 5a04d2a..4750f3f 100644 --- a/collatable/extras/__init__.py +++ b/collatable/extras/__init__.py @@ -1,4 +1,4 @@ -from collatable.extras.dataloader import DataLoader # noqa: F401 +from collatable.extras.dataloader import DataLoader, DefaultBatchSampler # noqa: F401 from collatable.extras.datamodule import ( # noqa: F401 DataModule, FieldConfig, diff --git a/collatable/extras/datamodule.py b/collatable/extras/datamodule.py index 60f3b27..fd09273 100644 --- a/collatable/extras/datamodule.py +++ b/collatable/extras/datamodule.py @@ -13,36 +13,47 @@ Sequence, TypeVar, Union, + cast, runtime_checkable, ) from collatable import Field, LabelField, TextField -from collatable.typing import DataArray, Scalar, Tensor +from collatable.typing import DataArray, IntTensor, Scalar, Tensor S = TypeVar("S") T = TypeVar("T") U = TypeVar("U") HashableT = TypeVar("HashableT", bound=Hashable) +HashableT_co = TypeVar("HashableT_co", bound=Hashable, covariant=True) HashableT_contra = TypeVar("HashableT_contra", bound=Hashable, contravariant=True) +IndexT = TypeVar("IndexT") IndexT_co = TypeVar("IndexT_co", bound=Union[Scalar, DataArray], covariant=True) @runtime_checkable -class IIndexer(Protocol[HashableT_contra, IndexT_co]): +class IIndexer(Protocol[HashableT, IndexT]): def __len__(self) -> int: ... - def __getitem__(self, value: HashableT_contra, /) -> int: ... + def __getitem__(self, value: HashableT, /) -> int: ... - def __call__(self, value: HashableT_contra, /) -> IndexT_co: ... + def __call__(self, value: HashableT, /) -> IndexT: ... + + def encode(self, value: HashableT, /) -> IndexT: ... + + def decode(self, index: IndexT, /) -> HashableT: ... @runtime_checkable -class ISequenceIndexer(Protocol[HashableT_contra, IndexT_co]): +class ISequenceIndexer(Protocol[HashableT, IndexT]): def __len__(self) -> int: ... - def __getitem__(self, value: HashableT_contra, /) -> int: ... + def __getitem__(self, value: HashableT, /) -> int: ... + + def __call__(self, value: Sequence[HashableT], /) -> IndexT: ... - def __call__(self, values: Sequence[HashableT_contra], /) -> IndexT_co: ... + def encode(self, value: Sequence[HashableT], /) -> IndexT: ... + + def decode(self, index: IndexT, /) -> Sequence[HashableT]: ... class FieldAccessor: @@ -59,6 +70,9 @@ class FieldTransform(Generic[S]): def __call__(self, obj: S) -> Field: raise NotImplementedError + def reconstruct(self, array: DataArray) -> S: + raise NotImplementedError + def build(self, dataset: Iterable[S]) -> None: pass @@ -67,7 +81,7 @@ def indexers(self) -> Mapping[str, IIndexer]: class TextFieldTransform(Generic[HashableT], FieldTransform[Union[str, Sequence[HashableT]]]): - _DEFAULT_TOKENIZER_PATTERN = re.compile(r"(?u)\b\w\w+\b") + _DEFAULT_TOKENIZER_PATTERN = re.compile(r"[^\s.,!?:;/]+(?:[-']\[^\s.,!?:;/]+)*|[.,!?:;/]") def __init__( self, @@ -95,9 +109,18 @@ def __call__(self, obj: Union[str, Sequence[HashableT]]) -> TextField: obj = self._tokenizer(obj) return TextField( obj, + indexer=self._indexer.encode, + padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0, + ) + + def reconstruct(self, array: DataArray) -> Sequence[HashableT]: + assert isinstance(array, Mapping) + field = TextField[HashableT].from_array( + array, indexer=self._indexer, padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0, ) + return field.tokens def build(self, dataset: Iterable[Union[str, Sequence[HashableT]]]) -> None: for special_token in self._special_tokens: @@ -105,7 +128,7 @@ def build(self, dataset: Iterable[Union[str, Sequence[HashableT]]]) -> None: for text in dataset: if isinstance(text, str): text = self._tokenizer(text) - self._indexer(text) + self._indexer.encode(text) def indexers(self) -> Mapping[str, IIndexer]: return {"tokens": self._indexer} @@ -118,14 +141,19 @@ def __init__( ) -> None: from .indexer import LabelIndexer - self._indexer: IIndexer[HashableT, int] = indexer or LabelIndexer[HashableT]() + self._indexer: IIndexer[HashableT, int] = indexer if indexer is not None else LabelIndexer[HashableT]() def __call__(self, obj: HashableT) -> LabelField: - return LabelField(obj, indexer=self._indexer) + return LabelField(obj, indexer=self._indexer.encode) + + def reconstruct(self, array: DataArray) -> HashableT: + array = cast(IntTensor, array) + field = LabelField[HashableT].from_array(array, indexer=self._indexer) + return field.label def build(self, dataset: Iterable[HashableT]) -> None: for label in dataset: - self._indexer[label] + self._indexer(label) def indexers(self) -> Mapping[str, IIndexer[HashableT, int]]: return {"labels": self._indexer} @@ -136,6 +164,10 @@ class FieldConfig(Generic[S, T]): accessor: Callable[[S], T] transform: FieldTransform[T] + @property + def reconstruct(self) -> Callable[[DataArray], T]: + return self.transform.reconstruct + class DataModule(Generic[T]): def __init__( @@ -151,6 +183,10 @@ def __init__( for name, transform in fields.items() } + @property + def fields(self) -> Mapping[str, FieldConfig[T, Any]]: + return self._fields + def build(self, dataset: Iterable[T]) -> None: for field in self._fields.values(): field.transform.build(field.accessor(obj) for obj in dataset) @@ -159,5 +195,5 @@ def __call__(self, dataset: Iterable[T]) -> Iterable[Dict[str, Field]]: for obj in dataset: yield {name: field.transform(field.accessor(obj)) for name, field in self._fields.items()} - def indexer(self, field: str, name: str) -> IIndexer: - return self._fields[field].transform.indexers()[name] + def reconstruct(self, array: Mapping[str, DataArray]) -> Mapping[str, Any]: + return {name: field.reconstruct(array[name]) for name, field in self._fields.items()} diff --git a/collatable/extras/indexer.py b/collatable/extras/indexer.py index 16809f9..ef8ba92 100644 --- a/collatable/extras/indexer.py +++ b/collatable/extras/indexer.py @@ -165,7 +165,7 @@ def from_documents( class TokenIndexer(Generic[ValueT], Indexer[ValueT]): - def __call__(self, tokens: Sequence[ValueT]) -> Dict[str, Tensor]: + def encode(self, tokens: Sequence[ValueT]) -> Mapping[str, Tensor]: token_ids = [self.get_index_by_value(value) for value in tokens] if self._bos_value is not None: token_ids = [self._value_to_index[self._bos_value]] + token_ids @@ -173,7 +173,21 @@ def __call__(self, tokens: Sequence[ValueT]) -> Dict[str, Tensor]: token_ids = token_ids + [self._value_to_index[self._eos_value]] return {"token_ids": numpy.array(token_ids, dtype=numpy.int64), "mask": numpy.ones(len(token_ids), dtype=bool)} + def decode(self, index: Mapping[str, Tensor]) -> Sequence[ValueT]: + token_ids = index["token_ids"] + mask = index["mask"] + return [self.get_value_by_index(token_id) for token_id, m in zip(token_ids, mask) if m] + + def __call__(self, tokens: Sequence[ValueT]) -> Mapping[str, Tensor]: + return self.encode(tokens) + class LabelIndexer(Generic[ValueT], Indexer[ValueT]): - def __call__(self, label: ValueT) -> int: + def encode(self, label: ValueT) -> int: return self.get_index_by_value(label) + + def decode(self, index: int) -> ValueT: + return self.get_value_by_index(index) + + def __call__(self, label: ValueT) -> int: + return self.encode(label) From 6b2c680d8374ac3f0003e9f4323c0104ebf4fa54 Mon Sep 17 00:00:00 2001 From: altescy Date: Wed, 19 Feb 2025 18:03:48 +0900 Subject: [PATCH 6/6] update readme --- README.md | 73 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 7cdbb7c..ab976b1 100644 --- a/README.md +++ b/README.md @@ -183,72 +183,83 @@ Execution result: ``` -### DataModule +### Rererence Implementation + +`extra` module provides a reference implementation to use `collatable` effectively. +Here is an example of text-to-text task that encodes raw texts/labels into token +ids and decodes them back to raw texts/labels: ```python from dataclasses import dataclass -from typing import Sequence, Union +from typing import Mapping, Sequence, Union from collatable import LabelField, TextField -from collatable.extras import DataLoader, LabelIndexer, TokenIndexer +from collatable.extras import DataLoader, Dataset, DefaultBatchSampler, LabelIndexer, TokenIndexer from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform +from collatable.util import debatched @dataclass class Text2TextExample: source: Union[str, Sequence[str]] target: Union[str, Sequence[str]] + language: str text2text_dataset = [ - Text2TextExample(source="how are you?", target="I am fine."), - Text2TextExample(source="what is your name?", target="My name is John."), - Text2TextExample(source="where are you?", target="I am in New York."), - Text2TextExample(source="what is the time?", target="It is 10:00 AM."), + Text2TextExample(source="how are you?", target="I am fine.", language="en"), + Text2TextExample(source="what is your name?", target="My name is John.", language="en"), + Text2TextExample(source="where are you?", target="I am in New-York.", language="en"), + Text2TextExample(source="what is the time?", target="It is 10:00 AM.", language="en"), + Text2TextExample(source="comment ça va?", target="Je vais bien.", language="fr"), ] -shared_token_indexer = TokenIndexer( - default="", - specials=["", ""], -) +shared_token_indexer = TokenIndexer(default="", specials=["", ""]) +language_indexer = LabelIndexer[str]() text2text_datamodule = DataModule[Text2TextExample]( fields={ "source": TextFieldTransform(indexer=shared_token_indexer, pad_token=""), "target": TextFieldTransform(indexer=shared_token_indexer, pad_token=""), + "language": LabelFieldTransform(indexer=language_indexer), } ) -with shared_token_indexer.context(train=True): +with shared_token_indexer.context(train=True), language_indexer.context(train=True): text2text_datamodule.build(text2text_dataset) -text2text_instances = list(text2text_datamodule(text2text_dataset)) -dataloader = DataLoader(batch_size=2) +dataloader = DataLoader(DefaultBatchSampler(batch_size=2)) + +text2text_instances = Dataset.from_iterable(text2text_datamodule(text2text_dataset)) + for batch in dataloader(text2text_instances): + print("Batch:") print(batch) + print("Reconstruction:") + for item in debatched(batch): + print(text2text_datamodule.reconstruct(item)) + print() ``` Execution result: -``` -{'target': { - 'token_ids': array([[12, 13, 0, 0], - [14, 8, 6, 15]]), - 'mask': array([[ True, True, False, False], - [ True, True, True, True]])}, - 'source': { - 'token_ids': array([[2, 3, 4, 0], - [5, 6, 7, 8]]), - 'mask': array([[ True, True, True, False], - [ True, True, True, True]])}} +```text +Batch: {'target': { - 'token_ids': array([[12, 16, 17, 18, 0], - [19, 6, 20, 21, 22]]), + 'token_ids': array([[16, 17, 18, 19, 0], + [20, 9, 7, 21, 19]]), 'mask': array([[ True, True, True, True, False], [ True, True, True, True, True]])}, - 'source': {'token_ids': array([[ 9, 3, 4, 0], - [ 5, 6, 10, 11]]), - 'mask': array([[ True, True, True, False], - [ True, True, True, True]])}} + 'language': array([0, 0], dtype=int32), + 'source': { + 'token_ids': array([[2, 3, 4, 5, 0], + [6, 7, 8, 9, 5]]), + 'mask': array([[ True, True, True, True, False], + [ True, True, True, True, True]])}} +Reconstruction: +{'source': ['how', 'are', 'you', '?'], 'target': ['I', 'am', 'fine', '.'], 'language': 'en'} +{'source': ['what', 'is', 'your', 'name', '?'], 'target': ['My', 'name', 'is', 'John', '.'], 'language': 'en'} + +... ```