Compute tf.slice output shape from constant begin/size#375
Conversation
The `Slice` generator (wala#568) forwarded the input dtype but left the output shape at ⊤. Derive it per axis from the constant `begin`/`size` extents and `input_.shape`: output.shape[i] = size[i] if size[i] >= 0 = input_.shape[i] - begin[i] if size[i] == -1 A `size` with no `-1` entries yields a concrete shape independent of the input shape; a `-1` ("all remaining") axis needs the corresponding input dim and `begin[i]`, degrading to a dynamic dimension (rank preserved) when either is non-constant. The shape falls back to ⊤ only when `begin`/`size` are non-constant or their ranks disagree with `input_`. This tightens the corpus-anchored `crf_forward`: `inputs` (from `tf.slice(inputs, [0, 1, 0], [-1, -1, -1])` over a `(2, 3, 4)` constant) now infers as `(2, 2, 4)`, and `state` as the pre-squeeze `(2, 1, 4)`, both formerly ⊤-shaped. Tightening `state` to the runtime `(2, 4)` awaits `tf.squeeze` axis-removal modeling (wala#513, Bucket 2). Closes wala#569. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR improves TensorFlow tf.slice shape inference in the Python ML tensor modeling layer by deriving the output shape from constant begin/size arguments (including the size = -1 “all remaining” case), instead of leaving the result shape at ⊤. It also adds/updates regression tests to pin the new inference behavior and tighten an existing CRF fixture’s inferred shapes.
Changes:
- Implement per-axis
tf.sliceoutput-shape derivation inSlicewhenbegin/sizeare statically resolvable. - Add a new TensorFlow fixture (
tf2_test_slice.py) and new JUnit tests to cover bothsize >= 0andsize == -1cases. - Update
testCrfForwardexpectations to reflect the newly inferred slice-derived shapes.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| com.ibm.wala.cast.python.test/data/tf2_test_slice.py | New TF2 fixture exercising tf.slice for fixed-size and “all remaining” slicing. |
| com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java | Adds constant-begin/size shape computation logic for tf.slice. |
| com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java | Adds slice tests and tightens CRF-forward expectations based on improved slice shape inference. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| try { | ||
| Set<List<Dimension<?>>> lists = this.getShapesFromShapeArgument(builder, pts); | ||
| return (lists == null || lists.isEmpty()) ? null : lists; | ||
| } catch (RuntimeException e) { | ||
| LOGGER.fine(() -> "Could not resolve " + name + " of " + this.getSource() + ": " + e + "."); | ||
| return null; | ||
| } |
| Set<List<Dimension<?>>> ret = HashSetFactory.make(); | ||
| for (List<Dimension<?>> input : inputShapes) | ||
| for (List<Dimension<?>> begin : beginLists) | ||
| for (List<Dimension<?>> size : sizeLists) { | ||
| List<Dimension<?>> out = sliceShape(input, begin, size); | ||
| if (out != null) ret.add(out); | ||
| } | ||
| return ret.isEmpty() ? null : ret; |
| if (inDim instanceof NumericDim && beginDim instanceof NumericDim) | ||
| out.add(new NumericDim(((NumericDim) inDim).value() - ((NumericDim) beginDim).value())); | ||
| else out.add(DynamicDim.INSTANCE); // keep the rank; this axis is dynamic. |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #375 +/- ##
=========================================
Coverage 71.67% 71.68%
- Complexity 2731 2744 +13
=========================================
Files 272 272
Lines 20347 20393 +46
Branches 3283 3298 +15
=========================================
+ Hits 14583 14618 +35
- Misses 4470 4475 +5
- Partials 1294 1300 +6 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary
The
Slicegenerator (wala#568) forwarded the input dtype but left the output shape at ⊤. This derives the output shape per axis from the constantbegin/sizeextents andinput_.shape:A
sizewith no-1entries yields a concrete shape independent of the input shape; a-1axis needs the correspondinginput_dim andbegin[i], degrading to a dynamic dimension (rank preserved) when either is non-constant. The shape falls back to ⊤ only whenbegin/sizeare non-constant or their ranks disagree withinput_.Corpus impact
This tightens the corpus-anchored
crf_forward(NLPGNN), the motivator for wala#569:inputs(fromtf.slice(inputs, [0, 1, 0], [-1, -1, -1])over a(2, 3, 4)constant) now infers as(2, 2, 4).state(fromtf.squeeze(tf.slice(inputs, [0, 0, 0], [-1, 1, -1]), [1])) now infers as the pre-squeeze(2, 1, 4).Both were formerly ⊤-shaped (dtype only). Tightening
stateto the runtime(2, 4)awaitstf.squeezeaxis-removal modeling, tracked by wala#513 (Bucket 2) and noted as a TODO ontestCrfForward.Tests
testSlice(size ≥ 0 path →(2, 2)) andtestSliceRemaining(size = -1path →(2, 3)), with runtime shapes ground-truthed viapython3.10.testCrfForwardupdated from ⊤ to(2, 2, 4)/(2, 1, 4).Closes wala#569.
🤖 Generated with Claude Code