diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 144f2b4be..5bd8ff3f0 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -248,6 +248,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape { private static final TensorType TENSOR_2_3_4_FLOAT32 = new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4))); + private static final TensorType TENSOR_2_2_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(2), new NumericDim(4))); + + private static final TensorType TENSOR_2_1_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1), new NumericDim(4))); + private static final TensorType TENSOR_4_4_FLOAT32 = new TensorType(FLOAT_32, asList(new NumericDim(4), new NumericDim(4))); @@ -2620,15 +2626,56 @@ public void testCrfLogNorm() 4, Set.of(TENSOR_4_4_FLOAT32))); } + /** + * Pins the {@code tf.slice} output shape derived from constant {@code begin}/{@code size} + * arguments (wala/ML#569). {@code tf.slice(x, [0, 1], [2, 2])} over a {@code (3, 4)} input yields + * {@code (2, 2)} — all {@code size} entries are non-negative, so the output shape is {@code size} + * exactly, independent of the input shape. + * + * @throws ClassHierarchyException On WALA class-hierarchy error. + * @throws IllegalArgumentException On illegal argument. + * @throws CancelException On analysis cancellation. + * @throws IOException On I/O error reading the test file. + */ + @Test + public void testSlice() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_slice.py", "consume", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); + } + + /** + * Pins the {@code tf.slice} output shape for the "all remaining" case (wala/ML#569). A {@code + * size[i]} of {@code -1} resolves to {@code input.shape[i] - begin[i]}: {@code tf.slice(x, [1, + * 0], [-1, 3])} over a {@code (3, 4)} input yields {@code (3 - 1, 3) = (2, 3)}. + * + * @throws ClassHierarchyException On WALA class-hierarchy error. + * @throws IllegalArgumentException On illegal argument. + * @throws CancelException On analysis cancellation. + * @throws IOException On I/O error reading the test file. + */ + @Test + public void testSliceRemaining() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_slice.py", "consume_remaining", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32))); + } + /** * Pins {@code crf_forward(inputs, state, transition_params, sequence_lengths)}'s parameter types. * Function body mirrors {@code crf_forward} from {@code kyzhouhzau/NLPGNN/nlpgnn/metrics/crf.py} * for tensor-type inference coverage. {@code crf_forward} is reached only through {@code * crf_log_norm} (its sole NLPGNN caller), which passes {@code inputs} and {@code state} from - * {@code tf.slice}/{@code tf.squeeze} results. Both now infer as {@code float32} with shape - * {@code ⊤}: the dedicated {@code Slice} generator forwards the input dtype while leaving the - * shape unknown (wala/ML#568); computing the precise {@code begin}/{@code size} shape is tracked - * by wala/ML#569. + * {@code tf.slice}/{@code tf.squeeze} results over a {@code (2, 3, 4)} constant. With the {@code + * begin}/{@code size} shape derivation (wala/ML#569), {@code inputs} (from {@code + * tf.slice(inputs, [0, 1, 0], [-1, -1, -1])}) now infers as {@code (2, 2, 4)}. {@code state} + * (from {@code tf.squeeze(tf.slice(inputs, [0, 0, 0], [-1, 1, -1]), [1])}) infers as the + * pre-squeeze {@code (2, 1, 4)} rather than the runtime {@code (2, 4)}: the {@code Slice} shape + * flows through, but {@code tf.squeeze}'s axis removal is not yet modeled. Both were previously + * {@code ⊤}-shaped (dtype only, wala/ML#568). + * + *

TODO: Tighten {@code state} to {@code (2, 4)} once {@code tf.squeeze} drops its singleton + * axis instead of passing the shape through; {@code tf.squeeze} is a {@code pass_through} alias + * whose shape semantics are tracked by wala/ML#513 (Bucket 2). * * @throws ClassHierarchyException On WALA class-hierarchy error. * @throws IllegalArgumentException On illegal argument. @@ -2644,8 +2691,8 @@ public void testCrfForward() 4, 14, Map.of( - 2, Set.of(TENSOR_UNKNOWN_SHAPE_FLOAT32), - 3, Set.of(TENSOR_UNKNOWN_SHAPE_FLOAT32), + 2, Set.of(TENSOR_2_2_4_FLOAT32), + 3, Set.of(TENSOR_2_1_4_FLOAT32), 4, Set.of(TENSOR_4_4_FLOAT32), 5, Set.of(TENSOR_2_INT32))); } diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java index 3fe8b035b..3bf338429 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java @@ -1,25 +1,44 @@ package com.ibm.wala.cast.python.ml.client; import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.cast.python.ml.types.TensorType.DynamicDim; +import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import com.ibm.wala.util.collections.HashSetFactory; +import com.ibm.wala.util.intset.OrdinalSet; +import java.util.ArrayList; import java.util.List; import java.util.Set; +import java.util.logging.Logger; /** * Generator for {@code tf.slice(input_, begin, size, name=None)}. Output dtype is inherited from - * the {@code input_} input. Output shape is left at ⊤ for now: the precise shape is determined by - * the {@code begin}/{@code size} extents (where a {@code size[i]} of {@code -1} means "all - * remaining elements along axis {@code i}"), tracked by wala/ML#569. Forwarding only the dtype already - * recovers it from ⊤ — see wala/ML#568. + * the {@code input_} input. Output shape is derived per axis from the constant {@code begin}/{@code + * size} extents and {@code input_.shape}: + * + *

+ * output.shape[i] = size[i]                    if size[i] >= 0
+ *                 = input_.shape[i] - begin[i] if size[i] == -1  ("all remaining")
+ * 
+ * + * A constant {@code size} with no {@code -1} entries gives a fully concrete shape independent of + * {@code input_.shape}; a {@code -1} entry needs the corresponding {@code input_} dim and {@code + * begin[i]}, and degrades to a {@link DynamicDim} on that axis (keeping the rank) when either is + * non-constant. The shape falls back to ⊤ only when {@code begin}/{@code size} are themselves + * non-constant or their ranks disagree with {@code input_}. See wala/ML#569; dtype forwarding alone landed in wala/ML#568. * * @see tf.slice * @author Raffi Khatchadourian */ public class Slice extends PassThroughUnaryTensorGenerator { + private static final Logger LOGGER = Logger.getLogger(Slice.class.getName()); + public Slice(PointsToSetVariable source) { super(source); } @@ -38,8 +57,92 @@ protected String getInputParameterName() { return "input_"; } + /** + * Derives the output shape per axis from the constant {@code begin} (arg 1) and {@code size} (arg + * 2) extents together with the {@code input_} (arg 0) shape, per the rule documented on the + * class. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @return The set of possible output shapes, or {@code null} (⊤) when {@code input_}'s shape is + * unknown, {@code begin}/{@code size} are non-constant, or their ranks disagree with {@code + * input_}. + */ @Override protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { - return null; + // input_ is arg 0 (resolved by the passthrough base); begin is arg 1; size is arg 2. + Set>> inputShapes = super.getDefaultShapes(builder); + if (inputShapes == null) return null; + + Set>> beginLists = resolveConstantIntList(builder, 1, "begin"); + Set>> sizeLists = resolveConstantIntList(builder, 2, "size"); + if (beginLists == null || sizeLists == null) return null; + + Set>> ret = HashSetFactory.make(); + for (List> input : inputShapes) + for (List> begin : beginLists) + for (List> size : sizeLists) { + List> out = sliceShape(input, begin, size); + if (out != null) ret.add(out); + } + return ret.isEmpty() ? null : ret; + } + + /** + * Resolves a {@code begin}/{@code size} argument (a constant list/tuple or {@code tf.constant} of + * ints) into its dimension lists via {@link #getShapesFromShapeArgument}. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param position The 0-based positional index of the argument (excluding {@code self}). + * @param name The keyword name of the argument. + * @return The resolved int-list candidates, or {@code null} when the argument's points-to set is + * empty or it cannot be resolved to a constant list. + */ + private Set>> resolveConstantIntList( + PropagationCallGraphBuilder builder, int position, String name) { + OrdinalSet pts = this.getArgumentPointsToSet(builder, position, name); + if (pts == null || pts.isEmpty()) return null; + try { + Set>> lists = this.getShapesFromShapeArgument(builder, pts); + return (lists == null || lists.isEmpty()) ? null : lists; + } catch (RuntimeException e) { + LOGGER.fine(() -> "Could not resolve " + name + " of " + this.getSource() + ": " + e + "."); + return null; + } + } + + /** + * Applies the {@code tf.slice} per-axis shape rule to a single {@code (input, begin, size)} + * combination. + * + * @param input The {@code input_} shape. + * @param begin The constant {@code begin} offsets. + * @param size The constant {@code size} extents. + * @return The output shape, or {@code null} when the ranks disagree, a {@code size} entry is + * non-constant, or a {@code size} entry is invalid ({@code < -1}). + */ + private static List> sliceShape( + List> input, List> begin, List> size) { + int rank = input.size(); + if (begin.size() != rank || size.size() != rank) return null; + + List> out = new ArrayList<>(rank); + for (int i = 0; i < rank; i++) { + Dimension sizeDim = size.get(i); + if (!(sizeDim instanceof NumericDim)) return null; // non-constant size extent. + int s = ((NumericDim) sizeDim).value(); + if (s >= 0) { + out.add(new NumericDim(s)); + } else if (s == -1) { + // "all remaining" along axis i: input_.shape[i] - begin[i], when both are constant. + Dimension inDim = input.get(i); + Dimension beginDim = begin.get(i); + if (inDim instanceof NumericDim && beginDim instanceof NumericDim) + out.add(new NumericDim(((NumericDim) inDim).value() - ((NumericDim) beginDim).value())); + else out.add(DynamicDim.INSTANCE); // keep the rank; this axis is dynamic. + } else { + return null; // size < -1 is invalid for tf.slice. + } + } + return out; } } diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_slice.py b/com.ibm.wala.cast.python.test/data/tf2_test_slice.py new file mode 100644 index 000000000..83ba32901 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_slice.py @@ -0,0 +1,28 @@ +import tensorflow as tf + + +def consume(y): + pass + + +def consume_remaining(z): + pass + + +x = tf.constant( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + dtype=tf.float32, +) +assert x.shape == (3, 4) and x.dtype == tf.float32 + +# All `size` entries non-negative: the output shape is `size` exactly, independent +# of the input shape. `begin=[0, 1]`, `size=[2, 2]` over `(3, 4)` -> `(2, 2)`. +y = tf.slice(x, [0, 1], [2, 2]) +assert y.shape == (2, 2) and y.dtype == tf.float32 +consume(y) + +# A `size[i]` of `-1` means "all remaining" along axis `i`: `input.shape[i] - +# begin[i]`. `begin=[1, 0]`, `size=[-1, 3]` over `(3, 4)` -> `(3 - 1, 3) = (2, 3)`. +z = tf.slice(x, [1, 0], [-1, 3]) +assert z.shape == (2, 3) and z.dtype == tf.float32 +consume_remaining(z)