diff --git a/.gitignore b/.gitignore index e4a9790f..b08a6df4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ test/log/ *.log .cur* .DS_Store + +# OXC build artifacts +astra-sim-alibabacloud/build/simai_oxc/build/ diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/CMakeLists.txt b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/CMakeLists.txt new file mode 100644 index 00000000..78732b6d --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/CMakeLists.txt @@ -0,0 +1,19 @@ +# CMake requirement +cmake_minimum_required(VERSION 3.15) + +# 项目名称和设置 +project(SimAI_oxc) + +# 查找源文件 +file(GLOB SOURCES "*.cc") +file(GLOB HEADERS "*.h") +include_directories("${PROJECT_SOURCE_DIR}/../../../") + +# 查找 libcurl +find_package(CURL REQUIRED) + +# 设置可执行文件 +add_executable(SimAI_oxc ${SOURCES} ${HEADERS}) + +# 链接库 +target_link_libraries(SimAI_oxc AstraSim CURL::libcurl) diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.cc b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.cc new file mode 100644 index 00000000..d273cbf4 --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.cc @@ -0,0 +1,618 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OxcFlowGenerator.h" +#include +#include +#include +#include +#include + +namespace OXC { + +OxcFlowGenerator::OxcFlowGenerator( + const std::string& oxc_server_url, + int num_gpus, + int gpus_per_server, + int tp_size, + int dp_size, + int ep_size, + int pp_size) + : oxc_server_url_(oxc_server_url), + alg_name_("ALGO_OXC_RING"), + num_gpus_(num_gpus), + gpus_per_server_(gpus_per_server), + tp_size_(tp_size), + dp_size_(dp_size), + ep_size_(ep_size), + pp_size_(pp_size), + global_flow_id_(0), + global_operation_id_(0), + comm_domains_set_(false), + external_ranktable_set_(false) { + http_client_.initialize(oxc_server_url); + // 初始化全局通信域(一次性计算) + initCommDomains(); + // 注意:RankTable 必须通过 setRankTable() 从外部文件加载 +} + +OxcFlowGenerator::~OxcFlowGenerator() { +} + +void OxcFlowGenerator::setAlgorithm(const std::string& alg_name) { + alg_name_ = alg_name; +} + +void OxcFlowGenerator::setRankTable(const RankTable& ranktable) { + global_ranktable_ = ranktable; + external_ranktable_set_ = true; + std::cout << "[OXC] External RankTable set with " << ranktable.rank_count << " ranks" << std::endl; +} + +void OxcFlowGenerator::setRankRackMap(const std::map& rank_rack_map) { + global_rank_rack_map_ = rank_rack_map; + std::cout << "[OXC] External RankRackMap set with " << rank_rack_map.size() << " entries" << std::endl; +} + +bool OxcFlowGenerator::hasExternalRankTable() const { + return external_ranktable_set_; +} + +bool OxcFlowGenerator::isOxcSupported(CommType comm_type) const { + // 目前OXC只支持AllReduce + return comm_type == CommType::ALL_REDUCE; +} + +int OxcFlowGenerator::getNextFlowId() { + return global_flow_id_++; +} + +int OxcFlowGenerator::getNextOperationId() { + return global_operation_id_++; +} + +const std::vector& OxcFlowGenerator::getAllFlows() const { + return all_flows_; +} + +const std::vector& OxcFlowGenerator::getAllOperations() const { + return all_operations_; +} + +std::vector> OxcFlowGenerator::buildCommDomains( + GroupType group_type, + int total_gpus) { + + // 如果已经缓存了通信域,直接返回 + if (comm_domains_set_) { + auto it = comm_domains_.find(group_type); + if (it != comm_domains_.end()) { + return it->second; + } + // 如果没有找到该类型的通信域,返回空 + return {}; + } + + // 如果没有缓存,动态计算(兼容旧代码) + return computeCommDomains(group_type, total_gpus); +} + +void OxcFlowGenerator::printCommDomainDetails( + const std::string& name, + const std::vector>& domains, + int max_groups_to_print) { + + std::cout << " " << name << ": " << domains.size() << " groups"; + if (!domains.empty()) { + std::cout << ", " << domains[0].size() << " ranks/group"; + } + std::cout << std::endl; + + // 打印每个组的详细信息(限制打印数量避免输出过多) + int groups_to_print = std::min(static_cast(domains.size()), max_groups_to_print); + for (int i = 0; i < groups_to_print; ++i) { + std::cout << " Group " << i << ": ["; + for (size_t j = 0; j < domains[i].size(); ++j) { + if (j > 0) std::cout << ", "; + std::cout << domains[i][j]; + } + std::cout << "]" << std::endl; + } + if (static_cast(domains.size()) > max_groups_to_print) { + std::cout << " ... (" << (domains.size() - max_groups_to_print) + << " more groups)" << std::endl; + } +} + +void OxcFlowGenerator::initCommDomains() { + // 一次性计算所有类型的通信域并缓存 + comm_domains_[GroupType::TP] = computeCommDomains(GroupType::TP, num_gpus_); + comm_domains_[GroupType::DP] = computeCommDomains(GroupType::DP, num_gpus_); + comm_domains_[GroupType::EP] = computeCommDomains(GroupType::EP, num_gpus_); + comm_domains_[GroupType::DP_EP] = computeCommDomains(GroupType::DP_EP, num_gpus_); + comm_domains_set_ = true; + + // 打印通信域详细信息 + std::cout << "[OXC] ========== Communication Domains ==========" << std::endl; + std::cout << "[OXC] Configuration: total_gpus=" << num_gpus_ + << ", gpus_per_server=" << gpus_per_server_ + << ", TP=" << tp_size_ + << ", DP=" << dp_size_ + << ", EP=" << ep_size_ + << ", PP=" << pp_size_ << std::endl; + + // 打印每种通信域的详细信息(最多打印4个组) + printCommDomainDetails("TP", comm_domains_[GroupType::TP], 4); + printCommDomainDetails("DP", comm_domains_[GroupType::DP], 4); + printCommDomainDetails("EP", comm_domains_[GroupType::EP], 4); + printCommDomainDetails("DP_EP", comm_domains_[GroupType::DP_EP], 4); + + std::cout << "[OXC] ============================================" << std::endl; +} + +std::vector> OxcFlowGenerator::computeCommDomains( + GroupType group_type, + int total_gpus) { + + std::vector> domains; + + // 仿照 MockNcclGroup 的通信域创建逻辑 + // 约束: TP_size * DP_size * PP_size = total_gpus + // 约束: EP_size * DP_EP_size = DP_size + + // 参数保护:防止除零和无效参数 + int tp_size = (tp_size_ > 0 && tp_size_ <= total_gpus) ? tp_size_ : 1; + int dp_size = (dp_size_ > 0 && dp_size_ <= total_gpus) ? dp_size_ : 1; + int ep_size = (ep_size_ > 0 && ep_size_ <= total_gpus) ? ep_size_ : 1; + + int TP_nums = total_gpus / tp_size; // TP 组数量 + int DP_nums = total_gpus / dp_size; // DP 组数量 + + // 计算 DP_EP_size (如果 EP_size > 1) + int dp_ep_size = (ep_size > 1 && dp_size >= ep_size) ? (dp_size / ep_size) : 1; + + switch (group_type) { + case GroupType::TP: { + // TP组:连续的 rank 组成一个 TP 组 + // 例如 TP_size=4, total=16: [[0,1,2,3], [4,5,6,7], [8,9,10,11], [12,13,14,15]] + if (tp_size > 1) { + for (int i = 0; i < TP_nums; ++i) { + std::vector domain; + for (int j = 0; j < tp_size; ++j) { + int rank = i * tp_size + j; + domain.push_back(rank); + } + domains.push_back(domain); + } + } + break; + } + case GroupType::DP: { + // DP组:跨 TP 组的相同位置 rank 组成 DP 组 + // 例如 DP_size=4, DP_nums=4, total=16: [[0,4,8,12], [1,5,9,13], [2,6,10,14], [3,7,11,15]] + if (dp_size > 1) { + for (int i = 0; i < DP_nums; ++i) { + std::vector domain; + for (int j = 0; j < dp_size; ++j) { + int rank = i + j * DP_nums; + domain.push_back(rank); + } + domains.push_back(domain); + } + } + break; + } + case GroupType::EP: { + // EP组:基于 TP 组,跨多个连续 TP 组选择相同位置的 rank + // 例如 EP_size=2, TP_size=4, total=16: + // TP组: [0,1,2,3], [4,5,6,7], [8,9,10,11], [12,13,14,15] + // EP组: [0,4], [1,5], [2,6], [3,7], [8,12], [9,13], [10,14], [11,15] + if (ep_size > 1 && TP_nums >= ep_size) { + // 先构建所有 TP 组 + std::vector> tp_groups; + for (int i = 0; i < TP_nums; ++i) { + std::vector tp_group; + for (int j = 0; j < tp_size; ++j) { + tp_group.push_back(i * tp_size + j); + } + tp_groups.push_back(tp_group); + } + + // 每 EP_size 个连续 TP 组形成一组 EP 域 + for (int i = 0; i < TP_nums; i += ep_size) { + // 对于 TP 组内的每个位置 + for (int k = 0; k < tp_size; ++k) { + std::vector domain; + // 从 EP_size 个连续 TP 组中取相同位置的 rank + for (int l = i; l < i + ep_size && l < TP_nums; ++l) { + domain.push_back(tp_groups[l][k]); + } + if (domain.size() > 1) { + domains.push_back(domain); + } + } + } + } + break; + } + case GroupType::DP_EP: { + // DP_EP组:类似 EP,但步长为 EP_size + // 例如 DP_EP_size=2, EP_size=2, TP_size=4, total=16: + // DP_EP组: [0,8], [1,9], [2,10], [3,11], [4,12], [5,13], [6,14], [7,15] + if (dp_ep_size > 1 && ep_size > 0) { + // 先构建所有 TP 组 + std::vector> tp_groups; + for (int i = 0; i < TP_nums; ++i) { + std::vector tp_group; + for (int j = 0; j < tp_size; ++j) { + tp_group.push_back(i * tp_size + j); + } + tp_groups.push_back(tp_group); + } + + // 每隔 EP_size 个 TP 组取一个,共取 DP_EP_size 个 + for (int i = 0; i < ep_size && i < TP_nums; ++i) { + for (int k = 0; k < tp_size; ++k) { + std::vector domain; + for (int l = i; l < TP_nums; l += ep_size) { + domain.push_back(tp_groups[l][k]); + } + if (domain.size() > 1) { + domains.push_back(domain); + } + } + } + } + break; + } + default: { + // 默认:所有GPU在一个组 + std::vector domain; + for (int i = 0; i < total_gpus; ++i) { + domain.push_back(i); + } + domains.push_back(domain); + break; + } + } + + return domains; +} + +std::vector OxcFlowGenerator::convertOxcResponse( + const std::vector& entries, + const OperationContext& ctx) { + + std::vector flows; + + // 按step分组,用于建立依赖关系 + // 使用 map> 优化查找效率 O(1) + std::map> step_dst_to_flow_id; + int base_flow_id = global_flow_id_; + + for (const auto& entry : entries) { + OutputFlow flow; + flow.operation_id = ctx.operation_id; + flow.layer_name = ctx.layer_name; + flow.phase = ctx.phase; + flow.comm_type = ctx.comm_type; + flow.group_type = ctx.group_type; + flow.flow_id = getNextFlowId(); + flow.src = entry.src_rank; + flow.dst = entry.dst_rank; + flow.flow_size = entry.datasize; + flow.step = entry.step; + + // 设置依赖:依赖于前一个step中目标为当前源的流 + if (entry.step > 0) { + auto step_it = step_dst_to_flow_id.find(entry.step - 1); + if (step_it != step_dst_to_flow_id.end()) { + auto dst_it = step_it->second.find(entry.src_rank); + if (dst_it != step_it->second.end()) { + flow.depends_on.push_back(dst_it->second); + } + } + } + + // 记录当前流的 dst -> flow_id 映射,供后续 step 查找 + step_dst_to_flow_id[entry.step][entry.dst_rank] = flow.flow_id; + flows.push_back(flow); + } + + return flows; +} + +std::vector OxcFlowGenerator::generateAllReduceViaOxc( + const OperationContext& ctx, + const std::vector& comm_group_ranks) { + + // 检查是否跨 rack 通信 + // OXC 算法只适用于跨 rack 的通信,同一 rack 内的通信使用原生算法 + std::set racks; + for (int rank : comm_group_ranks) { + std::string rank_str = std::to_string(rank); + auto it = global_rank_rack_map_.find(rank_str); + if (it != global_rank_rack_map_.end()) { + racks.insert(it->second); + } else { + // 如果 rank_rack_map 中没有该 rank,使用 gpus_per_server_ 计算作为回退 + int rack_id = rank / gpus_per_server_; + racks.insert("rack_" + std::to_string(rack_id)); + } + } + + if (racks.size() <= 1) { + // 所有 rank 都在同一个 rack,使用原生算法 + static int native_log_count = 0; + if (native_log_count < 5) { + std::cout << "[OXC] Using NATIVE (same rack): op=" << ctx.operation_id + << ", racks=" << racks.size() << std::endl; + native_log_count++; + } + return generateViaNative(ctx, comm_group_ranks); + } + + // 跨 rack 通信,调用 OXC API + static int oxc_log_count = 0; + if (oxc_log_count < 5) { + std::cout << "[OXC] Calling OXC API (cross-rack): op=" << ctx.operation_id + << ", racks=" << racks.size() << ", ranks=["; + for (size_t i = 0; i < std::min(comm_group_ranks.size(), static_cast(4)); ++i) { + if (i > 0) std::cout << ","; + std::cout << comm_group_ranks[i]; + } + std::cout << "]" << std::endl; + oxc_log_count++; + } + + OxcAllReduceRequest request; + + // 使用全局 RankTable + request.ranktable = global_ranktable_; + + // 构建dpCommDomain - 使用传入的通信组 + request.dpCommDomain.push_back(comm_group_ranks); + + // 设置通信量 + request.commDomainVolume = static_cast(ctx.data_size); + + // 使用全局 rank 到 rack 的映射 + request.rankIdRackIdMap = global_rank_rack_map_; + + // 设置算法名称 + request.algName = alg_name_; + + // 调用OXC API + std::vector entries = http_client_.callAllReduceApi(request); + + if (entries.empty()) { + std::cerr << "[OXC] Warning: Empty response from OXC API for operation " + << ctx.operation_id << ", error: " << http_client_.getLastError() + << std::endl; + // 回退到原生实现 + return generateViaNative(ctx, comm_group_ranks); + } + + std::cout << "[OXC] Received " << entries.size() << " flow entries for " + << ctx.layer_name << " " << phaseToString(ctx.phase) << std::endl; + + return convertOxcResponse(entries, ctx); +} + +std::vector OxcFlowGenerator::generateViaNative( + const OperationContext& ctx, + const std::vector& comm_group_ranks) { + + std::vector flows; + int num_ranks = static_cast(comm_group_ranks.size()); + + if (num_ranks <= 1) { + return flows; + } + + // 简单的Ring算法实现 + switch (ctx.comm_type) { + case CommType::ALL_REDUCE: { + // AllReduce = ReduceScatter + AllGather + // 简化:生成Ring通信模式 + uint64_t chunk_size = ctx.data_size / num_ranks; + + // ReduceScatter阶段 + for (int step = 0; step < num_ranks - 1; ++step) { + for (int i = 0; i < num_ranks; ++i) { + OutputFlow flow; + flow.operation_id = ctx.operation_id; + flow.layer_name = ctx.layer_name; + flow.phase = ctx.phase; + flow.comm_type = ctx.comm_type; + flow.group_type = ctx.group_type; + flow.flow_id = getNextFlowId(); + flow.src = comm_group_ranks[i]; + flow.dst = comm_group_ranks[(i + 1) % num_ranks]; + flow.flow_size = chunk_size; + flow.step = step; + + if (step > 0) { + // 依赖于前一个step + flow.depends_on.push_back(flow.flow_id - num_ranks); + } + + flows.push_back(flow); + } + } + + // AllGather阶段 + for (int step = 0; step < num_ranks - 1; ++step) { + for (int i = 0; i < num_ranks; ++i) { + OutputFlow flow; + flow.operation_id = ctx.operation_id; + flow.layer_name = ctx.layer_name; + flow.phase = ctx.phase; + flow.comm_type = ctx.comm_type; + flow.group_type = ctx.group_type; + flow.flow_id = getNextFlowId(); + flow.src = comm_group_ranks[i]; + flow.dst = comm_group_ranks[(i + 1) % num_ranks]; + flow.flow_size = chunk_size; + flow.step = (num_ranks - 1) + step; + + // 依赖于前一个step + flow.depends_on.push_back(flow.flow_id - num_ranks); + + flows.push_back(flow); + } + } + break; + } + + case CommType::ALL_GATHER: { + // AllGather: 每个rank发送数据到所有其他rank + uint64_t chunk_size = ctx.data_size / num_ranks; + + for (int step = 0; step < num_ranks - 1; ++step) { + for (int i = 0; i < num_ranks; ++i) { + OutputFlow flow; + flow.operation_id = ctx.operation_id; + flow.layer_name = ctx.layer_name; + flow.phase = ctx.phase; + flow.comm_type = ctx.comm_type; + flow.group_type = ctx.group_type; + flow.flow_id = getNextFlowId(); + flow.src = comm_group_ranks[i]; + flow.dst = comm_group_ranks[(i + 1) % num_ranks]; + flow.flow_size = chunk_size; + flow.step = step; + + if (step > 0) { + flow.depends_on.push_back(flow.flow_id - num_ranks); + } + + flows.push_back(flow); + } + } + break; + } + + case CommType::REDUCE_SCATTER: { + // ReduceScatter: 类似AllGather但带reduce + uint64_t chunk_size = ctx.data_size / num_ranks; + + for (int step = 0; step < num_ranks - 1; ++step) { + for (int i = 0; i < num_ranks; ++i) { + OutputFlow flow; + flow.operation_id = ctx.operation_id; + flow.layer_name = ctx.layer_name; + flow.phase = ctx.phase; + flow.comm_type = ctx.comm_type; + flow.group_type = ctx.group_type; + flow.flow_id = getNextFlowId(); + flow.src = comm_group_ranks[i]; + flow.dst = comm_group_ranks[(i + 1) % num_ranks]; + flow.flow_size = chunk_size; + flow.step = step; + + if (step > 0) { + flow.depends_on.push_back(flow.flow_id - num_ranks); + } + + flows.push_back(flow); + } + } + break; + } + + case CommType::ALL_TO_ALL: { + // AllToAll: 每个rank发送不同数据到每个其他rank + uint64_t chunk_size = ctx.data_size / (num_ranks * num_ranks); + + for (int src_idx = 0; src_idx < num_ranks; ++src_idx) { + for (int dst_idx = 0; dst_idx < num_ranks; ++dst_idx) { + if (src_idx == dst_idx) continue; + + OutputFlow flow; + flow.operation_id = ctx.operation_id; + flow.layer_name = ctx.layer_name; + flow.phase = ctx.phase; + flow.comm_type = ctx.comm_type; + flow.group_type = ctx.group_type; + flow.flow_id = getNextFlowId(); + flow.src = comm_group_ranks[src_idx]; + flow.dst = comm_group_ranks[dst_idx]; + flow.flow_size = chunk_size; + flow.step = 0; // AllToAll可以并行执行 + + flows.push_back(flow); + } + } + break; + } + + default: + break; + } + + return flows; +} + +std::vector OxcFlowGenerator::generateFlows( + const OperationContext& ctx, + const std::vector& comm_group_ranks) { + + std::vector flows; + + if (comm_group_ranks.size() <= 1) { + return flows; + } + + // 记录操作 + OperationContext op_ctx = ctx; + op_ctx.operation_id = getNextOperationId(); + op_ctx.base_flow_id = global_flow_id_; + + // 调试输出(每1000个操作输出一次,避免输出过多) + static int op_log_count = 0; + bool should_debug = (op_log_count < 10) || (op_log_count % 1000 == 0); + + if (should_debug) { + std::cout << "[OXC DEBUG] Op " << op_ctx.operation_id + << ": comm_type=" << commTypeToString(ctx.comm_type) + << ", group_type=" << groupTypeToString(ctx.group_type) + << ", phase=" << phaseToString(ctx.phase) + << ", ranks=["; + for (size_t i = 0; i < std::min(comm_group_ranks.size(), static_cast(4)); ++i) { + if (i > 0) std::cout << ","; + std::cout << comm_group_ranks[i]; + } + if (comm_group_ranks.size() > 4) std::cout << "..."; + std::cout << "]" << std::endl; + } + op_log_count++; + + if (isOxcSupported(ctx.comm_type)) { + flows = generateAllReduceViaOxc(op_ctx, comm_group_ranks); + } else { + flows = generateViaNative(op_ctx, comm_group_ranks); + } + + op_ctx.flow_count = static_cast(flows.size()); + all_operations_.push_back(op_ctx); + + // 添加到全局流列表 + all_flows_.insert(all_flows_.end(), flows.begin(), flows.end()); + + return flows; +} + +} // namespace OXC diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.h b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.h new file mode 100644 index 00000000..e0e95ca1 --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowGenerator.h @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OXC_FLOW_GENERATOR_H__ +#define __OXC_FLOW_GENERATOR_H__ + +#include +#include +#include +#include "astra-sim/system/OxcTypes.h" +#include "OxcHttpClient.h" + +namespace OXC { + +class OxcFlowGenerator { +public: + OxcFlowGenerator( + const std::string& oxc_server_url, + int num_gpus, + int gpus_per_server, + int tp_size, + int dp_size, + int ep_size, + int pp_size + ); + + ~OxcFlowGenerator(); + + // 为一个集合通信操作生成流 + std::vector generateFlows( + const OperationContext& ctx, + const std::vector& comm_group_ranks + ); + + // 检查OXC是否支持该操作类型 + bool isOxcSupported(CommType comm_type) const; + + // 获取所有生成的流 + const std::vector& getAllFlows() const; + + // 获取所有操作上下文 + const std::vector& getAllOperations() const; + + // 获取下一个流ID + int getNextFlowId(); + + // 获取下一个操作ID + int getNextOperationId(); + + // 设置OXC算法名称 + void setAlgorithm(const std::string& alg_name); + + // 设置外部 RankTable(从 JSON 文件加载) + void setRankTable(const RankTable& ranktable); + + // 设置外部 rank 到 rack 映射 + void setRankRackMap(const std::map& rank_rack_map); + + // 检查是否已设置外部 RankTable + bool hasExternalRankTable() const; + + // 从 MockNcclGroup 设置通信域 + // group_type -> list of comm domains, each domain is a list of ranks + void setCommDomainsFromMockNccl( + const std::map>>& domains + ); + + // 获取指定类型的通信域 + std::vector> getCommDomains(GroupType group_type) const; + + // 构建通信组的rank列表(如果没有从 MockNcclGroup 设置,则自己计算) + std::vector> buildCommDomains( + GroupType group_type, + int total_gpus + ); + +private: + // 通过OXC API生成AllReduce流 + std::vector generateAllReduceViaOxc( + const OperationContext& ctx, + const std::vector& comm_group_ranks + ); + + // 使用原生方式生成流(用于OXC不支持的操作) + std::vector generateViaNative( + const OperationContext& ctx, + const std::vector& comm_group_ranks + ); + + // 初始化全局通信域(在构造函数中调用一次) + void initCommDomains(); + + // 打印通信域详细信息 + void printCommDomainDetails( + const std::string& name, + const std::vector>& domains, + int max_groups_to_print = 4 + ); + + // 计算通信域(内部方法) + std::vector> computeCommDomains( + GroupType group_type, + int total_gpus + ); + + // 将OXC响应转换为OutputFlow + std::vector convertOxcResponse( + const std::vector& entries, + const OperationContext& ctx + ); + + OxcHttpClient http_client_; + std::string oxc_server_url_; + std::string alg_name_; + int num_gpus_; + int gpus_per_server_; + int tp_size_; + int dp_size_; + int ep_size_; + int pp_size_; + + int global_flow_id_; + int global_operation_id_; + std::vector all_flows_; + std::vector all_operations_; + + // 全局 RankTable(仿真任务全局唯一) + RankTable global_ranktable_; + // 全局 rank 到 rack 的映射 + std::map global_rank_rack_map_; + + // 从 MockNcclGroup 获取的通信域 + std::map>> comm_domains_; + bool comm_domains_set_; + + // 是否使用外部 RankTable + bool external_ranktable_set_; +}; + +} // namespace OXC + +#endif // __OXC_FLOW_GENERATOR_H__ diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.cc b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.cc new file mode 100644 index 00000000..d2b868eb --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.cc @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OxcFlowOutput.h" +#include +#include +#include +#include + +namespace OXC { + +OxcFlowOutput::OxcFlowOutput(const std::string& output_prefix) + : output_prefix_(output_prefix) { +} + +OxcFlowOutput::~OxcFlowOutput() { +} + +bool OxcFlowOutput::writeFlowMatrices(const std::vector& flows) { + std::string filename = output_prefix_ + "_flows.csv"; + std::ofstream ofs(filename); + + if (!ofs.is_open()) { + std::cerr << "[OXC] Error: Cannot open file " << filename << " for writing" << std::endl; + return false; + } + + // 写入CSV头 + ofs << "op_id,layer,phase,comm_type,group,flow_id,src,dst,size,step,depends_on" << std::endl; + + // 写入每个流 + for (const auto& flow : flows) { + ofs << flow.operation_id << "," + << flow.layer_name << "," + << phaseToString(flow.phase) << "," + << commTypeToString(flow.comm_type) << "," + << groupTypeToString(flow.group_type) << "," + << flow.flow_id << "," + << flow.src << "," + << flow.dst << "," + << flow.flow_size << "," + << flow.step << ","; + + // 写入依赖列表 + ofs << "\"["; + for (size_t i = 0; i < flow.depends_on.size(); ++i) { + if (i > 0) ofs << ","; + ofs << flow.depends_on[i]; + } + ofs << "]\"" << std::endl; + } + + ofs.close(); + std::cout << "[OXC] Flow matrices written to " << filename << std::endl; + return true; +} + +bool OxcFlowOutput::writeDependencyGraph( + const std::vector& operations, + const std::vector& flows) { + + std::string filename = output_prefix_ + "_deps.json"; + std::ofstream ofs(filename); + + if (!ofs.is_open()) { + std::cerr << "[OXC] Error: Cannot open file " << filename << " for writing" << std::endl; + return false; + } + + ofs << "{" << std::endl; + + // 写入操作列表 + ofs << " \"operations\": [" << std::endl; + for (size_t i = 0; i < operations.size(); ++i) { + const auto& op = operations[i]; + ofs << " {" << std::endl; + ofs << " \"op_id\": " << op.operation_id << "," << std::endl; + ofs << " \"layer\": \"" << op.layer_name << "\"," << std::endl; + ofs << " \"layer_index\": " << op.layer_index << "," << std::endl; + ofs << " \"phase\": \"" << phaseToString(op.phase) << "\"," << std::endl; + ofs << " \"type\": \"" << commTypeToString(op.comm_type) << "\"," << std::endl; + ofs << " \"group\": \"" << groupTypeToString(op.group_type) << "\"," << std::endl; + ofs << " \"data_size\": " << op.data_size << "," << std::endl; + ofs << " \"flow_count\": " << op.flow_count << "," << std::endl; + ofs << " \"depends_on\": ["; + for (size_t j = 0; j < op.depends_on_ops.size(); ++j) { + if (j > 0) ofs << ", "; + ofs << op.depends_on_ops[j]; + } + ofs << "]" << std::endl; + ofs << " }"; + if (i < operations.size() - 1) ofs << ","; + ofs << std::endl; + } + ofs << " ]," << std::endl; + + // 构建操作间依赖关系 + std::map> op_dependencies; + for (size_t i = 1; i < operations.size(); ++i) { + // 简单的顺序依赖:每个操作依赖于前一个操作 + op_dependencies[operations[i].operation_id].push_back(operations[i-1].operation_id); + } + + // 写入依赖关系 + ofs << " \"dependencies\": {" << std::endl; + bool first = true; + for (const auto& pair : op_dependencies) { + if (!first) ofs << "," << std::endl; + first = false; + ofs << " \"" << pair.first << "\": ["; + for (size_t i = 0; i < pair.second.size(); ++i) { + if (i > 0) ofs << ", "; + ofs << pair.second[i]; + } + ofs << "]"; + } + ofs << std::endl; + ofs << " }" << std::endl; + + ofs << "}" << std::endl; + + ofs.close(); + std::cout << "[OXC] Dependency graph written to " << filename << std::endl; + return true; +} + +bool OxcFlowOutput::writeSummary( + const WorkloadConfig& config, + const std::vector& operations, + const std::vector& flows, + const std::string& oxc_url, + const std::string& alg_name) { + + std::string filename = output_prefix_ + "_summary.txt"; + std::ofstream ofs(filename); + + if (!ofs.is_open()) { + std::cerr << "[OXC] Error: Cannot open file " << filename << " for writing" << std::endl; + return false; + } + + ofs << "SimAI-OXC Flow Generation Summary" << std::endl; + ofs << "=================================" << std::endl; + ofs << std::endl; + + ofs << "Workload Configuration:" << std::endl; + ofs << " Parallelism Policy: " << config.parallelism_policy << std::endl; + ofs << " Total GPUs: " << config.all_gpus << std::endl; + ofs << " GPUs per Server: " << config.gpus_per_server << std::endl; + ofs << " TP Size: " << config.model_parallel_npu_group << std::endl; + ofs << " EP Size: " << config.ep_size << std::endl; + ofs << " PP Size: " << config.pp_size << std::endl; + ofs << " VPP: " << config.vpp << std::endl; + ofs << " GA: " << config.ga << std::endl; + ofs << " Number of Layers: " << config.num_layers << std::endl; + ofs << std::endl; + + ofs << "OXC Configuration:" << std::endl; + ofs << " Server URL: " << oxc_url << std::endl; + ofs << " Algorithm: " << alg_name << std::endl; + ofs << std::endl; + + // 统计各类型操作数量 + std::map op_counts; + std::map oxc_op_counts; + for (const auto& op : operations) { + op_counts[op.comm_type]++; + if (op.comm_type == CommType::ALL_REDUCE) { + oxc_op_counts[op.comm_type]++; + } + } + + ofs << "Operations Processed:" << std::endl; + ofs << " Total Operations: " << operations.size() << std::endl; + for (const auto& pair : op_counts) { + std::string type_str = commTypeToString(pair.first); + bool is_oxc = (pair.first == CommType::ALL_REDUCE); + ofs << " - " << type_str << ": " << pair.second; + if (is_oxc) { + ofs << " (OXC)"; + } else { + ofs << " (Native)"; + } + ofs << std::endl; + } + ofs << std::endl; + + ofs << "Flow Statistics:" << std::endl; + ofs << " Total Flows Generated: " << flows.size() << std::endl; + + // 统计依赖数量 + int total_deps = 0; + for (const auto& flow : flows) { + total_deps += static_cast(flow.depends_on.size()); + } + ofs << " Total Dependencies: " << total_deps << std::endl; + ofs << std::endl; + + ofs << "Output Files:" << std::endl; + ofs << " Flow Matrix: " << output_prefix_ << "_flows.csv" << std::endl; + ofs << " Dependency Graph: " << output_prefix_ << "_deps.json" << std::endl; + ofs << " Summary: " << output_prefix_ << "_summary.txt" << std::endl; + + ofs.close(); + std::cout << "[OXC] Summary written to " << filename << std::endl; + return true; +} + +} // namespace OXC diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.h b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.h new file mode 100644 index 00000000..86341c04 --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcFlowOutput.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OXC_FLOW_OUTPUT_H__ +#define __OXC_FLOW_OUTPUT_H__ + +#include +#include +#include "astra-sim/system/OxcTypes.h" + +namespace OXC { + +class OxcFlowOutput { +public: + OxcFlowOutput(const std::string& output_prefix); + ~OxcFlowOutput(); + + // 写入流矩阵到CSV文件,返回是否成功 + bool writeFlowMatrices(const std::vector& flows); + + // 写入依赖图到JSON文件,返回是否成功 + bool writeDependencyGraph( + const std::vector& operations, + const std::vector& flows + ); + + // 写入摘要信息,返回是否成功 + bool writeSummary( + const WorkloadConfig& config, + const std::vector& operations, + const std::vector& flows, + const std::string& oxc_url, + const std::string& alg_name + ); + +private: + std::string output_prefix_; +}; + +} // namespace OXC + +#endif // __OXC_FLOW_OUTPUT_H__ diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.cc b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.cc new file mode 100644 index 00000000..fc086841 --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.cc @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OxcHttpClient.h" +#include +#include +#include +#include +#include + +namespace OXC { + +// CURL 全局初始化管理器(线程安全,整个程序生命周期只初始化一次) +class CurlGlobalManager { +public: + static CurlGlobalManager& instance() { + static CurlGlobalManager instance; + return instance; + } + + bool isInitialized() const { return initialized_; } + +private: + CurlGlobalManager() : initialized_(false) { + CURLcode res = curl_global_init(CURL_GLOBAL_DEFAULT); + if (res == CURLE_OK) { + initialized_ = true; + } else { + std::cerr << "[OXC] Warning: curl_global_init failed: " + << curl_easy_strerror(res) << std::endl; + } + } + + ~CurlGlobalManager() { + if (initialized_) { + curl_global_cleanup(); + } + } + + // 禁止拷贝和移动 + CurlGlobalManager(const CurlGlobalManager&) = delete; + CurlGlobalManager& operator=(const CurlGlobalManager&) = delete; + + bool initialized_; +}; + +// libcurl 回调函数,用于接收响应数据 +static size_t writeCallback(void* contents, size_t size, size_t nmemb, void* userp) { + size_t total_size = size * nmemb; + std::string* response = static_cast(userp); + response->append(static_cast(contents), total_size); + return total_size; +} + +OxcHttpClient::OxcHttpClient() + : timeout_seconds_(30), initialized_(false) { + // 确保 CURL 全局初始化(线程安全,只执行一次) + CurlGlobalManager::instance(); +} + +OxcHttpClient::~OxcHttpClient() { + // 不再调用 curl_global_cleanup(),由 CurlGlobalManager 在程序退出时处理 +} + +bool OxcHttpClient::initialize(const std::string& base_url) { + base_url_ = base_url; + // 移除末尾的斜杠 + while (!base_url_.empty() && base_url_.back() == '/') { + base_url_.pop_back(); + } + initialized_ = true; + return true; +} + +void OxcHttpClient::setTimeout(int timeout_seconds) { + timeout_seconds_ = timeout_seconds; +} + +std::string OxcHttpClient::getLastError() const { + return last_error_; +} + +std::string OxcHttpClient::httpPost(const std::string& url, const std::string& json_body) { + CURL* curl = curl_easy_init(); + if (!curl) { + last_error_ = "Failed to initialize CURL"; + return ""; + } + + std::string response; + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_body.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_body.size()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); + curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout_seconds_); + + CURLcode res = curl_easy_perform(curl); + + if (res != CURLE_OK) { + last_error_ = std::string("CURL error: ") + curl_easy_strerror(res); + response = ""; + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return response; +} + +std::string OxcHttpClient::buildRequestJson(const OxcAllReduceRequest& request) { + std::ostringstream oss; + oss << "{"; + + // ranktable + oss << "\"ranktable\":{"; + oss << "\"version\":\"" << request.ranktable.version << "\","; + oss << "\"status\":\"" << request.ranktable.status << "\","; + oss << "\"rank_count\":" << request.ranktable.rank_count << ","; + oss << "\"rank_list\":["; + + for (size_t i = 0; i < request.ranktable.rank_list.size(); ++i) { + const auto& rank = request.ranktable.rank_list[i]; + if (i > 0) oss << ","; + oss << "{"; + oss << "\"rank_id\":" << rank.rank_id << ","; + oss << "\"device_id\":" << rank.device_id << ","; + oss << "\"local_id\":" << rank.local_id << ","; + oss << "\"level_list\":["; + + for (size_t j = 0; j < rank.level_list.size(); ++j) { + const auto& level = rank.level_list[j]; + if (j > 0) oss << ","; + oss << "{"; + oss << "\"net_layer\":" << level.net_layer << ","; + oss << "\"net_instance_id\":\"" << level.net_instance_id << "\","; + oss << "\"net_type\":\"" << level.net_type << "\","; + oss << "\"net_attr\":\"" << level.net_attr << "\","; + oss << "\"rank_addr_list\":["; + + for (size_t k = 0; k < level.rank_addr_list.size(); ++k) { + const auto& addr = level.rank_addr_list[k]; + if (k > 0) oss << ","; + oss << "{"; + oss << "\"addr_type\":\"" << addr.addr_type << "\","; + oss << "\"addr\":\"" << addr.addr << "\","; + oss << "\"ports\":["; + for (size_t p = 0; p < addr.ports.size(); ++p) { + if (p > 0) oss << ","; + oss << "\"" << addr.ports[p] << "\""; + } + oss << "],"; + oss << "\"plane_id\":\"" << addr.plane_id << "\""; + oss << "}"; + } + oss << "]"; + oss << "}"; + } + oss << "]"; + oss << "}"; + } + oss << "]"; + oss << "},"; + + // dpCommDomain + oss << "\"dpCommDomain\":["; + for (size_t i = 0; i < request.dpCommDomain.size(); ++i) { + if (i > 0) oss << ","; + oss << "["; + for (size_t j = 0; j < request.dpCommDomain[i].size(); ++j) { + if (j > 0) oss << ","; + oss << request.dpCommDomain[i][j]; + } + oss << "]"; + } + oss << "],"; + + // commDomainVolume + oss << "\"commDomainVolume\":" << request.commDomainVolume << ","; + + // rankIdRackIdMap + oss << "\"rankIdRackIdMap\":{"; + bool first = true; + for (const auto& pair : request.rankIdRackIdMap) { + if (!first) oss << ","; + first = false; + oss << "\"" << pair.first << "\":\"" << pair.second << "\""; + } + oss << "},"; + + // algName + oss << "\"algName\":\"" << request.algName << "\""; + + oss << "}"; + return oss.str(); +} + +std::vector OxcHttpClient::parseResponseJson(const std::string& response) { + std::vector entries; + + // 简单的JSON数组解析: [[src, dst, step, datasize], ...] + // 跳过空白和开头的 '[' + size_t pos = 0; + while (pos < response.size() && (response[pos] == ' ' || response[pos] == '\n' || response[pos] == '\t')) { + pos++; + } + + if (pos >= response.size() || response[pos] != '[') { + last_error_ = "Invalid response format: expected '[' at start"; + return entries; + } + pos++; // 跳过 '[' + + while (pos < response.size()) { + // 跳过空白 + while (pos < response.size() && (response[pos] == ' ' || response[pos] == '\n' || response[pos] == '\t' || response[pos] == ',')) { + pos++; + } + + if (pos >= response.size() || response[pos] == ']') { + break; // 数组结束 + } + + if (response[pos] != '[') { + last_error_ = "Invalid response format: expected '[' for inner array"; + return entries; + } + pos++; // 跳过内部数组的 '[' + + // 解析四个数字: src, dst, step, datasize + std::vector values; + while (pos < response.size() && response[pos] != ']') { + // 跳过空白和逗号 + while (pos < response.size() && (response[pos] == ' ' || response[pos] == ',' || response[pos] == '\n' || response[pos] == '\t')) { + pos++; + } + + if (response[pos] == ']') break; + + // 解析数字 + bool negative = false; + if (response[pos] == '-') { + negative = true; + pos++; + } + + int64_t num = 0; + while (pos < response.size() && response[pos] >= '0' && response[pos] <= '9') { + num = num * 10 + (response[pos] - '0'); + pos++; + } + if (negative) num = -num; + values.push_back(num); + } + + if (values.size() >= 4) { + OxcFlowEntry entry; + entry.src_rank = static_cast(values[0]); + entry.dst_rank = static_cast(values[1]); + entry.step = static_cast(values[2]); + entry.datasize = static_cast(values[3]); + entries.push_back(entry); + } + + // 跳过 ']' + if (pos < response.size() && response[pos] == ']') { + pos++; + } + } + + return entries; +} + +std::vector OxcHttpClient::callAllReduceApi(const OxcAllReduceRequest& request) { + if (!initialized_) { + last_error_ = "HTTP client not initialized"; + return {}; + } + + std::string url = base_url_ + "/api/oxc/allreduce"; + std::string json_body = buildRequestJson(request); + + std::cout << "[OXC] Calling API: " << url << std::endl; + // 调试:输出请求体的前500个字符 + std::cout << "[OXC] Request body (first 500 chars): " + << json_body.substr(0, std::min(static_cast(500), json_body.size())) << std::endl; + + std::string response = httpPost(url, json_body); + + if (response.empty()) { + std::cerr << "[OXC] Empty response from API" << std::endl; + return {}; + } + + // 打印响应内容 + std::cout << "[OXC] Response (first 1000 chars): " + << response.substr(0, std::min(static_cast(1000), response.size())) << std::endl; + + std::vector entries = parseResponseJson(response); + + if (entries.empty() && !response.empty()) { + // 解析失败,输出调试信息 + std::cerr << "[OXC] Parse failed. Response (first 200 chars): " + << response.substr(0, std::min(static_cast(200), response.size())) << std::endl; + } + + return entries; +} + +} // namespace OXC diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.h b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.h new file mode 100644 index 00000000..7247662a --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcHttpClient.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OXC_HTTP_CLIENT_H__ +#define __OXC_HTTP_CLIENT_H__ + +#include +#include +#include "astra-sim/system/OxcTypes.h" + +namespace OXC { + +class OxcHttpClient { +public: + OxcHttpClient(); + ~OxcHttpClient(); + + // 初始化,设置服务器URL + bool initialize(const std::string& base_url); + + // 调用 OXC AllReduce API + // 成功返回流条目列表,失败返回空列表 + std::vector callAllReduceApi(const OxcAllReduceRequest& request); + + // 获取最后的错误信息 + std::string getLastError() const; + + // 设置超时时间(秒) + void setTimeout(int timeout_seconds); + +private: + // 构建请求JSON + std::string buildRequestJson(const OxcAllReduceRequest& request); + + // 解析响应JSON + std::vector parseResponseJson(const std::string& response); + + // 执行HTTP POST请求 + std::string httpPost(const std::string& url, const std::string& json_body); + + std::string base_url_; + std::string last_error_; + int timeout_seconds_; + bool initialized_; +}; + +} // namespace OXC + +#endif // __OXC_HTTP_CLIENT_H__ diff --git a/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcMain.cc b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcMain.cc new file mode 100644 index 00000000..eca53595 --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/network_frontend/oxc/OxcMain.cc @@ -0,0 +1,637 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "astra-sim/system/OxcTypes.h" +#include "OxcFlowGenerator.h" +#include "OxcFlowOutput.h" + +using namespace std; +using namespace OXC; + +// 命令行参数结构 +struct OxcParams { + string workload_path; + string ranktable_path; // 外部 RankTable JSON 文件路径 + int num_gpus; + int gpus_per_server; + string oxc_url; + string oxc_algo; + string output_prefix; + + OxcParams() + : num_gpus(16), + gpus_per_server(8), + oxc_url("http://localhost:8080"), + oxc_algo("ALGO_OXC_RING"), + output_prefix("./results/oxc_output") {} +}; + +void printUsage(const char* prog_name) { + cout << "Usage: " << prog_name << " [options]" << endl; + cout << "Options:" << endl; + cout << " -w, --workload Path to workload file (required)" << endl; + cout << " -ranktable Path to RankTable JSON file (required)" << endl; + cout << " -g, --gpus Number of GPUs (default: 16)" << endl; + cout << " -g_p_s, --gpus-per-server GPUs per server (default: 8)" << endl; + cout << " -oxc_url OXC server URL (default: http://localhost:8080)" << endl; + cout << " -oxc_algo OXC algorithm (default: ALGO_OXC_RING)" << endl; + cout << " -o, --output Output file prefix (default: ./results/oxc_output)" << endl; + cout << " -h, --help Show this help message" << endl; +} + +int parseArgs(int argc, char* argv[], OxcParams& params) { + // 手动解析所有参数,避免 getopt 与自定义参数冲突 + for (int i = 1; i < argc; ++i) { + string arg = argv[i]; + if ((arg == "-w" || arg == "--workload") && i + 1 < argc) { + params.workload_path = argv[++i]; + } else if ((arg == "-g" || arg == "--gpus") && i + 1 < argc) { + params.num_gpus = atoi(argv[++i]); + } else if ((arg == "-g_p_s" || arg == "--gpus-per-server") && i + 1 < argc) { + params.gpus_per_server = atoi(argv[++i]); + } else if ((arg == "-o" || arg == "--output") && i + 1 < argc) { + params.output_prefix = argv[++i]; + } else if ((arg == "-ranktable" || arg == "--ranktable") && i + 1 < argc) { + params.ranktable_path = argv[++i]; + } else if (arg == "-oxc_url" && i + 1 < argc) { + params.oxc_url = argv[++i]; + } else if (arg == "-oxc_algo" && i + 1 < argc) { + params.oxc_algo = argv[++i]; + } else if (arg == "-h" || arg == "--help") { + printUsage(argv[0]); + return 1; + } + } + + if (params.workload_path.empty()) { + cerr << "Error: Workload path is required" << endl; + printUsage(argv[0]); + return -1; + } + + if (params.ranktable_path.empty()) { + cerr << "Error: RankTable path is required" << endl; + cerr << " Use generate_ranktable.py to create a RankTable JSON file" << endl; + printUsage(argv[0]); + return -1; + } + + return 0; +} + +// ============== JSON 解析辅助函数 ============== + +// 跳过空白字符 +static size_t skipWhitespace(const string& s, size_t pos) { + while (pos < s.size() && (s[pos] == ' ' || s[pos] == '\n' || s[pos] == '\t' || s[pos] == '\r')) { + pos++; + } + return pos; +} + +// 解析 JSON 字符串值 +static string parseJsonString(const string& s, size_t& pos) { + pos = skipWhitespace(s, pos); + if (pos >= s.size() || s[pos] != '"') return ""; + pos++; // 跳过开头的 " + + string result; + while (pos < s.size() && s[pos] != '"') { + if (s[pos] == '\\' && pos + 1 < s.size()) { + pos++; // 跳过转义字符 + } + result += s[pos++]; + } + if (pos < s.size()) pos++; // 跳过结尾的 " + return result; +} + +// 解析 JSON 整数值 +static int64_t parseJsonInt(const string& s, size_t& pos) { + pos = skipWhitespace(s, pos); + bool negative = false; + if (pos < s.size() && s[pos] == '-') { + negative = true; + pos++; + } + int64_t result = 0; + while (pos < s.size() && s[pos] >= '0' && s[pos] <= '9') { + result = result * 10 + (s[pos] - '0'); + pos++; + } + return negative ? -result : result; +} + +// 查找 JSON 键 +static size_t findJsonKey(const string& s, size_t pos, const string& key) { + string search = "\"" + key + "\""; + size_t found = s.find(search, pos); + if (found != string::npos) { + found += search.size(); + // 跳过冒号 + found = skipWhitespace(s, found); + if (found < s.size() && s[found] == ':') { + found++; + } + } + return found; +} + +// 解析 RankTable JSON 文件 +bool parseRankTableJson(const string& filepath, RankTable& ranktable, map& rank_rack_map) { + ifstream file(filepath); + if (!file.is_open()) { + cerr << "Error: Cannot open RankTable file: " << filepath << endl; + return false; + } + + // 读取整个文件 + stringstream buffer; + buffer << file.rdbuf(); + string json = buffer.str(); + file.close(); + + size_t pos = 0; + + // 解析 version + pos = findJsonKey(json, 0, "version"); + if (pos != string::npos) { + ranktable.version = parseJsonString(json, pos); + } + + // 解析 status + pos = findJsonKey(json, 0, "status"); + if (pos != string::npos) { + ranktable.status = parseJsonString(json, pos); + } + + // 解析 rank_count + pos = findJsonKey(json, 0, "rank_count"); + if (pos != string::npos) { + ranktable.rank_count = static_cast(parseJsonInt(json, pos)); + } + + // 解析 rank_list + pos = findJsonKey(json, 0, "rank_list"); + if (pos == string::npos) { + cerr << "Error: rank_list not found in RankTable JSON" << endl; + return false; + } + + // 找到 rank_list 数组的开始 + pos = skipWhitespace(json, pos); + if (pos >= json.size() || json[pos] != '[') { + cerr << "Error: rank_list is not an array" << endl; + return false; + } + pos++; // 跳过 '[' + + // 解析每个 rank + while (pos < json.size()) { + pos = skipWhitespace(json, pos); + if (pos >= json.size() || json[pos] == ']') break; + if (json[pos] == ',') { pos++; continue; } + + if (json[pos] != '{') break; + + // 找到这个 rank 对象的结束位置 + int brace_count = 1; + size_t rank_start = pos; + pos++; + while (pos < json.size() && brace_count > 0) { + if (json[pos] == '{') brace_count++; + else if (json[pos] == '}') brace_count--; + pos++; + } + string rank_json = json.substr(rank_start, pos - rank_start); + + RankInfo rank_info; + size_t rpos = 0; + + // 解析 rank_id + rpos = findJsonKey(rank_json, 0, "rank_id"); + if (rpos != string::npos) { + rank_info.rank_id = static_cast(parseJsonInt(rank_json, rpos)); + } + + // 解析 device_id + rpos = findJsonKey(rank_json, 0, "device_id"); + if (rpos != string::npos) { + rank_info.device_id = static_cast(parseJsonInt(rank_json, rpos)); + } + + // 解析 local_id + rpos = findJsonKey(rank_json, 0, "local_id"); + if (rpos != string::npos) { + rank_info.local_id = static_cast(parseJsonInt(rank_json, rpos)); + } + + // 解析 level_list + rpos = findJsonKey(rank_json, 0, "level_list"); + if (rpos != string::npos) { + rpos = skipWhitespace(rank_json, rpos); + if (rpos < rank_json.size() && rank_json[rpos] == '[') { + rpos++; + + while (rpos < rank_json.size()) { + rpos = skipWhitespace(rank_json, rpos); + if (rpos >= rank_json.size() || rank_json[rpos] == ']') break; + if (rank_json[rpos] == ',') { rpos++; continue; } + + if (rank_json[rpos] != '{') break; + + // 找到这个 level 对象的结束位置 + int level_brace = 1; + size_t level_start = rpos; + rpos++; + while (rpos < rank_json.size() && level_brace > 0) { + if (rank_json[rpos] == '{') level_brace++; + else if (rank_json[rpos] == '}') level_brace--; + rpos++; + } + string level_json = rank_json.substr(level_start, rpos - level_start); + + LevelInfo level; + size_t lpos = 0; + + lpos = findJsonKey(level_json, 0, "net_layer"); + if (lpos != string::npos) { + level.net_layer = static_cast(parseJsonInt(level_json, lpos)); + } + + lpos = findJsonKey(level_json, 0, "net_instance_id"); + if (lpos != string::npos) { + level.net_instance_id = parseJsonString(level_json, lpos); + } + + lpos = findJsonKey(level_json, 0, "net_type"); + if (lpos != string::npos) { + level.net_type = parseJsonString(level_json, lpos); + } + + lpos = findJsonKey(level_json, 0, "net_attr"); + if (lpos != string::npos) { + level.net_attr = parseJsonString(level_json, lpos); + } + + // 解析 rank_addr_list + lpos = findJsonKey(level_json, 0, "rank_addr_list"); + if (lpos != string::npos) { + lpos = skipWhitespace(level_json, lpos); + if (lpos < level_json.size() && level_json[lpos] == '[') { + lpos++; + + while (lpos < level_json.size()) { + lpos = skipWhitespace(level_json, lpos); + if (lpos >= level_json.size() || level_json[lpos] == ']') break; + if (level_json[lpos] == ',') { lpos++; continue; } + + if (level_json[lpos] != '{') break; + + // 找到这个 addr 对象的结束位置 + int addr_brace = 1; + size_t addr_start = lpos; + lpos++; + while (lpos < level_json.size() && addr_brace > 0) { + if (level_json[lpos] == '{') addr_brace++; + else if (level_json[lpos] == '}') addr_brace--; + lpos++; + } + string addr_json = level_json.substr(addr_start, lpos - addr_start); + + RankAddr addr; + size_t apos = 0; + + apos = findJsonKey(addr_json, 0, "addr_type"); + if (apos != string::npos) { + addr.addr_type = parseJsonString(addr_json, apos); + } + + apos = findJsonKey(addr_json, 0, "addr"); + if (apos != string::npos) { + addr.addr = parseJsonString(addr_json, apos); + } + + apos = findJsonKey(addr_json, 0, "plane_id"); + if (apos != string::npos) { + addr.plane_id = parseJsonString(addr_json, apos); + } + + // 解析 ports 数组 + apos = findJsonKey(addr_json, 0, "ports"); + if (apos != string::npos) { + apos = skipWhitespace(addr_json, apos); + if (apos < addr_json.size() && addr_json[apos] == '[') { + apos++; + while (apos < addr_json.size()) { + apos = skipWhitespace(addr_json, apos); + if (apos >= addr_json.size() || addr_json[apos] == ']') break; + if (addr_json[apos] == ',') { apos++; continue; } + if (addr_json[apos] == '"') { + string port = parseJsonString(addr_json, apos); + if (!port.empty()) { + addr.ports.push_back(port); + } + } else { + break; + } + } + } + } + + level.rank_addr_list.push_back(addr); + } + } + } + + rank_info.level_list.push_back(level); + } + } + } + + ranktable.rank_list.push_back(rank_info); + } + + // 自动生成 rank_rack_map(基于 level_list 中的 net_instance_id) + for (const auto& rank : ranktable.rank_list) { + if (!rank.level_list.empty()) { + // 使用第一个 level 的 net_instance_id 作为 rack_id + rank_rack_map[to_string(rank.rank_id)] = rank.level_list[0].net_instance_id; + } + } + + cout << "[OXC] RankTable loaded from: " << filepath << endl; + cout << "[OXC] Version: " << ranktable.version << endl; + cout << "[OXC] Rank count: " << ranktable.rank_count << endl; + cout << "[OXC] Parsed ranks: " << ranktable.rank_list.size() << endl; + + return true; +} + +// ============== 工作负载解析 ============== + +// 解析工作负载文件 +WorkloadConfig parseWorkload(const string& workload_path) { + WorkloadConfig config; + ifstream file(workload_path); + + if (!file.is_open()) { + cerr << "Error: Cannot open workload file: " << workload_path << endl; + return config; + } + + string line; + + // 读取第一行:并行策略和配置 + if (getline(file, line)) { + istringstream iss(line); + string token; + + // 解析并行策略 + iss >> config.parallelism_policy; + + // 解析键值对 + while (iss >> token) { + if (token == "model_parallel_NPU_group:") { + iss >> config.model_parallel_npu_group; + } else if (token == "ep:") { + iss >> config.ep_size; + } else if (token == "pp:") { + iss >> config.pp_size; + } else if (token == "vpp:") { + iss >> config.vpp; + } else if (token == "ga:") { + iss >> config.ga; + } else if (token == "all_gpus:") { + iss >> config.all_gpus; + } + } + } + + // 读取第二行:层数 + if (getline(file, line)) { + config.num_layers = stoi(line); + } + + // 读取层定义 + while (getline(file, line)) { + if (line.empty()) continue; + + istringstream iss(line); + LayerCommInfo layer; + + string layer_name; + int dependency; + uint64_t fwd_compute, ig_compute, wg_compute, wg_update; + string fwd_type_str, ig_type_str, wg_type_str; + uint64_t fwd_size, ig_size, wg_size; + + // 解析层定义 + // 格式: layer_name dependency fwd_compute fwd_type fwd_size ig_compute ig_type ig_size wg_compute wg_type wg_size wg_update + iss >> layer_name >> dependency + >> fwd_compute >> fwd_type_str >> fwd_size + >> ig_compute >> ig_type_str >> ig_size + >> wg_compute >> wg_type_str >> wg_size + >> wg_update; + + layer.layer_name = layer_name; + layer.layer_index = static_cast(config.layers.size()); + + // 解析通信类型和组类型 + layer.fwd_comm_type = parseCommType(fwd_type_str); + layer.fwd_group_type = parseGroupType(fwd_type_str, TrainingPhase::FORWARD_PASS); + layer.fwd_comm_size = fwd_size; + + layer.ig_comm_type = parseCommType(ig_type_str); + layer.ig_group_type = parseGroupType(ig_type_str, TrainingPhase::INPUT_GRADIENT); + layer.ig_comm_size = ig_size; + + layer.wg_comm_type = parseCommType(wg_type_str); + layer.wg_group_type = parseGroupType(wg_type_str, TrainingPhase::WEIGHT_GRADIENT); + layer.wg_comm_size = wg_size; + + config.layers.push_back(layer); + } + + file.close(); + return config; +} + +int main(int argc, char* argv[]) { + OxcParams params; + + int ret = parseArgs(argc, argv, params); + if (ret != 0) { + return ret; + } + + cout << "SimAI-OXC Flow Generator" << endl; + cout << "========================" << endl; + cout << "Workload: " << params.workload_path << endl; + cout << "GPUs: " << params.num_gpus << endl; + cout << "GPUs per Server: " << params.gpus_per_server << endl; + cout << "RankTable: " << params.ranktable_path << endl; + cout << "OXC URL: " << params.oxc_url << endl; + cout << "OXC Algorithm: " << params.oxc_algo << endl; + cout << "Output Prefix: " << params.output_prefix << endl; + cout << endl; + + // 解析工作负载 + cout << "[OXC] Parsing workload..." << endl; + WorkloadConfig config = parseWorkload(params.workload_path); + + if (config.layers.empty()) { + cerr << "Error: No layers found in workload" << endl; + return -1; + } + + // 更新配置 + config.all_gpus = params.num_gpus; + config.gpus_per_server = params.gpus_per_server; + + cout << "[OXC] Workload parsed: " << config.num_layers << " layers" << endl; + cout << "[OXC] Parallelism: TP=" << config.model_parallel_npu_group + << ", EP=" << config.ep_size + << ", PP=" << config.pp_size << endl; + + // 计算DP大小 + int dp_size = params.num_gpus / config.model_parallel_npu_group; + if (config.ep_size > 1) { + dp_size = dp_size / config.ep_size; + } + + // 创建流生成器 + OxcFlowGenerator flow_gen( + params.oxc_url, + params.num_gpus, + params.gpus_per_server, + config.model_parallel_npu_group, + dp_size, + config.ep_size, + config.pp_size + ); + flow_gen.setAlgorithm(params.oxc_algo); + + // 加载外部 RankTable(必需) + cout << "[OXC] Loading RankTable from: " << params.ranktable_path << endl; + RankTable external_ranktable; + map external_rank_rack_map; + + if (!parseRankTableJson(params.ranktable_path, external_ranktable, external_rank_rack_map)) { + cerr << "Error: Failed to load RankTable from: " << params.ranktable_path << endl; + return -1; + } + flow_gen.setRankTable(external_ranktable); + flow_gen.setRankRackMap(external_rank_rack_map); + + // 处理每一层 + cout << "[OXC] Generating flows..." << endl; + int prev_op_id = -1; + + for (const auto& layer : config.layers) { + // 处理前向传播通信 + if (layer.fwd_comm_type != CommType::NONE && layer.fwd_comm_size > 0) { + OperationContext ctx; + ctx.layer_name = layer.layer_name; + ctx.layer_index = layer.layer_index; + ctx.phase = TrainingPhase::FORWARD_PASS; + ctx.comm_type = layer.fwd_comm_type; + ctx.group_type = layer.fwd_group_type; + ctx.data_size = layer.fwd_comm_size; + if (prev_op_id >= 0) { + ctx.depends_on_ops.push_back(prev_op_id); + } + + // 构建通信组 + auto domains = flow_gen.buildCommDomains(layer.fwd_group_type, params.num_gpus); + for (const auto& domain : domains) { + flow_gen.generateFlows(ctx, domain); + } + if (!flow_gen.getAllOperations().empty()) { + prev_op_id = flow_gen.getAllOperations().back().operation_id; + } + } + + // 处理输入梯度通信 + if (layer.ig_comm_type != CommType::NONE && layer.ig_comm_size > 0) { + OperationContext ctx; + ctx.layer_name = layer.layer_name; + ctx.layer_index = layer.layer_index; + ctx.phase = TrainingPhase::INPUT_GRADIENT; + ctx.comm_type = layer.ig_comm_type; + ctx.group_type = layer.ig_group_type; + ctx.data_size = layer.ig_comm_size; + if (prev_op_id >= 0) { + ctx.depends_on_ops.push_back(prev_op_id); + } + + auto domains = flow_gen.buildCommDomains(layer.ig_group_type, params.num_gpus); + for (const auto& domain : domains) { + flow_gen.generateFlows(ctx, domain); + } + if (!flow_gen.getAllOperations().empty()) { + prev_op_id = flow_gen.getAllOperations().back().operation_id; + } + } + + // 处理权重梯度通信 + if (layer.wg_comm_type != CommType::NONE && layer.wg_comm_size > 0) { + OperationContext ctx; + ctx.layer_name = layer.layer_name; + ctx.layer_index = layer.layer_index; + ctx.phase = TrainingPhase::WEIGHT_GRADIENT; + ctx.comm_type = layer.wg_comm_type; + ctx.group_type = layer.wg_group_type; + ctx.data_size = layer.wg_comm_size; + if (prev_op_id >= 0) { + ctx.depends_on_ops.push_back(prev_op_id); + } + + auto domains = flow_gen.buildCommDomains(layer.wg_group_type, params.num_gpus); + for (const auto& domain : domains) { + flow_gen.generateFlows(ctx, domain); + } + if (!flow_gen.getAllOperations().empty()) { + prev_op_id = flow_gen.getAllOperations().back().operation_id; + } + } + } + + // 输出结果 + cout << "[OXC] Writing output files..." << endl; + OxcFlowOutput output(params.output_prefix); + + const auto& all_flows = flow_gen.getAllFlows(); + const auto& all_ops = flow_gen.getAllOperations(); + + output.writeFlowMatrices(all_flows); + output.writeDependencyGraph(all_ops, all_flows); + output.writeSummary(config, all_ops, all_flows, params.oxc_url, params.oxc_algo); + + cout << endl; + cout << "SimAI-OXC completed successfully." << endl; + cout << " Total Operations: " << all_ops.size() << endl; + cout << " Total Flows: " << all_flows.size() << endl; + + return 0; +} diff --git a/astra-sim-alibabacloud/astra-sim/system/OxcTypes.h b/astra-sim-alibabacloud/astra-sim/system/OxcTypes.h new file mode 100644 index 00000000..a99f8f9b --- /dev/null +++ b/astra-sim-alibabacloud/astra-sim/system/OxcTypes.h @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2024, Alibaba Group; + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OXC_TYPES_H__ +#define __OXC_TYPES_H__ + +#include +#include +#include +#include + +namespace OXC { + +// RankTable 相关结构体 - 匹配 Java API 格式 + +struct RankAddr { + std::string addr_type; // "EID" + std::string addr; // "000000000000002000100000df001001" + std::vector ports; // ["0/0"] + std::string plane_id; // "plane0" +}; + +struct LevelInfo { + int net_layer; // 0 + std::string net_instance_id; // "superpod1_1" + std::string net_type; // "TOPO_FILE_DESC" + std::string net_attr; // "" + std::vector rank_addr_list; +}; + +struct RankInfo { + int rank_id; + int device_id; + int local_id; + std::vector level_list; +}; + +struct RankTable { + std::string version = "2.0"; + std::string status = "completed"; + int rank_count; + std::vector rank_list; +}; + +// OXC AllReduce API 请求结构 +struct OxcAllReduceRequest { + RankTable ranktable; + std::vector> dpCommDomain; + double commDomainVolume; + std::map rankIdRackIdMap; + std::string algName; // "ALGO_OXC_RING", "ALGO_OXC_HD", "ALGO_OXC_NB" +}; + +// OXC API 响应 - 单个流条目 +// 响应格式: [[src_rank, dst_rank, step, datasize], ...] +struct OxcFlowEntry { + int src_rank; + int dst_rank; + int step; + uint64_t datasize; +}; + +// 通信类型枚举 +enum class CommType { + NONE, + ALL_REDUCE, + ALL_GATHER, + REDUCE_SCATTER, + ALL_TO_ALL, + ALL_REDUCE_ALL_TO_ALL +}; + +// 通信组类型枚举 +enum class GroupType { + TP, // Tensor Parallelism + DP, // Data Parallelism + PP, // Pipeline Parallelism + EP, // Expert Parallelism + DP_EP, // Combined DP and EP + NONE +}; + +// 训练阶段枚举 +enum class TrainingPhase { + FORWARD_PASS, + INPUT_GRADIENT, + WEIGHT_GRADIENT +}; + +// 层通信信息 +struct LayerCommInfo { + std::string layer_name; + int layer_index; + + // 前向传播通信 + CommType fwd_comm_type; + GroupType fwd_group_type; + uint64_t fwd_comm_size; + + // 输入梯度通信 + CommType ig_comm_type; + GroupType ig_group_type; + uint64_t ig_comm_size; + + // 权重梯度通信 + CommType wg_comm_type; + GroupType wg_group_type; + uint64_t wg_comm_size; +}; + +// 工作负载配置 +struct WorkloadConfig { + std::string parallelism_policy; + int model_parallel_npu_group = 1; // TP size + int ep_size = 1; // EP size + int pp_size = 1; // PP size + int vpp = 1; // Virtual PP + int ga = 1; // Gradient Accumulation + int all_gpus = 0; // Total GPUs + int gpus_per_server = 8; // GPUs per server + int num_layers = 0; + std::vector layers; +}; + +// 输出流结构 +struct OutputFlow { + int operation_id; + std::string layer_name; + TrainingPhase phase; + CommType comm_type; + GroupType group_type; + int flow_id; + int src; + int dst; + uint64_t flow_size; + int step; + std::vector depends_on; // 该流依赖的流ID列表 +}; + +// 操作上下文 +struct OperationContext { + int operation_id; + std::string layer_name; + int layer_index; + TrainingPhase phase; + CommType comm_type; + GroupType group_type; + uint64_t data_size; + int base_flow_id; + int flow_count; + std::vector depends_on_ops; // 依赖的操作ID列表 +}; + +// 辅助函数:将枚举转换为字符串 +inline std::string commTypeToString(CommType type) { + switch (type) { + case CommType::NONE: return "NONE"; + case CommType::ALL_REDUCE: return "ALLREDUCE"; + case CommType::ALL_GATHER: return "ALLGATHER"; + case CommType::REDUCE_SCATTER: return "REDUCESCATTER"; + case CommType::ALL_TO_ALL: return "ALLTOALL"; + case CommType::ALL_REDUCE_ALL_TO_ALL: return "ALLREDUCEALLTOALL"; + default: return "UNKNOWN"; + } +} + +inline std::string groupTypeToString(GroupType type) { + switch (type) { + case GroupType::TP: return "TP"; + case GroupType::DP: return "DP"; + case GroupType::PP: return "PP"; + case GroupType::EP: return "EP"; + case GroupType::DP_EP: return "DP_EP"; + case GroupType::NONE: return "NONE"; + default: return "UNKNOWN"; + } +} + +inline std::string phaseToString(TrainingPhase phase) { + switch (phase) { + case TrainingPhase::FORWARD_PASS: return "fwd"; + case TrainingPhase::INPUT_GRADIENT: return "ig"; + case TrainingPhase::WEIGHT_GRADIENT: return "wg"; + default: return "unknown"; + } +} + +// 从字符串解析通信类型 +inline CommType parseCommType(const std::string& str) { + if (str.find("ALLREDUCE") != std::string::npos && + str.find("ALLTOALL") == std::string::npos) { + return CommType::ALL_REDUCE; + } else if (str.find("ALLGATHER") != std::string::npos) { + return CommType::ALL_GATHER; + } else if (str.find("REDUCESCATTER") != std::string::npos) { + return CommType::REDUCE_SCATTER; + } else if (str.find("ALLTOALL") != std::string::npos && + str.find("ALLREDUCE") == std::string::npos) { + return CommType::ALL_TO_ALL; + } else if (str.find("ALLREDUCEALLTOALL") != std::string::npos) { + return CommType::ALL_REDUCE_ALL_TO_ALL; + } + return CommType::NONE; +} + +// 从字符串解析组类型 +inline GroupType parseGroupType(const std::string& str, TrainingPhase phase) { + if (str.find("_DP_EP") != std::string::npos) { + return GroupType::DP_EP; + } else if (str.find("_EP") != std::string::npos) { + return GroupType::EP; + } + // 默认:权重梯度用DP,其他用TP + if (phase == TrainingPhase::WEIGHT_GRADIENT) { + return GroupType::DP; + } + return GroupType::TP; +} + +} // namespace OXC + +#endif // __OXC_TYPES_H__ diff --git a/astra-sim-alibabacloud/build.sh b/astra-sim-alibabacloud/build.sh index 5007f353..40e9bd94 100755 --- a/astra-sim-alibabacloud/build.sh +++ b/astra-sim-alibabacloud/build.sh @@ -3,6 +3,7 @@ SCRIPT_DIR=$(dirname "$(realpath $0)") NS3_BUILD_DIR="${SCRIPT_DIR:?}"/build/astra_ns3 SIMAI_PHY_BUILD_DIR="${SCRIPT_DIR:?}"/build/simai_phy SIMAI_ANALYTICAL_BUILD_DIR="${SCRIPT_DIR:?}"/build/simai_analytical +SIMAI_OXC_BUILD_DIR="${SCRIPT_DIR:?}"/build/simai_oxc SIM_LOG_DIR=/etc/astra-sim # Functions @@ -18,6 +19,9 @@ function cleanup_build { "analytical") cd "${SIMAI_ANALYTICAL_BUILD_DIR}" ./build.sh -l;; + "oxc") + cd "${SIMAI_OXC_BUILD_DIR}" + ./build.sh -l;; esac } @@ -33,6 +37,9 @@ function cleanup_result { "analytical") cd "${SIMAI_ANALYTICAL_BUILD_DIR}" ./build.sh -lr;; + "oxc") + cd "${SIMAI_OXC_BUILD_DIR}" + ./build.sh -lr;; esac } @@ -43,7 +50,7 @@ function compile { mkdir -p "${SIM_LOG_DIR}"/config/ mkdir -p "${SIM_LOG_DIR}"/topo/ mkdir -p "${SIM_LOG_DIR}"/results/ - local option="$1" + local option="$1" cd "${BUILD_DIR}" || exit case "$option" in "ns3") @@ -55,6 +62,9 @@ function compile { "analytical") cd "${SIMAI_ANALYTICAL_BUILD_DIR}" ./build.sh -c;; + "oxc") + cd "${SIMAI_OXC_BUILD_DIR}" + ./build.sh -c;; esac } @@ -68,7 +78,7 @@ case "$1" in compile "$2";; -h|--help|*) printf -- "help message\n" - printf -- "-c|--compile mode supported ns3/phy/analytical (example:./build.sh -c ns3)\n" + printf -- "-c|--compile mode supported ns3/phy/analytical/oxc (example:./build.sh -c ns3)\n" printf -- "-l|--clean (example:./build.sh -l ns3)\n" printf -- "-lr|--clean-result mode (example:./build.sh -lr ns3)\n" esac \ No newline at end of file diff --git a/astra-sim-alibabacloud/build/simai_oxc/CMakeLists.txt b/astra-sim-alibabacloud/build/simai_oxc/CMakeLists.txt new file mode 100644 index 00000000..d9d53a0d --- /dev/null +++ b/astra-sim-alibabacloud/build/simai_oxc/CMakeLists.txt @@ -0,0 +1,26 @@ +# CMake requirement +cmake_minimum_required(VERSION 3.15) + +# C++ requirement +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +# Compiler requirement +if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 5.3) + message(FATAL_ERROR "g++ (GNU) version should be greater than 5.3, but found ${CMAKE_CXX_COMPILER_VERSION}") + endif() +endif() + + +# Setup project +project (AstraSim_OXC) + +# Use analytical mode to exclude MPI-dependent files +set(USE_ANALYTICAL TRUE) + +# Compile AstraSim library +add_subdirectory("${PROJECT_SOURCE_DIR}/../../" AstraSim) + +# Compile OXC binary +add_subdirectory ("${PROJECT_SOURCE_DIR}/../../astra-sim/network_frontend/oxc" simai_oxc) diff --git a/astra-sim-alibabacloud/build/simai_oxc/build.sh b/astra-sim-alibabacloud/build/simai_oxc/build.sh new file mode 100755 index 00000000..7b0b9b05 --- /dev/null +++ b/astra-sim-alibabacloud/build/simai_oxc/build.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# Absolue path to this script +SCRIPT_DIR=$(dirname "$(realpath $0)") + +# Absolute paths to useful directories +BUILD_DIR="${SCRIPT_DIR:?}"/build/ +RESULT_DIR="${SCRIPT_DIR:?}"/result/ +BIN_DIR="${BUILD_DIR}"/simai_oxc/ +BINARY="./SimAI_oxc" + +# Functions +function cleanup_build { + rm -rf "${BUILD_DIR}" +} + +function cleanup_result { + rm -rf "${RESULT_DIR}" +} + +function setup { + mkdir -p "${BUILD_DIR}" + mkdir -p "${RESULT_DIR}" +} + +function compile { + cd "${BUILD_DIR}" || exit + cmake -DUSE_OXC=TRUE .. + make +} + + +# Main Script +case "$1" in +-l|--clean) + cleanup_build;; +-lr|--clean-result) + cleanup_build + cleanup_result;; +-c|--compile) + setup + compile;; +-h|--help|*) + echo "SimAI_oxc build script." + echo "Run $0 -c to compile.";; +esac diff --git a/docs/SimAI-OXC-Design.md b/docs/SimAI-OXC-Design.md new file mode 100644 index 00000000..17a80fb3 --- /dev/null +++ b/docs/SimAI-OXC-Design.md @@ -0,0 +1,1315 @@ +# SimAI-OXC 集成设计文档 + +## 目录 + +1. [背景与动机](#1-背景与动机) +2. [训练流程与通信模式](#2-训练流程与通信模式) +3. [系统架构](#3-系统架构) +4. [核心数据结构](#4-核心数据结构) +5. [模块设计](#5-模块设计) +6. [流程设计](#6-流程设计) +7. [接口规范](#7-接口规范) +8. [使用指南](#8-使用指南) + +--- + +## 1. 背景与动机 + +### 1.1 项目背景 + +#### SimAI 简介 + +**SimAI** 是阿里巴巴开发的全栈高精度 AI 大规模训练模拟器,已被 NSDI'25 Spring 接收。它提供了对整个 LLM 训练过程的详细建模和仿真,涵盖框架层、集合通信层和网络层。 + +SimAI 支持三种运行模式: +- **SimAI-Analytical**:使用总线带宽(busbw)抽象进行快速仿真,适合快速评估 +- **SimAI-Simulation**:基于 NS3 的全栈仿真,提供细粒度网络通信建模 +- **SimAI-Physical**:物理流量生成模式,用于 CPU RDMA 集群环境 + +SimAI 的核心组件包括: +- **AICB**(AI Communication Benchmark):工作负载生成和测试 +- **SimCCL**:集合通信库仿真 +- **astra-sim-alibabacloud**:扩展自 astra-sim,支持 NCCL 算法 +- **ns-3-alibabacloud**:网络仿真后端 + +#### optical_hccl_system 简介 + +**optical_hccl_system** 是一个基于光交叉连接(OXC, Optical Cross-Connect)的集合通信优化系统。该系统通过光交换技术实现更高效的跨机架通信,特别针对大规模 AI 训练场景中的 AllReduce 操作进行优化。 + +OXC 系统提供多种集合通信算法: +- **ALGO_OXC_RING**:基于环形拓扑的 AllReduce 算法 +- **ALGO_OXC_HD**:Halving-Doubling 算法 +- **ALGO_OXC_NB**:非阻塞算法 + +该系统以 Java 服务的形式运行,通过 HTTP REST API 接收请求并返回优化后的通信流调度方案。 + +### 1.2 集成动机 + +在大规模 AI 训练中,跨机架的集合通信(如 AllReduce)是主要的性能瓶颈之一: + +```mermaid +flowchart TD + subgraph Problem[传统电交换网络的问题] + P1[带宽受限:跨机架链路带宽通常低于机架内带宽] + P2[延迟较高:多跳电交换引入额外延迟] + P3[拥塞严重:大规模 AllReduce 容易造成网络拥塞] + end + + Problem --> Solution + + subgraph Solution[OXC 技术优势] + S1[高带宽:光链路提供更高的传输带宽] + S2[低延迟:减少电光转换开销] + S3[灵活调度:动态配置光路径优化通信模式] + end +``` + +### 1.3 设计目标 + +| 目标 | 描述 | +|------|------| +| **无缝集成** | 将 OXC 算法集成到 SimAI 仿真框架中 | +| **灵活配置** | 支持外部 RankTable 配置,适应不同拓扑 | +| **混合策略** | OXC 处理跨机架通信,原生算法处理机架内通信 | +| **可扩展性** | 支持大规模 GPU 集群仿真 | + +--- + +## 2. 训练流程与通信模式 + +### 2.1 训练流程概述 + +大模型训练的一个迭代包含三个主要阶段: + +```mermaid +flowchart TD + A[加载 Mini-batch 数据] --> B + + subgraph B[前向传播 Forward Pass] + B1[Layer 1 计算] --> B2[Layer 2 计算] --> B3[...] --> B4[Layer N 计算] --> B5[计算 Loss] + end + + B --> C + + subgraph C[反向传播 Backward Pass] + C1[计算 Loss 梯度] --> C2[Layer N 梯度] --> C3[...] --> C4[Layer 2 梯度] --> C5[Layer 1 梯度] + end + + C --> D + + subgraph D[参数更新 Weight Update] + D1[梯度同步 DP AllReduce] --> D2[优化器更新参数] + end + + D --> E[迭代完成] +``` + +### 2.2 并行策略与通信阶段 + +#### 2.2.1 TP (Tensor Parallelism) - 张量并行 + +TP 将单层的参数切分到多个 GPU,发生在每一层的前向和反向传播中。 + +```mermaid +flowchart TD + subgraph Input[输入层] + X[输入 X 在所有 GPU 上相同] + end + + subgraph Forward[前向传播] + direction LR + G0[GPU 0
Y0=X@W0
列切分] + G1[GPU 1
Y1=X@W1] + G2[GPU 2
Y2=X@W2] + G3[GPU 3
Y3=X@W3] + end + + X --> G0 & G1 & G2 & G3 + + G0 & G1 & G2 & G3 --> AG[AllGather
Y = Y0,Y1,Y2,Y3
收集所有部分结果] + + subgraph Backward[反向传播] + direction LR + D0[GPU 0
dW0=X.T@dY0] + D1[GPU 1
dW1=X.T@dY1] + D2[GPU 2
dW2=X.T@dY2] + D3[GPU 3
dW3=X.T@dY3] + end + + AG --> D0 & D1 & D2 & D3 + + D0 & D1 & D2 & D3 --> AR[AllReduce
同步输入梯度 dX] +``` + +#### 2.2.2 DP (Data Parallelism) - 数据并行 + +DP 每个 GPU 有完整模型副本,处理不同数据,发生在反向传播结束后。 + +```mermaid +flowchart TD + subgraph Data[数据分片] + direction LR + D0[GPU 0
Data 0] + D1[GPU 1
Data 1] + D2[GPU 2
Data 2] + D3[GPU 3
Data 3] + end + + subgraph Forward[前向传播] + direction LR + F0[Forward
Data0] + F1[Forward
Data1] + F2[Forward
Data2] + F3[Forward
Data3] + end + + D0 --> F0 + D1 --> F1 + D2 --> F2 + D3 --> F3 + + subgraph Backward[反向传播] + direction LR + B0[Backward
Grad0] + B1[Backward
Grad1] + B2[Backward
Grad2] + B3[Backward
Grad3] + end + + F0 --> B0 + F1 --> B1 + F2 --> B2 + F3 --> B3 + + B0 & B1 & B2 & B3 --> AR[AllReduce
Grad_avg = Σ/4
同步所有 GPU 的梯度] + + subgraph Update[参数更新] + direction LR + U0[W=W-lr*Grad] + U1[W=W-lr*Grad] + U2[W=W-lr*Grad] + U3[W=W-lr*Grad] + end + + AR --> U0 & U1 & U2 & U3 +``` + +#### 2.2.3 PP (Pipeline Parallelism) - 流水线并行 + +PP 将模型按层切分到不同 GPU,发生在层与层之间的数据传递。 + +```mermaid +flowchart LR + subgraph Stage0[Stage 0
Layer 0-3] + F0[F0] --> B0[B0] + end + + subgraph Stage1[Stage 1
Layer 4-7] + F1[F0] --> B1[B0] + end + + subgraph Stage2[Stage 2
Layer 8-11] + F2[F0] --> B2[B0] + end + + subgraph Stage3[Stage 3
Layer 12-15] + F3[F0+Loss] --> B3[B0] + end + + F0 -->|P2P Send| F1 + F1 -->|P2P Send| F2 + F2 -->|P2P Send| F3 + + B3 -->|P2P Send| B2 + B2 -->|P2P Send| B1 + B1 -->|P2P Send| B0 +``` + +**说明**: PP 使用 P2P Send/Recv,不使用集合通信。上图展示 1F1B 调度中 Micro-batch 0 的前向和反向传播流程。 + +#### 2.2.4 EP (Expert Parallelism) - 专家并行 + +EP 用于 MoE (Mixture of Experts) 模型,发生在 MoE 层的 token 路由。 + +```mermaid +flowchart TD + subgraph Experts[专家分布] + direction LR + E0[GPU 0
Expert 0,1] + E1[GPU 1
Expert 2,3] + E2[GPU 2
Expert 4,5] + E3[GPU 3
Expert 6,7] + end + + subgraph Router[Token 路由] + direction LR + R0[Router
决定去向] + R1[Router
决定去向] + R2[Router
决定去向] + R3[Router
决定去向] + end + + E0 --> R0 + E1 --> R1 + E2 --> R2 + E3 --> R3 + + R0 & R1 & R2 & R3 --> A2A1[AlltoAll
交换 tokens 到目标专家] + + subgraph Compute[专家计算] + direction LR + C0[Expert 0,1
计算] + C1[Expert 2,3
计算] + C2[Expert 4,5
计算] + C3[Expert 6,7
计算] + end + + A2A1 --> C0 & C1 & C2 & C3 + + C0 & C1 & C2 & C3 --> A2A2[AlltoAll
返回结果到原 GPU] +``` + +### 2.3 通信动作总结 + +```mermaid +flowchart TD + subgraph Collectives[集合通信操作] + AR[AllReduce] + AG[AllGather] + RS[ReduceScatter] + A2A[AlltoAll] + end + + subgraph Usage[使用场景] + AR --> AR1[TP 反向传播] + AR --> AR2[DP 梯度同步] + AG --> AG1[TP 前向传播] + AG --> AG2[TP 参数收集] + RS --> RS1[ZeRO 梯度分片] + A2A --> A2A1[MoE Token路由] + end + + Note[AllReduce = ReduceScatter + AllGather
用于需要所有节点都获得完整聚合结果的场景] +``` + +### 2.4 通信动作发生阶段总结表 + +| 通信动作 | 并行策略 | 发生阶段 | 通信组 | 是否跨机架 | +|----------|----------|----------|--------|------------| +| AllReduce | TP | 前向/反向每层 | TP 组 | 通常机架内 | +| AllReduce | DP | 反向传播后 | DP 组 | **跨机架** | +| AllGather | TP | 前向传播 | TP 组 | 通常机架内 | +| ReduceScatter | TP/ZeRO | 反向传播 | TP/DP 组 | 视情况 | +| AlltoAll | EP | MoE 层前后 | EP 组 | 视情况 | +| P2P Send/Recv | PP | 层间传递 | 相邻 Stage | 视情况 | + +**关键结论**: + +- **TP 通信**:发生在每层,但通常在机架内(同一服务器的 GPU) +- **DP 通信**:发生在反向传播后,通常跨机架 → **OXC 优化目标** +- **PP 通信**:使用 P2P,不使用集合通信 +- **EP 通信**:使用 AlltoAll,在 MoE 层 + +--- + +## 3. 系统架构 + +### 3.1 整体架构图 + +```mermaid +flowchart TB + subgraph User[用户层] + GR[generate_ranktable.py] + GW[generate_workload.py] + BIN[SimAI_oxc 二进制] + end + + subgraph Input[输入层] + RT[RankTable.json
拓扑配置] + WL[Workload.txt
训练负载] + end + + GR --> RT + GW --> WL + RT --> Core + WL --> Core + + subgraph Core[SimAI-OXC 核心层] + subgraph Main[OxcMain.cc] + PA[参数解析
parseArgs] + RTP[RankTable 解析器] + WLP[Workload 解析器] + end + + subgraph Gen[OxcFlowGenerator] + CD[通信域管理
CommDomains] + FS[流生成策略
OXC/Native] + DEP[依赖追踪
Dependencies] + end + + subgraph HTTP[OxcHttpClient] + POST[HTTP POST
libcurl] + JSON[JSON 序列化
Request] + PARSE[响应解析
Response] + end + + Main --> Gen + Gen --> HTTP + end + + BIN --> Main + + HTTP -->|POST /api/oxc/allreduce| OXC + + subgraph OXC[OXC 服务层 - 外部] + subgraph Java[optical_hccl_system
localhost:8080] + RING[ALGO_OXC_RING
环形算法] + HD[ALGO_OXC_HD
Halving-Doubling] + NB[ALGO_OXC_NB
非阻塞算法] + end + end + + OXC --> Output + + subgraph Output[输出层] + CSV[flows.csv
流矩阵] + DEPS[deps.json
依赖图] + SUM[summary.txt
统计摘要] + end +``` + +### 3.2 模块职责 + +| 模块 | 职责 | 关键文件 | +|------|------|----------| +| **OxcMain** | 程序入口、参数解析、流程编排 | `OxcMain.cc` | +| **OxcFlowGenerator** | 流生成核心逻辑、通信域管理 | `OxcFlowGenerator.h/cc` | +| **OxcHttpClient** | HTTP 通信、JSON 序列化 | `OxcHttpClient.h/cc` | +| **OxcFlowOutput** | 结果输出、格式化 | `OxcFlowOutput.h/cc` | +| **OxcTypes** | 共享数据结构定义 | `OxcTypes.h` | + +### 3.3 文件结构 + +``` +astra-sim-alibabacloud/ +├── astra-sim/ +│ ├── network_frontend/ +│ │ └── oxc/ # OXC 前端模块 +│ │ ├── CMakeLists.txt # 编译配置 +│ │ ├── OxcMain.cc # 主入口 +│ │ ├── OxcFlowGenerator.h # 流生成器头文件 +│ │ ├── OxcFlowGenerator.cc # 流生成器实现 +│ │ ├── OxcHttpClient.h # HTTP 客户端头文件 +│ │ ├── OxcHttpClient.cc # HTTP 客户端实现 +│ │ ├── OxcFlowOutput.h # 输出模块头文件 +│ │ └── OxcFlowOutput.cc # 输出模块实现 +│ └── system/ +│ └── OxcTypes.h # 共享数据结构 +├── build/ +│ └── simai_oxc/ # OXC 编译目录 +│ ├── build.sh # 编译脚本 +│ └── CMakeLists.txt # CMake 配置 +└── ... + +aicb/scripts/ +├── generate_ranktable.py # RankTable 生成脚本 +└── generate_dp_allreduce_workload.py # DP AllReduce 工作负载生成 +``` + +--- + +## 4. 核心数据结构 + +### 4.1 RankTable 结构 + +RankTable 是描述 GPU 集群拓扑的核心配置,定义在 `OxcTypes.h` 中: + +```cpp +// 网络地址信息 +struct RankAddr { + std::string addr_type; // 地址类型,如 "EID" + std::string addr; // 地址值,如 "000000000000002000100000df001001" + std::vector ports; // 端口列表,如 ["0/0"] + std::string plane_id; // 平面ID,如 "plane0" +}; + +// 网络层级信息 +struct LevelInfo { + int net_layer; // 网络层级,0 表示机架级 + std::string net_instance_id; // 网络实例ID,如 "rack_0" + std::string net_type; // 网络类型,如 "TOPO_FILE_DESC" + std::string net_attr; // 网络属性 + std::vector rank_addr_list; // 地址列表 +}; + +// 单个 Rank 的信息 +struct RankInfo { + int rank_id; // 全局 Rank ID + int device_id; // 设备 ID + int local_id; // 本地 ID(服务器内) + std::vector level_list; // 网络层级列表 +}; + +// 完整的 RankTable +struct RankTable { + std::string version = "2.0"; // 版本号 + std::string status = "completed"; // 状态 + int rank_count; // Rank 总数 + std::vector rank_list; // Rank 列表 +}; +``` + +### 4.2 RankTable JSON 示例 + +```json +{ + "version": "2.0", + "status": "completed", + "rank_count": 16, + "rank_list": [ + { + "rank_id": 0, + "device_id": 0, + "local_id": 0, + "level_list": [ + { + "net_layer": 0, + "net_instance_id": "rack_0", + "net_type": "TOPO_FILE_DESC", + "net_attr": "", + "rank_addr_list": [ + { + "addr_type": "EID", + "addr": "000000000000002000100000df001001", + "ports": ["0/0"], + "plane_id": "plane0" + } + ] + } + ] + } + ] +} +``` + +### 4.3 OXC API 请求结构 + +```cpp +struct OxcAllReduceRequest { + RankTable ranktable; // 集群拓扑 + std::vector> dpCommDomain; // DP 通信域 + double commDomainVolume; // 通信数据量(字节) + std::map rankIdRackIdMap; // Rank 到 Rack 映射 + std::string algName; // 算法名称 +}; +``` + +**算法选项**: +| 算法名称 | 描述 | +|----------|------| +| `ALGO_OXC_RING` | 基于环形拓扑的 AllReduce | +| `ALGO_OXC_HD` | Halving-Doubling 算法 | +| `ALGO_OXC_NB` | 非阻塞算法 | + +### 4.4 OXC API 响应结构 + +```cpp +// OXC 服务返回的流条目 +struct OxcFlowEntry { + int src_rank; // 源 Rank + int dst_rank; // 目标 Rank + int step; // 步骤编号 + uint64_t datasize; // 数据大小(字节) +}; +``` + +**响应 JSON 格式**: +```json +[ + [0, 8, 0, 16777216], + [8, 0, 0, 16777216], + [1, 9, 0, 16777216], + ... +] +``` + +### 4.5 输出流结构 + +```cpp +struct OutputFlow { + int operation_id; // 操作 ID + std::string layer_name; // 层名称 + std::string comm_type; // 通信类型 + std::string group_type; // 组类型(DP/TP/EP/PP) + int flow_id; // 流 ID + int src; // 源 Rank + int dst; // 目标 Rank + uint64_t flow_size; // 流大小 + int step; // 步骤 + std::vector depends_on; // 依赖的流 ID 列表 +}; +``` + +### 4.6 操作上下文 + +```cpp +struct OperationContext { + int operation_id; // 操作 ID + std::string layer_name; // 层名称 + std::string phase; // 阶段(fwd/bwd/wg) + CommType comm_type; // 通信类型 + GroupType group_type; // 组类型 + uint64_t msg_size; // 消息大小 + std::vector comm_group; // 通信组成员 +}; +``` + +### 4.7 通信类型枚举 + +```cpp +enum class CommType { + ALLREDUCE, + ALLGATHER, + REDUCESCATTER, + ALLTOALL, + BROADCAST, + REDUCE, + P2P_SEND, + P2P_RECV +}; + +enum class GroupType { + TP, // Tensor Parallelism + DP, // Data Parallelism + EP, // Expert Parallelism + PP // Pipeline Parallelism +}; +``` + +--- + +## 5. 模块设计 + +### 5.1 OxcMain 模块 + +**职责**:程序入口、参数解析、流程编排 + +```mermaid +flowchart TD + MAIN[main] --> PA[parseArgs
解析命令行参数] + + PA --> |"-w/--workload"| W[工作负载文件] + PA --> |"-g/--gpus"| G[GPU 总数] + PA --> |"-g_p_s"| GPS[每服务器 GPU 数] + PA --> |"-ranktable"| RT[RankTable JSON 文件 必需] + PA --> |"-oxc_url"| URL[OXC 服务 URL] + PA --> |"-oxc_algo"| ALG[OXC 算法名称] + PA --> |"-o/--output"| OUT[输出文件前缀] + + PA --> PRT[parseRankTableJson
解析 RankTable 并生成 rank_rack_map] + PRT --> PWL[parseWorkload
解析工作负载文件] + PWL --> GEN[OxcFlowGenerator
创建流生成器实例] + GEN --> FLOW[generateFlows
为每个操作生成流] + FLOW --> OUTPUT[OxcFlowOutput
输出结果文件] +``` + +**关键函数**: + +```cpp +// 解析 RankTable JSON 文件 +bool parseRankTableJson( + const std::string& filepath, + RankTable& ranktable, + std::map& rank_rack_map +); + +// rank_rack_map 自动生成逻辑 +for (const auto& rank : ranktable.rank_list) { + if (!rank.level_list.empty()) { + // 使用第一个 level 的 net_instance_id 作为 rack_id + rank_rack_map[to_string(rank.rank_id)] = rank.level_list[0].net_instance_id; + } +} +``` + +### 5.2 OxcFlowGenerator 模块 + +**职责**:核心流生成逻辑、通信域管理、OXC/Native 策略选择 + +```mermaid +flowchart TD + subgraph OxcFlowGenerator + subgraph Management[管理模块] + CD[通信域管理
CommDomains] + ID[ID 管理
Flow/Op IDs] + end + + subgraph Strategy[策略模块] + FS[流生成策略
OXC / Native] + DEP[依赖追踪
Dependencies] + end + + GF[generateFlows] + + GF --> CHECK{isOxcSupported?} + CHECK -->|是| OXC[generateAllReduceViaOxc] + CHECK -->|否| NATIVE[generateViaNative] + + CD --> GF + ID --> GF + OXC --> DEP + NATIVE --> DEP + end +``` + +**OXC 支持判断**: + +```cpp +bool OxcFlowGenerator::isOxcSupported(CommType comm_type) const { + // 目前仅 DP AllReduce 使用 OXC + return comm_type == CommType::ALLREDUCE; +} +``` + +**通信域构建**: + +```cpp +// 示例:TP=8, DP=2, 总共 16 GPU +// DP 通信域: [[0,8], [1,9], [2,10], [3,11], [4,12], [5,13], [6,14], [7,15]] +// 每个域内的 rank 需要一起做 AllReduce + +std::vector> buildCommDomains(GroupType group_type, int total_gpus) { + std::vector> domains; + + if (group_type == GroupType::DP) { + // DP 域:相同 TP 位置的 rank 组成一个域 + for (int tp_idx = 0; tp_idx < tp_size_; tp_idx++) { + std::vector domain; + for (int dp_idx = 0; dp_idx < dp_size_; dp_idx++) { + domain.push_back(tp_idx + dp_idx * tp_size_); + } + domains.push_back(domain); + } + } + return domains; +} +``` + +### 5.3 OxcHttpClient 模块 + +**职责**:HTTP 通信、JSON 序列化/反序列化 + +```mermaid +flowchart LR + subgraph OxcHttpClient + API[callAllReduceApi] + + API --> BUILD[buildRequestJson
OxcAllReduceRequest → JSON] + BUILD --> POST[httpPost
libcurl POST请求] + POST --> PARSE[parseResponseJson
JSON → OxcFlowEntry] + end + + POST -->|POST /api/oxc/allreduce| SERVER[OXC Server] + SERVER --> POST + + Note[依赖: libcurl
超时: 30秒可配置] +``` + +### 5.4 OxcFlowOutput 模块 + +**职责**:结果输出、格式化 + +```mermaid +flowchart TD + subgraph OxcFlowOutput + CSV[writeFlowsCsv] + JSON[writeDependenciesJson] + SUM[writeSummary] + + CSV --> CSV_OUT[flows.csv
op_id,layer,phase,comm_type,
group,flow_id,src,dst,size,step,depends_on] + + JSON --> JSON_OUT[deps.json
operations: ...
dependencies: ...] + + SUM --> SUM_OUT[summary.txt
工作负载配置
OXC 配置
操作统计
流统计] + end +``` + +--- + +## 6. 流程设计 + +### 6.1 整体执行流程 + +```mermaid +flowchart TD + START([开始执行]) --> PA[1. 解析命令行参数
parseArgs] + + PA --> RT[2. 加载 RankTable
parseRankTableJson] + RT_FILE[(ranktable.json
外部配置文件)] --> RT + RT -->|自动生成 rank_rack_map| WL + + WL[3. 解析工作负载
parseWorkload] + WL_FILE[(workload.txt
训练负载定义)] --> WL + + WL --> INIT[4. 初始化流生成器
OxcFlowGenerator
设置 RankTable
设置通信域] + + INIT --> LOOP[5. 遍历所有操作
for each operation] + + LOOP --> CHECK{OXC 支持?} + CHECK -->|是| OXC_API[OXC API] + CHECK -->|否| NATIVE[Native] + + OXC_API --> GEN[6. 生成流并追踪依赖
generateFlows] + NATIVE --> GEN + + GEN --> MORE{还有更多操作?} + MORE -->|是| LOOP + MORE -->|否| OUTPUT + + OUTPUT[7. 输出结果文件
flows.csv
deps.json
summary.txt] + + OUTPUT --> END([执行完成]) +``` + +### 6.2 OXC API 调用流程 + +```mermaid +sequenceDiagram + participant SimAI as SimAI_oxc + participant OXC as optical_hccl_system + + Note over SimAI: 1. 构建请求 + SimAI->>SimAI: OxcAllReduceRequest {
ranktable: {...},
dpCommDomain: [[0,8],...],
commDomainVolume: 16MB,
rankIdRackIdMap: {...},
algName: "ALGO_OXC_RING"
} + + SimAI->>OXC: 2. HTTP POST /api/oxc/allreduce + + Note over OXC: 3. 计算最优流调度
• 分析拓扑结构
• 选择算法
• 生成流矩阵 + + OXC-->>SimAI: 4. 返回流列表
[[src, dst, step, size], ...] + + Note over SimAI: 5. 转换为 OutputFlow
for each entry:
OutputFlow {
flow_id: next_id++,
src: entry.src_rank,
dst: entry.dst_rank,
step: entry.step,
flow_size: entry.size,
depends_on: [...]
} +``` + +### 6.3 依赖追踪机制 + +```mermaid +flowchart TD + subgraph IntraOp[操作内依赖 - 同一 AllReduce 内] + direction LR + subgraph Step0[Step 0] + F0[Flow0
0→8] + F1[Flow1
8→0] + end + subgraph Step1[Step 1] + F4[Flow4
0→4] + F5[Flow5
8→12] + end + subgraph Step2[Step 2] + F8[Flow8
0→2] + F9[Flow9
8→10] + end + F0 --> F4 --> F8 + F1 --> F5 --> F9 + end + + subgraph InterOp[操作间依赖 - 不同操作之间] + direction LR + OP0[Operation 0
Layer0 AllReduce
Flow 0-31] + OP1[Operation 1
Layer1 AllGather
Flow 32-55] + OP2[Operation 2
Layer1 AllReduce
Flow 56-87] + + OP0 -->|最后一个流 → 第一个流| OP1 + OP1 -->|最后一个流 → 第一个流| OP2 + end + + Rule1[依赖规则1: Step N 的流依赖于 Step N-1 中相同 src_rank 的流] + Rule2[依赖规则2: Operation N 的第一个流依赖于 Operation N-1 的最后一个流] +``` + +### 6.4 Native 流生成(非 OXC 操作) + +```mermaid +flowchart TD + subgraph Config[配置] + COMM[通信组: 0, 1, 2, 3 - 4个rank] + SIZE[数据大小: 64MB] + STEP_SIZE[每步数据: 64MB / 4 = 16MB] + end + + subgraph RingAllGather[Ring AllGather 流程] + subgraph S0[Step 0] + S0_0[0 → 1 16MB] + S0_1[1 → 2 16MB] + S0_2[2 → 3 16MB] + S0_3[3 → 0 16MB] + end + + subgraph S1[Step 1] + S1_0[0 → 1 16MB] + S1_1[1 → 2 16MB] + S1_2[2 → 3 16MB] + S1_3[3 → 0 16MB] + end + + subgraph S2[Step 2] + S2_0[0 → 1 16MB] + S2_1[1 → 2 16MB] + S2_2[2 → 3 16MB] + S2_3[3 → 0 16MB] + end + + S0_3 -.->|依赖| S1_0 + S0_0 -.->|依赖| S1_1 + S0_1 -.->|依赖| S1_2 + S0_2 -.->|依赖| S1_3 + + S1_3 -.->|依赖| S2_0 + S1_0 -.->|依赖| S2_1 + S1_1 -.->|依赖| S2_2 + S1_2 -.->|依赖| S2_3 + end + + TOTAL[总流数: 4 ranks × 3 steps = 12 flows] +``` + +--- + +## 7. 接口规范 + +### 7.1 命令行接口 + +```bash +./bin/SimAI_oxc [选项] +``` + +| 参数 | 简写 | 必需 | 默认值 | 描述 | +|------|------|------|--------|------| +| `--workload` | `-w` | 是 | - | 工作负载文件路径 | +| `--gpus` | `-g` | 是 | - | GPU 总数 | +| `--gpus-per-server` | `-g_p_s` | 否 | 8 | 每服务器 GPU 数 | +| `--ranktable` | - | 是 | - | RankTable JSON 文件路径 | +| `--oxc-url` | - | 否 | http://localhost:8080 | OXC 服务 URL | +| `--oxc-algo` | - | 否 | ALGO_OXC_RING | OXC 算法名称 | +| `--output` | `-o` | 否 | ./results/oxc | 输出文件前缀 | +| `--tp-size` | - | 否 | 8 | Tensor Parallelism 大小 | +| `--dp-size` | - | 否 | 自动计算 | Data Parallelism 大小 | +| `--ep-size` | - | 否 | 1 | Expert Parallelism 大小 | +| `--pp-size` | - | 否 | 1 | Pipeline Parallelism 大小 | + +### 7.2 OXC HTTP API + +#### 7.2.1 AllReduce 端点 + +**请求**: +``` +POST /api/oxc/allreduce +Content-Type: application/json +``` + +**请求体**: +```json +{ + "ranktable": { + "version": "2.0", + "status": "completed", + "rank_count": 16, + "rank_list": [...] + }, + "dpCommDomain": [[0, 8], [1, 9], ...], + "commDomainVolume": 16777216, + "rankIdRackIdMap": { + "0": "rack_0", + "1": "rack_0", + "8": "rack_1", + "9": "rack_1" + }, + "algName": "ALGO_OXC_RING" +} +``` + +**响应**: +```json +[ + [0, 8, 0, 16777216], + [8, 0, 0, 16777216], + [1, 9, 0, 16777216], + [9, 1, 0, 16777216], + ... +] +``` + +**响应字段说明**: +| 索引 | 字段 | 类型 | 描述 | +|------|------|------|------| +| 0 | src_rank | int | 源 Rank ID | +| 1 | dst_rank | int | 目标 Rank ID | +| 2 | step | int | 步骤编号(从 0 开始) | +| 3 | datasize | int64 | 数据大小(字节) | + +### 7.3 输出文件格式 + +#### 7.3.1 流矩阵 CSV (`*_flows.csv`) + +```csv +op_id,layer,phase,comm_type,group,flow_id,src,dst,size,step,depends_on +0,embedding,fwd,ALLREDUCE,DP,0,0,8,16777216,0,[] +0,embedding,fwd,ALLREDUCE,DP,1,8,0,16777216,0,[] +0,embedding,fwd,ALLREDUCE,DP,2,1,9,16777216,0,[] +0,embedding,fwd,ALLREDUCE,DP,3,9,1,16777216,0,[] +0,embedding,fwd,ALLREDUCE,DP,4,0,4,16777216,1,[0] +0,embedding,fwd,ALLREDUCE,DP,5,8,12,16777216,1,[1] +``` + +**字段说明**: +| 字段 | 类型 | 描述 | +|------|------|------| +| op_id | int | 操作 ID | +| layer | string | 层名称 | +| phase | string | 阶段(fwd/bwd/wg) | +| comm_type | string | 通信类型 | +| group | string | 组类型(DP/TP/EP/PP) | +| flow_id | int | 全局流 ID | +| src | int | 源 Rank | +| dst | int | 目标 Rank | +| size | int64 | 数据大小(字节) | +| step | int | 步骤编号 | +| depends_on | list | 依赖的流 ID 列表 | + +#### 7.3.2 依赖图 JSON (`*_deps.json`) + +```json +{ + "operations": [ + { + "op_id": 0, + "layer": "embedding", + "phase": "fwd", + "comm_type": "ALLREDUCE", + "group_type": "DP", + "flow_count": 32, + "first_flow_id": 0, + "last_flow_id": 31 + }, + { + "op_id": 1, + "layer": "attention", + "phase": "fwd", + "comm_type": "ALLGATHER", + "group_type": "TP", + "flow_count": 24, + "first_flow_id": 32, + "last_flow_id": 55 + } + ], + "operation_dependencies": { + "1": [0], + "2": [1] + }, + "flow_dependencies": { + "4": [0], + "5": [1], + "32": [31] + } +} +``` + +#### 7.3.3 摘要文件 (`*_summary.txt`) + +``` +SimAI-OXC Flow Generation Summary +================================= + +Workload Configuration: + Parallelism Policy: HYBRID_TRANSFORMER_FWD_IN_BCKWD + Total GPUs: 16 + GPUs per Server: 8 + TP Size: 8 + EP Size: 1 + PP Size: 1 + +OXC Configuration: + Server URL: http://localhost:8080 + Algorithm: ALGO_OXC_RING + +Operations Processed: + Total Operations: 6168 + - ALLREDUCE: 24 (OXC) + - ALLGATHER: 4096 (Native) + - REDUCESCATTER: 2048 (Native) + +Flow Statistics: + Total Flows Generated: 345888 + Total Dependencies: 296592 + +Output Files: + Flow Matrix: results/oxc_test_flows.csv + Dependency Graph: results/oxc_test_deps.json + Summary: results/oxc_test_summary.txt +``` + +### 7.4 RankTable 生成脚本接口 + +```bash +python3 aicb/scripts/generate_ranktable.py [选项] +``` + +| 参数 | 必需 | 默认值 | 描述 | +|------|------|--------|------| +| `--num-gpus` | 是 | - | GPU 总数 | +| `--gpus-per-server` | 否 | 8 | 每服务器 GPU 数 | +| `--output` | 否 | ranktable.json | 输出文件路径 | +| `--addr-prefix` | 否 | 00000000000000200010 | 地址前缀 | + +**示例**: +```bash +# 生成 16 GPU 的 RankTable +python3 aicb/scripts/generate_ranktable.py \ + --num-gpus 16 \ + --gpus-per-server 8 \ + --output ./config/ranktable_16gpu.json +``` + +--- + +## 8. 使用指南 + +### 8.1 环境准备 + +#### 8.1.1 依赖安装 + +```bash +# Ubuntu/Debian +sudo apt-get update +sudo apt-get install -y build-essential cmake libcurl4-openssl-dev + +# CentOS/RHEL +sudo yum install -y gcc gcc-c++ cmake libcurl-devel +``` + +#### 8.1.2 启动 OXC 服务 + +```bash +# 确保 optical_hccl_system 服务已启动 +cd /path/to/optical_hccl_system +java -jar optical_hccl_system.jar + +# 验证服务状态 +curl http://localhost:8080/health +``` + +### 8.2 编译 + +```bash +# 进入 SimAI 目录 +cd /path/to/SimAI + +# 编译 OXC 模块 +./scripts/build.sh -c oxc + +# 验证编译结果 +ls -la ./bin/SimAI_oxc +``` + +### 8.3 配置文件准备 + +#### 8.3.1 生成 RankTable + +```bash +# 生成 16 GPU(2 服务器 × 8 GPU)的 RankTable +python3 aicb/scripts/generate_ranktable.py \ + --num-gpus 16 \ + --gpus-per-server 8 \ + --output ./config/ranktable.json + +# 生成 128 GPU 的 RankTable +python3 aicb/scripts/generate_ranktable.py \ + --num-gpus 128 \ + --gpus-per-server 8 \ + --output ./config/ranktable_128gpu.json +``` + +#### 8.3.2 准备工作负载文件 + +使用 AICB 生成工作负载: + +```bash +# 生成 Megatron 风格的工作负载 +python3 aicb/workload_generator/generate_megatron_workload.py \ + --model-size 7B \ + --tp 8 \ + --dp 2 \ + --output ./workloads/megatron_7b.txt +``` + +或使用示例工作负载: + +```bash +# 使用内置示例 +cp ./example/workload_analytical.txt ./workloads/ +``` + +### 8.4 运行仿真 + +#### 8.4.1 基本运行 + +```bash +./bin/SimAI_oxc \ + -w ./workloads/megatron_7b.txt \ + -g 16 \ + -g_p_s 8 \ + -ranktable ./config/ranktable.json \ + -oxc_url http://localhost:8080 \ + -oxc_algo ALGO_OXC_RING \ + -o ./results/oxc_test +``` + +#### 8.4.2 使用不同算法 + +```bash +# 使用 Halving-Doubling 算法 +./bin/SimAI_oxc \ + -w ./workloads/megatron_7b.txt \ + -g 16 \ + -g_p_s 8 \ + -ranktable ./config/ranktable.json \ + -oxc_algo ALGO_OXC_HD \ + -o ./results/oxc_hd + +# 使用非阻塞算法 +./bin/SimAI_oxc \ + -w ./workloads/megatron_7b.txt \ + -g 16 \ + -g_p_s 8 \ + -ranktable ./config/ranktable.json \ + -oxc_algo ALGO_OXC_NB \ + -o ./results/oxc_nb +``` + +#### 8.4.3 大规模仿真 + +```bash +# 128 GPU 仿真 +./bin/SimAI_oxc \ + -w ./workloads/large_model.txt \ + -g 128 \ + -g_p_s 8 \ + -ranktable ./config/ranktable_128gpu.json \ + -oxc_url http://localhost:8080 \ + -oxc_algo ALGO_OXC_RING \ + -o ./results/oxc_128gpu +``` + +### 8.5 结果分析 + +#### 8.5.1 查看摘要 + +```bash +cat ./results/oxc_test_summary.txt +``` + +#### 8.5.2 分析流矩阵 + +```bash +# 查看前 20 行 +head -20 ./results/oxc_test_flows.csv + +# 统计各类型操作数量 +awk -F',' 'NR>1 {print $4}' ./results/oxc_test_flows.csv | sort | uniq -c +``` + +#### 8.5.3 可视化依赖图 + +```python +import json +import matplotlib.pyplot as plt +import networkx as nx + +# 加载依赖图 +with open('./results/oxc_test_deps.json', 'r') as f: + deps = json.load(f) + +# 创建有向图 +G = nx.DiGraph() +for op in deps['operations']: + G.add_node(op['op_id'], label=f"{op['layer']}\n{op['comm_type']}") + +for op_id, dep_list in deps['operation_dependencies'].items(): + for dep in dep_list: + G.add_edge(dep, int(op_id)) + +# 绘制 +plt.figure(figsize=(12, 8)) +pos = nx.spring_layout(G) +nx.draw(G, pos, with_labels=True, node_color='lightblue', + node_size=500, font_size=8, arrows=True) +plt.savefig('./results/dependency_graph.png') +``` + +### 8.6 故障排除 + +#### 8.6.1 常见问题 + +| 问题 | 可能原因 | 解决方案 | +|------|----------|----------| +| OXC API 连接失败 | 服务未启动 | 检查 `curl http://localhost:8080/health` | +| RankTable 解析错误 | JSON 格式错误 | 使用 `python -m json.tool ranktable.json` 验证 | +| 流数量为 0 | 工作负载无 AllReduce | 检查工作负载文件内容 | +| 编译失败 | 缺少 libcurl | 安装 `libcurl4-openssl-dev` | + +#### 8.6.2 调试模式 + +```bash +# 启用详细日志 +AS_LOG_LEVEL=DEBUG ./bin/SimAI_oxc \ + -w ./workloads/test.txt \ + -g 16 \ + -g_p_s 8 \ + -ranktable ./config/ranktable.json \ + -o ./results/debug_test +``` + +#### 8.6.3 验证 OXC 服务 + +```bash +# 手动测试 OXC API +curl -X POST http://localhost:8080/api/oxc/allreduce \ + -H "Content-Type: application/json" \ + -d '{ + "ranktable": {"version": "2.0", "status": "completed", "rank_count": 2, "rank_list": []}, + "dpCommDomain": [[0, 1]], + "commDomainVolume": 1048576, + "rankIdRackIdMap": {"0": "rack_0", "1": "rack_1"}, + "algName": "ALGO_OXC_RING" + }' +``` + +--- + +## 附录 + +### A. 术语表 + +| 术语 | 全称 | 描述 | +|------|------|------| +| OXC | Optical Cross-Connect | 光交叉连接,用于光网络中的信号交换 | +| TP | Tensor Parallelism | 张量并行,将单层参数切分到多个 GPU | +| DP | Data Parallelism | 数据并行,每个 GPU 处理不同数据批次 | +| EP | Expert Parallelism | 专家并行,用于 MoE 模型 | +| PP | Pipeline Parallelism | 流水线并行,将模型按层切分 | +| AllReduce | - | 集合通信操作,聚合所有节点数据 | +| AllGather | - | 集合通信操作,收集所有节点数据 | +| ReduceScatter | - | 集合通信操作,聚合并分发数据 | + +### B. 参考资料 + +1. SimAI 论文:[NSDI'25 Spring - SimAI](https://ennanzhai.github.io/pub/nsdi25spring-simai.pdf) +2. AICB 文档:https://github.com/aliyun/aicb +3. optical_hccl_system 文档:内部文档 +4. NCCL 算法:https://docs.nvidia.com/deeplearning/nccl/ + +### C. 版本历史 + +| 版本 | 日期 | 描述 | +|------|------|------| +| 1.0 | 2024-01 | 初始版本,支持 OXC AllReduce | +| 1.1 | 2024-02 | 添加外部 RankTable 支持 | \ No newline at end of file diff --git a/example/microAllReduce_dp.txt b/example/microAllReduce_dp.txt new file mode 100644 index 00000000..7edc891d --- /dev/null +++ b/example/microAllReduce_dp.txt @@ -0,0 +1,4 @@ +HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: 8 ep: 1 pp: 1 vpp: 8 ga: 1 all_gpus: 16 checkpoints: 0 checkpoint_initiates: 0 +2 +embedding_layer -1 556000 ALLREDUCE 16777216 1 NONE 0 1 ALLREDUCE 67108864 1 +transformer_layer -1 556000 ALLREDUCE 16777216 1 NONE 0 1 ALLREDUCE 134217728 1 diff --git a/scripts/build.sh b/scripts/build.sh index 702f246a..c15d768b 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -7,10 +7,11 @@ SIMAI_DIR="${ROOT_DIR:?}"/astra-sim-alibabacloud SOURCE_NS3_BIN_DIR="${SIMAI_DIR:?}"/extern/network_backend/ns3-interface/simulation/build/scratch/ns3.36.1-AstraSimNetwork-debug SOURCE_ANA_BIN_DIR="${SIMAI_DIR:?}"/build/simai_analytical/build/simai_analytical/SimAI_analytical SOURCE_PHY_BIN_DIR="${SIMAI_DIR:?}"/build/simai_phy/build/simai_phynet/SimAI_phynet +SOURCE_OXC_BIN_DIR="${SIMAI_DIR:?}"/build/simai_oxc/build/simai_oxc/SimAI_oxc TARGET_BIN_DIR="${SCRIPT_DIR:?}"/../bin function compile { - local option="$1" + local option="$1" case "$option" in "ns3") mkdir -p "${TARGET_BIN_DIR:?}" @@ -22,7 +23,7 @@ function compile { cp -r "${NS3_DIR:?}"/* "${SIMAI_DIR:?}"/extern/network_backend/ns3-interface cd "${SIMAI_DIR:?}" ./build.sh -lr ns3 - ./build.sh -c ns3 + ./build.sh -c ns3 ln -s "${SOURCE_NS3_BIN_DIR:?}" "${TARGET_BIN_DIR:?}"/SimAI_simulator;; "phy") mkdir -p "${TARGET_BIN_DIR:?}" @@ -31,7 +32,7 @@ function compile { fi cd "${SIMAI_DIR:?}" ./build.sh -lr phy - ./build.sh -c phy + ./build.sh -c phy ln -s "${SOURCE_PHY_BIN_DIR:?}" "${TARGET_BIN_DIR:?}"/SimAI_phynet;; "analytical") mkdir -p "${TARGET_BIN_DIR:?}" @@ -41,8 +42,18 @@ function compile { fi cd "${SIMAI_DIR:?}" ./build.sh -lr analytical - ./build.sh -c analytical + ./build.sh -c analytical ln -s "${SOURCE_ANA_BIN_DIR:?}" "${TARGET_BIN_DIR:?}"/SimAI_analytical;; + "oxc") + mkdir -p "${TARGET_BIN_DIR:?}" + mkdir -p "${ROOT_DIR:?}"/results + if [ -L "${TARGET_BIN_DIR:?}/SimAI_oxc" ]; then + rm -rf "${TARGET_BIN_DIR:?}"/SimAI_oxc + fi + cd "${SIMAI_DIR:?}" + ./build.sh -lr oxc + ./build.sh -c oxc + ln -s "${SOURCE_OXC_BIN_DIR:?}" "${TARGET_BIN_DIR:?}"/SimAI_oxc;; esac } @@ -68,6 +79,12 @@ function cleanup_build { fi cd "${SIMAI_DIR:?}" ./build.sh -lr analytical;; + "oxc") + if [ -L "${TARGET_BIN_DIR:?}/SimAI_oxc" ]; then + rm -rf "${TARGET_BIN_DIR:?}"/SimAI_oxc + fi + cd "${SIMAI_DIR:?}" + ./build.sh -lr oxc;; esac } @@ -79,7 +96,7 @@ case "$1" in compile "$2";; -h|--help|*) printf -- "help message\n" - printf -- "-c|--compile mode supported ns3/phy/analytical (example:./build.sh -c ns3)\n" + printf -- "-c|--compile mode supported ns3/phy/analytical/oxc (example:./build.sh -c ns3)\n" printf -- "-l|--clean (example:./build.sh -l ns3)\n" printf -- "-lr|--clean-result mode (example:./build.sh -lr ns3)\n" esac \ No newline at end of file