Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/synthid_text/hashing_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def accumulate_hash(
) -> torch.LongTensor:
"""Accumulate hash of data on current hash.

Method uses adapted linear congruential generator with newlib/musl parameters.
Method uses adapted linear congruential generator (LCG)with newlib/musl
parameters.

This function has following property -
f(x, data[T]) = f(f(x, data[:T - 1]), data[T])
Expand Down
116 changes: 63 additions & 53 deletions src/synthid_text/logits_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
"""Logit processor for supporting watermarking in HF model."""

from collections.abc import Sequence

import hashlib
from synthid_text import hashing_function
import torch
import transformers

from synthid_text import hashing_function


def update_scores(
scores: torch.FloatTensor,
Expand Down Expand Up @@ -78,11 +77,13 @@ def update_scores_distortionary(
for i in range(depth):
g_values_at_depth = g_values[:, :, i]
g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True)
coeff_not_in_g = (1 - g_mass_at_depth)**(num_leaves - 1)
coeff_in_g = (1 - (1 - g_mass_at_depth)**(num_leaves)) / g_mass_at_depth
coeff_not_in_g = (1 - g_mass_at_depth) ** (num_leaves - 1)
coeff_in_g = (1 - (1 - g_mass_at_depth) ** (num_leaves)) / g_mass_at_depth
coeffs = torch.where(
torch.logical_and(g_values_at_depth == 1, probs > 0),
coeff_in_g, coeff_not_in_g)
coeff_in_g,
coeff_not_in_g,
)
probs = probs * coeffs

log_probs = torch.log(probs)
Expand Down Expand Up @@ -129,33 +130,26 @@ class SynthIDLogitsProcessor(transformers.LogitsProcessor):
Logits processor updates the provided scores based on the binary g values
assigned to each possible ngram and watermarking key combination hashed into
an int64 keys.

A random sampling table is pre-computed and modulo table size is applied to
map from ngram keys (int64) to g values.
"""

def __init__(
self,
*,
ngram_len: int,
keys: Sequence[int],
sampling_table_size: int,
sampling_table_seed: int,
context_history_size: int,
temperature: float,
top_k: int,
device: torch.device,
skip_first_ngram_calls: bool = False,
apply_top_k: bool = True,
num_leaves: int = 2
num_leaves: int = 2,
):
"""Initializes the logits processor.

Args:
ngram_len: Ngram length.
keys: A sequence of watermarking keys, one for each depth.
sampling_table_size: Size of the sampling table.
sampling_table_seed: Random seed to generate the sampling table.
context_history_size: Size of the tensor to keep track of seen contexts.
temperature: Temperature to use for scaling the scores.
top_k: Top k to use for sampling the scores.
Expand All @@ -167,21 +161,16 @@ def __init__(
self.ngram_len = ngram_len
self.keys = torch.tensor(keys, device=device)

generator = torch.Generator(device=device).manual_seed(sampling_table_seed)
# A random sampling table is pre-computed and modulo table size is applied
# to map from a hash of ngram keys to g values, this is similar to the
# hashtable implementation used in
# https://github.com/facebookresearch/three_bricks. We note that the
# hashing employed in this repository is different from that used to
# watermark the Gemini App, and hence the detectors trained based on the
# hashing in this repository will not transfer to text generated by
# the Gemini App.
self.sampling_table = torch.randint(
low=0,
high=2,
size=(sampling_table_size,),
generator=generator,
device=device,
# Hash the keys to a string to be used as initialization vector (IV)
# for the hash function. Very important to have an unpredictable IV.
self.hash_iv = hashlib.sha256(
self.keys.to(torch.long).numpy().tobytes()
).digest()

# Assuming that the platform supports int64.
torch_long_max = torch.iinfo(torch.int64).max
self.hash_iv = (
int.from_bytes(self.hash_iv, byteorder="big") % torch_long_max
)
self.context_history_size = context_history_size
self.device = device
Expand Down Expand Up @@ -302,8 +291,8 @@ def watermarked_call(
)
# ngram_keys shape [batch_size, top_k, depth]

# 3. Sample g values.
g_values = self.sample_g_values(ngram_keys)
# 3. Sample g values by taking the lowest bit of the hash.
g_values = self.get_gvals(ngram_keys)
# g_values shape [batch_size, top_k, depth]

# 4. Modify scores.
Expand Down Expand Up @@ -336,6 +325,36 @@ def watermarked_call(
)
return updated_watermarked_scores, top_k_indices, scores_top_k

def get_gvals(
self,
ngram_keys: torch.LongTensor,
num_apply_hash: int = 12,
shift: int = 0,
) -> torch.LongTensor:
"""Samples g values from the computed ngram keys.

To derive the gvals we iteratively take the lowest three bits of
the ngram keys and add it to the previous gval.

Args:
ngram_keys: Random keys (batch_size, num_ngrams, depth).
num_apply_hash: Number of times to apply the hash function.
shift: Number of bits to shift the hash result.

Returns:
G values (batch_size, num_ngrams, depth).
"""

shift = shift or (64 // num_apply_hash)

for _ in range(num_apply_hash):
ngram_keys = (
hashing_function.accumulate_hash(ngram_keys, torch.LongTensor([1]))
>> shift
)

return (ngram_keys >> 30) % 2

def compute_ngram_keys(
self,
ngrams: torch.LongTensor,
Expand All @@ -360,7 +379,10 @@ def compute_ngram_keys(
)
batch_size, _, _ = ngrams.shape

hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
# Initialize hash result with the same hash_iv for all batch entries.
hash_result = torch.full(
(batch_size,), self.hash_iv, dtype=torch.long, device=self.device
)
# hash_result shape [batch_size,]
# ngrams shape [batch_size, num_ngrams, ngram_len]
hash_result = torch.vmap(
Expand Down Expand Up @@ -394,7 +416,10 @@ def _compute_keys(
"""
batch_size, _ = n_minus_1_grams.shape

hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
# Initialize hash result with the same hash_iv for all batch entries.
hash_result = torch.full(
(batch_size,), self.hash_iv, dtype=torch.long, device=self.device
)
# First hash n_minus_1 gram, for each batch entry we have a single
# n_minus_1 gram context.
# hash_result shape [batch_size]
Expand Down Expand Up @@ -422,24 +447,6 @@ def _compute_keys(
# hash_result shape should be [batch_size, num_indices, depth]
return hash_result, hash_result_with_just_context

def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor:
"""Samples g values from Bernoulli distribution.

It is not possible to pass random keys in a vectorized way in torch. Instead
we pre-compute a random sampling table, and use apply modulo table size to
map from ngram keys (int64) to g values.

Args:
ngram_keys: Random keys (batch_size, num_ngrams, depth).

Returns:
G values (batch_size, num_ngrams, depth).
"""
(sampling_table_size,) = self.sampling_table.shape
sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size))
ngram_keys = ngram_keys % sampling_table_size
return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2)

def _check_input_ids_shape(self, input_ids: torch.LongTensor):
"""Checks the shape of input ids."""
if len(input_ids.shape) != 2:
Expand All @@ -463,7 +470,7 @@ def compute_g_values(
self._check_input_ids_shape(input_ids)
ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1)
ngram_keys = self.compute_ngram_keys(ngrams)
return self.sample_g_values(ngram_keys)
return self.get_gvals(ngram_keys)

def compute_context_repetition_mask(
self,
Expand Down Expand Up @@ -497,7 +504,10 @@ def compute_context_repetition_mask(
are_repeated_contexts = []
for i in range(num_contexts):
context = contexts[:, i, :]
hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long)
# Initialize hash result with the same hash_iv for all batch entries.
hash_result = torch.full(
(batch_size,), self.hash_iv, dtype=torch.long, device=self.device
)
context_hash = hashing_function.accumulate_hash(hash_result, context)[
:, None
]
Expand Down
32 changes: 14 additions & 18 deletions src/synthid_text/logits_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
from absl.testing import parameterized
import immutabledict
import numpy as np
import torch
import tqdm

from synthid_text import logits_processing
from synthid_text import g_value_expectations
from synthid_text import logits_processing
from synthid_text import torch_testing
import torch
import tqdm


def does_mean_g_value_matches_theoretical(
Expand Down Expand Up @@ -68,8 +67,6 @@ def does_mean_g_value_matches_theoretical(
logits_processor = logits_processing.SynthIDLogitsProcessor(
ngram_len=ngram_len,
keys=keys,
sampling_table_size=2**16,
sampling_table_seed=0,
context_history_size=context_history_size,
device=device,
top_k=vocab_size,
Expand Down Expand Up @@ -147,8 +144,6 @@ def test_g_value_uniformity_for_random_ngrams(
watermarking_config = immutabledict.immutabledict({
'ngram_len': ngram_len,
'keys': np.random.randint(low=0, high=2**16, size=(num_layers,)),
'sampling_table_size': 2**16,
'sampling_table_seed': 0,
'context_history_size': 512,
'device': device,
})
Expand Down Expand Up @@ -187,8 +182,6 @@ def test_g_values_uniformity_across_vocab_size(self, vocab_size, num_layers):
watermarking_config = immutabledict.immutabledict({
'ngram_len': ngram_len,
'keys': np.random.randint(low=0, high=2**16, size=(num_layers,)),
'sampling_table_size': 2**16,
'sampling_table_seed': 0,
'context_history_size': 512,
'device': device,
})
Expand All @@ -209,7 +202,7 @@ def test_g_values_uniformity_across_vocab_size(self, vocab_size, num_layers):
),
)

g_values = logits_processor.sample_g_values(ngram_keys)
g_values = logits_processor.get_gvals(ngram_keys)
# g_values shape should be [batch_size, vocab_size, num_layers]
g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1))
self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001)
Expand All @@ -227,8 +220,6 @@ def test_distributional_convergence(self):
watermarking_config = immutabledict.immutabledict({
'ngram_len': 5,
'keys': np.random.randint(0, 10**9, size=(1,), dtype=np.int64),
'sampling_table_size': 2**16,
'sampling_table_seed': 0,
'context_history_size': 1024,
'device': device,
})
Expand Down Expand Up @@ -302,19 +293,26 @@ def test_distributional_convergence(self):
),
)
def test_bias_from_logits_processor(
self, vocab_size, ngram_len, num_layers, atol, num_leaves: int = 2,
self,
vocab_size,
ngram_len,
num_layers,
atol,
num_leaves: int = 2,
):
"""Check if watermarked distribution converges to input distribution."""
device = torch_testing.torch_device()
mean, expected, passes = does_mean_g_value_matches_theoretical(
vocab_size=vocab_size,
ngram_len=ngram_len,
batch_size=20_000,
keys=[np.random.randint(0, 10**9) for _ in range(num_layers)],
batch_size=50_000,
keys=[1],
atol=atol,
device=device,
num_leaves=num_leaves,
)
print('Mean', mean)
print('Expected', expected)
self.assertTrue(passes)


Expand All @@ -334,8 +332,6 @@ def set_up_logits_processor(
watermarking_config = immutabledict.immutabledict({
'ngram_len': ngram_len,
'keys': np.random.randint(low=0, high=2**16, size=(num_layers,)),
'sampling_table_size': 2**16,
'sampling_table_seed': 0,
'context_history_size': 512,
'device': device,
})
Expand Down
13 changes: 5 additions & 8 deletions src/synthid_text/synthid_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
from typing import Any, Optional, Union

import immutabledict
from synthid_text import logits_processing
import torch
import transformers

from synthid_text import logits_processing


DEFAULT_WATERMARKING_CONFIG = immutabledict.immutabledict({
"ngram_len": 5, # This corresponds to H=4 context window size in the paper.
Expand Down Expand Up @@ -59,8 +58,6 @@
90,
960,
],
"sampling_table_size": 2**16,
"sampling_table_seed": 0,
"context_history_size": 1024,
"device": (
torch.device("cuda:0")
Expand Down Expand Up @@ -212,10 +209,10 @@ def _sample(
)
if has_eos_stopping_criteria and pad_token_id is None:
raise ValueError(
"`stopping_criteria` is not empty, `pad_token_id` must be set in "
"`generation_config`. See "
"https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig"
"for more on how to configure the `pad_token_id`."
"`stopping_criteria` is not empty, `pad_token_id` must be set in"
" `generation_config`. See"
" https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfigfor"
" more on how to configure the `pad_token_id`."
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
Expand Down