-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
130 lines (105 loc) · 4.32 KB
/
utils.py
File metadata and controls
130 lines (105 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
random.seed(0)
################################
##### Activation functions #####
################################
def forward_with_cache(model, inputs, module, no_grad=True):
# define a tensor with the size of our cached activations
cache = []
def hook(module, input, output):
if isinstance(output, tuple):
cache.append(output[0])
else:
cache.append(output)
return None
hook_handle = module.register_forward_hook(hook)
if no_grad:
with torch.no_grad():
_ = model(**inputs)
else:
_ = model(**inputs)
hook_handle.remove()
return cache[0]
#######################################
##### Model and data loading code #####
#######################################
def get_params(model, layer_ids, param_ids):
params = []
for layer_id in layer_ids:
for i, p in enumerate(model.model.layers[layer_id].parameters()):
if i in param_ids:
params.append(p)
return params
def load_model(model_name_or_path):
torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True, use_fast=False
)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenizer.mask_token_id = tokenizer.eos_token_id
tokenizer.sep_token_id = tokenizer.eos_token_id
tokenizer.cls_token_id = tokenizer.eos_token_id
return model, tokenizer
def get_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
def get_dataset(name):
data = []
if name == "wikitext":
from datasets import load_dataset
raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", cache_dir='/egr/research-optml/wangc168/cache')
for x in raw_data:
if len(x['text']) > min_len:
data.append(str(x['text']))
else:
lable = 0
path = '/egr/research-optml/wangc168/watermark/IBM_Run_test/dataset/bio_remove_dataset.jsonl'
for line in open(path, "r"):
raw_text = json.loads(line)['text']
if len(raw_text) > min_len:
data.append(str(raw_text))
data_len = len(data)
print(data_len)
data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
return data
return (
[get_dataset(c) for c in forget_corpora],
[get_dataset(c) for c in retain_corpora],
)
def get_star1k_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
def get_dataset(name, is_forget=False):
data = []
if name == "wikitext":
from datasets import load_dataset
raw_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test", cache_dir='/egr/research-optml/wangc168/cache')
for x in raw_data:
if len(x['text']) > min_len:
data.append(str(x['text']))
elif is_forget and name == "safety":
from datasets import load_dataset
# make sure you've logged in to HuggingFace CLI if required
raw_data = load_dataset("UCSC-VLAA/STAR-1", split="train")
for x in raw_data:
if "question" in x and len(x["question"]) > min_len:
data.append(str(x["question"]))
else:
path = '/egr/research-optml/wangc168/watermark/IBM_Run_test/dataset/bio_remove_dataset.jsonl'
for line in open(path, "r"):
raw_text = json.loads(line)['text']
if len(raw_text) > min_len:
data.append(str(raw_text))
print(f"{name} data size:", len(data))
data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
return data
return (
[get_dataset(c, is_forget=True) for c in forget_corpora],
[get_dataset(c, is_forget=False) for c in retain_corpora],
)