Conversation
- Fix curl global init thread safety: use singleton CurlGlobalManager - Fix cross-rack detection: use global_rank_rack_map_ instead of gpus_per_server_ - Initialize WorkloadConfig members with default values - Optimize dependency tracking from O(n²) to O(n) using map lookup - Add error return values to OxcFlowOutput functions - Rename static debug counters for clarity - Add DP workload test file - Update design document with Mermaid diagrams Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
delete build directory
|
Anthony seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account. You have signed the CLA already but the status is still pending? Let us recheck it. |
There was a problem hiding this comment.
Pull request overview
This pull request integrates optical cross-connect (OXC) technology into the SimAI simulation framework, enabling optimized collective communication operations for large-scale AI training. The PR adds a complete OXC subsystem including build infrastructure, comprehensive documentation, and implementation of HTTP-based communication with an external OXC service.
Changes:
- Added OXC build support to build scripts with new "oxc" compilation mode
- Implemented complete OXC flow generation system with HTTP client, flow generator, and output formatting
- Added comprehensive 1,315-line design documentation in Chinese detailing architecture, data structures, and usage
- Integrated external RankTable configuration support for flexible topology specification
Reviewed changes
Copilot reviewed 15 out of 16 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/build.sh | Added OXC compilation mode and cleaned up trailing whitespace |
| astra-sim-alibabacloud/build.sh | Integrated OXC build directory and compilation support |
| astra-sim-alibabacloud/build/simai_oxc/ | New build scripts and CMake configuration for OXC module |
| astra-sim-alibabacloud/astra-sim/system/OxcTypes.h | Core data structures for RankTable, flow entries, and communication types |
| astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcMain.cc | Main entry point with argument parsing and workload processing |
| astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.* | HTTP client for OXC API communication using libcurl |
| astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.* | Flow generation logic with OXC and native algorithm support |
| astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.* | Output formatting for CSV, JSON, and summary files |
| docs/SimAI-OXC-Design.md | Comprehensive design documentation with architecture diagrams |
| example/microAllReduce_dp.txt | Example workload file for testing |
| .gitignore | Added OXC build artifacts to ignore list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class CurlGlobalManager { | ||
| public: | ||
| static CurlGlobalManager& instance() { | ||
| static CurlGlobalManager instance; | ||
| return instance; | ||
| } | ||
|
|
||
| bool isInitialized() const { return initialized_; } | ||
|
|
||
| private: | ||
| CurlGlobalManager() : initialized_(false) { | ||
| CURLcode res = curl_global_init(CURL_GLOBAL_DEFAULT); | ||
| if (res == CURLE_OK) { | ||
| initialized_ = true; | ||
| } else { | ||
| std::cerr << "[OXC] Warning: curl_global_init failed: " | ||
| << curl_easy_strerror(res) << std::endl; | ||
| } | ||
| } | ||
|
|
||
| ~CurlGlobalManager() { | ||
| if (initialized_) { | ||
| curl_global_cleanup(); | ||
| } | ||
| } | ||
|
|
||
| // 禁止拷贝和移动 | ||
| CurlGlobalManager(const CurlGlobalManager&) = delete; | ||
| CurlGlobalManager& operator=(const CurlGlobalManager&) = delete; | ||
|
|
||
| bool initialized_; | ||
| }; |
There was a problem hiding this comment.
The curl_global_init and curl_global_cleanup functions should not be called multiple times. While the CurlGlobalManager singleton pattern is correct, there's a potential issue: if multiple OxcHttpClient instances are destroyed and recreated in different threads, the destructor could be called while another thread is still using curl. Consider using a reference counting mechanism or ensuring that curl_global_cleanup is only called when all clients are destroyed.
| int tp_size = (tp_size_ > 0 && tp_size_ <= total_gpus) ? tp_size_ : 1; | ||
| int dp_size = (dp_size_ > 0 && dp_size_ <= total_gpus) ? dp_size_ : 1; | ||
| int ep_size = (ep_size_ > 0 && ep_size_ <= total_gpus) ? ep_size_ : 1; | ||
|
|
||
| int TP_nums = total_gpus / tp_size; // TP 组数量 | ||
| int DP_nums = total_gpus / dp_size; // DP 组数量 | ||
|
|
||
| // 计算 DP_EP_size (如果 EP_size > 1) | ||
| int dp_ep_size = (ep_size > 1 && dp_size >= ep_size) ? (dp_size / ep_size) : 1; |
There was a problem hiding this comment.
Missing error handling for division by zero. If tp_size_ is 0, the calculation total_gpus / tp_size will cause undefined behavior. While there's a guard at line 176, it only ensures tp_size > 0 for local variable tp_size, not for the original tp_size_ member. The same applies to dp_size_ and ep_size_. Consider validating these parameters in the constructor.
| if (step > 0) { | ||
| // 依赖于前一个step | ||
| flow.depends_on.push_back(flow.flow_id - num_ranks); |
There was a problem hiding this comment.
The dependency calculation at line 449 assumes flows are stored sequentially and uses flow.flow_id - num_ranks to find the previous step's flow. This is fragile because it relies on the specific ordering of flow IDs. If flows are generated in a different order or if there are gaps in flow IDs, this will create incorrect dependencies. Consider using a more explicit dependency tracking mechanism.
|
|
||
| if (entries.empty()) { | ||
| std::cerr << "[OXC] Warning: Empty response from OXC API for operation " | ||
| << ctx.operation_id << ", error: " << http_client_.getLastError() | ||
| << std::endl; |
There was a problem hiding this comment.
Inconsistent error handling: When OXC API returns empty entries, the code logs a warning and falls back to native implementation. However, if the HTTP request itself fails (curl error), the response is empty but the error might be logged by curl. Consider explicitly checking http_client_.getLastError() is not empty and including it in the fallback decision/logging.
| if (entries.empty()) { | |
| std::cerr << "[OXC] Warning: Empty response from OXC API for operation " | |
| << ctx.operation_id << ", error: " << http_client_.getLastError() | |
| << std::endl; | |
| std::string last_error = http_client_.getLastError(); | |
| if (entries.empty()) { | |
| if (!last_error.empty()) { | |
| std::cerr << "[OXC] Error: HTTP/OXC API call failed for operation " | |
| << ctx.operation_id << ", error: " << last_error | |
| << ". Falling back to native implementation." << std::endl; | |
| } else { | |
| std::cerr << "[OXC] Warning: Empty response from OXC API for operation " | |
| << ctx.operation_id | |
| << " with no HTTP client error reported. Falling back to native implementation." | |
| << std::endl; | |
| } |
| @@ -0,0 +1,46 @@ | |||
| #!/bin/bash | |||
|
|
|||
| # Absolue path to this script | |||
There was a problem hiding this comment.
Typo in comment: "Absolue" should be "Absolute"
| # Absolue path to this script | |
| # Absolute path to this script |
| uint64_t chunk_size = ctx.data_size / num_ranks; | ||
|
|
||
| // ReduceScatter阶段 | ||
| for (int step = 0; step < num_ranks - 1; ++step) { | ||
| for (int i = 0; i < num_ranks; ++i) { | ||
| OutputFlow flow; | ||
| flow.operation_id = ctx.operation_id; | ||
| flow.layer_name = ctx.layer_name; | ||
| flow.phase = ctx.phase; | ||
| flow.comm_type = ctx.comm_type; | ||
| flow.group_type = ctx.group_type; | ||
| flow.flow_id = getNextFlowId(); | ||
| flow.src = comm_group_ranks[i]; | ||
| flow.dst = comm_group_ranks[(i + 1) % num_ranks]; | ||
| flow.flow_size = chunk_size; | ||
| flow.step = step; | ||
|
|
||
| if (step > 0) { | ||
| // 依赖于前一个step | ||
| flow.depends_on.push_back(flow.flow_id - num_ranks); | ||
| } | ||
|
|
||
| flows.push_back(flow); | ||
| } | ||
| } | ||
|
|
||
| // AllGather阶段 | ||
| for (int step = 0; step < num_ranks - 1; ++step) { | ||
| for (int i = 0; i < num_ranks; ++i) { | ||
| OutputFlow flow; | ||
| flow.operation_id = ctx.operation_id; | ||
| flow.layer_name = ctx.layer_name; | ||
| flow.phase = ctx.phase; | ||
| flow.comm_type = ctx.comm_type; | ||
| flow.group_type = ctx.group_type; | ||
| flow.flow_id = getNextFlowId(); | ||
| flow.src = comm_group_ranks[i]; | ||
| flow.dst = comm_group_ranks[(i + 1) % num_ranks]; | ||
| flow.flow_size = chunk_size; | ||
| flow.step = (num_ranks - 1) + step; | ||
|
|
||
| // 依赖于前一个step | ||
| flow.depends_on.push_back(flow.flow_id - num_ranks); | ||
|
|
||
| flows.push_back(flow); | ||
| } | ||
| } | ||
| break; | ||
| } | ||
|
|
||
| case CommType::ALL_GATHER: { | ||
| // AllGather: 每个rank发送数据到所有其他rank | ||
| uint64_t chunk_size = ctx.data_size / num_ranks; | ||
|
|
||
| for (int step = 0; step < num_ranks - 1; ++step) { | ||
| for (int i = 0; i < num_ranks; ++i) { | ||
| OutputFlow flow; | ||
| flow.operation_id = ctx.operation_id; | ||
| flow.layer_name = ctx.layer_name; | ||
| flow.phase = ctx.phase; | ||
| flow.comm_type = ctx.comm_type; | ||
| flow.group_type = ctx.group_type; | ||
| flow.flow_id = getNextFlowId(); | ||
| flow.src = comm_group_ranks[i]; | ||
| flow.dst = comm_group_ranks[(i + 1) % num_ranks]; | ||
| flow.flow_size = chunk_size; | ||
| flow.step = step; | ||
|
|
||
| if (step > 0) { | ||
| flow.depends_on.push_back(flow.flow_id - num_ranks); | ||
| } | ||
|
|
||
| flows.push_back(flow); | ||
| } | ||
| } | ||
| break; | ||
| } | ||
|
|
||
| case CommType::REDUCE_SCATTER: { | ||
| // ReduceScatter: 类似AllGather但带reduce | ||
| uint64_t chunk_size = ctx.data_size / num_ranks; | ||
|
|
||
| for (int step = 0; step < num_ranks - 1; ++step) { | ||
| for (int i = 0; i < num_ranks; ++i) { | ||
| OutputFlow flow; | ||
| flow.operation_id = ctx.operation_id; | ||
| flow.layer_name = ctx.layer_name; | ||
| flow.phase = ctx.phase; | ||
| flow.comm_type = ctx.comm_type; | ||
| flow.group_type = ctx.group_type; | ||
| flow.flow_id = getNextFlowId(); | ||
| flow.src = comm_group_ranks[i]; | ||
| flow.dst = comm_group_ranks[(i + 1) % num_ranks]; | ||
| flow.flow_size = chunk_size; | ||
| flow.step = step; | ||
|
|
||
| if (step > 0) { | ||
| flow.depends_on.push_back(flow.flow_id - num_ranks); | ||
| } | ||
|
|
||
| flows.push_back(flow); | ||
| } | ||
| } | ||
| break; | ||
| } | ||
|
|
||
| case CommType::ALL_TO_ALL: { | ||
| // AllToAll: 每个rank发送不同数据到每个其他rank | ||
| uint64_t chunk_size = ctx.data_size / (num_ranks * num_ranks); |
There was a problem hiding this comment.
Potential integer overflow in chunk size calculation. When ctx.data_size is very large and num_ranks is small, or when calculating chunk_size for ALL_TO_ALL (line 538), there's no check for overflow. Consider adding overflow checks or using uint64_t consistently to prevent potential issues.
| if (json[pos] == '{') brace_count++; | ||
| else if (json[pos] == '}') brace_count--; | ||
| pos++; | ||
| } |
There was a problem hiding this comment.
Memory safety issue: The loop at line 210 uses nested braces counting to extract JSON objects, but if the JSON is malformed and has mismatched braces, the loop could read beyond the string bounds. While there's a check pos < json.size(), malformed JSON with extra closing braces could cause the substring operation to have invalid indices. Consider adding boundary checks before substring operations.
| } | |
| } | |
| // 在提取子字符串之前进行边界和完整性检查,防止格式错误的 JSON 导致越界 | |
| if (brace_count != 0 || pos <= rank_start || pos > json.size()) { | |
| std::cerr << "Error: malformed RankTable JSON (unmatched braces or invalid object range)." << std::endl; | |
| break; | |
| } |
| std::string OxcHttpClient::httpPost(const std::string& url, const std::string& json_body) { | ||
| CURL* curl = curl_easy_init(); | ||
| if (!curl) { | ||
| last_error_ = "Failed to initialize CURL"; | ||
| return ""; | ||
| } | ||
|
|
||
| std::string response; | ||
| struct curl_slist* headers = nullptr; | ||
| headers = curl_slist_append(headers, "Content-Type: application/json"); | ||
|
|
||
| curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); | ||
| curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_body.c_str()); | ||
| curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_body.size()); | ||
| curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); | ||
| curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback); | ||
| curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); | ||
| curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout_seconds_); | ||
|
|
||
| CURLcode res = curl_easy_perform(curl); | ||
|
|
||
| if (res != CURLE_OK) { | ||
| last_error_ = std::string("CURL error: ") + curl_easy_strerror(res); | ||
| response = ""; | ||
| } | ||
|
|
||
| curl_slist_free_all(headers); | ||
| curl_easy_cleanup(curl); | ||
|
|
||
| return response; |
There was a problem hiding this comment.
Potential resource leak: If curl_easy_init() succeeds but subsequent operations fail before curl_easy_cleanup() is called, the curl handle will leak. Consider using RAII (e.g., std::unique_ptr with a custom deleter) to ensure cleanup even when exceptions or early returns occur.
| // 构建操作间依赖关系 | ||
| std::map<int, std::vector<int>> op_dependencies; | ||
| for (size_t i = 1; i < operations.size(); ++i) { | ||
| // 简单的顺序依赖:每个操作依赖于前一个操作 | ||
| op_dependencies[operations[i].operation_id].push_back(operations[i-1].operation_id); |
There was a problem hiding this comment.
The dependency tracking in operation-level dependencies (lines 111-114) assumes sequential dependencies where each operation depends only on the previous one. However, the actual dependencies might be more complex based on the flow dependencies tracked in the flows themselves. This simplified dependency graph may not accurately represent the actual execution constraints.
| // 构建操作间依赖关系 | |
| std::map<int, std::vector<int>> op_dependencies; | |
| for (size_t i = 1; i < operations.size(); ++i) { | |
| // 简单的顺序依赖:每个操作依赖于前一个操作 | |
| op_dependencies[operations[i].operation_id].push_back(operations[i-1].operation_id); | |
| // 构建操作间依赖关系:基于每个操作自身记录的依赖关系 | |
| std::map<int, std::vector<int>> op_dependencies; | |
| for (const auto& op : operations) { | |
| // 使用 OperationContext 中的 depends_on_ops 作为真实的依赖 | |
| auto& deps = op_dependencies[op.operation_id]; | |
| deps.insert(deps.end(), op.depends_on_ops.begin(), op.depends_on_ops.end()); |
| while (pos < s.size() && s[pos] != '"') { | ||
| if (s[pos] == '\\' && pos + 1 < s.size()) { | ||
| pos++; // 跳过转义字符 | ||
| } | ||
| result += s[pos++]; | ||
| } | ||
| if (pos < s.size()) pos++; // 跳过结尾的 " |
There was a problem hiding this comment.
The parseJsonString function doesn't properly handle all JSON escape sequences. It only handles the backslash but doesn't decode escaped characters like \n, \t, ", \, etc. This could lead to incorrect parsing of strings containing these escape sequences. Consider implementing proper JSON string unescaping.
| while (pos < s.size() && s[pos] != '"') { | |
| if (s[pos] == '\\' && pos + 1 < s.size()) { | |
| pos++; // 跳过转义字符 | |
| } | |
| result += s[pos++]; | |
| } | |
| if (pos < s.size()) pos++; // 跳过结尾的 " | |
| // 辅助函数:将 Unicode 码点以 UTF-8 编码追加到 result | |
| auto appendCodePointUtf8 = [&result](unsigned int codePoint) { | |
| if (codePoint <= 0x7F) { | |
| result.push_back(static_cast<char>(codePoint)); | |
| } else if (codePoint <= 0x7FF) { | |
| result.push_back(static_cast<char>(0xC0 | ((codePoint >> 6) & 0x1F))); | |
| result.push_back(static_cast<char>(0x80 | (codePoint & 0x3F))); | |
| } else if (codePoint <= 0xFFFF) { | |
| result.push_back(static_cast<char>(0xE0 | ((codePoint >> 12) & 0x0F))); | |
| result.push_back(static_cast<char>(0x80 | ((codePoint >> 6) & 0x3F))); | |
| result.push_back(static_cast<char>(0x80 | (codePoint & 0x3F))); | |
| } else if (codePoint <= 0x10FFFF) { | |
| result.push_back(static_cast<char>(0xF0 | ((codePoint >> 18) & 0x07))); | |
| result.push_back(static_cast<char>(0x80 | ((codePoint >> 12) & 0x3F))); | |
| result.push_back(static_cast<char>(0x80 | ((codePoint >> 6) & 0x3F))); | |
| result.push_back(static_cast<char>(0x80 | (codePoint & 0x3F))); | |
| } | |
| }; | |
| auto hexValue = [](char c) -> int { | |
| if (c >= '0' && c <= '9') return c - '0'; | |
| if (c >= 'a' && c <= 'f') return 10 + (c - 'a'); | |
| if (c >= 'A' && c <= 'F') return 10 + (c - 'A'); | |
| return -1; | |
| }; | |
| while (pos < s.size()) { | |
| char c = s[pos]; | |
| // 遇到结束引号,跳出 | |
| if (c == '"') { | |
| pos++; // 跳过结尾的 " | |
| break; | |
| } | |
| if (c == '\\' && pos + 1 < s.size()) { | |
| char esc = s[pos + 1]; | |
| switch (esc) { | |
| case '"': | |
| result.push_back('"'); | |
| pos += 2; | |
| break; | |
| case '\\': | |
| result.push_back('\\'); | |
| pos += 2; | |
| break; | |
| case '/': | |
| result.push_back('/'); | |
| pos += 2; | |
| break; | |
| case 'b': | |
| result.push_back('\b'); | |
| pos += 2; | |
| break; | |
| case 'f': | |
| result.push_back('\f'); | |
| pos += 2; | |
| break; | |
| case 'n': | |
| result.push_back('\n'); | |
| pos += 2; | |
| break; | |
| case 'r': | |
| result.push_back('\r'); | |
| pos += 2; | |
| break; | |
| case 't': | |
| result.push_back('\t'); | |
| pos += 2; | |
| break; | |
| case 'u': { | |
| // 处理 \uXXXX 形式的 Unicode 转义 | |
| if (pos + 6 <= s.size()) { | |
| unsigned int codePoint = 0; | |
| bool ok = true; | |
| for (int i = 0; i < 4; ++i) { | |
| int v = hexValue(s[pos + 2 + i]); | |
| if (v < 0) { | |
| ok = false; | |
| break; | |
| } | |
| codePoint = (codePoint << 4) | static_cast<unsigned int>(v); | |
| } | |
| if (ok) { | |
| appendCodePointUtf8(codePoint); | |
| } else { | |
| // 非法转义,尽量保留原始内容 | |
| result.push_back('\\'); | |
| result.push_back('u'); | |
| for (int i = 0; i < 4 && pos + 2 + i < s.size(); ++i) { | |
| result.push_back(s[pos + 2 + i]); | |
| } | |
| } | |
| pos += 6; // 跳过 \uXXXX | |
| } else { | |
| // 不完整的 \u 转义,按字面量处理 | |
| result.push_back('\\'); | |
| result.push_back('u'); | |
| pos += 2; | |
| } | |
| break; | |
| } | |
| default: | |
| // 未知转义,按字面字符处理 esc | |
| result.push_back(esc); | |
| pos += 2; | |
| break; | |
| } | |
| continue; | |
| } | |
| // 普通字符,直接追加 | |
| result.push_back(c); | |
| pos++; | |
| } |
算法集成第一版