diff --git a/api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md b/api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md new file mode 100644 index 0000000..2e15cfb --- /dev/null +++ b/api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md @@ -0,0 +1,109 @@ +# 任务最大重试次数功能实现计划 + +## 一、需求背景 + +### 1.1 当前系统现状 + +当前系统已实现了基于 `processTimeout` 的超时机制: +- Worker 通过 `take` 接口拉取任务时可指定 `processTimeout` +- 超时后任务自动标记为完成(status_code: 408) +- 使用 Redis 键过期 + Keyspace Notifications 实现 + +**现有问题**: +- 任务超时后直接标记为完成,无法重试 +- 对于临时性失败(网络抖动、服务暂时不可用),无法自动恢复 +- 用户需要手动重新提交失败的任务 + +### 1.2 新需求 + +在 `take` 接口中增加 `processMaxRetries` 参数,实现以下功能: +- 指定任务的最大重试次数(例如:3次) +- 任务超时后不直接完成,而是重新放回队列 +- 达到最大重试次数后才标记为最终失败 +- 每次重试都有独立的超时时间 +- +--- + +## 二、设计方案 + +### 2.1 核心思路 + +**基于现有超时机制扩展**: +- 复用现有的 Redis 键过期 + Keyspace Notifications 机制 +- 在 Redis 中记录任务的当前重试次数和最大重试次数 +- 超时时的处理逻辑: + - 检查任务是否已过期(expireTime),已过期则直接失败,不重试 + - 检查重试次数是否达到上限,未达到则重新入队,达到上限则标记失败 + +### 2.2 处理流程 + +1. **Take 阶段**:初始化重试追踪(retry=0, max_retry=N, process_timeout=T) +2. **超时触发**:监听 timeout 键过期事件 +3. **过期检查**:检查任务是否已过期(expireTime),已过期则直接标记失败,不重试 +4. **判断重试**:检查当前重试次数是否达到上限(-1 表示无限重试) +5. **执行动作**:未达上限则重新入队并增加重试计数,达到上限则标记失败并清理追踪键 + +--- + +## 三、详细设计 + +### 3.1 数据模型 + +#### Take 类新增字段 + +**文件**: `openai-api/src/main/java/com/theokanning/openai/queue/Take.java` + +```java +private Integer processMaxRetries; // null/0: 不重试, -1: 无限重试, >0: 具体次数 +``` + +#### Redis 数据结构 + +**新增键**: +``` +retry:{taskId} # 当前重试次数(从1开始) +max_retry:{taskId} # 最大重试次数(-1表示无限) +process_timeout:{taskId} # 原始超时值(用于重新设置) +``` + +**现有键**: +``` +timeout:{taskId} # 超时控制键(已存在) +``` + +### 3.2 核心实现 + +**涉及文件**:`QueueService.java` + +**主要改动**: +1. **Take 流程**:增加参数校验和重试追踪初始化 +2. **超时处理**:改造监听器回调,增加重试逻辑判断 +3. **任务重入队**:查询任务,检查过期,更新 Redis 队列和重试计数 +4. **Complete 流程**:清理重试追踪键 + +**关键约束**: +- 在线队列(level = 0)不支持重试 +- 任务已过期(expireTime)则不重试,直接标记失败 +- 数据库状态保持不变,仅操作 Redis + +--- + +## 四、性能影响评估 + +### Redis 操作增量 + +**每次 take(启用重试)**: +- 现有:N 个任务 × 1 次 SET(timeout键) +- 新增:N 个任务 × 3 次 SET(retry键 + max_retry键 + process_timeout键) +- **影响**:Redis 操作量增加 3倍,但使用 Pipeline 批量操作,延迟可控 + +### 超时处理延迟 + +- 现有流程:超时 → 直接 complete(~10ms) +- 新流程:超时 → 读取配置 → requeue(~50ms) +- **影响**:可接受 + +### 队列长度 + +- 重试任务会放回队列头部,不影响正常任务处理 +- 建议设置合理的 `processMaxRetries`(≤10)或使用任务过期时间控制 diff --git a/api/pom.xml b/api/pom.xml index 0794828..fbeca81 100644 --- a/api/pom.xml +++ b/api/pom.xml @@ -37,8 +37,8 @@ 1.2.73 - 0.24.1 - + 0.24.3 + 30.1-jre 5.1.5 diff --git a/api/src/main/java/com/ke/bella/batch/service/QueueService.java b/api/src/main/java/com/ke/bella/batch/service/QueueService.java index 90d6df8..ea8e9b7 100644 --- a/api/src/main/java/com/ke/bella/batch/service/QueueService.java +++ b/api/src/main/java/com/ke/bella/batch/service/QueueService.java @@ -24,6 +24,7 @@ import com.theokanning.openai.queue.Put; import com.theokanning.openai.queue.Take; import com.theokanning.openai.queue.Task; +import io.micrometer.core.instrument.MeterRegistry; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import net.jodah.expiringmap.ExpirationListener; @@ -31,7 +32,6 @@ import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; -import io.micrometer.core.instrument.MeterRegistry; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @@ -89,6 +89,8 @@ public class QueueService { private static final String LOAD_LOCK_PREFIX = "queue:load:lock:"; private static final String TIMEOUT_KEY_PREFIX = "timeout:"; + private static final String RETRY_COUNT_PREFIX = "retry:"; + private static final String MAX_RETRY_PREFIX = "max_retry:"; private static final String KEYSPACE_PATTERN = "__keyevent@0__:expired"; @Autowired @@ -229,7 +231,7 @@ public Map> take(Take take) { .collect(Collectors.toList()); if(!allTasks.isEmpty() && take.getProcessTimeout() > 0) { - trackProcessTimeout(allTasks, take.getProcessTimeout()); + trackTask(allTasks, take); } tasksByQueue.forEach((queue, tasks) @@ -238,7 +240,7 @@ public Map> take(Take take) { } public void complete(String taskId, Map result) { - untrackProcessTimeout(taskId); + unTrackTask(taskId); int level = IDGenerator.parseLevel(taskId); if(QueueLevel.isOnline(level)) { QueueMetadataDB queueMeta = queueRepo.findMetadataById(IDGenerator.parseQueueId(taskId)); @@ -433,13 +435,15 @@ private void releaseLock(String lockKey) { } } - private void trackProcessTimeout(List tasks, long timeout) { + private void trackTask(List tasks, Take take) { try (Jedis jedis = jedisPool.getResource()) { Pipeline pipeline = jedis.pipelined(); for (Task task : tasks) { String taskId = task.getTaskId(); - pipeline.setex(TIMEOUT_KEY_PREFIX + taskId, timeout, taskId); + long ttl = (task.getExpireTime() - System.currentTimeMillis()) / 1000; + pipeline.setex(TIMEOUT_KEY_PREFIX + taskId, take.getProcessTimeout(), taskId); + pipeline.setex(MAX_RETRY_PREFIX + taskId, ttl, String.valueOf(take.getProcessMaxRetries())); } pipeline.sync(); @@ -448,12 +452,6 @@ private void trackProcessTimeout(List tasks, long timeout) { } } - private void untrackProcessTimeout(String taskId) { - try (Jedis jedis = jedisPool.getResource()) { - jedis.del(TIMEOUT_KEY_PREFIX + taskId); - } - } - @SneakyThrows private void taskProcessExpireListener() { JedisPubSub pubSub = new JedisPubSub() { @@ -464,13 +462,15 @@ public void onPMessage(String pattern, String channel, String message) { } String taskId = message.substring(TIMEOUT_KEY_PREFIX.length()); TaskExecutor.submit(() -> { - try { - Map result = new HashMap<>(); - result.put("status_code", "408"); - result.put("request_id", taskId); - complete(taskId, result); + try (Jedis jedis = jedisPool.getResource()) { + if(shouldRetry(taskId, jedis)) { + reEnqueueTask(taskId, jedis); + } else { + completeWithTimeout(taskId); + } } catch (Exception e) { - log.error("Failed to handle process timeout for task: {}", taskId, e); + log.error("Failed to handle timeout for task: {}", taskId, e); + completeWithTimeout(taskId); } }); } @@ -485,4 +485,64 @@ public void onPMessage(String pattern, String channel, String message) { } } + private boolean shouldRetry(String taskId, Jedis jedis) { + String maxRetryStr = jedis.get(MAX_RETRY_PREFIX + taskId); + if(maxRetryStr == null) { + return false; + } + int maxRetries = Integer.parseInt(maxRetryStr); + if(maxRetries == 0) { + return false; + } + + if(maxRetries == -1) { + return true; + } + + String retryCountStr = jedis.get(RETRY_COUNT_PREFIX + taskId); + int currentRetryCount = retryCountStr != null ? Integer.parseInt(retryCountStr) : 0; + return currentRetryCount < maxRetries; + } + + private void completeWithTimeout(String taskId) { + try { + Map result = new HashMap<>(); + result.put("status_code", "408"); + result.put("request_id", taskId); + complete(taskId, result); + } catch (Exception e) { + log.error("Failed to complete task {} with timeout", taskId, e); + } + } + + private void reEnqueueTask(String taskId, Jedis jedis) { + QueueDB queueDB = queueRepo.findTask(taskId); + if(queueDB == null) { + throw new IllegalStateException("Task not found: " + taskId); + } + Task task = queueRepo.parseTask(queueDB); + if(task.isExpire()) { + throw new IllegalStateException("Task expired: " + taskId); + } + + getQueue(task.getFullQueueName()).add(task); + + String retryCountStr = jedis.get(RETRY_COUNT_PREFIX + taskId); + int currentRetryCount = retryCountStr != null ? Integer.parseInt(retryCountStr) : 0; + long ttl = TimeUtils.toSeconds(queueDB.getExpiredAt()); + jedis.setex(RETRY_COUNT_PREFIX + taskId, ttl, String.valueOf(currentRetryCount + 1)); + } + + private void unTrackTask(String taskId) { + try (Jedis jedis = jedisPool.getResource()) { + jedis.del( + TIMEOUT_KEY_PREFIX + taskId, + RETRY_COUNT_PREFIX + taskId, + MAX_RETRY_PREFIX + taskId + ); + } catch (Exception e) { + log.error("Failed to cleanup task tracking for task: {}", taskId, e); + } + } + } diff --git a/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java b/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java index b5cf623..93bf912 100644 --- a/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java +++ b/api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java @@ -30,4 +30,13 @@ public static LocalDateTime parseTimestamp(String timestampStr) { public static String formatTimestamp(LocalDateTime dateTime) { return dateTime.format(TIMESTAMP_FORMATTER); } + + public static long toSeconds(LocalDateTime dateTime) { + if(dateTime == null) { + return 0; + } + long targetEpochMilli = toEpochMilli(dateTime); + long currentEpochMilli = System.currentTimeMillis(); + return (targetEpochMilli - currentEpochMilli) / 1000; + } } diff --git a/api/src/test/java/com/ke/bella/batch/service/QueueServiceRetryTest.java b/api/src/test/java/com/ke/bella/batch/service/QueueServiceRetryTest.java new file mode 100644 index 0000000..39ef29e --- /dev/null +++ b/api/src/test/java/com/ke/bella/batch/service/QueueServiceRetryTest.java @@ -0,0 +1,230 @@ +package com.ke.bella.batch.service; + +import com.ke.bella.batch.db.repo.QueueRepo; +import com.ke.bella.batch.tables.pojos.QueueDB; +import com.theokanning.openai.queue.Task; +import org.junit.Test; +import redis.clients.jedis.Jedis; + +import java.lang.reflect.Method; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.concurrent.BlockingQueue; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +/** + * 测试 QueueService 中的重试相关逻辑 + * 注意:这些测试使用反射调用私有方法 + */ +public class QueueServiceRetryTest { + + /** + * 测试 reEnqueueTask 方法 - 首次重试场景 + * 验证逻辑: 1. 任务存在且未过期 2. retry 键不存在(首次重试) 3. 应该设置 retry:taskId = 1 + */ + @Test + public void testReEnqueueTask_FirstRetry() throws Exception { + // 准备测试数据 + String taskId = "TASK-1-1-C-260302120000-0001-000001"; + + // Mock 对象 + QueueRepo queueRepo = mock(QueueRepo.class); + Jedis jedis = mock(Jedis.class); + BlockingQueue mockQueue = mock(BlockingQueue.class); + + // Mock 数据库返回任务 + QueueDB queueDB = createMockQueueDB(taskId, LocalDateTime.now().plusHours(1)); + when(queueRepo.findTask(taskId)).thenReturn(queueDB); + + // Mock 解析任务 + Task task = createMockTask(taskId, "test-queue", 1); + when(queueRepo.parseTask(queueDB)).thenReturn(task); + + // Mock Redis 操作 - 首次重试,retry 键不存在 + when(jedis.get("retry:" + taskId)).thenReturn(null); + + // 创建 QueueService 实例并注入依赖 + QueueService queueService = createQueueServiceWithMocks(queueRepo, mockQueue); + + // 调用私有方法 + Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class); + method.setAccessible(true); + method.invoke(queueService, taskId, jedis); + + // 验证任务被重新入队 + verify(mockQueue).add(task); + + // 验证重试次数被设置为 1 + verify(jedis).setex(eq("retry:" + taskId), anyLong(), eq("1")); + } + + /** + * 测试 reEnqueueTask 方法 - 第二次重试场景 + * 验证逻辑: 1. 任务存在且未过期 2. retry 键已存在且值为 1 3. 应该更新 retry:taskId = 2 + */ + @Test + public void testReEnqueueTask_SecondRetry() throws Exception { + String taskId = "TASK-1-1-C-260302120000-0001-000001"; + + QueueRepo queueRepo = mock(QueueRepo.class); + Jedis jedis = mock(Jedis.class); + BlockingQueue mockQueue = mock(BlockingQueue.class); + + QueueDB queueDB = createMockQueueDB(taskId, LocalDateTime.now().plusHours(1)); + when(queueRepo.findTask(taskId)).thenReturn(queueDB); + + Task task = createMockTask(taskId, "test-queue", 1); + when(queueRepo.parseTask(queueDB)).thenReturn(task); + + // Mock Redis 操作 - 第二次重试,retry 键值为 1 + when(jedis.get("retry:" + taskId)).thenReturn("1"); + + QueueService queueService = createQueueServiceWithMocks(queueRepo, mockQueue); + + Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class); + method.setAccessible(true); + method.invoke(queueService, taskId, jedis); + + verify(mockQueue).add(task); + verify(jedis).setex(eq("retry:" + taskId), anyLong(), eq("2")); + } + + /** + * 测试 reEnqueueTask 方法 - 任务不存在 + * 验证逻辑: 1. 数据库返回 null 2. 应该抛出 IllegalStateException + */ + @Test + public void testReEnqueueTask_TaskNotFound() throws Exception { + String taskId = "TASK-1-1-C-260302120000-0001-000001"; + + QueueRepo queueRepo = mock(QueueRepo.class); + Jedis jedis = mock(Jedis.class); + + when(queueRepo.findTask(taskId)).thenReturn(null); + + QueueService queueService = createQueueServiceWithMocks(queueRepo, null); + + Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class); + method.setAccessible(true); + + try { + method.invoke(queueService, taskId, jedis); + fail("Should throw IllegalStateException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof IllegalStateException); + assertTrue(e.getCause().getMessage().contains("Task not found")); + } + } + + /** + * 测试 reEnqueueTask 方法 - 任务已过期 + * 验证逻辑: 1. 任务的 expireTime 已经过期 2. 应该抛出 IllegalStateException + */ + @Test + public void testReEnqueueTask_TaskExpired() throws Exception { + String taskId = "TASK-1-1-C-260302120000-0001-000001"; + + QueueRepo queueRepo = mock(QueueRepo.class); + Jedis jedis = mock(Jedis.class); + + QueueDB queueDB = createMockQueueDB(taskId, LocalDateTime.now().minusHours(1)); + when(queueRepo.findTask(taskId)).thenReturn(queueDB); + + Task task = createMockTask(taskId, "test-queue", 1); + task.setExpireTime(System.currentTimeMillis() - 3600000); // 1小时前过期 + when(queueRepo.parseTask(queueDB)).thenReturn(task); + + QueueService queueService = createQueueServiceWithMocks(queueRepo, null); + + Method method = QueueService.class.getDeclaredMethod("reEnqueueTask", String.class, Jedis.class); + method.setAccessible(true); + + try { + method.invoke(queueService, taskId, jedis); + fail("Should throw IllegalStateException"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof IllegalStateException); + assertTrue(e.getCause().getMessage().contains("Task expired")); + } + } + + /** + * 测试 unTrackTask 方法 - 正常清理 + * 验证逻辑: 1. 调用 unTrackTask 2. 应该删除所有追踪键 + */ + @Test + public void testUnTrackTask_Success() throws Exception { + String taskId = "TASK-1-1-C-260302120000-0001-000001"; + + Jedis jedis = mock(Jedis.class); + redis.clients.jedis.JedisPool jedisPool = mock(redis.clients.jedis.JedisPool.class); + when(jedisPool.getResource()).thenReturn(jedis); + + QueueService queueService = new QueueService(); + injectField(queueService, "jedisPool", jedisPool); + + Method method = QueueService.class.getDeclaredMethod("unTrackTask", String.class); + method.setAccessible(true); + method.invoke(queueService, taskId); + + // 验证删除了所有追踪键 + verify(jedis).del( + eq("timeout:" + taskId), + eq("retry:" + taskId), + eq("max_retry:" + taskId) + ); + + // 验证 Jedis 资源被关闭 + verify(jedis).close(); + } + + // ============== 辅助方法 ============== + + private QueueService createQueueServiceWithMocks(QueueRepo queueRepo, BlockingQueue mockQueue) throws Exception { + QueueService queueService = new QueueService(); + injectField(queueService, "queueRepo", queueRepo); + + if(mockQueue != null) { + // 注入 QUEUE_CACHE + java.lang.reflect.Field cacheField = QueueService.class.getDeclaredField("QUEUE_CACHE"); + cacheField.setAccessible(true); + + @SuppressWarnings("unchecked") + com.google.common.cache.Cache> cache = + com.google.common.cache.CacheBuilder.newBuilder().build(); + + cache.put("test-queue:1", mockQueue); + cacheField.set(queueService, cache); + } + + return queueService; + } + + private void injectField(Object target, String fieldName, Object value) throws Exception { + java.lang.reflect.Field field = QueueService.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(target, value); + } + + private QueueDB createMockQueueDB(String taskId, LocalDateTime expiredAt) { + QueueDB queueDB = new QueueDB(); + queueDB.setTaskId(taskId); + queueDB.setQueue("test-queue"); + queueDB.setExpiredAt(expiredAt); + queueDB.setStatus("queued"); + return queueDB; + } + + private Task createMockTask(String taskId, String queue, int level) { + Task task = new Task(); + task.setTaskId(taskId); + task.setQueue(queue); + task.setLevel(level); + task.setExpireTime(System.currentTimeMillis() + 3600000); // 1小时后过期 + task.setData(new HashMap<>()); + return task; + } +} diff --git a/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java b/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java index 1d28eda..c34e204 100644 --- a/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java +++ b/api/src/test/java/com/ke/bella/batch/service/QueueServiceTest.java @@ -10,6 +10,7 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -20,11 +21,19 @@ import java.util.Map; import java.util.concurrent.BlockingQueue; -import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.*; - -import org.mockito.ArgumentCaptor; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class QueueServiceTest {