diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 144f2b4be..5bd8ff3f0 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -248,6 +248,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape { private static final TensorType TENSOR_2_3_4_FLOAT32 = new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4))); + private static final TensorType TENSOR_2_2_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(2), new NumericDim(4))); + + private static final TensorType TENSOR_2_1_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1), new NumericDim(4))); + private static final TensorType TENSOR_4_4_FLOAT32 = new TensorType(FLOAT_32, asList(new NumericDim(4), new NumericDim(4))); @@ -2620,15 +2626,56 @@ public void testCrfLogNorm() 4, Set.of(TENSOR_4_4_FLOAT32))); } + /** + * Pins the {@code tf.slice} output shape derived from constant {@code begin}/{@code size} + * arguments (wala/ML#569). {@code tf.slice(x, [0, 1], [2, 2])} over a {@code (3, 4)} input yields + * {@code (2, 2)} — all {@code size} entries are non-negative, so the output shape is {@code size} + * exactly, independent of the input shape. + * + * @throws ClassHierarchyException On WALA class-hierarchy error. + * @throws IllegalArgumentException On illegal argument. + * @throws CancelException On analysis cancellation. + * @throws IOException On I/O error reading the test file. + */ + @Test + public void testSlice() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_slice.py", "consume", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); + } + + /** + * Pins the {@code tf.slice} output shape for the "all remaining" case (wala/ML#569). A {@code + * size[i]} of {@code -1} resolves to {@code input.shape[i] - begin[i]}: {@code tf.slice(x, [1, + * 0], [-1, 3])} over a {@code (3, 4)} input yields {@code (3 - 1, 3) = (2, 3)}. + * + * @throws ClassHierarchyException On WALA class-hierarchy error. + * @throws IllegalArgumentException On illegal argument. + * @throws CancelException On analysis cancellation. + * @throws IOException On I/O error reading the test file. + */ + @Test + public void testSliceRemaining() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_slice.py", "consume_remaining", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32))); + } + /** * Pins {@code crf_forward(inputs, state, transition_params, sequence_lengths)}'s parameter types. * Function body mirrors {@code crf_forward} from {@code kyzhouhzau/NLPGNN/nlpgnn/metrics/crf.py} * for tensor-type inference coverage. {@code crf_forward} is reached only through {@code * crf_log_norm} (its sole NLPGNN caller), which passes {@code inputs} and {@code state} from - * {@code tf.slice}/{@code tf.squeeze} results. Both now infer as {@code float32} with shape - * {@code ⊤}: the dedicated {@code Slice} generator forwards the input dtype while leaving the - * shape unknown (wala/ML#568); computing the precise {@code begin}/{@code size} shape is tracked - * by wala/ML#569. + * {@code tf.slice}/{@code tf.squeeze} results over a {@code (2, 3, 4)} constant. With the {@code + * begin}/{@code size} shape derivation (wala/ML#569), {@code inputs} (from {@code + * tf.slice(inputs, [0, 1, 0], [-1, -1, -1])}) now infers as {@code (2, 2, 4)}. {@code state} + * (from {@code tf.squeeze(tf.slice(inputs, [0, 0, 0], [-1, 1, -1]), [1])}) infers as the + * pre-squeeze {@code (2, 1, 4)} rather than the runtime {@code (2, 4)}: the {@code Slice} shape + * flows through, but {@code tf.squeeze}'s axis removal is not yet modeled. Both were previously + * {@code ⊤}-shaped (dtype only, wala/ML#568). + * + *
TODO: Tighten {@code state} to {@code (2, 4)} once {@code tf.squeeze} drops its singleton + * axis instead of passing the shape through; {@code tf.squeeze} is a {@code pass_through} alias + * whose shape semantics are tracked by wala/ML#513 (Bucket 2). * * @throws ClassHierarchyException On WALA class-hierarchy error. * @throws IllegalArgumentException On illegal argument. @@ -2644,8 +2691,8 @@ public void testCrfForward() 4, 14, Map.of( - 2, Set.of(TENSOR_UNKNOWN_SHAPE_FLOAT32), - 3, Set.of(TENSOR_UNKNOWN_SHAPE_FLOAT32), + 2, Set.of(TENSOR_2_2_4_FLOAT32), + 3, Set.of(TENSOR_2_1_4_FLOAT32), 4, Set.of(TENSOR_4_4_FLOAT32), 5, Set.of(TENSOR_2_INT32))); } diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java index 3fe8b035b..3bf338429 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java @@ -1,25 +1,44 @@ package com.ibm.wala.cast.python.ml.client; import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.cast.python.ml.types.TensorType.DynamicDim; +import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import com.ibm.wala.util.collections.HashSetFactory; +import com.ibm.wala.util.intset.OrdinalSet; +import java.util.ArrayList; import java.util.List; import java.util.Set; +import java.util.logging.Logger; /** * Generator for {@code tf.slice(input_, begin, size, name=None)}. Output dtype is inherited from - * the {@code input_} input. Output shape is left at ⊤ for now: the precise shape is determined by - * the {@code begin}/{@code size} extents (where a {@code size[i]} of {@code -1} means "all - * remaining elements along axis {@code i}"), tracked by wala/ML#569. Forwarding only the dtype already - * recovers it from ⊤ — see wala/ML#568. + * the {@code input_} input. Output shape is derived per axis from the constant {@code begin}/{@code + * size} extents and {@code input_.shape}: + * + *
+ * output.shape[i] = size[i] if size[i] >= 0
+ * = input_.shape[i] - begin[i] if size[i] == -1 ("all remaining")
+ *
+ *
+ * A constant {@code size} with no {@code -1} entries gives a fully concrete shape independent of
+ * {@code input_.shape}; a {@code -1} entry needs the corresponding {@code input_} dim and {@code
+ * begin[i]}, and degrades to a {@link DynamicDim} on that axis (keeping the rank) when either is
+ * non-constant. The shape falls back to ⊤ only when {@code begin}/{@code size} are themselves
+ * non-constant or their ranks disagree with {@code input_}. See wala/ML#569; dtype forwarding alone landed in wala/ML#568.
*
* @see tf.slice
* @author Raffi Khatchadourian
*/
public class Slice extends PassThroughUnaryTensorGenerator {
+ private static final Logger LOGGER = Logger.getLogger(Slice.class.getName());
+
public Slice(PointsToSetVariable source) {
super(source);
}
@@ -38,8 +57,92 @@ protected String getInputParameterName() {
return "input_";
}
+ /**
+ * Derives the output shape per axis from the constant {@code begin} (arg 1) and {@code size} (arg
+ * 2) extents together with the {@code input_} (arg 0) shape, per the rule documented on the
+ * class.
+ *
+ * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph.
+ * @return The set of possible output shapes, or {@code null} (⊤) when {@code input_}'s shape is
+ * unknown, {@code begin}/{@code size} are non-constant, or their ranks disagree with {@code
+ * input_}.
+ */
@Override
protected Set