Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions api/docs/developer/PROCESS-MAX-RETRIES-PLAN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# 任务最大重试次数功能实现计划

## 一、需求背景

### 1.1 当前系统现状

当前系统已实现了基于 `processTimeout` 的超时机制:
- Worker 通过 `take` 接口拉取任务时可指定 `processTimeout`
- 超时后任务自动标记为完成(status_code: 408)
- 使用 Redis 键过期 + Keyspace Notifications 实现

**现有问题**:
- 任务超时后直接标记为完成,无法重试
- 对于临时性失败(网络抖动、服务暂时不可用),无法自动恢复
- 用户需要手动重新提交失败的任务

### 1.2 新需求

在 `take` 接口中增加 `processMaxRetries` 参数,实现以下功能:
- 指定任务的最大重试次数(例如:3次)
- 任务超时后不直接完成,而是重新放回队列
- 达到最大重试次数后才标记为最终失败
- 每次重试都有独立的超时时间
-
---

## 二、设计方案

### 2.1 核心思路

**基于现有超时机制扩展**:
- 复用现有的 Redis 键过期 + Keyspace Notifications 机制
- 在 Redis 中记录任务的当前重试次数和最大重试次数
- 超时时的处理逻辑:
- 检查任务是否已过期(expireTime),已过期则直接失败,不重试
- 检查重试次数是否达到上限,未达到则重新入队,达到上限则标记失败

### 2.2 处理流程

1. **Take 阶段**:初始化重试追踪(retry=0, max_retry=N, process_timeout=T)
2. **超时触发**:监听 timeout 键过期事件
3. **过期检查**:检查任务是否已过期(expireTime),已过期则直接标记失败,不重试
4. **判断重试**:检查当前重试次数是否达到上限(-1 表示无限重试)
5. **执行动作**:未达上限则重新入队并增加重试计数,达到上限则标记失败并清理追踪键

---

## 三、详细设计

### 3.1 数据模型

#### Take 类新增字段

**文件**: `openai-api/src/main/java/com/theokanning/openai/queue/Take.java`

```java
private Integer processMaxRetries; // null/0: 不重试, -1: 无限重试, >0: 具体次数
```

#### Redis 数据结构

**新增键**:
```
retry:{taskId} # 当前重试次数(从1开始)
max_retry:{taskId} # 最大重试次数(-1表示无限)
process_timeout:{taskId} # 原始超时值(用于重新设置)
```

**现有键**:
```
timeout:{taskId} # 超时控制键(已存在)
```

### 3.2 核心实现

**涉及文件**:`QueueService.java`

**主要改动**:
1. **Take 流程**:增加参数校验和重试追踪初始化
2. **超时处理**:改造监听器回调,增加重试逻辑判断
3. **任务重入队**:查询任务,检查过期,更新 Redis 队列和重试计数
4. **Complete 流程**:清理重试追踪键

**关键约束**:
- 在线队列(level = 0)不支持重试
- 任务已过期(expireTime)则不重试,直接标记失败
- 数据库状态保持不变,仅操作 Redis

---

## 四、性能影响评估

### Redis 操作增量

**每次 take(启用重试)**:
- 现有:N 个任务 × 1 次 SET(timeout键)
- 新增:N 个任务 × 3 次 SET(retry键 + max_retry键 + process_timeout键)
- **影响**:Redis 操作量增加 3倍,但使用 Pipeline 批量操作,延迟可控

### 超时处理延迟

- 现有流程:超时 → 直接 complete(~10ms)
- 新流程:超时 → 读取配置 → requeue(~50ms)
- **影响**:可接受

### 队列长度

- 重试任务会放回队列头部,不影响正常任务处理
- 建议设置合理的 `processMaxRetries`(≤10)或使用任务过期时间控制
4 changes: 2 additions & 2 deletions api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@

<!-- Bella Framework Versions -->
<bella-openapi.version>1.2.73</bella-openapi.version>
<bella-openai.version>0.24.1</bella-openai.version>
<bella-openai.version>0.24.3</bella-openai.version>

<!-- Utility Library Versions -->
<guava.version>30.1-jre</guava.version>
<jedis.version>5.1.5</jedis.version>
Expand Down
94 changes: 77 additions & 17 deletions api/src/main/java/com/ke/bella/batch/service/QueueService.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import com.theokanning.openai.queue.Put;
import com.theokanning.openai.queue.Take;
import com.theokanning.openai.queue.Task;
import io.micrometer.core.instrument.MeterRegistry;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.jodah.expiringmap.ExpirationListener;
import net.jodah.expiringmap.ExpiringMap;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import io.micrometer.core.instrument.MeterRegistry;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
Expand Down Expand Up @@ -89,6 +89,8 @@ public class QueueService {

private static final String LOAD_LOCK_PREFIX = "queue:load:lock:";
private static final String TIMEOUT_KEY_PREFIX = "timeout:";
private static final String RETRY_COUNT_PREFIX = "retry:";
private static final String MAX_RETRY_PREFIX = "max_retry:";
private static final String KEYSPACE_PATTERN = "__keyevent@0__:expired";

@Autowired
Expand Down Expand Up @@ -229,7 +231,7 @@ public Map<String, List<Task>> take(Take take) {
.collect(Collectors.toList());

if(!allTasks.isEmpty() && take.getProcessTimeout() > 0) {
trackProcessTimeout(allTasks, take.getProcessTimeout());
trackTask(allTasks, take);
}

tasksByQueue.forEach((queue, tasks)
Expand All @@ -238,7 +240,7 @@ public Map<String, List<Task>> take(Take take) {
}

public void complete(String taskId, Map<String, Object> result) {
untrackProcessTimeout(taskId);
unTrackTask(taskId);
int level = IDGenerator.parseLevel(taskId);
if(QueueLevel.isOnline(level)) {
QueueMetadataDB queueMeta = queueRepo.findMetadataById(IDGenerator.parseQueueId(taskId));
Expand Down Expand Up @@ -433,13 +435,15 @@ private void releaseLock(String lockKey) {
}
}

private void trackProcessTimeout(List<Task> tasks, long timeout) {
private void trackTask(List<Task> tasks, Take take) {
try (Jedis jedis = jedisPool.getResource()) {
Pipeline pipeline = jedis.pipelined();

for (Task task : tasks) {
String taskId = task.getTaskId();
pipeline.setex(TIMEOUT_KEY_PREFIX + taskId, timeout, taskId);
long ttl = (task.getExpireTime() - System.currentTimeMillis()) / 1000;
pipeline.setex(TIMEOUT_KEY_PREFIX + taskId, take.getProcessTimeout(), taskId);
pipeline.setex(MAX_RETRY_PREFIX + taskId, ttl, String.valueOf(take.getProcessMaxRetries()));
}

pipeline.sync();
Expand All @@ -448,12 +452,6 @@ private void trackProcessTimeout(List<Task> tasks, long timeout) {
}
}

private void untrackProcessTimeout(String taskId) {
try (Jedis jedis = jedisPool.getResource()) {
jedis.del(TIMEOUT_KEY_PREFIX + taskId);
}
}

@SneakyThrows
private void taskProcessExpireListener() {
JedisPubSub pubSub = new JedisPubSub() {
Expand All @@ -464,13 +462,15 @@ public void onPMessage(String pattern, String channel, String message) {
}
String taskId = message.substring(TIMEOUT_KEY_PREFIX.length());
TaskExecutor.submit(() -> {
try {
Map<String, Object> result = new HashMap<>();
result.put("status_code", "408");
result.put("request_id", taskId);
complete(taskId, result);
try (Jedis jedis = jedisPool.getResource()) {
if(shouldRetry(taskId, jedis)) {
reEnqueueTask(taskId, jedis);
} else {
completeWithTimeout(taskId);
}
} catch (Exception e) {
log.error("Failed to handle process timeout for task: {}", taskId, e);
log.error("Failed to handle timeout for task: {}", taskId, e);
completeWithTimeout(taskId);
}
});
}
Expand All @@ -485,4 +485,64 @@ public void onPMessage(String pattern, String channel, String message) {
}
}

private boolean shouldRetry(String taskId, Jedis jedis) {
String maxRetryStr = jedis.get(MAX_RETRY_PREFIX + taskId);
if(maxRetryStr == null) {
return false;
}
int maxRetries = Integer.parseInt(maxRetryStr);
if(maxRetries == 0) {
return false;
}

if(maxRetries == -1) {
return true;
}

String retryCountStr = jedis.get(RETRY_COUNT_PREFIX + taskId);
int currentRetryCount = retryCountStr != null ? Integer.parseInt(retryCountStr) : 0;
return currentRetryCount < maxRetries;
}

private void completeWithTimeout(String taskId) {
try {
Map<String, Object> result = new HashMap<>();
result.put("status_code", "408");
result.put("request_id", taskId);
complete(taskId, result);
} catch (Exception e) {
log.error("Failed to complete task {} with timeout", taskId, e);
}
}

private void reEnqueueTask(String taskId, Jedis jedis) {
QueueDB queueDB = queueRepo.findTask(taskId);
if(queueDB == null) {
throw new IllegalStateException("Task not found: " + taskId);
}
Task task = queueRepo.parseTask(queueDB);
if(task.isExpire()) {
throw new IllegalStateException("Task expired: " + taskId);
}

getQueue(task.getFullQueueName()).add(task);

String retryCountStr = jedis.get(RETRY_COUNT_PREFIX + taskId);
int currentRetryCount = retryCountStr != null ? Integer.parseInt(retryCountStr) : 0;
long ttl = TimeUtils.toSeconds(queueDB.getExpiredAt());
jedis.setex(RETRY_COUNT_PREFIX + taskId, ttl, String.valueOf(currentRetryCount + 1));
}

private void unTrackTask(String taskId) {
try (Jedis jedis = jedisPool.getResource()) {
jedis.del(
TIMEOUT_KEY_PREFIX + taskId,
RETRY_COUNT_PREFIX + taskId,
MAX_RETRY_PREFIX + taskId
);
} catch (Exception e) {
log.error("Failed to cleanup task tracking for task: {}", taskId, e);
}
}

}
9 changes: 9 additions & 0 deletions api/src/main/java/com/ke/bella/batch/utils/TimeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,13 @@ public static LocalDateTime parseTimestamp(String timestampStr) {
public static String formatTimestamp(LocalDateTime dateTime) {
return dateTime.format(TIMESTAMP_FORMATTER);
}

public static long toSeconds(LocalDateTime dateTime) {
if(dateTime == null) {
return 0;
}
long targetEpochMilli = toEpochMilli(dateTime);
long currentEpochMilli = System.currentTimeMillis();
return (targetEpochMilli - currentEpochMilli) / 1000;
}
}
Loading