From 1ad6344d0069d2790c0081c588cc18cd6e9de877 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Tue, 23 Jan 2024 15:39:23 -0800 Subject: [PATCH 1/4] Add MDS types varint, varuint. --- streaming/base/format/mds/encodings.py | 75 +++++++++++++++++++++++++- tests/test_encodings.py | 48 ++++++++++++++++- 2 files changed, 120 insertions(+), 3 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index a8280dbad..9d6624e2b 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from decimal import Decimal from io import BytesIO -from typing import Any, Optional, Set, Tuple +from typing import IO, Any, Optional, Set, Tuple import numpy as np from numpy import typing as npt @@ -366,6 +366,77 @@ def __init__(self): super().__init__(np.int64) +def _varuint_encode(obj: int) -> bytes: + if obj < 0: + raise ValueError(f'Expected non-negative integer, but got: {obj}.') + ret = [] + while True: + byte = obj & 0x7F + obj >>= 7 + if obj: + ret.append(0x80 | byte) + else: + ret.append(byte) + break + return bytes(ret) + + +def _varint_encode(obj: int) -> bytes: + if 0 <= obj: + obj = obj << 1 + else: + obj = ((-obj) << 1) | 1 + return _varuint_encode(obj) + + +def _varuint_decode(stream: IO[bytes]) -> int: + obj = 0 + shift = 0 + while True: + byte, = stream.read(1) + obj |= (byte & 0x7F) << shift + if 0x80 <= byte: + shift += 7 + else: + break + return obj + + +def _varint_decode(stream: IO[bytes]) -> int: + obj = _varuint_decode(stream) + if obj & 1: + obj = -(obj >> 1) + else: + obj >>= 1 + return obj + + +class VarUInt(Encoding): + """Varint DS3 type.""" + + @classmethod + def encode(cls, obj: int) -> bytes: + return _varuint_encode(obj) + + @classmethod + def decode(cls, data: bytes) -> int: + stream = BytesIO(data) + return _varuint_decode(stream) + + +class VarInt(Encoding): + """Varint DS3 type.""" + + @classmethod + def encode(cls, obj: int) -> bytes: + return _varint_encode(obj) + + @classmethod + def decode(cls, data: bytes) -> int: + stream = BytesIO(data) + return _varint_decode(stream) + + class Float16(Scalar): """Store float16.""" @@ -531,6 +602,8 @@ def _is_valid(self, original: Any, converted: Any) -> None: 'int16': Int16, 'int32': Int32, 'int64': Int64, + 'varuint': VarUInt, + 'varint': VarInt, 'float16': Float16, 'float32': Float32, 'float64': Float64, diff --git a/tests/test_encodings.py b/tests/test_encodings.py index bc3aac670..9df1f1288 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -456,8 +456,8 @@ def test_mds_StrDecimal(self, decoded: Decimal, encoded: bytes): assert dec == decoded def test_get_mds_encodings(self): - uints = {'uint8', 'uint16', 'uint32', 'uint64'} - ints = {'int8', 'int16', 'int32', 'int64', 'str_int'} + uints = {'uint8', 'uint16', 'uint32', 'uint64', 'varuint'} + ints = {'int8', 'int16', 'int32', 'int64', 'str_int', 'varint'} floats = {'float16', 'float32', 'float64', 'str_float', 'str_decimal'} scalars = uints | ints | floats expected_encodings = { @@ -488,6 +488,50 @@ def test_mds_scalar(self, encoding: str, decoded: Union[int, float], encoded: by dec = mdsEnc.mds_decode(encoding, encoded) assert dec == decoded + def test_varints(self): + from streaming.base.format.mds.encodings import mds_decode, mds_encode + for x in range(-700, 700, 7): + y = mds_encode('varint', x) + z = mds_decode('varint', y) + print(x, y, z) + assert x == z + for x in range(0, 700, 7): + y = mds_encode('varuint', x) + z = mds_decode('varuint', y) + print(x, y, z) + assert x == z + + @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27), + ('str', 'mosaicml')]) + def test_mds_encode(self, enc_name: str, data: Any): + output = mdsEnc.mds_encode(enc_name, data) + assert isinstance(output, bytes) + + @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', 9), ('int', '27'), ('str', 12.5)]) + def test_mds_encode_invalid_data(self, enc_name: str, data: Any): + with pytest.raises(AttributeError): + _ = mdsEnc.mds_encode(enc_name, data) + + @pytest.mark.parametrize(('enc_name', 'data', 'expected_data_type'), + [('bytes', b'c\x00\x00\x00\x00\x00\x00\x00', bytes), + ('str', b'mosaicml', str)]) + def test_mds_decode(self, enc_name: str, data: Any, expected_data_type: Any): + output = mdsEnc.mds_decode(enc_name, data) + assert isinstance(output, expected_data_type) + + @pytest.mark.parametrize(('enc_name', 'expected_size'), [('bytes', None), ('int', 8)]) + def test_get_mds_encoded_size(self, enc_name: str, expected_size: Any): + output = mdsEnc.get_mds_encoded_size(enc_name) + assert output is expected_size + + +class TestXSVEncodings: + + @pytest.mark.parametrize(('data', 'encode_data'), [('99', '99'), + ('streaming dataset', 'streaming dataset')]) + def test_str_encode_decode(self, data: str, encode_data: str): + str_enc = xsvEnc.Str() + @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27), ('str', 'mosaicml')]) def test_mds_encode(self, enc_name: str, data: Any): From 32fb2e156e94ccef8d9fe074973dedf9d1748fe9 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Tue, 23 Jan 2024 15:41:02 -0800 Subject: [PATCH 2/4] Fix dupe. --- tests/test_encodings.py | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 9df1f1288..d01c16340 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -26,6 +26,7 @@ def test_byte_encode_decode(self, data: bytes): output = byte_enc.decode(data) assert output == data +import tempfile @pytest.mark.parametrize('data', ['9', 25]) def test_byte_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): @@ -525,37 +526,6 @@ def test_get_mds_encoded_size(self, enc_name: str, expected_size: Any): assert output is expected_size -class TestXSVEncodings: - - @pytest.mark.parametrize(('data', 'encode_data'), [('99', '99'), - ('streaming dataset', 'streaming dataset')]) - def test_str_encode_decode(self, data: str, encode_data: str): - str_enc = xsvEnc.Str() - - @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27), - ('str', 'mosaicml')]) - def test_mds_encode(self, enc_name: str, data: Any): - output = mdsEnc.mds_encode(enc_name, data) - assert isinstance(output, bytes) - - @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', 9), ('int', '27'), ('str', 12.5)]) - def test_mds_encode_invalid_data(self, enc_name: str, data: Any): - with pytest.raises(AttributeError): - _ = mdsEnc.mds_encode(enc_name, data) - - @pytest.mark.parametrize(('enc_name', 'data', 'expected_data_type'), - [('bytes', b'c\x00\x00\x00\x00\x00\x00\x00', bytes), - ('str', b'mosaicml', str)]) - def test_mds_decode(self, enc_name: str, data: Any, expected_data_type: Any): - output = mdsEnc.mds_decode(enc_name, data) - assert isinstance(output, expected_data_type) - - @pytest.mark.parametrize(('enc_name', 'expected_size'), [('bytes', None), ('int', 8)]) - def test_get_mds_encoded_size(self, enc_name: str, expected_size: Any): - output = mdsEnc.get_mds_encoded_size(enc_name) - assert output is expected_size - - class TestXSVEncodings: @pytest.mark.parametrize(('data', 'encode_data'), [('99', '99'), From f4a9589f92af3cbb2b02891b0ddbdd35a9465d3d Mon Sep 17 00:00:00 2001 From: James Knighton Date: Tue, 23 Jan 2024 15:42:20 -0800 Subject: [PATCH 3/4] Keyboard issues. --- tests/test_encodings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_encodings.py b/tests/test_encodings.py index d01c16340..6fccf1aa1 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -26,7 +26,6 @@ def test_byte_encode_decode(self, data: bytes): output = byte_enc.decode(data) assert output == data -import tempfile @pytest.mark.parametrize('data', ['9', 25]) def test_byte_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): From e61d57a9f8ca9f793c8299fb83f8c52993d69d4a Mon Sep 17 00:00:00 2001 From: James Knighton Date: Tue, 23 Jan 2024 15:45:38 -0800 Subject: [PATCH 4/4] Fix docstring. --- streaming/base/format/mds/encodings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/streaming/base/format/mds/encodings.py b/streaming/base/format/mds/encodings.py index 9d6624e2b..1e4377397 100644 --- a/streaming/base/format/mds/encodings.py +++ b/streaming/base/format/mds/encodings.py @@ -412,7 +412,7 @@ def _varint_decode(stream: IO[bytes]) -> int: class VarUInt(Encoding): - """Varint DS3 type.""" + """Store an unsigned integer as a base-128 varint.""" @classmethod def encode(cls, obj: int) -> bytes: @@ -425,7 +425,7 @@ def decode(cls, data: bytes) -> int: class VarInt(Encoding): - """Varint DS3 type.""" + """Store an integer as a base-128 varint.""" @classmethod def encode(cls, obj: int) -> bytes: