Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/pylibsparseir/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
SPIR_INTERNAL_ERROR = -7

# Statistics type constants
STATISTICS_FERMIONIC = 1
STATISTICS_BOSONIC = 0
SPIR_STATISTICS_FERMIONIC = 1
SPIR_STATISTICS_BOSONIC = 0

# Order type constants
ORDER_COLUMN_MAJOR = 1
ORDER_ROW_MAJOR = 0
SPIR_ORDER_COLUMN_MAJOR = 1
SPIR_ORDER_ROW_MAJOR = 0

# Make sure these are available at module level
SPIR_ORDER_ROW_MAJOR = 0
Expand Down
95 changes: 94 additions & 1 deletion src/pylibsparseir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from .ctypes_wrapper import spir_kernel, spir_sve_result, spir_basis, spir_funcs, spir_sampling
from pylibsparseir.constants import COMPUTATION_SUCCESS, ORDER_ROW_MAJOR, SPIR_TWORK_FLOAT64, SPIR_TWORK_FLOAT64X2
from pylibsparseir.constants import COMPUTATION_SUCCESS, SPIR_ORDER_ROW_MAJOR, SPIR_ORDER_COLUMN_MAJOR, SPIR_TWORK_FLOAT64, SPIR_TWORK_FLOAT64X2, SPIR_STATISTICS_FERMIONIC, SPIR_STATISTICS_BOSONIC

def _find_library():
"""Find the SparseIR shared library."""
Expand Down Expand Up @@ -138,6 +138,9 @@ def _setup_prototypes():
_lib.spir_basis_get_default_taus.argtypes = [spir_basis, POINTER(c_double)]
_lib.spir_basis_get_default_taus.restype = c_int

_lib.spir_basis_get_default_taus_ext.argtypes = [spir_basis, c_int, POINTER(c_double), POINTER(c_int)]
_lib.spir_basis_get_default_taus_ext.restype = c_int

_lib.spir_basis_get_n_default_ws.argtypes = [spir_basis, POINTER(c_int)]
_lib.spir_basis_get_n_default_ws.restype = c_int

Expand All @@ -147,16 +150,37 @@ def _setup_prototypes():
_lib.spir_basis_get_n_default_matsus.argtypes = [spir_basis, c_bool, POINTER(c_int)]
_lib.spir_basis_get_n_default_matsus.restype = c_int

_lib.spir_basis_get_n_default_matsus_ext.argtypes = [spir_basis, c_bool, c_int, POINTER(c_int)]
_lib.spir_basis_get_n_default_matsus_ext.restype = c_int

_lib.spir_basis_get_default_matsus.argtypes = [spir_basis, c_bool, POINTER(c_int64)]
_lib.spir_basis_get_default_matsus.restype = c_int

_lib.spir_basis_get_default_matsus_ext.argtypes = [spir_basis, c_bool, c_int, POINTER(c_int64), POINTER(c_int)]
_lib.spir_basis_get_default_matsus_ext.restype = c_int

# Sampling objects
_lib.spir_tau_sampling_new.argtypes = [spir_basis, c_int, POINTER(c_double), POINTER(c_int)]
_lib.spir_tau_sampling_new.restype = spir_sampling

_lib.spir_tau_sampling_new_with_matrix.argtypes = [c_int, c_int, c_int, c_int, POINTER(c_double), POINTER(c_double), POINTER(c_int)]
_lib.spir_tau_sampling_new_with_matrix.restype = spir_sampling

_lib.spir_matsu_sampling_new.argtypes = [spir_basis, c_bool, c_int, POINTER(c_int64), POINTER(c_int)]
_lib.spir_matsu_sampling_new.restype = spir_sampling

_lib.spir_matsu_sampling_new_with_matrix.argtypes = [
c_int, # order
c_int, # statistics
c_int, # basis_size
c_bool, # positive_only
c_int, # num_points
POINTER(c_int64), # points
POINTER(c_double_complex), # matrix
POINTER(c_int) # status
]
_lib.spir_matsu_sampling_new_with_matrix.restype = spir_sampling

# Sampling operations
_lib.spir_sampling_eval_dd.argtypes = [
spir_sampling, c_int, c_int, POINTER(c_int), c_int,
Expand Down Expand Up @@ -445,6 +469,15 @@ def basis_get_default_tau_sampling_points(basis):

return points

def basis_get_default_tau_sampling_points_ext(basis, n_points):
"""Get default tau sampling points for a basis."""
points = np.zeros(n_points, dtype=np.float64)
n_points_returned = c_int()
status = _lib.spir_basis_get_default_taus_ext(basis, n_points, points.ctypes.data_as(POINTER(c_double)), byref(n_points_returned))
if status != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to get default tau points: {status}")
return points

def basis_get_default_omega_sampling_points(basis):
"""Get default omega (real frequency) sampling points for a basis."""
# Get number of points
Expand Down Expand Up @@ -477,6 +510,22 @@ def basis_get_default_matsubara_sampling_points(basis, positive_only=False):

return points

def basis_get_n_default_matsus_ext(basis, n_points, positive_only):
"""Get the number of default Matsubara sampling points for a basis."""
n_points_returned = c_int()
status = _lib.spir_basis_get_n_default_matsus_ext(basis, c_bool(positive_only), n_points, byref(n_points_returned))
if status != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to get number of default Matsubara points: {status}")
return n_points_returned.value

def basis_get_default_matsus_ext(basis, positive_only, points):
n_points = len(points)
n_points_returned = c_int()
status = _lib.spir_basis_get_default_matsus_ext(basis, c_bool(positive_only), n_points, points.ctypes.data_as(POINTER(c_int64)), byref(n_points_returned))
if status != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to get default Matsubara points: {status}")
return points

def tau_sampling_new(basis, sampling_points=None):
"""Create a new tau sampling object."""
if sampling_points is None:
Expand All @@ -496,6 +545,32 @@ def tau_sampling_new(basis, sampling_points=None):

return sampling

def _statistics_to_c(statistics):
"""Convert statistics to c type."""
if statistics == "F":
return SPIR_STATISTICS_FERMIONIC
elif statistics == "B":
return SPIR_STATISTICS_BOSONIC
else:
raise ValueError(f"Invalid statistics: {statistics}")

def tau_sampling_new_with_matrix(basis, statistics, sampling_points, matrix):
"""Create a new tau sampling object with a matrix."""
status = c_int()
sampling = _lib.spir_tau_sampling_new_with_matrix(
SPIR_ORDER_ROW_MAJOR,
_statistics_to_c(statistics),
basis.size,
sampling_points.size,
sampling_points.ctypes.data_as(POINTER(c_double)),
matrix.ctypes.data_as(POINTER(c_double)),
byref(status)
)
if status.value != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to create tau sampling: {status.value}")

return sampling

def matsubara_sampling_new(basis, positive_only=False, sampling_points=None):
"""Create a new Matsubara sampling object."""
if sampling_points is None:
Expand All @@ -514,3 +589,21 @@ def matsubara_sampling_new(basis, positive_only=False, sampling_points=None):
raise RuntimeError(f"Failed to create Matsubara sampling: {status.value}")

return sampling

def matsubara_sampling_new_with_matrix(statistics, basis_size, positive_only, sampling_points, matrix):
"""Create a new Matsubara sampling object with a matrix."""
status = c_int()
sampling = _lib.spir_matsu_sampling_new_with_matrix(
SPIR_ORDER_ROW_MAJOR, # order
_statistics_to_c(statistics), # statistics
c_int(basis_size), # basis_size
c_bool(positive_only), # positive_only
c_int(len(sampling_points)), # num_points
sampling_points.ctypes.data_as(POINTER(c_int64)), # points
matrix.ctypes.data_as(POINTER(c_double_complex)), # matrix
byref(status) # status
)
if status.value != COMPUTATION_SUCCESS:
raise RuntimeError(f"Failed to create Matsubara sampling: {status.value}")

return sampling
88 changes: 88 additions & 0 deletions src/sparse_ir/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (C) 2020-2022 Markus Wallerberger, Hiroshi Shinaoka, and others
# SPDX-License-Identifier: MIT
import functools
import numpy as np


def ravel_argument(last_dim=False):
"""Wrap function operating on 1-D numpy array to allow arbitrary shapes.

This decorator allows to write functions which only need to operate over
one-dimensional (ravelled) arrays. This often simplifies the "shape logic"
of the computation.
"""
return lambda fn: RavelArgumentDecorator(fn, last_dim)


class RavelArgumentDecorator(object):
def __init__(self, inner, last_dim=False):
self.instance = None
self.inner = inner
self.last_dim = last_dim
functools.update_wrapper(self, inner)

def __get__(self, instance, _owner=None):
self.instance = instance
return self

def __call__(self, x):
x = np.asarray(x)
if self.instance is None:
res = self.inner(x.ravel())
else:
res = self.inner(self.instance, x.ravel())
if self.last_dim:
return res.reshape(res.shape[:-1] + x.shape)
else:
return res.reshape(x.shape + res.shape[1:])


def check_reduced_matsubara(n, zeta=None):
"""Checks that ``n`` is a reduced Matsubara frequency.

Check that the argument is a reduced Matsubara frequency, which is an
integer obtained by scaling the freqency `w[n]` as follows::

beta / np.pi * w[n] == 2 * n + zeta

Note that this means that instead of a fermionic frequency (``zeta == 1``),
we expect an odd integer, while for a bosonic frequency (``zeta == 0``),
we expect an even one. If ``zeta`` is omitted, any one is fine.
"""
n = np.asarray(n)
if not np.issubdtype(n.dtype, np.integer):
nfloat = n
n = nfloat.astype(int)
if not (n == nfloat).all():
raise ValueError("reduced frequency n must be integer")
if zeta is not None:
if not (n & 1 == zeta).all():
raise ValueError("n have wrong parity")
return n


def check_range(x, xmin, xmax):
"""Checks each element is in range [xmin, xmax]"""
x = np.asarray(x)
if not (x >= xmin).all():
raise ValueError(f"Some x violate lower bound {xmin}")
if not (x <= xmax).all():
raise ValueError(f"Some x violate upper bound {xmax}")
return x


def check_svd_result(svd_result, matrix_shape=None):
"""Checks that argument is a valid SVD triple (u, s, vH)"""
u, s, vH = map(np.asarray, svd_result)
m_u, k_u = u.shape
k_s, = s.shape
k_v, n_v = vH.shape
if k_u != k_s or k_s != k_v:
raise ValueError("shape mismatch between SVD elements:"
f"({m_u}, {k_u}) x ({k_s}) x ({k_v}, {n_v})")
if matrix_shape is not None:
m, n = matrix_shape
if m_u != m or n_v != n:
raise ValueError(f"shape mismatch between SVD ({m_u}, {n_v}) "
f"and matrix ({m}, {n})")
return u, s, vH
Loading
Loading