From 9926f44f3eb5992c79a7ac361a3f205d1360d8cd Mon Sep 17 00:00:00 2001 From: bolun365 Date: Wed, 5 Nov 2025 16:48:16 +0800 Subject: [PATCH 001/314] =?UTF-8?q?mnn=20lib=E5=BA=93build=E8=84=9A?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- build_lib.sh | 807 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 807 insertions(+) create mode 100644 build_lib.sh diff --git a/build_lib.sh b/build_lib.sh new file mode 100644 index 0000000000..c839b6e7b6 --- /dev/null +++ b/build_lib.sh @@ -0,0 +1,807 @@ +#!/bin/bash + +# MNN 统一构建脚本 +# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN + +set -e + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# 默认配置 +BUILD_ANDROID=false +BUILD_IOS=false +BUILD_PYTHON=false +BUILD_IOS_SIMULATOR=false +BUILD_HARMONY=false +ANDROID_NDK="" +HARMONY_HOME="" +OUTPUT_DIR="mnn_builds" +CLEAN_BUILD=true +VERSION="" +PYTHON_CMD="python3" +DEPS_OPTIONS="" + +# 获取项目根目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR" + +print_usage() { + echo "MNN 统一构建脚本" + echo "" + echo "用法: $0 [选项]" + echo "" + echo "构建目标:" + echo " --android 构建 Android 版本 (需要 --ndk)" + echo " --ios 构建 iOS 真机版本 (arm64)" + echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" + echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" + echo " --python 构建 Python 版本" + echo "" + echo "Android 选项:" + echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" + echo "" + echo "鸿蒙选项:" + echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" + echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" + echo "" + echo "Python 选项:" + echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" + echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" + echo " --python-cmd CMD Python 命令 (默认: python3)" + echo "" + echo "通用选项:" + echo " -o, --output DIR 输出目录 (默认: mnn_builds)" + echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" + echo " --no-clean 不清理之前的构建目录" + echo " -h, --help 显示帮助信息" + echo "" + echo "示例:" + echo " # 构建所有平台" + echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 仅构建 Android" + echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 构建鸿蒙版本" + echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" + echo "" + echo " # 构建 iOS 真机和模拟器" + echo " $0 --ios --ios-simulator" + echo "" + echo " # 构建 Python (带 LLM 支持)" + echo " $0 --python --python-deps llm,opencl" +} + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + --android) + BUILD_ANDROID=true + shift + ;; + --ios) + BUILD_IOS=true + shift + ;; + --ios-simulator) + BUILD_IOS_SIMULATOR=true + shift + ;; + --python) + BUILD_PYTHON=true + shift + ;; + --harmony) + BUILD_HARMONY=true + shift + ;; + --ndk) + ANDROID_NDK="$2" + shift 2 + ;; + --harmony-home) + HARMONY_HOME="$2" + shift 2 + ;; + --python-deps) + DEPS_OPTIONS="$2" + shift 2 + ;; + --python-cmd) + PYTHON_CMD="$2" + shift 2 + ;; + -o|--output) + OUTPUT_DIR="$2" + shift 2 + ;; + -v|--version) + VERSION="$2" + shift 2 + ;; + --no-clean) + CLEAN_BUILD=false + shift + ;; + -h|--help) + print_usage + exit 0 + ;; + *) + echo -e "${RED}错误: 未知选项 $1${NC}" + print_usage + exit 1 + ;; + esac +done + +# 检查是否至少选择了一个构建目标 +if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then + echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" + print_usage + exit 1 +fi + +# 检查 Android NDK +if [ "$BUILD_ANDROID" = true ]; then + if [ -z "$ANDROID_NDK" ]; then + echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" + echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" + exit 1 + fi + + # 展开路径中的 ~ + ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" + + if [ ! -d "$ANDROID_NDK" ]; then + echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" + exit 1 + fi +fi + +# 查找鸿蒙工具链的函数 +find_harmony_toolchain() { + # 如果环境变量已设置,优先使用 + if [ -n "$HARMONY_HOME" ]; then + # 支持两种路径格式:直接指定或带 native 子目录 + if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + fi + fi + + # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 + local sdk_base="$HOME/Library/OpenHarmony/Sdk" + if [ -d "$sdk_base" ]; then + # 查找所有版本号目录下的 native 或 toolchains + # 按版本号倒序排列,优先使用最新版本 + for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do + if [ -d "$version_dir" ]; then + # 尝试 native/build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir/native" + return 0 + fi + # 尝试 build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir" + return 0 + fi + fi + done + fi + + # 其他可能的路径 + local possible_paths=( + "$HOME/Library/OpenHarmony/Sdk/native" + "$HOME/HarmonyOS/Sdk/native" + "$HOME/.ohos/native" + "/opt/HarmonyOS/Sdk/native" + "/usr/local/HarmonyOS/Sdk/native" + "$HOME/Library/HarmonyOS/Sdk/native" + ) + + # 尝试查找 + for path in "${possible_paths[@]}"; do + if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then + echo "$path" + return 0 + fi + done + + # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 + # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 + local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) + if [ -n "$found" ]; then + # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 + found=$(dirname "$found") + if [ "$(basename "$found")" = "cmake" ]; then + found=$(dirname "$found") + if [ "$(basename "$found")" = "build" ]; then + found=$(dirname "$found") + fi + fi + echo "$found" + return 0 + fi + + return 1 +} + +# 检查鸿蒙工具链 +if [ "$BUILD_HARMONY" = true ]; then + # 展开路径中的 ~ + if [ -n "$HARMONY_HOME" ]; then + HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" + fi + + # 尝试查找工具链 + if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + # 检查是否在 native 子目录下 + if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + else + echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" + HARMONY_HOME=$(find_harmony_toolchain) + + if [ -z "$HARMONY_HOME" ]; then + echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" + echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" + echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" + echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi + + # 如果找到的是带 native 的路径,需要调整 + if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + fi + fi + + echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" + fi + + # 验证工具链文件存在 + if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi +fi + +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}MNN 统一构建脚本${NC}" +echo -e "${GREEN}========================================${NC}" +echo "项目根目录: $PROJECT_ROOT" +echo "输出目录: $OUTPUT_DIR" +echo "" +echo -e "${BLUE}构建目标:${NC}" +[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" +[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" +[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" +[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" +echo "" + +cd "$PROJECT_ROOT" +mkdir -p "$OUTPUT_DIR" + +# ============================================================================ +# 构建 Android 版本 +# ============================================================================ +if [ "$BUILD_ANDROID" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Android 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export ANDROID_NDK + + ANDROID_BUILD_DIR="project/android" + ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" + + cd "$ANDROID_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 Android 构建...${NC}" + rm -rf build_32 build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + + # 构建 armeabi-v7a + echo -e "${BLUE}构建 armeabi-v7a...${NC}" + mkdir -p build_32 + cd build_32 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="armeabi-v7a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-14 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出 Android 库文件...${NC}" + mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN + + cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true + cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/android/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" + + echo -e "${GREEN}Android 构建完成!${NC}" + echo "输出位置: $ANDROID_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 真机版本 +# ============================================================================ +if [ "$BUILD_IOS" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 真机版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_64 + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建真机版本 (arm64) + echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" + rm -rf ios_64 + mkdir ios_64 + cd ios_64 + cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=OS64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + > /dev/null + make MNN -j8 > /dev/null + cd .. + + # 输出真机版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理 + rm -rf ios_64 + + echo -e "${GREEN}iOS 真机构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 模拟器版本 +# ============================================================================ +if [ "$BUILD_IOS_SIMULATOR" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) + BUILD_X86_64=false + echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" + rm -rf ios_simulator_x86 + mkdir ios_simulator_x86 + cd ios_simulator_x86 + if cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="x86_64" \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + && make MNN -j8; then + if [ -f "MNN.framework/MNN" ]; then + BUILD_X86_64=true + echo -e "${GREEN}x86_64 模拟器构建成功${NC}" + cd .. + else + echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + else + echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + + # 构建 arm64 模拟器版本 + echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" + rm -rf ios_simulator_arm64 + mkdir ios_simulator_arm64 + cd ios_simulator_arm64 + if ! cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON; then + echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + if ! make MNN -j8; then + echo -e "${RED}错误: arm64 模拟器编译失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + cd .. + + # 验证构建产物 + if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then + echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + + # 合并模拟器架构 + echo -e "${BLUE}合并模拟器架构...${NC}" + rm -rf ios_simulator + mkdir ios_simulator + + if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then + # 合并 x86_64 + arm64 + echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" + cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework + mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 + if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then + echo -e "${RED}错误: 合并架构失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + rm ios_simulator/MNN.framework/MNN_x86 + else + # 仅使用 arm64 + echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" + cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework + fi + + # 输出模拟器版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理临时目录 + rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 + + echo -e "${GREEN}iOS 模拟器构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建鸿蒙版本 +# ============================================================================ +if [ "$BUILD_HARMONY" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建鸿蒙版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export HARMONY_HOME + + HARMONY_BUILD_DIR="project/harmony" + HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" + + cd "$HARMONY_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" + rm -rf build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + + cmake "$PROJECT_ROOT" \ + -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DOHOS_ARCH="arm64-v8a" \ + -DOHOS_STL=c++_static \ + -DMNN_USE_LOGCAT=true \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_SUPPORT_BF16=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DOHOS_PLATFORM_LEVEL=9 \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出鸿蒙库文件...${NC}" + mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN + + cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/harmony/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" + + echo -e "${GREEN}鸿蒙构建完成!${NC}" + echo "输出位置: $HARMONY_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 Python 版本 +# ============================================================================ +if [ "$BUILD_PYTHON" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Python 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查 Python + if ! command -v $PYTHON_CMD &> /dev/null; then + echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" + exit 1 + fi + + PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" + + # 使用之前创建的 build_pymnn.sh 脚本 + if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then + PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" + [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" + [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" + [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" + PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" + + bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS + else + echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" + + cd pymnn/pip_package + + # 构建依赖 + if [ -n "$DEPS_OPTIONS" ]; then + $PYTHON_CMD build_deps.py $DEPS_OPTIONS + else + $PYTHON_CMD build_deps.py + fi + + # 构建 wheel + $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet + rm -rf build dist + + BUILD_ARGS="" + [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" + [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" + + $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS + + mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" + cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" + + cd "$PROJECT_ROOT" + fi + + echo -e "${GREEN}Python 构建完成!${NC}" + echo "输出位置: $PYTHON_OUTPUT_DIR" +fi + +# ============================================================================ +# 总结 +# ============================================================================ +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}所有构建完成!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" +echo "" +[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" +[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" +[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" +[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" +echo "" + + From 5c65900ddf48d11551c0e551e4269eb3a8a494cd Mon Sep 17 00:00:00 2001 From: HenryDen Date: Fri, 14 Nov 2025 10:06:20 +0800 Subject: [PATCH 002/314] Add a compile option and macro to default enable kleidiAI --- CMakeLists.txt | 1 + source/backend/cpu/arm/CMakeLists.txt | 3 +++ source/core/Backend.hpp | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d942aec59..f415ad6c6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,6 +258,7 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) +option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 18fca54a4e..61ebce6bdc 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,6 +36,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() + if(MNN_KLEIDIAI_DEFAULT_ON) + add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) + endif() endif() if (MNN_SME2) diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 231ea4a137..ce2b450f98 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -66,9 +66,11 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; - +#ifdef MNN_DEFAULT_USE_KLEIDIAI + bool enableKleidiAI = true; +#else bool enableKleidiAI = false; - +#endif // Use CPU Ids std::vector cpuIds; }; From a6360ded481d5eab249bbd9ef33b2cd6b583ca29 Mon Sep 17 00:00:00 2001 From: jianglinjun Date: Sat, 22 Nov 2025 01:12:23 +0800 Subject: [PATCH 003/314] =?UTF-8?q?fix:=20opencl=20depthwisedeconv?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E4=BD=8D=E7=BD=AE=E5=AF=B9=E5=BA=94=E5=85=B3?= =?UTF-8?q?=E7=B3=BB=E9=94=99=E8=AF=AF=E5=92=8CKernel=E4=B8=8D=E8=A7=84?= =?UTF-8?q?=E8=8C=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../opencl/execution/cl/depthwise_deconv2d.cl | 31 ++++++++++--------- .../cl/depthwise_deconv2d_mnn_cl.cpp | 27 +++++++++------- .../opencl/execution/cl/opencl_source_map.hpp | 2 +- .../image/DepthwiseDeconvExecution.cpp | 11 ++++--- 4 files changed, 40 insertions(+), 31 deletions(-) diff --git a/source/backend/opencl/execution/cl/depthwise_deconv2d.cl b/source/backend/opencl/execution/cl/depthwise_deconv2d.cl index 42abb86870..8a9f449c1a 100644 --- a/source/backend/opencl/execution/cl/depthwise_deconv2d.cl +++ b/source/backend/opencl/execution/cl/depthwise_deconv2d.cl @@ -1,8 +1,12 @@ +#ifdef MNN_SUPPORT_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + #define READ_INPUT_IMAGE(i, base) \ int in_width_value##i = in_width##i + base; \ in_width_value##i = \ select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); \ - in##i = read_imagef(input, SAMPLER, (int2)(in_width_value##i, in_hb_value)); + in##i = RI_F(input, SAMPLER, (int2)(in_width_value##i, in_hb_value)); #define CALCULATE_OUTPUT(i) \ out##i = mad(in##i.x, weights0, out##i); \ @@ -38,23 +42,22 @@ __kernel void depthwise_deconv2d(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, DEAL_NON_UNIFORM_DIM3(out_channel_blocks_idx, out_width_idx, out_batch_height_idx); #ifndef NO_BIAS - float4 out0 = read_imagef(bias, SAMPLER, (int2)(out_channel_blocks_idx, 0)); + FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(out_channel_blocks_idx, 0)); #else - float4 out0 = (float4)(0.0); + FLOAT4 out0 = (FLOAT4)0; #endif const int out_batch_idx = out_batch_height_idx / output_shape.x; const int out_height_idx = out_batch_height_idx % output_shape.x; - int kernel_start_x = (out_width_idx + align_shape.y) / stride_shape.y; - int kernel_start_y = (out_height_idx + align_shape.x) / stride_shape.x; - + int kernel_start_x = max(0, (out_width_idx + align_shape.y) / stride_shape.y); + int kernel_start_y = max(0, (out_height_idx + align_shape.x) / stride_shape.x); int deal_kernel_width = kernel_shape.y - mad24(kernel_start_x, stride_shape.y, padding_shape.y) + out_width_idx - 1; int deal_kernel_height = kernel_shape.x - mad24(kernel_start_y, stride_shape.x, padding_shape.x) + out_height_idx - 1; int kernel_image_x; - float4 in0; - float4 weight; + FLOAT4 in0; + FLOAT4 weight; int in_width0; int in_idx, in_idy; for (int k_y = deal_kernel_height, idx_h = kernel_start_y; k_y >= 0; k_y -= stride_shape.x, idx_h++) { @@ -67,19 +70,19 @@ __kernel void depthwise_deconv2d(GLOBAL_SIZE_3_DIMS __read_only image2d_t input, READ_INPUT_IMAGE(0, 0); kernel_image_x = mad24(k_y, kernel_shape.y, k_x); - weight = read_imagef(weights, SAMPLER, (int2)(kernel_image_x, out_channel_blocks_idx)); + weight = RI_F(weights, SAMPLER, (int2)(kernel_image_x, out_channel_blocks_idx)); out0 = mad(in0, weight, out0); } + } #ifdef RELU - out0 = fmax(out0, (float4)0); + out0 = fmax(out0, (FLOAT4)0); #endif #ifdef RELU6 - out0 = clamp(out0, (float4)0, (float4)6); + out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); #endif - const int output_image_x = mad24(out_channel_blocks_idx, output_shape.y, out_width_idx); - write_imagef(output, (int2)(output_image_x, out_batch_height_idx), out0); - } + const int output_image_x = mad24(out_channel_blocks_idx, output_shape.y, out_width_idx); + WI_F(output, (int2)(output_image_x, out_batch_height_idx), out0); } diff --git a/source/backend/opencl/execution/cl/depthwise_deconv2d_mnn_cl.cpp b/source/backend/opencl/execution/cl/depthwise_deconv2d_mnn_cl.cpp index 2e950a6f3a..42b03b3174 100644 --- a/source/backend/opencl/execution/cl/depthwise_deconv2d_mnn_cl.cpp +++ b/source/backend/opencl/execution/cl/depthwise_deconv2d_mnn_cl.cpp @@ -1,7 +1,10 @@ #include "opencl_source_map.hpp" namespace MNN { const char* depthwise_deconv2d = -"#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=read_imagef(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n" +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=RI_F(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n" "#define CALCULATE_OUTPUT(i) "" out##i = mad(in##i.x, weights0, out##i); "" out##i = mad(in##i.y, weights1, out##i); "" out##i = mad(in##i.z, weights2, out##i); "" out##i=mad(in##i.w,weights3,out##i);\n" "#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" "#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" @@ -24,19 +27,19 @@ const char* depthwise_deconv2d = " const int out_batch_height_idx=get_global_id(2);\n" " DEAL_NON_UNIFORM_DIM3(out_channel_blocks_idx,out_width_idx,out_batch_height_idx);\n" " #ifndef NO_BIAS\n" -" float4 out0=read_imagef(bias,SAMPLER,(int2)(out_channel_blocks_idx,0));\n" +" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_blocks_idx,0));\n" " #else\n" -" float4 out0=(float4)(0.0);\n" +" FLOAT4 out0=(FLOAT4)0;\n" " #endif\n" " const int out_batch_idx=out_batch_height_idx/output_shape.x;\n" " const int out_height_idx=out_batch_height_idx % output_shape.x;\n" -" int kernel_start_x=(out_width_idx+align_shape.y)/stride_shape.y;\n" -" int kernel_start_y=(out_height_idx+align_shape.x)/stride_shape.x;\n" +" int kernel_start_x=max(0,(out_width_idx+align_shape.y)/stride_shape.y);\n" +" int kernel_start_y=max(0,(out_height_idx+align_shape.x)/stride_shape.x);\n" " int deal_kernel_width=kernel_shape.y-mad24(kernel_start_x,stride_shape.y,padding_shape.y)+out_width_idx-1;\n" " int deal_kernel_height=kernel_shape.x-mad24(kernel_start_y,stride_shape.x,padding_shape.x)+out_height_idx-1;\n" " int kernel_image_x;\n" -" float4 in0;\n" -" float4 weight;\n" +" FLOAT4 in0;\n" +" FLOAT4 weight;\n" " int in_width0;\n" " int in_idx,in_idy;\n" " for (int k_y=deal_kernel_height,idx_h=kernel_start_y; k_y >= 0; k_y -= stride_shape.x,idx_h++) {\n" @@ -47,18 +50,18 @@ const char* depthwise_deconv2d = " in_idx=mul24(out_channel_blocks_idx,input_shape.y);\n" " READ_INPUT_IMAGE(0,0);\n" " kernel_image_x=mad24(k_y,kernel_shape.y,k_x);\n" -" weight=read_imagef(weights,SAMPLER,(int2)(kernel_image_x,out_channel_blocks_idx));\n" +" weight=RI_F(weights,SAMPLER,(int2)(kernel_image_x,out_channel_blocks_idx));\n" " out0=mad(in0,weight,out0);\n" " }\n" +" }\n" "#ifdef RELU\n" -" out0=fmax(out0,(float4)0);\n" +" out0=fmax(out0,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" -" out0=clamp(out0,(float4)0,(float4)6);\n" +" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int output_image_x=mad24(out_channel_blocks_idx,output_shape.y,out_width_idx);\n" -" write_imagef(output,(int2)(output_image_x,out_batch_height_idx),out0);\n" -" }\n" +" WI_F(output,(int2)(output_image_x,out_batch_height_idx),out0);\n" "}\n" ; } diff --git a/source/backend/opencl/execution/cl/opencl_source_map.hpp b/source/backend/opencl/execution/cl/opencl_source_map.hpp index d799ac78c3..ecc30b112e 100644 --- a/source/backend/opencl/execution/cl/opencl_source_map.hpp +++ b/source/backend/opencl/execution/cl/opencl_source_map.hpp @@ -394,7 +394,7 @@ const std::map OpenCLProgramMd5Map = { "groupnorm_buf", "7f4b041b77ba98165ab624d94444f327" }, { "unary_subgroup_buf", "31e3768f899da6da45084f617b13c282" }, { "gemm", "5729018147348682e02762ed5ec14d0c" }, - { "depthwise_deconv2d", "5a3e5498276638d6b73cf7b5e19bd750" }, + { "depthwise_deconv2d", "810f69205dede9b38e4858aad621fa71" }, { "range", "97feaf25d837a325382c162ad77ae0ca" }, { "scale_buf", "9176b8e86fd4d326e7fa14640ce13b48" }, { "matmul_buf", "b66faece7f0591d49c289e5227d9f680" }, diff --git a/source/backend/opencl/execution/image/DepthwiseDeconvExecution.cpp b/source/backend/opencl/execution/image/DepthwiseDeconvExecution.cpp index 6beede6a49..7bdf56f205 100644 --- a/source/backend/opencl/execution/image/DepthwiseDeconvExecution.cpp +++ b/source/backend/opencl/execution/image/DepthwiseDeconvExecution.cpp @@ -110,13 +110,16 @@ ErrorCode DepthwiseDeconvExecution::onEncode(const std::vector &inputs const int paddingHeight = pad.second; const int paddingWidth = pad.first; - const int alignHeight = strideHeight - 1 - paddingHeight; - const int alignWidth = strideWidth - 1 - paddingWidth; - const int filterHeight = mResource->mConv2dCommonParams->kernelY(); const int filterWidth = mResource->mConv2dCommonParams->kernelX(); const int kernelSize = filterHeight * filterWidth; + const int transPadH = filterHeight - 1 - pad.second; + const int transPadW = filterWidth - 1 - pad.first; + + const int alignHeight = strideHeight - 1 - transPadH; + const int alignWidth = strideWidth - 1 - transPadW; + mGWS = {static_cast(channelBlocks), static_cast(outputWidth), static_cast(outputHeight * outputBatch)}; std::string info = std::to_string(inputChannels) + "_" + std::to_string(outputChannels) + "_" + std::to_string(filterHeight) + "_" + std::to_string(filterWidth) + "_" + std::to_string(strideHeight) + "_" + std::to_string(strideWidth); @@ -127,7 +130,7 @@ ErrorCode DepthwiseDeconvExecution::onEncode(const std::vector &inputs int inputImageShape[2] = {inputHeight, inputWidth}; int outputImageShape[2] = {outputHeight, outputWidth}; int strideShape[2] = {strideHeight, strideWidth}; - int paddingShape[2] = {paddingHeight, paddingWidth}; + int paddingShape[2] = {transPadH, transPadW}; int alignShape[2] = {alignHeight, alignWidth}; int kernelShape[2] = {filterHeight, filterWidth}; From 4f5a68b3fc4fac794cb7e1afe210073614c78ea9 Mon Sep 17 00:00:00 2001 From: jianglinjun Date: Sat, 22 Nov 2025 01:12:34 +0800 Subject: [PATCH 004/314] =?UTF-8?q?fix:=20vulkan=20image=20barrier?= =?UTF-8?q?=E7=BC=BA=E5=B0=91=E5=8F=82=E6=95=B0=EF=BC=8C=E4=BC=9A=E5=AF=BC?= =?UTF-8?q?=E8=87=B4Adreno=20(TM)=20830=E5=BF=85=E7=8E=B0=E5=8D=B7?= =?UTF-8?q?=E7=A7=AF=E7=BB=93=E6=9E=9C=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/backend/vulkan/component/VulkanImage.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/backend/vulkan/component/VulkanImage.cpp b/source/backend/vulkan/component/VulkanImage.cpp index ef1c66a143..21f17c7952 100644 --- a/source/backend/vulkan/component/VulkanImage.cpp +++ b/source/backend/vulkan/component/VulkanImage.cpp @@ -105,12 +105,16 @@ void VulkanImage::insertMemoryBarrier( ) { VkImageMemoryBarrier imageMemoryBarrier; ::memset(&imageMemoryBarrier, 0, sizeof(VkImageMemoryBarrier)); + imageMemoryBarrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER; + imageMemoryBarrier.pNext = nullptr; imageMemoryBarrier.srcAccessMask = srcAccessMask; imageMemoryBarrier.dstAccessMask = dstAccessMask; imageMemoryBarrier.oldLayout = oldImageLayout; imageMemoryBarrier.newLayout = newImageLayout; imageMemoryBarrier.image = image; imageMemoryBarrier.subresourceRange = subresourceRange; + imageMemoryBarrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + imageMemoryBarrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; vkCmdPipelineBarrier( cmdbuffer, From 0a5ee52e7d209b3b5735419f52dd5dd366657478 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Mon, 24 Nov 2025 07:12:19 +0000 Subject: [PATCH 005/314] opt(RVV): Optimize pack and unpack functions with intrinsics Optimize MNNPackC4, MNNPackC2 and MNNUnpackC4 using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 ++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 ++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp new file mode 100644 index 0000000000..9a74f8998d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp @@ -0,0 +1,74 @@ +#include + +void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC2 = depth / 2; + int depthRemain = depthC2 * 2; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[2]; + + for (int z = 0; z < depthC2; ++z) { + float *dstZ = dst + z * areaOffset[1] * 2; + + for (int y = 0; y < 2; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + } + + srcOffset += areaOffset[0] * 2; + } + + if (remain > 0) { + float *dstZ = dst + depthC2 * areaOffset[1] * 2; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 2; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 2; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp new file mode 100644 index 0000000000..024e2c8c07 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp @@ -0,0 +1,80 @@ +#include + +void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[4]; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ = dst + z * areaOffset[1] * 4; + + for (int y = 0; y < 4; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + dstPtr[2] = srcChannel[2][x]; + dstPtr[3] = srcChannel[3][x]; + } + + srcOffset += areaOffset[0] * 4; + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 4; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 4; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp new file mode 100644 index 0000000000..4676e6dede --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp @@ -0,0 +1,55 @@ +#include + +void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ[4]; + + for (int y = 0; y < 4; ++y) { + dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); + srcOffset += 4 * vl; + } + + for (; x < area; ++x) { + dstZ[0][x] = srcOffset[0]; + dstZ[1][x] = srcOffset[1]; + dstZ[2][x] = srcOffset[2]; + dstZ[3][x] = srcOffset[3]; + srcOffset += (areaOffset[0] - area) * 4; + } + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + const float *srcBase = srcOffset; + + for (int y = 0; y < remain; ++y) { + float *dstChannel = dstZ + y * areaOffset[1]; + const float *srcChannel = srcBase + y; + + for (size_t x = 0; x < area; ++x) { + dstChannel[x] = srcChannel[0]; + srcChannel += 4; + } + } + } +} + From 89f47a28769596c3e2240c2cf53c509a134d3ce5 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Tue, 25 Nov 2025 06:33:36 +0000 Subject: [PATCH 006/314] opt(RVV): Optimize transpose functions with intrinsics Optimize MNNTranspose16Bit and MNNTranspose32Bit using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 +++++++++++++++++++ .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 ++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp new file mode 100644 index 0000000000..7598d6f8ac --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp @@ -0,0 +1,26 @@ +#include + +void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); + + for (int i = 0; i < h; ++i) { + const int16_t* srcPtr = srcO + i; + int16_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e16m8(w - j); + vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); + __riscv_vse16_v_i16m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + + diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp new file mode 100644 index 0000000000..e5c5eb83e6 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp @@ -0,0 +1,25 @@ +#include + +void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); + + for (int i = 0; i < h; ++i) { + const int32_t* srcPtr = srcO + i; + int32_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e32m8(w - j); + vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); + __riscv_vse32_v_i32m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + From aaf5b231cf9bc44361ad2c129100bdbcc444df05 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Wed, 26 Nov 2025 08:46:18 +0000 Subject: [PATCH 007/314] opt(RVV): Optimize core math and stride functions with intrinsics Optimize the following functions using RVV intrinsics: MNNAxByClampBroadcastUnit, MNNScaleAndAddBias, MNNCopyC4WithStride, MNNAddC4WithStride Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 +++++++++++ .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 ++++++++ .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 +++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp new file mode 100644 index 0000000000..59bb28a039 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp @@ -0,0 +1,29 @@ +#include + +void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); + size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp new file mode 100644 index 0000000000..6d966789f7 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp @@ -0,0 +1,52 @@ +#include + +void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { + float beta = parameters[1]; + float minF = parameters[2]; + float maxF = parameters[3]; + const ptrdiff_t stride = 4 * sizeof(float); + + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + 4 * y; + auto c = C + cStride * y; + float b0Beta = b[0] * beta; + float b1Beta = b[1] * beta; + float b2Beta = b[2] * beta; + float b3Beta = b[3] * beta; + size_t w = width; + + while (w > 0) { + size_t vl = __riscv_vsetvl_e32m8(w); + + vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); + + a += 4 * vl; + c += 4 * vl; + w -= vl; + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp new file mode 100644 index 0000000000..3d8c4f13fc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp @@ -0,0 +1,22 @@ +#include + +void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); +size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp new file mode 100644 index 0000000000..10992f9d59 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp @@ -0,0 +1,42 @@ +#include + +void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t z = 0; z < biasNumber; ++z) { + float *dstZ = dst + z * planeNumber * 4; + const float *srcZ = src + z * planeNumber * 4; + const float *biasZ = bias + 4 * z; + const float *alphaZ = alpha + 4 * z; + float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; + float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; + + size_t n = planeNumber; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a0, vl); + data = __riscv_vfadd_vf_f32m8(data, b0, vl); + __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a1, vl); + data = __riscv_vfadd_vf_f32m8(data, b1, vl); + __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a2, vl); + data = __riscv_vfadd_vf_f32m8(data, b2, vl); + __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a3, vl); + data = __riscv_vfadd_vf_f32m8(data, b3, vl); + __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); + + srcZ += vl * 4; + dstZ += vl * 4; + n -= vl; + } + } +} From 90ad653a6abb5232b0dca6e45f608bf63cc75079 Mon Sep 17 00:00:00 2001 From: hacksang <985438046@qq.com> Date: Wed, 26 Nov 2025 23:01:59 +0800 Subject: [PATCH 008/314] fix a bug in compute mGroupWithComputeRate --- source/backend/cpu/CPUBackend.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index ceb23910fd..64dd35008b 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -499,6 +499,7 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p currentRate *= decreaseRate; totalComputeRate += currentRate * selectSize; mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize)); + groupIndex--; } for (auto& g : mGroupWithComputeRate) { g.first = g.first / totalComputeRate; From 09c339c65cb73dda7c055c3e49e9026dfd998e3e Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Mon, 1 Dec 2025 01:21:02 +0000 Subject: [PATCH 009/314] opt(RVV): Optimize max and min float functions with intrinsics Optimize MNNMaxFloat and MNNMinFloat using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 ++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp new file mode 100644 index 0000000000..183a38bb10 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { + const float init = -FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + maxBuffer[j] = local; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp new file mode 100644 index 0000000000..9e8ade8641 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { + const float init = FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + minBuffer[j] = local; + } +} From 672c5862392393c171f1513bf7994d3b95e2a6a1 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Wed, 3 Dec 2025 07:06:07 +0000 Subject: [PATCH 010/314] opt(RVV): Optimize conv and strassen functions with intrinsics Optimize MNNConvRunForLineDepthwise, MNNDeconvRunForUnitDepthWise and MNNStrassenMergeCFunction using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 +++++++++++++++++++ .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 ++++++++++++++++ .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 ++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp new file mode 100644 index 0000000000..f82faf83f5 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp @@ -0,0 +1,48 @@ +#include + +void MNNConvRunForLineDepthwise( + float* dst, const float* src, const float* weight, + size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, + size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters) { + float minV = parameters[0]; + float maxV = parameters[1]; + ptrdiff_t srcByteStride = src_w_setup * sizeof(float); + ptrdiff_t dstByteStride = 4 * sizeof(float); + + for (size_t y = 0; y < height; ++y) { + const float* srcY = src + y * srcHStep; + float* dstY = dst + y * dstHStep; + size_t dx = 0; + + while (dx < width) { + size_t vl = __riscv_vsetvl_e32m8(width - dx); + + for (int c = 0; c < 4; ++c) { + vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); + const float* srcBase = srcY + dx * src_w_setup + c; + const float* weightPtr = weight + c; + + for (size_t fy = 0; fy < fh; ++fy) { + const float* srcFy = srcBase + fy * dilateY_step; + + for (size_t fx = 0; fx < fw; ++fx) { + float w = *weightPtr; + weightPtr += 4; + const float* srcFx = srcFy + fx * dilateX_step; + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); + } + } + + acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); + acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); + float* dstAddr = dstY + dx * 4 + c; + __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); + } + + dx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp new file mode 100644 index 0000000000..6658715e7e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp @@ -0,0 +1,42 @@ +#include + +void MNNDeconvRunForUnitDepthWise( + const float* dst, float* src, const float* weight, + size_t fw, size_t fh, + size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { + const ptrdiff_t wStride = 4 * sizeof(float); + const ptrdiff_t sStride = dilateX_step * sizeof(float); + float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; + + for (size_t fy = 0; fy < fh; ++fy) { + float* srcY = src + fy * dilateY_step; + const float* weightY = weight + fy * weightY_step; + + size_t fx = 0; + while (fx < fw) { + size_t vl = __riscv_vsetvl_e32m8(fw - fx); + + vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); + __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); + __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); + __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); + __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); + + fx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp new file mode 100644 index 0000000000..8ab5bb89fa --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp @@ -0,0 +1,36 @@ +#include + +void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, + float *xAddr, size_t cStride, size_t eSub, size_t hSub) { + for (int y = 0; y < hSub; ++y) { + float *c11Y = c11 + y * cStride; + float *c12Y = c12 + y * cStride; + float *c22Y = c22 + y * cStride; + float *c21Y = c21 + y * cStride; + float *xY = xAddr + y * eSub * 4; + size_t totalElements = eSub * 4; + size_t p = 0; + + while (p < totalElements) { + size_t vl = __riscv_vsetvl_e32m8(totalElements - p); + vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); + vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); + t = __riscv_vfadd_vv_f32m8(t, tmp, vl); + vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); + + tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); + __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); + + tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); + __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); + + c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); + __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); + + p += vl; + } + } +} From b7268aa3dab754190ed7d86dc16b9bab02e73d12 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Fri, 5 Dec 2025 06:38:57 +0000 Subject: [PATCH 011/314] opt(RVV): Optimize Softmax and ReluWithSlopeChannel with intrinsics Optimize MNNSoftmax and MNNReluWithSlopeChannel using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 +++++++++++ source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 +++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp new file mode 100644 index 0000000000..262f4cbfab --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp @@ -0,0 +1,45 @@ +#include + +void MNNReluWithSlopeChannel(float *dst, const float *src, + const float *slope, size_t sizeQuad, + size_t depthQuad) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t j = 0; j < depthQuad; ++j) { + const float *srcZ = src + 4 * j * sizeQuad; + float *dstZ = dst + 4 * j * sizeQuad; + float s0 = slope[4*j], s1 = slope[4*j + 1]; + float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; + size_t i = 0; + while (i < sizeQuad) { + size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); + const float *srcBase = srcZ + 4*i; + float *dstBase = dstZ + 4*i; + + vfloat32m8_t v; + vbool4_t mask; + + v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); + __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); + __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); + __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); + __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); + + i += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp new file mode 100644 index 0000000000..f510058c83 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp @@ -0,0 +1,80 @@ +#include +#include + +void MNNSoftmax(float *dest, const float *source, size_t size) { + size_t n = size; + const float *sourcePtr = source; + float *destPtr = dest; + float maxValue = -FLT_MAX; + vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); + maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); + sourcePtr += vl; + n -= vl; + } + + maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); + const float param = 0.6931471805599453f; + const float xLimit = 87.0f; + float sumValue = 0.f; + vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); + n = size; + sourcePtr = source; + destPtr = dest; + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); + vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); + vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); + vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); + + vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); + vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); + + vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( + __riscv_vsll_vx_i32m8( + __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); + + vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); + vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); + + vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + + vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); + __riscv_vse32_v_f32m8(destPtr, vA, vl); + sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); + + sourcePtr += vl; + destPtr += vl; + n -= vl; + } + + sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); + float sumInv = 1.0f / sumValue; + n = size; + destPtr = dest; + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); + vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); + __riscv_vse32_v_f32m8(destPtr, vDest, vl); + destPtr += vl; + n -= vl; + } +} From 44cd2f1c63c8bf7e1b78eb22186cf473a46b5748 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Tue, 9 Dec 2025 01:07:46 +0000 Subject: [PATCH 012/314] opt(RVV): Optimize top1 functions with intrinsics Optimize MNNVectorTop1Float and MNNVectorTop1Int32 using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 +++++++++++++++++++ .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp new file mode 100644 index 0000000000..7332360ce8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + float maxV = -FLT_MAX; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); + vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); + maxV = __riscv_vfmv_f_s_f32m1_f32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp new file mode 100644 index 0000000000..8c199709ec --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + int32_t maxV = INT32_MIN; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); + vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); + maxV = __riscv_vmv_x_s_i32m1_i32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} From f4fcff3436a9d95c5367699b1c53d4eb1cbc3b7d Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Fri, 12 Dec 2025 08:42:52 +0000 Subject: [PATCH 013/314] opt(RVV): Optimize resize functions with intrinsics Optimize CPUBilinearLineC4, CPUBilinearSampleC4, MNNBilinearLineC8, MNNBilinearSampleC8, MNNCubicLineC4, MNNCubicLineC16, MNNCubicSampleC4 and MNNCubicSampleC16 using RVV intrinsics. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 +++++ .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 ++++++++ .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 ++++++++++ .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 ++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 +++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 +++++++++ .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 +++++++++++++++ 8 files changed, 373 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp new file mode 100644 index 0000000000..a700016c31 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp @@ -0,0 +1,19 @@ +#include + +void CPUBilinearLineC4(float* dst, const float* A, const float* B, + const float* t, int8_t* zeroPoint, size_t number) { + float tf = *t; + float sf = 1.0f - tf; + size_t total = number << 2; + + size_t i = 0; + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); + vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); + v = __riscv_vle32_v_f32m8(B + i, vl); + result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); + __riscv_vse32_v_f32m8(dst + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp new file mode 100644 index 0000000000..5063c39bff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp @@ -0,0 +1,33 @@ +#include + +void CPUBilinearSampleC4(const float* src, float* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); + vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); + vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); + vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp new file mode 100644 index 0000000000..a26243bdb8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp @@ -0,0 +1,40 @@ +#include + +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, + const float* t, int8_t* zeroPoint, size_t number) { + int offset = *zeroPoint; + int8_t* dstPtr = dst; + + const int pack = 8; + const int16_t df = (int16_t)((*t) * 128.0f); + const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); + const size_t total = number * pack; + const int32_t ROUND_HALF = 1 << 13; + + size_t vl; + for (size_t i = 0; i < total; i += vl) { + vl = __riscv_vsetvl_e16m4(total - i); + vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); + vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); + v16 = __riscv_vle16_v_i16m4(B + i, vl); + v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); + + vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); + + tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); + mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); + vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); + mask = __riscv_vmand_mm_b4(mask, hasRem, vl); + + v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); + v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); + v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); + vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); + + __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp new file mode 100644 index 0000000000..bd111e3be4 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp @@ -0,0 +1,49 @@ +#include + +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + int16_t offset = (int16_t)(*zeroPoint); + const int pack = 8; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + vint16m4_t vdf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); + vint16m4_t vsf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8( + __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), + c, vl); + + vint16m4_t va = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + + vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); + voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), + c, vl); + + vint16m4_t vb = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); + __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, + __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp new file mode 100644 index 0000000000..fd6ce7a274 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp @@ -0,0 +1,53 @@ +#include + +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const int offset = *zeroPoint; + const int minVal = (int)minValue; + const int maxVal = (int)maxValue; + const size_t total = number << 4; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); + vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); + acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); + + vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); + vint = __riscv_vadd_vx_i32m8(vint, offset, vl); + vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); + vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); + + vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); + vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); + __riscv_vse8_v_i8m2(dst + i, vi8, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp new file mode 100644 index 0000000000..0da63ca0ff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp @@ -0,0 +1,38 @@ +#include + +void MNNCubicLineC4(float* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const size_t total = number << 2; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + __riscv_vse32_v_f32m8(dst + i, acc, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp new file mode 100644 index 0000000000..fd5b24a53d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp @@ -0,0 +1,79 @@ +#include + +void MNNCubicSampleC16(const int8_t* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 16; + int8_t zp = *zeroPoint; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp new file mode 100644 index 0000000000..78207e69e8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp @@ -0,0 +1,62 @@ +#include + +void MNNCubicSampleC4(const float* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); + } + + i += vl; + } +} From 8072a6299a92e2aacd293e26b06b09f5243d77e6 Mon Sep 17 00:00:00 2001 From: "weishan.wyf" Date: Mon, 15 Dec 2025 17:56:32 +0800 Subject: [PATCH 014/314] feat: add mnn supertonic support --- .../main/cpp/tts/include/mnn_tts_config.hpp | 4 + .../src/main/cpp/tts/include/mnn_tts_sdk.hpp | 3 +- .../supertonic/mnn_supertonic_tts_impl.hpp | 127 +++ .../src/main/cpp/tts/src/mnn_tts_config.cpp | 5 + .../app/src/main/cpp/tts/src/mnn_tts_sdk.cpp | 21 +- .../supertonic/mnn_supertonic_tts_impl.cpp | 902 ++++++++++++++++++ 6 files changed, 1055 insertions(+), 7 deletions(-) create mode 100644 apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp create mode 100644 apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp index 4f1903a5eb..8df703c026 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp @@ -46,4 +46,8 @@ class MNNTTSConfig std::string asset_folder_; std::string cache_folder_; int sample_rate_; + std::string precision_; + std::string speaker_id_; + int iter_steps_; + float speed_; }; \ No newline at end of file diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp index 263e912e41..21246a0152 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp @@ -7,6 +7,7 @@ #include "wavfile.hpp" #include "mnn_tts_config.hpp" +#include "mnn_tts_impl_base.hpp" //#include "piper/mnn_piper_tts_impl.hpp" #include "bertvits2/mnn_bertvits2_tts_impl.hpp" @@ -22,4 +23,4 @@ class MNNTTSSDK private: int sample_rate_; std::shared_ptr impl_; -}; \ No newline at end of file +}; diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp new file mode 100644 index 0000000000..24b2c72553 --- /dev/null +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp @@ -0,0 +1,127 @@ +#ifndef _HEADER_MNN_SUPERTONIC_TTS_IMPL_H_ +#define _HEADER_MNN_SUPERTONIC_TTS_IMPL_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "mnn_tts_impl_base.hpp" +#include "mnn_tts_logger.hpp" + +// Voice Style definition moved here +struct VoiceStyle +{ + std::vector> ttl; + std::vector> dp; + VoiceStyle() = default; + VoiceStyle(const std::vector> &ttl_data, + const std::vector> &dp_data) + : ttl(ttl_data), dp(dp_data) {} +}; + +/** + * @brief MNN C++ implementation of Supertonic TTS. + * + * This class handles the complete TTS pipeline including: + * 1. Text Processing (Normalization, cleaning, encoding) + * 2. MNN Model Inference (Duration Predictor, Text Encoder, Vector Estimator, + * Vocoder) + * 3. Audio Synthesis + */ +class MNNSupertonicTTSImpl : public MNNTTSImplBase +{ +public: + MNNSupertonicTTSImpl(const std::string &models_dir, const std::string &precision_dir, const std::string &speaker_id, int iter_steps, float speed); + + // Core Synthesis Interface + std::tuple Process(const std::string &text); + + // Overload for internal use or direct style passing + std::tuple synthesize(const std::string &text, const VoiceStyle &voice_styl, int steps, + float speed); + + // Save Audio + static bool save(const std::string &filename, + const std::vector &audio_data, int sample_rate); + +private: + // --- Configuration --- + std::string models_dir_; + std::string precision_dir_; + std::string cache_dir_; + std::string speaker_id_; + int iter_steps_; + float speed_; + + std::vector voice_ids_ = {"M1", "M2", "F1", "F2"}; + int sample_rate_; + int base_chunk_size_; + int chunk_compress_factor_; + int ldim_; + + // --- MNN Runtime --- + std::shared_ptr executor_; + + // --- MNN Modules --- + std::shared_ptr dp_module_; // Duration Predictor + std::shared_ptr te_module_; // Text Encoder + std::shared_ptr ve_module_; // Vector Estimator + std::shared_ptr vc_module_; // Vocoder + + // --- Internal Text Processor --- + class TextProcessor + { + public: + TextProcessor(const std::string &indexer_path); + std::vector encode(const std::string &text); + + private: + std::map unicode_to_index_; + std::map index_to_unicode_; + }; + std::unique_ptr text_processor_; + + // --- Private Helper Methods --- + + // Initialize all MNN models + void initializeModels(); + + // Predict phoneme durations + std::vector + predictDuration(const std::vector &text_ids, + const std::vector> &style_dp, + const std::vector &text_mask); + + // Encode text indices into embedding using Style Vector + std::vector + encodeText(const std::vector &text_ids, + const std::vector> &style_ttl, + const std::vector &text_mask); + + // Estimate audio vector (Flow Matching Step) + std::vector estimateVector(const std::vector &noisy_latent, + const std::vector &text_emb, + const std::vector &style_ttl_flat, + const std::vector &latent_mask, + const std::vector &text_mask, + int current_step, int total_step); + + // Vocode latent vector to audio + std::vector vocode(const std::vector &latent); + // --- Voice Style Management --- + void loadVoiceStyles(); + void loadVoiceStyle(const std::string &voice_name); + std::string preprocessText(const std::string &text); + + std::map voice_styles_; +}; + +#endif // _HEADER_MNN_SUPERTONIC_TTS_IMPL_H_ \ No newline at end of file diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp index b2b94c8391..75490a6b8c 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp @@ -34,6 +34,11 @@ MNNTTSConfig::MNNTTSConfig(const std::string &config_json_path) asset_folder_ = get_value_from_json(raw_config_data_, "asset_folder"); cache_folder_ = get_value_from_json(raw_config_data_, "cache_folder"); sample_rate_ = get_value_from_json(raw_config_data_, "sample_rate"); + precision_ = get_value_from_json(raw_config_data_, "precision"); + speaker_id_ = get_value_from_json(raw_config_data_, "speaker_id"); + iter_steps_ = get_value_from_json(raw_config_data_, "iter_steps"); + speed_ = get_value_from_json(raw_config_data_, "speed"); + } catch (const std::runtime_error &e) { diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp index 315e107647..478a5ffd71 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp @@ -1,9 +1,9 @@ - #include "mnn_tts_sdk.hpp" #include "piper/utf8.h" -#include +#include "supertonic/mnn_supertonic_tts_impl.hpp" #include // For std::wstring_convert and std::codecvt_utf8 #include +#include MNNTTSSDK::MNNTTSSDK(const std::string &config_folder) { @@ -14,15 +14,23 @@ MNNTTSSDK::MNNTTSSDK(const std::string &config_folder) auto assset_folder = config_folder + "/" + config.asset_folder_; auto cache_folder = config_folder + "/" + config.cache_folder_; sample_rate_ = config.sample_rate_; - if (model_type == "piper") { - impl_ = nullptr; + impl_ = nullptr; // std::make_shared(assset_folder, model_path, cache_folder); } else if (model_type == "bertvits") { - impl_ = std::make_shared(assset_folder, model_path, cache_folder); + impl_ = std::make_shared(assset_folder, model_path, cache_folder); + } + else if (model_type == "supertonic") + { + auto model_dir = config_folder; + std::string precision = config.precision_; + std::string speaker_id = config.speaker_id_; + int iter_steps = config.iter_steps_; + float speed = config.speed_; + impl_ = std::make_shared(model_dir, precision, speaker_id, iter_steps, speed); } else { @@ -35,7 +43,8 @@ std::tuple MNNTTSSDK::Process(const std::string &text) return impl_->Process(text); } -void MNNTTSSDK::WriteAudioToFile(const Audio &audio_data, const std::string &output_file_path) +void MNNTTSSDK::WriteAudioToFile(const Audio &audio_data, + const std::string &output_file_path) { std::ofstream audioFile(output_file_path, std::ios::binary); diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp new file mode 100644 index 0000000000..46f28a08e0 --- /dev/null +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp @@ -0,0 +1,902 @@ +/** + * @file mnn_supertonic_tts_impl.cpp + * @brief MNN Supertonic TTS实现类 + */ + +#include "supertonic/mnn_supertonic_tts_impl.hpp" +#include "mnn_tts_logger.hpp" +#include "utils.hpp" +#include + +#include "piper/uni_algo.hpp" +#include +#include +#include +#include // for getenv +#include +#include +#include +#include +#include + +using json = nlohmann::json; +using namespace MNN::Express; + +namespace +{ // Helper namespace + + // Emoji ranges check (matching Python regex) + bool is_emoji(char32_t cp) + { + if (cp >= 0x1F600 && cp <= 0x1F64F) + return true; // emoticons + if (cp >= 0x1F300 && cp <= 0x1F5FF) + return true; // symbols & pictographs + if (cp >= 0x1F680 && cp <= 0x1F6FF) + return true; // transport & map + if (cp >= 0x1F700 && cp <= 0x1F77F) + return true; + if (cp >= 0x1F780 && cp <= 0x1F7FF) + return true; + if (cp >= 0x1F800 && cp <= 0x1F8FF) + return true; + if (cp >= 0x1F900 && cp <= 0x1F9FF) + return true; + if (cp >= 0x1FA00 && cp <= 0x1FA6F) + return true; + if (cp >= 0x1FA70 && cp <= 0x1FAFF) + return true; + if (cp >= 0x2600 && cp <= 0x26FF) + return true; + if (cp >= 0x2700 && cp <= 0x27BF) + return true; + if (cp >= 0x1F1E6 && cp <= 0x1F1FF) + return true; // flags? + return false; + } + + // Combining diacritics check + bool is_combining_diacritic(char32_t cp) + { + // [\u0302\u0303\u0304\u0305\u0306\u0307\u0308\u030A\u030B\u030C\u0327\u0328\u0329\u032A\u032B\u032C\u032D\u032E\u032F] + // Simplified range check for performance (approximate 0300-036F block often + // used) But Python text.py is specific. Let's list them or use range if + // contiguous. They are mostly in 0x0300 block. + static const std::unordered_set diacritics = { + 0x0302, 0x0303, 0x0304, 0x0305, 0x0306, 0x0307, 0x0308, + 0x030A, 0x030B, 0x030C, 0x0327, 0x0328, 0x0329, 0x032A, + 0x032B, 0x032C, 0x032D, 0x032E, 0x032F}; + return diacritics.count(cp); + } + + // Special symbols check + bool is_special_symbol(char32_t cp) + { + // [♥☆♡©\\] + return cp == 0x2665 || cp == 0x2606 || cp == 0x2661 || cp == 0x00A9 || + cp == 0x005C; + } + + bool is_end_punctuation(char32_t cp) + { + // [.!?;:,'\"')\]}…。」』】〉》›»] + static const std::unordered_set puncts = { + '.', '!', '?', ';', ':', ',', '\'', + '"', ')', ']', '}', 0x2026, 0x3002, 0x300D, + 0x300F, 0x3011, 0x3009, 0x300B, 0x203A, 0x00BB}; + return puncts.count(cp); + } + + void utf8_append(std::string &s, char32_t cp) + { + if (cp < 0x80) + { + s.push_back(static_cast(cp)); + } + else if (cp < 0x800) + { + s.push_back(static_cast(0xC0 | (cp >> 6))); + s.push_back(static_cast(0x80 | (cp & 0x3F))); + } + else if (cp < 0x10000) + { + s.push_back(static_cast(0xE0 | (cp >> 12))); + s.push_back(static_cast(0x80 | ((cp >> 6) & 0x3F))); + s.push_back(static_cast(0x80 | (cp & 0x3F))); + } + else if (cp <= 0x10FFFF) + { + s.push_back(static_cast(0xF0 | (cp >> 18))); + s.push_back(static_cast(0x80 | ((cp >> 12) & 0x3F))); + s.push_back(static_cast(0x80 | ((cp >> 6) & 0x3F))); + s.push_back(static_cast(0x80 | (cp & 0x3F))); + } + } + +} // namespace + +// Text processor implementation +MNNSupertonicTTSImpl::TextProcessor::TextProcessor( + const std::string &indexer_path) +{ + std::ifstream file(indexer_path); + if (!file.is_open()) + { + throw std::runtime_error("Failed to open unicode indexer file: " + + indexer_path); + } + + json indexer_json; + file >> indexer_json; + file.close(); + + // Parse indexer + for (auto &item : indexer_json.items()) + { + uint16_t unicode_val = static_cast(std::stoi(item.key())); + int index = item.value().get(); + unicode_to_index_[unicode_val] = index; + index_to_unicode_[index] = unicode_val; + } + + PLOG(INFO, "Loaded unicode indexer with " + + std::to_string(unicode_to_index_.size()) + " entries"); +} + +std::vector +MNNSupertonicTTSImpl::TextProcessor::encode(const std::string &text) +{ + std::vector encoded; + + // 1. Normalize NFKD + std::string text_norm = una::norm::to_nfkd_utf8(text); + + // 2. Filter chars (Emojis, Diacritics, Special Symbols) and Map chars + std::string filtered; + filtered.reserve(text_norm.size()); + + auto view = una::ranges::utf8_view(text_norm); + for (auto it = view.begin(); it != view.end(); ++it) + { + char32_t cp = *it; + + if (is_emoji(cp)) + continue; + if (is_combining_diacritic(cp)) + continue; + if (is_special_symbol(cp)) + continue; + + // Char replacements + switch (cp) + { + case 0x2013: // – + case 0x2011: // ‑ + case 0x2014: // — + filtered.push_back('-'); + break; + case 0x00AF: // ¯ + case 0x005F: // _ + case 0x005B: // [ + case 0x005D: // ] + case 0x007C: // | + case 0x002F: // / + case 0x0023: // # + case 0x2192: // → + case 0x2190: // ← + filtered.push_back(' '); + break; + case 0x201C: // “ + case 0x201D: // ” + filtered.push_back('"'); + break; + case 0x2018: // ‘ + case 0x2019: // ’ + case 0x00B4: // ´ + case 0x0060: // ` + filtered.push_back('\''); + break; + default: + // Append as UTF-8 + std::string s; + utf8_append(s, cp); + filtered += s; + break; + } + } + + std::string t = filtered; + + // 3. String replacements (Expression replacements) + // "e.g.," -> "for example, " + // "i.e.," -> "that is, " + // "@" -> " at " + // Simple find/replace loop for these few replacements. + auto replace_all = [](std::string &str, const std::string &from, + const std::string &to) + { + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) + { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); + } + }; + + replace_all(t, "@", " at "); + replace_all(t, "e.g.,", "for example, "); + replace_all(t, "i.e.,", "that is, "); + + // 4. Regex replacements + // Spacing around punctuation + t = std::regex_replace(t, std::regex(" ,"), ","); + t = std::regex_replace(t, std::regex(" \\."), "."); + t = std::regex_replace(t, std::regex(" !"), "!"); + t = std::regex_replace(t, std::regex(" \\?"), "?"); + t = std::regex_replace(t, std::regex(" ;"), ";"); + t = std::regex_replace(t, std::regex(" :"), ":"); + t = std::regex_replace(t, std::regex(" '"), "'"); + + // Duplicate quotes + replace_all(t, "\"\"", "\""); + while (t.find("\"\"") != std::string::npos) + replace_all(t, "\"\"", "\""); + while (t.find("''") != std::string::npos) + replace_all(t, "''", "'"); + while (t.find("``") != std::string::npos) + replace_all(t, "``", "`"); + + // Remove extra spaces sequence + t = std::regex_replace(t, std::regex("\\s+"), " "); + + // Strip leading/trailing whitespace + t.erase(0, t.find_first_not_of(" ")); + t.erase(t.find_last_not_of(" ") + 1); + + // 5. Add period if needed + bool has_end_punct = false; + if (!t.empty()) + { + auto last_view = una::ranges::utf8_view(t); + // get last char + char32_t last_cp = 0; + for (auto cp : last_view) + last_cp = cp; + if (is_end_punctuation(last_cp)) + has_end_punct = true; + } + + if (!t.empty() && !has_end_punct) + { + t += "."; + } + + // 6. Encode to IDs + auto final_view = una::ranges::utf8_view(t); + for (char32_t cp : final_view) + { + uint16_t val = + static_cast(cp); // Potential truncation if outside BMP + auto it = unicode_to_index_.find(val); + if (it != unicode_to_index_.end()) + { + encoded.push_back(it->second); + } + else + { + // Handle unknown characters by mapping to index 0 (padding/unknown) + encoded.push_back(0); + } + } + + return encoded; +} + +// MNNSupertonicTTSImpl implementation +MNNSupertonicTTSImpl::MNNSupertonicTTSImpl( + const std::string &models_dir, + const std::string &precision_dir, + const std::string &speaker_id, + int iter_steps, + float speed) + : models_dir_(models_dir), + precision_dir_(precision_dir), + speaker_id_(speaker_id), + iter_steps_(iter_steps), + speed_(speed) +{ + + std::cout << "model_dir_" << models_dir_ << std::endl; + std::cout << "precsion_dir: " << precision_dir_ << std::endl; + std::cout << "cache_dir_: " << cache_dir_ << std::endl; + PLOG(INFO, "Initializing Supertonic TTS with models_dir: " + models_dir_); + + // Load config + std::string config_path = models_dir_ + "/mnn_models/tts.json"; + std::ifstream config_file(config_path); + if (!config_file.is_open()) + { + PLOG(ERROR, "Failed to open config file: " + config_path); + throw std::runtime_error("Failed to open config file: " + config_path); + } + json config; + config_file >> config; + + // Parse configuration + try + { + if (config.contains("ae")) + { + sample_rate_ = config["ae"].value("sample_rate", + 24000); // Default to 24000 if missing + base_chunk_size_ = config["ae"].value("base_chunk_size", 512); + } + else + { + sample_rate_ = 24000; + base_chunk_size_ = 512; + } + + // Attempt to retrieve TTL configuration + if (config.contains("ttl")) + { + chunk_compress_factor_ = config["ttl"].value("chunk_compress_factor", 6); + ldim_ = config["ttl"].value("latent_dim", 24); + } + else + { + // Fallback: Use default values or check specific nested keys + chunk_compress_factor_ = 6; + ldim_ = 24; + + if (config.contains("style_encoder") && + config["style_encoder"].contains("proj_in")) + { + chunk_compress_factor_ = config["style_encoder"]["proj_in"].value( + "chunk_compress_factor", 6); + ldim_ = config["style_encoder"]["proj_in"].value("ldim", 24); + } + } + + PLOG(INFO, "Config loaded: sample_rate=" + std::to_string(sample_rate_) + + ", base_chunk_size=" + std::to_string(base_chunk_size_) + + ", chunk_compress_factor=" + + std::to_string(chunk_compress_factor_) + + ", ldim=" + std::to_string(ldim_)); + } + catch (const std::exception &e) + { + PLOG(ERROR, "Error parsing config: " + std::string(e.what())); + throw; + } + + // Initialize text processor + std::string indexer_path = models_dir_ + "/mnn_models/unicode_indexer.json"; + text_processor_ = std::make_unique(indexer_path); + + // Initialize MNN inference engine + initializeModels(); + + // Load Voice Styles + loadVoiceStyles(); + + PLOG(INFO, "Supertonic TTS initialized successfully"); +} + +void MNNSupertonicTTSImpl::loadVoiceStyles() +{ + PLOG(INFO, "Loading voice styles..."); + + // Default voices + for (const auto &id : voice_ids_) + { + loadVoiceStyle(id); + } + PLOG(INFO, "Voice styles loaded successfully"); +} + +void MNNSupertonicTTSImpl::loadVoiceStyle(const std::string &voice_name) +{ + try + { + std::string style_path = + models_dir_ + "/mnn_models/voice_styles/" + voice_name + ".json"; + // Check if file exists, if not try old path structure + std::ifstream f_check(style_path); + if (!f_check.good()) + { + style_path = models_dir_ + "/voice_styles/" + voice_name + ".json"; + } + f_check.close(); + + std::ifstream style_file(style_path); + if (!style_file.is_open()) + { + PLOG(WARNING, "Failed to open style.json for voice: " + voice_name + + " at " + style_path); + return; + } + + json style_json; + style_file >> style_json; + style_file.close(); + + std::vector> ttl_vectors, dp_vectors; + + if (style_json.contains("style_ttl") && style_json.contains("style_dp")) + { + // Parse TTL + for (const auto &ttl_item : style_json["style_ttl"]["data"]) + { + for (const auto &vec : ttl_item) + { + std::vector ttl_vector; + for (const auto &val : vec) + ttl_vector.push_back(val.get()); + ttl_vectors.push_back(ttl_vector); + } + } + // Parse DP + for (const auto &dp_item : style_json["style_dp"]["data"]) + { + for (const auto &vec : dp_item) + { + std::vector dp_vector; + for (const auto &val : vec) + dp_vector.push_back(val.get()); + dp_vectors.push_back(dp_vector); + } + } + voice_styles_[voice_name] = VoiceStyle(ttl_vectors, dp_vectors); + PLOG(INFO, "Loaded voice style: " + voice_name); + } + } + catch (const std::exception &e) + { + PLOG(ERROR, "Error loading voice style " + voice_name + ": " + e.what()); + } +} + +std::string MNNSupertonicTTSImpl::preprocessText(const std::string &text) +{ + // Remove excess whitespace + std::string processed = std::regex_replace(text, std::regex("\\s+"), " "); + // Trim leading/trailing whitespace + processed = std::regex_replace(processed, std::regex("^\\s+|\\s+$"), ""); + processed = std::regex_replace(processed, std::regex("[\\.!?]+$"), "."); + return processed; +} + +std::tuple MNNSupertonicTTSImpl::Process(const std::string &text) +{ + const std::string &voice_name = speaker_id_; + int steps = iter_steps_; + float speed = speed_; + + if (voice_styles_.find(voice_name) == voice_styles_.end()) + { + PLOG(ERROR, "Voice style not found: " + voice_name); + throw std::runtime_error("Voice style not found: " + voice_name); + } + std::string processed_text = preprocessText(text); + return synthesize(processed_text, voice_styles_[voice_name], steps, speed); +} + +bool MNNSupertonicTTSImpl::save(const std::string &filename, + const std::vector &audio_data, + int sample_rate) +{ + return writeWavFile(filename, audio_data, sample_rate); +} + +std::tuple +MNNSupertonicTTSImpl::synthesize(const std::string &text, + const VoiceStyle &voice_style, int steps, + float speed) +{ + + auto default_ret = std::make_tuple(sample_rate_, std::vector(0)); + auto start_time = std::chrono::high_resolution_clock::now(); + + PLOG(INFO, "Synthesizing text: \"" + text + "\""); + + // 1. Text Encoding + std::vector text_ids = text_processor_->encode(text); + // Ensure non-empty text_ids; pad with 0 if necessary + if (text_ids.empty()) + { + text_ids.push_back(0); + } + + // Create text mask (all 1s) + std::vector text_mask(text_ids.size(), 1.0f); + + // 2. Duration Prediction + auto duration_outputs = predictDuration(text_ids, voice_style.dp, text_mask); + if (duration_outputs.empty()) + { + PLOG(ERROR, "Duration prediction failed"); + return default_ret; + } + + // 3. Process Duration to get total length + // The first element of duration_outputs is expected to be the total duration + float total_duration_sec = duration_outputs[0]; + + // Apply speed adjustment + if (speed > 0.0f) + { + total_duration_sec /= speed; + } + + int wav_len = static_cast(total_duration_sec * sample_rate_); + int chunk_size = base_chunk_size_ * chunk_compress_factor_; + int latent_len = (wav_len + chunk_size - 1) / chunk_size; + + // Ensure minimum length + if (latent_len < 1) + latent_len = 1; + + wav_len = latent_len * base_chunk_size_; // Recalculate aligned wav_len + + // 4. Generate Text Embedding + // This uses the text encoder model + std::vector text_emb = + encodeText(text_ids, voice_style.ttl, text_mask); + if (text_emb.empty()) + { + PLOG(ERROR, "Text encoding failed"); + return default_ret; + } + + // 5. Vector Estimation (Flow Matching) + // Flatten style.ttl for Vector Estimator input + std::vector style_ttl_flat; + if (!voice_style.ttl.empty()) + { + int ttl_dim0 = voice_style.ttl.size(); + int ttl_dim1 = voice_style.ttl[0].size(); + style_ttl_flat.reserve(ttl_dim0 * ttl_dim1); + for (const auto &vec : voice_style.ttl) + { + style_ttl_flat.insert(style_ttl_flat.end(), vec.begin(), vec.end()); + } + } + + // Initialize Latent Mask + std::vector latent_mask(latent_len, 1); + + // Generate Noisy Latent (Gaussian Noise) + int latent_dim = ldim_ * chunk_compress_factor_; + std::vector noisy_latent(1 * latent_dim * latent_len); + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution dis(0.0f, 1.0f); + + for (size_t i = 0; i < noisy_latent.size(); ++i) + { + noisy_latent[i] = dis(gen); + } + + // Perform Flow Matching steps + std::vector estimated = noisy_latent; + for (int step = 0; step < steps; ++step) + { + estimated = estimateVector(estimated, text_emb, style_ttl_flat, latent_mask, + text_mask, step, steps); + if (estimated.empty()) + { + PLOG(ERROR, "Vector estimation failed at step " + std::to_string(step)); + return default_ret; + } + } + + // 6. Vocoder Synthesis + std::vector audio = vocode(estimated); + if (audio.empty()) + { + PLOG(ERROR, "Vocoding failed"); + return default_ret; + } + + auto end_time = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration = end_time - start_time; + float time_cost = duration.count(); + float audio_duration = static_cast(audio.size()) / sample_rate_; + float rtf = time_cost / audio_duration; + PLOG(INFO, "RTF: " + std::to_string(rtf)); + + std::vector audio_int16; + for (int i = 0; i < audio.size(); i++) + { + audio_int16.push_back(audio[i] * 32768); + } + + return std::make_tuple(sample_rate_, audio_int16); +} + +// Core processing function +std::vector MNNSupertonicTTSImpl::predictDuration( + const std::vector &text_ids, + const std::vector> &style_dp, + const std::vector &text_mask) +{ + + MNN::Express::ExecutorScope scope(executor_); + + // Prepare inputs + std::vector inputs(3); + int num_tokens = static_cast(text_ids.size()); + int dp_size_0 = static_cast(style_dp.size()); + int dp_size_1 = static_cast(style_dp[0].size()); + + // Input 0: text_ids {1, num_tokens} NCHW + inputs[0] = + MNN::Express::_Input({1, num_tokens}, NCHW, halide_type_of()); + auto ptr0 = inputs[0]->writeMap(); + for (int i = 0; i < num_tokens; ++i) + { + ptr0[i] = text_ids[i]; + } + + // Input 1: style_dp {1, dp_size_0, dp_size_1} NCHW + inputs[1] = MNN::Express::_Input({1, dp_size_0, dp_size_1}, NCHW, + halide_type_of()); + auto ptr1 = inputs[1]->writeMap(); + for (int i = 0; i < dp_size_0; ++i) + { + for (int j = 0; j < dp_size_1; ++j) + { + ptr1[i * dp_size_1 + j] = style_dp[i][j]; + } + } + + // Input 2: text_mask {1, 1, num_tokens} NCHW + inputs[2] = + MNN::Express::_Input({1, 1, num_tokens}, NCHW, halide_type_of()); + auto ptr2 = inputs[2]->writeMap(); + for (int i = 0; i < num_tokens; ++i) + { + ptr2[i] = text_mask[i]; + } + + // Run Inference + std::vector outputs = dp_module_->onForward(inputs); + + // Process Output + if (outputs.empty()) + { + PLOG(ERROR, "Duration Predictor returned empty output"); + return {}; + } + + auto output = outputs[0]; + auto size = output->getInfo()->size; + std::vector result(size); + ::memcpy(result.data(), output->readMap(), size * sizeof(float)); + + return result; +} + +std::vector MNNSupertonicTTSImpl::encodeText( + const std::vector &text_ids, + const std::vector> &style_ttl, + const std::vector &text_mask) +{ + + MNN::Express::ExecutorScope scope(executor_); + + std::vector inputs(3); + + // Input 0: Text IDs + int num_tokens = static_cast(text_ids.size()); + inputs[0] = + MNN::Express::_Input({1, num_tokens}, NCHW, halide_type_of()); + auto ptr0 = inputs[0]->writeMap(); + for (int i = 0; i < num_tokens; ++i) + { + ptr0[i] = text_ids[i]; + } + + // Input 1: Style Output (TTL) + int ttl_size_0 = static_cast(style_ttl.size()); + int ttl_size_1 = static_cast(style_ttl[0].size()); + inputs[1] = MNN::Express::_Input({1, ttl_size_0, ttl_size_1}, NCHW, + halide_type_of()); + auto ptr1 = inputs[1]->writeMap(); + for (int i = 0; i < ttl_size_0; ++i) + { + for (int j = 0; j < ttl_size_1; ++j) + { + ptr1[i * ttl_size_1 + j] = style_ttl[i][j]; + } + } + + // Input 2: Text Mask + inputs[2] = + MNN::Express::_Input({1, 1, num_tokens}, NCHW, halide_type_of()); + auto ptr2 = inputs[2]->writeMap(); + for (int i = 0; i < num_tokens; ++i) + { + ptr2[i] = text_mask[i]; + } + + // Run Inference + std::vector outputs = te_module_->onForward(inputs); + + if (outputs.empty()) + { + PLOG(ERROR, "Text Encoder returned empty output"); + return {}; + } + + auto output = outputs[0]; + auto size = output->getInfo()->size; + std::vector result(size); + ::memcpy(result.data(), output->readMap(), size * sizeof(float)); + + return result; +} + +std::vector MNNSupertonicTTSImpl::estimateVector( + const std::vector &noisy_latent, const std::vector &text_emb, + const std::vector &style_ttl, const std::vector &latent_mask, + const std::vector &text_mask, int current_step, int total_step) +{ + + MNN::Express::ExecutorScope scope(executor_); + + // Prepare 7 inputs + std::vector inputs(7); + + // Shapes derived from mnn_estimator.cpp + int total_size = static_cast(noisy_latent.size()); + int latent_dim = ldim_ * chunk_compress_factor_; + int latent_len = total_size / latent_dim; + + // Input 0: noisy_latent {1, latent_dim, latent_len} NCHW + inputs[0] = MNN::Express::_Input({1, latent_dim, latent_len}, NCHW, + halide_type_of()); + ::memcpy(inputs[0]->writeMap(), noisy_latent.data(), + total_size * sizeof(float)); + + // Input 1: text_emb {1, channels, text_len} + int text_len = static_cast(text_mask.size()); + int text_channels = text_emb.size() / text_len; + + inputs[1] = MNN::Express::_Input({1, text_channels, text_len}, NCHW, + halide_type_of()); + ::memcpy(inputs[1]->writeMap(), text_emb.data(), + text_emb.size() * sizeof(float)); + + // Input 2: style_ttl {1, num_style_vectors, 256} NCHW + int style_len = static_cast(style_ttl.size()); + int style_dim = 256; + int num_style_vectors = style_len / style_dim; + + inputs[2] = MNN::Express::_Input({1, num_style_vectors, style_dim}, NCHW, + halide_type_of()); + ::memcpy(inputs[2]->writeMap(), style_ttl.data(), + style_ttl.size() * sizeof(float)); + + // Input 3: latent_mask {1, 1, latent_len} + inputs[3] = + MNN::Express::_Input({1, 1, latent_len}, NCHW, halide_type_of()); + auto ptr3 = inputs[3]->writeMap(); + for (int i = 0; i < latent_len; ++i) + ptr3[i] = static_cast(latent_mask[i]); + + // Input 4: text_mask {1, 1, text_len} + inputs[4] = + MNN::Express::_Input({1, 1, text_len}, NCHW, halide_type_of()); + ::memcpy(inputs[4]->writeMap(), text_mask.data(), + text_mask.size() * sizeof(float)); + + // Input 5: current_step {1} + inputs[5] = MNN::Express::_Input({1}, NCHW, halide_type_of()); + inputs[5]->writeMap()[0] = static_cast(current_step); + + // Input 6: total_step {1} + inputs[6] = MNN::Express::_Input({1}, NCHW, halide_type_of()); + inputs[6]->writeMap()[0] = static_cast(total_step); + + auto outputs = ve_module_->onForward(inputs); + + if (outputs.empty()) + { + PLOG(ERROR, "Vector Estimator returned empty output"); + return {}; + } + + auto output = outputs[0]; + auto size = output->getInfo()->size; + std::vector result(size); + ::memcpy(result.data(), output->readMap(), size * sizeof(float)); + return result; +} + +std::vector +MNNSupertonicTTSImpl::vocode(const std::vector &latent) +{ + MNN::Express::ExecutorScope scope(executor_); + + // Shape: {1, 144, latent_len} + // Latent dim 144. + int total_size = static_cast(latent.size()); + int dim = 144; + int len = total_size / dim; + + std::vector inputs(1); + inputs[0] = + MNN::Express::_Input({1, dim, len}, NCHW, halide_type_of()); + ::memcpy(inputs[0]->writeMap(), latent.data(), + total_size * sizeof(float)); + + auto outputs = vc_module_->onForward(inputs); + if (outputs.empty()) + return {}; + + auto output = outputs[0]; + auto size = output->getInfo()->size; + std::vector result(size); + ::memcpy(result.data(), output->readMap(), size * sizeof(float)); + return result; +} + +// std::tuple MNNSupertonicTTSImpl::Process(const std::string &text) +// { +// // Simplified Process implementation for compatibility with base class +// // interface +// VoiceStyle default_voice({{{0.1f, 0.2f, 0.3f}}}, {{{0.4f, 0.5f, 0.6f}}}); +// auto [audio_float, sample_rate, rtf] = +// synthesize(text, default_voice, 10, 1.0f); + +// // Convert float to int16_t +// Audio audio_int16(audio_float.size()); +// for (size_t i = 0; i < audio_float.size(); ++i) { +// audio_int16[i] = static_cast(audio_float[i] * 32767.0f); +// } + +// return std::make_tuple(sample_rate, audio_int16); +// } + +// Initialize MNN Models +void MNNSupertonicTTSImpl::initializeModels() +{ + PLOG(INFO, "Initializing models..."); + + // Set up Runtime (Executor) with Low Precision config + MNN::BackendConfig backendConfig; + backendConfig.precision = MNN::BackendConfig::Precision_Low; + backendConfig.memory = MNN::BackendConfig::Memory_Low; + + // Create Executor once + executor_ = std::shared_ptr( + MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 4)); + + MNN::Express::ExecutorScope scope(executor_); + + // Load Models directly using Module::load + auto loadModule = [&](const std::string &filename, const std::string &name) + { + std::string path = + models_dir_ + "/mnn_models/" + precision_dir_ + "/" + filename; + std::vector inputs, + outputs; // Empty for auto-detection or not needed for load + auto module = std::shared_ptr( + MNN::Express::Module::load(inputs, outputs, path.c_str())); + if (!module) + { + PLOG(ERROR, "Failed to load " + name + ": " + path); + throw std::runtime_error("Failed to load model: " + path); + } + PLOG(INFO, "Successfully loaded " + name); + return module; + }; + + dp_module_ = loadModule("duration_predictor.mnn", "Duration Predictor"); + te_module_ = loadModule("text_encoder.mnn", "Text Encoder"); + ve_module_ = loadModule("vector_estimator.mnn", "Vector Estimator"); + vc_module_ = loadModule("vocoder.mnn", "Vocoder"); +} \ No newline at end of file From cafcadb9a312633109cf7359bbb54bcfdc47d053 Mon Sep 17 00:00:00 2001 From: zlaazlaa <2889827787@qq.com> Date: Thu, 18 Dec 2025 14:36:16 +0800 Subject: [PATCH 015/314] fix(diffusion): simplify export logic and fix dynamic axes Signed-off-by: zlaazlaa <2889827787@qq.com> --- docs/transformers/diffusion.md | 3 +- transformers/diffusion/export/onnx_export.py | 30 ++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 7de27bb216..609793f806 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path + --output_path onnx_save_path \ + --opset 18 ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 21f05e83be..5516eb2fcc 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - "A sample prompt", + ["A sample prompt", "A sample prompt"], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,9 +97,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.text_encoder @@ -117,13 +115,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "timestep": {0: "batch"}, - "encoder_hidden_states": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -149,7 +143,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() onnx_export( vae_encoder, model_args=( @@ -159,30 +153,24 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = vae_encoder.decode + vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), - False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample", "return_dict"], + ordered_input_names=["latent_sample"], output_names=["sample"], - dynamic_axes={ - "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.vae From f31132f9d7f51c239f5c7f6b2dfe645d683a2e00 Mon Sep 17 00:00:00 2001 From: "zhaode.wzd" Date: Fri, 19 Dec 2025 09:45:24 +0800 Subject: [PATCH 016/314] [Sync] Update MNN to 3.3.1 and Sync Internal Improvements - Version: Upgrade to 3.3.1. - LLM: Refactor with Transformer-like config/model/tokenizer; optimize Tokenizer load speed; enhance KV Cache (mmap, disk storage, prefix cache); fix LoRA and VL model bugs. - CPU/SME: Support SME/SME2 instructions and mixed NEON/SME multi-threading; optimize RVV intrinsics and FP32 depth-wise kernels. - Metal/OpenCL: Support Metal KV Cache mmap and GPU family identification; refactor OpenCL Loop operators and fix offset bugs. - QNN/Android: Optimize QNN build/conversion; update Reranker API; fix Android packaging and SO loading issues. - General: Fix compilation for Arm32/Win/MinGW; update gitignore and expose public symbols; revert StrideSliceWrite behavior. --- .gitignore | 5 +- CMakeLists.txt | 9 +- cmake/KleidiAI.cmake | 21 +- docs/compile/other.md | 1 + docs/transformers/llm.md | 17 +- include/MNN/Interpreter.hpp | 22 +- include/MNN/MNNDefine.h | 6 +- project/android/qnnprepare.gradle | 28 +- pymnn/src/reranker.h | 11 + source/backend/arm82/Arm82Functions.cpp | 1244 ++++++++------ ...MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S | 109 +- ...GemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S | 106 +- .../sme2_asm/MNNPackedMatMulRemainFP16_SME2.S | 185 ++- source/backend/cpu/CPUAttention.cpp | 868 ++++++---- source/backend/cpu/CPUAttention.hpp | 29 +- source/backend/cpu/CPUBackend.cpp | 14 +- source/backend/cpu/CPUConvolution.cpp | 8 +- source/backend/cpu/CPUConvolution.hpp | 2 +- .../backend/cpu/CPUConvolutionDepthwise.cpp | 21 + source/backend/cpu/CPUKVCacheManager.cpp | 795 +++++++++ source/backend/cpu/CPUKVCacheManager.hpp | 140 ++ source/backend/cpu/CPUMatMul.cpp | 14 +- source/backend/cpu/CPURaster.cpp | 8 +- source/backend/cpu/CPURuntime.cpp | 2 + source/backend/cpu/CPURuntime.hpp | 1 + source/backend/cpu/CPUSoftmax.cpp | 10 +- source/backend/cpu/CPUTensorConvert.cpp | 2 +- source/backend/cpu/KVCacheManager.cpp | 758 --------- source/backend/cpu/KVCacheManager.hpp | 172 -- .../cpu/KleidiAIConvolutionDepthwise.cpp | 168 ++ .../cpu/KleidiAIConvolutionDepthwise.hpp | 37 + .../backend/cpu/arm/CommonOptFunctionNeon.cpp | 533 ++++-- .../MNNGemmInt8AddBiasScale_ARMV82_Unit.S | 159 +- .../MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S | 144 +- ...NNGemmInt8AddBiasScale16x32_SME2_w4_Fp16.S | 280 +++- ...NNGemmInt8AddBiasScale16x32_SME2_w4_Fp32.S | 445 ++++- ...NNGemmInt8AddBiasScale16x32_SME2_w8_Fp16.S | 278 +++- ...NNGemmInt8AddBiasScale16x32_SME2_w8_Fp32.S | 762 ++++++++- ...NNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16.S | 1 + ...NNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16.S | 2 + .../sme2_asm/MNNPackedMatMulRemainFP32_SME2.S | 452 ++++- .../backend/cpu/compute/CommonOptFunction.cpp | 379 ++++- .../backend/cpu/compute/CommonOptFunction.h | 32 +- .../cpu/compute/ConvInt8TiledExecutor.cpp | 816 +++++++-- .../cpu/compute/ConvInt8TiledExecutor.hpp | 12 +- .../cpu/compute/ConvolutionTiledExecutor.cpp | 2 +- .../backend/cpu/compute/Int8FunctionsOpt.cpp | 86 +- source/backend/cpu/compute/Int8FunctionsOpt.h | 28 +- source/backend/cpu/riscv/rvv/MNNMatrixAdd.cpp | 26 + source/backend/cpu/riscv/rvv/MNNMatrixMax.cpp | 26 + source/backend/cpu/riscv/rvv/MNNMatrixSub.cpp | 26 + source/backend/cpu/x86_x64/AVX2Functions.cpp | 2 +- .../cpu/x86_x64/FunctionDispatcher.cpp | 9 +- .../cpu/x86_x64/avx/FunctionSummary.hpp | 2 +- source/backend/cpu/x86_x64/avx/GemmAVX2.cpp | 10 +- source/backend/cpu/x86_x64/avx/GemmInt8.cpp | 41 +- .../backend/cpu/x86_x64/avx/MathFunctions.cpp | 149 +- .../cpu/x86_x64/avx/PackedFunction.cpp | 9 +- .../cpu/x86_x64/avx/ReorderFunctions.cpp | 12 +- .../cpu/x86_x64/avx512/PackedFunction.cpp | 138 +- .../cpu/x86_x64/avx512/ReorderFunctions.cpp | 12 +- .../cpu/x86_x64/sse/FunctionSummary.hpp | 1 - source/backend/cpu/x86_x64/sse/GemmSSE.cpp | 6 +- .../backend/cpu/x86_x64/sse/MathFunctions.cpp | 65 - source/backend/metal/ConvSimdGroupShader.hpp | 1454 ++++++++++++----- source/backend/metal/MetalAttention.mm | 250 ++- source/backend/metal/MetalAttentionShader.hpp | 636 +++++-- source/backend/metal/MetalBackend.hpp | 10 +- source/backend/metal/MetalBackend.mm | 35 +- source/backend/metal/MetalBinary.mm | 4 +- source/backend/metal/MetalCast.mm | 12 +- source/backend/metal/MetalConvolution1x1.hpp | 4 + source/backend/metal/MetalConvolution1x1.mm | 197 ++- source/backend/metal/MetalKVCacheManager.hpp | 65 + source/backend/metal/MetalKVCacheManager.mm | 336 ++++ source/backend/metal/MetalRaster.mm | 20 +- source/backend/metal/MetalUnary.mm | 4 +- .../execution/buffer/LoopBufExecution.cpp | 932 +++++++---- .../execution/buffer/LoopBufExecution.hpp | 59 +- .../execution/buffer/UnaryBufExecution.cpp | 3 +- .../backend/opencl/execution/cl/gather_buf.cl | 46 - .../opencl/execution/cl/gather_buf_mnn_cl.cpp | 50 - source/backend/opencl/execution/cl/loop.cl | 208 ++- .../backend/opencl/execution/cl/loop_buf.cl | 469 ------ .../opencl/execution/cl/loop_buf_mnn_cl.cpp | 463 ------ .../opencl/execution/cl/loop_mnn_cl.cpp | 196 ++- .../opencl/execution/cl/opencl_source_map.hpp | 16 +- .../opencl/execution/image/LoopExecution.cpp | 1367 ++++++++-------- .../opencl/execution/image/LoopExecution.hpp | 65 +- source/backend/qnn/CMakeLists.txt | 10 +- source/core/Backend.hpp | 27 +- source/core/KVCacheManager.cpp | 113 ++ source/core/KVCacheManager.hpp | 96 ++ source/core/OpCommonUtils.hpp | 10 + source/core/Session.cpp | 9 + source/core/TensorUtils.hpp | 2 +- source/geometry/GeometryStridedSlice.cpp | 4 +- test/main.cpp | 6 + test/op/AttentionTest.cpp | 155 +- test/op/StridedSliceTest.cpp | 178 ++ test/speed/HybridConvSpeedTest.cpp | 46 +- tools/converter/source/common/cli.cpp | 12 +- tools/cpp/CMakeLists.txt | 4 +- tools/cpp/MNN2QNNModel.cpp | 2 +- tools/cpp/ModuleBasic.cpp | 14 +- tools/cpp/compilefornpu.cpp | 1 - tools/script/arm2binary.py | 88 +- transformers/llm/config.json | 5 +- transformers/llm/engine/CMakeLists.txt | 29 +- transformers/llm/engine/demo/llm_demo.cpp | 1 + .../llm/engine/demo/rollback_demo.cpp | 159 +- transformers/llm/engine/include/llm/llm.hpp | 7 +- .../llm/engine/include/llm/reranker.hpp | 24 +- transformers/llm/engine/src/kvmeta.hpp | 16 +- transformers/llm/engine/src/llm.cpp | 128 +- transformers/llm/engine/src/llmconfig.hpp | 8 +- transformers/llm/engine/src/omni.cpp | 5 + transformers/llm/engine/src/tokenizer.cpp | 335 ++-- transformers/llm/engine/src/tokenizer.hpp | 20 +- transformers/llm/engine/tools/llm_bench.cpp | 151 +- transformers/llm/eval/evaluate_perplexity.py | 13 +- transformers/llm/export/llmexport.py | 1171 ++----------- transformers/llm/export/utils/audio.py | 13 +- .../llm/export/utils/awq_quantizer.py | 39 +- transformers/llm/export/utils/config.py | 116 ++ transformers/llm/export/utils/eagle.py | 16 +- .../llm/export/utils/mnn_converter.py | 64 +- transformers/llm/export/utils/model.py | 380 +++++ transformers/llm/export/utils/model_mapper.py | 162 +- transformers/llm/export/utils/mtp.py | 29 +- .../llm/export/utils/smooth_quantizer.py | 9 +- transformers/llm/export/utils/talker.py | 13 +- transformers/llm/export/utils/token2wav.py | 1 - transformers/llm/export/utils/tokenizer.py | 376 +++++ transformers/llm/export/utils/transformers.py | 32 +- transformers/llm/export/utils/vision.py | 15 +- 136 files changed, 13586 insertions(+), 7187 deletions(-) create mode 100644 source/backend/cpu/CPUKVCacheManager.cpp create mode 100644 source/backend/cpu/CPUKVCacheManager.hpp delete mode 100644 source/backend/cpu/KVCacheManager.cpp delete mode 100644 source/backend/cpu/KVCacheManager.hpp create mode 100644 source/backend/cpu/KleidiAIConvolutionDepthwise.cpp create mode 100644 source/backend/cpu/KleidiAIConvolutionDepthwise.hpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMatrixAdd.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMatrixMax.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMatrixSub.cpp create mode 100644 source/backend/metal/MetalKVCacheManager.hpp create mode 100644 source/backend/metal/MetalKVCacheManager.mm delete mode 100644 source/backend/opencl/execution/cl/gather_buf.cl delete mode 100644 source/backend/opencl/execution/cl/gather_buf_mnn_cl.cpp delete mode 100644 source/backend/opencl/execution/cl/loop_buf.cl delete mode 100644 source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp create mode 100644 source/core/KVCacheManager.cpp create mode 100644 source/core/KVCacheManager.hpp create mode 100644 transformers/llm/export/utils/config.py create mode 100644 transformers/llm/export/utils/model.py create mode 100644 transformers/llm/export/utils/tokenizer.py diff --git a/.gitignore b/.gitignore index 66f35d8e0a..a8edaf25f6 100644 --- a/.gitignore +++ b/.gitignore @@ -380,4 +380,7 @@ apps/mnncli/model_market_json_data.inc #kledi _deps #aicoding -.cursor \ No newline at end of file +.cursor + +# llm model +transformers/llm/export/model/ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d942aec59..67502b606b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -277,7 +277,7 @@ if (NOT MNN_CUDA OR NOT CMAKE_SYSTEM_NAME MATCHES "^Linux") set(MNN_CUDA_PROFILE OFF) endif() -if (NOT MNN_QNN) +if (NOT MNN_QNN) set(MNN_QNN_ONLINE_FINALIZE OFF) endif() @@ -557,10 +557,13 @@ ENDIF() find_package(Threads) list(APPEND MNN_EXTRA_DEPENDS ${CMAKE_THREAD_LIBS_INIT}) if(WIN32) - if(NOT MSVC) + if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fuse-ld=lld-link -lmsvcrt") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fuse-ld=lld-link -lmsvcrt") - else() + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -lmsvcrt") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -lmsvcrt") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /STACK:8388608") endif() endif() diff --git a/cmake/KleidiAI.cmake b/cmake/KleidiAI.cmake index c35f324b74..68072cf058 100644 --- a/cmake/KleidiAI.cmake +++ b/cmake/KleidiAI.cmake @@ -90,7 +90,9 @@ function (download_kleidiai_and_collect_sources) ${KLEIDIAI_SRC_DIR}/kai/ukernels/matmul/matmul_clamp_f16_f16p_f16p/ ${KLEIDIAI_SRC_DIR}/kai/ukernels/matmul/matmul_clamp_f16_f16_f16p/ ${KLEIDIAI_SRC_DIR}/kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/ - ${KLEIDIAI_SRC_DIR}/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/) + ${KLEIDIAI_SRC_DIR}/kai/ukernels/matmul/imatmul_clamp_f16_f16p_f16p/ + ${KLEIDIAI_SRC_DIR}/kai/ukernels/dwconv/pack/ + ${KLEIDIAI_SRC_DIR}/kai/ukernels/dwconv/dwconv_f32_f32_f32p/) file(GLOB kleidiai_pack_sources "${KLEIDIAI_SRC_DIR}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32pscalef32_f16_neon.c" @@ -196,6 +198,17 @@ function (download_kleidiai_and_collect_sources) ) list(APPEND KLEIDIAI_FILES_SME2 ${matmul_clamp_f32_qsi8d32p_qai4c32p_sme2_sources}) + file(GLOB dwconv_pack_sources + "${KLEIDIAI_SRC_DIR}/kai/ukernels/dwconv/pack/*.c" + ) + list(APPEND KLEIDIAI_FILES_SME2 ${dwconv_pack_sources}) + + file(GLOB dwconv_f32_f32_f32p_sme2_sources + "${KLEIDIAI_SRC_DIR}/kai/ukernels/dwconv/dwconv_f32_f32_f32p/*.c" + "${KLEIDIAI_SRC_DIR}/kai/ukernels/dwconv/dwconv_f32_f32_f32p/*.S" + ) + list(APPEND KLEIDIAI_FILES_SME2 ${dwconv_f32_f32_f32p_sme2_sources}) + set_source_files_properties( ${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS @@ -208,6 +221,8 @@ function (download_kleidiai_and_collect_sources) set(MNN_SOURCES_KLEIDIAI "${MNN_SOURCES_KLEIDIAI}" PARENT_SCOPE) set(KLEIDIAI_FILES_SME2 "${KLEIDIAI_FILES_SME2}" PARENT_SCOPE) - # Define macro to indicate KleidiAI is enabled - add_definitions(-DMNN_KLEIDIAI_ENABLED=1) + # Define macro to indicate KleidiAI is enabled (only on aarch64) + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)") + add_definitions(-DMNN_KLEIDIAI_ENABLED=1) + endif() endfunction() diff --git a/docs/compile/other.md b/docs/compile/other.md index b3edb55358..31a2734337 100644 --- a/docs/compile/other.md +++ b/docs/compile/other.md @@ -63,6 +63,7 @@ - `llm_demo` 大语言模型推理示例程序 - `diffusion_demo` 扩散模型示例程序 - `llm_bench` 大语言模型测评工具 + - `rollback_demo` 大语言模型kvcache回调示例工具 - `quantize_llm` 大语言模型feature map量化工具 ## 测试工具 - 相关编译选项 diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index fa4f60a851..e2ec2c94c3 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -403,12 +403,14 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt - thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68`(不是传统意义的线程数,代表的是opencl buffer存储和tuning wide模式) - precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16` - memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化 -- 与CPU动态量化相关的配置 - - dynamic_option: 推理时是否对feature map分blocksize/group进行量化。可选为:`0, 1, 2`,默认是`0`,含义如下: +- 与CPU动态量化相关的配置,提升精度、性能 + - dynamic_option: 推理时是否对feature map分blocksize/group进行量化。可选为:`0, 1, 2, 8, 9, 10`,默认是`0`,含义如下: - 0: feature map数据使用per channel量化 - 1: feature map数据使用per tensor量化 - 2: feature map数据用per block量化,blocksize等于权重量化时的blocksize,如果权重量化时没有使用per block量化,即使设置2,也不会对feature map做per block量化 - - prefer_decode: 是否希望有更快的解码(Decode)速度。可选:`true, false`,默认`false`。注意:当prompt长度小于300时,`true`条件下的Prefill速度会显著慢于`false`条件下时的性能。当prompt长度高于300时,`true`条件下的Prefill速度和`false`条件基本持平,Decode速度大约会快20%. 如果你希望在各种情况下Prefill速度和Decode速度更加均衡,建议设置该选项为`false`. + - 8+n(n=0,1,2): 该选项是为了加速LLM 推理时Decode性能。但是当prompt长度小于300时,Prefill速度会显著变慢。当prompt长度高于300时,Prefill速度不会变慢。 + - cpu_sme2_neon_division_ratio: 为了提高Arm SME后端多线程推理时性能,可根据模型、线程数定制化设置该参数。参数计算方式: Prefill阶段单个SME核和NEON核的工作量比例x:1,Decode阶段工作量比例y:1, + 则参数设置为8*x+y,x和y均是不大于7的正整数。41、49和33是常见的参数设置. 可以通过观察单线程推理时,SME后端相较于NEON后端的加速比来决定该参数的取值。默认是`41`. - Sampler配置 - sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler,当选择`mixed`时,依次执行mixed_samplers中的sampler)。默认为`greedy`,但是建议使用`mixed`、`temperature`来增加输出多样性,或使用`penalty`来降低重复。 - mixed_samplers: 当`sampler_type`为`mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]`, 模型计算得到的logits会依次经过这些sampler采样。 @@ -552,6 +554,15 @@ options: ./llm_bench -m ./Qwen2.5-1.5B-Instruct/config.json,./Qwen2.5-0.5B-Instruct/config.json -a cpu,opencl,metal -c 1,2 -t 8,12 -p 16,32 -n 10,20 -pg 8,16 -mmp 0 -rep 4 -kv true -fp ./test_result ``` +#### 多Prompt场景下KVCache选择性复用 +rollback_demo提供了多Prompt场景下自行选择复用部分kvcache的示例代码。 +```bash +./rollback_demo /path/to/model_dir/config.json /path/to/prompt.txt +``` +其中,prompt.txt需要包含至少三组prompt。 +- cache_prefix_in_disk需要设置为0或1。 +- cache_prefix_in_disk 设置1表示:第一段Prompt是后续Prompt的公共前缀Prompt,第二、三段Prompt分别是基于第一段Prompt后续的文本内容。第一次启动会将前缀Prompt的KVCache缓存在磁盘文件中。第二次启动会跳过公共前缀Prompt的Prefill,直接在磁盘中加载,提升Prefill速度。。 +- cache_prefix_in_disk 设置0表示:在多段Prompt下,如何删除不需要的KVCache,仅保留关联性的KVCache示例。 #### GPTQ权重 需要使用GPTQ权重,可以在导出模型时,使用`--gptq_path PATH`来指定的路径,使用如下: diff --git a/include/MNN/Interpreter.hpp b/include/MNN/Interpreter.hpp index d25af2e5a4..30a9a8a7af 100644 --- a/include/MNN/Interpreter.hpp +++ b/include/MNN/Interpreter.hpp @@ -226,11 +226,14 @@ class MNN_PUBLIC Interpreter { // Default is 50 CPU_LITTLECORE_DECREASE_RATE = 6, + // qkvQuantOption % 8: // 0: Do not quantize - // 1: Only quantize key, use int8 asymmetric quantization - // 2: Only quantize value, use fp8 quantization - // 3: quantize both key and value - // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V + // 1: Q,K: Int8, V: Float + // 2: Q,K,V: Int8 + + // qkvQuantOption / 8: + // 0: don't use flash attention + // 1: use flash attention QKV_QUANT_OPTIONS = 7, // size limit of kvcache in memory (for a single layer) @@ -255,7 +258,13 @@ class MNN_PUBLIC Interpreter { CPU_SME2_INSTRUCTIONS = 15, // Enable KleidiAI - CPU_ENABLE_KLEIDIAI = 16 + CPU_ENABLE_KLEIDIAI = 16, + + // Set CPU SME2 NEON division ratio, default is 41 + CPU_SME2_NEON_DIVISION_RATIO = 17, + + // Set SME cores, default is 2, if supports sme + CPU_SME_CORES = 18 }; enum ExternalPathType { @@ -271,6 +280,9 @@ class MNN_PUBLIC Interpreter { // Path of the NPU Model directory EXTERNAL_NPU_FILE_DIR = 3, + // Path of the kvcache directory + EXTERNAL_PATH_PREFIXCACHE_DIR = 4, + // Other types ... }; diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index b8e391e3eb..b7e4eac092 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h @@ -59,7 +59,6 @@ if(!(success)){ \ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ } -#ifndef MNN_BUILD_STATIC_LIBS #if defined(_MSC_VER) #if defined(BUILDING_MNN_DLL) #define MNN_PUBLIC __declspec(dllexport) @@ -71,13 +70,10 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ #else #define MNN_PUBLIC __attribute__((visibility("default"))) #endif -#else -#define MNN_PUBLIC -#endif #define STR_IMP(x) #x #define STR(x) STR_IMP(x) #define MNN_VERSION_MAJOR 3 #define MNN_VERSION_MINOR 3 -#define MNN_VERSION_PATCH 0 +#define MNN_VERSION_PATCH 1 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #endif /* MNNDefine_h */ diff --git a/project/android/qnnprepare.gradle b/project/android/qnnprepare.gradle index da0c6538ef..20efe0bf6b 100644 --- a/project/android/qnnprepare.gradle +++ b/project/android/qnnprepare.gradle @@ -4,7 +4,7 @@ // QNN configuration ext { // QNN download settings - QNN_LIBS_URL = 'http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs.zip' + QNN_LIBS_URL = 'http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs_2_37.zip' } def qnnZipName = 'qnn_inc_libs.zip' @@ -67,28 +67,24 @@ task prepareQnnDeps { } def findQnnDirectory(File searchDir) { - // Look for directory containing both include and jniLibs (or lib) directories - def candidates = [] + // Mirror the shell script's approach: find the 'include' directory + // anywhere in the hierarchy, then use its parent as the QNN root. + // This is more reliable than requiring both include and lib to be present + // at the same level. + + def foundIncludeDir = null searchDir.eachDirRecurse { dir -> - def hasInclude = new File(dir, 'include').exists() - def hasLibs = new File(dir, 'jniLibs').exists() || new File(dir, 'lib').exists() - - if (hasInclude && hasLibs) { - candidates.add(dir) + if (foundIncludeDir == null && dir.name == 'include') { + foundIncludeDir = dir } } - if (candidates.isEmpty()) { - // Fallback: look for directory that contains include - searchDir.eachDirRecurse { dir -> - if (new File(dir, 'include').exists()) { - candidates.add(dir) - } - } + if (foundIncludeDir != null) { + return foundIncludeDir.parentFile } - return candidates.isEmpty() ? null : candidates[0] + return null } def copyQnnFiles(File sourceDir) { diff --git a/pymnn/src/reranker.h b/pymnn/src/reranker.h index 641ca9e3c7..d369cf4921 100644 --- a/pymnn/src/reranker.h +++ b/pymnn/src/reranker.h @@ -54,6 +54,16 @@ static PyObject* PyMNNReranker_setInstruct(Reranker *self, PyObject *args) { Py_RETURN_NONE; } +static PyObject* PyMNNReranker_load(Reranker *self, PyObject *args) { + if (!self->reranker) { + PyErr_SetString(PyExc_RuntimeError, "Reranker not initialized"); + Py_RETURN_NONE; + } + + self->reranker->load(); + Py_RETURN_NONE; +} + static PyObject* PyMNNReranker_compute_scores(Reranker *self, PyObject *args) { if (!self->reranker) { PyErr_SetString(PyExc_RuntimeError, "Reranker not initialized"); @@ -118,6 +128,7 @@ static PyObject* PyMNNReranker_get_llm(Reranker *self, PyObject *args) { static PyMethodDef PyMNNReranker_methods[] = { {"set_instruct", (PyCFunction)PyMNNReranker_setInstruct, METH_VARARGS, "Set instruction for the reranker."}, + {"load", (PyCFunction)PyMNNReranker_load, METH_VARARGS, "Load the reranker model."}, {"compute_scores", (PyCFunction)PyMNNReranker_compute_scores, METH_VARARGS, "Compute scores for documents given a query."}, {"get_llm", (PyCFunction)PyMNNReranker_get_llm, METH_VARARGS, "Get the underlying LLM instance for configuration."}, {NULL} /* Sentinel */ diff --git a/source/backend/arm82/Arm82Functions.cpp b/source/backend/arm82/Arm82Functions.cpp index c9b4a3ee90..4647871e47 100644 --- a/source/backend/arm82/Arm82Functions.cpp +++ b/source/backend/arm82/Arm82Functions.cpp @@ -443,15 +443,17 @@ void ARM82StrassenMerge(FLOAT16* c11, FLOAT16* c12, FLOAT16* c21, FLOAT16* c22, } void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth, int32_t* areaOffset) { + // [depth/8, srcAreaOffset, 8] -> [area, dstAreaOffset] int srcAreaOffset = areaOffset[0]; + int dstAreaOffset = areaOffset[1]; int c = (int)depth; - int cDiv4 = c / 8; - int cAlign = cDiv4 * 8; + int cDiv8 = c / 8; + int cAlign = cDiv8 * 8; int areaDiv4 = area / 4; int areaAlign = areaDiv4 * 4; if (areaAlign > 0) { - for (int ci = 0; ci < cDiv4; ++ci) { + for (int ci = 0; ci < cDiv8; ++ci) { auto srcH = src + ci * 8 * srcAreaOffset; auto dstH = dst + ci * 8; for (int hi = 0; hi < areaAlign; hi+=4) { @@ -460,10 +462,10 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si auto src2 = srcH + hi * 8 + 16; auto src3 = srcH + hi * 8 + 24; - auto dst0 = dstH + hi * c; - auto dst1 = dstH + hi * c + c; - auto dst2 = dstH + hi * c + 2 * c; - auto dst3 = dstH + hi * c + 3 * c; + auto dst0 = dstH + hi * dstAreaOffset; + auto dst1 = dstH + hi * dstAreaOffset + dstAreaOffset; + auto dst2 = dstH + hi * dstAreaOffset + 2 * dstAreaOffset; + auto dst3 = dstH + hi * dstAreaOffset + 3 * dstAreaOffset; vst1q_s16(dst0, vld1q_s16(src0)); vst1q_s16(dst1, vld1q_s16(src1)); vst1q_s16(dst2, vld1q_s16(src2)); @@ -472,12 +474,12 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si } } if (areaAlign < area) { - for (int ci = 0; ci < cDiv4; ++ci) { + for (int ci = 0; ci < cDiv8; ++ci) { auto srcH = src + 8 * ci * srcAreaOffset; auto dstH = dst + ci * 8; for (int hi = areaAlign; hi < area; ++hi) { auto src0 = srcH + hi * 8; - auto dst0 = dstH + hi * c; + auto dst0 = dstH + hi * dstAreaOffset; vst1q_s16(dst0, vld1q_s16(src0)); } } @@ -492,7 +494,7 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si for (int hi = 0; hi < area; ++hi) { auto srcHeight = srcAlign + hi * 8; - auto dstHeight = dstAlign + hi * c; + auto dstHeight = dstAlign + hi * dstAreaOffset; for (int ci = 0; ci < cReamin; ++ci) { dstHeight[ci] = srcHeight[ci]; @@ -934,495 +936,124 @@ static void _ArmBasicMNNPackC4ForMatMul_A_L8(int8_t* destOrigin, int8_t const** } } +inline void transpose_4x4_f32(float32x4_t& r0, float32x4_t& r1, float32x4_t& r2, float32x4_t& r3) { + // Stage 1: Transpose 2x2 blocks of float32 elements between pairs of adjacent rows. + float32x4x2_t temp0 = vtrnq_f32(r0, r1); + float32x4x2_t temp1 = vtrnq_f32(r2, r3); + + // Intermediate state: + // temp0.val[0] = [A0, B0, A2, B2] + // temp0.val[1] = [A1, B1, A3, B3] + // temp1.val[0] = [C0, D0, C2, D2] + // temp1.val[1] = [C1, D1, C3, D3] + + // Stage 2: Manually swap the 64-bit blocks to finalize the transpose. + // This correctly simulates the non-existent 64-bit transpose/zip. + float64x2_t i0_f64 = vreinterpretq_f64_f32(temp0.val[0]); + float64x2_t i1_f64 = vreinterpretq_f64_f32(temp0.val[1]); + float64x2_t i2_f64 = vreinterpretq_f64_f32(temp1.val[0]); + float64x2_t i3_f64 = vreinterpretq_f64_f32(temp1.val[1]); + + // Combine the low 64 bits of i0 and i2 to form the first part of the result. + float32x4_t t0 = vreinterpretq_f32_f64(vcombine_f64(vget_low_f64(i0_f64), vget_low_f64(i2_f64))); + // Combine the low 64 bits of i1 and i3 for the second part. + float32x4_t t1 = vreinterpretq_f32_f64(vcombine_f64(vget_low_f64(i1_f64), vget_low_f64(i3_f64))); + // Combine the high 64 bits of i0 and i2 for the third part. + float32x4_t t2 = vreinterpretq_f32_f64(vcombine_f64(vget_high_f64(i0_f64), vget_high_f64(i2_f64))); + // Combine the high 64 bits of i1 and i3 for the final part. + float32x4_t t3 = vreinterpretq_f32_f64(vcombine_f64(vget_high_f64(i1_f64), vget_high_f64(i3_f64))); + + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; +} + static void Sme2MNNPackC4ForMatMul_A_FP16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) { - int LP = FP16_SME2_MATMUL_LP; - int pack = 8; - // LP >= pack + const int lP = FP16_SME2_MATMUL_LP; + const int pack = 8; int number = info[0]; int eReal = info[1]; int eDest = info[2]; int offset = info[3]; - for (int n=0; n 7) { - auto source = sourceN; - auto dest = destN; - l -= 8; - auto e = eWork; - if (e == eDest) { - auto s0 = vld1q_f32((float*)(source)); // 00112233 - auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 - auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - - auto s4 = vld1q_f32((float*)(source + 4 * srcStride0)); - auto s5 = vld1q_f32((float*)(source + 5 * srcStride0)); - auto s6 = vld1q_f32((float*)(source + 6 * srcStride0)); - auto s7 = vld1q_f32((float*)(source + 7 * srcStride0)); - - auto s8 = vld1q_f32((float*)(source + 8 * srcStride0)); - auto s9 = vld1q_f32((float*)(source + 9 * srcStride0)); - auto s10 = vld1q_f32((float*)(source + 10 * srcStride0)); - auto s11 = vld1q_f32((float*)(source + 11 * srcStride0)); - - auto s12 = vld1q_f32((float*)(source + 12 * srcStride0)); - auto s13 = vld1q_f32((float*)(source + 13 * srcStride0)); - auto s14 = vld1q_f32((float*)(source + 14 * srcStride0)); - auto s15 = vld1q_f32((float*)(source + 15 * srcStride0)); - - auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 - auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 - auto zip1s45 = vzip1q_f32(s4, s5); // 00001111 - auto zip1s67 = vzip1q_f32(s6, s7); // 00001111 - auto zip1s89 = vzip1q_f32(s8, s9); // 00001111 - auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111 - auto zip1s1213 = vzip1q_f32(s12, s13); // 00001111 - auto zip1s1415 = vzip1q_f32(s14, s15); // 00001111 - - auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 - auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 - auto zip2s45 = vzip2q_f32(s4, s5); // 22223333 - auto zip2s67 = vzip2q_f32(s6, s7); // 22223333 - auto zip2s89 = vzip2q_f32(s8, s9); // 22223333 - auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333 - auto zip2s1213 = vzip2q_f32(s12, s13); // 22223333 - auto zip2s1415 = vzip2q_f32(s14, s15); // 22223333 - - auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 - auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011); - auto zip1s12131415_01 = vzip1q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415); - - auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 - auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011); - auto zip2s12131415_01 = vzip2q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415); - - auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 - auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011); - auto zip1s12131415_23 = vzip1q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415); - - auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 - auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011); - auto zip2s12131415_23 = vzip2q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415); - - vst1q_f64((float64_t*)dest, zip1s0123_01); - vst1q_f64((float64_t*)(dest + 8), zip1s4567_01); - vst1q_f64((float64_t*)(dest + 16), zip1s891011_01); - vst1q_f64((float64_t*)(dest + 24), zip1s12131415_01); - - vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); - vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01); - vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01); - vst1q_f64((float64_t*)(dest + dstStride0 + 24), zip2s12131415_01); - - vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); - vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23); - vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23); - vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 24), zip1s12131415_23); - - vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 24), zip2s12131415_23); - - // dest += (4 * dstStride0); - // e -= eDest; - sourceN += (eReal * pack); - destN += (4 * dstStride0); - continue; - } - - if (e > 11) { - auto s0 = vld1q_f32((float*)(source)); // 00112233 - auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 - auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - - auto s4 = vld1q_f32((float*)(source + 4 * srcStride0)); - auto s5 = vld1q_f32((float*)(source + 5 * srcStride0)); - auto s6 = vld1q_f32((float*)(source + 6 * srcStride0)); - auto s7 = vld1q_f32((float*)(source + 7 * srcStride0)); - - auto s8 = vld1q_f32((float*)(source + 8 * srcStride0)); - auto s9 = vld1q_f32((float*)(source + 9 * srcStride0)); - auto s10 = vld1q_f32((float*)(source + 10 * srcStride0)); - auto s11 = vld1q_f32((float*)(source + 11 * srcStride0)); - - auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 - auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 - auto zip1s45 = vzip1q_f32(s4, s5); // 00001111 - auto zip1s67 = vzip1q_f32(s6, s7); // 00001111 - auto zip1s89 = vzip1q_f32(s8, s9); // 00001111 - auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111 - - auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 - auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 - auto zip2s45 = vzip2q_f32(s4, s5); // 22223333 - auto zip2s67 = vzip2q_f32(s6, s7); // 22223333 - auto zip2s89 = vzip2q_f32(s8, s9); // 22223333 - auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333 - - auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 - auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011); - - auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 - auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011); - - auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 - auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011); - - auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 - auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011); - - vst1q_f64((float64_t*)dest, zip1s0123_01); - vst1q_f64((float64_t*)(dest + 8), zip1s4567_01); - vst1q_f64((float64_t*)(dest + 16), zip1s891011_01); - - vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); - vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01); - vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01); - vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); - vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23); - vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23); + float32x4_t v0, v1, v2, v3, v4, v5, v6, v7; - vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23); + for (int n = 0; n < number; ++n) { + int e = el[4 * n + 0]; + int l = el[4 * n + 1]; + int eOffset = el[4 * n + 2]; + int lOffset = el[4 * n + 3]; - dest += 24; - e -= 12; - source += (12 * srcStride0); + auto destBase = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * lP; + auto sourceBase = (const FLOAT16*)(sourceGroup[n]); + + const int eTile = 8; + const int lTile = 8; + + const int eMain = e / eTile; + const int lMain = l / lTile; + + const size_t srcRowStride = (size_t)pack * offset; + const size_t srcColBlockStride = (size_t)eReal * pack; + const size_t dstColBlockStride = (size_t)eDest * lP; + + for (int y0 = 0; y0 < eMain; ++y0) { + const int yBase = y0 * eTile; + for (int x0 = 0; x0 < lMain; ++x0) { + const int xBase = x0 * lTile; + + const auto srcBlockBase = sourceBase + yBase * srcRowStride + x0 * srcColBlockStride; + + v0 = vld1q_f32((const float*)(srcBlockBase + 0 * srcRowStride)); + v1 = vld1q_f32((const float*)(srcBlockBase + 1 * srcRowStride)); + v2 = vld1q_f32((const float*)(srcBlockBase + 2 * srcRowStride)); + v3 = vld1q_f32((const float*)(srcBlockBase + 3 * srcRowStride)); + v4 = vld1q_f32((const float*)(srcBlockBase + 4 * srcRowStride)); + v5 = vld1q_f32((const float*)(srcBlockBase + 5 * srcRowStride)); + v6 = vld1q_f32((const float *)(srcBlockBase + 6 * srcRowStride)); + v7 = vld1q_f32((const float *)(srcBlockBase + 7 * srcRowStride)); + + transpose_4x4_f32(v0, v1, v2, v3); + transpose_4x4_f32(v4, v5, v6, v7); + + float* addr0 = (float*)(destBase + yBase * lP + (xBase / lP) * dstColBlockStride); + float* addr1= (float*)(destBase + yBase * lP + (xBase / lP + 1) * dstColBlockStride); + float* addr2= (float*)(destBase + yBase * lP + (xBase / lP + 2) * dstColBlockStride); + float* addr3= (float*)(destBase + yBase * lP + (xBase / lP + 3) * dstColBlockStride); + + vst1q_f32(addr0, v0); + vst1q_f32(addr0 + 4, v4); + vst1q_f32(addr1, v1); + vst1q_f32(addr1 + 4, v5); + vst1q_f32(addr2, v2); + vst1q_f32(addr2 + 4, v6); + vst1q_f32(addr3, v3); + vst1q_f32(addr3 + 4, v7); } + } - if (e > 7) { - auto s0 = vld1q_f32((float*)(source)); // 00112233 - auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 - auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - - auto s4 = vld1q_f32((float*)(source + 4 * srcStride0)); - auto s5 = vld1q_f32((float*)(source + 5 * srcStride0)); - auto s6 = vld1q_f32((float*)(source + 6 * srcStride0)); - auto s7 = vld1q_f32((float*)(source + 7 * srcStride0)); - - auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 - auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 - auto zip1s45 = vzip1q_f32(s4, s5); // 00001111 - auto zip1s67 = vzip1q_f32(s6, s7); // 00001111 - - auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 - auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 - auto zip2s45 = vzip2q_f32(s4, s5); // 22223333 - auto zip2s67 = vzip2q_f32(s6, s7); // 22223333 + const int eHandled = eMain * eTile; + const int lHandled = lMain * lTile; - auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 - auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - - auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 - auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - - auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 - auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - - auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 - auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - - vst1q_f64((float64_t*)dest, zip1s0123_01); - vst1q_f64((float64_t*)(dest + 8), zip1s4567_01); - - vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); - vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01); - - vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); - vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23); - - vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23); - - dest += 16; - e -= 8; - source += (8 * srcStride0); - } - - if (e > 3) { - auto s0 = vld1q_f32((float*)(source)); // 00112233 - auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 - auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - - auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 - auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 - - auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 - auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 - - auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 - - auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 - - auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 - - auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 - - vst1q_f64((float64_t*)dest, zip1s0123_01); - vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); - vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); - vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); - - dest += 8; - e -= 4; - source += (4 * srcStride0); - } - while (e > 0) { - auto s0 = vld1q_f32((float*)(source)); // 00112233 - - ((float*)dest)[0] = s0[0]; - ((float*)(dest + dstStride0))[0] = s0[1]; - ((float*)(dest + 2 * dstStride0))[0] = s0[2]; - ((float*)(dest + 3 * dstStride0))[0] = s0[3]; - - dest += 2; - e -= 1; - source += srcStride0; - } - sourceN += (eReal * pack); - destN += (4 * dstStride0); - } // l>7 - - if (l > 3) { - auto source = sourceN; - auto dest = destN; - l -= 4; - auto e = eWork; - if (e == eDest) { - auto s0 = vld1_f32((float*)(source)); // 0011 - auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 - auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - - auto s4 = vld1_f32((float*)(source + 4 * srcStride0)); - auto s5 = vld1_f32((float*)(source + 5 * srcStride0)); - auto s6 = vld1_f32((float*)(source + 6 * srcStride0)); - auto s7 = vld1_f32((float*)(source + 7 * srcStride0)); - - auto s8 = vld1_f32((float*)(source + 8 * srcStride0)); - auto s9 = vld1_f32((float*)(source + 9 * srcStride0)); - auto s10 = vld1_f32((float*)(source + 10 * srcStride0)); - auto s11 = vld1_f32((float*)(source + 11 * srcStride0)); - - auto s12 = vld1_f32((float*)(source + 12 * srcStride0)); - auto s13 = vld1_f32((float*)(source + 13 * srcStride0)); - auto s14 = vld1_f32((float*)(source + 14 * srcStride0)); - auto s15 = vld1_f32((float*)(source + 15 * srcStride0)); - - auto zip1s01 = vzip1_f32(s0, s1); // 0000 - auto zip1s23 = vzip1_f32(s2, s3); // 0000 - auto zip1s45 = vzip1_f32(s4, s5); // 0000 - auto zip1s67 = vzip1_f32(s6, s7); // 0000 - auto zip1s89 = vzip1_f32(s8, s9); // 0000 - auto zip1s1011 = vzip1_f32(s10, s11); // 0000 - auto zip1s1213 = vzip1_f32(s12, s13); // 0000 - auto zip1s1415 = vzip1_f32(s14, s15); // 0000 - - auto zip2s01 = vzip2_f32(s0, s1); // 1111 - auto zip2s23 = vzip2_f32(s2, s3); // 1111 - auto zip2s45 = vzip2_f32(s4, s5); // 1111 - auto zip2s67 = vzip2_f32(s6, s7); // 1111 - auto zip2s89 = vzip2_f32(s8, s9); // 1111 - auto zip2s1011 = vzip2_f32(s10, s11); // 1111 - auto zip2s1213 = vzip2_f32(s12, s13); // 1111 - auto zip2s1415 = vzip2_f32(s14, s15); // 1111 - - vst1_f32((float32_t*)dest, zip1s01); - vst1_f32((float32_t*)(dest + 4), zip1s23); - vst1_f32((float32_t*)(dest + 8), zip1s45); - vst1_f32((float32_t*)(dest + 12), zip1s67); - vst1_f32((float32_t*)(dest + 16), zip1s89); - vst1_f32((float32_t*)(dest + 20), zip1s1011); - vst1_f32((float32_t*)(dest + 24), zip1s1213); - vst1_f32((float32_t*)(dest + 28), zip1s1415); - - vst1_f32((float32_t*)(dest + dstStride0), zip2s01); - vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); - vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45); - vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67); - vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89); - vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011); - vst1_f32((float32_t*)(dest + dstStride0 + 24), zip2s1213); - vst1_f32((float32_t*)(dest + dstStride0 + 28), zip2s1415); - - - dest += 32; - e -= eDest; - } - - if (e > 11) { - auto s0 = vld1_f32((float*)(source)); // 0011 - auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 - auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - - auto s4 = vld1_f32((float*)(source + 4 * srcStride0)); - auto s5 = vld1_f32((float*)(source + 5 * srcStride0)); - auto s6 = vld1_f32((float*)(source + 6 * srcStride0)); - auto s7 = vld1_f32((float*)(source + 7 * srcStride0)); - - auto s8 = vld1_f32((float*)(source + 8 * srcStride0)); - auto s9 = vld1_f32((float*)(source + 9 * srcStride0)); - auto s10 = vld1_f32((float*)(source + 10 * srcStride0)); - auto s11 = vld1_f32((float*)(source + 11 * srcStride0)); - - auto zip1s01 = vzip1_f32(s0, s1); // 0000 - auto zip1s23 = vzip1_f32(s2, s3); // 0000 - auto zip1s45 = vzip1_f32(s4, s5); // 0000 - auto zip1s67 = vzip1_f32(s6, s7); // 0000 - auto zip1s89 = vzip1_f32(s8, s9); // 0000 - auto zip1s1011 = vzip1_f32(s10, s11); // 0000 - - auto zip2s01 = vzip2_f32(s0, s1); // 1111 - auto zip2s23 = vzip2_f32(s2, s3); // 1111 - auto zip2s45 = vzip2_f32(s4, s5); // 1111 - auto zip2s67 = vzip2_f32(s6, s7); // 1111 - auto zip2s89 = vzip2_f32(s8, s9); // 1111 - auto zip2s1011 = vzip2_f32(s10, s11); // 1111 - - vst1_f32((float32_t*)dest, zip1s01); - vst1_f32((float32_t*)(dest + 4), zip1s23); - vst1_f32((float32_t*)(dest + 8), zip1s45); - vst1_f32((float32_t*)(dest + 12), zip1s67); - vst1_f32((float32_t*)(dest + 16), zip1s89); - vst1_f32((float32_t*)(dest + 20), zip1s1011); - - vst1_f32((float32_t*)(dest + dstStride0), zip2s01); - vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); - vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45); - vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67); - vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89); - vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011); - - dest += 24; - e -= 12; - source += (12 * srcStride0); - } - - if (e > 7) { - auto s0 = vld1_f32((float*)(source)); // 0011 - auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 - auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - - auto s4 = vld1_f32((float*)(source + 4 * srcStride0)); - auto s5 = vld1_f32((float*)(source + 5 * srcStride0)); - auto s6 = vld1_f32((float*)(source + 6 * srcStride0)); - auto s7 = vld1_f32((float*)(source + 7 * srcStride0)); - - auto zip1s01 = vzip1_f32(s0, s1); // 0000 - auto zip1s23 = vzip1_f32(s2, s3); // 0000 - auto zip1s45 = vzip1_f32(s4, s5); // 0000 - auto zip1s67 = vzip1_f32(s6, s7); // 0000 - - auto zip2s01 = vzip2_f32(s0, s1); // 1111 - auto zip2s23 = vzip2_f32(s2, s3); // 1111 - auto zip2s45 = vzip2_f32(s4, s5); // 1111 - auto zip2s67 = vzip2_f32(s6, s7); // 1111 - - vst1_f32((float32_t*)dest, zip1s01); - vst1_f32((float32_t*)(dest + 4), zip1s23); - vst1_f32((float32_t*)(dest + 8), zip1s45); - vst1_f32((float32_t*)(dest + 12), zip1s67); - - vst1_f32((float32_t*)(dest + dstStride0), zip2s01); - vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); - vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45); - vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67); - - dest += 16; - e -= 8; - source += (8 * srcStride0); - } - - if (e > 3) { - auto s0 = vld1_f32((float*)(source)); // 0011 - auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 - auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); - auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - - auto zip1s01 = vzip1_f32(s0, s1); // 0000 - auto zip1s23 = vzip1_f32(s2, s3); // 0000 - - auto zip2s01 = vzip2_f32(s0, s1); // 1111 - auto zip2s23 = vzip2_f32(s2, s3); // 1111 - - vst1_f32((float32_t*)dest, zip1s01); - vst1_f32((float32_t*)(dest + 4), zip1s23); - - vst1_f32((float32_t*)(dest + dstStride0), zip2s01); - vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); - - dest += 8; - e -= 4; - source += (4 * srcStride0); - } - if (e > 1) { - auto s0 = vld1_f32((float*)(source)); // 0011 - auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 - - auto zip1s01 = vzip1_f32(s0, s1); // 0000 - - auto zip2s01 = vzip2_f32(s0, s1); // 1111 - - vst1_f32((float32_t*)dest, zip1s01); - - vst1_f32((float32_t*)(dest + dstStride0), zip2s01); - - dest += 4; - e -= 2; - source += (2 * srcStride0); - } - if (e > 0) { - auto s0 = vld1_f32((float*)(source)); // 0011 - - ((float*)dest)[0] = s0[0]; - ((float*)(dest + dstStride0))[0] = s0[1]; + // Process remaining rows + for (int y = eHandled; y < e; ++y) { + int yR = y % eDest; + for (int x = 0; x < l; ++x) { + int xR = x % pack; + int xC = x / pack; + destBase[(x / lP) * dstColBlockStride + yR * lP + (x % lP)] = sourceBase[xC * srcColBlockStride + y * srcRowStride + xR]; } - sourceN += 4; - destN += (2 * dstStride0); - } + } - auto source = (FLOAT16*)(sourceGroup[n]); - auto dest = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP; - if (l > 0) { - auto e = eWork; - auto lRemain = lWork - l; - // if e < eDest, packed A -> [LU, eDest, LP] eDest=eP - for (int y=0; y 1) { + for (int s = 0; s < seqLen; ++s) { + const FLOAT16* keySrc = sourceFp16 + s * kvNumHead * headDim + kvHeadIdx * headDim; + int d = 0; + for (; d <= headDim - 16; d += 16) { + float16x8_t maxVec0 = vld1q_f16(maxKeyFp16 + d); + float16x8_t maxVec1 = vld1q_f16(maxKeyFp16 + d + 8); + float16x8_t srcVec0 = vld1q_f16(keySrc + d); + float16x8_t srcVec1 = vld1q_f16(keySrc + d + 8); + maxVec0 = vmaxq_f16(maxVec0, srcVec0); + maxVec1 = vmaxq_f16(maxVec1, srcVec1); + vst1q_f16(maxKeyFp16 + d, maxVec0); + vst1q_f16(maxKeyFp16 + d + 8, maxVec1); + } + for (; d <= headDim - 8; d += 8) { + float16x8_t maxVec = vld1q_f16(maxKeyFp16 + d); + float16x8_t srcVec = vld1q_f16(keySrc + d); + maxVec = vmaxq_f16(maxVec, srcVec); + vst1q_f16(maxKeyFp16 + d, maxVec); + } + for (; d < headDim; ++d) { + maxKeyFp16[d] = ALIMAX(maxKeyFp16[d], keySrc[d]); + } + } + } + + // Quant fp16 + for (int s = 0; s < seqLen; s++) { + const FLOAT16* keySrc = sourceFp16 + s * kvNumHead * headDim + kvHeadIdx * headDim; + + float16x8_t minVec = vdupq_n_f16(keySrc[0]); + float16x8_t maxVec = vdupq_n_f16(keySrc[0]); + + int d = 0; + for (; d <= headDim - 8; d += 8) { + float16x8_t srcVec = vld1q_f16(keySrc + d); + float16x8_t maxKeyVec = vld1q_f16(maxKeyFp16 + d); + float16x8_t keyDataF16 = vsubq_f16(srcVec, maxKeyVec); + + minVec = vminq_f16(minVec, keyDataF16); + maxVec = vmaxq_f16(maxVec, keyDataF16); + + float32x4_t keyDataF32Low = vcvt_f32_f16(vget_low_f16(keyDataF16)); + float32x4_t keyDataF32High = vcvt_f32_f16(vget_high_f16(keyDataF16)); + } + + FLOAT16 minKey = vminvq_f16(minVec); + FLOAT16 maxKey = vmaxvq_f16(maxVec); + + for (; d < headDim; ++d) { + auto keydata = keySrc[d] - maxKeyFp16[d]; + minKey = ALIMIN(minKey, keydata); + maxKey = ALIMAX(maxKey, keydata); + } + + int outIndex = (pastLength + s) / hP; + int inIndex = (pastLength + s) % hP; + + float range = (float)maxKey - (float)minKey; + float quantScaleVal = 0; + float biasVal = minKey + 128.0f * range / 255.0; + if (range <= 1e-6f) { + quantScaleVal = 0.f; + } else { + quantScaleVal = 255.0f / range; + } + + for (int k = 0; k < blockNum; ++k) { + int8_t* weightDstBase = dst + outIndex * blockNum * packedWeightStride1 + k * packedWeightStride1; + float* scaleDst = (float*)(weightDstBase + weightStride1); + float* biasDst = scaleDst + hP; + + scaleDst[inIndex] = range / 255.f; + biasDst[inIndex] = biasVal; + + float32x4_t scaleVecFp32 = vdupq_n_f32(quantScaleVal); + float32x4_t negMinKeyVecF32 = vdupq_n_f32(-(float)minKey); + + const FLOAT16* currentKeyBlock = keySrc + k * blockHeadDim; + const FLOAT16* currentMaxBlock = maxKeyFp16 + k * blockHeadDim; + + int32x4_t sumInt32_0 = vdupq_n_s32(0); + int32x4_t sumInt32_1 = vdupq_n_s32(0); + int headDimIdx = 0; + for (; headDimIdx <= blockHeadDim - 8; headDimIdx += 8) { + float16x8_t srcVecFp16 = vld1q_f16(currentKeyBlock + headDimIdx); + float16x8_t maxVecFp16 = vld1q_f16(currentMaxBlock + headDimIdx); + + float16x8_t keyDataF16 = vsubq_f16(srcVecFp16, maxVecFp16); + + float32x4_t keyDataLowFp32 = vcvt_f32_f16(vget_low_f16(keyDataF16)); + float32x4_t keyDataHighFp32 = vcvt_f32_f16(vget_high_f16(keyDataF16)); + + keyDataLowFp32 = vaddq_f32(keyDataLowFp32, negMinKeyVecF32); + keyDataHighFp32 = vaddq_f32(keyDataHighFp32, negMinKeyVecF32); + + keyDataLowFp32 = vmulq_f32(keyDataLowFp32, scaleVecFp32); + keyDataHighFp32 = vmulq_f32(keyDataHighFp32, scaleVecFp32); + + keyDataLowFp32 = vaddq_f32(keyDataLowFp32, neg128Vec); + keyDataHighFp32 = vaddq_f32(keyDataHighFp32, neg128Vec); + + int32x4_t keyDataLowInt32 = vcvtaq_s32_f32(keyDataLowFp32); + int32x4_t keyDataHighInt32 = vcvtaq_s32_f32(keyDataHighFp32); + + int16x4_t s16Low = vmovn_s32(keyDataLowInt32); + int16x4_t s16High = vmovn_s32(keyDataHighInt32); + + int16x8_t s16Combined = vcombine_s16(s16Low, s16High); + + // sum + sumInt32_0 = vaddq_s32(sumInt32_0, keyDataLowInt32); + sumInt32_1 = vaddq_s32(sumInt32_1, keyDataHighInt32); + + int8x8_t s8Vec = vqmovn_s16(s16Combined); + + if (lP == 8) { + int i = headDimIdx / lP; + int8_t* dstPtr = weightDstBase + i * weightStride2 + inIndex * lP; + vst1_s8(dstPtr, s8Vec); + } else if (lP == 4) { + vst1_s8(tempBuffer, s8Vec); + int iLow = headDimIdx / lP; + int iHigh = (headDimIdx + 4) / lP; + + int8_t* dstPtrLow = weightDstBase + iLow * weightStride2 + inIndex * lP; + int8_t* dstPtrHigh = weightDstBase + iHigh * weightStride2 + inIndex * lP; + + std::memcpy(dstPtrLow, tempBuffer, 4); + std::memcpy(dstPtrHigh, tempBuffer + 4, 4); + } else { + vst1_s8(tempBuffer, s8Vec); + for (int nk = 0; nk < 8; ++nk) { + int headDimCurr = headDimIdx + nk; + int i = headDimCurr / lP; + int j = headDimCurr % lP; + weightDstBase[i * weightStride2 + inIndex * lP + j] = tempBuffer[nk]; + } + } + + } + + int32_t sumInt32 = vaddvq_s32(sumInt32_0) + vaddvq_s32(sumInt32_1); + + for (; headDimIdx < blockHeadDim; ++headDimIdx) { + int i = headDimIdx / lP; + int j = headDimIdx % lP; + float keyVal = (float)currentKeyBlock[headDimIdx] - (float)currentMaxBlock[headDimIdx]; + float quantVal = (keyVal - minKey) * quantScaleVal - 128.0f; + int32_t roundedVal = static_cast(roundf(quantVal)); + int8_t finalVal = static_cast(std::max(-128, std::min(127, roundedVal))); + weightDstBase[i * weightStride2 + inIndex * lP + j] = finalVal; + sumInt32 += finalVal; + } + + // store sum + sumKeyPtr[outIndex * hP + inIndex] = sumInt32 * range / 255.f + (minKey * (float)(blockHeadDim) + 128.0f * range * (float)(blockHeadDim) / 255.0); + } + } +} + +static void MNNQuantAttentionValueFP16(int8_t* dst, const float* source, float* valueSum, int32_t* params) { + // float value src : [kvSeq,kvNumHead,headDim] + // int8_t value dest: [updiv(maxLength,flashAttentionBlockKv), updiv(headDim,hp),updiv(flashAttentionBlockKv,lp),hp,lp] + // float value sum: [updiv(maxLength,flashAttentionBlockKv), roundup(headDim,hp)] + int32_t kvNumHead = params[0]; + int32_t seqLen = params[1]; + int32_t headDim = params[2]; + int32_t blockNum = params[3]; + int32_t maxLength = params[4]; + + int32_t lP = params[5]; + int32_t hP = params[6]; + int32_t pastLength = params[7]; + int32_t kvHeadIdx = params[8]; + + int32_t flashAttentionBlockKv = params[9]; + + auto blockKvseq = UP_DIV(seqLen + pastLength, blockNum); + auto weightStride2 = lP * hP; + auto weightStride1 = UP_DIV(flashAttentionBlockKv, lP) * weightStride2; + + auto packedStride1 = (int)(weightStride1 + 2 * hP * sizeof(float)); + auto packedStride0 = UP_DIV(headDim, hP) * packedStride1; + + auto srcStride0 = kvNumHead * headDim; + + auto sourceFp16 = (FLOAT16*)source; + + // quant scale & bias + if (pastLength == 0) { + for (int d = 0; d < headDim; ++d) { + float* scalePtr = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasPtr = scalePtr + hP; + + // find min,max + float dMax = sourceFp16[d + kvHeadIdx * headDim]; + float dMin = dMax; + for (int s = 0; s < seqLen; ++s) { + float data = sourceFp16[s * srcStride0 + d + kvHeadIdx * headDim]; + dMax = ALIMAX(dMax, data); + dMin = ALIMIN(dMin, data); + } + + // scale & bias + float range = dMax - dMin; + if (range < 1e-6) { + scalePtr[0] = 0.f; + biasPtr[0] = dMax; + } else { + float scale = range / 255.f; + float bias = range / 255.f * 128.f + dMin; + scalePtr[0] = scale; + biasPtr[0] = bias; + } + } + } + + // copy the scale&bias to each blockKv + // pastLength == 0: First time prefill + // (seqLen + pastLength) % flashAttentionBlockKv == 0: Open a new blockKv + if (pastLength == 0 || (pastLength % flashAttentionBlockKv) == 0) { + int32_t d0 = UP_DIV(maxLength, flashAttentionBlockKv); + int32_t d1 = UP_DIV(headDim, hP); + for (int k = 0; k < d0; ++k) { + for (int r = 0; r < d1; ++r) { + float* scalePtr = (float*)(dst + k * packedStride0 + r * packedStride1 + weightStride1); + float* biasPtr = scalePtr + hP; + memcpy(scalePtr, dst + r * packedStride1 + weightStride1, hP * sizeof(float)); + memcpy(biasPtr, dst + r * packedStride1 + weightStride1 + hP * sizeof(float), hP * sizeof(float)); + } + } + } + + std::vector qScales(headDim); + std::vector qBiases(headDim); + std::vector deqScales(headDim); + std::vector deqBiases(headDim); + int8_t tmpQ[8]; + + for (int d = 0; d < headDim; ++d) { + float* scaleBase = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasBase = scaleBase + hP; + + float s_val = scaleBase[0]; + float b_val = biasBase[0]; + + deqScales[d] = s_val; + deqBiases[d] = b_val; + + bool is_small = s_val < 1e-6f; + qScales[d] = is_small ? 0.0f : (1.0f / s_val); + qBiases[d] = is_small ? 0.0f : (-b_val / s_val); + } + + const __fp16* srcBasePtr = sourceFp16 + kvHeadIdx * headDim; + + const int32_t sumStride = ROUND_UP(headDim, hP); + + for (int s = 0; s < seqLen; ++s) { + int kvSeqIndx = s + pastLength; + + int blkIdx = kvSeqIndx / flashAttentionBlockKv; + int blkRem = kvSeqIndx % flashAttentionBlockKv; + + int idxInnerCommon = blkIdx * packedStride0 + (blkRem / lP) * weightStride2 + (blkRem % lP); + + float* curSumRow = valueSum + blkIdx * sumStride; + + const __fp16* srcRow = srcBasePtr + s * srcStride0; + + int d = 0; + for (; d <= headDim - 8; d += 8) { + // --- Load Source --- + float16x8_t vSrc16 = vld1q_f16(srcRow + d); + float32x4_t vSrc0 = vcvt_f32_f16(vget_low_f16(vSrc16)); + float32x4_t vSrc1 = vcvt_high_f32_f16(vSrc16); + + // --- Load Quant Params --- + float32x4_t vQs0 = vld1q_f32(&qScales[d]); + float32x4_t vQb0 = vld1q_f32(&qBiases[d]); + float32x4_t vQs1 = vld1q_f32(&qScales[d + 4]); + float32x4_t vQb1 = vld1q_f32(&qBiases[d + 4]); + + // --- Quantize: x * qs + qb --- + float32x4_t vRes0 = vaddq_f32(vmulq_f32(vSrc0, vQs0), vQb0); + float32x4_t vRes1 = vaddq_f32(vmulq_f32(vSrc1, vQs1), vQb1); + + // --- Round & Saturate --- + int32x4_t vInt32_0 = vcvtaq_s32_f32(vRes0); + int32x4_t vInt32_1 = vcvtaq_s32_f32(vRes1); + + int16x8_t vInt16 = vcombine_s16(vqmovn_s32(vInt32_0), vqmovn_s32(vInt32_1)); + int8x8_t vInt8 = vqmovn_s16(vInt16); // Clamp to [-128, 127] + + vst1_s8(tmpQ, vInt8); + for (int k = 0; k < 8; ++k) { + int cur_d = d + k; + int dstOffset = (cur_d / hP) * packedStride1 + idxInnerCommon + (cur_d % hP) * lP; + dst[dstOffset] = tmpQ[k]; + } + + int16x8_t vXq16 = vmovl_s8(vInt8); + float32x4_t vXqF0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vXq16))); + float32x4_t vXqF1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vXq16))); + + float32x4_t vDs0 = vld1q_f32(&deqScales[d]); + float32x4_t vDb0 = vld1q_f32(&deqBiases[d]); + float32x4_t vDs1 = vld1q_f32(&deqScales[d + 4]); + float32x4_t vDb1 = vld1q_f32(&deqBiases[d + 4]); + + // Dequant + float32x4_t vDeq0 = vaddq_f32(vmulq_f32(vXqF0, vDs0), vDb0); + float32x4_t vDeq1 = vaddq_f32(vmulq_f32(vXqF1, vDs1), vDb1); + + float* sumPtr = curSumRow + d; + vst1q_f32(sumPtr, vaddq_f32(vld1q_f32(sumPtr), vDeq0)); + vst1q_f32(sumPtr + 4, vaddq_f32(vld1q_f32(sumPtr + 4), vDeq1)); + } + + for (; d < headDim; ++d) { + float xf = (float)srcRow[d]; + + float val_f = xf * qScales[d] + qBiases[d]; + int32_t val_i = (int32_t)roundf(val_f); + if (val_i > 127) val_i = 127; + if (val_i < -128) val_i = -128; + int8_t xq = (int8_t)val_i; + + int dstOffset = (d / hP) * packedStride1 + idxInnerCommon + (d % hP) * lP; + dst[dstOffset] = xq; + + curSumRow[d] += ((float)xq * deqScales[d] + deqBiases[d]); + } + } + +/* + // Quant fp16 + for (int d = 0; d < headDim; ++d) { + // dst address + int idxBase = (d / hP) * packedStride1 + (d % hP) * lP; + int8_t* dstBase = dst + idxBase; + float* scaleBase = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasBase = scaleBase + hP; + float* sumBase = valueSum + (d / hP) * hP + (d % hP); + + float qscale = scaleBase[0] < 1e-6 ? 0 : 1.0f / scaleBase[0]; + float qbias = scaleBase[0] < 1e-6 ? 0 : (-biasBase[0] / scaleBase[0]); + // quant + for (int s = 0; s < seqLen; ++s) { + int kvSeqIndx = s + pastLength; + int idxInner = (kvSeqIndx / flashAttentionBlockKv) * packedStride0 + (kvSeqIndx % flashAttentionBlockKv) / lP * weightStride2 + (kvSeqIndx % flashAttentionBlockKv) % lP; + float xf = sourceFp16[s * srcStride0 + d + kvHeadIdx * headDim]; + int8_t xq = ALIMAX(ALIMIN(127, static_cast(roundf(xf * qscale + qbias))), -128); + dstBase[idxInner] = xq; + + // sum + int idxSum = (kvSeqIndx / flashAttentionBlockKv) * ROUND_UP(headDim, hP); + sumBase[idxSum] += ((float)xq * scaleBase[0] + biasBase[0]); + } + } + */ +} + #endif // MNN_SUPPORT_TRANSFORMER_FUSE #ifdef MNN_LOW_MEMORY @@ -2267,9 +2284,8 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float // dequant scale/bias : [EU, blockNum, step] // quant scale/bias: [blockNum, plane] if (info[7] == 1) { // scale&bias:[1] - ARM82CountMinMaxValue(src, dstMin, dstMax, kernelsize * stride0); - float maxval = *(FLOAT16*)dstMax; - float minval = *(FLOAT16*)dstMin; + FLOAT16 maxval, minval; + ARM82CountMinMaxValue(src, (float*)(&minval), (float*)(&maxval) , kernelsize * stride0); if (info[8] == 1 && (maxval - minval) > 1e-7) { if (minval > 0.f) { minval = 0.f; @@ -2390,6 +2406,230 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float #endif // MNN_LOW_MEMORY +#define EXP_APPROX_MIN_INPUT vdupq_n_f32(-88.0f) +#define EXP_APPROX_MAX_INPUT vdupq_n_f32(88.0f) +#define EXP_APPROX_LN2 vdupq_n_f32(0.69314718056f) // ln(2) +#define EXP_APPROX_LN2_INV vdupq_n_f32(1.44269504089f) // 1/ln(2) +// Fourth-order polynomial approximation coefficients of exp(r): +// P(x) = c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0 +#define EXP_APPROX_C4 vdupq_n_f32(0.0416624f) +#define EXP_APPROX_C3 vdupq_n_f32(0.166665f) +#define EXP_APPROX_C2 vdupq_n_f32(0.500000f) +#define EXP_APPROX_C1 vdupq_n_f32(1.0f) +#define EXP_APPROX_C0 vdupq_n_f32(1.0f) + +#ifndef __aarch64__ +static inline float32x4_t vrndaq_f32_compat(float32x4_t x) { + float32x4_t sign = vbslq_f32(vdupq_n_u32(0x80000000), x, vdupq_n_f32(0.0f)); + return vcvtq_f32_s32(vcvtq_s32_f32(vaddq_f32(x, vbslq_f32(vcltq_f32(x, vdupq_n_f32(0.0f)), vdupq_n_f32(-0.5f), vdupq_n_f32(0.5f))))); +} +#endif + +static inline float32x4_t expApprox(float32x4_t x) { + x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT); + + float32x4_t k_float; + float32x4_t r; + float32x4_t exp_r; +#if defined(__aarch64__) + k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV)); + + // r = x - k * ln(2) + r = vfmsq_f32(x, k_float, EXP_APPROX_LN2); + + // P(r) = (c0 + c2*r^2 + c4*r^4) + r*(c1 + c3*r^2) + float32x4_t r2 = vmulq_f32(r, r); + float32x4_t p_odd = vfmaq_f32(EXP_APPROX_C1, EXP_APPROX_C3, r2); + + float32x4_t p_even = vfmaq_f32(EXP_APPROX_C0, EXP_APPROX_C2, r2); + p_even = vfmaq_f32(p_even, EXP_APPROX_C4, vmulq_f32(r2, r2)); + exp_r = vfmaq_f32(p_even, p_odd, r); +#else + + k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV)); + + + r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2)); + + // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4))) + exp_r = vmlaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r + exp_r = vmlaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...) + exp_r = vmlaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...) + exp_r = vmlaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...) + +#endif + + int32x4_t k_int = vcvtq_s32_f32(k_float); + int32x4_t k_shifted = vshlq_n_s32(k_int, 23); + return vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted)); +} +static void MNNSoftmaxFp16(float* dest, const float* source, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) { + const int reduceSize_8 = UP_DIV(reduceSize, 8); + auto softmaxDst = (FLOAT16*)dest; + auto softmaxSrc = (FLOAT16*)source; + + // source shape: [reduceSizeOuter, outside, reduceSizeInner] + // for C4, [up_div(reduceSize,4), outside,4] => reduceSizeOuter=up_div(reduceSize,4), reduceSizeInner=4 + // for C, [outside, reduceSize] => reduceSizeOuter=1, reduceSizeInner=reduceSize + + const int packUnit = 8; + int reduceSizeOuter = 1; + int reduceSizeInner = reduceSize; + int stride0 = packUnit; + if (pack > 1) { + reduceSizeOuter = UP_DIV(reduceSize, pack); + reduceSizeInner = pack; + stride0 = outside * reduceSizeInner; + } + + + for (int k = 0; k < outside; ++k) { + if (mask && kvSeqOffset > k + validOffset) { + if (updateScale){ + updateScale[k] = 1; + } + for (int j = 0; j < reduceSizeOuter; ++j) { + int i = 0; + for (; i < reduceSizeInner; i += packUnit) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner + i; + vst1q_f16(destPtr, vdupq_n_f16(0.0f)); + } + if (i < reduceSizeInner) { + memset(softmaxDst + j * stride0 + k * reduceSizeInner + i, 0, (reduceSizeInner - i) * sizeof(__fp16)); + } + } + continue; + } + + const int validReduceSize = mask ? ALIMIN(reduceSize, k + (validOffset + 1) - kvSeqOffset) : reduceSize; + const int remain = validReduceSize % packUnit; + const int sizeDiv = validReduceSize / packUnit; + + // 1. newMax + float oldMax = -65504.0f; + if (runningMax) { + oldMax = runningMax[k]; + } + + auto newMaxVec = vdupq_n_f16(-65504.0f); + + for (int j = 0; j < sizeDiv; ++j) { + auto srcPtr = softmaxSrc + j * stride0 + k * reduceSizeInner; + float16x8_t srcVec = vld1q_f16(srcPtr); + newMaxVec = vmaxq_f16(newMaxVec, srcVec); + } + float newMax = vmaxvq_f16(newMaxVec); + + if (remain > 0) { + auto srcPtr = softmaxSrc + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + newMax = ALIMAX(newMax, (float)srcPtr[i]); + } + } + + const float finalMax = ALIMAX(oldMax, newMax); + const float32x4_t finalMaxVec = vdupq_n_f32(finalMax); + + // 2. exp(x - finalMax) + float sum = 0.0f; + float32x4_t sumVec0 = vdupq_n_f32(0.0f); + float32x4_t sumVec1 = vdupq_n_f32(0.0f); + + for (int j = 0; j < sizeDiv; ++j) { + auto idx = j * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + + float16x8_t srcVec = vld1q_f16(srcPtr); + + // F16 -> F32 + float32x4_t srcLo = vcvt_f32_f16(vget_low_f16(srcVec)); + float32x4_t srcHi = vcvt_f32_f16(vget_high_f16(srcVec)); + + // sub max + srcLo = vsubq_f32(srcLo, finalMaxVec); + srcHi = vsubq_f32(srcHi, finalMaxVec); + + // exp + srcLo = expApprox(srcLo); + srcHi = expApprox(srcHi); + + // sum + sumVec0 = vaddq_f32(sumVec0, srcLo); + sumVec1 = vaddq_f32(sumVec1, srcHi); + + // F32 -> F16 and store + vst1q_f16(dstPtr, vcombine_f16(vcvt_f16_f32(srcLo), vcvt_f16_f32(srcHi))); + } + + if (remain > 0) { + auto idx = sizeDiv * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + + __fp16 tempDst[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + + for(int i = 0; i < remain; ++i) { + float val = expf((float)srcPtr[i] - finalMax); + sum += val; + tempDst[i] = (__fp16)val; + } + if (pack > 1) { + memcpy(dstPtr, tempDst, packUnit * sizeof(__fp16)); + } else { + memcpy(dstPtr, tempDst, remain * sizeof(__fp16)); + } + } + + sum += vaddvq_f32(sumVec0) + vaddvq_f32(sumVec1); + + // 3. update runningMax, runningSum, scale or normalize softmax results + if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { + // update runningSum, runningMax, scale + float scaleForSum = expf(oldMax - finalMax); + runningSum[k] = runningSum[k] * scaleForSum + sum; + runningMax[k] = finalMax; + updateScale[k] = scaleForSum; + } else { + // Normalize softmax results + if (runningMax != nullptr && runningSum != nullptr) { + sum += runningSum[k] * expf(oldMax - finalMax); + } + float scale = 1.0f / (sum + 1e-20f); + float16x8_t scale_vec = vdupq_n_f16((__fp16)scale); + + for (int j = 0; j < sizeDiv; ++j) { + auto pDest = softmaxDst + j * stride0 + k * reduceSizeInner; + float16x8_t data = vld1q_f16(pDest); + data = vmulq_f16(data, scale_vec); + vst1q_f16(pDest, data); + } + + if (remain > 0) { + auto pDest = softmaxDst + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + pDest[i] = (__fp16)((float)pDest[i] * scale); + } + } + } + + // 4. memset invalid positions to zero + if (pack > 1) { + if (validReduceSize % packUnit > 0) { + memset(softmaxDst + sizeDiv * stride0 + k * reduceSizeInner + (validReduceSize % packUnit), 0, (packUnit - (validReduceSize % packUnit)) * sizeof(__fp16)); + } + auto validDiv8 = UP_DIV(validReduceSize, packUnit); + auto allDiv8 = UP_DIV(reduceSize, packUnit); + for (int j = validDiv8; j < allDiv8; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, packUnit * sizeof(__fp16)); + } + } else { + memset(softmaxDst + k * reduceSizeInner + validReduceSize, 0, (reduceSize - validReduceSize) * sizeof(__fp16)); + } + } +} + static CoreFunctions* gInstance = nullptr; static CoreInt8Functions* gArm82CoreInt8Functions = nullptr; @@ -2401,11 +2641,11 @@ bool Arm82Functions::init() { gArm82CoreInt8Functions = new CoreInt8Functions; *gArm82CoreInt8Functions = *MNNGetInt8CoreFunctions(); gInstance->int8MatmulRelatedFunctions = origin->int8MatmulRelatedFunctions; - gInstance->sme2Int8MatmulRelatedFuncionsHp32 = origin->sme2Int8MatmulRelatedFuncionsHp32; { if (origin->supportSDot) { gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>; - gInstance->supportSDot = true; + gInstance->arm82MatmulRelatedFunctions = origin->arm82MatmulRelatedFunctions; + gInstance->arm82MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>; } if (origin->supportI8mm) { gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L8<10, 8>; @@ -2435,7 +2675,7 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge); gInstance->MNNReorderWeightInt4 = origin->MNNReorderWeightInt4; gInstance->MNNSumWeightInt8 = origin->MNNSumWeightInt8; - gInstance->MNNSumWeightInt8SmeHp64 = origin->MNNSumWeightInt8SmeHp64; + gInstance->MNNSumWeightInt8SmeHp128 = origin->MNNSumWeightInt8SmeHp128; gInstance->penalty = 2.0f; FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16); FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16); @@ -2456,12 +2696,13 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A); FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B); - FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, origin->MNNSoftmax); + FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, MNNSoftmaxFp16); #if defined(__aarch64__) gInstance->supportFp16arith = origin->supportFp16arith; gInstance->supportSDot = origin->supportSDot; gInstance->supportI8mm = origin->supportI8mm; gInstance->supportSME2 = origin->supportSME2; + gInstance->smeCoreNumber = origin->smeCoreNumber; #ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM // Weight Dequant Gemm Kernels FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8); @@ -2478,6 +2719,7 @@ bool Arm82Functions::init() { if (origin->supportSDot) { FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm82); + gInstance->arm82MatmulRelatedFunctions.MNNGeneralIm2Col = MNNGeneralIm2col_Arm82; } if (origin->supportI8mm) { FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm86); @@ -2494,6 +2736,8 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndConvertFp32, MNNAttenPackAndConvertFp32); FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndScaleSingleHead, MNNAttenPackAndScaleSingleHead); FUNC_PTR_ASSIGN(gInstance->MNNFlashAttentionUpdateBlockOutput, MNNFlashAttentionUpdateBlockOutput); + gInstance->MNNQuantAttentionKey = MNNQuantAttentionKeyFP16; + gInstance->MNNQuantAttentionValue = MNNQuantAttentionValueFP16; #endif // MNN_SUPPORT_TRANSFORMER_FUSE gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16; @@ -2520,11 +2764,11 @@ bool Arm82Functions::init() { gInstance->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A; gInstance->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col; } + #ifdef __aarch64__ #ifdef MNN_SME2 if (origin->supportSME2) { gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>; - gInstance->sme2Int8MatmulRelatedFuncionsHp32.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>; FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16_SME2); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16_SME2); @@ -2534,12 +2778,16 @@ bool Arm82Functions::init() { #ifdef MNN_LOW_MEMORY FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Fp16Sme2); - gInstance->sme2Int8MatmulRelatedFuncionsHp32.MNNGeneralIm2Col = MNNGeneralIm2col_Fp16Sme2; #endif } #endif // MNN_SME2 #endif // __aarch64__ + // Update the function pointers in the int8MatmulRelatedFunctions struct. + gInstance->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A; + gInstance->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col; + + return true; } diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S index a72cc35f18..c7acea666a 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S @@ -131,19 +131,20 @@ ldr x23, [x6, #56] // fp32minmax ldr x26, [x6, #64] // blockNum lsl x22, x7, #2 // eDest * SRC_UNIT -mov x14, #-32 +mov x25, #-32 cbz x27, TILE_12 -mov x14, #16 +sub x25, x22, #32 TILE_12: cmp x7, #12 blt TILE_8 sub x4, x4, #128 + mov x12, x2 + mov x6, x0 + mov x14, x5 mov x20, x9 mov x15, x8 // input kernel sum - mov x6, x27 // input dequant bias mov x21, x24 // input dequant scale - mov x12, #-320 L8LoopDz_TILE_12: mov x11, x1 mov x19, #0 @@ -157,8 +158,8 @@ TILE12_BLOCKNUM: SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: - ld1 {v3.16b, v4.16b}, [x2], #32 // weight - ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + ld1 {v3.16b, v4.16b}, [x12], #32 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] @@ -168,6 +169,7 @@ TILE12_BLOCKNUM: .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] @@ -191,9 +193,8 @@ TILE12_BLOCKNUM: L8LoopSzEnd_TILE_12: L8Tile12Quan: - ld1 {v0.4s, v1.4s}, [x2], #32 // scale - ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // input kernel sum - ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 // weight scale&bias + ld1 {v4.4s, v5.4s, v6.4s}, [x8], x22 // input kernel sum Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -209,7 +210,7 @@ TILE12_BLOCKNUM: MUL_SCALE v1, v28, v29, v30, v31 ld1 {v0.4s, v1.4s}, [x24], #32 - ld1 {v7.4s}, [x24], x14 + ld1 {v7.4s}, [x24], x25 MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v7, v16, v17, v18, v19 @@ -218,34 +219,34 @@ TILE12_BLOCKNUM: MUL_EXTRA_SCALE v7, v28, v29, v30, v31 TILE12_L8_MLA_TERM: - MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 - MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 - MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 - MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 - MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 - MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 - MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 - MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 - MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 - MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 - MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 - MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 - - MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 - MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 - MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 - MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 - MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 - MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 - MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 - MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 - MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 - MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 - MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 - MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + MLA_WEIGHTZERO v8, v4, v2, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v4, v2, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v4, v2, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v4, v2, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v5, v2, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v5, v2, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v5, v2, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v5, v2, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v6, v2, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v6, v2, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v6, v2, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v6, v2, 3 // tile:11, oc:0-3 + + MLA_WEIGHTZERO v20, v4, v3, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v4, v3, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v4, v3, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v4, v3, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v5, v3, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v5, v3, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v5, v3, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v5, v3, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v6, v3, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v6, v3, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v6, v3, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v6, v3, 3 // tile:11, oc:4-7 cbz x27, TILE12_ADD_DSTV - ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias + ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias ld1 {v3.4s, v4.4s}, [x28], #32 // weight kernel sum MLA_WEIGHTZERO v8, v0, v3, 0 MLA_WEIGHTZERO v9, v0, v3, 1 @@ -284,9 +285,10 @@ TILE12_BLOCKNUM: ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x12 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10] ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 + sub x10, x10, #320 TILE12_L8_ACCUM_BUFFER: add x19, x19, #1 @@ -297,11 +299,12 @@ TILE12_BLOCKNUM: st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], x12 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] + sub x10, x10, #320 b TILE12_BLOCKNUM TILE12_POST: - sub x5, x5, #1 + sub x14, x14, #1 cbz x9, TILE12_CVT_FP16 ld1 {v0.4s, v1.4s}, [x20], #32 ADD_BIAS_FLOAT v8, v9, v10, v11, v0 @@ -329,16 +332,32 @@ TILE12_BLOCKNUM: TILE12_STORE: - st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 - st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 - st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4 + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x6], #64 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x6], x4 L8Tile12LoopCheck: - cbz x5, End + cbz x14, Tile12End mov x8, x15 // revert input kernel sum mov x24, x21 // revert input dequant scale - mov x27, x6 // revert input dequant bias + cbz x27, L8LoopDz_TILE_12 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 b L8LoopDz_TILE_12 +Tile12End: + + add x0, x0, #192 + sub x7, x7, #12 + cbz x7, End + add x1, x1, #48 + add x8, x15, #48 + add x24, x21, #48 + add x4, x4, #128 // revert x4 + + cbz x27, TILE_8 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 + REVERT_WEIGHT_KERNEL_SUM x28, x14, x26, x5 + add x27, x27, #48 + TILE_8: mov x25, #0 cbz x27, TILE_Remain @@ -365,6 +384,7 @@ TILE8_BLOCKNUM: SET_BIAS v12, v13, v14, v15 SET_BIAS v16, v17, v18, v19 SET_BIAS v20, v21, v22, v23 + L8LoopSz_TILE_8: ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src @@ -542,6 +562,7 @@ TILE4_BLOCKNUM: .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + subs x13, x13, #1 .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S index d8d9bd4df2..46cb8b98b7 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S @@ -131,19 +131,20 @@ ldr x23, [x6, #56] // fp32minmax ldr x26, [x6, #64] // blockNum lsl x22, x7, #2 // eDest * SRC_UNIT -mov x14, #-32 +mov x25, #-32 cbz x27, TILE_12 -mov x14, #16 +sub x25, x22, #32 TILE_12: cmp x7, #12 blt TILE_8 sub x4, x4, #128 + mov x12, x2 + mov x6, x0 + mov x14, x5 mov x20, x9 mov x15, x8 // input kernel sum - mov x6, x27 // input dequant bias mov x21, x24 // input dequant scale - mov x12, #-320 L8LoopDz_TILE_12: mov x11, x1 mov x19, #0 @@ -159,8 +160,8 @@ TILE12_BLOCKNUM: SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: - ld1 {v5.16b}, [x2], #16 // weight - ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + ld1 {v5.16b}, [x12], #16 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src // int4->int8 ushr v3.16b, v5.16b, #4 and v4.16b, v5.16b, v7.16b @@ -198,9 +199,8 @@ TILE12_BLOCKNUM: L8LoopSzEnd_TILE_12: L8Tile12Quan: - ld1 {v0.4s, v1.4s}, [x2], #32 // scale - ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // input kernel sum - ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 // weight scale&bias + ld1 {v4.4s, v5.4s, v6.4s}, [x8], x22 // input kernel sum Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -216,7 +216,7 @@ TILE12_BLOCKNUM: MUL_SCALE v1, v28, v29, v30, v31 ld1 {v0.4s, v1.4s}, [x24], #32 - ld1 {v7.4s}, [x24], x14 + ld1 {v7.4s}, [x24], x25 MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v7, v16, v17, v18, v19 @@ -225,34 +225,34 @@ TILE12_BLOCKNUM: MUL_EXTRA_SCALE v7, v28, v29, v30, v31 TILE12_L8_MLA_TERM: - MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 - MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 - MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 - MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 - MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 - MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 - MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 - MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 - MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 - MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 - MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 - MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 - - MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 - MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 - MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 - MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 - MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 - MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 - MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 - MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 - MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 - MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 - MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 - MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + MLA_WEIGHTZERO v8, v4, v2, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v4, v2, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v4, v2, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v4, v2, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v5, v2, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v5, v2, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v5, v2, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v5, v2, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v6, v2, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v6, v2, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v6, v2, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v6, v2, 3 // tile:11, oc:0-3 + + MLA_WEIGHTZERO v20, v4, v3, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v4, v3, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v4, v3, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v4, v3, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v5, v3, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v5, v3, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v5, v3, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v5, v3, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v6, v3, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v6, v3, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v6, v3, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v6, v3, 3 // tile:11, oc:4-7 cbz x27, TILE12_ADD_DSTV - ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias + ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias ld1 {v3.4s, v4.4s}, [x28], #32 // weight kernel sum MLA_WEIGHTZERO v8, v0, v3, 0 MLA_WEIGHTZERO v9, v0, v3, 1 @@ -291,9 +291,10 @@ TILE12_BLOCKNUM: ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x12 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10] ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 + sub x10, x10, #320 TILE12_L8_ACCUM_BUFFER: add x19, x19, #1 @@ -304,11 +305,12 @@ TILE12_BLOCKNUM: st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], x12 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] + sub x10, x10, #320 b TILE12_BLOCKNUM TILE12_POST: - sub x5, x5, #1 + sub x14, x14, #1 cbz x9, TILE12_CVT_FP16 ld1 {v0.4s, v1.4s}, [x20], #32 ADD_BIAS_FLOAT v8, v9, v10, v11, v0 @@ -336,16 +338,32 @@ TILE12_BLOCKNUM: TILE12_STORE: - st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 - st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 - st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4 + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x6], #64 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x6], x4 L8Tile12LoopCheck: - cbz x5, End + cbz x14, Tile12End mov x8, x15 // revert input kernel sum mov x24, x21 // revert input dequant scale - mov x27, x6 // revert input dequant bias + cbz x27, L8LoopDz_TILE_12 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 b L8LoopDz_TILE_12 +Tile12End: + + add x0, x0, #192 + sub x7, x7, #12 + cbz x7, End + add x1, x1, #48 + add x8, x15, #48 + add x24, x21, #48 + add x4, x4, #128 // revert x4 + + cbz x27, TILE_8 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 + REVERT_WEIGHT_KERNEL_SUM x28, x14, x26, x5 + add x27, x27, #48 + TILE_8: mov x25, #0 cbz x27, TILE_Remain diff --git a/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S b/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S index 25c1e3632c..795b6420df 100644 --- a/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S +++ b/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S @@ -32,7 +32,7 @@ ldr x10, [x4, #16] // h ldr x7, [x4, #24] // cStride ldr x19, [x4, #40] // bExtraStride -lsr x7, x7, #1 // cStride / sizeof(float16_t) + lsl x21, x3, #1 // eSize * lP mov w12, #0 @@ -65,6 +65,10 @@ cbz x5, ESIZE .inst 0x052223bf // dup z31.h, z29.h[0] ESIZE: // x3 <= eP +cmp x3, #1 +beq E1 + +lsr x7, x7, #1 // cStride / sizeof(float16_t) cmp x3, #16 blt LoopOcDiv8 @@ -277,6 +281,185 @@ beq End b LoopOcDiv8 // continue next ocDiv8 +E1: +cmp x3, #1 +blt End + +E1LoopH: +mov x8, x1 // A +mov x21, x9 // LU + +.inst 0xc00800ff // zero {za} +mov w11, #0 + +cbz x6, E1LoopL +// bias +lsl x4, x10, #3 +.inst 0x256447f1 // whilelt pn9.h, xzr, x4, vlx2 // oc to process, maximum is 64 +.inst 0xa04024c8 // ld1h {z8.h-z9.h}, pn9/z, [x6] +.inst 0x04265046 // addvl x6, x6, #2 + +.inst 0x6589a50c // fcvt z12.s, p1/m, z8.h +.inst 0x6489a50d // fcvtlt z13.s, p1/m, z8.h +.inst 0x6589a52e // fcvt z14.s, p1/m, z9.h +.inst 0x6489a52f // fcvtlt z15.s, p1/m, z9.h + +.inst 0x05ad6194 // zip1 z20.s, z12.s, z13.s +.inst 0x05ad6595 // zip2 z21.s, z12.s, z13.s +.inst 0x05af61d6 // zip1 z22.s, z14.s, z15.s +.inst 0x05af65d7 // zip2 z23.s, z14.s, z15.s +.inst 0xc1a17e80 // fadd za.s[w11, 0, VGx4], {z20.s-z23.s} + +E1LoopL: +.inst 0x8540c504 // ld1rw {z4.s}, p1/z, [x8] // A +.inst 0xa040a040 // ld1h {z0.h-z3.h}, pn8/z, [x2] // B +// [EP,LP] x [HP,LP] -> [EP,HP] +.inst 0xc1347000 // fdot za.s[w11, 0, VGx4], {z0.h-z3.h}, z4.h + +subs x21, x21, #1 +add x8, x8, x22 +.inst 0x04225082 // addvl x2, x2, #4 +bne E1LoopL + +add x2, x2, x19 // bExtraStride + +E1_To_FP16: +.inst 0xc0820000 // mova z0.s, p0/m, za0h.s[w12, 0] +.inst 0xc0822001 // mova z1.s, p0/m, za0h.s[w13, 0] +.inst 0xc0824002 // mova z2.s, p0/m, za0h.s[w14, 0] +.inst 0xc0826003 // mova z3.s, p0/m, za0h.s[w15, 0] + +.inst 0xc120e010 // fcvt z16.h, {z0.s-z1.s} +.inst 0xc120e051 // fcvt z17.h, {z2.s-z3.s} +cbz x5, E1Store +.inst 0xc17fc3d0 // fclamp {z16.h-z17.h}, z30.h, z31.h + +E1Store: +cmp x10, #8 +bge E1StoreH64 + +cmp x10, #1 +beq E1StoreH8 + +cmp x10, #2 +beq E1StoreH16 + +cmp x10, #3 +beq E1StoreH24 + +cmp x10, #4 +beq E1StoreH32 + +cmp x10, #5 +beq E1StoreH40 + +cmp x10, #6 +beq E1StoreH48 + +cmp x10, #7 +beq E1StoreH56 + +E1StoreH64: +add x21, x0, x7, LSL #1 // x0+2*x7 +add x11, x0, x7, LSL #2 // x0+4*x7 +add x20, x11, x7, LSL #1 // x20+6*x7 +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0x05b02213 // dup z19.q, z16.q[2] +.inst 0x05f02214 // dup z20.q, z16.q[3] +.inst 0x05702235 // dup z21.q, z17.q[1] +.inst 0x05b02236 // dup z22.q, z17.q[2] +.inst 0x05f02237 // dup z23.q, z17.q[3] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21] +.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7] +.inst 0xe400f171 // st1b {z17.b}, p4, [x11] +.inst 0xe4075175 // st1b {z21.b}, p4, [x11, x7] +.inst 0xe400f296 // st1b {z22.b}, p4, [x20] +.inst 0xe4075297 // st1b {z23.b}, p4, [x20, x7] +b E1H16_End + +E1StoreH56: +add x21, x0, x7, LSL #1 +add x11, x0, x7, LSL #2 +add x20, x11, x7, LSL #1 +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0x05b02213 // dup z19.q, z16.q[2] +.inst 0x05f02214 // dup z20.q, z16.q[3] +.inst 0x05702235 // dup z21.q, z17.q[1] +.inst 0x05b02236 // dup z22.q, z17.q[2] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21] +.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7] +.inst 0xe400f171 // st1b {z17.b}, p4, [x11] +.inst 0xe4075175 // st1b {z21.b}, p4, [x11, x7] +.inst 0xe400f296 // st1b {z22.b}, p4, [x20] +b End + +E1StoreH48: +add x21, x0, x7, LSL #1 +add x11, x0, x7, LSL #2 +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0x05b02213 // dup z19.q, z16.q[2] +.inst 0x05f02214 // dup z20.q, z16.q[3] +.inst 0x05702235 // dup z21.q, z17.q[1] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21] +.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7] +.inst 0xe400f171 // st1b {z17.b}, p4, [x11] +.inst 0xe4075175 // st1b {z21.b}, p4, [x11, x7] +b End + +E1StoreH40: +add x21, x0, x7, LSL #1 +add x11, x0, x7, LSL #2 +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0x05b02213 // dup z19.q, z16.q[2] +.inst 0x05f02214 // dup z20.q, z16.q[3] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21] +.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7] +.inst 0xe400f171 // st1b {z17.b}, p4, [x11] +b End + +E1StoreH32: +add x21, x0, x7, LSL #1 +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0x05b02213 // dup z19.q, z16.q[2] +.inst 0x05f02214 // dup z20.q, z16.q[3] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21] +.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7] +b End + +E1StoreH24: +add x21, x0, x7, LSL #1 +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0x05b02213 // dup z19.q, z16.q[2] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21] +b End + +E1StoreH16: +.inst 0x05702212 // dup z18.q, z16.q[1] +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7] +b End + +E1StoreH8: +.inst 0xe400f010 // st1b {z16.b}, p4, [x0] +b End + +E1H16_End: +subs x10, x10, #8 +add x0, x0, x7, LSL #3 +bne E1LoopH + End: .inst 0xd503467f // smstop diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp index a12394f9ac..e24fb4e01f 100644 --- a/source/backend/cpu/CPUAttention.cpp +++ b/source/backend/cpu/CPUAttention.cpp @@ -17,6 +17,8 @@ #include "core/BufferAllocator.hpp" #include "core/TensorUtils.hpp" #include "core/OpCommonUtils.hpp" +#include "core/BufferAllocator.hpp" +#include "compute/ConvolutionTiledExecutor.hpp" #if defined (__aarch64__) #define FLOAT16_T __fp16 @@ -24,264 +26,201 @@ #define FLOAT16_T float #endif -#define MNN_FLASH_ATTENTION_BLOCK_SIZE 64 namespace MNN { template -void CPUAttention::pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale) { - if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8] - mMinQ[h] = query->host()[h * mHeadDim]; - mMaxQ[h] = query->host()[h * mHeadDim]; - for (int i = 0; i < seq_len; i++) { - T * query_src = query->host() + i * mNumHead * mHeadDim + h * mHeadDim; - for (int j = 0; j < mHeadDim; j++) { - mMinQ[h] = ALIMIN(mMinQ[h], query_src[j]); - mMaxQ[h] = ALIMAX(mMaxQ[h], query_src[j]); - } - } - mQueryScale[h] = (mMaxQ[h] - mMinQ[h]) / 255.0f; - mQueryZeroPoint[h] = -255.0f * mMinQ[h] / (mMaxQ[h] - mMinQ[h]) - 128.0; - for (int i = 0; i < seq_len; i++) { - T * query_src = query->host() + i * mNumHead * mHeadDim + h * mHeadDim; - float sumQ = 0; - int out_index = i / eP8; - int in_index = i % eP8; - for (int j = 0; j < mHeadDim; j++) { - int a = j / lP8; - int b = j % lP8; - int quant_res = (int)roundf(query_src[j] / mQueryScale[h] + mQueryZeroPoint[h]); - sumQ += quant_res; - *((int8_t*)pack_q + out_index * UP_DIV(mHeadDim, lP8) * eP8 * lP8 + a * eP8 * lP8 + in_index * lP8 + b) = quant_res; - } - *((float*)sum_q + out_index * eP8 + in_index) = sumQ * mQueryScale[h]; +static void _maskQK(float * qkPacked, const float* scale, size_t seqLen, size_t subKvSeqLen, int pack, int maskStride, int kvoffset, const float* sinksPtr, const int8_t* maskPtr, bool quantKey) { + auto source = (T*)qkPacked; + if (quantKey == false) { + auto elementSize = seqLen * ROUND_UP(subKvSeqLen, pack); + for (int i = 0; i < elementSize; ++i) { + float data = source[i] * scale[0]; + source[i] = data; } } - else { - // target: [seq_len/eP, mHeadDim/lP, eP, lP] - T * query_src = query->host(); - T * query_dst = reinterpret_cast(pack_q); - auto stride0 = ROUND_UP(mHeadDim, lP) * eP; - auto stride1 = eP * lP; - if (mHeadDim % lP) { - memset(query_dst, 0, ROUND_UP(mHeadDim, lP) * bytes * ROUND_UP(seq_len, eP)); - } - for (int i = 0; i < seq_len; i++) { - int out_index = i / eP; - int in_index = i % eP; - for (int j = 0; j < mHeadDim; j++) { - query_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale; + + // mask: [seq, kvseq] + // data: [UP_DIV(kvseq, pack), seq, pack] + if (sinksPtr != nullptr) { + auto mask = (T*)maskPtr; + for (int i = 0; i < UP_DIV(subKvSeqLen, pack); ++i) { + for (int j = 0; j < seqLen; ++j) { + for (int k = 0; k < pack; ++k) { + if (kvoffset + i * pack + k > maskStride - 1) { + break; + } + source[i * seqLen * pack + j * pack + k] = source[i * seqLen * pack + j * pack + k] + mask[j * maskStride + kvoffset + i * pack + k]; + } } } } -} -template -void CPUAttention::unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len) { - float * dst = unpack_qk_dst; - T * src = (T *)(pack_qk_src); - // [kv_seq_len/mPack, seq_len, mPack] -> [seq_len, kv_seq_len] - for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < kv_seq_len; j++) { - int out_index = j / mPack; - int in_index = j % mPack; - dst[i * kv_seq_len + j] = src[out_index * seq_len * mPack + i * mPack + in_index]; - } - } } -template -static void pack_QK(int8_t * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) { - T * dst = reinterpret_cast(pack_qk_dst); - float * src = reinterpret_cast(qk_src); - // [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP] - auto stride0 = ROUND_UP(kv_seq_len, lP) * eP; - auto stride1 = eP * lP; - if (kv_seq_len % lP) { - memset(dst, 0, ROUND_UP(kv_seq_len, lP) * ROUND_UP(seq_len, eP) * bytes); - } - for (int i = 0; i < seq_len; i++) { - int out_index = i / eP; - int in_index = i % eP; - for (int j = 0; j < kv_seq_len; j++) { - dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = src[i * kv_seq_len + j]; - } - } -} +ErrorCode CPUAttention::onResize(const std::vector& inputs, const std::vector& outputs) { + auto gcore = static_cast(backend())->functions(); + auto core = static_cast(backend())->int8Functions(); + gcore->MNNGetMatMulPackMode(&eP, &lP, &hP); + mThreadNum = ((CPUBackend *)backend())->threadNumber(); + mPack = gcore->pack; + mBytes = gcore->bytes; + int qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption; + mUseFlashAttention = (qkvQuantOptions / 8 == 1); -template -static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* maskTensor, int offset, int startIndx, int processedKvLen, int extraSeq) { + // If slide window attention applied, quant key/value must be diabled + mQuantKey = inputs.size() < 5 && (qkvQuantOptions % 8 >= 1); + mQuantValue = inputs.size() < 5 && (qkvQuantOptions % 8 > 1) && mUseFlashAttention; + static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); - int endIndx = startIndx + processedKvLen; - if (maskTensor == nullptr) { - for (int i = 0; i < processedKvLen; i++) { - unpack_qk[i] = unpack_qk[i] * mScale; - } - return; + auto query = inputs[0]; + auto key = inputs[1]; + int seqLen = query->length(1); + int mBlockNum = 1; + mNumHead = query->length(2); + mHeadDim = query->length(3); + mKvNumHead = key->length(2); + mKVCacheManager->setAttenQuantKeyValue(mUseFlashAttention, mQuantKey, mQuantValue); + mKVCacheManager->onResize(mKvNumHead, mHeadDim); + + // Common buffer allocated + auto bufferAlloc = static_cast(backend())->getBufferAllocator(); + mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seqLen, mPack * mBytes})); + backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); + if (inputs.size() > 4 || mUseFlashAttention) { // needed by flash attention and sliding attention with sink + mRunningMax.reset(Tensor::createDevice({mThreadNum, seqLen * 4})); + mRunningSum.reset(Tensor::createDevice({mThreadNum, seqLen * 4})); + backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC); } - const int8_t* mask = maskTensor->host(); - halide_type_t htype = maskTensor->getType(); - int maskSize = maskTensor->elementSize(); - - if (htype == halide_type_of()) { - // float mask - T* fpmask_ptr = (T*)mask; - if (maskSize == (seq_len + extraSeq) * (kv_seq_len + extraSeq)) { // sliding attention, mask shape: [seq_len, kv_seq_len] - for (int i = 0; i < seq_len; ++i) { - auto unpack_qki = unpack_qk + i * processedKvLen; - auto fpmask_ptri = fpmask_ptr + i * (kv_seq_len + extraSeq); - for (int j = startIndx; j < endIndx; ++j) { - unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j]; - } - } - } else { // mask shape: [seq_len, seq_len] - for (int i = 0; i < seq_len; ++i) { - auto unpack_qki = unpack_qk + i * processedKvLen; - auto fpmask_ptri = fpmask_ptr + i * (seq_len + extraSeq); - - auto notMaskIndx = ALIMIN(endIndx, offset); - auto stMaskIndx = ALIMAX(startIndx, offset); - for (int j = startIndx; j < notMaskIndx; ++j) { - unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale; - } - for (int j = stMaskIndx; j < endIndx; ++j) { - unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j - offset]; - } + if (mUseFlashAttention) { // extra buffer need by flash attention + mExpfDiffMax.reset(Tensor::createDevice({mThreadNum, seqLen * 4})); + mTempOut.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seqLen, mPack * mBytes})); + backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC); + } + if (mQuantKey) { + int outterSeqLen = UP_DIV(seqLen, eP8); + int outterHeadDim = UP_DIV(mHeadDim, lP8); + + size_t packedQSize = 0; + if (outterSeqLen > 0) { + int fullSeqBlocks = (seqLen / eP8); + packedQSize += (size_t)fullSeqBlocks * outterHeadDim * eP8 * lP8; + + int lastEUnit = seqLen % eP8; + if (lastEUnit != 0) { + packedQSize += (size_t)outterHeadDim * lastEUnit * lP8; } } - } else { - // int mask - int* mask_ptr = (int*)mask; - for (int i = 0; i < seq_len; ++i) { - for (int j = 0; j < processedKvLen; ++j) { - int maskIndex = i * kv_seq_len + startIndx +j; - if (mask_ptr[maskIndex]) { - unpack_qk[i * processedKvLen + j] = unpack_qk[i * processedKvLen + j] * mScale; - } else { - unpack_qk[i * processedKvLen + j] = min_val; - } + mPackQ.reset(Tensor::createDevice({mNumHead, (int32_t)packedQSize})); + backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); + + mSumQ = bufferAlloc->alloc(mThreadNum * ROUND_UP(seqLen, eP8) * mBlockNum * sizeof(int32_t)); + mQueryScale = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES); + mQueryZeroPoint = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES); + mQueryQuantZero = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES); + mQueryQuantScale = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES); + mQuantQuery = bufferAlloc->alloc(seqLen * mNumHead * mHeadDim); + + if (mBlockNum > 1) { + mAccumBuffer = bufferAlloc->alloc(eP8 * hP8 * mThreadNum * QUANT_INFO_BYTES); + if (mAccumBuffer.invalid()) { + return OUT_OF_MEMORY; } } - } -} -typedef void(softmaxFunc)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); -template -static void softmaxQK(float* softmax_qk_addr, float* unpack_qk_addr, float* runningMax, float* runningSum, float* diffScale, const float* sinkPtr, softmaxFunc* sffunc, int seq_len, int kv_seq_len, int headIdx, bool isLastKvBlock) { + if (mSumQ.invalid() || mQueryScale.invalid() || mQueryQuantZero.invalid() || mQueryZeroPoint.invalid() || mQueryQuantScale.invalid() || mQuantQuery.invalid()) { + return OUT_OF_MEMORY; + } - // not sliding attention - if (sinkPtr == nullptr) { - sffunc(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len); - return; - } + // post parameters for int8 gemm + mGemmRelu.reset(2 * sizeof(int32_t)); + if (!mGemmRelu.get()) { + MNN_ERROR("Allocate mGemmRelu buffer failed in CPU Attention"); + return OUT_OF_MEMORY; + } + ((float*)mGemmRelu.get())[0] = -std::numeric_limits().max(); + ((float*)mGemmRelu.get())[1] = std::numeric_limits().max(); + if (mBytes == 2) { + gcore->MNNFp32ToLowp((float*)mGemmRelu.get(), reinterpret_cast(mGemmRelu.get()), 2); + } - float sink = ((T*)sinkPtr)[headIdx]; - if (!runningMax && !runningSum) { // Do not use flash attention + // GemmInt8 kernels + if (mBytes == 4) { + mInt8GemmKernel = core->Int8GemmKernel; + } else { + mInt8GemmKernel = core->MNNGemmInt8AddBiasScale_Unit_FP16; + } - for (int i = 0; i < seq_len; ++i) { - float exprOffset[4] = {1, 0, -sink, 1.f}; - MNNExp(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, exprOffset, kv_seq_len); - for (int j = 0; j < kv_seq_len; ++j) { - softmax_qk_addr[i * kv_seq_len + j] /= exprOffset[3]; + if (mQuantValue) { + mQuantQK = bufferAlloc->alloc(mThreadNum * eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack)); + mQKScale = bufferAlloc->alloc(eP8 * QUANT_INFO_BYTES); + mQKBias = bufferAlloc->alloc(eP8 * QUANT_INFO_BYTES); + mSumQK = bufferAlloc->alloc(mThreadNum * eP8 * QUANT_INFO_BYTES); + + if (mQuantQK.invalid() || mQKScale.invalid() || mQKBias.invalid() || mSumQK.invalid()) { + return OUT_OF_MEMORY; } } - return; + } else { + mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seqLen, eP), ROUND_UP(mHeadDim, lP), eP * mBytes})); + backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); } - // Use flash attention - if (isLastKvBlock) { - for (int i = 0; i < seq_len; ++i) { - runningSum[i] += expf(sink - runningMax[i]); - } - } - MNNSoftmax(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len); -} + // release tensor + backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); -template -static void unpack_QKV(int8_t* pack_qkv, int8_t* unpack_qkv, int mNumHead, int mHeadDim, int mPack, int seq_len) { - auto src_ptr = reinterpret_cast(pack_qkv); - auto dst_ptr = reinterpret_cast(unpack_qkv); - for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < mHeadDim; j++) { - int a = j / mPack; - int b = j % mPack; - dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * mPack + i * mPack + b]; - } + if (inputs.size() > 4 || mUseFlashAttention) { + backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC); } -} - -ErrorCode CPUAttention::onResize(const std::vector& inputs, const std::vector& outputs) { - auto core = static_cast(backend())->functions(); - core->MNNGetMatMulPackMode(&eP, &lP, &hP); - mThreadNum = ((CPUBackend *)backend())->threadNumber(); - mPack = core->pack; - bytes = core->bytes; - int qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption; - mUseGemmInt8 = (qkvQuantOptions % 8 == 4); - if (mUseGemmInt8) { - static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); + if (mUseFlashAttention) { + backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC); } - auto query = inputs[0]; - auto key = inputs[1]; - int seq_len = query->length(1); - mNumHead = query->length(2); - mHeadDim = query->length(3); - mKvNumHead = key->length(2); - mKVCacheManager->onResize(mKvNumHead, mHeadDim); - if (mUseGemmInt8) { - mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8})); - mSumQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), eP8})); - mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack})); - backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mSumQ.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); - mMinQ.resize(mNumHead); - mMaxQ.resize(mNumHead); - mQueryScale.resize(mNumHead); - mQueryZeroPoint.resize(mNumHead); - } else { - mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP * bytes})); - mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes})); - backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); - // flash attention - if (qkvQuantOptions / 8 == 1) { - mRunningMax.reset(Tensor::createDevice({mThreadNum, seq_len * 4})); - mRunningSum.reset(Tensor::createDevice({mThreadNum, seq_len * 4})); - mExpfDiffMax.reset(Tensor::createDevice({mThreadNum, seq_len * 4})); - mTempOut.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes})); - - backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC); + // release memchunk + if (mQuantKey) { + bufferAlloc->free(mSumQ); + bufferAlloc->free(mQueryScale); + bufferAlloc->free(mQueryZeroPoint); + bufferAlloc->free(mQueryQuantScale); + bufferAlloc->free(mQueryQuantZero); + bufferAlloc->free(mQuantQuery); + if (mBlockNum > 1) { + bufferAlloc->free(mAccumBuffer); } + if (mQuantValue) { + bufferAlloc->free(mQuantQK); + bufferAlloc->free(mQKScale); + bufferAlloc->free(mQKBias); + bufferAlloc->free(mSumQK); + } + } - backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); - - if (qkvQuantOptions / 8 == 1) { - backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC); + // Only allocated for quantized Q&K + if (mQuantKey) { + if (mBytes == 4) { + mQuantFunc = core->MNNFloat2Int8; + } else { + mQuantFunc = core->DynamicQuanInput_ARM82; } + } return NO_ERROR; } ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std::vector& outputs) { - auto core = static_cast(backend())->functions(); - auto qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption; + auto gcore = static_cast(backend())->functions(); + auto core = static_cast(backend())->int8Functions(); auto query = inputs[0]; auto key = inputs[1]; auto value = inputs[2]; const Tensor* mask = nullptr; - int seq_len = query->length(1); + int seqLen = query->length(1); if (inputs.size() > 3) { mask = inputs[3]; } @@ -291,16 +230,16 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: MNN_ASSERT(sinks != nullptr); MNN_ASSERT(sinks->elementSize() == mNumHead) } - int tileCount = UP_DIV(mNumHead, mThreadNum); + int numHeadDiv = UP_DIV(mNumHead, mThreadNum); int group_size = mNumHead / mKvNumHead; // reduce the value of 'query' to avoid fp16 overflow float mScale = 1.0 / sqrt(mHeadDim); float q_scale = 1.0; - if (bytes == 2) { + if (mBytes == 2 && !mQuantKey) { // reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow FLOAT16_T minValue; FLOAT16_T maxValue; - core->MNNCountMaxMinValue(query->host(), (float*)(&minValue), (float*)(&maxValue), query->elementSize()); + gcore->MNNCountMaxMinValue(query->host(), (float*)(&minValue), (float*)(&maxValue), query->elementSize()); float maxV = maxValue; float minV = minValue; float absMax = ALIMAX(fabsf(maxV), fabsf(minV)); @@ -309,169 +248,436 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: } mScale /= q_scale; } - int insertLen = seq_len; + int insertLen = seqLen; if (mKVCache && mMeta != nullptr) { if (mMeta->previous == mMeta->remove) { mKVCacheManager->onClear(); - mKVCacheManager->onAlloc(mMeta->add); + mKVCacheManager->onAlloc(mMeta, seqLen); } else { MNN_ASSERT(mMeta->previous == mKVCacheManager->kvLength()); mKVCacheManager->onRealloc(mMeta); } - insertLen = mMeta->add; + insertLen = (int)mMeta->add; } else { mKVCacheManager->onClear(); - mKVCacheManager->onAlloc(seq_len); + mKVCacheManager->onAlloc(mMeta, seqLen); } + // Add the new kv to the kvcache - mKVCacheManager->onPushBack(key, value, insertLen); - int padSeqLength = seq_len - insertLen; - seq_len = insertLen; - int kv_seq_len = mKVCacheManager->kvLength(); - int max_len = mKVCacheManager->maxLength(); - bool quant_key = mKVCacheManager->config()->mQuantKey; - bool quant_value = mKVCacheManager->config()->mQuantValue; - - mBlockKV = (qkvQuantOptions / 8 == 1) ? ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, kv_seq_len) : kv_seq_len; + mKVCacheManager->onUpdateKV(key, value, (int)insertLen); + + if (mUseFlashAttention) { + mBlockKV = ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, mKVCacheManager->kvLength()); + } else { + mBlockKV = mKVCacheManager->kvLength(); + } + + // Constant Initialization + auto padSeqLength = seqLen - insertLen; + seqLen = insertLen; + int kvSeqLen = mKVCacheManager->kvLength(); + int maxLen = mKVCacheManager->maxLength(); int32_t units[2] = {eP, lP}; + const float* sinksPtr = sinks ? sinks->host() : nullptr; + int kvValidOffset = kvSeqLen - seqLen; // reuse_kv=true or decode, kvValidOffset>0 // Temporary tensors for intermediate results - std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seq_len, mBlockKV})); - std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seq_len, mBlockKV})); - std::shared_ptr newPackQK(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mBlockKV, lP), eP * bytes})); - std::shared_ptr dequantV(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP * bytes})); - // mTempQKBlock.reset(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes})); - std::shared_ptr tempQKBlock(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes})); + std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seqLen, mBlockKV})); + std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seqLen, ROUND_UP(mBlockKV, mPack)})); // [mBlockKV/mPack, seqLen, mPack ] + std::shared_ptr newPackQK; + if (mQuantValue == false) { + newPackQK.reset(Tensor::createDevice({mThreadNum, eP * ROUND_UP(mBlockKV, lP) * mBytes})); + } else { + newPackQK.reset(Tensor::createDevice({mThreadNum, eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8)})); + } + std::shared_ptr mTempQKBlock(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seqLen, mPack * mBytes})); backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC); backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC); - backend()->onAcquireBuffer(tempQKBlock.get(), Backend::STATIC); - if (quant_value) { - backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC); - mKVCacheManager->onDequantValue(dequantV.get()); + backend()->onAcquireBuffer(mTempQKBlock.get(), Backend::STATIC); + + // Quantize Q and initialize bias 0 + if (mQuantKey) { + mGemmBias.reset(ROUND_UP(ALIMAX(mBlockKV, mHeadDim), hP8) * QUANT_INFO_BYTES); + if (!mGemmBias.get()) { + MNN_ERROR("Allocate bias buffer failed in CPU Attention\n"); + return OUT_OF_MEMORY; + } + memset(mGemmBias.get(), 0, ROUND_UP(ALIMAX(mBlockKV, mHeadDim), hP8) * QUANT_INFO_BYTES); + + // Q: [seqLen,numHead,headDim] + // maxQ, minQ: [seqLen,numHead] + // scaleQ, zeroQ: [numHead, seqLen] + // quantQ: [seqLen,numHead,headDim] + auto queryPtr = query->host(); + int divPart = UP_DIV(seqLen * mNumHead, mThreadNum); + MNN_CONCURRENCY_BEGIN (tId, mThreadNum) { + size_t info[9] = {1, (size_t)mHeadDim, 1, 1, 1, 1, 1, 1, 0}; + auto remainLu = seqLen * mNumHead - tId * divPart; + if (remainLu > 0) { + remainLu = ALIMIN(divPart, remainLu); + for (int i = tId * divPart; i < tId * divPart + remainLu; ++i) { + + // address + auto srcFloatPtr = (float*)(queryPtr + i * mHeadDim * mBytes); + auto dstInt8Ptr = (int8_t*)(mQuantQuery.ptr() + i * mHeadDim); + auto quantScalePtr = (float*)(mQueryQuantScale.ptr() + i * QUANT_INFO_BYTES); + auto quantZeroPtr = (float*)(mQueryQuantZero.ptr() + i * QUANT_INFO_BYTES); + + // scaleQ, zeroQ, [seqLen,numHead]->[numHead,seqLen] + int indexQ = (i / mNumHead) + (i % mNumHead) * seqLen; + auto scalePtr = (float*)(mQueryScale.ptr() + indexQ * QUANT_INFO_BYTES); + auto zeroPtr = (float*)(mQueryZeroPoint.ptr() + indexQ * QUANT_INFO_BYTES); + + + // compute the quant/dequant scale/bias + gcore->MNNAsyQuantInfo(scalePtr, zeroPtr, quantScalePtr, quantZeroPtr, nullptr, nullptr, srcFloatPtr, info); + scalePtr[0] *= mScale; + zeroPtr[0] *= mScale; + + // quantize the float query to int8_t query + mQuantFunc(srcFloatPtr, dstInt8Ptr, UP_DIV(mHeadDim, gcore->pack), quantScalePtr, -128, 127, quantZeroPtr, 0); + } + } + } MNN_CONCURRENCY_END(); + + // source int8_t query: [seqLen,numHead,headDim] + // dest int8_t query: [numHead,seqLen/eP,headDim/lP,eP,lP] + + int outterSeqLen = UP_DIV(seqLen, eP8); + int outterHeadDim = UP_DIV(mHeadDim, lP8); + size_t outputOffset = 0; + + const int8_t* src_base_ptr = (const int8_t*)mQuantQuery.ptr(); + int8_t* dst_base_ptr = mPackQ->host(); + + for (int h = 0; h < mNumHead; ++h) { + for (int seqBlock = 0; seqBlock < outterSeqLen; ++seqBlock) { + int seqBase = seqBlock * eP8; + int eunit = std::min(eP8, seqLen - seqBase); + size_t currentSeqBlockSize = (size_t)outterHeadDim * eunit * lP8; + + for (int dimBlock = 0; dimBlock < outterHeadDim; ++dimBlock) { + int dimBase = dimBlock * lP8; + int headDimRemain = mHeadDim - dimBase; + int copyLen = std::min(lP8, headDimRemain); + + if (copyLen <= 0) { + continue; + } + + int8_t* dst_block_ptr = dst_base_ptr + + outputOffset + + (size_t)dimBlock * (eunit * lP8); + + const size_t src_row_stride = (size_t)mNumHead * mHeadDim; + + for (int seqLocal = 0; seqLocal < eunit; ++seqLocal) { + int innerSeq = seqBase + seqLocal; + + const int8_t* src_row_ptr = src_base_ptr + + (size_t)innerSeq * src_row_stride + + (size_t)h * mHeadDim + + dimBase; + + int8_t* dst_row_ptr = dst_block_ptr + seqLocal * lP8; + + std::memcpy(dst_row_ptr, src_row_ptr, copyLen); + } + if (copyLen < lP8) { + for (int seqLocal = 0; seqLocal < eunit; ++seqLocal) { + int8_t* dst_pad_ptr = dst_block_ptr + seqLocal * lP8 + copyLen; + std::memset(dst_pad_ptr, 0, lP8 - copyLen); + } + } + } + outputOffset += currentSeqBlockSize; + } + } // Finish quantize Q + + if (mQuantValue) { + auto scalePtr = (float*)(mQKScale.ptr()); + auto zeroPtr = (float*)(mQKBias.ptr()); + for (int k = 0; k < eP8; ++k) { + scalePtr[k] = 1.f / 255.f; +#ifdef MNN_USE_SSE + zeroPtr[k] =0; +#else + zeroPtr[k] = 128.f / 255.f; +#endif + } + } + } - const float* sinksPtr = sinks ? sinks->host() : nullptr; + std::function mCompute = [=](int tId) { - auto qReordered = mPackQ->host() + tId * mPackQ->stride(0); - auto qkPacked = tempQKBlock->host() + tId * tempQKBlock->stride(0); - int8_t * sum_q = nullptr; + int8_t* qReordered = nullptr; + auto qkPacked = mTempQKBlock->host() + tId * mTempQKBlock->stride(0); auto qkFlatten = unpackQK->host() + tId * unpackQK->stride(0); auto qkSoftmax = softmMaxQ->host() + tId * softmMaxQ->stride(0); auto qkReordered = newPackQK->host() + tId * newPackQK->stride(0); auto qkvPacked = mPackQKV->host() + tId * mPackQKV->stride(0); - auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul; - auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain; + int headIndex = tId * numHeadDiv; + int headsToCompute = ALIMIN(numHeadDiv, mNumHead - headIndex); // Flash Attention auto runningMax = mRunningMax ? (float*)(mRunningMax->host() + tId * mRunningMax->stride(0)) : nullptr; auto runningSum = mRunningSum ? (float*)(mRunningSum->host() + tId * mRunningSum->stride(0)) : nullptr; auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host() + tId * mExpfDiffMax->stride(0)) : nullptr; auto outputPacked = mTempOut ? mTempOut->host() + tId * mTempOut->stride(0) : qkvPacked; - int head_index = tId * tileCount; - int kvBlocks = UP_DIV(kv_seq_len, mBlockKV); - if (mUseGemmInt8) { - qReordered = mPackQ->host() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8; - sum_q = mSumQ->host() + tId * UP_DIV(seq_len, eP8) * eP8 * 4; + int kvBlocks = UP_DIV(kvSeqLen, mBlockKV); + + QuanPostTreatParameters gemmParam4QxK, gemmParam4QKxV; // used by int8 gemm, allocated per thread. + SumByAxisParams sumParams4QxK, sumParams4QKxV; + float* qSumAddr = nullptr; + float* qScale = nullptr; + float* qBias = nullptr; + float* accumbuff = nullptr; + int32_t unitColBufferSize = 0; + if (mQuantKey) { + // parameters shared by all mBlockKV + gemmParam4QxK.blockNum = mBlockNum; + gemmParam4QxK.biasFloat = reinterpret_cast(mGemmBias.get()); + gemmParam4QxK.useInt8 = 0; + gemmParam4QxK.fp32minmax = reinterpret_cast(mGemmRelu.get()); + + sumParams4QxK.oneScale = 0; + sumParams4QxK.SRC_UNIT = lP8; + sumParams4QxK.blockNum = mBlockNum; + sumParams4QxK.DST_XUNIT = eP8; + sumParams4QxK.inputBlock = 0; + sumParams4QxK.kernelxy = 1; + // fixed + sumParams4QxK.LU = UP_DIV(mHeadDim, lP8); + sumParams4QxK.unitColBufferSize = ROUND_UP(mHeadDim, lP8) * eP8; + sumParams4QxK.kernelCountUnitDouble = UP_DIV(mHeadDim, lP8); + sumParams4QxK.valid = mHeadDim % lP8; + + + if (mBlockNum > 1) { + accumbuff = (float*)(mAccumBuffer.ptr() + tId * eP8 * hP8 * QUANT_INFO_BYTES); + } + unitColBufferSize = eP8 * ROUND_UP(mHeadDim, lP8); + + if (mQuantValue) { + gemmParam4QKxV.blockNum = mBlockNum; + gemmParam4QKxV.biasFloat = reinterpret_cast(mGemmBias.get()); + gemmParam4QKxV.useInt8 = 0; + gemmParam4QKxV.fp32minmax = reinterpret_cast(mGemmRelu.get()); + gemmParam4QKxV.inputScale = (float*)mQKScale.ptr(); + gemmParam4QKxV.inputBias = (float*)mQKBias.ptr(); + gemmParam4QKxV.srcKernelSum = (float*)(mSumQK.ptr() + tId * eP8 * QUANT_INFO_BYTES); + + sumParams4QKxV.oneScale = 0; + sumParams4QKxV.SRC_UNIT = lP8; + sumParams4QKxV.blockNum = mBlockNum; + sumParams4QKxV.DST_XUNIT = eP8; + sumParams4QKxV.inputBlock = 0; + sumParams4QKxV.kernelxy = 1; + sumParams4QKxV.unitColBufferSize = ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8) * eP8; + sumParams4QKxV.kernelCountUnitDouble = UP_DIV(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8); + } + } + + size_t vstride0 = ROUND_UP(mHeadDim, hP) * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP); + if (mQuantValue) { + vstride0 = (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP8) + 2 * QUANT_INFO_BYTES * mBlockNum * ROUND_UP(mHeadDim, hP8)); } - for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) { + + // use for V + float const* srcPtr[1]; + // only used for quantized V + float vQuantScale[1] = {255.f}; + float vQuantBias[1] = {-128.f}; + int32_t infoInt8V[5]; + infoInt8V[0] = 1; // number + infoInt8V[2] = static_cast(sumParams4QKxV.unitColBufferSize); + infoInt8V[3] = 1; // stride + int32_t elInt8V[4] = {eP8, ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), 0, 0}; + + // only used for float V + int32_t infoFloatV[4]; + infoFloatV[0] = 1; // number + infoFloatV[1] = seqLen; // eReal + infoFloatV[3] = 1; // stride + int32_t elFloatV[4] = {seqLen, ROUND_UP(kvSeqLen, lP), 0, 0}; + + int offset[2] = {seqLen, mNumHead * mHeadDim}; + + for (int h = headIndex; h < headIndex + headsToCompute; h++) { + // Prepare for flash attention if (runningSum && runningMax) { - memset(runningSum, 0, mRunningSum->stride(0)); if (sinksPtr == nullptr) { - for (int k = 0; k < seq_len; ++k) { + memset(runningSum, 0, mRunningSum->stride(0)); + for (int k = 0; k < seqLen; ++k) { runningMax[k] = std::numeric_limits::lowest(); } } else { + for (int k = 0; k < seqLen; ++k) { + runningSum[k] = 1.f; // exp(sink-sink) + } float sinkVal; - if (bytes == 2) { + if (mBytes == 2) { sinkVal = ((FLOAT16_T*)sinksPtr)[h]; } else { - sinkVal =sinksPtr[h]; + sinkVal = sinksPtr[h]; } - for (int k = 0; k < seq_len; ++k) { + for (int k = 0; k < seqLen; ++k) { runningMax[k] = sinkVal; } } } - int kv_h = h / group_size; - int8_t * key_addr = mKVCacheManager->addrOfKey(kv_h); - int8_t * scale_addr = mKVCacheManager->addrOfScale(kv_h); - int8_t * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h); - int8_t * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h); - int8_t * value_addr = quant_value ? (dequantV->host() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h); - if (mUseGemmInt8) { - if (bytes == 2) { - pack_query(query, qReordered, sum_q, seq_len, h, q_scale); - } else { - pack_query(query, qReordered, sum_q, seq_len, h, q_scale); - } + + // Compute the current addresses + int kvHeadIndex = h / group_size; + int8_t * keyAddr = mKVCacheManager->addrOfKey(kvHeadIndex); + int8_t * keySum = mKVCacheManager->addrOfKeySum(kvHeadIndex); + int8_t * valueAddr = mKVCacheManager->addrOfValue(kvHeadIndex); + float* valueSum = (float*)mKVCacheManager->addrOfValueSum(kvHeadIndex); + + // Get packed Q + if (mQuantKey == false) { + qReordered = mPackQ->host() + tId * mPackQ->stride(0); + gcore->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host() + h * mHeadDim * mBytes), mHeadDim * mNumHead, &q_scale, units, seqLen, mHeadDim); } else { - core->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host() + h * mHeadDim * bytes), mHeadDim * mNumHead, &q_scale, units, seq_len, mHeadDim); + qReordered = mPackQ->host() + h * mPackQ->stride(0); + qSumAddr = (float*)(mSumQ.ptr() + tId * ROUND_UP(seqLen, eP8) * mBlockNum * QUANT_INFO_BYTES); + qScale = (float*)(mQueryScale.ptr() + h * seqLen * mBlockNum * QUANT_INFO_BYTES); + qBias = (float*)(mQueryZeroPoint.ptr() + h * seqLen * mBlockNum * QUANT_INFO_BYTES); + gcore->MNNSumByAxisLForMatmul_A(qSumAddr, qReordered, qScale, seqLen, sumParams4QxK); } + + // Start computing for (int i = 0; i < kvBlocks; ++i) { - int subKvSeqLen = ALIMIN(mBlockKV, kv_seq_len - i * mBlockKV); - auto keyPtr = key_addr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * bytes; - auto valuePtr = value_addr + i * UP_DIV(mBlockKV, lP) * hP * lP * bytes; - // query @ key - { - int loop_e = seq_len / eP; - int remain = seq_len % eP; - auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes; - size_t shapeParameters[7] = {(size_t)eP * lP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seq_len * mPack * bytes, 0, 0, 0}; + int subKvSeqLen = ALIMIN(mBlockKV, kvSeqLen - i * mBlockKV); + // 1. query @ key + if (mQuantKey == false) { + auto keyPtr = keyAddr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; + int loop_e = seqLen / eP; + int remain = seqLen % eP; + auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * mBytes; + size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seqLen * mPack * mBytes, 0, 0, 0}; for (int ei = 0 ; ei < loop_e; ei++) { - QxK((float*)(qkPacked + (ei * eP * mPack) * bytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); + gcore->MNNPackedMatMul((float*)(qkPacked + (ei * eP * mPack) * mBytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); + } + gcore->MNNPackedMatMulRemain((float*)(qkPacked + (loop_e * eP * mPack) * mBytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); + } else { + auto eRemain = seqLen; + auto srcInt8 = qReordered; + auto dstInt8 = qkPacked; + auto keyPtr = keyAddr + i * UP_DIV(mBlockKV, hP8) * (ROUND_UP(mHeadDim, lP8) * hP8 + 2 * hP8 * QUANT_INFO_BYTES); + gemmParam4QxK.weightKernelSum = (float*)(keySum + i * mBlockKV * QUANT_INFO_BYTES); + gemmParam4QxK.inputScale = qScale; + gemmParam4QxK.inputBias = qBias; + gemmParam4QxK.srcKernelSum = qSumAddr; + while (eRemain > 0) { + auto eSize = ALIMIN(eP8, eRemain); + mInt8GemmKernel(dstInt8, srcInt8, keyPtr, UP_DIV(mHeadDim, lP8), mBytes * seqLen * mPack, UP_DIV(subKvSeqLen, mPack), &gemmParam4QxK, eSize); + eRemain -= eP8; + gemmParam4QxK.inputScale += eP8; + gemmParam4QxK.inputBias += eP8; + gemmParam4QxK.srcKernelSum += eP8; + srcInt8 += unitColBufferSize; + dstInt8 += eP8 * mPack * mBytes; + if (mBlockNum > 1) { + memset(accumbuff, 0, eP8 * hP8 * QUANT_INFO_BYTES); + gemmParam4QxK.accumBuffer = accumbuff; + } } - QxK_remain((float*)(qkPacked + (loop_e * eP * mPack) * bytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); } - // qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len, eP] + // 2. softmax scores + // qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len/lP, eP, lP] { - if(bytes == 2) { - if (seq_len == 1) { - core->MNNLowpToFp32((int16_t*)qkPacked, qkFlatten, seq_len * subKvSeqLen); - } else { - core->MNNAttenUnpackAndConvertFp16(qkFlatten, (float*)qkPacked, subKvSeqLen, seq_len, mPack); + if(mBytes == 2) { + if (!mQuantKey || sinksPtr != nullptr) { + _maskQK((float*)qkPacked, &mScale, seqLen, subKvSeqLen, mPack, kvSeqLen, i * mBlockKV,sinksPtr, mask->host(), mQuantKey); } - mask_QK(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen, padSeqLength); - softmaxQK(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1); - core->MNNAttenPackAndConvertFp32((float*)qkReordered, qkSoftmax, units, seq_len, subKvSeqLen); } else { - if (seq_len > 1) { - int32_t areaOffset[2] = {seq_len, seq_len}; - core->MNNUnpackCUnitTranspose(qkFlatten, (float*)qkPacked, seq_len, subKvSeqLen, areaOffset); - } else { - memcpy(qkFlatten, qkPacked, subKvSeqLen * sizeof(float)); + if (!mQuantKey || sinksPtr != nullptr) { + _maskQK((float*)qkPacked, &mScale, seqLen, subKvSeqLen, mPack, kvSeqLen, i * mBlockKV, sinksPtr, mask->host(), mQuantKey); } - mask_QK(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen, padSeqLength); - softmaxQK(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1); - MNNPackForMatMul_A((float*)qkReordered, qkSoftmax, seq_len, subKvSeqLen, eP, lP, bytes); } + bool useMask = (sinksPtr == nullptr); + gcore->MNNSoftmax(qkSoftmax, (float*)qkPacked, runningMax, runningSum, diffScale, seqLen, subKvSeqLen, i * mBlockKV, kvValidOffset, mPack, useMask); } - // qk @ v - // TODO: update qkvPacked using diffScale - size_t shapeParameters[7] = {(size_t)eP * lP * bytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seq_len * mPack * bytes, 0, 0, 0}; - size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(subKvSeqLen + i * mBlockKV, lP) + UP_DIV(i * mBlockKV, lP)) * hP * lP * bytes; - shapeParameters[5] = quant_value ? 0 : bExtraStride; - int loop_e = seq_len / eP; - int remain = seq_len % eP; - auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * bytes; - for (int ei = 0 ; ei < loop_e; ei++) { - core->MNNPackedMatMul((float*)(qkvPacked + (ei * eP * mPack) * bytes), (float*)(qkReordered + ei * qkStride0), (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); + // 3. qk @ v + auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * mBytes; + auto rowStart = (i * mBlockKV < kvValidOffset)? 0 : (i * mBlockKV - kvValidOffset); + + if (mQuantValue == false) { + auto valuePtr = valueAddr + i * vstride0 * mBytes; + size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seqLen * mPack * mBytes, 0, 0, 0}; + size_t bExtraStride = (i < kvBlocks - 1) ? 0 : (ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP) - ROUND_UP(subKvSeqLen, lP)) * hP * mBytes; + shapeParameters[5] = bExtraStride; + + int loop_e = (seqLen - rowStart) / eP; + int remain = (seqLen - rowStart) % eP; + + int ei = 0; + elFloatV[0] = eP; + elFloatV[1] = ROUND_UP(subKvSeqLen, lP); + infoFloatV[2] = eP; + for ( ; ei < loop_e; ei++) { + srcPtr[0] = (float const*)((int8_t*)qkSoftmax + (ei * eP + rowStart) * mPack * mBytes); + gcore->MNNPackC4ForMatMul_A((float*)qkReordered, srcPtr, infoFloatV, elFloatV); + gcore->MNNPackedMatMul((float*)(qkvPacked + (ei * eP + rowStart) * mPack * mBytes), (float*)qkReordered, (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); + } + if (remain > 0) { + elFloatV[0] = remain; + infoFloatV[2] = remain; + srcPtr[0] = (float const*)((int8_t*)qkSoftmax + (loop_e * eP + rowStart) * mPack * mBytes); + shapeParameters[0] = remain * lP * mBytes; + gcore->MNNPackC4ForMatMul_A((float*)qkReordered, srcPtr, infoFloatV, elFloatV); + gcore->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP + rowStart) * mPack * mBytes), (float*)qkReordered, (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); + } + } else { // use int8 kernel to compute qk@ v + auto valuePtr = valueAddr + i * vstride0; + auto eRemain = seqLen - rowStart; + auto qkPtr = (int8_t*)(qkSoftmax) + rowStart * mPack * mBytes; // [UP_DIV(subKvSeqLen,pack),seqLen,pack] + auto qkvFloat = qkvPacked + rowStart * mPack * mBytes; + gemmParam4QKxV.weightKernelSum = valueSum + i * ROUND_UP(mHeadDim, hP8); + sumParams4QKxV.valid = subKvSeqLen % lP8; + sumParams4QKxV.LU = UP_DIV(subKvSeqLen, lP8); + + auto dstInt8Ptr = (int8_t*)mQuantQK.ptr() + tId * eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack); + srcPtr[0] = (const float*)(dstInt8Ptr); + + while (eRemain > 0) { + auto eSize = ALIMIN(eRemain, eP8); + + memset(dstInt8Ptr, 0, eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack)); + + infoInt8V[1] = eSize; // eReal + infoInt8V[4] = eSize; // e to process + elInt8V[0] = eSize; // e to process + + + for (int qi = 0; qi < UP_DIV(subKvSeqLen, mPack); ++qi) { + mQuantFunc((float*)(qkPtr + qi * seqLen * mPack * mBytes), dstInt8Ptr + qi * eSize * mPack, eSize, vQuantScale, -128, 127, vQuantBias, 0); + } + core->MNNPackC4Int8ForMatMul_A(qkReordered, (int8_t const **)srcPtr, infoInt8V, elInt8V); + // mSumQK + gcore->MNNSumByAxisLForMatmul_A(gemmParam4QKxV.srcKernelSum, qkReordered, (float*)mQKScale.ptr(), eSize, sumParams4QKxV); + mInt8GemmKernel(qkvFloat, qkReordered, valuePtr, UP_DIV(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), mBytes * seqLen * mPack, UP_DIV(mHeadDim, mPack), &gemmParam4QKxV, eSize); + + eRemain -= eSize; + qkPtr += (eSize * mPack * mBytes); + qkvFloat += (eSize * mPack * mBytes); + } } - core->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP * mPack) * bytes), (float*)(qkReordered + loop_e * qkStride0), (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); + // 4. flash attention, update each sub kvSeq's final results if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) { - core->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seq_len, mPack, i, kvBlocks, mPackQKV->stride(0) / bytes, bytes); + gcore->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seqLen, mPack, i, kvBlocks, mPackQKV->stride(0) / mBytes, mBytes, rowStart); } } - // unpack: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim] - auto dst_ptr = outputs[0]->host() + h * mHeadDim * bytes; - if (bytes == 2) { - unpack_QKV((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len); - } else { - unpack_QKV((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len); - } + // Final results writing: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim] + auto dstPtr = outputs[0]->host() + h * mHeadDim * mBytes; + // offset = {seqLen, mNumHead * mHeadDim}; + gcore->MNNUnpackCUnitTranspose((float*)dstPtr, (float*)outputPacked, seqLen, mHeadDim, offset); } }; @@ -483,16 +689,14 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC); backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC); - backend()->onReleaseBuffer(tempQKBlock.get(), Backend::STATIC); - if (quant_value){ - backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC); - } + backend()->onReleaseBuffer(mTempQKBlock.get(), Backend::STATIC); + if (!mKVCache) { mKVCacheManager->onClear(); } auto ptr = outputs[0]->host(); - if (seq_len < outputs[0]->length(1)) { - ::memset(outputs[0]->host() + seq_len * mHeadDim * mNumHead * bytes, 0, (outputs[0]->length(1)-seq_len) * mHeadDim * mNumHead * bytes); + if (seqLen < outputs[0]->length(1)) { + ::memset(outputs[0]->host() + seqLen * mHeadDim * mNumHead * mBytes, 0, (outputs[0]->length(1)-seqLen) * mHeadDim * mNumHead * mBytes); } return NO_ERROR; } @@ -512,24 +716,20 @@ CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend) mPackQ.reset(Tensor::createDevice({1, 1, 1, 1})); mPackQKV.reset(Tensor::createDevice({1, 1, 1, 1})); MNN::KVCacheManager::KVCacheConfig kvconfig; - int qkvQuantOptions = static_cast(backend)->getRuntime()->hint().qkvQuantOption; - kvconfig.mUseInt8Kernel = (qkvQuantOptions % 8 == 4); // qkvQuantOption % 8: // 0: Do not quantize - // 1: Only quantize key, use int8 asymmetric quantization - // 2: Only quantize value, use fp8 quantization - // 3: quantize both key and value - // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V + // 1: Q,K: Int8, V: Float32 + // 2: Q,K,V: Int8 // qkvQuantOption / 8: + // 0: do not use flash attention // 1: use flash attention - kvconfig.mQuantKey = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 1) || (qkvQuantOptions % 8 == 3); - kvconfig.mQuantValue = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 2); kvconfig.mKVCacheDir = static_cast(backend)->getRuntime()->hint().kvcacheDirPath; - kvconfig.mKVCacheSizeLimit = static_cast(backend)->getRuntime()->hint().kvcacheSizeLimit; + kvconfig.mPrefixCacheDir = static_cast(backend)->getRuntime()->hint().prefixcacheDirPath; kvconfig.mExpandChunk = 64; - mKVCacheManager.reset(new KVCacheManager(backend, kvconfig)); + kvconfig.mBlockNum = 1; + mKVCacheManager.reset(new CPUKVCacheManager(backend, kvconfig)); } CPUAttention::~CPUAttention() { diff --git a/source/backend/cpu/CPUAttention.hpp b/source/backend/cpu/CPUAttention.hpp index 8739fb711c..22f98a2925 100644 --- a/source/backend/cpu/CPUAttention.hpp +++ b/source/backend/cpu/CPUAttention.hpp @@ -14,8 +14,8 @@ #include #include "core/Execution.hpp" #include "core/OpCommonUtils.hpp" +#include "CPUKVCacheManager.hpp" #include "MNN/ErrorCode.hpp" -#include "KVCacheManager.hpp" namespace MNN { @@ -28,19 +28,32 @@ class CPUAttention : public Execution { virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; private: bool mKVCache = true; - bool mUseGemmInt8 = false; - int bytes = 4; + int mBytes = 4; int mThreadNum = 1; int mBlockKV = 512; int eP, lP, hP, mPack; // float matmul packing int eP8, lP8, hP8; // GemmInt8 packing int mNumHead, mKvNumHead, mHeadDim; - std::shared_ptr mPackQ, mPackQKV, mSumQ, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax; - std::shared_ptr mKVCacheManager = nullptr; - std::vector mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint; - template void pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale); - template void unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len); KVMeta* mMeta; + + // common + std::shared_ptr mPackQ, mPackQKV, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax; + std::shared_ptr mKVCacheManager = nullptr; + bool mUseFlashAttention = true; + + // quant Query/Key/Value + bool mQuantKey = false; + bool mQuantValue = false; + int mBlockNum = 1; + MemChunk mSumQ; + MemChunk mQueryScale, mQueryZeroPoint, mQueryQuantScale, mQueryQuantZero; + MemChunk mQuantQuery, mAccumBuffer; + + MemChunk mQuantQK, mQKScale, mQKBias, mSumQK, mArray; + AutoStorage mGemmBias, mGemmRelu; + + std::function mQuantFunc; + decltype(CoreInt8Functions::Int8GemmKernel) mInt8GemmKernel; }; } // namespace MNN diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index ceb23910fd..0e0bc1f136 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -326,11 +326,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons auto core = MNNGetCoreFunctions(); if (core->supportFp16arith && precision == BackendConfig::Precision_Low) { res = new Arm82Backend(this, memory); - if (hint().useArmSme2Cores && core->supportSME2 && res->functions()->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernel) { - res->mRelatedFunctions = &(res->functions()->sme2Int8MatmulRelatedFuncionsHp32); - } else { - res->mRelatedFunctions = &(res->functions()->int8MatmulRelatedFunctions); - } + res->mRelatedFunctions = &(res->functions()->int8MatmulRelatedFunctions); break; } #endif @@ -458,12 +454,8 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p mRuntime = const_cast(runtime); auto core = MNNGetCoreFunctions(); mThreadNumber = mRuntime->mThreadNumber; - if (mRuntime->hint().useArmSme2Cores && core->supportSME2 && core->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernel) { - mThreadNumber = ALIMIN(2, mThreadNumber); - mRelatedFunctions = &core->sme2Int8MatmulRelatedFuncionsHp32; - } else { - mRelatedFunctions = &core->int8MatmulRelatedFunctions; - } + mRelatedFunctions = &core->int8MatmulRelatedFunctions; + // Compute Group Rate do { if (mThreadNumber <= 1 || mRuntime->mPower == BackendConfig::Power_Low) { diff --git a/source/backend/cpu/CPUConvolution.cpp b/source/backend/cpu/CPUConvolution.cpp index 12c9ceff06..5ecdfe6100 100644 --- a/source/backend/cpu/CPUConvolution.cpp +++ b/source/backend/cpu/CPUConvolution.cpp @@ -190,8 +190,10 @@ void CPUConvolution::MutableResourceInt8::updateInputOutputScale(std::vector(biasData[i] / (mInputScale * weightScale)) - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) + outputZeroPointFused; } } else { + auto outputScale = mResource->mWeightBits == 4 ? 1.f : mOutputScale; + int32_t outputZero = mResource->mWeightBits == 4 ? 0 : mOutputZeroPoint; for (int i = 0; i < ocUp4; ++i) { - biasfloat[i] = (biasData[i] - mResource->mWeightKernelSum->host()[i] * (mInputZeroPoint + offset) * mInputScale) / mOutputScale + mOutputZeroPoint; + biasfloat[i] = (biasData[i] - mResource->mWeightKernelSum->host()[i] * (mInputZeroPoint + offset) * mInputScale) / outputScale + outputZero; } } @@ -224,9 +226,9 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B auto scalePtr = resource->mOriginScale->host(); memset(scalePtr, 0, ocUpUnit * 2 * sizeof(float)); - resource->mActBits = 8; + resource->mWeightBits = 8; if (convParam->symmetricQuan()) { - resource->mActBits = convParam->symmetricQuan()->nbits(); + resource->mWeightBits = convParam->symmetricQuan()->nbits(); } const int8_t* weightSrc = nullptr; int weightSize = 0; diff --git a/source/backend/cpu/CPUConvolution.hpp b/source/backend/cpu/CPUConvolution.hpp index 26ca877602..13abecef0f 100644 --- a/source/backend/cpu/CPUConvolution.hpp +++ b/source/backend/cpu/CPUConvolution.hpp @@ -66,7 +66,7 @@ class CPUConvolution : public Execution { std::vector mReluThreshold; // relu or relu6 bool mRelu; - int mActBits; // quant bits + int mWeightBits; // quant bits bool mUseConvQuan = true; bool mWeightAsymmetricQuant = true; diff --git a/source/backend/cpu/CPUConvolutionDepthwise.cpp b/source/backend/cpu/CPUConvolutionDepthwise.cpp index 168b4193be..9472fbf9ca 100644 --- a/source/backend/cpu/CPUConvolutionDepthwise.cpp +++ b/source/backend/cpu/CPUConvolutionDepthwise.cpp @@ -15,6 +15,10 @@ #include "backend/cpu/compute/CommonOptFunction.h" #include "backend/cpu/compute/ConvOpt.h" +#ifdef MNN_KLEIDIAI_ENABLED +#include "backend/cpu/KleidiAIConvolutionDepthwise.hpp" +#endif //MNN_KLEIDIAI_ENABLED + namespace MNN { CPUConvolutionDepthwise::FloatExecution::FloatExecution(const Convolution2DCommon* common, Backend* b, const float* originWeight, size_t originWeightSize, @@ -276,6 +280,23 @@ class CPUConvolutionDepthwiseCreator : public CPUBackend::Creator { if (inputs.empty()) { return new CPUConvolutionDepthwise::FloatExecution(conv2d->common(), backend, originWeight, originWeightSize, originBias, originBiasSize); } +#ifdef MNN_KLEIDIAI_ENABLED + auto bytes = static_cast(backend)->functions()->bytes; + int kernel_height = conv2d->common()->kernelY(); + int kernel_width = conv2d->common()->kernelX(); + int strideY = conv2d->common()->strideY(); + int strideX = conv2d->common()->strideX(); + int dilateX = conv2d->common()->dilateX(); + int dilateY = conv2d->common()->dilateY(); + bool useKleidiAI = kernel_height ==3 && kernel_width ==3 && + strideY ==1 && strideX ==1 && + dilateX ==1 && dilateY ==1 && + bytes == 4; + useKleidiAI = backend->getRuntime()->hint().enableKleidiAI && useKleidiAI; + if(useKleidiAI) { + return new KleidiAIConvolutionDepthwise::KleidiAIDepthwiseExecution(conv2d->common(), backend, originWeight, originWeightSize, originBias, originBiasSize); + } +#endif //MNN_KLEIDIAI_ENABLED return new CPUConvolutionDepthwise::FloatExecution(conv2d->common(), backend, originWeight, originWeightSize, originBias, originBiasSize); } }; diff --git a/source/backend/cpu/CPUKVCacheManager.cpp b/source/backend/cpu/CPUKVCacheManager.cpp new file mode 100644 index 0000000000..7356e34dde --- /dev/null +++ b/source/backend/cpu/CPUKVCacheManager.cpp @@ -0,0 +1,795 @@ +// +// CPUKVCacheManager.cpp +// MNN +// +// Created by MNN on 2024/08/05. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#include "CPUKVCacheManager.hpp" +#include "core/Concurrency.h" + +namespace MNN { + +/* +** @brief Expand the size of kvcache and copy it from the old tensor in memory to the new tensor in memory +** Finally reset the pointer to the new tensor +*/ +void CPUKVCacheManager::expandKVCacheInMem(int oldMaxLength) { + /*=================================== Key ===================================*/ + auto new_key = Tensor::createDevice({mKvNumHead, (int)mCurrentKeySizePerHead}); + mBackend->onAcquireBuffer(new_key, Backend::STATIC); + if (mQuantKey) { + memset(new_key->host(), 0, mKvNumHead * mCurrentKeySizePerHead); + } + for (int h = 0; h < mKvNumHead; h++) { + memcpy( + new_key->host() + h * mCurrentKeySizePerHead, + mPastKey->host() + h * mPastKey->stride(0), + mPastKey->stride(0) + ); + if (!mQuantKey && (new_key->stride(0) - mPastKey->stride(0)) > 0) { + memset(new_key->host() + h * new_key->stride(0) + mPastKey->stride(0), 0, (new_key->stride(0) - mPastKey->stride(0))); + } + } + mPastKey.reset(new_key); + /*=================================== Value ===================================*/ + auto newValue = Tensor::createDevice({mKvNumHead, (int)mCurrentValueSizePerHead}); + mBackend->onAcquireBuffer(newValue, Backend::STATIC); + + if (mUseFlashAttention) { // [mKvNumHead, UP_DIV(mMaxLength, mFlashAttentionUpperKv), UP_DIV(mHeadDim, hP), UP_DIV(mFlashAttentionUpperKv, lP), hP, lP] + for (int h = 0; h < mKvNumHead; h++) { + memset(newValue->host() + h * newValue->stride(0), 0, newValue->stride(0)); + memcpy( + newValue->host() + h * newValue->stride(0), + mPastValue->host() + h * mPastValue->stride(0), + mPastValue->stride(0) + ); + } + } else { + if (mQuantValue) { // [mKvNumHead, UP_DIV(mHeadDim, hP8), (UP_DIV(mMaxLength, lP8)*hP8*lP8+2*hP8*sizeof(float)) ] + auto currentWeightInside = ROUND_UP(mMaxLength, lP8) * hP8; + auto currentStride1 = currentWeightInside + 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES; + auto currentStride0 = currentStride1 * UP_DIV(mHeadDim, hP8); + + auto prevWeightInside = ROUND_UP(oldMaxLength, lP8) * hP8; + auto prevStride1 = prevWeightInside + 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES; + auto prevStride0 = prevStride1 * UP_DIV(mHeadDim, hP8); + for (int h = 0; h < mKvNumHead; ++h) { + for (int d = 0; d < UP_DIV(mHeadDim, hP8); ++d) { + auto dstPtr = newValue->host() + h * currentStride0 + d * currentStride1; + auto srcPtr = mPastValue->host() + h * prevStride0 + d * prevStride1; + + // initialize 0 for weightInt8 + memset(dstPtr, 0, currentWeightInside); + // copy inner side weightInt8 + memcpy(dstPtr, srcPtr, prevWeightInside); + // copy hP8 scale&bias + memcpy(dstPtr + currentWeightInside, srcPtr + prevWeightInside, 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES); + } + } + } else { // [mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP] + auto currentStride1 = ROUND_UP(mMaxLength, lP) * hP * mBytes; + auto currentStride0 = ROUND_UP(mMaxLength, lP) * hP * UP_DIV(mHeadDim, hP) * mBytes; + + auto prevStride1 = ROUND_UP(oldMaxLength, lP) * hP * mBytes; + auto prevStride0 = ROUND_UP(oldMaxLength, lP) * hP * UP_DIV(mHeadDim, hP) * mBytes; + for (int h = 0; h < mKvNumHead; ++h) { + for (int d = 0; d < UP_DIV(mHeadDim, hP); ++d) { + auto dstPtr = newValue->host() + h * currentStride0 + d * currentStride1; + auto srcPtr = mPastValue->host() + h * prevStride0 + d * prevStride1; + + // initialize 0 for weight + if (lP > 1) { + memset(dstPtr, 0, currentStride1); + } + // copy inner side weight + memcpy(dstPtr, srcPtr, prevStride1); + } + } + } + } + mPastValue.reset(newValue); +} + +/* +** @brief Move the kvcache from memory to the memory-mapped kvcache files in disk +** Then release the memory buffer of old kvcache +*/ +void CPUKVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { + /*=================================== Key ===================================*/ + size_t prevKeySizePerHead = 0; + if (mQuantKey) { + prevKeySizePerHead = ROUND_UP(oldMaxLength, hP8) * ROUND_UP(mHeadDim, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(oldMaxLength, hP8); + } else { + prevKeySizePerHead = UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; + } + if (mHeadDim % lP || mQuantKey) { + memset(mMapKeyAddr, 0, mKvNumHead * mCurrentKeySizePerHead); + } + for (int h = 0; h < mKvNumHead; h++) { + memcpy( + mMapKeyAddr + h * mCurrentKeySizePerHead, + mPastKey->host() + h * prevKeySizePerHead, + prevKeySizePerHead + ); + } + mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); + mPastKey.reset(); + /*=================================== Value ===================================*/ + { + size_t prevValueSizePerHead = 0; + if (mQuantValue) { + prevValueSizePerHead = UP_DIV(oldMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mFlashAttentionUpperKv, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mHeadDim, hP8)); + } else { + prevValueSizePerHead = UP_DIV(oldMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP) * ROUND_UP(mFlashAttentionUpperKv, lP) * mBytes); + } + if (lP > 1 || mQuantValue) { + memset(mMapValueAddr, 0, mKvNumHead * mCurrentValueSizePerHead); + } + + if (mUseFlashAttention) { + for (int h = 0; h < mKvNumHead; h++) { + memcpy( + mMapValueAddr + h * mCurrentValueSizePerHead, + mPastValue->host() + h * prevValueSizePerHead, + prevValueSizePerHead + ); + } + } else { + if (mQuantValue) { // [mKvNumHead, UP_DIV(mHeadDim, hP8), (UP_DIV(mMaxLength, lP8)*hP8*lP8+2*hP8*sizeof(float)) ] + auto currentWeightInside = ROUND_UP(mMaxLength, lP8) * hP8; + auto currentStride1 = currentWeightInside + 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES; + auto currentStride0 = currentStride1 * UP_DIV(mHeadDim, hP8); + + auto prevWeightInside = ROUND_UP(oldMaxLength, lP8) * hP8; + auto prevStride1 = prevWeightInside + 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES; + auto prevStride0 = prevStride1 * UP_DIV(mHeadDim, hP8); + for (int h = 0; h < mKvNumHead; ++h) { + for (int d = 0; d < UP_DIV(mHeadDim, hP8); ++d) { + auto dstPtr = mMapValueAddr + h * currentStride0 + d * currentStride1; + auto srcPtr = mPastValue->host() + h * prevStride0 + d * prevStride1; + + // initialize 0 for weightInt8 + memset(dstPtr, 0, currentWeightInside); + // copy inner side weightInt8 + memcpy(dstPtr, srcPtr, prevWeightInside); + // copy hP8 scale&bias + memcpy(dstPtr + currentWeightInside, srcPtr + prevWeightInside, 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES); + } + } + } else { // [mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP] + auto currentStride1 = ROUND_UP(mMaxLength, lP) * hP * mBytes; + auto currentStride0 = ROUND_UP(mMaxLength, lP) * hP * UP_DIV(mHeadDim, hP) * mBytes; + + auto prevStride1 = ROUND_UP(oldMaxLength, lP) * hP * mBytes; + auto prevStride0 = ROUND_UP(oldMaxLength, lP) * hP * UP_DIV(mHeadDim, hP) * mBytes; + for (int h = 0; h < mKvNumHead; ++h) { + for (int d = 0; d < UP_DIV(mHeadDim, hP); ++d) { + auto dstPtr = mMapValueAddr + h * currentStride0 + d * currentStride1; + auto srcPtr = mPastValue->host() + h * prevStride0 + d * prevStride1; + + // initialize 0 for weight + if (lP > 1) { + memset(dstPtr, 0, currentStride1); + } + // copy inner side weight + memcpy(dstPtr, srcPtr, prevStride1); + } + } + } + } + mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC); + mPastValue.reset(); + } +} + +/* +** @brief Expand the size of kvcache files in disk +*/ +void CPUKVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize, file_t specKeyFile, file_t specValueFile) { + // Step 1: Copy the old kvcache from files to temporary buffers in memory + auto prevKeySizePerHead = oldKeySize / mKvNumHead; + auto prevValueSizePerHead = oldValueSize / mKvNumHead; + std::shared_ptr prevKey, prevValue; + prevKey.reset(Tensor::createDevice({mKvNumHead, prevKeySizePerHead})); + prevValue.reset(Tensor::createDevice({mKvNumHead, prevValueSizePerHead})); + + mBackend->onAcquireBuffer(prevKey.get(), Backend::STATIC); + mBackend->onAcquireBuffer(prevValue.get(), Backend::STATIC); + if (mHeadDim % lP) { + memset(prevKey->host(), 0, prevKey->length(0) * prevKey->stride(0)); + } + if (lP > 1) { + // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0. + // computing L is seq_len + memset(prevValue->host(), 0, prevValue->length(0) * prevValue->stride(0)); + } + mmapKVCache(oldKeySize, oldValueSize, specKeyFile, specValueFile); + memcpy(prevKey->host(), mMapKeyAddr, oldKeySize); + memcpy(prevValue->host(), mMapValueAddr, oldValueSize); + // Step 2: Resize the kvcache files and remap them + unmapKVCache(oldKeySize, oldValueSize); + resetKVCacheFileSize(keySize, valueSize); + mmapKVCache(keySize, valueSize); + // Step 3: Move the kvcache from temporary buffers in memory to disk + memset(mMapKeyAddr, 0, keySize); + memset(mMapValueAddr, 0, valueSize); + + for (int h = 0; h < mKvNumHead; h++) { + memcpy(mMapKeyAddr + h * mCurrentKeySizePerHead, prevKey->host() + h * prevKeySizePerHead, prevKeySizePerHead); + } + + if (mUseFlashAttention) { + for (int h = 0; h < mKvNumHead; h++) { + memcpy(mMapValueAddr + h * mCurrentValueSizePerHead, prevValue->host() + h * prevValueSizePerHead, prevValueSizePerHead); + } + } else { + if (mQuantValue) { + auto currentWeightInside = ROUND_UP(mMaxLength, lP8) * hP8; + auto currentStride1 = currentWeightInside + 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES; + auto currentStride0 = currentStride1 * UP_DIV(mHeadDim, hP8); + + auto prevWeightInside = ROUND_UP(oldMaxLength, lP8) * hP8; + auto prevStride1 = prevWeightInside + 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES; + auto prevStride0 = prevStride1 * UP_DIV(mHeadDim, hP8); + + for (int h = 0; h < mKvNumHead; ++h) { + for (int d = 0; d < UP_DIV(mHeadDim, hP8); ++d) { + auto dstPtr = mMapValueAddr + h * currentStride0 + d * currentStride1; + auto srcPtr = prevValue->host() + h * prevStride0 + d * prevStride1; + + // initialize 0 for weightInt8 + memset(dstPtr, 0, currentWeightInside); + // copy inner side weightInt8 + memcpy(dstPtr, srcPtr, prevWeightInside); + // copy hP8 scale&bias + memcpy(dstPtr + currentWeightInside, srcPtr + prevWeightInside, 2 * mConfig.mBlockNum * hP8 * QUANT_INFO_BYTES); + } + } + } else { + auto currentStride1 = ROUND_UP(mMaxLength, lP) * hP * mBytes; + auto currentStride0 = ROUND_UP(mMaxLength, lP) * hP * UP_DIV(mHeadDim, hP) * mBytes; + + auto prevStride1 = ROUND_UP(oldMaxLength, lP) * hP * mBytes; + auto prevStride0 = ROUND_UP(oldMaxLength, lP) * hP * UP_DIV(mHeadDim, hP) * mBytes; + for (int h = 0; h < mKvNumHead; ++h) { + for (int d = 0; d < UP_DIV(mHeadDim, hP); ++d) { + auto dstPtr = mMapValueAddr + h * currentStride0 + d * currentStride1; + auto srcPtr = prevValue->host() + h * prevStride0 + d * prevStride1; + + // initialize 0 for weight + if (lP > 1) { + memset(dstPtr, 0, currentStride1); + } + // copy inner side weight + memcpy(dstPtr, srcPtr, prevStride1); + } + } + } + } + + // Step 4: Release the temporary buffers + mBackend->onReleaseBuffer(prevKey.get(), Backend::STATIC); + mBackend->onReleaseBuffer(prevValue.get(), Backend::STATIC); +} + +void CPUKVCacheManager::onResize(int kv_num_head, int head_dim) { + mKvNumHead = kv_num_head; + mHeadDim = head_dim; + auto core = static_cast(mBackend)->functions(); + core->MNNGetMatMulPackMode(&eP, &lP, &hP); + mBytes = core->bytes; + mThreadNum = static_cast(mBackend)->threadNumber(); + if (mThreadNum > mKvNumHead) { + mThreadNum = mKvNumHead; + } + + static_cast(mBackend)->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); + mQuantKeyFunc = core->MNNQuantAttentionKey; + mQuantValueFunc = core->MNNQuantAttentionValue; + +} + +void CPUKVCacheManager::onAlloc(KVMeta* meta, int seq_len) { + mMeta = meta; + + // load disk prefix kvcache + if(mMeta != nullptr && mMeta->file_name.size() > 0 && mMeta->file_flag == KVMeta::PendingRead) { + // create new files + std::string pathk = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index) + ".k"; + std::string pathv = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index++) + ".v"; + mMeta->layer_index = mMeta->layer_index % mMeta->layer_nums; + auto old_key_fd = MNNOpenFile(pathk.c_str(), MNN_FILE_WRITE); + auto old_value_fd = MNNOpenFile(pathv.c_str(), MNN_FILE_WRITE); + if (old_key_fd == INVALID_FILE) { + MNN_PRINT("Failed to open the file: %s\n", pathk.c_str()); + } + if (old_value_fd == INVALID_FILE) { + MNN_PRINT("Failed to open the file: %s\n", pathv.c_str()); + } + + // get kv cache file info + auto oldKeySize = MNNGetFileSize(old_key_fd); + auto oldValueSize = MNNGetFileSize(old_value_fd); + + size_t oldMaxLength = 0; + if (mQuantKey || mQuantValue) { + MNN_ERROR("[Error]: Currently, kvcache save in disk not support quantized key/value\n"); + } else { + size_t oldKeyMaxLength = oldKeySize / (mKvNumHead * ROUND_UP(mHeadDim, lP) * mBytes); + size_t oldValueMaxLength = oldValueSize / (mKvNumHead * ROUND_UP(mHeadDim, hP) * mBytes); + oldMaxLength = ALIMIN(oldKeyMaxLength, oldValueMaxLength); + } + if(oldMaxLength < meta->seqlen_in_disk) { + MNN_ERROR("[Error]: Kvcache in disk size smaller than saved lengthInDiskToload:%d\n", (int)meta->seqlen_in_disk); + } + + if (mUseFlashAttention) { + setFlashAttentionUpperKv(MNN_FLASH_ATTENTION_BLOCK_SIZE); + } else { + setFlashAttentionUpperKv(mMaxLength); + } + int kv_seq_len = meta->add + meta->seqlen_in_disk; + mMaxLength = kv_seq_len > oldMaxLength ? kv_seq_len + mConfig.mExpandChunk : oldMaxLength; + size_t keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes; + size_t valueSize = (size_t)mKvNumHead * UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP) * ROUND_UP(mFlashAttentionUpperKv, lP) * mBytes); + + keySize = ALIMAX(keySize, oldKeySize); + valueSize = ALIMAX(valueSize, oldValueSize); + + if (mQuantKey) { + mCurrentKeySizePerHead = ROUND_UP(mMaxLength, hP8) * ROUND_UP(mHeadDim, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mMaxLength, hP8); + } else { + mCurrentKeySizePerHead = ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes; + } + if (mQuantValue) { + mCurrentValueSizePerHead = UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mFlashAttentionUpperKv, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mHeadDim, hP8)); + } else { + mCurrentValueSizePerHead = UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP) * ROUND_UP(mFlashAttentionUpperKv, lP) * mBytes); + } + + createKVCacheFile(); + resetKVCacheFileSize(keySize, valueSize); + expandKVCacheInDisk(oldMaxLength, oldKeySize, oldValueSize, keySize, valueSize, old_key_fd, old_value_fd); + mPastLength = meta->seqlen_in_disk; + mKVCacheInDisk = true; + + return; + } + + int kv_seq_len = mMeta != nullptr ? (int)meta->add : seq_len; + mMaxLength = kv_seq_len + mConfig.mExpandChunk; + if (mUseFlashAttention) { + setFlashAttentionUpperKv(MNN_FLASH_ATTENTION_BLOCK_SIZE); + } else { + setFlashAttentionUpperKv(mMaxLength); + } + + // 1. compute size + if (mQuantKey) { + mCurrentKeySizePerHead = ROUND_UP(mMaxLength, hP8) * ROUND_UP(mHeadDim, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mMaxLength, hP8); + } else { + mCurrentKeySizePerHead = ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes; + } + if (mQuantValue) { + mCurrentValueSizePerHead = UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mFlashAttentionUpperKv, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mHeadDim, hP8)); + } else { + mCurrentValueSizePerHead = UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP) * ROUND_UP(mFlashAttentionUpperKv, lP) * mBytes); + } + size_t keySize = (size_t)mKvNumHead * mCurrentKeySizePerHead; + size_t valueSize = (size_t)mKvNumHead * mCurrentValueSizePerHead; + + // 2. allocate buffer + + // case1: key&value size exceeds the limited size + // case2: multi prompts share a common prefix kv cache info + bool storeKvInDisk = !mConfig.mKVCacheDir.empty(); + bool sharePrefixKv = mMeta != nullptr && mMeta->file_name.size() > 0 && mMeta->file_flag == KVMeta::PendingWrite; + + if (sharePrefixKv) { + mSaveShareKvPrefix = true; + if(!MNNCreateDir(mConfig.mPrefixCacheDir.c_str())) { + MNN_PRINT("Failed to create prefix cache file dir: %s\n", mConfig.mPrefixCacheDir.c_str()); + } + } + if (storeKvInDisk || sharePrefixKv) { // store kv in disk + std::string keyStoredDst = ""; + std::string valueStoredDst = ""; + if(mMeta != nullptr) { + mBasePrefixFileName = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index); + keyStoredDst = sharePrefixKv ? mBasePrefixFileName + ".k" : ""; + valueStoredDst = sharePrefixKv ? mBasePrefixFileName + ".v" : ""; + mMeta->layer_index++; + mMeta->layer_index = mMeta->layer_index % mMeta->layer_nums; + } + createKVCacheFile(keyStoredDst, valueStoredDst); + resetKVCacheFileSize(keySize, valueSize); + mmapKVCache(keySize, valueSize); + mKVCacheInDisk = true; + } else { // store kv in memory + mPastKey.reset(Tensor::createDevice({mKvNumHead, (int)mCurrentKeySizePerHead})); + mPastValue.reset(Tensor::createDevice({mKvNumHead, (int)mCurrentValueSizePerHead})); + + mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC); + + // initilize 0 + if ((mHeadDim % lP && !mQuantKey) || mQuantKey) { + memset(mPastKey->host(), 0, mPastKey->length(0) * mPastKey->stride(0)); + } + if (lP > 1 || mQuantValue) { // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0. + memset(mPastValue->host(), 0, mPastValue->length(0) * mPastValue->stride(0)); + } + } + // scale, zero point and sum of key for quantization + if (mQuantKey) { // quant K + mKeySum.reset(Tensor::createDevice({mKvNumHead, ROUND_UP(mMaxLength, hP8) * QUANT_INFO_BYTES})); + mKeyMax.reset(Tensor::createDevice({mKvNumHead, mHeadDim * QUANT_INFO_BYTES})); + mBackend->onAcquireBuffer(mKeySum.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mKeyMax.get(), Backend::STATIC); + + for (int ks = 0; ks < mKvNumHead * mHeadDim; ++ks) { + mKeyMax->host()[ks] = std::numeric_limits::lowest(); + } + if (mBytes == 2) { + auto core = static_cast(mBackend)->functions(); + core->MNNFp32ToLowp(mKeyMax->host(), (int16_t*)(mKeyMax->host()), mKvNumHead * mHeadDim); + } + } + if (mQuantValue) { + mValueSum.reset(Tensor::createDevice({mKvNumHead, (int)UP_DIV(mMaxLength, mFlashAttentionUpperKv), ROUND_UP(mHeadDim, hP8) * QUANT_INFO_BYTES})); + mBackend->onAcquireBuffer(mValueSum.get(), Backend::STATIC); + memset(mValueSum->host(), 0, mValueSum->stride(0) * mValueSum->length(0)); + } +} + +void CPUKVCacheManager::onRealloc(KVMeta* meta) { + auto kv_seq_len = meta->previous + meta->add - meta->remove + meta->computeReverseSize(); + if (kv_seq_len > mMaxLength) { + // Realloc + int oldMaxLength = mMaxLength; + mMaxLength = (int)kv_seq_len + mConfig.mExpandChunk; + if (mUseFlashAttention) { + setFlashAttentionUpperKv(MNN_FLASH_ATTENTION_BLOCK_SIZE); + } else { + setFlashAttentionUpperKv(mMaxLength); + } + size_t oldKeySize = (size_t)mKvNumHead * mCurrentKeySizePerHead; + size_t oldValueSize = (size_t)mKvNumHead * mCurrentValueSizePerHead; + + // update current key size per head + if (mQuantKey) { + mCurrentKeySizePerHead = ROUND_UP(mMaxLength, hP8) * ROUND_UP(mHeadDim, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mMaxLength, hP8); + } else { + mCurrentKeySizePerHead = UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; + } + // update current value size per head + if (mQuantValue) { + mCurrentValueSizePerHead = UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mFlashAttentionUpperKv, lP8) + 2 * QUANT_INFO_BYTES * mConfig.mBlockNum * ROUND_UP(mHeadDim, hP8)); + } else { + mCurrentValueSizePerHead = UP_DIV(mMaxLength, mFlashAttentionUpperKv) * (ROUND_UP(mHeadDim, hP) * ROUND_UP(mFlashAttentionUpperKv, lP) * mBytes); + } + size_t keySize = (size_t)mKvNumHead * mCurrentKeySizePerHead; + size_t valueSize = (size_t)mKvNumHead * mCurrentValueSizePerHead; + + /*==== No limit for kvcache ====*/ + if (mKVCacheInDisk == false) { + expandKVCacheInMem(oldMaxLength); + } else { + expandKVCacheInDisk(oldMaxLength, oldKeySize, oldValueSize, keySize, valueSize); + } + /* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */ + if (mQuantKey) { + auto newKeySumTensor = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); + mBackend->onAcquireBuffer(newKeySumTensor, Backend::STATIC); + for (int h = 0; h < mKvNumHead; h++) { + memcpy(newKeySumTensor->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + } + mKeySum.reset(newKeySumTensor); + } + if (mQuantValue) { + auto newValueSumTensor = Tensor::createDevice({mKvNumHead, (int)UP_DIV(mMaxLength, mFlashAttentionUpperKv), ROUND_UP(mHeadDim, hP8) * QUANT_INFO_BYTES}); + mBackend->onAcquireBuffer(newValueSumTensor, Backend::STATIC); + auto remainSizePerHead = mValueSum->stride(0); + auto increSizePerHead = newValueSumTensor->stride(0) - mValueSum->stride(0); + for (int h = 0; h < mKvNumHead; ++h) { + memcpy(newValueSumTensor->host() + h * newValueSumTensor->stride(0) , mValueSum->host() + h * mValueSum->stride(0), remainSizePerHead); + // memset 0 + if (increSizePerHead > 0) { + memset(newValueSumTensor->host() + h * newValueSumTensor->stride(0) + remainSizePerHead, 0, increSizePerHead); + } + } + mValueSum.reset(newValueSumTensor); + } + } + // Remove + auto start = mPastLength - meta->remove; + if (0 == meta->n_reserve || mQuantKey || mQuantValue) { // n_reserve > 0 is not currently supported when K or V is quantized. + mPastLength = start; + return; + } +#if 1 + auto dstIndex = start; + for (int n = 0; n < meta->n_reserve; ++n) { + auto begin = meta->reserve[2 * n]; + auto size = meta->reserve[2 * n + 1]; + auto srcIndex = start + begin; + if (mBytes == 2) { + moveKV(srcIndex, dstIndex, size); + } else { + moveKV(srcIndex, dstIndex, size); + } + dstIndex += size; + } + mPastLength = dstIndex; +#else + // Don't support not align reserve + auto align = hP; + auto dstStart = start; + auto lastValidSrcEnd = start; + for (int n=0; nn_reserve; ++n) { + auto lastEndAlign = UP_DIV(lastValidSrcEnd, align) * align; + auto begin = meta->reserve[2 * n]; + auto size = meta->reserve[2 * n + 1]; + auto startAlign = ((begin + start) / align) * align; + if (startAlign <= lastEndAlign) { + // Fullly reserve + dstStart = dstStart + size; + lastValidSrcEnd = begin + start + size; + continue; + } + auto end = begin + start + size; + auto endAlign = UP_DIV(end, align) * align; + + auto sizeUnit = (endAlign - startAlign) / align; + auto dstStartAlign = UP_DIV(dstStart, align) * align; + + //TODO: Support Quant +// mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP})); + + // Move K + auto keyStride = UP_DIV(mMaxLength, align) * align * ROUND_UP(mHeadDim, lP); + auto dstKAddr = keyAddr() + dstStartAlign * ROUND_UP(mHeadDim, lP) * mBytes; + auto srcKAddr = keyAddr() + startAlign * ROUND_UP(mHeadDim, lP) * mBytes; + for (int i=0; ifile_name) + "_" + std::to_string(mMeta->layer_index) + ".k"; + std::string pathv = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index++) + ".v"; + mMeta->layer_index = mMeta->layer_index % mMeta->layer_nums; + + auto new_key_fd = MNNCreateFile(pathk.c_str()); + auto new_value_fd = MNNCreateFile(pathv.c_str()); + if (new_key_fd == INVALID_FILE) { + MNN_PRINT("Failed to create the file: %s\n", pathk.c_str()); + } + if (new_value_fd == INVALID_FILE) { + MNN_PRINT("Failed to create the file: %s\n", pathv.c_str()); + } + // set new file size + if (MNNSetFileSize(new_key_fd, keySize) != MNN::NO_ERROR || MNNSetFileSize(new_value_fd, valueSize) != MNN::NO_ERROR) { + MNN_PRINT("Failed to resize the kvcache files!\n"); + } + // mmap files + int8_t* mMapNewKeyAddr = (int8_t *)MNNMmapFile(new_key_fd, keySize); + if (mMapNewKeyAddr == nullptr) { + MNN_PRINT("Failed to memory-map the new kvcache!\n"); + } + int8_t* mMapNewValueAddr =(int8_t *)MNNMmapFile(new_value_fd, valueSize); + if (mMapNewValueAddr == nullptr) { + MNN_PRINT("Failed to memory-map the kvcache!\n"); + } + + // copy + memcpy(mMapNewKeyAddr, mMapKeyAddr, keySize); + memcpy(mMapNewValueAddr, mMapValueAddr, valueSize); + + // unmap new files + if (mMapNewKeyAddr != nullptr) { + MNNUnmapFile(mMapNewKeyAddr, keySize); + mMapNewKeyAddr = nullptr; + } + if (mMapNewValueAddr != nullptr) { + MNNUnmapFile(mMapNewValueAddr, valueSize); + mMapNewValueAddr = nullptr; + } + // close file + if (new_key_fd != INVALID_FILE) { + MNNCloseFile(new_key_fd); + new_key_fd = INVALID_FILE; + } + if (new_value_fd != INVALID_FILE) { + MNNCloseFile(new_value_fd); + new_value_fd = INVALID_FILE; + } +} + +void CPUKVCacheManager::onClear() { + if (mKVCacheInDisk) { + // mSaveShareKvPrefix also need unmap file + unmapKVCache(mCurrentKeySizePerHead * (size_t)mKvNumHead, mCurrentValueSizePerHead * (size_t)mKvNumHead); + if(mSaveShareKvPrefix) { + // set prefix cachefile validation + auto k_file = mBasePrefixFileName + ".k"; + if(MNNFileExist(k_file.c_str())) { + auto k_sync_file = mBasePrefixFileName + "_sync.k"; + MNNCreateFile(k_sync_file.c_str()); + } + auto v_file = mBasePrefixFileName + ".v"; + if(MNNFileExist(v_file.c_str())) { + auto v_sync_file = mBasePrefixFileName + "_sync.v"; + MNNCreateFile(v_sync_file.c_str()); + } + } else { + // delete temp kvcache file + removeKVCacheFile(); + } + mKVCacheInDisk = false; + } + mPastKey.reset(); + mPastValue.reset(); + mKeySum.reset(); + mKeyMax.reset(); + mValueSum.reset(); + mMaxLength = mPastLength = 0; +} + +template +void CPUKVCacheManager::ProcessKey(const Tensor* key, int seqLen, int kvHead) { + if (mQuantKey) { // [seqLen, headDim] -> [maxlen/hP8, blockNum, (headDim/blockNum)/lP8, hP8, lP8] + int8_t * keyDst = reinterpret_cast(addrOfKey(kvHead)); + float * sumDst = reinterpret_cast(addrOfKeySum(kvHead)); + + auto blockL = UP_DIV(mHeadDim, mConfig.mBlockNum); + auto weightStride1 = ROUND_UP(blockL, lP8) * hP8; + auto weightStride2 = lP8 * hP8; + auto packedWeightStride1 = weightStride1 + 2 * QUANT_INFO_BYTES * hP8; + + T* keyMax = reinterpret_cast(addrOfKeyMax(kvHead)); + int32_t params[] = {mKvNumHead, seqLen, mHeadDim, mConfig.mBlockNum, eP8, lP8, hP8, mPastLength, kvHead}; + mQuantKeyFunc(keyDst, key->host(), sumDst, (float*)keyMax, params); + } + else { // target: [maxlen/hP, headdim/lP, hP, lP] + T * key_dst = reinterpret_cast(addrOfKey(kvHead)); + auto stride0 = ROUND_UP(mHeadDim, lP) * hP; + auto stride1 = hP * lP; + for (int i = 0; i < seqLen; i++) { + T * key_src = key->host() + i * mKvNumHead * mHeadDim + kvHead * mHeadDim; + int out_index = (mPastLength + i) / hP; + int in_index = (mPastLength + i) % hP; + for (int j = 0; j < mHeadDim; j++) { + key_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = key_src[j]; + } + } + } +} + +template +void CPUKVCacheManager::ProcessValue(const Tensor* value, int seqLen, int kvHead) { // [headdim/hP, maxlen, hP] + if (mQuantValue) { + int8_t* valueDst = reinterpret_cast(addrOfValue(kvHead)); + float* valueSum = reinterpret_cast(addrOfValueSum(kvHead)); + + int32_t params[] = {mKvNumHead, seqLen, mHeadDim, mConfig.mBlockNum, mMaxLength, lP8, hP8, mPastLength, kvHead, (int32_t)mFlashAttentionUpperKv}; + mQuantValueFunc(valueDst, value->host(), valueSum, params); + } + else { + // [mHeadDim/hP, mMaxLength/lP, hP, lP] + auto stride0 = ROUND_UP(mMaxLength, lP) * hP; + auto stride1 = hP * lP; + + auto weightStride2 = lP * hP; + auto weightStride1 = UP_DIV((int32_t)mFlashAttentionUpperKv, lP) * weightStride2; + auto weightStride0 = weightStride1 * UP_DIV(mHeadDim, hP); + + T * value_dst = reinterpret_cast(addrOfValue(kvHead)); + for (int i = 0; i < seqLen; i++) { + T * value_src = value->host() + i * mKvNumHead * mHeadDim + kvHead * mHeadDim; + // int seqLenOut = (mPastLength + i) / lP; + // int seqLenIn = (mPastLength + i) % lP; + + int kvSeqIndx = mPastLength + i; + int idxInner = (kvSeqIndx / (int32_t)mFlashAttentionUpperKv) * weightStride0 + (kvSeqIndx % (int32_t)mFlashAttentionUpperKv) / lP * weightStride2 + (kvSeqIndx % (int32_t)mFlashAttentionUpperKv) % lP; + for (int j = 0; j < mHeadDim; j++) { + int idxBase = (j / hP) * weightStride1 + (j % hP) * lP; + int out_index = j / hP; + int in_index = j % hP; + // value_dst[out_index * stride0 + seqLenOut * stride1 + in_index * lP + seqLenIn] = value_src[j]; + value_dst[idxBase + idxInner] = value_src[j]; + } + } + } +} + +size_t CPUKVCacheManager::keyIndex(int seq, int dim) const { + return (seq / hP) * ROUND_UP(mHeadDim, lP) * hP + + (dim / lP) * hP * lP + + (seq % hP) * lP + + (dim % lP); +} + +size_t CPUKVCacheManager::valueIndex(int seq, int dim) const { + return (dim / hP) * ROUND_UP(mMaxLength, lP) * hP + + (seq / lP) * hP * lP + + (dim % hP) * lP + + (seq % lP); +} + +template +void CPUKVCacheManager::moveKV(int src, int dst, int size) { + for (int h = 0; h < mKvNumHead; ++h) { + auto kPtr = reinterpret_cast(addrOfKey(h)); + auto vPtr = reinterpret_cast(addrOfValue(h)); + for (int i = 0; i < size; i++) { + for (int j = 0; j < mHeadDim; j++) { + kPtr[keyIndex(dst + i, j)] = kPtr[keyIndex(src + i, j)]; + vPtr[valueIndex(dst + i, j)] = vPtr[valueIndex(src + i, j)]; + } + } + } +} + +void CPUKVCacheManager::onUpdateKV(const Tensor * key, const Tensor * value, int add) { + auto core = static_cast(mBackend)->functions(); + int seq_len = add; + auto divPart = UP_DIV(mKvNumHead, 1); + MNN_CONCURRENCY_BEGIN(tId, 1) { + auto remainPart = mKvNumHead - tId * divPart; + if (remainPart > 0) { + remainPart = ALIMIN(divPart, remainPart); + int startIdx = tId * divPart; + int endIdx = startIdx + remainPart; + for (int h = startIdx; h < endIdx; ++h) { + if (mBytes == 2) { + ProcessKey(key, seq_len, h); + ProcessValue(value, seq_len, h); + } else { + ProcessKey(key, seq_len, h); + ProcessValue(value, seq_len, h); + } + } + } + } MNN_CONCURRENCY_END(); + mPastLength += seq_len; +} + +} // namespace MNN + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/backend/cpu/CPUKVCacheManager.hpp b/source/backend/cpu/CPUKVCacheManager.hpp new file mode 100644 index 0000000000..a238b1ed2c --- /dev/null +++ b/source/backend/cpu/CPUKVCacheManager.hpp @@ -0,0 +1,140 @@ +// +// CPUKVCacheManager.hpp +// MNN +// +// Created by MNN on 2024/08/05. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#ifndef CPU_KVCACHE_MANAGER_HPP +#define CPU_KVCACHE_MANAGER_HPP + +#include "core/KVCacheManager.hpp" +#include "backend/cpu/CPUBackend.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" +#if defined (__aarch64__) +#define FLOAT16_T __fp16 +#else +#define FLOAT16_T float +#endif + +typedef uint8_t fp8_t; + +#define QUANT_INFO_BYTES 4 + +namespace MNN { + +class CPUKVCacheManager : public KVCacheManager{ +private: + int eP, lP, hP; // Packing mode for float matmul + int eP8, lP8, hP8; // Packing mode for int8 gemm kernel + int mThreadNum = 1; + + size_t mFlashAttentionUpperKv = 0; + + void expandKVCacheInMem(int oldMaxLength); + void moveKVCacheFromMemToDisk(int oldMaxLength); + void expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize, file_t specKeyFile = INVALID_FILE, file_t specValueFile = INVALID_FILE); + template void ProcessKey(const Tensor* key, int seq_len, int kv_h); + template void ProcessValue(const Tensor* value, int seq_len, int kv_h); + template void moveKV(int src, int dst, int size); + size_t keyIndex(int seq, int dim) const; + size_t valueIndex(int seq, int dim) const; + void saveKVCacheInDisk(); + + // The key/value size must be updated on every alloc or realloc call. + size_t mCurrentKeySizePerHead = 0; + size_t mCurrentValueSizePerHead = 0; + + // flash attention + bool mUseFlashAttention = true; + + // quant Key/Value + bool mQuantValue = false; // Quantize values to int8 or not + bool mQuantKey = false; // Whether to use int8 gemm kernel in CPU attention + std::shared_ptr mKeySum; // numhead, [maxlen/hP8, hP8] + std::shared_ptr mValueSum; // numhead, [headDim/hP8, hP8] + std::shared_ptr mKeyMax; // {numhead, headDim} + decltype(CoreFunctions::MNNQuantAttentionKey) mQuantKeyFunc; + decltype(CoreFunctions::MNNQuantAttentionValue) mQuantValueFunc; +public: + CPUKVCacheManager(Backend * backend, KVCacheConfig & kvConfig): KVCacheManager(backend, kvConfig) { + // nothing todo + } + ~CPUKVCacheManager() { + onClear(); + } + const Tensor * keySum() { + return mKeySum.get(); + } + + uint8_t* keyAddr() { + int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); + return (uint8_t*)baseAddr; + } + uint8_t* valudAddr() { + int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); + return (uint8_t*)baseAddr; + } + int8_t * addrOfKey(int kv_h) { + int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); + return baseAddr + kv_h * mCurrentKeySizePerHead; + } + int8_t * addrOfValue(int kv_h) { + int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); + return baseAddr + kv_h * mCurrentValueSizePerHead; + + } + void setFlashAttentionUpperKv(size_t upperKv) { + mFlashAttentionUpperKv = upperKv; + } + size_t getFlashAttentionBlockKv() { + return mFlashAttentionUpperKv; + } + + void onPushBack(const Tensor * key, const Tensor * value, int add); + void onDequantValue(Tensor * dequantedValues); + void onUpdateKV(const Tensor * key, const Tensor * value, int add); + + // quant Key/Value + int8_t * addrOfKeySum(int kv_h) { + if (mQuantKey) { + return mKeySum->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + }else { + return nullptr; + } + } + int8_t* addrOfKeyMax(int kvH) { + if (mQuantKey) { + return mKeyMax->host() + kvH * mHeadDim * mBytes; + } else { + return nullptr; + } + } + int8_t* addrOfValueSum(int kvH) { + if (mQuantValue) { + return mValueSum->host() + kvH * mValueSum->stride(0); + } else { + return nullptr; + } + } + void setAttenQuantKeyValue(bool useFlashAttention, bool quantKey, bool quantValue) { + mUseFlashAttention = useFlashAttention; + mQuantValue = quantValue; + mQuantKey = quantKey; + } + + virtual void onResize(int kv_num_head, int head_dim); + virtual void onClear(); + virtual void onAlloc(KVMeta* meta, int seq_len); + virtual void onRealloc(KVMeta* meta); + +}; + +} // namespace MNN + +#endif // CPU_KVCACHE_MANAGER_HPP + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 06c8ae281a..4f0765f050 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -268,15 +268,19 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const } else { core->MNNPackedMatMulRemain((float*)TC, (float*)TA, (float*)TB, xC, parameters, postPtr, biasPtr, nullptr, nullptr); } - int area[] = { - eP, - mE - }; if (mTransposeC) { + int offsets[] = { + eP, // src area + mH // dst depth + }; // hC4, e, 4 -> e, h auto dst = (uint8_t*)CPtr + xStart * mH * core->bytes; - core->MNNUnpackCUnitTranspose((float*)dst, (const float*)TC, xC, mH, area); + core->MNNUnpackCUnitTranspose((float*)dst, (const float*)TC, xC, mH, offsets); } else { + int area[] = { + eP, + mE + }; // hC4, e, 4 -> h, e auto dst = (uint8_t*)CPtr + xStart * core->bytes; core->MNNUnpackCUnit((float*)dst, (const float*)TC, xC, mH, area); diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index f500a5bd58..3272086531 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -687,13 +687,19 @@ ErrorCode CPURaster::onExecute(const std::vector &____inputs, const st for (auto& iter : mTempInput) { tensorConvert(iter.first, iter.second, (int)bytes); } - if (mHasReduce || TensorUtils::getDescribe(output)->overlap) { + if (mHasReduce) { // Don't support reduce with multi thread now threadNum = 1; } if (!mUseThreads) { threadNum = 1; } + + // StrideSliceWrite should not use multi threads + auto outputDescribe = TensorUtils::getDescribe(output); + if (outputDescribe->overlap) { + threadNum = 1; + } MNN_CONCURRENCY_BEGIN(tId, threadNum) { for (int u=tId; usme2 = true; + cpuinfo_isa->smeCoreNumber = 2; } } #endif @@ -1367,6 +1368,7 @@ static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) { } if (isa_features2 & CPUINFO_ARM_LINUX_FEATURE2_SME2) { cpuinfo_isa->sme2 = true; + cpuinfo_isa->smeCoreNumber = 1; } } #endif diff --git a/source/backend/cpu/CPURuntime.hpp b/source/backend/cpu/CPURuntime.hpp index a962af0c82..ebc6011d4d 100644 --- a/source/backend/cpu/CPURuntime.hpp +++ b/source/backend/cpu/CPURuntime.hpp @@ -24,6 +24,7 @@ struct MNNCPUInfo { bool sme2; std::vector groups; int cpuNumber = 0; + int smeCoreNumber = 0; }; using cpu_mask_t = unsigned long; int MNNSetSchedAffinity(const int* cpuIDs, int size); diff --git a/source/backend/cpu/CPUSoftmax.cpp b/source/backend/cpu/CPUSoftmax.cpp index 98c234443f..fbbb4c507f 100644 --- a/source/backend/cpu/CPUSoftmax.cpp +++ b/source/backend/cpu/CPUSoftmax.cpp @@ -76,7 +76,7 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { if (mTmpInput.ptr()) { tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); } - + if (mTmpOutput.ptr()) { tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); } @@ -200,7 +200,7 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { for (int v=0; vMNNFp32ToLowp((float*)tempOutput, (int16_t*)tempInput, outsideStride); MNNTranspose16Bit((int16_t*)dstO, (int16_t*)(tempInput), dims); } else if (mLowOrInt8 == 1) { @@ -288,7 +288,7 @@ ErrorCode CPUSoftmax::onResize(const std::vector &inputs, const std::v int threadNum = cpuBn->threadNumber(); auto buf = cpuBn->getBufferAllocator(); threadNum = ALIMIN(threadNum, outside); - + mTmpInput = buf->alloc(threadNum * inside * channel * sizeof(float)); if (mLowOrInt8 != 4) { mTmpOutput = buf->alloc(threadNum * inside * channel * sizeof(float)); @@ -350,7 +350,7 @@ class CPUSoftmaxCreator : public CPUBackend::Creator { virtual Execution *onCreate(const std::vector &inputs, const std::vector &outputs, const MNN::Op *op, Backend *backend) const override { return CPUSoftmax::create(op, backend); - + } }; diff --git a/source/backend/cpu/CPUTensorConvert.cpp b/source/backend/cpu/CPUTensorConvert.cpp index d05c0f3c8c..62f0a4f554 100644 --- a/source/backend/cpu/CPUTensorConvert.cpp +++ b/source/backend/cpu/CPUTensorConvert.cpp @@ -108,7 +108,7 @@ ErrorCode CPUTensorConverter::convert(const void* inputRaw, void* outputRaw, MNN if (1 == inside) { int offset[2] = { outside, - outside + channel }; int step = UP_DIV(outside, numberThread); int start = tId * step; diff --git a/source/backend/cpu/KVCacheManager.cpp b/source/backend/cpu/KVCacheManager.cpp deleted file mode 100644 index ce34d685dd..0000000000 --- a/source/backend/cpu/KVCacheManager.cpp +++ /dev/null @@ -1,758 +0,0 @@ -// -// KVCacheManager.cpp -// MNN -// -// Created by MNN on 2024/08/05. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef MNN_SUPPORT_TRANSFORMER_FUSE - -#include "KVCacheManager.hpp" -#include "core/Concurrency.h" - -namespace MNN { - -// Translate an address to a hex number string -static inline std::string addrToHex(void *addr) { - std::string result = ""; - uint64_t n = (uint64_t)addr; - for(int i = 15; i >= 0; i--) { - int t = (n >> (i * 4)) & 0x0f; - result.push_back((t < 10) ? ('0' + t) : ('A' + t - 10)); - } - return result; -} - -void KVCacheManager::createKVCacheFile() { - // Each layer has its own kvcache, so we have to create a key file and a value file for each layer and the file name must be unique - // Here we use the address of the mResource as the file name because the addresses of mResource in different layers are guaranteed to be different - std::string fileName = addrToHex(this); - std::string pathk = MNNFilePathConcat(mConfig.mKVCacheDir, fileName) + ".k"; - std::string pathv = MNNFilePathConcat(mConfig.mKVCacheDir, fileName) + ".v"; - mKeyCacheFD = MNNCreateFile(pathk.c_str()); - mValueCacheFD = MNNCreateFile(pathv.c_str()); - if (mKeyCacheFD == INVALID_FILE) { - MNN_PRINT("Failed to create the file: %s\n", pathk.c_str()); - } - if (mValueCacheFD == INVALID_FILE) { - MNN_PRINT("Failed to create the file: %s\n", pathv.c_str()); - } -} - -void KVCacheManager::removeKVCacheFile() { - std::string fileName = addrToHex(this); - std::string pathk = MNNFilePathConcat(mConfig.mKVCacheDir, fileName) + ".k"; - std::string pathv = MNNFilePathConcat(mConfig.mKVCacheDir, fileName) + ".v"; - if (mKeyCacheFD != INVALID_FILE) { - MNNCloseFile(mKeyCacheFD); - mKeyCacheFD = INVALID_FILE; - if (MNNRemoveFile(pathk.c_str()) != MNN::NO_ERROR) { - MNN_PRINT("Failed to remove the file: %s\n", pathk.c_str()); - } - } - if (mValueCacheFD != INVALID_FILE) { - MNNCloseFile(mValueCacheFD); - mValueCacheFD = INVALID_FILE; - if (MNNRemoveFile(pathv.c_str()) != MNN::NO_ERROR) { - MNN_PRINT("Failed to remove the file: %s\n", pathv.c_str()); - } - } -} - -void KVCacheManager::resetKVCacheFileSize(size_t keySize, size_t valueSize) { - if (MNNSetFileSize(mKeyCacheFD, keySize) != MNN::NO_ERROR || MNNSetFileSize(mValueCacheFD, valueSize) != MNN::NO_ERROR) { - MNN_PRINT("Failed to resize the kvcache files!\n"); - } -} - -/* -** @brief Memory-map the kvcache file -** @hint After memory-mapping, we can access the kvcache files with pointers, just like accessing memory buffer -** But the data actually resides in disk. -** The OS will set some kernel page cache and manage the data swaping, which we do not need to care. -*/ -void KVCacheManager::mmapKVCache(size_t keySize, size_t valueSize) -{ - if (mMapKeyAddr == nullptr) { - mMapKeyAddr = (int8_t *)MNNMmapFile(mKeyCacheFD, keySize); - if (mMapKeyAddr == nullptr) { - MNN_PRINT("Failed to memory-map the kvcache!\n"); - } - } - if (mMapValueAddr == nullptr) { - mMapValueAddr = (int8_t *)MNNMmapFile(mValueCacheFD, valueSize); - if (mMapValueAddr == nullptr) { - MNN_PRINT("Failed to memory-map the kvcache!\n"); - } - } -} - -void KVCacheManager::unmapKVCache(size_t keySize, size_t valueSize) -{ - if (mMapKeyAddr != nullptr) { - MNNUnmapFile(mMapKeyAddr, keySize); - mMapKeyAddr = nullptr; - } - if (mMapValueAddr != nullptr) { - MNNUnmapFile(mMapValueAddr, valueSize); - mMapValueAddr = nullptr; - } -} - -/* -** @brief Expand the size of kvcache and copy it from the old tensor in memory to the new tensor in memory -** Finally reset the pointer to the new tensor -*/ -void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { - /*=================================== Key ===================================*/ - if (mConfig.mUseInt8Kernel) { - auto new_key = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}); - mBackend->onAcquireBuffer(new_key, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - new_key->host() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 - ); - } - mPastKey.reset(new_key); - } - else if (mConfig.mQuantKey) { - auto new_key = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}); - mBackend->onAcquireBuffer(new_key, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - new_key->host() + h * new_key->stride(0), - mPastKey->host() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP), - ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) - ); - } - mPastKey.reset(new_key); - } - else { - auto new_key = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}); - mBackend->onAcquireBuffer(new_key, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - new_key->host() + h * new_key->stride(0) * mBytes, - mPastKey->host() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes, - ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes - ); - if ((new_key->stride(0) - mPastKey->stride(0)) > 0) { - memset(new_key->host() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes); - } - } - mPastKey.reset(new_key); - } - /*=================================== Value ===================================*/ - if (mConfig.mQuantValue) { - auto new_value = Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}); - mBackend->onAcquireBuffer(new_value, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy( - new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, - ROUND_UP(oldMaxLength, lP) * hP - ); - } - } - mPastValue.reset(new_value); - } - else { - auto new_value = Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}); - mBackend->onAcquireBuffer(new_value, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy( - new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, - ROUND_UP(oldMaxLength, lP) * hP * mBytes - ); - if ((new_value->stride(1) - mPastValue->stride(1)) > 0) { - memset(new_value->host() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes); - } - } - } - mPastValue.reset(new_value); - } -} - -/* -** @brief Move the kvcache from memory to the memory-mapped kvcache files in disk -** Then release the memory buffer of old kvcache -*/ -void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { - /*=================================== Key ===================================*/ - if (mConfig.mUseInt8Kernel) { - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 - ); - } - mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); - mPastKey.reset(); - } - if (mConfig.mQuantKey) { - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, - UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP - ); - } - mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); - mPastKey.reset(); - } - else { - if (mHeadDim % lP) { - memset(mMapKeyAddr, 0, mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes ); - } - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, - UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes - ); - } - mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); - mPastKey.reset(); - } - /*=================================== Value ===================================*/ - if (mConfig.mQuantValue) { - for (int h = 0; h < mKvNumHead; h++) { - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy( - mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, - ROUND_UP(oldMaxLength, lP) * hP - ); - } - } - mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC); - mPastValue.reset(); - } - else { - if (lP > 1) { - memset(mMapValueAddr, 0, mKvNumHead * ROUND_UP(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * mBytes); - } - for (int h = 0; h < mKvNumHead; h++) { - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy( - mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, - ROUND_UP(oldMaxLength, lP) * hP * mBytes - ); - } - } - mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC); - mPastValue.reset(); - } -} - -/* -** @brief Expand the size of kvcache files in disk -*/ -void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize) { - // Step 1: Copy the old kvcache from files to temporary buffers in memory - std::shared_ptr old_key, old_value; - if (mConfig.mUseInt8Kernel) { - old_key.reset(Tensor::createDevice({mKvNumHead, UP_DIV(oldMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8})); - } else if (mConfig.mQuantKey) { - old_key.reset(Tensor::createDevice({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP})); - } else { - old_key.reset(Tensor::createDevice({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP})); - } - if (mConfig.mQuantValue) { - old_value.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP})); - } else { - old_value.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP})); - } - mBackend->onAcquireBuffer(old_key.get(), Backend::STATIC); - mBackend->onAcquireBuffer(old_value.get(), Backend::STATIC); - if (mHeadDim % lP) { - memset(old_key->host(), 0, old_key->length(0) * old_key->stride(0) * mBytes); - } - if (lP > 1) { - // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0. - // computing L is seq_len - memset(old_value->host(), 0, old_value->length(0) * old_value->stride(0) * mBytes); - } - mmapKVCache(oldKeySize, oldValueSize); - memcpy(old_key->host(), mMapKeyAddr, oldKeySize); - memcpy(old_value->host(), mMapValueAddr, oldValueSize); - // Step 2: Resize the kvcache files and remap them - unmapKVCache(oldKeySize, oldValueSize); - resetKVCacheFileSize(keySize, valueSize); - mmapKVCache(keySize, valueSize); - // Step 3: Move the kvcache from temporary buffers in memory to disk - if (mConfig.mUseInt8Kernel) { - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - old_key->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 - ); - } - } else if (mConfig.mQuantKey) { - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, - old_key->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, - UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP - ); - } - } else { - for (int h = 0; h < mKvNumHead; h++) { - memcpy( - mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, - old_key->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, - UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes - ); - } - } - if (mConfig.mQuantValue) { - for (int h = 0; h < mKvNumHead; h++) { - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy( - mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, - old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, - ROUND_UP(oldMaxLength, lP) * hP - ); - } - } - } else { - for (int h = 0; h < mKvNumHead; h++) { - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy( - mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, - old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, - ROUND_UP(oldMaxLength, lP) * hP * mBytes - ); - } - } - } - // Step 4: Release the temporary buffers - mBackend->onReleaseBuffer(old_key.get(), Backend::STATIC); - mBackend->onReleaseBuffer(old_value.get(), Backend::STATIC); -} - -void KVCacheManager::onResize(int kv_num_head, int head_dim) { - mKvNumHead = kv_num_head; - mHeadDim = head_dim; - auto core = static_cast(mBackend)->functions(); - core->MNNGetMatMulPackMode(&eP, &lP, &hP); - mBytes = core->bytes; - mThreadNum = static_cast(mBackend)->threadNumber(); - if (mThreadNum > mKvNumHead) { - mThreadNum = mKvNumHead; - } - if (mConfig.mUseInt8Kernel) { - static_cast(mBackend)->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); - } -} - -void KVCacheManager::onAlloc(int kv_seq_len) { - mMaxLength = kv_seq_len + mConfig.mExpandChunk; - size_t keySize = 0, valueSize = 0; - if (mConfig.mUseInt8Kernel) { - keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; - } else if (mConfig.mQuantKey) { - keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP); - } else { - keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes; - } - valueSize = (size_t)mKvNumHead * ROUND_UP(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * (mConfig.mQuantValue ? 1 : mBytes); - /*============== Put the kvcache in disk ===========*/ - if (mConfig.mKVCacheSizeLimit != -1 && keySize + valueSize > mConfig.mKVCacheSizeLimit) { - createKVCacheFile(); - resetKVCacheFileSize(keySize, valueSize); - mmapKVCache(keySize, valueSize); - mKVCacheInDisk = true; - } - /*============== Put the kvcache in memory ===========*/ - else { - if (mConfig.mUseInt8Kernel) { - mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8})); - } else if (mConfig.mQuantKey) { - mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP})); - } else { - mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP})); - } - if (mConfig.mQuantValue) { - mPastValue.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP})); - } else { - mPastValue.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP})); - } - mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC); - mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC); - if (mHeadDim % lP) { - memset(mPastKey->host(), 0, mPastKey->length(0) * mPastKey->stride(0) * mBytes); - } - if (lP > 1) { // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0. - memset(mPastValue->host(), 0, mPastValue->length(0) * mPastValue->stride(0) * mBytes); - } - } - // scale, zero point and sum of key for quantization - if (mConfig.mUseInt8Kernel) { - mKeyScale.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8})); - mKeyZeroPoint.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8})); - mKeySum.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8})); - mBackend->onAcquireBuffer(mKeyScale.get(), Backend::STATIC); - mBackend->onAcquireBuffer(mKeyZeroPoint.get(), Backend::STATIC); - mBackend->onAcquireBuffer(mKeySum.get(), Backend::STATIC); - } else if (mConfig.mQuantKey) { - mKeyScale.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), hP})); - mKeyZeroPoint.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), hP})); - mBackend->onAcquireBuffer(mKeyScale.get(), Backend::STATIC); - mBackend->onAcquireBuffer(mKeyZeroPoint.get(), Backend::STATIC); - } -} - -void KVCacheManager::onRealloc(const KVMeta* meta) { - auto kv_seq_len = meta->previous + meta->add - meta->remove + meta->computeReverseSize(); - if (kv_seq_len > mMaxLength) { - // Realloc - int oldMaxLength = mMaxLength; - mMaxLength = kv_seq_len + mConfig.mExpandChunk; - size_t oldKeySize, oldValueSize, keySize, valueSize; - if (mConfig.mUseInt8Kernel) { - oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; - keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; - } else if (mConfig.mQuantKey) { - oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP; - keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP; - } else { - oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; - keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; - } - oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(oldMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes); - valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes); - /*==== No limit for kvcache ====*/ - if (mConfig.mKVCacheSizeLimit == -1) { - expandKVCacheInMem(oldMaxLength); - } - /*==== Last time the kvcache is memory, now it should be in memory too ====*/ - else if (keySize + valueSize <= mConfig.mKVCacheSizeLimit) { - expandKVCacheInMem(oldMaxLength); - } - /*==== Last time the kvcache is in memory, but now it should be moved to disk ====*/ - else if (oldKeySize + oldValueSize <= mConfig.mKVCacheSizeLimit) { - createKVCacheFile(); - resetKVCacheFileSize(keySize, valueSize); - mmapKVCache(keySize, valueSize); - moveKVCacheFromMemToDisk(oldMaxLength); - mKVCacheInDisk = true; - } - /*==== Last time the kvcache is disk, now it should be in disk too ====*/ - else { - expandKVCacheInDisk(oldMaxLength, oldKeySize, oldValueSize, keySize, valueSize); - } - /* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */ - if (mConfig.mUseInt8Kernel) { - auto new_scale = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); - auto new_zeroPoint = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); - auto new_sum = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); - mBackend->onAcquireBuffer(new_scale, Backend::STATIC); - mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); - mBackend->onAcquireBuffer(new_sum, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); - memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); - memcpy(new_sum->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); - } - mKeyScale.reset(new_scale); - mKeyZeroPoint.reset(new_zeroPoint); - mKeySum.reset(new_sum); - } else if (mConfig.mQuantKey) { - auto new_scale = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP}); - auto new_zeroPoint = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP}); - mBackend->onAcquireBuffer(new_scale, Backend::STATIC); - mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); - for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); - memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); - } - mKeyScale.reset(new_scale); - mKeyZeroPoint.reset(new_zeroPoint); - } - } - // Remove - auto start = mPastLength - meta->remove; - if (0 == meta->n_reserve) { - mPastLength = start; - return; - } -#if 1 - auto dstIndex = start; - for (int n = 0; n < meta->n_reserve; ++n) { - auto begin = meta->reserve[2 * n]; - auto size = meta->reserve[2 * n + 1]; - auto srcIndex = start + begin; - if (mBytes == 2) { - moveKV(srcIndex, dstIndex, size); - } else { - moveKV(srcIndex, dstIndex, size); - } - dstIndex += size; - } - mPastLength = dstIndex; -#else - // Don't support not align reserve - auto align = hP; - auto dstStart = start; - auto lastValidSrcEnd = start; - for (int n=0; nn_reserve; ++n) { - auto lastEndAlign = UP_DIV(lastValidSrcEnd, align) * align; - auto begin = meta->reserve[2 * n]; - auto size = meta->reserve[2 * n + 1]; - auto startAlign = ((begin + start) / align) * align; - if (startAlign <= lastEndAlign) { - // Fullly reserve - dstStart = dstStart + size; - lastValidSrcEnd = begin + start + size; - continue; - } - auto end = begin + start + size; - auto endAlign = UP_DIV(end, align) * align; - - auto sizeUnit = (endAlign - startAlign) / align; - auto dstStartAlign = UP_DIV(dstStart, align) * align; - - //TODO: Support Quant -// mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP})); - - // Move K - auto keyStride = UP_DIV(mMaxLength, align) * align * ROUND_UP(mHeadDim, lP); - auto dstKAddr = keyAddr() + dstStartAlign * ROUND_UP(mHeadDim, lP) * mBytes; - auto srcKAddr = keyAddr() + startAlign * ROUND_UP(mHeadDim, lP) * mBytes; - for (int i=0; i({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP})); - - // Move V - auto dstVAddr = valudAddr() + dstStartAlign * align * mBytes; - auto srcVAddr = valudAddr() + startAlign * align * mBytes; - auto number = mKvNumHead * UP_DIV(mHeadDim, align); - for (int i=0; i -void KVCacheManager::pack_key(const Tensor* key, int seq_len, int kv_h) { - if (mConfig.mUseInt8Kernel) { // [maxlen/hP8, headdim/lP8, hP8, lP8] - int8_t * key_dst = reinterpret_cast(addrOfKey(kv_h)); - float * scale_dst = reinterpret_cast(addrOfScale(kv_h)); - float * zeroPoint_dst = reinterpret_cast(addrOfZeroPoint(kv_h)); - float * sum_dst = reinterpret_cast(addrOfKeySum(kv_h)); - for (int s = 0; s < seq_len; s++) { - T * key_src = key->host() + s * mKvNumHead * mHeadDim + kv_h * mHeadDim; - float minKey = key_src[0]; - float maxKey = key_src[0]; - float sumKey = key_src[0]; - for (int d = 1; d < mHeadDim; d++) { - minKey = ALIMIN(minKey, key_src[d]); - maxKey = ALIMAX(maxKey, key_src[d]); - sumKey += key_src[d]; - } - int out_index = (mPastLength + s) / hP8; - int in_index = (mPastLength + s) % hP8; - scale_dst[out_index * hP8 + in_index] = (maxKey - minKey) / 255.0f; - zeroPoint_dst[out_index * hP8 + in_index] = -255.0f * minKey / (maxKey - minKey) - 128.0; - sum_dst[out_index * hP8 + in_index] = sumKey; - for (int d = 0; d < mHeadDim; d++) { - int i = d / lP8; - int j = d % lP8; - key_dst[out_index * UP_DIV(mHeadDim, lP8) * hP8 * lP8 + i * hP8 * lP8 + in_index * lP8 + j] = roundf((key_src[d] - minKey) / (maxKey - minKey) * 255.0f - 128.0f); - } - } - } - else if (mConfig.mQuantKey) { // [maxlen/hP, headdim, hP] - int8_t * key_dst = reinterpret_cast(addrOfKey(kv_h)); - T * scale_dst = reinterpret_cast(addrOfScale(kv_h)); - T * zeroPoint_dst = reinterpret_cast(addrOfZeroPoint(kv_h)); - for (int i = 0; i < seq_len; i++) { - T * key_src = key->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; - int out_index = (mPastLength + i) / hP; - int in_index = (mPastLength + i) % hP; - T minKey, maxKey; - static_cast(mBackend)->functions()->MNNCountMaxMinValue((float*)key_src, (float*)&minKey, (float*)&maxKey, mHeadDim); - scale_dst[out_index * hP + in_index] = (maxKey - minKey) / 255.0f; - zeroPoint_dst[out_index * hP + in_index] = 128.0f * (maxKey - minKey) / 255.0f + minKey; - for (int j = 0; j < mHeadDim; j++) { - key_dst[out_index * mHeadDim * hP + j * hP + in_index] = roundf((key_src[j] - minKey) / (maxKey - minKey) * 255 - 128); - } - } - } - else { // target: [maxlen/hP, headdim/lP, hP, lP] - T * key_dst = reinterpret_cast(addrOfKey(kv_h)); - auto stride0 = ROUND_UP(mHeadDim, lP) * hP; - auto stride1 = hP * lP; - for (int i = 0; i < seq_len; i++) { - T * key_src = key->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; - int out_index = (mPastLength + i) / hP; - int in_index = (mPastLength + i) % hP; - for (int j = 0; j < mHeadDim; j++) { - key_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = key_src[j]; - } - } - } -} - -template -void KVCacheManager::pack_value(const Tensor* value, int seq_len, int kv_h) { // [headdim/hP, maxlen, hP] - if (mConfig.mQuantValue) { - fp8_t * value_dst = reinterpret_cast(addrOfValue(kv_h)); - uint8_t * buf = (uint8_t *)MNNMemoryAllocAlign(mHeadDim, MNN_MEMORY_ALIGN_DEFAULT); - for (int i = 0; i < seq_len; i++) { - T * value_src = value->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; - if (sizeof(T) == 2) { - static_cast(mBackend)->functions()->MNNFp16ToFp8(buf, (uint16_t*)value_src, mHeadDim); - } else { - static_cast(mBackend)->functions()->MNNFp32ToFp8(buf, (float*)value_src, mHeadDim); - } - for (int j = 0; j < mHeadDim; j++) { - int out_index = j / hP; - int in_index = j % hP; - value_dst[out_index * mMaxLength * hP + (mPastLength + i) * hP + in_index] = buf[j]; - } - } - MNNMemoryFreeAlign(buf); - } - else { - // [mHeadDim/hP, mMaxLength/lP, hP, lP] - auto stride0 = ROUND_UP(mMaxLength, lP) * hP; - auto stride1 = hP * lP; - T * value_dst = reinterpret_cast(addrOfValue(kv_h)); - for (int i = 0; i < seq_len; i++) { - T * value_src = value->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; - int seqLenOut = (mPastLength + i) / lP; - int seqLenIn = (mPastLength + i) % lP; - for (int j = 0; j < mHeadDim; j++) { - int out_index = j / hP; - int in_index = j % hP; - value_dst[out_index * stride0 + seqLenOut * stride1 + in_index * lP + seqLenIn] = value_src[j]; - } - } - } -} - -size_t KVCacheManager::keyIndex(int seq, int dim) const { - return (seq / hP) * ROUND_UP(mHeadDim, lP) * hP + - (dim / lP) * hP * lP + - (seq % hP) * lP + - (dim % lP); -} - -size_t KVCacheManager::valueIndex(int seq, int dim) const { - return (dim / hP) * ROUND_UP(mMaxLength, lP) * hP + - (seq / lP) * hP * lP + - (dim % hP) * lP + - (seq % lP); -} - -template -void KVCacheManager::moveKV(int src, int dst, int size) { - for (int h = 0; h < mKvNumHead; ++h) { - auto kPtr = reinterpret_cast(addrOfKey(h)); - auto vPtr = reinterpret_cast(addrOfValue(h)); - for (int i = 0; i < size; i++) { - for (int j = 0; j < mHeadDim; j++) { - kPtr[keyIndex(dst + i, j)] = kPtr[keyIndex(src + i, j)]; - vPtr[valueIndex(dst + i, j)] = vPtr[valueIndex(src + i, j)]; - } - } - } -} - -void KVCacheManager::onPushBack(const Tensor * key, const Tensor * value, int add) { - auto core = static_cast(mBackend)->functions(); - int seq_len = add; - int tileCount = UP_DIV(mKvNumHead, mThreadNum); - std::function packKV = [=](int tid) { - for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) { - if (mBytes == 2) { - pack_key(key, seq_len, kv_h); - pack_value(value, seq_len, kv_h); - } else { - pack_key(key, seq_len, kv_h); - pack_value(value, seq_len, kv_h); - } - } - }; - MNN_CONCURRENCY_BEGIN(tid, mThreadNum) { - packKV((int)tid); - } - MNN_CONCURRENCY_END(); - mPastLength += seq_len; -} - -void KVCacheManager::onDequantValue(Tensor * dequantedValues) { - auto core = static_cast(mBackend)->functions(); - int tileCount = UP_DIV(mKvNumHead, mThreadNum); - std::function dequant = [=](int tid) { - for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) { - int8_t * dst = dequantedValues->host() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes; - int8_t * src = addrOfValue(kv_h); - for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - if (mBytes == 2) { - core->MNNFp8ToFp16((uint16_t*)dst, (uint8_t*)src, mPastLength * hP); - } else { - core->MNNFp8ToFp32((float*)dst, (uint8_t*)src, mPastLength * hP); - } - dst += mPastLength * hP * mBytes; - src += mMaxLength * hP; - } - } - }; - MNN_CONCURRENCY_BEGIN(tid, mThreadNum) { - dequant((int)tid); - } - MNN_CONCURRENCY_END(); -} - -} // namespace MNN - -#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/backend/cpu/KVCacheManager.hpp b/source/backend/cpu/KVCacheManager.hpp deleted file mode 100644 index 19bf38afbc..0000000000 --- a/source/backend/cpu/KVCacheManager.hpp +++ /dev/null @@ -1,172 +0,0 @@ -// -// KVCacheManager.hpp -// MNN -// -// Created by MNN on 2024/08/05. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef MNN_SUPPORT_TRANSFORMER_FUSE - -#ifndef KVCACHE_MANAGER_HPP -#define KVCACHE_MANAGER_HPP - -#include "core/Macro.h" -#include "core/MNNFileUtils.h" -#include "core/OpCommonUtils.hpp" -#include "backend/cpu/CPUBackend.hpp" -#include "backend/cpu/compute/CommonOptFunction.h" -#if defined (__aarch64__) -#define FLOAT16_T __fp16 -#else -#define FLOAT16_T float -#endif - -typedef uint8_t fp8_t; - -namespace MNN { - -class KVCacheManager : public NonCopyable{ -public: - struct KVCacheConfig { - bool mQuantKey = false; // Quantize keys to int8 or not - bool mQuantValue = false; // Quantize values to fp8 or not - bool mUseInt8Kernel = false; // Whether to use int8 gemm kernel in CPU attention - std::string mKVCacheDir = "/tmp"; // Path of the kvcache files in disk - size_t mKVCacheSizeLimit = -1; // The limit of the kvcache size - int mExpandChunk = 64; // Number of expand chunks when the buffer is full - }; -private: - Backend * mBackend; - KVCacheConfig mConfig; - std::shared_ptr mPastKey; // {numhead, [maxlen/hP, headdim, hP]} or {numhead, [maxlen/hP8, headdim/lP8, hP8, lP8]} - std::shared_ptr mPastValue; // numhead, [headdim/hP, maxlen, hP] - std::shared_ptr mKeyScale; // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]} - std::shared_ptr mKeyZeroPoint; // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]} - std::shared_ptr mKeySum; // numhead, [maxlen/hP8, hP8] - file_t mKeyCacheFD = INVALID_FILE; // The file descriptor of keys - file_t mValueCacheFD = INVALID_FILE; // The file descriptor of values - int8_t * mMapKeyAddr = nullptr; // Memory-mapped address of keys - int8_t * mMapValueAddr = nullptr; // Memory-mapped address of values - bool mKVCacheInDisk = false; // Whether the kvcache is in disk or in memory now - int mPastLength = 0; // Length of past kvcache - int mMaxLength = 0; // Capacity of current kvcache buffer (how many kv items can be stored at most) - int eP, lP, hP; // Packing mode for float matmul - int eP8, lP8, hP8; // Packing mode for int8 gemm kernel - int mBytes = 4, mThreadNum = 1; - int mKvNumHead = 0, mHeadDim = 0; - void createKVCacheFile(); - void removeKVCacheFile(); - void resetKVCacheFileSize(size_t keySize, size_t valueSize); - void mmapKVCache(size_t keySize, size_t valueSize); - void unmapKVCache(size_t keySize, size_t valueSize); - void expandKVCacheInMem(int oldMaxLength); - void moveKVCacheFromMemToDisk(int oldMaxLength); - void expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize); - template void pack_key(const Tensor* key, int seq_len, int kv_h); - template void pack_value(const Tensor* value, int seq_len, int kv_h); - template void moveKV(int src, int dst, int size); - size_t keyIndex(int seq, int dim) const; - size_t valueIndex(int seq, int dim) const; -public: - KVCacheManager(Backend * backend, KVCacheConfig & kvConfig) { - mBackend = backend; - mConfig = kvConfig; - } - ~KVCacheManager() { - onClear(); - } - const Backend * backend() { - return mBackend; - } - const KVCacheConfig * config() { - return &mConfig; - } - const Tensor * key() { - return mPastKey.get(); - } - const Tensor * value() { - return mPastValue.get(); - } - const Tensor * scale() { - return mKeyScale.get(); - } - const Tensor * zeroPoint() { - return mKeyZeroPoint.get(); - } - const Tensor * keySum() { - return mKeySum.get(); - } - bool inDisk() { - return mKVCacheInDisk; - } - int kvLength() { - return mPastLength; - } - int maxLength() { - return mMaxLength; - } - uint8_t* keyAddr() { - int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); - return (uint8_t*)baseAddr; - } - uint8_t* valudAddr() { - int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); - return (uint8_t*)baseAddr; - } - int8_t * addrOfKey(int kv_h) { - int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); - if (mConfig.mUseInt8Kernel) { - return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; - } else if (mConfig.mQuantKey) { - return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP; - } else { - return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; - } - } - int8_t * addrOfValue(int kv_h) { - int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); - if (mConfig.mQuantValue) { - return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP; - } else { - return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes; - } - } - int8_t * addrOfScale(int kv_h) { - if (mConfig.mUseInt8Kernel) { - return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; - } else if (mConfig.mQuantKey) { - return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; - } else { - return nullptr; - } - } - int8_t * addrOfZeroPoint(int kv_h) { - if (mConfig.mUseInt8Kernel) { - return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; - } else if (mConfig.mQuantKey) { - return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; - } else { - return nullptr; - } - } - int8_t * addrOfKeySum(int kv_h) { - if (mConfig.mUseInt8Kernel) { - return mKeySum->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; - }else { - return nullptr; - } - } - void onResize(int kv_num_head, int head_dim); - void onAlloc(int kv_seq_len); - void onRealloc(const KVMeta* meta); - void onClear(); - void onPushBack(const Tensor * key, const Tensor * value, int add); - void onDequantValue(Tensor * dequantedValues); -}; - -} // namespace MNN - -#endif // KVCACHE_MANAGER_HPP - -#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/backend/cpu/KleidiAIConvolutionDepthwise.cpp b/source/backend/cpu/KleidiAIConvolutionDepthwise.cpp new file mode 100644 index 0000000000..9ed81cc86a --- /dev/null +++ b/source/backend/cpu/KleidiAIConvolutionDepthwise.cpp @@ -0,0 +1,168 @@ +#include "backend/cpu/KleidiAIConvolutionDepthwise.hpp" + +#ifdef MNN_KLEIDIAI_ENABLED + +#include +#include "core/Concurrency.h" +#include "backend/cpu/compute/Int8FunctionsOpt.h" +#include "core/Macro.h" +#include "core/TensorUtils.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" +#include "backend/cpu/compute/ConvOpt.h" + +namespace MNN { + template +void nchw_to_nhwc_optimized(const T* src, T* dst, + int batch, int channel, int height, int width) { + const int hw = height * width; + const int chw = channel * hw; + const int wc = width * channel; + + for (int n = 0; n < batch; ++n) { + const T* src_batch = src + n * chw; + T* dst_batch = dst + n * chw; + + for (int c = 0; c < channel; ++c) { + const T* src_channel = src_batch + c * hw; + + for (int h = 0; h < height; ++h) { + const T* src_row = src_channel + h * width; + T* dst_row = dst_batch + h * wc + c; + + for (int w = 0; w < width; ++w) { + dst_row[w * channel] = src_row[w]; + } + } + } + } +} + +KleidiAIConvolutionDepthwise::KleidiAIDepthwiseExecution::KleidiAIDepthwiseExecution(const Convolution2DCommon* common, Backend* b, + const float* originWeight, size_t originWeightSize, + const float* bias, size_t biasSize) + : MNN::CPUConvolution(common, b) { + int kernel_height = common->kernelY(); + int kernel_width = common->kernelX(); + int channels = common->outputCount(); + int packedRhsSize = kai_rhs_get_dst_size_dwconv_pack_x32p1vlx1b_x32_x32_sme(kernel_height, kernel_width, channels); + mPackedRhs.reset(Tensor::createDevice(std::vector{packedRhsSize})); + bool success = b->onAcquireBuffer(mPackedRhs.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Error for alloc memory for CPUConvolutionDepthwise\n"); + mValid = false; + return; + } + mNumber = ((CPUBackend*)b)->threadNumber(); + mWeightTemp.reset(Tensor::createDevice(std::vector{channels * kernel_height * kernel_width * (int)sizeof(float)})); + success = b->onAcquireBuffer(mWeightTemp.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Error for alloc memory for CPUConvolutionDepthwise\n"); + mValid = false; + return; + } + auto weightTempPtr = mWeightTemp->host(); + nchw_to_nhwc_optimized(originWeight, weightTempPtr, 1, channels, kernel_height, kernel_width); + kai_run_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme(kernel_height, kernel_width, kernel_height, kernel_width, channels, weightTempPtr, bias, mPackedRhs.get()->host()); + b->onReleaseBuffer(mWeightTemp.get(), Backend::STATIC); +} + +ErrorCode KleidiAIConvolutionDepthwise::KleidiAIDepthwiseExecution::onResize(const std::vector& inputs, + const std::vector& outputs) { + CPUConvolution::onResize(inputs, outputs); + auto input = inputs[0]; + auto output = outputs[0]; + TensorUtils::getDescribe(&mOutputNHWC)->dimensionFormat = MNN_DATA_FORMAT_NHWC; + mOutputNHWC.buffer().dimensions = 4; + mOutputNHWC.buffer().dim[0].extent = output->batch(); + mOutputNHWC.buffer().dim[1].extent = output->height(); + mOutputNHWC.buffer().dim[2].extent = output->width(); + mOutputNHWC.buffer().dim[3].extent = output->channel(); + mOutputNHWC.buffer().type = output->getType(); + auto success = backend()->onAcquireBuffer(&mOutputNHWC, Backend::DYNAMIC); + if (!success) { + return OUT_OF_MEMORY; + } + + TensorUtils::getDescribe(&mInputNHWC)->dimensionFormat = MNN_DATA_FORMAT_NHWC; + mInputNHWC.buffer().dimensions = 4; + mInputNHWC.buffer().dim[0].extent = input->batch(); + mInputNHWC.buffer().dim[1].extent = input->height(); + mInputNHWC.buffer().dim[2].extent = input->width(); + mInputNHWC.buffer().dim[3].extent = input->channel(); + mInputNHWC.buffer().type = input->getType(); + success = backend()->onAcquireBuffer(&mInputNHWC, Backend::DYNAMIC); + if (!success) { + return OUT_OF_MEMORY; + } + + backend()->onReleaseBuffer(&mOutputNHWC, Backend::DYNAMIC); + backend()->onReleaseBuffer(&mInputNHWC, Backend::DYNAMIC); + return NO_ERROR; +} + +ErrorCode KleidiAIConvolutionDepthwise::KleidiAIDepthwiseExecution::onExecute(const std::vector& inputs, + const std::vector& outputs) { + auto inputTensor = inputs[0]; + auto outputTensor = outputs[0]; + const auto srcOrigin = mInputNHWC.host(); + auto dstOrigin = mOutputNHWC.host(); + auto postData = getPostParameters(); + auto output_height = outputTensor->height(); + auto core = static_cast(backend())->functions(); + auto batch = inputTensor->batch(); + + MNN_CONCURRENCY_BEGIN(tId, mNumber) { + CPUTensorConverter::convert(inputTensor, &mInputNHWC, core, tId, mNumber); + } + MNN_CONCURRENCY_END(); + + //CPUTensorConverter::convert(inputTensor, &mInputNHWC, core); + + constexpr size_t rows_handled = 4; // no of rows kernel handles each time. + for(size_t b = 0; b < batch; b++) { + const auto srcOriginBatch = srcOrigin + b * inputTensor->height() * inputTensor->width() * inputTensor->channel() * sizeof(float); + auto dstOriginBatch = dstOrigin + b * outputTensor->height() * outputTensor->width() * outputTensor->channel() * sizeof(float); + for (size_t out_row = 0; out_row < output_height; out_row += rows_handled) { + // Variables below used to calculate start of input pointer. + const int start_in_row = out_row - mPadY; + const size_t pad_top = (start_in_row < 0) ? (-start_in_row) : 0; + const size_t in_row = (start_in_row < 0) ? 0 : start_in_row; + + // Calculate row strides for pointer. + const size_t in_row_stride_bytes = (inputTensor->width() * inputTensor->channel() * sizeof(float)); + const size_t out_row_stride_bytes = (outputTensor->width() * outputTensor->channel() * sizeof(float)); + + // Number of input rows that can be read, number of output rows to calculate. + const size_t valid_input_rows = (in_row < inputTensor->height()) ? (inputTensor->height() - in_row) : 0; + const size_t valid_out_rows = (outputTensor->height() - out_row); + + // Increment output/input pointers according to tile being calculated. + auto out_offset = kai_get_dst_offset_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla( + out_row, out_row_stride_bytes); + auto in_offset = kai_get_src_offset_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla( + in_row, in_row_stride_bytes); + const auto inptr = (uint8_t*)srcOriginBatch + in_offset; + auto outptr = (uint8_t*)dstOriginBatch + out_offset; + + // NOTE: Kernel expects strides to be passed as bytes. + // f32_f32_f32p1vlx1b -> f32 output, f32 LHS, packed F32 rhs (with bias) as 1VL blocks. + // 3x3_s : 3x3 filter with stride 1 + // 4xc : 4 rows across all output channels (plane c) is produced. + kai_run_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla( + inptr, mPackedRhs.get()->host(), outptr, in_row_stride_bytes, inputTensor->channel() * sizeof(float), + out_row_stride_bytes, outputTensor->channel() * sizeof(float), valid_input_rows, valid_out_rows, + mPadX, pad_top, 0.0f, postData[2], postData[3]); + } + } + + MNN_CONCURRENCY_BEGIN(tId, mNumber) { + CPUTensorConverter::convert(&mOutputNHWC, outputTensor, core, tId, mNumber); + } + MNN_CONCURRENCY_END(); + //CPUTensorConverter::convert(&mOutputNHWC, outputTensor, core); + return NO_ERROR; +} + +} // namespace MNN + +#endif // defined(MNN_KLEIDIAI_ENABLED) diff --git a/source/backend/cpu/KleidiAIConvolutionDepthwise.hpp b/source/backend/cpu/KleidiAIConvolutionDepthwise.hpp new file mode 100644 index 0000000000..1a1ad7281b --- /dev/null +++ b/source/backend/cpu/KleidiAIConvolutionDepthwise.hpp @@ -0,0 +1,37 @@ +#ifndef KleidiAIConvolutionDepthwise_hpp +#define KleidiAIConvolutionDepthwise_hpp + +#ifdef MNN_KLEIDIAI_ENABLED + +#include "core/AutoStorage.h" +#include "backend/cpu/CPUConvolution.hpp" +#include "backend/cpu/compute/ConvolutionIntFactory.hpp" +#include "kai_rhs_dwconv_pack_x32p1vlx1b_x32_x32_sme.h" +#include "kai_dwconv_clamp_f32_f32_f32p1vlx1b_3x3_s1_4xc_sme2_mla.h" +#include "backend/cpu/CPUTensorConvert.hpp" + +namespace MNN { +class KleidiAIConvolutionDepthwise { + public: + class KleidiAIDepthwiseExecution : public CPUConvolution { + public: + KleidiAIDepthwiseExecution(const Convolution2DCommon *common, Backend *b, const float *originWeight, + size_t originWeightSize, const float *bias, size_t biasSize); + virtual ~KleidiAIDepthwiseExecution() = default; + virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; + + private: + int mNumber = 1; + std::shared_ptr mPackedRhs; + std::shared_ptr mWeightTemp; + Tensor mOutputNHWC; + Tensor mInputNHWC; + }; +}; + +} // namespace MNN + +#endif // defined(MNN_KLEIDIAI_ENABLED) + +#endif /* KleidiAIConvolutionDepthwise_hpp */ diff --git a/source/backend/cpu/arm/CommonOptFunctionNeon.cpp b/source/backend/cpu/arm/CommonOptFunctionNeon.cpp index b85069c0ca..4bc3e01315 100644 --- a/source/backend/cpu/arm/CommonOptFunctionNeon.cpp +++ b/source/backend/cpu/arm/CommonOptFunctionNeon.cpp @@ -8,6 +8,317 @@ extern "C" { void MNNTranspose32Bit4x4(int32_t* dstO, const int32_t* srcO, int32_t* dim); void MNNTranspose16Bit8x8(int16_t* dstO, const int16_t* srcO, int32_t* dim); } + +static inline float vmaxvq_f32_compat(float32x4_t v) { + #if defined(__aarch64__) + return vmaxvq_f32(v); + #else + float32x2_t p = vpmax_f32(vget_low_f32(v), vget_high_f32(v)); + p = vpmax_f32(p, p); + return vget_lane_f32(p, 0); + #endif + } + + static inline float vminvq_f32_compat(float32x4_t v) { + #if defined(__aarch64__) + return vminvq_f32(v); + #else + float32x2_t step1 = vpmin_f32(vget_low_f32(v), vget_high_f32(v)); + step1 = vpmin_f32(step1, step1); + return vget_lane_f32(step1, 0); + #endif + } + + static inline float vaddvq_f32_compat(float32x4_t v) { + #if defined(__aarch64__) + return vaddvq_f32(v); + #else + float32x2_t p = vpadd_f32(vget_low_f32(v), vget_high_f32(v)); + p = vpadd_f32(p, p); + return vget_lane_f32(p, 0); + #endif + } + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +#ifdef __aarch64__ +void MNNQuantAttentionKey(int8_t* dst, const float* source, float* sumKeyPtr, float* maxKeyPtr, int32_t* params) { + int32_t kvNumHead = params[0]; + int32_t seqLen = params[1]; + int32_t headDim = params[2]; + int32_t blockNum = params[3]; + int32_t eP = params[4]; + int32_t lP = params[5]; + int32_t hP = params[6]; + int32_t pastLength = params[7]; + int32_t kvHeadIdx = params[8]; + + auto blockHeadDim = UP_DIV(headDim, blockNum); + auto weightStride1 = ROUND_UP(blockHeadDim, lP) * hP; + auto weightStride2 = lP * hP; + auto packedWeightStride1 = weightStride1 + 2 * 4 * hP; + + int8_t tempBuffer[8]; + + if (seqLen > 1) { + // get max + for (int s = 0; s < seqLen; ++s) { + const float* keySrc = source + s * kvNumHead * headDim + kvHeadIdx * headDim; + int d = 0; + for (; d <= headDim - 8; d += 8) { + float32x4_t max_vec0 = vld1q_f32(maxKeyPtr + d); + float32x4_t max_vec1 = vld1q_f32(maxKeyPtr + d + 4); + float32x4_t src_vec0 = vld1q_f32(keySrc + d); + float32x4_t src_vec1 = vld1q_f32(keySrc + d + 4); + max_vec0 = vmaxq_f32(max_vec0, src_vec0); + max_vec1 = vmaxq_f32(max_vec1, src_vec1); + vst1q_f32(maxKeyPtr + d, max_vec0); + vst1q_f32(maxKeyPtr + d + 4, max_vec1); + } + for (; d <= headDim - 4; d += 4) { + float32x4_t max_vec = vld1q_f32(maxKeyPtr + d); + float32x4_t src_vec = vld1q_f32(keySrc + d); + max_vec = vmaxq_f32(max_vec, src_vec); + vst1q_f32(maxKeyPtr + d, max_vec); + } + for (; d < headDim; ++d) { + maxKeyPtr[d] = ALIMAX(maxKeyPtr[d], keySrc[d]); + } + } + } + + for (int s = 0; s < seqLen; s++) { + const float* keySrc = source + s * kvNumHead * headDim + kvHeadIdx * headDim; + + float32x4_t min_vec = vdupq_n_f32(keySrc[0] - maxKeyPtr[0]); + float32x4_t max_vec = vdupq_n_f32(keySrc[0] - maxKeyPtr[0]); + + int d = 0; + for (; d <= headDim - 4; d += 4) { + float32x4_t src_vec = vld1q_f32(keySrc + d); + float32x4_t max_key_vec = vld1q_f32(maxKeyPtr + d); + float32x4_t keydata_vec = vsubq_f32(src_vec, max_key_vec); + + min_vec = vminq_f32(min_vec, keydata_vec); + max_vec = vmaxq_f32(max_vec, keydata_vec); + } + // Reduction + float minKey = vminvq_f32_compat(min_vec); + float maxKey = vmaxvq_f32_compat(max_vec); + + // remain headDim + for (; d < headDim; ++d) { + auto keydata = keySrc[d] - maxKeyPtr[d]; + minKey = ALIMIN(minKey, keydata); + maxKey = ALIMAX(maxKey, keydata); + } + + int outIndex = (pastLength + s) / hP; + int inIndex = (pastLength + s) % hP; + + float range = maxKey - minKey; + float quantScaleVal = 0; + float biasVal = minKey + 128.0f * (range) / 255.0f; + if (range <= 1e-6f) { + quantScaleVal = 0.0f; + } else { + quantScaleVal = 255.0f / range; + } + + for (int k = 0; k < blockNum; ++k) { + int8_t* weightDstBase = dst + outIndex * blockNum * packedWeightStride1 + k * packedWeightStride1; + float* scaleDst = (float*)(weightDstBase + weightStride1); + float* biasDst = scaleDst + hP; + + scaleDst[inIndex] = range / 255.f; + biasDst[inIndex] = biasVal; + + float32x4_t scaleVec = vdupq_n_f32(quantScaleVal); + float32x4_t negBiasVec = vdupq_n_f32(-minKey); + float32x4_t neg128Vec = vdupq_n_f32(-128.0f); + + const float* currentKeyBlock = keySrc + k * blockHeadDim; + const float* currentMaxBlock = maxKeyPtr + k * blockHeadDim; + + int32x4_t sumInt32_0 = vdupq_n_s32(0); + int32x4_t sumInt32_1 = vdupq_n_s32(0); + int headDimIdx = 0; + for (; headDimIdx <= blockHeadDim - 8; headDimIdx += 8) { + float32x4_t srcVec0 = vld1q_f32(currentKeyBlock + headDimIdx); + float32x4_t srcVec1 = vld1q_f32(currentKeyBlock + headDimIdx + 4); + float32x4_t maxVec0 = vld1q_f32(currentMaxBlock + headDimIdx); + float32x4_t maxVec1 = vld1q_f32(currentMaxBlock + headDimIdx + 4); + + float32x4_t keyData0 = vsubq_f32(srcVec0, maxVec0); + float32x4_t keyData1 = vsubq_f32(srcVec1, maxVec1); + + keyData0 = vaddq_f32(keyData0, negBiasVec); + keyData1 = vaddq_f32(keyData1, negBiasVec); + + keyData0 = vmulq_f32(keyData0, scaleVec); + keyData1 = vmulq_f32(keyData1, scaleVec); + + keyData0 = vaddq_f32(keyData0, neg128Vec); + keyData1 = vaddq_f32(keyData1, neg128Vec); + + int32x4_t s32_0 = vcvtaq_s32_f32(keyData0); + int32x4_t s32_1 = vcvtaq_s32_f32(keyData1); + + sumInt32_0 = vaddq_s32(sumInt32_0, s32_0); + sumInt32_1 = vaddq_s32(sumInt32_1, s32_1); + + int16x4_t s16_0 = vmovn_s32(s32_0); + int16x4_t s16_1 = vmovn_s32(s32_1); + + int16x8_t s16Combined = vcombine_s16(s16_0, s16_1); + + int8x8_t s8Vec = vqmovn_s16(s16Combined); + + if (lP == 8) { + int i = headDimIdx / lP; + int8_t* dstPtr = weightDstBase + i * weightStride2 + inIndex * lP; + vst1_s8(dstPtr, s8Vec); + } else if (lP == 4) { + vst1_s8(tempBuffer, s8Vec); + int iLow = headDimIdx / lP; + int iHigh = (headDimIdx + 4) / lP; + + int8_t* dstPtrLow = weightDstBase + iLow * weightStride2 + inIndex * lP; + int8_t* dstPtrHigh = weightDstBase + iHigh * weightStride2 + inIndex * lP; + + std::memcpy(dstPtrLow, tempBuffer, 4); + std::memcpy(dstPtrHigh, tempBuffer + 4, 4); + } else { + vst1_s8(tempBuffer, s8Vec); + for (int k = 0; k < 8; ++k) { + int headDimCurr = headDimIdx + k; + int i = headDimCurr / lP; + int j = headDimCurr % lP; + weightDstBase[i * weightStride2 + inIndex * lP + j] = tempBuffer[k]; + } + } + } + + int32_t sumInt32 = vaddvq_s32(sumInt32_0) + vaddvq_s32(sumInt32_1); + + // remain L + for (; headDimIdx < blockHeadDim; ++headDimIdx) { + int i = headDimIdx / lP; + int j = headDimIdx % lP; + float keyVal = currentKeyBlock[headDimIdx] - currentMaxBlock[headDimIdx]; + float quant_val = (keyVal - minKey) * quantScaleVal - 128.0f; + int32_t rounded_val = static_cast(roundf(quant_val)); + int8_t finalVal = static_cast(std::max(-128, std::min(127, rounded_val))); + weightDstBase[i * weightStride2 + inIndex * lP + j] = finalVal; + sumInt32 += finalVal; + } + + // store sum + sumKeyPtr[outIndex * hP + inIndex] = (float)sumInt32 * range / 255.f + (minKey * blockHeadDim + 128.0f * range * blockHeadDim / 255.0f); + } + } +} + +void MNNQuantAttentionValue(int8_t* dst, const float* source, float* valueSum, int32_t* params) { + // float value src : [kvSeq,kvNumHead,headDim] + // int8_t value dest: [updiv(maxLength,flashAttentionBlockKv), updiv(headDim,hp),updiv(flashAttentionBlockKv,lp),hp,lp] + // float value sum: [updiv(maxLength,flashAttentionBlockKv), roundup(headDim,hp)] + int32_t kvNumHead = params[0]; + int32_t seqLen = params[1]; + int32_t headDim = params[2]; + int32_t blockNum = params[3]; + int32_t maxLength = params[4]; + + int32_t lP = params[5]; + int32_t hP = params[6]; + int32_t pastLength = params[7]; + int32_t kvHeadIdx = params[8]; + + int32_t flashAttentionBlockKv = params[9]; + + auto blockKvseq = UP_DIV(seqLen + pastLength, blockNum); + auto weightStride2 = lP * hP; + auto weightStride1 = UP_DIV(flashAttentionBlockKv, lP) * weightStride2; + + auto packedStride1 = (int)(weightStride1 + 2 * hP * sizeof(float)); + auto packedStride0 = UP_DIV(headDim, hP) * packedStride1; + + auto srcStride0 = kvNumHead * headDim; + + auto sourceFp32 = (float*)source; + + // quant scale & bias + if (pastLength == 0) { + for (int d = 0; d < headDim; ++d) { + float* scalePtr = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasPtr = scalePtr + hP; + + // find min,max + float dMax = sourceFp32[d + kvHeadIdx * headDim]; + float dMin = dMax; + for (int s = 0; s < seqLen; ++s) { + float data = sourceFp32[s * srcStride0 + d + kvHeadIdx * headDim]; + dMax = ALIMAX(dMax, data); + dMin = ALIMIN(dMin, data); + } + + // scale & bias + float range = dMax - dMin; + if (range < 1e-6) { + scalePtr[0] = 0.f; + biasPtr[0] = dMax; + } else { + float scale = range / 255.f; + float bias = range / 255.f * 128.f + dMin; + scalePtr[0] = scale; + biasPtr[0] = bias; + } + } + } + + // copy the scale&bias to each blockKv + // pastLength == 0: First time prefill + // pastLength % flashAttentionBlockKv == 0: Open a new blockKv + if (pastLength == 0 || (pastLength % flashAttentionBlockKv) == 0) { + int32_t d0 = UP_DIV(maxLength, flashAttentionBlockKv); + int32_t d1 = UP_DIV(headDim, hP); + for (int k = 0; k < d0; ++k) { + for (int r = 0; r < d1; ++r) { + float* scalePtr = (float*)(dst + k * packedStride0 + r * packedStride1 + weightStride1); + float* biasPtr = scalePtr + hP; + memcpy(scalePtr, dst + r * packedStride1 + weightStride1, hP * sizeof(float)); + memcpy(biasPtr, dst + r * packedStride1 + weightStride1 + hP * sizeof(float), hP * sizeof(float)); + } + } + } + + // Quant fp16 + for (int d = 0; d < headDim; ++d) { + // dst address + int idxBase = (d / hP) * packedStride1 + (d % hP) * lP; + int8_t* dstBase = dst + idxBase; + float* scaleBase = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasBase = scaleBase + hP; + float* sumBase = valueSum + (d / hP) * hP + (d % hP); + + float qscale = scaleBase[0] < 1e-6 ? 0 : 1.0f / scaleBase[0]; + float qbias = scaleBase[0] < 1e-6 ? 0 : (-biasBase[0] / scaleBase[0]); + // quant + for (int s = 0; s < seqLen; ++s) { + int kvSeqIndx = s + pastLength; + int idxInner = (kvSeqIndx / flashAttentionBlockKv) * packedStride0 + (kvSeqIndx % flashAttentionBlockKv) / lP * weightStride2 + (kvSeqIndx % flashAttentionBlockKv) % lP; + float xf = sourceFp32[s * srcStride0 + d + kvHeadIdx * headDim]; + int8_t xq = ALIMAX(ALIMIN(127, static_cast(roundf(xf * qscale + qbias))), -128); + dstBase[idxInner] = xq; + + // sum + int idxSum = (kvSeqIndx / flashAttentionBlockKv) * ROUND_UP(headDim, hP); + sumBase[idxSum] += ((float)xq * scaleBase[0] + biasBase[0]); + } + } +} +#endif // __aarch64__ +#endif // MNN_SUPPORT_TRANSFORMER_FUSE + void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { int w = dim[0]; int h = dim[1]; @@ -87,46 +398,35 @@ void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) { #define EXP_APPROX_C0 vdupq_n_f32(1.0f) #ifndef __aarch64__ -static inline float32x4_t vrndaq_f32_compat(float32x4_t val) { - const float32x4_t v_zero = vdupq_n_f32(0.0f); - - float32x4_t v_truncated = vcvtq_f32_s32(vcvtq_s32_f32(val)); - - uint32x4_t v_is_positive_frac = vcgtq_f32(val, v_truncated); - uint32x4_t v_is_negative_frac = vcltq_f32(val, v_truncated); - - float32x4_t v_offset = vbslq_f32(v_is_positive_frac, vdupq_n_f32(1.0f), v_zero); - v_offset = vbslq_f32(v_is_negative_frac, vdupq_n_f32(-1.0f), v_offset); - - return vaddq_f32(v_truncated, v_offset); +static inline float32x4_t vrndaq_f32_compat(float32x4_t x) { + float32x4_t sign = vbslq_f32(vdupq_n_u32(0x80000000), x, vdupq_n_f32(0.0f)); + return vcvtq_f32_s32(vcvtq_s32_f32(vaddq_f32(x, vbslq_f32(vcltq_f32(x, vdupq_n_f32(0.0f)), vdupq_n_f32(-0.5f), vdupq_n_f32(0.5f))))); } - #endif static inline float32x4_t expApprox(float32x4_t x) { x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT); - + float32x4_t k_float; float32x4_t r; float32x4_t exp_r; - #if defined(__aarch64__) - // 1. x = k * ln(2) + r k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV)); - + // r = x - k * ln(2) r = vfmsq_f32(x, k_float, EXP_APPROX_LN2); - // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4))) (Horner's method) - exp_r = vfmaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r - exp_r = vfmaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...) - exp_r = vfmaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...) - exp_r = vfmaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...) + // P(r) = (c0 + c2*r^2 + c4*r^4) + r*(c1 + c3*r^2) + float32x4_t r2 = vmulq_f32(r, r); + float32x4_t p_odd = vfmaq_f32(EXP_APPROX_C1, EXP_APPROX_C3, r2); + float32x4_t p_even = vfmaq_f32(EXP_APPROX_C0, EXP_APPROX_C2, r2); + p_even = vfmaq_f32(p_even, EXP_APPROX_C4, vmulq_f32(r2, r2)); + exp_r = vfmaq_f32(p_even, p_odd, r); #else k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV)); - + r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2)); @@ -140,8 +440,7 @@ static inline float32x4_t expApprox(float32x4_t x) { int32x4_t k_int = vcvtq_s32_f32(k_float); int32x4_t k_shifted = vshlq_n_s32(k_int, 23); - float32x4_t result = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted)); - return result; + return vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted)); } void MNNExpC8(float* dst, const float* src, float* offset, const float* parameters, size_t countC8) { @@ -195,7 +494,7 @@ void MNNExp(float* destPtr, const float* srcPtr, float* offset, size_t size) { srcPtr += 8; destPtr += 8; size -= 8; - + } while (size >= 4) { float32x4_t srcVec0 = vld1q_f32(srcPtr); @@ -242,7 +541,7 @@ void MNNExp(float* destPtr, const float* srcPtr, float* offset, size_t size) { srcPtr += 8; destPtr += 8; size -= 8; - + } while (size >= 4) { float32x4_t srcVec0 = vld1q_f32(srcPtr); @@ -424,100 +723,152 @@ void MNNPackForMatMul_A(float* dst, const float* src, size_t E, size_t L, size_t } } -void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { +void MNNSoftmax(float* softmaxDst, const float* softmaxSrc, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) { + + // source shape: [reduceSizeOuter, outside, reduceSizeInner] + // for C4, [up_div(reduceSize,4), outside,4] => reduceSizeOuter=up_div(reduceSize,4), reduceSizeInner=4 + // for C, [outside, reduceSize] => reduceSizeOuter=1, reduceSizeInner=reduceSize + + const int packUnit = 4; + int reduceSizeOuter = 1; + int reduceSizeInner = reduceSize; + int stride0 = packUnit; + if (pack > 1) { + reduceSizeOuter = UP_DIV(reduceSize, pack); + reduceSizeInner = pack; + stride0 = outside * reduceSizeInner; + } + for (int k = 0; k < outside; ++k) { - auto source = input + k * reduceSize; - auto dest = softmaxDst + k * reduceSize; + if (mask && kvSeqOffset > k + validOffset) { + if (updateScale){ + updateScale[k] = 1; + } + for (int j = 0; j < reduceSizeOuter; ++j) { + int i = 0; + for (; i < reduceSizeInner; i += packUnit) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner + i; + vst1q_f32(destPtr, vdupq_n_f32(0.0f)); + } + if (i < reduceSizeInner) { + memset(softmaxDst + j * stride0 + k * reduceSizeInner + i, 0, (reduceSizeInner - i) * sizeof(float)); + } + } + continue; + } - // new max - auto srcPtr = source; - auto size = reduceSize; - float32x4_t maxVec0 = vdupq_n_f32(source[0]); - auto maxVec1 = maxVec0; + const int validReduceSize = mask ? ALIMIN(reduceSize, k + (validOffset + 1) - kvSeqOffset) : reduceSize; + const int remain = validReduceSize % packUnit; + const int sizeDiv = validReduceSize / packUnit; - float oldMax = source[0]; + // 1. newMax + float oldMax = std::numeric_limits::lowest(); if (runningMax) { oldMax = runningMax[k]; } - while (size >= 8) { - float32x4_t srcVec0 = vld1q_f32(srcPtr); - float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); - - maxVec0 = vmaxq_f32(maxVec0, srcVec0); - maxVec1 = vmaxq_f32(maxVec1, srcVec1); + float newMax = std::numeric_limits::lowest(); - srcPtr += 8; - size -= 8; + for (int j = 0; j < sizeDiv; ++j) { + auto srcPtr = softmaxSrc + j * stride0 + k * reduceSizeInner; + float32x4_t srcVec = vld1q_f32(srcPtr); + newMax = ALIMAX(newMax, vmaxvq_f32_compat(srcVec)); } - while (size >= 4) { - float32x4_t srcVec0 = vld1q_f32(srcPtr); - maxVec0 = vmaxq_f32(maxVec0, srcVec0); - srcPtr += 4; - size -= 4; + if (remain > 0) { + auto srcPtr = softmaxSrc + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + newMax = ALIMAX(newMax, srcPtr[i]); + } } - maxVec0 = vmaxq_f32(maxVec0, maxVec1); - float32x2_t maxP = vpmax_f32(vget_low_f32(maxVec0), vget_high_f32(maxVec0)); - maxP = vpmax_f32(maxP, maxP); - auto newMax = vget_lane_f32(maxP, 0); - - while (size > 0) { - newMax = ALIMAX(newMax, srcPtr[0]); - srcPtr += 1; - size -= 1; + const float finalMax = ALIMAX(oldMax, newMax); + const float32x4_t finalMaxVec = vdupq_n_f32(finalMax); + + // 2. exp(x - finalMax) + float sum = 0.0f; + float32x4_t sumVec = vdupq_n_f32(0.0f); + + for (int j = 0; j < sizeDiv; ++j) { + auto idx = j * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + + float32x4_t srcVec = vld1q_f32(srcPtr); + // sub max + srcVec = vsubq_f32(srcVec, finalMaxVec); + // exp + srcVec = expApprox(srcVec); + // sum + sumVec = vaddq_f32(sumVec, srcVec); + // store + vst1q_f32(dstPtr, srcVec); } - newMax = ALIMAX(oldMax, newMax); - srcPtr = source; - auto destPtr = dest; - size = reduceSize; + if (remain > 0) { + auto idx = sizeDiv * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; - float exprOffset[4] = { - 1.0f, - 0.0f, - 0.0f, - 0.0f - }; - exprOffset[2] = -newMax; + float tempDst[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - // expf(xi-newmax) & new sum - MNNExp(destPtr, srcPtr, exprOffset, size); + for(int i = 0; i < remain; ++i) { + float val = expf(srcPtr[i] - finalMax); + sum += val; + tempDst[i] = val; + } + vst1q_f32(dstPtr, vld1q_f32(tempDst)); + } + sum += vaddvq_f32_compat(sumVec); + + // 3. if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { - // update runningSum, runningMax, scale=expf(oldMax-newMax) - float newSum = exprOffset[3]; - runningSum[k] = runningSum[k] * expf(oldMax - newMax) + newSum; - runningMax[k] = newMax; - updateScale[k] = expf(oldMax - newMax); + // update runningSum, runningMax, scale + float scaleForSum = expf(oldMax - finalMax); + runningSum[k] = runningSum[k] * scaleForSum + sum; + runningMax[k] = finalMax; + updateScale[k] = scaleForSum; } else { - // Normalize - float sum = exprOffset[3]; + // Normalization + if (runningMax != nullptr && runningSum != nullptr) { + sum += runningSum[k] * expf(oldMax - finalMax); + } float scale = 1.0f / (sum + 1e-20f); - int count = reduceSize; - auto pDest = dest; + float32x4_t scale_vec = vdupq_n_f32(scale); - float32x4_t scaleVec = vdupq_n_f32(scale); - while (count >= 4) { + for (int j = 0; j < sizeDiv; ++j) { + auto pDest = softmaxDst + j * stride0 + k * reduceSizeInner; float32x4_t data = vld1q_f32(pDest); - data = vmulq_f32(data, scaleVec); + data = vmulq_f32(data, scale_vec); vst1q_f32(pDest, data); - - pDest += 4; - count -= 4; } + if (remain > 0) { + auto pDest = softmaxDst + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + pDest[i] = pDest[i] * scale; + } + } + } - while (count > 0) { - *pDest *= scale; - pDest++; - count--; + // 4. memset 0 + if (pack > 1) { + if (validReduceSize % packUnit > 0) { + memset(softmaxDst + sizeDiv * stride0 + k * reduceSizeInner + (validReduceSize % packUnit), 0, (packUnit - (validReduceSize % packUnit)) * sizeof(float)); + } + auto validDiv4 = UP_DIV(validReduceSize, packUnit); + auto allDiv4 = UP_DIV(reduceSize, packUnit); + for (int j = validDiv4; j < allDiv4; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, packUnit * sizeof(float)); } + } else { + memset(softmaxDst + k * reduceSizeInner + validReduceSize, 0, (reduceSize - validReduceSize) * sizeof(float)); } } + return; } - #ifndef MNN_USE_NEON void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) { diff --git a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S index 94eb5cb2ba..d87090d8cc 100644 --- a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S +++ b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S @@ -145,26 +145,28 @@ ldr x27, [x6, #88] // inputBias ldr x10, [x6, #96] // accumBuffer ldr x26, [x6, #64] // blockNum lsl x22, x7, #2 // eDest * SRC_UNIT -mov x14, #-32 + +mov x25, #-32 add x23, x6, #16 // int8 max ptr cbz x28, TILE_12 ldr x23, [x6, #56] // fp32minmax cbz x27, TILE_12 -mov x14, #16 +sub x25, x22, #32 TILE_12: cmp x7, #12 blt TILE_8 sub x4, x4, #128 + mov x12, x2 + mov x6, x0 + mov x14, x5 mov x20, x9 mov x15, x8 // input kernel sum - mov x6, x27 // input dequant bias mov x21, x24 // input dequant scale - mov x12, #-320 cbnz x28, L8LoopDz_TILE_12 add x4, x4, #128 // int8 do not change L8LoopDz_TILE_12: - cmp x5, #2 + cmp x14, #2 blt L4LoopDz_TILE_12 mov x11, x1 mov x19, #0 @@ -179,8 +181,8 @@ TILE12_BLOCKNUM: SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: - ld1 {v3.16b, v4.16b}, [x2], #32 // weight - ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + ld1 {v3.16b, v4.16b}, [x12], #32 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] @@ -190,6 +192,7 @@ TILE12_BLOCKNUM: .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] @@ -213,9 +216,8 @@ TILE12_BLOCKNUM: L8LoopSzEnd_TILE_12: L8Tile12Quan: - ld1 {v0.4s, v1.4s}, [x2], #32 // scale - ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // input kernel sum - ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 // weight scale&bias + ld1 {v4.4s, v5.4s, v6.4s}, [x8], x22 // input kernel sum Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -230,9 +232,9 @@ TILE12_BLOCKNUM: MUL_SCALE v1, v24, v25, v26, v27 MUL_SCALE v1, v28, v29, v30, v31 - cbz x21, TILE12_L8_MLA + cbz x21, TILE12_L8_MLA_TERM ld1 {v0.4s, v1.4s}, [x24], #32 - ld1 {v7.4s}, [x24], x14 + ld1 {v7.4s}, [x24], x25 MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v7, v16, v17, v18, v19 @@ -240,36 +242,36 @@ TILE12_BLOCKNUM: MUL_EXTRA_SCALE v1, v24, v25, v26, v27 MUL_EXTRA_SCALE v7, v28, v29, v30, v31 - TILE12_L8_MLA: - MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 - MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 - MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 - MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 - MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 - MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 - MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 - MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 - MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 - MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 - MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 - MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 - - MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 - MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 - MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 - MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 - MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 - MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 - MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 - MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 - MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 - MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 - MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 - MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + TILE12_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v4, v2, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v4, v2, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v4, v2, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v4, v2, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v5, v2, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v5, v2, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v5, v2, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v5, v2, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v6, v2, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v6, v2, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v6, v2, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v6, v2, 3 // tile:11, oc:0-3 + + MLA_WEIGHTZERO v20, v4, v3, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v4, v3, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v4, v3, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v4, v3, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v5, v3, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v5, v3, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v5, v3, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v5, v3, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v6, v3, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v6, v3, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v6, v3, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v6, v3, 3 // tile:11, oc:4-7 cbz x27, TILE12_ADD_DSTV - ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias + ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias ld1 {v3.4s, v4.4s}, [x28], #32 // weight kernel sum MLA_WEIGHTZERO v8, v0, v3, 0 MLA_WEIGHTZERO v9, v0, v3, 1 @@ -308,9 +310,10 @@ TILE12_BLOCKNUM: ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x12 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10] ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 + sub x10, x10, #320 TILE12_L8_ACCUM_BUFFER: add x19, x19, #1 @@ -321,12 +324,13 @@ TILE12_BLOCKNUM: st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], x12 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] + sub x10, x10, #320 b TILE12_BLOCKNUM TILE12_POST: cbz x28, L8Tile12QuanUseInt8 - sub x5, x5, #2 + sub x14, x14, #2 cbz x9, TILE12_RELU ld1 {v0.4s, v1.4s}, [x20], #32 ADD_BIAS_FLOAT v8, v9, v10, v11, v0 @@ -349,16 +353,16 @@ TILE12_BLOCKNUM: sub x23, x23, #4 TILE12_STORE: - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x4 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], x4 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x6], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x6], x4 b L8Tile12LoopCheck L8Tile12QuanUseInt8: - sub x5, x5, #2 + sub x14, x14, #2 ld1r {v7.4s}, [x23], #4 // int8 max ld1r {v6.4s}, [x23] // int8 min ld1 {v0.4s, v1.4s}, [x9], #32 @@ -399,14 +403,15 @@ TILE12_BLOCKNUM: smin v19.16b, v7.16b, v19.16b smin v20.16b, v7.16b, v20.16b smin v21.16b, v7.16b, v21.16b - st1 {v16.16b, v17.16b, v18.16b}, [x0], x4 - st1 {v19.16b, v20.16b, v21.16b}, [x0], x4 + st1 {v16.16b, v17.16b, v18.16b}, [x6], x4 + st1 {v19.16b, v20.16b, v21.16b}, [x6], x4 L8Tile12LoopCheck: - cbz x5, End + cbz x14, Tile12End mov x8, x15 // revert input kernel sum mov x24, x21 // revert input dequant scale - mov x27, x6 // revert input dequant bias + cbz x27, L8LoopDz_TILE_12 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 b L8LoopDz_TILE_12 L4LoopDz_TILE_12: @@ -419,8 +424,8 @@ L4_TILE12_BLOCKNUM: SET_BIAS v16, v17, v18, v19 L4_LoopSz_TILE_12: - ld1 {v3.16b}, [x2] // weight - ld1 {v0.16b, v1.16b, v2.16b}, [x1], #48 // src + ld1 {v3.16b}, [x12] // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] @@ -429,7 +434,7 @@ L4_TILE12_BLOCKNUM: .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - add x2, x2, #32 // weight offset=lp*hp=32 + add x12, x12, #32 // weight offset=lp*hp=32 subs x13, x13, #1 .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] @@ -439,11 +444,11 @@ L4_TILE12_BLOCKNUM: L4_Tile12Quan: - ld1 {v0.4s}, [x2] // scale - add x2, x2, #32 - ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // x kernel sum - ld1 {v5.4s}, [x2] // weight quan zeropoint - add x2, x2, #32 + ld1 {v0.4s}, [x12] // scale + add x12, x12, #32 + ld1 {v2.4s, v3.4s, v4.4s}, [x8], x22 // x kernel sum + ld1 {v5.4s}, [x12] // weight quan zeropoint + add x12, x12, #32 Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -453,7 +458,7 @@ L4_TILE12_BLOCKNUM: cbz x21, TILE12_L4_MLA ld1 {v0.4s, v1.4s}, [x24], #32 - ld1 {v7.4s}, [x24], x14 + ld1 {v7.4s}, [x24], x25 MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v7, v16, v17, v18, v19 @@ -473,7 +478,7 @@ L4_TILE12_BLOCKNUM: MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 cbz x27, L4_TILE12_ADD_DSTV - ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias + ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias ld1 {v3.4s}, [x28] // weight kernel sum MLA_WEIGHTZERO v8, v0, v3, 0 MLA_WEIGHTZERO v9, v0, v3, 1 @@ -526,10 +531,10 @@ L4_TILE12_BLOCKNUM: sub x23, x23, #4 L4_TILE12_STORE: - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 - b End + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], x4 + b Tile12End L4Tile12QuanUseInt8: ld1r {v7.4s}, [x23], #4 // int8 max @@ -555,8 +560,24 @@ L4_TILE12_BLOCKNUM: smin v16.16b, v7.16b, v16.16b smin v17.16b, v7.16b, v17.16b smin v18.16b, v7.16b, v18.16b - st1 {v16.16b, v17.16b, v18.16b}, [x0], x4 - b End + st1 {v16.16b, v17.16b, v18.16b}, [x6], x4 + b Tile12UpdateAddr + +Tile12End: + add x4, x4, #128 // revert x4, int8 do not need + Tile12UpdateAddr: + add x0, x0, #192 + sub x7, x7, #12 + cbz x7, End + add x1, x1, #48 + add x8, x15, #48 + add x24, x21, #48 + + + cbz x27, TILE_8 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 + REVERT_WEIGHT_KERNEL_SUM x28, x14, x26, x5 + add x27, x27, #48 TILE_8: mov x25, #0 diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S index 1e8f6e1e88..d296821aa9 100644 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S @@ -149,24 +149,23 @@ ldr x23, [x6, #56] // fp32minmax ldr x26, [x6, #64] // blockNum lsl x22, x7, #2 // eDest * SRC_UNIT -mov x14, #-32 +mov x25, #-32 cbz x27, TILE_12 -mov x14, #16 +sub x25, x22, #32 TILE_12: cmp x7, #12 blt TILE_8 sub x4, x4, #128 + mov x12, x2 + mov x6, x0 + mov x14, x5 mov x20, x9 mov x15, x8 // input kernel sum - mov x6, x27 // input dequant bias mov x21, x24 // input dequant scale - mov x12, #-320 - - cmp x5, #2 - blt L4LoopDz_TILE_12 - L8LoopDz_TILE_12: + cmp x14, #2 + blt L4LoopDz_TILE_12 mov x11, x1 mov x19, #0 TILE12_BLOCKNUM: @@ -181,8 +180,8 @@ TILE12_BLOCKNUM: SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: - ld1 {v5.16b}, [x2], #16 // weight - ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + ld1 {v5.16b}, [x12], #16 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src // int4->int8 ushr v3.16b, v5.16b, #4 and v4.16b, v5.16b, v7.16b @@ -220,9 +219,8 @@ TILE12_BLOCKNUM: L8LoopSzEnd_TILE_12: L8Tile12Quan: - ld1 {v0.4s, v1.4s}, [x2], #32 // scale - ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // input kernel sum - ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 // weight scale&bias + ld1 {v4.4s, v5.4s, v6.4s}, [x8], x22 // input kernel sum Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -238,7 +236,7 @@ TILE12_BLOCKNUM: MUL_SCALE v1, v28, v29, v30, v31 ld1 {v0.4s, v1.4s}, [x24], #32 - ld1 {v7.4s}, [x24], x14 + ld1 {v7.4s}, [x24], x25 MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v7, v16, v17, v18, v19 @@ -247,34 +245,34 @@ TILE12_BLOCKNUM: MUL_EXTRA_SCALE v7, v28, v29, v30, v31 TILE12_L8_MLA_TERM: - MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 - MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 - MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 - MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 - MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 - MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 - MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 - MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 - MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 - MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 - MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 - MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 - - MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 - MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 - MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 - MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 - MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 - MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 - MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 - MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 - MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 - MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 - MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 - MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + MLA_WEIGHTZERO v8, v4, v2, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v4, v2, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v4, v2, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v4, v2, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v5, v2, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v5, v2, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v5, v2, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v5, v2, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v6, v2, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v6, v2, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v6, v2, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v6, v2, 3 // tile:11, oc:0-3 + + MLA_WEIGHTZERO v20, v4, v3, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v4, v3, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v4, v3, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v4, v3, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v5, v3, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v5, v3, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v5, v3, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v5, v3, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v6, v3, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v6, v3, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v6, v3, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v6, v3, 3 // tile:11, oc:4-7 cbz x27, TILE12_ADD_DSTV - ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias + ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias ld1 {v3.4s, v4.4s}, [x28], #32 // weight kernel sum MLA_WEIGHTZERO v8, v0, v3, 0 MLA_WEIGHTZERO v9, v0, v3, 1 @@ -313,9 +311,10 @@ TILE12_BLOCKNUM: ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 - ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x12 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10] ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 + sub x10, x10, #320 TILE12_L8_ACCUM_BUFFER: add x19, x19, #1 @@ -326,11 +325,12 @@ TILE12_BLOCKNUM: st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64 st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], x12 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] + sub x10, x10, #320 b TILE12_BLOCKNUM TILE12_POST: - sub x5, x5, #2 + sub x14, x14, #2 cbz x9, TILE12_RELU ld1 {v0.4s, v1.4s}, [x20], #32 ADD_BIAS_FLOAT v8, v9, v10, v11, v0 @@ -353,19 +353,19 @@ TILE12_BLOCKNUM: sub x23, x23, #4 TILE12_STORE: - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x4 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], x4 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x6], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x6], x4 L8Tile12LoopCheck: - cbz x5, End - cmp x5, #2 + cbz x14, Tile12End mov x8, x15 // revert input kernel sum mov x24, x21 // revert input dequant scale - mov x27, x6 // revert input dequant bias + cbz x27, L8LoopDz_TILE_12 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 bge L8LoopDz_TILE_12 L4LoopDz_TILE_12: @@ -380,8 +380,8 @@ L4_TILE12_BLOCKNUM: SET_BIAS v16, v17, v18, v19 L4_LoopSz_TILE_12: - ld1 {v5.16b}, [x2], #16 // weight - ld1 {v0.16b, v1.16b, v2.16b}, [x1], #48 // src + ld1 {v5.16b}, [x12], #16 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src // int4->int8 ushr v3.16b, v5.16b, #4 @@ -402,11 +402,11 @@ L4_TILE12_BLOCKNUM: L4_Tile12Quan: - ld1 {v0.4s}, [x2] // scale - add x2, x2, #32 - ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // x kernel sum - ld1 {v5.4s}, [x2] // weight quan zeropoint - add x2, x2, #32 + ld1 {v0.4s}, [x12] // scale + add x12, x12, #32 + ld1 {v2.4s, v3.4s, v4.4s}, [x8], x22 // x kernel sum + ld1 {v5.4s}, [x12] // weight quan zeropoint + add x12, x12, #32 Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -415,7 +415,7 @@ L4_TILE12_BLOCKNUM: MUL_SCALE v0, v16, v17, v18, v19 ld1 {v0.4s, v1.4s}, [x24], #32 - ld1 {v7.4s}, [x24], x14 + ld1 {v7.4s}, [x24], x25 MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v7, v16, v17, v18, v19 @@ -435,7 +435,7 @@ L4_TILE12_BLOCKNUM: MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 cbz x27, L4_TILE12_ADD_DSTV - ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias + ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias ld1 {v3.4s}, [x28] // weight kernel sum MLA_WEIGHTZERO v8, v0, v3, 0 MLA_WEIGHTZERO v9, v0, v3, 1 @@ -488,10 +488,24 @@ L4_TILE12_BLOCKNUM: sub x23, x23, #4 L4_TILE12_STORE: - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 - b End + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], x4 + +Tile12End: + + add x0, x0, #192 + sub x7, x7, #12 + cbz x7, End + add x1, x1, #48 + add x8, x15, #48 + add x24, x21, #48 + add x4, x4, #128 // revert x4 + + cbz x27, TILE_8 + REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22 + REVERT_WEIGHT_KERNEL_SUM x28, x14, x26, x5 + add x27, x27, #48 TILE_8: mov x25, #0 diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp16.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp16.S index b108851bdf..f43669adc9 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp16.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp16.S @@ -86,12 +86,13 @@ mov x19, #32 // HP=32 .inst 0x253617e7 // whilelt p7.b, xzr, x22 // eSize * LP int8 valid .inst 0x25b347f2 // whilelt pn10.s, xzr, x19, vlx2 // 32 float valid .inst 0x2558e3e2 // ptrue p2.h - -mov x25, 0 // inputBlockNum=1 -cbz x27, ESIZE -mov x25, x22 // input block quant: realDstCount * sizeof(float) +.inst 0x2518e084 // ptrue p4.b, #4 // 4 int8_t valid +.inst 0x2518e125 // ptrue p5.b, vl16 // 16 int8_t valid +.inst 0x2558e3e6 // ptrue p6.h // all fp16 valid ESIZE: + cmp x7, #2 + ble TILE_2 mov x19, x13 // input kernel sum mov x21, x23 // input dequant scale mov x20, x27 // input dequant bias @@ -156,8 +157,8 @@ bne LoopL .inst 0xa0404b80 // ld1w {z0.s, z1.s}, pn10/z, [x28] // weight kernel sum .inst 0x80800442 // fmopa za2.s, p1/m, p0/m, z2.s, z0.s .inst 0x80810443 // fmopa za3.s, p1/m, p0/m, z2.s, z1.s - add x27, x27, x25 - add x23, x23, x25 + add x27, x27, x22 + add x23, x23, x22 .inst 0x043c505c // addvl x28, x28, #2 HP_DEQUANT: @@ -190,7 +191,7 @@ bne LoopL .inst 0xc0080022 // zero {za1.s} // inputScale x weightScale -> [16,16] - .inst 0x809e07e1 // fmopa za1.s, p1/m, p0/m, z31.s, z30.s + .inst 0x809e67e1 // fmopa za1.s, p1/m, p3/m, z31.s, z30.s mov w8, #1 mov w10, #3 // extract scale from za1.s @@ -319,6 +320,271 @@ bne LoopL mov x27, x20 b LoopH +TILE_2: + cmp x7, #1 + beq TILE_1 + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + mov x15, #32 // 2 * pack(8) * sizeof(float16) + .inst 0x252f17e5 // whilelt p5.b, xzr, x15 + + .inst 0x84c0b9de // ld1rh {z30.h}, p6/z, [x14] + .inst 0x84c1b9df // ld1rh {z31.h}, p6/z, [x14, #2] + + +LoopDz_TILE2: + .inst 0x25b8c01a // mov z26.s, #0 + .inst 0x25b8c01b // mov z27.s, #0 + .inst 0x25b8c01c // mov z28.s, #0 + .inst 0x25b8c01d // mov z29.s, #0 + mov w8, #0 + mov x11, x1 // src + mov x15, x26 // blockid +TILE2_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE2: + .inst 0xa4003d62 // ld1rqb {z2.b}, p7/z, [x11] // src + .inst 0xa400ac40 // ld1b {z0.b}, p3/z, [x2] // weight + // int4->int8 + .inst 0xc08a4004 // luti4 {z4.b-z5.b}, zt0, z0[0] + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + .inst 0xc15214a4 // sdot za.s[w8, 4, VGx2], {z4.b-z5.b}, z2.b[1] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225022 // addvl x2, x2, #1 + + bne LoopSz_TILE2 + + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0xc006088a // mova {z10.s-z11.s}, za.s[w8, 4, VGx2] + .inst 0xc132e108 // scvtf {z8.s-z11.s}, {z8.s-z11.s} + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0xa4003da4 // ld1rqb {z4.b}, p7/z, [x13] // input kernel sum + .inst 0xa4003ee5 // ld1rqb {z5.b}, p7/z, [x23] // input kernel scale + + .inst 0x64a52006 // fmul z6.s, z0.s, z5.s[0] // e0 + .inst 0x64a52027 // fmul z7.s, z1.s, z5.s[0] + .inst 0x64ad200c // fmul z12.s, z0.s, z5.s[1] // e1 + .inst 0x64ad202d // fmul z13.s, z1.s, z5.s[1] + + .inst 0x64a4005a // fmla z26.s, z2.s, z4.s[0] // e0 + .inst 0x64a4007b // fmla z27.s, z3.s, z4.s[0] + .inst 0x64ac005c // fmla z28.s, z2.s, z4.s[1] // e1 + .inst 0x64ac007d // fmla z29.s, z3.s, z4.s[1] + + .inst 0x65a60d1a // fmla z26.s, p3/m, z8.s, z6.s + .inst 0x65a70d3b // fmla z27.s, p3/m, z9.s, z7.s + .inst 0x65ac0d5c // fmla z28.s, p3/m, z10.s, z12.s + .inst 0x65ad0d7d // fmla z29.s, p3/m, z11.s, z13.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE2_ADD_DSTV + .inst 0xa4003f65 // ld1rqb {z5.b}, p7/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011a // fmla z26.s, z8.s, z5.s[0] + .inst 0x64a5013b // fmla z27.s, z9.s, z5.s[0] + .inst 0x64ad011c // fmla z28.s, z8.s, z5.s[1] + .inst 0x64ad013d // fmla z29.s, z9.s, z5.s[1] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE2_ADD_DSTV: + cbnz x15, TILE2_BLOCKNUM + + TILE2_STORE: + lsl x15, x5, #3 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa0404520 // ld1w {z0.s, z1.s}, pn9/z, [x9] // bias + .inst 0x6580035a // fadd z26.s, z26.s, z0.s + .inst 0x6581037b // fadd z27.s, z27.s, z1.s + .inst 0x6580039c // fadd z28.s, z28.s, z0.s + .inst 0x658103bd // fadd z29.s, z29.s, z1.s + .inst 0x04295049 // addvl x9, x9, #2 + + .inst 0xc120e342 // fcvt z2.h, {z26.s-z27.s} + .inst 0xc120e383 // fcvt z3.h, {z28.s-z29.s} + .inst 0xc17fc3c2 // fclamp {z2.h-z3.h}, z30.h, z31.h + .inst 0xc123d440 // zip {z0.q, z1.q}, z2.q, z3.q + + cmp x5, #4 + bge TILE2_STORE32 + cmp x5, #3 + beq TILE2_STORE24 + cmp x5, #2 + beq TILE2_STORE16 + + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + b End + + TILE2_STORE16: + add x11, x0, x4 + .inst 0x0522cc02 // mov z2.b, p3/m, z0.b + .inst 0x05240002 // ext z2.b, z2.b, z0.b, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe400f562 // st1b {z2.b}, p5, [x11] + b End + + TILE2_STORE24: + add x11, x0, x4 + add x10, x0, x4, LSL #1 + .inst 0x0522cc02 // mov z2.b, p3/m, z0.b + .inst 0x05240002 // ext z2.b, z2.b, z0.b, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe400f562 // st1b {z2.b}, p5, [x11] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + b End + + TILE2_STORE32: + add x11, x0, x4 + add x10, x0, x4, LSL #1 + add x13, x11, x4, LSL #1 + .inst 0x0522cc02 // mov z2.b, p3/m, z0.b + .inst 0x0523cc23 // mov z3.b, p3/m, z1.b + .inst 0x05240002 // ext z2.b, z2.b, z0.b, #32 + .inst 0x05240023 // ext z3.b, z3.b, z1.b, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe400f562 // st1b {z2.b}, p5, [x11] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe400f5a3 // st1b {z3.b}, p5, [x13] + + TILE2_Dz_End: + subs x5, x5, #4 + add x0, x0, x4, LSL #2 + mov x13, x19 + mov x23, x21 + mov x27, x20 + beq End + b LoopDz_TILE2 + +TILE_1: + cmp x7, #1 + blt End + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + + .inst 0x84c0b9de // ld1rh {z30.h}, p6/z, [x14] + .inst 0x84c1b9df // ld1rh {z31.h}, p6/z, [x14, #2] + + +LoopDz_TILE1: + .inst 0x25b8c01c // mov z28.s, #0 + .inst 0x25b8c01d // mov z29.s, #0 + + mov w8, #0 + mov x11, x1 // src + mov x15, x26 // blockid +TILE1_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE_1: + .inst 0xa4003162 // ld1rqb {z2.b}, p4/z, [x11] // src + .inst 0xa400ac40 // ld1b {z0.b}, p3/z, [x2] // weight + // int4->int8 + .inst 0xc08a4004 // luti4 {z4.b-z5.b}, zt0, z0[0] + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225022 // addvl x2, x2, #1 + + bne LoopSz_TILE_1 + +LoopSzEnd_TILE_1: + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0x6594a108 // scvtf z8.s, p0/m, z8.s + .inst 0x6594a129 // scvtf z9.s, p0/m, z9.s + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0x8540c1a4 // ld1rw {z4.s}, p0/z, [x13] // input kernel sum + .inst 0x8540c2e5 // ld1rw {z5.s}, p0/z, [x23] // input kernel scale + + .inst 0x65850800 // fmul z0.s, z0.s, z5.s + .inst 0x65850821 // fmul z1.s, z1.s, z5.s + .inst 0x64a4005c // fmla z28.s, z2.s, z4.s[0] + .inst 0x64a4007d // fmla z29.s, z3.s, z4.s[0] + .inst 0x65a0011c // fmla z28.s, p0/m, z8.s, z0.s + .inst 0x65a1013d // fmla z29.s, p0/m, z9.s, z1.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE1_ADD_DSTV + .inst 0x8540c365 // ld1rw {z5.s}, p0/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011c // fmla z28.s, z8.s, z5.s[0] + .inst 0x64a5013d // fmla z29.s, z9.s, z5.s[0] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE1_ADD_DSTV: + cbnz x15, TILE1_BLOCKNUM + + TILE1_STORE: + lsl x15, x5, #3 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa040453a // ld1w {z26.s, z27.s}, pn9/z, [x9] // bias + .inst 0x659a039c // fadd z28.s, z28.s, z26.s + .inst 0x659b03bd // fadd z29.s, z29.s, z27.s + .inst 0x04295049 // addvl x9, x9, #2 + .inst 0xc120e382 // fcvt z2.h, {z28.s-z29.s} + .inst 0x647f27c2 // fclamp z2.h, z30.h, z31.h + + cmp x5, #4 + bge TILE1_STORE32 + cmp x5, #3 + beq TILE1_STORE24 + cmp x5, #2 + beq TILE1_STORE16 + + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + b End + + TILE1_STORE16: + .inst 0x05702043 // dup z3.q, z2.q[1] + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + .inst 0xe4045403 // st1b {z3.b}, p5, [x0, x4] + b End + + TILE1_STORE24: + add x11, x0, x4, LSL #1 + .inst 0x05702043 // dup z3.q, z2.q[1] + .inst 0x05b02044 // dup z4.q, z2.q[2] + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + .inst 0xe4045403 // st1b {z3.b}, p5, [x0, x4] + .inst 0xe400f564 // st1b {z4.b}, p5, [x11] + b End + + TILE1_STORE32: + subs x5, x5, #4 + add x11, x0, x4, LSL #1 + .inst 0x05702043 // dup z3.q, z2.q[1] + .inst 0x05b02044 // dup z4.q, z2.q[2] + .inst 0x05f02045 // dup z5.q, z2.q[3] + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + .inst 0xe4045403 // st1b {z3.b}, p5, [x0, x4] + .inst 0xe400f564 // st1b {z4.b}, p5, [x11] + .inst 0xe4045565 // st1b {z5.b}, p5, [x11, x4] + + cbz x5, End + add x0, x0, x4, LSL #2 + mov x13, x19 + mov x23, x21 + mov x27, x20 + b LoopDz_TILE1 + End: .inst 0xd503467f // smstop diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp32.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp32.S index e1e96ff8c3..77a97eb92b 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp32.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp32.S @@ -73,9 +73,9 @@ ldr x27, [x6, #88] // input bias ldr x8, [x6, #104] // indices ldr x14, [x6, #56] // float32 maxmin ptr +.inst 0xe11f8100 // ldr zt0, [x8] lsl x22, x7, #2 // eSize * GEMM_INT8_SRC_UNIT lsl x21, x7, #4 // eSize * pack * sizeof (float) -.inst 0xe11f8100 // ldr zt0, [x8] /* initialize predicates */ mov x19, #32 // HP=32 @@ -86,12 +86,15 @@ mov x19, #32 // HP=32 .inst 0x25207810 // ptrue pn8.b // all int8 valid .inst 0x253617e7 // whilelt p7.b, xzr, x22 // eSize * LP int8 valid .inst 0x25b347f2 // whilelt pn10.s, xzr, x19, vlx2 // 32 float valid +.inst 0x2518e084 // ptrue p4.b, #4 // 4 int8_t valid +.inst 0x2518e125 // ptrue p5.b, vl16 // 16 int8_t valid .inst 0x8540c1dc // ld1rw {z28.s}, p0/z, [x14] // float min .inst 0x8541c1dd // ld1rw {z29.s}, p0/z, [x14, #4] ESIZE: - mov x6, x0 // dst + cmp x7, #2 + ble TILE_2 mov x19, x13 // input kernel sum mov x21, x23 // input dequant scale mov x20, x27 // input dequant bias @@ -252,29 +255,29 @@ bne LoopL beq STORE8 STORE12: - add x10, x6, x4, LSL #1 // + 2*x4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + add x10, x0, x4, LSL #1 // + 2*x4 + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] b End STORE8: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] b End STORE4: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] b End STORE16: - add x10, x6, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #1 // + 2*x4 subs x5, x5, #4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] .inst 0xa0248d40 // st1b {z0.b-z3.b}, pn11, [x10, x4] - add x6, x6, x4, LSL #2 + add x0, x0, x4, LSL #2 beq End /* oc:16~31 */ @@ -305,29 +308,29 @@ bne LoopL beq STORE24 STORE20: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] b End STORE28: - add x10, x6, x4, LSL #1 // + 2*x4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + add x10, x0, x4, LSL #1 // + 2*x4 + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] b End STORE24: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] b End STORE32: - add x10, x6, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #1 // + 2*x4 subs x5, x5, #4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] .inst 0xa0248d40 // st1b {z0.b-z3.b}, pn11, [x10, x4] - add x6, x6, x4, LSL #2 + add x0, x0, x4, LSL #2 beq End // revert input scale/kernelSum @@ -336,6 +339,404 @@ bne LoopL mov x27, x20 b LoopH +TILE_2: + cmp x7, #1 + beq TILE_1 + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + + mov x15, #32 + .inst 0x252f17e5 // whilelt p5.b, xzr, x15 + +LoopDz_TILE2: + .inst 0x25b8c01a // mov z26.s, #0 + .inst 0x25b8c01b // mov z27.s, #0 + .inst 0x25b8c01e // mov z30.s, #0 + .inst 0x25b8c01f // mov z31.s, #0 + mov w8, #0 + mov x11, x1 // src + mov x15, x26 + +TILE2_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE2: + .inst 0xa4003d62 // ld1rqb {z2.b}, p7/z, [x11] // src + .inst 0xa400ac41 // ld1b {z1.b}, p3/z, [x2] // weight + // int4->int8 + .inst 0xc08a4024 // luti4 {z4.b-z5.b}, zt0, z1[0] + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + .inst 0xc15214a4 // sdot za.s[w8, 4, VGx2], {z4.b-z5.b}, z2.b[1] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225022 // addvl x2, x2, #1 + + bne LoopSz_TILE2 + + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0xc006088a // mova {z10.s-z11.s}, za.s[w8, 4, VGx2] + .inst 0xc132e108 // scvtf {z8.s-z11.s}, {z8.s-z11.s} + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0xa4003da4 // ld1rqb {z4.b}, p7/z, [x13] // input kernel sum + .inst 0xa4003ee5 // ld1rqb {z5.b}, p7/z, [x23] // input kernel scale + + .inst 0x64a52006 // fmul z6.s, z0.s, z5.s[0] // e0 + .inst 0x64a52027 // fmul z7.s, z1.s, z5.s[0] + .inst 0x64ad200c // fmul z12.s, z0.s, z5.s[1] // e1 + .inst 0x64ad202d // fmul z13.s, z1.s, z5.s[1] + + .inst 0x64a4005a // fmla z26.s, z2.s, z4.s[0] // e0 + .inst 0x64a4007b // fmla z27.s, z3.s, z4.s[0] + .inst 0x64ac005e // fmla z30.s, z2.s, z4.s[1] // e1 + .inst 0x64ac007f // fmla z31.s, z3.s, z4.s[1] + + .inst 0x65a60d1a // fmla z26.s, p3/m, z8.s, z6.s + .inst 0x65a70d3b // fmla z27.s, p3/m, z9.s, z7.s + .inst 0x65ac0d5e // fmla z30.s, p3/m, z10.s, z12.s + .inst 0x65ad0d7f // fmla z31.s, p3/m, z11.s, z13.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE2_ADD_DSTV + .inst 0xa4003f65 // ld1rqb {z5.b}, p7/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011a // fmla z26.s, z8.s, z5.s[0] + .inst 0x64a5013b // fmla z27.s, z9.s, z5.s[0] + .inst 0x64ad011e // fmla z30.s, z8.s, z5.s[1] + .inst 0x64ad013f // fmla z31.s, z9.s, z5.s[1] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE2_ADD_DSTV: + cbnz x15, TILE2_BLOCKNUM + + TILE2_STORE: + lsl x15, x5, #2 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa0404520 // ld1w {z0.s, z1.s}, pn9/z, [x9] // bias + .inst 0x6580035a // fadd z26.s, z26.s, z0.s + .inst 0x6581037b // fadd z27.s, z27.s, z1.s + .inst 0x658003de // fadd z30.s, z30.s, z0.s + .inst 0x658103ff // fadd z31.s, z31.s, z1.s + .inst 0x04295049 // addvl x9, x9, #2 + + .inst 0xc1bdc39a // fclamp {z26.s-z27.s}, z28.s, z29.s + .inst 0xc1bdc39e // fclamp {z30.s-z31.s}, z28.s, z29.s + // z0: 0~3,4_7 z1:8~11,12~15 + .inst 0xc13ed740 // zip {z0.q-z1.q}, z26.q, z30.q // (0,0)(0,1)(0,2)(0,3)(1,0)(1,1)(1,2)(1,3)...(1,12)(1,13)(1,14)(1,15) + .inst 0xc13fd762 // zip {z2.q-z3.q}, z27.q, z31.q // (0,16)(0,17)(0,18)(0,19)(1,16)(1,17)(1,18)(1,19)...(1,28)(1,29)(1,30)(1,31) + + cmp x5, #8 + bge TILE2_STORE32 + cmp x5, #7 + beq TILE2_STORE28 + cmp x5, #6 + beq TILE2_STORE24 + cmp x5, #5 + beq TILE2_STORE20 + cmp x5, #4 + beq TILE2_STORE16 + cmp x5, #3 + beq TILE2_STORE12 + cmp x5, #2 + beq TILE2_STORE8 + + TILE2_STORE4: + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + b End + + TILE2_STORE32: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0x05640049 // ext z9.b, {z2.b, z3.b}, #32 + .inst 0x0564006b // ext z11.b, {z3.b, z4.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + .inst 0xe4045509 // st1b {z9.b}, p5, [x8, x4] + .inst 0xe400f5a3 // st1b {z3.b}, p5, [x13] + .inst 0xe40455ab // st1b {z11.b}, p5, [x13, x4] + b TILE2_Dz_End + + TILE2_STORE28: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0x05640049 // ext z9.b, {z2.b, z3.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + .inst 0xe4045509 // st1b {z9.b}, p5, [x8, x4] + .inst 0xe400f5a3 // st1b {z3.b}, p5, [x13] + b End + + TILE2_STORE24: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0x05640049 // ext z9.b, {z2.b, z3.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + .inst 0xe4045509 // st1b {z9.b}, p5, [x8, x4] + b End + + TILE2_STORE20: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + b End + + TILE2_STORE16: + add x10, x0, x4, LSL #1 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + b End + + TILE2_STORE12: + add x10, x0, x4, LSL #1 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + b End + + TILE2_STORE8: + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + b End + + + TILE2_Dz_End: + subs x5, x5, #8 + add x0, x0, x4, LSL #3 + mov x13, x19 + mov x23, x21 + mov x27, x20 + beq End + b LoopDz_TILE2 + +TILE_1: + cmp x7, #1 + blt End + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + +LoopDz_TILE1: +.inst 0x25b8c01e // mov z30.s, #0 +.inst 0x25b8c01f // mov z31.s, #0 + mov w8, #0 + mov x11, x1 // src + mov x15, x26 // blockid +TILE1_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE_1: + .inst 0xa4003162 // ld1rqb {z2.b}, p4/z, [x11] // src + .inst 0xa400ac40 // ld1b {z0.b}, p3/z, [x2] // weight + // int4->int8 + .inst 0xc08a4004 // luti4 {z4.b-z5.b}, zt0, z0[0] + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225022 // addvl x2, x2, #1 + + bne LoopSz_TILE_1 + +LoopSzEnd_TILE_1: + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0x6594a108 // scvtf z8.s, p0/m, z8.s + .inst 0x6594a129 // scvtf z9.s, p0/m, z9.s + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0x8540c1a4 // ld1rw {z4.s}, p0/z, [x13] // input kernel sum + .inst 0x8540c2e5 // ld1rw {z5.s}, p0/z, [x23] // input kernel scale + + .inst 0x65850800 // fmul z0.s, z0.s, z5.s + .inst 0x65850821 // fmul z1.s, z1.s, z5.s + .inst 0x64a4005e // fmla z30.s, z2.s, z4.s[0] + .inst 0x64a4007f // fmla z31.s, z3.s, z4.s[0] + .inst 0x65a0011e // fmla z30.s, p0/m, z8.s, z0.s + .inst 0x65a1013f // fmla z31.s, p0/m, z9.s, z1.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE1_ADD_DSTV + .inst 0x8540c365 // ld1rw {z5.s}, p0/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011e // fmla z30.s, z8.s, z5.s[0] + .inst 0x64a5013f // fmla z31.s, z9.s, z5.s[0] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE1_ADD_DSTV: + cmp x15, #0 + bne TILE1_BLOCKNUM + + TILE1_STORE: + lsl x15, x5, #2 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa0404520 // ld1w {z0.s, z1.s}, pn9/z, [x9] // bias + .inst 0x04295049 // addvl x9, x9, #2 + .inst 0x658003de // fadd z30.s, z30.s, z0.s + .inst 0x658103ff // fadd z31.s, z31.s, z1.s + .inst 0xc1bdc39e // fclamp {z30.s-z31.s}, z28.s, z29.s + + cmp x5, #8 + bge TILE1_STORE32 + cmp x5, #7 + beq TILE1_STORE28 + cmp x5, #6 + beq TILE1_STORE24 + cmp x5, #5 + beq TILE1_STORE20 + cmp x5, #4 + beq TILE1_STORE16 + cmp x5, #3 + beq TILE1_STORE12 + cmp x5, #2 + beq TILE1_STORE8 + + TILE1_STORE4: + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + b End + + TILE1_STORE28: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0x057023e7 // dup z7.q, z31.q[1] + .inst 0x05b023e8 // dup z8.q, z31.q[2] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + .inst 0xe4045507 // st1b {z7.b}, p5, [x8, x4] + .inst 0xe400f5a8 // st1b {z8.b}, p5, [x13] + b End + + TILE1_STORE24: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0x057023e7 // dup z7.q, z31.q[1] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + .inst 0xe4045507 // st1b {z7.b}, p5, [x8, x4] + b End + + TILE1_STORE20: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + b End + + TILE1_STORE16: + add x10, x0, x4, LSL #1 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + b End + + TILE1_STORE12: + add x10, x0, x4, LSL #1 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + b End + + TILE1_STORE8: + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + b End + + + TILE1_STORE32: + subs x5, x5, #8 + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0x057023e7 // dup z7.q, z31.q[1] + .inst 0x05b023e8 // dup z8.q, z31.q[2] + .inst 0x05f023e9 // dup z9.q, z31.q[3] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + .inst 0xe4045507 // st1b {z7.b}, p5, [x8, x4] + .inst 0xe400f5a8 // st1b {z8.b}, p5, [x13] + .inst 0xe40455a9 // st1b {z9.b}, p5, [x13, x4] + beq End + + add x0, x0, x4, LSL #3 + mov x13, x19 + mov x23, x21 + mov x27, x20 + b LoopDz_TILE1 + End: .inst 0xd503467f // smstop diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp16.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp16.S index 217402c003..e45d8ebee1 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp16.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp16.S @@ -63,7 +63,6 @@ stp d14, d15, [sp, #32] .inst 0xd503477f // smstart - ldr x9, [x6, #8] // biasFloat ldr x13, [x6, #40] // srcKernelSum ldr x28, [x6, #48] // weightKernelSum @@ -86,12 +85,13 @@ mov x19, #32 // HP=32 .inst 0x253617e7 // whilelt p7.b, xzr, x22 // eSize * LP int8 valid .inst 0x25b347f2 // whilelt pn10.s, xzr, x19, vlx2 // 32 float valid .inst 0x2558e3e2 // ptrue p2.h - -mov x25, 0 // inputBlockNum=1 -cbz x27, ESIZE -mov x25, x22 // input block quant: realDstCount * sizeof(float) +.inst 0x2518e084 // ptrue p4.b, #4 // 4 int8_t valid +.inst 0x2518e125 // ptrue p5.b, vl16 // 16 int8_t valid +.inst 0x2558e3e6 // ptrue p6.h // all fp16 valid ESIZE: + cmp x7, #2 + ble TILE_2 mov x19, x13 // input kernel sum mov x21, x23 // input dequant scale mov x20, x27 // input dequant bias @@ -154,8 +154,8 @@ bne LoopL .inst 0xa0404b80 // ld1w {z0.s, z1.s}, pn10/z, [x28] // weight kernel sum .inst 0x80800442 // fmopa za2.s, p1/m, p0/m, z2.s, z0.s .inst 0x80810443 // fmopa za3.s, p1/m, p0/m, z2.s, z1.s - add x27, x27, x25 - add x23, x23, x25 + add x27, x27, x22 + add x23, x23, x22 .inst 0x043c505c // addvl x28, x28, #2 HP_DEQUANT: @@ -188,7 +188,7 @@ bne LoopL .inst 0xc0080022 // zero {za1.s} // inputScale x weightScale -> [16,16] - .inst 0x809e07e1 // fmopa za1.s, p1/m, p0/m, z31.s, z30.s + .inst 0x809e67e1 // fmopa za1.s, p1/m, p3/m, z31.s, z30.s mov w8, #1 mov w10, #3 // extract scale from za1.s @@ -317,6 +317,268 @@ bne LoopL mov x27, x20 b LoopH +TILE_2: + cmp x7, #1 + beq TILE_1 + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + mov x15, #32 // 2 * pack(8) * sizeof(float16) + .inst 0x252f17e5 // whilelt p5.b, xzr, x15 + + .inst 0x84c0b9de // ld1rh {z30.h}, p6/z, [x14] + .inst 0x84c1b9df // ld1rh {z31.h}, p6/z, [x14, #2] + + +LoopDz_TILE2: + .inst 0x25b8c01a // mov z26.s, #0 + .inst 0x25b8c01b // mov z27.s, #0 + .inst 0x25b8c01c // mov z28.s, #0 + .inst 0x25b8c01d // mov z29.s, #0 + + mov w8, #0 + mov x11, x1 // src + mov x15, x26 // blockid +TILE2_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE2: + .inst 0xa4003d62 // ld1rqb {z2.b}, p7/z, [x11] // src + .inst 0xa0400044 // ld1b {z4.b-z5.b}, pn8/z, [x2] // weight + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + .inst 0xc15214a4 // sdot za.s[w8, 4, VGx2], {z4.b-z5.b}, z2.b[1] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225042 // addvl x2, x2, #2 + + bne LoopSz_TILE2 + + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0xc006088a // mova {z10.s-z11.s}, za.s[w8, 4, VGx2] + .inst 0xc132e108 // scvtf {z8.s-z11.s}, {z8.s-z11.s} + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0xa4003da4 // ld1rqb {z4.b}, p7/z, [x13] // input kernel sum + .inst 0xa4003ee5 // ld1rqb {z5.b}, p7/z, [x23] // input kernel scale + + .inst 0x64a52006 // fmul z6.s, z0.s, z5.s[0] // e0 + .inst 0x64a52027 // fmul z7.s, z1.s, z5.s[0] + .inst 0x64ad200c // fmul z12.s, z0.s, z5.s[1] // e1 + .inst 0x64ad202d // fmul z13.s, z1.s, z5.s[1] + + .inst 0x64a4005a // fmla z26.s, z2.s, z4.s[0] // e0 + .inst 0x64a4007b // fmla z27.s, z3.s, z4.s[0] + .inst 0x64ac005c // fmla z28.s, z2.s, z4.s[1] // e1 + .inst 0x64ac007d // fmla z29.s, z3.s, z4.s[1] + + .inst 0x65a60d1a // fmla z26.s, p3/m, z8.s, z6.s + .inst 0x65a70d3b // fmla z27.s, p3/m, z9.s, z7.s + .inst 0x65ac0d5c // fmla z28.s, p3/m, z10.s, z12.s + .inst 0x65ad0d7d // fmla z29.s, p3/m, z11.s, z13.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE2_ADD_DSTV + .inst 0xa4003f65 // ld1rqb {z5.b}, p7/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011a // fmla z26.s, z8.s, z5.s[0] + .inst 0x64a5013b // fmla z27.s, z9.s, z5.s[0] + .inst 0x64ad011c // fmla z28.s, z8.s, z5.s[1] + .inst 0x64ad013d // fmla z29.s, z9.s, z5.s[1] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE2_ADD_DSTV: + cbnz x15, TILE2_BLOCKNUM + + TILE2_STORE: + lsl x15, x5, #3 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa0404520 // ld1w {z0.s, z1.s}, pn9/z, [x9] // bias + .inst 0x6580035a // fadd z26.s, z26.s, z0.s + .inst 0x6581037b // fadd z27.s, z27.s, z1.s + .inst 0x6580039c // fadd z28.s, z28.s, z0.s + .inst 0x658103bd // fadd z29.s, z29.s, z1.s + .inst 0x04295049 // addvl x9, x9, #2 + + .inst 0xc120e342 // fcvt z2.h, {z26.s-z27.s} + .inst 0xc120e383 // fcvt z3.h, {z28.s-z29.s} + .inst 0xc17fc3c2 // fclamp {z2.h-z3.h}, z30.h, z31.h + .inst 0xc123d440 // zip {z0.q, z1.q}, z2.q, z3.q + + cmp x5, #4 + bge TILE2_STORE32 + cmp x5, #3 + beq TILE2_STORE24 + cmp x5, #2 + beq TILE2_STORE16 + + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + b End + + TILE2_STORE16: + add x11, x0, x4 + .inst 0x0522cc02 // mov z2.b, p3/m, z0.b + .inst 0x05240002 // ext z2.b, z2.b, z0.b, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe400f562 // st1b {z2.b}, p5, [x11] + b End + + TILE2_STORE24: + add x11, x0, x4 + add x10, x0, x4, LSL #1 + .inst 0x0522cc02 // mov z2.b, p3/m, z0.b + .inst 0x05240002 // ext z2.b, z2.b, z0.b, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe400f562 // st1b {z2.b}, p5, [x11] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + b End + + TILE2_STORE32: + add x11, x0, x4 + add x10, x0, x4, LSL #1 + add x13, x11, x4, LSL #1 + .inst 0x0522cc02 // mov z2.b, p3/m, z0.b + .inst 0x0523cc23 // mov z3.b, p3/m, z1.b + .inst 0x05240002 // ext z2.b, z2.b, z0.b, #32 + .inst 0x05240023 // ext z3.b, z3.b, z1.b, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe400f562 // st1b {z2.b}, p5, [x11] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe400f5a3 // st1b {z3.b}, p5, [x13] + + TILE2_Dz_End: + subs x5, x5, #4 + add x0, x0, x4, LSL #2 + mov x13, x19 + mov x23, x21 + mov x27, x20 + beq End + b LoopDz_TILE2 + +TILE_1: + cmp x7, #1 + blt End + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + + .inst 0x84c0b9de // ld1rh {z30.h}, p6/z, [x14] + .inst 0x84c1b9df // ld1rh {z31.h}, p6/z, [x14, #2] + + +LoopDz_TILE1: + .inst 0x25b8c01c // mov z28.s, #0 + .inst 0x25b8c01d // mov z29.s, #0 + + mov w8, #0 + mov x11, x1 // src + mov x15, x26 // blockid +TILE1_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE_1: + .inst 0xa4003162 // ld1rqb {z2.b}, p4/z, [x11] // src + .inst 0xa0400044 // ld1b {z4.b-z5.b}, pn8/z, [x2] // weight + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225042 // addvl x2, x2, #2 + + bne LoopSz_TILE_1 + +LoopSzEnd_TILE_1: + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0x6594a108 // scvtf z8.s, p0/m, z8.s + .inst 0x6594a129 // scvtf z9.s, p0/m, z9.s + + .inst 0xa040c040 // ld1w {z0.s-z3.s}, pn8/z, [x2] // weight scale&bias + .inst 0x8540c1a4 // ld1rw {z4.s}, p0/z, [x13] // input kernel sum + .inst 0x8540c2e5 // ld1rw {z5.s}, p0/z, [x23] // input kernel scale + + .inst 0x65850800 // fmul z0.s, z0.s, z5.s + .inst 0x65850821 // fmul z1.s, z1.s, z5.s + .inst 0x64a4005c // fmla z28.s, z2.s, z4.s[0] + .inst 0x64a4007d // fmla z29.s, z3.s, z4.s[0] + .inst 0x65a0011c // fmla z28.s, p0/m, z8.s, z0.s + .inst 0x65a1013d // fmla z29.s, p0/m, z9.s, z1.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE1_ADD_DSTV + .inst 0x8540c365 // ld1rw {z5.s}, p0/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011c // fmla z28.s, z8.s, z5.s[0] + .inst 0x64a5013d // fmla z29.s, z9.s, z5.s[0] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE1_ADD_DSTV: + cbnz x15, TILE1_BLOCKNUM + + TILE1_STORE: + lsl x15, x5, #3 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa040453a // ld1w {z26.s, z27.s}, pn9/z, [x9] // bias + .inst 0x659a039c // fadd z28.s, z28.s, z26.s + .inst 0x659b03bd // fadd z29.s, z29.s, z27.s + .inst 0x04295049 // addvl x9, x9, #2 + .inst 0xc120e382 // fcvt z2.h, {z28.s-z29.s} + .inst 0x647f27c2 // fclamp z2.h, z30.h, z31.h + + cmp x5, #4 + bge TILE1_STORE32 + cmp x5, #3 + beq TILE1_STORE24 + cmp x5, #2 + beq TILE1_STORE16 + + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + b End + + TILE1_STORE16: + .inst 0x05702043 // dup z3.q, z2.q[1] + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + .inst 0xe4045403 // st1b {z3.b}, p5, [x0, x4] + b End + + TILE1_STORE24: + add x11, x0, x4, LSL #1 + .inst 0x05702043 // dup z3.q, z2.q[1] + .inst 0x05b02044 // dup z4.q, z2.q[2] + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + .inst 0xe4045403 // st1b {z3.b}, p5, [x0, x4] + .inst 0xe400f564 // st1b {z4.b}, p5, [x11] + b End + + TILE1_STORE32: + subs x5, x5, #4 + add x11, x0, x4, LSL #1 + .inst 0x05702043 // dup z3.q, z2.q[1] + .inst 0x05b02044 // dup z4.q, z2.q[2] + .inst 0x05f02045 // dup z5.q, z2.q[3] + .inst 0xe400f402 // st1b {z2.b}, p5, [x0] + .inst 0xe4045403 // st1b {z3.b}, p5, [x0, x4] + .inst 0xe400f564 // st1b {z4.b}, p5, [x11] + .inst 0xe4045565 // st1b {z5.b}, p5, [x11, x4] + + cbz x5, End + add x0, x0, x4, LSL #2 + mov x13, x19 + mov x23, x21 + mov x27, x20 + b LoopDz_TILE1 + End: .inst 0xd503467f // smstop diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32.S index 8231e0a5ed..7020ff0329 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32.S @@ -113,7 +113,7 @@ struct QuanPostTreatParameters { // const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, // const QuanPostTreatParameters* parameters, size_t realDstCount); -//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step x5:dst_depth_quad, x6: parameters, x7: realDstCount +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step x5:dst_depth_quad, x0: parameters, x7: realDstCount // sme2 Ep=16, LP=4, HP=16 stp x29, x30, [sp, #-320]! @@ -152,6 +152,8 @@ mov x19, #32 // HP=32 .inst 0x25207810 // ptrue pn8.b // all int8 valid .inst 0x253617e7 // whilelt p7.b, xzr, x22 // eSize * LP int8 valid .inst 0x25b347f2 // whilelt pn10.s, xzr, x19, vlx2 // 32 float valid +.inst 0x2518e082 // ptrue p2.b, #4 // 4 int8_t valid +.inst 0x2518e125 // ptrue p5.b, vl16 // 16 int8_t valid reluRead: /* relu min/max*/ @@ -166,7 +168,8 @@ add x14, x6, #16 .inst 0x84448ddc // ld1rb {z28.b}, p3/z, [x14, #4] ESIZE: - mov x6, x0 // dst + cmp x7, #2 + ble TILE_2 mov x19, x13 // input kernel sum mov x21, x23 // input dequant scale mov x20, x27 // input dequant bias @@ -327,29 +330,29 @@ bne LoopL beq STORE8 STORE12: - add x10, x6, x4, LSL #1 // + 2*x4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + add x10, x0, x4, LSL #1 // + 2*x4 + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] b End STORE8: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] b End STORE4: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] b End STORE16: - add x10, x6, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #1 // + 2*x4 subs x5, x5, #4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] .inst 0xa0248d40 // st1b {z0.b-z3.b}, pn11, [x10, x4] - add x6, x6, x4, LSL #2 + add x0, x0, x4, LSL #2 beq End /* oc:16~31 */ @@ -380,29 +383,29 @@ bne LoopL beq STORE24 STORE20: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] b End STORE28: - add x10, x6, x4, LSL #1 // + 2*x4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + add x10, x0, x4, LSL #1 // + 2*x4 + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] b End STORE24: - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] b End STORE32: - add x10, x6, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #1 // + 2*x4 subs x5, x5, #4 - .inst 0xa0608cd0 // st1b {z16.b-z19.b}, pn11, [x6] - .inst 0xa0248cd4 // st1b {z20.b-z23.b}, pn11, [x6, x4] + .inst 0xa0608c10 // st1b {z16.b-z19.b}, pn11, [x0] + .inst 0xa0248c14 // st1b {z20.b-z23.b}, pn11, [x0, x4] .inst 0xa0608d58 // st1b {z24.b-z27.b}, pn11, [x10] .inst 0xa0248d40 // st1b {z0.b-z3.b}, pn11, [x10, x4] - add x6, x6, x4, LSL #2 + add x0, x0, x4, LSL #2 beq End b HP_END @@ -442,29 +445,29 @@ bne LoopL beq HP_STORE_INT8_8 HP_STORE_INT8_4: - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] b End HP_STORE_INT8_12: - add x8, x6, x4, LSL #1 // + 2*x4 - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] - .inst 0xe4045cd1 // st1b {z17.b}, p7, [x6, x4] + add x8, x0, x4, LSL #1 // + 2*x4 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] b End HP_STORE_INT8_8: - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] - .inst 0xe4045cd1 // st1b {z17.b}, p7, [x6, x4] + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] b End HP_STORE_INT8_16: - add x8, x6, x4, LSL #1 // + 2*x4 + add x8, x0, x4, LSL #1 // + 2*x4 subs x5, x5, #4 - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] - .inst 0xe4045cd1 // st1b {z17.b}, p7, [x6, x4] + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] - add x6, x6, x4, LSL #2 + add x0, x0, x4, LSL #2 beq End @@ -501,29 +504,29 @@ bne LoopL beq HP_STORE_INT8_24 HP_STORE_INT8_20: - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] b End HP_STORE_INT8_28: - add x8, x6, x4, LSL #1 // + 2*x4 - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] - .inst 0xe4045cd1 // st1b {z17.b}, p7, [x6, x4] + add x8, x0, x4, LSL #1 // + 2*x4 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] b End HP_STORE_INT8_24: - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] - .inst 0xe4045cd1 // st1b {z17.b}, p7, [x6, x4] + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] b End HP_STORE_INT8_32: - add x8, x6, x4, LSL #1 // + 2*x4 + add x8, x0, x4, LSL #1 // + 2*x4 subs x5, x5, #4 - .inst 0xe400fcd0 // st1b {z16.b}, p7, [x6] - .inst 0xe4045cd1 // st1b {z17.b}, p7, [x6, x4] + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] - add x6, x6, x4, LSL #2 + add x0, x0, x4, LSL #2 beq End @@ -534,6 +537,681 @@ bne LoopL mov x27, x20 b LoopH +TILE_2: + cmp x7, #1 + beq TILE_1 + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + + mov x15, #32 + .inst 0x252f17e5 // whilelt p5.b, xzr, x15 + +LoopDz_TILE2: + .inst 0x25b8c01a // mov z26.s, #0 + .inst 0x25b8c01b // mov z27.s, #0 + .inst 0x25b8c01e // mov z30.s, #0 + .inst 0x25b8c01f // mov z31.s, #0 + mov w8, #0 + mov x11, x1 // src + mov x15, x26 + +TILE2_BLOCKNUM: + mov x10, x3 // src_depth_quad + .inst 0xc00800ff // zero {za} + + LoopSz_TILE2: + .inst 0xa4003d62 // ld1rqb {z2.b}, p7/z, [x11] // src + .inst 0xa0400044 // ld1b {z4.b-z5.b}, pn8/z, [x2] // weight + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + .inst 0xc15214a4 // sdot za.s[w8, 4, VGx2], {z4.b-z5.b}, z2.b[1] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225042 // addvl x2, x2, #2 + + bne LoopSz_TILE2 + + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0xc006088a // mova {z10.s-z11.s}, za.s[w8, 4, VGx2] + .inst 0xc132e108 // scvtf {z8.s-z11.s}, {z8.s-z11.s} + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0xa4003da4 // ld1rqb {z4.b}, p7/z, [x13] // input kernel sum + .inst 0xa4003ee5 // ld1rqb {z5.b}, p7/z, [x23] // input kernel scale + + .inst 0x64a52006 // fmul z6.s, z0.s, z5.s[0] // e0 + .inst 0x64a52027 // fmul z7.s, z1.s, z5.s[0] + .inst 0x64ad200c // fmul z12.s, z0.s, z5.s[1] // e1 + .inst 0x64ad202d // fmul z13.s, z1.s, z5.s[1] + + .inst 0x64a4005a // fmla z26.s, z2.s, z4.s[0] // e0 + .inst 0x64a4007b // fmla z27.s, z3.s, z4.s[0] + .inst 0x64ac005e // fmla z30.s, z2.s, z4.s[1] // e1 + .inst 0x64ac007f // fmla z31.s, z3.s, z4.s[1] + + .inst 0x65a60d1a // fmla z26.s, p3/m, z8.s, z6.s + .inst 0x65a70d3b // fmla z27.s, p3/m, z9.s, z7.s + .inst 0x65ac0d5e // fmla z30.s, p3/m, z10.s, z12.s + .inst 0x65ad0d7f // fmla z31.s, p3/m, z11.s, z13.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE2_ADD_DSTV + .inst 0xa4003f65 // ld1rqb {z5.b}, p7/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011a // fmla z26.s, z8.s, z5.s[0] + .inst 0x64a5013b // fmla z27.s, z9.s, z5.s[0] + .inst 0x64ad011e // fmla z30.s, z8.s, z5.s[1] + .inst 0x64ad013f // fmla z31.s, z9.s, z5.s[1] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE2_ADD_DSTV: + cbnz x15, TILE2_BLOCKNUM + + lsl x15, x5, #2 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa0404520 // ld1w {z0.s, z1.s}, pn9/z, [x9] // bias + .inst 0x6580035a // fadd z26.s, z26.s, z0.s + .inst 0x6581037b // fadd z27.s, z27.s, z1.s + .inst 0x658003de // fadd z30.s, z30.s, z0.s + .inst 0x658103ff // fadd z31.s, z31.s, z1.s + .inst 0x04295049 // addvl x9, x9, #2 + + cbz x28, TILE2_Int8_Output + + TILE2_STORE: + + .inst 0xc1bdc39a // fclamp {z26.s-z27.s}, z28.s, z29.s + .inst 0xc1bdc39e // fclamp {z30.s-z31.s}, z28.s, z29.s + // z0: 0~3,4_7 z1:8~11,12~15 + .inst 0xc13ed740 // zip {z0.q-z1.q}, z26.q, z30.q // (0,0)(0,1)(0,2)(0,3)(1,0)(1,1)(1,2)(1,3)...(1,12)(1,13)(1,14)(1,15) + .inst 0xc13fd762 // zip {z2.q-z3.q}, z27.q, z31.q // (0,16)(0,17)(0,18)(0,19)(1,16)(1,17)(1,18)(1,19)...(1,28)(1,29)(1,30)(1,31) + + cmp x5, #8 + bge TILE2_STORE32 + cmp x5, #7 + beq TILE2_STORE28 + cmp x5, #6 + beq TILE2_STORE24 + cmp x5, #5 + beq TILE2_STORE20 + cmp x5, #4 + beq TILE2_STORE16 + cmp x5, #3 + beq TILE2_STORE12 + cmp x5, #2 + beq TILE2_STORE8 + + TILE2_STORE4: + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + b End + + TILE2_STORE32: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0x05640049 // ext z9.b, {z2.b, z3.b}, #32 + .inst 0x0564006b // ext z11.b, {z3.b, z4.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + .inst 0xe4045509 // st1b {z9.b}, p5, [x8, x4] + .inst 0xe400f5a3 // st1b {z3.b}, p5, [x13] + .inst 0xe40455ab // st1b {z11.b}, p5, [x13, x4] + b TILE2_Dz_End + + TILE2_STORE28: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0x05640049 // ext z9.b, {z2.b, z3.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + .inst 0xe4045509 // st1b {z9.b}, p5, [x8, x4] + .inst 0xe400f5a3 // st1b {z3.b}, p5, [x13] + b End + + TILE2_STORE24: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0x05640049 // ext z9.b, {z2.b, z3.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + .inst 0xe4045509 // st1b {z9.b}, p5, [x8, x4] + b End + + TILE2_STORE20: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + .inst 0xe400f502 // st1b {z2.b}, p5, [x8] + b End + + TILE2_STORE16: + add x10, x0, x4, LSL #1 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0x05640027 // ext z7.b, {z1.b, z2.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + .inst 0xe4045547 // st1b {z7.b}, p5, [x10, x4] + b End + + TILE2_STORE12: + add x10, x0, x4, LSL #1 + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + .inst 0xe400f541 // st1b {z1.b}, p5, [x10] + b End + + TILE2_STORE8: + .inst 0x05640005 // ext z5.b, {z0.b, z1.b}, #32 + .inst 0xe400f400 // st1b {z0.b}, p5, [x0] + .inst 0xe4045405 // st1b {z5.b}, p5, [x0, x4] + b End + + TILE2_Int8_Output: + + .inst 0x65912f44 // fcmlt p4.s, p3/z, z26.s, #0.0 + .inst 0x65912f66 // fcmlt p6.s, p3/z, z27.s, #0.0 + .inst 0x6599901a // fsub z26.s, p4/m, z26.s, #0.5 + .inst 0x6599981b // fsub z27.s, p6/m, z27.s, #0.5 + .inst 0x65902f44 // fcmge p4.s, p3/z, z26.s, #0.0 + .inst 0x65902f66 // fcmge p6.s, p3/z, z27.s, #0.0 + .inst 0x6598901a // fadd z26.s, p4/m, z26.s, #0.5 + .inst 0x6598981b // fadd z27.s, p6/m, z27.s, #0.5 + .inst 0x659caf5a // fcvtzs z26.s, p3/m, z26.s + .inst 0x659caf7b // fcvtzs z27.s, p3/m, z27.s + + .inst 0x65912fc4 // fcmlt p4.s, p3/z, z30.s, #0.0 + .inst 0x65912fe6 // fcmlt p6.s, p3/z, z31.s, #0.0 + .inst 0x6599901e // fsub z30.s, p4/m, z30.s, #0.5 + .inst 0x6599981f // fsub z31.s, p6/m, z31.s, #0.5 + .inst 0x65902fc4 // fcmge p4.s, p3/z, z30.s, #0.0 + .inst 0x65902fe6 // fcmge p6.s, p3/z, z31.s, #0.0 + .inst 0x6598901e // fadd z30.s, p4/m, z30.s, #0.5 + .inst 0x6598981f // fadd z31.s, p6/m, z31.s, #0.5 + .inst 0x659cafde // fcvtzs z30.s, p3/m, z30.s + .inst 0x659cafff // fcvtzs z31.s, p3/m, z31.s + + .inst 0x0520cf40 // mov z0.b, p3/m, z26.b + .inst 0x0521cf61 // mov z1.b, p3/m, z27.b + .inst 0x0522cfc2 // mov z2.b, p3/m, z30.b + .inst 0x0523cfe3 // mov z3.b, p3/m, z31.b + + .inst 0xc133e004 // sqcvt z4.b, {z0.s-z3.s} + .inst 0x05640086 // ext z6.b, {z4.b, z5.b}, #32 + .inst 0x05a66090 // zip1 z16.s, z4.s, z6.s + .inst 0x441dc390 // sclamp z16.b, z28.b, z29.b + + cmp x5, #8 + bge TILE2_32 + cmp x5, #7 + beq TILE2_28 + cmp x5, #6 + beq TILE2_24 + cmp x5, #5 + beq TILE2_20 + cmp x5, #4 + beq TILE2_16 + cmp x5, #3 + beq TILE2_12 + cmp x5, #2 + beq TILE2_8 + cmp x5, #1 + beq TILE2_4 + + TILE2_32: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + add x14, x8, x4, LSL #2 // + 6*x4 + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0x05620212 // ext z18.b, {z16.b, z17.b}, #16 + .inst 0x05630213 // ext z19.b, {z16.b, z17.b}, #24 + .inst 0x05640214 // ext z20.b, {z16.b, z17.b}, #32 + .inst 0x05650215 // ext z21.b, {z16.b, z17.b}, #40 + .inst 0x05660216 // ext z22.b, {z16.b, z17.b}, #48 + .inst 0x05670217 // ext z23.b, {z16.b, z17.b}, #56 + + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] + .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] + .inst 0xe400fd54 // st1b {z20.b}, p7, [x10] + .inst 0xe4045d55 // st1b {z21.b}, p7, [x10, x4] + .inst 0xe400fdd6 // st1b {z22.b}, p7, [x14] + .inst 0xe4045dd7 // st1b {z23.b}, p7, [x14, x4] + b TILE2_Dz_End + + TILE2_28: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + add x14, x8, x4, LSL #2 // + 6*x4 + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0x05620212 // ext z18.b, {z16.b, z17.b}, #16 + .inst 0x05630213 // ext z19.b, {z16.b, z17.b}, #24 + .inst 0x05640214 // ext z20.b, {z16.b, z17.b}, #32 + .inst 0x05650215 // ext z21.b, {z16.b, z17.b}, #40 + .inst 0x05660216 // ext z22.b, {z16.b, z17.b}, #48 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] + .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] + .inst 0xe400fd54 // st1b {z20.b}, p7, [x10] + .inst 0xe4045d55 // st1b {z21.b}, p7, [x10, x4] + .inst 0xe400fdd6 // st1b {z22.b}, p7, [x14] + b End + + TILE2_24: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0x05620212 // ext z18.b, {z16.b, z17.b}, #16 + .inst 0x05630213 // ext z19.b, {z16.b, z17.b}, #24 + .inst 0x05640214 // ext z20.b, {z16.b, z17.b}, #32 + .inst 0x05650215 // ext z21.b, {z16.b, z17.b}, #40 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] + .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] + .inst 0xe400fd54 // st1b {z20.b}, p7, [x10] + .inst 0xe4045d55 // st1b {z21.b}, p7, [x10, x4] + b End + + TILE2_20: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0x05620212 // ext z18.b, {z16.b, z17.b}, #16 + .inst 0x05630213 // ext z19.b, {z16.b, z17.b}, #24 + .inst 0x05640214 // ext z20.b, {z16.b, z17.b}, #32 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] + .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] + .inst 0xe400fd54 // st1b {z20.b}, p7, [x10] + b End + + TILE2_16: + add x8, x0, x4, LSL #1 // + 2*x4 + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0x05620212 // ext z18.b, {z16.b, z17.b}, #16 + .inst 0x05630213 // ext z19.b, {z16.b, z17.b}, #24 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] + .inst 0xe4045d13 // st1b {z19.b}, p7, [x8, x4] + b End + + TILE2_12: + add x8, x0, x4, LSL #1 // + 2*x4 + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0x05620212 // ext z18.b, {z16.b, z17.b}, #16 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + .inst 0xe400fd12 // st1b {z18.b}, p7, [x8] + b End + + TILE2_8: + .inst 0x05610211 // ext z17.b, {z16.b, z17.b}, #8 + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + .inst 0xe4045c11 // st1b {z17.b}, p7, [x0, x4] + b End + + TILE2_4: + .inst 0xe400fc10 // st1b {z16.b}, p7, [x0] + b End + + + TILE2_Dz_End: + subs x5, x5, #8 + cbz x5, End + add x0, x0, x4, LSL #3 + mov x13, x19 + mov x23, x21 + mov x27, x20 + b LoopDz_TILE2 + + +TILE_1: + cmp x7, #1 + blt End + mov x19, x13 // input kernel sum + mov x21, x23 // input dequant scale + mov x20, x27 // input dequant bias + +LoopDz_TILE1: + + .inst 0x25b8c01e // mov z30.s, #0 + .inst 0x25b8c01f // mov z31.s, #0 + + mov w8, #0 + mov x11, x1 // src + mov x15, x26 + +TILE1_BLOCKNUM: + mov x10, x3 // src_depth_quad + +.inst 0xc00800ff // zero {za} + + LoopSz_TILE_1: + .inst 0xa4002962 // ld1rqb {z2.b}, p2/z, [x11] // src + .inst 0xa0400044 // ld1b {z4.b-z5.b}, pn8/z, [x2] // weight + // matmul + .inst 0xc15210a0 // sdot za.s[w8, 0, VGx2], {z4.b-z5.b}, z2.b[0] + subs x10, x10, #1 + add x11, x11, x22 + .inst 0x04225042 // addvl x2, x2, #2 + + bne LoopSz_TILE_1 + +LoopSzEnd_TILE_1: + sub x15, x15, #1 + .inst 0xc0060808 // mova {z8.s-z9.s}, za.s[w8, 0, VGx2] + .inst 0x6594ad08 // scvtf z8.s, p3/m, z8.s + .inst 0x6594ad29 // scvtf z9.s, p3/m, z9.s + + .inst 0xa0408040 // ld1b {z0.b-z3.b}, pn8/z, [x2] // weight scale&bias + .inst 0x8540cda4 // ld1rw {z4.s}, p3/z, [x13] // input kernel sum + .inst 0x8540cee5 // ld1rw {z5.s}, p3/z, [x23] // input kernel scale + + .inst 0x65850800 // fmul z0.s, z0.s, z5.s + .inst 0x65850821 // fmul z1.s, z1.s, z5.s + .inst 0x64a4005e // fmla z30.s, z2.s, z4.s[0] + .inst 0x64a4007f // fmla z31.s, z3.s, z4.s[0] + .inst 0x65a00d1e // fmla z30.s, p3/m, z8.s, z0.s + .inst 0x65a10d3f // fmla z31.s, p3/m, z9.s, z1.s + .inst 0x04225082 // addvl x2, x2, #4 + add x13, x13, x22 + + cbz x27, TILE1_ADD_DSTV + .inst 0x8540cf65 // ld1rw {z5.s}, p3/z, [x27] // input dequant bias + .inst 0xa0404b88 // ld1w {z8.s, z9.s}, pn10/z, [x28] // weight kernel sum + .inst 0x64a5011e // fmla z30.s, z8.s, z5.s[0] + .inst 0x64a5013f // fmla z31.s, z9.s, z5.s[0] + add x27, x27, x22 + add x23, x23, x22 + .inst 0x043c505c // addvl x28, x28, #2 + + TILE1_ADD_DSTV: + cmp x15, #0 + bne TILE1_BLOCKNUM + + lsl x15, x5, #2 // ocRemain + .inst 0x25af47f1 // whilelt pn9.s, xzr, x15, vlx2 + .inst 0xa0404520 // ld1w {z0.s, z1.s}, pn9/z, [x9] // bias + .inst 0x658003de // fadd z30.s, z30.s, z0.s + .inst 0x658103ff // fadd z31.s, z31.s, z1.s + .inst 0x04295049 // addvl x9, x9, #2 + + cbz x28, TILE1_Int8_Output + mov x11, #16 // float output: 1*4*sizeof(float) + + TILE1_STORE: + .inst 0xc1bdc39e // fclamp {z30.s-z31.s}, z28.s, z29.s + + cmp x5, #8 + bge TILE1_STORE32 + cmp x5, #7 + beq TILE1_STORE28 + cmp x5, #6 + beq TILE1_STORE24 + cmp x5, #5 + beq TILE1_STORE20 + cmp x5, #4 + beq TILE1_STORE16 + cmp x5, #3 + beq TILE1_STORE12 + cmp x5, #2 + beq TILE1_STORE8 + + TILE1_STORE4: + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + b End + + TILE1_STORE28: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0x057023e7 // dup z7.q, z31.q[1] + .inst 0x05b023e8 // dup z8.q, z31.q[2] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + .inst 0xe4045507 // st1b {z7.b}, p5, [x8, x4] + .inst 0xe400f5a8 // st1b {z8.b}, p5, [x13] + b End + + TILE1_STORE24: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0x057023e7 // dup z7.q, z31.q[1] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + .inst 0xe4045507 // st1b {z7.b}, p5, [x8, x4] + b End + + TILE1_STORE20: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + b End + + TILE1_STORE16: + add x10, x0, x4, LSL #1 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + b End + + TILE1_STORE12: + add x10, x0, x4, LSL #1 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + b End + + TILE1_STORE8: + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + b End + + TILE1_STORE32: + add x10, x0, x4, LSL #1 + add x8, x0, x4, LSL #2 + add x13, x10, x4, LSL #2 + .inst 0x057023c4 // dup z4.q, z30.q[1] + .inst 0x05b023c5 // dup z5.q, z30.q[2] + .inst 0x05f023c6 // dup z6.q, z30.q[3] + .inst 0x057023e7 // dup z7.q, z31.q[1] + .inst 0x05b023e8 // dup z8.q, z31.q[2] + .inst 0x05f023e9 // dup z9.q, z31.q[3] + .inst 0xe400f41e // st1b {z30.b}, p5, [x0] + .inst 0xe4045404 // st1b {z4.b}, p5, [x0, x4] + .inst 0xe400f545 // st1b {z5.b}, p5, [x10] + .inst 0xe4045546 // st1b {z6.b}, p5, [x10, x4] + .inst 0xe400f51f // st1b {z31.b}, p5, [x8] + .inst 0xe4045507 // st1b {z7.b}, p5, [x8, x4] + .inst 0xe400f5a8 // st1b {z8.b}, p5, [x13] + .inst 0xe40455a9 // st1b {z9.b}, p5, [x13, x4] + b TILE1_End + + TILE1_Int8_Output: + mov x11, #4 // int8_t output: 1*4*sizeof(int8_t) + + .inst 0x65912fc4 // fcmlt p4.s, p3/z, z30.s, #0.0 + .inst 0x65912fe6 // fcmlt p6.s, p3/z, z31.s, #0.0 + .inst 0x6599901e // fsub z30.s, p4/m, z30.s, #0.5 + .inst 0x6599981f // fsub z31.s, p6/m, z31.s, #0.5 + .inst 0x65902fc4 // fcmge p4.s, p3/z, z30.s, #0.0 + .inst 0x65902fe6 // fcmge p6.s, p3/z, z31.s, #0.0 + .inst 0x6598901e // fadd z30.s, p4/m, z30.s, #0.5 + .inst 0x6598981f // fadd z31.s, p6/m, z31.s, #0.5 + .inst 0x659cafde // fcvtzs z30.s, p3/m, z30.s + .inst 0x659cafff // fcvtzs z31.s, p3/m, z31.s + + + .inst 0x453043db // sqxtnb z27.h, z30.s + .inst 0x453043fa // sqxtnb z26.h, z31.s + .inst 0x057a6b60 // uzp1 z0.h, z27.h, z26.h + + .inst 0x4528401a // sqxtnb z26.b, z0.h + .inst 0x053a6b50 // uzp1 z16.b, z26.b, z26.b + .inst 0x441dc390 // sclamp z16.b, z28.b, z29.b + .inst 0x052c2211 // dup z17.s, z16.s[1] + .inst 0x05342212 // dup z18.s, z16.s[2] + .inst 0x053c2213 // dup z19.s, z16.s[3] + .inst 0x05642214 // dup z20.s, z16.s[4] + .inst 0x056c2215 // dup z21.s, z16.s[5] + .inst 0x05742216 // dup z22.s, z16.s[6] + .inst 0x057c2217 // dup z23.s, z16.s[7] + + cmp x5, #8 + bge TILE1_32 + cmp x5, #7 + beq TILE1_28 + cmp x5, #6 + beq TILE1_24 + cmp x5, #5 + beq TILE1_20 + cmp x5, #4 + beq TILE1_16 + cmp x5, #3 + beq TILE1_12 + cmp x5, #2 + beq TILE1_8 + cmp x5, #1 + beq TILE1_4 + + TILE1_28: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + add x14, x8, x4, LSL #2 // + 6*x4 + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + .inst 0xe400e912 // st1b {z18.b}, p2, [x8] + .inst 0xe4044913 // st1b {z19.b}, p2, [x8, x4] + .inst 0xe400e954 // st1b {z20.b}, p2, [x10] + .inst 0xe4044955 // st1b {z21.b}, p2, [x10, x4] + .inst 0xe400e9d6 // st1b {z22.b}, p2, [x14] + b End + + TILE1_24: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + .inst 0xe400e912 // st1b {z18.b}, p2, [x8] + .inst 0xe4044913 // st1b {z19.b}, p2, [x8, x4] + .inst 0xe400e954 // st1b {z20.b}, p2, [x10] + .inst 0xe4044955 // st1b {z21.b}, p2, [x10, x4] + b End + + TILE1_20: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + .inst 0xe400e912 // st1b {z18.b}, p2, [x8] + .inst 0xe4044913 // st1b {z19.b}, p2, [x8, x4] + .inst 0xe400e954 // st1b {z20.b}, p2, [x10] + b End + + TILE1_16: + add x8, x0, x4, LSL #1 // + 2*x4 + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + .inst 0xe400e912 // st1b {z18.b}, p2, [x8] + .inst 0xe4044913 // st1b {z19.b}, p2, [x8, x4] + b End + + TILE1_12: + add x8, x0, x4, LSL #1 // + 2*x4 + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + .inst 0xe400e912 // st1b {z18.b}, p2, [x8] + b End + + TILE1_8: + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + b End + + TILE1_4: + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + b End + + TILE1_32: + add x8, x0, x4, LSL #1 // + 2*x4 + add x10, x0, x4, LSL #2 // + 4*x4 + add x14, x8, x4, LSL #2 // + 6*x4 + .inst 0xe400e810 // st1b {z16.b}, p2, [x0] + .inst 0xe4044811 // st1b {z17.b}, p2, [x0, x4] + .inst 0xe400e912 // st1b {z18.b}, p2, [x8] + .inst 0xe4044913 // st1b {z19.b}, p2, [x8, x4] + .inst 0xe400e954 // st1b {z20.b}, p2, [x10] + .inst 0xe4044955 // st1b {z21.b}, p2, [x10, x4] + .inst 0xe400e9d6 // st1b {z22.b}, p2, [x14] + .inst 0xe40449d7 // st1b {z23.b}, p2, [x14, x4] + + TILE1_End: + subs x5, x5, #8 + add x0, x0, x4, LSL #3 + mov x13, x19 + mov x23, x21 + mov x27, x20 + beq End + b LoopDz_TILE1 + End: .inst 0xd503467f // smstop diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16.S index 1ed3b2ed97..5fadff41b3 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16.S @@ -201,6 +201,7 @@ LoopSzEnd_TILE_1: cmp x15, #0 bne TILE1_BLOCKNUM + TILE1_STORE: .inst 0x84c0a9c8 // ld1rh {z8.h}, p2/z, [x14] .inst 0x84c1a9c9 // ld1rh {z9.h}, p2/z, [x14, #2] diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16.S index 54e0a8a368..1aaca8cdea 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16.S @@ -66,6 +66,8 @@ mov x22, #48 .inst 0x2518e080 // ptrue p0.b, #4 // first 4 bytes .inst 0x2518e125 // ptrue p5.b, vl16 // first 16 bytes .inst 0x25207810 // ptrue pn8.b + + .inst 0x2598e3e1 // ptrue p1.s .inst 0x2558e3e2 // ptrue p2.h diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNPackedMatMulRemainFP32_SME2.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNPackedMatMulRemainFP32_SME2.S index 66c78583d6..0ca8aeca64 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNPackedMatMulRemainFP32_SME2.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNPackedMatMulRemainFP32_SME2.S @@ -17,12 +17,13 @@ asm_function MNNPackedMatMulRemainFP32_SME2 //void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); //Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias // parameter: {aStride, l, h, cStride, bExtraStride} -stp d14, d15, [sp, #(-16 * 6)]! +stp d14, d15, [sp, #(-16 * 8)]! stp d12, d13, [sp, #(16 * 1)] stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] stp x21, x22, [sp, #(16 * 4)] stp x19, x20, [sp, #(16 * 5)] +stp x23, x24, [sp, #(16 * 6)] .inst 0xd503477f // smstart @@ -34,8 +35,6 @@ ldr x10, [x4, #16] // h ldr x7, [x4, #24] // cStride ldr x19, [x4, #40] // bExtraStride -lsr x7, x7, #2 // cStride/sizeof(float) - mov w12, #0 mov w13, #4 mov w14, #8 @@ -51,6 +50,7 @@ lsl x20, x3, #2 // x20: eSize * pack .inst 0x25a317e1 // whilelt p1.s, xzr, x3 .inst 0x25a07810 // ptrue pn8.s .inst 0x25b467f2 // whilelt pn10.s, xzr, x20, vlx4 // eSize * pack valid +.inst 0x2518e124 // ptrue p4.b, vl16 // Relu parameters @@ -60,6 +60,10 @@ cbz x5, ESIZE ESIZE: // x3 <= eP +cmp x3, #1 +beq E1 + +lsr x7, x7, #2 // cStride/sizeof(float) cmp x3, #16 blt LoopOcDiv4 @@ -95,7 +99,6 @@ LoopL: .inst 0x80830083 // fmopa za3.s, p0/m, p0/m, z4.s, z3.s subs x21, x21, #1 -// addvl x8, x8, #1 add x8, x8, x22 .inst 0x04225082 // addvl x2, x2, #4 bne LoopL @@ -337,15 +340,454 @@ ble End b LoopOcDiv4 +E1: +cmp x3, #1 +blt End + + +E1LoopH: +mov w11, #0 // could be modified in 'Store' +mov w12, #0 // could be modified in 'Store' +mov x8, x1 // A +mov x21, x9 // LU + +.inst 0xc00800ff // zero {za} + +cbz x6, E1LoopL +// bias +lsl x4, x10, #2 +.inst 0x25a467f1 // whilelt pn9.s, xzr, x4, vlx4 +.inst 0xa040c4d4 // ld1w {z20.s-z23.s}, pn9/z, [x6] +.inst 0x04265086 // addvl x6, x6, #4 +.inst 0xc1a17e80 // fadd za.s[w11, 0, VGx4], {z20.s-z23.s} + +E1LoopL: +.inst 0x8540c104 // ld1rw {z4.s}, p0/z, [x8] // A +.inst 0xa040c040 // ld1w {z0.s-z3.s}, pn8/z, [x2] // B +// [EP,LP] x [HP,LP] -> [EP,HP] +.inst 0xc1347800 // fmla za.s[w11, 0, VGx4], {z0.s-z3.s}, z4.s + +subs x21, x21, #1 +add x8, x8, x22 +.inst 0x04225082 // addvl x2, x2, #4 +bne E1LoopL + +add x2, x2, x19 // bExtraStride + +E1Post: +.inst 0xc0820000 // mova z0.s, p0/m, za0h.s[w12, 0] +.inst 0xc0822001 // mova z1.s, p0/m, za0h.s[w13, 0] +.inst 0xc0824002 // mova z2.s, p0/m, za0h.s[w14, 0] +.inst 0xc0826003 // mova z3.s, p0/m, za0h.s[w15, 0] + +cbz x5, E1Store +.inst 0xc1bfcbc0 // fclamp {z0.s-z3.s}, z30.s, z31.s + +E1Store: +cmp x10, #16 +bge E1StoreH64 + + +cmp x10, #1 +beq E1StoreH4 + +cmp x10, #2 +beq E1StoreH8 + +cmp x10, #3 +beq E1StoreH12 + +cmp x10, #4 +beq E1StoreH16 + +cmp x10, #5 +beq E1StoreH20 + +cmp x10, #6 +beq E1StoreH24 + +cmp x10, #7 +beq E1StoreH28 + +cmp x10, #8 +beq E1StoreH32 + +cmp x10, #9 +beq E1StoreH36 + +cmp x10, #10 +beq E1StoreH40 + +cmp x10, #11 +beq E1StoreH44 + +cmp x10, #12 +beq E1StoreH48 + +cmp x10, #13 +beq E1StoreH52 + +cmp x10, #14 +beq E1StoreH56 + +cmp x10, #15 +beq E1StoreH60 + + +E1StoreH64: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +add x11, x12, x7, LSL #1 // 10*x7 +add x20, x24, x7, LSL #3 // 12*x7 +add x8, x23, x7, LSL #3 // 14*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0x05b02058 // dup z24.q, z2.q[2] +.inst 0x05f02059 // dup z25.q, z2.q[3] +.inst 0x0570207a // dup z26.q, z3.q[1] +.inst 0x05b0207b // dup z27.q, z3.q[2] +.inst 0x05f0207c // dup z28.q, z3.q[3] + +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +.inst 0xe400f178 // st1b {z24.b}, p4, [x11] +.inst 0xe4075179 // st1b {z25.b}, p4, [x11, x7] +.inst 0xe400f283 // st1b {z3.b}, p4, [x20] +.inst 0xe407529a // st1b {z26.b}, p4, [x20, x7] +.inst 0xe400f11b // st1b {z27.b}, p4, [x8] +.inst 0xe407511c // st1b {z28.b}, p4, [x8, x7] +b E1H16_End + +E1StoreH60: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +add x11, x12, x7, LSL #1 // 10*x7 +add x20, x24, x7, LSL #3 // 12*x7 +add x8, x23, x7, LSL #3 // 14*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0x05b02058 // dup z24.q, z2.q[2] +.inst 0x05f02059 // dup z25.q, z2.q[3] +.inst 0x0570207a // dup z26.q, z3.q[1] +.inst 0x05b0207b // dup z27.q, z3.q[2] + +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +.inst 0xe400f178 // st1b {z24.b}, p4, [x11] +.inst 0xe4075179 // st1b {z25.b}, p4, [x11, x7] +.inst 0xe400f283 // st1b {z3.b}, p4, [x20] +.inst 0xe407529a // st1b {z26.b}, p4, [x20, x7] +.inst 0xe400f11b // st1b {z27.b}, p4, [x8] +b End + +E1StoreH56: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +add x11, x12, x7, LSL #1 // 10*x7 +add x20, x24, x7, LSL #3 // 12*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0x05b02058 // dup z24.q, z2.q[2] +.inst 0x05f02059 // dup z25.q, z2.q[3] +.inst 0x0570207a // dup z26.q, z3.q[1] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +.inst 0xe400f178 // st1b {z24.b}, p4, [x11] +.inst 0xe4075179 // st1b {z25.b}, p4, [x11, x7] +.inst 0xe400f283 // st1b {z3.b}, p4, [x20] +.inst 0xe407529a // st1b {z26.b}, p4, [x20, x7] +b End + +E1StoreH52: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +add x11, x12, x7, LSL #1 // 10*x7 +add x20, x24, x7, LSL #3 // 12*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0x05b02058 // dup z24.q, z2.q[2] +.inst 0x05f02059 // dup z25.q, z2.q[3] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +.inst 0xe400f178 // st1b {z24.b}, p4, [x11] +.inst 0xe4075179 // st1b {z25.b}, p4, [x11, x7] +.inst 0xe400f283 // st1b {z3.b}, p4, [x20] +b End + +E1StoreH48: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +add x11, x12, x7, LSL #1 // 10*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0x05b02058 // dup z24.q, z2.q[2] +.inst 0x05f02059 // dup z25.q, z2.q[3] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +.inst 0xe400f178 // st1b {z24.b}, p4, [x11] +.inst 0xe4075179 // st1b {z25.b}, p4, [x11, x7] +b End + +E1StoreH44: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +add x11, x12, x7, LSL #1 // 10*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0x05b02058 // dup z24.q, z2.q[2] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +.inst 0xe400f178 // st1b {z24.b}, p4, [x11] +b End + +E1StoreH40: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0x05702057 // dup z23.q, z2.q[1] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +.inst 0xe4075197 // st1b {z23.b}, p4, [x12, x7] +b End + +E1StoreH36: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +add x12, x0, x7, LSL #3 // 8*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +.inst 0xe400f182 // st1b {z2.b}, p4, [x12] +b End + +E1StoreH32: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0x05f02036 // dup z22.q, z1.q[3] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +.inst 0xe40752f6 // st1b {z22.b}, p4, [x23, x7] +b End + +E1StoreH28: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +add x23, x21, x7, LSL #2 // 6*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0x05b02035 // dup z21.q, z1.q[2] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +.inst 0xe400f2f5 // st1b {z21.b}, p4, [x23] +b End + +E1StoreH24: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0x05702034 // dup z20.q, z1.q[1] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +.inst 0xe4075314 // st1b {z20.b}, p4, [x24, x7] +b End + +E1StoreH20: +add x21, x0, x7, LSL #1 // 2*x7 +add x24, x0, x7, LSL #2 // 4*x7 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +.inst 0xe400f301 // st1b {z1.b}, p4, [x24] +b End + +E1StoreH16: +add x21, x0, x7, LSL #1 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0x05f02013 // dup z19.q, z0.q[3] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +.inst 0xe40752b3 // st1b {z19.b}, p4, [x21, x7] +b End + +E1StoreH12: +add x21, x0, x7, LSL #1 +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0x05b02012 // dup z18.q, z0.q[2] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +.inst 0xe400f2b2 // st1b {z18.b}, p4, [x21] +b End + +E1StoreH8: +.inst 0x05702011 // dup z17.q, z0.q[1] +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +.inst 0xe4075011 // st1b {z17.b}, p4, [x0, x7] +b End + +E1StoreH4: +.inst 0xe400f000 // st1b {z0.b}, p4, [x0] +b End + +E1H16_End: +subs x10, x10, #16 +add x0, x0, x7, LSL #4 +bne E1LoopH + + End: .inst 0xd503467f // smstop +ldp x23, x24, [sp, #96] ldp x19, x20, [sp, #80] ldp x21, x22, [sp, #64] ldp d8, d9, [sp, #48] ldp d10, d11, [sp, #32] ldp d12, d13, [sp, #16] -ldp d14, d15, [sp], #96 +ldp d14, d15, [sp], #128 ret diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index 10bbca5585..d7d0d7fb34 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -536,8 +536,8 @@ static void MNNAsyQuantInfo_FP32(float* scale, float* bias, float* qscale, float } else { qscale[0] = 255.f / range; scale[0] = range / 255.f; - qbias[0] = roundf(-minval * 255.f / range)- 128.f; - bias[0] = -qbias[0] * scale[0]; + qbias[0] = -minval * 255.f / range - 128.f; + bias[0] = minval + 128.f * range / 255.f; } return; } @@ -1392,7 +1392,7 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) { #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { +static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes, int seqStart) { // source shape: [headDim/pack, seqLen, pack] // scale & normalizeScale shape: [seqLen] // dest shape: [headDim/pack, seqLen, pack] @@ -1400,7 +1400,8 @@ static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* sc if (idx > 0) { for (int j = 0; j < depthQuad; ++j) { - for (int i = 0; i < plane; ++i) { + int i = seqStart; + for (; i < plane; ++i) { auto dataNew = Vec::load(src + j * stride0 + i * pack); auto dataOld = Vec::load(dst + j * stride0 + i * pack); auto s = Vec(scale[i]); @@ -1463,6 +1464,9 @@ static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, dstBasePtr[(d + 6) * dstStrideDOuter] = sVec1[2]; dstBasePtr[(d + 7) * dstStrideDOuter] = sVec1[3]; } + for (; d < headDim; ++d) { + dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal; + } #else for (; d < headDim; ++d) { dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal; @@ -1470,10 +1474,175 @@ static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, #endif } } + +#ifndef __aarch64__ +void MNNQuantAttentionKey(int8_t* dst, const float* source, float* sumKeyPtr, float* maxKeyPtr, int32_t* params) { + int32_t kvNumHead = params[0]; + int32_t seqLen = params[1]; + int32_t headDim = params[2]; + int32_t blockNum = params[3]; + int32_t eP = params[4]; + int32_t lP = params[5]; + int32_t hP = params[6]; + int32_t pastLength = params[7]; + int32_t kvHeadIdx = params[8]; + + auto blockL = UP_DIV(headDim, blockNum); + auto weightStride1 = ROUND_UP(blockL, lP) * hP; + auto weightStride2 = lP * hP; + auto packedWeightStride1 = weightStride1 + 2 * 4 * hP; + + if (seqLen > 1) { + // get max + for (int s = 0; s < seqLen; ++s) { + const float* keySrc = source + s * kvNumHead * headDim + kvHeadIdx * headDim; + for (int d = 0; d < headDim; d++) { + maxKeyPtr[d] = ALIMAX(maxKeyPtr[d], keySrc[d]); + } + } + } + + for (int s = 0; s < seqLen; s++) { + const float* keySrc = source + s * kvNumHead * headDim + kvHeadIdx * headDim; + float minKey, maxKey; + minKey = keySrc[0] - maxKeyPtr[0]; + maxKey = keySrc[0] - maxKeyPtr[0]; + for (int d = 1; d < headDim; d++) { + auto keydata = keySrc[d] - maxKeyPtr[d]; + minKey = ALIMIN(minKey, keydata); + maxKey = ALIMAX(maxKey, keydata); + } + + int outIndex = (pastLength + s) / hP; + int inIndex = (pastLength + s) % hP; + + float sumKey = 0; + for (int k = 0; k < blockNum; ++k) { + int8_t* weightDst = dst + outIndex * blockNum * packedWeightStride1 + k * packedWeightStride1; + float* scaleDst = (float*)(weightDst + weightStride1); + float* biasDst = scaleDst + hP; + + scaleDst[inIndex] = (maxKey - minKey) / 255.0f; + biasDst[inIndex] = minKey + 128.f * (maxKey - minKey) / 255.f; + + for (int d = 0; d < blockL; d++) { + int i = d / lP; + int j = d % lP; + + int int8v = (int)(roundf((keySrc[d + k * blockL] - maxKeyPtr[d + k * blockL] - minKey) / (maxKey - minKey) * 255.0f - 128.0f)); + weightDst[i * weightStride2 + inIndex * lP + j] = int8v; + sumKey += (int8v * scaleDst[inIndex] + biasDst[inIndex]); + } + } + sumKeyPtr[outIndex * hP + inIndex] = sumKey; + } +} + +void MNNQuantAttentionValue(int8_t* dst, const float* source, float* valueSum, int32_t* params) { + // float value src : [kvSeq,kvNumHead,headDim] + // int8_t value dest: [updiv(maxLength,flashAttentionBlockKv), updiv(headDim,hp),updiv(flashAttentionBlockKv,lp),hp,lp] + // float value sum: [updiv(maxLength,flashAttentionBlockKv), roundup(headDim,hp)] + int32_t kvNumHead = params[0]; + int32_t seqLen = params[1]; + int32_t headDim = params[2]; + int32_t blockNum = params[3]; + int32_t maxLength = params[4]; + + int32_t lP = params[5]; + int32_t hP = params[6]; + int32_t pastLength = params[7]; + int32_t kvHeadIdx = params[8]; + + int32_t flashAttentionBlockKv = params[9]; + + auto blockKvseq = UP_DIV(seqLen + pastLength, blockNum); + auto weightStride2 = lP * hP; + auto weightStride1 = UP_DIV(flashAttentionBlockKv, lP) * weightStride2; + + auto packedStride1 = (int)(weightStride1 + 2 * hP * sizeof(float)); + auto packedStride0 = UP_DIV(headDim, hP) * packedStride1; + + auto srcStride0 = kvNumHead * headDim; + + auto sourceFp32 = (float*)source; + + // quant scale & bias + if (pastLength == 0) { + for (int d = 0; d < headDim; ++d) { + float* scalePtr = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasPtr = scalePtr + hP; + + // find min,max + float dMax = sourceFp32[d + kvHeadIdx * headDim]; + float dMin = dMax; + for (int s = 0; s < seqLen; ++s) { + float data = sourceFp32[s * srcStride0 + d + kvHeadIdx * headDim]; + dMax = ALIMAX(dMax, data); + dMin = ALIMIN(dMin, data); + } + + // scale & bias + float range = dMax - dMin; + if (range < 1e-6) { + scalePtr[0] = 0.f; + biasPtr[0] = dMax; + } else { + float scale = range / 255.f; + float bias = range / 255.f * 128.f + dMin; + scalePtr[0] = scale; + biasPtr[0] = bias; + } + } + } + + // copy the scale&bias to each blockKv + // pastLength == 0: First time prefill + // (seqLen + pastLength) % flashAttentionBlockKv == 0: Open a new blockKv + if (pastLength == 0 || (pastLength % flashAttentionBlockKv) == 0) { + int32_t d0 = UP_DIV(maxLength, flashAttentionBlockKv); + int32_t d1 = UP_DIV(headDim, hP); + for (int k = 0; k < d0; ++k) { + for (int r = 0; r < d1; ++r) { + float* scalePtr = (float*)(dst + k * packedStride0 + r * packedStride1 + weightStride1); + float* biasPtr = scalePtr + hP; + memcpy(scalePtr, dst + r * packedStride1 + weightStride1, hP * sizeof(float)); + memcpy(biasPtr, dst + r * packedStride1 + weightStride1 + hP * sizeof(float), hP * sizeof(float)); + } + } + } + + for (int d = 0; d < headDim; ++d) { + // dst address + int idxBase = (d / hP) * packedStride1 + (d % hP) * lP; + int8_t* dstBase = dst + idxBase; + float* scaleBase = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP); + float* biasBase = scaleBase + hP; + float* sumBase = valueSum + (d / hP) * hP + (d % hP); + + float qscale = scaleBase[0] < 1e-6 ? 0 : 1.0f / scaleBase[0]; + float qbias = scaleBase[0] < 1e-6 ? 0 : (-biasBase[0] / scaleBase[0]); + // quant + for (int s = 0; s < seqLen; ++s) { + int kvSeqIndx = s + pastLength; + int idxInner = (kvSeqIndx / flashAttentionBlockKv) * packedStride0 + (kvSeqIndx % flashAttentionBlockKv) / lP * weightStride2 + (kvSeqIndx % flashAttentionBlockKv) % lP; + float xf = sourceFp32[s * srcStride0 + d + kvHeadIdx * headDim]; + int8_t xq = ALIMAX(ALIMIN(127, static_cast(roundf(xf * qscale + qbias))), -128); + dstBase[idxInner] = xq; + + // sum + int idxSum = (kvSeqIndx / flashAttentionBlockKv) * ROUND_UP(headDim, hP); + sumBase[idxSum] += ((float)xq * scaleBase[0] + biasBase[0]); + } + } +} + +#endif + #endif // MNN_SUPPORT_TRANSFORMER_FUSE #ifndef MNN_USE_NEON + void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) { *eP = 16; *lP = 1; @@ -2702,54 +2871,136 @@ void MNNExpC8(float* dest, const float* source, float* offset, const float* para offset[3] = summer; } -void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { +void MNNSoftmax(float* softmaxDst, const float* softmaxSrc, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) { + + // source shape: [reduceSizeOuter, outside, reduceSizeInner] + // for C4, [up_div(reduceSize,4), outside,4] => reduceSizeOuter=up_div(reduceSize,4), reduceSizeInner=4 + // for C, [outside, reduceSize] => reduceSizeOuter=1, reduceSizeInner=reduceSize + + const int packUnit = 4; + int reduceSizeOuter = 1; + int reduceSizeInner = reduceSize; + int stride0 = packUnit; + if (pack > 1) { + reduceSizeOuter = UP_DIV(reduceSize, pack); + reduceSizeInner = pack; + stride0 = outside * reduceSizeInner; + } + + float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; for (int k = 0; k < outside; ++k) { - auto source = input + k * reduceSize; - auto dest = softmaxDst + k * reduceSize; + exprOffset[3] = 0.0f; // init sum to zero for each outer loop + if (mask && kvSeqOffset > k + validOffset) { + if (updateScale){ + updateScale[k] = 1; + } + for (int j = 0; j < reduceSizeOuter; ++j) { + int i = 0; + for (; i < reduceSizeInner; i += packUnit) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner + i; + memset(destPtr, 0, packUnit * sizeof(float)); + } + if (i < reduceSizeInner) { + memset(softmaxDst + j * stride0 + k * reduceSizeInner + i, 0, (reduceSizeInner - i) * sizeof(float)); + } + } + continue; + } - float oldMax = source[0]; + const int validReduceSize = mask ? ALIMIN(reduceSize, k + (validOffset + 1) - kvSeqOffset) : reduceSize; + const int remain = validReduceSize % packUnit; + const int sizeDiv = validReduceSize / packUnit; + + // 1. newMax + float oldMax = std::numeric_limits::lowest(); if (runningMax) { oldMax = runningMax[k]; } - // find max value of current block - float blockMax =source[0]; - for (int i = 1; i < reduceSize; ++i) { - blockMax = ALIMAX(blockMax, source[i]); + float newMax = std::numeric_limits::lowest(); + + for (int j = 0; j < sizeDiv; ++j) { + auto srcPtr = softmaxSrc + j * stride0 + k * reduceSizeInner; + for (int i = 0; i < packUnit; ++i) { + newMax = ALIMAX(newMax, srcPtr[i]); + } } - float newMax = ALIMAX(oldMax, blockMax); - // caculate block's expr(xi-newmax) and update runningMax - float xLimit = 87, param = 0.6931471805599453; - float blockSum = 0.f; - for (int i = 0; i < reduceSize; ++i) { - auto x = source[i] - newMax; - x = x > -xLimit ? x : -xLimit; - x = x < xLimit ? x : xLimit; + if (remain > 0) { + auto srcPtr = softmaxSrc + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + newMax = ALIMAX(newMax, srcPtr[i]); + } + } - int div = (x / param); - int div2 = (div + 127) << 23; - auto xReamin = x - div * param; - float expBasic = *(float*)(&div2); + const float finalMax = ALIMAX(oldMax, newMax); - auto t = xReamin; - auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; - dest[i] = expBasic * expRemain; - blockSum += dest[i]; + // 2. exp(x - finalMax) + exprOffset[2] = -finalMax; + + for (int j = 0; j < sizeDiv; ++j) { + auto idx = j * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + MNNExp(dstPtr, srcPtr, exprOffset, packUnit); } + float sum = exprOffset[3]; + + if (remain > 0) { + auto idx = sizeDiv * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + + for(int i = 0; i < remain; ++i) { + float val = expf(srcPtr[i] - finalMax); + sum += val; + dstPtr[i] = val; + } + } + + // 3. if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { - // update runningSum, runningMax, scale=expf(oldMax-newMax) - runningSum[k] = runningSum[k] * expf(oldMax - newMax) + blockSum; - runningMax[k] = newMax; - updateScale[k] = expf(oldMax - newMax); + // update runningSum, runningMax, scale + float scaleForSum = expf(oldMax - finalMax); + runningSum[k] = runningSum[k] * scaleForSum + sum; + runningMax[k] = finalMax; + updateScale[k] = scaleForSum; } else { - // Normalize - auto scale = 1.f / blockSum; - for (int i = 0; i < reduceSize; ++i) { - dest[i] *= scale; + // Normalization + if (runningMax != nullptr && runningSum != nullptr) { + sum += runningSum[k] * expf(oldMax - finalMax); + } + float scale = 1.0f / (sum + 1e-20f); + + for (int j = 0; j < sizeDiv; ++j) { + auto pDest = softmaxDst + j * stride0 + k * reduceSizeInner; + for (int i = 0; i < packUnit; ++i) { + pDest[i] = pDest[i] * scale; + } + } + if (remain > 0) { + auto pDest = softmaxDst + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + pDest[i] = pDest[i] * scale; + } } } + + // 4. memset 0 + if (pack > 1) { + if (validReduceSize % packUnit > 0) { + memset(softmaxDst + sizeDiv * stride0 + k * reduceSizeInner + (validReduceSize % packUnit), 0, (packUnit - (validReduceSize % packUnit)) * sizeof(float)); + } + auto validDiv4 = UP_DIV(validReduceSize, packUnit); + auto allDiv4 = UP_DIV(reduceSize, packUnit); + for (int j = validDiv4; j < allDiv4; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, packUnit * sizeof(float)); + } + } else { + memset(softmaxDst + k * reduceSizeInner + validReduceSize, 0, (reduceSize - validReduceSize) * sizeof(float)); + } } } @@ -2764,6 +3015,8 @@ void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) } #endif // no MNN_USE_SSE + + void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) { int countC8 = static_cast(dataSize) / 8; int remain = static_cast(dataSize) % 8; @@ -3380,9 +3633,10 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i int cDiv4 = c / 4; int cAlign = cDiv4 * 4; auto srcArea = areaOffset[0]; + auto dstDepthOffset = areaOffset[1]; for (int hi = 0; hi < area; ++hi) { const float* srcHeight = src + hi * 4; - float* dstHeight = dst + hi * c; + float* dstHeight = dst + hi * dstDepthOffset; for (int ci = 0; ci < cDiv4; ++ci) { Vec4::save(dstHeight + 4 * ci, Vec4::load(srcHeight + 4 * ci * srcArea)); } @@ -3398,7 +3652,7 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i for (int hi = 0; hi < area; ++hi) { const float* srcHeight = srcAlign + hi * 4; - float* dstHeight = dstAlign + hi * c; + float* dstHeight = dstAlign + hi * dstDepthOffset; for (int ci = 0; ci < cReamin; ++ci) { dstHeight[ci] = srcHeight[ci]; @@ -3789,18 +4043,18 @@ void MNNUnpackTransposeInt16(int16_t* dst, const int16_t* src, size_t area,size_ } } } -void MNNPackTransposeInt16(int16_t* dst, const int16_t* src, size_t area,size_t depth, int* areaOffset) { +void MNNPackTransposeInt16(int16_t* dst, const int16_t* src, size_t area,size_t depth, int* offset) { int c = (int)depth; int cDiv4 = c / 4; int cAlign = cDiv4 * 4; + int srcAreaOffset = offset[0]; + int dstDepthOffset = offset[1]; if (cAlign == c) { - int64_t* dst32 = (int64_t*)dst; - const int64_t* src32 = (int64_t*)src; for (int hi = 0; hi < area; ++hi) { - auto srcHeight = src32 + hi; - auto dstHeight = dst32 + hi * cDiv4; + auto srcHeight = (int64_t*)src + hi; + auto dstHeight = (int64_t*)(dst + hi * dstDepthOffset); for (int ci = 0; ci < cDiv4; ++ci) { - dstHeight[ci] = srcHeight[ci * areaOffset[0]]; + dstHeight[ci] = srcHeight[ci * srcAreaOffset]; } } return; @@ -3808,21 +4062,21 @@ void MNNPackTransposeInt16(int16_t* dst, const int16_t* src, size_t area,size_t for (int hi = 0; hi < area; ++hi) { auto srcHeight = src + hi * 4; - auto dstHeight = dst + hi * c; + auto dstHeight = dst + hi * dstDepthOffset; for (int ci = 0; ci < cDiv4; ++ci) { for (int i = 0; i < 4; ++i) { - dstHeight[ci * 4 + i] = srcHeight[4 * ci * areaOffset[0] + i]; + dstHeight[ci * 4 + i] = srcHeight[4 * ci * srcAreaOffset + i]; } } } int cReamin = c - cAlign; - auto srcAlign = src + areaOffset[0] * cAlign; + auto srcAlign = src + srcAreaOffset * cAlign; auto dstAlign = dst + cAlign; for (int hi = 0; hi < area; ++hi) { auto srcHeight = srcAlign + hi * 4; - auto dstHeight = dstAlign + hi * c; + auto dstHeight = dstAlign + hi * dstDepthOffset; for (int ci = 0; ci < cReamin; ++ci) { dstHeight[ci] = srcHeight[ci]; @@ -4224,7 +4478,9 @@ void MNNCoreFunctionInit() { #ifdef MNN_SUPPORT_TRANSFORMER_FUSE gCoreFunction->MNNAttenPackAndScaleSingleHead = MNNAttenPackAndScaleSingleHead; gCoreFunction->MNNFlashAttentionUpdateBlockOutput = MNNFlashAttentionUpdateBlockOutput; -#endif + gCoreFunction->MNNQuantAttentionKey = MNNQuantAttentionKey; + gCoreFunction->MNNQuantAttentionValue = MNNQuantAttentionValue; +#endif // MNN_SUPPORT_TRANSFORMER_FUSE gCoreFunction->MNNReluWithSlopeChannel = MNNReluWithSlopeChannel; gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg); @@ -4258,6 +4514,7 @@ void MNNCoreFunctionInit() { gCoreFunction->supportSDot = gCPUInfo.dot; gCoreFunction->supportI8mm = gCPUInfo.i8mm; gCoreFunction->supportSME2 = gCPUInfo.sme2; + gCoreFunction->smeCoreNumber = gCPUInfo.smeCoreNumber; gCoreFunction->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A; gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4; gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8; @@ -4265,6 +4522,8 @@ void MNNCoreFunctionInit() { if (gCoreFunction->supportSDot) { gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm82; gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm82; + gCoreFunction->arm82MatmulRelatedFunctions.MNNReorderWeightInt4 = MNNReorderWeightInt4Arm82; + gCoreFunction->arm82MatmulRelatedFunctions.MNNSumWeightInt8 = MNNSumWeightInt8Arm82; } if (gCoreFunction->supportI8mm) { gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm86; @@ -4287,33 +4546,28 @@ void MNNCoreFunctionInit() { #ifdef __aarch64__ if (gCoreFunction->supportSDot) { gCoreFunction->MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Arm82; + gCoreFunction->arm82MatmulRelatedFunctions.MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Arm82; } if (gCoreFunction->supportI8mm) { gCoreFunction->MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Arm86; } #endif #endif - { // int8MatmulRelatedFunctions - gCoreFunction->int8MatmulRelatedFunctions.MNNReorderWeightInt4 = gCoreFunction->MNNReorderWeightInt4; - gCoreFunction->int8MatmulRelatedFunctions.MNNSumWeightInt8 = gCoreFunction->MNNSumWeightInt8; - gCoreFunction->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gCoreFunction->MNNGeneralIm2Col; - } + + #ifdef __aarch64__ #ifdef MNN_SME2 if (gCoreFunction->supportSME2) { // Int8 Gemm related gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Sme2_Hp32; - gCoreFunction->MNNSumWeightInt8SmeHp64 = MNNSumWeightInt8Sme2_Hp128; + gCoreFunction->MNNSumWeightInt8SmeHp128 = MNNSumWeightInt8Sme2_Hp128; gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Sme2; - gCoreFunction->sme2Int8MatmulRelatedFuncionsHp32.MNNSumWeightInt8 = MNNSumWeightInt8Sme2_Hp32; - gCoreFunction->sme2Int8MatmulRelatedFuncionsHp32.MNNSumWeightInt8SmeHp64 = MNNSumWeightInt8Sme2_Hp128; - gCoreFunction->sme2Int8MatmulRelatedFuncionsHp32.MNNReorderWeightInt4 = MNNReorderWeightInt4Sme2; #ifdef MNN_LOW_MEMORY gCoreFunction->MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Sme2; - gCoreFunction->sme2Int8MatmulRelatedFuncionsHp32.MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Sme2; #endif + gCoreFunction->int8MatmulRelatedFunctions.MNNSumWeightInt8SmeHp128 = MNNSumWeightInt8Sme2_Hp128; // Float Gemm related gCoreFunction->MNNPackedMatMul = MNNPackedMatMulFP32_SME2; @@ -4324,6 +4578,13 @@ void MNNCoreFunctionInit() { } #endif // MNN_SME2 #endif // __aarch64__ + + + { // Update the function pointers in the int8MatmulRelatedFunctions struct. + gCoreFunction->int8MatmulRelatedFunctions.MNNReorderWeightInt4 = gCoreFunction->MNNReorderWeightInt4; + gCoreFunction->int8MatmulRelatedFunctions.MNNSumWeightInt8 = gCoreFunction->MNNSumWeightInt8; + gCoreFunction->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gCoreFunction->MNNGeneralIm2Col; + } MNNCoreInt8FunctionInit(); MNNFunctionInit(); } diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index 2182bccc7f..7aeee8a246 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -18,6 +18,10 @@ #include "core/Macro.h" #include "backend/cpu/compute/Int8FunctionsOpt.h" +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +#define MNN_FLASH_ATTENTION_BLOCK_SIZE 64 +#endif + extern "C" { #ifdef __aarch64__ #ifdef MNN_LOW_MEMORY @@ -41,6 +45,8 @@ void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, si #endif // __aarch64__ +void MNNQuantAttentionKey(int8_t* dst, const float* source, float* sumKey, float* maxKey, int32_t* params); +void MNNQuantAttentionValue(int8_t* dst, const float* source, float* valueQuantInfo, int32_t* params); void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size); void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size); @@ -121,8 +127,8 @@ void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slo void MNNHardSwishCommon(float* dst, const float* src, size_t size); void MNNGeluCommon(float* dst, const float* src, size_t size); void MNNGeluStandardCommon(float* dst, const float* src, size_t size); -void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false); +void MNNSoftmax(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack = 1, bool mask = false); // Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n void MNNGetMatMulPackMode(int* eP, int *lP, int* hP); @@ -231,7 +237,7 @@ namespace MNN { struct MatmulRelatedFunctions { // from coreFunctions void (*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr; - void (*MNNSumWeightInt8SmeHp64)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr; + void (*MNNSumWeightInt8SmeHp128)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr; void (*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum) = nullptr; void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack) = nullptr; @@ -248,6 +254,8 @@ struct MatmulRelatedFunctions { void(*MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams) = nullptr; + + int eP; }; struct CoreFunctions { @@ -262,6 +270,7 @@ struct CoreFunctions { bool supportSDot = false; bool supportI8mm = false; bool supportSME2 = false; + int smeCoreNumber = 0; /**MatMul Pack and Functions*/ void(*MNNGetMatMulPackMode)(int* eP, int *lP, int* hP); void(*MNNGetSparseMatMulPackMode)(int* eP, int *lP, int* hP); @@ -271,6 +280,13 @@ struct CoreFunctions { // parameters: e, l, h, CStride, AStride, BStride void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void(*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + // int8 matmul related + void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); + void(*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum); + void(*MNNSumWeightInt8)(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); + void(*MNNSumWeightInt8SmeHp128)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); + + // cpu dynamic quant void(*MNNAbsMax)(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) = nullptr; void(*MNNQuantScale)(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch) = nullptr; void(*MNNDynamicQuant)(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack, const float* bias) = nullptr; @@ -402,20 +418,18 @@ struct CoreFunctions { void(*MNN2BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); void(*MNN1BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); void(*MNNAccumulateSequenceNumber)(float* dst, const float* src, int size); - void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); - void(*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum); - void(*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); - void(*MNNSumWeightInt8SmeHp64)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); // Attention void(*MNNAttenUnpackAndConvertFp16)(float* dst, float* src, size_t depth, size_t planesize, int pack); void(*MNNAttenPackAndConvertFp32)(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize); void(*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim); - void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes); - void(*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); + void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes, int seqStart); + void(*MNNSoftmax)(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask); + void(*MNNQuantAttentionKey)(int8_t* dst, const float* source, float* sumKey, float* maxKey, int32_t* params); + void(*MNNQuantAttentionValue)(int8_t* dst, const float* source, float* valueQuantInfo, int32_t* params); MatmulRelatedFunctions int8MatmulRelatedFunctions; - MatmulRelatedFunctions sme2Int8MatmulRelatedFuncionsHp32; + MatmulRelatedFunctions arm82MatmulRelatedFunctions; }; void MNNCoreFunctionInit(); CoreFunctions* MNNGetCoreFunctions(); diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 412e3cd978..6babfceb22 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -18,8 +18,6 @@ #define QUANT_INFO_BYTES 4 #define WEIGHT_ONLINE_REORDER 8 -#define SME_DECODE_MAXHP 128 -#define SME_INT8MATMUL_EP 16 namespace MNN { ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op): CPUConvolution(op->main_as_Convolution2D()->common(), backend) {} @@ -51,7 +49,7 @@ ErrorCode ConvInt8TiledExecutor::onResize(const std::vector& inputs, co void ConvInt8TiledExecutor::initializeConvInt8QuantInfo(std::shared_ptr &resourceInt8, const Convolution2D *conv2D, std::shared_ptr quanCommon) { // input/output scale&zeorpoint if (conv2D->symmetricQuan()) { - resourceInt8->mActBits = conv2D->symmetricQuan()->nbits(); + resourceInt8->mWeightBits = conv2D->symmetricQuan()->nbits(); } if (conv2D->bias() && (conv2D->quanParameter()->alpha() || quanCommon->alpha.get())) { resourceInt8->mUseConvQuan = false; @@ -186,7 +184,7 @@ void ConvInt8TiledExecutor::packWeightAndQuantInfo(int8_t* dstbuffer, const int8 } } -static void _computeReorderQuantInfo(std::shared_ptr resource, int outputCount, int kernelCount, int pack, AutoStorage& reorderedQuantInfo, float* ikernelSum, int HP, bool realInt4OrInt8, bool canUseInt4, const float* quanInfoPtr, bool asymmetric) { +static void _computeReorderQuantInfo(float* weightKernelSum, int32_t* paramsKernelSum, bool blockQuantInput, bool canUseInt4, bool asyQuantWeight, float* quanInfoPtr, int outputCount, int kernelCount, int pack, AutoStorage& reorderedQuantInfo, float* ikernelSum, int HP, bool realInt4OrInt8) { // Only used for dynamic quant: // copy gemm bias // copy/compute real dequant bias/scale @@ -194,7 +192,8 @@ static void _computeReorderQuantInfo(std::shared_ptrmBlockNum; + int blockNum = paramsKernelSum[0]; + int kernelSumSize = paramsKernelSum[1]; int scaleSize = blockNum * ocUp4; // pack size. int blockSize = kernelCount / blockNum; int originOffset = 0; @@ -208,13 +207,11 @@ static void _computeReorderQuantInfo(std::shared_ptrmWeightKernelSum->host(); - ::memset(weightKernelSum, 0, resource->mWeightKernelSum->size()); + ::memset(weightKernelSum, 0, kernelSumSize * QUANT_INFO_BYTES); - bool blockQuantInput = (resource->mWeightKernelSum->length(0) / QUANT_INFO_BYTES == ocUpHp) ? false : true; int ocDiv4 = UP_DIV(outputCount, pack); // resource->mWeightKernelSum: [hU,blocknum,hP] - if (asymmetric) { + if (asyQuantWeight) { for (int i = 0; i < outputCount; ++i) { float accum = 0.f; auto ocOutside = i / HP; @@ -267,14 +264,84 @@ static void _computeReorderQuantInfo(std::shared_ptr& divides, int oc, int threads, int pack, int planeSize, int divisionRatio, int smeCores) { + // workload + auto ocDivPack = UP_DIV(oc, pack); + auto workUnit = UP_DIV(ocDivPack, divisionRatio * smeCores + 1 * (threads - smeCores)); + int calOcMain = ALIMIN(ROUND_UP(workUnit * pack * smeCores * divisionRatio, GEMM_INT8_UNIT_SME2_128), oc); + if (calOcMain <= ocMain) { // The purpose of this function is to increase the value of ocMain. + return; + } + ocMain = calOcMain; + ocBranch = oc - ocMain; + divides.assign(threads + 1, ocDivPack); + divides[0] = 0; + + // runtime UNIT for different core and different process(prefill or decode) + auto rtUnit4Sme = planeSize == 1? GEMM_INT8_UNIT_SME2_128 : GEMM_INT8_UNIT_SME2; + // mOcMain + auto ocPerSmeCore = ALIMIN(UP_DIV(UP_DIV(ROUND_UP(ocMain, pack), rtUnit4Sme), smeCores) * (rtUnit4Sme / pack), UP_DIV(ocMain, pack)); + for (int i = 0; i < smeCores; ++i) { + divides[i + 1] = ALIMIN(divides[i] + ocPerSmeCore, UP_DIV(ocMain, pack)); + } + + // ocRemain + if (ocBranch > 0) { + auto ocPerNeonCore = UP_DIV(UP_DIV(ROUND_UP(ocBranch, pack), GEMM_INT8_UNIT_ARM82), threads - smeCores) * (GEMM_INT8_UNIT_ARM82 / pack); + for (int i = smeCores + 1; i < threads + 1; ++i) { + divides[i] = ALIMIN(divides[i - 1] + ocPerNeonCore, ocDivPack); + } + } +} + +static inline void _getProportions(int totalProp, int& intensiveProp, int& lightProp) { + // compute the proportions of different kernels + lightProp = totalProp % 8; + intensiveProp = totalProp / 8 % 8; + if (lightProp == 0 && intensiveProp == 0) { + // pass + // Don't use mixed kernels + } else if (lightProp == 0) { + lightProp = 1; + } else if (intensiveProp == 0) { + intensiveProp = 6; + } else if (lightProp > intensiveProp) { + lightProp = 1; + } +} + +static inline void _computeDivides4Sme(std::vector& divides, int threads, int smeCoreNums, int size) { + divides.resize(threads + 1); + divides[0] = 0; + auto length = UP_DIV(size, smeCoreNums); + auto cur = length; + for (int i = 1; i < smeCoreNums + 1; ++i) { + divides[i] = cur; + cur = ALIMIN(cur + length, size); + } +} + +static inline void _updateMixedKernelFlag(bool &mixedKernel, bool &onlineReorderWeightSme, int threads, int eP, bool isDynamciQuant, bool postiveBothProp) { + mixedKernel = false; + if (threads >= 4 && eP == GEMM_INT8_DST_XUNIT_SME2 && isDynamciQuant && postiveBothProp) { + mixedKernel = true; + onlineReorderWeightSme = true; + } +} + DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr quanCommon, bool isDynamicQuant) : ConvInt8TiledExecutor(backend, op) { // convolution info auto convOp = op->main_as_Convolution2D(); int kernelCount = mCommon->kernelX() * mCommon->kernelY(); int oc = convOp->common()->outputCount(); int ic = convOp->common()->inputCount(); + bool asyWeight = quanCommon ? quanCommon->asymmetric : false; + + mOcBranch = 0; + mOcMain = oc; int blockNum = 1; + int inputBlockNum = 1; if (quanCommon) { int dequantCnt = quanCommon->alphaSize; if (quanCommon->asymmetric) { @@ -288,42 +355,74 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O auto core = static_cast(backend)->int8Functions(); auto gcore = static_cast(backend)->functions(); const int threads = static_cast(backend)->threadNumber(); + const int pack = gcore->pack; + + // runtime hint auto option = static_cast(backend)->getRuntime()->hint().dynamicQuantOption; auto weightOnlineReorderOption = WEIGHT_ONLINE_REORDER & option; auto inputBlockQuantOption = option % WEIGHT_ONLINE_REORDER; + if (inputBlockQuantOption == 2) { + inputBlockNum = blockNum; + } + + _getProportions(static_cast(backend)->getRuntime()->hint().divisionRatio, mRatioPrefill, mRatioDecode); + mSmeCores = gcore->smeCoreNumber; mRelatedFunctions = *(static_cast(backend)->int8GemmFunctions()); + mArm82Functions = gcore->arm82MatmulRelatedFunctions; + + int UNITMain, SRC_UNITMain, DST_XUNITMain; + int UNITBranch = 0; int SRC_UNITBranch = 0, DST_XUNITBranch = 0; + mRelatedFunctions.MNNGetGemmUnit(&UNITMain, &SRC_UNITMain, &DST_XUNITMain); + + if (mArm82Functions.MNNGetGemmUnit != nullptr) { // exclude cpu does not support arm82 + mArm82Functions.MNNGetGemmUnit(&UNITBranch, &SRC_UNITBranch, &DST_XUNITBranch); + } - int UNIT, SRC_UNIT, DST_XUNIT; - mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - int pack = gcore->pack; // prefer to maximum decode performance & the machine supports 'sme2' & the runtime backend is 'sme2' -> mOnlineReorderWeightSme=true - mOnlineReorderWeightSme = (weightOnlineReorderOption > 0 && DST_XUNIT == SME_INT8MATMUL_EP); + mOnlineReorderWeightSme = (weightOnlineReorderOption > 0 && DST_XUNITMain == GEMM_INT8_DST_XUNIT_SME2); if (isDynamicQuant == false) { mOnlineReorderWeightSme = false; } + _updateMixedKernelFlag(mMixedKernel, mOnlineReorderWeightSme, threads, DST_XUNITMain, isDynamicQuant, mRatioDecode&&mRatioPrefill); + + if (mMixedKernel) { + // total work: UP_DIV(oc, pack) + // (sme's work / neon's work) = divisionRatio + auto workUnit = UP_DIV(UP_DIV(oc, pack), mRatioDecode * mSmeCores + 1 * (threads - mSmeCores)); + mOcMain = ALIMIN(ROUND_UP(workUnit * pack * mSmeCores * mRatioDecode, GEMM_INT8_UNIT_SME2_128), oc);; + mOcBranch = oc - mOcMain; + } if (mOnlineReorderWeightSme) { - UNIT = SME_DECODE_MAXHP; + UNITMain = GEMM_INT8_UNIT_SME2_128; } // compute info - int ocUp4 = ROUND_UP(oc, pack); - int ocUpHp = ROUND_UP(oc, ALIMAX(UNIT, pack)); - int lU = UP_DIV(ic / blockNum, SRC_UNIT) * kernelCount; - int scaleSize = ocUp4 * blockNum; - std::vector shape = {blockNum, UP_DIV(oc, UNIT), lU, UNIT, SRC_UNIT}; + int ocUp4Main = ROUND_UP(mOcMain, pack); + int ocUpHpMain = ROUND_UP(mOcMain, UNITMain); + int lUMain = UP_DIV(ic / blockNum, SRC_UNITMain) * kernelCount; + int scaleSizeMain = ocUp4Main * blockNum; + + int ocUp4Branch = ROUND_UP(mOcBranch, pack); + int ocUpHpBranch = UNITBranch != 0 ? ROUND_UP(mOcBranch, UNITBranch) : 0; + int ocDivHpBranch = UNITBranch != 0 ? UP_DIV(mOcBranch, UNITBranch) : 0; + int lUBranch = UNITBranch != 0 ? UP_DIV(ic / blockNum, SRC_UNITBranch) * kernelCount : 0; + int scaleSizeBranch = ocUp4Branch * blockNum; + + std::vector shapeMain = {blockNum, UP_DIV(mOcMain, UNITMain), lUMain, UNITMain, SRC_UNITMain}; + std::vector shapeBranch = {blockNum, ocDivHpBranch, lUBranch, UNITBranch, SRC_UNITBranch}; mResourceInt8.reset(new CPUConvolution::ResourceInt8); - mResourceInt8->mWeightAsymmetricQuant = quanCommon ? quanCommon->asymmetric : false; - mResourceInt8->mActBits = 8; + mResourceInt8->mWeightAsymmetricQuant = asyWeight; + mResourceInt8->mWeightBits = 8; mResourceInt8->mBlockNum = blockNum; - if ((quanCommon && quanCommon->canUseInt4) || (convOp->symmetricQuan() && convOp->symmetricQuan()->nbits() <= 4)) { - shape[4] = SRC_UNIT / 2; - mResourceInt8->mActBits = 4; + if (quanCommon && quanCommon->canUseInt4) { + shapeMain[4] = SRC_UNITMain / 2; + shapeBranch[4] = SRC_UNITBranch / 2; + mResourceInt8->mWeightBits = 4; mResourceInt8->mWeightAsymmetricQuant = true; // offset: 8 from uint8_t } - if (isDynamicQuant) { - mResourceInt8->mDynamicQuant = true; - } + mResourceInt8->mDynamicQuant = isDynamicQuant ? true : false; + // Relu/Relu6 post parameters auto postPtr = getPostParameters(); mResourceInt8->mReluThreshold.resize(2); @@ -333,15 +432,14 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O gcore->MNNFp32ToLowp(mResourceInt8->mReluThreshold.data(), reinterpret_cast(mResourceInt8->mReluThreshold.data()), 2); } // buffer allocate - auto quantlen = 2 * blockNum * ROUND_UP(oc, UNIT) * QUANT_INFO_BYTES; - auto weightlen = shape[0] * shape[1] * shape[2] * shape[3] * shape[4]; - mResourceInt8->mWeightInt8.reset(Tensor::createDevice({weightlen + quantlen})); - mResourceInt8->mOriginBias.reset(Tensor::createDevice({ocUp4})); // float - if (inputBlockQuantOption != 2) { - mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice({QUANT_INFO_BYTES * ocUpHp})); - } else { - mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice({blockNum * QUANT_INFO_BYTES * ocUpHp})); - } + auto quantlenMain = 2 * blockNum * ROUND_UP(mOcMain, UNITMain) * QUANT_INFO_BYTES; + auto weightlenMain = shapeMain[0] * shapeMain[1] * shapeMain[2] * shapeMain[3] * shapeMain[4]; + auto quantlenBranch = 2 * blockNum * ocUpHpBranch * QUANT_INFO_BYTES; + auto weightlenBranch = shapeBranch[0] * shapeBranch[1] * shapeBranch[2] * shapeBranch[3] * shapeBranch[4]; + + mResourceInt8->mWeightInt8.reset(Tensor::createDevice({weightlenMain + quantlenMain + weightlenBranch + quantlenBranch})); + mResourceInt8->mOriginBias.reset(Tensor::createDevice({ocUp4Main + ocUpHpBranch})); // float + mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice({inputBlockNum * QUANT_INFO_BYTES * (ocUpHpMain + ocUpHpBranch)})); auto res = backend->onAcquireBuffer(mResourceInt8->mOriginBias.get(), Backend::STATIC); res &= backend->onAcquireBuffer(mResourceInt8->mWeightKernelSum.get(), Backend::STATIC); @@ -357,18 +455,22 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } // read weight, weight's scale&bias, convolution bias - ::memset(mResourceInt8->mOriginBias->host(), 0, ocUp4 * sizeof(float)); + ::memset(mResourceInt8->mOriginBias->host(), 0, mResourceInt8->mOriginBias->size()); // dynamic quant - bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic); + bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(mOcMain, UNITMain) == mOcMain && ROUND_UP(ic, SRC_UNITMain) == ic); // TODO:fix this + auto ocMain = mOcMain; + auto ocBranch = mOcBranch; auto target = mResourceInt8; - auto funcs = mRelatedFunctions; + auto funcsMain = mRelatedFunctions; + auto funcsBranch = mArm82Functions; auto needToReorderWeightOnline4Sme = mOnlineReorderWeightSme; // Save bias if (convOp->bias()) { - ::memcpy(mResourceInt8->mOriginBias->host(), convOp->bias()->data(), oc * sizeof(float)); + ::memcpy(mResourceInt8->mOriginBias->host(), convOp->bias()->data(), convOp->bias()->size() * sizeof(float)); } - auto function = [needToReorderWeightOnline4Sme, funcs, shape, UNIT, SRC_UNIT, DST_XUNIT, quanCommon, weightlen, scaleSize, directReadInt4weight, blockNum, ic, oc, kernelCount, pack, convOp, gcore, target]() -> int { + + auto reorderFunc = [=](decltype(mRelatedFunctions) funcs, std::vector shape, int UNIT, int SRC_UNIT, int DST_XUNIT, int weightlen, int scaleSize, int oc, int offsetTg, bool fastReadWeight, int8_t** addressPtr, weightSummerFuncion sumFunc) -> int { auto sh = shape; AutoStorage weightReordered(weightlen); AutoStorage reorderedQuantInfo(2 * scaleSize * QUANT_INFO_BYTES); @@ -379,26 +481,19 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } memset(kernelsum.get(), 0, blockNum * ROUND_UP(oc, UNIT) * QUANT_INFO_BYTES); - const uint8_t* srcPtr = nullptr; - if (quanCommon) { - srcPtr = (uint8_t*)quanCommon->weight.get(); - } else { - srcPtr = (uint8_t*)convOp->symmetricQuan()->weight()->data(); - } /* 1. reorder weight */ - if (target->mActBits == 4 && directReadInt4weight) { + auto srcPtr = (uint8_t*)addressPtr[0]; + if (target->mWeightBits == 4 && fastReadWeight) { auto dstPtr = (uint8_t*)weightReordered.get(); ::memset(dstPtr, 0, weightlen); funcs.MNNReorderWeightInt4(dstPtr, srcPtr, sh.data(), sh.size(), (float*)kernelsum.get()); - } else { + } else { // int4 weight but oc/ic not packed int blocksize = ic * kernelCount / blockNum; int originOffset = 0; int32_t info[6] = {blockNum, oc, ic, kernelCount, UNIT, SRC_UNIT}; - if (target->mActBits == 4) { + if (target->mWeightBits == 4) { originOffset = -8; - auto weightLength = quanCommon ? quanCommon->weight.size() : convOp->symmetricQuan()->weight()->size(); -// auto srcPtr = reinterpret_cast(quanCommon->weight.get()); - std::vector tmpWeight(weightLength * 2); + std::vector tmpWeight(oc * ic * kernelCount); for (int j = 0; j < oc; ++j) { for (int k = 0; k < blockNum; ++k) { for (int i = 0; i < blocksize; ++i) { @@ -414,12 +509,10 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O MNN_ERROR("Weight reorder memory not enough!\n"); return -1; } - if (!needToReorderWeightOnline4Sme) { - reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8); - } else { - reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8SmeHp64); - } - // pack two int4 to int8 + + reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), sumFunc); + + // pack two int4 to int8 int leng = weightlen * 2; auto srcint4Ptr = (uint8_t*)packedInt8weight.get(); auto dstint4Ptr = (uint8_t*)weightReordered.get(); @@ -430,7 +523,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O auto dst0 = dstint4Ptr + i * halfPermuteStride; for (int j = 0; j < halfPermuteStride; ++j) { int s0, s1, d; - if (DST_XUNIT == SME_INT8MATMUL_EP) { // SME2 + if (DST_XUNIT == GEMM_INT8_DST_XUNIT_SME2) { // SME2 s0 = src0[2 * j + 0]; s1 = src0[2 * j + 1]; d = s0 + (s1) * 16; @@ -443,11 +536,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } } } else { // int8 weight - if (!needToReorderWeightOnline4Sme) { - reorderWeight((uint8_t*)weightReordered.get(), srcPtr, info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8); - } else { - reorderWeight((uint8_t*)weightReordered.get(), srcPtr, info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8SmeHp64); - } + reorderWeight((uint8_t*)weightReordered.get(), srcPtr, info, 0, (float*)kernelsum.get(), sumFunc); } } if (convOp->symmetricQuan() && convOp->symmetricQuan()->bias()) { @@ -463,28 +552,60 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } /* 2. compute and order dequant scale&bias */ bool notConvertInt4ToInt8 = true; - bool canUseInt4 = ((quanCommon && quanCommon->canUseInt4) || (convOp->symmetricQuan() && convOp->symmetricQuan()->nbits() <= 4)) ? true : false; - if (quanCommon && quanCommon->canUseInt4 && !directReadInt4weight) { + if (target->mWeightBits == 4 && !fastReadWeight) { notConvertInt4ToInt8 = false; } - const float* quantInfoPtr = quanCommon ? quanCommon->alpha.get() : convOp->symmetricQuan()->scale()->data(); - bool asymmtric = (quanCommon && quanCommon->asymmetric) ? true: false; - _computeReorderQuantInfo(target, oc, kernelCount * ic, pack, reorderedQuantInfo, (float*)kernelsum.get(), UNIT, notConvertInt4ToInt8, canUseInt4, quantInfoPtr, asymmtric); + int32_t paramsKernelSum[2] = {blockNum, inputBlockNum * ROUND_UP(oc, UNIT)}; + float* weightKernelSum = (float*)addressPtr[2]; + float* quanScalePtr = (float*)addressPtr[3]; + _computeReorderQuantInfo(weightKernelSum, paramsKernelSum, (inputBlockQuantOption == 2), target->mWeightBits == 4, asyWeight, quanScalePtr, oc, kernelCount * ic, pack, reorderedQuantInfo, (float*)kernelsum.get(), UNIT, notConvertInt4ToInt8); /* 3. put weight and quantInfo together */ int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ROUND_UP(oc, pack)}; - ConvInt8TiledExecutor::packWeightAndQuantInfo(target->mWeightInt8->host(), (int8_t*)weightReordered.get(), reorderedQuantInfo.get(), params, QUANT_INFO_BYTES); + int8_t* weightInt8 = addressPtr[1]; + + ConvInt8TiledExecutor::packWeightAndQuantInfo(weightInt8, (int8_t*)weightReordered.get(), reorderedQuantInfo.get(), params, QUANT_INFO_BYTES); + + return 0; + }; + auto function = [=]() -> int { + bool fastReadWeight = (kernelCount == 1 && ROUND_UP(ocMain, UNITMain) == ocMain && ROUND_UP(ic, SRC_UNITMain) == ic); + weightSummerFuncion sumFunc = funcsMain.MNNSumWeightInt8; + if (mOnlineReorderWeightSme) { + sumFunc = funcsMain.MNNSumWeightInt8SmeHp128; + } + + int8_t* addressPtr[4]; + addressPtr[0] = quanCommon? quanCommon->weight.get() : (int8_t*)convOp->symmetricQuan()->weight()->data(); + addressPtr[1] = target->mWeightInt8->host(); + addressPtr[2] = target->mWeightKernelSum->host(); + addressPtr[3] = quanCommon? (int8_t*) quanCommon->alpha.get() : (int8_t*)convOp->symmetricQuan()->scale()->data(); + + reorderFunc(funcsMain, shapeMain, UNITMain, SRC_UNITMain, DST_XUNITMain, weightlenMain, scaleSizeMain, ocMain, 0, fastReadWeight, addressPtr, sumFunc); + + if (ocBranch > 0) { + // update the address of weight source, weight destination, weight kernel sum and weight scale + addressPtr[0] += (target->mWeightBits == 4 ? ocMain * ic * kernelCount / 2 : ocMain * ic * kernelCount); // ocMain%2==0, so divides 2 directly + addressPtr[1] += (weightlenMain + quantlenMain); + addressPtr[2] += ROUND_UP(ocMain, UNITMain) * inputBlockNum * QUANT_INFO_BYTES; + addressPtr[3] += (quanCommon->asymmetric ? 2 * ocMain * blockNum * QUANT_INFO_BYTES : ocMain * blockNum * QUANT_INFO_BYTES); + sumFunc = funcsBranch.MNNSumWeightInt8; + + fastReadWeight = (kernelCount == 1 && ROUND_UP(ocBranch, UNITMain) == ocBranch && ROUND_UP(ic, SRC_UNITMain) == ic); + reorderFunc(funcsBranch, shapeBranch, UNITBranch, SRC_UNITBranch, DST_XUNITBranch, weightlenBranch, scaleSizeBranch, ocBranch, 1, fastReadWeight, addressPtr, sumFunc); + } return 0; }; + static_cast(backend)->enqueueTask(std::move(function)); if (!isDynamicQuant) { mResourceInt8->mDynamicQuant = false; - std::shared_ptr scaleAndBias(new float[ocUpHp * 2 * mBlockNum], [](void* ptr) { + std::shared_ptr scaleAndBias(new float[ocUpHpMain * 2 * mBlockNum], [](void* ptr) { delete [] (float*)ptr; }); - memset(scaleAndBias.get(), 0, ocUpHp * 2 * mBlockNum * sizeof(float)); + memset(scaleAndBias.get(), 0, ocUpHpMain * 2 * mBlockNum * sizeof(float)); int weightSize; bool weightAsy = false; @@ -515,7 +636,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int scaleSize = quantCount / 2; for (int i = 0; i < scaleSize; ++i) { ((float*)scaleAndBias.get())[i] = quanCommon->alpha.get()[2 * i + 1]; - ((float*)scaleAndBias.get())[i + ocUpHp] = quanCommon->alpha.get()[2 * i]; + ((float*)scaleAndBias.get())[i + ocUpHpMain] = quanCommon->alpha.get()[2 * i]; } } } @@ -535,7 +656,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O if(convOp->symmetricQuan() && convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ mGemmKernel = mRelatedFunctions.Int8GemmKernelFast; } - if (mResourceInt8->mActBits == 4) { + if (mResourceInt8->mWeightBits == 4) { mGemmKernel = mRelatedFunctions.Int8GemmKernel_W4; } #endif @@ -568,36 +689,51 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input mUseBatchQuan = false; mIm2ColBasedInt8 = true; m4BitPtq = false; - if (mResourceInt8->mDynamicQuant == false && mResourceInt8->mActBits == 4) { + if (mResourceInt8->mDynamicQuant == false && mResourceInt8->mWeightBits == 4) { m4BitPtq = true; } + + // backend info + auto core = static_cast(backend())->int8Functions(); + auto gcore =static_cast(backend())->functions(); + const int threads = static_cast(backend())->threadNumber(); + mRelatedFunctions = *(static_cast(backend())->int8GemmFunctions()); + mArm82Functions = gcore->arm82MatmulRelatedFunctions; + + // runtime hint auto option = static_cast(backend())->getRuntime()->hint().dynamicQuantOption; + mSmeCores = gcore->smeCoreNumber; auto inputBlockQuantOption = option % WEIGHT_ONLINE_REORDER; auto weightOnlineReorderOption = WEIGHT_ONLINE_REORDER & option; + + _getProportions(static_cast(backend())->getRuntime()->hint().divisionRatio, mRatioPrefill, mRatioDecode); + + // feature map info int batch = inputs[0]->batch(); int inC = inputs[0]->channel(); auto output = outputs[0]; + int kernelCount = mCommon->kernelY() * mCommon->kernelX(); int inputPlane = batch * inputs[0]->width() * inputs[0]->height(); auto planeSize = output->width() * output->height() * output->batch(); - auto core = static_cast(backend())->int8Functions(); - auto gcore =static_cast(backend())->functions(); - const int threads = static_cast(backend())->threadNumber(); - - mRelatedFunctions = *(static_cast(backend())->int8GemmFunctions()); int UNIT, SRC_UNIT, DST_XUNIT; mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - mOnlineReorderWeightSme = (weightOnlineReorderOption > 0 && DST_XUNIT == SME_INT8MATMUL_EP); + + mOnlineReorderWeightSme = (weightOnlineReorderOption > 0 && DST_XUNIT == GEMM_INT8_DST_XUNIT_SME2); if (mResourceInt8->mDynamicQuant == false) { mOnlineReorderWeightSme = false; } + + _updateMixedKernelFlag(mMixedKernel, mOnlineReorderWeightSme, threads, DST_XUNIT, mResourceInt8->mDynamicQuant, mRatioDecode&&mRatioPrefill); + if (mOnlineReorderWeightSme && planeSize == 1) { // Decode, set runtime unit - UNIT = SME_DECODE_MAXHP; + UNIT = GEMM_INT8_UNIT_SME2_128; } + mGemmUnits[0] = UNIT; mGemmUnits[1] = SRC_UNIT; mGemmUnits[2] = DST_XUNIT; - int kernelCount = mCommon->kernelY() * mCommon->kernelX(); + bool fastway = (kernelCount == 1) && (output->width() == inputs[0]->width()) && (output->height() == inputs[0]->height()) && (mCommon->strideX() * mCommon->strideY()) == 1; if (inputPlane > 1) { mUseBatchQuan = true; @@ -613,7 +749,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } } - float weightBytes = mResourceInt8->mActBits == 4 ? 0.5 : 1; + float weightBytes = mResourceInt8->mWeightBits == 4 ? 0.5 : 1; mBlockNum = mResourceInt8->mBlockNum; CPUConvolution::onResize(inputs, outputs); @@ -633,7 +769,6 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } int matmulUnits[3] = {UNIT, SRC_UNIT, DST_XUNIT}; ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core, gcore->pack, matmulUnits); - // input scale buffer // Im2col info int im2colBytes = 1; @@ -648,6 +783,9 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int tileLimit = 0; int outC = output->channel(); int outC4 = UP_DIV(outC, gcore->pack); + mOcMain = outC; + mOcBranch = 0; + const int pack = gcore->pack; auto kernelCountUnit = mIm2ColParamter.kernelCountUnit; mSplitByOc = true; @@ -655,7 +793,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input float flop = gcore->bytes * planeSize * (ROUND_UP(output->channel(), gcore->pack) * kernelCountUnit * SRC_UNIT / 1024.0 / 1024.0 / 1024.0); float ios = (((CPUBackend*)backend())->getTensorSize(outputs[0], true) + ((CPUBackend*)backend())->getTensorSize(inputs[0], true) + ((CPUBackend*)backend())->getTensorSize(mResourceInt8->mWeightInt8.get()) * weightBytes) / (1024.0 * 1024.0 * 1024.0); - if (threads < planeSize || mOnlineReorderWeightSme) { // Thread split by output nhw. + if ((threads < planeSize || mOnlineReorderWeightSme) && !mMixedKernel) { // Thread split by output nhw. tileLimit = ALIMIN(tileLimitByC, UP_DIV(planeSize, threads)); mIm2ColCount = UP_DIV(tileLimit, DST_XUNIT); auto DynamicDestUnit = DST_XUNIT * mIm2ColCount; @@ -663,40 +801,62 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input if (mTileCount > threads || (mOnlineReorderWeightSme && planeSize > 1)) { mSplitByOc = false; } - } + if (mSplitByOc) { tileLimit = ALIMIN(tileLimitByC, planeSize); mIm2ColCount = UP_DIV(tileLimit, DST_XUNIT); auto DynamicDestUnit = DST_XUNIT * mIm2ColCount; mTileCount = UP_DIV(planeSize, DynamicDestUnit); - auto ocPerThread = UP_DIV(outC4, threads); - auto threadNeed = UP_DIV(outC4, ocPerThread); - int totalWork = outC4; - int part = 1; - if (UNIT > gcore->pack) { // AVX512:UNIT=64,pack=16 - MNN_ASSERT(UNIT % gcore->pack == 0); - int ocDivUnit = UP_DIV(outC4 * gcore->pack, UNIT); - ocPerThread = UP_DIV(ocDivUnit, threads); - threadNeed = UP_DIV(ocDivUnit, ocPerThread); - totalWork = ocDivUnit; - part = UNIT / gcore->pack; - } - mThreadNums = ALIMIN(threads, threadNeed); - mDivides.resize(threads+1); mDivides[0] = 0; - static_cast(backend())->computeDivideSizes(totalWork, mDivides.data() + 1, flop / ios); - for (int i = 0; i < mDivides.size(); ++i) { - mDivides[i] *= part; + // output channel divided by threads + if (!mMixedKernel) { + auto ocPerThread = UP_DIV(outC4, threads); + auto threadNeed = UP_DIV(outC4, ocPerThread); + int totalWork = outC4; + int part = 1; + if (UNIT > gcore->pack) { // AVX512:UNIT=64,pack=16 + MNN_ASSERT(UNIT % gcore->pack == 0); + int ocDivUnit = UP_DIV(outC4 * gcore->pack, UNIT); + ocPerThread = UP_DIV(ocDivUnit, threads); + threadNeed = UP_DIV(ocDivUnit, ocPerThread); + totalWork = ocDivUnit; + part = UNIT / gcore->pack; + } + mThreadNums = ALIMIN(threads, threadNeed); + + if (threads >= 4 && DST_XUNIT == GEMM_INT8_DST_XUNIT_SME2 && mResourceInt8->mDynamicQuant) { + _computeDivides4Sme(mDivides, threads, mSmeCores, totalWork); + } else { + mDivides.resize(threads+1); + mDivides[0] = 0; + static_cast(backend())->computeDivideSizes(totalWork, mDivides.data() + 1, flop / ios); + } + for (int i = 0; i < mDivides.size(); ++i) { + mDivides[i] *= part; + } + } else { + // workload + mOcMain = 0; // initialize for mixed kernel, before calculate + calculateSmeNeonWorkDivision(mOcMain, mOcBranch, mDivides, outC, threads, pack, planeSize, mRatioDecode, mSmeCores); + mThreadNums = threads; } } if (!mSplitByOc) { mThreadNums = ALIMIN(threads, mTileCount); - mDivides.resize(threads+1); - mDivides[0] = 0; - static_cast(backend())->computeDivideSizes(mTileCount, mDivides.data() + 1, flop / ios); + if (threads >= 4&&DST_XUNIT==GEMM_INT8_DST_XUNIT_SME2&&mResourceInt8->mDynamicQuant&&!mMixedKernel) { + _computeDivides4Sme(mDivides, threads, mSmeCores, mTileCount); + } else { + mDivides.resize(threads+1); + mDivides[0] = 0; + static_cast(backend())->computeDivideSizes(mTileCount, mDivides.data() + 1, flop / ios); + } + } + mDividesTmp.resize(threads + 1); + if (mMixedKernel) { + mOriginSmeWork = mDivides[mSmeCores]; } int ocUp4 = ROUND_UP(outC, gcore->pack); int k = mThreadNums; @@ -740,12 +900,12 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } #ifdef MNN_LOW_MEMORY - { // Dynamic Quant kernels + if (!mMixedKernel) { // Dynamic Quant kernels, use single gemm kernel. mGemmKernel = mRelatedFunctions.Int8GemmKernel; if (mOnlineReorderWeightSme && planeSize == 1) { mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP32_DecodeMax; } - if (mResourceInt8->mActBits == 4) { + if (mResourceInt8->mWeightBits == 4) { mGemmKernel = mRelatedFunctions.Int8GemmKernel_W4; if (mOnlineReorderWeightSme && planeSize == 1) { mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax; @@ -757,7 +917,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input if (mOnlineReorderWeightSme && planeSize == 1) { mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16_DecodeMax; } - if (mResourceInt8->mActBits == 4) { + if (mResourceInt8->mWeightBits == 4) { mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16; if (mOnlineReorderWeightSme && planeSize == 1) { mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16_DecodeMax; @@ -768,7 +928,46 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } // A axisSum kernel - mSumByAxisLFunc = mRelatedFunctions.MNNSumByAxisLForMatmul_A; + } else { // use sme and neon gemmInt8 + // Fp32 + if (planeSize == 1) { // Decode + mGemmKernels.push_back(mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP32_DecodeMax); + mGemmKernels.push_back(mArm82Functions.Int8GemmKernel); + if (mResourceInt8->mWeightBits == 4) { + mGemmKernels[0] = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax; + mGemmKernels[1] = mArm82Functions.Int8GemmKernel_W4; + } + } else { // Prefill + mGemmKernels.push_back(mRelatedFunctions.Int8GemmKernel); + mGemmKernels.push_back(mArm82Functions.Int8GemmKernel); + if (mResourceInt8->mWeightBits == 4) { + mGemmKernels[0] = mRelatedFunctions.Int8GemmKernel_W4; + mGemmKernels[1] = mArm82Functions.Int8GemmKernel_W4; + } + } + mQuantFunc = core->MNNFloat2Int8; + + // fp16 + if (gcore->bytes == 2 && gcore->pack == 8) { + if (planeSize == 1) { // Decode + mGemmKernels[0] = mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16_DecodeMax; + mGemmKernels[1] = mArm82Functions.MNNGemmInt8AddBiasScale_Unit_FP16; + if (mResourceInt8->mWeightBits == 4) { + mGemmKernels[0] = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16_DecodeMax; + mGemmKernels[1] = mArm82Functions.MNNGemmInt8AddBiasScale_w4_Unit_FP16; + } + } else { // Prefill + mGemmKernels[0] = mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16; + mGemmKernels[1] = mArm82Functions.MNNGemmInt8AddBiasScale_Unit_FP16; + if (mResourceInt8->mWeightBits == 4) { + mGemmKernels[0] = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16; + mGemmKernels[1] = mArm82Functions.MNNGemmInt8AddBiasScale_w4_Unit_FP16; + } + } + mQuantFunc = core->DynamicQuanInput_ARM82; + mQuantAndReorderFunc = core->DynamicQuanInputAndReorder_ARM82; + } + // A axisSum kernel } mInputBlockNum = (inputBlockQuantOption == 2) ? mBlockNum : 1; @@ -835,16 +1034,30 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input return OUT_OF_MEMORY; } if (mOnlineReorderWeightSme && planeSize > 1) { // only prefill need - int weightlenNew = ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; - if (mResourceInt8->mActBits == 4) { - weightlenNew /= 2; + int ocProcessedBySme = mOcMain; + int ocProcessedByNeon = 0; + if (mMixedKernel && mRatioDecode != mRatioPrefill) { + auto workUnit = UP_DIV(outC4, mRatioPrefill * mSmeCores + 1 * (threads - mSmeCores)); + ocProcessedBySme = ALIMIN(ROUND_UP(workUnit * pack * mSmeCores * mRatioPrefill, GEMM_INT8_UNIT_SME2_128), outC); + ocProcessedBySme = ALIMAX(ocProcessedBySme, mOcMain); + ocProcessedByNeon = outC - ocProcessedBySme; + } + int weightlenSme = ROUND_UP(ocProcessedBySme, GEMM_INT8_UNIT_SME2_128) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; + int weightlenNeon = ROUND_UP(ocProcessedByNeon, 8) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; + if (mResourceInt8->mWeightBits == 4) { + weightlenSme /= 2; + weightlenNeon /= 2; } - mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * mBlockNum * ROUND_UP(outC, SME_DECODE_MAXHP) * QUANT_INFO_BYTES); + int scalebiasLenSme = 2 * mBlockNum * ROUND_UP(ocProcessedBySme, GEMM_INT8_UNIT_SME2_128) * QUANT_INFO_BYTES; + int scalebiasLenNeon = 2 * mBlockNum * ROUND_UP(ocProcessedByNeon, 8) * QUANT_INFO_BYTES; + + + mWeight4Prefill = bufferAlloc->alloc(weightlenSme + scalebiasLenSme + weightlenNeon + scalebiasLenNeon); if (mWeight4Prefill.invalid()) { return OUT_OF_MEMORY; } if (mInputBlockNum > 1) { // only in this case, need to use weight_kernel_sum - mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * sizeof(float)); + mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, GEMM_INT8_UNIT_SME2_128) * mBlockNum * sizeof(float)); if (mWeightKernelSum4Prefill.invalid()) { return OUT_OF_MEMORY; } @@ -899,18 +1112,13 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input if (mTempOutput.invalid()) { return OUT_OF_MEMORY; } + bufferAlloc->free(mTempOutput); } backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mBatchQuantInfo.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mQuantInput.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mAccumBuffer.get(), Backend::DYNAMIC); - if (mBatchQuantInfo.get()) { - backend()->onReleaseBuffer(mBatchQuantInfo.get(), Backend::DYNAMIC); - } - - if (m4BitPtq) { - bufferAlloc->free(mTempOutput); - } return NO_ERROR; #else @@ -919,9 +1127,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input } -static void _onlineReorderWeight(int8_t* dst, int8_t* src, int hPSrc, int hPDst, int hU, int blockNum, int blockLu, int lp, bool int4weight) { - // The core assumption for this function's logic - // assert(hPSrc == hPDst * 4); +static void _onlineReorderWeightPackH128ToH32(int8_t* dst, int8_t* src, int hPSrc, int hPDst, int hU, int blockNum, int blockLu, int lp, bool int4weight) { + // hPSrc = 4 * hPDst int unitsize_ = hPDst * lp; if (int4weight) { @@ -1024,7 +1231,247 @@ static void _onlineReorderWeight(int8_t* dst, int8_t* src, int hPSrc, int hPDst, } } -static void _onlineReorderWeightKernelSum(float* dst, float* src, int blockNum, int hpSrc, int hpDst, int oc) { +static void _onlineReorderWeightPackH8ToH32(int8_t* dst, const int8_t* src, int blockLu, int lp, bool isInt4Weight, int srcH, int blockNum, int resOcBranch) { + constexpr int hPSrc = 8; + constexpr int hPDst = 32; + + int srcUnitLp = isInt4Weight ? lp / 2 : lp; + + const size_t srcUnitSize = (size_t)hPSrc * srcUnitLp; + const size_t dstUnitSize = (size_t)hPDst * srcUnitLp; + + const size_t srcStride1 = (size_t)blockLu * srcUnitSize + 2 * hPSrc * sizeof(float); + const size_t srcStride0 = (size_t)blockNum * srcStride1; + const size_t dstStride1 = (size_t)blockLu * dstUnitSize + 2 * hPDst * sizeof(float); + const size_t dstStride0 = (size_t)blockNum * dstStride1; + + const int hUDst = srcH / 4; + const int hTail = srcH % 4; + + for (int i = 0; i < hUDst; ++i) { + for (int k = 0; k < blockNum; ++k) { + auto weightSrcBase0 = src + (4 * i + 0) * srcStride0 + k * srcStride1; + auto weightSrcBase1 = src + (4 * i + 1) * srcStride0 + k * srcStride1; + auto weightSrcBase2 = src + (4 * i + 2) * srcStride0 + k * srcStride1; + auto weightSrcBase3 = src + (4 * i + 3) * srcStride0 + k * srcStride1; + auto weightDstBase = dst + i * dstStride0 + k * dstStride1; + + int lu = blockLu; + + // --- Reorder Weights --- + if (isInt4Weight) { + auto process_int4_block = [](uint8_t* dst_b, const uint8_t* src_b, size_t size) { + auto half_size = size / 2; + for (int s = 0; s < half_size; ++s) { + uint8_t p0 = src_b[2 * s]; + uint8_t p1 = src_b[2 * s + 1]; + dst_b[s] = (p1 & 0xF0) | (p0 >> 4); + dst_b[s + half_size] = (p1 << 4) | (p0 & 0x0F); + } + }; + while (lu >= 4) { + for (int j = 0; j < 4; ++j) { + const auto* srcPtr0 = (const uint8_t*)(weightSrcBase0 + j * srcUnitSize); + const auto* srcPtr1 = (const uint8_t*)(weightSrcBase1 + j * srcUnitSize); + const auto* srcPtr2 = (const uint8_t*)(weightSrcBase2 + j * srcUnitSize); + const auto* srcPtr3 = (const uint8_t*)(weightSrcBase3 + j * srcUnitSize); + auto* dstPtr = (uint8_t*)(weightDstBase + j * dstUnitSize); + + process_int4_block(dstPtr + 0 * srcUnitSize, srcPtr0, srcUnitSize); + process_int4_block(dstPtr + 1 * srcUnitSize, srcPtr1, srcUnitSize); + process_int4_block(dstPtr + 2 * srcUnitSize, srcPtr2, srcUnitSize); + process_int4_block(dstPtr + 3 * srcUnitSize, srcPtr3, srcUnitSize); + } + + weightSrcBase0 += 4 * srcUnitSize; + weightSrcBase1 += 4 * srcUnitSize; + weightSrcBase2 += 4 * srcUnitSize; + weightSrcBase3 += 4 * srcUnitSize; + weightDstBase += 4 * dstUnitSize; + lu -= 4; + } + + for (int j = 0; j < lu; ++j) { + const auto* srcPtr0 = (const uint8_t*)(weightSrcBase0); + const auto* srcPtr1 = (const uint8_t*)(weightSrcBase1); + const auto* srcPtr2 = (const uint8_t*)(weightSrcBase2); + const auto* srcPtr3 = (const uint8_t*)(weightSrcBase3); + auto* dstPtr = (uint8_t*)(weightDstBase); + + process_int4_block(dstPtr + 0 * srcUnitSize, srcPtr0, srcUnitSize); + process_int4_block(dstPtr + 1 * srcUnitSize, srcPtr1, srcUnitSize); + process_int4_block(dstPtr + 2 * srcUnitSize, srcPtr2, srcUnitSize); + process_int4_block(dstPtr + 3 * srcUnitSize, srcPtr3, srcUnitSize); + + weightSrcBase0 += srcUnitSize; + weightSrcBase1 += srcUnitSize; + weightSrcBase2 += srcUnitSize; + weightSrcBase3 += srcUnitSize; + weightDstBase += dstUnitSize; + } + } else { + while (lu >= 4) { + // j = 0 + memcpy(weightDstBase + 0 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 0 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 0 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 0 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 0 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 0 * srcUnitSize, srcUnitSize); + // j = 1 + memcpy(weightDstBase + 1 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 1 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 1 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 1 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 1 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 1 * srcUnitSize, srcUnitSize); + // j = 2 + memcpy(weightDstBase + 2 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 2 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 2 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 2 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 2 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 2 * srcUnitSize, srcUnitSize); + // j = 3 + memcpy(weightDstBase + 3 * dstUnitSize + 0 * srcUnitSize, weightSrcBase0 + 3 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 1 * srcUnitSize, weightSrcBase1 + 3 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 2 * srcUnitSize, weightSrcBase2 + 3 * srcUnitSize, srcUnitSize); + memcpy(weightDstBase + 3 * dstUnitSize + 3 * srcUnitSize, weightSrcBase3 + 3 * srcUnitSize, srcUnitSize); + + weightSrcBase0 += 4 * srcUnitSize; + weightSrcBase1 += 4 * srcUnitSize; + weightSrcBase2 += 4 * srcUnitSize; + weightSrcBase3 += 4 * srcUnitSize; + weightDstBase += 4 * dstUnitSize; + lu -= 4; + } + + for (int j = 0; j < lu; ++j) { + memcpy(weightDstBase + 0 * srcUnitSize, weightSrcBase0, srcUnitSize); + memcpy(weightDstBase + 1 * srcUnitSize, weightSrcBase1, srcUnitSize); + memcpy(weightDstBase + 2 * srcUnitSize, weightSrcBase2, srcUnitSize); + memcpy(weightDstBase + 3 * srcUnitSize, weightSrcBase3, srcUnitSize); + + weightSrcBase0 += srcUnitSize; + weightSrcBase1 += srcUnitSize; + weightSrcBase2 += srcUnitSize; + weightSrcBase3 += srcUnitSize; + weightDstBase += dstUnitSize; + } + } + + // --- Reorder scale and bias --- + const int scaleSrcSize = hPSrc * sizeof(float); + const int8_t* scaleSrcBase = src + (4 * i) * srcStride0 + k * srcStride1 + (size_t)blockLu * srcUnitSize; + int8_t* scaleDstBase = dst + i * dstStride0 + k * dstStride1 + (size_t)blockLu * dstUnitSize; + + memcpy(scaleDstBase + 0 * scaleSrcSize, scaleSrcBase + 0 * srcStride0, scaleSrcSize); + memcpy(scaleDstBase + 1 * scaleSrcSize, scaleSrcBase + 1 * srcStride0, scaleSrcSize); + memcpy(scaleDstBase + 2 * scaleSrcSize, scaleSrcBase + 2 * srcStride0, scaleSrcSize); + memcpy(scaleDstBase + 3 * scaleSrcSize, scaleSrcBase + 3 * srcStride0, scaleSrcSize); + + const int8_t* biasSrcBase = scaleSrcBase + scaleSrcSize; + int8_t* biasDstBase = scaleDstBase + hPDst * sizeof(float); + + memcpy(biasDstBase + 0 * scaleSrcSize, biasSrcBase + 0 * srcStride0, scaleSrcSize); + memcpy(biasDstBase + 1 * scaleSrcSize, biasSrcBase + 1 * srcStride0, scaleSrcSize); + memcpy(biasDstBase + 2 * scaleSrcSize, biasSrcBase + 2 * srcStride0, scaleSrcSize); + memcpy(biasDstBase + 3 * scaleSrcSize, biasSrcBase + 3 * srcStride0, scaleSrcSize); + } + } + + // --- 2. Process the tail --- + if (hTail > 0) { + // The last block starts at index hUDst. + const int i = hUDst; + for (int k = 0; k < blockNum; ++k) { + const int8_t* srcBases[4] = {nullptr, nullptr, nullptr, nullptr}; + for(int j = 0; j < hTail; ++j) { + srcBases[j] = src + (4 * i + j) * srcStride0 + k * srcStride1; + } + + auto weightDstBase = dst + i * dstStride0 + k * dstStride1; + + int lu = blockLu; + + if (isInt4Weight) { + auto process_int4_block = [](uint8_t* dst_b, const uint8_t* src_b, size_t size) { + auto half_size = size / 2; + for (int s = 0; s < half_size; ++s) { + uint8_t p0 = src_b[2 * s]; + uint8_t p1 = src_b[2 * s + 1]; + dst_b[s] = (p1 & 0xF0) | (p0 >> 4); + dst_b[s + half_size] = (p1 << 4) | (p0 & 0x0F); + } + }; + while (lu --> 0) { + for (int j = 0; j < hTail; ++j) { + process_int4_block( + (uint8_t*)(weightDstBase + j * srcUnitSize), + (const uint8_t*)(srcBases[j]), + srcUnitSize + ); + } + // For the remaining part of the destination block, set 0 + + if (hTail < 4) { + memset(weightDstBase + hTail * srcUnitSize, 0, (4 - hTail) * srcUnitSize); + } + + for(int j=0; j 0) { + for (int j = 0; j < hTail; ++j) { + memcpy(weightDstBase + j * srcUnitSize, srcBases[j], srcUnitSize); + } + // Zero out the rest of the destination block + if (hTail < 4) { + memset(weightDstBase + hTail * srcUnitSize, 0, (4 - hTail) * srcUnitSize); + } + + for(int j=0; j 0) { + size_t resLp = isInt4Weight ? lp / 2 : lp; + size_t resChannels = ROUND_UP(resOcBranch, hPSrc); + size_t resDataLen = (size_t)blockNum * ((size_t)blockLu * resChannels * resLp + 2 * resChannels * sizeof(float)); + + // The source for residual data starts after ALL processed srcH blocks. + memcpy(dst + (size_t)hUDst * dstStride0 + (hTail > 0 ? dstStride0 : 0), + src + (size_t)srcH * srcStride0, + resDataLen); + } +} + +static void _onlineReorderWeightKernelSumH128ToH32(float* dst, float* src, int blockNum, int hpSrc, int hpDst, int oc) { // hpSrc = 4 * hpDst // src shape: [huSrc, blockNum, hpSrc] // dst shape: [huDst, blockNum, hpDst], where huDst = huSrc * 4 @@ -1050,6 +1497,37 @@ static void _onlineReorderWeightKernelSum(float* dst, float* src, int blockNum, } } +static void _onlineReorderWeightKernelSumH8ToH32(float* dst, float* src, int blockNum, int hpSrc, int hpDst, int ocNeedReorder, int ocPreserve) { + // hpDst = 4 * hpSrc + // src shape: [huSrc, blockNum, hpSrc], where huSrc = huDst * 4 + // dst shape: [huDst, blockNum, hpDst] + + auto huDst = UP_DIV(ocNeedReorder, hpDst); + + auto strideSrc = blockNum * hpSrc; + auto strideDst = blockNum * hpDst; + + for (int i = 0; i < huDst; ++i) { + for (int k = 0; k < blockNum; ++k) { + auto dstBase = dst + i * strideDst + k * hpDst; + + auto src0 = src + (4 * i + 0) * strideSrc + k * hpSrc; + auto src1 = src + (4 * i + 1) * strideSrc + k * hpSrc; + auto src2 = src + (4 * i + 2) * strideSrc + k * hpSrc; + auto src3 = src + (4 * i + 3) * strideSrc + k * hpSrc; + + memcpy(dstBase, src0, hpSrc * sizeof(float)); + memcpy(dstBase + hpSrc, src1, hpSrc * sizeof(float)); + memcpy(dstBase + 2 * hpSrc, src2, hpSrc * sizeof(float)); + memcpy(dstBase + 3 * hpSrc, src3, hpSrc * sizeof(float)); + } + } + + if (ocPreserve) { + memcpy(dst + huDst * strideDst, src + 4 * huDst * strideSrc, ROUND_UP(ocPreserve, hpSrc) * blockNum * sizeof(float)); + } +} + ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inputs, const std::vector& outputs) { const auto input = inputs[0]; auto output = outputs[0]; @@ -1115,6 +1593,10 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } } + // Declare variables used in dynamic quantization + const int threads = static_cast(backend())->threadNumber(); + int dropBranch = 0; + #ifdef MNN_LOW_MEMORY auto BatchAsyDynamicQuant = [&](uint8_t* floatPtr, int32_t& inputZero, uint8_t* inputDequantScale, int LDiv4, int eCount, int innerSide, int32_t availableThreads, int8_t* dstInt8, uint8_t* inputDequantBias, int tId) { // if mIm2ColBasedInt8=false, input shape: [kernelsize,mBlockNum,blocklu,EP,LP] @@ -1154,7 +1636,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } if (mToFuseInputbias2Bias) { // Decode - inputZero = qbias[0]; + inputZero = roundf(qbias[0]); auto updatedBiasPtr = (float*)(mBiasBufferFusedInputzero.ptr() + tId * ocUpHp * QUANT_INFO_BYTES); auto matmulBiasPtr = mResourceInt8->mOriginBias->host(); auto weightKernelSum = mResourceInt8->mWeightKernelSum->host(); @@ -1212,14 +1694,55 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu if (mOnlineReorderWeightSme && plane > 1) { - _onlineReorderWeight((int8_t*)mWeight4Prefill.ptr(), weightDataPtr, SME_DECODE_MAXHP, UNIT, UP_DIV(oc, SME_DECODE_MAXHP), mBlockNum, blockL, SRC_UNIT, mResourceInt8->mActBits == 4); - weightDataPtr = (int8_t*)mWeight4Prefill.ptr(); + _onlineReorderWeightPackH128ToH32((int8_t*)mWeight4Prefill.ptr(), weightDataPtr, GEMM_INT8_UNIT_SME2_128, UNIT, UP_DIV(mOcMain, GEMM_INT8_UNIT_SME2_128), mBlockNum, blockL, SRC_UNIT, mResourceInt8->mWeightBits == 4); + + int kernelSumMainSize = 0; + int kernelSumBranchSize = 0; if (dstBytes > 1 && mInputBlockNum > 1) { - _onlineReorderWeightKernelSum((float*)mWeightKernelSum4Prefill.ptr(), mResourceInt8->mWeightKernelSum->host(), mBlockNum, SME_DECODE_MAXHP, UNIT, oc); + _onlineReorderWeightKernelSumH128ToH32((float*)mWeightKernelSum4Prefill.ptr(), mResourceInt8->mWeightKernelSum->host(), mBlockNum, GEMM_INT8_UNIT_SME2_128, UNIT, mOcMain); + kernelSumMainSize = ROUND_UP(mOcMain, UNIT) * mBlockNum * QUANT_INFO_BYTES; + kernelSumBranchSize = ROUND_UP(mOcBranch, 8) * mBlockNum * QUANT_INFO_BYTES; + } + + // If change the workload distribution among SME and NEON cores. + if (mMixedKernel && mRatioDecode != mRatioPrefill) { + auto offsetWeight = UP_DIV(mOcMain, GEMM_INT8_UNIT_SME2_128) * mBlockNum * blockL * SRC_UNIT * GEMM_INT8_UNIT_SME2_128; + if (mResourceInt8->mWeightBits == 4) { + offsetWeight /= 2; + } + offsetWeight += (ROUND_UP(mOcMain, GEMM_INT8_UNIT_SME2_128) * mBlockNum * 2 * sizeof(float)); + + // Don't change mOcMain&mOcBranch here. + int tmpMain = mOcMain; + int tmpBranch = mOcBranch; + calculateSmeNeonWorkDivision(tmpMain, tmpBranch, mDividesTmp, oc, threads, PackUnit, plane, mRatioPrefill, mSmeCores); + auto updatedSmeWork = mDividesTmp[mSmeCores]; + + + if (updatedSmeWork - mOriginSmeWork > 0 && ((updatedSmeWork - mOriginSmeWork) * 4 % 8 == 0)) { // To ensure pack=4, dropBranch % 2 == 0 + dropBranch = updatedSmeWork - mOriginSmeWork; // Ensure update "dropBranch" inner the loop. + memcpy(mDivides.data(), mDividesTmp.data(), (threads+1) * sizeof(float)); + dropBranch = mDivides[mSmeCores] - mOriginSmeWork; + _onlineReorderWeightPackH8ToH32((int8_t*)(mWeight4Prefill.ptr() + offsetWeight), weightDataPtr + offsetWeight, blockL, SRC_UNIT, mResourceInt8->mWeightBits == 4, (int)(dropBranch * PackUnit / 8), mBlockNum, (mDivides[threads] - mDivides[mSmeCores]) * PackUnit); + } + + if (dstBytes > 1 && mInputBlockNum > 1) { + if (dropBranch > 0) { + // reorder + _onlineReorderWeightKernelSumH8ToH32((float*)(mWeightKernelSum4Prefill.ptr() + kernelSumMainSize), (float*)(mResourceInt8->mWeightKernelSum->host() + kernelSumMainSize), mBlockNum, 8, UNIT, dropBranch * PackUnit, (mDivides[threads] - mDivides[mSmeCores]) * PackUnit); + } + } + } + + if (dropBranch == 0) { // If dropBranch == 0, it means that the arrangement of the weights processed by the Arm82 architecture remains unchanged. + // copy + memcpy(mWeightKernelSum4Prefill.ptr() + kernelSumMainSize, mResourceInt8->mWeightKernelSum->host() + kernelSumMainSize, kernelSumBranchSize); } + + weightDataPtr = (int8_t*)mWeight4Prefill.ptr(); } #endif - if (mResourceInt8->mActBits == 4) { + if (mResourceInt8->mWeightBits == 4) { weightBytes = 0.5; weightStepY /= 2; } @@ -1270,6 +1793,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu QuanPostTreatParameters quanParam; quanParam.blockNum = mBlockNum; int32_t indices[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + quanParam.indices = indices; if (dstBytes != 1) { quanParam.useInt8 = 0; quanParam.fp32minmax = reluPtr; @@ -1286,7 +1810,6 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu quanParam.minValue = mMutableResource->mClampMin; } } - quanParam.indices = indices; auto weightPtrTid = weightDataPtr; quanParam.weightKernelSum = ptrY; quanParam.biasFloat = reinterpret_cast(biasPtr); @@ -1487,17 +2010,25 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu int ocIndex = PackUnit * mDivides[tId]; auto ocDivThread = ALIMIN(mDivides[tId + 1] - mDivides[tId], ocDiv4 - mDivides[tId]); - if (ocIndex < ocUp4) { + if (ocIndex < ocUp4 && ocDivThread > 0) { + decltype(mGemmKernel) gemmInt8; + if (mMixedKernel) { + gemmInt8 = tId < mSmeCores ? mGemmKernels[0] : mGemmKernels[1]; + } else { + gemmInt8 = mGemmKernel; + } auto im2colDstThread = im2colDst; float* ptrY = nullptr; if (dstBytes != 1) { - ptrY = mResourceInt8->mWeightKernelSum->host() + (ocIndex / UNIT) * UNIT * mInputBlockNum; + float* wkernelSum = (mOnlineReorderWeightSme && mInputBlockNum > 1 && plane > 1) ? (float*)mWeightKernelSum4Prefill.ptr() : mResourceInt8->mWeightKernelSum->host(); + ptrY = wkernelSum + ocIndex * mInputBlockNum; } QuanPostTreatParameters quanParam; quanParam.blockNum = mBlockNum; quanParam.weightKernelSum = ptrY; quanParam.biasFloat = reinterpret_cast(biasPtr + ocIndex * 4); int32_t indices[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + quanParam.indices = indices; if (dstBytes != 1) { quanParam.useInt8 = 0; quanParam.fp32minmax = reluPtr; @@ -1514,7 +2045,6 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu quanParam.minValue = mMutableResource->mClampMin; } } - quanParam.indices = indices; uint8_t* inputScale = nullptr; // input scale for batch dynamic quant. uint8_t* inputBias = nullptr; float* accumbuff = nullptr; @@ -1531,7 +2061,14 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } auto outputInTilePtr = outputDataPtr + ocIndex * plane * dstBytes; - const auto weightPtrTid = weightDataPtr + static_cast(ocIndex * mBlockNum * blockL * SRC_UNIT * weightBytes + ocIndex * 2 * mBlockNum * QUANT_INFO_BYTES); + + auto weightSrc = weightDataPtr; + if (tId >= mSmeCores && dropBranch == 0 && mMixedKernel) { + weightSrc = mResourceInt8->mWeightInt8->host(); + } + + auto weightPtrTid = weightSrc + static_cast(ocIndex * mBlockNum * blockL * SRC_UNIT * weightBytes + ocIndex * 2 * mBlockNum * QUANT_INFO_BYTES); + int realDstCount = plane; auto ptrX = xKernelSumPtr; do { @@ -1543,7 +2080,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu memset(accumbuff, 0, UNIT * 4 * DST_XUNIT); quanParam.accumBuffer = accumbuff; } - mGemmKernel(outputInTilePtr, im2colDstThread, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step); + gemmInt8(outputInTilePtr, im2colDstThread, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step); ptrX += (step * mBlockNum); realDstCount-=step; outputInTilePtr += DST_XUNIT * PackUnit * dstBytes; @@ -1556,7 +2093,6 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu MNN_CONCURRENCY_END(); }; - const int threads = static_cast(backend())->threadNumber(); if (!mSplitByOc) { MNN_CONCURRENCY_BEGIN(tId, threads) { if (mDivides[tId + 1] - mDivides[tId] > 0) { diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp index 0bbe951e19..55890dd2ae 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp @@ -76,22 +76,32 @@ class DenseConvInt8TiledExecutor : public ConvInt8TiledExecutor { // for 4Bit Ptq model MemChunk mTempOutput; std::vector mDivides; + std::vector mDividesTmp; + std::vector mGemmKernels; int mGemmUnits[3]; int mThreadNums; int mBlockNum = 1; int mInputBlockNum = 1; int mOcPerThread; + int mOcMain; + int mOcBranch = 0; + int mRatioPrefill; + int mRatioDecode; + int mSmeCores = 2; + int mOriginSmeWork = 0; + int mSizeInputBlockQuant; bool mSplitByOc; bool mUseBatchQuan; bool mIm2ColBasedInt8; - int mSizeInputBlockQuant; bool mToFuseInputbias2Bias; bool mOnlineReorderWeightSme = false; // for 4Bit Ptq model bool m4BitPtq = false; + bool mMixedKernel; MatmulRelatedFunctions mRelatedFunctions; + MatmulRelatedFunctions mArm82Functions; }; } // namespace MNN diff --git a/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp b/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp index 5a95dba8ba..2c200e42f1 100644 --- a/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp @@ -132,7 +132,7 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara const auto srcCountUnit = UP_DIV(input->channel(), SRC_UNIT); dstIm2ColParamter.kernelCountUnit = srcCountUnit * kernelCount; dstIm2ColParamter.ic = srcCountUnit * SRC_UNIT; - + if (SRC_UNIT > pack) { // Carefully change it. dstIm2ColParamter.icup4 = ROUND_UP(input->channel(), pack); } else { diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.cpp b/source/backend/cpu/compute/Int8FunctionsOpt.cpp index e418b9ad4d..9d37a47513 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.cpp +++ b/source/backend/cpu/compute/Int8FunctionsOpt.cpp @@ -1432,7 +1432,7 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co float fp32min = 0, fp32max = 0; int weight_step_Z = src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) + 4 * 2 * GEMM_INT8_UNIT; int weight_step_Y = (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); - + if (0 == post->useInt8 && post->fp32minmax) { fp32min = (post->fp32minmax)[0]; fp32max = (post->fp32minmax)[1]; @@ -1441,7 +1441,7 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co float* biasPtr = (float*)post->biasFloat; auto accumbuff = post->accumBuffer; auto blockNum = post->blockNum; - + for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_z = dst + dz * dst_step; auto accum_z = accumbuff; @@ -1453,7 +1453,9 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; const auto srcSumPtr = post->srcKernelSum + bk * realCount; - + + const auto inputScalePtr = post->inputBias ? post->inputScale + bk * realCount : post->inputScale; + for (int w = 0; w < realCount; ++w) { const auto src_x = src + bk * src_depth_quad * GEMM_INT8_SRC_UNIT * realCount + w * GEMM_INT8_SRC_UNIT; auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; @@ -1475,7 +1477,7 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co for (int j = 0; j < GEMM_INT8_UNIT; ++j) { float value = dstTemp[j] * scale_dz[j] + srcSumPtr[w] * weightBias_dz[j]; if (post->inputScale) { - value = dstTemp[j] * scale_dz[j] * (post->inputScale + bk * realCount)[w] + srcSumPtr[w] * weightBias_dz[j]; + value = dstTemp[j] * scale_dz[j] * inputScalePtr[w] + srcSumPtr[w] * weightBias_dz[j]; } if (post->inputBias) { auto weightKernelSum = post->weightKernelSum + dz * (blockNum * GEMM_INT8_UNIT) + bk * GEMM_INT8_UNIT; @@ -1493,11 +1495,11 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co if (post->fp32minmax) { value = std::min(std::max(fp32min, value), fp32max); } - ((float*)dst_x)[j] = value; + ((float*)dst_x)[j] = value; } else { ((float*)accum_x)[j] = value; } - + } else { value += bias_dz[j]; value = ALIMAX(value, post->minValue); @@ -1535,8 +1537,9 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT; const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; const auto srcSumPtr = post->srcKernelSum + bk * realCount; + const auto inputScalePtr = post->inputBias ? post->inputScale + bk * realCount : post->inputScale; for (int w = 0; w < realCount; ++w) { - const auto src_x = src + w * GEMM_INT8_SRC_UNIT; + const auto src_x = src + bk * src_depth_quad * GEMM_INT8_SRC_UNIT * realCount + w * GEMM_INT8_SRC_UNIT; auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; auto accum_x = accum_z + w * GEMM_INT8_UNIT; int32_t dstTemp[4] = {0, 0, 0, 0}; @@ -1562,7 +1565,7 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, for (int j = 0; j < GEMM_INT8_UNIT; ++j) { float value = dstTemp[j] * scale_dz[j] + srcSumPtr[w] * weightBias_dz[j]; if (post->inputScale) { - value = dstTemp[j] * scale_dz[j] * (post->inputScale + bk * realCount)[w] + srcSumPtr[w] * weightBias_dz[j]; + value = dstTemp[j] * scale_dz[j] * inputScalePtr[w] + srcSumPtr[w] * weightBias_dz[j]; } if (post->inputBias) { auto weightKernelSum = post->weightKernelSum + dz * (blockNum * GEMM_INT8_UNIT) + bk * GEMM_INT8_UNIT; @@ -1579,8 +1582,8 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, } if (post->fp32minmax) { value = std::min(std::max(fp32min, value), fp32max); - } - ((float*)dst_x)[j] = value; + } + ((float*)dst_x)[j] = value; } else { ((float*)accum_x)[j] = value; } @@ -1685,7 +1688,7 @@ static void MNNLineDepthWiseInt8AddBiasScaleUnit3x3(int8_t* dst, const int8_t* s void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec) { // quanParamVec: - // 00: scale is vector + // 01: scale is vector // 10: zero is vector // 11: both are vector float scale4[4] = {scalep[0], scalep[0], scalep[0], scalep[0] }; @@ -1704,7 +1707,7 @@ void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* } for (int i = 0; i < sizeQuad; ++i) { for (int j=0; j<4; ++j) { - int v = (int)roundf(src[4*i+j] * scale4[j]) + zero4[j]; + int v = (int)roundf(src[4*i+j] * scale4[j] + zero4[j]); if (v > maxValue) { v = maxValue; } @@ -2111,7 +2114,7 @@ static void _ArmBasicMNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** sou int lRemain = l / 4; int lR4 = lR / 4; int lS = LUNIT - lR4; - + if (lastBag && e + eR < EP) { int elast = ALIMAX(eR + e, realDstCount % EP); dest = (int32_t*)(destOrigin + lC * elast * LP + lR + eC * info[2] + eR * LP); @@ -2168,7 +2171,7 @@ static void _ArmBasicMNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** sou } source += eReal * step; } - + while (lRemain > 0) { int step = ALIMIN(lRemain, LUNIT); for (int x=0; xInt8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_16x4_Unit_FAST; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnit; + core->int8MatmulRelatedFunctions.eP = GEMM_INT8_DST_XUNIT; #ifdef MNN_LOW_MEMORY gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_16x4_w4_Unit; #endif @@ -2463,19 +2467,29 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV82_Unit; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSdot; + core->int8MatmulRelatedFunctions.eP = GEMM_INT8_DST_XUNIT_ARM82; // Im2Col - gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4<12, 8>; + gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4; // ConvDepthwise gCoreFunc->ConvDepthwise3x3LineInt8_ARM82 = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3; core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM82; + + core->arm82MatmulRelatedFunctions.Int8GemmKernel = gCoreFunc->Int8GemmKernel; + core->arm82MatmulRelatedFunctions.Int8GemmKernelFast = gCoreFunc->Int8GemmKernelFast; + core->arm82MatmulRelatedFunctions.MNNGetGemmUnit = gCoreFunc->MNNGetGemmUnit; + core->arm82MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gCoreFunc->MNNPackC4Int8ForMatMul_A; + core->arm82MatmulRelatedFunctions.MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM82; #if defined(MNN_LOW_MEMORY) #ifdef MNN_USE_ARMV82 gCoreFunc->DynamicQuanInput_ARM82 = DynamicQuanInput_ARM82; gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16; gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16; gCoreFunc->DynamicQuanInputAndReorder_ARM82 = DynamicQuanInputAndReorder_ARM82; + core->arm82MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16; + core->arm82MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16; #endif gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV82_w4_Unit; + core->arm82MatmulRelatedFunctions.Int8GemmKernel_W4 = gCoreFunc->Int8GemmKernel_W4; #endif } if (core->supportI8mm) { @@ -2483,18 +2497,19 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm; + core->int8MatmulRelatedFunctions.eP = GEMM_INT8_DST_XUNIT_ARM86; core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM86; #if defined(MNN_LOW_MEMORY) gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit; - + #ifdef MNN_USE_ARMV82 gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16; gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16; #endif #endif // Im2Col - gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>; + gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A; } #endif // __aarch64__ { @@ -2505,7 +2520,7 @@ void MNNCoreInt8FunctionInit() { core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16; core->int8MatmulRelatedFunctions.MNNGetGemmUnit = gCoreFunc->MNNGetGemmUnit; core->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gCoreFunc->MNNPackC4Int8ForMatMul_A; - + core->int8MatmulRelatedFunctions.MNNSumByAxisLForMatmul_A = core->MNNSumByAxisLForMatmul_A; } @@ -2519,24 +2534,31 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp16; gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp16; core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_SME2; - gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<16, 4, 32>; + gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGetGemmUnit = MNNGetGemmUnitSme2_HP32; - core->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp32; - core->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernel = MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale16x32_SME2_w4_Fp16; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp16; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_SME2; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<16, 4, 32>; - core->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernelFast = MNNGemmInt8AddBiasScale16x32_SME2_w8_Fp32; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGemmInt8AddBiasScale_w4_Unit_FP16_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGemmInt8AddBiasScale_Unit_FP16_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGemmInt8AddBiasScale_Unit_FP32_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32; - core->sme2Int8MatmulRelatedFuncionsHp32.MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32; + // Only Sme2 has + core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp16; + core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp16; + core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP32_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32; + core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP32_DecodeMax = MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32; + core->int8MatmulRelatedFunctions.eP = GEMM_INT8_DST_XUNIT_SME2; } #endif #endif + + { // Update the function pointers in the int8MatmulRelatedFunctions struct. + core->int8MatmulRelatedFunctions.Int8GemmKernel = gCoreFunc->Int8GemmKernel; + core->int8MatmulRelatedFunctions.Int8GemmKernelFast = gCoreFunc->Int8GemmKernelFast; + core->int8MatmulRelatedFunctions.Int8GemmKernel_W4 = gCoreFunc->Int8GemmKernel_W4; + core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16; + core->int8MatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16; + core->int8MatmulRelatedFunctions.MNNGetGemmUnit = gCoreFunc->MNNGetGemmUnit; + core->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gCoreFunc->MNNPackC4Int8ForMatMul_A; + + core->int8MatmulRelatedFunctions.MNNSumByAxisLForMatmul_A = core->MNNSumByAxisLForMatmul_A; + + } MNNInt8FunctionInit(); } CoreInt8Functions* MNNGetInt8CoreFunctions() { diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.h b/source/backend/cpu/compute/Int8FunctionsOpt.h index 9ff7c23137..f20708e754 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.h +++ b/source/backend/cpu/compute/Int8FunctionsOpt.h @@ -24,14 +24,30 @@ typedef SSIZE_T ssize_t; #define GEMM_INT8_SRC_UNIT 16 #ifndef MNN_USE_SSE #ifdef __aarch64__ - #define GEMM_INT8_DST_XUNIT 4 + #define GEMM_INT8_DST_XUNIT 4 #else - #define GEMM_INT8_DST_XUNIT 2 -#endif + #define GEMM_INT8_DST_XUNIT 2 + #endif #else #define GEMM_INT8_DST_XUNIT 4 #endif +/* CPU supports sdot */ +#define GEMM_INT8_UNIT_ARM82 8 +#define GEMM_INT8_SRC_UNIT_ARM82 4 +#define GEMM_INT8_DST_XUNIT_ARM82 12 + +/* CPU supports i8mm */ +#define GEMM_INT8_UNIT_ARM86 8 +#define GEMM_INT8_SRC_UNIT_ARM86 8 +#define GEMM_INT8_DST_XUNIT_ARM86 10 + +/* CPU supports sme2 */ +#define GEMM_INT8_UNIT_SME2 32 +#define GEMM_INT8_SRC_UNIT_SME2 4 +#define GEMM_INT8_DST_XUNIT_SME2 16 +#define GEMM_INT8_UNIT_SME2_128 128 + #ifdef __cplusplus extern "C" { #endif @@ -86,10 +102,8 @@ struct CoreInt8Functions { void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT); void(*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el); - void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, - const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; - void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, - const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; + void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; + void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); // sparse diff --git a/source/backend/cpu/riscv/rvv/MNNMatrixAdd.cpp b/source/backend/cpu/riscv/rvv/MNNMatrixAdd.cpp new file mode 100644 index 0000000000..513febe1ce --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMatrixAdd.cpp @@ -0,0 +1,26 @@ +#include + +void MNNMatrixAdd(float *C, const float *A, const float *B, + size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height) { + size_t total = widthC4 * 4; + for (size_t y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + bStride * y; + auto c = C + cStride * y; + + size_t n = total; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl); + vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl); + vfloat32m8_t vc = __riscv_vfadd_vv_f32m8(va, vb, vl); + __riscv_vse32_v_f32m8(c, vc, vl); + + a += vl; + b += vl; + c += vl; + n -= vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMatrixMax.cpp b/source/backend/cpu/riscv/rvv/MNNMatrixMax.cpp new file mode 100644 index 0000000000..2d0f8c2493 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMatrixMax.cpp @@ -0,0 +1,26 @@ +#include + +void MNNMatrixMax(float *C, const float *A, const float *B, + size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height) { + size_t total = widthC4 * 4; + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + bStride * y; + auto c = C + cStride * y; + + size_t n = total; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl); + vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl); + vfloat32m8_t vc = __riscv_vfmax_vv_f32m8(va, vb, vl); + __riscv_vse32_v_f32m8(c, vc, vl); + + a += vl; + b += vl; + c += vl; + n -= vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMatrixSub.cpp b/source/backend/cpu/riscv/rvv/MNNMatrixSub.cpp new file mode 100644 index 0000000000..06ffa4e461 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMatrixSub.cpp @@ -0,0 +1,26 @@ +#include + +void MNNMatrixSub(float *C, const float *A, const float *B, + size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height) { + size_t total = widthC4 * 4; + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + bStride * y; + auto c = C + cStride * y; + + size_t n = total; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t va = __riscv_vle32_v_f32m8(a, vl); + vfloat32m8_t vb = __riscv_vle32_v_f32m8(b, vl); + vfloat32m8_t vc = __riscv_vfsub_vv_f32m8(va, vb, vl); + __riscv_vse32_v_f32m8(c, vc, vl); + + a += vl; + b += vl; + c += vl; + n -= vl; + } + } +} diff --git a/source/backend/cpu/x86_x64/AVX2Functions.cpp b/source/backend/cpu/x86_x64/AVX2Functions.cpp index ce6ea6cb30..5ad97a136b 100644 --- a/source/backend/cpu/x86_x64/AVX2Functions.cpp +++ b/source/backend/cpu/x86_x64/AVX2Functions.cpp @@ -62,7 +62,6 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1; // Dynamic Quant coreFunction->MNNCountMaxMinValue = _AVX_MNNCountMinMaxValue; - coreFunction->MNNSoftmax = _AVX_MNNSoftmax; // For Packed Functions @@ -112,6 +111,7 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->int8MatmulRelatedFunctions.Int8GemmKernel_W4 = gAVX2CoreInt8Functions->Int8GemmKernel_W4; coreFunction->int8MatmulRelatedFunctions.MNNGetGemmUnit = gAVX2CoreInt8Functions->MNNGetGemmUnit; coreFunction->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A; + coreFunction->int8MatmulRelatedFunctions.eP = 4; } return true; } diff --git a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp index 0bdc3e42f2..139f2124a1 100644 --- a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp +++ b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp @@ -24,7 +24,7 @@ struct FunctionGroup { int lP = 1; int hP = 4; void (*MNNExpC8)(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8; - void (*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) = _SSE_MNNSoftmax; + void (*MNNSoftmax)(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) = MNNSoftmax; void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) = _SSE_MNNReluInt8; void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish; void (*MNNGelu)(float* dst, const float* src, size_t size, float* parameters) = _SSE_MNNGelu; @@ -66,7 +66,6 @@ void MNNFunctionInit() { // Dynamic Quant coreFunction->MNNCountMaxMinValue = _SSE_MNNCountMinMaxValue; - coreFunction->MNNSoftmax = _SSE_MNNSoftmax; } #ifdef MNN_USE_AVX if (cpuFlags & libyuv::kCpuHasAVX2) { @@ -116,7 +115,7 @@ void MNNMaxPoolInt8_(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputW for (int y = 0; y < kernely; ++y) { for (int x = 0; x < kernelx; ++x) { const int8_t* inputPtr = srcPtr + pack* (x + inputWidth* y); - for (int idx = 0; idx < pack; ++idx) { + for (int idx = 0; idx < pack; ++idx) { results[idx] = std::max(results[idx], *(inputPtr + idx)); } } @@ -207,8 +206,8 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) { _SSE_MNNInt8ToInt16(dest, source, count); } -void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { - gFunc.MNNSoftmax(softmaxDst, input, runningMax, runningSum, updateScale, outside, reduceSize); +void MNNSoftmax(float* softmaxDst, const float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) { + gFunc.MNNSoftmax(softmaxDst, input, runningMax, runningSum, updateScale, outside, reduceSize, kvSeqOffset, validOffset, pack, mask); } void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) { diff --git a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp index a2d35834e0..5fcc0076f3 100644 --- a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp @@ -53,7 +53,7 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8); -void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); +void _AVX_MNNSoftmax(float* softmaxDst, const float* softmaxSrc, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask); void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec); void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, const float* zeroPoint, ssize_t quanParamVec); void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder); diff --git a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp index 9f98b76ac8..d07616c1e9 100644 --- a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp @@ -220,8 +220,8 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias } else { qscale[0] = 255.f / range; scale[0] = range / 255.f; - qbias[0] = roundf(-minval * 255.f / range)- 128.f; - bias[0] = -qbias[0] * scale[0]; + qbias[0] = -minval * 255.f / range- 128.f; + bias[0] = minval; } return; } @@ -262,14 +262,14 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias quantScale4 = _mm_blendv_ps(quantScale4, _0f, mask); dequantScale4 = _mm_blendv_ps(dequantScale4, _0f, mask); - quantBias4 = _mm_round_ps(_mm_blendv_ps(quantBias4, _0f, mask), 0); + quantBias4 = _mm_blendv_ps(quantBias4, _0f, mask); dequantBias4 = _mm_blendv_ps(dequantBias4, max4, mask); _mm_storeu_ps(scalePtr, dequantScale4); _mm_storeu_ps(biasPtr, dequantBias4); _mm_storeu_ps(qscale + qind, quantScale4); _mm_storeu_ps(qbias + qind, quantBias4); - + realDstCount -= DST_XUNIT; qind += DST_XUNIT; scalePtr += (blockNum * DST_XUNIT); @@ -382,7 +382,7 @@ void _AVX_MNNDynamicQuant(const float* src, int8_t* dst, const float* scale, siz _mm256_storeu_si256((__m256i *)tmp, r0_8); dstPtr[0] = tmp[0]; dstPtr[1] = tmp[4]; - + // next round xcount--; scalePtr += 1; diff --git a/source/backend/cpu/x86_x64/avx/GemmInt8.cpp b/source/backend/cpu/x86_x64/avx/GemmInt8.cpp index 0931ff3dd9..87cd340106 100644 --- a/source/backend/cpu/x86_x64/avx/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmInt8.cpp @@ -87,7 +87,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) / 2; int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float)* GEMMINT8_AVX2_H; const __m128i mask = _mm_set1_epi8(0xf); - + auto srcKernelSumPtr = post->srcKernelSum; __m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3; auto neg128_f = _mm256_set1_ps(-128.f); @@ -160,7 +160,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const auto D1 = D01; auto D2 = D02; auto D3 = D03; - auto scaleValue = _mm256_loadu_ps(scale_dz); + auto scaleValue = _mm256_loadu_ps(scale_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto f0 = _mm256_cvtepi32_ps(D0); @@ -380,12 +380,12 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const } } return; - } + } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; - + for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; @@ -462,7 +462,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); - } + } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); @@ -478,7 +478,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const } } return; - } + } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; @@ -551,14 +551,14 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const f0 = _mm256_max_ps(f0, fp32min); } _mm256_storeu_ps(((float*)dst_x), f0); - + } else { _mm256_storeu_ps(((float*)accum_x) , f0); } } } return; - } + } } @@ -588,7 +588,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); int weight_step_Z = src_depth_quad * weight_step_Y + 2 * sizeof(float) * GEMMINT8_AVX2_H; - + auto srcKernelSumPtr = post->srcKernelSum; __m256 kernelSum0, kernelSum1, kernelSum2, kernelSum3; auto neg128_f = _mm256_set1_ps(-128.f); @@ -930,12 +930,12 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons } } return; - } + } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; auto accum_x = accumbuff; - + for (int bk = 0; bk < blockNum; ++bk) { // block's weight&scale&bias const auto weight_dz = weight + dz * blockNum * weight_step_Z + bk * weight_step_Z; @@ -1032,7 +1032,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons auto biasValue = _mm256_loadu_ps(bias_dz); f0 = _mm256_add_ps(f0, biasValue); f1 = _mm256_add_ps(f1, biasValue); - } + } if (post->fp32minmax) { f0 = _mm256_min_ps(f0, fp32max); f1 = _mm256_min_ps(f1, fp32max); @@ -1049,7 +1049,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons } } return; - } + } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { auto dst_x = dst + dz * dst_step_tmp; @@ -1134,7 +1134,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons f0 = _mm256_max_ps(f0, fp32min); } _mm256_storeu_ps(((float*)dst_x), f0); - + } else { _mm256_storeu_ps(((float*)accum_x) , f0); } @@ -1142,7 +1142,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons } } return; - } + } } void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { @@ -1297,7 +1297,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, // D2 = _mm256_add_epi32(D2, biasValue0); auto scaleValue = _mm256_loadu_ps(scale_dz); - + auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); @@ -1332,7 +1332,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, } } return; - } + } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * weight_step_Z; @@ -1386,7 +1386,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, } } return; - } + } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * weight_step_Z; @@ -1493,7 +1493,7 @@ void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, d1 = _mm256_cvtps_epi32(_mm256_round_ps(f1, 3)); d0 = _mm256_add_epi32(d0, offset); d1 = _mm256_add_epi32(d1, offset); - + d0 = _mm256_permute4x64_epi64(_mm256_packs_epi32(d0, d1), 0xD8); d0 = _mm256_min_epi16(d0, maxValue); d0 = _mm256_max_epi16(d0, minValue); @@ -1532,7 +1532,8 @@ void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const fl m0 = _mm256_blendv_ps(plus, minus, m0); f0 = _mm256_add_ps(f0, m0); // 3: _MM_FROUND_TO_ZERO - auto d0 = _mm256_cvtps_epi32(_mm256_round_ps(f0, 3)); + auto r0 = _mm256_round_ps(f0, 3); + auto d0 = _mm256_cvtps_epi32(r0); d0 = _mm256_add_epi32(d0, offset); d0 = _mm256_packs_epi32(d0, _mm256_setzero_si256()); d0 = _mm256_permute4x64_epi64(d0, 0xD8); diff --git a/source/backend/cpu/x86_x64/avx/MathFunctions.cpp b/source/backend/cpu/x86_x64/avx/MathFunctions.cpp index fde08bbd31..25ca932c11 100644 --- a/source/backend/cpu/x86_x64/avx/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/avx/MathFunctions.cpp @@ -116,72 +116,129 @@ void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* offset[3] = total; } +void _AVX_MNNSoftmax(float* softmaxDst, const float* softmaxSrc, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) { + const int packUnit = 8; + int reduceSizeOuter = 1; + int reduceSizeInner = reduceSize; + int stride0 = packUnit; + if (pack > 1) { + reduceSizeOuter = UP_DIV(reduceSize, pack); + reduceSizeInner = pack; + stride0 = outside * reduceSizeInner; + } -void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { - const float xLimit = 87.0f; - const float param = 0.6931471805599453f; // ln(2) - const float inv_param = 1.0f / param; - const int32_t exp_offset = 127; - const float exp_scale = 8388608.0f; // 2^23 + float tmp[8]; + float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; for (int k = 0; k < outside; ++k) { - float* source = input + k * reduceSize; - float* dest = softmaxDst + k * reduceSize; + exprOffset[3] = 0.0f; // init sum to zero for each outer loop + if (mask && kvSeqOffset > k + validOffset) { + if (updateScale){ + updateScale[k] = 1; + } + for (int j = 0; j < reduceSizeOuter; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, reduceSizeInner * sizeof(float)); + } + continue; + } - float tmpfloat8[8]; - int count = reduceSize/ 8; - int remain = count * 8; - // step 1: get maxValue - float maxValue = source[0]; + const int validReduceSize = mask ? ALIMIN(reduceSize, k + (validOffset + 1) - kvSeqOffset) : reduceSize; + const int remain = validReduceSize % packUnit; + const int sizeDiv = validReduceSize / packUnit; + const float floatLowest = std::numeric_limits::lowest(); - float oldMax = maxValue; + // 1. newMax + float oldMax = floatLowest; if (runningMax) { oldMax = runningMax[k]; } - if (count > 0) { - auto maxVal = _mm256_loadu_ps(source); - for (int i = 1; i < count; i++) { - maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8)); - } - _mm256_storeu_ps(tmpfloat8, maxVal); - maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1]; - for (int i = 2; i < 8; i++) { - maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i]; + __m256 maxVec = _mm256_set1_ps(floatLowest); + for (int j = 0; j < sizeDiv; ++j) { + auto srcPtr = softmaxSrc + j * stride0 + k * reduceSizeInner; + __m256 srcVec = _mm256_loadu_ps(srcPtr); + maxVec = _mm256_max_ps(maxVec, srcVec); + } + _mm256_storeu_ps(tmp, maxVec); + float newMax = tmp[0]; + for (int i = 1; i < 8; ++i) { + newMax = ALIMAX(newMax, tmp[i]); + } + + if (remain > 0) { + auto srcPtr = softmaxSrc + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + newMax = ALIMAX(newMax, srcPtr[i]); } } - for (int i = remain; i < reduceSize; i++) { - maxValue = maxValue > source[i] ? maxValue : source[i]; + + const float finalMax = ALIMAX(oldMax, newMax); + exprOffset[2] = -finalMax; + + // 2. exp(x - finalMax) and Sum + for (int j = 0; j < sizeDiv; ++j) { + auto idx = j * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + MNNExp(dstPtr, srcPtr, exprOffset, packUnit); } - float newMax = ALIMAX(oldMax, maxValue); + float sum = exprOffset[3]; + + if (remain > 0) { + auto idx = sizeDiv * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; - // step 2: get exp(x - newMax) and sum(exp(x - newMax)) - float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; - exprOffset[2] = -newMax; - MNNExp(dest, source, exprOffset, reduceSize); - float sumValue = exprOffset[3]; + for (int i = 0; i < remain; ++i) { + float val = expf(srcPtr[i] - finalMax); + sum += val; + dstPtr[i] = val; + } + } + // 3. Normalization or update state if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { - // === Step 3: Update running variables === - float scale = expf(oldMax - newMax); - runningSum[k] = runningSum[k] * scale + sumValue; - runningMax[k] = newMax; - updateScale[k] = scale; + float scaleForSum = expf(oldMax - finalMax); + runningSum[k] = runningSum[k] * scaleForSum + sum; + runningMax[k] = finalMax; + updateScale[k] = scaleForSum; } else { - // step 3: get x / sum and store - for (int i = 0; i < count; ++i) { - // using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu - auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i)); - auto y = _mm256_set1_ps(sumValue); - auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y)); - _mm256_storeu_ps(dest + 8 * i, z); + if (runningMax != nullptr && runningSum != nullptr) { + sum += runningSum[k] * expf(oldMax - finalMax); + } + float scale = 1.0f / (sum + 1e-20f); + __m256 scaleVec = _mm256_set1_ps(scale); + + for (int j = 0; j < sizeDiv; ++j) { + auto pDest = softmaxDst + j * stride0 + k * reduceSizeInner; + __m256 data = _mm256_loadu_ps(pDest); + data = _mm256_mul_ps(data, scaleVec); + _mm256_storeu_ps(pDest, data); } - auto scale = 1.f / sumValue; - for (int i = remain; i < reduceSize; i++) { - dest[i] *= scale; + if (remain > 0) { + auto pDest = softmaxDst + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + pDest[i] *= scale; + } } } + + // 4. memset 0 + if (pack > 1) { + if (validReduceSize % pack > 0) { + memset(softmaxDst + (UP_DIV(validReduceSize, pack) - 1) * stride0 + k * reduceSizeInner + (validReduceSize % pack), 0, (pack - (validReduceSize % pack)) * sizeof(float)); + } + auto validOuter = UP_DIV(validReduceSize, pack); + auto allOuter = UP_DIV(reduceSize, pack); + for (int j = validOuter; j < allOuter; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, pack * sizeof(float)); + } + } else { + memset(softmaxDst + k * reduceSizeInner + validReduceSize, 0, (reduceSize - validReduceSize) * sizeof(float)); + } } } diff --git a/source/backend/cpu/x86_x64/avx/PackedFunction.cpp b/source/backend/cpu/x86_x64/avx/PackedFunction.cpp index 33a464d144..393b250671 100644 --- a/source/backend/cpu/x86_x64/avx/PackedFunction.cpp +++ b/source/backend/cpu/x86_x64/avx/PackedFunction.cpp @@ -50,12 +50,12 @@ void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes); -#endif +void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes, int seqStart); +#endif } #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { +void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes, int seqStart) { // source shape: [headDim/pack, seqLen, pack] // scale & normalizeScale shape: [seqLen] // dest shape: [headDim/pack, seqLen, pack] @@ -63,7 +63,8 @@ void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scal if (idx > 0) { for (int j = 0; j < depthQuad; ++j) { - for (int i = 0; i < plane; ++i) { + int i = seqStart; + for (; i < plane; ++i) { auto dataNew = Vec::load(src + j * stride0 + i * pack); auto dataOld = Vec::load(dst + j * stride0 + i * pack); auto s = Vec(scale[i]); diff --git a/source/backend/cpu/x86_x64/avx/ReorderFunctions.cpp b/source/backend/cpu/x86_x64/avx/ReorderFunctions.cpp index dea230bfd1..a4cf16d210 100644 --- a/source/backend/cpu/x86_x64/avx/ReorderFunctions.cpp +++ b/source/backend/cpu/x86_x64/avx/ReorderFunctions.cpp @@ -36,7 +36,7 @@ void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth, auto r5 = _mm256_loadu_ps(s + 5 * srcAreaOffset); auto r6 = _mm256_loadu_ps(s + 6 * srcAreaOffset); auto r7 = _mm256_loadu_ps(s + 7 * srcAreaOffset); - + TRANSPOSE_8x8; _mm256_storeu_ps(d + PACK_UNIT * 0, t0); @@ -247,15 +247,15 @@ void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_ } } -void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) { +void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* offset) { int c = (int)depth; int cDiv4 = c / PACK_UNIT; int cAlign = cDiv4 * PACK_UNIT; - auto srcAreaOffset = areaOffset[0]; - auto dstAreaOffset = areaOffset[1]; + auto srcAreaOffset = offset[0]; + auto dstDepthOffset = offset[1]; for (int hi = 0; hi < area; ++hi) { const float* srcHeight = src + hi * PACK_UNIT; - float* dstHeight = dst + hi * c; + float* dstHeight = dst + hi * dstDepthOffset; for (int ci = 0; ci < cDiv4; ++ci) { _mm256_storeu_ps(dstHeight + PACK_UNIT * ci, _mm256_loadu_ps(srcHeight + PACK_UNIT * ci * srcAreaOffset)); } @@ -271,7 +271,7 @@ void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, siz for (int hi = 0; hi < area; ++hi) { const float* srcHeight = srcAlign + hi * PACK_UNIT; - float* dstHeight = dstAlign + hi * c; + float* dstHeight = dstAlign + hi * dstDepthOffset; for (int ci = 0; ci < cReamin; ++ci) { dstHeight[ci] = srcHeight[ci]; diff --git a/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp b/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp index a8a56d8907..536445e9fc 100644 --- a/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp +++ b/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp @@ -160,7 +160,7 @@ static void _AVX512_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, fl qscale[0] = 255.f / range; scale[0] = range / 255.f; qbias[0] = roundf(-minval * 255.f / range)- 128.f; - bias[0] = -qbias[0] * scale[0]; + bias[0] = minval; } return; } @@ -208,7 +208,7 @@ static void _AVX512_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, fl _mm_storeu_ps(biasPtr, dequantBias4); _mm_storeu_ps(qscale + qind, quantScale4); _mm_storeu_ps(qbias + qind, quantBias4); - + realDstCount -= DST_XUNIT; qind += DST_XUNIT; scalePtr += (blockNum * DST_XUNIT); @@ -363,7 +363,7 @@ static void _AVX512_DynamicQuant(const float* src, int8_t* dst, const float* sca dstPtr[1] = tmp[4 * 1]; dstPtr[2] = tmp[4 * 2]; dstPtr[3] = tmp[4 * 3]; - + // next round xcount--; scalePtr += 1; @@ -1064,8 +1064,130 @@ static void _AVX512_MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFu } } +static void _AVX512_MNNSoftmax(float* softmaxDst, const float* softmaxSrc, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) { + const int packUnit = 16; + int reduceSizeOuter = 1; + int reduceSizeInner = reduceSize; + int stride0 = packUnit; + if (pack > 1) { + reduceSizeOuter = UP_DIV(reduceSize, pack); + reduceSizeInner = pack; + stride0 = outside * reduceSizeInner; + } + + float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; + for (int k = 0; k < outside; ++k) { + exprOffset[3] = 0.f; + if (mask && kvSeqOffset > k + validOffset) { + if (updateScale){ + updateScale[k] = 1; + } + for (int j = 0; j < reduceSizeOuter; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, reduceSizeInner * sizeof(float)); + } + continue; + } + + const int validReduceSize = mask ? ALIMIN(reduceSize, k + (validOffset + 1) - kvSeqOffset) : reduceSize; + const int remain = validReduceSize % packUnit; + const int sizeDiv = validReduceSize / packUnit; + const float floatLowest = std::numeric_limits::lowest(); + + // 1. newMax + float oldMax = floatLowest; + if (runningMax) { + oldMax = runningMax[k]; + } + + __m512 maxVec = _mm512_set1_ps(floatLowest); + for (int j = 0; j < sizeDiv; ++j) { + auto srcPtr = softmaxSrc + j * stride0 + k * reduceSizeInner; + __m512 srcVec = _mm512_loadu_ps(srcPtr); + maxVec = _mm512_max_ps(maxVec, srcVec); + } + float newMax = _mm512_reduce_max_ps(maxVec); + + if (remain > 0) { + auto srcPtr = softmaxSrc + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + newMax = ALIMAX(newMax, srcPtr[i]); + } + } + + const float finalMax = ALIMAX(oldMax, newMax); + const __m512 finalMaxVec = _mm512_set1_ps(finalMax); + exprOffset[2] = -finalMax; + + // 2. exp(x - finalMax) and Sum + __m512 sumVec = _mm512_setzero_ps(); + for (int j = 0; j < sizeDiv; ++j) { + auto idx = j * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + + MNNExp(dstPtr, srcPtr, exprOffset, packUnit); + } + + float sum = exprOffset[3]; + + if (remain > 0) { + auto idx = sizeDiv * stride0 + k * reduceSizeInner; + auto srcPtr = softmaxSrc + idx; + auto dstPtr = softmaxDst + idx; + for (int i = 0; i < remain; ++i) { + float val = expf(srcPtr[i] - finalMax); + sum += val; + dstPtr[i] = val; + } + } + + // 3. Normalization or update state + if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { + float scaleForSum = expf(oldMax - finalMax); + runningSum[k] = runningSum[k] * scaleForSum + sum; + runningMax[k] = finalMax; + updateScale[k] = scaleForSum; + } else { + if (runningMax != nullptr && runningSum != nullptr) { + sum += runningSum[k] * expf(oldMax - finalMax); + } + float scale = 1.0f / (sum + 1e-20f); + __m512 scaleVec = _mm512_set1_ps(scale); + + for (int j = 0; j < sizeDiv; ++j) { + auto pDest = softmaxDst + j * stride0 + k * reduceSizeInner; + __m512 data = _mm512_loadu_ps(pDest); + data = _mm512_mul_ps(data, scaleVec); + _mm512_storeu_ps(pDest, data); + } + if (remain > 0) { + auto pDest = softmaxDst + sizeDiv * stride0 + k * reduceSizeInner; + for (int i = 0; i < remain; ++i) { + pDest[i] *= scale; + } + } + } + + // 4. memset 0 for padding (逻辑不变) + if (pack > 1) { + if (validReduceSize % pack > 0) { + memset(softmaxDst + (UP_DIV(validReduceSize, pack) - 1) * stride0 + k * reduceSizeInner + (validReduceSize % pack), 0, (pack - (validReduceSize % pack)) * sizeof(float)); + } + auto validOuter = UP_DIV(validReduceSize, pack); + auto allOuter = UP_DIV(reduceSize, pack); + for (int j = validOuter; j < allOuter; ++j) { + auto destPtr = softmaxDst + j * stride0 + k * reduceSizeInner; + memset(destPtr, 0, pack * sizeof(float)); + } + } else { + memset(softmaxDst + k * reduceSizeInner + validReduceSize, 0, (reduceSize - validReduceSize) * sizeof(float)); + } + } +} + #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { +void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes, int seqStart) { // source shape: [headDim/pack, seqLen, pack] // scale & normalizeScale shape: [seqLen] // dest shape: [headDim/pack, seqLen, pack] @@ -1073,7 +1195,8 @@ void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* s if (idx > 0) { for (int j = 0; j < depthQuad; ++j) { - for (int i = 0; i < plane; ++i) { + int i = seqStart; + for (; i < plane; ++i) { auto dataNew = Vec::load(src + j * stride0 + i * pack); auto dataOld = Vec::load(dst + j * stride0 + i * pack); auto s = Vec(scale[i]); @@ -1116,7 +1239,8 @@ void _AVX512_ExtraInit(void* functions) { coreFunction->MNNAsyQuantInfo = _AVX512_MNNAsyQuantInfo; coreFunction->MNNAsyQuantFunc = _AVX512_MNNAsyQuantFunc; coreFunction->MNNCountMaxMinValue = _AVX512_MNNCountMinMaxValue; - + coreFunction->MNNSoftmax = _AVX512_MNNSoftmax; + coreFunction->MNNConvRunForLineDepthwise = _AVX512_MNNConvRunForLineDepthwise; coreFunction->MNNAxByClampBroadcastUnit = _AVX512_MNNAxByClampBroadcastUnit; coreFunction->MNNStrassenMergeCFunction = _AVX512_MNNStrassenMergeCFunction; @@ -1134,7 +1258,7 @@ void _AVX512_ExtraInit(void* functions) { coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode; coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel; -#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX512_MNNFlashAttentionUpdateBlockOutput; #endif } diff --git a/source/backend/cpu/x86_x64/avx512/ReorderFunctions.cpp b/source/backend/cpu/x86_x64/avx512/ReorderFunctions.cpp index ff3185de1e..422fad8e65 100644 --- a/source/backend/cpu/x86_x64/avx512/ReorderFunctions.cpp +++ b/source/backend/cpu/x86_x64/avx512/ReorderFunctions.cpp @@ -46,7 +46,7 @@ void _AVX512_MNNPackCUnit(float* dst, const float* src, size_t area, size_t dept LOAD_CASE(15); #undef LOAD_CASE transpose16x16F(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14, r15); - + #define SAVE_CASE(i) _mm512_storeu_ps(d + PACK_UNIT * i, r##i) SAVE_CASE(0); SAVE_CASE(1); @@ -310,15 +310,15 @@ void _AVX512_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, si } } -void _AVX512_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) { +void _AVX512_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* offset) { int c = (int)depth; int cDiv4 = c / PACK_UNIT; int cAlign = cDiv4 * PACK_UNIT; - auto srcAreaOffset = areaOffset[0]; - auto dstAreaOffset = areaOffset[1]; + auto srcAreaOffset = offset[0]; + auto dstDepthOffset = offset[1]; for (int hi = 0; hi < area; ++hi) { const float* srcHeight = src + hi * PACK_UNIT; - float* dstHeight = dst + hi * c; + float* dstHeight = dst + hi * dstDepthOffset; for (int ci = 0; ci < cDiv4; ++ci) { _mm512_storeu_ps(dstHeight + PACK_UNIT * ci, _mm512_loadu_ps(srcHeight + PACK_UNIT * ci * srcAreaOffset)); } @@ -334,7 +334,7 @@ void _AVX512_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, for (int hi = 0; hi < area; ++hi) { const float* srcHeight = srcAlign + hi * PACK_UNIT; - float* dstHeight = dstAlign + hi * c; + float* dstHeight = dstAlign + hi * dstDepthOffset; for (int ci = 0; ci < cReamin; ++ci) { dstHeight[ci] = srcHeight[ci]; diff --git a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp index 5c389ee734..da9d20bfc7 100644 --- a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp @@ -81,7 +81,6 @@ void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose); void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint); -void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); void _SSE_ExtraInit(void* functions); void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm); void _SSE_ImageProcessInit(void* functions, int cpuFlags); diff --git a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp index 4e907981c8..8dcfae2eab 100644 --- a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp @@ -71,7 +71,7 @@ void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s #ifdef MNN_LOW_MEMORY // Dynamic quant void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { - size_t srcStep = realSize * pack; + size_t srcStep = realSize * pack; __m128 mask = _mm_set1_ps(-0.0f); if (pack == 4) { // input c4 float tmp[4]; @@ -551,7 +551,7 @@ void _SSE_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias qscale[0] = 255.f / range; scale[0] = range / 255.f; qbias[0] = roundf(-minval * 255.f / range)- 128.f; - bias[0] = -qbias[0] * scale[0]; + bias[0] = minval; } return; } @@ -599,7 +599,7 @@ void _SSE_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias _mm_storeu_ps(biasPtr, dequantBias4); _mm_storeu_ps(qscale + qind, quantScale4); _mm_storeu_ps(qbias + qind, quantBias4); - + realDstCount -= DST_XUNIT; qind += DST_XUNIT; scalePtr += (blockNum * DST_XUNIT); diff --git a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp index 87b903fa1d..e6ffc50359 100644 --- a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp @@ -69,71 +69,6 @@ void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float* offset[3] = total; } -void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { - const float xLimit = 87.0f; - const float param = 0.6931471805599453f; // ln(2) - const float inv_param = 1.0f / param; - const int32_t exp_offset = 127; - const float exp_scale = 8388608.0f; // 2^23 - - for (int k = 0; k < outside; ++k) { - float* source = input + k * reduceSize; - float* dest = softmaxDst + k * reduceSize; - - float tmpfloat4[4]; - int count = static_cast(reduceSize / 4); - int remain = count * 4; - // step 1: get maxValue - float maxValue = source[0]; - float oldMax = maxValue; - if (runningMax) { - oldMax = runningMax[k]; - } - if (count > 0) { - auto maxVal = _mm_loadu_ps(source); - for (int i = 1; i < count; i++) { - maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4)); - } - _mm_storeu_ps(tmpfloat4, maxVal); - maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1]; - maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2]; - maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3]; - } - for (int i = remain; i < reduceSize; i++) { - maxValue = maxValue > source[i] ? maxValue : source[i]; - } - - float newMax = ALIMAX(oldMax, maxValue); - - // step 2: get exp(x - newMax) and sum(exp(x - newMax)) - float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; - exprOffset[2] = -newMax; - MNNExp(dest, source, exprOffset, reduceSize); - float sumValue = exprOffset[3]; - - if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { - // === Step 3: Update running variables === - float scale = expf(oldMax - newMax); - runningSum[k] = runningSum[k] * scale + sumValue; - runningMax[k] = newMax; - updateScale[k] = scale; - } else { - // step 3: get x / sum and store - for (int i = 0; i < count; ++i) { - // using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu - auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i)); - auto y = _mm_set1_ps(sumValue); - auto z = _mm_rcp_ps(_mm_mul_ps(x, y)); - _mm_storeu_ps(dest + 4 * i, z); - } - auto scale = 1.f / sumValue; - for (int i = remain; i < reduceSize; i++) { - dest[i] *= scale; - } - } - } -} - void _SSE_MNNGelu(float* dst, const float* src, size_t size, float* parameters) { // parameters[8] = {0.044715f, 0.79788458f, 378.f, 17325.f, 135135.f, 28.f, 3150.f, 62370.f}; auto var1 = _mm_set1_ps(parameters[0]); diff --git a/source/backend/metal/ConvSimdGroupShader.hpp b/source/backend/metal/ConvSimdGroupShader.hpp index b9c9ba981a..e47e034692 100644 --- a/source/backend/metal/ConvSimdGroupShader.hpp +++ b/source/backend/metal/ConvSimdGroupShader.hpp @@ -8,9 +8,10 @@ #if MNN_METAL_ENABLED -const char* gConv1x1W4SgMatrix = R"metal( +const char* gBasicConvPrefix = R"metal( #include #include + using namespace metal; typedef enum : int { None = 0, @@ -52,14 +53,126 @@ struct conv1x1_constants { float scale_coef; }; +namespace MNN { + typedef struct uchar4x2 { + private: + uchar2 v[4]; + public: + uchar4x2(uchar2 a) { + v[0] = a; v[1] = a; v[2] = a; v[3] = a; + } + uchar4x2(uchar2 a, uchar2 b, uchar2 c, uchar2 d) { + v[0] = a; v[1] = b; v[2] = c; v[3] = d; + } + + inline thread uchar2& operator[] (const int index) { + return v[index]; + } + inline device uchar2& operator[] (const int index) device { + return v[index]; + } + inline threadgroup uchar2& operator[] (const int index) threadgroup { + return v[index]; + } + + inline const thread uchar2& operator[] (const int index) const { + return v[index]; + } + inline const device uchar2& operator[] (const int index) const device { + return v[index]; + } + inline const threadgroup uchar2& operator[] (const int index) const threadgroup { + return v[index]; + } + + inline explicit operator half4x2() const { + return half4x2( half2(v[0]), half2(v[1]), half2(v[2]), half2(v[3]) ); + } + inline explicit operator half4x2() const device { + return half4x2( half2(v[0]), half2(v[1]), half2(v[2]), half2(v[3]) ); + } + inline explicit operator half4x2() const threadgroup { + return half4x2( half2(v[0]), half2(v[1]), half2(v[2]), half2(v[3]) ); + } + + inline explicit operator float4x2() const { + return float4x2( float2(v[0]), float2(v[1]), float2(v[2]), float2(v[3]) ); + } + inline explicit operator float4x2() const device { + return float4x2( float2(v[0]), float2(v[1]), float2(v[2]), float2(v[3]) ); + } + inline explicit operator float4x2() const threadgroup { + return float4x2( float2(v[0]), float2(v[1]), float2(v[2]), float2(v[3]) ); + } + } uchar4x2; + + typedef struct char4x4 { + private: + char4 v[4]; + public: + char4x4(char4 a) { + v[0] = a; v[1] = a; v[2] = a; v[3] = a; + } + char4x4(char4 a, char4 b, char4 c, char4 d) { + v[0] = a; v[1] = b; v[2] = c; v[3] = d; + } + + inline thread char4& operator[] (const int index) { + return v[index]; + } + inline device char4& operator[] (const int index) device { + return v[index]; + } + inline threadgroup char4& operator[] (const int index) threadgroup { + return v[index]; + } + + inline const thread char4& operator[] (const int index) const { + return v[index]; + } + inline const device char4& operator[] (const int index) const device { + return v[index]; + } + inline const threadgroup char4& operator[] (const int index) const threadgroup { + return v[index]; + } + + inline explicit operator half4x4() const { + return half4x4( half4(v[0]), half4(v[1]), half4(v[2]), half4(v[3]) ); + } + inline explicit operator half4x4() const device { + return half4x4( half4(v[0]), half4(v[1]), half4(v[2]), half4(v[3]) ); + } + inline explicit operator half4x4() const threadgroup { + return half4x4( half4(v[0]), half4(v[1]), half4(v[2]), half4(v[3]) ); + } + + inline explicit operator float4x4() const { + return float4x4( float4(v[0]), float4(v[1]), float4(v[2]), float4(v[3]) ); + } + inline explicit operator float4x4() const device { + return float4x4( float4(v[0]), float4(v[1]), float4(v[2]), float4(v[3]) ); + } + inline explicit operator float4x4() const threadgroup { + return float4x4( float4(v[0]), float4(v[1]), float4(v[2]), float4(v[3]) ); + } + } char4x4; +} + +#if MNN_METAL_FLOAT16_STORAGE +typedef simdgroup_half8x8 simdgroup_FTYPE8x8; +#else +typedef simdgroup_float8x8 simdgroup_FTYPE8x8; +#endif + #if MNN_METAL_FLOAT32_COMPUTER -typedef simdgroup_float8x8 simdgroup_T8x8; +typedef simdgroup_float8x8 simdgroup_FLOAT8x8; typedef float FLOAT; typedef float2 FLOAT2; typedef float4 FLOAT4; typedef float4x4 FLOAT4x4; #else -typedef simdgroup_half8x8 simdgroup_T8x8; +typedef simdgroup_half8x8 simdgroup_FLOAT8x8; typedef half FLOAT; typedef half2 FLOAT2; typedef half4 FLOAT4; @@ -71,9 +184,9 @@ typedef half4x4 FLOAT4x4; #define CONV_UNROLL_L (8) #define INIT_SIMDGROUP_MATRIX(a, b, d) \ - simdgroup_T8x8 sga[a];\ - simdgroup_T8x8 sgb[b];\ - simdgroup_T8x8 sgd[d];\ + simdgroup_FTYPE8x8 sga[a];\ + simdgroup_FTYPE8x8 sgb[b];\ + simdgroup_FLOAT8x8 sgd[d];\ for (int i = 0; i < d; i++){\ sgd[i] = make_filled_simdgroup_matrix(0.f);\ } @@ -89,7 +202,10 @@ typedef half4x4 FLOAT4x4; for(int i=0; i input: [K4, M32, K8] + ftype 1024~3071 ---> weight: [K4, K8, N64] + ftype 3072~3199 ---> scale/offset: [N64, 2] + // Write: + ftype 0~2047 ---> input: [M2, N2, N2, N2, M2, M8, N8] + */ + + threadgroup FLOAT4 sdata[768] = {(FLOAT)0.f}; + + INIT_SIMDGROUP_MATRIX(2, 4, 8); + + int rx = gid.x;// M/32 + int uz = gid.y;// N/64 + + // A:[4, 2, 16] + int ko = tiitg / 32;// 0~3 + int rcl = tiitg % 32;// 0~31 + int kl = rcl / 16;// 0~1 + int ml = rcl % 16;// 0~15 -> m + // B:[16, 2, 4] + int no = tiitg / 8;// 0~15 + int sl = tiitg % 8;// 0~7 + int kwl = sl / 4;// 0~1 + int nl = sl % 4;// 0~3 + + /** input: + threadgroup: [K4, M32, K8] -> [K4, M16, M2, K2, K4] + index: [ko, ml, M2, kl, K4] + each thread: M2K4 + layout: [K/4, M, K4] -> [K/32, K4, K2, M/32, M16, M2, K4] + index : [K/32, ko, kl, rx, ml, M2, K4] + */ + /** weight: + threadgroup: [K4, K8, N64] -> [K2, K4, K4, N16, N4] + index: [kwl, K4, K4, no, nl] + each thread: K4K4 + layout: [N/4, K/4, N4, K4] -> [N/64, N16, K/32, K2, K4, N4, K4] + index : [uz, no, K/32, kwl, K4, nl, K4] + */ + /** scale/offset: + layout:[N/4, block_size, 2, N4] -> [N/64, N16, block_size, 2, N4] + index : [uz, no, block_size, 2, nl] + */ + /** output: + threadgroup: [M32, N64] -> [M2, N2, N2, N2, M2, M8, N8] + index [kl, ko/2, ko%2, N2, ml/8, ml%8, N2, N4] + + each thread: M4N4 + layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M2, M16, N4] + index : [uz, ko, N4, rx, kl, ml, N4] + */ + + // boundary limit + + int idx_m20 = (rx * 16 + ml) * 2 + 0 < cst.input_size * cst.batch ? (rx * 16 + ml) * 2 + 0 : (cst.input_size * cst.batch - 1); + int idx_m21 = (rx * 16 + ml) * 2 + 1 < cst.input_size * cst.batch ? (rx * 16 + ml) * 2 + 1 : (cst.input_size * cst.batch - 1); + + int idx_k4 = 0 * 8 + ko * 2 + kl; + auto xy_in0 = in + idx_k4 * cst.input_size * cst.batch + idx_m20;// [K/4, M, K4] + auto xy_in1 = in + idx_k4 * cst.input_size * cst.batch + idx_m21;// [K/4, M, K4] + + int idx_wk4 = 0 * 8 + kwl * 4 + 0; + int idx_n4 = (uz * 16 + no) < cst.output_slice ? (uz * 16 + no) : (cst.output_slice - 1); + auto xy_wt = wt + (idx_n4 * cst.input_slice + idx_wk4) * 4 + nl;// [N/4, K/4, N4, K4] + + int idx_sa = (ko * 32 + ml * 2 + 0) * 2 + kl; + int idx_sb = 1024 + (kwl * 16 + 0) * 64 + no * 4 + nl; + int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; + for (int bi=0; bi> 4); + w_dequant[i][1] = FLOAT(w_int4[0] & 0x000F); + w_dequant[i][2] = FLOAT(w_int4[1] >> 4); + w_dequant[i][3] = FLOAT(w_int4[1] & 0x000F); + } + FLOAT4 val = FLOAT4(dequant_bias0 - 8.0 * scale0); + w_dequant = w_dequant * scale0 + FLOAT4x4(val, val, val, val); + + #elif defined(W_QUANT_8) + #pragma unroll(4) + for (int i = 0; i < 4; ++i) { + auto w = xy_wt[(z + i) * 4]; + FLOAT4 w_fp32 = FLOAT4(FLOAT(w[0]), FLOAT(w[1]), FLOAT(w[2]), FLOAT(w[3])); + w_dequant[i] = w_fp32 * scale0 + dequant_bias0; + } + #endif + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; ++i) { + ((threadgroup ftype*)sdata)[idx_sb + 64*i] = ftype(w_dequant[i/4][i%4]); // K4K4 + } + + ((threadgroup ftype4*)sdata)[idx_sa] = (ftype4)*(xy_in0); + ((threadgroup ftype4*)sdata)[idx_sa + 2] = (ftype4)*(xy_in1); + + + + threadgroup_barrier(mem_flags::mem_threadgroup); + + /* + A: [K4, M32, K8] -> [K4, M2, M16, K8] + index: [ik, sgitg/2, sga[0~1]] + + B: [K4, K8, N64] -> [K4, K8, N2, N32] + index: [ik, sgitg%2, sgb[0~3]] + + sgitg: compute M2 and N2 + */ + threadgroup ftype * sdata_a = (threadgroup ftype*)sdata + 16*8*(sgitg/2); + threadgroup ftype * sdata_b = (threadgroup ftype*)sdata + 1024 + 32*(sgitg%2); + + #pragma unroll(4) + for (short ik = 0; ik < 4; ik++) { + simdgroup_load(sga[0], (const threadgroup ftype*)sdata_a + 256 * ik, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata_a) + 256 * ik + 64, 8); + + simdgroup_load(sgb[0], ((threadgroup ftype*)sdata_b) + 512 * ik + 0, 64); + simdgroup_load(sgb[1], ((threadgroup ftype*)sdata_b) + 512 * ik + 8, 64); + simdgroup_load(sgb[2], ((threadgroup ftype*)sdata_b) + 512 * ik + 16, 64); + simdgroup_load(sgb[3], ((threadgroup ftype*)sdata_b) + 512 * ik + 24, 64); + + simdgroup_barrier(mem_flags::mem_none); + SIMDGROUP_MATRIX_FMA(2, 4); + + simdgroup_barrier(mem_flags::mem_none); + } + + xy_in0 += 8 * cst.input_size * cst.batch; + xy_in1 += 8 * cst.input_size * cst.batch; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup FLOAT * sdata_c = (threadgroup FLOAT*)sdata + 512*sgitg; + + SIMDGROUP_MATRIX_STORE((threadgroup FLOAT*)sdata_c, 8); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M2, M16, N4] + // index : [uz, ko, N4, rx, kl, ml, N4] + auto xy_out = out + ((uz * 4 + ko) * 4 + 0) * cst.output_size * cst.batch + (rx * 2 + kl) * 16 + ml;// [N/4, M, N4] + + // sdata [M2, N2, N2, N2, M2, M8, N8] + // index [kl, ko/2, ko%2, N2, ml/8, ml%8, N2, N4] + if((rx * 32 + kl * 16 + ml) < cst.input_size * cst.batch) { + if((uz * 4 + ko) * 4 < cst.output_slice) { + xy_out[0] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 0) * 16 + ml) * 2] + FLOAT4(biasTerms[(uz * 4 + ko) * 4])), cst.activation); + } + if((uz * 4 + ko) * 4 + 1 < cst.output_slice) { + xy_out[cst.output_size * cst.batch] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 0) * 16 + ml) * 2 + 1] + FLOAT4(biasTerms[(uz * 4 + ko) * 4 + 1])), cst.activation); + } + if((uz * 4 + ko) * 4 + 2 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 2] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 1) * 16 + ml) * 2] + FLOAT4(biasTerms[(uz * 4 + ko) * 4 + 2])), cst.activation); + } + if((uz * 4 + ko) * 4 + 3 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 3] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 1) * 16 + ml) * 2 + 1] + FLOAT4(biasTerms[(uz * 4 + ko) * 4 + 3])), cst.activation); + } + } +} + + kernel void conv1x1_gemm_32x64_wquant_sg(const device ftype2 *in [[buffer(0)]], device ftype4 *out [[buffer(1)]], constant conv1x1_constants& cst [[buffer(2)]], @@ -874,8 +1185,8 @@ kernel void conv1x1_gemm_32x64_wquant_sg(const device ftype2 *in [[bu for (int z = zmin; z < zmax; z += 2) { FLOAT2 data = (FLOAT2)*xy_in0; - ((threadgroup FLOAT*)sdata)[idx_sa] = data[0]; - ((threadgroup FLOAT*)sdata)[idx_sa + 1] = data[1]; + ((threadgroup ftype*)sdata)[idx_sa] = ftype(data[0]); + ((threadgroup ftype*)sdata)[idx_sa + 1] = ftype(data[1]); { #ifdef W_QUANT_4 @@ -887,25 +1198,25 @@ kernel void conv1x1_gemm_32x64_wquant_sg(const device ftype2 *in [[bu #endif FLOAT4 res = w4 * scale0[ni] + dequant_bias0[ni]; - // sdata[32 + 2* rcl + kl] = res; - ((threadgroup FLOAT*)sdata)[idx_sb] = res[0]; - ((threadgroup FLOAT*)sdata)[idx_sb + 64] = res[1]; - ((threadgroup FLOAT*)sdata)[idx_sb + 128] = res[2]; - ((threadgroup FLOAT*)sdata)[idx_sb + 192] = res[3]; + + ((threadgroup ftype*)sdata)[idx_sb] = ftype(res[0]); + ((threadgroup ftype*)sdata)[idx_sb + 64] = ftype(res[1]); + ((threadgroup ftype*)sdata)[idx_sb + 128] = ftype(res[2]); + ((threadgroup ftype*)sdata)[idx_sb + 192] = ftype(res[3]); } threadgroup_barrier(mem_flags::mem_threadgroup); - const threadgroup FLOAT * sdata_a = (const threadgroup FLOAT*)sdata + 16*8*(sgitg/2); - const threadgroup FLOAT * sdata_b = (const threadgroup FLOAT*)sdata + 32*8 + 32*(sgitg%2); + const threadgroup ftype * sdata_a = (const threadgroup ftype*)sdata + 16*8*(sgitg/2); + const threadgroup ftype * sdata_b = (const threadgroup ftype*)sdata + 32*8 + 32*(sgitg%2); - simdgroup_load(sga[0], (const threadgroup FLOAT*)sdata_a, 8); - simdgroup_load(sga[1], ((const threadgroup FLOAT*)sdata_a) + 64, 8); + simdgroup_load(sga[0], (const threadgroup ftype*)sdata_a, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata_a) + 64, 8); - simdgroup_load(sgb[0], ((const threadgroup FLOAT*)sdata_b) + 0, 64); - simdgroup_load(sgb[1], ((const threadgroup FLOAT*)sdata_b) + 8, 64); - simdgroup_load(sgb[2], ((const threadgroup FLOAT*)sdata_b) + 16, 64); - simdgroup_load(sgb[3], ((const threadgroup FLOAT*)sdata_b) + 24, 64); + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata_b) + 0, 64); + simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata_b) + 8, 64); + simdgroup_load(sgb[2], ((const threadgroup ftype*)sdata_b) + 16, 64); + simdgroup_load(sgb[3], ((const threadgroup ftype*)sdata_b) + 24, 64); SIMDGROUP_MATRIX_FMA(2, 4); @@ -940,79 +1251,665 @@ kernel void conv1x1_gemm_32x64_wquant_sg(const device ftype2 *in [[bu } )metal"; -const char* gConv1x1SgMatrix = R"metal( -#include -#include -using namespace metal; +const char* gConv1x1WfpSgMatrix = R"metal( +#ifdef USE_METAL_TENSOR_OPS +#include +#include +#endif -typedef enum : int { - None = 0, - ReLU = 1, - ReLU6 = 2, -} conv_activation_type; +kernel void conv1x1_w_dequant( + #ifdef W_QUANT_4 + const device uchar2 *wi [[buffer(0)]], + #elif defined(W_QUANT_8) + const device char4 *wi [[buffer(0)]], + #else + const device ftype4 *wi [[buffer(0)]],// [N/4, K/4, N4, K4] + #endif + device ftype4 *wf [[buffer(1)]],// [N/4, K/16, N4, K4, K4] + constant conv1x1_constants& cst [[buffer(2)]], + const device ftype4 *dequantScale [[buffer(3)]], + uint3 gid [[thread_position_in_grid]] +) { -inline ftype4 activate(ftype4 value, conv_activation_type type) { - switch (type) { - case ReLU: - return max(value, (ftype4)0); - case ReLU6: - return clamp(value, (ftype4)0, (ftype4)6); - default: // None - return value; + int idx_n = gid.x; // N + int idx_k16 = gid.y; // K/16 + + int idx_n4 = idx_n/4; + int idx_nl = idx_n%4; + int idx_k4 = idx_k16 * 4; + + if(idx_n4 >= cst.output_slice || idx_k4 >= cst.input_slice) { + return; } -} -struct conv1x1_constants { - int input_size; - int input_slice; - int output_width; - int output_height; - int output_size; - int output_slice; - int output_channel; - int batch; - int block_size; - conv_activation_type activation; - float scale_coef; -}; -#if MNN_METAL_FLOAT32_COMPUTER -typedef simdgroup_float8x8 simdgroup_T8x8; -typedef float FLOAT; -typedef float2 FLOAT2; -typedef float4 FLOAT4; -typedef float4x4 FLOAT4x4; + int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; + + + int bi = idx_k4 / block; + // [N/4, cst.block_size, 2/*scale_bias*/, N4] + FLOAT scale = FLOAT(((const device ftype *)dequantScale)[((idx_n4 * cst.block_size + bi) * 2 + 0) * 4 + idx_nl]) / (FLOAT)cst.scale_coef; + FLOAT dequant_bias = FLOAT(((const device ftype *)dequantScale)[((idx_n4 * cst.block_size + bi) * 2 + 1) * 4 + idx_nl]) / (FLOAT)cst.scale_coef; + + auto xy_wi = wi + (idx_n4 * cst.input_slice + idx_k4) * 4 + idx_nl;// [N/4, K/4, N4, K4] + auto xy_wf = wf + ((idx_n4 * (cst.input_slice/4) + idx_k16) * 4 + idx_nl) * 4;// [N/4, K/4, N4, K4] + + #ifdef W_QUANT_4 + for(int k = 0; k < 4; k++) { + #if W_ALIGN_K16_PROTECT + { + if(idx_k4 + k >= cst.input_slice) { + xy_wf[k] = ftype4(0); + } else { + uchar2 w_int4 = xy_wi[4*k]; // [N/4, K/4, N4, K4] + FLOAT4 w4 = FLOAT4((float)(w_int4[0] >> 4) - 8, (float)(w_int4[0] & 15) - 8, (float)(w_int4[1] >> 4) - 8, (float)(w_int4[1] & 15) - 8); + FLOAT4 res = w4 * scale + dequant_bias; + xy_wf[k] = (ftype4)res; + } + } + #else + { + uchar2 w_int4 = xy_wi[4*k]; // [N/4, K/4, N4, K4] + FLOAT4 w4 = FLOAT4((float)(w_int4[0] >> 4) - 8, (float)(w_int4[0] & 15) - 8, (float)(w_int4[1] >> 4) - 8, (float)(w_int4[1] & 15) - 8); + FLOAT4 res = w4 * scale + dequant_bias; + xy_wf[k] = (ftype4)res; + } + #endif + } + #elif defined(W_QUANT_8) + for(int k = 0; k < 4; k++) { + #if W_ALIGN_K16_PROTECT + { + if(idx_k4 + k >= cst.input_slice) { + xy_wf[k] = ftype4(0); + } else { + char4 w_int4 = xy_wi[4*k]; // [N/4, K/4, N4, K4] + FLOAT4 w4 = FLOAT4((float)w_int4[0], (float)w_int4[1], (float)w_int4[2], (float)w_int4[3]); + FLOAT4 res = w4 * scale + dequant_bias; + xy_wf[k] = (ftype4)res; + } + } + #else + { + char4 w_int4 = xy_wi[4*k]; // [N/4, K/4, N4, K4] + FLOAT4 w4 = FLOAT4((float)w_int4[0], (float)w_int4[1], (float)w_int4[2], (float)w_int4[3]); + FLOAT4 res = w4 * scale + dequant_bias; + xy_wf[k] = (ftype4)res; + } + #endif + } + #endif + +} + +kernel void conv1x1_gemm_32x64_split_k_sg(const device ftype4 *in [[buffer(0)]], + device ftype4 *out [[buffer(1)]], + constant conv1x1_constants& cst [[buffer(2)]], + #ifdef W_QUANT_4 + const device MNN::uchar4x2 *wt [[buffer(3)]],// [N/4, K/16, N4, K4, K4] + #elif defined(W_QUANT_8) + const device MNN::char4x4 *wt [[buffer(3)]],// [N/4, K/16, N4, K4, K4] + #else + const device ftype4x4 *wt [[buffer(3)]],// [N/4, K/16, N4, K4, K4] + #endif + const device ftype4 *biasTerms [[buffer(4)]], + #if defined(W_QUANT_4) || defined(W_QUANT_8) + const device ftype *dequantScale [[buffer(5)]], + #endif + uint3 gid [[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + +#ifdef USE_METAL_TENSOR_OPS + +#ifdef LOOP_K64 + /* + // Read: + ftype 0~2047 ---> input: [M32, K64] + ftype 2048~6015 ---> weight: [N64, K64] + // Write: + FLOAT 0~2047 ---> input: [M32, N64] + */ + threadgroup ftype4 sdata[1536] = {0.f}; + + const int K = 64, M = 32, N = 64; + auto tI = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] + auto tW = tensor, tensor_inline>((threadgroup ftype*)sdata + 2048, dextents(K, N));//[N, K] + + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(M, N, K, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mmOps; + + auto cT = mmOps.get_destination_cooperative_tensor(); + + int rx = gid.x;// M/32 + int uz = gid.y;// N/64 + + // A:[16, 8] + int kl = tiitg / 8;// 0~15 + int ml = tiitg % 8;// 0~7 + + // B:[16, 2, 4] + int no = tiitg / 8;// 0~15 + int sl = tiitg % 8;// 0~7 + int kwl = sl / 4;// 0~1 + int nl = sl % 4;// 0~3 + + // C:[32, 4] + int mlc = tiitg / 4;// 0~31 + int nlc = tiitg % 4;// 0~3 + /** input: + threadgroup: [M32, K64] -> [M8, M4, K16, K4] + index: [ml, M4, kl, K4] + each thread: M4K4 + layout: [K/4, M, K4] -> [K/64, K16, M/32, M8, M4, K4] + index : [K/64, kl, rx, ml, M4, K4] + */ + /** weight: + threadgroup: [N64, K64] -> [N16, N4, K2, K32] + index: [no, nl, kwl, K32] + each thread: K2K16 + layout: [N/4, K/16, N4, K4, K4] -> [N/64, N16, K/64, K2, K2, N4, K4, K4] + index : [uz, no, K/64, kwl, K2, nl, K4, K4] + */ + /** scale/offset: + layout:[N/4, block_size, 2, N4] -> [N/64, N16, block_size, 2, N4] + index : [uz, no, block_size, 2, nl] + */ + /** output: + threadgroup: [M32, N64] -> [M32, N4, N16] + index [mlc, nlc, N16] + + each thread: N16 + layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M32, N4] + index : [uz, nlc, N4, rx, mlc, N4] + */ + + // boundary limit + int idx_m40 = (rx * 8 + ml) * 4 + 0 < cst.input_size * cst.batch ? (rx * 8 + ml) * 4 + 0 : (cst.input_size * cst.batch - 1); + int idx_m41 = (rx * 8 + ml) * 4 + 1 < cst.input_size * cst.batch ? (rx * 8 + ml) * 4 + 1 : (cst.input_size * cst.batch - 1); + int idx_m42 = (rx * 8 + ml) * 4 + 2 < cst.input_size * cst.batch ? (rx * 8 + ml) * 4 + 2 : (cst.input_size * cst.batch - 1); + int idx_m43 = (rx * 8 + ml) * 4 + 3 < cst.input_size * cst.batch ? (rx * 8 + ml) * 4 + 3 : (cst.input_size * cst.batch - 1); + + int idx_k4 = 0 * 16 + kl; + auto xy_in0 = in + idx_k4 * cst.input_size * cst.batch + idx_m40;// [K/4, M, K4] + auto xy_in1 = in + idx_k4 * cst.input_size * cst.batch + idx_m41;// [K/4, M, K4] + auto xy_in2 = in + idx_k4 * cst.input_size * cst.batch + idx_m42;// [K/4, M, K4] + auto xy_in3 = in + idx_k4 * cst.input_size * cst.batch + idx_m43;// [K/4, M, K4] + + int idx_wk16 = (0 * 2 + kwl) * 2 + 0; + + int idx_n4 = (uz * 16 + no) < cst.output_slice ? (uz * 16 + no) : (cst.output_slice - 1); + auto xy_wt = wt + (idx_n4 * (cst.input_slice/4) + idx_wk16) * 4 + nl;// [N/4, K/16, N4, K4, K4] + + int idx_sa = (ml * 4 + 0) * 16 + kl; // [M8, M4, K16] x [K4] + int idx_sb = 512 + ((no * 4 + nl) * 2 + kwl) * 8 + 0; // [N16 N4, K2, K8] x [K4] + int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; + + for (int bi=0; bi, tensor_inline>((threadgroup FLOAT*)sdata, dextents(N, M)); // [M , N] + cT.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each thread: N16 + // layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M32, N4] + // index : [uz, nlc, N4, rx, mlc, N4] + + auto xy_out = out + ((uz * 4 + nlc) * 4 + 0) * cst.output_size * cst.batch + (rx * 32 + mlc);// [N/4, M, N4] + // sdata: [M32, N64] -> [M32, N4, N16] + // index [mlc, nlc, N16] + if((rx * 32 + mlc) < cst.input_size * cst.batch) { + if((uz * 4 + nlc) * 4 < cst.output_slice) { + xy_out[0] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 0] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4])), cst.activation); + } + if((uz * 4 + nlc) * 4 + 1 < cst.output_slice) { + xy_out[cst.output_size * cst.batch] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 1] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4 + 1])), cst.activation); + } + if((uz * 4 + nlc) * 4 + 2 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 2] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 2] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4 + 2])), cst.activation); + } + if((uz * 4 + nlc) * 4 + 3 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 3] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 3] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4 + 3])), cst.activation); + } + } #else -typedef simdgroup_half8x8 simdgroup_T8x8; -typedef half FLOAT; -typedef half2 FLOAT2; -typedef half4 FLOAT4; -typedef half4x4 FLOAT4x4; -#endif + /* + // Read: + ftype 0~1023 ---> input: [M32, K32] + ftype 1024~3071 ---> weight: [N64, K32] + // Write: + FLOAT 0~2047 ---> input: [M32, N64] + */ + threadgroup FLOAT4 sdata[800] = {0.f}; + const int K = 32, M = 32, N = 64; + auto tI = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] + auto tW = tensor, tensor_inline>((threadgroup ftype*)sdata + 1024, dextents(K, N));//[N, K] -#define SIMD_GROUP_WIDTH 32 -#define CONV_UNROLL (4) -#define CONV_UNROLL_L (8) + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(M, N, K, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mmOps; -#define INIT_SIMDGROUP_MATRIX(a, b, d) \ - simdgroup_T8x8 sga[a];\ - simdgroup_T8x8 sgb[b];\ - simdgroup_T8x8 sgd[d];\ - for (int i = 0; i < d; i++){\ - sgd[i] = make_filled_simdgroup_matrix(0.f);\ + auto cT = mmOps.get_destination_cooperative_tensor(); + + int rx = gid.x;// M/32 + int uz = gid.y;// N/64 + + // A:[8, 16] + int kl = tiitg / 16;// 0~7 + int ml = tiitg % 16;// 0~15 + + // B:[16, 4, 2] + int no = tiitg / 8;// 0~15 + int sl = tiitg % 8;// 0~7 + int nl = sl / 2;// 0~3 + int kwl = sl % 2;// 0~1 + + // C:[32, 4] + int mlc = tiitg / 4;// 0~31 + int nlc = tiitg % 4;// 0~3 + /** input: + threadgroup: [M32, K32] -> [M16, M2, K8, K4] + index: [ml, M2, kl, K4] + each thread: M2K4 + layout: [K/4, M, K4] -> [K/32, K8, M/32, M16, M2, K4] + index : [K/32, kl, rx, ml, M2, K4] + */ + /** weight: + threadgroup: [N64, K32] -> [N16 N4, K2, K16] + index: [no, nl, kwl, K16] + each thread: K4K4 + layout: [N/4, K/16, N4, K4, K4] -> [N/64, N16, K/32, K2, N4, K4, K4] + index : [uz, no, K/32, kwl, nl, K4, K4] + */ + /** scale/offset: + layout:[N/4, block_size, 2, N4] -> [N/64, N16, block_size, 2, N4] + index : [uz, no, block_size, 2, nl] + */ + /** output: + threadgroup: [M32, N64] -> [M32, N4, N16] + index [mlc, nlc, N16] + + each thread: N16 + layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M32, N4] + index : [uz, nlc, N4, rx, mlc, N4] + */ + + // boundary limit + int idx_m20 = (rx * 16 + ml) * 2 + 0 < cst.input_size * cst.batch ? (rx * 16 + ml) * 2 + 0 : (cst.input_size * cst.batch - 1); + int idx_m21 = (rx * 16 + ml) * 2 + 1 < cst.input_size * cst.batch ? (rx * 16 + ml) * 2 + 1 : (cst.input_size * cst.batch - 1); + + int idx_k4 = 0 * 8 + kl; + auto xy_in0 = in + idx_k4 * cst.input_size * cst.batch + idx_m20;// [K/4, M, K4] + auto xy_in1 = in + idx_k4 * cst.input_size * cst.batch + idx_m21;// [K/4, M, K4] + + int idx_wk16 = 0 * 2 + kwl; + + int idx_n4 = (uz * 16 + no) < cst.output_slice ? (uz * 16 + no) : (cst.output_slice - 1); + auto xy_wt = wt + (idx_n4 * (cst.input_slice/4) + idx_wk16) * 4 + nl;// [N/4, K/16, N4, K4, K4] + + int idx_sa = (ml * 2 + 0) * 8 + kl; // [M16, M2, K8] x [K4] + int idx_sb = 256 + ((no * 4 + nl) * 2 + kwl) * 4 + 0; // [N16 N4, K2, K4] x [K4] + int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; + + for (int bi=0; bi> 4); + w_dequant[0][0] = temp[0]; + w_dequant[1][0] = temp[1]; + w_dequant[2][0] = temp[2]; + w_dequant[3][0] = temp[3]; + temp = FLOAT4(uchar4(w_int4[0][0], w_int4[1][0], w_int4[2][0], w_int4[3][0]) & 0x000F); + w_dequant[0][1] = temp[0]; + w_dequant[1][1] = temp[1]; + w_dequant[2][1] = temp[2]; + w_dequant[3][1] = temp[3]; + temp = FLOAT4(uchar4(w_int4[0][1], w_int4[1][1], w_int4[2][1], w_int4[3][1]) >> 4); + w_dequant[0][2] = temp[0]; + w_dequant[1][2] = temp[1]; + w_dequant[2][2] = temp[2]; + w_dequant[3][2] = temp[3]; + temp = FLOAT4(uchar4(w_int4[0][1], w_int4[1][1], w_int4[2][1], w_int4[3][1]) & 0x000F); + w_dequant[0][3] = temp[0]; + w_dequant[1][3] = temp[1]; + w_dequant[2][3] = temp[2]; + w_dequant[3][3] = temp[3]; + + FLOAT4 val = FLOAT4(dequant_bias0 - 8.0 * scale0); + w_dequant = w_dequant * scale0 + FLOAT4x4(val, val, val, val); + + #elif defined(W_QUANT_8) + auto w = xy_wt[z]; + FLOAT4x4 w_fp32 = FLOAT4x4(FLOAT4(w[0]), FLOAT4(w[1]), FLOAT4(w[2]), FLOAT4(w[3])); + for (int i = 0; i < 4; ++i) { + w_dequant[i] = w_fp32[i] * scale0 + dequant_bias0; + } + #else + auto w = xy_wt[z]; + w_dequant = FLOAT4x4((FLOAT4)w[0], (FLOAT4)w[1], (FLOAT4)w[2], (FLOAT4)w[3]); + #endif + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (int i = 0; i < 4; ++i) { + ((threadgroup ftype4*)sdata)[idx_sb + i] = ftype4(w_dequant[i]); // K4K4 + } + + ((threadgroup ftype4*)sdata)[idx_sa] = (ftype4)*(xy_in0); + ((threadgroup ftype4*)sdata)[idx_sa + 8] = (ftype4)*(xy_in1); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + + auto sA = tI.slice(0, 0); + auto sB = tW.slice(0, 0); + + mmOps.run(sA, sB, cT); + + xy_in0 += 8 * cst.input_size * cst.batch; + xy_in1 += 8 * cst.input_size * cst.batch; + } } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto tC = tensor, tensor_inline>((threadgroup FLOAT*)sdata, dextents(N, M)); // [M , N] + cT.store(tC); -#define SIMDGROUP_MATRIX_FMA(a, b) \ - for(int j=0; j [N/64, N4, N4, M/32, M32, N4] + // index : [uz, nlc, N4, rx, mlc, N4] + + auto xy_out = out + ((uz * 4 + nlc) * 4 + 0) * cst.output_size * cst.batch + (rx * 32 + mlc);// [N/4, M, N4] + // sdata: [M32, N64] -> [M32, N4, N16] + // index [mlc, nlc, N16] + if((rx * 32 + mlc) < cst.input_size * cst.batch) { + if((uz * 4 + nlc) * 4 < cst.output_slice) { + xy_out[0] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 0] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4])), cst.activation); + } + if((uz * 4 + nlc) * 4 + 1 < cst.output_slice) { + xy_out[cst.output_size * cst.batch] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 1] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4 + 1])), cst.activation); + } + if((uz * 4 + nlc) * 4 + 2 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 2] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 2] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4 + 2])), cst.activation); + } + if((uz * 4 + nlc) * 4 + 3 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 3] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(mlc * 4 + nlc) * 4 + 3] + FLOAT4(biasTerms[(uz * 4 + nlc) * 4 + 3])), cst.activation); + } } +#endif +#else + /* + // Read: + ftype 0~1023 ---> input: [K4, M32, K8] + ftype 1024~3071 ---> weight: [K4, K8, N64] + ftype 3072~3199 ---> scale/offset: [N64, 2] + // Write: + FLOAT 0~2047 ---> input: [M2, N2, N2, N2, M2, M8, N8] + */ + threadgroup FLOAT4 sdata[800] = {0.f}; + + simdgroup_half8x8 sga[2]; + simdgroup_half8x8 sgb[4]; + simdgroup_float8x8 sgd[8]; + for (int i = 0; i < 8; i++){ + sgd[i] = make_filled_simdgroup_matrix(0.f); + } + + int rx = gid.x;// M/32 + int uz = gid.y;// N/64 -#define SIMDGROUP_MATRIX_STORE(ptr, d) \ - for(int i=0; i m + // B:[16, 2, 4] + int no = tiitg / 8;// 0~15 + int sl = tiitg % 8;// 0~7 + int kwl = sl / 4;// 0~1 + int nl = sl % 4;// 0~3 + + /** input: + threadgroup: [K4, M32, K8] -> [K4, M16, M2, K2, K4] + index: [ko, ml, M2, kl, K4] + each thread: M2K4 + layout: [K/4, M, K4] -> [K/32, K4, K2, M/32, M16, M2, K4] + index : [K/32, ko, kl, rx, ml, M2, K4] + */ + /** weight: + threadgroup: [K4, K8, N64] -> [K2, K4, K4, N16, N4] + index: [kwl, K4, K4, no, nl] + each thread: K4K4 + layout: [N/4, K/16, N4, K4, K4] -> [N/64, N16, K/32, K2, N4, K4, K4] + index : [uz, no, K/32, kwl, nl, K4, K4] + */ + /** scale/offset: + layout:[N/4, block_size, 2, N4] -> [N/64, N16, block_size, 2, N4] + index : [uz, no, block_size, 2, nl] + */ + /** output: + threadgroup: [M32, N64] -> [M2, N2, N2, N2, M2, M8, N8] + index [kl, ko/2, ko%2, N2, ml/8, ml%8, N2, N4] + + each thread: N16 + layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M2, M16, N4] + index : [uz, ko, N4, rx, kl, ml, N4] + */ + + // boundary limit + int idx_m20 = (rx * 16 + ml) * 2 + 0 < cst.input_size * cst.batch ? (rx * 16 + ml) * 2 + 0 : (cst.input_size * cst.batch - 1); + int idx_m21 = (rx * 16 + ml) * 2 + 1 < cst.input_size * cst.batch ? (rx * 16 + ml) * 2 + 1 : (cst.input_size * cst.batch - 1); + + int idx_k4 = 0 * 8 + ko * 2 + kl; + auto xy_in0 = in + idx_k4 * cst.input_size * cst.batch + idx_m20;// [K/4, M, K4] + auto xy_in1 = in + idx_k4 * cst.input_size * cst.batch + idx_m21;// [K/4, M, K4] + + int idx_wk16 = 0 * 2 + kwl; + + int idx_n4 = (uz * 16 + no) < cst.output_slice ? (uz * 16 + no) : (cst.output_slice - 1); + auto xy_wt = wt + (idx_n4 * (cst.input_slice/4) + idx_wk16) * 4 + nl;// [N/4, K/16, N4, K4, K4] + + int idx_sa = (ko * 32 + ml * 2 + 0) * 2 + kl; + int idx_sb = 1024 + (kwl * 16 + 0) * 64 + no * 4 + nl; + int block = (cst.input_slice + cst.block_size - 1) / cst.block_size; + + for (int bi=0; bi> 4); + w_dequant[0][0] = temp[0]; + w_dequant[1][0] = temp[1]; + w_dequant[2][0] = temp[2]; + w_dequant[3][0] = temp[3]; + temp = FLOAT4(uchar4(w_int4[0][0], w_int4[1][0], w_int4[2][0], w_int4[3][0]) & 0x000F); + w_dequant[0][1] = temp[0]; + w_dequant[1][1] = temp[1]; + w_dequant[2][1] = temp[2]; + w_dequant[3][1] = temp[3]; + temp = FLOAT4(uchar4(w_int4[0][1], w_int4[1][1], w_int4[2][1], w_int4[3][1]) >> 4); + w_dequant[0][2] = temp[0]; + w_dequant[1][2] = temp[1]; + w_dequant[2][2] = temp[2]; + w_dequant[3][2] = temp[3]; + temp = FLOAT4(uchar4(w_int4[0][1], w_int4[1][1], w_int4[2][1], w_int4[3][1]) & 0x000F); + w_dequant[0][3] = temp[0]; + w_dequant[1][3] = temp[1]; + w_dequant[2][3] = temp[2]; + w_dequant[3][3] = temp[3]; + + FLOAT4 val = FLOAT4(dequant_bias0 - 8.0 * scale0); + w_dequant = w_dequant * scale0 + FLOAT4x4(val, val, val, val); + + #elif defined(W_QUANT_8) + auto w = xy_wt[z]; + FLOAT4x4 w_fp32 = FLOAT4x4(FLOAT4(w[0]), FLOAT4(w[1]), FLOAT4(w[2]), FLOAT4(w[3])); + for (int i = 0; i < 4; ++i) { + w_dequant[i] = w_fp32[i] * scale0 + dequant_bias0; + } + #else + auto w = xy_wt[z]; + w_dequant = FLOAT4x4((FLOAT4)w[0], (FLOAT4)w[1], (FLOAT4)w[2], (FLOAT4)w[3]); + #endif + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; ++i) { + ((threadgroup ftype*)sdata)[idx_sb + 64*i] = ftype(w_dequant[i/4][i%4]); // K4K4 + } + + ((threadgroup ftype4*)sdata)[idx_sa] = (ftype4)*(xy_in0); + ((threadgroup ftype4*)sdata)[idx_sa + 2] = (ftype4)*(xy_in1); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + /* + A: [K4, M32, K8] -> [K4, M2, M16, K8] + index: [ik, sgitg/2, sga[0~1]] + + B: [K4, K8, N64] -> [K4, K8, N2, N32] + index: [ik, sgitg%2, sgb[0~3]] + + sgitg: compute M2 and N2 + */ + threadgroup ftype * sdata_a = (threadgroup ftype*)sdata + 16*8*(sgitg/2); + threadgroup ftype * sdata_b = (threadgroup ftype*)sdata + 1024 + 32*(sgitg%2); + + #pragma unroll(4) + for (short ik = 0; ik < 4; ik++) { + simdgroup_load(sga[0], (const threadgroup ftype*)sdata_a + 256 * ik, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata_a) + 256 * ik + 64, 8); + + simdgroup_load(sgb[0], ((threadgroup ftype*)sdata_b) + 512 * ik + 0, 64); + simdgroup_load(sgb[1], ((threadgroup ftype*)sdata_b) + 512 * ik + 8, 64); + simdgroup_load(sgb[2], ((threadgroup ftype*)sdata_b) + 512 * ik + 16, 64); + simdgroup_load(sgb[3], ((threadgroup ftype*)sdata_b) + 512 * ik + 24, 64); + + simdgroup_barrier(mem_flags::mem_none); + SIMDGROUP_MATRIX_FMA(2, 4); + + simdgroup_barrier(mem_flags::mem_none); + } + + xy_in0 += 8 * cst.input_size * cst.batch; + xy_in1 += 8 * cst.input_size * cst.batch; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup FLOAT * sdata_c = (threadgroup FLOAT*)sdata + 512*sgitg; + + SIMDGROUP_MATRIX_STORE((threadgroup FLOAT*)sdata_c, 8); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // layout: [N/4, M, N4] -> [N/64, N4, N4, M/32, M2, M16, N4] + // index : [uz, ko, N4, rx, kl, ml, N4] + auto xy_out = out + ((uz * 4 + ko) * 4 + 0) * cst.output_size * cst.batch + (rx * 2 + kl) * 16 + ml;// [N/4, M, N4] + + // sdata [M2, N2, N2, N2, M2, M8, N8] + // index [kl, ko/2, ko%2, N2, ml/8, ml%8, N2, N4] + if((rx * 32 + kl * 16 + ml) < cst.input_size * cst.batch) { + if((uz * 4 + ko) * 4 < cst.output_slice) { + xy_out[0] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 0) * 16 + ml) * 2] + FLOAT4(biasTerms[(uz * 4 + ko) * 4])), cst.activation); + } + if((uz * 4 + ko) * 4 + 1 < cst.output_slice) { + xy_out[cst.output_size * cst.batch] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 0) * 16 + ml) * 2 + 1] + FLOAT4(biasTerms[(uz * 4 + ko) * 4 + 1])), cst.activation); + } + if((uz * 4 + ko) * 4 + 2 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 2] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 1) * 16 + ml) * 2] + FLOAT4(biasTerms[(uz * 4 + ko) * 4 + 2])), cst.activation); + } + if((uz * 4 + ko) * 4 + 3 < cst.output_slice) { + xy_out[cst.output_size * cst.batch * 3] = activate(ftype4(((threadgroup FLOAT4*)sdata)[(((kl * 4 + ko) * 2 + 1) * 16 + ml) * 2 + 1] + FLOAT4(biasTerms[(uz * 4 + ko) * 4 + 3])), cst.activation); + } } +#endif +} + kernel void conv1x1_gemm_16x16_sg(const device ftype4 *in [[buffer(0)]], device ftype4 *out [[buffer(1)]], @@ -1027,7 +1924,7 @@ kernel void conv1x1_gemm_16x16_sg(const device ftype4 *in [[buffer(0) ftype 0~127 ---> input: [M16, K8] ftype 128~255 ---> input: [K8, N16] // Write: - ftype 0~255 ---> input: [N2, M2, M8, N8] + FLOAT 0~255 ---> input: [N2, M2, M8, N8] */ threadgroup FLOAT4 sdata[64] = {0.f}; @@ -1047,21 +1944,21 @@ kernel void conv1x1_gemm_16x16_sg(const device ftype4 *in [[buffer(0) auto xy_out = out + (4 * uz + 2 * kl) * cst.output_size * cst.batch + idx_m;// [N/4, M, N4] for (int z = kl; z < cst.input_slice; z += 2) { - sdata[2* rcl + kl] = FLOAT4(*xy_in0); + ((threadgroup ftype4*)sdata)[2* rcl + kl] = (*xy_in0); xy_in0 += 2 * cst.input_size * cst.batch; FLOAT4 w4 = FLOAT4(xy_wt[4 * z]); // [N/4, K/4, N4, K4] - ((threadgroup FLOAT*)sdata)[128 + (kl * 4 + 0) * 16 + rcl] = w4[0]; - ((threadgroup FLOAT*)sdata)[128 + (kl * 4 + 1) * 16 + rcl] = w4[1]; - ((threadgroup FLOAT*)sdata)[128 + (kl * 4 + 2) * 16 + rcl] = w4[2]; - ((threadgroup FLOAT*)sdata)[128 + (kl * 4 + 3) * 16 + rcl] = w4[3]; + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 0) * 16 + rcl] = ftype(w4[0]); + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 1) * 16 + rcl] = ftype(w4[1]); + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 2) * 16 + rcl] = ftype(w4[2]); + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 3) * 16 + rcl] = ftype(w4[3]); threadgroup_barrier(mem_flags::mem_threadgroup); - simdgroup_load(sga[0], (const threadgroup FLOAT*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup FLOAT*)sdata) + 64, 8); - simdgroup_load(sgb[0], ((const threadgroup FLOAT*)sdata) + 128, 16); - simdgroup_load(sgb[1], ((const threadgroup FLOAT*)sdata) + 136, 16); + simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 128, 16); + simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 136, 16); SIMDGROUP_MATRIX_FMA(2, 2); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -1095,7 +1992,7 @@ kernel void conv1x1_gemm_32x16_sg(const device ftype4 *in [[buffer(0) ftype 0~255 ---> input: [M32, K8] ftype 256~383 ---> input: [K8, N16] // Write: - ftype 0~511 ---> input: [N2, M4, M8, N8] + FLOAT 0~511 ---> input: [N2, M4, M8, N8] */ threadgroup FLOAT4 sdata[128] = {0.f}; @@ -1122,23 +2019,23 @@ kernel void conv1x1_gemm_32x16_sg(const device ftype4 *in [[buffer(0) auto xy_out1 = out + (4 * uz + 2 * kl) * cst.output_size * cst.batch + idx_m1;// [N/4, M, N4] for (int z = kl; z < cst.input_slice; z += 2) { - sdata[2* rcl + kl] = (FLOAT4)*xy_in0; - sdata[32 + 2* rcl + kl] = (FLOAT4)*xy_in1; + ((threadgroup ftype4*)sdata)[2* rcl + kl] = *xy_in0; + ((threadgroup ftype4*)sdata)[32 + 2* rcl + kl] = *xy_in1; FLOAT4 w4 = FLOAT4(xy_wt[4*z]); // [N/4, K/4, N4, K4] - ((threadgroup FLOAT*)sdata)[256 + (kl * 4 + 0) * 16 + rcl] = w4[0]; - ((threadgroup FLOAT*)sdata)[256 + (kl * 4 + 1) * 16 + rcl] = w4[1]; - ((threadgroup FLOAT*)sdata)[256 + (kl * 4 + 2) * 16 + rcl] = w4[2]; - ((threadgroup FLOAT*)sdata)[256 + (kl * 4 + 3) * 16 + rcl] = w4[3]; + ((threadgroup ftype*)sdata)[256 + (kl * 4 + 0) * 16 + rcl] = ftype(w4[0]); + ((threadgroup ftype*)sdata)[256 + (kl * 4 + 1) * 16 + rcl] = ftype(w4[1]); + ((threadgroup ftype*)sdata)[256 + (kl * 4 + 2) * 16 + rcl] = ftype(w4[2]); + ((threadgroup ftype*)sdata)[256 + (kl * 4 + 3) * 16 + rcl] = ftype(w4[3]); threadgroup_barrier(mem_flags::mem_threadgroup); - simdgroup_load(sga[0], (const threadgroup FLOAT*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup FLOAT*)sdata) + 64, 8); - simdgroup_load(sga[2], ((const threadgroup FLOAT*)sdata) + 128, 8); - simdgroup_load(sga[3], ((const threadgroup FLOAT*)sdata) + 192, 8); + simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); + simdgroup_load(sga[2], ((const threadgroup ftype*)sdata) + 128, 8); + simdgroup_load(sga[3], ((const threadgroup ftype*)sdata) + 192, 8); - simdgroup_load(sgb[0], ((const threadgroup FLOAT*)sdata) + 256, 16); - simdgroup_load(sgb[1], ((const threadgroup FLOAT*)sdata) + 264, 16); + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 256, 16); + simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 264, 16); SIMDGROUP_MATRIX_FMA(4, 2); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -1173,60 +2070,7 @@ kernel void conv1x1_gemm_32x16_sg(const device ftype4 *in [[buffer(0) )metal"; -const char* gConv1x1SgReduce = R"metal( -#include -#include -using namespace metal; -typedef enum : int { - None = 0, - ReLU = 1, - ReLU6 = 2, -} conv_activation_type; - -inline ftype4 activate(ftype4 value, conv_activation_type type) { - switch (type) { - case ReLU: - return max(value, (ftype4)0); - case ReLU6: - return clamp(value, (ftype4)0, (ftype4)6); - default: // None - return value; - } -} - -struct conv1x1_constants { - int input_size; - int input_slice; - int output_width; - int output_height; - int output_size; - int output_slice; - int output_channel; - int batch; - int block_size; - conv_activation_type activation; - float scale_coef; -}; - -#if MNN_METAL_FLOAT32_COMPUTER -typedef simdgroup_float8x8 simdgroup_T8x8; -typedef float FLOAT; -typedef float2 FLOAT2; -typedef float4 FLOAT4; -typedef float4x4 FLOAT4x4; -#else -typedef simdgroup_half8x8 simdgroup_T8x8; -typedef half FLOAT; -typedef half2 FLOAT2; -typedef half4 FLOAT4; -typedef half4x4 FLOAT4x4; -#endif - - -#define SIMD_GROUP_WIDTH 32 -#define CONV_UNROLL (4) -#define CONV_UNROLL_L (8) - +const char* gConv1x1WfpSgReduce = R"metal( kernel void conv1x1_z4_sg(const device ftype4 *in [[buffer(0)]], device ftype4 *out [[buffer(1)]], constant conv1x1_constants& cst [[buffer(2)]], @@ -1258,164 +2102,7 @@ kernel void conv1x1_z4_sg(const device ftype4 *in [[buffer(0)]], } )metal"; -const char* gConv1x1W4SgReduce = R"metal( -#include -#include -using namespace metal; -typedef enum : int { - None = 0, - ReLU = 1, - ReLU6 = 2, -} conv_activation_type; - -inline ftype4 activate(ftype4 value, conv_activation_type type) { - switch (type) { - case ReLU: - return max(value, (ftype4)0); - case ReLU6: - return clamp(value, (ftype4)0, (ftype4)6); - default: // None - return value; - } -} - -namespace MNN { - typedef struct uchar4x2 { - private: - uchar2 v[4]; - public: - uchar4x2(uchar2 a) { - v[0] = a; v[1] = a; v[2] = a; v[3] = a; - } - uchar4x2(uchar2 a, uchar2 b, uchar2 c, uchar2 d) { - v[0] = a; v[1] = b; v[2] = c; v[3] = d; - } - - inline thread uchar2& operator[] (const int index) { - return v[index]; - } - inline device uchar2& operator[] (const int index) device { - return v[index]; - } - inline threadgroup uchar2& operator[] (const int index) threadgroup { - return v[index]; - } - - inline const thread uchar2& operator[] (const int index) const { - return v[index]; - } - inline const device uchar2& operator[] (const int index) const device { - return v[index]; - } - inline const threadgroup uchar2& operator[] (const int index) const threadgroup { - return v[index]; - } - - inline explicit operator half4x2() const { - return half4x2( half2(v[0]), half2(v[1]), half2(v[2]), half2(v[3]) ); - } - inline explicit operator half4x2() const device { - return half4x2( half2(v[0]), half2(v[1]), half2(v[2]), half2(v[3]) ); - } - inline explicit operator half4x2() const threadgroup { - return half4x2( half2(v[0]), half2(v[1]), half2(v[2]), half2(v[3]) ); - } - - inline explicit operator float4x2() const { - return float4x2( float2(v[0]), float2(v[1]), float2(v[2]), float2(v[3]) ); - } - inline explicit operator float4x2() const device { - return float4x2( float2(v[0]), float2(v[1]), float2(v[2]), float2(v[3]) ); - } - inline explicit operator float4x2() const threadgroup { - return float4x2( float2(v[0]), float2(v[1]), float2(v[2]), float2(v[3]) ); - } - } uchar4x2; - - typedef struct char4x4 { - private: - char4 v[4]; - public: - char4x4(char4 a) { - v[0] = a; v[1] = a; v[2] = a; v[3] = a; - } - char4x4(char4 a, char4 b, char4 c, char4 d) { - v[0] = a; v[1] = b; v[2] = c; v[3] = d; - } - - inline thread char4& operator[] (const int index) { - return v[index]; - } - inline device char4& operator[] (const int index) device { - return v[index]; - } - inline threadgroup char4& operator[] (const int index) threadgroup { - return v[index]; - } - - inline const thread char4& operator[] (const int index) const { - return v[index]; - } - inline const device char4& operator[] (const int index) const device { - return v[index]; - } - inline const threadgroup char4& operator[] (const int index) const threadgroup { - return v[index]; - } - - inline explicit operator half4x4() const { - return half4x4( half4(v[0]), half4(v[1]), half4(v[2]), half4(v[3]) ); - } - inline explicit operator half4x4() const device { - return half4x4( half4(v[0]), half4(v[1]), half4(v[2]), half4(v[3]) ); - } - inline explicit operator half4x4() const threadgroup { - return half4x4( half4(v[0]), half4(v[1]), half4(v[2]), half4(v[3]) ); - } - - inline explicit operator float4x4() const { - return float4x4( float4(v[0]), float4(v[1]), float4(v[2]), float4(v[3]) ); - } - inline explicit operator float4x4() const device { - return float4x4( float4(v[0]), float4(v[1]), float4(v[2]), float4(v[3]) ); - } - inline explicit operator float4x4() const threadgroup { - return float4x4( float4(v[0]), float4(v[1]), float4(v[2]), float4(v[3]) ); - } - } char4x4; -} - -struct conv1x1_constants { - int input_size; - int input_slice; - int output_width; - int output_height; - int output_size; - int output_slice; - int output_channel; - int batch; - int block_size; - conv_activation_type activation; - float scale_coef; -}; - -#if MNN_METAL_FLOAT32_COMPUTER -typedef simdgroup_float8x8 simdgroup_T8x8; -typedef float FLOAT; -typedef float2 FLOAT2; -typedef float4 FLOAT4; -typedef float4x4 FLOAT4x4; -#else -typedef simdgroup_half8x8 simdgroup_T8x8; -typedef half FLOAT; -typedef half2 FLOAT2; -typedef half4 FLOAT4; -typedef half4x4 FLOAT4x4; -#endif - -#define SIMD_GROUP_WIDTH 32 -#define CONV_UNROLL (4) -#define CONV_UNROLL_L (8) +const char* gConv1x1WqSgReduce = R"metal( template kernel void conv1x1_gemv_g4mx_wquant_sg(const device ftype4 *in [[buffer(0)]], @@ -1510,6 +2197,11 @@ template [[host_name("conv1x1_gemv_g4m7_wquant_sg")]] kernel kernel_type_t conv1 template [[host_name("conv1x1_gemv_g4m8_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<8>; template [[host_name("conv1x1_gemv_g4m9_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<9>; template [[host_name("conv1x1_gemv_g4m10_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<10>; +template [[host_name("conv1x1_gemv_g4m11_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<11>; +template [[host_name("conv1x1_gemv_g4m12_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<12>; +template [[host_name("conv1x1_gemv_g4m13_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<13>; +template [[host_name("conv1x1_gemv_g4m14_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<14>; +template [[host_name("conv1x1_gemv_g4m15_wquant_sg")]] kernel kernel_type_t conv1x1_gemv_g4mx_wquant_sg<15>; kernel void conv1x1_gemv_g8_wquant_sg(const device ftype4 *in [[buffer(0)]], device ftype4 *out [[buffer(1)]], diff --git a/source/backend/metal/MetalAttention.mm b/source/backend/metal/MetalAttention.mm index 4b55940d5e..6acdbeee52 100644 --- a/source/backend/metal/MetalAttention.mm +++ b/source/backend/metal/MetalAttention.mm @@ -14,6 +14,7 @@ #import "MetalAttentionShader.hpp" #include "MNN_generated.h" #include "core/OpCommonUtils.hpp" +#include "MetalKVCacheManager.hpp" #if MNN_METAL_ENABLED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE @@ -21,13 +22,7 @@ namespace MNN { class AttentionBufExecution : public MetalExecution { public: - struct SharedCache { - std::shared_ptr mPastKey; - std::shared_ptr mPastValue; - int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0; - }; AttentionBufExecution(Backend *backend, bool kv_cache); - virtual ~AttentionBufExecution() = default; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; @@ -37,24 +32,24 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { return true; } auto exe = new AttentionBufExecution(bn, mKVCache); - exe->mCache = mCache; + exe->mKVCacheManager = mKVCacheManager; *dst = exe; return true; } private: void _init(); - void reallocKVCache(); void compilerShader(const std::vector &inputs); void handleKVAllocMemory(); bool mKVCache; - std::shared_ptr mCache; + std::shared_ptr mKVCacheManager = nullptr; float mScale; - const int mExpandChunk = 64; bool mShortSeq = false; std::shared_ptr mTempQK, mTempSoftMax; int mNumHead = 0, mHeadDim = 0, mValueH = 0, mKvNumHead = 0; int mSeqLen; + // for simd/tensor maxtrix load alignment + int mKvAlignNum = 32; id mKernel_softmax = nil; id mKernel_qk = nil; @@ -70,6 +65,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { KVMeta* mMeta; bool mQkSimdReduce = false; bool mQkSimdMatrix = false; + bool mQkTensorMatrix = false; bool mSftmSimdReduce = false; bool mQkvSimdReduce = false; bool mQkvSimdMatrix = false; @@ -79,6 +75,8 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { int mBatch, mKvSeqLen, mKvMaxLen; int mQseqSplitNum = 1; std::shared_ptr mTempK, mTempV; + bool mKvInDisk; + }; struct Param { @@ -91,13 +89,13 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { float scale; int max_kv_len; int batch; + int kv_align_len; }; AttentionBufExecution::AttentionBufExecution(Backend *backend, bool kv_cahce) : MetalExecution(backend) , mKVCache(kv_cahce) { _init(); } void AttentionBufExecution::_init() { - mCache.reset(new SharedCache); auto mtbn = static_cast(backend()); auto context = (__bridge MNNMetalContext *)mtbn->context(); mMeta = (KVMeta*)(mtbn->getMetaPtr()); @@ -107,97 +105,15 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { mParamCopy = [context newDeviceBuffer:6 * sizeof(int) access:CPUWriteOnly]; mTempQK.reset(Tensor::createDevice({0, 0})); mTempSoftMax.reset(Tensor::createDevice({0, 0})); -} - -void AttentionBufExecution::reallocKVCache() { - if (!mKVCache) { - return; - } - auto kv_seq_len = mMeta->previous + mMeta->add - mMeta->remove + mMeta->computeReverseSize(); - auto mtbn = static_cast(backend()); - int byte = 4; - if(mtbn->useFp16InsteadFp32()) { - byte = 2; - } - - auto start = mCache->mPastLength - mMeta->remove; - // latest length larger than maxLen - if (kv_seq_len > mCache->mMaxLength) { - - // copy mPastLength including all remove/reverse to new buffer first - auto copy_len = mCache->mPastLength; - bool needCopy = copy_len > 0; - - size_t old_size = mKvNumHead * copy_len * mHeadDim * byte; - size_t old_piece_size = copy_len * byte; - size_t old_piece_stride = mCache->mMaxLength * byte; - - mCache->mMaxLength = kv_seq_len + mExpandChunk; - // past_key: [1, numhead, headdim, maxlen] - auto new_key = Tensor::createDevice({mCache->mMaxLength, mKvNumHead, mHeadDim}); - // past_value: [1, numhead, maxlen, headdim] - auto new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mCache->mMaxLength}); - size_t size = mKvNumHead * mCache->mMaxLength * mHeadDim * byte; - auto res = backend()->onAcquireBuffer(new_key, Backend::STATIC); - res = res && backend()->onAcquireBuffer(new_value, Backend::STATIC); - if(!res) { - MNN_ERROR("attition kv cache realloc memory error:%d\n", res); - } - if (needCopy) { - auto newKeyBuf = MetalBackend::getBuffer(new_key); - auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second; - auto keyBuf = MetalBackend::getBuffer(mCache->mPastKey.get()); - auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;; - ::memcpy(new_key_ptr, key_ptr, old_size); - - auto newValueBuf = MetalBackend::getBuffer(new_value); - auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second; - auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get()); - auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second; - for(int i = 0; i < mKvNumHead * mHeadDim; i++) { - ::memcpy(new_value_ptr + i * mCache->mMaxLength * byte, value_ptr + i * old_piece_stride, old_piece_size); - } - } - mCache->mPastLength = (int)start; - - mCache->mPastKey.reset(new_key); - mCache->mPastValue.reset(new_value); - } - // Remove - { - if (0 == mMeta->n_reserve) { - mCache->mPastLength = start; - return; - } - - auto keyBuf = MetalBackend::getBuffer(mCache->mPastKey.get()); - auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second; - auto valueBuf = MetalBackend::getBuffer(mCache->mPastValue.get()); - auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second; - - auto src_start = start; - // TODO: need to ensure reserve info is sorted - for (int n = 0; n < mMeta->n_reserve; ++n) { - auto begin = mMeta->reserve[2 * n]; - auto length = mMeta->reserve[2 * n + 1]; - // past_key : [mCache->mPastLength, mKvNumHead, mHeadDim] - // past_value : [mKvNumHead, mHeadDim, mCache->mMaxLength] + MNN::MetalKVCacheManager::KVCacheConfig kvconfig; + kvconfig.mKVCacheDir = mtbn->getRuntime()->hint().kvcacheDirPath; + kvconfig.mPrefixCacheDir = mtbn->getRuntime()->hint().prefixcacheDirPath; + kvconfig.mExpandChunk = 64; + kvconfig.mKvAlignNum = mKvAlignNum; - auto copy_src_index = src_start + begin; - auto copy_dst_index = start; - for(int i = 0; i < length; i++) { - ::memcpy(key_ptr + (copy_dst_index + i) * mKvNumHead * mHeadDim * byte, key_ptr + (copy_src_index + i) * mKvNumHead * mHeadDim * byte, mKvNumHead * mHeadDim * byte); - } - for(int j = 0; j < mKvNumHead * mHeadDim; j++) { - for(int i = 0; i < length; i++) { - ::memcpy(value_ptr + (j * mCache->mMaxLength + copy_dst_index + i) * byte, value_ptr + (j * mCache->mMaxLength + copy_src_index + i) * byte, byte); - } - } - start += length; - } - mCache->mPastLength = (int)start; - } + mKVCacheManager.reset(new MetalKVCacheManager(backend(), kvconfig)); + mKvInDisk = !kvconfig.mKVCacheDir.empty(); } void AttentionBufExecution::compilerShader(const std::vector &inputs) { @@ -210,27 +126,27 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { std::string group_str = std::to_string(group_size); // Init Kernel - std::string T = "float"; - std::string T4 = "float4"; + std::string ftype = "float"; + std::string ftype4 = "float4"; if (mtbn->useFp16InsteadFp32()) { - T = "half"; - T4 = "half4"; + ftype = "half"; + ftype4 = "half4"; } std::vector qkKeys = { - {"matmul_qk_div_mask", T, group_str} + {"matmul_qk_div_mask", ftype, group_str} }; if(mHeadDim % 4 != 0) { qkKeys.emplace_back("HEAD_DIM_UNALIGNED_4"); } std::vector qkvKeys = { - {"matmul_qkv", T, group_str} + {"matmul_qkv", ftype, group_str} }; if(mQkvSimdReduce) { qkvKeys.emplace_back("SIMD_GROUP_REDUCE"); } std::vector qkPrefillKeys = { - {"matmul_qk_div_mask", T, group_str, "FOR_PREFILL"} + {"matmul_qk_div_mask", ftype, group_str, "FOR_PREFILL"} }; if(mHasMask) { if (mIsAddMask) { @@ -249,14 +165,31 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { qkPrefillKeys.emplace_back("SIMD_GROUP_MATRIX"); } std::vector qkvPrefillKeys = { - {"matmul_qkv", T, group_str, "FOR_PREFILL"} + {"matmul_qkv", ftype, group_str, "FOR_PREFILL"} }; if(mQkvSimdMatrix) { qkvPrefillKeys.emplace_back("SIMD_GROUP_MATRIX"); } + if (mtbn->useFp16InsteadFp32()) { + qkPrefillKeys.emplace_back("MNN_METAL_FLOAT16_STORAGE"); + qkvPrefillKeys.emplace_back("MNN_METAL_FLOAT16_STORAGE"); + } std::vector copyPastKeys = { - {"pastkv_copy", T, group_str} + {"pastkv_copy", ftype, group_str} + }; + std::vector shaders = { + "decode_qk", + "decode_qkv", + "prefill_qk", + "prefill_qkv", + "copy" }; + if(mQkTensorMatrix) { + shaders[2] = "prefill_qk_tensor"; + shaders[3] = "prefill_qkv_tensor"; + qkPrefillKeys.emplace_back("USE_METAL_TENSOR_OPS"); + qkvPrefillKeys.emplace_back("USE_METAL_TENSOR_OPS"); + } std::vector> keys = { qkKeys, qkvKeys, @@ -271,13 +204,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { gMatMulQKV, gCopyPastKV }; - std::vector shaders = { - "decode_qk", - "decode_qkv", - "prefill_qk", - "prefill_qkv", - "copy" - }; + std::vector> pipelines(keys.size()); for (int i=0; ifindPipeline(keys[i]); @@ -285,11 +212,11 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { // Rebuild Pipeline MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; - [dic setValue:@(keys[i][1].c_str()) forKey:@"T"]; - [dic setValue:@(T4.c_str()) forKey:@"T4"]; + [dic setValue:@(keys[i][1].c_str()) forKey:@"ftype"]; + [dic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; [dic setValue:@(keys[i][2].c_str()) forKey:@"GROUP_SIZE"]; for (int j=3; jmPastLength = mMeta != nullptr ? mMeta->previous : 0; - // kv-cache realloc function - reallocKVCache(); - mCache->mKv_seq_len = mCache->mPastLength + mSeqLen; - mKvSeqLen = mCache->mKv_seq_len; - mKvMaxLen = mCache->mMaxLength; + mKVCacheManager->setPastLength(mMeta != nullptr ? mMeta->previous : 0); + + if (mMeta->previous == mMeta->remove) { + mKVCacheManager->onClear(); + mKVCacheManager->onAlloc(mMeta, mSeqLen); + } else { + MNN_ASSERT(mMeta->previous == mKVCacheManager->kvLength()); + mKVCacheManager->onRealloc(mMeta); + } + + mKvSeqLen = mKVCacheManager->kvLength() + mSeqLen; + mKvMaxLen = mKVCacheManager->maxLength(); float useMemorySize = 1.0 * mKvMaxLen / 1024.0 * mSeqLen / 1024.0 * mBatch * mNumHead; // elementSize larger than 32M @@ -360,12 +293,13 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { if (mTempQK->length(1) != qSeqLenPiece * mKvMaxLen) { needMalloc = true; } - mTempQK->setLength(0, mBatch * mNumHead); - mTempQK->setLength(1, qSeqLenPiece * mKvMaxLen); - mTempSoftMax->setLength(0, mBatch * mNumHead); - mTempSoftMax->setLength(1, qSeqLenPiece * mKvMaxLen); if (needMalloc) { + mTempQK->setLength(0, mBatch * mNumHead); + mTempQK->setLength(1, qSeqLenPiece * mKvMaxLen); + mTempSoftMax->setLength(0, mBatch * mNumHead); + mTempSoftMax->setLength(1, qSeqLenPiece * mKvMaxLen); + auto res = backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC) && backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC); if (!res) { MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal %d\n", res); @@ -394,9 +328,11 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { mShortSeq = mSeqLen <= 10; mKvNumHead = key->shape()[2]; mKvSeqLen = key->shape()[1]; - mKvMaxLen = ROUND_UP(mKvSeqLen, 4); + // Align to mKvAlignNum, for simd/tensor matrix load + mKvMaxLen = ROUND_UP(mKvSeqLen, mKvAlignNum); if(mKVCache) { + mKVCacheManager->onResize(mKvNumHead, mHeadDim); return NO_ERROR; } @@ -442,9 +378,14 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { Tensor* tempTensorV; handleKVAllocMemory(); - if(mKVCache) { - tempTensorK = mCache->mPastKey.get(); - tempTensorV = mCache->mPastValue.get(); + id tempBufferK; + id tempBufferV; + if(mKvInDisk) { + tempBufferK = mKVCacheManager->getKeyBuffer(); + tempBufferV = mKVCacheManager->getValueBuffer(); + } else if(mKVCache) { + tempTensorK = mKVCacheManager->getKeyTensor(); + tempTensorV = mKVCacheManager->getValueTensor(); } else { tempTensorK = mTempK.get(); tempTensorV = mTempV.get(); @@ -453,11 +394,14 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { // whether use simdgroup bool supportSimdReduce = rt->supportSimdGroupReduce(); bool supportSimdMatrix = rt->supportSimdGroupMatrix(); + bool supportTensorMatrix = rt->supportTensorOps(); // decode and thread number not too large mQkSimdReduce = supportSimdReduce && mShortSeq; // loop_k can divide 8, thus avoid branch mQkSimdMatrix = supportSimdMatrix && mSeqLen >= 16 && mHeadDim % 8 == 0; + // 32x32x32 tensor block + mQkTensorMatrix = supportTensorMatrix && mSeqLen >= 128 && mHeadDim % 32 == 0; mSftmSimdReduce = supportSimdReduce; mQkvSimdReduce = supportSimdReduce && mShortSeq && mHeadDim * mNumHead < mKvSeqLen * 32; @@ -477,8 +421,8 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { // current new kv_len copyp[1] = key->shape()[1]; copyp[2] = mKvMaxLen; - copyp[3] = mCache->mPastLength * copyp[0]; - copyp[4] = mCache->mPastLength; + copyp[3] = mKVCacheManager->kvLength() * copyp[0]; + copyp[4] = mKVCacheManager->kvLength(); copyp[5] = mBatch; int copy_line = key->shape()[1]; @@ -486,8 +430,13 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { [encoder setComputePipelineState:pipeline]; MetalBackend::setTensor(key, encoder, 0); MetalBackend::setTensor(value, encoder, 1); - MetalBackend::setTensor(tempTensorK, encoder, 2); - MetalBackend::setTensor(tempTensorV, encoder, 3); + if(mKvInDisk) { + MetalBackend::setBuffer(tempBufferK, 0, encoder, 2); + MetalBackend::setBuffer(tempBufferV, 0, encoder, 3); + } else { + MetalBackend::setTensor(tempTensorK, encoder, 2); + MetalBackend::setTensor(tempTensorV, encoder, 3); + } [encoder setBuffer:mParamCopy offset:0 atIndex:4]; std::pair gl; @@ -510,6 +459,7 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { param->q_seq_piece_len = seqLenPiece; param->max_kv_len = mKvMaxLen; param->batch = mBatch; + param->kv_align_len = mKvAlignNum; } for(int seq_idx = 0; seq_idx < mQseqSplitNum; seq_idx++) { @@ -527,7 +477,11 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { // [mBatch, mNumHead, mSeqLen, mKvSeqLen] MetalBackend::setTensor(mTempQK.get(), encoder, 1); // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] - MetalBackend::setTensor(tempTensorK, encoder, 2); + if(mKvInDisk) { + MetalBackend::setBuffer(tempBufferK, 0, encoder, 2); + } else { + MetalBackend::setTensor(tempTensorK, encoder, 2); + } [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; [encoder setBuffer:mParamQKV offset:0 atIndex:4]; if(mHasMask) { @@ -538,6 +492,8 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { std::pair gl; if(mShortSeq) { gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, decode_grid_y / group_size, mKvSeqLen)]; + } else if(mQkTensorMatrix) { + gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 32), UP_DIV(mKvSeqLen, 32), decode_grid_y), MTLSizeMake(128, 1, 1)); } else if(mQkSimdMatrix) { gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mKvSeqLen, 16), decode_grid_y), MTLSizeMake(32, 1, 1)); } else { @@ -553,16 +509,19 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { int inside = 1; int outside = mBatch * mNumHead * seqLenPiece; int axis = mKvSeqLen; + int axis_align = ROUND_UP(axis, mKvAlignNum); { auto softmax = (int*)mParamSoftmax.contents; // Inside, axis, outside, plane(invalid) softmax[0] = inside; softmax[1] = axis; softmax[2] = outside; - softmax[3] = 0; + softmax[3] = axis_align; } [encoder setComputePipelineState:mKernel_softmax]; + // [mBatch, mNumHead, mSeqLen, mKvSeqLen] MetalBackend::setTensor(mTempQK.get(), encoder, 0); + // [mBatch, mNumHead, mSeqLen, ROUND_UP(mKvSeqLen, mKvAlignNum)] MetalBackend::setTensor(mTempSoftMax.get(), encoder, 1); [encoder setBuffer:mParamSoftmax offset:0 atIndex:2]; @@ -587,29 +546,36 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { pipeline = mKernelPrefill_qkv; } [encoder setComputePipelineState:pipeline]; - // [mBatch, mNumHead, mSeqLen, mKvSeqLen] + // [mBatch, mNumHead, mSeqLen, ROUND_UP(mKvSeqLen, mKvAlignNum)] MetalBackend::setTensor(mTempSoftMax.get(), encoder, 0); // [mBatch, mSeqLen, mNumHead, mHeadDim] MetalBackend::setTensor(outputs[0], encoder, 1); // [mBatch, mKvNumHead, mHeadDim, mMaxSeqLen] - MetalBackend::setTensor(tempTensorV, encoder, 2); + if(mKvInDisk) { + MetalBackend::setBuffer(tempBufferV, 0, encoder, 2); + } else { + MetalBackend::setTensor(tempTensorV, encoder, 2); + } [encoder setBytes:&seq_idx length:sizeof(seq_idx) atIndex:3]; [encoder setBuffer:mParamQKV offset:0 atIndex:4]; std::pair gl; if(mQkvSimdReduce) { gl = std::make_pair(MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim), MTLSizeMake(32, 1, 1)); + } else if(mQkTensorMatrix){ + gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 32), UP_DIV(mHeadDim, 32), mBatch * mNumHead), MTLSizeMake(128, 1, 1)); } else if(mQkvSimdMatrix){ gl = std::make_pair(MTLSizeMake(UP_DIV(seqLenPiece, 16), UP_DIV(mHeadDim, 16), mBatch * mNumHead), MTLSizeMake(32, 1, 1)); } else { gl = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(seqLenPiece, mBatch * mNumHead, mHeadDim)]; } +// printf("mBatch:%d, mNumHead:%d, mSeqLen:%d, mKvSeqLen:%d, mHeadDim:%d\n", mBatch, mNumHead, mSeqLen, mKvSeqLen, mHeadDim); [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; } } // Update status if(mKVCache) { - mCache->mPastLength += mSeqLen; + mKVCacheManager->setPastLength(mKVCacheManager->kvLength() + mSeqLen); } return; } diff --git a/source/backend/metal/MetalAttentionShader.hpp b/source/backend/metal/MetalAttentionShader.hpp index a28f5ffabd..1304a95e9a 100644 --- a/source/backend/metal/MetalAttentionShader.hpp +++ b/source/backend/metal/MetalAttentionShader.hpp @@ -10,6 +10,10 @@ #ifdef MNN_SUPPORT_TRANSFORMER_FUSE const char* gMatMulDivMask = R"metal( +#ifdef USE_METAL_TENSOR_OPS +#include +#include +#endif #include #include using namespace metal; @@ -23,16 +27,235 @@ struct Param { float scale; int max_kv_len; int batch; + int kv_align_len; }; + +#if MNN_METAL_FLOAT16_STORAGE +typedef simdgroup_half8x8 simdgroup_T8x8; +#else +typedef simdgroup_float8x8 simdgroup_T8x8; +#endif + #define SIMD_GROUP_WIDTH 32 -kernel void prefill_qk(const device T* input0 [[buffer(0)]], - device T* output [[buffer(1)]], - device T* past_key [[buffer(2)]], +#ifdef USE_METAL_TENSOR_OPS +kernel void prefill_qk_tensor(const device ftype4* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype4* past_key [[buffer(2)]], + constant int &seq_idx [[buffer(3)]], + constant Param& param [[buffer(4)]], +#ifdef ADD_MASK + const device ftype* mask [[buffer(5)]], +#elif defined(SET_MASK) + const device int* mask [[buffer(5)]], +#endif + uint3 gid[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +) { + /* + // Read: + ftype 0~1023 ---> input: [M32, K32] + ftype 1024~2047 ---> input: [N32, K32] + // Write: + float 0~1023 ---> input: [M32, N32] + */ + threadgroup ftype sdata[2048] = {0.f}; + + const int K = 32, M = 32, N = 32; + const int tb_offset = M * K; + auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] + auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + tb_offset, dextents(K, N));//[N, K] + + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(M, N, K, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mmOps; + + auto cT = mmOps.get_destination_cooperative_tensor(); + + // A: [32, 4] + int ml = tiitg / 4;// 0~31 + int kl = tiitg % 4;// 0~3 + + // B: [32, 4] + int nl = ml; + + // C: [32, 4] + int mcl = ml;// 0~31 + int ncl = kl;// 0~3 + + const int slq = gid.x; // q_seq_len/32 -> M/32 + const int slk = gid.y; // k_seq_len/32 -> N/32 + const int z = gid.z; // head_num * batch + + /** Q: + threadgroup: [M32, K32] -> [M32, K4, K2, K4] + index : [ml, kl, K2, K4] + each thread: K8 + layout: [B0, M, B1, K] -> [B0, M/32, M32, B1, K/32, K4, K2, K4] + index : [z/head_num, slq, ml, z%head_num, K/32, kl, K2, K4] + offset: ((z/head_num * q_seq_len + (slq * 32 + ml)) * head_num + z%head_num) * K/4 + (0 * 4 + kl) * 2 + 0 + */ + /** K: + threadgroup: [N32, K32] -> [M32, K4, K2, K4] + index : [nl, kl, K2, K4] + each thread: K8 + layout: [N, B/G, K] -> [N/32, N32, B/G, K/32, K4, K2, K4] + index : [slk, nl, B/G, K/32, kl, K2, K4] + offset: ((slk * 32 + nl) * B/G + z/G) * K/4 + (0 * 4 + kl) * 2 + 0 + */ + /** output: + threadgroup: [M32, N32] -> [M32, N4, N8] + each thread: N8 + layout: [B, M, N] -> [B, M/32, M32, N/32, N4, N8] + index : [z, slq, mcl, slk, ncl, N8] + offset: (z * q_seq_len + slq * 32 + mcl) * N + (slk * 4 + ncl) * 8 + 0 + */ + + int group = param.group; + int q_seq_len = param.query_seq_len; + int q_seq_piece_len = param.q_seq_piece_len; + int k_seq_len = param.key_seq_len; + int head_num = param.head_num; + int head_dim = param.head_dim; + + const int b = z / head_num; + const int hn = z % head_num; + int zin = hn / param.group; + + int idx_slq = seq_idx * q_seq_piece_len + slq * 32 + ml < q_seq_len ? seq_idx * q_seq_piece_len + slq * 32 + ml : q_seq_len - 1; + int idx_slk = slk * 32 + nl < k_seq_len ? slk * 32 + nl : k_seq_len - 1; + // [mBatch, mSeqLen, mNumHead, mHeadDim] + auto A_offset = input0 + ((b * q_seq_len + idx_slq) * head_num + hn) * head_dim / 4 + (0 * 4 + kl) * 2 + 0; + + // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] + auto B_offset = past_key + ((idx_slk * param.batch + b)* head_num / group + zin) * head_dim / 4 + (0 * 4 + kl) * 2 + 0; + + for(int i = 0; i < head_dim/4; i += 8){ + ((threadgroup ftype4*)sdata)[(ml * 4 + kl) * 2 + 0] = A_offset[i + 0]; + ((threadgroup ftype4*)sdata)[(ml * 4 + kl) * 2 + 1] = A_offset[i + 1]; + + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kl) * 2 + 0] = B_offset[i + 0]; + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kl) * 2 + 1] = B_offset[i + 1]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mmOps.run(sA, sB, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + auto tC = tensor, tensor_inline>((threadgroup float*)sdata, dextents(N, M)); // [M , N] + cT.store(tC); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // [M32, N4, N8] + auto sindex_base = (mcl * 4 + ncl) * 8 + 0; + + float Vscale = (float)param.scale; + + int base_k_idx = (slk * 4 + ncl) * 8 + 0; + auto xy_out = output + (z * q_seq_piece_len + slq * 32 + mcl) * k_seq_len + base_k_idx + 0; + if(slq * 32 + mcl < q_seq_piece_len && seq_idx * q_seq_piece_len + slq * 32 + mcl < q_seq_len) { + int ori_q_idx = seq_idx * q_seq_piece_len + slq * 32 + mcl; + if(base_k_idx + 0 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 0] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 0) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 0) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 0))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[0] = out0; + } + if(base_k_idx + 1 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 1] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 1) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 1) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 1))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[1] = out0; + } + if(base_k_idx + 2 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 2] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 2) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 2) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 2))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[2] = out0; + } + if(base_k_idx + 3 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 3] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 3) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 3) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 3))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[3] = out0; + } + if(base_k_idx + 4 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 4] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 4) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 4) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 4))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[4] = out0; + } + if(base_k_idx + 5 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 5] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 5) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 5) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 5))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[5] = out0; + } + if(base_k_idx + 6 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 6] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 6) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 6) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 6))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[6] = out0; + } + if(base_k_idx + 7 < k_seq_len) { + auto out0 = ((threadgroup float*)sdata)[sindex_base + 7] * Vscale; + #ifdef ADD_MASK + auto mask_val = (base_k_idx + 7) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (base_k_idx + 7) - k_seq_len + q_seq_len)] : 0.0; + out0 = mask_val + out0; + #elif defined(SET_MASK) + out0 = mask[(ori_q_idx * k_seq_len + (base_k_idx + 7))] == 0 ? -FLT_MAX : out0; + #endif + xy_out[7] = out0; + } + } + + + +} +#endif + +kernel void prefill_qk(const device ftype* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype* past_key [[buffer(2)]], constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], #ifdef ADD_MASK - const device T* mask [[buffer(5)]], + const device ftype* mask [[buffer(5)]], #elif defined(SET_MASK) const device int* mask [[buffer(5)]], #endif @@ -45,6 +268,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], uint3 gid[[thread_position_in_grid]] #endif ) { + #ifdef SIMD_GROUP_MATRIX /* @@ -52,15 +276,29 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], ftype 0~127 ---> input: [M16, K8] ftype 128~255 ---> input: [K8, N16] // Write: - ftype 0~255 ---> input: [N2, M2, M8, N8] + float 0~255 ---> input: [N2, M2, M8, N8] */ - - simdgroup_float8x8 sga[2]; - simdgroup_float8x8 sgb[2]; + threadgroup float sdata[256] = {0.f}; + +#ifdef USE_METAL_TENSOR_OPS + + const int K = 8, M = 16, N = 16; + auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] + auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + 128, dextents(N, K));//[K, N] + + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(M, N, K, false, false, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<1>> mmOps; + + auto cT = mmOps.get_destination_cooperative_tensor(); +#else + simdgroup_T8x8 sga[2]; + simdgroup_T8x8 sgb[2]; simdgroup_float8x8 sgd[4]; for (int i = 0; i < 4; i++){ sgd[i] = make_filled_simdgroup_matrix(0.f); } +#endif int kl = tiitg % 2;// 0~1 int rcl = tiitg / 2;// 0~15 @@ -102,55 +340,74 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], const int hn = z % head_num; int zin = hn / param.group; - threadgroup float sdata[256] = {0.f}; - int idx_slq = seq_idx * q_seq_piece_len + slq * 16 + rcl < q_seq_len ? seq_idx * q_seq_piece_len + slq * 16 + rcl : q_seq_len - 1; int idx_slk = slk * 16 + rcl < k_seq_len ? slk * 16 + rcl : k_seq_len - 1; // [mBatch, mSeqLen, mNumHead, mHeadDim] auto A_offset = input0 + ((b * q_seq_len + idx_slq) * head_num + hn) * head_dim + (0 * 2 + kl) * 4 + 0; + // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] auto B_offset = past_key + ((idx_slk * param.batch + b)* head_num / group + zin) * head_dim + 0 * 8 + kl * 4 + 0; for(int i = 0; i < head_dim; i += 8){ - sdata[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; - sdata[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; - sdata[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; - sdata[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; - - sdata[128 + (kl * 4 + 0) * 16 + rcl] = B_offset[i + 0]; - sdata[128 + (kl * 4 + 1) * 16 + rcl] = B_offset[i + 1]; - sdata[128 + (kl * 4 + 2) * 16 + rcl] = B_offset[i + 2]; - sdata[128 + (kl * 4 + 3) * 16 + rcl] = B_offset[i + 3]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; + + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 0) * 16 + rcl] = B_offset[i + 0]; + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 1) * 16 + rcl] = B_offset[i + 1]; + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 2) * 16 + rcl] = B_offset[i + 2]; + ((threadgroup ftype*)sdata)[128 + (kl * 4 + 3) * 16 + rcl] = B_offset[i + 3]; threadgroup_barrier(mem_flags::mem_threadgroup); - simdgroup_load(sga[0], (const threadgroup float*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8); +#ifdef USE_METAL_TENSOR_OPS + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mmOps.run(sA, sB, cT); +#else + simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); - simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16); - simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16); + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 128, 16); + simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 136, 16); simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); +#endif threadgroup_barrier(mem_flags::mem_threadgroup); } +#ifdef USE_METAL_TENSOR_OPS + + auto tC = tensor, tensor_inline>((threadgroup float*)sdata, dextents(N, M)); // [M , N] + cT.store(tC); +#else simdgroup_store(sgd[0], (threadgroup float*)sdata, 8); simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8); simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8); simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8); - +#endif + threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef USE_METAL_TENSOR_OPS + // [M16, N2, N8] + auto sindex_base = (rcl * 2 + kl) * 8 + 0; +#else // [N2, M2, M8, N8] + auto sindex_base = (kl * 16 + rcl) * 8 + 0; +#endif + float Vscale = (float)param.scale; auto xy_out = output + (z * q_seq_piece_len + slq * 16 + rcl) * k_seq_len + slk * 16 + kl * 8 + 0; if(slq * 16 + rcl < q_seq_piece_len && seq_idx * q_seq_piece_len + slq * 16 + rcl < q_seq_len) { int ori_q_idx = seq_idx * q_seq_piece_len + slq * 16 + rcl; if(slk * 16 + kl * 8 + 0 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 0] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 0] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 0) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 0) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -160,7 +417,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[0] = out0; } if(slk * 16 + kl * 8 + 1 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 1] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 1] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 1) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 1) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -170,7 +427,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[1] = out0; } if(slk * 16 + kl * 8 + 2 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 2] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 2] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 2) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 2) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -180,7 +437,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[2] = out0; } if(slk * 16 + kl * 8 + 3 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 3] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 3] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 3) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 3) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -190,7 +447,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[3] = out0; } if(slk * 16 + kl * 8 + 4 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 4] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 4] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 4) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 4) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -200,7 +457,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[4] = out0; } if(slk * 16 + kl * 8 + 5 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 5] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 5] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 5) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 5) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -210,7 +467,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[5] = out0; } if(slk * 16 + kl * 8 + 6 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 6] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 6] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 6) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 6) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -220,7 +477,7 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], xy_out[6] = out0; } if(slk * 16 + kl * 8 + 7 < k_seq_len) { - auto out0 = sdata[(kl * 16 + rcl) * 8 + 7] * Vscale; + auto out0 = ((threadgroup float*)sdata)[sindex_base + 7] * Vscale; #ifdef ADD_MASK auto mask_val = (slk * 16 + kl * 8 + 7) >= k_seq_len - q_seq_len ? mask[(ori_q_idx * q_seq_len + (slk * 16 + kl * 8 + 7) - k_seq_len + q_seq_len)] : 0.0; out0 = mask_val + out0; @@ -252,11 +509,11 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], const int offset_head = y * head_dim; const int offset_head_kv = (hn / group) * head_dim; // [mBatch, mSeqLen, mNumHead, mHeadDim] - const device T* A_offset = input0 + (b * query_seq_len + q_idx) * offset + offset_head; + const device ftype* A_offset = input0 + (b * query_seq_len + q_idx) * offset + offset_head; float Vscale = (float)param.scale; // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] - device const T* B_offset = past_key + (z * param.batch + b) * offset / group + offset_head_kv; + device const ftype* B_offset = past_key + (z * param.batch + b) * offset / group + offset_head_kv; const int output_offset = y * param.q_seq_piece_len * key_seq_len; float out0 = 0.0; @@ -274,18 +531,18 @@ kernel void prefill_qk(const device T* input0 [[buffer(0)]], #elif defined(SET_MASK) out0 = mask[((q_idx + 0) * key_seq_len + (z + 0))] == 0 ? -FLT_MAX : out0; #endif - output[output_offset + x * key_seq_len + z] = (T)out0; + output[output_offset + x * key_seq_len + z] = (ftype)out0; #endif } -kernel void decode_qk(const device T* input0 [[buffer(0)]], - device T* output [[buffer(1)]], - device T* past_key [[buffer(2)]], +kernel void decode_qk(const device ftype* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype* past_key [[buffer(2)]], // decode actually not compute in block constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], #ifdef ADD_MASK - const device T* mask [[buffer(5)]], + const device ftype* mask [[buffer(5)]], #elif defined(SET_MASK) const device int* mask [[buffer(5)]], #endif @@ -311,9 +568,9 @@ kernel void decode_qk(const device T* input0 [[buffer(0)]], const int offset_head_kv = kv_hn * head_dim; // [mBatch, mSeqLen, mNumHead, mHeadDim] - const device T* A_offset = input0 + (b * param.query_seq_len + x) * offset + offset_head; + const device ftype* A_offset = input0 + (b * param.query_seq_len + x) * offset + offset_head; // [mKvSeqLen, mBatch, mKvNumHead, mHeadDim] - device T* Pastkey_offset = past_key + (z * param.batch + b) * offset / group + offset_head_kv; + device ftype* Pastkey_offset = past_key + (z * param.batch + b) * offset / group + offset_head_kv; float Vscale = (float)param.scale; @@ -332,9 +589,9 @@ kernel void decode_qk(const device T* input0 [[buffer(0)]], #else { for(int i = 0; i < head_dim/4; i++){ - float4 B = float4(((const device T4*)Pastkey_offset)[i]); + float4 B = float4(((const device ftype4*)Pastkey_offset)[i]); for(int j = 0; j < group; j++) { - float4 A = float4(((const device T4*)(A_offset + head_dim * j))[i]); + float4 A = float4(((const device ftype4*)(A_offset + head_dim * j))[i]); out[j] += dot(A, B); } } @@ -352,7 +609,7 @@ kernel void decode_qk(const device T* input0 [[buffer(0)]], #elif SET_MASK out[j] = mask_val == 0 ? -FLT_MAX : out[j]; #endif - output[((y * group + j) * param.query_seq_len + x) * key_seq_len + z] = (T)out[j]; + output[((y * group + j) * param.query_seq_len + x) * key_seq_len + z] = (ftype)out[j]; } } @@ -371,10 +628,10 @@ struct Param { }; // Key: [batch, kv_seq_len, head_num / group * head_dim] -> [max_kv_len, batch, head_num / group * head_dim] // Value: [batch, kv_seq_len, head_num / group * head_dim] -> [batch, head_num / group * head_dim, max_kv_len] -kernel void copy(const device T* input0 [[buffer(0)]], - const device T* input1 [[buffer(1)]], - device T* output0 [[buffer(2)]], - device T* output1 [[buffer(3)]], +kernel void copy(const device ftype* input0 [[buffer(0)]], + const device ftype* input1 [[buffer(1)]], + device ftype* output0 [[buffer(2)]], + device ftype* output1 [[buffer(3)]], constant Param& param [[buffer(4)]], uint3 gid[[thread_position_in_grid]] ) { @@ -394,7 +651,10 @@ kernel void copy(const device T* input0 [[buffer(0)]], )metal"; const char* gMatMulQKV = R"metal( - +#ifdef USE_METAL_TENSOR_OPS +#include +#include +#endif #include #include using namespace metal; @@ -408,11 +668,154 @@ struct Param { float scale; int max_kv_len; int batch; + int kv_align_len; }; +#if MNN_METAL_FLOAT16_STORAGE +typedef simdgroup_half8x8 simdgroup_T8x8; +#else +typedef simdgroup_float8x8 simdgroup_T8x8; +#endif + +#ifdef USE_METAL_TENSOR_OPS +kernel void prefill_qkv_tensor(const device ftype* input0 [[buffer(0)]], + device ftype4* output [[buffer(1)]], + device ftype4* past_value [[buffer(2)]], + constant int &seq_idx [[buffer(3)]], + constant Param& param [[buffer(4)]], + uint3 gid[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]] +) { + /* + // Read: + ftype 0~1023 ---> input: [M32, K32] + ftype 1024~2047 ---> input: [N32, K32] + // Write: + float 0~1023 ---> input: [M32, N32] + */ + + threadgroup ftype sdata[2048] = {0.f}; + + const int K = 32, M = 32, N = 32; + const int tb_offset = M * K; + auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] + auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + tb_offset, dextents(K, N));//[N, K] + + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(M, N, K, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<4>> mmOps; + + auto cT = mmOps.get_destination_cooperative_tensor(); + + // QK:[32, 4] + int ml = tiitg / 4;// 0~31 + int kl = tiitg % 4;// 0~3 + + // V: [32, 4] + int nl = ml;// 0~31 + int kvl = kl;// 0~3 + + // QKV: [32, 4] + int mcl = ml;// 0~31 + int ncl = kl;// 0~3 + + const int sl = gid.x; // q_seq_len/32 -> M/32 + const int hm = gid.y; // head_dim/32 -> N/32 + const int z = gid.z; // head_num * batch + + /** QK: + threadgroup: [M32, K32] -> [M32, K4, K8] + index; [ml, kl, K8] + each thread: K8 + layout: [B, M, K] -> [B, M/32, M32, K/32, K4, K8] + index : [z, sl, ml, K/32, kl, K2, K4] + offset: (z * M + sl * 32 + ml) * K + (0 * 4 + kl) * 8 + 0 + */ + /** V: + threadgroup: [N32, K32] -> [N32, K4, K8] + index; [nl, kvl, K8] + each thread: K8 + layout: [B/G, N, K] -> [B/G, N/32, N32, K/32, K4, K8] + index : [zin, hm, nl, K/32, kvl, K2, K4] + offset: ((zin * head_dim + hm * 32 + nl) * param.max_kv_len/4 + (0 * 4 + kvl) * 2 + 0) + */ + /** output: + threadgroup: [M32, N32] -> [M32, N4, N8] + index: [mcl, ncl, N8] + each thread: N8 + layout: [B0, M, B1, N] -> [B0, M/32, M32, B1, N/32, N4, N8] + index : [B0, sl, mcl, B1, hm, ncl, N2, N4] + offset: ((b * q_seq_len + (sl * 32 + mcl)) * head_num + hn) * N/4 + (hm * 4 + ncl) * 2 + 0 + */ + + int group = param.group; + int q_seq_len = param.query_seq_len; + int q_seq_piece_len = param.q_seq_piece_len; + int value_seq_len = param.key_seq_len; + int align_value_len = ((value_seq_len + param.kv_align_len - 1) / param.kv_align_len) * param.kv_align_len; + + int head_num = param.head_num; + int head_dim = param.head_dim; + int b = z / head_num; + int hn = z % head_num; + int zin = b * (head_num / group) + hn / group; + + int idx_qk_sl = sl * 32 + ml < q_seq_piece_len ? (sl * 32 + ml) : q_seq_piece_len - 1; + + auto A_offset = input0 + (z * q_seq_piece_len + idx_qk_sl) * align_value_len + (0 * 4 + kl) * 8 + 0; + auto B_offset = past_value + (zin * head_dim + hm * 32 + nl) * param.max_kv_len / 4 + (0 * 4 + kvl) * 2 + 0; + + + for(int i = 0; i < (value_seq_len+3)/4; i += 8){ + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 0] = A_offset[4*i + 0]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 1] = A_offset[4*i + 1]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 2] = A_offset[4*i + 2]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 3] = A_offset[4*i + 3]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 4] = A_offset[4*i + 4]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 5] = A_offset[4*i + 5]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 6] = A_offset[4*i + 6]; + ((threadgroup ftype*)sdata)[(ml * 4 + kl) * 8 + 7] = A_offset[4*i + 7]; + + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kvl) * 2 + 0] = B_offset[i + 0]; + ((threadgroup ftype4*)sdata)[256 + (nl * 4 + kvl) * 2 + 1] = B_offset[i + 1]; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mmOps.run(sA, sB, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + auto tC = tensor, tensor_inline>((threadgroup float*)sdata, dextents(N, M)); // [M , N] + cT.store(tC); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // [M32, N4, N2, n4] + auto sindex_base = (mcl * 4 + ncl) * 2 + 0; + + // [M32, N4, N8] + // [mBatch, mSeqLen, mNumHead, mHeadDim] + auto xy_out = output + ((b * q_seq_len + seq_idx * q_seq_piece_len + sl * 32 + mcl) * head_num + hn) * head_dim/4 + (hm * 4 + ncl) * 2 + 0; + if(sl * 32 + mcl < q_seq_piece_len && seq_idx * q_seq_piece_len + sl * 32 + mcl < q_seq_len) { + if((hm * 4 + ncl) * 2 + 0 < head_dim/4) { + xy_out[0] = ftype4(((threadgroup float4*)sdata)[sindex_base + 0]); + } + if((hm * 4 + ncl) * 2 + 1 < head_dim/4) { + xy_out[1] = ftype4(((threadgroup float4*)sdata)[sindex_base + 1]); + } + } + +} +#endif + #define SIMD_GROUP_WIDTH 32 -kernel void prefill_qkv(const device T* input0 [[buffer(0)]], - device T* output [[buffer(1)]], - device T* past_value [[buffer(2)]], +kernel void prefill_qkv(const device ftype* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype* past_value [[buffer(2)]], constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], #ifdef SIMD_GROUP_MATRIX @@ -432,19 +835,34 @@ kernel void prefill_qkv(const device T* input0 [[buffer(0)]], // Write: ftype 0~255 ---> input: [N2, M2, M8, N8] */ - - simdgroup_float8x8 sga[2]; - simdgroup_float8x8 sgb[2]; + + threadgroup float sdata[256] = {0.f}; + +#ifdef USE_METAL_TENSOR_OPS + + const int K = 8, M = 16, N = 16; + auto tA = tensor, tensor_inline>((threadgroup ftype*)sdata, dextents(K, M));//[M, K] + auto tB = tensor, tensor_inline>((threadgroup ftype*)sdata + 128, dextents(N, K));//[K, N] + + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor(M, N, K, false, false, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<1>> mmOps; + + auto cT = mmOps.get_destination_cooperative_tensor(); +#else + simdgroup_T8x8 sga[2]; + simdgroup_T8x8 sgb[2]; simdgroup_float8x8 sgd[4]; for (int i = 0; i < 4; i++){ sgd[i] = make_filled_simdgroup_matrix(0.f); } +#endif - int kl = tiitg % 2;// 0~1 int rcl = tiitg / 2;// 0~15 + int kl = tiitg % 2;// 0~1 - int nl = tiitg % 4;// 0~3 - int kcl = tiitg / 4;// 0~7 + int nl = tiitg / 8;// 0~3 + int kcl = tiitg % 8;// 0~7 const int sl = gid.x; // q_seq_len/16 -> M/16 const int hm = gid.y; // head_dim/16 -> N/16 @@ -476,81 +894,100 @@ kernel void prefill_qkv(const device T* input0 [[buffer(0)]], int q_seq_len = param.query_seq_len; int q_seq_piece_len = param.q_seq_piece_len; int value_seq_len = param.key_seq_len; + int align_value_len = ((value_seq_len + param.kv_align_len - 1) / param.kv_align_len) * param.kv_align_len; int head_num = param.head_num; int head_dim = param.head_dim; int b = z / head_num; int hn = z % head_num; int zin = b * (head_num / group) + hn / group; - threadgroup float sdata[256] = {0.f}; - int idx_qk_sl = sl * 16 + rcl < q_seq_piece_len ? (sl * 16 + rcl) : q_seq_piece_len - 1; - auto A_offset = input0 + (z * q_seq_piece_len + idx_qk_sl) * value_seq_len + (0 * 2 + kl) * 4 + 0; + auto A_offset = input0 + (z * q_seq_piece_len + idx_qk_sl) * align_value_len + (0 * 2 + kl) * 4 + 0; auto B_offset = past_value + (zin * head_dim + hm * 16 + nl * 4 + 0) * param.max_kv_len + (0 * 8 + kcl); for(int i = 0; i < value_seq_len; i += 8){ - sdata[rcl * 8 + kl * 4 + 0] = (i + kl * 4 + 0 < value_seq_len) ? A_offset[i + 0] : 0.0; - sdata[rcl * 8 + kl * 4 + 1] = (i + kl * 4 + 1 < value_seq_len) ? A_offset[i + 1] : 0.0; - sdata[rcl * 8 + kl * 4 + 2] = (i + kl * 4 + 2 < value_seq_len) ? A_offset[i + 2] : 0.0; - sdata[rcl * 8 + kl * 4 + 3] = (i + kl * 4 + 3 < value_seq_len) ? A_offset[i + 3] : 0.0; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 0] = A_offset[i + 0]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 1] = A_offset[i + 1]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 2] = A_offset[i + 2]; + ((threadgroup ftype*)sdata)[rcl * 8 + kl * 4 + 3] = A_offset[i + 3]; - sdata[128 + kcl * 16 + nl * 4 + 0] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 0 < head_dim) ? B_offset[i + 0 * param.max_kv_len] : 0.0; - sdata[128 + kcl * 16 + nl * 4 + 1] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 1 < head_dim) ? B_offset[i + 1 * param.max_kv_len] : 0.0; - sdata[128 + kcl * 16 + nl * 4 + 2] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 2 < head_dim) ? B_offset[i + 2 * param.max_kv_len] : 0.0; - sdata[128 + kcl * 16 + nl * 4 + 3] = (i + kcl < value_seq_len && hm * 16 + nl * 4 + 3 < head_dim) ? B_offset[i + 3 * param.max_kv_len] : 0.0; - + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 0] = B_offset[i + 0 * param.max_kv_len]; + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 1] = B_offset[i + 1 * param.max_kv_len]; + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 2] = B_offset[i + 2 * param.max_kv_len]; + ((threadgroup ftype*)sdata)[128 + kcl * 16 + nl * 4 + 3] = B_offset[i + 3 * param.max_kv_len]; threadgroup_barrier(mem_flags::mem_threadgroup); - simdgroup_load(sga[0], (const threadgroup float*)sdata, 8); - simdgroup_load(sga[1], ((const threadgroup float*)sdata) + 64, 8); +#ifdef USE_METAL_TENSOR_OPS + auto sA = tA.slice(0, 0); + auto sB = tB.slice(0, 0); + + mmOps.run(sA, sB, cT); +#else + simdgroup_load(sga[0], (const threadgroup ftype*)sdata, 8); + simdgroup_load(sga[1], ((const threadgroup ftype*)sdata) + 64, 8); - simdgroup_load(sgb[0], ((const threadgroup float*)sdata) + 128, 16); - simdgroup_load(sgb[1], ((const threadgroup float*)sdata) + 136, 16); + simdgroup_load(sgb[0], ((const threadgroup ftype*)sdata) + 128, 16); + simdgroup_load(sgb[1], ((const threadgroup ftype*)sdata) + 136, 16); simdgroup_multiply_accumulate(sgd[0], sga[0], sgb[0], sgd[0]); simdgroup_multiply_accumulate(sgd[1], sga[1], sgb[0], sgd[1]); simdgroup_multiply_accumulate(sgd[2], sga[0], sgb[1], sgd[2]); simdgroup_multiply_accumulate(sgd[3], sga[1], sgb[1], sgd[3]); +#endif threadgroup_barrier(mem_flags::mem_threadgroup); } +#ifdef USE_METAL_TENSOR_OPS + + auto tC = tensor, tensor_inline>((threadgroup float*)sdata, dextents(N, M)); // [M , N] + cT.store(tC); +#else simdgroup_store(sgd[0], (threadgroup float*)sdata, 8); simdgroup_store(sgd[1], (threadgroup float*)sdata + 64, 8); simdgroup_store(sgd[2], (threadgroup float*)sdata + 128, 8); simdgroup_store(sgd[3], (threadgroup float*)sdata + 192, 8); - +#endif + threadgroup_barrier(mem_flags::mem_threadgroup); +#ifdef USE_METAL_TENSOR_OPS + // [M16, N2, N8] + auto sindex_base = (rcl * 2 + kl) * 8 + 0; +#else + // [N2, M2, M8, N8] + auto sindex_base = (kl * 16 + rcl) * 8 + 0; +#endif + // [N2, M2, M8, N8] // [mBatch, mSeqLen, mNumHead, mHeadDim] auto xy_out = output + ((b * q_seq_len + seq_idx * q_seq_piece_len + sl * 16 + rcl) * head_num + hn) * head_dim + hm * 16 + kl * 8 + 0; if(sl * 16 + rcl < q_seq_piece_len && seq_idx * q_seq_piece_len + sl * 16 + rcl < q_seq_len) { if(hm * 16 + kl * 8 + 0 < head_dim) { - xy_out[0] = sdata[(kl * 16 + rcl) * 8 + 0]; + xy_out[0] = ((threadgroup float*)sdata)[sindex_base + 0]; } if(hm * 16 + kl * 8 + 1 < head_dim) { - xy_out[1] = sdata[(kl * 16 + rcl) * 8 + 1]; + xy_out[1] = ((threadgroup float*)sdata)[sindex_base + 1]; } if(hm * 16 + kl * 8 + 2 < head_dim) { - xy_out[2] = sdata[(kl * 16 + rcl) * 8 + 2]; + xy_out[2] = ((threadgroup float*)sdata)[sindex_base + 2]; } if(hm * 16 + kl * 8 + 3 < head_dim) { - xy_out[3] = sdata[(kl * 16 + rcl) * 8 + 3]; + xy_out[3] = ((threadgroup float*)sdata)[sindex_base + 3]; } if(hm * 16 + kl * 8 + 4 < head_dim) { - xy_out[4] = sdata[(kl * 16 + rcl) * 8 + 4]; + xy_out[4] = ((threadgroup float*)sdata)[sindex_base + 4]; } if(hm * 16 + kl * 8 + 5 < head_dim) { - xy_out[5] = sdata[(kl * 16 + rcl) * 8 + 5]; + xy_out[5] = ((threadgroup float*)sdata)[sindex_base + 5]; } if(hm * 16 + kl * 8 + 6 < head_dim) { - xy_out[6] = sdata[(kl * 16 + rcl) * 8 + 6]; + xy_out[6] = ((threadgroup float*)sdata)[sindex_base + 6]; } if(hm * 16 + kl * 8 + 7 < head_dim) { - xy_out[7] = sdata[(kl * 16 + rcl) * 8 + 7]; + xy_out[7] = ((threadgroup float*)sdata)[sindex_base + 7]; } } @@ -568,6 +1005,7 @@ kernel void prefill_qkv(const device T* input0 [[buffer(0)]], int value_seq_len = param.key_seq_len; int head_num = param.head_num; int head_dim = param.head_dim; + int align_value_len = ((value_seq_len + param.kv_align_len - 1) / param.kv_align_len) * param.kv_align_len; int b = y / head_num; int hn = y % head_num; @@ -578,8 +1016,8 @@ kernel void prefill_qkv(const device T* input0 [[buffer(0)]], const int offset_head = yin * head_dim + z; // [mBatch, mNumHead, mSeqLen, mKvSeqLen] - device const T *A_offset = input0 + (y * q_seq_piece_len + x) * value_seq_len; - device const T *B_offset = past_value + offset_head * param.max_kv_len; + device const ftype *A_offset = input0 + (y * q_seq_piece_len + x) * align_value_len; + device const ftype *B_offset = past_value + offset_head * param.max_kv_len; float out = 0.0; for(int i = 0; i < value_seq_len; ++i){ @@ -592,9 +1030,9 @@ kernel void prefill_qkv(const device T* input0 [[buffer(0)]], #endif } -kernel void decode_qkv(const device T* input0 [[buffer(0)]], - device T* output [[buffer(1)]], - device T* past_value [[buffer(2)]], +kernel void decode_qkv(const device ftype* input0 [[buffer(0)]], + device ftype* output [[buffer(1)]], + device ftype* past_value [[buffer(2)]], // docode actually not compute in block constant int &seq_idx [[buffer(3)]], constant Param& param [[buffer(4)]], @@ -621,11 +1059,12 @@ kernel void decode_qkv(const device T* input0 [[buffer(0)]], int yin = b * (head_num / group) + hn / group; int value_seq_len = param.key_seq_len; + int align_value_len = ((value_seq_len + param.kv_align_len - 1) / param.kv_align_len) * param.kv_align_len; const int offset_head = (yin * head_dim + z) * param.max_kv_len; - device const T *A_offset = input0 + (y * q_seq_len + x) * value_seq_len; - device T *Pastvalue_offset = past_value + offset_head; + device const ftype *A_offset = input0 + (y * q_seq_len + x) * align_value_len; + device ftype *Pastvalue_offset = past_value + offset_head; float out = 0; #ifdef SIMD_GROUP_REDUCE @@ -638,7 +1077,7 @@ kernel void decode_qkv(const device T* input0 [[buffer(0)]], out = simd_sum(out); if(tiisg == 0) { // [mBatch, mSeqLen, mNumHead, mHeadDim] - output[((b * q_seq_len + x) * head_num + hn) * head_dim + z] = (T)out; + output[((b * q_seq_len + x) * head_num + hn) * head_dim + z] = (ftype)out; } #else for(int i = 0; i < value_seq_len; i++){ @@ -647,7 +1086,7 @@ kernel void decode_qkv(const device T* input0 [[buffer(0)]], out += A * B; } - output[((b * q_seq_len + x) * head_num + hn) * head_dim + z] = (T)out; + output[((b * q_seq_len + x) * head_num + hn) * head_dim + z] = (ftype)out; #endif } )metal"; @@ -659,7 +1098,7 @@ struct softmax_shape { int inside_size; int axis_length; int outside_size; - int flat_length; + int axis_align_length; }; #define SIMD_GROUP_WIDTH 32 @@ -674,9 +1113,10 @@ kernel void softmax_plane_sg(const device ftype *in [[buffer(0)]], // simdgroup compute axis data if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return; - auto axis_off = gid.y * s.axis_length * s.inside_size + gid.x; - auto axis_in = in + axis_off; - auto axis_out = out + axis_off; + auto in_offset = gid.y * s.axis_length * s.inside_size + gid.x; + auto out_offset = gid.y * s.axis_align_length * s.inside_size + gid.x; + auto axis_in = in + in_offset; + auto axis_out = out + out_offset; // get max float max1 = -FLT_MAX; @@ -693,8 +1133,8 @@ kernel void softmax_plane_sg(const device ftype *in [[buffer(0)]], sum1 = simd_sum(sum1); // output - for (int i = tiisg; i < s.axis_length; i+=SIMD_GROUP_WIDTH) { - axis_out[i * s.inside_size] = ftype(exp(float(axis_in[i * s.inside_size]) - float(max1)) / sum1); + for (int i = tiisg; i < s.axis_align_length; i+=SIMD_GROUP_WIDTH) { + axis_out[i * s.inside_size] = i >= s.axis_length ? ftype(0.0) : ftype(exp(float(axis_in[i * s.inside_size]) - float(max1)) / sum1); } } diff --git a/source/backend/metal/MetalBackend.hpp b/source/backend/metal/MetalBackend.hpp index ecbfacb636..7f35782028 100644 --- a/source/backend/metal/MetalBackend.hpp +++ b/source/backend/metal/MetalBackend.hpp @@ -42,6 +42,9 @@ class MetalRuntime : public Runtime { bool supportSimdGroupMatrix() { return mSimdGroupMatrix; } + bool supportTensorOps() { + return mTensorOps; + } void setGpuMode(const int cl_mode_num); void setCommandQueue(id queue, bool userSync); id getCommandQueue() const { @@ -109,6 +112,7 @@ class MetalRuntime : public Runtime { private: bool mSimdGroupReduce; bool mSimdGroupMatrix; + bool mTensorOps; }; @@ -162,6 +166,7 @@ class MetalBackend : public Backend { static void setTensor(const MNN::Tensor* tensor, id encoder, int index); static void setMem(const MemChunk& chunk, id encoder, int index); static uint8_t* getMemPtr(const MemChunk& chunk); + static void setBuffer(id buffer, int offset, id encoder, int index); static std::pair, int> getBuffer(const MNN::Tensor* tensor); size_t getTensorSizeInBytes(const Tensor* tensor) const; virtual bool onSelectDynamicAllocator(int index, int maxIndex) override; @@ -170,7 +175,7 @@ class MetalBackend : public Backend { void returnConstBuffer(id buffer) const; id makeComputePipelineWithSourceOption(const char* csource, const char* cname, MTLCompileOptions *options) const; public: - MetalBackend(std::shared_ptr staticMem, const MetalRuntime* runtime, bool usefp16AsFp32, BackendConfig::MemoryMode mode); + MetalBackend(const MetalRuntime* runtime, bool usefp16AsFp32, BackendConfig::MemoryMode mode); virtual ~MetalBackend(); virtual Runtime* getRuntime() override { return (Runtime*)mRuntime; @@ -214,7 +219,7 @@ class MetalBackend : public Backend { BufferAllocator* getBufferPool() const; EagerBufferAllocator *getStaticBufferPool() const { - return mStaticBufferPool.get(); + return mRuntime->mStaticAllocator.get(); } id getCommandBufferForBufferCopy() const; @@ -260,7 +265,6 @@ class MetalBackend : public Backend { mutable id mComputeEncoder = nil; std::shared_ptr mBufferPool; std::shared_ptr mBufferPoolShapeImmutable; - std::shared_ptr mStaticBufferPool; private: void _resetDynamicMemory() const; diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 9f285bee84..79f52ff2dc 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -11,6 +11,10 @@ #import #define METAL_CONST_BUFFER_LIMIT 128 #define METAL_SEPERATE_MAX_COUNT 2 +// overload of MTLGPUFamilyMetal3/MTLGPUFamilyMetal4 (not available in some environments) +#define MTLGPUFamilyMetal3_MNN 5001 +#define MTLGPUFamilyMetal4_MNN 5002 + #if MNN_METAL_ENABLED #include #import "backend/metal/MNNMetalContext.h" @@ -72,14 +76,13 @@ static void _MetalApplyTensor(uint8_t* host, size_t offset, Tensor* t) { map->insert(std::make_pair(t, c)); } -MetalBackend::MetalBackend(std::shared_ptr staticMem, const MetalRuntime* runtime, bool usefp16AsFp32, BackendConfig::MemoryMode mode) : Backend(MNN_FORWARD_METAL), +MetalBackend::MetalBackend(const MetalRuntime* runtime, bool usefp16AsFp32, BackendConfig::MemoryMode mode) : Backend(MNN_FORWARD_METAL), mEmptyMem(nil) { mRuntime = runtime; auto ctx = (__bridge MNNMetalContext *)runtime->context(); mBufferPool.reset(runtime->createDynamicAllocator(0, false)); mCurrentAllocator = mBufferPool.get(); - mStaticBufferPool = staticMem; mUseFloatAsFp16 = usefp16AsFp32; mMemoryMode = mode; mIsIphone = ctx.isIphone; @@ -182,8 +185,8 @@ MemChunk chunk() override { BufferAllocator* allocator = nullptr; switch (storageType) { case Backend::STATIC: { - buffer = mStaticBufferPool->alloc(size, false); - allocator = mStaticBufferPool.get(); + buffer = mRuntime->mStaticAllocator->alloc(size, false); + allocator = mRuntime->mStaticAllocator.get(); } break; case Backend::DYNAMIC: { buffer = mCurrentAllocator->alloc(size, false); @@ -812,7 +815,9 @@ static void _execute(id encoder, const MetalBackend::C uint8_t* MetalBackend::getMemPtr(const MemChunk& chunk) { return (uint8_t*)((MetalRuntimeAllocator::MetalBufferAlloc *)chunk.first)->getBuffer().contents + chunk.second; } - +void MetalBackend::setBuffer(id buffer, int offset, id encoder, int index) { + [encoder setBuffer:buffer offset:offset atIndex:index]; +} std::pair, int> MetalBackend::getBuffer(const MNN::Tensor* tensor) { return std::make_pair(((MetalRuntimeAllocator::MetalBufferAlloc *)tensor->deviceId())->getBuffer(), TensorUtils::getDescribe(tensor)->extra.offset); } @@ -1002,8 +1007,23 @@ static void _execute(id encoder, const MetalBackend::C auto ctx = (__bridge MNNMetalContext *)mContext; std::shared_ptr allocator(new MetalRuntimeAllocator([ctx device])); mSimdGroupReduce = [[ctx device] supportsFamily:MTLGPUFamilyApple7]; - mSimdGroupReduce |= [[ctx device] supportsFamily:MTLGPUFamilyMetal3]; + mSimdGroupReduce |= [[ctx device] supportsFamily:(MTLGPUFamily)MTLGPUFamilyMetal3_MNN]; mSimdGroupMatrix = [[ctx device] supportsFamily:MTLGPUFamilyApple7]; + // Metal4 Support M1/A14 and later chips + mTensorOps = [[ctx device] supportsFamily:(MTLGPUFamily)MTLGPUFamilyMetal4_MNN]; + + // AI TensorCore device support from M5/A19 + bool noAICoreDevice = [[[ctx device] name] containsString:@"M1"] || \ + [[[ctx device] name] containsString:@"M2"] || \ + [[[ctx device] name] containsString:@"M3"] || \ + [[[ctx device] name] containsString:@"M4"] || \ + [[[ctx device] name] containsString:@"A14"] || \ + [[[ctx device] name] containsString:@"A15"] || \ + [[[ctx device] name] containsString:@"A16"] || \ + [[[ctx device] name] containsString:@"A17"] || \ + [[[ctx device] name] containsString:@"A18"]; + mTensorOps = mTensorOps && !noAICoreDevice; +// MNN_PRINT("Metal device name %s, open tensor: %d\n\n", [[[ctx device] name] UTF8String], mTensorOps); mStaticAllocator.reset(new EagerBufferAllocator(allocator)); mDynamic.resize(METAL_SEPERATE_MAX_COUNT); for (auto& buf : mDynamic) { @@ -1200,6 +1220,7 @@ virtual MemChunk onAlloc(size_t size, size_t align) override { auto mem = mOrigin->onAlloc(size, align); MNN_ASSERT(mem.second == 0); id buffer = [mDevice newBufferWithBytesNoCopy:mem.first length:size options:MTLResourceStorageModeShared deallocator:nil]; + auto wrap = new MetalRuntimeAllocator::MetalBufferAlloc(buffer); return MemChunk((void *)wrap, 0); } @@ -1236,7 +1257,7 @@ virtual void onRelease(MemChunk chunk) override { memory = config->memory; } bool useFp16AsFp32 = precision != BackendConfig::Precision_High; - auto backend = new MetalBackend(mStaticAllocator, this, useFp16AsFp32, memory); + auto backend = new MetalBackend(this, useFp16AsFp32, memory); backend->setMetaPtr(pMeta); return backend; } diff --git a/source/backend/metal/MetalBinary.mm b/source/backend/metal/MetalBinary.mm index 6854a55782..20ae0fafdc 100755 --- a/source/backend/metal/MetalBinary.mm +++ b/source/backend/metal/MetalBinary.mm @@ -90,7 +90,7 @@ #include #include using namespace metal; -kernel void main0(const device T0 *in0 [[buffer(0)]], +kernel void binary(const device T0 *in0 [[buffer(0)]], const device T1 *in1 [[buffer(1)]], device T2 *out [[buffer(2)]], constant int4& s [[buffer(3)]], uint gid [[thread_position_in_grid]]) { if ((int)gid >= s.z) return; auto V0 = in0[s.x * int(gid)]; @@ -132,7 +132,7 @@ kernel void main0(const device T0 *in0 [[buffer(0)]], @"T2" : T2, @"CUSTOM" : custom, }; - pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryTemplate, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gBinaryTemplate, "binary", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { diff --git a/source/backend/metal/MetalCast.mm b/source/backend/metal/MetalCast.mm index b44e622c3c..efc8d80c29 100755 --- a/source/backend/metal/MetalCast.mm +++ b/source/backend/metal/MetalCast.mm @@ -17,7 +17,7 @@ R"glsl( #include using namespace metal; - kernel void main0(const device T0 *in [[buffer(0)]], + kernel void cast(const device T0 *in [[buffer(0)]], device T1 *out [[buffer(1)]], device uint4& s [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { @@ -206,7 +206,7 @@ static DataType _mapDataType(DataType src) { @"T1" : T1, @"TRANSOFRM" : TRANSOFRM }; - pipeline = mtbn->makeComputePipelineWithSourceOption(gCastTemplate, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gCastTemplate, "cast", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { @@ -223,7 +223,7 @@ static DataType _mapDataType(DataType src) { #include #include using namespace metal; -kernel void main0(device T* uOutput [[buffer(0)]], const device int* uSelect [[buffer(1)]], const device T* uInput0 [[buffer(2)]], const device T* uInput1 [[buffer(3)]], constant int4& uStride [[buffer(4)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +kernel void select(device T* uOutput [[buffer(0)]], const device int* uSelect [[buffer(1)]], const device T* uInput0 [[buffer(2)]], const device T* uInput1 [[buffer(3)]], constant int4& uStride [[buffer(4)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) { int i = int(gl_GlobalInvocationID.x); if (i < uStride.w) @@ -288,7 +288,7 @@ virtual void onEncode(const std::vector &inputs, const std::vectormakeComputePipelineWithSourceOption(gSelectTemplate, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gSelectTemplate, "select", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { @@ -325,7 +325,7 @@ virtual void onEncode(const std::vector &inputs, const std::vector #include using namespace metal; -kernel void main0(device T* uOutput [[buffer(0)]], const device T* uStart [[buffer(1)]], const device T* uDelta [[buffer(2)]], constant int4& uSize [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +kernel void range(device T* uOutput [[buffer(0)]], const device T* uStart [[buffer(1)]], const device T* uDelta [[buffer(2)]], constant int4& uSize [[buffer(3)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) { int i = int(gl_GlobalInvocationID.x); if(i < uSize.w) { @@ -348,7 +348,7 @@ kernel void main0(device T* uOutput [[buffer(0)]], const device T* uStart [[buff compileOptions.preprocessorMacros = @{ @"T" : T, }; - pipeline = mtbn->makeComputePipelineWithSourceOption(gRangeTemplate, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gRangeTemplate, "range", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { diff --git a/source/backend/metal/MetalConvolution1x1.hpp b/source/backend/metal/MetalConvolution1x1.hpp index bda5a483fa..866686957f 100644 --- a/source/backend/metal/MetalConvolution1x1.hpp +++ b/source/backend/metal/MetalConvolution1x1.hpp @@ -26,6 +26,10 @@ class MetalConvolution1x1 : public MetalConvolutionCommon { MetalConvolution1x1(Backend *backend, const MNN::Op *op, std::shared_ptr weight, std::shared_ptr bias, std::shared_ptr dequantScale, int dequantBits, float scaleCoef); id mPipeline; std::pair mThreads; + id mDequantPipeline; + std::pair mDequantThreads; + bool mPreDequantWeight = false; + std::shared_ptr mTempWeight; }; } // namespace MNN diff --git a/source/backend/metal/MetalConvolution1x1.mm b/source/backend/metal/MetalConvolution1x1.mm index eed4142280..fea0cca036 100644 --- a/source/backend/metal/MetalConvolution1x1.mm +++ b/source/backend/metal/MetalConvolution1x1.mm @@ -102,34 +102,55 @@ param->scale_coef = mScaleCoef; int area = ob * ow * oh; // basic marco info + std::string ftype = "float"; std::string ftype2 = "float2"; std::string ftype4 = "float4"; + std::string ftype2x4 = "float2x4"; std::string ftype4x4 = "float4x4"; if (backend->useFp16InsteadFp32()) { + ftype = "half"; ftype2 = "half2"; ftype4 = "half4"; + ftype2x4 = "half2x4"; ftype4x4 = "half4x4"; } MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; - auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; - [dic setValue:@(ftype2.c_str()) forKey:@"ftype2"]; - [dic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; - [dic setValue:@(ftype4x4.c_str()) forKey:@"ftype4x4"]; - [dic setValue:@"1" forKey:@"MNN_METAL_FLOAT32_COMPUTER"];; - - if(mDequantBits == 4) { - [dic setValue:@"1" forKey:@"W_QUANT_4"]; - } else if(mDequantBits == 8) { - [dic setValue:@"1" forKey:@"W_QUANT_8"]; + auto baseDic = [NSMutableDictionary dictionaryWithCapacity:0]; + [baseDic setValue:@(ftype.c_str()) forKey:@"ftype"]; + [baseDic setValue:@(ftype2.c_str()) forKey:@"ftype2"]; + [baseDic setValue:@(ftype4.c_str()) forKey:@"ftype4"]; + [baseDic setValue:@(ftype2x4.c_str()) forKey:@"ftype2x4"]; + [baseDic setValue:@(ftype4x4.c_str()) forKey:@"ftype4x4"]; + [baseDic setValue:@"1" forKey:@"MNN_METAL_FLOAT32_COMPUTER"]; + if (backend->useFp16InsteadFp32()) { + [baseDic setValue:@"1" forKey:@"MNN_METAL_FLOAT16_STORAGE"]; } - - option.preprocessorMacros = dic; std::vector baseKeys = {ftype4, "MNN_METAL_FLOAT32_COMPUTER"}; MetalRuntime* rt = (MetalRuntime *)backend->runtime(); + std::string basicShaderPrefix = gBasicConvPrefix; + + // if M is small, dequant weight in shader + // if device not support simdgroup matrix, only support dequant in shader + bool dequantInShader = (area < 128) || !(rt->supportSimdGroupMatrix()); + mPreDequantWeight = false; + #ifdef MNN_LOW_MEMORY - if (mDequantScaleBias.get()) { + if (mDequantScaleBias.get() && dequantInShader) { + //printf("inner dequant MNK: %d %d %d %d\n", area, oc, ic, blockSize); + + std::string sgmWqShader = gConv1x1WqSgMatrix; + std::string sgrWqShader = gConv1x1WqSgReduce; + + NSMutableDictionary *dic = [baseDic mutableCopy]; + if(mDequantBits == 4) { + [dic setValue:@"1" forKey:@"W_QUANT_4"]; + } else if(mDequantBits == 8) { + [dic setValue:@"1" forKey:@"W_QUANT_8"]; + } + option.preprocessorMacros = dic; + NSUInteger gid_x = UP_DIV(ow * oh, 4); NSUInteger gid_y = oc_4; NSUInteger gid_z = ob; @@ -147,7 +168,8 @@ } if(rt->supportSimdGroupReduce() && area <= short_seq) { baseKeys.emplace_back("conv1x1_wquant_sg_reduce"); - + + std::string sgrWqStr = basicShaderPrefix + sgrWqShader; if(area > 1) { auto keys = baseKeys; int piece = 1; @@ -168,7 +190,7 @@ keys.emplace_back(kernel_name); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgReduce, kernel_name.c_str(), option); + pipeline = backend->makeComputePipelineWithSourceOption(sgrWqStr.c_str(), kernel_name.c_str(), option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -179,7 +201,7 @@ keys.emplace_back("conv1x1_gemv_g16_wquant_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgReduce, "conv1x1_gemv_g16_wquant_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgrWqStr.c_str(), "conv1x1_gemv_g16_wquant_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -189,7 +211,7 @@ keys.emplace_back("conv1x1_gemv_g8_wquant_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgReduce, "conv1x1_gemv_g8_wquant_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgrWqStr.c_str(), "conv1x1_gemv_g8_wquant_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -200,15 +222,17 @@ } else if(rt->supportSimdGroupMatrix() && area > short_seq && oc > 8 && ic_4 % 8 == 0) { baseKeys.emplace_back("conv1x1_wquant_sg_matrix"); + std::string sgmWqStr = basicShaderPrefix + sgmWqShader; + // Generally threadgroup memory >= 16KB auto smem_size = [[context device] maxThreadgroupMemoryLength]; // choose different tile for different computation if(area >= 128 && oc >= 512 && area * oc > 512 * 2048 && smem_size >= 8192) { auto keys = baseKeys; - keys.emplace_back("conv1x1_gemm_32x64_wquant_sg"); + keys.emplace_back("conv1x1_gemm_32x64_wquant_split_k_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgMatrix, "conv1x1_gemm_32x64_wquant_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWqStr.c_str(), "conv1x1_gemm_32x64_wquant_split_k_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -219,7 +243,7 @@ keys.emplace_back("conv1x1_gemm_32x16_wquant_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgMatrix, "conv1x1_gemm_32x16_wquant_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWqStr.c_str(), "conv1x1_gemm_32x16_wquant_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -229,7 +253,7 @@ keys.emplace_back("conv1x1_gemm_16x32_wquant_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgMatrix, "conv1x1_gemm_16x32_wquant_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWqStr.c_str(), "conv1x1_gemm_16x32_wquant_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -245,18 +269,20 @@ keys.emplace_back(kernel_name); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgMatrix, kernel_name.c_str(), option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWqStr.c_str(), kernel_name.c_str(), option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; mThreads = std::make_pair(MTLSizeMake(UP_DIV(area, 8), UP_DIV(oc, oc_block), 1), MTLSizeMake(32, 1, 1)); } else { + std::string sgrWqStr = basicShaderPrefix + sgrWqShader; + auto keys = baseKeys; std::string kernel_name = "conv1x1_gemv_g4m" + std::to_string(area) + "_wquant_sg"; keys.emplace_back(kernel_name); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgReduce, kernel_name.c_str(), option); + pipeline = backend->makeComputePipelineWithSourceOption(sgrWqStr.c_str(), kernel_name.c_str(), option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -267,7 +293,7 @@ keys.emplace_back("conv1x1_gemm_16x16_wquant_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1W4SgMatrix, "conv1x1_gemm_16x16_wquant_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWqStr.c_str(), "conv1x1_gemm_16x16_wquant_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -309,7 +335,80 @@ return NO_ERROR; } #endif + + std::string sgmWfpShader = gConv1x1WfpSgMatrix; + std::string sgrWfpShader = gConv1x1WfpSgReduce; + + // Dequant using single shader + if (mDequantScaleBias.get()) { + baseKeys.emplace_back("conv1x1_dequant_weight_outter"); + std::string sgmWfpStr = basicShaderPrefix + sgmWfpShader; + + mPreDequantWeight = true; + { + NSMutableDictionary *dic = [baseDic mutableCopy]; + + if(mDequantBits == 4) { + [dic setValue:@"1" forKey:@"W_QUANT_4"]; + } else if(mDequantBits == 8) { + [dic setValue:@"1" forKey:@"W_QUANT_8"]; + } + if(ic % 16 != 0) { + [dic setValue:@"1" forKey:@"W_ALIGN_K16_PROTECT"]; + } + option.preprocessorMacros = dic; + + int bytes = backend->useFp16InsteadFp32() ? 2 : 4; + // accquire space + mTempWeight.reset(Tensor::createDevice(std::vector{ROUND_UP(oc, 4) * ROUND_UP(ic, 16) * bytes})); + backend->onAcquireBuffer(mTempWeight.get(), Backend::DYNAMIC); + backend->onReleaseBuffer(mTempWeight.get(), Backend::DYNAMIC); + + auto keys = baseKeys; + keys.emplace_back("conv1x1_w_dequant"); + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(sgmWfpStr.c_str(), "conv1x1_w_dequant", option); + rt->insertPipeline(keys, pipeline); + } + mDequantPipeline = pipeline; + + mDequantThreads = [context computeBestGroupAndLocal:pipeline threads:MTLSizeMake(UP_DIV(oc, 1), UP_DIV(ic, 16), 1)]; + } + + { + auto keys = baseKeys; + keys.emplace_back("conv1x1_gemm_32x64_split_k_sg"); + + NSMutableDictionary *dic = [baseDic mutableCopy]; + if(rt->supportTensorOps()) { + [dic setValue:@"1" forKey:@"USE_METAL_TENSOR_OPS"]; + keys.emplace_back("USE_METAL_TENSOR_OPS"); + if(ic > oc && ic > 2048 && (ic / blockSize) % 64 == 0) { + [dic setValue:@"1" forKey:@"LOOP_K64"]; + keys.emplace_back("LOOP_K64"); + } + } + option.preprocessorMacros = dic; + + auto pipeline = rt->findPipeline(keys); + if (nil == pipeline) { + pipeline = backend->makeComputePipelineWithSourceOption(sgmWfpStr.c_str(), "conv1x1_gemm_32x64_split_k_sg", option); + rt->insertPipeline(keys, pipeline); + } + mPipeline = pipeline; + mThreads = std::make_pair(MTLSizeMake(UP_DIV(area, 32), UP_DIV(oc, 64), 1), MTLSizeMake(128, 1, 1)); + //printf("out dequant MNK: %d %d %d %d\n", area, oc, ic, blockSize); + } + + return NO_ERROR; + } + + option.preprocessorMacros = baseDic; + if(rt->supportSimdGroupMatrix()) { + std::string sgmWfpStr = basicShaderPrefix + sgmWfpShader; + baseKeys.emplace_back("conv1x1_float_sg_matrix"); // total computation not too small if(area >= 16 && ic_4 >= 4 && ic_4 % 2 == 0 && oc_4 >= 4 && area * ic_4 * oc_4 >= 64 * 64 * 64) { @@ -319,7 +418,7 @@ keys.emplace_back("conv1x1_gemm_32x16_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1SgMatrix, "conv1x1_gemm_32x16_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWfpStr.c_str(), "conv1x1_gemm_32x16_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -329,7 +428,7 @@ keys.emplace_back("conv1x1_gemm_16x16_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1SgMatrix, "conv1x1_gemm_16x16_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgmWfpStr.c_str(), "conv1x1_gemm_16x16_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -339,6 +438,8 @@ } } if(rt->supportSimdGroupReduce()) { + std::string sgrWfpStr = basicShaderPrefix + sgrWfpShader; + baseKeys.emplace_back("conv1x1_float_sg_reduce"); // do input_channel reduce auto magic_num = 4.0; // total threads pretty small and loop pretty large @@ -347,7 +448,7 @@ keys.emplace_back("conv1x1_z4_sg"); auto pipeline = rt->findPipeline(keys); if (nil == pipeline) { - pipeline = backend->makeComputePipelineWithSourceOption(gConv1x1SgReduce, "conv1x1_z4_sg", option); + pipeline = backend->makeComputePipelineWithSourceOption(sgrWfpStr.c_str(), "conv1x1_z4_sg", option); rt->insertPipeline(keys, pipeline); } mPipeline = pipeline; @@ -438,17 +539,39 @@ void MetalConvolution1x1::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { auto input = inputs[0]; auto output = outputs[0]; - [encoder setComputePipelineState:mPipeline]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; - [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; - MetalBackend::setTensor(mWeight.get(), encoder, 3); - MetalBackend::setTensor(mBias.get(), encoder, 4); - if (mDequantScaleBias) { - MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 5); + if(mPreDequantWeight) { + // pre dequant weight pipeline + { + [encoder setComputePipelineState:mDequantPipeline]; + MetalBackend::setTensor(mWeight.get(), encoder, 0); + MetalBackend::setTensor(mTempWeight.get(), encoder, 1); + [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; + MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 3); + [encoder dispatchThreadgroups:mDequantThreads.first threadsPerThreadgroup:mDequantThreads.second]; + } + // convolution pipeline + { + [encoder setComputePipelineState:mPipeline]; + [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; + [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; + [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; + MetalBackend::setTensor(mTempWeight.get(), encoder, 3); + MetalBackend::setTensor(mBias.get(), encoder, 4); + MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 5); + [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; + } + } else { + [encoder setComputePipelineState:mPipeline]; + [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; + [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; + [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; + MetalBackend::setTensor(mWeight.get(), encoder, 3); + MetalBackend::setTensor(mBias.get(), encoder, 4); + if (mDequantScaleBias) { + MetalBackend::setTensor(mDequantScaleBias.get(), encoder, 5); + } + [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; } - [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; - #ifdef MNN_METAL_DEBUG_INFO if(!static_cast(backend())->useFp16InsteadFp32()) { { diff --git a/source/backend/metal/MetalKVCacheManager.hpp b/source/backend/metal/MetalKVCacheManager.hpp new file mode 100644 index 0000000000..7da4c83114 --- /dev/null +++ b/source/backend/metal/MetalKVCacheManager.hpp @@ -0,0 +1,65 @@ + +// +// MetalKVCacheManager.hpp +// MNN +// +// Created by MNN on 2025/12/04. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#ifndef METAL_KVCACHE_MANAGER_HPP +#define METAL_KVCACHE_MANAGER_HPP + +#import "core/Macro.h" +#import "core/MNNFileUtils.h" +#import "core/OpCommonUtils.hpp" +#import "core/KVCacheManager.hpp" + +namespace MNN { + +class MetalKVCacheManager : public KVCacheManager{ +private: + id mKeyBuffer; + id mValueBuffer; + size_t mCurrentTotalSize; + +private: + void expandKVCacheInDisk(size_t oldSize, size_t curSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy, file_t specKeyFile = INVALID_FILE, file_t specValueFile = INVALID_FILE); + void expandKVCacheInMem(size_t oldSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy); +public: + MetalKVCacheManager(Backend * backend, KVCacheConfig & kvConfig): KVCacheManager(backend, kvConfig) { + // nothing todo + } + ~MetalKVCacheManager() { + onClear(); + } + Tensor * getKeyTensor() { + return mPastKey.get(); + } + Tensor * getValueTensor() { + return mPastValue.get(); + } + id getKeyBuffer() { + return mKeyBuffer; + } + id getValueBuffer() { + return mValueBuffer; + } + + void setPastLength(int length) { + mPastLength = length; + } + + virtual void onResize(int kv_num_head, int head_dim); + virtual void onClear(); + virtual void onAlloc(KVMeta* meta, int seq_len); + virtual void onRealloc(KVMeta* meta); +}; + +} // namespace MNN + +#endif // METAL_KVCACHE_MANAGER_HPP + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/backend/metal/MetalKVCacheManager.mm b/source/backend/metal/MetalKVCacheManager.mm new file mode 100644 index 0000000000..f67f90fd2e --- /dev/null +++ b/source/backend/metal/MetalKVCacheManager.mm @@ -0,0 +1,336 @@ +// +// MetalKVCacheManager.mm +// MNN +// +// Created by MNN on 2025/12/04. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#import "backend/metal/MetalBackend.hpp" +#import "backend/metal/MNNMetalContext.h" +#import "MetalKVCacheManager.hpp" + +namespace MNN { + +void MetalKVCacheManager::onResize(int kv_num_head, int head_dim) { + mKvNumHead = kv_num_head; + mHeadDim = head_dim; +} + +void MetalKVCacheManager::onAlloc(KVMeta* meta, int seq_len) { + mMeta = meta; + auto mtbn = static_cast(mBackend); + auto context = (__bridge MNNMetalContext *)mtbn->context(); + + auto kv_seq_len = mMeta != nullptr ? mMeta->add : seq_len; + int byte = 4; + if(mtbn->useFp16InsteadFp32()) { + byte = 2; + } + // load disk prefix kvcache + if(mMeta != nullptr && mMeta->file_name.size() > 0 && mMeta->file_flag == KVMeta::PendingRead) { + // create new files + std::string pathk = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index) + ".k"; + std::string pathv = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index++) + ".v"; + mMeta->layer_index = mMeta->layer_index % mMeta->layer_nums; + auto old_key_fd = MNNOpenFile(pathk.c_str(), MNN_FILE_READ | MNN_FILE_WRITE); + auto old_value_fd = MNNOpenFile(pathv.c_str(), MNN_FILE_READ | MNN_FILE_WRITE); + if (old_key_fd == INVALID_FILE) { + MNN_PRINT("Failed to open the file: %s\n", pathk.c_str()); + } + if (old_value_fd == INVALID_FILE) { + MNN_PRINT("Failed to open the file: %s\n", pathv.c_str()); + } + + // get kv cache file info + auto oldKeySize = MNNGetFileSize(old_key_fd); + auto oldValueSize = MNNGetFileSize(old_value_fd); + auto oldTotalSize = ALIMIN(oldKeySize, oldValueSize); + if(oldKeySize != oldValueSize) { + MNN_ERROR("[Error]: Kvcache in disk size of key and value should equal with metal backend\n"); + } + size_t oldKeyMaxLength = oldKeySize / (mKvNumHead * mHeadDim * byte); + size_t oldValueMaxLength = oldValueSize / (mKvNumHead * mHeadDim * byte); + size_t oldMaxLength = ALIMIN(oldKeyMaxLength, oldValueMaxLength); + if(oldMaxLength < meta->seqlen_in_disk) { + MNN_ERROR("[Error]: Kvcache in disk size smaller than saved lengthInDiskToload:%d\n", (int)meta->seqlen_in_disk); + } + + int kv_seq_len = ROUND_UP(meta->add + meta->seqlen_in_disk, mConfig.mKvAlignNum); + mMaxLength = kv_seq_len > oldMaxLength ? ROUND_UP(meta->add + meta->seqlen_in_disk + mConfig.mExpandChunk, mConfig.mKvAlignNum) : oldMaxLength; + size_t totalSize = mKvNumHead * mMaxLength * mHeadDim * byte; + mCurrentTotalSize = totalSize; + + size_t old_piece_size = meta->seqlen_in_disk * byte; + size_t old_piece_stride = oldMaxLength * byte; + size_t new_piece_stride = mMaxLength * byte; + + mCurrentTotalSize = ALIMAX(mCurrentTotalSize, oldKeySize); + mCurrentTotalSize = ALIMAX(mCurrentTotalSize, oldValueSize); + + createKVCacheFile(); + resetKVCacheFileSize(mCurrentTotalSize, mCurrentTotalSize); + expandKVCacheInDisk(oldTotalSize, mCurrentTotalSize, old_piece_stride, old_piece_size, new_piece_stride, true, old_key_fd, old_value_fd); + + mPastLength = meta->seqlen_in_disk; + mKVCacheInDisk = true; + + return; + } + + // align max kv_seq_len to mKvAlignNum, for simd/tensor matrix load alignment + mMaxLength = ROUND_UP(kv_seq_len + mConfig.mExpandChunk, mConfig.mKvAlignNum); + size_t totalSize = mKvNumHead * mMaxLength * mHeadDim * byte; + mCurrentTotalSize = totalSize; + bool storeKvInDisk = !mConfig.mKVCacheDir.empty(); + bool sharePrefixKv = mMeta != nullptr && mMeta->file_name.size() > 0 && mMeta->file_flag == KVMeta::PendingWrite; + + if (sharePrefixKv) { + mSaveShareKvPrefix = true; + if(!MNNCreateDir(mConfig.mPrefixCacheDir.c_str())) { + MNN_PRINT("Failed to create prefix cache file dir: %s\n", mConfig.mPrefixCacheDir.c_str()); + } + } + + if(storeKvInDisk || sharePrefixKv) { + std::string keyStoredDst = ""; + std::string valueStoredDst = ""; + + if(mMeta != nullptr) { + mBasePrefixFileName = MNNFilePathConcat(mConfig.mPrefixCacheDir, mMeta->file_name) + "_" + std::to_string(mMeta->layer_index); + keyStoredDst = sharePrefixKv ? mBasePrefixFileName + ".k" : ""; + valueStoredDst = sharePrefixKv ? mBasePrefixFileName + ".v" : ""; + mMeta->layer_index++; + mMeta->layer_index = mMeta->layer_index % mMeta->layer_nums; + } + createKVCacheFile(keyStoredDst, valueStoredDst); + resetKVCacheFileSize(totalSize, totalSize); + mmapKVCache(totalSize, totalSize); + mKVCacheInDisk = true; + + mKeyBuffer = [[context device] newBufferWithBytesNoCopy:mMapKeyAddr length:totalSize options:MTLResourceStorageModeShared deallocator:nil]; + mValueBuffer = [[context device] newBufferWithBytesNoCopy:mMapValueAddr length:totalSize options:MTLResourceStorageModeShared deallocator:nil]; + + auto new_key_ptr = (uint8_t*)[mKeyBuffer contents]; + ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); + + auto new_value_ptr = (uint8_t*)[mValueBuffer contents]; + ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); + + } else { + // past_key: [maxlen, kvNumhead, headdim] + auto new_key = Tensor::createDevice({mMaxLength, mKvNumHead, mHeadDim}); + // past_value: [kvNumhead, headdim, maxlen] + auto new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mMaxLength}); + + + auto res = mBackend->onAcquireBuffer(new_key, Backend::STATIC); + res = res && mBackend->onAcquireBuffer(new_value, Backend::STATIC); + if(!res) { + MNN_ERROR("attition kv cache alloc memory error:%d\n", res); + } + // memset for qkv matmul mad, in case dirty data + auto newKeyBuf = MetalBackend::getBuffer(new_key); + auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second; + ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); + + auto newValueBuf = MetalBackend::getBuffer(new_value); + auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second; + ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); + + mPastKey.reset(new_key); + mPastValue.reset(new_value); + } + +} +void MetalKVCacheManager::onRealloc(KVMeta* meta) { + mMeta = meta; + auto kv_seq_len = mMeta->previous + mMeta->add - mMeta->remove + mMeta->computeReverseSize(); + auto mtbn = static_cast(mBackend); + + int byte = 4; + if(mtbn->useFp16InsteadFp32()) { + byte = 2; + } + + auto start = mPastLength - mMeta->remove; + // latest length larger than maxLen + if (kv_seq_len > mMaxLength) { + + // copy mPastLength including all remove/reverse to new buffer first + auto copy_len = mPastLength; + bool needCopy = mPastLength > 0; + + size_t old_size = mKvNumHead * copy_len * mHeadDim * byte; + size_t old_piece_size = copy_len * byte; + size_t old_piece_stride = mMaxLength * byte; + + // align max kv_seq_len to mKvAlignNum, for simd/tensor matrix load alignment + mMaxLength = ROUND_UP(kv_seq_len + mConfig.mExpandChunk, mConfig.mKvAlignNum); + + auto oldTotalSize = mCurrentTotalSize; + size_t size = mKvNumHead * mMaxLength * mHeadDim * byte; + mCurrentTotalSize = size; + size_t new_piece_stride = mMaxLength * byte; + + mPastLength = (int)start; + + if(mKVCacheInDisk) { + expandKVCacheInDisk(oldTotalSize, mCurrentTotalSize, old_piece_stride, old_piece_size, new_piece_stride, needCopy); + } else { + expandKVCacheInMem(oldTotalSize, old_piece_stride, old_piece_size, new_piece_stride, needCopy); + } + } + + // Remove + { + if (0 == mMeta->n_reserve) { + mPastLength = start; + return; + } + + int8_t *key_ptr = nullptr; + int8_t *value_ptr = nullptr; + if(mKVCacheInDisk) { + key_ptr = mMapKeyAddr; + value_ptr = mMapValueAddr; + } else { + auto keyBuf = MetalBackend::getBuffer(mPastKey.get()); + key_ptr = (int8_t*)[keyBuf.first contents] + keyBuf.second; + auto valueBuf = MetalBackend::getBuffer(mPastValue.get()); + value_ptr = (int8_t*)[valueBuf.first contents] + valueBuf.second; + } + auto src_start = start; + // TODO: need to ensure reserve info is sorted + for (int n = 0; n < mMeta->n_reserve; ++n) { + auto begin = mMeta->reserve[2 * n]; + auto length = mMeta->reserve[2 * n + 1]; + // past_key : [mCache->mPastLength, mKvNumHead, mHeadDim] + // past_value : [mKvNumHead, mHeadDim, mCache->mMaxLength] + + auto copy_src_index = src_start + begin; + auto copy_dst_index = start; + for(int i = 0; i < length; i++) { + ::memcpy(key_ptr + (copy_dst_index + i) * mKvNumHead * mHeadDim * byte, key_ptr + (copy_src_index + i) * mKvNumHead * mHeadDim * byte, mKvNumHead * mHeadDim * byte); + } + for(int j = 0; j < mKvNumHead * mHeadDim; j++) { + for(int i = 0; i < length; i++) { + ::memcpy(value_ptr + (j * mMaxLength + copy_dst_index + i) * byte, value_ptr + (j * mMaxLength + copy_src_index + i) * byte, byte); + } + } + start += length; + } + mPastLength = (int)start; + } +} + +void MetalKVCacheManager::expandKVCacheInMem(size_t oldSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy) { + auto mtbn = static_cast(mBackend); + int byte = 4; + if(mtbn->useFp16InsteadFp32()) { + byte = 2; + } + // past_key: [maxlen, kvNumhead, headdim] + auto new_key = Tensor::createDevice({mMaxLength, mKvNumHead, mHeadDim}); + // past_value: [kvNumhead, headdim, maxlen] + auto new_value = Tensor::createDevice({mKvNumHead, mHeadDim, mMaxLength}); + + auto res = mBackend->onAcquireBuffer(new_key, Backend::STATIC); + res = res && mBackend->onAcquireBuffer(new_value, Backend::STATIC); + if(!res) { + MNN_ERROR("attition kv cache realloc memory error:%d\n", res); + } + + // memset for qkv matmul mad, in case dirty data + auto newKeyBuf = MetalBackend::getBuffer(new_key); + auto new_key_ptr = (uint8_t*)[newKeyBuf.first contents] + newKeyBuf.second; + ::memset(new_key_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); + + auto newValueBuf = MetalBackend::getBuffer(new_value); + auto new_value_ptr = (uint8_t*)[newValueBuf.first contents] + newValueBuf.second; + ::memset(new_value_ptr, 0, mMaxLength * mKvNumHead * mHeadDim * byte); + + if (need_copy) { + auto keyBuf = MetalBackend::getBuffer(mPastKey.get()); + auto key_ptr = (uint8_t*)[keyBuf.first contents] + keyBuf.second;; + ::memcpy(new_key_ptr, key_ptr, oldSize); + + auto valueBuf = MetalBackend::getBuffer(mPastValue.get()); + auto value_ptr = (uint8_t*)[valueBuf.first contents] + valueBuf.second; + for(int i = 0; i < mKvNumHead * mHeadDim; i++) { + ::memcpy(new_value_ptr + i * new_piece_stride, value_ptr + i * old_piece_stride, old_piece_size); + } + } + + mPastKey.reset(new_key); + mPastValue.reset(new_value); +} + +void MetalKVCacheManager::expandKVCacheInDisk(size_t oldSize, size_t curSize, size_t old_piece_stride, size_t old_piece_size, size_t new_piece_stride, bool need_copy, file_t specKeyFile, file_t specValueFile) { + auto mtbn = static_cast(mBackend); + auto context = (__bridge MNNMetalContext *)mtbn->context(); + + mmapKVCache(oldSize, oldSize, specKeyFile, specValueFile); + std::vector prevKey, prevValue; + prevKey.resize(oldSize); + prevValue.resize(oldSize); + memcpy(prevKey.data(), mMapKeyAddr, oldSize); + memcpy(prevValue.data(), mMapValueAddr, oldSize); + + unmapKVCache(oldSize, oldSize); + resetKVCacheFileSize(curSize, curSize); + mmapKVCache(curSize, curSize); + + // reset id + mKeyBuffer = [[context device] newBufferWithBytesNoCopy:mMapKeyAddr length:curSize options:MTLResourceStorageModeShared deallocator:nil]; + mValueBuffer = [[context device] newBufferWithBytesNoCopy:mMapValueAddr length:curSize options:MTLResourceStorageModeShared deallocator:nil]; + + + // Step 3: Move the kvcache from temporary buffers in memory to disk + memset(mMapKeyAddr, 0, curSize); + memset(mMapValueAddr, 0, curSize); + + if (need_copy) { + ::memcpy(mMapKeyAddr, prevKey.data(), oldSize); + for(int i = 0; i < mKvNumHead * mHeadDim; i++) { + ::memcpy(mMapValueAddr + i * new_piece_stride, prevValue.data() + i * old_piece_stride, old_piece_size); + } + } +} + +void MetalKVCacheManager::onClear() { + if (mKVCacheInDisk) { + mKeyBuffer = nil; + mValueBuffer = nil; + + // mSaveShareKvPrefix also need unmap file + unmapKVCache(mCurrentTotalSize, mCurrentTotalSize); + if(mSaveShareKvPrefix) { + // set prefix cachefile validation + auto k_file = mBasePrefixFileName + ".k"; + if(MNNFileExist(k_file.c_str())) { + auto k_sync_file = mBasePrefixFileName + "_sync.k"; + MNNCreateFile(k_sync_file.c_str()); + } + auto v_file = mBasePrefixFileName + ".v"; + if(MNNFileExist(v_file.c_str())) { + auto v_sync_file = mBasePrefixFileName + "_sync.v"; + MNNCreateFile(v_sync_file.c_str()); + } + } else { + // delete temp kvcache file + removeKVCacheFile(); + } + mKVCacheInDisk = false; + } + mPastKey.reset(); + mPastValue.reset(); + mMaxLength = 0; + mPastLength = 0; +} +} // namespace MNN + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE + diff --git a/source/backend/metal/MetalRaster.mm b/source/backend/metal/MetalRaster.mm index 0c9d9c6b4e..76697b0532 100644 --- a/source/backend/metal/MetalRaster.mm +++ b/source/backend/metal/MetalRaster.mm @@ -69,7 +69,7 @@ static void writeSamplerInfo(SamplerInfo& info, const Tensor::InsideDescribe::Re uint4 size;//size[3] + totalSize uint4 extent;//dstStride[3]+dstOffset }; -kernel void main0(const device T *in [[buffer(0)]], +kernel void mblit(const device T *in [[buffer(0)]], device T *out [[buffer(1)]], const device uint4* buf [[buffer(2)]], uint3 tgid [[thread_position_in_grid]]) { @@ -98,7 +98,7 @@ kernel void main0(const device T *in [[buffer(0)]], uint4 size;//size[3] + totalSize uint4 extent;//dstStride[3]+dstOffset }; -kernel void main0(const device T *in [[buffer(0)]], +kernel void sblit(const device T *in [[buffer(0)]], device T *out [[buffer(1)]], constant SamplerInfo &info [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { @@ -119,7 +119,7 @@ kernel void main0(const device T *in [[buffer(0)]], uint4 size;//size[3] + totalSize uint4 extent;//dstStride[3]+dstOffset }; -kernel void main0(const device T *in [[buffer(0)]], +kernel void mraster(const device T *in [[buffer(0)]], device T *out [[buffer(1)]], const device uint4* buf [[buffer(2)]], uint3 tgid [[thread_position_in_grid]]) { @@ -183,7 +183,7 @@ kernel void main0(const device T *in [[buffer(0)]], uint4 size;//size[3] + totalSize uint4 extent;//dstStride[3]+dstOffset }; -kernel void main0(const device T *in [[buffer(0)]], +kernel void sraster(const device T *in [[buffer(0)]], device T *out [[buffer(1)]], const device uint4* buf [[buffer(2)]], uint3 gid [[thread_position_in_grid]]) { @@ -237,7 +237,7 @@ kernel void main0(const device T *in [[buffer(0)]], int4 value; uint4 size; }; -kernel void main0(device int4 *out [[buffer(0)]], +kernel void fill(device int4 *out [[buffer(0)]], constant MemsetInfo &info [[buffer(1)]], uint3 gid [[thread_position_in_grid]]) { if (gid.x < info.size.x) { @@ -267,9 +267,9 @@ kernel void main0(device int4 *out [[buffer(0)]], @"T" : @(unitName.c_str()), }; if (multiRegion) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gMultiBlitMetal, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gMultiBlitMetal, "mblit", compileOptions); } else { - pipeline = mtbn->makeComputePipelineWithSourceOption(gSingleBlitMetal, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gSingleBlitMetal, "sblit", compileOptions); } mtbn->runtime()->insertPipeline(keys, pipeline); } @@ -325,7 +325,7 @@ kernel void main0(device int4 *out [[buffer(0)]], }; auto pipeline = mtbn->runtime()->findPipeline(keys); if (nil == pipeline) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gFillInt4, "main0", nil); + pipeline = mtbn->makeComputePipelineWithSourceOption(gFillInt4, "fill", nil); mtbn->runtime()->insertPipeline(keys, pipeline); } mZeroPipeline = pipeline; @@ -452,9 +452,9 @@ kernel void main0(device int4 *out [[buffer(0)]], @"T" : @(unitName.c_str()), }; if(iter.second.size() == 1) { - pipeline = mtbn->makeComputePipelineWithSourceOption(gSingleRasterTemplate, "main0", options); + pipeline = mtbn->makeComputePipelineWithSourceOption(gSingleRasterTemplate, "sraster", options); } else { - pipeline = mtbn->makeComputePipelineWithSourceOption(gMultiRasterTemplate, "main0", options); + pipeline = mtbn->makeComputePipelineWithSourceOption(gMultiRasterTemplate, "mraster", options); } mtbn->runtime()->insertPipeline(keys, pipeline); } diff --git a/source/backend/metal/MetalUnary.mm b/source/backend/metal/MetalUnary.mm index f4ce7655d4..e9ab73302a 100755 --- a/source/backend/metal/MetalUnary.mm +++ b/source/backend/metal/MetalUnary.mm @@ -51,7 +51,7 @@ static inline float4 gelu(float4 value) { return result; } -kernel void main0(const device T *in [[buffer(0)]], \ +kernel void unary(const device T *in [[buffer(0)]], \ device T *out [[buffer(1)]], \ device unary_shape& s [[buffer(2)]], \ uint3 gid [[thread_position_in_grid]]) { \ @@ -162,7 +162,7 @@ kernel void main0(const device T *in [[buffer(0)]], \ @"T" : T, @"FUNC" : kernel, }; - pipeline = mtbn->makeComputePipelineWithSourceOption(gUnaryTemplate, "main0", compileOptions); + pipeline = mtbn->makeComputePipelineWithSourceOption(gUnaryTemplate, "unary", compileOptions); mtbn->runtime()->insertPipeline(keys, pipeline); } if (nil == pipeline) { diff --git a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp index 87667d18b6..3f8566e273 100644 --- a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp @@ -13,6 +13,124 @@ namespace MNN { namespace OpenCL { +static std::string getComputeOption(MNN::BinaryOpOperation type){ + std::string compute; + switch (type) { + case BinaryOpOperation_MUL: + compute = "in0*in1";break; + case BinaryOpOperation_ADD: + compute = "in0+in1";break; + case BinaryOpOperation_SUB: + compute = "in0-in1";break; + case BinaryOpOperation_REALDIV: + compute = "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))";break; + case BinaryOpOperation_MINIMUM: + compute = "in0>in1?in1:in0";break; + case BinaryOpOperation_MAXIMUM: + compute = "in0>in1?in0:in1";break; + case BinaryOpOperation_GREATER: + compute = "(float)(isgreater(in0,in1))";break; + case BinaryOpOperation_LESS: + compute = "(float)(isless(in0,in1))";break; + case BinaryOpOperation_LESS_EQUAL: + compute = "(float)(islessequal(in0,in1))";break; + case BinaryOpOperation_GREATER_EQUAL: + compute = "(float)(isgreaterequal(in0,in1))";break; + case BinaryOpOperation_EQUAL: + compute = "(float)(isequal(in0,in1))";break; + case BinaryOpOperation_FLOORDIV: + compute = "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))";break; + case BinaryOpOperation_FLOORMOD: + compute = "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1";break; + case BinaryOpOperation_POW: + compute = "pow(in0,in1)";break; + case BinaryOpOperation_SquaredDifference: + compute = "(in0-in1)*(in0-in1)";break; + case BinaryOpOperation_ATAN2: + compute = "(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))";break; + case BinaryOpOperation_NOTEQUAL: + compute = "(float)(isnotequal(in0,in1))";break; + case BinaryOpOperation_MOD: + compute = "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1";break; + default: + break; + } + return compute; +} + +static std::string getUnaryComputeOption(MNN::UnaryOpOperation type){ + std::string compute; + switch (type) { + case UnaryOpOperation_ABS: + compute = "fabs((float)(in))"; break; + case UnaryOpOperation_SQUARE: + compute = "in*in"; break; + case UnaryOpOperation_RSQRT: + compute = "rsqrt((float))(in)>(float)(0.000001)?(float))(in):(float)(0.000001))"; break; + case UnaryOpOperation_NEG: + compute = "-(in)"; break; + case UnaryOpOperation_EXP: + compute = "exp((float))(in))"; break; + case UnaryOpOperation_COS: + compute = "cos((float)(in))"; break; + case UnaryOpOperation_SIN: + compute = "sin((float)(in))"; break; + case UnaryOpOperation_TAN: + compute = "tan((float)(in))"; break; + case UnaryOpOperation_ATAN: + compute = "atan((float)(in))"; break; + case UnaryOpOperation_SQRT: + compute = "sqrt((float)(in))"; break; + case UnaryOpOperation_CEIL: + compute = "ceil((float)(in))"; break; + case UnaryOpOperation_RECIPROCAL: + compute = "native_recip((float)(in))"; break; + case UnaryOpOperation_LOG1P: + compute = "log1p((float)(in))"; break; + case UnaryOpOperation_LOG: + compute = "native_log((float)(in)>(float)(0.0000001)?(float)(in):(float)(0.0000001))"; break; + case UnaryOpOperation_FLOOR: + compute = "floor((float)(in))"; break; + case UnaryOpOperation_BNLL: + compute = "in>(float)((float)0)?(in+native_log(exp((float)(-(in)))+(float)(1.0))):(native_log(exp((float)(in))+(float)(1.0)))"; break; + case UnaryOpOperation_ACOSH: + compute = "acosh((float)(in))"; break; + case UnaryOpOperation_SINH: + compute = "sinh((float)(in))"; break; + case UnaryOpOperation_ASINH: + compute = "asinh((float)(in))"; break; + case UnaryOpOperation_ATANH: + compute = "atanh((float)(in))"; break; + case UnaryOpOperation_SIGN: + compute = "sign((float)(in))"; break; + case UnaryOpOperation_ROUND: + compute = "round((float)(in))"; break; + case UnaryOpOperation_COSH: + compute = "cosh((float)(in))"; break; + case UnaryOpOperation_ERF: + compute = "erf((float)(in))"; break; + case UnaryOpOperation_ERFC: + compute = "erfc((float)(in))"; break; + case UnaryOpOperation_EXPM1: + compute = "expm1((float)(in))"; break; + case UnaryOpOperation_SIGMOID: + compute = "native_recip((float)1+native_exp((float)(-in)))"; break; + case UnaryOpOperation_SILU: + compute = "((float)(in)*native_recip((float)1+native_exp((float)(-in))))"; break; + case UnaryOpOperation_TANH: + compute = "tanh((float)(in))"; break; + case UnaryOpOperation_HARDSWISH: + compute = "(float)(in)>(float)(-3.0f)?((float)(in)<(float)(3.0f)?(((float)(in)*((float)(in)+(float)3.0f))/(float)6.0f):(float)(in)):(float)(0.0f)"; break; + case UnaryOpOperation_GELU: + compute = "gelu((float)(in))"; break; + case UnaryOpOperation_GELU_STANDARD: + compute = "(erf((float)(in)*(float)0.7071067932881648)+(float)1.0)*(float)(in)*(float)0.5"; break; + default: + break; + } + return compute; +} + static void _setTensorStack(std::vector &result, const std::vector &inputs, const std::vector &outputs, const LoopParam *loop) { if (loop->inputIndexes() != nullptr) { @@ -25,113 +143,125 @@ static void _setTensorStack(std::vector &result, const std::vectortensorNumber()); - auto cmd = loop->commands()->GetAs(0); } -ErrorCode LoopGatherBufExecution::InitCommandOnEncode(const std::vector &inputs, const std::vector &outputs){ - auto cmd = mLoop->initCommand()->GetAs(0); - OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); - auto runTime = mOpenCLBackend->getOpenCLRuntime(); - if (cmd->op() == nullptr){ - Unit unit; - auto output = mTensors[cmd->indexes()->data()[0]]; - auto outputShape = tensorShapeFormat(output); - auto outputDes = TensorUtils::getDescribe(output); - int region[] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//nchw - if(MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat){ - region[1] = ROUND_UP(outputShape[3], 4); + +ErrorCode LoopBufExecution::InitCommandOnEncode(){ + for (int i=0; iinitCommand()->size(); ++i) { + auto cmd = mLoop->initCommand()->GetAs(i); + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + int mStride_src[4]; + int mStride_dst[4]; + int mStep[2]; + int mIter[2]; + if (cmd->op() == nullptr){ + Unit unit; + auto output = mTensors[cmd->indexes()->data()[0]]; + auto outputShape = tensorShapeFormat(output); + auto outputDes = TensorUtils::getDescribe(output); + int region[] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//nchw + if(MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat){ + region[1] = ROUND_UP(outputShape[3], 4); + } + unit.kernel = runTime->buildKernel("loop", "set_zero", {}, mOpenCLBackend->getPrecision(), output, output); + unit.localWorkSize = {8, 8}; + unit.globalWorkSize = {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, + (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}; + + int global_dim0 = region[2] * region[3]; + int global_dim1 = region[0] * region[1]; + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, global_dim0); + ret |= unit.kernel->get().setArg(idx++, global_dim1); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + MNN_CHECK_CL_SUCCESS(ret, "setArg set_zero buffer"); + mOpenCLBackend->recordKernel2d(unit.kernel, {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, + (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}, {8, 8}); + mUnits.emplace_back(unit); + return NO_ERROR; } - unit.kernel = runTime->buildKernel("raster_buf", "buffer_set_zero", {}, mOpenCLBackend->getPrecision(), output, output); - unit.localWorkSize = {8, 8}; - unit.globalWorkSize = {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, - (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}; - - int global_dim0 = region[2] * region[3]; - int global_dim1 = region[0] * region[1]; - - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, global_dim0); - ret |= unit.kernel->get().setArg(idx++, global_dim1); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_set_zero"); - mOpenCLBackend->recordKernel2d(unit.kernel, {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, - (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}, {8, 8}); - mUnits.emplace_back(unit); - return NO_ERROR; - } - int x = cmd->size()->data()[0]; - int y = cmd->size()->data()[1]; - int z = cmd->size()->data()[2]; - int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); - - auto srcStride = cmd->view()->GetAs(1)->stride()->data(); - auto dstStride = cmd->view()->GetAs(0)->stride()->data(); - for (int i = 0; i < 3; ++i) { - mStride_src[i] = srcStride[i]; - mStride_dst[i] = dstStride[i]; - } - - mStride_src[3] = 0; - mStride_dst[3] = 0; - ::memset(mStep, 0, 2 * sizeof(int)); - - // gather - { - Unit unit; - auto input = mTensors[cmd->indexes()->data()[1]]; - auto output = mTensors[cmd->indexes()->data()[0]]; - std::set buildOptions; + int x = cmd->size()->data()[0]; + int y = cmd->size()->data()[1]; + int z = cmd->size()->data()[2]; - unit.kernel = runTime->buildKernel("gather_buf", "batch_gather_buf", buildOptions, mOpenCLBackend->getPrecision(), input, output); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(1)}; + int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = mTensors[cmd->indexes()->data()[0]]->elementSize(); - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(input)); - ret |= unit.kernel->get().setArg(index++, x); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); - ret |= unit.kernel->get().setArg(index++, inputSize); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopInitGatherBufExecution"); + auto srcStride = cmd->view()->GetAs(1)->stride()->data(); + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src[i] = srcStride[i]; + mStride_dst[i] = dstStride[i]; + } - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "batch_gather_buf", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gather_buf").first; + mStride_src[3] = 0; + mStride_dst[3] = 0; + ::memset(mStep, 0, 2 * sizeof(int)); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + // gather + { + Unit unit; + auto input = mTensors[cmd->indexes()->data()[1]]; + auto output = mTensors[cmd->indexes()->data()[0]]; + std::set buildOptions; + + unit.kernel = runTime->buildKernel("loop", "batch_gather", buildOptions, mOpenCLBackend->getPrecision(), input, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(1)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(index++, x); + ret |= unit.kernel->get().setArg(index++, 0); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, inputSize); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopInitGatherBufExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "batch_gather", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + } } return NO_ERROR; } -ErrorCode LoopGatherBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { - auto cmd = mLoop->commands()->GetAs(0); +ErrorCode LoopBufExecution::LoopGather(const Tensor *output, int cmdIndex, int iter) { + auto cmd = mLoop->commands()->GetAs(cmdIndex); + auto op = cmd->op(); OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); auto runTime = mOpenCLBackend->getOpenCLRuntime(); - _setTensorStack(mTensors, inputs, outputs, mLoop); - mUnits.clear(); - mOffsetTensors.clear(); - - if(mLoop->initCommand() != nullptr){ - InitCommandOnEncode(inputs, outputs); - } int x = cmd->size()->data()[0]; int y = cmd->size()->data()[1]; int z = cmd->size()->data()[2]; - int n = mLoop->loopNumber(); + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; + if(mLoop->commands()->size() == 1 && OpType_UnaryOp == op->type() && nullptr == op->main() && cmd->fuse() < 0){ + // only one gather + n = mLoop->loopNumber(); + } + + int mStride_src[4]; + int mStride_dst[4]; + int mStep[2]; + int mIter[2]; int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = output->elementSize(); auto srcStride = cmd->view()->GetAs(1)->stride()->data(); auto dstStride = cmd->view()->GetAs(0)->stride()->data(); @@ -139,6 +269,11 @@ ErrorCode LoopGatherBufExecution::onEncode(const std::vector &inputs, mStride_src[i] = srcStride[i]; mStride_dst[i] = dstStride[i]; } + if(cmd->fuse() >= 0){ + mStride_dst[0] = y * z; + mStride_dst[1] = z; + mStride_dst[2] = 1; + } mStride_src[3] = cmd->view()->GetAs(1)->offset(); mStride_dst[3] = cmd->view()->GetAs(0)->offset(); @@ -146,72 +281,73 @@ ErrorCode LoopGatherBufExecution::onEncode(const std::vector &inputs, ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); // gather - { - Unit unit; - auto input = mTensors[cmd->indexes()->data()[1]]; - auto output = mTensors[cmd->indexes()->data()[0]]; - std::set buildOptions; - if (mIter[0] >= 0) { - buildOptions.emplace("-DOFFSET_DST"); - } - if (mIter[1] >= 0) { - buildOptions.emplace("-DOFFSET_SRC"); - } - - unit.kernel = runTime->buildKernel("gather_buf", "batch_gather_buf", buildOptions, mOpenCLBackend->getPrecision(), input, output); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(n)}; - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(input)); - for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { - if (mIter[i] >= 0) { - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[i]])); - } + Unit unit; + auto input = mTensors[cmd->indexes()->data()[1]]; + std::set buildOptions; + + if(op->main() != nullptr){ + std::string compute = getUnaryComputeOption(cmd->op()->main_as_UnaryOp()->opType()); + buildOptions.emplace("-DUNARY_OPERATOR=" + compute); + } + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_DST"); + } + if (mIter[1] >= 0) { + buildOptions.emplace("-DOFFSET_SRC"); + } + + unit.kernel = runTime->buildKernel("loop", "batch_gather", buildOptions, mOpenCLBackend->getPrecision(), input, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(n)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(input)); + for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { + if (mIter[i] >= 0) { + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[i]])); } - ret |= unit.kernel->get().setArg(index++, x); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); - ret |= unit.kernel->get().setArg(index++, inputSize); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopGatherBufExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "batch_gather_buf", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "gather_buf").first; - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + } + ret |= unit.kernel->get().setArg(index++, x); + ret |= unit.kernel->get().setArg(index++, iter); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, inputSize); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopGatherBufExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "batch_gather", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + if(cmd->fuse() >= 0){ + FuseOutput(cmdIndex, mStride_dst, cmd->size()->data()[0], cmd->size()->data()[1], cmd->size()->data()[2], n, iter); } return NO_ERROR; } -LoopBatchMatMulBufExecution::LoopBatchMatMulBufExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn) -: CommonExecution(bn, op) { - mLoop = loop; - mTensors.resize(mLoop->tensorNumber()); -} - -ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { - auto cmd = mLoop->commands()->GetAs(0); - mHasBias = cmd->indexes()->size() > 3; - mTransposeA = cmd->op()->main_as_MatMul()->transposeA(); - mTransposeB = cmd->op()->main_as_MatMul()->transposeB(); +ErrorCode LoopBufExecution::LoopBatchMatMul(const Tensor *output, int cmdIndex, int iter) { + auto cmd = mLoop->commands()->GetAs(cmdIndex); + bool mHasBias = cmd->indexes()->size() > 3; OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); auto runTime = mOpenCLBackend->getOpenCLRuntime(); - _setTensorStack(mTensors, inputs, outputs, mLoop); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; mOffset[0] = cmd->view()->GetAs(0)->offset(); mOffset[1] = cmd->view()->GetAs(1)->offset(); mOffset[2] = cmd->view()->GetAs(2)->offset(); - mUnits.clear(); if (mHasBias) { mOffset[3] = cmd->view()->GetAs(3)->offset(); } @@ -221,87 +357,200 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp int e = cmd->size()->data()[0]; int l = cmd->size()->data()[1]; int h = cmd->size()->data()[2]; - int n = mLoop->loopNumber(); - - { - // matmul - Unit unit; - std::string KernelName = "batch_matmul"; - std::set buildOptions = mBuildOptions; - if (mHasBias) { - buildOptions.emplace("-DBIAS"); - } - if (mTransposeA) { - buildOptions.emplace("-DTRANSPOSE_A"); - } - if (mTransposeB) { - buildOptions.emplace("-DTRANSPOSE_B"); - } - buildOptions.emplace("-DH_LEAVES=" + std::to_string(h % 4)); - unit.kernel = runTime->buildKernel("loop", KernelName, buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - std::vector mGlobalWorkSize = {(uint32_t)(UP_DIV(h, 4)), (uint32_t)(UP_DIV(e, 4)),(uint32_t)(n)}; - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[0]])); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[1]])); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[2]])); - if (mHasBias) { - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[3]])); - } - for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { - if (mIter[i] >= 0) { - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[i]])); - } else { - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[1]])); - } - } - ret |= unit.kernel->get().setArg(index++, e); - ret |= unit.kernel->get().setArg(index++, l); - ret |= unit.kernel->get().setArg(index++, h); - ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); - ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBatchMatMulBufExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; + // matmul + Unit unit; + std::string KernelName = "batch_matmul"; + std::set buildOptions; + if (mHasBias) { + buildOptions.emplace("-DBIAS"); + } + if (cmd->op()->main_as_MatMul()->transposeA()) { + buildOptions.emplace("-DTRANSPOSE_A"); + } + if (cmd->op()->main_as_MatMul()->transposeB()) { + buildOptions.emplace("-DTRANSPOSE_B"); + } + buildOptions.emplace("-DH_LEAVES=" + std::to_string(h % 4)); + unit.kernel = runTime->buildKernel("loop", KernelName, buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + std::vector mGlobalWorkSize = {(uint32_t)(UP_DIV(h, 4)), (uint32_t)(UP_DIV(e, 4)),(uint32_t)(n)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[1]])); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[2]])); + if (mHasBias) { + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[3]])); + } + for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { + if (mIter[i] >= 0) { + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[i]])); + } else { + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[1]])); + } + } + ret |= unit.kernel->get().setArg(index++, e); + ret |= unit.kernel->get().setArg(index++, l); + ret |= unit.kernel->get().setArg(index++, h); + ret |= unit.kernel->get().setArg(index++, iter); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBatchMatMulBufExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + if(cmd->fuse() >= 0){ + int mStride_dst[4]; + mStride_dst[0] = h * e; + mStride_dst[1] = h; + mStride_dst[2] = 1; + mStride_dst[3] = 1; + FuseOutput(cmdIndex, mStride_dst, 1, e, h, n, iter); } return NO_ERROR; } -LoopBinaryBufExecution::LoopBinaryBufExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn) - : CommonExecution(bn, op) { - mLoop = loop; - mTensors.resize(mLoop->tensorNumber()); - mBuildOptions.emplace("-DOPERATOR=" + compute); +ErrorCode LoopBufExecution::LoopBinary(const Tensor *output, int cmdIndex, int iter) { + auto cmd = mLoop->commands()->GetAs(cmdIndex); + std::string compute = getComputeOption(cmd->op()->main_as_BinaryOp()->opType()); + std::set buildOptions; + buildOptions.emplace("-DOPERATOR=" + compute); + if(cmd->op()->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (output->getType().code == halide_type_int || output->getType().code == halide_type_uint)){ + buildOptions.emplace("-DINT_COMPUTE_MOD"); + } + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; + + Unit unit; + int z = cmd->size()->data()[0]; + int y = cmd->size()->data()[1]; + int x = cmd->size()->data()[2]; + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; + int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = output->elementSize(); + + auto src0Stride = cmd->view()->GetAs(1)->stride()->data(); + auto src1Stride = cmd->view()->GetAs(2)->stride()->data(); + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src0[i] = src0Stride[i]; + mStride_src1[i] = src1Stride[i]; + mStride_dst[i] = dstStride[i]; + } + if(cmd->fuse() >= 0){ + mStride_dst[0] = y * x; + mStride_dst[1] = x; + mStride_dst[2] = 1; + } + + auto input0 = mTensors[cmd->indexes()->data()[1]]; + auto input1 = mTensors[cmd->indexes()->data()[2]]; + + ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); + ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(1)->offset(); + mOffset[2] = cmd->view()->GetAs(2)->offset(); + + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_DST"); + } + if (mIter[1] >= 0) { + buildOptions.emplace("-DOFFSET_SRC0"); + } + if (mIter[2] >= 0) { + buildOptions.emplace("-DOFFSET_SRC1"); + } + unit.kernel = runTime->buildKernel("loop", "loop_binary", buildOptions, mOpenCLBackend->getPrecision(), input0, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + + std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z*n)}; + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(input0)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(input1)); + for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { + if (mIter[i] >= 0) { + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[i]])); + } + } + ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); + ret |= unit.kernel->get().setArg(index++, iter); + ret |= unit.kernel->get().setArg(index++, z); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryBufExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_binary", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + if(cmd->fuse() >= 0){ + FuseOutput(cmdIndex, mStride_dst, cmd->size()->data()[0], cmd->size()->data()[1], cmd->size()->data()[2], n, iter); + } + return NO_ERROR; } -ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { +ErrorCode LoopBufExecution::LoopCumsum(const Tensor *output) { auto cmd = mLoop->commands()->GetAs(0); - if(cmd->op()->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (outputs[0]->getType().code == halide_type_int || outputs[0]->getType().code == halide_type_uint)){ - mBuildOptions.emplace("-DINT_COMPUTE_MOD"); + std::string compute = getComputeOption(cmd->op()->main_as_BinaryOp()->opType()); + std::set buildOptions; + buildOptions.emplace("-DOPERATOR=" + compute); + if(cmd->op()->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (output->getType().code == halide_type_int || output->getType().code == halide_type_uint)){ + buildOptions.emplace("-DINT_COMPUTE_MOD"); } OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); auto runTime = mOpenCLBackend->getOpenCLRuntime(); - _setTensorStack(mTensors, inputs, outputs, mLoop); - mUnits.clear(); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; Unit unit; int z = cmd->size()->data()[0]; int y = cmd->size()->data()[1]; int x = cmd->size()->data()[2]; - int n = mLoop->loopNumber(); + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = output->elementSize(); auto src0Stride = cmd->view()->GetAs(1)->stride()->data(); auto src1Stride = cmd->view()->GetAs(2)->stride()->data(); @@ -311,61 +560,27 @@ ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, mStride_src1[i] = src1Stride[i]; mStride_dst[i] = dstStride[i]; } + if(cmd->fuse() >= 0){ + mStride_dst[0] = y * x; + mStride_dst[1] = x; + mStride_dst[2] = 1; + } auto input0 = mTensors[cmd->indexes()->data()[1]]; auto input1 = mTensors[cmd->indexes()->data()[2]]; - auto output = mTensors[cmd->indexes()->data()[0]]; // cumsum // mTensors cmd->indexes()->data() = {2, 0, 1} -> {output, input0, input1}, output = input0 - if(!mLoop->parallel()){ - int loopNumber = mLoop->loopNumber(); - - ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); - mOffset[0] = cmd->view()->GetAs(0)->offset(); - mOffset[1] = cmd->view()->GetAs(1)->offset(); - mOffset[2] = cmd->view()->GetAs(2)->offset(); - unit.kernel = runTime->buildKernel("loop_buf", "loop_cumsum_buf", mBuildOptions, mOpenCLBackend->getPrecision(), input0, output); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - - std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z)}; - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(input0)); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(input1)); - ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); - ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); - ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); - ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); - ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); - ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); - ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); - ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); - ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); - ret |= unit.kernel->get().setArg(index++, loopNumber); - ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopCumsumBufExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_cumsum_buf", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop_buf").first; - - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - return NO_ERROR; - } + int loopNumber = mLoop->loopNumber(); - unit.kernel = runTime->buildKernel("loop_buf", "loop_binary_buf", mBuildOptions, mOpenCLBackend->getPrecision(), input0, output); + ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(1)->offset(); + mOffset[2] = cmd->view()->GetAs(2)->offset(); + unit.kernel = runTime->buildKernel("loop", "loop_cumsum", buildOptions, mOpenCLBackend->getPrecision(), input0, output); uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z)}; - uint32_t index = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); @@ -383,9 +598,105 @@ ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); + ret |= unit.kernel->get().setArg(index++, loopNumber); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopCumsumBufExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_cumsum", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + return NO_ERROR; +} + +ErrorCode LoopBufExecution::FuseOutput(int iter, int* inputStride, int sizeZ, int sizeY, int SizeX, int n, int n_offset) { + auto cmd = mLoop->commands()->GetAs(iter); + std::string compute = getComputeOption(MNN::BinaryOpOperation(cmd->fuse())); + std::set buildOptions; + buildOptions.emplace("-DOPERATOR=" + compute); + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; + auto input = mFuseTensor.get(); + auto output = mTensors[cmd->indexes()->data()[0]]; + int outputSize = output->elementSize(); + + Unit unit; + int z = sizeZ; + int y = sizeY; + int x = SizeX; + + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src0[i] = dstStride[i]; + mStride_src1[i] = inputStride[i]; + mStride_dst[i] = dstStride[i]; + } + + for(int i = 0; i < 4; ++i){ + mStep[i] = cmd->steps()->data()[0]; + } + ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(0)->offset(); + mOffset[2] = cmd->view()->GetAs(0)->offset(); + + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_DST"); + } + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_SRC0"); + } + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_SRC1"); + } + unit.kernel = runTime->buildKernel("loop", "loop_binary", buildOptions, mOpenCLBackend->getPrecision(), input, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + + std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z*n)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(input)); + if (mIter[0] >= 0) { + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[0]])); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[0]])); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[0]])); + } + ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); + ret |= unit.kernel->get().setArg(index++, n_offset); + ret |= unit.kernel->get().setArg(index++, z); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, outputSize); MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryBufExecution"); - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_binary_buf", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop_buf").first; + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_binary", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; @@ -394,6 +705,67 @@ ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, return NO_ERROR; } +ErrorCode LoopBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs){ + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + _setTensorStack(mTensors, inputs, outputs, mLoop); + // Make Temp output buffer + int bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); + int mMaxFuseBufferSize = 0; + int loopNumber = mLoop->parallel() ? 1 : mLoop->loopNumber(); + for (int i=0; icommands()->size(); ++i) { + auto cmd = mLoop->commands()->GetAs(i); + auto op = cmd->op(); + if (cmd->fuse() >= 0) { + // Make Temp output buffer + auto size = cmd->size()->data(); + if (cmd->op()->type() == OpType_MatMul) { + mMaxFuseBufferSize = std::max(mMaxFuseBufferSize, bufferUnitSize * size[0] * size[2]); + } else { + mMaxFuseBufferSize = std::max(mMaxFuseBufferSize, bufferUnitSize * size[0] * size[1] * size[2]); + } + } + } + if(mMaxFuseBufferSize != 0){ + mFuseTensor.reset(Tensor::createDevice({loopNumber * mMaxFuseBufferSize})); + mOpenCLBackend->onAcquireBuffer(mFuseTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mFuseTensor.get(), Backend::DYNAMIC); + } + mUnits.clear(); + if(mLoop->initCommand() != nullptr){ + InitCommandOnEncode(); + } + if (1 == mLoop->commands()->size()) { + auto cmd = mLoop->commands()->GetAs(0); + auto op = cmd->op(); + if (OpType_UnaryOp == op->type() && nullptr == op->main() && cmd->fuse() < 0) { + return LoopGather(mTensors[cmd->indexes()->data()[0]], 0, 0); + } + if(OpType_BinaryOp == op->type() && mLoop->parallel() == false && cmd->fuse() < 0){ + return LoopCumsum(mTensors[cmd->indexes()->data()[0]]); + } + } + for(int iter = 0; iter < loopNumber; ++iter){ + for (int index = 0; indexcommands()->size(); ++index) { + auto cmd = mLoop->commands()->GetAs(index); + auto op = cmd->op(); + Tensor *originOutput = mTensors[cmd->indexes()->data()[0]]; + Tensor *output = originOutput; + if(cmd->fuse() >= 0){ + output = mFuseTensor.get(); + } + if (OpType_UnaryOp == op->type()){ + LoopGather(output, index, iter); + }else if (OpType_MatMul == op->type()){ + LoopBatchMatMul(output, index, iter); + }else if(OpType_BinaryOp == op->type()){ + LoopBinary(output, index, iter); + } + } + } + return NO_ERROR; +} + class LoopBufCreator : public OpenCLBackend::Creator { public: virtual Execution *onCreate(const std::vector &inputs, const std::vector &outputs, @@ -408,61 +780,7 @@ class LoopBufCreator : public OpenCLBackend::Creator { if (nullptr == loop || loop->commands() == nullptr) { return nullptr; } - // Make Tensor Stack - if (1 == loop->commands()->size()) { - auto cmd = loop->commands()->GetAs(0); - auto subop = cmd->op(); - if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0) { - return new LoopGatherBufExecution(loop, op, backend); - } - if (OpType_MatMul == subop->type() && loop->parallel() && nullptr == loop->initCommand()) { - return new LoopBatchMatMulBufExecution(loop, op, backend); - } - if (OpType_BinaryOp == subop->type() && nullptr == loop->initCommand()) { - switch (subop->main_as_BinaryOp()->opType()) { - case BinaryOpOperation_MUL: - return new LoopBinaryBufExecution(loop, "in0*in1", op, backend); - case BinaryOpOperation_ADD: - return new LoopBinaryBufExecution(loop, "in0+in1", op, backend); - case BinaryOpOperation_SUB: - return new LoopBinaryBufExecution(loop, "in0-in1", op, backend); - case BinaryOpOperation_REALDIV: - return new LoopBinaryBufExecution(loop, "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))", op, backend); - case BinaryOpOperation_MINIMUM: - return new LoopBinaryBufExecution(loop, "in0>in1?in1:in0", op, backend); - case BinaryOpOperation_MAXIMUM: - return new LoopBinaryBufExecution(loop, "in0>in1?in0:in1", op, backend); - case BinaryOpOperation_GREATER: - return new LoopBinaryBufExecution(loop, "(float)(isgreater(in0,in1))", op, backend); - case BinaryOpOperation_LESS: - return new LoopBinaryBufExecution(loop, "(float)(isless(in0,in1))", op, backend); - case BinaryOpOperation_LESS_EQUAL: - return new LoopBinaryBufExecution(loop, "(float)(islessequal(in0,in1))", op, backend); - case BinaryOpOperation_GREATER_EQUAL: - return new LoopBinaryBufExecution(loop, "(float)(isgreaterequal(in0,in1))", op, backend); - case BinaryOpOperation_EQUAL: - return new LoopBinaryBufExecution(loop, "(float)(isequal(in0,in1))", op, backend); - case BinaryOpOperation_FLOORDIV: - return new LoopBinaryBufExecution(loop, "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))", op, backend); - case BinaryOpOperation_FLOORMOD: - return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", op, backend); - case BinaryOpOperation_POW: - return new LoopBinaryBufExecution(loop, "pow(in0,in1)", op, backend); - case BinaryOpOperation_SquaredDifference: - return new LoopBinaryBufExecution(loop, "(in0-in1)*(in0-in1)", op, backend); - case BinaryOpOperation_ATAN2: - return new LoopBinaryBufExecution(loop, "(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))", op, backend); - case BinaryOpOperation_NOTEQUAL: - return new LoopBinaryBufExecution(loop, "(float)(isnotequal(in0,in1))", op, backend); - case BinaryOpOperation_MOD: - return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", op, backend); - default: - break; - } - return nullptr; - } - } - return nullptr; + return new LoopBufExecution(loop, op, backend); } }; diff --git a/source/backend/opencl/execution/buffer/LoopBufExecution.hpp b/source/backend/opencl/execution/buffer/LoopBufExecution.hpp index 82665e9264..b033c3e6d3 100644 --- a/source/backend/opencl/execution/buffer/LoopBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/LoopBufExecution.hpp @@ -16,60 +16,21 @@ namespace MNN { namespace OpenCL { -class LoopGatherBufExecution : public CommonExecution { +class LoopBufExecution : public CommonExecution{ public: - LoopGatherBufExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); - virtual ~LoopGatherBufExecution() = default; + LoopBufExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); + virtual ~LoopBufExecution() = default; virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - ErrorCode InitCommandOnEncode(const std::vector &inputs, const std::vector &outputs); - -private: - const LoopParam *mLoop; - std::vector mTensors; - std::vector> mTmpTensors; - std::vector> mOffsetTensors; - int mStride_src[4]; - int mStride_dst[4]; - int mStep[2]; - int mIter[2]; - std::set mBuildOptions; -}; - -class LoopBatchMatMulBufExecution : public CommonExecution { -public: - LoopBatchMatMulBufExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); - virtual ~LoopBatchMatMulBufExecution() = default; - virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - - -private: - const LoopParam *mLoop; - std::vector mTensors; - int mOffset[4]; - int mStep[4]; - int mIter[4]; - bool mHasBias = false; - bool mTransposeA = false; - bool mTransposeB = false; - std::set mBuildOptions; -}; - - -class LoopBinaryBufExecution : public CommonExecution { -public: - LoopBinaryBufExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn); - virtual ~LoopBinaryBufExecution() = default; - virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - + ErrorCode InitCommandOnEncode(); + ErrorCode LoopGather(const Tensor *output, int cmdIndex, int iter); + ErrorCode LoopBatchMatMul(const Tensor *output, int cmdIndex, int iter); + ErrorCode LoopBinary(const Tensor *outputs, int cmdIndex, int iter); + ErrorCode LoopCumsum(const Tensor *output); + ErrorCode FuseOutput(int iter, int* inputStride, int sizeZ, int sizeY, int SizeX, int n, int n_offset); private: const LoopParam *mLoop; std::vector mTensors; - std::set mBuildOptions; - int mOffset[4]; - int mStep[4]; - int mStride_src0[3]; - int mStride_src1[3]; - int mStride_dst[3]; + std::shared_ptr mFuseTensor; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp b/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp index bed7966ae9..c653da3e62 100644 --- a/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp @@ -254,8 +254,7 @@ class UnaryBufCreator : public OpenCLBackend::Creator { return new UnaryBufExecution("gelu(convert_float4(in))", op, backend); case UnaryOpOperation_GELU_STANDARD: return new UnaryBufExecution("(erf(convert_float4(in)*(float4)0.7071067932881648)+(float4)1.0)*convert_float4(in)*(float4)0.5", op, backend); - - default: + default: break; } return nullptr; diff --git a/source/backend/opencl/execution/cl/gather_buf.cl b/source/backend/opencl/execution/cl/gather_buf.cl deleted file mode 100644 index ce41b9de89..0000000000 --- a/source/backend/opencl/execution/cl/gather_buf.cl +++ /dev/null @@ -1,46 +0,0 @@ -#ifdef MNN_SUPPORT_FP16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#endif - -__kernel void batch_gather_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global OUTPUT_TYPE* output, __global INPUT_TYPE* input, - #ifdef OFFSET_DST - __global int* offset_dst_ptr, - #endif - #ifdef OFFSET_SRC - __global int* offset_src_ptr, - #endif - __private const int x_size, - __private const int4 stride_src, - __private const int4 stride_dst, - __private const int2 steps, - __private const int2 iters, - __private const int inputSize) { - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - - int x = pos.x % x_size; - int y = pos.x / x_size; - - int2 index = (int2)(pos.z, pos.z); -#ifdef OFFSET_DST - index.x = offset_dst_ptr[pos.z]; -#endif - -#ifdef OFFSET_SRC - index.y = offset_src_ptr[pos.z]; -#endif - int2 offset = index * steps; - int src_offset = offset.y + stride_src.w + x * stride_src.x + y * stride_src.y + pos.y * stride_src.z; - int dst_offset = offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z; - - if(offset.x >= 0){ - if(offset.y >= 0 && offset.y < inputSize){ - output[dst_offset] = (OUTPUT_TYPE)input[src_offset]; - }else{ - output[dst_offset] = (OUTPUT_TYPE)(0); - } - } - } -} diff --git a/source/backend/opencl/execution/cl/gather_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/gather_buf_mnn_cl.cpp deleted file mode 100644 index c7d3543ac6..0000000000 --- a/source/backend/opencl/execution/cl/gather_buf_mnn_cl.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "opencl_source_map.hpp" -namespace MNN { -#ifndef MNN_OPENCL_BUFFER_CLOSED -const char* gather_buf = -"#ifdef MNN_SUPPORT_FP16\n" -"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" -"#endif\n" -"__kernel void batch_gather_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" -" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input,\n" -" #ifdef OFFSET_DST\n" -" __global int* offset_dst_ptr,\n" -" #endif\n" -" #ifdef OFFSET_SRC\n" -" __global int* offset_src_ptr,\n" -" #endif\n" -" __private const int x_size,\n" -" __private const int4 stride_src,\n" -" __private const int4 stride_dst,\n" -" __private const int2 steps,\n" -" __private const int2 iters,\n" -" __private const int inputSize) {\n" -" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n" -" \n" -" if (pos.x= 0){\n" -" if(offset.y >= 0 && offset.y= global_size_dim0 || input2 >= global_size_dim1) { \ + return; \ + } +__kernel void set_zero( + GLOBAL_SIZE_2_DIMS + __global OUTPUT_TYPE *output + ) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + DEAL_NON_UNIFORM_DIM2(x, y); + + output[y*global_size_dim0 + x] = (OUTPUT_TYPE)(0); +} __kernel void batch_matmul(__private int global_dim0, __private int global_dim1, __private int global_dim2, __global FLOAT* output, __global FLOAT* input_A, __global FLOAT* input_B, @@ -15,7 +33,7 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1, #endif __private const int e, __private const int l, - __private const int h, + __private const int h,__private const int iter, __private const int4 offsets, __private const int4 iters, __private const int4 steps) { @@ -23,6 +41,7 @@ __kernel void batch_matmul(__private int global_dim0, __private int global_dim1, if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { pos.x <<= 2; pos.y <<= 2; + pos.z += iter; int4 index = (int4)(pos.z); if (iters.x >= 0) { index.x = offset_O[pos.z]; @@ -284,6 +303,9 @@ __kernel void pack(__private int global_dim0, __private int global_dim1, __priva } } +#ifndef UNARY_OPERATOR + #define UNARY_OPERATOR in +#endif __kernel void batch_gather(__private int global_dim0, __private int global_dim1, __private int global_dim2, __global OUTPUT_TYPE* output, __global INPUT_TYPE* input, #ifdef OFFSET_DST @@ -293,10 +315,12 @@ __kernel void batch_gather(__private int global_dim0, __private int global_dim1, __global int* offset_src_ptr, #endif __private const int x_size, + __private const int iter, __private const int4 stride_src, __private const int4 stride_dst, __private const int2 steps, - __private const int inputSize) { + __private const int inputSize, + __private const int outputSize) { int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { @@ -304,6 +328,7 @@ __kernel void batch_gather(__private int global_dim0, __private int global_dim1, int x = pos.x % x_size; int y = pos.x / x_size; + pos.z += iter; int2 index = (int2)(pos.z, pos.z); #ifdef OFFSET_DST index.x = offset_dst_ptr[pos.z]; @@ -313,11 +338,13 @@ __kernel void batch_gather(__private int global_dim0, __private int global_dim1, index.y = offset_src_ptr[pos.z]; #endif int2 offset = index * steps; - if(offset.x >= 0){ + int outputIndex = offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z; + if(outputIndex < outputSize && offset.x >= 0){ if(offset.y >= 0 && offset.y < inputSize){ - output[offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z] = (OUTPUT_TYPE)input[offset.y + stride_src.w + x * stride_src.x + y * stride_src.y + pos.y * stride_src.z]; + INPUT_TYPE in = input[offset.y + stride_src.w + x * stride_src.x + y * stride_src.y + pos.y * stride_src.z]; + output[outputIndex] = (OUTPUT_TYPE)(UNARY_OPERATOR); }else{ - output[offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z] = (OUTPUT_TYPE)(0); + output[outputIndex] = (OUTPUT_TYPE)(0); } } } @@ -326,108 +353,72 @@ __kernel void batch_gather(__private int global_dim0, __private int global_dim1, #ifndef OPERATOR #define OPERATOR in0 + in1 #endif -__kernel void broadcast_binary(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __write_only image2d_t output, __read_only image2d_t input0, __read_only image2d_t input1, - __private const int8 src0_size, //(batch, channel, height, width) - __private const int4 src0C4_size, // nc4hw4 - __private const int8 src1_size, - __private const int4 src1C4_size, - __private const int8 dst_size, - __private const int dst_width, - __private const int dst_height, - __private const int dst_channel, - __private const int channel_block) { - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); +__kernel void loop_binary(__private int global_dim0, __private int global_dim1, __private int global_dim2, + __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, + #ifdef OFFSET_DST + __global int* offset_dst_ptr, + #endif + #ifdef OFFSET_SRC0 + __global int* offset_src0_ptr, + #endif + #ifdef OFFSET_SRC1 + __global int* offset_src1_ptr, + #endif + __private const int input0Stride0, + __private const int input0Stride1, + __private const int input0Stride2, + __private const int input1Stride0, + __private const int input1Stride1, + __private const int input1Stride2, + __private const int outputStride0, + __private const int outputStride1, + __private const int outputStride2, + __private const int iter, + __private const int zSize, + __private const int4 offsets, + __private const int4 steps, + __private const int outputSize + ) { + + const int x = get_global_id(0); + const int y = get_global_id(1); + const int zn = get_global_id(2); - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { + if (x < global_dim0 && y < global_dim1 && zn < global_dim2) { - const int wo = pos.x; - const int ho = pos.y; - const int co = pos.z % channel_block; - const int no = pos.z / channel_block; - int co4 = co << 2; - int4 covec = (int4)(co4 % dst_channel, (co4 + 1) % dst_channel, (co4 + 2) % dst_channel, (co4 + 3) % dst_channel); - int4 out_offset = ((no * dst_channel + covec) * dst_height + ho) * dst_width + wo; - int4 w = out_offset % (dst_size.s3 * dst_size.s4); out_offset /= (dst_size.s3 * dst_size.s4); - int4 h = out_offset % dst_size.s2; out_offset /= dst_size.s2; - int4 c = out_offset % dst_size.s1; out_offset /= dst_size.s1; - int4 n = out_offset % dst_size.s0; - #ifdef INT_COMPUTE_MOD - int4 in0, in1; - int* in0_ptr = (int*)&in0; - int* in1_ptr = (int*)&in1; - #else - float4 in0, in1; - float* in0_ptr = (float*)&in0; - float* in1_ptr = (float*)&in1; + int z = zn % zSize; + int n = zn / zSize; + n += iter; + int4 index = (int4)(n, n, n, n); + #ifdef OFFSET_DST + index.x = offset_dst_ptr[n]; + #endif + + #ifdef OFFSET_SRC0 + index.y = offset_src0_ptr[n]; #endif - { - int4 w0 = w % (src0_size.s3 * src0_size.s4); - int4 h0 = h % src0_size.s2; - int4 c0 = c % src0_size.s1; - int4 n0 = n % src0_size.s0; - int* w0_ptr = (int*)&w0; - int* h0_ptr = (int*)&h0; - int* c0_ptr = (int*)&c0; - int* n0_ptr = (int*)&n0; - for(int i = 0; i < 4; ++i){ - int c4offset = ((n0_ptr[i] * src0_size.s1 + c0_ptr[i]) * src0_size.s2 + h0_ptr[i]) * src0_size.s3 * src0_size.s4 + w0_ptr[i]; - int wc4 = c4offset % src0C4_size.x; c4offset /= src0C4_size.x; - int hc4 = c4offset % src0C4_size.y; c4offset /= src0C4_size.y; - int cc4 = c4offset % src0C4_size.z; c4offset /= src0C4_size.z; - int nc4 = c4offset % src0C4_size.w; - int cc4_offset = cc4 / 4; - int cc4_remain = cc4 % 4; - #ifdef INT_COMPUTE_MOD - int4 tmp = convert_int4(RI_DATA(input0, SAMPLER, (int2)(cc4_offset * src0C4_size.x + wc4, nc4 * src0C4_size.y + hc4))); - int *tmp_ptr = (int*)&tmp; - in0_ptr[i] = tmp_ptr[cc4_remain]; - #else - float4 tmp = convert_float4(RI_DATA(input0, SAMPLER, (int2)(cc4_offset * src0C4_size.x + wc4, nc4 * src0C4_size.y + hc4))); - float *tmp_ptr = (float*)&tmp; - in0_ptr[i] = tmp_ptr[cc4_remain]; - #endif - } - } - - { - int4 w0 = w % (src1_size.s3 * src1_size.s4); - int4 h0 = h % src1_size.s2; - int4 c0 = c % src1_size.s1; - int4 n0 = n % src1_size.s0; - int* w0_ptr = (int*)&w0; - int* h0_ptr = (int*)&h0; - int* c0_ptr = (int*)&c0; - int* n0_ptr = (int*)&n0; - for(int i = 0; i < 4; ++i){ - int c4offset = ((n0_ptr[i] * src1_size.s1 + c0_ptr[i]) * src1_size.s2 + h0_ptr[i]) * src1_size.s3 * src1_size.s4 + w0_ptr[i]; - int wc4 = c4offset % src1C4_size.x; c4offset /= src1C4_size.x; - int hc4 = c4offset % src1C4_size.y; c4offset /= src1C4_size.y; - int cc4 = c4offset % src1C4_size.z; c4offset /= src1C4_size.z; - int nc4 = c4offset % src1C4_size.w; - int cc4_offset = cc4 / 4; - int cc4_remain = cc4 % 4; - #ifdef INT_COMPUTE_MOD - int4 tmp = convert_int4(RI_DATA(input1, SAMPLER, (int2)(cc4_offset * src1C4_size.x + wc4, nc4 * src1C4_size.y + hc4))); - int *tmp_ptr = (int*)&tmp; - in1_ptr[i] = tmp_ptr[cc4_remain]; - #else - float4 tmp = convert_float4(RI_DATA(input1, SAMPLER, (int2)(cc4_offset * src1C4_size.x + wc4, nc4 * src1C4_size.y + hc4))); - float *tmp_ptr = (float*)&tmp; - in1_ptr[i] = tmp_ptr[cc4_remain]; - #endif - } - } + #ifdef OFFSET_SRC1 + index.z = offset_src1_ptr[n]; + #endif + int4 offset = index * steps + offsets; + int inputIndex0 = offset.y + z * input0Stride0 + y * input0Stride1 + x * input0Stride2; + int inputIndex1 = offset.z + z * input1Stride0 + y * input1Stride1 + x * input1Stride2; + int outputIndex = offset.x + z * outputStride0 + y * outputStride1 + x * outputStride2; #ifdef INT_COMPUTE_MOD - int4 out = in0 % in1; - out = ((out < (int4)0 && in1 > (int4)0) || (out > (int4)0 && in1 < (int4)0)) ? out + in1 : out; + int in0 = (int)input0[inputIndex0]; + int in1 = (int)input1[inputIndex1]; + int out = in0 % in1; + out = ((out < 0 && in1 > 0) || (out > 0 && in1 < 0)) ? out + in1 : out; #else - float4 out = OPERATOR; + float in0 = (float)input0[inputIndex0]; + float in1 = (float)input1[inputIndex1]; + float out = OPERATOR; #endif - - WI_DATA(output, (int2)(co * dst_width + wo, no * dst_height + ho), CONVERT_OUTPUT_I4(out)); + if(outputIndex < outputSize){ + output[outputIndex] = (OUTPUT_TYPE)out; + } } } @@ -444,7 +435,8 @@ __kernel void loop_cumsum(__private int global_dim0, __private int global_dim1, __private const int outputStride2, __private const int loopNumber, __private const int4 offsets, - __private const int4 steps + __private const int4 steps, + __private const int outputSize ) { const int x = get_global_id(0); @@ -457,20 +449,20 @@ __kernel void loop_cumsum(__private int global_dim0, __private int global_dim1, int inputIndex1 = z * input1Stride0 + y * input1Stride1 + x * input1Stride2; int outputIndex = z * outputStride0 + y * outputStride1 + x * outputStride2; - float4 in0 = 0; + float in0 = 0; if(offsets.z != offsets.y){ - in0.x = (float)input0[inputIndex0]; + in0 = (float)input0[inputIndex0]; } for(int i = 0; i < loopNumber; ++i){ int4 offset = (int4)i * steps + offsets; - float4 in1; - in1.x = (float)input1[inputIndex1 + offset.z]; - float4 out = OPERATOR; + float in1 = (float)input1[inputIndex1 + offset.z]; + float out = OPERATOR; - output[outputIndex + offset.x] = (OUTPUT_TYPE)out.x; - in0.x = out.x; + if(outputIndex + offset.x < outputSize){ + output[outputIndex + offset.x] = (OUTPUT_TYPE)out; + } + in0 = out; } } } - diff --git a/source/backend/opencl/execution/cl/loop_buf.cl b/source/backend/opencl/execution/cl/loop_buf.cl deleted file mode 100644 index 7239281325..0000000000 --- a/source/backend/opencl/execution/cl/loop_buf.cl +++ /dev/null @@ -1,469 +0,0 @@ -#ifdef MNN_SUPPORT_FP16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#endif -#define PI 3.141592653589f -#ifndef WGSW - #define WGSW 32 // work-group handle size W dimension -#endif -#ifndef WGSC - #define WGSC 32 // work-group handle size C dimension -#endif -#ifndef WGSH - #define WGSH 32 // work-group handle size H dimension -#endif -#ifndef TSW - #define TSW 8 // thread handle size W dimension -#endif -#ifndef TSC - #define TSC 8 // thread handle size C dimension -#endif -#ifndef TSH - #define TSH 8 // thread handle size H dimension -#endif - -// [C4 N H 1 4] -> [N H C 1] -__kernel void tile_trans_3d_buf(__global INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int widthPad, - __private const int heightPad, - __private const int channelPad, - __private const int batch, - __private const int width, - __private const int height, - __private const int channel -) { - int b = get_global_id(2); - - const int lidc = get_local_id(0); - const int lidh = get_local_id(1); - // group id - const int c = get_group_id(0) * WGSC; - const int h = get_group_id(1) * WGSH; - - int jc = lidc; - int ih = lidh; - - __local INPUT_TYPE4 localData[WGSH][WGSC/4];//h64c64 - - #pragma unroll - for(int i = 0; i < TSH; i++) { - #pragma unroll - for(int j = 0; j < TSC / 4; j++) { - int offset_h = i * WGSH / TSH + ih; - int offset_c = j * WGSC / TSC + jc ; - // [TSH, WGSH / TSH] [TSC / 4, WGSC / TSC, 4] - localData[offset_h][offset_c] = (h + offset_h >= height || c + 4 * offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0, input + ((b + (c/4+offset_c)*batch) * height + (h+offset_h)) * 4); - } - } - - barrier(CLK_LOCAL_MEM_FENCE); - - // C offset: [WGSC / TSC, TSC / 4] - // H offset: [WGSH / TSH, TSH] - int oc_base = jc * TSC / 4; - int oh_base = ih * TSH; - - //#pragma unroll - for(int i = 0; i < TSH; i++) { - int oh = oh_base + i; - - //#pragma unroll - for(int j = 0; j < TSC / 4; j++) { - int oc = oc_base + j; - - OUTPUT_TYPE4 value = CONVERT_OUTPUT4(localData[oh][oc]); - - vstore4(value, 0, output + ((b * heightPad + h + oh) * channelPad + c + 4 * oc)); - } - } -} -// [C4 N H W 4] -> [N C W H] -__kernel void tile_trans_4d_buf(__global INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int widthPad, - __private const int heightPad, - __private const int channelPad, - __private const int batch, - __private const int width, - __private const int height, - __private const int channel -) { - int bc = get_global_id(2); - int b = bc % batch; - int c4 = bc / batch; - int c = c4 << 2; - - const int lidw = get_local_id(0); - const int lidh = get_local_id(1); - // group id - const int w = get_group_id(0) * WGSW; - const int h = get_group_id(1) * WGSH; - - int jw = lidw; - int ih = lidh; - - __local INPUT_TYPE4 localData[WGSH][WGSW];//w32h32c4 - - #pragma unroll - for(int i = 0; i < TSH; i++) { - #pragma unroll - for(int j = 0; j < TSW; j++) { - int offset_h = h + ih + i * WGSH/TSH; - int offset_w = w + jw + j * WGSW/TSW; - localData[ih + i * WGSH / TSH][jw + j * WGSW/TSW] = (offset_h >= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0, input + (((b + c4 * batch) * height + offset_h) * width + offset_w) * 4); - } - } - - barrier(CLK_LOCAL_MEM_FENCE); - - // c4w32h32 - int oh = ih * TSH >> 4; - int mh = ih & (16 / TSH - 1); - // TSW offset: [TSH / 4, TSW / 4, 16 / TSH] - int ow_base = jw * TSW; - int oh_offset = oh << 4; - - //#pragma unroll - for(int i = 0; i < TSH / 4; i++) { - //#pragma unroll - for(int j = 0; j < TSW / 4; j++) { - - // c4 - OUTPUT_TYPE16 value; - int ow = ow_base + (((i * TSW / 4) + j) * (16 / TSH) + mh); - - value.s0 = localData[0+oh_offset][ow].s0; - value.s1 = localData[1+oh_offset][ow].s0; - value.s2 = localData[2+oh_offset][ow].s0; - value.s3 = localData[3+oh_offset][ow].s0; - value.s4 = localData[4+oh_offset][ow].s0; - value.s5 = localData[5+oh_offset][ow].s0; - value.s6 = localData[6+oh_offset][ow].s0; - value.s7 = localData[7+oh_offset][ow].s0; - value.s8 = localData[8+oh_offset][ow].s0; - value.s9 = localData[9+oh_offset][ow].s0; - value.sa = localData[10+oh_offset][ow].s0; - value.sb = localData[11+oh_offset][ow].s0; - value.sc = localData[12+oh_offset][ow].s0; - value.sd = localData[13+oh_offset][ow].s0; - value.se = localData[14+oh_offset][ow].s0; - value.sf = localData[15+oh_offset][ow].s0; - vstore16(value, 0, output + (((b * channelPad + c + 0) * widthPad + w + ow) * heightPad + h + oh_offset)); - - if(c + 1 < channel) { - value.s0 = localData[0+oh_offset][ow].s1; - value.s1 = localData[1+oh_offset][ow].s1; - value.s2 = localData[2+oh_offset][ow].s1; - value.s3 = localData[3+oh_offset][ow].s1; - value.s4 = localData[4+oh_offset][ow].s1; - value.s5 = localData[5+oh_offset][ow].s1; - value.s6 = localData[6+oh_offset][ow].s1; - value.s7 = localData[7+oh_offset][ow].s1; - value.s8 = localData[8+oh_offset][ow].s1; - value.s9 = localData[9+oh_offset][ow].s1; - value.sa = localData[10+oh_offset][ow].s1; - value.sb = localData[11+oh_offset][ow].s1; - value.sc = localData[12+oh_offset][ow].s1; - value.sd = localData[13+oh_offset][ow].s1; - value.se = localData[14+oh_offset][ow].s1; - value.sf = localData[15+oh_offset][ow].s1; - vstore16(value, 0, output + (((b * channelPad + c + 1) * widthPad + w + ow) * heightPad + h + oh_offset)); - } - - if(c + 2 < channel) { - value.s0 = localData[0+oh_offset][ow].s2; - value.s1 = localData[1+oh_offset][ow].s2; - value.s2 = localData[2+oh_offset][ow].s2; - value.s3 = localData[3+oh_offset][ow].s2; - value.s4 = localData[4+oh_offset][ow].s2; - value.s5 = localData[5+oh_offset][ow].s2; - value.s6 = localData[6+oh_offset][ow].s2; - value.s7 = localData[7+oh_offset][ow].s2; - value.s8 = localData[8+oh_offset][ow].s2; - value.s9 = localData[9+oh_offset][ow].s2; - value.sa = localData[10+oh_offset][ow].s2; - value.sb = localData[11+oh_offset][ow].s2; - value.sc = localData[12+oh_offset][ow].s2; - value.sd = localData[13+oh_offset][ow].s2; - value.se = localData[14+oh_offset][ow].s2; - value.sf = localData[15+oh_offset][ow].s2; - vstore16(value, 0, output + (((b * channelPad + c + 2) * widthPad + w + ow) * heightPad + h + oh_offset)); - } - - if(c + 3 < channel) { - value.s0 = localData[0+oh_offset][ow].s3; - value.s1 = localData[1+oh_offset][ow].s3; - value.s2 = localData[2+oh_offset][ow].s3; - value.s3 = localData[3+oh_offset][ow].s3; - value.s4 = localData[4+oh_offset][ow].s3; - value.s5 = localData[5+oh_offset][ow].s3; - value.s6 = localData[6+oh_offset][ow].s3; - value.s7 = localData[7+oh_offset][ow].s3; - value.s8 = localData[8+oh_offset][ow].s3; - value.s9 = localData[9+oh_offset][ow].s3; - value.sa = localData[10+oh_offset][ow].s3; - value.sb = localData[11+oh_offset][ow].s3; - value.sc = localData[12+oh_offset][ow].s3; - value.sd = localData[13+oh_offset][ow].s3; - value.se = localData[14+oh_offset][ow].s3; - value.sf = localData[15+oh_offset][ow].s3; - vstore16(value, 0, output + (((b * channelPad + c + 3) * widthPad + w + ow) * heightPad + h + oh_offset)); - } - } - } -} - -__kernel void tile_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global INPUT_TYPE* input, __global OUTPUT_TYPE* output, - __private const int widthPad, - __private const int heightPad, - __private const int channelPad, - __private const int batch, - __private const int width, - __private const int height, - __private const int channel){ - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int b = pos.z % batch; - const int w = pos.x; - const int h = pos.y; - const int c_4 = pos.z / batch; - - const int c = c_4 << 2; - const int x_src_pitch = 4; - const int y_src_pitch = x_src_pitch * width; - const int b_src_pitch = y_src_pitch * height; - const int c_src_pitch = b_src_pitch * batch; - - bool outBound = (w >= width || h >= height || c >= channel); -#ifdef MNN_NHWC - #if defined(DIMENSION_3) && defined(TRANSPOSE) - // [N, W, H, 1] - const int c_dst_pitch = 1; - const int y_dst_pitch = c_dst_pitch * channelPad; - const int x_dst_pitch = y_dst_pitch * heightPad; - const int b_dst_pitch = x_dst_pitch * widthPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #elif defined(DIMENSION_4) && defined(TRANSPOSE) - // [N, H, C, W] - const int x_dst_pitch = 1; - const int c_dst_pitch = x_dst_pitch * widthPad; - const int y_dst_pitch = c_dst_pitch * channelPad; - const int b_dst_pitch = y_dst_pitch * heightPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #elif defined(DIMENSION_3) - // [N, H, W, 1] - const int c_dst_pitch = 1; - const int x_dst_pitch = c_dst_pitch * channelPad; - const int y_dst_pitch = x_dst_pitch * widthPad; - const int b_dst_pitch = y_dst_pitch * heightPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #else - // [N, H, W, C] - const int c_dst_pitch = 1; - const int x_dst_pitch = c_dst_pitch * channelPad; - const int y_dst_pitch = x_dst_pitch * widthPad; - const int b_dst_pitch = y_dst_pitch * heightPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #endif -#else - #if defined(DIMENSION_3) && defined(TRANSPOSE) - // [N, H, C, 1] - const int x_dst_pitch = 1; - const int c_dst_pitch = x_dst_pitch * widthPad; - const int y_dst_pitch = c_dst_pitch * channelPad; - const int b_dst_pitch = y_dst_pitch * heightPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - - #elif defined(DIMENSION_4) && defined(TRANSPOSE) - // [N, C, W, H] - const int y_dst_pitch = 1; - const int x_dst_pitch = y_dst_pitch * heightPad; - const int c_dst_pitch = x_dst_pitch * widthPad; - const int b_dst_pitch = c_dst_pitch * channelPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #elif defined(DIMENSION_3) - // [N, C, H, 1] - const int x_dst_pitch = 1; - const int y_dst_pitch = x_dst_pitch * widthPad; - const int c_dst_pitch = y_dst_pitch * heightPad; - const int b_dst_pitch = c_dst_pitch * channelPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #else - // [N, C, H, W] - const int x_dst_pitch = 1; - const int y_dst_pitch = x_dst_pitch * widthPad; - const int c_dst_pitch = y_dst_pitch * heightPad; - const int b_dst_pitch = c_dst_pitch * channelPad; - OUTPUT_TYPE4 value = outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0, input + b * b_src_pitch + c_4 * c_src_pitch + h * y_src_pitch + w * x_src_pitch)); - #endif -#endif - - __global OUTPUT_TYPE* dst_ptr = output + b * b_dst_pitch + c * c_dst_pitch + h * y_dst_pitch + w * x_dst_pitch; - - dst_ptr[0] = value.x; - if(c + 1 >= channel)return; - dst_ptr[c_dst_pitch] = value.y; - if(c + 2 >= channel)return; - dst_ptr[2 * c_dst_pitch] = value.z; - if(c + 3 >= channel)return; - dst_ptr[3 * c_dst_pitch] = value.w; - } -} - -__kernel void pack_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global INPUT_TYPE* input, __global OUTPUT_TYPE* output, - __private const int widthPad, - __private const int heightPad, - __private const int channelPad, - __private const int batch, - __private const int width, - __private const int height, - __private const int channel){ - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - - const int b = pos.z % batch; - const int w = pos.x; - const int h = pos.y; - const int c_4 = pos.z / batch; - - const int c = c_4 << 2; - if(w >= width || h >= height || c >= channel) { - return; - } - const int x_dst_pitch = 4; - const int y_dst_pitch = x_dst_pitch * width; - const int c_dst_pitch = y_dst_pitch * height; - const int b_dst_pitch = c_dst_pitch * ((channel + 3) / 4); -#ifdef MNN_NHWC - #if defined(TRANSPOSE) && defined(DIMENSION_3) - // [N, W, H, 1] - const int c_src_pitch = 1; - const int y_src_pitch = c_src_pitch; - const int x_src_pitch = y_src_pitch * heightPad; - const int b_src_pitch = x_src_pitch * widthPad; - #elif defined(TRANSPOSE) && defined(DIMENSION_4) - // [N, H, C, W] - const int x_src_pitch = 1; - const int c_src_pitch = x_src_pitch * widthPad; - const int y_src_pitch = c_src_pitch * channelPad; - const int b_src_pitch = y_src_pitch * heightPad; - #else - // [N, H, W, C] - const int c_src_pitch = 1; - const int x_src_pitch = c_src_pitch * channelPad; - const int y_src_pitch = x_src_pitch * widthPad; - const int b_src_pitch = y_src_pitch * heightPad; - #endif -#else - #if defined(TRANSPOSE) && defined(DIMENSION_3) - // dst:[N, C, H, 1] -> src:[N, H, C, 1] - const int x_src_pitch = 1; - const int c_src_pitch = x_src_pitch * widthPad; - const int y_src_pitch = c_src_pitch * channelPad; - const int b_src_pitch = y_src_pitch * heightPad; - #elif defined(TRANSPOSE) && defined(DIMENSION_4) - // dst:[N, C, H, W] -> src:[N, C, W, H] - const int y_src_pitch = 1; - const int x_src_pitch = y_src_pitch * heightPad; - const int c_src_pitch = x_src_pitch * widthPad; - const int b_src_pitch = c_src_pitch * channelPad; - #else - // [N, C, H, W] - const int x_src_pitch = 1; - const int y_src_pitch = x_src_pitch * widthPad; - const int c_src_pitch = y_src_pitch * heightPad; - const int b_src_pitch = c_src_pitch * channelPad; - #endif -#endif - __global INPUT_TYPE* src_ptr = input + b * b_src_pitch + c * c_src_pitch + h * y_src_pitch + w * x_src_pitch; - OUTPUT_TYPE4 value = (OUTPUT_TYPE4)0; - OUTPUT_TYPE *value_ptr = (OUTPUT_TYPE*)&value; - for(int i = 0; i < 4 && (i + c < channel); ++i){ - value_ptr[i] = (OUTPUT_TYPE)src_ptr[i * c_src_pitch]; - } - vstore4(value, 0, output + b * b_dst_pitch + c_4 * c_dst_pitch + h * y_dst_pitch + w * x_dst_pitch); - } -} - -#ifndef OPERATOR - #define OPERATOR in0 + in1 -#endif -__kernel void loop_binary_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, - __private const int input0Stride0, - __private const int input0Stride1, - __private const int input0Stride2, - __private const int input1Stride0, - __private const int input1Stride1, - __private const int input1Stride2, - __private const int outputStride0, - __private const int outputStride1, - __private const int outputStride2 - ) { - - const int x = get_global_id(0); - const int y = get_global_id(1); - const int z = get_global_id(2); - - if (x < global_dim0 && y < global_dim1 && z < global_dim2) { - - int inputIndex0 = z * input0Stride0 + y * input0Stride1 + x * input0Stride2; - int inputIndex1 = z * input1Stride0 + y * input1Stride1 + x * input1Stride2; - int outputIndex = z * outputStride0 + y * outputStride1 + x * outputStride2; - #ifdef INT_COMPUTE_MOD - int in0 = (int)input0[inputIndex0]; - int in1 = (int)input1[inputIndex1]; - int out = in0 % in1; - out = ((out < 0 && in1 > 0) || (out > 0 && in1 < 0)) ? out + in1 : out; - #else - float in0 = (float)input0[inputIndex0]; - float in1 = (float)input1[inputIndex1]; - float out = OPERATOR; - #endif - output[outputIndex] = (OUTPUT_TYPE)out; - } -} - -__kernel void loop_cumsum_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, - __private const int input0Stride0, - __private const int input0Stride1, - __private const int input0Stride2, - __private const int input1Stride0, - __private const int input1Stride1, - __private const int input1Stride2, - __private const int outputStride0, - __private const int outputStride1, - __private const int outputStride2, - __private const int loopNumber, - __private const int4 offsets, - __private const int4 steps - ) { - - const int x = get_global_id(0); - const int y = get_global_id(1); - const int z = get_global_id(2); - - if (x < global_dim0 && y < global_dim1 && z < global_dim2) { - - int inputIndex0 = z * input0Stride0 + y * input0Stride1 + x * input0Stride2; - int inputIndex1 = z * input1Stride0 + y * input1Stride1 + x * input1Stride2; - int outputIndex = z * outputStride0 + y * outputStride1 + x * outputStride2; - - float in0 = 0; - if(offsets.z != offsets.y){ - in0 = (float)input0[inputIndex0]; - } - - for(int i = 0; i < loopNumber; ++i){ - int4 offset = (int4)i * steps + offsets; - float in1 = (float)input1[inputIndex1 + offset.z]; - float out = OPERATOR; - - output[outputIndex + offset.x] = (OUTPUT_TYPE)out; - in0 = out; - } - } -} diff --git a/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp b/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp deleted file mode 100644 index d23338e428..0000000000 --- a/source/backend/opencl/execution/cl/loop_buf_mnn_cl.cpp +++ /dev/null @@ -1,463 +0,0 @@ -#include "opencl_source_map.hpp" -namespace MNN { -#ifndef MNN_OPENCL_BUFFER_CLOSED -const char* loop_buf = -"#ifdef MNN_SUPPORT_FP16\n" -"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" -"#endif\n" -"#define PI 3.141592653589f\n" -"#ifndef WGSW\n" -" #define WGSW 32 // work-group handle size W dimension\n" -"#endif\n" -"#ifndef WGSC\n" -" #define WGSC 32 // work-group handle size C dimension\n" -"#endif\n" -"#ifndef WGSH\n" -" #define WGSH 32 // work-group handle size H dimension\n" -"#endif\n" -"#ifndef TSW\n" -" #define TSW 8 // thread handle size W dimension\n" -"#endif\n" -"#ifndef TSC\n" -" #define TSC 8 // thread handle size C dimension\n" -"#endif\n" -"#ifndef TSH\n" -" #define TSH 8 // thread handle size H dimension\n" -"#endif\n" -"// [C4 N H 1 4] -> [N H C 1]\n" -"__kernel void tile_trans_3d_buf(__global INPUT_TYPE* input,\n" -" __global OUTPUT_TYPE* output,\n" -" __private const int widthPad,\n" -" __private const int heightPad,\n" -" __private const int channelPad,\n" -" __private const int batch,\n" -" __private const int width,\n" -" __private const int height,\n" -" __private const int channel\n" -") {\n" -" int b=get_global_id(2);\n" -" \n" -" const int lidc=get_local_id(0);\n" -" const int lidh=get_local_id(1);\n" -" // group id\n" -" const int c=get_group_id(0)*WGSC;\n" -" const int h=get_group_id(1)*WGSH;\n" -" int jc=lidc;\n" -" int ih=lidh;\n" -" \n" -" __local INPUT_TYPE4 localData[WGSH][WGSC/4];//h64c64\n" -" \n" -" #pragma unroll\n" -" for(int i=0; i= height || c+4*offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0,input+((b+(c/4+offset_c)*batch)*height+(h+offset_h))*4);\n" -" }\n" -" }\n" -" \n" -" barrier(CLK_LOCAL_MEM_FENCE);\n" -" \n" -" // C offset: [WGSC/TSC,TSC/4]\n" -" // H offset: [WGSH/TSH,TSH]\n" -" int oc_base=jc*TSC/4;\n" -" int oh_base=ih*TSH;\n" -" //#pragma unroll\n" -" for(int i=0; i [N C W H]\n" -"__kernel void tile_trans_4d_buf(__global INPUT_TYPE* input,\n" -" __global OUTPUT_TYPE* output,\n" -" __private const int widthPad,\n" -" __private const int heightPad,\n" -" __private const int channelPad,\n" -" __private const int batch,\n" -" __private const int width,\n" -" __private const int height,\n" -" __private const int channel\n" -") {\n" -" int bc=get_global_id(2);\n" -" int b=bc % batch;\n" -" int c4=bc/batch;\n" -" int c=c4 << 2;\n" -" \n" -" const int lidw=get_local_id(0);\n" -" const int lidh=get_local_id(1);\n" -" // group id\n" -" const int w=get_group_id(0)*WGSW;\n" -" const int h=get_group_id(1)*WGSH;\n" -" int jw=lidw;\n" -" int ih=lidh;\n" -" \n" -" __local INPUT_TYPE4 localData[WGSH][WGSW];//w32h32c4\n" -" \n" -" #pragma unroll\n" -" for(int i=0; i= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0,input+(((b+c4*batch)*height+offset_h)*width+offset_w)*4);\n" -" }\n" -" }\n" -" \n" -" barrier(CLK_LOCAL_MEM_FENCE);\n" -" \n" -" // c4w32h32\n" -" int oh=ih*TSH >> 4;\n" -" int mh=ih & (16/TSH-1);\n" -" // TSW offset: [TSH/4,TSW/4,16/TSH]\n" -" int ow_base=jw*TSW;\n" -" int oh_offset=oh << 4;\n" -" //#pragma unroll\n" -" for(int i=0; i= width || h >= height || c >= channel);\n" -"#ifdef MNN_NHWC\n" -" #if defined(DIMENSION_3) && defined(TRANSPOSE)\n" -" // [N,W,H,1]\n" -" const int c_dst_pitch=1;\n" -" const int y_dst_pitch=c_dst_pitch*channelPad;\n" -" const int x_dst_pitch=y_dst_pitch*heightPad;\n" -" const int b_dst_pitch=x_dst_pitch*widthPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #elif defined(DIMENSION_4) && defined(TRANSPOSE)\n" -" // [N,H,C,W]\n" -" const int x_dst_pitch=1;\n" -" const int c_dst_pitch=x_dst_pitch*widthPad;\n" -" const int y_dst_pitch=c_dst_pitch*channelPad;\n" -" const int b_dst_pitch=y_dst_pitch*heightPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #elif defined(DIMENSION_3)\n" -" // [N,H,W,1]\n" -" const int c_dst_pitch=1;\n" -" const int x_dst_pitch=c_dst_pitch*channelPad;\n" -" const int y_dst_pitch=x_dst_pitch*widthPad;\n" -" const int b_dst_pitch=y_dst_pitch*heightPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #else\n" -" // [N,H,W,C]\n" -" const int c_dst_pitch=1;\n" -" const int x_dst_pitch=c_dst_pitch*channelPad;\n" -" const int y_dst_pitch=x_dst_pitch*widthPad;\n" -" const int b_dst_pitch=y_dst_pitch*heightPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #endif\n" -"#else\n" -" #if defined(DIMENSION_3) && defined(TRANSPOSE)\n" -" // [N,H,C,1]\n" -" const int x_dst_pitch=1;\n" -" const int c_dst_pitch=x_dst_pitch*widthPad;\n" -" const int y_dst_pitch=c_dst_pitch*channelPad;\n" -" const int b_dst_pitch=y_dst_pitch*heightPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" \n" -" #elif defined(DIMENSION_4) && defined(TRANSPOSE)\n" -" // [N,C,W,H]\n" -" const int y_dst_pitch=1;\n" -" const int x_dst_pitch=y_dst_pitch*heightPad;\n" -" const int c_dst_pitch=x_dst_pitch*widthPad;\n" -" const int b_dst_pitch=c_dst_pitch*channelPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #elif defined(DIMENSION_3)\n" -" // [N,C,H,1]\n" -" const int x_dst_pitch=1;\n" -" const int y_dst_pitch=x_dst_pitch*widthPad;\n" -" const int c_dst_pitch=y_dst_pitch*heightPad;\n" -" const int b_dst_pitch=c_dst_pitch*channelPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #else\n" -" // [N,C,H,W]\n" -" const int x_dst_pitch=1;\n" -" const int y_dst_pitch=x_dst_pitch*widthPad;\n" -" const int c_dst_pitch=y_dst_pitch*heightPad;\n" -" const int b_dst_pitch=c_dst_pitch*channelPad;\n" -" OUTPUT_TYPE4 value=outBound ? (OUTPUT_TYPE4)0 : CONVERT_OUTPUT4(vload4(0,input+b*b_src_pitch+c_4*c_src_pitch+h*y_src_pitch+w*x_src_pitch));\n" -" #endif\n" -"#endif\n" -" __global OUTPUT_TYPE* dst_ptr=output+b*b_dst_pitch+c*c_dst_pitch+h*y_dst_pitch+w*x_dst_pitch;\n" -" dst_ptr[0]=value.x;\n" -" if(c+1 >= channel)return;\n" -" dst_ptr[c_dst_pitch]=value.y;\n" -" if(c+2 >= channel)return;\n" -" dst_ptr[2*c_dst_pitch]=value.z;\n" -" if(c+3 >= channel)return;\n" -" dst_ptr[3*c_dst_pitch]=value.w;\n" -" }\n" -"}\n" -"__kernel void pack_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" -" __global INPUT_TYPE* input,__global OUTPUT_TYPE* output,\n" -" __private const int widthPad,\n" -" __private const int heightPad,\n" -" __private const int channelPad,\n" -" __private const int batch,\n" -" __private const int width,\n" -" __private const int height,\n" -" __private const int channel){\n" -" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n" -" if (pos.x= width || h >= height || c >= channel) {\n" -" return;\n" -" }\n" -" const int x_dst_pitch=4;\n" -" const int y_dst_pitch=x_dst_pitch*width;\n" -" const int c_dst_pitch=y_dst_pitch*height;\n" -" const int b_dst_pitch=c_dst_pitch*((channel+3)/4);\n" -"#ifdef MNN_NHWC\n" -" #if defined(TRANSPOSE) && defined(DIMENSION_3)\n" -" // [N,W,H,1]\n" -" const int c_src_pitch=1;\n" -" const int y_src_pitch=c_src_pitch;\n" -" const int x_src_pitch=y_src_pitch*heightPad;\n" -" const int b_src_pitch=x_src_pitch*widthPad;\n" -" #elif defined(TRANSPOSE) && defined(DIMENSION_4)\n" -" // [N,H,C,W]\n" -" const int x_src_pitch=1;\n" -" const int c_src_pitch=x_src_pitch*widthPad;\n" -" const int y_src_pitch=c_src_pitch*channelPad;\n" -" const int b_src_pitch=y_src_pitch*heightPad;\n" -" #else\n" -" // [N,H,W,C]\n" -" const int c_src_pitch=1;\n" -" const int x_src_pitch=c_src_pitch*channelPad;\n" -" const int y_src_pitch=x_src_pitch*widthPad;\n" -" const int b_src_pitch=y_src_pitch*heightPad;\n" -" #endif\n" -"#else\n" -" #if defined(TRANSPOSE) && defined(DIMENSION_3)\n" -" // dst:[N,C,H,1] -> src:[N,H,C,1]\n" -" const int x_src_pitch=1;\n" -" const int c_src_pitch=x_src_pitch*widthPad;\n" -" const int y_src_pitch=c_src_pitch*channelPad;\n" -" const int b_src_pitch=y_src_pitch*heightPad;\n" -" #elif defined(TRANSPOSE) && defined(DIMENSION_4)\n" -" // dst:[N,C,H,W] -> src:[N,C,W,H]\n" -" const int y_src_pitch=1;\n" -" const int x_src_pitch=y_src_pitch*heightPad;\n" -" const int c_src_pitch=x_src_pitch*widthPad;\n" -" const int b_src_pitch=c_src_pitch*channelPad;\n" -" #else\n" -" // [N,C,H,W]\n" -" const int x_src_pitch=1;\n" -" const int y_src_pitch=x_src_pitch*widthPad;\n" -" const int c_src_pitch=y_src_pitch*heightPad;\n" -" const int b_src_pitch=c_src_pitch*channelPad;\n" -" #endif\n" -"#endif\n" -" __global INPUT_TYPE* src_ptr=input+b*b_src_pitch+c*c_src_pitch+h*y_src_pitch+w*x_src_pitch;\n" -" OUTPUT_TYPE4 value=(OUTPUT_TYPE4)0;\n" -" OUTPUT_TYPE *value_ptr=(OUTPUT_TYPE*)&value;\n" -" for(int i=0; i<4 && (i+c0) || (out>0 && in1<0)) ? out+in1 : out;\n" -" #else\n" -" float in0=(float)input0[inputIndex0];\n" -" float in1=(float)input1[inputIndex1];\n" -" float out=OPERATOR;\n" -" #endif\n" -" output[outputIndex]=(OUTPUT_TYPE)out;\n" -" }\n" -"}\n" -"__kernel void loop_cumsum_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" -" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n" -" __private const int input0Stride0,\n" -" __private const int input0Stride1,\n" -" __private const int input0Stride2,\n" -" __private const int input1Stride0,\n" -" __private const int input1Stride1,\n" -" __private const int input1Stride2,\n" -" __private const int outputStride0,\n" -" __private const int outputStride1,\n" -" __private const int outputStride2,\n" -" __private const int loopNumber,\n" -" __private const int4 offsets,\n" -" __private const int4 steps\n" -" ) {\n" -" \n" -" const int x=get_global_id(0);\n" -" const int y=get_global_id(1);\n" -" const int z=get_global_id(2);\n" -" \n" -" if (x= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"__kernel void set_zero(\n" +" GLOBAL_SIZE_2_DIMS\n" +" __global OUTPUT_TYPE *output\n" +" ) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" +" \n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" +" \n" +" output[y*global_size_dim0+x]=(OUTPUT_TYPE)(0);\n" +"}\n" "__kernel void batch_matmul(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" " __global FLOAT* output,__global FLOAT* input_A,__global FLOAT* input_B,\n" "#ifdef BIAS\n" @@ -17,7 +30,7 @@ const char* loop = "#endif\n" " __private const int e,\n" " __private const int l,\n" -" __private const int h,\n" +" __private const int h,__private const int iter,\n" " __private const int4 offsets,\n" " __private const int4 iters,\n" " __private const int4 steps) {\n" @@ -25,6 +38,7 @@ const char* loop = " if (pos.x= 0) {\n" " index.x=offset_O[pos.z];\n" @@ -273,6 +287,9 @@ const char* loop = " WI_DATA(output,(int2)(pos.y*width+w,pos.z*height+h),value);\n" " }\n" "}\n" +"#ifndef UNARY_OPERATOR\n" +" #define UNARY_OPERATOR in\n" +"#endif\n" "__kernel void batch_gather(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" " __global OUTPUT_TYPE* output,__global INPUT_TYPE* input,\n" " #ifdef OFFSET_DST\n" @@ -282,16 +299,19 @@ const char* loop = " __global int* offset_src_ptr,\n" " #endif\n" " __private const int x_size,\n" +" __private const int iter,\n" " __private const int4 stride_src,\n" " __private const int4 stride_dst,\n" " __private const int2 steps,\n" -" __private const int inputSize) {\n" +" __private const int inputSize,\n" +" __private const int outputSize) {\n" " int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n" " \n" " if (pos.x= 0){\n" +" int outputIndex=offset.x+stride_dst.w+x*stride_dst.x+y*stride_dst.y+pos.y*stride_dst.z;\n" +" if(outputIndex= 0){\n" " if(offset.y >= 0 && offset.y(int4)0) || (out>(int4)0 && in1<(int4)0)) ? out+in1 : out;\n" +" int in0=(int)input0[inputIndex0];\n" +" int in1=(int)input1[inputIndex1];\n" +" int out=in0 % in1;\n" +" out=((out<0 && in1>0) || (out>0 && in1<0)) ? out+in1 : out;\n" " #else\n" -" float4 out=OPERATOR;\n" +" float in0=(float)input0[inputIndex0];\n" +" float in1=(float)input1[inputIndex1];\n" +" float out=OPERATOR;\n" " #endif\n" -" \n" -" WI_DATA(output,(int2)(co*dst_width+wo,no*dst_height+ho),CONVERT_OUTPUT_I4(out));\n" +" if(outputIndex OpenCLProgramMap = #ifndef MNN_OPENCL_BUFFER_CLOSED { "softmax_buf", softmax_buf }, #endif -#ifndef MNN_OPENCL_BUFFER_CLOSED - { "gather_buf", gather_buf }, -#endif #ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_INTEL_SUBGROUP { "conv_2d_c16_subgroup_buf", conv_2d_c16_subgroup_buf }, @@ -332,9 +323,6 @@ const std::map OpenCLProgramMap = #endif { "matmul", matmul }, { "binary", binary }, -#ifndef MNN_OPENCL_BUFFER_CLOSED - { "loop_buf", loop_buf }, -#endif { "roi_pooling", roi_pooling }, { "depthwise_conv2d", depthwise_conv2d }, { "layernorm", layernorm }, @@ -387,7 +375,7 @@ const std::map OpenCLProgramMd5Map = { "gemm_buf", "b030b6eacaf65a54e8eabee2755f892a" }, { "conv_2d_int", "985925b9f24d85fa38df2df9b01fafc5" }, { "copy_buffer_to_image2d", "a72ed287711f9bb78a2cfa9726a1fa92" }, - { "loop", "b739a26d78ebe48afd07e55244bdb260" }, + { "loop", "4849a55cd99f0ebab72a10527455341f" }, { "argmax_buf", "ae4a1ae3461b2758609022ac7569b11b" }, { "buffer_convert_subgroup_buf", "d968b717e537464a7fa08e742c9a0319" }, { "attention_buf", "7d05b22865927ca19dae5762ba6f1df9" }, @@ -405,7 +393,6 @@ const std::map OpenCLProgramMd5Map = { "winogradTransformDest2_3_1", "f2aaa52d652565e70a44868d4f6028e9" }, { "layernorm_buf", "5f6b88b29da72f51bdc85064b5663bb2" }, { "softmax_buf", "12052d403f3fa0cdfea2559296e88e6c" }, - { "gather_buf", "cb5cf89ff808f051ada3023876a402a4" }, { "conv_2d_c16_subgroup_buf", "81f9027f323b6890d08d49dab10a15e4" }, { "input_transe_buf", "c80482cd531add8582edc242bcbfa947" }, { "reduction_buf", "c16506adcebf7760a1a3c96ce0d386ee" }, @@ -415,7 +402,6 @@ const std::map OpenCLProgramMd5Map = { "buffer_convert_buf", "e633544642a1a9a61755c913cfe77017" }, { "matmul", "a3e51ece4be2eb0f28266718b313c24e" }, { "binary", "5683a6a6fd24660f0d05a70938fa6a62" }, - { "loop_buf", "0a3e7e970b69c27e15dbbe3dbda7c798" }, { "roi_pooling", "ba4a81b7ec7058d14afb377c18674a76" }, { "depthwise_conv2d", "a23dd590e0bdcdd60987e8bab5ed529f" }, { "layernorm", "bd457b4bd4f3c57818bc17e073b09e74" }, diff --git a/source/backend/opencl/execution/image/LoopExecution.cpp b/source/backend/opencl/execution/image/LoopExecution.cpp index 233fd7ca3c..d784ad9cdd 100644 --- a/source/backend/opencl/execution/image/LoopExecution.cpp +++ b/source/backend/opencl/execution/image/LoopExecution.cpp @@ -72,6 +72,124 @@ static void _PackTensor(cl::Buffer *input, Tensor *output, std::shared_ptrrecordKernel3d(kernelW, mGlobalWorkSize, mLocalWorkSize); } +static std::string getComputeOption(MNN::BinaryOpOperation type){ + std::string compute; + switch (type) { + case BinaryOpOperation_MUL: + compute = "in0*in1";break; + case BinaryOpOperation_ADD: + compute = "in0+in1";break; + case BinaryOpOperation_SUB: + compute = "in0-in1";break; + case BinaryOpOperation_REALDIV: + compute = "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))";break; + case BinaryOpOperation_MINIMUM: + compute = "in0>in1?in1:in0";break; + case BinaryOpOperation_MAXIMUM: + compute = "in0>in1?in0:in1";break; + case BinaryOpOperation_GREATER: + compute = "(float)(isgreater(in0,in1))";break; + case BinaryOpOperation_LESS: + compute = "(float)(isless(in0,in1))";break; + case BinaryOpOperation_LESS_EQUAL: + compute = "(float)(islessequal(in0,in1))";break; + case BinaryOpOperation_GREATER_EQUAL: + compute = "(float)(isgreaterequal(in0,in1))";break; + case BinaryOpOperation_EQUAL: + compute = "(float)(isequal(in0,in1))";break; + case BinaryOpOperation_FLOORDIV: + compute = "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))";break; + case BinaryOpOperation_FLOORMOD: + compute = "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1";break; + case BinaryOpOperation_POW: + compute = "pow(in0,in1)";break; + case BinaryOpOperation_SquaredDifference: + compute = "(in0-in1)*(in0-in1)";break; + case BinaryOpOperation_ATAN2: + compute = "(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))";break; + case BinaryOpOperation_NOTEQUAL: + compute = "(float)(isnotequal(in0,in1))";break; + case BinaryOpOperation_MOD: + compute = "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1";break; + default: + break; + } + return compute; +} + +static std::string getUnaryComputeOption(MNN::UnaryOpOperation type){ + std::string compute; + switch (type) { + case UnaryOpOperation_ABS: + compute = "fabs((float)(in))"; break; + case UnaryOpOperation_SQUARE: + compute = "in*in"; break; + case UnaryOpOperation_RSQRT: + compute = "rsqrt((float))(in)>(float)(0.000001)?(float))(in):(float)(0.000001))"; break; + case UnaryOpOperation_NEG: + compute = "-(in)"; break; + case UnaryOpOperation_EXP: + compute = "exp((float))(in))"; break; + case UnaryOpOperation_COS: + compute = "cos((float)(in))"; break; + case UnaryOpOperation_SIN: + compute = "sin((float)(in))"; break; + case UnaryOpOperation_TAN: + compute = "tan((float)(in))"; break; + case UnaryOpOperation_ATAN: + compute = "atan((float)(in))"; break; + case UnaryOpOperation_SQRT: + compute = "sqrt((float)(in))"; break; + case UnaryOpOperation_CEIL: + compute = "ceil((float)(in))"; break; + case UnaryOpOperation_RECIPROCAL: + compute = "native_recip((float)(in))"; break; + case UnaryOpOperation_LOG1P: + compute = "log1p((float)(in))"; break; + case UnaryOpOperation_LOG: + compute = "native_log((float)(in)>(float)(0.0000001)?(float)(in):(float)(0.0000001))"; break; + case UnaryOpOperation_FLOOR: + compute = "floor((float)(in))"; break; + case UnaryOpOperation_BNLL: + compute = "in>(float)((float)0)?(in+native_log(exp((float)(-(in)))+(float)(1.0))):(native_log(exp((float)(in))+(float)(1.0)))"; break; + case UnaryOpOperation_ACOSH: + compute = "acosh((float)(in))"; break; + case UnaryOpOperation_SINH: + compute = "sinh((float)(in))"; break; + case UnaryOpOperation_ASINH: + compute = "asinh((float)(in))"; break; + case UnaryOpOperation_ATANH: + compute = "atanh((float)(in))"; break; + case UnaryOpOperation_SIGN: + compute = "sign((float)(in))"; break; + case UnaryOpOperation_ROUND: + compute = "round((float)(in))"; break; + case UnaryOpOperation_COSH: + compute = "cosh((float)(in))"; break; + case UnaryOpOperation_ERF: + compute = "erf((float)(in))"; break; + case UnaryOpOperation_ERFC: + compute = "erfc((float)(in))"; break; + case UnaryOpOperation_EXPM1: + compute = "expm1((float)(in))"; break; + case UnaryOpOperation_SIGMOID: + compute = "native_recip((float)1+native_exp((float)(-in)))"; break; + case UnaryOpOperation_SILU: + compute = "((float)(in)*native_recip((float)1+native_exp((float)(-in))))"; break; + case UnaryOpOperation_TANH: + compute = "tanh((float)(in))"; break; + case UnaryOpOperation_HARDSWISH: + compute = "(float)(in)>(float)(-3.0f)?((float)(in)<(float)(3.0f)?(((float)(in)*((float)(in)+(float)3.0f))/(float)6.0f):(float)(in)):(float)(0.0f)"; break; + case UnaryOpOperation_GELU: + compute = "gelu((float)(in))"; break; + case UnaryOpOperation_GELU_STANDARD: + compute = "(erf((float)(in)*(float)0.7071067932881648)+(float)1.0)*(float)(in)*(float)0.5"; break; + default: + break; + } + return compute; +} + static void _setTensorStack(std::vector &result, const std::vector &inputs, const std::vector &outputs, const LoopParam *loop) { if (loop->inputIndexes() != nullptr) { @@ -84,462 +202,356 @@ static void _setTensorStack(std::vector &result, const std::vectortensorNumber()); +} - LoopGatherExecution::LoopGatherExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn) - : CommonExecution(bn, op) { - mLoop = loop; - mTensors.resize(mLoop->tensorNumber()); - } -ErrorCode LoopGatherExecution::InitCommandOnEncode(const std::vector &inputs, const std::vector &outputs){ - auto cmd = mLoop->initCommand()->GetAs(0); +void LoopExecution::ImageToBufferAllTensor(){ OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); - auto runTime = mOpenCLBackend->getOpenCLRuntime(); - auto bufferPool = mOpenCLBackend->getBufferPool(); - auto bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); - - if (cmd->op() == nullptr){ + auto bufferPool = mOpenCLBackend->getBufferPool(); + int bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); + for(int i = 0; i < mTensors.size(); ++i){ + auto input = mTensors[i]; + std::vector Shape = tensorShapeFormat(input); + const int Channel = Shape.at(3); + const int Width = Shape.at(2); + const int Height = Shape.at(1); + const int Batch = Shape.at(0); + mTmpBuffers[input] = bufferPool->alloc(input->elementSize() * bufferUnitSize); + Unit unit; - auto output = mTensors[cmd->indexes()->data()[0]]; - auto outputShape = tensorShapeFormat(output); - int region[] = {outputShape[0], UP_DIV(outputShape[3], 4), outputShape[1], outputShape[2]};//nhwc - unit.kernel = runTime->buildKernel("raster", "image_set_zero", {}, mOpenCLBackend->getPrecision(), output, output); - unit.localWorkSize = {8, 8}; - unit.globalWorkSize = {(uint32_t)UP_DIV((region[1] * region[3]), 16)*16, - (uint32_t)UP_DIV((region[0] * region[2]), 16)*16}; + _TileTensor(input, mTmpBuffers[input], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height,Channel, Batch, mOpenCLBackend, {}); + mUnits.emplace_back(unit); + } +} - int global_dim0 = region[1] * region[3]; - int global_dim1 = region[0] * region[2]; +void LoopExecution::BufferToImageOutputTensor(const std::vector &outputs){ + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + for(int i = 0; i < outputs.size(); ++i){ + auto output = outputs[i]; + std::vector Shape = tensorShapeFormat(output); + const int Channel = Shape.at(3); + const int Width = Shape.at(2); + const int Height = Shape.at(1); + const int Batch = Shape.at(0); + Unit unit; + _PackTensor(mTmpBuffers[output], output, unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, {}); + mUnits.emplace_back(unit); + } +} - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, global_dim0); - ret |= unit.kernel->get().setArg(idx++, global_dim1); - ret |= unit.kernel->get().setArg(idx++, openCLImage(output)); - if(ret != CL_SUCCESS) +ErrorCode LoopExecution::InitCommandOnEncode(){ + for (int i=0; iinitCommand()->size(); ++i) { + auto cmd = mLoop->initCommand()->GetAs(i); + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + int mStride_src[4]; + int mStride_dst[4]; + int mStep[2]; + int mIter[2]; + if (cmd->op() == nullptr){ + Unit unit; + auto output = mTensors[cmd->indexes()->data()[0]]; + auto outputShape = tensorShapeFormat(output); + auto outputDes = TensorUtils::getDescribe(output); + int region[] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//nchw + if(MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat){ + region[1] = ROUND_UP(outputShape[3], 4); + } + unit.kernel = runTime->buildKernel("loop", "set_zero", {}, mOpenCLBackend->getPrecision(), output, output); + unit.localWorkSize = {8, 8}; + unit.globalWorkSize = {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, + (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}; + + int global_dim0 = region[2] * region[3]; + int global_dim1 = region[0] * region[1]; + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, global_dim0); + ret |= unit.kernel->get().setArg(idx++, global_dim1); + ret |= unit.kernel->get().setArg(idx++, *mTmpBuffers[output]); + MNN_CHECK_CL_SUCCESS(ret, "setArg set_zero"); + mOpenCLBackend->recordKernel2d(unit.kernel, {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, + (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}, {8, 8}); + mUnits.emplace_back(unit); + return NO_ERROR; + } + int x = cmd->size()->data()[0]; + int y = cmd->size()->data()[1]; + int z = cmd->size()->data()[2]; + + int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = mTensors[cmd->indexes()->data()[0]]->elementSize(); + + auto srcStride = cmd->view()->GetAs(1)->stride()->data(); + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src[i] = srcStride[i]; + mStride_dst[i] = dstStride[i]; + } + + mStride_src[3] = 0; + mStride_dst[3] = 0; + ::memset(mStep, 0, 2 * sizeof(int)); + + // gather { - MNN_PRINT("setArg err %d\n", (int)ret); + Unit unit; + auto input = mTensors[cmd->indexes()->data()[1]]; + auto output = mTensors[cmd->indexes()->data()[0]]; + std::set buildOptions; + + unit.kernel = runTime->buildKernel("loop", "batch_gather", buildOptions, mOpenCLBackend->getPrecision(), input, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(1)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[output]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[input]); + ret |= unit.kernel->get().setArg(index++, x); + ret |= unit.kernel->get().setArg(index++, 0); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, inputSize); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopInitGatherExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "batch_gather", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } - mOpenCLBackend->recordKernel2d(unit.kernel, - {(uint32_t)UP_DIV((region[1] * region[3]), 16)*16, - (uint32_t)UP_DIV((region[0] * region[2]), 16)*16}, - {8, 8}); - mUnits.emplace_back(unit); - return NO_ERROR; } + return NO_ERROR; +} +ErrorCode LoopExecution::LoopGather(int cmdIndex, int iter) { + auto cmd = mLoop->commands()->GetAs(cmdIndex); + auto op = cmd->op(); + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); - mTmpInitBuffers.resize(2); int x = cmd->size()->data()[0]; int y = cmd->size()->data()[1]; int z = cmd->size()->data()[2]; + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; + if(mLoop->commands()->size() == 1 && OpType_UnaryOp == op->type() && nullptr == op->main() && cmd->fuse() < 0){ + // only one gather + n = mLoop->loopNumber(); + } + + int mStride_src[4]; + int mStride_dst[4]; + int mStep[2]; + int mIter[2]; int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); - + int outputSize = mTensors[cmd->indexes()->data()[0]]->elementSize(); + auto srcStride = cmd->view()->GetAs(1)->stride()->data(); auto dstStride = cmd->view()->GetAs(0)->stride()->data(); - for(int i = 0; i < 3; ++i) { + for (int i = 0; i < 3; ++i) { mStride_src[i] = srcStride[i]; mStride_dst[i] = dstStride[i]; } - - mStride_src[3] = 0; - mStride_dst[3] = 0; - ::memset(mStep, 0, 2 * sizeof(int)); - - // tile input - { - auto input = mTensors[cmd->indexes()->data()[1]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mTmpInitBuffers[1] = bufferPool->alloc(input->elementSize() * bufferUnitSize); - - Unit unit; - _TileTensor(mTensors[cmd->indexes()->data()[1]], mTmpInitBuffers[1], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height,Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - - // tile output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mTmpInitBuffers[0] = bufferPool->alloc(output->elementSize() * bufferUnitSize); - - Unit unit; - _TileTensor(mTensors[cmd->indexes()->data()[0]], mTmpInitBuffers[0], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height,Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); + if(cmd->fuse() >= 0){ + mStride_dst[0] = y * z; + mStride_dst[1] = z; + mStride_dst[2] = 1; } - + + mStride_src[3] = cmd->view()->GetAs(1)->offset(); + mStride_dst[3] = cmd->view()->GetAs(0)->offset(); + ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); + ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); + // gather - { - int offset_index = 0; - Unit unit; - std::string KernelName = "batch_gather"; - unit.kernel = runTime->buildKernel("loop", KernelName, mBuildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(1)}; - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, *mTmpInitBuffers[0]); - ret |= unit.kernel->get().setArg(index++, *mTmpInitBuffers[1]); - ret |= unit.kernel->get().setArg(index++, x); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - ret |= unit.kernel->get().setArg(index++, inputSize); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopGatherExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - mUnits.emplace_back(unit); + Unit unit; + auto output = mTensors[cmd->indexes()->data()[0]]; + auto input = mTensors[cmd->indexes()->data()[1]]; + std::set buildOptions; + + if(op->main() != nullptr){ + std::string compute = getUnaryComputeOption(cmd->op()->main_as_UnaryOp()->opType()); + buildOptions.emplace("-DUNARY_OPERATOR=" + compute); } - - //pack output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - Unit unit; - _PackTensor(mTmpInitBuffers[0], mTensors[cmd->indexes()->data()[0]], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_DST"); } - - return NO_ERROR; -} - ErrorCode LoopGatherExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { - auto cmd = mLoop->commands()->GetAs(0); - OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); - auto runTime = mOpenCLBackend->getOpenCLRuntime(); - auto bufferPool = mOpenCLBackend->getBufferPool(); - auto bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); - _setTensorStack(mTensors, inputs, outputs, mLoop); - mUnits.clear(); - - if(mLoop->initCommand() != nullptr){ - InitCommandOnEncode(inputs, outputs); - } - - mOffsetBuffers.clear(); - mTmpBuffers.resize(2); - int x = cmd->size()->data()[0]; - int y = cmd->size()->data()[1]; - int z = cmd->size()->data()[2]; - int n = mLoop->loopNumber(); - int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); - - auto srcStride = cmd->view()->GetAs(1)->stride()->data(); - auto dstStride = cmd->view()->GetAs(0)->stride()->data(); - for (int i = 0; i < 3; ++i) { - mStride_src[i] = srcStride[i]; - mStride_dst[i] = dstStride[i]; - } - - mStride_src[3] = cmd->view()->GetAs(1)->offset(); - mStride_dst[3] = cmd->view()->GetAs(0)->offset(); - ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); - ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); - - // tile input - { - auto input = mTensors[cmd->indexes()->data()[1]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mTmpBuffers[1] = bufferPool->alloc(input->elementSize() * bufferUnitSize); - - Unit unit; - _TileTensor(mTensors[cmd->indexes()->data()[1]], mTmpBuffers[1], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height,Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - - // tile output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mTmpBuffers[0] = bufferPool->alloc(output->elementSize() * bufferUnitSize); - - Unit unit; - _TileTensor(mTensors[cmd->indexes()->data()[0]], mTmpBuffers[0], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height,Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - - for(int i = 0; i < cmd->iterIndexes()->size(); ++i){ + if (mIter[1] >= 0) { + buildOptions.emplace("-DOFFSET_SRC"); + } + + unit.kernel = runTime->buildKernel("loop", "batch_gather", buildOptions, mOpenCLBackend->getPrecision(), input, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(n)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + if(cmd->fuse() >= 0){ + ret |= unit.kernel->get().setArg(index++, *mFuseBuffer); + }else{ + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[output]); + } + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[input]); + for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { if (mIter[i] >= 0) { - auto input = mTensors[cmd->iterIndexes()->data()[i]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mOffsetBuffers.emplace_back(bufferPool->alloc(input->elementSize() * bufferUnitSize)); - - Unit unit; - _TileTensor(input, mOffsetBuffers.back(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - } - - // gather - { - int offset_index = 0; - std::set buildOptions = mBuildOptions; - if (mIter[0] >= 0) { - buildOptions.emplace("-DOFFSET_DST"); - } - if (mIter[1] >= 0) { - buildOptions.emplace("-DOFFSET_SRC"); - } - Unit unit; - std::string KernelName = "batch_gather"; - unit.kernel = runTime->buildKernel("loop", KernelName, buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - std::vector mGlobalWorkSize = {(uint32_t)(x * y), (uint32_t)(z), (uint32_t)(n)}; - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[0]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[1]); - for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { - if (mIter[i] >= 0) { - ret |= unit.kernel->get().setArg(index++, *mOffsetBuffers[offset_index++]); - } + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->iterIndexes()->data()[i]]]); } - ret |= unit.kernel->get().setArg(index++, x); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); - ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - ret |= unit.kernel->get().setArg(index++, inputSize); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopGatherExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - mUnits.emplace_back(unit); - } - - //pack output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - Unit unit; - _PackTensor(mTmpBuffers[0], mTensors[cmd->indexes()->data()[0]], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - - for (int i = 0; i < mTmpBuffers.size(); ++i) { - bufferPool->recycle(mTmpBuffers[i]); - } - for (int i = 0; i < mOffsetBuffers.size(); ++i) { - bufferPool->recycle(mOffsetBuffers[i]); - } - if(mLoop->initCommand() != nullptr){ - for (int i = 0; i < mTmpInitBuffers.size(); ++i) { - bufferPool->recycle(mTmpInitBuffers[i]); - } - } - - return NO_ERROR; - } - - -LoopBatchMatMulExecution::LoopBatchMatMulExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn) - : CommonExecution(bn, op) { - mLoop = loop; - mTensors.resize(mLoop->tensorNumber()); + } + ret |= unit.kernel->get().setArg(index++, x); + ret |= unit.kernel->get().setArg(index++, iter); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_src), mStride_src); + ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, inputSize); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopGatherExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "batch_gather", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + if(cmd->fuse() >= 0){ + FuseOutput(cmdIndex, mStride_dst, cmd->size()->data()[0], cmd->size()->data()[1], cmd->size()->data()[2], n, iter); + } + return NO_ERROR; } -ErrorCode LoopBatchMatMulExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { - auto cmd = mLoop->commands()->GetAs(0); - mHasBias = cmd->indexes()->size() > 3; - mTransposeA = cmd->op()->main_as_MatMul()->transposeA(); - mTransposeB = cmd->op()->main_as_MatMul()->transposeB(); - OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); - auto runTime = mOpenCLBackend->getOpenCLRuntime(); - auto bufferPool = mOpenCLBackend->getBufferPool(); - auto bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); - _setTensorStack(mTensors, inputs, outputs, mLoop); - - mOffset[0] = cmd->view()->GetAs(0)->offset(); - mOffset[1] = cmd->view()->GetAs(1)->offset(); - mOffset[2] = cmd->view()->GetAs(2)->offset(); - mUnits.clear(); - mOffsetBuffers.clear(); - mTmpBuffers.resize(3); - if (mHasBias) { - mTmpBuffers.resize(4); - mOffset[3] = cmd->view()->GetAs(3)->offset(); - } - - ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); - ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); - int e = cmd->size()->data()[0]; - int l = cmd->size()->data()[1]; - int h = cmd->size()->data()[2]; - int n = mLoop->loopNumber(); - // tile input - for (int i = 1; i < cmd->indexes()->size(); ++i) { - auto input = mTensors[cmd->indexes()->data()[i]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mTmpBuffers[i] = bufferPool->alloc(Batch * Channel * ROUND_UP(Height, 4) * ROUND_UP(Width, 4) * bufferUnitSize); - Unit unit; - _TileTensor(input, mTmpBuffers[i], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - - for(int i = 0; i < cmd->iterIndexes()->size(); ++i){ +ErrorCode LoopExecution::LoopBatchMatMul(int cmdIndex, int iter) { + auto cmd = mLoop->commands()->GetAs(cmdIndex); + bool mHasBias = cmd->indexes()->size() > 3; + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; + + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(1)->offset(); + mOffset[2] = cmd->view()->GetAs(2)->offset(); + if (mHasBias) { + mOffset[3] = cmd->view()->GetAs(3)->offset(); + } + + ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); + ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); + int e = cmd->size()->data()[0]; + int l = cmd->size()->data()[1]; + int h = cmd->size()->data()[2]; + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; + // matmul + Unit unit; + std::string KernelName = "batch_matmul"; + std::set buildOptions; + if (mHasBias) { + buildOptions.emplace("-DBIAS"); + } + if (cmd->op()->main_as_MatMul()->transposeA()) { + buildOptions.emplace("-DTRANSPOSE_A"); + } + if (cmd->op()->main_as_MatMul()->transposeB()) { + buildOptions.emplace("-DTRANSPOSE_B"); + } + buildOptions.emplace("-DH_LEAVES=" + std::to_string(h % 4)); + unit.kernel = runTime->buildKernel("loop", KernelName, buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + std::vector mGlobalWorkSize = {(uint32_t)(UP_DIV(h, 4)), (uint32_t)(UP_DIV(e, 4)),(uint32_t)(n)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + if(cmd->fuse() >= 0){ + ret |= unit.kernel->get().setArg(index++, *mFuseBuffer); + }else{ + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->indexes()->data()[0]]]); + } + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->indexes()->data()[1]]]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->indexes()->data()[2]]]); + if (mHasBias) { + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->indexes()->data()[3]]]); + } + for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { if (mIter[i] >= 0) { - auto input = mTensors[cmd->iterIndexes()->data()[i]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mOffsetBuffers.emplace_back(bufferPool->alloc(input->elementSize() * bufferUnitSize)); - - Unit unit; - _TileTensor(input, mOffsetBuffers.back(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - } - - // matmul - { - mTmpBuffers[0] = bufferPool->alloc(n * e * h * bufferUnitSize); - int offset_index = 0; - - Unit unit; - std::string KernelName = "batch_matmul"; - std::set buildOptions = mBuildOptions; - if (mHasBias) { - buildOptions.emplace("-DBIAS"); - } - if (mTransposeA) { - buildOptions.emplace("-DTRANSPOSE_A"); - } - if (mTransposeB) { - buildOptions.emplace("-DTRANSPOSE_B"); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->iterIndexes()->data()[i]]]); + } else { + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->indexes()->data()[1]]]); } - buildOptions.emplace("-DH_LEAVES=" + std::to_string(h % 4)); - unit.kernel = runTime->buildKernel("loop", KernelName, buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - std::vector mGlobalWorkSize = {(uint32_t)(UP_DIV(h, 4)), (uint32_t)(UP_DIV(e, 4)),(uint32_t)(n)}; - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[0]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[1]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[2]); - if (mHasBias) { - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[3]); - } - for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { - if (mIter[i] >= 0) { - ret |= unit.kernel->get().setArg(index++, *mOffsetBuffers[offset_index++]); - } else { - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[0]); - } - } - ret |= unit.kernel->get().setArg(index++, e); - ret |= unit.kernel->get().setArg(index++, l); - ret |= unit.kernel->get().setArg(index++, h); - ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); - ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBatchMatMulExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - } - - //pack output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - Unit unit; - _PackTensor(mTmpBuffers[0], output, unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); - } - - for (int i = 0; i < mTmpBuffers.size(); ++i) { - bufferPool->recycle(mTmpBuffers[i]); } - for (int i = 0; i < mOffsetBuffers.size(); ++i) { - bufferPool->recycle(mOffsetBuffers[i]); + ret |= unit.kernel->get().setArg(index++, e); + ret |= unit.kernel->get().setArg(index++, l); + ret |= unit.kernel->get().setArg(index++, h); + ret |= unit.kernel->get().setArg(index++, iter); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBatchMatMulExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + if(cmd->fuse() >= 0){ + int mStride_dst[4]; + mStride_dst[0] = h * e; + mStride_dst[1] = h; + mStride_dst[2] = 1; + mStride_dst[3] = 1; + FuseOutput(cmdIndex, mStride_dst, 1, e, h, n, iter); } return NO_ERROR; } -LoopBinaryExecution::LoopBinaryExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn) - : CommonExecution(bn, op) { - mLoop = loop; - mTensors.resize(mLoop->tensorNumber()); - mBuildOptions.emplace("-DOPERATOR=" + compute); -} - -ErrorCode LoopBinaryExecution::cumSumOnEncode(const std::vector &inputs, const std::vector &outputs) { - auto cmd = mLoop->commands()->GetAs(0); +ErrorCode LoopExecution::LoopBinary(int cmdIndex, int iter) { + auto cmd = mLoop->commands()->GetAs(cmdIndex); + auto output = mTensors[cmd->indexes()->data()[0]]; + auto input0 = mTensors[cmd->indexes()->data()[1]]; + auto input1 = mTensors[cmd->indexes()->data()[2]]; + std::string compute = getComputeOption(cmd->op()->main_as_BinaryOp()->opType()); + std::set buildOptions; + buildOptions.emplace("-DOPERATOR=" + compute); + if(cmd->op()->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (output->getType().code == halide_type_int || output->getType().code == halide_type_uint)){ + buildOptions.emplace("-DINT_COMPUTE_MOD"); + } OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); - auto bufferPool = mOpenCLBackend->getBufferPool(); auto runTime = mOpenCLBackend->getOpenCLRuntime(); - auto bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); - mUnits.clear(); - mTmpBuffers.resize(2); - mOffset[0] = cmd->view()->GetAs(0)->offset(); - mOffset[1] = cmd->view()->GetAs(1)->offset(); - mOffset[2] = cmd->view()->GetAs(2)->offset(); - ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); - int loopNumber = mLoop->loopNumber(); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; + + Unit unit; int z = cmd->size()->data()[0]; int y = cmd->size()->data()[1]; int x = cmd->size()->data()[2]; - int n = mLoop->loopNumber(); + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = mTensors[cmd->indexes()->data()[0]]->elementSize(); auto src0Stride = cmd->view()->GetAs(1)->stride()->data(); auto src1Stride = cmd->view()->GetAs(2)->stride()->data(); @@ -549,227 +561,324 @@ ErrorCode LoopBinaryExecution::cumSumOnEncode(const std::vector &input mStride_src1[i] = src1Stride[i]; mStride_dst[i] = dstStride[i]; } - - // tile input - // mTensors cmd->indexes()->data() = {2, 0, 1} -> {output, input0, input1}, output = input0 - for (int i = 1; i < cmd->indexes()->size(); ++i) { - auto input = mTensors[cmd->indexes()->data()[i]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mTmpBuffers[i - 1] = bufferPool->alloc(Batch * Channel * ROUND_UP(Height, 4) * ROUND_UP(Width, 4) * bufferUnitSize); - - Unit unit; - _TileTensor(input, mTmpBuffers[i - 1], unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); + if(cmd->fuse() >= 0){ + mStride_dst[0] = y * x; + mStride_dst[1] = x; + mStride_dst[2] = 1; } - { - Unit unit; - std::set buildOptions = mBuildOptions; - unit.kernel = runTime->buildKernel("loop", "loop_cumsum", buildOptions, mOpenCLBackend->getPrecision(), mTensors[cmd->indexes()->data()[1]], mTensors[cmd->indexes()->data()[0]]); - uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - - std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z)}; - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[0]); // cumsum input0 == output -> mTmpBuffers[0] == mTmpBuffers[2] - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[0]); - ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[1]); - ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); - ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); - ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); - ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); - ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); - ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); - ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); - ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); - ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); - ret |= unit.kernel->get().setArg(index++, loopNumber); - ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); - ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopCumsumExecution"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_cumsum", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; - - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); + ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(1)->offset(); + mOffset[2] = cmd->view()->GetAs(2)->offset(); + + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_DST"); } + if (mIter[1] >= 0) { + buildOptions.emplace("-DOFFSET_SRC0"); + } + if (mIter[2] >= 0) { + buildOptions.emplace("-DOFFSET_SRC1"); + } + unit.kernel = runTime->buildKernel("loop", "loop_binary", buildOptions, mOpenCLBackend->getPrecision(), input0, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - //pack output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - Unit unit; - _PackTensor(mTmpBuffers[0], output, unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, mBuildOptions); - mUnits.emplace_back(unit); + std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z*n)}; + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + if(cmd->fuse() >= 0){ + ret |= unit.kernel->get().setArg(index++, *mFuseBuffer); + }else{ + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[output]); } + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[input0]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[input1]); + for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { + if (mIter[i] >= 0) { + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->iterIndexes()->data()[i]]]); + } + } + ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); + ret |= unit.kernel->get().setArg(index++, iter); + ret |= unit.kernel->get().setArg(index++, z); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryExecution"); - for (int i = 0; i < mTmpBuffers.size(); ++i) { - bufferPool->recycle(mTmpBuffers[i]); - } + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_binary", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + if(cmd->fuse() >= 0){ + FuseOutput(cmdIndex, mStride_dst, cmd->size()->data()[0], cmd->size()->data()[1], cmd->size()->data()[2], n, iter); + } return NO_ERROR; } - -ErrorCode LoopBinaryExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { - auto cmd = mLoop->commands()->GetAs(0); - if(cmd->op()->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (outputs[0]->getType().code == halide_type_int || outputs[0]->getType().code == halide_type_uint)){ - mBuildOptions.emplace("-DINT_COMPUTE_MOD"); +ErrorCode LoopExecution::LoopCumsum() { + auto cmd = mLoop->commands()->GetAs(0); + std::string compute = getComputeOption(cmd->op()->main_as_BinaryOp()->opType()); + + auto output = mTensors[cmd->indexes()->data()[0]]; + auto input0 = mTensors[cmd->indexes()->data()[1]]; + auto input1 = mTensors[cmd->indexes()->data()[2]]; + std::set buildOptions; + buildOptions.emplace("-DOPERATOR=" + compute); + if(cmd->op()->main_as_BinaryOp()->opType() == BinaryOpOperation_MOD && (output->getType().code == halide_type_int || output->getType().code == halide_type_uint)){ + buildOptions.emplace("-DINT_COMPUTE_MOD"); } OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); auto runTime = mOpenCLBackend->getOpenCLRuntime(); - _setTensorStack(mTensors, inputs, outputs, mLoop); - // cumsum - if(!mLoop->parallel()) - return cumSumOnEncode(inputs, outputs); + int mOffset[4]; + int mStep[4]; + int mIter[4]; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; - mUnits.clear(); Unit unit; - auto input0 = mTensors[cmd->indexes()->data()[1]]; - std::vector Input0Shape = tensorShapeFormat(input0); - int Input0Size[4] = {Input0Shape.at(2), Input0Shape.at(1),Input0Shape.at(3),Input0Shape.at(0)}; - - auto input1 = mTensors[cmd->indexes()->data()[2]]; - std::vector Input1Shape = tensorShapeFormat(input1); - int Input1Size[4] = {Input1Shape.at(2), Input1Shape.at(1),Input1Shape.at(3),Input1Shape.at(0)}; - - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); + int z = cmd->size()->data()[0]; + int y = cmd->size()->data()[1]; + int x = cmd->size()->data()[2]; + int n = mLoop->parallel() ? mLoop->loopNumber() : 1; + int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + int outputSize = mTensors[cmd->indexes()->data()[0]]->elementSize(); - bool broadcastInput0 = false; - bool broadcastInput1 = false; - int input0Shape[8] = {1, 1, 1, 1, 1, 1, 1, 1}; - int input1Shape[8] = {1, 1, 1, 1, 1, 1, 1, 1}; - int outputShape[8] = {1, 1, 1, 1, 1, 1, 1, 1}; + auto src0Stride = cmd->view()->GetAs(1)->stride()->data(); + auto src1Stride = cmd->view()->GetAs(2)->stride()->data(); + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src0[i] = src0Stride[i]; + mStride_src1[i] = src1Stride[i]; + mStride_dst[i] = dstStride[i]; + } + if(cmd->fuse() >= 0){ + mStride_dst[0] = y * x; + mStride_dst[1] = x; + mStride_dst[2] = 1; + } + + // cumsum + // mTensors cmd->indexes()->data() = {2, 0, 1} -> {output, input0, input1}, output = input0 + int loopNumber = mLoop->loopNumber(); + + ::memcpy(mStep, cmd->steps()->data(), cmd->steps()->size() * sizeof(int)); + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(1)->offset(); + mOffset[2] = cmd->view()->GetAs(2)->offset(); + unit.kernel = runTime->buildKernel("loop", "loop_cumsum", buildOptions, mOpenCLBackend->getPrecision(), input0, output); + uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); + + std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z)}; + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[output]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[input0]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[input1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); + ret |= unit.kernel->get().setArg(index++, loopNumber); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, outputSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg LoopCumsumExecution"); + + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_cumsum", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + + + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + + return NO_ERROR; +} - int offset0 = output->dimensions() - input0->dimensions(); - int offset1 = output->dimensions() - input1->dimensions(); - for (int i = 0; i < input0->dimensions(); ++i) { - input0Shape[i + offset0] = input0->length(i); - } - for (int i = 0; i < input1->dimensions(); ++i) { - input1Shape[i + offset1] = input1->length(i); - } - for(int i =0;idimensions();++i){ - outputShape[i] = output->length(i); - } - if (TensorUtils::getDescribe(input0)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC) - { - int iN = input0Shape[0]; - int iH = input0Shape[1]; - int iW = input0Shape[2]; - int iC = input0Shape[3]; - - if(input0->dimensions() > 4) - { - for(int i = 4; i < input0->dimensions(); i++) - { - iC *= input0Shape[i]; - } - } - input0Shape[0] = iN; - input0Shape[1] = iC; - input0Shape[2] = iH; - input0Shape[3] = iW; - input0Shape[4] = 1; - } - if (TensorUtils::getDescribe(input1)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC) - { - int iN = input1Shape[0]; - int iH = input1Shape[1]; - int iW = input1Shape[2]; - int iC = input1Shape[3]; - - if(input1->dimensions() > 4) - { - for(int i = 4; i < input1->dimensions(); i++) - { - iC *= input1Shape[i]; - } - } - input1Shape[0] = iN; - input1Shape[1] = iC; - input1Shape[2] = iH; - input1Shape[3] = iW; - input1Shape[4] = 1; - } - if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC) - { - int iN = outputShape[0]; - int iH = outputShape[1]; - int iW = outputShape[2]; - int iC = outputShape[3]; - - if(input1->dimensions() > 4) - { - for(int i = 4; i < input1->dimensions(); i++) - { - iC *= outputShape[i]; - } - } - input1Shape[0] = iN; - outputShape[1] = iC; - outputShape[2] = iH; - outputShape[3] = iW; - outputShape[4] = 1; +ErrorCode LoopExecution::FuseOutput(int iter, int* inputStride, int sizeZ, int sizeY, int SizeX, int n, int n_offset) { + auto cmd = mLoop->commands()->GetAs(iter); + std::string compute = getComputeOption(MNN::BinaryOpOperation(cmd->fuse())); + std::set buildOptions; + buildOptions.emplace("-DOPERATOR=" + compute); + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + + int mOffset[4]; + int mStep[4]; + int mIter[4]; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; + auto output = mTensors[cmd->indexes()->data()[0]]; + int outputSize = output->elementSize(); + + Unit unit; + int z = sizeZ; + int y = sizeY; + int x = SizeX; + + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src0[i] = dstStride[i]; + mStride_src1[i] = inputStride[i]; + mStride_dst[i] = dstStride[i]; } - - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - const int ChannelBlock = UP_DIV(Channel, 4); - auto BuildOptions = mBuildOptions; - std::string KernelName = "broadcast_binary"; - unit.kernel = runTime->buildKernel("loop", KernelName, BuildOptions, mOpenCLBackend->getPrecision(), input0, output); + + for(int i = 0; i < 4; ++i){ + mStep[i] = cmd->steps()->data()[0]; + } + ::memcpy(mIter, cmd->iterIndexes()->data(), cmd->iterIndexes()->size() * sizeof(int)); + mOffset[0] = cmd->view()->GetAs(0)->offset(); + mOffset[1] = cmd->view()->GetAs(0)->offset(); + mOffset[2] = cmd->view()->GetAs(0)->offset(); + + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_DST"); + } + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_SRC0"); + } + if (mIter[0] >= 0) { + buildOptions.emplace("-DOFFSET_SRC1"); + } + unit.kernel = runTime->buildKernel("loop", "loop_binary", buildOptions, mOpenCLBackend->getPrecision(), output, output); uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - - std::vector mGlobalWorkSize = {(uint32_t)(Width), (uint32_t)(Height), (uint32_t)(Batch * ChannelBlock)}; + + std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z*n)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, openCLImage(output)); - ret |= unit.kernel->get().setArg(index++, openCLImage(input0)); - ret |= unit.kernel->get().setArg(index++, openCLImage(input1)); - ret |= unit.kernel->get().setArg(index++, sizeof(input0Shape), input0Shape); - ret |= unit.kernel->get().setArg(index++, sizeof(Input0Size), Input0Size); - ret |= unit.kernel->get().setArg(index++, sizeof(input1Shape), input1Shape); - ret |= unit.kernel->get().setArg(index++, sizeof(Input1Size), Input1Size); - ret |= unit.kernel->get().setArg(index++, sizeof(outputShape), outputShape); - ret |= unit.kernel->get().setArg(index++, Width); - ret |= unit.kernel->get().setArg(index++, Height); - ret |= unit.kernel->get().setArg(index++, Channel); - ret |= unit.kernel->get().setArg(index++, ChannelBlock); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[output]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[output]); + ret |= unit.kernel->get().setArg(index++, *mFuseBuffer); + if (mIter[0] >= 0) { + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->iterIndexes()->data()[0]]]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->iterIndexes()->data()[0]]]); + ret |= unit.kernel->get().setArg(index++, *mTmpBuffers[mTensors[cmd->iterIndexes()->data()[0]]]); + } + ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); + ret |= unit.kernel->get().setArg(index++, n_offset); + ret |= unit.kernel->get().setArg(index++, z); + ret |= unit.kernel->get().setArg(index++, sizeof(mOffset), mOffset); + ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); + ret |= unit.kernel->get().setArg(index++, outputSize); MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryExecution"); - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_binary", unit.kernel, mOpenCLBackend->getCLTuneLevel(), "loop").first; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); mUnits.emplace_back(unit); - - + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); return NO_ERROR; } +ErrorCode LoopExecution::onEncode(const std::vector &inputs, const std::vector &outputs){ + OpenCLBackend *mOpenCLBackend = (OpenCLBackend *)backend(); + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + _setTensorStack(mTensors, inputs, outputs, mLoop); + mUnits.clear(); + // convert all image to buffer + ImageToBufferAllTensor(); + // Make Temp output buffer + auto bufferPool = mOpenCLBackend->getBufferPool(); + int bufferUnitSize = mOpenCLBackend->getPrecision() != BackendConfig::Precision_High ? sizeof(half_float::half) : sizeof(float); + int mMaxFuseBufferSize = 0; + int loopNumber = mLoop->parallel() ? 1 : mLoop->loopNumber(); + for (int i=0; icommands()->size(); ++i) { + auto cmd = mLoop->commands()->GetAs(i); + auto op = cmd->op(); + if (cmd->fuse() >= 0) { + // Make Temp output buffer + auto size = cmd->size()->data(); + if (cmd->op()->type() == OpType_MatMul) { + mMaxFuseBufferSize = std::max(mMaxFuseBufferSize, bufferUnitSize * size[0] * size[2]); + } else { + mMaxFuseBufferSize = std::max(mMaxFuseBufferSize, bufferUnitSize * size[0] * size[1] * size[2]); + } + } + } + if(mMaxFuseBufferSize != 0){ + mFuseBuffer = bufferPool->alloc(mMaxFuseBufferSize * bufferUnitSize); + } + if(mLoop->initCommand() != nullptr){ + InitCommandOnEncode(); + } + if (1 == mLoop->commands()->size()) { + auto cmd = mLoop->commands()->GetAs(0); + auto op = cmd->op(); + if (OpType_UnaryOp == op->type() && nullptr == op->main() && cmd->fuse() < 0) { + LoopGather(0, 0); + // convert all output buffer to image + BufferToImageOutputTensor(outputs); + return NO_ERROR; + } + if(OpType_BinaryOp == op->type() && mLoop->parallel() == false && cmd->fuse() < 0){ + LoopCumsum(); + // convert all output buffer to image + BufferToImageOutputTensor(outputs); + return NO_ERROR; + } + } + for(int iter = 0; iter < loopNumber; ++iter){ + for (int index = 0; indexcommands()->size(); ++index) { + auto cmd = mLoop->commands()->GetAs(index); + auto op = cmd->op(); + if (OpType_UnaryOp == op->type()){ + LoopGather(index, iter); + }else if (OpType_MatMul == op->type()){ + LoopBatchMatMul(index, iter); + }else if(OpType_BinaryOp == op->type()){ + LoopBinary(index, iter); + } + } + } + + // convert all output buffer to image + BufferToImageOutputTensor(outputs); + if(mMaxFuseBufferSize != 0){ + bufferPool->recycle(mFuseBuffer); + } + return NO_ERROR; +} class LoopCreator : public OpenCLBackend::Creator { public: @@ -779,61 +888,7 @@ class LoopCreator : public OpenCLBackend::Creator { if (nullptr == loop || loop->commands() == nullptr) { return nullptr; } - // Make Tensor Stack - if (1 == loop->commands()->size()) { - auto cmd = loop->commands()->GetAs(0); - auto subop = cmd->op(); - if (OpType_UnaryOp == subop->type() && nullptr == subop->main() && cmd->fuse() < 0) { - return new LoopGatherExecution(loop, op, backend); - } - if (OpType_MatMul == subop->type() && loop->parallel() && nullptr == loop->initCommand()) { - return new LoopBatchMatMulExecution(loop, op, backend); - } - if (OpType_BinaryOp == subop->type() && nullptr == loop->initCommand()) { - switch (subop->main_as_BinaryOp()->opType()) { - case BinaryOpOperation_MUL: - return new LoopBinaryExecution(loop, "in0*in1", op, backend); - case BinaryOpOperation_ADD: - return new LoopBinaryExecution(loop, "in0+in1", op, backend); - case BinaryOpOperation_SUB: - return new LoopBinaryExecution(loop, "in0-in1", op, backend); - case BinaryOpOperation_REALDIV: - return new LoopBinaryExecution(loop, "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))", op, backend); - case BinaryOpOperation_MINIMUM: - return new LoopBinaryExecution(loop, "in0>in1?in1:in0", op, backend); - case BinaryOpOperation_MAXIMUM: - return new LoopBinaryExecution(loop, "in0>in1?in0:in1", op, backend); - case BinaryOpOperation_GREATER: - return new LoopBinaryExecution(loop, "convert_float4(-isgreater(in0,in1))", op, backend); - case BinaryOpOperation_LESS: - return new LoopBinaryExecution(loop, "convert_float4(-isless(in0,in1))", op, backend); - case BinaryOpOperation_LESS_EQUAL: - return new LoopBinaryExecution(loop, "convert_float4(-islessequal(in0,in1))", op, backend); - case BinaryOpOperation_GREATER_EQUAL: - return new LoopBinaryExecution(loop, "convert_float4(-isgreaterequal(in0,in1))", op, backend); - case BinaryOpOperation_EQUAL: - return new LoopBinaryExecution(loop, "convert_float4(-isequal(in0,in1))", op, backend); - case BinaryOpOperation_FLOORDIV: - return new LoopBinaryExecution(loop, "floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", op, backend); - case BinaryOpOperation_FLOORMOD: - return new LoopBinaryExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend); - case BinaryOpOperation_POW: - return new LoopBinaryExecution(loop, "pow(in0,in1)", op, backend); - case BinaryOpOperation_SquaredDifference: - return new LoopBinaryExecution(loop, "(in0-in1)*(in0-in1)", op, backend); - case BinaryOpOperation_ATAN2: - return new LoopBinaryExecution(loop, "(in1==(float4)0?(sign(in0)*(float4)(PI/2)):(atan(in0/in1)+(in1>(float4)0?(float4)0:sign(in0)*(float4)PI)))", op, backend); - case BinaryOpOperation_NOTEQUAL: - return new LoopBinaryExecution(loop, "convert_float4(-isnotequal(in0,in1))", op, backend); - case BinaryOpOperation_MOD: - return new LoopBinaryExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend); - default: - break; - } - return nullptr; - } - } - return nullptr; + return new LoopExecution(loop, op, backend); } }; diff --git a/source/backend/opencl/execution/image/LoopExecution.hpp b/source/backend/opencl/execution/image/LoopExecution.hpp index bf63654ba5..12c9c2b53a 100644 --- a/source/backend/opencl/execution/image/LoopExecution.hpp +++ b/source/backend/opencl/execution/image/LoopExecution.hpp @@ -15,63 +15,24 @@ namespace MNN { namespace OpenCL { -class LoopGatherExecution : public CommonExecution { +class LoopExecution : public CommonExecution{ public: - LoopGatherExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); - virtual ~LoopGatherExecution() = default; + LoopExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); + virtual ~LoopExecution() = default; virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - ErrorCode InitCommandOnEncode(const std::vector &inputs, const std::vector &outputs); - -private: - const LoopParam *mLoop; - std::vector mTensors; - std::vector mTmpInitBuffers; - std::vector mTmpBuffers; - std::vector mOffsetBuffers; - int mStride_src[4]; - int mStride_dst[4]; - int mStep[2]; - int mIter[2]; - std::set mBuildOptions; -}; - -class LoopBatchMatMulExecution : public CommonExecution { -public: - LoopBatchMatMulExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); - virtual ~LoopBatchMatMulExecution() = default; - virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - -private: - const LoopParam *mLoop; - std::vector mTensors; - std::vector mTmpBuffers; - std::vector mOffsetBuffers; - int mOffset[4]; - int mStep[4]; - int mIter[4]; - bool mHasBias = false; - bool mTransposeA = false; - bool mTransposeB = false; - std::set mBuildOptions; -}; - -class LoopBinaryExecution : public CommonExecution { -public: - LoopBinaryExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn); - virtual ~LoopBinaryExecution() = default; - virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - ErrorCode cumSumOnEncode(const std::vector &inputs, const std::vector &outputs); - + void ImageToBufferAllTensor(); + void BufferToImageOutputTensor(const std::vector &outputs); + ErrorCode InitCommandOnEncode(); + ErrorCode LoopGather(int cmdIndex, int iter); + ErrorCode LoopBatchMatMul(int cmdIndex, int iter); + ErrorCode LoopBinary(int cmdIndex, int iter); + ErrorCode LoopCumsum(); + ErrorCode FuseOutput(int iter, int* inputStride, int sizeZ, int sizeY, int SizeX, int n, int n_offset); private: - int mOffset[4]; - int mStep[4]; - int mStride_src0[3]; - int mStride_src1[3]; - int mStride_dst[3]; const LoopParam *mLoop; std::vector mTensors; - std::set mBuildOptions; - std::vector mTmpBuffers; + cl::Buffer* mFuseBuffer; + std::map mTmpBuffers; }; } // namespace OpenCL diff --git a/source/backend/qnn/CMakeLists.txt b/source/backend/qnn/CMakeLists.txt index b3b9a67c2c..b2cf9aa7fc 100644 --- a/source/backend/qnn/CMakeLists.txt +++ b/source/backend/qnn/CMakeLists.txt @@ -3,7 +3,13 @@ option(MNN_QNN_CONVERT_MODE "Enable the Convert mode of the QNN backend." OFF) file(GLOB BACKEND_SRCS ${CMAKE_CURRENT_LIST_DIR}/backend/*.cpp) file(GLOB EXECUTION_SRCS ${CMAKE_CURRENT_LIST_DIR}/execution/*.cpp) set(MNN_QNN_SRCS ${BACKEND_SRCS} ${EXECUTION_SRCS}) -message(STATUS "QNN Root: $ENV{QNN_SDK_ROOT}") + +# Prefer CMake variable QNN_SDK_ROOT (passed via -DQNN_SDK_ROOT=...), fallback to environment +set(_QNN_ROOT "$ENV{QNN_SDK_ROOT}") +if (DEFINED QNN_SDK_ROOT AND NOT "${QNN_SDK_ROOT}" STREQUAL "") + set(_QNN_ROOT "${QNN_SDK_ROOT}") +endif() +message(STATUS "QNN Root: ${_QNN_ROOT}") if (MNN_QNN_CONVERT_MODE) file(GLOB CONVERTOR_SRCS ${CMAKE_CURRENT_LIST_DIR}/convertor/*.cpp) @@ -26,4 +32,4 @@ endif() target_include_directories(MNN_QNN PRIVATE ${CMAKE_CURRENT_LIST_DIR}/backend/) target_include_directories(MNN_QNN PRIVATE ${CMAKE_CURRENT_LIST_DIR}/convertor/) -target_include_directories(MNN_QNN PRIVATE $ENV{QNN_SDK_ROOT}/include/QNN/) +target_include_directories(MNN_QNN PRIVATE ${_QNN_ROOT}/include/QNN/) diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 231ea4a137..bcf618c3c9 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -36,12 +36,11 @@ struct RuntimeHint { // qkvQuantOption % 8: // 0: Do not quantize - // 1: Only quantize key, use int8 asymmetric quantization - // 2: Only quantize value, use fp8 quantization - // 3: quantize both key and value - // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V + // 1: Q,K: Int8, V: Float + // 2: Q,K,V: Int8 // qkvQuantOption / 8: + // 0: don't use flash attention // 1: use flash attention int qkvQuantOption = 8; @@ -53,8 +52,11 @@ struct RuntimeHint { int kvcacheSizeLimit = -1; // path of the kvcache directory - std::string kvcacheDirPath = "/tmp"; + std::string kvcacheDirPath = ""; + // path of the kvcache directory + std::string prefixcacheDirPath = "prefixcache"; + std::string midMemoryPath; std::string weightMemoryPath; int mmapFileSize = 1024; // MB @@ -63,7 +65,7 @@ struct RuntimeHint { // op encoder number for once commit int encorderNumForCommit = 10; int initThreadNumber = 0; - + // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; @@ -71,6 +73,15 @@ struct RuntimeHint { // Use CPU Ids std::vector cpuIds; + + // Division ration between SME and NEON when runtime threads>=4 + // Default: 41, which means that in LLM inference, + // during the Prefill stage the workload + // per single SME core is six times that of NEON, + //while during the Decode stage it is the same (1×). + int divisionRatio = 41; + + int smeCores = 2; // Number of SME cores of the backend, default is 2, if supports sme }; /** abstract backend */ class Backend : public NonCopyable { @@ -120,7 +131,7 @@ class Backend : public NonCopyable { - releases memory when `onClearBuffer` is called or when the backend is deleted. */ DYNAMIC_SEPERATE, - + DYNAMIC_IN_EXECUTION }; @@ -411,7 +422,7 @@ class RuntimeCreator { virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const { return false; } - + virtual bool onSetQuantInfo(const Op* op, const std::vector& inputs, const std::vector& outputs) const { return false; } diff --git a/source/core/KVCacheManager.cpp b/source/core/KVCacheManager.cpp new file mode 100644 index 0000000000..5ff3a1fee2 --- /dev/null +++ b/source/core/KVCacheManager.cpp @@ -0,0 +1,113 @@ +// +// KVCacheManager.cpp +// MNN +// +// Created by MNN on 2024/08/05. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#include "KVCacheManager.hpp" +#include "core/Concurrency.h" + +namespace MNN { + +// Translate an address to a hex number string +static inline std::string addrToHex(void *addr) { + std::string result = ""; + uint64_t n = (uint64_t)addr; + for(int i = 15; i >= 0; i--) { + int t = (n >> (i * 4)) & 0x0f; + result.push_back((t < 10) ? ('0' + t) : ('A' + t - 10)); + } + return result; +} + +void KVCacheManager::createKVCacheFile(std::string keyPath, std::string valuePath) { + // Each layer has its own kvcache, so we have to create a key file and a value file for each layer and the file name must be unique + // Here we use the address of the mResource as the file name because the addresses of mResource in different layers are guaranteed to be different + std::string fileName = addrToHex(this); + mBaseFileName = MNNFilePathConcat(mConfig.mKVCacheDir, fileName); + + std::string pathk = keyPath.size() > 0 ? keyPath : mBaseFileName + ".k"; + std::string pathv = valuePath.size() > 0 ? valuePath : mBaseFileName + ".v"; + mKeyCacheFD = MNNCreateFile(pathk.c_str()); + mValueCacheFD = MNNCreateFile(pathv.c_str()); + if (mKeyCacheFD == INVALID_FILE) { + MNN_PRINT("Failed to create the file: %s\n", pathk.c_str()); + } + if (mValueCacheFD == INVALID_FILE) { + MNN_PRINT("Failed to create the file: %s\n", pathv.c_str()); + } +} + +void KVCacheManager::removeKVCacheFile() { + std::string pathk = mBaseFileName + ".k"; + std::string pathv = mBaseFileName + ".v"; + if (mKeyCacheFD != INVALID_FILE) { + MNNCloseFile(mKeyCacheFD); + mKeyCacheFD = INVALID_FILE; + if (MNNRemoveFile(pathk.c_str()) != MNN::NO_ERROR) { + MNN_PRINT("Failed to remove the file: %s\n", pathk.c_str()); + } + } + if (mValueCacheFD != INVALID_FILE) { + MNNCloseFile(mValueCacheFD); + mValueCacheFD = INVALID_FILE; + if (MNNRemoveFile(pathv.c_str()) != MNN::NO_ERROR) { + MNN_PRINT("Failed to remove the file: %s\n", pathv.c_str()); + } + } +} + +void KVCacheManager::resetKVCacheFileSize(size_t keySize, size_t valueSize) { + if (MNNSetFileSize(mKeyCacheFD, keySize) != MNN::NO_ERROR || MNNSetFileSize(mValueCacheFD, valueSize) != MNN::NO_ERROR) { + MNN_PRINT("Failed to resize the kvcache files!\n"); + } +} + +/* +** @brief Memory-map the kvcache file +** @hint After memory-mapping, we can access the kvcache files with pointers, just like accessing memory buffer +** But the data actually resides in disk. +** The OS will set some kernel page cache and manage the data swaping, which we do not need to care. +*/ +void KVCacheManager::mmapKVCache(size_t keySize, size_t valueSize, file_t specKeyFile, file_t specValueFile) +{ + // if keyFile or value file not given, use mKeyCacheFD or mValueCacheFD + auto keyFrom = specKeyFile != INVALID_FILE ? specKeyFile : mKeyCacheFD; + auto valueFrom = specValueFile != INVALID_FILE ? specValueFile : mValueCacheFD; + + if (mMapKeyAddr == nullptr) { + mMapKeyAddr = (int8_t *)MNNMmapFile(keyFrom, keySize); + if (mMapKeyAddr == nullptr) { + MNN_PRINT("Failed to memory-map the kvcache!\n"); + } + } + + if (mMapValueAddr == nullptr) { + mMapValueAddr = (int8_t *)MNNMmapFile(valueFrom, valueSize); + if (mMapValueAddr == nullptr) { + MNN_PRINT("Failed to memory-map the kvcache!\n"); + } + } +} + +void KVCacheManager::unmapKVCache(size_t keySize, size_t valueSize) +{ + if (mMapKeyAddr != nullptr) { + MNNMmapSync(mMapKeyAddr, keySize); + MNNUnmapFile(mMapKeyAddr, keySize); + mMapKeyAddr = nullptr; + } + if (mMapValueAddr != nullptr) { + MNNMmapSync(mMapValueAddr, valueSize); + MNNUnmapFile(mMapValueAddr, valueSize); + mMapValueAddr = nullptr; + } +} + +} // namespace MNN + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/core/KVCacheManager.hpp b/source/core/KVCacheManager.hpp new file mode 100644 index 0000000000..918b96707e --- /dev/null +++ b/source/core/KVCacheManager.hpp @@ -0,0 +1,96 @@ +// +// KVCacheManager.hpp +// MNN +// +// Created by MNN on 2024/08/05. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +#ifndef KVCACHE_MANAGER_HPP +#define KVCACHE_MANAGER_HPP + +#include "core/Macro.h" +#include "core/MNNFileUtils.h" +#include "core/OpCommonUtils.hpp" + + +namespace MNN { + +class KVCacheManager : public NonCopyable{ +public: + struct KVCacheConfig { + std::string mKVCacheDir; // Path of the kvcache files in disk + std::string mPrefixCacheDir; // Path of the prefix prompt kvcache files in disk + int mExpandChunk = 64; // Number of expand chunks when the buffer is full + int mBlockNum = 1; + int mKvAlignNum; + }; +protected: + Backend * mBackend; + KVCacheConfig mConfig; + std::shared_ptr mPastKey; // {numhead, [maxlen/hP, headdim, hP]} or {numhead, [maxlen/hP8, headdim/lP8, hP8, lP8]} + std::shared_ptr mPastValue; // numhead, [headdim/hP, maxlen, hP] + file_t mKeyCacheFD = INVALID_FILE; // The file descriptor of keys + file_t mValueCacheFD = INVALID_FILE; // The file descriptor of values + int8_t * mMapKeyAddr = nullptr; // Memory-mapped address of keys + int8_t * mMapValueAddr = nullptr; // Memory-mapped address of values + bool mKVCacheInDisk = false; // Whether the kvcache is in disk or in memory now + bool mSaveShareKvPrefix = false; + int mPastLength = 0; // Length of past kvcache + int mMaxLength = 0; // Capacity of current kvcache buffer (how many kv items can be stored at most) + int mBytes = 4; + int mKvNumHead = 0, mHeadDim = 0; + KVMeta* mMeta; + std::string mBaseFileName; + std::string mBasePrefixFileName; + + void createKVCacheFile(std::string keyPath = "", std::string valuePath = ""); + void removeKVCacheFile(); + void resetKVCacheFileSize(size_t keySize, size_t valueSize); + void mmapKVCache(size_t keySize, size_t valueSize, file_t specKeyFile = INVALID_FILE, file_t specValueFile = INVALID_FILE); + void unmapKVCache(size_t keySize, size_t valueSize); + +public: + KVCacheManager(Backend * backend, KVCacheConfig & kvConfig) { + mBackend = backend; + mConfig = kvConfig; + } + ~KVCacheManager() { + // nothing todo + } + const Backend * backend() { + return mBackend; + } + const KVCacheConfig * config() { + return &mConfig; + } + const Tensor * key() { + return mPastKey.get(); + } + const Tensor * value() { + return mPastValue.get(); + } + + bool inDisk() { + return mKVCacheInDisk; + } + int kvLength() { + return mPastLength; + } + int maxLength() { + return mMaxLength; + } + + virtual void onResize(int kv_num_head, int head_dim) = 0; + virtual void onClear() = 0; + virtual void onAlloc(KVMeta* meta, int seq_len) = 0; + virtual void onRealloc(KVMeta* meta) = 0; +}; + +} // namespace MNN + +#endif // KVCACHE_MANAGER_HPP + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/core/OpCommonUtils.hpp b/source/core/OpCommonUtils.hpp index ad5dd6348c..0740cc16b2 100644 --- a/source/core/OpCommonUtils.hpp +++ b/source/core/OpCommonUtils.hpp @@ -17,12 +17,22 @@ struct Op; struct CoreFunctions; #ifdef MNN_SUPPORT_TRANSFORMER_FUSE struct KVMeta { + enum { + NoChange, + PendingWrite, + PendingRead + } file_operation; size_t block = 4096; size_t previous = 0; size_t remove = 0; int* reserve = nullptr; int n_reserve = 0; size_t add = 0; + std::string file_name = ""; + int file_flag = NoChange; + int seqlen_in_disk = 0; + int layer_index = 0; + int layer_nums = 0; int computeReverseSize() const { int sum = 0; for (int i=0; i 0 ? true : false; break; + case Interpreter::CPU_SME2_NEON_DIVISION_RATIO: + runtimeHint.divisionRatio = value; + break; + case Interpreter::CPU_SME_CORES: + runtimeHint.smeCores = value; + break; default: break; } @@ -134,6 +140,9 @@ void Session::ModeGroup::setExternalPath(std::string path, int type) { case MNN::Interpreter::EXTERNAL_PATH_KVCACHE_DIR: runtimeHint.kvcacheDirPath = path; break; + case MNN::Interpreter::EXTERNAL_PATH_PREFIXCACHE_DIR: + runtimeHint.prefixcacheDirPath = path; + break; case MNN::Interpreter::EXTERNAL_FEATUREMAP_DIR: runtimeHint.midMemoryPath = path; break; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 2b9937244e..1342a669bd 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -100,7 +100,6 @@ struct Tensor::InsideDescribe { int useCount = 0; Usage usage = NORMAL; std::vector regions; - bool overlap = false; halide_dimension_t dims[MNN_MAX_TENSOR_DIM]; // TensorArray Attribute std::shared_ptr tensorArrayAttr; @@ -108,6 +107,7 @@ struct Tensor::InsideDescribe { std::shared_ptr quantAttr; bool applyQuant = false; bool isMutable = true; + bool overlap = false; // Only used by strideSliceWrite now int index = -1; int group = 0; int channel_pack_num = 4; diff --git a/source/geometry/GeometryStridedSlice.cpp b/source/geometry/GeometryStridedSlice.cpp index fffa158570..1cd685d3db 100644 --- a/source/geometry/GeometryStridedSlice.cpp +++ b/source/geometry/GeometryStridedSlice.cpp @@ -333,10 +333,10 @@ class GeometryStridedSlice : public GeometryComputer { reg.origin = write; } Tensor::InsideDescribe::Region region; - region.size[2] = (int)TensorUtils::getRawSize(input); + region.size[2] = input->elementSize(); region.origin = input; outputDes->regions.insert(outputDes->regions.begin(), region); - outputDes->overlap = true; + outputDes->overlap = true; // should use 1 thread for cpu backend } return true; } diff --git a/test/main.cpp b/test/main.cpp index 37d9d432fb..ce303fda4a 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -71,6 +71,11 @@ int main(int argc, char* argv[]) { enableKleidiAI = atoi(argv[8]) > 0 ? true : false; FUNC_PRINT(enableKleidiAI); } + int divisionRatio = 1; + if (argc > 9) { + divisionRatio = atoi(argv[9]); + FUNC_PRINT(divisionRatio); + } auto exe = MNN::Express::Executor::newExecutor(type, config, thread); if (exe == nullptr) { MNN_ERROR("Can't create executor with type:%d, exit!\n", type); @@ -82,6 +87,7 @@ int main(int argc, char* argv[]) { MNN::RuntimeHint hint; hint.dynamicQuantOption = dynamicOption; hint.enableKleidiAI = enableKleidiAI; + hint.divisionRatio = divisionRatio; scope.Current()->getRuntime().second->setRuntimeHint(hint); MNNTestSuite::get()->pStaus.memory = memory; MNNTestSuite::get()->pStaus.precision = precision; diff --git a/test/op/AttentionTest.cpp b/test/op/AttentionTest.cpp index d12491f72c..ad9ab32f14 100644 --- a/test/op/AttentionTest.cpp +++ b/test/op/AttentionTest.cpp @@ -24,15 +24,24 @@ int HeadDim = 128; const float diff_threshold = 0.001; const float diff_percent_threshold = 0.1; const int pastLength = 101; - -#define LOOP 30 +#define GENERATE_TOKENS 128 struct KVMeta { + enum { + NoChange, + PendingWrite, + PendingRead + } file_operation; size_t block = 4096; size_t previous = 0; size_t remove = 0; int* reserve = nullptr; int n_reserve = 0; size_t add = 0; + std::string file_name = ""; + int file_flag = NoChange; + int seqlen_in_disk = 0; + int layer_index = 0; + int layer_nums = 0; std::vector reserveHost; void sync() { int revertNumber = 0; @@ -48,7 +57,7 @@ struct KVMeta { }; static KVMeta gMeta; -static std::shared_ptr _makeAttentionModule() { +static std::shared_ptr _makeAttentionModule(int quant_qkv = 8) { auto Q = _Input(); auto K = _Input(); auto V = _Input(); @@ -70,6 +79,7 @@ static std::shared_ptr _makeAttentionModule() { config.backendConfig = &bnConfig; std::shared_ptr rtmgr(Executor::RuntimeManager::createRuntimeManager(config)); rtmgr->setHintPtr(MNN::Interpreter::KVCACHE_INFO, &gMeta); + rtmgr->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, quant_qkv); std::shared_ptr m(Module::load({}, {}, (uint8_t*)buffer.data(), buffer.size(), rtmgr)); return m; } @@ -104,7 +114,7 @@ static VARP _computeAttentionExpr(VARP Q, VARP K, VARP V, VARP mask, KVCache cac if (mask->getInfo()->type.code == halide_type_int) { mask = (_Scalar(1.0) - _Cast(mask)) * _Scalar(std::numeric_limits::lowest()); } - + Q = _Reshape(Q, {batch, seqLength, kvNumHead,group, headDim}); Q = _Transpose(Q, {0, 2, 3, 1, 4}); K = _Reshape(K, {batch, seqLength, kvNumHead, 1, headDim}); @@ -151,7 +161,7 @@ static std::vector< std::vector< std::vector > > generateRandTensor(int C if (precision == 2) { a[i][j][k] = ((i + j + k) % 10) * 0.002; } else { - a[i][j][k] = (float)rand() / (float)RAND_MAX * 10.0 * (rand() % 2 ? 1 : -1); + a[i][j][k] = ((i + j + k) % 10) * 0.16 - 5.6; } } } @@ -190,7 +200,7 @@ VARP vector_to_var(std::vector< std::vector > & a) { return var; } -static std::vector< std::vector< std::vector > > +static std::vector< std::vector< std::vector > > computeAttention ( std::vector< std::vector< std::vector > > & query, std::vector< std::vector< std::vector > > & key, @@ -225,21 +235,13 @@ computeAttention ( auto diff = kv_seq_len - seq_len; for (int i = 0; i < seq_len; i++) { for (int j = 0; j < seq_len; j++) { - if (mask[i][j] == 1) { - qk[i][j+diff] *= scale; - } else { - qk[i][j+diff] = std::numeric_limits::lowest(); - } + qk[i][j+diff] = qk[i][j+diff] * scale + (1.f - mask[i][j]) * std::numeric_limits::lowest(); } } } else { for (int i = 0; i < seq_len; i++) { for (int j = 0; j < kv_seq_len; j++) { - if (mask[i][j] == 1) { - qk[i][j] *= scale; - } else { - qk[i][j] = std::numeric_limits::lowest(); - } + qk[i][j] = qk[i][j] * scale + (1.f - mask[i][j]) * std::numeric_limits::lowest(); } } } @@ -291,7 +293,7 @@ class NaiveAttention { std::vector< std::vector< std::vector > > onExecute ( std::vector< std::vector< std::vector > > & query, std::vector< std::vector< std::vector > > & key, - std::vector< std::vector< std::vector > > & value, + std::vector< std::vector< std::vector > > & value, std::vector< std::vector > & mask, int seq_len ) { @@ -305,27 +307,34 @@ class NaiveAttention { }; class AttentionTest : public MNNTestCase { - protected: - std::vector< std::vector< std::vector > > query; - std::vector< std::vector< std::vector > > key; - std::vector< std::vector< std::vector > > value; - std::vector< std::vector > mask; - std::vector< std::vector< std::vector > > expected_result; - VARP Query, Key, Value, Mask, Output; +protected: + std::vector< std::vector< std::vector > > query; + std::vector< std::vector< std::vector > > key; + std::vector< std::vector< std::vector > > value; + std::vector< std::vector > mask; + std::vector< std::vector< std::vector > > expected_result; + VARP Query, Key, Value, Mask, Output; + VARP Query1, Key1, Value1, Mask1; public: AttentionTest() = default; virtual ~AttentionTest() = default; - - void generateInput(int seq_len, int precision) { + void generateInput(int seq_len, int precision, bool genDecodeInput = false) { query = generateRandTensor(seq_len, NumHead, HeadDim, precision); key = generateRandTensor(seq_len, KvNumHead, HeadDim, precision); value = generateRandTensor(seq_len, KvNumHead, HeadDim, precision); Query = vector_to_var(query); Key = vector_to_var(key); Value = vector_to_var(value); + if (genDecodeInput) { + auto vecquery = generateRandTensor(1, NumHead, HeadDim, precision); + auto veckey = generateRandTensor(1, KvNumHead, HeadDim, precision); + auto vecvalue = generateRandTensor(1, KvNumHead, HeadDim, precision); + Query1 = vector_to_var(vecquery); + Key1 = vector_to_var(veckey); + Value1 = vector_to_var(vecvalue); + } } - - void generateMask(int seq_len, int kv_seq_len) { + void generateMask(int seq_len, int kv_seq_len, bool genDecodeInput = false) { mask.resize(seq_len); for (int i = 0; i < seq_len; i++) { mask[i].resize(kv_seq_len); @@ -339,7 +348,16 @@ class AttentionTest : public MNNTestCase { } Mask = vector_to_var(mask); Mask = (_Scalar(1.0) - _Cast(Mask)) * _Scalar(std::numeric_limits::lowest()); - + if (genDecodeInput) { + std::vector> vecmask; + vecmask.resize(1); + vecmask[0].resize(gMeta.previous + 1); + for (int i = 0; i < gMeta.previous + 1; ++i) { + vecmask[0][i] = 1; + } + Mask1 = vector_to_var(vecmask); + Mask1 = (_Scalar(1.0) - _Cast(Mask1)) * _Scalar(std::numeric_limits::lowest()); + } } bool compareResult(int seq_len) { @@ -360,7 +378,7 @@ class AttentionTest : public MNNTestCase { Output->unMap(); return true; } - + virtual bool run(int precision) { srand(2024); // unit test 1 @@ -437,7 +455,7 @@ class AttentionTest : public MNNTestCase { } }; -class SpeedAttentionTest : public MNNTestCase { +class SpeedAttentionTest : public AttentionTest { protected: std::vector< std::vector< std::vector > > query; std::vector< std::vector< std::vector > > key; @@ -448,51 +466,42 @@ class SpeedAttentionTest : public MNNTestCase { public: SpeedAttentionTest() = default; virtual ~SpeedAttentionTest() = default; - + virtual bool run(int precision) { - srand(2024); - int seq_len[] = {200, 400, 800, 1000, 2000}; - // unit test 1 - for (int n = 0; n < 5; ++n) { - auto rt = ExecutorScope::Current()->getRuntime(); - MNN::KVMeta meta; - for (auto& iter : rt.first) { - iter.second->pMeta = &meta; - } - std::shared_ptr naiveAttention(new NaiveAttention); - std::shared_ptr attention(new MNN::OpT); - attention->type = MNN::OpType_Attention; - attention->main.type = MNN::OpParameter_AttentionParam; - attention->main.value = new MNN::AttentionParamT; - attention->main.AsAttentionParam()->kv_cache = true; - meta.add = seq_len[n]; - VARP Query = _Input({1, seq_len[n], NumHead, HeadDim}, NCHW, halide_type_of()); - VARP Key = _Input({1, seq_len[n], KvNumHead, HeadDim}, NCHW, halide_type_of()); - VARP Value = _Input({1, seq_len[n], KvNumHead, HeadDim}, NCHW, halide_type_of()); - VARP Mask = _Input({1, 1, seq_len[n], seq_len[n]}, NCHW, halide_type_of()); - auto Output = Variable::create(Expr::create(attention.get(), {Query, Key, Value, Mask})); - { - Query.fix(VARP::INPUT); - Key.fix(VARP::INPUT); - Value.fix(VARP::INPUT); - Mask.fix(VARP::INPUT); - { - Query->writeMap(); - Key->writeMap(); - Value->writeMap(); - Mask->writeMap(); - Output->readMap(); + std::vector seqs = {4096}; + std::shared_ptr naiveAttention(new NaiveAttention); + std::shared_ptr attention(new MNN::OpT); + attention->type = MNN::OpType_Attention; + attention->main.type = MNN::OpParameter_AttentionParam; + attention->main.value = new MNN::AttentionParamT; + attention->main.AsAttentionParam()->kv_cache = true; + /* 3 attention module */ + std::vector quantQKV = {8, 9, 10}; + std::vector testNames = {"float qkv", "quant qk", "quant qkv"}; + for (int n = 0; n < seqs.size(); ++n) { + int seq_len = seqs[n]; + MNN_PRINT(">>> seq_len=%d, decode_len=%d\n", seq_len, GENERATE_TOKENS); + generateInput(seqs[n], precision, true); + generateMask(seqs[n], seq_len, true); + for (int m = 0; m < testNames.size(); ++m) { + gMeta.previous = 0; + gMeta.add = seq_len; + auto _module = _makeAttentionModule(quantQKV[m]); + MNN::Timer t1; + for (int x = 0; x < 5; ++x) { + Output = _module->onForward({Query, Key, Value, Mask})[0]; } - MNN::Timer _t; - for (int i = 0; i < LOOP; ++i) { - Query->writeMap(); - Key->writeMap(); - Value->writeMap(); - Mask->writeMap(); - Output->readMap(); + auto time = (float)t1.durationInUs() / 1000.0f / 5.f; + MNN_PRINT("%s: prefill cost = %.2f\n", testNames[m].c_str(), time); + gMeta.sync(); + MNN::Timer t2; + for (int x = 0; x < GENERATE_TOKENS; ++x) { + gMeta.add = 1; + auto output2 = _module->onForward({Query1, Key1, Value1, Mask1})[0]; + gMeta.sync(); } - auto time = (float)_t.durationInUs() / 1000.0f; - MNN_PRINT("seq_len = %d, avg time = %f\n", seq_len[n], time / LOOP); + time = (float)t2.durationInUs() / 1000.0f; + MNN_PRINT("%s: decode cost = %f\n", testNames[m].c_str(), time); } } return true; diff --git a/test/op/StridedSliceTest.cpp b/test/op/StridedSliceTest.cpp index 7e95aa8460..92a5ac8e20 100644 --- a/test/op/StridedSliceTest.cpp +++ b/test/op/StridedSliceTest.cpp @@ -241,6 +241,184 @@ MNNTestSuiteRegister(SplitC4Test, "op/splitc4"); class StrideSliceWriteTest: public MNNTestCase { virtual bool run(int precision) { + + // Test Case: 1D Input + { + // 1. Input data + auto input = _Input({20}, NCHW); + auto begin = _Input({1}, NCHW); + auto end = _Input({1}, NCHW); + auto strided = _Input({1}, NCHW); + auto write = _Input({5}, NCHW); + auto size = 20; + const float inputData[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 + }; + memcpy(input->writeMap(), inputData, size * sizeof(float)); + + // 2. (Slicing Parameters) + // Slice on Dim 0: from index 2 to 15 with stride 3 + // This will select indices: 2, 5, 8, 11, 14 + const int beginData[] = {2}; + memcpy(begin->writeMap(), beginData, 1 * sizeof(int)); + const int endData[] = {15}; + memcpy(end->writeMap(), endData, 1 * sizeof(int)); + const int strideData[] = {3}; + memcpy(strided->writeMap(), strideData, 1 * sizeof(int)); + + // 3. Write Tensor + // write element size = 5 ( (15-2)/3 rounded up ) + const float writeData[] = {99, 99, 99, 99, 99}; + memcpy(write->writeMap(), writeData, 5 * sizeof(float)); + + auto output = _StridedSliceWrite(input, begin, end, strided, write, 0, 0, 0, 0, 0); + + // 4. Expected Result + const std::vector expectedShape = {20}; + const std::vector expectedOutput = { + 0, 1, 99, 3, 4, 99, 6, 7, 99, 9, 10, 99, 12, 13, 99, 15, 16, 17, 18, 19 + }; + // Indices 2, 5, 8, 11, 14 have been replaced by 99. + + // 5. validate + if (!checkVector(output->getInfo()->dim.data(), expectedShape.data(), expectedShape.size(), 0)) { + MNN_PRINT("StrideSliceWrite shape test0 error\n"); + return false; + } + if (!checkVector(output->readMap(), expectedOutput.data(), expectedOutput.size(), 0.01)) { + MNN_PRINT("StrideSliceWrite test0 result error\n"); + return false; + } + } + + // Test Case: 2D Input + { + // 1. input data + auto input = _Input({6, 8}, NCHW); + auto begin = _Input({2}, NCHW); + auto end = _Input({2}, NCHW); + auto strided = _Input({2}, NCHW); + auto write = _Input({12}, NCHW); + auto size = 48; + const float inputData[] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, 10, + 20, 20, 20, 20, 20, 20, 20, 20, + 30, 30, 30, 30, 30, 30, 30, 30, + 40, 40, 40, 40, 40, 40, 40, 40, + 50, 50, 50, 50, 50, 50, 50, 50 + }; + memcpy(input->writeMap(), inputData, size * sizeof(float)); + + // 2. Slicing Parameters + // Slice on Dim 0 (rows): from index 1 to 6 with stride 2 -> selects rows 1, 3, 5 + // Slice on Dim 1 (cols): from index 0 to 8 with stride 2 -> selects cols 0, 2, 4, 6 + const int beginData[] = {1, 0}; + memcpy(begin->writeMap(), beginData, 2 * sizeof(int)); + const int endData[] = {6, 8}; + memcpy(end->writeMap(), endData, 2 * sizeof(int)); + const int strideData[] = {2, 2}; + memcpy(strided->writeMap(), strideData, 2 * sizeof(int)); + + // 3. Write Tensor + // write element size = 3 (rows) * 4 (cols) = 12 + const float writeData[] = { + 77, 77, 77, 77, + 77, 77, 77, 77, + 77, 77, 77, 77 + }; + memcpy(write->writeMap(), writeData, 12 * sizeof(float)); + + auto output = _StridedSliceWrite(input, begin, end, strided, write, 0, 0, 0, 0, 0); + + // 4. Expected Result + const std::vector expectedShape = {6, 8}; + const std::vector expectedOutput = { + 0, 0, 0, 0, 0, 0, 0, 0, // row 0: unchanged + 77, 10, 77, 10, 77, 10, 77, 10, // row 1: selected cols replaced + 20, 20, 20, 20, 20, 20, 20, 20, // row 2: unchanged + 77, 30, 77, 30, 77, 30, 77, 30, // row 3: selected cols replaced + 40, 40, 40, 40, 40, 40, 40, 40, // row 4: unchanged + 77, 50, 77, 50, 77, 50, 77, 50 // row 5: selected cols replaced + }; + + // 5. validate + if (!checkVector(output->getInfo()->dim.data(), expectedShape.data(), expectedShape.size(), 0)) { + MNN_PRINT("StrideSliceWrite shape test0 error\n"); + return false; + } + if (!checkVector(output->readMap(), expectedOutput.data(), expectedOutput.size(), 0.01)) { + MNN_PRINT("StrideSliceWrite test0 result error\n"); + return false; + } + } + + // Test Case: 3D Input + { + auto input = _Input({4, 5, 6}, NCHW); + auto begin = _Input({3}, NCHW); + auto end = _Input({3}, NCHW); + auto strided = _Input({3}, NCHW); + auto write = _Input({20}, NCHW); + auto size = 120; + + // 1. Input data + const float inputData[] = { + // Plane 0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // Plane 1 + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + // Plane 2 + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + // Plane 3 + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + }; + memcpy(input->writeMap(), inputData, size * sizeof(float)); + + // 2. Slicing Parameters + // Slice on Dim 0: from index 1 to 4 with stride 2 -> selects planes 1, 3 + // Slice on Dim 1: from index 0 to 5 with stride 1 -> selects all 5 rows + // Slice on Dim 2: from index 2 to 6 with stride 3 -> selects columns 2, 5 + const int beginData[] = {1, 0, 2}; + memcpy(begin->writeMap(), beginData, 3 * sizeof(int)); + const int endData[] = {4, 5, 6}; + memcpy(end->writeMap(), endData, 3 * sizeof(int)); + const int strideData[] = {2, 1, 3}; + memcpy(strided->writeMap(), strideData, 3 * sizeof(int)); + + // 3. Write Tensor + // 2 (dim0) * 5 (dim1) * 2 (dim2) = 20 + const float writeData[] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8 + }; + memcpy(write->writeMap(), writeData, 20 * sizeof(float)); + + auto output = _StridedSliceWrite(input, begin, end, strided, write, 0, 0, 0, 0, 0); + + // 4. Expected Result + const std::vector expectedShape = {4, 5, 6}; + const std::vector expectedOutput = { + // Plane 0 - remain the same + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // Plane 1 - write the new element + 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, 2, 2, 8, + // Plane 2 - remain the same + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + // Plane 3 - write the new element + 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8, 4, 4, 8 + }; + + // 5. validate + if (!checkVector(output->getInfo()->dim.data(), expectedShape.data(), expectedShape.size(), 0)) { + MNN_PRINT("StrideSliceWrite shape test0 error\n"); + return false; + } + if (!checkVector(output->readMap(), expectedOutput.data(), expectedOutput.size(), 0.01)) { + MNN_PRINT("StrideSliceWrite test0 result error\n"); + return false; + } + } { auto input = _Input({2, 3, 2, 12}, NCHW); auto begin = _Input({4}, NCHW); diff --git a/test/speed/HybridConvSpeedTest.cpp b/test/speed/HybridConvSpeedTest.cpp index 982f2e92ad..728a0d2b45 100644 --- a/test/speed/HybridConvSpeedTest.cpp +++ b/test/speed/HybridConvSpeedTest.cpp @@ -319,7 +319,7 @@ class ConvInt8BlockQuantTest : public HybridConvSpeedTestCommon { virtual bool run(int precision) { INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 17}; // {w, h} int batch[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; - std::vector blocks = {0, 64, 32}; + std::vector blocks = {0, 32, 64}; std::vector> channels = {{320, 320}, {640, 200}, {128, 79}}; std::vector kernels = {1, 3}; @@ -350,7 +350,7 @@ class HybridConvInt8Test : public HybridConvSpeedTestCommon { INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}; // {w, h} int batch[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 21, 22, 23, 25, 26, 27, 28, 29, 30}; std::vector blocks = {0, 32, 128}; - std::vector> channels = {{3, 7}, {4, 18}, {5, 22}, {12, 16}, {8, 8}, {8, 9}, {8, 16}, {7, 20}, {9, 24}, {2048, 54}, {1, 10}, {20, 153}, {9, 18}, {64, 28}, {1496, 11}, {10, 9}}; + std::vector> channels = {{128, 2048}, {3, 7}, {4, 18}, {5, 22}, {12, 16}, {8, 8}, {8, 9}, {8, 16}, {7, 20}, {9, 24}, {2048, 54}, {1, 10}, {20, 153}, {9, 18}, {64, 28}, {1496, 11}, {10, 9}}; std::vector> inputShapes = {{1, 1}}; std::vector> kernels = {{1, 1}}; std::vector weightBits = {4, 8}; @@ -454,7 +454,49 @@ class PTQInt4Test: public PtqTestCommon { MNNTestSuiteRegister(PTQInt4Test, "op/int4Ptq"); #endif +class ConvInt8MixedKernelTest : public HybridConvSpeedTestCommon { +public: + virtual bool run(int precision) { + INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}; // {w, h} + int batch[] = {1, 100}; + std::vector blocks = {0, 32, 128}; + std::vector> channels = {{1536, 1536}, {1536, 256}, {1536, 8960}, {8960, 1536}, {1536, 151936}, {896, 896}, {896, 128}, {4864, 896}, {896, 151936}, {200, 138}, {92, 92}, {126, 126}, {120, 1300}}; + for (int i = 0; i < 32; ++i) { // To test that every storage branch of 'Hp=128' is correct. + std::vector channel = {256, 4 * (i + 1)}; + channels.emplace_back(channel); + } + std::vector> inputShapes = {{1, 1}}; + std::vector> kernels = {{1, 1}}; + std::vector weightBits = {4, 8}; + int batchNum = sizeof(batch) / sizeof(int); + bool correct = true; + for (auto kernel: kernels) { + for (auto inputShape: inputShapes) { + for (auto block : blocks) { + for (auto& bits : weightBits) { + for (auto &channel: channels) { + if (dilate[0] > inputShape[0] || dilate[0] * (kernel[0] - 1) + 1 > inputShape[0] || dilate[0] * (kernel[1] - 1) + 1 > inputShape[1]) + continue; + if (block > 0 && channel[0] % block != 0) + continue; + for (int n = 0; n < batchNum; ++n) { + auto res = testKernel("Low memory mixed kernel test:", inputShape, kernel, channel, pad, strides, dilate, batch[n], bits, precision, false, block); + if (!res) { + MNN_ERROR("Error: low memory mixed kernel when bits=%d, n=%d, ic=%d, oc=%d, block=%d\n", bits, batch[n], channel[0], channel[1], block); + return false; + } + } + } + } + } + } + } + return true; + } +}; + MNNTestSuiteRegister(DenseConvInt8Test, "op/lowMemory/DenseConv"); MNNTestSuiteRegister(HybridConvInt8Test, "op/lowMemory/HybridConv"); MNNTestSuiteRegister(HybridConvSpeedInt8Test, "speed/HybridConv"); MNNTestSuiteRegister(ConvInt8BlockQuantTest, "op/lowMemory/blockConv"); +MNNTestSuiteRegister(ConvInt8MixedKernelTest, "op/lowMemory/mixedKernel"); diff --git a/tools/converter/source/common/cli.cpp b/tools/converter/source/common/cli.cpp index b7bb4f2d5a..a02ab2f808 100644 --- a/tools/converter/source/common/cli.cpp +++ b/tools/converter/source/common/cli.cpp @@ -470,8 +470,15 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv if (result.count("fp16")) { modelPath.saveHalfFloat = true; } + if (result.count("weightQuantAsymmetric")) { + modelPath.weightQuantAsymmetric = result["weightQuantAsymmetric"].as(); + } if (result.count("hqq")) { - modelPath.useHQQ = true; + if(modelPath.weightQuantAsymmetric) { + modelPath.useHQQ = true; + } else { + std::cout << "Warning, MNN Convert only support Hqq with weight asymmetric quant! Disable Hqq currently" << std::endl; + } } if (result.count("forTraining")) { modelPath.forTraining = true; @@ -488,9 +495,6 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv MNN_PRINT("Use HQQ to quant weight\n"); } } - if (result.count("weightQuantAsymmetric")) { - modelPath.weightQuantAsymmetric = result["weightQuantAsymmetric"].as(); - } if (result.count("weightQuantBlock")) { modelPath.weightQuantBlock = result["weightQuantBlock"].as(); } diff --git a/tools/cpp/CMakeLists.txt b/tools/cpp/CMakeLists.txt index ed89c1e177..9281ad52ab 100644 --- a/tools/cpp/CMakeLists.txt +++ b/tools/cpp/CMakeLists.txt @@ -44,10 +44,8 @@ list(APPEND MNN_CPP_TOOLS backendTest.out) add_executable(testModel.out ${CMAKE_CURRENT_LIST_DIR}/testModel.cpp) list(APPEND MNN_CPP_TOOLS testModel.out) -if (NOT WIN32) -add_executable(compilefornpu ${CMAKE_CURRENT_LIST_DIR}/compilefornpu.cpp) +add_executable(compilefornpu ${CMAKE_CURRENT_LIST_DIR}/compilefornpu.cpp ${CMAKE_CURRENT_LIST_DIR}/../../3rd_party/flatbuffers/src/util.cpp) list(APPEND MNN_CPP_TOOLS compilefornpu) -endif() if (MNN_QNN) diff --git a/tools/cpp/MNN2QNNModel.cpp b/tools/cpp/MNN2QNNModel.cpp index 435e98f6d0..216d8b95c0 100644 --- a/tools/cpp/MNN2QNNModel.cpp +++ b/tools/cpp/MNN2QNNModel.cpp @@ -255,7 +255,7 @@ int main(int argc, const char* argv[]) { } MNN::ScheduleConfig config; - config.type = MNN_FORWARD_NN; + config.type = MNN_CONVERT_QNN; std::shared_ptr rtmgr(Executor::RuntimeManager::createRuntimeManager(config)); rtmgr->setCache(curQnnModelDir.c_str()); MNN::Express::Module::Config mConfig; diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 75d4c066cf..90fa6b80d3 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -34,7 +34,7 @@ static bool compareOutput(VARP output, const std::string& directName, const std: } if (nullptr == info || nullptr == ptr) { - MNN_ERROR("TESTERROR name:%s, info:%p, ptr:%p. size:%d\n", name.c_str(), info, ptr, info->size); + MNN_ERROR("TESTERROR name:%s, info:%p, ptr:%p. size:%zu\n", name.c_str(), info, ptr, info->size); return false; } @@ -93,6 +93,9 @@ static bool compareOutput(VARP output, const std::string& directName, const std: static inline std::vector parseIntList(const std::string& str, char delim) { std::vector result; + if (str.empty()) { + return result; + } std::ptrdiff_t p1 = 0, p2; while (1) { p2 = str.find(delim, p1); @@ -135,7 +138,7 @@ int main(int argc, char *argv[]) { _initTensorStatic(); } } - int repeatNumber = 1; + int repeatNumber = 2; bool shapeMutable = true; std::vector inputs; std::vector outputs; @@ -251,12 +254,17 @@ int main(int argc, char *argv[]) { if (argc > 10) { enableKleidiAI = atoi(argv[10]) > 0 ? true : false; } + int mixedRatio = 17; + if (argc > 11) { + mixedRatio = atoi(argv[11]); + } MNN_PRINT("\n"); FUNC_PRINT(precision); FUNC_PRINT(memory); FUNC_PRINT(power); FUNC_PRINT_ALL(cacheFileName, s); FUNC_PRINT(enableKleidiAI); + FUNC_PRINT(mixedRatio); // create session MNN::ScheduleConfig config; config.type = type; @@ -334,6 +342,8 @@ int main(int argc, char *argv[]) { if (runMask & 2048) { rtmgr->setExternalPath("tmp", Interpreter::EXTERNAL_FEATUREMAP_DIR); } + + rtmgr->setHint(Interpreter::CPU_SME2_NEON_DIVISION_RATIO, mixedRatio); // set npu model dir, npu model and mnn model in same path size_t pos = modelName.find_last_of("/\\"); std::string modelPath; diff --git a/tools/cpp/compilefornpu.cpp b/tools/cpp/compilefornpu.cpp index f600c05c37..f2a03e7b36 100644 --- a/tools/cpp/compilefornpu.cpp +++ b/tools/cpp/compilefornpu.cpp @@ -13,7 +13,6 @@ #include "shape/SizeComputer.hpp" #include "core/OpCommonUtils.hpp" #include "core/Schedule.hpp" -#include "MNN_generated.h" #include "rapidjson/document.h" #include diff --git a/tools/script/arm2binary.py b/tools/script/arm2binary.py index 578a687262..e90b7e473d 100644 --- a/tools/script/arm2binary.py +++ b/tools/script/arm2binary.py @@ -61,7 +61,7 @@ def should_be_converted(instruction_line): # 规则2:如果包含 p, pn, z, 或 za 寄存器,则必须转换 # 正则表达式已更新以包含 pn<数字> - if re.search(r'\b(p\d+|pn\d+|z\d+|za)', instruction_line): + if re.search(r'\b(p\d+|pn\d+|z\d+|za+|zt0)', instruction_line): return True # 如果以上条件都不满足,则不转换 @@ -100,11 +100,79 @@ def generate_equivalent_instructions(canonical_line): items_to_process.append(new_instr) return list(equivalents) +def expand_register_range(instruction): + """ + 查找并展开指令中的寄存器范围,如 {z26.s-z27.s} -> {z26.s,z27.s}。 + {z0.s-z3.s} 会被展开为 {z0.s,z1.s,z2.s,z3.s}。 + """ + # 正则表达式模式: + # \{ \s* - 匹配 '{' 和可选空格 + # ([a-zA-Z])(\d+) - 捕获组1(前缀), 捕获组2(起始编号) + # (\.\w+) - 捕获组3(后缀) + # \s*-\s* - 匹配 '-' 和可选空格 + # \1(\d+)\3 - 匹配相同的组1(前缀), 捕获组4(结束编号), 相同的组3(后缀) + # \s* \} - 匹配可选空格和 '}' + pattern = re.compile(r'\{\s*([a-zA-Z])(\d+)(\.\w+)\s*-\s*\1(\d+)\3\s*\}') + + # 定义一个替换函数,用于生成展开后的列表 + def replacer(match): + prefix, start_num_str, suffix, end_num_str = match.groups() + start_num, end_num = int(start_num_str), int(end_num_str) + + # 确保范围是有效的 + if start_num >= end_num: + return match.group(0) # 如果范围无效,则不替换 + + # 使用列表推导生成所有寄存器名 + regs = [f"{prefix}{i}{suffix}" for i in range(start_num, end_num + 1)] + + # 将列表连接成一个无空格的字符串,并用花括号包裹 + # 输出如: {z26.s,z27.s} + return f"{{{','.join(regs)}}}" + + # 使用 re.sub 和我们的替换函数来执行替换 + return pattern.sub(replacer, instruction) + +def normalize_instruction(instruction): + """ + 对汇编指令进行语义规范化。 + - 步骤 0: 移除分号后的注释。 + - 步骤 1: 展开寄存器范围 (例如, z26.s-z27.s)。 + - 步骤 2: 规范化数字,统一转为十进制格式。 + - 步骤 3: 仅保留助记符后的第一个空格,移除所有其他空格。 + """ + # 步骤 0: 移除注释 + instruction = instruction.split(';')[0].strip() + + # 步骤 1: 展开寄存器范围 + instruction = expand_register_range(instruction) + + # 步骤 2: 规范化数字 + match = re.search(r'#\s*(0x[0-9a-fA-F]+|[0-9]+)', instruction) + if match: + number_str = match.group(1) + try: + decimal_value = int(number_str, 0) + instruction = instruction.replace(match.group(0), f'#{decimal_value}') + except ValueError: + pass + + # 步骤 3: 规范化空格 + parts = instruction.split(' ', 1) + if len(parts) == 2: + mnemonic = parts[0] + operands = parts[1] + operands_no_space = operands.replace(' ', '') + return f"{mnemonic} {operands_no_space}" + else: + return instruction + def find_best_match(source_line, instruction_map): - matcher = difflib.SequenceMatcher(None, source_line) + matcher = difflib.SequenceMatcher(None, normalize_instruction(source_line)) best_match_key, highest_score = None, 0.0 for key in instruction_map.keys(): - matcher.set_seq2(key) + keyNormalized = normalize_instruction(key) + matcher.set_seq2(keyNormalized) score = matcher.ratio() if score > highest_score: highest_score, best_match_key = score, key return best_match_key, highest_score @@ -166,19 +234,19 @@ def process_assembly_file(s_file_path, instruction_map, output_file_path): # print(f" -> 全局最相似的匹配是 '{best_match_key}' (相似度: {score:.2%})") # 报告2: 所有助记符相同的匹配 - mnemonic_matches = find_mnemonic_matches(canonical_content, instruction_map) - if mnemonic_matches: - source_mnemonic = canonical_content.split()[0] - print(f" -> 在 Objdump 中找到以下助记符为 '{source_mnemonic}' 的指令:") - for m_match in mnemonic_matches: - print(f" - '{m_match}'") + # mnemonic_matches = find_mnemonic_matches(canonical_content, instruction_map) + # if mnemonic_matches: + # source_mnemonic = canonical_content.split()[0] + # print(f" -> 在 Objdump 中找到以下助记符为 '{source_mnemonic}' 的指令:") + # for m_match in mnemonic_matches: + # print(f" - '{m_match}'") if score > SIMILARITY_THRESHOLD: print(f"警告 (行 {line_num}): '{content_with_comment}' 与 '{best_match_key}' 的相似度为 {score:.2%},这里同样进行替换。请检查是否正确。") new_line = f"{indentation}.inst 0x{instruction_map[best_match_key]} // {content_with_comment}\n" f_out.write(new_line) else: - print(f"错误 (行 {line_num}): 无法为 '{content_with_comment}' 找到任何直接或等价的匹配项,请检查指令或手动添加支持。") + print(f"错误 (行 {line_num}): 无法为 '{content_with_comment}' 找到任何直接或等价的匹配项,最相似匹配是 {best_match_key}, 相似分是{score} 请检查指令或手动添加支持。") f_out.write(line) # 保持原样 def main(): diff --git a/transformers/llm/config.json b/transformers/llm/config.json index d508b467d9..e30ec50821 100755 --- a/transformers/llm/config.json +++ b/transformers/llm/config.json @@ -10,8 +10,7 @@ "use_mmap":"false", "is_batch_quant": 1, - + "reuse_kv": false, - "quant_kv": 0, - "kvcache_limit": -1 + "quant_kv": 0 } diff --git a/transformers/llm/engine/CMakeLists.txt b/transformers/llm/engine/CMakeLists.txt index fb7b2e5b93..40534ea06b 100644 --- a/transformers/llm/engine/CMakeLists.txt +++ b/transformers/llm/engine/CMakeLists.txt @@ -1,5 +1,7 @@ option(BUILD_MLS "Build PC Commandline." OFF) option(LLM_USE_MINJA "Use minja to apply template" ON) +option(MNN_LLM_BUILD_DEMO "Build LLM demo" ON) +option(LLM_SUPPORT_HTTP_RESOURCE "Support HTTP resource download" ON) set(LLM_DEPS ${MNN_DEPS}) if (MNN_BUILD_OPENCV) @@ -53,8 +55,11 @@ if (MNN_BUILD_AUDIO) add_definitions(-DLLM_SUPPORT_AUDIO) endif() -# Disable exceptions in httplib since MNN is compiled with -fno-exceptions -target_compile_definitions(llm PRIVATE CPPHTTPLIB_NO_EXCEPTIONS) +if (LLM_SUPPORT_HTTP_RESOURCE) + target_compile_definitions(llm PRIVATE LLM_SUPPORT_HTTP_RESOURCE) + # Disable exceptions in httplib since MNN is compiled with -fno-exceptions + target_compile_definitions(llm PRIVATE CPPHTTPLIB_NO_EXCEPTIONS) +endif() IF(CMAKE_SYSTEM_NAME MATCHES "^Android" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND) IF(NOT NATIVE_INCLUDE_OUTPUT) @@ -71,15 +76,17 @@ INSTALL(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/include/ DESTINATION include FILES_M ENDIF() -add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp) -add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) -add_executable(reranker_demo ${CMAKE_CURRENT_LIST_DIR}/demo/reranker_demo.cpp) -add_executable(rollback_demo ${CMAKE_CURRENT_LIST_DIR}/demo/rollback_demo.cpp) -include(${CMAKE_CURRENT_LIST_DIR}/tools/CMakeLists.txt) -target_link_libraries(llm_demo ${LLM_DEPS}) -target_link_libraries(embedding_demo ${LLM_DEPS}) -target_link_libraries(reranker_demo ${LLM_DEPS}) -target_link_libraries(rollback_demo ${LLM_DEPS}) +if(MNN_LLM_BUILD_DEMO) + add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp) + add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) + add_executable(reranker_demo ${CMAKE_CURRENT_LIST_DIR}/demo/reranker_demo.cpp) + add_executable(rollback_demo ${CMAKE_CURRENT_LIST_DIR}/demo/rollback_demo.cpp) + include(${CMAKE_CURRENT_LIST_DIR}/tools/CMakeLists.txt) + target_link_libraries(llm_demo ${LLM_DEPS}) + target_link_libraries(embedding_demo ${LLM_DEPS}) + target_link_libraries(reranker_demo ${LLM_DEPS}) + target_link_libraries(rollback_demo ${LLM_DEPS}) +endif() if (BUILD_MLS) set(CMAKE_OSX_DEPLOYMENT_TARGET "13.0" CACHE STRING "Minimum macOS version" FORCE) diff --git a/transformers/llm/engine/demo/llm_demo.cpp b/transformers/llm/engine/demo/llm_demo.cpp index 78cce98a65..305ef2169b 100644 --- a/transformers/llm/engine/demo/llm_demo.cpp +++ b/transformers/llm/engine/demo/llm_demo.cpp @@ -105,6 +105,7 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ if (prompt.substr(0, 1) == "#") { continue; } + if (max_token_number >= 0) { llm->response(prompt, &std::cout, nullptr, 0); while (!llm->stoped() && context->gen_seq_len < max_token_number) { diff --git a/transformers/llm/engine/demo/rollback_demo.cpp b/transformers/llm/engine/demo/rollback_demo.cpp index 65c3066307..5a483cd5c2 100644 --- a/transformers/llm/engine/demo/rollback_demo.cpp +++ b/transformers/llm/engine/demo/rollback_demo.cpp @@ -58,7 +58,7 @@ std::vector> parse_csv(const std::vector& return csv_data; } -static int benchmark(Llm* llm, const std::vector& prompts, int max_token_number) { +static int benchmark(Llm* llm, const std::vector& prompts, int max_token_number, bool is_prompt_cache) { if (prompts.size() < 3) { MNN_ERROR("Need larger than 3 inputs\n"); return 0; @@ -68,49 +68,109 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ if (max_token_number <= 0) { max_token_number = 512; } - MNN_PRINT("Prefill\n"); - std::vector history; - for (int i = 0; i < 3; i++) { - const auto& prompt = prompts[i]; - llm->response(prompt, &std::cout, nullptr, 0); + + if(is_prompt_cache) { + MNN_PRINT("Prefix prompt cache demo\n"); + + auto prompt_base = prompts[0]; + auto prompt_add_0 = prompts[1]; + auto prompt_add_1 = prompts[2]; + std::vector history; + + // step 1: set prefix cache file name + llm->setPrefixCacheFile("model_prompt_config_mnnversion"); + // step 2: prefill prefix prompt + llm->response(prompt_base, &std::cout, nullptr, 0); + + + auto prompt_len = context->prompt_len; + auto decode_len = context->gen_seq_len; + auto prefill_time = context->prefill_us; + auto decode_time = context->decode_us; + auto sample_time = context->sample_us; + auto first_prefill_time = prefill_time; + // step 3: prompt_add_0 for response + llm->response(prompt_add_0); + + // step 4: erase first prompt_add_0 history history.emplace_back(llm->getCurrentHistory()); + llm->eraseHistory(prompt_len, history[0]); + + prompt_len += context->prompt_len; + decode_len += context->gen_seq_len; + prefill_time += context->prefill_us; + decode_time += context->decode_us; + sample_time += context->sample_us; + + // step 5: prompt_add_1 for response + llm->response(prompt_add_1); + + prompt_len += context->prompt_len; + decode_len += context->gen_seq_len; + prefill_time += context->prefill_us; + decode_time += context->decode_us; + sample_time += context->sample_us; + + float prefill_s = prefill_time / 1e6; + float decode_s = decode_time / 1e6; + float sample_s = sample_time / 1e6; + + MNN_PRINT("\n#################################\n"); + MNN_PRINT("prompt tokens num = %d\n", prompt_len); + MNN_PRINT("decode tokens num = %d\n", decode_len); + MNN_PRINT("first prefill time = %.2f s\n", (float)(first_prefill_time / 1e6)); + MNN_PRINT("prefill time = %.2f s\n", prefill_s); + MNN_PRINT(" decode time = %.2f s\n", decode_s); + MNN_PRINT(" sample time = %.2f s\n", sample_s); + MNN_PRINT("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); + MNN_PRINT(" decode speed = %.2f tok/s\n", decode_len / decode_s); + MNN_PRINT("##################################\n"); + } else { + + MNN_PRINT("Prefill\n"); + std::vector history; + for (int i = 0; i < 3; i++) { + const auto& prompt = prompts[i]; + llm->response(prompt, &std::cout, nullptr, 0); + history.emplace_back(llm->getCurrentHistory()); + } + MNN_PRINT("\n"); + + MNN_PRINT("[LLM Test: Erase 1]\n"); + llm->eraseHistory(history[0], history[1]); + llm->response(prompts[prompts.size()-1], &std::cout, nullptr, 0); + while (!llm->stoped() && context->gen_seq_len < max_token_number) { + llm->generate(1); + } + MNN_PRINT("\n[LLM Test End]\n"); + + llm->eraseHistory(0, 0); + history.clear(); + for (int i = 0; i < 3; i++) { + const auto& prompt = prompts[i]; + llm->response(prompt, &std::cout, nullptr, 0); + history.emplace_back(llm->getCurrentHistory()); + } + MNN_PRINT("[LLM Test: Erase 2]\n"); + llm->eraseHistory(history[1], history[2]); + llm->response(prompts[prompts.size()-1], &std::cout, nullptr, 0); + while (!llm->stoped() && context->gen_seq_len < max_token_number) { + llm->generate(1); + } + MNN_PRINT("\n[LLM Test End]\n"); + MNN_PRINT("[LLM Test For Init]\n"); + llm->reset(); + llm->eraseHistory(0, 0); + llm->response(prompts[prompts.size()-1], &std::cout, nullptr, 0); + while (!llm->stoped() && context->gen_seq_len < max_token_number) { + llm->generate(1); + } + MNN_PRINT("\n[LLM Test End]\n"); } - MNN_PRINT("\n"); - - MNN_PRINT("[LLM Test: Erase 1]\n"); - llm->eraseHistory(history[0], history[1]); - llm->response(prompts[prompts.size()-1], &std::cout, nullptr, 0); - while (!llm->stoped() && context->gen_seq_len < max_token_number) { - llm->generate(1); - } - MNN_PRINT("\n[LLM Test End]\n"); - - llm->eraseHistory(0, 0); - history.clear(); - for (int i = 0; i < 3; i++) { - const auto& prompt = prompts[i]; - llm->response(prompt, &std::cout, nullptr, 0); - history.emplace_back(llm->getCurrentHistory()); - } - MNN_PRINT("[LLM Test: Erase 2]\n"); - llm->eraseHistory(history[1], history[2]); - llm->response(prompts[prompts.size()-1], &std::cout, nullptr, 0); - while (!llm->stoped() && context->gen_seq_len < max_token_number) { - llm->generate(1); - } - MNN_PRINT("\n[LLM Test End]\n"); - MNN_PRINT("[LLM Test For Init]\n"); - llm->reset(); - llm->eraseHistory(0, 0); - llm->response(prompts[prompts.size()-1], &std::cout, nullptr, 0); - while (!llm->stoped() && context->gen_seq_len < max_token_number) { - llm->generate(1); - } - MNN_PRINT("\n[LLM Test End]\n"); return 0; } -static int eval(Llm* llm, std::string prompt_file, int max_token_number) { +static int eval(Llm* llm, std::string prompt_file, int max_token_number, bool is_prompt_cache) { std::cout << "prompt file is " << prompt_file << std::endl; std::ifstream prompt_fs(prompt_file); std::vector prompts; @@ -125,12 +185,12 @@ static int eval(Llm* llm, std::string prompt_file, int max_token_number) { if (prompts.empty()) { return 1; } - return benchmark(llm, prompts, max_token_number); + return benchmark(llm, prompts, max_token_number, is_prompt_cache); } int main(int argc, const char* argv[]) { if (argc < 2) { - std::cout << "Usage: " << argv[0] << " config.json " << std::endl; + std::cout << "Usage: " << argv[0] << " config.json prompt.txt " << std::endl; return 0; } MNN::BackendConfig backendConfig; @@ -141,15 +201,26 @@ int main(int argc, const char* argv[]) { std::cout << "config path is " << config_path << std::endl; std::unique_ptr llm(Llm::createLLM(config_path)); llm->set_config("{\"tmp_path\":\"tmp\"}"); + llm->set_config("{\"prefix_cache_path\":\"prefixcache\"}"); { AUTOTIME; llm->load(); } - int max_token_number = -1; + std::string prompt_file = argv[2]; + + int enable_cache_prompt = 0; if (argc >= 4) { std::istringstream os(argv[3]); + os >> enable_cache_prompt; + if(enable_cache_prompt != 0 && enable_cache_prompt != 1) { + MNN_PRINT("[Warning]: cache_prefix_in_disk value only accept 0 or 1.\n"); + } + } + + int max_token_number = -1; + if (argc >= 5) { + std::istringstream os(argv[4]); os >> max_token_number; } - std::string prompt_file = argv[2]; - return eval(llm.get(), prompt_file, max_token_number); + return eval(llm.get(), prompt_file, max_token_number, enable_cache_prompt == 1); } diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index bf37d6a021..6ae61a5e35 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -112,6 +112,7 @@ class MNN_PUBLIC Llm { void setKVCacheInfo(size_t add, size_t remove, int* reserve = nullptr, int n_reserve = 0); size_t getCurrentHistory() const; void eraseHistory(size_t begin, size_t end); + bool setPrefixCacheFile(const std::string& filename, int flag = 0); virtual void response(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1); void response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1); void response(const ChatMessages& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1); @@ -179,7 +180,6 @@ class MNN_PUBLIC Llm { std::shared_ptr mGenerationStrategy; void setSpeculativeConfig(); void updateContext(int seq_len, int gen_len); - private: bool mInSpec = false; int mDraftLength = 4; @@ -187,6 +187,11 @@ class MNN_PUBLIC Llm { bool mAsync = true; int mBlockSize = 0; std::vector mValidBlockSize; + bool mPrefixCacheMode = false; + std::string mPrefixCacheFileName; + int mCallIndex; + int mPrefixLength; + bool mIsPrefixFileExist = false; }; // Embedding start diff --git a/transformers/llm/engine/include/llm/reranker.hpp b/transformers/llm/engine/include/llm/reranker.hpp index 4bded5f886..85fcefe1be 100644 --- a/transformers/llm/engine/include/llm/reranker.hpp +++ b/transformers/llm/engine/include/llm/reranker.hpp @@ -36,6 +36,11 @@ class RerankerBase { */ virtual void initialize(const std::string& config_path) = 0; + /** + * @brief Loads the reranker model after initialization. + */ + virtual void load() = 0; + /** * @brief Sets the instruction for the reranker. * @param instruct The instruction string. @@ -71,15 +76,21 @@ class Qwen3Reranker : public RerankerBase { */ Qwen3Reranker(const std::string& config_path) { initialize(config_path); - } + } /** - * @brief Initializes the LLM and token IDs. + * @brief Initializes the LLM. * @param config_path The path to the LLM configuration. */ void initialize(const std::string& config_path) override { mLlm.reset(Llm::createLLM(config_path)); mLlm->set_config("{\"all_logits\":true}"); + } + + /** + * @brief Loads the LLM and initializes token IDs. + */ + void load() override { mLlm->load(); mTokenTrueId = mLlm->tokenizer_encode("yes")[0]; mTokenFalseId = mLlm->tokenizer_encode("no")[0]; @@ -213,7 +224,14 @@ class GteReranker : public RerankerBase { * @param config_path The path to the LLM configuration. */ void initialize(const std::string& config_path) override { - mLlm.reset(Embedding::createEmbedding(config_path)); + mLlm.reset(Embedding::createEmbedding(config_path, false)); + } + + /** + * @brief Loads the reranker model after initialization. + */ + void load() override { + mLlm->load(); } /** diff --git a/transformers/llm/engine/src/kvmeta.hpp b/transformers/llm/engine/src/kvmeta.hpp index ba882ee10d..9fb798fa6c 100644 --- a/transformers/llm/engine/src/kvmeta.hpp +++ b/transformers/llm/engine/src/kvmeta.hpp @@ -5,8 +5,8 @@ // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef KVMATE_hpp -#define KVMATE_hpp +#ifndef KVMETA_hpp +#define KVMETA_hpp #include @@ -15,16 +15,26 @@ using namespace Express; namespace Transformer { struct KVMeta { + enum { + NoChange, + PendingWrite, + PendingRead + } file_operation; size_t block = 4096; size_t previous = 0; size_t remove = 0; int* reserve = nullptr; int n_reserve = 0; size_t add = 0; + std::string file_name = ""; + int file_flag = NoChange; + int seqlen_in_disk = 0; + int layer_index = 0; + int layer_nums = 0; std::vector reserveHost; void sync(); }; } } -#endif // KVMATE_hpp \ No newline at end of file +#endif // KVMATE_hpp diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 2ef4dcc1bb..61a3715569 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -24,6 +24,7 @@ #include "sampler.hpp" #include "omni.hpp" #include "speculative_decoding/generate.hpp" +#include "core/MNNFileUtils.h" // 0: no debug, 1: test op time, 2: print tensor info, 3: print tensor in output #define DEBUG_MODE 0 @@ -127,7 +128,9 @@ void Llm::setRuntimeHint(std::shared_ptr &rtg rtg->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0); rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->config_.value("quant_qkv", 8)); - rtg->setHint(MNN::Interpreter::KVCACHE_SIZE_LIMIT, mConfig->kvcache_limit()); + if (mConfig->reuse_kv() && mConfig->config_.value("quant_qkv", 8) == 10) { + rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, 9); + } if (mConfig->use_cached_mmap()) { rtg->setHint(MNN::Interpreter::USE_CACHED_MMAP, 1); } @@ -135,30 +138,22 @@ void Llm::setRuntimeHint(std::shared_ptr &rtg if (mConfig->kvcache_mmap()) { rtg->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_PATH_KVCACHE_DIR); } + auto cachePath = mConfig->prefix_cache_path(); + rtg->setExternalPath(cachePath, MNN::Interpreter::EXTERNAL_PATH_PREFIXCACHE_DIR); if (mConfig->use_mmap()) { rtg->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_WEIGHT_DIR); } // set npu model dir rtg->setExternalPath(mConfig->npu_model_dir(), MNN::Interpreter::EXTERNAL_NPU_FILE_DIR); - auto dynamicOption = mConfig->dynamic_option(); - if (mConfig->dynamic_option()) { - rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, mConfig->dynamic_option()); - } - if (mConfig->thread_num() > 7) { // if thread_num > 7, cpu dynamic quant use Arm86 kernels - rtg->setHint(MNN::Interpreter::CPU_SME2_INSTRUCTIONS, 0); - } else { - rtg->setHint(MNN::Interpreter::CPU_SME2_INSTRUCTIONS, 1); + rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, mConfig->config_.value("dynamic_option", 0)); - } - if (mConfig->config_.value("prefer_decode", false)) { - dynamicOption = dynamicOption % 8 + 8; - rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, dynamicOption); - } rtg->setHintPtr(Interpreter::KVCACHE_INFO, mMeta.get()); if (backend_type_convert(mConfig->backend_type()) != 0) { // not cpu std::string cacheFilePath = tmpPath.length() != 0 ? tmpPath : "."; rtg->setCache(cacheFilePath + "/mnn_cachefile.bin"); } + rtg->setHint(MNN::Interpreter::CPU_SME2_NEON_DIVISION_RATIO, mConfig->config_.value("cpu_sme2_neon_division_ratio", 41)); + rtg->setHint(MNN::Interpreter::CPU_SME_CORES, mConfig->config_.value("cpu_sme_core_num", 2)); } void Llm::initRuntime() { @@ -716,23 +711,58 @@ std::vector Llm::generate(const std::vector& input_ids, int max_tokens if (max_tokens < 0) { max_tokens = mConfig->max_new_tokens(); } + + bool passExecute = false; + if(mPrefixCacheMode) { + mCallIndex++; + + // first time execute generate function + if(mCallIndex == 1) { + passExecute = mIsPrefixFileExist; + + if(!mIsPrefixFileExist) { + // save prefix kvcache file + mMeta->file_name = mPrefixCacheFileName; + mMeta->file_flag = KVMeta::PendingWrite; // write + } else { + // first time and cachefile exist, pass this time + } + mPrefixLength = input_ids.size(); + } + // second time execute generate function + else if(mCallIndex == 2) { + // second time and cachefile exist, load prefix file + if(mIsPrefixFileExist) { + mMeta->file_name = mPrefixCacheFileName; + mMeta->file_flag = KVMeta::PendingRead; // read + mMeta->seqlen_in_disk = mPrefixLength; // set_length + } + } + } + mContext->history_tokens.insert(mContext->history_tokens.end(), input_ids.begin(), input_ids.end()); // push to history_ids_ - if (0 == mBlockSize || input_ids.size() <= mBlockSize) { - auto hidden_states = embedding(input_ids); - return generate(hidden_states, max_tokens); - } - int total_size = (int)input_ids.size(); - int loop_size = UP_DIV(total_size, mBlockSize); - for (int i = 0; i < loop_size; i++) { - auto start = i * mBlockSize; - auto end = (i+1) * mBlockSize; - if (end >= total_size) { - end = total_size; + if(!passExecute) { + if (0 == mBlockSize || input_ids.size() <= mBlockSize) { + auto hidden_states = embedding(input_ids); + return generate(hidden_states, max_tokens); + } + int total_size = (int)input_ids.size(); + int loop_size = UP_DIV(total_size, mBlockSize); + for (int i = 0; i < loop_size; i++) { + auto start = i * mBlockSize; + auto end = (i+1) * mBlockSize; + if (end >= total_size) { + end = total_size; + } + std::vector chunk_ids(input_ids.begin() + start, input_ids.begin() + end); + auto input_embeds = embedding(chunk_ids); + generate(input_embeds, 0); } - std::vector chunk_ids(input_ids.begin() + start, input_ids.begin() + end); - auto input_embeds = embedding(chunk_ids); - generate(input_embeds, 0); + } else { + // update states + updateContext((int)input_ids.size(), 0); } + generate(max_tokens); mContext->prompt_len = static_cast(input_ids.size()); return mContext->output_tokens; @@ -770,6 +800,7 @@ std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) } int seqLen = input_embeds->getInfo()->dim[mSeqLenIndex]; mContext->prompt_len = seqLen; + Timer _t; forwardVec(input_embeds); if(mGenerateParam->outputs.size() < 1) { @@ -777,9 +808,24 @@ std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) } updateContext(seqLen, 0); mContext->prefill_us += _t.durationInUs(); - MNN::Express::ExecutorScope::Current()->gc(); // after prefill + // prefix cache mode and response second time + if(mPrefixCacheMode && mCallIndex == 2) { + if(mIsPrefixFileExist) { + // when cachefile exist, after second time prefill, updata previous length + mMeta->previous += mMeta->seqlen_in_disk; + } + // recover meta status + mMeta->seqlen_in_disk = 0; + mMeta->file_name = ""; + mMeta->file_flag = KVMeta::NoChange; + mMeta->layer_index = 0; + // recover normal mode + mPrefixCacheMode = false; + } + + #if DEBUG_MODE == 3 { std::ofstream outFile("input_embeds.txt"); @@ -840,6 +886,7 @@ void Llm::response(const ChatMessages& chat_prompts, std::ostream* os, const cha Llm::Llm(std::shared_ptr config) : mConfig(config) { mContext.reset(new LlmContext); mMeta.reset(new KVMeta); + mMeta->layer_nums = mConfig->layer_nums(); mGenerateParam.reset(new GenerationParams); } @@ -889,6 +936,29 @@ std::vector Llm::getOutputs() const { return mGenerateParam->outputs; } +bool Llm::setPrefixCacheFile(const std::string& filename, int flag) { + mPrefixCacheFileName = filename; + mCallIndex = 0; + mPrefixCacheMode = true; + + + mIsPrefixFileExist = true; + // check kvcache, validate file existence + for(int i = 0; i < mConfig->layer_nums(); i++) { + auto k_file = MNNFilePathConcat(mConfig->prefix_cache_path(), mPrefixCacheFileName) + "_" + std::to_string(i) + "_sync.k"; + if(!MNNFileExist(k_file.c_str())) { + mIsPrefixFileExist = false; + break; + } + auto v_file = MNNFilePathConcat(mConfig->prefix_cache_path(), mPrefixCacheFileName) + "_" + std::to_string(i) + "_sync.v"; + if(!MNNFileExist(v_file.c_str())) { + mIsPrefixFileExist = false; + break; + } + } + return mIsPrefixFileExist; +} + bool Llm::reuse_kv() { return mConfig->reuse_kv(); } static inline bool needNewVar(VARP var, int axis, int seq_len, int kv_seq_len = 0) { diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index 9656a566c9..97eb044a40 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -359,10 +359,6 @@ class LlmConfig { if (mllm) return mllm_config_.value("memory", "low"); return config_.value("memory", "low"); } - - int kvcache_limit() const { - return config_.value("kvcache_limit", -1); - } // backend config end > // talker config start @@ -454,6 +450,10 @@ class LlmConfig { return config_.value("tmp_path", ""); } + std::string prefix_cache_path() const { + return config_.value("prefix_cache_path", "prefixcache"); + } + std::string system_prompt() const { return config_.value("system_prompt", ""); } diff --git a/transformers/llm/engine/src/omni.cpp b/transformers/llm/engine/src/omni.cpp index 4afb06a907..b783c8740a 100644 --- a/transformers/llm/engine/src/omni.cpp +++ b/transformers/llm/engine/src/omni.cpp @@ -11,6 +11,7 @@ #endif #include #include +#include #include #include #include "omni.hpp" @@ -19,7 +20,9 @@ #include "tokenizer.hpp" #include "diskembedding.hpp" #include "sampler.hpp" +#ifdef LLM_SUPPORT_HTTP_RESOURCE #include "httplib.h" +#endif #ifdef LLM_SUPPORT_VISION #include #endif @@ -759,6 +762,7 @@ std::vector Omni::multimodeProcess(const std::string& mode, std::string inf // std::cout << "hw: " << mVisionHeight << ", " << mVisionWidth << std::endl; // std::cout << "file: " << file_info << std::endl; } +#ifdef LLM_SUPPORT_HTTP_RESOURCE if (file_info.substr(0, 4) == "http") { std::regex url_regex(R"(^https?://([^/]+)(/.*))"); std::smatch url_match_result; @@ -784,6 +788,7 @@ std::vector Omni::multimodeProcess(const std::string& mode, std::string inf std::cerr << "Failed to download file. Status code: " << (res ? res->status : 0) << std::endl; } } +#endif if (mode == "img" && mConfig->is_visual()) { return visionProcess(file_info); } diff --git a/transformers/llm/engine/src/tokenizer.cpp b/transformers/llm/engine/src/tokenizer.cpp index 0c5642495b..bd1a465d7e 100644 --- a/transformers/llm/engine/src/tokenizer.cpp +++ b/transformers/llm/engine/src/tokenizer.cpp @@ -6,6 +6,8 @@ // #include +#define MNN_OPEN_TIME_TRACE 1 +#include #include "tokenizer.hpp" #include #include @@ -22,62 +24,42 @@ namespace MNN { namespace Transformer { // base64 -static const char* get_base64_chars() { - return "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; -} - -static inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - -static inline size_t one_char_len(const char *src) { - return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; -} +static const int kBase64DecodeTable[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 0-15 + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // 16-31 + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, // 32-47 (+, /) + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, // 48-63 (0-9) + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, // 64-79 (A-O) + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, // 80-95 (P-Z) + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, // 96-111 (a-o) + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1 // 112-127 (p-z) +}; static std::string base64_decode(const std::string& str) { - int in_len = str.size(); - int i = 0; - int j = 0; - int in_ = 0; - unsigned char char_array_4[4], char_array_3[3]; + if (str.empty()) return ""; + size_t in_len = str.size(); std::string ret; + ret.reserve(in_len * 3 / 4 + 2); - while (in_len-- && ( str[in_] != '=') && is_base64(str[in_])) { - char_array_4[i++] = str[in_]; in_++; - if (i ==4) { - for (i = 0; i <4; i++) { - const char* base64_chars = get_base64_chars(); - char_array_4[i] = strchr(base64_chars, char_array_4[i]) - base64_chars; - } - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (i = 0; (i < 3); i++) { - ret.push_back(char_array_3[i]); - } - i = 0; - } - } - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; - } - for (j = 0; j < 4; j++) { - const char* base64_chars = get_base64_chars(); - char_array_4[j] = strchr(base64_chars, char_array_4[j]) - base64_chars; - } - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - for (j = 0; (j < i - 1); j++) { - ret.push_back(char_array_3[j]); + int val = 0, valb = -8; + for (unsigned char c : str) { + if (c > 127) continue; + int d = kBase64DecodeTable[c]; + if (d == -1) continue; + val = (val << 6) + d; + valb += 6; + if (valb >= 0) { + ret.push_back(char((val >> valb) & 0xFF)); + valb -= 8; } } return ret; } +static inline size_t one_char_len(const char *src) { + return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; +} + static inline void to_lower_case(std::string& str) { for (auto &c : str) { if (c >= 'A' && c <= 'Z') { @@ -87,6 +69,7 @@ static inline void to_lower_case(std::string& str) { } Tokenizer* Tokenizer::createTokenizer(const std::string& filename) { + AUTOTIME; Tokenizer* tokenizer = nullptr; // check file std::ifstream tok_file(filename); @@ -128,6 +111,7 @@ Tokenizer* Tokenizer::createTokenizer(const std::string& filename) { // load vocabs tokenizer->load_vocab(tok_file); tok_file.close(); + tokenizer->cache_special_tokens(); return tokenizer; } @@ -170,16 +154,31 @@ void Tokenizer::load_special(std::ifstream& tok_file) { } } +void Tokenizer::cache_special_tokens() { + special_tokens_cache_.clear(); + for (int id : special_tokens_) { + std::string token_str = decode(id); + if (!token_str.empty()) { + special_tokens_cache_.emplace_back(token_str, id); + } + } +} + std::vector Tokenizer::encode(const std::string& str) { std::vector ids = prefix_tokens_; - if (!special_tokens_.empty()) { + if (special_tokens_cache_.empty() && !special_tokens_.empty()) { + cache_special_tokens(); + } + + if (!special_tokens_cache_.empty()) { std::string text = str; size_t start = 0; for (size_t i = 0; i < text.length(); ++i) { - for (auto special_id : special_tokens_) { - const auto& token = decode(special_id); - if (token.empty()) continue; - if (i + token.length() <= text.length() && text.substr(i, token.length()) == token) { + for (const auto& pair : special_tokens_cache_) { + const std::string& token = pair.first; + int special_id = pair.second; + if (i + token.length() <= text.length() && + strncmp(text.c_str() + i, token.c_str(), token.length()) == 0) { if (i > start) { encode(text.substr(start, i - start), ids); } @@ -200,24 +199,33 @@ std::vector Tokenizer::encode(const std::string& str) { } bool Sentencepiece::load_vocab(std::ifstream& tok_file) { - std::string line, token; - std::getline(tok_file, line); + AUTOTIME; + std::string line; + if (!std::getline(tok_file, line)) return false; int vocab_len = std::stoi(line); - float score; - int type; sentence_pieces_.resize(vocab_len); + pieces_.reserve(vocab_len); + reserved_id_map_.reserve(vocab_len); + for (int index = 0; index < vocab_len; index++) { std::getline(tok_file, line); - std::istringstream line_str(line); - line_str >> token >> score >> type; - token = base64_decode(token); + + size_t first_space = line.find(' '); + if (first_space == std::string::npos) continue; + size_t second_space = line.find(' ', first_space + 1); + if (second_space == std::string::npos) continue; + + std::string token = base64_decode(line.substr(0, first_space)); + float score = std::strtof(line.c_str() + first_space + 1, nullptr); + int type = std::atoi(line.c_str() + second_space + 1); + auto piece_type = static_cast(type); - SentencePiece piece = {token, score, piece_type}; - sentence_pieces_[index] = std::move(piece); + sentence_pieces_[index] = {token, score, piece_type}; + string_view_ token_sv(sentence_pieces_[index].piece); if (piece_type == PieceType::NORMAL) { - pieces_.insert({token, index}); + pieces_.insert({token_sv, index}); } else { - reserved_id_map_.insert({token, index}); + reserved_id_map_.insert({token_sv, index}); if (piece_type == PieceType::UNKNOWN) { unk_id_ = index; } @@ -226,7 +234,7 @@ bool Sentencepiece::load_vocab(std::ifstream& tok_file) { return true; } -int Sentencepiece::piece_to_id(const std::string& piece) const { +int Sentencepiece::piece_to_id(string_view_ piece) const { auto it = reserved_id_map_.find(piece); if (it != reserved_id_map_.end()) { return it->second; @@ -285,8 +293,7 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized, f return; } const string_view_ piece(symbols[left].piece.data(), symbols[left].piece.size() + symbols[right].piece.size()); - std::string piece_str(piece.to_string()); - const auto it = pieces_.find(piece_str); + const auto it = pieces_.find(piece); if (it == pieces_.end()) { return; } @@ -367,8 +374,7 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized, f std::function resegment; resegment = [this, &resegment, &rev_merge](string_view_ w, EncodeResult *output) -> void { - std::string w_str(w.to_string()); - const int id = piece_to_id(w_str); + const int id = piece_to_id(w); // std::cout << "piece: " << w << ", id = " << id << std::endl; if (id == -1 || !is_unused(id)) { output->emplace_back(w, id); @@ -514,6 +520,9 @@ std::vector BertTokenizer::word_piece(const std::string& token) { std::string current = token; bool is_first_piece = true; + std::string candidate; + candidate.reserve(token.size() + 2); + while (!current.empty()) { int match_id = -1; size_t match_pos = 0; @@ -521,12 +530,12 @@ std::vector BertTokenizer::word_piece(const std::string& token) { // Try to find the longest matching piece in vocabulary // Start from the full length and work backwards for (size_t len = current.size(); len > 0; --len) { - std::string candidate = current.substr(0, len); - + candidate.clear(); // Add ## prefix for sub-word pieces (not the first piece) if (!is_first_piece) { - candidate = "##" + candidate; + candidate.append("##"); } + candidate.append(current.data(), len); auto vocab_it = encoder_.find(candidate); if (vocab_it != encoder_.end()) { @@ -626,14 +635,106 @@ void BertTokenizer::encode(const std::string& str, std::vector& ids) { } } -std::wstring utf8_to_wstring(const std::string& str) { - std::wstring_convert> myconv; - return myconv.from_bytes(str); +std::wstring utf8_to_wstring(const char* str, size_t len) { + if (len == 0) return std::wstring(); + + std::wstring wstr; + wstr.reserve(len); + + const char* p = str; + const char* end = str + len; + + while (p < end) { + unsigned char c = static_cast(*p); + + if (c < 0x80) { + wstr.push_back(static_cast(c)); + ++p; + } else if (c < 0xE0) { + if (p + 1 < end) { + wstr.push_back(static_cast( + ((c & 0x1F) << 6) | (static_cast(p[1]) & 0x3F) + )); + } + p += 2; + } else if (c < 0xF0) { + if (p + 2 < end) { + wstr.push_back(static_cast( + ((c & 0x0F) << 12) | + ((static_cast(p[1]) & 0x3F) << 6) | + (static_cast(p[2]) & 0x3F) + )); + } + p += 3; + } else if (c < 0xF8) { + if (p + 3 < end) { + unsigned int cp = ((c & 0x07) << 18) | + ((static_cast(p[1]) & 0x3F) << 12) | + ((static_cast(p[2]) & 0x3F) << 6) | + (static_cast(p[3]) & 0x3F); + + if (sizeof(wchar_t) == 2) { + // Windows: Surrogate pairs for code points > 0xFFFF + if (cp > 0xFFFF) { + cp -= 0x10000; + wstr.push_back(static_cast(0xD800 + (cp >> 10))); + wstr.push_back(static_cast(0xDC00 + (cp & 0x3FF))); + } else { + wstr.push_back(static_cast(cp)); + } + } else { + // Linux/macOS: Direct UTF-32 + wstr.push_back(static_cast(cp)); + } + } + p += 4; + } else { + ++p; + } + } + return wstr; } std::string wstring_to_utf8(const std::wstring& str) { - std::wstring_convert> myconv; - return myconv.to_bytes(str); + if (str.empty()) return std::string(); + std::string res; + res.reserve(str.size() * 3); + + const wchar_t* p = str.data(); + const wchar_t* end = p + str.size(); + + while (p < end) { + unsigned int cp = static_cast(*p); + p++; + if (sizeof(wchar_t) == 2) { + if (cp >= 0xD800 && cp <= 0xDBFF) { + if (p < end) { + unsigned int low = static_cast(*p); + if (low >= 0xDC00 && low <= 0xDFFF) { + cp = 0x10000 + ((cp - 0xD800) << 10) + (low - 0xDC00); + p++; + } + } + } + } + if (cp < 0x80) { + res.push_back(static_cast(cp)); + } else if (cp < 0x800) { + res.push_back(static_cast(0xC0 | (cp >> 6))); + res.push_back(static_cast(0x80 | (cp & 0x3F))); + } else if (cp < 0x10000) { + res.push_back(static_cast(0xE0 | (cp >> 12))); + res.push_back(static_cast(0x80 | ((cp >> 6) & 0x3F))); + res.push_back(static_cast(0x80 | (cp & 0x3F))); + } else { + res.push_back(static_cast(0xF0 | (cp >> 18))); + res.push_back(static_cast(0x80 | ((cp >> 12) & 0x3F))); + res.push_back(static_cast(0x80 | ((cp >> 6) & 0x3F))); + res.push_back(static_cast(0x80 | (cp & 0x3F))); + } + } + + return res; } // Given a token as a UTF8 string, encode each byte into an wchar_t @@ -648,48 +749,77 @@ void byte_encode_token(const std::string& token, } bool HuggingfaceTokenizer::load_vocab(std::ifstream& tok_file) { - std::string line, token; - // get nums - int vocab_len, merge_len; - std::getline(tok_file, line); - std::istringstream line_str(line); - line_str >> vocab_len >> merge_len; - // load vocab + std::string line; + line.reserve(256); // Reduce allocation during getline + + // 1. Get nums + int vocab_len = 0; + int merge_len = 0; + if (std::getline(tok_file, line)) { + std::istringstream line_str(line); + line_str >> vocab_len >> merge_len; + } + + // 2. Load vocab decoder_.resize(vocab_len); + encoder_.reserve(vocab_len); + for (int i = 0; i < vocab_len; i++) { std::getline(tok_file, line); - encoder_.insert({line, i}); - decoder_[i] = line; + // Move string to decoder to avoid copy, then use reference for encoder + decoder_[i] = std::move(line); + encoder_.emplace(decoder_[i], i); } - // load merge_rule + + // 3. Load merge_rules + bpe_ranks_.reserve(merge_len); for (int i = 0; i < merge_len; i++) { std::getline(tok_file, line); - int d = line.find(" "); - bpe_ranks_.insert({{utf8_to_wstring(line.substr(0, d)), - utf8_to_wstring(line.substr(d + 1))}, i}); + + size_t d = line.find(' '); + if (d != std::string::npos) { + // Use pointer-based conversion to avoid creating temporary substr strings + bpe_ranks_.emplace(std::make_pair( + utf8_to_wstring(line.data(), d), + utf8_to_wstring(line.data() + d + 1, line.size() - d - 1) + ), i); + } } - // bytes_to_unicode - auto _insert_range = [=](int start, int end) { + + // 4. bytes_to_unicode initialization + // Use a temporary local vector for O(1) access during construction + std::vector temp_map(256, 0); + + auto set_range = [&](int start, int end) { for (int c = start; c <= end; c++) { - b2u_.insert({uint8_t(c), wchar_t(c)}); + temp_map[c] = static_cast(c); } }; - b2u_.clear(); - _insert_range(L'!', L'~'); - _insert_range(L'¡', L'¬'); - _insert_range(L'®', L'ÿ'); + set_range(L'!', L'~'); + set_range(L'¡', L'¬'); + set_range(L'®', L'ÿ'); int n = 0; for (int b = 0; b < 256; b++) { - if (b2u_.find(uint8_t(b)) == b2u_.end()) { - b2u_.insert({uint8_t(b), wchar_t(256 + n)}); + if (temp_map[b] == 0) { + temp_map[b] = static_cast(256 + n); n++; } } - for (auto e : b2u_) { - u2b_.insert({e.second, e.first}); + + // Batch insert into member maps + b2u_.clear(); + u2b_.clear(); + // Hint: Assuming typical map implementations, insertion order matters slightly, + // but just bulk inserting is clean enough. + for (int i = 0; i < 256; ++i) { + uint8_t u8 = static_cast(i); + wchar_t wc = temp_map[i]; + b2u_.emplace(u8, wc); + u2b_.emplace(wc, u8); } + return true; } @@ -767,7 +897,7 @@ void HuggingfaceTokenizer::encode(const std::string& str, std::vector& ids) "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" // std::regex re("('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\\s\\w]+|\\s+)"); */ - std::regex re("('s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n[:alpha:][:digit:]]?[[:alpha:]]+|[[:digit:]]| ?[^\\s[:alpha:][:digit:]]+[\r\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", std::regex_constants::icase); + static const std::regex re("('s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n[:alpha:][:digit:]]?[[:alpha:]]+|[[:digit:]]| ?[^\\s[:alpha:][:digit:]]+[\r\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", std::regex_constants::icase); std::string input = str; std::vector result; @@ -799,7 +929,8 @@ std::string HuggingfaceTokenizer::decode(int id) { if (id >= decoder_.size()) { return ""; } - std::wstring w = utf8_to_wstring(decoder_.at(id)); + auto decode_utf8 = decoder_.at(id); + std::wstring w = utf8_to_wstring(decode_utf8.data(), decode_utf8.size()); std::string r; for (wchar_t c : w) { if (u2b_.find(c) != u2b_.end()) { diff --git a/transformers/llm/engine/src/tokenizer.hpp b/transformers/llm/engine/src/tokenizer.hpp index fce4f4b49b..c0a983bafd 100644 --- a/transformers/llm/engine/src/tokenizer.hpp +++ b/transformers/llm/engine/src/tokenizer.hpp @@ -78,14 +78,14 @@ class Trie { if (sizesecond, current_matched, current_it, ++it, end); } else { if (node.id!=-1) { return node.id; } - else { it = current_it; return current_matched;} + else { it = current_it; return current_matched;} } } public: @@ -142,12 +142,14 @@ class Tokenizer { std::vector encode(const std::string& str); virtual std::string decode(int id) = 0; protected: + void cache_special_tokens(); virtual void load_special(std::ifstream& file); virtual bool load_vocab(std::ifstream& file) = 0; virtual void encode(const std::string& str, std::vector& ids) = 0; std::vector special_tokens_; std::vector stop_tokens_; std::vector prefix_tokens_; + std::vector> special_tokens_cache_; private: std::string mTemplate; }; @@ -192,14 +194,14 @@ class Sentencepiece : public Tokenizer { // pieces from model std::vector sentence_pieces_; // piece -> id map for normal pieces - std::unordered_map pieces_; + std::unordered_map pieces_; // piece -> id map for control, unknown, and byte pieces - std::unordered_map reserved_id_map_; + std::unordered_map reserved_id_map_; private: float get_score(int id) const; bool is_unused(int id) const; bool is_control(int id) const; - int piece_to_id(const std::string& w) const; + int piece_to_id(string_view_ w) const; std::string byte_to_piece(unsigned char c) const; EncodeResult bpe_encode(string_view_ str, float alpha = 0.f); }; diff --git a/transformers/llm/engine/tools/llm_bench.cpp b/transformers/llm/engine/tools/llm_bench.cpp index 029077f58d..aec3af1b9e 100644 --- a/transformers/llm/engine/tools/llm_bench.cpp +++ b/transformers/llm/engine/tools/llm_bench.cpp @@ -27,6 +27,9 @@ struct RuntimeParameters { std::vector precision; std::vector memory; std::vector dynamicOption; + std::vector divisionRatioSme2Neon; + std::vector smeCoreNum; + std::vector quantAttention; }; struct TestParameters { @@ -47,6 +50,9 @@ struct CommandParameters { int precision; int memory; int dynamicOption; + int divisionRatioSme2Neon; + int smeCoreNum; + int quantAttention; int nPrompt; int nGenerate; @@ -54,19 +60,21 @@ struct CommandParameters { int nRepeat; std::string kvCache; std::string loadingTime; - }; static const RuntimeParameters runtimeParamsDefaults = { /* model */ { "./Qwen2.5-1.5B-Instruct" }, /* backends */ { 0 }, - /* threads */ { 4 }, - /* useMmap */ false, + /* threads */ { 4 }, + /* useMmap */ false, /* power */ { 0 }, /* precision */ { 2 }, /* memory */ { 2 }, - /* dynamicOption */ { 0 } + /* dynamicOption */ { 0 }, + /* quantAttention */ { 0 }, + /* divisionRatioSme2Neon*/ { 41 }, + /* smeCoreNum */ { 2 } }; @@ -93,6 +101,8 @@ struct commandParametersInstance { mCmdParam.precision = cmdParam.precision; mCmdParam.memory = cmdParam.memory; mCmdParam.dynamicOption = cmdParam.dynamicOption; + mCmdParam.divisionRatioSme2Neon = cmdParam.divisionRatioSme2Neon; + mCmdParam.quantAttention = cmdParam.quantAttention; mCmdParam.nPrompt = cmdParam.nPrompt; mCmdParam.nGenerate = cmdParam.nGenerate; @@ -100,6 +110,7 @@ struct commandParametersInstance { mCmdParam.nRepeat = cmdParam.nRepeat; mCmdParam.kvCache = cmdParam.kvCache; mCmdParam.loadingTime = cmdParam.loadingTime; + mCmdParam.smeCoreNum = cmdParam.smeCoreNum; } CommandParameters get_cmd_parameters() const { @@ -112,7 +123,10 @@ struct commandParametersInstance { mCmdParam.power == other.mCmdParam.power && mCmdParam.precision == other.mCmdParam.precision && mCmdParam.memory == other.mCmdParam.memory && - mCmdParam.dynamicOption == other.mCmdParam.dynamicOption; + mCmdParam.dynamicOption == other.mCmdParam.dynamicOption && + mCmdParam.quantAttention == other.mCmdParam.quantAttention && + mCmdParam.smeCoreNum == other.mCmdParam.smeCoreNum && + mCmdParam.divisionRatioSme2Neon == other.mCmdParam.divisionRatioSme2Neon; } }; @@ -163,19 +177,25 @@ struct TestInstance { int power; int memory; int dynamicOption; + int divisionRatioSme2Neon; + int smeCoreNum; + int quantAttention; TestInstance(const commandParametersInstance & instance) { - modelConfigFile = instance.mCmdParam.model; - threads = instance.mCmdParam.threads; - useMmap = instance.mCmdParam.useMmap; - nPrompt = instance.mCmdParam.nPrompt; - nGenerate = instance.mCmdParam.nGenerate; + modelConfigFile = instance.mCmdParam.model; + threads = instance.mCmdParam.threads; + useMmap = instance.mCmdParam.useMmap; + nPrompt = instance.mCmdParam.nPrompt; + nGenerate = instance.mCmdParam.nGenerate; backend = instance.mCmdParam.backend; precision = instance.mCmdParam.precision; memory = instance.mCmdParam.memory; power = instance.mCmdParam.power; dynamicOption = instance.mCmdParam.dynamicOption; + divisionRatioSme2Neon = instance.mCmdParam.divisionRatioSme2Neon; + smeCoreNum = instance.mCmdParam.smeCoreNum; + quantAttention = instance.mCmdParam.quantAttention; } std::vector getTokensPerSecond(int n_tokens, std::vector cost_us) const { @@ -291,7 +311,20 @@ struct markdownPrinter : public Printer { if (rp.dynamicOption.size() > 1) { fields.emplace_back("dynamicOption"); } + if (!(rp.divisionRatioSme2Neon.size() == 1 && rp.divisionRatioSme2Neon[0] == runtimeParamsDefaults.divisionRatioSme2Neon[0])) { + fields.emplace_back("divisionRatioSme2Neon"); + } + for (auto x: rp.quantAttention) { + if (x != 0) { + fields.emplace_back("quantAttention"); + break; + } + break; + } + if (!(rp.smeCoreNum.size() == 1 && rp.smeCoreNum[0] == runtimeParamsDefaults.smeCoreNum[0])) { + fields.emplace_back("smeCoreNum"); + } if (rp.useMmap) { fields.emplace_back("useMmap"); } @@ -379,6 +412,22 @@ struct markdownPrinter : public Printer { } else if (field == "useMmap") { if (t.useMmap) value = "true"; else value = "false"; + } else if (field == "divisionRatioSme2Neon") { + snprintf(buf, sizeof(buf), "%d", t.divisionRatioSme2Neon); + value = buf; + } else if (field == "smeCoreNum") { + snprintf(buf, sizeof(buf), "%d", t.smeCoreNum); + value = buf; + } else if (field == "quantAttention") { + snprintf(buf, sizeof(buf), "%d", t.quantAttention); +// value = buf; + if (t.quantAttention == 1) { + value = "Int8 Q,K"; + } else if (t.quantAttention == 2) { + value = "Int8 Q,K,V"; + } else { + + } } else { assert(false); @@ -444,6 +493,9 @@ static std::vector get_cmd_params_instances(const Run for (const auto & power : rp.power) for (const auto & nt : rp.threads) for (const auto & dyop : rp.dynamicOption) + for (const auto &mratio: rp.divisionRatioSme2Neon) + for (const auto &smeNum: rp.smeCoreNum) + for (const auto & quantAttn : rp.quantAttention) if (tp.kvCache == "true") { // MNN llm_demo test standard for (const auto & nPrompt : tp.nPrompt) { if (nPrompt == 0) { @@ -464,9 +516,12 @@ static std::vector get_cmd_params_instances(const Run tmpParam.nGenerate = nGenerate; tmpParam.useMmap = rp.useMmap; tmpParam.dynamicOption = dyop; + tmpParam.quantAttention = quantAttn; tmpParam.nRepeat = tp.nRepeat[0]; tmpParam.kvCache = "true"; tmpParam.loadingTime = tp.loadTime; + tmpParam.divisionRatioSme2Neon = mratio; + tmpParam.smeCoreNum = smeNum; auto instance = commandParametersInstance(tmpParam); instances.push_back(instance); } @@ -487,9 +542,12 @@ static std::vector get_cmd_params_instances(const Run tmpParam.precision = precision; tmpParam.memory = memory; tmpParam.dynamicOption = dyop; + tmpParam.quantAttention = quantAttn; tmpParam.nRepeat = tp.nRepeat[0]; tmpParam.kvCache = "false"; tmpParam.loadingTime = tp.loadTime; + tmpParam.divisionRatioSme2Neon = mratio; + tmpParam.smeCoreNum = smeNum; auto instance = commandParametersInstance(tmpParam); instances.push_back(instance); } @@ -505,9 +563,12 @@ static std::vector get_cmd_params_instances(const Run tmpParam.precision = precision; tmpParam.memory = memory; tmpParam.dynamicOption = dyop; + tmpParam.quantAttention = quantAttn; tmpParam.nRepeat = tp.nRepeat[0]; tmpParam.kvCache = "false"; tmpParam.loadingTime = tp.loadTime; + tmpParam.divisionRatioSme2Neon = mratio; + tmpParam.smeCoreNum = smeNum; auto instance = commandParametersInstance(tmpParam); instances.push_back(instance); } @@ -526,9 +587,12 @@ static std::vector get_cmd_params_instances(const Run tmpParam.precision = precision; tmpParam.memory = memory; tmpParam.dynamicOption = dyop; + tmpParam.quantAttention = quantAttn; tmpParam.nRepeat = tp.nRepeat[0]; tmpParam.kvCache = "false"; tmpParam.loadingTime = tp.loadTime; + tmpParam.divisionRatioSme2Neon = mratio; + tmpParam.smeCoreNum = smeNum; auto instance = commandParametersInstance(tmpParam); instances.push_back(instance); } @@ -568,7 +632,7 @@ static void printUsage(int /* argc */, char ** argv) { printf(" -h, --help\n"); printf(" -m, --model (default: ./Qwen2.5-1.5B-Instruct/config.json)\n"); printf(" -a, --backends (default: %s)\n", "cpu"); - printf(" -c, --precision (default: %s) | Note: (0:Normal(for cpu bakend, 'Nornal' is 'High'),1:High,2:Low)\n", join(runtimeParamsDefaults.precision, ",").c_str()); + printf(" -c, --precision (default: %s) | Note: (0:Normal(for cpu bakend, 'Normal' is 'High'),1:High,2:Low)\n", join(runtimeParamsDefaults.precision, ",").c_str()); printf(" -t, --threads (default: %s)\n", join(runtimeParamsDefaults.threads, ",").c_str()); printf(" -p, --n-prompt (default: %s)\n", join(testParamsDefaults.nPrompt, ",").c_str()); printf(" -n, --n-gen (default: %s)\n", join(testParamsDefaults.nGenerate, ",").c_str()); @@ -577,8 +641,11 @@ static void printUsage(int /* argc */, char ** argv) { printf(" -rep, --n-repeat (default: %s)\n", join(testParamsDefaults.nRepeat, ",").c_str()); printf(" -kv, --kv-cache (default: %s) | Note: if true: Every time the LLM model generates a new word, it utilizes the cached KV-cache\n", "false"); printf(" -fp, --file-print (default: %s)\n", "stdout"); + printf(" -scn, --sme-core-num (default: 2) | Note: Specify the number of smeCoreNum to use.\n"); printf(" -load, --loading-time (default: %s)\n", "true"); printf(" -dyo, --dynamicOption (default: 0) | Note: if set 8, trades higher memory usage for better decoding performance\n"); + printf(" -mr, --mixedSme2NeonRatio (default: 41) | Note: This parameter is intended to optimize multi-threaded inference performance on backends that support Arm SME instructions. The optimal ratio may vary across different models; we recommend trying values such as 41, 49, 33.\n"); + printf(" -qatten, --quant-attention <0|1> (default: 0) | Note: if 1, quantize attention's key value to int8; default 0\n"); } @@ -725,6 +792,28 @@ static bool parseCmdParams(int argc, char ** argv, RuntimeParameters & runtimePa } auto p = splitString(argv[i], splitDelim); testParams.loadTime = p[0]; + } else if (arg == "-mr" || arg == "--miexdSme2NeonRatio") { + if (++i >= argc) { + invalidParam = true; + break; + } + auto p = splitString(argv[i], splitDelim); + runtimeParams.divisionRatioSme2Neon.insert(runtimeParams.divisionRatioSme2Neon.end(), p.begin(), p.end()); + } else if (arg == "-scn" || arg == "--sme-core-num") { + if (++i >= argc) { + invalidParam = true; + break; + } + auto p = splitString(argv[i], splitDelim); + runtimeParams.smeCoreNum.insert(runtimeParams.smeCoreNum.end(), p.begin(), p.end()); + } else if (arg == "-qatten" || arg == "--quant-attention") { + // do nothing, reserved for future use + if (++i >= argc) { + invalidParam = true; + break; + } + auto p = splitString(argv[i], splitDelim); + runtimeParams.quantAttention.insert(runtimeParams.quantAttention.end(), p.begin(), p.end()); } else { invalidParam = true; @@ -770,6 +859,15 @@ static bool parseCmdParams(int argc, char ** argv, RuntimeParameters & runtimePa if (runtimeParams.dynamicOption.empty()) { runtimeParams.dynamicOption = runtimeParamsDefaults.dynamicOption; } + if (runtimeParams.divisionRatioSme2Neon.empty()) { + runtimeParams.divisionRatioSme2Neon = runtimeParamsDefaults.divisionRatioSme2Neon; + } + if (runtimeParams.smeCoreNum.empty()) { + runtimeParams.smeCoreNum = runtimeParamsDefaults.smeCoreNum; + } + if (runtimeParams.quantAttention.empty()) { + runtimeParams.quantAttention = runtimeParamsDefaults.quantAttention; + } if (testParams.nRepeat.empty()) { testParams.nRepeat = testParamsDefaults.nRepeat; } @@ -778,7 +876,7 @@ static bool parseCmdParams(int argc, char ** argv, RuntimeParameters & runtimePa } -static Llm* buildLLM(const std::string& config_path, int backend, int memory, int precision, int threads, int power, int dynamic_option, bool use_mmap) { +static Llm* buildLLM(const std::string& config_path, int backend, int memory, int precision, int threads, int power, int dynamic_option, bool use_mmap, int divisionRatioSme2Neon, int smeCoreNum, int promptLen, int quant_attention) { auto llmPtr = Llm::createLLM(config_path); llmPtr->set_config(R"({ "async":false @@ -813,11 +911,17 @@ static Llm* buildLLM(const std::string& config_path, int backend, int memory, in MNN_ERROR("thread_num for LLM config set error\n"); return nullptr; } - setSuccess &= llmPtr->set_config("{\"dynamic_option\":" + std::to_string(dynamic_option) + "}"); + auto doy = (promptLen <= 300 && promptLen != 0) ? (dynamic_option % 8) : (dynamic_option % 8 + 8); + setSuccess &= llmPtr->set_config("{\"dynamic_option\":" + std::to_string(doy) + "}"); if (!setSuccess) { MNN_ERROR("dynamic_option for LLM config set error\n"); return nullptr; } + setSuccess &= llmPtr->set_config("{\"quant_qkv\":" + std::to_string(quant_attention + 8) + "}"); + if (!setSuccess) { + MNN_ERROR("quant_qkv for LLM config set error\n"); + return nullptr; + } setSuccess &= llmPtr->set_config("{\"use_mmap\":" + mmap[use_mmap] + "}"); if (!setSuccess) { MNN_ERROR("use_mmap for LLM config set error\n"); @@ -828,9 +932,14 @@ static Llm* buildLLM(const std::string& config_path, int backend, int memory, in MNN_ERROR("tmp_path for LLM config set error\n"); return nullptr; } - setSuccess &= llmPtr->set_config("{\"prefer_decode\": false}"); // llm_bench use dynamic_option(-dyo) to control whether to use 'prefer_decode' + setSuccess &= llmPtr->set_config("{\"cpu_sme2_neon_division_ratio\":" + std::to_string(divisionRatioSme2Neon) + "}"); + if (!setSuccess) { + MNN_ERROR("cpu_sme2_neon_division_ratio for LLM config set error\n"); + return nullptr; + } + setSuccess &= llmPtr->set_config("{\"cpu_sme_core_num\":" + std::to_string(smeCoreNum) + "}"); if (!setSuccess) { - MNN_ERROR("prefer_decode for LLM config set error\n"); + MNN_ERROR("cpu_sme_core_num for LLM config set error\n"); return nullptr; } return llmPtr; @@ -868,7 +977,7 @@ int main(int argc, char ** argv) { auto executor = MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 1); MNN::Express::ExecutorScope scope(executor); - auto llmPtr = buildLLM(instance.mCmdParam.model, instance.mCmdParam.backend, instance.mCmdParam.memory, instance.mCmdParam.precision, instance.mCmdParam.threads, instance.mCmdParam.power, instance.mCmdParam.dynamicOption, instance.mCmdParam.useMmap); + auto llmPtr = buildLLM(instance.mCmdParam.model, instance.mCmdParam.backend, instance.mCmdParam.memory, instance.mCmdParam.precision, instance.mCmdParam.threads, instance.mCmdParam.power, instance.mCmdParam.dynamicOption, instance.mCmdParam.useMmap, instance.mCmdParam.quantAttention, instance.mCmdParam.divisionRatioSme2Neon, instance.mCmdParam.smeCoreNum, instance.mCmdParam.nPrompt); std::unique_ptr llm(llmPtr); if (instance.mCmdParam.loadingTime == "true") { for (int k = 0; k < 3; ++k) { @@ -884,14 +993,14 @@ int main(int argc, char ** argv) { if (instance.mCmdParam.nGenerate > 0) { llm->set_config("{\"max_new_tokens\":1}"); } - + auto prompt_tokens = instance.mCmdParam.nPrompt; auto decodeTokens = instance.mCmdParam.nGenerate; // llm_demo test if (instance.mCmdParam.kvCache == "true") { std::vector tokens(prompt_tokens, 16); - + for (int i = 0; i < instance.mCmdParam.nRepeat + 1; ++i) { llm->response(tokens, nullptr, nullptr, decodeTokens); auto prefillTime = context->prefill_us; @@ -910,7 +1019,7 @@ int main(int argc, char ** argv) { // Cool std::this_thread::sleep_for(std::chrono::milliseconds(5)); } - + // llama.cpp llama-bench test if (instance.mCmdParam.kvCache == "false") { int tok = 16; @@ -931,7 +1040,7 @@ int main(int argc, char ** argv) { t.samplesUs.push_back(sampler_us); } } - + if (printHeader) { printer_->fout = outfile; printer_->printHeader(runtimeParams, testParams); @@ -940,7 +1049,7 @@ int main(int argc, char ** argv) { printer_->printPerformance(t); // Cool std::this_thread::sleep_for(std::chrono::milliseconds(5)); - + } } diff --git a/transformers/llm/eval/evaluate_perplexity.py b/transformers/llm/eval/evaluate_perplexity.py index 899523d46b..3819f56913 100644 --- a/transformers/llm/eval/evaluate_perplexity.py +++ b/transformers/llm/eval/evaluate_perplexity.py @@ -2,15 +2,17 @@ import argparse from tqdm import tqdm import MNN.llm as mnnllm -from datasets import load_dataset +from datasets import load_dataset, load_from_disk import torch import copy def main(args): # load model model = mnnllm.create(args.mnn_path) + model.set_config({"quant_qkv": args.quant_qkv}) + model.set_config({'all_logits': True}) model.load() - model.set_config({'all_logits': True, 'use_template': False}) + model.generate_init() # load dataset @@ -19,6 +21,7 @@ def main(args): dataset_dir = eval_dataset.split("/")[1] dataset = load_dataset(dataset_name, dataset_dir, split="test") + # dataset = load_from_disk("./wikitest-2-raw-v1") input_ids = model.tokenizer_encode("\n\n".join(dataset["text"])) stride = 512 context_length = stride + stride // 2 @@ -65,6 +68,12 @@ def main(args): group.add_argument( "-d", "--eval_dataset", type=str, default='wikitext/wikitext-2-raw-v1', help="Evaluation dataset, default is `wikitext/wikitext-2-raw-v1`." ) + group.add_argument( + "--quant-qkv", + type=int, + default=8, + help="Quantization bits for QKV, default is 8(not quant), if set 9, quant", + ) args = parser.parse_args() diff --git a/transformers/llm/export/llmexport.py b/transformers/llm/export/llmexport.py index b812d66236..56138091a2 100644 --- a/transformers/llm/export/llmexport.py +++ b/transformers/llm/export/llmexport.py @@ -1,10 +1,8 @@ import os import json import glob -import base64 import warnings import argparse -import importlib warnings.filterwarnings("ignore") os.environ['TOKENIZERS_PARALLELISM'] = 'false' @@ -12,19 +10,15 @@ import onnx import torch -import transformers -from packaging.version import Version -from typing import Optional, List -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer +from utils.model import LlmModel, EmbeddingModel +from utils.tokenizer import LlmTokenizer from utils.spinner import spinner_run from utils.custom_op import FakeLinear from utils.onnx_rebuilder import OnnxRebuilder -from utils.mnn_converter import MNNConveter +from utils.mnn_converter import MNNConverter from utils.awq_quantizer import AwqQuantizer from utils.smooth_quantizer import SmoothQuantizer -from utils.model_mapper import ModelMapper -from utils.transformers import Embedding, Rotary, Decoder, Lm from utils.torch_utils import onnx_export class LlmExporter(torch.nn.Module): @@ -37,17 +31,8 @@ def __init__(self, args): self.load_model(args.path) def init_from_args(self, args): - self.visual = None - self.audio = None - self.talker = None - self.mtp = None - self.scale_emb = None self.args = args - self.max_length = 1024 - self.stop_ids = [] - self.sliding_attn_layers = [] - self.attention_type = 'full' - self.sliding_window = 0 + self.max_new_tokens = 1024 self.dst_name = 'llm' # load config from args self.onnx_path = os.path.join(self.args.dst_path, 'onnx') @@ -55,111 +40,23 @@ def init_from_args(self, args): self.args.tokenizer_path = self.args.path if args.lm_quant_bit is None: self.args.lm_quant_bit = self.args.quant_bit + self.args.tie_word_embeddings = False # init export dst dir if not os.path.exists(self.args.dst_path): os.makedirs(self.args.dst_path) if not os.path.exists(self.onnx_path): os.makedirs(self.onnx_path) - def get_model_class(self, model_type: str): - MODEL_CLASS_MAPPING = { - 'qwen3_vl': 'Qwen3VLForConditionalGeneration', - 'qwen3_vl_moe': 'Qwen3VLMoeForConditionalGeneration', - 'qwen2_5_omni': 'Qwen2_5OmniForConditionalGeneration', - 'qwen2_5_vl': 'Qwen2_5_VLForConditionalGeneration', - 'qwen2_vl': 'Qwen2VLForConditionalGeneration', - 'qwen2_audio': 'Qwen2AudioForConditionalGeneration', - 'smolvlm': 'AutoModelForImageTextToText', - 'idefics3': 'AutoModelForVision2Seq', - } - if model_type is None or model_type not in MODEL_CLASS_MAPPING: - return AutoModelForCausalLM - class_name = MODEL_CLASS_MAPPING[model_type] - try: - module = importlib.import_module('transformers') - model_class = getattr(module, class_name) - return model_class - except (ImportError, AttributeError) as e: - print(f"Import '{class_name}' from transformer failed: {e}") - return AutoModelForCausalLM - - def load_pretrained(self, model_path: str): - try: - self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_path, trust_remote_code=True, use_fast=False) - except: - try: - self.tokenizer = AutoTokenizer.from_pretrained(self.args.tokenizer_path, trust_remote_code=True) - except: - raise RuntimeError("Load tokenizer failed for ", model_path) - config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - model_type = getattr(config, 'model_type', None) - model_class = self.get_model_class(model_type) - if Version(transformers.__version__) >= Version("4.56.0"): - kwargs = { - 'dtype': 'auto', - 'trust_remote_code': True, - } - else: - kwargs = { - 'torch_dtype': 'auto', - 'trust_remote_code': True, - } - # special args - if model_type == 'internvl_chat': - kwargs.update({'use_flash_attn': False}) - try: - model = model_class.from_pretrained(model_path, **kwargs) - except: - try: - model = AutoModel.from_pretrained(model_path, **kwargs) - except: - raise RuntimeError("Load model failed for ", model_path) - # model & config - self.model = model.eval() - if model_type == 'qwen2_audio': - self.audio = self.model - self.model = self.audio.language_model - self.config = self.model.config - # LoRA - if self.args.lora_path is not None and not self.args.lora_split: - from peft import PeftModel - adapter = PeftModel.from_pretrained(self.model, model_id=self.args.lora_path) - self.model = adapter.merge_and_unload(progressbar=True) - - @staticmethod - def has_attr(obj, attr): - return hasattr(obj, attr) and getattr(obj, attr) is not None - @spinner_run(f'load pretrained model ', True) def load_model(self, model_path): - self.load_pretrained(model_path) - self.attention_mask_type = 'float' - # load tokenizer info - self.stop_ids.append(self.tokenizer.eos_token_id) - if hasattr(self.tokenizer, 'im_end_id'): - self.stop_ids.append(self.tokenizer.im_end_id) - try: - eot_id = self.tokenizer.encode('<|eot_id|>') - if len(eot_id) == 1: - self.stop_ids.append(eot_id[0]) - # gemma/gemma-2 - eot_id = self.tokenizer.encode('') - if len(eot_id) == 2 and eot_id[0] == 2: - self.stop_ids.append(eot_id[1]) - except: - pass - if hasattr(self.model, 'generation_config') and self.model.generation_config is not None: - eos_token_id = self.model.generation_config.eos_token_id - from collections.abc import Iterable - if isinstance(eos_token_id, int): - self.stop_ids.append(eos_token_id) - elif isinstance(eos_token_id, Iterable): - for id in eos_token_id: - self.stop_ids.append(id) - self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None] - self.stop_ids = list(set(self.stop_ids)) - model_mapper = ModelMapper() - self.model_type, self.model_map = model_mapper.get_map(self.config) + self.model = LlmModel.from_pretrained(model_path, args=self.args) + self.tokenizer = LlmTokenizer.from_pretrained( + self.args.tokenizer_path, + model_type=self.model.config.model_type + ) + self.model.tokenizer = self.tokenizer + self.config = self.model.config + self.model_type = self.config.model_type if self.args.awq or self.args.smooth: self.model.float() @@ -171,465 +68,100 @@ def visit_module(module): for name, child in module.named_children(): visit_module(child) visit_module(self.model) - # print(self.config, self.model_type, self.model_map, self.model) - # load config info - ModelMapper.do_map(self, self.config, self.model_map['config']) - if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - if not hasattr(self, 'rope_theta') or self.rope_theta is None: - self.rope_theta = 10000.0 - if not hasattr(self, 'rope_ratio') or self.rope_ratio is None: - self.rope_ratio = 1.0 - if not hasattr(self, 'head_dim') or self.head_dim is None: - if isinstance(self.num_attention_heads, list): - self.head_dim = [self.hidden_size // atten_head for atten_head in self.num_attention_heads] - else: - self.head_dim = self.hidden_size // self.num_attention_heads - # sliding attention layers - if hasattr(self, 'layer_types'): # gpt_oss - for i in range(len(self.layer_types)): - if self.layer_types[i] == 'sliding_attention': - self.sliding_attn_layers.append(i) - if len(self.sliding_attn_layers) >= self.num_hidden_layers: - self.attention_type = 'sliding' - elif len(self.sliding_attn_layers) > 0: - self.attention_type = 'mix' + # some export info - if isinstance(self.num_attention_heads, list): - self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads[0], self.head_dim] + if isinstance(self.config.num_attention_heads, list): + self.past_kv_shape = [self.config.num_hidden_layers, 2, 1, 0, self.config.num_key_value_heads[0], self.config.head_dim] else: - self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim] - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 1: "history_len" } - } + self.past_kv_shape = [self.config.num_hidden_layers, 2, 1, 0, self.config.num_key_value_heads, self.config.head_dim] + self.model_dynamic_axes = { "input_ids" : { 0: "seq_len" }, "attention_mask" : { 2: "seq_len", 3: "seq_len" }, "position_ids" : { 1: "seq_len" }, "past_key_values" : { 3: "history_len" } } - prompt_template = self.build_prompt_template() + self.llm_config = { - 'hidden_size' : self.hidden_size, - 'layer_nums' : self.num_hidden_layers, - 'attention_mask': self.attention_mask_type, - 'key_value_shape': self.past_kv_shape[1:], - "bos": prompt_template['bos'], - "system_prompt_template": prompt_template['system'].format(content='%s'), - 'user_prompt_template': prompt_template['user'].format(content='%s'), - 'assistant_prompt_template': prompt_template['assistant'].format(content='%s'), - 'is_visual': False, - 'attention_type': self.attention_type, + 'model_type': self.config.model_type, + 'hidden_size' : self.config.hidden_size, + 'attention_mask': 'float', # Will be determined by model later + 'attention_type': self.config.attention_type, } - if self.sliding_window > 0: - self.llm_config['sliding_window'] = self.sliding_window - if 'jinja' in prompt_template: - self.llm_config['jinja'] = prompt_template['jinja'] - # load modules - ModelMapper.do_map(self, self.model, self.model_map['model']) - - # rebuild modules - if self.lm_ is None: - out_features, in_features = self.embed_.weight.shape - self.lm_ = torch.nn.Linear(in_features, out_features) - self.lm_.weight = self.embed_.weight - elif not isinstance(self.lm_, torch.nn.Linear): - # for Baichuan2 - weight = self.lm_.weight - out_features, in_features = weight.shape - self.lm_ = torch.nn.Linear(in_features, out_features) - self.lm_.weight = weight - self.lm_.bias.data = torch.zeros(out_features, dtype=weight.dtype) - - if self.embed_.weight is self.lm_.weight: - import copy - embed_copy = copy.deepcopy(self.embed_) - self.embed = Embedding(embed_copy, self) - else: - self.embed = Embedding(self.embed_, self) + self.llm_config.update(self.model.get_config()) + if self.config.sliding_window > 0: + self.llm_config['sliding_window'] = self.config.sliding_window + if hasattr(self.tokenizer, 'get_chat_template'): + self.llm_config['jinja'] = { + 'chat_template': self.tokenizer.get_chat_template() + } + if self.tokenizer.bos_token: + self.llm_config['jinja']['bos'] = self.tokenizer.bos_token + if self.tokenizer.eos_token: + self.llm_config['jinja']['eos'] = self.tokenizer.eos_token # tie word embeddings - self.tie_word_embeddings = not self.args.seperate_embed and self.lm_.weight.equal(self.embed_.weight) - if self.tie_word_embeddings: - print("Tie word embeddings in lm, set lm quant bit to 8") - self.args.lm_quant_bit = 8 - - if 'gemma' in self.model_type: - self.scale_emb = self.embed.embed_scale + self.args.tie_word_embeddings = not self.args.seperate_embed and self.model.lm.lm.weight.equal(self.model.embed.embed.weight) + # Pass properties from model to exporter + self.visual = self.model.visual + self.audio = self.model.audio + self.talker = self.model.talker + self.mtp = self.model.mtp + self.scale_emb = self.model.scale_emb - # Rotary - self.rotary = Rotary(self) - self.blocks = [] - for block in self.blocks_.children(): - layer_id = len(self.blocks) - self.blocks.append(Decoder(block, layer_id, self)) - self.lm = Lm(self.lm_) - - # visual model - if self.visual is not None: - if self.args.export is not None: - self.visual.float() - from utils.vision import Vision - self.visual = Vision.get_vision(self.model_type)(self.visual, self) - if self.args.export is not None: - self.visual.float() - if hasattr(self, 'audio') and self.audio is not None: - from utils.audio import Audio - self.audio = Audio.get_audio(self.audio.config.model_type)(self.audio, self) - else: - self.audio = None - # talker model - if hasattr(self, 'talker') and self.talker is not None and \ - hasattr(self, 'token2wav') and self.token2wav is not None: - from utils.talker import Talker - self.talker = Talker.get_talker(self.model_type)(self.talker, self.token2wav, self) - # MTP model - if self.model_type == 'poi_qwen2_mtp': - self.mtp = [self.mtp1, self.mtp2] - if self.mtp is not None: - if self.args.export is not None: - for mtp_model in self.mtp: - mtp_model.float() - from utils.mtp import Mtp - self.mtp = Mtp.get_mtp(self.model_type)(self.mtp, self) return model_path - def full_attention_mask(self): - if self.token_len: - return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) - return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min - - def sliding_attention_mask(self, sliding_window: int): - if self.token_len: - sliding_mask = torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) - num_tokens_to_mask = self.seq_len - sliding_window - if num_tokens_to_mask > 0: - sliding_mask[..., :num_tokens_to_mask] = torch.finfo(torch.float32).min - return sliding_mask - causal_mask = torch.tril(torch.ones(self.seq_len, self.seq_len, dtype=torch.bool)) - query_indices = torch.arange(self.seq_len).view(-1, 1) - key_indices = torch.arange(self.seq_len).view(1, -1) - window_mask = (key_indices > query_indices - sliding_window) - final_mask_bool = causal_mask & window_mask - sliding_mask = torch.where(final_mask_bool, 0.0, torch.finfo(torch.float32).min) - return sliding_mask.view(1, 1, self.seq_len, self.seq_len) - - def get_attention_mask(self) -> torch.Tensor: - if self.model_type == 'chatglm': - return self.chatglm_attention_mask() - if self.attention_type == 'full': - return self.full_attention_mask() - elif self.attention_type == 'sliding': - return self.sliding_attention_mask(self.sliding_window) - elif self.attention_type == 'mix': - full_mask = self.full_attention_mask() - sliding_mask = self.sliding_attention_mask(self.sliding_window) - return torch.stack([full_mask, sliding_mask], dim=0) - - return None - - def get_position_ids(self, input_ids = None) -> torch.Tensor: - if self.visual is not None and hasattr(self.visual, 'get_position_ids') and callable(getattr(self.visual, 'get_position_ids')): - return self.visual.get_position_ids(input_ids, self.seq_len, self.token_len) - if self.model_type == 'chatglm': - return self.chatglm_position_ids() - if self.token_len: - return torch.tensor([[self.seq_len - 1]], dtype=torch.int) - return torch.arange(self.seq_len, dtype=torch.int).unsqueeze(0) - - def chatglm_attention_mask(self): - if self.token_len: - return torch.zeros([1]).bool().reshape([1, 1, 1, 1]) - attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool) - for i in range(self.seq_len - 1): - attention_mask[i][-1] = True - attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len]) - return attention_mask - - def chatglm_position_ids(self): - if self.token_len: - return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1]) - position_ids_0 = torch.arange(self.seq_len, dtype=torch.int) - position_ids_1 = torch.zeros(self.seq_len, dtype=torch.int) - position_ids_0[-1] = position_ids_0[-2] - position_ids_1[-1] = 1 - position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1) - return position_ids - - def visual_embed(self, input_ids): - return self.visual.embed(input_ids) - - def audio_embed(self, input_ids): - return self.audio.embed(input_ids) - - def embedding(self, input_ids): - if self.visual is not None and self.token_len == 0: - input_embeds = self.visual_embed(input_ids) - elif self.audio is not None and self.token_len == 0: - input_embeds = self.audio_embed(input_ids) - else: - input_embeds = self.embed(input_ids) - return input_embeds - - def forward(self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.Tensor, - past_key_values: Optional[List[torch.Tensor]] = None, - logits_index: int = -1, - deepstack_embeds: torch.Tensor = None - ): - hidden_states = input_ids # llm forward without embedding - if self.scale_emb is not None: - hidden_states = hidden_states * self.scale_emb - presents = [None for i in range(len(self.blocks))] - eagle_hidden_states = [] - rotary_pos_emb = self.rotary(position_ids) - if self.args.test and rotary_pos_emb.dtype != hidden_states.dtype: - rotary_pos_emb = rotary_pos_emb.type(hidden_states.dtype) - - for i in range(len(self.blocks)): - # eagle3 hidden states - if i == len(self.blocks)-3 or i == len(self.blocks)//2 or i==2: - eagle_hidden_states.append(hidden_states) - if past_key_values is not None and past_key_values[i] is not None: - past_kv = past_key_values[i] - else: - past_kv = None - - # sliding or full attn mask - if self.attention_type == 'mix': - is_sliding = i in self.sliding_attn_layers - layer_attention_mask = attention_mask[int(is_sliding)] - else: - layer_attention_mask = attention_mask - - hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, layer_attention_mask, past_kv) - presents[i] = kv - if deepstack_embeds is not None and i in range(deepstack_embeds.shape[0]): - hidden_states += deepstack_embeds[i] - - talker_embeds = None - if hasattr(self, 'talker') and self.talker is not None: - talker_embeds = self.final_layernorm_(hidden_states) + input_ids.permute([1, 0, 2]) - self.talker.add_talker_embeds(talker_embeds) - - final_layernorm = hidden_states - logits_index_long = logits_index.to(torch.int64) - if self.mtp is None: - hidden_states = hidden_states[:, logits_index_long:, :] - hidden_states = self.final_layernorm_(hidden_states) - # default: set hidden_state before lm_head as output node - final_layernorm = hidden_states - else: - # final_layernorm need compute all logists - if self.model_type == 'mimo': - final_layernorm = hidden_states # mimo - hidden_states = self.final_layernorm_(hidden_states) - if self.model_type == 'poi_qwen2_mtp': - final_layernorm = hidden_states # poi - hidden_states = hidden_states[:, logits_index_long:, :] - logits = self.lm(hidden_states) - if presents[0].shape == presents[-1].shape and None not in presents: - presents = torch.stack(presents) - self.seq_len += 1 - self.token_len += 1 - - if self.args.eagle_path is not None: - final_layernorm = torch.cat(eagle_hidden_states, dim=-1) - - return logits, final_layernorm, presents, talker_embeds - - # some test functions - def build_prompt_template(self): - template = { - 'bos': '', - 'system': '{content}', - 'user': '{content}', - 'assistant': '{content}' - } - if hasattr(self.tokenizer, 'get_chat_template'): - template['jinja'] = {} - template['jinja']['chat_template'] = self.tokenizer.get_chat_template() - if None != self.tokenizer.bos_token: - template['jinja']['bos'] = self.tokenizer.bos_token - if None != self.tokenizer.eos_token: - template['jinja']['eos'] = self.tokenizer.eos_token - if self.model_type == 'baichuan': - template['user'] = '{content}' - template['assistant'] = '{content}' - if self.model_type == 'chatglm': - template['user'] = '{content}[gMASK]' - if self.model_type == 'chatglm2' and 'codegeex' not in self.args.path: - template['user'] = '[Round 1]\n\n问:{content}\n\n' - template['assistant'] = '答:{content}\n\n' - if 'chatglm3' in self.args.path or 'glm-4' in self.args.path: - template['bos'] = '[gMASK]' - template['system'] = '<|system|>\n{content}\n' - template['user'] = '<|user|>\n{content}\n' - template['assistant'] = '<|assistant|>\n{content}\n' - if self.model_type == 'llama': - if 'Llama-2' in self.args.path: - template['bos'] = '[INST] ' - template['system'] = "<>\n{content}\n<>\n\n" - template['user'] = '{content} [/INST]' - template['assistant'] = "{content}"; - if 'Llama-3' in self.args.path: - template['system'] = '<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>' - template['user'] = '<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>' - template['assistant'] = '<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>' - if 'TinyLlama' in self.args.path: - template['bos'] = '' - template['system'] = '<|system|>\n{content}\n' - template['user'] = '<|user|>\n{content}\n' - template['assistant'] = '<|assistant|>\n{content}\n' - if 'Yi' in self.args.path or 'SmolLM2' in self.args.path: - template['system'] = '<|im_start|>system\n{content}<|im_end|>\n' - template['user'] = '<|im_start|>user\n{content}<|im_end|>\n' - template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n' - if self.model_type == 'gemma2': - template['bos'] = '' - template['system'] = 'system\n{content}\n' - template['user'] = 'user\n{content}\n' - template['assistant'] = 'model\n{content}\n' - if self.model_type == 'gemma': - template['bos'] = '' - if self.model_type == 'internlm': - template['user'] = '<|User|>:{content}\n' - template['assistant'] = '<|Bot|>:{content}\n' - if self.model_type == 'phi-msft': - template['user'] = 'Instruct: {content}\n' - template['assistant'] = 'Output:{content}\n' - if self.model_type == 'openelm': - template['bos'] = '' - if self.model_type == 'internvl_chat': - if 'Qwen' in self.config.llm_config._name_or_path: - print("[DEBUG] Use qwen prompt template") - template['system'] = '<|im_start|>system\n{content}<|im_end|>\n' - template['user'] = '<|im_start|>user\n{content}<|im_end|>\n' - template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n' - if self.model_type == 'phi3': - template['system'] = '<|im_start|>system<|im_sep|>{content}<|im_end|>' - template['user'] = '<|im_start|>user<|im_sep|>{content}<|im_end|>' - template['assistant'] = '<|im_start|>assistant<|im_sep|>{content}<|im_end|>' - - if self.model_type == "gemma3": - template['bos'] = 'user\n' - template['system'] = '{content}\n\n' - template['user'] = '{content}\n' - template['assistant'] = 'model\n{content}\nuser\n' - if self.model_type == "gemma3_text": - template['bos'] = 'user\n' - template['system'] = '{content}\n\n' - template['user'] = '{content}\n' - template['assistant'] = 'model\n{content}\nuser\n' - if self.model_type in ['idefics3', 'smolvlm']: - template['bos'] = '<|im_start|>' - template['system'] = 'System: {content}\n' - template['user'] = 'User:{content}\n' - template['assistant'] = 'Assistant:{content}\n' - if 'qwen' in self.model_type or 'mimo' in self.model_type: - template['system'] = '<|im_start|>system\n{content}<|im_end|>\n' - template['user'] = '<|im_start|>user\n{content}<|im_end|>\n' - template['assistant'] = '<|im_start|>assistant\n{content}<|im_end|>\n' - if 'DeepSeek' in self.args.path or 'deepseek' in self.args.path: - template['bos'] = '<|begin_of_sentence|>' - template['system'] = '{content}\n' - template['user'] = '\nUser: {content}\n' - template['assistant'] = '\nAssistant: {content}\n<|end_of_sentence|>' - if self.model_type == "ernie4_5": - template['bos'] = '<|begin_of_sentence|>' - template['user'] = 'User: {content}\n' - template['assistant'] = 'Assistant: {content}\n<|end_of_sentence|>' - return template - - def build_prompt(self, messages): - if hasattr(self.tokenizer, 'apply_chat_template'): - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ) - return prompt - template = self.build_prompt_template() - prompt = template['bos'] - for item in messages: - role, content = item['role'], item['content'] - if '{content}' in template[role]: - prompt += template[role].format(content=content) - else: - prompt += role + '\n' + content +'\n' - assistant_prefix = template['assistant'].split('{content}')[0] - return prompt + assistant_prefix - - def str_to_ids(self, prompt): - if self.visual is not None: - return self.visual.str_to_ids(prompt) - if self.audio is not None: - return self.audio.str_to_ids(prompt) - input_ids = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")['input_ids'] - return input_ids - - def id_to_str(self, token_id): - try: - word = self.tokenizer.decode(int(token_id)) - except: - def contains_replacement(text): return '\uFFFD' in text - def decode_id(token_id): - return self.tokenizer.convert_tokens_to_string( - self.tokenizer._convert_id_to_token(int(token_id))) - def decode_ids(token_ids): - return self.tokenizer.convert_tokens_to_string( - self.tokenizer.convert_ids_to_tokens(token_ids)) - word = decode_id(int(token_id)) - # Smollm tokenizer will produce half chinese character, using buffer to decode - if contains_replacement(word): - self.decode_buffer.append(token_id) - buffer_txt = decode_ids(self.decode_buffer) - if not contains_replacement(buffer_txt): - word = buffer_txt - self.decode_buffer.clear() - else: - word = '' - return word - @torch.no_grad() def response(self, query): # self.imitate_quant() - self.decode_buffer = [] + self.model.decode_buffer = [] messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query} ] - prompt = self.build_prompt(messages) - input_ids = self.str_to_ids(prompt) - self.seq_len = input_ids.numel() - self.context_len = self.seq_len - 2 - self.token_len = 0 - past_key_values = [None for i in range(self.num_hidden_layers)] - token_id = input_ids - while self.token_len < self.max_length: - attention_mask = self.get_attention_mask() - position_ids = self.get_position_ids(token_id) - input_ids = self.embedding(token_id) - deepstack_embeds = self.visual.deepstacks() if self.visual is not None else None - logits, _, past_key_values, _ = self.forward(input_ids, - attention_mask, - position_ids, - past_key_values, - deepstack_embeds = deepstack_embeds) + prompt = self.tokenizer.apply_chat_template(messages) + + # Use model's tokenizer methods for encoding + if self.model.visual is not None: + input_ids = self.model.visual.str_to_ids(prompt) + elif self.model.audio is not None: + input_ids = self.model.audio.str_to_ids(prompt) + else: + input_ids = self.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")['input_ids'] + + seq_len = input_ids.numel() + new_tokens = 0 + past_key_values = [None for i in range(self.config.num_hidden_layers)] + + while new_tokens < self.max_new_tokens: + attention_mask = self.model.get_attention_mask(seq_len, new_tokens) + position_ids = self.model.get_position_ids(seq_len, new_tokens, input_ids) + input_embeds = self.model.embedding(input_ids) + deepstack_embeds = self.model.visual.deepstacks() if self.model.visual is not None else None + + logits, _, past_key_values, _ = self.model.forward( + input_ids=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + logits_index=torch.tensor([-1], dtype=torch.int32), + deepstack_embeds=deepstack_embeds + ) + token_id = torch.argmax(logits[:,-1,:]) - if token_id in self.stop_ids: + seq_len += 1 + new_tokens += 1 + if token_id in self.tokenizer.stop_ids: print("", end='\n') break - word = self.id_to_str(token_id) + + # Use tokenizer's method for decoding + word = self.tokenizer.id_to_str(token_id) print(word, end="", flush=True) + input_ids = token_id - if hasattr(self, 'talker') and self.talker is not None: - self.talker.generate() + if hasattr(self.model, 'talker') and self.model.talker is not None: + self.model.talker.generate() def export_mtp(self): if self.mtp is None: @@ -637,17 +169,17 @@ def export_mtp(self): mtp_onnx = self.mtp.export(self.onnx_path) if self.mnn_converter: self.mtp.unloaded_ops['/lm/lm_head/Linear'] = self.unloaded_ops['/lm/lm_head/Linear'] - MNNConveter(self, self.mtp.unloaded_ops).export(mtp_onnx) + MNNConverter(self, self.mtp.unloaded_ops).export(mtp_onnx) def export_eagle(self): if self.args.eagle_path is None: return from utils.eagle import Eagle - self.eagle = Eagle.get_eagle(self.model_type)(self.args.eagle_path, self) + self.eagle = Eagle.get_eagle(self.model_type)(self.args.eagle_path, self.model) eagle_onnx, eagle_fc_onnx = self.eagle.export(self.onnx_path) if self.mnn_converter: - MNNConveter(self, None).export(eagle_onnx) - MNNConveter(self, None).export(eagle_fc_onnx) + MNNConverter(self, None).export(eagle_onnx) + MNNConverter(self, None).export(eagle_fc_onnx) @spinner_run(f'export embedding to ') @@ -655,11 +187,11 @@ def export_embed(self): import ctypes from utils.torch_utils import quant as torch_quant - if hasattr(self, 'word_embeddings'): + if hasattr(self.model, 'word_embeddings'): # embedding model's embed - tensor_data = self.word_embeddings.weight.data + tensor_data = self.model.word_embeddings.weight.data else: - tensor_data = self.embed.embed.weight.data + tensor_data = self.model.embed.embed.weight.data format_bit = getattr(self.args, 'embed_bit', 16) @@ -717,7 +249,7 @@ def export_config(self, mnn_config = False): } if self.args.embed_bit < 16: config['embedding_file'] = f"embeddings_int{self.args.embed_bit}.bin" - if self.talker is not None: + if hasattr(self, 'talker') and self.talker is not None: config['system_prompt'] = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." config['talker_max_new_tokens'] = 2048 config['talker_speaker'] = "Chelsie" @@ -725,7 +257,7 @@ def export_config(self, mnn_config = False): config['dit_solver'] = 1 if self.model_type == "gemma3": config.update({'precision': "normal"}) - if self.visual is not None or self.audio is not None: + if (hasattr(self, 'visual') and self.visual is not None) or (hasattr(self, 'visual') and self.audio is not None): config['mllm'] = { 'backend_type': "cpu", "thread_num": 4, @@ -761,14 +293,14 @@ def quant_dequant(linear, quant_bit = self.args.quant_bit, quant_block = self.ar linear.weight.data = dq_weight return linear with torch.no_grad(): - for i in range(self.num_hidden_layers): - for name, child in self.blocks[i].self_attn.named_children(): + for i in range(self.config.num_hidden_layers): + for name, child in self.model.blocks[i].self_attn.named_children(): if isinstance(child, torch.nn.Linear): - setattr(self.blocks[i].self_attn, name, quant_dequant(child)) - for name, child in self.blocks[i].mlp.named_children(): + setattr(self.model.blocks[i].self_attn, name, quant_dequant(child)) + for name, child in self.model.blocks[i].mlp.named_children(): if isinstance(child, torch.nn.Linear): - setattr(self.blocks[i].mlp, name, quant_dequant(child)) - self.lm.lm = quant_dequant(self.lm.lm) + setattr(self.model.blocks[i].mlp, name, quant_dequant(child)) + self.model.lm.lm = quant_dequant(self.model.lm.lm) def unload_param(self): self.unloaded_ops = {} @@ -779,26 +311,26 @@ def build_faker(real, name): return faker # replace linear with fakelinear to save export memory and time with torch.no_grad(): - for i in range(len(self.blocks)): + for i in range(len(self.model.blocks)): # different kv cache shape in different layers - if isinstance(self.num_attention_heads, list): - self.blocks[i].self_attn.export_fused_attn = True - is_moe = hasattr(self.blocks[i].mlp, 'is_moe') and self.blocks[i].mlp.is_moe + if isinstance(self.config.num_attention_heads, list): + self.model.blocks[i].self_attn.export_fused_attn = True + is_moe = hasattr(self.model.blocks[i].mlp, 'is_moe') and self.model.blocks[i].mlp.is_moe if is_moe: - self.blocks[i].mlp.export_moe = True - for name, child in self.blocks[i].self_attn.named_children(): + self.model.blocks[i].mlp.export_moe = True + for name, child in self.model.blocks[i].self_attn.named_children(): if isinstance(child, torch.nn.Linear): - setattr(self.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear')) - for name, child in self.blocks[i].mlp.named_children(): + setattr(self.model.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear')) + for name, child in self.model.blocks[i].mlp.named_children(): if isinstance(child, torch.nn.Linear): - setattr(self.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear')) + setattr(self.model.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear')) if is_moe and isinstance(child, torch.nn.ModuleList): # experts self.experts.append(child) for j in range(len(child)): for name, cchild in child[j].named_children(): if isinstance(cchild, torch.nn.Linear): - setattr(self.blocks[i].mlp.experts[j], name, build_faker(cchild, f'/expert/{i}_{j}/{name}')) - self.lm.lm = build_faker(self.lm.lm, f'/lm/lm_head/Linear') + setattr(self.model.blocks[i].mlp.experts[j], name, build_faker(cchild, f'/expert/{i}_{j}/{name}')) + self.model.lm.lm = build_faker(self.model.lm.lm, f'/lm/lm_head/Linear') @spinner_run(f'export model weight to ') def onnx_load_param(self, onnx_path): @@ -815,18 +347,18 @@ def slim_onnx(self, onnx_model): def export_onnx(self): # unload linear weight to save export memory self.unload_param() - model = self - self.seq_len = 3 - self.token_len = 0 - input_ids = torch.arange(3, dtype=torch.long) - attention_mask = self.get_attention_mask() - position_ids = self.get_position_ids(input_ids) + model = self.model + seq_len = 3 + new_tokens = 0 + input_ids = torch.arange(seq_len, dtype=torch.long) + attention_mask = model.get_attention_mask(seq_len, new_tokens) + position_ids = model.get_position_ids(seq_len, new_tokens, input_ids) onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx' # For export onnx, don't need image or audio's embedding - input_ids = self.embed(input_ids) + input_ids = model.embedding(input_ids) past_key_values = torch.zeros(self.past_kv_shape) logits_index = torch.tensor([-1], dtype=torch.int32) - if hasattr(self, 'talker') and self.talker is not None: + if hasattr(model, 'talker') and model.talker is not None: output_names = ['logits', 'hidden_states', 'presents', 'talker_embeds'] else: output_names = ['logits', 'hidden_states', 'presents'] @@ -834,7 +366,7 @@ def export_onnx(self): # Qwen3-VL if self.model_type in ['qwen3_vl', 'qwen3_vl_moe']: # add deepstack_embeds input - deepstack_embeds = torch.randn(3, 1, self.hidden_size) + deepstack_embeds = torch.randn(3, 1, self.config.hidden_size) onnx_export( model, (input_ids, attention_mask, position_ids, past_key_values, logits_index, deepstack_embeds), onnx_model, @@ -857,14 +389,12 @@ def export_onnx(self): return onnx_model def awq_quant(self): - self.awq_quantizer = AwqQuantizer(self) + self.awq_quantizer = AwqQuantizer(self.model) self.awq_quantizer.quantize() - self.is_awq_quantized = True def smooth_quant(self): - self.smooth_quantizer = SmoothQuantizer(model = self, act_bit=self.args.act_bit, act_sym=self.args.act_sym) + self.smooth_quantizer = SmoothQuantizer(model = self.model, act_bit=self.args.act_bit, act_sym=self.args.act_sym) self.smooth_quantizer.quantize() - self.is_smooth_quantized = True def export_vision(self): if self.visual is None: @@ -908,7 +438,7 @@ def export_talker(self): def export_language(self): # export_embedding - if self.mnn_converter and self.tie_word_embeddings: + if self.mnn_converter and self.args.tie_word_embeddings: pass # mnn tie_word_embeddings need't export embedding else: self.export_embed() @@ -918,7 +448,9 @@ def export_language(self): if self.args.onnx_slim: self.slim_onnx(onnx_model) if self.mnn_converter: - MNNConveter(self, self.unloaded_ops).export(onnx_model) + tie_embeddings_info = MNNConverter(self, self.unloaded_ops).export(onnx_model) + if tie_embeddings_info is not None: + self.llm_config['tie_embeddings'] = tie_embeddings_info else: self.onnx_load_param(onnx_model) @@ -927,10 +459,8 @@ def export(self, export_type): self.awq_quant() if self.args.smooth: self.smooth_quant() - if self.args.hqq and self.args.sym: - self.args.sym = False export_mnn = export_type == 'mnn' - self.mnn_converter = MNNConveter(self) if export_mnn else None + self.mnn_converter = MNNConverter(self) if export_mnn else None self.export_talker() self.export_vision() self.export_audio() @@ -950,391 +480,43 @@ def export(self, export_type): @spinner_run(f'export tokenizer to ') def export_tokenizer(self): - # load tokenizer file - tokenizer_model = os.path.join(self.args.tokenizer_path, 'tokenizer.model') - ice_text_model = os.path.join(self.args.tokenizer_path, 'ice_text.model') - try: - import sentencepiece as spm - if os.path.exists(tokenizer_model): - self.sp_model = spm.SentencePieceProcessor(tokenizer_model) - elif os.path.exists(ice_text_model): - self.sp_model = spm.SentencePieceProcessor(ice_text_model) - else: - self.sp_model = None - except: - self.sp_model = None - merge_file = os.path.join(self.args.path, 'merges.txt') - if os.path.exists(merge_file): - self.merge_txt = merge_file - else: - self.merge_txt = None - # TOKENIZER MAGIC NUMBER - MAGIC_NUMBER = 430 - # TOKENIZER TYPE - SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3 - def write_line(fp, *args): - for arg in args: - for token in arg: - fp.write(str(token) + ' ') - fp.write('\n') - def write_header(fp, type, speicals, prefix = []): - fp.write(f'{MAGIC_NUMBER} {type}\n') - fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n') - write_line(fp, speicals, self.stop_ids, prefix) - - file_path = os.path.join(self.args.dst_path, "tokenizer.txt") - special_list = list(self.tokenizer.added_tokens_decoder.keys()) - if hasattr(self.tokenizer, 'special_tokens'): - for k, v in self.tokenizer.special_tokens.items(): - special_list.append(v) - if hasattr(self.tokenizer, 'all_special_ids'): #gemma3 - special_list.extend(self.tokenizer.all_special_ids) - if hasattr(self.tokenizer, 'gmask_token_id'): - special_list.append(self.tokenizer.gmask_token_id) - if hasattr(self.model, 'generation_config') and self.model.generation_config is not None: - generation_config = self.model.generation_config - if hasattr(generation_config, 'user_token_id'): - special_list.append(generation_config.user_token_id) - if hasattr(generation_config, 'assistant_token_id'): - special_list.append(generation_config.assistant_token_id) - vocab_list = [] - prefix_list = [] - if hasattr(self.tokenizer, 'get_prefix_tokens'): - prefix_list = self.tokenizer.get_prefix_tokens() - - # Simple prefix token detection - if len(prefix_list) == 0: - try: - test_txt = 'A' - ids = self.tokenizer.encode(test_txt) - get_txt = self.tokenizer.decode(ids[-1]) - if len(ids) > 1 and get_txt == test_txt: - prefix_list += ids[:-1] - except Exception: - pass - - if self.sp_model is not None: - # senetencepiece - NORMAL = 1; UNKNOWN = 2; CONTROL = 3 - USER_DEFINED = 4; UNUSED = 5; BYTE = 6 - for i in range(self.sp_model.GetPieceSize()): - token = self.sp_model.IdToPiece(i) - score = self.sp_model.GetScore(i) - token_type = NORMAL - if self.sp_model.IsUnknown(i): - token_type = UNKNOWN - elif self.sp_model.IsControl(i): - token_type = CONTROL - elif self.sp_model.IsUnused(i): - token_type = UNUSED - elif self.sp_model.IsByte(i): - token_type = BYTE - if self.args.path == 'Chatglm_6b': - if '' in token: token = '\n' - if '<|tab|>' in token: token = '\t' - if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')]) - if '▁' in token: token = token.replace('▁', ' ') - token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") - vocab_list.append(f'{token_encode} {score} {token_type}\n') - # add special tokens to vocab_list - for index in special_list: - if index >= len(vocab_list): - try: - token = self.tokenizer.decode(index) - token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") - vocab_list.append(f'{token_encode} {0} {NORMAL}\n') - except: - pass - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, SENTENCEPIECE, special_list, prefix_list) - if self.model_type == "gemma3" or self.model_type == "gemma3-text": - fp.write(f'{len(vocab_list) + 1}\n') # len(vocab_list)==262144, self.tokenizer([262144])=='image_soft_token' is a special token - else: - fp.write(f'{len(vocab_list)}\n') - for vocab in vocab_list: - fp.write(vocab) - elif hasattr(self.tokenizer, 'mergeable_ranks'): - # tikton - vocab_list = [] - for k, v in self.tokenizer.mergeable_ranks.items(): - line = base64.b64encode(k).decode("utf8") + "\n" - vocab_list.append(line) - if hasattr(self.tokenizer, 'special_tokens'): - for k, v in self.tokenizer.special_tokens.items(): - line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n" - vocab_list.append(line) - if hasattr(self.tokenizer, 'added_tokens_decoder'): - for k, v in self.tokenizer.added_tokens_decoder.items(): - line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n" - vocab_list.append(line) - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, TIKTOIKEN, special_list, prefix_list) - fp.write(f'{len(vocab_list)}\n') - for vocab in vocab_list: - fp.write(vocab) - elif self.merge_txt is not None: - # huggingface tokenizer - merge_list = [] - vocab = self.tokenizer.get_vocab() - special_list = list(self.tokenizer.added_tokens_decoder.keys()) - vocab_list = ['' for i in range(len(vocab))] - # load vocab - for k, v in vocab.items(): - vocab_list[int(v)] = k - # load merge - with open(self.merge_txt, 'rt') as merge: - for line in merge.readlines(): - merge_list.append(line) - # write to tokenizer.txt - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, HUGGINGFACE, special_list) - fp.write(f'{len(vocab_list)} {len(merge_list)}\n') - for v in vocab_list: - fp.write(v + '\n') - for m in merge_list: - fp.write(m) - else: - # Determine tokenizer type based on tokenizer class and characteristics - tokenizer_class_name = type(self.tokenizer).__name__.lower() - vocab = self.tokenizer.get_vocab() - - # Check for SentencePiece-based tokenizers first - if ('xlmroberta' in tokenizer_class_name or - 'roberta' in tokenizer_class_name or - 'sentencepiece' in tokenizer_class_name or - hasattr(self.tokenizer, 'sp_model') or - (hasattr(self.tokenizer, 'vocab_file') and - self.tokenizer.vocab_file and 'sentencepiece' in self.tokenizer.vocab_file.lower()) or - # Check if tokenizer uses SentencePiece patterns (▁ prefix) - (len(vocab) > 0 and any('▁' in token for token in list(vocab.keys())[:100]))): - tokenizer_type = SENTENCEPIECE - print(f"Detected SentencePiece-based tokenizer: {tokenizer_class_name}") - elif 'bert' in tokenizer_class_name: - tokenizer_type = BERT - print(f"Detected BERT tokenizer: {tokenizer_class_name}") - else: - tokenizer_type = TIKTOIKEN - print(f"Detected TikToken tokenizer: {tokenizer_class_name}") - - vocab = self.tokenizer.get_vocab() - - if tokenizer_type == SENTENCEPIECE: - # Handle SentencePiece tokenizer (like XLMRoberta) - # Try to get SentencePiece model if available - sp_model_path = None - if hasattr(self.tokenizer, 'vocab_file') and self.tokenizer.vocab_file: - sp_model_path = self.tokenizer.vocab_file - elif hasattr(self.tokenizer, 'sp_model_kwargs'): - sp_model_path = getattr(self.tokenizer, 'vocab_file', None) - - if sp_model_path and os.path.exists(sp_model_path): - # Use existing SentencePiece export logic - print(f"Found SentencePiece model file: {sp_model_path}") - # This will be handled by the existing SentencePiece logic above - # For now, fall back to vocab-based export - pass - - # Export SentencePiece vocabulary in the correct format - vocab_list = [] - NORMAL = 1 # SentencePiece piece type - - for token, token_id in sorted(vocab.items(), key=lambda x: x[1]): - try: - # SentencePiece tokens are typically already properly encoded - token_bytes = token.encode('utf-8') - token_b64 = base64.b64encode(token_bytes).decode('utf-8') - # Format: token_base64 score piece_type - vocab_list.append(f'{token_b64} 0.0 {NORMAL}\n') - except Exception as e: - print(f"Warning: Failed to encode SentencePiece token '{token}': {e}") - # Use replacement character for problematic tokens - token_b64 = base64.b64encode('▁'.encode('utf-8')).decode('utf-8') - vocab_list.append(f'{token_b64} 0.0 {NORMAL}\n') - - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, SENTENCEPIECE, special_list, prefix_list) - fp.write(f'{len(vocab_list)}\n') - for vocab_line in vocab_list: - fp.write(vocab_line) - else: - # Handle BERT or TikToken tokenizer - # bert tokenizer - def unicode_to_byte(u: int): - # Handle special unicode mappings for BERT tokenizers - if u >= 256 and u <= 288: - return u - 256 - if u >= 289 and u <= 322: - return u - 162 - if u == 323: - return 173 - return u - - vocab_list = ['' for i in range(len(vocab))] - - # Process vocabulary with better UTF-8 handling - for k, v in vocab.items(): - if tokenizer_type == "BERT": - try: - # For BERT tokenizers, preserve the original token format - # Most BERT models already have proper UTF-8 encoded tokens - vocab_list[int(v)] = k.encode('utf-8') - except Exception as e: - # Fallback: try unicode_to_byte conversion for special cases - try: - vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]) - except Exception as e2: - print(f"Warning: Failed to encode token '{k}' with id {v}: {e2}") - vocab_list[int(v)] = k.encode('utf-8', errors='replace') - else: - # Fallback: try unicode_to_byte conversion for special cases - try: - vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]) - except Exception as e2: - print(f"Warning: Failed to encode token '{k}' with id {v}: {e2}") - vocab_list[int(v)] = k.encode('utf-8', errors='replace') - - special_list = list(self.tokenizer.added_tokens_decoder.keys()) - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, tokenizer_type, special_list) - fp.write(f'{len(vocab_list)}\n') - for v in vocab_list: - line = base64.b64encode(v).decode("utf8") + "\n" - fp.write(line) - return file_path + return self.tokenizer.export(self.args.dst_path) class EmbeddingExporter(LlmExporter): def __init__(self, args): super().__init__(args) - self.dst_name = 'reranker' if self.is_reranker else 'embedding' - - def word_embed(self, input_ids): - if hasattr(self, 'word_embeddings'): - return self.word_embeddings(input_ids.view(1, -1)) - return self.embed(input_ids.view(1, -1)) - - def bge_forward(self, inputs_embeds, attention_mask, position_ids): - # bert absolute position - inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size) - position_embeddings = self.position_embeddings(position_ids) - embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings - hidden_states = self.embedding_layernorm(embeddings) - for i in range(self.num_hidden_layers): - hidden_states = self.blocks[i](hidden_states, attention_mask)[0] - sentence_embeddings = hidden_states[:, 0] - sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) - return sentence_embeddings - - def gte_reranker_forward(self, inputs_embeds, attention_mask, position_ids): - freqs = position_ids.float().reshape(-1, 1) * self.inv_freq - emb = torch.cat((freqs, freqs), dim=-1) - rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1) - hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings) - for i in range(self.num_hidden_layers): - hidden_states = self.blocks[i](hidden_states, attention_mask, rope_embeds)[0] - pooled_output = self.lm(hidden_states) - logits = self.classifier(pooled_output) - return logits - - def gte_embedding_forward(self, inputs_embeds, attention_mask, position_ids): - # rope position - inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size) - freqs = position_ids.float().reshape(-1, 1) * self.inv_freq - emb = torch.cat((freqs, freqs), dim=-1) - rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1) - attention_bias = 1 - attention_mask.float() - hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings) - for i in range(self.num_hidden_layers): - hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0] - sentence_embeddings = hidden_states[:, 0] - sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) - return sentence_embeddings - - def gte_forward(self, inputs_embeds, attention_mask, position_ids): - if self.is_reranker: - return self.gte_reranker_forward(inputs_embeds, attention_mask, position_ids) - return self.gte_embedding_forward(inputs_embeds, attention_mask, position_ids) - - def qwen3_forward(self, inputs_embeds, attention_mask, position_ids): - hidden_states = inputs_embeds - rotary_pos_emb = self.rotary(position_ids) - for i in range(len(self.blocks)): - hidden_states, _ = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, None) - last_hidden_states = hidden_states[:, -1, :] - last_hidden_states = self.final_layernorm_(last_hidden_states) - return last_hidden_states - - def forward(self, inputs_embeds, attention_mask, position_ids): - if self.model_type == 'bert': - return self.bge_forward(inputs_embeds, attention_mask, position_ids) - if self.model_type == 'new': - return self.gte_forward(inputs_embeds, attention_mask, position_ids) - if self.model_type == 'qwen3': - return self.qwen3_forward(inputs_embeds, attention_mask, position_ids) - raise RuntimeError(f'Not support embedding model: {self.model_type}!') def response(self, query): - self.eval() + self.model.eval() prompt = self.build_prompt(query) input_ids = self.tokenizer(prompt)['input_ids'] - self.seq_len = len(input_ids) + seq_len = len(input_ids) input_ids = torch.tensor(input_ids) - position_ids = self.get_position_ids() - attention_mask = self.get_attention_mask() - inputs_embeds = self.word_embed(input_ids) - res = self.forward(inputs_embeds, attention_mask, position_ids) - # print(res) + position_ids = self.model.get_position_ids(seq_len) + attention_mask = self.model.get_attention_mask(seq_len) + inputs_embeds = self.model.word_embed(input_ids) + res = self.model.forward(inputs_embeds, attention_mask, position_ids) + print(res, res.shape) return res - @spinner_run(f'load pretrained model ') + def build_prompt(self, content): + if self.config.model_type == 'bert': + return f'[CLS]{content}[SEP]' + if self.config.model_type == 'new': + return f' {content}' + if self.config.model_type == 'qwen3': + return f'{content}<|endoftext|>' + + @spinner_run(f'load pretrained model ', True) def load_model(self, model_path): - self.is_reranker = False - if 'Qwen3' in model_path: - self.token_len = 0 - super().load_model(model_path) - self.model.float() - self.llm_config["jinja"]["chat_template"] = self.build_prompt("{{ messages | map(attribute='content') | join('') }}") - return model_path - self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - self.config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - self.config._attn_implementation = 'eager' + self.model = EmbeddingModel.from_pretrained(model_path, args=self.args) + self.config = self.model.config self.model_type = self.config.model_type - if 'gte' in model_path and 'rank' in model_path: - self.is_reranker = True - from transformers import AutoModelForSequenceClassification - self.model = AutoModelForSequenceClassification.from_pretrained(model_path, config=self.config, trust_remote_code=True).float().eval() - self.classifier = self.model.classifier - self.model = self.model.new - else: - self.model = AutoModel.from_pretrained(model_path, config=self.config, trust_remote_code=True).float().eval() - transformer = self.model.encoder - self.lm_ = self.model.pooler - self.embed_ = self.model.embeddings - self.word_embeddings = self.embed_.word_embeddings - self.token_type_embeddings = self.embed_.token_type_embeddings.weight.data[0] - self.embedding_layernorm = self.embed_.LayerNorm - if hasattr(self.embed_, 'position_embeddings'): - self.position_embeddings = self.embed_.position_embeddings - self.hidden_size = self.word_embeddings.weight.shape[-1] - self.blocks = transformer.layer - if self.model_type == 'new': - self.inv_freq = self.embed_.rotary_emb.inv_freq - # some wrapper - self.stop_ids = [] - self.num_hidden_layers = len(self.blocks) - self.embed = self.embed_ - self.lm = self.lm_ - # some config for export - self.model_dynamic_axes = { - "input_ids" : { 1: "seq_len" }, - "position_ids" : { 1: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" } - } - self.attention_mask_type = 'int' + self.tokenizer = LlmTokenizer(model_path, self.model_type) self.llm_config = { - 'hidden_size' : self.hidden_size, - 'layer_nums' : self.num_hidden_layers, - 'attention_mask': self.attention_mask_type, - 'key_value_shape': [], + 'model_type': self.config.model_type, + 'hidden_size' : self.config.hidden_size, + 'attention_mask': 'int', "jinja": { "chat_template": self.build_prompt("{{ messages | map(attribute='content') | join('') }}") }, @@ -1343,17 +525,16 @@ def load_model(self, model_path): return model_path def export_reranker(self): - model = self.eval() - self.seq_len = 4 + seq_len = 4 input_ids = torch.arange(12, dtype=torch.long) - position_ids = self.get_position_ids() - attention_mask = self.get_attention_mask() - inputs_embeds = self.word_embed(input_ids) - inputs_embeds = inputs_embeds.reshape(3, 4, self.hidden_size) + position_ids = self.model.get_position_ids(seq_len) + attention_mask = self.model.get_attention_mask(seq_len) + inputs_embeds = self.model.word_embed(input_ids) + inputs_embeds = inputs_embeds.reshape(3, 4, self.config.hidden_size) attention_mask = torch.zeros(3, 1, 1, 4, dtype=torch.float) onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx' onnx_export( - model, (inputs_embeds, attention_mask, position_ids), + self.model, (inputs_embeds, attention_mask, position_ids), onnx_model, input_names=[ 'input_ids', @@ -1370,18 +551,20 @@ def export_reranker(self): @spinner_run(f'export onnx model to ') def export_onnx(self): - self.unload_param() - if self.is_reranker: + if self.model_type == 'qwen3': + self.unload_param() + else: + self.unloaded_ops = None + if self.model.is_reranker: return self.export_reranker() - model = self.eval() - self.seq_len = 3 - input_ids = torch.arange(3, dtype=torch.long) - position_ids = self.get_position_ids() - attention_mask = self.get_attention_mask() - inputs_embeds = self.word_embed(input_ids) + seq_len = 3 + input_ids = torch.arange(seq_len, dtype=torch.long) + position_ids = self.model.get_position_ids(seq_len) + attention_mask = self.model.get_attention_mask(seq_len) + inputs_embeds = self.model.word_embed(input_ids) onnx_model = f'{self.onnx_path}/{self.dst_name}.onnx' onnx_export( - model, (inputs_embeds, attention_mask, position_ids), + self.model, (inputs_embeds, attention_mask, position_ids), onnx_model, input_names=[ 'input_ids', @@ -1389,7 +572,11 @@ def export_onnx(self): 'position_ids' ], output_names=['sentence_embeddings'], - dynamic_axes=self.model_dynamic_axes) + dynamic_axes={ + "input_ids" : { 1: "seq_len" }, + "position_ids" : { 1: "seq_len" }, + "attention_mask" : { 2: "seq_len", 3: "seq_len" } + }) return onnx_model def export(self, export_type): @@ -1401,8 +588,10 @@ def export(self, export_type): if self.args.onnx_slim: self.slim_onnx(onnx_model) if export_mnn: - transformer_fuse = not self.is_reranker - MNNConveter(self, self.unloaded_ops).export(onnx_model, transformer_fuse=transformer_fuse) + transformer_fuse = not self.model.is_reranker + tie_embeddings_info = MNNConverter(self, self.unloaded_ops).export(onnx_model, transformer_fuse=transformer_fuse) + if tie_embeddings_info is not None: + self.llm_config['tie_embeddings'] = tie_embeddings_info # delete onnx file try: for file in glob.glob(f'{self.onnx_path}/*'): @@ -1411,22 +600,6 @@ def export(self, export_type): except Exception as e: print(f"remove onnx error: {e}") - def build_prompt(self, content): - if self.model_type == 'bert': - return f'[CLS]{content}[SEP]' - if self.model_type == 'new': - return f' {content}' - if self.model_type == 'qwen3': - return f'{content}<|endoftext|>' - - def get_position_ids(self) -> torch.Tensor: - return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) - - def get_attention_mask(self) -> torch.Tensor: - if self.model_type == 'qwen3': - return super().get_attention_mask() - return torch.ones([1, 1, self.seq_len, self.seq_len], dtype=torch.float) - def export(path, type = None, tokenizer_path = None, @@ -1545,4 +718,4 @@ def main(): llm_exporter.export(args.export) if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/transformers/llm/export/utils/audio.py b/transformers/llm/export/utils/audio.py index 8330097256..02846d8e93 100644 --- a/transformers/llm/export/utils/audio.py +++ b/transformers/llm/export/utils/audio.py @@ -6,18 +6,21 @@ class Audio(torch.nn.Module): def __init__(self, audio, base): super().__init__() - self.model_type = base.model_type + self.model_type = base.config.model_type self.audio = audio self.embed_ = base.embed self.tokenizer = base.tokenizer - self.config = base.config - self.hidden_size = base.hidden_size - self.llm_config = base.llm_config + self.config = base.config.origin_config + self.hidden_size = base.config.hidden_size + self.llm_config = { 'is_audio': True } self.rope_ratio = 1.0 self.quant_bit = 16 self.init_config() self.load() + def get_config(self): + return self.llm_config + @staticmethod def get_audio(model_type): audio_models = { @@ -29,7 +32,7 @@ def get_audio(model_type): return None def init_config(self): - self.llm_config['is_audio'] = True + pass def load(self): raise NotImplementedError diff --git a/transformers/llm/export/utils/awq_quantizer.py b/transformers/llm/export/utils/awq_quantizer.py index 3217c00171..ea4cb8d4fd 100644 --- a/transformers/llm/export/utils/awq_quantizer.py +++ b/transformers/llm/export/utils/awq_quantizer.py @@ -252,7 +252,7 @@ def _search_best_scale( x_mean = (x_sum / num_elements).to(inp.dtype) AwqQuantizer.clear_memory(x_sum) - + inp = inp.to(next(layers[0].parameters()).device) # [STEP 3]: Compute output of module with torch.no_grad(): @@ -628,7 +628,7 @@ def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert): if not any(key in name for key in modules_to_not_convert): filtered_layers[name] = linear_layer return filtered_layers - + @staticmethod def to_device(module, device): for child_name, child_module in module.named_children(): @@ -638,8 +638,8 @@ def to_device(module, device): sub_child.to(device) else: child_module.to(device) - - + + @staticmethod def get_named_linears(module): linears = {} @@ -654,10 +654,10 @@ def get_named_linears(module): if isinstance(mod, torch.nn.Linear): full_name = f"{child_name}.{name}" if name else child_name linears[full_name] = mod - + return linears - + @staticmethod def get_op_by_name(module, op_name): for child_name, child_module in module.named_children(): @@ -665,7 +665,7 @@ def get_op_by_name(module, op_name): return child_module if child_name == 'self_attn': for name, mod in child_module.named_children(): - if name != 'config': + if name != 'config': full_name = f"{child_name}.{name}" if full_name == op_name: return mod @@ -674,10 +674,10 @@ def get_op_by_name(module, op_name): full_name = f"{child_name}.{name}" if name else child_name if full_name == op_name: return mod - + if op_name == "": return module - + raise ValueError(f"Cannot find op {op_name} in module {module}") @staticmethod @@ -694,7 +694,7 @@ def get_calib_dataset( if data == "pileval": dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") elif data == "wikitext": - + dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split=split) else: dataset = load_dataset(data, split=split) @@ -759,7 +759,7 @@ def clear_memory(weight=None): gc.collect() torch.cuda.empty_cache() - + @staticmethod def get_op_name(module, op): if module is op: @@ -767,9 +767,9 @@ def get_op_name(module, op): for child_name, child_module in module.named_children(): if child_name == 'self_attn': if child_module is op: - return child_name + return child_name for name, mod in child_module.named_children(): - if name != 'config': + if name != 'config': if mod is op: return f"{child_name}.{name}" for sub_name, sub_mod in mod.named_modules(): @@ -779,12 +779,12 @@ def get_op_name(module, op): else: if child_module is op: return child_name - + for name, mod in child_module.named_modules(): if mod is op: full_name = f"{child_name}.{name}" if name else child_name return full_name - + raise ValueError(f"Cannot find op {op} in module {module}") @staticmethod @@ -811,14 +811,13 @@ def init_quant(self, n_samples=128, max_seq_len=512): inps = [] layer_kwargs = {} # build inps - self.model.seq_len = samples.numel() - self.model.context_len = samples.numel() - 2 - self.model.token_len = 0 + seq_len = samples.numel() + new_tokens = 0 best_device = AwqQuantizer.get_best_device() inps = self.model.embedding(samples).to(best_device) - position_ids = self.model.get_position_ids() + position_ids = self.model.get_position_ids(seq_len, new_tokens) rotary_pos_emb = self.model.rotary(position_ids) - attention_mask = self.model.get_attention_mask() + attention_mask = self.model.get_attention_mask(seq_len, new_tokens) layer_kwargs["rotary_pos_emb"] = rotary_pos_emb.to(best_device) layer_kwargs["attention_mask"] = attention_mask.to(best_device) del samples diff --git a/transformers/llm/export/utils/config.py b/transformers/llm/export/utils/config.py new file mode 100644 index 0000000000..569cab745e --- /dev/null +++ b/transformers/llm/export/utils/config.py @@ -0,0 +1,116 @@ +from transformers import PretrainedConfig, AutoConfig +from utils.model_mapper import ModelMapper +from typing import Optional, List, Dict, Any, Union +from dataclasses import dataclass, field, asdict + +# model config + +class LlmConfig(PretrainedConfig): + model_type = "llm_config" + + def __init__(self, **kwargs): + self.hidden_size = kwargs.pop("hidden_size", 0) + self.num_attention_heads = kwargs.pop("num_attention_heads", 0) + self.num_hidden_layers = kwargs.pop("num_hidden_layers", 0) + self.num_key_value_heads = kwargs.pop("num_key_value_heads", self.num_attention_heads) + self.head_dim = kwargs.pop("head_dim", self.hidden_size // self.num_attention_heads if self.num_attention_heads > 0 else 0) + self.rope_theta = kwargs.pop("rope_theta", 10000.0) + self.rope_ratio = kwargs.pop("rope_ratio", 1.0) + self.sliding_window = kwargs.pop("sliding_window", 0) + self.sliding_window = self.sliding_window if self.sliding_window is not None else 0 + self.layer_types = kwargs.pop("layer_types", []) + self.attention_type = kwargs.pop("attention_type", 'full') + self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + self.model_map = kwargs.pop("model_map", {}) + super().__init__(**kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **kwargs) + + model_type, model_map = ModelMapper().get_map(config) + llm_config_kwargs = { + 'origin_config': config, + 'model_type': model_type, + 'model_map': model_map + } + llm_config = cls(**llm_config_kwargs) + # rename attribute for different models + ModelMapper.do_map(llm_config, config, model_map['config']) + + # Post-processing and setting defaults + if llm_config.num_key_value_heads is None: + llm_config.num_key_value_heads = llm_config.num_attention_heads + + if llm_config.rope_theta is None: + llm_config.rope_theta = 10000.0 + + if llm_config.rope_ratio is None: + llm_config.rope_ratio = 1.0 + + if llm_config.head_dim is None and llm_config.hidden_size > 0 and llm_config.num_attention_heads > 0: + if isinstance(llm_config.num_attention_heads, list): + llm_config.head_dim = [llm_config.hidden_size // atten_head for atten_head in llm_config.num_attention_heads] + else: + llm_config.head_dim = llm_config.hidden_size // llm_config.num_attention_heads + + # Determine attention type + sliding_attn_layers = [] + if hasattr(llm_config, 'layer_types') and llm_config.layer_types: + for i in range(len(llm_config.layer_types)): + if llm_config.layer_types[i] == 'sliding_attention': + sliding_attn_layers.append(i) + + if llm_config.num_hidden_layers > 0 and len(sliding_attn_layers) >= llm_config.num_hidden_layers: + llm_config.attention_type = 'sliding' + elif len(sliding_attn_layers) > 0: + llm_config.attention_type = 'mix' + else: + llm_config.attention_type = 'full' + + return llm_config + +# export config + +@dataclass +class VisionExportConfig: + """Configuration for vision-related capabilities.""" + image_mean: Optional[List[float]] = field(default_factory=list) + image_norm: Optional[List[float]] = field(default_factory=list) + image_size: Optional[Union[int, List[int]]] = None + image_size_unit: Optional[int] = None + vision_start: Optional[int] = None + vision_end: Optional[int] = None + image_pad: Optional[int] = None + num_grid_per_side: Optional[int] = None + has_deepstack: bool = False + image_max_size: Optional[int] = None + global_image: Optional[int] = None + vision_id_start_id: Optional[int] = None + vision_id_end_id: Optional[int] = None + vision_slice_start_id: Optional[int] = None + vision_slice_end_id: Optional[int] = None + +@dataclass +class LLMExportConfig: + """Top-level container for all export configurations.""" + is_audio: bool = False + is_visual: bool = False + has_talker: bool = False + attention_mask: str = 'float' + attention_type: str = 'full' + sliding_window: int = 0 + tie_embeddings: Optional[List[Union[int]]] = field(default_factory=list) + jinja: Dict[str, Any] = field(default_factory=dict) + vision: Optional[VisionExportConfig] = None + + def to_dict(self) -> Dict[str, Any]: + """Converts the configuration to a dictionary for JSON serialization.""" + nested_dict = asdict(self) + vision_data = nested_dict.pop('vision', None) + + if vision_data: + nested_dict.update(vision_data) + + final_dict = {key: value for key, value in nested_dict.items() if value is not None} + return final_dict \ No newline at end of file diff --git a/transformers/llm/export/utils/eagle.py b/transformers/llm/export/utils/eagle.py index 630c7f23fb..3f73e3e0e6 100644 --- a/transformers/llm/export/utils/eagle.py +++ b/transformers/llm/export/utils/eagle.py @@ -22,7 +22,7 @@ def __init__(self, eagle_path, base): config_file_path = eagle_path + "/config.json" self.eagle_config = PretrainedConfig.from_json_file(config_file_path) - self.model_type = base.model_type + self.model_type = base.config.model_type self.eagle_path = eagle_path self.config = base.config @@ -32,14 +32,12 @@ def __init__(self, eagle_path, base): self.rope_theta = 10000 self.rope_ratio = 1.0 self.head_dim = self.config.head_dim - self.config.model_type = base.model_type - self.config.model_map = base.model_map - self.hidden_size = base.hidden_size + self.hidden_size = self.config.hidden_size if self.eagle_config.hidden_size != self.hidden_size: raise RuntimeError(f'eagle_config hidden_size not equal: {self.eagle_config.hidden_size}, {self.hidden_size}!') - self.past_kv_shape = base.past_kv_shape - self.num_attention_heads = base.num_attention_heads - self.llm_config = base.llm_config + # self.past_kv_shape = base.past_kv_shape + self.num_attention_heads = self.config.num_attention_heads + self.past_kv_shape = [self.config.num_hidden_layers, 2, 1, 0, self.config.num_key_value_heads, self.config.head_dim] self.head_dim = self.config.head_dim self.num_key_value_heads = self.config.num_key_value_heads @@ -67,7 +65,7 @@ def __init__(self, eagle_path, base): # midlayer.input_layernorm self.midlayer.input_layernorm = RMSNorm(self.hidden_size, eps=self.eagle_config.rms_norm_eps) # midlayer.self_attn - self.midlayer.self_attn = Attention(None, 0, self.config) + self.midlayer.self_attn = Attention(None, 0, self.config, base.rotary, self.config.model_map) self.midlayer.self_attn.q_proj = nn.Linear(self.hidden_size * 2, self.num_attention_heads * self.head_dim, bias=False) self.midlayer.self_attn.k_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False) self.midlayer.self_attn.v_proj = nn.Linear(self.hidden_size * 2, self.num_key_value_heads * self.head_dim, bias=False) @@ -122,7 +120,7 @@ def get_eagle(model_type): } if model_type in eagles: return eagles[model_type] - return None + return LlamaEagle @spinner_run(f'export onnx model to ') def export(self, onnx_path): diff --git a/transformers/llm/export/utils/mnn_converter.py b/transformers/llm/export/utils/mnn_converter.py index 70fa755072..f3711f7f7a 100644 --- a/transformers/llm/export/utils/mnn_converter.py +++ b/transformers/llm/export/utils/mnn_converter.py @@ -14,21 +14,18 @@ EXPORT_LOG = '.export.log' -class MNNConveter: - def __init__(self, config, weight_ops = None): +class MNNConverter: + def __init__(self, exporter, weight_ops = None): self.weight_ops = weight_ops - self.config = config - self.quant_block = config.args.quant_block - self.quant_bit = config.args.quant_bit - self.lm_quant_bit = config.args.lm_quant_bit - self.symmetric = config.args.sym - self.hqq = config.args.hqq + self.exporter = exporter + self.args = exporter.args self.mnn_weight_offset = 0 - if os.path.exists(config.args.mnnconvert): - self.mnnconvert = config.args.mnnconvert + if os.path.exists(self.args.mnnconvert): + self.mnnconvert = self.args.mnnconvert else: self.mnnconvert = None self.lm_weight = None + self.tie_embeddings_info = None def convert(self, convert_args): sfd = os.dup(1) @@ -72,7 +69,7 @@ def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, grou convert_args += ['--weightQuantAsymmetric=0'] if save_external_data: convert_args += ['--saveExternalData'] - if self.hqq: + if self.args.hqq: convert_args += ['--hqq'] convert_args += args self.convert(convert_args) @@ -121,15 +118,15 @@ def removeDupOps(self, mnn_path): def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False, weight_sym = None): self.onnx_model_path = onnx_path self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn') - self.mnn_model_path = os.path.join(self.config.args.dst_path, self.mnn_name) + self.mnn_model_path = os.path.join(self.args.dst_path, self.mnn_name) self.mnn_weight_path = f'{self.mnn_model_path}.weight' if self.weight_ops is None: if quant_bit is None: - quant_bit = self.quant_bit + quant_bit = self.args.quant_bit if quant_block is None: - quant_block = self.quant_block + quant_block = self.args.quant_block if weight_sym is None: - weight_sym = self.symmetric + weight_sym = self.args.sym if quant_bit == 16: quant_args = ['--fp16'] else: @@ -150,19 +147,20 @@ def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fu self.json2mnn(mnn_json, self.mnn_model_path) self.removeDupOps(self.mnn_model_path) self.mnn2json(self.mnn_model_path, mnn_json) - if self.config.args.gptq_path is not None: + if self.args.gptq_path is not None: self.apply_gptq(mnn_json) - if self.config.args.lora_path is not None and self.config.args.lora_split: + if self.args.lora_path is not None and self.args.lora_split: self.export_lora(mnn_json) - if self.config.args.smooth: + if self.args.smooth: self.export_smooth_quant(mnn_json) + return self.tie_embeddings_info def get_experts_graphs(self, experts): - hidden_states = torch.randn((1, self.config.hidden_size)) + hidden_states = torch.randn((1, self.exporter.config.hidden_size)) layers_num = len(experts) expert_num = len(experts[0]) dummy_expert = experts[0][0] - onnx_model = f'{self.config.onnx_path}/expert.onnx' + onnx_model = f'{self.args.onnx_path}/expert.onnx' onnx_export( dummy_expert, (hidden_states), onnx_model, @@ -207,14 +205,14 @@ def get_experts_graphs(self, experts): @spinner_run(f'apply gptq to ') def apply_gptq(self, mnn_json): - GPTQ(self.config.args.gptq_path).apply(mnn_json, self.mnn_weight_path) + GPTQ(self.args.gptq_path).apply(mnn_json, self.mnn_weight_path) return self.mnn_weight_path @spinner_run(f'export split lora to ') def export_lora(self, mnn_json): - lora_model = os.path.join(self.config.args.dst_path, 'lora.mnn') + lora_model = os.path.join(self.args.dst_path, 'lora.mnn') lora_json = f'{lora_model}.json' - LoRA(self.config.args.lora_path).apply(mnn_json, lora_json) + LoRA(self.args.lora_path).apply(mnn_json, lora_json) self.json2mnn(lora_json, lora_model) if os.path.exists(lora_json): os.remove(lora_json) @@ -222,16 +220,16 @@ def export_lora(self, mnn_json): @spinner_run(f'export smooth quant scale to ') def export_smooth_quant(self, mnn_json): - self.config.smooth_quantizer.apply(mnn_json) + self.exporter.smooth_quantizer.apply(mnn_json) self.json2mnn(mnn_json, self.mnn_model_path) return self.mnn_model_path @spinner_run(f'quant model weight to ', True) def rebuild(self, json_path): mnn_graph = json.load(open(json_path, 'rt')) - has_experts = len(self.config.experts) > 0 + has_experts = hasattr(self.args, 'experts') and len(self.exporter.experts) > 0 if has_experts: - subgraphs = self.get_experts_graphs(self.config.experts) + subgraphs = self.get_experts_graphs(self.exporter.experts) mnn_graph['subgraphs'] = subgraphs new_ops = [] # Load layernorm weight from external @@ -265,7 +263,7 @@ def rebuild(self, json_path): return self.mnn_weight_path def quant(self, weight, quant_bit, quant_block, symmetric): - q_weight, alpha = torch_quant(weight, quant_bit, quant_block, symmetric, self.config.args.awq, self.config.args.hqq) + q_weight, alpha = torch_quant(weight, quant_bit, quant_block, symmetric, self.args.awq, self.args.hqq) return q_weight, alpha def write_weight(self, data): @@ -403,19 +401,19 @@ def rebuild_linear(self, op, graph): (linear.bias is not None) == has_bias) is_lm = 'lm_head' in name - quant_bit = self.lm_quant_bit if is_lm else self.quant_bit - block_size = ic if self.quant_block == 0 else self.quant_block + quant_bit = self.args.lm_quant_bit if is_lm else self.args.quant_bit + block_size = ic if self.args.quant_block == 0 else self.args.quant_block if is_lm and self.lm_weight is not None: external, q_min, shape_int32, header_len = self.lm_weight else: - external, q_min, shape_int32, header_len = self.build_weight(linear, quant_bit, self.quant_block, self.symmetric) + external, q_min, shape_int32, header_len = self.build_weight(linear, quant_bit, self.args.quant_block, self.args.sym) if is_lm and self.lm_weight is None: self.lm_weight = [external, q_min, shape_int32, header_len] - if is_lm and self.config.tie_word_embeddings: + if is_lm and self.args.tie_word_embeddings: weight_offset = external[0] + header_len alpha_offset = external[0] + external[1] alpha_size = external[2] - self.config.llm_config['tie_embeddings'] = [weight_offset, alpha_offset, alpha_size, quant_bit, self.quant_block] + self.tie_embeddings_info = [weight_offset, alpha_offset, alpha_size, quant_bit, self.args.quant_block] origin_input = op['inputIndexes'] origin_output = op['outputIndexes'] @@ -460,7 +458,7 @@ def rebuild_linear(self, op, graph): if quant_bit == 16: quanParameter = { "type": 3 } else: - if self.symmetric: + if self.args.sym: aMin = 0 readType = 0 else: diff --git a/transformers/llm/export/utils/model.py b/transformers/llm/export/utils/model.py new file mode 100644 index 0000000000..a327109e65 --- /dev/null +++ b/transformers/llm/export/utils/model.py @@ -0,0 +1,380 @@ +import torch +import importlib +from packaging.version import Version +from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoModelForCausalLM +from typing import Optional, List + +from utils.config import LlmConfig +from utils.tokenizer import LlmTokenizer +from utils.model_mapper import ModelMapper +from utils.transformers import Embedding, Rotary, Decoder, Lm + +class LlmModel(PreTrainedModel): + config_class = LlmConfig + + def __init__(self, config, args=None): + super().__init__(config) + self.config = config + self.args = args + self.tokenizer = None + self.model = None + self.visual = None + self.audio = None + self.talker = None + self.mtp = None + self.scale_emb = None + + def _init_weights(self, module): + pass + + def get_config(self): + llm_config = {} + models = ['visual', 'audio', 'talker'] + for m in models: + if hasattr(self, m) and getattr(self, m) is not None: + m_config = getattr(self, m).get_config() + llm_config.update(m_config) + return llm_config + + @staticmethod + def get_model_class(model_type: str): + # Same as in LlmExporter + MODEL_CLASS_MAPPING = { + 'qwen3_vl': 'Qwen3VLForConditionalGeneration', + 'qwen3_vl_moe': 'Qwen3VLMoeForConditionalGeneration', + 'qwen2_5_omni': 'Qwen2_5OmniForConditionalGeneration', + 'qwen2_5_vl': 'Qwen2_5_VLForConditionalGeneration', + 'qwen2_vl': 'Qwen2VLForConditionalGeneration', + 'qwen2_audio': 'Qwen2AudioForConditionalGeneration', + 'smolvlm': 'AutoModelForImageTextToText', + 'idefics3': 'AutoModelForVision2Seq', + } + if model_type is None or model_type not in MODEL_CLASS_MAPPING: + return AutoModelForCausalLM + class_name = MODEL_CLASS_MAPPING[model_type] + try: + module = importlib.import_module('transformers') + return getattr(module, class_name) + except (ImportError, AttributeError): + return AutoModelForCausalLM + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, args=None, **kwargs): + config = LlmConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + model_type = config.model_type + model_class = cls.get_model_class(model_type) + + load_kwargs = {'trust_remote_code': True} + if Version(importlib.metadata.version("transformers")) >= Version("4.56.0"): + load_kwargs['dtype'] = 'auto' + else: + load_kwargs['torch_dtype'] = 'auto' + + if model_type == 'internvl_chat': + load_kwargs['use_flash_attn'] = False + + try: + original_model = model_class.from_pretrained(pretrained_model_name_or_path, **load_kwargs) + except Exception: + original_model = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs) + + # LoRA + if args.lora_path is not None and not args.lora_split: + from peft import PeftModel + adapter = PeftModel.from_pretrained(original_model, model_id=args.lora_path) + original_model = adapter.merge_and_unload(progressbar=True) + + original_model = original_model.eval() + + model = cls(config, args) + + if model_type == 'qwen2_audio': + model.audio = original_model + original_model = original_model.language_model + + ModelMapper.do_map(model, original_model, config.model_map['model']) + + model.tokenizer = LlmTokenizer.from_pretrained( + pretrained_model_name_or_path, + model_type=model_type + ) + + # Rebuild modules + if model.lm is None: + out_features, in_features = model.embed.weight.shape + model.lm = torch.nn.Linear(in_features, out_features, bias=False) + model.lm.weight = model.embed.weight + elif not isinstance(model.lm, torch.nn.Linear): + weight = model.lm.weight + out_features, in_features = weight.shape + model.lm = torch.nn.Linear(in_features, out_features, bias=False) + model.lm.weight = weight + + model.embed = Embedding(model.embed, config) + model.rotary = Rotary(config) + model.blocks = torch.nn.ModuleList([ + Decoder(block, i, config, model.rotary, config.model_map) for i, block in enumerate(model.blocks.children()) + ]) + model.lm = Lm(model.lm) + + if 'gemma' in model_type: + model.scale_emb = model.embed.embedscale + + # Multi-modal parts + if model.visual is not None: + from utils.vision import Vision + # model.visual = Vision.get_vision(model_type)(model.visual, model) + model.visual = Vision.get_vision(model_type)(model.visual.float(), model).float() + if hasattr(model, 'audio') and model.audio is not None: + from utils.audio import Audio + model.audio = Audio.get_audio(model.audio.config.model_type)(model.audio, model) + if hasattr(model, 'talker') and model.talker is not None: + from utils.talker import Talker + model.talker = Talker.get_talker(model_type)(model.talker, model.token2wav, model) + if model_type == 'poi_qwen2_mtp': + model.mtp = [model.mtp1, model.mtp2] + if model.mtp is not None: + from utils.mtp import Mtp + model.mtp = Mtp.get_mtp(model_type)(model.mtp, model) + + return model + + def embedding(self, input_ids): + if self.visual is not None and len(input_ids) > 1: + return self.visual.embed(input_ids) + if self.audio is not None and len(input_ids) > 1: + return self.audio.embed(input_ids) + return self.embed(input_ids) + + def forward(self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Optional[List[torch.Tensor]] = None, + logits_index: torch.Tensor = torch.tensor([-1], dtype=torch.int32), + deepstack_embeds: torch.Tensor = None + ): + hidden_states = input_ids # llm forward without embedding + if self.scale_emb is not None: + hidden_states = hidden_states * self.scale_emb + presents = [None for i in range(len(self.blocks))] + eagle_hidden_states = [] + rotary_pos_emb = self.rotary(position_ids) + if self.args and self.args.test and rotary_pos_emb.dtype != hidden_states.dtype: + rotary_pos_emb = rotary_pos_emb.type(hidden_states.dtype) + + for i in range(len(self.blocks)): + # eagle3 hidden states + if self.args and self.args.eagle_path and (i == len(self.blocks)-3 or i == len(self.blocks)//2 or i==2): + eagle_hidden_states.append(hidden_states) + + past_kv = past_key_values[i] if past_key_values is not None and past_key_values[i] is not None else None + + # sliding or full attn mask + if self.config.attention_type == 'mix': + is_sliding = i in self.config.sliding_attn_layers + layer_attention_mask = attention_mask[int(is_sliding)] + else: + layer_attention_mask = attention_mask + + hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, layer_attention_mask, past_kv) + presents[i] = kv + if deepstack_embeds is not None and i in range(deepstack_embeds.shape[0]): + hidden_states += deepstack_embeds[i] + + talker_embeds = None + if hasattr(self, 'talker') and self.talker is not None: + talker_embeds = self.final_layernorm(hidden_states) + input_ids.permute([1, 0, 2]) + self.talker.add_talker_embeds(talker_embeds) + + final_layernorm = hidden_states + logits_index_long = logits_index.to(torch.int64) + if self.mtp is None: + hidden_states = hidden_states[:, logits_index_long:, :] + hidden_states = self.final_layernorm(hidden_states) + # default: set hidden_state before lm_head as output node + final_layernorm = hidden_states + else: + # final_layernorm need compute all logists + if self.config.model_type == 'mimo': + final_layernorm = hidden_states # mimo + hidden_states = self.final_layernorm(hidden_states) + if self.config.model_type == 'poi_qwen2_mtp': + final_layernorm = hidden_states # poi + hidden_states = hidden_states[:, logits_index_long:, :] + logits = self.lm(hidden_states) + if presents[0] is not None and presents[0].shape == presents[-1].shape and None not in presents: + presents = torch.stack(presents) + + if self.args and self.args.eagle_path is not None: + final_layernorm = torch.cat(eagle_hidden_states, dim=-1) + + return logits, final_layernorm, presents, talker_embeds + + def get_attention_mask(self, seq_len: int, new_tokens: int = 0): + if self.config.model_type == 'chatglm': + return self.chatglm_attention_mask() + if self.config.attention_type == 'full': + return self.full_attention_mask(seq_len, new_tokens) + elif self.config.attention_type == 'sliding': + return self.sliding_attention_mask(self.config.sliding_window, seq_len, new_tokens) + elif self.config.attention_type == 'mix': + full_mask = self.full_attention_mask(seq_len, new_tokens) + sliding_mask = self.sliding_attention_mask(self.config.sliding_window, seq_len, new_tokens) + return torch.stack([full_mask, sliding_mask], dim=0) + return None + + def full_attention_mask(self, seq_len, new_tokens): + if new_tokens: + return torch.zeros([1, 1, 1, seq_len], dtype=torch.float32) + return (1 - torch.tril(torch.ones([1, 1, seq_len, seq_len]))) * torch.finfo(torch.float32).min + + def sliding_attention_mask(self, sliding_window, seq_len, new_tokens): + if new_tokens: + sliding_mask = torch.zeros([1, 1, 1, seq_len], dtype=torch.float32) + num_tokens_to_mask = seq_len - sliding_window + if num_tokens_to_mask > 0: + sliding_mask[..., :num_tokens_to_mask] = torch.finfo(torch.float32).min + return sliding_mask + causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) + query_indices = torch.arange(seq_len).view(-1, 1) + key_indices = torch.arange(seq_len).view(1, -1) + window_mask = (key_indices > query_indices - sliding_window) + final_mask_bool = causal_mask & window_mask + sliding_mask = torch.where(final_mask_bool, 0.0, torch.finfo(torch.float32).min) + return sliding_mask.view(1, 1, seq_len, seq_len) + + def get_position_ids(self, seq_len, new_tokens=0, input_ids=None): + if self.visual is not None and hasattr(self.visual, 'get_position_ids'): + return self.visual.get_position_ids(input_ids, seq_len, new_tokens) + if self.config.model_type == 'chatglm': + return self.chatglm_position_ids(seq_len, new_tokens) + if new_tokens: + return torch.tensor([[seq_len - 1]], dtype=torch.int) + return torch.arange(seq_len, dtype=torch.int).unsqueeze(0) + + def chatglm_attention_mask(self, seq_len, is_decode): + if is_decode: + return torch.zeros([1]).bool().reshape([1, 1, 1, 1]) + attention_mask = torch.zeros([seq_len, seq_len], dtype=torch.bool) + for i in range(seq_len - 1): + attention_mask[i][-1] = True + return attention_mask.reshape([1, 1, seq_len, seq_len]) + + def chatglm_position_ids(self, seq_len, new_tokens): + if new_tokens: + return torch.tensor([seq_len - 2, new_tokens + 1]).reshape([1, 2, 1]) + position_ids_0 = torch.arange(seq_len, dtype=torch.int) + position_ids_1 = torch.zeros(seq_len, dtype=torch.int) + position_ids_0[-1] = position_ids_0[-2] + position_ids_1[-1] = 1 + return torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1) + +class EmbeddingModel(LlmModel): + def __init__(self, config, args=None): + super().__init__(config, args) + self.is_reranker = False + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, args=None, **kwargs): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + model_type = config.model_type + if model_type == 'qwen3': + model = super(EmbeddingModel, cls).from_pretrained(pretrained_model_name_or_path, args=args).float().eval() + return model + # gte, bge + config._attn_implementation = 'eager' + model = cls(config, args) + if model_type == 'new' and 'NewForSequenceClassification' in config.architectures: + model.is_reranker = True + from transformers import AutoModelForSequenceClassification + origin_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, config=config, trust_remote_code=True).float().eval() + model.classifier = origin_model.classifier + origin_model = origin_model.new + else: + origin_model = AutoModel.from_pretrained(pretrained_model_name_or_path, config=config, trust_remote_code=True).float().eval() + + transformer = origin_model.encoder + model.lm = origin_model.pooler + model.embed = origin_model.embeddings + model.word_embeddings = model.embed.word_embeddings + model.token_type_embeddings = model.embed.token_type_embeddings.weight.data[0] + model.embedding_layernorm = model.embed.LayerNorm + if hasattr(model.embed, 'position_embeddings'): + model.position_embeddings = model.embed.position_embeddings + model.hidden_size = model.word_embeddings.weight.shape[-1] + model.blocks = transformer.layer + # some wrapper + model.num_hidden_layers = len(model.blocks) + return model + + def forward(self, inputs_embeds, attention_mask, position_ids): + if self.config.model_type == 'bert': + return self.bge_forward(inputs_embeds, attention_mask, position_ids) + if self.config.model_type == 'new': + return self.gte_forward(inputs_embeds, attention_mask, position_ids) + if self.config.model_type == 'qwen3': + return self.qwen3_forward(inputs_embeds, attention_mask, position_ids) + raise RuntimeError(f'Not support embedding model: {self.config.model_type}!') + + def word_embed(self, input_ids): + if hasattr(self, 'word_embeddings'): + return self.word_embeddings(input_ids.view(1, -1)) + return self.embed(input_ids.view(1, -1)) + + def bge_forward(self, inputs_embeds, attention_mask, position_ids): + inputs_embeds = inputs_embeds.reshape(1, -1, self.config.hidden_size) + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings + hidden_states = self.embedding_layernorm(embeddings) + for i in range(self.config.num_hidden_layers): + hidden_states = self.blocks[i](hidden_states, attention_mask)[0] + sentence_embeddings = hidden_states[:, 0] + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + def gte_reranker_forward(self, inputs_embeds, attention_mask, position_ids): + freqs = position_ids.float().reshape(-1, 1) * self.embed.rotary_emb.inv_freq + emb = torch.cat((freqs, freqs), dim=-1) + rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1) + hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings) + for i in range(self.config.num_hidden_layers): + hidden_states = self.blocks[i](hidden_states, attention_mask, rope_embeds)[0] + pooled_output = self.lm(hidden_states) + logits = self.classifier(pooled_output) + return logits + + def gte_embedding_forward(self, inputs_embeds, attention_mask, position_ids): + inputs_embeds = inputs_embeds.reshape(1, -1, self.config.hidden_size) + freqs = position_ids.float().reshape(-1, 1) * self.embed.rotary_emb.inv_freq + emb = torch.cat((freqs, freqs), dim=-1) + rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1) + attention_bias = 1 - attention_mask.float() + hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings) + for i in range(self.config.num_hidden_layers): + hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0] + sentence_embeddings = hidden_states[:, 0] + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + def gte_forward(self, inputs_embeds, attention_mask, position_ids): + if self.is_reranker: + return self.gte_reranker_forward(inputs_embeds, attention_mask, position_ids) + return self.gte_embedding_forward(inputs_embeds, attention_mask, position_ids) + + def qwen3_forward(self, inputs_embeds, attention_mask, position_ids): + hidden_states = inputs_embeds + rotary_pos_emb = self.rotary(position_ids) + for i in range(len(self.blocks)): + hidden_states, _ = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, None) + last_hidden_states = hidden_states[:, -1, :] + last_hidden_states = self.final_layernorm(last_hidden_states) + return last_hidden_states + + def get_position_ids(self, seq_len) -> torch.Tensor: + return torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + + def get_attention_mask(self, seq_len) -> torch.Tensor: + if self.config.model_type == 'qwen3': + return super().get_attention_mask(seq_len, 0) + return torch.ones([1, 1, seq_len, seq_len], dtype=torch.float) \ No newline at end of file diff --git a/transformers/llm/export/utils/model_mapper.py b/transformers/llm/export/utils/model_mapper.py index 9c27ed7553..6ea5168bd2 100644 --- a/transformers/llm/export/utils/model_mapper.py +++ b/transformers/llm/export/utils/model_mapper.py @@ -55,10 +55,10 @@ def regist_deepseek_vl(self): 'num_key_value_heads': 'language_config.num_key_value_heads', }, 'model': { - 'lm_': 'language_model.lm_head', - 'embed_': 'language_model.model.embed_tokens', - 'blocks_': 'language_model.model.layers', - 'final_layernorm_': 'language_model.model.norm', + 'lm': 'language_model.lm_head', + 'embed': 'language_model.model.embed_tokens', + 'blocks': 'language_model.model.layers', + 'final_layernorm': 'language_model.model.norm', 'visual': 'vision_model' }, 'decoder': { @@ -88,10 +88,10 @@ def regist_qwen_omni(self): 'rope_scaling': 'thinker_config.text_config.rope_scaling' }, 'model': { - 'lm_': 'thinker.lm_head', - 'embed_': 'thinker.model.embed_tokens', - 'blocks_': 'thinker.model.layers', - 'final_layernorm_': 'thinker.model.norm', + 'lm': 'thinker.lm_head', + 'embed': 'thinker.model.embed_tokens', + 'blocks': 'thinker.model.layers', + 'final_layernorm': 'thinker.model.norm', 'visual': 'thinker.visual', 'audio': 'thinker.audio_tower', 'talker': 'talker', @@ -111,10 +111,10 @@ def regist_qwen(self): 'rope_theta': 'rotary_emb_base', }, 'model': { - 'lm_': 'lm_head', - 'embed_': 'transformer.wte', - 'blocks_': 'transformer.h', - 'final_layernorm_': 'transformer.ln_f', + 'lm': 'lm_head', + 'embed': 'transformer.wte', + 'blocks': 'transformer.h', + 'final_layernorm': 'transformer.ln_f', 'visual': 'transformer.visual' }, 'decoder': { @@ -221,10 +221,10 @@ def regist_glm(self): 'num_hidden_layers': 'num_layers' }, 'model': { - 'lm_': 'lm_head', - 'embed_': 'transformer.word_embeddings', - 'blocks_': 'transformer.layers', - 'final_layernorm_': 'transformer.final_layernorm', + 'lm': 'lm_head', + 'embed': 'transformer.word_embeddings', + 'blocks': 'transformer.layers', + 'final_layernorm': 'transformer.final_layernorm', }, 'decoder': { 'self_attn': 'attention', @@ -249,10 +249,10 @@ def regist_glm2(self): 'rope_ratio': 'rope_ratio' }, 'model': { - 'lm_': 'transformer.output_layer', - 'embed_': 'transformer.embedding.word_embeddings', - 'blocks_': 'transformer.encoder.layers', - 'final_layernorm_': 'transformer.encoder.final_layernorm', + 'lm': 'transformer.output_layer', + 'embed': 'transformer.embedding.word_embeddings', + 'blocks': 'transformer.encoder.layers', + 'final_layernorm': 'transformer.encoder.final_layernorm', }, 'decoder': { 'self_attn': 'self_attention', @@ -276,10 +276,10 @@ def regist_phi(self): 'rotary_dim': 'rotary_dim' }, 'model': { - 'lm_': 'lm_head.linear', - 'embed_': 'transformer.embd.wte', - 'blocks_': 'transformer.h', - 'final_layernorm_': 'lm_head.ln', + 'lm': 'lm_head.linear', + 'embed': 'transformer.embd.wte', + 'blocks': 'transformer.h', + 'final_layernorm': 'lm_head.ln', }, 'decoder': { 'self_attn': 'mixer', @@ -304,10 +304,10 @@ def regist_phi3(self): 'num_key_value_heads': 'num_key_value_heads', }, 'model': { - 'lm_': 'lm_head', - 'embed_': 'model.embed_tokens', - 'blocks_': 'model.layers', - 'final_layernorm_': 'model.norm' + 'lm': 'lm_head', + 'embed': 'model.embed_tokens', + 'blocks': 'model.layers', + 'final_layernorm': 'model.norm' }, 'decoder': { 'self_attn': 'self_attn', @@ -332,10 +332,10 @@ def regist_intervl(self): 'num_key_value_heads': 'llm_config.num_key_value_heads', }, 'model': { - 'lm_': 'language_model.lm_head', - 'embed_': 'language_model.model.embed_tokens', - 'blocks_': 'language_model.model.layers', - 'final_layernorm_': 'language_model.model.norm', + 'lm': 'language_model.lm_head', + 'embed': 'language_model.model.embed_tokens', + 'blocks': 'language_model.model.layers', + 'final_layernorm': 'language_model.model.norm', 'visual': 'vision_model' }, 'decoder': { @@ -384,10 +384,10 @@ def regist_gemma3(self): 'eoi_token_index': 'eoi_token_index', #'' }, 'model': { - 'lm_': 'language_model.lm_head', - 'embed_': 'language_model.model.embed_tokens', - 'blocks_': 'language_model.model.layers', - 'final_layernorm_': 'language_model.model.norm', + 'lm': 'language_model.lm_head', + 'embed': 'language_model.model.embed_tokens', + 'blocks': 'language_model.model.layers', + 'final_layernorm': 'language_model.model.norm', 'vision_tower': 'vision_tower', 'visual': 'vision_tower.vision_model', 'multi_modal_projector': 'multi_modal_projector' @@ -431,10 +431,10 @@ def regist_gemma3_text(self): 'sliding_window': 'sliding_window' }, 'model': { - 'lm_': 'lm_head', - 'embed_': 'model.embed_tokens', - 'blocks_': 'model.layers', - 'final_layernorm_': 'model.norm', + 'lm': 'lm_head', + 'embed': 'model.embed_tokens', + 'blocks': 'model.layers', + 'final_layernorm': 'model.norm', 'rotary_emb': 'model.rotary_emb', 'rotary_emb_local': 'model.rotary_emb_local' }, @@ -467,10 +467,10 @@ def register_openelm(self): 'rope_theta': 'rope_freq_constant' } openelm_model = { - 'lm_': 'lm_head', - 'embed_': 'transformer.token_embeddings', - 'blocks_': 'transformer.layers', - 'final_layernorm_': 'transformer.norm' + 'lm': 'lm_head', + 'embed': 'transformer.token_embeddings', + 'blocks': 'transformer.layers', + 'final_layernorm': 'transformer.norm' } openelm_decoder = { 'self_attn': 'attn', @@ -503,11 +503,12 @@ def regist_idefics3(self): 'rope_scaling': 'text_config.rope_scaling' } idefics3_model = { - 'lm_': 'lm_head', - 'embed_': 'model.text_model.embed_tokens', - 'blocks_': 'model.text_model.layers', - 'final_layernorm_': 'model.text_model.norm', - 'visual': 'model.vision_model' + 'lm': 'lm_head', + 'embed': 'model.text_model.embed_tokens', + 'blocks': 'model.text_model.layers', + 'final_layernorm': 'model.text_model.norm', + 'visual': 'model.vision_model', + 'visual.connector': 'model.connector' } idefics3_map = { 'config': idefics3_config, @@ -533,10 +534,10 @@ def regist_qwenvl(self): if TRANSFORMERS_VERSION <= '4.52.1': return qwen2vl_model = { - 'lm_': 'lm_head', - 'embed_': 'model.language_model.embed_tokens', - 'blocks_': 'model.language_model.layers', - 'final_layernorm_': 'model.language_model.norm', + 'lm': 'lm_head', + 'embed': 'model.language_model.embed_tokens', + 'blocks': 'model.language_model.layers', + 'final_layernorm': 'model.language_model.norm', 'visual': 'model.visual' } qwen2vl_map = { @@ -652,10 +653,10 @@ def regist_minicpmv(self): minicpmv_config = copy.deepcopy(self.default_config) minicpmv_config['scale_emb'] = 'scale_emb' minicpmv_model = { - 'lm_': 'llm.lm_head', - 'embed_': 'llm.model.embed_tokens', - 'blocks_': 'llm.model.layers', - 'final_layernorm_': 'llm.model.norm', + 'lm': 'llm.lm_head', + 'embed': 'llm.model.embed_tokens', + 'blocks': 'llm.model.layers', + 'final_layernorm': 'llm.model.norm', 'visual': 'vpm', 'resampler': 'resampler' } @@ -684,10 +685,10 @@ def init_default_map(self): 'max_position_embeddings': 'max_position_embeddings' } self.default_model = { - 'lm_': 'lm_head', - 'embed_': 'model.embed_tokens', - 'blocks_': 'model.layers', - 'final_layernorm_': 'model.norm', + 'lm': 'lm_head', + 'embed': 'model.embed_tokens', + 'blocks': 'model.layers', + 'final_layernorm': 'model.norm', 'visual': 'visual' } self.default_decoder = { @@ -711,14 +712,35 @@ def init_default_map(self): } @staticmethod - def do_map(dst, src, map): - for dst_attr, src_attr in map.items(): - attributes = src_attr.split('.') - obj = src - for attr in attributes: - if hasattr(obj, attr): - obj = getattr(obj, attr) + def do_map(dst, src, mapping): + # Sort mapping by key to ensure parents are set before children + # e.g., 'visual' is processed before 'visual.connector' for SmolVLM + for dst_path, src_path in sorted(mapping.items(), key=lambda x: x[0]): + # --- 1. Retrieve value from source --- + val = src + for attr in src_path.split('.'): + if hasattr(val, attr): + val = getattr(val, attr) else: - obj = None + val = None break - setattr(dst, dst_attr, obj) \ No newline at end of file + + # --- 2. Navigate to destination parent node --- + dst_parts = dst_path.split('.') + target = dst + + # Traverse to the second-to-last object + path_valid = True + for attr in dst_parts[:-1]: + if hasattr(target, attr): + target = getattr(target, attr) + if target is None: + path_valid = False + break + else: + path_valid = False + break + + # --- 3. Set value --- + if path_valid and target: + setattr(target, dst_parts[-1], val) \ No newline at end of file diff --git a/transformers/llm/export/utils/mtp.py b/transformers/llm/export/utils/mtp.py index 1193fd00e3..687eb4e017 100644 --- a/transformers/llm/export/utils/mtp.py +++ b/transformers/llm/export/utils/mtp.py @@ -1,8 +1,6 @@ -import math import torch import torch.nn as nn -import numpy as np -from typing import Optional, List, Tuple +from typing import Optional, Tuple from .transformers import Attention from utils.custom_op import FakeLinear @@ -12,21 +10,18 @@ class Mtp(torch.nn.Module): def __init__(self, mtp, base): super().__init__() - self.model_type = base.model_type + self.model_type = base.config.model_type self.mtp = mtp self.embed_ = base.embed self.lm_ = base.lm + self.rotary = base.rotary - self.config_ = base.config + self.config = base.config if not hasattr(base.config, 'head_dim'): - self.config_.head_dim = base.head_dim - self.config_.rotary = base.rotary - self.config_.model_type = base.model_type - self.config_.model_map = base.model_map - self.hidden_size = base.hidden_size - self.past_kv_shape = base.past_kv_shape - self.num_attention_heads = base.num_attention_heads - self.llm_config = base.llm_config + self.config.head_dim = base.head_dim + self.hidden_size = self.config.hidden_size + self.num_attention_heads = self.config.num_attention_heads + self.past_kv_shape = [self.config.num_hidden_layers, 2, 1, 0, self.config.num_key_value_heads, self.config.head_dim] self.load() self.unloaded_ops = {} @@ -99,7 +94,7 @@ def load(self): self.post_attention_layernorm = getattr(self.mtp[0], 'post_attention_layernorm') self.mlp = getattr(self.mtp[0], 'mlp') self.final_layernorm = getattr(self.mtp[0], 'final_layernorm') - self.self_attn = Attention(self.self_attn, 0, self.config_) + self.self_attn = Attention(self.self_attn, 0, self.config, self.rotary, self.config.model_map) def unload_param(self): def build_faker(real, name): @@ -137,7 +132,7 @@ def forward(self, residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - rotary_pos_emb = self.config_.rotary(position_ids) + rotary_pos_emb = self.rotary(position_ids) # Self Attention hidden_states, present_key_value = self.self_attn( @@ -189,7 +184,7 @@ def load(self): self.ori_attn = getattr(self.decode_layers[i], 'self_attn') self.post_attention_layernorm.append(getattr(self.decode_layers[i], 'post_attention_layernorm')) self.mlp.append(getattr(self.decode_layers[i], 'mlp')) - self.self_attn.append(Attention(self.ori_attn, i, self.config_)) + self.self_attn.append(Attention(self.ori_attn, i, self.config)) def unload_param(self): def build_faker(real, name): @@ -221,7 +216,7 @@ def forward(self, # [1, -1, self.hidden_size] mtp_hidden_states = [] - rotary_pos_emb = self.config_.rotary(position_ids) + rotary_pos_emb = self.rotary(position_ids) hidden_states = hidden_states.view(1, -1, self.hidden_size) hidden_states = hidden_states[:, 0 : input_embeds.size(0), :] diff --git a/transformers/llm/export/utils/smooth_quantizer.py b/transformers/llm/export/utils/smooth_quantizer.py index 97a8980c5c..fd3ef92c4b 100644 --- a/transformers/llm/export/utils/smooth_quantizer.py +++ b/transformers/llm/export/utils/smooth_quantizer.py @@ -164,13 +164,12 @@ def init_quant(self, n_samples=128, max_seq_len=512): def _get_first_input(self, sample): layer_kwargs = {} - self.model.seq_len = sample.numel() - self.model.context_len = sample.numel() - 2 - self.model.token_len = 0 + seq_len = sample.numel() + new_tokens = 0 inps = self.model.embedding(sample).to(self.best_device) - position_ids = self.model.get_position_ids(sample) + position_ids = self.model.get_position_ids(seq_len, new_tokens, sample) rotary_pos_emb = self.model.rotary(position_ids) - attention_mask = self.model.get_attention_mask() + attention_mask = self.model.get_attention_mask(seq_len, new_tokens, ) layer_kwargs["rotary_pos_emb"] = rotary_pos_emb.to(self.best_device) layer_kwargs["attention_mask"] = attention_mask.to(self.best_device) del sample diff --git a/transformers/llm/export/utils/talker.py b/transformers/llm/export/utils/talker.py index cd7485954f..aa1e9cd245 100644 --- a/transformers/llm/export/utils/talker.py +++ b/transformers/llm/export/utils/talker.py @@ -10,14 +10,14 @@ class Talker(torch.nn.Module): def __init__(self, talker, token2wav, base): super().__init__() - self.model_type = base.model_type + self.model_type = base.config.model_type self.thinker_embed = base.embed self.args = base.args self.talker = talker.float() self.token2wav = Qwen2_5OmniToken2Wav(token2wav, base) self.config = base.config - self.hidden_size = base.hidden_size - self.llm_config = base.llm_config + self.hidden_size = base.config.hidden_size + self.llm_config = { 'has_talker': True } self.rope_ratio = 1.0 self.quant_bit = 4 if self.hidden_size <= 2048: @@ -26,6 +26,9 @@ def __init__(self, talker, token2wav, base): self.init_config() self.load() + def get_config(self): + return self.llm_config + @staticmethod def get_talker(model_type): audio_models = { @@ -36,7 +39,7 @@ def get_talker(model_type): return None def init_config(self): - self.llm_config['has_talker'] = True + pass def load(self): raise NotImplementedError @@ -87,7 +90,7 @@ def forward(self, position_ids): class Qwen2_5OmniTalker(Talker): def __init__(self, talker, token2wav, base): super().__init__(talker, token2wav, base) - self.input_hidden_size = base.hidden_size + self.input_hidden_size = base.config.hidden_size self.seq_len = 0 self.token_len = 0 self.talker_embeds = [] diff --git a/transformers/llm/export/utils/token2wav.py b/transformers/llm/export/utils/token2wav.py index 436dbec26d..85b5b373cd 100644 --- a/transformers/llm/export/utils/token2wav.py +++ b/transformers/llm/export/utils/token2wav.py @@ -13,7 +13,6 @@ def __init__(self,token2wav, base): self.args = base.args self.token2wav = token2wav.float() self.config = base.config - self.llm_config = base.llm_config self.rope_ratio = 1.0 self.quant_bit = 8 self.load() diff --git a/transformers/llm/export/utils/tokenizer.py b/transformers/llm/export/utils/tokenizer.py new file mode 100644 index 0000000000..e7dba1f1b1 --- /dev/null +++ b/transformers/llm/export/utils/tokenizer.py @@ -0,0 +1,376 @@ +import os +import base64 +from transformers import PreTrainedTokenizer, AutoTokenizer + +class LlmTokenizer(PreTrainedTokenizer): + def __init__(self, tokenizer_path, model_type, **kwargs): + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, use_fast=False) + except: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, use_fast=True) + self.tokenizer_path = tokenizer_path + self.model_type = model_type + # stop_ids + self.stop_ids = [] + self.stop_ids.append(self.tokenizer.eos_token_id) + if hasattr(self.tokenizer, 'im_end_id'): + self.stop_ids.append(self.tokenizer.im_end_id) + try: + eot_id = self.tokenizer.encode('<|eot_id|>') + if len(eot_id) == 1: + self.stop_ids.append(eot_id[0]) + eot_id = self.tokenizer.encode('') + if len(eot_id) == 2 and eot_id[0] == 2: + self.stop_ids.append(eot_id[1]) + except: + pass + if hasattr(self.tokenizer, 'generation_config') and self.tokenizer.generation_config is not None: + eos_token_id = self.tokenizer.generation_config.eos_token_id + from collections.abc import Iterable + if isinstance(eos_token_id, int): + self.stop_ids.append(eos_token_id) + elif isinstance(eos_token_id, Iterable): + for id in eos_token_id: + self.stop_ids.append(id) + self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None] + self.stop_ids = list(set(self.stop_ids)) + super().__init__(**kwargs) + + def __call__(self, *args, **kwargs): + return self.tokenizer(*args, **kwargs) + + def __getattr__(self, name): + if self.tokenizer and hasattr(self.tokenizer, name): + return getattr(self.tokenizer, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text, **kwargs) + + def _convert_token_to_id(self, token): + return self.tokenizer.convert_tokens_to_ids(token) + + def _convert_id_to_token(self, index): + return self.tokenizer.convert_ids_to_tokens(index) + + def get_vocab(self): + return self.tokenizer.get_vocab() + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + def id_to_str(self, token_id): + try: + word = self.tokenizer.decode(int(token_id)) + except: + def contains_replacement(text): return '\uFFFD' in text + def decode_id(token_id): + return self.tokenizer.convert_tokens_to_string( + self.tokenizer._convert_id_to_token(int(token_id))) + def decode_ids(token_ids): + return self.tokenizer.convert_tokens_to_string( + self.tokenizer.convert_ids_to_tokens(token_ids)) + word = decode_id(int(token_id)) + # Smollm tokenizer will produce half chinese character, using buffer to decode + if contains_replacement(word): + self.decode_buffer.append(token_id) + buffer_txt = decode_ids(self.decode_buffer) + if not contains_replacement(buffer_txt): + word = buffer_txt + self.decode_buffer.clear() + else: + word = '' + return word + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, model_type, **kwargs): + return cls(pretrained_model_name_or_path, model_type, **kwargs) + + def apply_chat_template(self, conversation, **kwargs): + if hasattr(self.tokenizer, 'apply_chat_template'): + return self.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, **kwargs) + raise RuntimeError('Tokenizer no `apply_chat_template` funtion.') + + def save_vocabulary(self, save_directory, **kwargs): + file_path = os.path.join(save_directory, "tokenizer.txt") + # ... (rest of the save_vocabulary logic is unchanged) + return (file_path,) + + def get_chat_template(self, chat_template = None, tools = None): + return self.tokenizer.get_chat_template(chat_template, tools) + + def export(self, save_directory, model_path=None, model_type=None): + """ + Export tokenizer to MNN format with comprehensive tokenizer type support. + + Args: + save_directory: Directory to save the exported tokenizer + model_path: Optional model path for tokenizer file discovery + model_type: Optional model type for special handling + + Returns: + str: Path to the exported tokenizer file + """ + import os + import base64 + + # Use provided values or fall back to instance values + if model_path is None: + model_path = self.tokenizer_path + if model_type is None: + model_type = self.model_type + + # Create directory if it doesn't exist + os.makedirs(save_directory, exist_ok=True) + + # TOKENIZER MAGIC NUMBER + MAGIC_NUMBER = 430 + # TOKENIZER TYPE + SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3 + + def write_line(fp, *args): + for arg in args: + for token in arg: + fp.write(str(token) + ' ') + fp.write('\n') + + def write_header(fp, type, speicals, prefix=[]): + fp.write(f'{MAGIC_NUMBER} {type}\n') + fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n') + write_line(fp, speicals, self.stop_ids, prefix) + + file_path = os.path.join(save_directory, "tokenizer.txt") + + # Collect special tokens from various sources + special_list = list(self.tokenizer.added_tokens_decoder.keys()) + if hasattr(self.tokenizer, 'special_tokens'): + for k, v in self.tokenizer.special_tokens.items(): + special_list.append(v) + if hasattr(self.tokenizer, 'all_special_ids'): + special_list.extend(self.tokenizer.all_special_ids) + if hasattr(self.tokenizer, 'gmask_token_id'): + special_list.append(self.tokenizer.gmask_token_id) + + # Handle generation_config special tokens + if hasattr(self.tokenizer, 'generation_config') and self.tokenizer.generation_config is not None: + generation_config = self.tokenizer.generation_config + if hasattr(generation_config, 'user_token_id'): + special_list.append(generation_config.user_token_id) + if hasattr(generation_config, 'assistant_token_id'): + special_list.append(generation_config.assistant_token_id) + + vocab_list = [] + prefix_list = [] + + # Get prefix tokens + if hasattr(self.tokenizer, 'get_prefix_tokens'): + prefix_list = self.tokenizer.get_prefix_tokens() + + # Simple prefix token detection + if len(prefix_list) == 0: + try: + test_txt = 'A' + ids = self.tokenizer.encode(test_txt) + get_txt = self.tokenizer.decode(ids[-1]) + if len(ids) > 1 and get_txt == test_txt: + prefix_list += ids[:-1] + except Exception: + pass + + # Load SentencePiece model if available + sp_model = None + tokenizer_model = os.path.join(model_path, 'tokenizer.model') + ice_text_model = os.path.join(model_path, 'ice_text.model') + + try: + import sentencepiece as spm + if os.path.exists(tokenizer_model): + sp_model = spm.SentencePieceProcessor(tokenizer_model) + elif os.path.exists(ice_text_model): + sp_model = spm.SentencePieceProcessor(ice_text_model) + except Exception: + sp_model = None + + # Check for merge file (BERT/HuggingFace tokenizers) + merge_file = os.path.join(model_path, 'merges.txt') + merge_txt = merge_file if os.path.exists(merge_file) else None + + if sp_model is not None: + # SentencePiece tokenizer export + NORMAL = 1; UNKNOWN = 2; CONTROL = 3 + USER_DEFINED = 4; UNUSED = 5; BYTE = 6 + + for i in range(sp_model.GetPieceSize()): + token = sp_model.IdToPiece(i) + score = sp_model.GetScore(i) + token_type = NORMAL + if sp_model.IsUnknown(i): + token_type = UNKNOWN + elif sp_model.IsControl(i): + token_type = CONTROL + elif sp_model.IsUnused(i): + token_type = UNUSED + elif sp_model.IsByte(i): + token_type = BYTE + + # Handle special cases for specific models + if model_path == 'Chatglm_6b': + if '' in token: token = '\n' + if '<|tab|>' in token: token = '\t' + if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')]) + if '▁' in token: token = token.replace('▁', ' ') + + token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") + vocab_list.append(f'{token_encode} {score} {token_type}\n') + + # Add special tokens to vocab_list + for index in special_list: + if index >= len(vocab_list): + try: + token = self.tokenizer.decode(index) + token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") + vocab_list.append(f'{token_encode} {0} {NORMAL}\n') + except: + pass + + # Write SentencePiece format + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, SENTENCEPIECE, special_list, prefix_list) + if model_type == "gemma3" or model_type == "gemma3-text": + fp.write(f'{len(vocab_list) + 1}\n') # +1 for image_soft_token + else: + fp.write(f'{len(vocab_list)}\n') + for vocab in vocab_list: + fp.write(vocab) + + elif hasattr(self.tokenizer, 'mergeable_ranks'): + # TikToken tokenizer export + vocab_list = [] + for k, v in self.tokenizer.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + "\n" + vocab_list.append(line) + if hasattr(self.tokenizer, 'special_tokens'): + for k, v in self.tokenizer.special_tokens.items(): + line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n" + vocab_list.append(line) + if hasattr(self.tokenizer, 'added_tokens_decoder'): + for k, v in self.tokenizer.added_tokens_decoder.items(): + line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n" + vocab_list.append(line) + + # Write TikToken format + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, TIKTOIKEN, special_list, prefix_list) + fp.write(f'{len(vocab_list)}\n') + for vocab in vocab_list: + fp.write(vocab) + + elif merge_txt is not None: + # HuggingFace/BERT tokenizer export + merge_list = [] + vocab = self.tokenizer.get_vocab() + special_list = list(self.tokenizer.added_tokens_decoder.keys()) + vocab_list = ['' for i in range(len(vocab))] + + # Load vocab + for k, v in vocab.items(): + vocab_list[int(v)] = k + + # Load merge + with open(merge_txt, 'rt') as merge: + for line in merge.readlines(): + merge_list.append(line) + + # Write HuggingFace format + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, HUGGINGFACE, special_list) + fp.write(f'{len(vocab_list)} {len(merge_list)}\n') + for v in vocab_list: + fp.write(v + '\n') + for m in merge_list: + fp.write(m) + else: + # Auto-detect tokenizer type and export + tokenizer_class_name = type(self.tokenizer).__name__.lower() + vocab = self.tokenizer.get_vocab() + + # Check for SentencePiece-based tokenizers + if ('xlmroberta' in tokenizer_class_name or + 'roberta' in tokenizer_class_name or + 'sentencepiece' in tokenizer_class_name or + hasattr(self.tokenizer, 'sp_model') or + (hasattr(self.tokenizer, 'vocab_file') and + self.tokenizer.vocab_file and 'sentencepiece' in self.tokenizer.vocab_file.lower()) or + # Check for SentencePiece patterns (▁ prefix) + (len(vocab) > 0 and any('▁' in token for token in list(vocab.keys())[:100]))): + tokenizer_type = SENTENCEPIECE + print(f"Detected SentencePiece-based tokenizer: {tokenizer_class_name}") + elif 'bert' in tokenizer_class_name: + tokenizer_type = BERT + print(f"Detected BERT tokenizer: {tokenizer_class_name}") + else: + tokenizer_type = TIKTOIKEN + print(f"Detected TikToken tokenizer: {tokenizer_class_name}") + + vocab = self.tokenizer.get_vocab() + + if tokenizer_type == SENTENCEPIECE: + # Handle SentencePiece tokenizer + vocab_list = [] + NORMAL = 1 + + for token, token_id in sorted(vocab.items(), key=lambda x: x[1]): + try: + token_bytes = token.encode('utf-8') + token_b64 = base64.b64encode(token_bytes).decode('utf-8') + vocab_list.append(f'{token_b64} 0.0 {NORMAL}\n') + except Exception as e: + print(f"Warning: Failed to encode SentencePiece token '{token}': {e}") + token_b64 = base64.b64encode('▁'.encode('utf-8')).decode('utf-8') + vocab_list.append(f'{token_b64} 0.0 {NORMAL}\n') + + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, SENTENCEPIECE, special_list, prefix_list) + fp.write(f'{len(vocab_list)}\n') + for vocab_line in vocab_list: + fp.write(vocab_line) + else: + # Handle BERT or TikToken tokenizer + def unicode_to_byte(u: int): + # Handle special unicode mappings for BERT tokenizers + if u >= 256 and u <= 288: + return u - 256 + if u >= 289 and u <= 322: + return u - 162 + if u == 323: + return 173 + return u + + vocab_list = ['' for i in range(len(vocab))] + + for k, v in vocab.items(): + if tokenizer_type == BERT: + try: + vocab_list[int(v)] = k.encode('utf-8') + except Exception as e: + try: + vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]) + except Exception as e2: + print(f"Warning: Failed to encode token '{k}' with id {v}: {e2}") + vocab_list[int(v)] = k.encode('utf-8', errors='replace') + else: + try: + vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]) + except Exception as e2: + print(f"Warning: Failed to encode token '{k}' with id {v}: {e2}") + vocab_list[int(v)] = k.encode('utf-8', errors='replace') + + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, tokenizer_type, special_list) + fp.write(f'{len(vocab_list)}\n') + for v in vocab_list: + line = base64.b64encode(v).decode("utf8") + "\n" + fp.write(line) + + return file_path \ No newline at end of file diff --git a/transformers/llm/export/utils/transformers.py b/transformers/llm/export/utils/transformers.py index e792ec3748..1c3d527f5a 100644 --- a/transformers/llm/export/utils/transformers.py +++ b/transformers/llm/export/utils/transformers.py @@ -28,13 +28,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Attention(torch.nn.Module): - def __init__(self, attn, layer_id, config): + def __init__(self, attn, layer_id, config, rotary, mapper): super().__init__() self.export_fused_attn = False if config is None: return self.config = config self.fused_attn = FusedAttention(config.hidden_size, f'/layers.{layer_id}/self_attn/FusedAttention') self.layer_id = layer_id + self.rotary = rotary self.hidden_size = config.hidden_size self.head_dim = config.head_dim if isinstance(config.num_attention_heads, list): @@ -45,9 +46,8 @@ def __init__(self, attn, layer_id, config): self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rotary = config.rotary - ModelMapper.do_map(self, attn, config.model_map['attention']) + ModelMapper.do_map(self, attn, mapper['attention']) if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: # split qkv linear to q, k, v @@ -315,7 +315,7 @@ def forward(self, position_ids): if self.theta_sections is not None: return self.mrope_forward(position_ids) position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * self.theta + idx_theta = position_ids * self.theta.to(position_ids.device) rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)]) if self.model_type == 'ernie4_5': rotary_pos_emb = torch.stack((rotary_pos_emb, rotary_pos_emb), dim=-1) @@ -441,10 +441,10 @@ def forward(self, hidden_states: torch.Tensor, debug=False) -> torch.Tensor: return out class Mlp(torch.nn.Module): - def __init__(self, mlp, config, layer_id): + def __init__(self, mlp, mapper, layer_id): super().__init__() self.layer_id = layer_id - ModelMapper.do_map(self, mlp, config.model_map['mlp']) + ModelMapper.do_map(self, mlp, mapper['mlp']) self.is_moe = hasattr(self, 'experts') self.export_moe = False self.custom_moe = MoE(self.num_experts, self.top_k, layer_id) @@ -571,18 +571,16 @@ def forward(self, hidden_states: torch.Tensor): return final_hidden_states class Decoder(torch.nn.Module): - def __init__(self, decoder, layer_id, config): + def __init__(self, decoder, layer_id, config, rotary=None, mapper=None): super().__init__() - self.cross_decoder = False - ModelMapper.do_map(self, decoder, config.model_map['decoder']) - if 'mlp' in config.model_map: - self.mlp = Mlp(self.mlp, config, layer_id) - # mllama has cross_attn - if hasattr(self, 'cross_attn') and self.cross_attn is not None: - self.cross_decoder = True - self.self_attn = Attention(self.cross_attn, layer_id, config) - else: - self.self_attn = Attention(self.self_attn, layer_id, config) + if rotary is None: + rotary = config.rotary + if mapper is None: + mapper = config.model_map + ModelMapper.do_map(self, decoder, mapper['decoder']) + if 'mlp' in mapper: + self.mlp = Mlp(self.mlp, mapper, layer_id) + self.self_attn = Attention(self.self_attn, layer_id, config, rotary, mapper) self.hidden_size = config.hidden_size if hasattr(config, 'num_hidden_layers'): # minicpm diff --git a/transformers/llm/export/utils/vision.py b/transformers/llm/export/utils/vision.py index af0c6e7767..95c8718bbf 100644 --- a/transformers/llm/export/utils/vision.py +++ b/transformers/llm/export/utils/vision.py @@ -16,17 +16,20 @@ def __init__(self, visual, base): self.quant_block = 128 self.transformer_fuse = True self.group_conv_native = False - self.model_type = base.model_type + self.model_type = base.config.model_type self.visual = visual.eval() self.embed_ = base.embed self.tokenizer = base.tokenizer - self.config = base.config - self.hidden_size = base.hidden_size - self.llm_config = base.llm_config + self.config = base.config.origin_config + self.hidden_size = base.config.hidden_size + self.llm_config = { "is_visual": True } self.rope_ratio = 1.0 self.init_config() self.load() + def get_config(self): + return self.llm_config + @staticmethod def get_vision(model_type): visual_models = { @@ -856,8 +859,8 @@ def __init__(self, visual, base): self.image_mean = np.array([0.5, 0.5, 0.5], dtype=np.float32) self.image_norm = np.array([0.5, 0.5, 0.5], dtype=np.float32) super().__init__(visual, base) - self.connector = base.model.model.connector.float() self.visual = self.visual.float() + self.connector = self.visual.connector.float() self.quant_bit = 8 self.transformer_fuse = False @@ -1389,4 +1392,4 @@ def export(self, onnx_path): "attention_mask": { 0: "num", 1: "size" }, "tgt_sizes": { 0: "num" } }) - return onnx_model \ No newline at end of file + return onnx_model From 2c2d7fa68675191bcdad9ebd9727e9c271b0db07 Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Fri, 19 Dec 2025 07:01:14 +0000 Subject: [PATCH 017/314] opt(RVV): Optimize blitter functions with intrinsics Optimize image blitter and basic format conversion functions using RVV intrinsics, including: - Gray to C3/C4 conversion - RGBA/BGRA to BGR/BGRA reordering - RGB/RGBA/BGR/BGRA to Gray conversion Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 +++++++++++++++++ .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 ++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 +++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 ++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 +++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 ++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 +++++++++++++++++++ 11 files changed, 201 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp new file mode 100644 index 0000000000..145cbea73f --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp @@ -0,0 +1,18 @@ +#include + +void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp new file mode 100644 index 0000000000..d46fe6c85b --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp new file mode 100644 index 0000000000..684db6aed3 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp new file mode 100644 index 0000000000..9d524f13ca --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp @@ -0,0 +1,20 @@ +#include + +void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp new file mode 100644 index 0000000000..952fcaf090 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp @@ -0,0 +1,13 @@ +#include + +void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp new file mode 100644 index 0000000000..5ee4540f98 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp @@ -0,0 +1,16 @@ +#include + +void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp new file mode 100644 index 0000000000..f2b6c7a78d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp new file mode 100644 index 0000000000..ddd67a7d8c --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp new file mode 100644 index 0000000000..d56b58546d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp new file mode 100644 index 0000000000..7c6decf39e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp new file mode 100644 index 0000000000..1b946c33cc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} From 93709a1b29647b45c4a62ee3396bedf95a5886f2 Mon Sep 17 00:00:00 2001 From: "zhaode.wzd" Date: Fri, 19 Dec 2025 16:01:30 +0800 Subject: [PATCH 018/314] [LLM:Feature] Support Context info File. --- docs/transformers/llm.md | 16 ++++++++++++++++ transformers/llm/engine/src/llm.cpp | 22 ++++++++++++++++++++++ transformers/llm/engine/src/llmconfig.hpp | 4 ++++ transformers/llm/engine/src/tokenizer.cpp | 9 ++++++--- transformers/llm/export/utils/config.py | 1 + transformers/llm/export/utils/model.py | 4 ++-- 6 files changed, 51 insertions(+), 5 deletions(-) diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index e2ec2c94c3..c37cb50afc 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -383,6 +383,7 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt - dit_model: 当使用Omni模型时,dit_model的实际路径为`base_dir + dit_model`,默认为`base_dir + 'dit.mnn'` - bigvgan_model: 当使用Omni模型时,bigvgan_model的实际路径为`base_dir + bigvgan_model`,默认为`base_dir + 'bigvgan.mnn'` - spk_dict: 当使用Omni模型时,spk_dict的实际路径为`base_dir + spk_dict`,默认为`base_dir + 'spk_dict.txt'` + - context_file: 配置上下文信息文件路径,实际路径为`base_dir + context_file`,默认`base_dir + 'context.json'`,内容格式为json格式的上下文信息,包含:如tools,enable_thinking等信息。 - 推理配置 - max_new_tokens: 生成时最大token数,默认为`512` - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`. @@ -477,6 +478,21 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt "is_single": true } ``` +- `context.json` + ```json + { + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_time", + "description": "获取当前时间" + } + } + ], + "enable_thinking": false + } + ``` #### 推理用法 `llm_demo`的用法如下: diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 61a3715569..53af11239a 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -234,6 +234,26 @@ bool Llm::load() { // init module status // 1. load vocab mTokenizer.reset(Tokenizer::createTokenizer(mConfig->tokenizer_file())); + // 2. load context + { + std::ifstream contextFile(mConfig->context_file()); + if (contextFile.is_open()) { + std::ostringstream contextStream; + contextStream << contextFile.rdbuf(); + auto contextStr = contextStream.str(); + // check valid json + rapidjson::Document contextDoc; + contextDoc.Parse(contextStr.c_str()); + if (!contextDoc.HasParseError()) { + std::string config_json = R"({ + "jinja": { + "context": )" + contextStr + R"( + } + })"; + mConfig->config_.merge(config_json.c_str()); + } + } + } mDiskEmbedding.reset(new DiskEmbedding(mConfig)); mPrompt.reset(Prompt::createPrompt(mContext, mConfig)); mSampler.reset(Sampler::createSampler(mContext, mConfig)); @@ -870,6 +890,8 @@ void Llm::response(const std::string& user_content, std::ostream* os, const char if (mConfig->use_template()) { prompt = mPrompt->applyTemplate(user_content, true); } + std::cout << "user_content: " << user_content << std::endl; + std::cout << "prompt: " << prompt << std::endl; std::vector input_ids = tokenizer_encode(prompt); response(input_ids, os, end_with, max_new_tokens); } diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index 97eb044a40..ad0686fb4a 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -315,6 +315,10 @@ class LlmConfig { std::string audio_model() const { return base_dir_ + config_.value("audio_model", "audio.mnn"); } + + std::string context_file() const { + return base_dir_ + config_.value("context_file", "context.json"); + } // model file config end > // < generate config start diff --git a/transformers/llm/engine/src/tokenizer.cpp b/transformers/llm/engine/src/tokenizer.cpp index bd1a465d7e..ffbb7d7c04 100644 --- a/transformers/llm/engine/src/tokenizer.cpp +++ b/transformers/llm/engine/src/tokenizer.cpp @@ -430,6 +430,9 @@ void Sentencepiece::encode(const std::string& str, std::vector& ids) { } std::string Sentencepiece::decode(int id) { + if (id < 0 || id >= static_cast(sentence_pieces_.size())) { + return ""; + } auto piece = sentence_pieces_[id].piece; int pos = piece.find("▁"); if (pos != -1) { @@ -481,7 +484,7 @@ void Tiktoken::encode(const std::string& str, std::vector& ids) { } std::string Tiktoken::decode(int id) { - if (id >= decoder_.size()) { + if (id < 0 || id >= static_cast(decoder_.size())) { return ""; } return decoder_[id]; @@ -503,7 +506,7 @@ bool BertTokenizer::load_vocab(std::ifstream& tok_file) { } std::string BertTokenizer::decode(int id) { - if (id >= decoder_.size()) { + if (id < 0 || id >= static_cast(decoder_.size())) { return ""; } return decoder_[id]; @@ -926,7 +929,7 @@ void HuggingfaceTokenizer::encode(const std::string& str, std::vector& ids) std::string HuggingfaceTokenizer::decode(int id) { // printf("decode id = %d, %lu, %s#\n", id, decoder_.size(), decoder_.at(id).c_str()); - if (id >= decoder_.size()) { + if (id < 0 || id >= static_cast(decoder_.size())) { return ""; } auto decode_utf8 = decoder_.at(id); diff --git a/transformers/llm/export/utils/config.py b/transformers/llm/export/utils/config.py index 569cab745e..b747ad6ee7 100644 --- a/transformers/llm/export/utils/config.py +++ b/transformers/llm/export/utils/config.py @@ -65,6 +65,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): llm_config.attention_type = 'sliding' elif len(sliding_attn_layers) > 0: llm_config.attention_type = 'mix' + llm_config.sliding_attn_layers = sliding_attn_layers else: llm_config.attention_type = 'full' diff --git a/transformers/llm/export/utils/model.py b/transformers/llm/export/utils/model.py index a327109e65..63e14dec30 100644 --- a/transformers/llm/export/utils/model.py +++ b/transformers/llm/export/utils/model.py @@ -117,8 +117,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, args=None, **kwargs): ]) model.lm = Lm(model.lm) - if 'gemma' in model_type: - model.scale_emb = model.embed.embedscale + if 'gemma' in model_type and hasattr(model.embed, 'embed_scale'): + model.scale_emb = model.embed.embed_scale # Multi-modal parts if model.visual is not None: From 2ec9dbe5ccee13284d7f3ea89862333cff44a812 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=A5=E9=81=97?= Date: Fri, 19 Dec 2025 18:05:00 +0800 Subject: [PATCH 019/314] taoavatar support supertonic tts --- apps/Android/MnnTaoAvatar/README.md | 3 + apps/Android/MnnTaoAvatar/README_CN.md | 3 + apps/Android/MnnTaoAvatar/app/build.gradle | 7 +- .../app/src/main/cpp/tts/CMakeLists.txt | 15 +- .../main/cpp/tts/include/mnn_tts_config.hpp | 11 +- .../src/main/cpp/tts/include/mnn_tts_sdk.hpp | 9 +- .../supertonic/mnn_supertonic_tts_impl.hpp | 8 +- .../src/main/cpp/tts/src/mnn_tts_config.cpp | 36 +++- .../app/src/main/cpp/tts/src/mnn_tts_sdk.cpp | 84 +++++++-- .../supertonic/mnn_supertonic_tts_impl.cpp | 97 ++++++++-- .../src/main/cpp/tts_droid/tts_service.cpp | 13 +- .../src/main/cpp/tts_droid/tts_service.hpp | 4 +- .../main/cpp/tts_droid/tts_service_jni.cpp | 28 ++- .../java/com/taobao/meta/avatar/MHConfig.kt | 5 +- .../com/taobao/meta/avatar/MainActivity.kt | 2 +- .../taobao/meta/avatar/debug/DebugModule.kt | 4 +- .../meta/avatar/download/DownloadManager.kt | 174 ------------------ .../meta/avatar/download/DownloadModule.kt | 3 +- .../meta/avatar/settings/MainSettings.kt | 32 +++- .../avatar/settings/MainSettingsFragment.kt | 30 +++ .../com/taobao/meta/avatar/tts/TtsService.kt | 116 +++++++++++- .../app/src/main/res/values-zh/arrays.xml | 17 ++ .../app/src/main/res/values-zh/strings.xml | 6 + .../app/src/main/res/values/arrays.xml | 17 ++ .../app/src/main/res/values/strings.xml | 6 + .../src/main/res/xml/main_settings_prefs.xml | 20 ++ 26 files changed, 509 insertions(+), 241 deletions(-) delete mode 100644 apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadManager.kt create mode 100644 apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/arrays.xml create mode 100644 apps/Android/MnnTaoAvatar/app/src/main/res/values/arrays.xml diff --git a/apps/Android/MnnTaoAvatar/README.md b/apps/Android/MnnTaoAvatar/README.md index 2eb1d9c3d5..5b8fc09a27 100644 --- a/apps/Android/MnnTaoAvatar/README.md +++ b/apps/Android/MnnTaoAvatar/README.md @@ -64,6 +64,9 @@ cd apps/Android/MnnTaoAvatar ``` ## Releases +## Version 0.0.2 ++ Click here to [download](https://meta.alicdn.com/data/mnn/avatar/mnn_taoavatar_0_0_2.apk) ++ support supertonic tts for TaoAvatar ## Version 0.0.1 + Click here to [download](https://meta.alicdn.com/data/mnn/avatar/mnn_taoavatar_0_0_1.apk) + this is our first public released version; you can chat with 3d avatar in the app with asr and tts; if you have any questions, please feel free to open an issue for assistance. diff --git a/apps/Android/MnnTaoAvatar/README_CN.md b/apps/Android/MnnTaoAvatar/README_CN.md index e6f7cc2af9..2f02e56bef 100644 --- a/apps/Android/MnnTaoAvatar/README_CN.md +++ b/apps/Android/MnnTaoAvatar/README_CN.md @@ -54,6 +54,9 @@ cd apps/Android/MnnTaoAvatar ## Releases +## Version 0.0.2 ++ 点击这里[下载](https://meta.alicdn.com/data/mnn/avatar/mnn_taoavatar_0_0_2.apk) ++ 新增对TaoAvatar的Supertonic TTS支持 ## Version 0.0.1 + 点击这里[下载](https://meta.alicdn.com/data/mnn/avatar/mnn_taoavatar_0_0_1.apk) + 这是我们首次公开发布的版本;您可以在应用程序中通过语音识别(ASR)和语音合成(TTS)与3D虚拟形象进行聊天;如果您有任何问题,请随时提交Issue以获得帮助。 diff --git a/apps/Android/MnnTaoAvatar/app/build.gradle b/apps/Android/MnnTaoAvatar/app/build.gradle index 423c5ab54e..45fc04e7a8 100644 --- a/apps/Android/MnnTaoAvatar/app/build.gradle +++ b/apps/Android/MnnTaoAvatar/app/build.gradle @@ -55,7 +55,7 @@ android { minSdk 26 targetSdk 35 versionCode 1 - versionName "0.0.1" + versionName "0.0.2" externalNativeBuild { cmake { @@ -99,9 +99,4 @@ dependencies { implementation 'com.squareup.retrofit2:retrofit:2.9.0' implementation 'com.squareup.okhttp3:logging-interceptor:4.9.3' implementation 'com.squareup.retrofit2:converter-scalars:2.9.0' - implementation "com.liulishuo.okdownload:okdownload:${okdownload_version}" - implementation "com.liulishuo.okdownload:sqlite:${okdownload_version}" - implementation "com.liulishuo.okdownload:okhttp:${okdownload_version}" - implementation "com.liulishuo.okdownload:filedownloader:${okdownload_version}" - implementation "com.liulishuo.okdownload:ktx:${okdownload_version}" } diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/CMakeLists.txt b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/CMakeLists.txt index 6432d02b92..c7ec0d4c2c 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/CMakeLists.txt +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/CMakeLists.txt @@ -8,6 +8,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(BUILD_BERTVITS2 "Build BertVit2 TTS " ON) option(BUILD_PIPER "Build PIPER TTS " OFF) +option(BUILD_SUPERTONIC "Build Supertonic TTS " ON) include_directories( ${CMAKE_CURRENT_LIST_DIR}/include @@ -27,6 +28,12 @@ if(BUILD_PIPER) ) endif() +if(BUILD_SUPERTONIC) + include_directories( + ${CMAKE_CURRENT_LIST_DIR}/include/supertonic + ) +endif() + set(SHARED_SOURCE_FILES ${CMAKE_CURRENT_LIST_DIR}/src/mnn_tts_config.cpp ${CMAKE_CURRENT_LIST_DIR}/src/mnn_tts_sdk.cpp @@ -55,9 +62,15 @@ set(PIPER_SOURCE_FILES ) endif() +if(BUILD_SUPERTONIC) +set(SUPERTONIC_SOURCE_FILES + ${CMAKE_CURRENT_LIST_DIR}/src/supertonic/mnn_supertonic_tts_impl.cpp +) +endif() + if(BUILD_PIPER) add_subdirectory(third_party/piper/espeak-ng) endif() -add_library(${PROJECT_NAME} SHARED ${PIPER_SOURCE_FILES} ${BERTVITS2_SOURCE_FILES} ${SHARED_SOURCE_FILES}) +add_library(${PROJECT_NAME} SHARED ${PIPER_SOURCE_FILES} ${BERTVITS2_SOURCE_FILES} ${SUPERTONIC_SOURCE_FILES} ${SHARED_SOURCE_FILES}) target_link_libraries(${PROJECT_NAME} log MNN ) diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp index 8df703c026..c90e028f0d 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_config.hpp @@ -17,6 +17,13 @@ class MNNTTSConfig { public: explicit MNNTTSConfig(const std::string &config_file_path); + + // 支持参数覆盖的构造函数 + MNNTTSConfig(const std::string &config_file_path, + const std::map &overrides); + + // 应用参数覆盖 + void applyOverrides(const std::map &overrides); // 模板方法的实现必须放在头文件中或者在源文件中模板实例化 template @@ -46,8 +53,4 @@ class MNNTTSConfig std::string asset_folder_; std::string cache_folder_; int sample_rate_; - std::string precision_; - std::string speaker_id_; - int iter_steps_; - float speed_; }; \ No newline at end of file diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp index 21246a0152..15c7f39b99 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/mnn_tts_sdk.hpp @@ -14,13 +14,20 @@ class MNNTTSSDK { public: - MNNTTSSDK(const std::string &config_folder); + MNNTTSSDK(const std::string &config_folder, const std::string ¶ms_json = "{}"); // synthesize audio std::tuple Process(const std::string &text); void WriteAudioToFile(const Audio &audio_data, const std::string &output_file_path); + + // Set speaker ID dynamically (only for supertonic model) + void SetSpeakerId(const std::string &speaker_id); private: int sample_rate_; std::shared_ptr impl_; + std::string model_type_; // Store model type for SetSpeakerId + + // 辅助方法:JSON 字符串转 map + std::map parseJsonToMap(const std::string &json_str); }; diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp index 24b2c72553..723dcb3382 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/include/supertonic/mnn_supertonic_tts_impl.hpp @@ -39,7 +39,8 @@ struct VoiceStyle class MNNSupertonicTTSImpl : public MNNTTSImplBase { public: - MNNSupertonicTTSImpl(const std::string &models_dir, const std::string &precision_dir, const std::string &speaker_id, int iter_steps, float speed); + MNNSupertonicTTSImpl(const std::string &models_dir, + const std::map &overrides = {}); // Core Synthesis Interface std::tuple Process(const std::string &text); @@ -48,9 +49,8 @@ class MNNSupertonicTTSImpl : public MNNTTSImplBase std::tuple synthesize(const std::string &text, const VoiceStyle &voice_styl, int steps, float speed); - // Save Audio - static bool save(const std::string &filename, - const std::vector &audio_data, int sample_rate); + // Set speaker ID dynamically (no restart required) + void SetSpeakerId(const std::string &speaker_id); private: // --- Configuration --- diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp index 75490a6b8c..dd8398fb3a 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_config.cpp @@ -34,11 +34,6 @@ MNNTTSConfig::MNNTTSConfig(const std::string &config_json_path) asset_folder_ = get_value_from_json(raw_config_data_, "asset_folder"); cache_folder_ = get_value_from_json(raw_config_data_, "cache_folder"); sample_rate_ = get_value_from_json(raw_config_data_, "sample_rate"); - precision_ = get_value_from_json(raw_config_data_, "precision"); - speaker_id_ = get_value_from_json(raw_config_data_, "speaker_id"); - iter_steps_ = get_value_from_json(raw_config_data_, "iter_steps"); - speed_ = get_value_from_json(raw_config_data_, "speed"); - } catch (const std::runtime_error &e) { @@ -46,3 +41,34 @@ MNNTTSConfig::MNNTTSConfig(const std::string &config_json_path) throw std::runtime_error("Error in config file " + config_json_path + ": " + e.what()); } } + +// 新增:支持参数覆盖的构造函数 +MNNTTSConfig::MNNTTSConfig(const std::string &config_file_path, + const std::map &overrides) + : MNNTTSConfig(config_file_path) // 委托给原构造函数 +{ + // 应用参数覆盖 + applyOverrides(overrides); +} + +// 应用参数覆盖 +void MNNTTSConfig::applyOverrides(const std::map &overrides) { + if (overrides.empty()) { + return; + } + + for (const auto& [key, value] : overrides) { + try { + if (key == "model_type") { + model_type_ = value; + } else if (key == "sample_rate") { + sample_rate_ = std::stoi(value); + } + // 可以继续添加其他参数的覆盖逻辑 + } catch (const std::exception &e) { + // 忽略无法转换的参数,使用配置文件中的默认值 + std::cerr << "Warning: Failed to override parameter '" << key + << "' with value '" << value << "': " << e.what() << std::endl; + } + } +} diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp index 478a5ffd71..2609ea067e 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/mnn_tts_sdk.cpp @@ -1,36 +1,77 @@ #include "mnn_tts_sdk.hpp" #include "piper/utf8.h" #include "supertonic/mnn_supertonic_tts_impl.hpp" +#include "mnn_tts_logger.hpp" +#include "nlohmann/json.hpp" #include // For std::wstring_convert and std::codecvt_utf8 #include #include -MNNTTSSDK::MNNTTSSDK(const std::string &config_folder) +using json = nlohmann::json; + +// 辅助方法:将 JSON 字符串解析为 std::map +std::map MNNTTSSDK::parseJsonToMap(const std::string &json_str) { + std::map result; + + if (json_str.empty() || json_str == "{}") { + return result; + } + + try { + json j = json::parse(json_str); + + // 遍历所有键值对,全部转换为 string + for (auto& [key, value] : j.items()) { + if (value.is_string()) { + result[key] = value.get(); + } else if (value.is_number_integer()) { + result[key] = std::to_string(value.get()); + } else if (value.is_number_float()) { + result[key] = std::to_string(value.get()); + } else if (value.is_boolean()) { + result[key] = value.get() ? "true" : "false"; + } else { + // 其他类型转为字符串 + result[key] = value.dump(); + } + } + } catch (const json::exception &e) { + PLOG(ERROR, "Failed to parse JSON: " + std::string(e.what())); + } + + return result; +} + +MNNTTSSDK::MNNTTSSDK(const std::string &config_folder, const std::string ¶ms_json) { std::string config_json_path = config_folder + "/config.json"; - auto config = MNNTTSConfig(config_json_path); - auto model_type = config.model_type_; + + // 1. 解析 JSON 为 map + auto overrides = parseJsonToMap(params_json); + + // 2. 创建 MNNTTSConfig,传入 overrides + auto config = MNNTTSConfig(config_json_path, overrides); + + model_type_ = config.model_type_; auto model_path = config_folder + "/" + config.model_path_; auto assset_folder = config_folder + "/" + config.asset_folder_; auto cache_folder = config_folder + "/" + config.cache_folder_; sample_rate_ = config.sample_rate_; - if (model_type == "piper") + + if (model_type_ == "piper") { impl_ = nullptr; // std::make_shared(assset_folder, model_path, cache_folder); } - else if (model_type == "bertvits") + else if (model_type_ == "bertvits") { impl_ = std::make_shared(assset_folder, model_path, cache_folder); } - else if (model_type == "supertonic") + else if (model_type_ == "supertonic") { auto model_dir = config_folder; - std::string precision = config.precision_; - std::string speaker_id = config.speaker_id_; - int iter_steps = config.iter_steps_; - float speed = config.speed_; - impl_ = std::make_shared(model_dir, precision, speaker_id, iter_steps, speed); + // Pass overrides to MNNSupertonicTTSImpl, which will read precision, speaker_id, iter_steps, speed from config.json + impl_ = std::make_shared(model_dir, overrides); } else { @@ -54,3 +95,24 @@ void MNNTTSSDK::WriteAudioToFile(const Audio &audio_data, audioFile.write((const char *)audio_data.data(), sizeof(int16_t) * audio_data.size()); } + +void MNNTTSSDK::SetSpeakerId(const std::string &speaker_id) +{ + // Only supertonic model supports dynamic speaker_id change + if (model_type_ != "supertonic") + { + PLOG(WARNING, "SetSpeakerId is only supported for supertonic model, current model: " + model_type_); + return; + } + + // Cast to MNNSupertonicTTSImpl and call SetSpeakerId + auto supertonic_impl = std::dynamic_pointer_cast(impl_); + if (supertonic_impl) + { + supertonic_impl->SetSpeakerId(speaker_id); + } + else + { + PLOG(ERROR, "Failed to cast impl to MNNSupertonicTTSImpl"); + } +} diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp index 46f28a08e0..af8df6e048 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts/src/supertonic/mnn_supertonic_tts_impl.cpp @@ -29,7 +29,7 @@ namespace bool is_emoji(char32_t cp) { if (cp >= 0x1F600 && cp <= 0x1F64F) - return true; // emoticons + return true; // emoticons if (cp >= 0x1F300 && cp <= 0x1F5FF) return true; // symbols & pictographs if (cp >= 0x1F680 && cp <= 0x1F6FF) @@ -295,23 +295,82 @@ MNNSupertonicTTSImpl::TextProcessor::encode(const std::string &text) // MNNSupertonicTTSImpl implementation MNNSupertonicTTSImpl::MNNSupertonicTTSImpl( const std::string &models_dir, - const std::string &precision_dir, - const std::string &speaker_id, - int iter_steps, - float speed) - : models_dir_(models_dir), - precision_dir_(precision_dir), - speaker_id_(speaker_id), - iter_steps_(iter_steps), - speed_(speed) + const std::map &overrides) + : models_dir_(models_dir) { + PLOG(INFO, "Initializing Supertonic TTS with models_dir: " + models_dir_); + + // Load config.json to get precision, speaker_id, iter_steps, speed + std::string config_json_path = models_dir_ + "/config.json"; + json config_json; + + // Try to read config.json + std::ifstream config_json_file(config_json_path); + if (config_json_file.is_open()) { + try { + config_json_file >> config_json; + config_json_file.close(); + } catch (const std::exception &e) { + PLOG(WARNING, "Failed to parse config.json: " + std::string(e.what())); + } + } else { + PLOG(WARNING, "config.json not found, using defaults"); + } + + // Get precision: from overrides, then config.json, then default + if (overrides.find("precision") != overrides.end() && !overrides.at("precision").empty()) { + precision_dir_ = overrides.at("precision"); + } else if (config_json.contains("precision") && config_json["precision"].is_string()) { + precision_dir_ = config_json["precision"].get(); + } else { + precision_dir_ = "fp16"; // default + } + + // Get speaker_id: from overrides, then config.json, then default + if (overrides.find("speaker_id") != overrides.end() && !overrides.at("speaker_id").empty()) { + speaker_id_ = overrides.at("speaker_id"); + } else if (config_json.contains("speaker_id") && config_json["speaker_id"].is_string()) { + speaker_id_ = config_json["speaker_id"].get(); + } else { + speaker_id_ = "M1"; // default + } + + // Get iter_steps: from overrides, then config.json, then default + if (overrides.find("iter_steps") != overrides.end() && !overrides.at("iter_steps").empty()) { + try { + iter_steps_ = std::stoi(overrides.at("iter_steps")); + } catch (const std::exception &e) { + PLOG(WARNING, "Failed to parse iter_steps from overrides, using default"); + iter_steps_ = 10; // default + } + } else if (config_json.contains("iter_steps") && config_json["iter_steps"].is_number_integer()) { + iter_steps_ = config_json["iter_steps"].get(); + } else { + iter_steps_ = 10; // default + } + + // Get speed: from overrides, then config.json, then default + if (overrides.find("speed") != overrides.end() && !overrides.at("speed").empty()) { + try { + speed_ = std::stof(overrides.at("speed")); + } catch (const std::exception &e) { + PLOG(WARNING, "Failed to parse speed from overrides, using default"); + speed_ = 1.0f; // default + } + } else if (config_json.contains("speed") && config_json["speed"].is_number_float()) { + speed_ = config_json["speed"].get(); + } else { + speed_ = 1.0f; // default + } std::cout << "model_dir_" << models_dir_ << std::endl; std::cout << "precsion_dir: " << precision_dir_ << std::endl; + std::cout << "speaker_id: " << speaker_id_ << std::endl; + std::cout << "iter_steps: " << iter_steps_ << std::endl; + std::cout << "speed: " << speed_ << std::endl; std::cout << "cache_dir_: " << cache_dir_ << std::endl; - PLOG(INFO, "Initializing Supertonic TTS with models_dir: " + models_dir_); - // Load config + // Load tts.json config std::string config_path = models_dir_ + "/mnn_models/tts.json"; std::ifstream config_file(config_path); if (!config_file.is_open()) @@ -482,11 +541,17 @@ std::tuple MNNSupertonicTTSImpl::Process(const std::string &text) return synthesize(processed_text, voice_styles_[voice_name], steps, speed); } -bool MNNSupertonicTTSImpl::save(const std::string &filename, - const std::vector &audio_data, - int sample_rate) +void MNNSupertonicTTSImpl::SetSpeakerId(const std::string &speaker_id) { - return writeWavFile(filename, audio_data, sample_rate); + // Validate speaker_id exists in voice_styles_ + if (voice_styles_.find(speaker_id) == voice_styles_.end()) + { + PLOG(ERROR, "Cannot set speaker_id to invalid value: " + speaker_id); + throw std::runtime_error("Invalid speaker_id: " + speaker_id); + } + + speaker_id_ = speaker_id; + PLOG(INFO, "Speaker ID changed to: " + speaker_id_); } std::tuple diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.cpp index 8f867fc46f..d762026ead 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.cpp @@ -11,11 +11,12 @@ TTSService::~TTSService() { tts_ = nullptr; } -bool TTSService::LoadTtsResources(const char *resPath, const char* modelName, const char* cacheDir) { +bool TTSService::LoadTtsResources(const char *resPath, const char* modelName, + const char* cacheDir, const std::string ¶msJson) { MH_DEBUG("TTSService::LoadTtsResources resPath: %s", resPath); + MH_DEBUG("TTSService::LoadTtsResources paramsJson: %s", paramsJson.c_str()); if (!tts_) { - tts_ = std::make_shared( - std::string(resPath)); + tts_ = std::make_shared(std::string(resPath), paramsJson); } if (!tts_) { MH_ERROR("Failed to create TTSService."); @@ -42,6 +43,12 @@ void TTSService::SetIndex(int index) { current_index_ = index; } +void TTSService::SetSpeakerId(const std::string &speaker_id) { + if (tts_) { + tts_->SetSpeakerId(speaker_id); + } +} + std::vector TTSService::Process(const std::string &text, int id) { if (tts_ != nullptr && (!text.empty())) { auto audio = tts_->Process(text); diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.hpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.hpp index e90d323ae4..f6e66b30e7 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.hpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service.hpp @@ -10,9 +10,11 @@ namespace TaoAvatar { class TTSService { public: explicit TTSService(std::string language); - bool LoadTtsResources(const char *resPath, const char* modelName, const char* cacheDir); + bool LoadTtsResources(const char *resPath, const char* modelName, + const char* cacheDir, const std::string ¶msJson = "{}"); // 新增参数 std::vector Process(const std::string &text, int id); void SetIndex(int index); + void SetSpeakerId(const std::string &speaker_id); // 动态设置音色(仅英文模式) virtual ~TTSService(); private: std::shared_ptr tts_ = nullptr; diff --git a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service_jni.cpp b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service_jni.cpp index 96b7e692cb..e53bee77b9 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service_jni.cpp +++ b/apps/Android/MnnTaoAvatar/app/src/main/cpp/tts_droid/tts_service_jni.cpp @@ -31,19 +31,28 @@ Java_com_taobao_meta_avatar_tts_TtsService_nativeLoadResourcesFromFile(JNIEnv *e jlong nativePtr, jstring resourceDir, jstring modelName, - jstring cacheDir) { + jstring cacheDir, + jstring paramsJson) { // 新增:JSON 参数字符串 std::unique_lock lock(gTTSMutex); auto ttsService = reinterpret_cast(nativePtr); const char *resourceDirCStr = env->GetStringUTFChars(resourceDir, nullptr); const char *modelNameCStr = env->GetStringUTFChars(modelName, nullptr); const char *cacheDirCStr = env->GetStringUTFChars(cacheDir, nullptr); + const char *paramsJsonCStr = env->GetStringUTFChars(paramsJson, nullptr); + bool result = false; if (ttsService) { - result = ttsService->LoadTtsResources(resourceDirCStr, modelNameCStr, cacheDirCStr); + std::string paramsJsonStr = (paramsJsonCStr != nullptr) ? + std::string(paramsJsonCStr) : "{}"; + result = ttsService->LoadTtsResources(resourceDirCStr, modelNameCStr, + cacheDirCStr, paramsJsonStr); } - env->ReleaseStringUTFChars(modelName, modelNameCStr); + env->ReleaseStringUTFChars(resourceDir, resourceDirCStr); + env->ReleaseStringUTFChars(modelName, modelNameCStr); env->ReleaseStringUTFChars(cacheDir, cacheDirCStr); + env->ReleaseStringUTFChars(paramsJson, paramsJsonCStr); + return result ? JNI_TRUE : JNI_FALSE; } @@ -70,4 +79,17 @@ Java_com_taobao_meta_avatar_tts_TtsService_nativeSetCurrentIndex(JNIEnv *env, jo tts_service->SetIndex(index); } +JNIEXPORT void JNICALL +Java_com_taobao_meta_avatar_tts_TtsService_nativeSetSpeakerId(JNIEnv *env, jobject thiz, + jlong nativePtr, + jstring speakerId) { + std::unique_lock lock(gTTSMutex); + auto ttsService = reinterpret_cast(nativePtr); + const char *speakerIdCStr = env->GetStringUTFChars(speakerId, nullptr); + if (ttsService) { + ttsService->SetSpeakerId(std::string(speakerIdCStr)); + } + env->ReleaseStringUTFChars(speakerId, speakerIdCStr); +} + } \ No newline at end of file diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MHConfig.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MHConfig.kt index 62cbc024bf..86482be54e 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MHConfig.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MHConfig.kt @@ -20,9 +20,12 @@ object MHConfig { val TTS_MODEL_DIR get() = "${BASE_DIR}/bert-vits2-MNN/" + val TTS_MODEL_DIR_EN + get() = "${BASE_DIR}/supertonic-tts-mnn/" + val A2BS_MODEL_DIR get() = "${BASE_DIR}/UniTalker-MNN/" - +// /data/data/com.taobao.meta.avatar/files/.mnnmodels/modelscope/supertonic-tts-mnn val LLM_MODEL_DIR get() = "${BASE_DIR}/Qwen2.5-1.5B-Instruct-MNN" diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MainActivity.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MainActivity.kt index fc7ab2c44a..2cbc8e6c73 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MainActivity.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/MainActivity.kt @@ -360,7 +360,7 @@ class MainActivity : AppCompatActivity(), } private suspend fun loadTTSModel() { - ttsService!!.init(MHConfig.TTS_MODEL_DIR) + ttsService!!.init(MHConfig.TTS_MODEL_DIR, context = this) } private suspend fun loadA2BSModel() { diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/debug/DebugModule.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/debug/DebugModule.kt index fc3edf082b..b2109afe9b 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/debug/DebugModule.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/debug/DebugModule.kt @@ -122,7 +122,7 @@ class DebugModule { var ttsService:TtsService? = null if (DEBUG_DISABLE_SERVICE_AUTO_START) { ttsService = TtsService() - ttsService.init(MHConfig.TTS_MODEL_DIR) + ttsService.init(MHConfig.TTS_MODEL_DIR, context = activity) ttsService.waitForInitComplete() } else { ttsService = activity.getTtsService() @@ -214,7 +214,7 @@ class DebugModule { } suspend fun testKokoroZhEn() { - ttsService.init(MHConfig.TTS_MODEL_DIR) + ttsService.init(MHConfig.TTS_MODEL_DIR, context = activity) ttsService.waitForInitComplete() val tts_path = "/data/local/tmp/kokoro-multi-lang-v1_0" val config = OfflineTtsConfig( diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadManager.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadManager.kt deleted file mode 100644 index d4611b8cfb..0000000000 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadManager.kt +++ /dev/null @@ -1,174 +0,0 @@ -// Created by ruoyi.sjd on 2025/3/31. -// Copyright (c) 2024 Alibaba Group Holding Limited All rights reserved. - -package com.taobao.meta.avatar.download - -import android.content.Context -import android.util.Log -import com.alibaba.mnnllm.android.utils.FileUtils.formatFileSize -import com.liulishuo.okdownload.DownloadListener -import com.liulishuo.okdownload.DownloadTask -import com.liulishuo.okdownload.OkDownload -import com.liulishuo.okdownload.core.breakpoint.BreakpointInfo -import com.liulishuo.okdownload.core.breakpoint.BreakpointStoreOnSQLite -import com.liulishuo.okdownload.core.cause.EndCause -import com.liulishuo.okdownload.core.cause.ResumeFailedCause -import com.liulishuo.okdownload.core.dispatcher.DownloadDispatcher -import java.io.File - -class DownloadManager(private val context: Context) { - - companion object { - const val TAG = "DownloadManager" - private const val MODEL_URL = "https://meta.alicdn.com/data/mnn/avatar/qwen2.5-1.5b-instruct-int8-private.zip" - } - - init { - val builder = OkDownload.Builder(context) - .downloadStore(BreakpointStoreOnSQLite(context)) - OkDownload.setSingletonInstance(builder.build()) - DownloadDispatcher.setMaxParallelRunningCount(3); - } - - private var downloadCallback: DownloadCallback? = null - private var lastProgressTime: Long = 0 - private var lastProgressBytes: Long = 0 - - fun getDownloadPath(): String { - return context.filesDir.absolutePath + "/metahuman" - } - - fun getDownloadSuccessFlagPath(): String { - return context.filesDir.absolutePath + "/metahuman/success" - } - - fun setDownloadCallback(callback: DownloadCallback) { - downloadCallback = callback - } - - fun isDownloadComplete():Boolean { - val file = File(getDownloadSuccessFlagPath()) - return file.exists() - } - - fun download() { - val targetFile = File(getDownloadPath() + "_tmp") - val url = MODEL_URL - val filename = "metahuman-model.zip" - val task = DownloadTask.Builder(url, targetFile) - .setFilename(filename) - .setConnectionCount(1) - .setMinIntervalMillisCallbackProcess(100) - .setPassIfAlreadyCompleted(true) - .build() - var downloadSpeedStr = "0Bps" - task.enqueue(object : DownloadListener { - override fun taskStart(task: DownloadTask) { - Log.d(TAG, "taskStart") - } - - override fun connectTrialStart( - task: DownloadTask, - requestHeaderFields: MutableMap> - ) { - Log.d(TAG, "connectTrialStart") - downloadCallback?.onDownloadStart() - } - - override fun connectTrialEnd( - task: DownloadTask, - responseCode: Int, - responseHeaderFields: MutableMap> - ) { - Log.d(TAG, "connectTrialEnd") - } - - override fun downloadFromBeginning( - task: DownloadTask, - info: BreakpointInfo, - cause: ResumeFailedCause - ) { - Log.d(TAG, "downloadFromBeginning cause: $cause totalFileLength: ${task.info?.totalLength} " ) - } - - override fun downloadFromBreakpoint(task: DownloadTask, info: BreakpointInfo) { - Log.d(TAG, "downloadFromBreakpoint:") - downloadCallback?.onDownloadStart() - } - - override fun connectStart( - task: DownloadTask, - blockIndex: Int, - requestHeaderFields: MutableMap> - ) { - Log.d(TAG, "connectStart" ) - } - - override fun connectEnd( - task: DownloadTask, - blockIndex: Int, - responseCode: Int, - responseHeaderFields: MutableMap> - ) { - Log.d(TAG, "connectEnd" ) - } - - override fun fetchStart(task: DownloadTask, blockIndex: Int, contentLength: Long) { - Log.d(TAG, "fetchStart") - } - - override fun fetchProgress(task: DownloadTask, blockIndex: Int, increaseBytes: Long) { - if (task.info != null) { -// Log.d(TAG, "Info totalLength: ${task.info?.totalLength}" + -// " totalOffset: ${task.info?.totalOffset} " + -// "blockCount ${task.info?.blockCount} " + -// "realPercent: ${(task.info?.totalOffset?:0).toDouble().div((task.info?.totalLength?:1).toDouble()).times(100)}%" -// ) - val progressPercent = if (task.info!!.totalLength > 0) - (task.info!!.totalOffset.toDouble() / task.info!!.totalLength) * 100 else 0.0 - - val currentTime = System.currentTimeMillis() - if (lastProgressTime == 0L) { - lastProgressTime = currentTime - lastProgressBytes = task.info!!.totalOffset - } - val timeElapsed = currentTime - lastProgressTime - val bytesDownloaded = task.info!!.totalOffset - lastProgressBytes - if (timeElapsed > 1000 && bytesDownloaded > 0) { - val downloadSpeed = bytesDownloaded / timeElapsed * 1000 // bytes per second - downloadSpeedStr = "${formatFileSize(downloadSpeed)}ps" - lastProgressTime = currentTime - lastProgressBytes = task.info!!.totalOffset - } - downloadCallback?.onDownloadProgress(progressPercent, task.info!!.totalOffset, task.info!!.totalLength, downloadSpeedStr) - } - } - - override fun fetchEnd(task: DownloadTask, blockIndex: Int, contentLength: Long) { - Log.d(TAG, "fetchEnd" ) - } - - override fun taskEnd( - task: DownloadTask, - cause: EndCause, - realCause: java.lang.Exception? - ) { - if (cause == EndCause.COMPLETED) { - Log.d(TAG, "download complete: 100% ") - downloadCallback?.onDownloadComplete(true, task.file) - } else if (realCause != null) { - Log.e(TAG, "download end: $cause", realCause) - downloadCallback?.onDownloadError(realCause) - } else { - Log.d(TAG, "download end: $cause") - downloadCallback?.onDownloadComplete(false, task.file) - } - } - }) - } - fun cancelDownload() { - Log.d(TAG, "cancelDownload") - OkDownload.with().downloadDispatcher().cancelAll(); - } - -} diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadModule.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadModule.kt index b25c2cd115..cbd899402a 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadModule.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/download/DownloadModule.kt @@ -24,7 +24,8 @@ class DownloadModule(private val context: Activity) { "taobao-mnn/bert-vits2-MNN", "taobao-mnn/TaoAvatar-NNR-MNN", "taobao-mnn/sherpa-mnn-streaming-zipformer-bilingual-zh-en-2023-02-20", - "taobao-mnn/sherpa-mnn-streaming-zipformer-en-2023-02-21" + "taobao-mnn/sherpa-mnn-streaming-zipformer-en-2023-02-21", + "taobao-mnn/supertonic-tts-mnn" ) private val finishedSet = mutableSetOf() diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettings.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettings.kt index bb1e97635b..d2bfd6ef63 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettings.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettings.kt @@ -9,6 +9,9 @@ import com.taobao.meta.avatar.R object MainSettings { + private const val KEY_TTS_SPEAKER_ID = "tts_speaker_id" + private const val KEY_TTS_SPEED = "tts_speed" + private const val DEFAULT_SPEAKER_ID = "F1" fun getLlmPrompt(context: Context): String { val sharedPreferences = PreferenceManager.getDefaultSharedPreferences(context) @@ -24,5 +27,32 @@ object MainSettings { val sharedPreferences = PreferenceManager.getDefaultSharedPreferences(context) return sharedPreferences.getBoolean("show_debug_info", true) } - + + // TTS Speaker ID + fun getTtsSpeakerId(context: Context): String { + val sharedPreferences = PreferenceManager.getDefaultSharedPreferences(context) + return sharedPreferences.getString(KEY_TTS_SPEAKER_ID, DEFAULT_SPEAKER_ID) + ?: DEFAULT_SPEAKER_ID + } + + fun setTtsSpeakerId(context: Context, speakerId: String) { + PreferenceManager.getDefaultSharedPreferences(context) + .edit() + .putString(KEY_TTS_SPEAKER_ID, speakerId) + .apply() + } + + // TTS Speed (范围 0.5 - 2.0,存储时放大10倍为整数 5-20) + fun getTtsSpeed(context: Context): Float { + val sharedPreferences = PreferenceManager.getDefaultSharedPreferences(context) + val speedInt = sharedPreferences.getInt(KEY_TTS_SPEED, 10) + return speedInt / 10.0f + } + + fun setTtsSpeed(context: Context, speedInt: Int) { + PreferenceManager.getDefaultSharedPreferences(context) + .edit() + .putInt(KEY_TTS_SPEED, speedInt) + .apply() + } } \ No newline at end of file diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettingsFragment.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettingsFragment.kt index 83196f626a..39452604d8 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettingsFragment.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/settings/MainSettingsFragment.kt @@ -2,10 +2,14 @@ // Copyright (c) 2024 Alibaba Group Holding Limited All rights reserved. package com.taobao.meta.avatar.settings + import android.os.Bundle +import android.widget.Toast import androidx.preference.EditTextPreference +import androidx.preference.ListPreference import androidx.preference.Preference import androidx.preference.PreferenceFragmentCompat +import androidx.preference.SeekBarPreference import com.taobao.meta.avatar.R import com.taobao.meta.avatar.utils.AppUtils @@ -29,5 +33,31 @@ class MainSettingsFragment : PreferenceFragmentCompat() { true } } + + // TTS Speaker ID 设置(仅英文模式支持,支持动态切换) + val speakerIdPref = findPreference("tts_speaker_id") + speakerIdPref?.apply { + val isChinese = AppUtils.isChinese() + // 只有英文模式才启用 speaker_id 设置 + isEnabled = !isChinese + if (!isChinese) { + value = MainSettings.getTtsSpeakerId(requireContext()) + setOnPreferenceChangeListener { _, newValue -> + val newSpeakerId = newValue as String + MainSettings.setTtsSpeakerId(requireContext(), newSpeakerId) + // TtsService 会自动监听 SharedPreferences 变化并应用 + Toast.makeText(requireContext(), + "Voice changed to $newSpeakerId", + Toast.LENGTH_SHORT).show() + true + } + } else { + summary = "Only available in English mode" + } + } + + // TTS Speed 设置(暂时隐藏,不支持设置) + val speedPref = findPreference("tts_speed") + speedPref?.isVisible = false } } \ No newline at end of file diff --git a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/tts/TtsService.kt b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/tts/TtsService.kt index b4a42b6f25..6ce12859b2 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/tts/TtsService.kt +++ b/apps/Android/MnnTaoAvatar/app/src/main/java/com/taobao/meta/avatar/tts/TtsService.kt @@ -1,13 +1,19 @@ package com.taobao.meta.avatar.tts +import android.content.Context +import android.content.SharedPreferences +import android.preference.PreferenceManager import android.util.Log import com.k2fsa.sherpa.mnn.GeneratedAudio +import com.taobao.meta.avatar.MHConfig import com.taobao.meta.avatar.debug.DebugModule +import com.taobao.meta.avatar.settings.MainSettings import com.taobao.meta.avatar.utils.AppUtils import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async +import org.json.JSONObject class TtsService { @@ -19,17 +25,29 @@ class TtsService { @Volatile private var isLoaded = false private var initDeferred: Deferred? = null + private var sharedPreferences: SharedPreferences? = null + private var applicationContext: Context? = null + private var currentSpeakerId: String? = null + private val preferenceChangeListener = SharedPreferences.OnSharedPreferenceChangeListener { _, key -> + if (key == "tts_speaker_id") { + handleSpeakerIdChange() + } + } init { ttsServiceNative = nativeCreateTTS(if (AppUtils.isChinese()) "zh" else "en") } fun destroy() { + // 取消注册 SharedPreferences 监听器 + sharedPreferences?.unregisterOnSharedPreferenceChangeListener(preferenceChangeListener) + sharedPreferences = null + applicationContext = null nativeDestroy(ttsServiceNative) ttsServiceNative = 0 } - suspend fun init(modelDir: String?): Boolean { + suspend fun init(modelDir: String?, context: Context? = null): Boolean { if (isLoaded) return true if (initDeferred == null) { initDeferred = CoroutineScope(Dispatchers.IO).async { @@ -38,19 +56,99 @@ class TtsService { sherpaTts?.init(null) return@async true } - nativeLoadResourcesFromFile(ttsServiceNative, - modelDir!!, + + // 1. 根据语言选择模型目录 + val isChinese = AppUtils.isChinese() + val actualModelDir = if (isChinese) { + MHConfig.TTS_MODEL_DIR + } else { + MHConfig.TTS_MODEL_DIR_EN + } + + // 2. 构建参数覆盖 Map + val overrideParams = mutableMapOf() + context?.let { + // 只有英文模式支持 speaker_id + if (!isChinese) { + val speakerId = MainSettings.getTtsSpeakerId(it) + if (speakerId.isNotEmpty()) { + overrideParams["speaker_id"] = speakerId + } + } + + // speed 设置暂时不支持,使用 config.json 中的默认值 + // val speed = MainSettings.getTtsSpeed(it) + // overrideParams["speed"] = speed.toString() + } + + // 3. 序列化为 JSON + val paramsJson = if (overrideParams.isEmpty()) { + "{}" + } else { + JSONObject(overrideParams as Map<*, *>).toString() + } + + Log.d(TAG, "Loading TTS from: $actualModelDir with params: $paramsJson") + + nativeLoadResourcesFromFile( + ttsServiceNative, + actualModelDir, "", - "") + "", + paramsJson + ) true } } val result = initDeferred!!.await() if (result) { isLoaded = true + // 注册 SharedPreferences 监听器(当有 Context 时) + context?.let { + registerPreferenceListener(it) + // 初始化当前 speaker ID + if (!AppUtils.isChinese()) { + currentSpeakerId = MainSettings.getTtsSpeakerId(it) + } + } } return result } + + private fun registerPreferenceListener(context: Context) { + if (sharedPreferences == null) { + // 保存 applicationContext 避免内存泄漏 + applicationContext = context.applicationContext + sharedPreferences = PreferenceManager.getDefaultSharedPreferences(applicationContext) + sharedPreferences?.registerOnSharedPreferenceChangeListener(preferenceChangeListener) + Log.d(TAG, "Registered SharedPreferences listener for speaker ID changes") + } + } + + private fun handleSpeakerIdChange() { + if (!isLoaded) { + Log.d(TAG, "TtsService not loaded yet, speaker ID change will be applied after initialization") + return + } + + // 只在英文模式下处理 + if (AppUtils.isChinese()) { + return + } + + applicationContext?.let { ctx -> + val newSpeakerId = MainSettings.getTtsSpeakerId(ctx) + if (newSpeakerId != currentSpeakerId && newSpeakerId.isNotEmpty()) { + try { + setSpeakerId(newSpeakerId) + currentSpeakerId = newSpeakerId + Log.d(TAG, "Speaker ID changed to: $newSpeakerId") + } catch (e: Exception) { + Log.e(TAG, "Failed to set speaker ID: $newSpeakerId", e) + } + } + } + } suspend fun waitForInitComplete(): Boolean { if (isLoaded) return true @@ -64,6 +162,10 @@ class TtsService { nativeSetCurrentIndex(ttsServiceNative, index) } + fun setSpeakerId(speakerId: String) { + nativeSetSpeakerId(ttsServiceNative, speakerId) + } + fun process(text: String, id: Int): ShortArray { return nativeProcess(ttsServiceNative, text, id) } @@ -81,8 +183,10 @@ class TtsService { private external fun nativeDestroy(nativePtr: Long) private external fun nativeLoadResourcesFromFile(nativePtr: Long, resourceDir: String, - modelName:String, - mmapDir:String): Boolean + modelName: String, + mmapDir: String, + paramsJson: String): Boolean // 新增:JSON 格式的参数覆盖 + private external fun nativeSetSpeakerId(nativePtr: Long, speakerId: String) // 动态设置音色 private external fun nativeProcess(nativePtr: Long, text: String, id: Int): ShortArray companion object { diff --git a/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/arrays.xml b/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/arrays.xml new file mode 100644 index 0000000000..055934dae0 --- /dev/null +++ b/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/arrays.xml @@ -0,0 +1,17 @@ + + + + 男声1 + 男声2 + 女声1 + 女声2 + + + + M1 + M2 + F1 + F2 + + + diff --git a/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/strings.xml b/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/strings.xml index 0aec2952cc..9a674858a6 100755 --- a/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/strings.xml +++ b/apps/Android/MnnTaoAvatar/app/src/main/res/values-zh/strings.xml @@ -33,4 +33,10 @@ 由MNN提供本地推理服务 录音权限被拒绝,请开启 + 音色 + %s + 男声1 + 男声2 + 女声1 + 女声2 diff --git a/apps/Android/MnnTaoAvatar/app/src/main/res/values/arrays.xml b/apps/Android/MnnTaoAvatar/app/src/main/res/values/arrays.xml new file mode 100644 index 0000000000..c1f8ac1ca5 --- /dev/null +++ b/apps/Android/MnnTaoAvatar/app/src/main/res/values/arrays.xml @@ -0,0 +1,17 @@ + + + + Male Voice 1 + Male Voice 2 + Female Voice 1 + Female Voice 2 + + + + M1 + M2 + F1 + F2 + + + diff --git a/apps/Android/MnnTaoAvatar/app/src/main/res/values/strings.xml b/apps/Android/MnnTaoAvatar/app/src/main/res/values/strings.xml index e2e164527e..aa46cceafd 100755 --- a/apps/Android/MnnTaoAvatar/app/src/main/res/values/strings.xml +++ b/apps/Android/MnnTaoAvatar/app/src/main/res/values/strings.xml @@ -33,5 +33,11 @@ Inference Local by MNN Record permission denied, please grant it + Voice + %s + Male Voice 1 + Male Voice 2 + Female Voice 1 + Female Voice 2 diff --git a/apps/Android/MnnTaoAvatar/app/src/main/res/xml/main_settings_prefs.xml b/apps/Android/MnnTaoAvatar/app/src/main/res/xml/main_settings_prefs.xml index cacc5b6508..3c13f689e4 100644 --- a/apps/Android/MnnTaoAvatar/app/src/main/res/xml/main_settings_prefs.xml +++ b/apps/Android/MnnTaoAvatar/app/src/main/res/xml/main_settings_prefs.xml @@ -1,7 +1,27 @@ + + + + + + + + + + + + + + Date: Fri, 19 Dec 2025 15:49:13 +0800 Subject: [PATCH 020/314] Project import generated by Copybara. GitOrigin-RevId: 0fda1298b1b377be72287ceadb418af333af7146 --- docs/Makefile | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 docs/Makefile diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000..d4bb2cbb9e --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) From d6c55ca68eca51567e06aa45beb8ef2de94046b9 Mon Sep 17 00:00:00 2001 From: "zhaode.wzd" Date: Mon, 22 Dec 2025 09:56:21 +0800 Subject: [PATCH 021/314] [MNN:Refact] Delete not use file. --- MNN.sln | 0 docker_release.sh | 6 ------ docker_run.sh | 6 ------ 3 files changed, 12 deletions(-) delete mode 100644 MNN.sln delete mode 100755 docker_release.sh delete mode 100755 docker_run.sh diff --git a/MNN.sln b/MNN.sln deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/docker_release.sh b/docker_release.sh deleted file mode 100755 index 05ba96cf2d..0000000000 --- a/docker_release.sh +++ /dev/null @@ -1,6 +0,0 @@ -# using docker run release -docker start mnn_release -docker exec -i -e TEST_ID=$(pwd | awk -F "/" '{print $(NF-1)}') mnn_release bash <<'EOF' -cd ~/yanxing_zhaode/cise/space/$TEST_ID/source && ./release.sh pymnn -exit -EOF \ No newline at end of file diff --git a/docker_run.sh b/docker_run.sh deleted file mode 100755 index 6cff3321bc..0000000000 --- a/docker_run.sh +++ /dev/null @@ -1,6 +0,0 @@ -# using docker run test -docker start mnn_ci -docker exec -i -e TEST_ID=$(pwd | awk -F "/" '{print $(NF-1)}') mnn_ci bash <<'EOF' -cd ~/yanxing_zhaode/cise/space/$TEST_ID/source && ./test.sh linux -exit -EOF From 1f805fd9c2889d03fc979521b6062aa55ff7e314 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 09:56:50 +0800 Subject: [PATCH 022/314] Project import generated by Copybara. GitOrigin-RevId: fd90884f44c381932e3de8224bc69a0c327a3344 --- CMakeLists.txt | 1 - build_lib.sh | 807 ------------------ docs/transformers/diffusion.md | 3 +- source/backend/cpu/arm/CMakeLists.txt | 3 - .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 - .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 - .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 - .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 -- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 - .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 - .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 -- source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 - .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 -- .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 - .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 -- .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 - .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 -- .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 -- .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 - source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 - source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 - source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 -- source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 -- source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 - .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 - .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 - source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 - .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 - source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 -- .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 - .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 - .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 - source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 -- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 - .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 - source/core/Backend.hpp | 6 +- transformers/diffusion/export/onnx_export.py | 30 +- 43 files changed, 24 insertions(+), 2180 deletions(-) delete mode 100644 build_lib.sh delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f99e37ec1c..67502b606b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,7 +258,6 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) -option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/build_lib.sh b/build_lib.sh deleted file mode 100644 index c839b6e7b6..0000000000 --- a/build_lib.sh +++ /dev/null @@ -1,807 +0,0 @@ -#!/bin/bash - -# MNN 统一构建脚本 -# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN - -set -e - -# 颜色输出 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# 默认配置 -BUILD_ANDROID=false -BUILD_IOS=false -BUILD_PYTHON=false -BUILD_IOS_SIMULATOR=false -BUILD_HARMONY=false -ANDROID_NDK="" -HARMONY_HOME="" -OUTPUT_DIR="mnn_builds" -CLEAN_BUILD=true -VERSION="" -PYTHON_CMD="python3" -DEPS_OPTIONS="" - -# 获取项目根目录 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$SCRIPT_DIR" - -print_usage() { - echo "MNN 统一构建脚本" - echo "" - echo "用法: $0 [选项]" - echo "" - echo "构建目标:" - echo " --android 构建 Android 版本 (需要 --ndk)" - echo " --ios 构建 iOS 真机版本 (arm64)" - echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" - echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" - echo " --python 构建 Python 版本" - echo "" - echo "Android 选项:" - echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" - echo "" - echo "鸿蒙选项:" - echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" - echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" - echo "" - echo "Python 选项:" - echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" - echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" - echo " --python-cmd CMD Python 命令 (默认: python3)" - echo "" - echo "通用选项:" - echo " -o, --output DIR 输出目录 (默认: mnn_builds)" - echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" - echo " --no-clean 不清理之前的构建目录" - echo " -h, --help 显示帮助信息" - echo "" - echo "示例:" - echo " # 构建所有平台" - echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 仅构建 Android" - echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 构建鸿蒙版本" - echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" - echo "" - echo " # 构建 iOS 真机和模拟器" - echo " $0 --ios --ios-simulator" - echo "" - echo " # 构建 Python (带 LLM 支持)" - echo " $0 --python --python-deps llm,opencl" -} - -# 解析命令行参数 -while [[ $# -gt 0 ]]; do - case $1 in - --android) - BUILD_ANDROID=true - shift - ;; - --ios) - BUILD_IOS=true - shift - ;; - --ios-simulator) - BUILD_IOS_SIMULATOR=true - shift - ;; - --python) - BUILD_PYTHON=true - shift - ;; - --harmony) - BUILD_HARMONY=true - shift - ;; - --ndk) - ANDROID_NDK="$2" - shift 2 - ;; - --harmony-home) - HARMONY_HOME="$2" - shift 2 - ;; - --python-deps) - DEPS_OPTIONS="$2" - shift 2 - ;; - --python-cmd) - PYTHON_CMD="$2" - shift 2 - ;; - -o|--output) - OUTPUT_DIR="$2" - shift 2 - ;; - -v|--version) - VERSION="$2" - shift 2 - ;; - --no-clean) - CLEAN_BUILD=false - shift - ;; - -h|--help) - print_usage - exit 0 - ;; - *) - echo -e "${RED}错误: 未知选项 $1${NC}" - print_usage - exit 1 - ;; - esac -done - -# 检查是否至少选择了一个构建目标 -if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then - echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" - print_usage - exit 1 -fi - -# 检查 Android NDK -if [ "$BUILD_ANDROID" = true ]; then - if [ -z "$ANDROID_NDK" ]; then - echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" - echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" - exit 1 - fi - - # 展开路径中的 ~ - ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" - - if [ ! -d "$ANDROID_NDK" ]; then - echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" - exit 1 - fi -fi - -# 查找鸿蒙工具链的函数 -find_harmony_toolchain() { - # 如果环境变量已设置,优先使用 - if [ -n "$HARMONY_HOME" ]; then - # 支持两种路径格式:直接指定或带 native 子目录 - if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - fi - fi - - # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 - local sdk_base="$HOME/Library/OpenHarmony/Sdk" - if [ -d "$sdk_base" ]; then - # 查找所有版本号目录下的 native 或 toolchains - # 按版本号倒序排列,优先使用最新版本 - for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do - if [ -d "$version_dir" ]; then - # 尝试 native/build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir/native" - return 0 - fi - # 尝试 build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir" - return 0 - fi - fi - done - fi - - # 其他可能的路径 - local possible_paths=( - "$HOME/Library/OpenHarmony/Sdk/native" - "$HOME/HarmonyOS/Sdk/native" - "$HOME/.ohos/native" - "/opt/HarmonyOS/Sdk/native" - "/usr/local/HarmonyOS/Sdk/native" - "$HOME/Library/HarmonyOS/Sdk/native" - ) - - # 尝试查找 - for path in "${possible_paths[@]}"; do - if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then - echo "$path" - return 0 - fi - done - - # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 - # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 - local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) - if [ -n "$found" ]; then - # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 - found=$(dirname "$found") - if [ "$(basename "$found")" = "cmake" ]; then - found=$(dirname "$found") - if [ "$(basename "$found")" = "build" ]; then - found=$(dirname "$found") - fi - fi - echo "$found" - return 0 - fi - - return 1 -} - -# 检查鸿蒙工具链 -if [ "$BUILD_HARMONY" = true ]; then - # 展开路径中的 ~ - if [ -n "$HARMONY_HOME" ]; then - HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" - fi - - # 尝试查找工具链 - if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - # 检查是否在 native 子目录下 - if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - else - echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" - HARMONY_HOME=$(find_harmony_toolchain) - - if [ -z "$HARMONY_HOME" ]; then - echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" - echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" - echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" - echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi - - # 如果找到的是带 native 的路径,需要调整 - if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - fi - fi - - echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" - fi - - # 验证工具链文件存在 - if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi -fi - -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}MNN 统一构建脚本${NC}" -echo -e "${GREEN}========================================${NC}" -echo "项目根目录: $PROJECT_ROOT" -echo "输出目录: $OUTPUT_DIR" -echo "" -echo -e "${BLUE}构建目标:${NC}" -[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" -[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" -[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" -[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" -echo "" - -cd "$PROJECT_ROOT" -mkdir -p "$OUTPUT_DIR" - -# ============================================================================ -# 构建 Android 版本 -# ============================================================================ -if [ "$BUILD_ANDROID" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Android 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export ANDROID_NDK - - ANDROID_BUILD_DIR="project/android" - ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" - - cd "$ANDROID_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 Android 构建...${NC}" - rm -rf build_32 build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - - # 构建 armeabi-v7a - echo -e "${BLUE}构建 armeabi-v7a...${NC}" - mkdir -p build_32 - cd build_32 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="armeabi-v7a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-14 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="arm64-v8a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-21 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出 Android 库文件...${NC}" - mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN - - cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true - cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/android/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" - - echo -e "${GREEN}Android 构建完成!${NC}" - echo "输出位置: $ANDROID_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 真机版本 -# ============================================================================ -if [ "$BUILD_IOS" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 真机版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_64 - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建真机版本 (arm64) - echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" - rm -rf ios_64 - mkdir ios_64 - cd ios_64 - cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=OS64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - > /dev/null - make MNN -j8 > /dev/null - cd .. - - # 输出真机版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理 - rm -rf ios_64 - - echo -e "${GREEN}iOS 真机构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 模拟器版本 -# ============================================================================ -if [ "$BUILD_IOS_SIMULATOR" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) - BUILD_X86_64=false - echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" - rm -rf ios_simulator_x86 - mkdir ios_simulator_x86 - cd ios_simulator_x86 - if cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="x86_64" \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - && make MNN -j8; then - if [ -f "MNN.framework/MNN" ]; then - BUILD_X86_64=true - echo -e "${GREEN}x86_64 模拟器构建成功${NC}" - cd .. - else - echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - else - echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - - # 构建 arm64 模拟器版本 - echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" - rm -rf ios_simulator_arm64 - mkdir ios_simulator_arm64 - cd ios_simulator_arm64 - if ! cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON; then - echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - if ! make MNN -j8; then - echo -e "${RED}错误: arm64 模拟器编译失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - cd .. - - # 验证构建产物 - if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then - echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - - # 合并模拟器架构 - echo -e "${BLUE}合并模拟器架构...${NC}" - rm -rf ios_simulator - mkdir ios_simulator - - if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then - # 合并 x86_64 + arm64 - echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" - cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework - mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 - if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then - echo -e "${RED}错误: 合并架构失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - rm ios_simulator/MNN.framework/MNN_x86 - else - # 仅使用 arm64 - echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" - cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework - fi - - # 输出模拟器版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理临时目录 - rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 - - echo -e "${GREEN}iOS 模拟器构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建鸿蒙版本 -# ============================================================================ -if [ "$BUILD_HARMONY" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建鸿蒙版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export HARMONY_HOME - - HARMONY_BUILD_DIR="project/harmony" - HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" - - cd "$HARMONY_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" - rm -rf build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - - cmake "$PROJECT_ROOT" \ - -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DOHOS_ARCH="arm64-v8a" \ - -DOHOS_STL=c++_static \ - -DMNN_USE_LOGCAT=true \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_SUPPORT_BF16=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DOHOS_PLATFORM_LEVEL=9 \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出鸿蒙库文件...${NC}" - mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN - - cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/harmony/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" - - echo -e "${GREEN}鸿蒙构建完成!${NC}" - echo "输出位置: $HARMONY_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 Python 版本 -# ============================================================================ -if [ "$BUILD_PYTHON" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Python 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查 Python - if ! command -v $PYTHON_CMD &> /dev/null; then - echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" - exit 1 - fi - - PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" - - # 使用之前创建的 build_pymnn.sh 脚本 - if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then - PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" - [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" - [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" - [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" - PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" - - bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS - else - echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" - - cd pymnn/pip_package - - # 构建依赖 - if [ -n "$DEPS_OPTIONS" ]; then - $PYTHON_CMD build_deps.py $DEPS_OPTIONS - else - $PYTHON_CMD build_deps.py - fi - - # 构建 wheel - $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet - rm -rf build dist - - BUILD_ARGS="" - [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" - [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" - - $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS - - mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" - cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" - - cd "$PROJECT_ROOT" - fi - - echo -e "${GREEN}Python 构建完成!${NC}" - echo "输出位置: $PYTHON_OUTPUT_DIR" -fi - -# ============================================================================ -# 总结 -# ============================================================================ -echo "" -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}所有构建完成!${NC}" -echo -e "${GREEN}========================================${NC}" -echo "" -echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" -echo "" -[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" -[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" -[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" -[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" -echo "" - - diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 609793f806..7de27bb216 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,8 +20,7 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path \ - --opset 18 + --output_path onnx_save_path ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 61ebce6bdc..18fca54a4e 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,9 +36,6 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() - if(MNN_KLEIDIAI_DEFAULT_ON) - add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) - endif() endif() if (MNN_SME2) diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp deleted file mode 100644 index a700016c31..0000000000 --- a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include - -void CPUBilinearLineC4(float* dst, const float* A, const float* B, - const float* t, int8_t* zeroPoint, size_t number) { - float tf = *t; - float sf = 1.0f - tf; - size_t total = number << 2; - - size_t i = 0; - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); - vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); - v = __riscv_vle32_v_f32m8(B + i, vl); - result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); - __riscv_vse32_v_f32m8(dst + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp deleted file mode 100644 index 5063c39bff..0000000000 --- a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -void CPUBilinearSampleC4(const float* src, float* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); - vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); - vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); - vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp deleted file mode 100644 index 59bb28a039..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); - size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp deleted file mode 100644 index 6d966789f7..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { - float beta = parameters[1]; - float minF = parameters[2]; - float maxF = parameters[3]; - const ptrdiff_t stride = 4 * sizeof(float); - - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + 4 * y; - auto c = C + cStride * y; - float b0Beta = b[0] * beta; - float b1Beta = b[1] * beta; - float b2Beta = b[2] * beta; - float b3Beta = b[3] * beta; - size_t w = width; - - while (w > 0) { - size_t vl = __riscv_vsetvl_e32m8(w); - - vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); - - a += 4 * vl; - c += 4 * vl; - w -= vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp deleted file mode 100644 index 145cbea73f..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include - -void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp deleted file mode 100644 index d46fe6c85b..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp deleted file mode 100644 index 684db6aed3..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp deleted file mode 100644 index a26243bdb8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, - const float* t, int8_t* zeroPoint, size_t number) { - int offset = *zeroPoint; - int8_t* dstPtr = dst; - - const int pack = 8; - const int16_t df = (int16_t)((*t) * 128.0f); - const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); - const size_t total = number * pack; - const int32_t ROUND_HALF = 1 << 13; - - size_t vl; - for (size_t i = 0; i < total; i += vl) { - vl = __riscv_vsetvl_e16m4(total - i); - vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); - vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); - v16 = __riscv_vle16_v_i16m4(B + i, vl); - v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); - - vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); - - tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); - mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); - vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); - mask = __riscv_vmand_mm_b4(mask, hasRem, vl); - - v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); - v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); - v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); - vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); - - __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp deleted file mode 100644 index bd111e3be4..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include - -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - int16_t offset = (int16_t)(*zeroPoint); - const int pack = 8; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - vint16m4_t vdf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); - vint16m4_t vsf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8( - __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), - c, vl); - - vint16m4_t va = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - - vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); - voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), - c, vl); - - vint16m4_t vb = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); - __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, - __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp deleted file mode 100644 index 9d524f13ca..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp deleted file mode 100644 index f82faf83f5..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include - -void MNNConvRunForLineDepthwise( - float* dst, const float* src, const float* weight, - size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, - size_t height, size_t srcHStep, size_t dstHStep, - const float* bias, const float* parameters) { - float minV = parameters[0]; - float maxV = parameters[1]; - ptrdiff_t srcByteStride = src_w_setup * sizeof(float); - ptrdiff_t dstByteStride = 4 * sizeof(float); - - for (size_t y = 0; y < height; ++y) { - const float* srcY = src + y * srcHStep; - float* dstY = dst + y * dstHStep; - size_t dx = 0; - - while (dx < width) { - size_t vl = __riscv_vsetvl_e32m8(width - dx); - - for (int c = 0; c < 4; ++c) { - vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); - const float* srcBase = srcY + dx * src_w_setup + c; - const float* weightPtr = weight + c; - - for (size_t fy = 0; fy < fh; ++fy) { - const float* srcFy = srcBase + fy * dilateY_step; - - for (size_t fx = 0; fx < fw; ++fx) { - float w = *weightPtr; - weightPtr += 4; - const float* srcFx = srcFy + fx * dilateX_step; - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); - } - } - - acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); - acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); - float* dstAddr = dstY + dx * 4 + c; - __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); - } - - dx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp deleted file mode 100644 index 3d8c4f13fc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include - -void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); -size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp deleted file mode 100644 index fd6ce7a274..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include - -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const int offset = *zeroPoint; - const int minVal = (int)minValue; - const int maxVal = (int)maxValue; - const size_t total = number << 4; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); - vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); - acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); - - vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); - vint = __riscv_vadd_vx_i32m8(vint, offset, vl); - vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); - vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); - - vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); - vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); - __riscv_vse8_v_i8m2(dst + i, vi8, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp deleted file mode 100644 index 0da63ca0ff..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include - -void MNNCubicLineC4(float* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const size_t total = number << 2; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - __riscv_vse32_v_f32m8(dst + i, acc, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp deleted file mode 100644 index fd5b24a53d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include - -void MNNCubicSampleC16(const int8_t* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 16; - int8_t zp = *zeroPoint; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp deleted file mode 100644 index 78207e69e8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include - -void MNNCubicSampleC4(const float* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp deleted file mode 100644 index 6658715e7e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNDeconvRunForUnitDepthWise( - const float* dst, float* src, const float* weight, - size_t fw, size_t fh, - size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { - const ptrdiff_t wStride = 4 * sizeof(float); - const ptrdiff_t sStride = dilateX_step * sizeof(float); - float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; - - for (size_t fy = 0; fy < fh; ++fy) { - float* srcY = src + fy * dilateY_step; - const float* weightY = weight + fy * weightY_step; - - size_t fx = 0; - while (fx < fw) { - size_t vl = __riscv_vsetvl_e32m8(fw - fx); - - vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); - __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); - __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); - __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); - __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); - - fx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp deleted file mode 100644 index 952fcaf090..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp deleted file mode 100644 index 5ee4540f98..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp deleted file mode 100644 index 183a38bb10..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { - const float init = -FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - maxBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp deleted file mode 100644 index 9e8ade8641..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { - const float init = FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - minBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp deleted file mode 100644 index 9a74f8998d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include - -void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC2 = depth / 2; - int depthRemain = depthC2 * 2; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[2]; - - for (int z = 0; z < depthC2; ++z) { - float *dstZ = dst + z * areaOffset[1] * 2; - - for (int y = 0; y < 2; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - } - - srcOffset += areaOffset[0] * 2; - } - - if (remain > 0) { - float *dstZ = dst + depthC2 * areaOffset[1] * 2; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 2; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 2; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp deleted file mode 100644 index 024e2c8c07..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include - -void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[4]; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ = dst + z * areaOffset[1] * 4; - - for (int y = 0; y < 4; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - dstPtr[2] = srcChannel[2][x]; - dstPtr[3] = srcChannel[3][x]; - } - - srcOffset += areaOffset[0] * 4; - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 4; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 4; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp deleted file mode 100644 index f2b6c7a78d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp deleted file mode 100644 index ddd67a7d8c..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp deleted file mode 100644 index d56b58546d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp deleted file mode 100644 index 7c6decf39e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp deleted file mode 100644 index 1b946c33cc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp deleted file mode 100644 index 262f4cbfab..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include - -void MNNReluWithSlopeChannel(float *dst, const float *src, - const float *slope, size_t sizeQuad, - size_t depthQuad) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t j = 0; j < depthQuad; ++j) { - const float *srcZ = src + 4 * j * sizeQuad; - float *dstZ = dst + 4 * j * sizeQuad; - float s0 = slope[4*j], s1 = slope[4*j + 1]; - float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; - size_t i = 0; - while (i < sizeQuad) { - size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); - const float *srcBase = srcZ + 4*i; - float *dstBase = dstZ + 4*i; - - vfloat32m8_t v; - vbool4_t mask; - - v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); - __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); - __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); - __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); - __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); - - i += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp deleted file mode 100644 index 10992f9d59..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t z = 0; z < biasNumber; ++z) { - float *dstZ = dst + z * planeNumber * 4; - const float *srcZ = src + z * planeNumber * 4; - const float *biasZ = bias + 4 * z; - const float *alphaZ = alpha + 4 * z; - float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; - float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; - - size_t n = planeNumber; - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a0, vl); - data = __riscv_vfadd_vf_f32m8(data, b0, vl); - __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a1, vl); - data = __riscv_vfadd_vf_f32m8(data, b1, vl); - __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a2, vl); - data = __riscv_vfadd_vf_f32m8(data, b2, vl); - __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a3, vl); - data = __riscv_vfadd_vf_f32m8(data, b3, vl); - __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); - - srcZ += vl * 4; - dstZ += vl * 4; - n -= vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp deleted file mode 100644 index f510058c83..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include -#include - -void MNNSoftmax(float *dest, const float *source, size_t size) { - size_t n = size; - const float *sourcePtr = source; - float *destPtr = dest; - float maxValue = -FLT_MAX; - vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); - maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); - sourcePtr += vl; - n -= vl; - } - - maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); - const float param = 0.6931471805599453f; - const float xLimit = 87.0f; - float sumValue = 0.f; - vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); - n = size; - sourcePtr = source; - destPtr = dest; - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); - vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); - vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); - vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); - - vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); - vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); - - vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( - __riscv_vsll_vx_i32m8( - __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); - - vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); - vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); - - vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - - vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); - __riscv_vse32_v_f32m8(destPtr, vA, vl); - sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); - - sourcePtr += vl; - destPtr += vl; - n -= vl; - } - - sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); - float sumInv = 1.0f / sumValue; - n = size; - destPtr = dest; - - while (n > 0) - { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); - vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); - __riscv_vse32_v_f32m8(destPtr, vDest, vl); - destPtr += vl; - n -= vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp deleted file mode 100644 index 8ab5bb89fa..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include - -void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, - float *xAddr, size_t cStride, size_t eSub, size_t hSub) { - for (int y = 0; y < hSub; ++y) { - float *c11Y = c11 + y * cStride; - float *c12Y = c12 + y * cStride; - float *c22Y = c22 + y * cStride; - float *c21Y = c21 + y * cStride; - float *xY = xAddr + y * eSub * 4; - size_t totalElements = eSub * 4; - size_t p = 0; - - while (p < totalElements) { - size_t vl = __riscv_vsetvl_e32m8(totalElements - p); - vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); - vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); - t = __riscv_vfadd_vv_f32m8(t, tmp, vl); - vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); - - tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); - __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); - - tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); - __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); - - c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); - __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); - - p += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp deleted file mode 100644 index 7598d6f8ac..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include - -void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); - - for (int i = 0; i < h; ++i) { - const int16_t* srcPtr = srcO + i; - int16_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e16m8(w - j); - vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); - __riscv_vse16_v_i16m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - - diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp deleted file mode 100644 index e5c5eb83e6..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include - -void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); - - for (int i = 0; i < h; ++i) { - const int32_t* srcPtr = srcO + i; - int32_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e32m8(w - j); - vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); - __riscv_vse32_v_i32m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp deleted file mode 100644 index 4676e6dede..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ[4]; - - for (int y = 0; y < 4; ++y) { - dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); - srcOffset += 4 * vl; - } - - for (; x < area; ++x) { - dstZ[0][x] = srcOffset[0]; - dstZ[1][x] = srcOffset[1]; - dstZ[2][x] = srcOffset[2]; - dstZ[3][x] = srcOffset[3]; - srcOffset += (areaOffset[0] - area) * 4; - } - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - const float *srcBase = srcOffset; - - for (int y = 0; y < remain; ++y) { - float *dstChannel = dstZ + y * areaOffset[1]; - const float *srcChannel = srcBase + y; - - for (size_t x = 0; x < area; ++x) { - dstChannel[x] = srcChannel[0]; - srcChannel += 4; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp deleted file mode 100644 index 7332360ce8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - float maxV = -FLT_MAX; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); - vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); - maxV = __riscv_vfmv_f_s_f32m1_f32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp deleted file mode 100644 index 8c199709ec..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - int32_t maxV = INT32_MIN; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); - vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); - maxV = __riscv_vmv_x_s_i32m1_i32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 6850b6b4f6..bcf618c3c9 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,11 +68,9 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; -#ifdef MNN_DEFAULT_USE_KLEIDIAI - bool enableKleidiAI = true; -#else + bool enableKleidiAI = false; -#endif + // Use CPU Ids std::vector cpuIds; diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 5516eb2fcc..21f05e83be 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - ["A sample prompt", "A sample prompt"], + "A sample prompt", padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,7 +97,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes=None, + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, opset=opset, ) del pipeline.text_encoder @@ -115,9 +117,13 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -143,7 +149,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() onnx_export( vae_encoder, model_args=( @@ -153,24 +159,30 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] + vae_decoder.forward = vae_encoder.decode onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample"], + ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], - dynamic_axes=None, + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) del pipeline.vae From 265e56cdce07beadac3ae68863d2acaa3a4fe81d Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:36:21 +0800 Subject: [PATCH 023/314] Merge pull request #4067 from ihb2032/opt/rvv-pixel-conv opt(RVV): Optimize blitter functions with intrinsics GitOrigin-RevId: a22d2d445a0d106f5c9201cbedd49c7b168225c6 --- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 +++++++++++++++++ .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 ++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 +++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 ++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 +++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 ++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 +++++++++++++++++++ 11 files changed, 201 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp new file mode 100644 index 0000000000..145cbea73f --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp @@ -0,0 +1,18 @@ +#include + +void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp new file mode 100644 index 0000000000..d46fe6c85b --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp new file mode 100644 index 0000000000..684db6aed3 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp new file mode 100644 index 0000000000..9d524f13ca --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp @@ -0,0 +1,20 @@ +#include + +void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp new file mode 100644 index 0000000000..952fcaf090 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp @@ -0,0 +1,13 @@ +#include + +void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp new file mode 100644 index 0000000000..5ee4540f98 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp @@ -0,0 +1,16 @@ +#include + +void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp new file mode 100644 index 0000000000..f2b6c7a78d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp new file mode 100644 index 0000000000..ddd67a7d8c --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp new file mode 100644 index 0000000000..d56b58546d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp new file mode 100644 index 0000000000..7c6decf39e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp new file mode 100644 index 0000000000..1b946c33cc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} From d6e0798e5f6e437175623a32778131bcb0cdbb4f Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:13 +0800 Subject: [PATCH 024/314] Merge pull request #4053 from ihb2032/opt/rvv-resize-functions opt(RVV): Optimize resize functions with intrinsics GitOrigin-RevId: 824f1b9ad56f611613f801eaa7e1c2ae2d3fd307 --- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 +++++ .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 ++++++++ .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 ++++++++++ .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 ++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 +++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 +++++++++ .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 +++++++++++++++ 8 files changed, 373 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp new file mode 100644 index 0000000000..a700016c31 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp @@ -0,0 +1,19 @@ +#include + +void CPUBilinearLineC4(float* dst, const float* A, const float* B, + const float* t, int8_t* zeroPoint, size_t number) { + float tf = *t; + float sf = 1.0f - tf; + size_t total = number << 2; + + size_t i = 0; + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); + vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); + v = __riscv_vle32_v_f32m8(B + i, vl); + result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); + __riscv_vse32_v_f32m8(dst + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp new file mode 100644 index 0000000000..5063c39bff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp @@ -0,0 +1,33 @@ +#include + +void CPUBilinearSampleC4(const float* src, float* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); + vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); + vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); + vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp new file mode 100644 index 0000000000..a26243bdb8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp @@ -0,0 +1,40 @@ +#include + +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, + const float* t, int8_t* zeroPoint, size_t number) { + int offset = *zeroPoint; + int8_t* dstPtr = dst; + + const int pack = 8; + const int16_t df = (int16_t)((*t) * 128.0f); + const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); + const size_t total = number * pack; + const int32_t ROUND_HALF = 1 << 13; + + size_t vl; + for (size_t i = 0; i < total; i += vl) { + vl = __riscv_vsetvl_e16m4(total - i); + vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); + vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); + v16 = __riscv_vle16_v_i16m4(B + i, vl); + v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); + + vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); + + tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); + mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); + vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); + mask = __riscv_vmand_mm_b4(mask, hasRem, vl); + + v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); + v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); + v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); + vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); + + __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp new file mode 100644 index 0000000000..bd111e3be4 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp @@ -0,0 +1,49 @@ +#include + +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + int16_t offset = (int16_t)(*zeroPoint); + const int pack = 8; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + vint16m4_t vdf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); + vint16m4_t vsf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8( + __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), + c, vl); + + vint16m4_t va = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + + vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); + voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), + c, vl); + + vint16m4_t vb = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); + __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, + __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp new file mode 100644 index 0000000000..fd6ce7a274 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp @@ -0,0 +1,53 @@ +#include + +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const int offset = *zeroPoint; + const int minVal = (int)minValue; + const int maxVal = (int)maxValue; + const size_t total = number << 4; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); + vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); + acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); + + vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); + vint = __riscv_vadd_vx_i32m8(vint, offset, vl); + vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); + vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); + + vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); + vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); + __riscv_vse8_v_i8m2(dst + i, vi8, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp new file mode 100644 index 0000000000..0da63ca0ff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp @@ -0,0 +1,38 @@ +#include + +void MNNCubicLineC4(float* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const size_t total = number << 2; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + __riscv_vse32_v_f32m8(dst + i, acc, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp new file mode 100644 index 0000000000..fd5b24a53d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp @@ -0,0 +1,79 @@ +#include + +void MNNCubicSampleC16(const int8_t* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 16; + int8_t zp = *zeroPoint; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp new file mode 100644 index 0000000000..78207e69e8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp @@ -0,0 +1,62 @@ +#include + +void MNNCubicSampleC4(const float* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); + } + + i += vl; + } +} From 7230c3537637cc1dd966a10063b9605deb862b28 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:55 +0800 Subject: [PATCH 025/314] Merge pull request #4050 from ihb2032/opt/rvv-top1 opt(RVV): Optimize top1 functions with intrinsics GitOrigin-RevId: 070c444b927aab4db76297e217bfe92a4508b294 --- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 +++++++++++++++++++ .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp new file mode 100644 index 0000000000..7332360ce8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + float maxV = -FLT_MAX; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); + vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); + maxV = __riscv_vfmv_f_s_f32m1_f32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp new file mode 100644 index 0000000000..8c199709ec --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + int32_t maxV = INT32_MIN; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); + vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); + maxV = __riscv_vmv_x_s_i32m1_i32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} From 64172875f9257e33e25536cfbd205b4f698071b1 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:36 +0800 Subject: [PATCH 026/314] Merge pull request #4044 from ihb2032/opt/rvv-softmax-relu opt(RVV): Optimize Softmax and ReluWithSlopeChannel with intrinsics GitOrigin-RevId: 98d2f9db51b45bf1deda2fb22398e56b323b5ae2 --- .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 +++++++++++ source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 +++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp new file mode 100644 index 0000000000..262f4cbfab --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp @@ -0,0 +1,45 @@ +#include + +void MNNReluWithSlopeChannel(float *dst, const float *src, + const float *slope, size_t sizeQuad, + size_t depthQuad) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t j = 0; j < depthQuad; ++j) { + const float *srcZ = src + 4 * j * sizeQuad; + float *dstZ = dst + 4 * j * sizeQuad; + float s0 = slope[4*j], s1 = slope[4*j + 1]; + float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; + size_t i = 0; + while (i < sizeQuad) { + size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); + const float *srcBase = srcZ + 4*i; + float *dstBase = dstZ + 4*i; + + vfloat32m8_t v; + vbool4_t mask; + + v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); + __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); + __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); + __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); + __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); + + i += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp new file mode 100644 index 0000000000..f510058c83 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp @@ -0,0 +1,80 @@ +#include +#include + +void MNNSoftmax(float *dest, const float *source, size_t size) { + size_t n = size; + const float *sourcePtr = source; + float *destPtr = dest; + float maxValue = -FLT_MAX; + vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); + maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); + sourcePtr += vl; + n -= vl; + } + + maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); + const float param = 0.6931471805599453f; + const float xLimit = 87.0f; + float sumValue = 0.f; + vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); + n = size; + sourcePtr = source; + destPtr = dest; + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); + vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); + vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); + vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); + + vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); + vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); + + vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( + __riscv_vsll_vx_i32m8( + __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); + + vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); + vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); + + vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + + vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); + __riscv_vse32_v_f32m8(destPtr, vA, vl); + sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); + + sourcePtr += vl; + destPtr += vl; + n -= vl; + } + + sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); + float sumInv = 1.0f / sumValue; + n = size; + destPtr = dest; + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); + vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); + __riscv_vse32_v_f32m8(destPtr, vDest, vl); + destPtr += vl; + n -= vl; + } +} From b40cd5239b5cbc8e8fbdbe7b09ff75be128e337a Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:54 +0800 Subject: [PATCH 027/314] Merge pull request #4042 from ihb2032/opt/rvv-conv-strassen opt(RVV): Optimize conv and strassen functions with intrinsics GitOrigin-RevId: bf461aa6e424685c2bc16570bc44220b65418ead --- .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 +++++++++++++++++++ .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 ++++++++++++++++ .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 ++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp new file mode 100644 index 0000000000..f82faf83f5 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp @@ -0,0 +1,48 @@ +#include + +void MNNConvRunForLineDepthwise( + float* dst, const float* src, const float* weight, + size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, + size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters) { + float minV = parameters[0]; + float maxV = parameters[1]; + ptrdiff_t srcByteStride = src_w_setup * sizeof(float); + ptrdiff_t dstByteStride = 4 * sizeof(float); + + for (size_t y = 0; y < height; ++y) { + const float* srcY = src + y * srcHStep; + float* dstY = dst + y * dstHStep; + size_t dx = 0; + + while (dx < width) { + size_t vl = __riscv_vsetvl_e32m8(width - dx); + + for (int c = 0; c < 4; ++c) { + vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); + const float* srcBase = srcY + dx * src_w_setup + c; + const float* weightPtr = weight + c; + + for (size_t fy = 0; fy < fh; ++fy) { + const float* srcFy = srcBase + fy * dilateY_step; + + for (size_t fx = 0; fx < fw; ++fx) { + float w = *weightPtr; + weightPtr += 4; + const float* srcFx = srcFy + fx * dilateX_step; + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); + } + } + + acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); + acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); + float* dstAddr = dstY + dx * 4 + c; + __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); + } + + dx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp new file mode 100644 index 0000000000..6658715e7e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp @@ -0,0 +1,42 @@ +#include + +void MNNDeconvRunForUnitDepthWise( + const float* dst, float* src, const float* weight, + size_t fw, size_t fh, + size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { + const ptrdiff_t wStride = 4 * sizeof(float); + const ptrdiff_t sStride = dilateX_step * sizeof(float); + float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; + + for (size_t fy = 0; fy < fh; ++fy) { + float* srcY = src + fy * dilateY_step; + const float* weightY = weight + fy * weightY_step; + + size_t fx = 0; + while (fx < fw) { + size_t vl = __riscv_vsetvl_e32m8(fw - fx); + + vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); + __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); + __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); + __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); + __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); + + fx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp new file mode 100644 index 0000000000..8ab5bb89fa --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp @@ -0,0 +1,36 @@ +#include + +void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, + float *xAddr, size_t cStride, size_t eSub, size_t hSub) { + for (int y = 0; y < hSub; ++y) { + float *c11Y = c11 + y * cStride; + float *c12Y = c12 + y * cStride; + float *c22Y = c22 + y * cStride; + float *c21Y = c21 + y * cStride; + float *xY = xAddr + y * eSub * 4; + size_t totalElements = eSub * 4; + size_t p = 0; + + while (p < totalElements) { + size_t vl = __riscv_vsetvl_e32m8(totalElements - p); + vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); + vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); + t = __riscv_vfadd_vv_f32m8(t, tmp, vl); + vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); + + tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); + __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); + + tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); + __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); + + c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); + __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); + + p += vl; + } + } +} From 0012df844b185e6124ba9dcc80d02bed88a28073 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:07 +0800 Subject: [PATCH 028/314] Merge pull request #4036 from ihb2032/opt/rvv-minmax-float opt(RVV): Optimize max and min float functions with intrinsics GitOrigin-RevId: cf83302a16083000f569672536d270edb597b0a5 --- source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 ++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp new file mode 100644 index 0000000000..183a38bb10 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { + const float init = -FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + maxBuffer[j] = local; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp new file mode 100644 index 0000000000..9e8ade8641 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { + const float init = FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + minBuffer[j] = local; + } +} From dc2e3336feacaf2c182e02f38b68156736b40ba2 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:38 +0800 Subject: [PATCH 029/314] Merge pull request #4026 from ihb2032/opt/rvv-math-stride-ops opt(RVV): Optimize core math and stride functions with intrinsics GitOrigin-RevId: 1b2d4bd5da63d4c4f3a2e457c8a91f3dd47ebb99 --- .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 +++++++++++ .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 ++++++++ .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 +++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp new file mode 100644 index 0000000000..59bb28a039 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp @@ -0,0 +1,29 @@ +#include + +void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); + size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp new file mode 100644 index 0000000000..6d966789f7 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp @@ -0,0 +1,52 @@ +#include + +void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { + float beta = parameters[1]; + float minF = parameters[2]; + float maxF = parameters[3]; + const ptrdiff_t stride = 4 * sizeof(float); + + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + 4 * y; + auto c = C + cStride * y; + float b0Beta = b[0] * beta; + float b1Beta = b[1] * beta; + float b2Beta = b[2] * beta; + float b3Beta = b[3] * beta; + size_t w = width; + + while (w > 0) { + size_t vl = __riscv_vsetvl_e32m8(w); + + vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); + + a += 4 * vl; + c += 4 * vl; + w -= vl; + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp new file mode 100644 index 0000000000..3d8c4f13fc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp @@ -0,0 +1,22 @@ +#include + +void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); +size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp new file mode 100644 index 0000000000..10992f9d59 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp @@ -0,0 +1,42 @@ +#include + +void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t z = 0; z < biasNumber; ++z) { + float *dstZ = dst + z * planeNumber * 4; + const float *srcZ = src + z * planeNumber * 4; + const float *biasZ = bias + 4 * z; + const float *alphaZ = alpha + 4 * z; + float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; + float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; + + size_t n = planeNumber; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a0, vl); + data = __riscv_vfadd_vf_f32m8(data, b0, vl); + __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a1, vl); + data = __riscv_vfadd_vf_f32m8(data, b1, vl); + __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a2, vl); + data = __riscv_vfadd_vf_f32m8(data, b2, vl); + __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a3, vl); + data = __riscv_vfadd_vf_f32m8(data, b3, vl); + __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); + + srcZ += vl * 4; + dstZ += vl * 4; + n -= vl; + } + } +} From 89fdf322e5691243629022e6d47b0fa31bcdbe5f Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:52 +0800 Subject: [PATCH 030/314] Merge pull request #4023 from ihb2032/feature/rvv-transpose-functions opt(RVV): Optimize transpose functions with intrinsics GitOrigin-RevId: 24f98cc6e50fac1e178be5c3c425a3f622343cd0 --- .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 +++++++++++++++++++ .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 ++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp new file mode 100644 index 0000000000..7598d6f8ac --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp @@ -0,0 +1,26 @@ +#include + +void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); + + for (int i = 0; i < h; ++i) { + const int16_t* srcPtr = srcO + i; + int16_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e16m8(w - j); + vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); + __riscv_vse16_v_i16m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + + diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp new file mode 100644 index 0000000000..e5c5eb83e6 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp @@ -0,0 +1,25 @@ +#include + +void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); + + for (int i = 0; i < h; ++i) { + const int32_t* srcPtr = srcO + i; + int32_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e32m8(w - j); + vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); + __riscv_vse32_v_i32m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + From 3cd1ab680dc17d1d78e0529ff12e061f28715948 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:44:24 +0800 Subject: [PATCH 031/314] Merge pull request #4021 from ihb2032/feature/rvv-opt opt(RVV): Optimize pack and unpack functions with intrinsics GitOrigin-RevId: 58b54e86481db0588b59c679c0bade51e04b0d38 --- source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 ++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 ++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp new file mode 100644 index 0000000000..9a74f8998d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp @@ -0,0 +1,74 @@ +#include + +void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC2 = depth / 2; + int depthRemain = depthC2 * 2; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[2]; + + for (int z = 0; z < depthC2; ++z) { + float *dstZ = dst + z * areaOffset[1] * 2; + + for (int y = 0; y < 2; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + } + + srcOffset += areaOffset[0] * 2; + } + + if (remain > 0) { + float *dstZ = dst + depthC2 * areaOffset[1] * 2; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 2; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 2; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp new file mode 100644 index 0000000000..024e2c8c07 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp @@ -0,0 +1,80 @@ +#include + +void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[4]; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ = dst + z * areaOffset[1] * 4; + + for (int y = 0; y < 4; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + dstPtr[2] = srcChannel[2][x]; + dstPtr[3] = srcChannel[3][x]; + } + + srcOffset += areaOffset[0] * 4; + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 4; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 4; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp new file mode 100644 index 0000000000..4676e6dede --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp @@ -0,0 +1,55 @@ +#include + +void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ[4]; + + for (int y = 0; y < 4; ++y) { + dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); + srcOffset += 4 * vl; + } + + for (; x < area; ++x) { + dstZ[0][x] = srcOffset[0]; + dstZ[1][x] = srcOffset[1]; + dstZ[2][x] = srcOffset[2]; + dstZ[3][x] = srcOffset[3]; + srcOffset += (areaOffset[0] - area) * 4; + } + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + const float *srcBase = srcOffset; + + for (int y = 0; y < remain; ++y) { + float *dstChannel = dstZ + y * areaOffset[1]; + const float *srcChannel = srcBase + y; + + for (size_t x = 0; x < area; ++x) { + dstChannel[x] = srcChannel[0]; + srcChannel += 4; + } + } + } +} + From 893a8ce33f5b45d58a6e084b4b9a4169f0a3f880 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:54:53 +0800 Subject: [PATCH 032/314] Merge pull request #4061 from zlaazlaa/fix_diffusion fix(diffusion): simplify export logic and fix dynamic axes GitOrigin-RevId: 4c1cd8ed04606b5302cf9807a42bcc034ebf7c1b --- docs/transformers/diffusion.md | 3 +- transformers/diffusion/export/onnx_export.py | 30 ++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 7de27bb216..609793f806 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path + --output_path onnx_save_path \ + --opset 18 ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 21f05e83be..5516eb2fcc 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - "A sample prompt", + ["A sample prompt", "A sample prompt"], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,9 +97,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.text_encoder @@ -117,13 +115,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "timestep": {0: "batch"}, - "encoder_hidden_states": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -149,7 +143,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() onnx_export( vae_encoder, model_args=( @@ -159,30 +153,24 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = vae_encoder.decode + vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), - False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample", "return_dict"], + ordered_input_names=["latent_sample"], output_names=["sample"], - dynamic_axes={ - "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.vae From 198bdfcfbc8779d453737d0f348cc184c5c35814 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:03 +0800 Subject: [PATCH 033/314] Merge pull request #3998 from bolun365/bolun365-patch-1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mnn lib库自动化build脚本 GitOrigin-RevId: ac1e2a9fd51ff3a9102660cda0d0731dfd849f95 --- build_lib.sh | 807 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 807 insertions(+) create mode 100644 build_lib.sh diff --git a/build_lib.sh b/build_lib.sh new file mode 100644 index 0000000000..c839b6e7b6 --- /dev/null +++ b/build_lib.sh @@ -0,0 +1,807 @@ +#!/bin/bash + +# MNN 统一构建脚本 +# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN + +set -e + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# 默认配置 +BUILD_ANDROID=false +BUILD_IOS=false +BUILD_PYTHON=false +BUILD_IOS_SIMULATOR=false +BUILD_HARMONY=false +ANDROID_NDK="" +HARMONY_HOME="" +OUTPUT_DIR="mnn_builds" +CLEAN_BUILD=true +VERSION="" +PYTHON_CMD="python3" +DEPS_OPTIONS="" + +# 获取项目根目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR" + +print_usage() { + echo "MNN 统一构建脚本" + echo "" + echo "用法: $0 [选项]" + echo "" + echo "构建目标:" + echo " --android 构建 Android 版本 (需要 --ndk)" + echo " --ios 构建 iOS 真机版本 (arm64)" + echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" + echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" + echo " --python 构建 Python 版本" + echo "" + echo "Android 选项:" + echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" + echo "" + echo "鸿蒙选项:" + echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" + echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" + echo "" + echo "Python 选项:" + echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" + echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" + echo " --python-cmd CMD Python 命令 (默认: python3)" + echo "" + echo "通用选项:" + echo " -o, --output DIR 输出目录 (默认: mnn_builds)" + echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" + echo " --no-clean 不清理之前的构建目录" + echo " -h, --help 显示帮助信息" + echo "" + echo "示例:" + echo " # 构建所有平台" + echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 仅构建 Android" + echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 构建鸿蒙版本" + echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" + echo "" + echo " # 构建 iOS 真机和模拟器" + echo " $0 --ios --ios-simulator" + echo "" + echo " # 构建 Python (带 LLM 支持)" + echo " $0 --python --python-deps llm,opencl" +} + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + --android) + BUILD_ANDROID=true + shift + ;; + --ios) + BUILD_IOS=true + shift + ;; + --ios-simulator) + BUILD_IOS_SIMULATOR=true + shift + ;; + --python) + BUILD_PYTHON=true + shift + ;; + --harmony) + BUILD_HARMONY=true + shift + ;; + --ndk) + ANDROID_NDK="$2" + shift 2 + ;; + --harmony-home) + HARMONY_HOME="$2" + shift 2 + ;; + --python-deps) + DEPS_OPTIONS="$2" + shift 2 + ;; + --python-cmd) + PYTHON_CMD="$2" + shift 2 + ;; + -o|--output) + OUTPUT_DIR="$2" + shift 2 + ;; + -v|--version) + VERSION="$2" + shift 2 + ;; + --no-clean) + CLEAN_BUILD=false + shift + ;; + -h|--help) + print_usage + exit 0 + ;; + *) + echo -e "${RED}错误: 未知选项 $1${NC}" + print_usage + exit 1 + ;; + esac +done + +# 检查是否至少选择了一个构建目标 +if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then + echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" + print_usage + exit 1 +fi + +# 检查 Android NDK +if [ "$BUILD_ANDROID" = true ]; then + if [ -z "$ANDROID_NDK" ]; then + echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" + echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" + exit 1 + fi + + # 展开路径中的 ~ + ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" + + if [ ! -d "$ANDROID_NDK" ]; then + echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" + exit 1 + fi +fi + +# 查找鸿蒙工具链的函数 +find_harmony_toolchain() { + # 如果环境变量已设置,优先使用 + if [ -n "$HARMONY_HOME" ]; then + # 支持两种路径格式:直接指定或带 native 子目录 + if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + fi + fi + + # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 + local sdk_base="$HOME/Library/OpenHarmony/Sdk" + if [ -d "$sdk_base" ]; then + # 查找所有版本号目录下的 native 或 toolchains + # 按版本号倒序排列,优先使用最新版本 + for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do + if [ -d "$version_dir" ]; then + # 尝试 native/build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir/native" + return 0 + fi + # 尝试 build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir" + return 0 + fi + fi + done + fi + + # 其他可能的路径 + local possible_paths=( + "$HOME/Library/OpenHarmony/Sdk/native" + "$HOME/HarmonyOS/Sdk/native" + "$HOME/.ohos/native" + "/opt/HarmonyOS/Sdk/native" + "/usr/local/HarmonyOS/Sdk/native" + "$HOME/Library/HarmonyOS/Sdk/native" + ) + + # 尝试查找 + for path in "${possible_paths[@]}"; do + if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then + echo "$path" + return 0 + fi + done + + # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 + # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 + local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) + if [ -n "$found" ]; then + # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 + found=$(dirname "$found") + if [ "$(basename "$found")" = "cmake" ]; then + found=$(dirname "$found") + if [ "$(basename "$found")" = "build" ]; then + found=$(dirname "$found") + fi + fi + echo "$found" + return 0 + fi + + return 1 +} + +# 检查鸿蒙工具链 +if [ "$BUILD_HARMONY" = true ]; then + # 展开路径中的 ~ + if [ -n "$HARMONY_HOME" ]; then + HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" + fi + + # 尝试查找工具链 + if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + # 检查是否在 native 子目录下 + if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + else + echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" + HARMONY_HOME=$(find_harmony_toolchain) + + if [ -z "$HARMONY_HOME" ]; then + echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" + echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" + echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" + echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi + + # 如果找到的是带 native 的路径,需要调整 + if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + fi + fi + + echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" + fi + + # 验证工具链文件存在 + if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi +fi + +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}MNN 统一构建脚本${NC}" +echo -e "${GREEN}========================================${NC}" +echo "项目根目录: $PROJECT_ROOT" +echo "输出目录: $OUTPUT_DIR" +echo "" +echo -e "${BLUE}构建目标:${NC}" +[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" +[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" +[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" +[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" +echo "" + +cd "$PROJECT_ROOT" +mkdir -p "$OUTPUT_DIR" + +# ============================================================================ +# 构建 Android 版本 +# ============================================================================ +if [ "$BUILD_ANDROID" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Android 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export ANDROID_NDK + + ANDROID_BUILD_DIR="project/android" + ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" + + cd "$ANDROID_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 Android 构建...${NC}" + rm -rf build_32 build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + + # 构建 armeabi-v7a + echo -e "${BLUE}构建 armeabi-v7a...${NC}" + mkdir -p build_32 + cd build_32 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="armeabi-v7a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-14 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出 Android 库文件...${NC}" + mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN + + cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true + cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/android/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" + + echo -e "${GREEN}Android 构建完成!${NC}" + echo "输出位置: $ANDROID_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 真机版本 +# ============================================================================ +if [ "$BUILD_IOS" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 真机版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_64 + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建真机版本 (arm64) + echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" + rm -rf ios_64 + mkdir ios_64 + cd ios_64 + cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=OS64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + > /dev/null + make MNN -j8 > /dev/null + cd .. + + # 输出真机版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理 + rm -rf ios_64 + + echo -e "${GREEN}iOS 真机构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 模拟器版本 +# ============================================================================ +if [ "$BUILD_IOS_SIMULATOR" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) + BUILD_X86_64=false + echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" + rm -rf ios_simulator_x86 + mkdir ios_simulator_x86 + cd ios_simulator_x86 + if cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="x86_64" \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + && make MNN -j8; then + if [ -f "MNN.framework/MNN" ]; then + BUILD_X86_64=true + echo -e "${GREEN}x86_64 模拟器构建成功${NC}" + cd .. + else + echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + else + echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + + # 构建 arm64 模拟器版本 + echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" + rm -rf ios_simulator_arm64 + mkdir ios_simulator_arm64 + cd ios_simulator_arm64 + if ! cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON; then + echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + if ! make MNN -j8; then + echo -e "${RED}错误: arm64 模拟器编译失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + cd .. + + # 验证构建产物 + if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then + echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + + # 合并模拟器架构 + echo -e "${BLUE}合并模拟器架构...${NC}" + rm -rf ios_simulator + mkdir ios_simulator + + if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then + # 合并 x86_64 + arm64 + echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" + cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework + mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 + if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then + echo -e "${RED}错误: 合并架构失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + rm ios_simulator/MNN.framework/MNN_x86 + else + # 仅使用 arm64 + echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" + cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework + fi + + # 输出模拟器版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理临时目录 + rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 + + echo -e "${GREEN}iOS 模拟器构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建鸿蒙版本 +# ============================================================================ +if [ "$BUILD_HARMONY" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建鸿蒙版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export HARMONY_HOME + + HARMONY_BUILD_DIR="project/harmony" + HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" + + cd "$HARMONY_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" + rm -rf build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + + cmake "$PROJECT_ROOT" \ + -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DOHOS_ARCH="arm64-v8a" \ + -DOHOS_STL=c++_static \ + -DMNN_USE_LOGCAT=true \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_SUPPORT_BF16=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DOHOS_PLATFORM_LEVEL=9 \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出鸿蒙库文件...${NC}" + mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN + + cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/harmony/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" + + echo -e "${GREEN}鸿蒙构建完成!${NC}" + echo "输出位置: $HARMONY_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 Python 版本 +# ============================================================================ +if [ "$BUILD_PYTHON" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Python 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查 Python + if ! command -v $PYTHON_CMD &> /dev/null; then + echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" + exit 1 + fi + + PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" + + # 使用之前创建的 build_pymnn.sh 脚本 + if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then + PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" + [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" + [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" + [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" + PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" + + bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS + else + echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" + + cd pymnn/pip_package + + # 构建依赖 + if [ -n "$DEPS_OPTIONS" ]; then + $PYTHON_CMD build_deps.py $DEPS_OPTIONS + else + $PYTHON_CMD build_deps.py + fi + + # 构建 wheel + $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet + rm -rf build dist + + BUILD_ARGS="" + [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" + [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" + + $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS + + mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" + cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" + + cd "$PROJECT_ROOT" + fi + + echo -e "${GREEN}Python 构建完成!${NC}" + echo "输出位置: $PYTHON_OUTPUT_DIR" +fi + +# ============================================================================ +# 总结 +# ============================================================================ +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}所有构建完成!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" +echo "" +[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" +[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" +[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" +[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" +echo "" + + From b1ae09cc18b58b3d8067da5c2fbb6ad4b4593a1b Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:54 +0800 Subject: [PATCH 034/314] Merge pull request #4009 from HenryDen/default_opt Add a compile option and macro to default enable kleidiAI GitOrigin-RevId: a3bc314f99ee49f10608550b92d4b37e0ca2d8f0 --- CMakeLists.txt | 1 + source/backend/cpu/arm/CMakeLists.txt | 3 +++ source/core/Backend.hpp | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 67502b606b..f99e37ec1c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,6 +258,7 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) +option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 18fca54a4e..61ebce6bdc 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,6 +36,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() + if(MNN_KLEIDIAI_DEFAULT_ON) + add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) + endif() endif() if (MNN_SME2) diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index bcf618c3c9..6850b6b4f6 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,9 +68,11 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; - +#ifdef MNN_DEFAULT_USE_KLEIDIAI + bool enableKleidiAI = true; +#else bool enableKleidiAI = false; - +#endif // Use CPU Ids std::vector cpuIds; From adcd2cf3bb329032895217b2cef8c6beb2a65496 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:42:22 +0800 Subject: [PATCH 035/314] Merge branch feature/add_4th_groupchat into master Title: [Doc:Update] update dingtalk in README. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审的主要改动是对README文件中的钉钉群信息进行了更新,包括群号、状态以及删除了一些过时的信息。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25029869 GitOrigin-RevId: 3e482c2332f0a4f4088ff8bdf75048eb51177330 --- README.md | 14 +++++++------- README_CN.md | 10 ++++------ README_JP.md | 9 +++++---- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 5fe168ed05..7959890c16 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #1 (Full): 23329087 +Group #4 (Available): 160170007549 -Group #2 (Full): 23350225 +Group #3 (Full) -Group #3: QR code: +Group #2 (Full): 23350225 -![MNN-3](doc/dingdingmnn3.png) +Group #1 (Full): 23329087 ## Historical Paper diff --git a/README_CN.md b/README_CN.md index edcf823a28..f769a1e14b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,12 +111,10 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群1:23329087 -- 钉钉群2:23350225 -- 钉钉群3:扫描二维码加入 - -![MNN-3](doc/dingdingmnn3.png) - +- 钉钉群3 (可加入): 160170007549 +- 钉钉群3 (已无法加入) +- 钉钉群2 (已满): 23350225 +- 钉钉群1 (已满): 23329087 ## 历史论文 diff --git a/README_JP.md b/README_JP.md index c2baa58d94..2f33def31a 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,13 +117,14 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: -グループ#1(満員):23329087 -グループ#2(満員):23350225 +グループ#4 :160170007549 -グループ#3:QRコード: +グループ#3 (満員) -![MNN-3](doc/dingdingmnn3.png) +グループ#2(満員):23350225 + +グループ#1(満員):23329087 ## 歴史的な論文 From ffed9d73130bcd024f6509a7b73be541328f913b Mon Sep 17 00:00:00 2001 From: jxt1234 Date: Sun, 21 Dec 2025 14:27:40 +0800 Subject: [PATCH 036/314] Tools:Feature: Support dump per op cost for ModuleBasic --- tools/cpp/ExprDebug.hpp | 53 +++++++++++++++++++++++++---- tools/cpp/ModuleBasic.cpp | 46 ++++++++++--------------- transformers/llm/engine/src/llm.cpp | 21 +----------- 3 files changed, 66 insertions(+), 54 deletions(-) diff --git a/tools/cpp/ExprDebug.hpp b/tools/cpp/ExprDebug.hpp index 167e97c562..49e3db6156 100644 --- a/tools/cpp/ExprDebug.hpp +++ b/tools/cpp/ExprDebug.hpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #define DUMP_NUM_DATA(type) \ @@ -135,29 +136,69 @@ static void _initDebug() { struct TimeTraceInfo { - std::map>>> mTypes; + std::map>> mTypes; void begin(const MNN::OperatorInfo* info) { auto tIter = mTypes.find(info->type()); if (tIter == mTypes.end()) { - std::map>> _t; + std::map> _t; mTypes.insert(std::make_pair(info->type(), _t)); tIter = mTypes.find(info->type()); } mInserIter = tIter->second.find(info->name()); if (mInserIter == tIter->second.end()) { - std::vector> _t; - tIter->second.insert(std::make_pair(info->name(), _t)); + tIter->second.insert(std::make_pair(info->name(), std::make_tuple(0.0f, 0.0f, 0))); mInserIter = tIter->second.find(info->name()); } mTimer.reset(); } void end(const MNN::OperatorInfo* info) { auto timeInMs = (float)mTimer.durationInUs() / 1000.0f; - mInserIter->second.emplace_back(std::make_pair(timeInMs, info->flops())); + std::get<0>(mInserIter->second) += timeInMs; + std::get<1>(mInserIter->second) += info->flops(); + std::get<2>(mInserIter->second) ++; + } + void dump(bool dumpPerOp = false) { + if (dumpPerOp) { + auto cmp = [](const std::tuple& first, const std::tuple& second) { + return std::get<1>(first) > std::get<1>(second); + }; + std::priority_queue, std::vector>, decltype(cmp)> que(cmp); + for (auto& iter : mTypes) { + for (auto& t : iter.second) { + auto mergeType = t.first + " ["+iter.first +"]"; + auto unit = std::make_tuple(mergeType, std::get<0>(t.second), std::get<1>(t.second), std::get<2>(t.second)); + que.push(unit); + } + } + while (!que.empty()) { + auto& t = que.top(); + MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", std::get<0>(t).c_str(), std::get<1>(t), std::get<2>(t), std::get<3>(t), std::get<2>(t) / std::get<1>(t)); + que.pop(); + } + return; + } + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + int count = 0; + for (auto& t : iter.second) { + summer += std::get<0>(t.second); + summerflops += std::get<1>(t.second); + count += std::get<2>(t.second); + } + MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, count, + summerflops / summer); + opSummer += summer; + opFlopsSummber += summerflops; + } + MNN_PRINT("OP Summer: %.7f ms, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, + opFlopsSummber / opSummer); } private: - std::map>>::iterator mInserIter; + std::map>::iterator mInserIter; MNN::Timer mTimer; }; static TimeTraceInfo* gTimeTraceInfo = nullptr; diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 90fa6b80d3..5798bc6d26 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -499,10 +499,13 @@ int main(int argc, char *argv[]) { if (runTime > 0) { int t = runTime; - std::vector times(t, 0.0f); if (runMask & 4) { _initTimeTrace(); } + float minTime = std::numeric_limits::max(); + float maxTime = 0.0f; + float sum = 0.0f; + for (int i = 0; i < t; ++i) { Timer _l; auto out = net->onForward(inputs); @@ -510,41 +513,28 @@ int main(int argc, char *argv[]) { for (auto o : out) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } - times[i] = _l.durationInUs() / 1000.0f; + auto time = _l.durationInUs() / 1000.0f; if (freq > 0.0f) { - float remainMs = (1000.0f / freq) - times[i]; + float remainMs = (1000.0f / freq) - time; if (remainMs > 0.0f) { std::this_thread::sleep_for(std::chrono::milliseconds((int)remainMs)); } } - } - if (nullptr != gTimeTraceInfo) { - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : gTimeTraceInfo->mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - for (auto& t : iter.second) { - for (auto& t0 : t.second) { - summer += t0.first; - summerflops += t0.second; - } - } - summer = summer / (float)t; - summerflops = summerflops / (float)t; - MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer); - opSummer += summer; - opFlopsSummber+= summerflops; + if (maxTime < time) { + maxTime = time; + } + if (minTime > time) { + minTime = time; } - MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer); - } - auto minTime = std::min_element(times.begin(), times.end()); - auto maxTime = std::max_element(times.begin(), times.end()); - float sum = 0.0f; - for (auto time : times) { sum += time; } - MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, *minTime, *maxTime); + if (nullptr != gTimeTraceInfo) { + MNN_PRINT("Per Op Trace: \n"); + gTimeTraceInfo->dump(true); + MNN_PRINT("Per Type Trace: \n"); + gTimeTraceInfo->dump(false); + } + MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, minTime, maxTime); } rtmgr->updateCache(); return 0; diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 53af11239a..63c590e0fd 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -915,26 +915,7 @@ Llm::Llm(std::shared_ptr config) : mConfig(config) { Llm::~Llm() { #if DEBUG_MODE == 1 if (nullptr != gTimeTraceInfo) { - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : gTimeTraceInfo->mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - for (auto& t : iter.second) { - for (auto& t0 : t.second) { - summer += t0.first; - summerflops += t0.second; - } - } - summer = summer; - summerflops = summerflops; - MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, - summerflops / summer); - opSummer += summer; - opFlopsSummber += summerflops; - } - MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, - opFlopsSummber / opSummer); + gTimeTraceInfo->dump(); } #endif mGenerateParam.reset(); From 528d2b3558d5df4b51d1c2e629c7a9bac8d51d3a Mon Sep 17 00:00:00 2001 From: jxt1234 Date: Sun, 21 Dec 2025 18:49:09 +0800 Subject: [PATCH 037/314] MNN:Speed: Optimize MatMul E=1 for case H = 24-31. Reduce unuseful operation, reduce memcpy for CPURNNSequenceGRU --- source/backend/cpu/CPURNNSequenceGRU.cpp | 70 ++++++++++--------- source/backend/cpu/CPURNNSequenceGRU.hpp | 15 +++- .../backend/cpu/compute/CommonOptFunction.cpp | 61 ++++++++++++++-- source/math/Vec.hpp | 3 +- 4 files changed, 106 insertions(+), 43 deletions(-) diff --git a/source/backend/cpu/CPURNNSequenceGRU.cpp b/source/backend/cpu/CPURNNSequenceGRU.cpp index daae8811c7..0bda660e9c 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.cpp +++ b/source/backend/cpu/CPURNNSequenceGRU.cpp @@ -10,30 +10,26 @@ #include #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/ConvOpt.h" -#include "backend/cpu/compute/CommonOptFunction.h" #include "core/TensorUtils.hpp" namespace MNN { // implement GRU cell function // Ref: tensorflow/python/ops/rnn_cell_impl.py -void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, +void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt) { - auto bn = static_cast(backend()); - auto mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); - auto addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); - auto subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); - auto tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); - auto bytes = bn->functions()->bytes; - auto sigmoidFunc = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); + std::shared_ptr& resetHt, uint8_t* hiddenStateOutput) { // gate is (z_t, r_t) + auto bytes = mRNNFunctions.bytes; + MNNBinaryExecute mulFunction = mRNNFunctions.mulFunction; + MNNBinaryExecute addFunction = mRNNFunctions.addFunction; + MNNBinaryExecute subFunction = mRNNFunctions.subFunction; + MNNUnaryExecute tanhFunction = mRNNFunctions.tanhFunction; + MNNUnaryExecute sigmoidFunction = mRNNFunctions.sigmoidFunction; auto inputAndStatePtr = inputAndState->host(); - auto hiddenStatePtr = hiddenState->host(); ::memcpy(inputAndStatePtr, input, inputLength * bytes); - ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStatePtr, numUnits * bytes); + ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStateInput, numUnits * bytes); inputAndState->setLength(1, inputLength + numUnits); // // [x_t, h_t-1] * [W_zr, R_zr]: (1, inputLength + numUnits) X (inputLength + numUnits, 2 * numUnits) @@ -42,9 +38,8 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, recurrentBias->setLength(1, 2 * numUnits); addFunction(gate->host(), gate->host(), recurrentBias->host(), 2*numUnits, -1); // (1, 2*numUnits) - const int gateSize = gate->elementSize(); auto gatePtr = gate->host(); - sigmoidFunc(gatePtr, gatePtr, gateSize); + sigmoidFunction(gatePtr, gatePtr, 2 * numUnits); // reset gate, // r_t is the second segment auto rtPtr = gatePtr + numUnits * bytes; @@ -52,7 +47,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // calculate Rt (.) (Ht_1 * Rh + Rbh) auto recurrentHiddenBiasPtr = recurrentBias->host() + 2 * numUnits * bytes; auto rhWeightPtr = candidateWeight->host() + inputLength * numUnits * bytes; - mMatMulU2U->execute(hiddenState->host(), (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); + mMatMulU2U->execute((float*)hiddenStateInput, (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); mulFunction(resetHt->host(), rtPtr, resetHt->host(), numUnits, -1); // calculate Xt * Wh @@ -65,7 +60,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // r_t: (1, numUnits) auto resetGatePtr = inputAndStatePtr + inputLength * bytes; // h_t1(1, numUnits) = r_t(1, numUnits) * h_t-1_(1, numUnits) - mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1); + mulFunction(resetGatePtr, rtPtr, hiddenStateInput, numUnits, -1); // deal with recurrent bias and linear_before_reset parameter auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes; auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host() + 2 * numUnits * bytes); @@ -76,9 +71,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, } // h = (1-g)*t+g*h = t + g*(h-t) tanhFunction(resetHt->host(), rtPtr, numUnits); - subFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); - mulFunction(hiddenStatePtr, hiddenStatePtr, gatePtr, numUnits, -1); - addFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); + subFunction(hiddenStateOutput, hiddenStateInput, resetHt->host(), numUnits, -1); + mulFunction(hiddenStateOutput, hiddenStateOutput, gatePtr, numUnits, -1); + addFunction(hiddenStateOutput, hiddenStateOutput, resetHt->host(), numUnits, -1); inputAndState->setLength(1, inputLength + 2 * numUnits); } @@ -143,6 +138,13 @@ ErrorCode CPURNNSequenceGRU::onResize(const std::vector& inputs, const backend()->onReleaseBuffer(mInputAndState.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mGate.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mResetHt.get(), Backend::DYNAMIC); + auto bn = static_cast(backend()); + mRNNFunctions.mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); + mRNNFunctions.addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); + mRNNFunctions.subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); + mRNNFunctions.tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); + mRNNFunctions.bytes = bn->functions()->bytes; + mRNNFunctions.sigmoidFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); return NO_ERROR; } @@ -183,27 +185,29 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const const int inputCodeLength = input->length(2); // MNN_PRINT("inputSequenceLength:%d, batchSize:%d, inputCodeLength:%d, mNumUnits:%d, hiddenStateDataSize:%d\n", inputSequenceLength, batchSize, inputCodeLength, mNumUnits, hiddenStateDataSize); for (int b = 0; b < batchSize; ++b) { // swap order + auto hiddenStateInput = hiddenStatePtr; + auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * (mIsBidirectionalRNN + 1)) { auto source = inputs[inputSize - 1]->host() + b * hiddenStateDataSize; - ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); + hiddenStateInput = source; } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = 0; i < inputSequenceLength; ++i) { const int inputOffset = i * SequenceStride + b * inputCodeLength; - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, fwGateWeight, fwGateBias, - fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt); - if (mKeepAllOutputs) { - ::memcpy(outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes, hiddenStatePtr, hiddenStateDataSize); + hiddenStateOutput = outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes; } + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, fwGateWeight, fwGateBias, + fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); + + hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { - ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); + ::memcpy(outputYhPtr, hiddenStateOutput, hiddenStateDataSize); outputYhPtr += mNumUnits * bytes; } - } // backward rnn @@ -221,22 +225,24 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const auto outputBw = outputs[0]; auto const outputBwPtr = outputBw->host(); for (int b = 0; b < batchSize; ++b) { + auto hiddenStateInput = hiddenStatePtr; + auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * 2) { auto source = inputs[inputSize - 1]->host() + (batchSize + b) * hiddenStateDataSize; - ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); + hiddenStateInput = source; } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = inputSequenceLength - 1; i >= 0; i--) { const int inputOffset = i * SequenceStride + b * inputCodeLength; - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, bwGateWeight, bwGateBias, - bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt); if (mKeepAllOutputs) { - ::memcpy(outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes, - hiddenStatePtr, hiddenStateDataSize); + hiddenStateOutput = outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes; } + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, bwGateWeight, bwGateBias, + bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); + hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); diff --git a/source/backend/cpu/CPURNNSequenceGRU.hpp b/source/backend/cpu/CPURNNSequenceGRU.hpp index 0987d13053..0125b9e8a1 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.hpp +++ b/source/backend/cpu/CPURNNSequenceGRU.hpp @@ -11,6 +11,7 @@ #include "core/Execution.hpp" #include "CPUMatMul.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" namespace MNN { class CPURNNSequenceGRU : public Execution { @@ -19,13 +20,20 @@ class CPURNNSequenceGRU : public Execution { virtual ~CPURNNSequenceGRU(); virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - + struct RNNFuntions { + MNNBinaryExecute mulFunction; + MNNBinaryExecute addFunction; + MNNBinaryExecute subFunction; + MNNUnaryExecute tanhFunction; + MNNUnaryExecute sigmoidFunction; + int bytes; + }; private: void runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, + uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt); + std::shared_ptr& resetHt, uint8_t* hiddenStateOutput); bool mKeepAllOutputs; bool mIsBidirectionalRNN; bool mlinearBeforeReset; @@ -42,6 +50,7 @@ class CPURNNSequenceGRU : public Execution { std::shared_ptr mMatMulU2U; // For inputLength -> numUnit std::shared_ptr mMatMulI2U; + RNNFuntions mRNNFunctions; }; } // namespace MNN diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index d7d0d7fb34..9abefb6df4 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -3882,12 +3882,13 @@ void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, si #endif -void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { +void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tIdL) { auto l = param->l; auto h = param->h; auto numberThread = param->numberThread; auto lC4 = l / 4; auto lR = lC4 * 4; + auto tId = (int)tIdL; if (param->BTranspose) { for (int y=tId; y= 8) { + if (0 == tId) { + auto bs = B + hEnd; + Vec4 sumValue0; + Vec4 sumValue1; + if (biasPtr != nullptr) { + sumValue0 = Vec4::load(biasPtr + hEnd + 0); + sumValue1 = Vec4::load(biasPtr + hEnd + 4); + } else { + sumValue0 = Vec4(0.0f); + sumValue1 = Vec4(0.0f); + } + auto srcY = A + hEnd * l; + for (int x=0; x= 4) { + if (0 == tId) { + auto bs = B + hEnd; + Vec4 sumValue0; + if (biasPtr != nullptr) { + sumValue0 = Vec4::load(biasPtr + hEnd + 0); + } else { + sumValue0 = Vec4(0.0f); + } + auto srcY = A + hEnd * l; + for (int x=0; x { using VecType = Vec; using VecTypeInt32 = Vec; float32x4_t value; - Vec() { - } + Vec() = default; Vec(const float v) { value = vdupq_n_f32(v); } From 10de18c80d5104f8f9431e5ee1e2e297196c6987 Mon Sep 17 00:00:00 2001 From: jxt1234 Date: Sun, 21 Dec 2025 23:11:27 +0800 Subject: [PATCH 038/314] MNN:Speed: Reduce ThreadPool Function Object Cost --- source/backend/cpu/CPUBackend.cpp | 7 +- source/backend/cpu/CPUBackend.hpp | 3 + source/backend/cpu/CPUBinary.cpp | 60 +-- source/backend/cpu/CPUBinary.hpp | 4 + source/backend/cpu/CPUMatMul.cpp | 28 +- source/backend/cpu/CPUMatMul.hpp | 7 +- source/backend/cpu/CPURaster.cpp | 619 +++++++++++++++--------------- source/backend/cpu/CPURaster.hpp | 3 +- source/backend/cpu/ThreadPool.cpp | 32 +- source/backend/cpu/ThreadPool.hpp | 6 +- source/core/Concurrency.h | 13 +- test/core/ThreadPoolTest.cpp | 6 +- 12 files changed, 420 insertions(+), 368 deletions(-) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 95cbd903b7..8d284aa33b 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -104,15 +104,14 @@ void CPURuntime::_bindCPUCore() const { #ifdef MNN_USE_THREAD_POOL if (nullptr != mThreadPool) { mThreadPool->active(); - mThreadPool->enqueue(std::make_pair([&](int i) { + ThreadPool::TASK task = std::make_pair([&](int i) { MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); - return 0; - }, mThreadNumber), mTaskIndex); + }, mThreadNumber); + mThreadPool->enqueue(&task, mTaskIndex); mThreadPool->deactive(); } #endif } - void CPURuntime::_resetThreadPool() const { mThreadNumber = std::max(1, mThreadNumber); mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 884036eb38..ec4c555dec 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -176,6 +176,9 @@ class CPUBackend : public Backend { #ifdef MNN_USE_THREAD_POOL inline int taskIndex() const {return mRuntime->mTaskIndex;} inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;} + void enqueue(ThreadPool::TASK& task) const { + threadPool()->enqueue(&task, taskIndex()); + } #endif static void initCreatorMap(); static size_t getBytes(const Backend* backend, const Tensor* output); diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp index 059e502d0b..61ccf4fca3 100644 --- a/source/backend/cpu/CPUBinary.cpp +++ b/source/backend/cpu/CPUBinary.cpp @@ -45,6 +45,37 @@ ErrorCode CPUBinary::onResize(const std::vector& inputs, const std::vec mThreadNum = threads; mWorkDiv = UP_DIV(mTotalSize, threads); } + int inpBytes = inputs[0]->getType().bytes(); + int outBytes = outputs[0]->getType().bytes(); + if (halide_type_float == inputs[0]->getType().code) { + inpBytes = static_cast(backend())->functions()->bytes; + } + if (halide_type_float == outputs[0]->getType().code) { + outBytes = static_cast(backend())->functions()->bytes; + } + bool outputInt = outputs[0]->getType().code == halide_type_int; + mTask = std::make_pair([this, inpBytes, outBytes, outputInt](int tId) { + int start = tId * mWorkDiv; + int realSize = ALIMIN(mWorkDiv, mTotalSize - start); + if (realSize > 0) { + auto inp0 = mInput0Ptr + start * inpBytes; + auto inp1 = mInput1Ptr + start * inpBytes; + if (mNeedBroadcastIndex == 0) { + inp0 = mInput0Ptr; + } else if (mNeedBroadcastIndex == 1) { + inp1 = mInput1Ptr; + } + auto out = mOutputPtr + start * outBytes; + mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); + if(mActivationType == 1 && outputInt) { + for(int i=0; i 0 ? val : 0; + ((int32_t *)out)[i] = res; + } + } + } + } , mThreadNum); return NO_ERROR; } @@ -67,31 +98,10 @@ ErrorCode CPUBinary::onExecute(const std::vector& inputs, const std::ve outBytes = static_cast(backend())->functions()->bytes; } auto precision = static_cast(backend())->precisionMode(); - - MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { - int start = tId * mWorkDiv; - int realSize = ALIMIN(mWorkDiv, mTotalSize - start); - if (realSize > 0) { - auto inp0 = input0Ptr + start * inpBytes; - auto inp1 = input1Ptr + start * inpBytes; - if (mNeedBroadcastIndex == 0) { - inp0 = input0Ptr; - } else if (mNeedBroadcastIndex == 1) { - inp1 = input1Ptr; - } - auto out = outputPtr + start * outBytes; - mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); - if(mActivationType == 1 && output->getType().code == halide_type_int) { - for(int i=0; i 0 ? val : 0; - ((int32_t *)out)[i] = res; - } - } - } - } - MNN_CONCURRENCY_END(); - + mInput0Ptr = input0Ptr; + mInput1Ptr = input1Ptr; + mOutputPtr = outputPtr; + MNN_CONCURRENCY_ENQUEUE(mTask); if(mActivationType == 1 && output->getType().code == halide_type_float) { mActivationExe->onExecute(outputs, outputs); } diff --git a/source/backend/cpu/CPUBinary.hpp b/source/backend/cpu/CPUBinary.hpp index 9250df79ae..17cb3b5f47 100644 --- a/source/backend/cpu/CPUBinary.hpp +++ b/source/backend/cpu/CPUBinary.hpp @@ -33,6 +33,10 @@ class CPUBinary : public Execution { int mThreadNum; int mWorkDiv; std::shared_ptr mActivationExe; + std::pair, int> mTask; + uint8_t* mInput0Ptr = nullptr; + uint8_t* mOutputPtr = nullptr; + uint8_t* mInput1Ptr = nullptr; }; } // namespace MNN #endif /* CPUBinary_hpp */ diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 4f0765f050..22b96a64ee 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -37,9 +37,8 @@ void CPUMatMul::_scheduleForVecE(int e, int l, int h) { param.BTranspose = mTransposeB; param.numberThread = numberThread; auto func = static_cast(backend())->functions()->MNNComputeMatMulForE_1; - mPreFunctions.emplace_back(std::make_pair([param, func]( - int tId, const float* A, const float* B, const float* biasPtr, float* C) { - func(A, B, C, biasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func, this](int tId) { + func(mA, mB, mC, mBiasPtr, ¶m, tId); }, numberThread)); } @@ -54,9 +53,9 @@ void CPUMatMul::_scheduleForVec(int e, int l, int h) { auto func = static_cast(backend())->functions()->MNNComputeMatMulForH_1; // TODD: Support e = 1 MNN_ASSERT(h == 1); - mPreFunctions.emplace_back(std::make_pair([param, func]( - int tId, const float* A, const float* B, const float* biasPtr, float* C) { - func(A, B, C, biasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func, this]( + int tId) { + func(mA, mB, mC, mBiasPtr, ¶m, tId); }, numberThread)); } @@ -100,8 +99,8 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec return OUT_OF_MEMORY; } - mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) { - core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB); + mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId) { + core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), mB, h, 1, l, mTransposeB); } , 1)); bool useBias = false; MemChunk bdestAlloc; @@ -120,9 +119,9 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec } mTempBias = bdestAlloc; mPreFunctions.emplace_back(std::make_pair( - [biasLength, bdestAlloc, core](int tId, const float* APtr, const float* BPtr, const float* borigin, float* C) { + [biasLength, bdestAlloc, core, this](int tId) { ::memset(bdestAlloc.ptr(), 0, UP_DIV(biasLength, core->pack) * core->bytes * core->pack); - ::memcpy(bdestAlloc.ptr(), borigin, biasLength * core->bytes); + ::memcpy(bdestAlloc.ptr(), mBiasPtr, biasLength * core->bytes); }, 1)); } else { mUseBiasDirectly = true; @@ -167,11 +166,12 @@ ErrorCode CPUMatMul::onExecute(const std::vector& inputs, const std::ve } void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const float* biasPtr) { + mA = APtr; + mB = BPtr; + mC = CPtr; + mBiasPtr = biasPtr; for (auto& f : mPreFunctions) { - MNN_CONCURRENCY_BEGIN(tId, f.second) { - f.first(tId, APtr, BPtr, biasPtr, CPtr); - } - MNN_CONCURRENCY_END(); + MNN_CONCURRENCY_ENQUEUE(f); } if (mE > 0) { auto core = static_cast(backend())->functions(); diff --git a/source/backend/cpu/CPUMatMul.hpp b/source/backend/cpu/CPUMatMul.hpp index 872a77a9a8..48226795f0 100644 --- a/source/backend/cpu/CPUMatMul.hpp +++ b/source/backend/cpu/CPUMatMul.hpp @@ -29,7 +29,7 @@ class CPUMatMul : public Execution { bool mTransposeB; bool mTransposeC; bool mSupportMultiThread = false; - std::vector, int>> mPreFunctions; + std::vector, int>> mPreFunctions; bool mUseBiasDirectly = false; MemChunk mTempA; MemChunk mTempB; @@ -40,6 +40,11 @@ class CPUMatMul : public Execution { int mL; int mH; std::vector mPostParameters; + // For Execute Paramters + const float* mA = nullptr; + const float* mB = nullptr; + const float* mBiasPtr = nullptr; + float* mC = nullptr; }; } // namespace MNN diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index 3272086531..f64dafced3 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -49,227 +49,6 @@ struct ReduceInfo { } }; -ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { - MNN_ASSERT(outputs.size() == 1); - auto output = outputs[0]; - OpCommonUtils::rasterInputReset(____inputs, outputs[0]); - auto des = TensorUtils::getDescribe(output); - auto outputDes = TensorUtils::getDescribe(output); - mNeedZero = !TensorUtils::regionIsFull(output); - mZeroPoint = 0; - mUseThreads = false; - if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { -#ifdef MNN_USE_SSE - mZeroPoint = (int)outputDes->quantAttr->zero + 128; -#else - mZeroPoint = (int)outputDes->quantAttr->zero; -#endif - } - mTempInput.clear(); - mFastBlit.clear(); - mCacheRegions.clear(); - mTempOutput = nullptr; - auto midFormat = MNN_DATA_FORMAT_NCHW; - mTempInputCopy.clear(); - mFast = false; - auto core = static_cast(backend())->functions(); - mSingleConvert.type = 0; - // all_srcFormat == dstFormat == NC4HW4 : Fast Exe - if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { - mFast = true; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - mFast = false; - break; - } - if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { - mFast = false; - break; - } - } - if (mFast) { - mUseThreads = des->regions.size() > 16 ? true : false; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (slice.origin == nullptr) { - continue; - } - Tensor::InsideDescribe::Region newRegion; - OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); - mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); - } - return NO_ERROR; - } - } - // srcNum == 1 && srcFormat != dstFormat : Single Convert - if (des->regions.size() == 1) { - OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); - if (mSingleConvert.type > 0) { - mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; - return NO_ERROR; - } - } - // Acquire Buffer for temp output - // TODO: optimize it - if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { - mTempOutput.reset(new Tensor); - TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); - } - if (nullptr != mTempOutput) { - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - } - // input is NC4HW4 add Convert - std::vector forRelease; - TensorUtils::FuseWrap fuseUtils; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - auto origin = slice.origin; - if (nullptr == origin /*|| nullptr == origin->host()*/) { - continue; - } - // if tensor is not NC4HW4 or has been merged, don't need deal - if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); - continue; - } - // if NC4HW4's C%4 == 0, change convert to transpose and fuse it - if (origin->batch() == 1 && origin->channel() % core->pack == 0) { - int channel = origin->channel(); - int area = 1; - // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height - for (int d = 2; d < origin->dimensions(); d++) { - area *= origin->length(d); - } - Tensor::InsideDescribe::Region regionTmp; - regionTmp.src.offset = 0; - regionTmp.src.stride[0] = area * core->pack; - regionTmp.src.stride[1] = 1; - regionTmp.src.stride[2] = core->pack; - regionTmp.dst.offset = 0; - regionTmp.dst.stride[0] = area * core->pack; - regionTmp.dst.stride[1] = area; - regionTmp.dst.stride[2] = 1; - regionTmp.size[0] = channel / core->pack; - regionTmp.size[1] = core->pack; - regionTmp.size[2] = area; - regionTmp.origin = slice.origin; - bool merge = fuseUtils.match(regionTmp, slice); - if (merge) { - std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); - *newSlice = slice; - fuseUtils.apply(regionTmp, *newSlice); - // cache the merged tensor - if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); - mCacheRegions.emplace_back(newSlice); - continue; - } - } - auto cache = static_cast(backend())->getCache(); - auto tempTensor = cache->findCacheTensor(origin, midFormat); - //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); - if (nullptr == tempTensor) { - std::shared_ptr newTensor(new Tensor); - TensorUtils::copyShape(origin, newTensor.get()); - TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; - TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; - TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; - newTensor->buffer().type = origin->getType(); - TensorUtils::setLinearLayout(newTensor.get()); - mTempInput.insert(std::make_pair(origin, newTensor.get())); - auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - tempTensor = newTensor.get(); - TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; - cache->pushCacheTensor(newTensor, origin, midFormat); - } - if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { - forRelease.emplace_back(tempTensor); - } - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); - } - for (auto t : forRelease) { - backend()->onReleaseBuffer(t, Backend::DYNAMIC); - } - if (nullptr != mTempOutput) { - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); - } - auto threadNumber = static_cast(backend())->threadNumber(); - mHasReduce = false; - ReduceInfo reduceInfo; - for (auto& iter : mTempInputCopy) { - if (reduceInfo.compute(*iter.second)) { - mHasReduce = true; - break; - } - } - if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { - // Split to multi region - auto region = mTempInputCopy[0].second; - if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = false; - return NO_ERROR; - } - if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - auto tensorPtr = mTempInputCopy[0].first; - int pos = -1; - for (int i=0; i<3; ++i) { - if (region->size[i] > 1) { - pos = i; - break; - } - } - if (-1 == pos) { - // Don't need divide - return NO_ERROR; - } - mTempInputCopy.clear(); - int divSize = UP_DIV(region->size[pos], threadNumber); - for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); - auto& cacheReg = *cacheRegPtr; - int sta = i * divSize; - int fin = sta + divSize; - fin = std::min(fin, region->size[pos]); - if (fin <= sta) { - break; - } - for (int v=0; v<3; ++v) { - cacheReg.src.stride[v] = region->src.stride[v]; - cacheReg.dst.stride[v] = region->dst.stride[v]; - } - int curSize = fin - sta; - for (int v=0; vsize[v]; - } - cacheReg.size[pos] = curSize; - cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; - cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; - for (int v=pos+1; v<3; ++v) { - cacheReg.size[v] = region->size[v]; - } - mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); - mCacheRegions.emplace_back(cacheRegPtr); - } - } - return NO_ERROR; -} static void _transpose(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region, int bytes) { int dims[4], keepDim = -1; for (int i = 0; i < 3; i++) { @@ -324,15 +103,12 @@ static void _2BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, } } -void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) const { +void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) { auto input = inputs[0]; auto output = outputs[0]; auto bytes = CPUBackend::getBytes(backend(), output); auto core = static_cast(backend())->functions(); - auto threadNum = static_cast(backend())->threadNumber(); - if (mNeedZero) { - ::memset(output->host(), mZeroPoint, static_cast(backend())->getTensorSize(output) * bytes); - } + int threadNum = static_cast(backend())->threadNumber(); auto byteC4 = bytes * core->pack; auto C4proc = core->MNN4BitcopyWithStride; switch (byteC4) { @@ -352,7 +128,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve if (!mUseThreads) { threadNum = 1; } - MNN_CONCURRENCY_BEGIN(tId, threadNum) { + mTasks.emplace_back(std::make_pair([threadNum, this, output, bytes, C4proc, byteC4](int tId) { for (int u=(int)tId; uhost() == nullptr) { @@ -393,8 +169,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve } } } - } - MNN_CONCURRENCY_END(); + }, threadNum)); } static BlitProc _selectUnitProc(int bytes, int stride, int ds) { @@ -596,97 +371,307 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const } } void CPURaster::tensorConvert(Tensor* input, Tensor* output, int bytes) { - auto& subIb = input->buffer(); - auto& subOb = output->buffer(); - auto source = TensorUtils::getDescribe(input)->dimensionFormat; - auto dest = TensorUtils::getDescribe(output)->dimensionFormat; - if (subIb.dimensions <= 1 || source == dest) { - ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); - return; - } - auto tup = CPUTensorConverter::splitDimensions(subIb, source); - int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); - const int bitLength = bytes; + std::pair, int> task; auto core = static_cast(backend())->functions(); auto threadNumber = static_cast(backend())->threadNumber(); if (!mUseThreads) { threadNumber = 1; } - MNN_CONCURRENCY_BEGIN(tId, threadNumber) { + task.first = [input, output, bytes, threadNumber, core](int tId) { + auto& subIb = input->buffer(); + auto& subOb = output->buffer(); + auto source = TensorUtils::getDescribe(input)->dimensionFormat; + auto dest = TensorUtils::getDescribe(output)->dimensionFormat; + if (subIb.dimensions <= 1 || source == dest) { + ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); + return; + } + auto tup = CPUTensorConverter::splitDimensions(subIb, source); + int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); + const int bitLength = bytes; CPUTensorConverter::convert(subIb.host, subOb.host, source, dest, batch, area, channel, bitLength, core, tId, threadNumber); }; - MNN_CONCURRENCY_END(); + task.second = threadNumber; + mTasks.emplace_back(task); } - - -ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { - void* mOutputPtr = nullptr; - if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = outputs[0]->host(); - } - if (mFast) { - executeFaster(____inputs, outputs); - return NO_ERROR; - } - auto core = static_cast(backend())->functions(); +ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { + MNN_ASSERT(outputs.size() == 1); auto output = outputs[0]; + OpCommonUtils::rasterInputReset(____inputs, outputs[0]); + auto des = TensorUtils::getDescribe(output); + auto outputDes = TensorUtils::getDescribe(output); + mNeedZero = !TensorUtils::regionIsFull(output); + mZeroPoint = 0; + mUseThreads = false; + int threadNum = static_cast(backend())->threadNumber(); + if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { +#ifdef MNN_USE_SSE + mZeroPoint = (int)outputDes->quantAttr->zero + 128; +#else + mZeroPoint = (int)outputDes->quantAttr->zero; +#endif + } size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - auto outputEleSize = static_cast(backend())->getTensorSize(output); - auto threadNum = static_cast(backend())->threadNumber(); - if (mSingleConvert.type > 0) { - auto realInput = ____inputs[0]; - int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; - auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; - auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; - auto channelC4 = UP_DIV(srcChannel, core->pack); - auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; - auto batchStride = srcChannel * srcArea * bytes; - auto inputBatchStride = batchStride; - auto outputBatchStride = batchStride; - if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { - if (realInput->dimensions() <= 1) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - return NO_ERROR; + mTempInput.clear(); + mFastBlit.clear(); + mCacheRegions.clear(); + mTempOutput = nullptr; + mTasks.clear(); + auto midFormat = MNN_DATA_FORMAT_NCHW; + mTempInputCopy.clear(); + mFast = false; + auto core = static_cast(backend())->functions(); + mSingleConvert.type = 0; + // all_srcFormat == dstFormat == NC4HW4 : Fast Exe + if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { + mFast = true; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + mFast = false; + break; } - inputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - destFormat = MNN_DATA_FORMAT_NHWC; - } else { - destFormat = MNN_DATA_FORMAT_NCHW; + if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { + mFast = false; + break; } - } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { - if (output->dimensions() <= 1) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - return NO_ERROR; + } + if (mFast) { + mUseThreads = des->regions.size() > 16 ? true : false; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (slice.origin == nullptr) { + continue; + } + Tensor::InsideDescribe::Region newRegion; + OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); + mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); } - outputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - sourceFormat = MNN_DATA_FORMAT_NHWC; - } else { - sourceFormat = MNN_DATA_FORMAT_NCHW; + executeFaster(____inputs, outputs); + return NO_ERROR; + } + } + // srcNum == 1 && srcFormat != dstFormat : Single Convert + if (des->regions.size() == 1) { + OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); + if (mSingleConvert.type > 0) { + std::pair, int> task; + mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; + auto realInput = ____inputs[0]; + int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; + auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; + auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; + auto channelC4 = UP_DIV(srcChannel, core->pack); + auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; + auto batchStride = srcChannel * srcArea * bytes; + auto inputBatchStride = batchStride; + auto outputBatchStride = batchStride; + if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { + if (realInput->dimensions() <= 1) { + task.first = [output, realInput, bytes](int tId) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + }; + task.second = 1; + mTasks.emplace_back(task); + return NO_ERROR; + } + inputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + destFormat = MNN_DATA_FORMAT_NHWC; + } else { + destFormat = MNN_DATA_FORMAT_NCHW; + } + } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { + if (output->dimensions() <= 1) { + task.first = [output, realInput, bytes](int tId) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + }; + task.second = 1; + mTasks.emplace_back(task); + return NO_ERROR; + } + outputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + sourceFormat = MNN_DATA_FORMAT_NHWC; + } else { + sourceFormat = MNN_DATA_FORMAT_NCHW; + } + } + if (!mUseThreads) { + threadNum = 1; } + task.first = [realInput, output, sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, threadNum](int tId) { + CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); + }; + task.second = threadNum; + mTasks.emplace_back(task); + return NO_ERROR; } - if (!mUseThreads) { - threadNum = 1; + } + // Acquire Buffer for temp output + // TODO: optimize it + if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { + mTempOutput.reset(new Tensor); + TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); + } + if (nullptr != mTempOutput) { + auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; } - MNN_CONCURRENCY_BEGIN(tId, threadNum) { - CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); - }; - MNN_CONCURRENCY_END(); - return NO_ERROR; } - if (mNeedZero) { - if (mTempOutput == nullptr) { - ::memset(output->host(), mZeroPoint, outputEleSize * bytes); - } else { - ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); + // input is NC4HW4 add Convert + std::vector forRelease; + TensorUtils::FuseWrap fuseUtils; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + auto origin = slice.origin; + if (nullptr == origin /*|| nullptr == origin->host()*/) { + continue; + } + // if tensor is not NC4HW4 or has been merged, don't need deal + if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); + continue; + } + // if NC4HW4's C%4 == 0, change convert to transpose and fuse it + if (origin->batch() == 1 && origin->channel() % core->pack == 0) { + int channel = origin->channel(); + int area = 1; + // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height + for (int d = 2; d < origin->dimensions(); d++) { + area *= origin->length(d); + } + Tensor::InsideDescribe::Region regionTmp; + regionTmp.src.offset = 0; + regionTmp.src.stride[0] = area * core->pack; + regionTmp.src.stride[1] = 1; + regionTmp.src.stride[2] = core->pack; + regionTmp.dst.offset = 0; + regionTmp.dst.stride[0] = area * core->pack; + regionTmp.dst.stride[1] = area; + regionTmp.dst.stride[2] = 1; + regionTmp.size[0] = channel / core->pack; + regionTmp.size[1] = core->pack; + regionTmp.size[2] = area; + regionTmp.origin = slice.origin; + bool merge = fuseUtils.match(regionTmp, slice); + if (merge) { + std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); + *newSlice = slice; + fuseUtils.apply(regionTmp, *newSlice); + // cache the merged tensor + if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); + mCacheRegions.emplace_back(newSlice); + continue; + } + } + auto cache = static_cast(backend())->getCache(); + auto tempTensor = cache->findCacheTensor(origin, midFormat); + //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); + if (nullptr == tempTensor) { + std::shared_ptr newTensor(new Tensor); + TensorUtils::copyShape(origin, newTensor.get()); + TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; + TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; + TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; + newTensor->buffer().type = origin->getType(); + TensorUtils::setLinearLayout(newTensor.get()); + mTempInput.insert(std::make_pair(origin, newTensor.get())); + auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + tempTensor = newTensor.get(); + TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; + cache->pushCacheTensor(newTensor, origin, midFormat); } + if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { + forRelease.emplace_back(tempTensor); + } + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); + } + for (auto t : forRelease) { + backend()->onReleaseBuffer(t, Backend::DYNAMIC); } + if (nullptr != mTempOutput) { + backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); + } + auto threadNumber = static_cast(backend())->threadNumber(); + mHasReduce = false; + ReduceInfo reduceInfo; + for (auto& iter : mTempInputCopy) { + if (reduceInfo.compute(*iter.second)) { + mHasReduce = true; + break; + } + } + // Encode convert for (auto& iter : mTempInput) { tensorConvert(iter.first, iter.second, (int)bytes); } + do { + if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { + // Split to multi region + auto region = mTempInputCopy[0].second; + if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = false; + break; + } + if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + auto tensorPtr = mTempInputCopy[0].first; + int pos = -1; + for (int i=0; i<3; ++i) { + if (region->size[i] > 1) { + pos = i; + break; + } + } + if (-1 == pos) { + // Don't need divide + break; + } + mTempInputCopy.clear(); + int divSize = UP_DIV(region->size[pos], threadNumber); + for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); + auto& cacheReg = *cacheRegPtr; + int sta = i * divSize; + int fin = sta + divSize; + fin = std::min(fin, region->size[pos]); + if (fin <= sta) { + break; + } + for (int v=0; v<3; ++v) { + cacheReg.src.stride[v] = region->src.stride[v]; + cacheReg.dst.stride[v] = region->dst.stride[v]; + } + int curSize = fin - sta; + for (int v=0; vsize[v]; + } + cacheReg.size[pos] = curSize; + cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; + cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; + for (int v=pos+1; v<3; ++v) { + cacheReg.size[v] = region->size[v]; + } + mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); + mCacheRegions.emplace_back(cacheRegPtr); + } + } + } while (false); if (mHasReduce) { // Don't support reduce with multi thread now threadNum = 1; @@ -700,8 +685,13 @@ ErrorCode CPURaster::onExecute(const std::vector &____inputs, const st if (outputDescribe->overlap) { threadNum = 1; } - - MNN_CONCURRENCY_BEGIN(tId, threadNum) { + mTasks.emplace_back(std::make_pair([this, threadNum, output, bytes, core](int tId){ + void* mOutputPtr = nullptr; + if (nullptr != mTempOutput) { + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = output->host(); + } for (int u=tId; u &____inputs, const st auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; _blit(slice, (int)bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp); } - } - MNN_CONCURRENCY_END(); + }, threadNum)); if (nullptr != mTempOutput) { tensorConvert(mTempOutput.get(), output, (int)bytes); } return NO_ERROR; } + + +ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { + void* mOutputPtr = nullptr; + if (nullptr != mTempOutput) { + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = outputs[0]->host(); + } + auto core = static_cast(backend())->functions(); + auto output = outputs[0]; + size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); + auto outputEleSize = static_cast(backend())->getTensorSize(output); + auto threadNum = static_cast(backend())->threadNumber(); + if (mNeedZero) { + if (mTempOutput == nullptr) { + ::memset(output->host(), mZeroPoint, outputEleSize * bytes); + } else { + ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); + } + } + for (auto& task : mTasks) { + MNN_CONCURRENCY_ENQUEUE(task); + } + return NO_ERROR; +} class CPULoop : public Execution { public: struct ThreadContainer { diff --git a/source/backend/cpu/CPURaster.hpp b/source/backend/cpu/CPURaster.hpp index 9df10700bd..bff149df52 100644 --- a/source/backend/cpu/CPURaster.hpp +++ b/source/backend/cpu/CPURaster.hpp @@ -24,7 +24,7 @@ class CPURaster : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - void executeFaster(const std::vector &inputs, const std::vector &outputs) const; + void executeFaster(const std::vector &inputs, const std::vector &outputs); void tensorConvert(Tensor* input, Tensor* output, int bytes); private: std::map mTempInput; @@ -38,6 +38,7 @@ class CPURaster : public Execution { int32_t mZeroPoint = 0; bool mHasReduce = false; bool mUseThreads = false; + std::vector, int>> mTasks; }; } #endif diff --git a/source/backend/cpu/ThreadPool.cpp b/source/backend/cpu/ThreadPool.cpp index 15a2d8241c..d7765c4fbc 100644 --- a/source/backend/cpu/ThreadPool.cpp +++ b/source/backend/cpu/ThreadPool.cpp @@ -60,7 +60,7 @@ ThreadPool::ThreadPool(int numberThread) { while (mActiveCount > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { - mTasks[i].first.first(threadIndex); + mTasks[i].first->first(threadIndex); { *mTasks[i].second[threadIndex] = false; } } } @@ -118,16 +118,18 @@ void ThreadPool::deactive() { mActiveCount--; } -void ThreadPool::enqueue(TASK&& task, int index) { +void ThreadPool::enqueue(TASK* taskp, int index) { + auto& task = *taskp; if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } - enqueueInternal(std::move(task), index); + enqueueInternal(taskp, index); } -void ThreadPool::enqueueInternal(TASK&& task, int index) { +void ThreadPool::enqueueInternal(TASK* taskp, int index) { + auto& task = *taskp; if (mActiveCount == 0) { for (int i = 0; i < task.second; ++i) { task.first(i); @@ -135,24 +137,25 @@ void ThreadPool::enqueueInternal(TASK&& task, int index) { return; } int workSize = task.second; + TASK* tmpTask = nullptr; if (workSize > mNumberThread) { - mTasks[index].first = std::make_pair( - [workSize, &task, this](int tId) { - for (int v = tId; v < workSize; v += mNumberThread) { - task.first(v); - } - }, - mNumberThread); + tmpTask = new TASK; + *tmpTask = std::make_pair([workSize, &task, this](int tId) { + for (int v = tId; v < workSize; v += mNumberThread) { + task.first(v); + } + }, mNumberThread); + mTasks[index].first = tmpTask; workSize = mNumberThread; } else { - mTasks[index].first = std::move(task); + mTasks[index].first = taskp; } { for (int i = 1; i < workSize; ++i) { *mTasks[index].second[i] = true; } } - mTasks[index].first.first(0); + mTasks[index].first->first(0); bool complete = true; do { complete = true; @@ -165,6 +168,9 @@ void ThreadPool::enqueueInternal(TASK&& task, int index) { std::this_thread::yield(); // FUNC_PRINT(notComplete); } while (!complete); + if (nullptr != tmpTask) { + delete tmpTask; + } } } // namespace MNN #endif diff --git a/source/backend/cpu/ThreadPool.hpp b/source/backend/cpu/ThreadPool.hpp index 4bf23de1b0..8891da61b1 100644 --- a/source/backend/cpu/ThreadPool.hpp +++ b/source/backend/cpu/ThreadPool.hpp @@ -25,7 +25,7 @@ class MNN_PUBLIC ThreadPool { int numberThread() const { return mNumberThread; } - void enqueue(TASK&& task, int index); + void enqueue(TASK* task, int index); void active(); void deactive(); @@ -37,7 +37,7 @@ class MNN_PUBLIC ThreadPool { static void destroy(); private: - void enqueueInternal(TASK&& task, int index); + void enqueueInternal(TASK* task, int index); ThreadPool(int numberThread = 0); ~ThreadPool(); @@ -46,7 +46,7 @@ class MNN_PUBLIC ThreadPool { std::vector mTaskAvailable; std::atomic mStop = {false}; - std::vector>> mTasks; + std::vector>> mTasks; std::condition_variable mCondition; std::mutex mQueueMutex; diff --git a/source/core/Concurrency.h b/source/core/Concurrency.h index 73f5984e5a..7c06625fe4 100644 --- a/source/core/Concurrency.h +++ b/source/core/Concurrency.h @@ -12,6 +12,9 @@ #define LAUNCH_MULTI_THREADS_WORKLOAD 1e+5 #ifdef MNN_FORBIT_MULTI_THREADS +#define MNN_CONCURRENCY_ENQUEUE(task) \ +for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) for (int __iter__ = 0; __iter__ < __num__; __iter__++) { #define MNN_CONCURRENCY_END() } @@ -19,6 +22,8 @@ #include "backend/cpu/ThreadPool.hpp" #define MNN_STRINGIFY(a) #a +#define MNN_CONCURRENCY_ENQUEUE(task) ((CPUBackend*)backend())->enqueue(task) + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ { \ std::pair, int> task; \ @@ -28,8 +33,7 @@ } \ ; \ auto cpuBn = (CPUBackend*)backend(); \ - auto thrPl = cpuBn->threadPool(); \ - thrPl->enqueue(std::move(task), cpuBn->taskIndex()); \ + cpuBn->enqueue(task); \ } #else @@ -38,6 +42,9 @@ #include #include +#define MNN_CONCURRENCY_ENQUEUE(task) \ +dispatch_apply(task.second, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) {task.first(__iter__);}); + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) { #define MNN_CONCURRENCY_END() \ @@ -58,6 +65,8 @@ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, // Android #else #include +#define MNN_CONCURRENCY_ENQUEUE(task) \ +_Pragma("omp parallel for") for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} #define MNN_STRINGIFY(a) #a #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ diff --git a/test/core/ThreadPoolTest.cpp b/test/core/ThreadPoolTest.cpp index 6886f86e62..e010939e5f 100644 --- a/test/core/ThreadPoolTest.cpp +++ b/test/core/ThreadPoolTest.cpp @@ -26,11 +26,11 @@ class ThreadPoolTest : public MNNTestCase { auto workIndex = threadPool->acquireWorkIndex(); FUNC_PRINT(workIndex); threadPool->active(); - auto func = [](int index) { + ThreadPool::TASK task = std::make_pair([](int index) { FUNC_PRINT(index); std::this_thread::yield(); - }; - threadPool->enqueue(std::make_pair(std::move(func), 10), workIndex); + }, 10); + threadPool->enqueue(&task, workIndex); threadPool->deactive(); threadPool->releaseWorkIndex(workIndex); }); From 6a80bbeb5e5b771b525c384212e785534eb9bfff Mon Sep 17 00:00:00 2001 From: jxt1234 Date: Mon, 22 Dec 2025 11:26:37 +0800 Subject: [PATCH 039/314] Geometry:Speed: Ref regions instead of make copy, some case can't merge, Remove copy when reduce single axis --- source/core/OpCommonUtils.cpp | 91 ------------------- source/core/OpCommonUtils.hpp | 1 - source/core/TensorUtils.cpp | 12 +++ source/core/TensorUtils.hpp | 1 + source/geometry/GeometryComputerUtils.cpp | 4 +- source/geometry/GeometryComputerUtils.hpp | 2 +- source/geometry/GeometryReduce.cpp | 104 +++++++++++++++++++++- source/geometry/GeometryReshape.cpp | 11 +-- 8 files changed, 122 insertions(+), 104 deletions(-) diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index c80afaef87..a69263ffaa 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -386,98 +386,7 @@ void OpCommonUtils::broastCastComputeDim(int* dims, int* stride, int* iStride0, } } } -std::vector> OpCommonUtils::computeReduceDims(const std::vector& inputs, - const Op* op) { - // Compute axises - std::vector axises; - if (inputs.size() >= 2) { - auto size = inputs[1]->elementSize(); - auto dims = inputs[1]->host(); - for (int i = 0; i < size; ++i) { - axises.emplace_back(dims[i]); - } - } else { - auto reduct = op->main_as_ReductionParam(); - if (nullptr != reduct->dim()) { - for (int i = 0; i < reduct->dim()->size(); ++i) { - axises.emplace_back(reduct->dim()->data()[i]); - } - } - } - auto totalSize = TensorUtils::getRawSize(inputs[0]); - if (axises.empty()) { - return {std::make_tuple(1, totalSize, 1)}; - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - axises[i] = inputs[0]->dimensions() + axises[i]; - if (axises[i] < 0) { - return {std::make_tuple(1, totalSize, 1)}; - } - } - } - // Cache for input's dims - std::vector lengths(inputs[0]->dimensions()); - for (int i = 0; i < lengths.size(); ++i) { - lengths[i] = inputs[0]->length(i); - } - std::vector> groupAxises; - { - // Merge adj axis - std::sort(axises.begin(), axises.end()); - int lastAxis = axises[0]; - int length = 1; - int start = axises[0]; - for (int i = 1; i < axises.size(); ++i) { - // MNN_PRINT("%d - %d\n", axises[i], lastAxis); - if (axises[i] - lastAxis == 1) { - length++; - } else { - groupAxises.emplace_back(std::make_pair(start, length)); - length = 1; - start = axises[i]; - } - lastAxis = axises[i]; - } - groupAxises.emplace_back(std::make_pair(start, length)); - } - - // Compute inside-outside-axis - std::vector> result; - for (int i = 0; i < groupAxises.size(); ++i) { - int outsideSize = 1; - int insideSize = 1; - int axisSize = 1; - auto start = groupAxises[i].first; - auto length = groupAxises[i].second; - if (start >= (int)lengths.size()) { - break; - } - for (int j = 0; j < start; ++j) { - outsideSize *= lengths[j]; - } - for (int j = start; j < start + length; ++j) { - if (j >= (int)lengths.size()) { - break; - } - axisSize *= lengths[j]; - lengths[j] = 1; - } - for (int j = start + length; j < lengths.size(); ++j) { - insideSize *= lengths[j]; - } - if (1 == axisSize) { - continue; - } - result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); - } - // FUNC_PRINT(result.size()); - if (result.empty()) { - result.emplace_back(std::make_tuple(1, 1, totalSize)); - } - return result; -} void OpCommonUtils::unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice) { int value = indice; diff --git a/source/core/OpCommonUtils.hpp b/source/core/OpCommonUtils.hpp index 0740cc16b2..8ec0628336 100644 --- a/source/core/OpCommonUtils.hpp +++ b/source/core/OpCommonUtils.hpp @@ -56,7 +56,6 @@ class MNN_PUBLIC OpCommonUtils { static bool supportDynamicInputMemory(MNNForwardType type); static void broastCastComputeDim(int* dims, int* stride, int* iStride0, int* iStride1, const Tensor* input0, const Tensor* input1, const Tensor* output); - static std::vector> computeReduceDims(const std::vector& inputs, const Op* op); static void unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice); static int computeStride(int32_t* strides, const int* shape, int length); diff --git a/source/core/TensorUtils.cpp b/source/core/TensorUtils.cpp index ae5b87143c..d233fc9d89 100644 --- a/source/core/TensorUtils.cpp +++ b/source/core/TensorUtils.cpp @@ -32,6 +32,18 @@ bool TensorUtils::regionIsFull(Tensor* input) { return regionSize == size; } +void TensorUtils::makeFullRef(Tensor* output, Tensor* input) { + auto des = TensorUtils::getDescribe(input); + auto outputDes = TensorUtils::getDescribe(output); + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) { + outputDes->regions = des->regions; + } else { + outputDes->regions = {makeFullSlice(input)}; + } +} + + Tensor::InsideDescribe::Region TensorUtils::makeFullSlice(Tensor* input) { Tensor::InsideDescribe::Region totalSlice; totalSlice.src.offset = 0; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 1342a669bd..a577fea05f 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -184,6 +184,7 @@ class MNN_PUBLIC TensorUtils { static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat); static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input); + static void makeFullRef(Tensor* output, Tensor* input); static bool regionIsFull(Tensor* input); static bool isCopyRegion(const Tensor::InsideDescribe::Region& region); static bool isTransposeRegion(const Tensor::InsideDescribe::Region& region); diff --git a/source/geometry/GeometryComputerUtils.cpp b/source/geometry/GeometryComputerUtils.cpp index 01a4e02ea2..85f64de55d 100644 --- a/source/geometry/GeometryComputerUtils.cpp +++ b/source/geometry/GeometryComputerUtils.cpp @@ -477,9 +477,9 @@ std::shared_ptr GeometryComputerUtils::makeBinary(int type, Tensor* inp return cmdP; } -std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output) { +std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis) { flatbuffers::FlatBufferBuilder builder(DEFAULT_ALLOCATE_SIZE); - auto vec = builder.CreateVector(std::vector{1}); + auto vec = builder.CreateVector(std::vector{axis}); ReductionParamBuilder builder_(builder); builder_.add_operation(type); builder_.add_keepDims(true); diff --git a/source/geometry/GeometryComputerUtils.hpp b/source/geometry/GeometryComputerUtils.hpp index c0dffdcdb1..97c4d5811f 100644 --- a/source/geometry/GeometryComputerUtils.hpp +++ b/source/geometry/GeometryComputerUtils.hpp @@ -18,7 +18,7 @@ class GeometryComputerUtils { static void addConvert(const CommandBuffer& srcBuffer, CommandBuffer& dstBuffer, GeometryComputer::Context& ctx); static std::shared_ptr makeCommand(flatbuffers::FlatBufferBuilder& builder, const std::vector& inputs, const std::vector& outputs); static std::shared_ptr makeBinary(int type, Tensor* input0, Tensor* input1, Tensor* output); - static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output); + static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis = 1); static std::shared_ptr makeUnary(UnaryOpOperation type, Tensor* input0, Tensor* output); static std::shared_ptr makeLayerNorm(Tensor* input0, Tensor* output, std::vector axis, float epsilon, std::vector gamma, std::vector beta, std::vector external, int group = 1, bool useRMS = false); static std::shared_ptr makeMatMul(Tensor* input0, Tensor* input1, Tensor* output, Tensor* Bias = nullptr, diff --git a/source/geometry/GeometryReduce.cpp b/source/geometry/GeometryReduce.cpp index c2a3bb4114..855f4bcf69 100644 --- a/source/geometry/GeometryReduce.cpp +++ b/source/geometry/GeometryReduce.cpp @@ -10,6 +10,83 @@ #include "geometry/GeometryComputerUtils.hpp" #include "core/OpCommonUtils.hpp" namespace MNN { +static std::vector> _computeReduceDims(const std::vector& inputs, + std::vector& axises) { + + auto totalSize = TensorUtils::getRawSize(inputs[0]); + if (axises.empty()) { + return {std::make_tuple(1, totalSize, 1)}; + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + if (axises[i] < 0) { + return {std::make_tuple(1, totalSize, 1)}; + } + } + } + // Cache for input's dims + std::vector lengths(inputs[0]->dimensions()); + for (int i = 0; i < lengths.size(); ++i) { + lengths[i] = inputs[0]->length(i); + } + std::vector> groupAxises; + { + // Merge adj axis + std::sort(axises.begin(), axises.end()); + int lastAxis = axises[0]; + int length = 1; + int start = axises[0]; + for (int i = 1; i < axises.size(); ++i) { + // MNN_PRINT("%d - %d\n", axises[i], lastAxis); + if (axises[i] - lastAxis == 1) { + length++; + } else { + groupAxises.emplace_back(std::make_pair(start, length)); + length = 1; + start = axises[i]; + } + lastAxis = axises[i]; + } + groupAxises.emplace_back(std::make_pair(start, length)); + } + + // Compute inside-outside-axis + std::vector> result; + + for (int i = 0; i < groupAxises.size(); ++i) { + int outsideSize = 1; + int insideSize = 1; + int axisSize = 1; + auto start = groupAxises[i].first; + auto length = groupAxises[i].second; + if (start >= (int)lengths.size()) { + break; + } + for (int j = 0; j < start; ++j) { + outsideSize *= lengths[j]; + } + for (int j = start; j < start + length; ++j) { + if (j >= (int)lengths.size()) { + break; + } + axisSize *= lengths[j]; + lengths[j] = 1; + } + for (int j = start + length; j < lengths.size(); ++j) { + insideSize *= lengths[j]; + } + if (1 == axisSize) { + continue; + } + result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); + } + // FUNC_PRINT(result.size()); + if (result.empty()) { + result.emplace_back(std::make_tuple(1, 1, totalSize)); + } + return result; +} + class GeometryReduce : public GeometryComputer { public: virtual bool onCompute(const Op* op, const std::vector& inputs, const std::vector& outputs, @@ -18,6 +95,31 @@ class GeometryReduce : public GeometryComputer { MNN_ASSERT(inputs.size() >= 1); auto reduct = op->main_as_ReductionParam(); auto reductOp = reduct->operation(); + std::vector axises; + if (inputs.size() >= 2) { + auto size = inputs[1]->elementSize(); + auto dims = inputs[1]->host(); + for (int i = 0; i < size; ++i) { + axises.emplace_back(dims[i]); + } + } else { + auto reduct = op->main_as_ReductionParam(); + if (nullptr != reduct->dim()) { + for (int i = 0; i < reduct->dim()->size(); ++i) { + axises.emplace_back(reduct->dim()->data()[i]); + } + } + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + axises[i] = inputs[0]->dimensions() + axises[i]; + } + } + if (1 == axises.size() && TensorUtils::getDescribe(inputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4 && TensorUtils::getDescribe(outputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + auto cmd = GeometryComputerUtils::makeReduce(reductOp, inputs[0], outputs[0], axises[0]); + res.command.emplace_back(std::move(cmd)); + return true; + } // prod([]) = 1 if (inputs[0]->elementSize() == 0) { if(!context.allocTensor(outputs[0])) { @@ -39,7 +141,7 @@ class GeometryReduce : public GeometryComputer { } return true; } - auto reduceDims = OpCommonUtils::computeReduceDims(inputs, op); + auto reduceDims = _computeReduceDims(inputs, axises); Tensor* currentInput = inputs[0]; MNN_ASSERT(reduceDims.size() > 0); auto dimType = currentInput->getDimensionType(); diff --git a/source/geometry/GeometryReshape.cpp b/source/geometry/GeometryReshape.cpp index 88d98a24c9..1df3384e37 100644 --- a/source/geometry/GeometryReshape.cpp +++ b/source/geometry/GeometryReshape.cpp @@ -42,8 +42,7 @@ class GeometryReshape : public GeometryComputer { return true; } } - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); return true; } }; @@ -75,10 +74,7 @@ class SingleGeometryComputer : public GeometryComputer { Context& context, CommandBuffer& res) const override { auto input = inputs[0]; auto output = outputs[0]; - auto inputDes = TensorUtils::getDescribe(input); - auto outputDes = TensorUtils::getDescribe(output); - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); return true; } }; @@ -94,8 +90,7 @@ class CopyGeometryComputer : public GeometryComputer { outputDes->tensorArrayAttr = inputDes->tensorArrayAttr; return true; } - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); } return true; } From b7a61d2b81f10994543ccc298a0bb5bbc2bedc4a Mon Sep 17 00:00:00 2001 From: jxt1234 Date: Mon, 22 Dec 2025 15:22:34 +0800 Subject: [PATCH 040/314] MNN:Speed: Unroll and use fma opt H_1 --- .../backend/cpu/compute/CommonOptFunction.cpp | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index 9abefb6df4..c9bfcc2189 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -4031,14 +4031,33 @@ void MNNComputeMatMulForH_1(const float* A, const float* B, float* C, const floa if (nullptr != biasPtr) { biasValue = *biasPtr; } - auto lC4 = l / 4; - auto lR = lC4 * 4; + auto lC4 = l / 16; + auto lRO = lC4 * 16; for (int y=tId; y= 8) { + sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); + sum1 = Vec::fma(sum1, Vec4::load(srcY + lR + 4), Vec4::load(B + lR + 4)); + lR += 8; + } + if (l - lR >= 4) { + sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); + lR += 4; + } + sum2 = sum2 + sum3; + sumValue = sumValue + sum1; + sumValue = sumValue + sum2; float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; for (int x=lR; x Date: Mon, 22 Dec 2025 15:55:30 +0800 Subject: [PATCH 041/314] MNN:Speed: Optimize CPU Binary Broadcast Compute --- source/backend/cpu/CPURaster.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index f64dafced3..1339089347 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -1081,7 +1081,15 @@ class CPULoop : public Execution { auto stride2 = cmd->view()->GetAs(2)->stride()->data(); auto blit1 = _selectUnitProc(bytes, stride1[2], 1); auto blit2 = _selectUnitProc(bytes, stride2[2], 1); - if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) { + if (cmd->size()->data()[2] == 1 || (stride1[2] <= 1 && stride2[2] <= 1 && (stride1[2] + stride1[1] != 0))) { + // Support elementwise or one src broadcast + int needBroadcastIndex = -1; + if (0 == stride1[2]) { + needBroadcastIndex = 0; + } + if (0 == stride2[2]) { + needBroadcastIndex = 1; + } for (int z=0; zsize()->data()[0]; ++z) { auto src0Z = src0 + z * stride1[0] * bytes; auto src1Z = src1 + z * stride2[0] * bytes; @@ -1090,7 +1098,7 @@ class CPULoop : public Execution { auto src0Y = src0Z + y * stride1[1] * bytes; auto src1Y = src1Z + y * stride2[1] * bytes; auto dstY = dstZ + y * stride0[1] * bytes; - proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1); + proc(dstY, src0Y, src1Y, cmd->size()->data()[2], needBroadcastIndex); } } } else { From 3d53ad165498c456ffd531ab415b251792802ddd Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 09:56:50 +0800 Subject: [PATCH 042/314] Project import generated by Copybara. GitOrigin-RevId: f936e7dcb1d1dbef608b5a01ad46ce1da8fca7de --- CMakeLists.txt | 1 - README.md | 14 +- README_CN.md | 10 +- README_JP.md | 9 +- build_lib.sh | 807 ------------------ docs/transformers/diffusion.md | 3 +- source/backend/cpu/CPUBackend.cpp | 8 +- source/backend/cpu/CPUBackend.hpp | 3 - source/backend/cpu/CPUBinary.cpp | 60 +- source/backend/cpu/CPUBinary.hpp | 4 - source/backend/cpu/CPUMatMul.cpp | 28 +- source/backend/cpu/CPUMatMul.hpp | 7 +- source/backend/cpu/CPURNNSequenceGRU.cpp | 70 +- source/backend/cpu/CPURNNSequenceGRU.hpp | 15 +- source/backend/cpu/CPURaster.cpp | 631 +++++++------- source/backend/cpu/CPURaster.hpp | 3 +- source/backend/cpu/ThreadPool.cpp | 32 +- source/backend/cpu/ThreadPool.hpp | 6 +- source/backend/cpu/arm/CMakeLists.txt | 3 - .../backend/cpu/compute/CommonOptFunction.cpp | 88 +- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 - .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 - .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 - .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 -- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 - .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 - .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 -- source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 - .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 -- .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 - .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 -- .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 - .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 -- .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 -- .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 - source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 - source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 - source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 -- source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 -- source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 - .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 - .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 - source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 - .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 - source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 -- .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 - .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 - .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 - source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 -- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 - .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 - source/core/Backend.hpp | 6 +- source/core/Concurrency.h | 13 +- source/core/OpCommonUtils.cpp | 91 ++ source/core/OpCommonUtils.hpp | 1 + source/core/TensorUtils.cpp | 12 - source/core/TensorUtils.hpp | 1 - source/geometry/GeometryComputerUtils.cpp | 4 +- source/geometry/GeometryComputerUtils.hpp | 2 +- source/geometry/GeometryReduce.cpp | 104 +-- source/geometry/GeometryReshape.cpp | 11 +- source/math/Vec.hpp | 3 +- test/core/ThreadPoolTest.cpp | 6 +- tools/cpp/ExprDebug.hpp | 53 +- tools/cpp/ModuleBasic.cpp | 46 +- transformers/diffusion/export/onnx_export.py | 30 +- transformers/llm/engine/src/llm.cpp | 21 +- 73 files changed, 616 insertions(+), 2944 deletions(-) delete mode 100644 build_lib.sh delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f99e37ec1c..67502b606b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,7 +258,6 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) -option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/README.md b/README.md index 7959890c16..5fe168ed05 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #4 (Available): 160170007549 - -Group #3 (Full) +Group #1 (Full): 23329087 Group #2 (Full): 23350225 -Group #1 (Full): 23329087 +Group #3: QR code: + +![MNN-3](doc/dingdingmnn3.png) ## Historical Paper diff --git a/README_CN.md b/README_CN.md index f769a1e14b..edcf823a28 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,10 +111,12 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群3 (可加入): 160170007549 -- 钉钉群3 (已无法加入) -- 钉钉群2 (已满): 23350225 -- 钉钉群1 (已满): 23329087 +- 钉钉群1:23329087 +- 钉钉群2:23350225 +- 钉钉群3:扫描二维码加入 + +![MNN-3](doc/dingdingmnn3.png) + ## 历史论文 diff --git a/README_JP.md b/README_JP.md index 2f33def31a..c2baa58d94 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,14 +117,13 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: - -グループ#4 :160170007549 - -グループ#3 (満員) +グループ#1(満員):23329087 グループ#2(満員):23350225 -グループ#1(満員):23329087 +グループ#3:QRコード: + +![MNN-3](doc/dingdingmnn3.png) ## 歴史的な論文 diff --git a/build_lib.sh b/build_lib.sh deleted file mode 100644 index c839b6e7b6..0000000000 --- a/build_lib.sh +++ /dev/null @@ -1,807 +0,0 @@ -#!/bin/bash - -# MNN 统一构建脚本 -# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN - -set -e - -# 颜色输出 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# 默认配置 -BUILD_ANDROID=false -BUILD_IOS=false -BUILD_PYTHON=false -BUILD_IOS_SIMULATOR=false -BUILD_HARMONY=false -ANDROID_NDK="" -HARMONY_HOME="" -OUTPUT_DIR="mnn_builds" -CLEAN_BUILD=true -VERSION="" -PYTHON_CMD="python3" -DEPS_OPTIONS="" - -# 获取项目根目录 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$SCRIPT_DIR" - -print_usage() { - echo "MNN 统一构建脚本" - echo "" - echo "用法: $0 [选项]" - echo "" - echo "构建目标:" - echo " --android 构建 Android 版本 (需要 --ndk)" - echo " --ios 构建 iOS 真机版本 (arm64)" - echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" - echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" - echo " --python 构建 Python 版本" - echo "" - echo "Android 选项:" - echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" - echo "" - echo "鸿蒙选项:" - echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" - echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" - echo "" - echo "Python 选项:" - echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" - echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" - echo " --python-cmd CMD Python 命令 (默认: python3)" - echo "" - echo "通用选项:" - echo " -o, --output DIR 输出目录 (默认: mnn_builds)" - echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" - echo " --no-clean 不清理之前的构建目录" - echo " -h, --help 显示帮助信息" - echo "" - echo "示例:" - echo " # 构建所有平台" - echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 仅构建 Android" - echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 构建鸿蒙版本" - echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" - echo "" - echo " # 构建 iOS 真机和模拟器" - echo " $0 --ios --ios-simulator" - echo "" - echo " # 构建 Python (带 LLM 支持)" - echo " $0 --python --python-deps llm,opencl" -} - -# 解析命令行参数 -while [[ $# -gt 0 ]]; do - case $1 in - --android) - BUILD_ANDROID=true - shift - ;; - --ios) - BUILD_IOS=true - shift - ;; - --ios-simulator) - BUILD_IOS_SIMULATOR=true - shift - ;; - --python) - BUILD_PYTHON=true - shift - ;; - --harmony) - BUILD_HARMONY=true - shift - ;; - --ndk) - ANDROID_NDK="$2" - shift 2 - ;; - --harmony-home) - HARMONY_HOME="$2" - shift 2 - ;; - --python-deps) - DEPS_OPTIONS="$2" - shift 2 - ;; - --python-cmd) - PYTHON_CMD="$2" - shift 2 - ;; - -o|--output) - OUTPUT_DIR="$2" - shift 2 - ;; - -v|--version) - VERSION="$2" - shift 2 - ;; - --no-clean) - CLEAN_BUILD=false - shift - ;; - -h|--help) - print_usage - exit 0 - ;; - *) - echo -e "${RED}错误: 未知选项 $1${NC}" - print_usage - exit 1 - ;; - esac -done - -# 检查是否至少选择了一个构建目标 -if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then - echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" - print_usage - exit 1 -fi - -# 检查 Android NDK -if [ "$BUILD_ANDROID" = true ]; then - if [ -z "$ANDROID_NDK" ]; then - echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" - echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" - exit 1 - fi - - # 展开路径中的 ~ - ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" - - if [ ! -d "$ANDROID_NDK" ]; then - echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" - exit 1 - fi -fi - -# 查找鸿蒙工具链的函数 -find_harmony_toolchain() { - # 如果环境变量已设置,优先使用 - if [ -n "$HARMONY_HOME" ]; then - # 支持两种路径格式:直接指定或带 native 子目录 - if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - fi - fi - - # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 - local sdk_base="$HOME/Library/OpenHarmony/Sdk" - if [ -d "$sdk_base" ]; then - # 查找所有版本号目录下的 native 或 toolchains - # 按版本号倒序排列,优先使用最新版本 - for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do - if [ -d "$version_dir" ]; then - # 尝试 native/build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir/native" - return 0 - fi - # 尝试 build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir" - return 0 - fi - fi - done - fi - - # 其他可能的路径 - local possible_paths=( - "$HOME/Library/OpenHarmony/Sdk/native" - "$HOME/HarmonyOS/Sdk/native" - "$HOME/.ohos/native" - "/opt/HarmonyOS/Sdk/native" - "/usr/local/HarmonyOS/Sdk/native" - "$HOME/Library/HarmonyOS/Sdk/native" - ) - - # 尝试查找 - for path in "${possible_paths[@]}"; do - if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then - echo "$path" - return 0 - fi - done - - # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 - # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 - local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) - if [ -n "$found" ]; then - # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 - found=$(dirname "$found") - if [ "$(basename "$found")" = "cmake" ]; then - found=$(dirname "$found") - if [ "$(basename "$found")" = "build" ]; then - found=$(dirname "$found") - fi - fi - echo "$found" - return 0 - fi - - return 1 -} - -# 检查鸿蒙工具链 -if [ "$BUILD_HARMONY" = true ]; then - # 展开路径中的 ~ - if [ -n "$HARMONY_HOME" ]; then - HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" - fi - - # 尝试查找工具链 - if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - # 检查是否在 native 子目录下 - if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - else - echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" - HARMONY_HOME=$(find_harmony_toolchain) - - if [ -z "$HARMONY_HOME" ]; then - echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" - echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" - echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" - echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi - - # 如果找到的是带 native 的路径,需要调整 - if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - fi - fi - - echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" - fi - - # 验证工具链文件存在 - if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi -fi - -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}MNN 统一构建脚本${NC}" -echo -e "${GREEN}========================================${NC}" -echo "项目根目录: $PROJECT_ROOT" -echo "输出目录: $OUTPUT_DIR" -echo "" -echo -e "${BLUE}构建目标:${NC}" -[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" -[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" -[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" -[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" -echo "" - -cd "$PROJECT_ROOT" -mkdir -p "$OUTPUT_DIR" - -# ============================================================================ -# 构建 Android 版本 -# ============================================================================ -if [ "$BUILD_ANDROID" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Android 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export ANDROID_NDK - - ANDROID_BUILD_DIR="project/android" - ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" - - cd "$ANDROID_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 Android 构建...${NC}" - rm -rf build_32 build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - - # 构建 armeabi-v7a - echo -e "${BLUE}构建 armeabi-v7a...${NC}" - mkdir -p build_32 - cd build_32 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="armeabi-v7a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-14 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="arm64-v8a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-21 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出 Android 库文件...${NC}" - mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN - - cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true - cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/android/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" - - echo -e "${GREEN}Android 构建完成!${NC}" - echo "输出位置: $ANDROID_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 真机版本 -# ============================================================================ -if [ "$BUILD_IOS" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 真机版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_64 - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建真机版本 (arm64) - echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" - rm -rf ios_64 - mkdir ios_64 - cd ios_64 - cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=OS64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - > /dev/null - make MNN -j8 > /dev/null - cd .. - - # 输出真机版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理 - rm -rf ios_64 - - echo -e "${GREEN}iOS 真机构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 模拟器版本 -# ============================================================================ -if [ "$BUILD_IOS_SIMULATOR" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) - BUILD_X86_64=false - echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" - rm -rf ios_simulator_x86 - mkdir ios_simulator_x86 - cd ios_simulator_x86 - if cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="x86_64" \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - && make MNN -j8; then - if [ -f "MNN.framework/MNN" ]; then - BUILD_X86_64=true - echo -e "${GREEN}x86_64 模拟器构建成功${NC}" - cd .. - else - echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - else - echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - - # 构建 arm64 模拟器版本 - echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" - rm -rf ios_simulator_arm64 - mkdir ios_simulator_arm64 - cd ios_simulator_arm64 - if ! cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON; then - echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - if ! make MNN -j8; then - echo -e "${RED}错误: arm64 模拟器编译失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - cd .. - - # 验证构建产物 - if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then - echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - - # 合并模拟器架构 - echo -e "${BLUE}合并模拟器架构...${NC}" - rm -rf ios_simulator - mkdir ios_simulator - - if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then - # 合并 x86_64 + arm64 - echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" - cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework - mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 - if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then - echo -e "${RED}错误: 合并架构失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - rm ios_simulator/MNN.framework/MNN_x86 - else - # 仅使用 arm64 - echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" - cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework - fi - - # 输出模拟器版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理临时目录 - rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 - - echo -e "${GREEN}iOS 模拟器构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建鸿蒙版本 -# ============================================================================ -if [ "$BUILD_HARMONY" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建鸿蒙版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export HARMONY_HOME - - HARMONY_BUILD_DIR="project/harmony" - HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" - - cd "$HARMONY_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" - rm -rf build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - - cmake "$PROJECT_ROOT" \ - -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DOHOS_ARCH="arm64-v8a" \ - -DOHOS_STL=c++_static \ - -DMNN_USE_LOGCAT=true \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_SUPPORT_BF16=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DOHOS_PLATFORM_LEVEL=9 \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出鸿蒙库文件...${NC}" - mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN - - cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/harmony/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" - - echo -e "${GREEN}鸿蒙构建完成!${NC}" - echo "输出位置: $HARMONY_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 Python 版本 -# ============================================================================ -if [ "$BUILD_PYTHON" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Python 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查 Python - if ! command -v $PYTHON_CMD &> /dev/null; then - echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" - exit 1 - fi - - PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" - - # 使用之前创建的 build_pymnn.sh 脚本 - if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then - PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" - [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" - [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" - [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" - PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" - - bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS - else - echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" - - cd pymnn/pip_package - - # 构建依赖 - if [ -n "$DEPS_OPTIONS" ]; then - $PYTHON_CMD build_deps.py $DEPS_OPTIONS - else - $PYTHON_CMD build_deps.py - fi - - # 构建 wheel - $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet - rm -rf build dist - - BUILD_ARGS="" - [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" - [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" - - $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS - - mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" - cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" - - cd "$PROJECT_ROOT" - fi - - echo -e "${GREEN}Python 构建完成!${NC}" - echo "输出位置: $PYTHON_OUTPUT_DIR" -fi - -# ============================================================================ -# 总结 -# ============================================================================ -echo "" -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}所有构建完成!${NC}" -echo -e "${GREEN}========================================${NC}" -echo "" -echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" -echo "" -[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" -[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" -[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" -[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" -echo "" - - diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 609793f806..7de27bb216 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,8 +20,7 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path \ - --opset 18 + --output_path onnx_save_path ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 8d284aa33b..0e0bc1f136 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -104,14 +104,15 @@ void CPURuntime::_bindCPUCore() const { #ifdef MNN_USE_THREAD_POOL if (nullptr != mThreadPool) { mThreadPool->active(); - ThreadPool::TASK task = std::make_pair([&](int i) { + mThreadPool->enqueue(std::make_pair([&](int i) { MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); - }, mThreadNumber); - mThreadPool->enqueue(&task, mTaskIndex); + return 0; + }, mThreadNumber), mTaskIndex); mThreadPool->deactive(); } #endif } + void CPURuntime::_resetThreadPool() const { mThreadNumber = std::max(1, mThreadNumber); mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); @@ -490,7 +491,6 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p currentRate *= decreaseRate; totalComputeRate += currentRate * selectSize; mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize)); - groupIndex--; } for (auto& g : mGroupWithComputeRate) { g.first = g.first / totalComputeRate; diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index ec4c555dec..884036eb38 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -176,9 +176,6 @@ class CPUBackend : public Backend { #ifdef MNN_USE_THREAD_POOL inline int taskIndex() const {return mRuntime->mTaskIndex;} inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;} - void enqueue(ThreadPool::TASK& task) const { - threadPool()->enqueue(&task, taskIndex()); - } #endif static void initCreatorMap(); static size_t getBytes(const Backend* backend, const Tensor* output); diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp index 61ccf4fca3..059e502d0b 100644 --- a/source/backend/cpu/CPUBinary.cpp +++ b/source/backend/cpu/CPUBinary.cpp @@ -45,37 +45,6 @@ ErrorCode CPUBinary::onResize(const std::vector& inputs, const std::vec mThreadNum = threads; mWorkDiv = UP_DIV(mTotalSize, threads); } - int inpBytes = inputs[0]->getType().bytes(); - int outBytes = outputs[0]->getType().bytes(); - if (halide_type_float == inputs[0]->getType().code) { - inpBytes = static_cast(backend())->functions()->bytes; - } - if (halide_type_float == outputs[0]->getType().code) { - outBytes = static_cast(backend())->functions()->bytes; - } - bool outputInt = outputs[0]->getType().code == halide_type_int; - mTask = std::make_pair([this, inpBytes, outBytes, outputInt](int tId) { - int start = tId * mWorkDiv; - int realSize = ALIMIN(mWorkDiv, mTotalSize - start); - if (realSize > 0) { - auto inp0 = mInput0Ptr + start * inpBytes; - auto inp1 = mInput1Ptr + start * inpBytes; - if (mNeedBroadcastIndex == 0) { - inp0 = mInput0Ptr; - } else if (mNeedBroadcastIndex == 1) { - inp1 = mInput1Ptr; - } - auto out = mOutputPtr + start * outBytes; - mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); - if(mActivationType == 1 && outputInt) { - for(int i=0; i 0 ? val : 0; - ((int32_t *)out)[i] = res; - } - } - } - } , mThreadNum); return NO_ERROR; } @@ -98,10 +67,31 @@ ErrorCode CPUBinary::onExecute(const std::vector& inputs, const std::ve outBytes = static_cast(backend())->functions()->bytes; } auto precision = static_cast(backend())->precisionMode(); - mInput0Ptr = input0Ptr; - mInput1Ptr = input1Ptr; - mOutputPtr = outputPtr; - MNN_CONCURRENCY_ENQUEUE(mTask); + + MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { + int start = tId * mWorkDiv; + int realSize = ALIMIN(mWorkDiv, mTotalSize - start); + if (realSize > 0) { + auto inp0 = input0Ptr + start * inpBytes; + auto inp1 = input1Ptr + start * inpBytes; + if (mNeedBroadcastIndex == 0) { + inp0 = input0Ptr; + } else if (mNeedBroadcastIndex == 1) { + inp1 = input1Ptr; + } + auto out = outputPtr + start * outBytes; + mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); + if(mActivationType == 1 && output->getType().code == halide_type_int) { + for(int i=0; i 0 ? val : 0; + ((int32_t *)out)[i] = res; + } + } + } + } + MNN_CONCURRENCY_END(); + if(mActivationType == 1 && output->getType().code == halide_type_float) { mActivationExe->onExecute(outputs, outputs); } diff --git a/source/backend/cpu/CPUBinary.hpp b/source/backend/cpu/CPUBinary.hpp index 17cb3b5f47..9250df79ae 100644 --- a/source/backend/cpu/CPUBinary.hpp +++ b/source/backend/cpu/CPUBinary.hpp @@ -33,10 +33,6 @@ class CPUBinary : public Execution { int mThreadNum; int mWorkDiv; std::shared_ptr mActivationExe; - std::pair, int> mTask; - uint8_t* mInput0Ptr = nullptr; - uint8_t* mOutputPtr = nullptr; - uint8_t* mInput1Ptr = nullptr; }; } // namespace MNN #endif /* CPUBinary_hpp */ diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 22b96a64ee..4f0765f050 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -37,8 +37,9 @@ void CPUMatMul::_scheduleForVecE(int e, int l, int h) { param.BTranspose = mTransposeB; param.numberThread = numberThread; auto func = static_cast(backend())->functions()->MNNComputeMatMulForE_1; - mPreFunctions.emplace_back(std::make_pair([param, func, this](int tId) { - func(mA, mB, mC, mBiasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func]( + int tId, const float* A, const float* B, const float* biasPtr, float* C) { + func(A, B, C, biasPtr, ¶m, tId); }, numberThread)); } @@ -53,9 +54,9 @@ void CPUMatMul::_scheduleForVec(int e, int l, int h) { auto func = static_cast(backend())->functions()->MNNComputeMatMulForH_1; // TODD: Support e = 1 MNN_ASSERT(h == 1); - mPreFunctions.emplace_back(std::make_pair([param, func, this]( - int tId) { - func(mA, mB, mC, mBiasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func]( + int tId, const float* A, const float* B, const float* biasPtr, float* C) { + func(A, B, C, biasPtr, ¶m, tId); }, numberThread)); } @@ -99,8 +100,8 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec return OUT_OF_MEMORY; } - mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId) { - core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), mB, h, 1, l, mTransposeB); + mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) { + core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB); } , 1)); bool useBias = false; MemChunk bdestAlloc; @@ -119,9 +120,9 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec } mTempBias = bdestAlloc; mPreFunctions.emplace_back(std::make_pair( - [biasLength, bdestAlloc, core, this](int tId) { + [biasLength, bdestAlloc, core](int tId, const float* APtr, const float* BPtr, const float* borigin, float* C) { ::memset(bdestAlloc.ptr(), 0, UP_DIV(biasLength, core->pack) * core->bytes * core->pack); - ::memcpy(bdestAlloc.ptr(), mBiasPtr, biasLength * core->bytes); + ::memcpy(bdestAlloc.ptr(), borigin, biasLength * core->bytes); }, 1)); } else { mUseBiasDirectly = true; @@ -166,12 +167,11 @@ ErrorCode CPUMatMul::onExecute(const std::vector& inputs, const std::ve } void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const float* biasPtr) { - mA = APtr; - mB = BPtr; - mC = CPtr; - mBiasPtr = biasPtr; for (auto& f : mPreFunctions) { - MNN_CONCURRENCY_ENQUEUE(f); + MNN_CONCURRENCY_BEGIN(tId, f.second) { + f.first(tId, APtr, BPtr, biasPtr, CPtr); + } + MNN_CONCURRENCY_END(); } if (mE > 0) { auto core = static_cast(backend())->functions(); diff --git a/source/backend/cpu/CPUMatMul.hpp b/source/backend/cpu/CPUMatMul.hpp index 48226795f0..872a77a9a8 100644 --- a/source/backend/cpu/CPUMatMul.hpp +++ b/source/backend/cpu/CPUMatMul.hpp @@ -29,7 +29,7 @@ class CPUMatMul : public Execution { bool mTransposeB; bool mTransposeC; bool mSupportMultiThread = false; - std::vector, int>> mPreFunctions; + std::vector, int>> mPreFunctions; bool mUseBiasDirectly = false; MemChunk mTempA; MemChunk mTempB; @@ -40,11 +40,6 @@ class CPUMatMul : public Execution { int mL; int mH; std::vector mPostParameters; - // For Execute Paramters - const float* mA = nullptr; - const float* mB = nullptr; - const float* mBiasPtr = nullptr; - float* mC = nullptr; }; } // namespace MNN diff --git a/source/backend/cpu/CPURNNSequenceGRU.cpp b/source/backend/cpu/CPURNNSequenceGRU.cpp index 0bda660e9c..daae8811c7 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.cpp +++ b/source/backend/cpu/CPURNNSequenceGRU.cpp @@ -10,26 +10,30 @@ #include #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/ConvOpt.h" +#include "backend/cpu/compute/CommonOptFunction.h" #include "core/TensorUtils.hpp" namespace MNN { // implement GRU cell function // Ref: tensorflow/python/ops/rnn_cell_impl.py -void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, +void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, + std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt, uint8_t* hiddenStateOutput) { + std::shared_ptr& resetHt) { + auto bn = static_cast(backend()); + auto mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); + auto addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); + auto subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); + auto tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); + auto bytes = bn->functions()->bytes; + auto sigmoidFunc = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); // gate is (z_t, r_t) - auto bytes = mRNNFunctions.bytes; - MNNBinaryExecute mulFunction = mRNNFunctions.mulFunction; - MNNBinaryExecute addFunction = mRNNFunctions.addFunction; - MNNBinaryExecute subFunction = mRNNFunctions.subFunction; - MNNUnaryExecute tanhFunction = mRNNFunctions.tanhFunction; - MNNUnaryExecute sigmoidFunction = mRNNFunctions.sigmoidFunction; auto inputAndStatePtr = inputAndState->host(); + auto hiddenStatePtr = hiddenState->host(); ::memcpy(inputAndStatePtr, input, inputLength * bytes); - ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStateInput, numUnits * bytes); + ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStatePtr, numUnits * bytes); inputAndState->setLength(1, inputLength + numUnits); // // [x_t, h_t-1] * [W_zr, R_zr]: (1, inputLength + numUnits) X (inputLength + numUnits, 2 * numUnits) @@ -38,8 +42,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, recurrentBias->setLength(1, 2 * numUnits); addFunction(gate->host(), gate->host(), recurrentBias->host(), 2*numUnits, -1); // (1, 2*numUnits) + const int gateSize = gate->elementSize(); auto gatePtr = gate->host(); - sigmoidFunction(gatePtr, gatePtr, 2 * numUnits); + sigmoidFunc(gatePtr, gatePtr, gateSize); // reset gate, // r_t is the second segment auto rtPtr = gatePtr + numUnits * bytes; @@ -47,7 +52,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // calculate Rt (.) (Ht_1 * Rh + Rbh) auto recurrentHiddenBiasPtr = recurrentBias->host() + 2 * numUnits * bytes; auto rhWeightPtr = candidateWeight->host() + inputLength * numUnits * bytes; - mMatMulU2U->execute((float*)hiddenStateInput, (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); + mMatMulU2U->execute(hiddenState->host(), (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); mulFunction(resetHt->host(), rtPtr, resetHt->host(), numUnits, -1); // calculate Xt * Wh @@ -60,7 +65,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // r_t: (1, numUnits) auto resetGatePtr = inputAndStatePtr + inputLength * bytes; // h_t1(1, numUnits) = r_t(1, numUnits) * h_t-1_(1, numUnits) - mulFunction(resetGatePtr, rtPtr, hiddenStateInput, numUnits, -1); + mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1); // deal with recurrent bias and linear_before_reset parameter auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes; auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host() + 2 * numUnits * bytes); @@ -71,9 +76,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, } // h = (1-g)*t+g*h = t + g*(h-t) tanhFunction(resetHt->host(), rtPtr, numUnits); - subFunction(hiddenStateOutput, hiddenStateInput, resetHt->host(), numUnits, -1); - mulFunction(hiddenStateOutput, hiddenStateOutput, gatePtr, numUnits, -1); - addFunction(hiddenStateOutput, hiddenStateOutput, resetHt->host(), numUnits, -1); + subFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); + mulFunction(hiddenStatePtr, hiddenStatePtr, gatePtr, numUnits, -1); + addFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); inputAndState->setLength(1, inputLength + 2 * numUnits); } @@ -138,13 +143,6 @@ ErrorCode CPURNNSequenceGRU::onResize(const std::vector& inputs, const backend()->onReleaseBuffer(mInputAndState.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mGate.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mResetHt.get(), Backend::DYNAMIC); - auto bn = static_cast(backend()); - mRNNFunctions.mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); - mRNNFunctions.addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); - mRNNFunctions.subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); - mRNNFunctions.tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); - mRNNFunctions.bytes = bn->functions()->bytes; - mRNNFunctions.sigmoidFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); return NO_ERROR; } @@ -185,29 +183,27 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const const int inputCodeLength = input->length(2); // MNN_PRINT("inputSequenceLength:%d, batchSize:%d, inputCodeLength:%d, mNumUnits:%d, hiddenStateDataSize:%d\n", inputSequenceLength, batchSize, inputCodeLength, mNumUnits, hiddenStateDataSize); for (int b = 0; b < batchSize; ++b) { // swap order - auto hiddenStateInput = hiddenStatePtr; - auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * (mIsBidirectionalRNN + 1)) { auto source = inputs[inputSize - 1]->host() + b * hiddenStateDataSize; - hiddenStateInput = source; + ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = 0; i < inputSequenceLength; ++i) { const int inputOffset = i * SequenceStride + b * inputCodeLength; + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, fwGateWeight, fwGateBias, + fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt); + if (mKeepAllOutputs) { - hiddenStateOutput = outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes; + ::memcpy(outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes, hiddenStatePtr, hiddenStateDataSize); } - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, fwGateWeight, fwGateBias, - fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); - - hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { - ::memcpy(outputYhPtr, hiddenStateOutput, hiddenStateDataSize); + ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); outputYhPtr += mNumUnits * bytes; } + } // backward rnn @@ -225,24 +221,22 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const auto outputBw = outputs[0]; auto const outputBwPtr = outputBw->host(); for (int b = 0; b < batchSize; ++b) { - auto hiddenStateInput = hiddenStatePtr; - auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * 2) { auto source = inputs[inputSize - 1]->host() + (batchSize + b) * hiddenStateDataSize; - hiddenStateInput = source; + ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = inputSequenceLength - 1; i >= 0; i--) { const int inputOffset = i * SequenceStride + b * inputCodeLength; + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, bwGateWeight, bwGateBias, + bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt); if (mKeepAllOutputs) { - hiddenStateOutput = outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes; + ::memcpy(outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes, + hiddenStatePtr, hiddenStateDataSize); } - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, bwGateWeight, bwGateBias, - bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); - hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); diff --git a/source/backend/cpu/CPURNNSequenceGRU.hpp b/source/backend/cpu/CPURNNSequenceGRU.hpp index 0125b9e8a1..0987d13053 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.hpp +++ b/source/backend/cpu/CPURNNSequenceGRU.hpp @@ -11,7 +11,6 @@ #include "core/Execution.hpp" #include "CPUMatMul.hpp" -#include "backend/cpu/compute/CommonOptFunction.h" namespace MNN { class CPURNNSequenceGRU : public Execution { @@ -20,20 +19,13 @@ class CPURNNSequenceGRU : public Execution { virtual ~CPURNNSequenceGRU(); virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - struct RNNFuntions { - MNNBinaryExecute mulFunction; - MNNBinaryExecute addFunction; - MNNBinaryExecute subFunction; - MNNUnaryExecute tanhFunction; - MNNUnaryExecute sigmoidFunction; - int bytes; - }; + private: void runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, + std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt, uint8_t* hiddenStateOutput); + std::shared_ptr& resetHt); bool mKeepAllOutputs; bool mIsBidirectionalRNN; bool mlinearBeforeReset; @@ -50,7 +42,6 @@ class CPURNNSequenceGRU : public Execution { std::shared_ptr mMatMulU2U; // For inputLength -> numUnit std::shared_ptr mMatMulI2U; - RNNFuntions mRNNFunctions; }; } // namespace MNN diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index 1339089347..3272086531 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -49,6 +49,227 @@ struct ReduceInfo { } }; +ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { + MNN_ASSERT(outputs.size() == 1); + auto output = outputs[0]; + OpCommonUtils::rasterInputReset(____inputs, outputs[0]); + auto des = TensorUtils::getDescribe(output); + auto outputDes = TensorUtils::getDescribe(output); + mNeedZero = !TensorUtils::regionIsFull(output); + mZeroPoint = 0; + mUseThreads = false; + if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { +#ifdef MNN_USE_SSE + mZeroPoint = (int)outputDes->quantAttr->zero + 128; +#else + mZeroPoint = (int)outputDes->quantAttr->zero; +#endif + } + mTempInput.clear(); + mFastBlit.clear(); + mCacheRegions.clear(); + mTempOutput = nullptr; + auto midFormat = MNN_DATA_FORMAT_NCHW; + mTempInputCopy.clear(); + mFast = false; + auto core = static_cast(backend())->functions(); + mSingleConvert.type = 0; + // all_srcFormat == dstFormat == NC4HW4 : Fast Exe + if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { + mFast = true; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + mFast = false; + break; + } + if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { + mFast = false; + break; + } + } + if (mFast) { + mUseThreads = des->regions.size() > 16 ? true : false; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (slice.origin == nullptr) { + continue; + } + Tensor::InsideDescribe::Region newRegion; + OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); + mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); + } + return NO_ERROR; + } + } + // srcNum == 1 && srcFormat != dstFormat : Single Convert + if (des->regions.size() == 1) { + OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); + if (mSingleConvert.type > 0) { + mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; + return NO_ERROR; + } + } + // Acquire Buffer for temp output + // TODO: optimize it + if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { + mTempOutput.reset(new Tensor); + TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); + } + if (nullptr != mTempOutput) { + auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + } + // input is NC4HW4 add Convert + std::vector forRelease; + TensorUtils::FuseWrap fuseUtils; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + auto origin = slice.origin; + if (nullptr == origin /*|| nullptr == origin->host()*/) { + continue; + } + // if tensor is not NC4HW4 or has been merged, don't need deal + if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); + continue; + } + // if NC4HW4's C%4 == 0, change convert to transpose and fuse it + if (origin->batch() == 1 && origin->channel() % core->pack == 0) { + int channel = origin->channel(); + int area = 1; + // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height + for (int d = 2; d < origin->dimensions(); d++) { + area *= origin->length(d); + } + Tensor::InsideDescribe::Region regionTmp; + regionTmp.src.offset = 0; + regionTmp.src.stride[0] = area * core->pack; + regionTmp.src.stride[1] = 1; + regionTmp.src.stride[2] = core->pack; + regionTmp.dst.offset = 0; + regionTmp.dst.stride[0] = area * core->pack; + regionTmp.dst.stride[1] = area; + regionTmp.dst.stride[2] = 1; + regionTmp.size[0] = channel / core->pack; + regionTmp.size[1] = core->pack; + regionTmp.size[2] = area; + regionTmp.origin = slice.origin; + bool merge = fuseUtils.match(regionTmp, slice); + if (merge) { + std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); + *newSlice = slice; + fuseUtils.apply(regionTmp, *newSlice); + // cache the merged tensor + if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); + mCacheRegions.emplace_back(newSlice); + continue; + } + } + auto cache = static_cast(backend())->getCache(); + auto tempTensor = cache->findCacheTensor(origin, midFormat); + //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); + if (nullptr == tempTensor) { + std::shared_ptr newTensor(new Tensor); + TensorUtils::copyShape(origin, newTensor.get()); + TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; + TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; + TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; + newTensor->buffer().type = origin->getType(); + TensorUtils::setLinearLayout(newTensor.get()); + mTempInput.insert(std::make_pair(origin, newTensor.get())); + auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + tempTensor = newTensor.get(); + TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; + cache->pushCacheTensor(newTensor, origin, midFormat); + } + if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { + forRelease.emplace_back(tempTensor); + } + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); + } + for (auto t : forRelease) { + backend()->onReleaseBuffer(t, Backend::DYNAMIC); + } + if (nullptr != mTempOutput) { + backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); + } + auto threadNumber = static_cast(backend())->threadNumber(); + mHasReduce = false; + ReduceInfo reduceInfo; + for (auto& iter : mTempInputCopy) { + if (reduceInfo.compute(*iter.second)) { + mHasReduce = true; + break; + } + } + if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { + // Split to multi region + auto region = mTempInputCopy[0].second; + if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = false; + return NO_ERROR; + } + if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + auto tensorPtr = mTempInputCopy[0].first; + int pos = -1; + for (int i=0; i<3; ++i) { + if (region->size[i] > 1) { + pos = i; + break; + } + } + if (-1 == pos) { + // Don't need divide + return NO_ERROR; + } + mTempInputCopy.clear(); + int divSize = UP_DIV(region->size[pos], threadNumber); + for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); + auto& cacheReg = *cacheRegPtr; + int sta = i * divSize; + int fin = sta + divSize; + fin = std::min(fin, region->size[pos]); + if (fin <= sta) { + break; + } + for (int v=0; v<3; ++v) { + cacheReg.src.stride[v] = region->src.stride[v]; + cacheReg.dst.stride[v] = region->dst.stride[v]; + } + int curSize = fin - sta; + for (int v=0; vsize[v]; + } + cacheReg.size[pos] = curSize; + cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; + cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; + for (int v=pos+1; v<3; ++v) { + cacheReg.size[v] = region->size[v]; + } + mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); + mCacheRegions.emplace_back(cacheRegPtr); + } + } + return NO_ERROR; +} static void _transpose(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region, int bytes) { int dims[4], keepDim = -1; for (int i = 0; i < 3; i++) { @@ -103,12 +324,15 @@ static void _2BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, } } -void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) { +void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) const { auto input = inputs[0]; auto output = outputs[0]; auto bytes = CPUBackend::getBytes(backend(), output); auto core = static_cast(backend())->functions(); - int threadNum = static_cast(backend())->threadNumber(); + auto threadNum = static_cast(backend())->threadNumber(); + if (mNeedZero) { + ::memset(output->host(), mZeroPoint, static_cast(backend())->getTensorSize(output) * bytes); + } auto byteC4 = bytes * core->pack; auto C4proc = core->MNN4BitcopyWithStride; switch (byteC4) { @@ -128,7 +352,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve if (!mUseThreads) { threadNum = 1; } - mTasks.emplace_back(std::make_pair([threadNum, this, output, bytes, C4proc, byteC4](int tId) { + MNN_CONCURRENCY_BEGIN(tId, threadNum) { for (int u=(int)tId; uhost() == nullptr) { @@ -169,7 +393,8 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve } } } - }, threadNum)); + } + MNN_CONCURRENCY_END(); } static BlitProc _selectUnitProc(int bytes, int stride, int ds) { @@ -371,307 +596,97 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const } } void CPURaster::tensorConvert(Tensor* input, Tensor* output, int bytes) { - std::pair, int> task; + auto& subIb = input->buffer(); + auto& subOb = output->buffer(); + auto source = TensorUtils::getDescribe(input)->dimensionFormat; + auto dest = TensorUtils::getDescribe(output)->dimensionFormat; + if (subIb.dimensions <= 1 || source == dest) { + ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); + return; + } + auto tup = CPUTensorConverter::splitDimensions(subIb, source); + int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); + const int bitLength = bytes; auto core = static_cast(backend())->functions(); auto threadNumber = static_cast(backend())->threadNumber(); if (!mUseThreads) { threadNumber = 1; } - task.first = [input, output, bytes, threadNumber, core](int tId) { - auto& subIb = input->buffer(); - auto& subOb = output->buffer(); - auto source = TensorUtils::getDescribe(input)->dimensionFormat; - auto dest = TensorUtils::getDescribe(output)->dimensionFormat; - if (subIb.dimensions <= 1 || source == dest) { - ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); - return; - } - auto tup = CPUTensorConverter::splitDimensions(subIb, source); - int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); - const int bitLength = bytes; + MNN_CONCURRENCY_BEGIN(tId, threadNumber) { CPUTensorConverter::convert(subIb.host, subOb.host, source, dest, batch, area, channel, bitLength, core, tId, threadNumber); }; - task.second = threadNumber; - mTasks.emplace_back(task); + MNN_CONCURRENCY_END(); } -ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { - MNN_ASSERT(outputs.size() == 1); - auto output = outputs[0]; - OpCommonUtils::rasterInputReset(____inputs, outputs[0]); - auto des = TensorUtils::getDescribe(output); - auto outputDes = TensorUtils::getDescribe(output); - mNeedZero = !TensorUtils::regionIsFull(output); - mZeroPoint = 0; - mUseThreads = false; - int threadNum = static_cast(backend())->threadNumber(); - if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { -#ifdef MNN_USE_SSE - mZeroPoint = (int)outputDes->quantAttr->zero + 128; -#else - mZeroPoint = (int)outputDes->quantAttr->zero; -#endif - } - size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - mTempInput.clear(); - mFastBlit.clear(); - mCacheRegions.clear(); - mTempOutput = nullptr; - mTasks.clear(); - auto midFormat = MNN_DATA_FORMAT_NCHW; - mTempInputCopy.clear(); - mFast = false; - auto core = static_cast(backend())->functions(); - mSingleConvert.type = 0; - // all_srcFormat == dstFormat == NC4HW4 : Fast Exe - if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { - mFast = true; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - mFast = false; - break; - } - if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { - mFast = false; - break; - } - } - if (mFast) { - mUseThreads = des->regions.size() > 16 ? true : false; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (slice.origin == nullptr) { - continue; - } - Tensor::InsideDescribe::Region newRegion; - OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); - mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); - } - executeFaster(____inputs, outputs); - return NO_ERROR; - } - } - // srcNum == 1 && srcFormat != dstFormat : Single Convert - if (des->regions.size() == 1) { - OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); - if (mSingleConvert.type > 0) { - std::pair, int> task; - mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; - auto realInput = ____inputs[0]; - int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; - auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; - auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; - auto channelC4 = UP_DIV(srcChannel, core->pack); - auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; - auto batchStride = srcChannel * srcArea * bytes; - auto inputBatchStride = batchStride; - auto outputBatchStride = batchStride; - if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { - if (realInput->dimensions() <= 1) { - task.first = [output, realInput, bytes](int tId) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - }; - task.second = 1; - mTasks.emplace_back(task); - return NO_ERROR; - } - inputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - destFormat = MNN_DATA_FORMAT_NHWC; - } else { - destFormat = MNN_DATA_FORMAT_NCHW; - } - } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { - if (output->dimensions() <= 1) { - task.first = [output, realInput, bytes](int tId) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - }; - task.second = 1; - mTasks.emplace_back(task); - return NO_ERROR; - } - outputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - sourceFormat = MNN_DATA_FORMAT_NHWC; - } else { - sourceFormat = MNN_DATA_FORMAT_NCHW; - } - } - if (!mUseThreads) { - threadNum = 1; - } - task.first = [realInput, output, sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, threadNum](int tId) { - CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); - }; - task.second = threadNum; - mTasks.emplace_back(task); - return NO_ERROR; - } - } - // Acquire Buffer for temp output - // TODO: optimize it - if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { - mTempOutput.reset(new Tensor); - TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); - } + + +ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { + void* mOutputPtr = nullptr; if (nullptr != mTempOutput) { - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = outputs[0]->host(); } - // input is NC4HW4 add Convert - std::vector forRelease; - TensorUtils::FuseWrap fuseUtils; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - auto origin = slice.origin; - if (nullptr == origin /*|| nullptr == origin->host()*/) { - continue; - } - // if tensor is not NC4HW4 or has been merged, don't need deal - if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; + if (mFast) { + executeFaster(____inputs, outputs); + return NO_ERROR; + } + auto core = static_cast(backend())->functions(); + auto output = outputs[0]; + size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); + auto outputEleSize = static_cast(backend())->getTensorSize(output); + auto threadNum = static_cast(backend())->threadNumber(); + if (mSingleConvert.type > 0) { + auto realInput = ____inputs[0]; + int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; + auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; + auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; + auto channelC4 = UP_DIV(srcChannel, core->pack); + auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; + auto batchStride = srcChannel * srcArea * bytes; + auto inputBatchStride = batchStride; + auto outputBatchStride = batchStride; + if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { + if (realInput->dimensions() <= 1) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + return NO_ERROR; } - mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); - continue; - } - // if NC4HW4's C%4 == 0, change convert to transpose and fuse it - if (origin->batch() == 1 && origin->channel() % core->pack == 0) { - int channel = origin->channel(); - int area = 1; - // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height - for (int d = 2; d < origin->dimensions(); d++) { - area *= origin->length(d); + inputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + destFormat = MNN_DATA_FORMAT_NHWC; + } else { + destFormat = MNN_DATA_FORMAT_NCHW; } - Tensor::InsideDescribe::Region regionTmp; - regionTmp.src.offset = 0; - regionTmp.src.stride[0] = area * core->pack; - regionTmp.src.stride[1] = 1; - regionTmp.src.stride[2] = core->pack; - regionTmp.dst.offset = 0; - regionTmp.dst.stride[0] = area * core->pack; - regionTmp.dst.stride[1] = area; - regionTmp.dst.stride[2] = 1; - regionTmp.size[0] = channel / core->pack; - regionTmp.size[1] = core->pack; - regionTmp.size[2] = area; - regionTmp.origin = slice.origin; - bool merge = fuseUtils.match(regionTmp, slice); - if (merge) { - std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); - *newSlice = slice; - fuseUtils.apply(regionTmp, *newSlice); - // cache the merged tensor - if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); - mCacheRegions.emplace_back(newSlice); - continue; + } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { + if (output->dimensions() <= 1) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + return NO_ERROR; } - } - auto cache = static_cast(backend())->getCache(); - auto tempTensor = cache->findCacheTensor(origin, midFormat); - //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); - if (nullptr == tempTensor) { - std::shared_ptr newTensor(new Tensor); - TensorUtils::copyShape(origin, newTensor.get()); - TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; - TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; - TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; - newTensor->buffer().type = origin->getType(); - TensorUtils::setLinearLayout(newTensor.get()); - mTempInput.insert(std::make_pair(origin, newTensor.get())); - auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; + outputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + sourceFormat = MNN_DATA_FORMAT_NHWC; + } else { + sourceFormat = MNN_DATA_FORMAT_NCHW; } - tempTensor = newTensor.get(); - TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; - cache->pushCacheTensor(newTensor, origin, midFormat); } - if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { - forRelease.emplace_back(tempTensor); + if (!mUseThreads) { + threadNum = 1; } - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); - } - for (auto t : forRelease) { - backend()->onReleaseBuffer(t, Backend::DYNAMIC); - } - if (nullptr != mTempOutput) { - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); + MNN_CONCURRENCY_BEGIN(tId, threadNum) { + CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); + }; + MNN_CONCURRENCY_END(); + return NO_ERROR; } - auto threadNumber = static_cast(backend())->threadNumber(); - mHasReduce = false; - ReduceInfo reduceInfo; - for (auto& iter : mTempInputCopy) { - if (reduceInfo.compute(*iter.second)) { - mHasReduce = true; - break; + if (mNeedZero) { + if (mTempOutput == nullptr) { + ::memset(output->host(), mZeroPoint, outputEleSize * bytes); + } else { + ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); } } - // Encode convert for (auto& iter : mTempInput) { tensorConvert(iter.first, iter.second, (int)bytes); } - do { - if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { - // Split to multi region - auto region = mTempInputCopy[0].second; - if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = false; - break; - } - if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - auto tensorPtr = mTempInputCopy[0].first; - int pos = -1; - for (int i=0; i<3; ++i) { - if (region->size[i] > 1) { - pos = i; - break; - } - } - if (-1 == pos) { - // Don't need divide - break; - } - mTempInputCopy.clear(); - int divSize = UP_DIV(region->size[pos], threadNumber); - for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); - auto& cacheReg = *cacheRegPtr; - int sta = i * divSize; - int fin = sta + divSize; - fin = std::min(fin, region->size[pos]); - if (fin <= sta) { - break; - } - for (int v=0; v<3; ++v) { - cacheReg.src.stride[v] = region->src.stride[v]; - cacheReg.dst.stride[v] = region->dst.stride[v]; - } - int curSize = fin - sta; - for (int v=0; vsize[v]; - } - cacheReg.size[pos] = curSize; - cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; - cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; - for (int v=pos+1; v<3; ++v) { - cacheReg.size[v] = region->size[v]; - } - mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); - mCacheRegions.emplace_back(cacheRegPtr); - } - } - } while (false); if (mHasReduce) { // Don't support reduce with multi thread now threadNum = 1; @@ -685,13 +700,8 @@ ErrorCode CPURaster::onResize(const std::vector &____inputs, const std if (outputDescribe->overlap) { threadNum = 1; } - mTasks.emplace_back(std::make_pair([this, threadNum, output, bytes, core](int tId){ - void* mOutputPtr = nullptr; - if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = output->host(); - } + + MNN_CONCURRENCY_BEGIN(tId, threadNum) { for (int u=tId; u &____inputs, const std auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; _blit(slice, (int)bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp); } - }, threadNum)); - if (nullptr != mTempOutput) { - tensorConvert(mTempOutput.get(), output, (int)bytes); } - return NO_ERROR; -} - - -ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { - void* mOutputPtr = nullptr; + MNN_CONCURRENCY_END(); if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = outputs[0]->host(); - } - auto core = static_cast(backend())->functions(); - auto output = outputs[0]; - size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - auto outputEleSize = static_cast(backend())->getTensorSize(output); - auto threadNum = static_cast(backend())->threadNumber(); - if (mNeedZero) { - if (mTempOutput == nullptr) { - ::memset(output->host(), mZeroPoint, outputEleSize * bytes); - } else { - ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); - } - } - for (auto& task : mTasks) { - MNN_CONCURRENCY_ENQUEUE(task); + tensorConvert(mTempOutput.get(), output, (int)bytes); } return NO_ERROR; } @@ -1081,15 +1066,7 @@ class CPULoop : public Execution { auto stride2 = cmd->view()->GetAs(2)->stride()->data(); auto blit1 = _selectUnitProc(bytes, stride1[2], 1); auto blit2 = _selectUnitProc(bytes, stride2[2], 1); - if (cmd->size()->data()[2] == 1 || (stride1[2] <= 1 && stride2[2] <= 1 && (stride1[2] + stride1[1] != 0))) { - // Support elementwise or one src broadcast - int needBroadcastIndex = -1; - if (0 == stride1[2]) { - needBroadcastIndex = 0; - } - if (0 == stride2[2]) { - needBroadcastIndex = 1; - } + if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) { for (int z=0; zsize()->data()[0]; ++z) { auto src0Z = src0 + z * stride1[0] * bytes; auto src1Z = src1 + z * stride2[0] * bytes; @@ -1098,7 +1075,7 @@ class CPULoop : public Execution { auto src0Y = src0Z + y * stride1[1] * bytes; auto src1Y = src1Z + y * stride2[1] * bytes; auto dstY = dstZ + y * stride0[1] * bytes; - proc(dstY, src0Y, src1Y, cmd->size()->data()[2], needBroadcastIndex); + proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1); } } } else { diff --git a/source/backend/cpu/CPURaster.hpp b/source/backend/cpu/CPURaster.hpp index bff149df52..9df10700bd 100644 --- a/source/backend/cpu/CPURaster.hpp +++ b/source/backend/cpu/CPURaster.hpp @@ -24,7 +24,7 @@ class CPURaster : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - void executeFaster(const std::vector &inputs, const std::vector &outputs); + void executeFaster(const std::vector &inputs, const std::vector &outputs) const; void tensorConvert(Tensor* input, Tensor* output, int bytes); private: std::map mTempInput; @@ -38,7 +38,6 @@ class CPURaster : public Execution { int32_t mZeroPoint = 0; bool mHasReduce = false; bool mUseThreads = false; - std::vector, int>> mTasks; }; } #endif diff --git a/source/backend/cpu/ThreadPool.cpp b/source/backend/cpu/ThreadPool.cpp index d7765c4fbc..15a2d8241c 100644 --- a/source/backend/cpu/ThreadPool.cpp +++ b/source/backend/cpu/ThreadPool.cpp @@ -60,7 +60,7 @@ ThreadPool::ThreadPool(int numberThread) { while (mActiveCount > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { - mTasks[i].first->first(threadIndex); + mTasks[i].first.first(threadIndex); { *mTasks[i].second[threadIndex] = false; } } } @@ -118,18 +118,16 @@ void ThreadPool::deactive() { mActiveCount--; } -void ThreadPool::enqueue(TASK* taskp, int index) { - auto& task = *taskp; +void ThreadPool::enqueue(TASK&& task, int index) { if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } - enqueueInternal(taskp, index); + enqueueInternal(std::move(task), index); } -void ThreadPool::enqueueInternal(TASK* taskp, int index) { - auto& task = *taskp; +void ThreadPool::enqueueInternal(TASK&& task, int index) { if (mActiveCount == 0) { for (int i = 0; i < task.second; ++i) { task.first(i); @@ -137,25 +135,24 @@ void ThreadPool::enqueueInternal(TASK* taskp, int index) { return; } int workSize = task.second; - TASK* tmpTask = nullptr; if (workSize > mNumberThread) { - tmpTask = new TASK; - *tmpTask = std::make_pair([workSize, &task, this](int tId) { - for (int v = tId; v < workSize; v += mNumberThread) { - task.first(v); - } - }, mNumberThread); - mTasks[index].first = tmpTask; + mTasks[index].first = std::make_pair( + [workSize, &task, this](int tId) { + for (int v = tId; v < workSize; v += mNumberThread) { + task.first(v); + } + }, + mNumberThread); workSize = mNumberThread; } else { - mTasks[index].first = taskp; + mTasks[index].first = std::move(task); } { for (int i = 1; i < workSize; ++i) { *mTasks[index].second[i] = true; } } - mTasks[index].first->first(0); + mTasks[index].first.first(0); bool complete = true; do { complete = true; @@ -168,9 +165,6 @@ void ThreadPool::enqueueInternal(TASK* taskp, int index) { std::this_thread::yield(); // FUNC_PRINT(notComplete); } while (!complete); - if (nullptr != tmpTask) { - delete tmpTask; - } } } // namespace MNN #endif diff --git a/source/backend/cpu/ThreadPool.hpp b/source/backend/cpu/ThreadPool.hpp index 8891da61b1..4bf23de1b0 100644 --- a/source/backend/cpu/ThreadPool.hpp +++ b/source/backend/cpu/ThreadPool.hpp @@ -25,7 +25,7 @@ class MNN_PUBLIC ThreadPool { int numberThread() const { return mNumberThread; } - void enqueue(TASK* task, int index); + void enqueue(TASK&& task, int index); void active(); void deactive(); @@ -37,7 +37,7 @@ class MNN_PUBLIC ThreadPool { static void destroy(); private: - void enqueueInternal(TASK* task, int index); + void enqueueInternal(TASK&& task, int index); ThreadPool(int numberThread = 0); ~ThreadPool(); @@ -46,7 +46,7 @@ class MNN_PUBLIC ThreadPool { std::vector mTaskAvailable; std::atomic mStop = {false}; - std::vector>> mTasks; + std::vector>> mTasks; std::condition_variable mCondition; std::mutex mQueueMutex; diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 61ebce6bdc..18fca54a4e 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,9 +36,6 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() - if(MNN_KLEIDIAI_DEFAULT_ON) - add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) - endif() endif() if (MNN_SME2) diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index c9bfcc2189..d7d0d7fb34 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -3882,13 +3882,12 @@ void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, si #endif -void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tIdL) { +void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { auto l = param->l; auto h = param->h; auto numberThread = param->numberThread; auto lC4 = l / 4; auto lR = lC4 * 4; - auto tId = (int)tIdL; if (param->BTranspose) { for (int y=tId; y= 8) { - if (0 == tId) { - auto bs = B + hEnd; - Vec4 sumValue0; - Vec4 sumValue1; - if (biasPtr != nullptr) { - sumValue0 = Vec4::load(biasPtr + hEnd + 0); - sumValue1 = Vec4::load(biasPtr + hEnd + 4); - } else { - sumValue0 = Vec4(0.0f); - sumValue1 = Vec4(0.0f); - } - auto srcY = A + hEnd * l; - for (int x=0; x= 4) { - if (0 == tId) { - auto bs = B + hEnd; - Vec4 sumValue0; - if (biasPtr != nullptr) { - sumValue0 = Vec4::load(biasPtr + hEnd + 0); - } else { - sumValue0 = Vec4(0.0f); - } - auto srcY = A + hEnd * l; - for (int x=0; x= 8) { - sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); - sum1 = Vec::fma(sum1, Vec4::load(srcY + lR + 4), Vec4::load(B + lR + 4)); - lR += 8; - } - if (l - lR >= 4) { - sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); - lR += 4; - } - sum2 = sum2 + sum3; - sumValue = sumValue + sum1; - sumValue = sumValue + sum2; + sumValue = sumValue + Vec4::load(srcY + 4 * x) * Vec4::load(B + 4 * x); + } float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; for (int x=lR; x - -void CPUBilinearLineC4(float* dst, const float* A, const float* B, - const float* t, int8_t* zeroPoint, size_t number) { - float tf = *t; - float sf = 1.0f - tf; - size_t total = number << 2; - - size_t i = 0; - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); - vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); - v = __riscv_vle32_v_f32m8(B + i, vl); - result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); - __riscv_vse32_v_f32m8(dst + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp deleted file mode 100644 index 5063c39bff..0000000000 --- a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -void CPUBilinearSampleC4(const float* src, float* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); - vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); - vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); - vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp deleted file mode 100644 index 59bb28a039..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); - size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp deleted file mode 100644 index 6d966789f7..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { - float beta = parameters[1]; - float minF = parameters[2]; - float maxF = parameters[3]; - const ptrdiff_t stride = 4 * sizeof(float); - - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + 4 * y; - auto c = C + cStride * y; - float b0Beta = b[0] * beta; - float b1Beta = b[1] * beta; - float b2Beta = b[2] * beta; - float b3Beta = b[3] * beta; - size_t w = width; - - while (w > 0) { - size_t vl = __riscv_vsetvl_e32m8(w); - - vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); - - a += 4 * vl; - c += 4 * vl; - w -= vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp deleted file mode 100644 index 145cbea73f..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include - -void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp deleted file mode 100644 index d46fe6c85b..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp deleted file mode 100644 index 684db6aed3..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp deleted file mode 100644 index a26243bdb8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, - const float* t, int8_t* zeroPoint, size_t number) { - int offset = *zeroPoint; - int8_t* dstPtr = dst; - - const int pack = 8; - const int16_t df = (int16_t)((*t) * 128.0f); - const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); - const size_t total = number * pack; - const int32_t ROUND_HALF = 1 << 13; - - size_t vl; - for (size_t i = 0; i < total; i += vl) { - vl = __riscv_vsetvl_e16m4(total - i); - vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); - vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); - v16 = __riscv_vle16_v_i16m4(B + i, vl); - v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); - - vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); - - tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); - mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); - vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); - mask = __riscv_vmand_mm_b4(mask, hasRem, vl); - - v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); - v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); - v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); - vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); - - __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp deleted file mode 100644 index bd111e3be4..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include - -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - int16_t offset = (int16_t)(*zeroPoint); - const int pack = 8; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - vint16m4_t vdf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); - vint16m4_t vsf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8( - __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), - c, vl); - - vint16m4_t va = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - - vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); - voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), - c, vl); - - vint16m4_t vb = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); - __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, - __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp deleted file mode 100644 index 9d524f13ca..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp deleted file mode 100644 index f82faf83f5..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include - -void MNNConvRunForLineDepthwise( - float* dst, const float* src, const float* weight, - size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, - size_t height, size_t srcHStep, size_t dstHStep, - const float* bias, const float* parameters) { - float minV = parameters[0]; - float maxV = parameters[1]; - ptrdiff_t srcByteStride = src_w_setup * sizeof(float); - ptrdiff_t dstByteStride = 4 * sizeof(float); - - for (size_t y = 0; y < height; ++y) { - const float* srcY = src + y * srcHStep; - float* dstY = dst + y * dstHStep; - size_t dx = 0; - - while (dx < width) { - size_t vl = __riscv_vsetvl_e32m8(width - dx); - - for (int c = 0; c < 4; ++c) { - vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); - const float* srcBase = srcY + dx * src_w_setup + c; - const float* weightPtr = weight + c; - - for (size_t fy = 0; fy < fh; ++fy) { - const float* srcFy = srcBase + fy * dilateY_step; - - for (size_t fx = 0; fx < fw; ++fx) { - float w = *weightPtr; - weightPtr += 4; - const float* srcFx = srcFy + fx * dilateX_step; - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); - } - } - - acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); - acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); - float* dstAddr = dstY + dx * 4 + c; - __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); - } - - dx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp deleted file mode 100644 index 3d8c4f13fc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include - -void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); -size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp deleted file mode 100644 index fd6ce7a274..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include - -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const int offset = *zeroPoint; - const int minVal = (int)minValue; - const int maxVal = (int)maxValue; - const size_t total = number << 4; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); - vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); - acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); - - vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); - vint = __riscv_vadd_vx_i32m8(vint, offset, vl); - vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); - vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); - - vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); - vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); - __riscv_vse8_v_i8m2(dst + i, vi8, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp deleted file mode 100644 index 0da63ca0ff..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include - -void MNNCubicLineC4(float* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const size_t total = number << 2; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - __riscv_vse32_v_f32m8(dst + i, acc, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp deleted file mode 100644 index fd5b24a53d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include - -void MNNCubicSampleC16(const int8_t* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 16; - int8_t zp = *zeroPoint; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp deleted file mode 100644 index 78207e69e8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include - -void MNNCubicSampleC4(const float* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp deleted file mode 100644 index 6658715e7e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNDeconvRunForUnitDepthWise( - const float* dst, float* src, const float* weight, - size_t fw, size_t fh, - size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { - const ptrdiff_t wStride = 4 * sizeof(float); - const ptrdiff_t sStride = dilateX_step * sizeof(float); - float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; - - for (size_t fy = 0; fy < fh; ++fy) { - float* srcY = src + fy * dilateY_step; - const float* weightY = weight + fy * weightY_step; - - size_t fx = 0; - while (fx < fw) { - size_t vl = __riscv_vsetvl_e32m8(fw - fx); - - vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); - __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); - __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); - __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); - __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); - - fx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp deleted file mode 100644 index 952fcaf090..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp deleted file mode 100644 index 5ee4540f98..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp deleted file mode 100644 index 183a38bb10..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { - const float init = -FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - maxBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp deleted file mode 100644 index 9e8ade8641..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { - const float init = FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - minBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp deleted file mode 100644 index 9a74f8998d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include - -void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC2 = depth / 2; - int depthRemain = depthC2 * 2; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[2]; - - for (int z = 0; z < depthC2; ++z) { - float *dstZ = dst + z * areaOffset[1] * 2; - - for (int y = 0; y < 2; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - } - - srcOffset += areaOffset[0] * 2; - } - - if (remain > 0) { - float *dstZ = dst + depthC2 * areaOffset[1] * 2; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 2; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 2; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp deleted file mode 100644 index 024e2c8c07..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include - -void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[4]; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ = dst + z * areaOffset[1] * 4; - - for (int y = 0; y < 4; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - dstPtr[2] = srcChannel[2][x]; - dstPtr[3] = srcChannel[3][x]; - } - - srcOffset += areaOffset[0] * 4; - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 4; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 4; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp deleted file mode 100644 index f2b6c7a78d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp deleted file mode 100644 index ddd67a7d8c..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp deleted file mode 100644 index d56b58546d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp deleted file mode 100644 index 7c6decf39e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp deleted file mode 100644 index 1b946c33cc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp deleted file mode 100644 index 262f4cbfab..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include - -void MNNReluWithSlopeChannel(float *dst, const float *src, - const float *slope, size_t sizeQuad, - size_t depthQuad) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t j = 0; j < depthQuad; ++j) { - const float *srcZ = src + 4 * j * sizeQuad; - float *dstZ = dst + 4 * j * sizeQuad; - float s0 = slope[4*j], s1 = slope[4*j + 1]; - float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; - size_t i = 0; - while (i < sizeQuad) { - size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); - const float *srcBase = srcZ + 4*i; - float *dstBase = dstZ + 4*i; - - vfloat32m8_t v; - vbool4_t mask; - - v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); - __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); - __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); - __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); - __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); - - i += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp deleted file mode 100644 index 10992f9d59..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t z = 0; z < biasNumber; ++z) { - float *dstZ = dst + z * planeNumber * 4; - const float *srcZ = src + z * planeNumber * 4; - const float *biasZ = bias + 4 * z; - const float *alphaZ = alpha + 4 * z; - float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; - float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; - - size_t n = planeNumber; - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a0, vl); - data = __riscv_vfadd_vf_f32m8(data, b0, vl); - __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a1, vl); - data = __riscv_vfadd_vf_f32m8(data, b1, vl); - __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a2, vl); - data = __riscv_vfadd_vf_f32m8(data, b2, vl); - __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a3, vl); - data = __riscv_vfadd_vf_f32m8(data, b3, vl); - __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); - - srcZ += vl * 4; - dstZ += vl * 4; - n -= vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp deleted file mode 100644 index f510058c83..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include -#include - -void MNNSoftmax(float *dest, const float *source, size_t size) { - size_t n = size; - const float *sourcePtr = source; - float *destPtr = dest; - float maxValue = -FLT_MAX; - vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); - maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); - sourcePtr += vl; - n -= vl; - } - - maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); - const float param = 0.6931471805599453f; - const float xLimit = 87.0f; - float sumValue = 0.f; - vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); - n = size; - sourcePtr = source; - destPtr = dest; - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); - vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); - vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); - vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); - - vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); - vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); - - vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( - __riscv_vsll_vx_i32m8( - __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); - - vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); - vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); - - vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - - vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); - __riscv_vse32_v_f32m8(destPtr, vA, vl); - sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); - - sourcePtr += vl; - destPtr += vl; - n -= vl; - } - - sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); - float sumInv = 1.0f / sumValue; - n = size; - destPtr = dest; - - while (n > 0) - { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); - vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); - __riscv_vse32_v_f32m8(destPtr, vDest, vl); - destPtr += vl; - n -= vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp deleted file mode 100644 index 8ab5bb89fa..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include - -void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, - float *xAddr, size_t cStride, size_t eSub, size_t hSub) { - for (int y = 0; y < hSub; ++y) { - float *c11Y = c11 + y * cStride; - float *c12Y = c12 + y * cStride; - float *c22Y = c22 + y * cStride; - float *c21Y = c21 + y * cStride; - float *xY = xAddr + y * eSub * 4; - size_t totalElements = eSub * 4; - size_t p = 0; - - while (p < totalElements) { - size_t vl = __riscv_vsetvl_e32m8(totalElements - p); - vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); - vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); - t = __riscv_vfadd_vv_f32m8(t, tmp, vl); - vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); - - tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); - __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); - - tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); - __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); - - c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); - __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); - - p += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp deleted file mode 100644 index 7598d6f8ac..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include - -void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); - - for (int i = 0; i < h; ++i) { - const int16_t* srcPtr = srcO + i; - int16_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e16m8(w - j); - vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); - __riscv_vse16_v_i16m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - - diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp deleted file mode 100644 index e5c5eb83e6..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include - -void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); - - for (int i = 0; i < h; ++i) { - const int32_t* srcPtr = srcO + i; - int32_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e32m8(w - j); - vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); - __riscv_vse32_v_i32m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp deleted file mode 100644 index 4676e6dede..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ[4]; - - for (int y = 0; y < 4; ++y) { - dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); - srcOffset += 4 * vl; - } - - for (; x < area; ++x) { - dstZ[0][x] = srcOffset[0]; - dstZ[1][x] = srcOffset[1]; - dstZ[2][x] = srcOffset[2]; - dstZ[3][x] = srcOffset[3]; - srcOffset += (areaOffset[0] - area) * 4; - } - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - const float *srcBase = srcOffset; - - for (int y = 0; y < remain; ++y) { - float *dstChannel = dstZ + y * areaOffset[1]; - const float *srcChannel = srcBase + y; - - for (size_t x = 0; x < area; ++x) { - dstChannel[x] = srcChannel[0]; - srcChannel += 4; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp deleted file mode 100644 index 7332360ce8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - float maxV = -FLT_MAX; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); - vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); - maxV = __riscv_vfmv_f_s_f32m1_f32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp deleted file mode 100644 index 8c199709ec..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - int32_t maxV = INT32_MIN; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); - vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); - maxV = __riscv_vmv_x_s_i32m1_i32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 6850b6b4f6..bcf618c3c9 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,11 +68,9 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; -#ifdef MNN_DEFAULT_USE_KLEIDIAI - bool enableKleidiAI = true; -#else + bool enableKleidiAI = false; -#endif + // Use CPU Ids std::vector cpuIds; diff --git a/source/core/Concurrency.h b/source/core/Concurrency.h index 7c06625fe4..73f5984e5a 100644 --- a/source/core/Concurrency.h +++ b/source/core/Concurrency.h @@ -12,9 +12,6 @@ #define LAUNCH_MULTI_THREADS_WORKLOAD 1e+5 #ifdef MNN_FORBIT_MULTI_THREADS -#define MNN_CONCURRENCY_ENQUEUE(task) \ -for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} - #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) for (int __iter__ = 0; __iter__ < __num__; __iter__++) { #define MNN_CONCURRENCY_END() } @@ -22,8 +19,6 @@ for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__) #include "backend/cpu/ThreadPool.hpp" #define MNN_STRINGIFY(a) #a -#define MNN_CONCURRENCY_ENQUEUE(task) ((CPUBackend*)backend())->enqueue(task) - #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ { \ std::pair, int> task; \ @@ -33,7 +28,8 @@ for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__) } \ ; \ auto cpuBn = (CPUBackend*)backend(); \ - cpuBn->enqueue(task); \ + auto thrPl = cpuBn->threadPool(); \ + thrPl->enqueue(std::move(task), cpuBn->taskIndex()); \ } #else @@ -42,9 +38,6 @@ for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__) #include #include -#define MNN_CONCURRENCY_ENQUEUE(task) \ -dispatch_apply(task.second, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) {task.first(__iter__);}); - #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) { #define MNN_CONCURRENCY_END() \ @@ -65,8 +58,6 @@ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, // Android #else #include -#define MNN_CONCURRENCY_ENQUEUE(task) \ -_Pragma("omp parallel for") for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} #define MNN_STRINGIFY(a) #a #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index a69263ffaa..c80afaef87 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -386,7 +386,98 @@ void OpCommonUtils::broastCastComputeDim(int* dims, int* stride, int* iStride0, } } } +std::vector> OpCommonUtils::computeReduceDims(const std::vector& inputs, + const Op* op) { + // Compute axises + std::vector axises; + if (inputs.size() >= 2) { + auto size = inputs[1]->elementSize(); + auto dims = inputs[1]->host(); + for (int i = 0; i < size; ++i) { + axises.emplace_back(dims[i]); + } + } else { + auto reduct = op->main_as_ReductionParam(); + if (nullptr != reduct->dim()) { + for (int i = 0; i < reduct->dim()->size(); ++i) { + axises.emplace_back(reduct->dim()->data()[i]); + } + } + } + auto totalSize = TensorUtils::getRawSize(inputs[0]); + if (axises.empty()) { + return {std::make_tuple(1, totalSize, 1)}; + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + axises[i] = inputs[0]->dimensions() + axises[i]; + if (axises[i] < 0) { + return {std::make_tuple(1, totalSize, 1)}; + } + } + } + // Cache for input's dims + std::vector lengths(inputs[0]->dimensions()); + for (int i = 0; i < lengths.size(); ++i) { + lengths[i] = inputs[0]->length(i); + } + std::vector> groupAxises; + { + // Merge adj axis + std::sort(axises.begin(), axises.end()); + int lastAxis = axises[0]; + int length = 1; + int start = axises[0]; + for (int i = 1; i < axises.size(); ++i) { + // MNN_PRINT("%d - %d\n", axises[i], lastAxis); + if (axises[i] - lastAxis == 1) { + length++; + } else { + groupAxises.emplace_back(std::make_pair(start, length)); + length = 1; + start = axises[i]; + } + lastAxis = axises[i]; + } + groupAxises.emplace_back(std::make_pair(start, length)); + } + + // Compute inside-outside-axis + std::vector> result; + for (int i = 0; i < groupAxises.size(); ++i) { + int outsideSize = 1; + int insideSize = 1; + int axisSize = 1; + auto start = groupAxises[i].first; + auto length = groupAxises[i].second; + if (start >= (int)lengths.size()) { + break; + } + for (int j = 0; j < start; ++j) { + outsideSize *= lengths[j]; + } + for (int j = start; j < start + length; ++j) { + if (j >= (int)lengths.size()) { + break; + } + axisSize *= lengths[j]; + lengths[j] = 1; + } + for (int j = start + length; j < lengths.size(); ++j) { + insideSize *= lengths[j]; + } + if (1 == axisSize) { + continue; + } + result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); + } + // FUNC_PRINT(result.size()); + if (result.empty()) { + result.emplace_back(std::make_tuple(1, 1, totalSize)); + } + return result; +} void OpCommonUtils::unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice) { int value = indice; diff --git a/source/core/OpCommonUtils.hpp b/source/core/OpCommonUtils.hpp index 8ec0628336..0740cc16b2 100644 --- a/source/core/OpCommonUtils.hpp +++ b/source/core/OpCommonUtils.hpp @@ -56,6 +56,7 @@ class MNN_PUBLIC OpCommonUtils { static bool supportDynamicInputMemory(MNNForwardType type); static void broastCastComputeDim(int* dims, int* stride, int* iStride0, int* iStride1, const Tensor* input0, const Tensor* input1, const Tensor* output); + static std::vector> computeReduceDims(const std::vector& inputs, const Op* op); static void unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice); static int computeStride(int32_t* strides, const int* shape, int length); diff --git a/source/core/TensorUtils.cpp b/source/core/TensorUtils.cpp index d233fc9d89..ae5b87143c 100644 --- a/source/core/TensorUtils.cpp +++ b/source/core/TensorUtils.cpp @@ -32,18 +32,6 @@ bool TensorUtils::regionIsFull(Tensor* input) { return regionSize == size; } -void TensorUtils::makeFullRef(Tensor* output, Tensor* input) { - auto des = TensorUtils::getDescribe(input); - auto outputDes = TensorUtils::getDescribe(output); - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; - if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) { - outputDes->regions = des->regions; - } else { - outputDes->regions = {makeFullSlice(input)}; - } -} - - Tensor::InsideDescribe::Region TensorUtils::makeFullSlice(Tensor* input) { Tensor::InsideDescribe::Region totalSlice; totalSlice.src.offset = 0; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index a577fea05f..1342a669bd 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -184,7 +184,6 @@ class MNN_PUBLIC TensorUtils { static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat); static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input); - static void makeFullRef(Tensor* output, Tensor* input); static bool regionIsFull(Tensor* input); static bool isCopyRegion(const Tensor::InsideDescribe::Region& region); static bool isTransposeRegion(const Tensor::InsideDescribe::Region& region); diff --git a/source/geometry/GeometryComputerUtils.cpp b/source/geometry/GeometryComputerUtils.cpp index 85f64de55d..01a4e02ea2 100644 --- a/source/geometry/GeometryComputerUtils.cpp +++ b/source/geometry/GeometryComputerUtils.cpp @@ -477,9 +477,9 @@ std::shared_ptr GeometryComputerUtils::makeBinary(int type, Tensor* inp return cmdP; } -std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis) { +std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output) { flatbuffers::FlatBufferBuilder builder(DEFAULT_ALLOCATE_SIZE); - auto vec = builder.CreateVector(std::vector{axis}); + auto vec = builder.CreateVector(std::vector{1}); ReductionParamBuilder builder_(builder); builder_.add_operation(type); builder_.add_keepDims(true); diff --git a/source/geometry/GeometryComputerUtils.hpp b/source/geometry/GeometryComputerUtils.hpp index 97c4d5811f..c0dffdcdb1 100644 --- a/source/geometry/GeometryComputerUtils.hpp +++ b/source/geometry/GeometryComputerUtils.hpp @@ -18,7 +18,7 @@ class GeometryComputerUtils { static void addConvert(const CommandBuffer& srcBuffer, CommandBuffer& dstBuffer, GeometryComputer::Context& ctx); static std::shared_ptr makeCommand(flatbuffers::FlatBufferBuilder& builder, const std::vector& inputs, const std::vector& outputs); static std::shared_ptr makeBinary(int type, Tensor* input0, Tensor* input1, Tensor* output); - static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis = 1); + static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output); static std::shared_ptr makeUnary(UnaryOpOperation type, Tensor* input0, Tensor* output); static std::shared_ptr makeLayerNorm(Tensor* input0, Tensor* output, std::vector axis, float epsilon, std::vector gamma, std::vector beta, std::vector external, int group = 1, bool useRMS = false); static std::shared_ptr makeMatMul(Tensor* input0, Tensor* input1, Tensor* output, Tensor* Bias = nullptr, diff --git a/source/geometry/GeometryReduce.cpp b/source/geometry/GeometryReduce.cpp index 855f4bcf69..c2a3bb4114 100644 --- a/source/geometry/GeometryReduce.cpp +++ b/source/geometry/GeometryReduce.cpp @@ -10,83 +10,6 @@ #include "geometry/GeometryComputerUtils.hpp" #include "core/OpCommonUtils.hpp" namespace MNN { -static std::vector> _computeReduceDims(const std::vector& inputs, - std::vector& axises) { - - auto totalSize = TensorUtils::getRawSize(inputs[0]); - if (axises.empty()) { - return {std::make_tuple(1, totalSize, 1)}; - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - if (axises[i] < 0) { - return {std::make_tuple(1, totalSize, 1)}; - } - } - } - // Cache for input's dims - std::vector lengths(inputs[0]->dimensions()); - for (int i = 0; i < lengths.size(); ++i) { - lengths[i] = inputs[0]->length(i); - } - std::vector> groupAxises; - { - // Merge adj axis - std::sort(axises.begin(), axises.end()); - int lastAxis = axises[0]; - int length = 1; - int start = axises[0]; - for (int i = 1; i < axises.size(); ++i) { - // MNN_PRINT("%d - %d\n", axises[i], lastAxis); - if (axises[i] - lastAxis == 1) { - length++; - } else { - groupAxises.emplace_back(std::make_pair(start, length)); - length = 1; - start = axises[i]; - } - lastAxis = axises[i]; - } - groupAxises.emplace_back(std::make_pair(start, length)); - } - - // Compute inside-outside-axis - std::vector> result; - - for (int i = 0; i < groupAxises.size(); ++i) { - int outsideSize = 1; - int insideSize = 1; - int axisSize = 1; - auto start = groupAxises[i].first; - auto length = groupAxises[i].second; - if (start >= (int)lengths.size()) { - break; - } - for (int j = 0; j < start; ++j) { - outsideSize *= lengths[j]; - } - for (int j = start; j < start + length; ++j) { - if (j >= (int)lengths.size()) { - break; - } - axisSize *= lengths[j]; - lengths[j] = 1; - } - for (int j = start + length; j < lengths.size(); ++j) { - insideSize *= lengths[j]; - } - if (1 == axisSize) { - continue; - } - result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); - } - // FUNC_PRINT(result.size()); - if (result.empty()) { - result.emplace_back(std::make_tuple(1, 1, totalSize)); - } - return result; -} - class GeometryReduce : public GeometryComputer { public: virtual bool onCompute(const Op* op, const std::vector& inputs, const std::vector& outputs, @@ -95,31 +18,6 @@ class GeometryReduce : public GeometryComputer { MNN_ASSERT(inputs.size() >= 1); auto reduct = op->main_as_ReductionParam(); auto reductOp = reduct->operation(); - std::vector axises; - if (inputs.size() >= 2) { - auto size = inputs[1]->elementSize(); - auto dims = inputs[1]->host(); - for (int i = 0; i < size; ++i) { - axises.emplace_back(dims[i]); - } - } else { - auto reduct = op->main_as_ReductionParam(); - if (nullptr != reduct->dim()) { - for (int i = 0; i < reduct->dim()->size(); ++i) { - axises.emplace_back(reduct->dim()->data()[i]); - } - } - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - axises[i] = inputs[0]->dimensions() + axises[i]; - } - } - if (1 == axises.size() && TensorUtils::getDescribe(inputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4 && TensorUtils::getDescribe(outputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - auto cmd = GeometryComputerUtils::makeReduce(reductOp, inputs[0], outputs[0], axises[0]); - res.command.emplace_back(std::move(cmd)); - return true; - } // prod([]) = 1 if (inputs[0]->elementSize() == 0) { if(!context.allocTensor(outputs[0])) { @@ -141,7 +39,7 @@ class GeometryReduce : public GeometryComputer { } return true; } - auto reduceDims = _computeReduceDims(inputs, axises); + auto reduceDims = OpCommonUtils::computeReduceDims(inputs, op); Tensor* currentInput = inputs[0]; MNN_ASSERT(reduceDims.size() > 0); auto dimType = currentInput->getDimensionType(); diff --git a/source/geometry/GeometryReshape.cpp b/source/geometry/GeometryReshape.cpp index 1df3384e37..88d98a24c9 100644 --- a/source/geometry/GeometryReshape.cpp +++ b/source/geometry/GeometryReshape.cpp @@ -42,7 +42,8 @@ class GeometryReshape : public GeometryComputer { return true; } } - TensorUtils::makeFullRef(output, input); + outputDes->regions = {TensorUtils::makeFullSlice(input)}; + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; return true; } }; @@ -74,7 +75,10 @@ class SingleGeometryComputer : public GeometryComputer { Context& context, CommandBuffer& res) const override { auto input = inputs[0]; auto output = outputs[0]; - TensorUtils::makeFullRef(output, input); + auto inputDes = TensorUtils::getDescribe(input); + auto outputDes = TensorUtils::getDescribe(output); + outputDes->regions = {TensorUtils::makeFullSlice(input)}; + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; return true; } }; @@ -90,7 +94,8 @@ class CopyGeometryComputer : public GeometryComputer { outputDes->tensorArrayAttr = inputDes->tensorArrayAttr; return true; } - TensorUtils::makeFullRef(output, input); + outputDes->regions = {TensorUtils::makeFullSlice(input)}; + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; } return true; } diff --git a/source/math/Vec.hpp b/source/math/Vec.hpp index cc9354a7f1..6839ab83b0 100644 --- a/source/math/Vec.hpp +++ b/source/math/Vec.hpp @@ -372,7 +372,8 @@ struct Vec { using VecType = Vec; using VecTypeInt32 = Vec; float32x4_t value; - Vec() = default; + Vec() { + } Vec(const float v) { value = vdupq_n_f32(v); } diff --git a/test/core/ThreadPoolTest.cpp b/test/core/ThreadPoolTest.cpp index e010939e5f..6886f86e62 100644 --- a/test/core/ThreadPoolTest.cpp +++ b/test/core/ThreadPoolTest.cpp @@ -26,11 +26,11 @@ class ThreadPoolTest : public MNNTestCase { auto workIndex = threadPool->acquireWorkIndex(); FUNC_PRINT(workIndex); threadPool->active(); - ThreadPool::TASK task = std::make_pair([](int index) { + auto func = [](int index) { FUNC_PRINT(index); std::this_thread::yield(); - }, 10); - threadPool->enqueue(&task, workIndex); + }; + threadPool->enqueue(std::make_pair(std::move(func), 10), workIndex); threadPool->deactive(); threadPool->releaseWorkIndex(workIndex); }); diff --git a/tools/cpp/ExprDebug.hpp b/tools/cpp/ExprDebug.hpp index 49e3db6156..167e97c562 100644 --- a/tools/cpp/ExprDebug.hpp +++ b/tools/cpp/ExprDebug.hpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #define DUMP_NUM_DATA(type) \ @@ -136,69 +135,29 @@ static void _initDebug() { struct TimeTraceInfo { - std::map>> mTypes; + std::map>>> mTypes; void begin(const MNN::OperatorInfo* info) { auto tIter = mTypes.find(info->type()); if (tIter == mTypes.end()) { - std::map> _t; + std::map>> _t; mTypes.insert(std::make_pair(info->type(), _t)); tIter = mTypes.find(info->type()); } mInserIter = tIter->second.find(info->name()); if (mInserIter == tIter->second.end()) { - tIter->second.insert(std::make_pair(info->name(), std::make_tuple(0.0f, 0.0f, 0))); + std::vector> _t; + tIter->second.insert(std::make_pair(info->name(), _t)); mInserIter = tIter->second.find(info->name()); } mTimer.reset(); } void end(const MNN::OperatorInfo* info) { auto timeInMs = (float)mTimer.durationInUs() / 1000.0f; - std::get<0>(mInserIter->second) += timeInMs; - std::get<1>(mInserIter->second) += info->flops(); - std::get<2>(mInserIter->second) ++; - } - void dump(bool dumpPerOp = false) { - if (dumpPerOp) { - auto cmp = [](const std::tuple& first, const std::tuple& second) { - return std::get<1>(first) > std::get<1>(second); - }; - std::priority_queue, std::vector>, decltype(cmp)> que(cmp); - for (auto& iter : mTypes) { - for (auto& t : iter.second) { - auto mergeType = t.first + " ["+iter.first +"]"; - auto unit = std::make_tuple(mergeType, std::get<0>(t.second), std::get<1>(t.second), std::get<2>(t.second)); - que.push(unit); - } - } - while (!que.empty()) { - auto& t = que.top(); - MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", std::get<0>(t).c_str(), std::get<1>(t), std::get<2>(t), std::get<3>(t), std::get<2>(t) / std::get<1>(t)); - que.pop(); - } - return; - } - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - int count = 0; - for (auto& t : iter.second) { - summer += std::get<0>(t.second); - summerflops += std::get<1>(t.second); - count += std::get<2>(t.second); - } - MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, count, - summerflops / summer); - opSummer += summer; - opFlopsSummber += summerflops; - } - MNN_PRINT("OP Summer: %.7f ms, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, - opFlopsSummber / opSummer); + mInserIter->second.emplace_back(std::make_pair(timeInMs, info->flops())); } private: - std::map>::iterator mInserIter; + std::map>>::iterator mInserIter; MNN::Timer mTimer; }; static TimeTraceInfo* gTimeTraceInfo = nullptr; diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 5798bc6d26..90fa6b80d3 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -499,13 +499,10 @@ int main(int argc, char *argv[]) { if (runTime > 0) { int t = runTime; + std::vector times(t, 0.0f); if (runMask & 4) { _initTimeTrace(); } - float minTime = std::numeric_limits::max(); - float maxTime = 0.0f; - float sum = 0.0f; - for (int i = 0; i < t; ++i) { Timer _l; auto out = net->onForward(inputs); @@ -513,28 +510,41 @@ int main(int argc, char *argv[]) { for (auto o : out) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } - auto time = _l.durationInUs() / 1000.0f; + times[i] = _l.durationInUs() / 1000.0f; if (freq > 0.0f) { - float remainMs = (1000.0f / freq) - time; + float remainMs = (1000.0f / freq) - times[i]; if (remainMs > 0.0f) { std::this_thread::sleep_for(std::chrono::milliseconds((int)remainMs)); } } - if (maxTime < time) { - maxTime = time; - } - if (minTime > time) { - minTime = time; - } - sum += time; } if (nullptr != gTimeTraceInfo) { - MNN_PRINT("Per Op Trace: \n"); - gTimeTraceInfo->dump(true); - MNN_PRINT("Per Type Trace: \n"); - gTimeTraceInfo->dump(false); + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : gTimeTraceInfo->mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + for (auto& t : iter.second) { + for (auto& t0 : t.second) { + summer += t0.first; + summerflops += t0.second; + } + } + summer = summer / (float)t; + summerflops = summerflops / (float)t; + MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer); + opSummer += summer; + opFlopsSummber+= summerflops; + } + MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer); + } + auto minTime = std::min_element(times.begin(), times.end()); + auto maxTime = std::max_element(times.begin(), times.end()); + float sum = 0.0f; + for (auto time : times) { + sum += time; } - MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, minTime, maxTime); + MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, *minTime, *maxTime); } rtmgr->updateCache(); return 0; diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 5516eb2fcc..21f05e83be 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - ["A sample prompt", "A sample prompt"], + "A sample prompt", padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,7 +97,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes=None, + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, opset=opset, ) del pipeline.text_encoder @@ -115,9 +117,13 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -143,7 +149,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() onnx_export( vae_encoder, model_args=( @@ -153,24 +159,30 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] + vae_decoder.forward = vae_encoder.decode onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample"], + ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], - dynamic_axes=None, + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) del pipeline.vae diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 63c590e0fd..53af11239a 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -915,7 +915,26 @@ Llm::Llm(std::shared_ptr config) : mConfig(config) { Llm::~Llm() { #if DEBUG_MODE == 1 if (nullptr != gTimeTraceInfo) { - gTimeTraceInfo->dump(); + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : gTimeTraceInfo->mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + for (auto& t : iter.second) { + for (auto& t0 : t.second) { + summer += t0.first; + summerflops += t0.second; + } + } + summer = summer; + summerflops = summerflops; + MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, + summerflops / summer); + opSummer += summer; + opFlopsSummber += summerflops; + } + MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, + opFlopsSummber / opSummer); } #endif mGenerateParam.reset(); From 77d9089b97c55a2a2c8753adb753c9bc61d4de90 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:36:21 +0800 Subject: [PATCH 043/314] Merge pull request #4067 from ihb2032/opt/rvv-pixel-conv opt(RVV): Optimize blitter functions with intrinsics GitOrigin-RevId: 784bb542822e52ae67f017cb2adeaad7ce43c267 --- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 +++++++++++++++++ .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 ++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 +++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 ++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 +++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 ++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 +++++++++++++++++++ 11 files changed, 201 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp new file mode 100644 index 0000000000..145cbea73f --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp @@ -0,0 +1,18 @@ +#include + +void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp new file mode 100644 index 0000000000..d46fe6c85b --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp new file mode 100644 index 0000000000..684db6aed3 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp new file mode 100644 index 0000000000..9d524f13ca --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp @@ -0,0 +1,20 @@ +#include + +void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp new file mode 100644 index 0000000000..952fcaf090 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp @@ -0,0 +1,13 @@ +#include + +void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp new file mode 100644 index 0000000000..5ee4540f98 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp @@ -0,0 +1,16 @@ +#include + +void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp new file mode 100644 index 0000000000..f2b6c7a78d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp new file mode 100644 index 0000000000..ddd67a7d8c --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp new file mode 100644 index 0000000000..d56b58546d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp new file mode 100644 index 0000000000..7c6decf39e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp new file mode 100644 index 0000000000..1b946c33cc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} From 23e8e2aaa5b29ddb602df9333806df9175ee8fce Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:13 +0800 Subject: [PATCH 044/314] Merge pull request #4053 from ihb2032/opt/rvv-resize-functions opt(RVV): Optimize resize functions with intrinsics GitOrigin-RevId: e55248749f6c5b8c7c7d5b67d734f79943569955 --- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 +++++ .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 ++++++++ .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 ++++++++++ .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 ++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 +++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 +++++++++ .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 +++++++++++++++ 8 files changed, 373 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp new file mode 100644 index 0000000000..a700016c31 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp @@ -0,0 +1,19 @@ +#include + +void CPUBilinearLineC4(float* dst, const float* A, const float* B, + const float* t, int8_t* zeroPoint, size_t number) { + float tf = *t; + float sf = 1.0f - tf; + size_t total = number << 2; + + size_t i = 0; + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); + vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); + v = __riscv_vle32_v_f32m8(B + i, vl); + result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); + __riscv_vse32_v_f32m8(dst + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp new file mode 100644 index 0000000000..5063c39bff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp @@ -0,0 +1,33 @@ +#include + +void CPUBilinearSampleC4(const float* src, float* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); + vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); + vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); + vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp new file mode 100644 index 0000000000..a26243bdb8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp @@ -0,0 +1,40 @@ +#include + +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, + const float* t, int8_t* zeroPoint, size_t number) { + int offset = *zeroPoint; + int8_t* dstPtr = dst; + + const int pack = 8; + const int16_t df = (int16_t)((*t) * 128.0f); + const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); + const size_t total = number * pack; + const int32_t ROUND_HALF = 1 << 13; + + size_t vl; + for (size_t i = 0; i < total; i += vl) { + vl = __riscv_vsetvl_e16m4(total - i); + vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); + vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); + v16 = __riscv_vle16_v_i16m4(B + i, vl); + v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); + + vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); + + tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); + mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); + vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); + mask = __riscv_vmand_mm_b4(mask, hasRem, vl); + + v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); + v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); + v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); + vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); + + __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp new file mode 100644 index 0000000000..bd111e3be4 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp @@ -0,0 +1,49 @@ +#include + +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + int16_t offset = (int16_t)(*zeroPoint); + const int pack = 8; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + vint16m4_t vdf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); + vint16m4_t vsf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8( + __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), + c, vl); + + vint16m4_t va = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + + vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); + voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), + c, vl); + + vint16m4_t vb = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); + __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, + __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp new file mode 100644 index 0000000000..fd6ce7a274 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp @@ -0,0 +1,53 @@ +#include + +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const int offset = *zeroPoint; + const int minVal = (int)minValue; + const int maxVal = (int)maxValue; + const size_t total = number << 4; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); + vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); + acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); + + vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); + vint = __riscv_vadd_vx_i32m8(vint, offset, vl); + vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); + vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); + + vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); + vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); + __riscv_vse8_v_i8m2(dst + i, vi8, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp new file mode 100644 index 0000000000..0da63ca0ff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp @@ -0,0 +1,38 @@ +#include + +void MNNCubicLineC4(float* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const size_t total = number << 2; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + __riscv_vse32_v_f32m8(dst + i, acc, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp new file mode 100644 index 0000000000..fd5b24a53d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp @@ -0,0 +1,79 @@ +#include + +void MNNCubicSampleC16(const int8_t* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 16; + int8_t zp = *zeroPoint; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp new file mode 100644 index 0000000000..78207e69e8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp @@ -0,0 +1,62 @@ +#include + +void MNNCubicSampleC4(const float* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); + } + + i += vl; + } +} From 77bc18bde7b9db1586411f8bffcdef1fb9906966 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:55 +0800 Subject: [PATCH 045/314] Merge pull request #4050 from ihb2032/opt/rvv-top1 opt(RVV): Optimize top1 functions with intrinsics GitOrigin-RevId: f9f777c193cac1b7cf3201eb2bf789c782f31ca7 --- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 +++++++++++++++++++ .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp new file mode 100644 index 0000000000..7332360ce8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + float maxV = -FLT_MAX; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); + vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); + maxV = __riscv_vfmv_f_s_f32m1_f32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp new file mode 100644 index 0000000000..8c199709ec --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + int32_t maxV = INT32_MIN; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); + vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); + maxV = __riscv_vmv_x_s_i32m1_i32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} From fa5fc3123eed0259b68509e86b72936575752ba2 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:36 +0800 Subject: [PATCH 046/314] Merge pull request #4044 from ihb2032/opt/rvv-softmax-relu opt(RVV): Optimize Softmax and ReluWithSlopeChannel with intrinsics GitOrigin-RevId: 07b2b4e3b678f2b440bb954b58760d47d7c54689 --- .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 +++++++++++ source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 +++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp new file mode 100644 index 0000000000..262f4cbfab --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp @@ -0,0 +1,45 @@ +#include + +void MNNReluWithSlopeChannel(float *dst, const float *src, + const float *slope, size_t sizeQuad, + size_t depthQuad) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t j = 0; j < depthQuad; ++j) { + const float *srcZ = src + 4 * j * sizeQuad; + float *dstZ = dst + 4 * j * sizeQuad; + float s0 = slope[4*j], s1 = slope[4*j + 1]; + float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; + size_t i = 0; + while (i < sizeQuad) { + size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); + const float *srcBase = srcZ + 4*i; + float *dstBase = dstZ + 4*i; + + vfloat32m8_t v; + vbool4_t mask; + + v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); + __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); + __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); + __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); + __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); + + i += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp new file mode 100644 index 0000000000..f510058c83 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp @@ -0,0 +1,80 @@ +#include +#include + +void MNNSoftmax(float *dest, const float *source, size_t size) { + size_t n = size; + const float *sourcePtr = source; + float *destPtr = dest; + float maxValue = -FLT_MAX; + vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); + maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); + sourcePtr += vl; + n -= vl; + } + + maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); + const float param = 0.6931471805599453f; + const float xLimit = 87.0f; + float sumValue = 0.f; + vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); + n = size; + sourcePtr = source; + destPtr = dest; + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); + vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); + vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); + vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); + + vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); + vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); + + vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( + __riscv_vsll_vx_i32m8( + __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); + + vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); + vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); + + vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + + vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); + __riscv_vse32_v_f32m8(destPtr, vA, vl); + sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); + + sourcePtr += vl; + destPtr += vl; + n -= vl; + } + + sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); + float sumInv = 1.0f / sumValue; + n = size; + destPtr = dest; + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); + vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); + __riscv_vse32_v_f32m8(destPtr, vDest, vl); + destPtr += vl; + n -= vl; + } +} From 98d21c2a75b5fc98d973a33e12ebdb1184f7ea60 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:54 +0800 Subject: [PATCH 047/314] Merge pull request #4042 from ihb2032/opt/rvv-conv-strassen opt(RVV): Optimize conv and strassen functions with intrinsics GitOrigin-RevId: 29b59dacf57d9d4fb4438209ac292956ae59b134 --- .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 +++++++++++++++++++ .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 ++++++++++++++++ .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 ++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp new file mode 100644 index 0000000000..f82faf83f5 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp @@ -0,0 +1,48 @@ +#include + +void MNNConvRunForLineDepthwise( + float* dst, const float* src, const float* weight, + size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, + size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters) { + float minV = parameters[0]; + float maxV = parameters[1]; + ptrdiff_t srcByteStride = src_w_setup * sizeof(float); + ptrdiff_t dstByteStride = 4 * sizeof(float); + + for (size_t y = 0; y < height; ++y) { + const float* srcY = src + y * srcHStep; + float* dstY = dst + y * dstHStep; + size_t dx = 0; + + while (dx < width) { + size_t vl = __riscv_vsetvl_e32m8(width - dx); + + for (int c = 0; c < 4; ++c) { + vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); + const float* srcBase = srcY + dx * src_w_setup + c; + const float* weightPtr = weight + c; + + for (size_t fy = 0; fy < fh; ++fy) { + const float* srcFy = srcBase + fy * dilateY_step; + + for (size_t fx = 0; fx < fw; ++fx) { + float w = *weightPtr; + weightPtr += 4; + const float* srcFx = srcFy + fx * dilateX_step; + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); + } + } + + acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); + acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); + float* dstAddr = dstY + dx * 4 + c; + __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); + } + + dx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp new file mode 100644 index 0000000000..6658715e7e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp @@ -0,0 +1,42 @@ +#include + +void MNNDeconvRunForUnitDepthWise( + const float* dst, float* src, const float* weight, + size_t fw, size_t fh, + size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { + const ptrdiff_t wStride = 4 * sizeof(float); + const ptrdiff_t sStride = dilateX_step * sizeof(float); + float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; + + for (size_t fy = 0; fy < fh; ++fy) { + float* srcY = src + fy * dilateY_step; + const float* weightY = weight + fy * weightY_step; + + size_t fx = 0; + while (fx < fw) { + size_t vl = __riscv_vsetvl_e32m8(fw - fx); + + vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); + __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); + __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); + __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); + __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); + + fx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp new file mode 100644 index 0000000000..8ab5bb89fa --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp @@ -0,0 +1,36 @@ +#include + +void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, + float *xAddr, size_t cStride, size_t eSub, size_t hSub) { + for (int y = 0; y < hSub; ++y) { + float *c11Y = c11 + y * cStride; + float *c12Y = c12 + y * cStride; + float *c22Y = c22 + y * cStride; + float *c21Y = c21 + y * cStride; + float *xY = xAddr + y * eSub * 4; + size_t totalElements = eSub * 4; + size_t p = 0; + + while (p < totalElements) { + size_t vl = __riscv_vsetvl_e32m8(totalElements - p); + vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); + vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); + t = __riscv_vfadd_vv_f32m8(t, tmp, vl); + vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); + + tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); + __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); + + tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); + __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); + + c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); + __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); + + p += vl; + } + } +} From c677acc10f1e00adcff0aad5bd8f1d1d06bddc88 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:07 +0800 Subject: [PATCH 048/314] Merge pull request #4036 from ihb2032/opt/rvv-minmax-float opt(RVV): Optimize max and min float functions with intrinsics GitOrigin-RevId: 826e9dd9b4bb8b260d29bc9574840b83ec8e9154 --- source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 ++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp new file mode 100644 index 0000000000..183a38bb10 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { + const float init = -FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + maxBuffer[j] = local; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp new file mode 100644 index 0000000000..9e8ade8641 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { + const float init = FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + minBuffer[j] = local; + } +} From 90bd50dec4b33ff3829f75f0cdf3a1fec704af6c Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:38 +0800 Subject: [PATCH 049/314] Merge pull request #4026 from ihb2032/opt/rvv-math-stride-ops opt(RVV): Optimize core math and stride functions with intrinsics GitOrigin-RevId: 3036cf5098b26250c04bef3b93801f9f0caf62a6 --- .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 +++++++++++ .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 ++++++++ .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 +++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp new file mode 100644 index 0000000000..59bb28a039 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp @@ -0,0 +1,29 @@ +#include + +void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); + size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp new file mode 100644 index 0000000000..6d966789f7 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp @@ -0,0 +1,52 @@ +#include + +void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { + float beta = parameters[1]; + float minF = parameters[2]; + float maxF = parameters[3]; + const ptrdiff_t stride = 4 * sizeof(float); + + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + 4 * y; + auto c = C + cStride * y; + float b0Beta = b[0] * beta; + float b1Beta = b[1] * beta; + float b2Beta = b[2] * beta; + float b3Beta = b[3] * beta; + size_t w = width; + + while (w > 0) { + size_t vl = __riscv_vsetvl_e32m8(w); + + vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); + + a += 4 * vl; + c += 4 * vl; + w -= vl; + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp new file mode 100644 index 0000000000..3d8c4f13fc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp @@ -0,0 +1,22 @@ +#include + +void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); +size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp new file mode 100644 index 0000000000..10992f9d59 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp @@ -0,0 +1,42 @@ +#include + +void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t z = 0; z < biasNumber; ++z) { + float *dstZ = dst + z * planeNumber * 4; + const float *srcZ = src + z * planeNumber * 4; + const float *biasZ = bias + 4 * z; + const float *alphaZ = alpha + 4 * z; + float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; + float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; + + size_t n = planeNumber; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a0, vl); + data = __riscv_vfadd_vf_f32m8(data, b0, vl); + __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a1, vl); + data = __riscv_vfadd_vf_f32m8(data, b1, vl); + __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a2, vl); + data = __riscv_vfadd_vf_f32m8(data, b2, vl); + __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a3, vl); + data = __riscv_vfadd_vf_f32m8(data, b3, vl); + __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); + + srcZ += vl * 4; + dstZ += vl * 4; + n -= vl; + } + } +} From 98e98e38ba2c520b8bf17fd2edaeed717858ca3a Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:52 +0800 Subject: [PATCH 050/314] Merge pull request #4023 from ihb2032/feature/rvv-transpose-functions opt(RVV): Optimize transpose functions with intrinsics GitOrigin-RevId: 72c11be4d128d054363c863f375136f3972a2ab0 --- .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 +++++++++++++++++++ .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 ++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp new file mode 100644 index 0000000000..7598d6f8ac --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp @@ -0,0 +1,26 @@ +#include + +void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); + + for (int i = 0; i < h; ++i) { + const int16_t* srcPtr = srcO + i; + int16_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e16m8(w - j); + vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); + __riscv_vse16_v_i16m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + + diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp new file mode 100644 index 0000000000..e5c5eb83e6 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp @@ -0,0 +1,25 @@ +#include + +void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); + + for (int i = 0; i < h; ++i) { + const int32_t* srcPtr = srcO + i; + int32_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e32m8(w - j); + vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); + __riscv_vse32_v_i32m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + From 308b8cd6e8ba6e0857c8c9996470fa0ae028c44a Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:44:24 +0800 Subject: [PATCH 051/314] Merge pull request #4021 from ihb2032/feature/rvv-opt opt(RVV): Optimize pack and unpack functions with intrinsics GitOrigin-RevId: d9f4036b55096f812885e84113c96876e431147e --- source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 ++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 ++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp new file mode 100644 index 0000000000..9a74f8998d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp @@ -0,0 +1,74 @@ +#include + +void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC2 = depth / 2; + int depthRemain = depthC2 * 2; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[2]; + + for (int z = 0; z < depthC2; ++z) { + float *dstZ = dst + z * areaOffset[1] * 2; + + for (int y = 0; y < 2; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + } + + srcOffset += areaOffset[0] * 2; + } + + if (remain > 0) { + float *dstZ = dst + depthC2 * areaOffset[1] * 2; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 2; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 2; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp new file mode 100644 index 0000000000..024e2c8c07 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp @@ -0,0 +1,80 @@ +#include + +void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[4]; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ = dst + z * areaOffset[1] * 4; + + for (int y = 0; y < 4; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + dstPtr[2] = srcChannel[2][x]; + dstPtr[3] = srcChannel[3][x]; + } + + srcOffset += areaOffset[0] * 4; + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 4; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 4; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp new file mode 100644 index 0000000000..4676e6dede --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp @@ -0,0 +1,55 @@ +#include + +void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ[4]; + + for (int y = 0; y < 4; ++y) { + dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); + srcOffset += 4 * vl; + } + + for (; x < area; ++x) { + dstZ[0][x] = srcOffset[0]; + dstZ[1][x] = srcOffset[1]; + dstZ[2][x] = srcOffset[2]; + dstZ[3][x] = srcOffset[3]; + srcOffset += (areaOffset[0] - area) * 4; + } + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + const float *srcBase = srcOffset; + + for (int y = 0; y < remain; ++y) { + float *dstChannel = dstZ + y * areaOffset[1]; + const float *srcChannel = srcBase + y; + + for (size_t x = 0; x < area; ++x) { + dstChannel[x] = srcChannel[0]; + srcChannel += 4; + } + } + } +} + From fe508afc19eeea368d7b07cdc7a71dd8cf5c149e Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:54:53 +0800 Subject: [PATCH 052/314] Merge pull request #4061 from zlaazlaa/fix_diffusion fix(diffusion): simplify export logic and fix dynamic axes GitOrigin-RevId: cc6faf47f33d462e2e1ac613ec710ce55c39a86a --- docs/transformers/diffusion.md | 3 +- transformers/diffusion/export/onnx_export.py | 30 ++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 7de27bb216..609793f806 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path + --output_path onnx_save_path \ + --opset 18 ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 21f05e83be..5516eb2fcc 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - "A sample prompt", + ["A sample prompt", "A sample prompt"], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,9 +97,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.text_encoder @@ -117,13 +115,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "timestep": {0: "batch"}, - "encoder_hidden_states": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -149,7 +143,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() onnx_export( vae_encoder, model_args=( @@ -159,30 +153,24 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = vae_encoder.decode + vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), - False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample", "return_dict"], + ordered_input_names=["latent_sample"], output_names=["sample"], - dynamic_axes={ - "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.vae From bd530d04724aea97990eb650d6e899dfd8d4c729 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:03 +0800 Subject: [PATCH 053/314] Merge pull request #3998 from bolun365/bolun365-patch-1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mnn lib库自动化build脚本 GitOrigin-RevId: 9bac02d0d7bbb82f6a2cd42b01789f5efbdefd8c --- build_lib.sh | 807 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 807 insertions(+) create mode 100644 build_lib.sh diff --git a/build_lib.sh b/build_lib.sh new file mode 100644 index 0000000000..c839b6e7b6 --- /dev/null +++ b/build_lib.sh @@ -0,0 +1,807 @@ +#!/bin/bash + +# MNN 统一构建脚本 +# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN + +set -e + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# 默认配置 +BUILD_ANDROID=false +BUILD_IOS=false +BUILD_PYTHON=false +BUILD_IOS_SIMULATOR=false +BUILD_HARMONY=false +ANDROID_NDK="" +HARMONY_HOME="" +OUTPUT_DIR="mnn_builds" +CLEAN_BUILD=true +VERSION="" +PYTHON_CMD="python3" +DEPS_OPTIONS="" + +# 获取项目根目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR" + +print_usage() { + echo "MNN 统一构建脚本" + echo "" + echo "用法: $0 [选项]" + echo "" + echo "构建目标:" + echo " --android 构建 Android 版本 (需要 --ndk)" + echo " --ios 构建 iOS 真机版本 (arm64)" + echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" + echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" + echo " --python 构建 Python 版本" + echo "" + echo "Android 选项:" + echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" + echo "" + echo "鸿蒙选项:" + echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" + echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" + echo "" + echo "Python 选项:" + echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" + echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" + echo " --python-cmd CMD Python 命令 (默认: python3)" + echo "" + echo "通用选项:" + echo " -o, --output DIR 输出目录 (默认: mnn_builds)" + echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" + echo " --no-clean 不清理之前的构建目录" + echo " -h, --help 显示帮助信息" + echo "" + echo "示例:" + echo " # 构建所有平台" + echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 仅构建 Android" + echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 构建鸿蒙版本" + echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" + echo "" + echo " # 构建 iOS 真机和模拟器" + echo " $0 --ios --ios-simulator" + echo "" + echo " # 构建 Python (带 LLM 支持)" + echo " $0 --python --python-deps llm,opencl" +} + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + --android) + BUILD_ANDROID=true + shift + ;; + --ios) + BUILD_IOS=true + shift + ;; + --ios-simulator) + BUILD_IOS_SIMULATOR=true + shift + ;; + --python) + BUILD_PYTHON=true + shift + ;; + --harmony) + BUILD_HARMONY=true + shift + ;; + --ndk) + ANDROID_NDK="$2" + shift 2 + ;; + --harmony-home) + HARMONY_HOME="$2" + shift 2 + ;; + --python-deps) + DEPS_OPTIONS="$2" + shift 2 + ;; + --python-cmd) + PYTHON_CMD="$2" + shift 2 + ;; + -o|--output) + OUTPUT_DIR="$2" + shift 2 + ;; + -v|--version) + VERSION="$2" + shift 2 + ;; + --no-clean) + CLEAN_BUILD=false + shift + ;; + -h|--help) + print_usage + exit 0 + ;; + *) + echo -e "${RED}错误: 未知选项 $1${NC}" + print_usage + exit 1 + ;; + esac +done + +# 检查是否至少选择了一个构建目标 +if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then + echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" + print_usage + exit 1 +fi + +# 检查 Android NDK +if [ "$BUILD_ANDROID" = true ]; then + if [ -z "$ANDROID_NDK" ]; then + echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" + echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" + exit 1 + fi + + # 展开路径中的 ~ + ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" + + if [ ! -d "$ANDROID_NDK" ]; then + echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" + exit 1 + fi +fi + +# 查找鸿蒙工具链的函数 +find_harmony_toolchain() { + # 如果环境变量已设置,优先使用 + if [ -n "$HARMONY_HOME" ]; then + # 支持两种路径格式:直接指定或带 native 子目录 + if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + fi + fi + + # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 + local sdk_base="$HOME/Library/OpenHarmony/Sdk" + if [ -d "$sdk_base" ]; then + # 查找所有版本号目录下的 native 或 toolchains + # 按版本号倒序排列,优先使用最新版本 + for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do + if [ -d "$version_dir" ]; then + # 尝试 native/build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir/native" + return 0 + fi + # 尝试 build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir" + return 0 + fi + fi + done + fi + + # 其他可能的路径 + local possible_paths=( + "$HOME/Library/OpenHarmony/Sdk/native" + "$HOME/HarmonyOS/Sdk/native" + "$HOME/.ohos/native" + "/opt/HarmonyOS/Sdk/native" + "/usr/local/HarmonyOS/Sdk/native" + "$HOME/Library/HarmonyOS/Sdk/native" + ) + + # 尝试查找 + for path in "${possible_paths[@]}"; do + if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then + echo "$path" + return 0 + fi + done + + # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 + # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 + local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) + if [ -n "$found" ]; then + # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 + found=$(dirname "$found") + if [ "$(basename "$found")" = "cmake" ]; then + found=$(dirname "$found") + if [ "$(basename "$found")" = "build" ]; then + found=$(dirname "$found") + fi + fi + echo "$found" + return 0 + fi + + return 1 +} + +# 检查鸿蒙工具链 +if [ "$BUILD_HARMONY" = true ]; then + # 展开路径中的 ~ + if [ -n "$HARMONY_HOME" ]; then + HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" + fi + + # 尝试查找工具链 + if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + # 检查是否在 native 子目录下 + if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + else + echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" + HARMONY_HOME=$(find_harmony_toolchain) + + if [ -z "$HARMONY_HOME" ]; then + echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" + echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" + echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" + echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi + + # 如果找到的是带 native 的路径,需要调整 + if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + fi + fi + + echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" + fi + + # 验证工具链文件存在 + if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi +fi + +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}MNN 统一构建脚本${NC}" +echo -e "${GREEN}========================================${NC}" +echo "项目根目录: $PROJECT_ROOT" +echo "输出目录: $OUTPUT_DIR" +echo "" +echo -e "${BLUE}构建目标:${NC}" +[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" +[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" +[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" +[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" +echo "" + +cd "$PROJECT_ROOT" +mkdir -p "$OUTPUT_DIR" + +# ============================================================================ +# 构建 Android 版本 +# ============================================================================ +if [ "$BUILD_ANDROID" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Android 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export ANDROID_NDK + + ANDROID_BUILD_DIR="project/android" + ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" + + cd "$ANDROID_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 Android 构建...${NC}" + rm -rf build_32 build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + + # 构建 armeabi-v7a + echo -e "${BLUE}构建 armeabi-v7a...${NC}" + mkdir -p build_32 + cd build_32 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="armeabi-v7a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-14 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出 Android 库文件...${NC}" + mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN + + cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true + cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/android/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" + + echo -e "${GREEN}Android 构建完成!${NC}" + echo "输出位置: $ANDROID_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 真机版本 +# ============================================================================ +if [ "$BUILD_IOS" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 真机版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_64 + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建真机版本 (arm64) + echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" + rm -rf ios_64 + mkdir ios_64 + cd ios_64 + cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=OS64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + > /dev/null + make MNN -j8 > /dev/null + cd .. + + # 输出真机版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理 + rm -rf ios_64 + + echo -e "${GREEN}iOS 真机构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 模拟器版本 +# ============================================================================ +if [ "$BUILD_IOS_SIMULATOR" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) + BUILD_X86_64=false + echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" + rm -rf ios_simulator_x86 + mkdir ios_simulator_x86 + cd ios_simulator_x86 + if cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="x86_64" \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + && make MNN -j8; then + if [ -f "MNN.framework/MNN" ]; then + BUILD_X86_64=true + echo -e "${GREEN}x86_64 模拟器构建成功${NC}" + cd .. + else + echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + else + echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + + # 构建 arm64 模拟器版本 + echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" + rm -rf ios_simulator_arm64 + mkdir ios_simulator_arm64 + cd ios_simulator_arm64 + if ! cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON; then + echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + if ! make MNN -j8; then + echo -e "${RED}错误: arm64 模拟器编译失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + cd .. + + # 验证构建产物 + if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then + echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + + # 合并模拟器架构 + echo -e "${BLUE}合并模拟器架构...${NC}" + rm -rf ios_simulator + mkdir ios_simulator + + if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then + # 合并 x86_64 + arm64 + echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" + cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework + mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 + if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then + echo -e "${RED}错误: 合并架构失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + rm ios_simulator/MNN.framework/MNN_x86 + else + # 仅使用 arm64 + echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" + cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework + fi + + # 输出模拟器版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理临时目录 + rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 + + echo -e "${GREEN}iOS 模拟器构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建鸿蒙版本 +# ============================================================================ +if [ "$BUILD_HARMONY" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建鸿蒙版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export HARMONY_HOME + + HARMONY_BUILD_DIR="project/harmony" + HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" + + cd "$HARMONY_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" + rm -rf build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + + cmake "$PROJECT_ROOT" \ + -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DOHOS_ARCH="arm64-v8a" \ + -DOHOS_STL=c++_static \ + -DMNN_USE_LOGCAT=true \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_SUPPORT_BF16=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DOHOS_PLATFORM_LEVEL=9 \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出鸿蒙库文件...${NC}" + mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN + + cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/harmony/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" + + echo -e "${GREEN}鸿蒙构建完成!${NC}" + echo "输出位置: $HARMONY_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 Python 版本 +# ============================================================================ +if [ "$BUILD_PYTHON" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Python 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查 Python + if ! command -v $PYTHON_CMD &> /dev/null; then + echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" + exit 1 + fi + + PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" + + # 使用之前创建的 build_pymnn.sh 脚本 + if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then + PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" + [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" + [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" + [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" + PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" + + bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS + else + echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" + + cd pymnn/pip_package + + # 构建依赖 + if [ -n "$DEPS_OPTIONS" ]; then + $PYTHON_CMD build_deps.py $DEPS_OPTIONS + else + $PYTHON_CMD build_deps.py + fi + + # 构建 wheel + $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet + rm -rf build dist + + BUILD_ARGS="" + [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" + [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" + + $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS + + mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" + cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" + + cd "$PROJECT_ROOT" + fi + + echo -e "${GREEN}Python 构建完成!${NC}" + echo "输出位置: $PYTHON_OUTPUT_DIR" +fi + +# ============================================================================ +# 总结 +# ============================================================================ +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}所有构建完成!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" +echo "" +[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" +[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" +[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" +[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" +echo "" + + From 6cd5795b0fb154607ed94d911bb6c4c115f13cef Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:54 +0800 Subject: [PATCH 054/314] Merge pull request #4009 from HenryDen/default_opt Add a compile option and macro to default enable kleidiAI GitOrigin-RevId: d252203d159374844e90bfe13589b9c0c36f62ee --- CMakeLists.txt | 1 + source/backend/cpu/arm/CMakeLists.txt | 3 +++ source/core/Backend.hpp | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 67502b606b..f99e37ec1c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,6 +258,7 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) +option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 18fca54a4e..61ebce6bdc 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,6 +36,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() + if(MNN_KLEIDIAI_DEFAULT_ON) + add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) + endif() endif() if (MNN_SME2) diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index bcf618c3c9..6850b6b4f6 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,9 +68,11 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; - +#ifdef MNN_DEFAULT_USE_KLEIDIAI + bool enableKleidiAI = true; +#else bool enableKleidiAI = false; - +#endif // Use CPU Ids std::vector cpuIds; From d4812d3a68863cf85c9a1e50bc068c15df1ffbe0 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:42:22 +0800 Subject: [PATCH 055/314] Merge branch feature/add_4th_groupchat into master Title: [Doc:Update] update dingtalk in README. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审的主要改动是对README文件中的钉钉群信息进行了更新,包括群号、状态以及删除了一些过时的信息。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25029869 GitOrigin-RevId: 323623143de7fac53e2a4683e9a3c2090f392ae6 --- README.md | 14 +++++++------- README_CN.md | 10 ++++------ README_JP.md | 9 +++++---- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 5fe168ed05..7959890c16 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #1 (Full): 23329087 +Group #4 (Available): 160170007549 -Group #2 (Full): 23350225 +Group #3 (Full) -Group #3: QR code: +Group #2 (Full): 23350225 -![MNN-3](doc/dingdingmnn3.png) +Group #1 (Full): 23329087 ## Historical Paper diff --git a/README_CN.md b/README_CN.md index edcf823a28..f769a1e14b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,12 +111,10 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群1:23329087 -- 钉钉群2:23350225 -- 钉钉群3:扫描二维码加入 - -![MNN-3](doc/dingdingmnn3.png) - +- 钉钉群3 (可加入): 160170007549 +- 钉钉群3 (已无法加入) +- 钉钉群2 (已满): 23350225 +- 钉钉群1 (已满): 23329087 ## 历史论文 diff --git a/README_JP.md b/README_JP.md index c2baa58d94..2f33def31a 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,13 +117,14 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: -グループ#1(満員):23329087 -グループ#2(満員):23350225 +グループ#4 :160170007549 -グループ#3:QRコード: +グループ#3 (満員) -![MNN-3](doc/dingdingmnn3.png) +グループ#2(満員):23350225 + +グループ#1(満員):23329087 ## 歴史的な論文 From 5d57463911d11249c73dd9d651344d53fe8a7063 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 15:02:01 +0800 Subject: [PATCH 056/314] Merge pull request #4027 from codefuturedalao/master [BugFix] fix a bug in compute mGroupWithComputeRate GitOrigin-RevId: 0a30b5c040bc34aff1de94e7fa571ebb8f2c20fa --- source/backend/cpu/CPUBackend.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 0e0bc1f136..95cbd903b7 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -491,6 +491,7 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p currentRate *= decreaseRate; totalComputeRate += currentRate * selectSize; mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize)); + groupIndex--; } for (auto& g : mGroupWithComputeRate) { g.first = g.first / totalComputeRate; From b969229773594b61fa1a9175d0d3bf7101574db0 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 19:18:35 +0800 Subject: [PATCH 057/314] Merge pull request #4076 from jxt1234/feature/smallmodel_opt Feature/smallmodel opt GitOrigin-RevId: 5610add6e64c6d49f8b984d0d744c85f206f2be7 --- source/backend/cpu/CPUBackend.cpp | 7 +- source/backend/cpu/CPUBackend.hpp | 3 + source/backend/cpu/CPUBinary.cpp | 60 +- source/backend/cpu/CPUBinary.hpp | 4 + source/backend/cpu/CPUMatMul.cpp | 28 +- source/backend/cpu/CPUMatMul.hpp | 7 +- source/backend/cpu/CPURNNSequenceGRU.cpp | 70 +- source/backend/cpu/CPURNNSequenceGRU.hpp | 15 +- source/backend/cpu/CPURaster.cpp | 631 +++++++++--------- source/backend/cpu/CPURaster.hpp | 3 +- source/backend/cpu/ThreadPool.cpp | 32 +- source/backend/cpu/ThreadPool.hpp | 6 +- .../backend/cpu/compute/CommonOptFunction.cpp | 88 ++- source/core/Concurrency.h | 13 +- source/core/OpCommonUtils.cpp | 91 --- source/core/OpCommonUtils.hpp | 1 - source/core/TensorUtils.cpp | 12 + source/core/TensorUtils.hpp | 1 + source/geometry/GeometryComputerUtils.cpp | 4 +- source/geometry/GeometryComputerUtils.hpp | 2 +- source/geometry/GeometryReduce.cpp | 104 ++- source/geometry/GeometryReshape.cpp | 11 +- source/math/Vec.hpp | 3 +- test/core/ThreadPoolTest.cpp | 6 +- tools/cpp/ExprDebug.hpp | 53 +- tools/cpp/ModuleBasic.cpp | 46 +- transformers/llm/engine/src/llm.cpp | 21 +- 27 files changed, 747 insertions(+), 575 deletions(-) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 95cbd903b7..8d284aa33b 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -104,15 +104,14 @@ void CPURuntime::_bindCPUCore() const { #ifdef MNN_USE_THREAD_POOL if (nullptr != mThreadPool) { mThreadPool->active(); - mThreadPool->enqueue(std::make_pair([&](int i) { + ThreadPool::TASK task = std::make_pair([&](int i) { MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); - return 0; - }, mThreadNumber), mTaskIndex); + }, mThreadNumber); + mThreadPool->enqueue(&task, mTaskIndex); mThreadPool->deactive(); } #endif } - void CPURuntime::_resetThreadPool() const { mThreadNumber = std::max(1, mThreadNumber); mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 884036eb38..ec4c555dec 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -176,6 +176,9 @@ class CPUBackend : public Backend { #ifdef MNN_USE_THREAD_POOL inline int taskIndex() const {return mRuntime->mTaskIndex;} inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;} + void enqueue(ThreadPool::TASK& task) const { + threadPool()->enqueue(&task, taskIndex()); + } #endif static void initCreatorMap(); static size_t getBytes(const Backend* backend, const Tensor* output); diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp index 059e502d0b..61ccf4fca3 100644 --- a/source/backend/cpu/CPUBinary.cpp +++ b/source/backend/cpu/CPUBinary.cpp @@ -45,6 +45,37 @@ ErrorCode CPUBinary::onResize(const std::vector& inputs, const std::vec mThreadNum = threads; mWorkDiv = UP_DIV(mTotalSize, threads); } + int inpBytes = inputs[0]->getType().bytes(); + int outBytes = outputs[0]->getType().bytes(); + if (halide_type_float == inputs[0]->getType().code) { + inpBytes = static_cast(backend())->functions()->bytes; + } + if (halide_type_float == outputs[0]->getType().code) { + outBytes = static_cast(backend())->functions()->bytes; + } + bool outputInt = outputs[0]->getType().code == halide_type_int; + mTask = std::make_pair([this, inpBytes, outBytes, outputInt](int tId) { + int start = tId * mWorkDiv; + int realSize = ALIMIN(mWorkDiv, mTotalSize - start); + if (realSize > 0) { + auto inp0 = mInput0Ptr + start * inpBytes; + auto inp1 = mInput1Ptr + start * inpBytes; + if (mNeedBroadcastIndex == 0) { + inp0 = mInput0Ptr; + } else if (mNeedBroadcastIndex == 1) { + inp1 = mInput1Ptr; + } + auto out = mOutputPtr + start * outBytes; + mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); + if(mActivationType == 1 && outputInt) { + for(int i=0; i 0 ? val : 0; + ((int32_t *)out)[i] = res; + } + } + } + } , mThreadNum); return NO_ERROR; } @@ -67,31 +98,10 @@ ErrorCode CPUBinary::onExecute(const std::vector& inputs, const std::ve outBytes = static_cast(backend())->functions()->bytes; } auto precision = static_cast(backend())->precisionMode(); - - MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { - int start = tId * mWorkDiv; - int realSize = ALIMIN(mWorkDiv, mTotalSize - start); - if (realSize > 0) { - auto inp0 = input0Ptr + start * inpBytes; - auto inp1 = input1Ptr + start * inpBytes; - if (mNeedBroadcastIndex == 0) { - inp0 = input0Ptr; - } else if (mNeedBroadcastIndex == 1) { - inp1 = input1Ptr; - } - auto out = outputPtr + start * outBytes; - mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); - if(mActivationType == 1 && output->getType().code == halide_type_int) { - for(int i=0; i 0 ? val : 0; - ((int32_t *)out)[i] = res; - } - } - } - } - MNN_CONCURRENCY_END(); - + mInput0Ptr = input0Ptr; + mInput1Ptr = input1Ptr; + mOutputPtr = outputPtr; + MNN_CONCURRENCY_ENQUEUE(mTask); if(mActivationType == 1 && output->getType().code == halide_type_float) { mActivationExe->onExecute(outputs, outputs); } diff --git a/source/backend/cpu/CPUBinary.hpp b/source/backend/cpu/CPUBinary.hpp index 9250df79ae..17cb3b5f47 100644 --- a/source/backend/cpu/CPUBinary.hpp +++ b/source/backend/cpu/CPUBinary.hpp @@ -33,6 +33,10 @@ class CPUBinary : public Execution { int mThreadNum; int mWorkDiv; std::shared_ptr mActivationExe; + std::pair, int> mTask; + uint8_t* mInput0Ptr = nullptr; + uint8_t* mOutputPtr = nullptr; + uint8_t* mInput1Ptr = nullptr; }; } // namespace MNN #endif /* CPUBinary_hpp */ diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 4f0765f050..22b96a64ee 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -37,9 +37,8 @@ void CPUMatMul::_scheduleForVecE(int e, int l, int h) { param.BTranspose = mTransposeB; param.numberThread = numberThread; auto func = static_cast(backend())->functions()->MNNComputeMatMulForE_1; - mPreFunctions.emplace_back(std::make_pair([param, func]( - int tId, const float* A, const float* B, const float* biasPtr, float* C) { - func(A, B, C, biasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func, this](int tId) { + func(mA, mB, mC, mBiasPtr, ¶m, tId); }, numberThread)); } @@ -54,9 +53,9 @@ void CPUMatMul::_scheduleForVec(int e, int l, int h) { auto func = static_cast(backend())->functions()->MNNComputeMatMulForH_1; // TODD: Support e = 1 MNN_ASSERT(h == 1); - mPreFunctions.emplace_back(std::make_pair([param, func]( - int tId, const float* A, const float* B, const float* biasPtr, float* C) { - func(A, B, C, biasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func, this]( + int tId) { + func(mA, mB, mC, mBiasPtr, ¶m, tId); }, numberThread)); } @@ -100,8 +99,8 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec return OUT_OF_MEMORY; } - mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) { - core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB); + mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId) { + core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), mB, h, 1, l, mTransposeB); } , 1)); bool useBias = false; MemChunk bdestAlloc; @@ -120,9 +119,9 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec } mTempBias = bdestAlloc; mPreFunctions.emplace_back(std::make_pair( - [biasLength, bdestAlloc, core](int tId, const float* APtr, const float* BPtr, const float* borigin, float* C) { + [biasLength, bdestAlloc, core, this](int tId) { ::memset(bdestAlloc.ptr(), 0, UP_DIV(biasLength, core->pack) * core->bytes * core->pack); - ::memcpy(bdestAlloc.ptr(), borigin, biasLength * core->bytes); + ::memcpy(bdestAlloc.ptr(), mBiasPtr, biasLength * core->bytes); }, 1)); } else { mUseBiasDirectly = true; @@ -167,11 +166,12 @@ ErrorCode CPUMatMul::onExecute(const std::vector& inputs, const std::ve } void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const float* biasPtr) { + mA = APtr; + mB = BPtr; + mC = CPtr; + mBiasPtr = biasPtr; for (auto& f : mPreFunctions) { - MNN_CONCURRENCY_BEGIN(tId, f.second) { - f.first(tId, APtr, BPtr, biasPtr, CPtr); - } - MNN_CONCURRENCY_END(); + MNN_CONCURRENCY_ENQUEUE(f); } if (mE > 0) { auto core = static_cast(backend())->functions(); diff --git a/source/backend/cpu/CPUMatMul.hpp b/source/backend/cpu/CPUMatMul.hpp index 872a77a9a8..48226795f0 100644 --- a/source/backend/cpu/CPUMatMul.hpp +++ b/source/backend/cpu/CPUMatMul.hpp @@ -29,7 +29,7 @@ class CPUMatMul : public Execution { bool mTransposeB; bool mTransposeC; bool mSupportMultiThread = false; - std::vector, int>> mPreFunctions; + std::vector, int>> mPreFunctions; bool mUseBiasDirectly = false; MemChunk mTempA; MemChunk mTempB; @@ -40,6 +40,11 @@ class CPUMatMul : public Execution { int mL; int mH; std::vector mPostParameters; + // For Execute Paramters + const float* mA = nullptr; + const float* mB = nullptr; + const float* mBiasPtr = nullptr; + float* mC = nullptr; }; } // namespace MNN diff --git a/source/backend/cpu/CPURNNSequenceGRU.cpp b/source/backend/cpu/CPURNNSequenceGRU.cpp index daae8811c7..0bda660e9c 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.cpp +++ b/source/backend/cpu/CPURNNSequenceGRU.cpp @@ -10,30 +10,26 @@ #include #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/ConvOpt.h" -#include "backend/cpu/compute/CommonOptFunction.h" #include "core/TensorUtils.hpp" namespace MNN { // implement GRU cell function // Ref: tensorflow/python/ops/rnn_cell_impl.py -void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, +void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt) { - auto bn = static_cast(backend()); - auto mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); - auto addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); - auto subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); - auto tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); - auto bytes = bn->functions()->bytes; - auto sigmoidFunc = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); + std::shared_ptr& resetHt, uint8_t* hiddenStateOutput) { // gate is (z_t, r_t) + auto bytes = mRNNFunctions.bytes; + MNNBinaryExecute mulFunction = mRNNFunctions.mulFunction; + MNNBinaryExecute addFunction = mRNNFunctions.addFunction; + MNNBinaryExecute subFunction = mRNNFunctions.subFunction; + MNNUnaryExecute tanhFunction = mRNNFunctions.tanhFunction; + MNNUnaryExecute sigmoidFunction = mRNNFunctions.sigmoidFunction; auto inputAndStatePtr = inputAndState->host(); - auto hiddenStatePtr = hiddenState->host(); ::memcpy(inputAndStatePtr, input, inputLength * bytes); - ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStatePtr, numUnits * bytes); + ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStateInput, numUnits * bytes); inputAndState->setLength(1, inputLength + numUnits); // // [x_t, h_t-1] * [W_zr, R_zr]: (1, inputLength + numUnits) X (inputLength + numUnits, 2 * numUnits) @@ -42,9 +38,8 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, recurrentBias->setLength(1, 2 * numUnits); addFunction(gate->host(), gate->host(), recurrentBias->host(), 2*numUnits, -1); // (1, 2*numUnits) - const int gateSize = gate->elementSize(); auto gatePtr = gate->host(); - sigmoidFunc(gatePtr, gatePtr, gateSize); + sigmoidFunction(gatePtr, gatePtr, 2 * numUnits); // reset gate, // r_t is the second segment auto rtPtr = gatePtr + numUnits * bytes; @@ -52,7 +47,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // calculate Rt (.) (Ht_1 * Rh + Rbh) auto recurrentHiddenBiasPtr = recurrentBias->host() + 2 * numUnits * bytes; auto rhWeightPtr = candidateWeight->host() + inputLength * numUnits * bytes; - mMatMulU2U->execute(hiddenState->host(), (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); + mMatMulU2U->execute((float*)hiddenStateInput, (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); mulFunction(resetHt->host(), rtPtr, resetHt->host(), numUnits, -1); // calculate Xt * Wh @@ -65,7 +60,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // r_t: (1, numUnits) auto resetGatePtr = inputAndStatePtr + inputLength * bytes; // h_t1(1, numUnits) = r_t(1, numUnits) * h_t-1_(1, numUnits) - mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1); + mulFunction(resetGatePtr, rtPtr, hiddenStateInput, numUnits, -1); // deal with recurrent bias and linear_before_reset parameter auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes; auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host() + 2 * numUnits * bytes); @@ -76,9 +71,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, } // h = (1-g)*t+g*h = t + g*(h-t) tanhFunction(resetHt->host(), rtPtr, numUnits); - subFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); - mulFunction(hiddenStatePtr, hiddenStatePtr, gatePtr, numUnits, -1); - addFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); + subFunction(hiddenStateOutput, hiddenStateInput, resetHt->host(), numUnits, -1); + mulFunction(hiddenStateOutput, hiddenStateOutput, gatePtr, numUnits, -1); + addFunction(hiddenStateOutput, hiddenStateOutput, resetHt->host(), numUnits, -1); inputAndState->setLength(1, inputLength + 2 * numUnits); } @@ -143,6 +138,13 @@ ErrorCode CPURNNSequenceGRU::onResize(const std::vector& inputs, const backend()->onReleaseBuffer(mInputAndState.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mGate.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mResetHt.get(), Backend::DYNAMIC); + auto bn = static_cast(backend()); + mRNNFunctions.mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); + mRNNFunctions.addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); + mRNNFunctions.subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); + mRNNFunctions.tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); + mRNNFunctions.bytes = bn->functions()->bytes; + mRNNFunctions.sigmoidFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); return NO_ERROR; } @@ -183,27 +185,29 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const const int inputCodeLength = input->length(2); // MNN_PRINT("inputSequenceLength:%d, batchSize:%d, inputCodeLength:%d, mNumUnits:%d, hiddenStateDataSize:%d\n", inputSequenceLength, batchSize, inputCodeLength, mNumUnits, hiddenStateDataSize); for (int b = 0; b < batchSize; ++b) { // swap order + auto hiddenStateInput = hiddenStatePtr; + auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * (mIsBidirectionalRNN + 1)) { auto source = inputs[inputSize - 1]->host() + b * hiddenStateDataSize; - ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); + hiddenStateInput = source; } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = 0; i < inputSequenceLength; ++i) { const int inputOffset = i * SequenceStride + b * inputCodeLength; - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, fwGateWeight, fwGateBias, - fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt); - if (mKeepAllOutputs) { - ::memcpy(outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes, hiddenStatePtr, hiddenStateDataSize); + hiddenStateOutput = outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes; } + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, fwGateWeight, fwGateBias, + fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); + + hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { - ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); + ::memcpy(outputYhPtr, hiddenStateOutput, hiddenStateDataSize); outputYhPtr += mNumUnits * bytes; } - } // backward rnn @@ -221,22 +225,24 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const auto outputBw = outputs[0]; auto const outputBwPtr = outputBw->host(); for (int b = 0; b < batchSize; ++b) { + auto hiddenStateInput = hiddenStatePtr; + auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * 2) { auto source = inputs[inputSize - 1]->host() + (batchSize + b) * hiddenStateDataSize; - ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); + hiddenStateInput = source; } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = inputSequenceLength - 1; i >= 0; i--) { const int inputOffset = i * SequenceStride + b * inputCodeLength; - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, bwGateWeight, bwGateBias, - bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt); if (mKeepAllOutputs) { - ::memcpy(outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes, - hiddenStatePtr, hiddenStateDataSize); + hiddenStateOutput = outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes; } + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, bwGateWeight, bwGateBias, + bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); + hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); diff --git a/source/backend/cpu/CPURNNSequenceGRU.hpp b/source/backend/cpu/CPURNNSequenceGRU.hpp index 0987d13053..0125b9e8a1 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.hpp +++ b/source/backend/cpu/CPURNNSequenceGRU.hpp @@ -11,6 +11,7 @@ #include "core/Execution.hpp" #include "CPUMatMul.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" namespace MNN { class CPURNNSequenceGRU : public Execution { @@ -19,13 +20,20 @@ class CPURNNSequenceGRU : public Execution { virtual ~CPURNNSequenceGRU(); virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - + struct RNNFuntions { + MNNBinaryExecute mulFunction; + MNNBinaryExecute addFunction; + MNNBinaryExecute subFunction; + MNNUnaryExecute tanhFunction; + MNNUnaryExecute sigmoidFunction; + int bytes; + }; private: void runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, + uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt); + std::shared_ptr& resetHt, uint8_t* hiddenStateOutput); bool mKeepAllOutputs; bool mIsBidirectionalRNN; bool mlinearBeforeReset; @@ -42,6 +50,7 @@ class CPURNNSequenceGRU : public Execution { std::shared_ptr mMatMulU2U; // For inputLength -> numUnit std::shared_ptr mMatMulI2U; + RNNFuntions mRNNFunctions; }; } // namespace MNN diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index 3272086531..1339089347 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -49,227 +49,6 @@ struct ReduceInfo { } }; -ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { - MNN_ASSERT(outputs.size() == 1); - auto output = outputs[0]; - OpCommonUtils::rasterInputReset(____inputs, outputs[0]); - auto des = TensorUtils::getDescribe(output); - auto outputDes = TensorUtils::getDescribe(output); - mNeedZero = !TensorUtils::regionIsFull(output); - mZeroPoint = 0; - mUseThreads = false; - if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { -#ifdef MNN_USE_SSE - mZeroPoint = (int)outputDes->quantAttr->zero + 128; -#else - mZeroPoint = (int)outputDes->quantAttr->zero; -#endif - } - mTempInput.clear(); - mFastBlit.clear(); - mCacheRegions.clear(); - mTempOutput = nullptr; - auto midFormat = MNN_DATA_FORMAT_NCHW; - mTempInputCopy.clear(); - mFast = false; - auto core = static_cast(backend())->functions(); - mSingleConvert.type = 0; - // all_srcFormat == dstFormat == NC4HW4 : Fast Exe - if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { - mFast = true; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - mFast = false; - break; - } - if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { - mFast = false; - break; - } - } - if (mFast) { - mUseThreads = des->regions.size() > 16 ? true : false; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (slice.origin == nullptr) { - continue; - } - Tensor::InsideDescribe::Region newRegion; - OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); - mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); - } - return NO_ERROR; - } - } - // srcNum == 1 && srcFormat != dstFormat : Single Convert - if (des->regions.size() == 1) { - OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); - if (mSingleConvert.type > 0) { - mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; - return NO_ERROR; - } - } - // Acquire Buffer for temp output - // TODO: optimize it - if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { - mTempOutput.reset(new Tensor); - TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); - } - if (nullptr != mTempOutput) { - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - } - // input is NC4HW4 add Convert - std::vector forRelease; - TensorUtils::FuseWrap fuseUtils; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - auto origin = slice.origin; - if (nullptr == origin /*|| nullptr == origin->host()*/) { - continue; - } - // if tensor is not NC4HW4 or has been merged, don't need deal - if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); - continue; - } - // if NC4HW4's C%4 == 0, change convert to transpose and fuse it - if (origin->batch() == 1 && origin->channel() % core->pack == 0) { - int channel = origin->channel(); - int area = 1; - // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height - for (int d = 2; d < origin->dimensions(); d++) { - area *= origin->length(d); - } - Tensor::InsideDescribe::Region regionTmp; - regionTmp.src.offset = 0; - regionTmp.src.stride[0] = area * core->pack; - regionTmp.src.stride[1] = 1; - regionTmp.src.stride[2] = core->pack; - regionTmp.dst.offset = 0; - regionTmp.dst.stride[0] = area * core->pack; - regionTmp.dst.stride[1] = area; - regionTmp.dst.stride[2] = 1; - regionTmp.size[0] = channel / core->pack; - regionTmp.size[1] = core->pack; - regionTmp.size[2] = area; - regionTmp.origin = slice.origin; - bool merge = fuseUtils.match(regionTmp, slice); - if (merge) { - std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); - *newSlice = slice; - fuseUtils.apply(regionTmp, *newSlice); - // cache the merged tensor - if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); - mCacheRegions.emplace_back(newSlice); - continue; - } - } - auto cache = static_cast(backend())->getCache(); - auto tempTensor = cache->findCacheTensor(origin, midFormat); - //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); - if (nullptr == tempTensor) { - std::shared_ptr newTensor(new Tensor); - TensorUtils::copyShape(origin, newTensor.get()); - TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; - TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; - TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; - newTensor->buffer().type = origin->getType(); - TensorUtils::setLinearLayout(newTensor.get()); - mTempInput.insert(std::make_pair(origin, newTensor.get())); - auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - tempTensor = newTensor.get(); - TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; - cache->pushCacheTensor(newTensor, origin, midFormat); - } - if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { - forRelease.emplace_back(tempTensor); - } - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); - } - for (auto t : forRelease) { - backend()->onReleaseBuffer(t, Backend::DYNAMIC); - } - if (nullptr != mTempOutput) { - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); - } - auto threadNumber = static_cast(backend())->threadNumber(); - mHasReduce = false; - ReduceInfo reduceInfo; - for (auto& iter : mTempInputCopy) { - if (reduceInfo.compute(*iter.second)) { - mHasReduce = true; - break; - } - } - if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { - // Split to multi region - auto region = mTempInputCopy[0].second; - if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = false; - return NO_ERROR; - } - if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - auto tensorPtr = mTempInputCopy[0].first; - int pos = -1; - for (int i=0; i<3; ++i) { - if (region->size[i] > 1) { - pos = i; - break; - } - } - if (-1 == pos) { - // Don't need divide - return NO_ERROR; - } - mTempInputCopy.clear(); - int divSize = UP_DIV(region->size[pos], threadNumber); - for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); - auto& cacheReg = *cacheRegPtr; - int sta = i * divSize; - int fin = sta + divSize; - fin = std::min(fin, region->size[pos]); - if (fin <= sta) { - break; - } - for (int v=0; v<3; ++v) { - cacheReg.src.stride[v] = region->src.stride[v]; - cacheReg.dst.stride[v] = region->dst.stride[v]; - } - int curSize = fin - sta; - for (int v=0; vsize[v]; - } - cacheReg.size[pos] = curSize; - cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; - cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; - for (int v=pos+1; v<3; ++v) { - cacheReg.size[v] = region->size[v]; - } - mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); - mCacheRegions.emplace_back(cacheRegPtr); - } - } - return NO_ERROR; -} static void _transpose(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region, int bytes) { int dims[4], keepDim = -1; for (int i = 0; i < 3; i++) { @@ -324,15 +103,12 @@ static void _2BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, } } -void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) const { +void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) { auto input = inputs[0]; auto output = outputs[0]; auto bytes = CPUBackend::getBytes(backend(), output); auto core = static_cast(backend())->functions(); - auto threadNum = static_cast(backend())->threadNumber(); - if (mNeedZero) { - ::memset(output->host(), mZeroPoint, static_cast(backend())->getTensorSize(output) * bytes); - } + int threadNum = static_cast(backend())->threadNumber(); auto byteC4 = bytes * core->pack; auto C4proc = core->MNN4BitcopyWithStride; switch (byteC4) { @@ -352,7 +128,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve if (!mUseThreads) { threadNum = 1; } - MNN_CONCURRENCY_BEGIN(tId, threadNum) { + mTasks.emplace_back(std::make_pair([threadNum, this, output, bytes, C4proc, byteC4](int tId) { for (int u=(int)tId; uhost() == nullptr) { @@ -393,8 +169,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve } } } - } - MNN_CONCURRENCY_END(); + }, threadNum)); } static BlitProc _selectUnitProc(int bytes, int stride, int ds) { @@ -596,97 +371,307 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const } } void CPURaster::tensorConvert(Tensor* input, Tensor* output, int bytes) { - auto& subIb = input->buffer(); - auto& subOb = output->buffer(); - auto source = TensorUtils::getDescribe(input)->dimensionFormat; - auto dest = TensorUtils::getDescribe(output)->dimensionFormat; - if (subIb.dimensions <= 1 || source == dest) { - ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); - return; - } - auto tup = CPUTensorConverter::splitDimensions(subIb, source); - int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); - const int bitLength = bytes; + std::pair, int> task; auto core = static_cast(backend())->functions(); auto threadNumber = static_cast(backend())->threadNumber(); if (!mUseThreads) { threadNumber = 1; } - MNN_CONCURRENCY_BEGIN(tId, threadNumber) { + task.first = [input, output, bytes, threadNumber, core](int tId) { + auto& subIb = input->buffer(); + auto& subOb = output->buffer(); + auto source = TensorUtils::getDescribe(input)->dimensionFormat; + auto dest = TensorUtils::getDescribe(output)->dimensionFormat; + if (subIb.dimensions <= 1 || source == dest) { + ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); + return; + } + auto tup = CPUTensorConverter::splitDimensions(subIb, source); + int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); + const int bitLength = bytes; CPUTensorConverter::convert(subIb.host, subOb.host, source, dest, batch, area, channel, bitLength, core, tId, threadNumber); }; - MNN_CONCURRENCY_END(); + task.second = threadNumber; + mTasks.emplace_back(task); } - - -ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { - void* mOutputPtr = nullptr; - if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = outputs[0]->host(); - } - if (mFast) { - executeFaster(____inputs, outputs); - return NO_ERROR; - } - auto core = static_cast(backend())->functions(); +ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { + MNN_ASSERT(outputs.size() == 1); auto output = outputs[0]; + OpCommonUtils::rasterInputReset(____inputs, outputs[0]); + auto des = TensorUtils::getDescribe(output); + auto outputDes = TensorUtils::getDescribe(output); + mNeedZero = !TensorUtils::regionIsFull(output); + mZeroPoint = 0; + mUseThreads = false; + int threadNum = static_cast(backend())->threadNumber(); + if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { +#ifdef MNN_USE_SSE + mZeroPoint = (int)outputDes->quantAttr->zero + 128; +#else + mZeroPoint = (int)outputDes->quantAttr->zero; +#endif + } size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - auto outputEleSize = static_cast(backend())->getTensorSize(output); - auto threadNum = static_cast(backend())->threadNumber(); - if (mSingleConvert.type > 0) { - auto realInput = ____inputs[0]; - int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; - auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; - auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; - auto channelC4 = UP_DIV(srcChannel, core->pack); - auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; - auto batchStride = srcChannel * srcArea * bytes; - auto inputBatchStride = batchStride; - auto outputBatchStride = batchStride; - if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { - if (realInput->dimensions() <= 1) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - return NO_ERROR; + mTempInput.clear(); + mFastBlit.clear(); + mCacheRegions.clear(); + mTempOutput = nullptr; + mTasks.clear(); + auto midFormat = MNN_DATA_FORMAT_NCHW; + mTempInputCopy.clear(); + mFast = false; + auto core = static_cast(backend())->functions(); + mSingleConvert.type = 0; + // all_srcFormat == dstFormat == NC4HW4 : Fast Exe + if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { + mFast = true; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + mFast = false; + break; } - inputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - destFormat = MNN_DATA_FORMAT_NHWC; - } else { - destFormat = MNN_DATA_FORMAT_NCHW; + if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { + mFast = false; + break; } - } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { - if (output->dimensions() <= 1) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - return NO_ERROR; + } + if (mFast) { + mUseThreads = des->regions.size() > 16 ? true : false; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (slice.origin == nullptr) { + continue; + } + Tensor::InsideDescribe::Region newRegion; + OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); + mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); } - outputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - sourceFormat = MNN_DATA_FORMAT_NHWC; - } else { - sourceFormat = MNN_DATA_FORMAT_NCHW; + executeFaster(____inputs, outputs); + return NO_ERROR; + } + } + // srcNum == 1 && srcFormat != dstFormat : Single Convert + if (des->regions.size() == 1) { + OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); + if (mSingleConvert.type > 0) { + std::pair, int> task; + mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; + auto realInput = ____inputs[0]; + int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; + auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; + auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; + auto channelC4 = UP_DIV(srcChannel, core->pack); + auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; + auto batchStride = srcChannel * srcArea * bytes; + auto inputBatchStride = batchStride; + auto outputBatchStride = batchStride; + if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { + if (realInput->dimensions() <= 1) { + task.first = [output, realInput, bytes](int tId) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + }; + task.second = 1; + mTasks.emplace_back(task); + return NO_ERROR; + } + inputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + destFormat = MNN_DATA_FORMAT_NHWC; + } else { + destFormat = MNN_DATA_FORMAT_NCHW; + } + } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { + if (output->dimensions() <= 1) { + task.first = [output, realInput, bytes](int tId) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + }; + task.second = 1; + mTasks.emplace_back(task); + return NO_ERROR; + } + outputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + sourceFormat = MNN_DATA_FORMAT_NHWC; + } else { + sourceFormat = MNN_DATA_FORMAT_NCHW; + } } + if (!mUseThreads) { + threadNum = 1; + } + task.first = [realInput, output, sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, threadNum](int tId) { + CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); + }; + task.second = threadNum; + mTasks.emplace_back(task); + return NO_ERROR; } - if (!mUseThreads) { - threadNum = 1; + } + // Acquire Buffer for temp output + // TODO: optimize it + if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { + mTempOutput.reset(new Tensor); + TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); + } + if (nullptr != mTempOutput) { + auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; } - MNN_CONCURRENCY_BEGIN(tId, threadNum) { - CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); - }; - MNN_CONCURRENCY_END(); - return NO_ERROR; } - if (mNeedZero) { - if (mTempOutput == nullptr) { - ::memset(output->host(), mZeroPoint, outputEleSize * bytes); - } else { - ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); + // input is NC4HW4 add Convert + std::vector forRelease; + TensorUtils::FuseWrap fuseUtils; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + auto origin = slice.origin; + if (nullptr == origin /*|| nullptr == origin->host()*/) { + continue; + } + // if tensor is not NC4HW4 or has been merged, don't need deal + if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); + continue; } + // if NC4HW4's C%4 == 0, change convert to transpose and fuse it + if (origin->batch() == 1 && origin->channel() % core->pack == 0) { + int channel = origin->channel(); + int area = 1; + // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height + for (int d = 2; d < origin->dimensions(); d++) { + area *= origin->length(d); + } + Tensor::InsideDescribe::Region regionTmp; + regionTmp.src.offset = 0; + regionTmp.src.stride[0] = area * core->pack; + regionTmp.src.stride[1] = 1; + regionTmp.src.stride[2] = core->pack; + regionTmp.dst.offset = 0; + regionTmp.dst.stride[0] = area * core->pack; + regionTmp.dst.stride[1] = area; + regionTmp.dst.stride[2] = 1; + regionTmp.size[0] = channel / core->pack; + regionTmp.size[1] = core->pack; + regionTmp.size[2] = area; + regionTmp.origin = slice.origin; + bool merge = fuseUtils.match(regionTmp, slice); + if (merge) { + std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); + *newSlice = slice; + fuseUtils.apply(regionTmp, *newSlice); + // cache the merged tensor + if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); + mCacheRegions.emplace_back(newSlice); + continue; + } + } + auto cache = static_cast(backend())->getCache(); + auto tempTensor = cache->findCacheTensor(origin, midFormat); + //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); + if (nullptr == tempTensor) { + std::shared_ptr newTensor(new Tensor); + TensorUtils::copyShape(origin, newTensor.get()); + TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; + TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; + TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; + newTensor->buffer().type = origin->getType(); + TensorUtils::setLinearLayout(newTensor.get()); + mTempInput.insert(std::make_pair(origin, newTensor.get())); + auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + tempTensor = newTensor.get(); + TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; + cache->pushCacheTensor(newTensor, origin, midFormat); + } + if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { + forRelease.emplace_back(tempTensor); + } + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); + } + for (auto t : forRelease) { + backend()->onReleaseBuffer(t, Backend::DYNAMIC); + } + if (nullptr != mTempOutput) { + backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); } + auto threadNumber = static_cast(backend())->threadNumber(); + mHasReduce = false; + ReduceInfo reduceInfo; + for (auto& iter : mTempInputCopy) { + if (reduceInfo.compute(*iter.second)) { + mHasReduce = true; + break; + } + } + // Encode convert for (auto& iter : mTempInput) { tensorConvert(iter.first, iter.second, (int)bytes); } + do { + if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { + // Split to multi region + auto region = mTempInputCopy[0].second; + if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = false; + break; + } + if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + auto tensorPtr = mTempInputCopy[0].first; + int pos = -1; + for (int i=0; i<3; ++i) { + if (region->size[i] > 1) { + pos = i; + break; + } + } + if (-1 == pos) { + // Don't need divide + break; + } + mTempInputCopy.clear(); + int divSize = UP_DIV(region->size[pos], threadNumber); + for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); + auto& cacheReg = *cacheRegPtr; + int sta = i * divSize; + int fin = sta + divSize; + fin = std::min(fin, region->size[pos]); + if (fin <= sta) { + break; + } + for (int v=0; v<3; ++v) { + cacheReg.src.stride[v] = region->src.stride[v]; + cacheReg.dst.stride[v] = region->dst.stride[v]; + } + int curSize = fin - sta; + for (int v=0; vsize[v]; + } + cacheReg.size[pos] = curSize; + cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; + cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; + for (int v=pos+1; v<3; ++v) { + cacheReg.size[v] = region->size[v]; + } + mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); + mCacheRegions.emplace_back(cacheRegPtr); + } + } + } while (false); if (mHasReduce) { // Don't support reduce with multi thread now threadNum = 1; @@ -700,8 +685,13 @@ ErrorCode CPURaster::onExecute(const std::vector &____inputs, const st if (outputDescribe->overlap) { threadNum = 1; } - - MNN_CONCURRENCY_BEGIN(tId, threadNum) { + mTasks.emplace_back(std::make_pair([this, threadNum, output, bytes, core](int tId){ + void* mOutputPtr = nullptr; + if (nullptr != mTempOutput) { + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = output->host(); + } for (int u=tId; u &____inputs, const st auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; _blit(slice, (int)bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp); } - } - MNN_CONCURRENCY_END(); + }, threadNum)); if (nullptr != mTempOutput) { tensorConvert(mTempOutput.get(), output, (int)bytes); } return NO_ERROR; } + + +ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { + void* mOutputPtr = nullptr; + if (nullptr != mTempOutput) { + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = outputs[0]->host(); + } + auto core = static_cast(backend())->functions(); + auto output = outputs[0]; + size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); + auto outputEleSize = static_cast(backend())->getTensorSize(output); + auto threadNum = static_cast(backend())->threadNumber(); + if (mNeedZero) { + if (mTempOutput == nullptr) { + ::memset(output->host(), mZeroPoint, outputEleSize * bytes); + } else { + ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); + } + } + for (auto& task : mTasks) { + MNN_CONCURRENCY_ENQUEUE(task); + } + return NO_ERROR; +} class CPULoop : public Execution { public: struct ThreadContainer { @@ -1066,7 +1081,15 @@ class CPULoop : public Execution { auto stride2 = cmd->view()->GetAs(2)->stride()->data(); auto blit1 = _selectUnitProc(bytes, stride1[2], 1); auto blit2 = _selectUnitProc(bytes, stride2[2], 1); - if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) { + if (cmd->size()->data()[2] == 1 || (stride1[2] <= 1 && stride2[2] <= 1 && (stride1[2] + stride1[1] != 0))) { + // Support elementwise or one src broadcast + int needBroadcastIndex = -1; + if (0 == stride1[2]) { + needBroadcastIndex = 0; + } + if (0 == stride2[2]) { + needBroadcastIndex = 1; + } for (int z=0; zsize()->data()[0]; ++z) { auto src0Z = src0 + z * stride1[0] * bytes; auto src1Z = src1 + z * stride2[0] * bytes; @@ -1075,7 +1098,7 @@ class CPULoop : public Execution { auto src0Y = src0Z + y * stride1[1] * bytes; auto src1Y = src1Z + y * stride2[1] * bytes; auto dstY = dstZ + y * stride0[1] * bytes; - proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1); + proc(dstY, src0Y, src1Y, cmd->size()->data()[2], needBroadcastIndex); } } } else { diff --git a/source/backend/cpu/CPURaster.hpp b/source/backend/cpu/CPURaster.hpp index 9df10700bd..bff149df52 100644 --- a/source/backend/cpu/CPURaster.hpp +++ b/source/backend/cpu/CPURaster.hpp @@ -24,7 +24,7 @@ class CPURaster : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - void executeFaster(const std::vector &inputs, const std::vector &outputs) const; + void executeFaster(const std::vector &inputs, const std::vector &outputs); void tensorConvert(Tensor* input, Tensor* output, int bytes); private: std::map mTempInput; @@ -38,6 +38,7 @@ class CPURaster : public Execution { int32_t mZeroPoint = 0; bool mHasReduce = false; bool mUseThreads = false; + std::vector, int>> mTasks; }; } #endif diff --git a/source/backend/cpu/ThreadPool.cpp b/source/backend/cpu/ThreadPool.cpp index 15a2d8241c..d7765c4fbc 100644 --- a/source/backend/cpu/ThreadPool.cpp +++ b/source/backend/cpu/ThreadPool.cpp @@ -60,7 +60,7 @@ ThreadPool::ThreadPool(int numberThread) { while (mActiveCount > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { - mTasks[i].first.first(threadIndex); + mTasks[i].first->first(threadIndex); { *mTasks[i].second[threadIndex] = false; } } } @@ -118,16 +118,18 @@ void ThreadPool::deactive() { mActiveCount--; } -void ThreadPool::enqueue(TASK&& task, int index) { +void ThreadPool::enqueue(TASK* taskp, int index) { + auto& task = *taskp; if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } - enqueueInternal(std::move(task), index); + enqueueInternal(taskp, index); } -void ThreadPool::enqueueInternal(TASK&& task, int index) { +void ThreadPool::enqueueInternal(TASK* taskp, int index) { + auto& task = *taskp; if (mActiveCount == 0) { for (int i = 0; i < task.second; ++i) { task.first(i); @@ -135,24 +137,25 @@ void ThreadPool::enqueueInternal(TASK&& task, int index) { return; } int workSize = task.second; + TASK* tmpTask = nullptr; if (workSize > mNumberThread) { - mTasks[index].first = std::make_pair( - [workSize, &task, this](int tId) { - for (int v = tId; v < workSize; v += mNumberThread) { - task.first(v); - } - }, - mNumberThread); + tmpTask = new TASK; + *tmpTask = std::make_pair([workSize, &task, this](int tId) { + for (int v = tId; v < workSize; v += mNumberThread) { + task.first(v); + } + }, mNumberThread); + mTasks[index].first = tmpTask; workSize = mNumberThread; } else { - mTasks[index].first = std::move(task); + mTasks[index].first = taskp; } { for (int i = 1; i < workSize; ++i) { *mTasks[index].second[i] = true; } } - mTasks[index].first.first(0); + mTasks[index].first->first(0); bool complete = true; do { complete = true; @@ -165,6 +168,9 @@ void ThreadPool::enqueueInternal(TASK&& task, int index) { std::this_thread::yield(); // FUNC_PRINT(notComplete); } while (!complete); + if (nullptr != tmpTask) { + delete tmpTask; + } } } // namespace MNN #endif diff --git a/source/backend/cpu/ThreadPool.hpp b/source/backend/cpu/ThreadPool.hpp index 4bf23de1b0..8891da61b1 100644 --- a/source/backend/cpu/ThreadPool.hpp +++ b/source/backend/cpu/ThreadPool.hpp @@ -25,7 +25,7 @@ class MNN_PUBLIC ThreadPool { int numberThread() const { return mNumberThread; } - void enqueue(TASK&& task, int index); + void enqueue(TASK* task, int index); void active(); void deactive(); @@ -37,7 +37,7 @@ class MNN_PUBLIC ThreadPool { static void destroy(); private: - void enqueueInternal(TASK&& task, int index); + void enqueueInternal(TASK* task, int index); ThreadPool(int numberThread = 0); ~ThreadPool(); @@ -46,7 +46,7 @@ class MNN_PUBLIC ThreadPool { std::vector mTaskAvailable; std::atomic mStop = {false}; - std::vector>> mTasks; + std::vector>> mTasks; std::condition_variable mCondition; std::mutex mQueueMutex; diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index d7d0d7fb34..c9bfcc2189 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -3882,12 +3882,13 @@ void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, si #endif -void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { +void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tIdL) { auto l = param->l; auto h = param->h; auto numberThread = param->numberThread; auto lC4 = l / 4; auto lR = lC4 * 4; + auto tId = (int)tIdL; if (param->BTranspose) { for (int y=tId; y= 8) { + if (0 == tId) { + auto bs = B + hEnd; + Vec4 sumValue0; + Vec4 sumValue1; + if (biasPtr != nullptr) { + sumValue0 = Vec4::load(biasPtr + hEnd + 0); + sumValue1 = Vec4::load(biasPtr + hEnd + 4); + } else { + sumValue0 = Vec4(0.0f); + sumValue1 = Vec4(0.0f); + } + auto srcY = A + hEnd * l; + for (int x=0; x= 4) { + if (0 == tId) { + auto bs = B + hEnd; + Vec4 sumValue0; + if (biasPtr != nullptr) { + sumValue0 = Vec4::load(biasPtr + hEnd + 0); + } else { + sumValue0 = Vec4(0.0f); + } + auto srcY = A + hEnd * l; + for (int x=0; x= 8) { + sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); + sum1 = Vec::fma(sum1, Vec4::load(srcY + lR + 4), Vec4::load(B + lR + 4)); + lR += 8; + } + if (l - lR >= 4) { + sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); + lR += 4; + } + sum2 = sum2 + sum3; + sumValue = sumValue + sum1; + sumValue = sumValue + sum2; float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; for (int x=lR; xenqueue(task) + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ { \ std::pair, int> task; \ @@ -28,8 +33,7 @@ } \ ; \ auto cpuBn = (CPUBackend*)backend(); \ - auto thrPl = cpuBn->threadPool(); \ - thrPl->enqueue(std::move(task), cpuBn->taskIndex()); \ + cpuBn->enqueue(task); \ } #else @@ -38,6 +42,9 @@ #include #include +#define MNN_CONCURRENCY_ENQUEUE(task) \ +dispatch_apply(task.second, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) {task.first(__iter__);}); + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) { #define MNN_CONCURRENCY_END() \ @@ -58,6 +65,8 @@ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, // Android #else #include +#define MNN_CONCURRENCY_ENQUEUE(task) \ +_Pragma("omp parallel for") for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} #define MNN_STRINGIFY(a) #a #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index c80afaef87..a69263ffaa 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -386,98 +386,7 @@ void OpCommonUtils::broastCastComputeDim(int* dims, int* stride, int* iStride0, } } } -std::vector> OpCommonUtils::computeReduceDims(const std::vector& inputs, - const Op* op) { - // Compute axises - std::vector axises; - if (inputs.size() >= 2) { - auto size = inputs[1]->elementSize(); - auto dims = inputs[1]->host(); - for (int i = 0; i < size; ++i) { - axises.emplace_back(dims[i]); - } - } else { - auto reduct = op->main_as_ReductionParam(); - if (nullptr != reduct->dim()) { - for (int i = 0; i < reduct->dim()->size(); ++i) { - axises.emplace_back(reduct->dim()->data()[i]); - } - } - } - auto totalSize = TensorUtils::getRawSize(inputs[0]); - if (axises.empty()) { - return {std::make_tuple(1, totalSize, 1)}; - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - axises[i] = inputs[0]->dimensions() + axises[i]; - if (axises[i] < 0) { - return {std::make_tuple(1, totalSize, 1)}; - } - } - } - // Cache for input's dims - std::vector lengths(inputs[0]->dimensions()); - for (int i = 0; i < lengths.size(); ++i) { - lengths[i] = inputs[0]->length(i); - } - std::vector> groupAxises; - { - // Merge adj axis - std::sort(axises.begin(), axises.end()); - int lastAxis = axises[0]; - int length = 1; - int start = axises[0]; - for (int i = 1; i < axises.size(); ++i) { - // MNN_PRINT("%d - %d\n", axises[i], lastAxis); - if (axises[i] - lastAxis == 1) { - length++; - } else { - groupAxises.emplace_back(std::make_pair(start, length)); - length = 1; - start = axises[i]; - } - lastAxis = axises[i]; - } - groupAxises.emplace_back(std::make_pair(start, length)); - } - - // Compute inside-outside-axis - std::vector> result; - for (int i = 0; i < groupAxises.size(); ++i) { - int outsideSize = 1; - int insideSize = 1; - int axisSize = 1; - auto start = groupAxises[i].first; - auto length = groupAxises[i].second; - if (start >= (int)lengths.size()) { - break; - } - for (int j = 0; j < start; ++j) { - outsideSize *= lengths[j]; - } - for (int j = start; j < start + length; ++j) { - if (j >= (int)lengths.size()) { - break; - } - axisSize *= lengths[j]; - lengths[j] = 1; - } - for (int j = start + length; j < lengths.size(); ++j) { - insideSize *= lengths[j]; - } - if (1 == axisSize) { - continue; - } - result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); - } - // FUNC_PRINT(result.size()); - if (result.empty()) { - result.emplace_back(std::make_tuple(1, 1, totalSize)); - } - return result; -} void OpCommonUtils::unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice) { int value = indice; diff --git a/source/core/OpCommonUtils.hpp b/source/core/OpCommonUtils.hpp index 0740cc16b2..8ec0628336 100644 --- a/source/core/OpCommonUtils.hpp +++ b/source/core/OpCommonUtils.hpp @@ -56,7 +56,6 @@ class MNN_PUBLIC OpCommonUtils { static bool supportDynamicInputMemory(MNNForwardType type); static void broastCastComputeDim(int* dims, int* stride, int* iStride0, int* iStride1, const Tensor* input0, const Tensor* input1, const Tensor* output); - static std::vector> computeReduceDims(const std::vector& inputs, const Op* op); static void unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice); static int computeStride(int32_t* strides, const int* shape, int length); diff --git a/source/core/TensorUtils.cpp b/source/core/TensorUtils.cpp index ae5b87143c..d233fc9d89 100644 --- a/source/core/TensorUtils.cpp +++ b/source/core/TensorUtils.cpp @@ -32,6 +32,18 @@ bool TensorUtils::regionIsFull(Tensor* input) { return regionSize == size; } +void TensorUtils::makeFullRef(Tensor* output, Tensor* input) { + auto des = TensorUtils::getDescribe(input); + auto outputDes = TensorUtils::getDescribe(output); + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) { + outputDes->regions = des->regions; + } else { + outputDes->regions = {makeFullSlice(input)}; + } +} + + Tensor::InsideDescribe::Region TensorUtils::makeFullSlice(Tensor* input) { Tensor::InsideDescribe::Region totalSlice; totalSlice.src.offset = 0; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 1342a669bd..a577fea05f 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -184,6 +184,7 @@ class MNN_PUBLIC TensorUtils { static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat); static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input); + static void makeFullRef(Tensor* output, Tensor* input); static bool regionIsFull(Tensor* input); static bool isCopyRegion(const Tensor::InsideDescribe::Region& region); static bool isTransposeRegion(const Tensor::InsideDescribe::Region& region); diff --git a/source/geometry/GeometryComputerUtils.cpp b/source/geometry/GeometryComputerUtils.cpp index 01a4e02ea2..85f64de55d 100644 --- a/source/geometry/GeometryComputerUtils.cpp +++ b/source/geometry/GeometryComputerUtils.cpp @@ -477,9 +477,9 @@ std::shared_ptr GeometryComputerUtils::makeBinary(int type, Tensor* inp return cmdP; } -std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output) { +std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis) { flatbuffers::FlatBufferBuilder builder(DEFAULT_ALLOCATE_SIZE); - auto vec = builder.CreateVector(std::vector{1}); + auto vec = builder.CreateVector(std::vector{axis}); ReductionParamBuilder builder_(builder); builder_.add_operation(type); builder_.add_keepDims(true); diff --git a/source/geometry/GeometryComputerUtils.hpp b/source/geometry/GeometryComputerUtils.hpp index c0dffdcdb1..97c4d5811f 100644 --- a/source/geometry/GeometryComputerUtils.hpp +++ b/source/geometry/GeometryComputerUtils.hpp @@ -18,7 +18,7 @@ class GeometryComputerUtils { static void addConvert(const CommandBuffer& srcBuffer, CommandBuffer& dstBuffer, GeometryComputer::Context& ctx); static std::shared_ptr makeCommand(flatbuffers::FlatBufferBuilder& builder, const std::vector& inputs, const std::vector& outputs); static std::shared_ptr makeBinary(int type, Tensor* input0, Tensor* input1, Tensor* output); - static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output); + static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis = 1); static std::shared_ptr makeUnary(UnaryOpOperation type, Tensor* input0, Tensor* output); static std::shared_ptr makeLayerNorm(Tensor* input0, Tensor* output, std::vector axis, float epsilon, std::vector gamma, std::vector beta, std::vector external, int group = 1, bool useRMS = false); static std::shared_ptr makeMatMul(Tensor* input0, Tensor* input1, Tensor* output, Tensor* Bias = nullptr, diff --git a/source/geometry/GeometryReduce.cpp b/source/geometry/GeometryReduce.cpp index c2a3bb4114..855f4bcf69 100644 --- a/source/geometry/GeometryReduce.cpp +++ b/source/geometry/GeometryReduce.cpp @@ -10,6 +10,83 @@ #include "geometry/GeometryComputerUtils.hpp" #include "core/OpCommonUtils.hpp" namespace MNN { +static std::vector> _computeReduceDims(const std::vector& inputs, + std::vector& axises) { + + auto totalSize = TensorUtils::getRawSize(inputs[0]); + if (axises.empty()) { + return {std::make_tuple(1, totalSize, 1)}; + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + if (axises[i] < 0) { + return {std::make_tuple(1, totalSize, 1)}; + } + } + } + // Cache for input's dims + std::vector lengths(inputs[0]->dimensions()); + for (int i = 0; i < lengths.size(); ++i) { + lengths[i] = inputs[0]->length(i); + } + std::vector> groupAxises; + { + // Merge adj axis + std::sort(axises.begin(), axises.end()); + int lastAxis = axises[0]; + int length = 1; + int start = axises[0]; + for (int i = 1; i < axises.size(); ++i) { + // MNN_PRINT("%d - %d\n", axises[i], lastAxis); + if (axises[i] - lastAxis == 1) { + length++; + } else { + groupAxises.emplace_back(std::make_pair(start, length)); + length = 1; + start = axises[i]; + } + lastAxis = axises[i]; + } + groupAxises.emplace_back(std::make_pair(start, length)); + } + + // Compute inside-outside-axis + std::vector> result; + + for (int i = 0; i < groupAxises.size(); ++i) { + int outsideSize = 1; + int insideSize = 1; + int axisSize = 1; + auto start = groupAxises[i].first; + auto length = groupAxises[i].second; + if (start >= (int)lengths.size()) { + break; + } + for (int j = 0; j < start; ++j) { + outsideSize *= lengths[j]; + } + for (int j = start; j < start + length; ++j) { + if (j >= (int)lengths.size()) { + break; + } + axisSize *= lengths[j]; + lengths[j] = 1; + } + for (int j = start + length; j < lengths.size(); ++j) { + insideSize *= lengths[j]; + } + if (1 == axisSize) { + continue; + } + result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); + } + // FUNC_PRINT(result.size()); + if (result.empty()) { + result.emplace_back(std::make_tuple(1, 1, totalSize)); + } + return result; +} + class GeometryReduce : public GeometryComputer { public: virtual bool onCompute(const Op* op, const std::vector& inputs, const std::vector& outputs, @@ -18,6 +95,31 @@ class GeometryReduce : public GeometryComputer { MNN_ASSERT(inputs.size() >= 1); auto reduct = op->main_as_ReductionParam(); auto reductOp = reduct->operation(); + std::vector axises; + if (inputs.size() >= 2) { + auto size = inputs[1]->elementSize(); + auto dims = inputs[1]->host(); + for (int i = 0; i < size; ++i) { + axises.emplace_back(dims[i]); + } + } else { + auto reduct = op->main_as_ReductionParam(); + if (nullptr != reduct->dim()) { + for (int i = 0; i < reduct->dim()->size(); ++i) { + axises.emplace_back(reduct->dim()->data()[i]); + } + } + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + axises[i] = inputs[0]->dimensions() + axises[i]; + } + } + if (1 == axises.size() && TensorUtils::getDescribe(inputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4 && TensorUtils::getDescribe(outputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + auto cmd = GeometryComputerUtils::makeReduce(reductOp, inputs[0], outputs[0], axises[0]); + res.command.emplace_back(std::move(cmd)); + return true; + } // prod([]) = 1 if (inputs[0]->elementSize() == 0) { if(!context.allocTensor(outputs[0])) { @@ -39,7 +141,7 @@ class GeometryReduce : public GeometryComputer { } return true; } - auto reduceDims = OpCommonUtils::computeReduceDims(inputs, op); + auto reduceDims = _computeReduceDims(inputs, axises); Tensor* currentInput = inputs[0]; MNN_ASSERT(reduceDims.size() > 0); auto dimType = currentInput->getDimensionType(); diff --git a/source/geometry/GeometryReshape.cpp b/source/geometry/GeometryReshape.cpp index 88d98a24c9..1df3384e37 100644 --- a/source/geometry/GeometryReshape.cpp +++ b/source/geometry/GeometryReshape.cpp @@ -42,8 +42,7 @@ class GeometryReshape : public GeometryComputer { return true; } } - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); return true; } }; @@ -75,10 +74,7 @@ class SingleGeometryComputer : public GeometryComputer { Context& context, CommandBuffer& res) const override { auto input = inputs[0]; auto output = outputs[0]; - auto inputDes = TensorUtils::getDescribe(input); - auto outputDes = TensorUtils::getDescribe(output); - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); return true; } }; @@ -94,8 +90,7 @@ class CopyGeometryComputer : public GeometryComputer { outputDes->tensorArrayAttr = inputDes->tensorArrayAttr; return true; } - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); } return true; } diff --git a/source/math/Vec.hpp b/source/math/Vec.hpp index 6839ab83b0..cc9354a7f1 100644 --- a/source/math/Vec.hpp +++ b/source/math/Vec.hpp @@ -372,8 +372,7 @@ struct Vec { using VecType = Vec; using VecTypeInt32 = Vec; float32x4_t value; - Vec() { - } + Vec() = default; Vec(const float v) { value = vdupq_n_f32(v); } diff --git a/test/core/ThreadPoolTest.cpp b/test/core/ThreadPoolTest.cpp index 6886f86e62..e010939e5f 100644 --- a/test/core/ThreadPoolTest.cpp +++ b/test/core/ThreadPoolTest.cpp @@ -26,11 +26,11 @@ class ThreadPoolTest : public MNNTestCase { auto workIndex = threadPool->acquireWorkIndex(); FUNC_PRINT(workIndex); threadPool->active(); - auto func = [](int index) { + ThreadPool::TASK task = std::make_pair([](int index) { FUNC_PRINT(index); std::this_thread::yield(); - }; - threadPool->enqueue(std::make_pair(std::move(func), 10), workIndex); + }, 10); + threadPool->enqueue(&task, workIndex); threadPool->deactive(); threadPool->releaseWorkIndex(workIndex); }); diff --git a/tools/cpp/ExprDebug.hpp b/tools/cpp/ExprDebug.hpp index 167e97c562..49e3db6156 100644 --- a/tools/cpp/ExprDebug.hpp +++ b/tools/cpp/ExprDebug.hpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #define DUMP_NUM_DATA(type) \ @@ -135,29 +136,69 @@ static void _initDebug() { struct TimeTraceInfo { - std::map>>> mTypes; + std::map>> mTypes; void begin(const MNN::OperatorInfo* info) { auto tIter = mTypes.find(info->type()); if (tIter == mTypes.end()) { - std::map>> _t; + std::map> _t; mTypes.insert(std::make_pair(info->type(), _t)); tIter = mTypes.find(info->type()); } mInserIter = tIter->second.find(info->name()); if (mInserIter == tIter->second.end()) { - std::vector> _t; - tIter->second.insert(std::make_pair(info->name(), _t)); + tIter->second.insert(std::make_pair(info->name(), std::make_tuple(0.0f, 0.0f, 0))); mInserIter = tIter->second.find(info->name()); } mTimer.reset(); } void end(const MNN::OperatorInfo* info) { auto timeInMs = (float)mTimer.durationInUs() / 1000.0f; - mInserIter->second.emplace_back(std::make_pair(timeInMs, info->flops())); + std::get<0>(mInserIter->second) += timeInMs; + std::get<1>(mInserIter->second) += info->flops(); + std::get<2>(mInserIter->second) ++; + } + void dump(bool dumpPerOp = false) { + if (dumpPerOp) { + auto cmp = [](const std::tuple& first, const std::tuple& second) { + return std::get<1>(first) > std::get<1>(second); + }; + std::priority_queue, std::vector>, decltype(cmp)> que(cmp); + for (auto& iter : mTypes) { + for (auto& t : iter.second) { + auto mergeType = t.first + " ["+iter.first +"]"; + auto unit = std::make_tuple(mergeType, std::get<0>(t.second), std::get<1>(t.second), std::get<2>(t.second)); + que.push(unit); + } + } + while (!que.empty()) { + auto& t = que.top(); + MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", std::get<0>(t).c_str(), std::get<1>(t), std::get<2>(t), std::get<3>(t), std::get<2>(t) / std::get<1>(t)); + que.pop(); + } + return; + } + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + int count = 0; + for (auto& t : iter.second) { + summer += std::get<0>(t.second); + summerflops += std::get<1>(t.second); + count += std::get<2>(t.second); + } + MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, count, + summerflops / summer); + opSummer += summer; + opFlopsSummber += summerflops; + } + MNN_PRINT("OP Summer: %.7f ms, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, + opFlopsSummber / opSummer); } private: - std::map>>::iterator mInserIter; + std::map>::iterator mInserIter; MNN::Timer mTimer; }; static TimeTraceInfo* gTimeTraceInfo = nullptr; diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 90fa6b80d3..5798bc6d26 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -499,10 +499,13 @@ int main(int argc, char *argv[]) { if (runTime > 0) { int t = runTime; - std::vector times(t, 0.0f); if (runMask & 4) { _initTimeTrace(); } + float minTime = std::numeric_limits::max(); + float maxTime = 0.0f; + float sum = 0.0f; + for (int i = 0; i < t; ++i) { Timer _l; auto out = net->onForward(inputs); @@ -510,41 +513,28 @@ int main(int argc, char *argv[]) { for (auto o : out) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } - times[i] = _l.durationInUs() / 1000.0f; + auto time = _l.durationInUs() / 1000.0f; if (freq > 0.0f) { - float remainMs = (1000.0f / freq) - times[i]; + float remainMs = (1000.0f / freq) - time; if (remainMs > 0.0f) { std::this_thread::sleep_for(std::chrono::milliseconds((int)remainMs)); } } - } - if (nullptr != gTimeTraceInfo) { - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : gTimeTraceInfo->mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - for (auto& t : iter.second) { - for (auto& t0 : t.second) { - summer += t0.first; - summerflops += t0.second; - } - } - summer = summer / (float)t; - summerflops = summerflops / (float)t; - MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer); - opSummer += summer; - opFlopsSummber+= summerflops; + if (maxTime < time) { + maxTime = time; + } + if (minTime > time) { + minTime = time; } - MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer); - } - auto minTime = std::min_element(times.begin(), times.end()); - auto maxTime = std::max_element(times.begin(), times.end()); - float sum = 0.0f; - for (auto time : times) { sum += time; } - MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, *minTime, *maxTime); + if (nullptr != gTimeTraceInfo) { + MNN_PRINT("Per Op Trace: \n"); + gTimeTraceInfo->dump(true); + MNN_PRINT("Per Type Trace: \n"); + gTimeTraceInfo->dump(false); + } + MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, minTime, maxTime); } rtmgr->updateCache(); return 0; diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 53af11239a..63c590e0fd 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -915,26 +915,7 @@ Llm::Llm(std::shared_ptr config) : mConfig(config) { Llm::~Llm() { #if DEBUG_MODE == 1 if (nullptr != gTimeTraceInfo) { - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : gTimeTraceInfo->mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - for (auto& t : iter.second) { - for (auto& t0 : t.second) { - summer += t0.first; - summerflops += t0.second; - } - } - summer = summer; - summerflops = summerflops; - MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, - summerflops / summer); - opSummer += summer; - opFlopsSummber += summerflops; - } - MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, - opFlopsSummber / opSummer); + gTimeTraceInfo->dump(); } #endif mGenerateParam.reset(); From 238690d4536a29f77d2d2576ca75edc3837a6d2f Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Tue, 23 Dec 2025 12:36:55 +0800 Subject: [PATCH 058/314] Merge branch feature/metal_backgroup_issue into master Title: [Metal Feature] check UI Status for metal command commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审主要增加了对执行状态的检查和错误处理,并引入了新的日志打印方式以提高调试和监控能力。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/24965986 GitOrigin-RevId: b7ad051c324c1b7d4aa231fc062f2f5d8e7f7a0f --- express/Expr.cpp | 16 +++++++++ express/Utils.cpp | 17 +++++++++ express/module/Module.cpp | 26 ++++++++++++++ source/backend/metal/MetalBackend.mm | 33 +++++++++++++++++ source/core/Backend.hpp | 1 + source/core/Tensor.cpp | 8 +++++ transformers/llm/engine/demo/llm_demo.cpp | 36 +++++++++---------- transformers/llm/engine/include/llm/llm.hpp | 9 +++++ transformers/llm/engine/src/llm.cpp | 18 +++++++++- .../engine/src/speculative_decoding/eagle.cpp | 16 +++++++++ .../src/speculative_decoding/generate.cpp | 13 ++++++- .../src/speculative_decoding/lookahead.cpp | 12 +++++++ .../engine/src/speculative_decoding/mtp.cpp | 12 +++++++ 13 files changed, 197 insertions(+), 20 deletions(-) diff --git a/express/Expr.cpp b/express/Expr.cpp index a735adbe3e..d9244e8de4 100644 --- a/express/Expr.cpp +++ b/express/Expr.cpp @@ -813,12 +813,28 @@ void* Variable::readInternal(bool forShape) { // The Varp will not be created as input, so we just need copy once return inside->mHostTensor->host(); } + inside->mHostTensor = new Tensor; TensorUtils::copyShape(originTensor, inside->mHostTensor, true); inside->mHostTensor->buffer().type = originTensor->getType(); inside->mHostTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(inside->mHostTensor->size(), MNN_MEMORY_ALIGN_DEFAULT); TensorUtils::getDescribe(inside->mHostTensor)->memoryType = Tensor::InsideDescribe::MEMORY_HOST; originTensor->copyToHostTensor(inside->mHostTensor); + bool hasNoExecution = false; + if (nullptr != originTensor) { + auto backend = TensorUtils::getDescribeOrigin(originTensor)->getBackend(); + if (nullptr != backend) { + // Try to sync to check execution status + int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, originTensor); + if (NO_EXECUTION == syncResult) { + hasNoExecution = true; + } + } + } + if (hasNoExecution) { + MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n"); + return nullptr; + } return inside->mHostTensor->host(); } return originTensor->buffer().host; diff --git a/express/Utils.cpp b/express/Utils.cpp index f71f2c997b..6aac549ece 100644 --- a/express/Utils.cpp +++ b/express/Utils.cpp @@ -181,6 +181,23 @@ void* Executor::ComputeCache::mapOutput(int offset, Tensor* dest) { if (0 == tensor->usize()) { return nullptr; } + + bool hasNoExecution = false; + if (nullptr != tensor) { + auto backend = TensorUtils::getDescribeOrigin(tensor)->getBackend(); + if (nullptr != backend) { + // Try to sync to check execution status + int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, tensor); + if (NO_EXECUTION == syncResult) { + hasNoExecution = true; + } + } + } + if (hasNoExecution) { + MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n"); + return nullptr; + } + Utils::allocMemoryForHostTensor(dest); if(nullptr != dest->host()) { tensor->copyToHostTensor(dest); diff --git a/express/module/Module.cpp b/express/module/Module.cpp index f8b4728153..3c54ca89c3 100644 --- a/express/module/Module.cpp +++ b/express/module/Module.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "core/OpCommonUtils.hpp" #include "PipelineModule.hpp" #include "core/FileLoader.hpp" @@ -17,6 +18,7 @@ #include "Utils.hpp" #include "RuntimeAttr.hpp" #include "ModuleInside.hpp" +#include "core/TensorUtils.hpp" #include #ifdef MNN_INTERNAL_ENABLED #include "internal/auth/ModelAuth.hpp" @@ -221,6 +223,30 @@ class NetModule : public Module { Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime); outputs = mModule->onForward(inputs); } + + // Check execution status after forward + if (!outputs.empty()) { + bool hasNoExecution = false; + for (auto& v : outputs) { + auto t = Utils::getTensor(v); + if (nullptr != t) { + auto backend = TensorUtils::getDescribeOrigin(t)->getBackend(); + if (nullptr != backend) { + // Try to sync to check execution status + int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, t); + if (NO_EXECUTION == syncResult) { + hasNoExecution = true; + break; + } + } + } + } + if (hasNoExecution) { + MNN_PRINT("Warning, Backend has stop execute, return empty output vector varps\n"); + outputs.clear(); + } + } + #ifdef MNN_INTERNAL_ENABLED do { if (outputs.empty()) { diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 79f52ff2dc..885808fc44 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -15,6 +15,7 @@ #define MTLGPUFamilyMetal3_MNN 5001 #define MTLGPUFamilyMetal4_MNN 5002 +#define CHECK_IOS_UI_STATUS #if MNN_METAL_ENABLED #include #import "backend/metal/MNNMetalContext.h" @@ -22,6 +23,9 @@ #import "core/TensorUtils.hpp" #include "MetalCache_generated.h" #include "core/MNNFileUtils.h" +#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE +#import +#endif int MNNMetalGetTensorContent(MNNMetalTensorContent* content, void* tensor) { if (nullptr == content || nullptr == tensor) { return 0; @@ -776,6 +780,9 @@ static void _execute(id encoder, const MetalBackend::C MNN_ASSERT(false); // should not be handled here } int MetalBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) { + if (mRuntime->pExecutionStatus == NO_EXECUTION) { + return NO_EXECUTION; + } flushEncoder(); auto ctx = (__bridge MNNMetalContext *)context(); commit_net(); @@ -824,6 +831,19 @@ static void _execute(id encoder, const MetalBackend::C void MetalBackend::commit() const { +#ifdef CHECK_IOS_UI_STATUS +#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE + if ([UIApplication sharedApplication].applicationState == UIApplicationStateBackground || [UIApplication sharedApplication].applicationState == UIApplicationStateInactive) { + mRuntime->pExecutionStatus = NO_EXECUTION; + _commandBuffer = nil; + if (!mSupportDeferEncode) { + _commandBuffer_net = nil; + } + return; + } +#endif +#endif + mRuntime->pExecutionStatus = NO_ERROR; if (nil != _commandBuffer && _commandBuffer.status < MTLCommandBufferStatusCommitted) { [_commandBuffer commit]; mRuntime->_waiting = _commandBuffer; @@ -836,6 +856,19 @@ static void _execute(id encoder, const MetalBackend::C } void MetalBackend::commit_net() const { +#ifdef CHECK_IOS_UI_STATUS +#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE + if ([UIApplication sharedApplication].applicationState == UIApplicationStateBackground || [UIApplication sharedApplication].applicationState == UIApplicationStateInactive) { + mRuntime->pExecutionStatus = NO_EXECUTION; + _commandBuffer_net = nil; + if (!mSupportDeferEncode) { + _commandBuffer = nil; + } + return; + } +#endif +#endif + mRuntime->pExecutionStatus = NO_ERROR; if (nil != _commandBuffer_net && _commandBuffer_net.status < MTLCommandBufferStatusCommitted) { [_commandBuffer_net commit]; mRuntime->_waiting = _commandBuffer_net; diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 6850b6b4f6..e463f251f7 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -395,6 +395,7 @@ class Runtime : public NonCopyable { } mutable int pCurrentStatus = 0; // NO_ERROR + mutable int pExecutionStatus = 0; // NO_ERROR // TODO: Move to Backend void* pMeta = nullptr; diff --git a/source/core/Tensor.cpp b/source/core/Tensor.cpp index 18bf5ec7a6..664fa6b790 100644 --- a/source/core/Tensor.cpp +++ b/source/core/Tensor.cpp @@ -430,6 +430,14 @@ void* Tensor::map(MapType mtype, DimensionType dtype) { return mBuffer.host; } + if (mtype == Tensor::MAP_TENSOR_READ) { + int syncResult = bn->onSync(mtype, false, this); + if (NO_EXECUTION == syncResult) { + MNN_PRINT("Warning, Backend has stop execute, return nullptr for tensor map addr\n"); + return nullptr; + } + } + auto mapPtr = bn->onMapTensor(mtype, dtype, this); if(mapPtr != nullptr) { // Get mapPtr in specific backend diff --git a/transformers/llm/engine/demo/llm_demo.cpp b/transformers/llm/engine/demo/llm_demo.cpp index 305ef2169b..ec0f39c146 100644 --- a/transformers/llm/engine/demo/llm_demo.cpp +++ b/transformers/llm/engine/demo/llm_demo.cpp @@ -135,21 +135,21 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ if (context->audio_input_s > 0.0f) { audio_speed = context->audio_input_s / audio_s; } - printf("\n#################################\n"); - printf("prompt tokens num = %d\n", prompt_len); - printf("decode tokens num = %d\n", decode_len); - printf(" vision time = %.2f s\n", vision_s); - printf(" pixels_mp = %.2f MP\n", context->pixels_mp); - printf(" audio process time = %.2f s\n", audio_s); - printf(" audio input time = %.2f s\n", context->audio_input_s); - printf("prefill time = %.2f s\n", prefill_s); - printf(" decode time = %.2f s\n", decode_s); - printf(" sample time = %.2f s\n", sample_s); - printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); - printf(" decode speed = %.2f tok/s\n", decode_len / decode_s); - printf(" vision speed = %.3f MP/s\n", vision_speed); - printf(" audio RTF = %.3f \n", audio_s / context->audio_input_s); - printf("##################################\n"); + MNN_PRINT("\n#################################\n"); + MNN_PRINT("prompt tokens num = %d\n", prompt_len); + MNN_PRINT("decode tokens num = %d\n", decode_len); + MNN_PRINT(" vision time = %.2f s\n", vision_s); + MNN_PRINT(" pixels_mp = %.2f MP\n", context->pixels_mp); + MNN_PRINT(" audio process time = %.2f s\n", audio_s); + MNN_PRINT(" audio input time = %.2f s\n", context->audio_input_s); + MNN_PRINT("prefill time = %.2f s\n", prefill_s); + MNN_PRINT(" decode time = %.2f s\n", decode_s); + MNN_PRINT(" sample time = %.2f s\n", sample_s); + MNN_PRINT("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); + MNN_PRINT(" decode speed = %.2f tok/s\n", decode_len / decode_s); + MNN_PRINT(" vision speed = %.3f MP/s\n", vision_speed); + MNN_PRINT(" audio RTF = %.3f \n", audio_s / context->audio_input_s); + MNN_PRINT("##################################\n"); return 0; } @@ -165,12 +165,12 @@ static int ceval(Llm* llm, const std::vector& lines, std::string fi prompt += "\nC. " + elements[4]; prompt += "\nD. " + elements[5]; prompt += "\n\n"; - printf("%s", prompt.c_str()); - printf("## 进度: %d / %lu\n", i, lines.size() - 1); + MNN_PRINT("%s", prompt.c_str()); + MNN_PRINT("## 进度: %d / %lu\n", i, lines.size() - 1); std::ostringstream lineOs; llm->response(prompt.c_str(), &lineOs); auto line = lineOs.str(); - printf("%s", line.c_str()); + MNN_PRINT("%s", line.c_str()); answers.push_back(line); } { diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index 6ae61a5e35..20eff94be9 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -59,6 +59,13 @@ enum TuneType { // op encoder number for commit OP_ENCODER_NUMBER = 0, }; +enum class LlmStatus { + RUNNING = 0, + NORMAL_FINISHED = 1, + MAX_TOKENS_FINISHED = 2, + USER_CANCEL = 3, + INTERNAL_ERROR = 4, +}; enum class MatchStrictLevel : int; enum class NgramSelectRule : int; @@ -84,6 +91,8 @@ struct LlmContext { std::vector history_tokens; std::vector output_tokens; std::string generate_str; + // llm status + LlmStatus status; }; struct GenerationParams; class MNN_PUBLIC Llm { diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 63c590e0fd..c0cabd4414 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -467,6 +467,7 @@ std::vector Llm::forwardRaw(Express::VARP hiddenState, Express::V std::vector outputs = selectModule->onForward(inputs); if (outputs.empty()) { + mContext->status = LlmStatus::INTERNAL_ERROR; return outputs; } if (!mAsync) { @@ -592,6 +593,9 @@ std::vector Llm::forwardVec(MNN::Express::VARP input_embeds) { auto attention_mask = gen_attention_mask(blockSize); auto position_ids = gen_position_ids(blockSize); logits = forwardRaw(embed, attention_mask, position_ids); + if(logits.empty()) { + return logits; + } updateContext(blockSize, 0); } bool hasPad = false; @@ -623,6 +627,9 @@ std::vector Llm::forwardVec(MNN::Express::VARP input_embeds) { auto attention_mask = gen_attention_mask(forwardSize); auto position_ids = gen_position_ids(forwardSize); logits = forwardRaw(input_embeds, attention_mask, position_ids); + if(logits.empty()) { + return logits; + } } updateContext(-blockSize * blockNumber, 0); if (hasPad) { @@ -676,6 +683,7 @@ void Llm::generate_init(std::ostream* os, const char* end_with) { mContext->decode_us = 0; mContext->current_token = -1; mContext->sample_us = 0; + mContext->status = LlmStatus::RUNNING; if (!mConfig->reuse_kv()) { mContext->all_seq_len = 0; mContext->history_tokens.clear(); @@ -824,6 +832,7 @@ std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) Timer _t; forwardVec(input_embeds); if(mGenerateParam->outputs.size() < 1) { + mContext->status = LlmStatus::INTERNAL_ERROR; return {}; } updateContext(seqLen, 0); @@ -1132,7 +1141,14 @@ VARP Llm::gen_position_ids(int seq_len) { } bool Llm::is_stop(int token_id) { - return mTokenizer->is_stop(token_id); + if (mContext->status == LlmStatus::USER_CANCEL || mContext->status == LlmStatus::INTERNAL_ERROR) { + return true; + } + bool stop = mTokenizer->is_stop(token_id); + if (stop) { + mContext->status = LlmStatus::NORMAL_FINISHED; + } + return stop; } } // namespace Transformer } // namespace MNN diff --git a/transformers/llm/engine/src/speculative_decoding/eagle.cpp b/transformers/llm/engine/src/speculative_decoding/eagle.cpp index b4c892fd97..15548c64d3 100644 --- a/transformers/llm/engine/src/speculative_decoding/eagle.cpp +++ b/transformers/llm/engine/src/speculative_decoding/eagle.cpp @@ -328,9 +328,22 @@ void EagleGeneration::generate(GenerationParams& param) { std::vector accpetLens; auto newTokens = 0, steps = 0; while (true) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } steps++; MNN::Timer _dt; auto decodingInfo = treeDecoding(draftInfo); + for (auto o : decodingInfo) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } + if(decodingInfo.empty()) { + break; + } + treeDecodingTime += _dt.durationInUs(); auto acceptInfo = evaluatePosterior(draftInfo, decodingInfo[0]); newTokens += acceptInfo.acceptTokens.size(); @@ -352,6 +365,9 @@ void EagleGeneration::generate(GenerationParams& param) { eagleGenerateTime += _gt.durationInUs(); } mContext->decode_us += _t.durationInUs(); + if(newTokens >= param.max_new_tokens) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } #if EAGLE_DEBUG printf("\n### Tree Decoding Time: %f s, Eagle Generate Time: %f s\n", (float)treeDecodingTime / 1000000.0, (float)eagleGenerateTime / 1000000.0); printf("\n### Tree Decoding Avg Time: %f ms, steps: %d\n", (float)treeDecodingTime / 1000.0 / steps, steps); diff --git a/transformers/llm/engine/src/speculative_decoding/generate.cpp b/transformers/llm/engine/src/speculative_decoding/generate.cpp index 31d3a3b9f7..4ed01b1f5c 100644 --- a/transformers/llm/engine/src/speculative_decoding/generate.cpp +++ b/transformers/llm/engine/src/speculative_decoding/generate.cpp @@ -43,6 +43,9 @@ void ArGeneration::generate(GenerationParams& param) { int max_token = param.max_new_tokens; int len = 0; while (len < max_token) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } AUTOTIME; // Update gen seq mContext->current_token = mLlm->sample(param.outputs[0], param.validLogitStart, param.validLogitSize); @@ -63,9 +66,14 @@ void ArGeneration::generate(GenerationParams& param) { *mContext->os << decodeStr; *mContext->os << std::flush; } - // Compute Next Logits auto outputs = mLlm->forwardVec({mContext->current_token}); + for (auto o : outputs) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } if(outputs.empty()) { break; } @@ -74,6 +82,9 @@ void ArGeneration::generate(GenerationParams& param) { mContext->decode_us += _t.durationInUs(); len++; } + if(len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } } int Generation::draftVerify(VARP logits, const std::vector &drafts, bool& stop) { diff --git a/transformers/llm/engine/src/speculative_decoding/lookahead.cpp b/transformers/llm/engine/src/speculative_decoding/lookahead.cpp index cf4c2a5c79..d8ce38037e 100644 --- a/transformers/llm/engine/src/speculative_decoding/lookahead.cpp +++ b/transformers/llm/engine/src/speculative_decoding/lookahead.cpp @@ -89,6 +89,9 @@ void LookaheadGeneration::generate(GenerationParams& param) { int verify_len = mLlm->mDraftLength + 1; while (len < max_token) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } MNN::Timer _t; std::vector drafts; drafts.push_back(mContext->current_token); @@ -126,6 +129,12 @@ void LookaheadGeneration::generate(GenerationParams& param) { AUTOTIME; // do draft token parallel verify auto outputs = mLlm->forwardVec(drafts); + for (auto o : outputs) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } if(outputs.empty()) { break; } @@ -192,6 +201,9 @@ void LookaheadGeneration::generate(GenerationParams& param) { } } } + if(len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } #ifdef DUMP_PROFILE_INFO // adopt speculative decoding rate float spl_rate = 100.0 * spl_count / (spl_count + arg_count); diff --git a/transformers/llm/engine/src/speculative_decoding/mtp.cpp b/transformers/llm/engine/src/speculative_decoding/mtp.cpp index aefc4a5aa7..f5c6e0261a 100644 --- a/transformers/llm/engine/src/speculative_decoding/mtp.cpp +++ b/transformers/llm/engine/src/speculative_decoding/mtp.cpp @@ -151,6 +151,9 @@ void MtpGeneration::generate(GenerationParams& param) { int spl_count = 0; while (len < max_token) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } MNN::Timer _t; std::vector drafts; drafts.push_back(mContext->current_token); @@ -171,6 +174,12 @@ void MtpGeneration::generate(GenerationParams& param) { AUTOTIME; // do draft token parallel verify auto outputs = mLlm->forwardVec(drafts); + for (auto o : outputs) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } if (outputs.size() < 2) { break; } @@ -238,6 +247,9 @@ void MtpGeneration::generate(GenerationParams& param) { } } } + if(len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } #ifdef DUMP_PROFILE_INFO // draft accept rate if adopt speculative decoding float spl_accept_rate = 100.0 * spl_accept / spl_decode; From 73bfaa46fe4f9c0e24e0238073f5ab1a94edca6c Mon Sep 17 00:00:00 2001 From: ihb2032 <1355790728@qq.com> Date: Tue, 23 Dec 2025 06:59:30 +0000 Subject: [PATCH 059/314] opt(RVV): Optimize CV color conversion functions with intrinsics Optimize MNNC3 and MNNNV21 related color conversion functions using RVV intrinsics, including C3 to YUV/XYZ/HSV/BGR555/BGR565 and NV21 to RGBA/RGB/BGRA/BGR. Signed-off-by: ihb2032 <1355790728@qq.com> Co-authored-by: lyd1992 --- .../backend/cpu/riscv/rvv/MNNC3ToBGR555.cpp | 32 +++++++ .../backend/cpu/riscv/rvv/MNNC3ToBGR565.cpp | 32 +++++++ source/backend/cpu/riscv/rvv/MNNC3ToHSV.cpp | 89 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToXYZ.cpp | 71 +++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToYUV.cpp | 87 ++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNNV21ToBGR.cpp | 85 ++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNNV21ToBGRA.cpp | 89 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNNV21ToRGB.cpp | 85 ++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNNV21ToRGBA.cpp | 89 +++++++++++++++++++ 9 files changed, 659 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToBGR555.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToBGR565.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToHSV.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToXYZ.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToYUV.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNNV21ToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNNV21ToBGRA.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNNV21ToRGB.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNNV21ToRGBA.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToBGR555.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToBGR555.cpp new file mode 100644 index 0000000000..bcaee2fb98 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToBGR555.cpp @@ -0,0 +1,32 @@ +#include + +void MNNC3ToBGR555(const unsigned char* source, unsigned char* dest, + size_t count, bool bgr) { + unsigned short* dest16 = reinterpret_cast(dest); + size_t i = 0; + int rOffset = bgr ? 2 : 0; + int bOffset = bgr ? 0 : 2; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + bOffset, 3, vl); + vuint8m4_t shifted = __riscv_vsrl_vx_u8m4(channel, 3, vl); + vuint16m8_t result = __riscv_vzext_vf2_u16m8(shifted, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + vuint8m4_t masked = __riscv_vand_vx_u8m4(channel, 0xF8, vl); + vuint16m8_t wide = __riscv_vzext_vf2_u16m8(masked, vl); + wide = __riscv_vsll_vx_u16m8(wide, 2, vl); + result = __riscv_vor_vv_u16m8(result, wide, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + rOffset, 3, vl); + masked = __riscv_vand_vx_u8m4(channel, 0xF8, vl); + wide = __riscv_vzext_vf2_u16m8(masked, vl); + wide = __riscv_vsll_vx_u16m8(wide, 7, vl); + result = __riscv_vor_vv_u16m8(result, wide, vl); + + __riscv_vse16_v_u16m8(dest16 + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToBGR565.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToBGR565.cpp new file mode 100644 index 0000000000..dcf99158c6 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToBGR565.cpp @@ -0,0 +1,32 @@ +#include + +void MNNC3ToBGR565(const unsigned char* source, unsigned char* dest, + size_t count, bool bgr) { + unsigned short* dest16 = reinterpret_cast(dest); + size_t i = 0; + int rOffset = bgr ? 2 : 0; + int bOffset = bgr ? 0 : 2; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + bOffset, 3, vl); + vuint8m4_t shifted = __riscv_vsrl_vx_u8m4(channel, 3, vl); + vuint16m8_t result = __riscv_vzext_vf2_u16m8(shifted, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + vuint8m4_t masked = __riscv_vand_vx_u8m4(channel, 0xFC, vl); + vuint16m8_t wide = __riscv_vzext_vf2_u16m8(masked, vl); + wide = __riscv_vsll_vx_u16m8(wide, 3, vl); + result = __riscv_vor_vv_u16m8(result, wide, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + rOffset, 3, vl); + masked = __riscv_vand_vx_u8m4(channel, 0xF8, vl); + wide = __riscv_vzext_vf2_u16m8(masked, vl); + wide = __riscv_vsll_vx_u16m8(wide, 8, vl); + result = __riscv_vor_vv_u16m8(result, wide, vl); + + __riscv_vse16_v_u16m8(dest16 + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToHSV.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToHSV.cpp new file mode 100644 index 0000000000..0dc857b914 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToHSV.cpp @@ -0,0 +1,89 @@ +#include + +void MNNC3ToHSV(const unsigned char* source, unsigned char* dest, + size_t count, bool bgr, bool full) { + const float hrange = full ? 256.0f : 180.0f; + const float hscale = hrange / 6.0f; + const int hrangeI = full ? 256 : 180; + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + + vuint8m2_t vrU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 0, 3, vl); + vuint8m2_t vgU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 1, 3, vl); + vuint8m2_t vbU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 2, 3, vl); + if (bgr) { + vuint8m2_t tmp = vrU8; + vrU8 = vbU8; + vbU8 = tmp; + } + + vuint8m2_t vmaxU8 = __riscv_vmaxu_vv_u8m2( + __riscv_vmaxu_vv_u8m2(vrU8, vgU8, vl), vbU8, vl); + vuint8m2_t vminU8 = __riscv_vminu_vv_u8m2( + __riscv_vminu_vv_u8m2(vrU8, vgU8, vl), vbU8, vl); + vuint8m2_t vdiffU8 = __riscv_vsub_vv_u8m2(vmaxU8, vminU8, vl); + + vint16m4_t vr = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vrU8, vl)); + vint16m4_t vg = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vgU8, vl)); + vint16m4_t vb = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vbU8, vl)); + vint16m4_t vdiff = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vdiffU8, vl)); + + vbool4_t maskR = __riscv_vmseq_vv_u8m2_b4(vmaxU8, vrU8, vl); + vbool4_t maskG = __riscv_vmseq_vv_u8m2_b4(vmaxU8, vgU8, vl); + vbool4_t maskDiffZero = __riscv_vmseq_vx_u8m2_b4(vdiffU8, 0, vl); + vbool4_t maskVZero = __riscv_vmseq_vx_u8m2_b4(vmaxU8, 0, vl); + + vint16m4_t sum16 = __riscv_vadd_vv_i16m4( + __riscv_vsub_vv_i16m4(vr, vg, vl), + __riscv_vsll_vx_i16m4(vdiff, 2, vl), vl); + vint16m4_t temp16 = __riscv_vadd_vv_i16m4( + __riscv_vsub_vv_i16m4(vb, vr, vl), + __riscv_vsll_vx_i16m4(vdiff, 1, vl), vl); + sum16 = __riscv_vmerge_vvm_i16m4(sum16, temp16, maskG, vl); + sum16 = __riscv_vmerge_vvm_i16m4(sum16, __riscv_vsub_vv_i16m4(vg, vb, vl), maskR, vl); + + vfloat32m8_t sumF = __riscv_vfcvt_f_x_v_f32m8(__riscv_vsext_vf2_i32m8(sum16, vl), vl); + vfloat32m8_t diffF = __riscv_vfcvt_f_xu_v_f32m8(__riscv_vzext_vf4_u32m8(vdiffU8, vl), vl); + + sumF = __riscv_vfmul_vf_f32m8(sumF, hscale, vl); + sumF = __riscv_vfdiv_vv_f32m8(sumF, __riscv_vfmax_vf_f32m8(diffF, 1.0f, vl), vl); + sumF = __riscv_vfmerge_vfm_f32m8(sumF, 0.0f, maskDiffZero, vl); + + sumF = __riscv_vfadd_vf_f32m8(sumF, 0.5f, vl); + vint32m8_t sum = __riscv_vfcvt_rtz_x_f_v_i32m8(sumF, vl); + + vbool4_t isNegFrac = __riscv_vmflt_vf_f32m8_b4(sumF, 0.0f, vl); + vfloat32m8_t sumBack = __riscv_vfcvt_f_x_v_f32m8(sum, vl); + vbool4_t notInt = __riscv_vmfne_vv_f32m8_b4(sumF, sumBack, vl); + vbool4_t floorAdjust = __riscv_vmand_mm_b4(isNegFrac , notInt, vl); + sum = __riscv_vsub_vx_i32m8_mu(floorAdjust, sum, sum, 1, vl); + + vbool4_t hNeg = __riscv_vmslt_vx_i32m8_b4(sum, 0, vl); + sum = __riscv_vadd_vx_i32m8_mu(hNeg, sum, sum, hrangeI, vl); + + sum = __riscv_vmin_vx_i32m8(__riscv_vmax_vx_i32m8(sum, 0, vl), hrangeI - 1, vl); + sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + vuint8m2_t result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 0, 3, result, vl); + + sumF = __riscv_vfcvt_f_xu_v_f32m8(__riscv_vzext_vf4_u32m8(vmaxU8, vl), vl); + sumF = __riscv_vfdiv_vv_f32m8( + __riscv_vfmul_vf_f32m8(diffF, 255.0f, vl), + __riscv_vfmax_vf_f32m8(sumF, 1.0f, vl), vl); + sumF = __riscv_vfmerge_vfm_f32m8(sumF, 0.0f, maskVZero, vl); + + sumF = __riscv_vfadd_vf_f32m8(sumF, 0.5f, vl); + sum = __riscv_vfcvt_rtz_x_f_v_i32m8(sumF, vl); + + sum = __riscv_vmin_vx_i32m8(__riscv_vmax_vx_i32m8(sum, 0, vl), 255, vl); + sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 1, 3, result, vl); + + __riscv_vsse8_v_u8m2(dest + 3 * i + 2, 3, vmaxU8, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToXYZ.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToXYZ.cpp new file mode 100644 index 0000000000..88c8fc7483 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToXYZ.cpp @@ -0,0 +1,71 @@ +#include +#include + +void MNNC3ToXYZ(const unsigned char* source, unsigned char* dest, + size_t count, bool bgr) { + static const int coeffs[] = { + 1689, 1465, 739, + 871, 2929, 296, + 79, 488, 3892 + }; + + int r0 = 0, r1 = 3, r2 = 6, b0 = 2, b1 = 5, b2 = 8; + if (bgr) { + std::swap(r0, b0); + std::swap(r1, b1); + std::swap(r2, b2); + } + + int16_t C0 = coeffs[r0], C1 = coeffs[1], C2 = coeffs[b0], + C3 = coeffs[r1], C4 = coeffs[4], C5 = coeffs[b1], + C6 = coeffs[r2], C7 = coeffs[7], C8 = coeffs[b2]; + + size_t i = 0; + const int32_t rounding = 1 << 11; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + vuint8m2_t vrU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 0, 3, vl); + vuint8m2_t vgU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 1, 3, vl); + vuint8m2_t vbU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 2, 3, vl); + + vint16m4_t vr = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vrU8, vl)); + vint16m4_t vg = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vgU8, vl)); + vint16m4_t vb = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vbU8, vl)); + + vint32m8_t sum = __riscv_vwmul_vx_i32m8(vr, C0, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C1, vg, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C2, vb, vl); + sum = __riscv_vadd_vx_i32m8(sum, rounding, vl); + sum = __riscv_vsra_vx_i32m8(sum, 12, vl); + sum = __riscv_vmax_vx_i32m8(sum, 0, vl); + sum = __riscv_vmin_vx_i32m8(sum, 255, vl); + vint16m4_t sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + vuint8m2_t result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 0, 3, result, vl); + + sum = __riscv_vwmul_vx_i32m8(vr, C3, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C4, vg, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C5, vb, vl); + sum = __riscv_vadd_vx_i32m8(sum, rounding, vl); + sum = __riscv_vsra_vx_i32m8(sum, 12, vl); + sum = __riscv_vmax_vx_i32m8(sum, 0, vl); + sum = __riscv_vmin_vx_i32m8(sum, 255, vl); + sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 1, 3, result, vl); + + sum = __riscv_vwmul_vx_i32m8(vr, C6, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C7, vg, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C8, vb, vl); + sum = __riscv_vadd_vx_i32m8(sum, rounding, vl); + sum = __riscv_vsra_vx_i32m8(sum, 12, vl); + sum = __riscv_vmax_vx_i32m8(sum, 0, vl); + sum = __riscv_vmin_vx_i32m8(sum, 255, vl); + sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 2, 3, result, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToYUV.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToYUV.cpp new file mode 100644 index 0000000000..7f4963f325 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToYUV.cpp @@ -0,0 +1,87 @@ +#include +#include + +void MNNC3ToYUV(const unsigned char* source, unsigned char* dest, + size_t count, bool bgr, bool yuv) { + static const int coeffs[] = { + // Y + 4899, 9617, 1868, + // Cr + 8192, -6860, -1332, + // Cb + -2765, -5427, 8192, + // U + -2412, -4734, 7146, + // V + 10076, -8438, -1638 + }; + + int r0 = 0, r1 = 3, r2 = 6, + g0 = 1, g1 = 4, g2 = 7, + b0 = 2, b1 = 5, b2 = 8; + if (yuv) { + r1 = 9, r2 = 12; + g1 = 10, g2 = 13; + b1 = 11, b2 = 14; + } + if (bgr) { + std::swap(r0, b0); + std::swap(r1, b1); + std::swap(r2, b2); + } + + int16_t C0 = coeffs[r0], C1 = coeffs[g0], C2 = coeffs[b0], + C3 = coeffs[r1], C4 = coeffs[g1], C5 = coeffs[b1], + C6 = coeffs[r2], C7 = coeffs[g2], C8 = coeffs[b2]; + + size_t i = 0; + const int32_t rounding = 1 << 13; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + vuint8m2_t vrU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 0, 3, vl); + vuint8m2_t vgU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 1, 3, vl); + vuint8m2_t vbU8 = __riscv_vlse8_v_u8m2(source + 3 * i + 2, 3, vl); + + vint16m4_t vr = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vrU8, vl)); + vint16m4_t vg = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vgU8, vl)); + vint16m4_t vb = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(vbU8, vl)); + + vint32m8_t sum = __riscv_vwmul_vx_i32m8(vr, C0, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C1, vg, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C2, vb, vl); + sum = __riscv_vadd_vx_i32m8(sum, rounding, vl); + sum = __riscv_vsra_vx_i32m8(sum, 14, vl); + sum = __riscv_vmax_vx_i32m8(sum, 0, vl); + sum = __riscv_vmin_vx_i32m8(sum, 255, vl); + vint16m4_t sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + vuint8m2_t result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 0, 3, result, vl); + + sum = __riscv_vwmul_vx_i32m8(vr, C3, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C4, vg, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C5, vb, vl); + sum = __riscv_vadd_vx_i32m8(sum, rounding, vl); + sum = __riscv_vsra_vx_i32m8(sum, 14, vl); + sum = __riscv_vadd_vx_i32m8(sum, 128, vl); + sum = __riscv_vmax_vx_i32m8(sum, 0, vl); + sum = __riscv_vmin_vx_i32m8(sum, 255, vl); + sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 1, 3, result, vl); + + sum = __riscv_vwmul_vx_i32m8(vr, C6, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C7, vg, vl); + sum = __riscv_vwmacc_vx_i32m8(sum, C8, vb, vl); + sum = __riscv_vadd_vx_i32m8(sum, rounding, vl); + sum = __riscv_vsra_vx_i32m8(sum, 14, vl); + sum = __riscv_vadd_vx_i32m8(sum, 128, vl); + sum = __riscv_vmax_vx_i32m8(sum, 0, vl); + sum = __riscv_vmin_vx_i32m8(sum, 255, vl); + sum16 = __riscv_vnsra_wx_i16m4(sum, 0, vl); + result = __riscv_vnsrl_wx_u8m2(__riscv_vreinterpret_v_i16m4_u16m4(sum16), 0, vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 2, 3, result, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNNV21ToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNNV21ToBGR.cpp new file mode 100644 index 0000000000..8348388018 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNNV21ToBGR.cpp @@ -0,0 +1,85 @@ +#include +#include + +void MNNNV21ToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + const unsigned char* y = source; + const unsigned char* uv = source + count; + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + vl = vl & ~1UL; + if (vl == 0) break; + size_t vlHalf = vl / 2; + vuint8m2_t dupIdx = __riscv_vsrl_vx_u8m2(__riscv_vid_v_u8m2(vl), 1, vl); + + vuint8m2_t channel8 = __riscv_vle8_v_u8m2(y + i, vl); + vint16m4_t y16 = __riscv_vreinterpret_v_u16m4_i16m4( + __riscv_vzext_vf2_u16m4(channel8, vl)); + + vuint8m1_t half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t v16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2 + 1, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t u16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + vint32m8_t y32 = __riscv_vsll_vx_i32m8( + __riscv_vwcvt_x_x_v_i32m8(y16, vl), 6, vl); + vint32m8_t v32 = __riscv_vwcvt_x_x_v_i32m8(v16, vl); + vint32m8_t u32 = __riscv_vwcvt_x_x_v_i32m8(u16, vl); + + vint32m8_t calc32 = __riscv_vmacc_vx_i32m8(y32, 73, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + vint16m4_t res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 2, 3, channel8, vl); + + calc32 = __riscv_vnmsac_vx_i32m8(y32, 25, u32, vl); + calc32 = __riscv_vnmsac_vx_i32m8(calc32, 37, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 1, 3, channel8, vl); + + calc32 = __riscv_vmacc_vx_i32m8(y32, 130, u32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 0, 3, channel8, vl); + + i += vl; + } + + for (; i < count; ++i) { + int Y = y[i]; + int U = (int)uv[(i / 2) * 2 + 1] - 128; + int V = (int)uv[(i / 2) * 2 + 0] - 128; + Y = Y << 6; + int R = (Y + 73 * V) >> 6; + int G = (Y - 25 * U - 37 * V) >> 6; + int B = (Y + 130 * U) >> 6; + R = std::min(std::max(R, 0), 255); + G = std::min(std::max(G, 0), 255); + B = std::min(std::max(B, 0), 255); + dest[3 * i + 2] = (uint8_t)R; + dest[3 * i + 1] = (uint8_t)G; + dest[3 * i + 0] = (uint8_t)B; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNNV21ToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNNV21ToBGRA.cpp new file mode 100644 index 0000000000..d39c281aeb --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNNV21ToBGRA.cpp @@ -0,0 +1,89 @@ +#include +#include + +void MNNNV21ToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { + const unsigned char* y = source; + const unsigned char* uv = source + count; + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + vl = vl & ~1UL; + if (vl == 0) break; + size_t vlHalf = vl / 2; + vuint8m2_t dupIdx = __riscv_vsrl_vx_u8m2(__riscv_vid_v_u8m2(vl), 1, vl); + + vuint8m2_t channel8 = __riscv_vle8_v_u8m2(y + i, vl); + vint16m4_t y16 = __riscv_vreinterpret_v_u16m4_i16m4( + __riscv_vzext_vf2_u16m4(channel8, vl)); + + vuint8m1_t half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t v16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2 + 1, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t u16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + vint32m8_t y32 = __riscv_vsll_vx_i32m8( + __riscv_vwcvt_x_x_v_i32m8(y16, vl), 6, vl); + vint32m8_t v32 = __riscv_vwcvt_x_x_v_i32m8(v16, vl); + vint32m8_t u32 = __riscv_vwcvt_x_x_v_i32m8(u16, vl); + + vint32m8_t calc32 = __riscv_vmacc_vx_i32m8(y32, 73, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + vint16m4_t res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 2, 4, channel8, vl); + + calc32 = __riscv_vnmsac_vx_i32m8(y32, 25, u32, vl); + calc32 = __riscv_vnmsac_vx_i32m8(calc32, 37, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 1, 4, channel8, vl); + + calc32 = __riscv_vmacc_vx_i32m8(y32, 130, u32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 0, 4, channel8, vl); + + channel8 = __riscv_vmv_v_x_u8m2(255, vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 3, 4, channel8, vl); + + i += vl; + } + + for (; i < count; ++i) { + int Y = y[i]; + int U = (int)uv[(i / 2) * 2 + 1] - 128; + int V = (int)uv[(i / 2) * 2 + 0] - 128; + Y = Y << 6; + int R = (Y + 73 * V) >> 6; + int G = (Y - 25 * U - 37 * V) >> 6; + int B = (Y + 130 * U) >> 6; + R = std::min(std::max(R, 0), 255); + G = std::min(std::max(G, 0), 255); + B = std::min(std::max(B, 0), 255); + dest[4 * i + 2] = (uint8_t)R; + dest[4 * i + 1] = (uint8_t)G; + dest[4 * i + 0] = (uint8_t)B; + dest[4 * i + 3] = 255; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNNV21ToRGB.cpp b/source/backend/cpu/riscv/rvv/MNNNV21ToRGB.cpp new file mode 100644 index 0000000000..cca29e5020 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNNV21ToRGB.cpp @@ -0,0 +1,85 @@ +#include +#include + +void MNNNV21ToRGB(const unsigned char* source, unsigned char* dest, size_t count) { + const unsigned char* y = source; + const unsigned char* uv = source + count; + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + vl = vl & ~1UL; + if (vl == 0) break; + size_t vlHalf = vl / 2; + vuint8m2_t dupIdx = __riscv_vsrl_vx_u8m2(__riscv_vid_v_u8m2(vl), 1, vl); + + vuint8m2_t channel8 = __riscv_vle8_v_u8m2(y + i, vl); + vint16m4_t y16 = __riscv_vreinterpret_v_u16m4_i16m4( + __riscv_vzext_vf2_u16m4(channel8, vl)); + + vuint8m1_t half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t v16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2 + 1, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t u16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + vint32m8_t y32 = __riscv_vsll_vx_i32m8( + __riscv_vwcvt_x_x_v_i32m8(y16, vl), 6, vl); + vint32m8_t v32 = __riscv_vwcvt_x_x_v_i32m8(v16, vl); + vint32m8_t u32 = __riscv_vwcvt_x_x_v_i32m8(u16, vl); + + vint32m8_t calc32 = __riscv_vmacc_vx_i32m8(y32, 73, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + vint16m4_t res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 0, 3, channel8, vl); + + calc32 = __riscv_vnmsac_vx_i32m8(y32, 25, u32, vl); + calc32 = __riscv_vnmsac_vx_i32m8(calc32, 37, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 1, 3, channel8, vl); + + calc32 = __riscv_vmacc_vx_i32m8(y32, 130, u32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 3 * i + 2, 3, channel8, vl); + + i += vl; + } + + for (; i < count; ++i) { + int Y = y[i]; + int U = (int)uv[(i / 2) * 2 + 1] - 128; + int V = (int)uv[(i / 2) * 2 + 0] - 128; + Y = Y << 6; + int R = (Y + 73 * V) >> 6; + int G = (Y - 25 * U - 37 * V) >> 6; + int B = (Y + 130 * U) >> 6; + R = std::min(std::max(R, 0), 255); + G = std::min(std::max(G, 0), 255); + B = std::min(std::max(B, 0), 255); + dest[3 * i + 0] = (uint8_t)R; + dest[3 * i + 1] = (uint8_t)G; + dest[3 * i + 2] = (uint8_t)B; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNNV21ToRGBA.cpp b/source/backend/cpu/riscv/rvv/MNNNV21ToRGBA.cpp new file mode 100644 index 0000000000..dc6e2f0090 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNNV21ToRGBA.cpp @@ -0,0 +1,89 @@ +#include +#include + +void MNNNV21ToRGBA(const unsigned char* source, unsigned char* dest, size_t count) { + const unsigned char* y = source; + const unsigned char* uv = source + count; + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m2(count - i); + vl = vl & ~1UL; + if (vl == 0) break; + size_t vlHalf = vl / 2; + vuint8m2_t dupIdx = __riscv_vsrl_vx_u8m2(__riscv_vid_v_u8m2(vl), 1, vl); + + vuint8m2_t channel8 = __riscv_vle8_v_u8m2(y + i, vl); + vint16m4_t y16 = __riscv_vreinterpret_v_u16m4_i16m4( + __riscv_vzext_vf2_u16m4(channel8, vl)); + + vuint8m1_t half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t v16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + half8 = __riscv_vlse8_v_u8m1(uv + (i / 2) * 2 + 1, 2, vlHalf); + channel8 = __riscv_vrgather_vv_u8m2( + __riscv_vlmul_ext_v_u8m1_u8m2(half8), dupIdx, vl); + vint16m4_t u16 = __riscv_vsub_vx_i16m4( + __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(channel8, vl)), + 128, vl); + + vint32m8_t y32 = __riscv_vsll_vx_i32m8( + __riscv_vwcvt_x_x_v_i32m8(y16, vl), 6, vl); + vint32m8_t v32 = __riscv_vwcvt_x_x_v_i32m8(v16, vl); + vint32m8_t u32 = __riscv_vwcvt_x_x_v_i32m8(u16, vl); + + vint32m8_t calc32 = __riscv_vmacc_vx_i32m8(y32, 73, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + vint16m4_t res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 0, 4, channel8, vl); + + calc32 = __riscv_vnmsac_vx_i32m8(y32, 25, u32, vl); + calc32 = __riscv_vnmsac_vx_i32m8(calc32, 37, v32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 1, 4, channel8, vl); + + calc32 = __riscv_vmacc_vx_i32m8(y32, 130, u32, vl); + calc32 = __riscv_vsra_vx_i32m8(calc32, 6, vl); + calc32 = __riscv_vmax_vx_i32m8(calc32, 0, vl); + calc32 = __riscv_vmin_vx_i32m8(calc32, 255, vl); + res16 = __riscv_vncvt_x_x_w_i16m4(calc32, vl); + channel8 = __riscv_vncvt_x_x_w_u8m2( + __riscv_vreinterpret_v_i16m4_u16m4(res16), vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 2, 4, channel8, vl); + + channel8 = __riscv_vmv_v_x_u8m2(255, vl); + __riscv_vsse8_v_u8m2(dest + 4 * i + 3, 4, channel8, vl); + + i += vl; + } + + for (; i < count; ++i) { + int Y = y[i]; + int U = (int)uv[(i / 2) * 2 + 1] - 128; + int V = (int)uv[(i / 2) * 2 + 0] - 128; + Y = Y << 6; + int R = (Y + 73 * V) >> 6; + int G = (Y - 25 * U - 37 * V) >> 6; + int B = (Y + 130 * U) >> 6; + R = std::min(std::max(R, 0), 255); + G = std::min(std::max(G, 0), 255); + B = std::min(std::max(B, 0), 255); + dest[4 * i + 0] = (uint8_t)R; + dest[4 * i + 1] = (uint8_t)G; + dest[4 * i + 2] = (uint8_t)B; + dest[4 * i + 3] = 255; + } +} From db2eb9abcf68eb107d35e92685e03a3471a4320d Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Tue, 23 Dec 2025 15:47:54 +0800 Subject: [PATCH 060/314] Merge branch feature/fix_sync into master Title: [Bugfix:CI] Fix duplicate msg when sync to github. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 这段代码在 `copybara_sync.sh` 脚本中新增了一个功能,用于检测并跳过从 GitHub 导入的 commits,通过识别包含 `GitOrigin-RevId` 的 commit 来确定上次同步点,并从该点之后的第一个非导入 commit 开始进行同步。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25060238 GitOrigin-RevId: ca65f11f52c1b76a826cbdc260a063d1467a8f35 --- docs/compile/cmake.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index 6513e38fad..91a9c03959 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -101,4 +101,5 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_BUILD_LLM_OMNI | 若构建基于MNN的llm库和demo,是否支持图像和音频输入功能,默认为`OFF` 。仅在MNN_BUILD_LLM 打开时生效。开启时 MNN_BUILD_OPENCV , MNN_IMGCODECS , MNN_BUILD_AUDIO 同时打开| | MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo,默认为`OFF` . 打开时MNN_BUILD_OPENCV , MNN_IMGCODECS, MNN_LOW_MEMORY, MNN_SUPPORT_TRANSFORMER_FUSE 同步开启| | MNN_KLEIDIAI | 是否集成ARM的klediAI加速库,默认为`ON` | +| MNN_KLEIDIAI_DEFAULT_ON | 是否默认使用KLEIDIAI的Kernel, 默认为`OFF` | | MNN_USE_RVV | 是否启用RISC-V向量扩展支持,默认为`OFF` | From a92edc9b42124b96368db4d2da1155afd67afe4a Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 09:56:50 +0800 Subject: [PATCH 061/314] Project import generated by Copybara. GitOrigin-RevId: fd90884f44c381932e3de8224bc69a0c327a3344 --- CMakeLists.txt | 1 - README.md | 14 +- README_CN.md | 10 +- README_JP.md | 9 +- build_lib.sh | 807 ------------------ docs/compile/cmake.md | 1 - docs/transformers/diffusion.md | 3 +- express/Expr.cpp | 16 - express/Utils.cpp | 17 - express/module/Module.cpp | 26 - source/backend/cpu/CPUBackend.cpp | 8 +- source/backend/cpu/CPUBackend.hpp | 3 - source/backend/cpu/CPUBinary.cpp | 60 +- source/backend/cpu/CPUBinary.hpp | 4 - source/backend/cpu/CPUMatMul.cpp | 28 +- source/backend/cpu/CPUMatMul.hpp | 7 +- source/backend/cpu/CPURNNSequenceGRU.cpp | 70 +- source/backend/cpu/CPURNNSequenceGRU.hpp | 15 +- source/backend/cpu/CPURaster.cpp | 631 +++++++------- source/backend/cpu/CPURaster.hpp | 3 +- source/backend/cpu/ThreadPool.cpp | 32 +- source/backend/cpu/ThreadPool.hpp | 6 +- source/backend/cpu/arm/CMakeLists.txt | 3 - .../backend/cpu/compute/CommonOptFunction.cpp | 88 +- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 - .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 - .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 - .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 -- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 - .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 - .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 -- source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 - .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 -- .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 - .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 -- .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 - .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 -- .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 -- .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 - source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 - source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 - source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 -- source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 -- source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 - .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 - .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 - source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 - .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 - source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 -- .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 - .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 - .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 - source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 -- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 - .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 - source/backend/metal/MetalBackend.mm | 33 - source/core/Backend.hpp | 7 +- source/core/Concurrency.h | 13 +- source/core/OpCommonUtils.cpp | 91 ++ source/core/OpCommonUtils.hpp | 1 + source/core/Tensor.cpp | 8 - source/core/TensorUtils.cpp | 12 - source/core/TensorUtils.hpp | 1 - source/geometry/GeometryComputerUtils.cpp | 4 +- source/geometry/GeometryComputerUtils.hpp | 2 +- source/geometry/GeometryReduce.cpp | 104 +-- source/geometry/GeometryReshape.cpp | 11 +- source/math/Vec.hpp | 3 +- test/core/ThreadPoolTest.cpp | 6 +- tools/cpp/ExprDebug.hpp | 53 +- tools/cpp/ModuleBasic.cpp | 46 +- transformers/diffusion/export/onnx_export.py | 30 +- transformers/llm/engine/demo/llm_demo.cpp | 36 +- transformers/llm/engine/include/llm/llm.hpp | 9 - transformers/llm/engine/src/llm.cpp | 39 +- .../engine/src/speculative_decoding/eagle.cpp | 16 - .../src/speculative_decoding/generate.cpp | 13 +- .../src/speculative_decoding/lookahead.cpp | 12 - .../engine/src/speculative_decoding/mtp.cpp | 12 - 85 files changed, 636 insertions(+), 3142 deletions(-) delete mode 100644 build_lib.sh delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f99e37ec1c..67502b606b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,7 +258,6 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) -option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/README.md b/README.md index 7959890c16..5fe168ed05 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #4 (Available): 160170007549 - -Group #3 (Full) +Group #1 (Full): 23329087 Group #2 (Full): 23350225 -Group #1 (Full): 23329087 +Group #3: QR code: + +![MNN-3](doc/dingdingmnn3.png) ## Historical Paper diff --git a/README_CN.md b/README_CN.md index f769a1e14b..edcf823a28 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,10 +111,12 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群3 (可加入): 160170007549 -- 钉钉群3 (已无法加入) -- 钉钉群2 (已满): 23350225 -- 钉钉群1 (已满): 23329087 +- 钉钉群1:23329087 +- 钉钉群2:23350225 +- 钉钉群3:扫描二维码加入 + +![MNN-3](doc/dingdingmnn3.png) + ## 历史论文 diff --git a/README_JP.md b/README_JP.md index 2f33def31a..c2baa58d94 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,14 +117,13 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: - -グループ#4 :160170007549 - -グループ#3 (満員) +グループ#1(満員):23329087 グループ#2(満員):23350225 -グループ#1(満員):23329087 +グループ#3:QRコード: + +![MNN-3](doc/dingdingmnn3.png) ## 歴史的な論文 diff --git a/build_lib.sh b/build_lib.sh deleted file mode 100644 index c839b6e7b6..0000000000 --- a/build_lib.sh +++ /dev/null @@ -1,807 +0,0 @@ -#!/bin/bash - -# MNN 统一构建脚本 -# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN - -set -e - -# 颜色输出 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# 默认配置 -BUILD_ANDROID=false -BUILD_IOS=false -BUILD_PYTHON=false -BUILD_IOS_SIMULATOR=false -BUILD_HARMONY=false -ANDROID_NDK="" -HARMONY_HOME="" -OUTPUT_DIR="mnn_builds" -CLEAN_BUILD=true -VERSION="" -PYTHON_CMD="python3" -DEPS_OPTIONS="" - -# 获取项目根目录 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$SCRIPT_DIR" - -print_usage() { - echo "MNN 统一构建脚本" - echo "" - echo "用法: $0 [选项]" - echo "" - echo "构建目标:" - echo " --android 构建 Android 版本 (需要 --ndk)" - echo " --ios 构建 iOS 真机版本 (arm64)" - echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" - echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" - echo " --python 构建 Python 版本" - echo "" - echo "Android 选项:" - echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" - echo "" - echo "鸿蒙选项:" - echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" - echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" - echo "" - echo "Python 选项:" - echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" - echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" - echo " --python-cmd CMD Python 命令 (默认: python3)" - echo "" - echo "通用选项:" - echo " -o, --output DIR 输出目录 (默认: mnn_builds)" - echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" - echo " --no-clean 不清理之前的构建目录" - echo " -h, --help 显示帮助信息" - echo "" - echo "示例:" - echo " # 构建所有平台" - echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 仅构建 Android" - echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 构建鸿蒙版本" - echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" - echo "" - echo " # 构建 iOS 真机和模拟器" - echo " $0 --ios --ios-simulator" - echo "" - echo " # 构建 Python (带 LLM 支持)" - echo " $0 --python --python-deps llm,opencl" -} - -# 解析命令行参数 -while [[ $# -gt 0 ]]; do - case $1 in - --android) - BUILD_ANDROID=true - shift - ;; - --ios) - BUILD_IOS=true - shift - ;; - --ios-simulator) - BUILD_IOS_SIMULATOR=true - shift - ;; - --python) - BUILD_PYTHON=true - shift - ;; - --harmony) - BUILD_HARMONY=true - shift - ;; - --ndk) - ANDROID_NDK="$2" - shift 2 - ;; - --harmony-home) - HARMONY_HOME="$2" - shift 2 - ;; - --python-deps) - DEPS_OPTIONS="$2" - shift 2 - ;; - --python-cmd) - PYTHON_CMD="$2" - shift 2 - ;; - -o|--output) - OUTPUT_DIR="$2" - shift 2 - ;; - -v|--version) - VERSION="$2" - shift 2 - ;; - --no-clean) - CLEAN_BUILD=false - shift - ;; - -h|--help) - print_usage - exit 0 - ;; - *) - echo -e "${RED}错误: 未知选项 $1${NC}" - print_usage - exit 1 - ;; - esac -done - -# 检查是否至少选择了一个构建目标 -if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then - echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" - print_usage - exit 1 -fi - -# 检查 Android NDK -if [ "$BUILD_ANDROID" = true ]; then - if [ -z "$ANDROID_NDK" ]; then - echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" - echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" - exit 1 - fi - - # 展开路径中的 ~ - ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" - - if [ ! -d "$ANDROID_NDK" ]; then - echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" - exit 1 - fi -fi - -# 查找鸿蒙工具链的函数 -find_harmony_toolchain() { - # 如果环境变量已设置,优先使用 - if [ -n "$HARMONY_HOME" ]; then - # 支持两种路径格式:直接指定或带 native 子目录 - if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - fi - fi - - # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 - local sdk_base="$HOME/Library/OpenHarmony/Sdk" - if [ -d "$sdk_base" ]; then - # 查找所有版本号目录下的 native 或 toolchains - # 按版本号倒序排列,优先使用最新版本 - for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do - if [ -d "$version_dir" ]; then - # 尝试 native/build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir/native" - return 0 - fi - # 尝试 build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir" - return 0 - fi - fi - done - fi - - # 其他可能的路径 - local possible_paths=( - "$HOME/Library/OpenHarmony/Sdk/native" - "$HOME/HarmonyOS/Sdk/native" - "$HOME/.ohos/native" - "/opt/HarmonyOS/Sdk/native" - "/usr/local/HarmonyOS/Sdk/native" - "$HOME/Library/HarmonyOS/Sdk/native" - ) - - # 尝试查找 - for path in "${possible_paths[@]}"; do - if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then - echo "$path" - return 0 - fi - done - - # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 - # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 - local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) - if [ -n "$found" ]; then - # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 - found=$(dirname "$found") - if [ "$(basename "$found")" = "cmake" ]; then - found=$(dirname "$found") - if [ "$(basename "$found")" = "build" ]; then - found=$(dirname "$found") - fi - fi - echo "$found" - return 0 - fi - - return 1 -} - -# 检查鸿蒙工具链 -if [ "$BUILD_HARMONY" = true ]; then - # 展开路径中的 ~ - if [ -n "$HARMONY_HOME" ]; then - HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" - fi - - # 尝试查找工具链 - if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - # 检查是否在 native 子目录下 - if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - else - echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" - HARMONY_HOME=$(find_harmony_toolchain) - - if [ -z "$HARMONY_HOME" ]; then - echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" - echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" - echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" - echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi - - # 如果找到的是带 native 的路径,需要调整 - if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - fi - fi - - echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" - fi - - # 验证工具链文件存在 - if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi -fi - -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}MNN 统一构建脚本${NC}" -echo -e "${GREEN}========================================${NC}" -echo "项目根目录: $PROJECT_ROOT" -echo "输出目录: $OUTPUT_DIR" -echo "" -echo -e "${BLUE}构建目标:${NC}" -[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" -[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" -[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" -[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" -echo "" - -cd "$PROJECT_ROOT" -mkdir -p "$OUTPUT_DIR" - -# ============================================================================ -# 构建 Android 版本 -# ============================================================================ -if [ "$BUILD_ANDROID" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Android 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export ANDROID_NDK - - ANDROID_BUILD_DIR="project/android" - ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" - - cd "$ANDROID_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 Android 构建...${NC}" - rm -rf build_32 build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - - # 构建 armeabi-v7a - echo -e "${BLUE}构建 armeabi-v7a...${NC}" - mkdir -p build_32 - cd build_32 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="armeabi-v7a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-14 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="arm64-v8a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-21 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出 Android 库文件...${NC}" - mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN - - cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true - cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/android/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" - - echo -e "${GREEN}Android 构建完成!${NC}" - echo "输出位置: $ANDROID_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 真机版本 -# ============================================================================ -if [ "$BUILD_IOS" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 真机版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_64 - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建真机版本 (arm64) - echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" - rm -rf ios_64 - mkdir ios_64 - cd ios_64 - cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=OS64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - > /dev/null - make MNN -j8 > /dev/null - cd .. - - # 输出真机版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理 - rm -rf ios_64 - - echo -e "${GREEN}iOS 真机构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 模拟器版本 -# ============================================================================ -if [ "$BUILD_IOS_SIMULATOR" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) - BUILD_X86_64=false - echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" - rm -rf ios_simulator_x86 - mkdir ios_simulator_x86 - cd ios_simulator_x86 - if cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="x86_64" \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - && make MNN -j8; then - if [ -f "MNN.framework/MNN" ]; then - BUILD_X86_64=true - echo -e "${GREEN}x86_64 模拟器构建成功${NC}" - cd .. - else - echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - else - echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - - # 构建 arm64 模拟器版本 - echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" - rm -rf ios_simulator_arm64 - mkdir ios_simulator_arm64 - cd ios_simulator_arm64 - if ! cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON; then - echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - if ! make MNN -j8; then - echo -e "${RED}错误: arm64 模拟器编译失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - cd .. - - # 验证构建产物 - if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then - echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - - # 合并模拟器架构 - echo -e "${BLUE}合并模拟器架构...${NC}" - rm -rf ios_simulator - mkdir ios_simulator - - if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then - # 合并 x86_64 + arm64 - echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" - cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework - mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 - if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then - echo -e "${RED}错误: 合并架构失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - rm ios_simulator/MNN.framework/MNN_x86 - else - # 仅使用 arm64 - echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" - cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework - fi - - # 输出模拟器版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理临时目录 - rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 - - echo -e "${GREEN}iOS 模拟器构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建鸿蒙版本 -# ============================================================================ -if [ "$BUILD_HARMONY" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建鸿蒙版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export HARMONY_HOME - - HARMONY_BUILD_DIR="project/harmony" - HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" - - cd "$HARMONY_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" - rm -rf build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - - cmake "$PROJECT_ROOT" \ - -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DOHOS_ARCH="arm64-v8a" \ - -DOHOS_STL=c++_static \ - -DMNN_USE_LOGCAT=true \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_SUPPORT_BF16=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DOHOS_PLATFORM_LEVEL=9 \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出鸿蒙库文件...${NC}" - mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN - - cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/harmony/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" - - echo -e "${GREEN}鸿蒙构建完成!${NC}" - echo "输出位置: $HARMONY_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 Python 版本 -# ============================================================================ -if [ "$BUILD_PYTHON" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Python 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查 Python - if ! command -v $PYTHON_CMD &> /dev/null; then - echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" - exit 1 - fi - - PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" - - # 使用之前创建的 build_pymnn.sh 脚本 - if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then - PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" - [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" - [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" - [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" - PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" - - bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS - else - echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" - - cd pymnn/pip_package - - # 构建依赖 - if [ -n "$DEPS_OPTIONS" ]; then - $PYTHON_CMD build_deps.py $DEPS_OPTIONS - else - $PYTHON_CMD build_deps.py - fi - - # 构建 wheel - $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet - rm -rf build dist - - BUILD_ARGS="" - [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" - [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" - - $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS - - mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" - cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" - - cd "$PROJECT_ROOT" - fi - - echo -e "${GREEN}Python 构建完成!${NC}" - echo "输出位置: $PYTHON_OUTPUT_DIR" -fi - -# ============================================================================ -# 总结 -# ============================================================================ -echo "" -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}所有构建完成!${NC}" -echo -e "${GREEN}========================================${NC}" -echo "" -echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" -echo "" -[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" -[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" -[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" -[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" -echo "" - - diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index 91a9c03959..6513e38fad 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -101,5 +101,4 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_BUILD_LLM_OMNI | 若构建基于MNN的llm库和demo,是否支持图像和音频输入功能,默认为`OFF` 。仅在MNN_BUILD_LLM 打开时生效。开启时 MNN_BUILD_OPENCV , MNN_IMGCODECS , MNN_BUILD_AUDIO 同时打开| | MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo,默认为`OFF` . 打开时MNN_BUILD_OPENCV , MNN_IMGCODECS, MNN_LOW_MEMORY, MNN_SUPPORT_TRANSFORMER_FUSE 同步开启| | MNN_KLEIDIAI | 是否集成ARM的klediAI加速库,默认为`ON` | -| MNN_KLEIDIAI_DEFAULT_ON | 是否默认使用KLEIDIAI的Kernel, 默认为`OFF` | | MNN_USE_RVV | 是否启用RISC-V向量扩展支持,默认为`OFF` | diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 609793f806..7de27bb216 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,8 +20,7 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path \ - --opset 18 + --output_path onnx_save_path ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/express/Expr.cpp b/express/Expr.cpp index d9244e8de4..a735adbe3e 100644 --- a/express/Expr.cpp +++ b/express/Expr.cpp @@ -813,28 +813,12 @@ void* Variable::readInternal(bool forShape) { // The Varp will not be created as input, so we just need copy once return inside->mHostTensor->host(); } - inside->mHostTensor = new Tensor; TensorUtils::copyShape(originTensor, inside->mHostTensor, true); inside->mHostTensor->buffer().type = originTensor->getType(); inside->mHostTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(inside->mHostTensor->size(), MNN_MEMORY_ALIGN_DEFAULT); TensorUtils::getDescribe(inside->mHostTensor)->memoryType = Tensor::InsideDescribe::MEMORY_HOST; originTensor->copyToHostTensor(inside->mHostTensor); - bool hasNoExecution = false; - if (nullptr != originTensor) { - auto backend = TensorUtils::getDescribeOrigin(originTensor)->getBackend(); - if (nullptr != backend) { - // Try to sync to check execution status - int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, originTensor); - if (NO_EXECUTION == syncResult) { - hasNoExecution = true; - } - } - } - if (hasNoExecution) { - MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n"); - return nullptr; - } return inside->mHostTensor->host(); } return originTensor->buffer().host; diff --git a/express/Utils.cpp b/express/Utils.cpp index 6aac549ece..f71f2c997b 100644 --- a/express/Utils.cpp +++ b/express/Utils.cpp @@ -181,23 +181,6 @@ void* Executor::ComputeCache::mapOutput(int offset, Tensor* dest) { if (0 == tensor->usize()) { return nullptr; } - - bool hasNoExecution = false; - if (nullptr != tensor) { - auto backend = TensorUtils::getDescribeOrigin(tensor)->getBackend(); - if (nullptr != backend) { - // Try to sync to check execution status - int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, tensor); - if (NO_EXECUTION == syncResult) { - hasNoExecution = true; - } - } - } - if (hasNoExecution) { - MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n"); - return nullptr; - } - Utils::allocMemoryForHostTensor(dest); if(nullptr != dest->host()) { tensor->copyToHostTensor(dest); diff --git a/express/module/Module.cpp b/express/module/Module.cpp index 3c54ca89c3..f8b4728153 100644 --- a/express/module/Module.cpp +++ b/express/module/Module.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include "core/OpCommonUtils.hpp" #include "PipelineModule.hpp" #include "core/FileLoader.hpp" @@ -18,7 +17,6 @@ #include "Utils.hpp" #include "RuntimeAttr.hpp" #include "ModuleInside.hpp" -#include "core/TensorUtils.hpp" #include #ifdef MNN_INTERNAL_ENABLED #include "internal/auth/ModelAuth.hpp" @@ -223,30 +221,6 @@ class NetModule : public Module { Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime); outputs = mModule->onForward(inputs); } - - // Check execution status after forward - if (!outputs.empty()) { - bool hasNoExecution = false; - for (auto& v : outputs) { - auto t = Utils::getTensor(v); - if (nullptr != t) { - auto backend = TensorUtils::getDescribeOrigin(t)->getBackend(); - if (nullptr != backend) { - // Try to sync to check execution status - int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, t); - if (NO_EXECUTION == syncResult) { - hasNoExecution = true; - break; - } - } - } - } - if (hasNoExecution) { - MNN_PRINT("Warning, Backend has stop execute, return empty output vector varps\n"); - outputs.clear(); - } - } - #ifdef MNN_INTERNAL_ENABLED do { if (outputs.empty()) { diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 8d284aa33b..0e0bc1f136 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -104,14 +104,15 @@ void CPURuntime::_bindCPUCore() const { #ifdef MNN_USE_THREAD_POOL if (nullptr != mThreadPool) { mThreadPool->active(); - ThreadPool::TASK task = std::make_pair([&](int i) { + mThreadPool->enqueue(std::make_pair([&](int i) { MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); - }, mThreadNumber); - mThreadPool->enqueue(&task, mTaskIndex); + return 0; + }, mThreadNumber), mTaskIndex); mThreadPool->deactive(); } #endif } + void CPURuntime::_resetThreadPool() const { mThreadNumber = std::max(1, mThreadNumber); mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); @@ -490,7 +491,6 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p currentRate *= decreaseRate; totalComputeRate += currentRate * selectSize; mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize)); - groupIndex--; } for (auto& g : mGroupWithComputeRate) { g.first = g.first / totalComputeRate; diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index ec4c555dec..884036eb38 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -176,9 +176,6 @@ class CPUBackend : public Backend { #ifdef MNN_USE_THREAD_POOL inline int taskIndex() const {return mRuntime->mTaskIndex;} inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;} - void enqueue(ThreadPool::TASK& task) const { - threadPool()->enqueue(&task, taskIndex()); - } #endif static void initCreatorMap(); static size_t getBytes(const Backend* backend, const Tensor* output); diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp index 61ccf4fca3..059e502d0b 100644 --- a/source/backend/cpu/CPUBinary.cpp +++ b/source/backend/cpu/CPUBinary.cpp @@ -45,37 +45,6 @@ ErrorCode CPUBinary::onResize(const std::vector& inputs, const std::vec mThreadNum = threads; mWorkDiv = UP_DIV(mTotalSize, threads); } - int inpBytes = inputs[0]->getType().bytes(); - int outBytes = outputs[0]->getType().bytes(); - if (halide_type_float == inputs[0]->getType().code) { - inpBytes = static_cast(backend())->functions()->bytes; - } - if (halide_type_float == outputs[0]->getType().code) { - outBytes = static_cast(backend())->functions()->bytes; - } - bool outputInt = outputs[0]->getType().code == halide_type_int; - mTask = std::make_pair([this, inpBytes, outBytes, outputInt](int tId) { - int start = tId * mWorkDiv; - int realSize = ALIMIN(mWorkDiv, mTotalSize - start); - if (realSize > 0) { - auto inp0 = mInput0Ptr + start * inpBytes; - auto inp1 = mInput1Ptr + start * inpBytes; - if (mNeedBroadcastIndex == 0) { - inp0 = mInput0Ptr; - } else if (mNeedBroadcastIndex == 1) { - inp1 = mInput1Ptr; - } - auto out = mOutputPtr + start * outBytes; - mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); - if(mActivationType == 1 && outputInt) { - for(int i=0; i 0 ? val : 0; - ((int32_t *)out)[i] = res; - } - } - } - } , mThreadNum); return NO_ERROR; } @@ -98,10 +67,31 @@ ErrorCode CPUBinary::onExecute(const std::vector& inputs, const std::ve outBytes = static_cast(backend())->functions()->bytes; } auto precision = static_cast(backend())->precisionMode(); - mInput0Ptr = input0Ptr; - mInput1Ptr = input1Ptr; - mOutputPtr = outputPtr; - MNN_CONCURRENCY_ENQUEUE(mTask); + + MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { + int start = tId * mWorkDiv; + int realSize = ALIMIN(mWorkDiv, mTotalSize - start); + if (realSize > 0) { + auto inp0 = input0Ptr + start * inpBytes; + auto inp1 = input1Ptr + start * inpBytes; + if (mNeedBroadcastIndex == 0) { + inp0 = input0Ptr; + } else if (mNeedBroadcastIndex == 1) { + inp1 = input1Ptr; + } + auto out = outputPtr + start * outBytes; + mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); + if(mActivationType == 1 && output->getType().code == halide_type_int) { + for(int i=0; i 0 ? val : 0; + ((int32_t *)out)[i] = res; + } + } + } + } + MNN_CONCURRENCY_END(); + if(mActivationType == 1 && output->getType().code == halide_type_float) { mActivationExe->onExecute(outputs, outputs); } diff --git a/source/backend/cpu/CPUBinary.hpp b/source/backend/cpu/CPUBinary.hpp index 17cb3b5f47..9250df79ae 100644 --- a/source/backend/cpu/CPUBinary.hpp +++ b/source/backend/cpu/CPUBinary.hpp @@ -33,10 +33,6 @@ class CPUBinary : public Execution { int mThreadNum; int mWorkDiv; std::shared_ptr mActivationExe; - std::pair, int> mTask; - uint8_t* mInput0Ptr = nullptr; - uint8_t* mOutputPtr = nullptr; - uint8_t* mInput1Ptr = nullptr; }; } // namespace MNN #endif /* CPUBinary_hpp */ diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 22b96a64ee..4f0765f050 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -37,8 +37,9 @@ void CPUMatMul::_scheduleForVecE(int e, int l, int h) { param.BTranspose = mTransposeB; param.numberThread = numberThread; auto func = static_cast(backend())->functions()->MNNComputeMatMulForE_1; - mPreFunctions.emplace_back(std::make_pair([param, func, this](int tId) { - func(mA, mB, mC, mBiasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func]( + int tId, const float* A, const float* B, const float* biasPtr, float* C) { + func(A, B, C, biasPtr, ¶m, tId); }, numberThread)); } @@ -53,9 +54,9 @@ void CPUMatMul::_scheduleForVec(int e, int l, int h) { auto func = static_cast(backend())->functions()->MNNComputeMatMulForH_1; // TODD: Support e = 1 MNN_ASSERT(h == 1); - mPreFunctions.emplace_back(std::make_pair([param, func, this]( - int tId) { - func(mA, mB, mC, mBiasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func]( + int tId, const float* A, const float* B, const float* biasPtr, float* C) { + func(A, B, C, biasPtr, ¶m, tId); }, numberThread)); } @@ -99,8 +100,8 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec return OUT_OF_MEMORY; } - mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId) { - core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), mB, h, 1, l, mTransposeB); + mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) { + core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB); } , 1)); bool useBias = false; MemChunk bdestAlloc; @@ -119,9 +120,9 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec } mTempBias = bdestAlloc; mPreFunctions.emplace_back(std::make_pair( - [biasLength, bdestAlloc, core, this](int tId) { + [biasLength, bdestAlloc, core](int tId, const float* APtr, const float* BPtr, const float* borigin, float* C) { ::memset(bdestAlloc.ptr(), 0, UP_DIV(biasLength, core->pack) * core->bytes * core->pack); - ::memcpy(bdestAlloc.ptr(), mBiasPtr, biasLength * core->bytes); + ::memcpy(bdestAlloc.ptr(), borigin, biasLength * core->bytes); }, 1)); } else { mUseBiasDirectly = true; @@ -166,12 +167,11 @@ ErrorCode CPUMatMul::onExecute(const std::vector& inputs, const std::ve } void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const float* biasPtr) { - mA = APtr; - mB = BPtr; - mC = CPtr; - mBiasPtr = biasPtr; for (auto& f : mPreFunctions) { - MNN_CONCURRENCY_ENQUEUE(f); + MNN_CONCURRENCY_BEGIN(tId, f.second) { + f.first(tId, APtr, BPtr, biasPtr, CPtr); + } + MNN_CONCURRENCY_END(); } if (mE > 0) { auto core = static_cast(backend())->functions(); diff --git a/source/backend/cpu/CPUMatMul.hpp b/source/backend/cpu/CPUMatMul.hpp index 48226795f0..872a77a9a8 100644 --- a/source/backend/cpu/CPUMatMul.hpp +++ b/source/backend/cpu/CPUMatMul.hpp @@ -29,7 +29,7 @@ class CPUMatMul : public Execution { bool mTransposeB; bool mTransposeC; bool mSupportMultiThread = false; - std::vector, int>> mPreFunctions; + std::vector, int>> mPreFunctions; bool mUseBiasDirectly = false; MemChunk mTempA; MemChunk mTempB; @@ -40,11 +40,6 @@ class CPUMatMul : public Execution { int mL; int mH; std::vector mPostParameters; - // For Execute Paramters - const float* mA = nullptr; - const float* mB = nullptr; - const float* mBiasPtr = nullptr; - float* mC = nullptr; }; } // namespace MNN diff --git a/source/backend/cpu/CPURNNSequenceGRU.cpp b/source/backend/cpu/CPURNNSequenceGRU.cpp index 0bda660e9c..daae8811c7 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.cpp +++ b/source/backend/cpu/CPURNNSequenceGRU.cpp @@ -10,26 +10,30 @@ #include #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/ConvOpt.h" +#include "backend/cpu/compute/CommonOptFunction.h" #include "core/TensorUtils.hpp" namespace MNN { // implement GRU cell function // Ref: tensorflow/python/ops/rnn_cell_impl.py -void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, +void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, + std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt, uint8_t* hiddenStateOutput) { + std::shared_ptr& resetHt) { + auto bn = static_cast(backend()); + auto mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); + auto addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); + auto subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); + auto tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); + auto bytes = bn->functions()->bytes; + auto sigmoidFunc = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); // gate is (z_t, r_t) - auto bytes = mRNNFunctions.bytes; - MNNBinaryExecute mulFunction = mRNNFunctions.mulFunction; - MNNBinaryExecute addFunction = mRNNFunctions.addFunction; - MNNBinaryExecute subFunction = mRNNFunctions.subFunction; - MNNUnaryExecute tanhFunction = mRNNFunctions.tanhFunction; - MNNUnaryExecute sigmoidFunction = mRNNFunctions.sigmoidFunction; auto inputAndStatePtr = inputAndState->host(); + auto hiddenStatePtr = hiddenState->host(); ::memcpy(inputAndStatePtr, input, inputLength * bytes); - ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStateInput, numUnits * bytes); + ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStatePtr, numUnits * bytes); inputAndState->setLength(1, inputLength + numUnits); // // [x_t, h_t-1] * [W_zr, R_zr]: (1, inputLength + numUnits) X (inputLength + numUnits, 2 * numUnits) @@ -38,8 +42,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, recurrentBias->setLength(1, 2 * numUnits); addFunction(gate->host(), gate->host(), recurrentBias->host(), 2*numUnits, -1); // (1, 2*numUnits) + const int gateSize = gate->elementSize(); auto gatePtr = gate->host(); - sigmoidFunction(gatePtr, gatePtr, 2 * numUnits); + sigmoidFunc(gatePtr, gatePtr, gateSize); // reset gate, // r_t is the second segment auto rtPtr = gatePtr + numUnits * bytes; @@ -47,7 +52,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // calculate Rt (.) (Ht_1 * Rh + Rbh) auto recurrentHiddenBiasPtr = recurrentBias->host() + 2 * numUnits * bytes; auto rhWeightPtr = candidateWeight->host() + inputLength * numUnits * bytes; - mMatMulU2U->execute((float*)hiddenStateInput, (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); + mMatMulU2U->execute(hiddenState->host(), (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); mulFunction(resetHt->host(), rtPtr, resetHt->host(), numUnits, -1); // calculate Xt * Wh @@ -60,7 +65,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // r_t: (1, numUnits) auto resetGatePtr = inputAndStatePtr + inputLength * bytes; // h_t1(1, numUnits) = r_t(1, numUnits) * h_t-1_(1, numUnits) - mulFunction(resetGatePtr, rtPtr, hiddenStateInput, numUnits, -1); + mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1); // deal with recurrent bias and linear_before_reset parameter auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes; auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host() + 2 * numUnits * bytes); @@ -71,9 +76,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, } // h = (1-g)*t+g*h = t + g*(h-t) tanhFunction(resetHt->host(), rtPtr, numUnits); - subFunction(hiddenStateOutput, hiddenStateInput, resetHt->host(), numUnits, -1); - mulFunction(hiddenStateOutput, hiddenStateOutput, gatePtr, numUnits, -1); - addFunction(hiddenStateOutput, hiddenStateOutput, resetHt->host(), numUnits, -1); + subFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); + mulFunction(hiddenStatePtr, hiddenStatePtr, gatePtr, numUnits, -1); + addFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); inputAndState->setLength(1, inputLength + 2 * numUnits); } @@ -138,13 +143,6 @@ ErrorCode CPURNNSequenceGRU::onResize(const std::vector& inputs, const backend()->onReleaseBuffer(mInputAndState.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mGate.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mResetHt.get(), Backend::DYNAMIC); - auto bn = static_cast(backend()); - mRNNFunctions.mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); - mRNNFunctions.addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); - mRNNFunctions.subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); - mRNNFunctions.tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); - mRNNFunctions.bytes = bn->functions()->bytes; - mRNNFunctions.sigmoidFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); return NO_ERROR; } @@ -185,29 +183,27 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const const int inputCodeLength = input->length(2); // MNN_PRINT("inputSequenceLength:%d, batchSize:%d, inputCodeLength:%d, mNumUnits:%d, hiddenStateDataSize:%d\n", inputSequenceLength, batchSize, inputCodeLength, mNumUnits, hiddenStateDataSize); for (int b = 0; b < batchSize; ++b) { // swap order - auto hiddenStateInput = hiddenStatePtr; - auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * (mIsBidirectionalRNN + 1)) { auto source = inputs[inputSize - 1]->host() + b * hiddenStateDataSize; - hiddenStateInput = source; + ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = 0; i < inputSequenceLength; ++i) { const int inputOffset = i * SequenceStride + b * inputCodeLength; + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, fwGateWeight, fwGateBias, + fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt); + if (mKeepAllOutputs) { - hiddenStateOutput = outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes; + ::memcpy(outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes, hiddenStatePtr, hiddenStateDataSize); } - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, fwGateWeight, fwGateBias, - fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); - - hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { - ::memcpy(outputYhPtr, hiddenStateOutput, hiddenStateDataSize); + ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); outputYhPtr += mNumUnits * bytes; } + } // backward rnn @@ -225,24 +221,22 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const auto outputBw = outputs[0]; auto const outputBwPtr = outputBw->host(); for (int b = 0; b < batchSize; ++b) { - auto hiddenStateInput = hiddenStatePtr; - auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * 2) { auto source = inputs[inputSize - 1]->host() + (batchSize + b) * hiddenStateDataSize; - hiddenStateInput = source; + ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = inputSequenceLength - 1; i >= 0; i--) { const int inputOffset = i * SequenceStride + b * inputCodeLength; + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, bwGateWeight, bwGateBias, + bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt); if (mKeepAllOutputs) { - hiddenStateOutput = outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes; + ::memcpy(outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes, + hiddenStatePtr, hiddenStateDataSize); } - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, bwGateWeight, bwGateBias, - bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); - hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); diff --git a/source/backend/cpu/CPURNNSequenceGRU.hpp b/source/backend/cpu/CPURNNSequenceGRU.hpp index 0125b9e8a1..0987d13053 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.hpp +++ b/source/backend/cpu/CPURNNSequenceGRU.hpp @@ -11,7 +11,6 @@ #include "core/Execution.hpp" #include "CPUMatMul.hpp" -#include "backend/cpu/compute/CommonOptFunction.h" namespace MNN { class CPURNNSequenceGRU : public Execution { @@ -20,20 +19,13 @@ class CPURNNSequenceGRU : public Execution { virtual ~CPURNNSequenceGRU(); virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - struct RNNFuntions { - MNNBinaryExecute mulFunction; - MNNBinaryExecute addFunction; - MNNBinaryExecute subFunction; - MNNUnaryExecute tanhFunction; - MNNUnaryExecute sigmoidFunction; - int bytes; - }; + private: void runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, + std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt, uint8_t* hiddenStateOutput); + std::shared_ptr& resetHt); bool mKeepAllOutputs; bool mIsBidirectionalRNN; bool mlinearBeforeReset; @@ -50,7 +42,6 @@ class CPURNNSequenceGRU : public Execution { std::shared_ptr mMatMulU2U; // For inputLength -> numUnit std::shared_ptr mMatMulI2U; - RNNFuntions mRNNFunctions; }; } // namespace MNN diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index 1339089347..3272086531 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -49,6 +49,227 @@ struct ReduceInfo { } }; +ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { + MNN_ASSERT(outputs.size() == 1); + auto output = outputs[0]; + OpCommonUtils::rasterInputReset(____inputs, outputs[0]); + auto des = TensorUtils::getDescribe(output); + auto outputDes = TensorUtils::getDescribe(output); + mNeedZero = !TensorUtils::regionIsFull(output); + mZeroPoint = 0; + mUseThreads = false; + if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { +#ifdef MNN_USE_SSE + mZeroPoint = (int)outputDes->quantAttr->zero + 128; +#else + mZeroPoint = (int)outputDes->quantAttr->zero; +#endif + } + mTempInput.clear(); + mFastBlit.clear(); + mCacheRegions.clear(); + mTempOutput = nullptr; + auto midFormat = MNN_DATA_FORMAT_NCHW; + mTempInputCopy.clear(); + mFast = false; + auto core = static_cast(backend())->functions(); + mSingleConvert.type = 0; + // all_srcFormat == dstFormat == NC4HW4 : Fast Exe + if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { + mFast = true; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + mFast = false; + break; + } + if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { + mFast = false; + break; + } + } + if (mFast) { + mUseThreads = des->regions.size() > 16 ? true : false; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (slice.origin == nullptr) { + continue; + } + Tensor::InsideDescribe::Region newRegion; + OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); + mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); + } + return NO_ERROR; + } + } + // srcNum == 1 && srcFormat != dstFormat : Single Convert + if (des->regions.size() == 1) { + OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); + if (mSingleConvert.type > 0) { + mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; + return NO_ERROR; + } + } + // Acquire Buffer for temp output + // TODO: optimize it + if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { + mTempOutput.reset(new Tensor); + TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); + } + if (nullptr != mTempOutput) { + auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + } + // input is NC4HW4 add Convert + std::vector forRelease; + TensorUtils::FuseWrap fuseUtils; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + auto origin = slice.origin; + if (nullptr == origin /*|| nullptr == origin->host()*/) { + continue; + } + // if tensor is not NC4HW4 or has been merged, don't need deal + if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); + continue; + } + // if NC4HW4's C%4 == 0, change convert to transpose and fuse it + if (origin->batch() == 1 && origin->channel() % core->pack == 0) { + int channel = origin->channel(); + int area = 1; + // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height + for (int d = 2; d < origin->dimensions(); d++) { + area *= origin->length(d); + } + Tensor::InsideDescribe::Region regionTmp; + regionTmp.src.offset = 0; + regionTmp.src.stride[0] = area * core->pack; + regionTmp.src.stride[1] = 1; + regionTmp.src.stride[2] = core->pack; + regionTmp.dst.offset = 0; + regionTmp.dst.stride[0] = area * core->pack; + regionTmp.dst.stride[1] = area; + regionTmp.dst.stride[2] = 1; + regionTmp.size[0] = channel / core->pack; + regionTmp.size[1] = core->pack; + regionTmp.size[2] = area; + regionTmp.origin = slice.origin; + bool merge = fuseUtils.match(regionTmp, slice); + if (merge) { + std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); + *newSlice = slice; + fuseUtils.apply(regionTmp, *newSlice); + // cache the merged tensor + if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); + mCacheRegions.emplace_back(newSlice); + continue; + } + } + auto cache = static_cast(backend())->getCache(); + auto tempTensor = cache->findCacheTensor(origin, midFormat); + //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); + if (nullptr == tempTensor) { + std::shared_ptr newTensor(new Tensor); + TensorUtils::copyShape(origin, newTensor.get()); + TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; + TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; + TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; + newTensor->buffer().type = origin->getType(); + TensorUtils::setLinearLayout(newTensor.get()); + mTempInput.insert(std::make_pair(origin, newTensor.get())); + auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + tempTensor = newTensor.get(); + TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; + cache->pushCacheTensor(newTensor, origin, midFormat); + } + if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { + forRelease.emplace_back(tempTensor); + } + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); + } + for (auto t : forRelease) { + backend()->onReleaseBuffer(t, Backend::DYNAMIC); + } + if (nullptr != mTempOutput) { + backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); + } + auto threadNumber = static_cast(backend())->threadNumber(); + mHasReduce = false; + ReduceInfo reduceInfo; + for (auto& iter : mTempInputCopy) { + if (reduceInfo.compute(*iter.second)) { + mHasReduce = true; + break; + } + } + if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { + // Split to multi region + auto region = mTempInputCopy[0].second; + if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = false; + return NO_ERROR; + } + if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + auto tensorPtr = mTempInputCopy[0].first; + int pos = -1; + for (int i=0; i<3; ++i) { + if (region->size[i] > 1) { + pos = i; + break; + } + } + if (-1 == pos) { + // Don't need divide + return NO_ERROR; + } + mTempInputCopy.clear(); + int divSize = UP_DIV(region->size[pos], threadNumber); + for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); + auto& cacheReg = *cacheRegPtr; + int sta = i * divSize; + int fin = sta + divSize; + fin = std::min(fin, region->size[pos]); + if (fin <= sta) { + break; + } + for (int v=0; v<3; ++v) { + cacheReg.src.stride[v] = region->src.stride[v]; + cacheReg.dst.stride[v] = region->dst.stride[v]; + } + int curSize = fin - sta; + for (int v=0; vsize[v]; + } + cacheReg.size[pos] = curSize; + cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; + cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; + for (int v=pos+1; v<3; ++v) { + cacheReg.size[v] = region->size[v]; + } + mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); + mCacheRegions.emplace_back(cacheRegPtr); + } + } + return NO_ERROR; +} static void _transpose(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region, int bytes) { int dims[4], keepDim = -1; for (int i = 0; i < 3; i++) { @@ -103,12 +324,15 @@ static void _2BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, } } -void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) { +void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) const { auto input = inputs[0]; auto output = outputs[0]; auto bytes = CPUBackend::getBytes(backend(), output); auto core = static_cast(backend())->functions(); - int threadNum = static_cast(backend())->threadNumber(); + auto threadNum = static_cast(backend())->threadNumber(); + if (mNeedZero) { + ::memset(output->host(), mZeroPoint, static_cast(backend())->getTensorSize(output) * bytes); + } auto byteC4 = bytes * core->pack; auto C4proc = core->MNN4BitcopyWithStride; switch (byteC4) { @@ -128,7 +352,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve if (!mUseThreads) { threadNum = 1; } - mTasks.emplace_back(std::make_pair([threadNum, this, output, bytes, C4proc, byteC4](int tId) { + MNN_CONCURRENCY_BEGIN(tId, threadNum) { for (int u=(int)tId; uhost() == nullptr) { @@ -169,7 +393,8 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve } } } - }, threadNum)); + } + MNN_CONCURRENCY_END(); } static BlitProc _selectUnitProc(int bytes, int stride, int ds) { @@ -371,307 +596,97 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const } } void CPURaster::tensorConvert(Tensor* input, Tensor* output, int bytes) { - std::pair, int> task; + auto& subIb = input->buffer(); + auto& subOb = output->buffer(); + auto source = TensorUtils::getDescribe(input)->dimensionFormat; + auto dest = TensorUtils::getDescribe(output)->dimensionFormat; + if (subIb.dimensions <= 1 || source == dest) { + ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); + return; + } + auto tup = CPUTensorConverter::splitDimensions(subIb, source); + int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); + const int bitLength = bytes; auto core = static_cast(backend())->functions(); auto threadNumber = static_cast(backend())->threadNumber(); if (!mUseThreads) { threadNumber = 1; } - task.first = [input, output, bytes, threadNumber, core](int tId) { - auto& subIb = input->buffer(); - auto& subOb = output->buffer(); - auto source = TensorUtils::getDescribe(input)->dimensionFormat; - auto dest = TensorUtils::getDescribe(output)->dimensionFormat; - if (subIb.dimensions <= 1 || source == dest) { - ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); - return; - } - auto tup = CPUTensorConverter::splitDimensions(subIb, source); - int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); - const int bitLength = bytes; + MNN_CONCURRENCY_BEGIN(tId, threadNumber) { CPUTensorConverter::convert(subIb.host, subOb.host, source, dest, batch, area, channel, bitLength, core, tId, threadNumber); }; - task.second = threadNumber; - mTasks.emplace_back(task); + MNN_CONCURRENCY_END(); } -ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { - MNN_ASSERT(outputs.size() == 1); - auto output = outputs[0]; - OpCommonUtils::rasterInputReset(____inputs, outputs[0]); - auto des = TensorUtils::getDescribe(output); - auto outputDes = TensorUtils::getDescribe(output); - mNeedZero = !TensorUtils::regionIsFull(output); - mZeroPoint = 0; - mUseThreads = false; - int threadNum = static_cast(backend())->threadNumber(); - if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { -#ifdef MNN_USE_SSE - mZeroPoint = (int)outputDes->quantAttr->zero + 128; -#else - mZeroPoint = (int)outputDes->quantAttr->zero; -#endif - } - size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - mTempInput.clear(); - mFastBlit.clear(); - mCacheRegions.clear(); - mTempOutput = nullptr; - mTasks.clear(); - auto midFormat = MNN_DATA_FORMAT_NCHW; - mTempInputCopy.clear(); - mFast = false; - auto core = static_cast(backend())->functions(); - mSingleConvert.type = 0; - // all_srcFormat == dstFormat == NC4HW4 : Fast Exe - if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { - mFast = true; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - mFast = false; - break; - } - if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { - mFast = false; - break; - } - } - if (mFast) { - mUseThreads = des->regions.size() > 16 ? true : false; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (slice.origin == nullptr) { - continue; - } - Tensor::InsideDescribe::Region newRegion; - OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); - mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); - } - executeFaster(____inputs, outputs); - return NO_ERROR; - } - } - // srcNum == 1 && srcFormat != dstFormat : Single Convert - if (des->regions.size() == 1) { - OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); - if (mSingleConvert.type > 0) { - std::pair, int> task; - mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; - auto realInput = ____inputs[0]; - int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; - auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; - auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; - auto channelC4 = UP_DIV(srcChannel, core->pack); - auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; - auto batchStride = srcChannel * srcArea * bytes; - auto inputBatchStride = batchStride; - auto outputBatchStride = batchStride; - if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { - if (realInput->dimensions() <= 1) { - task.first = [output, realInput, bytes](int tId) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - }; - task.second = 1; - mTasks.emplace_back(task); - return NO_ERROR; - } - inputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - destFormat = MNN_DATA_FORMAT_NHWC; - } else { - destFormat = MNN_DATA_FORMAT_NCHW; - } - } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { - if (output->dimensions() <= 1) { - task.first = [output, realInput, bytes](int tId) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - }; - task.second = 1; - mTasks.emplace_back(task); - return NO_ERROR; - } - outputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - sourceFormat = MNN_DATA_FORMAT_NHWC; - } else { - sourceFormat = MNN_DATA_FORMAT_NCHW; - } - } - if (!mUseThreads) { - threadNum = 1; - } - task.first = [realInput, output, sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, threadNum](int tId) { - CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); - }; - task.second = threadNum; - mTasks.emplace_back(task); - return NO_ERROR; - } - } - // Acquire Buffer for temp output - // TODO: optimize it - if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { - mTempOutput.reset(new Tensor); - TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); - } + + +ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { + void* mOutputPtr = nullptr; if (nullptr != mTempOutput) { - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = outputs[0]->host(); } - // input is NC4HW4 add Convert - std::vector forRelease; - TensorUtils::FuseWrap fuseUtils; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - auto origin = slice.origin; - if (nullptr == origin /*|| nullptr == origin->host()*/) { - continue; - } - // if tensor is not NC4HW4 or has been merged, don't need deal - if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; + if (mFast) { + executeFaster(____inputs, outputs); + return NO_ERROR; + } + auto core = static_cast(backend())->functions(); + auto output = outputs[0]; + size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); + auto outputEleSize = static_cast(backend())->getTensorSize(output); + auto threadNum = static_cast(backend())->threadNumber(); + if (mSingleConvert.type > 0) { + auto realInput = ____inputs[0]; + int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; + auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; + auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; + auto channelC4 = UP_DIV(srcChannel, core->pack); + auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; + auto batchStride = srcChannel * srcArea * bytes; + auto inputBatchStride = batchStride; + auto outputBatchStride = batchStride; + if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { + if (realInput->dimensions() <= 1) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + return NO_ERROR; } - mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); - continue; - } - // if NC4HW4's C%4 == 0, change convert to transpose and fuse it - if (origin->batch() == 1 && origin->channel() % core->pack == 0) { - int channel = origin->channel(); - int area = 1; - // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height - for (int d = 2; d < origin->dimensions(); d++) { - area *= origin->length(d); + inputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + destFormat = MNN_DATA_FORMAT_NHWC; + } else { + destFormat = MNN_DATA_FORMAT_NCHW; } - Tensor::InsideDescribe::Region regionTmp; - regionTmp.src.offset = 0; - regionTmp.src.stride[0] = area * core->pack; - regionTmp.src.stride[1] = 1; - regionTmp.src.stride[2] = core->pack; - regionTmp.dst.offset = 0; - regionTmp.dst.stride[0] = area * core->pack; - regionTmp.dst.stride[1] = area; - regionTmp.dst.stride[2] = 1; - regionTmp.size[0] = channel / core->pack; - regionTmp.size[1] = core->pack; - regionTmp.size[2] = area; - regionTmp.origin = slice.origin; - bool merge = fuseUtils.match(regionTmp, slice); - if (merge) { - std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); - *newSlice = slice; - fuseUtils.apply(regionTmp, *newSlice); - // cache the merged tensor - if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); - mCacheRegions.emplace_back(newSlice); - continue; + } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { + if (output->dimensions() <= 1) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + return NO_ERROR; } - } - auto cache = static_cast(backend())->getCache(); - auto tempTensor = cache->findCacheTensor(origin, midFormat); - //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); - if (nullptr == tempTensor) { - std::shared_ptr newTensor(new Tensor); - TensorUtils::copyShape(origin, newTensor.get()); - TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; - TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; - TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; - newTensor->buffer().type = origin->getType(); - TensorUtils::setLinearLayout(newTensor.get()); - mTempInput.insert(std::make_pair(origin, newTensor.get())); - auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; + outputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + sourceFormat = MNN_DATA_FORMAT_NHWC; + } else { + sourceFormat = MNN_DATA_FORMAT_NCHW; } - tempTensor = newTensor.get(); - TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; - cache->pushCacheTensor(newTensor, origin, midFormat); } - if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { - forRelease.emplace_back(tempTensor); + if (!mUseThreads) { + threadNum = 1; } - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); - } - for (auto t : forRelease) { - backend()->onReleaseBuffer(t, Backend::DYNAMIC); - } - if (nullptr != mTempOutput) { - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); + MNN_CONCURRENCY_BEGIN(tId, threadNum) { + CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); + }; + MNN_CONCURRENCY_END(); + return NO_ERROR; } - auto threadNumber = static_cast(backend())->threadNumber(); - mHasReduce = false; - ReduceInfo reduceInfo; - for (auto& iter : mTempInputCopy) { - if (reduceInfo.compute(*iter.second)) { - mHasReduce = true; - break; + if (mNeedZero) { + if (mTempOutput == nullptr) { + ::memset(output->host(), mZeroPoint, outputEleSize * bytes); + } else { + ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); } } - // Encode convert for (auto& iter : mTempInput) { tensorConvert(iter.first, iter.second, (int)bytes); } - do { - if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { - // Split to multi region - auto region = mTempInputCopy[0].second; - if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = false; - break; - } - if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - auto tensorPtr = mTempInputCopy[0].first; - int pos = -1; - for (int i=0; i<3; ++i) { - if (region->size[i] > 1) { - pos = i; - break; - } - } - if (-1 == pos) { - // Don't need divide - break; - } - mTempInputCopy.clear(); - int divSize = UP_DIV(region->size[pos], threadNumber); - for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); - auto& cacheReg = *cacheRegPtr; - int sta = i * divSize; - int fin = sta + divSize; - fin = std::min(fin, region->size[pos]); - if (fin <= sta) { - break; - } - for (int v=0; v<3; ++v) { - cacheReg.src.stride[v] = region->src.stride[v]; - cacheReg.dst.stride[v] = region->dst.stride[v]; - } - int curSize = fin - sta; - for (int v=0; vsize[v]; - } - cacheReg.size[pos] = curSize; - cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; - cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; - for (int v=pos+1; v<3; ++v) { - cacheReg.size[v] = region->size[v]; - } - mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); - mCacheRegions.emplace_back(cacheRegPtr); - } - } - } while (false); if (mHasReduce) { // Don't support reduce with multi thread now threadNum = 1; @@ -685,13 +700,8 @@ ErrorCode CPURaster::onResize(const std::vector &____inputs, const std if (outputDescribe->overlap) { threadNum = 1; } - mTasks.emplace_back(std::make_pair([this, threadNum, output, bytes, core](int tId){ - void* mOutputPtr = nullptr; - if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = output->host(); - } + + MNN_CONCURRENCY_BEGIN(tId, threadNum) { for (int u=tId; u &____inputs, const std auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; _blit(slice, (int)bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp); } - }, threadNum)); - if (nullptr != mTempOutput) { - tensorConvert(mTempOutput.get(), output, (int)bytes); } - return NO_ERROR; -} - - -ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { - void* mOutputPtr = nullptr; + MNN_CONCURRENCY_END(); if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = outputs[0]->host(); - } - auto core = static_cast(backend())->functions(); - auto output = outputs[0]; - size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - auto outputEleSize = static_cast(backend())->getTensorSize(output); - auto threadNum = static_cast(backend())->threadNumber(); - if (mNeedZero) { - if (mTempOutput == nullptr) { - ::memset(output->host(), mZeroPoint, outputEleSize * bytes); - } else { - ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); - } - } - for (auto& task : mTasks) { - MNN_CONCURRENCY_ENQUEUE(task); + tensorConvert(mTempOutput.get(), output, (int)bytes); } return NO_ERROR; } @@ -1081,15 +1066,7 @@ class CPULoop : public Execution { auto stride2 = cmd->view()->GetAs(2)->stride()->data(); auto blit1 = _selectUnitProc(bytes, stride1[2], 1); auto blit2 = _selectUnitProc(bytes, stride2[2], 1); - if (cmd->size()->data()[2] == 1 || (stride1[2] <= 1 && stride2[2] <= 1 && (stride1[2] + stride1[1] != 0))) { - // Support elementwise or one src broadcast - int needBroadcastIndex = -1; - if (0 == stride1[2]) { - needBroadcastIndex = 0; - } - if (0 == stride2[2]) { - needBroadcastIndex = 1; - } + if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) { for (int z=0; zsize()->data()[0]; ++z) { auto src0Z = src0 + z * stride1[0] * bytes; auto src1Z = src1 + z * stride2[0] * bytes; @@ -1098,7 +1075,7 @@ class CPULoop : public Execution { auto src0Y = src0Z + y * stride1[1] * bytes; auto src1Y = src1Z + y * stride2[1] * bytes; auto dstY = dstZ + y * stride0[1] * bytes; - proc(dstY, src0Y, src1Y, cmd->size()->data()[2], needBroadcastIndex); + proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1); } } } else { diff --git a/source/backend/cpu/CPURaster.hpp b/source/backend/cpu/CPURaster.hpp index bff149df52..9df10700bd 100644 --- a/source/backend/cpu/CPURaster.hpp +++ b/source/backend/cpu/CPURaster.hpp @@ -24,7 +24,7 @@ class CPURaster : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - void executeFaster(const std::vector &inputs, const std::vector &outputs); + void executeFaster(const std::vector &inputs, const std::vector &outputs) const; void tensorConvert(Tensor* input, Tensor* output, int bytes); private: std::map mTempInput; @@ -38,7 +38,6 @@ class CPURaster : public Execution { int32_t mZeroPoint = 0; bool mHasReduce = false; bool mUseThreads = false; - std::vector, int>> mTasks; }; } #endif diff --git a/source/backend/cpu/ThreadPool.cpp b/source/backend/cpu/ThreadPool.cpp index d7765c4fbc..15a2d8241c 100644 --- a/source/backend/cpu/ThreadPool.cpp +++ b/source/backend/cpu/ThreadPool.cpp @@ -60,7 +60,7 @@ ThreadPool::ThreadPool(int numberThread) { while (mActiveCount > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { - mTasks[i].first->first(threadIndex); + mTasks[i].first.first(threadIndex); { *mTasks[i].second[threadIndex] = false; } } } @@ -118,18 +118,16 @@ void ThreadPool::deactive() { mActiveCount--; } -void ThreadPool::enqueue(TASK* taskp, int index) { - auto& task = *taskp; +void ThreadPool::enqueue(TASK&& task, int index) { if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } - enqueueInternal(taskp, index); + enqueueInternal(std::move(task), index); } -void ThreadPool::enqueueInternal(TASK* taskp, int index) { - auto& task = *taskp; +void ThreadPool::enqueueInternal(TASK&& task, int index) { if (mActiveCount == 0) { for (int i = 0; i < task.second; ++i) { task.first(i); @@ -137,25 +135,24 @@ void ThreadPool::enqueueInternal(TASK* taskp, int index) { return; } int workSize = task.second; - TASK* tmpTask = nullptr; if (workSize > mNumberThread) { - tmpTask = new TASK; - *tmpTask = std::make_pair([workSize, &task, this](int tId) { - for (int v = tId; v < workSize; v += mNumberThread) { - task.first(v); - } - }, mNumberThread); - mTasks[index].first = tmpTask; + mTasks[index].first = std::make_pair( + [workSize, &task, this](int tId) { + for (int v = tId; v < workSize; v += mNumberThread) { + task.first(v); + } + }, + mNumberThread); workSize = mNumberThread; } else { - mTasks[index].first = taskp; + mTasks[index].first = std::move(task); } { for (int i = 1; i < workSize; ++i) { *mTasks[index].second[i] = true; } } - mTasks[index].first->first(0); + mTasks[index].first.first(0); bool complete = true; do { complete = true; @@ -168,9 +165,6 @@ void ThreadPool::enqueueInternal(TASK* taskp, int index) { std::this_thread::yield(); // FUNC_PRINT(notComplete); } while (!complete); - if (nullptr != tmpTask) { - delete tmpTask; - } } } // namespace MNN #endif diff --git a/source/backend/cpu/ThreadPool.hpp b/source/backend/cpu/ThreadPool.hpp index 8891da61b1..4bf23de1b0 100644 --- a/source/backend/cpu/ThreadPool.hpp +++ b/source/backend/cpu/ThreadPool.hpp @@ -25,7 +25,7 @@ class MNN_PUBLIC ThreadPool { int numberThread() const { return mNumberThread; } - void enqueue(TASK* task, int index); + void enqueue(TASK&& task, int index); void active(); void deactive(); @@ -37,7 +37,7 @@ class MNN_PUBLIC ThreadPool { static void destroy(); private: - void enqueueInternal(TASK* task, int index); + void enqueueInternal(TASK&& task, int index); ThreadPool(int numberThread = 0); ~ThreadPool(); @@ -46,7 +46,7 @@ class MNN_PUBLIC ThreadPool { std::vector mTaskAvailable; std::atomic mStop = {false}; - std::vector>> mTasks; + std::vector>> mTasks; std::condition_variable mCondition; std::mutex mQueueMutex; diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 61ebce6bdc..18fca54a4e 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,9 +36,6 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() - if(MNN_KLEIDIAI_DEFAULT_ON) - add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) - endif() endif() if (MNN_SME2) diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index c9bfcc2189..d7d0d7fb34 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -3882,13 +3882,12 @@ void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, si #endif -void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tIdL) { +void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { auto l = param->l; auto h = param->h; auto numberThread = param->numberThread; auto lC4 = l / 4; auto lR = lC4 * 4; - auto tId = (int)tIdL; if (param->BTranspose) { for (int y=tId; y= 8) { - if (0 == tId) { - auto bs = B + hEnd; - Vec4 sumValue0; - Vec4 sumValue1; - if (biasPtr != nullptr) { - sumValue0 = Vec4::load(biasPtr + hEnd + 0); - sumValue1 = Vec4::load(biasPtr + hEnd + 4); - } else { - sumValue0 = Vec4(0.0f); - sumValue1 = Vec4(0.0f); - } - auto srcY = A + hEnd * l; - for (int x=0; x= 4) { - if (0 == tId) { - auto bs = B + hEnd; - Vec4 sumValue0; - if (biasPtr != nullptr) { - sumValue0 = Vec4::load(biasPtr + hEnd + 0); - } else { - sumValue0 = Vec4(0.0f); - } - auto srcY = A + hEnd * l; - for (int x=0; x= 8) { - sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); - sum1 = Vec::fma(sum1, Vec4::load(srcY + lR + 4), Vec4::load(B + lR + 4)); - lR += 8; - } - if (l - lR >= 4) { - sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); - lR += 4; - } - sum2 = sum2 + sum3; - sumValue = sumValue + sum1; - sumValue = sumValue + sum2; + sumValue = sumValue + Vec4::load(srcY + 4 * x) * Vec4::load(B + 4 * x); + } float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; for (int x=lR; x - -void CPUBilinearLineC4(float* dst, const float* A, const float* B, - const float* t, int8_t* zeroPoint, size_t number) { - float tf = *t; - float sf = 1.0f - tf; - size_t total = number << 2; - - size_t i = 0; - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); - vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); - v = __riscv_vle32_v_f32m8(B + i, vl); - result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); - __riscv_vse32_v_f32m8(dst + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp deleted file mode 100644 index 5063c39bff..0000000000 --- a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -void CPUBilinearSampleC4(const float* src, float* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); - vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); - vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); - vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp deleted file mode 100644 index 59bb28a039..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); - size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp deleted file mode 100644 index 6d966789f7..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { - float beta = parameters[1]; - float minF = parameters[2]; - float maxF = parameters[3]; - const ptrdiff_t stride = 4 * sizeof(float); - - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + 4 * y; - auto c = C + cStride * y; - float b0Beta = b[0] * beta; - float b1Beta = b[1] * beta; - float b2Beta = b[2] * beta; - float b3Beta = b[3] * beta; - size_t w = width; - - while (w > 0) { - size_t vl = __riscv_vsetvl_e32m8(w); - - vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); - - a += 4 * vl; - c += 4 * vl; - w -= vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp deleted file mode 100644 index 145cbea73f..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include - -void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp deleted file mode 100644 index d46fe6c85b..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp deleted file mode 100644 index 684db6aed3..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp deleted file mode 100644 index a26243bdb8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, - const float* t, int8_t* zeroPoint, size_t number) { - int offset = *zeroPoint; - int8_t* dstPtr = dst; - - const int pack = 8; - const int16_t df = (int16_t)((*t) * 128.0f); - const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); - const size_t total = number * pack; - const int32_t ROUND_HALF = 1 << 13; - - size_t vl; - for (size_t i = 0; i < total; i += vl) { - vl = __riscv_vsetvl_e16m4(total - i); - vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); - vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); - v16 = __riscv_vle16_v_i16m4(B + i, vl); - v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); - - vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); - - tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); - mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); - vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); - mask = __riscv_vmand_mm_b4(mask, hasRem, vl); - - v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); - v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); - v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); - vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); - - __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp deleted file mode 100644 index bd111e3be4..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include - -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - int16_t offset = (int16_t)(*zeroPoint); - const int pack = 8; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - vint16m4_t vdf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); - vint16m4_t vsf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8( - __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), - c, vl); - - vint16m4_t va = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - - vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); - voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), - c, vl); - - vint16m4_t vb = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); - __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, - __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp deleted file mode 100644 index 9d524f13ca..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp deleted file mode 100644 index f82faf83f5..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include - -void MNNConvRunForLineDepthwise( - float* dst, const float* src, const float* weight, - size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, - size_t height, size_t srcHStep, size_t dstHStep, - const float* bias, const float* parameters) { - float minV = parameters[0]; - float maxV = parameters[1]; - ptrdiff_t srcByteStride = src_w_setup * sizeof(float); - ptrdiff_t dstByteStride = 4 * sizeof(float); - - for (size_t y = 0; y < height; ++y) { - const float* srcY = src + y * srcHStep; - float* dstY = dst + y * dstHStep; - size_t dx = 0; - - while (dx < width) { - size_t vl = __riscv_vsetvl_e32m8(width - dx); - - for (int c = 0; c < 4; ++c) { - vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); - const float* srcBase = srcY + dx * src_w_setup + c; - const float* weightPtr = weight + c; - - for (size_t fy = 0; fy < fh; ++fy) { - const float* srcFy = srcBase + fy * dilateY_step; - - for (size_t fx = 0; fx < fw; ++fx) { - float w = *weightPtr; - weightPtr += 4; - const float* srcFx = srcFy + fx * dilateX_step; - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); - } - } - - acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); - acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); - float* dstAddr = dstY + dx * 4 + c; - __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); - } - - dx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp deleted file mode 100644 index 3d8c4f13fc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include - -void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); -size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp deleted file mode 100644 index fd6ce7a274..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include - -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const int offset = *zeroPoint; - const int minVal = (int)minValue; - const int maxVal = (int)maxValue; - const size_t total = number << 4; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); - vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); - acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); - - vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); - vint = __riscv_vadd_vx_i32m8(vint, offset, vl); - vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); - vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); - - vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); - vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); - __riscv_vse8_v_i8m2(dst + i, vi8, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp deleted file mode 100644 index 0da63ca0ff..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include - -void MNNCubicLineC4(float* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const size_t total = number << 2; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - __riscv_vse32_v_f32m8(dst + i, acc, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp deleted file mode 100644 index fd5b24a53d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include - -void MNNCubicSampleC16(const int8_t* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 16; - int8_t zp = *zeroPoint; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp deleted file mode 100644 index 78207e69e8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include - -void MNNCubicSampleC4(const float* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp deleted file mode 100644 index 6658715e7e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNDeconvRunForUnitDepthWise( - const float* dst, float* src, const float* weight, - size_t fw, size_t fh, - size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { - const ptrdiff_t wStride = 4 * sizeof(float); - const ptrdiff_t sStride = dilateX_step * sizeof(float); - float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; - - for (size_t fy = 0; fy < fh; ++fy) { - float* srcY = src + fy * dilateY_step; - const float* weightY = weight + fy * weightY_step; - - size_t fx = 0; - while (fx < fw) { - size_t vl = __riscv_vsetvl_e32m8(fw - fx); - - vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); - __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); - __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); - __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); - __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); - - fx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp deleted file mode 100644 index 952fcaf090..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp deleted file mode 100644 index 5ee4540f98..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp deleted file mode 100644 index 183a38bb10..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { - const float init = -FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - maxBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp deleted file mode 100644 index 9e8ade8641..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { - const float init = FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - minBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp deleted file mode 100644 index 9a74f8998d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include - -void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC2 = depth / 2; - int depthRemain = depthC2 * 2; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[2]; - - for (int z = 0; z < depthC2; ++z) { - float *dstZ = dst + z * areaOffset[1] * 2; - - for (int y = 0; y < 2; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - } - - srcOffset += areaOffset[0] * 2; - } - - if (remain > 0) { - float *dstZ = dst + depthC2 * areaOffset[1] * 2; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 2; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 2; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp deleted file mode 100644 index 024e2c8c07..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include - -void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[4]; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ = dst + z * areaOffset[1] * 4; - - for (int y = 0; y < 4; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - dstPtr[2] = srcChannel[2][x]; - dstPtr[3] = srcChannel[3][x]; - } - - srcOffset += areaOffset[0] * 4; - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 4; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 4; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp deleted file mode 100644 index f2b6c7a78d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp deleted file mode 100644 index ddd67a7d8c..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp deleted file mode 100644 index d56b58546d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp deleted file mode 100644 index 7c6decf39e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp deleted file mode 100644 index 1b946c33cc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp deleted file mode 100644 index 262f4cbfab..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include - -void MNNReluWithSlopeChannel(float *dst, const float *src, - const float *slope, size_t sizeQuad, - size_t depthQuad) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t j = 0; j < depthQuad; ++j) { - const float *srcZ = src + 4 * j * sizeQuad; - float *dstZ = dst + 4 * j * sizeQuad; - float s0 = slope[4*j], s1 = slope[4*j + 1]; - float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; - size_t i = 0; - while (i < sizeQuad) { - size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); - const float *srcBase = srcZ + 4*i; - float *dstBase = dstZ + 4*i; - - vfloat32m8_t v; - vbool4_t mask; - - v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); - __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); - __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); - __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); - __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); - - i += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp deleted file mode 100644 index 10992f9d59..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t z = 0; z < biasNumber; ++z) { - float *dstZ = dst + z * planeNumber * 4; - const float *srcZ = src + z * planeNumber * 4; - const float *biasZ = bias + 4 * z; - const float *alphaZ = alpha + 4 * z; - float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; - float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; - - size_t n = planeNumber; - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a0, vl); - data = __riscv_vfadd_vf_f32m8(data, b0, vl); - __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a1, vl); - data = __riscv_vfadd_vf_f32m8(data, b1, vl); - __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a2, vl); - data = __riscv_vfadd_vf_f32m8(data, b2, vl); - __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a3, vl); - data = __riscv_vfadd_vf_f32m8(data, b3, vl); - __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); - - srcZ += vl * 4; - dstZ += vl * 4; - n -= vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp deleted file mode 100644 index f510058c83..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include -#include - -void MNNSoftmax(float *dest, const float *source, size_t size) { - size_t n = size; - const float *sourcePtr = source; - float *destPtr = dest; - float maxValue = -FLT_MAX; - vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); - maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); - sourcePtr += vl; - n -= vl; - } - - maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); - const float param = 0.6931471805599453f; - const float xLimit = 87.0f; - float sumValue = 0.f; - vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); - n = size; - sourcePtr = source; - destPtr = dest; - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); - vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); - vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); - vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); - - vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); - vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); - - vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( - __riscv_vsll_vx_i32m8( - __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); - - vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); - vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); - - vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - - vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); - __riscv_vse32_v_f32m8(destPtr, vA, vl); - sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); - - sourcePtr += vl; - destPtr += vl; - n -= vl; - } - - sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); - float sumInv = 1.0f / sumValue; - n = size; - destPtr = dest; - - while (n > 0) - { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); - vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); - __riscv_vse32_v_f32m8(destPtr, vDest, vl); - destPtr += vl; - n -= vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp deleted file mode 100644 index 8ab5bb89fa..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include - -void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, - float *xAddr, size_t cStride, size_t eSub, size_t hSub) { - for (int y = 0; y < hSub; ++y) { - float *c11Y = c11 + y * cStride; - float *c12Y = c12 + y * cStride; - float *c22Y = c22 + y * cStride; - float *c21Y = c21 + y * cStride; - float *xY = xAddr + y * eSub * 4; - size_t totalElements = eSub * 4; - size_t p = 0; - - while (p < totalElements) { - size_t vl = __riscv_vsetvl_e32m8(totalElements - p); - vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); - vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); - t = __riscv_vfadd_vv_f32m8(t, tmp, vl); - vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); - - tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); - __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); - - tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); - __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); - - c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); - __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); - - p += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp deleted file mode 100644 index 7598d6f8ac..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include - -void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); - - for (int i = 0; i < h; ++i) { - const int16_t* srcPtr = srcO + i; - int16_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e16m8(w - j); - vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); - __riscv_vse16_v_i16m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - - diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp deleted file mode 100644 index e5c5eb83e6..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include - -void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); - - for (int i = 0; i < h; ++i) { - const int32_t* srcPtr = srcO + i; - int32_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e32m8(w - j); - vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); - __riscv_vse32_v_i32m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp deleted file mode 100644 index 4676e6dede..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ[4]; - - for (int y = 0; y < 4; ++y) { - dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); - srcOffset += 4 * vl; - } - - for (; x < area; ++x) { - dstZ[0][x] = srcOffset[0]; - dstZ[1][x] = srcOffset[1]; - dstZ[2][x] = srcOffset[2]; - dstZ[3][x] = srcOffset[3]; - srcOffset += (areaOffset[0] - area) * 4; - } - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - const float *srcBase = srcOffset; - - for (int y = 0; y < remain; ++y) { - float *dstChannel = dstZ + y * areaOffset[1]; - const float *srcChannel = srcBase + y; - - for (size_t x = 0; x < area; ++x) { - dstChannel[x] = srcChannel[0]; - srcChannel += 4; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp deleted file mode 100644 index 7332360ce8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - float maxV = -FLT_MAX; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); - vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); - maxV = __riscv_vfmv_f_s_f32m1_f32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp deleted file mode 100644 index 8c199709ec..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - int32_t maxV = INT32_MIN; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); - vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); - maxV = __riscv_vmv_x_s_i32m1_i32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 885808fc44..79f52ff2dc 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -15,7 +15,6 @@ #define MTLGPUFamilyMetal3_MNN 5001 #define MTLGPUFamilyMetal4_MNN 5002 -#define CHECK_IOS_UI_STATUS #if MNN_METAL_ENABLED #include #import "backend/metal/MNNMetalContext.h" @@ -23,9 +22,6 @@ #import "core/TensorUtils.hpp" #include "MetalCache_generated.h" #include "core/MNNFileUtils.h" -#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE -#import -#endif int MNNMetalGetTensorContent(MNNMetalTensorContent* content, void* tensor) { if (nullptr == content || nullptr == tensor) { return 0; @@ -780,9 +776,6 @@ static void _execute(id encoder, const MetalBackend::C MNN_ASSERT(false); // should not be handled here } int MetalBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) { - if (mRuntime->pExecutionStatus == NO_EXECUTION) { - return NO_EXECUTION; - } flushEncoder(); auto ctx = (__bridge MNNMetalContext *)context(); commit_net(); @@ -831,19 +824,6 @@ static void _execute(id encoder, const MetalBackend::C void MetalBackend::commit() const { -#ifdef CHECK_IOS_UI_STATUS -#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE - if ([UIApplication sharedApplication].applicationState == UIApplicationStateBackground || [UIApplication sharedApplication].applicationState == UIApplicationStateInactive) { - mRuntime->pExecutionStatus = NO_EXECUTION; - _commandBuffer = nil; - if (!mSupportDeferEncode) { - _commandBuffer_net = nil; - } - return; - } -#endif -#endif - mRuntime->pExecutionStatus = NO_ERROR; if (nil != _commandBuffer && _commandBuffer.status < MTLCommandBufferStatusCommitted) { [_commandBuffer commit]; mRuntime->_waiting = _commandBuffer; @@ -856,19 +836,6 @@ static void _execute(id encoder, const MetalBackend::C } void MetalBackend::commit_net() const { -#ifdef CHECK_IOS_UI_STATUS -#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE - if ([UIApplication sharedApplication].applicationState == UIApplicationStateBackground || [UIApplication sharedApplication].applicationState == UIApplicationStateInactive) { - mRuntime->pExecutionStatus = NO_EXECUTION; - _commandBuffer_net = nil; - if (!mSupportDeferEncode) { - _commandBuffer = nil; - } - return; - } -#endif -#endif - mRuntime->pExecutionStatus = NO_ERROR; if (nil != _commandBuffer_net && _commandBuffer_net.status < MTLCommandBufferStatusCommitted) { [_commandBuffer_net commit]; mRuntime->_waiting = _commandBuffer_net; diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index e463f251f7..bcf618c3c9 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,11 +68,9 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; -#ifdef MNN_DEFAULT_USE_KLEIDIAI - bool enableKleidiAI = true; -#else + bool enableKleidiAI = false; -#endif + // Use CPU Ids std::vector cpuIds; @@ -395,7 +393,6 @@ class Runtime : public NonCopyable { } mutable int pCurrentStatus = 0; // NO_ERROR - mutable int pExecutionStatus = 0; // NO_ERROR // TODO: Move to Backend void* pMeta = nullptr; diff --git a/source/core/Concurrency.h b/source/core/Concurrency.h index 7c06625fe4..73f5984e5a 100644 --- a/source/core/Concurrency.h +++ b/source/core/Concurrency.h @@ -12,9 +12,6 @@ #define LAUNCH_MULTI_THREADS_WORKLOAD 1e+5 #ifdef MNN_FORBIT_MULTI_THREADS -#define MNN_CONCURRENCY_ENQUEUE(task) \ -for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} - #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) for (int __iter__ = 0; __iter__ < __num__; __iter__++) { #define MNN_CONCURRENCY_END() } @@ -22,8 +19,6 @@ for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__) #include "backend/cpu/ThreadPool.hpp" #define MNN_STRINGIFY(a) #a -#define MNN_CONCURRENCY_ENQUEUE(task) ((CPUBackend*)backend())->enqueue(task) - #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ { \ std::pair, int> task; \ @@ -33,7 +28,8 @@ for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__) } \ ; \ auto cpuBn = (CPUBackend*)backend(); \ - cpuBn->enqueue(task); \ + auto thrPl = cpuBn->threadPool(); \ + thrPl->enqueue(std::move(task), cpuBn->taskIndex()); \ } #else @@ -42,9 +38,6 @@ for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__) #include #include -#define MNN_CONCURRENCY_ENQUEUE(task) \ -dispatch_apply(task.second, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) {task.first(__iter__);}); - #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) { #define MNN_CONCURRENCY_END() \ @@ -65,8 +58,6 @@ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, // Android #else #include -#define MNN_CONCURRENCY_ENQUEUE(task) \ -_Pragma("omp parallel for") for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} #define MNN_STRINGIFY(a) #a #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index a69263ffaa..c80afaef87 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -386,7 +386,98 @@ void OpCommonUtils::broastCastComputeDim(int* dims, int* stride, int* iStride0, } } } +std::vector> OpCommonUtils::computeReduceDims(const std::vector& inputs, + const Op* op) { + // Compute axises + std::vector axises; + if (inputs.size() >= 2) { + auto size = inputs[1]->elementSize(); + auto dims = inputs[1]->host(); + for (int i = 0; i < size; ++i) { + axises.emplace_back(dims[i]); + } + } else { + auto reduct = op->main_as_ReductionParam(); + if (nullptr != reduct->dim()) { + for (int i = 0; i < reduct->dim()->size(); ++i) { + axises.emplace_back(reduct->dim()->data()[i]); + } + } + } + auto totalSize = TensorUtils::getRawSize(inputs[0]); + if (axises.empty()) { + return {std::make_tuple(1, totalSize, 1)}; + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + axises[i] = inputs[0]->dimensions() + axises[i]; + if (axises[i] < 0) { + return {std::make_tuple(1, totalSize, 1)}; + } + } + } + // Cache for input's dims + std::vector lengths(inputs[0]->dimensions()); + for (int i = 0; i < lengths.size(); ++i) { + lengths[i] = inputs[0]->length(i); + } + std::vector> groupAxises; + { + // Merge adj axis + std::sort(axises.begin(), axises.end()); + int lastAxis = axises[0]; + int length = 1; + int start = axises[0]; + for (int i = 1; i < axises.size(); ++i) { + // MNN_PRINT("%d - %d\n", axises[i], lastAxis); + if (axises[i] - lastAxis == 1) { + length++; + } else { + groupAxises.emplace_back(std::make_pair(start, length)); + length = 1; + start = axises[i]; + } + lastAxis = axises[i]; + } + groupAxises.emplace_back(std::make_pair(start, length)); + } + + // Compute inside-outside-axis + std::vector> result; + for (int i = 0; i < groupAxises.size(); ++i) { + int outsideSize = 1; + int insideSize = 1; + int axisSize = 1; + auto start = groupAxises[i].first; + auto length = groupAxises[i].second; + if (start >= (int)lengths.size()) { + break; + } + for (int j = 0; j < start; ++j) { + outsideSize *= lengths[j]; + } + for (int j = start; j < start + length; ++j) { + if (j >= (int)lengths.size()) { + break; + } + axisSize *= lengths[j]; + lengths[j] = 1; + } + for (int j = start + length; j < lengths.size(); ++j) { + insideSize *= lengths[j]; + } + if (1 == axisSize) { + continue; + } + result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); + } + // FUNC_PRINT(result.size()); + if (result.empty()) { + result.emplace_back(std::make_tuple(1, 1, totalSize)); + } + return result; +} void OpCommonUtils::unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice) { int value = indice; diff --git a/source/core/OpCommonUtils.hpp b/source/core/OpCommonUtils.hpp index 8ec0628336..0740cc16b2 100644 --- a/source/core/OpCommonUtils.hpp +++ b/source/core/OpCommonUtils.hpp @@ -56,6 +56,7 @@ class MNN_PUBLIC OpCommonUtils { static bool supportDynamicInputMemory(MNNForwardType type); static void broastCastComputeDim(int* dims, int* stride, int* iStride0, int* iStride1, const Tensor* input0, const Tensor* input1, const Tensor* output); + static std::vector> computeReduceDims(const std::vector& inputs, const Op* op); static void unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice); static int computeStride(int32_t* strides, const int* shape, int length); diff --git a/source/core/Tensor.cpp b/source/core/Tensor.cpp index 664fa6b790..18bf5ec7a6 100644 --- a/source/core/Tensor.cpp +++ b/source/core/Tensor.cpp @@ -430,14 +430,6 @@ void* Tensor::map(MapType mtype, DimensionType dtype) { return mBuffer.host; } - if (mtype == Tensor::MAP_TENSOR_READ) { - int syncResult = bn->onSync(mtype, false, this); - if (NO_EXECUTION == syncResult) { - MNN_PRINT("Warning, Backend has stop execute, return nullptr for tensor map addr\n"); - return nullptr; - } - } - auto mapPtr = bn->onMapTensor(mtype, dtype, this); if(mapPtr != nullptr) { // Get mapPtr in specific backend diff --git a/source/core/TensorUtils.cpp b/source/core/TensorUtils.cpp index d233fc9d89..ae5b87143c 100644 --- a/source/core/TensorUtils.cpp +++ b/source/core/TensorUtils.cpp @@ -32,18 +32,6 @@ bool TensorUtils::regionIsFull(Tensor* input) { return regionSize == size; } -void TensorUtils::makeFullRef(Tensor* output, Tensor* input) { - auto des = TensorUtils::getDescribe(input); - auto outputDes = TensorUtils::getDescribe(output); - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; - if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) { - outputDes->regions = des->regions; - } else { - outputDes->regions = {makeFullSlice(input)}; - } -} - - Tensor::InsideDescribe::Region TensorUtils::makeFullSlice(Tensor* input) { Tensor::InsideDescribe::Region totalSlice; totalSlice.src.offset = 0; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index a577fea05f..1342a669bd 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -184,7 +184,6 @@ class MNN_PUBLIC TensorUtils { static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat); static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input); - static void makeFullRef(Tensor* output, Tensor* input); static bool regionIsFull(Tensor* input); static bool isCopyRegion(const Tensor::InsideDescribe::Region& region); static bool isTransposeRegion(const Tensor::InsideDescribe::Region& region); diff --git a/source/geometry/GeometryComputerUtils.cpp b/source/geometry/GeometryComputerUtils.cpp index 85f64de55d..01a4e02ea2 100644 --- a/source/geometry/GeometryComputerUtils.cpp +++ b/source/geometry/GeometryComputerUtils.cpp @@ -477,9 +477,9 @@ std::shared_ptr GeometryComputerUtils::makeBinary(int type, Tensor* inp return cmdP; } -std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis) { +std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output) { flatbuffers::FlatBufferBuilder builder(DEFAULT_ALLOCATE_SIZE); - auto vec = builder.CreateVector(std::vector{axis}); + auto vec = builder.CreateVector(std::vector{1}); ReductionParamBuilder builder_(builder); builder_.add_operation(type); builder_.add_keepDims(true); diff --git a/source/geometry/GeometryComputerUtils.hpp b/source/geometry/GeometryComputerUtils.hpp index 97c4d5811f..c0dffdcdb1 100644 --- a/source/geometry/GeometryComputerUtils.hpp +++ b/source/geometry/GeometryComputerUtils.hpp @@ -18,7 +18,7 @@ class GeometryComputerUtils { static void addConvert(const CommandBuffer& srcBuffer, CommandBuffer& dstBuffer, GeometryComputer::Context& ctx); static std::shared_ptr makeCommand(flatbuffers::FlatBufferBuilder& builder, const std::vector& inputs, const std::vector& outputs); static std::shared_ptr makeBinary(int type, Tensor* input0, Tensor* input1, Tensor* output); - static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis = 1); + static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output); static std::shared_ptr makeUnary(UnaryOpOperation type, Tensor* input0, Tensor* output); static std::shared_ptr makeLayerNorm(Tensor* input0, Tensor* output, std::vector axis, float epsilon, std::vector gamma, std::vector beta, std::vector external, int group = 1, bool useRMS = false); static std::shared_ptr makeMatMul(Tensor* input0, Tensor* input1, Tensor* output, Tensor* Bias = nullptr, diff --git a/source/geometry/GeometryReduce.cpp b/source/geometry/GeometryReduce.cpp index 855f4bcf69..c2a3bb4114 100644 --- a/source/geometry/GeometryReduce.cpp +++ b/source/geometry/GeometryReduce.cpp @@ -10,83 +10,6 @@ #include "geometry/GeometryComputerUtils.hpp" #include "core/OpCommonUtils.hpp" namespace MNN { -static std::vector> _computeReduceDims(const std::vector& inputs, - std::vector& axises) { - - auto totalSize = TensorUtils::getRawSize(inputs[0]); - if (axises.empty()) { - return {std::make_tuple(1, totalSize, 1)}; - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - if (axises[i] < 0) { - return {std::make_tuple(1, totalSize, 1)}; - } - } - } - // Cache for input's dims - std::vector lengths(inputs[0]->dimensions()); - for (int i = 0; i < lengths.size(); ++i) { - lengths[i] = inputs[0]->length(i); - } - std::vector> groupAxises; - { - // Merge adj axis - std::sort(axises.begin(), axises.end()); - int lastAxis = axises[0]; - int length = 1; - int start = axises[0]; - for (int i = 1; i < axises.size(); ++i) { - // MNN_PRINT("%d - %d\n", axises[i], lastAxis); - if (axises[i] - lastAxis == 1) { - length++; - } else { - groupAxises.emplace_back(std::make_pair(start, length)); - length = 1; - start = axises[i]; - } - lastAxis = axises[i]; - } - groupAxises.emplace_back(std::make_pair(start, length)); - } - - // Compute inside-outside-axis - std::vector> result; - - for (int i = 0; i < groupAxises.size(); ++i) { - int outsideSize = 1; - int insideSize = 1; - int axisSize = 1; - auto start = groupAxises[i].first; - auto length = groupAxises[i].second; - if (start >= (int)lengths.size()) { - break; - } - for (int j = 0; j < start; ++j) { - outsideSize *= lengths[j]; - } - for (int j = start; j < start + length; ++j) { - if (j >= (int)lengths.size()) { - break; - } - axisSize *= lengths[j]; - lengths[j] = 1; - } - for (int j = start + length; j < lengths.size(); ++j) { - insideSize *= lengths[j]; - } - if (1 == axisSize) { - continue; - } - result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); - } - // FUNC_PRINT(result.size()); - if (result.empty()) { - result.emplace_back(std::make_tuple(1, 1, totalSize)); - } - return result; -} - class GeometryReduce : public GeometryComputer { public: virtual bool onCompute(const Op* op, const std::vector& inputs, const std::vector& outputs, @@ -95,31 +18,6 @@ class GeometryReduce : public GeometryComputer { MNN_ASSERT(inputs.size() >= 1); auto reduct = op->main_as_ReductionParam(); auto reductOp = reduct->operation(); - std::vector axises; - if (inputs.size() >= 2) { - auto size = inputs[1]->elementSize(); - auto dims = inputs[1]->host(); - for (int i = 0; i < size; ++i) { - axises.emplace_back(dims[i]); - } - } else { - auto reduct = op->main_as_ReductionParam(); - if (nullptr != reduct->dim()) { - for (int i = 0; i < reduct->dim()->size(); ++i) { - axises.emplace_back(reduct->dim()->data()[i]); - } - } - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - axises[i] = inputs[0]->dimensions() + axises[i]; - } - } - if (1 == axises.size() && TensorUtils::getDescribe(inputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4 && TensorUtils::getDescribe(outputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - auto cmd = GeometryComputerUtils::makeReduce(reductOp, inputs[0], outputs[0], axises[0]); - res.command.emplace_back(std::move(cmd)); - return true; - } // prod([]) = 1 if (inputs[0]->elementSize() == 0) { if(!context.allocTensor(outputs[0])) { @@ -141,7 +39,7 @@ class GeometryReduce : public GeometryComputer { } return true; } - auto reduceDims = _computeReduceDims(inputs, axises); + auto reduceDims = OpCommonUtils::computeReduceDims(inputs, op); Tensor* currentInput = inputs[0]; MNN_ASSERT(reduceDims.size() > 0); auto dimType = currentInput->getDimensionType(); diff --git a/source/geometry/GeometryReshape.cpp b/source/geometry/GeometryReshape.cpp index 1df3384e37..88d98a24c9 100644 --- a/source/geometry/GeometryReshape.cpp +++ b/source/geometry/GeometryReshape.cpp @@ -42,7 +42,8 @@ class GeometryReshape : public GeometryComputer { return true; } } - TensorUtils::makeFullRef(output, input); + outputDes->regions = {TensorUtils::makeFullSlice(input)}; + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; return true; } }; @@ -74,7 +75,10 @@ class SingleGeometryComputer : public GeometryComputer { Context& context, CommandBuffer& res) const override { auto input = inputs[0]; auto output = outputs[0]; - TensorUtils::makeFullRef(output, input); + auto inputDes = TensorUtils::getDescribe(input); + auto outputDes = TensorUtils::getDescribe(output); + outputDes->regions = {TensorUtils::makeFullSlice(input)}; + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; return true; } }; @@ -90,7 +94,8 @@ class CopyGeometryComputer : public GeometryComputer { outputDes->tensorArrayAttr = inputDes->tensorArrayAttr; return true; } - TensorUtils::makeFullRef(output, input); + outputDes->regions = {TensorUtils::makeFullSlice(input)}; + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; } return true; } diff --git a/source/math/Vec.hpp b/source/math/Vec.hpp index cc9354a7f1..6839ab83b0 100644 --- a/source/math/Vec.hpp +++ b/source/math/Vec.hpp @@ -372,7 +372,8 @@ struct Vec { using VecType = Vec; using VecTypeInt32 = Vec; float32x4_t value; - Vec() = default; + Vec() { + } Vec(const float v) { value = vdupq_n_f32(v); } diff --git a/test/core/ThreadPoolTest.cpp b/test/core/ThreadPoolTest.cpp index e010939e5f..6886f86e62 100644 --- a/test/core/ThreadPoolTest.cpp +++ b/test/core/ThreadPoolTest.cpp @@ -26,11 +26,11 @@ class ThreadPoolTest : public MNNTestCase { auto workIndex = threadPool->acquireWorkIndex(); FUNC_PRINT(workIndex); threadPool->active(); - ThreadPool::TASK task = std::make_pair([](int index) { + auto func = [](int index) { FUNC_PRINT(index); std::this_thread::yield(); - }, 10); - threadPool->enqueue(&task, workIndex); + }; + threadPool->enqueue(std::make_pair(std::move(func), 10), workIndex); threadPool->deactive(); threadPool->releaseWorkIndex(workIndex); }); diff --git a/tools/cpp/ExprDebug.hpp b/tools/cpp/ExprDebug.hpp index 49e3db6156..167e97c562 100644 --- a/tools/cpp/ExprDebug.hpp +++ b/tools/cpp/ExprDebug.hpp @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #define DUMP_NUM_DATA(type) \ @@ -136,69 +135,29 @@ static void _initDebug() { struct TimeTraceInfo { - std::map>> mTypes; + std::map>>> mTypes; void begin(const MNN::OperatorInfo* info) { auto tIter = mTypes.find(info->type()); if (tIter == mTypes.end()) { - std::map> _t; + std::map>> _t; mTypes.insert(std::make_pair(info->type(), _t)); tIter = mTypes.find(info->type()); } mInserIter = tIter->second.find(info->name()); if (mInserIter == tIter->second.end()) { - tIter->second.insert(std::make_pair(info->name(), std::make_tuple(0.0f, 0.0f, 0))); + std::vector> _t; + tIter->second.insert(std::make_pair(info->name(), _t)); mInserIter = tIter->second.find(info->name()); } mTimer.reset(); } void end(const MNN::OperatorInfo* info) { auto timeInMs = (float)mTimer.durationInUs() / 1000.0f; - std::get<0>(mInserIter->second) += timeInMs; - std::get<1>(mInserIter->second) += info->flops(); - std::get<2>(mInserIter->second) ++; - } - void dump(bool dumpPerOp = false) { - if (dumpPerOp) { - auto cmp = [](const std::tuple& first, const std::tuple& second) { - return std::get<1>(first) > std::get<1>(second); - }; - std::priority_queue, std::vector>, decltype(cmp)> que(cmp); - for (auto& iter : mTypes) { - for (auto& t : iter.second) { - auto mergeType = t.first + " ["+iter.first +"]"; - auto unit = std::make_tuple(mergeType, std::get<0>(t.second), std::get<1>(t.second), std::get<2>(t.second)); - que.push(unit); - } - } - while (!que.empty()) { - auto& t = que.top(); - MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", std::get<0>(t).c_str(), std::get<1>(t), std::get<2>(t), std::get<3>(t), std::get<2>(t) / std::get<1>(t)); - que.pop(); - } - return; - } - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - int count = 0; - for (auto& t : iter.second) { - summer += std::get<0>(t.second); - summerflops += std::get<1>(t.second); - count += std::get<2>(t.second); - } - MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, count, - summerflops / summer); - opSummer += summer; - opFlopsSummber += summerflops; - } - MNN_PRINT("OP Summer: %.7f ms, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, - opFlopsSummber / opSummer); + mInserIter->second.emplace_back(std::make_pair(timeInMs, info->flops())); } private: - std::map>::iterator mInserIter; + std::map>>::iterator mInserIter; MNN::Timer mTimer; }; static TimeTraceInfo* gTimeTraceInfo = nullptr; diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 5798bc6d26..90fa6b80d3 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -499,13 +499,10 @@ int main(int argc, char *argv[]) { if (runTime > 0) { int t = runTime; + std::vector times(t, 0.0f); if (runMask & 4) { _initTimeTrace(); } - float minTime = std::numeric_limits::max(); - float maxTime = 0.0f; - float sum = 0.0f; - for (int i = 0; i < t; ++i) { Timer _l; auto out = net->onForward(inputs); @@ -513,28 +510,41 @@ int main(int argc, char *argv[]) { for (auto o : out) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } - auto time = _l.durationInUs() / 1000.0f; + times[i] = _l.durationInUs() / 1000.0f; if (freq > 0.0f) { - float remainMs = (1000.0f / freq) - time; + float remainMs = (1000.0f / freq) - times[i]; if (remainMs > 0.0f) { std::this_thread::sleep_for(std::chrono::milliseconds((int)remainMs)); } } - if (maxTime < time) { - maxTime = time; - } - if (minTime > time) { - minTime = time; - } - sum += time; } if (nullptr != gTimeTraceInfo) { - MNN_PRINT("Per Op Trace: \n"); - gTimeTraceInfo->dump(true); - MNN_PRINT("Per Type Trace: \n"); - gTimeTraceInfo->dump(false); + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : gTimeTraceInfo->mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + for (auto& t : iter.second) { + for (auto& t0 : t.second) { + summer += t0.first; + summerflops += t0.second; + } + } + summer = summer / (float)t; + summerflops = summerflops / (float)t; + MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer); + opSummer += summer; + opFlopsSummber+= summerflops; + } + MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer); + } + auto minTime = std::min_element(times.begin(), times.end()); + auto maxTime = std::max_element(times.begin(), times.end()); + float sum = 0.0f; + for (auto time : times) { + sum += time; } - MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, minTime, maxTime); + MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, *minTime, *maxTime); } rtmgr->updateCache(); return 0; diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 5516eb2fcc..21f05e83be 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - ["A sample prompt", "A sample prompt"], + "A sample prompt", padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,7 +97,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes=None, + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, opset=opset, ) del pipeline.text_encoder @@ -115,9 +117,13 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -143,7 +149,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() onnx_export( vae_encoder, model_args=( @@ -153,24 +159,30 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] + vae_decoder.forward = vae_encoder.decode onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample"], + ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], - dynamic_axes=None, + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) del pipeline.vae diff --git a/transformers/llm/engine/demo/llm_demo.cpp b/transformers/llm/engine/demo/llm_demo.cpp index ec0f39c146..305ef2169b 100644 --- a/transformers/llm/engine/demo/llm_demo.cpp +++ b/transformers/llm/engine/demo/llm_demo.cpp @@ -135,21 +135,21 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ if (context->audio_input_s > 0.0f) { audio_speed = context->audio_input_s / audio_s; } - MNN_PRINT("\n#################################\n"); - MNN_PRINT("prompt tokens num = %d\n", prompt_len); - MNN_PRINT("decode tokens num = %d\n", decode_len); - MNN_PRINT(" vision time = %.2f s\n", vision_s); - MNN_PRINT(" pixels_mp = %.2f MP\n", context->pixels_mp); - MNN_PRINT(" audio process time = %.2f s\n", audio_s); - MNN_PRINT(" audio input time = %.2f s\n", context->audio_input_s); - MNN_PRINT("prefill time = %.2f s\n", prefill_s); - MNN_PRINT(" decode time = %.2f s\n", decode_s); - MNN_PRINT(" sample time = %.2f s\n", sample_s); - MNN_PRINT("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); - MNN_PRINT(" decode speed = %.2f tok/s\n", decode_len / decode_s); - MNN_PRINT(" vision speed = %.3f MP/s\n", vision_speed); - MNN_PRINT(" audio RTF = %.3f \n", audio_s / context->audio_input_s); - MNN_PRINT("##################################\n"); + printf("\n#################################\n"); + printf("prompt tokens num = %d\n", prompt_len); + printf("decode tokens num = %d\n", decode_len); + printf(" vision time = %.2f s\n", vision_s); + printf(" pixels_mp = %.2f MP\n", context->pixels_mp); + printf(" audio process time = %.2f s\n", audio_s); + printf(" audio input time = %.2f s\n", context->audio_input_s); + printf("prefill time = %.2f s\n", prefill_s); + printf(" decode time = %.2f s\n", decode_s); + printf(" sample time = %.2f s\n", sample_s); + printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); + printf(" decode speed = %.2f tok/s\n", decode_len / decode_s); + printf(" vision speed = %.3f MP/s\n", vision_speed); + printf(" audio RTF = %.3f \n", audio_s / context->audio_input_s); + printf("##################################\n"); return 0; } @@ -165,12 +165,12 @@ static int ceval(Llm* llm, const std::vector& lines, std::string fi prompt += "\nC. " + elements[4]; prompt += "\nD. " + elements[5]; prompt += "\n\n"; - MNN_PRINT("%s", prompt.c_str()); - MNN_PRINT("## 进度: %d / %lu\n", i, lines.size() - 1); + printf("%s", prompt.c_str()); + printf("## 进度: %d / %lu\n", i, lines.size() - 1); std::ostringstream lineOs; llm->response(prompt.c_str(), &lineOs); auto line = lineOs.str(); - MNN_PRINT("%s", line.c_str()); + printf("%s", line.c_str()); answers.push_back(line); } { diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index 20eff94be9..6ae61a5e35 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -59,13 +59,6 @@ enum TuneType { // op encoder number for commit OP_ENCODER_NUMBER = 0, }; -enum class LlmStatus { - RUNNING = 0, - NORMAL_FINISHED = 1, - MAX_TOKENS_FINISHED = 2, - USER_CANCEL = 3, - INTERNAL_ERROR = 4, -}; enum class MatchStrictLevel : int; enum class NgramSelectRule : int; @@ -91,8 +84,6 @@ struct LlmContext { std::vector history_tokens; std::vector output_tokens; std::string generate_str; - // llm status - LlmStatus status; }; struct GenerationParams; class MNN_PUBLIC Llm { diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index c0cabd4414..53af11239a 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -467,7 +467,6 @@ std::vector Llm::forwardRaw(Express::VARP hiddenState, Express::V std::vector outputs = selectModule->onForward(inputs); if (outputs.empty()) { - mContext->status = LlmStatus::INTERNAL_ERROR; return outputs; } if (!mAsync) { @@ -593,9 +592,6 @@ std::vector Llm::forwardVec(MNN::Express::VARP input_embeds) { auto attention_mask = gen_attention_mask(blockSize); auto position_ids = gen_position_ids(blockSize); logits = forwardRaw(embed, attention_mask, position_ids); - if(logits.empty()) { - return logits; - } updateContext(blockSize, 0); } bool hasPad = false; @@ -627,9 +623,6 @@ std::vector Llm::forwardVec(MNN::Express::VARP input_embeds) { auto attention_mask = gen_attention_mask(forwardSize); auto position_ids = gen_position_ids(forwardSize); logits = forwardRaw(input_embeds, attention_mask, position_ids); - if(logits.empty()) { - return logits; - } } updateContext(-blockSize * blockNumber, 0); if (hasPad) { @@ -683,7 +676,6 @@ void Llm::generate_init(std::ostream* os, const char* end_with) { mContext->decode_us = 0; mContext->current_token = -1; mContext->sample_us = 0; - mContext->status = LlmStatus::RUNNING; if (!mConfig->reuse_kv()) { mContext->all_seq_len = 0; mContext->history_tokens.clear(); @@ -832,7 +824,6 @@ std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) Timer _t; forwardVec(input_embeds); if(mGenerateParam->outputs.size() < 1) { - mContext->status = LlmStatus::INTERNAL_ERROR; return {}; } updateContext(seqLen, 0); @@ -924,7 +915,26 @@ Llm::Llm(std::shared_ptr config) : mConfig(config) { Llm::~Llm() { #if DEBUG_MODE == 1 if (nullptr != gTimeTraceInfo) { - gTimeTraceInfo->dump(); + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : gTimeTraceInfo->mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + for (auto& t : iter.second) { + for (auto& t0 : t.second) { + summer += t0.first; + summerflops += t0.second; + } + } + summer = summer; + summerflops = summerflops; + MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, + summerflops / summer); + opSummer += summer; + opFlopsSummber += summerflops; + } + MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, + opFlopsSummber / opSummer); } #endif mGenerateParam.reset(); @@ -1141,14 +1151,7 @@ VARP Llm::gen_position_ids(int seq_len) { } bool Llm::is_stop(int token_id) { - if (mContext->status == LlmStatus::USER_CANCEL || mContext->status == LlmStatus::INTERNAL_ERROR) { - return true; - } - bool stop = mTokenizer->is_stop(token_id); - if (stop) { - mContext->status = LlmStatus::NORMAL_FINISHED; - } - return stop; + return mTokenizer->is_stop(token_id); } } // namespace Transformer } // namespace MNN diff --git a/transformers/llm/engine/src/speculative_decoding/eagle.cpp b/transformers/llm/engine/src/speculative_decoding/eagle.cpp index 15548c64d3..b4c892fd97 100644 --- a/transformers/llm/engine/src/speculative_decoding/eagle.cpp +++ b/transformers/llm/engine/src/speculative_decoding/eagle.cpp @@ -328,22 +328,9 @@ void EagleGeneration::generate(GenerationParams& param) { std::vector accpetLens; auto newTokens = 0, steps = 0; while (true) { - if(mContext->status == LlmStatus::USER_CANCEL) { - break; - } steps++; MNN::Timer _dt; auto decodingInfo = treeDecoding(draftInfo); - for (auto o : decodingInfo) { - if(nullptr == o->readMap()) { - mContext->status = LlmStatus::INTERNAL_ERROR; - break; - } - } - if(decodingInfo.empty()) { - break; - } - treeDecodingTime += _dt.durationInUs(); auto acceptInfo = evaluatePosterior(draftInfo, decodingInfo[0]); newTokens += acceptInfo.acceptTokens.size(); @@ -365,9 +352,6 @@ void EagleGeneration::generate(GenerationParams& param) { eagleGenerateTime += _gt.durationInUs(); } mContext->decode_us += _t.durationInUs(); - if(newTokens >= param.max_new_tokens) { - mContext->status = LlmStatus::MAX_TOKENS_FINISHED; - } #if EAGLE_DEBUG printf("\n### Tree Decoding Time: %f s, Eagle Generate Time: %f s\n", (float)treeDecodingTime / 1000000.0, (float)eagleGenerateTime / 1000000.0); printf("\n### Tree Decoding Avg Time: %f ms, steps: %d\n", (float)treeDecodingTime / 1000.0 / steps, steps); diff --git a/transformers/llm/engine/src/speculative_decoding/generate.cpp b/transformers/llm/engine/src/speculative_decoding/generate.cpp index 4ed01b1f5c..31d3a3b9f7 100644 --- a/transformers/llm/engine/src/speculative_decoding/generate.cpp +++ b/transformers/llm/engine/src/speculative_decoding/generate.cpp @@ -43,9 +43,6 @@ void ArGeneration::generate(GenerationParams& param) { int max_token = param.max_new_tokens; int len = 0; while (len < max_token) { - if(mContext->status == LlmStatus::USER_CANCEL) { - break; - } AUTOTIME; // Update gen seq mContext->current_token = mLlm->sample(param.outputs[0], param.validLogitStart, param.validLogitSize); @@ -66,14 +63,9 @@ void ArGeneration::generate(GenerationParams& param) { *mContext->os << decodeStr; *mContext->os << std::flush; } + // Compute Next Logits auto outputs = mLlm->forwardVec({mContext->current_token}); - for (auto o : outputs) { - if(nullptr == o->readMap()) { - mContext->status = LlmStatus::INTERNAL_ERROR; - break; - } - } if(outputs.empty()) { break; } @@ -82,9 +74,6 @@ void ArGeneration::generate(GenerationParams& param) { mContext->decode_us += _t.durationInUs(); len++; } - if(len >= max_token) { - mContext->status = LlmStatus::MAX_TOKENS_FINISHED; - } } int Generation::draftVerify(VARP logits, const std::vector &drafts, bool& stop) { diff --git a/transformers/llm/engine/src/speculative_decoding/lookahead.cpp b/transformers/llm/engine/src/speculative_decoding/lookahead.cpp index d8ce38037e..cf4c2a5c79 100644 --- a/transformers/llm/engine/src/speculative_decoding/lookahead.cpp +++ b/transformers/llm/engine/src/speculative_decoding/lookahead.cpp @@ -89,9 +89,6 @@ void LookaheadGeneration::generate(GenerationParams& param) { int verify_len = mLlm->mDraftLength + 1; while (len < max_token) { - if(mContext->status == LlmStatus::USER_CANCEL) { - break; - } MNN::Timer _t; std::vector drafts; drafts.push_back(mContext->current_token); @@ -129,12 +126,6 @@ void LookaheadGeneration::generate(GenerationParams& param) { AUTOTIME; // do draft token parallel verify auto outputs = mLlm->forwardVec(drafts); - for (auto o : outputs) { - if(nullptr == o->readMap()) { - mContext->status = LlmStatus::INTERNAL_ERROR; - break; - } - } if(outputs.empty()) { break; } @@ -201,9 +192,6 @@ void LookaheadGeneration::generate(GenerationParams& param) { } } } - if(len >= max_token) { - mContext->status = LlmStatus::MAX_TOKENS_FINISHED; - } #ifdef DUMP_PROFILE_INFO // adopt speculative decoding rate float spl_rate = 100.0 * spl_count / (spl_count + arg_count); diff --git a/transformers/llm/engine/src/speculative_decoding/mtp.cpp b/transformers/llm/engine/src/speculative_decoding/mtp.cpp index f5c6e0261a..aefc4a5aa7 100644 --- a/transformers/llm/engine/src/speculative_decoding/mtp.cpp +++ b/transformers/llm/engine/src/speculative_decoding/mtp.cpp @@ -151,9 +151,6 @@ void MtpGeneration::generate(GenerationParams& param) { int spl_count = 0; while (len < max_token) { - if(mContext->status == LlmStatus::USER_CANCEL) { - break; - } MNN::Timer _t; std::vector drafts; drafts.push_back(mContext->current_token); @@ -174,12 +171,6 @@ void MtpGeneration::generate(GenerationParams& param) { AUTOTIME; // do draft token parallel verify auto outputs = mLlm->forwardVec(drafts); - for (auto o : outputs) { - if(nullptr == o->readMap()) { - mContext->status = LlmStatus::INTERNAL_ERROR; - break; - } - } if (outputs.size() < 2) { break; } @@ -247,9 +238,6 @@ void MtpGeneration::generate(GenerationParams& param) { } } } - if(len >= max_token) { - mContext->status = LlmStatus::MAX_TOKENS_FINISHED; - } #ifdef DUMP_PROFILE_INFO // draft accept rate if adopt speculative decoding float spl_accept_rate = 100.0 * spl_accept / spl_decode; From 03835a171a772c0f11dacd16719a26b3d5b0a555 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:36:21 +0800 Subject: [PATCH 062/314] Merge pull request #4067 from ihb2032/opt/rvv-pixel-conv opt(RVV): Optimize blitter functions with intrinsics GitOrigin-RevId: a22d2d445a0d106f5c9201cbedd49c7b168225c6 --- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 +++++++++++++++++ .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 ++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 +++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 ++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 +++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 ++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 +++++++++++++++++++ 11 files changed, 201 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp new file mode 100644 index 0000000000..145cbea73f --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp @@ -0,0 +1,18 @@ +#include + +void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp new file mode 100644 index 0000000000..d46fe6c85b --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp new file mode 100644 index 0000000000..684db6aed3 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp new file mode 100644 index 0000000000..9d524f13ca --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp @@ -0,0 +1,20 @@ +#include + +void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp new file mode 100644 index 0000000000..952fcaf090 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp @@ -0,0 +1,13 @@ +#include + +void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp new file mode 100644 index 0000000000..5ee4540f98 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp @@ -0,0 +1,16 @@ +#include + +void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp new file mode 100644 index 0000000000..f2b6c7a78d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp new file mode 100644 index 0000000000..ddd67a7d8c --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp new file mode 100644 index 0000000000..d56b58546d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp new file mode 100644 index 0000000000..7c6decf39e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp new file mode 100644 index 0000000000..1b946c33cc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} From 0ed8746c597f1cbd230767d52265f0cce516b4e0 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:13 +0800 Subject: [PATCH 063/314] Merge pull request #4053 from ihb2032/opt/rvv-resize-functions opt(RVV): Optimize resize functions with intrinsics GitOrigin-RevId: 824f1b9ad56f611613f801eaa7e1c2ae2d3fd307 --- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 +++++ .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 ++++++++ .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 ++++++++++ .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 ++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 +++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 +++++++++ .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 +++++++++++++++ 8 files changed, 373 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp new file mode 100644 index 0000000000..a700016c31 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp @@ -0,0 +1,19 @@ +#include + +void CPUBilinearLineC4(float* dst, const float* A, const float* B, + const float* t, int8_t* zeroPoint, size_t number) { + float tf = *t; + float sf = 1.0f - tf; + size_t total = number << 2; + + size_t i = 0; + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); + vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); + v = __riscv_vle32_v_f32m8(B + i, vl); + result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); + __riscv_vse32_v_f32m8(dst + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp new file mode 100644 index 0000000000..5063c39bff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp @@ -0,0 +1,33 @@ +#include + +void CPUBilinearSampleC4(const float* src, float* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); + vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); + vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); + vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp new file mode 100644 index 0000000000..a26243bdb8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp @@ -0,0 +1,40 @@ +#include + +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, + const float* t, int8_t* zeroPoint, size_t number) { + int offset = *zeroPoint; + int8_t* dstPtr = dst; + + const int pack = 8; + const int16_t df = (int16_t)((*t) * 128.0f); + const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); + const size_t total = number * pack; + const int32_t ROUND_HALF = 1 << 13; + + size_t vl; + for (size_t i = 0; i < total; i += vl) { + vl = __riscv_vsetvl_e16m4(total - i); + vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); + vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); + v16 = __riscv_vle16_v_i16m4(B + i, vl); + v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); + + vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); + + tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); + mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); + vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); + mask = __riscv_vmand_mm_b4(mask, hasRem, vl); + + v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); + v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); + v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); + vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); + + __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp new file mode 100644 index 0000000000..bd111e3be4 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp @@ -0,0 +1,49 @@ +#include + +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + int16_t offset = (int16_t)(*zeroPoint); + const int pack = 8; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + vint16m4_t vdf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); + vint16m4_t vsf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8( + __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), + c, vl); + + vint16m4_t va = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + + vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); + voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), + c, vl); + + vint16m4_t vb = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); + __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, + __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp new file mode 100644 index 0000000000..fd6ce7a274 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp @@ -0,0 +1,53 @@ +#include + +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const int offset = *zeroPoint; + const int minVal = (int)minValue; + const int maxVal = (int)maxValue; + const size_t total = number << 4; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); + vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); + acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); + + vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); + vint = __riscv_vadd_vx_i32m8(vint, offset, vl); + vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); + vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); + + vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); + vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); + __riscv_vse8_v_i8m2(dst + i, vi8, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp new file mode 100644 index 0000000000..0da63ca0ff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp @@ -0,0 +1,38 @@ +#include + +void MNNCubicLineC4(float* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const size_t total = number << 2; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + __riscv_vse32_v_f32m8(dst + i, acc, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp new file mode 100644 index 0000000000..fd5b24a53d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp @@ -0,0 +1,79 @@ +#include + +void MNNCubicSampleC16(const int8_t* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 16; + int8_t zp = *zeroPoint; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp new file mode 100644 index 0000000000..78207e69e8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp @@ -0,0 +1,62 @@ +#include + +void MNNCubicSampleC4(const float* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); + } + + i += vl; + } +} From 17c0061dfe6561551f1090db14a12fb001e95760 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:55 +0800 Subject: [PATCH 064/314] Merge pull request #4050 from ihb2032/opt/rvv-top1 opt(RVV): Optimize top1 functions with intrinsics GitOrigin-RevId: 070c444b927aab4db76297e217bfe92a4508b294 --- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 +++++++++++++++++++ .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp new file mode 100644 index 0000000000..7332360ce8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + float maxV = -FLT_MAX; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); + vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); + maxV = __riscv_vfmv_f_s_f32m1_f32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp new file mode 100644 index 0000000000..8c199709ec --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + int32_t maxV = INT32_MIN; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); + vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); + maxV = __riscv_vmv_x_s_i32m1_i32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} From 812bb6011eb67c3cbdfb5065ccf63dc91517ebf8 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:36 +0800 Subject: [PATCH 065/314] Merge pull request #4044 from ihb2032/opt/rvv-softmax-relu opt(RVV): Optimize Softmax and ReluWithSlopeChannel with intrinsics GitOrigin-RevId: 98d2f9db51b45bf1deda2fb22398e56b323b5ae2 --- .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 +++++++++++ source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 +++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp new file mode 100644 index 0000000000..262f4cbfab --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp @@ -0,0 +1,45 @@ +#include + +void MNNReluWithSlopeChannel(float *dst, const float *src, + const float *slope, size_t sizeQuad, + size_t depthQuad) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t j = 0; j < depthQuad; ++j) { + const float *srcZ = src + 4 * j * sizeQuad; + float *dstZ = dst + 4 * j * sizeQuad; + float s0 = slope[4*j], s1 = slope[4*j + 1]; + float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; + size_t i = 0; + while (i < sizeQuad) { + size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); + const float *srcBase = srcZ + 4*i; + float *dstBase = dstZ + 4*i; + + vfloat32m8_t v; + vbool4_t mask; + + v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); + __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); + __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); + __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); + __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); + + i += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp new file mode 100644 index 0000000000..f510058c83 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp @@ -0,0 +1,80 @@ +#include +#include + +void MNNSoftmax(float *dest, const float *source, size_t size) { + size_t n = size; + const float *sourcePtr = source; + float *destPtr = dest; + float maxValue = -FLT_MAX; + vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); + maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); + sourcePtr += vl; + n -= vl; + } + + maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); + const float param = 0.6931471805599453f; + const float xLimit = 87.0f; + float sumValue = 0.f; + vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); + n = size; + sourcePtr = source; + destPtr = dest; + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); + vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); + vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); + vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); + + vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); + vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); + + vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( + __riscv_vsll_vx_i32m8( + __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); + + vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); + vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); + + vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + + vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); + __riscv_vse32_v_f32m8(destPtr, vA, vl); + sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); + + sourcePtr += vl; + destPtr += vl; + n -= vl; + } + + sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); + float sumInv = 1.0f / sumValue; + n = size; + destPtr = dest; + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); + vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); + __riscv_vse32_v_f32m8(destPtr, vDest, vl); + destPtr += vl; + n -= vl; + } +} From be7d0bb20181376570f255670a1e3a3b0f757409 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:54 +0800 Subject: [PATCH 066/314] Merge pull request #4042 from ihb2032/opt/rvv-conv-strassen opt(RVV): Optimize conv and strassen functions with intrinsics GitOrigin-RevId: bf461aa6e424685c2bc16570bc44220b65418ead --- .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 +++++++++++++++++++ .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 ++++++++++++++++ .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 ++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp new file mode 100644 index 0000000000..f82faf83f5 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp @@ -0,0 +1,48 @@ +#include + +void MNNConvRunForLineDepthwise( + float* dst, const float* src, const float* weight, + size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, + size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters) { + float minV = parameters[0]; + float maxV = parameters[1]; + ptrdiff_t srcByteStride = src_w_setup * sizeof(float); + ptrdiff_t dstByteStride = 4 * sizeof(float); + + for (size_t y = 0; y < height; ++y) { + const float* srcY = src + y * srcHStep; + float* dstY = dst + y * dstHStep; + size_t dx = 0; + + while (dx < width) { + size_t vl = __riscv_vsetvl_e32m8(width - dx); + + for (int c = 0; c < 4; ++c) { + vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); + const float* srcBase = srcY + dx * src_w_setup + c; + const float* weightPtr = weight + c; + + for (size_t fy = 0; fy < fh; ++fy) { + const float* srcFy = srcBase + fy * dilateY_step; + + for (size_t fx = 0; fx < fw; ++fx) { + float w = *weightPtr; + weightPtr += 4; + const float* srcFx = srcFy + fx * dilateX_step; + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); + } + } + + acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); + acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); + float* dstAddr = dstY + dx * 4 + c; + __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); + } + + dx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp new file mode 100644 index 0000000000..6658715e7e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp @@ -0,0 +1,42 @@ +#include + +void MNNDeconvRunForUnitDepthWise( + const float* dst, float* src, const float* weight, + size_t fw, size_t fh, + size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { + const ptrdiff_t wStride = 4 * sizeof(float); + const ptrdiff_t sStride = dilateX_step * sizeof(float); + float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; + + for (size_t fy = 0; fy < fh; ++fy) { + float* srcY = src + fy * dilateY_step; + const float* weightY = weight + fy * weightY_step; + + size_t fx = 0; + while (fx < fw) { + size_t vl = __riscv_vsetvl_e32m8(fw - fx); + + vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); + __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); + __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); + __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); + __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); + + fx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp new file mode 100644 index 0000000000..8ab5bb89fa --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp @@ -0,0 +1,36 @@ +#include + +void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, + float *xAddr, size_t cStride, size_t eSub, size_t hSub) { + for (int y = 0; y < hSub; ++y) { + float *c11Y = c11 + y * cStride; + float *c12Y = c12 + y * cStride; + float *c22Y = c22 + y * cStride; + float *c21Y = c21 + y * cStride; + float *xY = xAddr + y * eSub * 4; + size_t totalElements = eSub * 4; + size_t p = 0; + + while (p < totalElements) { + size_t vl = __riscv_vsetvl_e32m8(totalElements - p); + vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); + vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); + t = __riscv_vfadd_vv_f32m8(t, tmp, vl); + vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); + + tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); + __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); + + tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); + __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); + + c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); + __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); + + p += vl; + } + } +} From 2140c1557d5ad4a6e19f34f54617962acee12a73 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:07 +0800 Subject: [PATCH 067/314] Merge pull request #4036 from ihb2032/opt/rvv-minmax-float opt(RVV): Optimize max and min float functions with intrinsics GitOrigin-RevId: cf83302a16083000f569672536d270edb597b0a5 --- source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 ++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp new file mode 100644 index 0000000000..183a38bb10 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { + const float init = -FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + maxBuffer[j] = local; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp new file mode 100644 index 0000000000..9e8ade8641 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { + const float init = FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + minBuffer[j] = local; + } +} From b682012db79b3bc41309be13ba23ed341bffccd5 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:38 +0800 Subject: [PATCH 068/314] Merge pull request #4026 from ihb2032/opt/rvv-math-stride-ops opt(RVV): Optimize core math and stride functions with intrinsics GitOrigin-RevId: 1b2d4bd5da63d4c4f3a2e457c8a91f3dd47ebb99 --- .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 +++++++++++ .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 ++++++++ .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 +++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp new file mode 100644 index 0000000000..59bb28a039 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp @@ -0,0 +1,29 @@ +#include + +void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); + size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp new file mode 100644 index 0000000000..6d966789f7 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp @@ -0,0 +1,52 @@ +#include + +void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { + float beta = parameters[1]; + float minF = parameters[2]; + float maxF = parameters[3]; + const ptrdiff_t stride = 4 * sizeof(float); + + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + 4 * y; + auto c = C + cStride * y; + float b0Beta = b[0] * beta; + float b1Beta = b[1] * beta; + float b2Beta = b[2] * beta; + float b3Beta = b[3] * beta; + size_t w = width; + + while (w > 0) { + size_t vl = __riscv_vsetvl_e32m8(w); + + vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); + + a += 4 * vl; + c += 4 * vl; + w -= vl; + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp new file mode 100644 index 0000000000..3d8c4f13fc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp @@ -0,0 +1,22 @@ +#include + +void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); +size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp new file mode 100644 index 0000000000..10992f9d59 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp @@ -0,0 +1,42 @@ +#include + +void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t z = 0; z < biasNumber; ++z) { + float *dstZ = dst + z * planeNumber * 4; + const float *srcZ = src + z * planeNumber * 4; + const float *biasZ = bias + 4 * z; + const float *alphaZ = alpha + 4 * z; + float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; + float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; + + size_t n = planeNumber; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a0, vl); + data = __riscv_vfadd_vf_f32m8(data, b0, vl); + __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a1, vl); + data = __riscv_vfadd_vf_f32m8(data, b1, vl); + __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a2, vl); + data = __riscv_vfadd_vf_f32m8(data, b2, vl); + __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a3, vl); + data = __riscv_vfadd_vf_f32m8(data, b3, vl); + __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); + + srcZ += vl * 4; + dstZ += vl * 4; + n -= vl; + } + } +} From 4297496fd014f13b42ba90ba870ffb41d433041b Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:52 +0800 Subject: [PATCH 069/314] Merge pull request #4023 from ihb2032/feature/rvv-transpose-functions opt(RVV): Optimize transpose functions with intrinsics GitOrigin-RevId: 24f98cc6e50fac1e178be5c3c425a3f622343cd0 --- .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 +++++++++++++++++++ .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 ++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp new file mode 100644 index 0000000000..7598d6f8ac --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp @@ -0,0 +1,26 @@ +#include + +void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); + + for (int i = 0; i < h; ++i) { + const int16_t* srcPtr = srcO + i; + int16_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e16m8(w - j); + vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); + __riscv_vse16_v_i16m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + + diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp new file mode 100644 index 0000000000..e5c5eb83e6 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp @@ -0,0 +1,25 @@ +#include + +void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); + + for (int i = 0; i < h; ++i) { + const int32_t* srcPtr = srcO + i; + int32_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e32m8(w - j); + vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); + __riscv_vse32_v_i32m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + From e65ef9a82618f6a0b81dd4913e5a8b4cad985142 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:44:24 +0800 Subject: [PATCH 070/314] Merge pull request #4021 from ihb2032/feature/rvv-opt opt(RVV): Optimize pack and unpack functions with intrinsics GitOrigin-RevId: 58b54e86481db0588b59c679c0bade51e04b0d38 --- source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 ++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 ++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp new file mode 100644 index 0000000000..9a74f8998d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp @@ -0,0 +1,74 @@ +#include + +void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC2 = depth / 2; + int depthRemain = depthC2 * 2; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[2]; + + for (int z = 0; z < depthC2; ++z) { + float *dstZ = dst + z * areaOffset[1] * 2; + + for (int y = 0; y < 2; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + } + + srcOffset += areaOffset[0] * 2; + } + + if (remain > 0) { + float *dstZ = dst + depthC2 * areaOffset[1] * 2; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 2; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 2; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp new file mode 100644 index 0000000000..024e2c8c07 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp @@ -0,0 +1,80 @@ +#include + +void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[4]; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ = dst + z * areaOffset[1] * 4; + + for (int y = 0; y < 4; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + dstPtr[2] = srcChannel[2][x]; + dstPtr[3] = srcChannel[3][x]; + } + + srcOffset += areaOffset[0] * 4; + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 4; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 4; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp new file mode 100644 index 0000000000..4676e6dede --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp @@ -0,0 +1,55 @@ +#include + +void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ[4]; + + for (int y = 0; y < 4; ++y) { + dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); + srcOffset += 4 * vl; + } + + for (; x < area; ++x) { + dstZ[0][x] = srcOffset[0]; + dstZ[1][x] = srcOffset[1]; + dstZ[2][x] = srcOffset[2]; + dstZ[3][x] = srcOffset[3]; + srcOffset += (areaOffset[0] - area) * 4; + } + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + const float *srcBase = srcOffset; + + for (int y = 0; y < remain; ++y) { + float *dstChannel = dstZ + y * areaOffset[1]; + const float *srcChannel = srcBase + y; + + for (size_t x = 0; x < area; ++x) { + dstChannel[x] = srcChannel[0]; + srcChannel += 4; + } + } + } +} + From d18def649b4516ee9148a904147d0c79b4913126 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:54:53 +0800 Subject: [PATCH 071/314] Merge pull request #4061 from zlaazlaa/fix_diffusion fix(diffusion): simplify export logic and fix dynamic axes GitOrigin-RevId: 4c1cd8ed04606b5302cf9807a42bcc034ebf7c1b --- docs/transformers/diffusion.md | 3 +- transformers/diffusion/export/onnx_export.py | 30 ++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 7de27bb216..609793f806 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path + --output_path onnx_save_path \ + --opset 18 ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 21f05e83be..5516eb2fcc 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - "A sample prompt", + ["A sample prompt", "A sample prompt"], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,9 +97,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.text_encoder @@ -117,13 +115,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "timestep": {0: "batch"}, - "encoder_hidden_states": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -149,7 +143,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() onnx_export( vae_encoder, model_args=( @@ -159,30 +153,24 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = vae_encoder.decode + vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), - False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample", "return_dict"], + ordered_input_names=["latent_sample"], output_names=["sample"], - dynamic_axes={ - "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.vae From 3e4234f3dd1659e1835c780e477bac0ce770a15b Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:03 +0800 Subject: [PATCH 072/314] Merge pull request #3998 from bolun365/bolun365-patch-1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mnn lib库自动化build脚本 GitOrigin-RevId: ac1e2a9fd51ff3a9102660cda0d0731dfd849f95 --- build_lib.sh | 807 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 807 insertions(+) create mode 100644 build_lib.sh diff --git a/build_lib.sh b/build_lib.sh new file mode 100644 index 0000000000..c839b6e7b6 --- /dev/null +++ b/build_lib.sh @@ -0,0 +1,807 @@ +#!/bin/bash + +# MNN 统一构建脚本 +# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN + +set -e + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# 默认配置 +BUILD_ANDROID=false +BUILD_IOS=false +BUILD_PYTHON=false +BUILD_IOS_SIMULATOR=false +BUILD_HARMONY=false +ANDROID_NDK="" +HARMONY_HOME="" +OUTPUT_DIR="mnn_builds" +CLEAN_BUILD=true +VERSION="" +PYTHON_CMD="python3" +DEPS_OPTIONS="" + +# 获取项目根目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR" + +print_usage() { + echo "MNN 统一构建脚本" + echo "" + echo "用法: $0 [选项]" + echo "" + echo "构建目标:" + echo " --android 构建 Android 版本 (需要 --ndk)" + echo " --ios 构建 iOS 真机版本 (arm64)" + echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" + echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" + echo " --python 构建 Python 版本" + echo "" + echo "Android 选项:" + echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" + echo "" + echo "鸿蒙选项:" + echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" + echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" + echo "" + echo "Python 选项:" + echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" + echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" + echo " --python-cmd CMD Python 命令 (默认: python3)" + echo "" + echo "通用选项:" + echo " -o, --output DIR 输出目录 (默认: mnn_builds)" + echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" + echo " --no-clean 不清理之前的构建目录" + echo " -h, --help 显示帮助信息" + echo "" + echo "示例:" + echo " # 构建所有平台" + echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 仅构建 Android" + echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 构建鸿蒙版本" + echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" + echo "" + echo " # 构建 iOS 真机和模拟器" + echo " $0 --ios --ios-simulator" + echo "" + echo " # 构建 Python (带 LLM 支持)" + echo " $0 --python --python-deps llm,opencl" +} + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + --android) + BUILD_ANDROID=true + shift + ;; + --ios) + BUILD_IOS=true + shift + ;; + --ios-simulator) + BUILD_IOS_SIMULATOR=true + shift + ;; + --python) + BUILD_PYTHON=true + shift + ;; + --harmony) + BUILD_HARMONY=true + shift + ;; + --ndk) + ANDROID_NDK="$2" + shift 2 + ;; + --harmony-home) + HARMONY_HOME="$2" + shift 2 + ;; + --python-deps) + DEPS_OPTIONS="$2" + shift 2 + ;; + --python-cmd) + PYTHON_CMD="$2" + shift 2 + ;; + -o|--output) + OUTPUT_DIR="$2" + shift 2 + ;; + -v|--version) + VERSION="$2" + shift 2 + ;; + --no-clean) + CLEAN_BUILD=false + shift + ;; + -h|--help) + print_usage + exit 0 + ;; + *) + echo -e "${RED}错误: 未知选项 $1${NC}" + print_usage + exit 1 + ;; + esac +done + +# 检查是否至少选择了一个构建目标 +if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then + echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" + print_usage + exit 1 +fi + +# 检查 Android NDK +if [ "$BUILD_ANDROID" = true ]; then + if [ -z "$ANDROID_NDK" ]; then + echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" + echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" + exit 1 + fi + + # 展开路径中的 ~ + ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" + + if [ ! -d "$ANDROID_NDK" ]; then + echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" + exit 1 + fi +fi + +# 查找鸿蒙工具链的函数 +find_harmony_toolchain() { + # 如果环境变量已设置,优先使用 + if [ -n "$HARMONY_HOME" ]; then + # 支持两种路径格式:直接指定或带 native 子目录 + if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + fi + fi + + # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 + local sdk_base="$HOME/Library/OpenHarmony/Sdk" + if [ -d "$sdk_base" ]; then + # 查找所有版本号目录下的 native 或 toolchains + # 按版本号倒序排列,优先使用最新版本 + for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do + if [ -d "$version_dir" ]; then + # 尝试 native/build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir/native" + return 0 + fi + # 尝试 build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir" + return 0 + fi + fi + done + fi + + # 其他可能的路径 + local possible_paths=( + "$HOME/Library/OpenHarmony/Sdk/native" + "$HOME/HarmonyOS/Sdk/native" + "$HOME/.ohos/native" + "/opt/HarmonyOS/Sdk/native" + "/usr/local/HarmonyOS/Sdk/native" + "$HOME/Library/HarmonyOS/Sdk/native" + ) + + # 尝试查找 + for path in "${possible_paths[@]}"; do + if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then + echo "$path" + return 0 + fi + done + + # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 + # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 + local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) + if [ -n "$found" ]; then + # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 + found=$(dirname "$found") + if [ "$(basename "$found")" = "cmake" ]; then + found=$(dirname "$found") + if [ "$(basename "$found")" = "build" ]; then + found=$(dirname "$found") + fi + fi + echo "$found" + return 0 + fi + + return 1 +} + +# 检查鸿蒙工具链 +if [ "$BUILD_HARMONY" = true ]; then + # 展开路径中的 ~ + if [ -n "$HARMONY_HOME" ]; then + HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" + fi + + # 尝试查找工具链 + if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + # 检查是否在 native 子目录下 + if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + else + echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" + HARMONY_HOME=$(find_harmony_toolchain) + + if [ -z "$HARMONY_HOME" ]; then + echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" + echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" + echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" + echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi + + # 如果找到的是带 native 的路径,需要调整 + if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + fi + fi + + echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" + fi + + # 验证工具链文件存在 + if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi +fi + +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}MNN 统一构建脚本${NC}" +echo -e "${GREEN}========================================${NC}" +echo "项目根目录: $PROJECT_ROOT" +echo "输出目录: $OUTPUT_DIR" +echo "" +echo -e "${BLUE}构建目标:${NC}" +[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" +[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" +[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" +[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" +echo "" + +cd "$PROJECT_ROOT" +mkdir -p "$OUTPUT_DIR" + +# ============================================================================ +# 构建 Android 版本 +# ============================================================================ +if [ "$BUILD_ANDROID" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Android 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export ANDROID_NDK + + ANDROID_BUILD_DIR="project/android" + ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" + + cd "$ANDROID_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 Android 构建...${NC}" + rm -rf build_32 build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + + # 构建 armeabi-v7a + echo -e "${BLUE}构建 armeabi-v7a...${NC}" + mkdir -p build_32 + cd build_32 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="armeabi-v7a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-14 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出 Android 库文件...${NC}" + mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN + + cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true + cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/android/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" + + echo -e "${GREEN}Android 构建完成!${NC}" + echo "输出位置: $ANDROID_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 真机版本 +# ============================================================================ +if [ "$BUILD_IOS" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 真机版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_64 + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建真机版本 (arm64) + echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" + rm -rf ios_64 + mkdir ios_64 + cd ios_64 + cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=OS64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + > /dev/null + make MNN -j8 > /dev/null + cd .. + + # 输出真机版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理 + rm -rf ios_64 + + echo -e "${GREEN}iOS 真机构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 模拟器版本 +# ============================================================================ +if [ "$BUILD_IOS_SIMULATOR" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) + BUILD_X86_64=false + echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" + rm -rf ios_simulator_x86 + mkdir ios_simulator_x86 + cd ios_simulator_x86 + if cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="x86_64" \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + && make MNN -j8; then + if [ -f "MNN.framework/MNN" ]; then + BUILD_X86_64=true + echo -e "${GREEN}x86_64 模拟器构建成功${NC}" + cd .. + else + echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + else + echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + + # 构建 arm64 模拟器版本 + echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" + rm -rf ios_simulator_arm64 + mkdir ios_simulator_arm64 + cd ios_simulator_arm64 + if ! cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON; then + echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + if ! make MNN -j8; then + echo -e "${RED}错误: arm64 模拟器编译失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + cd .. + + # 验证构建产物 + if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then + echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + + # 合并模拟器架构 + echo -e "${BLUE}合并模拟器架构...${NC}" + rm -rf ios_simulator + mkdir ios_simulator + + if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then + # 合并 x86_64 + arm64 + echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" + cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework + mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 + if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then + echo -e "${RED}错误: 合并架构失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + rm ios_simulator/MNN.framework/MNN_x86 + else + # 仅使用 arm64 + echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" + cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework + fi + + # 输出模拟器版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理临时目录 + rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 + + echo -e "${GREEN}iOS 模拟器构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建鸿蒙版本 +# ============================================================================ +if [ "$BUILD_HARMONY" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建鸿蒙版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export HARMONY_HOME + + HARMONY_BUILD_DIR="project/harmony" + HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" + + cd "$HARMONY_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" + rm -rf build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + + cmake "$PROJECT_ROOT" \ + -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DOHOS_ARCH="arm64-v8a" \ + -DOHOS_STL=c++_static \ + -DMNN_USE_LOGCAT=true \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_SUPPORT_BF16=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DOHOS_PLATFORM_LEVEL=9 \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出鸿蒙库文件...${NC}" + mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN + + cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/harmony/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" + + echo -e "${GREEN}鸿蒙构建完成!${NC}" + echo "输出位置: $HARMONY_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 Python 版本 +# ============================================================================ +if [ "$BUILD_PYTHON" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Python 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查 Python + if ! command -v $PYTHON_CMD &> /dev/null; then + echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" + exit 1 + fi + + PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" + + # 使用之前创建的 build_pymnn.sh 脚本 + if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then + PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" + [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" + [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" + [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" + PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" + + bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS + else + echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" + + cd pymnn/pip_package + + # 构建依赖 + if [ -n "$DEPS_OPTIONS" ]; then + $PYTHON_CMD build_deps.py $DEPS_OPTIONS + else + $PYTHON_CMD build_deps.py + fi + + # 构建 wheel + $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet + rm -rf build dist + + BUILD_ARGS="" + [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" + [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" + + $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS + + mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" + cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" + + cd "$PROJECT_ROOT" + fi + + echo -e "${GREEN}Python 构建完成!${NC}" + echo "输出位置: $PYTHON_OUTPUT_DIR" +fi + +# ============================================================================ +# 总结 +# ============================================================================ +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}所有构建完成!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" +echo "" +[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" +[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" +[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" +[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" +echo "" + + From 4944ab9def8c861e368c3e9b0538f7141408fd98 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:54 +0800 Subject: [PATCH 073/314] Merge pull request #4009 from HenryDen/default_opt Add a compile option and macro to default enable kleidiAI GitOrigin-RevId: a3bc314f99ee49f10608550b92d4b37e0ca2d8f0 --- CMakeLists.txt | 1 + source/backend/cpu/arm/CMakeLists.txt | 3 +++ source/core/Backend.hpp | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 67502b606b..f99e37ec1c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,6 +258,7 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) +option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 18fca54a4e..61ebce6bdc 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,6 +36,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() + if(MNN_KLEIDIAI_DEFAULT_ON) + add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) + endif() endif() if (MNN_SME2) diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index bcf618c3c9..6850b6b4f6 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,9 +68,11 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; - +#ifdef MNN_DEFAULT_USE_KLEIDIAI + bool enableKleidiAI = true; +#else bool enableKleidiAI = false; - +#endif // Use CPU Ids std::vector cpuIds; From 3704aee7152a2f30ee733bd2416893637c9dfb57 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:42:22 +0800 Subject: [PATCH 074/314] Merge branch feature/add_4th_groupchat into master Title: [Doc:Update] update dingtalk in README. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审的主要改动是对README文件中的钉钉群信息进行了更新,包括群号、状态以及删除了一些过时的信息。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25029869 GitOrigin-RevId: 3e482c2332f0a4f4088ff8bdf75048eb51177330 --- README.md | 14 +++++++------- README_CN.md | 10 ++++------ README_JP.md | 9 +++++---- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 5fe168ed05..7959890c16 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #1 (Full): 23329087 +Group #4 (Available): 160170007549 -Group #2 (Full): 23350225 +Group #3 (Full) -Group #3: QR code: +Group #2 (Full): 23350225 -![MNN-3](doc/dingdingmnn3.png) +Group #1 (Full): 23329087 ## Historical Paper diff --git a/README_CN.md b/README_CN.md index edcf823a28..f769a1e14b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,12 +111,10 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群1:23329087 -- 钉钉群2:23350225 -- 钉钉群3:扫描二维码加入 - -![MNN-3](doc/dingdingmnn3.png) - +- 钉钉群3 (可加入): 160170007549 +- 钉钉群3 (已无法加入) +- 钉钉群2 (已满): 23350225 +- 钉钉群1 (已满): 23329087 ## 历史论文 diff --git a/README_JP.md b/README_JP.md index c2baa58d94..2f33def31a 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,13 +117,14 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: -グループ#1(満員):23329087 -グループ#2(満員):23350225 +グループ#4 :160170007549 -グループ#3:QRコード: +グループ#3 (満員) -![MNN-3](doc/dingdingmnn3.png) +グループ#2(満員):23350225 + +グループ#1(満員):23329087 ## 歴史的な論文 From 2c8637115a71cce4d8df8c07c0df58e736c55963 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 09:56:50 +0800 Subject: [PATCH 075/314] Project import generated by Copybara. GitOrigin-RevId: f936e7dcb1d1dbef608b5a01ad46ce1da8fca7de --- CMakeLists.txt | 1 - README.md | 14 +- README_CN.md | 10 +- README_JP.md | 9 +- build_lib.sh | 807 ------------------ docs/transformers/diffusion.md | 3 +- source/backend/cpu/arm/CMakeLists.txt | 3 - .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 - .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 - .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 - .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 -- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 - .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 - .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 -- source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 - .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 -- .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 - .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 -- .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 - .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 -- .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 -- .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 - source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 - source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 - source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 - source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 -- source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 -- source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 - .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 - .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 - source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 - source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 - .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 - .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 - source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 -- .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 - .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 - .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 - source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 -- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 - .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 - source/core/Backend.hpp | 6 +- transformers/diffusion/export/onnx_export.py | 30 +- 46 files changed, 41 insertions(+), 2196 deletions(-) delete mode 100644 build_lib.sh delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp delete mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f99e37ec1c..67502b606b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,7 +258,6 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) -option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/README.md b/README.md index 7959890c16..5fe168ed05 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #4 (Available): 160170007549 - -Group #3 (Full) +Group #1 (Full): 23329087 Group #2 (Full): 23350225 -Group #1 (Full): 23329087 +Group #3: QR code: + +![MNN-3](doc/dingdingmnn3.png) ## Historical Paper diff --git a/README_CN.md b/README_CN.md index f769a1e14b..edcf823a28 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,10 +111,12 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群3 (可加入): 160170007549 -- 钉钉群3 (已无法加入) -- 钉钉群2 (已满): 23350225 -- 钉钉群1 (已满): 23329087 +- 钉钉群1:23329087 +- 钉钉群2:23350225 +- 钉钉群3:扫描二维码加入 + +![MNN-3](doc/dingdingmnn3.png) + ## 历史论文 diff --git a/README_JP.md b/README_JP.md index 2f33def31a..c2baa58d94 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,14 +117,13 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: - -グループ#4 :160170007549 - -グループ#3 (満員) +グループ#1(満員):23329087 グループ#2(満員):23350225 -グループ#1(満員):23329087 +グループ#3:QRコード: + +![MNN-3](doc/dingdingmnn3.png) ## 歴史的な論文 diff --git a/build_lib.sh b/build_lib.sh deleted file mode 100644 index c839b6e7b6..0000000000 --- a/build_lib.sh +++ /dev/null @@ -1,807 +0,0 @@ -#!/bin/bash - -# MNN 统一构建脚本 -# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN - -set -e - -# 颜色输出 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# 默认配置 -BUILD_ANDROID=false -BUILD_IOS=false -BUILD_PYTHON=false -BUILD_IOS_SIMULATOR=false -BUILD_HARMONY=false -ANDROID_NDK="" -HARMONY_HOME="" -OUTPUT_DIR="mnn_builds" -CLEAN_BUILD=true -VERSION="" -PYTHON_CMD="python3" -DEPS_OPTIONS="" - -# 获取项目根目录 -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_ROOT="$SCRIPT_DIR" - -print_usage() { - echo "MNN 统一构建脚本" - echo "" - echo "用法: $0 [选项]" - echo "" - echo "构建目标:" - echo " --android 构建 Android 版本 (需要 --ndk)" - echo " --ios 构建 iOS 真机版本 (arm64)" - echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" - echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" - echo " --python 构建 Python 版本" - echo "" - echo "Android 选项:" - echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" - echo "" - echo "鸿蒙选项:" - echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" - echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" - echo "" - echo "Python 选项:" - echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" - echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" - echo " --python-cmd CMD Python 命令 (默认: python3)" - echo "" - echo "通用选项:" - echo " -o, --output DIR 输出目录 (默认: mnn_builds)" - echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" - echo " --no-clean 不清理之前的构建目录" - echo " -h, --help 显示帮助信息" - echo "" - echo "示例:" - echo " # 构建所有平台" - echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 仅构建 Android" - echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" - echo "" - echo " # 构建鸿蒙版本" - echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" - echo "" - echo " # 构建 iOS 真机和模拟器" - echo " $0 --ios --ios-simulator" - echo "" - echo " # 构建 Python (带 LLM 支持)" - echo " $0 --python --python-deps llm,opencl" -} - -# 解析命令行参数 -while [[ $# -gt 0 ]]; do - case $1 in - --android) - BUILD_ANDROID=true - shift - ;; - --ios) - BUILD_IOS=true - shift - ;; - --ios-simulator) - BUILD_IOS_SIMULATOR=true - shift - ;; - --python) - BUILD_PYTHON=true - shift - ;; - --harmony) - BUILD_HARMONY=true - shift - ;; - --ndk) - ANDROID_NDK="$2" - shift 2 - ;; - --harmony-home) - HARMONY_HOME="$2" - shift 2 - ;; - --python-deps) - DEPS_OPTIONS="$2" - shift 2 - ;; - --python-cmd) - PYTHON_CMD="$2" - shift 2 - ;; - -o|--output) - OUTPUT_DIR="$2" - shift 2 - ;; - -v|--version) - VERSION="$2" - shift 2 - ;; - --no-clean) - CLEAN_BUILD=false - shift - ;; - -h|--help) - print_usage - exit 0 - ;; - *) - echo -e "${RED}错误: 未知选项 $1${NC}" - print_usage - exit 1 - ;; - esac -done - -# 检查是否至少选择了一个构建目标 -if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then - echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" - print_usage - exit 1 -fi - -# 检查 Android NDK -if [ "$BUILD_ANDROID" = true ]; then - if [ -z "$ANDROID_NDK" ]; then - echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" - echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" - exit 1 - fi - - # 展开路径中的 ~ - ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" - - if [ ! -d "$ANDROID_NDK" ]; then - echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" - exit 1 - fi -fi - -# 查找鸿蒙工具链的函数 -find_harmony_toolchain() { - # 如果环境变量已设置,优先使用 - if [ -n "$HARMONY_HOME" ]; then - # 支持两种路径格式:直接指定或带 native 子目录 - if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$HARMONY_HOME" - return 0 - fi - fi - - # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 - local sdk_base="$HOME/Library/OpenHarmony/Sdk" - if [ -d "$sdk_base" ]; then - # 查找所有版本号目录下的 native 或 toolchains - # 按版本号倒序排列,优先使用最新版本 - for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do - if [ -d "$version_dir" ]; then - # 尝试 native/build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir/native" - return 0 - fi - # 尝试 build/cmake/ohos.toolchain.cmake - if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then - echo "$version_dir" - return 0 - fi - fi - done - fi - - # 其他可能的路径 - local possible_paths=( - "$HOME/Library/OpenHarmony/Sdk/native" - "$HOME/HarmonyOS/Sdk/native" - "$HOME/.ohos/native" - "/opt/HarmonyOS/Sdk/native" - "/usr/local/HarmonyOS/Sdk/native" - "$HOME/Library/HarmonyOS/Sdk/native" - ) - - # 尝试查找 - for path in "${possible_paths[@]}"; do - if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then - echo "$path" - return 0 - fi - done - - # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 - # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 - local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) - if [ -n "$found" ]; then - # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 - found=$(dirname "$found") - if [ "$(basename "$found")" = "cmake" ]; then - found=$(dirname "$found") - if [ "$(basename "$found")" = "build" ]; then - found=$(dirname "$found") - fi - fi - echo "$found" - return 0 - fi - - return 1 -} - -# 检查鸿蒙工具链 -if [ "$BUILD_HARMONY" = true ]; then - # 展开路径中的 ~ - if [ -n "$HARMONY_HOME" ]; then - HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" - fi - - # 尝试查找工具链 - if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - # 检查是否在 native 子目录下 - if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - else - echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" - HARMONY_HOME=$(find_harmony_toolchain) - - if [ -z "$HARMONY_HOME" ]; then - echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" - echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" - echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" - echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi - - # 如果找到的是带 native 的路径,需要调整 - if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then - HARMONY_HOME="$HARMONY_HOME/native" - fi - fi - - echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" - fi - - # 验证工具链文件存在 - if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then - echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" - exit 1 - fi -fi - -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}MNN 统一构建脚本${NC}" -echo -e "${GREEN}========================================${NC}" -echo "项目根目录: $PROJECT_ROOT" -echo "输出目录: $OUTPUT_DIR" -echo "" -echo -e "${BLUE}构建目标:${NC}" -[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" -[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" -[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" -[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" -echo "" - -cd "$PROJECT_ROOT" -mkdir -p "$OUTPUT_DIR" - -# ============================================================================ -# 构建 Android 版本 -# ============================================================================ -if [ "$BUILD_ANDROID" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Android 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export ANDROID_NDK - - ANDROID_BUILD_DIR="project/android" - ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" - - cd "$ANDROID_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 Android 构建...${NC}" - rm -rf build_32 build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - - # 构建 armeabi-v7a - echo -e "${BLUE}构建 armeabi-v7a...${NC}" - mkdir -p build_32 - cd build_32 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="armeabi-v7a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-14 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - cmake ../../../ \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DANDROID_ABI="arm64-v8a" \ - -DANDROID_STL=c++_static \ - -DANDROID_NATIVE_API_LEVEL=android-21 \ - -DANDROID_TOOLCHAIN=clang \ - -DMNN_USE_LOGCAT=false \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出 Android 库文件...${NC}" - mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN - - cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true - cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/android/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" - cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" - - echo -e "${GREEN}Android 构建完成!${NC}" - echo "输出位置: $ANDROID_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 真机版本 -# ============================================================================ -if [ "$BUILD_IOS" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 真机版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_64 - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建真机版本 (arm64) - echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" - rm -rf ios_64 - mkdir ios_64 - cd ios_64 - cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=OS64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - > /dev/null - make MNN -j8 > /dev/null - cd .. - - # 输出真机版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理 - rm -rf ios_64 - - echo -e "${GREEN}iOS 真机构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 iOS 模拟器版本 -# ============================================================================ -if [ "$BUILD_IOS_SIMULATOR" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查是否在 macOS 上 - if [[ "$OSTYPE" != "darwin"* ]]; then - echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" - exit 1 - fi - - # 检查 Xcode - if ! command -v xcodebuild &> /dev/null; then - echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" - exit 1 - fi - - IOS_BUILD_DIR="project/ios" - IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" - - cd "$IOS_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" - rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* - find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true - fi - - mkdir -p MNN-iOS-CPU-GPU/Static - cd MNN-iOS-CPU-GPU/Static - - # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) - BUILD_X86_64=false - echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" - rm -rf ios_simulator_x86 - mkdir ios_simulator_x86 - cd ios_simulator_x86 - if cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="x86_64" \ - -DMNN_ARM82=OFF \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - && make MNN -j8; then - if [ -f "MNN.framework/MNN" ]; then - BUILD_X86_64=true - echo -e "${GREEN}x86_64 模拟器构建成功${NC}" - cd .. - else - echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - else - echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" - cd .. - rm -rf ios_simulator_x86 - fi - - # 构建 arm64 模拟器版本 - echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" - rm -rf ios_simulator_arm64 - mkdir ios_simulator_arm64 - cd ios_simulator_arm64 - if ! cmake "$PROJECT_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ - -DENABLE_BITCODE=0 \ - -DMNN_AAPL_FMWK=1 \ - -DMNN_SEP_BUILD=OFF \ - -DMNN_BUILD_SHARED_LIBS=false \ - -DMNN_USE_THREAD_POOL=OFF \ - -DPLATFORM=SIMULATOR64 \ - -DARCHS="arm64" \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_METAL=OFF \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON; then - echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - if ! make MNN -j8; then - echo -e "${RED}错误: arm64 模拟器编译失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - cd .. - - # 验证构建产物 - if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then - echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - - # 合并模拟器架构 - echo -e "${BLUE}合并模拟器架构...${NC}" - rm -rf ios_simulator - mkdir ios_simulator - - if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then - # 合并 x86_64 + arm64 - echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" - cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework - mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 - if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then - echo -e "${RED}错误: 合并架构失败${NC}" - cd "$PROJECT_ROOT" - exit 1 - fi - rm ios_simulator/MNN.framework/MNN_x86 - else - # 仅使用 arm64 - echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" - cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework - fi - - # 输出模拟器版本 - mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" - rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" - cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" - - # 清理临时目录 - rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 - - echo -e "${GREEN}iOS 模拟器构建完成!${NC}" - echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建鸿蒙版本 -# ============================================================================ -if [ "$BUILD_HARMONY" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建鸿蒙版本${NC}" - echo -e "${GREEN}========================================${NC}" - - export HARMONY_HOME - - HARMONY_BUILD_DIR="project/harmony" - HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" - - cd "$HARMONY_BUILD_DIR" - - # 清理 - if [ "$CLEAN_BUILD" = true ]; then - echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" - rm -rf build_64 export - fi - - # 在项目根目录创建输出目录 - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - - # 构建 arm64-v8a - echo -e "${BLUE}构建 arm64-v8a...${NC}" - mkdir -p build_64 - cd build_64 - - cmake "$PROJECT_ROOT" \ - -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DOHOS_ARCH="arm64-v8a" \ - -DOHOS_STL=c++_static \ - -DMNN_USE_LOGCAT=true \ - -DMNN_BUILD_BENCHMARK=ON \ - -DMNN_USE_SSE=OFF \ - -DMNN_SUPPORT_BF16=OFF \ - -DMNN_BUILD_TEST=ON \ - -DMNN_ARM82=ON \ - -DMNN_LOW_MEMORY=ON \ - -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ - -DMNN_BUILD_LLM=ON \ - -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ - -DMNN_BUILD_DIFFUSION=ON \ - -DMNN_OPENCL=OFF \ - -DMNN_SEP_BUILD=OFF \ - -DLLM_SUPPORT_AUDIO=ON \ - -DMNN_BUILD_AUDIO=ON \ - -DLLM_SUPPORT_VISION=ON \ - -DMNN_BUILD_OPENCV=ON \ - -DMNN_IMGCODECS=ON \ - -DOHOS_PLATFORM_LEVEL=9 \ - -DNATIVE_LIBRARY_OUTPUT=. \ - -DNATIVE_INCLUDE_OUTPUT=. \ - > /dev/null - - make -j4 MNN > /dev/null - cd .. - - # 导出文件 - echo -e "${BLUE}导出鸿蒙库文件...${NC}" - mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN - - cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true - cp -r ../../include/MNN/* export/harmony/include/MNN/ - - # 复制到统一输出目录 - # 如果目标路径存在但不是目录,先删除 - if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then - rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - fi - mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" - cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" - - echo -e "${GREEN}鸿蒙构建完成!${NC}" - echo "输出位置: $HARMONY_OUTPUT_DIR" - cd "$PROJECT_ROOT" -fi - -# ============================================================================ -# 构建 Python 版本 -# ============================================================================ -if [ "$BUILD_PYTHON" = true ]; then - echo -e "${GREEN}========================================${NC}" - echo -e "${GREEN}开始构建 Python 版本${NC}" - echo -e "${GREEN}========================================${NC}" - - # 检查 Python - if ! command -v $PYTHON_CMD &> /dev/null; then - echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" - exit 1 - fi - - PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" - - # 使用之前创建的 build_pymnn.sh 脚本 - if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then - PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" - [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" - [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" - [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" - PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" - - bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS - else - echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" - - cd pymnn/pip_package - - # 构建依赖 - if [ -n "$DEPS_OPTIONS" ]; then - $PYTHON_CMD build_deps.py $DEPS_OPTIONS - else - $PYTHON_CMD build_deps.py - fi - - # 构建 wheel - $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet - rm -rf build dist - - BUILD_ARGS="" - [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" - [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" - - $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS - - mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" - cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" - - cd "$PROJECT_ROOT" - fi - - echo -e "${GREEN}Python 构建完成!${NC}" - echo "输出位置: $PYTHON_OUTPUT_DIR" -fi - -# ============================================================================ -# 总结 -# ============================================================================ -echo "" -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}所有构建完成!${NC}" -echo -e "${GREEN}========================================${NC}" -echo "" -echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" -echo "" -[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" -[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" -[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" -[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" -[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" -echo "" - - diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 609793f806..7de27bb216 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,8 +20,7 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path \ - --opset 18 + --output_path onnx_save_path ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 61ebce6bdc..18fca54a4e 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,9 +36,6 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() - if(MNN_KLEIDIAI_DEFAULT_ON) - add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) - endif() endif() if (MNN_SME2) diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp deleted file mode 100644 index a700016c31..0000000000 --- a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include - -void CPUBilinearLineC4(float* dst, const float* A, const float* B, - const float* t, int8_t* zeroPoint, size_t number) { - float tf = *t; - float sf = 1.0f - tf; - size_t total = number << 2; - - size_t i = 0; - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); - vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); - v = __riscv_vle32_v_f32m8(B + i, vl); - result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); - __riscv_vse32_v_f32m8(dst + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp deleted file mode 100644 index 5063c39bff..0000000000 --- a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include - -void CPUBilinearSampleC4(const float* src, float* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); - vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); - vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); - vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp deleted file mode 100644 index 59bb28a039..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include - -void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); - size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); - vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); - vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp deleted file mode 100644 index 6d966789f7..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { - float beta = parameters[1]; - float minF = parameters[2]; - float maxF = parameters[3]; - const ptrdiff_t stride = 4 * sizeof(float); - - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + 4 * y; - auto c = C + cStride * y; - float b0Beta = b[0] * beta; - float b1Beta = b[1] * beta; - float b2Beta = b[2] * beta; - float b3Beta = b[3] * beta; - size_t w = width; - - while (w > 0) { - size_t vl = __riscv_vsetvl_e32m8(w); - - vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); - data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); - data = __riscv_vfmax_vf_f32m8(data, minF, vl); - data = __riscv_vfmin_vf_f32m8(data, maxF, vl); - __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); - - a += 4 * vl; - c += 4 * vl; - w -= vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp deleted file mode 100644 index 145cbea73f..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include - -void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp deleted file mode 100644 index d46fe6c85b..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp deleted file mode 100644 index 684db6aed3..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp deleted file mode 100644 index a26243bdb8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include - -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, - const float* t, int8_t* zeroPoint, size_t number) { - int offset = *zeroPoint; - int8_t* dstPtr = dst; - - const int pack = 8; - const int16_t df = (int16_t)((*t) * 128.0f); - const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); - const size_t total = number * pack; - const int32_t ROUND_HALF = 1 << 13; - - size_t vl; - for (size_t i = 0; i < total; i += vl) { - vl = __riscv_vsetvl_e16m4(total - i); - vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); - vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); - v16 = __riscv_vle16_v_i16m4(B + i, vl); - v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); - - vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); - v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); - - tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); - mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); - v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); - vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); - mask = __riscv_vmand_mm_b4(mask, hasRem, vl); - - v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); - v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); - v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); - vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); - - __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp deleted file mode 100644 index bd111e3be4..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include - -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, - const int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - int16_t offset = (int16_t)(*zeroPoint); - const int pack = 8; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); - vint16m4_t vdf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); - vint16m4_t vsf = __riscv_vnsra_wx_i16m4( - __riscv_vfcvt_rtz_x_f_v_i32m8( - __riscv_vfmul_vf_f32m8( - __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), - c, vl); - - vint16m4_t va = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - - vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); - voff = __riscv_vadd_vx_u32m8( - __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), - c, vl); - - vint16m4_t vb = __riscv_vsub_vx_i16m4( - __riscv_vsext_vf2_i16m4( - __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); - vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); - __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, - __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp deleted file mode 100644 index 9d524f13ca..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp deleted file mode 100644 index f82faf83f5..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include - -void MNNConvRunForLineDepthwise( - float* dst, const float* src, const float* weight, - size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, - size_t height, size_t srcHStep, size_t dstHStep, - const float* bias, const float* parameters) { - float minV = parameters[0]; - float maxV = parameters[1]; - ptrdiff_t srcByteStride = src_w_setup * sizeof(float); - ptrdiff_t dstByteStride = 4 * sizeof(float); - - for (size_t y = 0; y < height; ++y) { - const float* srcY = src + y * srcHStep; - float* dstY = dst + y * dstHStep; - size_t dx = 0; - - while (dx < width) { - size_t vl = __riscv_vsetvl_e32m8(width - dx); - - for (int c = 0; c < 4; ++c) { - vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); - const float* srcBase = srcY + dx * src_w_setup + c; - const float* weightPtr = weight + c; - - for (size_t fy = 0; fy < fh; ++fy) { - const float* srcFy = srcBase + fy * dilateY_step; - - for (size_t fx = 0; fx < fw; ++fx) { - float w = *weightPtr; - weightPtr += 4; - const float* srcFx = srcFy + fx * dilateX_step; - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); - } - } - - acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); - acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); - float* dstAddr = dstY + dx * 4 + c; - __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); - } - - dx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp deleted file mode 100644 index 3d8c4f13fc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include - -void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { - ptrdiff_t srcStrideByte = srcStride * sizeof(float); - ptrdiff_t dstStrideByte = dstStride * sizeof(float); -size_t vl; - - for (size_t i = count; i > 0; i -= vl) { - vl = __riscv_vsetvl_e32m8(i); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); - data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); - __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); - source += vl * srcStride; - dest += vl * dstStride; - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp deleted file mode 100644 index fd6ce7a274..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include - -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const int offset = *zeroPoint; - const int minVal = (int)minValue; - const int maxVal = (int)maxValue; - const size_t total = number << 4; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); - vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); - acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); - - vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); - vint = __riscv_vadd_vx_i32m8(vint, offset, vl); - vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); - vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); - - vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); - vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); - __riscv_vse8_v_i8m2(dst + i, vi8, vl); - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp deleted file mode 100644 index 0da63ca0ff..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include - -void MNNCubicLineC4(float* dst, const float* A, const float* B, - const float* C, const float* D, float* t, - int8_t* zeroPoint, size_t number, - ssize_t minValue, ssize_t maxValue) { - const float f = *t; - const float t2 = f * f, t3 = t2 * f; - const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; - const float t1 = 1.0f - f, t1_2 = t1 * t1; - const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; - const float ta = 1.0f + f, ta2 = ta * ta; - const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; - const float td = 2.0f - f, td2 = td * td; - const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; - const size_t total = number << 2; - size_t i = 0; - - while (i < total) { - size_t vl = __riscv_vsetvl_e32m8(total - i); - vfloat32m8_t v, acc; - - v = __riscv_vle32_v_f32m8(A + i, vl); - acc = __riscv_vfmul_vf_f32m8(v, a0, vl); - - v = __riscv_vle32_v_f32m8(B + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); - - v = __riscv_vle32_v_f32m8(C + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); - - v = __riscv_vle32_v_f32m8(D + i, vl); - acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); - - __riscv_vse32_v_f32m8(dst + i, acc, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp deleted file mode 100644 index fd5b24a53d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include - -void MNNCubicSampleC16(const int8_t* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 16; - int8_t zp = *zeroPoint; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c, vl); - - vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); - vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); - vtmp = __riscv_vfcvt_f_x_v_f32m8( - __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp deleted file mode 100644 index 78207e69e8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include - -void MNNCubicSampleC4(const float* src, float* dst, - int32_t* position, const float* factor, - int8_t* zeroPoint, size_t number) { - const int pack = 4; - size_t i = 0; - - while (i < number) { - size_t vl = __riscv_vsetvl_e32m8(number - i); - vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); - - for (int c = 0; c < pack; c++) { - vuint32m8_t voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); - vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); - vfloat32m8_t vc = vtmp; - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); - vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); - vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); - - voff = __riscv_vsll_vx_u32m8( - __riscv_vreinterpret_v_i32m8_u32m8( - __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); - voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); - vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); - - va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); - vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); - - va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); - va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); - - __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); - } - - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp deleted file mode 100644 index 6658715e7e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNDeconvRunForUnitDepthWise( - const float* dst, float* src, const float* weight, - size_t fw, size_t fh, - size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { - const ptrdiff_t wStride = 4 * sizeof(float); - const ptrdiff_t sStride = dilateX_step * sizeof(float); - float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; - - for (size_t fy = 0; fy < fh; ++fy) { - float* srcY = src + fy * dilateY_step; - const float* weightY = weight + fy * weightY_step; - - size_t fx = 0; - while (fx < fw) { - size_t vl = __riscv_vsetvl_e32m8(fw - fx); - - vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); - vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); - __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); - __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); - __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); - - w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); - s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); - s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); - __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); - - fx += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp deleted file mode 100644 index 952fcaf090..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp deleted file mode 100644 index 5ee4540f98..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); - vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); - __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp deleted file mode 100644 index 183a38bb10..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { - const float init = -FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - maxBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp deleted file mode 100644 index 9e8ade8641..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { - const float init = FLT_MAX; - for (int j = 0; j < UNIT; ++j) { - float local = init; - size_t i = 0; - - while (i < (size_t)inputCountUnit) { - size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); - float *p0 = input + (i * UNIT * 2) + j * 2; - float *p1 = p0 + 1; - vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); - vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); - vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); - local = __riscv_vfmv_f_s_f32m1_f32(vred); - i += vl; - } - minBuffer[j] = local; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp deleted file mode 100644 index 9a74f8998d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include - -void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC2 = depth / 2; - int depthRemain = depthC2 * 2; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[2]; - - for (int z = 0; z < depthC2; ++z) { - float *dstZ = dst + z * areaOffset[1] * 2; - - for (int y = 0; y < 2; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - } - - srcOffset += areaOffset[0] * 2; - } - - if (remain > 0) { - float *dstZ = dst + depthC2 * areaOffset[1] * 2; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 2; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 2; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 2; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp deleted file mode 100644 index 024e2c8c07..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include - -void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - const float *srcChannel[4]; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ = dst + z * areaOffset[1] * 4; - - for (int y = 0; y < 4; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); - vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - dstPtr[0] = srcChannel[0][x]; - dstPtr[1] = srcChannel[1][x]; - dstPtr[2] = srcChannel[2][x]; - dstPtr[3] = srcChannel[3][x]; - } - - srcOffset += areaOffset[0] * 4; - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - - for (int y = 0; y < remain; ++y) { - srcChannel[y] = srcOffset + areaOffset[0] * y; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); - } - - vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); - for (int y = remain; y < 4; ++y) { - __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); - } - } - - for (; x < area; ++x) { - float *dstPtr = dstZ + x * 4; - - for (int y = 0; y < remain; ++y) { - dstPtr[y] = srcChannel[y][x]; - } - - for (int y = remain; y < 4; ++y) { - dstPtr[y] = 0.0f; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp deleted file mode 100644 index f2b6c7a78d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp deleted file mode 100644 index ddd67a7d8c..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); - __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp deleted file mode 100644 index d56b58546d..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp deleted file mode 100644 index 7c6decf39e..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m8(count - i); - vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); - - channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); - __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp deleted file mode 100644 index 1b946c33cc..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { - size_t i = 0; - while (i < count) { - size_t vl = __riscv_vsetvl_e8m4(count - i); - vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); - vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); - - channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); - sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); - - vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); - __riscv_vse8_v_u8m4(dest + i, result, vl); - i += vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp deleted file mode 100644 index 262f4cbfab..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include - -void MNNReluWithSlopeChannel(float *dst, const float *src, - const float *slope, size_t sizeQuad, - size_t depthQuad) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t j = 0; j < depthQuad; ++j) { - const float *srcZ = src + 4 * j * sizeQuad; - float *dstZ = dst + 4 * j * sizeQuad; - float s0 = slope[4*j], s1 = slope[4*j + 1]; - float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; - size_t i = 0; - while (i < sizeQuad) { - size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); - const float *srcBase = srcZ + 4*i; - float *dstBase = dstZ + 4*i; - - vfloat32m8_t v; - vbool4_t mask; - - v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); - __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); - __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); - __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); - - v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); - mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); - v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); - __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); - - i += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp deleted file mode 100644 index 10992f9d59..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { - const ptrdiff_t stride = 4 * sizeof(float); - - for (size_t z = 0; z < biasNumber; ++z) { - float *dstZ = dst + z * planeNumber * 4; - const float *srcZ = src + z * planeNumber * 4; - const float *biasZ = bias + 4 * z; - const float *alphaZ = alpha + 4 * z; - float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; - float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; - - size_t n = planeNumber; - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a0, vl); - data = __riscv_vfadd_vf_f32m8(data, b0, vl); - __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a1, vl); - data = __riscv_vfadd_vf_f32m8(data, b1, vl); - __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a2, vl); - data = __riscv_vfadd_vf_f32m8(data, b2, vl); - __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); - - data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); - data = __riscv_vfmul_vf_f32m8(data, a3, vl); - data = __riscv_vfadd_vf_f32m8(data, b3, vl); - __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); - - srcZ += vl * 4; - dstZ += vl * 4; - n -= vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp deleted file mode 100644 index f510058c83..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include -#include - -void MNNSoftmax(float *dest, const float *source, size_t size) { - size_t n = size; - const float *sourcePtr = source; - float *destPtr = dest; - float maxValue = -FLT_MAX; - vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); - maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); - sourcePtr += vl; - n -= vl; - } - - maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); - const float param = 0.6931471805599453f; - const float xLimit = 87.0f; - float sumValue = 0.f; - vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); - n = size; - sourcePtr = source; - destPtr = dest; - - while (n > 0) { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); - vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); - vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); - vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); - - vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); - vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); - - vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( - __riscv_vsll_vx_i32m8( - __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); - - vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); - vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); - - vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); - vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); - - vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); - __riscv_vse32_v_f32m8(destPtr, vA, vl); - sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); - - sourcePtr += vl; - destPtr += vl; - n -= vl; - } - - sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); - float sumInv = 1.0f / sumValue; - n = size; - destPtr = dest; - - while (n > 0) - { - size_t vl = __riscv_vsetvl_e32m8(n); - vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); - vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); - __riscv_vse32_v_f32m8(destPtr, vDest, vl); - destPtr += vl; - n -= vl; - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp deleted file mode 100644 index 8ab5bb89fa..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include - -void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, - float *xAddr, size_t cStride, size_t eSub, size_t hSub) { - for (int y = 0; y < hSub; ++y) { - float *c11Y = c11 + y * cStride; - float *c12Y = c12 + y * cStride; - float *c22Y = c22 + y * cStride; - float *c21Y = c21 + y * cStride; - float *xY = xAddr + y * eSub * 4; - size_t totalElements = eSub * 4; - size_t p = 0; - - while (p < totalElements) { - size_t vl = __riscv_vsetvl_e32m8(totalElements - p); - vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); - vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); - t = __riscv_vfadd_vv_f32m8(t, tmp, vl); - vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); - - tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); - tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); - __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); - - tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); - tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); - __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); - - c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); - __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); - - p += vl; - } - } -} diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp deleted file mode 100644 index 7598d6f8ac..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include - -void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); - - for (int i = 0; i < h; ++i) { - const int16_t* srcPtr = srcO + i; - int16_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e16m8(w - j); - vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); - __riscv_vse16_v_i16m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - - diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp deleted file mode 100644 index e5c5eb83e6..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include - -void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { - int w = dim[0]; - int h = dim[1]; - int srcStride = dim[2]; - int dstStride = dim[3]; - ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); - - for (int i = 0; i < h; ++i) { - const int32_t* srcPtr = srcO + i; - int32_t* dstPtr = dstO + i * dstStride; - - int j = 0; - while (j < w) { - size_t vl = __riscv_vsetvl_e32m8(w - j); - vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); - __riscv_vse32_v_i32m8(dstPtr, data, vl); - srcPtr += vl * srcStride; - dstPtr += vl; - j += vl; - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp deleted file mode 100644 index 4676e6dede..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { - int depthC4 = depth / 4; - int depthRemain = depthC4 * 4; - int remain = depth - depthRemain; - const float *srcOffset = src; - - for (int z = 0; z < depthC4; ++z) { - float *dstZ[4]; - - for (int y = 0; y < 4; ++y) { - dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; - } - - size_t x = 0; - size_t vl = __riscv_vsetvl_e32m8(area); - - for (; x + vl <= area; x += vl) { - vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); - vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); - __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); - srcOffset += 4 * vl; - } - - for (; x < area; ++x) { - dstZ[0][x] = srcOffset[0]; - dstZ[1][x] = srcOffset[1]; - dstZ[2][x] = srcOffset[2]; - dstZ[3][x] = srcOffset[3]; - srcOffset += (areaOffset[0] - area) * 4; - } - } - - if (remain > 0) { - float *dstZ = dst + depthC4 * areaOffset[1] * 4; - const float *srcBase = srcOffset; - - for (int y = 0; y < remain; ++y) { - float *dstChannel = dstZ + y * areaOffset[1]; - const float *srcChannel = srcBase + y; - - for (size_t x = 0; x < area; ++x) { - dstChannel[x] = srcChannel[0]; - srcChannel += 4; - } - } - } -} - diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp deleted file mode 100644 index 7332360ce8..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - float maxV = -FLT_MAX; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); - vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); - maxV = __riscv_vfmv_f_s_f32m1_f32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); - vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp deleted file mode 100644 index 8c199709ec..0000000000 --- a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include -#include - -#define UNIT 4 - -void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { - size_t n = inputCountUnit * UNIT; - int32_t maxV = INT32_MIN; - int32_t maxIdx = 0; - size_t vl; - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); - vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); - maxV = __riscv_vmv_x_s_i32m1_i32(result); - i += vl; - } - - for (size_t i = 0; i < n; ) { - vl = __riscv_vsetvl_e32m8(n - i); - vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); - vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); - long first = __riscv_vfirst_m_b4(mask, vl); - - if (first >= 0) { - maxIdx = i + first; - break; - } - - i += vl; - } - - maxValue[0] = maxV; - maxIndex[0] = maxIdx; -} diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 6850b6b4f6..bcf618c3c9 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,11 +68,9 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; -#ifdef MNN_DEFAULT_USE_KLEIDIAI - bool enableKleidiAI = true; -#else + bool enableKleidiAI = false; -#endif + // Use CPU Ids std::vector cpuIds; diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 5516eb2fcc..21f05e83be 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - ["A sample prompt", "A sample prompt"], + "A sample prompt", padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,7 +97,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes=None, + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, opset=opset, ) del pipeline.text_encoder @@ -115,9 +117,13 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -143,7 +149,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() onnx_export( vae_encoder, model_args=( @@ -153,24 +159,30 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes=None, + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels + vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] + vae_decoder.forward = vae_encoder.decode onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), + False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample"], + ordered_input_names=["latent_sample", "return_dict"], output_names=["sample"], - dynamic_axes=None, + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, opset=opset, ) del pipeline.vae From 8f5d15bdd3e455c39bcc6a7a882ad837c0ef41b7 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:36:21 +0800 Subject: [PATCH 076/314] Merge pull request #4067 from ihb2032/opt/rvv-pixel-conv opt(RVV): Optimize blitter functions with intrinsics GitOrigin-RevId: 784bb542822e52ae67f017cb2adeaad7ce43c267 --- source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp | 18 +++++++++++++++++ .../backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp | 13 ++++++++++++ source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp | 16 +++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp | 17 ++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp | 20 +++++++++++++++++++ .../backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp | 20 +++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp | 17 ++++++++++++++++ source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp | 20 +++++++++++++++++++ 11 files changed, 201 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp new file mode 100644 index 0000000000..145cbea73f --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToBGR.cpp @@ -0,0 +1,18 @@ +#include + +void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp new file mode 100644 index 0000000000..d46fe6c85b --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp new file mode 100644 index 0000000000..684db6aed3 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBGRToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} \ No newline at end of file diff --git a/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp new file mode 100644 index 0000000000..9d524f13ca --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNC3ToC4.cpp @@ -0,0 +1,20 @@ +#include + +void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp new file mode 100644 index 0000000000..952fcaf090 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC3.cpp @@ -0,0 +1,13 @@ +#include + +void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 0, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 1, 3, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 3 + 2, 3, gray, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp new file mode 100644 index 0000000000..5ee4540f98 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNGRAYToC4.cpp @@ -0,0 +1,16 @@ +#include + +void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t gray = __riscv_vle8_v_u8m8(source + i, vl); + vuint8m8_t alpha = __riscv_vmv_v_x_u8m8(255, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 0, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 1, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 2, 4, gray, vl); + __riscv_vsse8_v_u8m8(dest + i * 4 + 3, 4, alpha, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp new file mode 100644 index 0000000000..f2b6c7a78d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp new file mode 100644 index 0000000000..ddd67a7d8c --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToBGRA.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 4 * i + 2, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 0, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 1, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 1, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 0, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 2, 4, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 4 * i + 3, 4, vl); + __riscv_vsse8_v_u8m8(dest + 4 * i + 3, 4, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp new file mode 100644 index 0000000000..d56b58546d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBAToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 4 * i + 0, 4, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 1, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 4 * i + 2, 4, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + channel = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp new file mode 100644 index 0000000000..7c6decf39e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToBGR.cpp @@ -0,0 +1,17 @@ +#include + +void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m8(count - i); + vuint8m8_t channel = __riscv_vlse8_v_u8m8(source + 3 * i + 2, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 0, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 1, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 1, 3, channel, vl); + + channel = __riscv_vlse8_v_u8m8(source + 3 * i + 0, 3, vl); + __riscv_vsse8_v_u8m8(dest + 3 * i + 2, 3, channel, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp new file mode 100644 index 0000000000..1b946c33cc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNRGBToGRAY.cpp @@ -0,0 +1,20 @@ +#include + +void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { + size_t i = 0; + while (i < count) { + size_t vl = __riscv_vsetvl_e8m4(count - i); + vuint8m4_t channel = __riscv_vlse8_v_u8m4(source + 3 * i + 0, 3, vl); + vuint16m8_t sum = __riscv_vwmulu_vx_u16m8(channel, 19, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 1, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 38, channel, vl); + + channel = __riscv_vlse8_v_u8m4(source + 3 * i + 2, 3, vl); + sum = __riscv_vwmaccu_vx_u16m8(sum, 7, channel, vl); + + vuint8m4_t result = __riscv_vnsrl_wx_u8m4(sum, 6, vl); + __riscv_vse8_v_u8m4(dest + i, result, vl); + i += vl; + } +} From 7e8a2bee8b5fd9ff5c72b44bfc73cb1ce2f8a94a Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:13 +0800 Subject: [PATCH 077/314] Merge pull request #4053 from ihb2032/opt/rvv-resize-functions opt(RVV): Optimize resize functions with intrinsics GitOrigin-RevId: e55248749f6c5b8c7c7d5b67d734f79943569955 --- .../cpu/riscv/rvv/CPUBilinearLineC4.cpp | 19 +++++ .../cpu/riscv/rvv/CPUBilinearSampleC4.cpp | 33 ++++++++ .../cpu/riscv/rvv/MNNBilinearLineC8.cpp | 40 ++++++++++ .../cpu/riscv/rvv/MNNBilinearSampleC8.cpp | 49 ++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC16.cpp | 53 +++++++++++++ .../backend/cpu/riscv/rvv/MNNCubicLineC4.cpp | 38 +++++++++ .../cpu/riscv/rvv/MNNCubicSampleC16.cpp | 79 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCubicSampleC4.cpp | 62 +++++++++++++++ 8 files changed, 373 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp new file mode 100644 index 0000000000..a700016c31 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearLineC4.cpp @@ -0,0 +1,19 @@ +#include + +void CPUBilinearLineC4(float* dst, const float* A, const float* B, + const float* t, int8_t* zeroPoint, size_t number) { + float tf = *t; + float sf = 1.0f - tf; + size_t total = number << 2; + + size_t i = 0; + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v = __riscv_vle32_v_f32m8(A + i, vl); + vfloat32m8_t result = __riscv_vfmul_vf_f32m8(v, sf, vl); + v = __riscv_vle32_v_f32m8(B + i, vl); + result = __riscv_vfmacc_vf_f32m8(result, tf, v, vl); + __riscv_vse32_v_f32m8(dst + i, result, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp new file mode 100644 index 0000000000..5063c39bff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/CPUBilinearSampleC4.cpp @@ -0,0 +1,33 @@ +#include + +void CPUBilinearSampleC4(const float* src, float* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vr = __riscv_vluxei32_v_f32m8(src, voff, vl); + vfloat32m8_t vsf = __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl); + vr = __riscv_vfmul_vv_f32m8(vr, vsf, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vsf = __riscv_vluxei32_v_f32m8(src, voff, vl); + vr = __riscv_vfmacc_vv_f32m8(vr, vf, vsf, vl); + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, vr, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp new file mode 100644 index 0000000000..a26243bdb8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearLineC8.cpp @@ -0,0 +1,40 @@ +#include + +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, + const float* t, int8_t* zeroPoint, size_t number) { + int offset = *zeroPoint; + int8_t* dstPtr = dst; + + const int pack = 8; + const int16_t df = (int16_t)((*t) * 128.0f); + const int16_t sf = (int16_t)((1.0f - *t) * 128.0f); + const size_t total = number * pack; + const int32_t ROUND_HALF = 1 << 13; + + size_t vl; + for (size_t i = 0; i < total; i += vl) { + vl = __riscv_vsetvl_e16m4(total - i); + vint16m4_t v16 = __riscv_vle16_v_i16m4(A + i, vl); + vint32m8_t v32 = __riscv_vwmul_vx_i32m8(v16, sf, vl); + v16 = __riscv_vle16_v_i16m4(B + i, vl); + v32 = __riscv_vwmacc_vx_i32m8(v32, df, v16, vl); + + vbool4_t mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + vint32m8_t tmp = __riscv_vadd_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vsub_vx_i32m8(v32, ROUND_HALF, vl); + v32 = __riscv_vmerge_vvm_i32m8(tmp, v32, mask, vl); + + tmp = __riscv_vsra_vx_i32m8(v32, 14, vl); + mask = __riscv_vmslt_vx_i32m8_b4(v32, 0, vl); + v32 = __riscv_vand_vx_i32m8(v32, 0x3FFF, vl); + vbool4_t hasRem = __riscv_vmsne_vx_i32m8_b4(v32, 0, vl); + mask = __riscv_vmand_mm_b4(mask, hasRem, vl); + + v32 = __riscv_vadd_vx_i32m8_mu(mask, tmp, tmp, 1, vl); + v32 = __riscv_vadd_vx_i32m8(v32, offset, vl); + v16 = __riscv_vnsra_wx_i16m4(v32, 0, vl); + vint8m2_t v8 = __riscv_vnsra_wx_i8m2(v16, 0, vl); + + __riscv_vse8_v_i8m2(dstPtr + i, v8, vl); + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp new file mode 100644 index 0000000000..bd111e3be4 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNBilinearSampleC8.cpp @@ -0,0 +1,49 @@ +#include + +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, + const int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + int16_t offset = (int16_t)(*zeroPoint); + const int pack = 8; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vf = __riscv_vle32_v_f32m8(factor + i, vl); + vint16m4_t vdf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8(vf, 128.0f, vl), vl), 0, vl); + vint16m4_t vsf = __riscv_vnsra_wx_i16m4( + __riscv_vfcvt_rtz_x_f_v_i32m8( + __riscv_vfmul_vf_f32m8( + __riscv_vfrsub_vf_f32m8(vf, 1.0f, vl), 128.0f, vl), vl), 0, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i, 8, vl)), 3, vl), + c, vl); + + vint16m4_t va = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + + vint32m8_t vr = __riscv_vwmul_vv_i32m8(va, vsf, vl); + voff = __riscv_vadd_vx_u32m8( + __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 2*i + 1, 8, vl)), 3, vl), + c, vl); + + vint16m4_t vb = __riscv_vsub_vx_i16m4( + __riscv_vsext_vf2_i16m4( + __riscv_vluxei32_v_i8m2(src, voff, vl), vl), offset, vl); + vr = __riscv_vwmacc_vv_i32m8(vr, vb, vdf, vl); + __riscv_vsse16_v_i16m4(dst + i * pack + c, 16, + __riscv_vnsra_wx_i16m4(vr, 0, vl), vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp new file mode 100644 index 0000000000..fd6ce7a274 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC16.cpp @@ -0,0 +1,53 @@ +#include + +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const int offset = *zeroPoint; + const int minVal = (int)minValue; + const int maxVal = (int)maxValue; + const size_t total = number << 4; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + vfloat32m8_t half = __riscv_vfmv_v_f_f32m8(0.5f, vl); + vfloat32m8_t signHalf = __riscv_vfsgnj_vv_f32m8(half, acc, vl); + acc = __riscv_vfadd_vv_f32m8(acc, signHalf, vl); + + vint32m8_t vint = __riscv_vfcvt_rtz_x_f_v_i32m8(acc, vl); + vint = __riscv_vadd_vx_i32m8(vint, offset, vl); + vint = __riscv_vmax_vx_i32m8(vint, minVal, vl); + vint = __riscv_vmin_vx_i32m8(vint, maxVal, vl); + + vint16m4_t vi16 = __riscv_vncvt_x_x_w_i16m4(vint, vl); + vint8m2_t vi8 = __riscv_vncvt_x_x_w_i8m2(vi16, vl); + __riscv_vse8_v_i8m2(dst + i, vi8, vl); + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp new file mode 100644 index 0000000000..0da63ca0ff --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicLineC4.cpp @@ -0,0 +1,38 @@ +#include + +void MNNCubicLineC4(float* dst, const float* A, const float* B, + const float* C, const float* D, float* t, + int8_t* zeroPoint, size_t number, + ssize_t minValue, ssize_t maxValue) { + const float f = *t; + const float t2 = f * f, t3 = t2 * f; + const float b0 = 1.0f - 2.25f * t2 + 1.25f * t3; + const float t1 = 1.0f - f, t1_2 = t1 * t1; + const float c0 = 1.0f - 2.25f * t1_2 + 1.25f * t1_2 * t1; + const float ta = 1.0f + f, ta2 = ta * ta; + const float a0 = 3.0f - 6.0f * ta + 3.75f * ta2 - 0.75f * ta2 * ta; + const float td = 2.0f - f, td2 = td * td; + const float d0 = 3.0f - 6.0f * td + 3.75f * td2 - 0.75f * td2 * td; + const size_t total = number << 2; + size_t i = 0; + + while (i < total) { + size_t vl = __riscv_vsetvl_e32m8(total - i); + vfloat32m8_t v, acc; + + v = __riscv_vle32_v_f32m8(A + i, vl); + acc = __riscv_vfmul_vf_f32m8(v, a0, vl); + + v = __riscv_vle32_v_f32m8(B + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, b0, v, vl); + + v = __riscv_vle32_v_f32m8(C + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, c0, v, vl); + + v = __riscv_vle32_v_f32m8(D + i, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, d0, v, vl); + + __riscv_vse32_v_f32m8(dst + i, acc, vl); + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp new file mode 100644 index 0000000000..fd5b24a53d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC16.cpp @@ -0,0 +1,79 @@ +#include + +void MNNCubicSampleC16(const int8_t* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 16; + int8_t zp = *zeroPoint; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vint8m2_t vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vint16m4_t vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vfloat32m8_t vB = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c, vl); + + vtmp_i8 = __riscv_vluxei32_v_i8m2(src, voff, vl); + vtmp_i16 = __riscv_vwsub_vx_i16m4(vtmp_i8, zp, vl); + vtmp = __riscv_vfcvt_f_x_v_f32m8( + __riscv_vwcvt_x_x_v_i32m8(vtmp_i16, vl), vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, pack * sizeof(float), va, vl); + } + + i += vl; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp new file mode 100644 index 0000000000..78207e69e8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCubicSampleC4.cpp @@ -0,0 +1,62 @@ +#include + +void MNNCubicSampleC4(const float* src, float* dst, + int32_t* position, const float* factor, + int8_t* zeroPoint, size_t number) { + const int pack = 4; + size_t i = 0; + + while (i < number) { + size_t vl = __riscv_vsetvl_e32m8(number - i); + vfloat32m8_t vt = __riscv_vle32_v_f32m8(factor + i, vl); + + for (int c = 0; c < pack; c++) { + vuint32m8_t voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 0, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + vfloat32m8_t va = __riscv_vfmul_vf_f32m8(vtmp, -0.75f, vl); + vfloat32m8_t vb = __riscv_vfmul_vf_f32m8(vtmp, 1.5f, vl); + vfloat32m8_t vc = vtmp; + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 1, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vfloat32m8_t vB = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 1.25f, vB, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -2.25f, vB, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 2, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, -1.25f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, 1.5f, vtmp, vl); + vc = __riscv_vfsub_vv_f32m8(vtmp, vc, vl); + vc = __riscv_vfmul_vf_f32m8(vc, 0.75f, vl); + + voff = __riscv_vsll_vx_u32m8( + __riscv_vreinterpret_v_i32m8_u32m8( + __riscv_vlse32_v_i32m8(position + 4*i + 3, 16, vl)), 4, vl); + voff = __riscv_vadd_vx_u32m8(voff, c * 4, vl); + vtmp = __riscv_vluxei32_v_f32m8(src, voff, vl); + + va = __riscv_vfmacc_vf_f32m8(va, 0.75f, vtmp, vl); + vb = __riscv_vfmacc_vf_f32m8(vb, -0.75f, vtmp, vl); + + va = __riscv_vfmadd_vv_f32m8(va, vt, vb, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vc, vl); + va = __riscv_vfmadd_vv_f32m8(va, vt, vB, vl); + + __riscv_vsse32_v_f32m8(dst + i * pack + c, 16, va, vl); + } + + i += vl; + } +} From ef36caf314f8049888915945895b5c1200e9da56 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:41:55 +0800 Subject: [PATCH 078/314] Merge pull request #4050 from ihb2032/opt/rvv-top1 opt(RVV): Optimize top1 functions with intrinsics GitOrigin-RevId: f9f777c193cac1b7cf3201eb2bf789c782f31ca7 --- .../cpu/riscv/rvv/MNNVectorTop1Float.cpp | 37 +++++++++++++++++++ .../cpu/riscv/rvv/MNNVectorTop1Int32.cpp | 37 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp new file mode 100644 index 0000000000..7332360ce8 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Float.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Float(float* input, float* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + float maxV = -FLT_MAX; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(maxV, vl); + vfloat32m1_t result = __riscv_vfredmax_vs_f32m8_f32m1(data, scalar, vl); + maxV = __riscv_vfmv_f_s_f32m1_f32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t data = __riscv_vle32_v_f32m8(input + i, vl); + vbool4_t mask = __riscv_vmfeq_vf_f32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} diff --git a/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp new file mode 100644 index 0000000000..8c199709ec --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNVectorTop1Int32.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define UNIT 4 + +void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, size_t inputCountUnit) { + size_t n = inputCountUnit * UNIT; + int32_t maxV = INT32_MIN; + int32_t maxIdx = 0; + size_t vl; + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vint32m1_t scalar = __riscv_vmv_s_x_i32m1(maxV, vl); + vint32m1_t result = __riscv_vredmax_vs_i32m8_i32m1(data, scalar, vl); + maxV = __riscv_vmv_x_s_i32m1_i32(result); + i += vl; + } + + for (size_t i = 0; i < n; ) { + vl = __riscv_vsetvl_e32m8(n - i); + vint32m8_t data = __riscv_vle32_v_i32m8(input + i, vl); + vbool4_t mask = __riscv_vmseq_vx_i32m8_b4(data, maxV, vl); + long first = __riscv_vfirst_m_b4(mask, vl); + + if (first >= 0) { + maxIdx = i + first; + break; + } + + i += vl; + } + + maxValue[0] = maxV; + maxIndex[0] = maxIdx; +} From bea52013e0ccc032165015d22f52b496e325ee96 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:36 +0800 Subject: [PATCH 079/314] Merge pull request #4044 from ihb2032/opt/rvv-softmax-relu opt(RVV): Optimize Softmax and ReluWithSlopeChannel with intrinsics GitOrigin-RevId: 07b2b4e3b678f2b440bb954b58760d47d7c54689 --- .../cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp | 45 +++++++++++ source/backend/cpu/riscv/rvv/MNNSoftmax.cpp | 80 +++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNSoftmax.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp new file mode 100644 index 0000000000..262f4cbfab --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNReluWithSlopeChannel.cpp @@ -0,0 +1,45 @@ +#include + +void MNNReluWithSlopeChannel(float *dst, const float *src, + const float *slope, size_t sizeQuad, + size_t depthQuad) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t j = 0; j < depthQuad; ++j) { + const float *srcZ = src + 4 * j * sizeQuad; + float *dstZ = dst + 4 * j * sizeQuad; + float s0 = slope[4*j], s1 = slope[4*j + 1]; + float s2 = slope[4*j + 2], s3 = slope[4*j + 3]; + size_t i = 0; + while (i < sizeQuad) { + size_t vl = __riscv_vsetvl_e32m8(sizeQuad - i); + const float *srcBase = srcZ + 4*i; + float *dstBase = dstZ + 4*i; + + vfloat32m8_t v; + vbool4_t mask; + + v = __riscv_vlse32_v_f32m8(srcBase, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s0, vl); + __riscv_vsse32_v_f32m8(dstBase, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 1, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s1, vl); + __riscv_vsse32_v_f32m8(dstBase + 1, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 2, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s2, vl); + __riscv_vsse32_v_f32m8(dstBase + 2, stride, v, vl); + + v = __riscv_vlse32_v_f32m8(srcBase + 3, stride, vl); + mask = __riscv_vmflt_vf_f32m8_b4(v, 0.0f, vl); + v = __riscv_vfmul_vf_f32m8_mu(mask, v, v, s3, vl); + __riscv_vsse32_v_f32m8(dstBase + 3, stride, v, vl); + + i += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp new file mode 100644 index 0000000000..f510058c83 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNSoftmax.cpp @@ -0,0 +1,80 @@ +#include +#include + +void MNNSoftmax(float *dest, const float *source, size_t size) { + size_t n = size; + const float *sourcePtr = source; + float *destPtr = dest; + float maxValue = -FLT_MAX; + vfloat32m1_t maxVecValue = __riscv_vfmv_s_f_f32m1(maxValue, 1); + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vSrc = __riscv_vle32_v_f32m8(sourcePtr, vl); + maxVecValue = __riscv_vfredmax_vs_f32m8_f32m1(vSrc, maxVecValue, vl); + sourcePtr += vl; + n -= vl; + } + + maxValue = __riscv_vfmv_f_s_f32m1_f32(maxVecValue); + const float param = 0.6931471805599453f; + const float xLimit = 87.0f; + float sumValue = 0.f; + vfloat32m1_t sumVecValue = __riscv_vfmv_s_f_f32m1(sumValue, 1); + n = size; + sourcePtr = source; + destPtr = dest; + + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vA = __riscv_vle32_v_f32m8(sourcePtr, vl); + vA = __riscv_vfsub_vf_f32m8(vA, maxValue, vl); + vA = __riscv_vfmax_vf_f32m8(vA, -xLimit, vl); + vA = __riscv_vfmin_vf_f32m8(vA, xLimit, vl); + + vfloat32m8_t vB = __riscv_vfdiv_vf_f32m8(vA, param, vl); + vint32m8_t vBI = __riscv_vfcvt_x_f_v_i32m8(vB, vl); + + vfloat32m8_t vC = __riscv_vreinterpret_v_i32m8_f32m8( + __riscv_vsll_vx_i32m8( + __riscv_vadd_vx_i32m8(vBI, 127, vl), 23, vl)); + + vB = __riscv_vfcvt_f_x_v_f32m8(vBI, vl); + vB = __riscv_vfnmsub_vf_f32m8(vB, param, vA, vl); + + vA = __riscv_vfmv_v_f_f32m8(1.0f / 120.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 24.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f / 6.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 0.5f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + vA = __riscv_vfmul_vv_f32m8(vA, vB, vl); + vA = __riscv_vfadd_vf_f32m8(vA, 1.0f, vl); + + vA = __riscv_vfmul_vv_f32m8(vC, vA, vl); + __riscv_vse32_v_f32m8(destPtr, vA, vl); + sumVecValue = __riscv_vfredosum_vs_f32m8_f32m1(vA, sumVecValue, vl); + + sourcePtr += vl; + destPtr += vl; + n -= vl; + } + + sumValue = __riscv_vfmv_f_s_f32m1_f32(sumVecValue); + float sumInv = 1.0f / sumValue; + n = size; + destPtr = dest; + + while (n > 0) + { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t vDest = __riscv_vle32_v_f32m8(destPtr, vl); + vDest = __riscv_vfmul_vf_f32m8(vDest, sumInv, vl); + __riscv_vse32_v_f32m8(destPtr, vDest, vl); + destPtr += vl; + n -= vl; + } +} From bacb257b273ca10d8a2fd7dce709f5479716115e Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:42:54 +0800 Subject: [PATCH 080/314] Merge pull request #4042 from ihb2032/opt/rvv-conv-strassen opt(RVV): Optimize conv and strassen functions with intrinsics GitOrigin-RevId: 29b59dacf57d9d4fb4438209ac292956ae59b134 --- .../riscv/rvv/MNNConvRunForLineDepthwise.cpp | 48 +++++++++++++++++++ .../rvv/MNNDeconvRunForUnitDepthWise.cpp | 42 ++++++++++++++++ .../riscv/rvv/MNNStrassenMergeCFunction.cpp | 36 ++++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp new file mode 100644 index 0000000000..f82faf83f5 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNConvRunForLineDepthwise.cpp @@ -0,0 +1,48 @@ +#include + +void MNNConvRunForLineDepthwise( + float* dst, const float* src, const float* weight, + size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, + size_t height, size_t srcHStep, size_t dstHStep, + const float* bias, const float* parameters) { + float minV = parameters[0]; + float maxV = parameters[1]; + ptrdiff_t srcByteStride = src_w_setup * sizeof(float); + ptrdiff_t dstByteStride = 4 * sizeof(float); + + for (size_t y = 0; y < height; ++y) { + const float* srcY = src + y * srcHStep; + float* dstY = dst + y * dstHStep; + size_t dx = 0; + + while (dx < width) { + size_t vl = __riscv_vsetvl_e32m8(width - dx); + + for (int c = 0; c < 4; ++c) { + vfloat32m8_t acc = __riscv_vfmv_v_f_f32m8(bias[c], vl); + const float* srcBase = srcY + dx * src_w_setup + c; + const float* weightPtr = weight + c; + + for (size_t fy = 0; fy < fh; ++fy) { + const float* srcFy = srcBase + fy * dilateY_step; + + for (size_t fx = 0; fx < fw; ++fx) { + float w = *weightPtr; + weightPtr += 4; + const float* srcFx = srcFy + fx * dilateX_step; + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcFx, srcByteStride, vl); + acc = __riscv_vfmacc_vf_f32m8(acc, w, s, vl); + } + } + + acc = __riscv_vfmax_vf_f32m8(acc, minV, vl); + acc = __riscv_vfmin_vf_f32m8(acc, maxV, vl); + float* dstAddr = dstY + dx * 4 + c; + __riscv_vsse32_v_f32m8(dstAddr, dstByteStride, acc, vl); + } + + dx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp new file mode 100644 index 0000000000..6658715e7e --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNDeconvRunForUnitDepthWise.cpp @@ -0,0 +1,42 @@ +#include + +void MNNDeconvRunForUnitDepthWise( + const float* dst, float* src, const float* weight, + size_t fw, size_t fh, + size_t weightY_step, size_t dilateX_step, size_t dilateY_step) { + const ptrdiff_t wStride = 4 * sizeof(float); + const ptrdiff_t sStride = dilateX_step * sizeof(float); + float d0 = dst[0], d1 = dst[1], d2 = dst[2], d3 = dst[3]; + + for (size_t fy = 0; fy < fh; ++fy) { + float* srcY = src + fy * dilateY_step; + const float* weightY = weight + fy * weightY_step; + + size_t fx = 0; + while (fx < fw) { + size_t vl = __riscv_vsetvl_e32m8(fw - fx); + + vfloat32m8_t w = __riscv_vlse32_v_f32m8(weightY + 0 + fx * 4, wStride, vl); + vfloat32m8_t s = __riscv_vlse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d0, w, vl); + __riscv_vsse32_v_f32m8(srcY + 0 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 1 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d1, w, vl); + __riscv_vsse32_v_f32m8(srcY + 1 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 2 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d2, w, vl); + __riscv_vsse32_v_f32m8(srcY + 2 + fx * dilateX_step, sStride, s, vl); + + w = __riscv_vlse32_v_f32m8(weightY + 3 + fx * 4, wStride, vl); + s = __riscv_vlse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, vl); + s = __riscv_vfmacc_vf_f32m8(s, d3, w, vl); + __riscv_vsse32_v_f32m8(srcY + 3 + fx * dilateX_step, sStride, s, vl); + + fx += vl; + } + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp new file mode 100644 index 0000000000..8ab5bb89fa --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNStrassenMergeCFunction.cpp @@ -0,0 +1,36 @@ +#include + +void MNNStrassenMergeCFunction(float *c11, float *c12, float *c21, float *c22, + float *xAddr, size_t cStride, size_t eSub, size_t hSub) { + for (int y = 0; y < hSub; ++y) { + float *c11Y = c11 + y * cStride; + float *c12Y = c12 + y * cStride; + float *c22Y = c22 + y * cStride; + float *c21Y = c21 + y * cStride; + float *xY = xAddr + y * eSub * 4; + size_t totalElements = eSub * 4; + size_t p = 0; + + while (p < totalElements) { + size_t vl = __riscv_vsetvl_e32m8(totalElements - p); + vfloat32m8_t t = __riscv_vle32_v_f32m8(xY + p, vl); + vfloat32m8_t tmp = __riscv_vle32_v_f32m8(c12Y + p, vl); + t = __riscv_vfadd_vv_f32m8(t, tmp, vl); + vfloat32m8_t c22v = __riscv_vle32_v_f32m8(c22Y + p, vl); + + tmp = __riscv_vle32_v_f32m8(c11Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, c22v, vl); + tmp = __riscv_vfadd_vv_f32m8(tmp, t, vl); + __riscv_vse32_v_f32m8(c12Y + p, tmp, vl); + + tmp = __riscv_vle32_v_f32m8(c21Y + p, vl); + tmp = __riscv_vfadd_vv_f32m8(t, tmp, vl); + __riscv_vse32_v_f32m8(c21Y + p, tmp, vl); + + c22v = __riscv_vfadd_vv_f32m8(c22v, tmp, vl); + __riscv_vse32_v_f32m8(c22Y + p, c22v, vl); + + p += vl; + } + } +} From 19df56fe60d95ea0f5d789ea05f59550d854b099 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:07 +0800 Subject: [PATCH 081/314] Merge pull request #4036 from ihb2032/opt/rvv-minmax-float opt(RVV): Optimize max and min float functions with intrinsics GitOrigin-RevId: 826e9dd9b4bb8b260d29bc9574840b83ec8e9154 --- source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp | 25 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNMinFloat.cpp | 25 ++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNMinFloat.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp new file mode 100644 index 0000000000..183a38bb10 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMaxFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMaxFloat(float *input, float *maxBuffer, int32_t inputCountUnit) { + const float init = -FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmax = __riscv_vfmax_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmax_vs_f32m8_f32m1(vmax, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + maxBuffer[j] = local; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp new file mode 100644 index 0000000000..9e8ade8641 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNMinFloat.cpp @@ -0,0 +1,25 @@ +#include +#include + +#define UNIT 4 + +void MNNMinFloat(float *input, float *minBuffer, int32_t inputCountUnit) { + const float init = FLT_MAX; + for (int j = 0; j < UNIT; ++j) { + float local = init; + size_t i = 0; + + while (i < (size_t)inputCountUnit) { + size_t vl = __riscv_vsetvl_e32m8(inputCountUnit - i); + float *p0 = input + (i * UNIT * 2) + j * 2; + float *p1 = p0 + 1; + vfloat32m8_t v0 = __riscv_vlse32_v_f32m8(p0, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t v1 = __riscv_vlse32_v_f32m8(p1, UNIT * 2 * sizeof(float), vl); + vfloat32m8_t vmin = __riscv_vfmin_vv_f32m8(v0, v1, vl); + vfloat32m1_t vred = __riscv_vfredmin_vs_f32m8_f32m1(vmin, __riscv_vfmv_s_f_f32m1(local, 1), vl); + local = __riscv_vfmv_f_s_f32m1_f32(vred); + i += vl; + } + minBuffer[j] = local; + } +} From 434f804e4269939449777eea6b0faea992ed1de6 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:38 +0800 Subject: [PATCH 082/314] Merge pull request #4026 from ihb2032/opt/rvv-math-stride-ops opt(RVV): Optimize core math and stride functions with intrinsics GitOrigin-RevId: 3036cf5098b26250c04bef3b93801f9f0caf62a6 --- .../cpu/riscv/rvv/MNNAddC4WithStride.cpp | 29 +++++++++++ .../riscv/rvv/MNNAxByClampBroadcastUnit.cpp | 52 +++++++++++++++++++ .../cpu/riscv/rvv/MNNCopyC4WithStride.cpp | 22 ++++++++ .../cpu/riscv/rvv/MNNScaleAndAddBias.cpp | 42 +++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp new file mode 100644 index 0000000000..59bb28a039 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAddC4WithStride.cpp @@ -0,0 +1,29 @@ +#include + +void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); + size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t vs = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + vfloat32m8_t vd = __riscv_vlse32_v_f32m8(dest + 0, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 1, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 2, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, vd, vl); + vs = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + vd = __riscv_vlse32_v_f32m8(dest + 3, dstStrideByte, vl); + vd = __riscv_vfadd_vv_f32m8(vd, vs, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, vd, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} diff --git a/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp new file mode 100644 index 0000000000..6d966789f7 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNAxByClampBroadcastUnit.cpp @@ -0,0 +1,52 @@ +#include + +void MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { + float beta = parameters[1]; + float minF = parameters[2]; + float maxF = parameters[3]; + const ptrdiff_t stride = 4 * sizeof(float); + + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + 4 * y; + auto c = C + cStride * y; + float b0Beta = b[0] * beta; + float b1Beta = b[1] * beta; + float b2Beta = b[2] * beta; + float b3Beta = b[3] * beta; + size_t w = width; + + while (w > 0) { + size_t vl = __riscv_vsetvl_e32m8(w); + + vfloat32m8_t data = __riscv_vlse32_v_f32m8(a + 0, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b0Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 1, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b1Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 2, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b2Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(a + 3, stride, vl); + data = __riscv_vfadd_vf_f32m8(data, b3Beta, vl); + data = __riscv_vfmax_vf_f32m8(data, minF, vl); + data = __riscv_vfmin_vf_f32m8(data, maxF, vl); + __riscv_vsse32_v_f32m8(c + 3, stride, data, vl); + + a += 4 * vl; + c += 4 * vl; + w -= vl; + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp new file mode 100644 index 0000000000..3d8c4f13fc --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNCopyC4WithStride.cpp @@ -0,0 +1,22 @@ +#include + +void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { + ptrdiff_t srcStrideByte = srcStride * sizeof(float); + ptrdiff_t dstStrideByte = dstStride * sizeof(float); +size_t vl; + + for (size_t i = count; i > 0; i -= vl) { + vl = __riscv_vsetvl_e32m8(i); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(source + 0, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 0, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 1, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 1, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 2, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 2, dstStrideByte, data, vl); + data = __riscv_vlse32_v_f32m8(source + 3, srcStrideByte, vl); + __riscv_vsse32_v_f32m8(dest + 3, dstStrideByte, data, vl); + source += vl * srcStride; + dest += vl * dstStride; + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp new file mode 100644 index 0000000000..10992f9d59 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNScaleAndAddBias.cpp @@ -0,0 +1,42 @@ +#include + +void MNNScaleAndAddBias(float *dst, const float *src, const float *bias, const float *alpha, size_t planeNumber, size_t biasNumber) { + const ptrdiff_t stride = 4 * sizeof(float); + + for (size_t z = 0; z < biasNumber; ++z) { + float *dstZ = dst + z * planeNumber * 4; + const float *srcZ = src + z * planeNumber * 4; + const float *biasZ = bias + 4 * z; + const float *alphaZ = alpha + 4 * z; + float b0 = biasZ[0], b1 = biasZ[1], b2 = biasZ[2], b3 = biasZ[3]; + float a0 = alphaZ[0], a1 = alphaZ[1], a2 = alphaZ[2], a3 = alphaZ[3]; + + size_t n = planeNumber; + while (n > 0) { + size_t vl = __riscv_vsetvl_e32m8(n); + vfloat32m8_t data = __riscv_vlse32_v_f32m8(srcZ + 0, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a0, vl); + data = __riscv_vfadd_vf_f32m8(data, b0, vl); + __riscv_vsse32_v_f32m8(dstZ + 0, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 1, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a1, vl); + data = __riscv_vfadd_vf_f32m8(data, b1, vl); + __riscv_vsse32_v_f32m8(dstZ + 1, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 2, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a2, vl); + data = __riscv_vfadd_vf_f32m8(data, b2, vl); + __riscv_vsse32_v_f32m8(dstZ + 2, stride, data, vl); + + data = __riscv_vlse32_v_f32m8(srcZ + 3, stride, vl); + data = __riscv_vfmul_vf_f32m8(data, a3, vl); + data = __riscv_vfadd_vf_f32m8(data, b3, vl); + __riscv_vsse32_v_f32m8(dstZ + 3, stride, data, vl); + + srcZ += vl * 4; + dstZ += vl * 4; + n -= vl; + } + } +} From 6c234a4eb8c0d724cb688d1e46ea6e00edbd56d3 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:43:52 +0800 Subject: [PATCH 083/314] Merge pull request #4023 from ihb2032/feature/rvv-transpose-functions opt(RVV): Optimize transpose functions with intrinsics GitOrigin-RevId: 72c11be4d128d054363c863f375136f3972a2ab0 --- .../cpu/riscv/rvv/MNNTranspose16Bit.cpp | 26 +++++++++++++++++++ .../cpu/riscv/rvv/MNNTranspose32Bit.cpp | 25 ++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp new file mode 100644 index 0000000000..7598d6f8ac --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose16Bit.cpp @@ -0,0 +1,26 @@ +#include + +void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int16_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int16_t); + + for (int i = 0; i < h; ++i) { + const int16_t* srcPtr = srcO + i; + int16_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e16m8(w - j); + vint16m8_t data = __riscv_vlse16_v_i16m8(srcPtr, srcStrideByte, vl); + __riscv_vse16_v_i16m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + + diff --git a/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp new file mode 100644 index 0000000000..e5c5eb83e6 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNTranspose32Bit.cpp @@ -0,0 +1,25 @@ +#include + +void MNNTranspose32Bit(int32_t* dstO, const int32_t* srcO, int32_t* dim) { + int w = dim[0]; + int h = dim[1]; + int srcStride = dim[2]; + int dstStride = dim[3]; + ptrdiff_t srcStrideByte = srcStride * sizeof(int32_t); + + for (int i = 0; i < h; ++i) { + const int32_t* srcPtr = srcO + i; + int32_t* dstPtr = dstO + i * dstStride; + + int j = 0; + while (j < w) { + size_t vl = __riscv_vsetvl_e32m8(w - j); + vint32m8_t data = __riscv_vlse32_v_i32m8(srcPtr, srcStrideByte, vl); + __riscv_vse32_v_i32m8(dstPtr, data, vl); + srcPtr += vl * srcStride; + dstPtr += vl; + j += vl; + } + } +} + From 8e4b0023670e6b48f71959ec7f5a817ab01f0097 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:44:24 +0800 Subject: [PATCH 084/314] Merge pull request #4021 from ihb2032/feature/rvv-opt opt(RVV): Optimize pack and unpack functions with intrinsics GitOrigin-RevId: d9f4036b55096f812885e84113c96876e431147e --- source/backend/cpu/riscv/rvv/MNNPackC2.cpp | 74 ++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNPackC4.cpp | 80 ++++++++++++++++++++ source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp | 55 ++++++++++++++ 3 files changed, 209 insertions(+) create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC2.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNPackC4.cpp create mode 100644 source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp diff --git a/source/backend/cpu/riscv/rvv/MNNPackC2.cpp b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp new file mode 100644 index 0000000000..9a74f8998d --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC2.cpp @@ -0,0 +1,74 @@ +#include + +void MNNPackC2(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC2 = depth / 2; + int depthRemain = depthC2 * 2; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[2]; + + for (int z = 0; z < depthC2; ++z) { + float *dstZ = dst + z * areaOffset[1] * 2; + + for (int y = 0; y < 2; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 2 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 2 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + } + + srcOffset += areaOffset[0] * 2; + } + + if (remain > 0) { + float *dstZ = dst + depthC2 * areaOffset[1] * 2; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 2; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 2 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 2; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 2; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNPackC4.cpp b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp new file mode 100644 index 0000000000..024e2c8c07 --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNPackC4.cpp @@ -0,0 +1,80 @@ +#include + +void MNNPackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + const float *srcChannel[4]; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ = dst + z * areaOffset[1] * 4; + + for (int y = 0; y < 4; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[0] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 0, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[1] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 1, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[2] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 2, 4 * sizeof(float), vec, vl); + vec = __riscv_vle32_v_f32m8(srcChannel[3] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + 3, 4 * sizeof(float), vec, vl); + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + dstPtr[0] = srcChannel[0][x]; + dstPtr[1] = srcChannel[1][x]; + dstPtr[2] = srcChannel[2][x]; + dstPtr[3] = srcChannel[3][x]; + } + + srcOffset += areaOffset[0] * 4; + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + + for (int y = 0; y < remain; ++y) { + srcChannel[y] = srcOffset + areaOffset[0] * y; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + vfloat32m8_t vec = __riscv_vle32_v_f32m8(srcChannel[y] + x, vl); + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), vec, vl); + } + + vfloat32m8_t zero = __riscv_vfmv_v_f_f32m8(0.0f, vl); + for (int y = remain; y < 4; ++y) { + __riscv_vsse32_v_f32m8(dstPtr + y, 4 * sizeof(float), zero, vl); + } + } + + for (; x < area; ++x) { + float *dstPtr = dstZ + x * 4; + + for (int y = 0; y < remain; ++y) { + dstPtr[y] = srcChannel[y][x]; + } + + for (int y = remain; y < 4; ++y) { + dstPtr[y] = 0.0f; + } + } + } +} + diff --git a/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp new file mode 100644 index 0000000000..4676e6dede --- /dev/null +++ b/source/backend/cpu/riscv/rvv/MNNUnpackC4.cpp @@ -0,0 +1,55 @@ +#include + +void MNNUnpackC4(float *dst, const float *src, size_t area, size_t depth, int *areaOffset) { + int depthC4 = depth / 4; + int depthRemain = depthC4 * 4; + int remain = depth - depthRemain; + const float *srcOffset = src; + + for (int z = 0; z < depthC4; ++z) { + float *dstZ[4]; + + for (int y = 0; y < 4; ++y) { + dstZ[y] = dst + (z * 4 + y) * areaOffset[1]; + } + + size_t x = 0; + size_t vl = __riscv_vsetvl_e32m8(area); + + for (; x + vl <= area; x += vl) { + vfloat32m8_t vec = __riscv_vlse32_v_f32m8(srcOffset + 0, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[0] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 1, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[1] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 2, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[2] + x, vec, vl); + vec = __riscv_vlse32_v_f32m8(srcOffset + 3, 4 * sizeof(float), vl); + __riscv_vse32_v_f32m8(dstZ[3] + x, vec, vl); + srcOffset += 4 * vl; + } + + for (; x < area; ++x) { + dstZ[0][x] = srcOffset[0]; + dstZ[1][x] = srcOffset[1]; + dstZ[2][x] = srcOffset[2]; + dstZ[3][x] = srcOffset[3]; + srcOffset += (areaOffset[0] - area) * 4; + } + } + + if (remain > 0) { + float *dstZ = dst + depthC4 * areaOffset[1] * 4; + const float *srcBase = srcOffset; + + for (int y = 0; y < remain; ++y) { + float *dstChannel = dstZ + y * areaOffset[1]; + const float *srcChannel = srcBase + y; + + for (size_t x = 0; x < area; ++x) { + dstChannel[x] = srcChannel[0]; + srcChannel += 4; + } + } + } +} + From 712b4a308d81fe99702c59d03a6469e95eb15f3d Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 10:54:53 +0800 Subject: [PATCH 085/314] Merge pull request #4061 from zlaazlaa/fix_diffusion fix(diffusion): simplify export logic and fix dynamic axes GitOrigin-RevId: cc6faf47f33d462e2e1ac613ec710ce55c39a86a --- docs/transformers/diffusion.md | 3 +- transformers/diffusion/export/onnx_export.py | 30 ++++++-------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 7de27bb216..609793f806 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai cd mnn_path/transformers/diffusion/export python onnx_export.py \ --model_path hf_sd_load_path \ - --output_path onnx_save_path + --output_path onnx_save_path \ + --opset 18 ``` 注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: ``` diff --git a/transformers/diffusion/export/onnx_export.py b/transformers/diffusion/export/onnx_export.py index 21f05e83be..5516eb2fcc 100644 --- a/transformers/diffusion/export/onnx_export.py +++ b/transformers/diffusion/export/onnx_export.py @@ -84,7 +84,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F num_tokens = pipeline.text_encoder.config.max_position_embeddings text_hidden_size = pipeline.text_encoder.config.hidden_size text_input = pipeline.tokenizer( - "A sample prompt", + ["A sample prompt", "A sample prompt"], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, @@ -97,9 +97,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.text_encoder @@ -117,13 +115,9 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F # False, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + ordered_input_names=["sample", "timestep", "encoder_hidden_states"], output_names=["out_sample"], # has to be different from "sample" for correct tracing - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - "timestep": {0: "batch"}, - "encoder_hidden_states": {0: "batch", 1: "sequence"}, - }, + dynamic_axes=None, opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) @@ -149,7 +143,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F vae_in_channels = vae_encoder.config.in_channels vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder - vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].mode() onnx_export( vae_encoder, model_args=( @@ -159,30 +153,24 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F output_path=output_path / "vae_encoder" / "model.onnx", ordered_input_names=["sample", "return_dict"], output_names=["latent_sample"], - dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) # VAE DECODER vae_decoder = pipeline.vae vae_latent_channels = vae_decoder.config.latent_channels - vae_out_channels = vae_decoder.config.out_channels # forward only through the decoder part - vae_decoder.forward = vae_encoder.decode + vae_decoder.forward = lambda latent: vae_decoder.decode(latent, return_dict=False)[0] onnx_export( vae_decoder, model_args=( torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), - False, ), output_path=output_path / "vae_decoder" / "model.onnx", - ordered_input_names=["latent_sample", "return_dict"], + ordered_input_names=["latent_sample"], output_names=["sample"], - dynamic_axes={ - "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, - }, + dynamic_axes=None, opset=opset, ) del pipeline.vae From c26819df2d6e78a10453ab460ea2219cbf7a21ce Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:03 +0800 Subject: [PATCH 086/314] Merge pull request #3998 from bolun365/bolun365-patch-1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mnn lib库自动化build脚本 GitOrigin-RevId: 9bac02d0d7bbb82f6a2cd42b01789f5efbdefd8c --- build_lib.sh | 807 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 807 insertions(+) create mode 100644 build_lib.sh diff --git a/build_lib.sh b/build_lib.sh new file mode 100644 index 0000000000..c839b6e7b6 --- /dev/null +++ b/build_lib.sh @@ -0,0 +1,807 @@ +#!/bin/bash + +# MNN 统一构建脚本 +# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN + +set -e + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# 默认配置 +BUILD_ANDROID=false +BUILD_IOS=false +BUILD_PYTHON=false +BUILD_IOS_SIMULATOR=false +BUILD_HARMONY=false +ANDROID_NDK="" +HARMONY_HOME="" +OUTPUT_DIR="mnn_builds" +CLEAN_BUILD=true +VERSION="" +PYTHON_CMD="python3" +DEPS_OPTIONS="" + +# 获取项目根目录 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR" + +print_usage() { + echo "MNN 统一构建脚本" + echo "" + echo "用法: $0 [选项]" + echo "" + echo "构建目标:" + echo " --android 构建 Android 版本 (需要 --ndk)" + echo " --ios 构建 iOS 真机版本 (arm64)" + echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)" + echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)" + echo " --python 构建 Python 版本" + echo "" + echo "Android 选项:" + echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)" + echo "" + echo "鸿蒙选项:" + echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)" + echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native" + echo "" + echo "Python 选项:" + echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔" + echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp" + echo " --python-cmd CMD Python 命令 (默认: python3)" + echo "" + echo "通用选项:" + echo " -o, --output DIR 输出目录 (默认: mnn_builds)" + echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)" + echo " --no-clean 不清理之前的构建目录" + echo " -h, --help 显示帮助信息" + echo "" + echo "示例:" + echo " # 构建所有平台" + echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 仅构建 Android" + echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879" + echo "" + echo " # 构建鸿蒙版本" + echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native" + echo "" + echo " # 构建 iOS 真机和模拟器" + echo " $0 --ios --ios-simulator" + echo "" + echo " # 构建 Python (带 LLM 支持)" + echo " $0 --python --python-deps llm,opencl" +} + +# 解析命令行参数 +while [[ $# -gt 0 ]]; do + case $1 in + --android) + BUILD_ANDROID=true + shift + ;; + --ios) + BUILD_IOS=true + shift + ;; + --ios-simulator) + BUILD_IOS_SIMULATOR=true + shift + ;; + --python) + BUILD_PYTHON=true + shift + ;; + --harmony) + BUILD_HARMONY=true + shift + ;; + --ndk) + ANDROID_NDK="$2" + shift 2 + ;; + --harmony-home) + HARMONY_HOME="$2" + shift 2 + ;; + --python-deps) + DEPS_OPTIONS="$2" + shift 2 + ;; + --python-cmd) + PYTHON_CMD="$2" + shift 2 + ;; + -o|--output) + OUTPUT_DIR="$2" + shift 2 + ;; + -v|--version) + VERSION="$2" + shift 2 + ;; + --no-clean) + CLEAN_BUILD=false + shift + ;; + -h|--help) + print_usage + exit 0 + ;; + *) + echo -e "${RED}错误: 未知选项 $1${NC}" + print_usage + exit 1 + ;; + esac +done + +# 检查是否至少选择了一个构建目标 +if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then + echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}" + print_usage + exit 1 +fi + +# 检查 Android NDK +if [ "$BUILD_ANDROID" = true ]; then + if [ -z "$ANDROID_NDK" ]; then + echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}" + echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}" + exit 1 + fi + + # 展开路径中的 ~ + ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}" + + if [ ! -d "$ANDROID_NDK" ]; then + echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}" + exit 1 + fi +fi + +# 查找鸿蒙工具链的函数 +find_harmony_toolchain() { + # 如果环境变量已设置,优先使用 + if [ -n "$HARMONY_HOME" ]; then + # 支持两种路径格式:直接指定或带 native 子目录 + if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$HARMONY_HOME" + return 0 + fi + fi + + # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录 + local sdk_base="$HOME/Library/OpenHarmony/Sdk" + if [ -d "$sdk_base" ]; then + # 查找所有版本号目录下的 native 或 toolchains + # 按版本号倒序排列,优先使用最新版本 + for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do + if [ -d "$version_dir" ]; then + # 尝试 native/build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir/native" + return 0 + fi + # 尝试 build/cmake/ohos.toolchain.cmake + if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then + echo "$version_dir" + return 0 + fi + fi + done + fi + + # 其他可能的路径 + local possible_paths=( + "$HOME/Library/OpenHarmony/Sdk/native" + "$HOME/HarmonyOS/Sdk/native" + "$HOME/.ohos/native" + "/opt/HarmonyOS/Sdk/native" + "/usr/local/HarmonyOS/Sdk/native" + "$HOME/Library/HarmonyOS/Sdk/native" + ) + + # 尝试查找 + for path in "${possible_paths[@]}"; do + if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then + echo "$path" + return 0 + fi + done + + # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找 + # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢 + local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1) + if [ -n "$found" ]; then + # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录 + found=$(dirname "$found") + if [ "$(basename "$found")" = "cmake" ]; then + found=$(dirname "$found") + if [ "$(basename "$found")" = "build" ]; then + found=$(dirname "$found") + fi + fi + echo "$found" + return 0 + fi + + return 1 +} + +# 检查鸿蒙工具链 +if [ "$BUILD_HARMONY" = true ]; then + # 展开路径中的 ~ + if [ -n "$HARMONY_HOME" ]; then + HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}" + fi + + # 尝试查找工具链 + if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + # 检查是否在 native 子目录下 + if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + else + echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}" + HARMONY_HOME=$(find_harmony_toolchain) + + if [ -z "$HARMONY_HOME" ]; then + echo -e "${RED}错误: 找不到鸿蒙工具链${NC}" + echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}" + echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}" + echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi + + # 如果找到的是带 native 的路径,需要调整 + if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then + HARMONY_HOME="$HARMONY_HOME/native" + fi + fi + + echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}" + fi + + # 验证工具链文件存在 + if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then + echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}" + exit 1 + fi +fi + +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}MNN 统一构建脚本${NC}" +echo -e "${GREEN}========================================${NC}" +echo "项目根目录: $PROJECT_ROOT" +echo "输出目录: $OUTPUT_DIR" +echo "" +echo -e "${BLUE}构建目标:${NC}" +[ "$BUILD_ANDROID" = true ] && echo " ✓ Android" +[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)" +[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)" +[ "$BUILD_PYTHON" = true ] && echo " ✓ Python" +echo "" + +cd "$PROJECT_ROOT" +mkdir -p "$OUTPUT_DIR" + +# ============================================================================ +# 构建 Android 版本 +# ============================================================================ +if [ "$BUILD_ANDROID" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Android 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export ANDROID_NDK + + ANDROID_BUILD_DIR="project/android" + ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android" + + cd "$ANDROID_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 Android 构建...${NC}" + rm -rf build_32 build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + + # 构建 armeabi-v7a + echo -e "${BLUE}构建 armeabi-v7a...${NC}" + mkdir -p build_32 + cd build_32 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="armeabi-v7a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-14 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + cmake ../../../ \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DANDROID_ABI="arm64-v8a" \ + -DANDROID_STL=c++_static \ + -DANDROID_NATIVE_API_LEVEL=android-21 \ + -DANDROID_TOOLCHAIN=clang \ + -DMNN_USE_LOGCAT=false \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出 Android 库文件...${NC}" + mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN + + cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true + cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/android/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" + cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/" + + echo -e "${GREEN}Android 构建完成!${NC}" + echo "输出位置: $ANDROID_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 真机版本 +# ============================================================================ +if [ "$BUILD_IOS" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 真机版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_64 + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建真机版本 (arm64) + echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}" + rm -rf ios_64 + mkdir ios_64 + cd ios_64 + cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=OS64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + > /dev/null + make MNN -j8 > /dev/null + cd .. + + # 输出真机版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理 + rm -rf ios_64 + + echo -e "${GREEN}iOS 真机构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 iOS 模拟器版本 +# ============================================================================ +if [ "$BUILD_IOS_SIMULATOR" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查是否在 macOS 上 + if [[ "$OSTYPE" != "darwin"* ]]; then + echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}" + exit 1 + fi + + # 检查 Xcode + if ! command -v xcodebuild &> /dev/null; then + echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}" + exit 1 + fi + + IOS_BUILD_DIR="project/ios" + IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator" + + cd "$IOS_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}" + rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator* + find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true + fi + + mkdir -p MNN-iOS-CPU-GPU/Static + cd MNN-iOS-CPU-GPU/Static + + # 构建 x86_64 模拟器版本(尝试构建,失败则跳过) + BUILD_X86_64=false + echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}" + rm -rf ios_simulator_x86 + mkdir ios_simulator_x86 + cd ios_simulator_x86 + if cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="x86_64" \ + -DMNN_ARM82=OFF \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + && make MNN -j8; then + if [ -f "MNN.framework/MNN" ]; then + BUILD_X86_64=true + echo -e "${GREEN}x86_64 模拟器构建成功${NC}" + cd .. + else + echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + else + echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}" + cd .. + rm -rf ios_simulator_x86 + fi + + # 构建 arm64 模拟器版本 + echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}" + rm -rf ios_simulator_arm64 + mkdir ios_simulator_arm64 + cd ios_simulator_arm64 + if ! cmake "$PROJECT_ROOT" \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \ + -DENABLE_BITCODE=0 \ + -DMNN_AAPL_FMWK=1 \ + -DMNN_SEP_BUILD=OFF \ + -DMNN_BUILD_SHARED_LIBS=false \ + -DMNN_USE_THREAD_POOL=OFF \ + -DPLATFORM=SIMULATOR64 \ + -DARCHS="arm64" \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_METAL=OFF \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON; then + echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + if ! make MNN -j8; then + echo -e "${RED}错误: arm64 模拟器编译失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + cd .. + + # 验证构建产物 + if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then + echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + + # 合并模拟器架构 + echo -e "${BLUE}合并模拟器架构...${NC}" + rm -rf ios_simulator + mkdir ios_simulator + + if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then + # 合并 x86_64 + arm64 + echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}" + cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework + mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86 + if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then + echo -e "${RED}错误: 合并架构失败${NC}" + cd "$PROJECT_ROOT" + exit 1 + fi + rm ios_simulator/MNN.framework/MNN_x86 + else + # 仅使用 arm64 + echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}" + cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework + fi + + # 输出模拟器版本 + mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR" + rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework" + cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/" + + # 清理临时目录 + rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64 + + echo -e "${GREEN}iOS 模拟器构建完成!${NC}" + echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建鸿蒙版本 +# ============================================================================ +if [ "$BUILD_HARMONY" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建鸿蒙版本${NC}" + echo -e "${GREEN}========================================${NC}" + + export HARMONY_HOME + + HARMONY_BUILD_DIR="project/harmony" + HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony" + + cd "$HARMONY_BUILD_DIR" + + # 清理 + if [ "$CLEAN_BUILD" = true ]; then + echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}" + rm -rf build_64 export + fi + + # 在项目根目录创建输出目录 + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + + # 构建 arm64-v8a + echo -e "${BLUE}构建 arm64-v8a...${NC}" + mkdir -p build_64 + cd build_64 + + cmake "$PROJECT_ROOT" \ + -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=Release \ + -DOHOS_ARCH="arm64-v8a" \ + -DOHOS_STL=c++_static \ + -DMNN_USE_LOGCAT=true \ + -DMNN_BUILD_BENCHMARK=ON \ + -DMNN_USE_SSE=OFF \ + -DMNN_SUPPORT_BF16=OFF \ + -DMNN_BUILD_TEST=ON \ + -DMNN_ARM82=ON \ + -DMNN_LOW_MEMORY=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ + -DMNN_BUILD_LLM=ON \ + -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \ + -DMNN_BUILD_DIFFUSION=ON \ + -DMNN_OPENCL=OFF \ + -DMNN_SEP_BUILD=OFF \ + -DLLM_SUPPORT_AUDIO=ON \ + -DMNN_BUILD_AUDIO=ON \ + -DLLM_SUPPORT_VISION=ON \ + -DMNN_BUILD_OPENCV=ON \ + -DMNN_IMGCODECS=ON \ + -DOHOS_PLATFORM_LEVEL=9 \ + -DNATIVE_LIBRARY_OUTPUT=. \ + -DNATIVE_INCLUDE_OUTPUT=. \ + > /dev/null + + make -j4 MNN > /dev/null + cd .. + + # 导出文件 + echo -e "${BLUE}导出鸿蒙库文件...${NC}" + mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN + + cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true + cp -r ../../include/MNN/* export/harmony/include/MNN/ + + # 复制到统一输出目录 + # 如果目标路径存在但不是目录,先删除 + if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then + rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + fi + mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" + cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/" + + echo -e "${GREEN}鸿蒙构建完成!${NC}" + echo "输出位置: $HARMONY_OUTPUT_DIR" + cd "$PROJECT_ROOT" +fi + +# ============================================================================ +# 构建 Python 版本 +# ============================================================================ +if [ "$BUILD_PYTHON" = true ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}开始构建 Python 版本${NC}" + echo -e "${GREEN}========================================${NC}" + + # 检查 Python + if ! command -v $PYTHON_CMD &> /dev/null; then + echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}" + exit 1 + fi + + PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python" + + # 使用之前创建的 build_pymnn.sh 脚本 + if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then + PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR" + [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION" + [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS" + [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean" + PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD" + + bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS + else + echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}" + + cd pymnn/pip_package + + # 构建依赖 + if [ -n "$DEPS_OPTIONS" ]; then + $PYTHON_CMD build_deps.py $DEPS_OPTIONS + else + $PYTHON_CMD build_deps.py + fi + + # 构建 wheel + $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet + rm -rf build dist + + BUILD_ARGS="" + [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION" + [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS" + + $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS + + mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR" + cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/" + + cd "$PROJECT_ROOT" + fi + + echo -e "${GREEN}Python 构建完成!${NC}" + echo "输出位置: $PYTHON_OUTPUT_DIR" +fi + +# ============================================================================ +# 总结 +# ============================================================================ +echo "" +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}所有构建完成!${NC}" +echo -e "${GREEN}========================================${NC}" +echo "" +echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR" +echo "" +[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android" +[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device" +[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator" +[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony" +[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python" +echo "" + + From 7180b2de8cd1912f85de6bbde8755226d069d9b5 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:04:54 +0800 Subject: [PATCH 087/314] Merge pull request #4009 from HenryDen/default_opt Add a compile option and macro to default enable kleidiAI GitOrigin-RevId: d252203d159374844e90bfe13589b9c0c36f62ee --- CMakeLists.txt | 1 + source/backend/cpu/arm/CMakeLists.txt | 3 +++ source/core/Backend.hpp | 6 ++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 67502b606b..f99e37ec1c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,6 +258,7 @@ option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF) option(MNN_KLEIDIAI "Enable KLEIDIAI" ON) +option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON) option(MNN_AVX512 "Enable AVX512" OFF) diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index 18fca54a4e..61ebce6bdc 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -36,6 +36,9 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR if (MNN_KLEIDIAI) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/KleidiAI.cmake) download_kleidiai_and_collect_sources() + if(MNN_KLEIDIAI_DEFAULT_ON) + add_definitions(-DMNN_DEFAULT_USE_KLEIDIAI) + endif() endif() if (MNN_SME2) diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index bcf618c3c9..6850b6b4f6 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -68,9 +68,11 @@ struct RuntimeHint { // whether to use Arm sme2 cores when threads>1 bool useArmSme2Cores = true; - +#ifdef MNN_DEFAULT_USE_KLEIDIAI + bool enableKleidiAI = true; +#else bool enableKleidiAI = false; - +#endif // Use CPU Ids std::vector cpuIds; From 7bdfb0d003161d9e9f5d2c8d4d67caed1d72d91f Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 11:42:22 +0800 Subject: [PATCH 088/314] Merge branch feature/add_4th_groupchat into master Title: [Doc:Update] update dingtalk in README. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审的主要改动是对README文件中的钉钉群信息进行了更新,包括群号、状态以及删除了一些过时的信息。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25029869 GitOrigin-RevId: 323623143de7fac53e2a4683e9a3c2090f392ae6 --- README.md | 14 +++++++------- README_CN.md | 10 ++++------ README_JP.md | 9 +++++---- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 5fe168ed05..7959890c16 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ [![日本語バージョン](https://img.shields.io/badge/Language-%E6%97%A5%E6%9C%AC%E8%AA%9E-green)](README_JP.md) [![MNN Homepage](https://img.shields.io/badge/Homepage-Visit-green)](http://www.mnn.zone) -[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) -[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) +[![MNN Chat App](https://img.shields.io/badge/Apps-MNN_Chat-blue)](./apps/Android/MnnLlmChat/README.md) +[![TaoAvatar](https://img.shields.io/badge/Apps-MNN_TaoAvatar-blue)](./apps/Android/Mnn3dAvatar/README.md) ## News 🔥 - [2025/10/16] Support Qwen3-VL Series. -- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md) +- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon

@@ -154,13 +154,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En Dingtalk discussion groups: -Group #1 (Full): 23329087 +Group #4 (Available): 160170007549 -Group #2 (Full): 23350225 +Group #3 (Full) -Group #3: QR code: +Group #2 (Full): 23350225 -![MNN-3](doc/dingdingmnn3.png) +Group #1 (Full): 23329087 ## Historical Paper diff --git a/README_CN.md b/README_CN.md index edcf823a28..f769a1e14b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -111,12 +111,10 @@ MNN适配的硬件架构与精度详见下表: ## 社区交流与反馈 钉钉群组: -- 钉钉群1:23329087 -- 钉钉群2:23350225 -- 钉钉群3:扫描二维码加入 - -![MNN-3](doc/dingdingmnn3.png) - +- 钉钉群3 (可加入): 160170007549 +- 钉钉群3 (已无法加入) +- 钉钉群2 (已满): 23350225 +- 钉钉群1 (已满): 23329087 ## 历史论文 diff --git a/README_JP.md b/README_JP.md index c2baa58d94..2f33def31a 100644 --- a/README_JP.md +++ b/README_JP.md @@ -117,13 +117,14 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ Dingtalkディスカッショングループ: -グループ#1(満員):23329087 -グループ#2(満員):23350225 +グループ#4 :160170007549 -グループ#3:QRコード: +グループ#3 (満員) -![MNN-3](doc/dingdingmnn3.png) +グループ#2(満員):23350225 + +グループ#1(満員):23329087 ## 歴史的な論文 From 4ebb5cc0a17a275fd83597d4a19c79e1b12c96f2 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 15:02:01 +0800 Subject: [PATCH 089/314] Merge pull request #4027 from codefuturedalao/master [BugFix] fix a bug in compute mGroupWithComputeRate GitOrigin-RevId: 0a30b5c040bc34aff1de94e7fa571ebb8f2c20fa --- source/backend/cpu/CPUBackend.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 0e0bc1f136..95cbd903b7 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -491,6 +491,7 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p currentRate *= decreaseRate; totalComputeRate += currentRate * selectSize; mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize)); + groupIndex--; } for (auto& g : mGroupWithComputeRate) { g.first = g.first / totalComputeRate; From edf9165248cc421e986851d3bb1aea404b22adbb Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Mon, 22 Dec 2025 19:18:35 +0800 Subject: [PATCH 090/314] Merge pull request #4076 from jxt1234/feature/smallmodel_opt Feature/smallmodel opt GitOrigin-RevId: 5610add6e64c6d49f8b984d0d744c85f206f2be7 --- source/backend/cpu/CPUBackend.cpp | 7 +- source/backend/cpu/CPUBackend.hpp | 3 + source/backend/cpu/CPUBinary.cpp | 60 +- source/backend/cpu/CPUBinary.hpp | 4 + source/backend/cpu/CPUMatMul.cpp | 28 +- source/backend/cpu/CPUMatMul.hpp | 7 +- source/backend/cpu/CPURNNSequenceGRU.cpp | 70 +- source/backend/cpu/CPURNNSequenceGRU.hpp | 15 +- source/backend/cpu/CPURaster.cpp | 631 +++++++++--------- source/backend/cpu/CPURaster.hpp | 3 +- source/backend/cpu/ThreadPool.cpp | 32 +- source/backend/cpu/ThreadPool.hpp | 6 +- .../backend/cpu/compute/CommonOptFunction.cpp | 88 ++- source/core/Concurrency.h | 13 +- source/core/OpCommonUtils.cpp | 91 --- source/core/OpCommonUtils.hpp | 1 - source/core/TensorUtils.cpp | 12 + source/core/TensorUtils.hpp | 1 + source/geometry/GeometryComputerUtils.cpp | 4 +- source/geometry/GeometryComputerUtils.hpp | 2 +- source/geometry/GeometryReduce.cpp | 104 ++- source/geometry/GeometryReshape.cpp | 11 +- source/math/Vec.hpp | 3 +- test/core/ThreadPoolTest.cpp | 6 +- tools/cpp/ExprDebug.hpp | 53 +- tools/cpp/ModuleBasic.cpp | 46 +- transformers/llm/engine/src/llm.cpp | 21 +- 27 files changed, 747 insertions(+), 575 deletions(-) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 95cbd903b7..8d284aa33b 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -104,15 +104,14 @@ void CPURuntime::_bindCPUCore() const { #ifdef MNN_USE_THREAD_POOL if (nullptr != mThreadPool) { mThreadPool->active(); - mThreadPool->enqueue(std::make_pair([&](int i) { + ThreadPool::TASK task = std::make_pair([&](int i) { MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); - return 0; - }, mThreadNumber), mTaskIndex); + }, mThreadNumber); + mThreadPool->enqueue(&task, mTaskIndex); mThreadPool->deactive(); } #endif } - void CPURuntime::_resetThreadPool() const { mThreadNumber = std::max(1, mThreadNumber); mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 884036eb38..ec4c555dec 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -176,6 +176,9 @@ class CPUBackend : public Backend { #ifdef MNN_USE_THREAD_POOL inline int taskIndex() const {return mRuntime->mTaskIndex;} inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;} + void enqueue(ThreadPool::TASK& task) const { + threadPool()->enqueue(&task, taskIndex()); + } #endif static void initCreatorMap(); static size_t getBytes(const Backend* backend, const Tensor* output); diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp index 059e502d0b..61ccf4fca3 100644 --- a/source/backend/cpu/CPUBinary.cpp +++ b/source/backend/cpu/CPUBinary.cpp @@ -45,6 +45,37 @@ ErrorCode CPUBinary::onResize(const std::vector& inputs, const std::vec mThreadNum = threads; mWorkDiv = UP_DIV(mTotalSize, threads); } + int inpBytes = inputs[0]->getType().bytes(); + int outBytes = outputs[0]->getType().bytes(); + if (halide_type_float == inputs[0]->getType().code) { + inpBytes = static_cast(backend())->functions()->bytes; + } + if (halide_type_float == outputs[0]->getType().code) { + outBytes = static_cast(backend())->functions()->bytes; + } + bool outputInt = outputs[0]->getType().code == halide_type_int; + mTask = std::make_pair([this, inpBytes, outBytes, outputInt](int tId) { + int start = tId * mWorkDiv; + int realSize = ALIMIN(mWorkDiv, mTotalSize - start); + if (realSize > 0) { + auto inp0 = mInput0Ptr + start * inpBytes; + auto inp1 = mInput1Ptr + start * inpBytes; + if (mNeedBroadcastIndex == 0) { + inp0 = mInput0Ptr; + } else if (mNeedBroadcastIndex == 1) { + inp1 = mInput1Ptr; + } + auto out = mOutputPtr + start * outBytes; + mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); + if(mActivationType == 1 && outputInt) { + for(int i=0; i 0 ? val : 0; + ((int32_t *)out)[i] = res; + } + } + } + } , mThreadNum); return NO_ERROR; } @@ -67,31 +98,10 @@ ErrorCode CPUBinary::onExecute(const std::vector& inputs, const std::ve outBytes = static_cast(backend())->functions()->bytes; } auto precision = static_cast(backend())->precisionMode(); - - MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { - int start = tId * mWorkDiv; - int realSize = ALIMIN(mWorkDiv, mTotalSize - start); - if (realSize > 0) { - auto inp0 = input0Ptr + start * inpBytes; - auto inp1 = input1Ptr + start * inpBytes; - if (mNeedBroadcastIndex == 0) { - inp0 = input0Ptr; - } else if (mNeedBroadcastIndex == 1) { - inp1 = input1Ptr; - } - auto out = outputPtr + start * outBytes; - mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); - if(mActivationType == 1 && output->getType().code == halide_type_int) { - for(int i=0; i 0 ? val : 0; - ((int32_t *)out)[i] = res; - } - } - } - } - MNN_CONCURRENCY_END(); - + mInput0Ptr = input0Ptr; + mInput1Ptr = input1Ptr; + mOutputPtr = outputPtr; + MNN_CONCURRENCY_ENQUEUE(mTask); if(mActivationType == 1 && output->getType().code == halide_type_float) { mActivationExe->onExecute(outputs, outputs); } diff --git a/source/backend/cpu/CPUBinary.hpp b/source/backend/cpu/CPUBinary.hpp index 9250df79ae..17cb3b5f47 100644 --- a/source/backend/cpu/CPUBinary.hpp +++ b/source/backend/cpu/CPUBinary.hpp @@ -33,6 +33,10 @@ class CPUBinary : public Execution { int mThreadNum; int mWorkDiv; std::shared_ptr mActivationExe; + std::pair, int> mTask; + uint8_t* mInput0Ptr = nullptr; + uint8_t* mOutputPtr = nullptr; + uint8_t* mInput1Ptr = nullptr; }; } // namespace MNN #endif /* CPUBinary_hpp */ diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 4f0765f050..22b96a64ee 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -37,9 +37,8 @@ void CPUMatMul::_scheduleForVecE(int e, int l, int h) { param.BTranspose = mTransposeB; param.numberThread = numberThread; auto func = static_cast(backend())->functions()->MNNComputeMatMulForE_1; - mPreFunctions.emplace_back(std::make_pair([param, func]( - int tId, const float* A, const float* B, const float* biasPtr, float* C) { - func(A, B, C, biasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func, this](int tId) { + func(mA, mB, mC, mBiasPtr, ¶m, tId); }, numberThread)); } @@ -54,9 +53,9 @@ void CPUMatMul::_scheduleForVec(int e, int l, int h) { auto func = static_cast(backend())->functions()->MNNComputeMatMulForH_1; // TODD: Support e = 1 MNN_ASSERT(h == 1); - mPreFunctions.emplace_back(std::make_pair([param, func]( - int tId, const float* A, const float* B, const float* biasPtr, float* C) { - func(A, B, C, biasPtr, ¶m, tId); + mPreFunctions.emplace_back(std::make_pair([param, func, this]( + int tId) { + func(mA, mB, mC, mBiasPtr, ¶m, tId); }, numberThread)); } @@ -100,8 +99,8 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec return OUT_OF_MEMORY; } - mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) { - core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB); + mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId) { + core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), mB, h, 1, l, mTransposeB); } , 1)); bool useBias = false; MemChunk bdestAlloc; @@ -120,9 +119,9 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec } mTempBias = bdestAlloc; mPreFunctions.emplace_back(std::make_pair( - [biasLength, bdestAlloc, core](int tId, const float* APtr, const float* BPtr, const float* borigin, float* C) { + [biasLength, bdestAlloc, core, this](int tId) { ::memset(bdestAlloc.ptr(), 0, UP_DIV(biasLength, core->pack) * core->bytes * core->pack); - ::memcpy(bdestAlloc.ptr(), borigin, biasLength * core->bytes); + ::memcpy(bdestAlloc.ptr(), mBiasPtr, biasLength * core->bytes); }, 1)); } else { mUseBiasDirectly = true; @@ -167,11 +166,12 @@ ErrorCode CPUMatMul::onExecute(const std::vector& inputs, const std::ve } void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const float* biasPtr) { + mA = APtr; + mB = BPtr; + mC = CPtr; + mBiasPtr = biasPtr; for (auto& f : mPreFunctions) { - MNN_CONCURRENCY_BEGIN(tId, f.second) { - f.first(tId, APtr, BPtr, biasPtr, CPtr); - } - MNN_CONCURRENCY_END(); + MNN_CONCURRENCY_ENQUEUE(f); } if (mE > 0) { auto core = static_cast(backend())->functions(); diff --git a/source/backend/cpu/CPUMatMul.hpp b/source/backend/cpu/CPUMatMul.hpp index 872a77a9a8..48226795f0 100644 --- a/source/backend/cpu/CPUMatMul.hpp +++ b/source/backend/cpu/CPUMatMul.hpp @@ -29,7 +29,7 @@ class CPUMatMul : public Execution { bool mTransposeB; bool mTransposeC; bool mSupportMultiThread = false; - std::vector, int>> mPreFunctions; + std::vector, int>> mPreFunctions; bool mUseBiasDirectly = false; MemChunk mTempA; MemChunk mTempB; @@ -40,6 +40,11 @@ class CPUMatMul : public Execution { int mL; int mH; std::vector mPostParameters; + // For Execute Paramters + const float* mA = nullptr; + const float* mB = nullptr; + const float* mBiasPtr = nullptr; + float* mC = nullptr; }; } // namespace MNN diff --git a/source/backend/cpu/CPURNNSequenceGRU.cpp b/source/backend/cpu/CPURNNSequenceGRU.cpp index daae8811c7..0bda660e9c 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.cpp +++ b/source/backend/cpu/CPURNNSequenceGRU.cpp @@ -10,30 +10,26 @@ #include #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/ConvOpt.h" -#include "backend/cpu/compute/CommonOptFunction.h" #include "core/TensorUtils.hpp" namespace MNN { // implement GRU cell function // Ref: tensorflow/python/ops/rnn_cell_impl.py -void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, +void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt) { - auto bn = static_cast(backend()); - auto mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); - auto addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); - auto subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); - auto tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); - auto bytes = bn->functions()->bytes; - auto sigmoidFunc = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); + std::shared_ptr& resetHt, uint8_t* hiddenStateOutput) { // gate is (z_t, r_t) + auto bytes = mRNNFunctions.bytes; + MNNBinaryExecute mulFunction = mRNNFunctions.mulFunction; + MNNBinaryExecute addFunction = mRNNFunctions.addFunction; + MNNBinaryExecute subFunction = mRNNFunctions.subFunction; + MNNUnaryExecute tanhFunction = mRNNFunctions.tanhFunction; + MNNUnaryExecute sigmoidFunction = mRNNFunctions.sigmoidFunction; auto inputAndStatePtr = inputAndState->host(); - auto hiddenStatePtr = hiddenState->host(); ::memcpy(inputAndStatePtr, input, inputLength * bytes); - ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStatePtr, numUnits * bytes); + ::memcpy(inputAndStatePtr + inputLength * bytes, hiddenStateInput, numUnits * bytes); inputAndState->setLength(1, inputLength + numUnits); // // [x_t, h_t-1] * [W_zr, R_zr]: (1, inputLength + numUnits) X (inputLength + numUnits, 2 * numUnits) @@ -42,9 +38,8 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, recurrentBias->setLength(1, 2 * numUnits); addFunction(gate->host(), gate->host(), recurrentBias->host(), 2*numUnits, -1); // (1, 2*numUnits) - const int gateSize = gate->elementSize(); auto gatePtr = gate->host(); - sigmoidFunc(gatePtr, gatePtr, gateSize); + sigmoidFunction(gatePtr, gatePtr, 2 * numUnits); // reset gate, // r_t is the second segment auto rtPtr = gatePtr + numUnits * bytes; @@ -52,7 +47,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // calculate Rt (.) (Ht_1 * Rh + Rbh) auto recurrentHiddenBiasPtr = recurrentBias->host() + 2 * numUnits * bytes; auto rhWeightPtr = candidateWeight->host() + inputLength * numUnits * bytes; - mMatMulU2U->execute(hiddenState->host(), (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); + mMatMulU2U->execute((float*)hiddenStateInput, (float*)rhWeightPtr, resetHt->host(), (float*)recurrentHiddenBiasPtr); mulFunction(resetHt->host(), rtPtr, resetHt->host(), numUnits, -1); // calculate Xt * Wh @@ -65,7 +60,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, // r_t: (1, numUnits) auto resetGatePtr = inputAndStatePtr + inputLength * bytes; // h_t1(1, numUnits) = r_t(1, numUnits) * h_t-1_(1, numUnits) - mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1); + mulFunction(resetGatePtr, rtPtr, hiddenStateInput, numUnits, -1); // deal with recurrent bias and linear_before_reset parameter auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes; auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host() + 2 * numUnits * bytes); @@ -76,9 +71,9 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength, } // h = (1-g)*t+g*h = t + g*(h-t) tanhFunction(resetHt->host(), rtPtr, numUnits); - subFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); - mulFunction(hiddenStatePtr, hiddenStatePtr, gatePtr, numUnits, -1); - addFunction(hiddenStatePtr, hiddenStatePtr, resetHt->host(), numUnits, -1); + subFunction(hiddenStateOutput, hiddenStateInput, resetHt->host(), numUnits, -1); + mulFunction(hiddenStateOutput, hiddenStateOutput, gatePtr, numUnits, -1); + addFunction(hiddenStateOutput, hiddenStateOutput, resetHt->host(), numUnits, -1); inputAndState->setLength(1, inputLength + 2 * numUnits); } @@ -143,6 +138,13 @@ ErrorCode CPURNNSequenceGRU::onResize(const std::vector& inputs, const backend()->onReleaseBuffer(mInputAndState.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mGate.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mResetHt.get(), Backend::DYNAMIC); + auto bn = static_cast(backend()); + mRNNFunctions.mulFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_MUL); + mRNNFunctions.addFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); + mRNNFunctions.subFunction = bn->functions()->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_SUB); + mRNNFunctions.tanhFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_TANH, bn->precisionMode()); + mRNNFunctions.bytes = bn->functions()->bytes; + mRNNFunctions.sigmoidFunction = bn->functions()->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_SIGMOID, bn->precisionMode()); return NO_ERROR; } @@ -183,27 +185,29 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const const int inputCodeLength = input->length(2); // MNN_PRINT("inputSequenceLength:%d, batchSize:%d, inputCodeLength:%d, mNumUnits:%d, hiddenStateDataSize:%d\n", inputSequenceLength, batchSize, inputCodeLength, mNumUnits, hiddenStateDataSize); for (int b = 0; b < batchSize; ++b) { // swap order + auto hiddenStateInput = hiddenStatePtr; + auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * (mIsBidirectionalRNN + 1)) { auto source = inputs[inputSize - 1]->host() + b * hiddenStateDataSize; - ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); + hiddenStateInput = source; } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = 0; i < inputSequenceLength; ++i) { const int inputOffset = i * SequenceStride + b * inputCodeLength; - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, fwGateWeight, fwGateBias, - fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt); - if (mKeepAllOutputs) { - ::memcpy(outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes, hiddenStatePtr, hiddenStateDataSize); + hiddenStateOutput = outputPtr + (i * output->stride(0) + b * mNumUnits) * bytes; } + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, fwGateWeight, fwGateBias, + fwCandidateWeight, fwCandidateBias, fwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); + + hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { - ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); + ::memcpy(outputYhPtr, hiddenStateOutput, hiddenStateDataSize); outputYhPtr += mNumUnits * bytes; } - } // backward rnn @@ -221,22 +225,24 @@ ErrorCode CPURNNSequenceGRU::onExecute(const std::vector& inputs, const auto outputBw = outputs[0]; auto const outputBwPtr = outputBw->host(); for (int b = 0; b < batchSize; ++b) { + auto hiddenStateInput = hiddenStatePtr; + auto hiddenStateOutput = hiddenStatePtr; if (inputSize > 1 + forwardParamNumber * 2) { auto source = inputs[inputSize - 1]->host() + (batchSize + b) * hiddenStateDataSize; - ::memcpy(hiddenStatePtr, source, hiddenStateDataSize); + hiddenStateInput = source; } else { ::memset(hiddenStatePtr, 0, hiddenStateDataSize); } for (int i = inputSequenceLength - 1; i >= 0; i--) { const int inputOffset = i * SequenceStride + b * inputCodeLength; - runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, mHiddenState, mNumUnits, bwGateWeight, bwGateBias, - bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt); if (mKeepAllOutputs) { - ::memcpy(outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes, - hiddenStatePtr, hiddenStateDataSize); + hiddenStateOutput = outputBwPtr + (i * outputBw->stride(0) + (batchSize + b) * mNumUnits) * bytes; } + runRNNStep(inputPtr + inputOffset * bytes, inputCodeLength, mlinearBeforeReset, hiddenStateInput, mNumUnits, bwGateWeight, bwGateBias, + bwCandidateWeight, bwCandidateBias, bwRecurrentBias, mInputAndState, mGate, mResetHt, hiddenStateOutput); + hiddenStateInput = hiddenStateOutput; } if ((mKeepAllOutputs && outputSize > 1) || !mKeepAllOutputs) { ::memcpy(outputYhPtr, hiddenStatePtr, hiddenStateDataSize); diff --git a/source/backend/cpu/CPURNNSequenceGRU.hpp b/source/backend/cpu/CPURNNSequenceGRU.hpp index 0987d13053..0125b9e8a1 100644 --- a/source/backend/cpu/CPURNNSequenceGRU.hpp +++ b/source/backend/cpu/CPURNNSequenceGRU.hpp @@ -11,6 +11,7 @@ #include "core/Execution.hpp" #include "CPUMatMul.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" namespace MNN { class CPURNNSequenceGRU : public Execution { @@ -19,13 +20,20 @@ class CPURNNSequenceGRU : public Execution { virtual ~CPURNNSequenceGRU(); virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - + struct RNNFuntions { + MNNBinaryExecute mulFunction; + MNNBinaryExecute addFunction; + MNNBinaryExecute subFunction; + MNNUnaryExecute tanhFunction; + MNNUnaryExecute sigmoidFunction; + int bytes; + }; private: void runRNNStep(const uint8_t* input, const int inputLength, const bool linearBeforeReset, - std::shared_ptr& hiddenState, const int numUnits, Tensor* gateWeight, Tensor* gateBias, + uint8_t* hiddenStateInput, const int numUnits, Tensor* gateWeight, Tensor* gateBias, Tensor* candidateWeight, Tensor* candidateBias, Tensor* recurrentBias, std::shared_ptr& inputAndState, std::shared_ptr& gate, - std::shared_ptr& resetHt); + std::shared_ptr& resetHt, uint8_t* hiddenStateOutput); bool mKeepAllOutputs; bool mIsBidirectionalRNN; bool mlinearBeforeReset; @@ -42,6 +50,7 @@ class CPURNNSequenceGRU : public Execution { std::shared_ptr mMatMulU2U; // For inputLength -> numUnit std::shared_ptr mMatMulI2U; + RNNFuntions mRNNFunctions; }; } // namespace MNN diff --git a/source/backend/cpu/CPURaster.cpp b/source/backend/cpu/CPURaster.cpp index 3272086531..1339089347 100644 --- a/source/backend/cpu/CPURaster.cpp +++ b/source/backend/cpu/CPURaster.cpp @@ -49,227 +49,6 @@ struct ReduceInfo { } }; -ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { - MNN_ASSERT(outputs.size() == 1); - auto output = outputs[0]; - OpCommonUtils::rasterInputReset(____inputs, outputs[0]); - auto des = TensorUtils::getDescribe(output); - auto outputDes = TensorUtils::getDescribe(output); - mNeedZero = !TensorUtils::regionIsFull(output); - mZeroPoint = 0; - mUseThreads = false; - if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { -#ifdef MNN_USE_SSE - mZeroPoint = (int)outputDes->quantAttr->zero + 128; -#else - mZeroPoint = (int)outputDes->quantAttr->zero; -#endif - } - mTempInput.clear(); - mFastBlit.clear(); - mCacheRegions.clear(); - mTempOutput = nullptr; - auto midFormat = MNN_DATA_FORMAT_NCHW; - mTempInputCopy.clear(); - mFast = false; - auto core = static_cast(backend())->functions(); - mSingleConvert.type = 0; - // all_srcFormat == dstFormat == NC4HW4 : Fast Exe - if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { - mFast = true; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - mFast = false; - break; - } - if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { - mFast = false; - break; - } - } - if (mFast) { - mUseThreads = des->regions.size() > 16 ? true : false; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - if (slice.origin == nullptr) { - continue; - } - Tensor::InsideDescribe::Region newRegion; - OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); - mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); - } - return NO_ERROR; - } - } - // srcNum == 1 && srcFormat != dstFormat : Single Convert - if (des->regions.size() == 1) { - OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); - if (mSingleConvert.type > 0) { - mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; - return NO_ERROR; - } - } - // Acquire Buffer for temp output - // TODO: optimize it - if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { - mTempOutput.reset(new Tensor); - TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); - } - if (nullptr != mTempOutput) { - auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - } - // input is NC4HW4 add Convert - std::vector forRelease; - TensorUtils::FuseWrap fuseUtils; - for (int i=0; i< des->regions.size(); ++i) { - auto& slice = des->regions[i]; - auto origin = slice.origin; - if (nullptr == origin /*|| nullptr == origin->host()*/) { - continue; - } - // if tensor is not NC4HW4 or has been merged, don't need deal - if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); - continue; - } - // if NC4HW4's C%4 == 0, change convert to transpose and fuse it - if (origin->batch() == 1 && origin->channel() % core->pack == 0) { - int channel = origin->channel(); - int area = 1; - // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height - for (int d = 2; d < origin->dimensions(); d++) { - area *= origin->length(d); - } - Tensor::InsideDescribe::Region regionTmp; - regionTmp.src.offset = 0; - regionTmp.src.stride[0] = area * core->pack; - regionTmp.src.stride[1] = 1; - regionTmp.src.stride[2] = core->pack; - regionTmp.dst.offset = 0; - regionTmp.dst.stride[0] = area * core->pack; - regionTmp.dst.stride[1] = area; - regionTmp.dst.stride[2] = 1; - regionTmp.size[0] = channel / core->pack; - regionTmp.size[1] = core->pack; - regionTmp.size[2] = area; - regionTmp.origin = slice.origin; - bool merge = fuseUtils.match(regionTmp, slice); - if (merge) { - std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); - *newSlice = slice; - fuseUtils.apply(regionTmp, *newSlice); - // cache the merged tensor - if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); - mCacheRegions.emplace_back(newSlice); - continue; - } - } - auto cache = static_cast(backend())->getCache(); - auto tempTensor = cache->findCacheTensor(origin, midFormat); - //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); - if (nullptr == tempTensor) { - std::shared_ptr newTensor(new Tensor); - TensorUtils::copyShape(origin, newTensor.get()); - TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; - TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; - TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; - newTensor->buffer().type = origin->getType(); - TensorUtils::setLinearLayout(newTensor.get()); - mTempInput.insert(std::make_pair(origin, newTensor.get())); - auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - tempTensor = newTensor.get(); - TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; - cache->pushCacheTensor(newTensor, origin, midFormat); - } - if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { - forRelease.emplace_back(tempTensor); - } - if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); - } - for (auto t : forRelease) { - backend()->onReleaseBuffer(t, Backend::DYNAMIC); - } - if (nullptr != mTempOutput) { - backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); - } - auto threadNumber = static_cast(backend())->threadNumber(); - mHasReduce = false; - ReduceInfo reduceInfo; - for (auto& iter : mTempInputCopy) { - if (reduceInfo.compute(*iter.second)) { - mHasReduce = true; - break; - } - } - if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { - // Split to multi region - auto region = mTempInputCopy[0].second; - if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = false; - return NO_ERROR; - } - if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { - mUseThreads = true; - } - auto tensorPtr = mTempInputCopy[0].first; - int pos = -1; - for (int i=0; i<3; ++i) { - if (region->size[i] > 1) { - pos = i; - break; - } - } - if (-1 == pos) { - // Don't need divide - return NO_ERROR; - } - mTempInputCopy.clear(); - int divSize = UP_DIV(region->size[pos], threadNumber); - for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); - auto& cacheReg = *cacheRegPtr; - int sta = i * divSize; - int fin = sta + divSize; - fin = std::min(fin, region->size[pos]); - if (fin <= sta) { - break; - } - for (int v=0; v<3; ++v) { - cacheReg.src.stride[v] = region->src.stride[v]; - cacheReg.dst.stride[v] = region->dst.stride[v]; - } - int curSize = fin - sta; - for (int v=0; vsize[v]; - } - cacheReg.size[pos] = curSize; - cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; - cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; - for (int v=pos+1; v<3; ++v) { - cacheReg.size[v] = region->size[v]; - } - mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); - mCacheRegions.emplace_back(cacheRegPtr); - } - } - return NO_ERROR; -} static void _transpose(int32_t* dstO, const int32_t* srcO, const Tensor::InsideDescribe::Region& region, int bytes) { int dims[4], keepDim = -1; for (int i = 0; i < 3; i++) { @@ -324,15 +103,12 @@ static void _2BitcopyWithStrideC4(uint8_t* dstO, const uint8_t* srcO, int size, } } -void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) const { +void CPURaster::executeFaster(const std::vector &inputs, const std::vector &outputs) { auto input = inputs[0]; auto output = outputs[0]; auto bytes = CPUBackend::getBytes(backend(), output); auto core = static_cast(backend())->functions(); - auto threadNum = static_cast(backend())->threadNumber(); - if (mNeedZero) { - ::memset(output->host(), mZeroPoint, static_cast(backend())->getTensorSize(output) * bytes); - } + int threadNum = static_cast(backend())->threadNumber(); auto byteC4 = bytes * core->pack; auto C4proc = core->MNN4BitcopyWithStride; switch (byteC4) { @@ -352,7 +128,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve if (!mUseThreads) { threadNum = 1; } - MNN_CONCURRENCY_BEGIN(tId, threadNum) { + mTasks.emplace_back(std::make_pair([threadNum, this, output, bytes, C4proc, byteC4](int tId) { for (int u=(int)tId; uhost() == nullptr) { @@ -393,8 +169,7 @@ void CPURaster::executeFaster(const std::vector &inputs, const std::ve } } } - } - MNN_CONCURRENCY_END(); + }, threadNum)); } static BlitProc _selectUnitProc(int bytes, int stride, int ds) { @@ -596,97 +371,307 @@ static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const } } void CPURaster::tensorConvert(Tensor* input, Tensor* output, int bytes) { - auto& subIb = input->buffer(); - auto& subOb = output->buffer(); - auto source = TensorUtils::getDescribe(input)->dimensionFormat; - auto dest = TensorUtils::getDescribe(output)->dimensionFormat; - if (subIb.dimensions <= 1 || source == dest) { - ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); - return; - } - auto tup = CPUTensorConverter::splitDimensions(subIb, source); - int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); - const int bitLength = bytes; + std::pair, int> task; auto core = static_cast(backend())->functions(); auto threadNumber = static_cast(backend())->threadNumber(); if (!mUseThreads) { threadNumber = 1; } - MNN_CONCURRENCY_BEGIN(tId, threadNumber) { + task.first = [input, output, bytes, threadNumber, core](int tId) { + auto& subIb = input->buffer(); + auto& subOb = output->buffer(); + auto source = TensorUtils::getDescribe(input)->dimensionFormat; + auto dest = TensorUtils::getDescribe(output)->dimensionFormat; + if (subIb.dimensions <= 1 || source == dest) { + ::memcpy(subOb.host, subIb.host, input->elementSize() * bytes); + return; + } + auto tup = CPUTensorConverter::splitDimensions(subIb, source); + int area = std::get<1>(tup), batch = std::get<0>(tup), channel = std::get<2>(tup); + const int bitLength = bytes; CPUTensorConverter::convert(subIb.host, subOb.host, source, dest, batch, area, channel, bitLength, core, tId, threadNumber); }; - MNN_CONCURRENCY_END(); + task.second = threadNumber; + mTasks.emplace_back(task); } - - -ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { - void* mOutputPtr = nullptr; - if (nullptr != mTempOutput) { - mOutputPtr = mTempOutput->host(); - } else { - mOutputPtr = outputs[0]->host(); - } - if (mFast) { - executeFaster(____inputs, outputs); - return NO_ERROR; - } - auto core = static_cast(backend())->functions(); +ErrorCode CPURaster::onResize(const std::vector &____inputs, const std::vector &outputs) { + MNN_ASSERT(outputs.size() == 1); auto output = outputs[0]; + OpCommonUtils::rasterInputReset(____inputs, outputs[0]); + auto des = TensorUtils::getDescribe(output); + auto outputDes = TensorUtils::getDescribe(output); + mNeedZero = !TensorUtils::regionIsFull(output); + mZeroPoint = 0; + mUseThreads = false; + int threadNum = static_cast(backend())->threadNumber(); + if (outputDes->quantAttr != nullptr && outputDes->applyQuant) { +#ifdef MNN_USE_SSE + mZeroPoint = (int)outputDes->quantAttr->zero + 128; +#else + mZeroPoint = (int)outputDes->quantAttr->zero; +#endif + } size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); - auto outputEleSize = static_cast(backend())->getTensorSize(output); - auto threadNum = static_cast(backend())->threadNumber(); - if (mSingleConvert.type > 0) { - auto realInput = ____inputs[0]; - int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; - auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; - auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; - auto channelC4 = UP_DIV(srcChannel, core->pack); - auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; - auto batchStride = srcChannel * srcArea * bytes; - auto inputBatchStride = batchStride; - auto outputBatchStride = batchStride; - if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { - if (realInput->dimensions() <= 1) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - return NO_ERROR; + mTempInput.clear(); + mFastBlit.clear(); + mCacheRegions.clear(); + mTempOutput = nullptr; + mTasks.clear(); + auto midFormat = MNN_DATA_FORMAT_NCHW; + mTempInputCopy.clear(); + mFast = false; + auto core = static_cast(backend())->functions(); + mSingleConvert.type = 0; + // all_srcFormat == dstFormat == NC4HW4 : Fast Exe + if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { + mFast = true; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + mFast = false; + break; } - inputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - destFormat = MNN_DATA_FORMAT_NHWC; - } else { - destFormat = MNN_DATA_FORMAT_NCHW; + if (!OpCommonUtils::canBlitFast(slice, output, core->pack, true)) { + mFast = false; + break; } - } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { - if (output->dimensions() <= 1) { - ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); - return NO_ERROR; + } + if (mFast) { + mUseThreads = des->regions.size() > 16 ? true : false; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (slice.origin == nullptr) { + continue; + } + Tensor::InsideDescribe::Region newRegion; + OpCommonUtils::turnToPackRegion(slice, newRegion, output, core->pack, true); + mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); } - outputBatchStride = batchStrideC4; - if (2 == mSingleConvert.type) { - sourceFormat = MNN_DATA_FORMAT_NHWC; - } else { - sourceFormat = MNN_DATA_FORMAT_NCHW; + executeFaster(____inputs, outputs); + return NO_ERROR; + } + } + // srcNum == 1 && srcFormat != dstFormat : Single Convert + if (des->regions.size() == 1) { + OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); + if (mSingleConvert.type > 0) { + std::pair, int> task; + mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; + auto realInput = ____inputs[0]; + int srcBatch = mSingleConvert.batch, srcChannel = mSingleConvert.channel, srcArea = mSingleConvert.area; + auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat; + auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat; + auto channelC4 = UP_DIV(srcChannel, core->pack); + auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes; + auto batchStride = srcChannel * srcArea * bytes; + auto inputBatchStride = batchStride; + auto outputBatchStride = batchStride; + if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) { + if (realInput->dimensions() <= 1) { + task.first = [output, realInput, bytes](int tId) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + }; + task.second = 1; + mTasks.emplace_back(task); + return NO_ERROR; + } + inputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + destFormat = MNN_DATA_FORMAT_NHWC; + } else { + destFormat = MNN_DATA_FORMAT_NCHW; + } + } else if (MNN_DATA_FORMAT_NC4HW4 == destFormat) { + if (output->dimensions() <= 1) { + task.first = [output, realInput, bytes](int tId) { + ::memcpy(output->host(), realInput->host(), realInput->elementSize() * bytes); + }; + task.second = 1; + mTasks.emplace_back(task); + return NO_ERROR; + } + outputBatchStride = batchStrideC4; + if (2 == mSingleConvert.type) { + sourceFormat = MNN_DATA_FORMAT_NHWC; + } else { + sourceFormat = MNN_DATA_FORMAT_NCHW; + } } + if (!mUseThreads) { + threadNum = 1; + } + task.first = [realInput, output, sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, threadNum](int tId) { + CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); + }; + task.second = threadNum; + mTasks.emplace_back(task); + return NO_ERROR; } - if (!mUseThreads) { - threadNum = 1; + } + // Acquire Buffer for temp output + // TODO: optimize it + if (MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat) { + mTempOutput.reset(new Tensor); + TensorUtils::setupTensorInfo(output, mTempOutput.get(), midFormat); + } + if (nullptr != mTempOutput) { + auto res = backend()->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; } - MNN_CONCURRENCY_BEGIN(tId, threadNum) { - CPUTensorConverter::convert(realInput->host(), output->host(), sourceFormat, destFormat, srcBatch, srcArea, srcChannel, bytes, core, tId, threadNum); - }; - MNN_CONCURRENCY_END(); - return NO_ERROR; } - if (mNeedZero) { - if (mTempOutput == nullptr) { - ::memset(output->host(), mZeroPoint, outputEleSize * bytes); - } else { - ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); + // input is NC4HW4 add Convert + std::vector forRelease; + TensorUtils::FuseWrap fuseUtils; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + auto origin = slice.origin; + if (nullptr == origin /*|| nullptr == origin->host()*/) { + continue; + } + // if tensor is not NC4HW4 or has been merged, don't need deal + if (TensorUtils::getDescribe(origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, &slice)); + continue; } + // if NC4HW4's C%4 == 0, change convert to transpose and fuse it + if (origin->batch() == 1 && origin->channel() % core->pack == 0) { + int channel = origin->channel(); + int area = 1; + // conv3d/pool3d will has 5 dims, area = depth * width * height, otherwise area = width * height + for (int d = 2; d < origin->dimensions(); d++) { + area *= origin->length(d); + } + Tensor::InsideDescribe::Region regionTmp; + regionTmp.src.offset = 0; + regionTmp.src.stride[0] = area * core->pack; + regionTmp.src.stride[1] = 1; + regionTmp.src.stride[2] = core->pack; + regionTmp.dst.offset = 0; + regionTmp.dst.stride[0] = area * core->pack; + regionTmp.dst.stride[1] = area; + regionTmp.dst.stride[2] = 1; + regionTmp.size[0] = channel / core->pack; + regionTmp.size[1] = core->pack; + regionTmp.size[2] = area; + regionTmp.origin = slice.origin; + bool merge = fuseUtils.match(regionTmp, slice); + if (merge) { + std::shared_ptr newSlice(new Tensor::InsideDescribe::Region); + *newSlice = slice; + fuseUtils.apply(regionTmp, *newSlice); + // cache the merged tensor + if (newSlice->size[0] * newSlice->size[1] * newSlice->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(origin, newSlice.get())); + mCacheRegions.emplace_back(newSlice); + continue; + } + } + auto cache = static_cast(backend())->getCache(); + auto tempTensor = cache->findCacheTensor(origin, midFormat); + //MNN_ASSERT(CPUBackend::getBytes(backend(), origin) == 4); + if (nullptr == tempTensor) { + std::shared_ptr newTensor(new Tensor); + TensorUtils::copyShape(origin, newTensor.get()); + TensorUtils::getDescribe(newTensor.get())->dimensionFormat = midFormat; + TensorUtils::getDescribe(newTensor.get())->quantAttr = TensorUtils::getDescribe(origin)->quantAttr; + TensorUtils::getDescribe(newTensor.get())->applyQuant = TensorUtils::getDescribe(origin)->applyQuant;; + newTensor->buffer().type = origin->getType(); + TensorUtils::setLinearLayout(newTensor.get()); + mTempInput.insert(std::make_pair(origin, newTensor.get())); + auto res = backend()->onAcquireBuffer(newTensor.get(), Backend::DYNAMIC); + if (!res) { + return OUT_OF_MEMORY; + } + tempTensor = newTensor.get(); + TensorUtils::getDescribe(tempTensor)->useCount = TensorUtils::getDescribe(origin)->useCount; + cache->pushCacheTensor(newTensor, origin, midFormat); + } + if (--TensorUtils::getDescribe(tempTensor)->useCount == 0) { + forRelease.emplace_back(tempTensor); + } + if (slice.size[0] * slice.size[1] * slice.size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + mTempInputCopy.emplace_back(std::make_pair(tempTensor, &slice)); + } + for (auto t : forRelease) { + backend()->onReleaseBuffer(t, Backend::DYNAMIC); + } + if (nullptr != mTempOutput) { + backend()->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC); } + auto threadNumber = static_cast(backend())->threadNumber(); + mHasReduce = false; + ReduceInfo reduceInfo; + for (auto& iter : mTempInputCopy) { + if (reduceInfo.compute(*iter.second)) { + mHasReduce = true; + break; + } + } + // Encode convert for (auto& iter : mTempInput) { tensorConvert(iter.first, iter.second, (int)bytes); } + do { + if (mTempInputCopy.size() == 1 && threadNumber > 1 && (!mHasReduce)) { + // Split to multi region + auto region = mTempInputCopy[0].second; + if (region->size[0] * region->size[1] * region->size[2] < LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = false; + break; + } + if (region->size[0] * region->size[1] * region->size[2] > LAUNCH_MULTI_THREADS_WORKLOAD) { + mUseThreads = true; + } + auto tensorPtr = mTempInputCopy[0].first; + int pos = -1; + for (int i=0; i<3; ++i) { + if (region->size[i] > 1) { + pos = i; + break; + } + } + if (-1 == pos) { + // Don't need divide + break; + } + mTempInputCopy.clear(); + int divSize = UP_DIV(region->size[pos], threadNumber); + for (int i=0; i cacheRegPtr(new Tensor::InsideDescribe::Region); + auto& cacheReg = *cacheRegPtr; + int sta = i * divSize; + int fin = sta + divSize; + fin = std::min(fin, region->size[pos]); + if (fin <= sta) { + break; + } + for (int v=0; v<3; ++v) { + cacheReg.src.stride[v] = region->src.stride[v]; + cacheReg.dst.stride[v] = region->dst.stride[v]; + } + int curSize = fin - sta; + for (int v=0; vsize[v]; + } + cacheReg.size[pos] = curSize; + cacheReg.src.offset = region->src.offset + sta * region->src.stride[pos]; + cacheReg.dst.offset = region->dst.offset + sta * region->dst.stride[pos]; + for (int v=pos+1; v<3; ++v) { + cacheReg.size[v] = region->size[v]; + } + mTempInputCopy.emplace_back(std::make_pair(tensorPtr, cacheRegPtr.get())); + mCacheRegions.emplace_back(cacheRegPtr); + } + } + } while (false); if (mHasReduce) { // Don't support reduce with multi thread now threadNum = 1; @@ -700,8 +685,13 @@ ErrorCode CPURaster::onExecute(const std::vector &____inputs, const st if (outputDescribe->overlap) { threadNum = 1; } - - MNN_CONCURRENCY_BEGIN(tId, threadNum) { + mTasks.emplace_back(std::make_pair([this, threadNum, output, bytes, core](int tId){ + void* mOutputPtr = nullptr; + if (nullptr != mTempOutput) { + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = output->host(); + } for (int u=tId; u &____inputs, const st auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes; _blit(slice, (int)bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp); } - } - MNN_CONCURRENCY_END(); + }, threadNum)); if (nullptr != mTempOutput) { tensorConvert(mTempOutput.get(), output, (int)bytes); } return NO_ERROR; } + + +ErrorCode CPURaster::onExecute(const std::vector &____inputs, const std::vector &outputs) { + void* mOutputPtr = nullptr; + if (nullptr != mTempOutput) { + mOutputPtr = mTempOutput->host(); + } else { + mOutputPtr = outputs[0]->host(); + } + auto core = static_cast(backend())->functions(); + auto output = outputs[0]; + size_t bytes = (size_t)(CPUBackend::getBytes(backend(), output)); + auto outputEleSize = static_cast(backend())->getTensorSize(output); + auto threadNum = static_cast(backend())->threadNumber(); + if (mNeedZero) { + if (mTempOutput == nullptr) { + ::memset(output->host(), mZeroPoint, outputEleSize * bytes); + } else { + ::memset(mTempOutput->host(), mZeroPoint, mTempOutput->elementSize() * bytes); + } + } + for (auto& task : mTasks) { + MNN_CONCURRENCY_ENQUEUE(task); + } + return NO_ERROR; +} class CPULoop : public Execution { public: struct ThreadContainer { @@ -1066,7 +1081,15 @@ class CPULoop : public Execution { auto stride2 = cmd->view()->GetAs(2)->stride()->data(); auto blit1 = _selectUnitProc(bytes, stride1[2], 1); auto blit2 = _selectUnitProc(bytes, stride2[2], 1); - if (cmd->size()->data()[2] == 1 || (stride1[2] == 1 && stride2[2] == 1)) { + if (cmd->size()->data()[2] == 1 || (stride1[2] <= 1 && stride2[2] <= 1 && (stride1[2] + stride1[1] != 0))) { + // Support elementwise or one src broadcast + int needBroadcastIndex = -1; + if (0 == stride1[2]) { + needBroadcastIndex = 0; + } + if (0 == stride2[2]) { + needBroadcastIndex = 1; + } for (int z=0; zsize()->data()[0]; ++z) { auto src0Z = src0 + z * stride1[0] * bytes; auto src1Z = src1 + z * stride2[0] * bytes; @@ -1075,7 +1098,7 @@ class CPULoop : public Execution { auto src0Y = src0Z + y * stride1[1] * bytes; auto src1Y = src1Z + y * stride2[1] * bytes; auto dstY = dstZ + y * stride0[1] * bytes; - proc(dstY, src0Y, src1Y, cmd->size()->data()[2], -1); + proc(dstY, src0Y, src1Y, cmd->size()->data()[2], needBroadcastIndex); } } } else { diff --git a/source/backend/cpu/CPURaster.hpp b/source/backend/cpu/CPURaster.hpp index 9df10700bd..bff149df52 100644 --- a/source/backend/cpu/CPURaster.hpp +++ b/source/backend/cpu/CPURaster.hpp @@ -24,7 +24,7 @@ class CPURaster : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - void executeFaster(const std::vector &inputs, const std::vector &outputs) const; + void executeFaster(const std::vector &inputs, const std::vector &outputs); void tensorConvert(Tensor* input, Tensor* output, int bytes); private: std::map mTempInput; @@ -38,6 +38,7 @@ class CPURaster : public Execution { int32_t mZeroPoint = 0; bool mHasReduce = false; bool mUseThreads = false; + std::vector, int>> mTasks; }; } #endif diff --git a/source/backend/cpu/ThreadPool.cpp b/source/backend/cpu/ThreadPool.cpp index 15a2d8241c..d7765c4fbc 100644 --- a/source/backend/cpu/ThreadPool.cpp +++ b/source/backend/cpu/ThreadPool.cpp @@ -60,7 +60,7 @@ ThreadPool::ThreadPool(int numberThread) { while (mActiveCount > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { - mTasks[i].first.first(threadIndex); + mTasks[i].first->first(threadIndex); { *mTasks[i].second[threadIndex] = false; } } } @@ -118,16 +118,18 @@ void ThreadPool::deactive() { mActiveCount--; } -void ThreadPool::enqueue(TASK&& task, int index) { +void ThreadPool::enqueue(TASK* taskp, int index) { + auto& task = *taskp; if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } - enqueueInternal(std::move(task), index); + enqueueInternal(taskp, index); } -void ThreadPool::enqueueInternal(TASK&& task, int index) { +void ThreadPool::enqueueInternal(TASK* taskp, int index) { + auto& task = *taskp; if (mActiveCount == 0) { for (int i = 0; i < task.second; ++i) { task.first(i); @@ -135,24 +137,25 @@ void ThreadPool::enqueueInternal(TASK&& task, int index) { return; } int workSize = task.second; + TASK* tmpTask = nullptr; if (workSize > mNumberThread) { - mTasks[index].first = std::make_pair( - [workSize, &task, this](int tId) { - for (int v = tId; v < workSize; v += mNumberThread) { - task.first(v); - } - }, - mNumberThread); + tmpTask = new TASK; + *tmpTask = std::make_pair([workSize, &task, this](int tId) { + for (int v = tId; v < workSize; v += mNumberThread) { + task.first(v); + } + }, mNumberThread); + mTasks[index].first = tmpTask; workSize = mNumberThread; } else { - mTasks[index].first = std::move(task); + mTasks[index].first = taskp; } { for (int i = 1; i < workSize; ++i) { *mTasks[index].second[i] = true; } } - mTasks[index].first.first(0); + mTasks[index].first->first(0); bool complete = true; do { complete = true; @@ -165,6 +168,9 @@ void ThreadPool::enqueueInternal(TASK&& task, int index) { std::this_thread::yield(); // FUNC_PRINT(notComplete); } while (!complete); + if (nullptr != tmpTask) { + delete tmpTask; + } } } // namespace MNN #endif diff --git a/source/backend/cpu/ThreadPool.hpp b/source/backend/cpu/ThreadPool.hpp index 4bf23de1b0..8891da61b1 100644 --- a/source/backend/cpu/ThreadPool.hpp +++ b/source/backend/cpu/ThreadPool.hpp @@ -25,7 +25,7 @@ class MNN_PUBLIC ThreadPool { int numberThread() const { return mNumberThread; } - void enqueue(TASK&& task, int index); + void enqueue(TASK* task, int index); void active(); void deactive(); @@ -37,7 +37,7 @@ class MNN_PUBLIC ThreadPool { static void destroy(); private: - void enqueueInternal(TASK&& task, int index); + void enqueueInternal(TASK* task, int index); ThreadPool(int numberThread = 0); ~ThreadPool(); @@ -46,7 +46,7 @@ class MNN_PUBLIC ThreadPool { std::vector mTaskAvailable; std::atomic mStop = {false}; - std::vector>> mTasks; + std::vector>> mTasks; std::condition_variable mCondition; std::mutex mQueueMutex; diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index d7d0d7fb34..c9bfcc2189 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -3882,12 +3882,13 @@ void MNNVectorTop1Int32(int32_t* input, int32_t* maxValue, int32_t* maxIndex, si #endif -void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { +void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tIdL) { auto l = param->l; auto h = param->h; auto numberThread = param->numberThread; auto lC4 = l / 4; auto lR = lC4 * 4; + auto tId = (int)tIdL; if (param->BTranspose) { for (int y=tId; y= 8) { + if (0 == tId) { + auto bs = B + hEnd; + Vec4 sumValue0; + Vec4 sumValue1; + if (biasPtr != nullptr) { + sumValue0 = Vec4::load(biasPtr + hEnd + 0); + sumValue1 = Vec4::load(biasPtr + hEnd + 4); + } else { + sumValue0 = Vec4(0.0f); + sumValue1 = Vec4(0.0f); + } + auto srcY = A + hEnd * l; + for (int x=0; x= 4) { + if (0 == tId) { + auto bs = B + hEnd; + Vec4 sumValue0; + if (biasPtr != nullptr) { + sumValue0 = Vec4::load(biasPtr + hEnd + 0); + } else { + sumValue0 = Vec4(0.0f); + } + auto srcY = A + hEnd * l; + for (int x=0; x= 8) { + sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); + sum1 = Vec::fma(sum1, Vec4::load(srcY + lR + 4), Vec4::load(B + lR + 4)); + lR += 8; + } + if (l - lR >= 4) { + sumValue = Vec::fma(sumValue, Vec4::load(srcY + lR), Vec4::load(B + lR)); + lR += 4; + } + sum2 = sum2 + sum3; + sumValue = sumValue + sum1; + sumValue = sumValue + sum2; float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; for (int x=lR; xenqueue(task) + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ { \ std::pair, int> task; \ @@ -28,8 +33,7 @@ } \ ; \ auto cpuBn = (CPUBackend*)backend(); \ - auto thrPl = cpuBn->threadPool(); \ - thrPl->enqueue(std::move(task), cpuBn->taskIndex()); \ + cpuBn->enqueue(task); \ } #else @@ -38,6 +42,9 @@ #include #include +#define MNN_CONCURRENCY_ENQUEUE(task) \ +dispatch_apply(task.second, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) {task.first(__iter__);}); + #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, 0), ^(size_t __iter__) { #define MNN_CONCURRENCY_END() \ @@ -58,6 +65,8 @@ dispatch_apply(__num__, dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH, // Android #else #include +#define MNN_CONCURRENCY_ENQUEUE(task) \ +_Pragma("omp parallel for") for (int __iter__ = 0; __iter__ < task.second; __iter__++) {task.first(__iter__);} #define MNN_STRINGIFY(a) #a #define MNN_CONCURRENCY_BEGIN(__iter__, __num__) \ diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index c80afaef87..a69263ffaa 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -386,98 +386,7 @@ void OpCommonUtils::broastCastComputeDim(int* dims, int* stride, int* iStride0, } } } -std::vector> OpCommonUtils::computeReduceDims(const std::vector& inputs, - const Op* op) { - // Compute axises - std::vector axises; - if (inputs.size() >= 2) { - auto size = inputs[1]->elementSize(); - auto dims = inputs[1]->host(); - for (int i = 0; i < size; ++i) { - axises.emplace_back(dims[i]); - } - } else { - auto reduct = op->main_as_ReductionParam(); - if (nullptr != reduct->dim()) { - for (int i = 0; i < reduct->dim()->size(); ++i) { - axises.emplace_back(reduct->dim()->data()[i]); - } - } - } - auto totalSize = TensorUtils::getRawSize(inputs[0]); - if (axises.empty()) { - return {std::make_tuple(1, totalSize, 1)}; - } - for (int i = 0; i < axises.size(); ++i) { - if (axises[i] < 0) { - axises[i] = inputs[0]->dimensions() + axises[i]; - if (axises[i] < 0) { - return {std::make_tuple(1, totalSize, 1)}; - } - } - } - // Cache for input's dims - std::vector lengths(inputs[0]->dimensions()); - for (int i = 0; i < lengths.size(); ++i) { - lengths[i] = inputs[0]->length(i); - } - std::vector> groupAxises; - { - // Merge adj axis - std::sort(axises.begin(), axises.end()); - int lastAxis = axises[0]; - int length = 1; - int start = axises[0]; - for (int i = 1; i < axises.size(); ++i) { - // MNN_PRINT("%d - %d\n", axises[i], lastAxis); - if (axises[i] - lastAxis == 1) { - length++; - } else { - groupAxises.emplace_back(std::make_pair(start, length)); - length = 1; - start = axises[i]; - } - lastAxis = axises[i]; - } - groupAxises.emplace_back(std::make_pair(start, length)); - } - - // Compute inside-outside-axis - std::vector> result; - for (int i = 0; i < groupAxises.size(); ++i) { - int outsideSize = 1; - int insideSize = 1; - int axisSize = 1; - auto start = groupAxises[i].first; - auto length = groupAxises[i].second; - if (start >= (int)lengths.size()) { - break; - } - for (int j = 0; j < start; ++j) { - outsideSize *= lengths[j]; - } - for (int j = start; j < start + length; ++j) { - if (j >= (int)lengths.size()) { - break; - } - axisSize *= lengths[j]; - lengths[j] = 1; - } - for (int j = start + length; j < lengths.size(); ++j) { - insideSize *= lengths[j]; - } - if (1 == axisSize) { - continue; - } - result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); - } - // FUNC_PRINT(result.size()); - if (result.empty()) { - result.emplace_back(std::make_tuple(1, 1, totalSize)); - } - return result; -} void OpCommonUtils::unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice) { int value = indice; diff --git a/source/core/OpCommonUtils.hpp b/source/core/OpCommonUtils.hpp index 0740cc16b2..8ec0628336 100644 --- a/source/core/OpCommonUtils.hpp +++ b/source/core/OpCommonUtils.hpp @@ -56,7 +56,6 @@ class MNN_PUBLIC OpCommonUtils { static bool supportDynamicInputMemory(MNNForwardType type); static void broastCastComputeDim(int* dims, int* stride, int* iStride0, int* iStride1, const Tensor* input0, const Tensor* input1, const Tensor* output); - static std::vector> computeReduceDims(const std::vector& inputs, const Op* op); static void unravelIndexHelper(int32_t* coordinate, const int32_t* mod, int size, int indice); static int computeStride(int32_t* strides, const int* shape, int length); diff --git a/source/core/TensorUtils.cpp b/source/core/TensorUtils.cpp index ae5b87143c..d233fc9d89 100644 --- a/source/core/TensorUtils.cpp +++ b/source/core/TensorUtils.cpp @@ -32,6 +32,18 @@ bool TensorUtils::regionIsFull(Tensor* input) { return regionSize == size; } +void TensorUtils::makeFullRef(Tensor* output, Tensor* input) { + auto des = TensorUtils::getDescribe(input); + auto outputDes = TensorUtils::getDescribe(output); + outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) { + outputDes->regions = des->regions; + } else { + outputDes->regions = {makeFullSlice(input)}; + } +} + + Tensor::InsideDescribe::Region TensorUtils::makeFullSlice(Tensor* input) { Tensor::InsideDescribe::Region totalSlice; totalSlice.src.offset = 0; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 1342a669bd..a577fea05f 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -184,6 +184,7 @@ class MNN_PUBLIC TensorUtils { static void setupTensorInfo(const Tensor* tensor, Tensor* wrapTensor, MNN_DATA_FORMAT mMidFormat); static Tensor::InsideDescribe::Region makeFullSlice(Tensor* input); + static void makeFullRef(Tensor* output, Tensor* input); static bool regionIsFull(Tensor* input); static bool isCopyRegion(const Tensor::InsideDescribe::Region& region); static bool isTransposeRegion(const Tensor::InsideDescribe::Region& region); diff --git a/source/geometry/GeometryComputerUtils.cpp b/source/geometry/GeometryComputerUtils.cpp index 01a4e02ea2..85f64de55d 100644 --- a/source/geometry/GeometryComputerUtils.cpp +++ b/source/geometry/GeometryComputerUtils.cpp @@ -477,9 +477,9 @@ std::shared_ptr GeometryComputerUtils::makeBinary(int type, Tensor* inp return cmdP; } -std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output) { +std::shared_ptr GeometryComputerUtils::makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis) { flatbuffers::FlatBufferBuilder builder(DEFAULT_ALLOCATE_SIZE); - auto vec = builder.CreateVector(std::vector{1}); + auto vec = builder.CreateVector(std::vector{axis}); ReductionParamBuilder builder_(builder); builder_.add_operation(type); builder_.add_keepDims(true); diff --git a/source/geometry/GeometryComputerUtils.hpp b/source/geometry/GeometryComputerUtils.hpp index c0dffdcdb1..97c4d5811f 100644 --- a/source/geometry/GeometryComputerUtils.hpp +++ b/source/geometry/GeometryComputerUtils.hpp @@ -18,7 +18,7 @@ class GeometryComputerUtils { static void addConvert(const CommandBuffer& srcBuffer, CommandBuffer& dstBuffer, GeometryComputer::Context& ctx); static std::shared_ptr makeCommand(flatbuffers::FlatBufferBuilder& builder, const std::vector& inputs, const std::vector& outputs); static std::shared_ptr makeBinary(int type, Tensor* input0, Tensor* input1, Tensor* output); - static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output); + static std::shared_ptr makeReduce(ReductionType type, Tensor* input0, Tensor* output, int axis = 1); static std::shared_ptr makeUnary(UnaryOpOperation type, Tensor* input0, Tensor* output); static std::shared_ptr makeLayerNorm(Tensor* input0, Tensor* output, std::vector axis, float epsilon, std::vector gamma, std::vector beta, std::vector external, int group = 1, bool useRMS = false); static std::shared_ptr makeMatMul(Tensor* input0, Tensor* input1, Tensor* output, Tensor* Bias = nullptr, diff --git a/source/geometry/GeometryReduce.cpp b/source/geometry/GeometryReduce.cpp index c2a3bb4114..855f4bcf69 100644 --- a/source/geometry/GeometryReduce.cpp +++ b/source/geometry/GeometryReduce.cpp @@ -10,6 +10,83 @@ #include "geometry/GeometryComputerUtils.hpp" #include "core/OpCommonUtils.hpp" namespace MNN { +static std::vector> _computeReduceDims(const std::vector& inputs, + std::vector& axises) { + + auto totalSize = TensorUtils::getRawSize(inputs[0]); + if (axises.empty()) { + return {std::make_tuple(1, totalSize, 1)}; + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + if (axises[i] < 0) { + return {std::make_tuple(1, totalSize, 1)}; + } + } + } + // Cache for input's dims + std::vector lengths(inputs[0]->dimensions()); + for (int i = 0; i < lengths.size(); ++i) { + lengths[i] = inputs[0]->length(i); + } + std::vector> groupAxises; + { + // Merge adj axis + std::sort(axises.begin(), axises.end()); + int lastAxis = axises[0]; + int length = 1; + int start = axises[0]; + for (int i = 1; i < axises.size(); ++i) { + // MNN_PRINT("%d - %d\n", axises[i], lastAxis); + if (axises[i] - lastAxis == 1) { + length++; + } else { + groupAxises.emplace_back(std::make_pair(start, length)); + length = 1; + start = axises[i]; + } + lastAxis = axises[i]; + } + groupAxises.emplace_back(std::make_pair(start, length)); + } + + // Compute inside-outside-axis + std::vector> result; + + for (int i = 0; i < groupAxises.size(); ++i) { + int outsideSize = 1; + int insideSize = 1; + int axisSize = 1; + auto start = groupAxises[i].first; + auto length = groupAxises[i].second; + if (start >= (int)lengths.size()) { + break; + } + for (int j = 0; j < start; ++j) { + outsideSize *= lengths[j]; + } + for (int j = start; j < start + length; ++j) { + if (j >= (int)lengths.size()) { + break; + } + axisSize *= lengths[j]; + lengths[j] = 1; + } + for (int j = start + length; j < lengths.size(); ++j) { + insideSize *= lengths[j]; + } + if (1 == axisSize) { + continue; + } + result.emplace_back(std::make_tuple(outsideSize, axisSize, insideSize)); + } + // FUNC_PRINT(result.size()); + if (result.empty()) { + result.emplace_back(std::make_tuple(1, 1, totalSize)); + } + return result; +} + class GeometryReduce : public GeometryComputer { public: virtual bool onCompute(const Op* op, const std::vector& inputs, const std::vector& outputs, @@ -18,6 +95,31 @@ class GeometryReduce : public GeometryComputer { MNN_ASSERT(inputs.size() >= 1); auto reduct = op->main_as_ReductionParam(); auto reductOp = reduct->operation(); + std::vector axises; + if (inputs.size() >= 2) { + auto size = inputs[1]->elementSize(); + auto dims = inputs[1]->host(); + for (int i = 0; i < size; ++i) { + axises.emplace_back(dims[i]); + } + } else { + auto reduct = op->main_as_ReductionParam(); + if (nullptr != reduct->dim()) { + for (int i = 0; i < reduct->dim()->size(); ++i) { + axises.emplace_back(reduct->dim()->data()[i]); + } + } + } + for (int i = 0; i < axises.size(); ++i) { + if (axises[i] < 0) { + axises[i] = inputs[0]->dimensions() + axises[i]; + } + } + if (1 == axises.size() && TensorUtils::getDescribe(inputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4 && TensorUtils::getDescribe(outputs[0])->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + auto cmd = GeometryComputerUtils::makeReduce(reductOp, inputs[0], outputs[0], axises[0]); + res.command.emplace_back(std::move(cmd)); + return true; + } // prod([]) = 1 if (inputs[0]->elementSize() == 0) { if(!context.allocTensor(outputs[0])) { @@ -39,7 +141,7 @@ class GeometryReduce : public GeometryComputer { } return true; } - auto reduceDims = OpCommonUtils::computeReduceDims(inputs, op); + auto reduceDims = _computeReduceDims(inputs, axises); Tensor* currentInput = inputs[0]; MNN_ASSERT(reduceDims.size() > 0); auto dimType = currentInput->getDimensionType(); diff --git a/source/geometry/GeometryReshape.cpp b/source/geometry/GeometryReshape.cpp index 88d98a24c9..1df3384e37 100644 --- a/source/geometry/GeometryReshape.cpp +++ b/source/geometry/GeometryReshape.cpp @@ -42,8 +42,7 @@ class GeometryReshape : public GeometryComputer { return true; } } - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); return true; } }; @@ -75,10 +74,7 @@ class SingleGeometryComputer : public GeometryComputer { Context& context, CommandBuffer& res) const override { auto input = inputs[0]; auto output = outputs[0]; - auto inputDes = TensorUtils::getDescribe(input); - auto outputDes = TensorUtils::getDescribe(output); - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); return true; } }; @@ -94,8 +90,7 @@ class CopyGeometryComputer : public GeometryComputer { outputDes->tensorArrayAttr = inputDes->tensorArrayAttr; return true; } - outputDes->regions = {TensorUtils::makeFullSlice(input)}; - outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL; + TensorUtils::makeFullRef(output, input); } return true; } diff --git a/source/math/Vec.hpp b/source/math/Vec.hpp index 6839ab83b0..cc9354a7f1 100644 --- a/source/math/Vec.hpp +++ b/source/math/Vec.hpp @@ -372,8 +372,7 @@ struct Vec { using VecType = Vec; using VecTypeInt32 = Vec; float32x4_t value; - Vec() { - } + Vec() = default; Vec(const float v) { value = vdupq_n_f32(v); } diff --git a/test/core/ThreadPoolTest.cpp b/test/core/ThreadPoolTest.cpp index 6886f86e62..e010939e5f 100644 --- a/test/core/ThreadPoolTest.cpp +++ b/test/core/ThreadPoolTest.cpp @@ -26,11 +26,11 @@ class ThreadPoolTest : public MNNTestCase { auto workIndex = threadPool->acquireWorkIndex(); FUNC_PRINT(workIndex); threadPool->active(); - auto func = [](int index) { + ThreadPool::TASK task = std::make_pair([](int index) { FUNC_PRINT(index); std::this_thread::yield(); - }; - threadPool->enqueue(std::make_pair(std::move(func), 10), workIndex); + }, 10); + threadPool->enqueue(&task, workIndex); threadPool->deactive(); threadPool->releaseWorkIndex(workIndex); }); diff --git a/tools/cpp/ExprDebug.hpp b/tools/cpp/ExprDebug.hpp index 167e97c562..49e3db6156 100644 --- a/tools/cpp/ExprDebug.hpp +++ b/tools/cpp/ExprDebug.hpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #define DUMP_NUM_DATA(type) \ @@ -135,29 +136,69 @@ static void _initDebug() { struct TimeTraceInfo { - std::map>>> mTypes; + std::map>> mTypes; void begin(const MNN::OperatorInfo* info) { auto tIter = mTypes.find(info->type()); if (tIter == mTypes.end()) { - std::map>> _t; + std::map> _t; mTypes.insert(std::make_pair(info->type(), _t)); tIter = mTypes.find(info->type()); } mInserIter = tIter->second.find(info->name()); if (mInserIter == tIter->second.end()) { - std::vector> _t; - tIter->second.insert(std::make_pair(info->name(), _t)); + tIter->second.insert(std::make_pair(info->name(), std::make_tuple(0.0f, 0.0f, 0))); mInserIter = tIter->second.find(info->name()); } mTimer.reset(); } void end(const MNN::OperatorInfo* info) { auto timeInMs = (float)mTimer.durationInUs() / 1000.0f; - mInserIter->second.emplace_back(std::make_pair(timeInMs, info->flops())); + std::get<0>(mInserIter->second) += timeInMs; + std::get<1>(mInserIter->second) += info->flops(); + std::get<2>(mInserIter->second) ++; + } + void dump(bool dumpPerOp = false) { + if (dumpPerOp) { + auto cmp = [](const std::tuple& first, const std::tuple& second) { + return std::get<1>(first) > std::get<1>(second); + }; + std::priority_queue, std::vector>, decltype(cmp)> que(cmp); + for (auto& iter : mTypes) { + for (auto& t : iter.second) { + auto mergeType = t.first + " ["+iter.first +"]"; + auto unit = std::make_tuple(mergeType, std::get<0>(t.second), std::get<1>(t.second), std::get<2>(t.second)); + que.push(unit); + } + } + while (!que.empty()) { + auto& t = que.top(); + MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", std::get<0>(t).c_str(), std::get<1>(t), std::get<2>(t), std::get<3>(t), std::get<2>(t) / std::get<1>(t)); + que.pop(); + } + return; + } + float opSummer = 0.0f; + float opFlopsSummber = 0.0f; + for (auto& iter : mTypes) { + float summer = 0.0f; + float summerflops = 0.0f; + int count = 0; + for (auto& t : iter.second) { + summer += std::get<0>(t.second); + summerflops += std::get<1>(t.second); + count += std::get<2>(t.second); + } + MNN_PRINT("%s : %.7f ms, FLOP: %.7f, COUNT: %d, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, count, + summerflops / summer); + opSummer += summer; + opFlopsSummber += summerflops; + } + MNN_PRINT("OP Summer: %.7f ms, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, + opFlopsSummber / opSummer); } private: - std::map>>::iterator mInserIter; + std::map>::iterator mInserIter; MNN::Timer mTimer; }; static TimeTraceInfo* gTimeTraceInfo = nullptr; diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 90fa6b80d3..5798bc6d26 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -499,10 +499,13 @@ int main(int argc, char *argv[]) { if (runTime > 0) { int t = runTime; - std::vector times(t, 0.0f); if (runMask & 4) { _initTimeTrace(); } + float minTime = std::numeric_limits::max(); + float maxTime = 0.0f; + float sum = 0.0f; + for (int i = 0; i < t; ++i) { Timer _l; auto out = net->onForward(inputs); @@ -510,41 +513,28 @@ int main(int argc, char *argv[]) { for (auto o : out) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } - times[i] = _l.durationInUs() / 1000.0f; + auto time = _l.durationInUs() / 1000.0f; if (freq > 0.0f) { - float remainMs = (1000.0f / freq) - times[i]; + float remainMs = (1000.0f / freq) - time; if (remainMs > 0.0f) { std::this_thread::sleep_for(std::chrono::milliseconds((int)remainMs)); } } - } - if (nullptr != gTimeTraceInfo) { - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : gTimeTraceInfo->mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - for (auto& t : iter.second) { - for (auto& t0 : t.second) { - summer += t0.first; - summerflops += t0.second; - } - } - summer = summer / (float)t; - summerflops = summerflops / (float)t; - MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, summerflops / summer); - opSummer += summer; - opFlopsSummber+= summerflops; + if (maxTime < time) { + maxTime = time; + } + if (minTime > time) { + minTime = time; } - MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, opFlopsSummber/opSummer); - } - auto minTime = std::min_element(times.begin(), times.end()); - auto maxTime = std::max_element(times.begin(), times.end()); - float sum = 0.0f; - for (auto time : times) { sum += time; } - MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, *minTime, *maxTime); + if (nullptr != gTimeTraceInfo) { + MNN_PRINT("Per Op Trace: \n"); + gTimeTraceInfo->dump(true); + MNN_PRINT("Per Type Trace: \n"); + gTimeTraceInfo->dump(false); + } + MNN_PRINT("Avg= %f ms, min= %f ms, max= %f ms\n", sum / (float)t, minTime, maxTime); } rtmgr->updateCache(); return 0; diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 53af11239a..63c590e0fd 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -915,26 +915,7 @@ Llm::Llm(std::shared_ptr config) : mConfig(config) { Llm::~Llm() { #if DEBUG_MODE == 1 if (nullptr != gTimeTraceInfo) { - float opSummer = 0.0f; - float opFlopsSummber = 0.0f; - for (auto& iter : gTimeTraceInfo->mTypes) { - float summer = 0.0f; - float summerflops = 0.0f; - for (auto& t : iter.second) { - for (auto& t0 : t.second) { - summer += t0.first; - summerflops += t0.second; - } - } - summer = summer; - summerflops = summerflops; - MNN_PRINT("%s : %.7f, FLOP: %.7f, Speed: %.7f GFlops\n", iter.first.c_str(), summer, summerflops, - summerflops / summer); - opSummer += summer; - opFlopsSummber += summerflops; - } - MNN_PRINT("OP Summer: %.7f, Flops: %.7f, Speed: %.7f GFlops\n", opSummer, opFlopsSummber, - opFlopsSummber / opSummer); + gTimeTraceInfo->dump(); } #endif mGenerateParam.reset(); From afd359a702ecabfab675dee7b0504526d9cc4322 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Tue, 23 Dec 2025 12:36:55 +0800 Subject: [PATCH 091/314] Merge branch feature/metal_backgroup_issue into master Title: [Metal Feature] check UI Status for metal command commit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审主要增加了对执行状态的检查和错误处理,并引入了新的日志打印方式以提高调试和监控能力。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/24965986 GitOrigin-RevId: b7ad051c324c1b7d4aa231fc062f2f5d8e7f7a0f --- express/Expr.cpp | 16 +++++++++ express/Utils.cpp | 17 +++++++++ express/module/Module.cpp | 26 ++++++++++++++ source/backend/metal/MetalBackend.mm | 33 +++++++++++++++++ source/core/Backend.hpp | 1 + source/core/Tensor.cpp | 8 +++++ transformers/llm/engine/demo/llm_demo.cpp | 36 +++++++++---------- transformers/llm/engine/include/llm/llm.hpp | 9 +++++ transformers/llm/engine/src/llm.cpp | 18 +++++++++- .../engine/src/speculative_decoding/eagle.cpp | 16 +++++++++ .../src/speculative_decoding/generate.cpp | 13 ++++++- .../src/speculative_decoding/lookahead.cpp | 12 +++++++ .../engine/src/speculative_decoding/mtp.cpp | 12 +++++++ 13 files changed, 197 insertions(+), 20 deletions(-) diff --git a/express/Expr.cpp b/express/Expr.cpp index a735adbe3e..d9244e8de4 100644 --- a/express/Expr.cpp +++ b/express/Expr.cpp @@ -813,12 +813,28 @@ void* Variable::readInternal(bool forShape) { // The Varp will not be created as input, so we just need copy once return inside->mHostTensor->host(); } + inside->mHostTensor = new Tensor; TensorUtils::copyShape(originTensor, inside->mHostTensor, true); inside->mHostTensor->buffer().type = originTensor->getType(); inside->mHostTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(inside->mHostTensor->size(), MNN_MEMORY_ALIGN_DEFAULT); TensorUtils::getDescribe(inside->mHostTensor)->memoryType = Tensor::InsideDescribe::MEMORY_HOST; originTensor->copyToHostTensor(inside->mHostTensor); + bool hasNoExecution = false; + if (nullptr != originTensor) { + auto backend = TensorUtils::getDescribeOrigin(originTensor)->getBackend(); + if (nullptr != backend) { + // Try to sync to check execution status + int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, originTensor); + if (NO_EXECUTION == syncResult) { + hasNoExecution = true; + } + } + } + if (hasNoExecution) { + MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n"); + return nullptr; + } return inside->mHostTensor->host(); } return originTensor->buffer().host; diff --git a/express/Utils.cpp b/express/Utils.cpp index f71f2c997b..6aac549ece 100644 --- a/express/Utils.cpp +++ b/express/Utils.cpp @@ -181,6 +181,23 @@ void* Executor::ComputeCache::mapOutput(int offset, Tensor* dest) { if (0 == tensor->usize()) { return nullptr; } + + bool hasNoExecution = false; + if (nullptr != tensor) { + auto backend = TensorUtils::getDescribeOrigin(tensor)->getBackend(); + if (nullptr != backend) { + // Try to sync to check execution status + int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, tensor); + if (NO_EXECUTION == syncResult) { + hasNoExecution = true; + } + } + } + if (hasNoExecution) { + MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n"); + return nullptr; + } + Utils::allocMemoryForHostTensor(dest); if(nullptr != dest->host()) { tensor->copyToHostTensor(dest); diff --git a/express/module/Module.cpp b/express/module/Module.cpp index f8b4728153..3c54ca89c3 100644 --- a/express/module/Module.cpp +++ b/express/module/Module.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "core/OpCommonUtils.hpp" #include "PipelineModule.hpp" #include "core/FileLoader.hpp" @@ -17,6 +18,7 @@ #include "Utils.hpp" #include "RuntimeAttr.hpp" #include "ModuleInside.hpp" +#include "core/TensorUtils.hpp" #include #ifdef MNN_INTERNAL_ENABLED #include "internal/auth/ModelAuth.hpp" @@ -221,6 +223,30 @@ class NetModule : public Module { Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime); outputs = mModule->onForward(inputs); } + + // Check execution status after forward + if (!outputs.empty()) { + bool hasNoExecution = false; + for (auto& v : outputs) { + auto t = Utils::getTensor(v); + if (nullptr != t) { + auto backend = TensorUtils::getDescribeOrigin(t)->getBackend(); + if (nullptr != backend) { + // Try to sync to check execution status + int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, t); + if (NO_EXECUTION == syncResult) { + hasNoExecution = true; + break; + } + } + } + } + if (hasNoExecution) { + MNN_PRINT("Warning, Backend has stop execute, return empty output vector varps\n"); + outputs.clear(); + } + } + #ifdef MNN_INTERNAL_ENABLED do { if (outputs.empty()) { diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 79f52ff2dc..885808fc44 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -15,6 +15,7 @@ #define MTLGPUFamilyMetal3_MNN 5001 #define MTLGPUFamilyMetal4_MNN 5002 +#define CHECK_IOS_UI_STATUS #if MNN_METAL_ENABLED #include #import "backend/metal/MNNMetalContext.h" @@ -22,6 +23,9 @@ #import "core/TensorUtils.hpp" #include "MetalCache_generated.h" #include "core/MNNFileUtils.h" +#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE +#import +#endif int MNNMetalGetTensorContent(MNNMetalTensorContent* content, void* tensor) { if (nullptr == content || nullptr == tensor) { return 0; @@ -776,6 +780,9 @@ static void _execute(id encoder, const MetalBackend::C MNN_ASSERT(false); // should not be handled here } int MetalBackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) { + if (mRuntime->pExecutionStatus == NO_EXECUTION) { + return NO_EXECUTION; + } flushEncoder(); auto ctx = (__bridge MNNMetalContext *)context(); commit_net(); @@ -824,6 +831,19 @@ static void _execute(id encoder, const MetalBackend::C void MetalBackend::commit() const { +#ifdef CHECK_IOS_UI_STATUS +#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE + if ([UIApplication sharedApplication].applicationState == UIApplicationStateBackground || [UIApplication sharedApplication].applicationState == UIApplicationStateInactive) { + mRuntime->pExecutionStatus = NO_EXECUTION; + _commandBuffer = nil; + if (!mSupportDeferEncode) { + _commandBuffer_net = nil; + } + return; + } +#endif +#endif + mRuntime->pExecutionStatus = NO_ERROR; if (nil != _commandBuffer && _commandBuffer.status < MTLCommandBufferStatusCommitted) { [_commandBuffer commit]; mRuntime->_waiting = _commandBuffer; @@ -836,6 +856,19 @@ static void _execute(id encoder, const MetalBackend::C } void MetalBackend::commit_net() const { +#ifdef CHECK_IOS_UI_STATUS +#if defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE + if ([UIApplication sharedApplication].applicationState == UIApplicationStateBackground || [UIApplication sharedApplication].applicationState == UIApplicationStateInactive) { + mRuntime->pExecutionStatus = NO_EXECUTION; + _commandBuffer_net = nil; + if (!mSupportDeferEncode) { + _commandBuffer = nil; + } + return; + } +#endif +#endif + mRuntime->pExecutionStatus = NO_ERROR; if (nil != _commandBuffer_net && _commandBuffer_net.status < MTLCommandBufferStatusCommitted) { [_commandBuffer_net commit]; mRuntime->_waiting = _commandBuffer_net; diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 6850b6b4f6..e463f251f7 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -395,6 +395,7 @@ class Runtime : public NonCopyable { } mutable int pCurrentStatus = 0; // NO_ERROR + mutable int pExecutionStatus = 0; // NO_ERROR // TODO: Move to Backend void* pMeta = nullptr; diff --git a/source/core/Tensor.cpp b/source/core/Tensor.cpp index 18bf5ec7a6..664fa6b790 100644 --- a/source/core/Tensor.cpp +++ b/source/core/Tensor.cpp @@ -430,6 +430,14 @@ void* Tensor::map(MapType mtype, DimensionType dtype) { return mBuffer.host; } + if (mtype == Tensor::MAP_TENSOR_READ) { + int syncResult = bn->onSync(mtype, false, this); + if (NO_EXECUTION == syncResult) { + MNN_PRINT("Warning, Backend has stop execute, return nullptr for tensor map addr\n"); + return nullptr; + } + } + auto mapPtr = bn->onMapTensor(mtype, dtype, this); if(mapPtr != nullptr) { // Get mapPtr in specific backend diff --git a/transformers/llm/engine/demo/llm_demo.cpp b/transformers/llm/engine/demo/llm_demo.cpp index 305ef2169b..ec0f39c146 100644 --- a/transformers/llm/engine/demo/llm_demo.cpp +++ b/transformers/llm/engine/demo/llm_demo.cpp @@ -135,21 +135,21 @@ static int benchmark(Llm* llm, const std::vector& prompts, int max_ if (context->audio_input_s > 0.0f) { audio_speed = context->audio_input_s / audio_s; } - printf("\n#################################\n"); - printf("prompt tokens num = %d\n", prompt_len); - printf("decode tokens num = %d\n", decode_len); - printf(" vision time = %.2f s\n", vision_s); - printf(" pixels_mp = %.2f MP\n", context->pixels_mp); - printf(" audio process time = %.2f s\n", audio_s); - printf(" audio input time = %.2f s\n", context->audio_input_s); - printf("prefill time = %.2f s\n", prefill_s); - printf(" decode time = %.2f s\n", decode_s); - printf(" sample time = %.2f s\n", sample_s); - printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); - printf(" decode speed = %.2f tok/s\n", decode_len / decode_s); - printf(" vision speed = %.3f MP/s\n", vision_speed); - printf(" audio RTF = %.3f \n", audio_s / context->audio_input_s); - printf("##################################\n"); + MNN_PRINT("\n#################################\n"); + MNN_PRINT("prompt tokens num = %d\n", prompt_len); + MNN_PRINT("decode tokens num = %d\n", decode_len); + MNN_PRINT(" vision time = %.2f s\n", vision_s); + MNN_PRINT(" pixels_mp = %.2f MP\n", context->pixels_mp); + MNN_PRINT(" audio process time = %.2f s\n", audio_s); + MNN_PRINT(" audio input time = %.2f s\n", context->audio_input_s); + MNN_PRINT("prefill time = %.2f s\n", prefill_s); + MNN_PRINT(" decode time = %.2f s\n", decode_s); + MNN_PRINT(" sample time = %.2f s\n", sample_s); + MNN_PRINT("prefill speed = %.2f tok/s\n", prompt_len / prefill_s); + MNN_PRINT(" decode speed = %.2f tok/s\n", decode_len / decode_s); + MNN_PRINT(" vision speed = %.3f MP/s\n", vision_speed); + MNN_PRINT(" audio RTF = %.3f \n", audio_s / context->audio_input_s); + MNN_PRINT("##################################\n"); return 0; } @@ -165,12 +165,12 @@ static int ceval(Llm* llm, const std::vector& lines, std::string fi prompt += "\nC. " + elements[4]; prompt += "\nD. " + elements[5]; prompt += "\n\n"; - printf("%s", prompt.c_str()); - printf("## 进度: %d / %lu\n", i, lines.size() - 1); + MNN_PRINT("%s", prompt.c_str()); + MNN_PRINT("## 进度: %d / %lu\n", i, lines.size() - 1); std::ostringstream lineOs; llm->response(prompt.c_str(), &lineOs); auto line = lineOs.str(); - printf("%s", line.c_str()); + MNN_PRINT("%s", line.c_str()); answers.push_back(line); } { diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index 6ae61a5e35..20eff94be9 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -59,6 +59,13 @@ enum TuneType { // op encoder number for commit OP_ENCODER_NUMBER = 0, }; +enum class LlmStatus { + RUNNING = 0, + NORMAL_FINISHED = 1, + MAX_TOKENS_FINISHED = 2, + USER_CANCEL = 3, + INTERNAL_ERROR = 4, +}; enum class MatchStrictLevel : int; enum class NgramSelectRule : int; @@ -84,6 +91,8 @@ struct LlmContext { std::vector history_tokens; std::vector output_tokens; std::string generate_str; + // llm status + LlmStatus status; }; struct GenerationParams; class MNN_PUBLIC Llm { diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 63c590e0fd..c0cabd4414 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -467,6 +467,7 @@ std::vector Llm::forwardRaw(Express::VARP hiddenState, Express::V std::vector outputs = selectModule->onForward(inputs); if (outputs.empty()) { + mContext->status = LlmStatus::INTERNAL_ERROR; return outputs; } if (!mAsync) { @@ -592,6 +593,9 @@ std::vector Llm::forwardVec(MNN::Express::VARP input_embeds) { auto attention_mask = gen_attention_mask(blockSize); auto position_ids = gen_position_ids(blockSize); logits = forwardRaw(embed, attention_mask, position_ids); + if(logits.empty()) { + return logits; + } updateContext(blockSize, 0); } bool hasPad = false; @@ -623,6 +627,9 @@ std::vector Llm::forwardVec(MNN::Express::VARP input_embeds) { auto attention_mask = gen_attention_mask(forwardSize); auto position_ids = gen_position_ids(forwardSize); logits = forwardRaw(input_embeds, attention_mask, position_ids); + if(logits.empty()) { + return logits; + } } updateContext(-blockSize * blockNumber, 0); if (hasPad) { @@ -676,6 +683,7 @@ void Llm::generate_init(std::ostream* os, const char* end_with) { mContext->decode_us = 0; mContext->current_token = -1; mContext->sample_us = 0; + mContext->status = LlmStatus::RUNNING; if (!mConfig->reuse_kv()) { mContext->all_seq_len = 0; mContext->history_tokens.clear(); @@ -824,6 +832,7 @@ std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) Timer _t; forwardVec(input_embeds); if(mGenerateParam->outputs.size() < 1) { + mContext->status = LlmStatus::INTERNAL_ERROR; return {}; } updateContext(seqLen, 0); @@ -1132,7 +1141,14 @@ VARP Llm::gen_position_ids(int seq_len) { } bool Llm::is_stop(int token_id) { - return mTokenizer->is_stop(token_id); + if (mContext->status == LlmStatus::USER_CANCEL || mContext->status == LlmStatus::INTERNAL_ERROR) { + return true; + } + bool stop = mTokenizer->is_stop(token_id); + if (stop) { + mContext->status = LlmStatus::NORMAL_FINISHED; + } + return stop; } } // namespace Transformer } // namespace MNN diff --git a/transformers/llm/engine/src/speculative_decoding/eagle.cpp b/transformers/llm/engine/src/speculative_decoding/eagle.cpp index b4c892fd97..15548c64d3 100644 --- a/transformers/llm/engine/src/speculative_decoding/eagle.cpp +++ b/transformers/llm/engine/src/speculative_decoding/eagle.cpp @@ -328,9 +328,22 @@ void EagleGeneration::generate(GenerationParams& param) { std::vector accpetLens; auto newTokens = 0, steps = 0; while (true) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } steps++; MNN::Timer _dt; auto decodingInfo = treeDecoding(draftInfo); + for (auto o : decodingInfo) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } + if(decodingInfo.empty()) { + break; + } + treeDecodingTime += _dt.durationInUs(); auto acceptInfo = evaluatePosterior(draftInfo, decodingInfo[0]); newTokens += acceptInfo.acceptTokens.size(); @@ -352,6 +365,9 @@ void EagleGeneration::generate(GenerationParams& param) { eagleGenerateTime += _gt.durationInUs(); } mContext->decode_us += _t.durationInUs(); + if(newTokens >= param.max_new_tokens) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } #if EAGLE_DEBUG printf("\n### Tree Decoding Time: %f s, Eagle Generate Time: %f s\n", (float)treeDecodingTime / 1000000.0, (float)eagleGenerateTime / 1000000.0); printf("\n### Tree Decoding Avg Time: %f ms, steps: %d\n", (float)treeDecodingTime / 1000.0 / steps, steps); diff --git a/transformers/llm/engine/src/speculative_decoding/generate.cpp b/transformers/llm/engine/src/speculative_decoding/generate.cpp index 31d3a3b9f7..4ed01b1f5c 100644 --- a/transformers/llm/engine/src/speculative_decoding/generate.cpp +++ b/transformers/llm/engine/src/speculative_decoding/generate.cpp @@ -43,6 +43,9 @@ void ArGeneration::generate(GenerationParams& param) { int max_token = param.max_new_tokens; int len = 0; while (len < max_token) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } AUTOTIME; // Update gen seq mContext->current_token = mLlm->sample(param.outputs[0], param.validLogitStart, param.validLogitSize); @@ -63,9 +66,14 @@ void ArGeneration::generate(GenerationParams& param) { *mContext->os << decodeStr; *mContext->os << std::flush; } - // Compute Next Logits auto outputs = mLlm->forwardVec({mContext->current_token}); + for (auto o : outputs) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } if(outputs.empty()) { break; } @@ -74,6 +82,9 @@ void ArGeneration::generate(GenerationParams& param) { mContext->decode_us += _t.durationInUs(); len++; } + if(len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } } int Generation::draftVerify(VARP logits, const std::vector &drafts, bool& stop) { diff --git a/transformers/llm/engine/src/speculative_decoding/lookahead.cpp b/transformers/llm/engine/src/speculative_decoding/lookahead.cpp index cf4c2a5c79..d8ce38037e 100644 --- a/transformers/llm/engine/src/speculative_decoding/lookahead.cpp +++ b/transformers/llm/engine/src/speculative_decoding/lookahead.cpp @@ -89,6 +89,9 @@ void LookaheadGeneration::generate(GenerationParams& param) { int verify_len = mLlm->mDraftLength + 1; while (len < max_token) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } MNN::Timer _t; std::vector drafts; drafts.push_back(mContext->current_token); @@ -126,6 +129,12 @@ void LookaheadGeneration::generate(GenerationParams& param) { AUTOTIME; // do draft token parallel verify auto outputs = mLlm->forwardVec(drafts); + for (auto o : outputs) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } if(outputs.empty()) { break; } @@ -192,6 +201,9 @@ void LookaheadGeneration::generate(GenerationParams& param) { } } } + if(len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } #ifdef DUMP_PROFILE_INFO // adopt speculative decoding rate float spl_rate = 100.0 * spl_count / (spl_count + arg_count); diff --git a/transformers/llm/engine/src/speculative_decoding/mtp.cpp b/transformers/llm/engine/src/speculative_decoding/mtp.cpp index aefc4a5aa7..f5c6e0261a 100644 --- a/transformers/llm/engine/src/speculative_decoding/mtp.cpp +++ b/transformers/llm/engine/src/speculative_decoding/mtp.cpp @@ -151,6 +151,9 @@ void MtpGeneration::generate(GenerationParams& param) { int spl_count = 0; while (len < max_token) { + if(mContext->status == LlmStatus::USER_CANCEL) { + break; + } MNN::Timer _t; std::vector drafts; drafts.push_back(mContext->current_token); @@ -171,6 +174,12 @@ void MtpGeneration::generate(GenerationParams& param) { AUTOTIME; // do draft token parallel verify auto outputs = mLlm->forwardVec(drafts); + for (auto o : outputs) { + if(nullptr == o->readMap()) { + mContext->status = LlmStatus::INTERNAL_ERROR; + break; + } + } if (outputs.size() < 2) { break; } @@ -238,6 +247,9 @@ void MtpGeneration::generate(GenerationParams& param) { } } } + if(len >= max_token) { + mContext->status = LlmStatus::MAX_TOKENS_FINISHED; + } #ifdef DUMP_PROFILE_INFO // draft accept rate if adopt speculative decoding float spl_accept_rate = 100.0 * spl_accept / spl_decode; From 731c263a96febc92dc817de0232de8a5914e217b Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Tue, 23 Dec 2025 15:47:54 +0800 Subject: [PATCH 092/314] Merge branch feature/fix_sync into master Title: [Bugfix:CI] Fix duplicate msg when sync to github. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 这段代码在 `copybara_sync.sh` 脚本中新增了一个功能,用于检测并跳过从 GitHub 导入的 commits,通过识别包含 `GitOrigin-RevId` 的 commit 来确定上次同步点,并从该点之后的第一个非导入 commit 开始进行同步。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/25060238 GitOrigin-RevId: ca65f11f52c1b76a826cbdc260a063d1467a8f35 --- docs/compile/cmake.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index 6513e38fad..91a9c03959 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -101,4 +101,5 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_BUILD_LLM_OMNI | 若构建基于MNN的llm库和demo,是否支持图像和音频输入功能,默认为`OFF` 。仅在MNN_BUILD_LLM 打开时生效。开启时 MNN_BUILD_OPENCV , MNN_IMGCODECS , MNN_BUILD_AUDIO 同时打开| | MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo,默认为`OFF` . 打开时MNN_BUILD_OPENCV , MNN_IMGCODECS, MNN_LOW_MEMORY, MNN_SUPPORT_TRANSFORMER_FUSE 同步开启| | MNN_KLEIDIAI | 是否集成ARM的klediAI加速库,默认为`ON` | +| MNN_KLEIDIAI_DEFAULT_ON | 是否默认使用KLEIDIAI的Kernel, 默认为`OFF` | | MNN_USE_RVV | 是否启用RISC-V向量扩展支持,默认为`OFF` | From 38681b52594c4e9f4554a7b59d8617ecab2e0406 Mon Sep 17 00:00:00 2001 From: MNNSyncBot Date: Tue, 23 Dec 2025 16:08:39 +0800 Subject: [PATCH 093/314] =?UTF-8?q?Merge=20branch=20feature/opencl=5Fmmap?= =?UTF-8?q?=5Fsupport=20into=20master=20Title:=20[feature:opencl]opencl?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=B0=86=E6=9D=83=E9=87=8D=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E5=88=B0=E5=8D=95=E4=B8=AA=E6=96=87=E4=BB=B6=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本次代码评审的主要内容是对OpenCL后端进行了优化,引入了`MmapPool`以支持内存映射池管理,并在多个执行单元中增加了对内存映射错误的检查与处理,同时调整了部分数据传输和转换逻辑以提高性能和稳定性。 Link: https://code.alibaba-inc.com/AliNN/AliNNPrivate/codereview/24702294 GitOrigin-RevId: 1a83d5da23cbb011d0cf522cdc6d49f5778c0999 --- source/backend/opencl/core/MmapPool.cpp | 282 +++++++++++++++ source/backend/opencl/core/MmapPool.hpp | 103 ++++++ source/backend/opencl/core/OpenCLBackend.cpp | 68 +++- source/backend/opencl/core/OpenCLBackend.hpp | 10 +- .../execution/buffer/ConvBufExecution.cpp | 194 +++++----- .../execution/buffer/ConvBufExecution.hpp | 4 +- .../buffer/ConvBufLowMemoryExecution.cpp | 340 +++++++++++------- .../execution/buffer/ConvBufWinograd.cpp | 103 +++--- .../buffer/ConvSubgroupBufExecution.cpp | 31 +- .../execution/buffer/DeconvBufExecution.cpp | 47 ++- .../buffer/DepthwiseConvBufExecution.cpp | 37 +- .../DepthwiseConvSubgroupBufExecution.cpp | 112 +++--- .../buffer/GroupNormBufExecution.cpp | 80 ++--- .../buffer/LayerNormBufExecution.cpp | 77 ++-- .../execution/buffer/ReluBufExecution.cpp | 29 +- .../execution/buffer/ScaleBufExecution.cpp | 73 ++-- 16 files changed, 1064 insertions(+), 526 deletions(-) create mode 100644 source/backend/opencl/core/MmapPool.cpp create mode 100644 source/backend/opencl/core/MmapPool.hpp diff --git a/source/backend/opencl/core/MmapPool.cpp b/source/backend/opencl/core/MmapPool.cpp new file mode 100644 index 0000000000..8b8e893628 --- /dev/null +++ b/source/backend/opencl/core/MmapPool.cpp @@ -0,0 +1,282 @@ +// +// MmapPool.cpp +// MNN +// +// Created by MNN on 2025/12/02. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#include "backend/opencl/core/MmapPool.hpp" +namespace MNN { +namespace OpenCL { +// only support static memory +OpenCLMmapAllocator::OpenCLMmapAllocator(const char* dirName, const char* prefix, const char* posfix, bool autoRemove) { + if (nullptr != dirName) { + mFileName = dirName; + if (!MNNCreateDir(dirName)) { + MNN_ERROR("%s not exist\n", dirName); + } + } + if (nullptr != prefix) { + mPrefix = prefix; + } + if (nullptr != posfix) { + mPosfix = posfix; + } + mRemove = autoRemove; +} + +std::string OpenCLMmapAllocator::onAlloc(size_t size) { + MNN_ASSERT(size > 0); + MNN_ASSERT(!mSynced); + std::string name = mPrefix + std::to_string(mAllocTimes) + "." + mPosfix; + std::string fileName = MNNFilePathConcat(mFileName, name); + file_t file; + if (MNNFileExist(fileName.c_str())) { + file = MNNOpenFile(fileName.c_str(), MNN_FILE_READ | MNN_FILE_WRITE); + } else { + file = MNNCreateFile(fileName.c_str()); + auto code = MNNSetFileSize(file, size); + if (NO_ERROR != code) { + MNN_ERROR("Set File size %lu error= %d\n", size, code); + } + mNewMmap = true; + } + mCache.insert(std::make_pair(fileName, std::make_tuple(file, size))); + mAllocTimes++; + return fileName; +} +bool OpenCLMmapAllocator::read(std::string fileName, size_t offset, size_t size, void* buffer){ + auto iter = mCache.find(fileName); + if (iter == mCache.end()) { + MNN_ASSERT(false); + MNN_ERROR("Invalid mmap for OpenCLMmapAllocator\n"); + return false; + } + file_t file = std::get<0>(iter->second); + auto ret = MNNSetFilePointer(file, offset); + if (ret != NO_ERROR) { + return false; + } + auto readSize = MNNReadFile(file, buffer, size); + if (readSize != size) { + return false; + } +} +bool OpenCLMmapAllocator::write(std::string fileName, size_t offset, size_t size, void* buffer){ + auto iter = mCache.find(fileName); + if (iter == mCache.end()) { + MNN_ASSERT(false); + MNN_ERROR("Invalid unMmap for OpenCLMmapAllocator\n"); + return false; + } + file_t file = std::get<0>(iter->second); + auto ret = MNNSetFilePointer(file, offset); + if (ret != NO_ERROR) { + return false; + } + auto writeSize = MNNWriteFile(file, buffer, size); + if (writeSize != size) { + return false; + } +} +void OpenCLMmapAllocator::sync() { + if (!mRemove && mNewMmap) { + std::string cacheName = mPrefix + "sync." + mPosfix; + std::string fileName = MNNFilePathConcat(mFileName, cacheName); + MNNCreateFile(fileName.c_str()); + } +} + +std::shared_ptr MmapPool::allocBuffer(size_t size, bool separate) { + if (!separate) { + auto iter = mFreeBufferList.lower_bound(size); + if (iter != mFreeBufferList.end()) { + auto buffer = iter->second->buffer; + mFreeBufferList.erase(iter); + return buffer; + } + } + std::string fileName; + for(auto iter : mFileInfo){ + if(mFileSize - iter.second >= size){ + fileName = iter.first; + } + } + if(fileName.length() == 0){ + //need open new file + fileName = mOrigin->onAlloc(mFileSize); + mFileInfo.insert(std::make_pair(fileName, 0)); + } + + std::shared_ptr node(new OpenCLMmapBufferNode); + cl_int ret = CL_SUCCESS; + mTotalSize += size; + node->fileName = fileName; + node->size = size; + node->buffer.reset(new cl::Buffer(mContext, mFlag, size, NULL, &ret)); + node->offset = mFileInfo[fileName]; + mFileInfo[fileName] += size; + if (nullptr == node->buffer.get() || ret != CL_SUCCESS) { + MNN_ERROR("Alloc Buffer %lu error, code:%d \n", size, ret); + return nullptr; + } + if(mUseCachedMmap > 1){ + auto CLptr = mCommand.enqueueMapBuffer(*node->buffer.get(), CL_TRUE, CL_MAP_WRITE, 0, node->size); + if(CLptr == nullptr){ + MNN_ERROR("map buffer %d error\n", node->size); + return nullptr; + } + mOrigin->read(fileName, node->offset, node->size, CLptr); + mCommand.enqueueUnmapMemObject(*node->buffer.get(), CLptr); + } + mAllBuffer.insert(std::make_pair(node->buffer.get(), node)); + return node->buffer; +} + +std::shared_ptr MmapPool::allocImage(size_t w, size_t h, cl_channel_type type, bool separate) { + if (!separate) { + int minWaste = 0; + auto findIter = mFreeImageList.end(); + for (auto iterP = mFreeImageList.begin(); iterP != mFreeImageList.end(); iterP++) { + auto& iter = *iterP; + if (iter->w >= w && iter->h >= h && iter->type == type) { + int waste = iter->w * iter->h - w * h; + if (minWaste == 0 || waste < minWaste) { + findIter = iterP; + minWaste = waste; + } + } + } + if (findIter != mFreeImageList.end()) { + auto image = (*findIter)->image; + mFreeImageList.erase(findIter); + return image; + } + } + + std::shared_ptr node(new OpenCLMmapImageNode); + cl_int ret = CL_SUCCESS; + size_t row_pitch, slice_pitch; + node->w = w; + node->h = h; + node->type = type; + node->image.reset(new cl::Image2D(mContext, mFlag, cl::ImageFormat(CL_RGBA, type), w, h, 0, nullptr, &ret)); + if (nullptr == node->image.get() || ret != CL_SUCCESS) { + MNN_ERROR("Alloc Image %d x %d error, code:%d \n", w, h, ret); + return nullptr; + } + auto CLptr = mCommand.enqueueMapImage(*node->image.get(), CL_TRUE, CL_MAP_WRITE, {0, 0, 0}, {w, h, 1}, &row_pitch, &slice_pitch); + if(CLptr == nullptr){ + MNN_ERROR("map Image %d x %d error\n", w, h); + return nullptr; + } + size_t size = h * row_pitch; + + std::string fileName; + for(auto iter : mFileInfo){ + if(mFileSize - iter.second >= size){ + fileName = iter.first; + } + } + if(fileName.length() == 0){ + //need open new file + fileName = mOrigin->onAlloc(mFileSize); + mFileInfo.insert(std::make_pair(fileName, 0)); + } + node->fileName = fileName; + node->size = size; + node->offset = mFileInfo[fileName]; + mFileInfo[fileName] += size; + if(mUseCachedMmap > 1){ + mOrigin->read(fileName, node->offset, node->size, CLptr); + } + + mCommand.enqueueUnmapMemObject(*node->image.get(), CLptr); + mAllImage.insert(std::make_pair(node->image.get(), node)); + return node->image; +} + +void MmapPool::recycle(cl::Buffer* buffer, bool release) { + auto iter = mAllBuffer.find(buffer); + if (iter == mAllBuffer.end()) { + MNN_ERROR("Error for recycle buffer\n"); + return; + } + if (release) { + mAllBuffer.erase(iter); + return; + } + mFreeBufferList.insert(std::make_pair(iter->second->size, iter->second)); +} + +void MmapPool::recycle(cl::Image* image, bool release) { + auto iter = mAllImage.find(image); + if (iter == mAllImage.end()) { + MNN_ERROR("Error for recycle image\n"); + return; + } + if (release) { + mAllImage.erase(iter); + return; + } + mFreeImageList.push_back(iter->second); +} + +void MmapPool::clear() { + mFreeBufferList.clear(); + mFreeImageList.clear(); + mAllBuffer.clear(); + mAllImage.clear(); + mTotalSize = 0; +} + +void MmapPool::releaseFreeList() { + for(auto mf : mFreeBufferList){ + auto iter = mAllBuffer.find(mf.second->buffer.get()); + if (iter != mAllBuffer.end()) { + mAllBuffer.erase(iter); + } + } + mFreeBufferList.clear(); + + for(auto mf : mFreeImageList){ + auto iter = mAllImage.find(mf->image.get()); + if (iter != mAllImage.end()) { + mAllImage.erase(iter); + } + } + mFreeImageList.clear(); +} + +void MmapPool::sync() { + if(mHasSync){ + return; + } + if(mUseCachedMmap == 1){ + for(auto iter : mAllBuffer){ + auto node = iter.second; + auto CLptr = mCommand.enqueueMapBuffer(*node->buffer.get(), CL_TRUE, CL_MAP_WRITE, 0, node->size); + if(CLptr == nullptr){ + MNN_ERROR("map buffer %d error\n", node->size); + continue; + } + mOrigin->write(node->fileName, node->offset, node->size, CLptr); + mCommand.enqueueUnmapMemObject(*node->buffer.get(), CLptr); + } + for(auto iter : mAllImage){ + auto node = iter.second; + size_t row_pitch, slice_pitch; + size_t w = node->w; + size_t h = node->h; + auto CLptr = mCommand.enqueueMapImage(*node->image.get(), CL_TRUE, CL_MAP_WRITE, {0, 0, 0}, {w, h, 1}, &row_pitch, &slice_pitch); + mOrigin->write(node->fileName, node->offset, node->size, CLptr); + mCommand.enqueueUnmapMemObject(*node->image.get(), CLptr); + } + } + mOrigin->sync(); + mHasSync = true; +}; + +} // namespace OpenCL +} // namespace MNN diff --git a/source/backend/opencl/core/MmapPool.hpp b/source/backend/opencl/core/MmapPool.hpp new file mode 100644 index 0000000000..84c79f18fb --- /dev/null +++ b/source/backend/opencl/core/MmapPool.hpp @@ -0,0 +1,103 @@ +// +// MmapPool.hpp +// MNN +// +// Created by MNN on 2025/12/02. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifndef MmapPool_hpp +#define MmapPool_hpp + +#include +#include +#include +#include +#include +#include "core/NonCopyable.hpp" +#include "backend/opencl/core/runtime/OpenCLWrapper.hpp" +#include "core/BufferAllocator.hpp" +#include "core/MNNFileUtils.h" + +namespace MNN { +namespace OpenCL { +struct OpenCLMmapBufferNode{ + OpenCLMmapBufferNode(){}; + std::string fileName; + size_t offset; + size_t size; + std::shared_ptr buffer; +}; + +struct OpenCLMmapImageNode { + OpenCLMmapImageNode(){}; + std::string fileName; + size_t offset; + size_t size; + int w; + int h; + cl_channel_type type; + std::shared_ptr image; +}; + +class OpenCLMmapAllocator { +private: + std::map> mCache; + std::string mFileName; + std::string mPrefix; + std::string mPosfix; + int mAllocTimes = 0; + bool mRemove; + bool mNewMmap = false; + +public: + OpenCLMmapAllocator(const char* dirName, const char* prefix, const char* posfix, bool autoRemove); + ~ OpenCLMmapAllocator() { + for (auto& iter : mCache) { + MNNCloseFile(std::get<0>(iter.second)); + if (mRemove) { + MNNRemoveFile(iter.first.c_str()); + } + } + } + std::string onAlloc(size_t size); + bool read(std::string fileName, size_t offset, size_t size, void* buffer); + bool write(std::string fileName, size_t offset, size_t size, void* buffer); + void sync(); +}; + +class MmapPool : public NonCopyable { +public: + MmapPool(std::shared_ptr origin, cl::Context& context, cl::CommandQueue& command, cl_mem_flags flags, int useCacheMmap) : mOrigin(origin), mContext(context), mCommand(command), mFlag(flags), mUseCachedMmap(useCacheMmap) {} + + std::shared_ptr allocBuffer(size_t size, bool separate = false); + std::shared_ptr allocImage(size_t w, size_t h, cl_channel_type type, bool separate = false); + void recycle(cl::Buffer* buffer, bool release = false); + void recycle(cl::Image* image, bool release = false); + void clear(); + void releaseFreeList(); + void sync(); + size_t totalSize() { return mTotalSize; } + +private: + std::map> mAllBuffer; + std::multimap> mFreeBufferList; + std::map> mAllImage; + std::list> mFreeImageList; + std::map mFileInfo; + std::shared_ptr mOrigin; + + cl::Context& mContext; + cl::CommandQueue& mCommand; + cl_mem_flags mFlag; + size_t mTotalSize = 0; + int mUseCachedMmap; + bool mHasSync = false; + size_t mFileSize = 1024*1024*1024; +}; + + +} // namespace OpenCL +} // namespace MNN + +#endif /* MmapPool_hpp */ diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp index 19af039ce2..52c6635cda 100644 --- a/source/backend/opencl/core/OpenCLBackend.cpp +++ b/source/backend/opencl/core/OpenCLBackend.cpp @@ -17,6 +17,7 @@ #include #include "core/Macro.h" #include "runtime/OpenCLTuneInfo.hpp" +#include "core/MNNFileUtils.h" #ifdef __ANDROID__ #include #endif @@ -66,6 +67,7 @@ CLRuntime::CLRuntime(const Backend::Info& info){ CLRuntime::~CLRuntime() { mImagePool = nullptr; mBufferPool = nullptr; + mMmapPool = nullptr; mOpenCLRuntime = nullptr; delete mTunedInfo; } @@ -215,7 +217,24 @@ Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const if(precision > 2 || precision < 0){ precision = BackendConfig::Precision_High; } - auto backend = new OpenCLBackend(precision, memory, mInfo.gpuMode, mImagePool, mBufferPool, this); + + if (hint().weightMemoryPath.size() > 0 && mMmapPool.get() == nullptr) { + // Only support set weightmap dir once + // forward_type, precision_type, memory_type, power_type + std::string prefix = "1_0_0_0_"; + std::string posfix = "opencl.weight"; + auto syncPath = prefix + "sync." + posfix; + bool autoRemove = true; + if (hint().useCachedMmap) { + autoRemove = false; + std::string fileName = MNNFilePathConcat(hint().weightMemoryPath, syncPath); + const_cast(hint()).useCachedMmap += MNNFileExist(fileName.c_str()); + } + std::shared_ptr mmap; + mmap.reset(new OpenCLMmapAllocator(hint().weightMemoryPath.c_str(), prefix.c_str(), posfix.c_str(), autoRemove)); + mMmapPool.reset(new MmapPool(mmap, mOpenCLRuntime->context(), mOpenCLRuntime->commandQueue(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, hint().useCachedMmap)); + } + auto backend = new OpenCLBackend(precision, memory, mInfo.gpuMode, mImagePool, mBufferPool, mMmapPool, this); backend->setMetaPtr(pMeta); return backend; } @@ -223,6 +242,9 @@ Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const void CLRuntime::onGabageCollect(int level) { mImagePool->releaseFreeList(); mBufferPool->releaseFreeList(); + if(mMmapPool != nullptr){ + mMmapPool->releaseFreeList(); + } } float CLRuntime::onGetMemoryInMB() { @@ -241,7 +263,7 @@ std::map, OpenCLBackend::Creator*>* gCreator() { return creators; }; -OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, int gpuMode, std::shared_ptrimgPool, std::shared_ptr bufPool, const CLRuntime *runtime) +OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, int gpuMode, std::shared_ptrimgPool, std::shared_ptr bufPool, std::shared_ptr mmapPool, const CLRuntime *runtime) : Backend(MNN_FORWARD_OPENCL) { mGpuMode = gpuMode; @@ -264,6 +286,7 @@ OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConf setGpuMode(gpuMode); mStaticImagePool = imgPool; mStaticBufferPool = bufPool; + mStaticAllocatorMMap = mmapPool; if(mOpenCLRuntime.get()){ if(mOpenCLRuntime->isCreateError() == true) { mIsCreateError = true; @@ -334,6 +357,30 @@ class CLMemReleaseBuffer : public Backend::MemObj { BufferPool* mBufferPool; }; +class CLMemReleaseMmapBuffer : public Backend::MemObj { +public: + CLMemReleaseMmapBuffer(cl::Buffer* bId, MmapPool* mmapPool) { + mBuffer = bId; + mMmapPool = mmapPool; + } + CLMemReleaseMmapBuffer(cl::Image* bId, MmapPool* mmapPool) { + mImage = bId; + mMmapPool = mmapPool; + } + virtual ~ CLMemReleaseMmapBuffer() { + if(mBuffer != nullptr){ + mMmapPool->recycle(mBuffer); + } + if(mImage != nullptr){ + mMmapPool->recycle(mImage); + } + } +private: + cl::Buffer* mBuffer = nullptr; + cl::Image* mImage = nullptr; + MmapPool* mMmapPool = nullptr; +}; + class CLMemReleaseImage : public Backend::MemObj { public: CLMemReleaseImage(cl::Image* bId, ImagePool* bufferPool) { @@ -422,10 +469,16 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp return new CLReleaseExecutionBuffer(node, mExecutionBufferPool.get()); } MNN_ASSERT(storageType == STATIC); - - auto buffer = mStaticBufferPool->alloc(size*typeSize); - ((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer; // fix - return new CLMemReleaseBuffer(buffer, mStaticBufferPool.get()); + if(mCLRuntime->hint().useCachedMmap && mStaticAllocatorMMap.get() != nullptr) + { + auto buffer = mStaticAllocatorMMap->allocBuffer(size*typeSize).get(); + ((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer; // fix + return new CLMemReleaseMmapBuffer(buffer, mStaticAllocatorMMap.get()); + }else{ + auto buffer = mStaticBufferPool->alloc(size*typeSize); + ((Tensor*)nativeTensor)->buffer().device = (uint64_t)buffer; // fix + return new CLMemReleaseBuffer(buffer, mStaticBufferPool.get()); + } } else #endif /* MNN_OPENCL_BUFFER_CLOSED */ @@ -493,6 +546,9 @@ bool OpenCLBackend::onSelectDynamicAllocator(int index, int maxIndex) { bool OpenCLBackend::onClearBuffer() { mImagePool->clear(); mBufferPool->clear(); + if(mStaticAllocatorMMap.get() != nullptr){ + mStaticAllocatorMMap.get()->sync(); + } if(mMapMem.second != nullptr) { #ifdef MNN_OPENCL_SVM_ENABLE if(mUseSvm) diff --git a/source/backend/opencl/core/OpenCLBackend.hpp b/source/backend/opencl/core/OpenCLBackend.hpp index 665670d028..e51c32d5f2 100644 --- a/source/backend/opencl/core/OpenCLBackend.hpp +++ b/source/backend/opencl/core/OpenCLBackend.hpp @@ -15,7 +15,9 @@ #include #include +#include "core/BufferAllocator.hpp" #include "backend/opencl/core/BufferPool.hpp" +#include "backend/opencl/core/MmapPool.hpp" #include "backend/opencl/core/ImageBufferConvertor.hpp" #include "backend/opencl/core/BufferConvertor.hpp" #include "backend/opencl/core/ImagePool.hpp" @@ -68,6 +70,7 @@ class CLRuntime : public Runtime { std::shared_ptr mOpenCLRuntime; std::shared_ptr mImagePool; std::shared_ptr mBufferPool; + mutable std::shared_ptr mMmapPool; BackendConfig::PrecisionMode mPrecision; BackendConfig::MemoryMode mMemory; bool mCLRuntimeError = false; @@ -79,7 +82,7 @@ class CLRuntime : public Runtime { class OpenCLBackend : public Backend { public: - OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, int gpuMode, std::shared_ptrimgPool, std::shared_ptr bufPool, const CLRuntime *runtime); + OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, int gpuMode, std::shared_ptrimgPool, std::shared_ptr bufPool, std::shared_ptr mmapPool, const CLRuntime *runtime); ~OpenCLBackend(); OpenCLRuntime *getOpenCLRuntime(); @@ -110,6 +113,10 @@ class OpenCLBackend : public Backend { BufferPool *getBufferPool() const { return mBufferPool; } + + std::shared_ptr getStaticAllocatorMMap() const { + return mStaticAllocatorMMap; + } virtual bool onSelectDynamicAllocator(int index, int maxIndex) override; BackendConfig::PrecisionMode getPrecision() const { @@ -163,6 +170,7 @@ class OpenCLBackend : public Backend { ImagePool* mImagePool; BufferPool* mBufferPool; + std::shared_ptr mStaticAllocatorMMap; std::shared_ptr mExecutionBufferPool; std::shared_ptr mImagePoolFirst; diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp index a9e8561409..fee8459ced 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp @@ -35,27 +35,27 @@ ConvBufCommonExecution::ConvBufCommonExecution(const Convolution2D *conv2dParams mResource.reset(new ConvBufResource); mResource->mBias.reset(Tensor::createDevice({1, 1, 1, ROUND_UP(biasSize, 32)})); backend->onAcquireBuffer(mResource->mBias.get(), Backend::STATIC); - cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); - - cl_int res; - auto biasPtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(biasPtrCL != nullptr && res == CL_SUCCESS){ - ::memset(biasPtrCL, 0, buffer_size); - if (nullptr != conv2dParams->bias()) { - const float *biasDataPtr = conv2dParams->bias()->data(); - if(openclBackend->getPrecision() != BackendConfig::Precision_High){ - for(int i=0; igetRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); + cl_int res; + auto biasPtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(biasPtrCL != nullptr && res == CL_SUCCESS){ + ::memset(biasPtrCL, 0, buffer_size); + if (nullptr != conv2dParams->bias()) { + const float *biasDataPtr = conv2dParams->bias()->data(); + if(openclBackend->getPrecision() != BackendConfig::Precision_High){ + for(int i=0; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } - openclBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } ConvBufCommonExecution::ConvBufCommonExecution(const Op *op, Backend *backend, bool isExtra) { @@ -79,46 +79,48 @@ ConvBufCommonExecution::ConvBufCommonExecution(const Op *op, Backend *backend, b mResource.reset(new ConvBufResource); mResource->mBias.reset(Tensor::createDevice({1, 1, 1, ROUND_UP(biasSize, 32)})); backend->onAcquireBuffer(mResource->mBias.get(), Backend::STATIC); - cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); - - auto biasPtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(biasPtrCL != nullptr && res == CL_SUCCESS){ - ::memset(biasPtrCL, 0, buffer_size); - if (nullptr != conv2dParams->bias()) { - const float *biasDataPtr = conv2dParams->bias()->data(); - if(openclBackend->getPrecision() != BackendConfig::Precision_High){ - for(int i=0; igetRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); + auto biasPtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(biasPtrCL != nullptr && res == CL_SUCCESS){ + ::memset(biasPtrCL, 0, buffer_size); + if (nullptr != conv2dParams->bias()) { + const float *biasDataPtr = conv2dParams->bias()->data(); + if(openclBackend->getPrecision() != BackendConfig::Precision_High){ + for(int i=0; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } - openclBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); if(isExtra){ const PRelu* preluParam = flatbuffers::GetRoot(op->main_as_Extra()->attr()->GetAs(1)->tensor()->uint8s()->data()); const float *slopeDataPtr = preluParam->slope()->data(); mResource->mSlope.reset(Tensor::createDevice({1, 1, 1, ROUND_UP(biasSize, 32)})); backend->onAcquireBuffer(mResource->mSlope.get(), Backend::STATIC); - cl::Buffer &slopeBuffer = openCLBuffer(mResource->mSlope.get()); - - auto slopePtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(slopeBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(slopePtrCL != nullptr && res == CL_SUCCESS){ - if(openclBackend->getPrecision() != BackendConfig::Precision_High){ - for(int i=0; igetRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &slopeBuffer = openCLBuffer(mResource->mSlope.get()); + auto slopePtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(slopeBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(slopePtrCL != nullptr && res == CL_SUCCESS){ + if(openclBackend->getPrecision() != BackendConfig::Precision_High){ + for(int i=0; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(slopeBuffer, slopePtrCL); } - openclBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(slopeBuffer, slopePtrCL); } } @@ -221,74 +223,76 @@ ConvBufExecution::ConvBufExecution(const std::vector &inputs, const st } if (mResource->mConv1x1Opt) { int buffer_size = ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) * ROUND_UP(mResource->mInputChannel, mResource->mAlignK); - mResource->mFilter.reset( - Tensor::createDevice({buffer_size})); + mResource->mFilter.reset(Tensor::createDevice({buffer_size})); mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); - - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - buffer_size *= sizeof(half_float::half); - } else { - buffer_size *= sizeof(float); - } - - cl::Buffer &filterBuffer = openCLBuffer(mResource->mFilter.get()); - cl_int error; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - filterBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); - if(nullptr != ptrCL && error == CL_SUCCESS){ - memset((void *)ptrCL, 0, buffer_size); + + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - // [Ci, Co] ( [K, N] ) - for (int o = 0; o < mResource->mOutputChannel; o++) { - for (int i = 0; i < mResource->mInputChannel; i++) { - ((half_float::half *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) + o] = (half_float::half)(mFilterDataPtr[o * mResource->mInputChannel + i]); - } - } + buffer_size *= sizeof(half_float::half); } else { - for (int o = 0; o < mResource->mOutputChannel; o++) { - for (int i = 0; i < mResource->mInputChannel; i++) { - ((float *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) + o] = (mFilterDataPtr[o * mResource->mInputChannel + i]); + buffer_size *= sizeof(float); + } + cl::Buffer &filterBuffer = openCLBuffer(mResource->mFilter.get()); + cl_int error; + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( + filterBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); + if(nullptr != ptrCL && error == CL_SUCCESS){ + memset((void *)ptrCL, 0, buffer_size); + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + // [Ci, Co] ( [K, N] ) + for (int o = 0; o < mResource->mOutputChannel; o++) { + for (int i = 0; i < mResource->mInputChannel; i++) { + ((half_float::half *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) + o] = (half_float::half)(mFilterDataPtr[o * mResource->mInputChannel + i]); + } + } + } else { + for (int o = 0; o < mResource->mOutputChannel; o++) { + for (int i = 0; i < mResource->mInputChannel; i++) { + ((float *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) + o] = (mFilterDataPtr[o * mResource->mInputChannel + i]); + } } } + }else{ + MNN_ERROR("Map error filterPtrCL == nullptr \n"); } - }else{ - MNN_ERROR("Map error filterPtrCL == nullptr \n"); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBuffer, ptrCL); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBuffer, ptrCL); } else { mResource->mFilter.reset( Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 4) * ROUND_UP(mResource->mInputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight})); if (mFilterDataPtr != nullptr) { std::vector filterImageShape{ROUND_UP(mResource->mInputChannel, 4), (UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight)}; - std::shared_ptr filterBuffer( - Tensor::createDevice({mResource->mOutputChannel, ROUND_UP(mResource->mInputChannel, 4), mResource->mKernelWidth, mResource->mKernelHeight})); - - int buffer_size = filterBuffer->elementSize() * sizeof(float); - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); - filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); - - cl_int res; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(ptrCL != nullptr && res == CL_SUCCESS) { - ::memset(ptrCL, 0, buffer_size); - const int copy_size = mResource->mKernelWidth * mResource->mKernelHeight * sizeof(float); - for(int oc=0; ocmOutputChannel; oc++) { - for(int ic=0; icmInputChannel; ic++) { - ::memcpy((float *)ptrCL + (oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelWidth * mResource->mKernelHeight, mFilterDataPtr + (oc * mResource->mInputChannel + ic) * mResource->mKernelWidth * mResource->mKernelHeight, copy_size); + mResource->mFilter.reset(Tensor::createDevice({filterImageShape[1] * 4 * filterImageShape[0]})); + mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); + + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + std::shared_ptr filterBuffer(Tensor::createDevice({mResource->mOutputChannel, ROUND_UP(mResource->mInputChannel, 4), mResource->mKernelWidth, mResource->mKernelHeight})); + + int buffer_size = filterBuffer->elementSize() * sizeof(float); + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); + filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); + + cl_int res; + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(ptrCL != nullptr && res == CL_SUCCESS) { + ::memset(ptrCL, 0, buffer_size); + const int copy_size = mResource->mKernelWidth * mResource->mKernelHeight * sizeof(float); + for(int oc=0; ocmOutputChannel; oc++) { + for(int ic=0; icmInputChannel; ic++) { + ::memcpy((float *)ptrCL + (oc * ROUND_UP(mResource->mInputChannel, 4) + ic) * mResource->mKernelWidth * mResource->mKernelHeight, mFilterDataPtr + (oc * mResource->mInputChannel + ic) * mResource->mKernelWidth * mResource->mKernelHeight, copy_size); + } } + }else{ + MNN_ERROR("Map error ptrCL == nullptr \n"); } - }else{ - MNN_ERROR("Map error ptrCL == nullptr \n"); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); + + MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; + + bool needTrans = true; + bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), needTrans); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); - - mResource->mFilter.reset(Tensor::createDevice({filterImageShape[1] * 4 * filterImageShape[0]})); - mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); - MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; - - bool needTrans = true; - bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), needTrans); } } diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.hpp b/source/backend/opencl/execution/buffer/ConvBufExecution.hpp index 8cbcaf75f2..11a18300dd 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.hpp @@ -21,8 +21,8 @@ struct ConvBufResource { const Convolution2DCommon *mConv2dCommonParams; const Convolution2D *mConv2dParams; std::shared_ptr mKernelBuffer; - std::shared_ptr mKernelImage; - std::shared_ptr dequantScaleOffset; + std::shared_ptr mKernelImage; + std::shared_ptr mDequantScaleOffsetBuffer; std::shared_ptr mFilter; std::shared_ptr mBias; std::shared_ptr mSlope; diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp index 0502fe9005..03b9f64fbe 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp @@ -28,7 +28,7 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(void *weight_ptr) { } // src of alpha in CPU float * dequantAlpha = quanCommon->alpha.get(); - int totalCount = quanCommon->alpha.size(); + int totalCount = quanCommon->alphaSize; int soSize = 1; if (quanCommon->asymmetric) { soSize = 2; @@ -39,95 +39,110 @@ void ConvBufLowMemoryExecution::getInfoFromOpLowMemory(void *weight_ptr) { mResource->mBlockSize = totalCount / numAlpha; // set mDequantScale mDequantOffset int numAlphaPack = ROUND_UP(numAlpha, 4); + int fpBytes = mOpenCLBackend->fpBytes(); + int buffer_size = mResource->mBlockSize * numAlphaPack * fpBytes * soSize + sizeof(float); - mResource->dequantScaleOffset.reset(Tensor::createDevice({ROUND_UP(mResource->mBlockSize, 4), numAlphaPack, soSize})); - mOpenCLBackend->onAcquireBuffer(mResource->dequantScaleOffset.get(), Backend::STATIC); - cl::Buffer &dequantScaleOffsetBuffer = openCLBuffer(mResource->dequantScaleOffset.get()); + auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mDequantScaleOffsetBuffer = staticMapAlloc.get()->allocBuffer(buffer_size); + }else{ + mResource->mDequantScaleOffsetBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); + } // transfer data from src in cpu to dst in gpu - int fpBytes = mOpenCLBackend->fpBytes(); cl_int resBias, resScaleOffset; - - int mapSize = mResource->mBlockSize * numAlphaPack * fpBytes * soSize; - void * dequantScaleOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(dequantScaleOffsetBuffer, true, CL_MAP_WRITE, 0, mapSize, nullptr, nullptr, &resScaleOffset); float coef = 1.0; - if(fpBytes == 2) { - float max_data = 0.0f; - if (quanCommon->asymmetric){ - for (int i = 0; i < numAlpha; ++i) { - auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; - for(int j = 0; j < mResource->mBlockSize; ++j){ - float s = fabsf(srcZ[2*j+0]); - float b = fabsf(srcZ[2*j+1]); - float temp = ALIMAX(s, b); - if(temp > max_data) { - max_data = temp; - } - } - } + + void * dequantScaleOffsetBufferMap = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*mResource->mDequantScaleOffsetBuffer.get(), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &resScaleOffset); + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap > 1){ + if(fpBytes == 2){ + float* coefMapPtr = (float*)(((half_float::half*)dequantScaleOffsetBufferMap) + (numAlphaPack * mResource->mBlockSize * soSize)); + coef = coefMapPtr[0]; }else{ - for (int i = 0; i < numAlpha; ++i) { - auto srcZ = dequantAlpha + i * mResource->mBlockSize; - for(int j = 0; j < mResource->mBlockSize; ++j){ - float s = fabsf(srcZ[j]); - if(s > max_data) { - max_data = s; - } - } - } + coef = ((float *)dequantScaleOffsetBufferMap)[(numAlphaPack * mResource->mBlockSize * soSize)]; } - if(abs(max_data) >= 0.000001f){ - coef = 1000.0f / max_data; - } - if (dequantScaleOffsetBufferMap != nullptr && resScaleOffset == CL_SUCCESS) { - if (quanCommon->asymmetric) { + }else{ + if(fpBytes == 2) { + float max_data = 0.0f; + if (quanCommon->asymmetric){ for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; for(int j = 0; j < mResource->mBlockSize; ++j){ - float o = srcZ[2*j+0]; - float s = srcZ[2*j+1]; - ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = (half_float::half)(s * coef); - ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = (half_float::half)(o * coef); + float s = fabsf(srcZ[2*j+0]); + float b = fabsf(srcZ[2*j+1]); + float temp = ALIMAX(s, b); + if(temp > max_data) { + max_data = temp; + } } } - } else { + }else{ for (int i = 0; i < numAlpha; ++i) { auto srcZ = dequantAlpha + i * mResource->mBlockSize; for(int j = 0; j < mResource->mBlockSize; ++j){ - ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = (half_float::half)(srcZ[j] * coef); + float s = fabsf(srcZ[j]); + if(s > max_data) { + max_data = s; + } } } } - } else { - MNN_ERROR("Map error dequantBufferMap == nullptr \n"); - MNN_ASSERT(false); - } - } else{ - if (dequantScaleOffsetBufferMap != nullptr && resScaleOffset == CL_SUCCESS) { - if (quanCommon->asymmetric) { - for (int i = 0; i < numAlpha; ++i) { - auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; - for(int j = 0; j < mResource->mBlockSize; ++j){ - float o = srcZ[2*j+0]; - float s = srcZ[2*j+1]; - ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = s * coef; - ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = o * coef; + if(abs(max_data) >= 0.000001f){ + coef = 1000.0f / max_data; + } + if (dequantScaleOffsetBufferMap != nullptr && resScaleOffset == CL_SUCCESS) { + if (quanCommon->asymmetric) { + for (int i = 0; i < numAlpha; ++i) { + auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; + for(int j = 0; j < mResource->mBlockSize; ++j){ + float o = srcZ[2*j+0]; + float s = srcZ[2*j+1]; + ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = (half_float::half)(s * coef); + ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = (half_float::half)(o * coef); + } + } + } else { + for (int i = 0; i < numAlpha; ++i) { + auto srcZ = dequantAlpha + i * mResource->mBlockSize; + for(int j = 0; j < mResource->mBlockSize; ++j){ + ((half_float::half*)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = (half_float::half)(srcZ[j] * coef); + } } } + float* coefMapPtr = (float*)(((half_float::half*)dequantScaleOffsetBufferMap) + (numAlphaPack * mResource->mBlockSize * soSize)); + coefMapPtr[0] = coef; } else { - for (int i = 0; i < numAlpha; ++i) { - auto srcZ = dequantAlpha + i * mResource->mBlockSize; - for(int j = 0; j < mResource->mBlockSize; ++j){ - ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = srcZ[j] * coef; + MNN_ERROR("Map error dequantBufferMap == nullptr \n"); + MNN_ASSERT(false); + } + } else{ + if (dequantScaleOffsetBufferMap != nullptr && resScaleOffset == CL_SUCCESS) { + if (quanCommon->asymmetric) { + for (int i = 0; i < numAlpha; ++i) { + auto srcZ = dequantAlpha + i * mResource->mBlockSize * 2; + for(int j = 0; j < mResource->mBlockSize; ++j){ + float o = srcZ[2*j+0]; + float s = srcZ[2*j+1]; + ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2] = s * coef; + ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i) * 2 + 1] = o * coef; + } + } + } else { + for (int i = 0; i < numAlpha; ++i) { + auto srcZ = dequantAlpha + i * mResource->mBlockSize; + for(int j = 0; j < mResource->mBlockSize; ++j){ + ((float *)dequantScaleOffsetBufferMap)[(j * numAlphaPack + i)] = srcZ[j] * coef; + } } } + ((float *)dequantScaleOffsetBufferMap)[(numAlphaPack * mResource->mBlockSize * soSize)] = coef; + } else { + MNN_ERROR("Map error dequantBufferMap == nullptr \n"); + MNN_ASSERT(false); } - } else { - MNN_ERROR("Map error dequantBufferMap == nullptr \n"); - MNN_ASSERT(false); } } mResource->mCoef = coef; - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(dequantScaleOffsetBuffer, dequantScaleOffsetBufferMap); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mDequantScaleOffsetBuffer.get(), dequantScaleOffsetBufferMap); // set mFilterDataPtr mFilterDataPtr = (void *)quanCommon->weight.get(); } @@ -203,7 +218,7 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory() { } else{ getInfoFromOpLowMemory(nullptr); } - cl_int res; + cl_int res = CL_SUCCESS; std::shared_ptr filterBuffer(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, PACK_COUT), ROUND_UP(mResource->mInputChannel, PACK_CIN), 1, 1})); size_t buffer_size = filterBuffer->usize() / sizeof(float); size_t cpy_size = mResource->mOutputChannel * mResource->mInputChannel; @@ -216,35 +231,72 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory() { } else if(mResource->mNumQuantBit == 8){ actual_packCin /= 2; } else {/* More types to be supported. */} - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); - void *mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(mapPtr != nullptr && res == CL_SUCCESS){ + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); + void *mapPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(mapPtr != nullptr && res == CL_SUCCESS){ + if(preAllocGpuMem){ + getInfoFromOpLowMemory(mapPtr); + } else{ + ::memcpy(mapPtr, mFilterDataPtr, cpy_size); + } + } else { + MNN_ERROR("set1x1WeightLowMemory: Map error ptrCL == nullptr \n"); + MNN_ASSERT(false); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, mapPtr); + // Use Image load weights + if(UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384){ + mResource->mUseImage = true; + } + auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); + if(mResource->mUseImage){ + size_t w = UP_DIV(mResource->mInputChannel, actual_packCin); + size_t h = UP_DIV(mResource->mOutputChannel, PACK_COUT); + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mKernelImage = staticMapAlloc.get()->allocImage(w, h, CL_SIGNED_INT32); + }else{ + mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); + } + if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) { + MNN_ERROR("Alloc Image %d x %d error, code:%d \n", (int)w, (int)h, (int)res); + } + }else{ + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mKernelBuffer = staticMapAlloc.get()->allocBuffer(buffer_size); + }else{ + mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); + } + } + convertToQuantWeight1x1Buffer(filterBufferCL); + }else { if(preAllocGpuMem){ - getInfoFromOpLowMemory(mapPtr); - } else{ - ::memcpy(mapPtr, mFilterDataPtr, cpy_size); + getInfoFromOpLowMemory(nullptr); } - } else { - MNN_ERROR("set1x1WeightLowMemory: Map error ptrCL == nullptr \n"); - MNN_ASSERT(false); - } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, mapPtr); - - // Use Image load weights - if(UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384){ - mResource->mUseImage = true; - } - if(mResource->mUseImage){ - size_t w = UP_DIV(mResource->mInputChannel, actual_packCin); - size_t h = UP_DIV(mResource->mOutputChannel, PACK_COUT); - mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); - if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) { - MNN_ERROR("Alloc Image %d x %d error, code:%d \n", (int)w, (int)h, (int)res); + // Use Image load weights + if(UP_DIV(mResource->mInputChannel, actual_packCin) <= 16384 && ROUND_UP(mResource->mOutputChannel, PACK_COUT) <= 16384){ + mResource->mUseImage = true; + } + auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); + if(mResource->mUseImage){ + size_t w = UP_DIV(mResource->mInputChannel, actual_packCin); + size_t h = UP_DIV(mResource->mOutputChannel, PACK_COUT); + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mKernelImage = staticMapAlloc.get()->allocImage(w, h, CL_SIGNED_INT32); + }else{ + mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); + } + if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) { + MNN_ERROR("Alloc Image %d x %d error, code:%d \n", (int)w, (int)h, (int)res); + } + }else{ + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mKernelBuffer = staticMapAlloc.get()->allocBuffer(buffer_size); + }else{ + mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); + } } - }else{ - mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); } - convertToQuantWeight1x1Buffer(filterBufferCL); } // set mFilter for the general kernels void ConvBufLowMemoryExecution::setGeneralWeightLowMemory() { @@ -258,47 +310,63 @@ void ConvBufLowMemoryExecution::setGeneralWeightLowMemory() { } else{ getInfoFromOpLowMemory(nullptr); } - std::shared_ptr filterBuffer(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 4), mResource->mInputChannel, mResource->mKernelWidth, mResource->mKernelHeight})); - size_t buffer_size = filterBuffer->usize() / sizeof(float); - size_t cpy_size = mResource->mOutputChannel * mResource->mInputChannel * mResource->mKernelWidth * mResource->mKernelHeight; - if (mResource->mNumQuantBit == 4){ - buffer_size /= 2; - cpy_size = UP_DIV(cpy_size, 2); - } - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); - filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); - // map and pack data from filterDataPtr - cl_int res; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(ptrCL != nullptr && res == CL_SUCCESS) { - if(preAllocGpuMem){ - getInfoFromOpLowMemory(ptrCL); - } else{ - ::memcpy(ptrCL, mFilterDataPtr, cpy_size); + + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + std::shared_ptr filterBuffer(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 4), mResource->mInputChannel, mResource->mKernelWidth, mResource->mKernelHeight})); + size_t buffer_size = filterBuffer->usize() / sizeof(float); + size_t cpy_size = mResource->mOutputChannel * mResource->mInputChannel * mResource->mKernelWidth * mResource->mKernelHeight; + if (mResource->mNumQuantBit == 4){ + buffer_size /= 2; + cpy_size = UP_DIV(cpy_size, 2); } - } else { - MNN_ERROR("setGeneralWeightLowMemory: Map error ptrCL == nullptr \n"); - } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); - // convert to NC4HW4 - if (mResource->mNumQuantBit == 8) { - // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight - mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 4 * ROUND_UP(mResource->mInputChannel, 4)})); - mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); - MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; - // filterBuffer shape: {OC, ROUND_UP(IC, 4), mKernelWidth, mKernelHeight} - bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), false, true, true, mResource->mNumQuantBit); - } else if (mResource->mNumQuantBit == 4){ - // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight - // For int4 case, data stored in mFilter should be uint8_t, - // while "Tensor::createDevice" occupies more memory than "Tensor::createDevice". - // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be supported. - mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 2 * ROUND_UP(mResource->mInputChannel, 4)})); - mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); + filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); + // map and pack data from filterDataPtr + cl_int res; + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(ptrCL != nullptr && res == CL_SUCCESS) { + if(preAllocGpuMem){ + getInfoFromOpLowMemory(ptrCL); + } else{ + ::memcpy(ptrCL, mFilterDataPtr, cpy_size); + } + } else { + MNN_ERROR("setGeneralWeightLowMemory: Map error ptrCL == nullptr \n"); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); + if (mResource->mNumQuantBit == 8) { + // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight + mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 4 * ROUND_UP(mResource->mInputChannel, 4)})); + mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); + } else if (mResource->mNumQuantBit == 4){ + // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight + // For int4 case, data stored in mFilter should be uint8_t, + // while "Tensor::createDevice" occupies more memory than "Tensor::createDevice". + // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be supported. + mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 2 * ROUND_UP(mResource->mInputChannel, 4)})); + mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); + } + // convert to NC4HW4 MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; - // filterBuffer shape: {OC, ROUND_UP(IC, 4), mKernelWidth, mKernelHeight} bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), false, true, true, mResource->mNumQuantBit); - } else {/* More types to be supported. */} + }else{ + if(preAllocGpuMem){ + getInfoFromOpLowMemory(nullptr); + } + if (mResource->mNumQuantBit == 8) { + // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight + mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 4 * ROUND_UP(mResource->mInputChannel, 4)})); + mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); + } else if (mResource->mNumQuantBit == 4){ + // ROUND_UP(IC, 4), UP_DIV(OC, 4) * mKernelWidth * mKernelHeight + // For int4 case, data stored in mFilter should be uint8_t, + // while "Tensor::createDevice" occupies more memory than "Tensor::createDevice". + // Therefore, we use "Tensor::createDevice" currently, leaving "Tensor::createDevice" to be supported. + mResource->mFilter.reset(Tensor::createDevice({1, UP_DIV(mResource->mOutputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight, 1, 2 * ROUND_UP(mResource->mInputChannel, 4)})); + mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); + } + } + } // select the fastest kernel for the general cases by tuning void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * output) { @@ -355,7 +423,7 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(input)); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); - ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(output)); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); @@ -402,7 +470,7 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); @@ -497,7 +565,7 @@ void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor }else{ ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmWeightTensor.get())); ret |= unit.kernel->get().setArg(idx++, static_cast(mResource->mInputChannel)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannel4Align)); @@ -650,7 +718,7 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu }else{ ret |= kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } - ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel->get().setArg(idx++, openCLBuffer(output)); ret |= kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); @@ -686,7 +754,7 @@ void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * outpu }else{ ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); @@ -813,7 +881,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu }else{ ret |= kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } - ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); ret |= kernel->get().setArg(idx++, static_cast(outputChannelAlign8)); @@ -848,7 +916,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu }else{ ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign8)); @@ -900,7 +968,7 @@ void ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * outpu }else{ ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, *mResource->mDequantScaleOffsetBuffer.get()); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); diff --git a/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp b/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp index 3d26b1c6f1..54f320485c 100644 --- a/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp @@ -113,21 +113,22 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx mResource->mBias.reset(Tensor::createDevice({1, 1, 1, (int)ALIGN_UP4(mCo)})); mOpenCLBackend->onAcquireBuffer(mResource->mBias.get(), Backend::STATIC); - cl::Buffer &bias_buffer = *(cl::Buffer *)mResource->mBias->buffer().device; - - auto bias_ptr = queue.enqueueMapBuffer(bias_buffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &ret_code); - if(bias_ptr == nullptr || ret_code) { - MNN_ERROR("clBuffer map error!\n"); - } - ::memset(bias_ptr, 0, buffer_size); - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for(int i=0; ibias()->data()[i]; + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &bias_buffer = *(cl::Buffer *)mResource->mBias->buffer().device; + auto bias_ptr = queue.enqueueMapBuffer(bias_buffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &ret_code); + if(bias_ptr == nullptr || ret_code) { + MNN_ERROR("clBuffer map error!\n"); } - } else { - ::memcpy(bias_ptr, conv2D->bias()->data(), mCo*sizeof(float)); + ::memset(bias_ptr, 0, buffer_size); + if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for(int i=0; ibias()->data()[i]; + } + } else { + ::memcpy(bias_ptr, conv2D->bias()->data(), mCo*sizeof(float)); + } + queue.enqueueUnmapMemObject(bias_buffer, bias_ptr); } - queue.enqueueUnmapMemObject(bias_buffer, bias_ptr); auto ocC16 = UP_DIV(mCo, 16); @@ -155,21 +156,22 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx cl::Buffer& weightBuffer = *(cl::Buffer*)mResource->mWeight->buffer().device; - auto weight_ptr = - queue.enqueueMapBuffer(weightBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &ret_code); - if (weight_ptr != nullptr && ret_code == CL_SUCCESS) { - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for (int i = 0; i < weightDest->elementSize(); i++) { - ((half_float::half*)weight_ptr)[i] = (half_float::half)(weightDest->host()[i]); + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + auto weight_ptr = queue.enqueueMapBuffer(weightBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &ret_code); + if (weight_ptr != nullptr && ret_code == CL_SUCCESS) { + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < weightDest->elementSize(); i++) { + ((half_float::half*)weight_ptr)[i] = (half_float::half)(weightDest->host()[i]); + } + } else { + ::memcpy(weight_ptr, weightDest->host(), buffer_size); } } else { - ::memcpy(weight_ptr, weightDest->host(), buffer_size); + MNN_ERROR("Map error weightPtr == nullptr \n"); } - } else { - MNN_ERROR("Map error weightPtr == nullptr \n"); + + queue.enqueueUnmapMemObject(weightBuffer, weight_ptr); } - - queue.enqueueUnmapMemObject(weightBuffer, weight_ptr); }else #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ { @@ -185,21 +187,22 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx mResource->mBias.reset(Tensor::createDevice({1, 1, 1, (int)ALIGN_UP4(mCo)})); mOpenCLBackend->onAcquireBuffer(mResource->mBias.get(), Backend::STATIC); - cl::Buffer &bias_buffer = *(cl::Buffer *)mResource->mBias->buffer().device; - - auto bias_ptr = queue.enqueueMapBuffer(bias_buffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &ret_code); - if(bias_ptr == nullptr || ret_code) { - MNN_ERROR("clBuffer map error!\n"); - } - ::memset(bias_ptr, 0, buffer_size); - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for(int i=0; ibias()->data()[i]; + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &bias_buffer = *(cl::Buffer *)mResource->mBias->buffer().device; + auto bias_ptr = queue.enqueueMapBuffer(bias_buffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &ret_code); + if(bias_ptr == nullptr || ret_code) { + MNN_ERROR("clBuffer map error!\n"); } - } else { - ::memcpy(bias_ptr, conv2D->bias()->data(), mCo*sizeof(float)); + ::memset(bias_ptr, 0, buffer_size); + if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for(int i=0; ibias()->data()[i]; + } + } else { + ::memcpy(bias_ptr, conv2D->bias()->data(), mCo*sizeof(float)); + } + queue.enqueueUnmapMemObject(bias_buffer, bias_ptr); } - queue.enqueueUnmapMemObject(bias_buffer, bias_ptr); int unit = UNIT; int kernelSize = kx; @@ -223,19 +226,21 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx mResource->mWeight.reset(Tensor::createDevice({alpha * alpha * ROUND_UP(mCo, mResource->mAlignN) * ROUND_UP(mCi, mResource->mAlignK)}));//NHWC mOpenCLBackend->onAcquireBuffer(mResource->mWeight.get(), Backend::STATIC); - buffer_size = mCo * mCi * ky * kx * sizeof(float); - cl::Buffer& weightBufferCL = openCLBuffer(tmpFilterTensor.get()); - - cl_int res; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(weightBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(ptrCL != nullptr && res == CL_SUCCESS) { - ::memcpy(ptrCL, filterDataPtr, buffer_size); - }else{ - MNN_ERROR("Map weightBufferCL error:%d, ptrCL == nullptr \n", res); + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + buffer_size = mCo * mCi * ky * kx * sizeof(float); + cl::Buffer& weightBufferCL = openCLBuffer(tmpFilterTensor.get()); + + cl_int res; + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(weightBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if(ptrCL != nullptr && res == CL_SUCCESS) { + ::memcpy(ptrCL, filterDataPtr, buffer_size); + }else{ + MNN_ERROR("Map weightBufferCL error:%d, ptrCL == nullptr \n", res); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(weightBufferCL, ptrCL); + + convertWeightFormat(weightBufferCL, mResource->mAlignK, mResource->mAlignN); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(weightBufferCL, ptrCL); - - convertWeightFormat(weightBufferCL, mResource->mAlignK, mResource->mAlignN); } } diff --git a/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp b/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp index c2c26fa1a1..bd61923a5d 100644 --- a/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp @@ -111,7 +111,7 @@ ConvSubgroupBuf::ConvSubgroupBuf(const std::vector &inputs, const std: int weightSize = 0; std::shared_ptr quanCommon; ConvolutionCommon::getConvParameters(&quanCommon, backend, op, &FilterDataPtr, &weightSize); - if (FilterDataPtr != nullptr) { + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1 && FilterDataPtr != nullptr) { std::shared_ptr sourceWeight( Tensor::create(std::vector{mResource->mOutputChannel, mResource->mInputChannel, mResource->mKernelWidth, mResource->mKernelHeight}, (void *)FilterDataPtr, Tensor::CAFFE)); @@ -164,24 +164,25 @@ ConvSubgroupBuf::ConvSubgroupBuf(const std::vector &inputs, const std: cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); cl_int res; - auto biasPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if (biasPtrCL != nullptr && res == CL_SUCCESS) { - ::memset(biasPtrCL, 0, buffer_size); - if (nullptr != conv2dParams->bias()) { - const float *biasDataPtr = conv2dParams->bias()->data(); - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for (int i = 0; i < biasSize; i++) { - ((half_float::half *)biasPtrCL)[i] = (half_float::half)(biasDataPtr[i]); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + auto biasPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if (biasPtrCL != nullptr && res == CL_SUCCESS) { + ::memset(biasPtrCL, 0, buffer_size); + if (nullptr != conv2dParams->bias()) { + const float *biasDataPtr = conv2dParams->bias()->data(); + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < biasSize; i++) { + ((half_float::half *)biasPtrCL)[i] = (half_float::half)(biasDataPtr[i]); + } + } else { + ::memcpy(biasPtrCL, biasDataPtr, biasSize * sizeof(float)); } - } else { - ::memcpy(biasPtrCL, biasDataPtr, biasSize * sizeof(float)); } + } else { + MNN_ERROR("Map error biasPtrCL == nullptr \n"); } - } else { - MNN_ERROR("Map error biasPtrCL == nullptr \n"); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } if (mResource->mConv2dCommonParams->relu()) { diff --git a/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp b/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp index 3da5c87cd8..b2410f9615 100644 --- a/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp @@ -40,32 +40,31 @@ DeconvBufExecution::DeconvBufExecution(const std::vector &inputs, cons int inputChannel = weightSize / (kernelWidth * kernelHeight * outputChannel); std::vector filterShape{outputChannel, inputChannel, kernelHeight, kernelWidth}; std::vector filterImageShape{(int)inputChannel, (int)UP_DIV(outputChannel, 4) * kernelWidth * kernelHeight}; - std::vector filterDataPtrTransformed; - filterDataPtrTransformed.resize(weightSize); - IOHW2OIHW(filterDataPtr, filterDataPtrTransformed.data(), outputChannel, inputChannel, kernelHeight, - kernelWidth); - - std::shared_ptr filterBuffer( - Tensor::createDevice({outputChannel, inputChannel, kernelHeight, kernelWidth})); - - size_t buffer_size = filterBuffer->elementSize() * sizeof(float); - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_ONLY | CL_MEM_ALLOC_HOST_PTR, buffer_size); - filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); - cl_int error; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); - if(ptrCL != nullptr && error == CL_SUCCESS){ - ::memcpy(ptrCL, filterDataPtrTransformed.data(), filterBuffer->size()); - }else{ - MNN_ERROR("Map error ptrCL == nullptr \n"); - } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); - mResource->mFilter.reset(Tensor::createDevice({1, filterImageShape[1], 1, 4 * filterImageShape[0]})); mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); - MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; - - bool needTrans = true; - bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), needTrans); + + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + std::vector filterDataPtrTransformed; + filterDataPtrTransformed.resize(weightSize); + IOHW2OIHW(filterDataPtr, filterDataPtrTransformed.data(), outputChannel, inputChannel, kernelHeight, kernelWidth); + + std::shared_ptr filterBuffer(Tensor::createDevice({outputChannel, inputChannel, kernelHeight, kernelWidth})); + + size_t buffer_size = filterBuffer->elementSize() * sizeof(float); + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_ONLY | CL_MEM_ALLOC_HOST_PTR, buffer_size); + filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); + cl_int error; + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); + if(ptrCL != nullptr && error == CL_SUCCESS){ + ::memcpy(ptrCL, filterDataPtrTransformed.data(), filterBuffer->size()); + }else{ + MNN_ERROR("Map error ptrCL == nullptr \n"); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); + MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; + bool needTrans = true; + bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), needTrans); + } mResource->mBuildOptions.emplace("-DBIAS"); if (conv2dCommonParams->relu() == true) { mResource->mBuildOptions.emplace("-DRELU"); diff --git a/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp b/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp index 14f7b590e2..97722f94ac 100644 --- a/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp @@ -28,7 +28,7 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector int outputChannel = mResource->mConv2dCommonParams->outputCount(); std::vector filterShape{1, outputChannel, kernelHeight, kernelWidth}; - std::vector filterImageShape{(int)kernelHeight * kernelWidth, (int)UP_DIV(outputChannel, 4)}; + int filterImageShape[2] = {(int)kernelHeight * kernelWidth, (int)UP_DIV(outputChannel, 4)}; const float* filterDataPtr = nullptr; @@ -37,26 +37,25 @@ DepthwiseConvBufExecution::DepthwiseConvBufExecution(const std::vector ConvolutionCommon::getConvParameters(&quanCommon, backend, op, &filterDataPtr, &filterDataSize); mResource->mFilter.reset(Tensor::createDevice({1, ROUND_UP(filterImageShape[1], 2)/*for kernel C8 read*/, 1, 4 * filterImageShape[0]})); - std::shared_ptr filterBuffer(Tensor::createDevice(filterShape)); - - size_t buffer_size = filterBuffer->elementSize() * sizeof(float); - cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); - filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); - cl_int error; - auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); - if(ptrCL != nullptr && error == CL_SUCCESS){ - ::memcpy(ptrCL, filterDataPtr, filterBuffer->size()); - }else{ - MNN_ERROR("Map error ptrCL == nullptr \n"); - } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); - mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); - MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; - - bool needTrans = true; - bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::DW_CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), needTrans); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + std::shared_ptr filterBuffer(Tensor::createDevice(filterShape)); + size_t buffer_size = filterBuffer->elementSize() * sizeof(float); + cl::Buffer filterBufferCL(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size); + filterBuffer->buffer().device = (uint64_t)(&filterBufferCL); + cl_int error; + auto ptrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(filterBufferCL, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); + if(ptrCL != nullptr && error == CL_SUCCESS){ + ::memcpy(ptrCL, filterDataPtr, filterBuffer->size()); + }else{ + MNN_ERROR("Map error ptrCL == nullptr \n"); + } + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); + MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; + bool needTrans = true; + bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::DW_CONV2D_FILTER, mResource->mFilter.get(), mOpenCLBackend->getPrecision(), needTrans); + } if (mResource->mConv2dCommonParams->relu() == true) { mResource->mBuildOptions.emplace("-DRELU"); } else if (mResource->mConv2dCommonParams->relu6() == true) { diff --git a/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp b/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp index da7f3d1562..a08eea1e53 100644 --- a/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp @@ -32,51 +32,53 @@ DepthwiseConvSubgroupBufExecution::DepthwiseConvSubgroupBufExecution(const std:: // create tensor for intel filter mResource->mFilter.reset(Tensor::createDevice(std::vector{1, UP_DIV(outputChannel, 16), kernelWidth * kernelHeight, 16})); auto res = mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); - cl_int ret_code; - if (!res) { - mValid = false; - return; - } - const float *filterDataPtr = nullptr; - int filterDataSize = 0; - std::shared_ptr quanCommon; - ConvolutionCommon::getConvParameters(&quanCommon, backend, op, &filterDataPtr, &filterDataSize); - if (filterDataPtr != nullptr) { - std::shared_ptr sourceWeight(Tensor::create( - std::vector{1, outputChannel, kernelWidth, kernelHeight}, - (void *)filterDataPtr, Tensor::CAFFE)); - std::shared_ptr destWeight(Tensor::create(std::vector{1, UP_DIV(outputChannel, 16), kernelWidth * kernelHeight, 16})); - - transformWeight(destWeight.get(), sourceWeight.get()); - auto weightDestSize = destWeight->size(); - - auto buffer_size = destWeight->elementSize(); - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - buffer_size *= sizeof(half_float::half); - } else { - buffer_size *= sizeof(float); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl_int ret_code; + if (!res) { + mValid = false; + return; } - - cl::Buffer &weightBuffer = *(cl::Buffer *)mResource->mFilter->buffer().device; - - auto runTime = mOpenCLBackend->getOpenCLRuntime(); - auto queue = runTime->commandQueue(); - - auto weight_ptr = queue.enqueueMapBuffer(weightBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, - nullptr, &ret_code); - if (weight_ptr != nullptr && ret_code == CL_SUCCESS) { + const float *filterDataPtr = nullptr; + int filterDataSize = 0; + std::shared_ptr quanCommon; + ConvolutionCommon::getConvParameters(&quanCommon, backend, op, &filterDataPtr, &filterDataSize); + if (filterDataPtr != nullptr) { + std::shared_ptr sourceWeight(Tensor::create( + std::vector{1, outputChannel, kernelWidth, kernelHeight}, + (void *)filterDataPtr, Tensor::CAFFE)); + std::shared_ptr destWeight(Tensor::create(std::vector{1, UP_DIV(outputChannel, 16), kernelWidth * kernelHeight, 16})); + + transformWeight(destWeight.get(), sourceWeight.get()); + auto weightDestSize = destWeight->size(); + + auto buffer_size = destWeight->elementSize(); if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for (int i = 0; i < destWeight->elementSize(); i++) { - ((half_float::half *)weight_ptr)[i] = (half_float::half)(destWeight->host()[i]); + buffer_size *= sizeof(half_float::half); + } else { + buffer_size *= sizeof(float); + } + + cl::Buffer &weightBuffer = *(cl::Buffer *)mResource->mFilter->buffer().device; + + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + auto queue = runTime->commandQueue(); + + auto weight_ptr = queue.enqueueMapBuffer(weightBuffer, CL_TRUE, CL_MAP_WRITE, 0, buffer_size, nullptr, + nullptr, &ret_code); + if (weight_ptr != nullptr && ret_code == CL_SUCCESS) { + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < destWeight->elementSize(); i++) { + ((half_float::half *)weight_ptr)[i] = (half_float::half)(destWeight->host()[i]); + } + } else { + ::memcpy(weight_ptr, destWeight->host(), buffer_size); } } else { - ::memcpy(weight_ptr, destWeight->host(), buffer_size); + MNN_ERROR("Map error weightPtr == nullptr \n"); } - } else { - MNN_ERROR("Map error weightPtr == nullptr \n"); + + queue.enqueueUnmapMemObject(weightBuffer, weight_ptr); } - - queue.enqueueUnmapMemObject(weightBuffer, weight_ptr); } } { @@ -91,26 +93,26 @@ DepthwiseConvSubgroupBufExecution::DepthwiseConvSubgroupBufExecution(const std:: mResource->mBias.reset(Tensor::createDevice({1, 1, 1, ROUND_UP(biasSize, 16)})); backend->onAcquireBuffer(mResource->mBias.get(), Backend::STATIC); cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); - - cl_int res; - auto biasPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if (biasPtrCL != nullptr && res == CL_SUCCESS) { - ::memset(biasPtrCL, 0, buffer_size); - if (nullptr != mResource->mConv2dParams->bias()) { - const float *biasDataPtr = mResource->mConv2dParams->bias()->data(); - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for (int i = 0; i < biasSize; i++) { - ((half_float::half *)biasPtrCL)[i] = (half_float::half)(biasDataPtr[i]); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl_int res; + auto biasPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); + if (biasPtrCL != nullptr && res == CL_SUCCESS) { + ::memset(biasPtrCL, 0, buffer_size); + if (nullptr != mResource->mConv2dParams->bias()) { + const float *biasDataPtr = mResource->mConv2dParams->bias()->data(); + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < biasSize; i++) { + ((half_float::half *)biasPtrCL)[i] = (half_float::half)(biasDataPtr[i]); + } + } else { + ::memcpy(biasPtrCL, biasDataPtr, biasSize * sizeof(float)); } - } else { - ::memcpy(biasPtrCL, biasDataPtr, biasSize * sizeof(float)); } + } else { + MNN_ERROR("Map error biasPtrCL == nullptr \n"); } - } else { - MNN_ERROR("Map error biasPtrCL == nullptr \n"); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } if (mResource->mConv2dCommonParams->relu() == true) { diff --git a/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp b/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp index f545e865e6..221f124556 100644 --- a/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp @@ -30,29 +30,6 @@ GroupNormBufExecution::GroupNormBufExecution(const MNN::Op* op, Backend* backend if (!status) { MNN_ERROR("Out of memory when gamma is acquired in GroupNorm.\n"); } - - cl::Buffer &gammaBuffer = openCLBuffer(mGammaTensor.get()); - - cl_int res; - auto GammaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - gammaBuffer, true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &res); - if(GammaPtrCL != nullptr && res == CL_SUCCESS){ - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ - for (int i = 0; i < size; i++) { - ((half_float::half*)GammaPtrCL)[i] = (half_float::half)(group_norm_param->gamma()->data()[i]); - } - for(int i=size; igamma()->data(), size * sizeof(float)); - } - } else { - MNN_ERROR("GroupNorm Gamma map error:%d\n", res); - } - - if (group_norm_param->beta()->size() != size) { MNN_ERROR("Size of gamma and beta are not match in GroupNorm.\n"); } @@ -61,29 +38,48 @@ GroupNormBufExecution::GroupNormBufExecution(const MNN::Op* op, Backend* backend if (!status) { MNN_ERROR("Out of memory when beta is acquired in GroupNorm.\n"); } - - cl::Buffer &betaBuffer = openCLBuffer(mBetaTensor.get()); - - auto BetaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - betaBuffer, true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &res); - if(BetaPtrCL != nullptr && res == CL_SUCCESS){ - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ - for (int i = 0; i < size; i++) { - ((half_float::half*)BetaPtrCL)[i] = (half_float::half)(group_norm_param->beta()->data()[i]); + + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl_int res; + cl::Buffer &gammaBuffer = openCLBuffer(mGammaTensor.get()); + auto GammaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(gammaBuffer, true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &res); + if(GammaPtrCL != nullptr && res == CL_SUCCESS){ + if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ + for (int i = 0; i < size; i++) { + ((half_float::half*)GammaPtrCL)[i] = (half_float::half)(group_norm_param->gamma()->data()[i]); + } + for(int i=size; igamma()->data(), size * sizeof(float)); } - for(int i=size; igetOpenCLRuntime()->commandQueue().enqueueMapBuffer(betaBuffer, true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &res); + if(BetaPtrCL != nullptr && res == CL_SUCCESS){ + if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ + for (int i = 0; i < size; i++) { + ((half_float::half*)BetaPtrCL)[i] = (half_float::half)(group_norm_param->beta()->data()[i]); + } + for(int i=size; ibeta()->data(), size * sizeof(float)); } - }else{ - ::memset(BetaPtrCL, 0, ALIGN_UP4(size) * sizeof(float)); - ::memcpy(BetaPtrCL, group_norm_param->beta()->data(), size * sizeof(float)); + } else { + MNN_ERROR("GroupNorm Beta map error:%d\n", res); } - } else { - MNN_ERROR("GroupNorm Beta map error:%d\n", res); + + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(gammaBuffer, GammaPtrCL); + mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(betaBuffer, BetaPtrCL); } - - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(gammaBuffer, GammaPtrCL); - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(betaBuffer, BetaPtrCL); } } diff --git a/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp b/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp index f9f6e9f720..cdc0f63a28 100644 --- a/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp @@ -39,54 +39,67 @@ LayerNormBufExecution::LayerNormBufExecution(const std::vector &inputs gammasize = layer_norm_param->external()->data()[1] / sizeof(float); } + auto staticMapAlloc = mOpenCLBackend->getStaticAllocatorMMap(); if(mResource->has_gamma_beta_){ { auto error = CL_SUCCESS; int size = gammasize; - mResource->mGammaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, ALIGN_UP4(size) * bufferUnitSize)); - auto GammaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mGammaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &error); - const float* gamma_data = layer_norm_param->gamma()->data(); - if(GammaPtrCL != nullptr && error == CL_SUCCESS){ - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ - for (int i = 0; i < size; i++) - { - ((half_float::half*)GammaPtrCL)[i] = (half_float::half)(gamma_data[i]); - } - for(int i=size; igetRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mGammaBuffer = staticMapAlloc.get()->allocBuffer(ALIGN_UP4(size) * bufferUnitSize); + }else{ + mResource->mGammaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, ALIGN_UP4(size) * bufferUnitSize)); + } + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + auto GammaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mGammaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &error); + const float* gamma_data = layer_norm_param->gamma()->data(); + if(GammaPtrCL != nullptr && error == CL_SUCCESS){ + if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ + for (int i = 0; i < size; i++) + { + ((half_float::half*)GammaPtrCL)[i] = (half_float::half)(gamma_data[i]); + } + for(int i=size; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mGammaBuffer.get(), GammaPtrCL); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mGammaBuffer.get(), GammaPtrCL); } { auto error = CL_SUCCESS; int size = gammasize; - mResource->mBetaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, ALIGN_UP4(size) * bufferUnitSize)); - auto BetaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mBetaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &error); - const float* beta_data = layer_norm_param->beta()->data(); - if(BetaPtrCL != nullptr && error == CL_SUCCESS){ - if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ - for (int i = 0; i < size; i++) - { - ((half_float::half*)BetaPtrCL)[i] = (half_float::half)(beta_data[i]); - } - for(int i=size; igetRuntime()->hint().useCachedMmap && staticMapAlloc != nullptr){ + mResource->mBetaBuffer = staticMapAlloc.get()->allocBuffer(ALIGN_UP4(size) * bufferUnitSize); + }else{ + mResource->mBetaBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, ALIGN_UP4(size) * bufferUnitSize)); + } + if(mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + auto BetaPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mBetaBuffer.get()), true, CL_MAP_WRITE, 0, ALIGN_UP4(size) * bufferUnitSize, nullptr, nullptr, &error); + const float* beta_data = layer_norm_param->beta()->data(); + if(BetaPtrCL != nullptr && error == CL_SUCCESS){ + if(mOpenCLBackend->getPrecision() != BackendConfig::Precision_High){ + for (int i = 0; i < size; i++) + { + ((half_float::half*)BetaPtrCL)[i] = (half_float::half)(beta_data[i]); + } + for(int i=size; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mBetaBuffer.get(), BetaPtrCL); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*mResource->mBetaBuffer.get(), BetaPtrCL); } } } diff --git a/source/backend/opencl/execution/buffer/ReluBufExecution.cpp b/source/backend/opencl/execution/buffer/ReluBufExecution.cpp index 40d55c5e9e..0e097cafca 100644 --- a/source/backend/opencl/execution/buffer/ReluBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ReluBufExecution.cpp @@ -31,24 +31,25 @@ ReluBufExecution::ReluBufExecution(const std::vector &inputs, const MN mOpenCLBackend->onAcquireBuffer(mPreluParam.get(), Backend::STATIC); cl::Buffer &preluBuffer = openCLBuffer(mPreluParam.get()); cl_int error; - auto preluDataPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - preluBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); - if(preluDataPtrCL != nullptr && error == CL_SUCCESS){ - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for(int i=0; igetRuntime()->hint().useCachedMmap <= 1){ + auto preluDataPtrCL = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(preluBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); + if(preluDataPtrCL != nullptr && error == CL_SUCCESS){ + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for(int i=0; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(preluBuffer, preluDataPtrCL); } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(preluBuffer, preluDataPtrCL); } ReluBufExecution::~ReluBufExecution() { diff --git a/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp b/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp index 4d3c8bd1a0..0c27ec3793 100644 --- a/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp @@ -35,27 +35,28 @@ ScaleBufExecution::ScaleBufExecution(const std::vector &inputs, const mScale.reset(Tensor::createDevice({1, 1, 1, ALIGN_UP4(scaleSize)})); backend->onAcquireBuffer(mScale.get(), Backend::STATIC); - - cl::Buffer &scaleBuffer = openCLBuffer(mScale.get()); - cl_int error; - auto scalePtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - scaleBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); - if(nullptr != scalePtrCL && error == CL_SUCCESS){ - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for (int i = 0; i < scaleSize; i++) { - ((half_float::half *)scalePtrCL)[i] = (half_float::half)(scaleDataPtr[i]); - } - for(int i=scaleSize; igetRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &scaleBuffer = openCLBuffer(mScale.get()); + cl_int error; + auto scalePtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(scaleBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); + if(nullptr != scalePtrCL && error == CL_SUCCESS){ + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < scaleSize; i++) { + ((half_float::half *)scalePtrCL)[i] = (half_float::half)(scaleDataPtr[i]); + } + for(int i=scaleSize; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(scaleBuffer, scalePtrCL); } - openclBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(scaleBuffer, scalePtrCL); std::set buildOptions; if (nullptr != scaleParams->biasData() && nullptr != scaleParams->biasData()->data()) { @@ -72,27 +73,27 @@ ScaleBufExecution::ScaleBufExecution(const std::vector &inputs, const mBias.reset(Tensor::createDevice({1, 1, 1, ALIGN_UP4(biasSize)})); backend->onAcquireBuffer(mBias.get(), Backend::STATIC); - cl::Buffer &biasBuffer = openCLBuffer(mBias.get()); - cl_int error; - auto biasPtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer( - biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); - if(nullptr != biasPtrCL && error == CL_SUCCESS){ - if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { - for (int i = 0; i < biasSize; i++) { - ((half_float::half *)biasPtrCL)[i] = (half_float::half)(biasDataPtr[i]); + if (mOpenCLBackend->getRuntime()->hint().useCachedMmap <= 1){ + cl::Buffer &biasBuffer = openCLBuffer(mBias.get()); + cl_int error; + auto biasPtrCL = openclBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(biasBuffer, true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); + if(nullptr != biasPtrCL && error == CL_SUCCESS){ + if (mOpenCLBackend->getPrecision() != BackendConfig::Precision_High) { + for (int i = 0; i < biasSize; i++) { + ((half_float::half *)biasPtrCL)[i] = (half_float::half)(biasDataPtr[i]); + } + for(int i=biasSize; igetOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); } - openclBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); - mBuildOptions.emplace("-DBIAS"); mHasBias = true; } From 023dbea88172f2dab0cbbdc93658267f5a71027b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8B=A5=E9=81=97?= Date: Tue, 23 Dec 2025 19:56:44 +0800 Subject: [PATCH 094/314] refactor mdoel downloader --- apps/Android/MnnLlmChat/app/build.gradle | 2 + apps/Android/MnnLlmChat/settings.gradle | 3 + apps/frameworks/mnn_tts/CMakeLists.txt | 23 +- .../com/taobao/meta/avatar/tts/TtsService.kt | 5 + apps/frameworks/mnn_tts/demo/android/BUILD.md | 389 +++++++ .../mnn_tts/demo/android/QUICKREF.md | 179 ++++ .../frameworks/mnn_tts/demo/android/README.md | 109 ++ .../demo/android/TTS_INTEGRATION_GUIDE.md | 404 ++++++++ .../mnn_tts/demo/android/build.gradle | 2 + apps/frameworks/mnn_tts/demo/android/build.sh | 181 ++++ .../java/com/mnn/tts/demo/DemoActivity.kt | 89 +- .../java/com/mnn/tts/demo/ModelAdapter.kt | 49 + .../demo/android/res/layout/activity_demo.xml | 17 + .../demo/android/res/layout/item_model.xml | 35 + .../mnn_tts/demo/android/settings.gradle | 2 - .../demo/android/src/main/AndroidManifest.xml | 29 +- .../com/alibaba/mnn/tts/demo/MainActivity.kt | 245 ++++- .../com/alibaba/mnn/tts/demo/MnnTtsService.kt | 210 ++++ .../mnn/tts/demo/MnnTtsSettingsActivity.kt | 43 + .../com/alibaba/mnn/tts/demo/ModelAdapter.kt | 205 ++++ .../com/alibaba/mnn/tts/demo/ModelConfig.kt | 6 + .../java/com/mnn/tts/demo/MnnTtsService.kt | 420 ++++++++ .../main/res/drawable/bg_filter_border.xml | 6 + .../src/main/res/drawable/bg_play_button.xml | 4 + .../main/res/drawable/bg_voice_spinner.xml | 7 + .../src/main/res/layout/activity_main.xml | 108 +- .../main/res/layout/activity_tts_settings.xml | 30 + .../src/main/res/layout/item_model.xml | 75 ++ .../src/main/res/layout/spinner_item_dark.xml | 8 + .../android/src/main/res/values/colors.xml | 6 +- .../android/src/main/res/values/strings.xml | 6 +- .../android/src/main/res/values/themes.xml | 8 +- .../android/src/main/res/xml/tts_engine.xml | 3 + .../mnn_tts/include/mnn_tts_config.hpp | 7 + .../mnn_tts/include/mnn_tts_sdk.hpp | 9 +- .../supertonic/mnn_supertonic_tts_impl.hpp | 125 +++ .../mnn_tts/src/android/tts_service.cpp | 6 + .../mnn_tts/src/android/tts_service.hpp | 1 + .../mnn_tts/src/android/tts_service_jni.cpp | 12 + .../frameworks/mnn_tts/src/mnn_tts_config.cpp | 31 + apps/frameworks/mnn_tts/src/mnn_tts_sdk.cpp | 83 +- .../supertonic/mnn_supertonic_tts_impl.cpp | 950 ++++++++++++++++++ apps/frameworks/mnn_tts/tests/test_main.cpp | 11 + .../model_downloader/android/build.gradle | 54 + .../android/src/main/AndroidManifest.xml | 15 + .../api/download/DownloadCoroutineManager.kt | 0 .../mls/api/download/DownloadExecutor.kt | 0 .../mls/api/download/DownloadFileUtils.kt | 0 .../api/download/DownloadForegroundService.kt | 0 .../alibaba/mls/api/download/DownloadInfo.kt | 0 .../mls/api/download/DownloadListener.kt | 0 .../api/download/DownloadPausedException.kt | 0 .../api/download/DownloadPersistentData.kt | 0 .../alibaba/mls/api/download/DownloadState.kt | 0 .../mls/api/download/FileDownloadTask.kt | 0 .../mls/api/download/ModelDownloadManager.kt | 0 .../mls/api/download/ModelFileDownloader.kt | 0 .../mls/api/download/ModelRepoDownloader.kt | 0 .../api/download/hf/HfFileMetadataUtils.kt | 0 .../mls/api/download/hf/HfModelDownloader.kt | 0 .../mls/api/download/hf/HfShaVerifier.kt | 0 .../mls/api/download/ml/MLModelDownloader.kt | 0 .../mls/api/download/ms/MsModelDownloader.kt | 0 .../model_downloader/cpp/CMakeLists.txt | 96 ++ .../cpp/include/dl_config.hpp | 30 + .../cpp}/include/file_utils.hpp | 8 +- .../cpp}/include/hf_api_client.hpp | 2 +- .../cpp}/include/hf_file_metadata.hpp | 4 +- .../cpp}/include/hf_file_metadata_utils.hpp | 4 +- .../cpp}/include/hf_model_downloader.hpp | 4 +- .../cpp}/include/hf_sha_verifier.hpp | 4 +- .../cpp}/include/log_utils.hpp | 38 +- .../cpp}/include/ml_api_client.hpp | 4 +- .../cpp}/include/ml_model_downloader.hpp | 4 +- .../cpp}/include/model_download_manager.hpp | 4 +- .../cpp}/include/model_file_downloader.hpp | 2 +- .../cpp}/include/model_market_data.hpp | 4 +- .../cpp}/include/model_name_utils.hpp | 10 +- .../cpp}/include/model_repo_downloader.hpp | 4 +- .../cpp}/include/model_sources.hpp | 4 +- .../cpp}/include/ms_api_client.hpp | 4 +- .../cpp}/include/ms_model_downloader.hpp | 4 +- .../model_downloader/cpp}/src/file_utils.cpp | 24 +- .../cpp}/src/hf_api_client.cpp | 13 +- .../cpp}/src/hf_file_metadata_utils.cpp | 4 +- .../cpp}/src/hf_model_downloader.cpp | 9 +- .../cpp}/src/hf_sha_verifier.cpp | 4 +- .../model_downloader/cpp}/src/log_utils.cpp | 4 +- .../cpp}/src/ml_api_client.cpp | 10 +- .../cpp}/src/ml_model_downloader.cpp | 4 +- .../cpp}/src/model_download_manager.cpp | 4 +- .../cpp}/src/model_file_downloader.cpp | 2 +- .../cpp}/src/model_market_data.cpp | 4 +- .../cpp}/src/model_name_utils.cpp | 18 +- .../cpp}/src/model_repo_downloader.cpp | 4 +- .../cpp}/src/model_sources.cpp | 4 +- .../cpp}/src/ms_api_client.cpp | 11 +- .../cpp}/src/ms_model_downloader.cpp | 50 +- apps/mnncli/CMakeLists.txt | 209 +--- apps/mnncli/build.sh | 4 +- apps/mnncli/include/cli_download_listener.hpp | 4 +- apps/mnncli/include/mnncli_server.hpp | 2 +- apps/mnncli/include/model_repository.hpp | 6 +- apps/mnncli/include/user_interface.hpp | 4 + apps/mnncli/src/cli_config_manager.cpp | 4 +- apps/mnncli/src/cli_download_listener.cpp | 27 +- apps/mnncli/src/file_utils_config.cpp | 20 +- .../handlers/benchmark_command_handler.cpp | 4 + .../src/handlers/config_command_handler.cpp | 4 + .../src/handlers/delete_command_handler.cpp | 4 + .../src/handlers/download_command_handler.cpp | 4 + .../src/handlers/info_command_handler.cpp | 4 + .../src/handlers/list_command_handler.cpp | 4 + .../handlers/model_info_command_handler.cpp | 4 + .../src/handlers/run_command_handler.cpp | 4 + .../src/handlers/search_command_handler.cpp | 4 + .../src/handlers/serve_command_handler.cpp | 4 + apps/mnncli/src/local_model_utils.cpp | 17 +- apps/mnncli/src/mnncli.cpp | 4 +- apps/mnncli/src/model_manager.cpp | 41 +- apps/mnncli/src/model_repository.cpp | 6 +- 121 files changed, 4564 insertions(+), 439 deletions(-) create mode 100644 apps/frameworks/mnn_tts/demo/android/BUILD.md create mode 100644 apps/frameworks/mnn_tts/demo/android/QUICKREF.md create mode 100644 apps/frameworks/mnn_tts/demo/android/README.md create mode 100644 apps/frameworks/mnn_tts/demo/android/TTS_INTEGRATION_GUIDE.md create mode 100755 apps/frameworks/mnn_tts/demo/android/build.sh create mode 100644 apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/ModelAdapter.kt create mode 100644 apps/frameworks/mnn_tts/demo/android/res/layout/item_model.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsService.kt create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsSettingsActivity.kt create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelAdapter.kt create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelConfig.kt create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/java/com/mnn/tts/demo/MnnTtsService.kt create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_filter_border.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_play_button.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_voice_spinner.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/layout/activity_tts_settings.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/layout/item_model.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/layout/spinner_item_dark.xml create mode 100644 apps/frameworks/mnn_tts/demo/android/src/main/res/xml/tts_engine.xml create mode 100644 apps/frameworks/mnn_tts/include/supertonic/mnn_supertonic_tts_impl.hpp create mode 100644 apps/frameworks/mnn_tts/src/supertonic/mnn_supertonic_tts_impl.cpp create mode 100644 apps/frameworks/model_downloader/android/build.gradle create mode 100644 apps/frameworks/model_downloader/android/src/main/AndroidManifest.xml rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadCoroutineManager.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadExecutor.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadFileUtils.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadForegroundService.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadInfo.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadListener.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadPausedException.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadPersistentData.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/DownloadState.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/FileDownloadTask.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/ModelDownloadManager.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/ModelFileDownloader.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/ModelRepoDownloader.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/hf/HfFileMetadataUtils.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/hf/HfModelDownloader.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/hf/HfShaVerifier.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/ml/MLModelDownloader.kt (100%) rename apps/{Android/MnnLlmChat/app => frameworks/model_downloader/android}/src/main/java/com/alibaba/mls/api/download/ms/MsModelDownloader.kt (100%) create mode 100644 apps/frameworks/model_downloader/cpp/CMakeLists.txt create mode 100644 apps/frameworks/model_downloader/cpp/include/dl_config.hpp rename apps/{mnncli => frameworks/model_downloader/cpp}/include/file_utils.hpp (92%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/hf_api_client.hpp (97%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/hf_file_metadata.hpp (91%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/hf_file_metadata_utils.hpp (97%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/hf_model_downloader.hpp (98%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/hf_sha_verifier.hpp (96%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/log_utils.hpp (70%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/ml_api_client.hpp (96%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/ml_model_downloader.hpp (98%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/model_download_manager.hpp (98%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/model_file_downloader.hpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/model_market_data.hpp (73%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/model_name_utils.hpp (71%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/model_repo_downloader.hpp (98%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/model_sources.hpp (94%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/ms_api_client.hpp (96%) rename apps/{mnncli => frameworks/model_downloader/cpp}/include/ms_model_downloader.hpp (97%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/file_utils.cpp (91%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/hf_api_client.cpp (98%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/hf_file_metadata_utils.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/hf_model_downloader.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/hf_sha_verifier.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/log_utils.cpp (98%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/ml_api_client.cpp (96%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/ml_model_downloader.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/model_download_manager.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/model_file_downloader.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/model_market_data.cpp (86%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/model_name_utils.cpp (93%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/model_repo_downloader.cpp (99%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/model_sources.cpp (97%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/ms_api_client.cpp (96%) rename apps/{mnncli => frameworks/model_downloader/cpp}/src/ms_model_downloader.cpp (94%) diff --git a/apps/Android/MnnLlmChat/app/build.gradle b/apps/Android/MnnLlmChat/app/build.gradle index 3ba3cd16c8..ab40bcefd3 100644 --- a/apps/Android/MnnLlmChat/app/build.gradle +++ b/apps/Android/MnnLlmChat/app/build.gradle @@ -164,6 +164,8 @@ dependencies { // MNN TTS Framework implementation project(':mnn_tts') + implementation project(':model_downloader') + androidTestImplementation 'androidx.test.ext:junit:1.2.1' androidTestImplementation 'androidx.test.espresso:espresso-core:3.6.1' diff --git a/apps/Android/MnnLlmChat/settings.gradle b/apps/Android/MnnLlmChat/settings.gradle index ac77acaf29..91bf25974e 100644 --- a/apps/Android/MnnLlmChat/settings.gradle +++ b/apps/Android/MnnLlmChat/settings.gradle @@ -19,3 +19,6 @@ rootProject.name = "MnnLlmChat" include ':app' include ':mnn_tts' project(':mnn_tts').projectDir = new File('../../frameworks/mnn_tts/android') +include ':model_downloader' +project(':model_downloader').projectDir = new File('../../frameworks/model_downloader/android') + diff --git a/apps/frameworks/mnn_tts/CMakeLists.txt b/apps/frameworks/mnn_tts/CMakeLists.txt index aef81fc3c1..5027dca996 100644 --- a/apps/frameworks/mnn_tts/CMakeLists.txt +++ b/apps/frameworks/mnn_tts/CMakeLists.txt @@ -7,6 +7,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(BUILD_BERTVITS2 "Build BertVit2 TTS " ON) option(BUILD_PIPER "Build PIPER TTS " OFF) +option(BUILD_SUPERTONIC "Build Supertonic TTS " ON) option(BUILD_ANDROID "Build for Android" OFF) if(ANDROID) @@ -40,11 +41,17 @@ endif() if(BUILD_PIPER) include_directories( - ${CMAKE_CURRENT_LIST_DIR}/include/piper + ${CMAKE_CURRENT_LIST_DIR}/include/piper ${CMAKE_CURRENT_LIST_DIR}/third_party/piper/espeak-ng/src/include/espeak-ng/ ) endif() +if(BUILD_SUPERTONIC) + include_directories( + ${CMAKE_CURRENT_LIST_DIR}/include/supertonic + ) +endif() + set(SHARED_SOURCE_FILES ${CMAKE_CURRENT_LIST_DIR}/src/mnn_tts_config.cpp ${CMAKE_CURRENT_LIST_DIR}/src/mnn_tts_sdk.cpp @@ -74,6 +81,12 @@ if(BUILD_PIPER) ) endif() +if(BUILD_SUPERTONIC) + set(SUPERTONIC_SOURCE_FILES + ${CMAKE_CURRENT_LIST_DIR}/src/supertonic/mnn_supertonic_tts_impl.cpp + ) +endif() + if(BUILD_ANDROID) set(ANDROID_SOURCE_FILES ${CMAKE_CURRENT_LIST_DIR}/src/android/tts_service.cpp @@ -85,11 +98,15 @@ if(BUILD_PIPER) add_subdirectory(third_party/piper/espeak-ng) endif() -add_library(${PROJECT_NAME} SHARED ${PIPER_SOURCE_FILES} ${BERTVITS2_SOURCE_FILES} ${SHARED_SOURCE_FILES} ${ANDROID_SOURCE_FILES}) +add_library(${PROJECT_NAME} SHARED ${PIPER_SOURCE_FILES} ${BERTVITS2_SOURCE_FILES} ${SUPERTONIC_SOURCE_FILES} ${SHARED_SOURCE_FILES} ${ANDROID_SOURCE_FILES}) # Add 16KB page size support for Android if(BUILD_ANDROID) target_link_options(${PROJECT_NAME} PRIVATE "-Wl,-z,max-page-size=16384") endif() -target_link_libraries(${PROJECT_NAME} log MNN ) \ No newline at end of file +if(BUILD_ANDROID) + target_link_libraries(${PROJECT_NAME} log MNN) +else() + target_link_libraries(${PROJECT_NAME} MNN) +endif() \ No newline at end of file diff --git a/apps/frameworks/mnn_tts/android/java/com/taobao/meta/avatar/tts/TtsService.kt b/apps/frameworks/mnn_tts/android/java/com/taobao/meta/avatar/tts/TtsService.kt index 1837cd6ea6..8e48549c5b 100644 --- a/apps/frameworks/mnn_tts/android/java/com/taobao/meta/avatar/tts/TtsService.kt +++ b/apps/frameworks/mnn_tts/android/java/com/taobao/meta/avatar/tts/TtsService.kt @@ -56,6 +56,10 @@ class TtsService { return nativeProcess(ttsServiceNative, text, id) } + fun setSpeakerId(speakerId: String) { + nativeSetSpeakerId(ttsServiceNative, speakerId) + } + fun setLanguage(language: String) { if (currentLanguage != language) { currentLanguage = language @@ -74,6 +78,7 @@ class TtsService { resourceDir: String, modelName:String, mmapDir:String): Boolean + private external fun nativeSetSpeakerId(nativePtr: Long, speakerId: String) private external fun nativeProcess(nativePtr: Long, text: String, id: Int): ShortArray companion object { diff --git a/apps/frameworks/mnn_tts/demo/android/BUILD.md b/apps/frameworks/mnn_tts/demo/android/BUILD.md new file mode 100644 index 0000000000..85630c1616 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/BUILD.md @@ -0,0 +1,389 @@ +# MNN TTS Android Demo 构建文档 + +## 项目概述 + +MNN TTS Android Demo 是基于 MNN (Mobile Neural Network) 框架的文本转语音 (Text-to-Speech) 演示应用。该应用展示了如何在 Android 平台上使用 MNN TTS SDK 进行语音合成。 + +## 项目结构 + +``` +mnn_tts/ +├── android/ # MNN TTS Android 库模块 +│ ├── build.gradle # 库模块构建配置 +│ ├── java/ # Java/Kotlin 源代码 +│ └── src/ # 原生 C++ 源代码 +├── demo/android/ # Android Demo 应用 +│ ├── build.gradle # 应用构建配置 +│ ├── settings.gradle # Gradle 项目设置 +│ ├── src/ # 应用源代码 +│ │ └── main/ +│ │ ├── java/ # Kotlin 源代码 +│ │ └── res/ # Android 资源文件 +│ └── build/ # 构建输出目录 +├── include/ # C++ 头文件 +├── src/ # C++ 源代码实现 +└── CMakeLists.txt # CMake 构建配置 +``` + +## 前置要求 + +### 必需的软件和工具 + +1. **Android Studio** (推荐版本: Arctic Fox 或更高) + - 下载地址: https://developer.android.com/studio + +2. **Android SDK** + - Compile SDK: 35 + - Min SDK: 21 (Android 5.0) + - Target SDK: 35 + - Build Tools: 最新版本 + +3. **Android NDK** + - 版本: 27.2.12479018 (推荐) + - NDK 用于编译 C++ 代码 + +4. **Java Development Kit (JDK)** + - 版本: JDK 17 或更高 + - 用于 Gradle 构建 + +5. **Gradle** + - 版本: 8.9 (通过 Gradle Wrapper 自动管理) + +6. **CMake** + - 版本: 3.22.1 或更高 + - 用于构建原生 C++ 代码 + +### 依赖的 MNN 库 + +项目依赖于预编译的 MNN 静态库,位置: +``` +/Users/songjinde/git/MNNX/MNN/project/android/build_64/lib/libMNN.so +``` + +如果该库不存在,需要先构建 MNN 核心库: +```bash +cd /Users/songjinde/git/MNNX/MNN/project/android +./build_64.sh +``` + +## 构建步骤 + +### 方法 1: 使用 Gradle 命令行 (推荐) + +1. **进入项目目录** + ```bash + cd /Users/songjinde/git/MNNX/MNN/apps/frameworks/mnn_tts/demo/android + ``` + +2. **清理之前的构建 (可选)** + ```bash + ./gradlew clean + ``` + +3. **构建 Debug APK** + ```bash + ./gradlew assembleDebug + ``` + +4. **构建 Release APK** + ```bash + ./gradlew assembleRelease + ``` + +5. **查看构建输出** + ```bash + ls -lh build/outputs/apk/debug/ + ``` + + 生成的 APK 文件: + - **Debug**: `build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk` + - **Release**: `build/outputs/apk/release/MNNTTSDemo-arm64-v8a-release.apk` + +### 方法 2: 使用 Android Studio + +1. **打开项目** + - 启动 Android Studio + - 选择 "Open an Existing Project" + - 导航到 `/Users/songjinde/git/MNNX/MNN/apps/frameworks/mnn_tts/demo/android` + - 点击 "OK" + +2. **Gradle 同步** + - Android Studio 会自动开始 Gradle 同步 + - 如果没有自动同步,点击 "File" > "Sync Project with Gradle Files" + +3. **配置构建变体** + - 在左下角选择 "Build Variants" + - 选择 "debug" 或 "release" + +4. **构建 APK** + - 点击 "Build" > "Build Bundle(s) / APK(s)" > "Build APK(s)" + - 或者使用快捷键: Ctrl+Shift+A (Windows/Linux) 或 Cmd+Shift+A (Mac) + +5. **查看构建结果** + - 构建成功后会显示通知 + - 点击 "locate" 查看 APK 文件位置 + +## 构建配置说明 + +### 应用配置 (demo/android/build.gradle) + +```gradle +android { + namespace 'com.alibaba.mnn.tts.demo' + compileSdk 35 // 编译 SDK 版本 + + defaultConfig { + applicationId "com.alibaba.mnn.tts.demo" + minSdk 21 // 最低支持 Android 5.0 + targetSdk 35 // 目标 SDK + versionCode 1 // 应用版本号 + versionName "1.0" // 应用版本名称 + } + + splits { + abi { + enable true + reset() + include 'arm64-v8a' // 仅构建 ARM64 版本 + universalApk false // 不生成通用 APK + } + } +} +``` + +### 库配置 (android/build.gradle) + +```gradle +android { + namespace 'com.alibaba.mnn.tts' + compileSdk 34 + ndkVersion "27.2.12479018" // NDK 版本 + + externalNativeBuild { + cmake { + path file('../CMakeLists.txt') // CMake 配置文件 + version '3.22.1' // CMake 版本 + } + } +} +``` + +### CMake 配置 (CMakeLists.txt) + +关键配置选项: +- `BUILD_BERTVITS2`: 构建 BertVits2 TTS (默认 ON) +- `BUILD_PIPER`: 构建 PIPER TTS (默认 OFF) +- `BUILD_SUPERTONIC`: 构建 Supertonic TTS (默认 ON) +- `BUILD_ANDROID`: Android 平台标志 (自动检测) + +## 依赖库说明 + +### Android 依赖 + +```gradle +dependencies { + implementation project(':mnn_tts') // MNN TTS 库 + implementation 'androidx.appcompat:appcompat:1.6.1' + implementation 'com.google.android.material:material:1.10.0' + implementation 'androidx.constraintlayout:constraintlayout:2.1.4' + implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.7.0' + implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3' + implementation 'androidx.core:core-ktx:1.16.0' + implementation 'androidx.recyclerview:recyclerview:1.3.2' + implementation 'androidx.cardview:cardview:1.0.0' +} +``` + +### 原生库 + +- **libMNN.so**: MNN 核心推理引擎 +- **libmnn_tts.so**: MNN TTS SDK 实现 +- **libc++_shared.so**: C++ 标准库 + +## 安装和运行 + +### 安装到设备 + +1. **使用 Gradle 命令** + ```bash + ./gradlew installDebug + ``` + +2. **使用 adb 命令** + ```bash + adb install build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk + ``` + +3. **使用 Android Studio** + - 点击工具栏的 "Run" 按钮 (绿色三角形) + - 选择目标设备 + - 应用会自动安装并启动 + +### 运行应用 + +1. **启动应用** + - 在设备上找到 "MNNTTSDemo" 应用图标 + - 点击启动 + +2. **使用 adb 启动** + ```bash + adb shell am start -n com.alibaba.mnn.tts.demo/.MainActivity + ``` + +## 常见问题和解决方案 + +### 1. NDK 未找到 + +**错误信息**: NDK not configured + +**解决方案**: +```bash +# 在 local.properties 中配置 NDK 路径 +echo "ndk.dir=/Users/songjinde/Library/Android/sdk/ndk/27.2.12479018" >> local.properties +``` + +### 2. MNN 库未找到 + +**错误信息**: libMNN.so not found + +**解决方案**: +```bash +# 先构建 MNN 核心库 +cd /Users/songjinde/git/MNNX/MNN/project/android +./build_64.sh +``` + +### 3. Gradle 同步失败 + +**错误信息**: Failed to sync Gradle project + +**解决方案**: +```bash +# 清理 Gradle 缓存 +./gradlew clean +rm -rf .gradle +./gradlew build --refresh-dependencies +``` + +### 4. CMake 构建失败 + +**错误信息**: CMake build failed + +**解决方案**: +- 检查 NDK 版本是否正确 +- 确保 CMake 版本 >= 3.22.1 +- 检查 MNN 库是否存在 + +### 5. ABI 不匹配 + +**错误信息**: INSTALL_FAILED_NO_MATCHING_ABIS + +**解决方案**: +- 应用仅支持 ARM64 (arm64-v8a) 设备 +- 确保测试设备是 ARM64 架构 +- 或修改 build.gradle 添加其他 ABI 支持 + +## 性能优化建议 + +### Release 构建优化 + +1. **启用代码混淆** + ```gradle + buildTypes { + release { + minifyEnabled true + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt') + } + } + ``` + +2. **启用资源缩减** + ```gradle + buildTypes { + release { + shrinkResources true + } + } + ``` + +3. **使用 Release NDK 构建** + - Release 构建会自动使用优化的原生库 + +### 运行时优化 + +1. **模型加载**: 首次加载模型时间较长,建议使用异步加载 +2. **内存管理**: 及时释放不再使用的模型资源 +3. **线程池**: 使用合理的线程数量进行推理 + +## 技术架构 + +### 应用架构 + +``` +MainActivity.kt +├── ModelAdapter.kt # 模型列表适配器 +├── AudioChunksPlayer.kt # 音频播放器 +└── MNN TTS SDK + ├── BertVits2 TTS # BertVits2 语音合成 + ├── Supertonic TTS # Supertonic 语音合成 + └── MNN Engine # MNN 推理引擎 +``` + +### 关键功能 + +1. **文本转语音**: 输入文本,生成语音音频 +2. **模型管理**: 支持多种 TTS 模型切换 +3. **音频播放**: 实时播放生成的语音 +4. **性能监控**: 显示推理时间和资源使用 + +## 调试技巧 + +### 查看日志 + +```bash +# 查看应用日志 +adb logcat -s MNN_TTS:* AndroidRuntime:E + +# 查看原生日志 +adb logcat -s DEBUG:* native:* +``` + +### 性能分析 + +1. **使用 Android Profiler** + - 在 Android Studio 中打开 "View" > "Tool Windows" > "Profiler" + - 监控 CPU、内存和网络使用 + +2. **使用 Systrace** + ```bash + python systrace.py -t 10 -o trace.html sched freq idle + ``` + +## 参考资源 + +- **MNN 官方文档**: https://www.yuque.com/mnn/cn +- **MNN GitHub**: https://github.com/alibaba/MNN +- **Android 开发指南**: https://developer.android.com/guide +- **NDK 开发指南**: https://developer.android.com/ndk + +## 版本信息 + +- **应用版本**: 1.0 +- **MNN 版本**: Latest +- **最低 Android 版本**: 5.0 (API 21) +- **目标 Android 版本**: 14.0 (API 35) +- **支持的架构**: ARM64 (arm64-v8a) + +## 许可证 + +本项目遵循 MNN 项目的许可证条款。 + +## 联系方式 + +如有问题或建议,请联系 MNN 项目维护者或提交 Issue。 + +--- + +**最后更新**: 2025-12-21 +**构建状态**: ✅ 成功 +**生成的 APK**: `build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk` (15 MB) diff --git a/apps/frameworks/mnn_tts/demo/android/QUICKREF.md b/apps/frameworks/mnn_tts/demo/android/QUICKREF.md new file mode 100644 index 0000000000..7561b10c11 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/QUICKREF.md @@ -0,0 +1,179 @@ +# MNN TTS Android Demo - 快速参考 + +## 🚀 快速构建 + +```bash +cd /Users/songjinde/git/MNNX/MNN/apps/frameworks/mnn_tts/demo/android + +# 使用构建脚本 (推荐) +./build.sh # 构建 Debug APK +./build.sh release # 构建 Release APK +./build.sh install # 构建并安装到设备 +./build.sh clean # 清理构建 + +# 使用 Gradle 命令 +./gradlew assembleDebug # 构建 Debug +./gradlew assembleRelease # 构建 Release +./gradlew installDebug # 安装 Debug +./gradlew clean # 清理 +``` + +## 📦 构建输出 + +| 构建类型 | APK 路径 | 大小 | +|---------|---------|------| +| Debug | `build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk` | ~15 MB | +| Release | `build/outputs/apk/release/MNNTTSDemo-arm64-v8a-release-unsigned.apk` | ~8 MB | + +## 📱 设备要求 + +- **最低版本**: Android 5.0 (API 21) +- **目标版本**: Android 14 (API 35) +- **架构**: ARM64 (arm64-v8a) +- **权限**: 无特殊权限要求 + +## 🛠️ 开发工具 + +| 工具 | 版本 | +|-----|------| +| Android Studio | Arctic Fox+ | +| Gradle | 8.9 | +| NDK | 27.2.12479018 | +| CMake | 3.22.1+ | +| Kotlin | 1.9.22 | +| JDK | 17+ | + +## 📂 项目配置文件 + +| 文件 | 用途 | +|-----|------| +| `build.gradle` | 应用构建配置 | +| `settings.gradle` | 项目模块配置 | +| `CMakeLists.txt` | 原生代码构建配置 | +| `local.properties` | 本地 SDK/NDK 路径 | +| `gradle.properties` | Gradle 属性配置 | + +## 🔧 常用命令 + +### Gradle 任务 + +```bash +./gradlew tasks # 查看所有任务 +./gradlew build # 完整构建 +./gradlew clean build # 清理并构建 +./gradlew assembleDebug --info # 详细构建日志 +./gradlew assembleDebug --scan # 构建分析 +``` + +### ADB 命令 + +```bash +# 安装 +adb install -r build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk + +# 卸载 +adb uninstall com.alibaba.mnn.tts.demo + +# 启动 +adb shell am start -n com.alibaba.mnn.tts.demo/.MainActivity + +# 停止 +adb shell am force-stop com.alibaba.mnn.tts.demo + +# 查看日志 +adb logcat -s MNN_TTS:* AndroidRuntime:E + +# 清除数据 +adb shell pm clear com.alibaba.mnn.tts.demo +``` + +## 🐛 调试技巧 + +### 查看构建配置 + +```bash +./gradlew app:dependencies # 查看依赖树 +./gradlew :mnn_tts:tasks # 查看库模块任务 +``` + +### 检查 APK 内容 + +```bash +unzip -l build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk +``` + +### 查看 APK 信息 + +```bash +aapt dump badging build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk +``` + +## 🔍 故障排查 + +### 问题: MNN 库未找到 + +```bash +# 检查库是否存在 +ls -la ../../../project/android/build_64/lib/libMNN.so + +# 如果不存在,构建 MNN 库 +cd ../../../project/android +./build_64.sh +``` + +### 问题: NDK 未配置 + +```bash +# 创建或编辑 local.properties +echo "ndk.dir=$HOME/Library/Android/sdk/ndk/27.2.12479018" >> local.properties +echo "sdk.dir=$HOME/Library/Android/sdk" >> local.properties +``` + +### 问题: Gradle 同步失败 + +```bash +# 清理并重新同步 +./gradlew clean +rm -rf .gradle build +./gradlew build --refresh-dependencies +``` + +## 📊 构建时间 + +| 操作 | 预计时间 | +|-----|---------| +| Clean | ~5 秒 | +| 首次构建 | ~2-3 分钟 | +| 增量构建 | ~30-60 秒 | +| 安装到设备 | ~10 秒 | + +## 🎯 关键文件 + +``` +demo/android/ +├── build.sh # 构建脚本 ⭐ +├── BUILD.md # 详细构建文档 📄 +├── README.md # 快速开始 📖 +├── QUICKREF.md # 本文件 📋 +├── build.gradle # 构建配置 ⚙️ +├── settings.gradle # 项目设置 ⚙️ +└── src/main/ + ├── java/ # Kotlin 代码 + ├── res/ # 资源文件 + └── AndroidManifest.xml # 清单文件 +``` + +## 🔗 相关链接 + +- **MNN 文档**: https://www.yuque.com/mnn/cn +- **Android 开发**: https://developer.android.com +- **Gradle 文档**: https://docs.gradle.org +- **Kotlin 文档**: https://kotlinlang.org + +## 📝 版本历史 + +- **v1.0** (2025-12-21): 初始版本,支持 BertVits2 和 Supertonic TTS + +--- + +**提示**: 详细的构建说明请参考 [BUILD.md](BUILD.md) diff --git a/apps/frameworks/mnn_tts/demo/android/README.md b/apps/frameworks/mnn_tts/demo/android/README.md new file mode 100644 index 0000000000..7b2a725a3e --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/README.md @@ -0,0 +1,109 @@ +# MNN TTS Android Demo - 快速开始 + +## 一键构建 + +### 前置条件 + +1. 安装 Android Studio +2. 安装 NDK 27.2.12479018 +3. 确保 MNN 库已构建 (位于 `../../../project/android/build_64/lib/libMNN.so`) + +### 构建命令 + +```bash +cd /Users/songjinde/git/MNNX/MNN/apps/frameworks/mnn_tts/demo/android + +# 清理构建 +./gradlew clean + +# 构建 Debug APK +./gradlew assembleDebug + +# 安装到设备 +./gradlew installDebug +``` + +### 输出位置 + +``` +build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk +``` + +## 系统要求 + +- **最低 Android 版本**: Android 5.0 (API 21) +- **目标 Android 版本**: Android 14 (API 35) +- **支持架构**: ARM64 (arm64-v8a) +- **APK 大小**: 约 15 MB + +## 功能特性 + +- ✅ BertVits2 TTS 语音合成 +- ✅ Supertonic TTS 语音合成 +- ✅ 多模型支持 +- ✅ 实时音频播放 +- ✅ 模型列表管理 + +## 项目结构 + +``` +demo/android/ +├── build.gradle # 应用构建配置 +├── settings.gradle # Gradle 项目设置 +├── src/main/ +│ ├── java/ # Kotlin 源代码 +│ │ └── com/alibaba/mnn/tts/demo/ +│ │ ├── MainActivity.kt +│ │ ├── ModelAdapter.kt +│ │ └── audio/AudioChunksPlayer.kt +│ ├── res/ # Android 资源 +│ └── AndroidManifest.xml +└── build/ # 构建输出 +``` + +## 依赖模块 + +``` +:app (demo) +└── :mnn_tts (库模块) + └── MNN (原生库) +``` + +## 常见问题 + +### Q: 构建失败,提示 NDK 未找到? +A: 在项目根目录创建 `local.properties` 文件,添加: +```properties +ndk.dir=/path/to/your/ndk +sdk.dir=/path/to/your/sdk +``` + +### Q: 运行时提示库文件未找到? +A: 先构建 MNN 核心库: +```bash +cd ../../../project/android +./build_64.sh +``` + +### Q: 安装失败,提示 INSTALL_FAILED_NO_MATCHING_ABIS? +A: 确保设备是 ARM64 架构,或修改 build.gradle 添加其他 ABI 支持。 + +## 更多文档 + +详细的构建文档请查看: [BUILD.md](BUILD.md) + +## 技术栈 + +- **语言**: Kotlin + C++17 +- **构建工具**: Gradle 8.9 + CMake 3.22.1 +- **框架**: MNN (Mobile Neural Network) +- **UI**: Material Design Components + +## 开发者信息 + +基于 MNN 深度学习框架开发的 TTS 演示应用。 + +--- + +**构建时间**: 约 1 分钟 +**最后验证**: 2025-12-21 ✅ diff --git a/apps/frameworks/mnn_tts/demo/android/TTS_INTEGRATION_GUIDE.md b/apps/frameworks/mnn_tts/demo/android/TTS_INTEGRATION_GUIDE.md new file mode 100644 index 0000000000..4c550c2645 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/TTS_INTEGRATION_GUIDE.md @@ -0,0 +1,404 @@ +# MNN TTS 注册为 Android 系统 TTS 服务指南 + +## 概述 + +要将 MNN TTS 注册为 Android 系统 TTS 服务,需要实现 Android 的 `TextToSpeechService` 抽象类,并正确配置相关文件。 + +## 实现步骤 + +### 1. 创建 TTS 服务类 + +需要创建一个继承自 `android.speech.tts.TextToSpeechService` 的服务类: + +**文件位置**: `src/main/java/com/alibaba/mnn/tts/demo/MnnTtsService.kt` + +```kotlin +package com.alibaba.mnn.tts.demo + +import android.speech.tts.SynthesisCallback +import android.speech.tts.SynthesisRequest +import android.speech.tts.TextToSpeech +import android.speech.tts.TextToSpeechService +import android.util.Log +import com.taobao.meta.avatar.tts.TtsService +import java.io.File + +class MnnTtsService : TextToSpeechService() { + + private var ttsService: TtsService? = null + private var isInitialized = false + private val defaultModelPath = "/data/local/tmp/tts_models/default" + + companion object { + private const val TAG = "MnnTtsService" + private val SUPPORTED_LANGUAGES = setOf("zh-CN", "zh_CN", "cmn-Hans-CN") + } + + override fun onCreate() { + super.onCreate() + Log.d(TAG, "MnnTtsService created") + } + + override fun onIsLanguageAvailable(lang: String?, country: String?, variant: String?): Int { + Log.d(TAG, "onIsLanguageAvailable: lang=$lang, country=$country, variant=$variant") + + // 检查是否支持中文 + val locale = buildLocaleString(lang, country) + return when { + SUPPORTED_LANGUAGES.contains(locale) -> TextToSpeech.LANG_COUNTRY_AVAILABLE + lang == "zh" -> TextToSpeech.LANG_AVAILABLE + else -> TextToSpeech.LANG_NOT_SUPPORTED + } + } + + override fun onGetLanguage(): Array { + Log.d(TAG, "onGetLanguage") + // 返回默认语言:中文(中国) + return arrayOf("zh", "CHN", "") + } + + override fun onLoadLanguage(lang: String?, country: String?, variant: String?): Int { + Log.d(TAG, "onLoadLanguage: lang=$lang, country=$country, variant=$variant") + + val locale = buildLocaleString(lang, country) + if (!SUPPORTED_LANGUAGES.contains(locale) && lang != "zh") { + return TextToSpeech.LANG_NOT_SUPPORTED + } + + // 初始化 TTS 引擎 + if (!isInitialized) { + initializeTtsEngine() + } + + return if (isInitialized) { + TextToSpeech.LANG_COUNTRY_AVAILABLE + } else { + TextToSpeech.ERROR + } + } + + override fun onStop() { + Log.d(TAG, "onStop") + // 停止当前的合成任务 + } + + override fun onSynthesizeText(request: SynthesisRequest?, callback: SynthesisCallback?) { + if (request == null || callback == null) { + Log.e(TAG, "Invalid synthesis request or callback") + return + } + + val text = request.charSequenceText?.toString() ?: request.text + if (text.isNullOrEmpty()) { + callback.error() + return + } + + Log.d(TAG, "onSynthesizeText: text=$text, language=${request.language}, country=${request.country}") + + try { + // 确保 TTS 引擎已初始化 + if (!isInitialized) { + initializeTtsEngine() + } + + if (!isInitialized || ttsService == null) { + Log.e(TAG, "TTS engine not initialized") + callback.error() + return + } + + // 等待初始化完成 + val isReady = ttsService?.waitForInitComplete() ?: false + if (!isReady) { + Log.e(TAG, "TTS engine not ready") + callback.error() + return + } + + // 开始合成 + val sampleRate = 44100 + callback.start(sampleRate, android.media.AudioFormat.ENCODING_PCM_16BIT, 1) + + // 使用 TTS 服务处理文本 + val audioData = ttsService?.process(text, 0) + + if (audioData != null && audioData.isNotEmpty()) { + Log.d(TAG, "Generated ${audioData.size} audio samples") + + // 将 FloatArray 转换为 ByteArray (PCM 16-bit) + val maxBufferSize = callback.maxBufferSize + val byteBuffer = ByteArray(maxBufferSize) + var offset = 0 + + for (sample in audioData) { + // 转换 float 到 16-bit PCM + val pcmValue = (sample * 32767f).toInt().coerceIn(-32768, 32767).toShort() + + // 写入字节(小端序) + byteBuffer[offset++] = (pcmValue.toInt() and 0xFF).toByte() + byteBuffer[offset++] = ((pcmValue.toInt() shr 8) and 0xFF).toByte() + + // 当缓冲区满时,发送数据 + if (offset >= maxBufferSize - 2) { + callback.audioAvailable(byteBuffer, 0, offset) + offset = 0 + } + } + + // 发送剩余数据 + if (offset > 0) { + callback.audioAvailable(byteBuffer, 0, offset) + } + + callback.done() + Log.d(TAG, "Synthesis completed successfully") + } else { + Log.e(TAG, "No audio data generated") + callback.error() + } + + } catch (e: Exception) { + Log.e(TAG, "Error during synthesis", e) + callback.error() + } + } + + private fun initializeTtsEngine() { + try { + Log.d(TAG, "Initializing TTS engine with model: $defaultModelPath") + + // 检查模型文件是否存在 + val modelDir = File(defaultModelPath) + if (!modelDir.exists() || !modelDir.isDirectory) { + Log.e(TAG, "Model directory not found: $defaultModelPath") + return + } + + val configFile = File(modelDir, "config.json") + if (!configFile.exists()) { + Log.e(TAG, "config.json not found in model directory") + return + } + + // 初始化 TTS 服务 + ttsService = TtsService() + val initResult = ttsService?.init(defaultModelPath) ?: false + + if (initResult) { + isInitialized = true + Log.d(TAG, "TTS engine initialized successfully") + } else { + Log.e(TAG, "Failed to initialize TTS engine") + ttsService = null + } + + } catch (e: Exception) { + Log.e(TAG, "Error initializing TTS engine", e) + ttsService = null + isInitialized = false + } + } + + override fun onDestroy() { + super.onDestroy() + Log.d(TAG, "MnnTtsService destroyed") + + try { + ttsService?.destroy() + ttsService = null + isInitialized = false + } catch (e: Exception) { + Log.e(TAG, "Error destroying TTS service", e) + } + } + + private fun buildLocaleString(lang: String?, country: String?): String { + return when { + lang.isNullOrEmpty() -> "" + country.isNullOrEmpty() -> lang + else -> "$lang-$country" + } + } +} +``` + +### 2. 创建 TTS 设置 Activity + +**文件位置**: `src/main/java/com/alibaba/mnn/tts/demo/MnnTtsSettingsActivity.kt` + +```kotlin +package com.alibaba.mnn.tts.demo + +import android.os.Bundle +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import java.io.File + +class MnnTtsSettingsActivity : AppCompatActivity() { + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_tts_settings) + + // 显示当前配置信息 + val infoText = findViewById(R.id.settingsInfoText) + val modelPath = "/data/local/tmp/tts_models/default" + val modelExists = File(modelPath).exists() + + infoText.text = """ + MNN TTS Engine Settings + + Model Path: $modelPath + Model Status: ${if (modelExists) "Available" else "Not Found"} + + Supported Languages: + - Chinese (China): zh-CN + + Note: Please ensure TTS models are placed in: + /data/local/tmp/tts_models/ + """.trimIndent() + } +} +``` + +### 3. 创建设置界面布局 + +**文件位置**: `src/main/res/layout/activity_tts_settings.xml` + +```xml + + + + + + +``` + +### 4. 创建 TTS 引擎配置文件 + +**文件位置**: `src/main/res/xml/tts_engine.xml` + +```xml + + +``` + +**注意**: +- 简化的配置文件只包含必需的 `settingsActivity` 属性 +- `` 标签中的属性(如 `android:locale`, `android:gender` 等)在较低 API 级别不支持,已移除 +- 语言支持通过 `MnnTtsService` 中的 `onIsLanguageAvailable()` 和 `onGetLanguage()` 方法实现 + +### 5. 更新 AndroidManifest.xml + +AndroidManifest.xml 已经包含了必要的配置: + +```xml + + + + + + + + + + + + + + +``` + +### 6. 更新 strings.xml + +**文件位置**: `src/main/res/values/strings.xml` + +添加以下字符串资源: + +```xml +MNN TTS Engine +MNN TTS Settings +``` + +## 使用方法 + +### 1. 准备模型文件 + +将 TTS 模型文件放置到设备的以下目录: +``` +/data/local/tmp/tts_models/default/ +├── config.json +├── model.mnn +└── (其他模型文件) +``` + +使用 adb 命令推送模型: +```bash +adb push /path/to/model /data/local/tmp/tts_models/default/ +``` + +### 2. 安装应用 + +```bash +./gradlew installDebug +``` + +### 3. 在系统设置中启用 + +1. 打开 **设置** → **系统** → **语言和输入法** → **文字转语音输出** +2. 选择 **首选引擎** → **MNN TTS Engine** +3. 点击设置图标可以查看引擎配置信息 + +### 4. 测试 TTS + +在系统 TTS 设置页面点击"播放"按钮测试,或使用以下代码: + +```kotlin +val tts = TextToSpeech(context) { status -> + if (status == TextToSpeech.SUCCESS) { + tts.language = Locale.CHINA + tts.speak("你好,这是MNN TTS测试", TextToSpeech.QUEUE_FLUSH, null, null) + } +} +``` + +## 关键要点 + +1. **服务生命周期**: `TextToSpeechService` 由系统管理,会在需要时创建和销毁 +2. **线程安全**: 合成方法可能在后台线程调用,需要注意线程安全 +3. **音频格式**: 必须使用 PCM 16-bit 格式输出音频数据 +4. **语言支持**: 通过 `onIsLanguageAvailable` 声明支持的语言 +5. **模型路径**: 确保模型文件路径可访问且包含必要的配置文件 + +## 调试技巧 + +1. 使用 `adb logcat -s MnnTtsService` 查看 TTS 服务日志 +2. 检查 `/data/local/tmp/tts_models/` 目录权限 +3. 确保应用有必要的权限(已在 AndroidManifest.xml 中声明) +4. 在系统 TTS 设置中测试引擎是否正常工作 + +## 常见问题 + +### Q: 系统设置中看不到 MNN TTS Engine +**A**: 检查 AndroidManifest.xml 中的 service 配置,确保 `android:exported="true"` 且包含正确的 intent-filter + +### diff --git a/apps/frameworks/mnn_tts/demo/android/build.gradle b/apps/frameworks/mnn_tts/demo/android/build.gradle index 46eec8adea..2d685d5018 100644 --- a/apps/frameworks/mnn_tts/demo/android/build.gradle +++ b/apps/frameworks/mnn_tts/demo/android/build.gradle @@ -54,6 +54,8 @@ dependencies { implementation 'org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3' implementation 'androidx.core:core-ktx:1.12.0' implementation 'androidx.core:core-ktx:1.16.0' + implementation 'androidx.recyclerview:recyclerview:1.3.2' + implementation 'androidx.cardview:cardview:1.0.0' testImplementation 'junit:junit:4.13.2' androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' diff --git a/apps/frameworks/mnn_tts/demo/android/build.sh b/apps/frameworks/mnn_tts/demo/android/build.sh new file mode 100755 index 0000000000..c2a52d95d2 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/build.sh @@ -0,0 +1,181 @@ +#!/bin/bash +# MNN TTS Android Demo 构建脚本 +# 使用方法: ./build.sh [debug|release|clean|install] + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# 颜色输出 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +function print_info() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +function print_warn() { + echo -e "${YELLOW}[WARN]${NC} $1" +} + +function print_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +function check_prerequisites() { + print_info "检查前置条件..." + + # 检查 MNN 库 + MNN_LIB_PATH="../../../project/android/build_64/lib/libMNN.so" + if [ ! -f "$MNN_LIB_PATH" ]; then + print_error "MNN 库未找到: $MNN_LIB_PATH" + print_info "请先构建 MNN 库:" + print_info " cd ../../../project/android" + print_info " ./build_64.sh" + exit 1 + fi + + # 检查 Java + if ! command -v java &> /dev/null; then + print_error "Java 未安装" + exit 1 + fi + + # 检查 Gradle Wrapper + if [ ! -f "./gradlew" ]; then + print_error "Gradle Wrapper 未找到" + exit 1 + fi + + print_info "前置条件检查通过 ✓" +} + +function clean_build() { + print_info "清理构建目录..." + ./gradlew clean + print_info "清理完成 ✓" +} + +function build_debug() { + print_info "开始构建 Debug APK..." + ./gradlew assembleDebug + + APK_PATH="build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk" + if [ -f "$APK_PATH" ]; then + APK_SIZE=$(ls -lh "$APK_PATH" | awk '{print $5}') + print_info "构建成功! ✓" + print_info "APK 位置: $APK_PATH" + print_info "APK 大小: $APK_SIZE" + else + print_error "构建失败,APK 未生成" + exit 1 + fi +} + +function build_release() { + print_info "开始构建 Release APK..." + ./gradlew assembleRelease + + APK_PATH="build/outputs/apk/release/MNNTTSDemo-arm64-v8a-release-unsigned.apk" + if [ -f "$APK_PATH" ]; then + APK_SIZE=$(ls -lh "$APK_PATH" | awk '{print $5}') + print_info "构建成功! ✓" + print_info "APK 位置: $APK_PATH" + print_info "APK 大小: $APK_SIZE" + print_warn "注意: Release APK 未签名,需要签名后才能发布" + else + print_error "构建失败,APK 未生成" + exit 1 + fi +} + +function install_debug() { + print_info "安装 Debug APK 到设备..." + + # 检查设备连接 + if ! command -v adb &> /dev/null; then + print_error "adb 未找到,请确保 Android SDK Platform-Tools 已安装" + exit 1 + fi + + DEVICE_COUNT=$(adb devices | grep -v "List" | grep "device$" | wc -l) + if [ "$DEVICE_COUNT" -eq 0 ]; then + print_error "未检测到 Android 设备" + print_info "请确保:" + print_info " 1. 设备已通过 USB 连接" + print_info " 2. 设备已开启 USB 调试" + print_info " 3. 已授权计算机进行 USB 调试" + exit 1 + fi + + print_info "检测到 $DEVICE_COUNT 个设备" + + APK_PATH="build/outputs/apk/debug/MNNTTSDemo-arm64-v8a-debug.apk" + if [ ! -f "$APK_PATH" ]; then + print_warn "APK 不存在,先构建..." + build_debug + fi + + print_info "正在安装..." + ./gradlew installDebug + + print_info "安装完成! ✓" + print_info "应用包名: com.alibaba.mnn.tts.demo" + print_info "" + print_info "启动应用:" + print_info " adb shell am start -n com.alibaba.mnn.tts.demo/.MainActivity" +} + +function show_usage() { + echo "MNN TTS Android Demo 构建脚本" + echo "" + echo "使用方法: $0 [command]" + echo "" + echo "可用命令:" + echo " debug - 构建 Debug APK (默认)" + echo " release - 构建 Release APK" + echo " clean - 清理构建目录" + echo " install - 构建并安装 Debug APK 到设备" + echo " help - 显示此帮助信息" + echo "" + echo "示例:" + echo " $0 # 构建 Debug APK" + echo " $0 debug # 构建 Debug APK" + echo " $0 release # 构建 Release APK" + echo " $0 clean # 清理构建" + echo " $0 install # 安装到设备" + echo "" +} + +# 主流程 +case "${1:-debug}" in + debug) + check_prerequisites + build_debug + ;; + release) + check_prerequisites + build_release + ;; + clean) + clean_build + ;; + install) + check_prerequisites + install_debug + ;; + help|--help|-h) + show_usage + ;; + *) + print_error "未知命令: $1" + echo "" + show_usage + exit 1 + ;; +esac + +print_info "所有操作完成! ✓" diff --git a/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/DemoActivity.kt b/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/DemoActivity.kt index f408b8a314..9e65090bde 100644 --- a/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/DemoActivity.kt +++ b/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/DemoActivity.kt @@ -3,6 +3,7 @@ package com.mnn.tts.demo import android.Manifest import android.content.pm.PackageManager import android.os.Bundle +import android.util.Log import android.widget.Button import android.widget.EditText import android.widget.RadioGroup @@ -10,15 +11,24 @@ import android.widget.Toast import androidx.appcompat.app.AppCompatActivity import androidx.core.app.ActivityCompat import androidx.core.content.ContextCompat +import androidx.lifecycle.lifecycleScope +import androidx.recyclerview.widget.LinearLayoutManager +import androidx.recyclerview.widget.RecyclerView +import com.alibaba.mnn.tts.demo.R import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File class DemoActivity : AppCompatActivity() { private lateinit var ttsDemo: TtsServiceDemo private lateinit var textInput: EditText private lateinit var speakButton: Button private lateinit var languageGroup: RadioGroup + private lateinit var modelRecyclerView: RecyclerView + private lateinit var modelAdapter: ModelAdapter + private var selectedModelPath: String? = null override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) @@ -27,9 +37,13 @@ class DemoActivity : AppCompatActivity() { textInput = findViewById(R.id.textInput) speakButton = findViewById(R.id.speakButton) languageGroup = findViewById(R.id.languageGroup) + modelRecyclerView = findViewById(R.id.modelRecyclerView) ttsDemo = TtsServiceDemo(this) + // 初始化模型列表 + initModelList() + speakButton.setOnClickListener { val text = textInput.text.toString() if (text.isNotEmpty()) { @@ -46,6 +60,68 @@ class DemoActivity : AppCompatActivity() { checkPermissionAndInitialize() } + private fun initModelList() { + modelAdapter = ModelAdapter { modelPath -> + selectedModelPath = modelPath + loadTtsModel(modelPath) + } + modelRecyclerView.layoutManager = LinearLayoutManager(this) + modelRecyclerView.adapter = modelAdapter + + lifecycleScope.launch { + val models = scanTtsModels() + modelAdapter.updateModels(models) + if (models.isEmpty()) { + Toast.makeText(this@DemoActivity, "No TTS models found in /data/local/tmp/tts_models", Toast.LENGTH_LONG).show() + } + } + } + + private suspend fun scanTtsModels(): List = withContext(Dispatchers.IO) { + val modelsDir = File("/data/local/tmp/tts_models") + val modelList = mutableListOf() + + try { + if (modelsDir.exists() && modelsDir.isDirectory) { + modelsDir.listFiles()?.forEach { file -> + if (file.isDirectory) { + // 检查是否包含 config.json(可选,也可以直接添加所有文件夹) + val configFile = File(file, "config.json") + if (configFile.exists()) { + modelList.add(file.absolutePath) + Log.d(TAG, "Found model: ${file.absolutePath}") + } else { + // 如果没有 config.json,也添加文件夹(可能是有效的模型目录) + modelList.add(file.absolutePath) + Log.d(TAG, "Found model directory (no config.json): ${file.absolutePath}") + } + } + } + } else { + Log.w(TAG, "Models directory does not exist: ${modelsDir.absolutePath}") + } + } catch (e: Exception) { + Log.e(TAG, "Error scanning models directory", e) + } + + modelList.sorted() + } + + private fun loadTtsModel(modelPath: String) { + lifecycleScope.launch { + try { + Toast.makeText(this@DemoActivity, "Loading model: ${File(modelPath).name}...", Toast.LENGTH_SHORT).show() + Log.d(TAG, "Loading TTS model: $modelPath") + ttsDemo.initialize(modelPath) + speakButton.isEnabled = true + Toast.makeText(this@DemoActivity, "Model loaded: ${File(modelPath).name}", Toast.LENGTH_SHORT).show() + } catch (e: Exception) { + Log.e(TAG, "Error loading TTS model", e) + Toast.makeText(this@DemoActivity, "Failed to load model: ${e.message}", Toast.LENGTH_SHORT).show() + } + } + } + private fun checkPermissionAndInitialize() { if (ContextCompat.checkSelfPermission( this, @@ -63,14 +139,10 @@ class DemoActivity : AppCompatActivity() { } private fun initializeTts() { - CoroutineScope(Dispatchers.Main).launch { - try { - val modelDir = getExternalFilesDir(null)?.absolutePath + "/tts_models" - ttsDemo.initialize(modelDir) - speakButton.isEnabled = true - } catch (e: Exception) { - Toast.makeText(this@DemoActivity, "Failed to initialize TTS", Toast.LENGTH_SHORT).show() - } + // 不再在这里初始化,而是等待用户从列表中选择模型 + // 如果已经有选中的模型,则加载它 + if (selectedModelPath != null) { + loadTtsModel(selectedModelPath!!) } } @@ -97,5 +169,6 @@ class DemoActivity : AppCompatActivity() { companion object { private const val PERMISSION_REQUEST_CODE = 1001 + private const val TAG = "DemoActivity" } } \ No newline at end of file diff --git a/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/ModelAdapter.kt b/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/ModelAdapter.kt new file mode 100644 index 0000000000..cf2107f2b5 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/java/com/mnn/tts/demo/ModelAdapter.kt @@ -0,0 +1,49 @@ +package com.mnn.tts.demo + +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import android.widget.TextView +import androidx.recyclerview.widget.RecyclerView +import com.alibaba.mnn.tts.demo.R +import java.io.File + +class ModelAdapter( + private val onModelSelected: (String) -> Unit +) : RecyclerView.Adapter() { + + private var models: List = emptyList() + + fun updateModels(newModels: List) { + models = newModels + notifyDataSetChanged() + } + + override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ModelViewHolder { + val view = LayoutInflater.from(parent.context) + .inflate(R.layout.item_model, parent, false) + return ModelViewHolder(view) + } + + override fun onBindViewHolder(holder: ModelViewHolder, position: Int) { + val modelPath = models[position] + val modelFile = File(modelPath) + holder.bind(modelFile.name, modelPath) + holder.itemView.setOnClickListener { + onModelSelected(modelPath) + } + } + + override fun getItemCount(): Int = models.size + + class ModelViewHolder(itemView: View) : RecyclerView.ViewHolder(itemView) { + private val modelNameText: TextView = itemView.findViewById(R.id.modelNameText) + private val modelPathText: TextView = itemView.findViewById(R.id.modelPathText) + + fun bind(modelName: String, modelPath: String) { + modelNameText.text = modelName + modelPathText.text = modelPath + } + } +} + diff --git a/apps/frameworks/mnn_tts/demo/android/res/layout/activity_demo.xml b/apps/frameworks/mnn_tts/demo/android/res/layout/activity_demo.xml index 5cec422343..8780c80c53 100644 --- a/apps/frameworks/mnn_tts/demo/android/res/layout/activity_demo.xml +++ b/apps/frameworks/mnn_tts/demo/android/res/layout/activity_demo.xml @@ -5,6 +5,23 @@ android:orientation="vertical" android:padding="16dp"> + + + + + + + + + + + + + + + + diff --git a/apps/frameworks/mnn_tts/demo/android/settings.gradle b/apps/frameworks/mnn_tts/demo/android/settings.gradle index a4357bbe15..d245580d63 100644 --- a/apps/frameworks/mnn_tts/demo/android/settings.gradle +++ b/apps/frameworks/mnn_tts/demo/android/settings.gradle @@ -19,7 +19,5 @@ dependencyResolutionManagement { } rootProject.name = "MNNTTSDemo" -include ':app' include ':mnn_tts' -project(':app').projectDir = new File('./app') project(':mnn_tts').projectDir = new File('../../android') \ No newline at end of file diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/AndroidManifest.xml b/apps/frameworks/mnn_tts/demo/android/src/main/AndroidManifest.xml index c22d93fd44..4fee5143c9 100644 --- a/apps/frameworks/mnn_tts/demo/android/src/main/AndroidManifest.xml +++ b/apps/frameworks/mnn_tts/demo/android/src/main/AndroidManifest.xml @@ -6,7 +6,11 @@ - + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MainActivity.kt b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MainActivity.kt index 728dceada6..b374b04f56 100644 --- a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MainActivity.kt +++ b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MainActivity.kt @@ -2,12 +2,22 @@ package com.alibaba.mnn.tts.demo import android.os.Bundle import android.util.Log +import android.view.View +import android.widget.AdapterView +import android.widget.ArrayAdapter import android.widget.Button import android.widget.EditText +import android.widget.Spinner import android.widget.TextView +import org.json.JSONObject import androidx.appcompat.app.AppCompatActivity import androidx.lifecycle.lifecycleScope +import androidx.recyclerview.widget.LinearLayoutManager +import androidx.recyclerview.widget.RecyclerView +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File import com.taobao.meta.avatar.tts.TtsService import com.alibaba.mnn.tts.demo.audio.AudioChunksPlayer @@ -16,15 +26,25 @@ class MainActivity : AppCompatActivity() { private lateinit var resultText: TextView private lateinit var inputText: EditText private lateinit var processButton: Button + private lateinit var modelRecyclerView: RecyclerView + private lateinit var languageSpinner: Spinner private lateinit var ttsService: TtsService private lateinit var audioPlayer: AudioChunksPlayer + private lateinit var modelAdapter: ModelAdapter + + // State + private var allModels: List> = emptyList() + private var selectedModelPath: String? = null + private var currentSpeakerId: String = "" + private var currentLanguage: String = "en" + private var isTtsInitialized = false override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) setContentView(R.layout.activity_main) initViews() - initTtsService() + initModelList() initAudioPlayer() setupTtsTest() } @@ -33,28 +53,228 @@ class MainActivity : AppCompatActivity() { resultText = findViewById(R.id.resultText) inputText = findViewById(R.id.inputText) processButton = findViewById(R.id.processButton) + modelRecyclerView = findViewById(R.id.modelRecyclerView) + languageSpinner = findViewById(R.id.languageSpinner) } - private fun initTtsService() { + private fun initModelList() { + modelAdapter = ModelAdapter( + onModelSelected = { modelPath, config -> + if (selectedModelPath != modelPath) { + selectedModelPath = modelPath + currentSpeakerId = "" // Reset or pick first? + loadTtsModel(modelPath) + } + }, + onSpeakerSelected = { speakerId -> + currentSpeakerId = speakerId + Log.d("TTS_TEST", "Speaker selected: $speakerId") + }, + onPlayClicked = { modelPath, speakerId -> + val text = inputText.text.toString().trim() + if (text.isEmpty()) { + resultText.text = "Please enter some text" + return@ModelAdapter + } + + currentSpeakerId = speakerId + if (selectedModelPath == modelPath && isTtsInitialized) { + processTtsText(text) + } else { + // Load and then play + selectedModelPath = modelPath + loadTtsModelAndPlay(modelPath, text) + } + } + ) + modelRecyclerView.layoutManager = LinearLayoutManager(this) + modelRecyclerView.adapter = modelAdapter + + lifecycleScope.launch { + allModels = scanTtsModels() + setupLanguageFilter() + } + } + + private fun loadTtsModelAndPlay(modelPath: String, text: String) { + if (isTtsInitialized) { + try { + ttsService.destroy() + isTtsInitialized = false + } catch (e: Exception) { + Log.e("TTS_TEST", "Error destroying previous TTS service", e) + } + } + + ttsService = TtsService() + + lifecycleScope.launch { + try { + resultText.text = "Loading model: ${File(modelPath).name}..." + val initResult = ttsService.init(modelPath) + if (initResult) { + isTtsInitialized = true + if (currentLanguage.isNotEmpty()) { + ttsService.setLanguage(currentLanguage) + } + processTtsText(text) + } else { + resultText.text = "Failed to load model: ${File(modelPath).name}" + } + } catch (e: Exception) { + Log.e("TTS_TEST", "Error initializing TTS service", e) + resultText.text = "Error loading model: ${e.message}" + } + } + } + + private suspend fun scanTtsModels(): List> = withContext(Dispatchers.IO) { + val modelsDir = File("/data/local/tmp/tts_models") + val modelList = mutableListOf>() + + try { + if (modelsDir.exists() && modelsDir.isDirectory) { + modelsDir.listFiles()?.forEach { file -> + if (file.isDirectory) { + // Check if it contains config.json + val configFile = File(file, "config.json") + if (configFile.exists()) { + val config = readModelConfig(file.absolutePath) + modelList.add(file.absolutePath to config) + Log.d("TTS_TEST", "Found model: ${file.absolutePath}") + } + } + } + } else { + Log.w("TTS_TEST", "Models directory does not exist: ${modelsDir.absolutePath}") + } + } catch (e: Exception) { + Log.e("TTS_TEST", "Error scanning models directory", e) + } + + modelList.sortedBy { it.first } + } + + private fun setupLanguageFilter() { + // Collect all unique languages + val languages = mutableSetOf() + allModels.forEach { (_, config) -> + if (config.languages.isNotEmpty()) { + languages.addAll(config.languages) + } + } + + // Convert to list and sort + val langList = languages.toList().sorted().toMutableList() + if (langList.isEmpty()) langList.add("en") // Default fallback + + val langAdapter = ArrayAdapter(this, android.R.layout.simple_spinner_item, langList) + langAdapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) + languageSpinner.adapter = langAdapter + + languageSpinner.onItemSelectedListener = object : AdapterView.OnItemSelectedListener { + override fun onItemSelected(parent: AdapterView<*>?, view: View?, position: Int, id: Long) { + currentLanguage = langList[position] + // Apply filter + filterModels(currentLanguage) + + // Update TTS service language if initialized + if (isTtsInitialized) { + ttsService.setLanguage(currentLanguage) + } + } + override fun onNothingSelected(parent: AdapterView<*>?) {} + } + + // Initial filter + if (langList.isNotEmpty()) { + currentLanguage = langList[0] + filterModels(currentLanguage) + } + } + + private fun filterModels(language: String) { + val filtered = allModels.filter { (_, config) -> + config.languages.isEmpty() || config.languages.contains(language) + }.map { it.first } + + modelAdapter.updateModels(filtered) + + if (filtered.isEmpty()) { + resultText.text = "No models found for language: $language" + } else { + resultText.text = "Found ${filtered.size} models" + } + } + + private fun loadTtsModel(modelPath: String) { + if (isTtsInitialized) { + try { + ttsService.destroy() + isTtsInitialized = false + } catch (e: Exception) { + Log.e("TTS_TEST", "Error destroying previous TTS service", e) + } + } + ttsService = TtsService() + lifecycleScope.launch { try { - val modelDir = "/data/local/tmp/test_new_tts/bert-vits/" - val initResult = ttsService.init(modelDir) + resultText.text = "Loading model: ${File(modelPath).name}..." + Log.d("TTS_TEST", "Initializing TTS Service with model: $modelPath") + val initResult = ttsService.init(modelPath) if (initResult) { + isTtsInitialized = true Log.d("TTS_TEST", "TTS Service initialized successfully") - resultText.text = "TTS Service ready" + resultText.text = "Model loaded: ${File(modelPath).name}\nTTS Service ready" + + // Set initial language if selected + if (currentLanguage.isNotEmpty()) { + ttsService.setLanguage(currentLanguage) + } } else { Log.e("TTS_TEST", "TTS Service initialization failed") - resultText.text = "TTS Service initialization failed" + resultText.text = "Failed to load model: ${File(modelPath).name}" } } catch (e: Exception) { Log.e("TTS_TEST", "Error initializing TTS service", e) - resultText.text = "Error: ${e.message}" + resultText.text = "Error loading model: ${e.message}" } } } + private fun readModelConfig(modelPath: String): ModelConfig { + try { + val configFile = File(modelPath, "config.json") + if (configFile.exists()) { + val content = configFile.readText() + val json = JSONObject(content) + + val speakers = mutableListOf() + if (json.has("speakers")) { + val speakersJson = json.getJSONArray("speakers") + for (i in 0 until speakersJson.length()) { + speakers.add(speakersJson.getString(i)) + } + } + + val languages = mutableListOf() + if (json.has("languages")) { + val languagesJson = json.getJSONArray("languages") + for (i in 0 until languagesJson.length()) { + languages.add(languagesJson.getString(i)) + } + } + + return ModelConfig(speakers, languages) + } + } catch (e: Exception) { + Log.e("TTS_TEST", "Error reading config.json", e) + } + return ModelConfig() + } + private fun initAudioPlayer() { audioPlayer = AudioChunksPlayer() audioPlayer.sampleRate = 44100 // Common TTS sample rate @@ -63,6 +283,11 @@ class MainActivity : AppCompatActivity() { private fun setupTtsTest() { processButton.setOnClickListener { + if (!isTtsInitialized || selectedModelPath == null) { + resultText.text = "Please select a model first" + return@setOnClickListener + } + val text = inputText.text.toString().trim() if (text.isNotEmpty()) { processTtsText(text) @@ -86,6 +311,12 @@ class MainActivity : AppCompatActivity() { resultText.text = "Processing: $text" // Process text with TTS + if (currentLanguage.isNotEmpty()) { + ttsService.setLanguage(currentLanguage) + } + if (currentSpeakerId.isNotEmpty()) { + ttsService.setSpeakerId(currentSpeakerId) + } val audioData = ttsService.process(text, 0) Log.d("TTS_TEST", "Generated audio data with ${audioData.size} samples") diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsService.kt b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsService.kt new file mode 100644 index 0000000000..8044a6b33b --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsService.kt @@ -0,0 +1,210 @@ +package com.alibaba.mnn.tts.demo + +import android.speech.tts.SynthesisCallback +import android.speech.tts.SynthesisRequest +import android.speech.tts.TextToSpeech +import android.speech.tts.TextToSpeechService +import android.util.Log +import com.taobao.meta.avatar.tts.TtsService +import java.io.File + +class MnnTtsService : TextToSpeechService() { + + private var ttsService: TtsService? = null + private var isInitialized = false + private val defaultModelPath = "/data/local/tmp/tts_models/default" + + companion object { + private const val TAG = "MnnTtsService" + private val SUPPORTED_LANGUAGES = setOf("zh-CN", "zh_CN", "cmn-Hans-CN") + } + + override fun onCreate() { + super.onCreate() + Log.d(TAG, "MnnTtsService created") + } + + override fun onIsLanguageAvailable(lang: String?, country: String?, variant: String?): Int { + Log.d(TAG, "onIsLanguageAvailable: lang=$lang, country=$country, variant=$variant") + + // 检查是否支持中文 + val locale = buildLocaleString(lang, country) + return when { + SUPPORTED_LANGUAGES.contains(locale) -> TextToSpeech.LANG_COUNTRY_AVAILABLE + lang == "zh" -> TextToSpeech.LANG_AVAILABLE + else -> TextToSpeech.LANG_NOT_SUPPORTED + } + } + + override fun onGetLanguage(): Array { + Log.d(TAG, "onGetLanguage") + // 返回默认语言:中文(中国) + return arrayOf("zh", "CHN", "") + } + + override fun onLoadLanguage(lang: String?, country: String?, variant: String?): Int { + Log.d(TAG, "onLoadLanguage: lang=$lang, country=$country, variant=$variant") + + val locale = buildLocaleString(lang, country) + if (!SUPPORTED_LANGUAGES.contains(locale) && lang != "zh") { + return TextToSpeech.LANG_NOT_SUPPORTED + } + + // 初始化 TTS 引擎 + if (!isInitialized) { + initializeTtsEngine() + } + + return if (isInitialized) { + TextToSpeech.LANG_COUNTRY_AVAILABLE + } else { + TextToSpeech.ERROR + } + } + + override fun onStop() { + Log.d(TAG, "onStop") + // 停止当前的合成任务 + } + + override fun onSynthesizeText(request: SynthesisRequest?, callback: SynthesisCallback?) { + if (request == null || callback == null) { + Log.e(TAG, "Invalid synthesis request or callback") + return + } + + val text = request.charSequenceText?.toString() ?: request.text + if (text.isNullOrEmpty()) { + callback.error() + return + } + + Log.d(TAG, "onSynthesizeText: text=$text, language=${request.language}, country=${request.country}") + + try { + // 确保 TTS 引擎已初始化 + if (!isInitialized) { + initializeTtsEngine() + } + + if (!isInitialized || ttsService == null) { + Log.e(TAG, "TTS engine not initialized") + callback.error() + return + } + + // TTS 引擎已初始化,直接使用 + + // 开始合成 + val sampleRate = 44100 + callback.start(sampleRate, android.media.AudioFormat.ENCODING_PCM_16BIT, 1) + + // 使用 TTS 服务处理文本 + val audioData = ttsService?.process(text, 0) + + if (audioData != null && audioData.isNotEmpty()) { + Log.d(TAG, "Generated ${audioData.size} audio samples") + + // 将 FloatArray 转换为 ByteArray (PCM 16-bit) + val maxBufferSize = callback.maxBufferSize + val byteBuffer = ByteArray(maxBufferSize) + var offset = 0 + + for (sample in audioData) { + // 转换 float 到 16-bit PCM + val pcmValue = (sample * 32767f).toInt().coerceIn(-32768, 32767).toShort() + + // 写入字节(小端序) + byteBuffer[offset++] = (pcmValue.toInt() and 0xFF).toByte() + byteBuffer[offset++] = ((pcmValue.toInt() shr 8) and 0xFF).toByte() + + // 当缓冲区满时,发送数据 + if (offset >= maxBufferSize - 2) { + callback.audioAvailable(byteBuffer, 0, offset) + offset = 0 + } + } + + // 发送剩余数据 + if (offset > 0) { + callback.audioAvailable(byteBuffer, 0, offset) + } + + callback.done() + Log.d(TAG, "Synthesis completed successfully") + } else { + Log.e(TAG, "No audio data generated") + callback.error() + } + + } catch (e: Exception) { + Log.e(TAG, "Error during synthesis", e) + callback.error() + } + } + + private fun initializeTtsEngine() { + try { + Log.d(TAG, "Initializing TTS engine with model: $defaultModelPath") + + // 检查模型文件是否存在 + val modelDir = File(defaultModelPath) + if (!modelDir.exists() || !modelDir.isDirectory) { + Log.e(TAG, "Model directory not found: $defaultModelPath") + return + } + + val configFile = File(modelDir, "config.json") + if (!configFile.exists()) { + Log.e(TAG, "config.json not found in model directory") + return + } + + // 初始化 TTS 服务(同步调用) + ttsService = TtsService() + // 注意:这里假设 init 方法有同步版本,如果没有需要使用 runBlocking + val initResult = try { + kotlinx.coroutines.runBlocking { + ttsService?.init(defaultModelPath) ?: false + } + } catch (e: Exception) { + Log.e(TAG, "Error during TTS init", e) + false + } + + if (initResult) { + isInitialized = true + Log.d(TAG, "TTS engine initialized successfully") + } else { + Log.e(TAG, "Failed to initialize TTS engine") + ttsService = null + } + + } catch (e: Exception) { + Log.e(TAG, "Error initializing TTS engine", e) + ttsService = null + isInitialized = false + } + } + + override fun onDestroy() { + super.onDestroy() + Log.d(TAG, "MnnTtsService destroyed") + + try { + ttsService?.destroy() + ttsService = null + isInitialized = false + } catch (e: Exception) { + Log.e(TAG, "Error destroying TTS service", e) + } + } + + private fun buildLocaleString(lang: String?, country: String?): String { + return when { + lang.isNullOrEmpty() -> "" + country.isNullOrEmpty() -> lang + else -> "$lang-$country" + } + } +} diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsSettingsActivity.kt b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsSettingsActivity.kt new file mode 100644 index 0000000000..a7cdbfa35c --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/MnnTtsSettingsActivity.kt @@ -0,0 +1,43 @@ +package com.alibaba.mnn.tts.demo + +import android.os.Bundle +import android.widget.TextView +import androidx.appcompat.app.AppCompatActivity +import java.io.File + +class MnnTtsSettingsActivity : AppCompatActivity() { + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContentView(R.layout.activity_tts_settings) + + // 显示当前配置信息 + val infoText = findViewById(R.id.settingsInfoText) + val modelPath = "/data/local/tmp/tts_models/default" + val modelDir = File(modelPath) + val modelExists = modelDir.exists() && modelDir.isDirectory + val configExists = File(modelDir, "config.json").exists() + + infoText.text = """ + MNN TTS Engine Settings + + Model Path: $modelPath + Model Directory: ${if (modelExists) "✓ Found" else "✗ Not Found"} + Config File: ${if (configExists) "✓ Found" else "✗ Not Found"} + + Supported Languages: + - Chinese (China): zh-CN + + Note: Please ensure TTS models are placed in: + /data/local/tmp/tts_models/default/ + + Required files: + - config.json + - model.mnn + - (other model files) + + Use adb to push models: + adb push /path/to/model /data/local/tmp/tts_models/default/ + """.trimIndent() + } +} diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelAdapter.kt b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelAdapter.kt new file mode 100644 index 0000000000..d92be40fc5 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelAdapter.kt @@ -0,0 +1,205 @@ +package com.alibaba.mnn.tts.demo + +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import android.widget.TextView +import androidx.recyclerview.widget.RecyclerView +import java.io.File + +class ModelAdapter( + private val onModelSelected: (String, ModelConfig) -> Unit, + private val onSpeakerSelected: (String) -> Unit, + private val onPlayClicked: (String, String) -> Unit +) : RecyclerView.Adapter() { + + private var models: List = emptyList() + private var selectedPosition = -1 + + fun updateModels(newModels: List) { + models = newModels + notifyDataSetChanged() + } + + override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ModelViewHolder { + val view = LayoutInflater.from(parent.context) + .inflate(R.layout.item_model, parent, false) + return ModelViewHolder(view) + } + + override fun onBindViewHolder(holder: ModelViewHolder, position: Int) { + val modelPath = models[position] + holder.bind(modelPath, position == selectedPosition) + + holder.itemView.setOnClickListener { + if (selectedPosition != holder.adapterPosition) { + val previousSelected = selectedPosition + selectedPosition = holder.adapterPosition + notifyItemChanged(previousSelected) + notifyItemChanged(selectedPosition) + + // Notify selection + holder.loadConfig(modelPath)?.let { config -> + onModelSelected(modelPath, config) + } + } + } + + holder.setSpeakerListener { speakerId -> + if (position == selectedPosition) { + onSpeakerSelected(speakerId) + } + } + + holder.setPlayListener { modelPath, speakerId -> + onPlayClicked(modelPath, speakerId) + } + } + + override fun getItemCount(): Int = models.size + + class ModelViewHolder(itemView: View) : RecyclerView.ViewHolder(itemView) { + private val modelNameText: TextView = itemView.findViewById(R.id.modelNameText) + private val modelDescText: TextView = itemView.findViewById(R.id.modelDescText) + private val modelRadioButton: android.widget.RadioButton = itemView.findViewById(R.id.modelRadioButton) + private val voiceSpinner: android.widget.Spinner = itemView.findViewById(R.id.voiceSpinner) + private val playButton: android.view.View = itemView.findViewById(R.id.playButton) + private val cardContent: android.view.View = itemView.findViewById(R.id.cardContent) + + private var currentModelPath: String? = null + private var currentConfig: ModelConfig? = null + private var speakerListener: ((String) -> Unit)? = null + private var playListener: ((String, String) -> Unit)? = null + private var currentSpeakerIndex: Int = 0 + + fun bind(modelPath: String, isSelected: Boolean) { + if (currentModelPath != modelPath) { + currentModelPath = modelPath + currentConfig = null + currentSpeakerIndex = 0 + voiceSpinner.adapter = null + } + + val file = File(modelPath) + modelNameText.text = file.name + modelDescText.text = modelPath + modelRadioButton.isChecked = isSelected + + // Reload config for each bind if needed (or rely on loadConfig cache) + val config = loadConfig(modelPath) + + // Highlight selected item + if (isSelected) { + cardContent.setBackgroundResource(R.drawable.bg_filter_border) + playButton.visibility = View.VISIBLE + + // Only show speaker selection if config has speakers + if (config != null && config.speakers.isNotEmpty()) { + voiceSpinner.visibility = View.VISIBLE + setupSpinner(config) + } else { + voiceSpinner.visibility = View.GONE + voiceSpinner.adapter = null + } + } else { + cardContent.background = null + voiceSpinner.visibility = View.GONE + playButton.visibility = View.GONE + voiceSpinner.adapter = null + } + + playButton.setOnClickListener { + val speakerId = if (config != null && config.speakers.isNotEmpty()) { + config.speakers[currentSpeakerIndex] + } else "" + playListener?.invoke(modelPath, speakerId) + } + } + + fun loadConfig(modelPath: String): ModelConfig? { + if (currentConfig != null) return currentConfig + try { + val configFile = File(modelPath, "config.json") + if (configFile.exists()) { + val content = configFile.readText() + val json = org.json.JSONObject(content) + val speakers = mutableListOf() + if (json.has("speakers")) { + val arr = json.getJSONArray("speakers") + for (i in 0 until arr.length()) speakers.add(arr.getString(i)) + } + val languages = mutableListOf() + if (json.has("languages")) { + val arr = json.getJSONArray("languages") + for (i in 0 until arr.length()) languages.add(arr.getString(i)) + } + currentConfig = ModelConfig(speakers, languages) + } else { + currentConfig = ModelConfig() + } + } catch (e: Exception) { + currentConfig = ModelConfig() + } + return currentConfig + } + + private fun setupSpinner(config: ModelConfig?) { + config ?: return + if (config.speakers.isEmpty()) { + voiceSpinner.visibility = View.GONE + return + } + + val adapter = android.widget.ArrayAdapter(itemView.context, R.layout.spinner_item_dark, config.speakers) + adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) + voiceSpinner.adapter = adapter + voiceSpinner.setSelection(currentSpeakerIndex) + + voiceSpinner.onItemSelectedListener = object : android.widget.AdapterView.OnItemSelectedListener { + override fun onItemSelected(parent: android.widget.AdapterView<*>?, view: View?, position: Int, id: Long) { + currentSpeakerIndex = position + val speakerId = config.speakers[position] + speakerListener?.invoke(speakerId) + } + override fun onNothingSelected(parent: android.widget.AdapterView<*>?) {} + } + } + + fun setSpeakerListener(listener: (String) -> Unit) { + this.speakerListener = listener + } + + fun setPlayListener(listener: (String, String) -> Unit) { + this.playListener = listener + } + } +} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelConfig.kt b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelConfig.kt new file mode 100644 index 0000000000..044cf779be --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/alibaba/mnn/tts/demo/ModelConfig.kt @@ -0,0 +1,6 @@ +package com.alibaba.mnn.tts.demo + +data class ModelConfig( + val speakers: List = emptyList(), + val languages: List = emptyList() +) diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/java/com/mnn/tts/demo/MnnTtsService.kt b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/mnn/tts/demo/MnnTtsService.kt new file mode 100644 index 0000000000..01af186de0 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/java/com/mnn/tts/demo/MnnTtsService.kt @@ -0,0 +1,420 @@ +package com.mnn.tts.demo + +import android.media.AudioFormat +import android.speech.tts.SynthesisCallback +import android.speech.tts.SynthesisRequest +import android.speech.tts.TextToSpeech +import android.speech.tts.TextToSpeechService +import android.speech.tts.Voice +import android.util.Log +import com.taobao.meta.avatar.tts.TtsService +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import java.io.File +import java.util.Locale +import org.json.JSONObject + +class MnnTtsService : TextToSpeechService() { + + // 使用 SupervisorJob,这样如果一个子协程崩了,不会导致整个 Scope 失效 + private val serviceScope = CoroutineScope(Dispatchers.IO + SupervisorJob()) + + // 互斥锁,确保同一时间只有一个合成任务在跑推理 + private val synthesisMutex = Mutex() + + // 当前正在运行的合成任务 Job,用于 onStop 时取消 + private var synthesisJob: Job? = null + + private var ttsService: TtsService? = null + private var currentLanguage: String = "zh" // 默认中文(两字母代码,用于内部处理) + private var currentLanguageOriginal: String = "zh" // 原始语言代码(可能三字母,用于返回给系统) + private var currentCountry: String = "" // 当前国家代码 + private var modelPath: String? = null + private var sampleRate: Int = 16000 // 从 config.json 读取的采样率,默认 16000 + + companion object { + private const val TAG = "MnnTtsService" + private const val PREFS_NAME = "mnn_tts_prefs" + private const val KEY_MODEL_PATH = "model_path" + // 注意:/data/local/tmp 通常需要 Root 权限或 Debug 模式才能访问 + private const val DEFAULT_MODEL_PATH = "/data/local/tmp/tts_models" + + // 音频参数常量 + private const val DEFAULT_SAMPLE_RATE = 16000 // 默认采样率 + private const val ENCODING = AudioFormat.ENCODING_PCM_16BIT + private const val CHANNEL_COUNT = 1 + } + + override fun onCreate() { + super.onCreate() + Log.d(TAG, "MnnTtsService onCreate") + // 异步加载模型,避免阻塞主线程导致 ANR + serviceScope.launch { + loadModelPath() + initializeTtsService() + } + } + + override fun onDestroy() { + super.onDestroy() + Log.d(TAG, "MnnTtsService onDestroy") + serviceScope.cancel() // 取消所有协程 + try { + ttsService?.destroy() + } catch (e: Exception) { + Log.e(TAG, "Error destroying TTS service", e) + } + ttsService = null + } + + // --- 核心方法:合成文本 --- + override fun onSynthesizeText(request: SynthesisRequest?, callback: SynthesisCallback?) { + // 使用最高级别的日志,确保能看到 + Log.e(TAG, "🔥🔥🔥🔥🔥 [onSynthesizeText 被调用!] Lang: ${request?.language}, Text: ${request?.charSequenceText?.take(50)}") + if (callback == null || request == null) { + Log.e(TAG, "❌ onSynthesizeText: callback or request is null!") + return + } + + val text = request.charSequenceText?.toString() ?: "" + + Log.d(TAG, "[合成开始] Text: \"$text\", Lang: ${request.language}") + + // 1. 如果上一个任务还在跑,先取消它 + runCatching { synthesisJob?.cancel() } + + // 2. 启动新的协程任务 + synthesisJob = serviceScope.launch { + // 使用 Mutex 锁住,防止多线程并发调用底层 C++ 引擎导致 crash + synthesisMutex.withLock { + try { + // 检查是否已被取消 (比如用户刚点播放立刻点了暂停) + if (!isActive) return@withLock + + // 检查 TTS 引擎状态 + val tts = ttsService + if (tts == null) { + Log.e(TAG, "TTS service is null, attempting re-init") + initializeTtsService() + if (ttsService == null) { + callback.error() + return@withLock + } + } + + // 设置语言 + val reqLang = request.language ?: currentLanguage + val langCode = if (reqLang.lowercase().contains("en")) "en" else "zh" + ttsService?.setLanguage(langCode) + + // 等待 TTS 服务初始化完成(关键修复:确保模型已加载) + val isReady = ttsService?.waitForInitComplete() ?: false + if (!isReady) { + Log.e(TAG, "TTS service not ready after waiting") + callback.error() + return@withLock + } + + // 执行推理 (耗时操作) + // 注意:这里假设 tts.process 是同步阻塞的 + val audioData = ttsService?.process(text, 0) + + // 再次检查取消状态 + if (!isActive) return@withLock + + if (audioData == null || audioData.isEmpty()) { + Log.w(TAG, "Generated audio data is empty") + callback.error() + return@withLock + } + + // 3. 开始向系统写入数据 (全程在 IO 线程,不要切换到 Main) + + // Step A: 告诉系统准备接收音频 + // start 返回 ERROR 表示系统侧可能已断开 + // 使用从 config.json 读取的采样率 + if (callback.start(sampleRate, ENCODING, CHANNEL_COUNT) != TextToSpeech.SUCCESS) { + Log.w(TAG, "callback.start failed, system may have aborted") + return@withLock + } + + // Step B: 格式转换 ShortArray -> ByteArray (Little Endian) + val byteArray = ByteArray(audioData.size * 2) + for (i in audioData.indices) { + val sample = audioData[i].toInt() + // 低8位 + byteArray[i * 2] = (sample and 0xFF).toByte() + // 高8位 + byteArray[i * 2 + 1] = ((sample shr 8) and 0xFF).toByte() + } + + // Step C: 写入数据 + // 这里是一次性写入。如果数据量极大,建议分块写入(chunked) + val maxBufferSize = callback.maxBufferSize + var offset = 0 + while (offset < byteArray.size && isActive) { + val bytesToWrite = Math.min(maxBufferSize, byteArray.size - offset) + val result = callback.audioAvailable(byteArray, offset, bytesToWrite) + + if (result != TextToSpeech.SUCCESS) { + Log.w(TAG, "callback.audioAvailable failed") + return@withLock + } + offset += bytesToWrite + } + + // Step D: 结束 + if (isActive) { + callback.done() + Log.d(TAG, "[合成完成] Sent ${byteArray.size} bytes") + } + + } catch (e: Exception) { + Log.e(TAG, "Synthesis critical error", e) + // 防止崩溃传递给系统 + if (isActive) callback.error() + } + } + } + } + + // --- 核心方法:停止合成 --- + override fun onStop() { + Log.d(TAG, "onStop called - INTERRUPT") + // 关键:立即取消当前的协程任务 + synthesisJob?.cancel() + synthesisJob = null + } + +// --- 修改 1: 优化语言检查,支持三字母代码 (ISO 639-2) --- + override fun onIsLanguageAvailable(lang: String?, country: String?, variant: String?): Int { + Log.e(TAG, "🔥系统询问语言检查: lang=$lang, country=$country, variant=$variant") + + // 简单的模糊匹配:只要包含 "zh", "cn", "en", "eng" 就认为支持 + val l = (lang ?: "").lowercase() + if (l.contains("zh") || l.contains("cn") || l.contains("en") || l.contains("eng")) { + return TextToSpeech.LANG_COUNTRY_AVAILABLE + } + + // 暂时为了调试,还是返回成功,但正常应该返回 NOT_SUPPORTED + return TextToSpeech.LANG_COUNTRY_AVAILABLE + } + + // --- 修改 2: 加载语言也强制通过 --- + override fun onLoadLanguage(lang: String?, country: String?, variant: String?): Int { + Log.e(TAG, "🔥系统请求加载语言: lang=$lang, country=$country, variant=$variant") + + // 保存原始语言代码(可能三字母),用于 onGetLanguage 返回 + currentLanguageOriginal = lang ?: "zh" + + // 转换三字母代码为两字母代码(ISO 639-2 -> ISO 639-1),用于内部处理 + val langCode = when { + lang == null -> "zh" + lang.lowercase().contains("eng") || lang.lowercase().contains("en") -> "en" + lang.lowercase().contains("zh") || lang.lowercase().contains("cn") -> "zh" + else -> lang.take(2).lowercase() // 取前两个字符 + } + currentLanguage = langCode + currentCountry = country ?: "" + Log.e(TAG, "🔥设置 currentLanguage = $currentLanguage (内部), currentLanguageOriginal = $currentLanguageOriginal (返回系统), currentCountry = $currentCountry") + return TextToSpeech.LANG_COUNTRY_AVAILABLE + } + + + override fun onGetLanguage(): Array { + // 返回格式: [language, country, variant] + // 【关键】返回当前实际使用的语言代码(可能是三字母),以匹配系统请求 + val country = when (currentLanguage) { + "en" -> "USA" + "zh" -> "CN" + else -> currentCountry.ifEmpty { "CN" } + } + return arrayOf(currentLanguageOriginal, country, "") + } + + // --------------------------------------------------------- + // 新增:必须告诉系统你有具体的"发音人",否则高版本安卓不理你 + // --------------------------------------------------------- + + override fun onGetVoices(): List { + // 定义一个中文发音人 + val zhVoice = Voice( + "mnn_zh_voice", // 唯一ID + Locale.CHINA, // 对应的 Locale + Voice.QUALITY_HIGH, // 质量 + Voice.LATENCY_NORMAL, // 延迟 + false, // 是否需要网络 + setOf("male") // 特征 (male/female) + ) + + // 定义一个英文发音人 (对应系统请求的 eng-USA) + val enVoice = Voice( + "mnn_en_voice", + Locale.US, + Voice.QUALITY_HIGH, + Voice.LATENCY_NORMAL, + false, + setOf("female") + ) + + Log.e(TAG, "🔥系统获取 Voice 列表: [mnn_zh_voice, mnn_en_voice]") + return listOf(zhVoice, enVoice) + } + + override fun onIsValidVoiceName(voiceName: String?): Int { + val result = if (voiceName == "mnn_zh_voice" || voiceName == "mnn_en_voice") { + TextToSpeech.SUCCESS + } else { + TextToSpeech.ERROR + } + Log.d(TAG, "系统验证 Voice 名称: $voiceName -> $result") + return result + } + + override fun onLoadVoice(voiceName: String?): Int { + Log.e(TAG, "🔥🔥🔥 系统请求加载 Voice: $voiceName") + + // 根据 Voice 名字切换你的模型或参数 + if (voiceName == "mnn_en_voice") { + currentLanguage = "en" // 内部使用两字母 + currentLanguageOriginal = "eng" // 返回给系统使用三字母 + currentCountry = "USA" // 设置对应的国家代码 + } else { + currentLanguage = "zh" + currentLanguageOriginal = "zh" + currentCountry = "CHN" + } + + // 【关键修复】确保 TTS 服务已经初始化完成 + // 如果还在初始化中,等待一下(最多等待 3 秒) + if (ttsService == null) { + Log.i(TAG, "TTS service ready after waiting ms") + } else { + Log.i(TAG, "TTS service already initialized") + } + + Log.e(TAG, "🔥🔥🔥 onLoadVoice 返回 SUCCESS, currentLanguage=$currentLanguage, currentCountry=$currentCountry") + return TextToSpeech.SUCCESS + } + + override fun onGetDefaultVoiceNameFor(lang: String?, country: String?, variant: String?): String? { + // 当系统请求英语时,默认返回英文 Voice 的名字 + val checkLang = (lang ?: "").lowercase() + val voiceName = if (checkLang.contains("en") || checkLang.contains("eng")) { + "mnn_en_voice" + } else { + "mnn_zh_voice" + } + Log.d(TAG, "系统请求默认 Voice: lang=$lang -> $voiceName") + return voiceName + } + + // --- 初始化与路径逻辑 (保持原逻辑优化) --- + private fun loadModelPath() { + val prefs = getSharedPreferences(PREFS_NAME, MODE_PRIVATE) + modelPath = prefs.getString(KEY_MODEL_PATH, null) + + if (modelPath.isNullOrEmpty()) { + modelPath = findDefaultModel() + } + Log.i(TAG, "Model Path resolved to: $modelPath") + } + + private fun findDefaultModel(): String? { + // 优先检查 App 私有目录 (更安全) + val privatePath = File(getExternalFilesDir(null), "tts_models") + if (privatePath.exists()) return privatePath.absolutePath + + // 检查原始路径 + val legacyPath = File(DEFAULT_MODEL_PATH) + if (legacyPath.exists() && legacyPath.isDirectory) { + // 简单的查找逻辑 + legacyPath.listFiles()?.firstOrNull { it.isDirectory }?.let { + return it.absolutePath + } + } + return null + } + +private suspend fun initializeTtsService() { + // 1. 【关键修复】将可变的成员变量赋值给不可变的局部变量 + // 这样编译器就确定 path 在这个函数里永远不会变了 + val path = modelPath + + if (path.isNullOrEmpty()) { + Log.e(TAG, "Skipping init: No model path") + return + } + + // 避免重复初始化 + if (ttsService != null) return + + try { + // 从 config.json 读取采样率 + loadSampleRateFromConfig(path) + + val service = TtsService() + + // 2. 这里使用局部变量 path,它已经被智能转换为非空 String 了 + val success = service.init(path) + + if (success) { + ttsService = service + Log.i(TAG, "TTS Engine Initialized! Sample rate: $sampleRate Hz") + } else { + Log.e(TAG, "TTS Engine Init Failed (return false)") + } + } catch (e: Exception) { + Log.e(TAG, "TTS Engine Init Exception", e) + } + } + + // 从 config.json 读取采样率 + private fun loadSampleRateFromConfig(modelPath: String) { + try { + val configFile = File(modelPath, "config.json") + if (configFile.exists() && configFile.isFile) { + val configContent = configFile.readText() + val configJson = JSONObject(configContent) + + if (configJson.has("sample_rate")) { + sampleRate = configJson.getInt("sample_rate") + Log.i(TAG, "Loaded sample rate from config.json: $sampleRate Hz") + } else { + Log.w(TAG, "config.json does not contain 'sample_rate', using default: $DEFAULT_SAMPLE_RATE Hz") + sampleRate = DEFAULT_SAMPLE_RATE + } + } else { + Log.w(TAG, "config.json not found at $modelPath/config.json, using default: $DEFAULT_SAMPLE_RATE Hz") + sampleRate = DEFAULT_SAMPLE_RATE + } + } catch (e: Exception) { + Log.e(TAG, "Error reading sample rate from config.json", e) + sampleRate = DEFAULT_SAMPLE_RATE + } + } + + // 供外部 Activity 调用更新模型路径 + fun updateModelPath(path: String) { + val prefs = getSharedPreferences(PREFS_NAME, MODE_PRIVATE) + prefs.edit().putString(KEY_MODEL_PATH, path).apply() + + // 重启服务逻辑 + serviceScope.launch { + synthesisMutex.withLock { + ttsService?.destroy() + ttsService = null + modelPath = path + // 重新读取采样率并初始化服务 + initializeTtsService() + } + } + } +} \ No newline at end of file diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_filter_border.xml b/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_filter_border.xml new file mode 100644 index 0000000000..27062798b9 --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_filter_border.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_play_button.xml b/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_play_button.xml new file mode 100644 index 0000000000..a31da11c8d --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_play_button.xml @@ -0,0 +1,4 @@ + + + + diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_voice_spinner.xml b/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_voice_spinner.xml new file mode 100644 index 0000000000..b4cff90e9d --- /dev/null +++ b/apps/frameworks/mnn_tts/demo/android/src/main/res/drawable/bg_voice_spinner.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/apps/frameworks/mnn_tts/demo/android/src/main/res/layout/activity_main.xml b/apps/frameworks/mnn_tts/demo/android/src/main/res/layout/activity_main.xml index dfe6daf6fc..3ede753d77 100644 --- a/apps/frameworks/mnn_tts/demo/android/src/main/res/layout/activity_main.xml +++ b/apps/frameworks/mnn_tts/demo/android/src/main/res/layout/activity_main.xml @@ -3,42 +3,98 @@ xmlns:app="http://schemas.android.com/apk/res-auto" android:layout_width="match_parent" android:layout_height="match_parent" - android:padding="16dp"> + android:padding="16dp" + android:background="#121212"> - + app:layout_constraintStart_toStartOf="parent" /> -