Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape {
private static final TensorType TENSOR_2_3_4_FLOAT32 =
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4)));

private static final TensorType TENSOR_2_2_4_FLOAT32 =
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(2), new NumericDim(4)));

private static final TensorType TENSOR_2_1_4_FLOAT32 =
new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1), new NumericDim(4)));

private static final TensorType TENSOR_4_4_FLOAT32 =
new TensorType(FLOAT_32, asList(new NumericDim(4), new NumericDim(4)));

Expand Down Expand Up @@ -2620,15 +2626,56 @@ public void testCrfLogNorm()
4, Set.of(TENSOR_4_4_FLOAT32)));
}

/**
* Pins the {@code tf.slice} output shape derived from constant {@code begin}/{@code size}
* arguments (wala/ML#569). {@code tf.slice(x, [0, 1], [2, 2])} over a {@code (3, 4)} input yields
* {@code (2, 2)} — all {@code size} entries are non-negative, so the output shape is {@code size}
* exactly, independent of the input shape.
*
* @throws ClassHierarchyException On WALA class-hierarchy error.
* @throws IllegalArgumentException On illegal argument.
* @throws CancelException On analysis cancellation.
* @throws IOException On I/O error reading the test file.
*/
@Test
public void testSlice()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_slice.py", "consume", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32)));
}

/**
* Pins the {@code tf.slice} output shape for the "all remaining" case (wala/ML#569). A {@code
* size[i]} of {@code -1} resolves to {@code input.shape[i] - begin[i]}: {@code tf.slice(x, [1,
* 0], [-1, 3])} over a {@code (3, 4)} input yields {@code (3 - 1, 3) = (2, 3)}.
*
* @throws ClassHierarchyException On WALA class-hierarchy error.
* @throws IllegalArgumentException On illegal argument.
* @throws CancelException On analysis cancellation.
* @throws IOException On I/O error reading the test file.
*/
@Test
public void testSliceRemaining()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_slice.py", "consume_remaining", 1, 1, Map.of(2, Set.of(TENSOR_2_3_FLOAT32)));
}

/**
* Pins {@code crf_forward(inputs, state, transition_params, sequence_lengths)}'s parameter types.
* Function body mirrors {@code crf_forward} from {@code kyzhouhzau/NLPGNN/nlpgnn/metrics/crf.py}
* for tensor-type inference coverage. {@code crf_forward} is reached only through {@code
* crf_log_norm} (its sole NLPGNN caller), which passes {@code inputs} and {@code state} from
* {@code tf.slice}/{@code tf.squeeze} results. Both now infer as {@code float32} with shape
* {@code ⊤}: the dedicated {@code Slice} generator forwards the input dtype while leaving the
* shape unknown (wala/ML#568); computing the precise {@code begin}/{@code size} shape is tracked
* by <a href="https://github.com/wala/ML/issues/569">wala/ML#569</a>.
* {@code tf.slice}/{@code tf.squeeze} results over a {@code (2, 3, 4)} constant. With the {@code
* begin}/{@code size} shape derivation (wala/ML#569), {@code inputs} (from {@code
* tf.slice(inputs, [0, 1, 0], [-1, -1, -1])}) now infers as {@code (2, 2, 4)}. {@code state}
* (from {@code tf.squeeze(tf.slice(inputs, [0, 0, 0], [-1, 1, -1]), [1])}) infers as the
* pre-squeeze {@code (2, 1, 4)} rather than the runtime {@code (2, 4)}: the {@code Slice} shape
* flows through, but {@code tf.squeeze}'s axis removal is not yet modeled. Both were previously
* {@code ⊤}-shaped (dtype only, wala/ML#568).
*
* <p>TODO: Tighten {@code state} to {@code (2, 4)} once {@code tf.squeeze} drops its singleton
* axis instead of passing the shape through; {@code tf.squeeze} is a {@code pass_through} alias
* whose shape semantics are tracked by <a
* href="https://github.com/wala/ML/issues/513">wala/ML#513</a> (Bucket 2).
*
* @throws ClassHierarchyException On WALA class-hierarchy error.
* @throws IllegalArgumentException On illegal argument.
Expand All @@ -2644,8 +2691,8 @@ public void testCrfForward()
4,
14,
Map.of(
2, Set.of(TENSOR_UNKNOWN_SHAPE_FLOAT32),
3, Set.of(TENSOR_UNKNOWN_SHAPE_FLOAT32),
2, Set.of(TENSOR_2_2_4_FLOAT32),
3, Set.of(TENSOR_2_1_4_FLOAT32),
4, Set.of(TENSOR_4_4_FLOAT32),
5, Set.of(TENSOR_2_INT32)));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@
package com.ibm.wala.cast.python.ml.client;

import com.ibm.wala.cast.python.ml.types.TensorType.Dimension;
import com.ibm.wala.cast.python.ml.types.TensorType.DynamicDim;
import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;

/**
* Generator for {@code tf.slice(input_, begin, size, name=None)}. Output dtype is inherited from
* the {@code input_} input. Output shape is left at ⊤ for now: the precise shape is determined by
* the {@code begin}/{@code size} extents (where a {@code size[i]} of {@code -1} means "all
* remaining elements along axis {@code i}"), tracked by <a
* href="https://github.com/wala/ML/issues/569">wala/ML#569</a>. Forwarding only the dtype already
* recovers it from ⊤ — see <a href="https://github.com/wala/ML/issues/568">wala/ML#568</a>.
* the {@code input_} input. Output shape is derived per axis from the constant {@code begin}/{@code
* size} extents and {@code input_.shape}:
*
* <pre>
* output.shape[i] = size[i] if size[i] &gt;= 0
* = input_.shape[i] - begin[i] if size[i] == -1 ("all remaining")
* </pre>
*
* A constant {@code size} with no {@code -1} entries gives a fully concrete shape independent of
* {@code input_.shape}; a {@code -1} entry needs the corresponding {@code input_} dim and {@code
* begin[i]}, and degrades to a {@link DynamicDim} on that axis (keeping the rank) when either is
* non-constant. The shape falls back to ⊤ only when {@code begin}/{@code size} are themselves
* non-constant or their ranks disagree with {@code input_}. See <a
* href="https://github.com/wala/ML/issues/569">wala/ML#569</a>; dtype forwarding alone landed in <a
* href="https://github.com/wala/ML/issues/568">wala/ML#568</a>.
*
* @see <a href="https://www.tensorflow.org/api_docs/python/tf/slice">tf.slice</a>
* @author <a href="mailto:khatchad@hunter.cuny.edu">Raffi Khatchadourian</a>
*/
public class Slice extends PassThroughUnaryTensorGenerator {

private static final Logger LOGGER = Logger.getLogger(Slice.class.getName());

public Slice(PointsToSetVariable source) {
super(source);
}
Expand All @@ -38,8 +57,92 @@
return "input_";
}

/**
* Derives the output shape per axis from the constant {@code begin} (arg 1) and {@code size} (arg
* 2) extents together with the {@code input_} (arg 0) shape, per the rule documented on the
* class.
*
* @param builder The {@link PropagationCallGraphBuilder} used to build the call graph.
* @return The set of possible output shapes, or {@code null} (⊤) when {@code input_}'s shape is
* unknown, {@code begin}/{@code size} are non-constant, or their ranks disagree with {@code
* input_}.
*/
@Override
protected Set<List<Dimension<?>>> getDefaultShapes(PropagationCallGraphBuilder builder) {
return null;
// input_ is arg 0 (resolved by the passthrough base); begin is arg 1; size is arg 2.
Set<List<Dimension<?>>> inputShapes = super.getDefaultShapes(builder);
if (inputShapes == null) return null;

Set<List<Dimension<?>>> beginLists = resolveConstantIntList(builder, 1, "begin");
Set<List<Dimension<?>>> sizeLists = resolveConstantIntList(builder, 2, "size");
if (beginLists == null || sizeLists == null) return null;

Set<List<Dimension<?>>> ret = HashSetFactory.make();
for (List<Dimension<?>> input : inputShapes)
for (List<Dimension<?>> begin : beginLists)
for (List<Dimension<?>> size : sizeLists) {
List<Dimension<?>> out = sliceShape(input, begin, size);
if (out != null) ret.add(out);
}
return ret.isEmpty() ? null : ret;
Comment on lines +80 to +87
}

/**
* Resolves a {@code begin}/{@code size} argument (a constant list/tuple or {@code tf.constant} of
* ints) into its dimension lists via {@link #getShapesFromShapeArgument}.
*
* @param builder The {@link PropagationCallGraphBuilder} used to build the call graph.
* @param position The 0-based positional index of the argument (excluding {@code self}).
* @param name The keyword name of the argument.
* @return The resolved int-list candidates, or {@code null} when the argument's points-to set is
* empty or it cannot be resolved to a constant list.
*/
private Set<List<Dimension<?>>> resolveConstantIntList(
PropagationCallGraphBuilder builder, int position, String name) {
OrdinalSet<InstanceKey> pts = this.getArgumentPointsToSet(builder, position, name);
if (pts == null || pts.isEmpty()) return null;
try {
Set<List<Dimension<?>>> lists = this.getShapesFromShapeArgument(builder, pts);
return (lists == null || lists.isEmpty()) ? null : lists;
} catch (RuntimeException e) {
LOGGER.fine(() -> "Could not resolve " + name + " of " + this.getSource() + ": " + e + ".");
return null;

Check warning on line 109 in com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java

View check run for this annotation

Codecov / codecov/patch

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java#L107-L109

Added lines #L107 - L109 were not covered by tests
}
Comment on lines +104 to +110
}

/**
* Applies the {@code tf.slice} per-axis shape rule to a single {@code (input, begin, size)}
* combination.
*
* @param input The {@code input_} shape.
* @param begin The constant {@code begin} offsets.
* @param size The constant {@code size} extents.
* @return The output shape, or {@code null} when the ranks disagree, a {@code size} entry is
* non-constant, or a {@code size} entry is invalid ({@code < -1}).
*/
private static List<Dimension<?>> sliceShape(
List<Dimension<?>> input, List<Dimension<?>> begin, List<Dimension<?>> size) {
int rank = input.size();
if (begin.size() != rank || size.size() != rank) return null;

List<Dimension<?>> out = new ArrayList<>(rank);
for (int i = 0; i < rank; i++) {
Dimension<?> sizeDim = size.get(i);
if (!(sizeDim instanceof NumericDim)) return null; // non-constant size extent.
int s = ((NumericDim) sizeDim).value();
if (s >= 0) {
out.add(new NumericDim(s));
} else if (s == -1) {
// "all remaining" along axis i: input_.shape[i] - begin[i], when both are constant.
Dimension<?> inDim = input.get(i);
Dimension<?> beginDim = begin.get(i);
if (inDim instanceof NumericDim && beginDim instanceof NumericDim)
out.add(new NumericDim(((NumericDim) inDim).value() - ((NumericDim) beginDim).value()));
else out.add(DynamicDim.INSTANCE); // keep the rank; this axis is dynamic.

Check warning on line 141 in com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java

View check run for this annotation

Codecov / codecov/patch

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java#L141

Added line #L141 was not covered by tests
Comment on lines +139 to +141
} else {
return null; // size < -1 is invalid for tf.slice.

Check warning on line 143 in com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java

View check run for this annotation

Codecov / codecov/patch

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Slice.java#L143

Added line #L143 was not covered by tests
}
}
return out;
}
}
28 changes: 28 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf


def consume(y):
pass


def consume_remaining(z):
pass


x = tf.constant(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
dtype=tf.float32,
)
assert x.shape == (3, 4) and x.dtype == tf.float32

# All `size` entries non-negative: the output shape is `size` exactly, independent
# of the input shape. `begin=[0, 1]`, `size=[2, 2]` over `(3, 4)` -> `(2, 2)`.
y = tf.slice(x, [0, 1], [2, 2])
assert y.shape == (2, 2) and y.dtype == tf.float32
consume(y)

# A `size[i]` of `-1` means "all remaining" along axis `i`: `input.shape[i] -
# begin[i]`. `begin=[1, 0]`, `size=[-1, 3]` over `(3, 4)` -> `(3 - 1, 3) = (2, 3)`.
z = tf.slice(x, [1, 0], [-1, 3])
assert z.shape == (2, 3) and z.dtype == tf.float32
consume_remaining(z)
Loading