Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,74 @@ Execution result:
[-1, -1, -1],
[-1, -1, -1]]], dtype=int32)}
```


### DataModule

```python
from dataclasses import dataclass
from typing import Sequence, Union

from collatable import LabelField, TextField
from collatable.extras import DataLoader, LabelIndexer, TokenIndexer
from collatable.extras.datamodule import DataModule, LabelFieldTransform, TextFieldTransform


@dataclass
class Text2TextExample:
source: Union[str, Sequence[str]]
target: Union[str, Sequence[str]]


text2text_dataset = [
Text2TextExample(source="how are you?", target="I am fine."),
Text2TextExample(source="what is your name?", target="My name is John."),
Text2TextExample(source="where are you?", target="I am in New York."),
Text2TextExample(source="what is the time?", target="It is 10:00 AM."),
]

shared_token_indexer = TokenIndexer(
default="<unk>",
specials=["<pad>", "<unk>"],
)

text2text_datamodule = DataModule[Text2TextExample](
fields={
"source": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
"target": TextFieldTransform(indexer=shared_token_indexer, pad_token="<pad>"),
}
)

with shared_token_indexer.context(train=True):
text2text_datamodule.build(text2text_dataset)

text2text_instances = list(text2text_datamodule(text2text_dataset))

dataloader = DataLoader(batch_size=2)
for batch in dataloader(text2text_instances):
print(batch)
```

Execution result:

```
{'target': {
'token_ids': array([[12, 13, 0, 0],
[14, 8, 6, 15]]),
'mask': array([[ True, True, False, False],
[ True, True, True, True]])},
'source': {
'token_ids': array([[2, 3, 4, 0],
[5, 6, 7, 8]]),
'mask': array([[ True, True, True, False],
[ True, True, True, True]])}}
{'target': {
'token_ids': array([[12, 16, 17, 18, 0],
[19, 6, 20, 21, 22]]),
'mask': array([[ True, True, True, True, False],
[ True, True, True, True, True]])},
'source': {'token_ids': array([[ 9, 3, 4, 0],
[ 5, 6, 10, 11]]),
'mask': array([[ True, True, True, False],
[ True, True, True, True]])}}
```
7 changes: 7 additions & 0 deletions collatable/extras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from collatable.extras.dataloader import DataLoader # noqa: F401
from collatable.extras.datamodule import ( # noqa: F401
DataModule,
FieldConfig,
FieldTransform,
LabelFieldTransform,
TextFieldTransform,
)
from collatable.extras.dataset import Dataset # noqa: F401
from collatable.extras.indexer import Indexer, LabelIndexer, TokenIndexer # noqa: F401
163 changes: 163 additions & 0 deletions collatable/extras/datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import re
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Generic,
Hashable,
Iterable,
Mapping,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
runtime_checkable,
)

from collatable import Field, LabelField, TextField
from collatable.typing import DataArray, Scalar, Tensor

S = TypeVar("S")
T = TypeVar("T")
U = TypeVar("U")
HashableT = TypeVar("HashableT", bound=Hashable)
HashableT_contra = TypeVar("HashableT_contra", bound=Hashable, contravariant=True)
IndexT_co = TypeVar("IndexT_co", bound=Union[Scalar, DataArray], covariant=True)


@runtime_checkable
class IIndexer(Protocol[HashableT_contra, IndexT_co]):
def __len__(self) -> int: ...

def __getitem__(self, value: HashableT_contra, /) -> int: ...

def __call__(self, value: HashableT_contra, /) -> IndexT_co: ...


@runtime_checkable
class ISequenceIndexer(Protocol[HashableT_contra, IndexT_co]):
def __len__(self) -> int: ...

def __getitem__(self, value: HashableT_contra, /) -> int: ...

def __call__(self, values: Sequence[HashableT_contra], /) -> IndexT_co: ...


class FieldAccessor:
def __init__(self, field: str) -> None:
self._field = field.split(".")

def __call__(self, obj: Any) -> Any:
for part in self._field:
obj = obj[part] if isinstance(obj, Mapping) else getattr(obj, part)
return obj


class FieldTransform(Generic[S]):
def __call__(self, obj: S) -> Field:
raise NotImplementedError

def build(self, dataset: Iterable[S]) -> None:
pass

def indexers(self) -> Mapping[str, IIndexer]:
return dict((attribute, value) for attribute, value in self.__dict__.items() if isinstance(value, IIndexer))


class TextFieldTransform(Generic[HashableT], FieldTransform[Union[str, Sequence[HashableT]]]):
_DEFAULT_TOKENIZER_PATTERN = re.compile(r"(?u)\b\w\w+\b")

def __init__(
self,
tokenizer: Optional[Callable[[str], Sequence[HashableT]]] = None,
pad_token: Optional[HashableT] = None,
unk_token: Optional[HashableT] = None,
special_tokens: Optional[Sequence[HashableT]] = None,
indexer: Optional[ISequenceIndexer[HashableT, Mapping[str, Tensor]]] = None,
) -> None:
from .indexer import TokenIndexer

self._tokenizer = tokenizer or (lambda text: self._DEFAULT_TOKENIZER_PATTERN.findall(text))
self._indexer: ISequenceIndexer[HashableT, Mapping[str, Tensor]] = (
indexer if indexer is not None else TokenIndexer[HashableT]()
)
self._pad_token = pad_token
self._special_tokens = special_tokens or []
if unk_token is not None and unk_token not in self._special_tokens:
self._special_tokens = [unk_token, *self._special_tokens]
if pad_token is not None and pad_token not in self._special_tokens:
self._special_tokens = [pad_token, *self._special_tokens]

def __call__(self, obj: Union[str, Sequence[HashableT]]) -> TextField:
if isinstance(obj, str):
obj = self._tokenizer(obj)
return TextField(
obj,
indexer=self._indexer,
padding_value=self._indexer[self._pad_token] if self._pad_token is not None else 0,
)

def build(self, dataset: Iterable[Union[str, Sequence[HashableT]]]) -> None:
for special_token in self._special_tokens:
self._indexer[special_token]
for text in dataset:
if isinstance(text, str):
text = self._tokenizer(text)
self._indexer(text)

def indexers(self) -> Mapping[str, IIndexer]:
return {"tokens": self._indexer}


class LabelFieldTransform(FieldTransform[HashableT]):
def __init__(
self,
indexer: Optional[IIndexer[HashableT, int]] = None,
) -> None:
from .indexer import LabelIndexer

self._indexer: IIndexer[HashableT, int] = indexer or LabelIndexer[HashableT]()

def __call__(self, obj: HashableT) -> LabelField:
return LabelField(obj, indexer=self._indexer)

def build(self, dataset: Iterable[HashableT]) -> None:
for label in dataset:
self._indexer[label]

def indexers(self) -> Mapping[str, IIndexer[HashableT, int]]:
return {"labels": self._indexer}


@dataclass
class FieldConfig(Generic[S, T]):
accessor: Callable[[S], T]
transform: FieldTransform[T]


class DataModule(Generic[T]):
def __init__(
self,
fields: Mapping[str, Union[FieldTransform, FieldConfig[T, Any]]],
) -> None:
self._fields: Mapping[str, FieldConfig[T, Any]] = {
name: (
FieldConfig(accessor=FieldAccessor(name), transform=transform)
if isinstance(transform, FieldTransform)
else transform
)
for name, transform in fields.items()
}

def build(self, dataset: Iterable[T]) -> None:
for field in self._fields.values():
field.transform.build(field.accessor(obj) for obj in dataset)

def __call__(self, dataset: Iterable[T]) -> Iterable[Dict[str, Field]]:
for obj in dataset:
yield {name: field.transform(field.accessor(obj)) for name, field in self._fields.items()}

def indexer(self, field: str, name: str) -> IIndexer:
return self._fields[field].transform.indexers()[name]
17 changes: 9 additions & 8 deletions collatable/fields/label_field.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from typing import Callable, Mapping, Optional, TypeVar, Union
from typing import Callable, Generic, Hashable, Mapping, Optional, TypeVar

import numpy

from collatable.fields.field import Field
from collatable.typing import IntTensor

Self = TypeVar("Self", bound="LabelField")
LabelT = TypeVar("LabelT", bound=Hashable)


class LabelField(Field[IntTensor]):
class LabelField(Generic[LabelT], Field[IntTensor]):
__slots__ = ["_label", "_label_index"]

def __init__(
self,
label: Union[int, str],
label: LabelT,
*,
vocab: Optional[Mapping[str, int]] = None,
indexer: Optional[Callable[[str], int]] = None,
vocab: Optional[Mapping[LabelT, int]] = None,
indexer: Optional[Callable[[LabelT], int]] = None,
) -> None:
if isinstance(label, str) and vocab is None is indexer:
raise ValueError("LabelField with string labels requires vocab or indexer")
Expand All @@ -42,15 +43,15 @@ def __repr__(self) -> str:
return f"LabelField(label={self._label})"

@property
def label(self) -> Union[int, str]:
def label(self) -> LabelT:
return self._label

def as_array(self) -> IntTensor:
return numpy.array(self._label_index, dtype=numpy.int32)

@staticmethod
def _make_indexer(vocab: Mapping[str, int]) -> Callable[[str], int]:
def indexer(label: str) -> int:
def _make_indexer(vocab: Mapping[LabelT, int]) -> Callable[[LabelT], int]:
def indexer(label: LabelT) -> int:
return vocab[label]

return indexer