-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathutils.py
More file actions
133 lines (117 loc) · 5.06 KB
/
utils.py
File metadata and controls
133 lines (117 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from pathlib import Path
from icv_src.icv_datasets.vqa_dataset import load_okvqa_ds, load_vqav2_ds
from icv_src.icv_datasets.caption_dataset import load_coco_ds
from icv_src.metrics import (
postprocess_ok_vqa_generation,
postprocess_vqa_generation,
)
from lmm_icl_interface import (
Idefics2Interface,
IdeficsInterface,
LMMPromptManager,
OpenFlamingoInterface,
)
def get_icv_cpk_path(result_dir, model_name, dataset_name, run_name):
result_dir = Path(result_dir)
model_cpk_dir = result_dir / "model_cpk" / dataset_name / model_name / run_name
return model_cpk_dir
def get_inference_paths(result_dir, model_name, dataset_name, run_name):
result_dir = Path(result_dir)
save_dir = result_dir / "inference" / model_name / dataset_name / run_name
meta_info_dir = save_dir / "meta_info"
metric_file_path = save_dir / "result.json"
return save_dir, meta_info_dir, metric_file_path
def init_interface(cfg):
model_cpk_dir = Path(cfg.model_cpk_dir)
prompt_manager = LMMPromptManager(
cfg.prompt.prompt_template,
cfg.prompt.column_token_map,
label_field=cfg.prompt.label_filed,
sep_token=cfg.prompt.sep_token,
query_prompt_template=cfg.prompt.query_prompt_template,
)
if cfg.lmm.name == "idefics-9b":
interface = IdeficsInterface(
model_name_or_path=model_cpk_dir / cfg.lmm.model_name,
precision=cfg.lmm.precision,
device=cfg.lmm.device,
prompt_manager=prompt_manager,
instruction=cfg.prompt.instruction,
image_field=cfg.prompt.image_field,
label_field=cfg.prompt.label_filed,
)
processor = interface.processor
elif "openflamingo" in cfg.lmm.name.lower():
interface = OpenFlamingoInterface(
lang_encoder_path=model_cpk_dir / cfg.lmm.lang_encoder_path,
tokenizer_path=model_cpk_dir / cfg.lmm.tokenizer_path,
flamingo_checkpoint_dir=model_cpk_dir / cfg.lmm.hf_root,
cross_attn_every_n_layers=cfg.lmm.cross_attn_every_n_layers,
hf_root=cfg.lmm.hf_root,
precision=cfg.lmm.precision,
device=cfg.lmm.device,
prompt_manager=prompt_manager,
instruction=cfg.prompt.instruction,
image_field=cfg.prompt.image_field,
label_field=cfg.prompt.label_filed,
load_from_local=True,
init_device=cfg.lmm.init_device,
)
processor = interface.processor
elif cfg.lmm.name == "idefics2-8b-base":
interface = Idefics2Interface(
model_name_or_path=model_cpk_dir / cfg.lmm.model_name,
precision=cfg.lmm.precision,
device=cfg.lmm.device,
prompt_manager=prompt_manager,
instruction=cfg.prompt.instruction,
image_field=cfg.prompt.image_field,
label_field=cfg.prompt.label_filed,
)
processor = interface.processor
return prompt_manager, interface, processor
def init_dataset(cfg, split):
if cfg.data_cfg.task.datasets.name == "vqav2":
ds = load_vqav2_ds(
cfg.data_cfg.task.datasets.root_dir,
cfg.data_cfg.task.datasets.train_coco_dataset_root,
cfg.data_cfg.task.datasets.val_coco_dataset_root,
split,
val_ann_file=cfg.data_cfg.task.datasets.val_ann_path,
)
post_process_fun = vqa_postprocess
elif cfg.data_cfg.task.datasets.name == "okvqa":
ds = load_okvqa_ds(
cfg.data_cfg.task.datasets.root_dir,
cfg.data_cfg.task.datasets.train_coco_dataset_root,
cfg.data_cfg.task.datasets.val_coco_dataset_root,
split,
)
post_process_fun = ok_vq_postprocess
elif cfg.data_cfg.task.datasets.name == "coco2017":
ds = load_coco_ds(
train_coco_dataset_root=cfg.data_cfg.task.datasets.train_coco_dataset_root,
val_coco_dataset_root=cfg.data_cfg.task.datasets.val_coco_dataset_root,
train_coco_annotation_file=cfg.data_cfg.task.datasets.train_coco_annotation_file,
val_coco_annotation_file=cfg.data_cfg.task.datasets.val_coco_annotation_file,
split=split,
)
post_process_fun = caption_postprocess
else:
raise ValueError(f"{cfg.data_cfg.task.datasets.name=} error")
return ds, post_process_fun
def caption_postprocess(text, model_name):
if "flamingo" in model_name:
return text.split("Output", 1)[0].replace('"', "").strip()
elif "idefics" in model_name:
return text.split("Caption", 1)[0].replace('"', "").replace("\n", "").strip()
def vqa_postprocess(text, model_name):
if "flamingo" in model_name:
return postprocess_vqa_generation(text).strip()
elif "idefics" in model_name:
return postprocess_vqa_generation(text).replace("\n", "").strip()
def ok_vq_postprocess(text, model_name):
if "flamingo" in model_name:
return postprocess_ok_vqa_generation(text).strip()
elif "idefics" in model_name:
return postprocess_ok_vqa_generation(text).replace("\n", "").strip()