diff --git a/benchmarks/bench_encode.py b/benchmarks/bench_encode.py index b79eda9..043c3e8 100644 --- a/benchmarks/bench_encode.py +++ b/benchmarks/bench_encode.py @@ -13,6 +13,9 @@ MEDIUM_DATA_ENCODED = base64.b64encode(MEDIUM_DATA) LARGE_DATA_ENCODED = base64.b64encode(LARGE_DATA) +MEDIUM_DATA_ENCODEBYTES = base64.encodebytes(MEDIUM_DATA) +LARGE_DATA_ENCODEBYTES = base64.encodebytes(LARGE_DATA) + def stdlib_b64encode(data, altchars=None) -> None: for _ in range(ITERATIONS): @@ -23,6 +26,26 @@ def base64_utils_b64encode(data, altchars=None) -> None: for _ in range(ITERATIONS): base64_utils.b64encode(data, altchars=altchars) +def stdlib_encodebytes(data) -> None: + for _ in range(ITERATIONS): + base64.encodebytes(data) + + +def base64_utils_encodebytes(data) -> None: + for _ in range(ITERATIONS): + base64_utils.encodebytes(data) + + +def stdlib_decodebytes(data) -> None: + for _ in range(ITERATIONS): + base64.decodebytes(data) + + +def base64_utils_decodebytes(data) -> None: + for _ in range(ITERATIONS): + base64_utils.decodebytes(data) + + def stdlib_b64decode(data, altchars=None, validate=False) -> None: for _ in range(ITERATIONS): base64.b64decode(data, altchars=altchars, validate=validate) @@ -58,5 +81,25 @@ def base64_utils_b64decode(data, altchars=None, validate=False) -> None: lambda: stdlib_b64decode(MEDIUM_DATA_ENCODED), lambda: base64_utils_b64decode(MEDIUM_DATA_ENCODED), "b64decode (100 KB data)", - ) + ), + ( + lambda: stdlib_encodebytes(MEDIUM_DATA), + lambda: base64_utils_encodebytes(MEDIUM_DATA), + "encodebytes (100 KB data)", + ), + ( + lambda: stdlib_encodebytes(LARGE_DATA), + lambda: base64_utils_encodebytes(LARGE_DATA), + "encodebytes (1 MB data)", + ), + ( + lambda: stdlib_decodebytes(MEDIUM_DATA_ENCODEBYTES), + lambda: base64_utils_decodebytes(MEDIUM_DATA_ENCODEBYTES), + "decodebytes (100 KB data)", + ), + ( + lambda: stdlib_decodebytes(LARGE_DATA_ENCODEBYTES), + lambda: base64_utils_decodebytes(LARGE_DATA_ENCODEBYTES), + "decodebytes (1 MB data)", + ), ] diff --git a/python/base64_utils/__init__.py b/python/base64_utils/__init__.py index 19a1ad8..b8cf66a 100644 --- a/python/base64_utils/__init__.py +++ b/python/base64_utils/__init__.py @@ -2,6 +2,8 @@ __version__, b64decode, b64encode, + decodebytes, + encodebytes, standard_b64decode, standard_b64encode, urlsafe_b64decode, @@ -12,6 +14,8 @@ "__version__", "b64decode", "b64encode", + "decodebytes", + "encodebytes", "standard_b64decode", "standard_b64encode", "urlsafe_b64decode", diff --git a/python/base64_utils/__init__.pyi b/python/base64_utils/__init__.pyi index 7af5c3d..ac6aa01 100644 --- a/python/base64_utils/__init__.pyi +++ b/python/base64_utils/__init__.pyi @@ -5,6 +5,8 @@ __version__: str __all__ = [ "b64decode", "b64encode", + "decodebytes", + "encodebytes", "standard_b64decode", "standard_b64encode", "urlsafe_b64decode", @@ -17,6 +19,8 @@ def b64decode( validate: bool = False, ) -> bytes: ... def b64encode(s: ReadableBuffer, altchars: ReadableBuffer | None = None) -> bytes: ... +def decodebytes(s: ReadableBuffer) -> bytes: ... +def encodebytes(s: ReadableBuffer) -> bytes: ... def standard_b64decode(s: str | ReadableBuffer) -> bytes: ... def standard_b64encode(s: ReadableBuffer) -> bytes: ... def urlsafe_b64decode(s: str | ReadableBuffer) -> bytes: ... diff --git a/src/decoder.rs b/src/decoder.rs index b1d4a96..75f94de 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -95,3 +95,13 @@ pub fn urlsafe_b64decode(py: Python<'_>, s: StringOrBytes) -> PyResult, s: StringOrBytes) -> PyResult> { + let mut input: Vec = s.into_bytes(); + input.retain(|b| !b.is_ascii_whitespace()); + + let output = forgiving_decode_inplace(&mut input) + .map_err(|_| PyValueError::new_err("Invalid base64-encoded string"))?; + Ok(PyBytes::new(py, output).into()) +} diff --git a/src/encoder.rs b/src/encoder.rs index d7aaf60..d5fd37c 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -3,6 +3,9 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyBytes; +const MAXLINESIZE: usize = 76; +const MAXBINSIZE: usize = 57; + #[pyfunction] #[pyo3(signature = (s, altchars=None))] pub fn b64encode(py: Python<'_>, s: &[u8], altchars: Option<&[u8]>) -> PyResult> { @@ -56,3 +59,23 @@ pub fn urlsafe_b64encode(py: Python<'_>, s: &[u8]) -> PyResult> { })?; Ok(output.into()) } + +#[pyfunction] +pub fn encodebytes(py: Python<'_>, s: &[u8]) -> PyResult> { + let encoded_len = STANDARD.encoded_length(s.len()); + let num_lines = (encoded_len + MAXLINESIZE - 1) / MAXLINESIZE; + let total_len = encoded_len + num_lines; // one \n per line + + let output = PyBytes::new_with(py, total_len, |buf| { + let mut pos = 0; + for chunk in s.chunks(MAXBINSIZE) { + let enc_len = STANDARD.encoded_length(chunk.len()); + let _ = STANDARD.encode(chunk, Out::from_slice(&mut buf[pos..pos + enc_len])); + pos += enc_len; + buf[pos] = b'\n'; + pos += 1; + } + Ok(()) + })?; + Ok(output.into()) +} diff --git a/src/lib.rs b/src/lib.rs index 10cf229..bceff51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,5 +12,7 @@ fn _base64_utils(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(encoder::b64encode, m)?)?; m.add_function(wrap_pyfunction!(encoder::standard_b64encode, m)?)?; m.add_function(wrap_pyfunction!(encoder::urlsafe_b64encode, m)?)?; + m.add_function(wrap_pyfunction!(encoder::encodebytes, m)?)?; + m.add_function(wrap_pyfunction!(decoder::decodebytes, m)?)?; Ok(()) } diff --git a/tests/test_base64_decode.py b/tests/test_base64_decode.py index 6ff7bd5..16d539d 100644 --- a/tests/test_base64_decode.py +++ b/tests/test_base64_decode.py @@ -74,3 +74,26 @@ def test_urlsafe_b64decode() -> None: assert isinstance(decoded, bytes) assert expected == decoded + + +def test_decodebytes() -> None: + data = base64.encodebytes(b"Hello, World!") + + decoded = base64_utils.decodebytes(data) + expected = base64.decodebytes(data) + + assert isinstance(decoded, bytes) + assert expected == decoded + + +def test_decodebytes_multiline() -> None: + data = base64.encodebytes(b"A" * 100) + + decoded = base64_utils.decodebytes(data) + expected = base64.decodebytes(data) + + assert expected == decoded + + +def test_decodebytes_empty() -> None: + assert base64_utils.decodebytes(b"") == base64.decodebytes(b"") diff --git a/tests/test_base64_encode.py b/tests/test_base64_encode.py index 4822bdc..cc0bca3 100644 --- a/tests/test_base64_encode.py +++ b/tests/test_base64_encode.py @@ -44,3 +44,26 @@ def test_urlsafe_b64encode() -> None: assert isinstance(encoded, bytes) assert expected == encoded + + +def test_encodebytes() -> None: + data = b"Hello, World!" + encoded = base64_utils.encodebytes(data) + expected = base64.encodebytes(data) + + assert isinstance(encoded, bytes) + assert expected == encoded + + +def test_encodebytes_multiline() -> None: + data = b"A" * 100 + encoded = base64_utils.encodebytes(data) + expected = base64.encodebytes(data) + + assert expected == encoded + lines = encoded.split(b"\n") + assert all(len(line) <= 76 for line in lines) + + +def test_encodebytes_empty() -> None: + assert base64_utils.encodebytes(b"") == base64.encodebytes(b"")