Skip to content

Sharding-incompatible array creation in scico/linop/xray/_xray.py #643

@bwohlberg

Description

@bwohlberg

Module scico/linop/xray/_xray.py has a number of calls to jax.numpy.zeros and jax.numpy.ones that do not allow specification of a device or a sharding.

See

scico/linop/xray/_xray.py:235:       jnp.zeros((len(angles), ny), dtype=im.dtype)
scico/linop/xray/_xray.py:396:   proj = jnp.zeros((num_views,) + det_shape, dtype=im.dtype)
scico/linop/xray/_xray.py:441:   HTy = jnp.zeros(input_shape, dtype=proj.dtype)
scico/linop/xray/_xray.py:168:       unit_sino = jnp.ones(self.output_shape, dtype=np.float32)

A branch to address this should be created by branching from branch brendt/shard.

Metadata

Metadata

Labels

improvementImprovement of existing code, including addressing of omissions or inconsistencies

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions