-
Notifications
You must be signed in to change notification settings - Fork 9
Documentation
This module handles MPI configuration, backend selection, device initialization, and the QUDA library setup.
The init() function initializes the QUDA library to perform lattice QCD calculations on GPUs. You should set the grid to let QUDA know how to split the lattice. The following code defines a grid with size Gx, Gy, Gz, Gt = 1, 1, 1, 2, which means we use 2 GPUs to split the lattice into 2 parts in the t direction.
from pyquda import init
init([1, 1, 1, 2])mpiexec -n 2 is needed to run the code like this. The number of MPI processes should be equal to the grid volume. The default grid is [1, 1, 1, 1], and you can suppress it when you are using only 1 GPU to perform the calculation.
The init() function accepts many keyword arguments to configure the QUDA runtime:
init(
grid_size=[1, 1, 1, 2],
latt_size=[4, 4, 4, 8], # set the default lattice
backend="cupy", # "cupy", "torch", "numpy", or "dpnp"
resource_path=".cache", # QUDA tuning cache directory
enable_mps=False, # enable CUDA MPS
enable_gdr=False, # enable GPUDirect RDMA
)CAUTION: Initialization should be performed before any operation in PyQUDA.
There are functions to get the MPI and backend configuration:
-
getMPIComm()— returns the MPI communicator -
getMPISize()— returns the total number of MPI processes -
getMPIRank()— returns the rank of the current process -
getGridSize()— returns the grid size[Gx, Gy, Gz, Gt] -
getGridCoord()— returns the grid coordinate of the current process -
getGridRanks()— returns the mapping from grid coordinate to MPI rank -
getArrayBackend()— returns the backend name ("cupy","torch","numpy","dpnp") -
getArrayBackendTarget()— returns the backend target ("cuda","hip","cpu", etc.) -
getArrayDevice()— returns the current device index
pyquda provides a logger to make sure output or error will only be printed once. The logging level is the same as in logging package.
from pyquda import getLogger
logger = getLogger()
logger.debug("This is DEBUG")
logger.info("This is INFO")
logger.warning("This is WARNING", RuntimeWarning)
logger.error("This is ERROR", RuntimeError)
logger.critical("This is CRITICAL", RuntimeError)QUDA's functions and parameter struct are stored in submodule pyquda.pyquda, which is the main file to wrap up the QUDA functions. You should not call functions from this module directly unless you know what you are doing. QUDA parameter structs are also defined here, and it's not recommended to make such a struct by yourself.
from pyquda import QudaInvertParam
inv_param = QudaInvertParam()QUDA's enums are stored in submodule pyquda.enum_quda, which is basically a translation of QUDA's enum_quda.h.
from pyquda.enum_quda import QudaInverterType
inv_param.inv_type = QudaInverterType.QUDA_CG_INVERTERThis module defines the classes to store process-specific data on GPUs. Many pure gauge functions are bound to the LatticeGauge class for convenience.
This class handles the information to construct a lattice with multiple processes. It defines the size of the lattice and the grid to split the lattice. It also handles the extra information to define the lattice (the t-boundary and the anisotropy).
For example, the code below defines a lattice with the global size Lx, Ly, Lz, Lt = 4, 4, 4, 8. The lattice is anti-periodic on the t-boundary and isotropic.
from pyquda.field import LatticeInfo
latt_info = LatticeInfo([4, 4, 4, 8], t_boundary=-1, anisotropy=1.0)A default lattice can be set during init():
from pyquda import init
init([1, 1, 1, 2], latt_size=[4, 4, 4, 8])Once we get a LatticeInfo instance, we can create lattice field objects. The LatticeInfo instance is saved in the latt_info attribute.
from pyquda.field import LatticeGauge, LatticePropagator
gauge = LatticeGauge(latt_info)
propagator = LatticePropagator(latt_info)Data stored in fields can be accessed via the data attribute. Use getHost() to materialize a host-side copy, or copy() to deep-copy the entire field object.
gauge_data = gauge.data
gauge_host_copy = gauge.getHost()
gauge_copy = gauge.copy()location identifies where the data currently resides. toDevice() and toHost() transfer data between host and device memory.
print(gauge.location) # cupy
gauge.toHost()
print(gauge.location) # numpy
gauge.toDevice()
print(gauge.location) # cupyAll data is stored in even-odd preconditioned format. lexico() returns a numpy.ndarray copy of the data without even-odd preconditioning.
print(gauge.data.shape) # (Nd, 2, Lt, Lz, Ly, Lx // 2, Nc, Nc)
print(gauge.lexico().shape) # (Nd, Lt, Lz, Ly, Lx, Nc, Nc)
print(propagator.data.shape) # (2, Lt, Lz, Ly, Lx // 2, Ns, Ns, Nc, Nc)
print(propagator.lexico().shape) # (Lt, Lz, Ly, Lx, Ns, Ns, Nc, Nc)For LatticeGauge, you can also access individual directions and use even/odd parity components:
gauge_x = gauge[0] # LatticeLink for the x-direction
gauge_even = gauge.even # even sites
gauge_odd = gauge.odd # odd sitesData in the data attribute can be a numpy.ndarray, cupy.ndarray, or torch.Tensor depending on the backend. The dtype is complex128 in most field types and float64 in momentum and clover fields. All Dirac indices use the DeGrand-Rossi basis.
The shapes of data in different fields:
-
LatticeGauge.data—(Nd, 2, Lt, Lz, Ly, Lx // 2, Nc, Nc), row-column order, complex -
LatticeMom.data—(Nd, 2, Lt, Lz, Ly, Lx // 2, 10), lower triangle (6) + diagonal (3) + reserved (1), real -
LatticeClover.data—(2, Lt, Lz, Ly, Lx // 2, 2, ((Ns // 2) * Nc) ** 2), upper-left + lower-right block, diagonal (6) + lower triangle (30), real -
LatticeRotation.data—(1, 2, Lt, Lz, Ly, Lx // 2, Nc, Nc), complex -
LatticeFermion.data—(2, Lt, Lz, Ly, Lx // 2, Ns, Nc), complex -
LatticePropagator.data—(2, Lt, Lz, Ly, Lx // 2, Ns, Ns, Nc, Nc), sink-source order, complex -
LatticeStaggeredFermion.data—(2, Lt, Lz, Ly, Lx // 2, Nc), complex -
LatticeStaggeredPropagator.data—(2, Lt, Lz, Ly, Lx // 2, Nc, Nc), sink-source order, complex
Lattice fields support HDF5 I/O via h5py (requires h5py to be installed):
gauge.saveH5("gauge.h5")
gauge_loaded = LatticeGauge.loadH5("gauge.h5")
# Append to existing HDF5 file (e.g., accumulating configurations)
gauge.appendH5("trajectory.h5")
# Save in single precision for smaller file size
gauge.saveH5("gauge_fp32.h5", use_fp32=True)The I/O functions are defined in the pyquda_utils.io module (from pyquda_utils import io).
For example, we can read a gauge configuration generated by Chroma in .lime format:
from pyquda_utils import io
gauge = io.readChromaQIOGauge("weak_field.lime")The full list of supported I/O operations:
Chroma QIO (LIME) — read only (write is not supported)
readChromaQIOGauge(filename, checksum=True, reunitarize_sigma=1e-6)readChromaQIOPropagator(filename, checksum=True)readChromaQIOStaggeredPropagator(filename, checksum=True)
ILDG — read only
readILDGGauge(filename, checksum=True, reunitarize_sigma=1e-6)readILDGBinGauge(filename, dtype, latt_size)
MILC — read/write gauge, read propagator
readMILCGauge(filename, checksum=True, reunitarize_sigma=1e-6)writeMILCGauge(filename, gauge)readMILCQIOPropagator(filename)readMILCQIOStaggeredPropagator(filename)
NERSC — read/write
readNERSCGauge(filename, checksum=True, plaquette=True, link_trace=True, reunitarize_sigma=1e-6)writeNERSCGauge(filename, gauge, use_fp32=False, ensemble_id="PyQUDA", ensemble_label="", sequence_number=0)
OpenQCD — read/write
readOpenQCDGauge(filename, plaquette=True)writeOpenQCDGauge(filename, gauge)
KYU — read/write
readKYUGauge(filename, latt_size)writeKYUGauge(filename, gauge)readKYUPropagator(filename, latt_size)writeKYUPropagator(filename, propagator)
XQCD — read/write
readXQCDPropagator(filename, latt_size)readXQCDStaggeredPropagator(filename, latt_size)writeXQCDPropagator(filename, propagator)writeXQCDStaggeredPropagator(filename, propagator)
NPY (NumPy) — read/write
readNPYGauge(filename)writeNPYGauge(filename, gauge)readNPYPropagator(filename)writeNPYPropagator(filename, propagator)
Gamma basis conversion
-
rotateToDiracPauli(propagator)— convert from DeGrand-Rossi to Dirac-Pauli basis -
rotateToDeGrandRossi(propagator)— convert from Dirac-Pauli to DeGrand-Rossi basis
The following methods are bound to the LatticeGauge class:
Covariant operations:
-
covDev(x, covdev_mu)- Applies the covariant derivative on
xin directioncovdev_mu.xshould beLatticeFermion. 0/1/2/3 represent +x/+y/+z/+t and 4/5/6/7 represent -x/-y/-z/-t. The covariant derivative is defined as$\psi'(x)=U_\mu(x)\psi(x+\hat{\mu})$ .
- Applies the covariant derivative on
-
laplace(x, laplace3D)- Applies the Laplacian operator on
x, andlaplace3Dtakes 3 or 4 to apply Laplacian on spatial or all directions.xshould beLatticeStaggeredFermion. The Laplacian operator is defined as$\psi'(x)=\frac{1}{N_\mathrm{Lap}}\sum_\mu\psi(x)-\dfrac{1}{2}\left[U_\mu(x)\psi(x+\hat{\mu})+U_\mu^\dagger(x-\hat{\mu})\psi(x-\hat{\mu})\right]$
- Applies the Laplacian operator on
-
wuppertalSmear(x, n_steps, alpha)- Applies Wuppertal (Gaussian) smearing to
x(LatticeFermionorLatticeStaggeredFermion).n_stepsis the number of smearing steps andalphais the smearing strength.
- Applies Wuppertal (Gaussian) smearing to
Gauge field utilities:
-
staggeredPhase(applied)- Applies or removes the staggered phase from the gauge field.
appliedis a boolean indicating the current state. The convention is controlled byLatticeGauge.pure_gauge.gauge_param.staggered_phase_type, which defaults to the MILC convention.
- Applies or removes the staggered phase from the gauge field.
-
projectSU3(tol)- Projects the gauge field onto SU(3).
tolis the tolerance for deviation from SU(3).2e-15(10× fp64 epsilon) is a good choice.
- Projects the gauge field onto SU(3).
-
setAntiPeriodicT()- Applies anti-periodic boundary conditions in the temporal direction.
-
setAnisotropy()- Applies the anisotropy factor to temporal links.
Path and loop operations:
-
path(paths)-
pathsis a list of integer direction sequences defining gauge-transported paths.
-
-
loop(loops, coeff)-
loopsis a list of length 4:[loops_x, loops_y, loops_z, loops_t]. Allloops_*should have the same shape.coeffis a list of coefficients.
-
-
loopTrace(loops)- Returns the traces of all loops as a
numpy.ndarrayof dtypecomplex128. Each element is$\sum_x\mathrm{Tr},W(x)$ .
- Returns the traces of all loops as a
Gauge smearing:
-
apeSmear(n_steps, alpha, dir_ignore)- APE smearing.
alphais the smearing strength.dir_ignoreis the direction to exclude (-1 for none, 3 for spatial-only).
- APE smearing.
-
apeSmearChroma(n_steps, factor, dir_ignore)- A variant of
apeSmear()with Chroma's convention for thefactorparameter.
- A variant of
-
stoutSmear(n_steps, rho, dir_ignore)- Stout smearing.
rhois the smearing strength.
- Stout smearing.
-
hypSmear(n_steps, alpha1, alpha2, alpha3, dir_ignore)- HYP smearing.
alpha1/alpha2/alpha3are the smearing strengths on levels 3/2/1.
- HYP smearing.
Gradient flow:
-
wilsonFlow(n_steps, epsilon)- Wilson flow. Returns the energy (all, spatial, temporal) for each step.
-
wilsonFlowScale(max_steps, epsilon)- Returns
$t_0$ and$w_0$ with up tomax_stepsWilson flow steps.
- Returns
-
wilsonFlowChroma(n_steps, time)- Wilson flow using total
timeinstead of per-stepepsilon.
- Wilson flow using total
-
symanzikFlow(n_steps, epsilon)- Symanzik flow. Returns the energy (all, spatial, temporal) for each step.
-
symanzikFlowScale(max_steps, epsilon)- Returns
$t_0$ and$w_0$ with up tomax_stepsSymanzik flow steps.
- Returns
-
symanzikFlowChroma(n_steps, time)- Symanzik flow using total
timeinstead of per-stepepsilon.
- Symanzik flow using total
Observables:
-
plaquette()— returns (all, spatial, temporal) plaquette -
polyakovLoop()— returns (real, imaginary) Polyakov loop -
energy()— returns (all, spatial, temporal) energy density -
qcharge()— returns the topological charge -
qchargeDensity()— returns the topological charge density array
Random generation:
-
gauss(seed, sigma)- Fills the gauge field with random SU(3) matrices:
$U = \exp(\sigma H)$ , where$H$ is Gaussian-distributed su(3).sigma=0gives free field,sigma=1gives maximum disorder.
- Fills the gauge field with random SU(3) matrices:
Gauge fixing:
-
fixingOVR(gauge_dir, Nsteps, verbose_interval, relax_boost, tolerance, reunit_interval, stopWtheta)- Gauge fixing with over-relaxation (supports multi-GPU).
-
gauge_dir: 3 for Coulomb, 4 for Landau -
relax_boost: typical value 1.5 or 1.7 -
stopWtheta: 0 for MILC criterion, 1 for theta
-
- Gauge fixing with over-relaxation (supports multi-GPU).
-
fixingFFT(gauge_dir, Nsteps, verbose_interval, alpha, autotune, tolerance, stopWtheta)- Gauge fixing with FFT (single GPU only).
-
gauge_dir: 3 for Coulomb, 4 for Landau -
alpha: typical value 0.08 -
autotune: 1 to auto-tune alpha
-
- Gauge fixing with FFT (single GPU only).
High-level helpers live in pyquda_utils.core, with source construction in pyquda_utils.source. Import with from pyquda_utils import core, source.
pyquda_utils.core re-exports the most common runtime and field symbols, so many workflows can stay inside one namespace:
from pyquda_utils import core
core.init(resource_path=".cache/quda")
latt_info = core.LatticeInfo([4, 4, 4, 8], t_boundary=-1)
logger = core.getLogger()The following groups are re-exported from pyquda / pyquda.field:
- Runtime helpers:
init,getMPIComm,getMPISize,getMPIRank,getGridSize,getGridCoord,getGridMap,getArrayBackend,getArrayDevice,getLogger,setLoggerLevel - Common field constants:
Ns,Nc,Nd,X,Y,Z,T - Common field classes:
LatticeInfo,LatticeGauge,LatticePropagator,LatticeStaggeredPropagator,LatticeFermion,LatticeStaggeredFermion, and theirMulti*variants
The recommended constructors now take a LatticeInfo directly. This replaces the older getDslash() / getStaggeredDslash() style that accepted a local lattice size.
from pyquda_utils import core
latt_info = core.LatticeInfo([4, 4, 4, 8], t_boundary=-1, anisotropy=1.0)
wilson = core.getWilson(latt_info, mass=0.1, tol=1e-12, maxiter=1000)
clover = core.getClover(
latt_info,
mass=0.1,
tol=1e-12,
maxiter=1000,
xi_0=1.0,
clover_csw_t=1.2,
clover_csw_r=1.2,
)
hisq = core.getHISQ(latt_info, mass=0.01, tol=1e-12, maxiter=1000, naik_epsilon=0.0)Available constructors:
getDirac(latt_info, mass, tol, maxiter, xi_0=1.0, clover_coeff_t=0.0, clover_coeff_r=1.0, multigrid=None)getStaggeredDirac(latt_info, mass, tol, maxiter, tadpole_coeff=1.0, naik_epsilon=0.0)getWilson(latt_info, mass, tol, maxiter, multigrid=None)getClover(latt_info, mass, tol, maxiter, xi_0=1.0, clover_csw_t=..., clover_csw_r=..., multigrid=None)getStaggered(latt_info, mass, tol, maxiter, tadpole_coeff=1.0)getHISQ(latt_info, mass, tol, maxiter, naik_epsilon=0.0, multigrid=None)
multigrid accepts either a geometric block-size list such as [[2, 2, 2, 2], [4, 4, 4, 4]] or a Multigrid object.
Source construction helpers live in pyquda_utils.source:
from pyquda_utils import source
point = source.source(latt_info, "point", [0, 0, 0, 0], spin=0, color=0)
wall = source.source(latt_info, "wall", 4, spin=0, color=0)
volume = source.source(latt_info, "volume", None, spin=0, color=0)
prop_point = source.source12(latt_info, "point", [0, 0, 0, 0])
stag_point = source.source3(latt_info, "point", [0, 0, 0, 0])Current source types for source.source():
-
"point":t_srce=[x, y, z, t] -
"wall":t_srce=t -
"volume":t_srce=None -
"momentum": wall or volume source multiplied bysource_phase -
"colorvector": wall or volume color-vector source fromsource_phase
Convenience constructors:
-
multiFermion()/multiStaggeredFermion()build multi-RHS fermion collections -
propagator()/staggeredPropagator()build propagator-shaped sources -
gaussianSmear()applies Wuppertal smearing to fermions, propagators, andMulti*fields -
sequential(),sequential12(), andsequential3()keep only one sink time-slice
Passing a raw lattice-size list to source() / source12() / source3() is still accepted for compatibility, but it is deprecated. Prefer passing LatticeInfo.
pyquda_utils.phase defines reusable phase fields:
from pyquda_utils.phase import MomentumPhase, GridPhase
from pyquda_utils import source
momentum_phase = MomentumPhase(latt_info).getPhase([1, 2, 3])
grid_phase = GridPhase(latt_info, stride=[2, 2, 2, 2]).getPhase([1, 1, 1, 1])
momentum_source = source.source12(latt_info, "wall", 4, source_phase=momentum_phase)
grid_source = source.source12(latt_info, "volume", None, source_phase=momentum_phase * grid_phase)The core.invert*() family is now the main high-level inversion interface:
from pyquda_utils import core, source
dirac = core.getWilson(latt_info, mass=0.1, tol=1e-12, maxiter=1000)
dirac.loadGauge(gauge)
propag = core.invert(dirac, source_type="point", t_srce=[0, 0, 0, 0], mrhs=4)
seq_source = source.sequential12(propag, t_srce=4)
seq_propag = core.invertPropagator(dirac, seq_source, mrhs=4)Available helpers:
invert(dirac, source_type, t_srce, source_phase=None, mrhs=1, restart=0) -> LatticePropagatorinvertEigenvector(dirac, t_srce, source_propag, mrhs=1, restart=0) -> MultiLatticeFermioninvertSequential(dirac, source_propag, t_srce, mrhs=1, restart=0) -> LatticePropagatorinvertPropagator(dirac, source_propag, mrhs=1, restart=0) -> LatticePropagatorinvertStaggered(dirac, source_type, t_srce, source_phase=None, mrhs=1, restart=0) -> LatticeStaggeredPropagatorinvertStaggeredSequential(dirac, source_propag, t_srce, mrhs=1, restart=0) -> LatticeStaggeredPropagatorinvertStaggeredPropagator(dirac, source_propag, mrhs=1, restart=0) -> LatticeStaggeredPropagator
mrhs controls multi-right-hand-side batching, and restart applies recursive residual correction through invertMultiSrcRestart().
pyquda_utils.core also contains helpers for collecting lattice-shaped NumPy arrays over MPI:
gatherLattice2(data, tzyx, reduce_op="sum", root=0)scatterLattice(data_all, tzyx, root=0)gatherScatterLattice(data, tzyx, reduce_op="sum", root=0)gatherLattice(data, axes, reduce_op="sum", root=0)
gatherLattice2() / gatherScatterLattice() support sum, mean, prod, max, and min.
The following names are still re-exported from pyquda_utils.core for backward compatibility, but new code should avoid them:
| Deprecated name | Use instead |
|---|---|
getDslash() |
getDirac(), getWilson(), or getClover()
|
getStaggeredDslash() |
getStaggeredDirac(), getStaggered(), or getHISQ()
|
invert12() |
invert() or invertPropagator()
|
smear() |
LatticeGauge.stoutSmear() |
smear4() |
LatticeGauge.hypSmear() or LatticeGauge.stoutSmear()
|
cb2() |
evenodd() |
The pycontract plugin ships a high-level Python wrapper around the generated Cython bindings. The canonical usage is reflected in tests/test_pycontract.py.
Initialize PyQUDA first, then initialize the plugin:
from pyquda_utils import core
from pyquda_plugins import pycontract
core.init(resource_path=".cache/quda")
pycontract.init()Unlike the low-level generated binding, pycontract.init() does not take a device id. It reads the current device from the initialized PyQUDA runtime.
Typical usage:
from pyquda_utils import core, gamma, io
from pyquda_plugins import pycontract
propag = io.readQIOPropagator("pt_prop_0")
propag.toDevice()
gamma_2 = gamma.Gamma(2)
gamma_4 = gamma.Gamma(8)
gamma_5 = gamma.Gamma(15)
C = gamma_2 @ gamma_4
CG_A = C @ gamma_4 @ gamma_5
CG_B = C @ gamma_5
Pp = (gamma.Gamma(0) + gamma_4) / 2
meson = pycontract.mesonTwoPoint(propag, propag, CG_A, CG_B)
all_sink = pycontract.mesonAllSinkTwoPoint(propag, propag, CG_B)
baryon = pycontract.baryonTwoPoint(
propag,
propag,
propag,
pycontract.BaryonContractType.IK_JL_NM,
CG_A,
CG_B,
Pp,
)Current high-level wrapper functions:
mesonTwoPoint(propag_a, propag_b, gamma_ab, gamma_dc) -> LatticeComplexmesonAllSinkTwoPoint(propag_a, propag_b, gamma_dc) -> MultiLatticeComplexmesonAllSourceTwoPoint(propag_a, propag_b, gamma_ab) -> MultiLatticeComplexbaryonDiquark(propag_i, propag_j, gamma_ij, gamma_kl) -> LatticePropagatorbaryonTwoPoint(propag_i, propag_j, propag_n, contract_type, gamma_ij, gamma_kl, gamma_mn) -> LatticeComplexbaryonGeneralTwoPoint(propag_i, propag_j, propag_n, contract_type, gamma_ij, gamma_kl, project_mn) -> LatticeComplexbaryonSequentialTwoPoint(propag_i, propag_j, propag_n, contract_type, sequential_type, gamma_ij, gamma_kl, gamma_mn) -> LatticePropagatorbaryonTwoPoint_v2(propag_i, propag_j, propag_m, contract_type, gamma_ij, gamma_kl, gamma_mn) -> LatticeComplex
gamma_mn can be either a Gamma or a Projector. The wrapper expands Projector values into the corresponding linear combination of Gamma contractions automatically.