diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 7b706bd..5e3c55d 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -79,13 +79,13 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.AUTO_RELEASE }}
with:
- tag_name: v1.0.3
+ tag_name: ${{ github.ref_name }}
draft: true
files: |
AICodingOfficer_windows.zip
AICodingOfficer_mac_x86.zip
AICodingOfficer_linux.zip
- name: 🎉AICO
+ name: "AICO: A cutting-edge artificial intelligence text coding officer"
body: |
## Bug Fixes
- Fix a bug #1
diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
new file mode 100644
index 0000000..93858c0
--- /dev/null
+++ b/.github/workflows/pr-test.yml
@@ -0,0 +1,47 @@
+name: PR Test
+
+on:
+ pull_request:
+ branches: [ main ]
+ paths-ignore:
+ - '**.md'
+ - 'docs/**'
+ - '.gitignore'
+
+jobs:
+ test:
+ name: Build Test
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ windows-latest, macos-latest, ubuntu-latest ]
+ python-version: ["3.10"]
+
+ steps:
+ - name: Check out
+ uses: actions/checkout@v3
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v3
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install pyinstaller
+
+ - name: Build Test
+ run: pyinstaller build.spec
+
+ - name: Verify build artifacts
+ shell: bash
+ run: |
+ if [ "${{ runner.os }}" = "Windows" ]; then
+ test -f "dist/AICodingOfficer.exe" || exit 1
+ elif [ "${{ runner.os }}" = "macOS" ]; then
+ test -d "dist/AICodingOfficer.app" || exit 1
+ elif [ "${{ runner.os }}" = "Linux" ]; then
+ test -f "dist/AICodingOfficer" || exit 1
+ fi
diff --git a/build.spec b/build.spec
index cf04019..5f98520 100644
--- a/build.spec
+++ b/build.spec
@@ -49,5 +49,5 @@ app = BUNDLE(
name='AICodingOfficer.app',
icon='image/icon.icns',
bundle_identifier=None,
- version='1.0.3'
+ version='1.0.5'
)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2a51196..c4d5792 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -12,7 +12,7 @@
project = 'AI Coding Officer'
copyright = '2024, Jianjun Xiao'
author = 'Jianjun Xiao'
-release = 'v1.0.3'
+release = 'v1.0.5'
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
diff --git a/src/core.py b/src/core.py
index 69df0c5..350979a 100644
--- a/src/core.py
+++ b/src/core.py
@@ -114,8 +114,8 @@ def __init__(self):
self.myAutoCodingWindow = MyAutoCodingWindow()
self.stackWidget.addWidget(self.myAutoCodingWindow)
- self.myMainWindow = MyMainWindow()
- self.stackWidget.addWidget(self.myMainWindow)
+ # self.myMainWindow = MyMainWindow()
+ # self.stackWidget.addWidget(self.myMainWindow)
# initialize layout
self.initLayout()
@@ -570,9 +570,14 @@ def prepare_prompt(self):
for topic_id in topic_reply_tree_dict:
for reply_tree in topic_reply_tree_dict[topic_id]['reply_tree']:
- prompt_content = r"""您将看到一组论坛中的话题和回帖,您的任务是优先根据下面的编码表中的含义解释对每个回帖提取一组标签“codes”(只有当编码表中没有合适的标签时才输出“NULL”),并以中文举例说明提取标签的理由,注意将理由翻译为中文列出。结果以JSON格式的数组输出:[{"reply_id":"1234","tags":[],"reason":[]},{"reply_id":"2345","tags":[],"reason":[]}],注意只输出JSON,不要包括其他内容!tags和reason中的内容一一对应,请根据实际情况填写,不要直接复制粘贴。
- 编码表:\n
- """
+ if self.language == 'Chinese':
+ prompt_content = r"""您将看到一组论坛中的话题和回帖,您的任务是优先根据下面的编码表中的含义解释对每个回帖提取一组标签“codes”(只有当编码表中没有合适的标签时才输出“NULL”),并以中文举例说明提取标签的理由,注意将理由翻译为中文列出。结果以JSON格式的数组输出:[{"reply_id":"1234","tags":[],"reason":[]},{"reply_id":"2345","tags":[],"reason":[]}],注意只输出JSON,不要包括其他内容!tags和reason中的内容一一对应,请根据实际情况填写,不要直接复制粘贴。
+ 编码表:\n
+ """
+ elif self.language == 'English':
+ prompt_content = r"""您将看到一组论坛中的话题和回帖,您的任务是优先根据下面的编码表中的含义解释对每个回帖提取一组标签“codes”(只有当编码表中没有合适的标签时才输出“NULL”),并以英文举例说明提取标签的理由,注意将理由翻译为英文列出。结果以JSON格式的数组输出:[{"reply_id":"1234","tags":[],"reason":[]},{"reply_id":"2345","tags":[],"reason":[]}],注意只输出JSON,不要包括其他内容!tags和reason中的内容一一对应,请根据实际情况填写,不要直接复制粘贴。
+ 编码表:\n
+ """
prompt_content += r"""
{encode_table_latex}
\n\n话题:\n
diff --git a/src/gui/setting.py b/src/gui/setting.py
index b95d8cd..1b7fd38 100644
--- a/src/gui/setting.py
+++ b/src/gui/setting.py
@@ -1,11 +1,11 @@
from PySide6.QtCore import Qt
-from PySide6.QtWidgets import QLabel, QVBoxLayout, QHBoxLayout, QFrame
+from PySide6.QtWidgets import QLabel, QVBoxLayout, QHBoxLayout, QFrame, QMessageBox
from PySide6.QtGui import QIcon
from qfluentwidgets import LineEdit, PushButton, FluentIcon, PrimaryPushButton, EditableComboBox
from src.module.resource import getResource
-from src.module.config import localDBFilePath, logFolder
-
+from src.module.config import localDBFilePath, logFolder, readConfig, configFile
+from src.module.aihubmix import AiHubMixAPI
class SettingWindow(object):
def setupUI(self, this_window):
@@ -33,15 +33,38 @@ def setupUI(self, this_window):
self.languageCard = self.settingCard(self.languageTitle, self.languageInfo, self.language, "full")
- # 选择模型
+ # 设置模型API key
+ self.modelApiTitle = QLabel("API Key")
+ self.modelApiInfo = QLabel("Please enter your AiHubMix API key.")
+
+ self.modelApiKey = LineEdit(self)
+ self.modelApiKey.setFixedWidth(200)
+ self.modelApiKey.setClearButtonEnabled(True)
+ self.modelApiKey.setText(readConfig().get("APIkey", "api_key"))
+
+ # 添加验证按钮
+ self.validateApiButton = PushButton("Validate", self)
+ self.validateApiButton.setFixedWidth(80)
+ self.validateApiButton.clicked.connect(self.validateApiKey)
- self.modelTypeTitle = QLabel("GPT model")
+ # 创建水平布局来放置API Key输入框和验证按钮
+ self.apiKeyLayout = QHBoxLayout()
+ self.apiKeyLayout.addWidget(self.modelApiKey)
+ self.apiKeyLayout.addWidget(self.validateApiButton)
+ self.apiKeyLayout.addStretch()
- self.modelTypeInfo = QLabel("Select the GPT model to use, different models have different performance.")
+ self.apiKeyFrame = QFrame()
+ self.apiKeyFrame.setLayout(self.apiKeyLayout)
+
+ self.modelApiCard = self.settingCard(self.modelApiTitle, self.modelApiInfo, self.apiKeyFrame, "full")
+
+ # 选择模型
+ self.modelTypeTitle = QLabel("AI Model")
+ self.modelTypeInfo = QLabel("Select the AI model to use. The list shows models available with your API key.")
self.modelTypeInfo.setObjectName("cardInfoLabel")
- self.modelTypeUrl = QLabel("To read document.")
+ self.modelTypeUrl = QLabel("View documentation")
self.modelTypeUrl.setOpenExternalLinks(True)
self.modelInfoLayout = QHBoxLayout()
@@ -57,10 +80,25 @@ def setupUI(self, this_window):
self.modelType = EditableComboBox(self)
self.modelType.setMinimumWidth(200)
self.modelType.setMaximumWidth(200)
- self.modelType.addItems(["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"])
- self.modelType.setText("moonshot-v1-8k") # 设置默认值为 "Kimi"
+
+ # 初始化时加载已保存的模型
+ current_model = readConfig().get("AICO", "model")
+ self.modelType.setText(current_model)
+
+ # 添加刷新按钮
+ self.refreshModelsButton = PushButton("Refresh", self)
+ self.refreshModelsButton.setFixedWidth(80)
+ self.refreshModelsButton.clicked.connect(self.refreshAvailableModels)
+
+ self.modelTypeLayout = QHBoxLayout()
+ self.modelTypeLayout.addWidget(self.modelType)
+ self.modelTypeLayout.addWidget(self.refreshModelsButton)
+ self.modelTypeLayout.addStretch()
- self.modelTypeCard = self.settingCard(self.modelTypeTitle, self.dateInfoFrame, self.modelType, "full")
+ self.modelTypeFrame = QFrame()
+ self.modelTypeFrame.setLayout(self.modelTypeLayout)
+
+ self.modelTypeCard = self.settingCard(self.modelTypeTitle, self.dateInfoFrame, self.modelTypeFrame, "full")
# 选择线程数
self.threadTitle = QLabel("Thread count")
@@ -76,17 +114,6 @@ def setupUI(self, this_window):
self.threadCard = self.settingCard(self.threadTitle, self.threadInfo, self.threadCount, "full")
- # 设置模型API key
-
- self.modelApiTitle = QLabel("GPT API Key")
- self.modelApiInfo = QLabel("Please enter the API key of the Kimi open platform.")
-
- self.modelApiKey = LineEdit(self)
- self.modelApiKey.setFixedWidth(200)
- self.modelApiKey.setClearButtonEnabled(True)
-
- self.modelApiCard = self.settingCard(self.modelApiTitle, self.modelApiInfo, self.modelApiKey, "full")
-
# 本地数据库文件夹
self.localDBTitle = QLabel("Local database")
@@ -137,6 +164,8 @@ def setupUI(self, this_window):
layout.addSpacing(12)
layout.addLayout(self.buttonLayout)
+ self.applyButton.clicked.connect(self.saveSettings)
+
def settingCard(self, card_title, card_info, card_func, size):
card_title.setObjectName("cardTitleLabel")
card_info.setObjectName("cardInfoLabel")
@@ -182,3 +211,64 @@ def tutorialCard(self, card_token, card_explain):
self.card.setLayout(self.tutorialLayout)
return self.card
+
+ def validateApiKey(self):
+ """验证API密钥"""
+ api_key = self.modelApiKey.text().strip()
+ if not api_key:
+ QMessageBox.warning(self, "Error", "Please enter your API key first!")
+ return
+
+ api = AiHubMixAPI()
+ if api.validate_api_key():
+ QMessageBox.information(self, "Success", "API Key is valid!")
+ self.refreshAvailableModels()
+ else:
+ QMessageBox.warning(self, "Error", "Invalid API key. Please check and try again.")
+
+ def refreshAvailableModels(self):
+ """刷新可用模型列表"""
+ api = AiHubMixAPI()
+ models = api.get_available_models()
+
+ if not models:
+ QMessageBox.warning(self, "Error", "Failed to fetch available models. Please check your API key and try again.")
+ return
+
+ current_model = self.modelType.text()
+ self.modelType.clear()
+ self.modelType.addItems(models)
+
+ # 保持当前选择的模型(如果它仍然可用)
+ if current_model in models:
+ self.modelType.setText(current_model)
+ else:
+ self.modelType.setText(models[0])
+
+ def saveSettings(self):
+ """保存设置"""
+ config = readConfig()
+
+ # 保存API密钥
+ api_key = self.modelApiKey.text().strip()
+ if api_key:
+ config.set("APIkey", "api_key", api_key)
+
+ # 保存选择的模型
+ selected_model = self.modelType.text()
+ if selected_model:
+ config.set("AICO", "model", selected_model)
+
+ # 保存语言设置
+ config.set("Language", "language", self.language.text())
+
+ # 保存线程数
+ thread_count = self.threadCount.text()
+ if thread_count.isdigit() and 1 <= int(thread_count) <= 8:
+ config.set("Thread", "thread_count", thread_count)
+
+ # 写入配置文件
+ with open(configFile(), "w", encoding="utf-8") as f:
+ config.write(f)
+
+ QMessageBox.information(self, "Success", "Settings saved successfully!")
diff --git a/src/module/aihubmix.py b/src/module/aihubmix.py
new file mode 100644
index 0000000..a1dcd71
--- /dev/null
+++ b/src/module/aihubmix.py
@@ -0,0 +1,108 @@
+import requests
+from src.module.config import readConfig
+
+class APIError(Exception):
+ """API错误的自定义异常类"""
+ def __init__(self, status_code, error_type, message):
+ self.status_code = status_code
+ self.error_type = error_type
+ self.message = message
+ super().__init__(self.message)
+
+class AiHubMixAPI:
+ BASE_URL = "https://api.aihubmix.com/v1"
+
+ # HTTP状态码及其对应的错误描述
+ ERROR_MESSAGES = {
+ 400: "请求格式错误,不能被服务器理解。通常意味着客户端错误。",
+ 401: "API密钥验证未通过。你需要验证你的API密钥是否正确,其他原因",
+ 403: "一般是权限不足。",
+ 404: "请求的资源未找到。你可能正在尝试访问一个不存在的端点。",
+ 413: "请求体太大。你可能需要减小你的请求体容量。",
+ 429: "由于频繁的请求超过限制,你已经超过了你的速率限制。",
+ 500: "服务器内部的错误。这可能是OpenAI服务器的问题,不是你的问题。",
+ 503: "服务器暂时不可用。这可能是由于OpenAI正在进行维护或者服务器过载。"
+ }
+
+ def __init__(self):
+ self.api_key = readConfig().get("APIkey", "api_key")
+ self.headers = {
+ 'Content-Type': 'application/json',
+ 'Authorization': f'Bearer {self.api_key}'
+ }
+
+ def _handle_error_response(self, response):
+ """处理API错误响应"""
+ status_code = response.status_code
+ try:
+ error_data = response.json().get('error', {})
+ error_type = error_data.get('type', 'unknown_error')
+ error_message = error_data.get('message', '未知错误')
+ except ValueError:
+ error_type = 'parse_error'
+ error_message = '无法解析错误响应'
+
+ # 获取详细的错误描述
+ error_description = self.ERROR_MESSAGES.get(status_code, '未知错误')
+
+ # 组合完整的错误信息
+ full_error_message = f"HTTP {status_code}: {error_description}\n具体错误: {error_message}"
+
+ raise APIError(status_code, error_type, full_error_message)
+
+ def _make_request(self, method, endpoint, **kwargs):
+ """统一的请求处理方法"""
+ try:
+ url = f"{self.BASE_URL}/{endpoint.lstrip('/')}"
+ response = requests.request(method, url, headers=self.headers, **kwargs)
+
+ if response.status_code != 200:
+ self._handle_error_response(response)
+
+ return response.json()
+ except requests.exceptions.RequestException as e:
+ raise APIError(0, 'network_error', f"网络请求错误: {str(e)}")
+ except Exception as e:
+ raise APIError(0, 'unknown_error', f"未知错误: {str(e)}")
+
+ def get_available_models(self):
+ """获取当前API Key支持的所有模型列表"""
+ try:
+ response = self._make_request('GET', '/models')
+ models = response.get('data', [])
+ return [model['id'] for model in models if model.get('available', True)]
+ except APIError as e:
+ print(f"获取模型列表失败: {e.message}")
+ return []
+
+ def get_model_info(self, model_id):
+ """获取特定模型的详细信息"""
+ try:
+ response = self._make_request('GET', f'/models/{model_id}')
+ return response.get('data', {})
+ except APIError as e:
+ print(f"获取模型信息失败: {e.message}")
+ return {}
+
+ def validate_api_key(self):
+ """验证API Key是否有效"""
+ try:
+ self._make_request('GET', '/models')
+ return True
+ except APIError:
+ return False
+
+ def chat_completion(self, model, messages, temperature=0.7, max_tokens=None):
+ """统一的聊天完成接口"""
+ try:
+ data = {
+ "model": model,
+ "messages": messages,
+ "temperature": temperature
+ }
+ if max_tokens is not None:
+ data["max_tokens"] = max_tokens
+
+ return self._make_request('POST', '/chat/completions', json=data)
+ except APIError as e:
+ return {"error": e.message}
diff --git a/src/module/coding.py b/src/module/coding.py
index ec95d1f..af98312 100644
--- a/src/module/coding.py
+++ b/src/module/coding.py
@@ -1,12 +1,13 @@
+import re
import time
import arrow
import json
import sqlite3
-import requests
import threading
from PySide6.QtCore import QObject, Signal, QThread
-from queue import Queue
+from queue import Queue, Empty
from src.module.config import localDBFilePath, readConfig
+from src.module.aihubmix import AiHubMixAPI
language = readConfig().get("Language", "language")
@@ -27,9 +28,10 @@ def __init__(self, limit):
self._stop_event = threading.Event()
def run(self):
- self._stop_event.clear() # 重置停止事件
+ self._stop_event.clear()
self.running_signal.emit(True)
- main_coding(self._stop_event, self.output_signal, self.thread_results, self.THREAD_COUNT, self.DATABASE_PATH, self.TABLE_NAME, self.LABEL_COLUMN_NAME, self.LIMIT, self.DEFAULT_NODE_RECOGNITION_PROMPT)
+ main_coding(self._stop_event, self.output_signal, self.thread_results, self.THREAD_COUNT, self.DATABASE_PATH,
+ self.TABLE_NAME, self.LABEL_COLUMN_NAME, self.LIMIT, self.DEFAULT_NODE_RECOGNITION_PROMPT)
self.running_signal.emit(False)
def stop(self):
@@ -39,166 +41,176 @@ def stop(self):
def __del__(self):
self.stop()
-def send_request(url, headers, data):
- error_type_description = {
- "content_filter": "内容审查拒绝,您的输入或生成内容可能包含不安全或敏感内容,请您避免输入易产生敏感内容的提示语,谢谢",
- "invalid_request_error": "请求无效,通常是您请求格式错误或者缺少必要参数,请检查后重试",
- "invalid_authentication_error": "鉴权失败,请检查 apikey 是否正确,请修改后重试",
- "exceeded_current_quota_error": "账户异常,请检查您的账户余额",
- "permission_denied_error": "访问其他用户信息的行为不被允许,请检查",
- "resource_not_found_error": "不存在此模型或者没有授权访问此模型,请检查后重试",
- "engine_overloaded_error": "当前并发请求过多,节点限流中,请稍后重试;建议充值升级 tier,享受更丝滑的体验",
- "exceeded_current_quota_error": "账户额度不足,请检查账户余额,保证账户余额可匹配您 tokens 的消耗费用后重试",
- "rate_limit_reached_error": "请求触发了账户并发个数的限制,请等待指定时间后重试",
- "server_error": "解析文件失败,请重试",
- "unexpected_output": "内部错误,请联系管理员",
- }
-
- try:
- response = requests.post(url, headers=headers, json=data)
- response.raise_for_status() # 将触发异常的HTTP错误
- return response.json()
- except requests.exceptions.HTTPError as e:
- error_code = e.response.status_code
- if language == "Chinese":
- error_type = e.response.json().get("error", {}).get("type", "未知错误")
- error_message = e.response.json().get("error", {}).get("message", "未知错误")
- response_message = {"error": f"HTTP {error_code}", "message": error_message, "description": error_type_description.get(error_type, "未知错误")}
- else:
- error_type = e.response.json().get("error", {}).get("type", "unknown error")
- error_message = e.response.json().get("error", {}).get("message", "unknown error")
- response_message = {"error": f"HTTP {error_code}", "message": error_message, "description": error_type_description.get(error_type, "unknown error")}
- return response_message
- except requests.exceptions.RequestException as e:
- if language == "Chinese":
- return {"error": "网络错误", "message": str(e)}
- else:
- return {"error": "Network error", "message": str(e)}
- except Exception as e:
- if language == "Chinese":
- return {"error": "未知错误", "message": str(e)}
- else:
- return {"error": "unknown error", "message": str(e)}
-
def get_code_from_gpt(output_signal, prompt_content):
- MOONSHOT_API_KEY = readConfig().get("APIkey", "api_key")
- MODEL_TYPE = readConfig().get("AICO", "model")
- url = 'https://api.moonshot.cn/v1/chat/completions'
-
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {MOONSHOT_API_KEY}'
- }
-
- data = {
- "model": MODEL_TYPE,
- "messages": [
- {
- "role": "system",
- "content": "您将看到一组论坛中的话题和回帖,您的任务是优先根据下面的编码表中的含义解释对每个回帖提取一组标签,并在一组 JSON 对象中输出。",
- },
- {
- "role": "user",
- "content": prompt_content,
- "partial": True
- },
- # {
- # "role": "assistant",
- # "content": "",
- # "partial": True
- # },
- ],
- "temperature": 0.3,
- "max_tokens": 1000,
- }
-
- response = send_request(url, headers, data)
+ """调用AI模型进行代码生成"""
+ api = AiHubMixAPI()
+ model = readConfig().get("AICO", "model")
+
+ if not model:
+ error_msg = "未选择AI模型" if language == "Chinese" else "No AI model selected"
+ output_signal.emit(f"[Error] [{arrow.now().format('YYYY-MM-DD HH:mm:ss')}] {error_msg}")
+ return None
+
+ messages = [
+ {
+ "role": "system",
+ "content": "您将看到一组论坛中的话题和回帖,您的任务是优先根据下面的编码表中的含义解释对每个回帖提取一组标签,并在一组 JSON 对象中输出。",
+ },
+ {
+ "role": "user",
+ "content": prompt_content
+ }
+ ]
+
+ response = api.chat_completion(model, messages)
+
if 'error' in response:
+ timestamp = arrow.now().format('YYYY-MM-DD HH:mm:ss')
if language == "Chinese":
- output_signal.emit("[提示] [" + arrow.now().format('YYYY-MM-DD HH:mm:ss') + "] [GPT 返回]:" + str(response))
+ output_signal.emit(f"[错误] [{timestamp}] {response['error']}")
else:
- output_signal.emit("[Notice] [" + arrow.now().format('YYYY-MM-DD HH:mm:ss') + "] [GPT returned]:" + str(response))
- time.sleep(1)
+ output_signal.emit(f"[Error] [{timestamp}] {response['error']}")
return None
+
return response
def parse_gpt_response(response):
+ """解析AI模型的响应"""
try:
- res = response['choices'][0]['message']['content'].replace('```json', '').replace('```', '').replace(' ', '').replace('\n', '')
- return str(res)
+ if not response or 'choices' not in response:
+ return None
+ content = response['choices'][0]['message']['content']
+ # 去除content中的换行、空格等字符
+ content = re.sub(r'[\n\r\s]+', ' ', str(content)).strip()
+ return content if content else None
except Exception as e:
- print('parse_gpt_response ', e)
- return 'None'
+ print(f"解析响应失败: {str(e)}")
+ return None
def encode_data(output_signal, record, default_node_recognition_prompt, label):
- prompt_content = default_node_recognition_prompt + record['prompt_content']
- response = get_code_from_gpt(output_signal, prompt_content)
- return {
- 'index': record['index'],
- label: 'None' if response == None else parse_gpt_response(response),
- 'orign_response': str(response)
- }
+ """编码数据"""
+ try:
+ prompt_content = record[0] # 获取 prompt_content 列的值
+ response = get_code_from_gpt(output_signal, prompt_content)
+ if response:
+ prompt_code = parse_gpt_response(response)
+ if prompt_code:
+ return prompt_code, json.dumps(response) # 使用 json.dumps 确保 response 被正确序列化
+ except Exception as e:
+ print(f"编码失败: {str(e)}")
+ return None, None
def worker(stop_event, output_signal, input_queue, output_dict, db_path, table_name, label, default_node_recognition_prompt):
- conn = sqlite3.connect(db_path)
- cursor = conn.cursor()
-
- while not input_queue.empty():
+ """工作线程处理函数"""
+ while not stop_event.is_set():
+ try:
+ # 使用timeout参数,这样可以更频繁地检查stop_event
+ try:
+ record = input_queue.get(timeout=0.1)
+ except Empty:
+ continue
+
+ # 再次检查stop_event,如果设置了就立即退出
+ if stop_event.is_set():
+ break
+
+ # 输出当前处理进度
+ timestamp = arrow.now().format('YYYY-MM-DD HH:mm:ss')
+ if language == "Chinese":
+ output_signal.emit(f"[提示] [{timestamp}] [当前线程]:{threading.current_thread().name},剩余 {input_queue.qsize()} 项待处理")
+ else:
+ output_signal.emit(f"[Notice] [{timestamp}] [Current thread]: {threading.current_thread().name}, {input_queue.qsize()} items remaining")
+ except Empty:
+ break
+
+ # 如果设置了stop_event,跳过处理直接退出
if stop_event.is_set():
break
- record = input_queue.get()
- if language == "Chinese":
- output_signal.emit("[提示] [{}] [当前线程]:{},剩余 {} 项待处理".format(arrow.now().format('YYYY-MM-DD HH:mm:ss'), threading.current_thread().name, input_queue.qsize()))
- else:
- output_signal.emit("[Notice] [{}] [Current thread]: {},remaining {} items to process".format(arrow.now().format('YYYY-MM-DD HH:mm:ss'), threading.current_thread().name, input_queue.qsize()))
- try:
- encoded_record = encode_data(output_signal, record, default_node_recognition_prompt, label)
- if encoded_record[label] != 'None':
- cursor.execute("UPDATE {} SET `{}` = ?, `prompt_code_orign` = ? WHERE `index` = ?".format(table_name, label),
- (encoded_record[label], encoded_record['orign_response'], encoded_record['index']))
+
+ prompt_code, prompt_code_orign = encode_data(output_signal, record, default_node_recognition_prompt, label)
+
+ # 再次检查stop_event
+ if stop_event.is_set():
+ break
+
+ if prompt_code and prompt_code_orign:
+ try:
+ conn = sqlite3.connect(db_path)
+ cursor = conn.cursor()
+ cursor.execute(
+ "UPDATE prompt SET prompt_code = ?, prompt_code_orign = ? WHERE prompt_content = ?",
+ (prompt_code, prompt_code_orign, record[0])
+ )
conn.commit()
- output_dict[encoded_record['index']] = encoded_record
- else:
+ conn.close()
+ except Exception as e:
+ timestamp = arrow.now().format('YYYY-MM-DD HH:mm:ss')
if language == "Chinese":
- output_signal.emit("[提示] [" + arrow.now().format('YYYY-MM-DD HH:mm:ss') + "] [编码失败]:无法从 GPT 获取编码")
+ output_signal.emit(f"[错误] [{timestamp}] 数据库更新失败: {str(e)}")
else:
- output_signal.emit("[Notice] [" + arrow.now().format('YYYY-MM-DD HH:mm:ss') + "] [Encoding failed]: Unable to get code from GPT")
- except sqlite3.Error as e:
- print("An error occurred:", e.args[0])
- conn.close()
+ output_signal.emit(f"[Error] [{timestamp}] Database update failed: {str(e)}")
+ else:
+ timestamp = arrow.now().format('YYYY-MM-DD HH:mm:ss')
+ if language == "Chinese":
+ output_signal.emit(f"[警告] [{timestamp}] 编码失败")
+ else:
+ output_signal.emit(f"[Warning] [{timestamp}] Encoding failed")
def fetch_data_from_database(db_path, table_name, label, limit=-1):
+ """从数据库获取数据"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
try:
if limit == -1:
- cursor.execute('SELECT * FROM {} WHERE prompt_content > 0 AND `{}` = "None"'.format(table_name, label))
+ cursor.execute("SELECT prompt_content FROM prompt WHERE prompt_code IS NULL OR prompt_code = 'None'")
else:
- cursor.execute('SELECT * FROM {} WHERE prompt_content > 0 AND `{}` = "None" LIMIT {}'.format(table_name, label, limit))
- columns = [column[0] for column in cursor.description]
+ cursor.execute("SELECT prompt_content FROM prompt WHERE prompt_code IS NULL OR prompt_code = 'None' LIMIT ?",
+ (limit,))
records = cursor.fetchall()
- data_list = [dict(zip(columns, record)) for record in records]
+ return records
finally:
conn.close()
- return data_list
def main_coding(stop_event, output_signal, thread_results, THREAD_COUNT, DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME, LIMIT, DEFAULT_NODE_RECOGNITION_PROMPT):
- input_queue = Queue()
- data_list = fetch_data_from_database(DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME, LIMIT)
+ """主编码处理函数"""
+ records = fetch_data_from_database(DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME, LIMIT)
+ if not records:
+ if language == "Chinese":
+ output_signal.emit(f"[提示] [{arrow.now().format('YYYY-MM-DD HH:mm:ss')}] 没有需要处理的数据")
+ else:
+ output_signal.emit(f"[Notice] [{arrow.now().format('YYYY-MM-DD HH:mm:ss')}] No data to process")
+ return
- for record in data_list:
+ input_queue = Queue()
+ for record in records:
input_queue.put(record)
- threads = []
- for _ in range(THREAD_COUNT):
- t = threading.Thread(target=worker, args=(stop_event, output_signal, input_queue, thread_results, DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME, DEFAULT_NODE_RECOGNITION_PROMPT))
- threads.append(t)
+ threads = []
+ for _ in range(min(THREAD_COUNT, len(records))):
+ t = threading.Thread(
+ target=worker,
+ args=(stop_event, output_signal, input_queue, thread_results, DATABASE_PATH,
+ TABLE_NAME, LABEL_COLUMN_NAME, DEFAULT_NODE_RECOGNITION_PROMPT)
+ )
+ t.daemon = True # 设置为守护线程,这样主程序退出时线程会自动结束
t.start()
+ threads.append(t)
- for t in threads:
- t.join()
+ # 使用超时等待,这样可以响应停止信号
+ while threads:
+ for t in threads[:]:
+ t.join(timeout=0.1) # 等待0.1秒
+ if not t.is_alive():
+ threads.remove(t)
+
+ # 如果设置了stop_event,就不再等待其他线程
+ if stop_event.is_set():
+ break
- if language == "Chinese":
- output_signal.emit("[提示] [{}] [编码统计]:剩余 {} 项待处理".format(arrow.now().format('YYYY-MM-DD HH:mm:ss'), len(fetch_data_from_database(DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME))))
- else:
- output_signal.emit("[Notice] [{}] [Coding statistics]:Total {} items remaining to be processed".format(arrow.now().format('YYYY-MM-DD HH:mm:ss'), len(fetch_data_from_database(DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME))))
\ No newline at end of file
+ # 只有在正常完成时才显示统计信息
+ if not stop_event.is_set():
+ remaining = len(fetch_data_from_database(DATABASE_PATH, TABLE_NAME, LABEL_COLUMN_NAME))
+ if language == "Chinese":
+ output_signal.emit(f"[提示] [{arrow.now().format('YYYY-MM-DD HH:mm:ss')}] [编码统计]:剩余 {remaining} 项待处理")
+ else:
+ output_signal.emit(f"[Notice] [{arrow.now().format('YYYY-MM-DD HH:mm:ss')}] [Coding statistics]:Total {remaining} items remaining to be processed")
\ No newline at end of file
diff --git a/src/module/config.py b/src/module/config.py
index 8b29493..281a031 100644
--- a/src/module/config.py
+++ b/src/module/config.py
@@ -74,7 +74,7 @@ def initConfig(config_file):
config.set("APIkey", "api_key", "")
config.add_section("AICO")
- config.set("AICO", "model", "moonshot-v1-8k")
+ config.set("AICO", "model", "") # 默认为空,由用户选择
config.add_section("Thread")
config.set("Thread", "thread_count", "1")
diff --git a/src/module/version.py b/src/module/version.py
index 6335ff6..e98dace 100644
--- a/src/module/version.py
+++ b/src/module/version.py
@@ -3,7 +3,7 @@
def currentVersion():
- current_version = "1.0.3"
+ current_version = "1.0.5"
return current_version
diff --git a/version.txt b/version.txt
index b96c45e..97c6964 100644
--- a/version.txt
+++ b/version.txt
@@ -31,7 +31,7 @@ VSVersionInfo(
StringStruct(u'FileVersion', u'1'),
StringStruct(u'LegalCopyright', u'Copyright (C) 2024 Jianjun Xiao'),
StringStruct(u'ProductName', u'AICodingOfficer'),
- StringStruct(u'ProductVersion', u'1.0.3')])
+ StringStruct(u'ProductVersion', u'1.0.5')])
]),
VarFileInfo([VarStruct(u'Translation', [2052, 1200])])
]