diff --git a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java index e180f4b6d54..4ea7967a948 100644 --- a/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java +++ b/dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java @@ -73,6 +73,7 @@ public BadConfigurationException(final String message) { static final String META_STRUCT_MESSAGES = "messages"; static final String META_STRUCT_CATEGORIES = "attack_categories"; static final String META_STRUCT_SDS = "sds"; + static final String META_STRUCT_TAG_PROBS = "tag_probs"; public static void install() { final Config config = Config.get(); @@ -258,6 +259,8 @@ public Evaluation evaluate(final List messages, final Options options) final List tags = (List) result.get("tags"); @SuppressWarnings("unchecked") final List sdsFindings = (List) result.get("sds_findings"); + @SuppressWarnings("unchecked") + final Map tagProbs = (Map) result.get("tag_probs"); span.setTag(ACTION_TAG, action); if (reason != null) { span.setTag(REASON_TAG, reason); @@ -265,6 +268,9 @@ public Evaluation evaluate(final List messages, final Options options) if (tags != null && !tags.isEmpty()) { metaStruct.put(META_STRUCT_CATEGORIES, tags); } + if (tagProbs != null && !tagProbs.isEmpty()) { + metaStruct.put(META_STRUCT_TAG_PROBS, tagProbs); + } if (sdsFindings != null && !sdsFindings.isEmpty()) { metaStruct.put(META_STRUCT_SDS, sdsFindings); } @@ -273,9 +279,9 @@ public Evaluation evaluate(final List messages, final Options options) WafMetricCollector.get().aiGuardRequest(action, shouldBlock); if (shouldBlock) { span.setTag(BLOCKED_TAG, true); - throw new AIGuardAbortError(action, reason, tags, sdsFindings); + throw new AIGuardAbortError(action, reason, tags, tagProbs, sdsFindings); } - return new Evaluation(action, reason, tags, sdsFindings); + return new Evaluation(action, reason, tags, tagProbs, sdsFindings); } } catch (AIGuardAbortError e) { span.addThrowable(e); diff --git a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy index a1db249566e..dc5b7769e70 100644 --- a/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy +++ b/dd-java-agent/agent-aiguard/src/test/groovy/com/datadog/aiguard/AIGuardInternalTests.groovy @@ -168,7 +168,7 @@ class AIGuardInternalTests extends DDSpecification { return mockResponse( request, 200, - [data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]] + [data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], tag_probs: suite.tagProbabilities ?: [:], is_blocking_enabled: suite.blocking]]] ) } } @@ -210,12 +210,14 @@ class AIGuardInternalTests extends DDSpecification { error.action == suite.action error.reason == suite.reason error.tags == suite.tags + error.tagProbabilities == suite.tagProbabilities error.sds == [] } else { error == null eval.action == suite.action eval.reason == suite.reason eval.tags == suite.tags + eval.tagProbabilities == suite.tagProbabilities eval.sds == [] } assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false') @@ -555,6 +557,9 @@ class AIGuardInternalTests extends DDSpecification { if (suite.tags) { assert meta.attack_categories == suite.tags } + if (suite.tagProbabilities) { + assert meta.tag_probs == suite.tagProbabilities + } final receivedMessages = snakeCaseJson(meta.messages) final expectedMessages = snakeCaseJson(suite.messages) JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE) @@ -774,15 +779,17 @@ class AIGuardInternalTests extends DDSpecification { private final AIGuard.Action action private final String reason private final List tags + private final Map tagProbabilities private final boolean blocking private final String description private final String target private final List messages - TestSuite(AIGuard.Action action, String reason, List tags, boolean blocking, String description, String target, List messages) { + TestSuite(AIGuard.Action action, String reason, Map tagProbabilities, boolean blocking, String description, String target, List messages) { this.action = action this.reason = reason - this.tags = tags + this.tags = new ArrayList<>(tagProbabilities.keySet()) + this.tagProbabilities = tagProbabilities this.blocking = blocking this.description = description this.target = target @@ -791,9 +798,9 @@ class AIGuardInternalTests extends DDSpecification { static List build() { def actionValues = [ - [ALLOW, 'Go ahead', []], - [DENY, 'Nope', ['deny_everything', 'test_deny']], - [ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']] + [ALLOW, 'Go ahead', [:]], + [DENY, 'Nope', ['deny_everything': 0.2D, 'test_deny': 0.8D]], + [ABORT, 'Kill it with fire', ['alarm_tag': 0.1D, 'abort_everything': 0.9D]] ] def blockingValues = [true, false] def suiteValues = [ diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java index f1a328222e3..f8111224ed7 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/AIGuard.java @@ -5,6 +5,7 @@ import java.util.Collections; import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Objects; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -69,14 +70,20 @@ public static class AIGuardAbortError extends RuntimeException { private final Action action; private final String reason; private final List tags; + private final Map tagProbs; private final List sds; public AIGuardAbortError( - final Action action, final String reason, final List tags, final List sds) { + final Action action, + final String reason, + final List tags, + final Map tagProbs, + final List sds) { super(reason); this.action = action; this.reason = reason; this.tags = tags; + this.tagProbs = tagProbs != null ? tagProbs : Collections.emptyMap(); this.sds = sds != null ? sds : Collections.emptyList(); } @@ -92,6 +99,10 @@ public List getTags() { return tags; } + public Map getTagProbabilities() { + return tagProbs; + } + public List getSds() { return sds; } @@ -156,6 +167,7 @@ public static class Evaluation { final Action action; final String reason; final List tags; + final Map tagProbs; final List sds; /** @@ -164,13 +176,19 @@ public static class Evaluation { * @param action the recommended action for the evaluated content * @param reason human-readable explanation for the decision * @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection) + * @param tagProbs map of tags associated to their probability * @param sds list of Sensitive Data Scanner findings */ public Evaluation( - final Action action, final String reason, final List tags, final List sds) { + final Action action, + final String reason, + final List tags, + final Map tagProbs, + final List sds) { this.action = action; this.reason = reason; this.tags = tags; + this.tagProbs = tagProbs; this.sds = sds != null ? sds : Collections.emptyList(); } @@ -201,6 +219,15 @@ public List getTags() { return tags; } + /** + * Returns a map from tag to their probability (e.g. [indirect-prompt-injection: 0.25]) + * + * @return map of tag probabilities. + */ + public Map getTagProbabilities() { + return tagProbs; + } + /** * Returns the list of Sensitive Data Scanner findings. * diff --git a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java index 9fa6c5013b3..4389eba0c41 100644 --- a/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java +++ b/dd-trace-api/src/main/java/datadog/trace/api/aiguard/noop/NoOpEvaluator.java @@ -2,6 +2,7 @@ import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW; import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; import datadog.trace.api.aiguard.AIGuard.Evaluation; import datadog.trace.api.aiguard.AIGuard.Message; @@ -13,6 +14,6 @@ public final class NoOpEvaluator implements Evaluator { @Override public Evaluation evaluate(final List messages, final Options options) { - return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyList()); + return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyMap(), emptyList()); } }