diff --git a/.gitmodules b/.gitmodules index 7b9576416..4e04dd80a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,7 @@ [submodule "third_party/gpudma"] path = third_party/gpudma url = https://github.com/karakozov/gpudma + +[submodule "examples/llama/llama"] + path = examples/llama/llama + url = https://github.com/facebookresearch/llama diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index f0353e8d4..273bfa67e 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -4,6 +4,7 @@ #ifndef ARK_KERNELS_LAYERNORM_H_ #define ARK_KERNELS_LAYERNORM_H_ +#include "math_functions.h" #include "reduce.h" namespace ark { @@ -63,21 +64,18 @@ struct LayerNorm (tid_c + uc * UnitOutDims::C) * InDims::HW + (tid_n + un * UnitOutDims::N) * InDims::CHW; - DataType reduced; - ReduceTypeMean::singleIdentity(&reduced); + DataType reduced = ReduceTypeMean::singleIdentity(); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_in_w; - ReduceTypeMean::singleReduce(&reduced, &reduced, &in[idx_in]); + ReduceTypeMean::singleReduce(reduced, in[idx_in_base + idx_in_w]); } UnitOp::sync_threads(); // final reduction on shared memory using warp shuffle. reduced = warpsReduce( reduced, tid, smem_per_warp); // get the average result. - ReduceTypeMean::singlePostReduce(&reduced, &reduced, UnitOutDims::W); - DataType variance; - ReduceTypeMean::singleIdentity(&variance); + ReduceTypeMean::singlePostReduce(reduced, UnitOutDims::W); + DataType variance = ReduceTypeMean::singleIdentity(); // get the variance UnitOp::sync_threads(); for (int idx_in_w = tid_w; idx_in_w < InShape::W; @@ -88,7 +86,7 @@ struct LayerNorm UnitOp::sync_threads(); variance = warpsReduce( variance, tid, smem_per_warp); - ReduceTypeMean::singlePostReduce(&variance, &variance, UnitOutDims::W); + ReduceTypeMean::singlePostReduce(variance, UnitOutDims::W); UnitOp::sync_threads(); // the output is (input - mean) / sqrt(variance) for (int idx_in_w = tid_w; idx_in_w < InShape::W; @@ -127,7 +125,8 @@ DEVICE void layernorm(ark::half *out, const ark::half *in, int uop_idx, // Root Mean Square Layer Normalization: https://arxiv.org/pdf/1910.07467.pdf template + int SmemBytes, typename DataType, typename CompType, + int NelemPerThread> struct RMSNorm { using UnitOp = @@ -138,7 +137,8 @@ struct RMSNorm int smem_per_warp) { using InOutChk = LayerNormShapeChecker; - using ReduceTypeMean = ReduceTypeMean; + using ReduceTypeMean = + ReduceTypeMean; constexpr int NonReduceDimLength = UnitOutDims::NCH; // The reduction dimension of the final stage. @@ -166,25 +166,27 @@ struct RMSNorm (tid_c + uc * UnitOutDims::C) * InDims::HW + (tid_n + un * UnitOutDims::N) * InDims::CHW; - DataType variance; - ReduceTypeMean::singleIdentity(&variance); + CompType var = ReduceTypeMean::singleIdentity(); + // get the variance UnitOp::sync_threads(); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_in_w; - variance += (in[idx_in]) * (in[idx_in]); + CompType data = static_cast(in[idx_in_base + idx_in_w]); + var += data * data; } UnitOp::sync_threads(); - variance = warpsReduce( - variance, tid, smem_per_warp); - ReduceTypeMean::singlePostReduce(&variance, &variance, UnitOutDims::W); + var = warpsReduce( + var, tid, smem_per_warp) / + UnitOutDims::W; UnitOp::sync_threads(); - // the output is (input - mean) / sqrt(variance) + // the output is (input - mean) / sqrt(reduced) for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; - out[idx_in] = (in[idx_in]) * rsqrtf(variance + 1e-5f); + CompType data = static_cast(in[idx_in]); + out[idx_in] = static_cast( + data * Rsqrt::compute(var + CompType(1e-5f))); } } }; @@ -196,8 +198,8 @@ DEVICE void rmsnorm(float *out, const float *in, int uop_idx, int smem_per_warp) { constexpr int NelemPerThread = 1; RMSNorm::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, float, float, NelemPerThread>::run(out, in, uop_idx, + smem_per_warp); } template ::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, ark::half, float, NelemPerThread>::run(out, in, uop_idx, + smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/math_functions.h b/ark/include/kernels/math_functions.h index 62872fcfc..4cdfe4ef1 100644 --- a/ark/include/kernels/math_functions.h +++ b/ark/include/kernels/math_functions.h @@ -10,11 +10,15 @@ namespace ark { struct Exp { - static DEVICE float compute(float input) + static DEVICE float compute(const float &input) { return expf(input); } - static DEVICE __half2 compute(__half2 input) + static DEVICE __half compute(const __half &input) + { + return hexp(input); + } + static DEVICE __half2 compute(const __half2 &input) { return h2exp(input); } @@ -22,16 +26,36 @@ struct Exp struct Sqrt { - static DEVICE float compute(float input) + static DEVICE float compute(const float &input) { return sqrtf(input); } - static DEVICE __half2 compute(__half2 input) + static DEVICE __half compute(const __half &input) + { + return hsqrt(input); + } + static DEVICE __half2 compute(const __half2 &input) { return h2sqrt(input); } }; +struct Rsqrt +{ + static DEVICE float compute(const float &input) + { + return rsqrtf(input); + } + static DEVICE __half compute(const __half &input) + { + return hrsqrt(input); + } + static DEVICE __half2 compute(const __half2 &input) + { + return h2rsqrt(input); + } +}; + template struct Math; diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index c9756e73a..ec6422d59 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -18,49 +18,49 @@ typedef enum } AxisType; // Shared memory for reduction. -template struct ReduceSharedStorage +template struct ReduceSharedStorage { - DataType storage[32]; + CompType storage[32]; }; /* Reduce single-precision `val` within a single warp. */ template -DEVICE DataType warpReduce(DataType val) + typename CompType = typename ReduceType::CompType> +DEVICE CompType warpReduce(const CompType &val) { - DataType res = val; - DataType tmp; + CompType res = val; + CompType tmp; if (LanesNum >= 32) { tmp = __shfl_xor_sync(0xffffffff, res, 16, 32); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); tmp = __shfl_xor_sync(0xffffffff, res, 8, 16); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); tmp = __shfl_xor_sync(0xffffffff, res, 4, 8); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); tmp = __shfl_xor_sync(0xffffffff, res, 2, 4); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); tmp = __shfl_xor_sync(0xffffffff, res, 1, 2); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); } else { if (LanesNum > 16) { tmp = __shfl_xor_sync(0xffffffff, res, 16, 32); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); } if (LanesNum > 8) { tmp = __shfl_xor_sync(0xffffffff, res, 8, 16); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); } if (LanesNum > 4) { tmp = __shfl_xor_sync(0xffffffff, res, 4, 8); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); } if (LanesNum > 2) { tmp = __shfl_xor_sync(0xffffffff, res, 2, 4); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); } if (LanesNum > 1) { tmp = __shfl_xor_sync(0xffffffff, res, 1, 2); - ReduceType::singleReduce(&res, &res, &tmp); + ReduceType::singleReduce(res, tmp); } } return res; @@ -68,28 +68,28 @@ DEVICE DataType warpReduce(DataType val) // Reduce single-precision `val` within multiple warps. template -DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) + typename CompType = typename ReduceType::CompType> +DEVICE CompType warpsReduce(const CompType &val, int tid, int smem_per_warp) { - val = warpReduce(val); + CompType res = warpReduce(val); if (LanesNum > 32) { - ReduceSharedStorage *shared = - UnitOp::template shared_memory>( + ReduceSharedStorage *shared = + UnitOp::template shared_memory>( smem_per_warp); int laneId = tid & 31; int warpId = tid >> 5; if (laneId == 0) { - shared->storage[warpId] = val; + shared->storage[warpId] = res; } UnitOp::sync_threads(); if (laneId < (LanesNum >> 5)) { - val = shared->storage[laneId]; + res = shared->storage[laneId]; } else { - ReduceType::singleIdentity(&val); + res = ReduceType::singleIdentity(); } - val = warpReduce(val); + res = warpReduce(res); } - return val; + return res; } // Check if InShape can be reduced into OutShape and if UnitOutDims is valid. @@ -118,65 +118,65 @@ struct ReduceShapeChecker "Invalid UnitOutDims::W"); }; -template struct ReduceTypeSum +template +struct ReduceTypeSum { using DataType = _DataType; + using CompType = _CompType; static const int NelemPerThread = _NelemPerThread; - static DEVICE void identity(DataType *v) + static DEVICE void identity(CompType *v) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { v[elem] = 0; } } - static DEVICE void reduce(DataType *out, const DataType *in0, - const DataType *in1) + static DEVICE void reduce(CompType *out, const DataType *in) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in0[elem] + in1[elem]; + out[elem] += static_cast(in[elem]); } } - static DEVICE void postReduce(DataType *out, const DataType *in, + static DEVICE void postReduce(DataType *out, const CompType *in, int nelem = 1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in[elem]; + out[elem] = static_cast(in[elem]); } } - static DEVICE void singleIdentity(DataType *v) + static DEVICE CompType singleIdentity() { - *v = 0; + return CompType(0); } - static DEVICE void singleReduce(DataType *out, const DataType *in0, - const DataType *in1) + static DEVICE void singleReduce(CompType &out, const CompType &in) { - *out = *in0 + *in1; + out += in; } - static DEVICE void singlePostReduce(DataType *out, const DataType *in, + static DEVICE void singlePostReduce(DataType &out, const CompType &in, int nelem = 1) { - *out = *in; + out = static_cast(in); } }; -template <> struct ReduceTypeSum +template <> struct ReduceTypeSum { using DataType = half; + using CompType = half; static const int NelemPerThread = 2; static DEVICE void identity(half *v) { *reinterpret_cast<__half2 *>(v) = (__half2_raw){0, 0}; } - static DEVICE void reduce(half *out, const half *in0, const half *in1) + static DEVICE void reduce(half *out, const half *in) { __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in02 = reinterpret_cast(in0); - const __half2 *in12 = reinterpret_cast(in1); - *out2 = __hadd2(*in02, *in12); + const __half2 *in2 = reinterpret_cast(in); + *out2 = __hadd2(*out2, *in2); } static DEVICE void postReduce(half *out, const half *in, int nelem = 1) { @@ -184,86 +184,87 @@ template <> struct ReduceTypeSum const __half2 *in2 = reinterpret_cast(in); *out2 = *in2; } - static DEVICE void singleIdentity(half *v) + static DEVICE half singleIdentity() { - *v = 0; + return half(0); } - static DEVICE void singleReduce(half *out, const half *in0, const half *in1) + static DEVICE void singleReduce(half &out, const half &in) { - *out = *in0 + *in1; + out += in; } - static DEVICE void singlePostReduce(half *out, const half *in, + static DEVICE void singlePostReduce(half &out, const half &in, int nelem = 1) { - *out = *in; + out = in; } }; -template struct ReduceTypeMax +template +struct ReduceTypeMax { using DataType = _DataType; + using CompType = _CompType; static const int NelemPerThread = _NelemPerThread; - static DEVICE void identity(DataType *v) + static DEVICE void identity(CompType *v) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - v[elem] = platform::numeric_limits::lowest(); + v[elem] = platform::numeric_limits::lowest(); } } - static DEVICE void reduce(DataType *out, const DataType *in0, - const DataType *in1) + static DEVICE void reduce(CompType *out, const DataType *in) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = (in0[elem] > in1[elem]) ? in0[elem] : in1[elem]; + CompType data = static_cast(in[elem]); + out[elem] = (out[elem] > data) ? out[elem] : data; } } - static DEVICE void postReduce(DataType *out, const DataType *in, + static DEVICE void postReduce(DataType *out, const CompType *in, int nelem = 1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in[elem]; + out[elem] = static_cast(in[elem]); } } - static DEVICE void singleIdentity(DataType *v) + static DEVICE CompType singleIdentity() { - *v = platform::numeric_limits::lowest(); + return platform::numeric_limits::lowest(); } - static DEVICE void singleReduce(DataType *out, const DataType *in0, - const DataType *in1) + static DEVICE void singleReduce(CompType &out, const CompType &in) { - *out = (*in0 > *in1) ? *in0 : *in1; + out = (out > in) ? out : in; } - static DEVICE void singlePostReduce(DataType *out, const DataType *in, + static DEVICE void singlePostReduce(DataType &out, const CompType &in, int nelem = 1) { - *out = *in; + out = static_cast(in); } }; -template <> struct ReduceTypeMax +template <> struct ReduceTypeMax { using DataType = half; + using CompType = half; static const int NelemPerThread = 2; static DEVICE void identity(half *v) { *reinterpret_cast<__half2 *>(v) = (__half2_raw){0xfbff, 0xfbff}; } - static DEVICE void reduce(half *out, const half *in0, const half *in1) + static DEVICE void reduce(half *out, const half *in) { #if (__CUDA_ARCH__ >= 800) __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in02 = reinterpret_cast(in0); - const __half2 *in12 = reinterpret_cast(in1); - *out2 = __hmax2(*in02, *in12); + const __half2 *in2 = reinterpret_cast(in); + *out2 = __hmax2(*out2, *in2); #else #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = (in0[elem] > in1[elem]) ? in0[elem] : in1[elem]; + out[elem] = (out[elem] > in[elem]) ? out[elem] : in[elem]; } #endif // (__CUDA_ARCH__ >= 800) } @@ -273,80 +274,80 @@ template <> struct ReduceTypeMax const __half2 *in2 = reinterpret_cast(in); *out2 = *in2; } - static DEVICE void singleIdentity(half *v) + static DEVICE half singleIdentity() { - *v = platform::numeric_limits::lowest(); + return platform::numeric_limits::lowest(); } - static DEVICE void singleReduce(half *out, const half *in0, const half *in1) + static DEVICE void singleReduce(half &out, const half &in) { - *out = (*in0 > *in1) ? *in0 : *in1; + out = (out > in) ? out : in; } - static DEVICE void singlePostReduce(half *out, const half *in, + static DEVICE void singlePostReduce(half &out, const half &in, int nelem = 1) { - *out = *in; + out = in; } }; -template struct ReduceTypeMean +template +struct ReduceTypeMean { using DataType = _DataType; + using CompType = _CompType; static const int NelemPerThread = _NelemPerThread; - static DEVICE void identity(DataType *v) + static DEVICE void identity(CompType *v) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { v[elem] = 0; } } - static DEVICE void reduce(DataType *out, const DataType *in0, - const DataType *in1) + static DEVICE void reduce(CompType *out, const DataType *in) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in0[elem] + in1[elem]; + out[elem] += static_cast(in[elem]); } } - static DEVICE void postReduce(DataType *out, const DataType *in, + static DEVICE void postReduce(DataType *out, const CompType *in, int nelem = 1) { #pragma unroll for (int elem = 0; elem < NelemPerThread; ++elem) { - out[elem] = in[elem] / nelem; + out[elem] = static_cast(in[elem] / CompType(nelem)); } } - static DEVICE void singleIdentity(DataType *v) + static DEVICE CompType singleIdentity() { - *v = 0; + return CompType(0); } - static DEVICE void singleReduce(DataType *out, const DataType *in0, - const DataType *in1) + static DEVICE void singleReduce(CompType &out, const CompType &in) { - *out = *in0 + *in1; + out += in; } - static DEVICE void singlePostReduce(DataType *out, const DataType *in, + static DEVICE void singlePostReduce(DataType &out, const CompType &in, int nelem = 1) { - *out = *in / nelem; + out = static_cast(in / CompType(nelem)); } }; -template <> struct ReduceTypeMean +template <> struct ReduceTypeMean { using DataType = half; + using CompType = half; static const int NelemPerThread = 2; static DEVICE void identity(half *v) { *reinterpret_cast<__half2 *>(v) = (__half2_raw){0, 0}; } - static DEVICE void reduce(half *out, const half *in0, const half *in1) + static DEVICE void reduce(half *out, const half *in) { __half2 *out2 = reinterpret_cast<__half2 *>(out); - const __half2 *in02 = reinterpret_cast(in0); - const __half2 *in12 = reinterpret_cast(in1); - *out2 = __hadd2(*in02, *in12); + const __half2 *in2 = reinterpret_cast(in); + *out2 = __hadd2(*out2, *in2); } static DEVICE void postReduce(half *out, const half *in, int nelem = 1) { @@ -354,18 +355,18 @@ template <> struct ReduceTypeMean const __half2 *in2 = reinterpret_cast(in); *out2 = __h2div(*in2, __float2half2_rn((float)nelem)); } - static DEVICE void singleIdentity(half *v) + static DEVICE half singleIdentity() { - *v = 0; + return half(0); } - static DEVICE void singleReduce(half *out, const half *in0, const half *in1) + static DEVICE void singleReduce(half &out, const half &in) { - *out = *in0 + *in1; + out += in; } - static DEVICE void singlePostReduce(half *out, const half *in, + static DEVICE void singlePostReduce(half &out, const half &in, int nelem = 1) { - *out = *in / nelem; + out = in / nelem; } }; @@ -379,19 +380,20 @@ template { using DataType = typename ReduceType::DataType; + using CompType = typename ReduceType::CompType; static const int NelemPerThread = ReduceType::NelemPerThread; - static DEVICE void compute(DataType *out, DataType *in, int idx_n, + static DEVICE void compute(DataType *out, const DataType *in, int idx_n, int idx_c, int idx_h, int idx_w) { int idx_out = idx_c * OutDims::HW + idx_h * OutDims::W + idx_w; int idx_in = idx_c * InDims::HW + idx_h * InDims::W + idx_w; - DataType reduced[NelemPerThread]; + CompType reduced[NelemPerThread]; ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::N; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i * InDims::CHW]); + ReduceType::reduce(reduced, &in[idx_in + i * InDims::CHW]); } ReduceType::postReduce(&out[idx_out], reduced, InShape::N); } @@ -403,6 +405,7 @@ template { using DataType = typename ReduceType::DataType; + using CompType = typename ReduceType::CompType; static const int NelemPerThread = ReduceType::NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, @@ -410,12 +413,12 @@ struct EwiseReduceCompType { int idx_out = idx_n * OutDims::CHW + idx_h * OutDims::W + idx_w; int idx_in = idx_n * InDims::CHW + idx_h * InDims::W + idx_w; - DataType reduced[NelemPerThread]; + CompType reduced[NelemPerThread]; ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::C; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i * InDims::HW]); + ReduceType::reduce(reduced, &in[idx_in + i * InDims::HW]); } ReduceType::postReduce(&out[idx_out], reduced, InShape::C); } @@ -427,6 +430,7 @@ template { using DataType = typename ReduceType::DataType; + using CompType = typename ReduceType::CompType; static const int NelemPerThread = ReduceType::NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, @@ -434,12 +438,12 @@ struct EwiseReduceCompType { int idx_out = idx_n * OutDims::CHW + idx_c * OutDims::HW + idx_w; int idx_in = idx_n * InDims::CHW + idx_c * InDims::HW + idx_w; - DataType reduced[NelemPerThread]; + CompType reduced[NelemPerThread]; ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::H; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i * InDims::W]); + ReduceType::reduce(reduced, &in[idx_in + i * InDims::W]); } ReduceType::postReduce(&out[idx_out], reduced, InShape::H); } @@ -451,6 +455,7 @@ template { using DataType = typename ReduceType::DataType; + using CompType = typename ReduceType::CompType; static const int NelemPerThread = ReduceType::NelemPerThread; static DEVICE void compute(DataType *out, DataType *in, int idx_n, @@ -460,21 +465,20 @@ struct EwiseReduceCompType idx_n * OutDims::CHW + idx_c * OutDims::HW + idx_h * OutDims::W; int idx_in = idx_n * InDims::CHW + idx_c * InDims::HW + idx_h * InDims::W; - DataType reduced[NelemPerThread]; + CompType reduced[NelemPerThread]; ReduceType::identity(reduced); #pragma unroll for (int i = 0; i < InShape::W; ++i) { - ReduceType::reduce(reduced, reduced, &in[idx_in + i]); + ReduceType::reduce(reduced, &in[idx_in + i]); } - DataType finalSum; - ReduceType::singleIdentity(&finalSum); + CompType finalSum = ReduceType::singleIdentity(); #pragma unroll for (int i = 0; i < NelemPerThread; ++i) { - ReduceType::singleReduce(&finalSum, &finalSum, &reduced[i]); + ReduceType::singleReduce(finalSum, reduced[i]); } - ReduceType::singlePostReduce(&out[idx_out], &finalSum, InShape::W); + ReduceType::singlePostReduce(out[idx_out], finalSum, InShape::W); } }; @@ -487,6 +491,7 @@ struct EwiseReduce using UnitOp = UnitOp; using DataType = typename ReduceType::DataType; + using CompType = typename ReduceType::CompType; static const int NelemPerThread = ReduceType::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); @@ -521,6 +526,7 @@ struct WwiseReduce using UnitOp = UnitOp; using DataType = typename ReduceType::DataType; + using CompType = typename ReduceType::CompType; static const int NelemPerThread = ReduceType::NelemPerThread; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); @@ -570,21 +576,19 @@ struct WwiseReduce (tid_c + uc * UnitOutDims::C) * InDims::HW + (tid_n + un * UnitOutDims::N) * InDims::CHW; - DataType reduced[NelemPerThread]; + CompType reduced[NelemPerThread]; ReduceType::identity(reduced); #pragma unroll for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_in_w; - ReduceType::reduce(reduced, reduced, &in[idx_in]); + ReduceType::reduce(reduced, &in[idx_in_base + idx_in_w]); } - DataType finalSum; - ReduceType::singleIdentity(&finalSum); + CompType finalSum = ReduceType::singleIdentity(); #pragma unroll for (int i = 0; i < NelemPerThread; ++i) { - ReduceType::singleReduce(&finalSum, &finalSum, &reduced[i]); + ReduceType::singleReduce(finalSum, reduced[i]); } UnitOp::sync_threads(); @@ -595,7 +599,7 @@ struct WwiseReduce // write the result to output. if (tid % ThreadsPerRow == 0) { - ReduceType::singlePostReduce(&out[idx_out], &finalSum, InShape::W); + ReduceType::singlePostReduce(out[idx_out], finalSum, InShape::W); } } }; @@ -606,7 +610,8 @@ template , Axis>::run(out, in, uop_idx); + SmemBytes, ReduceTypeSum, Axis>::run(out, in, + uop_idx); } template , Axis>::run(out, in, - uop_idx); + SmemBytes, ReduceTypeSum, Axis>::run(out, in, + uop_idx); } template , Axis>::run(out, in, - uop_idx); + SmemBytes, ReduceTypeMean, Axis>::run(out, in, + uop_idx); } template , Axis>::run(out, in, - uop_idx); + SmemBytes, ReduceTypeMean, Axis>::run(out, in, + uop_idx); } template , Axis>::run(out, in, uop_idx); + SmemBytes, ReduceTypeMax, Axis>::run(out, in, + uop_idx); } template , Axis>::run(out, in, - uop_idx); + SmemBytes, ReduceTypeMax, Axis>::run(out, in, + uop_idx); } template , Axis>::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeSum, + Axis>::runW(out, in, uop_idx, smem_per_warp); } template , Axis>::runW(out, in, - uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeSum, + Axis>::runW(out, in, uop_idx, smem_per_warp); } template , Axis>::runW(out, in, - uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMean, + Axis>::runW(out, in, uop_idx, smem_per_warp); } template , Axis>::runW(out, in, - uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMean, + Axis>::runW(out, in, uop_idx, smem_per_warp); } template , Axis>::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMax, + Axis>::runW(out, in, uop_idx, smem_per_warp); } template , Axis>::runW(out, in, - uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMax, + Axis>::runW(out, in, uop_idx, smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/softmax.h b/ark/include/kernels/softmax.h index 32ec3b875..d39f24840 100644 --- a/ark/include/kernels/softmax.h +++ b/ark/include/kernels/softmax.h @@ -4,6 +4,7 @@ #ifndef ARK_KERNELS_SOFTMAX_H_ #define ARK_KERNELS_SOFTMAX_H_ +#include "math_functions.h" #include "reduce.h" namespace ark { @@ -24,7 +25,7 @@ template struct SoftmaxShapeChecker // Perform layer normalization on input and write the result on output. template + int SmemBytes, typename DataType, typename CompType, int NelemPerThread> struct Softmax { using UnitOp = @@ -39,8 +40,8 @@ struct Softmax int smem_per_warp) { using InOutChk = SoftmaxShapeChecker; - using ReduceTypeMax = ReduceTypeMax; - using ReduceTypeSum = ReduceTypeSum; + using ReduceTypeMax = ReduceTypeMax; + using ReduceTypeSum = ReduceTypeSum; constexpr int NonReduceDimLength = UnitOutDims::NCH; // The reduction dimension of the final stage. @@ -69,40 +70,36 @@ struct Softmax (tid_n + un * UnitOutDims::N) * InDims::CHW; // get the max input. - DataType max_input; - ReduceTypeMax::singleIdentity(&max_input); + CompType max_input = ReduceTypeMax::singleIdentity(); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_in_w; - ReduceTypeMax::singleReduce(&max_input, &max_input, &in[idx_in]); + ReduceTypeMax::singleReduce(max_input, in[idx_in_base + idx_in_w]); } UnitOp::sync_threads(); // final reduction on shared memory using warp shuffle. max_input = warpsReduce( max_input, tid, smem_per_warp); - // get the max input. - ReduceTypeMax::singlePostReduce(&max_input, &max_input, UnitOutDims::W); // get the exp input sum, use float to avoid overflow. - DataType exp_sum_input; - ReduceTypeSum::singleIdentity(&exp_sum_input); + CompType exp_sum_input = ReduceTypeSum::singleIdentity(); UnitOp::sync_threads(); for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; - exp_sum_input = exp_sum_input + expf(in[idx_in] - max_input); + CompType data = static_cast(in[idx_in]); + exp_sum_input += Exp::compute(data - max_input); } UnitOp::sync_threads(); exp_sum_input = warpsReduce( exp_sum_input, tid, smem_per_warp); - ReduceTypeSum::singlePostReduce(&exp_sum_input, &exp_sum_input); UnitOp::sync_threads(); // the output is for (int idx_in_w = tid_w; idx_in_w < InShape::W; idx_in_w += ThreadsPerRow) { int idx_in = idx_in_base + idx_in_w; - out[idx_in] = expf(in[idx_in] - max_input) / exp_sum_input; + CompType data = static_cast(in[idx_in]); + out[idx_in] = Exp::compute(data - max_input) / exp_sum_input; } } }; @@ -114,7 +111,7 @@ DEVICE void softmax(float *out, float *in, int uop_idx, int smem_per_warp) { constexpr int NelemPerThread = 1; Softmax::run(out, in, uop_idx, + SmemBytes, float, float, NelemPerThread>::run(out, in, uop_idx, smem_per_warp); } @@ -126,7 +123,7 @@ DEVICE void softmax(ark::half *out, ark::half *in, int uop_idx, { constexpr int NelemPerThread = 1; Softmax::run(out, in, uop_idx, + SmemBytes, ark::half, float, NelemPerThread>::run(out, in, uop_idx, smem_per_warp); } diff --git a/ark/ops/ops_reduce_test.cc b/ark/ops/ops_reduce_test.cc index df63ea13f..76a046d4d 100644 --- a/ark/ops/ops_reduce_test.cc +++ b/ark/ops/ops_reduce_test.cc @@ -6,6 +6,7 @@ #include "ops_test_common.h" #include "unittest/unittest_utils.h" #include +#include template void baseline_reduce_sum_axis0(std::vector &outputs, @@ -24,12 +25,69 @@ void baseline_reduce_sum_axis0(std::vector &outputs, for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType h = 0; h < ish[2]; ++h) { for (ark::DimType w = 0; w < ish[3]; ++w) { - T sum = 0; + float sum = 0; for (ark::DimType n = 0; n < ish[0]; ++n) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } - out[c * osh[2] * osh[3] + h * osh[3] + w] = sum; + out[c * osh[2] * osh[3] + h * osh[3] + w] = T(sum); + } + } + } +} + +template +void baseline_reduce_mean_axis0(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[0] == 1); + + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + for (ark::DimType w = 0; w < ish[3]; ++w) { + float sum = 0; + for (ark::DimType n = 0; n < ish[0]; ++n) { + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); + } + out[c * osh[2] * osh[3] + h * osh[3] + w] = T(sum / ish[0]); + } + } + } +} + +template +void baseline_reduce_max_axis0(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[0] == 1); + + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + for (ark::DimType w = 0; w < ish[3]; ++w) { + float red = std::numeric_limits::lowest(); + for (ark::DimType n = 0; n < ish[0]; ++n) { + float val = float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); + red = (val > red) ? val : red; + } + out[c * osh[2] * osh[3] + h * osh[3] + w] = T(red); } } } @@ -52,12 +110,12 @@ void baseline_reduce_sum_axis1(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType h = 0; h < ish[2]; ++h) { for (ark::DimType w = 0; w < ish[3]; ++w) { - T sum = 0; + float sum = 0; for (ark::DimType c = 0; c < ish[1]; ++c) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } - out[n * osh[1] * osh[2] * osh[3] + h * osh[3] + w] = sum; + out[n * osh[1] * osh[2] * osh[3] + h * osh[3] + w] = T(sum); } } } @@ -80,13 +138,13 @@ void baseline_reduce_sum_axis2(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType w = 0; w < ish[3]; ++w) { - T sum = 0; + float sum = 0; for (ark::DimType h = 0; h < ish[2]; ++h) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + w] = - sum; + T(sum); } } } @@ -109,67 +167,126 @@ void baseline_reduce_sum_axis3(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType h = 0; h < ish[2]; ++h) { - T sum = 0; + float sum = 0; + for (ark::DimType w = 0; w < ish[3]; ++w) { + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); + } + out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + + h * osh[3]] = T(sum); + } + } + } +}; + +template +void baseline_reduce_mean_axis3(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[3] == 1); + + for (ark::DimType n = 0; n < ish[0]; ++n) { + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + float sum = 0; for (ark::DimType w = 0; w < ish[3]; ++w) { - sum += input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + sum += float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); } out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + - h * osh[3]] = sum; + h * osh[3]] = T(sum / ish[3]); } } } }; -ark::unittest::State test_reduce_axis0() +template +void baseline_reduce_max_axis3(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes) +{ + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0].dims4(); + ark::Dims ish = input_shapes[0].dims4(); + + assert(osh[3] == 1); + + for (ark::DimType n = 0; n < ish[0]; ++n) { + for (ark::DimType c = 0; c < ish[1]; ++c) { + for (ark::DimType h = 0; h < ish[2]; ++h) { + float red = std::numeric_limits::lowest(); + for (ark::DimType w = 0; w < ish[3]; ++w) { + float val = float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); + red = (val > red) ? val : red; + } + out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + + h * osh[3]] = T(red); + } + } + } +}; + +ark::unittest::State test_reduce_sum_axis0() { ark::Model m; ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP32); ark::Tensor *out = m.reduce_sum(t, /*axis=*/0); - auto result = ark::op_test("reduce_axis0", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_axis0", m, {t}, {out}, baseline_reduce_sum_axis0); ark::op_test_log(result); return ark::unittest::SUCCESS; } -ark::unittest::State test_reduce_axis1() +ark::unittest::State test_reduce_sum_axis1() { ark::Model m; ark::Tensor *t = m.tensor(ark::Dims(1, 2, 4, 1024), ark::FP32); ark::Tensor *out = m.reduce_sum(t, /*axis=*/1); - auto result = ark::op_test("reduce_axis1", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_axis1", m, {t}, {out}, baseline_reduce_sum_axis1); ark::op_test_log(result); return ark::unittest::SUCCESS; } -ark::unittest::State test_reduce_axis2() +ark::unittest::State test_reduce_sum_axis2() { ark::Model m; ark::Tensor *t = m.tensor(ark::Dims(1, 1, 7, 8192), ark::FP32); ark::Tensor *out = m.reduce_sum(t, /*axis=*/2); - auto result = ark::op_test("reduce_axis2", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_axis2", m, {t}, {out}, baseline_reduce_sum_axis2); ark::op_test_log(result); return ark::unittest::SUCCESS; } -ark::unittest::State test_reduce_axis3() +ark::unittest::State test_reduce_sum_axis3() { ark::Model m; ark::Tensor *t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); ark::Tensor *out = m.reduce_sum(t, /*axis=*/3); - auto result = ark::op_test("reduce_axis3", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_axis3", m, {t}, {out}, baseline_reduce_sum_axis3); ark::op_test_log(result); return ark::unittest::SUCCESS; } -ark::unittest::State test_reduce_axis3_padded() +ark::unittest::State test_reduce_sum_axis3_padded() { ark::Model m; ark::Tensor *t = m.tensor(ark::Dims(1, 1, 2, 8192), ark::FP32); @@ -177,43 +294,91 @@ ark::unittest::State test_reduce_axis3_padded() ark::Dims(1, 1, 2, 32)); out = m.reduce_sum(t, /*axis=*/3, out); - auto result = ark::op_test("reduce_axis3_padded", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_axis3_padded", m, {t}, {out}, baseline_reduce_sum_axis3); ark::op_test_log(result); return ark::unittest::SUCCESS; } -ark::unittest::State test_reduce_fp16() +ark::unittest::State test_reduce_sum_fp16() { { ark::Model m; - ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 8192), ark::FP16); ark::Tensor *out = m.reduce_sum(t, /*axis=*/0); - auto result = ark::op_test("reduce_fp16_axis0", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_fp16_axis0", m, {t}, {out}, baseline_reduce_sum_axis0); ark::op_test_log(result); } { ark::Model m; - ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 1024), ark::FP16); + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 8192), ark::FP16); ark::Tensor *out = m.reduce_sum(t, /*axis=*/3); - auto result = ark::op_test("reduce_fp16_axis3", m, {t}, {out}, + auto result = ark::op_test("reduce_sum_fp16_axis3", m, {t}, {out}, baseline_reduce_sum_axis3); ark::op_test_log(result); } return ark::unittest::SUCCESS; } +ark::unittest::State test_reduce_mean_fp16() +{ + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 8192), ark::FP16); + ark::Tensor *out = m.reduce_mean(t, /*axis=*/0); + + auto result = ark::op_test("reduce_mean_fp16_axis0", m, {t}, {out}, + baseline_reduce_mean_axis0); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 8192), ark::FP16); + ark::Tensor *out = m.reduce_mean(t, /*axis=*/3); + + auto result = ark::op_test("reduce_mean_fp16_axis3", m, {t}, {out}, + baseline_reduce_mean_axis3); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_reduce_max_fp16() +{ + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 8192), ark::FP16); + ark::Tensor *out = m.reduce_max(t, /*axis=*/0); + + auto result = ark::op_test("reduce_max_fp16_axis0", m, {t}, {out}, + baseline_reduce_max_axis0); + ark::op_test_log(result); + } + { + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(7, 2, 4, 8192), ark::FP16); + ark::Tensor *out = m.reduce_max(t, /*axis=*/3); + + auto result = ark::op_test("reduce_max_fp16_axis3", m, {t}, {out}, + baseline_reduce_max_axis3); + ark::op_test_log(result); + } + return ark::unittest::SUCCESS; +} + int main() { ark::init(); - UNITTEST(test_reduce_axis0); - UNITTEST(test_reduce_axis1); - UNITTEST(test_reduce_axis2); - UNITTEST(test_reduce_axis3); - UNITTEST(test_reduce_axis3_padded); - UNITTEST(test_reduce_fp16); + UNITTEST(test_reduce_sum_axis0); + UNITTEST(test_reduce_sum_axis1); + UNITTEST(test_reduce_sum_axis2); + UNITTEST(test_reduce_sum_axis3); + UNITTEST(test_reduce_sum_axis3_padded); + UNITTEST(test_reduce_sum_fp16); + UNITTEST(test_reduce_mean_fp16); + UNITTEST(test_reduce_max_fp16); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_rmsnorm_test.cc b/ark/ops/ops_rmsnorm_test.cc index 6eff74646..b3bc2edff 100644 --- a/ark/ops/ops_rmsnorm_test.cc +++ b/ark/ops/ops_rmsnorm_test.cc @@ -24,21 +24,21 @@ void baseline_rmsnorm(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType h = 0; h < ish[2]; ++h) { - T square_sum = 0; + float square_sum = 0; for (ark::DimType w = 0; w < ish[3]; ++w) { - - T val = input[n * ish[1] * ish[2] * ish[3] + - c * ish[2] * ish[3] + h * ish[3] + w]; + float val = + float(input[n * ish[1] * ish[2] * ish[3] + + c * ish[2] * ish[3] + h * ish[3] + w]); square_sum += val * val; } - T eps = 1e-5; - T rms = (T)sqrt((float)square_sum / ish[3]) + eps; + float eps = 1e-5; + float rms = std::sqrt(square_sum / ish[3]) + eps; for (ark::DimType w = 0; w < ish[3]; ++w) { out[n * osh[1] * osh[2] * osh[3] + c * osh[2] * osh[3] + h * osh[3] + w] = - input[n * osh[1] * osh[2] * osh[3] + - c * osh[2] * osh[3] + h * osh[3] + w] / - rms; + T(float(input[n * osh[1] * osh[2] * osh[3] + + c * osh[2] * osh[3] + h * osh[3] + w]) / + rms); } } } @@ -48,10 +48,10 @@ void baseline_rmsnorm(std::vector &outputs, ark::unittest::State test_rmsnorm_fp32() { ark::Model m; - ark::Tensor *t = m.tensor(ark::Dims(1, 32, 32, 256), ark::FP32); + ark::Tensor *t = m.tensor(ark::Dims(1, 8192), ark::FP32); ark::Tensor *out = m.rmsnorm(t); auto result = - ark::op_test("rmsnorm", m, {t}, {out}, baseline_rmsnorm); + ark::op_test("rmsnorm_fp32", m, {t}, {out}, baseline_rmsnorm); ark::op_test_log(result); return ark::unittest::SUCCESS; } @@ -59,18 +59,93 @@ ark::unittest::State test_rmsnorm_fp32() ark::unittest::State test_rmsnorm_fp16() { ark::Model model; - ark::Tensor *input = model.tensor(ark::Dims(1, 32, 32, 256), ark::FP16); + ark::Tensor *input = model.tensor(ark::Dims(1, 8192), ark::FP16); ark::Tensor *output = model.rmsnorm(input); - auto result = ark::op_test("rmsnorm", model, {input}, {output}, + + // std::vector data; + // for (int i = 0; i < 8192; ++i) { + // data.push_back(ark::half_t(8.0f)); + // } + auto result = ark::op_test("rmsnorm_fp16", model, {input}, {output}, baseline_rmsnorm); ark::op_test_log(result); return ark::unittest::SUCCESS; } +ark::unittest::State test_rmsnorm_compare() +{ + ark::srand(); + + ark::Dims shape(2048, 16384); + auto input_data_fp16 = ark::utils::rand_halfs(shape.size(), 0.1); + auto input_data_fp32 = std::unique_ptr(new float[shape.size()]); + for (int i = 0; i < shape.size(); ++i) { + input_data_fp32[i] = float(input_data_fp16[i]); + } + + std::vector output_fp32(shape.size()); + std::vector output_fp16_tmp(shape.size()); + std::vector output_fp16(shape.size()); + + std::string test_name = "rmsnorm_compare"; + int num_warps_per_sm = 16; + + { + ark::Model model_fp32; + ark::Tensor *input = model_fp32.tensor(shape, ark::FP32); + ark::Tensor *output = model_fp32.rmsnorm(input); + + ark::Executor exe{0, 1, model_fp32, test_name, num_warps_per_sm}; + exe.compile(); + + input->write(input_data_fp32.get()); + + exe.launch(); + exe.run(1); + exe.stop(); + + output->read(output_fp32.data()); + } + + { + ark::Model model_fp16; + ark::Tensor *input = model_fp16.tensor(shape, ark::FP16); + ark::Tensor *output = model_fp16.rmsnorm(input); + + ark::Executor exe{0, 1, model_fp16, test_name, num_warps_per_sm}; + exe.compile(); + + input->write(input_data_fp16.get()); + + exe.launch(); + exe.run(1); + exe.stop(); + + output->read(output_fp16_tmp.data()); + } + + for (ark::DimType i = 0; i < shape.size(); ++i) { + output_fp16[i] = float(output_fp16_tmp[i]); + } + + auto comp = ark::tensor_compare(output_fp32.data(), output_fp16.data(), shape); + + ark::OpsTestResult result; + result.test_name = test_name; + result.num_warps_per_sm = num_warps_per_sm; + result.mse.push_back(comp.mse); + result.max_diff.push_back(comp.max_diff); + result.max_err_rate.push_back(comp.max_error_rate); + + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} + int main() { ark::init(); - UNITTEST(test_rmsnorm_fp32); - UNITTEST(test_rmsnorm_fp16); + // UNITTEST(test_rmsnorm_fp32); + // UNITTEST(test_rmsnorm_fp16); + UNITTEST(test_rmsnorm_compare); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_softmax_test.cc b/ark/ops/ops_softmax_test.cc index ced9b5434..947c8227f 100644 --- a/ark/ops/ops_softmax_test.cc +++ b/ark/ops/ops_softmax_test.cc @@ -22,17 +22,24 @@ void baseline_softmax(std::vector &outputs, for (ark::DimType n = 0; n < ish[0]; ++n) { for (ark::DimType c = 0; c < ish[1]; ++c) { for (ark::DimType h = 0; h < ish[2]; ++h) { - T sum = 0; + float maxval = std::numeric_limits::lowest(); for (ark::DimType w = 0; w < ish[3]; ++w) { - sum += std::exp(input[w + h * ish[3] + c * ish[2] * ish[3] + - n * ish[1] * ish[2] * ish[3]]); + float val = float(input[w + h * ish[3] + c * ish[2] * ish[3] + + n * ish[1] * ish[2] * ish[3]]); + if (val > maxval) { + maxval = val; + } + } + float sum = 0; + std::vector exps(ish[3]); + for (ark::DimType w = 0; w < ish[3]; ++w) { + exps[w] = std::exp(float(input[w + h * ish[3] + c * ish[2] * ish[3] + + n * ish[1] * ish[2] * ish[3]]) - maxval); + sum += exps[w]; } for (ark::DimType w = 0; w < ish[3]; ++w) { out[w + h * osh[3] + c * osh[2] * osh[3] + - n * osh[1] * osh[2] * osh[3]] = - std::exp(input[w + h * ish[3] + c * ish[2] * ish[3] + - n * ish[1] * ish[2] * ish[3]]) / - sum; + n * osh[1] * osh[2] * osh[3]] = T(exps[w] / sum); } } } @@ -42,11 +49,23 @@ void baseline_softmax(std::vector &outputs, ark::unittest::State test_softmax_fp32() { ark::Model m; - ark::Tensor *t = m.tensor(ark::Dims(2, 8192), ark::FP32); + ark::Tensor *t = m.tensor(ark::Dims(64, 8192), ark::FP32); + ark::Tensor *out = m.softmax(t); + + auto result = + ark::op_test("softmax_fp32", m, {t}, {out}, baseline_softmax); + ark::op_test_log(result); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_fp16() +{ + ark::Model m; + ark::Tensor *t = m.tensor(ark::Dims(64, 8192), ark::FP16); ark::Tensor *out = m.softmax(t); auto result = - ark::op_test("reduce_axis3", m, {t}, {out}, baseline_softmax); + ark::op_test("softmax_fp16", m, {t}, {out}, baseline_softmax); ark::op_test_log(result); return ark::unittest::SUCCESS; } @@ -55,5 +74,6 @@ int main() { ark::init(); UNITTEST(test_softmax_fp32); + UNITTEST(test_softmax_fp16); return ark::unittest::SUCCESS; } diff --git a/examples/llama/llama b/examples/llama/llama new file mode 160000 index 000000000..9f0e39399 --- /dev/null +++ b/examples/llama/llama @@ -0,0 +1 @@ +Subproject commit 9f0e393991b45d320f5b4a287eaaeb8a7d2e6f8e diff --git a/examples/llama/model.py b/examples/llama/model.py new file mode 100644 index 000000000..3ea616774 --- /dev/null +++ b/examples/llama/model.py @@ -0,0 +1,485 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark +import math +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +@dataclass +class ModelArgs7B(ModelArgs): + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +@dataclass +class ModelArgs13B(ModelArgs): + dim: int = 5120 + n_layers: int = 40 + n_heads: int = 40 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +@dataclass +class ModelArgs70B(ModelArgs): + dim: int = 8192 + n_layers: int = 80 + n_heads: int = 64 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 4096 + + +class RMSNorm(ark.Module): + """ + Root mean square layer normalization (RMSNorm). + """ + + def __init__( + self, dim: int, eps: float = 1e-6, dtype: ark.DataType = ark.fp16 + ): + super().__init__() + self.eps = eps + self.dtype = dtype + self.weight = ark.parameter([dim], dtype) + + def _norm(self, x): + # x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return ark.rmsnorm(x) + + def forward(self, x): + output = self._norm(x) + return ark.mul(output, ark.reshape(self.weight, [1, 1, -1])) + + +class ColumnParallelLinear(ark.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Here the weight = A^T, so we need to partition the weight matrix along + its first dimension. + + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.dtype = dtype + self.local_rank = local_rank + self.world_size = world_size + + self.weight = ark.parameter([out_dim // world_size, in_dim], dtype) + + def forward(self, x): + if self.world_size == 1: + return ark.matmul(x, self.weight, transpose_other=True) + # (batch_size, seq_len, out_dim // world_size) + output_tensor_shard = ark.matmul(x, self.weight, transpose_other=True) + all_gather_tensor_shards = ark.all_gather( + output_tensor_shard, self.local_rank, self.world_size + ) + # We need to concat the output_tensor_shards along the last dimension + assert len(all_gather_tensor_shards) == self.world_size + output_tensor = ark.tensor( + [x.shape[0], x.shape[1], self.out_dim], self.dtype + ) + output_tensor_shards = ark.sharding( + output_tensor, 2, self.out_dim // self.world_size + ) + output_dependency = [] + # Copy all the all_gather_tensor_shards to output_tensor_shards + for i in range(self.world_size): + output_tensor_shard = ark.scale( + all_gather_tensor_shards[i], 1.0, output_tensor_shards[i] + ) + output_dependency.append(output_tensor_shard) + # The output_tensor should depend on the scale operators + output_tensor = ark.identity(output_tensor, output_dependency) + return output_tensor + + +class RowParallelLinear(ark.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + + Here the weight = A^T, so we need to partition the weight matrix along + its second dimension. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.dtype = dtype + self.local_rank = local_rank + self.world_size = world_size + + self.weight = ark.parameter([out_dim, in_dim // world_size], dtype) + + def forward(self, x): + if self.world_size == 1: + return ark.matmul(x, self.weight, transpose_other=True) + x_ndims = len(x.shape) + x_shards = ark.sharding(x, x_ndims - 1, self.in_dim // self.world_size) + output_parallel = ark.matmul( + x_shards[self.local_rank], self.weight, transpose_other=True + ) + # allreduce the output_parallel, currently we only support allreduce on 1D tensor, + # so we need to reshape the output_parallel to 1D + output_shape = output_parallel.shape + # multiply the output_shape list + output_shape_bytes = 1 + for i in range(len(output_shape)): + output_shape_bytes *= output_shape[i] + output_parallel_reshape = ark.reshape( + output_parallel, + [output_shape_bytes], + ) + output_reshape = ark.all_reduce( + output_parallel_reshape, self.local_rank, self.world_size + ) + output = ark.reshape(output_reshape, output_shape) + return output + + +class ParallelEmbedding(ark.Module): + """Embedding layer.""" + + # TODO: support parallelism + def __init__(self, vocab_size: int, dim: int, dtype: ark.DataType): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + self.weight = ark.parameter([vocab_size, dim], dtype) + + def forward(self, x): + return ark.embedding(x, self.weight) + + +class Linear(ark.Module): + """ + Linear layer module with weights and no bias. + """ + + def __init__( + self, in_dim: int, out_dim: int, dtype: ark.DataType = ark.fp16 + ): + super().__init__() + self.dtype = dtype + self.weight = ark.parameter([out_dim, in_dim], dtype) + + def forward(self, x): + return ark.matmul(x, self.weight, transpose_other=True) + + +class Silu(ark.Module): + """ + Silu activation function, silu(x) = x * sigmoid(x) + """ + + def __init__(self): + super().__init__() + + def forward(self, x: ark.Tensor): + # We need to specify output tensor so that the sigmoid op will not be an in-place operator + output = ark.tensor(x.shape(), x.dtype()) + x1 = ark.sigmoid(x, output) + return ark.mul(x, x1) + + +class FeedForward(ark.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, dtype, local_rank, world_size + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, dtype, local_rank, world_size + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, dtype, local_rank, world_size + ) + + def forward(self, x): + # self.w2(F.silu(self.w1(x)) * self.w3(x)) + x1 = self.w1(x) + x1 = Silu()(x1) + x2 = self.w3(x) + x3 = ark.mul(x1, x2) + x4 = self.w2(x3) + return x4 + + +def apply_rotary_emb(xq, xk, freqs_cis): + """ + Apply rotary embeddings to xq and xk. + """ + xq_out = ark.rope(xq, freqs_cis) + xk_out = ark.rope(xk, freqs_cis) + return xq_out, xk_out + + +class Attention(ark.Module): + def __init__( + self, + args: ModelArgs, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.n_kv_heads = ( + args.n_heads if args.n_kv_heads is None else args.n_kv_heads + ) + model_parallel_size = 1 + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + dtype, + local_rank, + world_size, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + dtype, + local_rank, + world_size, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + dtype, + local_rank, + world_size, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + dtype, + local_rank, + world_size, + ) + + def forward( + self, + x: ark.Tensor, + start_pos: int, + freqs_cis: ark.Tensor, + mask: Optional[ark.Tensor], + ): + bsz, seqlen, _ = x.shape() + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + # xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + # xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + # xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq = ark.reshape(xq, [bsz, seqlen, self.n_local_heads, self.head_dim]) + xk = ark.reshape( + xk, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] + ) + xv = ark.reshape( + xv, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] + ) + if freqs_cis is not None: + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # TODO: enable kv cache and mask later + keys = xk + values = xv + # (bs, n_local_heads, seqlen, head_dim) + xq = ark.transpose(xq, [0, 2, 1, 3]) + keys = ark.transpose(keys, [0, 2, 1, 3]) + values = ark.transpose(values, [0, 2, 1, 3]) + + # (bs, n_local_heads, head_dim, seqlen) + keys_transpose = ark.transpose(keys, [0, 1, 3, 2]) + scores = ark.matmul(xq, keys_transpose) + scores = ark.scale(scores, 1.0 / math.sqrt(self.head_dim)) + + if mask is not None: + scores = ark.add(scores, mask) + scores = ark.softmax(scores) + + output = ark.matmul( + scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + output = ark.transpose(output, [0, 2, 1, 3]) + output = ark.reshape( + output, [bsz, seqlen, self.head_dim * self.n_local_heads] + ) + output = self.wo(output) + return output + + +class TransformerBlock(ark.Module): + def __init__( + self, + layer_id: int, + args: ModelArgs, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args, dtype, local_rank, world_size) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + dtype=dtype, + local_rank=local_rank, + world_size=world_size, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps, dtype=dtype) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps, dtype=dtype) + + def forward( + self, + x: ark.Tensor, + start_pos: int, + freqs_cis: ark.Tensor, + mask: Optional[ark.Tensor], + ): + attention_norm_x = self.attention_norm(x) + h = self.attention.forward(attention_norm_x, start_pos, freqs_cis, mask) + h = ark.add(x, h) + out = ark.add(h, self.feed_forward(self.ffn_norm(h))) + return out + + +class Transformer(ark.Module): + def __init__( + self, + params: ModelArgs, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = ParallelEmbedding( + params.vocab_size, params.dim, dtype + ) + + self.layers = [] + for layer_id in range(self.n_layers): + self.layers.append( + TransformerBlock( + layer_id, params, dtype, local_rank, world_size + ) + ) + self.register_module(f"layers.{layer_id}", self.layers[layer_id]) + self.norm = RMSNorm(params.dim, eps=params.norm_eps, dtype=dtype) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, dtype, local_rank, world_size + ) + + def forward( + self, + tokens: ark.Tensor, + start_pos: int, + freqs_cis: ark.Tensor, + mask: Optional[ark.Tensor], + ): + h = self.tok_embeddings(tokens) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h) + return output diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py new file mode 100644 index 000000000..190a6b017 --- /dev/null +++ b/examples/llama/model_test.py @@ -0,0 +1,463 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark +import sys +import torch +import os +import time +import fairscale + +sys.path.append("llama") +import llama.model as model_pt +import model as model_ark +import numpy as np +from typing import Any, Dict, List +from model import ModelArgs, ModelArgs7B, ModelArgs13B, ModelArgs70B + + +pth_path: str = "/mnt/7B/consolidated.00.pth" + +numpy_dtype_to_torch_dtype: dict = { + np.float16: torch.float16, + np.float32: torch.float32, + np.int32: torch.int32, +} + + +def run_ark( + module: ark.Module, + state_dict: Dict[str, np.ndarray], + inputs: list = [], + rank: int = 0, + world_size: int = 1, +) -> List[np.ndarray]: + ark.set_rank(rank) + ark.set_world_size(world_size) + + module_inputs = [ + ark.tensor(list(i.shape), ark.DataType.from_numpy(i.dtype)) + if isinstance(i, np.ndarray) + else i + for i in inputs + ] + output = module(*module_inputs) + + runtime = ark.Runtime() + runtime.launch() + + # Load model parameters + module.load_state_dict(state_dict) + + # Load input data into tensors + tensors = [i for i in module_inputs if isinstance(i, ark.Tensor)] + tensor_data = [i for i in inputs if isinstance(i, np.ndarray)] + for tensor, ndarray in zip(tensors, tensor_data): + tensor.from_numpy(ndarray) + + # Run the model + runtime.run() + + if isinstance(output, list) or isinstance(output, tuple): + return [o.to_numpy() for o in output] + return [output.to_numpy()] + + +def run_pt( + module: torch.nn.Module, + state_dict: Dict[str, np.ndarray], + inputs: list = [], +) -> List[np.ndarray]: + # Update the current state_dict with the given one + cur_state_dict = module.state_dict() + for k, v in state_dict.items(): + cur_state_dict[k] = torch.from_numpy(v) + module.load_state_dict(cur_state_dict) + + # Load input data to GPU + input_tensors = [ + torch.from_numpy(i).to("cuda:0") if isinstance(i, np.ndarray) else i + for i in inputs + ] + + # Load the module to GPU + module = module.to("cuda:0") + + # Run the module + with torch.no_grad(): + output = module(*input_tensors) + + if isinstance(output, list) or isinstance(output, tuple): + return [o.detach().to("cpu").numpy() for o in output] + return [output.detach().to("cpu").numpy()] + + +def test_module( + inputs: List[np.ndarray], + dtype: np.dtype, + module_class_ark: ark.Module, + module_args_ark: list, + module_class_pt: torch.nn.Module, + module_args_pt: list, + ark_inputs: List[np.ndarray] = [], # used when ARK needs different inputs + module_name_prefix: str = "", +): + # ARK module + module_ark: ark.Module = module_class_ark(*module_args_ark) + + param_names = set(module_ark.params_dict().keys()) + + if os.path.exists(pth_path): + prefix = module_name_prefix + "." if module_name_prefix else "" + # Load the state_dict from the given path + state_dict = torch.load(pth_path) + state_dict = { + k[len(prefix) :]: v.float().numpy().astype(dtype) + for k, v in state_dict.items() + if k[len(prefix) :] in param_names and k.startswith(prefix) + } + else: + # Create a random state_dict + state_dict = { + k: np.random.uniform(low=-0.1, high=0.1, size=v.size()).astype(dtype) + for k, v in module_ark.params_dict().items() + } + + # Run the ARK module + output_ark = run_ark( + module_ark, state_dict, ark_inputs if ark_inputs else inputs + ) + + # PyTorch module + module_pt: torch.nn.Module = module_class_pt(*module_args_pt) + + # + for _, param in module_pt.named_parameters(): + param.data = param.data.to(numpy_dtype_to_torch_dtype[dtype]) + + # Run the PyTorch module + output_pt = run_pt(module_pt, state_dict, inputs) + + # Compare the outputs + eps = np.finfo(np.float64).eps + for i, (o_ark, o_pt) in enumerate(zip(output_ark, output_pt)): + o_ark = o_ark.flatten().astype(np.float64) + o_pt = o_pt.flatten().astype(np.float64) + + abs_diff = np.abs(o_ark - o_pt) + max_abs_diff_idx = np.argmax(abs_diff) + max_abs_diff = abs_diff[max_abs_diff_idx] + + rel_diff = abs_diff / (np.abs(o_pt) + eps) + max_rel_diff_idx = np.argmax(rel_diff) + max_rel_diff = rel_diff[max_rel_diff_idx] + + mean_square_error = np.mean(np.square(o_ark - o_pt)) + + # Test info as string + test_info = f"{module_class_ark.__name__}: output {i}" + + print( + test_info + "\n" + f" max_abs_diff: {max_abs_diff:.4e} ({o_pt[max_abs_diff_idx]} vs {o_ark[max_abs_diff_idx]})\n" + f" max_rel_diff: {max_rel_diff:.4e} ({o_pt[max_rel_diff_idx]} vs {o_ark[max_rel_diff_idx]})\n" + f" mean_square_error: {mean_square_error:.4e}\n" + ) + + +def test_rmsnorm( + args: ModelArgs, batch_size: int, seq_len: int, dtype: np.dtype +): + ark.init() + + # Create random input data + inputs = [ + np.random.uniform( + low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim) + ).astype(dtype) + ] + + test_module( + inputs, + dtype, + module_class_ark=model_ark.RMSNorm, + module_args_ark=[ + args.dim, + args.norm_eps, + ark.DataType.from_numpy(dtype), + ], + module_class_pt=model_pt.RMSNorm, + module_args_pt=[args.dim], + module_name_prefix="norm", + ) + + +def test_row_parallel_linear( + args: ModelArgs, + batch_size: int, + seq_len: int, + dtype: np.dtype, + world_size: int = 1, +): + ark.init() + + # Create random input data + inputs = [ + np.random.uniform( + low=-0.1, + high=0.1, + size=(batch_size, seq_len, args.dim // args.n_heads * args.n_heads), + ).astype(dtype) + ] + + if world_size == 1: + test_module( + inputs, + dtype, + module_class_ark=model_ark.RowParallelLinear, + module_args_ark=[ + args.dim // args.n_heads * args.n_heads, + args.dim, + ark.DataType.from_numpy(dtype), + 0, + 1, + ], + module_class_pt=fairscale.nn.model_parallel.RowParallelLinear, + module_args_pt=[ + args.dim // args.n_heads * args.n_heads, + args.dim, + False, + lambda x: x, + ], + module_name_prefix="layers.0.attention.wo", + ) + + +def test_column_parallel_linear( + args: ModelArgs, + batch_size: int, + seq_len: int, + dtype: np.dtype, + world_size: int = 1, +): + ark.init() + + # Create random input data + inputs = [ + np.random.uniform( + low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim) + ).astype(dtype) + ] + + if world_size == 1: + test_module( + inputs, + dtype, + module_class_ark=model_ark.ColumnParallelLinear, + module_args_ark=[ + args.dim, + args.dim // args.n_heads * args.n_heads, + ark.DataType.from_numpy(dtype), + 0, + 1, + ], + module_class_pt=fairscale.nn.model_parallel.ColumnParallelLinear, + module_args_pt=[ + args.dim, + args.dim // args.n_heads * args.n_heads, + False, + lambda x: x, + ], + module_name_prefix="layers.0.attention.wq", + ) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / ( + theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(np.float32) / dim) + ) + t = np.arange(end, dtype=np.float32) + freqs = np.outer(t, freqs).astype(np.float32) + freqs_cis = np.exp(1j * freqs) + return freqs_cis + + +def test_attention( + args: ModelArgs, + batch_size: int, + seq_len: int, + dtype: np.dtype, + world_size: int = 1, +): + ark.init() + + # + freqs_cis = precompute_freqs_cis( + args.dim // args.n_heads, args.max_seq_len * 2 + )[0:seq_len] + + freqs_cis_ark = freqs_cis.astype(np.complex64) + freqs_cis_ark = ( + np.stack([freqs_cis_ark.real, freqs_cis_ark.imag], axis=-1) + .astype(dtype) + .reshape(1, seq_len, 1, args.dim // args.n_heads) + ) + + feature = np.random.uniform( + low=-0.1, high=0.1, size=(batch_size, seq_len, args.dim) + ).astype(dtype) + + inputs = [feature, 0, freqs_cis, None] + ark_inputs = [feature, 0, freqs_cis_ark, None] + + if world_size == 1: + test_module( + inputs, + dtype, + module_class_ark=model_ark.Attention, + module_args_ark=[args, ark.DataType.from_numpy(dtype), 0, 1], + module_class_pt=model_pt.Attention, + module_args_pt=[args], + ark_inputs=ark_inputs, + module_name_prefix="layers.0.attention", + ) + + +def test_transformer_block( + args: ModelArgs, + batch_size: int, + seq_len: int, + dtype: np.dtype, + world_size: int = 1, +): + ark.init() + + # + freqs_cis = precompute_freqs_cis( + args.dim // args.n_heads, args.max_seq_len * 2 + )[0:seq_len] + + freqs_cis_ark = freqs_cis.astype(np.complex64) + freqs_cis_ark = ( + np.stack([freqs_cis_ark.real, freqs_cis_ark.imag], axis=-1) + .astype(dtype) + .reshape(1, seq_len, 1, args.dim // args.n_heads) + ) + + feature = np.random.uniform( + low=-1, high=1, size=(batch_size, seq_len, args.dim) + ).astype(dtype) + + inputs = [feature, 0, freqs_cis, None] + ark_inputs = [feature, 0, freqs_cis_ark, None] + + if world_size == 1: + test_module( + inputs, + dtype, + module_class_ark=model_ark.TransformerBlock, + module_args_ark=[0, args, ark.DataType.from_numpy(dtype), 0, 1], + module_class_pt=model_pt.TransformerBlock, + module_args_pt=[0, args], + ark_inputs=ark_inputs, + module_name_prefix="layers.0", + ) + + +def test_transformer( + args: ModelArgs, + batch_size: int, + seq_len: int, + dtype: np.dtype, + world_size: int = 1, +): + ark.init() + + # Random input tokens + + tokens = np.random.randint( + low=0, high=args.vocab_size, size=(batch_size, seq_len) + ).astype(np.int32) + + start_pos = 0 + + # Pre-calculated freqs_cis + + freqs_cis = precompute_freqs_cis( + args.dim // args.n_heads, args.max_seq_len * 2 + )[0:seq_len] + + freqs_cis_ark = freqs_cis.astype(np.complex64) + freqs_cis_ark = ( + np.stack([freqs_cis_ark.real, freqs_cis_ark.imag], axis=-1) + .astype(dtype) + .reshape(1, seq_len, 1, args.dim // args.n_heads) + ) + + # Pre-calculated mask + + if seq_len == 1: + mask = None + else: + mask = np.full((1, 1, seq_len, seq_len), -np.inf, dtype=dtype) + mask = np.triu(mask, k=start_pos + 1) + + inputs = [tokens, start_pos] + ark_inputs = [tokens, start_pos, freqs_cis_ark, mask] + + if world_size == 1: + test_module( + inputs, + dtype, + module_class_ark=model_ark.Transformer, + module_args_ark=[args, ark.DataType.from_numpy(dtype), 0, 1], + module_class_pt=model_pt.Transformer, + module_args_pt=[args], + ark_inputs=ark_inputs, + ) + + +def test(args, batch_size, seq_len, dtype, world_size): + # test_rmsnorm(args, batch_size, seq_len, dtype) + # test_row_parallel_linear(args, batch_size, seq_len, dtype, world_size) + # test_column_parallel_linear(args, batch_size, seq_len, dtype, world_size) + # test_attention(args, batch_size, seq_len, dtype, world_size) + # test_transformer_block(args, batch_size, seq_len, dtype, world_size) + test_transformer(args, batch_size, seq_len, dtype, world_size) + + +if __name__ == "__main__": + # Configurations + args = ModelArgs7B() + batch_size = 1 + seq_len = 2048 + dtype = np.float16 + world_size = 1 + + # Default from HuggingFace + args.vocab_size = 32000 + + # For debugging + # args.n_layers = 8 + + # Verify the configurations + assert batch_size <= args.max_batch_size + assert seq_len <= args.max_seq_len + + # For torch.distributed + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + os.environ["WORLD_SIZE"] = str(world_size) + + if world_size == 1: + # For torch.distributed + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + torch.distributed.init_process_group("nccl") + + # For fairscale + fairscale.nn.model_parallel.initialize.initialize_model_parallel( + world_size + ) + + test(args, batch_size, seq_len, dtype, world_size) diff --git a/pyproject.toml b/pyproject.toml index c5a936e7a..6ac10808c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,4 +24,4 @@ BUILD_PYTHON = "ON" line-length = 80 target-version = ['py38'] include = '\.pyi?$' -exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist|third_party|docs)/' +exclude = '/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist|third_party|docs|examples/llama/llama)/' diff --git a/python/ark/module.py b/python/ark/module.py index 5e7645e52..324380f2e 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -64,25 +64,24 @@ def load_state_dict( Must be called after the executor is launched. """ logging.info("Loading model from state_dict") - for name, module in self.sub_modules.items(): - if module is not None: - module.load_state_dict(state_dict, prefix=prefix + name + ".") - for name, param in self.parameters.items(): - param.from_numpy(state_dict[prefix + name]) - def state_dict(self, prefix="") -> Dict[str, np.ndarray]: + all_keys = set(state_dict.keys()) + pd = self.params_dict(prefix) + for name, param in pd.items(): + param.from_numpy(state_dict[name]) + all_keys.remove(name) + if all_keys: + logging.warning( + f"{len(all_keys)} unused parameter(s) in state_dict" + ) + + def state_dict(self, prefix: str = "") -> Dict[str, np.ndarray]: """ - Copies the parameters from the device GPU to the host and saves the model to a state_dict. + Copies the parameters from the device GPU to the host and saves the + model to a state_dict. Must be called after the executor is launched. """ - state_dict = {} - for name, module in self.sub_modules.items(): - if module is not None: - state_dict.update(module.state_dict(prefix=prefix + name + ".")) - for name, param in self.parameters.items(): - param_np = param.to_numpy() - state_dict[prefix + name] = param_np - return state_dict + return {k: v.to_numpy() for k, v in self.params_dict(prefix).items()} def forward(self, *args: Any, **kwargs: Any) -> Any: ...