Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ requirements:
- scipy
- setuptools
- wheel
- spm-lab::pylibsparseir >=0.7.0,<0.8.0
- spm-lab::pylibsparseir >=0.8.0,<0.9.0
run:
- numpy
- python
- scipy
- spm-lab::pylibsparseir >=0.7.0,<0.8.0
- spm-lab::pylibsparseir >=0.8.0,<0.9.0

#test:
# imports:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy",
"scipy",
"pylibsparseir>=0.7.0,<0.8.0",
"pylibsparseir>=0.8.0,<0.9.0",
]
authors = [
{name = "SpM-lab"}
Expand Down
49 changes: 49 additions & 0 deletions src/sparse_ir/poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def funcs_get_slice(funcs_ptr, indices):
raise RuntimeError(f"Failed to get basis function {indices}: {status.value}")
return FunctionSet(funcs)

def funcs_clone(funcs_ptr):
"""Clone a function set."""
cloned = _lib.spir_funcs_clone(funcs_ptr)
if not cloned:
raise RuntimeError("Failed to clone function set")
return FunctionSet(cloned)

def funcs_deriv(funcs_ptr, n):
"""Compute the n-th derivative of a function set."""
status = c_int()
deriv_funcs = _lib.spir_funcs_deriv(funcs_ptr, n, status)
if status.value != 0:
raise RuntimeError(f"Failed to compute derivative of order {n}: {status.value}")
if not deriv_funcs:
raise RuntimeError(f"Failed to compute derivative of order {n}")
return FunctionSet(deriv_funcs)

def funcs_ft_get_slice(funcs_ptr, indices):
status = c_int()
indices = np.asarray(indices, dtype=np.int32)
Expand Down Expand Up @@ -134,6 +151,25 @@ def __getitem__(self, index):

return funcs_get_slice(self._ptr, indices)

def deriv(self, n=1):
"""Compute the n-th derivative of the basis functions.

Args:
n (int): Order of the derivative (default: 1)

Returns:
FunctionSet: New function set representing the n-th derivative
"""
if self._released:
raise RuntimeError("Function set has been released")
if n < 0:
raise ValueError("Derivative order must be non-negative")
if n == 0:
# Return a clone
return funcs_clone(self._ptr)

return funcs_deriv(self._ptr, n)

def release(self):
"""Manually release the function set."""
if not self._released and self._ptr:
Expand Down Expand Up @@ -374,6 +410,19 @@ def __getitem__(self, index):
self._xmax, self._period,
self._default_overlap_range)

def deriv(self, n=1):
"""Compute the n-th derivative of the basis functions.

Args:
n (int): Order of the derivative (default: 1)

Returns:
PiecewiseLegendrePolyVector: New polynomial vector representing the n-th derivative
"""
deriv_funcs = self._funcs.deriv(n)
return PiecewiseLegendrePolyVector(deriv_funcs, self._xmin, self._xmax,
self._period, self._default_overlap_range)


def overlap(self, f, xmin: float = None, xmax: float = None, *, rtol=2.3e-16, return_error=False, points=None):
r"""Evaluate overlap integral of this polynomial with function ``f``.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,4 @@ def test_invalid_parameters(self):
sve_result_new(kernel, 1e-20) # Very small epsilon
except RuntimeError:
# This is acceptable - very small epsilon might fail
pass
pass