diff --git a/CMakeLists.txt b/CMakeLists.txt index b9e2deb5..1204fcfe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,27 @@ endif() if(WITH_NVIDIA) add_compile_definitions(WITH_NVIDIA=1) + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + set(CMAKE_CUDA_ARCHITECTURES "native") + else() + # Detect GPU architecture via `nvidia-smi`. + execute_process( + COMMAND nvidia-smi --query-gpu=compute_cap --format=csv,noheader + OUTPUT_VARIABLE _gpu_cc OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET RESULT_VARIABLE _nvsmi_result + ) + if(_nvsmi_result EQUAL 0 AND _gpu_cc) + string(REGEX MATCH "^[0-9]+\\.[0-9]+" _first_cc "${_gpu_cc}") + string(REPLACE "." "" _arch "${_first_cc}") + set(CMAKE_CUDA_ARCHITECTURES "${_arch}") + message(STATUS "Auto-detected CUDA architecture: `sm_${_arch}`.") + else() + message(WARNING "Could not detect GPU architecture; defaulting to `sm_80`.") + set(CMAKE_CUDA_ARCHITECTURES "80") + endif() + endif() + endif() enable_language(CUDA) find_package(CUDAToolkit REQUIRED) endif() diff --git a/scripts/_generate_legacy_c.py b/scripts/_generate_legacy_c.py new file mode 100644 index 00000000..4001169d --- /dev/null +++ b/scripts/_generate_legacy_c.py @@ -0,0 +1,190 @@ +from _operator_utils import snake_to_pascal + +# Override `PascalCase` names to match InfiniCore's existing C API conventions. +_C_API_NAME_OVERRIDES = { + "rms_norm": "RMSNorm", + "swiglu": "SwiGLU", +} + +# Override which constructor/call overload index to use per operator. +# Default is `-1` (the last one, typically the simplest). +_CONSTRUCTOR_INDEX_OVERRIDES = { + "rms_norm": 0, +} + +_CALL_INDEX_OVERRIDES = {} + + +def generate_legacy_c(operator, paths): + # The C++ class name from InfiniOps (e.g. `RmsNorm`, `Swiglu`). + cpp_name = snake_to_pascal(operator.name) + # The C API name, which may differ from the C++ class name. + pascal_name = _C_API_NAME_OVERRIDES.get(operator.name, cpp_name) + constructor_index = _CONSTRUCTOR_INDEX_OVERRIDES.get(operator.name, -1) + call_index = _CALL_INDEX_OVERRIDES.get(operator.name, -1) + + def _generate_source(operator): + return f"""#include "../../handle.h" +#include "../../tensor.h" +#include "infiniop/ops/{operator.name}.h" +#include "base/{operator.name}.h" +#include "make.h" + +static infini::ops::DataType DataTypeFromInfiniDType( + const infiniDtype_t& dtype) {{ + static constexpr infini::ops::ConstexprMap + kInfiniDTypeToDataType{{ + {{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}}, + {{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}}, + {{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}}, + {{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}}, + {{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}}, + {{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}}, + {{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}}, + {{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}}, + {{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}}, + {{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}}, + {{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}}, + {{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}}; + + return kInfiniDTypeToDataType.at(dtype); +}} + +static infini::ops::Device::Type DeviceTypeFromInfiniDevice( + const infiniDevice_t& device) {{ + static constexpr infini::ops::ConstexprMap< + infiniDevice_t, infini::ops::Device::Type, + static_cast(INFINI_DEVICE_TYPE_COUNT)> + kInfiniDeviceToDeviceType{{ + {{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}}, + {{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}}, + {{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}}, + {{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}}, + {{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}}, + {{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}}, + {{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}}, + {{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}}, + {{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}}, + {{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}}; + + return kInfiniDeviceToDeviceType.at(device); +}} + +__INFINI_C {_generate_create_func_def(operator)} + +__INFINI_C {_generate_get_workspace_size_func_def(operator)} + +__INFINI_C {_generate_call_func_def(operator)} + +__INFINI_C {_generate_destroy_func_def(operator)} +""" + + def _generate_header(operator): + return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ +#define __INFINIOP_{operator.name.upper()}_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniop{pascal_name}Descriptor_t; + +__INFINI_C __export {_generate_create_func_decl(operator)}; + +__INFINI_C __export {_generate_get_workspace_size_func_decl(operator)}; + +__INFINI_C __export {_generate_call_func_decl(operator)}; + +__INFINI_C __export {_generate_destroy_func_decl(operator)}; + +#endif +""" + + def _generate_create_func_def(operator): + constructor = operator.constructors[constructor_index] + + return f"""{_generate_create_func_decl(operator)} {{ + *desc_ptr = reinterpret_cast(infini::ops::Make{cpp_name}({{}}, {_generate_arguments(constructor)}).release()); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_get_workspace_size_func_def(operator): + return f"""{_generate_get_workspace_size_func_decl(operator)} {{ + *size = 0; // desc->workspace_size(); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_call_func_def(operator): + call = operator.calls[call_index] + + return f"""{_generate_call_func_decl(operator)} {{ + auto *op = reinterpret_cast(desc); + op->set_stream(stream); + op->set_workspace(workspace); + op->set_workspace_size_in_bytes(workspace_size); + static_cast(*op)({_generate_arguments(call, is_data=True)}); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_destroy_func_def(operator): + return f"""{_generate_destroy_func_decl(operator)} {{ + delete reinterpret_cast(desc); + + return INFINI_STATUS_SUCCESS; +}}""" + + def _generate_create_func_decl(operator): + constructor = operator.constructors[constructor_index] + params = _generate_params(constructor) + + return f"infiniStatus_t infiniopCreate{pascal_name}Descriptor(infiniopHandle_t handle, infiniop{pascal_name}Descriptor_t *desc_ptr, {params})" + + def _generate_get_workspace_size_func_decl(operator): + return f"infiniStatus_t infiniopGet{pascal_name}WorkspaceSize(infiniop{pascal_name}Descriptor_t desc, size_t *size)" + + def _generate_call_func_decl(operator): + call = operator.calls[call_index] + params = _generate_params(call, call=True) + params = params.replace("void * stream, ", "") + + return f"infiniStatus_t infiniop{pascal_name}(infiniop{pascal_name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)" + + def _generate_destroy_func_decl(operator): + return f"infiniStatus_t infiniopDestroy{pascal_name}Descriptor(infiniop{pascal_name}Descriptor_t desc)" + + def _generate_params(node, call=False): + arguments = tuple(node.get_arguments()) + arguments = (arguments[-1], *arguments[:-1]) + + def _handle_tensor(spelling): + if call: + return spelling.replace("Tensor", "void *") + + return spelling.replace("Tensor", "infiniopTensorDescriptor_t") + + def _handle_std_optional(spelling): + return spelling.replace("std::optional<", "").replace(">", "") + + return ", ".join( + f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" + for arg in arguments + ) + + def _generate_arguments(node, is_data=False): + return ", ".join( + _generate_tensor_caster(arg.spelling, is_data=is_data) + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + if arg.spelling != "handle" and arg.spelling != "stream" + ) + + def _generate_tensor_caster(name, is_data=False): + if is_data: + return f"infini::ops::Tensor(const_cast({name}), infini::ops::Tensor::Shape{{}})" + + return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}" + + return _generate_source(operator), _generate_header(operator) diff --git a/scripts/_generate_pybind11.py b/scripts/_generate_pybind11.py new file mode 100644 index 00000000..89a902d2 --- /dev/null +++ b/scripts/_generate_pybind11.py @@ -0,0 +1,104 @@ +from _operator_utils import snake_to_pascal + + +def generate_pybind11(operator): + def _generate_params(node): + return ( + ", ".join( + f"{arg.type.spelling} {arg.spelling}" + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + .replace("const Tensor", "py::object") + .replace("Tensor", "py::object") + ) + + def _generate_arguments(node): + return ", ".join( + f"TensorFromPybind11Handle({arg.spelling})" + if "Tensor" in arg.type.spelling + else arg.spelling + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + op_name = operator.name + + def _generate_init(constructor): + constructor_params = _generate_params(constructor) + + return f""" .def(py::init([]({constructor_params}) {{ + return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; + }}))""" + + def _generate_py_args(node): + return ", ".join( + f'py::arg("{arg.spelling}")' + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + def _generate_call(op_name, call, method=True): + call_params = _generate_params(call) + call_args = _generate_arguments(call) + + if not method: + params = ( + f"{call_params}, std::size_t implementation_index" + if call_params + else "std::size_t implementation_index" + ) + py_args = _generate_py_args(call) + py_args_str = f"{py_args}, " if py_args else "" + + return f""" m.def("{op_name}", []({params}) {{ + Config config; + config.set_implementation_index(implementation_index); + return Self::call({{}}, config, {call_args}); + }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" + + return f""" .def("__call__", [](const Self& self, {call_params}) {{ + return static_cast&>(self)({call_args}); + }})""" + + inits = "\n".join( + _generate_init(constructor) for constructor in operator.constructors + ) + calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) + callers = "\n".join( + _generate_call(operator.name, call, method=False) for call in operator.calls + ) + + pascal_case_op_name = snake_to_pascal(op_name) + + return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ +#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ + +#include +#include + +#include "base/{op_name}.h" +#include "config.h" +#include "pybind11_utils.h" + +namespace py = pybind11; + +namespace infini::ops {{ + +void Bind{pascal_case_op_name}(py::module& m) {{ + using Self = {pascal_case_op_name}; + + py::class_(m, "{pascal_case_op_name}") +{inits} +{calls} + .def_static("active_implementation_indices", [](const std::string& device) {{ + return Self::active_implementation_indices(DeviceTypeFromString(device)); + }}); + +{callers} +}} + +}} // namespace infini::ops + +#endif +""" diff --git a/scripts/_generate_shared_lib.py b/scripts/_generate_shared_lib.py new file mode 100644 index 00000000..d951acf7 --- /dev/null +++ b/scripts/_generate_shared_lib.py @@ -0,0 +1,79 @@ +from _operator_utils import snake_to_pascal + + +def generate_shared_lib(operator, paths): + cpp_name = snake_to_pascal(operator.name) + + def _generate_impl_includes(): + return "\n".join( + f'#include "{str(path).removeprefix("src/")}"' for path in paths + ) + + def _generate_params(node): + return ", ".join( + f"{arg.type.spelling} {arg.spelling}" for arg in node.get_arguments() + ) + + def _generate_arguments(node): + return ", ".join(arg.spelling for arg in node.get_arguments()) + + def _generate_make_decl(constructor): + params = _generate_params(constructor) + + if params: + params = f"const Config& config, {params}" + else: + params = "const Config& config" + + return f"std::unique_ptr Make{cpp_name}({params})" + + def _generate_make_def(constructor): + args = _generate_arguments(constructor) + make_args = f"config, {args}" if args else "config" + + return f"""{_generate_make_decl(constructor)} {{ + return Operator<{cpp_name}>::make({make_args}); +}}""" + + impl_includes = _generate_impl_includes() + + make_defs = "\n\n".join(_generate_make_def(c) for c in operator.constructors) + + source = f"""#include "base/{operator.name}.h" +{impl_includes} + +namespace infini::ops {{ + +{make_defs} + +}} // namespace infini::ops +""" + + make_decls = "\n\n".join( + f"{_generate_make_decl(c)};" for c in operator.constructors + ) + + return source, make_decls + + +def generate_shared_lib_header(all_decls): + combined = "\n\n".join(all_decls) + + return f"""#ifndef INFINI_OPS_MAKE_H_ +#define INFINI_OPS_MAKE_H_ + +#include +#include + +#include "config.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops {{ + +{combined} + +}} // namespace infini::ops + +#endif +""" diff --git a/scripts/_operator_utils.py b/scripts/_operator_utils.py new file mode 100644 index 00000000..2cb52bba --- /dev/null +++ b/scripts/_operator_utils.py @@ -0,0 +1,103 @@ +import pathlib +import shutil +import subprocess + +import clang.cindex +from clang.cindex import CursorKind + +_SRC_DIR = pathlib.Path("src") + +_BASE_DIR = _SRC_DIR / "base" + + +def snake_to_pascal(snake_str): + return "".join(word.capitalize() for word in snake_str.split("_")) + + +class Operator: + def __init__(self, name, constructors, calls): + self.name = name + + self.constructors = constructors + + self.calls = calls + + +class OperatorExtractor: + def __call__(self, op_name): + def _get_system_include_flags(): + def _get_compilers(): + compilers = [] + + for compiler in ("clang++", "g++"): + if shutil.which(compiler) is not None: + compilers.append(compiler) + + return compilers + + system_include_flags = [] + + for compiler in _get_compilers(): + for line in subprocess.getoutput( + f"{compiler} -E -x c++ -v /dev/null" + ).splitlines(): + if not line.startswith(" "): + continue + + system_include_flags.append("-isystem") + system_include_flags.append(line.strip()) + + return system_include_flags + + system_include_flags = _get_system_include_flags() + + index = clang.cindex.Index.create() + args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) + translation_unit = index.parse(f"src/base/{op_name}.h", args=args) + + nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) + + constructors = [] + calls = [] + + for node in nodes: + if node.kind == CursorKind.CONSTRUCTOR: + constructors.append(node) + elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": + calls.append(node) + + return Operator(op_name, constructors, calls) + + @staticmethod + def _find(node, op_name): + pascal_case_op_name = snake_to_pascal(op_name) + + if ( + node.semantic_parent + and node.semantic_parent.spelling == pascal_case_op_name + ): + yield node + + for child in node.get_children(): + yield from OperatorExtractor._find(child, op_name) + + +def get_all_ops(devices): + ops = {} + + for file_path in _BASE_DIR.iterdir(): + if not file_path.is_file(): + continue + + op_name = file_path.stem + + ops[op_name] = [] + + for file_path in _SRC_DIR.rglob("*"): + if not file_path.is_file() or file_path.parent.parent.name not in devices: + continue + + if f"class Operator<{snake_to_pascal(op_name)}" in file_path.read_text(): + ops[op_name].append(file_path) + + return ops diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 5aa8896e..f2d4ea43 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -1,16 +1,12 @@ import argparse import json import pathlib -import shutil -import subprocess import textwrap -import clang.cindex -from clang.cindex import CursorKind - -_SRC_DIR = pathlib.Path("src") - -_BASE_DIR = _SRC_DIR / "base" +from _operator_utils import OperatorExtractor, get_all_ops, snake_to_pascal +from _generate_pybind11 import generate_pybind11 +from _generate_legacy_c import generate_legacy_c +from _generate_shared_lib import generate_shared_lib, generate_shared_lib_header _GENERATION_DIR = pathlib.Path("generated") @@ -22,377 +18,6 @@ _INDENTATION = " " - -class _OperatorExtractor: - def __call__(self, op_name): - def _get_system_include_flags(): - def _get_compilers(): - compilers = [] - - for compiler in ("clang++", "g++"): - if shutil.which(compiler) is not None: - compilers.append(compiler) - - return compilers - - system_include_flags = [] - - for compiler in _get_compilers(): - for line in subprocess.getoutput( - f"{compiler} -E -x c++ -v /dev/null" - ).splitlines(): - if not line.startswith(" "): - continue - - system_include_flags.append("-isystem") - system_include_flags.append(line.strip()) - - return system_include_flags - - system_include_flags = _get_system_include_flags() - - index = clang.cindex.Index.create() - args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name}.h", args=args) - - nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) - - constructors = [] - calls = [] - - for node in nodes: - if node.kind == CursorKind.CONSTRUCTOR: - constructors.append(node) - elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()": - calls.append(node) - - return _Operator(op_name, constructors, calls) - - @staticmethod - def _find(node, op_name): - pascal_case_op_name = _snake_to_pascal(op_name) - - if ( - node.semantic_parent - and node.semantic_parent.spelling == pascal_case_op_name - ): - yield node - - for child in node.get_children(): - yield from _OperatorExtractor._find(child, op_name) - - -class _Operator: - def __init__(self, name, constructors, calls): - self.name = name - - self.constructors = constructors - - self.calls = calls - - -def _generate_pybind11(operator): - def _generate_params(node): - return ( - ", ".join( - f"{arg.type.spelling} {arg.spelling}" - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - .replace("const Tensor", "py::object") - .replace("Tensor", "py::object") - ) - - def _generate_arguments(node): - return ", ".join( - f"TensorFromPybind11Handle({arg.spelling})" - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - - op_name = operator.name - - def _generate_init(constructor): - constructor_params = _generate_params(constructor) - - return f""" .def(py::init([]({constructor_params}) {{ - return std::unique_ptr{{static_cast(Self::make({_generate_arguments(constructor)}).release())}}; - }}))""" - - def _generate_py_args(node): - return ", ".join( - f'py::arg("{arg.spelling}")' - for arg in node.get_arguments() - if arg.spelling != "stream" - ) - - def _generate_call(op_name, call, method=True): - call_params = _generate_params(call) - call_args = _generate_arguments(call) - - if not method: - params = ( - f"{call_params}, std::size_t implementation_index" - if call_params - else "std::size_t implementation_index" - ) - py_args = _generate_py_args(call) - py_args_str = f"{py_args}, " if py_args else "" - - return f""" m.def("{op_name}", []({params}) {{ - Config config; - config.set_implementation_index(implementation_index); - return Self::call({{}}, config, {call_args}); - }}, {py_args_str}py::kw_only(), py::arg("implementation_index") = 0);""" - - return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return static_cast&>(self)({call_args}); - }})""" - - inits = "\n".join( - _generate_init(constructor) for constructor in operator.constructors - ) - calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls) - callers = "\n".join( - _generate_call(operator.name, call, method=False) for call in operator.calls - ) - - pascal_case_op_name = _snake_to_pascal(op_name) - - return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ -#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ - -#include -#include - -#include "base/{op_name}.h" -#include "config.h" -#include "pybind11_utils.h" - -namespace py = pybind11; - -namespace infini::ops {{ - -void Bind{pascal_case_op_name}(py::module& m) {{ - using Self = {pascal_case_op_name}; - - py::class_(m, "{pascal_case_op_name}") -{inits} -{calls} - .def_static("active_implementation_indices", [](const std::string& device) {{ - return Self::active_implementation_indices(DeviceTypeFromString(device)); - }}); - -{callers} -}} - -}} // namespace infini::ops - -#endif -""" - - -def _generate_legacy_c(operator, paths): - def _generate_source(operator): - impl_includes = "\n".join( - f'#include "{str(path).removeprefix("src/")}"' for path in paths - ) - - return f"""#include "../../handle.h" -#include "../../tensor.h" -#include "infiniop/ops/{operator.name.lower()}.h" -{impl_includes} - -static infini::ops::DataType DataTypeFromInfiniDType( - const infiniDtype_t& dtype) {{ - static constexpr infini::ops::ConstexprMap - kInfiniDTypeToDataType{{ - {{{{{{INFINI_DTYPE_I8, infini::ops::DataType::kInt8}}, - {{INFINI_DTYPE_I16, infini::ops::DataType::kInt16}}, - {{INFINI_DTYPE_I32, infini::ops::DataType::kInt32}}, - {{INFINI_DTYPE_I64, infini::ops::DataType::kInt64}}, - {{INFINI_DTYPE_U8, infini::ops::DataType::kUInt8}}, - {{INFINI_DTYPE_U16, infini::ops::DataType::kUInt16}}, - {{INFINI_DTYPE_U32, infini::ops::DataType::kUInt32}}, - {{INFINI_DTYPE_U64, infini::ops::DataType::kUInt64}}, - {{INFINI_DTYPE_F16, infini::ops::DataType::kFloat16}}, - {{INFINI_DTYPE_BF16, infini::ops::DataType::kBFloat16}}, - {{INFINI_DTYPE_F32, infini::ops::DataType::kFloat32}}, - {{INFINI_DTYPE_F64, infini::ops::DataType::kFloat64}}}}}}}}; - - return kInfiniDTypeToDataType.at(dtype); -}} - -static infini::ops::Device::Type DeviceTypeFromInfiniDevice( - const infiniDevice_t& device) {{ - static constexpr infini::ops::ConstexprMap< - infiniDevice_t, infini::ops::Device::Type, - static_cast(INFINI_DEVICE_TYPE_COUNT)> - kInfiniDeviceToDeviceType{{ - {{{{{{INFINI_DEVICE_CPU, infini::ops::Device::Type::kCpu}}, - {{INFINI_DEVICE_NVIDIA, infini::ops::Device::Type::kNvidia}}, - {{INFINI_DEVICE_CAMBRICON, infini::ops::Device::Type::kCambricon}}, - {{INFINI_DEVICE_ASCEND, infini::ops::Device::Type::kAscend}}, - {{INFINI_DEVICE_METAX, infini::ops::Device::Type::kMetax}}, - {{INFINI_DEVICE_MOORE, infini::ops::Device::Type::kMoore}}, - {{INFINI_DEVICE_ILUVATAR, infini::ops::Device::Type::kIluvatar}}, - {{INFINI_DEVICE_KUNLUN, infini::ops::Device::Type::kKunlun}}, - {{INFINI_DEVICE_HYGON, infini::ops::Device::Type::kHygon}}, - {{INFINI_DEVICE_QY, infini::ops::Device::Type::kQy}}}}}}}}; - - return kInfiniDeviceToDeviceType.at(device); -}} - -__C {_generate_create_func_def(operator)} - -__C {_generate_get_workspace_size_func_def(operator)} - -__C {_generate_call_func_def(operator)} - -__C {_generate_destroy_func_def(operator)} -""" - - def _generate_header(operator): - return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__ -#define __INFINIOP_{operator.name.upper()}_API_H__ - -#include "base/{operator.name.lower()}.h" - -typedef struct infini::ops::Operator *infiniop{operator.name}Descriptor_t; - -__C __export {_generate_create_func_decl(operator)}; - -__C __export {_generate_get_workspace_size_func_decl(operator)}; - -__C __export {_generate_call_func_decl(operator)}; - -__C __export {_generate_destroy_func_decl(operator)}; - -#endif -""" - - def _generate_create_func_def(operator): - name = operator.name - constructor = operator.constructors[-1] - - return f"""{_generate_create_func_decl(operator)} {{ - *desc_ptr = infini::ops::Operator::make({_generate_arguments(constructor)}).release(); - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_get_workspace_size_func_def(operator): - return f"""{_generate_get_workspace_size_func_decl(operator)} {{ - *size = 0; // desc->workspace_size(); - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_call_func_def(operator): - call = operator.calls[-1] - - return f"""{_generate_call_func_decl(operator)} {{ - (*desc)(stream, {_generate_arguments(call, is_data=True)}); - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_destroy_func_def(operator): - return f"""{_generate_destroy_func_decl(operator)} {{ - delete desc; - - return INFINI_STATUS_SUCCESS; -}}""" - - def _generate_create_func_decl(operator): - name = operator.name - constructor = operator.constructors[-1] - params = _generate_params(constructor) - - return f"infiniStatus_t infiniopCreate{name}Descriptor(infiniopHandle_t handle, infiniop{name}Descriptor_t *desc_ptr, {params})" - - def _generate_get_workspace_size_func_decl(operator): - name = operator.name - - return f"infiniStatus_t infiniopGet{name}WorkspaceSize(infiniop{name}Descriptor_t desc, size_t *size)" - - def _generate_call_func_decl(operator): - name = operator.name - call = operator.calls[-1] - params = _generate_params(call, call=True) - params = params.replace("void * stream, ", "") - - return f"infiniStatus_t infiniop{name}(infiniop{name}Descriptor_t desc, void *workspace, size_t workspace_size, {params}, void *stream)" - - def _generate_destroy_func_decl(operator): - name = operator.name - - return f"infiniStatus_t infiniopDestroy{name}Descriptor(infiniop{name}Descriptor_t desc)" - - def _generate_params(node, call=False): - arguments = tuple(node.get_arguments()) - - arguments = (arguments[-1], *arguments[:-1]) - - def _handle_tensor(spelling): - if call: - return spelling.replace("Tensor", "void *") - return spelling.replace("Tensor", "infiniopTensorDescriptor_t") - - def _handle_std_optional(spelling): - return spelling.replace("std::optional<", "").replace(">", "") - - return ", ".join( - f"{_handle_std_optional(_handle_tensor(arg.type.spelling))} {arg.spelling}" - for arg in arguments - ) - - def _generate_arguments(node, is_data=False): - return ", ".join( - _generate_tensor_caster(arg.spelling, is_data=is_data) - if "Tensor" in arg.type.spelling - else arg.spelling - for arg in node.get_arguments() - if arg.spelling != "handle" and arg.spelling != "stream" - ) - - def _generate_tensor_caster(name, is_data=False): - if is_data: - return f"infini::ops::Tensor(const_cast({name}), infini::ops::Tensor::Shape{{}})" - - return f"infini::ops::Tensor{{nullptr, {name}->shape(), DataTypeFromInfiniDType({name}->dtype()), infini::ops::Device{{DeviceTypeFromInfiniDevice(handle->device), handle->device_id}}, {name}->strides()}}" - - return _generate_source(operator), _generate_header(operator) - - -def _snake_to_pascal(snake_str): - return "".join(word.capitalize() for word in snake_str.split("_")) - - -def _get_all_ops(devices): - ops = {} - - for file_path in _BASE_DIR.iterdir(): - if not file_path.is_file(): - continue - - op_name = file_path.stem - - ops[op_name] = [] - - for file_path in _SRC_DIR.rglob("*"): - if not file_path.is_file() or file_path.parent.parent.name not in devices: - continue - - if f"class Operator<{_snake_to_pascal(op_name)}" in file_path.read_text(): - ops[op_name].append(file_path) - - return ops - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="An automatic wrapper generator.") @@ -415,29 +40,36 @@ def _get_all_ops(devices): if ops_json.exists(): ops = json.loads(ops_json.read_text()) else: - ops = _get_all_ops(args.devices) + ops = get_all_ops(args.devices) header_paths = [] bind_func_names = [] + shared_lib_decls = [] for op_name, impl_paths in ops.items(): - extractor = _OperatorExtractor() + extractor = OperatorExtractor() operator = extractor(op_name) source_path = _GENERATED_SRC_DIR / op_name header_name = f"{op_name}.h" - bind_func_name = f"Bind{_snake_to_pascal(op_name)}" + bind_func_name = f"Bind{snake_to_pascal(op_name)}" - (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + (_BINDINGS_DIR / header_name).write_text(generate_pybind11(operator)) - legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) + legacy_c_source, legacy_c_header = generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) + sl_source, sl_decls = generate_shared_lib(operator, impl_paths) + (_GENERATED_SRC_DIR / op_name / "make.cc").write_text(sl_source) + shared_lib_decls.append(sl_decls) + header_paths.append(header_name) bind_func_names.append(bind_func_name) + (_INCLUDE_DIR / "make.h").write_text(generate_shared_lib_header(shared_lib_decls)) + impl_includes = "\n".join( f'#include "{impl_path}"' for impl_paths in ops.values() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0b56341b..a3f8c322 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -174,6 +174,14 @@ endif() target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +file(GLOB FACTORY_SOURCES "${PROJECT_SOURCE_DIR}/generated/src/*/make.cc") +if(FACTORY_SOURCES) + if(WITH_NVIDIA OR WITH_ILUVATAR) + set_source_files_properties(${FACTORY_SOURCES} PROPERTIES LANGUAGE CUDA) + endif() + target_sources(infiniops PRIVATE ${FACTORY_SOURCES}) +endif() + if(GENERATE_PYTHON_BINDINGS) find_package(Python COMPONENTS Interpreter REQUIRED) execute_process( diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index dc28f0aa..ed087863 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -14,6 +14,7 @@ class RmsNorm : public Operator { RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out) : input_shape_{input.shape()}, out_shape_{out.shape()}, + dtype_{out.dtype()}, input_strides_{input.strides()}, out_strides_{out.strides()}, eps_{eps}, @@ -41,6 +42,8 @@ class RmsNorm : public Operator { Tensor::Shape out_shape_; + const DataType dtype_; + Tensor::Strides input_strides_; Tensor::Strides out_strides_; diff --git a/src/common/generic_utils.h b/src/common/generic_utils.h index 795f2fb7..b34ce8b0 100644 --- a/src/common/generic_utils.h +++ b/src/common/generic_utils.h @@ -5,9 +5,9 @@ namespace infini::ops::utils { -std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, - const std::size_t* shape, - const std::ptrdiff_t* strides) { +inline std::size_t IndexToOffset(std::size_t flat_index, std::size_t ndim, + const std::size_t* shape, + const std::ptrdiff_t* strides) { std::size_t res = 0; for (std::size_t i = ndim; i-- > 0;) { res += (flat_index % shape[i]) * strides[i]; diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index 14848ee4..ee64ec6c 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -19,7 +19,7 @@ class Operator : public CausalSoftmax, void operator()(const Tensor input, Tensor out) const override { DispatchFunc( - out.dtype(), + dtype_, [&](auto tag) { using T = typename decltype(tag)::type; Compute(input, out); diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index a4dfb989..1b3ebb32 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -32,7 +32,7 @@ class Operator : public Gemm, std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c) const override { DispatchFunc( - c.dtype(), + c_type_, [&](auto tag) { using T = typename decltype(tag)::type; Compute(a, b, alpha, beta, trans_a, trans_b, c); diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index 9cae419e..c7a1a484 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -20,7 +20,7 @@ class Operator : public RmsNorm, void operator()(const Tensor input, const Tensor weight, float eps, Tensor out) const override { DispatchFunc( - out.dtype(), + dtype_, [&](auto tag) { using T = typename decltype(tag)::type; Compute(input, weight, eps, out); diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 7c7ac871..dd3df98b 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -7,6 +7,7 @@ #include "base/causal_softmax.h" #include "cuda/causal_softmax/kernel.cuh" #include "cuda/kernel_commons.cuh" +#include "cuda/runtime_utils.h" #include "data_type.h" #include "dispatcher.h" @@ -29,14 +30,11 @@ class CudaCausalSoftmax : public CausalSoftmax { dim3 grid(static_cast(seq_len_), static_cast(batch_size_)); - assert(out.dtype() == input.dtype()); - int block_size = RuntimeUtils::GetOptimalBlockSize(); DispatchFunc, ReducedFloatTypes>, AllCudaBlockSizes>( - // TODO: Output dtype should use the one passed in during construction. - {static_cast(out.dtype()), block_size}, + {static_cast(dtype_), block_size}, [&](auto list_tag) { using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); diff --git a/src/cuda/gemm/blas.h b/src/cuda/gemm/blas.h index 264d52c1..a53ab63e 100644 --- a/src/cuda/gemm/blas.h +++ b/src/cuda/gemm/blas.h @@ -42,25 +42,25 @@ class BlasGemm : public Gemm { const auto& trans_b_value{trans_b.value_or(trans_b_)}; auto op_a{GetOpA(trans_a_value, trans_b_value)}; auto op_b{GetOpB(trans_a_value, trans_b_value)}; - const void* alpha_ptr{GetAlphaPtr(alpha_value, c.dtype())}; - const void* beta_ptr{GetBetaPtr(beta_value, c.dtype())}; + const void* alpha_ptr{GetAlphaPtr(alpha_value, c_type_)}; + const void* beta_ptr{GetBetaPtr(beta_value, c_type_)}; Backend::BlasGemmStridedBatchedEx( GetHandle(), op_a, op_b, swap_a_and_b_ ? n_ : m_, swap_a_and_b_ ? m_ : n_, k_, alpha_ptr, swap_a_and_b_ ? b.data() : a.data(), - BlasUtils::GetDataType(swap_a_and_b_ ? b.dtype() - : a.dtype()), + BlasUtils::GetDataType(swap_a_and_b_ ? b_type_ + : a_type_), swap_a_and_b_ ? ldb_ : lda_, swap_a_and_b_ ? batch_stride_b_ : batch_stride_a_, swap_a_and_b_ ? a.data() : b.data(), - BlasUtils::GetDataType(swap_a_and_b_ ? a.dtype() - : b.dtype()), + BlasUtils::GetDataType(swap_a_and_b_ ? a_type_ + : b_type_), swap_a_and_b_ ? lda_ : ldb_, swap_a_and_b_ ? batch_stride_a_ : batch_stride_b_, beta_ptr, c.data(), - BlasUtils::GetDataType(c.dtype()), ldc_, + BlasUtils::GetDataType(c_type_), ldc_, batch_stride_c_, batch_count_, - BlasUtils::GetComputeType(c.dtype()), + BlasUtils::GetComputeType(c_type_), Backend::BLAS_GEMM_DEFAULT); } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 14146edc..62a58dd2 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -32,13 +32,11 @@ class CudaRmsNorm : public RmsNorm { uint32_t num_blocks = static_cast(batch_size_ * nhead_); - assert(out.dtype() == input.dtype() && out.dtype() == weight.dtype()); - int block_size = RuntimeUtils::GetOptimalBlockSize(); DispatchFunc, ReducedFloatTypes>, AllCudaBlockSizes>( - {static_cast(out.dtype()), block_size}, + {static_cast(dtype_), block_size}, [&](auto list_tag) { using T = TypeMapType(list_tag)>; constexpr int kBlockSize = ListGet<1>(list_tag); diff --git a/src/nvidia/gemm/cublaslt.h b/src/nvidia/gemm/cublaslt.h index 38de8507..9e1d3610 100644 --- a/src/nvidia/gemm/cublaslt.h +++ b/src/nvidia/gemm/cublaslt.h @@ -56,11 +56,10 @@ class Operator : public Gemm { const auto* a_ptr{swap_a_and_b_ ? b.data() : a.data()}; const auto* b_ptr{swap_a_and_b_ ? a.data() : b.data()}; const auto a_dtype{BlasUtils::GetDataType( - swap_a_and_b_ ? b.dtype() : a.dtype())}; + swap_a_and_b_ ? b_type_ : a_type_)}; const auto b_dtype{BlasUtils::GetDataType( - swap_a_and_b_ ? a.dtype() : b.dtype())}; - const auto c_dtype{ - BlasUtils::GetDataType(c.dtype())}; + swap_a_and_b_ ? a_type_ : b_type_)}; + const auto c_dtype{BlasUtils::GetDataType(c_type_)}; const auto a_ld{static_cast(swap_a_and_b_ ? ldb_ : lda_)}; const auto b_ld{static_cast(swap_a_and_b_ ? lda_ : ldb_)}; const auto c_ld{static_cast(ldc_)}; @@ -72,7 +71,7 @@ class Operator : public Gemm { cublasLtMatmulDesc_t op_desc{}; auto status = cublasLtMatmulDescCreate( - &op_desc, BlasUtils::GetComputeType(c.dtype()), + &op_desc, BlasUtils::GetComputeType(c_type_), CUDA_R_32F); assert(status == CUBLAS_STATUS_SUCCESS && "failed to create cuBLASLt matmul descriptor");