diff --git a/include/tensorwrapper/types/floating_point.hpp b/include/tensorwrapper/types/floating_point.hpp index 7e2de193..4c673856 100644 --- a/include/tensorwrapper/types/floating_point.hpp +++ b/include/tensorwrapper/types/floating_point.hpp @@ -92,10 +92,10 @@ T pow(T value, double pow) { #define TW_APPLY_FLOATING_POINT_TYPES(MACRO_IN) \ MACRO_IN(float); \ MACRO_IN(double); \ - MACRO_IN(types::ufloat); \ - MACRO_IN(types::udouble); \ - MACRO_IN(types::ifloat); \ - MACRO_IN(types::idouble); + MACRO_IN(tensorwrapper::types::ufloat); \ + MACRO_IN(tensorwrapper::types::udouble); \ + MACRO_IN(tensorwrapper::types::ifloat); \ + MACRO_IN(tensorwrapper::types::idouble); } // namespace tensorwrapper::types WTF_REGISTER_FP_TYPE(tensorwrapper::types::ufloat); @@ -147,3 +147,11 @@ T pow(T value, double pow) { } // namespace tensorwrapper::types #endif + +#define DECLARE_WTF_CONTIGUOUS(TYPE) \ + extern template class wtf::buffer::detail_::ContiguousModel; \ + extern template class wtf::buffer::detail_::ContiguousViewModel; + +TW_APPLY_FLOATING_POINT_TYPES(DECLARE_WTF_CONTIGUOUS); + +#undef DECLARE_WTF_CONTIGUOUS diff --git a/src/python/module.cpp b/src/python/module.cpp index cc45470d..98c028c9 100644 --- a/src/python/module.cpp +++ b/src/python/module.cpp @@ -16,7 +16,6 @@ #include "tensor/export_tensor.hpp" #include -#include namespace tensorwrapper { diff --git a/src/python/tensor/export_tensor.cpp b/src/python/tensor/export_tensor.cpp index 31877149..8e2d02da 100644 --- a/src/python/tensor/export_tensor.cpp +++ b/src/python/tensor/export_tensor.cpp @@ -17,7 +17,9 @@ #include "export_tensor.hpp" #include #include -#include +#include +#include +#include namespace tensorwrapper { namespace { diff --git a/src/tensorwrapper/types/floating_point.cpp b/src/tensorwrapper/types/floating_point.cpp new file mode 100644 index 00000000..5aec097f --- /dev/null +++ b/src/tensorwrapper/types/floating_point.cpp @@ -0,0 +1,25 @@ +/* + * Copyright 2026 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define DEFINE_WTF_CONTIGUOUS(TYPE) \ + template class wtf::buffer::detail_::ContiguousModel; \ + template class wtf::buffer::detail_::ContiguousViewModel; + +TW_APPLY_FLOATING_POINT_TYPES(DEFINE_WTF_CONTIGUOUS); + +#undef DEFINE_WTF_CONTIGUOUS diff --git a/tests/python/unit_tests/tensor/test_tensor.py b/tests/python/unit_tests/tensor/test_tensor.py index 88e5768a..ca99ba21 100644 --- a/tests/python/unit_tests/tensor/test_tensor.py +++ b/tests/python/unit_tests/tensor/test_tensor.py @@ -32,12 +32,9 @@ def test_rank(self): self.assertEqual(self.matrix_from_cpp.rank(), 2) def test_equality(self): - pass - - # XXX: In the CI, this breaks with Clang and a non-sensical stack trace - # self.assertTrue(self.scalar == self.scalar_from_cpp) - # self.assertTrue(self.vector == self.vector_from_cpp) - # self.assertTrue(self.matrix == self.matrix_from_cpp) + self.assertTrue(self.scalar == self.scalar_from_cpp) + self.assertTrue(self.vector == self.vector_from_cpp) + self.assertTrue(self.matrix == self.matrix_from_cpp) def test_inequality(self): self.assertTrue(self.defaulted != self.scalar)