diff --git a/src/arraymancer/linear_algebra/special_matrices.nim b/src/arraymancer/linear_algebra/special_matrices.nim index a528023c..adf6a80e 100644 --- a/src/arraymancer/linear_algebra/special_matrices.nim +++ b/src/arraymancer/linear_algebra/special_matrices.nim @@ -4,7 +4,7 @@ import ../tensor import ./helpers/triangular -import std / [sequtils, bitops] +import std / [sequtils, bitops, strformat] proc hilbert*(n: int, T: typedesc[SomeFloat]): Tensor[T] = ## Generates an Hilbert matrix of shape [N, N] @@ -129,6 +129,7 @@ proc diagonal*[T](a: Tensor[T], k = 0, anti = false): Tensor[T] {.noInit.} = ## - anti: If true, get the k-th "anti-diagonal" instead of the k-th regular diagonal. ## Result: ## - A copy of the diagonal elements as a rank-1 tensor + bind `&` assert a.rank == 2, "diagonal() only works on matrices" assert k < a.shape[0], &"Diagonal index ({k=}) exceeds the output matrix height ({a.shape[0]})" assert k < a.shape[1], &"Diagonal index ({k=}) exceeds the output matrix width ({a.shape[1]})" @@ -167,6 +168,7 @@ proc set_diagonal*[T](a: var Tensor[T], d: Tensor[T], k = 0, anti = false) = ## - k: The index k of the diagonal that will be changed. The default is 0 (i.e. the main diagonal). ## Use k>0 for diagonals above the main diagonal, and k<0 for diagonals below the main diagonal. ## - anti: If true, set the k-th "anti-diagonal" instead of the k-th regular diagonal. + bind `&` assert a.rank == 2, "set_diagonal() only works on matrices" assert d.rank == 1, "The diagonal passed to set_diagonal() must be a rank-1 tensor" assert k < a.shape[0], &"Diagonal index ({k=}) exceeds input matrix height ({a.shape[0]})" @@ -259,6 +261,7 @@ proc tri*[T](shape: Metadata, k: static int = 0, upper: static bool = false): Te ## diagonal. The default is false. ## Result: ## - The constructed, rank-2 triangular tensor. + bind `&` assert shape.len == 2, &"tri() requires a rank-2 shape as it's input but a shape of rank {shape.len} was passed" assert k < shape[0], &"tri() received a diagonal index ({k=}) which exceeds the output matrix height ({shape[0]})" assert k < shape[1], &"tri() received a diagonal index ({k=}) which exceeds the output matrix width ({shape[1]})" diff --git a/tests/linear_algebra/test_linear_algebra.nim b/tests/linear_algebra/test_linear_algebra.nim index 5277debc..60c656c4 100644 --- a/tests/linear_algebra/test_linear_algebra.nim +++ b/tests/linear_algebra/test_linear_algebra.nim @@ -2,7 +2,7 @@ # Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). import ../../src/arraymancer -import std / [unittest, strformat] +import std / [unittest] proc main() = suite "Linear algebra": diff --git a/tests/linear_algebra/test_special_matrices.nim b/tests/linear_algebra/test_special_matrices.nim index 78514d10..d1baffbb 100644 --- a/tests/linear_algebra/test_special_matrices.nim +++ b/tests/linear_algebra/test_special_matrices.nim @@ -2,7 +2,7 @@ # Distributed under the Apache v2 License (license terms are at http://www.apache.org/licenses/LICENSE-2.0). import ../../src/arraymancer -import std / [unittest, strformat] +import std / [unittest] proc main() = suite "Diagonals":