diff --git a/.gitignore b/.gitignore index 15e3b61..c6602b1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ __pycache__/ .Python .venv/ .tox/ +.ruff_cache/ env/ build/ develop-eggs/ diff --git a/bin/target_driver.sh b/bin/target_driver.sh index 7695bd2..7d1001a 100755 --- a/bin/target_driver.sh +++ b/bin/target_driver.sh @@ -12,6 +12,7 @@ git checkout "$version" git pull origin "$version" cd .. cp driver/tests/unit/common/codec/packstream/v1/test_packstream.py tests/codec/packstream/v1/from_driver/test_packstream.py +cp driver/tests/unit/common/codec/packstream/v2/test_packstream.py tests/codec/packstream/v2/from_driver/test_packstream.py cp driver/tests/unit/common/codec/packstream/test_structure.py tests/codec/packstream/from_driver/test_structure.py cp -r driver/tests/unit/common/vector/* tests/vector/from_driver diff --git a/changelog.d/86.feature.md b/changelog.d/86.feature.md new file mode 100644 index 0000000..e2415eb --- /dev/null +++ b/changelog.d/86.feature.md @@ -0,0 +1,2 @@ +Add support for PackStream v2. +This is required to support Bolt 6.1 which introduces `UUID`s. diff --git a/driver b/driver index c245423..e847904 160000 --- a/driver +++ b/driver @@ -1 +1 @@ -Subproject commit c2454236512448d7b55abc6b10e33e2eaae3b8bf +Subproject commit e84790445f08ce2ea164941a95355855b7bc5756 diff --git a/src/codec/packstream.rs b/src/codec/packstream.rs index 547966c..44506ae 100644 --- a/src/codec/packstream.rs +++ b/src/codec/packstream.rs @@ -13,7 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod uuid; mod v1; +mod v2; use pyo3::basic::CompareOp; use pyo3::exceptions::{PyIndexError, PyValueError}; @@ -33,6 +35,10 @@ pub(super) fn init_module(m: &Bound, name: &str) -> PyResult<()> { m.add_submodule(&mod_v1)?; v1::init_module(&mod_v1, format!("{name}.v1").as_str())?; + let mod_v2 = PyModule::new(py, "v2")?; + m.add_submodule(&mod_v2)?; + v2::init_module(&mod_v2, format!("{name}.v2").as_str())?; + m.add_class::()?; Ok(()) diff --git a/src/codec/packstream/uuid.rs b/src/codec/packstream/uuid.rs new file mode 100644 index 0000000..4d79b4b --- /dev/null +++ b/src/codec/packstream/uuid.rs @@ -0,0 +1,23 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use pyo3::prelude::*; +use pyo3::sync::PyOnceLock; +use pyo3::types::PyType; + +pub(crate) fn get_uuid_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> { + static UUID_CLS: PyOnceLock> = PyOnceLock::new(); + UUID_CLS.import(py, "uuid", "UUID") +} diff --git a/src/codec/packstream/v1.rs b/src/codec/packstream/v1.rs index c20330c..43591d4 100644 --- a/src/codec/packstream/v1.rs +++ b/src/codec/packstream/v1.rs @@ -13,10 +13,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod pack; -mod unpack; +pub(super) mod extension; +pub(super) mod pack; +pub(super) mod unpack; use pyo3::prelude::*; +use pyo3::types::{PyByteArray, PyBytes, PyDict}; use pyo3::wrap_pyfunction; use crate::register_package; @@ -46,11 +48,30 @@ const BYTES_8: u8 = 0xCC; const BYTES_16: u8 = 0xCD; const BYTES_32: u8 = 0xCE; +#[pyfunction] +#[pyo3(name = "pack", signature = (value, dehydration_hooks=None))] +fn pack_fn<'py>( + value: &Bound<'py, PyAny>, + dehydration_hooks: Option<&Bound<'py, PyAny>>, +) -> PyResult> { + pack::pack::(value, dehydration_hooks) +} + +#[pyfunction] +#[pyo3(name = "unpack", signature = (bytes, idx, hydration_hooks=None))] +fn unpack_fn( + bytes: Bound, + idx: usize, + hydration_hooks: Option>, +) -> PyResult<(Py, usize)> { + unpack::unpack::(bytes, idx, hydration_hooks) +} + pub(crate) fn init_module(m: &Bound, name: &str) -> PyResult<()> { register_package(m, name)?; - m.add_function(wrap_pyfunction!(unpack::unpack, m)?)?; - m.add_function(wrap_pyfunction!(pack::pack, m)?)?; + m.add_function(wrap_pyfunction!(unpack_fn, m)?)?; + m.add_function(wrap_pyfunction!(pack_fn, m)?)?; Ok(()) } diff --git a/src/codec/packstream/v1/extension.rs b/src/codec/packstream/v1/extension.rs new file mode 100644 index 0000000..adf7d0f --- /dev/null +++ b/src/codec/packstream/v1/extension.rs @@ -0,0 +1,63 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ffi::CStr; + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +use super::super::uuid::get_uuid_cls; +use super::pack::PackStreamEncoder; +use super::unpack::PackStreamDecoder; + +pub(crate) trait PackStreamV1Ext: Sized { + fn type_mapping_import() -> &'static CStr; + fn pack_ext( + value: &'_ Bound, + encoder: &mut PackStreamEncoder<'_, Self>, + ) -> PyResult; + fn unpack_ext(marker: u8, decoder: &mut PackStreamDecoder) + -> PyResult>>; +} + +pub(crate) struct PackStreamV1BaseExt {} + +impl PackStreamV1Ext for PackStreamV1BaseExt { + #[inline] + fn type_mapping_import() -> &'static CStr { + c"from neo4j._codec.packstream.v1.types import *" + } + + #[inline] + fn pack_ext(value: &'_ Bound, _: &mut PackStreamEncoder<'_, Self>) -> PyResult { + let py = value.py(); + + let uuid_cls = get_uuid_cls(py)?; + if value.is_instance(uuid_cls)? { + return Err(PyErr::new::(format!( + "Values of type {} are not supported \ + (requires Bolt protocol version 6.1 or newer)", + value.get_type().str()? + ))); + } + + Ok(false) + } + + #[inline] + fn unpack_ext(_: u8, _: &mut PackStreamDecoder) -> PyResult>> { + Ok(None) + } +} diff --git a/src/codec/packstream/v1/pack.rs b/src/codec/packstream/v1/pack.rs index 8467cb5..9ed712d 100644 --- a/src/codec/packstream/v1/pack.rs +++ b/src/codec/packstream/v1/pack.rs @@ -14,6 +14,7 @@ // limitations under the License. use std::borrow::Cow; +use std::marker::PhantomData; use std::sync::OnceLock; use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError}; @@ -24,6 +25,7 @@ use pyo3::types::{PyByteArray, PyBytes, PyDict, PyString, PyTuple, PyType}; use pyo3::{intern, IntoPyObjectExt}; use super::super::Structure; +use super::extension::PackStreamV1Ext; use super::{ BYTES_16, BYTES_32, BYTES_8, FALSE, FLOAT_64, INT_16, INT_32, INT_64, INT_8, LIST_16, LIST_32, LIST_8, MAP_16, MAP_32, MAP_8, NULL, STRING_16, STRING_32, STRING_8, TINY_LIST, TINY_MAP, @@ -130,46 +132,42 @@ impl TypeMappings { } } -static TYPE_MAPPINGS: OnceLock> = OnceLock::new(); +fn get_type_mappings(py: Python<'_>) -> PyResult<&'static TypeMappings> { + static TYPE_MAPPINGS: OnceLock> = OnceLock::new(); -fn get_type_mappings(py: Python<'_>) -> PyResult<&'static TypeMappings> { let mappings = TYPE_MAPPINGS.get_or_init_py_attached(py, || { let locals = PyDict::new(py); - py.run( - c"from neo4j._codec.packstream.v1.types import *", - None, - Some(&locals), - )?; + py.run(E::type_mapping_import(), None, Some(&locals))?; TypeMappings::new(&locals) }); mappings.as_ref().map_err(|e| e.clone_ref(py)) } -#[pyfunction] -#[pyo3(signature = (value, dehydration_hooks=None))] -pub(super) fn pack<'py>( +pub(crate) fn pack<'py, E: PackStreamV1Ext>( value: &Bound<'py, PyAny>, dehydration_hooks: Option<&Bound<'py, PyAny>>, ) -> PyResult> { let py = value.py(); - let type_mappings = get_type_mappings(py)?; - let mut encoder = PackStreamEncoder::new(dehydration_hooks, type_mappings); + let type_mappings = get_type_mappings::(py)?; + let mut encoder = PackStreamEncoder::::new(dehydration_hooks, type_mappings); encoder.write(value)?; Ok(PyBytes::new(py, &encoder.buffer)) } -struct PackStreamEncoder<'a> { +pub(crate) struct PackStreamEncoder<'a, E: PackStreamV1Ext> { + ext: PhantomData, dehydration_hooks: Option<&'a Bound<'a, PyAny>>, type_mappings: &'a TypeMappings, buffer: Vec, } -impl<'a> PackStreamEncoder<'a> { +impl<'a, E: PackStreamV1Ext> PackStreamEncoder<'a, E> { fn new( dehydration_hooks: Option<&'a Bound<'a, PyAny>>, type_mappings: &'a TypeMappings, ) -> Self { Self { + ext: PhantomData, dehydration_hooks, type_mappings, buffer: Default::default(), @@ -243,6 +241,10 @@ impl<'a> PackStreamEncoder<'a> { }); } + if E::pack_ext(value, self)? { + return Ok(()); + } + if let Ok(value) = value.extract::>() { let value_ref = value.borrow(); let size = value_ref.fields.len().try_into().map_err(|_| { @@ -407,4 +409,9 @@ impl<'a> PackStreamEncoder<'a> { self.buffer.extend(&[TINY_STRUCT + size, tag]); Ok(()) } + + #[inline] + pub(crate) fn write_raw<'b, I: IntoIterator>(&mut self, iter: I) { + self.buffer.extend(iter) + } } diff --git a/src/codec/packstream/v1/unpack.rs b/src/codec/packstream/v1/unpack.rs index 865cff3..7c1b4da 100644 --- a/src/codec/packstream/v1/unpack.rs +++ b/src/codec/packstream/v1/unpack.rs @@ -13,6 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::marker::PhantomData; + use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::sync::critical_section::with_critical_section; @@ -20,33 +22,33 @@ use pyo3::types::{IntoPyDict, PyByteArray, PyBytes, PyDict, PyList, PyTuple}; use pyo3::{intern, IntoPyObjectExt}; use super::super::Structure; +use super::extension::PackStreamV1Ext; use super::{ BYTES_16, BYTES_32, BYTES_8, FALSE, FLOAT_64, INT_16, INT_32, INT_64, INT_8, LIST_16, LIST_32, LIST_8, MAP_16, MAP_32, MAP_8, NULL, STRING_16, STRING_32, STRING_8, TINY_LIST, TINY_MAP, TINY_STRING, TINY_STRUCT, TRUE, }; -#[pyfunction] -#[pyo3(signature = (bytes, idx, hydration_hooks=None))] -pub(super) fn unpack( +pub(crate) fn unpack( bytes: Bound, idx: usize, hydration_hooks: Option>, ) -> PyResult<(Py, usize)> { let py = bytes.py(); - let mut decoder = PackStreamDecoder::new(py, bytes, idx, hydration_hooks); + let mut decoder = PackStreamDecoder::::new(py, bytes, idx, hydration_hooks); let result = decoder.read()?; Ok((result, decoder.index)) } -struct PackStreamDecoder<'a> { +pub(crate) struct PackStreamDecoder<'a, E: PackStreamV1Ext> { + ext: PhantomData, py: Python<'a>, bytes: Bound<'a, PyByteArray>, index: usize, hydration_hooks: Option>, } -impl<'a> PackStreamDecoder<'a> { +impl<'a, E: PackStreamV1Ext> PackStreamDecoder<'a, E> { fn new( py: Python<'a>, bytes: Bound<'a, PyByteArray>, @@ -54,6 +56,7 @@ impl<'a> PackStreamDecoder<'a> { hydration_hooks: Option>, ) -> Self { Self { + ext: PhantomData, py, bytes, index: idx, @@ -133,10 +136,13 @@ impl<'a> PackStreamDecoder<'a> { } _ if high_nibble == TINY_STRUCT => self.read_struct((marker & 0x0F).into())?, _ => { - // raise ValueError("Unknown PackStream marker %02X" % marker) - return Err(PyErr::new::(format!( - "Unknown PackStream marker {marker:02X}", - ))); + let Some(value) = E::unpack_ext(marker, self)? else { + // raise ValueError("Unknown PackStream marker %02X" % marker) + return Err(PyErr::new::(format!( + "Unknown PackStream marker {marker:02X}", + ))); + }; + value } }) } @@ -262,7 +268,7 @@ impl<'a> PackStreamDecoder<'a> { Ok(byte) } - fn read_n_bytes(&mut self) -> PyResult<[u8; N]> { + pub(crate) fn read_n_bytes(&mut self) -> PyResult<[u8; N]> { let to = self.index + N; with_critical_section(&self.bytes, || { // Safety: @@ -320,4 +326,9 @@ impl<'a> PackStreamDecoder<'a> { fn read_f64(&mut self) -> PyResult { self.read_n_bytes().map(f64::from_be_bytes) } + + #[inline] + pub(crate) fn py(&self) -> Python<'a> { + self.py + } } diff --git a/src/codec/packstream/v2.rs b/src/codec/packstream/v2.rs new file mode 100644 index 0000000..42329d0 --- /dev/null +++ b/src/codec/packstream/v2.rs @@ -0,0 +1,52 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub(super) mod extension; + +use pyo3::prelude::*; +use pyo3::types::{PyByteArray, PyBytes, PyDict}; +use pyo3::wrap_pyfunction; + +use crate::register_package; + +const UUID: u8 = 0xE0; + +#[pyfunction] +#[pyo3(signature = (value, dehydration_hooks=None))] +fn pack<'py>( + value: &Bound<'py, PyAny>, + dehydration_hooks: Option<&Bound<'py, PyAny>>, +) -> PyResult> { + super::v1::pack::pack::(value, dehydration_hooks) +} + +#[pyfunction] +#[pyo3(signature = (bytes, idx, hydration_hooks=None))] +fn unpack( + bytes: Bound, + idx: usize, + hydration_hooks: Option>, +) -> PyResult<(Py, usize)> { + super::v1::unpack::unpack::(bytes, idx, hydration_hooks) +} + +pub(crate) fn init_module(m: &Bound, name: &str) -> PyResult<()> { + register_package(m, name)?; + + m.add_function(wrap_pyfunction!(unpack, m)?)?; + m.add_function(wrap_pyfunction!(pack, m)?)?; + + Ok(()) +} diff --git a/src/codec/packstream/v2/extension.rs b/src/codec/packstream/v2/extension.rs new file mode 100644 index 0000000..8d150a0 --- /dev/null +++ b/src/codec/packstream/v2/extension.rs @@ -0,0 +1,78 @@ +// Copyright (c) "Neo4j" +// Neo4j Sweden AB [https://neo4j.com] +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ffi::CStr; + +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::IntoPyObjectExt; + +use super::super::uuid::get_uuid_cls; +use super::super::v1::extension::PackStreamV1Ext; +use super::super::v1::pack::PackStreamEncoder; +use super::super::v1::unpack::PackStreamDecoder; +use super::UUID; + +pub(crate) struct PackStreamV2Ext {} + +impl PackStreamV2Ext { + fn write_uuid(uuid: u128, encoder: &mut PackStreamEncoder<'_, Self>) { + encoder.write_raw(&[UUID]); + encoder.write_raw(&u128::to_be_bytes(uuid)); + } +} + +impl PackStreamV1Ext for PackStreamV2Ext { + #[inline] + fn type_mapping_import() -> &'static CStr { + c"from neo4j._codec.packstream.v2.types import *" + } + + #[inline] + fn pack_ext( + value: &'_ Bound, + encoder: &mut PackStreamEncoder<'_, Self>, + ) -> PyResult { + let py = value.py(); + + let uuid_cls = get_uuid_cls(py)?; + if value.is_instance(uuid_cls)? { + let uuid_int: u128 = value.getattr(intern!(py, "int"))?.extract()?; + Self::write_uuid(uuid_int, encoder); + return Ok(true); + } + + Ok(false) + } + + #[inline] + fn unpack_ext( + marker: u8, + decoder: &mut PackStreamDecoder, + ) -> PyResult>> { + let py = decoder.py(); + + Ok(match marker { + UUID => { + let uuid_cls = get_uuid_cls(py)?; + let uuid_int = u128::from_be_bytes(decoder.read_n_bytes()?); + let uuid_obj = + uuid_cls.call1((py.None(), py.None(), py.None(), py.None(), uuid_int))?; + Some(uuid_obj.into_py_any(py)?) + } + _ => None, + }) + } +} diff --git a/tests/codec/packstream/v1/from_driver/test_packstream.py b/tests/codec/packstream/v1/from_driver/test_packstream.py index 9e8350e..3de16f3 100644 --- a/tests/codec/packstream/v1/from_driver/test_packstream.py +++ b/tests/codec/packstream/v1/from_driver/test_packstream.py @@ -14,25 +14,30 @@ # limitations under the License. +import re import struct import typing +import uuid from contextlib import suppress +from decimal import Decimal +from fractions import Fraction from io import BytesIO from math import ( isnan, pi, ) -from uuid import uuid4 import numpy as np import pyarrow as pa import pytest from neo4j._codec.packstream import Structure -from neo4j._codec.packstream.v1 import ( +from neo4j._codec.packstream._common import ( PackableBuffer, - Packer, UnpackableBuffer, +) +from neo4j._codec.packstream.v1 import ( + Packer, Unpacker, ) @@ -89,6 +94,18 @@ def _pack(*values, dehydration_hooks=None): return _pack +@pytest.fixture +def unpack(unpacker_with_buffer): + unpacker, unpackable_buffer = unpacker_with_buffer + + def _unpack(data, hydration_hooks=None): + unpackable_buffer.data = bytearray(data) + unpackable_buffer.used = len(data) + return unpacker.unpack() + + return _unpack + + _default_out_value = object() @@ -260,7 +277,7 @@ def _map_value(v): return constructor -class TestPackStream: +class TestPackStreamV1: @pytest.mark.parametrize( "value", (None, *((pd.NA,) if HAS_PD else ())), @@ -797,6 +814,69 @@ def test_struct_size_overflow(self, pack): fields = [1] * 16 pack(Structure(b"X", *fields)) - def test_illegal_uuid(self, assert_packable): - with pytest.raises(ValueError): - assert_packable(uuid4(), b"\xb0XXX") + def test_illegal_uuid(self, pack): + with pytest.raises(ValueError) as exc: + pack(uuid.uuid4()) + + msg = str(exc.value) + assert str(uuid.UUID) in msg + assert re.search(r"\bbolt\b.*\b6\.1\b", msg, re.IGNORECASE) + + @pytest.mark.parametrize( + "value", + ( + Fraction(1, 3), + Decimal("1.333333333333333333"), + re.compile(r".*"), + ), + ) + def test_illegal_types(self, value, pack): + with pytest.raises(ValueError) as exc: + pack(value) + + msg = str(exc.value) + assert str(type(value)) in msg + assert "bolt" not in msg.lower() + + @pytest.mark.parametrize( + "marker", + ( + b"\xc4", + b"\xc5", + b"\xc6", + b"\xc7", + b"\xcf", + b"\xd3", + b"\xd7", + b"\xdb", + b"\xdc", + b"\xdd", + b"\xde", + b"\xdf", + b"\xe0", + b"\xe1", + b"\xe2", + b"\xe3", + b"\xe4", + b"\xe5", + b"\xe6", + b"\xe7", + b"\xe8", + b"\xe9", + b"\xea", + b"\xeb", + b"\xec", + b"\xed", + b"\xee", + b"\xef", + ), + ) + def test_unpacking_undefined_marker(self, marker, unpack): + data = marker + (b"\xc0" * 128) + with pytest.raises(ValueError) as exc: + unpack(data) + + int_marker = int.from_bytes(marker, "big") + msg = str(exc.value) + assert re.search(r"\bmarker\b", msg, re.IGNORECASE) + assert re.search(rf"\b{int_marker:02x}\b", msg, re.IGNORECASE) diff --git a/tests/codec/packstream/v2/__init__.py b/tests/codec/packstream/v2/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/packstream/v2/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/codec/packstream/v2/from_driver/__init__.py b/tests/codec/packstream/v2/from_driver/__init__.py new file mode 100644 index 0000000..3f96809 --- /dev/null +++ b/tests/codec/packstream/v2/from_driver/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/codec/packstream/v2/from_driver/test_packstream.py b/tests/codec/packstream/v2/from_driver/test_packstream.py new file mode 100644 index 0000000..ff87683 --- /dev/null +++ b/tests/codec/packstream/v2/from_driver/test_packstream.py @@ -0,0 +1,896 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +import struct +import sys +import typing +import uuid +from contextlib import suppress +from decimal import Decimal +from fractions import Fraction +from io import BytesIO +from math import ( + isnan, + pi, +) + +import numpy as np +import pyarrow as pa +import pytest + +from neo4j._codec.packstream import Structure +from neo4j._codec.packstream._common import ( + PackableBuffer, + UnpackableBuffer, +) +from neo4j._codec.packstream.v2 import ( + Packer, + Unpacker, +) + + +HAS_PD = True +if typing.TYPE_CHECKING: + import pandas as pd +else: + try: + import pandas as pd + except ImportError: + pd = None + HAS_PD = False + +standard_ascii = [chr(i) for i in range(128)] +not_ascii = "♥O◘♦♥O◘♦" + + +@pytest.fixture +def packer_with_buffer(): + packable_buffer = Packer.new_packable_buffer() + return Packer(packable_buffer), packable_buffer + + +@pytest.fixture +def unpacker_with_buffer(): + unpackable_buffer = Unpacker.new_unpackable_buffer() + return Unpacker(unpackable_buffer), unpackable_buffer + + +def test_packable_buffer(packer_with_buffer): + packer, packable_buffer = packer_with_buffer + assert isinstance(packable_buffer, PackableBuffer) + assert packable_buffer is packer.stream + + +def test_unpackable_buffer(unpacker_with_buffer): + unpacker, unpackable_buffer = unpacker_with_buffer + assert isinstance(unpackable_buffer, UnpackableBuffer) + assert unpackable_buffer is unpacker.unpackable + + +@pytest.fixture +def pack(packer_with_buffer): + packer, packable_buffer = packer_with_buffer + + def _pack(*values, dehydration_hooks=None): + for value in values: + packer.pack(value, dehydration_hooks=dehydration_hooks) + data = bytearray(packable_buffer.data) + packable_buffer.clear() + return data + + return _pack + + +@pytest.fixture +def unpack(unpacker_with_buffer): + unpacker, unpackable_buffer = unpacker_with_buffer + + def _unpack(data, hydration_hooks=None): + unpackable_buffer.data = bytearray(data) + unpackable_buffer.used = len(data) + return unpacker.unpack() + + return _unpack + + +_default_out_value = object() + + +@pytest.fixture +def assert_packable(packer_with_buffer, unpacker_with_buffer): + def _recursive_nan_equal(a, b): + if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)): + return len(a) == len(b) and all( + _recursive_nan_equal(x, y) for x, y in zip(a, b, strict=True) + ) + elif isinstance(a, dict) and isinstance(b, dict): + return len(a) == len(b) and all( + _recursive_nan_equal(a[k], b[k]) for k in a + ) + else: + return a == b or (isnan(a) and isnan(b)) + + def _assert(in_value, packed_value, out_value=_default_out_value): + if out_value is _default_out_value: + out_value = in_value + nonlocal packer_with_buffer, unpacker_with_buffer + packer, packable_buffer = packer_with_buffer + unpacker, unpackable_buffer = unpacker_with_buffer + packable_buffer.clear() + unpackable_buffer.reset() + + packer.pack(in_value) + packed_data = packable_buffer.data + assert packed_data == packed_value + + unpackable_buffer.data = bytearray(packed_data) + unpackable_buffer.used = len(packed_data) + unpacked_data = unpacker.unpack() + assert _recursive_nan_equal(unpacked_data, out_value) + + return _assert + + +@pytest.fixture(params=(True, False)) +def np_float_overflow_as_error(request): + should_raise = request.param + if should_raise: + old_err = np.seterr(over="raise") + else: + old_err = np.seterr(over="ignore") + yield + np.seterr(**old_err) + + +@pytest.fixture( + params=( + int, + np.int8, + np.int16, + np.int32, + np.int64, + np.longlong, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.ulonglong, + ) +) +def int_type(request): + if issubclass(request.param, np.number): + + def _int_type(value): + # this avoids deprecation warning from NEP50 and forces + # c-style wrapping of the value + return np.array(value).astype(request.param).item() + + return _int_type + else: + return request.param + + +@pytest.fixture( + params=(float, np.float16, np.float32, np.float64, np.longdouble) +) +def float_type(request, np_float_overflow_as_error): + return request.param + + +@pytest.fixture(params=(bool, np.bool_)) +def bool_type(request): + return request.param + + +@pytest.fixture(params=(bytes, bytearray, np.bytes_)) +def bytes_type(request): + return request.param + + +@pytest.fixture(params=(str, np.str_)) +def str_type(request): + return request.param + + +@pytest.fixture( + params=( + pytest.param(list, id="list"), + pytest.param(tuple, id="tuple"), + pytest.param(np.array, id="np.array"), + *( + ( + pytest.param( + pd.Series, + id="pd.Series", + ), + pytest.param( + pd.array, + id="pd.array", + ), + pytest.param( + pd.arrays.SparseArray, + id="pd.arrays.SparseArray", + ), + pytest.param( + pd.arrays.NumpyExtensionArray, + id="pd.arrays.NumpyExtensionArray", + ), + pytest.param( + pd.arrays.ArrowExtensionArray, + id="pd.arrays.ArrowExtensionArray", + ), + ) + if HAS_PD + else () + ), + ) +) +def sequence_type(request): + if HAS_PD and request.param is pd.Series: + + def constructor(value): + if not value: + return pd.Series(dtype=object) + return pd.Series(value) + + elif HAS_PD and request.param is pd.array and pd.__version__ >= "3": + + def constructor(value): + with suppress(ValueError): + return pd.array(value) + return pd.array(value, dtype=object) + + elif HAS_PD and request.param is pd.arrays.NumpyExtensionArray: + + def constructor(value): + return pd.arrays.NumpyExtensionArray(np.array(value)) + + elif HAS_PD and request.param is pd.arrays.ArrowExtensionArray: + + def constructor(value): + def _map_value(v): + if isinstance(v, pd.arrays.ArrowExtensionArray): + v = pa.array(v) + if isinstance(v, pa.Array): + v = v.to_pylist() + return v + + value = map(_map_value, value) + return pd.arrays.ArrowExtensionArray(pa.array(value)) + + else: + constructor = request.param + + return constructor + + +class TestPackStreamV2: + @pytest.mark.parametrize( + "value", + (None, *((pd.NA,) if HAS_PD else ())), + ) + def test_none(self, value, assert_packable): + assert_packable(value, b"\xc0", None) + + def test_boolean(self, bool_type, assert_packable): + assert_packable(bool_type(True), b"\xc3") + assert_packable(bool_type(False), b"\xc2") + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize( + "dtype", + (bool, *((pd.BooleanDtype(),) if HAS_PD else ())), + ) + def test_boolean_pandas_series(self, dtype, assert_packable): + value = [True, False] + value_series = pd.Series(value, dtype=dtype) + assert_packable(value_series, b"\x92\xc3\xc2", value) + + def test_negative_tiny_int(self, int_type, assert_packable): + for z in range(-16, 0): + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + assert_packable(z_typed, bytes(bytearray([z + 0x100]))) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize( + "dtype", + ( + int, + np.int8, + np.int16, + np.int32, + np.int64, + np.longlong, + *( + ( + pd.Int8Dtype(), + pd.Int16Dtype(), + pd.Int32Dtype(), + pd.Int64Dtype(), + ) + if HAS_PD + else () + ), + ), + ) + def test_negative_tiny_int_pandas_series(self, dtype, assert_packable): + for z in range(-16, 0): + z_typed = pd.Series(z, dtype=dtype) + assert_packable(z_typed, bytes(bytearray([0x91, z + 0x100])), [z]) + + def test_positive_tiny_int(self, int_type, assert_packable): + for z in range(128): + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + assert_packable(z_typed, bytes(bytearray([z]))) + + def test_negative_int8(self, int_type, assert_packable): + for z in range(-128, -16): + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + assert_packable(z_typed, bytes(bytearray([0xC8, z + 0x100]))) + + def test_positive_int16(self, int_type, assert_packable): + for z in range(128, 32768): + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + expected = b"\xc9" + struct.pack(">h", z) + assert_packable(z_typed, expected) + + def test_negative_int16(self, int_type, assert_packable): + for z in range(-32768, -128): + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + expected = b"\xc9" + struct.pack(">h", z) + assert_packable(z_typed, expected) + + def test_positive_int32(self, int_type, assert_packable): + for e in range(15, 31): + z = 2**e + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + expected = b"\xca" + struct.pack(">i", z) + assert_packable(z_typed, expected) + + def test_negative_int32(self, int_type, assert_packable): + for e in range(15, 31): + z = -(2**e + 1) + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + expected = b"\xca" + struct.pack(">i", z) + assert_packable(z_typed, expected) + + def test_positive_int64(self, int_type, assert_packable): + for e in range(31, 63): + z = 2**e + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + expected = b"\xcb" + struct.pack(">q", z) + assert_packable(z_typed, expected) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize( + "dtype", + ( + int, + np.int64, + np.longlong, + np.uint64, + np.ulonglong, + *( + ( + pd.Int64Dtype(), + pd.UInt64Dtype(), + ) + if HAS_PD + else () + ), + ), + ) + def test_positive_int64_pandas_series(self, dtype, assert_packable): + for e in range(31, 63): + z = 2**e + z_typed = pd.Series(z, dtype=dtype) + expected = b"\x91\xcb" + struct.pack(">q", z) + assert_packable(z_typed, expected, [z]) + + def test_negative_int64(self, int_type, assert_packable): + for e in range(31, 63): + z = -(2**e + 1) + z_typed = int_type(z) + if z != int(z_typed): + continue # not representable + expected = b"\xcb" + struct.pack(">q", z) + assert_packable(z_typed, expected) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize( + "dtype", + ( + int, + np.int64, + np.longlong, + *((pd.Int64Dtype(),) if HAS_PD else ()), + ), + ) + def test_negative_int64_pandas_series(self, dtype, assert_packable): + for e in range(31, 63): + z = -(2**e + 1) + z_typed = pd.Series(z, dtype=dtype) + expected = b"\x91\xcb" + struct.pack(">q", z) + assert_packable(z_typed, expected, [z]) + + def test_integer_positive_overflow(self, int_type, pack, assert_packable): + with pytest.raises(OverflowError): + z = 2**63 + 1 + z_typed = int_type(z) + if z != int(z_typed): + pytest.skip("not representable") + pack(z_typed) + + def test_integer_negative_overflow(self, int_type, pack, assert_packable): + with pytest.raises(OverflowError): + z = -(2**63) - 1 + z_typed = int_type(z) + if z != int(z_typed): + pytest.skip("not representable") + pack(z_typed) + + def test_float(self, float_type, assert_packable): + for z in ( + 0.0, + -0.0, + pi, + 2 * pi, + float("inf"), + float("-inf"), + float("nan"), + *(float(2**e) + 0.5 for e in range(100)), + *(-float(2**e) + 0.5 for e in range(100)), + ): + try: + z_typed = float_type(z) + except FloatingPointError: + continue # not representable + expected = b"\xc1" + struct.pack(">d", float(z_typed)) + assert_packable(z_typed, expected) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize( + "dtype", + ( + float, + np.float16, + np.float32, + np.float64, + np.longdouble, + *( + ( + pd.Float32Dtype(), + pd.Float64Dtype(), + ) + if HAS_PD + else () + ), + ), + ) + def test_float_pandas_series( + self, dtype, np_float_overflow_as_error, assert_packable + ): + for z in ( + 0.0, + -0.0, + pi, + 2 * pi, + float("inf"), + float("-inf"), + float("nan"), + *(float(2**e) + 0.5 for e in range(100)), + *(-float(2**e) + 0.5 for e in range(100)), + ): + try: + z_typed = pd.Series(z, dtype=dtype) + except FloatingPointError: + continue # not representable + if z_typed[0] is pd.NA: + expected_bytes = b"\x91\xc0" # encoded as NULL + expected_value = [None] + else: + expected_bytes = b"\x91\xc1" + struct.pack( + ">d", float(z_typed[0]) + ) + expected_value = [float(z_typed[0])] + assert_packable(z_typed, expected_bytes, expected_value) + + def test_empty_bytes(self, bytes_type, assert_packable): + b = bytes_type(b"") + assert_packable(b, b"\xcc\x00") + + def test_bytes_8(self, bytes_type, assert_packable): + b = bytes_type(b"hello") + assert_packable(b, b"\xcc\x05hello") + + def test_bytes_16(self, bytes_type, assert_packable): + b = bytearray(40000) + b_typed = bytes_type(b) + assert_packable(b_typed, b"\xcd\x9c\x40" + b) + + def test_bytes_32(self, bytes_type, assert_packable): + b = bytearray(80000) + b_typed = bytes_type(b) + assert_packable(b_typed, b"\xce\x00\x01\x38\x80" + b) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + def test_bytes_pandas_series(self, assert_packable): + for b, header in ( + (b"", b"\xcc\x00"), + (b"hello", b"\xcc\x05"), + (bytearray(40000), b"\xcd\x9c\x40"), + (bytearray(80000), b"\xce\x00\x01\x38\x80"), + ): + b_typed = pd.Series([b]) + assert_packable(b_typed, b"\x91" + header + b, [b]) + + def test_bytearray_size_overflow(self, bytes_type, assert_packable): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer._pack_bytes_header(2**32) + + def test_empty_string(self, str_type, assert_packable): + assert_packable(str_type(""), b"\x80") + + def test_tiny_strings(self, str_type, assert_packable): + for size in range(0x10): + s = str_type("A" * size) + assert_packable(s, bytes(bytearray([0x80 + size]) + (b"A" * size))) + + def test_string_8(self, str_type, assert_packable): + t = "A" * 40 + b = t.encode("utf-8") + t_typed = str_type(t) + assert_packable(t_typed, b"\xd0\x28" + b) + + def test_string_16(self, str_type, assert_packable): + t = "A" * 40000 + b = t.encode("utf-8") + t_typed = str_type(t) + assert_packable(t_typed, b"\xd1\x9c\x40" + b) + + def test_string_32(self, str_type, assert_packable): + t = "A" * 80000 + b = t.encode("utf-8") + t_typed = str_type(t) + assert_packable(t_typed, b"\xd2\x00\x01\x38\x80" + b) + + def test_unicode_string(self, str_type, assert_packable): + t = "héllö" + b = t.encode("utf-8") + t_typed = str_type(t) + assert_packable(t_typed, bytes(bytearray([0x80 + len(b)])) + b) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize( + "dtype", + ( + str, + np.str_, + *( + ( + pd.StringDtype("python"), + pd.StringDtype("pyarrow"), + ) + if HAS_PD + else () + ), + ), + ) + def test_string_pandas_series(self, dtype, assert_packable): + values = ( + ("", b"\x80"), + ("A" * 40, b"\xd0\x28"), + ("A" * 40000, b"\xd1\x9c\x40"), + ("A" * 80000, b"\xd2\x00\x01\x38\x80"), + ) + for t, header in values: + t_typed = pd.Series([t], dtype=dtype) + assert_packable(t_typed, b"\x91" + header + t.encode("utf-8"), [t]) + + t_typed = pd.Series([t for t, _ in values], dtype=dtype) + expected = bytes([0x90 + len(values)]) + b"".join( + header + t.encode("utf-8") for t, header in values + ) + assert_packable(t_typed, expected, [t for t, _ in values]) + + def test_string_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer._pack_string_header(2**32) + + def test_empty_list(self, sequence_type, assert_packable): + list_ = [] + list_typed = sequence_type(list_) + assert_packable(list_typed, b"\x90", list_) + + def test_tiny_lists(self, sequence_type, assert_packable): + for size in range(0x10): + nums = [1] * size + nums_typed = sequence_type(nums) + data_out = bytearray([0x90 + size]) + bytearray([1] * size) + assert_packable(nums_typed, bytes(data_out), nums) + + def test_list_8(self, sequence_type, assert_packable): + nums = [1] * 40 + nums_typed = sequence_type(nums) + assert_packable(nums_typed, b"\xd4\x28" + (b"\x01" * 40), nums) + + def test_list_16(self, sequence_type, assert_packable): + nums = [1] * 40000 + nums_typed = sequence_type(nums) + assert_packable(nums_typed, b"\xd5\x9c\x40" + (b"\x01" * 40000), nums) + + def test_list_32(self, sequence_type, assert_packable): + nums = [1] * 80000 + nums_typed = sequence_type(nums) + assert_packable( + nums_typed, b"\xd6\x00\x01\x38\x80" + (b"\x01" * 80000), nums + ) + + @pytest.mark.parametrize("inner_as_list", (True, False)) + def test_nested_lists(self, sequence_type, inner_as_list, assert_packable): + list_ = [[[]]] + if inner_as_list: + l_typed = sequence_type(list_) + else: + l_typed = sequence_type([sequence_type([sequence_type([])])]) + assert_packable(l_typed, b"\x91\x91\x90", list_) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize("as_series", (True, False)) + def test_list_pandas_categorical(self, as_series, pack, assert_packable): + animals = ["cat", "dog", "cat", "cat", "dog", "horse"] + animals_typed = pd.Categorical(animals) + if as_series: + animals_typed = pd.Series(animals_typed) + b = b"".join([b"\x96", *(pack(e) for e in animals)]) + assert_packable(animals_typed, b, animals) + + def test_list_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer._pack_list_header(2**32) + + def test_empty_map(self, assert_packable): + assert_packable({}, b"\xa0") + + @pytest.mark.parametrize("size", range(0x10)) + def test_tiny_maps(self, assert_packable, size): + data_in = {} + data_out = bytearray([0xA0 + size]) + for el in range(1, size + 1): + data_in[chr(64 + el)] = el + data_out += bytearray([0x81, 64 + el, el]) + assert_packable(data_in, bytes(data_out)) + + @pytest.mark.parametrize("size", range(0x10)) + def test_tiny_maps_padded_key(self, assert_packable, size): + data_in = {} + data_out = bytearray([0xA0 + size]) + padding = b"1234567890abcdefghijklmnopqrstuvwxyz" + for el in range(1, size + 1): + data_in[padding.decode("ascii") + chr(64 + el)] = el + data_out += bytearray([0xD0, 37, *padding, 64 + el, el]) + assert_packable(data_in, bytes(data_out)) + + def test_map_8(self, pack, assert_packable): + d = {f"A{i}": 1 for i in range(40)} + b = b"".join(pack(f"A{i}", 1) for i in range(40)) + assert_packable(d, b"\xd8\x28" + b) + + def test_map_8_padded_key(self, pack, assert_packable): + padding = "1234567890abcdefghijklmnopqrstuvwxyz" + d = {f"{padding}-{i}": 1 for i in range(40)} + b = b"".join(pack(f"{padding}-{i}", 1) for i in range(40)) + assert_packable(d, b"\xd8\x28" + b) + + def test_map_16(self, pack, assert_packable): + d = {f"A{i}": 1 for i in range(40000)} + b = b"".join(pack(f"A{i}", 1) for i in range(40000)) + assert_packable(d, b"\xd9\x9c\x40" + b) + + def test_map_32(self, pack, assert_packable): + d = {f"A{i}": 1 for i in range(80000)} + b = b"".join(pack(f"A{i}", 1) for i in range(80000)) + assert_packable(d, b"\xda\x00\x01\x38\x80" + b) + + def test_map_key_tiny_string(self, assert_packable): + key = "A" + d = {key: 1} + data_out = b"\xa1\x81" + key.encode("utf-8") + b"\x01" + assert_packable(d, bytes(data_out)) + + def test_map_key_string_8(self, assert_packable): + key = "A" * 40 + d = {key: 1} + data_out = b"\xa1\xd0\x28" + key.encode("utf-8") + b"\x01" + assert_packable(d, data_out) + + def test_map_key_string_16(self, assert_packable): + key = "A" * 40000 + d = {key: 1} + data_out = b"\xa1\xd1\x9c\x40" + key.encode("utf-8") + b"\x01" + assert_packable(d, data_out) + + def test_map_key_string_32(self, assert_packable): + key = "A" * 80000 + d = {key: 1} + data_out = b"\xa1\xd2\x00\x01\x38\x80" + key.encode("utf-8") + b"\x01" + assert_packable(d, data_out) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + def test_empty_dataframe_maps(self, assert_packable): + df = pd.DataFrame() + assert_packable(df, b"\xa0", {}) + + @pytest.mark.skipif(pd is None, reason="pandas not installed") + @pytest.mark.parametrize("size", range(0x10)) + def test_tiny_dataframes_maps(self, assert_packable, size): + data_in = {} + data_out = bytearray([0xA0 + size]) + for el in range(1, size + 1): + data_in[chr(64 + el)] = [el] + data_out += bytearray([0x81, 64 + el, 0x91, el]) + data_in_typed = pd.DataFrame(data_in) + assert_packable(data_in_typed, bytes(data_out), data_in) + + def test_map_size_overflow(self): + stream_out = BytesIO() + packer = Packer(stream_out) + with pytest.raises(OverflowError): + packer._pack_map_header(2**32) + + @pytest.mark.parametrize( + ("map_", "exc_type"), + ( + ({1: "1"}, TypeError), + ({"x": {1: "eins", 2: "zwei", 3: "drei"}}, TypeError), + ({"x": {(1, 2): "1+2i", (2, 0): "2"}}, TypeError), + *( + ( + (pd.DataFrame({1: ["1"]}), TypeError), + (pd.DataFrame({(1, 2): ["1"]}), TypeError), + ) + if HAS_PD + else () + ), + ), + ) + def test_map_key_type(self, packer_with_buffer, map_, exc_type): + # maps must have string keys + packer, _packable_buffer = packer_with_buffer + with pytest.raises(exc_type, match="strings"): + packer._pack(map_) + + def test_illegal_signature(self, assert_packable): + with pytest.raises(ValueError): + assert_packable(Structure(b"XXX"), b"\xb0XXX") + + def test_empty_struct(self, assert_packable): + assert_packable(Structure(b"X"), b"\xb0X") + + def test_tiny_structs(self, assert_packable): + for size in range(0x10): + fields = [1] * size + data_in = Structure(b"A", *fields) + data_out = bytearray((0xB0 + size, 0x41, *fields)) + assert_packable(data_in, bytes(data_out)) + + def test_struct_size_overflow(self, pack): + with pytest.raises(OverflowError): + fields = [1] * 16 + pack(Structure(b"X", *fields)) + + @pytest.mark.parametrize( + "value", + ( + uuid.UUID("{12345678-1234-5678-1234-567812345678}"), + uuid.UUID(int=0), + uuid.uuid3(uuid.uuid1(), "name"), + uuid.uuid4(), + uuid.uuid5(uuid.uuid1(), "name"), + *( + ( + uuid.uuid6(), # type: ignore[attr-defined] + uuid.uuid7(), # type: ignore[attr-defined] + uuid.uuid8(), # type: ignore[attr-defined] + ) + if sys.version_info >= (3, 14) + else () + ), + ), + ) + def test_uuid(self, value, assert_packable): + assert_packable(value, b"\xe0" + value.bytes) + + @pytest.mark.parametrize( + "value", + ( + Fraction(1, 3), + Decimal("1.333333333333333333"), + re.compile(r".*"), + ), + ) + def test_illegal_types(self, value, pack): + with pytest.raises(ValueError) as exc: + pack(value) + + msg = str(exc.value) + assert str(type(value)) in msg + assert "bolt" not in msg.lower() + + @pytest.mark.parametrize( + "marker", + ( + b"\xc4", + b"\xc5", + b"\xc6", + b"\xc7", + b"\xcf", + b"\xd3", + b"\xd7", + b"\xdb", + b"\xdc", + b"\xdd", + b"\xde", + b"\xdf", + b"\xe1", + b"\xe2", + b"\xe3", + b"\xe4", + b"\xe5", + b"\xe6", + b"\xe7", + b"\xe8", + b"\xe9", + b"\xea", + b"\xeb", + b"\xec", + b"\xed", + b"\xee", + b"\xef", + ), + ) + def test_unpacking_undefined_marker(self, marker, unpack): + data = marker + (b"\xc0" * 128) + with pytest.raises(ValueError) as exc: + unpack(data) + + int_marker = int.from_bytes(marker, "big") + msg = str(exc.value) + assert re.search(r"\bmarker\b", msg, re.IGNORECASE) + assert re.search(rf"\b{int_marker:02x}\b", msg, re.IGNORECASE) diff --git a/tests/codec/packstream/v2/test_injection.py b/tests/codec/packstream/v2/test_injection.py new file mode 100644 index 0000000..e01b9ee --- /dev/null +++ b/tests/codec/packstream/v2/test_injection.py @@ -0,0 +1,150 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib +import sys +import traceback + +import pytest + +from neo4j._codec.hydration import DehydrationHooks +from neo4j._codec.packstream import Structure +from neo4j._codec.packstream.v2 import ( + Packer, + Unpacker, +) + + +@pytest.fixture +def packer_with_buffer(): + packable_buffer = Packer.new_packable_buffer() + return Packer(packable_buffer), packable_buffer + + +@pytest.fixture +def unpacker_with_buffer(): + unpackable_buffer = Unpacker.new_unpackable_buffer() + return Unpacker(unpackable_buffer), unpackable_buffer + + +def test_pack_injection_works(packer_with_buffer): + class TestClass: + pass + + class TestError(Exception): + pass + + def raise_test_exception(*args, **kwargs): + raise TestError + + dehydration_hooks = DehydrationHooks( + exact_types={TestClass: raise_test_exception}, + subtypes={}, + ) + test_object = TestClass() + packer, _ = packer_with_buffer + + with pytest.raises(TestError) as exc: + packer.pack(test_object, dehydration_hooks=dehydration_hooks) + + # printing the traceback to stdout to make it easier to debug + traceback.print_exception(exc.type, exc.value, exc.tb, file=sys.stdout) + + assert any("_rust_pack" in str(entry.statement) for entry in exc.traceback) + assert not any( + "_py_pack" in str(entry.statement) for entry in exc.traceback + ) + + +def test_unpack_injection_works(unpacker_with_buffer): + class TestError(Exception): + pass + + def raise_test_exception(*args, **kwargs): + raise TestError + + hydration_hooks = {Structure: raise_test_exception} + unpacker, buffer = unpacker_with_buffer + + buffer.reset() + buffer.data = bytearray(b"\xb0\xff") + + with pytest.raises(TestError) as exc: + unpacker.unpack(hydration_hooks) + + # printing the traceback to stdout to make it easier to debug + traceback.print_exception(exc.type, exc.value, exc.tb, file=sys.stdout) + + assert any( + "_rust_unpack" in str(entry.statement) for entry in exc.traceback + ) + assert not any( + "_py_unpack" in str(entry.statement) for entry in exc.traceback + ) + + +@pytest.mark.parametrize( + ("name", "submodule_names"), + ( + # packstream v2 + ("neo4j._rust.codec.packstream.v2", ()), + ("neo4j._rust.codec.packstream", ("v2",)), + ("neo4j._rust.codec", ("packstream",)), + ("neo4j._rust", ("codec",)), + ("neo4j", ("_rust",)), + ), +) +def test_import_module(name, submodule_names): + module = importlib.import_module(name) + + assert module.__name__ == name + + for submodule_name in submodule_names: + package = getattr(module, submodule_name) + assert package.__name__ == f"{name}.{submodule_name}" + + +def test_rust_struct_access(): + tag = b"F" + fields = ["foo", False, 42, 3.14, b"bar"] + struct = Structure(tag, *fields) + + assert struct.tag == tag + assert isinstance(struct.tag, bytes) + assert struct.fields == fields + + +def test_rust_struct_equal(): + struct1 = Structure(b"F", "foo", False, 42, 3.14, b"bar") + struct2 = Structure(b"F", "foo", False, 42, 3.14, b"bar") + assert struct1 == struct2 + # [noqa] for testing correctness of equality + assert not struct1 != struct2 # noqa: SIM202 + + +@pytest.mark.parametrize( + "args", + ( + (b"F", "foo", True, 42, 3.14, b"bar"), + (b"f", "foo", False, 42, 3.14, b"baz"), + ), +) +def test_rust_struct_not_equal(args): + struct1 = Structure(b"F", "foo", False, 42, 3.14, b"bar") + struct2 = Structure(*args) + assert struct1 != struct2 + # [noqa] for testing correctness of equality + assert not struct1 == struct2 # noqa: SIM201 diff --git a/tests/vector/from_driver/test_vector.py b/tests/vector/from_driver/test_vector.py index d7702b4..d83779e 100644 --- a/tests/vector/from_driver/test_vector.py +++ b/tests/vector/from_driver/test_vector.py @@ -17,6 +17,7 @@ from __future__ import annotations import abc +import contextlib import math import random import struct @@ -37,6 +38,10 @@ ) +with contextlib.suppress(ImportError): + import pyarrow.compute + + if t.TYPE_CHECKING: import numpy import pyarrow diff --git a/tox.ini b/tox.ini index 452be28..4bcd4c8 100644 --- a/tox.ini +++ b/tox.ini @@ -11,4 +11,4 @@ extras = commands_pre = devdriver: python -m pip install ./driver --no-deps commands = - test: python -m pytest -v --benchmark-skip {posargs} tests/test_no_gil.py + test: python -m pytest -v --benchmark-skip {posargs} tests