From 6c1940ca5e664c3a894c747a4ae4c800434d011a Mon Sep 17 00:00:00 2001 From: Dat-Boi-Arjun <56085610+Dat-Boi-Arjun@users.noreply.github.com> Date: Sat, 27 Mar 2021 17:36:31 -0700 Subject: [PATCH 1/4] Specified encoding for reading txt files --- syfertext/data/readers/language_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syfertext/data/readers/language_modeling.py b/syfertext/data/readers/language_modeling.py index 01eee596..fef2f9eb 100644 --- a/syfertext/data/readers/language_modeling.py +++ b/syfertext/data/readers/language_modeling.py @@ -21,7 +21,7 @@ def read(self, dataset_meta): data_path = pathlib.Path(data_path) # Open the text file to read and encode its text - with data_path.open() as f: + with data_path.open(encoding='utf-8') as f: # Read all lines for line in f.readlines(): From 1315516f9e31444c1ee269da6b7580a42bd77f62 Mon Sep 17 00:00:00 2001 From: Dat-Boi-Arjun <56085610+Dat-Boi-Arjun@users.noreply.github.com> Date: Sat, 27 Mar 2021 17:39:23 -0700 Subject: [PATCH 2/4] Added support for BERT MLM --- syfertext/data/iterators/bert_loader.py | 83 +++++++++++++++++++++++++ syfertext/encoders/bert_encoder.py | 16 +++++ 2 files changed, 99 insertions(+) create mode 100644 syfertext/data/iterators/bert_loader.py create mode 100644 syfertext/encoders/bert_encoder.py diff --git a/syfertext/data/iterators/bert_loader.py b/syfertext/data/iterators/bert_loader.py new file mode 100644 index 00000000..9a75eff3 --- /dev/null +++ b/syfertext/data/iterators/bert_loader.py @@ -0,0 +1,83 @@ +from typing import Dict, List +from torch import LongTensor +from transformers import DataCollatorForLanguageModeling + + +class BERTIterator: + + def __init__(self, dataset_reader, batch_size: int, sentence_len: int): + self.dataset_reader = dataset_reader + self.batch_size = batch_size + self.sentence_len = sentence_len + + self.data_collator = DataCollatorForLanguageModeling( + tokenizer=self.dataset_reader.encoder.tokenizer_ref, + mlm = True, + mlm_probability = 0.15) + + def load(self, dataset_meta) -> LongTensor: + self.dataset_reader.read(dataset_meta) + + #In case user wants to display the data + return self.dataset_reader.encoded_text + + def __iter__(self): + + self.index = 0 + + return self + + def __next__(self): + + if self.index + self.batch_size > self.num_examples: + raise StopIteration + + batch_examples = [] + + for i in range(self.batch_size): + example = self._load_example() + batch_examples.append(example) + + batch = self._collate(batch_examples=batch_examples) + + return batch + + @property + def num_examples(self): + """Returns that number of non-overlapping examples + in the dataset + """ + + num_examples = (len(self.dataset_reader.encoded_text) - 1) // self.sentence_len + + return num_examples + + @property + def num_batches(self): + """Returns the total number of batches. The last batch + is dropped if its size is less than self.batch_size. + """ + + num_batches = self.num_examples // self.batch_size + + return num_batches + + def _load_example(self) -> LongTensor: + + # LongTensor containing the dataset + dataset = self.dataset_reader.encoded_text + + #Getting an example - sequence of length 'sentence_len' + example = dataset.narrow( + dim=0, start=self.index * self.sentence_len, length=self.sentence_len + ) + + self.index += 1 + + return example + + def _collate(self, batch_examples: List) -> Dict: + + return self.data_collator(batch_examples) + + diff --git a/syfertext/encoders/bert_encoder.py b/syfertext/encoders/bert_encoder.py new file mode 100644 index 00000000..1f52774d --- /dev/null +++ b/syfertext/encoders/bert_encoder.py @@ -0,0 +1,16 @@ +from typing import Dict, List +from transformers import BertTokenizer + +class BERTEncoder: + + def __init__(self): + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + def __call__(self, text:List) -> Dict: + inputs = self.tokenizer(text) + return {"token_ids": inputs["input_ids"]} + + @property + def tokenizer_ref(self): + #decorator method so tokenizer can't be modified + return self.tokenizer \ No newline at end of file From 77af30e0f62779e4a54dcb68f3b9a9bd0da87237 Mon Sep 17 00:00:00 2001 From: Dat-Boi-Arjun <56085610+Dat-Boi-Arjun@users.noreply.github.com> Date: Thu, 22 Apr 2021 23:28:24 -0700 Subject: [PATCH 3/4] Made changes according to PR --- examples/local/bert_iterator/bert_mlm.ipynb | 249 ++++++++++++++++++++ syfertext/data/iterators/bert_loader.py | 25 +- 2 files changed, 260 insertions(+), 14 deletions(-) create mode 100644 examples/local/bert_iterator/bert_mlm.ipynb diff --git a/examples/local/bert_iterator/bert_mlm.ipynb b/examples/local/bert_iterator/bert_mlm.ipynb new file mode 100644 index 00000000..14b56e21 --- /dev/null +++ b/examples/local/bert_iterator/bert_mlm.ipynb @@ -0,0 +1,249 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6-final" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.8.6 64-bit", + "metadata": { + "interpreter": { + "hash": "5202b1b321302d3e244bf56e867ff8fe1ef9c7446c57e95c118c3d2a6f0522ba" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "## This notebook trains a local version of the BERT MLM Model" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import transformers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from syfertext.data.metas.language_modeling import TextDatasetMeta\n", + "from syfertext.data.readers.language_modeling import TextReader\n", + "from syfertext.data.iterators.bert_loader import BERTIterator\n", + "from syfertext.encoders.bert_encoder import BERTEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " torch.device(\"cuda\")\n", + "\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + " \n", + "print(torch.cuda.get_device_properties(device))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "encoder = BERTEncoder()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = transformers.BertForMaskedLM.from_pretrained(\"bert-base-uncased\")\n", + "model.to(device)\n", + "print(\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = transformers.AdamW(model.parameters(), lr=2e-5, eps=1e-8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "meta = TextDatasetMeta(train_path=\"PATH TO TRAIN DATA\", \n", + " valid_path=\"PATH TO VALIDATION DATA\", \n", + " test_path=\"PATH TO TEST DATA\")\n", + "\n", + "model_save_path = \"./mlm_model.pt\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_reader = TextReader(encoder=encoder, mode='train')\n", + "train_loader = BERTIterator(batch_size=20, sentence_len=35, dataset_reader=train_reader)\n", + "train_loader.load(meta)\n", + "num_epochs = 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scheduler = transformers.get_linear_schedule_with_warmup(optimizer, \n", + " num_warmup_steps=0, \n", + " num_training_steps=train_loader.num_examples * num_epochs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "val_reader = TextReader(encoder=encoder, mode='valid')\n", + "val_loader = BERTIterator(batch_size=10, sentence_len=35, dataset_reader=val_reader)\n", + "val_loader.load(meta)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(loader, model):\n", + " total_loss = 0.\n", + "\n", + " with torch.no_grad():\n", + " for data in loader:\n", + " inputs = data[\"input_ids\"].to(device)\n", + " labels = data[\"labels\"].to(device)\n", + "\n", + " outputs = model(input_ids=inputs, labels=labels)\n", + " total_loss += len(inputs) * outputs.loss.item()\n", + "\n", + " return total_loss / loader.num_examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(42)\n", + "\n", + "total_batches = train_loader.num_batches\n", + "\n", + "#Change this depending on how often you want training updates\n", + "log_interval = 200\n", + "\n", + "for epoch in range(1, num_epochs + 1):\n", + " model.train()\n", + " print(f\"=========EPOCH {epoch}=========\")\n", + "\n", + " for batch_num, data in enumerate(train_loader):\n", + " inputs = data[\"input_ids\"].to(device)\n", + " labels = data[\"labels\"].to(device)\n", + "\n", + " model.zero_grad()\n", + "\n", + " outputs = model(input_ids=inputs, labels=labels)\n", + " loss = outputs.loss\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " scheduler.step()\n", + "\n", + " if (batch_num % log_interval == 0):\n", + " print(f\"Batch {batch_num}/{total_batches} | Loss: {loss.item()}\")\n", + "\n", + " model.eval()\n", + " val_loss = evaluate(val_loader, model)\n", + " print(\"-------------------\")\n", + " print(f\"Val Loss for Epoch {epoch}: {val_loss}\")\n", + " print(\"-------------------\")\n", + "\n", + "print(f\"Done training! Saving model to {model_save_path}\")\n", + "torch.save(model.state_dict(), model_save_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pred_model = transformers.BertForMaskedLM.from_pretrained(\"bert-base-uncased\")\n", + "print(\"Base model loaded\")\n", + "pred_model.load_state_dict(torch.load(model_save_path))\n", + "pred_model.eval().to(device)\n", + "print(\"Trained state initialized\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_reader = TextReader(encoder=encoder, mode='test')\n", + "test_loader = BERTIterator(batch_size=10, sentence_len=35, dataset_reader=test_reader)\n", + "test_loader.load(meta)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_loss = evaluate(test_loader, pred_model)\n", + "print(f\"Test Loss: {test_loss}\")" + ] + } + ] +} \ No newline at end of file diff --git a/syfertext/data/iterators/bert_loader.py b/syfertext/data/iterators/bert_loader.py index 9a75eff3..194f8846 100644 --- a/syfertext/data/iterators/bert_loader.py +++ b/syfertext/data/iterators/bert_loader.py @@ -10,17 +10,9 @@ def __init__(self, dataset_reader, batch_size: int, sentence_len: int): self.batch_size = batch_size self.sentence_len = sentence_len - self.data_collator = DataCollatorForLanguageModeling( - tokenizer=self.dataset_reader.encoder.tokenizer_ref, - mlm = True, - mlm_probability = 0.15) - - def load(self, dataset_meta) -> LongTensor: + def load(self, dataset_meta): self.dataset_reader.read(dataset_meta) - #In case user wants to display the data - return self.dataset_reader.encoded_text - def __iter__(self): self.index = 0 @@ -28,9 +20,6 @@ def __iter__(self): return self def __next__(self): - - if self.index + self.batch_size > self.num_examples: - raise StopIteration batch_examples = [] @@ -48,7 +37,7 @@ def num_examples(self): in the dataset """ - num_examples = (len(self.dataset_reader.encoded_text) - 1) // self.sentence_len + num_examples = len(self.dataset_reader.encoded_text) // self.sentence_len return num_examples @@ -62,6 +51,9 @@ def num_batches(self): return num_batches + def __len__(self): + return self.num_batches + def _load_example(self) -> LongTensor: # LongTensor containing the dataset @@ -78,6 +70,11 @@ def _load_example(self) -> LongTensor: def _collate(self, batch_examples: List) -> Dict: - return self.data_collator(batch_examples) + data_collator = DataCollatorForLanguageModeling( + tokenizer=self.dataset_reader.encoder.tokenizer_ref, + mlm = True, + mlm_probability = 0.15) + + return data_collator(batch_examples) From 8df9c4807501b613df79773b391ded22c101e0b0 Mon Sep 17 00:00:00 2001 From: Dat-Boi-Arjun <56085610+Dat-Boi-Arjun@users.noreply.github.com> Date: Sat, 1 May 2021 18:18:14 -0700 Subject: [PATCH 4/4] Removed tokenizer_ref property --- syfertext/data/iterators/bert_loader.py | 2 +- syfertext/encoders/bert_encoder.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/syfertext/data/iterators/bert_loader.py b/syfertext/data/iterators/bert_loader.py index 194f8846..dd036bbd 100644 --- a/syfertext/data/iterators/bert_loader.py +++ b/syfertext/data/iterators/bert_loader.py @@ -71,7 +71,7 @@ def _load_example(self) -> LongTensor: def _collate(self, batch_examples: List) -> Dict: data_collator = DataCollatorForLanguageModeling( - tokenizer=self.dataset_reader.encoder.tokenizer_ref, + tokenizer=self.dataset_reader.encoder.tokenizer, mlm = True, mlm_probability = 0.15) diff --git a/syfertext/encoders/bert_encoder.py b/syfertext/encoders/bert_encoder.py index 1f52774d..e918143f 100644 --- a/syfertext/encoders/bert_encoder.py +++ b/syfertext/encoders/bert_encoder.py @@ -8,9 +8,4 @@ def __init__(self): def __call__(self, text:List) -> Dict: inputs = self.tokenizer(text) - return {"token_ids": inputs["input_ids"]} - - @property - def tokenizer_ref(self): - #decorator method so tokenizer can't be modified - return self.tokenizer \ No newline at end of file + return {"token_ids": inputs["input_ids"]} \ No newline at end of file