Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public BadConfigurationException(final String message) {
static final String META_STRUCT_MESSAGES = "messages";
static final String META_STRUCT_CATEGORIES = "attack_categories";
static final String META_STRUCT_SDS = "sds";
static final String META_STRUCT_TAG_PROBS = "tag_probs";

public static void install() {
final Config config = Config.get();
Expand Down Expand Up @@ -258,13 +259,18 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
final List<String> tags = (List<String>) result.get("tags");
@SuppressWarnings("unchecked")
final List<?> sdsFindings = (List<?>) result.get("sds_findings");
@SuppressWarnings("unchecked")
final Map<String, Number> tagProbs = (Map<String, Number>) result.get("tag_probs");
span.setTag(ACTION_TAG, action);
if (reason != null) {
span.setTag(REASON_TAG, reason);
}
if (tags != null && !tags.isEmpty()) {
metaStruct.put(META_STRUCT_CATEGORIES, tags);
}
if (tagProbs != null && !tagProbs.isEmpty()) {
metaStruct.put(META_STRUCT_TAG_PROBS, tagProbs);
}
if (sdsFindings != null && !sdsFindings.isEmpty()) {
metaStruct.put(META_STRUCT_SDS, sdsFindings);
}
Expand All @@ -273,9 +279,9 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
if (shouldBlock) {
span.setTag(BLOCKED_TAG, true);
throw new AIGuardAbortError(action, reason, tags, sdsFindings);
throw new AIGuardAbortError(action, reason, tags, tagProbs, sdsFindings);
}
return new Evaluation(action, reason, tags, sdsFindings);
return new Evaluation(action, reason, tags, tagProbs, sdsFindings);
}
} catch (AIGuardAbortError e) {
span.addThrowable(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class AIGuardInternalTests extends DDSpecification {
return mockResponse(
request,
200,
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], is_blocking_enabled: suite.blocking]]]
[data: [attributes: [action: suite.action, reason: suite.reason, tags: suite.tags ?: [], tag_probs: suite.tagProbabilities ?: [:], is_blocking_enabled: suite.blocking]]]
)
}
}
Expand Down Expand Up @@ -210,12 +210,14 @@ class AIGuardInternalTests extends DDSpecification {
error.action == suite.action
error.reason == suite.reason
error.tags == suite.tags
error.tagProbabilities == suite.tagProbabilities
error.sds == []
} else {
error == null
eval.action == suite.action
eval.reason == suite.reason
eval.tags == suite.tags
eval.tagProbabilities == suite.tagProbabilities
eval.sds == []
}
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')
Expand Down Expand Up @@ -555,6 +557,9 @@ class AIGuardInternalTests extends DDSpecification {
if (suite.tags) {
assert meta.attack_categories == suite.tags
}
if (suite.tagProbabilities) {
assert meta.tag_probs == suite.tagProbabilities
}
final receivedMessages = snakeCaseJson(meta.messages)
final expectedMessages = snakeCaseJson(suite.messages)
JSONAssert.assertEquals(expectedMessages, receivedMessages, JSONCompareMode.NON_EXTENSIBLE)
Expand Down Expand Up @@ -774,15 +779,17 @@ class AIGuardInternalTests extends DDSpecification {
private final AIGuard.Action action
private final String reason
private final List<String> tags
private final Map<String, Double> tagProbabilities
private final boolean blocking
private final String description
private final String target
private final List<AIGuard.Message> messages

TestSuite(AIGuard.Action action, String reason, List<String> tags, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
TestSuite(AIGuard.Action action, String reason, Map<String, Double> tagProbabilities, boolean blocking, String description, String target, List<AIGuard.Message> messages) {
this.action = action
this.reason = reason
this.tags = tags
this.tags = new ArrayList<>(tagProbabilities.keySet())
this.tagProbabilities = tagProbabilities
this.blocking = blocking
this.description = description
this.target = target
Expand All @@ -791,9 +798,9 @@ class AIGuardInternalTests extends DDSpecification {

static List<TestSuite> build() {
def actionValues = [
[ALLOW, 'Go ahead', []],
[DENY, 'Nope', ['deny_everything', 'test_deny']],
[ABORT, 'Kill it with fire', ['alarm_tag', 'abort_everything']]
[ALLOW, 'Go ahead', [:]],
[DENY, 'Nope', ['deny_everything': 0.2D, 'test_deny': 0.8D]],
[ABORT, 'Kill it with fire', ['alarm_tag': 0.1D, 'abort_everything': 0.9D]]
]
def blockingValues = [true, false]
def suiteValues = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -69,14 +70,20 @@ public static class AIGuardAbortError extends RuntimeException {
private final Action action;
private final String reason;
private final List<String> tags;
private final Map<String, Number> tagProbs;
private final List<?> sds;

public AIGuardAbortError(
final Action action, final String reason, final List<String> tags, final List<?> sds) {
final Action action,
final String reason,
final List<String> tags,
final Map<String, Number> tagProbs,
final List<?> sds) {
super(reason);
this.action = action;
this.reason = reason;
this.tags = tags;
this.tagProbs = tagProbs != null ? tagProbs : Collections.emptyMap();
this.sds = sds != null ? sds : Collections.emptyList();
}

Expand All @@ -92,6 +99,10 @@ public List<String> getTags() {
return tags;
}

public Map<String, Number> getTagProbabilities() {
return tagProbs;
}

public List<?> getSds() {
return sds;
}
Expand Down Expand Up @@ -156,6 +167,7 @@ public static class Evaluation {
final Action action;
final String reason;
final List<String> tags;
final Map<String, Number> tagProbs;
final List<?> sds;

/**
Expand All @@ -164,13 +176,19 @@ public static class Evaluation {
* @param action the recommended action for the evaluated content
* @param reason human-readable explanation for the decision
* @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection)
* @param tagProbs map of tags associated to their probability
* @param sds list of Sensitive Data Scanner findings
*/
public Evaluation(
final Action action, final String reason, final List<String> tags, final List<?> sds) {
final Action action,
final String reason,
final List<String> tags,
final Map<String, Number> tagProbs,
final List<?> sds) {
this.action = action;
this.reason = reason;
this.tags = tags;
this.tagProbs = tagProbs;
this.sds = sds != null ? sds : Collections.emptyList();
}

Expand Down Expand Up @@ -201,6 +219,15 @@ public List<String> getTags() {
return tags;
}

/**
* Returns a map from tag to their probability (e.g. [indirect-prompt-injection: 0.25])
*
* @return map of tag probabilities.
*/
public Map<String, Number> getTagProbabilities() {
return tagProbs;
}

/**
* Returns the list of Sensitive Data Scanner findings.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static datadog.trace.api.aiguard.AIGuard.Action.ALLOW;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;

import datadog.trace.api.aiguard.AIGuard.Evaluation;
import datadog.trace.api.aiguard.AIGuard.Message;
Expand All @@ -13,6 +14,6 @@ public final class NoOpEvaluator implements Evaluator {

@Override
public Evaluation evaluate(final List<Message> messages, final Options options) {
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyList());
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyMap(), emptyList());
}
}
Loading