-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
192 lines (160 loc) · 6.96 KB
/
main.py
File metadata and controls
192 lines (160 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import json
import time
import os
import base64
import asyncio
import aiohttp
import uuid
from aiohttp import FormData
from collections import defaultdict
try:
from tqdm import tqdm
HAS_TQDM = True
except ImportError:
HAS_TQDM = False
class ProgressTracker:
def __init__(self, total_images, total_keys):
self.total = total_images * total_keys
self.completed = 0
self.lock = asyncio.Lock()
self.key_stats = defaultdict(int)
self.start_time = time.time()
if HAS_TQDM:
self.pbar = tqdm(total=self.total, desc="Генерация изображений", unit="img")
else:
print(f"Всего изображений для генерации: {self.total}")
async def update(self, key_prefix, current, total):
async with self.lock:
self.completed += 1
self.key_stats[key_prefix] += 1
if HAS_TQDM:
self.pbar.update(1)
self.pbar.set_postfix_str(f"Ключ {key_prefix}: {self.key_stats[key_prefix]}/{total}")
else:
elapsed = time.time() - self.start_time
img_per_sec = self.completed / elapsed if elapsed > 0 else 0
print(f"\rПрогресс: {self.completed}/{self.total} ({self.completed/self.total:.1%}) | "
f"{img_per_sec:.2f} img/sec | "
f"Ключ {key_prefix}: {self.key_stats[key_prefix]}/{total}", end="")
def close(self):
if HAS_TQDM:
self.pbar.close()
else:
print()
class Text2ImageAPI:
def __init__(self, url, fusion_brain_token, fusion_brain_key):
self.URL = url
self.AUTH_HEADERS = {
'X-Key': f'Key {fusion_brain_token}',
'X-Secret': f'Secret {fusion_brain_key}',
}
self.session = aiohttp.ClientSession()
self.request_semaphore = asyncio.Semaphore(1)
self.last_request_time = 0
async def _throttled_request(self, method, url, **kwargs):
async with self.request_semaphore:
elapsed = time.time() - self.last_request_time
if elapsed < 0.05:
await asyncio.sleep(0.1 - elapsed)
self.last_request_time = time.time()
async with method(url, headers=self.AUTH_HEADERS, **kwargs) as response:
return await response.json()
async def get_model(self):
data = await self._throttled_request(
self.session.get,
self.URL + 'key/api/v1/models'
)
return data[0]['id']
async def generate(self, prompt, model, images=1, width=1024, height=1024):
params = {
"type": "GENERATE",
"numImages": int(images),
"width": int(width),
"height": int(height),
"negativePromptDecoder": "яркие цвета, кислотность, высокая контрастность",
"generateParams": {
"query": str(prompt)
}
}
data = FormData()
data.add_field('model_id', str(model))
data.add_field('params', json.dumps(params), content_type='application/json')
response_data = await self._throttled_request(
self.session.post,
self.URL + 'key/api/v1/text2image/run',
data=data
)
return response_data.get('uuid')
async def check_generation(self, request_id, attempts=20, delay=5):
while attempts > 0:
data = await self._throttled_request(
self.session.get,
self.URL + 'key/api/v1/text2image/status/' + request_id
)
if data['status'] == 'DONE':
return data.get('images')
attempts -= 1
await asyncio.sleep(delay)
return None
async def close(self):
await self.session.close()
async def save_image(image_base64, output_dir="output", key_prefix=""):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
unique_id = uuid.uuid4().hex[:8]
filename = f"image_{int(time.time())}_{key_prefix}_{unique_id}.jpg"
filepath = os.path.join(output_dir, filename)
image_data = base64.b64decode(image_base64)
with open(filepath, "wb") as f:
f.write(image_data)
return filepath
async def worker(prompt, key_pair, total_images, output_dir, progress_tracker):
token, key = key_pair.split(':')
api = Text2ImageAPI('https://api-key.fusionbrain.ai/', token, key)
key_prefix = token[:6]
try:
model_id = await api.get_model()
if not model_id:
print(f"Key {key_prefix}...: Error getting model")
return
generated = 0
while generated < total_images:
try:
gen_uuid = await api.generate(prompt, model_id)
if not gen_uuid:
print(f"Key {key_prefix}...: Failed to start generation")
continue
images = await api.check_generation(gen_uuid)
if images:
image_path = await save_image(images[0], output_dir, key_prefix)
generated += 1
await progress_tracker.update(key_prefix, generated, total_images)
else:
print(f"\nKey {key_prefix}...: Generation failed, retrying...")
except Exception as e:
print(f"\nKey {key_prefix}...: Error - {str(e)}")
await asyncio.sleep(2)
finally:
await api.close()
async def main():
prompt = "Девушка курьер на электросамокате в красной повседневной одежде (свитшот, джинсы) с термокоробом за спиной, с телефоном, смотрит в камеру с дружелюбной улыбкой, фон — город, солнечный день, цветение сакуры, стиль — аниме с элементами реализма"
total_images_per_key = 100
output_dir = "output"
try:
with open('keys.txt', 'r') as f:
key_pairs = [line.strip() for line in f if line.strip() and ':' in line]
except FileNotFoundError:
print("Error: keys.txt file not found")
return
progress_tracker = ProgressTracker(total_images_per_key, len(key_pairs))
global_semaphore = asyncio.Semaphore(20)
async def throttled_worker(prompt, key_pair, total_images, output_dir, progress_tracker):
async with global_semaphore:
await worker(prompt, key_pair, total_images, output_dir, progress_tracker)
await asyncio.sleep(0.1)
workers = [throttled_worker(prompt, key_pair, total_images_per_key, output_dir, progress_tracker)
for key_pair in key_pairs]
await asyncio.gather(*workers)
progress_tracker.close()
if __name__ == '__main__':
asyncio.run(main())