diff --git a/api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md b/api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md
new file mode 100644
index 0000000..2e15cfb
--- /dev/null
+++ b/api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md
@@ -0,0 +1,109 @@
+# 任务最大重试次数功能实现计划
+
+## 一、需求背景
+
+### 1.1 当前系统现状
+
+当前系统已实现了基于 `processTimeout` 的超时机制:
+- Worker 通过 `take` 接口拉取任务时可指定 `processTimeout`
+- 超时后任务自动标记为完成(status_code: 408)
+- 使用 Redis 键过期 + Keyspace Notifications 实现
+
+**现有问题**:
+- 任务超时后直接标记为完成,无法重试
+- 对于临时性失败(网络抖动、服务暂时不可用),无法自动恢复
+- 用户需要手动重新提交失败的任务
+
+### 1.2 新需求
+
+在 `take` 接口中增加 `processMaxRetries` 参数,实现以下功能:
+- 指定任务的最大重试次数(例如:3次)
+- 任务超时后不直接完成,而是重新放回队列
+- 达到最大重试次数后才标记为最终失败
+- 每次重试都有独立的超时时间
+-
+---
+
+## 二、设计方案
+
+### 2.1 核心思路
+
+**基于现有超时机制扩展**:
+- 复用现有的 Redis 键过期 + Keyspace Notifications 机制
+- 在 Redis 中记录任务的当前重试次数和最大重试次数
+- 超时时的处理逻辑:
+ - 检查任务是否已过期(expireTime),已过期则直接失败,不重试
+ - 检查重试次数是否达到上限,未达到则重新入队,达到上限则标记失败
+
+### 2.2 处理流程
+
+1. **Take 阶段**:初始化重试追踪(retry=0, max_retry=N, process_timeout=T)
+2. **超时触发**:监听 timeout 键过期事件
+3. **过期检查**:检查任务是否已过期(expireTime),已过期则直接标记失败,不重试
+4. **判断重试**:检查当前重试次数是否达到上限(-1 表示无限重试)
+5. **执行动作**:未达上限则重新入队并增加重试计数,达到上限则标记失败并清理追踪键
+
+---
+
+## 三、详细设计
+
+### 3.1 数据模型
+
+#### Take 类新增字段
+
+**文件**: `openai-api/src/main/java/com/theokanning/openai/queue/Take.java`
+
+```java
+private Integer processMaxRetries; // null/0: 不重试, -1: 无限重试, >0: 具体次数
+```
+
+#### Redis 数据结构
+
+**新增键**:
+```
+retry:{taskId} # 当前重试次数(从1开始)
+max_retry:{taskId} # 最大重试次数(-1表示无限)
+process_timeout:{taskId} # 原始超时值(用于重新设置)
+```
+
+**现有键**:
+```
+timeout:{taskId} # 超时控制键(已存在)
+```
+
+### 3.2 核心实现
+
+**涉及文件**:`QueueService.java`
+
+**主要改动**:
+1. **Take 流程**:增加参数校验和重试追踪初始化
+2. **超时处理**:改造监听器回调,增加重试逻辑判断
+3. **任务重入队**:查询任务,检查过期,更新 Redis 队列和重试计数
+4. **Complete 流程**:清理重试追踪键
+
+**关键约束**:
+- 在线队列(level = 0)不支持重试
+- 任务已过期(expireTime)则不重试,直接标记失败
+- 数据库状态保持不变,仅操作 Redis
+
+---
+
+## 四、性能影响评估
+
+### Redis 操作增量
+
+**每次 take(启用重试)**:
+- 现有:N 个任务 × 1 次 SET(timeout键)
+- 新增:N 个任务 × 3 次 SET(retry键 + max_retry键 + process_timeout键)
+- **影响**:Redis 操作量增加 3倍,但使用 Pipeline 批量操作,延迟可控
+
+### 超时处理延迟
+
+- 现有流程:超时 → 直接 complete(~10ms)
+- 新流程:超时 → 读取配置 → requeue(~50ms)
+- **影响**:可接受
+
+### 队列长度
+
+- 重试任务会放回队列头部,不影响正常任务处理
+- 建议设置合理的 `processMaxRetries`(≤10)或使用任务过期时间控制
diff --git a/api/pom.xml b/api/pom.xml
index 0794828..fbeca81 100644
--- a/api/pom.xml
+++ b/api/pom.xml
@@ -37,8 +37,8 @@
1.2.73
- 0.24.1
-
+ 0.24.3
+
30.1-jre
5.1.5
diff --git a/api/src/main/java/com/ke/bella/batch/service/QueueService.java b/api/src/main/java/com/ke/bella/batch/service/QueueService.java
index 90d6df8..ea8e9b7 100644
--- a/api/src/main/java/com/ke/bella/batch/service/QueueService.java
+++ b/api/src/main/java/com/ke/bella/batch/service/QueueService.java
@@ -24,6 +24,7 @@
import com.theokanning.openai.queue.Put;
import com.theokanning.openai.queue.Take;
import com.theokanning.openai.queue.Task;
+import io.micrometer.core.instrument.MeterRegistry;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.jodah.expiringmap.ExpirationListener;
@@ -31,7 +32,6 @@
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
-import io.micrometer.core.instrument.MeterRegistry;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@@ -89,6 +89,8 @@ public class QueueService {
private static final String LOAD_LOCK_PREFIX = "queue:load:lock:";
private static final String TIMEOUT_KEY_PREFIX = "timeout:";
+ private static final String RETRY_COUNT_PREFIX = "retry:";
+ private static final String MAX_RETRY_PREFIX = "max_retry:";
private static final String KEYSPACE_PATTERN = "__keyevent@0__:expired";
@Autowired
@@ -229,7 +231,7 @@ public Map> take(Take take) {
.collect(Collectors.toList());
if(!allTasks.isEmpty() && take.getProcessTimeout() > 0) {
- trackProcessTimeout(allTasks, take.getProcessTimeout());
+ trackTask(allTasks, take);
}
tasksByQueue.forEach((queue, tasks)
@@ -238,7 +240,7 @@ public Map> take(Take take) {
}
public void complete(String taskId, Map result) {
- untrackProcessTimeout(taskId);
+ unTrackTask(taskId);
int level = IDGenerator.parseLevel(taskId);
if(QueueLevel.isOnline(level)) {
QueueMetadataDB queueMeta = queueRepo.findMetadataById(IDGenerator.parseQueueId(taskId));
@@ -433,13 +435,15 @@ private void releaseLock(String lockKey) {
}
}
- private void trackProcessTimeout(List tasks, long timeout) {
+ private void trackTask(List tasks, Take take) {
try (Jedis jedis = jedisPool.getResource()) {
Pipeline pipeline = jedis.pipelined();
for (Task task : tasks) {
String taskId = task.getTaskId();
- pipeline.setex(TIMEOUT_KEY_PREFIX + taskId, timeout, taskId);
+ long ttl = (task.getExpireTime() - System.currentTimeMillis()) / 1000;
+ pipeline.setex(TIMEOUT_KEY_PREFIX + taskId, take.getProcessTimeout(), taskId);
+ pipeline.setex(MAX_RETRY_PREFIX + taskId, ttl, String.valueOf(take.getProcessMaxRetries()));
}
pipeline.sync();
@@ -448,12 +452,6 @@ private void trackProcessTimeout(List tasks, long timeout) {
}
}
- private void untrackProcessTimeout(String taskId) {
- try (Jedis jedis = jedisPool.getResource()) {
- jedis.del(TIMEOUT_KEY_PREFIX + taskId);
- }
- }
-
@SneakyThrows
private void taskProcessExpireListener() {
JedisPubSub pubSub = new JedisPubSub() {
@@ -464,13 +462,15 @@ public void onPMessage(String pattern, String channel, String message) {
}
String taskId = message.substring(TIMEOUT_KEY_PREFIX.length());
TaskExecutor.submit(() -> {
- try {
- Map result = new HashMap<>();
- result.put("status_code", "408");
- result.put("request_id", taskId);
- complete(taskId, result);
+ try (Jedis jedis = jedisPool.getResource()) {
+ if(shouldRetry(taskId, jedis)) {
+ reEnqueueTask(taskId, jedis);
+ } else {
+ completeWithTimeout(taskId);
+ }
} catch (Exception e) {
- log.error("Failed to handle process timeout for task: {}", taskId, e);
+ log.error("Failed to handle timeout for task: {}", taskId, e);
+ completeWithTimeout(taskId);
}
});
}
@@ -485,4 +485,64 @@ public void onPMessage(String pattern, String channel, String message) {
}
}
+ private boolean shouldRetry(String taskId, Jedis jedis) {
+ String maxRetryStr = jedis.get(MAX_RETRY_PREFIX + taskId);
+ if(maxRetryStr == null) {
+ return false;
+ }
+ int maxRetries = Integer.parseInt(maxRetryStr);
+ if(maxRetries == 0) {
+ return false;
+ }
+
+ if(maxRetries == -1) {
+ return true;
+ }
+
+ String retryCountStr = jedis.get(RETRY_COUNT_PREFIX + taskId);
+ int currentRetryCount = retryCountStr != null ? Integer.parseInt(retryCountStr) : 0;
+ return currentRetryCount < maxRetries;
+ }
+
+ private void completeWithTimeout(String taskId) {
+ try {
+ Map result = new HashMap<>();
+ result.put("status_code", "408");
+ result.put("request_id", taskId);
+ complete(taskId, result);
+ } catch (Exception e) {
+ log.error("Failed to complete task {} with timeout", taskId, e);
+ }
+ }
+
+ private void reEnqueueTask(String taskId, Jedis jedis) {
+ QueueDB queueDB = queueRepo.findTask(taskId);
+ if(queueDB == null) {
+ throw new IllegalStateException("Task not found: " + taskId);
+ }
+ Task task = queueRepo.parseTask(queueDB);
+ if(task.isExpire()) {
+ throw new IllegalStateException("Task expired: " + taskId);
+ }
+
+ getQueue(task.getFullQueueName()).add(task);
+
+ String retryCountStr = jedis.get(RETRY_COUNT_PREFIX + taskId);
+ int currentRetryCount = retryCountStr != null ? Integer.parseInt(retryCountStr) : 0;
+ long ttl = TimeUtils.toSeconds(queueDB.getExpiredAt());
+ jedis.setex(RETRY_COUNT_PREFIX + taskId, ttl, String.valueOf(currentRetryCount + 1));
+ }
+
+ private void unTrackTask(String taskId) {
+ try (Jedis jedis = jedisPool.getResource()) {
+ jedis.del(
+ TIMEOUT_KEY_PREFIX + taskId,
+ RETRY_COUNT_PREFIX + taskId,
+ MAX_RETRY_PREFIX + taskId
+ );
+ } catch (Exception e) {
+ log.error("Failed to cleanup task tracking for task: {}", taskId, e);
+ }
+ }
+
}
diff --git a/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java b/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java
index b5cf623..93bf912 100644
--- a/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java
+++ b/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java
@@ -30,4 +30,13 @@ public static LocalDateTime parseTimestamp(String timestampStr) {
public static String formatTimestamp(LocalDateTime dateTime) {
return dateTime.format(TIMESTAMP_FORMATTER);
}
+
+ public static long toSeconds(LocalDateTime dateTime) {
+ if(dateTime == null) {
+ return 0;
+ }
+ long targetEpochMilli = toEpochMilli(dateTime);
+ long currentEpochMilli = System.currentTimeMillis();
+ return (targetEpochMilli - currentEpochMilli) / 1000;
+ }
}
diff --git a/api/src/test/java/com/ke/bella/batch/service/QueueServiceRetryTest.java b/api/src/test/java/com/ke/bella/batch/service/QueueServiceRetryTest.java
new file mode 100644
index 0000000..39ef29e
--- /dev/null
+++ b/api/src/test/java/com/ke/bella/batch/service/QueueServiceRetryTest.java
@@ -0,0 +1,230 @@
+package com.ke.bella.batch.service;
+
+import com.ke.bella.batch.db.repo.QueueRepo;
+import com.ke.bella.batch.tables.pojos.QueueDB;
+import com.theokanning.openai.queue.Task;
+import org.junit.Test;
+import redis.clients.jedis.Jedis;
+
+import java.lang.reflect.Method;
+import java.time.LocalDateTime;
+import java.util.HashMap;
+import java.util.concurrent.BlockingQueue;
+
+import static org.junit.Assert.*;
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.*;
+
+/**
+ * 测试 QueueService 中的重试相关逻辑
+ * 注意:这些测试使用反射调用私有方法
+ */
+public class QueueServiceRetryTest {
+
+ /**
+ * 测试 reEnqueueTask 方法 - 首次重试场景
+ * 验证逻辑: 1. 任务存在且未过期 2. retry 键不存在(首次重试) 3. 应该设置 retry:taskId = 1
+ */
+ @Test
+ public void testReEnqueueTask_FirstRetry() throws Exception {
+ // 准备测试数据
+ String taskId = "TASK-1-1-C-260302120000-0001-000001";
+
+ // Mock 对象
+ QueueRepo queueRepo = mock(QueueRepo.class);
+ Jedis jedis = mock(Jedis.class);
+ BlockingQueue mockQueue = mock(BlockingQueue.class);
+
+ // Mock 数据库返回任务
+ QueueDB queueDB = createMockQueueDB(taskId, LocalDateTime.now().plusHours(1));
+ when(queueRepo.findTask(taskId)).thenReturn(queueDB);
+
+ // Mock 解析任务
+ Task task = createMockTask(taskId, "test-queue", 1);
+ when(queueRepo.parseTask(queueDB)).thenReturn(task);
+
+ // Mock Redis 操作 - 首次重试,retry 键不存在
+ when(jedis.get("retry:" + taskId)).thenReturn(null);
+
+ // 创建 QueueService 实例并注入依赖
+ QueueService queueService = createQueueServiceWithMocks(queueRepo, mockQueue);
+
+ // 调用私有方法
+ Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class);
+ method.setAccessible(true);
+ method.invoke(queueService, taskId, jedis);
+
+ // 验证任务被重新入队
+ verify(mockQueue).add(task);
+
+ // 验证重试次数被设置为 1
+ verify(jedis).setex(eq("retry:" + taskId), anyLong(), eq("1"));
+ }
+
+ /**
+ * 测试 reEnqueueTask 方法 - 第二次重试场景
+ * 验证逻辑: 1. 任务存在且未过期 2. retry 键已存在且值为 1 3. 应该更新 retry:taskId = 2
+ */
+ @Test
+ public void testReEnqueueTask_SecondRetry() throws Exception {
+ String taskId = "TASK-1-1-C-260302120000-0001-000001";
+
+ QueueRepo queueRepo = mock(QueueRepo.class);
+ Jedis jedis = mock(Jedis.class);
+ BlockingQueue mockQueue = mock(BlockingQueue.class);
+
+ QueueDB queueDB = createMockQueueDB(taskId, LocalDateTime.now().plusHours(1));
+ when(queueRepo.findTask(taskId)).thenReturn(queueDB);
+
+ Task task = createMockTask(taskId, "test-queue", 1);
+ when(queueRepo.parseTask(queueDB)).thenReturn(task);
+
+ // Mock Redis 操作 - 第二次重试,retry 键值为 1
+ when(jedis.get("retry:" + taskId)).thenReturn("1");
+
+ QueueService queueService = createQueueServiceWithMocks(queueRepo, mockQueue);
+
+ Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class);
+ method.setAccessible(true);
+ method.invoke(queueService, taskId, jedis);
+
+ verify(mockQueue).add(task);
+ verify(jedis).setex(eq("retry:" + taskId), anyLong(), eq("2"));
+ }
+
+ /**
+ * 测试 reEnqueueTask 方法 - 任务不存在
+ * 验证逻辑: 1. 数据库返回 null 2. 应该抛出 IllegalStateException
+ */
+ @Test
+ public void testReEnqueueTask_TaskNotFound() throws Exception {
+ String taskId = "TASK-1-1-C-260302120000-0001-000001";
+
+ QueueRepo queueRepo = mock(QueueRepo.class);
+ Jedis jedis = mock(Jedis.class);
+
+ when(queueRepo.findTask(taskId)).thenReturn(null);
+
+ QueueService queueService = createQueueServiceWithMocks(queueRepo, null);
+
+ Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class);
+ method.setAccessible(true);
+
+ try {
+ method.invoke(queueService, taskId, jedis);
+ fail("Should throw IllegalStateException");
+ } catch (Exception e) {
+ assertTrue(e.getCause() instanceof IllegalStateException);
+ assertTrue(e.getCause().getMessage().contains("Task not found"));
+ }
+ }
+
+ /**
+ * 测试 reEnqueueTask 方法 - 任务已过期
+ * 验证逻辑: 1. 任务的 expireTime 已经过期 2. 应该抛出 IllegalStateException
+ */
+ @Test
+ public void testReEnqueueTask_TaskExpired() throws Exception {
+ String taskId = "TASK-1-1-C-260302120000-0001-000001";
+
+ QueueRepo queueRepo = mock(QueueRepo.class);
+ Jedis jedis = mock(Jedis.class);
+
+ QueueDB queueDB = createMockQueueDB(taskId, LocalDateTime.now().minusHours(1));
+ when(queueRepo.findTask(taskId)).thenReturn(queueDB);
+
+ Task task = createMockTask(taskId, "test-queue", 1);
+ task.setExpireTime(System.currentTimeMillis() - 3600000); // 1小时前过期
+ when(queueRepo.parseTask(queueDB)).thenReturn(task);
+
+ QueueService queueService = createQueueServiceWithMocks(queueRepo, null);
+
+ Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class);
+ method.setAccessible(true);
+
+ try {
+ method.invoke(queueService, taskId, jedis);
+ fail("Should throw IllegalStateException");
+ } catch (Exception e) {
+ assertTrue(e.getCause() instanceof IllegalStateException);
+ assertTrue(e.getCause().getMessage().contains("Task expired"));
+ }
+ }
+
+ /**
+ * 测试 unTrackTask 方法 - 正常清理
+ * 验证逻辑: 1. 调用 unTrackTask 2. 应该删除所有追踪键
+ */
+ @Test
+ public void testUnTrackTask_Success() throws Exception {
+ String taskId = "TASK-1-1-C-260302120000-0001-000001";
+
+ Jedis jedis = mock(Jedis.class);
+ redis.clients.jedis.JedisPool jedisPool = mock(redis.clients.jedis.JedisPool.class);
+ when(jedisPool.getResource()).thenReturn(jedis);
+
+ QueueService queueService = new QueueService();
+ injectField(queueService, "jedisPool", jedisPool);
+
+ Method method = QueueService.class.getDeclaredMethod("unTrackTask", String.class);
+ method.setAccessible(true);
+ method.invoke(queueService, taskId);
+
+ // 验证删除了所有追踪键
+ verify(jedis).del(
+ eq("timeout:" + taskId),
+ eq("retry:" + taskId),
+ eq("max_retry:" + taskId)
+ );
+
+ // 验证 Jedis 资源被关闭
+ verify(jedis).close();
+ }
+
+ // ============== 辅助方法 ==============
+
+ private QueueService createQueueServiceWithMocks(QueueRepo queueRepo, BlockingQueue mockQueue) throws Exception {
+ QueueService queueService = new QueueService();
+ injectField(queueService, "queueRepo", queueRepo);
+
+ if(mockQueue != null) {
+ // 注入 QUEUE_CACHE
+ java.lang.reflect.Field cacheField = QueueService.class.getDeclaredField("QUEUE_CACHE");
+ cacheField.setAccessible(true);
+
+ @SuppressWarnings("unchecked")
+ com.google.common.cache.Cache> cache =
+ com.google.common.cache.CacheBuilder.newBuilder().build();
+
+ cache.put("test-queue:1", mockQueue);
+ cacheField.set(queueService, cache);
+ }
+
+ return queueService;
+ }
+
+ private void injectField(Object target, String fieldName, Object value) throws Exception {
+ java.lang.reflect.Field field = QueueService.class.getDeclaredField(fieldName);
+ field.setAccessible(true);
+ field.set(target, value);
+ }
+
+ private QueueDB createMockQueueDB(String taskId, LocalDateTime expiredAt) {
+ QueueDB queueDB = new QueueDB();
+ queueDB.setTaskId(taskId);
+ queueDB.setQueue("test-queue");
+ queueDB.setExpiredAt(expiredAt);
+ queueDB.setStatus("queued");
+ return queueDB;
+ }
+
+ private Task createMockTask(String taskId, String queue, int level) {
+ Task task = new Task();
+ task.setTaskId(taskId);
+ task.setQueue(queue);
+ task.setLevel(level);
+ task.setExpireTime(System.currentTimeMillis() + 3600000); // 1小时后过期
+ task.setData(new HashMap<>());
+ return task;
+ }
+}
diff --git a/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java b/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java
index 1d28eda..c34e204 100644
--- a/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java
+++ b/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java
@@ -10,6 +10,7 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
@@ -20,11 +21,19 @@
import java.util.Map;
import java.util.concurrent.BlockingQueue;
-import static org.junit.Assert.*;
-import static org.mockito.ArgumentMatchers.*;
-import static org.mockito.Mockito.*;
-
-import org.mockito.ArgumentCaptor;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class QueueServiceTest {