Skip to content
Merged
41 changes: 30 additions & 11 deletions subwiz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from collections import defaultdict
from typing import Callable

from importlib.metadata import version

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
import torch
Expand All @@ -32,9 +34,11 @@


MODEL_REPO = "HadrianSecurity/subwiz"
MODEL_FILE = "model.pt"
TOKENIZER_FILE = "tokenizer.json"
MODEL_FILE = "model_v2.pt"
TOKENIZER_FILE = "tokenizer_v2.json"
CONFIG_FILE = "config.json"
# Change revision when realeasing new weights
REVISION = "9a2c505d0312ad6938b27d9b4338020fe37883e8"


def get_model_and_tokenizer(
Expand All @@ -55,13 +59,22 @@ def get_model_and_tokenizer(
disable_progress_bars()

model_path = hf_hub_download(
repo_id=MODEL_REPO, filename=MODEL_FILE, force_download=force_download
repo_id=MODEL_REPO,
filename=MODEL_FILE,
force_download=force_download,
revision=REVISION,
)
tokenizer_path = hf_hub_download(
repo_id=MODEL_REPO, filename=TOKENIZER_FILE, force_download=force_download
repo_id=MODEL_REPO,
filename=TOKENIZER_FILE,
force_download=force_download,
revision=REVISION,
)
hf_hub_download(
repo_id=MODEL_REPO, filename=CONFIG_FILE, force_download=force_download
repo_id=MODEL_REPO,
filename=CONFIG_FILE,
force_download=force_download,
revision=REVISION,
)
if quiet:
enable_progress_bars()
Expand Down Expand Up @@ -102,13 +115,19 @@ def run_inference(
Set of predicted domain objects
"""

apex = next(iter(input_domains)).apex_domain
subs = [dom.subdomain for dom in input_domains]
tokenizer_input = ",".join(sorted(subs)) + "[DELIM]"
# TODO: pick a different subset, if some were out of context last iteration
apex_domain = next(iter(input_domains)).apex_domain
subdomains_tokenizer_input = ",".join(sorted(subs)) + "[DELIM]"
apex_tokenizer_input = apex_domain + "[DELIM]"

subs_x = tokenizer.encode(subdomains_tokenizer_input)
apex_x = tokenizer.encode(apex_tokenizer_input)

# Trim subs to account for the apex part, grab last part
subs_x = subs_x[-(gpt_model.config.block_size - len(apex_x)) :]

x = tokenizer.encode(tokenizer_input)
x = [1] * (gpt_model.config.block_size - len(x)) + x
x = apex_x + subs_x
x = [gpt_model.pad_token] * (gpt_model.config.block_size - len(x)) + x
x = torch.tensor(x)

blocked_outputs = {dom.subdomain for dom in blocked_domains}
Expand All @@ -128,7 +147,7 @@ def run_inference(
for pred in predictions
}

predictions: set[str] = {sub + "." + apex for sub in predictions}
predictions: set[str] = {sub + "." + apex_domain for sub in predictions}

predicted_domains: set[Domain] = set()
for pred in predictions:
Expand Down
35 changes: 34 additions & 1 deletion subwiz/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init__(self, config: GPTConfig):
self.end_token = self.tokenizer("[END]")["input_ids"][0]
self.comma_token = self.tokenizer(",")["input_ids"][0]
self.delim_token = self.tokenizer("[DELIM]")["input_ids"][0]
self.pad_token = self.tokenizer("[PAD]")["input_ids"][0]

self.transformer = nn.ModuleDict(
dict(
Expand Down Expand Up @@ -382,6 +383,27 @@ def device(self) -> str:
# assign model inputs to the right device
return next(self.lm_head.parameters()).device.type

def _trim_subdomains(
self,
sequences: torch.Tensor,
apex_unpadded_position: int,
num_tokens_generated: int,
) -> torch.Tensor:
if num_tokens_generated == 0:
return sequences

trimming_position = (
apex_unpadded_position
if num_tokens_generated > apex_unpadded_position
else num_tokens_generated
)

sequences = torch.cat(
(sequences[:, :trimming_position], sequences[:, trimming_position + 1 :]),
dim=1,
)
return sequences

@torch.no_grad()
def generate(
self,
Expand Down Expand Up @@ -426,8 +448,15 @@ def generate(
run_uuid = uuid.uuid4()

idx = idx.to(self.device)
num_initial_pad_tokens = (idx == self.pad_token).sum().item()
sequences = idx.unsqueeze(0)

apex_padded_position = (sequences == self.delim_token).nonzero(as_tuple=True)[
1
][0]

apex_unpadded_position = 1 + apex_padded_position - num_initial_pad_tokens

probabilities = torch.tensor([1.0], device=self.device)

finished_sequences = torch.tensor([], device=self.device)
Expand All @@ -441,7 +470,11 @@ def generate(
on_iteration()

# trim the sequences down to block size
sequences = sequences[:, -self.config.block_size :]
sequences = self._trim_subdomains(
sequences,
apex_unpadded_position=apex_unpadded_position,
num_tokens_generated=i,
)

# remove any invalid subdomain starts
outputs = self.tokenizer.batch_decode(sequences[:, -i:])
Expand Down
8 changes: 7 additions & 1 deletion tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ def test_languages():
no_resolve=True,
)
print(results)
assert "english.hadrian.io" in results
assert {
"english.hadrian.io",
"french.hadrian.io",
"spanish.hadrian.io",
"portuguese.hadrian.io",
"dutch.hadrian.io",
} & set(results)


def test_numbers():
Expand Down