Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 42 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,72 +183,83 @@ Execution result:
```


### DataModule
### Rererence Implementation

`extra` module provides a reference implementation to use `collatable` effectively.
Here is an example of text-to-text task that encodes raw texts/labels into token
ids and decodes them back to raw texts/labels:

```python
from dataclasses import dataclass
from typing import Sequence, Union
from typing import Mapping, Sequence, Union

from collatable import LabelField, TextField
from collatable.extras import DataLoader, LabelIndexer, TokenIndexer
from collatable.extras import DataLoader, Dataset, DefaultBatchSampler, LabelIndexer, TokenIndexer
from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform
from collatable.util import debatched


@dataclass
class Text2TextExample:
source: Union[str, Sequence[str]]
target: Union[str, Sequence[str]]
language: str


text2text_dataset = [
Text2TextExample(source="how are you?", target="I am fine."),
Text2TextExample(source="what is your name?", target="My name is John."),
Text2TextExample(source="where are you?", target="I am in New York."),
Text2TextExample(source="what is the time?", target="It is 10:00 AM."),
Text2TextExample(source="how are you?", target="I am fine.", language="en"),
Text2TextExample(source="what is your name?", target="My name is John.", language="en"),
Text2TextExample(source="where are you?", target="I am in New-York.", language="en"),
Text2TextExample(source="what is the time?", target="It is 10:00 AM.", language="en"),
Text2TextExample(source="comment ça va?", target="Je vais bien.", language="fr"),
]

shared_token_indexer = TokenIndexer(
default="<unk>",
specials=["<pad>", "<unk>"],
)
shared_token_indexer = TokenIndexer(default="<unk>", specials=["<pad>", "<unk>"])
language_indexer = LabelIndexer[str]()

text2text_datamodule = DataModule[Text2TextExample](
fields={
"source": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
"target": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
"language": LabelFieldTransform(indexer=language_indexer),
}
)

with shared_token_indexer.context(train=True):
with shared_token_indexer.context(train=True), language_indexer.context(train=True):
text2text_datamodule.build(text2text_dataset)

text2text_instances = list(text2text_datamodule(text2text_dataset))

dataloader = DataLoader(batch_size=2)
dataloader = DataLoader(DefaultBatchSampler(batch_size=2))

text2text_instances = Dataset.from_iterable(text2text_datamodule(text2text_dataset))

for batch in dataloader(text2text_instances):
print("Batch:")
print(batch)
print("Reconstruction:")
for item in debatched(batch):
print(text2text_datamodule.reconstruct(item))
print()
```

Execution result:

```
{'target': {
'token_ids': array([[12, 13, 0, 0],
[14, 8, 6, 15]]),
'mask': array([[ True, True, False, False],
[ True, True, True, True]])},
'source': {
'token_ids': array([[2, 3, 4, 0],
[5, 6, 7, 8]]),
'mask': array([[ True, True, True, False],
[ True, True, True, True]])}}
```text
Batch:
{'target': {
'token_ids': array([[12, 16, 17, 18, 0],
[19, 6, 20, 21, 22]]),
'token_ids': array([[16, 17, 18, 19, 0],
[20, 9, 7, 21, 19]]),
'mask': array([[ True, True, True, True, False],
[ True, True, True, True, True]])},
'source': {'token_ids': array([[ 9, 3, 4, 0],
[ 5, 6, 10, 11]]),
'mask': array([[ True, True, True, False],
[ True, True, True, True]])}}
'language': array([0, 0], dtype=int32),
'source': {
'token_ids': array([[2, 3, 4, 5, 0],
[6, 7, 8, 9, 5]]),
'mask': array([[ True, True, True, True, False],
[ True, True, True, True, True]])}}
Reconstruction:
{'source': ['how', 'are', 'you', '?'], 'target': ['I', 'am', 'fine', '.'], 'language': 'en'}
{'source': ['what', 'is', 'your', 'name', '?'], 'target': ['My', 'name', 'is', 'John', '.'], 'language': 'en'}

...
```
2 changes: 1 addition & 1 deletion collatable/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collatable.extras.dataloader import DataLoader # noqa: F401
from collatable.extras.dataloader import DataLoader, DefaultBatchSampler # noqa: F401
from collatable.extras.datamodule import ( # noqa: F401
DataModule,
FieldConfig,
Expand Down
91 changes: 56 additions & 35 deletions collatable/extras/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,87 @@
import math
import random
from typing import Dict, Iterator, Mapping, Optional, Sequence
from typing import Dict, Iterable, Iterator, Mapping, Optional, Protocol, Sequence, TypeVar

from collatable.collator import Collator
from collatable.fields import Field
from collatable.typing import DataArray

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)

class BatchIterator:

class SizedIterator(Protocol[T_co]):
def __len__(self) -> int: ...

def __next__(self) -> T_co: ...

def __iter__(self) -> Iterator[T_co]: ...


class BatchIterator(SizedIterator[Dict[str, DataArray]]):
def __init__(
self,
dataset: Sequence[Mapping[str, Field]],
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
indices: Iterable[Sequence[int]],
num_batches: int,
collator: Optional[Collator] = None,
) -> None:
self._dataset = dataset
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
self._offset = 0
self._indices = iter(indices)
self._num_batches = num_batches
self._collator = collator or Collator()
self._indices = list(range(len(self._dataset)))
if self._shuffle:
random.shuffle(self._indices)

def __len__(self) -> int:
if self._drop_last:
return len(self._dataset) // self._batch_size
return math.ceil(len(self._dataset) / self._batch_size)
return self._num_batches

def __next__(self) -> Dict[str, DataArray]:
if self._offset >= len(self._dataset):
raise StopIteration
if self._offset + self._batch_size > len(self._dataset):
if self._drop_last:
raise StopIteration
batch_indices = self._indices[self._offset :]
else:
batch_indices = self._indices[self._offset : self._offset + self._batch_size]
self._offset += self._batch_size
return self._collator([self._dataset[i] for i in batch_indices])
indices = next(self._indices)
return self._collator([self._dataset[i] for i in indices])

def __iter__(self) -> Iterator[Dict[str, DataArray]]:
return self


class DataLoader:
class IBatchSampler(Protocol):
def __call__(self, dataset: Sequence) -> SizedIterator[Mapping[str, DataArray]]: ...


class DefaultBatchSampler:
def __init__(
self, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, collator: Optional[Collator] = None
self,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
) -> None:
self._batch_size = batch_size
self._shuffle = shuffle
self._drop_last = drop_last
self._collator = collator or Collator()

def __call__(self, dataset: Sequence[Mapping[str, Field]]) -> BatchIterator:
return BatchIterator(
dataset,
batch_size=self._batch_size,
shuffle=self._shuffle,
drop_last=self._drop_last,
collator=self._collator,
def __call__(self, dataset: Sequence) -> BatchIterator:
num_batches = (
len(dataset) // self._batch_size if self._drop_last else math.ceil(len(dataset) / self._batch_size)
)
indices = list(range(len(dataset)))
if self._shuffle:
random.shuffle(indices)

def iter_batches() -> Iterator[Sequence[int]]:
for batch_index in range(num_batches):
start_index = batch_index * self._batch_size
end_index = start_index + self._batch_size
yield indices[start_index:end_index]

return BatchIterator(dataset, iter_batches(), num_batches)


class DataLoader:
def __init__(
self,
sampler: Optional[IBatchSampler] = None,
collator: Optional[Collator] = None,
) -> None:
self._sampler = sampler or DefaultBatchSampler()
self._collator = collator or Collator()

def __call__(self, dataset: Sequence[Mapping[str, Field]]) -> SizedIterator[Mapping[str, DataArray]]:
return self._sampler(dataset)
64 changes: 50 additions & 14 deletions collatable/extras/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,47 @@
Sequence,
TypeVar,
Union,
cast,
runtime_checkable,
)

from collatable import Field, LabelField, TextField
from collatable.typing import DataArray, Scalar, Tensor
from collatable.typing import DataArray, IntTensor, Scalar, Tensor

S = TypeVar("S")
T = TypeVar("T")
U = TypeVar("U")
HashableT = TypeVar("HashableT", bound=Hashable)
HashableT_co = TypeVar("HashableT_co", bound=Hashable, covariant=True)
HashableT_contra = TypeVar("HashableT_contra", bound=Hashable, contravariant=True)
IndexT = TypeVar("IndexT")
IndexT_co = TypeVar("IndexT_co", bound=Union[Scalar, DataArray], covariant=True)


@runtime_checkable
class IIndexer(Protocol[HashableT_contra, IndexT_co]):
class IIndexer(Protocol[HashableT, IndexT]):
def __len__(self) -> int: ...

def __getitem__(self, value: HashableT_contra, /) -> int: ...
def __getitem__(self, value: HashableT, /) -> int: ...

def __call__(self, value: HashableT_contra, /) -> IndexT_co: ...
def __call__(self, value: HashableT, /) -> IndexT: ...

def encode(self, value: HashableT, /) -> IndexT: ...

def decode(self, index: IndexT, /) -> HashableT: ...


@runtime_checkable
class ISequenceIndexer(Protocol[HashableT_contra, IndexT_co]):
class ISequenceIndexer(Protocol[HashableT, IndexT]):
def __len__(self) -> int: ...

def __getitem__(self, value: HashableT_contra, /) -> int: ...
def __getitem__(self, value: HashableT, /) -> int: ...

def __call__(self, value: Sequence[HashableT], /) -> IndexT: ...

def __call__(self, values: Sequence[HashableT_contra], /) -> IndexT_co: ...
def encode(self, value: Sequence[HashableT], /) -> IndexT: ...

def decode(self, index: IndexT, /) -> Sequence[HashableT]: ...


class FieldAccessor:
Expand All @@ -59,6 +70,9 @@ class FieldTransform(Generic[S]):
def __call__(self, obj: S) -> Field:
raise NotImplementedError

def reconstruct(self, array: DataArray) -> S:
raise NotImplementedError

def build(self, dataset: Iterable[S]) -> None:
pass

Expand All @@ -67,7 +81,7 @@ def indexers(self) -> Mapping[str, IIndexer]:


class TextFieldTransform(Generic[HashableT], FieldTransform[Union[str, Sequence[HashableT]]]):
_DEFAULT_TOKENIZER_PATTERN = re.compile(r"(?u)\b\w\w+\b")
_DEFAULT_TOKENIZER_PATTERN = re.compile(r"[^\s.,!?:;/]+(?:[-']\[^\s.,!?:;/]+)*|[.,!?:;/]")

def __init__(
self,
Expand Down Expand Up @@ -95,17 +109,26 @@ def __call__(self, obj: Union[str, Sequence[HashableT]]) -> TextField:
obj = self._tokenizer(obj)
return TextField(
obj,
indexer=self._indexer.encode,
padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0,
)

def reconstruct(self, array: DataArray) -> Sequence[HashableT]:
assert isinstance(array, Mapping)
field = TextField[HashableT].from_array(
array,
indexer=self._indexer,
padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0,
)
return field.tokens

def build(self, dataset: Iterable[Union[str, Sequence[HashableT]]]) -> None:
for special_token in self._special_tokens:
self._indexer[special_token]
for text in dataset:
if isinstance(text, str):
text = self._tokenizer(text)
self._indexer(text)
self._indexer.encode(text)

def indexers(self) -> Mapping[str, IIndexer]:
return {"tokens": self._indexer}
Expand All @@ -118,14 +141,19 @@ def __init__(
) -> None:
from .indexer import LabelIndexer

self._indexer: IIndexer[HashableT, int] = indexer or LabelIndexer[HashableT]()
self._indexer: IIndexer[HashableT, int] = indexer if indexer is not None else LabelIndexer[HashableT]()

def __call__(self, obj: HashableT) -> LabelField:
return LabelField(obj, indexer=self._indexer)
return LabelField(obj, indexer=self._indexer.encode)

def reconstruct(self, array: DataArray) -> HashableT:
array = cast(IntTensor, array)
field = LabelField[HashableT].from_array(array, indexer=self._indexer)
return field.label

def build(self, dataset: Iterable[HashableT]) -> None:
for label in dataset:
self._indexer[label]
self._indexer(label)

def indexers(self) -> Mapping[str, IIndexer[HashableT, int]]:
return {"labels": self._indexer}
Expand All @@ -136,6 +164,10 @@ class FieldConfig(Generic[S, T]):
accessor: Callable[[S], T]
transform: FieldTransform[T]

@property
def reconstruct(self) -> Callable[[DataArray], T]:
return self.transform.reconstruct


class DataModule(Generic[T]):
def __init__(
Expand All @@ -151,6 +183,10 @@ def __init__(
for name, transform in fields.items()
}

@property
def fields(self) -> Mapping[str, FieldConfig[T, Any]]:
return self._fields

def build(self, dataset: Iterable[T]) -> None:
for field in self._fields.values():
field.transform.build(field.accessor(obj) for obj in dataset)
Expand All @@ -159,5 +195,5 @@ def __call__(self, dataset: Iterable[T]) -> Iterable[Dict[str, Field]]:
for obj in dataset:
yield {name: field.transform(field.accessor(obj)) for name, field in self._fields.items()}

def indexer(self, field: str, name: str) -> IIndexer:
return self._fields[field].transform.indexers()[name]
def reconstruct(self, array: Mapping[str, DataArray]) -> Mapping[str, Any]:
return {name: field.reconstruct(array[name]) for name, field in self._fields.items()}
Loading