diff --git a/src/synthid_text/hashing_function.py b/src/synthid_text/hashing_function.py index 62493fa..523b8ec 100644 --- a/src/synthid_text/hashing_function.py +++ b/src/synthid_text/hashing_function.py @@ -26,7 +26,8 @@ def accumulate_hash( ) -> torch.LongTensor: """Accumulate hash of data on current hash. - Method uses adapted linear congruential generator with newlib/musl parameters. + Method uses adapted linear congruential generator (LCG)with newlib/musl + parameters. This function has following property - f(x, data[T]) = f(f(x, data[:T - 1]), data[T]) diff --git a/src/synthid_text/logits_processing.py b/src/synthid_text/logits_processing.py index 256a758..3866c9d 100644 --- a/src/synthid_text/logits_processing.py +++ b/src/synthid_text/logits_processing.py @@ -16,12 +16,11 @@ """Logit processor for supporting watermarking in HF model.""" from collections.abc import Sequence - +import hashlib +from synthid_text import hashing_function import torch import transformers -from synthid_text import hashing_function - def update_scores( scores: torch.FloatTensor, @@ -78,11 +77,13 @@ def update_scores_distortionary( for i in range(depth): g_values_at_depth = g_values[:, :, i] g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True) - coeff_not_in_g = (1 - g_mass_at_depth)**(num_leaves - 1) - coeff_in_g = (1 - (1 - g_mass_at_depth)**(num_leaves)) / g_mass_at_depth + coeff_not_in_g = (1 - g_mass_at_depth) ** (num_leaves - 1) + coeff_in_g = (1 - (1 - g_mass_at_depth) ** (num_leaves)) / g_mass_at_depth coeffs = torch.where( torch.logical_and(g_values_at_depth == 1, probs > 0), - coeff_in_g, coeff_not_in_g) + coeff_in_g, + coeff_not_in_g, + ) probs = probs * coeffs log_probs = torch.log(probs) @@ -129,9 +130,6 @@ class SynthIDLogitsProcessor(transformers.LogitsProcessor): Logits processor updates the provided scores based on the binary g values assigned to each possible ngram and watermarking key combination hashed into an int64 keys. - - A random sampling table is pre-computed and modulo table size is applied to - map from ngram keys (int64) to g values. """ def __init__( @@ -139,23 +137,19 @@ def __init__( *, ngram_len: int, keys: Sequence[int], - sampling_table_size: int, - sampling_table_seed: int, context_history_size: int, temperature: float, top_k: int, device: torch.device, skip_first_ngram_calls: bool = False, apply_top_k: bool = True, - num_leaves: int = 2 + num_leaves: int = 2, ): """Initializes the logits processor. Args: ngram_len: Ngram length. keys: A sequence of watermarking keys, one for each depth. - sampling_table_size: Size of the sampling table. - sampling_table_seed: Random seed to generate the sampling table. context_history_size: Size of the tensor to keep track of seen contexts. temperature: Temperature to use for scaling the scores. top_k: Top k to use for sampling the scores. @@ -167,21 +161,16 @@ def __init__( self.ngram_len = ngram_len self.keys = torch.tensor(keys, device=device) - generator = torch.Generator(device=device).manual_seed(sampling_table_seed) - # A random sampling table is pre-computed and modulo table size is applied - # to map from a hash of ngram keys to g values, this is similar to the - # hashtable implementation used in - # https://github.com/facebookresearch/three_bricks. We note that the - # hashing employed in this repository is different from that used to - # watermark the Gemini App, and hence the detectors trained based on the - # hashing in this repository will not transfer to text generated by - # the Gemini App. - self.sampling_table = torch.randint( - low=0, - high=2, - size=(sampling_table_size,), - generator=generator, - device=device, + # Hash the keys to a string to be used as initialization vector (IV) + # for the hash function. Very important to have an unpredictable IV. + self.hash_iv = hashlib.sha256( + self.keys.to(torch.long).numpy().tobytes() + ).digest() + + # Assuming that the platform supports int64. + torch_long_max = torch.iinfo(torch.int64).max + self.hash_iv = ( + int.from_bytes(self.hash_iv, byteorder="big") % torch_long_max ) self.context_history_size = context_history_size self.device = device @@ -302,8 +291,8 @@ def watermarked_call( ) # ngram_keys shape [batch_size, top_k, depth] - # 3. Sample g values. - g_values = self.sample_g_values(ngram_keys) + # 3. Sample g values by taking the lowest bit of the hash. + g_values = self.get_gvals(ngram_keys) # g_values shape [batch_size, top_k, depth] # 4. Modify scores. @@ -336,6 +325,36 @@ def watermarked_call( ) return updated_watermarked_scores, top_k_indices, scores_top_k + def get_gvals( + self, + ngram_keys: torch.LongTensor, + num_apply_hash: int = 12, + shift: int = 0, + ) -> torch.LongTensor: + """Samples g values from the computed ngram keys. + + To derive the gvals we iteratively take the lowest three bits of + the ngram keys and add it to the previous gval. + + Args: + ngram_keys: Random keys (batch_size, num_ngrams, depth). + num_apply_hash: Number of times to apply the hash function. + shift: Number of bits to shift the hash result. + + Returns: + G values (batch_size, num_ngrams, depth). + """ + + shift = shift or (64 // num_apply_hash) + + for _ in range(num_apply_hash): + ngram_keys = ( + hashing_function.accumulate_hash(ngram_keys, torch.LongTensor([1])) + >> shift + ) + + return (ngram_keys >> 30) % 2 + def compute_ngram_keys( self, ngrams: torch.LongTensor, @@ -360,7 +379,10 @@ def compute_ngram_keys( ) batch_size, _, _ = ngrams.shape - hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long) + # Initialize hash result with the same hash_iv for all batch entries. + hash_result = torch.full( + (batch_size,), self.hash_iv, dtype=torch.long, device=self.device + ) # hash_result shape [batch_size,] # ngrams shape [batch_size, num_ngrams, ngram_len] hash_result = torch.vmap( @@ -394,7 +416,10 @@ def _compute_keys( """ batch_size, _ = n_minus_1_grams.shape - hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long) + # Initialize hash result with the same hash_iv for all batch entries. + hash_result = torch.full( + (batch_size,), self.hash_iv, dtype=torch.long, device=self.device + ) # First hash n_minus_1 gram, for each batch entry we have a single # n_minus_1 gram context. # hash_result shape [batch_size] @@ -422,24 +447,6 @@ def _compute_keys( # hash_result shape should be [batch_size, num_indices, depth] return hash_result, hash_result_with_just_context - def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor: - """Samples g values from Bernoulli distribution. - - It is not possible to pass random keys in a vectorized way in torch. Instead - we pre-compute a random sampling table, and use apply modulo table size to - map from ngram keys (int64) to g values. - - Args: - ngram_keys: Random keys (batch_size, num_ngrams, depth). - - Returns: - G values (batch_size, num_ngrams, depth). - """ - (sampling_table_size,) = self.sampling_table.shape - sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size)) - ngram_keys = ngram_keys % sampling_table_size - return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2) - def _check_input_ids_shape(self, input_ids: torch.LongTensor): """Checks the shape of input ids.""" if len(input_ids.shape) != 2: @@ -463,7 +470,7 @@ def compute_g_values( self._check_input_ids_shape(input_ids) ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1) ngram_keys = self.compute_ngram_keys(ngrams) - return self.sample_g_values(ngram_keys) + return self.get_gvals(ngram_keys) def compute_context_repetition_mask( self, @@ -497,7 +504,10 @@ def compute_context_repetition_mask( are_repeated_contexts = [] for i in range(num_contexts): context = contexts[:, i, :] - hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long) + # Initialize hash result with the same hash_iv for all batch entries. + hash_result = torch.full( + (batch_size,), self.hash_iv, dtype=torch.long, device=self.device + ) context_hash = hashing_function.accumulate_hash(hash_result, context)[ :, None ] diff --git a/src/synthid_text/logits_processing_test.py b/src/synthid_text/logits_processing_test.py index 0176291..51de7b5 100644 --- a/src/synthid_text/logits_processing_test.py +++ b/src/synthid_text/logits_processing_test.py @@ -19,12 +19,11 @@ from absl.testing import parameterized import immutabledict import numpy as np -import torch -import tqdm - -from synthid_text import logits_processing from synthid_text import g_value_expectations +from synthid_text import logits_processing from synthid_text import torch_testing +import torch +import tqdm def does_mean_g_value_matches_theoretical( @@ -68,8 +67,6 @@ def does_mean_g_value_matches_theoretical( logits_processor = logits_processing.SynthIDLogitsProcessor( ngram_len=ngram_len, keys=keys, - sampling_table_size=2**16, - sampling_table_seed=0, context_history_size=context_history_size, device=device, top_k=vocab_size, @@ -147,8 +144,6 @@ def test_g_value_uniformity_for_random_ngrams( watermarking_config = immutabledict.immutabledict({ 'ngram_len': ngram_len, 'keys': np.random.randint(low=0, high=2**16, size=(num_layers,)), - 'sampling_table_size': 2**16, - 'sampling_table_seed': 0, 'context_history_size': 512, 'device': device, }) @@ -187,8 +182,6 @@ def test_g_values_uniformity_across_vocab_size(self, vocab_size, num_layers): watermarking_config = immutabledict.immutabledict({ 'ngram_len': ngram_len, 'keys': np.random.randint(low=0, high=2**16, size=(num_layers,)), - 'sampling_table_size': 2**16, - 'sampling_table_seed': 0, 'context_history_size': 512, 'device': device, }) @@ -209,7 +202,7 @@ def test_g_values_uniformity_across_vocab_size(self, vocab_size, num_layers): ), ) - g_values = logits_processor.sample_g_values(ngram_keys) + g_values = logits_processor.get_gvals(ngram_keys) # g_values shape should be [batch_size, vocab_size, num_layers] g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1)) self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001) @@ -227,8 +220,6 @@ def test_distributional_convergence(self): watermarking_config = immutabledict.immutabledict({ 'ngram_len': 5, 'keys': np.random.randint(0, 10**9, size=(1,), dtype=np.int64), - 'sampling_table_size': 2**16, - 'sampling_table_seed': 0, 'context_history_size': 1024, 'device': device, }) @@ -302,19 +293,26 @@ def test_distributional_convergence(self): ), ) def test_bias_from_logits_processor( - self, vocab_size, ngram_len, num_layers, atol, num_leaves: int = 2, + self, + vocab_size, + ngram_len, + num_layers, + atol, + num_leaves: int = 2, ): """Check if watermarked distribution converges to input distribution.""" device = torch_testing.torch_device() mean, expected, passes = does_mean_g_value_matches_theoretical( vocab_size=vocab_size, ngram_len=ngram_len, - batch_size=20_000, - keys=[np.random.randint(0, 10**9) for _ in range(num_layers)], + batch_size=50_000, + keys=[1], atol=atol, device=device, num_leaves=num_leaves, ) + print('Mean', mean) + print('Expected', expected) self.assertTrue(passes) @@ -334,8 +332,6 @@ def set_up_logits_processor( watermarking_config = immutabledict.immutabledict({ 'ngram_len': ngram_len, 'keys': np.random.randint(low=0, high=2**16, size=(num_layers,)), - 'sampling_table_size': 2**16, - 'sampling_table_seed': 0, 'context_history_size': 512, 'device': device, }) diff --git a/src/synthid_text/synthid_mixin.py b/src/synthid_text/synthid_mixin.py index a78404d..b4017dc 100644 --- a/src/synthid_text/synthid_mixin.py +++ b/src/synthid_text/synthid_mixin.py @@ -19,11 +19,10 @@ from typing import Any, Optional, Union import immutabledict +from synthid_text import logits_processing import torch import transformers -from synthid_text import logits_processing - DEFAULT_WATERMARKING_CONFIG = immutabledict.immutabledict({ "ngram_len": 5, # This corresponds to H=4 context window size in the paper. @@ -59,8 +58,6 @@ 90, 960, ], - "sampling_table_size": 2**16, - "sampling_table_seed": 0, "context_history_size": 1024, "device": ( torch.device("cuda:0") @@ -212,10 +209,10 @@ def _sample( ) if has_eos_stopping_criteria and pad_token_id is None: raise ValueError( - "`stopping_criteria` is not empty, `pad_token_id` must be set in " - "`generation_config`. See " - "https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig" - "for more on how to configure the `pad_token_id`." + "`stopping_criteria` is not empty, `pad_token_id` must be set in" + " `generation_config`. See" + " https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfigfor" + " more on how to configure the `pad_token_id`." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None