diff --git a/.github/workflows/llm-pr-review.yml b/.github/workflows/llm-pr-review.yml
new file mode 100644
index 0000000000..2f03442c9b
--- /dev/null
+++ b/.github/workflows/llm-pr-review.yml
@@ -0,0 +1,24 @@
+name: LLM Code Review
+
+on:
+ pull_request:
+ types: [opened, reopened, synchronize]
+
+permissions:
+ contents: read
+ pull-requests: write
+
+jobs:
+ review:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Repo
+ uses: actions/checkout@v4
+
+ - name: LLM Code Review
+ uses: wangzhaode/MNNCodeReviewer@v1.0.0
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ OPENAI_API_ENDPOINT: https://maas-api.ai-yuanjing.com/openapi/compatible-mode/v1
+ MODEL: glm-5
\ No newline at end of file
diff --git a/.github/workflows/pymnn_release.yml b/.github/workflows/pymnn_release.yml
index 856b551308..ae02b95a9d 100644
--- a/.github/workflows/pymnn_release.yml
+++ b/.github/workflows/pymnn_release.yml
@@ -18,7 +18,7 @@ jobs:
- { os: ubuntu-latest, arch: x86_64, build: 'cp*-manylinux*' }
- { os: ubuntu-24.04-arm, arch: aarch64, build: 'cp*-manylinux*' }
- { os: windows-latest, arch: AMD64, build: 'cp*' }
- - { os: macos-13, arch: x86_64, build: 'cp*' }
+ - { os: macos-14, arch: x86_64, build: 'cp*' }
- { os: macos-14, arch: arm64, build: 'cp*' }
steps:
@@ -39,7 +39,7 @@ jobs:
run: python -m pip install pipx
- name: Build wheels
- uses: pypa/cibuildwheel@v2.16.5
+ uses: pypa/cibuildwheel@v2.22.0
env:
CIBW_ARCHS_MACOS: ${{ matrix.arch }}
CIBW_ARCHS_LINUX: ${{ matrix.arch }}
@@ -69,6 +69,7 @@ jobs:
publish_wheels:
permissions:
contents: none
+ id-token: write
name: Upload
needs: [build_wheels]
runs-on: ubuntu-latest
@@ -86,5 +87,4 @@ jobs:
- uses: pypa/gh-action-pypi-publish@release/v1
with:
- password: ${{ secrets.PYPI_API_TOKEN }}
skip_existing: true
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 66f35d8e0a..52df887af2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -375,9 +375,24 @@ datasets/*
source/backend/qnn/3rdParty/include
project/android/.cxx
pymnn/android/.cxx/
+pymnn/android/.cxx/abi_configuration_5u53tc49.jsonz
apps/mnncli/.cursorrules
apps/mnncli/model_market_json_data.inc
#kledi
_deps
#aicoding
-.cursor
\ No newline at end of file
+.cursor
+
+# llm model
+transformers/llm/export/model/
+apps/Android/.qoder/settings.json
+apps/Android/MnnLlmChatOld
+
+transformers/llm/export/tmp/
+
+# iOS
+apps/iOS/MNNLLMChat/Chat/
+apps/iOS/MNNLLMChat/swift-transformers/
+apps/iOS/MNNLLMChat/MNNLLMiOS/LocalModel/Qwen3-4B-MNN
+apps/iOS/MNNLLMChat/MNNLLMiOS/LocalModel/Qwen3-0.6B-MNN
+apps/iOS/MNNLLMChat/MNNLLMiOS/LocalModel/Qwen2.5-Omni-3B-MNN
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 18b326339e..5023ee2356 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -80,6 +80,7 @@ option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF
option(MNN_CPU_WEIGHT_DEQUANT_GEMM "Build MNN CPU weight dequant related gemm kernels." OFF)
option(MNN_BUILD_AUDIO "Build audio api in MNN." OFF)
option(MNN_SME2 "Use Arm sme2 instructions" ON)
+option(MNN_METAL_TENSOR "Use Metal4 tensor instructions" ON)
if (MNN_BUILD_MINI)
set(MNN_SKIPBUILD_GEOMETRY ON CACHE BOOL "" FORCE)
@@ -258,6 +259,7 @@ option(MNN_VULKAN "Enable Vulkan" OFF)
option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON)
option(MNN_SUPPORT_FP16_ARMV7 "Enable ARMv8.2's FP16 Compute for armv7 arch, may cause library not valid for 32 bit cpu" OFF)
option(MNN_KLEIDIAI "Enable KLEIDIAI" ON)
+option(MNN_KLEIDIAI_DEFAULT_ON "Use KLEIDIAI kernels by default" OFF)
option(MNN_ONEDNN "Enable oneDNN" OFF)
option(MNN_AVX2 "Open AVX2 Compile for x86 if possible" ON)
option(MNN_AVX512 "Enable AVX512" OFF)
@@ -277,7 +279,7 @@ if (NOT MNN_CUDA OR NOT CMAKE_SYSTEM_NAME MATCHES "^Linux")
set(MNN_CUDA_PROFILE OFF)
endif()
-if (NOT MNN_QNN)
+if (NOT MNN_QNN)
set(MNN_QNN_ONLINE_FINALIZE OFF)
endif()
@@ -373,6 +375,9 @@ endif()
IF(MNN_DEBUG_MEMORY)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address")
+
+ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address")
+ set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fsanitize=address")
endif()
set(MNN_DEPS "")
@@ -549,6 +554,7 @@ ENDIF()
IF(MNN_BUILD_DIFFUSION)
file(GLOB MNN_DIFFUSION_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/diffusion/engine/include/diffusion/*)
list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/diffusion/engine/include/diffusion/diffusion.hpp)
+ list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/diffusion/engine/include/diffusion/sana_llm.hpp)
ENDIF()
@@ -936,6 +942,11 @@ if (NOT MNN_BUILD_SHARED_LIBS)
endif()
list(APPEND MNN_TARGETS MNN)
list(REMOVE_ITEM MNN_TARGETS MNN)
+
+# Cache MNN_DEPS and MNN_INCLUDES for external projects
+set(MNN_LIBS ${MNN_DEPS} CACHE INTERNAL "MNN targets")
+set(MNN_INCLUDE_DIRS ${MNN_INCLUDES} CACHE INTERNAL "MNN include directories")
+
IF(MNN_BUILD_DEMO)
include(${CMAKE_CURRENT_LIST_DIR}/demo/exec/CMakeLists.txt)
ENDIF()
diff --git a/MNN.sln b/MNN.sln
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/README.md b/README.md
index 5fe168ed05..8bb0dce106 100644
--- a/README.md
+++ b/README.md
@@ -6,13 +6,27 @@
[](README_JP.md)
[](http://www.mnn.zone)
-[](./apps/Android/MnnLlmChat/README.md)
-[](./apps/Android/Mnn3dAvatar/README.md)
-
+[](./apps/Android/MnnLlmChat/README.md)
+[](./apps/Android/Mnn3dAvatar/README.md)
+[](./apps/sana/README.md)
## News 🔥
+- [2026/03/05] Support Qwen3.5 Series.
+
+
+
+
+
+- [2026/02/13] MNN-Sana-Edit-V2 is now available at [apps](./apps/sana/README.md), offering cartoon-style photo editing based on Sana.
+
+
+
+
+
+ History News
+
- [2025/10/16] Support Qwen3-VL Series.
-- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)
+- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)
@@ -24,10 +38,6 @@
-
-
- History News
-
- [2025/04/30] android app support qwen3 and dark mode [MNN Chat App](./apps/Android/MnnLlmChat/README.md#releases).
@@ -154,13 +164,13 @@ The group discussions are predominantly Chinese. But we welcome and will help En
Dingtalk discussion groups:
-Group #1 (Full): 23329087
+Group #4 (Available): 160170007549
-Group #2 (Full): 23350225
+Group #3 (Full)
-Group #3: QR code:
+Group #2 (Full): 23350225
-
+Group #1 (Full): 23329087
## Historical Paper
diff --git a/README_CN.md b/README_CN.md
index edcf823a28..f769a1e14b 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -111,12 +111,10 @@ MNN适配的硬件架构与精度详见下表:
## 社区交流与反馈
钉钉群组:
-- 钉钉群1:23329087
-- 钉钉群2:23350225
-- 钉钉群3:扫描二维码加入
-
-
-
+- 钉钉群3 (可加入): 160170007549
+- 钉钉群3 (已无法加入)
+- 钉钉群2 (已满): 23350225
+- 钉钉群1 (已满): 23329087
## 历史论文
diff --git a/README_JP.md b/README_JP.md
index c2baa58d94..2f33def31a 100644
--- a/README_JP.md
+++ b/README_JP.md
@@ -117,13 +117,14 @@ MNN(テンソル計算エンジン)に基づいて、推論、トレーニ
Dingtalkディスカッショングループ:
-グループ#1(満員):23329087
-グループ#2(満員):23350225
+グループ#4 :160170007549
-グループ#3:QRコード:
+グループ#3 (満員)
-
+グループ#2(満員):23350225
+
+グループ#1(満員):23329087
## 歴史的な論文
diff --git a/build_lib.sh b/build_lib.sh
new file mode 100644
index 0000000000..c839b6e7b6
--- /dev/null
+++ b/build_lib.sh
@@ -0,0 +1,807 @@
+#!/bin/bash
+
+# MNN 统一构建脚本
+# 支持构建 Android、iOS、鸿蒙和 Python 版本的 MNN
+
+set -e
+
+# 颜色输出
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+BLUE='\033[0;34m'
+NC='\033[0m' # No Color
+
+# 默认配置
+BUILD_ANDROID=false
+BUILD_IOS=false
+BUILD_PYTHON=false
+BUILD_IOS_SIMULATOR=false
+BUILD_HARMONY=false
+ANDROID_NDK=""
+HARMONY_HOME=""
+OUTPUT_DIR="mnn_builds"
+CLEAN_BUILD=true
+VERSION=""
+PYTHON_CMD="python3"
+DEPS_OPTIONS=""
+
+# 获取项目根目录
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_ROOT="$SCRIPT_DIR"
+
+print_usage() {
+ echo "MNN 统一构建脚本"
+ echo ""
+ echo "用法: $0 [选项]"
+ echo ""
+ echo "构建目标:"
+ echo " --android 构建 Android 版本 (需要 --ndk)"
+ echo " --ios 构建 iOS 真机版本 (arm64)"
+ echo " --ios-simulator 构建 iOS 模拟器版本 (x86_64 + arm64)"
+ echo " --harmony 构建鸿蒙版本 (需要 --harmony-home 或设置 HARMONY_HOME)"
+ echo " --python 构建 Python 版本"
+ echo ""
+ echo "Android 选项:"
+ echo " --ndk PATH Android NDK 路径 (例如: ~/Library/Android/sdk/ndk/29.0.13599879)"
+ echo ""
+ echo "鸿蒙选项:"
+ echo " --harmony-home PATH 鸿蒙工具链路径 (例如: ~/Library/OpenHarmony/Sdk/native)"
+ echo " 如果未指定,将优先查找 ~/Library/OpenHarmony/Sdk/native"
+ echo ""
+ echo "Python 选项:"
+ echo " --python-deps OPTIONS Python 依赖选项,多个用逗号分隔"
+ echo " 可用选项: llm,opencl,cuda,torch,render,vulkan,internal,no_sse,openmp"
+ echo " --python-cmd CMD Python 命令 (默认: python3)"
+ echo ""
+ echo "通用选项:"
+ echo " -o, --output DIR 输出目录 (默认: mnn_builds)"
+ echo " -v, --version VERSION 版本号 (默认: 自动从源码读取)"
+ echo " --no-clean 不清理之前的构建目录"
+ echo " -h, --help 显示帮助信息"
+ echo ""
+ echo "示例:"
+ echo " # 构建所有平台"
+ echo " $0 --android --ios --harmony --python --ndk ~/Library/Android/sdk/ndk/29.0.13599879"
+ echo ""
+ echo " # 仅构建 Android"
+ echo " $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879"
+ echo ""
+ echo " # 构建鸿蒙版本"
+ echo " $0 --harmony --harmony-home ~/Library/OpenHarmony/Sdk/native"
+ echo ""
+ echo " # 构建 iOS 真机和模拟器"
+ echo " $0 --ios --ios-simulator"
+ echo ""
+ echo " # 构建 Python (带 LLM 支持)"
+ echo " $0 --python --python-deps llm,opencl"
+}
+
+# 解析命令行参数
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --android)
+ BUILD_ANDROID=true
+ shift
+ ;;
+ --ios)
+ BUILD_IOS=true
+ shift
+ ;;
+ --ios-simulator)
+ BUILD_IOS_SIMULATOR=true
+ shift
+ ;;
+ --python)
+ BUILD_PYTHON=true
+ shift
+ ;;
+ --harmony)
+ BUILD_HARMONY=true
+ shift
+ ;;
+ --ndk)
+ ANDROID_NDK="$2"
+ shift 2
+ ;;
+ --harmony-home)
+ HARMONY_HOME="$2"
+ shift 2
+ ;;
+ --python-deps)
+ DEPS_OPTIONS="$2"
+ shift 2
+ ;;
+ --python-cmd)
+ PYTHON_CMD="$2"
+ shift 2
+ ;;
+ -o|--output)
+ OUTPUT_DIR="$2"
+ shift 2
+ ;;
+ -v|--version)
+ VERSION="$2"
+ shift 2
+ ;;
+ --no-clean)
+ CLEAN_BUILD=false
+ shift
+ ;;
+ -h|--help)
+ print_usage
+ exit 0
+ ;;
+ *)
+ echo -e "${RED}错误: 未知选项 $1${NC}"
+ print_usage
+ exit 1
+ ;;
+ esac
+done
+
+# 检查是否至少选择了一个构建目标
+if [ "$BUILD_ANDROID" = false ] && [ "$BUILD_IOS" = false ] && [ "$BUILD_IOS_SIMULATOR" = false ] && [ "$BUILD_HARMONY" = false ] && [ "$BUILD_PYTHON" = false ]; then
+ echo -e "${RED}错误: 请至少选择一个构建目标 (--android, --ios, --ios-simulator, --harmony, 或 --python)${NC}"
+ print_usage
+ exit 1
+fi
+
+# 检查 Android NDK
+if [ "$BUILD_ANDROID" = true ]; then
+ if [ -z "$ANDROID_NDK" ]; then
+ echo -e "${RED}错误: 构建 Android 必须指定 NDK 路径 (使用 --ndk)${NC}"
+ echo -e "${RED}示例: $0 --android --ndk ~/Library/Android/sdk/ndk/29.0.13599879${NC}"
+ exit 1
+ fi
+
+ # 展开路径中的 ~
+ ANDROID_NDK="${ANDROID_NDK/#\~/$HOME}"
+
+ if [ ! -d "$ANDROID_NDK" ]; then
+ echo -e "${RED}错误: NDK 路径不存在: $ANDROID_NDK${NC}"
+ exit 1
+ fi
+fi
+
+# 查找鸿蒙工具链的函数
+find_harmony_toolchain() {
+ # 如果环境变量已设置,优先使用
+ if [ -n "$HARMONY_HOME" ]; then
+ # 支持两种路径格式:直接指定或带 native 子目录
+ if [ -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then
+ echo "$HARMONY_HOME"
+ return 0
+ elif [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then
+ echo "$HARMONY_HOME"
+ return 0
+ fi
+ fi
+
+ # 优先查找 ~/Library/OpenHarmony/Sdk,支持版本号目录
+ local sdk_base="$HOME/Library/OpenHarmony/Sdk"
+ if [ -d "$sdk_base" ]; then
+ # 查找所有版本号目录下的 native 或 toolchains
+ # 按版本号倒序排列,优先使用最新版本
+ for version_dir in $(ls -d "$sdk_base"/* 2>/dev/null | sort -Vr); do
+ if [ -d "$version_dir" ]; then
+ # 尝试 native/build/cmake/ohos.toolchain.cmake
+ if [ -f "$version_dir/native/build/cmake/ohos.toolchain.cmake" ]; then
+ echo "$version_dir/native"
+ return 0
+ fi
+ # 尝试 build/cmake/ohos.toolchain.cmake
+ if [ -f "$version_dir/build/cmake/ohos.toolchain.cmake" ]; then
+ echo "$version_dir"
+ return 0
+ fi
+ fi
+ done
+ fi
+
+ # 其他可能的路径
+ local possible_paths=(
+ "$HOME/Library/OpenHarmony/Sdk/native"
+ "$HOME/HarmonyOS/Sdk/native"
+ "$HOME/.ohos/native"
+ "/opt/HarmonyOS/Sdk/native"
+ "/usr/local/HarmonyOS/Sdk/native"
+ "$HOME/Library/HarmonyOS/Sdk/native"
+ )
+
+ # 尝试查找
+ for path in "${possible_paths[@]}"; do
+ if [ -n "$path" ] && [ -f "$path/build/cmake/ohos.toolchain.cmake" ]; then
+ echo "$path"
+ return 0
+ fi
+ done
+
+ # 限制搜索范围,只在 OpenHarmony/Sdk 目录下快速查找
+ # 避免在整个 OpenHarmony 目录下递归查找,这可能会很慢
+ local found=$(find "$HOME/Library/OpenHarmony/Sdk" -maxdepth 4 -type f -name "ohos.toolchain.cmake" 2>/dev/null | head -1)
+ if [ -n "$found" ]; then
+ # 从 ohos.toolchain.cmake 向上查找 native 目录或 SDK 根目录
+ found=$(dirname "$found")
+ if [ "$(basename "$found")" = "cmake" ]; then
+ found=$(dirname "$found")
+ if [ "$(basename "$found")" = "build" ]; then
+ found=$(dirname "$found")
+ fi
+ fi
+ echo "$found"
+ return 0
+ fi
+
+ return 1
+}
+
+# 检查鸿蒙工具链
+if [ "$BUILD_HARMONY" = true ]; then
+ # 展开路径中的 ~
+ if [ -n "$HARMONY_HOME" ]; then
+ HARMONY_HOME="${HARMONY_HOME/#\~/$HOME}"
+ fi
+
+ # 尝试查找工具链
+ if [ -z "$HARMONY_HOME" ] || [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then
+ # 检查是否在 native 子目录下
+ if [ -n "$HARMONY_HOME" ] && [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then
+ HARMONY_HOME="$HARMONY_HOME/native"
+ else
+ echo -e "${YELLOW}正在查找鸿蒙工具链...${NC}"
+ HARMONY_HOME=$(find_harmony_toolchain)
+
+ if [ -z "$HARMONY_HOME" ]; then
+ echo -e "${RED}错误: 找不到鸿蒙工具链${NC}"
+ echo -e "${RED}默认查找路径: ~/Library/OpenHarmony/Sdk/*/native${NC}"
+ echo -e "${RED}请使用 --harmony-home 指定路径,或设置 HARMONY_HOME 环境变量${NC}"
+ echo -e "${RED}工具链文件应位于: /build/cmake/ohos.toolchain.cmake 或 /native/build/cmake/ohos.toolchain.cmake${NC}"
+ exit 1
+ fi
+
+ # 如果找到的是带 native 的路径,需要调整
+ if [ -f "$HARMONY_HOME/native/build/cmake/ohos.toolchain.cmake" ]; then
+ HARMONY_HOME="$HARMONY_HOME/native"
+ fi
+ fi
+
+ echo -e "${GREEN}找到鸿蒙工具链: $HARMONY_HOME${NC}"
+ fi
+
+ # 验证工具链文件存在
+ if [ ! -f "$HARMONY_HOME/build/cmake/ohos.toolchain.cmake" ]; then
+ echo -e "${RED}错误: 鸿蒙工具链文件不存在: $HARMONY_HOME/build/cmake/ohos.toolchain.cmake${NC}"
+ exit 1
+ fi
+fi
+
+echo -e "${GREEN}========================================${NC}"
+echo -e "${GREEN}MNN 统一构建脚本${NC}"
+echo -e "${GREEN}========================================${NC}"
+echo "项目根目录: $PROJECT_ROOT"
+echo "输出目录: $OUTPUT_DIR"
+echo ""
+echo -e "${BLUE}构建目标:${NC}"
+[ "$BUILD_ANDROID" = true ] && echo " ✓ Android"
+[ "$BUILD_IOS" = true ] && echo " ✓ iOS (真机 arm64)"
+[ "$BUILD_IOS_SIMULATOR" = true ] && echo " ✓ iOS (模拟器 x86_64 + arm64)"
+[ "$BUILD_HARMONY" = true ] && echo " ✓ 鸿蒙 (arm64-v8a)"
+[ "$BUILD_PYTHON" = true ] && echo " ✓ Python"
+echo ""
+
+cd "$PROJECT_ROOT"
+mkdir -p "$OUTPUT_DIR"
+
+# ============================================================================
+# 构建 Android 版本
+# ============================================================================
+if [ "$BUILD_ANDROID" = true ]; then
+ echo -e "${GREEN}========================================${NC}"
+ echo -e "${GREEN}开始构建 Android 版本${NC}"
+ echo -e "${GREEN}========================================${NC}"
+
+ export ANDROID_NDK
+
+ ANDROID_BUILD_DIR="project/android"
+ ANDROID_OUTPUT_DIR="$OUTPUT_DIR/android"
+
+ cd "$ANDROID_BUILD_DIR"
+
+ # 清理
+ if [ "$CLEAN_BUILD" = true ]; then
+ echo -e "${YELLOW}清理之前的 Android 构建...${NC}"
+ rm -rf build_32 build_64 export
+ fi
+
+ # 在项目根目录创建输出目录
+ mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR"
+
+ # 构建 armeabi-v7a
+ echo -e "${BLUE}构建 armeabi-v7a...${NC}"
+ mkdir -p build_32
+ cd build_32
+ cmake ../../../ \
+ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_ABI="armeabi-v7a" \
+ -DANDROID_STL=c++_static \
+ -DANDROID_NATIVE_API_LEVEL=android-14 \
+ -DANDROID_TOOLCHAIN=clang \
+ -DMNN_USE_LOGCAT=false \
+ -DMNN_USE_SSE=OFF \
+ -DMNN_BUILD_TEST=ON \
+ -DMNN_ARM82=OFF \
+ -DMNN_LOW_MEMORY=ON \
+ -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \
+ -DMNN_BUILD_LLM=ON \
+ -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \
+ -DMNN_BUILD_DIFFUSION=ON \
+ -DMNN_OPENCL=OFF \
+ -DMNN_SEP_BUILD=OFF \
+ -DLLM_SUPPORT_AUDIO=ON \
+ -DMNN_BUILD_AUDIO=ON \
+ -DLLM_SUPPORT_VISION=ON \
+ -DMNN_BUILD_OPENCV=ON \
+ -DMNN_IMGCODECS=ON \
+ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
+ -DNATIVE_LIBRARY_OUTPUT=. \
+ -DNATIVE_INCLUDE_OUTPUT=. \
+ > /dev/null
+
+ make -j4 MNN > /dev/null
+ cd ..
+
+ # 构建 arm64-v8a
+ echo -e "${BLUE}构建 arm64-v8a...${NC}"
+ mkdir -p build_64
+ cd build_64
+ cmake ../../../ \
+ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DANDROID_ABI="arm64-v8a" \
+ -DANDROID_STL=c++_static \
+ -DANDROID_NATIVE_API_LEVEL=android-21 \
+ -DANDROID_TOOLCHAIN=clang \
+ -DMNN_USE_LOGCAT=false \
+ -DMNN_BUILD_BENCHMARK=ON \
+ -DMNN_USE_SSE=OFF \
+ -DMNN_BUILD_TEST=ON \
+ -DMNN_ARM82=ON \
+ -DMNN_LOW_MEMORY=ON \
+ -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \
+ -DMNN_BUILD_LLM=ON \
+ -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \
+ -DMNN_BUILD_DIFFUSION=ON \
+ -DMNN_OPENCL=OFF \
+ -DMNN_SEP_BUILD=OFF \
+ -DLLM_SUPPORT_AUDIO=ON \
+ -DMNN_BUILD_AUDIO=ON \
+ -DLLM_SUPPORT_VISION=ON \
+ -DMNN_BUILD_OPENCV=ON \
+ -DMNN_IMGCODECS=ON \
+ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \
+ -DNATIVE_LIBRARY_OUTPUT=. \
+ -DNATIVE_INCLUDE_OUTPUT=. \
+ > /dev/null
+
+ make -j4 MNN > /dev/null
+ cd ..
+
+ # 导出文件
+ echo -e "${BLUE}导出 Android 库文件...${NC}"
+ mkdir -p export/android/{armeabi-v7a,arm64-v8a}/libs export/android/include/MNN
+
+ cp build_32/*.so export/android/armeabi-v7a/libs/ 2>/dev/null || true
+ cp build_64/*.so export/android/arm64-v8a/libs/ 2>/dev/null || true
+ cp -r ../../include/MNN/* export/android/include/MNN/
+
+ # 复制到统一输出目录
+ # 如果目标路径存在但不是目录,先删除
+ if [ -e "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR" ]; then
+ rm -f "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR"
+ fi
+ mkdir -p "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR"
+ cp -r export/android/* "$PROJECT_ROOT/$ANDROID_OUTPUT_DIR/"
+
+ echo -e "${GREEN}Android 构建完成!${NC}"
+ echo "输出位置: $ANDROID_OUTPUT_DIR"
+ cd "$PROJECT_ROOT"
+fi
+
+# ============================================================================
+# 构建 iOS 真机版本
+# ============================================================================
+if [ "$BUILD_IOS" = true ]; then
+ echo -e "${GREEN}========================================${NC}"
+ echo -e "${GREEN}开始构建 iOS 真机版本${NC}"
+ echo -e "${GREEN}========================================${NC}"
+
+ # 检查是否在 macOS 上
+ if [[ "$OSTYPE" != "darwin"* ]]; then
+ echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}"
+ exit 1
+ fi
+
+ # 检查 Xcode
+ if ! command -v xcodebuild &> /dev/null; then
+ echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}"
+ exit 1
+ fi
+
+ IOS_BUILD_DIR="project/ios"
+ IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/device"
+
+ cd "$IOS_BUILD_DIR"
+
+ # 清理
+ if [ "$CLEAN_BUILD" = true ]; then
+ echo -e "${YELLOW}清理之前的 iOS 真机构建...${NC}"
+ rm -rf MNN-iOS-CPU-GPU/Static/ios_64
+ find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_64/*" 2>/dev/null | xargs rm -f 2>/dev/null || true
+ fi
+
+ mkdir -p MNN-iOS-CPU-GPU/Static
+ cd MNN-iOS-CPU-GPU/Static
+
+ # 构建真机版本 (arm64)
+ echo -e "${BLUE}构建 iOS 真机版本 (arm64)...${NC}"
+ rm -rf ios_64
+ mkdir ios_64
+ cd ios_64
+ cmake "$PROJECT_ROOT" \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \
+ -DENABLE_BITCODE=0 \
+ -DMNN_AAPL_FMWK=1 \
+ -DMNN_SEP_BUILD=OFF \
+ -DMNN_BUILD_SHARED_LIBS=false \
+ -DMNN_USE_THREAD_POOL=OFF \
+ -DPLATFORM=OS64 \
+ -DARCHS="arm64" \
+ -DMNN_ARM82=ON \
+ -DMNN_LOW_MEMORY=ON \
+ -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \
+ -DMNN_BUILD_LLM=ON \
+ -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \
+ -DMNN_METAL=ON \
+ -DMNN_BUILD_DIFFUSION=ON \
+ -DMNN_OPENCL=OFF \
+ -DLLM_SUPPORT_AUDIO=ON \
+ -DMNN_BUILD_AUDIO=ON \
+ -DLLM_SUPPORT_VISION=ON \
+ -DMNN_BUILD_OPENCV=ON \
+ -DMNN_IMGCODECS=ON \
+ > /dev/null
+ make MNN -j8 > /dev/null
+ cd ..
+
+ # 输出真机版本
+ mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR"
+ rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework"
+ cp -R ios_64/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/"
+
+ # 清理
+ rm -rf ios_64
+
+ echo -e "${GREEN}iOS 真机构建完成!${NC}"
+ echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework"
+ cd "$PROJECT_ROOT"
+fi
+
+# ============================================================================
+# 构建 iOS 模拟器版本
+# ============================================================================
+if [ "$BUILD_IOS_SIMULATOR" = true ]; then
+ echo -e "${GREEN}========================================${NC}"
+ echo -e "${GREEN}开始构建 iOS 模拟器版本${NC}"
+ echo -e "${GREEN}========================================${NC}"
+
+ # 检查是否在 macOS 上
+ if [[ "$OSTYPE" != "darwin"* ]]; then
+ echo -e "${RED}错误: iOS 构建只能在 macOS 上执行${NC}"
+ exit 1
+ fi
+
+ # 检查 Xcode
+ if ! command -v xcodebuild &> /dev/null; then
+ echo -e "${RED}错误: 找不到 Xcode,请先安装 Xcode${NC}"
+ exit 1
+ fi
+
+ IOS_BUILD_DIR="project/ios"
+ IOS_OUTPUT_DIR="$OUTPUT_DIR/ios/simulator"
+
+ cd "$IOS_BUILD_DIR"
+
+ # 清理
+ if [ "$CLEAN_BUILD" = true ]; then
+ echo -e "${YELLOW}清理之前的 iOS 模拟器构建...${NC}"
+ rm -rf MNN-iOS-CPU-GPU/Static/ios_simulator*
+ find "$PROJECT_ROOT" -name "CMakeCache.txt" -path "*/ios_simulator*/*" 2>/dev/null | xargs rm -f 2>/dev/null || true
+ fi
+
+ mkdir -p MNN-iOS-CPU-GPU/Static
+ cd MNN-iOS-CPU-GPU/Static
+
+ # 构建 x86_64 模拟器版本(尝试构建,失败则跳过)
+ BUILD_X86_64=false
+ echo -e "${BLUE}构建 iOS 模拟器版本 (x86_64)...${NC}"
+ rm -rf ios_simulator_x86
+ mkdir ios_simulator_x86
+ cd ios_simulator_x86
+ if cmake "$PROJECT_ROOT" \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \
+ -DENABLE_BITCODE=0 \
+ -DMNN_AAPL_FMWK=1 \
+ -DMNN_SEP_BUILD=OFF \
+ -DMNN_BUILD_SHARED_LIBS=false \
+ -DMNN_USE_THREAD_POOL=OFF \
+ -DPLATFORM=SIMULATOR64 \
+ -DARCHS="x86_64" \
+ -DMNN_ARM82=OFF \
+ -DMNN_LOW_MEMORY=ON \
+ -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \
+ -DMNN_BUILD_LLM=ON \
+ -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \
+ -DMNN_METAL=OFF \
+ -DMNN_BUILD_DIFFUSION=ON \
+ -DMNN_OPENCL=OFF \
+ -DLLM_SUPPORT_AUDIO=ON \
+ -DMNN_BUILD_AUDIO=ON \
+ -DLLM_SUPPORT_VISION=ON \
+ -DMNN_BUILD_OPENCV=ON \
+ -DMNN_IMGCODECS=ON \
+ && make MNN -j8; then
+ if [ -f "MNN.framework/MNN" ]; then
+ BUILD_X86_64=true
+ echo -e "${GREEN}x86_64 模拟器构建成功${NC}"
+ cd ..
+ else
+ echo -e "${YELLOW}警告: x86_64 模拟器构建产物不存在,跳过此架构${NC}"
+ cd ..
+ rm -rf ios_simulator_x86
+ fi
+ else
+ echo -e "${YELLOW}警告: x86_64 模拟器构建失败,跳过此架构(这在 Apple Silicon Mac 上是正常的)${NC}"
+ cd ..
+ rm -rf ios_simulator_x86
+ fi
+
+ # 构建 arm64 模拟器版本
+ echo -e "${BLUE}构建 iOS 模拟器版本 (arm64)...${NC}"
+ rm -rf ios_simulator_arm64
+ mkdir ios_simulator_arm64
+ cd ios_simulator_arm64
+ if ! cmake "$PROJECT_ROOT" \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DCMAKE_TOOLCHAIN_FILE=$PROJECT_ROOT/cmake/ios.toolchain.cmake \
+ -DENABLE_BITCODE=0 \
+ -DMNN_AAPL_FMWK=1 \
+ -DMNN_SEP_BUILD=OFF \
+ -DMNN_BUILD_SHARED_LIBS=false \
+ -DMNN_USE_THREAD_POOL=OFF \
+ -DPLATFORM=SIMULATOR64 \
+ -DARCHS="arm64" \
+ -DMNN_ARM82=ON \
+ -DMNN_LOW_MEMORY=ON \
+ -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \
+ -DMNN_BUILD_LLM=ON \
+ -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \
+ -DMNN_METAL=OFF \
+ -DMNN_BUILD_DIFFUSION=ON \
+ -DMNN_OPENCL=OFF \
+ -DLLM_SUPPORT_AUDIO=ON \
+ -DMNN_BUILD_AUDIO=ON \
+ -DLLM_SUPPORT_VISION=ON \
+ -DMNN_BUILD_OPENCV=ON \
+ -DMNN_IMGCODECS=ON; then
+ echo -e "${RED}错误: arm64 模拟器 cmake 配置失败${NC}"
+ cd "$PROJECT_ROOT"
+ exit 1
+ fi
+ if ! make MNN -j8; then
+ echo -e "${RED}错误: arm64 模拟器编译失败${NC}"
+ cd "$PROJECT_ROOT"
+ exit 1
+ fi
+ cd ..
+
+ # 验证构建产物
+ if [ ! -f "ios_simulator_arm64/MNN.framework/MNN" ]; then
+ echo -e "${RED}错误: 未找到 arm64 模拟器框架文件${NC}"
+ cd "$PROJECT_ROOT"
+ exit 1
+ fi
+
+ # 合并模拟器架构
+ echo -e "${BLUE}合并模拟器架构...${NC}"
+ rm -rf ios_simulator
+ mkdir ios_simulator
+
+ if [ "$BUILD_X86_64" = true ] && [ -f "ios_simulator_x86/MNN.framework/MNN" ]; then
+ # 合并 x86_64 + arm64
+ echo -e "${BLUE}合并 x86_64 + arm64 架构...${NC}"
+ cp -R ios_simulator_x86/MNN.framework ios_simulator/MNN.framework
+ mv ios_simulator/MNN.framework/MNN ios_simulator/MNN.framework/MNN_x86
+ if ! lipo -create ios_simulator/MNN.framework/MNN_x86 ios_simulator_arm64/MNN.framework/MNN -output ios_simulator/MNN.framework/MNN; then
+ echo -e "${RED}错误: 合并架构失败${NC}"
+ cd "$PROJECT_ROOT"
+ exit 1
+ fi
+ rm ios_simulator/MNN.framework/MNN_x86
+ else
+ # 仅使用 arm64
+ echo -e "${BLUE}仅使用 arm64 架构(x86_64 不可用)...${NC}"
+ cp -R ios_simulator_arm64/MNN.framework ios_simulator/MNN.framework
+ fi
+
+ # 输出模拟器版本
+ mkdir -p "$PROJECT_ROOT/$IOS_OUTPUT_DIR"
+ rm -rf "$PROJECT_ROOT/$IOS_OUTPUT_DIR/MNN.framework"
+ cp -R ios_simulator/MNN.framework "$PROJECT_ROOT/$IOS_OUTPUT_DIR/"
+
+ # 清理临时目录
+ rm -rf ios_simulator ios_simulator_x86 ios_simulator_arm64
+
+ echo -e "${GREEN}iOS 模拟器构建完成!${NC}"
+ echo "输出位置: $IOS_OUTPUT_DIR/MNN.framework"
+ cd "$PROJECT_ROOT"
+fi
+
+# ============================================================================
+# 构建鸿蒙版本
+# ============================================================================
+if [ "$BUILD_HARMONY" = true ]; then
+ echo -e "${GREEN}========================================${NC}"
+ echo -e "${GREEN}开始构建鸿蒙版本${NC}"
+ echo -e "${GREEN}========================================${NC}"
+
+ export HARMONY_HOME
+
+ HARMONY_BUILD_DIR="project/harmony"
+ HARMONY_OUTPUT_DIR="$OUTPUT_DIR/harmony"
+
+ cd "$HARMONY_BUILD_DIR"
+
+ # 清理
+ if [ "$CLEAN_BUILD" = true ]; then
+ echo -e "${YELLOW}清理之前的鸿蒙构建...${NC}"
+ rm -rf build_64 export
+ fi
+
+ # 在项目根目录创建输出目录
+ mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR"
+
+ # 构建 arm64-v8a
+ echo -e "${BLUE}构建 arm64-v8a...${NC}"
+ mkdir -p build_64
+ cd build_64
+
+ cmake "$PROJECT_ROOT" \
+ -DCMAKE_TOOLCHAIN_FILE=$HARMONY_HOME/build/cmake/ohos.toolchain.cmake \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DOHOS_ARCH="arm64-v8a" \
+ -DOHOS_STL=c++_static \
+ -DMNN_USE_LOGCAT=true \
+ -DMNN_BUILD_BENCHMARK=ON \
+ -DMNN_USE_SSE=OFF \
+ -DMNN_SUPPORT_BF16=OFF \
+ -DMNN_BUILD_TEST=ON \
+ -DMNN_ARM82=ON \
+ -DMNN_LOW_MEMORY=ON \
+ -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \
+ -DMNN_BUILD_LLM=ON \
+ -DMNN_CPU_WEIGHT_DEQUANT_GEMM=ON \
+ -DMNN_BUILD_DIFFUSION=ON \
+ -DMNN_OPENCL=OFF \
+ -DMNN_SEP_BUILD=OFF \
+ -DLLM_SUPPORT_AUDIO=ON \
+ -DMNN_BUILD_AUDIO=ON \
+ -DLLM_SUPPORT_VISION=ON \
+ -DMNN_BUILD_OPENCV=ON \
+ -DMNN_IMGCODECS=ON \
+ -DOHOS_PLATFORM_LEVEL=9 \
+ -DNATIVE_LIBRARY_OUTPUT=. \
+ -DNATIVE_INCLUDE_OUTPUT=. \
+ > /dev/null
+
+ make -j4 MNN > /dev/null
+ cd ..
+
+ # 导出文件
+ echo -e "${BLUE}导出鸿蒙库文件...${NC}"
+ mkdir -p export/harmony/arm64-v8a/libs export/harmony/include/MNN
+
+ cp build_64/*.so export/harmony/arm64-v8a/libs/ 2>/dev/null || true
+ cp -r ../../include/MNN/* export/harmony/include/MNN/
+
+ # 复制到统一输出目录
+ # 如果目标路径存在但不是目录,先删除
+ if [ -e "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ] && [ ! -d "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR" ]; then
+ rm -f "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR"
+ fi
+ mkdir -p "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR"
+ cp -r export/harmony/* "$PROJECT_ROOT/$HARMONY_OUTPUT_DIR/"
+
+ echo -e "${GREEN}鸿蒙构建完成!${NC}"
+ echo "输出位置: $HARMONY_OUTPUT_DIR"
+ cd "$PROJECT_ROOT"
+fi
+
+# ============================================================================
+# 构建 Python 版本
+# ============================================================================
+if [ "$BUILD_PYTHON" = true ]; then
+ echo -e "${GREEN}========================================${NC}"
+ echo -e "${GREEN}开始构建 Python 版本${NC}"
+ echo -e "${GREEN}========================================${NC}"
+
+ # 检查 Python
+ if ! command -v $PYTHON_CMD &> /dev/null; then
+ echo -e "${RED}错误: 找不到 Python 命令 '$PYTHON_CMD'${NC}"
+ exit 1
+ fi
+
+ PYTHON_OUTPUT_DIR="$OUTPUT_DIR/python"
+
+ # 使用之前创建的 build_pymnn.sh 脚本
+ if [ -f "$PROJECT_ROOT/build_pymnn.sh" ]; then
+ PYTHON_BUILD_ARGS="-o $PYTHON_OUTPUT_DIR"
+ [ -n "$VERSION" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -v $VERSION"
+ [ -n "$DEPS_OPTIONS" ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS -d $DEPS_OPTIONS"
+ [ "$CLEAN_BUILD" = false ] && PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --no-clean"
+ PYTHON_BUILD_ARGS="$PYTHON_BUILD_ARGS --python $PYTHON_CMD"
+
+ bash "$PROJECT_ROOT/build_pymnn.sh" $PYTHON_BUILD_ARGS
+ else
+ echo -e "${YELLOW}警告: 未找到 build_pymnn.sh,使用基本构建方式...${NC}"
+
+ cd pymnn/pip_package
+
+ # 构建依赖
+ if [ -n "$DEPS_OPTIONS" ]; then
+ $PYTHON_CMD build_deps.py $DEPS_OPTIONS
+ else
+ $PYTHON_CMD build_deps.py
+ fi
+
+ # 构建 wheel
+ $PYTHON_CMD -m pip install -U numpy wheel setuptools --quiet
+ rm -rf build dist
+
+ BUILD_ARGS=""
+ [ -n "$VERSION" ] && BUILD_ARGS="--version $VERSION"
+ [ -n "$DEPS_OPTIONS" ] && BUILD_ARGS="$BUILD_ARGS --deps $DEPS_OPTIONS"
+
+ $PYTHON_CMD setup.py bdist_wheel $BUILD_ARGS
+
+ mkdir -p "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR"
+ cp dist/*.whl "$PROJECT_ROOT/$PYTHON_OUTPUT_DIR/"
+
+ cd "$PROJECT_ROOT"
+ fi
+
+ echo -e "${GREEN}Python 构建完成!${NC}"
+ echo "输出位置: $PYTHON_OUTPUT_DIR"
+fi
+
+# ============================================================================
+# 总结
+# ============================================================================
+echo ""
+echo -e "${GREEN}========================================${NC}"
+echo -e "${GREEN}所有构建完成!${NC}"
+echo -e "${GREEN}========================================${NC}"
+echo ""
+echo "输出目录: $PROJECT_ROOT/$OUTPUT_DIR"
+echo ""
+[ "$BUILD_ANDROID" = true ] && echo -e "${GREEN}Android:${NC} $OUTPUT_DIR/android"
+[ "$BUILD_IOS" = true ] && echo -e "${GREEN}iOS (真机):${NC} $OUTPUT_DIR/ios/device"
+[ "$BUILD_IOS_SIMULATOR" = true ] && echo -e "${GREEN}iOS (模拟器):${NC} $OUTPUT_DIR/ios/simulator"
+[ "$BUILD_HARMONY" = true ] && echo -e "${GREEN}鸿蒙:${NC} $OUTPUT_DIR/harmony"
+[ "$BUILD_PYTHON" = true ] && echo -e "${GREEN}Python:${NC} $OUTPUT_DIR/python"
+echo ""
+
+
diff --git a/docker_release.sh b/docker_release.sh
deleted file mode 100755
index 05ba96cf2d..0000000000
--- a/docker_release.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-# using docker run release
-docker start mnn_release
-docker exec -i -e TEST_ID=$(pwd | awk -F "/" '{print $(NF-1)}') mnn_release bash <<'EOF'
-cd ~/yanxing_zhaode/cise/space/$TEST_ID/source && ./release.sh pymnn
-exit
-EOF
\ No newline at end of file
diff --git a/docker_run.sh b/docker_run.sh
deleted file mode 100755
index 6cff3321bc..0000000000
--- a/docker_run.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-# using docker run test
-docker start mnn_ci
-docker exec -i -e TEST_ID=$(pwd | awk -F "/" '{print $(NF-1)}') mnn_ci bash <<'EOF'
-cd ~/yanxing_zhaode/cise/space/$TEST_ID/source && ./test.sh linux
-exit
-EOF
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000000..d4bb2cbb9e
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md
index 6513e38fad..138a9ddd49 100644
--- a/docs/compile/cmake.md
+++ b/docs/compile/cmake.md
@@ -39,6 +39,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
| MNN_INTERNAL | 是否构建MNN的一些内部功能,如:日志;默认为`OFF` |
| MNN_JNI | 是否构建MNN的JNI支持,默认为`OFF` |
| MNN_METAL | 是否构建`Metal`后端,默认为`OFF` |
+| MNN_METAL_TENSOR | 是否启用`Metal Tensor`接口,该宏仅在`MNN_METAL=ON`时生效,默认为`ON` |
| MNN_OPENCL | 是否构建`OpenCL`后端,默认为`OFF` |
| MNN_OPENGL | 是否构建`OpenGL`后端,默认为`OFF` |
| MNN_VULKAN | 是否构建`Vulkan`后端,默认为`OFF` |
@@ -101,4 +102,5 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
| MNN_BUILD_LLM_OMNI | 若构建基于MNN的llm库和demo,是否支持图像和音频输入功能,默认为`OFF` 。仅在MNN_BUILD_LLM 打开时生效。开启时 MNN_BUILD_OPENCV , MNN_IMGCODECS , MNN_BUILD_AUDIO 同时打开|
| MNN_BUILD_DIFFUSION | 是否构建基于MNN的diffusion demo,默认为`OFF` . 打开时MNN_BUILD_OPENCV , MNN_IMGCODECS, MNN_LOW_MEMORY, MNN_SUPPORT_TRANSFORMER_FUSE 同步开启|
| MNN_KLEIDIAI | 是否集成ARM的klediAI加速库,默认为`ON` |
+| MNN_KLEIDIAI_DEFAULT_ON | 是否默认使用KLEIDIAI的Kernel, 默认为`OFF` |
| MNN_USE_RVV | 是否启用RISC-V向量扩展支持,默认为`OFF` |
diff --git a/docs/compile/other.md b/docs/compile/other.md
index b3edb55358..31a2734337 100644
--- a/docs/compile/other.md
+++ b/docs/compile/other.md
@@ -63,6 +63,7 @@
- `llm_demo` 大语言模型推理示例程序
- `diffusion_demo` 扩散模型示例程序
- `llm_bench` 大语言模型测评工具
+ - `rollback_demo` 大语言模型kvcache回调示例工具
- `quantize_llm` 大语言模型feature map量化工具
## 测试工具
- 相关编译选项
diff --git a/docs/faq.md b/docs/faq.md
index c1c7addef0..7078e3caf1 100644
--- a/docs/faq.md
+++ b/docs/faq.md
@@ -83,7 +83,7 @@ MNN 一般以动态库形式使用,里面有大量自注册函数,如果需
- GCC: -Wl,--whole-archive MNN -Wl,--no-whole-archive
- OSX(Xcode): -Wl,-force_load MNN
-- Window(Visio-Studio): /WHOLEARCHIVE:MNN
+- Window(Visual-Studio): /WHOLEARCHIVE:MNN
## 模型转换
@@ -317,3 +317,25 @@ GPU 后端调用 copy 的时间包含两个部分
3. 重新编译`MNN`库文件, `Convert`等所有工具;
4. 使用新的工具重新转换模型;
5. 在端侧使用新模型和新的`MNN`库文件进行部署;
+
+### 如何裁剪MNN库体积
+1. 如果模型输入形状固定,且不包含控制流算子,可以在编译MNN时增加 -DMNN_BUILD_MINI=ON ,这样形状计算和几何计算
+2. 如果只希望MNN运行指定的一系列模型,可以按如下方式先获取算子列表,再根据算子列表裁剪MNN
+
+```
+# 产出 op.txt
+./GetMNNInfo x0.mnn
+# 追加 op.txt
+./GetMNNInfo x1.mnn
+./GetMNNInfo x2.mnn
+......
+```
+
+```
+cd ${MNN_ROOT}
+python3 tools/script/prue_mnn_ops.py op.txt .
+# 执行完成后会发现 source/backend/cpu/CPUOPRegister.cpp , source/geometry/GeometryOPRegister.cpp 等注册函数被修改了,说明生效
+```
+
+- 如果 MNN 编译为静态库,裁剪方案不影响静态库大小,但可以减小最终集成体积
+- 如果需要编译模型转换工具等,不要启用裁剪方案。建议仅在交叉编译目标部署设备的二进制库时使用裁剪方案。
diff --git a/docs/tools/visual.md b/docs/tools/visual.md
index 3b4494eb3c..02d6629720 100644
--- a/docs/tools/visual.md
+++ b/docs/tools/visual.md
@@ -2,7 +2,7 @@
可视化的效果:

-在详细调研了市面上比较主流的可视化工具后,`Netron`是一款受众面较多、兼容多款模型的可视化模型,同时它还具有跨平台支持、`Python`模块支持的能力。因此,在研究了一番它的设计和架构并考虑后续`MNN`自身的演进,我们决定**官方维护`MNN`模型的可视化能力并将其作为`Pull Request`合并,大家可以放心使用啦。**
+在详细调研了市面上比较主流的可视化工具后,`Netron`是一款受众面较广、兼容多款模型的可视化工具,同时它还具有跨平台支持、`Python`模块支持的能力。因此,在研究了一番它的设计和架构并考虑后续`MNN`自身的演进,我们决定**官方维护`MNN`模型的可视化能力并将其作为`Pull Request`合并,大家可以放心使用啦。**
## 功能列表
- 支持加载`.mnn`模型 。
diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md
index 7de27bb216..609793f806 100644
--- a/docs/transformers/diffusion.md
+++ b/docs/transformers/diffusion.md
@@ -20,7 +20,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai
cd mnn_path/transformers/diffusion/export
python onnx_export.py \
--model_path hf_sd_load_path \
- --output_path onnx_save_path
+ --output_path onnx_save_path \
+ --opset 18
```
注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境:
```
diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md
index fa4f60a851..389cd9364c 100644
--- a/docs/transformers/llm.md
+++ b/docs/transformers/llm.md
@@ -34,7 +34,7 @@
* **关键产物**:脚本会生成一个包含 `llm.mnn`, `llm.mnn.weight`, `tokenizer.txt`, `embeddings_bf16.bin`【可能存在】, `llm_config.json`, `config.json` 等文件的模型目录。
4. **(可选)高级功能**:
- * **量化**:通过 `--quant_bit 4` 和 `--quant_block 128` 等参数可以调节量化的Bits数,默认为`4 bit , block size 64`。通过 `--hqq` 或 `--awq` 可以启用对应算法以提升量化后的模型精度,一般建议增加`--hqq`
+ * **量化**:通过 `--quant_bit 4` 和 `--quant_block 128` 等参数可以调节量化的Bits数,默认为`4 bit , block size 64`。通过 `--hqq` 或 `--awq` 或 `--omni` 可以启用对应算法以提升量化后的模型精度,一般建议增加`--hqq`
* **LoRA**:通过 `--lora_path` 合并或分离 LoRA 权重。
* **Embeding**:对于目前主流的8b以下模型,采用了`Tie-Embeding`技术,默认不会导出`embeddings_bf16.bin`,而是复用`llm.mnn.weight`中的`lm`权重,需要提升embed精度可以设置 `--seperate_embed` 分离出`embeddings_bf16.bin`。
* **GPTQ**:通过 `--gptq_path` 应用预量化好的 GPTQ 权重。
@@ -190,7 +190,7 @@ python llmexport.py \
usage: llmexport.py [-h] --path PATH [--type TYPE] [--tokenizer_path TOKENIZER_PATH] [--lora_path LORA_PATH]
[--gptq_path GPTQ_PATH] [--dst_path DST_PATH] [--verbose] [--test TEST] [--export EXPORT]
[--onnx_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK]
- [--lm_quant_bit LM_QUANT_BIT] [--mnnconvert MNNCONVERT] [--ppl] [--awq] [--sym] [--seperate_embed]
+ [--lm_quant_bit LM_QUANT_BIT] [--mnnconvert MNNCONVERT] [--ppl] [--awq] [--omni] [--sym] [--seperate_embed]
[--lora_split]
llm_exporter
@@ -383,15 +383,23 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
- dit_model: 当使用Omni模型时,dit_model的实际路径为`base_dir + dit_model`,默认为`base_dir + 'dit.mnn'`
- bigvgan_model: 当使用Omni模型时,bigvgan_model的实际路径为`base_dir + bigvgan_model`,默认为`base_dir + 'bigvgan.mnn'`
- spk_dict: 当使用Omni模型时,spk_dict的实际路径为`base_dir + spk_dict`,默认为`base_dir + 'spk_dict.txt'`
+ - context_file: 配置上下文信息文件路径,实际路径为`base_dir + context_file`,默认`base_dir + 'context.json'`,内容格式为json格式的上下文信息,包含:如tools,enable_thinking等信息。
- 推理配置
- max_new_tokens: 生成时最大token数,默认为`512`
- reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false`.
- - quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下:
- - 0: key和value都不量化
- - 1: 使用非对称8bit量化存储key
- - 2: 使用fp8格式量化存储value
- - 3: 使用非对称8bit量化存储key,使用fp8格式量化存储value
- - 4: 量化kv的同时使用非对称8bit量化query,并使用int8矩阵乘计算Q*K
+ - quant_qkv: 选项废弃,请使用 `attention_mode`
+ - attention_mode:
+ - CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 8, 9, 10`,默认为`8`,含义如下:
+ - 0: 运行时不使用Flash Attention, query, key, value均不量化
+ - 1: 运行时不使用Flash Attention, query和key使用8bit非对称量化,value不量化
+ - 2: 运行时不使用Flash Attention, query, key, value均使用8bit非对称量化
+ - 8: 运行时使用Flash Attention, query, key, value均不量化
+ - 9: 运行时使用Flash Attention, query和key使用8bit非对称量化,value不量化
+ - 10: 运行时使用Flash Attention, query, key, value均使用8bit非对称量化
+ - GPU attention 算子中是否使用Flash Attention,可选为:`0, 8, 16`,默认为`8`,目前仅支持Metal后端,含义如下:
+ - 0: 运行时不使用Flash Attention, 朴素Attention实现,上下文较长时不推荐内存占用高
+ - 8: 运行时使用Flash Attention, 在算子层面分步实现,性能接近设为0,内存占用比设为0小
+ - 16: 运行时使用Flash Attention, 在算子层面单算子融合实现,内存占用最小,性能比设为8稍慢一些
- use_mmap: 是否使用mmap方式,在内存不足时将权重写入磁盘,避免溢出,默认为false,手机上建议设成true
- chunk: 限制每次最大处理的token数,高于此值将分块运行,以减少内存占用,eg: chunk: 128
- chunk_limits: 限制每次处理的token数,不在此范围内将分拆或者补零处理,eg: chunk_limits: [128, 1] , 存在 chunk_limits 时,chunk 配置无效
@@ -403,12 +411,14 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
- thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68`(不是传统意义的线程数,代表的是opencl buffer存储和tuning wide模式)
- precision: 推理使用精度策略,默认为:`"low"`,尽量使用`fp16`
- memory: 推理使用内存策略,默认为:`"low"`,开启运行时量化
-- 与CPU动态量化相关的配置
- - dynamic_option: 推理时是否对feature map分blocksize/group进行量化。可选为:`0, 1, 2`,默认是`0`,含义如下:
+- 与CPU动态量化相关的配置,提升精度、性能
+ - dynamic_option: 推理时是否对feature map分blocksize/group进行量化。可选为:`0, 1, 2, 8, 9, 10`,默认是`0`,含义如下:
- 0: feature map数据使用per channel量化
- 1: feature map数据使用per tensor量化
- 2: feature map数据用per block量化,blocksize等于权重量化时的blocksize,如果权重量化时没有使用per block量化,即使设置2,也不会对feature map做per block量化
- - prefer_decode: 是否希望有更快的解码(Decode)速度。可选:`true, false`,默认`false`。注意:当prompt长度小于300时,`true`条件下的Prefill速度会显著慢于`false`条件下时的性能。当prompt长度高于300时,`true`条件下的Prefill速度和`false`条件基本持平,Decode速度大约会快20%. 如果你希望在各种情况下Prefill速度和Decode速度更加均衡,建议设置该选项为`false`.
+ - 8+n(n=0,1,2): 该选项是为了加速LLM 推理时Decode性能。但是当prompt长度小于300时,Prefill速度会显著变慢。当prompt长度高于300时,Prefill速度不会变慢。
+ - cpu_sme2_neon_division_ratio: 为了提高Arm SME后端多线程推理时性能,可根据模型、线程数定制化设置该参数。参数计算方式: Prefill阶段单个SME核和NEON核的工作量比例x:1,Decode阶段工作量比例y:1,
+ 则参数设置为8*x+y,x和y均是不大于7的正整数。41、49和33是常见的参数设置. 可以通过观察单线程推理时,SME后端相较于NEON后端的加速比来决定该参数的取值。默认是`41`.
- Sampler配置
- sampler_type: 使用的sampler种类,目前支持`greedy`, `temperature`, `topK`, `topP`, `minP`, `tfs`, `typical`, `penalty`8种基本sampler,外加`mixed`(混合sampler,当选择`mixed`时,依次执行mixed_samplers中的sampler)。默认为`greedy`,但是建议使用`mixed`、`temperature`来增加输出多样性,或使用`penalty`来降低重复。
- mixed_samplers: 当`sampler_type`为`mixed`时有效,默认为`["topK", "tfs", "typical", "topP", "min_p", "temperature"]`, 模型计算得到的logits会依次经过这些sampler采样。
@@ -475,6 +485,21 @@ node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
"is_single": true
}
```
+- `context.json`
+ ```json
+ {
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_time",
+ "description": "获取当前时间"
+ }
+ }
+ ],
+ "enable_thinking": false
+ }
+ ```
#### 推理用法
`llm_demo`的用法如下:
@@ -552,6 +577,15 @@ options:
./llm_bench -m ./Qwen2.5-1.5B-Instruct/config.json,./Qwen2.5-0.5B-Instruct/config.json -a cpu,opencl,metal -c 1,2 -t 8,12 -p 16,32 -n 10,20 -pg 8,16 -mmp 0 -rep 4 -kv true -fp ./test_result
```
+#### 多Prompt场景下KVCache选择性复用
+rollback_demo提供了多Prompt场景下自行选择复用部分kvcache的示例代码。
+```bash
+./rollback_demo /path/to/model_dir/config.json /path/to/prompt.txt
+```
+其中,prompt.txt需要包含至少三组prompt。
+- cache_prefix_in_disk需要设置为0或1。
+- cache_prefix_in_disk 设置1表示:第一段Prompt是后续Prompt的公共前缀Prompt,第二、三段Prompt分别是基于第一段Prompt后续的文本内容。第一次启动会将前缀Prompt的KVCache缓存在磁盘文件中。第二次启动会跳过公共前缀Prompt的Prefill,直接在磁盘中加载,提升Prefill速度。。
+- cache_prefix_in_disk 设置0表示:在多段Prompt下,如何删除不需要的KVCache,仅保留关联性的KVCache示例。
#### GPTQ权重
需要使用GPTQ权重,可以在导出模型时,使用`--gptq_path PATH`来指定的路径,使用如下:
@@ -771,11 +805,29 @@ print(out_ids)
### LLM 模型导出
NPU运行LLM需要特定的量化格式,需要按如下参数以导出 mnn
+llmexport脚本导出在NPU上运行的模型时,必须使用的选项有:
+- --generate_for_npu: 导出在NPU上运行的模型
+- --seperate_embed: NPU必须使用embedding层和lm层分开存储
+- --sym: 目前NPU仅支持权重对称量化
+用于提高量化精度可以使用的选项,选择其一即可,不可以同时使用:
+- --smooth 使用Smooth量化算法提高精度
+- --omni:使用Omni量化算法提高精度
+部分选项说明:
+- QNN已经支持了feature map使用非对称量化,转模型时可以不使用`--act_sym`,即该选项可视情况加或者不加;
+- NPU目前仅支持feature map使用16bit量化以提高模型精度,所以转模型时加上选项`--act_bit=16`;
+- 经过测试,截止2026年1月,仅仅在高通8Gen5芯片上使用QNN推理时,权重是4bit量化且group=64时,模型性能会比权重8bit量化,group=0时更好。
+- 如果是要转换出在QNN上运行的LLM模型,LM层也会量化,该层的权重量化参数和其他Linear层一致
+- 模型用于量化的校准数据集来源于HuggingFace的wikitext数据集,如果你想要使用指定的多个prompt作为校准数据集,可以使用`--calib_data`选项
+
`--smooth --act_bit=16 --quant_block=0 --lm_quant_bit=16 --quant_bit=4 --seperate_embed --sym --act_sym`
eg:
```
-python3 llmexport.py --path /Users/xtjiang/.cache/modelscope/hub/models/Qwen/Qwen3-4B --export mnn --smooth --act_bit=16 --quant_block=0 --lm_quant_bit=16 --seperate_embed --quant_bit=4 --sym --act_sym
+python3 llmexport.py --path /YouPath/Dowload/models/Qwen/Qwen3-4B --export mnn --smooth --act_bit=16 --quant_block=0 --lm_quant_bit=16 --seperate_embed --quant_bit=4 --sym --act_sym
+```
+或者你也可以自定义校准数据集,并使用Omni算法提高量化精度:
+```
+python llmexport.py --path /YouPath/Dowload/models/Qwen/Qwen3-0.6B --export mnn --quant_block 64 --quant_bit 4 --generate_for_npu --seperate_embed --act_bit=16 --sym --omni --hqq --calib_data /Your/prompt.txt
```
### QNN LLM
@@ -844,7 +896,7 @@ mkdir build_64 && cd build_64
```
ANDROID_WORKING_DIR=/data/local/tmp/MNN/
-HEXAGON_ARCH=v75
+HEXAGON_ARCH=75
adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnHtp.so ${ANDROID_WORKING_DIR}
adb push ${QNN_SDK_ROOT}/lib/aarch64-android/libQnnHtpV${HEXAGON_ARCH}Stub.so ${ANDROID_WORKING_DIR}
adb push ${QNN_SDK_ROOT}/lib/hexagon-v${HEXAGON_ARCH}/unsigned/libQnnHtpV${HEXAGON_ARCH}Skel.so ${ANDROID_WORKING_DIR}
diff --git a/express/Expr.cpp b/express/Expr.cpp
index a735adbe3e..4080b7a9b6 100644
--- a/express/Expr.cpp
+++ b/express/Expr.cpp
@@ -237,7 +237,7 @@ EXPRP Expr::create(const OpT* op, std::vector inputs, int outputSize) {
return create(std::move(info), nullptr, VARP::INPUT);
}
if (OpType_Const == op->type || OpType_TrainableParam == op->type) {
- if (!op->externalPath.empty()) {
+ if (!op->externalPath.empty() || (!op->main.AsBlob()->external.empty())) {
flatbuffers::FlatBufferBuilder builder;
auto offset = Op::Pack(builder, op);
builder.Finish(offset);
@@ -813,12 +813,28 @@ void* Variable::readInternal(bool forShape) {
// The Varp will not be created as input, so we just need copy once
return inside->mHostTensor->host();
}
+
inside->mHostTensor = new Tensor;
TensorUtils::copyShape(originTensor, inside->mHostTensor, true);
inside->mHostTensor->buffer().type = originTensor->getType();
inside->mHostTensor->buffer().host = (uint8_t*)MNNMemoryAllocAlign(inside->mHostTensor->size(), MNN_MEMORY_ALIGN_DEFAULT);
TensorUtils::getDescribe(inside->mHostTensor)->memoryType = Tensor::InsideDescribe::MEMORY_HOST;
originTensor->copyToHostTensor(inside->mHostTensor);
+ bool hasNoExecution = false;
+ if (nullptr != originTensor) {
+ auto backend = TensorUtils::getDescribeOrigin(originTensor)->getBackend();
+ if (nullptr != backend) {
+ // Try to sync to check execution status
+ int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, originTensor);
+ if (NO_EXECUTION == syncResult) {
+ hasNoExecution = true;
+ }
+ }
+ }
+ if (hasNoExecution) {
+ MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n");
+ return nullptr;
+ }
return inside->mHostTensor->host();
}
return originTensor->buffer().host;
diff --git a/express/MathOp.cpp b/express/MathOp.cpp
index 903a42fa61..ac04f27ee0 100644
--- a/express/MathOp.cpp
+++ b/express/MathOp.cpp
@@ -1137,21 +1137,31 @@ VARP _UnravelIndex(VARP indices, VARP dims) {
VARP _ScatterNd(VARP indices, VARP updates, VARP shape, int reducetion) {
std::unique_ptr op(new OpT);
- op->main.type = OpParameter_BinaryOp;
op->type = OpType_ScatterNd;
- auto param = new BinaryOpT;
- param->opType = (BinaryOpOperation)reducetion;
- op->main.value = param;
+ if (reducetion != -1) {
+ op->main.type = OpParameter_BinaryOp;
+ auto param = new BinaryOpT;
+ param->opType = (BinaryOpOperation)reducetion;
+ op->main.value = param;
+ } else {
+ op->main.type = OpParameter_NONE;
+ op->main.value = nullptr;
+ }
return (Variable::create(Expr::create(std::move(op), {indices, updates, shape})));
}
VARP _ScatterNd(VARP indices, VARP updates, VARP shape, VARP input, int reducetion) {
std::unique_ptr op(new OpT);
- op->main.type = OpParameter_BinaryOp;
op->type = OpType_ScatterNd;
- auto param = new BinaryOpT;
- param->opType = (BinaryOpOperation)reducetion;
- op->main.value = param;
+ if (reducetion != -1) {
+ op->main.type = OpParameter_BinaryOp;
+ auto param = new BinaryOpT;
+ param->opType = (BinaryOpOperation)reducetion;
+ op->main.value = param;
+ } else {
+ op->main.type = OpParameter_NONE;
+ op->main.value = nullptr;
+ }
return (Variable::create(Expr::create(std::move(op), {indices, updates, shape, input})));
}
VARP _ScatterNd(VARP indices, VARP updates, VARP shape) {
diff --git a/express/Utils.cpp b/express/Utils.cpp
index f71f2c997b..6aac549ece 100644
--- a/express/Utils.cpp
+++ b/express/Utils.cpp
@@ -181,6 +181,23 @@ void* Executor::ComputeCache::mapOutput(int offset, Tensor* dest) {
if (0 == tensor->usize()) {
return nullptr;
}
+
+ bool hasNoExecution = false;
+ if (nullptr != tensor) {
+ auto backend = TensorUtils::getDescribeOrigin(tensor)->getBackend();
+ if (nullptr != backend) {
+ // Try to sync to check execution status
+ int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, tensor);
+ if (NO_EXECUTION == syncResult) {
+ hasNoExecution = true;
+ }
+ }
+ }
+ if (hasNoExecution) {
+ MNN_PRINT("\nWarning, Backend has stop execute, return nullptr for current varp\n");
+ return nullptr;
+ }
+
Utils::allocMemoryForHostTensor(dest);
if(nullptr != dest->host()) {
tensor->copyToHostTensor(dest);
diff --git a/express/module/MoEModule.cpp b/express/module/MoEModule.cpp
index 9ae9f3192a..2f51586a3f 100644
--- a/express/module/MoEModule.cpp
+++ b/express/module/MoEModule.cpp
@@ -20,8 +20,9 @@ std::vector MoEModule::onForward(const std::vector
auto routingWeights = inputs[1];
auto selectedExperts = inputs[2];
auto selectedDim = selectedExperts->getInfo()->dim;
- const int seqlen = selectedDim[1];
- const int topK = selectedDim[2];
+ int ranks = static_cast(selectedDim.size());
+ const int seqlen = selectedDim[ranks - 2];
+ const int topK = selectedDim[ranks - 1];
MNN_ASSERT(topK == mTopK);
auto selectedPtr = selectedExperts->readMap();
// decode
diff --git a/express/module/Module.cpp b/express/module/Module.cpp
index f8b4728153..3c54ca89c3 100644
--- a/express/module/Module.cpp
+++ b/express/module/Module.cpp
@@ -9,6 +9,7 @@
#include
#include
#include
+#include
#include "core/OpCommonUtils.hpp"
#include "PipelineModule.hpp"
#include "core/FileLoader.hpp"
@@ -17,6 +18,7 @@
#include "Utils.hpp"
#include "RuntimeAttr.hpp"
#include "ModuleInside.hpp"
+#include "core/TensorUtils.hpp"
#include
#ifdef MNN_INTERNAL_ENABLED
#include "internal/auth/ModelAuth.hpp"
@@ -221,6 +223,30 @@ class NetModule : public Module {
Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime);
outputs = mModule->onForward(inputs);
}
+
+ // Check execution status after forward
+ if (!outputs.empty()) {
+ bool hasNoExecution = false;
+ for (auto& v : outputs) {
+ auto t = Utils::getTensor(v);
+ if (nullptr != t) {
+ auto backend = TensorUtils::getDescribeOrigin(t)->getBackend();
+ if (nullptr != backend) {
+ // Try to sync to check execution status
+ int syncResult = backend->onSync(Tensor::MAP_TENSOR_READ, false, t);
+ if (NO_EXECUTION == syncResult) {
+ hasNoExecution = true;
+ break;
+ }
+ }
+ }
+ }
+ if (hasNoExecution) {
+ MNN_PRINT("Warning, Backend has stop execute, return empty output vector varps\n");
+ outputs.clear();
+ }
+ }
+
#ifdef MNN_INTERNAL_ENABLED
do {
if (outputs.empty()) {
diff --git a/express/module/StaticModule.cpp b/express/module/StaticModule.cpp
index f99c2fdaf3..980aeb36b7 100644
--- a/express/module/StaticModule.cpp
+++ b/express/module/StaticModule.cpp
@@ -133,7 +133,9 @@ static std::vector> preRearrangeWeights( // NOLIN
}
break;
}
- case MNN::OpType_Attention: {
+ case MNN::OpType_Attention:
+ case MNN::OpType_LinearAttention:
+ {
exe.reset(backend->onCreate({}, {}, op));
if (exe.get() == nullptr) {
exe.reset(backupBackend->onCreate({}, {}, op));
@@ -149,13 +151,29 @@ static std::vector> preRearrangeWeights( // NOLIN
break;
}
case MNN::OpType_LayerNorm: {
- std::shared_ptr tmpstorage;
- exe.reset(OpCommonUtils::createExecutionWithExternal(backend, info.inputs, info.outputs, op, &loader, tmpstorage));
- if (exe.get() == nullptr) {
- exe.reset(OpCommonUtils::createExecutionWithExternal(backupBackend, info.inputs, info.outputs, op, &loader, tmpstorage));
+ if (!base_executions.empty() && op->name()) {
+ auto iter = base_executions.find(op->name()->str());
+ if (iter != base_executions.end()) {
+ auto base_exe = iter->second.get();
+ Execution* copyExecution = nullptr;
+ base_exe->onClone(backend, op, ©Execution);
+ if (copyExecution == nullptr) {
+ base_exe->onClone(backupBackend, op, ©Execution);
+ }
+ if (copyExecution != nullptr && copyExecution->onClone(nullptr, op, nullptr)) {
+ exe.reset(copyExecution);
+ }
+ }
}
- if (nullptr == exe) {
- break;
+ if (exe == nullptr) {
+ std::shared_ptr tmpstorage;
+ exe.reset(OpCommonUtils::createExecutionWithExternal(backend, info.inputs, info.outputs, op, &loader, tmpstorage));
+ if (exe.get() == nullptr) {
+ exe.reset(OpCommonUtils::createExecutionWithExternal(backupBackend, info.inputs, info.outputs, op, &loader, tmpstorage));
+ }
+ if (nullptr == exe) {
+ break;
+ }
}
// The exe can't clone
if (!exe->onClone(nullptr, op, nullptr)) {
diff --git a/include/MNN/Interpreter.hpp b/include/MNN/Interpreter.hpp
index d25af2e5a4..dfc3f11d78 100644
--- a/include/MNN/Interpreter.hpp
+++ b/include/MNN/Interpreter.hpp
@@ -119,7 +119,7 @@ class MNN_PUBLIC Interpreter {
*/
static Interpreter* createFromBuffer(const void* buffer, size_t size);
~Interpreter();
-
+
/**
* @brief destroy Interpreter
* @param model given Interpreter to release.
@@ -153,18 +153,18 @@ class MNN_PUBLIC Interpreter {
Session_Backend_Auto = 9, // Auto Determine the Op type by MNN
/** Determine static memory whether recyle in resizeSession or just cache the memory */
- Session_Memory_Collect = 10, // Recycle static memory when session resize in case memory explosion
+ Session_Memory_Collect = 10, // Recycle static memory when session resize in case memory explosion
Session_Memory_Cache = 11, // Cache the static memory for next forward usage
/** Determine whether use codegen function */
Session_Codegen_Disable = 12, // Disable codegen in case extra build codegen cost
Session_Codegen_Enable = 13, // Enable codegen
-
+
/** Dynamic Reisze Optimization */
Session_Resize_Check = 14, // Open Trace for resize
Session_Resize_Fix = 15, // Apply Resize Optimization
-
- /** Set for Module's traceOrOptimize API.
+
+ /** Set for Module's traceOrOptimize API.
Module_Forward_Seperate:
when inputs is not empty , Module's onForward will only infer shape and alloc memory.
when inputs is empty , Module's onForward will only runSession to compute content.
@@ -199,7 +199,7 @@ class MNN_PUBLIC Interpreter {
* If resize session generate new cache info, try to rewrite cache file.
* If resize session do not generate any new cache info, just do nothing.
* @param session given session
- * @param flag Protected param, not used now
+ * @param flag Protected param, not used now
*/
ErrorCode updateCacheFile(Session *session, int flag = 0);
@@ -226,12 +226,15 @@ class MNN_PUBLIC Interpreter {
// Default is 50
CPU_LITTLECORE_DECREASE_RATE = 6,
+ // attentionOption % 8:
// 0: Do not quantize
- // 1: Only quantize key, use int8 asymmetric quantization
- // 2: Only quantize value, use fp8 quantization
- // 3: quantize both key and value
- // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
- QKV_QUANT_OPTIONS = 7,
+ // 1: Q,K: Int8, V: Float
+ // 2: Q,K,V: Int8
+
+ // attentionOption / 8:
+ // 0: don't use flash attention
+ // 1: use flash attention
+ ATTENTION_OPTION = 7,
// size limit of kvcache in memory (for a single layer)
// if the size of kvcache exceeds the limit, it will be moved to disk
@@ -244,7 +247,7 @@ class MNN_PUBLIC Interpreter {
// mmap allocate file size, KB
MMAP_FILE_SIZE = 11,
USE_CACHED_MMAP = 12,
-
+
// Multi-Thread Load module, default is 0 (don't use other Thread)
INIT_THREAD_NUMBER = 13,
@@ -255,13 +258,19 @@ class MNN_PUBLIC Interpreter {
CPU_SME2_INSTRUCTIONS = 15,
// Enable KleidiAI
- CPU_ENABLE_KLEIDIAI = 16
+ CPU_ENABLE_KLEIDIAI = 16,
+
+ // Set CPU SME2 NEON division ratio, default is 41
+ CPU_SME2_NEON_DIVISION_RATIO = 17,
+
+ // Set SME cores, default is 2, if supports sme
+ CPU_SME_CORES = 18
};
enum ExternalPathType {
// Path of the kvcache directory
EXTERNAL_PATH_KVCACHE_DIR = 0,
-
+
// Mid Buffer Cache File
EXTERNAL_FEATUREMAP_DIR = 1,
@@ -271,6 +280,9 @@ class MNN_PUBLIC Interpreter {
// Path of the NPU Model directory
EXTERNAL_NPU_FILE_DIR = 3,
+ // Path of the kvcache directory
+ EXTERNAL_PATH_PREFIXCACHE_DIR = 4,
+
// Other types ...
};
@@ -283,10 +295,10 @@ class MNN_PUBLIC Interpreter {
// Use loop instead of raster + compute if possible
GEOMETRCOMPUTEMASK_USELOOP = 1 << 2,
-
+
// Support Geometry Cache, if shape changed, will try recompute, and then run compute if failed
GEOMETRCOMPUTEMASK_OPENCACHE = 1 << 3,
-
+
// Full option open mask, for example, if want to close useloop, can set mask as (GEOMETRCOMPUTEMASK_ALL - GEOMETRCOMPUTEMASK_USELOOP)
GEOMETRCOMPUTEMASK_ALL = 0xFFFF,
};
@@ -357,7 +369,7 @@ class MNN_PUBLIC Interpreter {
*/
void resizeSession(Session* session, int needRelloc);
-
+
/**
* @brief call this function if don't need resize or create session any more, it will save a few memory that equal
* to the size of model buffer
@@ -447,7 +459,7 @@ class MNN_PUBLIC Interpreter {
RuntimeManager::getInfo: 0: no resize, 1: re-malloc, 2: resize
*/
RESIZE_STATUS = 3,
-
+
/** Mode / NumberThread, int* */
THREAD_NUMBER = 4,
diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h
index b8e391e3eb..dfe50fa640 100644
--- a/include/MNN/MNNDefine.h
+++ b/include/MNN/MNNDefine.h
@@ -59,7 +59,6 @@
if(!(success)){ \
MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
}
-#ifndef MNN_BUILD_STATIC_LIBS
#if defined(_MSC_VER)
#if defined(BUILDING_MNN_DLL)
#define MNN_PUBLIC __declspec(dllexport)
@@ -71,13 +70,10 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#else
#define MNN_PUBLIC __attribute__((visibility("default")))
#endif
-#else
-#define MNN_PUBLIC
-#endif
#define STR_IMP(x) #x
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 3
-#define MNN_VERSION_MINOR 3
-#define MNN_VERSION_PATCH 0
+#define MNN_VERSION_MINOR 4
+#define MNN_VERSION_PATCH 1
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */
diff --git a/prepare_qnn_deps.sh b/prepare_qnn_deps.sh
index 2e3086a684..ef57355b94 100755
--- a/prepare_qnn_deps.sh
+++ b/prepare_qnn_deps.sh
@@ -7,7 +7,7 @@ set -e
# --- Configuration ---
# URL for QNN libraries zip file
-QNN_LIBS_URL='http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs.zip'
+QNN_LIBS_URL='http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs_2_37.zip'
# Project root is the current directory where the script is run
PROJECT_ROOT=$(pwd)
# Temporary directory for downloads and extraction
@@ -21,15 +21,47 @@ QNN_DEST_DIR="$PROJECT_ROOT/source/backend/qnn/3rdParty"
echo "BUILD_QNN is ON. Preparing QNN dependencies..."
-# 1. Download the QNN zip file if it doesn't exist
+# Fast-path: if destination already prepared, skip download/unpack/copy
+DEST_INCLUDE_DIR="$QNN_DEST_DIR/include"
+DEST_LIB_DIR="$QNN_DEST_DIR/lib"
+if [ -d "$DEST_INCLUDE_DIR" ] && [ -n "$(find "$DEST_INCLUDE_DIR" -mindepth 1 -print -quit 2>/dev/null)" ]; then
+ echo "Detected existing QNN SDK at: $QNN_DEST_DIR"
+ # Ensure env vars are set even when skipping work
+ QNN_SDK_ROOT_PATH="$(cd "$QNN_DEST_DIR" && pwd)"
+ export QNN_SDK_ROOT="$QNN_SDK_ROOT_PATH"
+ ENV_FILE="$PROJECT_ROOT/.qnn_env"
+ echo "export QNN_SDK_ROOT=\"$QNN_SDK_ROOT_PATH\"" > "$ENV_FILE"
+ echo "Set QNN_SDK_ROOT to: $QNN_SDK_ROOT_PATH"
+ echo "You can add it to your shell by running: source $ENV_FILE"
+ exit 0
+fi
+
+# 1. Download the QNN zip file if it doesn't exist or is corrupted
mkdir -p "$BUILD_DIR"
-if [ ! -f "$QNN_ZIP_FILE" ]; then
+
+download_qnn_zip() {
echo "Downloading QNN dependencies from ${QNN_LIBS_URL}"
- # Use curl to download, following redirects (-L) and showing progress
- curl -L -o "$QNN_ZIP_FILE" "$QNN_LIBS_URL"
+ curl -fL --retry 3 --retry-delay 2 -o "$QNN_ZIP_FILE" "$QNN_LIBS_URL"
echo "Downloaded to: ${QNN_ZIP_FILE}"
+}
+
+validate_zip() {
+ unzip -tq "$QNN_ZIP_FILE" >/dev/null 2>&1
+}
+
+if [ ! -f "$QNN_ZIP_FILE" ]; then
+ download_qnn_zip
else
echo "Using cached zip: ${QNN_ZIP_FILE}"
+ if ! validate_zip; then
+ echo "Cached zip appears to be invalid or corrupted. Re-downloading..."
+ rm -f "$QNN_ZIP_FILE"
+ download_qnn_zip
+ if ! validate_zip; then
+ echo "Error: Downloaded zip is invalid. Please try again later." >&2
+ exit 1
+ fi
+ fi
fi
# 2. Unpack the zip file into a clean temporary directory
@@ -57,8 +89,6 @@ EXTRACTED_QNN_DIR=$(dirname "$INCLUDE_DIR")
echo "Found QNN content in: $EXTRACTED_QNN_DIR"
# 4. Copy headers and libraries to their final destination
-DEST_INCLUDE_DIR="$QNN_DEST_DIR/include"
-DEST_LIB_DIR="$QNN_DEST_DIR/lib"
echo "Creating destination directories..."
mkdir -p "$DEST_INCLUDE_DIR"
@@ -85,8 +115,16 @@ else
echo "Warning: No 'lib' or 'jniLibs' directory found in $EXTRACTED_QNN_DIR"
fi
-# 5. Clean up temporary build directory
+# 5. Clean up temporary unzip directory but keep the cached zip for future runs
echo "Cleaning up temporary files..."
-rm -rf "$BUILD_DIR"
+rm -rf "$QNN_TMP_DIR"
echo "QNN dependencies preparation completed successfully."
+
+# 6. Export QNN_SDK_ROOT for current shell and persist to .qnn_env for future shells
+QNN_SDK_ROOT_PATH="$(cd "$QNN_DEST_DIR" && pwd)"
+export QNN_SDK_ROOT="$QNN_SDK_ROOT_PATH"
+ENV_FILE="$PROJECT_ROOT/.qnn_env"
+echo "export QNN_SDK_ROOT=\"$QNN_SDK_ROOT_PATH\"" > "$ENV_FILE"
+echo "Set QNN_SDK_ROOT to: $QNN_SDK_ROOT_PATH"
+echo "You can add it to your shell by running: source $ENV_FILE"
diff --git a/pymnn/examples/MNNLlm/llm_example.py b/pymnn/examples/MNNLlm/llm_example.py
index ec96c0afb0..47925e6d54 100644
--- a/pymnn/examples/MNNLlm/llm_example.py
+++ b/pymnn/examples/MNNLlm/llm_example.py
@@ -7,13 +7,15 @@
config_path = sys.argv[1]
# create model
-qwen = llm.create(config_path)
+model = llm.create(config_path)
# load model
-qwen.load()
+model.load()
# response stream
-out = qwen.response('你好', True)
-print(out)
+print('>>> Model Status: ', model.context.status)
+out = model.response('你好', True)
+print('>>> Model Status: ', model.context.status)
-out_ids = qwen.generate([151644, 872, 198, 108386, 151645, 198, 151644, 77091])
+# generate
+out_ids = model.generate([151644, 872, 198, 108386, 151645, 198, 151644, 77091])
print(out_ids)
diff --git a/pymnn/examples/MNNLlm/vllm_exmaple.py b/pymnn/examples/MNNLlm/vlm_example.py
similarity index 100%
rename from pymnn/examples/MNNLlm/vllm_exmaple.py
rename to pymnn/examples/MNNLlm/vlm_example.py
diff --git a/pymnn/pip_package/MNN/__init__.py b/pymnn/pip_package/MNN/__init__.py
index 89ed46b145..6910552898 100644
--- a/pymnn/pip_package/MNN/__init__.py
+++ b/pymnn/pip_package/MNN/__init__.py
@@ -9,4 +9,5 @@
from . import optim
from . import numpy
from . import cv
-from . import audio
\ No newline at end of file
+from . import audio
+from . import llm
diff --git a/pymnn/pip_package/MNN/llm/__init__.py b/pymnn/pip_package/MNN/llm/__init__.py
index cf6626d645..d8291b4526 100644
--- a/pymnn/pip_package/MNN/llm/__init__.py
+++ b/pymnn/pip_package/MNN/llm/__init__.py
@@ -1,4 +1,18 @@
import _mnncengine.llm as _F
+from enum import IntEnum
+
+class LlmStatus(IntEnum):
+ RUNNING = 0
+ NORMAL_FINISHED = 1
+ MAX_TOKENS_FINISHED = 2
+ USER_CANCEL = 3
+ INTERNAL_ERROR = 4
+
+ def __str__(self):
+ return "{}.{}".format(self.__class__.__name__, self.name)
+
+ def __repr__(self):
+ return "{}.{}".format(self.__class__.__name__, self.name)
class Context:
def __init__(self, llm_obj):
@@ -116,6 +130,7 @@ def output_tokens(self):
def output_tokens(self, value):
self.update(output_tokens=value)
+
@property
def generate_str(self):
return self._data.get('generate_str', '')
@@ -124,8 +139,19 @@ def generate_str(self):
def generate_str(self, value):
self.update(generate_str=value)
+ @property
+ def status(self):
+ return LlmStatus(self._data.get('status', 0))
+
+ @status.setter
+ def status(self, value):
+ if isinstance(value, LlmStatus):
+ self.update(status=int(value))
+ else:
+ self.update(status=int(value))
+
def __repr__(self):
- return f"Context({self._data})"
+ return "Context({})".format(self._data)
class Llm:
@@ -425,4 +451,4 @@ def create(config_path, embedding_model = False):
>>> llm = mllm.create('./qwen-1.8b-int4/config.json')
'''
c_obj = _F.create(config_path, embedding_model)
- return Llm(c_obj)
\ No newline at end of file
+ return Llm(c_obj)
diff --git a/pymnn/pip_package/setup.py b/pymnn/pip_package/setup.py
index 52018354c8..ca478e95a8 100644
--- a/pymnn/pip_package/setup.py
+++ b/pymnn/pip_package/setup.py
@@ -61,7 +61,7 @@ def report(*args):
""" print information """
print(*args)
-package_name = 'MNN'
+package_name = 'mnn'
USE_INTERNAL = False
USE_TRT = False
USE_CUDA = False
@@ -91,20 +91,20 @@ def report(*args):
print ("USE_RENDER:", USE_RENDER)
if os.path.isdir('../../schema/private'):
- package_name += '_Internal'
+ package_name += '_internal'
else:
USE_INTERNAL = False
if USE_TRT:
- package_name += '_TRT'
+ package_name += '_trt'
if USE_CUDA:
- package_name += '_CUDA'
+ package_name += '_cuda'
if USE_VULKAN:
- package_name += '_VULKAN'
+ package_name += '_vulkan'
if USE_OPENCL:
- package_name += '_OPENCL'
+ package_name += '_opencl'
if USE_RENDER:
- package_name += '_RENDER'
+ package_name += '_render'
print ('Building with python wheel with package name ', package_name)
diff --git a/pymnn/src/llm.h b/pymnn/src/llm.h
index 2e7319857b..f20bf2f26f 100644
--- a/pymnn/src/llm.h
+++ b/pymnn/src/llm.h
@@ -174,7 +174,7 @@ static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) {
PyObject* content = nullptr;
int stream = 0;
- int max_new_tokens = 2048;
+ int max_new_tokens = -1;
if (!PyArg_ParseTuple(args, "O|ii", &content, &stream, &max_new_tokens)) {
MNN_PRINT("[MNNLLM] response: PyArg_ParseTuple failed\n");
@@ -368,6 +368,9 @@ static PyObject* PyMNNLLM_get_context(LLM *self, PyObject *args) {
PyDict_SetItemString(dict, "generate_str", string2Object(context->generate_str));
+ // llm status
+ PyDict_SetItemString(dict, "status", PyLong_FromLong((int)context->status));
+
return dict;
}
@@ -429,6 +432,11 @@ static PyObject* PyMNNLLM_set_context(LLM *self, PyObject *args) {
context->generate_str = object2String(generate_str);
}
+ PyObject* status = PyDict_GetItemString(dict, "status");
+ if (status && PyLong_Check(status)) {
+ context->status = (MNN::Transformer::LlmStatus)PyLong_AsLong(status);
+ }
+
Py_RETURN_NONE;
}
diff --git a/pymnn/src/reranker.h b/pymnn/src/reranker.h
index 641ca9e3c7..d369cf4921 100644
--- a/pymnn/src/reranker.h
+++ b/pymnn/src/reranker.h
@@ -54,6 +54,16 @@ static PyObject* PyMNNReranker_setInstruct(Reranker *self, PyObject *args) {
Py_RETURN_NONE;
}
+static PyObject* PyMNNReranker_load(Reranker *self, PyObject *args) {
+ if (!self->reranker) {
+ PyErr_SetString(PyExc_RuntimeError, "Reranker not initialized");
+ Py_RETURN_NONE;
+ }
+
+ self->reranker->load();
+ Py_RETURN_NONE;
+}
+
static PyObject* PyMNNReranker_compute_scores(Reranker *self, PyObject *args) {
if (!self->reranker) {
PyErr_SetString(PyExc_RuntimeError, "Reranker not initialized");
@@ -118,6 +128,7 @@ static PyObject* PyMNNReranker_get_llm(Reranker *self, PyObject *args) {
static PyMethodDef PyMNNReranker_methods[] = {
{"set_instruct", (PyCFunction)PyMNNReranker_setInstruct, METH_VARARGS, "Set instruction for the reranker."},
+ {"load", (PyCFunction)PyMNNReranker_load, METH_VARARGS, "Load the reranker model."},
{"compute_scores", (PyCFunction)PyMNNReranker_compute_scores, METH_VARARGS, "Compute scores for documents given a query."},
{"get_llm", (PyCFunction)PyMNNReranker_get_llm, METH_VARARGS, "Get the underlying LLM instance for configuration."},
{NULL} /* Sentinel */
diff --git a/pymnn/update_mnn_wrapper_assets.sh b/pymnn/update_mnn_wrapper_assets.sh
index c72b82a4d2..ac1fb8b452 100755
--- a/pymnn/update_mnn_wrapper_assets.sh
+++ b/pymnn/update_mnn_wrapper_assets.sh
@@ -33,9 +33,9 @@ rm -rf tools
cat __init__.py | sed '/from . import tools/d' > __init__.py.tmp
mv __init__.py.tmp __init__.py
-rm -rf llm
-cat __init__.py | sed '/from . import llm/d' > __init__.py.tmp
-mv __init__.py.tmp __init__.py
+# rm -rf llm
+# cat __init__.py | sed '/from . import llm/d' > __init__.py.tmp
+# mv __init__.py.tmp __init__.py
rm -rf audio
cat __init__.py | sed '/from . import audio/d' > __init__.py.tmp
diff --git a/schema/current/MNN_generated.h b/schema/current/MNN_generated.h
index 377ea2dce8..d9f33ef0b1 100644
--- a/schema/current/MNN_generated.h
+++ b/schema/current/MNN_generated.h
@@ -27,6 +27,9 @@ struct StringVecT;
struct AttentionParam;
struct AttentionParamT;
+struct LinearAttentionParam;
+struct LinearAttentionParamT;
+
struct FmhaV2Param;
struct FmhaV2ParamT;
@@ -77,6 +80,8 @@ inline const flatbuffers::TypeTable *StringVecTypeTable();
inline const flatbuffers::TypeTable *AttentionParamTypeTable();
+inline const flatbuffers::TypeTable *LinearAttentionParamTypeTable();
+
inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable();
inline const flatbuffers::TypeTable *FmhcaParamTypeTable();
@@ -278,6 +283,7 @@ enum OpType {
OpType_SeqLen2Spatial = 302,
OpType_SplitGeLU = 303,
OpType_GroupNorm = 304,
+ OpType_LinearAttention = 305,
OpType_Extra = 512,
OpType_ConvInt8 = 513,
OpType_Int8ToFloat = 514,
@@ -292,7 +298,7 @@ enum OpType {
OpType_MAX = OpType_GridSample
};
-inline const OpType (&EnumValuesOpType())[182] {
+inline const OpType (&EnumValuesOpType())[183] {
static const OpType values[] = {
OpType_AbsVal,
OpType_QuantizedAdd,
@@ -466,6 +472,7 @@ inline const OpType (&EnumValuesOpType())[182] {
OpType_SeqLen2Spatial,
OpType_SplitGeLU,
OpType_GroupNorm,
+ OpType_LinearAttention,
OpType_Extra,
OpType_ConvInt8,
OpType_Int8ToFloat,
@@ -787,7 +794,7 @@ inline const char * const *EnumNamesOpType() {
"SeqLen2Spatial",
"SplitGeLU",
"GroupNorm",
- "",
+ "LinearAttention",
"",
"",
"",
@@ -1199,11 +1206,12 @@ enum OpParameter {
OpParameter_FmhcaParam = 97,
OpParameter_AttentionParam = 98,
OpParameter_StftParam = 99,
+ OpParameter_LinearAttentionParam = 100,
OpParameter_MIN = OpParameter_NONE,
- OpParameter_MAX = OpParameter_StftParam
+ OpParameter_MAX = OpParameter_LinearAttentionParam
};
-inline const OpParameter (&EnumValuesOpParameter())[100] {
+inline const OpParameter (&EnumValuesOpParameter())[101] {
static const OpParameter values[] = {
OpParameter_NONE,
OpParameter_QuantizedAdd,
@@ -1304,7 +1312,8 @@ inline const OpParameter (&EnumValuesOpParameter())[100] {
OpParameter_FmhaV2Param,
OpParameter_FmhcaParam,
OpParameter_AttentionParam,
- OpParameter_StftParam
+ OpParameter_StftParam,
+ OpParameter_LinearAttentionParam
};
return values;
}
@@ -1411,13 +1420,14 @@ inline const char * const *EnumNamesOpParameter() {
"FmhcaParam",
"AttentionParam",
"StftParam",
+ "LinearAttentionParam",
nullptr
};
return names;
}
inline const char *EnumNameOpParameter(OpParameter e) {
- if (e < OpParameter_NONE || e > OpParameter_StftParam) return "";
+ if (e < OpParameter_NONE || e > OpParameter_LinearAttentionParam) return "";
const size_t index = static_cast(e);
return EnumNamesOpParameter()[index];
}
@@ -1822,6 +1832,10 @@ template<> struct OpParameterTraits {
static const OpParameter enum_value = OpParameter_StftParam;
};
+template<> struct OpParameterTraits {
+ static const OpParameter enum_value = OpParameter_LinearAttentionParam;
+};
+
struct OpParameterUnion {
OpParameter type;
void *value;
@@ -2645,6 +2659,14 @@ struct OpParameterUnion {
return type == OpParameter_StftParam ?
reinterpret_cast(value) : nullptr;
}
+ LinearAttentionParamT *AsLinearAttentionParam() {
+ return type == OpParameter_LinearAttentionParam ?
+ reinterpret_cast(value) : nullptr;
+ }
+ const LinearAttentionParamT *AsLinearAttentionParam() const {
+ return type == OpParameter_LinearAttentionParam ?
+ reinterpret_cast(value) : nullptr;
+ }
};
bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj, OpParameter type);
@@ -3005,6 +3027,115 @@ inline flatbuffers::Offset CreateAttentionParam(
flatbuffers::Offset CreateAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const AttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct LinearAttentionParamT : public flatbuffers::NativeTable {
+ typedef LinearAttentionParam TableType;
+ std::string attn_type;
+ int32_t num_k_heads;
+ int32_t num_v_heads;
+ int32_t head_k_dim;
+ int32_t head_v_dim;
+ bool use_qk_l2norm;
+ LinearAttentionParamT()
+ : num_k_heads(0),
+ num_v_heads(0),
+ head_k_dim(0),
+ head_v_dim(0),
+ use_qk_l2norm(false) {
+ }
+};
+
+struct LinearAttentionParam FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef LinearAttentionParamT NativeTableType;
+ static const flatbuffers::TypeTable *MiniReflectTypeTable() {
+ return LinearAttentionParamTypeTable();
+ }
+ const flatbuffers::String *attn_type() const {
+ return GetPointer(4);
+ }
+ int32_t num_k_heads() const {
+ return GetField(6, 0);
+ }
+ int32_t num_v_heads() const {
+ return GetField(8, 0);
+ }
+ int32_t head_k_dim() const {
+ return GetField(10, 0);
+ }
+ int32_t head_v_dim() const {
+ return GetField(12, 0);
+ }
+ bool use_qk_l2norm() const {
+ return GetField(14, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, 4) &&
+ verifier.VerifyString(attn_type()) &&
+ VerifyField(verifier, 6) &&
+ VerifyField(verifier, 8) &&
+ VerifyField(verifier, 10) &&
+ VerifyField(verifier, 12) &&
+ VerifyField(verifier, 14) &&
+ verifier.EndTable();
+ }
+ LinearAttentionParamT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(LinearAttentionParamT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LinearAttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LinearAttentionParamBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_attn_type(flatbuffers::Offset attn_type) {
+ fbb_.AddOffset(4, attn_type);
+ }
+ void add_num_k_heads(int32_t num_k_heads) {
+ fbb_.AddElement(6, num_k_heads, 0);
+ }
+ void add_num_v_heads(int32_t num_v_heads) {
+ fbb_.AddElement(8, num_v_heads, 0);
+ }
+ void add_head_k_dim(int32_t head_k_dim) {
+ fbb_.AddElement(10, head_k_dim, 0);
+ }
+ void add_head_v_dim(int32_t head_v_dim) {
+ fbb_.AddElement(12, head_v_dim, 0);
+ }
+ void add_use_qk_l2norm(bool use_qk_l2norm) {
+ fbb_.AddElement(14, static_cast(use_qk_l2norm), 0);
+ }
+ explicit LinearAttentionParamBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ LinearAttentionParamBuilder &operator=(const LinearAttentionParamBuilder &);
+ flatbuffers::Offset Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset CreateLinearAttentionParam(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset attn_type = 0,
+ int32_t num_k_heads = 0,
+ int32_t num_v_heads = 0,
+ int32_t head_k_dim = 0,
+ int32_t head_v_dim = 0,
+ bool use_qk_l2norm = false) {
+ LinearAttentionParamBuilder builder_(_fbb);
+ builder_.add_head_v_dim(head_v_dim);
+ builder_.add_head_k_dim(head_k_dim);
+ builder_.add_num_v_heads(num_v_heads);
+ builder_.add_num_k_heads(num_k_heads);
+ builder_.add_attn_type(attn_type);
+ builder_.add_use_qk_l2norm(use_qk_l2norm);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset CreateLinearAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const LinearAttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct FmhaV2ParamT : public flatbuffers::NativeTable {
typedef FmhaV2Param TableType;
int32_t heads;
@@ -3971,6 +4102,9 @@ struct Op FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const StftParam *main_as_StftParam() const {
return main_type() == OpParameter_StftParam ? static_cast(main()) : nullptr;
}
+ const LinearAttentionParam *main_as_LinearAttentionParam() const {
+ return main_type() == OpParameter_LinearAttentionParam ? static_cast(main()) : nullptr;
+ }
const flatbuffers::String *name() const {
return GetPointer(10);
}
@@ -4404,6 +4538,10 @@ template<> inline const StftParam *Op::main_as() const {
return main_as_StftParam();
}
+template<> inline const LinearAttentionParam *Op::main_as() const {
+ return main_as_LinearAttentionParam();
+}
+
struct OpBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -5227,6 +5365,47 @@ inline flatbuffers::Offset CreateAttentionParam(flatbuffers::Fla
_kv_cache);
}
+inline LinearAttentionParamT *LinearAttentionParam::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new LinearAttentionParamT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void LinearAttentionParam::UnPackTo(LinearAttentionParamT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = attn_type(); if (_e) _o->attn_type = _e->str(); };
+ { auto _e = num_k_heads(); _o->num_k_heads = _e; };
+ { auto _e = num_v_heads(); _o->num_v_heads = _e; };
+ { auto _e = head_k_dim(); _o->head_k_dim = _e; };
+ { auto _e = head_v_dim(); _o->head_v_dim = _e; };
+ { auto _e = use_qk_l2norm(); _o->use_qk_l2norm = _e; };
+}
+
+inline flatbuffers::Offset LinearAttentionParam::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LinearAttentionParamT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateLinearAttentionParam(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset CreateLinearAttentionParam(flatbuffers::FlatBufferBuilder &_fbb, const LinearAttentionParamT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LinearAttentionParamT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _attn_type = _o->attn_type.empty() ? 0 : _fbb.CreateString(_o->attn_type);
+ auto _num_k_heads = _o->num_k_heads;
+ auto _num_v_heads = _o->num_v_heads;
+ auto _head_k_dim = _o->head_k_dim;
+ auto _head_v_dim = _o->head_v_dim;
+ auto _use_qk_l2norm = _o->use_qk_l2norm;
+ return MNN::CreateLinearAttentionParam(
+ _fbb,
+ _attn_type,
+ _num_k_heads,
+ _num_v_heads,
+ _head_k_dim,
+ _head_v_dim,
+ _use_qk_l2norm);
+}
+
inline FmhaV2ParamT *FmhaV2Param::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new FmhaV2ParamT();
UnPackTo(_o, _resolver);
@@ -6163,6 +6342,10 @@ inline bool VerifyOpParameter(flatbuffers::Verifier &verifier, const void *obj,
auto ptr = reinterpret_cast(obj);
return verifier.VerifyTable(ptr);
}
+ case OpParameter_LinearAttentionParam: {
+ auto ptr = reinterpret_cast(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -6577,6 +6760,10 @@ inline void *OpParameterUnion::UnPack(const void *obj, OpParameter type, const f
auto ptr = reinterpret_cast(obj);
return ptr->UnPack(resolver);
}
+ case OpParameter_LinearAttentionParam: {
+ auto ptr = reinterpret_cast(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -6979,6 +7166,10 @@ inline flatbuffers::Offset OpParameterUnion::Pack(flatbuffers::FlatBufferB
auto ptr = reinterpret_cast(value);
return CreateStftParam(_fbb, ptr, _rehasher).Union();
}
+ case OpParameter_LinearAttentionParam: {
+ auto ptr = reinterpret_cast(value);
+ return CreateLinearAttentionParam(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -7381,6 +7572,10 @@ inline OpParameterUnion::OpParameterUnion(const OpParameterUnion &u) FLATBUFFERS
value = new StftParamT(*reinterpret_cast(u.value));
break;
}
+ case OpParameter_LinearAttentionParam: {
+ value = new LinearAttentionParamT(*reinterpret_cast(u.value));
+ break;
+ }
default:
break;
}
@@ -7883,6 +8078,11 @@ inline void OpParameterUnion::Reset() {
delete ptr;
break;
}
+ case OpParameter_LinearAttentionParam: {
+ auto ptr = reinterpret_cast(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
@@ -8072,12 +8272,13 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
+ { flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 }
};
static const flatbuffers::TypeFunction type_refs[] = {
OpTypeTypeTable
};
- static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 512, 513, 514, 515, 517, 518, 600, 601, 603, 604 };
+ static const int64_t values[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 299, 300, 301, 302, 303, 304, 305, 512, 513, 514, 515, 517, 518, 600, 601, 603, 604 };
static const char * const names[] = {
"AbsVal",
"QuantizedAdd",
@@ -8251,6 +8452,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"SeqLen2Spatial",
"SplitGeLU",
"GroupNorm",
+ "LinearAttention",
"Extra",
"ConvInt8",
"Int8ToFloat",
@@ -8263,7 +8465,7 @@ inline const flatbuffers::TypeTable *OpTypeTypeTable() {
"GridSample"
};
static const flatbuffers::TypeTable tt = {
- flatbuffers::ST_ENUM, 182, type_codes, type_refs, values, names
+ flatbuffers::ST_ENUM, 183, type_codes, type_refs, values, names
};
return &tt;
}
@@ -8369,7 +8571,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
{ flatbuffers::ET_SEQUENCE, 0, 95 },
{ flatbuffers::ET_SEQUENCE, 0, 96 },
{ flatbuffers::ET_SEQUENCE, 0, 97 },
- { flatbuffers::ET_SEQUENCE, 0, 98 }
+ { flatbuffers::ET_SEQUENCE, 0, 98 },
+ { flatbuffers::ET_SEQUENCE, 0, 99 }
};
static const flatbuffers::TypeFunction type_refs[] = {
QuantizedAddTypeTable,
@@ -8470,7 +8673,8 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
FmhaV2ParamTypeTable,
FmhcaParamTypeTable,
AttentionParamTypeTable,
- StftParamTypeTable
+ StftParamTypeTable,
+ LinearAttentionParamTypeTable
};
static const char * const names[] = {
"NONE",
@@ -8572,10 +8776,11 @@ inline const flatbuffers::TypeTable *OpParameterTypeTable() {
"FmhaV2Param",
"FmhcaParam",
"AttentionParam",
- "StftParam"
+ "StftParam",
+ "LinearAttentionParam"
};
static const flatbuffers::TypeTable tt = {
- flatbuffers::ST_UNION, 100, type_codes, type_refs, nullptr, names
+ flatbuffers::ST_UNION, 101, type_codes, type_refs, nullptr, names
};
return &tt;
}
@@ -8698,6 +8903,29 @@ inline const flatbuffers::TypeTable *AttentionParamTypeTable() {
return &tt;
}
+inline const flatbuffers::TypeTable *LinearAttentionParamTypeTable() {
+ static const flatbuffers::TypeCode type_codes[] = {
+ { flatbuffers::ET_STRING, 0, -1 },
+ { flatbuffers::ET_INT, 0, -1 },
+ { flatbuffers::ET_INT, 0, -1 },
+ { flatbuffers::ET_INT, 0, -1 },
+ { flatbuffers::ET_INT, 0, -1 },
+ { flatbuffers::ET_BOOL, 0, -1 }
+ };
+ static const char * const names[] = {
+ "attn_type",
+ "num_k_heads",
+ "num_v_heads",
+ "head_k_dim",
+ "head_v_dim",
+ "use_qk_l2norm"
+ };
+ static const flatbuffers::TypeTable tt = {
+ flatbuffers::ST_TABLE, 6, type_codes, nullptr, nullptr, names
+ };
+ return &tt;
+}
+
inline const flatbuffers::TypeTable *FmhaV2ParamTypeTable() {
static const flatbuffers::TypeCode type_codes[] = {
{ flatbuffers::ET_INT, 0, -1 }
diff --git a/schema/default/MNN.fbs b/schema/default/MNN.fbs
index eb78b1224d..c15df9deaa 100644
--- a/schema/default/MNN.fbs
+++ b/schema/default/MNN.fbs
@@ -194,6 +194,7 @@ enum OpType : int {
SeqLen2Spatial = 302,
SplitGeLU = 303,
GroupNorm = 304,
+ LinearAttention = 305,
Extra = 512,
// quantization
@@ -231,6 +232,15 @@ table AttentionParam {
kv_cache: bool = true;
}
+table LinearAttentionParam {
+ attn_type: string;
+ num_k_heads: int;
+ num_v_heads: int;
+ head_k_dim: int;
+ head_v_dim: int;
+ use_qk_l2norm: bool;
+}
+
table FmhaV2Param {
heads: int;
}
@@ -421,7 +431,8 @@ union OpParameter {
FmhaV2Param,
FmhcaParam,
AttentionParam,
- StftParam
+ StftParam,
+ LinearAttentionParam
}
table Op {
diff --git a/source/backend/arm82/Arm82Functions.cpp b/source/backend/arm82/Arm82Functions.cpp
index c9b4a3ee90..b9ec3e368f 100644
--- a/source/backend/arm82/Arm82Functions.cpp
+++ b/source/backend/arm82/Arm82Functions.cpp
@@ -443,15 +443,17 @@ void ARM82StrassenMerge(FLOAT16* c11, FLOAT16* c12, FLOAT16* c21, FLOAT16* c22,
}
void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth, int32_t* areaOffset) {
+ // [depth/8, srcAreaOffset, 8] -> [area, dstAreaOffset]
int srcAreaOffset = areaOffset[0];
+ int dstAreaOffset = areaOffset[1];
int c = (int)depth;
- int cDiv4 = c / 8;
- int cAlign = cDiv4 * 8;
+ int cDiv8 = c / 8;
+ int cAlign = cDiv8 * 8;
int areaDiv4 = area / 4;
int areaAlign = areaDiv4 * 4;
if (areaAlign > 0) {
- for (int ci = 0; ci < cDiv4; ++ci) {
+ for (int ci = 0; ci < cDiv8; ++ci) {
auto srcH = src + ci * 8 * srcAreaOffset;
auto dstH = dst + ci * 8;
for (int hi = 0; hi < areaAlign; hi+=4) {
@@ -460,10 +462,10 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si
auto src2 = srcH + hi * 8 + 16;
auto src3 = srcH + hi * 8 + 24;
- auto dst0 = dstH + hi * c;
- auto dst1 = dstH + hi * c + c;
- auto dst2 = dstH + hi * c + 2 * c;
- auto dst3 = dstH + hi * c + 3 * c;
+ auto dst0 = dstH + hi * dstAreaOffset;
+ auto dst1 = dstH + hi * dstAreaOffset + dstAreaOffset;
+ auto dst2 = dstH + hi * dstAreaOffset + 2 * dstAreaOffset;
+ auto dst3 = dstH + hi * dstAreaOffset + 3 * dstAreaOffset;
vst1q_s16(dst0, vld1q_s16(src0));
vst1q_s16(dst1, vld1q_s16(src1));
vst1q_s16(dst2, vld1q_s16(src2));
@@ -472,12 +474,12 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si
}
}
if (areaAlign < area) {
- for (int ci = 0; ci < cDiv4; ++ci) {
+ for (int ci = 0; ci < cDiv8; ++ci) {
auto srcH = src + 8 * ci * srcAreaOffset;
auto dstH = dst + ci * 8;
for (int hi = areaAlign; hi < area; ++hi) {
auto src0 = srcH + hi * 8;
- auto dst0 = dstH + hi * c;
+ auto dst0 = dstH + hi * dstAreaOffset;
vst1q_s16(dst0, vld1q_s16(src0));
}
}
@@ -492,7 +494,7 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si
for (int hi = 0; hi < area; ++hi) {
auto srcHeight = srcAlign + hi * 8;
- auto dstHeight = dstAlign + hi * c;
+ auto dstHeight = dstAlign + hi * dstAreaOffset;
for (int ci = 0; ci < cReamin; ++ci) {
dstHeight[ci] = srcHeight[ci];
@@ -934,495 +936,124 @@ static void _ArmBasicMNNPackC4ForMatMul_A_L8(int8_t* destOrigin, int8_t const**
}
}
+inline void transpose_4x4_f32(float32x4_t& r0, float32x4_t& r1, float32x4_t& r2, float32x4_t& r3) {
+ // Stage 1: Transpose 2x2 blocks of float32 elements between pairs of adjacent rows.
+ float32x4x2_t temp0 = vtrnq_f32(r0, r1);
+ float32x4x2_t temp1 = vtrnq_f32(r2, r3);
+
+ // Intermediate state:
+ // temp0.val[0] = [A0, B0, A2, B2]
+ // temp0.val[1] = [A1, B1, A3, B3]
+ // temp1.val[0] = [C0, D0, C2, D2]
+ // temp1.val[1] = [C1, D1, C3, D3]
+
+ // Stage 2: Manually swap the 64-bit blocks to finalize the transpose.
+ // This correctly simulates the non-existent 64-bit transpose/zip.
+ float64x2_t i0_f64 = vreinterpretq_f64_f32(temp0.val[0]);
+ float64x2_t i1_f64 = vreinterpretq_f64_f32(temp0.val[1]);
+ float64x2_t i2_f64 = vreinterpretq_f64_f32(temp1.val[0]);
+ float64x2_t i3_f64 = vreinterpretq_f64_f32(temp1.val[1]);
+
+ // Combine the low 64 bits of i0 and i2 to form the first part of the result.
+ float32x4_t t0 = vreinterpretq_f32_f64(vcombine_f64(vget_low_f64(i0_f64), vget_low_f64(i2_f64)));
+ // Combine the low 64 bits of i1 and i3 for the second part.
+ float32x4_t t1 = vreinterpretq_f32_f64(vcombine_f64(vget_low_f64(i1_f64), vget_low_f64(i3_f64)));
+ // Combine the high 64 bits of i0 and i2 for the third part.
+ float32x4_t t2 = vreinterpretq_f32_f64(vcombine_f64(vget_high_f64(i0_f64), vget_high_f64(i2_f64)));
+ // Combine the high 64 bits of i1 and i3 for the final part.
+ float32x4_t t3 = vreinterpretq_f32_f64(vcombine_f64(vget_high_f64(i1_f64), vget_high_f64(i3_f64)));
+
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+}
+
static void Sme2MNNPackC4ForMatMul_A_FP16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
- int LP = FP16_SME2_MATMUL_LP;
- int pack = 8;
- // LP >= pack
+ const int lP = FP16_SME2_MATMUL_LP;
+ const int pack = 8;
int number = info[0];
int eReal = info[1];
int eDest = info[2];
int offset = info[3];
- for (int n=0; n 7) {
- auto source = sourceN;
- auto dest = destN;
- l -= 8;
- auto e = eWork;
- if (e == eDest) {
- auto s0 = vld1q_f32((float*)(source)); // 00112233
- auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
- auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
-
- auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
- auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
- auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
- auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
-
- auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
- auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
- auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
- auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
-
- auto s12 = vld1q_f32((float*)(source + 12 * srcStride0));
- auto s13 = vld1q_f32((float*)(source + 13 * srcStride0));
- auto s14 = vld1q_f32((float*)(source + 14 * srcStride0));
- auto s15 = vld1q_f32((float*)(source + 15 * srcStride0));
-
- auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
- auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
- auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
- auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
- auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
- auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
- auto zip1s1213 = vzip1q_f32(s12, s13); // 00001111
- auto zip1s1415 = vzip1q_f32(s14, s15); // 00001111
-
- auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
- auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
- auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
- auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
- auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
- auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
- auto zip2s1213 = vzip2q_f32(s12, s13); // 22223333
- auto zip2s1415 = vzip2q_f32(s14, s15); // 22223333
-
- auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
- auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
- auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
- auto zip1s12131415_01 = vzip1q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
-
- auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
- auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
- auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
- auto zip2s12131415_01 = vzip2q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
-
- auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
- auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
- auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
- auto zip1s12131415_23 = vzip1q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
-
- auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
- auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
- auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
- auto zip2s12131415_23 = vzip2q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
-
- vst1q_f64((float64_t*)dest, zip1s0123_01);
- vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
- vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
- vst1q_f64((float64_t*)(dest + 24), zip1s12131415_01);
-
- vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
- vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
- vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
- vst1q_f64((float64_t*)(dest + dstStride0 + 24), zip2s12131415_01);
-
- vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 24), zip1s12131415_23);
-
- vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 24), zip2s12131415_23);
-
- // dest += (4 * dstStride0);
- // e -= eDest;
- sourceN += (eReal * pack);
- destN += (4 * dstStride0);
- continue;
- }
-
- if (e > 11) {
- auto s0 = vld1q_f32((float*)(source)); // 00112233
- auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
- auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
- auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
- auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
- auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
- auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
+ float32x4_t v0, v1, v2, v3, v4, v5, v6, v7;
- auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
- auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
- auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
- auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
-
- auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
- auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
- auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
- auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
- auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
- auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
-
- auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
- auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
- auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
- auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
- auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
- auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
-
- auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
- auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
- auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
-
- auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
- auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
- auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
-
- auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
- auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
- auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
-
- auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
- auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
- auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
-
- vst1q_f64((float64_t*)dest, zip1s0123_01);
- vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
- vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
-
- vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
- vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
- vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
-
- vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
-
- vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
+ for (int n = 0; n < number; ++n) {
+ int e = el[4 * n + 0];
+ int l = el[4 * n + 1];
+ int eOffset = el[4 * n + 2];
+ int lOffset = el[4 * n + 3];
- dest += 24;
- e -= 12;
- source += (12 * srcStride0);
+ auto destBase = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * lP;
+ auto sourceBase = (const FLOAT16*)(sourceGroup[n]);
+
+ const int eTile = 8;
+ const int lTile = 8;
+
+ const int eMain = e / eTile;
+ const int lMain = l / lTile;
+
+ const size_t srcRowStride = (size_t)pack * offset;
+ const size_t srcColBlockStride = (size_t)eReal * pack;
+ const size_t dstColBlockStride = (size_t)eDest * lP;
+
+ for (int y0 = 0; y0 < eMain; ++y0) {
+ const int yBase = y0 * eTile;
+ for (int x0 = 0; x0 < lMain; ++x0) {
+ const int xBase = x0 * lTile;
+
+ const auto srcBlockBase = sourceBase + yBase * srcRowStride + x0 * srcColBlockStride;
+
+ v0 = vld1q_f32((const float*)(srcBlockBase + 0 * srcRowStride));
+ v1 = vld1q_f32((const float*)(srcBlockBase + 1 * srcRowStride));
+ v2 = vld1q_f32((const float*)(srcBlockBase + 2 * srcRowStride));
+ v3 = vld1q_f32((const float*)(srcBlockBase + 3 * srcRowStride));
+ v4 = vld1q_f32((const float*)(srcBlockBase + 4 * srcRowStride));
+ v5 = vld1q_f32((const float*)(srcBlockBase + 5 * srcRowStride));
+ v6 = vld1q_f32((const float *)(srcBlockBase + 6 * srcRowStride));
+ v7 = vld1q_f32((const float *)(srcBlockBase + 7 * srcRowStride));
+
+ transpose_4x4_f32(v0, v1, v2, v3);
+ transpose_4x4_f32(v4, v5, v6, v7);
+
+ float* addr0 = (float*)(destBase + yBase * lP + (xBase / lP) * dstColBlockStride);
+ float* addr1= (float*)(destBase + yBase * lP + (xBase / lP + 1) * dstColBlockStride);
+ float* addr2= (float*)(destBase + yBase * lP + (xBase / lP + 2) * dstColBlockStride);
+ float* addr3= (float*)(destBase + yBase * lP + (xBase / lP + 3) * dstColBlockStride);
+
+ vst1q_f32(addr0, v0);
+ vst1q_f32(addr0 + 4, v4);
+ vst1q_f32(addr1, v1);
+ vst1q_f32(addr1 + 4, v5);
+ vst1q_f32(addr2, v2);
+ vst1q_f32(addr2 + 4, v6);
+ vst1q_f32(addr3, v3);
+ vst1q_f32(addr3 + 4, v7);
}
+ }
- if (e > 7) {
- auto s0 = vld1q_f32((float*)(source)); // 00112233
- auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
- auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
-
- auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
- auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
- auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
- auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
-
- auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
- auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
- auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
- auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
-
- auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
- auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
- auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
- auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
-
- auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
- auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
-
- auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
- auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
-
- auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
- auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
+ const int eHandled = eMain * eTile;
+ const int lHandled = lMain * lTile;
- auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
- auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
-
- vst1q_f64((float64_t*)dest, zip1s0123_01);
- vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
-
- vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
- vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
-
- vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
-
- vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
-
- dest += 16;
- e -= 8;
- source += (8 * srcStride0);
- }
-
- if (e > 3) {
- auto s0 = vld1q_f32((float*)(source)); // 00112233
- auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
- auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
-
- auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
- auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
-
- auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
- auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
-
- auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
-
- auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
-
- auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
-
- auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
-
- vst1q_f64((float64_t*)dest, zip1s0123_01);
- vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
- vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
- vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
-
- dest += 8;
- e -= 4;
- source += (4 * srcStride0);
- }
- while (e > 0) {
- auto s0 = vld1q_f32((float*)(source)); // 00112233
-
- ((float*)dest)[0] = s0[0];
- ((float*)(dest + dstStride0))[0] = s0[1];
- ((float*)(dest + 2 * dstStride0))[0] = s0[2];
- ((float*)(dest + 3 * dstStride0))[0] = s0[3];
-
- dest += 2;
- e -= 1;
- source += srcStride0;
- }
- sourceN += (eReal * pack);
- destN += (4 * dstStride0);
- } // l>7
-
- if (l > 3) {
- auto source = sourceN;
- auto dest = destN;
- l -= 4;
- auto e = eWork;
- if (e == eDest) {
- auto s0 = vld1_f32((float*)(source)); // 0011
- auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
- auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
-
- auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
- auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
- auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
- auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
-
- auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
- auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
- auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
- auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
-
- auto s12 = vld1_f32((float*)(source + 12 * srcStride0));
- auto s13 = vld1_f32((float*)(source + 13 * srcStride0));
- auto s14 = vld1_f32((float*)(source + 14 * srcStride0));
- auto s15 = vld1_f32((float*)(source + 15 * srcStride0));
-
- auto zip1s01 = vzip1_f32(s0, s1); // 0000
- auto zip1s23 = vzip1_f32(s2, s3); // 0000
- auto zip1s45 = vzip1_f32(s4, s5); // 0000
- auto zip1s67 = vzip1_f32(s6, s7); // 0000
- auto zip1s89 = vzip1_f32(s8, s9); // 0000
- auto zip1s1011 = vzip1_f32(s10, s11); // 0000
- auto zip1s1213 = vzip1_f32(s12, s13); // 0000
- auto zip1s1415 = vzip1_f32(s14, s15); // 0000
-
- auto zip2s01 = vzip2_f32(s0, s1); // 1111
- auto zip2s23 = vzip2_f32(s2, s3); // 1111
- auto zip2s45 = vzip2_f32(s4, s5); // 1111
- auto zip2s67 = vzip2_f32(s6, s7); // 1111
- auto zip2s89 = vzip2_f32(s8, s9); // 1111
- auto zip2s1011 = vzip2_f32(s10, s11); // 1111
- auto zip2s1213 = vzip2_f32(s12, s13); // 1111
- auto zip2s1415 = vzip2_f32(s14, s15); // 1111
-
- vst1_f32((float32_t*)dest, zip1s01);
- vst1_f32((float32_t*)(dest + 4), zip1s23);
- vst1_f32((float32_t*)(dest + 8), zip1s45);
- vst1_f32((float32_t*)(dest + 12), zip1s67);
- vst1_f32((float32_t*)(dest + 16), zip1s89);
- vst1_f32((float32_t*)(dest + 20), zip1s1011);
- vst1_f32((float32_t*)(dest + 24), zip1s1213);
- vst1_f32((float32_t*)(dest + 28), zip1s1415);
-
- vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
- vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
- vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
- vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
- vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
- vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
- vst1_f32((float32_t*)(dest + dstStride0 + 24), zip2s1213);
- vst1_f32((float32_t*)(dest + dstStride0 + 28), zip2s1415);
-
-
- dest += 32;
- e -= eDest;
- }
-
- if (e > 11) {
- auto s0 = vld1_f32((float*)(source)); // 0011
- auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
- auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
-
- auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
- auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
- auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
- auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
-
- auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
- auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
- auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
- auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
-
- auto zip1s01 = vzip1_f32(s0, s1); // 0000
- auto zip1s23 = vzip1_f32(s2, s3); // 0000
- auto zip1s45 = vzip1_f32(s4, s5); // 0000
- auto zip1s67 = vzip1_f32(s6, s7); // 0000
- auto zip1s89 = vzip1_f32(s8, s9); // 0000
- auto zip1s1011 = vzip1_f32(s10, s11); // 0000
-
- auto zip2s01 = vzip2_f32(s0, s1); // 1111
- auto zip2s23 = vzip2_f32(s2, s3); // 1111
- auto zip2s45 = vzip2_f32(s4, s5); // 1111
- auto zip2s67 = vzip2_f32(s6, s7); // 1111
- auto zip2s89 = vzip2_f32(s8, s9); // 1111
- auto zip2s1011 = vzip2_f32(s10, s11); // 1111
-
- vst1_f32((float32_t*)dest, zip1s01);
- vst1_f32((float32_t*)(dest + 4), zip1s23);
- vst1_f32((float32_t*)(dest + 8), zip1s45);
- vst1_f32((float32_t*)(dest + 12), zip1s67);
- vst1_f32((float32_t*)(dest + 16), zip1s89);
- vst1_f32((float32_t*)(dest + 20), zip1s1011);
-
- vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
- vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
- vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
- vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
- vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
- vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
-
- dest += 24;
- e -= 12;
- source += (12 * srcStride0);
- }
-
- if (e > 7) {
- auto s0 = vld1_f32((float*)(source)); // 0011
- auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
- auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
-
- auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
- auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
- auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
- auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
-
- auto zip1s01 = vzip1_f32(s0, s1); // 0000
- auto zip1s23 = vzip1_f32(s2, s3); // 0000
- auto zip1s45 = vzip1_f32(s4, s5); // 0000
- auto zip1s67 = vzip1_f32(s6, s7); // 0000
-
- auto zip2s01 = vzip2_f32(s0, s1); // 1111
- auto zip2s23 = vzip2_f32(s2, s3); // 1111
- auto zip2s45 = vzip2_f32(s4, s5); // 1111
- auto zip2s67 = vzip2_f32(s6, s7); // 1111
-
- vst1_f32((float32_t*)dest, zip1s01);
- vst1_f32((float32_t*)(dest + 4), zip1s23);
- vst1_f32((float32_t*)(dest + 8), zip1s45);
- vst1_f32((float32_t*)(dest + 12), zip1s67);
-
- vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
- vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
- vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
- vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
-
- dest += 16;
- e -= 8;
- source += (8 * srcStride0);
- }
-
- if (e > 3) {
- auto s0 = vld1_f32((float*)(source)); // 0011
- auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
- auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
- auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
-
- auto zip1s01 = vzip1_f32(s0, s1); // 0000
- auto zip1s23 = vzip1_f32(s2, s3); // 0000
-
- auto zip2s01 = vzip2_f32(s0, s1); // 1111
- auto zip2s23 = vzip2_f32(s2, s3); // 1111
-
- vst1_f32((float32_t*)dest, zip1s01);
- vst1_f32((float32_t*)(dest + 4), zip1s23);
-
- vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
- vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
-
- dest += 8;
- e -= 4;
- source += (4 * srcStride0);
- }
- if (e > 1) {
- auto s0 = vld1_f32((float*)(source)); // 0011
- auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
-
- auto zip1s01 = vzip1_f32(s0, s1); // 0000
-
- auto zip2s01 = vzip2_f32(s0, s1); // 1111
-
- vst1_f32((float32_t*)dest, zip1s01);
-
- vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
-
- dest += 4;
- e -= 2;
- source += (2 * srcStride0);
- }
- if (e > 0) {
- auto s0 = vld1_f32((float*)(source)); // 0011
-
- ((float*)dest)[0] = s0[0];
- ((float*)(dest + dstStride0))[0] = s0[1];
+ // Process remaining rows
+ for (int y = eHandled; y < e; ++y) {
+ int yR = y % eDest;
+ for (int x = 0; x < l; ++x) {
+ int xR = x % pack;
+ int xC = x / pack;
+ destBase[(x / lP) * dstColBlockStride + yR * lP + (x % lP)] = sourceBase[xC * srcColBlockStride + y * srcRowStride + xR];
}
- sourceN += 4;
- destN += (2 * dstStride0);
- }
+ }
- auto source = (FLOAT16*)(sourceGroup[n]);
- auto dest = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
- if (l > 0) {
- auto e = eWork;
- auto lRemain = lWork - l;
- // if e < eDest, packed A -> [LU, eDest, LP] eDest=eP
- for (int y=0; y 1) {
+ for (int s = 0; s < seqLen; ++s) {
+ const FLOAT16* keySrc = sourceFp16 + s * kvNumHead * headDim + kvHeadIdx * headDim;
+ int d = 0;
+ for (; d <= headDim - 16; d += 16) {
+ float16x8_t maxVec0 = vld1q_f16(maxKeyFp16 + d);
+ float16x8_t maxVec1 = vld1q_f16(maxKeyFp16 + d + 8);
+ float16x8_t srcVec0 = vld1q_f16(keySrc + d);
+ float16x8_t srcVec1 = vld1q_f16(keySrc + d + 8);
+ maxVec0 = vmaxq_f16(maxVec0, srcVec0);
+ maxVec1 = vmaxq_f16(maxVec1, srcVec1);
+ vst1q_f16(maxKeyFp16 + d, maxVec0);
+ vst1q_f16(maxKeyFp16 + d + 8, maxVec1);
+ }
+ for (; d <= headDim - 8; d += 8) {
+ float16x8_t maxVec = vld1q_f16(maxKeyFp16 + d);
+ float16x8_t srcVec = vld1q_f16(keySrc + d);
+ maxVec = vmaxq_f16(maxVec, srcVec);
+ vst1q_f16(maxKeyFp16 + d, maxVec);
+ }
+ for (; d < headDim; ++d) {
+ maxKeyFp16[d] = ALIMAX(maxKeyFp16[d], keySrc[d]);
+ }
+ }
}
- auto dstStride1 = eP;
- auto dstStride0 = planesize * dstStride1;
+ // Quant fp16
+ for (int s = 0; s < seqLen; s++) {
+ const FLOAT16* keySrc = sourceFp16 + s * kvNumHead * headDim + kvHeadIdx * headDim;
- for (int i = 0; i < depth; ++i) {
- size_t realsize = planesize;
- const float* srcPtr = src + i * planesize;
- FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) + (i / eP) * dstStride0;
+ float16x8_t minVec = vdupq_n_f16(keySrc[0]);
+ float16x8_t maxVec = vdupq_n_f16(keySrc[0]);
- while (realsize >= 16) {
- float32x4_t s0_f32 = vld1q_f32(srcPtr);
- float32x4_t s1_f32 = vld1q_f32(srcPtr + 4);
- float32x4_t s2_f32 = vld1q_f32(srcPtr + 8);
- float32x4_t s3_f32 = vld1q_f32(srcPtr + 12);
+ int d = 0;
+ for (; d <= headDim - 8; d += 8) {
+ float16x8_t srcVec = vld1q_f16(keySrc + d);
+ float16x8_t maxKeyVec = vld1q_f16(maxKeyFp16 + d);
+ float16x8_t keyDataF16 = vsubq_f16(srcVec, maxKeyVec);
- float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
- float16x4_t d1_f16 = vcvt_f16_f32(s1_f32);
- float16x4_t d2_f16 = vcvt_f16_f32(s2_f32);
- float16x4_t d3_f16 = vcvt_f16_f32(s3_f32);
+ minVec = vminq_f16(minVec, keyDataF16);
+ maxVec = vmaxq_f16(maxVec, keyDataF16);
- vst1_lane_f16(dstPtr, d0_f16, 0);
- vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
- vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
- vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
+ float32x4_t keyDataF32Low = vcvt_f32_f16(vget_low_f16(keyDataF16));
+ float32x4_t keyDataF32High = vcvt_f32_f16(vget_high_f16(keyDataF16));
+ }
- vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0);
- vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1);
- vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2);
- vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3);
+ FLOAT16 minKey = vminvq_f16(minVec);
+ FLOAT16 maxKey = vmaxvq_f16(maxVec);
- vst1_lane_f16(dstPtr + 8 * dstStride1, d2_f16, 0);
- vst1_lane_f16(dstPtr + 9 * dstStride1, d2_f16, 1);
- vst1_lane_f16(dstPtr + 10 * dstStride1, d2_f16, 2);
- vst1_lane_f16(dstPtr + 11 * dstStride1, d2_f16, 3);
+ for (; d < headDim; ++d) {
+ auto keydata = keySrc[d] - maxKeyFp16[d];
+ minKey = ALIMIN(minKey, keydata);
+ maxKey = ALIMAX(maxKey, keydata);
+ }
- vst1_lane_f16(dstPtr + 12 * dstStride1, d3_f16, 0);
- vst1_lane_f16(dstPtr + 13 * dstStride1, d3_f16, 1);
- vst1_lane_f16(dstPtr + 14 * dstStride1, d3_f16, 2);
- vst1_lane_f16(dstPtr + 15 * dstStride1, d3_f16, 3);
+ int outIndex = (pastLength + s) / hP;
+ int inIndex = (pastLength + s) % hP;
- srcPtr += 16;
- dstPtr += 16 * dstStride1;
- realsize -= 16;
+ float range = (float)maxKey - (float)minKey;
+ float quantScaleVal = 0;
+ float biasVal = minKey + 128.0f * range / 255.0;
+ if (range <= 1e-6f) {
+ quantScaleVal = 0.f;
+ } else {
+ quantScaleVal = 255.0f / range;
}
- if (realsize >= 8) {
- float32x4_t s0_f32 = vld1q_f32(srcPtr);
- float32x4_t s1_f32 = vld1q_f32(srcPtr + 4);
+ for (int k = 0; k < blockNum; ++k) {
+ int8_t* weightDstBase = dst + outIndex * blockNum * packedWeightStride1 + k * packedWeightStride1;
+ float* scaleDst = (float*)(weightDstBase + weightStride1);
+ float* biasDst = scaleDst + hP;
- float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
- float16x4_t d1_f16 = vcvt_f16_f32(s1_f32);
+ scaleDst[inIndex] = range / 255.f;
+ biasDst[inIndex] = biasVal;
- vst1_lane_f16(dstPtr, d0_f16, 0);
- vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
- vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
- vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
+ float32x4_t scaleVecFp32 = vdupq_n_f32(quantScaleVal);
+ float32x4_t negMinKeyVecF32 = vdupq_n_f32(-(float)minKey);
- vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0);
- vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1);
- vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2);
- vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3);
+ const FLOAT16* currentKeyBlock = keySrc + k * blockHeadDim;
+ const FLOAT16* currentMaxBlock = maxKeyFp16 + k * blockHeadDim;
- srcPtr += 8;
- dstPtr += 8 * dstStride1;
- realsize -= 8;
- }
+ int32x4_t sumInt32_0 = vdupq_n_s32(0);
+ int32x4_t sumInt32_1 = vdupq_n_s32(0);
+ int headDimIdx = 0;
+ for (; headDimIdx <= blockHeadDim - 8; headDimIdx += 8) {
+ float16x8_t srcVecFp16 = vld1q_f16(currentKeyBlock + headDimIdx);
+ float16x8_t maxVecFp16 = vld1q_f16(currentMaxBlock + headDimIdx);
- if (realsize >= 4) {
- float32x4_t s0_f32 = vld1q_f32(srcPtr);
- float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
+ float16x8_t keyDataF16 = vsubq_f16(srcVecFp16, maxVecFp16);
- vst1_lane_f16(dstPtr, d0_f16, 0);
- vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
- vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
- vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
+ float32x4_t keyDataLowFp32 = vcvt_f32_f16(vget_low_f16(keyDataF16));
+ float32x4_t keyDataHighFp32 = vcvt_f32_f16(vget_high_f16(keyDataF16));
- srcPtr += 4;
- dstPtr += 4 * dstStride1;
- realsize -= 4;
- }
+ keyDataLowFp32 = vaddq_f32(keyDataLowFp32, negMinKeyVecF32);
+ keyDataHighFp32 = vaddq_f32(keyDataHighFp32, negMinKeyVecF32);
+
+ keyDataLowFp32 = vmulq_f32(keyDataLowFp32, scaleVecFp32);
+ keyDataHighFp32 = vmulq_f32(keyDataHighFp32, scaleVecFp32);
- for (; realsize > 0; --realsize) {
- *dstPtr = (FLOAT16)(*srcPtr);
- srcPtr++;
- dstPtr += dstStride1;
+ keyDataLowFp32 = vaddq_f32(keyDataLowFp32, neg128Vec);
+ keyDataHighFp32 = vaddq_f32(keyDataHighFp32, neg128Vec);
+
+ int32x4_t keyDataLowInt32 = vcvtaq_s32_f32(keyDataLowFp32);
+ int32x4_t keyDataHighInt32 = vcvtaq_s32_f32(keyDataHighFp32);
+
+ int16x4_t s16Low = vmovn_s32(keyDataLowInt32);
+ int16x4_t s16High = vmovn_s32(keyDataHighInt32);
+
+ int16x8_t s16Combined = vcombine_s16(s16Low, s16High);
+
+ // sum
+ sumInt32_0 = vaddq_s32(sumInt32_0, keyDataLowInt32);
+ sumInt32_1 = vaddq_s32(sumInt32_1, keyDataHighInt32);
+
+ int8x8_t s8Vec = vqmovn_s16(s16Combined);
+
+ if (lP == 8) {
+ int i = headDimIdx / lP;
+ int8_t* dstPtr = weightDstBase + i * weightStride2 + inIndex * lP;
+ vst1_s8(dstPtr, s8Vec);
+ } else if (lP == 4) {
+ vst1_s8(tempBuffer, s8Vec);
+ int iLow = headDimIdx / lP;
+ int iHigh = (headDimIdx + 4) / lP;
+
+ int8_t* dstPtrLow = weightDstBase + iLow * weightStride2 + inIndex * lP;
+ int8_t* dstPtrHigh = weightDstBase + iHigh * weightStride2 + inIndex * lP;
+
+ std::memcpy(dstPtrLow, tempBuffer, 4);
+ std::memcpy(dstPtrHigh, tempBuffer + 4, 4);
+ } else {
+ vst1_s8(tempBuffer, s8Vec);
+ for (int nk = 0; nk < 8; ++nk) {
+ int headDimCurr = headDimIdx + nk;
+ int i = headDimCurr / lP;
+ int j = headDimCurr % lP;
+ weightDstBase[i * weightStride2 + inIndex * lP + j] = tempBuffer[nk];
+ }
+ }
+
+ }
+
+ int32_t sumInt32 = vaddvq_s32(sumInt32_0) + vaddvq_s32(sumInt32_1);
+
+ for (; headDimIdx < blockHeadDim; ++headDimIdx) {
+ int i = headDimIdx / lP;
+ int j = headDimIdx % lP;
+ float keyVal = (float)currentKeyBlock[headDimIdx] - (float)currentMaxBlock[headDimIdx];
+ float quantVal = (keyVal - minKey) * quantScaleVal - 128.0f;
+ int32_t roundedVal = static_cast(roundf(quantVal));
+ int8_t finalVal = static_cast(std::max(-128, std::min(127, roundedVal)));
+ weightDstBase[i * weightStride2 + inIndex * lP + j] = finalVal;
+ sumInt32 += finalVal;
+ }
+
+ // store sum
+ sumKeyPtr[outIndex * hP + inIndex] = sumInt32 * range / 255.f + (minKey * (float)(blockHeadDim) + 128.0f * range * (float)(blockHeadDim) / 255.0);
}
}
}
-static void MNNAttenPackAndConvertFp32(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize) {
- int32_t eP = units[0];
- int32_t lP = units[1]; // Now lP=1 or 2
+static void MNNQuantAttentionValueFP16(int8_t* dst, const float* source, float* valueSum, int32_t* params) {
+ // float value src : [kvSeq,kvNumHead,headDim]
+ // int8_t value dest: [updiv(maxLength,flashAttentionBlockKv), updiv(headDim,hp),updiv(flashAttentionBlockKv,lp),hp,lp]
+ // float value sum: [updiv(maxLength,flashAttentionBlockKv), roundup(headDim,hp)]
+ int32_t kvNumHead = params[0];
+ int32_t seqLen = params[1];
+ int32_t headDim = params[2];
+ int32_t blockNum = params[3];
+ int32_t maxLength = params[4];
+
+ int32_t lP = params[5];
+ int32_t hP = params[6];
+ int32_t pastLength = params[7];
+ int32_t kvHeadIdx = params[8];
+
+ int32_t flashAttentionBlockKv = params[9];
+
+ auto blockKvseq = UP_DIV(seqLen + pastLength, blockNum);
+ auto weightStride2 = lP * hP;
+ auto weightStride1 = UP_DIV(flashAttentionBlockKv, lP) * weightStride2;
+
+ auto packedStride1 = (int)(weightStride1 + 2 * hP * sizeof(float));
+ auto packedStride0 = UP_DIV(headDim, hP) * packedStride1;
+
+ auto srcStride0 = kvNumHead * headDim;
+
+ auto sourceFp16 = (FLOAT16*)source;
+
+ // quant scale & bias
+ if (pastLength == 0) {
+ for (int d = 0; d < headDim; ++d) {
+ float* scalePtr = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP);
+ float* biasPtr = scalePtr + hP;
+
+ // find min,max
+ float dMax = sourceFp16[d + kvHeadIdx * headDim];
+ float dMin = dMax;
+ for (int s = 0; s < seqLen; ++s) {
+ float data = sourceFp16[s * srcStride0 + d + kvHeadIdx * headDim];
+ dMax = ALIMAX(dMax, data);
+ dMin = ALIMIN(dMin, data);
+ }
- if (lP != 1 && lP != 2) {
- MNN_ERROR("This function only supports lP=1 or 2\n");
- return;
+ // scale & bias
+ float range = dMax - dMin;
+ if (range < 1e-6) {
+ scalePtr[0] = 0.f;
+ biasPtr[0] = dMax;
+ } else {
+ float scale = range / 255.f;
+ float bias = range / 255.f * 128.f + dMin;
+ scalePtr[0] = scale;
+ biasPtr[0] = bias;
+ }
+ }
}
- // src [depth, planesize] (float32)
- // dst [depth/eP, planesize/lP, eP, lP] (float16)
+ // copy the scale&bias to each blockKv
+ // pastLength == 0: First time prefill
+ // (seqLen + pastLength) % flashAttentionBlockKv == 0: Open a new blockKv
+ if (pastLength == 0 || (pastLength % flashAttentionBlockKv) == 0) {
+ int32_t d0 = UP_DIV(maxLength, flashAttentionBlockKv);
+ int32_t d1 = UP_DIV(headDim, hP);
+ for (int k = 0; k < d0; ++k) {
+ for (int r = 0; r < d1; ++r) {
+ float* scalePtr = (float*)(dst + k * packedStride0 + r * packedStride1 + weightStride1);
+ float* biasPtr = scalePtr + hP;
+ memcpy(scalePtr, dst + r * packedStride1 + weightStride1, hP * sizeof(float));
+ memcpy(biasPtr, dst + r * packedStride1 + weightStride1 + hP * sizeof(float), hP * sizeof(float));
+ }
+ }
+ }
- if (lP == 1) {
- MNNAttenPackAndConvertFp32LP1(dst, src, units, depth, planesize);
- return;
+ std::vector qScales(headDim);
+ std::vector qBiases(headDim);
+ std::vector deqScales(headDim);
+ std::vector deqBiases(headDim);
+ int8_t tmpQ[8];
+
+ for (int d = 0; d < headDim; ++d) {
+ float* scaleBase = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP);
+ float* biasBase = scaleBase + hP;
+
+ float s_val = scaleBase[0];
+ float b_val = biasBase[0];
+
+ deqScales[d] = s_val;
+ deqBiases[d] = b_val;
+
+ bool is_small = s_val < 1e-6f;
+ qScales[d] = is_small ? 0.0f : (1.0f / s_val);
+ qBiases[d] = is_small ? 0.0f : (-b_val / s_val);
}
- auto dstStride1 = eP * lP;
- auto dstStride0 = UP_DIV(planesize, lP) * dstStride1;
+ const __fp16* srcBasePtr = sourceFp16 + kvHeadIdx * headDim;
- for (int i = 0; i < depth; ++i) {
- size_t realsize = planesize;
- const float* srcPtr = src + i * planesize;
- FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) * lP + (i / eP) * dstStride0;
+ const int32_t sumStride = ROUND_UP(headDim, hP);
- while (realsize >= 16) {
- float32x4_t s0 = vld1q_f32(srcPtr);
- float32x4_t s1 = vld1q_f32(srcPtr + 4);
- float32x4_t s2 = vld1q_f32(srcPtr + 8);
- float32x4_t s3 = vld1q_f32(srcPtr + 12);
+ for (int s = 0; s < seqLen; ++s) {
+ int kvSeqIndx = s + pastLength;
- float16x4_t h0 = vcvt_f16_f32(s0);
- float16x4_t h1 = vcvt_f16_f32(s1);
- float16x4_t h2 = vcvt_f16_f32(s2);
- float16x4_t h3 = vcvt_f16_f32(s3);
+ int blkIdx = kvSeqIndx / flashAttentionBlockKv;
+ int blkRem = kvSeqIndx % flashAttentionBlockKv;
- vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
+ int idxInnerCommon = blkIdx * packedStride0 + (blkRem / lP) * weightStride2 + (blkRem % lP);
- vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1);
+ float* curSumRow = valueSum + blkIdx * sumStride;
- vst1_lane_u32((uint32_t*)(dstPtr + 4 * dstStride1), vreinterpret_u32_f16(h2), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + 5 * dstStride1), vreinterpret_u32_f16(h2), 1);
+ const __fp16* srcRow = srcBasePtr + s * srcStride0;
- vst1_lane_u32((uint32_t*)(dstPtr + 6 * dstStride1), vreinterpret_u32_f16(h3), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + 7 * dstStride1), vreinterpret_u32_f16(h3), 1);
+ int d = 0;
+ for (; d <= headDim - 8; d += 8) {
+ // --- Load Source ---
+ float16x8_t vSrc16 = vld1q_f16(srcRow + d);
+ float32x4_t vSrc0 = vcvt_f32_f16(vget_low_f16(vSrc16));
+ float32x4_t vSrc1 = vcvt_high_f32_f16(vSrc16);
- realsize -= 16;
- srcPtr += 16;
- dstPtr += 8 * dstStride1;
- }
+ // --- Load Quant Params ---
+ float32x4_t vQs0 = vld1q_f32(&qScales[d]);
+ float32x4_t vQb0 = vld1q_f32(&qBiases[d]);
+ float32x4_t vQs1 = vld1q_f32(&qScales[d + 4]);
+ float32x4_t vQb1 = vld1q_f32(&qBiases[d + 4]);
- if (realsize >= 8) {
- float32x4_t s0 = vld1q_f32(srcPtr);
- float32x4_t s1 = vld1q_f32(srcPtr + 4);
+ // --- Quantize: x * qs + qb ---
+ float32x4_t vRes0 = vaddq_f32(vmulq_f32(vSrc0, vQs0), vQb0);
+ float32x4_t vRes1 = vaddq_f32(vmulq_f32(vSrc1, vQs1), vQb1);
- float16x4_t h0 = vcvt_f16_f32(s0);
- float16x4_t h1 = vcvt_f16_f32(s1);
+ // --- Round & Saturate ---
+ int32x4_t vInt32_0 = vcvtaq_s32_f32(vRes0);
+ int32x4_t vInt32_1 = vcvtaq_s32_f32(vRes1);
- vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
+ int16x8_t vInt16 = vcombine_s16(vqmovn_s32(vInt32_0), vqmovn_s32(vInt32_1));
+ int8x8_t vInt8 = vqmovn_s16(vInt16); // Clamp to [-128, 127]
- vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1);
+ vst1_s8(tmpQ, vInt8);
+ for (int k = 0; k < 8; ++k) {
+ int cur_d = d + k;
+ int dstOffset = (cur_d / hP) * packedStride1 + idxInnerCommon + (cur_d % hP) * lP;
+ dst[dstOffset] = tmpQ[k];
+ }
- realsize -= 8;
- srcPtr += 8;
- dstPtr += 4 * dstStride1;
- }
+ int16x8_t vXq16 = vmovl_s8(vInt8);
+ float32x4_t vXqF0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vXq16)));
+ float32x4_t vXqF1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vXq16)));
- if (realsize >= 4) {
- float32x4_t s0 = vld1q_f32(srcPtr);
- float16x4_t h0 = vcvt_f16_f32(s0);
+ float32x4_t vDs0 = vld1q_f32(&deqScales[d]);
+ float32x4_t vDb0 = vld1q_f32(&deqBiases[d]);
+ float32x4_t vDs1 = vld1q_f32(&deqScales[d + 4]);
+ float32x4_t vDb1 = vld1q_f32(&deqBiases[d + 4]);
- vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
- vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
+ // Dequant
+ float32x4_t vDeq0 = vaddq_f32(vmulq_f32(vXqF0, vDs0), vDb0);
+ float32x4_t vDeq1 = vaddq_f32(vmulq_f32(vXqF1, vDs1), vDb1);
- realsize -= 4;
- srcPtr += 4;
- dstPtr += 2 * dstStride1;
+ float* sumPtr = curSumRow + d;
+ vst1q_f32(sumPtr, vaddq_f32(vld1q_f32(sumPtr), vDeq0));
+ vst1q_f32(sumPtr + 4, vaddq_f32(vld1q_f32(sumPtr + 4), vDeq1));
}
- if (realsize >= 2) {
- float32x2_t s0 = vld1_f32(srcPtr);
- float16x4_t h0 = vcvt_f16_f32(vcombine_f32(s0, s0));
+ for (; d < headDim; ++d) {
+ float xf = (float)srcRow[d];
- vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
+ float val_f = xf * qScales[d] + qBiases[d];
+ int32_t val_i = (int32_t)roundf(val_f);
+ if (val_i > 127) val_i = 127;
+ if (val_i < -128) val_i = -128;
+ int8_t xq = (int8_t)val_i;
- realsize -= 2;
- srcPtr += 2;
- dstPtr += dstStride1;
+ int dstOffset = (d / hP) * packedStride1 + idxInnerCommon + (d % hP) * lP;
+ dst[dstOffset] = xq;
+
+ curSumRow[d] += ((float)xq * deqScales[d] + deqBiases[d]);
}
+ }
+
+/*
+ // Quant fp16
+ for (int d = 0; d < headDim; ++d) {
+ // dst address
+ int idxBase = (d / hP) * packedStride1 + (d % hP) * lP;
+ int8_t* dstBase = dst + idxBase;
+ float* scaleBase = (float*)(dst + (d / hP) * packedStride1 + weightStride1) + (d % hP);
+ float* biasBase = scaleBase + hP;
+ float* sumBase = valueSum + (d / hP) * hP + (d % hP);
+
+ float qscale = scaleBase[0] < 1e-6 ? 0 : 1.0f / scaleBase[0];
+ float qbias = scaleBase[0] < 1e-6 ? 0 : (-biasBase[0] / scaleBase[0]);
+ // quant
+ for (int s = 0; s < seqLen; ++s) {
+ int kvSeqIndx = s + pastLength;
+ int idxInner = (kvSeqIndx / flashAttentionBlockKv) * packedStride0 + (kvSeqIndx % flashAttentionBlockKv) / lP * weightStride2 + (kvSeqIndx % flashAttentionBlockKv) % lP;
+ float xf = sourceFp16[s * srcStride0 + d + kvHeadIdx * headDim];
+ int8_t xq = ALIMAX(ALIMIN(127, static_cast(roundf(xf * qscale + qbias))), -128);
+ dstBase[idxInner] = xq;
- if (realsize > 0) {
- dstPtr[0] = (FLOAT16)srcPtr[0];
- dstPtr[1] = (FLOAT16)0.0f;
+ // sum
+ int idxSum = (kvSeqIndx / flashAttentionBlockKv) * ROUND_UP(headDim, hP);
+ sumBase[idxSum] += ((float)xq * scaleBase[0] + biasBase[0]);
}
}
+ */
}
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
@@ -2267,9 +2086,8 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
// dequant scale/bias : [EU, blockNum, step]
// quant scale/bias: [blockNum, plane]
if (info[7] == 1) { // scale&bias:[1]
- ARM82CountMinMaxValue(src, dstMin, dstMax, kernelsize * stride0);
- float maxval = *(FLOAT16*)dstMax;
- float minval = *(FLOAT16*)dstMin;
+ FLOAT16 maxval, minval;
+ ARM82CountMinMaxValue(src, (float*)(&minval), (float*)(&maxval) , kernelsize * stride0);
if (info[8] == 1 && (maxval - minval) > 1e-7) {
if (minval > 0.f) {
minval = 0.f;
@@ -2279,9 +2097,9 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
}
auto range = maxval - minval;
if (range <= 1e-7) {
- scale[0] = 0.f;
- qscale[0] = 0.f;
- qbias[0] = 0.f;
+ scale[0] = 1.f;
+ qscale[0] = 1.f;
+ qbias[0] = -maxval;
bias[0] = maxval;
} else {
qscale[0] = 255.f / range;
@@ -2390,6 +2208,417 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
#endif // MNN_LOW_MEMORY
+#define EXP_APPROX_MIN_INPUT vdupq_n_f32(-88.0f)
+#define EXP_APPROX_MAX_INPUT vdupq_n_f32(88.0f)
+#define EXP_APPROX_LN2 vdupq_n_f32(0.69314718056f) // ln(2)
+#define EXP_APPROX_LN2_INV vdupq_n_f32(1.44269504089f) // 1/ln(2)
+// Fourth-order polynomial approximation coefficients of exp(r):
+// P(x) = c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
+#define EXP_APPROX_C4 vdupq_n_f32(0.0416624f)
+#define EXP_APPROX_C3 vdupq_n_f32(0.166665f)
+#define EXP_APPROX_C2 vdupq_n_f32(0.500000f)
+#define EXP_APPROX_C1 vdupq_n_f32(1.0f)
+#define EXP_APPROX_C0 vdupq_n_f32(1.0f)
+
+#ifndef __aarch64__
+static inline float32x4_t vrndaq_f32_compat(float32x4_t x) {
+ float32x4_t sign = vbslq_f32(vdupq_n_u32(0x80000000), x, vdupq_n_f32(0.0f));
+ return vcvtq_f32_s32(vcvtq_s32_f32(vaddq_f32(x, vbslq_f32(vcltq_f32(x, vdupq_n_f32(0.0f)), vdupq_n_f32(-0.5f), vdupq_n_f32(0.5f)))));
+}
+#endif
+
+static inline float32x4_t expApprox(float32x4_t x) {
+ x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT);
+
+ float32x4_t k_float;
+ float32x4_t r;
+ float32x4_t exp_r;
+#if defined(__aarch64__)
+ k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV));
+
+ // r = x - k * ln(2)
+ r = vfmsq_f32(x, k_float, EXP_APPROX_LN2);
+
+ // P(r) = (c0 + c2*r^2 + c4*r^4) + r*(c1 + c3*r^2)
+ float32x4_t r2 = vmulq_f32(r, r);
+ float32x4_t p_odd = vfmaq_f32(EXP_APPROX_C1, EXP_APPROX_C3, r2);
+
+ float32x4_t p_even = vfmaq_f32(EXP_APPROX_C0, EXP_APPROX_C2, r2);
+ p_even = vfmaq_f32(p_even, EXP_APPROX_C4, vmulq_f32(r2, r2));
+ exp_r = vfmaq_f32(p_even, p_odd, r);
+#else
+
+ k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV));
+
+
+ r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2));
+
+ // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4)))
+ exp_r = vmlaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
+ exp_r = vmlaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...)
+ exp_r = vmlaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...)
+ exp_r = vmlaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...)
+
+#endif
+
+ int32x4_t k_int = vcvtq_s32_f32(k_float);
+ int32x4_t k_shifted = vshlq_n_s32(k_int, 23);
+ return vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted));
+}
+static void MNNSoftmaxFp16_Pack8(float* dest, const float* source, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) {
+ auto softmaxDst = (FLOAT16*)dest;
+ auto softmaxSrc = (FLOAT16*)source;
+
+ if (pack != 8) {
+ MNN_ERROR("MNNSoftmaxFp16_Pack8 only support pack=8\n");
+ return;
+ }
+
+ const int packUnit = 8;
+ int reduceSizeOuter = UP_DIV(reduceSize, packUnit);
+ int stride0 = outside * packUnit;
+
+ // Loop Tiling: Unroll K by 16
+ // 16 * 8 * 2 = 256 Bytes
+ for (int k = 0; k < outside; k += 16) {
+ int count = ALIMIN(16, outside - k);
+
+ int validLens[16];
+ bool isRowValid[16];
+
+ for (int i = 0; i < count; ++i) {
+ int currentK = k + i;
+ if (mask && kvSeqOffset > currentK + validOffset) {
+ isRowValid[i] = false;
+ validLens[i] = 0;
+ if (updateScale) updateScale[currentK] = 1.0f;
+ } else {
+ isRowValid[i] = true;
+ validLens[i] = mask ? ALIMIN(reduceSize, currentK + (validOffset + 1) - kvSeqOffset) : reduceSize;
+ }
+ }
+
+ float currentMax[16];
+ for (int i = 0; i < count; ++i) {
+ currentMax[i] = runningMax ? runningMax[k + i] : -65504.0f;
+ }
+
+ for (int j = 0; j < reduceSizeOuter; ++j) {
+ auto blockSrcBase = softmaxSrc + j * stride0 + k * packUnit;
+
+ for (int i = 0; i < count; ++i) {
+ if (!isRowValid[i]) continue;
+
+ int len = validLens[i];
+ int blockStart = j * packUnit;
+ if (blockStart >= len) continue;
+
+ auto srcPtr = blockSrcBase + i * packUnit;
+ int remain = len - blockStart;
+
+ if (remain >= packUnit) {
+ float16x8_t val = vld1q_f16(srcPtr);
+ float maxInVec = vmaxvq_f16(val);
+ currentMax[i] = ALIMAX(currentMax[i], maxInVec);
+ } else {
+ for (int p = 0; p < remain; ++p) {
+ currentMax[i] = ALIMAX(currentMax[i], (float)srcPtr[p]);
+ }
+ }
+ }
+ }
+
+ float currentSum[16] = {0.0f};
+ float32x4_t vecSum0[16]; // Low part accumulator
+ float32x4_t vecSum1[16]; // High part accumulator
+ float32x4_t finalMaxVec[16];
+
+ for (int i = 0; i < count; ++i) {
+ vecSum0[i] = vdupq_n_f32(0.0f);
+ vecSum1[i] = vdupq_n_f32(0.0f);
+ finalMaxVec[i] = vdupq_n_f32(currentMax[i]);
+ }
+
+ for (int j = 0; j < reduceSizeOuter; ++j) {
+ auto blockSrcBase = softmaxSrc + j * stride0 + k * packUnit;
+ auto blockDstBase = softmaxDst + j * stride0 + k * packUnit;
+
+ for (int i = 0; i < count; ++i) {
+ if (!isRowValid[i]) {
+ memset(blockDstBase + i * packUnit, 0, packUnit * sizeof(__fp16));
+ continue;
+ }
+
+ int len = validLens[i];
+ int blockStart = j * packUnit;
+ if (blockStart >= len) {
+ memset(blockDstBase + i * packUnit, 0, packUnit * sizeof(__fp16));
+ continue;
+ }
+
+ auto srcPtr = blockSrcBase + i * packUnit;
+ auto dstPtr = blockDstBase + i * packUnit;
+ int remain = len - blockStart;
+
+ if (remain >= packUnit) {
+ float16x8_t srcVal = vld1q_f16(srcPtr);
+
+ // F16 -> F32 expansion
+ float32x4_t low = vcvt_f32_f16(vget_low_f16(srcVal));
+ float32x4_t high = vcvt_f32_f16(vget_high_f16(srcVal));
+
+ // Subtract Max
+ low = vsubq_f32(low, finalMaxVec[i]);
+ high = vsubq_f32(high, finalMaxVec[i]);
+
+ // Exp
+ low = expApprox(low);
+ high = expApprox(high);
+
+ // Accumulate Sum
+ vecSum0[i] = vaddq_f32(vecSum0[i], low);
+ vecSum1[i] = vaddq_f32(vecSum1[i], high);
+
+ // Store Exp result temporarily
+ vst1q_f16(dstPtr, vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high)));
+ } else {
+ // Handle Tail
+ for (int p = 0; p < remain; ++p) {
+ float val = expf((float)srcPtr[p] - currentMax[i]);
+ currentSum[i] += val;
+ dstPtr[p] = (__fp16)val;
+ }
+ memset(dstPtr + remain, 0, (packUnit - remain) * sizeof(__fp16));
+ }
+ }
+ }
+
+ // Horizontal reduction for sums
+ for (int i = 0; i < count; ++i) {
+ currentSum[i] += vaddvq_f32(vecSum0[i]) + vaddvq_f32(vecSum1[i]);
+ }
+
+ for (int i = 0; i < count; ++i) {
+ int currentK = k + i;
+ if (!isRowValid[i]) continue;
+
+ float scale;
+ if (runningMax && runningSum && updateScale) {
+ // Incremental Softmax logic
+ float oldMax = runningMax[currentK];
+ float scaleForSum = expf(oldMax - currentMax[i]);
+ runningSum[currentK] = runningSum[currentK] * scaleForSum + currentSum[i];
+ runningMax[currentK] = currentMax[i];
+ updateScale[currentK] = scaleForSum;
+ continue;
+ } else {
+ // Standard Softmax logic
+ if (runningMax && runningSum) {
+ currentSum[i] += runningSum[currentK] * expf(runningMax[currentK] - currentMax[i]);
+ }
+ scale = 1.0f / (currentSum[i] + 1e-20f);
+ }
+
+ float16x8_t scaleVec = vdupq_n_f16((__fp16)scale);
+
+ // Normalize Pass
+ for (int j = 0; j < reduceSizeOuter; ++j) {
+ int len = validLens[i];
+ int blockStart = j * packUnit;
+ if (blockStart >= len) break;
+
+ auto dstPtr = softmaxDst + j * stride0 + k * packUnit + i * packUnit;
+
+ if (len - blockStart >= packUnit) {
+ float16x8_t val = vld1q_f16(dstPtr);
+ val = vmulq_f16(val, scaleVec);
+ vst1q_f16(dstPtr, val);
+ } else {
+ int remain = len - blockStart;
+ for (int p = 0; p < remain; ++p) {
+ dstPtr[p] = (__fp16)((float)dstPtr[p] * scale);
+ }
+ }
+ }
+ }
+ }
+}
+
+static void MNNSoftmaxFp16_Pack1(float* dest, const float* source, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, bool mask) {
+ auto softmaxDst = (FLOAT16*)dest;
+ auto softmaxSrc = (FLOAT16*)source;
+
+ for (int k = 0; k < outside; ++k) {
+ int currentValidSize = reduceSize;
+ bool isRowValid = true;
+
+ if (mask) {
+ if (kvSeqOffset > k + validOffset) {
+ isRowValid = false;
+ currentValidSize = 0;
+ if (updateScale) updateScale[k] = 1.0f;
+ } else {
+ currentValidSize = ALIMIN(reduceSize, k + (validOffset + 1) - kvSeqOffset);
+ }
+ }
+
+ if (!isRowValid || currentValidSize == 0) {
+ memset(softmaxDst + k * reduceSize, 0, reduceSize * sizeof(__fp16));
+ continue;
+ }
+
+ auto srcRow = softmaxSrc + k * reduceSize;
+ auto dstRow = softmaxDst + k * reduceSize;
+
+ float oldMax = runningMax ? runningMax[k] : -65504.0f;
+ float16x8_t maxVec = vdupq_n_f16(-65504.0f);
+
+ // Unroll 4 (32 elements per loop)
+ int i = 0;
+ for (; i <= currentValidSize - 32; i += 32) {
+ float16x8_t v0 = vld1q_f16(srcRow + i + 0);
+ float16x8_t v1 = vld1q_f16(srcRow + i + 8);
+ float16x8_t v2 = vld1q_f16(srcRow + i + 16);
+ float16x8_t v3 = vld1q_f16(srcRow + i + 24);
+
+ maxVec = vmaxq_f16(maxVec, v0);
+ maxVec = vmaxq_f16(maxVec, v1);
+ maxVec = vmaxq_f16(maxVec, v2);
+ maxVec = vmaxq_f16(maxVec, v3);
+ }
+ // Handle remaining blocks of 8
+ for (; i <= currentValidSize - 8; i += 8) {
+ maxVec = vmaxq_f16(maxVec, vld1q_f16(srcRow + i));
+ }
+
+ // Horizontal Max reduction
+ float newMax = vmaxvq_f16(maxVec);
+
+ // Handle remaining scalars (Tail)
+ for (; i < currentValidSize; ++i) {
+ newMax = ALIMAX(newMax, (float)srcRow[i]);
+ }
+
+ float finalMax = ALIMAX(oldMax, newMax);
+ float32x4_t finalMaxVec = vdupq_n_f32(finalMax);
+
+ float sum = 0.0f;
+ float32x4_t sumVec = vdupq_n_f32(0.0f);
+
+ i = 0;
+ // Unroll 2 (16 elements). Exp is heavy, unroll 4 might cause register spilling.
+ for (; i <= currentValidSize - 16; i += 16) {
+ float16x8_t v0 = vld1q_f16(srcRow + i);
+ float16x8_t v1 = vld1q_f16(srcRow + i + 8);
+
+ // Process v0
+ float32x4_t v0_lo = vcvt_f32_f16(vget_low_f16(v0));
+ float32x4_t v0_hi = vcvt_f32_f16(vget_high_f16(v0));
+ v0_lo = expApprox(vsubq_f32(v0_lo, finalMaxVec));
+ v0_hi = expApprox(vsubq_f32(v0_hi, finalMaxVec));
+ sumVec = vaddq_f32(sumVec, v0_lo);
+ sumVec = vaddq_f32(sumVec, v0_hi);
+ vst1q_f16(dstRow + i, vcombine_f16(vcvt_f16_f32(v0_lo), vcvt_f16_f32(v0_hi)));
+
+ // Process v1
+ float32x4_t v1_lo = vcvt_f32_f16(vget_low_f16(v1));
+ float32x4_t v1_hi = vcvt_f32_f16(vget_high_f16(v1));
+ v1_lo = expApprox(vsubq_f32(v1_lo, finalMaxVec));
+ v1_hi = expApprox(vsubq_f32(v1_hi, finalMaxVec));
+ sumVec = vaddq_f32(sumVec, v1_lo);
+ sumVec = vaddq_f32(sumVec, v1_hi);
+ vst1q_f16(dstRow + i + 8, vcombine_f16(vcvt_f16_f32(v1_lo), vcvt_f16_f32(v1_hi)));
+ }
+
+ // Handle remaining blocks of 8
+ for (; i <= currentValidSize - 8; i += 8) {
+ float16x8_t v = vld1q_f16(srcRow + i);
+ float32x4_t v_lo = vcvt_f32_f16(vget_low_f16(v));
+ float32x4_t v_hi = vcvt_f32_f16(vget_high_f16(v));
+
+ v_lo = expApprox(vsubq_f32(v_lo, finalMaxVec));
+ v_hi = expApprox(vsubq_f32(v_hi, finalMaxVec));
+
+ sumVec = vaddq_f32(sumVec, v_lo);
+ sumVec = vaddq_f32(sumVec, v_hi);
+
+ vst1q_f16(dstRow + i, vcombine_f16(vcvt_f16_f32(v_lo), vcvt_f16_f32(v_hi)));
+ }
+
+ // Handle Tail scalars
+ if (i < currentValidSize) {
+ __fp16 tempDst[8];
+ int remain = currentValidSize - i;
+ auto sPtr = srcRow + i;
+ for (int p = 0; p < remain; ++p) {
+ float val = expf((float)sPtr[p] - finalMax);
+ sum += val;
+ tempDst[p] = (__fp16)val;
+ }
+ memcpy(dstRow + i, tempDst, remain * sizeof(__fp16));
+ i += remain; // align i to currentValidSize
+ }
+
+ sum += vaddvq_f32(sumVec);
+
+ // Fill remaining invalid part with 0
+ if (currentValidSize < reduceSize) {
+ memset(dstRow + currentValidSize, 0, (reduceSize - currentValidSize) * sizeof(__fp16));
+ }
+
+ if (runningMax && runningSum && updateScale) {
+ float scaleForSum = expf(oldMax - finalMax);
+ runningSum[k] = runningSum[k] * scaleForSum + sum;
+ runningMax[k] = finalMax;
+ updateScale[k] = scaleForSum;
+ } else {
+ if (runningMax && runningSum) {
+ sum += runningSum[k] * expf(oldMax - finalMax);
+ }
+ float scale = 1.0f / (sum + 1e-20f);
+ float16x8_t scaleVec = vdupq_n_f16((__fp16)scale);
+
+ // Unroll 4 (32 elements) for throughput
+ i = 0;
+ for (; i <= currentValidSize - 32; i += 32) {
+ float16x8_t v0 = vld1q_f16(dstRow + i);
+ float16x8_t v1 = vld1q_f16(dstRow + i + 8);
+ float16x8_t v2 = vld1q_f16(dstRow + i + 16);
+ float16x8_t v3 = vld1q_f16(dstRow + i + 24);
+
+ vst1q_f16(dstRow + i, vmulq_f16(v0, scaleVec));
+ vst1q_f16(dstRow + i + 8, vmulq_f16(v1, scaleVec));
+ vst1q_f16(dstRow + i + 16, vmulq_f16(v2, scaleVec));
+ vst1q_f16(dstRow + i + 24, vmulq_f16(v3, scaleVec));
+ }
+ for (; i <= currentValidSize - 8; i += 8) {
+ float16x8_t v = vld1q_f16(dstRow + i);
+ vst1q_f16(dstRow + i, vmulq_f16(v, scaleVec));
+ }
+ for (; i < currentValidSize; ++i) {
+ dstRow[i] = (__fp16)((float)dstRow[i] * scale);
+ }
+ }
+ }
+}
+
+
+static void MNNSoftmaxFp16(float* dest, const float* source, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize, int kvSeqOffset, int validOffset, int pack, bool mask) {
+ // source shape: [reduceSizeOuter, outside, reduceSizeInner]
+ // for C4, [up_div(reduceSize,8), outside,8] => reduceSizeOuter=up_div(reduceSize,8), reduceSizeInner=8
+ // for C, [outside, reduceSize] => reduceSizeOuter=1, reduceSizeInner=reduceSize
+ if (pack == 8) {
+ MNNSoftmaxFp16_Pack8(dest, source, runningMax, runningSum, updateScale, outside, reduceSize, kvSeqOffset, validOffset, pack, mask);
+ return;
+ }
+ if (pack == 1) {
+ MNNSoftmaxFp16_Pack1(dest, source, runningMax, runningSum, updateScale, outside, reduceSize, kvSeqOffset, validOffset, mask);
+ return;
+ }
+ MNN_ERROR("MNNSoftMaxFp16 not support pack!=8 and pack!=1\n");
+ return;
+}
+
static CoreFunctions* gInstance = nullptr;
static CoreInt8Functions* gArm82CoreInt8Functions = nullptr;
@@ -2401,11 +2630,11 @@ bool Arm82Functions::init() {
gArm82CoreInt8Functions = new CoreInt8Functions;
*gArm82CoreInt8Functions = *MNNGetInt8CoreFunctions();
gInstance->int8MatmulRelatedFunctions = origin->int8MatmulRelatedFunctions;
- gInstance->sme2Int8MatmulRelatedFuncionsHp32 = origin->sme2Int8MatmulRelatedFuncionsHp32;
{
if (origin->supportSDot) {
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>;
- gInstance->supportSDot = true;
+ gInstance->arm82MatmulRelatedFunctions = origin->arm82MatmulRelatedFunctions;
+ gInstance->arm82MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>;
}
if (origin->supportI8mm) {
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L8<10, 8>;
@@ -2413,11 +2642,6 @@ bool Arm82Functions::init() {
}
}
- FUNC_PTR_ASSIGN(gInstance->MNNFp32ToFp8, MNNFp32ToFp8);
- FUNC_PTR_ASSIGN(gInstance->MNNFp16ToFp8, MNNFp16ToFp8);
- FUNC_PTR_ASSIGN(gInstance->MNNFp8ToFp32, MNNFp8ToFp32);
- FUNC_PTR_ASSIGN(gInstance->MNNFp8ToFp16, MNNFp8ToFp16);
-
FUNC_PTR_ASSIGN(gInstance->MNNFp32ToLowp, MNNQuantizeFP16);
FUNC_PTR_ASSIGN(gInstance->MNNLowpToFp32, MNNDequantizeFP16);
gInstance->bytes = 2;
@@ -2435,7 +2659,7 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge);
gInstance->MNNReorderWeightInt4 = origin->MNNReorderWeightInt4;
gInstance->MNNSumWeightInt8 = origin->MNNSumWeightInt8;
- gInstance->MNNSumWeightInt8SmeHp64 = origin->MNNSumWeightInt8SmeHp64;
+ gInstance->MNNSumWeightInt8SmeHp128 = origin->MNNSumWeightInt8SmeHp128;
gInstance->penalty = 2.0f;
FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16);
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16);
@@ -2456,12 +2680,13 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
- FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, origin->MNNSoftmax);
+ FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, MNNSoftmaxFp16);
#if defined(__aarch64__)
gInstance->supportFp16arith = origin->supportFp16arith;
gInstance->supportSDot = origin->supportSDot;
gInstance->supportI8mm = origin->supportI8mm;
gInstance->supportSME2 = origin->supportSME2;
+ gInstance->smeCoreNumber = origin->smeCoreNumber;
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
// Weight Dequant Gemm Kernels
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8);
@@ -2478,6 +2703,7 @@ bool Arm82Functions::init() {
if (origin->supportSDot) {
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm82);
+ gInstance->arm82MatmulRelatedFunctions.MNNGeneralIm2Col = MNNGeneralIm2col_Arm82;
}
if (origin->supportI8mm) {
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm86);
@@ -2490,10 +2716,10 @@ bool Arm82Functions::init() {
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
// Attention
- FUNC_PTR_ASSIGN(gInstance->MNNAttenUnpackAndConvertFp16, MNNAttenUnpackAndConvertFp16);
- FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndConvertFp32, MNNAttenPackAndConvertFp32);
FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndScaleSingleHead, MNNAttenPackAndScaleSingleHead);
FUNC_PTR_ASSIGN(gInstance->MNNFlashAttentionUpdateBlockOutput, MNNFlashAttentionUpdateBlockOutput);
+ gInstance->MNNQuantAttentionKey = MNNQuantAttentionKeyFP16;
+ gInstance->MNNQuantAttentionValue = MNNQuantAttentionValueFP16;
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16;
@@ -2520,11 +2746,11 @@ bool Arm82Functions::init() {
gInstance->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
gInstance->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col;
}
+
#ifdef __aarch64__
#ifdef MNN_SME2
if (origin->supportSME2) {
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
- gInstance->sme2Int8MatmulRelatedFuncionsHp32.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16_SME2);
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16_SME2);
@@ -2534,12 +2760,16 @@ bool Arm82Functions::init() {
#ifdef MNN_LOW_MEMORY
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Fp16Sme2);
- gInstance->sme2Int8MatmulRelatedFuncionsHp32.MNNGeneralIm2Col = MNNGeneralIm2col_Fp16Sme2;
#endif
}
#endif // MNN_SME2
#endif // __aarch64__
+ // Update the function pointers in the int8MatmulRelatedFunctions struct.
+ gInstance->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
+ gInstance->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col;
+
+
return true;
}
diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S
index a72cc35f18..c7acea666a 100644
--- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S
+++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S
@@ -131,19 +131,20 @@ ldr x23, [x6, #56] // fp32minmax
ldr x26, [x6, #64] // blockNum
lsl x22, x7, #2 // eDest * SRC_UNIT
-mov x14, #-32
+mov x25, #-32
cbz x27, TILE_12
-mov x14, #16
+sub x25, x22, #32
TILE_12:
cmp x7, #12
blt TILE_8
sub x4, x4, #128
+ mov x12, x2
+ mov x6, x0
+ mov x14, x5
mov x20, x9
mov x15, x8 // input kernel sum
- mov x6, x27 // input dequant bias
mov x21, x24 // input dequant scale
- mov x12, #-320
L8LoopDz_TILE_12:
mov x11, x1
mov x19, #0
@@ -157,8 +158,8 @@ TILE12_BLOCKNUM:
SET_BIAS v28, v29, v30, v31
L8LoopSz_TILE_12:
- ld1 {v3.16b, v4.16b}, [x2], #32 // weight
- ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src
+ ld1 {v3.16b, v4.16b}, [x12], #32 // weight
+ ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src
.inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0]
.inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1]
.inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2]
@@ -168,6 +169,7 @@ TILE12_BLOCKNUM:
.inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1]
.inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2]
.inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3]
+
.inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0]
.inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1]
.inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2]
@@ -191,9 +193,8 @@ TILE12_BLOCKNUM:
L8LoopSzEnd_TILE_12:
L8Tile12Quan:
- ld1 {v0.4s, v1.4s}, [x2], #32 // scale
- ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // input kernel sum
- ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint
+ ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 // weight scale&bias
+ ld1 {v4.4s, v5.4s, v6.4s}, [x8], x22 // input kernel sum
Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19
@@ -209,7 +210,7 @@ TILE12_BLOCKNUM:
MUL_SCALE v1, v28, v29, v30, v31
ld1 {v0.4s, v1.4s}, [x24], #32
- ld1 {v7.4s}, [x24], x14
+ ld1 {v7.4s}, [x24], x25
MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v1, v12, v13, v14, v15
MUL_EXTRA_SCALE v7, v16, v17, v18, v19
@@ -218,34 +219,34 @@ TILE12_BLOCKNUM:
MUL_EXTRA_SCALE v7, v28, v29, v30, v31
TILE12_L8_MLA_TERM:
- MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3
- MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3
- MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3
- MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3
- MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3
- MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3
- MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3
- MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3
- MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3
- MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3
- MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3
- MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3
-
- MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7
- MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7
- MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7
- MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7
- MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7
- MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7
- MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7
- MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7
- MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7
- MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7
- MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7
- MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7
+ MLA_WEIGHTZERO v8, v4, v2, 0 // tile:0, oc:0-3
+ MLA_WEIGHTZERO v9, v4, v2, 1 // tile:1, oc:0-3
+ MLA_WEIGHTZERO v10, v4, v2, 2 // tile:2, oc:0-3
+ MLA_WEIGHTZERO v11, v4, v2, 3 // tile:3, oc:0-3
+ MLA_WEIGHTZERO v12, v5, v2, 0 // tile:4, oc:0-3
+ MLA_WEIGHTZERO v13, v5, v2, 1 // tile:5, oc:0-3
+ MLA_WEIGHTZERO v14, v5, v2, 2 // tile:6, oc:0-3
+ MLA_WEIGHTZERO v15, v5, v2, 3 // tile:7, oc:0-3
+ MLA_WEIGHTZERO v16, v6, v2, 0 // tile:8, oc:0-3
+ MLA_WEIGHTZERO v17, v6, v2, 1 // tile:9, oc:0-3
+ MLA_WEIGHTZERO v18, v6, v2, 2 // tile:10, oc:0-3
+ MLA_WEIGHTZERO v19, v6, v2, 3 // tile:11, oc:0-3
+
+ MLA_WEIGHTZERO v20, v4, v3, 0 // tile:0, oc:4-7
+ MLA_WEIGHTZERO v21, v4, v3, 1 // tile:1, oc:4-7
+ MLA_WEIGHTZERO v22, v4, v3, 2 // tile:2, oc:4-7
+ MLA_WEIGHTZERO v23, v4, v3, 3 // tile:3, oc:4-7
+ MLA_WEIGHTZERO v24, v5, v3, 0 // tile:4, oc:4-7
+ MLA_WEIGHTZERO v25, v5, v3, 1 // tile:5, oc:4-7
+ MLA_WEIGHTZERO v26, v5, v3, 2 // tile:6, oc:4-7
+ MLA_WEIGHTZERO v27, v5, v3, 3 // tile:7, oc:4-7
+ MLA_WEIGHTZERO v28, v6, v3, 0 // tile:8, oc:4-7
+ MLA_WEIGHTZERO v29, v6, v3, 1 // tile:9, oc:4-7
+ MLA_WEIGHTZERO v30, v6, v3, 2 // tile:10, oc:4-7
+ MLA_WEIGHTZERO v31, v6, v3, 3 // tile:11, oc:4-7
cbz x27, TILE12_ADD_DSTV
- ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias
+ ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias
ld1 {v3.4s, v4.4s}, [x28], #32 // weight kernel sum
MLA_WEIGHTZERO v8, v0, v3, 0
MLA_WEIGHTZERO v9, v0, v3, 1
@@ -284,9 +285,10 @@ TILE12_BLOCKNUM:
ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3
ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
- ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x12
+ ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10]
ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3
ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7
+ sub x10, x10, #320
TILE12_L8_ACCUM_BUFFER:
add x19, x19, #1
@@ -297,11 +299,12 @@ TILE12_BLOCKNUM:
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
- st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], x12
+ st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10]
+ sub x10, x10, #320
b TILE12_BLOCKNUM
TILE12_POST:
- sub x5, x5, #1
+ sub x14, x14, #1
cbz x9, TILE12_CVT_FP16
ld1 {v0.4s, v1.4s}, [x20], #32
ADD_BIAS_FLOAT v8, v9, v10, v11, v0
@@ -329,16 +332,32 @@ TILE12_BLOCKNUM:
TILE12_STORE:
- st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64
- st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64
- st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4
+ st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], #64
+ st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x6], #64
+ st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x6], x4
L8Tile12LoopCheck:
- cbz x5, End
+ cbz x14, Tile12End
mov x8, x15 // revert input kernel sum
mov x24, x21 // revert input dequant scale
- mov x27, x6 // revert input dequant bias
+ cbz x27, L8LoopDz_TILE_12
+ REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22
b L8LoopDz_TILE_12
+Tile12End:
+
+ add x0, x0, #192
+ sub x7, x7, #12
+ cbz x7, End
+ add x1, x1, #48
+ add x8, x15, #48
+ add x24, x21, #48
+ add x4, x4, #128 // revert x4
+
+ cbz x27, TILE_8
+ REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22
+ REVERT_WEIGHT_KERNEL_SUM x28, x14, x26, x5
+ add x27, x27, #48
+
TILE_8:
mov x25, #0
cbz x27, TILE_Remain
@@ -365,6 +384,7 @@ TILE8_BLOCKNUM:
SET_BIAS v12, v13, v14, v15
SET_BIAS v16, v17, v18, v19
SET_BIAS v20, v21, v22, v23
+
L8LoopSz_TILE_8:
ld1 {v3.16b, v4.16b}, [x12], #32 // weight
ld1 {v0.16b, v1.16b}, [x11], x22 // src
@@ -542,6 +562,7 @@ TILE4_BLOCKNUM:
.inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1]
.inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2]
.inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3]
+
subs x13, x13, #1
.inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0]
.inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1]
diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S
index d8d9bd4df2..46cb8b98b7 100644
--- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S
+++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S
@@ -131,19 +131,20 @@ ldr x23, [x6, #56] // fp32minmax
ldr x26, [x6, #64] // blockNum
lsl x22, x7, #2 // eDest * SRC_UNIT
-mov x14, #-32
+mov x25, #-32
cbz x27, TILE_12
-mov x14, #16
+sub x25, x22, #32
TILE_12:
cmp x7, #12
blt TILE_8
sub x4, x4, #128
+ mov x12, x2
+ mov x6, x0
+ mov x14, x5
mov x20, x9
mov x15, x8 // input kernel sum
- mov x6, x27 // input dequant bias
mov x21, x24 // input dequant scale
- mov x12, #-320
L8LoopDz_TILE_12:
mov x11, x1
mov x19, #0
@@ -159,8 +160,8 @@ TILE12_BLOCKNUM:
SET_BIAS v28, v29, v30, v31
L8LoopSz_TILE_12:
- ld1 {v5.16b}, [x2], #16 // weight
- ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src
+ ld1 {v5.16b}, [x12], #16 // weight
+ ld1 {v0.16b, v1.16b, v2.16b}, [x11], x22 // src
// int4->int8
ushr v3.16b, v5.16b, #4
and v4.16b, v5.16b, v7.16b
@@ -198,9 +199,8 @@ TILE12_BLOCKNUM:
L8LoopSzEnd_TILE_12:
L8Tile12Quan:
- ld1 {v0.4s, v1.4s}, [x2], #32 // scale
- ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // input kernel sum
- ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint
+ ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 // weight scale&bias
+ ld1 {v4.4s, v5.4s, v6.4s}, [x8], x22 // input kernel sum
Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19
@@ -216,7 +216,7 @@ TILE12_BLOCKNUM:
MUL_SCALE v1, v28, v29, v30, v31
ld1 {v0.4s, v1.4s}, [x24], #32
- ld1 {v7.4s}, [x24], x14
+ ld1 {v7.4s}, [x24], x25
MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v1, v12, v13, v14, v15
MUL_EXTRA_SCALE v7, v16, v17, v18, v19
@@ -225,34 +225,34 @@ TILE12_BLOCKNUM:
MUL_EXTRA_SCALE v7, v28, v29, v30, v31
TILE12_L8_MLA_TERM:
- MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3
- MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3
- MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3
- MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3
- MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3
- MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3
- MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3
- MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3
- MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3
- MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3
- MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3
- MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3
-
- MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7
- MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7
- MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7
- MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7
- MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7
- MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7
- MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7
- MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7
- MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7
- MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7
- MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7
- MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7
+ MLA_WEIGHTZERO v8, v4, v2, 0 // tile:0, oc:0-3
+ MLA_WEIGHTZERO v9, v4, v2, 1 // tile:1, oc:0-3
+ MLA_WEIGHTZERO v10, v4, v2, 2 // tile:2, oc:0-3
+ MLA_WEIGHTZERO v11, v4, v2, 3 // tile:3, oc:0-3
+ MLA_WEIGHTZERO v12, v5, v2, 0 // tile:4, oc:0-3
+ MLA_WEIGHTZERO v13, v5, v2, 1 // tile:5, oc:0-3
+ MLA_WEIGHTZERO v14, v5, v2, 2 // tile:6, oc:0-3
+ MLA_WEIGHTZERO v15, v5, v2, 3 // tile:7, oc:0-3
+ MLA_WEIGHTZERO v16, v6, v2, 0 // tile:8, oc:0-3
+ MLA_WEIGHTZERO v17, v6, v2, 1 // tile:9, oc:0-3
+ MLA_WEIGHTZERO v18, v6, v2, 2 // tile:10, oc:0-3
+ MLA_WEIGHTZERO v19, v6, v2, 3 // tile:11, oc:0-3
+
+ MLA_WEIGHTZERO v20, v4, v3, 0 // tile:0, oc:4-7
+ MLA_WEIGHTZERO v21, v4, v3, 1 // tile:1, oc:4-7
+ MLA_WEIGHTZERO v22, v4, v3, 2 // tile:2, oc:4-7
+ MLA_WEIGHTZERO v23, v4, v3, 3 // tile:3, oc:4-7
+ MLA_WEIGHTZERO v24, v5, v3, 0 // tile:4, oc:4-7
+ MLA_WEIGHTZERO v25, v5, v3, 1 // tile:5, oc:4-7
+ MLA_WEIGHTZERO v26, v5, v3, 2 // tile:6, oc:4-7
+ MLA_WEIGHTZERO v27, v5, v3, 3 // tile:7, oc:4-7
+ MLA_WEIGHTZERO v28, v6, v3, 0 // tile:8, oc:4-7
+ MLA_WEIGHTZERO v29, v6, v3, 1 // tile:9, oc:4-7
+ MLA_WEIGHTZERO v30, v6, v3, 2 // tile:10, oc:4-7
+ MLA_WEIGHTZERO v31, v6, v3, 3 // tile:11, oc:4-7
cbz x27, TILE12_ADD_DSTV
- ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias
+ ld1 {v0.4s, v1.4s, v2.4s}, [x27], x22 // input dequant bias
ld1 {v3.4s, v4.4s}, [x28], #32 // weight kernel sum
MLA_WEIGHTZERO v8, v0, v3, 0
MLA_WEIGHTZERO v9, v0, v3, 1
@@ -291,9 +291,10 @@ TILE12_BLOCKNUM:
ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3
ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
- ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x12
+ ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10]
ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3
ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7
+ sub x10, x10, #320
TILE12_L8_ACCUM_BUFFER:
add x19, x19, #1
@@ -304,11 +305,12 @@ TILE12_BLOCKNUM:
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
- st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], x12
+ st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10]
+ sub x10, x10, #320
b TILE12_BLOCKNUM
TILE12_POST:
- sub x5, x5, #1
+ sub x14, x14, #1
cbz x9, TILE12_CVT_FP16
ld1 {v0.4s, v1.4s}, [x20], #32
ADD_BIAS_FLOAT v8, v9, v10, v11, v0
@@ -336,16 +338,32 @@ TILE12_BLOCKNUM:
TILE12_STORE:
- st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64
- st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64
- st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4
+ st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], #64
+ st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x6], #64
+ st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x6], x4
L8Tile12LoopCheck:
- cbz x5, End
+ cbz x14, Tile12End
mov x8, x15 // revert input kernel sum
mov x24, x21 // revert input dequant scale
- mov x27, x6 // revert input dequant bias
+ cbz x27, L8LoopDz_TILE_12
+ REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22
b L8LoopDz_TILE_12
+Tile12End:
+
+ add x0, x0, #192
+ sub x7, x7, #12
+ cbz x7, End
+ add x1, x1, #48
+ add x8, x15, #48
+ add x24, x21, #48
+ add x4, x4, #128 // revert x4
+
+ cbz x27, TILE_8
+ REVERT_INPUT_DEQUANT_BIAS x27, x19, x26, x22
+ REVERT_WEIGHT_KERNEL_SUM x28, x14, x26, x5
+ add x27, x27, #48
+
TILE_8:
mov x25, #0
cbz x27, TILE_Remain
diff --git a/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S b/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S
index 25c1e3632c..795b6420df 100644
--- a/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S
+++ b/source/backend/arm82/asm/arm64/sme2_asm/MNNPackedMatMulRemainFP16_SME2.S
@@ -32,7 +32,7 @@ ldr x10, [x4, #16] // h
ldr x7, [x4, #24] // cStride
ldr x19, [x4, #40] // bExtraStride
-lsr x7, x7, #1 // cStride / sizeof(float16_t)
+
lsl x21, x3, #1 // eSize * lP
mov w12, #0
@@ -65,6 +65,10 @@ cbz x5, ESIZE
.inst 0x052223bf // dup z31.h, z29.h[0]
ESIZE: // x3 <= eP
+cmp x3, #1
+beq E1
+
+lsr x7, x7, #1 // cStride / sizeof(float16_t)
cmp x3, #16
blt LoopOcDiv8
@@ -277,6 +281,185 @@ beq End
b LoopOcDiv8 // continue next ocDiv8
+E1:
+cmp x3, #1
+blt End
+
+E1LoopH:
+mov x8, x1 // A
+mov x21, x9 // LU
+
+.inst 0xc00800ff // zero {za}
+mov w11, #0
+
+cbz x6, E1LoopL
+// bias
+lsl x4, x10, #3
+.inst 0x256447f1 // whilelt pn9.h, xzr, x4, vlx2 // oc to process, maximum is 64
+.inst 0xa04024c8 // ld1h {z8.h-z9.h}, pn9/z, [x6]
+.inst 0x04265046 // addvl x6, x6, #2
+
+.inst 0x6589a50c // fcvt z12.s, p1/m, z8.h
+.inst 0x6489a50d // fcvtlt z13.s, p1/m, z8.h
+.inst 0x6589a52e // fcvt z14.s, p1/m, z9.h
+.inst 0x6489a52f // fcvtlt z15.s, p1/m, z9.h
+
+.inst 0x05ad6194 // zip1 z20.s, z12.s, z13.s
+.inst 0x05ad6595 // zip2 z21.s, z12.s, z13.s
+.inst 0x05af61d6 // zip1 z22.s, z14.s, z15.s
+.inst 0x05af65d7 // zip2 z23.s, z14.s, z15.s
+.inst 0xc1a17e80 // fadd za.s[w11, 0, VGx4], {z20.s-z23.s}
+
+E1LoopL:
+.inst 0x8540c504 // ld1rw {z4.s}, p1/z, [x8] // A
+.inst 0xa040a040 // ld1h {z0.h-z3.h}, pn8/z, [x2] // B
+// [EP,LP] x [HP,LP] -> [EP,HP]
+.inst 0xc1347000 // fdot za.s[w11, 0, VGx4], {z0.h-z3.h}, z4.h
+
+subs x21, x21, #1
+add x8, x8, x22
+.inst 0x04225082 // addvl x2, x2, #4
+bne E1LoopL
+
+add x2, x2, x19 // bExtraStride
+
+E1_To_FP16:
+.inst 0xc0820000 // mova z0.s, p0/m, za0h.s[w12, 0]
+.inst 0xc0822001 // mova z1.s, p0/m, za0h.s[w13, 0]
+.inst 0xc0824002 // mova z2.s, p0/m, za0h.s[w14, 0]
+.inst 0xc0826003 // mova z3.s, p0/m, za0h.s[w15, 0]
+
+.inst 0xc120e010 // fcvt z16.h, {z0.s-z1.s}
+.inst 0xc120e051 // fcvt z17.h, {z2.s-z3.s}
+cbz x5, E1Store
+.inst 0xc17fc3d0 // fclamp {z16.h-z17.h}, z30.h, z31.h
+
+E1Store:
+cmp x10, #8
+bge E1StoreH64
+
+cmp x10, #1
+beq E1StoreH8
+
+cmp x10, #2
+beq E1StoreH16
+
+cmp x10, #3
+beq E1StoreH24
+
+cmp x10, #4
+beq E1StoreH32
+
+cmp x10, #5
+beq E1StoreH40
+
+cmp x10, #6
+beq E1StoreH48
+
+cmp x10, #7
+beq E1StoreH56
+
+E1StoreH64:
+add x21, x0, x7, LSL #1 // x0+2*x7
+add x11, x0, x7, LSL #2 // x0+4*x7
+add x20, x11, x7, LSL #1 // x20+6*x7
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0x05b02213 // dup z19.q, z16.q[2]
+.inst 0x05f02214 // dup z20.q, z16.q[3]
+.inst 0x05702235 // dup z21.q, z17.q[1]
+.inst 0x05b02236 // dup z22.q, z17.q[2]
+.inst 0x05f02237 // dup z23.q, z17.q[3]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21]
+.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7]
+.inst 0xe400f171 // st1b {z17.b}, p4, [x11]
+.inst 0xe4075175 // st1b {z21.b}, p4, [x11, x7]
+.inst 0xe400f296 // st1b {z22.b}, p4, [x20]
+.inst 0xe4075297 // st1b {z23.b}, p4, [x20, x7]
+b E1H16_End
+
+E1StoreH56:
+add x21, x0, x7, LSL #1
+add x11, x0, x7, LSL #2
+add x20, x11, x7, LSL #1
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0x05b02213 // dup z19.q, z16.q[2]
+.inst 0x05f02214 // dup z20.q, z16.q[3]
+.inst 0x05702235 // dup z21.q, z17.q[1]
+.inst 0x05b02236 // dup z22.q, z17.q[2]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21]
+.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7]
+.inst 0xe400f171 // st1b {z17.b}, p4, [x11]
+.inst 0xe4075175 // st1b {z21.b}, p4, [x11, x7]
+.inst 0xe400f296 // st1b {z22.b}, p4, [x20]
+b End
+
+E1StoreH48:
+add x21, x0, x7, LSL #1
+add x11, x0, x7, LSL #2
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0x05b02213 // dup z19.q, z16.q[2]
+.inst 0x05f02214 // dup z20.q, z16.q[3]
+.inst 0x05702235 // dup z21.q, z17.q[1]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21]
+.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7]
+.inst 0xe400f171 // st1b {z17.b}, p4, [x11]
+.inst 0xe4075175 // st1b {z21.b}, p4, [x11, x7]
+b End
+
+E1StoreH40:
+add x21, x0, x7, LSL #1
+add x11, x0, x7, LSL #2
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0x05b02213 // dup z19.q, z16.q[2]
+.inst 0x05f02214 // dup z20.q, z16.q[3]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21]
+.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7]
+.inst 0xe400f171 // st1b {z17.b}, p4, [x11]
+b End
+
+E1StoreH32:
+add x21, x0, x7, LSL #1
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0x05b02213 // dup z19.q, z16.q[2]
+.inst 0x05f02214 // dup z20.q, z16.q[3]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21]
+.inst 0xe40752b4 // st1b {z20.b}, p4, [x21, x7]
+b End
+
+E1StoreH24:
+add x21, x0, x7, LSL #1
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0x05b02213 // dup z19.q, z16.q[2]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+.inst 0xe400f2b3 // st1b {z19.b}, p4, [x21]
+b End
+
+E1StoreH16:
+.inst 0x05702212 // dup z18.q, z16.q[1]
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+.inst 0xe4075012 // st1b {z18.b}, p4, [x0, x7]
+b End
+
+E1StoreH8:
+.inst 0xe400f010 // st1b {z16.b}, p4, [x0]
+b End
+
+E1H16_End:
+subs x10, x10, #8
+add x0, x0, x7, LSL #3
+bne E1LoopH
+
End:
.inst 0xd503467f // smstop
diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp
index a12394f9ac..2f1324e287 100644
--- a/source/backend/cpu/CPUAttention.cpp
+++ b/source/backend/cpu/CPUAttention.cpp
@@ -17,6 +17,8 @@
#include "core/BufferAllocator.hpp"
#include "core/TensorUtils.hpp"
#include "core/OpCommonUtils.hpp"
+#include "core/BufferAllocator.hpp"
+#include "compute/ConvolutionTiledExecutor.hpp"
#if defined (__aarch64__)
#define FLOAT16_T __fp16
@@ -24,264 +26,284 @@
#define FLOAT16_T float
#endif
-#define MNN_FLASH_ATTENTION_BLOCK_SIZE 64
namespace MNN {
template
-void CPUAttention::pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale) {
- if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8]
- mMinQ[h] = query->host()[h * mHeadDim];
- mMaxQ[h] = query->host()[h * mHeadDim];
- for (int i = 0; i < seq_len; i++) {
- T * query_src = query->host() + i * mNumHead * mHeadDim + h * mHeadDim;
- for (int j = 0; j < mHeadDim; j++) {
- mMinQ[h] = ALIMIN(mMinQ[h], query_src[j]);
- mMaxQ[h] = ALIMAX(mMaxQ[h], query_src[j]);
- }
- }
- mQueryScale[h] = (mMaxQ[h] - mMinQ[h]) / 255.0f;
- mQueryZeroPoint[h] = -255.0f * mMinQ[h] / (mMaxQ[h] - mMinQ[h]) - 128.0;
- for (int i = 0; i < seq_len; i++) {
- T * query_src = query->host() + i * mNumHead * mHeadDim + h * mHeadDim;
- float sumQ = 0;
- int out_index = i / eP8;
- int in_index = i % eP8;
- for (int j = 0; j < mHeadDim; j++) {
- int a = j / lP8;
- int b = j % lP8;
- int quant_res = (int)roundf(query_src[j] / mQueryScale[h] + mQueryZeroPoint[h]);
- sumQ += quant_res;
- *((int8_t*)pack_q + out_index * UP_DIV(mHeadDim, lP8) * eP8 * lP8 + a * eP8 * lP8 + in_index * lP8 + b) = quant_res;
- }
- *((float*)sum_q + out_index * eP8 + in_index) = sumQ * mQueryScale[h];
- }
- }
- else {
- // target: [seq_len/eP, mHeadDim/lP, eP, lP]
- T * query_src = query->host();
- T * query_dst = reinterpret_cast(pack_q);
- auto stride0 = ROUND_UP(mHeadDim, lP) * eP;
- auto stride1 = eP * lP;
- if (mHeadDim % lP) {
- memset(query_dst, 0, ROUND_UP(mHeadDim, lP) * bytes * ROUND_UP(seq_len, eP));
- }
- for (int i = 0; i < seq_len; i++) {
- int out_index = i / eP;
- int in_index = i % eP;
- for (int j = 0; j < mHeadDim; j++) {
- query_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale;
- }
- }
+static void _maskQK(float * qkPacked, const float* scale, size_t seqLen, size_t processedKvSeq, int pack, int kvSeqLen, int kvoffset, int padKvSeqLen, const float* sinksPtr, const Tensor* mask, bool quantKey, bool isLowerTriangular) {
+ /*
+ * FIGURE 1: mask->elementSize() == seqLen * maskStride
+ * Context: Cross Attention or Prefill stage (Full Context).
+ * Logic: gapLen = 0. The mask tensor dimensions match the logical QK matrix exactly.
+ * Direct access: mask[row * stride + col]
+ * Row\Col 0 1 2 3
+ *
+ * 0 0 X X X (Can only see Col 0)
+ *
+ * 1 0 0 X X (Can see Col 0, 1)
+ *
+ * 2 0 0 0 X (Can see Col 0, 1, 2)
+ *
+ * 3 0 0 0 0 (Fully visible)
+ *
+ * Legend:
+ * '0' : Visible (Value = Scale * QK)
+ * 'X' : Masked (Value = -inf)
+ */
+
+
+ /*
+ * FIGURE 2: mask->elementSize() != seqLen * maskStride
+ * Context: Self-Attention Inference (Decoding stage).
+ * Logic: gapLen = maskStride - seqLen (Right Alignment).
+ * The "Gap" represents History KV Cache, which is implicitly visible.
+ * The Mask Tensor only covers the current sequence window.
+ *
+ * Example: maskStride (Total KV) = 6
+ * seqLen (Current Q) = 4
+ * gapLen = 6 - 4 = 2
+ *
+ * Structure:
+ * - Cols [0, 1]: "Gap" / History region. Code logic: `if (col < gapLen) continue;`.
+ * No mask is added, so they remain Visible ('0').
+ * - Cols [2-5]: "Current" region. Code logic: `mask[col - gapLen]`.
+ *
+ * Row\Col 0 1 | 2 3 4 5
+ * (Gap) | (Mask Tensor Region)
+ *
+ * 0 0 0 | 0 X X X <-- Mask row 0 applies to Col 2~5
+ * |
+ * 1 0 0 | 0 0 X X <-- Mask row 1 applies to Col 2~5
+ * |
+ * 2 0 0 | 0 0 0 X <-- Mask row 2 applies to Col 2~5
+ * |
+ * 3 0 0 | 0 0 0 0 <-- Mask row 3 applies to Col 2~5
+ *
+ * Legend:
+ * '0' (Left) : History KV, implicitly visible (code skips mask addition).
+ * '0' (Right) : Current KV, visible according to Mask Tensor.
+ * 'X' : Masked by Mask Tensor (-inf).
+ */
+
+ if (isLowerTriangular && quantKey) {
+ return;
}
-}
+ constexpr float NEG_INF = -std::numeric_limits::infinity();
+ auto source = (T*)qkPacked;
+ float scaleVal = scale[0];
+ int gapLen = (mask->elementSize() == (seqLen + padKvSeqLen) * (kvSeqLen + padKvSeqLen)) ? 0 : static_cast(kvSeqLen - seqLen);
-template
-void CPUAttention::unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len) {
- float * dst = unpack_qk_dst;
- T * src = (T *)(pack_qk_src);
- // [kv_seq_len/mPack, seq_len, mPack] -> [seq_len, kv_seq_len]
- for (int i = 0; i < seq_len; i++) {
- for (int j = 0; j < kv_seq_len; j++) {
- int out_index = j / mPack;
- int in_index = j % mPack;
- dst[i * kv_seq_len + j] = src[out_index * seq_len * mPack + i * mPack + in_index];
- }
- }
-}
+ auto kvBlockCount = UP_DIV(processedKvSeq, pack);
+ auto qkSize = ROUND_UP(processedKvSeq, pack) * seqLen;
-template
-static void pack_QK(int8_t * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
- T * dst = reinterpret_cast(pack_qk_dst);
- float * src = reinterpret_cast(qk_src);
- // [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP]
- auto stride0 = ROUND_UP(kv_seq_len, lP) * eP;
- auto stride1 = eP * lP;
- if (kv_seq_len % lP) {
- memset(dst, 0, ROUND_UP(kv_seq_len, lP) * ROUND_UP(seq_len, eP) * bytes);
- }
- for (int i = 0; i < seq_len; i++) {
- int out_index = i / eP;
- int in_index = i % eP;
- for (int j = 0; j < kv_seq_len; j++) {
- dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = src[i * kv_seq_len + j];
+ if (isLowerTriangular) {
+ for (int i = 0; i < qkSize; ++i) {
+ source[i] *= scaleVal;
}
+ return;
}
-}
-
-template
-static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* maskTensor, int offset, int startIndx, int processedKvLen, int extraSeq) {
- int endIndx = startIndx + processedKvLen;
- if (maskTensor == nullptr) {
- for (int i = 0; i < processedKvLen; i++) {
- unpack_qk[i] = unpack_qk[i] * mScale;
- }
+ if (mask == nullptr) {
return;
}
- const int8_t* mask = maskTensor->host();
- halide_type_t htype = maskTensor->getType();
- int maskSize = maskTensor->elementSize();
-
- if (htype == halide_type_of()) {
- // float mask
- T* fpmask_ptr = (T*)mask;
- if (maskSize == (seq_len + extraSeq) * (kv_seq_len + extraSeq)) { // sliding attention, mask shape: [seq_len, kv_seq_len]
- for (int i = 0; i < seq_len; ++i) {
- auto unpack_qki = unpack_qk + i * processedKvLen;
- auto fpmask_ptri = fpmask_ptr + i * (kv_seq_len + extraSeq);
- for (int j = startIndx; j < endIndx; ++j) {
- unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j];
- }
- }
- } else { // mask shape: [seq_len, seq_len]
- for (int i = 0; i < seq_len; ++i) {
- auto unpack_qki = unpack_qk + i * processedKvLen;
- auto fpmask_ptri = fpmask_ptr + i * (seq_len + extraSeq);
-
- auto notMaskIndx = ALIMIN(endIndx, offset);
- auto stMaskIndx = ALIMAX(startIndx, offset);
- for (int j = startIndx; j < notMaskIndx; ++j) {
- unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale;
- }
- for (int j = stMaskIndx; j < endIndx; ++j) {
- unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j - offset];
- }
- }
- }
- } else {
- // int mask
- int* mask_ptr = (int*)mask;
- for (int i = 0; i < seq_len; ++i) {
- for (int j = 0; j < processedKvLen; ++j) {
- int maskIndex = i * kv_seq_len + startIndx +j;
- if (mask_ptr[maskIndex]) {
- unpack_qk[i * processedKvLen + j] = unpack_qk[i * processedKvLen + j] * mScale;
- } else {
- unpack_qk[i * processedKvLen + j] = min_val;
- }
- }
- }
- }
-}
-typedef void(softmaxFunc)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
-template
-static void softmaxQK(float* softmax_qk_addr, float* unpack_qk_addr, float* runningMax, float* runningSum, float* diffScale, const float* sinkPtr, softmaxFunc* sffunc, int seq_len, int kv_seq_len, int headIdx, bool isLastKvBlock) {
+ auto maskPtr = mask->host();
- // not sliding attention
- if (sinkPtr == nullptr) {
- sffunc(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len);
- return;
- }
+ // not lower triangular
+ auto maskCols = (mask->elementSize() == (seqLen + padKvSeqLen) * (kvSeqLen + padKvSeqLen)) ? kvSeqLen + padKvSeqLen : seqLen + padKvSeqLen;
+ for (int i = 0; i < kvBlockCount; ++i) {
+ T* blockDataPtr = source + (i * seqLen * pack);
- float sink = ((T*)sinkPtr)[headIdx];
- if (!runningMax && !runningSum) { // Do not use flash attention
+ for (int j = 0; j < seqLen; ++j) {
+ T* dataPtr = blockDataPtr + (j * pack);
+ const T* currentMaskRow = maskPtr + j * maskCols;
- for (int i = 0; i < seq_len; ++i) {
- float exprOffset[4] = {1, 0, -sink, 1.f};
- MNNExp(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, exprOffset, kv_seq_len);
- for (int j = 0; j < kv_seq_len; ++j) {
- softmax_qk_addr[i * kv_seq_len + j] /= exprOffset[3];
- }
- }
- return;
- }
+ for (int k = 0; k < pack; ++k) {
+ float val = (float)dataPtr[k];
+ if (!quantKey) {
+ val *= scaleVal;
+ dataPtr[k] = (T)val;
+ }
+ int currentKvSeqIndx = kvoffset + i * pack + k; // kvoffset=i*mBlockKv
- // Use flash attention
- if (isLastKvBlock) {
- for (int i = 0; i < seq_len; ++i) {
- runningSum[i] += expf(sink - runningMax[i]);
- }
- }
- MNNSoftmax(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len);
-}
+ if (currentKvSeqIndx < gapLen) {
+ continue;
+ }
+ if (currentKvSeqIndx - gapLen >= maskCols) {
+ break;
+ }
-template
-static void unpack_QKV(int8_t* pack_qkv, int8_t* unpack_qkv, int mNumHead, int mHeadDim, int mPack, int seq_len) {
- auto src_ptr = reinterpret_cast(pack_qkv);
- auto dst_ptr = reinterpret_cast(unpack_qkv);
- for (int i = 0; i < seq_len; i++) {
- for (int j = 0; j < mHeadDim; j++) {
- int a = j / mPack;
- int b = j % mPack;
- dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * mPack + i * mPack + b];
+ val += (float)currentMaskRow[currentKvSeqIndx - gapLen];
+ dataPtr[k] = (T)val;
+
+ }
}
}
}
ErrorCode CPUAttention::onResize(const std::vector& inputs, const std::vector& outputs) {
- auto core = static_cast(backend())->functions();
- core->MNNGetMatMulPackMode(&eP, &lP, &hP);
+ auto gcore = static_cast(backend())->functions();
+ auto core = static_cast(backend())->int8Functions();
+ gcore->MNNGetMatMulPackMode(&eP, &lP, &hP);
mThreadNum = ((CPUBackend *)backend())->threadNumber();
- mPack = core->pack;
- bytes = core->bytes;
- int qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption;
- mUseGemmInt8 = (qkvQuantOptions % 8 == 4);
- if (mUseGemmInt8) {
- static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8);
- }
+ mPack = gcore->pack;
+ mBytes = gcore->bytes;
+ int attentionOption = static_cast(backend())->getRuntime()->hint().attentionOption;
+ mUseFlashAttention = (attentionOption / 8 == 1);
+
+ // If slide window attention applied, quant key/value must be diabled
+ mQuantKey = inputs.size() < 5 && (attentionOption % 8 >= 1);
+ mQuantValue = inputs.size() < 5 && (attentionOption % 8 > 1) && mUseFlashAttention;
+ static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8);
+
auto query = inputs[0];
auto key = inputs[1];
- int seq_len = query->length(1);
+ int seqLen = query->length(1);
+ int mBlockNum = 1;
mNumHead = query->length(2);
mHeadDim = query->length(3);
mKvNumHead = key->length(2);
+ mKVCacheManager->setAttenQuantKeyValue(mUseFlashAttention, mQuantKey, mQuantValue);
mKVCacheManager->onResize(mKvNumHead, mHeadDim);
- if (mUseGemmInt8) {
- mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8}));
- mSumQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), eP8}));
- mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack}));
+
+ // Common buffer allocated
+ auto bufferAlloc = static_cast(backend())->getBufferAllocator();
+ mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seqLen, mPack * mBytes}));
+ backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
+ if (inputs.size() > 4 || mUseFlashAttention) { // needed by flash attention and sliding attention with sink
+ mRunningMax.reset(Tensor::createDevice({mThreadNum, seqLen * 4}));
+ mRunningSum.reset(Tensor::createDevice({mThreadNum, seqLen * 4}));
+ backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC);
+ backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC);
+ }
+ if (mUseFlashAttention) { // extra buffer need by flash attention
+ mExpfDiffMax.reset(Tensor::createDevice({mThreadNum, seqLen * 4}));
+ mTempOut.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seqLen, mPack * mBytes}));
+ backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
+ backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC);
+ }
+ if (mQuantKey) {
+ int outterSeqLen = UP_DIV(seqLen, eP8);
+ int outterHeadDim = UP_DIV(mHeadDim, lP8);
+
+ size_t packedQSize = 0;
+ if (outterSeqLen > 0) {
+ int fullSeqBlocks = (seqLen / eP8);
+ packedQSize += (size_t)fullSeqBlocks * outterHeadDim * eP8 * lP8;
+
+ int lastEUnit = seqLen % eP8;
+ if (lastEUnit != 0) {
+ packedQSize += (size_t)outterHeadDim * lastEUnit * lP8;
+ }
+ }
+ mPackQ.reset(Tensor::createDevice({mNumHead, (int32_t)packedQSize}));
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
- backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC);
- backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mSumQ.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
- mMinQ.resize(mNumHead);
- mMaxQ.resize(mNumHead);
- mQueryScale.resize(mNumHead);
- mQueryZeroPoint.resize(mNumHead);
+
+ mSumQ = bufferAlloc->alloc(mThreadNum * ROUND_UP(seqLen, eP8) * mBlockNum * sizeof(int32_t));
+ mQueryScale = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES);
+ mQueryZeroPoint = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES);
+ mQueryQuantZero = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES);
+ mQueryQuantScale = bufferAlloc->alloc(mNumHead * seqLen * mBlockNum * QUANT_INFO_BYTES);
+ mQuantQuery = bufferAlloc->alloc(seqLen * mNumHead * UP_DIV(mHeadDim, gcore->pack) * gcore->pack);
+
+ if (mBlockNum > 1) {
+ mAccumBuffer = bufferAlloc->alloc(eP8 * hP8 * mThreadNum * QUANT_INFO_BYTES);
+ if (mAccumBuffer.invalid()) {
+ return OUT_OF_MEMORY;
+ }
+ }
+
+ if (mSumQ.invalid() || mQueryScale.invalid() || mQueryQuantZero.invalid() || mQueryZeroPoint.invalid() || mQueryQuantScale.invalid() || mQuantQuery.invalid()) {
+ return OUT_OF_MEMORY;
+ }
+
+ // post parameters for int8 gemm
+ mGemmRelu.reset(2 * sizeof(int32_t));
+ if (!mGemmRelu.get()) {
+ MNN_ERROR("Allocate mGemmRelu buffer failed in CPU Attention");
+ return OUT_OF_MEMORY;
+ }
+ ((float*)mGemmRelu.get())[0] = -std::numeric_limits().max();
+ ((float*)mGemmRelu.get())[1] = std::numeric_limits().max();
+ if (mBytes == 2) {
+ gcore->MNNFp32ToLowp((float*)mGemmRelu.get(), reinterpret_cast(mGemmRelu.get()), 2);
+ }
+
+ // GemmInt8 kernels
+ if (mBytes == 4) {
+ mInt8GemmKernel = core->Int8GemmKernel;
+ } else {
+ mInt8GemmKernel = core->MNNGemmInt8AddBiasScale_Unit_FP16;
+ }
+
+ if (mQuantValue) {
+ mQuantQK = bufferAlloc->alloc(mThreadNum * eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack));
+ mQKScale = bufferAlloc->alloc(eP8 * QUANT_INFO_BYTES);
+ mQKBias = bufferAlloc->alloc(eP8 * QUANT_INFO_BYTES);
+ mSumQK = bufferAlloc->alloc(mThreadNum * eP8 * QUANT_INFO_BYTES);
+
+ if (mQuantQK.invalid() || mQKScale.invalid() || mQKBias.invalid() || mSumQK.invalid()) {
+ return OUT_OF_MEMORY;
+ }
+ }
} else {
- mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP * bytes}));
- mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes}));
+ mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seqLen, eP), ROUND_UP(mHeadDim, lP), eP * mBytes}));
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
+ }
- // flash attention
- if (qkvQuantOptions / 8 == 1) {
- mRunningMax.reset(Tensor::createDevice({mThreadNum, seq_len * 4}));
- mRunningSum.reset(Tensor::createDevice({mThreadNum, seq_len * 4}));
- mExpfDiffMax.reset(Tensor::createDevice({mThreadNum, seq_len * 4}));
- mTempOut.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes}));
-
- backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC);
- backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC);
- backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
- backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC);
- }
+ // release tensor
+ backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
+ backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
+ if (inputs.size() > 4 || mUseFlashAttention) {
+ backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC);
+ backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC);
+ }
+ if (mUseFlashAttention) {
+ backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
+ backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC);
+ }
- if (qkvQuantOptions / 8 == 1) {
- backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
- backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC);
+ // release memchunk
+ if (mQuantKey) {
+ bufferAlloc->free(mSumQ);
+ bufferAlloc->free(mQueryScale);
+ bufferAlloc->free(mQueryZeroPoint);
+ bufferAlloc->free(mQueryQuantScale);
+ bufferAlloc->free(mQueryQuantZero);
+ bufferAlloc->free(mQuantQuery);
+ if (mBlockNum > 1) {
+ bufferAlloc->free(mAccumBuffer);
}
+ if (mQuantValue) {
+ bufferAlloc->free(mQuantQK);
+ bufferAlloc->free(mQKScale);
+ bufferAlloc->free(mQKBias);
+ bufferAlloc->free(mSumQK);
+ }
+ }
+
+ // Only allocated for quantized Q&K
+ if (mQuantKey) {
+ if (mBytes == 4) {
+ mQuantFunc = core->MNNFloat2Int8;
+ } else {
+ mQuantFunc = core->DynamicQuanInput_ARM82;
+ }
+
}
return NO_ERROR;
}
ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std::vector& outputs) {
- auto core = static_cast(backend())->functions();
- auto qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption;
+ auto gcore = static_cast(backend())->functions();
+ auto core = static_cast(backend())->int8Functions();
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
+ int seqLen = query->length(1);
const Tensor* mask = nullptr;
- int seq_len = query->length(1);
if (inputs.size() > 3) {
mask = inputs[3];
}
@@ -291,16 +313,16 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std:
MNN_ASSERT(sinks != nullptr);
MNN_ASSERT(sinks->elementSize() == mNumHead)
}
- int tileCount = UP_DIV(mNumHead, mThreadNum);
+ int numHeadDiv = UP_DIV(mNumHead, mThreadNum);
int group_size = mNumHead / mKvNumHead;
// reduce the value of 'query' to avoid fp16 overflow
float mScale = 1.0 / sqrt(mHeadDim);
float q_scale = 1.0;
- if (bytes == 2) {
+ if (mBytes == 2 && !mQuantKey) {
// reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow
FLOAT16_T minValue;
FLOAT16_T maxValue;
- core->MNNCountMaxMinValue(query->host(), (float*)(&minValue), (float*)(&maxValue), query->elementSize());
+ gcore->MNNCountMaxMinValue(query->host(), (float*)(&minValue), (float*)(&maxValue), query->elementSize());
float maxV = maxValue;
float minV = minValue;
float absMax = ALIMAX(fabsf(maxV), fabsf(minV));
@@ -309,169 +331,450 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std:
}
mScale /= q_scale;
}
- int insertLen = seq_len;
+ int insertLen = seqLen;
if (mKVCache && mMeta != nullptr) {
if (mMeta->previous == mMeta->remove) {
mKVCacheManager->onClear();
- mKVCacheManager->onAlloc(mMeta->add);
+ mKVCacheManager->onAlloc(mMeta, seqLen);
} else {
MNN_ASSERT(mMeta->previous == mKVCacheManager->kvLength());
mKVCacheManager->onRealloc(mMeta);
}
- insertLen = mMeta->add;
+ insertLen = (int)mMeta->add;
} else {
mKVCacheManager->onClear();
- mKVCacheManager->onAlloc(seq_len);
+ mKVCacheManager->onAlloc(mMeta, seqLen);
}
+
// Add the new kv to the kvcache
- mKVCacheManager->onPushBack(key, value, insertLen);
- int padSeqLength = seq_len - insertLen;
- seq_len = insertLen;
- int kv_seq_len = mKVCacheManager->kvLength();
- int max_len = mKVCacheManager->maxLength();
- bool quant_key = mKVCacheManager->config()->mQuantKey;
- bool quant_value = mKVCacheManager->config()->mQuantValue;
-
- mBlockKV = (qkvQuantOptions / 8 == 1) ? ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, kv_seq_len) : kv_seq_len;
+ mKVCacheManager->onUpdateKV(key, value, (int)insertLen);
+
+ if (mUseFlashAttention) {
+ mBlockKV = ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, mKVCacheManager->kvLength());
+ } else {
+ mBlockKV = mKVCacheManager->kvLength();
+ }
+
+ // Constant Initialization
+ auto padSeqLength = seqLen - insertLen;
+ seqLen = insertLen;
+ int kvSeqLen = mKVCacheManager->kvLength();
+ int maxLen = mKVCacheManager->maxLength();
int32_t units[2] = {eP, lP};
+ const float* sinksPtr = sinks ? sinks->host() : nullptr;
+ int kvValidOffset = kvSeqLen - seqLen; // reuse_kv=true or decode, kvValidOffset>0
// Temporary tensors for intermediate results
- std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seq_len, mBlockKV}));
- std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seq_len, mBlockKV}));
- std::shared_ptr newPackQK(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mBlockKV, lP), eP * bytes}));
- std::shared_ptr dequantV(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP * bytes}));
- // mTempQKBlock.reset(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
- std::shared_ptr tempQKBlock(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
+ std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seqLen, mBlockKV}));
+ std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seqLen, ROUND_UP(mBlockKV, mPack)})); // [mBlockKV/mPack, seqLen, mPack ]
+ std::shared_ptr newPackQK;
+ if (mQuantValue == false) {
+ newPackQK.reset(Tensor::createDevice({mThreadNum, eP * ROUND_UP(mBlockKV, lP) * mBytes}));
+ } else {
+ newPackQK.reset(Tensor::createDevice({mThreadNum, eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8)}));
+ }
+ std::shared_ptr mTempQKBlock(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seqLen, mPack * mBytes}));
backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC);
backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC);
- backend()->onAcquireBuffer(tempQKBlock.get(), Backend::STATIC);
- if (quant_value) {
- backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC);
- mKVCacheManager->onDequantValue(dequantV.get());
+ backend()->onAcquireBuffer(mTempQKBlock.get(), Backend::STATIC);
+
+ // Quantize Q and initialize bias 0
+ if (mQuantKey) {
+ mGemmBias.reset(ROUND_UP(ALIMAX(mBlockKV, mHeadDim), hP8) * QUANT_INFO_BYTES);
+ if (!mGemmBias.get()) {
+ MNN_ERROR("Allocate bias buffer failed in CPU Attention\n");
+ return OUT_OF_MEMORY;
+ }
+ memset(mGemmBias.get(), 0, ROUND_UP(ALIMAX(mBlockKV, mHeadDim), hP8) * QUANT_INFO_BYTES);
+
+ // Q: [seqLen,numHead,headDim]
+ // maxQ, minQ: [seqLen,numHead]
+ // scaleQ, zeroQ: [numHead, seqLen]
+ // quantQ: [seqLen,numHead,headDim]
+ auto queryPtr = query->host();
+ int divPart = UP_DIV(seqLen * mNumHead, mThreadNum);
+ MNN_CONCURRENCY_BEGIN (tId, mThreadNum) {
+ size_t info[9] = {1, (size_t)mHeadDim, 1, 1, 1, 1, 1, 1, 0};
+ auto remainLu = seqLen * mNumHead - tId * divPart;
+ if (remainLu > 0) {
+ remainLu = ALIMIN(divPart, remainLu);
+ for (int i = tId * divPart; i < tId * divPart + remainLu; ++i) {
+
+ // address
+ auto srcFloatPtr = (float*)(queryPtr + i * mHeadDim * mBytes);
+ auto dstInt8Ptr = (int8_t*)(mQuantQuery.ptr() + i * mHeadDim);
+ auto quantScalePtr = (float*)(mQueryQuantScale.ptr() + i * QUANT_INFO_BYTES);
+ auto quantZeroPtr = (float*)(mQueryQuantZero.ptr() + i * QUANT_INFO_BYTES);
+
+ // scaleQ, zeroQ, [seqLen,numHead]->[numHead,seqLen]
+ int indexQ = (i / mNumHead) + (i % mNumHead) * seqLen;
+ auto scalePtr = (float*)(mQueryScale.ptr() + indexQ * QUANT_INFO_BYTES);
+ auto zeroPtr = (float*)(mQueryZeroPoint.ptr() + indexQ * QUANT_INFO_BYTES);
+
+
+ // compute the quant/dequant scale/bias
+ gcore->MNNAsyQuantInfo(scalePtr, zeroPtr, quantScalePtr, quantZeroPtr, nullptr, nullptr, srcFloatPtr, info);
+ scalePtr[0] *= mScale;
+ zeroPtr[0] *= mScale;
+
+ // quantize the float query to int8_t query
+ mQuantFunc(srcFloatPtr, dstInt8Ptr, UP_DIV(mHeadDim, gcore->pack), quantScalePtr, -128, 127, quantZeroPtr, 0);
+ }
+ }
+ } MNN_CONCURRENCY_END();
+
+ // source int8_t query: [seqLen,numHead,headDim]
+ // dest int8_t query: [numHead,seqLen/eP,headDim/lP,eP,lP]
+
+ int outterSeqLen = UP_DIV(seqLen, eP8);
+ int outterHeadDim = UP_DIV(mHeadDim, lP8);
+ size_t outputOffset = 0;
+
+ const int8_t* src_base_ptr = (const int8_t*)mQuantQuery.ptr();
+ int8_t* dst_base_ptr = mPackQ->host();
+
+ for (int h = 0; h < mNumHead; ++h) {
+ for (int seqBlock = 0; seqBlock < outterSeqLen; ++seqBlock) {
+ int seqBase = seqBlock * eP8;
+ int eunit = std::min(eP8, seqLen - seqBase);
+ size_t currentSeqBlockSize = (size_t)outterHeadDim * eunit * lP8;
+
+ for (int dimBlock = 0; dimBlock < outterHeadDim; ++dimBlock) {
+ int dimBase = dimBlock * lP8;
+ int headDimRemain = mHeadDim - dimBase;
+ int copyLen = std::min(lP8, headDimRemain);
+
+ if (copyLen <= 0) {
+ continue;
+ }
+
+ int8_t* dst_block_ptr = dst_base_ptr +
+ outputOffset +
+ (size_t)dimBlock * (eunit * lP8);
+
+ const size_t src_row_stride = (size_t)mNumHead * mHeadDim;
+
+ for (int seqLocal = 0; seqLocal < eunit; ++seqLocal) {
+ int innerSeq = seqBase + seqLocal;
+
+ const int8_t* src_row_ptr = src_base_ptr +
+ (size_t)innerSeq * src_row_stride +
+ (size_t)h * mHeadDim +
+ dimBase;
+
+ int8_t* dst_row_ptr = dst_block_ptr + seqLocal * lP8;
+
+ std::memcpy(dst_row_ptr, src_row_ptr, copyLen);
+ }
+ if (copyLen < lP8) {
+ for (int seqLocal = 0; seqLocal < eunit; ++seqLocal) {
+ int8_t* dst_pad_ptr = dst_block_ptr + seqLocal * lP8 + copyLen;
+ std::memset(dst_pad_ptr, 0, lP8 - copyLen);
+ }
+ }
+ }
+ outputOffset += currentSeqBlockSize;
+ }
+ } // Finish quantize Q
+
+ if (mQuantValue) {
+ auto scalePtr = (float*)(mQKScale.ptr());
+ auto zeroPtr = (float*)(mQKBias.ptr());
+ for (int k = 0; k < eP8; ++k) {
+ scalePtr[k] = 1.f / 255.f;
+#ifdef MNN_USE_SSE
+ zeroPtr[k] =0;
+#else
+ zeroPtr[k] = 128.f / 255.f;
+#endif
+ }
+ }
+
}
- const float* sinksPtr = sinks ? sinks->host() : nullptr;
+
std::function mCompute = [=](int tId) {
- auto qReordered = mPackQ->host() + tId * mPackQ->stride(0);
- auto qkPacked = tempQKBlock->host() + tId * tempQKBlock->stride(0);
- int8_t * sum_q = nullptr;
+ int8_t* qReordered = nullptr;
+ auto qkPacked = mTempQKBlock->host() + tId * mTempQKBlock->stride(0);
auto qkFlatten = unpackQK->host() + tId * unpackQK->stride(0);
auto qkSoftmax = softmMaxQ->host() + tId * softmMaxQ->stride(0);
auto qkReordered = newPackQK->host() + tId * newPackQK->stride(0);
auto qkvPacked = mPackQKV->host() + tId * mPackQKV->stride(0);
- auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul;
- auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain;
+ int headIndex = tId * numHeadDiv;
+ int headsToCompute = ALIMIN(numHeadDiv, mNumHead - headIndex);
// Flash Attention
auto runningMax = mRunningMax ? (float*)(mRunningMax->host() + tId * mRunningMax->stride(0)) : nullptr;
auto runningSum = mRunningSum ? (float*)(mRunningSum->host() + tId * mRunningSum->stride(0)) : nullptr;
auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host() + tId * mExpfDiffMax->stride(0)) : nullptr;
auto outputPacked = mTempOut ? mTempOut->host() + tId * mTempOut->stride(0) : qkvPacked;
- int head_index = tId * tileCount;
- int kvBlocks = UP_DIV(kv_seq_len, mBlockKV);
+
+ int kvBlocks = UP_DIV(kvSeqLen, mBlockKV);
+
+ bool isLowerTriangular = (mask == nullptr);
+ if (mask != nullptr && mask->shape().empty()) {
+ if (mBytes == 2) {
+ auto maskPtr = mask->host();
+ if (maskPtr[0] < 1e-6) {
+ isLowerTriangular = true;
+ }
+ } else {
+ auto maskPtr = mask->host();
+ if (maskPtr[0] < 1e-6f) {
+ isLowerTriangular = true;
+ }
+ }
+ }
+ bool useMaskInSoftmax = (isLowerTriangular && sinksPtr == nullptr);
+
+ QuanPostTreatParameters gemmParam4QxK, gemmParam4QKxV; // used by int8 gemm, allocated per thread.
+ SumByAxisParams sumParams4QxK, sumParams4QKxV = {};
+ float* qSumAddr = nullptr;
+ float* qScale = nullptr;
+ float* qBias = nullptr;
+ float* accumbuff = nullptr;
+ int32_t unitColBufferSize = 0;
+ if (mQuantKey) {
+ // parameters shared by all mBlockKV
+ gemmParam4QxK.blockNum = mBlockNum;
+ gemmParam4QxK.biasFloat = reinterpret_cast(mGemmBias.get());
+ gemmParam4QxK.useInt8 = 0;
+ gemmParam4QxK.fp32minmax = reinterpret_cast(mGemmRelu.get());
+
+ sumParams4QxK.oneScale = 0;
+ sumParams4QxK.SRC_UNIT = lP8;
+ sumParams4QxK.blockNum = mBlockNum;
+ sumParams4QxK.DST_XUNIT = eP8;
+ sumParams4QxK.inputBlock = 0;
+ sumParams4QxK.kernelxy = 1;
+ // fixed
+ sumParams4QxK.LU = UP_DIV(mHeadDim, lP8);
+ sumParams4QxK.unitColBufferSize = ROUND_UP(mHeadDim, lP8) * eP8;
+ sumParams4QxK.kernelCountUnitDouble = UP_DIV(mHeadDim, lP8);
+ sumParams4QxK.valid = mHeadDim % lP8;
+
+
+ if (mBlockNum > 1) {
+ accumbuff = (float*)(mAccumBuffer.ptr() + tId * eP8 * hP8 * QUANT_INFO_BYTES);
+ }
+ unitColBufferSize = eP8 * ROUND_UP(mHeadDim, lP8);
+
+ if (mQuantValue) {
+ gemmParam4QKxV.blockNum = mBlockNum;
+ gemmParam4QKxV.biasFloat = reinterpret_cast(mGemmBias.get());
+ gemmParam4QKxV.useInt8 = 0;
+ gemmParam4QKxV.fp32minmax = reinterpret_cast(mGemmRelu.get());
+ gemmParam4QKxV.inputScale = (float*)mQKScale.ptr();
+ gemmParam4QKxV.inputBias = (float*)mQKBias.ptr();
+ gemmParam4QKxV.srcKernelSum = (float*)(mSumQK.ptr() + tId * eP8 * QUANT_INFO_BYTES);
+
+ sumParams4QKxV.oneScale = 0;
+ sumParams4QKxV.SRC_UNIT = lP8;
+ sumParams4QKxV.blockNum = mBlockNum;
+ sumParams4QKxV.DST_XUNIT = eP8;
+ sumParams4QKxV.inputBlock = 0;
+ sumParams4QKxV.kernelxy = 1;
+ sumParams4QKxV.unitColBufferSize = ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8) * eP8;
+ sumParams4QKxV.kernelCountUnitDouble = UP_DIV(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8);
+ }
+ }
- if (mUseGemmInt8) {
- qReordered = mPackQ->host() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8;
- sum_q = mSumQ->host() + tId * UP_DIV(seq_len, eP8) * eP8 * 4;
+ size_t vstride0 = ROUND_UP(mHeadDim, hP) * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP);
+ if (mQuantValue) {
+ vstride0 = (ROUND_UP(mHeadDim, hP8) * ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP8) + 2 * QUANT_INFO_BYTES * mBlockNum * ROUND_UP(mHeadDim, hP8));
}
- for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
+
+ // use for V
+ float const* srcPtr[1];
+ // only used for quantized V
+ float vQuantScale[1] = {255.f};
+ float vQuantBias[1] = {-128.f};
+ int32_t infoInt8V[5];
+ infoInt8V[0] = 1; // number
+ infoInt8V[2] = static_cast(sumParams4QKxV.unitColBufferSize);
+ infoInt8V[3] = 1; // stride
+ int32_t elInt8V[4] = {eP8, ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), 0, 0};
+
+ // only used for float V
+ int32_t infoFloatV[4];
+ infoFloatV[0] = 1; // number
+ infoFloatV[1] = seqLen; // eReal
+ infoFloatV[3] = 1; // stride
+ int32_t elFloatV[4] = {seqLen, ROUND_UP(kvSeqLen, lP), 0, 0};
+
+ int offset[2] = {seqLen, mNumHead * mHeadDim};
+
+ for (int h = headIndex; h < headIndex + headsToCompute; h++) {
+ // Prepare for flash attention
if (runningSum && runningMax) {
- memset(runningSum, 0, mRunningSum->stride(0));
if (sinksPtr == nullptr) {
- for (int k = 0; k < seq_len; ++k) {
+ memset(runningSum, 0, mRunningSum->stride(0));
+ for (int k = 0; k < seqLen; ++k) {
runningMax[k] = std::numeric_limits::lowest();
}
} else {
+ for (int k = 0; k < seqLen; ++k) {
+ runningSum[k] = 1.f; // exp(sink-sink)
+ }
float sinkVal;
- if (bytes == 2) {
+ if (mBytes == 2) {
sinkVal = ((FLOAT16_T*)sinksPtr)[h];
} else {
- sinkVal =sinksPtr[h];
+ sinkVal = sinksPtr[h];
}
- for (int k = 0; k < seq_len; ++k) {
+ for (int k = 0; k < seqLen; ++k) {
runningMax[k] = sinkVal;
}
}
}
- int kv_h = h / group_size;
- int8_t * key_addr = mKVCacheManager->addrOfKey(kv_h);
- int8_t * scale_addr = mKVCacheManager->addrOfScale(kv_h);
- int8_t * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
- int8_t * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
- int8_t * value_addr = quant_value ? (dequantV->host() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
- if (mUseGemmInt8) {
- if (bytes == 2) {
- pack_query(query, qReordered, sum_q, seq_len, h, q_scale);
- } else {
- pack_query(query, qReordered, sum_q, seq_len, h, q_scale);
- }
+
+ // Compute the current addresses
+ int kvHeadIndex = h / group_size;
+ int8_t * keyAddr = mKVCacheManager->addrOfKey(kvHeadIndex);
+ int8_t * keySum = mKVCacheManager->addrOfKeySum(kvHeadIndex);
+ int8_t * valueAddr = mKVCacheManager->addrOfValue(kvHeadIndex);
+ float* valueSum = (float*)mKVCacheManager->addrOfValueSum(kvHeadIndex);
+
+ // Get packed Q
+ if (mQuantKey == false) {
+ qReordered = mPackQ->host() + tId * mPackQ->stride(0);
+ gcore->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host() + h * mHeadDim * mBytes), mHeadDim * mNumHead, &q_scale, units, seqLen, mHeadDim);
} else {
- core->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host() + h * mHeadDim * bytes), mHeadDim * mNumHead, &q_scale, units, seq_len, mHeadDim);
+ qReordered = mPackQ->host() + h * mPackQ->stride(0);
+ qSumAddr = (float*)(mSumQ.ptr() + tId * ROUND_UP(seqLen, eP8) * mBlockNum * QUANT_INFO_BYTES);
+ qScale = (float*)(mQueryScale.ptr() + h * seqLen * mBlockNum * QUANT_INFO_BYTES);
+ qBias = (float*)(mQueryZeroPoint.ptr() + h * seqLen * mBlockNum * QUANT_INFO_BYTES);
+ gcore->MNNSumByAxisLForMatmul_A(qSumAddr, qReordered, qScale, seqLen, sumParams4QxK);
}
+
+ // Start computing
for (int i = 0; i < kvBlocks; ++i) {
- int subKvSeqLen = ALIMIN(mBlockKV, kv_seq_len - i * mBlockKV);
- auto keyPtr = key_addr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * bytes;
- auto valuePtr = value_addr + i * UP_DIV(mBlockKV, lP) * hP * lP * bytes;
- // query @ key
- {
- int loop_e = seq_len / eP;
- int remain = seq_len % eP;
- auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes;
- size_t shapeParameters[7] = {(size_t)eP * lP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seq_len * mPack * bytes, 0, 0, 0};
+ int subKvSeqLen = ALIMIN(mBlockKV, kvSeqLen - i * mBlockKV);
+ // 1. query @ key
+ if (mQuantKey == false) {
+ auto keyPtr = keyAddr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
+ int loop_e = seqLen / eP;
+ int remain = seqLen % eP;
+ auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * mBytes;
+ size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seqLen * mPack * mBytes, 0, 0, 0};
for (int ei = 0 ; ei < loop_e; ei++) {
- QxK((float*)(qkPacked + (ei * eP * mPack) * bytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
+ gcore->MNNPackedMatMul((float*)(qkPacked + (ei * eP * mPack) * mBytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
+ }
+ if (remain > 0) {
+ gcore->MNNPackedMatMulRemain((float*)(qkPacked + (loop_e * eP * mPack) * mBytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
+ }
+ } else {
+ auto eRemain = seqLen;
+ auto srcInt8 = qReordered;
+ auto dstInt8 = qkPacked;
+ auto keyPtr = keyAddr + i * UP_DIV(mBlockKV, hP8) * (ROUND_UP(mHeadDim, lP8) * hP8 + 2 * hP8 * QUANT_INFO_BYTES);
+ gemmParam4QxK.weightKernelSum = (float*)(keySum + i * mBlockKV * QUANT_INFO_BYTES);
+ gemmParam4QxK.inputScale = qScale;
+ gemmParam4QxK.inputBias = qBias;
+ gemmParam4QxK.srcKernelSum = qSumAddr;
+ while (eRemain > 0) {
+ auto eSize = ALIMIN(eP8, eRemain);
+ mInt8GemmKernel(dstInt8, srcInt8, keyPtr, UP_DIV(mHeadDim, lP8), mBytes * seqLen * mPack, UP_DIV(subKvSeqLen, mPack), &gemmParam4QxK, eSize);
+ eRemain -= eP8;
+ gemmParam4QxK.inputScale += eP8;
+ gemmParam4QxK.inputBias += eP8;
+ gemmParam4QxK.srcKernelSum += eP8;
+ srcInt8 += unitColBufferSize;
+ dstInt8 += eP8 * mPack * mBytes;
+ if (mBlockNum > 1) {
+ memset(accumbuff, 0, eP8 * hP8 * QUANT_INFO_BYTES);
+ gemmParam4QxK.accumBuffer = accumbuff;
+ }
}
- QxK_remain((float*)(qkPacked + (loop_e * eP * mPack) * bytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
}
- // qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len, eP]
+ // 2. softmax scores, softmax src/dst shape: [kv_seq_len/mPack, seq_len, mPack]
{
- if(bytes == 2) {
- if (seq_len == 1) {
- core->MNNLowpToFp32((int16_t*)qkPacked, qkFlatten, seq_len * subKvSeqLen);
- } else {
- core->MNNAttenUnpackAndConvertFp16(qkFlatten, (float*)qkPacked, subKvSeqLen, seq_len, mPack);
- }
- mask_QK(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen, padSeqLength);
- softmaxQK(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1);
- core->MNNAttenPackAndConvertFp32((float*)qkReordered, qkSoftmax, units, seq_len, subKvSeqLen);
- } else {
- if (seq_len > 1) {
- int32_t areaOffset[2] = {seq_len, seq_len};
- core->MNNUnpackCUnitTranspose(qkFlatten, (float*)qkPacked, seq_len, subKvSeqLen, areaOffset);
+ if (mQuantKey == false || isLowerTriangular == false || sinksPtr != nullptr) {
+ if (mBytes == 2) {
+ _maskQK((float*)qkPacked, &mScale, seqLen, subKvSeqLen, mPack, kvSeqLen, i * mBlockKV, padSeqLength, sinksPtr, mask, mQuantKey, isLowerTriangular);
} else {
- memcpy(qkFlatten, qkPacked, subKvSeqLen * sizeof(float));
+ _maskQK((float*)qkPacked, &mScale, seqLen, subKvSeqLen, mPack, kvSeqLen, i * mBlockKV, padSeqLength, sinksPtr, mask, mQuantKey, isLowerTriangular);
}
- mask_QK(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen, padSeqLength);
- softmaxQK(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1);
- MNNPackForMatMul_A((float*)qkReordered, qkSoftmax, seq_len, subKvSeqLen, eP, lP, bytes);
}
+ gcore->MNNSoftmax(qkSoftmax, (float*)qkPacked, runningMax, runningSum, diffScale, seqLen, subKvSeqLen, i * mBlockKV, kvValidOffset, mPack, useMaskInSoftmax);
}
- // qk @ v
- // TODO: update qkvPacked using diffScale
- size_t shapeParameters[7] = {(size_t)eP * lP * bytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seq_len * mPack * bytes, 0, 0, 0};
- size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(subKvSeqLen + i * mBlockKV, lP) + UP_DIV(i * mBlockKV, lP)) * hP * lP * bytes;
- shapeParameters[5] = quant_value ? 0 : bExtraStride;
- int loop_e = seq_len / eP;
- int remain = seq_len % eP;
- auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * bytes;
- for (int ei = 0 ; ei < loop_e; ei++) {
- core->MNNPackedMatMul((float*)(qkvPacked + (ei * eP * mPack) * bytes), (float*)(qkReordered + ei * qkStride0), (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
+ // 3. qk @ v
+ auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * mBytes;
+ auto rowStart = (!isLowerTriangular || i * mBlockKV < kvValidOffset)? 0 : (i * mBlockKV - kvValidOffset);
+
+ if (mQuantValue == false) {
+ auto valuePtr = valueAddr + i * vstride0 * mBytes;
+ size_t shapeParameters[7] = {(size_t)eP * lP * mBytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seqLen * mPack * mBytes, 0, 0, 0};
+ size_t bExtraStride = (i < kvBlocks - 1) ? 0 : (ROUND_UP(mKVCacheManager->getFlashAttentionBlockKv(), lP) - ROUND_UP(subKvSeqLen, lP)) * hP * mBytes;
+ shapeParameters[5] = bExtraStride;
+
+ int loop_e = (seqLen - rowStart) / eP;
+ int remain = (seqLen - rowStart) % eP;
+
+ int ei = 0;
+ elFloatV[0] = eP;
+ elFloatV[1] = ROUND_UP(subKvSeqLen, lP);
+ infoFloatV[2] = eP;
+ for ( ; ei < loop_e; ei++) {
+ srcPtr[0] = (float const*)((int8_t*)qkSoftmax + (ei * eP + rowStart) * mPack * mBytes);
+ gcore->MNNPackC4ForMatMul_A((float*)qkReordered, srcPtr, infoFloatV, elFloatV);
+ gcore->MNNPackedMatMul((float*)(qkvPacked + (ei * eP + rowStart) * mPack * mBytes), (float*)qkReordered, (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
+ }
+ if (remain > 0) {
+ elFloatV[0] = remain;
+ infoFloatV[2] = remain;
+ srcPtr[0] = (float const*)((int8_t*)qkSoftmax + (loop_e * eP + rowStart) * mPack * mBytes);
+ shapeParameters[0] = remain * lP * mBytes;
+ gcore->MNNPackC4ForMatMul_A((float*)qkReordered, srcPtr, infoFloatV, elFloatV);
+ gcore->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP + rowStart) * mPack * mBytes), (float*)qkReordered, (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
+ }
+ } else { // use int8 kernel to compute qk@ v
+ auto valuePtr = valueAddr + i * vstride0;
+ auto eRemain = seqLen - rowStart;
+ auto qkPtr = (int8_t*)(qkSoftmax) + rowStart * mPack * mBytes; // [UP_DIV(subKvSeqLen,pack),seqLen,pack]
+ auto qkvFloat = qkvPacked + rowStart * mPack * mBytes;
+ gemmParam4QKxV.weightKernelSum = valueSum + i * ROUND_UP(mHeadDim, hP8);
+ sumParams4QKxV.valid = subKvSeqLen % lP8;
+ sumParams4QKxV.LU = UP_DIV(subKvSeqLen, lP8);
+
+ auto dstInt8Ptr = (int8_t*)mQuantQK.ptr() + tId * eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack);
+ srcPtr[0] = (const float*)(dstInt8Ptr);
+
+ while (eRemain > 0) {
+ auto eSize = ALIMIN(eRemain, eP8);
+
+ memset(dstInt8Ptr, 0, eP8 * ROUND_UP(MNN_FLASH_ATTENTION_BLOCK_SIZE, mPack));
+
+ infoInt8V[1] = eSize; // eReal
+ infoInt8V[4] = eSize; // e to process
+ elInt8V[0] = eSize; // e to process
+
+
+ for (int qi = 0; qi < UP_DIV(subKvSeqLen, mPack); ++qi) {
+ mQuantFunc((float*)(qkPtr + qi * seqLen * mPack * mBytes), dstInt8Ptr + qi * eSize * mPack, eSize, vQuantScale, -128, 127, vQuantBias, 0);
+ }
+ core->MNNPackC4Int8ForMatMul_A(qkReordered, (int8_t const **)srcPtr, infoInt8V, elInt8V);
+ // mSumQK
+ gcore->MNNSumByAxisLForMatmul_A(gemmParam4QKxV.srcKernelSum, qkReordered, (float*)mQKScale.ptr(), eSize, sumParams4QKxV);
+ mInt8GemmKernel(qkvFloat, qkReordered, valuePtr, UP_DIV(MNN_FLASH_ATTENTION_BLOCK_SIZE, lP8), mBytes * seqLen * mPack, UP_DIV(mHeadDim, mPack), &gemmParam4QKxV, eSize);
+
+ eRemain -= eSize;
+ qkPtr += (eSize * mPack * mBytes);
+ qkvFloat += (eSize * mPack * mBytes);
+ }
}
- core->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP * mPack) * bytes), (float*)(qkReordered + loop_e * qkStride0), (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
+ // 4. flash attention, update each sub kvSeq's final results
if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) {
- core->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seq_len, mPack, i, kvBlocks, mPackQKV->stride(0) / bytes, bytes);
+ gcore->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seqLen, mPack, i, kvBlocks, mPackQKV->stride(0) / mBytes, mBytes, rowStart);
}
}
- // unpack: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim]
- auto dst_ptr = outputs[0]->host() + h * mHeadDim * bytes;
- if (bytes == 2) {
- unpack_QKV((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len);
- } else {
- unpack_QKV((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len);
- }
+ // Final results writing: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim]
+ auto dstPtr = outputs[0]->host() + h * mHeadDim * mBytes;
+ // offset = {seqLen, mNumHead * mHeadDim};
+ gcore->MNNUnpackCUnitTranspose((float*)dstPtr, (float*)outputPacked, seqLen, mHeadDim, offset);
}
};
@@ -483,16 +786,14 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std:
backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC);
backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC);
- backend()->onReleaseBuffer(tempQKBlock.get(), Backend::STATIC);
- if (quant_value){
- backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC);
- }
+ backend()->onReleaseBuffer(mTempQKBlock.get(), Backend::STATIC);
+
if (!mKVCache) {
mKVCacheManager->onClear();
}
auto ptr = outputs[0]->host();
- if (seq_len < outputs[0]->length(1)) {
- ::memset(outputs[0]->host() + seq_len * mHeadDim * mNumHead * bytes, 0, (outputs[0]->length(1)-seq_len) * mHeadDim * mNumHead * bytes);
+ if (seqLen < outputs[0]->length(1)) {
+ ::memset(outputs[0]->host() + seqLen * mHeadDim * mNumHead * mBytes, 0, (outputs[0]->length(1)-seqLen) * mHeadDim * mNumHead * mBytes);
}
return NO_ERROR;
}
@@ -512,24 +813,20 @@ CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend)
mPackQ.reset(Tensor::createDevice({1, 1, 1, 1}));
mPackQKV.reset(Tensor::createDevice({1, 1, 1, 1}));
MNN::KVCacheManager::KVCacheConfig kvconfig;
- int qkvQuantOptions = static_cast(backend)->getRuntime()->hint().qkvQuantOption;
- kvconfig.mUseInt8Kernel = (qkvQuantOptions % 8 == 4);
- // qkvQuantOption % 8:
+ // attentionOption % 8:
// 0: Do not quantize
- // 1: Only quantize key, use int8 asymmetric quantization
- // 2: Only quantize value, use fp8 quantization
- // 3: quantize both key and value
- // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
+ // 1: Q,K: Int8, V: Float32
+ // 2: Q,K,V: Int8
- // qkvQuantOption / 8:
+ // attentionOption / 8:
+ // 0: do not use flash attention
// 1: use flash attention
- kvconfig.mQuantKey = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 1) || (qkvQuantOptions % 8 == 3);
- kvconfig.mQuantValue = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 2);
kvconfig.mKVCacheDir = static_cast(backend)->getRuntime()->hint().kvcacheDirPath;
- kvconfig.mKVCacheSizeLimit = static_cast(backend)->getRuntime()->hint().kvcacheSizeLimit;
+ kvconfig.mPrefixCacheDir = static_cast(backend)->getRuntime()->hint().prefixcacheDirPath;
kvconfig.mExpandChunk = 64;
- mKVCacheManager.reset(new KVCacheManager(backend, kvconfig));
+ kvconfig.mBlockNum = 1;
+ mKVCacheManager.reset(new CPUKVCacheManager(backend, kvconfig));
}
CPUAttention::~CPUAttention() {
diff --git a/source/backend/cpu/CPUAttention.hpp b/source/backend/cpu/CPUAttention.hpp
index 8739fb711c..22f98a2925 100644
--- a/source/backend/cpu/CPUAttention.hpp
+++ b/source/backend/cpu/CPUAttention.hpp
@@ -14,8 +14,8 @@
#include
#include "core/Execution.hpp"
#include "core/OpCommonUtils.hpp"
+#include "CPUKVCacheManager.hpp"
#include "MNN/ErrorCode.hpp"
-#include "KVCacheManager.hpp"
namespace MNN {
@@ -28,19 +28,32 @@ class CPUAttention : public Execution {
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
private:
bool mKVCache = true;
- bool mUseGemmInt8 = false;
- int bytes = 4;
+ int mBytes = 4;
int mThreadNum = 1;
int mBlockKV = 512;
int eP, lP, hP, mPack; // float matmul packing
int eP8, lP8, hP8; // GemmInt8 packing
int mNumHead, mKvNumHead, mHeadDim;
- std::shared_ptr mPackQ, mPackQKV, mSumQ, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax;
- std::shared_ptr mKVCacheManager = nullptr;
- std::vector mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint;
- template void pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale);
- template void unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len);
KVMeta* mMeta;
+
+ // common
+ std::shared_ptr mPackQ, mPackQKV, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax;
+ std::shared_ptr mKVCacheManager = nullptr;
+ bool mUseFlashAttention = true;
+
+ // quant Query/Key/Value
+ bool mQuantKey = false;
+ bool mQuantValue = false;
+ int mBlockNum = 1;
+ MemChunk mSumQ;
+ MemChunk mQueryScale, mQueryZeroPoint, mQueryQuantScale, mQueryQuantZero;
+ MemChunk mQuantQuery, mAccumBuffer;
+
+ MemChunk mQuantQK, mQKScale, mQKBias, mSumQK, mArray;
+ AutoStorage mGemmBias, mGemmRelu;
+
+ std::function mQuantFunc;
+ decltype(CoreInt8Functions::Int8GemmKernel) mInt8GemmKernel;
};
} // namespace MNN
diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp
index ceb23910fd..8d284aa33b 100644
--- a/source/backend/cpu/CPUBackend.cpp
+++ b/source/backend/cpu/CPUBackend.cpp
@@ -104,15 +104,14 @@ void CPURuntime::_bindCPUCore() const {
#ifdef MNN_USE_THREAD_POOL
if (nullptr != mThreadPool) {
mThreadPool->active();
- mThreadPool->enqueue(std::make_pair([&](int i) {
+ ThreadPool::TASK task = std::make_pair([&](int i) {
MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second);
- return 0;
- }, mThreadNumber), mTaskIndex);
+ }, mThreadNumber);
+ mThreadPool->enqueue(&task, mTaskIndex);
mThreadPool->deactive();
}
#endif
}
-
void CPURuntime::_resetThreadPool() const {
mThreadNumber = std::max(1, mThreadNumber);
mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER);
@@ -326,11 +325,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
auto core = MNNGetCoreFunctions();
if (core->supportFp16arith && precision == BackendConfig::Precision_Low) {
res = new Arm82Backend(this, memory);
- if (hint().useArmSme2Cores && core->supportSME2 && res->functions()->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernel) {
- res->mRelatedFunctions = &(res->functions()->sme2Int8MatmulRelatedFuncionsHp32);
- } else {
- res->mRelatedFunctions = &(res->functions()->int8MatmulRelatedFunctions);
- }
+ res->mRelatedFunctions = &(res->functions()->int8MatmulRelatedFunctions);
break;
}
#endif
@@ -458,12 +453,8 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
mRuntime = const_cast(runtime);
auto core = MNNGetCoreFunctions();
mThreadNumber = mRuntime->mThreadNumber;
- if (mRuntime->hint().useArmSme2Cores && core->supportSME2 && core->sme2Int8MatmulRelatedFuncionsHp32.Int8GemmKernel) {
- mThreadNumber = ALIMIN(2, mThreadNumber);
- mRelatedFunctions = &core->sme2Int8MatmulRelatedFuncionsHp32;
- } else {
- mRelatedFunctions = &core->int8MatmulRelatedFunctions;
- }
+ mRelatedFunctions = &core->int8MatmulRelatedFunctions;
+
// Compute Group Rate
do {
if (mThreadNumber <= 1 || mRuntime->mPower == BackendConfig::Power_Low) {
@@ -499,6 +490,7 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
currentRate *= decreaseRate;
totalComputeRate += currentRate * selectSize;
mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize));
+ groupIndex--;
}
for (auto& g : mGroupWithComputeRate) {
g.first = g.first / totalComputeRate;
diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp
index 884036eb38..ec4c555dec 100644
--- a/source/backend/cpu/CPUBackend.hpp
+++ b/source/backend/cpu/CPUBackend.hpp
@@ -176,6 +176,9 @@ class CPUBackend : public Backend {
#ifdef MNN_USE_THREAD_POOL
inline int taskIndex() const {return mRuntime->mTaskIndex;}
inline ThreadPool* threadPool() const {return mRuntime->mThreadPool;}
+ void enqueue(ThreadPool::TASK& task) const {
+ threadPool()->enqueue(&task, taskIndex());
+ }
#endif
static void initCreatorMap();
static size_t getBytes(const Backend* backend, const Tensor* output);
diff --git a/source/backend/cpu/CPUBinary.cpp b/source/backend/cpu/CPUBinary.cpp
index 059e502d0b..61ccf4fca3 100644
--- a/source/backend/cpu/CPUBinary.cpp
+++ b/source/backend/cpu/CPUBinary.cpp
@@ -45,6 +45,37 @@ ErrorCode CPUBinary::onResize(const std::vector& inputs, const std::vec
mThreadNum = threads;
mWorkDiv = UP_DIV(mTotalSize, threads);
}
+ int inpBytes = inputs[0]->getType().bytes();
+ int outBytes = outputs[0]->getType().bytes();
+ if (halide_type_float == inputs[0]->getType().code) {
+ inpBytes = static_cast(backend())->functions()->bytes;
+ }
+ if (halide_type_float == outputs[0]->getType().code) {
+ outBytes = static_cast(backend())->functions()->bytes;
+ }
+ bool outputInt = outputs[0]->getType().code == halide_type_int;
+ mTask = std::make_pair([this, inpBytes, outBytes, outputInt](int tId) {
+ int start = tId * mWorkDiv;
+ int realSize = ALIMIN(mWorkDiv, mTotalSize - start);
+ if (realSize > 0) {
+ auto inp0 = mInput0Ptr + start * inpBytes;
+ auto inp1 = mInput1Ptr + start * inpBytes;
+ if (mNeedBroadcastIndex == 0) {
+ inp0 = mInput0Ptr;
+ } else if (mNeedBroadcastIndex == 1) {
+ inp1 = mInput1Ptr;
+ }
+ auto out = mOutputPtr + start * outBytes;
+ mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex);
+ if(mActivationType == 1 && outputInt) {
+ for(int i=0; i 0 ? val : 0;
+ ((int32_t *)out)[i] = res;
+ }
+ }
+ }
+ } , mThreadNum);
return NO_ERROR;
}
@@ -67,31 +98,10 @@ ErrorCode CPUBinary::onExecute(const std::vector& inputs, const std::ve
outBytes = static_cast(backend())->functions()->bytes;
}
auto precision = static_cast(backend())->precisionMode();
-
- MNN_CONCURRENCY_BEGIN(tId, mThreadNum) {
- int start = tId * mWorkDiv;
- int realSize = ALIMIN(mWorkDiv, mTotalSize - start);
- if (realSize > 0) {
- auto inp0 = input0Ptr + start * inpBytes;
- auto inp1 = input1Ptr + start * inpBytes;
- if (mNeedBroadcastIndex == 0) {
- inp0 = input0Ptr;
- } else if (mNeedBroadcastIndex == 1) {
- inp1 = input1Ptr;
- }
- auto out = outputPtr + start * outBytes;
- mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex);
- if(mActivationType == 1 && output->getType().code == halide_type_int) {
- for(int i=0; i 0 ? val : 0;
- ((int32_t *)out)[i] = res;
- }
- }
- }
- }
- MNN_CONCURRENCY_END();
-
+ mInput0Ptr = input0Ptr;
+ mInput1Ptr = input1Ptr;
+ mOutputPtr = outputPtr;
+ MNN_CONCURRENCY_ENQUEUE(mTask);
if(mActivationType == 1 && output->getType().code == halide_type_float) {
mActivationExe->onExecute(outputs, outputs);
}
diff --git a/source/backend/cpu/CPUBinary.hpp b/source/backend/cpu/CPUBinary.hpp
index 9250df79ae..17cb3b5f47 100644
--- a/source/backend/cpu/CPUBinary.hpp
+++ b/source/backend/cpu/CPUBinary.hpp
@@ -33,6 +33,10 @@ class CPUBinary : public Execution {
int mThreadNum;
int mWorkDiv;
std::shared_ptr mActivationExe;
+ std::pair, int> mTask;
+ uint8_t* mInput0Ptr = nullptr;
+ uint8_t* mOutputPtr = nullptr;
+ uint8_t* mInput1Ptr = nullptr;
};
} // namespace MNN
#endif /* CPUBinary_hpp */
diff --git a/source/backend/cpu/CPUConvolution.cpp b/source/backend/cpu/CPUConvolution.cpp
index 12c9ceff06..a47777b2f8 100644
--- a/source/backend/cpu/CPUConvolution.cpp
+++ b/source/backend/cpu/CPUConvolution.cpp
@@ -159,8 +159,8 @@ void CPUConvolution::MutableResourceInt8::updateInputOutputScale(std::vector(biasData[i] / (mInputScale * weightScale)) - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) + outputZeroPointFused;
}
} else {
+ auto outputScale = mResource->mWeightBits == 4 ? 1.f : mOutputScale;
+ int32_t outputZero = mResource->mWeightBits == 4 ? 0 : mOutputZeroPoint;
for (int i = 0; i < ocUp4; ++i) {
- biasfloat[i] = (biasData[i] - mResource->mWeightKernelSum->host()[i] * (mInputZeroPoint + offset) * mInputScale) / mOutputScale + mOutputZeroPoint;
+ biasfloat[i] = (biasData[i] - mResource->mWeightKernelSum->host()[i] * (mInputZeroPoint + offset) * mInputScale) / outputScale + outputZero;
}
}
@@ -224,9 +226,9 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B
auto scalePtr = resource->mOriginScale->host();
memset(scalePtr, 0, ocUpUnit * 2 * sizeof(float));
- resource->mActBits = 8;
+ resource->mWeightBits = 8;
if (convParam->symmetricQuan()) {
- resource->mActBits = convParam->symmetricQuan()->nbits();
+ resource->mWeightBits = convParam->symmetricQuan()->nbits();
}
const int8_t* weightSrc = nullptr;
int weightSize = 0;
diff --git a/source/backend/cpu/CPUConvolution.hpp b/source/backend/cpu/CPUConvolution.hpp
index 26ca877602..13abecef0f 100644
--- a/source/backend/cpu/CPUConvolution.hpp
+++ b/source/backend/cpu/CPUConvolution.hpp
@@ -66,7 +66,7 @@ class CPUConvolution : public Execution {
std::vector mReluThreshold;
// relu or relu6
bool mRelu;
- int mActBits; // quant bits
+ int mWeightBits; // quant bits
bool mUseConvQuan = true;
bool mWeightAsymmetricQuant = true;
diff --git a/source/backend/cpu/CPUKVCacheManager.cpp b/source/backend/cpu/CPUKVCacheManager.cpp
new file mode 100644
index 0000000000..82e4685a3d
--- /dev/null
+++ b/source/backend/cpu/CPUKVCacheManager.cpp
@@ -0,0 +1,796 @@
+//
+// CPUKVCacheManager.cpp
+// MNN
+//
+// Created by MNN on 2024/08/05.
+// Copyright © 2018, Alibaba Group Holding Limited
+//
+
+#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
+
+#include "CPUKVCacheManager.hpp"
+#include "core/Concurrency.h"
+
+namespace MNN {
+
+/*
+** @brief Expand the size of kvcache and copy it from the old tensor in memory to the new tensor in memory
+** Finally reset the pointer to the new tensor
+*/
+void CPUKVCacheManager::expandKVCacheInMem(int oldMaxLength) {
+ /*=================================== Key ===================================*/
+ auto new_key = Tensor::createDevice({mKvNumHead, (int)mCurrentKeySizePerHead});
+ mBackend->onAcquireBuffer(new_key, Backend::STATIC);
+ if (mQuantKey) {
+ memset(new_key->host(), 0, mKvNumHead * mCurrentKeySizePerHead);
+ }
+ for (int h = 0; h < mKvNumHead; h++) {
+ memcpy(
+ new_key->host() + h * mCurrentKeySizePerHead,
+ mPastKey->host() + h * mPastKey->stride(0),
+ mPastKey->stride(0)
+ );
+ if (!mQuantKey && (new_key->stride(0) - mPastKey->stride(0)) > 0) {
+ memset(new_key->host() + h * new_key->stride(0) + mPastKey->stride(0), 0, (new_key->stride(0) - mPastKey->stride(0)));
+ }
+ }
+ mPastKey.reset(new_key);
+ /*=================================== Value ===================================*/
+ auto newValue = Tensor::createDevice({mKvNumHead, (int)mCurrentValueSizePerHead});
+ mBackend->onAcquireBuffer(newValue, Backend::STATIC);
+
+ if (mUseFlashAttention) { // [mKvNumHead, UP_DIV(mMaxLength, mFlashAttentionUpperKv), UP_DIV(mHeadDim, hP), UP_DIV(mFlashAttentionUpperKv, lP), hP, lP]
+ for (int h = 0; h < mKvNumHead; h++) {
+ memset(newValue->host() + h * newValue->stride(0), 0, newValue->stride(0));
+ memcpy(
+ newValue->host() + h * newValue->stride(0),
+ mPastValue->host