diff --git a/subwiz/main.py b/subwiz/main.py index 4e0a1a3..e200d59 100644 --- a/subwiz/main.py +++ b/subwiz/main.py @@ -12,6 +12,8 @@ from collections import defaultdict from typing import Callable +from importlib.metadata import version + from huggingface_hub import hf_hub_download from huggingface_hub.utils import disable_progress_bars, enable_progress_bars import torch @@ -32,9 +34,11 @@ MODEL_REPO = "HadrianSecurity/subwiz" -MODEL_FILE = "model.pt" -TOKENIZER_FILE = "tokenizer.json" +MODEL_FILE = "model_v2.pt" +TOKENIZER_FILE = "tokenizer_v2.json" CONFIG_FILE = "config.json" +# Change revision when realeasing new weights +REVISION = "9a2c505d0312ad6938b27d9b4338020fe37883e8" def get_model_and_tokenizer( @@ -55,13 +59,22 @@ def get_model_and_tokenizer( disable_progress_bars() model_path = hf_hub_download( - repo_id=MODEL_REPO, filename=MODEL_FILE, force_download=force_download + repo_id=MODEL_REPO, + filename=MODEL_FILE, + force_download=force_download, + revision=REVISION, ) tokenizer_path = hf_hub_download( - repo_id=MODEL_REPO, filename=TOKENIZER_FILE, force_download=force_download + repo_id=MODEL_REPO, + filename=TOKENIZER_FILE, + force_download=force_download, + revision=REVISION, ) hf_hub_download( - repo_id=MODEL_REPO, filename=CONFIG_FILE, force_download=force_download + repo_id=MODEL_REPO, + filename=CONFIG_FILE, + force_download=force_download, + revision=REVISION, ) if quiet: enable_progress_bars() @@ -102,13 +115,19 @@ def run_inference( Set of predicted domain objects """ - apex = next(iter(input_domains)).apex_domain subs = [dom.subdomain for dom in input_domains] - tokenizer_input = ",".join(sorted(subs)) + "[DELIM]" - # TODO: pick a different subset, if some were out of context last iteration + apex_domain = next(iter(input_domains)).apex_domain + subdomains_tokenizer_input = ",".join(sorted(subs)) + "[DELIM]" + apex_tokenizer_input = apex_domain + "[DELIM]" + + subs_x = tokenizer.encode(subdomains_tokenizer_input) + apex_x = tokenizer.encode(apex_tokenizer_input) + + # Trim subs to account for the apex part, grab last part + subs_x = subs_x[-(gpt_model.config.block_size - len(apex_x)) :] - x = tokenizer.encode(tokenizer_input) - x = [1] * (gpt_model.config.block_size - len(x)) + x + x = apex_x + subs_x + x = [gpt_model.pad_token] * (gpt_model.config.block_size - len(x)) + x x = torch.tensor(x) blocked_outputs = {dom.subdomain for dom in blocked_domains} @@ -128,7 +147,7 @@ def run_inference( for pred in predictions } - predictions: set[str] = {sub + "." + apex for sub in predictions} + predictions: set[str] = {sub + "." + apex_domain for sub in predictions} predicted_domains: set[Domain] = set() for pred in predictions: diff --git a/subwiz/model.py b/subwiz/model.py index f230d45..ddcade2 100644 --- a/subwiz/model.py +++ b/subwiz/model.py @@ -248,6 +248,7 @@ def __init__(self, config: GPTConfig): self.end_token = self.tokenizer("[END]")["input_ids"][0] self.comma_token = self.tokenizer(",")["input_ids"][0] self.delim_token = self.tokenizer("[DELIM]")["input_ids"][0] + self.pad_token = self.tokenizer("[PAD]")["input_ids"][0] self.transformer = nn.ModuleDict( dict( @@ -382,6 +383,27 @@ def device(self) -> str: # assign model inputs to the right device return next(self.lm_head.parameters()).device.type + def _trim_subdomains( + self, + sequences: torch.Tensor, + apex_unpadded_position: int, + num_tokens_generated: int, + ) -> torch.Tensor: + if num_tokens_generated == 0: + return sequences + + trimming_position = ( + apex_unpadded_position + if num_tokens_generated > apex_unpadded_position + else num_tokens_generated + ) + + sequences = torch.cat( + (sequences[:, :trimming_position], sequences[:, trimming_position + 1 :]), + dim=1, + ) + return sequences + @torch.no_grad() def generate( self, @@ -426,8 +448,15 @@ def generate( run_uuid = uuid.uuid4() idx = idx.to(self.device) + num_initial_pad_tokens = (idx == self.pad_token).sum().item() sequences = idx.unsqueeze(0) + apex_padded_position = (sequences == self.delim_token).nonzero(as_tuple=True)[ + 1 + ][0] + + apex_unpadded_position = 1 + apex_padded_position - num_initial_pad_tokens + probabilities = torch.tensor([1.0], device=self.device) finished_sequences = torch.tensor([], device=self.device) @@ -441,7 +470,11 @@ def generate( on_iteration() # trim the sequences down to block size - sequences = sequences[:, -self.config.block_size :] + sequences = self._trim_subdomains( + sequences, + apex_unpadded_position=apex_unpadded_position, + num_tokens_generated=i, + ) # remove any invalid subdomain starts outputs = self.tokenizer.batch_decode(sequences[:, -i:]) diff --git a/tests/test_results.py b/tests/test_results.py index 901699a..9cda550 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -19,7 +19,13 @@ def test_languages(): no_resolve=True, ) print(results) - assert "english.hadrian.io" in results + assert { + "english.hadrian.io", + "french.hadrian.io", + "spanish.hadrian.io", + "portuguese.hadrian.io", + "dutch.hadrian.io", + } & set(results) def test_numbers():