Skip to content

Feature request: lingvo.jax.asserts.HasShape #332

@drpngx

Description

@drpngx

I tried

def AssertShape(x: jnp.array, shape) -> None:
  if not jnp.array_equal(x.shape, shape):
    raise ValueError(f'Shape mismatch: found {x.shape}, expected: {shape}')

and got

jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..

(BTW, note the double period)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions