-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathautolabel.py
More file actions
71 lines (58 loc) · 2.43 KB
/
Copy pathautolabel.py
File metadata and controls
71 lines (58 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
class AutoLabel:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"prompt": ("STRING", {"default": "a photography of"}),
"repo_id": ("STRING", {"default": "Salesforce/blip-image-captioning-base"}),
"inference_mode": (["gpu_float16", "gpu", "cpu"],),
"get_model_online": ("BOOLEAN", {"default": True},)
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("main_object_description",)
FUNCTION = "generate_caption"
CATEGORY = "AutoLabel"
def tensor_to_image(self, tensor):
tensor = tensor.cpu()
image_np = tensor.squeeze().mul(255).clamp(0, 255).byte().numpy()
image = Image.fromarray(image_np, mode='RGB')
return image
def generate_caption(self, image, prompt, repo_id, inference_mode, get_model_online):
if image is None:
raise ValueError("Need an image")
if not repo_id:
raise ValueError("Need a repo_id or local_model_path")
if not get_model_online:
os.environ['TRANSFORMERS_OFFLINE'] = "1"
processor = BlipProcessor.from_pretrained(repo_id)
pil_image = self.tensor_to_image(image)
try:
if inference_mode == "gpu_float16":
model = BlipForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch.float16).to("cuda")
inputs = processor(pil_image, prompt, return_tensors="pt").to("cuda", torch.float16)
elif inference_mode == "gpu":
model = BlipForConditionalGeneration.from_pretrained(repo_id).to("cuda")
inputs = processor(pil_image, prompt, return_tensors="pt").to("cuda")
else:
model = BlipForConditionalGeneration.from_pretrained(repo_id)
inputs = processor(pil_image, prompt, return_tensors="pt")
out = model.generate(**inputs)
description = processor.decode(out[0], skip_special_tokens=True)
return (description,)
except Exception as e:
print(e)
return ("Error occurred during caption generation",)
NODE_CLASS_MAPPINGS = {
"AutoLabel": AutoLabel
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AutoLabel": "Auto Label"
}