From dfa35eaee78cc562c91eac0b74bea77673eab41c Mon Sep 17 00:00:00 2001 From: SaltyChiang Date: Sun, 31 Mar 2024 20:41:31 +0800 Subject: [PATCH 01/32] `Mat()` of overlap is correct. --- include/dirac_quda.h | 52 +++++ include/dslash_quda.h | 24 +++ include/enum_quda.h | 3 + include/enum_quda_fortran.h | 45 +++-- include/kernels/dslash_overlap.cuh | 66 ++++++ include/quda.h | 4 + lib/CMakeLists.txt | 3 +- lib/dirac.cpp | 14 ++ lib/dirac_overlap.cpp | 314 +++++++++++++++++++++++++++++ lib/dslash_overlap.cu | 64 ++++++ lib/eig_iram.cpp | 6 +- lib/eigensolve_quda.cpp | 3 +- lib/interface_quda.cpp | 100 +++++++++ 13 files changed, 673 insertions(+), 25 deletions(-) create mode 100644 include/kernels/dslash_overlap.cuh create mode 100644 lib/dirac_overlap.cpp create mode 100644 lib/dslash_overlap.cu diff --git a/include/dirac_quda.h b/include/dirac_quda.h index 1aa339139f..3e5c5098bb 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -79,6 +79,8 @@ namespace quda { bool use_mobius_fused_kernel; // Whether or not use fused kernels for Mobius + int chebyshev_degree; + // Default constructor DiracParam() : type(QUDA_INVALID_DIRAC), @@ -181,6 +183,8 @@ namespace quda { bool use_mobius_fused_kernel; // Whether or not use fused kernels for Mobius + int chebyshev_degree; + mutable TimeProfile profile; public: @@ -1321,6 +1325,47 @@ namespace quda { virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; }; + // Full overlap + class DiracOverlap : public DiracWilson { + + protected: + double mass_overlap; + DiracWilson *wilson; + mutable int hermitian_wilson_n_eig; + mutable std::vector hermitian_wilson_evecs; + mutable std::vector hermitian_wilson_evals; + mutable double remez_tol; + mutable int remez_n; + mutable std::vector remez_c; + + public: + DiracOverlap(const DiracParam ¶m); + DiracOverlap(const DiracOverlap &dirac); + virtual ~DiracOverlap(); + DiracOverlap& operator=(const DiracOverlap &dirac); + + virtual void Dslash(ColorSpinorField &, const ColorSpinorField &, const QudaParity ) const { errorQuda("Not implemented!\n"); } + virtual void DslashXpay(ColorSpinorField &, const ColorSpinorField &, const QudaParity, const ColorSpinorField &, const double &) const { errorQuda("Not implemented!\n"); } + virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; + virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; + + virtual void prepare(ColorSpinorField* &src, ColorSpinorField* &sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const; + virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const; + + virtual QudaDiracType getDiracType() const { return QUDA_OVERLAP_DIRAC; } + + /** + @brief If managed memory and prefetch is enabled, prefetch + all relevant memory fields (gauge, clover, temporary spinors) + to the CPU or GPU as requested + @param[in] mem_space Memory space we are prefetching to + @param[in] stream Which stream to run the prefetch in (default 0) + */ + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + + void setupHermitianWilson(int n_eig, const std::vector &evecs, const std::vector &evals, double invsqrt_tol) const; + }; + // Full staggered class DiracStaggered : public Dirac { @@ -2533,6 +2578,10 @@ namespace quda { // needs 5th dimension reversal, Mobius needs that inversion... errorQuda("Support for Hermitian DWF operator %d does not exist yet", dirac_type); break; + case QUDA_OVERLAP_DIRAC: + case QUDA_OVERLAPPC_DIRAC: + errorQuda("Support for Hermitian Overlap operator %d does not exist yet", dirac_type); + break; case QUDA_STAGGERED_DIRAC: case QUDA_ASQTAD_DIRAC: // Gamma5 is (-1)^(x+y+z+t) @@ -2601,6 +2650,9 @@ namespace quda { || dirac_type == QUDA_GAUGE_COVDEV_DIRAC) return true; + if (dirac_type == QUDA_WILSON_DIRAC || dirac_type == QUDA_CLOVER_DIRAC) + return true; + // subtle: odd operator gets a minus sign if ((dirac_type == QUDA_STAGGEREDPC_DIRAC || dirac_type == QUDA_ASQTADPC_DIRAC) && (pc_type == QUDA_MATPC_EVEN_EVEN || pc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC)) diff --git a/include/dslash_quda.h b/include/dslash_quda.h index ea12741be0..0c15c1e115 100644 --- a/include/dslash_quda.h +++ b/include/dslash_quda.h @@ -632,6 +632,30 @@ namespace quda void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type); + /** + @brief Driver for applying the Wilson stencil + + out = D * in + + where D is the gauged Wilson linear operator. + + If kappa is non-zero, the operation is given by out = x + kappa * D in. + This operator can be applied to both single parity + (checker-boarded) fields, or to full fields. + + @param[out] out The output result field + @param[in] in The input field + @param[in] U The gauge field used for the operator + @param[in] kappa Scale factor applied + @param[in] x Vector field we accumulate onto to + @param[in] parity Destination parity + @param[in] dagger Whether this is for the dagger operator + @param[in] comm_override Override for which dimensions are partitioned + @param[in] profile The TimeProfile used for profiling the dslash + */ + void ApplyOverlap(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double kappa, + const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); + // The EOFA stuff namespace mobius_eofa { diff --git a/include/enum_quda.h b/include/enum_quda.h index c4cbb59901..d615c2c757 100644 --- a/include/enum_quda.h +++ b/include/enum_quda.h @@ -96,6 +96,7 @@ typedef enum QudaDslashType_s { QUDA_DOMAIN_WALL_4D_DSLASH, QUDA_MOBIUS_DWF_DSLASH, QUDA_MOBIUS_DWF_EOFA_DSLASH, + QUDA_OVERLAP_DSLASH, QUDA_STAGGERED_DSLASH, QUDA_ASQTAD_DSLASH, QUDA_TWISTED_MASS_DSLASH, @@ -307,6 +308,8 @@ typedef enum QudaDiracType_s { QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC, QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC, QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC, + QUDA_OVERLAP_DIRAC, + QUDA_OVERLAPPC_DIRAC, QUDA_STAGGERED_DIRAC, QUDA_STAGGEREDPC_DIRAC, QUDA_STAGGEREDKD_DIRAC, diff --git a/include/enum_quda_fortran.h b/include/enum_quda_fortran.h index 6a8708948a..bb128e28bc 100644 --- a/include/enum_quda_fortran.h +++ b/include/enum_quda_fortran.h @@ -84,12 +84,13 @@ #define QUDA_DOMAIN_WALL_4D_DSLASH 4 #define QUDA_MOBIUS_DWF_DSLASH 5 #define QUDA_MOBIUS_DWF_EOFA_DSLASH 6 -#define QUDA_STAGGERED_DSLASH 7 -#define QUDA_ASQTAD_DSLASH 8 -#define QUDA_TWISTED_MASS_DSLASH 9 -#define QUDA_TWISTED_CLOVER_DSLASH 10 -#define QUDA_LAPLACE_DSLASH 11 -#define QUDA_COVDEV_DSLASH 12 +#define QUDA_OVERLAP_DSLASH 7 +#define QUDA_STAGGERED_DSLASH 8 +#define QUDA_ASQTAD_DSLASH 9 +#define QUDA_TWISTED_MASS_DSLASH 10 +#define QUDA_TWISTED_CLOVER_DSLASH 11 +#define QUDA_LAPLACE_DSLASH 12 +#define QUDA_COVDEV_DSLASH 13 #define QUDA_INVALID_DSLASH QUDA_INVALID_ENUM #define QudaInverterType integer(4) @@ -280,21 +281,23 @@ #define QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC 11 #define QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC 12 #define QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC 13 -#define QUDA_STAGGERED_DIRAC 14 -#define QUDA_STAGGEREDPC_DIRAC 15 -#define QUDA_STAGGEREDKD_DIRAC 16 -#define QUDA_ASQTAD_DIRAC 17 -#define QUDA_ASQTADPC_DIRAC 18 -#define QUDA_ASQTADKD_DIRAC 19 -#define QUDA_TWISTED_MASS_DIRAC 20 -#define QUDA_TWISTED_MASSPC_DIRAC 21 -#define QUDA_TWISTED_CLOVER_DIRAC 22 -#define QUDA_TWISTED_CLOVERPC_DIRAC 23 -#define QUDA_COARSE_DIRAC 24 -#define QUDA_COARSEPC_DIRAC 25 -#define QUDA_GAUGE_LAPLACE_DIRAC 26 -#define QUDA_GAUGE_LAPLACEPC_DIRAC 27 -#define QUDA_GAUGE_COVDEV_DIRAC 28 +#define QUDA_OVERLAP_DIRAC 14 +#define QUDA_OVERLAPPC_DIRAC 15 +#define QUDA_STAGGERED_DIRAC 16 +#define QUDA_STAGGEREDPC_DIRAC 17 +#define QUDA_STAGGEREDKD_DIRAC 18 +#define QUDA_ASQTAD_DIRAC 19 +#define QUDA_ASQTADPC_DIRAC 20 +#define QUDA_ASQTADKD_DIRAC 21 +#define QUDA_TWISTED_MASS_DIRAC 22 +#define QUDA_TWISTED_MASSPC_DIRAC 23 +#define QUDA_TWISTED_CLOVER_DIRAC 24 +#define QUDA_TWISTED_CLOVERPC_DIRAC 25 +#define QUDA_COARSE_DIRAC 26 +#define QUDA_COARSEPC_DIRAC 27 +#define QUDA_GAUGE_LAPLACE_DIRAC 28 +#define QUDA_GAUGE_LAPLACEPC_DIRAC 29 +#define QUDA_GAUGE_COVDEV_DIRAC 30 #define QUDA_INVALID_DIRAC QUDA_INVALID_ENUM ! Where the field is stored diff --git a/include/kernels/dslash_overlap.cuh b/include/kernels/dslash_overlap.cuh new file mode 100644 index 0000000000..b029556aa7 --- /dev/null +++ b/include/kernels/dslash_overlap.cuh @@ -0,0 +1,66 @@ +#pragma once + +#include +#include +#include + +namespace quda +{ + + template + struct OverlapArg : WilsonArg { + using WilsonArg::nSpin; + static constexpr int length = (nSpin / (nSpin / 2)) * 2 * nColor * nColor * (nSpin / 2) * (nSpin / 2) / 2; + + typedef typename mapper::type real; + + const real a; /** xpay scale factor */ + + OverlapArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, + const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : + WilsonArg(out, in, U, a, x, parity, dagger, comm_override), a(a) + { + } + }; + + template struct overlap : dslash_default { + + const Arg &arg; + constexpr overlap(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation + + /** + @brief Apply the Wilson-clover dslash + out(x) = M*in = A(x)*x(x) + D * in(x-mu) + Note this routine only exists in xpay form. + */ + template + __device__ __host__ __forceinline__ void operator()(int idx, int, int parity) + { + typedef typename mapper::type real; + typedef ColorSpinor Vector; + + bool active + = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) + int thread_dim; // which dimension is thread working on (fused kernel only) + + auto coord = getCoords(arg, idx, 0, parity, thread_dim); + + const int my_spinor_parity = nParity == 2 ? parity : 0; + Vector out; + applyWilson(out, arg, coord, parity, idx, thread_dim, active); + + int xs = coord.x_cb + coord.s * arg.dc.volume_4d_cb; + if (xpay && mykernel_type == INTERIOR_KERNEL) { + Vector x = arg.x(xs, my_spinor_parity); + out = x + arg.a * out; + } else if (mykernel_type != INTERIOR_KERNEL && active) { + Vector x = arg.out(xs, my_spinor_parity); + out = x + (xpay ? arg.a * out : out); + } + + if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out(xs, my_spinor_parity) = out; + } + }; + +} // namespace quda diff --git a/include/quda.h b/include/quda.h index d11fd4cbd0..b1bfcffec6 100644 --- a/include/quda.h +++ b/include/quda.h @@ -447,6 +447,10 @@ extern "C" { /** Whether to use fused kernels for mobius */ QudaBoolean use_mobius_fused_kernel; + int hermitian_wilson_n_ev; + int hermitian_wilson_n_kr; + double hermitian_wilson_tol; + double overlap_invsqrt_tol; } QudaInvertParam; // Parameter set for solving eigenvalue problems. diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d8c1d8342b..a940a99197 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -42,7 +42,7 @@ set (QUDA_OBJS dirac_staggered_kd.cpp dirac_clover_hasenbusch_twist.cpp dirac_improved_staggered.cpp dirac_improved_staggered_kd.cpp dirac_domain_wall.cpp dirac_domain_wall_4d.cpp dirac_mobius.cpp dirac_twisted_clover.cpp - dirac_twisted_mass.cpp + dirac_twisted_mass.cpp dirac_overlap.cpp llfat_quda.cu staggered_two_link_quda.cu gauge_force.cu gauge_loop_trace.cu gauge_polyakov_loop.cu gauge_random.cu gauge_noise.cu gauge_field_strength_tensor.cu clover_quda.cu @@ -110,6 +110,7 @@ set(QUDA_DSLASH_OBJS dslash_domain_wall_4d_m5inv_m5inv.cu dslash_domain_wall_4d_m5mob.cu dslash_domain_wall_4d_m5pre_m5mob.cu + dslash_overlap.cu dslash_pack2.cu laplace.cu covDev.cu staggered_quark_smearing.cu) if(QUDA_NVSHMEM) diff --git a/lib/dirac.cpp b/lib/dirac.cpp index 35411c1841..6aa9dc3f07 100644 --- a/lib/dirac.cpp +++ b/lib/dirac.cpp @@ -17,6 +17,7 @@ namespace quda { type(param.type), halo_precision(param.halo_precision), use_mobius_fused_kernel(param.use_mobius_fused_kernel), + chebyshev_degree(param.chebyshev_degree), profile("Dirac", false) { for (int i=0; i<4; i++) commDim[i] = param.commDim[i]; @@ -30,6 +31,7 @@ namespace quda { dagger(dirac.dagger), type(dirac.type), halo_precision(dirac.halo_precision), + chebyshev_degree(dirac.chebyshev_degree), profile("Dirac", false) { for (int i=0; i<4; i++) commDim[i] = dirac.commDim[i]; @@ -50,6 +52,8 @@ namespace quda { matpcType = dirac.matpcType; dagger = dirac.dagger; + chebyshev_degree = dirac.chebyshev_degree; + for (int i=0; i<4; i++) commDim[i] = dirac.commDim[i]; profile = dirac.profile; @@ -160,6 +164,12 @@ namespace quda { } else if (param.type == QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC) { if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracMobiusEofaPC operator\n"); return new DiracMobiusEofaPC(param); + } else if (param.type == QUDA_OVERLAP_DIRAC) { + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracOverlap operator\n"); + return new DiracOverlap(param); + } else if (param.type == QUDA_OVERLAPPC_DIRAC) { + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracOverlapPC operator\n"); + errorQuda("Overlap Dirac doesn't support even-odd preconditioning\n"); } else if (param.type == QUDA_STAGGERED_DIRAC) { if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printfQuda("Creating a DiracStaggered operator\n"); return new DiracStaggered(param); @@ -270,6 +280,10 @@ namespace quda { case QUDA_GAUGE_LAPLACEPC_DIRAC: steps = 2; break; + case QUDA_OVERLAP_DIRAC: + case QUDA_OVERLAPPC_DIRAC: + steps = chebyshev_degree; + break; default: errorQuda("Unsupported Dslash type %d.\n", type); steps = 0; diff --git a/lib/dirac_overlap.cpp b/lib/dirac_overlap.cpp new file mode 100644 index 0000000000..62ed0177b5 --- /dev/null +++ b/lib/dirac_overlap.cpp @@ -0,0 +1,314 @@ +#include +#include +#include +#include +#include +#include + +namespace quda +{ + // Chebyshev polynomial the first kind + // T_{k+1}(x) = 2 x T_k(x) - T_{k-1}(x) + double Tn(double x, int n) + { + if (abs(x) <= 1.0) { return cos(n * acos(x)); } + double T0 = 1, T1 = x, Tk; + switch (n) { + case 0: return T0; + case 1: return T1; + default: + for (int k = 2; k <= n; ++k) { + Tk = 2 * x * T1 - T0; + T0 = T1; + T1 = Tk; + } + return Tk; + } + } + + // \sum_{i=0}^n c_i T_i + // T_{k+1}(x) = 2 x T_k(x) - T_{k-1}(x) + // Use Clenshaw algorithm + double ciTi(double x, std::vector c, int n) + { + double b2 = 0.0, b1 = 0.0, bk; + for (int k = n; k >= 1; --k) { + bk = c[k] + 2 * x * b1 - b2; + b2 = b1; + b1 = bk; + } + return c[0] + x * b1 - b2; + } + + // (\sum_{i=0}^n c_i T_i)' = \sum_{i=1}^n i c_i U_{i-1} + // U_{k+1}(x) = 2 x U_k(x) - U_{k-1} + // Use Clenshaw algorithm + double iciUim1(double x, std::vector c, int n) + { + double b2 = 0.0, b1 = 0.0, bk; + for (int k = n - 1; k >= 1; --k) { + bk = (k + 1) * c[k + 1] + 2 * x * b1 - b2; + b2 = b1; + b1 = bk; + } + return c[1] + 2 * x * b1 - b2; + } + + double residual(double x, std::vector c, int n, double epsilon, bool derivative) + { + const double z = (x * 2 - (1 + epsilon)) / (1 - epsilon); + if (derivative) { + return -1 / (2 * sqrt(x)) * ciTi(z, c, n) - sqrt(x) * iciUim1(z, c, n) * (2 / (1 - epsilon)); + } else { + return 1 - sqrt(x) * ciTi(z, c, n); + } + } + + double find_root(double x_l, double x_r, std::vector c, int n, double epsilon, bool derivative) + { + double x_m, res_r, res_l, res_m; + + res_l = residual(x_l, c, n, epsilon, derivative); + res_r = residual(x_r, c, n, epsilon, derivative); + if (abs(res_l) < 1e-15) return x_l; + if (abs(res_r) < 1e-15) return x_r; + if (res_r * res_l > 0) + printf("ERROR: find_root with derivative=%d called with wrong ends: (%e %e)->(%e %e)\n", derivative, x_l, x_r, + res_l, res_r); + for (int i = 0; i < 10; i++) { + x_m = (res_l * x_r - res_r * x_l) / (res_l - res_r); + res_m = residual(x_m, c, n, epsilon, derivative); + if (res_m * res_l > 0) { + x_l = x_m; + res_l = res_m; + } else { + x_r = x_m; + res_r = res_m; + } + } + return (res_l * x_r - res_r * x_l) / (res_l - res_r); + } + + std::vector minimaxApproximationRemez(double delta, double epsilon) + { + const int n = ceil(-log(delta / 0.41) / (2.083 * sqrt(epsilon))) + 1; + constexpr int max_iter = 5; + std::vector y(n + 1), z(n + 1), c(n + 1), b(n + 1); + Eigen::Map b_eigen(b.data(), b.size()), c_eigen(c.data(), c.size()); + Eigen::MatrixXd M_eigen(n + 1, n + 1); + + for (int i = 0; i < n + 1; ++i) { + z[i] = cos(M_PI * i / n); + y[i] = (z[i] * (1 - epsilon) + (1 + epsilon)) / 2; + } + + for (int iter = 0; iter < max_iter; ++iter) { + // Construct matrix M_ij=\sqrt{y_i}T_j(z_i) + for (int i = 0; i < n + 1; ++i) { + for (int j = 0; j < n; ++j) { M_eigen(i, j) = sqrt(y[i]) * Tn(z[i], j); } + M_eigen(i, n) = i % 2 == 0 ? 1 : -1; // T_n is not a real Chebyshev polynomial + b_eigen(i) = 1.0; + } + c_eigen = M_eigen.lu().solve(b_eigen); + + // Drop T_n + for (int i = 0; i < n; ++i) { b[i] = find_root(y[i], y[i + 1], c, n - 1, epsilon, false); } + for (int i = n - 1; i > 0; --i) { y[i] = find_root(b[i], b[i - 1], c, n - 1, epsilon, true); } + for (int i = 1; i < n; ++i) { z[i] = (2 * y[i] - (1 + epsilon)) / (1 - epsilon); } + for (int i = 0; i < n + 1; ++i) { b[i] = abs(1 - sqrt(y[i]) * ciTi(z[i], c, n - 1)); } + if (*std::max_element(b.begin(), b.end()) <= delta) { return {c.begin(), c.begin() + n}; } + } + errorQuda("minimaxApproximationRemez can not converge"); + } + + DiracOverlap::DiracOverlap(const DiracParam ¶m) : + DiracWilson(param), mass_overlap(0), wilson(new DiracWilson(param)) + { + } + + DiracOverlap::DiracOverlap(const DiracOverlap &dirac) : + DiracWilson(dirac), + mass_overlap(dirac.mass_overlap), + wilson(dirac.wilson), + hermitian_wilson_n_eig(dirac.hermitian_wilson_n_eig), + hermitian_wilson_evecs(dirac.hermitian_wilson_evecs), + hermitian_wilson_evals(dirac.hermitian_wilson_evals), + remez_tol(dirac.remez_tol), + remez_n(dirac.remez_n), + remez_c(dirac.remez_c) + { + } + + DiracOverlap::~DiracOverlap() { delete wilson; } + + DiracOverlap &DiracOverlap::operator=(const DiracOverlap &dirac) + { + if (&dirac != this) { + DiracWilson::operator=(dirac); + mass_overlap = dirac.mass_overlap; + wilson = dirac.wilson; + hermitian_wilson_n_eig = dirac.hermitian_wilson_n_eig; + hermitian_wilson_evecs = dirac.hermitian_wilson_evecs; + hermitian_wilson_evals = dirac.hermitian_wilson_evals; + remez_tol = dirac.remez_tol; + remez_n = dirac.remez_n; + remez_c = dirac.remez_c; + } + return *this; + } + + // Apply sign function for small eigenvalues by applying lambda_i/|lambda_i|*|V_i> &evecs, + const std::vector &evals, int n_eig) + { + logQuda(QUDA_VERBOSE, "Deflating %d vectors\n", n_eig); + + // = A_i + std::vector s(n_eig); + blas::cDotProduct(s, {evecs.begin(), evecs.begin() + n_eig}, src); + + // src -= A_i|V_i> + for (int i = 0; i < n_eig; i++) { s[i] *= -1; } + blas::caxpy(s, {evecs.begin(), evecs.begin() + n_eig}, src); + + // sol += lambda_i/|lambda_i|*A_i|V_i> + for (int i = 0; i < n_eig; i++) { s[i] *= -evals[i] / abs(evals[i]); } + blas::zero(sol); + blas::caxpy(s, {evecs.begin(), evecs.begin() + n_eig}, sol); + } + + void signHighPolynomial(ColorSpinorField &b1, ColorSpinorField &b2, ColorSpinorField &Ab1, const ColorSpinorField &in, + DiracMatrix *mat, std::vector &remez_c, int remez_n, const double epsilon, + const double lambda_max) + { + b1.zero(); + b2.zero(); + for (int k = remez_n; k >= 1; --k) { + (*mat)(Ab1, b1); + blas::axpbyz(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1, Ab1); + blas::axpbypczw(remez_c[k], in, 2, Ab1, -1, b2, b2); + std::swap(b1, b2); + } + (*mat)(Ab1, b1); + blas::axpbyz(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1, Ab1); + blas::axpbypczw(remez_c[0], in, 1, Ab1, -1, b2, b2); + } + +#define flip(x) (x) = ((x) == QUDA_DAG_YES ? QUDA_DAG_NO : QUDA_DAG_YES) + + void DiracOverlap::M(ColorSpinorField &out, const ColorSpinorField &in) const + { + printfQuda("Entering DracOverlap::M\n"); + + auto tmp1 = getFieldTmp(in); + auto tmp2 = getFieldTmp(in); + auto tmp3 = getFieldTmp(in); + auto tmp4 = getFieldTmp(in); + auto tmp5 = getFieldTmp(in); + ColorSpinorField &b1 = tmp1; + ColorSpinorField &b2 = tmp2; + ColorSpinorField &Mb1 = tmp3; + ColorSpinorField &Ab1 = tmp4; + ColorSpinorField &deflated = tmp5; + + const double hermitian_wilson_evals_max = hermitian_wilson_evals[hermitian_wilson_n_eig - 1]; + const double epsilon = hermitian_wilson_evals_max * hermitian_wilson_evals_max; + const double lambda_max = (1 + 8 * kappa); + const double rho = 4 - 1 / (2 * kappa); + + // signLow(out, deflated, hermitian_wilson_evecs, hermitian_wilson_evals, hermitian_wilson_n_eig); + std::vector s(hermitian_wilson_n_eig); + blas::cDotProduct(s, hermitian_wilson_evecs, in); + for (int i = 0; i < hermitian_wilson_n_eig; i++) { s[i] *= -1; } + blas::caxpyz(s, hermitian_wilson_evecs, in, deflated); + for (int i = 0; i < hermitian_wilson_n_eig; i++) { + s[i] *= -hermitian_wilson_evals[i] / abs(hermitian_wilson_evals[i]); + } + out.zero(); + blas::caxpy(s, hermitian_wilson_evecs, out); + gamma5(out, out); + + // signHighPolynomial(b1, b2, Ab1, deflated, mat, remez_c, remez_n, epsilon, lambda_max); + b1.zero(); + b2.zero(); + for (int k = remez_n; k >= 1; --k) { + DiracWilson::M(Mb1, b1); + flip(dagger); + DiracWilson::M(Ab1, Mb1); + flip(dagger); + blas::axpby(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1); + blas::axpbypczw(remez_c[k], deflated, 2, Ab1, -1, b2, b2); + std::swap(b1, b2); + } + DiracWilson::M(Mb1, b1); + flip(dagger); + DiracWilson::M(Ab1, Mb1); + flip(dagger); + blas::axpby(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1); + blas::axpbypczw(remez_c[0], deflated, 1, Ab1, -1, b2, b2); + DiracWilson::M(b1, b2); + + blas::axpbypczw(rho, in, rho / lambda_max, b1, rho, out, out); + } + + void DiracOverlap::MdagM(ColorSpinorField &out, const ColorSpinorField &in) const + { + checkFullSpinor(out, in); + auto tmp = getFieldTmp(in); + + M(tmp, in); + Mdag(out, tmp); + } + + void DiracOverlap::prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, + const QudaSolutionType solType) const + { + if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { + errorQuda("Preconditioned solution requires a preconditioned solve_type"); + } + + src = &b; + sol = &x; + } + + void DiracOverlap::reconstruct(ColorSpinorField &, const ColorSpinorField &, const QudaSolutionType) const + { + // do nothing + } + + void DiracOverlap::prefetch(QudaFieldLocation mem_space, qudaStream_t stream) const + { + Dirac::prefetch(mem_space, stream); + } + + void DiracOverlap::setupHermitianWilson(int n_eig, const std::vector &evecs, + const std::vector &evals, double invsqrt_tol) const + { + hermitian_wilson_n_eig = n_eig; + hermitian_wilson_evecs.resize(n_eig); + hermitian_wilson_evals.resize(n_eig); + + const double lambda_max = 1 + 8 * kappa; + ColorSpinorParam cudaParam(evecs[0]); + if (evecs[0].Precision() == gauge->Precision()) { + cudaParam.create = QUDA_REFERENCE_FIELD_CREATE; + for (int i = 0; i < n_eig; i++) { + cudaParam.v = evecs[i].data(); + hermitian_wilson_evecs[i] = ColorSpinorField(cudaParam); + hermitian_wilson_evals[i] = evals[i].real() / lambda_max; + } + } else { + cudaParam.create = QUDA_NULL_FIELD_CREATE; + cudaParam.setPrecision(gauge->Precision(), gauge->Precision(), true); + for (int i = 0; i < n_eig; i++) { + hermitian_wilson_evecs[i] = ColorSpinorField(cudaParam); + hermitian_wilson_evecs[i] = evecs[i]; + hermitian_wilson_evals[i] = evals[i].real() / lambda_max; + } + } + remez_tol = invsqrt_tol; + remez_c = minimaxApproximationRemez(invsqrt_tol, pow(hermitian_wilson_evals[n_eig - 1], 2)); + remez_n = remez_c.size() - 1; + } + +} // namespace quda diff --git a/lib/dslash_overlap.cu b/lib/dslash_overlap.cu new file mode 100644 index 0000000000..aff1cbafe0 --- /dev/null +++ b/lib/dslash_overlap.cu @@ -0,0 +1,64 @@ +#include +#include +#include +#include + +#include +#include + +/** + This is the Wilson-clover linear operator +*/ + +namespace quda +{ + + template class Overlap : public Dslash + { + using Dslash = Dslash; + using Dslash::arg; + using Dslash::in; + + public: + Overlap(Arg &arg, const ColorSpinorField &out, const ColorSpinorField &in) : Dslash(arg, out, in) { } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + Dslash::setParam(tp); + Dslash::template instantiate(tp, stream); + } + }; + + template struct OverlapApply { + + inline OverlapApply(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, + const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, + TimeProfile &profile) + { + constexpr int nDim = 4; + OverlapArg arg(out, in, U, a, x, parity, dagger, comm_override); + Overlap overlap(arg, out, in); + + dslash::DslashPolicyTune policy(overlap, in, in.VolumeCB(), in.GhostFaceCB(), profile); + } + }; + + // Apply the overlap operator + // out(x) = M*in = (A(x)*in(x) + a * \sum_mu U_{-\mu}(x)in(x+mu) + U^\dagger_mu(x-mu)in(x-mu)) + // Uses the kappa normalization for the Wilson operator. +#ifdef GPU_WILSON_DIRAC + void ApplyOverlap(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, + const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile) + { + instantiate(out, in, U, a, x, parity, dagger, comm_override, profile); + } +#else + void ApplyOverlap(ColorSpinorField &, const ColorSpinorField &, const GaugeField &, double, const ColorSpinorField &, + int, bool, const int *, TimeProfile &) + { + errorQuda("Wilson dslash has not been built"); + } +#endif + +} // namespace quda diff --git a/lib/eig_iram.cpp b/lib/eig_iram.cpp index 4817cf0fc7..c26be5e113 100644 --- a/lib/eig_iram.cpp +++ b/lib/eig_iram.cpp @@ -50,7 +50,8 @@ namespace quda std::swap(v[j], r[0]); // r_{j} = M * v_{j}; - mat(r[0], v[j]); + // mat(r[0], v[j]); + chebyOp(r[0], v[j]); double beta_pre = sqrt(blas::norm2(r[0])); @@ -444,7 +445,8 @@ namespace quda // Apply a matrix op to the residual to place it in the // range of the operator - mat(r[0], kSpace[0]); + // mat(r[0], kSpace[0]); + chebyOp(r[0], kSpace[0]); // Convergence criteria double epsilon = setEpsilon(kSpace[0].Precision()); diff --git a/lib/eigensolve_quda.cpp b/lib/eigensolve_quda.cpp index 710d6ac13a..891ae3d821 100644 --- a/lib/eigensolve_quda.cpp +++ b/lib/eigensolve_quda.cpp @@ -280,7 +280,8 @@ namespace quda double b = eig_param->a_max; double delta = (b - a) / 2.0; double theta = (b + a) / 2.0; - double sigma1 = -delta / theta; + double lambda1 = b; // lambda1 = 0 before + double sigma1 = delta / (lambda1 - theta); double sigma; double d1 = sigma1 / delta; double d2 = 1.0; diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index b07a0b2c7b..4e58e9085f 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -1478,6 +1478,9 @@ namespace quda { diracParam.c_5[i].imag()); } break; + case QUDA_OVERLAP_DSLASH: + diracParam.type = pc ? QUDA_OVERLAPPC_DIRAC : QUDA_OVERLAP_DIRAC; + break; case QUDA_STAGGERED_DSLASH: diracParam.type = pc ? QUDA_STAGGEREDPC_DIRAC : QUDA_STAGGERED_DIRAC; break; @@ -1771,6 +1774,79 @@ namespace quda { logQuda(QUDA_DEBUG_VERBOSE, "Mass rescale: norm of source out = %g\n", blas::norm2(b)); } + + void setupHermitianWilson(QudaInvertParam *param, const lat_dim_t &X, std::vector &evecs, std::vector &evals) + { + DiracParam diracWilsonParam; + setDiracParam(diracWilsonParam, param, false); + Dirac *dWilson = new DiracWilson(diracWilsonParam); + + // Construct vectors + //------------------------------------------------------ + // Create host wrappers around application vector set + ColorSpinorParam cpuParam(nullptr, *param, X, QUDA_MAT_SOLUTION, QUDA_CUDA_FIELD_LOCATION); + + int n_eig = param->hermitian_wilson_n_ev; + + // Create device side ColorSpinorField vector space to pass to the + // compute function. Download any user supplied data as an initial guess. + ColorSpinorParam cudaParam(cpuParam, *param, QUDA_CUDA_FIELD_LOCATION); + cudaParam.create = QUDA_ZERO_FIELD_CREATE; + cudaParam.setPrecision(param->cuda_prec_eigensolver, param->cuda_prec_eigensolver, true); + // Ensure device vectors qre in UKQCD basis for Wilson type fermions + cudaParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS; + + for (int i = 0; i < n_eig; i++) { + evecs[i] = ColorSpinorField(cudaParam); + evals[i] = 0.0; + } + + //------------------------------------------------------ + // We must construct the correct Dirac operator type based on the three + // options: The normal operator, the daggered operator, and if we pre + // multiply by gamma5. Each combination requires a unique Dirac operator + // object. + + // Use MdagM=(G5M)^2 to make sure all eigenvalus>0 + DiracMatrix *mWilson = new DiracMdagM(*dWilson); + QudaEigParam eig_param = newQudaEigParam(); + int n_kr = param->hermitian_wilson_n_kr; + + eig_param.eig_type = QUDA_EIG_TR_LANCZOS; + eig_param.use_poly_acc = QUDA_BOOLEAN_TRUE; + eig_param.poly_deg = 50; + eig_param.a_min = 0.2 * 0.2; + eig_param.a_max = (1 + 8 * param->kappa) * (1 + 8 * param->kappa); + eig_param.use_dagger = QUDA_BOOLEAN_FALSE; + eig_param.use_norm_op = QUDA_BOOLEAN_TRUE; + eig_param.use_pc = QUDA_BOOLEAN_FALSE; + eig_param.compute_gamma5 = QUDA_BOOLEAN_FALSE; + eig_param.spectrum = QUDA_SPECTRUM_SR_EIG; + eig_param.n_ev = param->hermitian_wilson_n_ev; + eig_param.n_kr = param->hermitian_wilson_n_kr; + eig_param.n_conv = param->hermitian_wilson_n_ev; + eig_param.tol = param->hermitian_wilson_tol; + eig_param.vec_infile[0] = 0; + eig_param.vec_outfile[0] = 0; + eig_param.max_restarts = 1000; + + auto *eig_solve = quda::EigenSolver::create(&eig_param, *mWilson, profileEigensolve); + (*eig_solve)(evecs, evals); + delete eig_solve; + + // Recalculate eigenvalues + delete mWilson; + mWilson = new DiracG5M(*dWilson); + ColorSpinorField tmp(cudaParam); + for (int i = 0; i < n_eig; ++i) { + (*mWilson)(tmp, evecs[i]); + evals[i] = blas::cDotProduct(tmp, evecs[i]); + } + + delete mWilson; + delete dWilson; + + } } void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity parity) @@ -1879,6 +1955,17 @@ void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) setDiracParam(diracParam, inv_param, pc); Dirac *dirac = Dirac::create(diracParam); // create the Dirac operator + + // Setup eigensystem for hermitian Wilson operator + if (inv_param->dslash_type == QUDA_OVERLAP_DSLASH) { + const int n_eig = inv_param->hermitian_wilson_n_ev; + const double invsqrt_tol = inv_param->overlap_invsqrt_tol; + std::vector evecs(n_eig); + std::vector evals(n_eig); + setupHermitianWilson(inv_param, gauge.X(), evecs, evals); + ((DiracOverlap *)dirac)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + } + dirac->M(out, in); // apply the operator delete dirac; // clean up @@ -2651,6 +2738,19 @@ void invertQuda(void *hp_x, void *hp_b, QudaInvertParam *param) // and an eigensolver createDiracWithEig(d, dSloppy, dPre, dEig, *param, pc_solve); + // Setup eigensystem for hermitian Wilson operator + if (param->dslash_type == QUDA_OVERLAP_DSLASH) { + const int n_eig = param->hermitian_wilson_n_ev; + const double invsqrt_tol = param->overlap_invsqrt_tol; + std::vector evecs(n_eig); + std::vector evals(n_eig); + setupHermitianWilson(param, cudaGauge->X(), evecs, evals); + ((DiracOverlap *)d)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + ((DiracOverlap *)dSloppy)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + ((DiracOverlap *)dPre)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + ((DiracOverlap *)dEig)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + } + Dirac &dirac = *d; Dirac &diracSloppy = *dSloppy; Dirac &diracPre = *dPre; From 4691a74c076bde17a7392aeb130726d39b7563fe Mon Sep 17 00:00:00 2001 From: V3-vvv <1754124202@qq.com> Date: Thu, 19 Sep 2024 14:30:49 +0800 Subject: [PATCH 02/32] modify dirac_overlap codes to work with newest quda --- include/dirac_quda.h | 12 +++++----- lib/dirac_overlap.cpp | 56 ++++++++++++++++++++++++------------------- 2 files changed, 37 insertions(+), 31 deletions(-) diff --git a/include/dirac_quda.h b/include/dirac_quda.h index 5c16f38bd7..b9984f04b9 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -1331,13 +1331,13 @@ namespace quda { virtual ~DiracOverlap(); DiracOverlap& operator=(const DiracOverlap &dirac); - virtual void Dslash(ColorSpinorField &, const ColorSpinorField &, const QudaParity ) const { errorQuda("Not implemented!\n"); } - virtual void DslashXpay(ColorSpinorField &, const ColorSpinorField &, const QudaParity, const ColorSpinorField &, const double &) const { errorQuda("Not implemented!\n"); } - virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; - virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; + virtual void Dslash(cvector_ref &, cvector_ref &, const QudaParity ) const { errorQuda("Not implemented!\n"); } + virtual void DslashXpay(cvector_ref &, cvector_ref &, const QudaParity, cvector_ref &, const double &) const { errorQuda("Not implemented!\n"); } + virtual void M(cvector_ref &out, cvector_ref &in) const; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const; - virtual void prepare(ColorSpinorField* &src, ColorSpinorField* &sol, ColorSpinorField &x, ColorSpinorField &b, const QudaSolutionType) const; - virtual void reconstruct(ColorSpinorField &x, const ColorSpinorField &b, const QudaSolutionType) const; + virtual void prepare(cvector_ref &sol, cvector_ref &src, cvector_ref &x, cvector_ref &b, const QudaSolutionType) const; + virtual void reconstruct(cvector_ref &x, cvector_ref &b, const QudaSolutionType) const; virtual QudaDiracType getDiracType() const { return QUDA_OVERLAP_DIRAC; } diff --git a/lib/dirac_overlap.cpp b/lib/dirac_overlap.cpp index 62ed0177b5..7d7e8973e7 100644 --- a/lib/dirac_overlap.cpp +++ b/lib/dirac_overlap.cpp @@ -1,5 +1,6 @@ +#include "util_quda.h" #include -#include +// #include #include #include #include @@ -196,15 +197,15 @@ namespace quda #define flip(x) (x) = ((x) == QUDA_DAG_YES ? QUDA_DAG_NO : QUDA_DAG_YES) - void DiracOverlap::M(ColorSpinorField &out, const ColorSpinorField &in) const + void DiracOverlap::M(cvector_ref &out, cvector_ref &in) const { printfQuda("Entering DracOverlap::M\n"); - auto tmp1 = getFieldTmp(in); - auto tmp2 = getFieldTmp(in); - auto tmp3 = getFieldTmp(in); - auto tmp4 = getFieldTmp(in); - auto tmp5 = getFieldTmp(in); + auto tmp1 = getFieldTmp(in[0]); + auto tmp2 = getFieldTmp(in[0]); + auto tmp3 = getFieldTmp(in[0]); + auto tmp4 = getFieldTmp(in[0]); + auto tmp5 = getFieldTmp(in[0]); ColorSpinorField &b1 = tmp1; ColorSpinorField &b2 = tmp2; ColorSpinorField &Mb1 = tmp3; @@ -216,19 +217,21 @@ namespace quda const double lambda_max = (1 + 8 * kappa); const double rho = 4 - 1 / (2 * kappa); - // signLow(out, deflated, hermitian_wilson_evecs, hermitian_wilson_evals, hermitian_wilson_n_eig); + //signLow(out, deflated, hermitian_wilson_evecs, hermitian_wilson_evals, hermitian_wilson_n_eig); std::vector s(hermitian_wilson_n_eig); - blas::cDotProduct(s, hermitian_wilson_evecs, in); + blas::block::cDotProduct(s, hermitian_wilson_evecs, in[0]); for (int i = 0; i < hermitian_wilson_n_eig; i++) { s[i] *= -1; } - blas::caxpyz(s, hermitian_wilson_evecs, in, deflated); + blas::block::caxpyz(s, hermitian_wilson_evecs, in[0], deflated); for (int i = 0; i < hermitian_wilson_n_eig; i++) { s[i] *= -hermitian_wilson_evals[i] / abs(hermitian_wilson_evals[i]); } - out.zero(); - blas::caxpy(s, hermitian_wilson_evecs, out); - gamma5(out, out); + out[0].zero(); + + blas::block::caxpy(s, hermitian_wilson_evecs, out[0]); + + gamma5(out[0], out[0]); - // signHighPolynomial(b1, b2, Ab1, deflated, mat, remez_c, remez_n, epsilon, lambda_max); + //signHighPolynomial(b1, b2, Ab1, deflated, mat, remez_c, remez_n, epsilon, lambda_max); b1.zero(); b2.zero(); for (int k = remez_n; k >= 1; --k) { @@ -240,6 +243,7 @@ namespace quda blas::axpbypczw(remez_c[k], deflated, 2, Ab1, -1, b2, b2); std::swap(b1, b2); } + DiracWilson::M(Mb1, b1); flip(dagger); DiracWilson::M(Ab1, Mb1); @@ -247,31 +251,32 @@ namespace quda blas::axpby(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1); blas::axpbypczw(remez_c[0], deflated, 1, Ab1, -1, b2, b2); DiracWilson::M(b1, b2); - - blas::axpbypczw(rho, in, rho / lambda_max, b1, rho, out, out); + blas::axpbypczw(rho, in[0], rho / lambda_max, b1, rho, out[0], out[0]); } - void DiracOverlap::MdagM(ColorSpinorField &out, const ColorSpinorField &in) const + void DiracOverlap::MdagM(cvector_ref &out, cvector_ref &in) const { - checkFullSpinor(out, in); - auto tmp = getFieldTmp(in); + checkFullSpinor(out[0], in[0]); + auto tmp = getFieldTmp(in[0]); - M(tmp, in); - Mdag(out, tmp); + M(tmp, in[0]); + Mdag(out[0], tmp); } - void DiracOverlap::prepare(ColorSpinorField *&src, ColorSpinorField *&sol, ColorSpinorField &x, ColorSpinorField &b, + void DiracOverlap::prepare(cvector_ref &sol, cvector_ref &src, cvector_ref &x, cvector_ref &b, const QudaSolutionType solType) const { if (solType == QUDA_MATPC_SOLUTION || solType == QUDA_MATPCDAG_MATPC_SOLUTION) { errorQuda("Preconditioned solution requires a preconditioned solve_type"); } - src = &b; - sol = &x; + for (auto i = 0u; i < b.size(); i++) { + src[i] = const_cast(b[i]).create_alias(); + sol[i] = x[i].create_alias(); + } } - void DiracOverlap::reconstruct(ColorSpinorField &, const ColorSpinorField &, const QudaSolutionType) const + void DiracOverlap::reconstruct(cvector_ref &, cvector_ref &, const QudaSolutionType) const { // do nothing } @@ -284,6 +289,7 @@ namespace quda void DiracOverlap::setupHermitianWilson(int n_eig, const std::vector &evecs, const std::vector &evals, double invsqrt_tol) const { + printfQuda("Entering DiracOverlap::setupHermitianWilson\n"); hermitian_wilson_n_eig = n_eig; hermitian_wilson_evecs.resize(n_eig); hermitian_wilson_evals.resize(n_eig); From cb16be229fa31b2510ff1bec8a26a06a927fd5d8 Mon Sep 17 00:00:00 2001 From: V3-vvv <175412402@qq.com> Date: Thu, 12 Dec 2024 15:36:24 +0800 Subject: [PATCH 03/32] enable code to run overlap DiracM eigensolver --- lib/interface_quda.cpp | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 9b4292892e..4b670be915 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -105,6 +105,11 @@ CloverField *cloverPrecondition = nullptr; CloverField *cloverRefinement = nullptr; CloverField *cloverEigensolver = nullptr; +ColorSpinorField *OverlapPrecise = nullptr; +ColorSpinorField *OverlapSloppy = nullptr; +ColorSpinorField *OverlapPrecondition = nullptr; +ColorSpinorField *OverlapEigensolver = nullptr; + GaugeField momResident; GaugeField *extendedGaugeResident = nullptr; @@ -2776,7 +2781,8 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam // Create the dirac operator with a sloppy and a precon. bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) || (inv_param->solve_type == QUDA_NORMOP_PC_SOLVE); createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve); - Dirac &dirac = *dEig; + // Dirac &dirac = *dEig; + Dirac *dirac = dEig; //------------------------------------------------------ // Construct vectors @@ -2847,6 +2853,16 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam if (!eig_param->use_norm_op && !eig_param->use_dagger && eig_param->compute_gamma5) { m = new DiracG5M(dirac); } else if (!eig_param->use_norm_op && !eig_param->use_dagger && !eig_param->compute_gamma5) { + // Setup eigensystem for hermitian Wilson operator + if (inv_param->dslash_type == QUDA_OVERLAP_DSLASH) { + const auto &gauge = *gaugePrecise; + const int n_eig = inv_param->hermitian_wilson_n_ev; + const double invsqrt_tol = inv_param->overlap_invsqrt_tol; + std::vector evecs(n_eig); + std::vector evals(n_eig); + setupHermitianWilson(inv_param, gauge.X(), evecs, evals); + ((DiracOverlap*)dirac)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + } m = new DiracM(dirac); } else if (!eig_param->use_norm_op && eig_param->use_dagger) { m = new DiracMdag(dirac); From 9468668cc88c98a1fd23df97628e3442fced88ac Mon Sep 17 00:00:00 2001 From: V3-vvv <175412402@qq.com> Date: Mon, 16 Dec 2024 17:39:20 +0800 Subject: [PATCH 04/32] modify interface_quda.cpp for overlap settings --- lib/interface_quda.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 4b670be915..17c4229124 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -2844,6 +2844,18 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam eig_param->use_dagger ? "true" : "false", eig_param->use_norm_op ? "true" : "false"); } } + + // pre-setttings for the overlap operator + if (inv_param->dslash_type == QUDA_OVERLAP_DSLASH) { + const auto &gauge = *gaugePrecise; + const int n_eig = inv_param->hermitian_wilson_n_ev; + const double invsqrt_tol = inv_param->overlap_invsqrt_tol; + std::vector evecs(n_eig); + std::vector evals(n_eig); + setupHermitianWilson(inv_param, gauge.X(), evecs, evals); + ((DiracOverlap*)dirac)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); + } + //------------------------------------------------------ // We must construct the correct Dirac operator type based on the three // options: The normal operator, the daggered operator, and if we pre @@ -2853,16 +2865,6 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam if (!eig_param->use_norm_op && !eig_param->use_dagger && eig_param->compute_gamma5) { m = new DiracG5M(dirac); } else if (!eig_param->use_norm_op && !eig_param->use_dagger && !eig_param->compute_gamma5) { - // Setup eigensystem for hermitian Wilson operator - if (inv_param->dslash_type == QUDA_OVERLAP_DSLASH) { - const auto &gauge = *gaugePrecise; - const int n_eig = inv_param->hermitian_wilson_n_ev; - const double invsqrt_tol = inv_param->overlap_invsqrt_tol; - std::vector evecs(n_eig); - std::vector evals(n_eig); - setupHermitianWilson(inv_param, gauge.X(), evecs, evals); - ((DiracOverlap*)dirac)->setupHermitianWilson(n_eig, evecs, evals, invsqrt_tol); - } m = new DiracM(dirac); } else if (!eig_param->use_norm_op && eig_param->use_dagger) { m = new DiracMdag(dirac); From cbc3ebc5ee09cf8edd08f19a3b7338e619024894 Mon Sep 17 00:00:00 2001 From: V3-vvv <175412402@qq.com> Date: Mon, 16 Dec 2024 19:29:18 +0800 Subject: [PATCH 05/32] code modification --- lib/dirac_overlap.cpp | 43 +++--------------------------------------- lib/interface_quda.cpp | 3 ++- 2 files changed, 5 insertions(+), 41 deletions(-) diff --git a/lib/dirac_overlap.cpp b/lib/dirac_overlap.cpp index 7d7e8973e7..cf5143d071 100644 --- a/lib/dirac_overlap.cpp +++ b/lib/dirac_overlap.cpp @@ -158,43 +158,6 @@ namespace quda return *this; } - // Apply sign function for small eigenvalues by applying lambda_i/|lambda_i|*|V_i> &evecs, - const std::vector &evals, int n_eig) - { - logQuda(QUDA_VERBOSE, "Deflating %d vectors\n", n_eig); - - // = A_i - std::vector s(n_eig); - blas::cDotProduct(s, {evecs.begin(), evecs.begin() + n_eig}, src); - - // src -= A_i|V_i> - for (int i = 0; i < n_eig; i++) { s[i] *= -1; } - blas::caxpy(s, {evecs.begin(), evecs.begin() + n_eig}, src); - - // sol += lambda_i/|lambda_i|*A_i|V_i> - for (int i = 0; i < n_eig; i++) { s[i] *= -evals[i] / abs(evals[i]); } - blas::zero(sol); - blas::caxpy(s, {evecs.begin(), evecs.begin() + n_eig}, sol); - } - - void signHighPolynomial(ColorSpinorField &b1, ColorSpinorField &b2, ColorSpinorField &Ab1, const ColorSpinorField &in, - DiracMatrix *mat, std::vector &remez_c, int remez_n, const double epsilon, - const double lambda_max) - { - b1.zero(); - b2.zero(); - for (int k = remez_n; k >= 1; --k) { - (*mat)(Ab1, b1); - blas::axpbyz(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1, Ab1); - blas::axpbypczw(remez_c[k], in, 2, Ab1, -1, b2, b2); - std::swap(b1, b2); - } - (*mat)(Ab1, b1); - blas::axpbyz(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1, Ab1); - blas::axpbypczw(remez_c[0], in, 1, Ab1, -1, b2, b2); - } - #define flip(x) (x) = ((x) == QUDA_DAG_YES ? QUDA_DAG_NO : QUDA_DAG_YES) void DiracOverlap::M(cvector_ref &out, cvector_ref &in) const @@ -240,7 +203,7 @@ namespace quda DiracWilson::M(Ab1, Mb1); flip(dagger); blas::axpby(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1); - blas::axpbypczw(remez_c[k], deflated, 2, Ab1, -1, b2, b2); + blas::axpbypczw(remez_c[k], deflated, 2.0, Ab1, -1.0, b2, b2); std::swap(b1, b2); } @@ -249,7 +212,7 @@ namespace quda DiracWilson::M(Ab1, Mb1); flip(dagger); blas::axpby(-(1 + epsilon) / (1 - epsilon), b1, 2 / (1 - epsilon) / (lambda_max * lambda_max), Ab1); - blas::axpbypczw(remez_c[0], deflated, 1, Ab1, -1, b2, b2); + blas::axpbypczw(remez_c[0], deflated, 1.0, Ab1, -1.0, b2, b2); DiracWilson::M(b1, b2); blas::axpbypczw(rho, in[0], rho / lambda_max, b1, rho, out[0], out[0]); } @@ -317,4 +280,4 @@ namespace quda remez_n = remez_c.size() - 1; } -} // namespace quda +} // namespace quda \ No newline at end of file diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 17c4229124..2b90138c5e 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -1843,7 +1843,8 @@ namespace quda { eig_param.vec_outfile[0] = 0; eig_param.max_restarts = 1000; - auto *eig_solve = quda::EigenSolver::create(&eig_param, *mWilson, profileEigensolve); + // auto *eig_solve = quda::EigenSolver::create(&eig_param, *mWilson, profileEigensolve); + auto *eig_solve = quda::EigenSolver::create(&eig_param, *mWilson); (*eig_solve)(evecs, evals); delete eig_solve; From d506ae7d84b03802cb4a04ac2427322e5b903bf1 Mon Sep 17 00:00:00 2001 From: SaltyChiang Date: Thu, 26 Dec 2024 03:01:47 +0800 Subject: [PATCH 06/32] Add chrial projection for spinor field. --- include/enum_quda.h | 6 + include/kernels/spinor_chiral_project.cuh | 89 ++++++ .../kernels/spinor_chiral_project_helper.cuh | 285 ++++++++++++++++++ lib/CMakeLists.txt | 1 + lib/spinor_chiral_project.cu | 135 +++++++++ 5 files changed, 516 insertions(+) create mode 100644 include/kernels/spinor_chiral_project.cuh create mode 100644 include/kernels/spinor_chiral_project_helper.cuh create mode 100644 lib/spinor_chiral_project.cu diff --git a/include/enum_quda.h b/include/enum_quda.h index b48d154dc4..6a485fc414 100644 --- a/include/enum_quda.h +++ b/include/enum_quda.h @@ -640,6 +640,12 @@ typedef enum QudaExtLibType_s { QUDA_EXTLIB_INVALID = QUDA_INVALID_ENUM } QudaExtLibType; +typedef enum QudaChirality_s { + QUDA_CHIRALITY_UPPER, + QUDA_CHIRALITY_LOWER, + QUDA_CHIRALITY_INVALID = QUDA_INVALID_ENUM +} QudaChirality; + #ifdef __cplusplus } #endif diff --git a/include/kernels/spinor_chiral_project.cuh b/include/kernels/spinor_chiral_project.cuh new file mode 100644 index 0000000000..79e95156de --- /dev/null +++ b/include/kernels/spinor_chiral_project.cuh @@ -0,0 +1,89 @@ +#include +#include +#include +#include +#include "spinor_chiral_project_helper.cuh" + +namespace quda +{ + using namespace colorspinor; + + template class Basis_> + struct ChiralProjectSpinorArg : kernel_param<> { + using real = typename mapper::type; + static constexpr int nSpinOut = 2; + static constexpr int nSpinIn = 4; + static constexpr int nColor = nColor_; + static constexpr QudaChirality Chirality = Chirality_; + using Basis = Basis_<4, nColor_, Chirality_, true>; + using Vout = typename colorspinor_mapper::type; + using Vin = typename colorspinor_mapper::type; + + int X[4]; + Vout out; + const Vin in; + ChiralProjectSpinorArg(ColorSpinorField &out, const ColorSpinorField &in) : + kernel_param(dim3(in.VolumeCB(), in.SiteSubset(), 1)), out(out), in(in) + { + for (int dir = 0; dir < 4; dir++) X[dir] = in.X()[dir]; + X[0] *= (in.SiteSubset() == 1) ? 2 : 1; // need full lattice dims + } + }; + + template class Basis_> + struct ChrialEmbedSpinorArg : kernel_param<> { + using real = typename mapper::type; + static constexpr int nSpinOut = 4; + static constexpr int nSpinIn = 2; + static constexpr int nColor = nColor_; + static constexpr QudaChirality Chirality = Chirality_; + using Basis = Basis_<4, nColor_, Chirality_, false>; + using Vout = typename colorspinor_mapper::type; + using Vin = typename colorspinor_mapper::type; + + int X[4]; + Vout out; + const Vin in; + ChrialEmbedSpinorArg(ColorSpinorField &out, const ColorSpinorField &in) : + kernel_param(dim3(in.VolumeCB(), in.SiteSubset(), 1)), out(out), in(in) + { + for (int dir = 0; dir < 4; dir++) X[dir] = in.X()[dir]; + X[0] *= (in.SiteSubset() == 1) ? 2 : 1; // need full lattice dims + } + }; + + template struct ChiralEmbedSpinor { + const Arg &arg; + constexpr ChiralEmbedSpinor(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity) + { + using VectorOut = ColorSpinor; + using VectorIn = ColorSpinor; + int x[4]; + getCoords(x, x_cb, arg.X, parity); + VectorOut out; + VectorIn in = arg.in(x_cb, parity); + arg.out(x_cb, parity) = out; + } + }; + + template struct ChiralProjectSpinor { + const Arg &arg; + constexpr ChiralProjectSpinor(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity) + { + using VectorOut = ColorSpinor; + using VectorIn = ColorSpinor; + int x[4]; + getCoords(x, x_cb, arg.X, parity); + VectorOut out; + VectorIn in = arg.in(x_cb, parity); + arg.out(x_cb, parity) = out; + } + }; + +} // namespace quda diff --git a/include/kernels/spinor_chiral_project_helper.cuh b/include/kernels/spinor_chiral_project_helper.cuh new file mode 100644 index 0000000000..58988aedfb --- /dev/null +++ b/include/kernels/spinor_chiral_project_helper.cuh @@ -0,0 +1,285 @@ +#include +#include + +#define PRESERVE_SPINOR_NORM + +#ifdef PRESERVE_SPINOR_NORM // Preserve the norm regardless of basis +#define kP (1.0 / sqrt(2.0)) +#define kU (1.0 / sqrt(2.0)) +#else // More numerically accurate not to preserve the norm between basis +#define kP (0.5) +#define kU (1.0) +#endif + +namespace quda +{ + + using namespace colorspinor; + + /** Straight copy with no basis change */ + template struct PreserveBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + const int offset = (Chirality == QUDA_CHIRALITY_UPPER) ? 0 : Ns / 2; + if constexpr (Project) { + for (int s = 0; s < Ns / 2; s++) + for (int c = 0; c < Nc; c++) out[s * Nc + c] = in[(offset + s) * Nc + c]; + } else { + for (int s = 0; s < Ns / 2; s++) + for (int c = 0; c < Nc; c++) out[(offset + s) * Nc + c] = in[s * Nc + c]; + } + } + }; + + /** Transform from relativistic Degrand-Rossi into non-relativistic UKQCD basis */ + template struct NonRelBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {1, 2, 3, 0}; + int s2[4] = {3, 0, 1, 2}; + FloatOut K1[4] = {static_cast(kP), static_cast(-kP), static_cast(-kP), + static_cast(-kP)}; + FloatOut K2[4] = {static_cast(kP), static_cast(-kP), static_cast(kP), + static_cast(kP)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]) + + K2[offset + s] * static_cast>(in[s2[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {1, 0, 1, 0}; + FloatOut K1[4] = {static_cast(kP), static_cast(-kP), + upper ? static_cast(kP) : static_cast(-kP), + upper ? static_cast(-kP) : static_cast(kP)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from non-relativistic UKQCD into relativistic Degrand-Rossi basis */ + template struct RelBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {1, 2, 3, 0}; + int s2[4] = {3, 0, 1, 2}; + FloatOut K1[4] = {static_cast(-kU), static_cast(kU), static_cast(kU), + static_cast(kU)}; + FloatOut K2[4] = {static_cast(-kU), static_cast(kU), static_cast(-kU), + static_cast(-kU)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]) + + K2[offset + s] * static_cast>(in[s2[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {1, 0, 1, 0}; + FloatOut K1[4] = {static_cast(-kU), static_cast(kU), + upper ? static_cast(-kU) : static_cast(kU), + upper ? static_cast(kU) : static_cast(-kU)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from non-relativistic Dirac-Pauli into relativistic Degrand-Rossi basis */ + template struct DegrandRossiToDiracPaulBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {1, 2, 1, 0}; + int s2[4] = {3, 0, 3, 2}; + FloatOut K1[4] = {static_cast(-kU), static_cast(kU), static_cast(kU), + static_cast(-kU)}; + FloatOut K2[4] = {static_cast(-kU), static_cast(kU), static_cast(-kU), + static_cast(kU)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]) + + K2[offset + s] * static_cast>(in[s2[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {1, 0, 1, 0}; + FloatOut K1[4] = {static_cast(-kU), static_cast(kU), + upper ? static_cast(kU) : static_cast(-kU), + upper ? static_cast(-kU) : static_cast(kU)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from relativistic Degrand-Rossi into non-relativistic Dirac-Pauli basis */ + template struct DiracPaulToDegrandRossiBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {1, 2, 1, 0}; + int s2[4] = {3, 0, 3, 2}; + FloatOut K1[4] = {static_cast(kP), static_cast(kP), static_cast(kP), + static_cast(-kP)}; + FloatOut K2[4] = {static_cast(-kP), static_cast(-kP), static_cast(kP), + static_cast(-kP)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]) + + K2[offset + s] * static_cast>(in[s2[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {1, 0, 1, 0}; + FloatOut K1[4] = {upper ? static_cast(kP) : static_cast(-kP), + upper ? static_cast(-kP) : static_cast(kP), static_cast(kP), + static_cast(-kP)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from chiral into UKQCD non-relativistic basis */ + template struct ChiralToNonRelBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {0, 1, 0, 1}; + int s2[4] = {2, 3, 2, 3}; + FloatOut K1[4] = {static_cast(-kP), static_cast(-kP), static_cast(kP), + static_cast(kP)}; + FloatOut K2[4] + = {static_cast(kP), static_cast(kP), static_cast(kP), static_cast(kP)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]) + + K2[offset + s] * static_cast>(in[s2[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {0, 1, 0, 1}; + FloatOut K1[4] = {upper ? static_cast(-kP) : static_cast(kP), + upper ? static_cast(-kP) : static_cast(kP), static_cast(kP), + static_cast(kP)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from UKQCD non-relativistic into chiral basis */ + template struct NonRelToChiralBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {0, 1, 0, 1}; + int s2[4] = {2, 3, 2, 3}; + FloatOut K1[4] = {static_cast(-kU), static_cast(-kU), static_cast(kU), + static_cast(kU)}; + FloatOut K2[4] + = {static_cast(kU), static_cast(kU), static_cast(kU), static_cast(kU)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]) + + K2[offset + s] * static_cast>(in[s2[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {0, 1, 0, 1}; + FloatOut K1[4] = {upper ? static_cast(-kU) : static_cast(kU), + upper ? static_cast(-kU) : static_cast(kU), static_cast(kU), + static_cast(kU)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from chiral into DeGrand-Rossi basis or from DeGrand-Rossi into chiral basis */ + template struct ChiralToFromDegrandRossiBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {3, 2, 1, 0}; + FloatOut K1[4] = {static_cast(-1.0), static_cast(1.0), static_cast(1.0), + static_cast(-1.0)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {1, 0, 1, 0}; + FloatOut K1[4] = {upper ? static_cast(0.0) : static_cast(-1.0), + upper ? static_cast(0.0) : static_cast(1.0), + upper ? static_cast(1.0) : static_cast(0.0), + upper ? static_cast(-1.0) : static_cast(0.0)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + + /** Transform from UKQCD to Dirac-Pauli and from Dirac-Pauli into UKQCD basis */ + template struct UKQCDToFromDiracPauliBasis { + template + __device__ __host__ inline void operator()(complex out[Ns * Nc], const complex in[Ns * Nc]) const + { + constexpr bool upper = (Chirality == QUDA_CHIRALITY_UPPER); + constexpr int offset = upper ? 0 : Ns / 2; + if constexpr (Project) { + int s1[4] = {0, 1, 2, 3}; + FloatOut K1[4] = {static_cast(-1.0), static_cast(-1.0), static_cast(1.0), + static_cast(1.0)}; + for (int s = 0; s < Ns / 2; s++) { + for (int c = 0; c < Nc; c++) { + out[s * Nc + c] = K1[offset + s] * static_cast>(in[s1[offset + s] * Nc + c]); + } + } + } else { + int s1[4] = {0, 1, 0, 1}; + FloatOut K1[4] = {upper ? static_cast(-1.0) : static_cast(0.0), + upper ? static_cast(-1.0) : static_cast(0.0), + upper ? static_cast(0.0) : static_cast(1.0), + upper ? static_cast(0.0) : static_cast(1.0)}; + for (int s = 0; s < Ns; s++) { + for (int c = 0; c < Nc; c++) { out[s * Nc + c] = K1[s] * static_cast>(in[s1[s] * Nc + c]); } + } + } + } + }; + +} // namespace quda diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 954c5c4ffb..b3f4a7b9af 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -57,6 +57,7 @@ set (QUDA_OBJS clover_deriv_quda.cu clover_invert.cu copy_gauge_extended.cu extract_gauge_ghost_extended.cu copy_color_spinor.cpp spinor_noise.cu spinor_dilute.cu spinor_reweight.cu + spinor_chiral_project.cu copy_color_spinor_dd.cu copy_color_spinor_ds.cu copy_color_spinor_dh.cu copy_color_spinor_dq.cu copy_color_spinor_ss.cu copy_color_spinor_sd.cu diff --git a/lib/spinor_chiral_project.cu b/lib/spinor_chiral_project.cu new file mode 100644 index 0000000000..b134b18721 --- /dev/null +++ b/lib/spinor_chiral_project.cu @@ -0,0 +1,135 @@ +#include +#include +#include +#include + +namespace quda +{ + + template class SpinorChiralEmbed : TunableKernel2D + { + ColorSpinorField &out; + const ColorSpinorField ∈ + template