import os
import time
import queue
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import onnxruntime as ort
from post_process import PPDocLayoutPostProcess
from pre_process import PPDocLayoutPreProcess
from configs import gpu_ids, SESSIONS_PER_GPU,LABELS
def build_batch_outputs_from_pred(pred, batch_size: int) -> List[Dict[str, np.ndarray]]:
"""
pred[0] boxes_all: (max_total, 7) batch拍平
pred[1] counts: (N,)
pred[2] masks_all: (max_total, 200, 200) (可选)
-> batch_outputs: [{"boxes": boxes_i, "masks": masks_i?}, ...]
"""
boxes_all = pred[0]
counts = pred[1].astype(int).tolist()
if len(counts) != batch_size:
raise RuntimeError(f"counts length {len(counts)} != batch_size {batch_size}")
has_masks = len(pred) >= 3 and pred[2] is not None
masks_all = pred[2] if has_masks else None
batch_outputs: List[Dict[str, np.ndarray]] = []
offset = 0
for n in counts:
n = int(n)
boxes_i = boxes_all[offset : offset + n]
out_i: Dict[str, np.ndarray] = {"boxes": np.array(boxes_i)}
if has_masks:
masks_i = masks_all[offset : offset + n]
out_i["masks"] = np.array(masks_i)
batch_outputs.append(out_i)
offset += n
return batch_outputs
# ---------------- GPU provider choose (ROCm first) ----------------
def _pick_gpu_provider(rocm_first: bool = True) -> Optional[str]:
avail = set(ort.get_available_providers())
# ROCm / CUDA provider names in ORT
rocm = "ROCMExecutionProvider"
cuda = "CUDAExecutionProvider"
if rocm_first:
if rocm in avail:
return rocm
if cuda in avail:
return cuda
else:
if cuda in avail:
return cuda
if rocm in avail:
return rocm
return None
def _make_provider_list(gpu_id: int, rocm_first: bool = True) -> List[Any]:
"""
Returns ORT providers config for a single session pinned to gpu_id.
"""
p = _pick_gpu_provider(rocm_first=rocm_first)
if p is None:
# fallback CPU only
return ["CPUExecutionProvider"]
# Most ORT builds accept "device_id" for CUDA/ROCM EP.
# Keep CPU as fallback.
return [(p, {"device_id": int(gpu_id)}), "CPUExecutionProvider"]
# ---------------- High concurrency multi-session pool ----------------
@dataclass
class _Job:
imgs: List[np.ndarray]
conf_thres: float
kwargs: Dict[str, Any]
done: threading.Event
result: Any = None
error: Optional[BaseException] = None
class PPDocLayoutMultiGPUPool:
"""
多GPU + 多Session 高并发推理池:
- gpu_ids=[0,1]:每个gpu上创建 num_sessions_per_gpu 个 session(通常=并发度)
- run(imgs): 支持单张或batch图片,内部挑空闲worker执行
- provider: 优先 ROCm,其次 CUDA,最后 CPU
"""
def __init__(
self,
model_path: str,
gpu_ids: Optional[List[int]] = None,
num_sessions_per_gpu: int = 1,
img_size: Tuple[int, int] = (800, 800),
rocm_first: bool = True,
intra_op_num_threads: int = 1,
inter_op_num_threads: int = 1,
labels_meta_key: str = "character",
):
self.model_path = model_path
self.gpu_ids = gpu_ids or [0]
self.num_sessions_per_gpu = int(max(1, num_sessions_per_gpu))
self.img_size = img_size
self.rocm_first = rocm_first
self.labels = LABELS
self.label2id = {n: i for i, n in enumerate(self.labels)}
# shared pre/post (纯python,线程安全通常OK;若你实现里有可变成员,改成每worker一份)
self.pre = PPDocLayoutPreProcess(img_size=img_size)
# 你原代码:PPDocLayoutPostProcess(labels=labels, scale_size=[W,H]?)
# 这里沿用你写法:[img_size[1], img_size[0]] => [W,H]
self.post = PPDocLayoutPostProcess(labels=self.labels, scale_size=[img_size[1], img_size[0]])
self._job_q: "queue.Queue[_Job]" = queue.Queue()
self._workers: List[threading.Thread] = []
self._stop = False
# build sessions and start worker threads
self._sessions: List[ort.InferenceSession] = []
self._session_locks: List[threading.Lock] = []
so = ort.SessionOptions()
# so.intra_op_num_threads = int(intra_op_num_threads)
# so.inter_op_num_threads = int(inter_op_num_threads)
# === 新增:限制显存霸占的终极杀手锏 ===
# so.enable_mem_pattern = False
# so.enable_cpu_mem_arena = False
# so.add_session_config_entry("session.disable_prepacking", "1")
# # ==================================
for gid in self.gpu_ids:
providers = _make_provider_list(gid, rocm_first=self.rocm_first)
for _ in range(self.num_sessions_per_gpu):
sess = ort.InferenceSession(
self.model_path,
sess_options=so,
providers=providers,
)
self._sessions.append(sess)
self._session_locks.append(threading.Lock())
so.enable_profiling = True
so.profile_file_prefix = "ort_profile"
sess = ort.InferenceSession(model_path, sess_options=so, providers=providers)
p = sess.end_profiling()
print(666)
print(p)
# one worker per session (simple + effective)
for idx in range(len(self._sessions)):
t = threading.Thread(target=self._worker_loop, args=(idx,), daemon=True)
t.start()
self._workers.append(t)
def close(self):
self._stop = True
# unblock workers
for _ in self._workers:
self._job_q.put(_Job(imgs=[], conf_thres=0.0, kwargs={}, done=threading.Event()))
# no explicit session close in ORT py; let GC handle
def run(
self,
imgs: List[np.ndarray] | np.ndarray,
conf_thres: float = 0.5,
**post_kwargs,
):
"""
外部直接调用:
pool.run(imgs, conf_thres=0.5, layout_nms=True, ...)
imgs: 单张(np.ndarray) 或 多张(list[np.ndarray])
return:
单张 -> list[dict]
多张 -> list[list[dict]] (每张图一个list)
"""
single = False
if isinstance(imgs, np.ndarray):
imgs_list = [imgs]
single = True
else:
imgs_list = imgs
job = _Job(
imgs=imgs_list,
conf_thres=float(conf_thres),
kwargs=post_kwargs,
done=threading.Event(),
)
self._job_q.put(job)
job.done.wait()
if job.error:
raise job.error
return job.result[0] if single else job.result
def _worker_loop(self, sess_idx: int):
sess = self._sessions[sess_idx]
lock = self._session_locks[sess_idx]
while True:
job = self._job_q.get()
if self._stop:
job.done.set()
break
if not job.imgs:
# sentinel
job.done.set()
continue
try:
with lock:
outputs = self._run_on_session(sess, job.imgs, job.conf_thres, **job.kwargs)
job.result = outputs
except BaseException as e:
job.error = e
finally:
job.done.set()
def _run_on_session(self, sess: ort.InferenceSession, imgs: List[np.ndarray], conf_thres: float, **post_kwargs):
# preprocess (native batch)
ori_datas, batch_inputs = self.pre.batch(imgs)
ort_inputs = {
"im_shape": np.asarray(batch_inputs[0]),
"image": np.asarray(batch_inputs[1]),
"scale_factor": np.asarray(batch_inputs[2]),
}
# inference
pred = sess.run(None, ort_inputs)
# to post format
batch_outputs = build_batch_outputs_from_pred(pred, batch_size=len(imgs))
# postprocess
# default kwargs consistent with your original call
call_kwargs = dict(
threshold=conf_thres,
layout_nms=True,
layout_shape_mode="auto",
filter_overlap_boxes=True,
skip_order_labels=None,
)
call_kwargs.update(post_kwargs)
boxes_b, scores_b, class_names_b = self.post(
batch_outputs=batch_outputs,
datas=ori_datas,
**call_kwargs,
)
# format (per image list)
outputs: List[List[Dict[str, Any]]] = []
for i in range(len(imgs)):
out_i = []
for box, score, name in zip(boxes_b[i], scores_b[i], class_names_b[i]):
out_i.append(
{
"label": name,
"coordinate": [int(box[0]), int(box[1]), int(box[2]), int(box[3])],
"score": float(score),
"cls_id": int(self.label2id.get(name, -1)),
}
)
outputs.append(out_i)
import gc
del pred, batch_inputs, ort_inputs, batch_outputs, ori_datas
gc.collect()
return outputs
# ---------------- Convenience wrapper function ----------------
def create_layout_engine(
model_path: str,
gpu_ids: Optional[List[int]] = None,
num_sessions_per_gpu: int = 2,
img_size: Tuple[int, int] = (800, 800),
rocm_first: bool = True,
) -> PPDocLayoutMultiGPUPool:
"""
外部调用入口:
engine = create_layout_engine("./cc_v3.onnx", gpu_ids=[0,1], num_sessions_per_gpu=2)
res = engine.run(img) # 单张
res_b = engine.run([img1,img2]) # batch
"""
return PPDocLayoutMultiGPUPool(
model_path=model_path,
gpu_ids=gpu_ids,
num_sessions_per_gpu=num_sessions_per_gpu,
img_size=img_size,
rocm_first=rocm_first,
)
# ---------------- Example usage ----------------
if __name__ == "__main__":
engine = create_layout_engine(
"./cc_v3.onnx",
gpu_ids=gpu_ids,
num_sessions_per_gpu=SESSIONS_PER_GPU,
img_size=(800, 800),
rocm_first=True,
)
img = cv2.imread("img.png")
t0 = time.perf_counter()
out = engine.run(img, conf_thres=0.5)
t1 = time.perf_counter()
print(f"latency: {(t1-t0)*1000:.2f} ms")
print(out)
engine.close()
so = ort.SessionOptions()
so.enable_mem_pattern = True
so.enable_cpu_mem_arena = True
问题描述 / pp doclayout v3在预加载下,基于cpu推理,内存会随着推理次数巨增,直至占满系统内存
运行环境 / Runtime Environment
Windows , python 3.11
复现代码 / Reproduction Code
可能解决方案 / Possible solutions