Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions trx/tests/test_memmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-

import json
import os
import struct
import tempfile
import zipfile

Expand Down Expand Up @@ -476,6 +478,112 @@ def test__ensure_little_endian_big_endian_input():
assert result[0] == 0x12345678


def test_load_zip_with_local_header_extra_field():
"""Test loading ZIP where local header has extra field not in central dir.

Regression test for a bug where zip_info.FileHeader() was used to calculate
data offset. The ZIP spec allows local headers to have different extra
fields than central directory entries. The fix reads the actual local
file header to get the correct offset.
"""
positions = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
offsets = np.array([0, 2], dtype=np.uint64)
header = {
"DIMENSIONS": [10, 10, 10],
"VOXEL_TO_RASMM": np.eye(4).tolist(),
"NB_VERTICES": 2,
"NB_STREAMLINES": 1,
}

with tempfile.TemporaryDirectory() as tmp_dir:
trx_path = os.path.join(tmp_dir, "test.trx")

# Build ZIP with extra bytes in local headers but not central directory
with open(trx_path, "wb") as f:
local_info = []
extra = b"\x00\x00\x04\x00TEST" # 8-byte extra field

for name, data in [
("header.json", json.dumps(header).encode()),
("positions.3.float32", positions.tobytes()),
("offsets.uint64", offsets.tobytes()),
]:
offset = f.tell()
fname = name.encode()
crc = zipfile.crc32(data)
# Local header WITH extra field
f.write(
struct.pack(
"<4sHHHHHIIIHH",
b"PK\x03\x04",
20,
0,
0,
0,
0,
crc,
len(data),
len(data),
len(fname),
len(extra),
)
)
f.write(fname)
f.write(extra)
f.write(data)
local_info.append((name, offset, crc, len(data)))

cd_start = f.tell()
for name, offset, crc, size in local_info:
fname = name.encode()
# Central directory WITHOUT extra field (mismatch!)
f.write(
struct.pack(
"<4sHHHHHHIIIHHHHHII",
b"PK\x01\x02",
20,
20,
0,
0,
0,
0,
crc,
size,
size,
len(fname),
0,
0,
0,
0,
0,
offset,
)
)
f.write(fname)

# End of central directory
f.write(
struct.pack(
"<4sHHHHIIH",
b"PK\x05\x06",
0,
0,
3,
3,
f.tell() - cd_start,
cd_start,
0,
)
)

trx = tmm.load_from_zip(trx_path)
np.testing.assert_array_almost_equal(trx.streamlines._data, positions)
assert trx.header["NB_VERTICES"] == 2
assert trx.header["NB_STREAMLINES"] == 1

trx.close()


def test_endianness_roundtrip():
"""Test that data survives write/read cycle with correct endianness."""
with get_trx_tmp_dir() as dirname:
Expand Down
26 changes: 22 additions & 4 deletions trx/trx_file_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import shutil
import struct
from typing import Any, List, Optional, Tuple, Type, Union
import zipfile

Expand Down Expand Up @@ -407,13 +408,30 @@ def load_from_zip(filename: str) -> Type["TrxFile"]:
if ext == ".bit":
ext = ".bool"

mem_adress = zip_info.header_offset + len(zip_info.FileHeader())
# Read actual local file header to get correct data offset.
# We can't use zip_info.FileHeader() because ZIP spec allows local
# headers to differ from central directory entries.
# See: https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT
_ZIP_LOCAL_HEADER_SIZE = 30
_ZIP_LOCAL_HEADER_SIGNATURE = b"PK\x03\x04"

zf.fp.seek(zip_info.header_offset)
local_header = zf.fp.read(_ZIP_LOCAL_HEADER_SIZE)
if len(local_header) < _ZIP_LOCAL_HEADER_SIZE:
raise ValueError(f"Truncated local file header for {elem_filename}")
if local_header[:4] != _ZIP_LOCAL_HEADER_SIGNATURE:
raise ValueError(
f"Invalid local file header signature for {elem_filename}"
)
fname_len, extra_len = struct.unpack("<HH", local_header[26:30])

mem_adress = (
zip_info.header_offset + _ZIP_LOCAL_HEADER_SIZE + fname_len + extra_len
)

dtype_size = np.dtype(ext[1:]).itemsize
size = zip_info.file_size / dtype_size

if len(zip_info.extra):
mem_adress -= len(zip_info.extra)

if size.is_integer():
files_pointer_size[elem_filename] = mem_adress, int(size)
else:
Expand Down