diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/AliasSettings.java b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/AliasSettings.java new file mode 100644 index 000000000..536655aa4 --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/AliasSettings.java @@ -0,0 +1,5 @@ +package sample; + +public @interface AliasSettings { + int interProcDepth() default 0; +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/BaseSample.java b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/BaseSample.java new file mode 100644 index 000000000..350b2459e --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/BaseSample.java @@ -0,0 +1,57 @@ +package sample; + +public class BaseSample { + + protected Object field; + + public Object getField() { + return this.field; + } + + public void setField(Object val) { + this.field = val; + } + + public static class Box { + public Object value; + + public void touchHeap() { } + } + + public static Object readValue(Box box) { + return box.value; + } + + public static Box makeBox(Object value) { + Box box = new Box(); + box.value = value; + return box; + } + + public static Box passThroughBox(Box box) { + return box; + } + + public static class Nested { + public Box box; + + public void touchHeap() { } + + public void touchHeapDepth2() { touchHeap(); } + } + + public static class Node { + public Node next; + public Object data; + } + + public static Object identity(Object x) { + return x; + } + + public static void sinkOneValue(Object v) { } + + public static void sinkTwoValues(Object v1, Object v2) { } + + public static void doNothing() { } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/CombinedHeapAliasSample.java b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/CombinedHeapAliasSample.java new file mode 100644 index 000000000..9fd6a058c --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/CombinedHeapAliasSample.java @@ -0,0 +1,74 @@ +package sample.alias; + +import sample.AliasSettings; +import sample.BaseSample; + +public class CombinedHeapAliasSample extends BaseSample { + + static void writeArgThenTouchHeap(Box box, Object src) { + box.value = src; + Object alias = box.value; + box.touchHeap(); + sinkOneValue(alias); + } + + @AliasSettings(interProcDepth = 1) + static void returnArgField(Box box) { + Object result = readValue(box); + sinkOneValue(result); + } + + @AliasSettings(interProcDepth = 1) + static void returnIdentityThenWriteField(Box box, Object src) { + Object tmp = identity(src); + box.value = tmp; + Object result = box.value; + sinkOneValue(result); + } + + @AliasSettings(interProcDepth = 1) + static void freshObjectCarriesReturnedArg(Object src) { + Box fresh = makeBox(identity(src)); + Object result = fresh.value; + sinkOneValue(result); + } + + static void freshObjectCopiesArgumentField(Box srcBox) { + Box fresh = new Box(); + fresh.value = srcBox.value; + Object result = fresh.value; + sinkOneValue(result); + } + + @AliasSettings(interProcDepth = 1) + static void passThroughReceiverThenReadField(Box box) { + Box alias = passThroughBox(box); + Object result = alias.value; + sinkOneValue(result); + } + + @AliasSettings(interProcDepth = 1) + static void nestedWriteReturnAndTouchHeap(Nested nested, Object src) { + nested.box.value = identity(src); + Object alias = readValue(nested.box); + nested.touchHeapDepth2(); + sinkOneValue(alias); + } + + static void overwriteFieldWithFreshObject(Box box, Object src) { + box.value = src; + Box fresh = new Box(); + fresh.value = new Object(); + box.value = fresh.value; + Object result = fresh.value; + sinkOneValue(result); + } + + @AliasSettings(interProcDepth = 1) + static void returnFreshBoxThenAliasField(Box box, Object src) { + box.value = src; + Box other = makeBox(box.value); + Object result = other.value; + sinkOneValue(result); + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/HeapAliasSample.java b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/HeapAliasSample.java index d37a7ca4a..c0ce5d40d 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/HeapAliasSample.java +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/HeapAliasSample.java @@ -1,19 +1,8 @@ package sample.alias; -public class HeapAliasSample { +import sample.BaseSample; - static class Box { - Object value; - } - - static class Nested { - Box box; - } - - static class Node { - Node next; - Object data; - } +public class HeapAliasSample extends BaseSample { static void readArgField(Box box) { Object dst = box.value; @@ -112,7 +101,4 @@ static void aliasedReceiverFieldWrite(Box b1, Box b2, Object src) { Object dst = b2.value; sinkOneValue(dst); } - - static void sinkOneValue(Object v) { } - static void sinkTwoValues(Object v1, Object v2) { } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/InterProcAliasSample.java b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/InterProcAliasSample.java index c5e89d351..9e6aee32c 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/InterProcAliasSample.java +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/InterProcAliasSample.java @@ -2,34 +2,25 @@ import java.util.Collections; import java.util.List; +import sample.AliasSettings; +import sample.BaseSample; -public class InterProcAliasSample { - - Object field; - - Object getField() { - return this.field; - } - - void setField(Object val) { - this.field = val; - } +public class InterProcAliasSample extends BaseSample { + @AliasSettings(interProcDepth = 1) void testGetterAlias() { Object result = getField(); sinkOneValue(result); } + @AliasSettings(interProcDepth = 1) void testSetterThenGetter(Object src) { setField(src); Object result = getField(); sinkOneValue(result); } - static Object identity(Object x) { - return x; - } - + @AliasSettings(interProcDepth = 1) static void testIdentityCall(Object src) { Object result = identity(src); sinkOneValue(result); @@ -47,6 +38,4 @@ static void testExternalCallInvalidatesHeap(Object src) { Object dst = arr[0]; sinkOneValue(dst); } - - static void sinkOneValue(Object v) { } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/LoopAliasSample.java b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/LoopAliasSample.java index 0aa8c8b6e..3544398f5 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/LoopAliasSample.java +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/samples/src/main/java/sample/alias/LoopAliasSample.java @@ -1,13 +1,9 @@ package sample.alias; import java.util.List; +import sample.BaseSample; -public class LoopAliasSample { - - static class Node { - Node next; - Object data; - } +public class LoopAliasSample extends BaseSample { static void aliasInLoop(Object a, Object b) { Object cur = a; @@ -62,7 +58,4 @@ static void nodeNextLoopData(Node head) { Object data = cur.data; sinkOneValue(data); } - - static void sinkOneValue(Object v) { } - static void doNothing() { } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLocalAliasAnalysis.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLocalAliasAnalysis.kt index 97b23c0ba..efa9077ac 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLocalAliasAnalysis.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLocalAliasAnalysis.kt @@ -1,6 +1,7 @@ package org.opentaint.dataflow.jvm.ap.ifds import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap import org.opentaint.dataflow.ap.ifds.AccessPathBase import org.opentaint.dataflow.jvm.ap.ifds.alias.JIRIntraProcAliasAnalysis import org.opentaint.ir.api.common.cfg.CommonInst @@ -23,41 +24,65 @@ class JIRLocalAliasAnalysis( val aliasAnalysisTimeLimit: Duration = 10.seconds, ) - private val aliasInfo by lazy { compute() } + private val mayAliasInfo by lazy { computeMay() } + private val mustAliasInfo by lazy { computeMust() } class MethodAliasInfo( val aliasBeforeStatement: Array>?>?, val aliasAfterStatement: Array>?>?, ) + class MethodMustAliasInfo( + val aliasBeforeStatement: Array>?>?, + val aliasAfterStatement: Array>?>?, + ) + private fun getLocalVarAliases( alias: Array>?>, instIdx: Int, base: AccessPathBase.LocalVar ): List? = alias[instIdx]?.getOrDefault(base.idx, null)?.filter { - it !is AliasApInfo || it.accessors.isNotEmpty() || it.base != base + it !is AccessPathBase || it != base + }?.map { it.wrapAliasInfo() } + + private fun getAccessPathBaseAliases( + alias: Array>?>, + instIdx: Int, base: AccessPathBase + ): List? = + alias[instIdx]?.getOrDefault(base, null)?.filter { + it !is AccessPathBase || it != base }?.map { it.wrapAliasInfo() } + fun findMustAlias(base: AccessPathBase, statement: CommonInst): List? { + val aliasBefore = mustAliasInfo.aliasBeforeStatement ?: return null + val idx = languageManager.getInstIndex(statement) + return getAccessPathBaseAliases(aliasBefore, idx, base) + } + fun findAlias(base: AccessPathBase.LocalVar, statement: CommonInst): List? { - val aliasBefore = aliasInfo.aliasBeforeStatement ?: return null + val aliasBefore = mayAliasInfo.aliasBeforeStatement ?: return null val idx = languageManager.getInstIndex(statement) return getLocalVarAliases(aliasBefore, idx, base) } fun getAllAliasAtStatement(statement: CommonInst): Int2ObjectOpenHashMap> { - val aliasBefore = aliasInfo.aliasBeforeStatement ?: return Int2ObjectOpenHashMap() + val aliasBefore = mayAliasInfo.aliasBeforeStatement ?: return Int2ObjectOpenHashMap() val idx = languageManager.getInstIndex(statement) return aliasBefore[idx]?.let { wrapAllInfo(it) } ?: Int2ObjectOpenHashMap() } fun findAliasAfterStatement(base: AccessPathBase.LocalVar, statement: CommonInst): List? { - val aliasAfter = aliasInfo.aliasAfterStatement ?: return null + val aliasAfter = mayAliasInfo.aliasAfterStatement ?: return null val idx = languageManager.getInstIndex(statement) return getLocalVarAliases(aliasAfter, idx, base) } - private fun compute(): MethodAliasInfo { - return JIRIntraProcAliasAnalysis(entryPoint, graph, callResolver, languageManager, params).compute(localVariableReachability) + private fun computeMay(): MethodAliasInfo { + return JIRIntraProcAliasAnalysis(entryPoint, graph, callResolver, languageManager, params).computeMay(localVariableReachability) + } + + private fun computeMust(): MethodMustAliasInfo { + return JIRIntraProcAliasAnalysis(entryPoint, graph, callResolver, languageManager, params).computeMust(localVariableReachability) } sealed interface AliasAccessor { @@ -91,6 +116,14 @@ class JIRLocalAliasAnalysis( return result } + fun wrapAllInfo(info: Object2ObjectOpenHashMap>): Object2ObjectOpenHashMap> { + val result = Object2ObjectOpenHashMap>() + for ((key, aliases) in info) { + result.put(key, List(aliases.size) { aliases[it].wrapAliasInfo() }) + } + return result + } + fun unwrapAllInfo(info: Int2ObjectOpenHashMap>): Int2ObjectOpenHashMap> { val result = Int2ObjectOpenHashMap>(info.size, 0.99f) val iter = info.int2ObjectEntrySet().fastIterator() @@ -102,5 +135,17 @@ class JIRLocalAliasAnalysis( } return result } + + fun unwrapAllInfo(info: Object2ObjectOpenHashMap>): Object2ObjectOpenHashMap> { + val result = Object2ObjectOpenHashMap>(info.size, 0.99f) + val iter = info.object2ObjectEntrySet().fastIterator() + while (iter.hasNext()) { + val entry = iter.next() + val value = entry.value + val unwrapped = Array(value.size) { value[it].unwrap() } + result.put(entry.key, unwrapped) + } + return result + } } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysis.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysis.kt index dadffd372..bf8b7d32e 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysis.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysis.kt @@ -4,8 +4,6 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectMap import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap import it.unimi.dsi.fastutil.ints.IntArrayList import it.unimi.dsi.fastutil.ints.IntCollection -import it.unimi.dsi.fastutil.ints.IntIntImmutablePair -import it.unimi.dsi.fastutil.ints.IntIntMutablePair import it.unimi.dsi.fastutil.ints.IntOpenHashSet import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasAccessor import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalVariableReachability @@ -13,7 +11,6 @@ import org.opentaint.dataflow.jvm.ap.ifds.alias.JIRIntraProcAliasAnalysis.JIRIns import org.opentaint.dataflow.jvm.ap.ifds.alias.RefValue.Local import org.opentaint.dataflow.util.forEachInt import org.opentaint.dataflow.util.forEachIntEntry -import org.opentaint.dataflow.util.mapIntTo import org.opentaint.ir.api.jvm.JIRField import org.opentaint.ir.api.jvm.JIRMethod import org.opentaint.ir.api.jvm.cfg.JIRInst @@ -23,8 +20,11 @@ import java.util.BitSet class DSUAliasAnalysis( val methodCallResolver: CallResolver, val rootMethodReachabilityInfo: JIRLocalVariableReachability, - val cancellation: AnalysisCancellation + val mergeType: MergeType, + val cancellation: AnalysisCancellation, ) { + private fun isMustAlias() = mergeType == MergeType.Must + private val aliasManager = AAInfoManager() private val dsuMergeStrategy = DsuMergeStrategy(aliasManager) @@ -74,194 +74,6 @@ class DSUAliasAnalysis( override fun createLocal(idx: Int): Local = Local(idx, ContextInfo.rootContext) } - interface ImmutableState { - fun mutableCopy(): State - } - - class State private constructor( - val manager: AAInfoManager, - val aliasGroups: IntDisjointSets, - ) : ImmutableState { - - override fun hashCode(): Int = error("Unsupported operation") - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is State) return false - - /** - * We don't need to align heap instances here. - * Since set repr selection is deterministic due to strategy, - * if sets are equal their repr are also equal. - * So, heap instances must be equal. - * */ - return aliasGroups == other.aliasGroups - } - - fun asImmutable(): ImmutableState = this - - override fun mutableCopy(): State = State(manager, aliasGroups.mutableCopy()) - - fun removeUnsafe(infos: IntOpenHashSet): State { - if (infos.isEmpty()) return this - - val normalizedInfos = fixHeapElementInstance(infos) - val result = aliasGroups.mutableCopy() - - val removedInstances = IntOpenHashSet() - result.prepareRemoveAll(normalizedInfos, removedInstances) - - val removeAfterHeapFix = IntOpenHashSet() - restoreHeapInvariant(manager, result, removeAfterHeapFix) - - // since we use prepare-remove, old replaced roots are still in the DSU - removeAfterHeapFix.addAll(normalizedInfos) - result.removeAll(removeAfterHeapFix) - - if (removedInstances.isEmpty()) { - return State(manager, result) - } - - val removedHeap = IntOpenHashSet() - result.allElements().forEachInt { - if (!manager.isHeapAlias(it)) return@forEachInt - - val heapElement = manager.getHeapRefUnchecked(it) - if (removedInstances.contains(heapElement.instance)) { - removedHeap.add(it) - } - } - - return State(manager, result).removeUnsafe(removedHeap) - } - - fun aliasGroupId(info: Int): Int = aliasGroups.find(info) - fun aliasGroupRepr(groupId: Int): Int = aliasGroups.find(groupId) - - fun mergeAliasSets(aliasSets: IntOpenHashSet): State { - if (aliasSets.size < 2) return this - - val firstRepr = aliasSets.intIterator().nextInt() - val relations = mutableListOf() - aliasSets.forEachInt { - if (it == firstRepr) return@forEachInt - relations += IntIntMutablePair(firstRepr, it) - } - - val result = aliasGroups.mutableCopy() - mergeUnionRelations(relations, result, manager) - - return State(manager, result) - } - - fun forEachAliasInSet(info: Int, body: (Int) -> Unit) = forEachAliasInSetWithBreak(info, body) - - fun forEachAliasInSetWithBreak(info: Int, body: (Int) -> Unit?) { - aliasGroups.forEachElementInSet(info, body) - } - - fun allAliasSets(): Collection = aliasGroups.allSets() - - fun allSetElements(): IntOpenHashSet = aliasGroups.allElements() - - override fun toString(): String = buildString { - for (aliasSet in allAliasSets()) { - appendLine("{") - aliasSet.forEachInt { - appendLine("\t($it) -> ${manager.getElementUncheck(it)}") - } - appendLine("}") - } - } - - private fun fixHeapElementInstance(elements: IntOpenHashSet) = - elements.mapIntTo(IntOpenHashSet(elements.size)) { - ensureHeapElementCorrect(it, aliasGroups, manager) - } - - companion object { - fun empty(manager: AAInfoManager, strategy: DsuMergeStrategy): State = - State(manager, IntDisjointSets(strategy)) - - private fun restoreHeapInvariant( - manager: AAInfoManager, - state: IntDisjointSets, - elementsToRemove: IntOpenHashSet, - ) { - while (true) { - val replacements = mutableListOf() - - state.allElements().forEachInt { elementIdx -> - if (elementsToRemove.contains(elementIdx)) return@forEachInt - - val fixedHeap = ensureHeapElementCorrect(elementIdx, state, manager) - if (fixedHeap == elementIdx) return@forEachInt - - replacements += IntIntImmutablePair(elementIdx, fixedHeap) - } - - if (replacements.isEmpty()) return - - for (replacement in replacements) { - elementsToRemove.add(replacement.leftInt()) - state.union(replacement.leftInt(), replacement.rightInt()) - } - } - } - - private fun ensureHeapElementCorrect(element: Int, state: IntDisjointSets, manager: AAInfoManager): Int { - if (!manager.isHeapAlias(element)) return element - - val heapElement = manager.getHeapRefUnchecked(element) - val heapInstanceRepr = state.find(heapElement.instance) - if (heapInstanceRepr == heapElement.instance) return element - - return manager.replaceHeapInstance(element, heapInstanceRepr) - } - - fun merge(manager: AAInfoManager, strategy: DsuMergeStrategy, states: List): State { - val allElementParentRelations = mutableListOf() - states.forEach { s -> - val stateDsu = (s as State).aliasGroups - stateDsu.collectElementParentPairs(allElementParentRelations) - } - - val result = IntDisjointSets(strategy) - mergeUnionRelations(allElementParentRelations, result, manager) - - return State(manager, result) - } - - private fun mergeUnionRelations( - relations: List, - result: IntDisjointSets, - manager: AAInfoManager - ) { - val removedElements = IntOpenHashSet() - while (true) { - var modified = false - relations.forEach { - val status = result.union(it.leftInt(), it.rightInt()) - modified = modified or status - } - - if (!modified) break - - restoreHeapInvariant(manager, result, removedElements) - - relations.forEach { relation -> - val fixedLeft = ensureHeapElementCorrect(relation.leftInt(), result, manager) - val fixedRight = ensureHeapElementCorrect(relation.rightInt(), result, manager) - - relation.left(fixedLeft) - relation.right(fixedRight) - } - } - result.removeAll(removedElements) - } - } - } - private fun AAInfo.index(): Int { return aliasManager.getOrAdd(this) } @@ -342,7 +154,7 @@ class DSUAliasAnalysis( private fun merge(inst: JIRInst, states: Int2ObjectMap, call: CallTreeNode): ImmutableState { val statesToMerge = states.values.filterNotNull() - val merged = State.merge(aliasManager, dsuMergeStrategy, statesToMerge) + val merged = State.merge(aliasManager, dsuMergeStrategy, statesToMerge, mergeType) val reachabilityInfo = methodReachabilityInfo(inst.location.method) val instIdx = inst.location.index @@ -440,7 +252,7 @@ class DSUAliasAnalysis( return statesAfterCall.first().mutableCopy() } - return State.merge(aliasManager, dsuMergeStrategy, statesAfterCall) + return State.merge(aliasManager, dsuMergeStrategy, statesAfterCall, mergeType) } private fun State.invalidateOuterHeapAliases(startInvalidAliases: IntOpenHashSet): State { @@ -606,7 +418,7 @@ class DSUAliasAnalysis( val heapAlias = heapAppender(obj).index() var resultState = state - if (!state.containsMultipleConcreteOrOuterLocations(instanceInfo)) { + if (isMustAlias() || !state.containsMultipleConcreteOrOuterLocations(instanceInfo)) { resultState = resultState.remove(heapAlias) } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUState.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUState.kt new file mode 100644 index 000000000..6b607753c --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUState.kt @@ -0,0 +1,247 @@ +package org.opentaint.dataflow.jvm.ap.ifds.alias + +import it.unimi.dsi.fastutil.ints.IntArrayList +import it.unimi.dsi.fastutil.ints.IntIntImmutablePair +import it.unimi.dsi.fastutil.ints.IntIntMutablePair +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import org.opentaint.dataflow.jvm.ap.ifds.alias.DSUAliasAnalysis.DsuMergeStrategy +import org.opentaint.dataflow.util.forEachInt +import org.opentaint.dataflow.util.mapIntTo +import kotlin.collections.forEach + +interface ImmutableState { + fun mutableCopy(): State +} + +enum class MergeType{ + May, Must +} + +class State private constructor( + val manager: AAInfoManager, + val aliasGroups: IntDisjointSets, +) : ImmutableState { + + override fun hashCode(): Int = error("Unsupported operation") + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is State) return false + + /** + * We don't need to align heap instances here. + * Since set repr selection is deterministic due to strategy, + * if sets are equal their repr are also equal. + * So, heap instances must be equal. + * */ + return aliasGroups == other.aliasGroups + } + + fun asImmutable(): ImmutableState = this + + override fun mutableCopy(): State = State(manager, aliasGroups.mutableCopy()) + + fun removeUnsafe(infos: IntOpenHashSet): State { + if (infos.isEmpty()) return this + + val normalizedInfos = fixHeapElementInstance(infos) + val result = aliasGroups.mutableCopy() + + val removedInstances = IntOpenHashSet() + result.prepareRemoveAll(normalizedInfos, removedInstances) + + val removeAfterHeapFix = IntOpenHashSet() + restoreHeapInvariant(manager, result, removeAfterHeapFix) + + // since we use prepare-remove, old replaced roots are still in the DSU + removeAfterHeapFix.addAll(normalizedInfos) + result.removeAll(removeAfterHeapFix) + + if (removedInstances.isEmpty()) { + return State(manager, result) + } + + val removedHeap = IntOpenHashSet() + result.allElements().forEachInt { + if (!manager.isHeapAlias(it)) return@forEachInt + + val heapElement = manager.getHeapRefUnchecked(it) + if (removedInstances.contains(heapElement.instance)) { + removedHeap.add(it) + } + } + + return State(manager, result).removeUnsafe(removedHeap) + } + + fun aliasGroupId(info: Int): Int = aliasGroups.find(info) + fun aliasGroupRepr(groupId: Int): Int = aliasGroups.find(groupId) + + fun mergeAliasSets(aliasSets: IntOpenHashSet): State { + if (aliasSets.size < 2) return this + + val firstRepr = aliasSets.intIterator().nextInt() + val relations = mutableListOf() + aliasSets.forEachInt { + if (it == firstRepr) return@forEachInt + relations += IntIntMutablePair(firstRepr, it) + } + + val result = aliasGroups.mutableCopy() + mergeUnionRelations(relations, result, manager) + + return State(manager, result) + } + + fun forEachAliasInSet(info: Int, body: (Int) -> Unit) = forEachAliasInSetWithBreak(info, body) + + fun forEachAliasInSetWithBreak(info: Int, body: (Int) -> Unit?) { + aliasGroups.forEachElementInSet(info, body) + } + + fun allAliasSets(): Collection = aliasGroups.allSets() + + fun allSetElements(): IntOpenHashSet = aliasGroups.allElements() + + override fun toString(): String = buildString { + for (aliasSet in allAliasSets()) { + appendLine("{") + aliasSet.forEachInt { + appendLine("\t($it) -> ${manager.getElementUncheck(it)}") + } + appendLine("}") + } + } + + private fun fixHeapElementInstance(elements: IntOpenHashSet) = + elements.mapIntTo(IntOpenHashSet(elements.size)) { + ensureHeapElementCorrect(it, aliasGroups, manager) + } + + companion object { + fun empty(manager: AAInfoManager, strategy: DsuMergeStrategy): State = + State(manager, IntDisjointSets(strategy)) + + private fun restoreHeapInvariant( + manager: AAInfoManager, + state: IntDisjointSets, + elementsToRemove: IntOpenHashSet, + ) { + while (true) { + val replacements = mutableListOf() + + state.allElements().forEachInt { elementIdx -> + if (elementsToRemove.contains(elementIdx)) return@forEachInt + + val fixedHeap = ensureHeapElementCorrect(elementIdx, state, manager) + if (fixedHeap == elementIdx) return@forEachInt + + replacements += IntIntImmutablePair(elementIdx, fixedHeap) + } + + if (replacements.isEmpty()) return + + for (replacement in replacements) { + elementsToRemove.add(replacement.leftInt()) + state.union(replacement.leftInt(), replacement.rightInt()) + } + } + } + + private fun ensureHeapElementCorrect(element: Int, state: IntDisjointSets, manager: AAInfoManager): Int { + if (!manager.isHeapAlias(element)) return element + + val heapElement = manager.getHeapRefUnchecked(element) + val heapInstanceRepr = state.find(heapElement.instance) + if (heapInstanceRepr == heapElement.instance) return element + + return manager.replaceHeapInstance(element, heapInstanceRepr) + } + + fun merge(manager: AAInfoManager, strategy: DsuMergeStrategy, states: List, mergeType: MergeType): State { + val allAliasGroups = states.map { (it as State).aliasGroups } + + val relations = when (mergeType) { + MergeType.May -> mergeMay(allAliasGroups) + MergeType.Must -> mergeMust(allAliasGroups) + } + + val result = IntDisjointSets(strategy) + mergeUnionRelations(relations, result, manager) + + return State(manager, result) + } + + fun mergeMay(allAliasGroups: List): List { + val allElementParentRelations = mutableListOf() + allAliasGroups.forEach { a -> + a.collectElementParentPairs(allElementParentRelations) + } + return allElementParentRelations + } + + fun mergeMust(allAliasGroups: List): List { + val elementsInAll = IntOpenHashSet() + val allSetElements = allAliasGroups.map { it.allElements() } + var first = true + allSetElements.forEach { set -> + if (first) { + elementsInAll.addAll(set) + first = false + return@forEach + } + elementsInAll.removeAll { element -> !set.contains(element) } + } + + if (elementsInAll.isEmpty()) return emptyList() + + val resultRelations = mutableListOf() + val map = Object2IntOpenHashMap() + val totalStates = allAliasGroups.size + elementsInAll.forEachInt { element -> + val elementSignature = IntArrayList(totalStates) + allAliasGroups.forEach { aliasGroup -> + elementSignature.add(aliasGroup.find(element)) + } + if (map.contains(elementSignature)) { + val parent = map.getInt(elementSignature) + resultRelations.add(IntIntMutablePair(parent, element)) + } + else { + map.put(elementSignature, element) + } + } + + return resultRelations + } + + private fun mergeUnionRelations( + relations: List, + result: IntDisjointSets, + manager: AAInfoManager + ) { + val removedElements = IntOpenHashSet() + while (true) { + var modified = false + relations.forEach { + val status = result.union(it.leftInt(), it.rightInt()) + modified = modified or status + } + + if (!modified) break + + restoreHeapInvariant(manager, result, removedElements) + + relations.forEach { relation -> + val fixedLeft = ensureHeapElementCorrect(relation.leftInt(), result, manager) + val fixedRight = ensureHeapElementCorrect(relation.rightInt(), result, manager) + + relation.left(fixedLeft) + relation.right(fixedRight) + } + } + result.removeAll(removedElements) + } + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/JIRIntraProcAliasAnalysis.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/JIRIntraProcAliasAnalysis.kt index 141ae187b..92bd1a585 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/JIRIntraProcAliasAnalysis.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/JIRIntraProcAliasAnalysis.kt @@ -1,6 +1,7 @@ package org.opentaint.dataflow.jvm.ap.ifds.alias import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap import mu.KLogging import org.opentaint.dataflow.ap.ifds.AccessPathBase import org.opentaint.dataflow.graph.CompactGraph @@ -59,15 +60,15 @@ class JIRIntraProcAliasAnalysis( override fun buildMethodJig(entryPoint: JIRInst): JIRInstGraph = getJIG(entryPoint) } - fun compute( + fun computeMay( localVariableReachability: JIRLocalVariableReachability ): JIRLocalAliasAnalysis.MethodAliasInfo = withAnalysisCancellation( timeLimit = params.aliasAnalysisTimeLimit, - body = { compute(it, localVariableReachability) }, + body = { computeMay(it, localVariableReachability) }, onAnalysisCancelled = { logger.error { - "Alias analysis for ${entryPoint.location.method} exceed ${params.aliasAnalysisTimeLimit}" + "May alias analysis for ${entryPoint.location.method} exceed ${params.aliasAnalysisTimeLimit}" } JIRLocalAliasAnalysis.MethodAliasInfo( @@ -77,12 +78,30 @@ class JIRIntraProcAliasAnalysis( } ) - private fun compute( + fun computeMust( + localVariableReachability: JIRLocalVariableReachability + ): JIRLocalAliasAnalysis.MethodMustAliasInfo = + withAnalysisCancellation( + timeLimit = params.aliasAnalysisTimeLimit, + body = { computeMust(it, localVariableReachability) }, + onAnalysisCancelled = { + logger.error { + "Must alias analysis for ${entryPoint.location.method} exceed ${params.aliasAnalysisTimeLimit}" + } + + JIRLocalAliasAnalysis.MethodMustAliasInfo( + aliasBeforeStatement = null, + aliasAfterStatement = null + ) + } + ) + + private fun computeMay( cancellation: AnalysisCancellation, localVariableReachability: JIRLocalVariableReachability ): JIRLocalAliasAnalysis.MethodAliasInfo { val jig = getJIG(entryPoint) - val daa = DSUAliasAnalysis(CallResolver(), localVariableReachability, cancellation).analyze(jig) + val daa = DSUAliasAnalysis(CallResolver(), localVariableReachability, MergeType.May, cancellation).analyze(jig) val aliasBeforeStatement = Array(jig.statements.size) { i -> resolveLocalVar(daa.statesBeforeStmt[i], localVariableReachability, i) @@ -95,6 +114,24 @@ class JIRIntraProcAliasAnalysis( return compressAliasInfo(aliasBeforeStatement, aliasAfterStatement) } + private fun computeMust( + cancellation: AnalysisCancellation, + localVariableReachability: JIRLocalVariableReachability + ): JIRLocalAliasAnalysis.MethodMustAliasInfo { + val jig = getJIG(entryPoint) + val daa = DSUAliasAnalysis(CallResolver(), localVariableReachability, MergeType.Must, cancellation).analyze(jig) + + val aliasBeforeStatement = Array(jig.statements.size) { i -> + resolveAccessPathBase(daa.statesBeforeStmt[i], localVariableReachability, i) + } + + val aliasAfterStatement = Array(jig.statements.size) { i -> + resolveAccessPathBase(daa.statesAfterStmt[i], localVariableReachability, i) + } + + return compressMustAliasInfo(aliasBeforeStatement, aliasAfterStatement) + } + private fun compressAliasInfo( aliasBeforeStatement: Array>>, aliasAfterStatement: Array>> @@ -107,6 +144,18 @@ class JIRIntraProcAliasAnalysis( return JIRLocalAliasAnalysis.MethodAliasInfo(compressedBefore, compressedAfter) } + private fun compressMustAliasInfo( + aliasBeforeStatement: Array>>, + aliasAfterStatement: Array>> + ): JIRLocalAliasAnalysis.MethodMustAliasInfo { + val compressedBefore = arrayOfNulls>>(aliasBeforeStatement.size) + val compressedAfter = arrayOfNulls>>(aliasAfterStatement.size) + + compress(aliasBeforeStatement, compressedBefore, reference = null, referenceCompressed = null) + compress(aliasAfterStatement, compressedAfter, aliasBeforeStatement, compressedBefore) + return JIRLocalAliasAnalysis.MethodMustAliasInfo(compressedBefore, compressedAfter) + } + private fun compress( statementInfo: Array>>, compressed: Array>?>, @@ -138,6 +187,37 @@ class JIRIntraProcAliasAnalysis( } } + private fun compress( + statementInfo: Array>>, + compressed: Array>?>, + reference: Array>>?, + referenceCompressed: Array>?>? + ) { + for (i in statementInfo.indices) { + val current = statementInfo[i] + if (current.isEmpty()) continue + + if (i > 0 && statementInfo[i - 1] == current) { + compressed[i] = compressed[i - 1] + continue + } + + if (reference != null) { + if (reference[i] == current) { + compressed[i] = referenceCompressed!![i] + } + + if (i > 0 && reference[i - 1] == current) { + compressed[i] = referenceCompressed!![i - 1] + continue + } + } + + val unwrapped = JIRLocalAliasAnalysis.unwrapAllInfo(current) + compressed[i] = unwrapped + } + } + private fun resolveLocalVar( daa: ConnectedAliases, reachableLocals: JIRLocalVariableReachability, @@ -165,6 +245,35 @@ class JIRIntraProcAliasAnalysis( return result } + private fun AccessPathBase.isMustRelevantBase() = + this is AccessPathBase.LocalVar || this is AccessPathBase.Argument || this is AccessPathBase.This + + private fun resolveAccessPathBase( + daa: ConnectedAliases, + reachableLocals: JIRLocalVariableReachability, + instIdx: Int + ): Object2ObjectOpenHashMap> { + val result = Object2ObjectOpenHashMap>() + daa.aliasGroups.forEach { (_, group) -> + val converted = group + .flatMap { it.convertToAliasInfo(daa.aliasGroups, depth = 0) } + .filter { it !is AliasApInfo || reachableLocals.isReachable(it.base, instIdx) } + .distinct() + + // size == 1 means only local was converted to AliasInfo; not really meaningful + if (converted.size <= 1) return@forEach + + val bases = converted.filterIsInstance() + .filter { it.base.isMustRelevantBase() && it.accessors.isEmpty() } + .map { it.base } + + bases.forEach { base -> + result[base] = converted + } + } + return result + } + private fun AAInfo.convertToAliasInfo( aliasGroups: Int2ObjectOpenHashMap>, depth: Int diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/FactUtils.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/FactUtils.kt new file mode 100644 index 000000000..0abb4c96d --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/FactUtils.kt @@ -0,0 +1,152 @@ +package org.opentaint.dataflow.jvm.ap.ifds.analysis + +import org.opentaint.dataflow.ap.ifds.AccessPathBase +import org.opentaint.dataflow.ap.ifds.Accessor +import org.opentaint.dataflow.ap.ifds.AnyAccessor +import org.opentaint.dataflow.ap.ifds.access.FinalFactAp +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasApInfo +import org.opentaint.ir.api.jvm.cfg.JIRInst + +data class AliasedFact(val fact: FinalFactAp, val alias: AliasApInfo?) + +object FactUtils { + /** + * @return Pair of lists, where left one has everything relevant to `prefix`, and right one has everything else. + */ + fun splitByPrefix(prefix: List, fact: FinalFactAp): Pair, List> { + if (prefix.isEmpty()) return listOf(fact) to emptyList() + + val head = prefix.first() + val tail = prefix.drop(1) + if (tail.isEmpty()) { + if (fact.startsWithAccessor(AnyAccessor)) { + val factAfterAny = fact.readAccessor(AnyAccessor) + ?: error("Impossible") + + val clearHeadAfterAny = factAfterAny.clearAccessor(head) + val remainingAfterAny = clearHeadAfterAny?.prependAccessor(AnyAccessor) + + val withHeadAfterAny = factAfterAny.readAccessor(head) + val prefixAfterAny = withHeadAfterAny?.prependAccessor(head)?.prependAccessor(AnyAccessor) + + val factWithoutAny = fact.clearAccessor(AnyAccessor) + val remainingWithoutAny = factWithoutAny?.clearAccessor(head) + val prefixWithoutAny = factWithoutAny?.readAccessor(head)?.prependAccessor(head) + + val prefixFact = listOfNotNull(prefixAfterAny, prefixWithoutAny) + val remainingFact = listOfNotNull(remainingAfterAny, remainingWithoutAny) + + return prefixFact to remainingFact + } + + if (!fact.startsWithAccessor(head)) { + return emptyList() to listOf(fact) + } + + val remainingFact = listOfNotNull(fact.clearAccessor(head)) + val prefixFact = fact.readAccessor(head)?.prependAccessor(head) + + return listOfNotNull(prefixFact) to remainingFact + } + + val child = fact.readAccessor(head) + ?: return emptyList() to listOf(fact) + + val headRemaining = listOfNotNull(fact.clearAccessor(head)) + val (tailprefix, tailRemaining) = splitByPrefix(tail, child) + val prefixRelated = tailprefix.map { it.prependAccessor(head) } + val remaining = headRemaining + tailRemaining.map { it.prependAccessor(head) } + + return prefixRelated to remaining + } + + private fun rewriteForBase(fact: FinalFactAp, alias: AliasApInfo, newBase: AccessPathBase): AliasedFact { + val newFact = alias.accessors.fold(fact.rebase(newBase) as FinalFactAp?) { f, accessor -> + f?.readAccessor(accessor.apAccessor()) + } + check(newFact != null) { "Aliased fact did not contain all alias accessors!" } + return AliasedFact(newFact, alias) + } + + fun rewriteForAlias(fact: FinalFactAp, alias: AliasApInfo?): FinalFactAp { + if (alias == null) return fact + val newFact = alias.accessors.foldRight(fact.rebase(alias.base)) { accessor, f -> + f.prependAccessor(accessor.apAccessor()) + } + return newFact + } + + private fun splitByMustAlias( + fact: FinalFactAp, + mustAlias: AliasApInfo, + ): Pair, List> { + if (fact.base != mustAlias.base) { + return emptyList() to listOf(fact) + } + val aliasAccessors = mustAlias.accessors.map { it.apAccessor() } + return splitByPrefix(aliasAccessors, fact) + } + + fun splitFactByBaseMustAlias( + aliasAnalysis: JIRLocalAliasAnalysis?, + statement: JIRInst, + relevantBase: AccessPathBase, + fact: FinalFactAp, + includeOriginal: Boolean + ): Pair, List> { + val aliases = aliasAnalysis?.getMustAliases(statement, relevantBase).orEmpty() + var irrelevantFacts = listOf(fact) + val aliasedFacts = mutableListOf() + // hack for uniform calls + if (includeOriginal) aliasedFacts.add(AliasedFact(fact, null)) + + aliases.forEach { alias -> + val left = mutableListOf() + irrelevantFacts.forEach { fact -> + val (aliased, irrelevant) = splitByMustAlias(fact, alias) + left.addAll(irrelevant) + aliased.forEach { aliasedFact -> + val rebasedFact = rewriteForBase(aliasedFact, alias, relevantBase) + aliasedFacts.add(rebasedFact) + } + } + irrelevantFacts = left + } + + // no splits happened, nothing was aliased, no irrelevant expected + if (irrelevantFacts.isNotEmpty() && irrelevantFacts.first() === fact) { + irrelevantFacts = emptyList() + } + + return aliasedFacts to irrelevantFacts + } + + fun splitFactMultipleBases( + aliasAnalysis: JIRLocalAliasAnalysis?, + statement: JIRInst, + relevantBases: List, + fact: FinalFactAp, + includeOriginal: Boolean + ): Pair, List> { + var irrelevantFacts = listOf(fact) + val aliasedFacts = mutableListOf() + if (includeOriginal) aliasedFacts.add(AliasedFact(fact, null)) + + relevantBases.forEach { base -> + val newIrrelevant = mutableListOf() + irrelevantFacts.forEach { fact -> + val (aliased, irrelevant) = + splitFactByBaseMustAlias(aliasAnalysis, statement, base, fact, false) + aliasedFacts.addAll(aliased) + newIrrelevant.addAll(irrelevant) + } + irrelevantFacts = newIrrelevant + } + if (irrelevantFacts.size == 1 && irrelevantFacts.first() === fact) { + irrelevantFacts = emptyList() + } + + return aliasedFacts to irrelevantFacts + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAliasUtil.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAliasUtil.kt index a69373cec..ed488f712 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAliasUtil.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAliasUtil.kt @@ -22,6 +22,19 @@ fun JIRLocalAliasAnalysis.forEachAliasAtStatement(statement: JIRInst, fact: Fina .forEach { alias -> applyAlias(fact, alias, body) } } +fun JIRLocalAliasAnalysis.getMustAliases( + statement: JIRInst, + relevantBase: AccessPathBase +): List { + return findMustAlias(relevantBase, statement).orEmpty() + .filterIsInstance() + .filterNot { alias -> alias.base is AccessPathBase.Constant } +} + +fun JIRLocalAliasAnalysis.forEachMustAlias(statement: JIRInst, fact: FinalFactAp, body: (FinalFactAp) -> Unit) { + getMustAliases(statement, fact.base).forEach { alias -> applyAlias(fact, alias, body) } +} + fun JIRLocalAliasAnalysis.forEachAliasAfterStatement(statement: JIRInst, fact: FinalFactAp, body: (FinalFactAp) -> Unit) { val base = fact.base as? AccessPathBase.LocalVar ?: return val aliases = findAliasAfterStatement(base, statement) ?: return diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallFlowFunction.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallFlowFunction.kt index 667f816e6..bf462a65e 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallFlowFunction.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallFlowFunction.kt @@ -26,6 +26,7 @@ import org.opentaint.dataflow.jvm.ap.ifds.JIRMarkAwareConditionRewriter import org.opentaint.dataflow.jvm.ap.ifds.JIRMethodCallFactMapper import org.opentaint.dataflow.jvm.ap.ifds.JIRMethodPositionBaseTypeResolver import org.opentaint.dataflow.jvm.ap.ifds.JIRSimpleFactAwareConditionEvaluator +import org.opentaint.dataflow.jvm.ap.ifds.MethodFlowFunctionUtils import org.opentaint.dataflow.jvm.ap.ifds.TaintConfigUtils.applyCleaner import org.opentaint.dataflow.jvm.ap.ifds.TaintConfigUtils.applyPassThrough import org.opentaint.dataflow.jvm.ap.ifds.TaintConfigUtils.applyRuleWithAssumptions @@ -243,17 +244,34 @@ class JIRMethodCallFlowFunction( addCallToStart: (factReader: FinalFactReader, callerFact: FinalFactAp, startFactBase: AccessPathBase, TraceInfo) -> Unit, addUnchecked: (MethodCallFlowFunction.CallFact) -> Unit, ) { - if (!JIRMethodCallFactMapper.factIsRelevantToMethodCall(returnValue, callExpr, factAp)) { + val relevantBases = callExpr.operands.mapNotNull { MethodFlowFunctionUtils.accessPathBase(it) } + + val (aliasedFacts, irrelevantFacts) = + FactUtils.splitFactMultipleBases(analysisContext.aliasAnalysis, statement, relevantBases, factAp, true) + + var fixedFactAp: FinalFactAp? = null + + aliasedFacts.forEach { (fact, _) -> + if (JIRMethodCallFactMapper.factIsRelevantToMethodCall(returnValue, callExpr, fact)) + fixedFactAp = fact + } + + if (fixedFactAp == null) { skipCall() return } + irrelevantFacts.forEach { fact -> + val reader = FinalFactReader(fact, apManager) + addCallToReturn(reader, fact, TraceInfo.Flow) + } + val conditionRewriter = JIRMarkAwareConditionRewriter( CallPositionToJIRValueResolver(callExpr, returnValue), analysisContext, statement ) - val factReader = FinalFactReader(factAp, apManager) + val factReader = FinalFactReader(fixedFactAp, apManager) val markAfterAnyFieldResolver = createMarkAfterFieldsResolver( analysisContext.methodEntryPoint, initialFacts @@ -288,7 +306,7 @@ class JIRMethodCallFlowFunction( callee = callExpr.callee, callExpr = callExpr, returnValue = null, - factAp = factAp, + factAp = fixedFactAp, checker = analysisContext.factTypeChecker, ) { callerFact, startFactBase -> applyCleanersOrCallToStart( @@ -611,8 +629,12 @@ class JIRMethodCallFlowFunction( addCallToReturn: (FinalFactReader, FinalFactAp, TraceInfo?) -> Unit, addSideEffectRequirement: (FinalFactReader) -> Unit, ) { - val factReader = FinalFactReader(factAp, apManager) + analysisContext.aliasAnalysis?.forEachMustAlias(statement, factAp) { fact -> + val aliasReader = FinalFactReader(fact, apManager) + unresolvedCallDefaultFactPropagation(aliasReader, fact, addCallToReturn) + } + val factReader = FinalFactReader(factAp, apManager) unresolvedCallDefaultFactPropagation(factReader, factAp, addCallToReturn) val method = callExpr.callee diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallSummaryHandler.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallSummaryHandler.kt index 10d9d2b92..56916dbc4 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallSummaryHandler.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallSummaryHandler.kt @@ -10,6 +10,8 @@ import org.opentaint.dataflow.ap.ifds.analysis.MethodCallSummaryHandler import org.opentaint.dataflow.ap.ifds.analysis.MethodCallSummaryHandler.SummaryEdge import org.opentaint.dataflow.ap.ifds.analysis.MethodSequentFlowFunction.Sequent import org.opentaint.dataflow.jvm.ap.ifds.JIRMethodCallFactMapper +import org.opentaint.dataflow.jvm.ap.ifds.MethodFlowFunctionUtils +import org.opentaint.ir.api.jvm.cfg.JIRImmediate import org.opentaint.ir.api.jvm.cfg.JIRInst class JIRMethodCallSummaryHandler( @@ -51,6 +53,10 @@ class JIRMethodCallSummaryHandler( } } + analysisContext.aliasAnalysis?.forEachMustAlias(statement, summaryFactAp) { fact -> + result += handleSummaryEdge(initialFactRefinement, fact) + } + handleSummaryEdge(initialFactRefinement, summaryFactAp) } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt index 0e0fb7315..63298a4af 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt @@ -17,6 +17,7 @@ import org.opentaint.dataflow.ap.ifds.analysis.MethodSequentFlowFunction.TraceIn import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.VulnerabilityTriggerPosition import org.opentaint.dataflow.configuration.jvm.ConstantTrue import org.opentaint.dataflow.jvm.ap.ifds.CalleePositionToJIRValueResolver +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis import org.opentaint.dataflow.jvm.ap.ifds.JIRMarkAwareConditionRewriter import org.opentaint.dataflow.jvm.ap.ifds.MethodFlowFunctionUtils import org.opentaint.dataflow.jvm.ap.ifds.MethodFlowFunctionUtils.accessPathBase @@ -548,14 +549,24 @@ class JIRMethodSequentFlowFunction( val accessor = accessors.first() - if (!factAp.mayRemoveAfterWrite(instance, accessor)) { + var fixedFact: FinalFactAp? = null + + val (aliased, irrelevant) = + FactUtils.splitFactByBaseMustAlias(analysisContext.aliasAnalysis, currentInst, instance, factAp, true) + + aliased.forEach { (fact, _) -> + if (fact.mayRemoveAfterWrite(instance, accessor)) + fixedFact = fact + } + + if (fixedFact == null) { // Fact is irrelevant to current writing unchanged(factAp) return } - if (factAp.isAbstract() && accessor !in factAp.exclusions) { - val nonAbstractAp = factAp.removeAbstraction() + if (fixedFact.isAbstract() && accessor !in fixedFact.exclusions) { + val nonAbstractAp = fixedFact.removeAbstraction() if (nonAbstractAp != null) { fieldWrite( instance, accessors, assignFrom, nonAbstractAp, @@ -563,15 +574,26 @@ class JIRMethodSequentFlowFunction( ) } - propagateAbstractFactWithFieldExcluded(factAp, accessor, propagateFactWithAccessorExclude) + propagateAbstractFactWithFieldExcluded(fixedFact, accessor, propagateFactWithAccessorExclude) return } - check(factAp.startsWithAccessor(accessor)) + irrelevant.forEach { fact -> propagateFact(fact) } + aliased.forEach { (fact, alias) -> + check(fact.startsWithAccessor(accessor)) - val newAp = factAp.clearField(accessor) ?: return - propagateFact(newAp) + val hasElementAccessor = alias?.accessors.orEmpty().any { it is JIRLocalAliasAnalysis.AliasAccessor.Array } + val newAp = + // todo hack: keep fact on the array elements + if (hasElementAccessor) fact + else fact.clearField(accessor) + + newAp?.let { + val restoredFact = FactUtils.rewriteForAlias(it, alias) + propagateFact(restoredFact) + } + } } private fun propagateAbstractFactWithFieldExcluded( diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/JIRBasicAtomEvaluator.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/JIRBasicAtomEvaluator.kt index 1eb0bc8c5..7fe4b297e 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/JIRBasicAtomEvaluator.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/JIRBasicAtomEvaluator.kt @@ -294,12 +294,16 @@ class JIRBasicAtomEvaluator( val base = AccessPathBase.LocalVar(lv.index) if (!matchArrayValue) { - // todo: use must alias if negated - val aliasInfo = aa.findAlias(base, statement) + val aliasInfo = + if (negated) + aa.findMustAlias(base, statement) + else + aa.findAlias(base, statement) if (aliasInfo != null) { body(aliasInfo) } } else { + // should `negated` change behaviour here like in the if-case? val allAliases = aa.getAllAliasAtStatement(statement) for ((_, aliasSet) in allAliases) { for (info in aliasSet) { diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/BasicTestUtils.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/BasicTestUtils.kt index 5049e6d7f..4515bbd04 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/BasicTestUtils.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/BasicTestUtils.kt @@ -4,13 +4,47 @@ import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance +import org.opentaint.dataflow.ap.ifds.AccessPathBase +import org.opentaint.dataflow.ap.ifds.access.FactAp +import org.opentaint.dataflow.ap.ifds.access.InitialFactAp +import org.opentaint.dataflow.configuration.jvm.TaintCleaner +import org.opentaint.dataflow.configuration.jvm.TaintEntryPointSource +import org.opentaint.dataflow.configuration.jvm.TaintMethodEntrySink +import org.opentaint.dataflow.configuration.jvm.TaintMethodExitSink +import org.opentaint.dataflow.configuration.jvm.TaintMethodExitSource +import org.opentaint.dataflow.configuration.jvm.TaintMethodSink +import org.opentaint.dataflow.configuration.jvm.TaintMethodSource +import org.opentaint.dataflow.configuration.jvm.TaintPassThrough +import org.opentaint.dataflow.configuration.jvm.TaintStaticFieldSource +import org.opentaint.dataflow.ifds.SingletonUnit +import org.opentaint.dataflow.ifds.UnitType +import org.opentaint.dataflow.ifds.UnknownUnit +import org.opentaint.dataflow.jvm.ap.ifds.JIRCallResolver +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasAccessor +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasApInfo +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalVariableReachability +import org.opentaint.dataflow.jvm.ap.ifds.analysis.JIRAnalysisManager +import org.opentaint.dataflow.jvm.ap.ifds.taint.TaintRulesProvider +import org.opentaint.dataflow.jvm.ifds.JIRUnitResolver +import org.opentaint.ir.api.common.CommonMethod +import org.opentaint.ir.api.common.cfg.CommonInst import org.opentaint.ir.api.jvm.JIRClasspath import org.opentaint.ir.api.jvm.JIRDatabase +import org.opentaint.ir.api.jvm.JIRField +import org.opentaint.ir.api.jvm.JIRMethod +import org.opentaint.ir.api.jvm.RegisteredLocation +import org.opentaint.ir.api.jvm.cfg.JIRCallInst +import org.opentaint.ir.api.jvm.cfg.JIRInst +import org.opentaint.ir.api.jvm.cfg.JIRLocalVar +import org.opentaint.ir.api.jvm.cfg.JIRValue import org.opentaint.ir.impl.JIRRamErsSettings import org.opentaint.ir.impl.features.InMemoryHierarchy import org.opentaint.ir.impl.features.Usages import org.opentaint.ir.impl.features.classpaths.UnknownClasses +import org.opentaint.ir.impl.features.usagesExt import org.opentaint.ir.impl.opentaintIrDb +import org.opentaint.jvm.graph.JApplicationGraphImpl import java.nio.file.Path import kotlin.io.path.Path @@ -20,6 +54,72 @@ abstract class BasicTestUtils { protected lateinit var db: JIRDatabase protected lateinit var cp: JIRClasspath + private val noRules = object : TaintRulesProvider { + override fun entryPointRulesForMethod( + method: CommonMethod, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun sourceRulesForMethod( + method: CommonMethod, + statement: CommonInst, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun exitSourceRulesForMethod( + method: CommonMethod, + statement: CommonInst, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun sinkRulesForMethod( + method: CommonMethod, + statement: CommonInst, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun sinkRulesForMethodEntry( + method: CommonMethod, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun sinkRulesForMethodExit( + method: CommonMethod, + statement: CommonInst, + fact: FactAp?, + initialFacts: Set?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun passTroughRulesForMethod( + method: CommonMethod, + statement: CommonInst, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun cleanerRulesForMethod( + method: CommonMethod, + statement: CommonInst, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + + override fun sourceRulesForStaticField( + field: JIRField, + statement: CommonInst, + fact: FactAp?, + allRelevant: Boolean + ): Iterable = emptyList() + } + + protected val manager by lazy { JIRAnalysisManager(cp, noRules) } + @BeforeAll fun setup() { val jarPath = System.getenv("TEST_SAMPLES_JAR") @@ -59,4 +159,83 @@ abstract class BasicTestUtils { protected fun findMethod(className: String, methodName: String) = findClass(className).declaredMethods.find { it.name == methodName } ?: error("Method $methodName not found in $className") + + protected fun aaForMethod(method: JIRMethod): JIRLocalAliasAnalysis { + val ep = method.instList.first() + val usages = runBlocking { cp.usagesExt() } + val graph = JApplicationGraphImpl(cp, usages) + + val callResolver = JIRCallResolver(cp, SingleLocationUnit(method.enclosingClass.declaration.location)) + val localReachability = JIRLocalVariableReachability(method, graph, manager) + + val params = loadSettings(method) + + return JIRLocalAliasAnalysis(ep, graph, callResolver, localReachability, manager, params) + } + + private fun loadSettings(method: JIRMethod): JIRLocalAliasAnalysis.Params { + val settings = method.annotations.find { it.name == ALIAS_SETTINGS } + val callDepth = settings?.values[INTER_PROC_SETTING] as? Int + ?: return JIRLocalAliasAnalysis.Params() + return interProcParams(callDepth) + } + + protected fun interProcParams(depth: Int) = + JIRLocalAliasAnalysis.Params(useAliasAnalysis = true, aliasAnalysisInterProcCallDepth = depth) + + protected fun JIRMethod.findSinkCall(sinkName: String): JIRCallInst = + instList.filterIsInstance().first { it.callExpr.method.name == sinkName } + + protected fun JIRLocalAliasAnalysis.valueApAliases(value: JIRValue, stmt: JIRInst): List = + valueAliases(value, stmt).filterIsInstance() + + protected fun JIRLocalAliasAnalysis.sinkArgApAliases(sink: JIRCallInst): List = + valueApAliases(sink.callExpr.args[0], sink) + + abstract fun JIRLocalAliasAnalysis.getAliases( + base: AccessPathBase.LocalVar, + statement: JIRInst + ): List? + + protected fun JIRLocalAliasAnalysis.valueAliases( + value: JIRValue, + stmt: JIRInst + ): List { + check(value is JIRLocalVar) { "Only local var aliases supported" } + return getAliases(AccessPathBase.LocalVar(value.index), stmt).orEmpty() + } + + protected fun AliasApInfo.isPlainBase(expected: AccessPathBase): Boolean = + accessors.isEmpty() && base == expected + + protected fun AliasAccessor.isField(name: String): Boolean = + this is AliasAccessor.Field && this.fieldName == name + + protected fun List.singleFieldNamed(name: String): Boolean = + size == 1 && single().isField(name) + + protected class SingleLocationUnit(val loc: RegisteredLocation) : JIRUnitResolver { + override fun resolve(method: JIRMethod): UnitType = + if (method.enclosingClass.declaration.location == loc) SingletonUnit else UnknownUnit + + override fun locationIsUnknown(loc: RegisteredLocation): Boolean = loc != this.loc + } + + companion object { + const val ALIAS_SAMPLE_PKG = "sample.alias" + const val SIMPLE_SAMPLE = "$ALIAS_SAMPLE_PKG.SimpleAliasSample" + const val LOOP_SAMPLE = "$ALIAS_SAMPLE_PKG.LoopAliasSample" + const val HEAP_SAMPLE = "$ALIAS_SAMPLE_PKG.HeapAliasSample" + const val COMBINED_HEAP_SAMPLE = "$ALIAS_SAMPLE_PKG.CombinedHeapAliasSample" + const val INTERPROC_SAMPLE = "$ALIAS_SAMPLE_PKG.InterProcAliasSample" + + private const val ALIAS_SETTINGS = "sample.AliasSettings" + private const val INTER_PROC_SETTING = "interProcDepth" + + protected const val FIELD_VALUE = "value" + protected const val FIELD_BOX = "box" + protected const val FIELD_NEXT = "next" + protected const val FIELD_DATA = "data" + protected const val FIELD_INTERPROC = "field" + } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/DSUAliasAnalysisTestUtils.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/DSUAliasAnalysisTestUtils.kt new file mode 100644 index 000000000..f7949cd82 --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/DSUAliasAnalysisTestUtils.kt @@ -0,0 +1,103 @@ +package org.opentaint.dataflow.jvm + +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasAccessor.Field +import org.opentaint.dataflow.jvm.ap.ifds.alias.AAInfo +import org.opentaint.dataflow.jvm.ap.ifds.alias.AAInfoManager +import org.opentaint.dataflow.jvm.ap.ifds.alias.ArrayAlias +import org.opentaint.dataflow.jvm.ap.ifds.alias.ContextInfo +import org.opentaint.dataflow.jvm.ap.ifds.alias.DSUAliasAnalysis +import org.opentaint.dataflow.jvm.ap.ifds.alias.FieldAlias +import org.opentaint.dataflow.jvm.ap.ifds.alias.HeapAlias +import org.opentaint.dataflow.jvm.ap.ifds.alias.LocalAlias +import org.opentaint.dataflow.jvm.ap.ifds.alias.LocalAlias.SimpleLoc +import org.opentaint.dataflow.jvm.ap.ifds.alias.MergeType +import org.opentaint.dataflow.jvm.ap.ifds.alias.RefValue +import org.opentaint.dataflow.jvm.ap.ifds.alias.State +import org.opentaint.dataflow.jvm.ap.ifds.alias.Stmt +import org.opentaint.dataflow.jvm.ap.ifds.alias.Unknown +import java.util.IdentityHashMap +import kotlin.collections.forEach + +abstract class DSUAliasAnalysisTestUtils(protected val mergeType: MergeType) { + + protected val manager = AAInfoManager() + protected val strategy = DSUAliasAnalysis.DsuMergeStrategy(manager) + + protected class StateBuilder( + private val manager: AAInfoManager, + private val strategy: DSUAliasAnalysis.DsuMergeStrategy, + private val mergeType: MergeType, + ) { + private var state = State.empty(manager, strategy) + + private val created = IdentityHashMap() + + fun local(idx: Int): LocalAlias = create( + SimpleLoc(RefValue.Local(idx, ContextInfo.rootContext)) + ) + + fun unknown(originalIdx: Int): Unknown = create( + Unknown(Stmt.Return(value = null, originalIdx = originalIdx), ContextInfo.rootContext) + ) + + fun arrayAlias(instanceInfo: AAInfo) = heapAlias(instanceInfo) { i -> HeapAlias(i, ArrayAlias) } + + fun fieldAlias(instanceInfo: AAInfo, fieldName: String) = heapAlias(instanceInfo) { i -> + HeapAlias(i, FieldAlias(Field("Cls", fieldName, "I"), isImmutable = true)) + } + + private fun heapAlias(instance: AAInfo, body: (Int) -> HeapAlias): HeapAlias { + val instanceId = infoId(instance) + val instanceGroupId = state.aliasGroupId(instanceId) + + return create(body(instanceGroupId)) + } + + private fun create(info: T): T { + created[info] = Unit + return info + } + + fun merge(set: Set) { + val setIds = infoIds(set) + state = state.mergeAliasSets(setIds) + } + + fun remove(set: Set) { + val setIds = infoIds(set) + state = state.removeUnsafe(setIds) + } + + private fun infoId(info: AAInfo): Int { + check(created.containsKey(info)) { "$info doesn't belongs to the current state" } + return manager.getOrAdd(info) + } + + private fun infoIds(set: Set): IntOpenHashSet { + val setIds = IntOpenHashSet() + set.forEach { setIds.add(infoId(it)) } + return setIds + } + + fun build(): State = state + + fun mergeStates(vararg builders: StateBuilder) { + val states = builders.map { it.state } + this.state = State.merge(manager, strategy, states, mergeType) + + builders.forEach { + created.putAll(it.created) + } + } + } + + protected inline fun buildState(body: StateBuilder.() -> Unit): State = + fillState(body).build() + + protected inline fun fillState(body: StateBuilder.() -> Unit): StateBuilder { + val builder = StateBuilder(manager, strategy, mergeType) + builder.body() + return builder + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysisStateTest.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUMayAliasAnalysisStateTest.kt similarity index 87% rename from core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysisStateTest.kt rename to core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUMayAliasAnalysisStateTest.kt index a42ea18ec..a8632bd55 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUAliasAnalysisStateTest.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUMayAliasAnalysisStateTest.kt @@ -1,93 +1,10 @@ package org.opentaint.dataflow.jvm.ap.ifds.alias -import it.unimi.dsi.fastutil.ints.IntOpenHashSet -import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasAccessor.Field -import org.opentaint.dataflow.jvm.ap.ifds.alias.DSUAliasAnalysis.State -import org.opentaint.dataflow.jvm.ap.ifds.alias.LocalAlias.SimpleLoc -import java.util.IdentityHashMap +import org.opentaint.dataflow.jvm.DSUAliasAnalysisTestUtils import kotlin.test.Test import kotlin.test.assertEquals -class DSUAliasAnalysisStateTest { - - private val manager = AAInfoManager() - private val strategy = DSUAliasAnalysis.DsuMergeStrategy(manager) - - private class StateBuilder( - private val manager: AAInfoManager, - private val strategy: DSUAliasAnalysis.DsuMergeStrategy - ) { - private var state = State.empty(manager, strategy) - - private val created = IdentityHashMap() - - fun local(idx: Int): LocalAlias = create( - SimpleLoc(RefValue.Local(idx, ContextInfo.rootContext)) - ) - - fun unknown(originalIdx: Int): Unknown = create( - Unknown(Stmt.Return(value = null, originalIdx = originalIdx), ContextInfo.rootContext) - ) - - fun arrayAlias(instanceInfo: AAInfo) = heapAlias(instanceInfo) { i -> HeapAlias(i, ArrayAlias) } - - fun fieldAlias(instanceInfo: AAInfo, fieldName: String) = heapAlias(instanceInfo) { i -> - HeapAlias(i, FieldAlias(Field("Cls", fieldName, "I"), isImmutable = true)) - } - - private fun heapAlias(instance: AAInfo, body: (Int) -> HeapAlias): HeapAlias { - val instanceId = infoId(instance) - val instanceGroupId = state.aliasGroupId(instanceId) - - return create(body(instanceGroupId)) - } - - private fun create(info: T): T { - created[info] = Unit - return info - } - - fun merge(set: Set) { - val setIds = infoIds(set) - state = state.mergeAliasSets(setIds) - } - - fun remove(set: Set) { - val setIds = infoIds(set) - state = state.removeUnsafe(setIds) - } - - private fun infoId(info: AAInfo): Int { - check(created.containsKey(info)) { "$info doesn't belongs to the current state" } - return manager.getOrAdd(info) - } - - private fun infoIds(set: Set): IntOpenHashSet { - val setIds = IntOpenHashSet() - set.forEach { setIds.add(infoId(it)) } - return setIds - } - - fun build(): State = state - - fun mergeStates(vararg builders: StateBuilder) { - val states = builders.map { it.state } - this.state = State.merge(manager, strategy, states) - - builders.forEach { - created.putAll(it.created) - } - } - } - - private inline fun buildState(body: StateBuilder.() -> Unit): State = - fillState(body).build() - - private inline fun fillState(body: StateBuilder.() -> Unit): StateBuilder { - val builder = StateBuilder(manager, strategy) - builder.body() - return builder - } +class DSUMayAliasAnalysisStateTest : DSUAliasAnalysisTestUtils(MergeType.May) { @Test fun mergeAliasSetsOfTwoLocals() { diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUMustAliasAnalysisStateTest.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUMustAliasAnalysisStateTest.kt new file mode 100644 index 000000000..6e2252416 --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/DSUMustAliasAnalysisStateTest.kt @@ -0,0 +1,835 @@ +package org.opentaint.dataflow.jvm.ap.ifds.alias + +import org.opentaint.dataflow.jvm.DSUAliasAnalysisTestUtils +import kotlin.test.Test +import kotlin.test.assertEquals + +class DSUMustAliasAnalysisStateTest : DSUAliasAnalysisTestUtils(MergeType.Must) { + + @Test + fun mergeAliasSetsOfTwoLocals() { + val set1 = buildState { + merge(setOf(local(0), local(1))) + merge(setOf(local(0), local(1))) + } + + val set2 = buildState { + merge(setOf(local(0), local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeAliasSetsOfThreeElements() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b, c)) + } + + val set2 = buildState { + merge(setOf(local(0), local(1), local(2))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeAliasSetsWithSingleElementIsNoop() { + val set1 = buildState { + val a = local(0) + val b = local(1) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(a)) + } + + val set2 = buildState { + merge(setOf(local(0))) + merge(setOf(local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeAliasSetsPreservesDisjointGroups() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(d)) + merge(setOf(a, b)) + } + + val set2 = buildState { + merge(setOf(local(0), local(1))) + merge(setOf(local(2))) + merge(setOf(local(3))) + } + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeSingleElement() { + val set1 = buildState { + val a = local(0) + val b = local(1) + merge(setOf(a, b)) + remove(setOf(a)) + } + + val set2 = buildState { + merge(setOf(local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeEmptySetIsNoop() { + val set1 = buildState { + val a = local(0) + val b = local(1) + merge(setOf(a, b)) + remove(emptySet()) + } + + val set2 = buildState { + merge(setOf(local(0), local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeAllElements() { + val set1 = buildState { + val a = local(0) + val b = local(1) + merge(setOf(a, b)) + remove(setOf(a, b)) + } + + val set2 = buildState {} + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeFromMultipleGroups() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b)) + merge(setOf(c, d)) + remove(setOf(a, c)) + } + + val set2 = buildState { + merge(setOf(local(1))) + merge(setOf(local(3))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeTwoDisjointStates() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(2), local(3))) }, + ) + } + + val set2 = buildState {} + + assertEquals(set2, set1) + } + + @Test + fun mergeOverlappingStates() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(1), local(2))) }, + ) + } + + val set2 = buildState {} + + assertEquals(set2, set1) + } + + @Test + fun mergeSingleState() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + ) + } + + val set2 = buildState { + merge(setOf(local(0), local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeAliasSetsAfterRemoveUnsafe() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b)) + merge(setOf(c, d)) + remove(setOf(b)) + merge(setOf(a, c)) + } + + val set2 = buildState { + merge(setOf(local(0), local(2), local(3))) + } + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeAfterMergeAliasSets() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b, c)) + remove(setOf(b)) + } + + val set2 = buildState { + merge(setOf(local(0), local(2))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeStatesFollowedByMergeAliasSets() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(2), local(3))) }, + ) + merge(setOf(local(1), local(2))) + } + + val set2 = buildState { + merge(setOf(local(1), local(2))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeStatesFollowedByRemoveUnsafe() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(2), local(3))) }, + ) + remove(setOf(local(0), local(3))) + } + + val set2 = buildState { + merge(setOf(local(1))) + merge(setOf(local(2))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeAliasSetsWithHeapAliases() { + val set1 = buildState { + val loc = local(0) + val arr = arrayAlias(loc) + val c = local(5) + merge(setOf(loc, arr)) + merge(setOf(c)) + merge(setOf(loc, c)) + } + + val set2 = buildState { + val loc = local(0) + val arr = arrayAlias(loc) + merge(setOf(loc, arr, local(5))) + } + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeWithFieldAliases() { + val set1 = buildState { + val loc = local(0) + val f = fieldAlias(loc, "x") + val c = local(1) + merge(setOf(loc, f, c)) + remove(setOf(f)) + } + + val set2 = buildState { + merge(setOf(local(0), local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeStatesWithUnknownAliases() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), unknown(0), unknown(1))) }, + fillState { merge(setOf(local(1), unknown(0), unknown(1))) }, + ) + } + + val set2 = buildState { + merge(setOf(unknown(0), unknown(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun chainedMergeRemoveMerge() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + val e = local(4) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(d)) + merge(setOf(e)) + merge(setOf(a, b)) + remove(setOf(c)) + merge(setOf(d, e)) + } + + val set2 = buildState { + merge(setOf(local(0), local(1))) + merge(setOf(local(3), local(4))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeThreeOverlappingStates() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1), local(2), local(4))) }, + fillState { merge(setOf(local(0), local(1), local(2), local(5))) }, + fillState { merge(setOf(local(0), local(1), local(2), local(3))) }, + ) + } + + val set2 = buildState { + merge(setOf(local(0), local(1), local(2))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeStatesRemoveAndMergeAliasSets() { + val set1 = buildState { + mergeStates( + fillState { + merge(setOf(local(0), local(2))) + merge(setOf(local(1))) + }, + fillState { merge(setOf(local(0), local(1), local(2))) }, + ) + remove(setOf(local(2))) + merge(setOf(local(0), local(3))) + } + + val set2 = buildState { + merge(setOf(local(0), local(3))) + } + + assertEquals(set2, set1) + } + + @Test + fun mergeEmptyStates() { + val set1 = buildState { + mergeStates( + fillState {}, + fillState {}, + ) + } + + val set2 = buildState {} + + assertEquals(set2, set1) + } + + @Test + fun mergeAliasSetsOnEmptySetIsNoop() { + val set1 = buildState { + val a = local(0) + val b = local(1) + merge(setOf(a)) + merge(setOf(b)) + merge(emptySet()) + } + + val set2 = buildState { + merge(setOf(local(0))) + merge(setOf(local(1))) + } + + assertEquals(set2, set1) + } + + @Test + fun removeUnsafeThenMergeThenRemove() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + val e = local(4) + merge(setOf(a, b)) + merge(setOf(c)) + merge(setOf(d)) + merge(setOf(e)) + remove(setOf(a)) + merge(setOf(b, c)) + remove(setOf(d)) + } + + val set2 = buildState { + merge(setOf(local(1), local(2))) + merge(setOf(local(4))) + } + + assertEquals(set2, set1) + } + + @Test + fun deepHeapChainMerge() { + val set1 = buildState { + val loc = local(10) + val d1 = arrayAlias(loc) + val d2 = arrayAlias(d1) + val d3 = arrayAlias(d2) + merge(setOf(loc, d1)) + merge(setOf(d2, d3)) + merge(setOf(loc, d2)) + } + + val set2 = buildState { + val loc = local(10) + val d1 = arrayAlias(loc) + val d2 = arrayAlias(d1) + val d3 = arrayAlias(d2) + merge(setOf(loc, d1, d2, d3)) + } + + assertEquals(set2, set1) + } + + @Test + fun deepMixedHeapChainMergeAndRemove() { + val set1 = buildState { + val loc = local(30) + val arr1 = arrayAlias(loc) + val fld2 = fieldAlias(arr1, "f") + val arr3 = arrayAlias(fld2) + val e = local(99) + merge(setOf(loc, arr1)) + merge(setOf(fld2, arr3)) + merge(setOf(e)) + merge(setOf(arr1, fld2)) + remove(setOf(e)) + } + + val set2 = buildState { + val loc = local(30) + val arr1 = arrayAlias(loc) + val fld2 = fieldAlias(arr1, "f") + val arr3 = arrayAlias(fld2) + merge(setOf(loc, arr1, fld2, arr3)) + } + + assertEquals(set2, set1) + } + + @Test + fun deepHeapChainStateMergeConnects() { + val set1 = buildState { + mergeStates( + fillState { + val loc = local(40) + val h1 = arrayAlias(loc) + merge(setOf(loc, h1)) + }, + fillState { + val loc = local(40) + val h1 = arrayAlias(loc) + val h2 = arrayAlias(h1) + val h3 = arrayAlias(h2) + merge(setOf(loc, h1, h2, h3)) + }, + ) + } + + val set2 = buildState { + val loc = local(40) + val h1 = arrayAlias(loc) + merge(setOf(loc, h1)) + } + + assertEquals(set2, set1) + } + + @Test + fun equalityAfterDifferentMergeOrder() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(d)) + merge(setOf(a, b)) + merge(setOf(c, d)) + } + + val set2 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(d)) + merge(setOf(c, d)) + merge(setOf(a, b)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityMergeAllAtOnceVsIncrementally() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b, c)) + } + + val set2 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b)) + merge(setOf(b, c)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityRemoveThenMergeVsMergeFiltered() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b)) + merge(setOf(c, d)) + remove(setOf(b, d)) + merge(setOf(a, c)) + } + + val set2 = buildState { + merge(setOf(local(0), local(2))) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityStateMergeVsManualMergeAliasSets() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1), local(2), local(3))) }, + fillState { merge(setOf(local(0), local(1), local(2), local(4))) }, + ) + } + + val set2 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b)) + merge(setOf(b, c)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityMergeTwoWaysWithDeepHeap() { + val set1 = buildState { + mergeStates( + fillState { + val loc = local(90) + val h1 = arrayAlias(loc) + val h2 = fieldAlias(h1, "q") + val h3 = arrayAlias(h2) + val h4 = arrayAlias(h3) + merge(setOf(loc, h1, h2, h3, h4, local(92))) + }, + fillState { + val loc = local(90) + val h1 = arrayAlias(loc) + val h2 = fieldAlias(h1, "q") + val h3 = arrayAlias(h2) + val h5 = fieldAlias(h3, "q") + merge(setOf(loc, h1, h2, h3, h5, local(91))) + }, + ) + } + + val set2 = buildState { + val loc = local(90) + val h1 = arrayAlias(loc) + val h2 = fieldAlias(h1, "q") + val h3 = arrayAlias(h2) + merge(setOf(loc, h1, h2, h3)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityChainedOpsVsDirectConstruction() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + val e = local(4) + val f = local(5) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(d)) + merge(setOf(e)) + merge(setOf(f)) + merge(setOf(a, b, c)) + remove(setOf(d)) + merge(setOf(e, f)) + } + + val set2 = buildState { + merge(setOf(local(0), local(1), local(2))) + merge(setOf(local(4), local(5))) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityMutableCopyChainVsOriginalChain() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b)) + remove(setOf(c)) + } + + val set2 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + merge(setOf(a)) + merge(setOf(b)) + merge(setOf(c)) + merge(setOf(a, b)) + remove(setOf(c)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityRemoveOrderIndependence() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b, c, d)) + remove(setOf(a)) + remove(setOf(c)) + } + + val set2 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b, c, d)) + remove(setOf(c)) + remove(setOf(a)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityRemoveAtOnceVsOneByOne() { + val set1 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b, c, d)) + remove(setOf(a, c)) + } + + val set2 = buildState { + val a = local(0) + val b = local(1) + val c = local(2) + val d = local(3) + merge(setOf(a, b, c, d)) + remove(setOf(a)) + remove(setOf(c)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityDeepChainMergeStateThenRemoveVsBuildDirect() { + val set1 = buildState { + mergeStates( + fillState { + val loc = local(140) + val f1 = fieldAlias(loc, "a") + val f2 = fieldAlias(f1, "b") + val f3 = fieldAlias(f2, "c") + merge(setOf(loc, f1, f2, f3, local(141))) + }, + fillState { + val loc = local(140) + val f1 = fieldAlias(loc, "a") + val f2 = fieldAlias(f1, "b") + val f3 = fieldAlias(f2, "c") + merge(setOf(loc, f1, f2, f3, local(141))) + }, + ) + remove(setOf(local(141))) + } + + val set2 = buildState { + val loc = local(140) + val f1 = fieldAlias(loc, "a") + val f2 = fieldAlias(f1, "b") + val f3 = fieldAlias(f2, "c") + merge(setOf(loc, f1, f2, f3)) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityThreeStateMergeVsTwoStepMerge() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(1), local(2))) }, + fillState { merge(setOf(local(2), local(3))) }, + ) + } + + val set2 = buildState { + val merged12 = fillState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(1), local(2))) }, + ) + } + mergeStates( + merged12, + fillState { merge(setOf(local(2), local(3))) }, + ) + } + + assertEquals(set1, set2) + } + + @Test + fun equalityMergeOrderDoesNotMatter() { + val set1 = buildState { + mergeStates( + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(2), local(3))) }, + fillState { merge(setOf(local(1), local(2))) }, + ) + } + + val set2 = buildState { + mergeStates( + fillState { merge(setOf(local(1), local(2))) }, + fillState { merge(setOf(local(0), local(1))) }, + fillState { merge(setOf(local(2), local(3))) }, + ) + } + + assertEquals(set1, set2) + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/AliasSampleTest.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MayAliasSampleTest.kt similarity index 67% rename from core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/AliasSampleTest.kt rename to core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MayAliasSampleTest.kt index 8f069b07d..babc2e3c3 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/AliasSampleTest.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MayAliasSampleTest.kt @@ -1,114 +1,22 @@ package org.opentaint.dataflow.jvm.ap.ifds.alias -import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.TestInstance import org.opentaint.dataflow.ap.ifds.AccessPathBase import org.opentaint.dataflow.ap.ifds.AccessPathBase.Companion.Argument -import org.opentaint.dataflow.ap.ifds.access.FactAp -import org.opentaint.dataflow.ap.ifds.access.InitialFactAp -import org.opentaint.dataflow.configuration.jvm.TaintCleaner -import org.opentaint.dataflow.configuration.jvm.TaintEntryPointSource -import org.opentaint.dataflow.configuration.jvm.TaintMethodEntrySink -import org.opentaint.dataflow.configuration.jvm.TaintMethodExitSink -import org.opentaint.dataflow.configuration.jvm.TaintMethodExitSource -import org.opentaint.dataflow.configuration.jvm.TaintMethodSink -import org.opentaint.dataflow.configuration.jvm.TaintMethodSource -import org.opentaint.dataflow.configuration.jvm.TaintPassThrough -import org.opentaint.dataflow.configuration.jvm.TaintStaticFieldSource -import org.opentaint.dataflow.ifds.SingletonUnit -import org.opentaint.dataflow.ifds.UnitType -import org.opentaint.dataflow.ifds.UnknownUnit import org.opentaint.dataflow.jvm.BasicTestUtils -import org.opentaint.dataflow.jvm.ap.ifds.JIRCallResolver import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasAccessor -import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasApInfo -import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalVariableReachability -import org.opentaint.dataflow.jvm.ap.ifds.analysis.JIRAnalysisManager -import org.opentaint.dataflow.jvm.ap.ifds.taint.TaintRulesProvider -import org.opentaint.dataflow.jvm.ifds.JIRUnitResolver -import org.opentaint.ir.api.common.CommonMethod -import org.opentaint.ir.api.common.cfg.CommonInst -import org.opentaint.ir.api.jvm.JIRField -import org.opentaint.ir.api.jvm.JIRMethod -import org.opentaint.ir.api.jvm.RegisteredLocation -import org.opentaint.ir.api.jvm.cfg.JIRCallInst import org.opentaint.ir.api.jvm.cfg.JIRInst -import org.opentaint.ir.api.jvm.cfg.JIRLocalVar -import org.opentaint.ir.api.jvm.cfg.JIRValue -import org.opentaint.ir.impl.features.usagesExt -import org.opentaint.jvm.graph.JApplicationGraphImpl import kotlin.test.Test import kotlin.test.assertFalse import kotlin.test.assertTrue @TestInstance(TestInstance.Lifecycle.PER_CLASS) -class AliasSampleTest : BasicTestUtils() { - private val noRules = object : TaintRulesProvider { - override fun entryPointRulesForMethod( - method: CommonMethod, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun sourceRulesForMethod( - method: CommonMethod, - statement: CommonInst, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun exitSourceRulesForMethod( - method: CommonMethod, - statement: CommonInst, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun sinkRulesForMethod( - method: CommonMethod, - statement: CommonInst, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun sinkRulesForMethodEntry( - method: CommonMethod, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun sinkRulesForMethodExit( - method: CommonMethod, - statement: CommonInst, - fact: FactAp?, - initialFacts: Set?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun passTroughRulesForMethod( - method: CommonMethod, - statement: CommonInst, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun cleanerRulesForMethod( - method: CommonMethod, - statement: CommonInst, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - - override fun sourceRulesForStaticField( - field: JIRField, - statement: CommonInst, - fact: FactAp?, - allRelevant: Boolean - ): Iterable = emptyList() - } - - private val manager by lazy { JIRAnalysisManager(cp, noRules) } +class MayAliasSampleTest : BasicTestUtils() { + override fun JIRLocalAliasAnalysis.getAliases( + base: AccessPathBase.LocalVar, + statement: JIRInst + ): List? = findAlias(base, statement) @Test fun `test simple aliasing`() { @@ -494,7 +402,7 @@ class AliasSampleTest : BasicTestUtils() { @Test fun `test getter aliases this field`() { val method = findMethod(INTERPROC_SAMPLE, "testGetterAlias") - val aa = aaForMethod(method, interProcParams(depth = 1)) + val aa = aaForMethod(method) val sink = method.findSinkCall("sinkOneValue") val apAliases = aa.sinkArgApAliases(sink) @@ -509,7 +417,7 @@ class AliasSampleTest : BasicTestUtils() { @Test fun `test setter then getter`() { val method = findMethod(INTERPROC_SAMPLE, "testSetterThenGetter") - val aa = aaForMethod(method, interProcParams(depth = 1)) + val aa = aaForMethod(method) val sink = method.findSinkCall("sinkOneValue") val apAliases = aa.sinkArgApAliases(sink) @@ -520,7 +428,7 @@ class AliasSampleTest : BasicTestUtils() { @Test fun `test identity same-class call`() { val method = findMethod(INTERPROC_SAMPLE, "testIdentityCall") - val aa = aaForMethod(method, interProcParams(depth = 1)) + val aa = aaForMethod(method) val sink = method.findSinkCall("sinkOneValue") val apAliases = aa.sinkArgApAliases(sink) @@ -550,67 +458,137 @@ class AliasSampleTest : BasicTestUtils() { assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } } - private fun aaForMethod( - method: JIRMethod, - params: JIRLocalAliasAnalysis.Params = JIRLocalAliasAnalysis.Params() - ): JIRLocalAliasAnalysis { - val ep = method.instList.first() - val usages = runBlocking { cp.usagesExt() } - val graph = JApplicationGraphImpl(cp, usages) + @Test + fun `test combined write arg then touch heap`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "writeArgThenTouchHeap") + val aa = aaForMethod(method) - val callResolver = JIRCallResolver(cp, SingleLocationUnit(method.enclosingClass.declaration.location)) - val localReachability = JIRLocalVariableReachability(method, graph, manager) + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) - return JIRLocalAliasAnalysis(ep, graph, callResolver, localReachability, manager, params) + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertFalse { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } } - private fun interProcParams(depth: Int) = - JIRLocalAliasAnalysis.Params(useAliasAnalysis = true, aliasAnalysisInterProcCallDepth = depth) + @Test + fun `test combined return argument field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "returnArgField") + val aa = aaForMethod(method) - private fun JIRMethod.findSinkCall(sinkName: String): JIRCallInst = - instList.filterIsInstance().first { it.callExpr.method.name == sinkName } + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) - private fun JIRLocalAliasAnalysis.valueApAliases(value: JIRValue, stmt: JIRInst): List = - valueAliases(value, stmt).filterIsInstance() + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } - private fun JIRLocalAliasAnalysis.sinkArgApAliases(sink: JIRCallInst): List = - valueApAliases(sink.callExpr.args[0], sink) + @Test + fun `test combined return identity then write field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "returnIdentityThenWriteField") + val aa = aaForMethod(method) - private fun JIRLocalAliasAnalysis.valueAliases( - value: JIRValue, - stmt: JIRInst - ): List { - check(value is JIRLocalVar) { "Only local var aliases supported" } - return findAlias(AccessPathBase.LocalVar(value.index), stmt).orEmpty() + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } } - private fun AliasApInfo.isPlainBase(expected: AccessPathBase): Boolean = - accessors.isEmpty() && base == expected + @Test + fun `test combined fresh object carries returned arg`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "freshObjectCarriesReturnedArg") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) - private fun AliasAccessor.isField(name: String): Boolean = - this is AliasAccessor.Field && this.fieldName == name + assertTrue { apAliases.any { it.isPlainBase(Argument(0)) } } + } - private fun List.singleFieldNamed(name: String): Boolean = - size == 1 && single().isField(name) + @Test + fun `test combined fresh object copies argument field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "freshObjectCopiesArgumentField") + val aa = aaForMethod(method) - private class SingleLocationUnit(val loc: RegisteredLocation) : JIRUnitResolver { - override fun resolve(method: JIRMethod): UnitType = - if (method.enclosingClass.declaration.location == loc) SingletonUnit else UnknownUnit + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) - override fun locationIsUnknown(loc: RegisteredLocation): Boolean = loc != this.loc + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } } - companion object { - const val ALIAS_SAMPLE_PKG = "sample.alias" - const val SIMPLE_SAMPLE = "$ALIAS_SAMPLE_PKG.SimpleAliasSample" - const val LOOP_SAMPLE = "$ALIAS_SAMPLE_PKG.LoopAliasSample" - const val HEAP_SAMPLE = "$ALIAS_SAMPLE_PKG.HeapAliasSample" - const val INTERPROC_SAMPLE = "$ALIAS_SAMPLE_PKG.InterProcAliasSample" + @Test + fun `test combined pass through receiver then read field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "passThroughReceiverThenReadField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined nested write return and touch heap`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "nestedWriteReturnAndTouchHeap") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertFalse { apAliases.any { it.accessors.isNotEmpty() } } + } + + @Test + fun `test combined overwrite field with fresh object`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "overwriteFieldWithFreshObject") + val aa = aaForMethod(method) - private const val FIELD_VALUE = "value" - private const val FIELD_BOX = "box" - private const val FIELD_NEXT = "next" - private const val FIELD_DATA = "data" - private const val FIELD_INTERPROC = "field" + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + // note: may want to remove this alias link for may analysis in the future + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined return fresh box then alias field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "returnFreshBoxThenAliasField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MayAndMustRelationTest.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MayAndMustRelationTest.kt new file mode 100644 index 000000000..6b338d2ae --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MayAndMustRelationTest.kt @@ -0,0 +1,56 @@ +package org.opentaint.dataflow.jvm.ap.ifds.alias + +import org.junit.jupiter.api.TestInstance +import org.opentaint.dataflow.ap.ifds.AccessPathBase +import org.opentaint.dataflow.jvm.BasicTestUtils +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis +import org.opentaint.ir.api.jvm.JIRMethod +import org.opentaint.ir.api.jvm.cfg.JIRCallInst +import org.opentaint.ir.api.jvm.cfg.JIRInst +import org.opentaint.ir.api.jvm.cfg.JIRLocalVar +import kotlin.test.Test +import kotlin.test.fail + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class MayAndMustRelationTest : BasicTestUtils() { + override fun JIRLocalAliasAnalysis.getAliases( + base: AccessPathBase.LocalVar, + statement: JIRInst + ): List = error("unreachable") + + @Test + fun `check must alias inclusion in may alias`() { + val allClasses = cp.locations.flatMap { it.classNames.orEmpty() } + val sampleClasses = allClasses.filter { it.startsWith(ALIAS_SAMPLE_PKG) } + val methods = sampleClasses.flatMap { findClass(it).declaredMethods } + + methods.forEach { method -> + val aa = aaForMethod(method) + val sink = method.findOneOfSinkCalls() ?: return@forEach + + sink.callExpr.args.filterIsInstance().forEach { arg -> + val simpleLoc = AccessPathBase.LocalVar(arg.index) + + val mayResult = aa.findAlias(simpleLoc, sink).orEmpty().toSet() + val mustResult = aa.findMustAlias(simpleLoc, sink).orEmpty() + + mustResult.forEach { alias -> + if (alias !in mayResult) { + fail("Must alias diverged with May at `${method.enclosingClass.name}.${method.name}`!") + } + } + } + } + } + + private fun JIRMethod.findOneOfSinkCalls(): JIRCallInst? = + instList.filterIsInstance().firstOrNull { it.callExpr.method.name in SINKS } + + companion object { + private val SINKS = listOf( + "sinkOneValue", + "sinkTwoValues", + "testSimpleArgAlias", + ) + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MustAliasSampleTest.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MustAliasSampleTest.kt new file mode 100644 index 000000000..a1d04f095 --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/test/kotlin/org/opentaint/dataflow/jvm/ap/ifds/alias/MustAliasSampleTest.kt @@ -0,0 +1,589 @@ +package org.opentaint.dataflow.jvm.ap.ifds.alias + +import org.junit.jupiter.api.TestInstance +import org.opentaint.dataflow.ap.ifds.AccessPathBase +import org.opentaint.dataflow.ap.ifds.AccessPathBase.Companion.Argument +import org.opentaint.dataflow.jvm.BasicTestUtils +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis +import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis.AliasAccessor +import org.opentaint.ir.api.jvm.cfg.JIRInst +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class MustAliasSampleTest : BasicTestUtils() { + override fun JIRLocalAliasAnalysis.getAliases( + base: AccessPathBase.LocalVar, + statement: JIRInst + ): List? = findMustAlias(base, statement) + + @Test + fun `test simple aliasing`() { + val method = findMethod(SIMPLE_SAMPLE, "simpleArgAlias") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("testSimpleArgAlias") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + assertFalse { apAliases.any { it.isPlainBase(Argument(1)) } } + } + + @Test + fun `test alias in while loop`() { + val method = findMethod(LOOP_SAMPLE, "aliasInLoop") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + assertFalse { apAliases.any { it.isPlainBase(Argument(1)) } } + } + + @Test + fun `test alias in for-each loop`() { + val method = findMethod(LOOP_SAMPLE, "aliasInForEachLoop") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + } + + @Test + fun `test alias in try-catch both branches`() { + val method = findMethod(LOOP_SAMPLE, "aliasInTryCatch") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + assertFalse { apAliases.any { it.isPlainBase(Argument(1)) } } + } + + @Test + fun `test alias in try only`() { + val method = findMethod(LOOP_SAMPLE, "aliasInTryOnly") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(0)) } } + } + + @Test + fun `test node next loop - loop diverges field chain`() { + val method = findMethod(LOOP_SAMPLE, "nodeNextLoop") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + assertFalse { apAliases.any { it.base == Argument(0) && it.accessors.isNotEmpty() } } + } + + @Test + fun `test node next loop data - loop diverges field chain`() { + val method = findMethod(LOOP_SAMPLE, "nodeNextLoopData") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { + apAliases.any { + it.base == Argument(0) + && it.accessors.size == 1 + && it.accessors.single().isField(FIELD_DATA) + } + } + assertFalse { + apAliases.any { + it.base == Argument(0) + && it.accessors.size >= 2 + && it.accessors.last().isField(FIELD_DATA) + && it.accessors.dropLast(1).all { a -> a.isField(FIELD_NEXT) } + } + } + } + + @Test + fun `test read argument field`() { + val method = findMethod(HEAP_SAMPLE, "readArgField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test write then read argument field`() { + val method = findMethod(HEAP_SAMPLE, "writeArgField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test read argument deep field`() { + val method = findMethod(HEAP_SAMPLE, "readArgDeepField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) + && it.accessors.size == 2 + && it.accessors[0].isField(FIELD_BOX) + && it.accessors[1].isField(FIELD_VALUE) + } + } + } + + @Test + fun `test write then read argument deep field`() { + val method = findMethod(HEAP_SAMPLE, "writeArgDeepField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) + && it.accessors.size == 2 + && it.accessors[0].isField(FIELD_BOX) + && it.accessors[1].isField(FIELD_VALUE) + } + } + } + + @Test + fun `test read argument array element`() { + val method = findMethod(HEAP_SAMPLE, "readArgArrayElement") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleOrNull() == AliasAccessor.Array + } + } + } + + @Test + fun `test write then read argument array element`() { + val method = findMethod(HEAP_SAMPLE, "writeArgArrayElement") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleOrNull() == AliasAccessor.Array + } + } + } + + @Test + fun `test field to field copy`() { + val method = findMethod(HEAP_SAMPLE, "fieldToField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + assertTrue { + apAliases.any { + it.base == Argument(1) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test swap fields`() { + val method = findMethod(HEAP_SAMPLE, "swapFields") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkTwoValues") + + val aValueAliases = aa.valueApAliases(sink.callExpr.args[0], sink) + assertTrue { + aValueAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + assertFalse { + aValueAliases.any { + it.base == Argument(1) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + + val bValueAliases = aa.valueApAliases(sink.callExpr.args[1], sink) + assertFalse { + bValueAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + assertTrue { + bValueAliases.any { + it.base == Argument(1) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test array element to field`() { + val method = findMethod(HEAP_SAMPLE, "arrayToField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleOrNull() == AliasAccessor.Array + } + } + assertTrue { + apAliases.any { + it.base == Argument(1) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test field to array element`() { + val method = findMethod(HEAP_SAMPLE, "fieldToArray") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + assertTrue { + apAliases.any { + it.base == Argument(1) && it.accessors.singleOrNull() == AliasAccessor.Array + } + } + } + + @Test + fun `test node traversal - loop diverges field chain`() { + val method = findMethod(HEAP_SAMPLE, "nodeTraversal") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + assertFalse { + apAliases.any { + it.base == Argument(0) + && it.accessors.isNotEmpty() + && it.accessors.all { a -> a.isField(FIELD_NEXT) } + } + } + } + + @Test + fun `test node traversal data - loop diverges field chain`() { + val method = findMethod(HEAP_SAMPLE, "nodeTraversalData") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { + apAliases.any { + it.base == Argument(0) + && it.accessors.size == 1 + && it.accessors.single().isField(FIELD_DATA) + } + } + assertFalse { + apAliases.any { + it.base == Argument(0) + && it.accessors.size >= 2 + && it.accessors.last().isField(FIELD_DATA) + && it.accessors.dropLast(1).all { a -> a.isField(FIELD_NEXT) } + } + } + } + + @Test + fun `test field overwrite on argument receiver`() { + val method = findMethod(HEAP_SAMPLE, "fieldOverwrite") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { apAliases.any { it.isPlainBase(Argument(2)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test conditional field write on argument receiver`() { + val method = findMethod(HEAP_SAMPLE, "conditionalFieldWrite") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(1)) } } + assertFalse { apAliases.any { it.isPlainBase(Argument(2)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test aliased receiver field write`() { + val method = findMethod(HEAP_SAMPLE, "aliasedReceiverFieldWrite") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(1) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test getter aliases this field`() { + val method = findMethod(INTERPROC_SAMPLE, "testGetterAlias") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == AccessPathBase.This && it.accessors.singleFieldNamed(FIELD_INTERPROC) + } + } + } + + @Test + fun `test setter then getter`() { + val method = findMethod(INTERPROC_SAMPLE, "testSetterThenGetter") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.base == Argument(0) } } + } + + @Test + fun `test identity same-class call`() { + val method = findMethod(INTERPROC_SAMPLE, "testIdentityCall") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(0)) } } + } + + @Test + fun `test external call return is unknown`() { + val method = findMethod(INTERPROC_SAMPLE, "testExternalCallReturn") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + } + + @Test + fun `test external call invalidates heap aliases`() { + val method = findMethod(INTERPROC_SAMPLE, "testExternalCallInvalidatesHeap") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(0)) } } + } + + @Test + fun `test combined write arg then touch heap`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "writeArgThenTouchHeap") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertFalse { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined return argument field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "returnArgField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined return identity then write field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "returnIdentityThenWriteField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined fresh object carries returned arg`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "freshObjectCarriesReturnedArg") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(0)) } } + } + + @Test + fun `test combined fresh object copies argument field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "freshObjectCopiesArgumentField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined pass through receiver then read field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "passThroughReceiverThenReadField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined nested write return and touch heap`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "nestedWriteReturnAndTouchHeap") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertFalse { apAliases.any { it.accessors.isNotEmpty() } } + } + + @Test + fun `test combined overwrite field with fresh object`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "overwriteFieldWithFreshObject") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertFalse { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } + + @Test + fun `test combined return fresh box then alias field`() { + val method = findMethod(COMBINED_HEAP_SAMPLE, "returnFreshBoxThenAliasField") + val aa = aaForMethod(method) + + val sink = method.findSinkCall("sinkOneValue") + val apAliases = aa.sinkArgApAliases(sink) + + assertTrue { apAliases.any { it.isPlainBase(Argument(1)) } } + assertTrue { + apAliases.any { + it.base == Argument(0) && it.accessors.singleFieldNamed(FIELD_VALUE) + } + } + } +} diff --git a/core/opentaint-java-querylang/samples/src/main/java/example/MustAliasExample.java b/core/opentaint-java-querylang/samples/src/main/java/example/MustAliasExample.java new file mode 100644 index 000000000..0caf4f9c5 --- /dev/null +++ b/core/opentaint-java-querylang/samples/src/main/java/example/MustAliasExample.java @@ -0,0 +1,80 @@ +package example; + +import base.RuleSample; +import base.RuleSet; +import example.util.CleanableData; +import java.util.Map; + +@RuleSet("example/MustAliasExample.yaml") +public abstract class MustAliasExample implements RuleSample { + protected CleanableData data; + + protected void cleanData(CleanableData data) { + data.cleanData(); + } + + protected void doNothing(CleanableData data) { } + + static class NegativeSimpleFlowTaint extends MustAliasExample { + @Override + public void entrypoint() { + this.data = new CleanableData(); + this.data.cleanData(); + this.data.sendInfo(); + } + } + + static class PositiveSimpleFlowTaint extends MustAliasExample { + @Override + public void entrypoint() { + this.data = new CleanableData(); + this.data.sendInfo(); + } + } + + static class NegativeMethodFlowTaint extends MustAliasExample { + @Override + public void entrypoint() { + CleanableData obj = new CleanableData(); + CleanableData obj2 = obj; + cleanData(obj2); + obj.sendInfo(); + } + } + + static class PositiveMethodFlowTaint extends MustAliasExample { + @Override + public void entrypoint() { + CleanableData obj = new CleanableData(); + CleanableData obj2 = obj; + if (obj.info.length() > 5) { + cleanData(obj2); + } + obj.sendInfo(); + } + } + + static class PositiveMethodFlow2Taint extends MustAliasExample { + @Override + public void entrypoint() { + CleanableData obj = new CleanableData(); + doNothing(obj); + obj.sendInfo(); + } + } + + protected Map dataMap; + + protected void putData(CleanableData data) { + dataMap.put(0, data); + } + + static class PositiveMapFlowTaint extends MustAliasExample { + @Override + public void entrypoint() { + CleanableData obj = new CleanableData(); + putData(obj); + dataMap.get(0).sendInfo(); + } + } +} diff --git a/core/opentaint-java-querylang/samples/src/main/java/example/util/CleanableData.java b/core/opentaint-java-querylang/samples/src/main/java/example/util/CleanableData.java new file mode 100644 index 000000000..ac55088a8 --- /dev/null +++ b/core/opentaint-java-querylang/samples/src/main/java/example/util/CleanableData.java @@ -0,0 +1,13 @@ +package example.util; + +public class CleanableData { + public String info; + + public CleanableData() { + this.info = "so unsafe"; + } + + public void cleanData() { } + + public void sendInfo() { } +} diff --git a/core/opentaint-java-querylang/samples/src/main/resources/example/MustAliasExample.yaml b/core/opentaint-java-querylang/samples/src/main/resources/example/MustAliasExample.yaml new file mode 100644 index 000000000..ae9971a43 --- /dev/null +++ b/core/opentaint-java-querylang/samples/src/main/resources/example/MustAliasExample.yaml @@ -0,0 +1,17 @@ +rules: + - id: example-Rule + languages: + - java + severity: ERROR + message: cleanedMethod + patterns: + - pattern: | + $A = new CleanableData(); + ... + $A.sendInfo(); + - pattern-not-inside: | + $A = new CleanableData(); + ... + $A.cleanData(); + ... + $A.sendInfo(); diff --git a/core/opentaint-java-querylang/src/test/kotlin/org/opentaint/semgrep/ExampleTest.kt b/core/opentaint-java-querylang/src/test/kotlin/org/opentaint/semgrep/ExampleTest.kt index 9a874949f..fbe0bee31 100644 --- a/core/opentaint-java-querylang/src/test/kotlin/org/opentaint/semgrep/ExampleTest.kt +++ b/core/opentaint-java-querylang/src/test/kotlin/org/opentaint/semgrep/ExampleTest.kt @@ -6,6 +6,7 @@ import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS +import org.opentaint.config.ConfigLoader import org.opentaint.dataflow.configuration.jvm.serialized.PositionBase import org.opentaint.dataflow.configuration.jvm.serialized.SerializedRule import org.opentaint.dataflow.configuration.jvm.serialized.SerializedSimpleNameMatcher.Simple @@ -13,6 +14,7 @@ import org.opentaint.dataflow.configuration.jvm.serialized.SerializedTaintPassAc import org.opentaint.semgrep.pattern.conversion.taint.anyFunction import org.opentaint.semgrep.pattern.conversion.taint.base import org.opentaint.semgrep.util.SampleBasedTest +import kotlin.collections.orEmpty import kotlin.test.Test @TestInstance(PER_CLASS) @@ -195,6 +197,12 @@ class ExampleTest : SampleBasedTest() { @Test fun `test array example`() = runTest() + @Test + fun `test must alias examples`() = runTest { cfg -> + val config = ConfigLoader.getConfig()?.passThrough.orEmpty() + cfg.copy(passThrough = cfg.passThrough.orEmpty() + config) + } + @Test fun `test join with taint and matching left`() = runTest()