From cfa0bbc0fb3d4ae54a58191a860c9bf436633a66 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 03:21:51 +0000 Subject: [PATCH 01/16] refactor: extract shared operator utilities into `_operator_utils.py` Move `OperatorExtractor`, `Operator`, `snake_to_pascal`, and `get_all_ops` out of `generate_wrappers.py` into a reusable module. --- scripts/_operator_utils.py | 103 ++++++++++++++++++++++++++++++++ scripts/generate_wrappers.py | 110 ++--------------------------------- 2 files changed, 108 insertions(+), 105 deletions(-) create mode 100644 scripts/_operator_utils.py diff --git a/scripts/_operator_utils.py b/scripts/_operator_utils.py new file mode 100644 index 00000000..2cb52bba --- /dev/null +++ b/scripts/_operator_utils.py @@ -0,0 +1,103 @@ +import pathlib +import shutil +import subprocess + +import clang.cindex +from clang.cindex import CursorKind + +_SRC_DIR = pathlib.Path("src") + +_BASE_DIR = _SRC_DIR / "base" + + +def snake_to_pascal(snake_str): + return "".join(word.capitalize() for word in snake_str.split("_")) + + +class Operator: + def __init__(self, name, constructors, calls): + self.name = name + + self.constructors = constructors + + self.calls = calls + + +class OperatorExtractor: + def __call__(self, op_name): + def _get_system_include_flags(): + def _get_compilers(): + compilers = [] + + for compiler in ("clang++", "g++"): + if shutil.which(compiler) is not None: + compilers.append(compiler) + + return compilers + + system_include_flags = [] + + for compiler in _get_compilers(): + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue + + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) + + return system_include_flags + + system_include_flags = _get_system_include_flags() + + index = clang.cindex.Index.create() + args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) + translation_unit = index.parse(f"src/base/{op_name}.h", args=args) + + nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) + + constructors = [] + calls = [] + + for node in nodes: + if node.kind == CursorKind.CONSTRUCTOR: + constructors.append(node) + elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": + calls.append(node) + + return Operator(op_name, constructors, calls) + + @staticmethod + def _find(node, op_name): + pascal_case_op_name = snake_to_pascal(op_name) + + if ( + node.semantic_parent + and node.semantic_parent.spelling == pascal_case_op_name + ): + yield node + + for child in node.get_children(): + yield from OperatorExtractor._find(child, op_name) + + +def get_all_ops(devices): + ops = {} + + for file_path in _BASE_DIR.iterdir(): + if not file_path.is_file(): + continue + + op_name = file_path.stem + + ops[op_name] = [] + + for file_path in _SRC_DIR.rglob("*"): + if not file_path.is_file() or file_path.parent.parent.name not in devices: + continue + + if f"class Operator<{snake_to_pascal(op_name)}" in file_path.read_text(): + ops[op_name].append(file_path) + + return ops diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 5aa8896e..271eb69c 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,16 +1,9 @@ import argparse import json import pathlib -import shutil -import subprocess import textwrap -import clang.cindex -from clang.cindex import CursorKind - -_SRC_DIR = pathlib.Path("src") - -_BASE_DIR = _SRC_DIR / "base" +from _operator_utils import OperatorExtractor, get_all_ops, snake_to_pascal _GENERATION_DIR = pathlib.Path("generated") @@ -23,74 +16,6 @@ _INDENTATION = " " -class _OperatorExtractor: - def __call__(self, op_name): - def _get_system_include_flags(): - def _get_compilers(): - compilers = [] - - for compiler in ("clang++", "g++"): - if shutil.which(compiler) is not None: - compilers.append(compiler) - - return compilers - - system_include_flags = [] - - for compiler in _get_compilers(): - for line in subprocess.getoutput( - f"{compiler} -E -x c++ -v /dev/null" - ).splitlines(): - if not line.startswith(" "): - continue - - system_include_flags.append("-isystem") - system_include_flags.append(line.strip()) - - return system_include_flags - - system_include_flags = _get_system_include_flags() - - index = clang.cindex.Index.create() - args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name}.h", args=args) - - nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) - - constructors = [] - calls = [] - - for node in nodes: - if node.kind == CursorKind.CONSTRUCTOR: - constructors.append(node) - elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": - calls.append(node) - - return _Operator(op_name, constructors, calls) - - @staticmethod - def _find(node, op_name): - pascal_case_op_name = _snake_to_pascal(op_name) - - if ( - node.semantic_parent - and node.semantic_parent.spelling == pascal_case_op_name - ): - yield node - - for child in node.get_children(): - yield from _OperatorExtractor._find(child, op_name) - - -class _Operator: - def __init__(self, name, constructors, calls): - self.name = name - - self.constructors = constructors - - self.calls = calls - - def _generate_pybind11(operator): def _generate_params(node): return ( @@ -159,7 +84,7 @@ def _generate_call(op_name, call, method=True): _generate_call(operator.name, call, method=False) for call in operator.calls ) - pascal_case_op_name = _snake_to_pascal(op_name) + pascal_case_op_name = snake_to_pascal(op_name) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ @@ -368,31 +293,6 @@ def _generate_tensor_caster(name, is_data=False): return _generate_source(operator), _generate_header(operator) -def _snake_to_pascal(snake_str): - return "".join(word.capitalize() for word in snake_str.split("_")) - - -def _get_all_ops(devices): - ops = {} - - for file_path in _BASE_DIR.iterdir(): - if not file_path.is_file(): - continue - - op_name = file_path.stem - - ops[op_name] = [] - - for file_path in _SRC_DIR.rglob("*"): - if not file_path.is_file() or file_path.parent.parent.name not in devices: - continue - - if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): - ops[op_name].append(file_path) - - return ops - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="An automatic wrapper generator.") @@ -415,18 +315,18 @@ def _get_all_ops(devices): if ops_json.exists(): ops = json.loads(ops_json.read_text()) else: - ops = _get_all_ops(args.devices) + ops = get_all_ops(args.devices) header_paths = [] bind_func_names = [] for op_name, impl_paths in ops.items(): - extractor = _OperatorExtractor() + extractor = OperatorExtractor() operator = extractor(op_name) source_path = _GENERATED_SRC_DIR / op_name header_name = f"{op_name}.h" - bind_func_name = f"Bind{_snake_to_pascal(op_name)}" + bind_func_name = f"Bind{snake_to_pascal(op_name)}" (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) From d035ee8bec332746f46ff145a019067afd2f98fd Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 03:27:37 +0000 Subject: [PATCH 02/16] refactor: extract pybind11 generator into `_generate_pybind11.py` --- scripts/_generate_pybind11.py | 104 +++++++++++++++++++++++++++++++++ scripts/generate_wrappers.py | 106 +--------------------------------- 2 files changed, 106 insertions(+), 104 deletions(-) create mode 100644 scripts/_generate_pybind11.py diff --git a/scripts/_generate_pybind11.py b/scripts/_generate_pybind11.py new file mode 100644 index 00000000..89a902d2 --- /dev/null +++ b/scripts/_generate_pybind11.py @@ -0,0 +1,104 @@ +from _operator_utils import snake_to_pascal + + +def generate_pybind11(operator): + def _generate_params(node): + return ( + ", ".join( + f"{arg.type.spelling} {arg.spelling}" + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + .replace("const Tensor", "py::object") + .replace("Tensor", "py::object") + ) + + def _generate_arguments(node): + return ", ".join( + f"TensorFromPybind11Handle({arg.spelling})" + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + op_name = operator.name + + def _generate_init(constructor): + constructor_params = _generate_params(constructor) + + return f""" .def(py::init([]({constructor_params}) {{ + return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; + }}))""" + + def _generate_py_args(node): + return ", ".join( + f'py::arg("{arg.spelling}")' + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + def _generate_call(op_name, call, method=True): + call_params = _generate_params(call) + call_args = _generate_arguments(call) + + if not method: + params = ( + f"{call_params}, std::size_t implementation_index" + if call_params + else "std::size_t implementation_index" + ) + py_args = _generate_py_args(call) + py_args_str = f"{py_args}, " if py_args else "" + + return f""" m.def("{op_name}", []({params}) {{ + Config config; + config.set_implementation_index(implementation_index); + return Self::call({{}}, config, {call_args}); + }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" + + return f""" .def("__call__", [](const Self& self, {call_params}) {{ + return static_cast&>(self)({call_args}); + }})""" + + inits = "\n".join( + _generate_init(constructor) for constructor in operator.constructors + ) + calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) + callers = "\n".join( + _generate_call(operator.name, call, method=False) for call in operator.calls + ) + + pascal_case_op_name = snake_to_pascal(op_name) + + return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ +#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ + +#include +#include + +#include "base/{op_name}.h" +#include "config.h" +#include "pybind11_utils.h" + +namespace py = pybind11; + +namespace infini::ops {{ + +void Bind{pascal_case_op_name}(py::module& m) {{ + using Self = {pascal_case_op_name}; + + py::class_(m, "{pascal_case_op_name}") +{inits} +{calls} + .def_static("active_implementation_indices", [](const std::string& device) {{ + return Self::active_implementation_indices(DeviceTypeFromString(device)); + }}); + +{callers} +}} + +}} // namespace infini::ops + +#endif +""" diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 271eb69c..dfc13979 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -4,6 +4,7 @@ import textwrap from _operator_utils import OperatorExtractor, get_all_ops, snake_to_pascal +from _generate_pybind11 import generate_pybind11 _GENERATION_DIR = pathlib.Path("generated") @@ -16,109 +17,6 @@ _INDENTATION = " " -def _generate_pybind11(operator): - def _generate_params(node): - return ( - ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - .replace("const Tensor", "py::object") - .replace("Tensor", "py::object") - ) - - def _generate_arguments(node): - return ", ".join( - f"TensorFromPybind11Handle({arg.spelling})" - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - - op_name = operator.name - - def _generate_init(constructor): - constructor_params = _generate_params(constructor) - - return f""" .def(py::init([]({constructor_params}) {{ - return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; - }}))""" - - def _generate_py_args(node): - return ", ".join( - f'py::arg("{arg.spelling}")' - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - - def _generate_call(op_name, call, method=True): - call_params = _generate_params(call) - call_args = _generate_arguments(call) - - if not method: - params = ( - f"{call_params}, std::size_t implementation_index" - if call_params - else "std::size_t implementation_index" - ) - py_args = _generate_py_args(call) - py_args_str = f"{py_args}, " if py_args else "" - - return f""" m.def("{op_name}", []({params}) {{ - Config config; - config.set_implementation_index(implementation_index); - return Self::call({{}}, config, {call_args}); - }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" - - return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return static_cast&>(self)({call_args}); - }})""" - - inits = "\n".join( - _generate_init(constructor) for constructor in operator.constructors - ) - calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) - callers = "\n".join( - _generate_call(operator.name, call, method=False) for call in operator.calls - ) - - pascal_case_op_name = snake_to_pascal(op_name) - - return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ -#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ - -#include -#include - -#include "base/{op_name}.h" -#include "config.h" -#include "pybind11_utils.h" - -namespace py = pybind11; - -namespace infini::ops {{ - -void Bind{pascal_case_op_name}(py::module& m) {{ - using Self = {pascal_case_op_name}; - - py::class_(m, "{pascal_case_op_name}") -{inits} -{calls} - .def_static("active_implementation_indices", [](const std::string& device) {{ - return Self::active_implementation_indices(DeviceTypeFromString(device)); - }}); - -{callers} -}} - -}} // namespace infini::ops - -#endif -""" - - def _generate_legacy_c(operator, paths): def _generate_source(operator): impl_includes = "\n".join( @@ -328,7 +226,7 @@ def _generate_tensor_caster(name, is_data=False): header_name = f"{op_name}.h" bind_func_name = f"Bind{snake_to_pascal(op_name)}" - (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + (_BINDINGS_DIR / header_name).write_text(generate_pybind11(operator)) legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) From cb348904ac52805862b0de2aa710f57127cc356f Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 04:47:08 +0000 Subject: [PATCH 03/16] refactor: extract legacy C generator into `_generate_legacy_c.py` Also fix the generated output to be compatible with InfiniCore: - Use `__INFINI_C` instead of `__C`. - Use `InfiniopDescriptor *` as the descriptor base type. - Include `../operator_descriptor.h` instead of InfiniOps base headers. - Use PascalCase for C API names (e.g. `infiniopCreateGemmDescriptor`). - Set stream/workspace via `Handle` setters instead of passing as arguments. - Use `reinterpret_cast` for descriptor type conversions. --- scripts/_generate_legacy_c.py | 175 +++++++++++++++++++++++++++++++++ scripts/generate_wrappers.py | 178 +--------------------------------- 2 files changed, 177 insertions(+), 176 deletions(-) create mode 100644 scripts/_generate_legacy_c.py diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py new file mode 100644 index 00000000..2397d5b7 --- /dev/null +++ b/scripts/_generate_legacy_c.py @@ -0,0 +1,175 @@ +from _operator_utils import snake_to_pascal + + +def generate_legacy_c(operator, paths): + pascal_name = snake_to_pascal(operator.name) + + def _generate_source(operator): + impl_includes = "\n".join( + f'#include "{str(path).removeprefix("src/")}"' for path in paths + ) + + return f"""#include "../../handle.h" +#include "../../tensor.h" +#include "infiniop/ops/{operator.name}.h" +{impl_includes} + +static infini::ops::DataType DataTypeFromInfiniDType( + const infiniDtype_t& dtype) {{ + static constexpr infini::ops::ConstexprMap + kInfiniDTypeToDataType{{ + {{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}}, + {{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}}, + {{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}}, + {{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}}, + {{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}}, + {{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}}, + {{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}}, + {{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}}, + {{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}}, + {{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}}, + {{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}}, + {{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}}; + + return kInfiniDTypeToDataType.at(dtype); +}} + +static infini::ops::Device::Type DeviceTypeFromInfiniDevice( + const infiniDevice_t& device) {{ + static constexpr infini::ops::ConstexprMap< + infiniDevice_t, infini::ops::Device::Type, + static_cast(INFINI_DEVICE_TYPE_COUNT)> + kInfiniDeviceToDeviceType{{ + {{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}}, + {{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}}, + {{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}}, + {{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}}, + {{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}}, + {{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}}, + {{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}}, + {{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}}, + {{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}}, + {{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}}; + + return kInfiniDeviceToDeviceType.at(device); +}} + +__INFINI_C {_generate_create_func_def(operator)} + +__INFINI_C {_generate_get_workspace_size_func_def(operator)} + +__INFINI_C {_generate_call_func_def(operator)} + +__INFINI_C {_generate_destroy_func_def(operator)} +""" + + def _generate_header(operator): + return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ +#define __INFINIOP_{operator.name.upper()}_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniop{pascal_name}Descriptor_t; + +__INFINI_C __export {_generate_create_func_decl(operator)}; + +__INFINI_C __export {_generate_get_workspace_size_func_decl(operator)}; + +__INFINI_C __export {_generate_call_func_decl(operator)}; + +__INFINI_C __export {_generate_destroy_func_decl(operator)}; + +#endif +""" + + def _generate_create_func_def(operator): + constructor = operator.constructors[-1] + + return f"""{_generate_create_func_decl(operator)} {{ + *desc_ptr = reinterpret_cast(infini::ops::Operator::make({_generate_arguments(constructor)}).release()); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_get_workspace_size_func_def(operator): + return f"""{_generate_get_workspace_size_func_decl(operator)} {{ + *size = 0; // desc->workspace_size(); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_call_func_def(operator): + call = operator.calls[-1] + + return f"""{_generate_call_func_decl(operator)} {{ + auto *op = reinterpret_cast *>(desc); + op->set_stream(stream); + op->set_workspace(workspace); + op->set_workspace_size_in_bytes(workspace_size); + const auto &op_ref = *op; + op_ref({_generate_arguments(call, is_data=True)}); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_destroy_func_def(operator): + return f"""{_generate_destroy_func_decl(operator)} {{ + delete reinterpret_cast *>(desc); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_create_func_decl(operator): + constructor = operator.constructors[-1] + params = _generate_params(constructor) + + return f"infiniStatus_t infiniopCreate{pascal_name}Descriptor(infiniopHandle_t handle, infiniop{pascal_name}Descriptor_t *desc_ptr, {params})" + + def _generate_get_workspace_size_func_decl(operator): + return f"infiniStatus_t infiniopGet{pascal_name}WorkspaceSize(infiniop{pascal_name}Descriptor_t desc, size_t *size)" + + def _generate_call_func_decl(operator): + call = operator.calls[-1] + params = _generate_params(call, call=True) + params = params.replace("void * stream, ", "") + + return f"infiniStatus_t infiniop{pascal_name}(infiniop{pascal_name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)" + + def _generate_destroy_func_decl(operator): + return f"infiniStatus_t infiniopDestroy{pascal_name}Descriptor(infiniop{pascal_name}Descriptor_t desc)" + + def _generate_params(node, call=False): + arguments = tuple(node.get_arguments()) + arguments = (arguments[-1], *arguments[:-1]) + + def _handle_tensor(spelling): + if call: + return spelling.replace("Tensor", "void *") + + return spelling.replace("Tensor", "infiniopTensorDescriptor_t") + + def _handle_std_optional(spelling): + return spelling.replace("std::optional<", "").replace(">", "") + + return ", ".join( + f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" + for arg in arguments + ) + + def _generate_arguments(node, is_data=False): + return ", ".join( + _generate_tensor_caster(arg.spelling, is_data=is_data) + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + if arg.spelling != "handle" and arg.spelling != "stream" + ) + + def _generate_tensor_caster(name, is_data=False): + if is_data: + return f"infini::ops::Tensor(const_cast({name}), infini::ops::Tensor::Shape{{}})" + + return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}" + + return _generate_source(operator), _generate_header(operator) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index dfc13979..679e5bab 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -5,6 +5,7 @@ from _operator_utils import OperatorExtractor, get_all_ops, snake_to_pascal from _generate_pybind11 import generate_pybind11 +from _generate_legacy_c import generate_legacy_c _GENERATION_DIR = pathlib.Path("generated") @@ -16,181 +17,6 @@ _INDENTATION = " " - -def _generate_legacy_c(operator, paths): - def _generate_source(operator): - impl_includes = "\n".join( - f'#include "{str(path).removeprefix("src/")}"' for path in paths - ) - - return f"""#include "../../handle.h" -#include "../../tensor.h" -#include "infiniop/ops/{operator.name.lower()}.h" -{impl_includes} - -static infini::ops::DataType DataTypeFromInfiniDType( - const infiniDtype_t& dtype) {{ - static constexpr infini::ops::ConstexprMap - kInfiniDTypeToDataType{{ - {{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}}, - {{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}}, - {{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}}, - {{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}}, - {{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}}, - {{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}}, - {{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}}, - {{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}}, - {{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}}, - {{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}}, - {{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}}, - {{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}}; - - return kInfiniDTypeToDataType.at(dtype); -}} - -static infini::ops::Device::Type DeviceTypeFromInfiniDevice( - const infiniDevice_t& device) {{ - static constexpr infini::ops::ConstexprMap< - infiniDevice_t, infini::ops::Device::Type, - static_cast(INFINI_DEVICE_TYPE_COUNT)> - kInfiniDeviceToDeviceType{{ - {{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}}, - {{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}}, - {{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}}, - {{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}}, - {{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}}, - {{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}}, - {{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}}, - {{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}}, - {{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}}, - {{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}}; - - return kInfiniDeviceToDeviceType.at(device); -}} - -__C {_generate_create_func_def(operator)} - -__C {_generate_get_workspace_size_func_def(operator)} - -__C {_generate_call_func_def(operator)} - -__C {_generate_destroy_func_def(operator)} -""" - - def _generate_header(operator): - return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ -#define __INFINIOP_{operator.name.upper()}_API_H__ - -#include "base/{operator.name.lower()}.h" - -typedef struct infini::ops::Operator *infiniop{operator.name}Descriptor_t; - -__C __export {_generate_create_func_decl(operator)}; - -__C __export {_generate_get_workspace_size_func_decl(operator)}; - -__C __export {_generate_call_func_decl(operator)}; - -__C __export {_generate_destroy_func_decl(operator)}; - -#endif -""" - - def _generate_create_func_def(operator): - name = operator.name - constructor = operator.constructors[-1] - - return f"""{_generate_create_func_decl(operator)} {{ - *desc_ptr = infini::ops::Operator::make({_generate_arguments(constructor)}).release(); - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_get_workspace_size_func_def(operator): - return f"""{_generate_get_workspace_size_func_decl(operator)} {{ - *size = 0; // desc->workspace_size(); - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_call_func_def(operator): - call = operator.calls[-1] - - return f"""{_generate_call_func_decl(operator)} {{ - (*desc)(stream, {_generate_arguments(call, is_data=True)}); - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_destroy_func_def(operator): - return f"""{_generate_destroy_func_decl(operator)} {{ - delete desc; - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_create_func_decl(operator): - name = operator.name - constructor = operator.constructors[-1] - params = _generate_params(constructor) - - return f"infiniStatus_t infiniopCreate{name}Descriptor(infiniopHandle_t handle, infiniop{name}Descriptor_t *desc_ptr, {params})" - - def _generate_get_workspace_size_func_decl(operator): - name = operator.name - - return f"infiniStatus_t infiniopGet{name}WorkspaceSize(infiniop{name}Descriptor_t desc, size_t *size)" - - def _generate_call_func_decl(operator): - name = operator.name - call = operator.calls[-1] - params = _generate_params(call, call=True) - params = params.replace("void * stream, ", "") - - return f"infiniStatus_t infiniop{name}(infiniop{name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)" - - def _generate_destroy_func_decl(operator): - name = operator.name - - return f"infiniStatus_t infiniopDestroy{name}Descriptor(infiniop{name}Descriptor_t desc)" - - def _generate_params(node, call=False): - arguments = tuple(node.get_arguments()) - - arguments = (arguments[-1], *arguments[:-1]) - - def _handle_tensor(spelling): - if call: - return spelling.replace("Tensor", "void *") - return spelling.replace("Tensor", "infiniopTensorDescriptor_t") - - def _handle_std_optional(spelling): - return spelling.replace("std::optional<", "").replace(">", "") - - return ", ".join( - f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" - for arg in arguments - ) - - def _generate_arguments(node, is_data=False): - return ", ".join( - _generate_tensor_caster(arg.spelling, is_data=is_data) - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "handle" and arg.spelling != "stream" - ) - - def _generate_tensor_caster(name, is_data=False): - if is_data: - return f"infini::ops::Tensor(const_cast({name}), infini::ops::Tensor::Shape{{}})" - - return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}" - - return _generate_source(operator), _generate_header(operator) - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="An automatic wrapper generator.") @@ -228,7 +54,7 @@ def _generate_tensor_caster(name, is_data=False): (_BINDINGS_DIR / header_name).write_text(generate_pybind11(operator)) - legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) + legacy_c_source, legacy_c_header = generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) From 8a919485937e1880548830615756ee75ffa0bade Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 05:00:50 +0000 Subject: [PATCH 04/16] fix: add `inline` to `IndexToOffset` in `generic_utils.h` Without `inline`, this function defined in a header causes multiple definition errors when included from multiple translation units. --- src/common/generic_utils.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h index 795f2fb7..b34ce8b0 100644 --- a/src/common/generic_utils.h +++ b/src/common/generic_utils.h @@ -5,9 +5,9 @@ namespace infini::ops::utils { -std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, - const std::size_t* shape, - const std::ptrdiff_t* strides) { +inline std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, + const std::size_t* shape, + const std::ptrdiff_t* strides) { std::size_t res = 0; for (std::size_t i = ndim; i-- > 0;) { res += (flat_index % shape[i]) * strides[i]; From 7b6d8ef23e520cff92d060cb107a93942608ff32 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 05:01:03 +0000 Subject: [PATCH 05/16] fix: wrap device-specific includes with `#ifdef` guards in legacy C generator The generated `operator.cc` now guards device implementation includes (e.g. CUDA headers) with `ENABLE_*_API` preprocessor checks, matching InfiniCore's conditional compilation pattern. --- scripts/_generate_legacy_c.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py index 2397d5b7..c92837e1 100644 --- a/scripts/_generate_legacy_c.py +++ b/scripts/_generate_legacy_c.py @@ -4,10 +4,39 @@ def generate_legacy_c(operator, paths): pascal_name = snake_to_pascal(operator.name) + # Map InfiniOps device directory names to InfiniCore preprocessor guards. + _DEVICE_GUARDS = { + "cpu": "ENABLE_CPU_API", + "nvidia": "ENABLE_NVIDIA_API", + "cambricon": "ENABLE_CAMBRICON_API", + "ascend": "ENABLE_ASCEND_API", + "metax": "ENABLE_METAX_API", + "moore": "ENABLE_MOORE_API", + "iluvatar": "ENABLE_ILUVATAR_API", + "kunlun": "ENABLE_KUNLUN_API", + "hygon": "ENABLE_HYGON_API", + "qy": "ENABLE_QY_API", + } + + def _generate_guarded_includes(): + lines = [] + + for path in paths: + rel = str(path).removeprefix("src/") + device = rel.split("/")[0] + guard = _DEVICE_GUARDS.get(device) + + if guard: + lines.append(f"#ifdef {guard}") + lines.append(f'#include "{rel}"') + lines.append("#endif") + else: + lines.append(f'#include "{rel}"') + + return "\n".join(lines) + def _generate_source(operator): - impl_includes = "\n".join( - f'#include "{str(path).removeprefix("src/")}"' for path in paths - ) + impl_includes = _generate_guarded_includes() return f"""#include "../../handle.h" #include "../../tensor.h" From 20436f05d3aca2be2c97423be3cfcb87226266fa Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 05:02:06 +0000 Subject: [PATCH 06/16] fix: add C API name overrides for `RMSNorm` and `SwiGLU` InfiniCore uses non-standard casing for some operators (e.g. `RMSNorm` instead of `RmsNorm`). Add override mapping to preserve compatibility. --- scripts/_generate_legacy_c.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py index c92837e1..8d074c6e 100644 --- a/scripts/_generate_legacy_c.py +++ b/scripts/_generate_legacy_c.py @@ -1,8 +1,17 @@ from _operator_utils import snake_to_pascal +# Override PascalCase names to match InfiniCore's existing C API conventions. +_C_API_NAME_OVERRIDES = { + "rms_norm": "RMSNorm", + "swiglu": "SwiGLU", +} + def generate_legacy_c(operator, paths): - pascal_name = snake_to_pascal(operator.name) + # The C++ class name from InfiniOps (e.g. `RmsNorm`, `Swiglu`). + cpp_name = snake_to_pascal(operator.name) + # The C API name, which may differ from the C++ class name. + pascal_name = _C_API_NAME_OVERRIDES.get(operator.name, cpp_name) # Map InfiniOps device directory names to InfiniCore preprocessor guards. _DEVICE_GUARDS = { @@ -116,7 +125,7 @@ def _generate_create_func_def(operator): constructor = operator.constructors[-1] return f"""{_generate_create_func_decl(operator)} {{ - *desc_ptr = reinterpret_cast(infini::ops::Operator::make({_generate_arguments(constructor)}).release()); + *desc_ptr = reinterpret_cast(infini::ops::Operator::make({_generate_arguments(constructor)}).release()); return INFINI_STATUS_SUCCESS; }}""" @@ -132,7 +141,7 @@ def _generate_call_func_def(operator): call = operator.calls[-1] return f"""{_generate_call_func_decl(operator)} {{ - auto *op = reinterpret_cast *>(desc); + auto *op = reinterpret_cast *>(desc); op->set_stream(stream); op->set_workspace(workspace); op->set_workspace_size_in_bytes(workspace_size); @@ -144,7 +153,7 @@ def _generate_call_func_def(operator): def _generate_destroy_func_def(operator): return f"""{_generate_destroy_func_decl(operator)} {{ - delete reinterpret_cast *>(desc); + delete reinterpret_cast *>(desc); return INFINI_STATUS_SUCCESS; }}""" From d2eb8bb26b14d1c9187c98a07d689c5c6275a4f2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 07:36:26 +0000 Subject: [PATCH 07/16] fix: add missing `#include` for `RuntimeUtils` in `cuda/causal_softmax/kernel.h` --- src/cuda/causal_softmax/kernel.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 7c7ac871..cffa0713 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -7,6 +7,7 @@ #include "base/causal_softmax.h" #include "cuda/causal_softmax/kernel.cuh" #include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" #include "data_type.h" #include "dispatcher.h" From b163eb49249e6c0f1610c496cd5a17dde56960a4 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 08:18:22 +0000 Subject: [PATCH 08/16] refactor: use stored member variables instead of call-time tensor metadata in `operator()` The `operator()` overloads should only use data pointers from call-time `Tensor` arguments, not metadata like `dtype()`. This ensures correct behavior when the C API bridge passes data-only tensors. - `Gemm`: use `a_type_`/`b_type_`/`c_type_` instead of `a.dtype()`/etc. in both cuBLAS and cuBLASLt implementations - `CausalSoftmax`: use `dtype_` instead of `out.dtype()` - `RmsNorm`: add `dtype_` member to base class, use it instead of `out.dtype()` in CUDA and CPU implementations --- src/base/rms_norm.h | 5 ++++- src/cpu/causal_softmax/causal_softmax.h | 2 +- src/cpu/gemm/gemm.h | 2 +- src/cpu/rms_norm/rms_norm.h | 2 +- src/cuda/causal_softmax/kernel.h | 5 +---- src/cuda/gemm/blas.h | 16 ++++++++-------- src/cuda/rms_norm/kernel.h | 4 +--- src/nvidia/gemm/cublaslt.h | 9 ++++----- 8 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index dc28f0aa..33b5c9bd 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -12,7 +12,8 @@ namespace infini::ops { class RmsNorm : public Operator { public: RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out) - : input_shape_{input.shape()}, + : dtype_{out.dtype()}, + input_shape_{input.shape()}, out_shape_{out.shape()}, input_strides_{input.strides()}, out_strides_{out.strides()}, @@ -37,6 +38,8 @@ class RmsNorm : public Operator { } protected: + const DataType dtype_; + Tensor::Shape input_shape_; Tensor::Shape out_shape_; diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index 14848ee4..ee64ec6c 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -19,7 +19,7 @@ class Operator : public CausalSoftmax, void operator()(const Tensor input, Tensor out) const override { DispatchFunc( - out.dtype(), + dtype_, [&](auto tag) { using T = typename decltype(tag)::type; Compute(input, out); diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index a4dfb989..1b3ebb32 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -32,7 +32,7 @@ class Operator : public Gemm, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { DispatchFunc( - c.dtype(), + c_type_, [&](auto tag) { using T = typename decltype(tag)::type; Compute(a, b, alpha, beta, trans_a, trans_b, c); diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index 9cae419e..c7a1a484 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -20,7 +20,7 @@ class Operator : public RmsNorm, void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const override { DispatchFunc( - out.dtype(), + dtype_, [&](auto tag) { using T = typename decltype(tag)::type; Compute(input, weight, eps, out); diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index cffa0713..dd3df98b 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -30,14 +30,11 @@ class CudaCausalSoftmax : public CausalSoftmax { dim3 grid(static_cast(seq_len_), static_cast(batch_size_)); - assert(out.dtype() == input.dtype()); - int block_size = RuntimeUtils::GetOptimalBlockSize(); DispatchFunc, ReducedFloatTypes>, AllCudaBlockSizes>( - // TODO: Output dtype should use the one passed in during construction. - {static_cast(out.dtype()), block_size}, + {static_cast(dtype_), block_size}, [&](auto list_tag) { using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 264d52c1..a53ab63e 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -42,25 +42,25 @@ class BlasGemm : public Gemm { const auto& trans_b_value{trans_b.value_or(trans_b_)}; auto op_a{GetOpA(trans_a_value, trans_b_value)}; auto op_b{GetOpB(trans_a_value, trans_b_value)}; - const void* alpha_ptr{GetAlphaPtr(alpha_value, c.dtype())}; - const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())}; + const void* alpha_ptr{GetAlphaPtr(alpha_value, c_type_)}; + const void* beta_ptr{GetBetaPtr(beta_value, c_type_)}; Backend::BlasGemmStridedBatchedEx( GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(), - BlasUtils::GetDataType(swap_a_and_b_ ? b.dtype() - : a.dtype()), + BlasUtils::GetDataType(swap_a_and_b_ ? b_type_ + : a_type_), swap_a_and_b_ ? ldb_ : lda_, swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, swap_a_and_b_ ? a.data() : b.data(), - BlasUtils::GetDataType(swap_a_and_b_ ? a.dtype() - : b.dtype()), + BlasUtils::GetDataType(swap_a_and_b_ ? a_type_ + : b_type_), swap_a_and_b_ ? lda_ : ldb_, swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, beta_ptr, c.data(), - BlasUtils::GetDataType(c.dtype()), ldc_, + BlasUtils::GetDataType(c_type_), ldc_, batch_stride_c_, batch_count_, - BlasUtils::GetComputeType(c.dtype()), + BlasUtils::GetComputeType(c_type_), Backend::BLAS_GEMM_DEFAULT); } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 14146edc..62a58dd2 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -32,13 +32,11 @@ class CudaRmsNorm : public RmsNorm { uint32_t num_blocks = static_cast(batch_size_ * nhead_); - assert(out.dtype() == input.dtype() && out.dtype() == weight.dtype()); - int block_size = RuntimeUtils::GetOptimalBlockSize(); DispatchFunc, ReducedFloatTypes>, AllCudaBlockSizes>( - {static_cast(out.dtype()), block_size}, + {static_cast(dtype_), block_size}, [&](auto list_tag) { using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); diff --git a/src/nvidia/gemm/cublaslt.h b/src/nvidia/gemm/cublaslt.h index 38de8507..9e1d3610 100644 --- a/src/nvidia/gemm/cublaslt.h +++ b/src/nvidia/gemm/cublaslt.h @@ -56,11 +56,10 @@ class Operator : public Gemm { const auto* a_ptr{swap_a_and_b_ ? b.data() : a.data()}; const auto* b_ptr{swap_a_and_b_ ? a.data() : b.data()}; const auto a_dtype{BlasUtils::GetDataType( - swap_a_and_b_ ? b.dtype() : a.dtype())}; + swap_a_and_b_ ? b_type_ : a_type_)}; const auto b_dtype{BlasUtils::GetDataType( - swap_a_and_b_ ? a.dtype() : b.dtype())}; - const auto c_dtype{ - BlasUtils::GetDataType(c.dtype())}; + swap_a_and_b_ ? a_type_ : b_type_)}; + const auto c_dtype{BlasUtils::GetDataType(c_type_)}; const auto a_ld{static_cast(swap_a_and_b_ ? ldb_ : lda_)}; const auto b_ld{static_cast(swap_a_and_b_ ? lda_ : ldb_)}; const auto c_ld{static_cast(ldc_)}; @@ -72,7 +71,7 @@ class Operator : public Gemm { cublasLtMatmulDesc_t op_desc{}; auto status = cublasLtMatmulDescCreate( - &op_desc, BlasUtils::GetComputeType(c.dtype()), + &op_desc, BlasUtils::GetComputeType(c_type_), CUDA_R_32F); assert(status == CUBLAS_STATUS_SUCCESS && "failed to create cuBLASLt matmul descriptor"); From 9dc2bc07a323915fe3bf9179e26d949d63018c65 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 9 Apr 2026 08:43:07 +0000 Subject: [PATCH 09/16] refactor: use explicit index mappings for constructor/call selection Replace `constructors[-1]`/`calls[-1]` with configurable index overrides (`_CONSTRUCTOR_INDEX_OVERRIDES`, `_CALL_INDEX_OVERRIDES`) so each operator can specify which constructor and `operator()` overload to use for the generated C API. For example, `RmsNorm` uses constructor index `0` to include the `eps` parameter. --- scripts/_generate_legacy_c.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py index 8d074c6e..c1eaeddb 100644 --- a/scripts/_generate_legacy_c.py +++ b/scripts/_generate_legacy_c.py @@ -6,12 +6,22 @@ "swiglu": "SwiGLU", } +# Override which constructor/call overload index to use per operator. +# Default is -1 (the last one, typically the simplest). +_CONSTRUCTOR_INDEX_OVERRIDES = { + "rms_norm": 0, +} + +_CALL_INDEX_OVERRIDES = {} + def generate_legacy_c(operator, paths): # The C++ class name from InfiniOps (e.g. `RmsNorm`, `Swiglu`). cpp_name = snake_to_pascal(operator.name) # The C API name, which may differ from the C++ class name. pascal_name = _C_API_NAME_OVERRIDES.get(operator.name, cpp_name) + constructor_index = _CONSTRUCTOR_INDEX_OVERRIDES.get(operator.name, -1) + call_index = _CALL_INDEX_OVERRIDES.get(operator.name, -1) # Map InfiniOps device directory names to InfiniCore preprocessor guards. _DEVICE_GUARDS = { @@ -122,7 +132,7 @@ def _generate_header(operator): """ def _generate_create_func_def(operator): - constructor = operator.constructors[-1] + constructor = operator.constructors[constructor_index] return f"""{_generate_create_func_decl(operator)} {{ *desc_ptr = reinterpret_cast(infini::ops::Operator::make({_generate_arguments(constructor)}).release()); @@ -138,7 +148,7 @@ def _generate_get_workspace_size_func_def(operator): }}""" def _generate_call_func_def(operator): - call = operator.calls[-1] + call = operator.calls[call_index] return f"""{_generate_call_func_decl(operator)} {{ auto *op = reinterpret_cast *>(desc); @@ -159,7 +169,7 @@ def _generate_destroy_func_def(operator): }}""" def _generate_create_func_decl(operator): - constructor = operator.constructors[-1] + constructor = operator.constructors[constructor_index] params = _generate_params(constructor) return f"infiniStatus_t infiniopCreate{pascal_name}Descriptor(infiniopHandle_t handle, infiniop{pascal_name}Descriptor_t *desc_ptr, {params})" @@ -168,7 +178,7 @@ def _generate_get_workspace_size_func_decl(operator): return f"infiniStatus_t infiniopGet{pascal_name}WorkspaceSize(infiniop{pascal_name}Descriptor_t desc, size_t *size)" def _generate_call_func_decl(operator): - call = operator.calls[-1] + call = operator.calls[call_index] params = _generate_params(call, call=True) params = params.replace("void * stream, ", "") From 2024c59261d38a66833aafe760d5d4f8e2cd4ed6 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Apr 2026 07:05:26 +0000 Subject: [PATCH 10/16] feat: add shared library factory function generator Add `_generate_shared_lib.py` that generates non-template `Make{Op}` factory functions wrapping `Operator::make()`. These are compiled into `libinfiniops.so`, allowing consumers to construct operators without needing device-specific headers (e.g., CUDA). Generated files: - `generated/src/{op}/make.cc`: factory function definitions - `generated/include/make.h`: combined header with declarations --- scripts/_generate_shared_lib.py | 77 +++++++++++++++++++++++++++++++++ scripts/generate_wrappers.py | 8 ++++ src/CMakeLists.txt | 8 ++++ 3 files changed, 93 insertions(+) create mode 100644 scripts/_generate_shared_lib.py diff --git a/scripts/_generate_shared_lib.py b/scripts/_generate_shared_lib.py new file mode 100644 index 00000000..9bdca1d7 --- /dev/null +++ b/scripts/_generate_shared_lib.py @@ -0,0 +1,77 @@ +from _operator_utils import snake_to_pascal + + +def generate_shared_lib(operator, paths): + cpp_name = snake_to_pascal(operator.name) + + def _generate_impl_includes(): + return "\n".join( + f'#include "{str(path).removeprefix("src/")}"' for path in paths + ) + + def _generate_params(node): + return ", ".join( + f"{arg.type.spelling} {arg.spelling}" for arg in node.get_arguments() + ) + + def _generate_arguments(node): + return ", ".join(arg.spelling for arg in node.get_arguments()) + + def _generate_make_decl(constructor): + params = _generate_params(constructor) + if params: + params = f"const Config& config, {params}" + else: + params = "const Config& config" + return f"std::unique_ptr Make{cpp_name}({params})" + + def _generate_make_def(constructor): + args = _generate_arguments(constructor) + make_args = f"config, {args}" if args else "config" + return f"""{_generate_make_decl(constructor)} {{ + return Operator<{cpp_name}>::make({make_args}); +}}""" + + impl_includes = _generate_impl_includes() + + make_defs = "\n\n".join( + _generate_make_def(c) for c in operator.constructors + ) + + source = f"""#include "base/{operator.name}.h" +{impl_includes} + +namespace infini::ops {{ + +{make_defs} + +}} // namespace infini::ops +""" + + make_decls = "\n\n".join( + f"{_generate_make_decl(c)};" for c in operator.constructors + ) + + return source, make_decls + + +def generate_shared_lib_header(all_decls): + combined = "\n\n".join(all_decls) + return f"""#ifndef INFINI_OPS_MAKE_H_ +#define INFINI_OPS_MAKE_H_ + +#include +#include + +#include "config.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops {{ + +{combined} + +}} // namespace infini::ops + +#endif +""" diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 679e5bab..f2d4ea43 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -6,6 +6,7 @@ from _operator_utils import OperatorExtractor, get_all_ops, snake_to_pascal from _generate_pybind11 import generate_pybind11 from _generate_legacy_c import generate_legacy_c +from _generate_shared_lib import generate_shared_lib, generate_shared_lib_header _GENERATION_DIR = pathlib.Path("generated") @@ -43,6 +44,7 @@ header_paths = [] bind_func_names = [] + shared_lib_decls = [] for op_name, impl_paths in ops.items(): extractor = OperatorExtractor() @@ -59,9 +61,15 @@ (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) + sl_source, sl_decls = generate_shared_lib(operator, impl_paths) + (_GENERATED_SRC_DIR / op_name / "make.cc").write_text(sl_source) + shared_lib_decls.append(sl_decls) + header_paths.append(header_name) bind_func_names.append(bind_func_name) + (_INCLUDE_DIR / "make.h").write_text(generate_shared_lib_header(shared_lib_decls)) + impl_includes = "\n".join( f'#include "{impl_path}"' for impl_paths in ops.values() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0b56341b..a3f8c322 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -174,6 +174,14 @@ endif() target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +file(GLOB FACTORY_SOURCES "${PROJECT_SOURCE_DIR}/generated/src/*/make.cc") +if(FACTORY_SOURCES) + if(WITH_NVIDIA OR WITH_ILUVATAR) + set_source_files_properties(${FACTORY_SOURCES} PROPERTIES LANGUAGE CUDA) + endif() + target_sources(infiniops PRIVATE ${FACTORY_SOURCES}) +endif() + if(GENERATE_PYTHON_BINDINGS) find_package(Python COMPONENTS Interpreter REQUIRED) execute_process( From 66db32ffa892abb6980395e07f13504446ad81b2 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Apr 2026 07:54:04 +0000 Subject: [PATCH 11/16] refactor: use factory functions in legacy C generator Update the legacy C generator to call `Make{Op}` factory functions from `libinfiniops.so` instead of `Operator::make()`. The generated wrappers no longer include device-specific headers and can be compiled with g++ (no CUDA needed). --- scripts/_generate_legacy_c.py | 45 +++++------------------------------ 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py index c1eaeddb..33279110 100644 --- a/scripts/_generate_legacy_c.py +++ b/scripts/_generate_legacy_c.py @@ -23,44 +23,12 @@ def generate_legacy_c(operator, paths): constructor_index = _CONSTRUCTOR_INDEX_OVERRIDES.get(operator.name, -1) call_index = _CALL_INDEX_OVERRIDES.get(operator.name, -1) - # Map InfiniOps device directory names to InfiniCore preprocessor guards. - _DEVICE_GUARDS = { - "cpu": "ENABLE_CPU_API", - "nvidia": "ENABLE_NVIDIA_API", - "cambricon": "ENABLE_CAMBRICON_API", - "ascend": "ENABLE_ASCEND_API", - "metax": "ENABLE_METAX_API", - "moore": "ENABLE_MOORE_API", - "iluvatar": "ENABLE_ILUVATAR_API", - "kunlun": "ENABLE_KUNLUN_API", - "hygon": "ENABLE_HYGON_API", - "qy": "ENABLE_QY_API", - } - - def _generate_guarded_includes(): - lines = [] - - for path in paths: - rel = str(path).removeprefix("src/") - device = rel.split("/")[0] - guard = _DEVICE_GUARDS.get(device) - - if guard: - lines.append(f"#ifdef {guard}") - lines.append(f'#include "{rel}"') - lines.append("#endif") - else: - lines.append(f'#include "{rel}"') - - return "\n".join(lines) - def _generate_source(operator): - impl_includes = _generate_guarded_includes() - return f"""#include "../../handle.h" #include "../../tensor.h" #include "infiniop/ops/{operator.name}.h" -{impl_includes} +#include "base/{operator.name}.h" +#include "make.h" static infini::ops::DataType DataTypeFromInfiniDType( const infiniDtype_t& dtype) {{ @@ -135,7 +103,7 @@ def _generate_create_func_def(operator): constructor = operator.constructors[constructor_index] return f"""{_generate_create_func_decl(operator)} {{ - *desc_ptr = reinterpret_cast(infini::ops::Operator::make({_generate_arguments(constructor)}).release()); + *desc_ptr = reinterpret_cast(infini::ops::Make{cpp_name}({{}}, {_generate_arguments(constructor)}).release()); return INFINI_STATUS_SUCCESS; }}""" @@ -151,19 +119,18 @@ def _generate_call_func_def(operator): call = operator.calls[call_index] return f"""{_generate_call_func_decl(operator)} {{ - auto *op = reinterpret_cast *>(desc); + auto *op = reinterpret_cast(desc); op->set_stream(stream); op->set_workspace(workspace); op->set_workspace_size_in_bytes(workspace_size); - const auto &op_ref = *op; - op_ref({_generate_arguments(call, is_data=True)}); + static_cast(*op)({_generate_arguments(call, is_data=True)}); return INFINI_STATUS_SUCCESS; }}""" def _generate_destroy_func_def(operator): return f"""{_generate_destroy_func_decl(operator)} {{ - delete reinterpret_cast *>(desc); + delete reinterpret_cast(desc); return INFINI_STATUS_SUCCESS; }}""" From dd691862ea2c521ad394ad37bf05428d6040cc88 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Apr 2026 09:28:52 +0000 Subject: [PATCH 12/16] style: use Markdown backtick-fencing for identifiers in comments --- scripts/_generate_legacy_c.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py index 33279110..4001169d 100644 --- a/scripts/_generate_legacy_c.py +++ b/scripts/_generate_legacy_c.py @@ -1,13 +1,13 @@ from _operator_utils import snake_to_pascal -# Override PascalCase names to match InfiniCore's existing C API conventions. +# Override `PascalCase` names to match InfiniCore's existing C API conventions. _C_API_NAME_OVERRIDES = { "rms_norm": "RMSNorm", "swiglu": "SwiGLU", } # Override which constructor/call overload index to use per operator. -# Default is -1 (the last one, typically the simplest). +# Default is `-1` (the last one, typically the simplest). _CONSTRUCTOR_INDEX_OVERRIDES = { "rms_norm": 0, } From 7d481eadfc00419df1663b52f6e9d72a470120a3 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Apr 2026 09:32:13 +0000 Subject: [PATCH 13/16] style: apply `ruff format` to `_generate_shared_lib.py` --- scripts/_generate_shared_lib.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scripts/_generate_shared_lib.py b/scripts/_generate_shared_lib.py index 9bdca1d7..4c528d44 100644 --- a/scripts/_generate_shared_lib.py +++ b/scripts/_generate_shared_lib.py @@ -34,9 +34,7 @@ def _generate_make_def(constructor): impl_includes = _generate_impl_includes() - make_defs = "\n\n".join( - _generate_make_def(c) for c in operator.constructors - ) + make_defs = "\n\n".join(_generate_make_def(c) for c in operator.constructors) source = f"""#include "base/{operator.name}.h" {impl_includes} From ed784a76b0ac94009320a2f90b2c392e89e18073 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Apr 2026 09:45:16 +0000 Subject: [PATCH 14/16] style: add blank lines around control flow and before returns --- scripts/_generate_shared_lib.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/_generate_shared_lib.py b/scripts/_generate_shared_lib.py index 4c528d44..d951acf7 100644 --- a/scripts/_generate_shared_lib.py +++ b/scripts/_generate_shared_lib.py @@ -19,15 +19,18 @@ def _generate_arguments(node): def _generate_make_decl(constructor): params = _generate_params(constructor) + if params: params = f"const Config& config, {params}" else: params = "const Config& config" + return f"std::unique_ptr Make{cpp_name}({params})" def _generate_make_def(constructor): args = _generate_arguments(constructor) make_args = f"config, {args}" if args else "config" + return f"""{_generate_make_decl(constructor)} {{ return Operator<{cpp_name}>::make({make_args}); }}""" @@ -55,6 +58,7 @@ def _generate_make_def(constructor): def generate_shared_lib_header(all_decls): combined = "\n\n".join(all_decls) + return f"""#ifndef INFINI_OPS_MAKE_H_ #define INFINI_OPS_MAKE_H_ From 96043ba8d2abb72d96fe94dcdcf92df65c84d9fa Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 10 Apr 2026 09:51:03 +0000 Subject: [PATCH 15/16] style: reorder `dtype_` member to match `Tensor` field order --- src/base/rms_norm.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index 33b5c9bd..ed087863 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -12,9 +12,9 @@ namespace infini::ops { class RmsNorm : public Operator { public: RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out) - : dtype_{out.dtype()}, - input_shape_{input.shape()}, + : input_shape_{input.shape()}, out_shape_{out.shape()}, + dtype_{out.dtype()}, input_strides_{input.strides()}, out_strides_{out.strides()}, eps_{eps}, @@ -38,12 +38,12 @@ class RmsNorm : public Operator { } protected: - const DataType dtype_; - Tensor::Shape input_shape_; Tensor::Shape out_shape_; + const DataType dtype_; + Tensor::Strides input_strides_; Tensor::Strides out_strides_; From a5127a01f731a24e1b6ec29fdb707b7f73b8e8fc Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Thu, 16 Apr 2026 07:21:07 +0000 Subject: [PATCH 16/16] feat: add CUDA architecture auto-detection for NVIDIA builds Automatically detect the GPU's compute capability when `CMAKE_CUDA_ARCHITECTURES` is not explicitly set. Uses CMake 3.24+ `"native"` mode when available; falls back to querying `nvidia-smi` on older CMake versions, defaulting to `sm_80` if detection fails. --- CMakeLists.txt | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index b9e2deb5..1204fcfe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,27 @@ endif() if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + set(CMAKE_CUDA_ARCHITECTURES "native") + else() + # Detect GPU architecture via `nvidia-smi`. + execute_process( + COMMAND nvidia-smi --query-gpu=compute_cap --format=csv,noheader + OUTPUT_VARIABLE _gpu_cc OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET RESULT_VARIABLE _nvsmi_result + ) + if(_nvsmi_result EQUAL 0 AND _gpu_cc) + string(REGEX MATCH "^[0-9]+\\.[0-9]+" _first_cc "${_gpu_cc}") + string(REPLACE "." "" _arch "${_first_cc}") + set(CMAKE_CUDA_ARCHITECTURES "${_arch}") + message(STATUS "Auto-detected CUDA architecture: `sm_${_arch}`.") + else() + message(WARNING "Could not detect GPU architecture; defaulting to `sm_80`.") + set(CMAKE_CUDA_ARCHITECTURES "80") + endif() + endif() + endif() enable_language(CUDA) find_package(CUDAToolkit REQUIRED) endif()