Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cubed/icechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def store_icechunk(

arrays = []
for source, target, region in zip(sources, targets, regions_list):
sharding_enabled = target.shards is not None
sharding_misaligned = target.shards != source.chunks
if sharding_enabled and sharding_misaligned:
source = source.rechunk(target.shards)
array = _store_array(
source,
target,
Expand Down
39 changes: 37 additions & 2 deletions cubed/tests/test_icechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def icechunk_storage(tmpdir) -> "Storage":
return Storage.new_local_filesystem(str(tmpdir))


def create_icechunk(a, icechunk_storage, /, *, dtype=None, chunks=None):
def create_icechunk(a, icechunk_storage, /, *, dtype=None, chunks=None, shards=None):
# from dask.asarray
if not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
Expand All @@ -44,7 +44,9 @@ def create_icechunk(a, icechunk_storage, /, *, dtype=None, chunks=None):
store = session.store

group = zarr.group(store=store, overwrite=True)
arr = group.create_array("a", shape=a.shape, dtype=dtype, chunks=chunks)
arr = group.create_array(
"a", shape=a.shape, dtype=dtype, chunks=chunks, shards=shards
)

arr[...] = a

Expand Down Expand Up @@ -136,3 +138,36 @@ def test_store_icechunk_region(icechunk_storage, executor):
]
),
)


def test_store_icechunk_sharded(icechunk_storage, executor):
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
)
create_icechunk(
np.zeros((4, 4), dtype=int), icechunk_storage, chunks=(2, 2), shards=(4, 4)
)

# note that the same zarr store is overwritten in the following tests

repo = Repository.open(storage=icechunk_storage)
session = repo.writable_session("main")
fork = session.fork()
store = fork.store
group = zarr.open_group(store=store)
target = group.get("a")
merged_session = store_icechunk(sources=a, targets=target, executor=executor)
session.merge(merged_session)
session.commit("commit 1")

# reopen store and check contents of array
repo = Repository.open(icechunk_storage)
session = repo.readonly_session(branch="main")
store = session.store

group = zarr.open_group(store=store, mode="r")
assert_array_equal(
cubed.from_array(group["a"])[:],
a,
)
Loading