diff --git a/dd-trace-core/src/test/groovy/datadog/trace/TracerConnectionReliabilityTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/TracerConnectionReliabilityTest.groovy deleted file mode 100644 index 6ecc076a10d..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/TracerConnectionReliabilityTest.groovy +++ /dev/null @@ -1,167 +0,0 @@ -package datadog.trace - -import static datadog.trace.api.ConfigDefaults.DEFAULT_TRACE_AGENT_PORT -import static datadog.trace.api.ProtocolVersion.V0_4 - -import com.squareup.moshi.JsonAdapter -import com.squareup.moshi.Moshi -import com.squareup.moshi.Types -import datadog.communication.ddagent.DDAgentFeaturesDiscovery -import datadog.communication.ddagent.SharedCommunicationObjects -import datadog.metrics.api.Monitoring -import datadog.trace.agent.test.utils.PortUtils -import datadog.trace.api.IdGenerationStrategy -import datadog.trace.core.CoreTracer -import datadog.trace.test.util.DDSpecification -import java.lang.reflect.Type -import okhttp3.HttpUrl -import okhttp3.OkHttpClient -import okhttp3.Request -import org.testcontainers.containers.FixedHostPortGenericContainer -import org.testcontainers.containers.GenericContainer -import org.testcontainers.containers.wait.strategy.Wait -import spock.lang.AutoCleanup -import spock.lang.Shared - -class TracerConnectionReliabilityTest extends DDSpecification { - final static FEATURES_DISCOVERY_MIN_DELAY = 10 - - @Shared - OkHttpClient client - @Shared - JsonAdapter> traceJsonAdapter - - int agentContainerPort - @AutoCleanup - CoreTracer tracer - - def setupSpec() { - client = new OkHttpClient() - // Create body parser for /test/traces route - def moshi = new Moshi.Builder().build() - Type type = Types.newParameterizedType(List, Types.newParameterizedType(List, SentTraces)) - traceJsonAdapter = moshi.adapter(type) - } - - def setup() { - // Pick a random port for the test agent - agentContainerPort = PortUtils.randomOpenPort() - // Build a tracer talking to the test agent (with the right port and traces endpoint) - def properties = new Properties() - properties.put("trace.agent.port", Integer.toString(agentContainerPort)) - def sharedCommunicationObjects = new SharedCommunicationObjects() - sharedCommunicationObjects.agentUrl = HttpUrl.get("http://localhost:" + agentContainerPort) - sharedCommunicationObjects.agentHttpClient = client - def fixedFeaturesDiscovery = new FixedTraceEndpointFeaturesDiscovery(sharedCommunicationObjects) - sharedCommunicationObjects.setFeaturesDiscovery(fixedFeaturesDiscovery) - - tracer = CoreTracer.builder() - .idGenerationStrategy(IdGenerationStrategy.fromName("SEQUENTIAL")) - .withProperties(properties) - .sharedCommunicationObjects(sharedCommunicationObjects) - .build() - } - - def "test late agent start"() { - setup: - createSpans(10, 100) - tracer.flush() - - when: - def agentContainer = startTestAgentContainer() - def noAgentCount = getTraceCount(agentContainer) - waitForDiscoveryTimeout() - - createSpans(20, 100) - tracer.flush() - def withAgentCount = getTraceCount(agentContainer) - agentContainer.stop() - - then: - !agentContainer.running - noAgentCount == 0 - withAgentCount == 20 - } - - def "test agent restart"() { - setup: - def agentContainer = startTestAgentContainer() - - when: - createSpans(10, 100) - tracer.flush() - def withAgentCount = getTraceCount(agentContainer) - - then: - withAgentCount == 10 - - when: - agentContainer.stop() - createSpans(10, 100) - tracer.flush() - - waitForDiscoveryTimeout() - agentContainer = startTestAgentContainer() - def noTraceCount = getTraceCount(agentContainer) - createSpans(10, 100) - tracer.flush() - withAgentCount = getTraceCount(agentContainer) - agentContainer.stop() - - then: - !agentContainer.running - noTraceCount == 0 - withAgentCount == 10 - } - - def startTestAgentContainer() { - //noinspection GrDeprecatedAPIUsage Use FixedHostPortGenericContainer against deprecation because we need to know the exposed to configure the tracer at start - def agentContainer = new FixedHostPortGenericContainer("registry.ddbuild.io/images/mirror/dd-apm-test-agent/ddapm-test-agent:v1.44.0") - .withFixedExposedPort(agentContainerPort, DEFAULT_TRACE_AGENT_PORT) - .withEnv("ENABLED_CHECKS", "trace_count_header,meta_tracer_version_header,trace_content_length") - .waitingFor(Wait.forHttp("/test/traces")) - agentContainer.start() - return agentContainer - } - - def createSpans(int count, int delay) { - for (def index: 1..count) { - def span = tracer.buildSpan("datadog", "operation-${index}").start() - Thread.sleep(delay) - span.finish() - } - } - - static waitForDiscoveryTimeout() { - Thread.sleep(FEATURES_DISCOVERY_MIN_DELAY * 1.5 as long) - } - - def getTraceCount(GenericContainer agentContainer) { - def request = new Request.Builder() - .url("http://${agentContainer.host}:${agentContainerPort}/test/traces") - .build() - def execute = client.newCall(request).execute() - def body = execute.body().string() - return traceJsonAdapter.fromJson(body).size() - } - - class FixedTraceEndpointFeaturesDiscovery extends DDAgentFeaturesDiscovery { - FixedTraceEndpointFeaturesDiscovery(SharedCommunicationObjects objects) { - super(objects.agentHttpClient, Monitoring.DISABLED, objects.agentUrl, V0_4, false, false) - } - - @Override - String getTraceEndpoint() { - return V04_ENDPOINT - } - - @Override - protected long getFeaturesDiscoveryMinDelayMillis() { - FEATURES_DISCOVERY_MIN_DELAY - } - } - - static class SentTraces { - String name - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/api/writer/PrintingWriterTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/api/writer/PrintingWriterTest.groovy deleted file mode 100644 index 74072a2b24d..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/api/writer/PrintingWriterTest.groovy +++ /dev/null @@ -1,155 +0,0 @@ -package datadog.trace.api.writer - -import com.squareup.moshi.Moshi -import com.squareup.moshi.Types -import datadog.trace.common.writer.ListWriter -import datadog.trace.common.writer.PrintingWriter -import datadog.trace.core.test.DDCoreSpecification -import okio.Buffer - -import java.nio.charset.StandardCharsets - -class PrintingWriterTest extends DDCoreSpecification { - - def tracer = tracerBuilder().writer(new ListWriter()).build() - def sampleTrace - def secondTrace - - def adapter = new Moshi.Builder().build().adapter(Types.newParameterizedType(Map, String, - Types.newParameterizedType(List, - Types.newParameterizedType(List, Map)))) - - def setup() { - def builder = tracer.buildSpan("datadog", "fakeOperation") - .withServiceName("fakeService") - .withResourceName("fakeResource") - .withSpanType("fakeType") - - sampleTrace = [builder.start(), builder.start()] - secondTrace = [builder.start()] - } - - def cleanup() { - tracer?.close() - } - - def "test printing regular ids"() { - given: - def buffer = new Buffer() - def writer = new PrintingWriter(buffer.outputStream(), false) - - when: - writer.write(sampleTrace) - Map>> result = adapter.fromJson(buffer.readString(StandardCharsets.UTF_8)) - - then: - result["traces"][0].size() == sampleTrace.size() - result["traces"][0].each { - assert it["service"] == "fakeService" - assert it["name"] == "fakeOperation" - assert it["resource"] == "fakeResource" - assert it["type"] == "fakeType" - assert it["trace_id"] instanceof Number - assert it["span_id"] instanceof Number - assert it["parent_id"] instanceof Number - assert it["start"] instanceof Number - assert it["duration"] instanceof Number - assert it["error"] == 0 - assert it["metrics"] instanceof Map - assert it["meta"] instanceof Map - } - - when: - writer.write(secondTrace) - result = adapter.fromJson(buffer.readString(StandardCharsets.UTF_8)) - - then: - result["traces"][0].size() == secondTrace.size() - result["traces"][0].each { - assert it["service"] == "fakeService" - assert it["name"] == "fakeOperation" - assert it["resource"] == "fakeResource" - assert it["type"] == "fakeType" - assert it["trace_id"] instanceof Number - assert it["span_id"] instanceof Number - assert it["parent_id"] instanceof Number - assert it["start"] instanceof Number - assert it["duration"] instanceof Number - assert it["error"] == 0 - assert it["metrics"] instanceof Map - assert it["meta"] instanceof Map - } - } - - def "test printing regular hex ids"() { - - given: - def buffer = new Buffer() - def writer = new PrintingWriter(buffer.outputStream(), true) - - when: - writer.write(sampleTrace) - Map>> result = adapter.fromJson(buffer.readString(StandardCharsets.UTF_8)) - - then: - result["traces"][0].size() == sampleTrace.size() - result["traces"][0].each { - assert it["service"] == "fakeService" - assert it["name"] == "fakeOperation" - assert it["resource"] == "fakeResource" - assert it["type"] == "fakeType" - assert it["trace_id"] instanceof String - assert it["span_id"] instanceof String - assert it["parent_id"] instanceof String - assert it["start"] instanceof Number - assert it["duration"] instanceof Number - assert it["error"] == 0 - assert it["metrics"] instanceof Map - assert it["meta"] instanceof Map - } - } - - def "test printing multiple traces"() { - given: - def buffer = new Buffer() - def writer = new PrintingWriter(buffer.outputStream(), false) - - when: - writer.write(sampleTrace) - writer.write(secondTrace) - Map>> result1 = adapter.fromJson(buffer.readUtf8Line()) - Map>> result2 = adapter.fromJson(buffer.readUtf8Line()) - - then: - result1["traces"][0].size() == sampleTrace.size() - result2["traces"][0].each { - assert it["service"] == "fakeService" - assert it["name"] == "fakeOperation" - assert it["resource"] == "fakeResource" - assert it["type"] == "fakeType" - assert it["trace_id"] instanceof Number - assert it["span_id"] instanceof Number - assert it["parent_id"] instanceof Number - assert it["start"] instanceof Number - assert it["duration"] instanceof Number - assert it["error"] == 0 - assert it["metrics"] instanceof Map - assert it["meta"] instanceof Map - } - result2["traces"][0].size() == secondTrace.size() - result2["traces"][0].each { - assert it["service"] == "fakeService" - assert it["name"] == "fakeOperation" - assert it["resource"] == "fakeResource" - assert it["type"] == "fakeType" - assert it["trace_id"] instanceof Number - assert it["span_id"] instanceof Number - assert it["parent_id"] instanceof Number - assert it["start"] instanceof Number - assert it["duration"] instanceof Number - assert it["error"] == 0 - assert it["metrics"] instanceof Map - assert it["meta"] instanceof Map - } - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/api/writer/TraceStructureWriterTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/api/writer/TraceStructureWriterTest.groovy deleted file mode 100644 index 00dd1fe65dd..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/api/writer/TraceStructureWriterTest.groovy +++ /dev/null @@ -1,29 +0,0 @@ -package datadog.trace.api.writer - - -import datadog.trace.common.writer.TraceStructureWriter -import datadog.trace.core.test.DDCoreSpecification - -class TraceStructureWriterTest extends DDCoreSpecification { - def "parse CLI args"() { - when: - def args = TraceStructureWriter.parseArgs(cli, windows) - - then: - args.length > 0 - args[0] == path - - where: - windows | cli | path - true | 'C:/tmp/file' | 'C:/tmp/file' - true | 'C:\\tmp\\file' | 'C:\\tmp\\file' - true | 'file' | 'file' - true | 'C:/tmp/file:includeresource' | 'C:/tmp/file' - true | 'C:\\tmp\\file:includeresource' | 'C:\\tmp\\file' - true | 'file:includeresource' | 'file' - false | '/var/tmp/file' | '/var/tmp/file' - false | 'file' | 'file' - false | '/var/tmp/file' | '/var/tmp/file' - false | 'file' | 'file' - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/interceptor/CiVisibilityApmProtocolInterceptorTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/civisibility/interceptor/CiVisibilityApmProtocolInterceptorTest.groovy deleted file mode 100644 index 288227560b3..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/interceptor/CiVisibilityApmProtocolInterceptorTest.groovy +++ /dev/null @@ -1,61 +0,0 @@ -package datadog.trace.civisibility.interceptor - -import datadog.trace.api.DDSpanTypes -import datadog.trace.bootstrap.instrumentation.api.Tags -import datadog.trace.common.writer.ListWriter -import datadog.trace.core.test.DDCoreSpecification -import spock.lang.Timeout - -@Timeout(10) -class CiVisibilityApmProtocolInterceptorTest extends DDCoreSpecification { - - def writer = new ListWriter() - def tracer = tracerBuilder().writer(writer).build() - - def cleanup() { - tracer?.close() - } - - def "test suite and test module spans are filtered out"() { - setup: - tracer.addTraceInterceptor(CiVisibilityApmProtocolInterceptor.INSTANCE) - - tracer.buildSpan("datadog", "test-module").withSpanType(DDSpanTypes.TEST_MODULE_END).start().finish() - tracer.buildSpan("datadog", "test-suite").withSpanType(DDSpanTypes.TEST_SUITE_END).start().finish() - tracer.buildSpan("datadog", "test").withSpanType(DDSpanTypes.TEST).start().finish() - - writer.waitForTraces(1) - - expect: - def trace = writer.firstTrace() - trace.size() == 1 - - def span = trace[0] - span.operationName == "test" - } - - def "test session, test module and test suite IDs are nullified"() { - setup: - tracer.addTraceInterceptor(CiVisibilityApmProtocolInterceptor.INSTANCE) - - def testSpan = tracer.buildSpan("datadog", "test").withSpanType(DDSpanTypes.TEST).start() - testSpan.setTag(Tags.TEST_SESSION_ID, "session ID") - testSpan.setTag(Tags.TEST_MODULE_ID, "module ID") - testSpan.setTag(Tags.TEST_SUITE_ID, "suite ID") - testSpan.setTag("random tag", "random value") - testSpan.finish() - - writer.waitForTraces(1) - - expect: - def trace = writer.firstTrace() - trace.size() == 1 - - def span = trace[0] - - span.getTag(Tags.TEST_SESSION_ID) == null - span.getTag(Tags.TEST_MODULE_ID) == null - span.getTag(Tags.TEST_SUITE_ID) == null - span.getTag("random tag") == "random value" - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/interceptor/CiVisibilityTraceInterceptorTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/civisibility/interceptor/CiVisibilityTraceInterceptorTest.groovy deleted file mode 100644 index 9b50ed9c825..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/interceptor/CiVisibilityTraceInterceptorTest.groovy +++ /dev/null @@ -1,66 +0,0 @@ -package datadog.trace.civisibility.interceptor - -import datadog.trace.api.DDSpanTypes -import datadog.trace.api.DDTags -import datadog.trace.api.civisibility.CIConstants -import datadog.trace.common.writer.ListWriter -import datadog.trace.core.DDSpanContext -import datadog.trace.core.test.DDCoreSpecification -import spock.lang.Timeout - -@Timeout(10) -class CiVisibilityTraceInterceptorTest extends DDCoreSpecification { - - def writer = new ListWriter() - def tracer = tracerBuilder().writer(writer).build() - - def cleanup() { - tracer?.close() - } - - def "discard a trace that does not come from ci app"() { - tracer.addTraceInterceptor(CiVisibilityTraceInterceptor.INSTANCE) - tracer.buildSpan("datadog", "sample-span").start().finish() - - expect: - writer.size() == 0 - } - - def "do not discard a trace that comes from ci app"() { - tracer.addTraceInterceptor(CiVisibilityTraceInterceptor.INSTANCE) - - def span = tracer.buildSpan("datadog", "sample-span").start() - ((DDSpanContext) span.context()).origin = CIConstants.CIAPP_TEST_ORIGIN - span.finish() - - expect: - writer.size() == 1 - } - - def "add tracer version to spans of type #spanType"() { - setup: - tracer.addTraceInterceptor(CiVisibilityTraceInterceptor.INSTANCE) - - - def span = tracer.buildSpan("datadog", "sample-span").withSpanType(spanType).start() - ((DDSpanContext) span.context()).origin = CIConstants.CIAPP_TEST_ORIGIN - span.finish() - writer.waitForTraces(1) - - expect: - def trace = writer.firstTrace() - trace.size() == 1 - - def receivedSpan = trace[0] - - receivedSpan.getTag(DDTags.LIBRARY_VERSION_TAG_KEY) != null - - where: - spanType << [ - DDSpanTypes.TEST, - DDSpanTypes.TEST_SUITE_END, - DDSpanTypes.TEST_MODULE_END, - DDSpanTypes.TEST_SESSION_END - ] - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/writer/ddintake/CiTestCovMapperV2Test.groovy b/dd-trace-core/src/test/groovy/datadog/trace/civisibility/writer/ddintake/CiTestCovMapperV2Test.groovy deleted file mode 100644 index cdbe02fb3cc..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/writer/ddintake/CiTestCovMapperV2Test.groovy +++ /dev/null @@ -1,333 +0,0 @@ -package datadog.trace.civisibility.writer.ddintake - -import com.fasterxml.jackson.databind.ObjectMapper -import datadog.communication.serialization.GrowableBuffer -import datadog.communication.serialization.msgpack.MsgPackWriter -import datadog.trace.api.DDTraceId -import datadog.trace.api.civisibility.coverage.CoverageProbes -import datadog.trace.api.civisibility.coverage.CoverageStore -import datadog.trace.api.civisibility.coverage.NoOpProbes -import datadog.trace.api.civisibility.coverage.TestReport -import datadog.trace.api.civisibility.coverage.TestReportFileEntry -import datadog.trace.api.civisibility.domain.TestContext -import datadog.trace.api.sampling.PrioritySampling -import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes -import datadog.trace.core.CoreSpan -import datadog.trace.core.propagation.PropagationTags -import datadog.trace.core.test.DDCoreSpecification -import org.msgpack.jackson.dataformat.MessagePackFactory -import spock.lang.Shared - -import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel - -class CiTestCovMapperV2Test extends DDCoreSpecification { - - @Shared - ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()) - - def "test writes message"() { - given: - def trace = givenTrace(new TestReport(DDTraceId.from(1), 2, 3, [new TestReportFileEntry("source", BitSet.valueOf(new long[] { - 3, 5, 8 - }))])) - - when: - def message = getMappedMessage(trace) - - then: - message == [ - version : 2, - coverages: [ - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 3, - files : [ - [ - filename: "source", - bitmap: [3, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8] - ] - ] - ] - ] - ] - } - - def "test writes message with multiple files and multiple lines"() { - given: - def trace = givenTrace(new TestReport(DDTraceId.from(1), 2, 3, [ - new TestReportFileEntry("sourceA", BitSet.valueOf(new long[] { - 3, 5, 8 - })), - new TestReportFileEntry("sourceB", BitSet.valueOf(new long[] { - 1, 255, 7 - })) - ])) - - when: - def message = getMappedMessage(trace) - - then: - message == [ - version : 2, - coverages: [ - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 3, - files : [ - [ - filename: "sourceA", - bitmap:[3, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8] - ], - [ - filename: "sourceB", - bitmap:[1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 7] - ] - ] - ] - ] - ] - } - - def "test writes message with multiple reports"() { - given: - def trace = givenTrace( - new TestReport(DDTraceId.from(1), 2, 3, [ - new TestReportFileEntry("sourceA", BitSet.valueOf(new long[] { - 2, 17, 41 - })) - ]), - new TestReport(DDTraceId.from(1), 2, 4, [ - new TestReportFileEntry("sourceB", BitSet.valueOf(new long[] { - 11, 13, 55 - })) - ]), - ) - - when: - def message = getMappedMessage(trace) - - then: - message == [ - version : 2, - coverages: [ - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 3, - files : [ - [ - filename: "sourceA", - bitmap:[2, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 41] - ] - ] - ], - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 4, - files : [ - [ - filename: "sourceB", - bitmap:[11, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 55] - ] - ] - ] - ] - ] - } - - def "skips spans that have no reports"() { - given: - def trace = givenTrace(null, new TestReport(DDTraceId.from(1), 2, 3, [new TestReportFileEntry("source", BitSet.valueOf(new long[] { - 83, 25, 48 - }))]), null) - - when: - def message = getMappedMessage(trace) - - then: - message == [ - version : 2, - coverages: [ - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 3, - files : [ - [ - filename: "source", - bitmap:[83, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 48] - ] - ] - ] - ] - ] - } - - def "skips empty reports"() { - given: - def trace = givenTrace( - new TestReport(DDTraceId.from(1), 2, 3, [ - new TestReportFileEntry("source", BitSet.valueOf(new long[] { - 33, 53, 87 - })) - ]), - new TestReport(DDTraceId.from(1), 2, 4, []) - ) - - when: - def message = getMappedMessage(trace) - - then: - message == [ - version : 2, - coverages: [ - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 3, - files : [ - [ - filename: "source", - bitmap:[33, 0, 0, 0, 0, 0, 0, 0, 53, 0, 0, 0, 0, 0, 0, 0, 87] - ] - ] - ] - ] - ] - } - - def "skips duplicate reports"() { - given: - def trace = new ArrayList() - - def report = new TestReport(DDTraceId.from(1), 2, 3, [new TestReportFileEntry("source", BitSet.valueOf(new long[] { - 3, 5, 8 - }))]) - - trace.add(buildSpan(0, InternalSpanTypes.TEST, PropagationTags.factory().empty(), [:], PrioritySampling.SAMPLER_KEEP, new DummyTestContext(new DummyReportHolder(report)))) - trace.add(buildSpan(0, "testChild", PropagationTags.factory().empty(), [:], PrioritySampling.SAMPLER_KEEP, new DummyTestContext(new DummyReportHolder(report)))) - - when: - def message = getMappedMessage(trace) - - then: - message == [ - version : 2, - coverages: [ - [ - test_session_id: 1, - test_suite_id : 2, - span_id : 3, - files : [ - [ - filename: "source", - bitmap:[3, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8] - ] - ] - ] - ] - ] - } - - private List> givenTrace(TestReport... testReports) { - def trace = new ArrayList() - for (TestReport testReport : testReports) { - def testReportHolder = new DummyReportHolder(testReport) - trace.add(buildSpan(0, InternalSpanTypes.TEST, PropagationTags.factory().empty(), [:], PrioritySampling.SAMPLER_KEEP, new DummyTestContext(testReportHolder))) - } - return trace - } - - private Map getMappedMessage(List> trace) { - def buffer = new GrowableBuffer(1024) - def mapper = new CiTestCovMapperV2(false) - mapper.map(trace, new MsgPackWriter(buffer)) - - WritableByteChannel channel = new ByteArrayWritableByteChannel() - - def slice = buffer.slice() - def payload = mapper.newPayload().withBody(1, slice) - payload.writeTo(channel) - - def writtenBytes = channel.toByteArray() - return objectMapper.readValue(writtenBytes, Map) - } - - private static final class DummyReportHolder implements CoverageStore { - private final testReport - - DummyReportHolder(testReport) { - this.testReport = testReport - } - - @Override - TestReport getReport() { - testReport - } - - @Override - boolean report(DDTraceId testSessionId, Long testSuiteId, long spanId) { - return false - } - - @Override - CoverageProbes getProbes() { - return NoOpProbes.INSTANCE - } - } - - private static final class DummyTestContext implements TestContext { - private final CoverageStore coverageStore - - DummyTestContext(CoverageStore coverageStore) { - this.coverageStore = coverageStore - } - - @Override - CoverageStore getCoverageStore() { - return coverageStore - } - - @Override - def void set(Class key, T value) { - } - - @Override - def T get(Class key) { - return null - } - } - - - private static final class ByteArrayWritableByteChannel implements WritableByteChannel { - - private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream() - - @Override - int write(ByteBuffer src) throws IOException { - int remaining = src.remaining() - byte[] buffer = new byte[remaining] - src.get(buffer) - outputStream.write(buffer) - return remaining - } - - @Override - boolean isOpen() { - return true - } - - @Override - void close() throws IOException { - outputStream.close() - } - - byte[] toByteArray() { - return outputStream.toByteArray() - } - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/writer/ddintake/CiTestCycleMapperV1PayloadTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/civisibility/writer/ddintake/CiTestCycleMapperV1PayloadTest.groovy deleted file mode 100644 index 0b8ce38ca22..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/civisibility/writer/ddintake/CiTestCycleMapperV1PayloadTest.groovy +++ /dev/null @@ -1,507 +0,0 @@ -package datadog.trace.civisibility.writer.ddintake - -import com.fasterxml.jackson.databind.ObjectMapper -import datadog.communication.serialization.ByteBufferConsumer -import datadog.communication.serialization.FlushingBuffer -import datadog.communication.serialization.msgpack.MsgPackWriter -import datadog.trace.api.DDTags -import datadog.trace.api.DDTraceId -import datadog.trace.api.civisibility.CiVisibilityWellKnownTags -import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes -import datadog.trace.bootstrap.instrumentation.api.Tags -import datadog.trace.common.writer.Payload -import datadog.trace.common.writer.TraceGenerator -import datadog.trace.core.DDSpanContext -import datadog.trace.test.util.DDSpecification -import org.junit.jupiter.api.Assertions -import org.msgpack.core.MessageFormat -import org.msgpack.core.MessagePack -import org.msgpack.core.MessageUnpacker -import org.msgpack.jackson.dataformat.MessagePackFactory - -import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel - -import static datadog.trace.bootstrap.instrumentation.api.InstrumentationTags.DD_MEASURED -import static datadog.trace.api.civisibility.CIConstants.MAX_META_STRING_VALUE_LENGTH -import static datadog.trace.util.Strings.truncate -import static datadog.trace.common.writer.TraceGenerator.generateRandomSpan -import static datadog.trace.common.writer.TraceGenerator.generateRandomTraces -import static org.junit.jupiter.api.Assertions.assertEquals -import static org.junit.jupiter.api.Assertions.assertFalse -import static org.junit.jupiter.api.Assertions.assertNotNull -import static org.msgpack.core.MessageFormat.FLOAT32 -import static org.msgpack.core.MessageFormat.FLOAT64 -import static org.msgpack.core.MessageFormat.INT16 -import static org.msgpack.core.MessageFormat.INT32 -import static org.msgpack.core.MessageFormat.INT64 -import static org.msgpack.core.MessageFormat.INT8 -import static org.msgpack.core.MessageFormat.NEGFIXINT -import static org.msgpack.core.MessageFormat.POSFIXINT -import static org.msgpack.core.MessageFormat.UINT16 -import static org.msgpack.core.MessageFormat.UINT32 -import static org.msgpack.core.MessageFormat.UINT64 -import static org.msgpack.core.MessageFormat.UINT8 - -class CiTestCycleMapperV1PayloadTest extends DDSpecification { - - def "test traces written correctly with bufferSize=#bufferSize, traceCount=#traceCount, lowCardinality=#lowCardinality"() { - setup: - CiVisibilityWellKnownTags wellKnownTags = new CiVisibilityWellKnownTags( - "runtimeid", "my-env", "language", - "my-runtime-name", "my-runtime-version", "my-runtime-vendor", - "my-os-arch", "my-os-platform", "my-os-version", "false") - CiTestCycleMapperV1 mapper = new CiTestCycleMapperV1(wellKnownTags, false) - - List> traces = generateRandomTraces(traceCount, lowCardinality) - PayloadVerifier verifier = new PayloadVerifier(wellKnownTags, traces, mapper) - - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(bufferSize, verifier)) - when: - boolean tracesFitInBuffer = true - for (List trace : traces) { - if (!packer.format(trace, mapper)) { - verifier.skipLargeTrace() - tracesFitInBuffer = false - } - } - packer.flush() - then: - if (tracesFitInBuffer) { - verifier.verifyTracesConsumed() - } - where: - bufferSize | traceCount | lowCardinality - 20 << 10 | 0 | true - 20 << 10 | 1 | true - 30 << 10 | 1 | true - 30 << 10 | 2 | true - 20 << 10 | 0 | false - 20 << 10 | 1 | false - 30 << 10 | 1 | false - 30 << 10 | 2 | false - 100 << 10 | 0 | true - 100 << 10 | 1 | true - 100 << 10 | 10 | true - 100 << 10 | 100 | true - 100 << 10 | 1000 | true - 100 << 10 | 0 | false - 100 << 10 | 1 | false - 100 << 10 | 10 | false - 100 << 10 | 100 | false - 100 << 10 | 1000 | false - } - - def "verify test_suite_id, test_module_id, and test_session_id are written as top level tags in test event"() { - setup: - def span = generateRandomSpan(InternalSpanTypes.TEST, [ - (Tags.TEST_SESSION_ID): DDTraceId.from(123), - (Tags.TEST_MODULE_ID) : 456, - (Tags.TEST_SUITE_ID) : 789, - ]) - - when: - Map deserializedSpan = whenASpanIsWritten(span) - - then: - verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456, 789) - - def spanContent = (Map) deserializedSpan.get("content") - assert spanContent.containsKey("trace_id") - assert spanContent.containsKey("span_id") - assert spanContent.containsKey("parent_id") - } - - def "truncates meta string values and preserves metrics and top level ids"() { - setup: - String longValue = "a" * (MAX_META_STRING_VALUE_LENGTH + 1) - String exactValue = "b" * MAX_META_STRING_VALUE_LENGTH - def span = generateRandomSpan(InternalSpanTypes.TEST, [ - (Tags.TEST_SESSION_ID): DDTraceId.from(123), - (Tags.TEST_MODULE_ID) : 456, - (Tags.TEST_SUITE_ID) : 789, - "custom.tag" : longValue, - "exact.tag" : exactValue, - "custom.metric" : 42, - ]) - - when: - Map deserializedSpan = whenASpanIsWritten(span) - - then: - verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456, 789) - - def spanContent = (Map) deserializedSpan.get("content") - def deserializedMetrics = (Map) spanContent.get("metrics") - def deserializedMeta = (Map) spanContent.get("meta") - - assert deserializedMeta.get("custom.tag") == longValue.substring(0, MAX_META_STRING_VALUE_LENGTH) - assert deserializedMeta.get("custom.tag").length() == MAX_META_STRING_VALUE_LENGTH - assert deserializedMeta.get("exact.tag") == exactValue - assert deserializedMetrics.get("custom.metric") == 42 - } - - def "truncates payload metadata values"() { - setup: - String longValue = "m" * (MAX_META_STRING_VALUE_LENGTH + 1) - CiVisibilityWellKnownTags wellKnownTags = new CiVisibilityWellKnownTags( - longValue, longValue, longValue, - longValue, longValue, longValue, - longValue, longValue, longValue, longValue) - CiTestCycleMapperV1 mapper = new CiTestCycleMapperV1(wellKnownTags, false) - List> traces = Collections.singletonList( - Collections.singletonList(generateRandomSpan(InternalSpanTypes.TEST, Collections.emptyMap()))) - PayloadVerifier verifier = new PayloadVerifier(wellKnownTags, traces, mapper) - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(100 << 10, verifier)) - - when: - packer.format(traces.get(0), mapper) - packer.flush() - - then: - verifier.verifyTracesConsumed() - } - - def "verify test_suite_end event is written correctly"() { - setup: - def span = generateRandomSpan(InternalSpanTypes.TEST_SUITE_END, [ - (Tags.TEST_SESSION_ID): DDTraceId.from(123), - (Tags.TEST_MODULE_ID) : 456, - (Tags.TEST_SUITE_ID) : 789, - ]) - - when: - Map deserializedSpan = whenASpanIsWritten(span) - - then: - verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456, 789) - - def spanContent = (Map) deserializedSpan.get("content") - assert !spanContent.containsKey("trace_id") - assert !spanContent.containsKey("span_id") - assert !spanContent.containsKey("parent_id") - } - - def "verify test_module_end event is written correctly"() { - setup: - def span = generateRandomSpan(InternalSpanTypes.TEST_MODULE_END, [ - (Tags.TEST_SESSION_ID): DDTraceId.from(123), - (Tags.TEST_MODULE_ID) : 456, - ]) - - when: - Map deserializedSpan = whenASpanIsWritten(span) - - then: - verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456, null) - - def spanContent = (Map) deserializedSpan.get("content") - assert !spanContent.containsKey("trace_id") - assert !spanContent.containsKey("span_id") - assert !spanContent.containsKey("parent_id") - } - - def "verify result is not affected by successive mapping calls"(){ - setup: - def span = generateRandomSpan(InternalSpanTypes.TEST, [ - (Tags.TEST_SESSION_ID): DDTraceId.from(123), - (Tags.TEST_MODULE_ID) : 456, - (Tags.TEST_SUITE_ID) : 789, - ]) - - when: - whenASpanIsWritten(span) - Map deserializedSpan = whenASpanIsWritten(span) - - then: - verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456, 789) - - def spanContent = (Map) deserializedSpan.get("content") - assert spanContent.containsKey("trace_id") - assert spanContent.containsKey("span_id") - assert spanContent.containsKey("parent_id") - } - - private static void verifyTopLevelTags(Map deserializedSpan, DDTraceId testSessionId, Long testModuleId, Long testSuiteId) { - Map deserializedSpanContent = (Map) deserializedSpan.get("content") - Map deserializedMetrics = (Map) deserializedSpanContent.get("metrics") - Map deserializedMeta = (Map) deserializedSpanContent.get("meta") - - if (testSessionId != null) { - assert deserializedSpanContent.get(Tags.TEST_SESSION_ID) == testSessionId.toLong() - } else { - assert !deserializedSpanContent.containsKey(Tags.TEST_SESSION_ID) - } - - if (testModuleId != null) { - assert deserializedSpanContent.get(Tags.TEST_MODULE_ID) == testModuleId - } else { - assert !deserializedSpanContent.containsKey(Tags.TEST_MODULE_ID) - } - - if (testSuiteId != null) { - assert deserializedSpanContent.get(Tags.TEST_SUITE_ID) == testSuiteId - } else { - assert !deserializedSpanContent.containsKey(Tags.TEST_SUITE_ID) - } - - assert !deserializedMetrics.containsKey(Tags.TEST_SESSION_ID) - assert !deserializedMetrics.containsKey(Tags.TEST_MODULE_ID) - assert !deserializedMetrics.containsKey(Tags.TEST_SUITE_ID) - - assert !deserializedMeta.containsKey(Tags.TEST_SESSION_ID) - assert !deserializedMeta.containsKey(Tags.TEST_MODULE_ID) - assert !deserializedMeta.containsKey(Tags.TEST_SUITE_ID) - } - - private static Map whenASpanIsWritten(TraceGenerator.PojoSpan span) { - List trace = Collections.singletonList(span) - - CiVisibilityWellKnownTags wellKnownTags = new CiVisibilityWellKnownTags( - "runtimeid", "my-env", "language", - "my-runtime-name", "my-runtime-version", "my-runtime-vendor", - "my-os-arch", "my-os-platform", "my-os-version", "false") - CiTestCycleMapperV1 mapper = new CiTestCycleMapperV1(wellKnownTags, false) - - ByteBufferConsumer consumer = new CaptureConsumer() - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(100 << 10, consumer)) - - packer.format(trace, mapper) - packer.flush() - - ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()) - return (Map) objectMapper.readValue(consumer.bytes, Object) - } - - private static class CaptureConsumer implements ByteBufferConsumer { - private byte[] bytes - - @Override - void accept(int messageCount, ByteBuffer buffer) { - this.bytes = new byte[buffer.limit() - buffer.position()] - buffer.get(bytes) - } - } - - private static final class PayloadVerifier implements ByteBufferConsumer, WritableByteChannel { - - private final List> expectedTraces - private final CiTestCycleMapperV1 mapper - private final CiVisibilityWellKnownTags wellKnownTags - private ByteBuffer captured = ByteBuffer.allocate(200 << 10) - - private int position = 0 - - private PayloadVerifier(CiVisibilityWellKnownTags wellKnownTags, List> traces, CiTestCycleMapperV1 mapper) { - this.expectedTraces = traces - this.mapper = mapper - this.wellKnownTags = wellKnownTags - } - - void skipLargeTrace() { - ++position - } - - void verifyTracesConsumed() { - assertEquals(expectedTraces.size(), position) - } - - @Override - void accept(int messageCount, ByteBuffer buffer) { - if (expectedTraces.isEmpty() && messageCount == 0) { - return - } - - try { - Payload payload = mapper.newPayload().withBody(messageCount, buffer) - payload.writeTo(this) - captured.flip() - assertNotNull(payload.toRequest()) - MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(captured) - assertEquals(3, unpacker.unpackMapHeader()) - assertEquals("version", unpacker.unpackString()) - assertEquals(1, unpacker.unpackInt()) - assertEquals("metadata", unpacker.unpackString()) - assertEquals(1, unpacker.unpackMapHeader()) - assertEquals("*", unpacker.unpackString()) - - assertEquals(10, unpacker.unpackMapHeader()) - assertEquals("env", unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.env as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals("runtime-id", unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.runtimeId as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals("language", unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.language as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(Tags.RUNTIME_NAME, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.runtimeName as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(Tags.RUNTIME_VENDOR, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.runtimeVendor as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(Tags.RUNTIME_VERSION, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.runtimeVersion as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(Tags.OS_ARCHITECTURE, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.osArch as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(Tags.OS_PLATFORM, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.osPlatform as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(Tags.OS_VERSION, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.osVersion as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - assertEquals(DDTags.TEST_IS_USER_PROVIDED_SERVICE, unpacker.unpackString()) - assertEquals(truncate(wellKnownTags.isUserProvidedService as String, MAX_META_STRING_VALUE_LENGTH), unpacker.unpackString()) - - assertEquals("events", unpacker.unpackString()) - - List expectedTrace = expectedTraces.get(position++) - int eventCount = unpacker.unpackArrayHeader() - while (expectedTrace.size() < eventCount) { - expectedTrace.addAll(expectedTraces.get(position++)) - } - assertEquals(expectedTrace.size(), eventCount) - for (int k = 0; k < eventCount; ++k) { - TraceGenerator.PojoSpan expectedSpan = expectedTrace.get(k) - assertEquals(3, unpacker.unpackMapHeader()) - assertEquals("type", unpacker.unpackString()) - if ("test" == String.valueOf(expectedSpan.getType())) { - assertEquals("test", unpacker.unpackString()) - } else { - assertEquals("span", unpacker.unpackString()) - } - assertEquals("version", unpacker.unpackString()) - assertEquals(1, unpacker.unpackInt()) - assertEquals("content", unpacker.unpackString()) - assertEquals(11, unpacker.unpackMapHeader()) - assertEquals("trace_id", unpacker.unpackString()) - long traceId = unpacker.unpackValue().asNumberValue().toLong() - assertEquals(expectedSpan.getTraceId().toLong(), traceId) - assertEquals("span_id", unpacker.unpackString()) - long spanId = unpacker.unpackValue().asNumberValue().toLong() - assertEquals(expectedSpan.getSpanId(), spanId) - assertEquals("parent_id", unpacker.unpackString()) - long parentId = unpacker.unpackValue().asNumberValue().toLong() - assertEquals(expectedSpan.getParentId(), parentId) - assertEquals("service", unpacker.unpackString()) - String serviceName = unpacker.unpackString() - assertEqualsWithNullAsEmpty(expectedSpan.getServiceName(), serviceName) - assertEquals("name", unpacker.unpackString()) - String operationName = unpacker.unpackString() - assertEqualsWithNullAsEmpty(expectedSpan.getOperationName(), operationName) - assertEquals("resource", unpacker.unpackString()) - String resourceName = unpacker.unpackString() - assertEqualsWithNullAsEmpty(expectedSpan.getResourceName(), resourceName) - - assertEquals("start", unpacker.unpackString()) - long startTime = unpacker.unpackLong() - assertEquals(expectedSpan.getStartTime(), startTime) - assertEquals("duration", unpacker.unpackString()) - long duration = unpacker.unpackLong() - assertEquals(expectedSpan.getDurationNano(), duration) - assertEquals("error", unpacker.unpackString()) - int error = unpacker.unpackInt() - assertEquals(expectedSpan.getError(), error) - assertEquals("metrics", unpacker.unpackString()) - int metricsSize = unpacker.unpackMapHeader() - HashMap metrics = new HashMap<>() - for (int j = 0; j < metricsSize; ++j) { - String key = unpacker.unpackString() - Number n = null - MessageFormat format = unpacker.getNextFormat() - switch (format) { - case NEGFIXINT: - case POSFIXINT: - case INT8: - case UINT8: - case INT16: - case UINT16: - case INT32: - case UINT32: - n = unpacker.unpackInt() - break - case INT64: - case UINT64: - n = unpacker.unpackLong() - break - case FLOAT32: - n = unpacker.unpackFloat() - break - case FLOAT64: - n = unpacker.unpackDouble() - break - default: - Assertions.fail("Unexpected type in metrics values: " + format) - } - if (DD_MEASURED.toString() == key) { - assert ((n == 1) && expectedSpan.isMeasured()) || !expectedSpan.isMeasured() - } else if (DDSpanContext.PRIORITY_SAMPLING_KEY == key) { - //check that priority sampling is only on first and last span - if (k == 0 || k == eventCount - 1) { - assertEquals(expectedSpan.samplingPriority(), n.intValue()) - } else { - assertFalse(expectedSpan.hasSamplingPriority()) - } - } else { - metrics.put(key, n) - } - } - for (Map.Entry metric : metrics.entrySet()) { - if (metric.getValue() instanceof Double || metric.getValue() instanceof Float) { - assertEquals(((Number) expectedSpan.getTag(metric.getKey())).doubleValue(), metric.getValue().doubleValue(), 0.001) - } else { - assertEquals(expectedSpan.getTag(metric.getKey()), metric.getValue()) - } - } - assertEquals("meta", unpacker.unpackString()) - int metaSize = unpacker.unpackMapHeader() - HashMap meta = new HashMap<>() - for (int j = 0; j < metaSize; ++j) { - meta.put(unpacker.unpackString(), unpacker.unpackString()) - } - for (Map.Entry entry : meta.entrySet()) { - if (Tags.HTTP_STATUS.equals(entry.getKey())) { - assertEquals(String.valueOf(expectedSpan.getHttpStatusCode()), entry.getValue()) - } else { - Object tag = expectedSpan.getTag(entry.getKey()) - if (null != tag) { - assertEquals(String.valueOf(tag), entry.getValue()) - } else { - assertEquals(expectedSpan.getBaggage().get(entry.getKey()), entry.getValue()) - } - } - } - } - } catch (IOException e) { - Assertions.fail(e.getMessage()) - } finally { - mapper.reset() - captured.position(0) - captured.limit(captured.capacity()) - } - } - - @Override - int write(ByteBuffer src) throws IOException { - if (captured.remaining() < src.remaining()) { - ByteBuffer newBuffer = ByteBuffer.allocate(captured.capacity() + src.capacity()) - captured.flip() - newBuffer.put(captured) - captured = newBuffer - return write(src) - } - captured.put(src) - return src.position() - } - - @Override - boolean isOpen() { - return true - } - - @Override - void close() throws IOException {} - } - - private static void assertEqualsWithNullAsEmpty(CharSequence expected, CharSequence actual) { - if (null == expected) { - assertEquals("", actual) - } else { - assertEquals(expected.toString(), actual.toString()) - } - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceGenerator.groovy b/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceGenerator.groovy deleted file mode 100644 index 66bdbab137b..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/common/writer/TraceGenerator.groovy +++ /dev/null @@ -1,455 +0,0 @@ -package datadog.trace.common.writer - -import static datadog.trace.api.sampling.PrioritySampling.UNSET -import static java.util.Collections.emptyList - -import datadog.trace.api.DDSpanId -import datadog.trace.api.DDTags -import datadog.trace.api.DDTraceId -import datadog.trace.api.IdGenerationStrategy -import datadog.trace.api.ProcessTags -import datadog.trace.api.TagMap -import datadog.trace.api.sampling.PrioritySampling -import datadog.trace.bootstrap.instrumentation.api.AgentSpanLink -import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString -import datadog.trace.core.CoreSpan -import datadog.trace.core.Metadata -import datadog.trace.core.MetadataConsumer -import java.util.concurrent.ThreadLocalRandom -import java.util.concurrent.TimeUnit - -class TraceGenerator { - - static List> generateRandomTraces(int howMany, boolean lowCardinality) { - List> traces = new ArrayList<>(howMany) - for (int i = 0; i < howMany; ++i) { - int traceSize = ThreadLocalRandom.current().nextInt(2, 20) - traces.add(generateRandomTrace(traceSize, lowCardinality)) - } - return traces - } - - private static List generateRandomTrace(int size, boolean lowCardinality) { - List trace = new ArrayList<>(size) - long traceId = ThreadLocalRandom.current().nextLong(1, Long.MAX_VALUE) - for (int i = 0; i < size; ++i) { - def spanType = "type-" + ThreadLocalRandom.current().nextInt(lowCardinality ? 1 : 100) - trace.add(randomSpan(traceId, lowCardinality, spanType, Collections.emptyMap())) - } - return trace - } - - private static final IdGenerationStrategy ID_GENERATION_STRATEGY = IdGenerationStrategy.fromName("RANDOM") - - static CoreSpan generateRandomSpan(CharSequence type, Map extraTags) { - long traceId = ThreadLocalRandom.current().nextLong(1, Long.MAX_VALUE) - return randomSpan(traceId, true, type, extraTags) - } - - private static CoreSpan randomSpan(long traceId, boolean lowCardinality, CharSequence type, Map extraTags) { - ThreadLocalRandom random = ThreadLocalRandom.current() - Map baggage = new HashMap<>() - if (random.nextBoolean()) { - baggage.put("baggage-key", lowCardinality ? "x" : randomString(100)) - if (random.nextBoolean()) { - baggage.put("tag.1", "bar") - baggage.put("tag.2", "qux") - } - } - Map tags = new HashMap<>(extraTags) - int tagCount = random.nextInt(0, 20) - for (int i = 0; i < tagCount; ++i) { - tags.put("tag." + i, random.nextBoolean() ? "foo" : randomString(2000)) - tags.put("tag.1." + i, lowCardinality ? "y" : UUID.randomUUID()) - tags.put("tag.2." + i, random.nextBoolean()) - switch (random.nextInt(8)) { - case 0: - tags.put("tag.3." + i, BigDecimal.valueOf(random.nextDouble())) - break - case 1: - tags.put("tag.3." + i, BigInteger.valueOf(random.nextLong())) - break - default: - break - } - } - int metricCount = random.nextInt(0, 20) - for (int i = 0; i < metricCount; ++i) { - String name = "metric." + i - Number metric = null - switch (random.nextInt(4)) { - case 0: - metric = random.nextInt() - break - case 1: - metric = random.nextLong() - break - case 2: - metric = random.nextFloat() - break - case 3: - metric = random.nextDouble() - break - } - tags.put(name, metric) - } - - return new PojoSpan( - "service-" + random.nextInt(lowCardinality ? 1 : 10), - "operation-" + random.nextInt(lowCardinality ? 1 : 100), - UTF8BytesString.create("resource-" + random.nextInt(lowCardinality ? 1 : 100)), - DDTraceId.from(traceId), - ID_GENERATION_STRATEGY.generateSpanId(), - DDSpanId.ZERO, - TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()), - random.nextLong(500, 10_000_000), - random.nextInt(2), - baggage, - tags, - type, - random.nextBoolean(), - PrioritySampling.SAMPLER_KEEP, - 200, - "some-origin") - } - - private static String randomString(int maxLength) { - char[] chars = new char[ThreadLocalRandom.current().nextInt(maxLength)] - for (int i = 0; i < chars.length; ++i) { - char next = (char) ThreadLocalRandom.current().nextInt((int) Character.MAX_VALUE) - if (Character.isSurrogate(next)) { - if (i < chars.length - 1) { - chars[i++] = '\uD801' - chars[i] = '\uDC01' - } else { - chars[i] = 'a' - } - } else { - chars[i] = next - } - } - return new String(chars) - } - - static class PojoSpan implements CoreSpan { - - private final CharSequence serviceName - private final CharSequence operationName - private final CharSequence resourceName - private final DDTraceId traceId - private final long spanId - private final long parentId - private final long start - private final long duration - private final int error - private final CharSequence type - private final boolean measured - private final Metadata metadata - private short httpStatusCode - private final int samplingPriority - private final Map metaStruct = [:] - - PojoSpan( - String serviceName, - String operationName, - CharSequence resourceName, - DDTraceId traceId, - long spanId, - long parentId, - long start, - long duration, - int error, - Map baggage, - Map tags, - CharSequence type, - boolean measured, - int samplingPriority, - int statusCode, - CharSequence origin, - List spanLinks = emptyList()) { - this.serviceName = UTF8BytesString.create(serviceName) - this.operationName = UTF8BytesString.create(operationName) - this.resourceName = UTF8BytesString.create(resourceName) - this.traceId = traceId - this.spanId = spanId - this.parentId = parentId - this.start = start - this.duration = duration - this.error = error - this.type = type - this.measured = measured - this.samplingPriority = samplingPriority - this.httpStatusCode = (short) statusCode - this.metadata = new Metadata( - Thread.currentThread().getId(), - UTF8BytesString.create(Thread.currentThread().getName()), - TagMap.fromMap(tags), - baggage, - samplingPriority, - measured, - topLevel, - statusCode == 0 ? null : UTF8BytesString.create(Integer.toString(statusCode)), - origin, - 0, - ProcessTags.tagsForSerialization, - spanLinks) - } - - @Override - PojoSpan getLocalRootSpan() { - return this - } - - @Override - String getServiceName() { - return serviceName - } - - @Override - CharSequence getOperationName() { - return operationName - } - - @Override - CharSequence getResourceName() { - return resourceName - } - - @Override - DDTraceId getTraceId() { - return traceId - } - - @Override - long getSpanId() { - return spanId - } - - @Override - long getParentId() { - return parentId - } - - @Override - long getStartTime() { - return start - } - - @Override - long getDurationNano() { - return duration - } - - @Override - int getError() { - return error - } - - @Override - PojoSpan setMeasured(boolean measured) { - return this - } - - @Override - PojoSpan setErrorMessage(String errorMessage) { - return this - } - - @Override - PojoSpan addThrowable(Throwable error) { - return this - } - - @Override - PojoSpan setTag(String tag, String value) { - return this - } - - @Override - PojoSpan setTag(String tag, boolean value) { - return this - } - - @Override - PojoSpan setTag(String tag, int value) { - return this - } - - @Override - PojoSpan setTag(String tag, long value) { - return this - } - - @Override - PojoSpan setTag(String tag, double value) { - return this - } - - @Override - PojoSpan setTag(String tag, Number value) { - return this - } - - @Override - PojoSpan setTag(String tag, CharSequence value) { - return this - } - - @Override - PojoSpan setTag(String tag, Object value) { - return this - } - - @Override - PojoSpan removeTag(String tag) { - metadata.getTags().remove(tag) - return this - } - - @Override - boolean isMeasured() { - return measured - } - - @Override - boolean isTopLevel() { - return false - } - - @Override - boolean isForceKeep() { - return false - } - - @Override - short getHttpStatusCode() { - return httpStatusCode - } - - @Override - CharSequence getOrigin() { - return metadata.getOrigin() - } - - Map getBaggage() { - return metadata.getBaggage() - } - - Map getTags() { - return metadata.getTags() - } - - @Override - CharSequence getType() { - return this.type - } - - @Override - void processServiceTags() {} - - @Override - void processTagsAndBaggage(MetadataConsumer consumer) { - consumer.accept(metadata) - } - - @Override - void processTagsAndBaggage(MetadataConsumer consumer, boolean injectLinksAsTags, boolean injectBaggageAsTags) { - consumer.accept(metadata) - } - - @Override - PojoSpan setSamplingPriority(int samplingPriority, int samplingMechanism) { - return this - } - - @Override - PojoSpan setSamplingPriority(int samplingPriority, CharSequence rate, double sampleRate, int samplingMechanism) { - return this - } - - @Override - PojoSpan setSpanSamplingPriority(double rate, int limit) { - return this - } - - @Override - PojoSpan setMetric(CharSequence name, int value) { - return this - } - - @Override - PojoSpan setMetric(CharSequence name, long value) { - return this - } - - @Override - PojoSpan setMetric(CharSequence name, float value) { - return this - } - - @Override - PojoSpan setMetric(CharSequence name, double value) { - return this - } - - @Override - PojoSpan setFlag(CharSequence name, boolean value) { - return this - } - - @Override - int samplingPriority() { - return samplingPriority - } - - @Override - U getTag(CharSequence name, U defaultValue) { - U value = getTag(name) - return null == value ? defaultValue : value - } - - @Override - U getTag(CharSequence name) { - // replicate logic here because DDSpanContext has to pretend some of its - // fields are elements of a map for backward compatibility reasons - String tag = String.valueOf(name) - Object value = null - switch (tag) { - case DDTags.THREAD_ID: - value = metadata.getThreadId() - break - case DDTags.THREAD_NAME: - value = metadata.getThreadName() - break - default: - value = tags.get(tag) - } - return value as U - } - - @Override - boolean hasSamplingPriority() { - return samplingPriority != UNSET - } - - @Override - Map getMetaStruct() { - return metaStruct - } - - @Override - PojoSpan setMetaStruct(String field, Object value) { - if (value == null) { - metaStruct.remove(field) - } else { - metaStruct[field] = value - } - return this - } - - @Override - int getLongRunningVersion() { - return 0 - } - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy deleted file mode 100644 index 81410905dc3..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaAppSecHandlerTest.groovy +++ /dev/null @@ -1,1447 +0,0 @@ -package datadog.trace.lambda - -import datadog.trace.api.Config -import datadog.trace.api.function.TriConsumer -import datadog.trace.api.function.TriFunction -import datadog.trace.api.gateway.CallbackProvider -import datadog.trace.api.gateway.Flow -import datadog.trace.api.gateway.RequestContext -import datadog.trace.api.gateway.RequestContextSlot -import datadog.trace.bootstrap.ActiveSubsystems -import datadog.trace.bootstrap.instrumentation.api.AgentSpan -import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext -import datadog.trace.bootstrap.instrumentation.api.AgentTracer -import datadog.trace.bootstrap.instrumentation.api.TagContext -import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter -import datadog.trace.core.test.DDCoreSpecification -import spock.lang.Shared - -import java.nio.charset.StandardCharsets -import java.util.function.BiFunction -import java.util.function.Function -import java.util.function.Supplier - -import static datadog.trace.api.gateway.Events.EVENTS - -class LambdaAppSecHandlerTest extends DDCoreSpecification { - - @Shared - def originalAppSecActive - - @Shared - AgentTracer.TracerAPI originalTracer - - def setupSpec() { - originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE - originalTracer = AgentTracer.get() - } - - def cleanupSpec() { - ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive - } - - def setup() { - ActiveSubsystems.APPSEC_ACTIVE = true - } - - def "processRequestStart returns null when AppSec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - def event = createInputStream('{"test": "data"}') - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "processRequestStart returns null for non-ByteArrayInputStream"() { - when: - def result = LambdaAppSecHandler.processRequestStart("not a stream") - - then: - result == null - } - - def "processRequestStart returns null for null event"() { - when: - def result = LambdaAppSecHandler.processRequestStart(null) - - then: - result == null - } - - def "processRequestStart returns null for oversized event"() { - given: - def maxSize = Config.get().getAppSecBodyParsingSizeLimit() - def largeBody = "x" * (maxSize + 1) - def event = createInputStream(largeBody) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "processRequestStart returns null for zero-size event"() { - given: - def event = createInputStream('') - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "processRequestStart returns null for malformed JSON"() { - given: - def event = createInputStream('{invalid json') - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null - } - - def "stream can be read multiple times after processing"() { - given: - def jsonData = '{"test": "data", "requestContext": {"httpMethod": "GET"}}' - def event = createInputStream(jsonData) - - when: - LambdaAppSecHandler.processRequestStart(event) - event.reset() - def content = new String(event.bytes, StandardCharsets.UTF_8) - - then: - content == jsonData - } - - - // ============================================================================ - // Trigger Type Detection Tests - // ============================================================================ - - def "detects API Gateway v1 REST trigger type"() { - given: - def event = [ - requestContext: [ - httpMethod: "GET", - requestId: "abc123" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST - } - - def "detects API Gateway v2 HTTP trigger type"() { - given: - def event = [ - requestContext: [ - http: [ - method: "POST", - path: "/api" - ], - domainName: "api.example.com" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP - } - - def "detects Lambda Function URL trigger type"() { - given: - def event = [ - requestContext: [ - http: [ - method: "GET", - path: "/" - ], - domainName: "xyz123.lambda-url.us-east-1.on.aws" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL - } - - def "detects ALB trigger type without multi-value headers"() { - given: - def event = [ - httpMethod: "GET", - path: "/", - requestContext: [ - elb: [ - targetGroupArn: "arn:aws:..." - ] - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB - } - - def "detects ALB trigger type with multi-value headers"() { - given: - def event = [ - httpMethod: "GET", - path: "/", - multiValueHeaders: [ - accept: ["text/html", "application/json"] - ], - requestContext: [ - elb: [ - targetGroupArn: "arn:aws:..." - ] - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE - } - - def "detects WebSocket trigger type with routeKey"() { - given: - def event = [ - requestContext: [ - connectionId: "conn-123", - routeKey: "\$connect" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET - } - - def "detects WebSocket trigger type with eventType"() { - given: - def event = [ - requestContext: [ - connectionId: "conn-456", - eventType: "CONNECT" - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET - } - - def "detects unknown trigger type for unrecognized events"() { - given: - def event = [ - someUnknownField: "value" - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN - } - - def "detects unknown trigger type for empty requestContext"() { - given: - def event = [ - requestContext: [:] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.UNKNOWN - } - - def "detects Lambda URL when http present but no domainName"() { - given: - def event = [ - requestContext: [ - http: [ - method: "GET", - path: "/ambiguous" - ] - ] - ] - - when: - def triggerType = LambdaAppSecHandler.detectTriggerType(event) - - then: - triggerType == LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL - } - - // ============================================================================ - // Data Extraction Tests with Mocked Callbacks - // ============================================================================ - - def "extracts API Gateway v1 REST data correctly"() { - given: - def eventJson = ''' - { - "path": "/api/users/123", - "httpMethod": "POST", - "headers": { - "Content-Type": "application/json", - "Authorization": "Bearer token123" - }, - "pathParameters": { - "userId": "123" - }, - "body": "{\\"name\\": \\"John\\"}", - "requestContext": { - "httpMethod": "POST", - "requestId": "req-123", - "identity": { - "sourceIp": "192.168.1.100" - } - } - } - ''' - def event = createInputStream(eventJson) - - // Track callback invocations - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedSourceIp = null - def capturedSourcePort = null - def capturedPathParams = null - def capturedBody = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - }, - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - result instanceof TagContext - - capturedMethod == "POST" - capturedPath == "/api/users/123" - capturedHeaders["Content-Type"] == "application/json" - capturedHeaders["Authorization"] == "Bearer token123" - capturedSourceIp == "192.168.1.100" - capturedSourcePort == 0 - capturedPathParams == ["userId": "123"] - capturedBody instanceof Map - capturedBody.name == "John" - } - - def "extracts API Gateway v2 HTTP data correctly"() { - given: - def eventJson = ''' - { - "version": "2.0", - "headers": { - "content-type": "application/json", - "x-custom-header": "custom-value" - }, - "cookies": ["session=abc123", "user=john"], - "pathParameters": { - "id": "456" - }, - "body": "test body", - "requestContext": { - "http": { - "method": "PUT", - "path": "/api/items/456", - "sourceIp": "10.0.0.50", - "sourcePort": 54321 - }, - "domainName": "api.example.com" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedSourceIp = null - def capturedSourcePort = null - def capturedPathParams = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - capturedSourcePort = port - }, - onPathParams: { params -> - capturedPathParams = params - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "PUT" - capturedPath == "/api/items/456" - capturedHeaders["content-type"] == "application/json" - capturedHeaders["x-custom-header"] == "custom-value" - capturedHeaders["cookie"] == "session=abc123; user=john" - capturedSourceIp == "10.0.0.50" - capturedSourcePort == 54321 - capturedPathParams == ["id": "456"] - } - - def "extracts Lambda Function URL data correctly"() { - given: - def eventJson = ''' - { - "version": "2.0", - "headers": { - "host": "xyz.lambda-url.us-east-1.on.aws" - }, - "requestContext": { - "http": { - "method": "GET", - "path": "/function/path", - "sourceIp": "1.2.3.4" - }, - "domainName": "xyz.lambda-url.us-east-1.on.aws" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "GET" - capturedPath == "/function/path" - } - - def "extracts ALB data correctly"() { - given: - def eventJson = ''' - { - "path": "/alb/test", - "httpMethod": "DELETE", - "headers": { - "x-forwarded-for": "203.0.113.42", - "user-agent": "curl/7.64.1" - }, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "DELETE" - capturedPath == "/alb/test" - capturedSourceIp == "203.0.113.42" - } - - def "extracts ALB multi-value headers correctly"() { - given: - def eventJson = ''' - { - "path": "/test", - "httpMethod": "GET", - "multiValueHeaders": { - "accept": ["text/html", "application/json"], - "x-custom": ["value1", "value2"] - }, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:..." - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedHeaders["accept"] == "text/html, application/json" - capturedHeaders["x-custom"] == "value1, value2" - } - - def "handles multi-value headers with empty list"() { - given: - def eventJson = ''' - { - "path": "/test", - "httpMethod": "GET", - "multiValueHeaders": { - "accept": [], - "x-custom": ["value1"] - }, - "requestContext": { - "elb": { - "targetGroupArn": "arn:aws:..." - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedHeaders["accept"] == "" // Empty list should result in empty string - capturedHeaders["x-custom"] == "value1" - } - - def "extracts WebSocket data correctly"() { - given: - def eventJson = ''' - { - "requestContext": { - "routeKey": "$connect", - "connectionId": "conn-abc123", - "identity": { - "sourceIp": "192.168.0.100" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "WEBSOCKET" - capturedPath == "\$connect" - capturedSourceIp == "192.168.0.100" - } - - def "handles base64 encoded body correctly"() { - given: - def originalBody = "This is test data" - def base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()) - def eventJson = """ - { - "body": "${base64Body}", - "isBase64Encoded": true, - "requestContext": { - "httpMethod": "POST" - } - } - """ - def event = createInputStream(eventJson) - - def capturedBody = null - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == originalBody - } - - def "handles null body correctly"() { - given: - def event = createInputStream('{"body": null, "requestContext": {"httpMethod": "GET"}}') - - def capturedBody = "NOT_CALLED" - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "NOT_CALLED" // Callback should not be invoked for null body - } - - def "handles empty body correctly"() { - given: - def event = createInputStream('{"body": "", "requestContext": {"httpMethod": "POST"}}') - - def capturedBody = null - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "" // Empty body is passed as empty string to WAF - } - - def "handles path with query string correctly"() { - given: - def eventJson = ''' - { - "path": "/api/users?id=123&filter=active", - "requestContext": { - "httpMethod": "GET" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedPath = null - def capturedQuery = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedPath = uri.path() - capturedQuery = uri.query() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedPath == "/api/users" - capturedQuery == "id=123&filter=active" - } - - def "extracts scheme and port from X-Forwarded headers"() { - given: - def eventJson = ''' - { - "path": "/api/test", - "headers": { - "x-forwarded-proto": "http", - "x-forwarded-port": "8080" - }, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-123" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedScheme = null - def capturedPort = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedScheme = uri.scheme() - capturedPort = uri.port() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedScheme == "http" - capturedPort == 8080 - } - - def "falls back to https/443 when X-Forwarded headers are absent"() { - given: - def eventJson = ''' - { - "path": "/api/test", - "headers": {}, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-123" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedScheme = null - def capturedPort = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedScheme = uri.scheme() - capturedPort = uri.port() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedScheme == "https" - capturedPort == 443 - } - - def "handles invalid X-Forwarded-Port gracefully"() { - given: - def eventJson = ''' - { - "path": "/api/test", - "headers": { - "x-forwarded-proto": "https", - "x-forwarded-port": "not-a-number" - }, - "requestContext": { - "httpMethod": "GET", - "requestId": "req-123" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedScheme = null - def capturedPort = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedScheme = uri.scheme() - capturedPort = uri.port() - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedScheme == "https" - capturedPort == 443 - } - - def "handles invalid base64 body gracefully"() { - given: - def eventJson = ''' - { - "body": "not-valid-base64", - "isBase64Encoded": true, - "requestContext": { - "httpMethod": "POST" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedBody = "NOT_CALLED" - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "NOT_CALLED" // Should not call body callback when decode fails - } - - def "handles base64 decoded empty string body"() { - given: - def base64Empty = Base64.getEncoder().encodeToString("".getBytes()) - def eventJson = """ - { - "body": "${base64Empty}", - "isBase64Encoded": true, - "requestContext": { - "httpMethod": "POST" - } - } - """ - def event = createInputStream(eventJson) - - def capturedBody = "NOT_CALLED" - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody == "" // Should pass empty string after decoding - } - - def "handles body with special characters"() { - given: - def eventJson = ''' - { - "body": "{\\"text\\": \\"Hello δΈ–η•Œ 🌍\\"}", - "requestContext": { - "httpMethod": "POST" - } - } - ''' - def event = createInputStream(eventJson) - - def capturedBody = null - - setupMockCallbacks( - onBody: { body -> - capturedBody = body - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedBody instanceof Map - capturedBody.text == "Hello δΈ–η•Œ 🌍" - } - - // ============================================================================ - // Generic Data Extraction Tests - // ============================================================================ - - def "extracts data from unknown trigger type using generic extraction"() { - given: - def eventJson = ''' - { - "path": "/generic/path", - "httpMethod": "PATCH", - "headers": { - "x-custom-header": "generic-value" - }, - "unknownField": "should be ignored", - "requestContext": { - "identity": { - "sourceIp": "203.0.113.1" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedHeaders = [:] - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onHeader: { name, value -> - capturedHeaders[name] = value - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "PATCH" - capturedPath == "/generic/path" - capturedHeaders["x-custom-header"] == "generic-value" - capturedSourceIp == "203.0.113.1" - } - - def "extracts data from unknown trigger with http in requestContext"() { - given: - def eventJson = ''' - { - "requestContext": { - "http": { - "method": "OPTIONS", - "path": "/options/path", - "sourceIp": "198.51.100.50" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedMethod = null - def capturedPath = null - def capturedSourceIp = null - - setupMockCallbacks( - onMethodUri: { method, uri -> - capturedMethod = method - capturedPath = uri.path() - }, - onSocketAddress: { ip, port -> - capturedSourceIp = ip - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedMethod == "OPTIONS" - capturedPath == "/options/path" - capturedSourceIp == "198.51.100.50" - } - - def "handles cookies merging with existing cookie header"() { - given: - def eventJson = ''' - { - "headers": { - "cookie": "existing=value" - }, - "cookies": ["new=cookie1", "another=cookie2"], - "requestContext": { - "http": { - "method": "GET", - "path": "/" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - capturedHeaders["cookie"] == "existing=value; new=cookie1; another=cookie2" - } - - def "handles empty cookies array correctly"() { - given: - def eventJson = ''' - { - "headers": { - "content-type": "application/json" - }, - "cookies": [], - "requestContext": { - "http": { - "method": "GET", - "path": "/" - } - } - } - ''' - def event = createInputStream(eventJson) - - def capturedHeaders = [:] - - setupMockCallbacks( - onHeader: { name, value -> - capturedHeaders[name] = value - } - ) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null - !capturedHeaders.containsKey("cookie") // Empty array should not add cookie header - } - - // ============================================================================ - // processRequestEnd Tests - // ============================================================================ - - def "processRequestEnd does nothing when span is null"() { - when: - LambdaAppSecHandler.processRequestEnd(null) - - then: - noExceptionThrown() - } - - def "processRequestEnd does nothing when AppSec is disabled"() { - given: - ActiveSubsystems.APPSEC_ACTIVE = false - def span = Mock(AgentSpan) - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - 0 * span._ - } - - def "processRequestEnd does nothing when span has no RequestContext"() { - given: - def span = Mock(AgentSpan) { - getRequestContext() >> null - } - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - noExceptionThrown() - } - - def "processRequestEnd invokes requestEnded callback with RequestContext"() { - given: - def mockAppSecContext = new Object() - def mockRequestContext = Mock(RequestContext) { - getData(RequestContextSlot.APPSEC) >> mockAppSecContext - } - def span = Mock(AgentSpan) { - getRequestContext() >> mockRequestContext - } - - def callbackInvoked = false - def capturedContext = null - def capturedSpan = null - - def mockRequestEndedCallback = Mock(BiFunction) { - apply(_ as RequestContext, _ as AgentSpan) >> { - RequestContext ctx, AgentSpan s -> - callbackInvoked = true - capturedContext = ctx - capturedSpan = s - return new Flow.ResultFlow<>(null) - } - } - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestEnded()) >> mockRequestEndedCallback - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - callbackInvoked - capturedContext == mockRequestContext - capturedSpan == span - } - - def "processRequestEnd handles null requestEnded callback gracefully"() { - given: - def mockRequestContext = Mock(RequestContext) - def span = Mock(AgentSpan) { - getRequestContext() >> mockRequestContext - } - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestEnded()) >> null - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - LambdaAppSecHandler.processRequestEnd(span) - - then: - noExceptionThrown() // Should log warning but not throw - } - - // ============================================================================ - // mergeContexts Tests - // ============================================================================ - - def "mergeContexts returns null when both contexts are null"() { - when: - def result = LambdaAppSecHandler.mergeContexts(null, null) - - then: - result == null - } - - def "mergeContexts returns extensionContext when appSecContext is null"() { - given: - def extensionContext = Mock(TagContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, null) - - then: - result == extensionContext - } - - def "mergeContexts returns appSecContext when extensionContext is null"() { - given: - def appSecContext = Mock(TagContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(null, appSecContext) - - then: - result == appSecContext - } - - def "mergeContexts merges AppSec data into TagContext"() { - given: - def appSecData = new Object() - - // Create real TagContext instances since methods are final - def appSecContext = new TagContext() - appSecContext.withRequestContextDataAppSec(appSecData) - - def extensionContext = new TagContext() - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) - - then: - result == extensionContext - result.getRequestContextDataAppSec() == appSecData - } - - def "mergeContexts returns extensionContext when appSecContext is not TagContext"() { - given: - def extensionContext = Mock(TagContext) - def appSecContext = Mock(AgentSpanContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) - - then: - result == extensionContext - } - - def "mergeContexts returns extensionContext when it is not TagContext"() { - given: - def extensionContext = Mock(AgentSpanContext) - def appSecContext = Mock(TagContext) - - when: - def result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext) - - then: - result == extensionContext - } - - // ============================================================================ - // Error Handling and Null Callback Tests - // ============================================================================ - - def "processRequestStart handles null requestStarted callback gracefully"() { - given: - def eventJson = '{"requestContext": {"httpMethod": "GET"}}' - def event = createInputStream(eventJson) - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestStarted()) >> null - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null // Should return null when requestStarted callback is missing - } - - def "processRequestStart handles null methodUri callback gracefully"() { - given: - def eventJson = ''' - { - "path": "/test", - "requestContext": { - "httpMethod": "GET" - } - } - ''' - def event = createInputStream(eventJson) - - def mockAppSecContext = new Object() - - def mockRequestStartedCallback = Mock(Supplier) { - get() >> new Flow.ResultFlow<>(mockAppSecContext) - } - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback - getCallback(EVENTS.requestMethodUriRaw()) >> null // Null callback - getCallback(EVENTS.requestHeader()) >> null - getCallback(EVENTS.requestClientSocketAddress()) >> null - getCallback(EVENTS.requestHeaderDone()) >> Mock(Function) { - apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) - } - getCallback(EVENTS.requestPathParams()) >> null - getCallback(EVENTS.requestBodyProcessed()) >> null - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - AgentTracer.forceRegister(mockTracer) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result != null // Should continue processing even if methodUri callback is null - result instanceof TagContext - } - - def "processRequestStart handles exception during JSON parsing"() { - given: - def invalidJson = '{this is not valid JSON at all' - def event = createInputStream(invalidJson) - - when: - def result = LambdaAppSecHandler.processRequestStart(event) - - then: - result == null // Should return null on parse error - } - - def "processRequestStart handles exception during stream reading"() { - given: - def mockStream = Mock(ByteArrayInputStream) { - available() >> { throw new IOException("Stream error") } - } - - when: - def result = LambdaAppSecHandler.processRequestStart(mockStream) - - then: - result == null // Should return null on IO error - } - - // ============================================================================ - // Helper Methods - // ============================================================================ - - private ByteArrayInputStream createInputStream(String json) { - return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)) - } - - /** - * Set up mock callbacks to capture invocations and verify data extraction. - * This mocks the AgentTracer and callback provider to intercept gateway calls. - */ - private void setupMockCallbacks(Map callbacks) { - def mockAppSecContext = new Object() - - def mockRequestStartedCallback = Mock(Supplier) { - get() >> new Flow.ResultFlow<>(mockAppSecContext) - } - - def mockMethodUriCallback = callbacks.onMethodUri ? Mock(TriFunction) { - apply(_ as RequestContext, _ as String, _ as URIDataAdapter) >> { - RequestContext ctx, String method, URIDataAdapter uri -> - callbacks.onMethodUri(method, uri) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockHeaderCallback = callbacks.onHeader ? Mock(TriConsumer) { - accept(_ as RequestContext, _ as String, _ as String) >> { - RequestContext ctx, String name, String value -> - callbacks.onHeader(name, value) - } - } : null - - def mockSocketAddressCallback = callbacks.onSocketAddress ? Mock(TriFunction) { - apply(_ as RequestContext, _ as String, _ as Integer) >> { - RequestContext ctx, String ip, Integer port -> - callbacks.onSocketAddress(ip, port) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockHeaderDoneCallback = Mock(Function) { - apply(_ as RequestContext) >> new Flow.ResultFlow<>(null) - } - - def mockPathParamsCallback = callbacks.onPathParams ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Map) >> { - RequestContext ctx, Map params -> - callbacks.onPathParams(params) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockBodyCallback = callbacks.onBody ? Mock(BiFunction) { - apply(_ as RequestContext, _ as Object) >> { - RequestContext ctx, Object body -> - callbacks.onBody(body) - return new Flow.ResultFlow<>(null) - } - } : null - - def mockCallbackProvider = Mock(CallbackProvider) { - getCallback(EVENTS.requestStarted()) >> mockRequestStartedCallback - getCallback(EVENTS.requestMethodUriRaw()) >> mockMethodUriCallback - getCallback(EVENTS.requestHeader()) >> mockHeaderCallback - getCallback(EVENTS.requestClientSocketAddress()) >> mockSocketAddressCallback - getCallback(EVENTS.requestHeaderDone()) >> mockHeaderDoneCallback - getCallback(EVENTS.requestPathParams()) >> mockPathParamsCallback - getCallback(EVENTS.requestBodyProcessed()) >> mockBodyCallback - } - - def mockTracer = Mock(AgentTracer.TracerAPI) { - getCallbackProvider(RequestContextSlot.APPSEC) >> mockCallbackProvider - } - - // Install the mock tracer - AgentTracer.forceRegister(mockTracer) - } - - def cleanup() { - AgentTracer.forceRegister(originalTracer) - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaHandlerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaHandlerTest.groovy deleted file mode 100644 index ab8b188eb8b..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/LambdaHandlerTest.groovy +++ /dev/null @@ -1,330 +0,0 @@ -package datadog.trace.lambda - -import datadog.trace.api.DDSpanId -import datadog.trace.api.DDTags -import datadog.trace.api.DDTraceId -import datadog.trace.core.CoreTracer -import datadog.trace.core.test.DDCoreSpecification -import datadog.trace.core.DDSpan -import com.amazonaws.services.lambda.runtime.events.SQSEvent -import com.amazonaws.services.lambda.runtime.events.SNSEvent -import com.amazonaws.services.lambda.runtime.events.S3Event -import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent -import com.amazonaws.services.lambda.runtime.events.models.s3.S3EventNotification - -import static datadog.trace.agent.test.server.http.TestHttpServer.httpServer - -class LambdaHandlerTest extends DDCoreSpecification { - - class TestObject { - - public String field1 - public boolean field2 - - TestObject() { - this.field1 = "toto" - this.field2 = true - } - - @Override - String toString() { - "$field1 / $field2}" - } - } - - def "test start invocation success"() { - given: - CoreTracer ct = tracerBuilder().build() - - def server = httpServer { - handlers { - post("/lambda/start-invocation") { - response - .status(200) - .addHeader("x-datadog-trace-id", "1234") - .addHeader("x-datadog-sampling-priority", "2") - .send() - } - } - } - LambdaHandler.setExtensionBaseUrl(server.address.toString()) - - when: - def objTest = LambdaHandler.notifyStartInvocation(obj, "lambda-request-123") - - then: - objTest.getTraceId().toString() == traceId - objTest.getSamplingPriority() == samplingPriority - server.lastRequest.headers.get("lambda-runtime-aws-request-id") == "lambda-request-123" - - cleanup: - server.close() - ct.close() - - where: - traceId | samplingPriority | obj - "1234" | 2 | new TestObject() - } - - def "test start invocation with 128 bit trace ID"() { - given: - CoreTracer ct = tracerBuilder().build() - - def server = httpServer { - handlers { - post("/lambda/start-invocation") { - response - .status(200) - .addHeader("x-datadog-trace-id", "5744042798732701615") - .addHeader("x-datadog-sampling-priority", "2") - .addHeader("x-datadog-tags", "_dd.p.tid=1914fe7789eb32be") - .send() - } - } - } - LambdaHandler.setExtensionBaseUrl(server.address.toString()) - - when: - def objTest = LambdaHandler.notifyStartInvocation(obj, "lambda-request-123") - - then: - objTest.getTraceId().toHexString() == traceId - objTest.getSamplingPriority() == samplingPriority - server.lastRequest.headers.get("lambda-runtime-aws-request-id") == "lambda-request-123" - - cleanup: - server.close() - ct.close() - - where: - traceId | samplingPriority | obj - "1914fe7789eb32be4fb6f07e011a6faf" | 2 | new TestObject() - } - - def "test start invocation failure"() { - given: - CoreTracer ct = tracerBuilder().build() - - def server = httpServer { - handlers { - post("/lambda/start-invocation") { - response - .status(500) - .send() - } - } - } - LambdaHandler.setExtensionBaseUrl(server.address.toString()) - - when: - def objTest = LambdaHandler.notifyStartInvocation(obj, "my-lambda-request") - - then: - objTest == expected - server.lastRequest.headers.get("lambda-runtime-aws-request-id") == "my-lambda-request" - - cleanup: - server.close() - ct.close() - - where: - expected | obj - null | new TestObject() - } - - def "test end invocation success"() { - given: - def server = httpServer { - handlers { - post("/lambda/end-invocation") { - response - .status(200) - .send() - } - } - } - LambdaHandler.setExtensionBaseUrl(server.address.toString()) - DDSpan span = Mock(DDSpan) { - getTraceId() >> DDTraceId.from("1234") - getSpanId() >> DDSpanId.from("5678") - getSamplingPriority() >> 2 - } - - when: - def result = LambdaHandler.notifyEndInvocation(span, lambdaResult, boolValue, lambdaReqIdHeaderValue) - - then: - server.lastRequest.headers.get("x-datadog-invocation-error") == eHeaderValue - server.lastRequest.headers.get("x-datadog-trace-id") == tIdHeaderValue - server.lastRequest.headers.get("x-datadog-span-id") == sIdHeaderValue - server.lastRequest.headers.get("x-datadog-sampling-priority") == sPIdHeaderValue - server.lastRequest.headers.get("lambda-runtime-aws-request-id") == lambdaReqIdHeaderValue - result == expected - - cleanup: - server.close() - - where: - expected | eHeaderValue | tIdHeaderValue | sIdHeaderValue | sPIdHeaderValue | lambdaResult | boolValue | lambdaReqIdHeaderValue - true | "true" | "1234" | "5678" | "2" | {} | true | "request123" - true | null | "1234" | "5678" | "2" | "12345" | false | "request456" - } - - def "test end invocation failure"() { - given: - def server = httpServer { - handlers { - post("/lambda/end-invocation") { - response - .status(500) - .send() - } - } - } - LambdaHandler.setExtensionBaseUrl(server.address.toString()) - DDSpan span = Mock(DDSpan) { - getTraceId() >> DDTraceId.from("1234") - getSpanId() >> DDSpanId.from("5678") - getSamplingPriority() >> 2 - } - - when: - def result = LambdaHandler.notifyEndInvocation(span, lambdaResult, boolValue, lambdaReqIdHeaderValue) - - then: - result == expected - server.lastRequest.headers.get("x-datadog-invocation-error") == headerValue - server.lastRequest.headers.get("lambda-runtime-aws-request-id") == lambdaReqIdHeaderValue - - cleanup: - server.close() - - where: - expected | headerValue | lambdaResult | boolValue | lambdaReqIdHeaderValue - false | "true" | {} | true | "request123" - false | null | "12345" | false | "request456" - } - - def "test end invocation success with error metadata"() { - given: - def server = httpServer { - handlers { - post("/lambda/end-invocation") { - response - .status(200) - .send() - } - } - } - LambdaHandler.setExtensionBaseUrl(server.address.toString()) - DDSpan span = Mock(DDSpan) { - getTraceId() >> DDTraceId.from("1234") - getSpanId() >> DDSpanId.from("5678") - getSamplingPriority() >> 2 - getTag(DDTags.ERROR_MSG) >> "custom error message" - getTag(DDTags.ERROR_TYPE) >> "java.lang.Throwable" - getTag(DDTags.ERROR_STACK) >> "errorStack\n \ttest" - } - - when: - LambdaHandler.notifyEndInvocation(span, {}, true, "lambda-request-123") - - then: - server.lastRequest.headers.get("x-datadog-invocation-error") == "true" - server.lastRequest.headers.get("x-datadog-invocation-error-msg") == "custom error message" - server.lastRequest.headers.get("x-datadog-invocation-error-type") == "java.lang.Throwable" - server.lastRequest.headers.get("x-datadog-invocation-error-stack") == "ZXJyb3JTdGFjawogCXRlc3Q=" - server.lastRequest.headers.get("lambda-runtime-aws-request-id") == "lambda-request-123" - - cleanup: - server.close() - } - - def "test moshi toJson SQSEvent"() { - given: - def myEvent = new SQSEvent() - List records = new ArrayList<>() - SQSEvent.SQSMessage message = new SQSEvent.SQSMessage() - message.setMessageId("myId") - message.setAwsRegion("myRegion") - records.add(message) - myEvent.setRecords(records) - - when: - def result = LambdaHandler.writeValueAsString(myEvent) - - then: - result == "{\"records\":[{\"awsRegion\":\"myRegion\",\"messageId\":\"myId\"}]}" - } - - def "test moshi toJson S3Event"() { - given: - List list = new ArrayList<>() - S3EventNotification.S3EventNotificationRecord item0 = new S3EventNotification.S3EventNotificationRecord( - "region", "eventName", "mySource", null, "3.4", - null, null, null, null) - list.add(item0) - def myEvent = new S3Event(list) - - when: - def result = LambdaHandler.writeValueAsString(myEvent) - - then: - result == "{\"records\":[{\"awsRegion\":\"region\",\"eventName\":\"eventName\",\"eventSource\":\"mySource\",\"eventVersion\":\"3.4\"}]}" - } - - def "test moshi toJson SNSEvent"() { - given: - def myEvent = new SNSEvent() - List records = new ArrayList<>() - SNSEvent.SNSRecord message = new SNSEvent.SNSRecord() - message.setEventSource("mySource") - message.setEventVersion("myVersion") - records.add(message) - myEvent.setRecords(records) - - when: - def result = LambdaHandler.writeValueAsString(myEvent) - - then: - result == "{\"records\":[{\"eventSource\":\"mySource\",\"eventVersion\":\"myVersion\"}]}" - } - - def "test moshi toJson APIGatewayProxyRequestEvent"() { - given: - def myEvent = new APIGatewayProxyRequestEvent() - myEvent.setBody("bababango") - myEvent.setHttpMethod("POST") - - when: - def result = LambdaHandler.writeValueAsString(myEvent) - - then: - result == "{\"body\":\"bababango\",\"httpMethod\":\"POST\"}" - } - - def "test moshi toJson InputStream"() { - given: - def body = "{\"body\":\"bababango\",\"httpMethod\":\"POST\"}" - def myEvent = new ByteArrayInputStream(body.getBytes()) - - when: - def result = LambdaHandler.writeValueAsString(myEvent) - - then: - result == body - } - - def "test moshi toJson OutputStream"() { - given: - def body = "{\"body\":\"bababango\",\"statusCode\":\"200\"}" - def myEvent = new ByteArrayOutputStream() - myEvent.write(body.getBytes(), 0, body.length()) - - when: - def result = LambdaHandler.writeValueAsString(myEvent) - - then: - result == body - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/lambda/SkipTypeJsonSerializerTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/lambda/SkipTypeJsonSerializerTest.groovy deleted file mode 100644 index ab1712bb308..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/lambda/SkipTypeJsonSerializerTest.groovy +++ /dev/null @@ -1,184 +0,0 @@ -package datadog.trace.lambda - -import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent -import com.amazonaws.services.lambda.runtime.events.SNSEvent -import com.amazonaws.services.lambda.runtime.events.SQSEvent -import datadog.trace.core.test.DDCoreSpecification -import com.squareup.moshi.Moshi - -abstract class AbstractSerialize { - public String randomString -} - -class SubClass extends AbstractSerialize { - SubClass() { - this.randomString = "tutu" - } -} - -class CustomRequest

extends LambdaRequest { - public P path - public B body -} -interface ApiRequestPath {} -class LambdaRequest { - public boolean testBool - public String emptyStr - public Map emptyHeaders -} - -class SkipUnhandledTypeJsonSerializerTest extends DDCoreSpecification { - - static class TestJsonObject { - - public String field1 - public boolean field2 - public AbstractSerialize field3 - public NestedJsonObject field4 - public ByteArrayInputStream field5 - - TestJsonObject() { - this.field1 = "toto" - this.field2 = true - this.field3 = new SubClass() - this.field4 = new NestedJsonObject() - this.field5 = new ByteArrayInputStream() - } - } - - static class NestedJsonObject { - - public AbstractSerialize field - - NestedJsonObject() { - this.field = new SubClass() - } - } - - def "test string serialization"() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def result = adapter.toJson(new TestJsonObject()) - - then: - result == "{\"field1\":\"toto\",\"field2\":true,\"field3\":{},\"field4\":{\"field\":{}},\"field5\":{}}" - } - - def "test simple case"() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def list = new LinkedHashMap() - list.put("key0","item0") - list.put("key1","item1") - list.put("key2","item2") - def result = adapter.toJson(list) - - then: - result == "{\"key0\":\"item0\",\"key1\":\"item1\",\"key2\":\"item2\"}" - } - - def "test SQS event "() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def myEvent = new SQSEvent() - List records = new ArrayList<>() - SQSEvent.SQSMessage message = new SQSEvent.SQSMessage() - message.setMessageId("myId") - message.setAwsRegion("myRegion") - records.add(message) - myEvent.setRecords(records) - def result = adapter.toJson(myEvent) - - then: - result == "{\"records\":[{\"awsRegion\":\"myRegion\",\"messageId\":\"myId\"}]}" - } - - def "test SNS Event"() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def myEvent = new SNSEvent() - List records = new ArrayList<>() - SNSEvent.SNSRecord message = new SNSEvent.SNSRecord() - message.setEventSource("mySource") - message.setEventVersion("myVersion") - records.add(message) - myEvent.setRecords(records) - def result = adapter.toJson(myEvent) - - then: - result == "{\"records\":[{\"eventSource\":\"mySource\",\"eventVersion\":\"myVersion\"}]}" - } - - def "test APIGatewayProxyRequest Event"() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def myEvent = new APIGatewayProxyRequestEvent() - myEvent.setBody("bababango") - myEvent.setHttpMethod("POST") - def result = adapter.toJson(myEvent) - - then: - result == "{\"body\":\"bababango\",\"httpMethod\":\"POST\"}" - } - - def "test MapStringObject Event"() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def myEvent = new HashMap() - def myNestedEvent = new HashMap() - myNestedEvent.put("nestedKey0", "nestedValue1") - myNestedEvent.put("nestedKey1", true) - myNestedEvent.put("nestedKey2", ["aaa", "bbb", "ccc", "dddd"]) - myEvent.put("firstKey", new TestJsonObject()) - myEvent.put("secondKey", myNestedEvent) - def result = adapter.toJson(myEvent) - - then: - result == "{\"firstKey\":{\"field1\":\"toto\",\"field2\":true,\"field3\":{},\"field4\":{\"field\":{}},\"field5\":{}},\"secondKey\":{\"nestedKey2\":[\"aaa\",\"bbb\",\"ccc\",\"dddd\"],\"nestedKey0\":\"nestedValue1\",\"nestedKey1\":true}}" - } - - def "test custom payload"() { - given: - def adapter = new Moshi.Builder() - .add(SkipUnsupportedTypeJsonAdapter.newFactory()) - .build() - .adapter(Object) - - when: - def customPayload = new CustomRequest() - def result = adapter.toJson(customPayload) - - then: - result == "{\"body\":{},\"path\":{},\"testBool\":false}" - } -} diff --git a/dd-trace-core/src/test/groovy/datadog/trace/llmobs/writer/ddintake/LLMObsSpanMapperTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/llmobs/writer/ddintake/LLMObsSpanMapperTest.groovy deleted file mode 100644 index a62c5315aff..00000000000 --- a/dd-trace-core/src/test/groovy/datadog/trace/llmobs/writer/ddintake/LLMObsSpanMapperTest.groovy +++ /dev/null @@ -1,372 +0,0 @@ -package datadog.trace.llmobs.writer.ddintake - -import com.fasterxml.jackson.databind.ObjectMapper -import datadog.communication.serialization.ByteBufferConsumer -import datadog.communication.serialization.FlushingBuffer -import datadog.communication.serialization.msgpack.MsgPackWriter -import datadog.trace.api.DDTags -import datadog.trace.api.llmobs.LLMObs -import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes -import datadog.trace.bootstrap.instrumentation.api.Tags -import datadog.trace.common.writer.ListWriter -import datadog.trace.core.test.DDCoreSpecification -import org.msgpack.jackson.dataformat.MessagePackFactory -import spock.lang.Shared - -import java.nio.ByteBuffer -import java.nio.channels.WritableByteChannel - -class LLMObsSpanMapperTest extends DDCoreSpecification { - - @Shared - ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()) - - def "test LLMObsSpanMapper serialization"() { - setup: - def mapper = new LLMObsSpanMapper() - def tracer = tracerBuilder().writer(new ListWriter()).build() - - - // Create a real LLMObs span using the tracer - def llmSpan = tracer.buildSpan("datadog", "openai.request") - .withResourceName("createCompletion") - .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) - .withTag("_ml_obs_tag.model_name", "gpt-4") - .withTag("_ml_obs_tag.model_provider", "openai") - .withTag("_ml_obs_metric.input_tokens", 50) - .withTag("_ml_obs_metric.output_tokens", 25) - .withTag("_ml_obs_metric.total_tokens", 75) - .withTag("_ml_obs_tag.session_id", "abc-123-session") - .start() - - llmSpan.setSpanType(InternalSpanTypes.LLMOBS) - - def toolCall = LLMObs.ToolCall.from("get_weather", "function_call", "call_123", [location: "San Francisco"]) - def toolResult = LLMObs.ToolResult.from("get_weather", "function_call_output", "call_123", '{"temperature":"72F"}') - def inputMessages = [ - LLMObs.LLMMessage.from("user", "Hello, what's the weather like?"), - LLMObs.LLMMessage.from("assistant", null, [toolCall], [toolResult]) - ] - def outputMessages = [LLMObs.LLMMessage.from("assistant", "I'll help you check the weather.")] - llmSpan.setTag("_ml_obs_tag.input", [ - messages: inputMessages, - prompt: [ - id: "prompt_123", - version: "1", - variables: [city: "San Francisco"], - chat_template: [[role: "user", content: "Hello, what's the weather like in {{city}}?"]] - ] - ]) - llmSpan.setTag("_ml_obs_tag.output", outputMessages) - llmSpan.setTag("_ml_obs_tag.metadata", [temperature: 0.7, max_tokens: 100]) - llmSpan.setTag("_ml_obs_tag.tool_definitions", [ - [ - name: "get_weather", - description: "Get weather by city", - schema: [type: "object", properties: [city: [type: "string"]]] - ] - ]) - llmSpan.setError(true) - llmSpan.setTag(DDTags.ERROR_MSG, "boom") - llmSpan.setTag(DDTags.ERROR_TYPE, "java.lang.IllegalStateException") - llmSpan.setTag(DDTags.ERROR_STACK, "stacktrace") - - llmSpan.finish() - - def trace = [llmSpan] - CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() - // Keep all formatted spans in a single flush for this assertion. - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)) - - when: - packer.format(trace, mapper) - packer.flush() - - then: - sink.captured != null - def payload = mapper.newPayload() - payload.withBody(1, sink.captured) - - // Capture the size before the buffer is written and the body buffer is emptied. - def sizeInBytes = payload.sizeInBytes() - - def channel = new ByteArrayOutputStream() - payload.writeTo(new WritableByteChannel() { - @Override - int write(ByteBuffer src) throws IOException { - def bytes = new byte[src.remaining()] - src.get(bytes) - channel.write(bytes) - return bytes.length - } - - @Override - boolean isOpen() { - return true - } - - @Override - void close() throws IOException { } - }) - - def bytesWritten = channel.toByteArray() - sizeInBytes == bytesWritten.length - def result = objectMapper.readValue(bytesWritten, Map) - - then: - result.containsKey("event_type") - result["event_type"] == "span" - result.containsKey("_dd.stage") - result["_dd.stage"] == "raw" - result.containsKey("spans") - result["spans"] instanceof List - result["spans"].size() == 1 - - def spanData = result["spans"][0] - spanData["name"] == "OpenAI.createCompletion" - spanData.containsKey("span_id") - spanData.containsKey("trace_id") - spanData.containsKey("start_ns") - spanData.containsKey("duration") - spanData.containsKey("_dd") - spanData["_dd"]["span_id"] == spanData["span_id"] - spanData["_dd"]["trace_id"] == spanData["trace_id"] - spanData["_dd"]["apm_trace_id"] == spanData["trace_id"] - - // Top-level session_id field β€” what the LLM Trace Explorer's Sessions filter queries. - spanData.containsKey("session_id") - spanData["session_id"] == "abc-123-session" - - spanData.containsKey("meta") - spanData["meta"]["span.kind"] == "llm" - spanData["meta"].containsKey("error") - spanData["meta"]["error"]["message"] == "boom" - spanData["meta"]["error"]["type"] == "java.lang.IllegalStateException" - spanData["meta"]["error"]["stack"] == "stacktrace" - spanData["meta"].containsKey("input") - spanData["meta"]["input"].containsKey("messages") - spanData["meta"]["input"]["messages"][0].containsKey("content") - spanData["meta"]["input"]["messages"][0]["content"] == "Hello, what's the weather like?" - spanData["meta"]["input"]["messages"][0].containsKey("role") - spanData["meta"]["input"]["messages"][0]["role"] == "user" - spanData["meta"]["input"]["messages"][1]["role"] == "assistant" - !spanData["meta"]["input"]["messages"][1].containsKey("content") - spanData["meta"]["input"]["messages"][1]["tool_calls"][0]["name"] == "get_weather" - spanData["meta"]["input"]["messages"][1]["tool_calls"][0]["type"] == "function_call" - spanData["meta"]["input"]["messages"][1]["tool_calls"][0]["tool_id"] == "call_123" - spanData["meta"]["input"]["messages"][1]["tool_calls"][0]["arguments"] == [location: "San Francisco"] - spanData["meta"]["input"]["messages"][1]["tool_results"][0]["name"] == "get_weather" - spanData["meta"]["input"]["messages"][1]["tool_results"][0]["type"] == "function_call_output" - spanData["meta"]["input"]["messages"][1]["tool_results"][0]["tool_id"] == "call_123" - spanData["meta"]["input"]["messages"][1]["tool_results"][0]["result"] == '{"temperature":"72F"}' - spanData["meta"]["input"]["prompt"]["id"] == "prompt_123" - spanData["meta"]["input"]["prompt"]["version"] == "1" - spanData["meta"]["input"]["prompt"]["variables"] == [city: "San Francisco"] - spanData["meta"]["input"]["prompt"]["chat_template"] == [[role: "user", content: "Hello, what's the weather like in {{city}}?"]] - spanData["meta"].containsKey("output") - spanData["meta"]["output"].containsKey("messages") - spanData["meta"]["output"]["messages"][0].containsKey("content") - spanData["meta"]["output"]["messages"][0]["content"] == "I'll help you check the weather." - spanData["meta"]["output"]["messages"][0].containsKey("role") - spanData["meta"]["output"]["messages"][0]["role"] == "assistant" - spanData["meta"]["tool_definitions"][0]["name"] == "get_weather" - spanData["meta"]["tool_definitions"][0]["description"] == "Get weather by city" - spanData["meta"]["tool_definitions"][0]["schema"] == [type: "object", properties: [city: [type: "string"]]] - spanData["meta"].containsKey("metadata") - - spanData.containsKey("metrics") - spanData["metrics"]["input_tokens"] == 50.0 - spanData["metrics"]["output_tokens"] == 25.0 - spanData["metrics"]["total_tokens"] == 75.0 - - spanData.containsKey("tags") - spanData["tags"].contains("language:jvm") - spanData["tags"].contains("session_id:abc-123-session") - } - - def "test LLMObsSpanMapper writes no spans when none are LLMObs spans"() { - setup: - def mapper = new LLMObsSpanMapper() - def tracer = tracerBuilder().writer(new ListWriter()).build() - - def regularSpan1 = tracer.buildSpan("datadog", "http.request") - .withResourceName("GET /api/users") - .withTag("http.method", "GET") - .withTag("http.url", "https://example.com/api/users") - .start() - regularSpan1.finish() - - def regularSpan2 = tracer.buildSpan("datadog", "database.query") - .withResourceName("SELECT * FROM users") - .withTag("db.type", "postgresql") - .start() - regularSpan2.finish() - - def trace = [regularSpan1, regularSpan2] - CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() - // Keep all formatted spans in a single flush for this assertion. - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)) - - when: - packer.format(trace, mapper) - packer.flush() - - then: - sink.captured == null - } - - def "test consecutive packer.format calls accumulate spans from multiple traces"() { - setup: - def mapper = new LLMObsSpanMapper() - def tracer = tracerBuilder().writer(new ListWriter()).build() - - // First trace with 2 LLMObs spans - def llmSpan1 = tracer.buildSpan("datadog", "chat-completion-1") - .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) - .withTag("_ml_obs_tag.model_name", "gpt-4") - .withTag("_ml_obs_tag.model_provider", "openai") - .start() - llmSpan1.setSpanType(InternalSpanTypes.LLMOBS) - llmSpan1.finish() - - def llmSpan2 = tracer.buildSpan("datadog", "chat-completion-2") - .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) - .withTag("_ml_obs_tag.model_name", "gpt-3.5") - .withTag("_ml_obs_tag.model_provider", "openai") - .start() - llmSpan2.setSpanType(InternalSpanTypes.LLMOBS) - llmSpan2.finish() - - // Second trace with 1 LLMObs span - def llmSpan3 = tracer.buildSpan("datadog", "chat-completion-3") - .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) - .withTag("_ml_obs_tag.model_name", "claude-3") - .withTag("_ml_obs_tag.model_provider", "anthropic") - .start() - llmSpan3.setSpanType(InternalSpanTypes.LLMOBS) - llmSpan3.finish() - - def trace1 = [llmSpan1, llmSpan2] - def trace2 = [llmSpan3] - CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() - // Keep all formatted spans in a single flush for this assertion. - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)) - - when: - packer.format(trace1, mapper) - packer.format(trace2, mapper) - packer.flush() - - then: - sink.captured != null - def payload = mapper.newPayload() - payload.withBody(3, sink.captured) - - // Capture the size before the buffer is written and the body buffer is emptied. - def sizeInBytes = payload.sizeInBytes() - - def channel = new ByteArrayOutputStream() - payload.writeTo(new WritableByteChannel() { - @Override - int write(ByteBuffer src) throws IOException { - def bytes = new byte[src.remaining()] - src.get(bytes) - channel.write(bytes) - return bytes.length - } - - @Override - boolean isOpen() { - return true - } - - @Override - void close() throws IOException { } - }) - - def bytesWritten = channel.toByteArray() - sizeInBytes == bytesWritten.length - def result = objectMapper.readValue(bytesWritten, Map) - - then: - result.containsKey("event_type") - result["event_type"] == "span" - result.containsKey("_dd.stage") - result["_dd.stage"] == "raw" - result.containsKey("spans") - result["spans"] instanceof List - result["spans"].size() == 3 - - def spanNames = result["spans"].collect { it["name"] } - spanNames.contains("chat-completion-1") - spanNames.contains("chat-completion-2") - spanNames.contains("chat-completion-3") - } - - def "test LLMObsSpanMapper omits top-level session_id when not set"() { - setup: - def mapper = new LLMObsSpanMapper() - def tracer = tracerBuilder().writer(new ListWriter()).build() - - def llmSpan = tracer.buildSpan("datadog", "openai.request") - .withResourceName("createCompletion") - .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) - .withTag("_ml_obs_tag.model_name", "gpt-4") - .withTag("_ml_obs_tag.model_provider", "openai") - .start() - llmSpan.setSpanType(InternalSpanTypes.LLMOBS) - llmSpan.finish() - - def trace = [llmSpan] - CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer() - MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)) - - when: - packer.format(trace, mapper) - packer.flush() - - then: - sink.captured != null - def payload = mapper.newPayload() - payload.withBody(1, sink.captured) - - def channel = new ByteArrayOutputStream() - payload.writeTo(new WritableByteChannel() { - @Override - int write(ByteBuffer src) throws IOException { - def bytes = new byte[src.remaining()] - src.get(bytes) - channel.write(bytes) - return bytes.length - } - - @Override - boolean isOpen() { - return true - } - - @Override - void close() throws IOException { } - }) - - def result = objectMapper.readValue(channel.toByteArray(), Map) - def spanData = result["spans"][0] - - then: - // No top-level session_id field when the tag was never set. - !spanData.containsKey("session_id") - - // And no session_id entry leaks into tags[] either. - spanData["tags"].every { !it.startsWith("session_id:") } - } - - static class CapturingByteBufferConsumer implements ByteBufferConsumer { - - ByteBuffer captured - - @Override - void accept(int messageCount, ByteBuffer buffer) { - captured = buffer - } - } -} diff --git a/dd-trace-core/src/test/java/datadog/trace/TracerConnectionReliabilityTest.java b/dd-trace-core/src/test/java/datadog/trace/TracerConnectionReliabilityTest.java new file mode 100644 index 00000000000..eecd93bb2b5 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/TracerConnectionReliabilityTest.java @@ -0,0 +1,186 @@ +package datadog.trace; + +import static datadog.trace.api.ConfigDefaults.DEFAULT_TRACE_AGENT_PORT; +import static datadog.trace.api.ProtocolVersion.V0_4; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +import com.squareup.moshi.JsonAdapter; +import com.squareup.moshi.Moshi; +import com.squareup.moshi.Types; +import datadog.communication.ddagent.DDAgentFeaturesDiscovery; +import datadog.communication.ddagent.SharedCommunicationObjects; +import datadog.metrics.api.Monitoring; +import datadog.trace.agent.test.utils.PortUtils; +import datadog.trace.api.IdGenerationStrategy; +import datadog.trace.core.CoreTracer; +import datadog.trace.test.util.DDJavaSpecification; +import java.io.IOException; +import java.lang.reflect.Type; +import java.util.List; +import java.util.Properties; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.FixedHostPortGenericContainer; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +public class TracerConnectionReliabilityTest extends DDJavaSpecification { + + static final int FEATURES_DISCOVERY_MIN_DELAY = 10; + + static OkHttpClient client; + static JsonAdapter>> traceJsonAdapter; + + int agentContainerPort; + CoreTracer tracer; + + @BeforeAll + static void setupSpec() { + client = new OkHttpClient(); + // Create body parser for /test/traces route + Moshi moshi = new Moshi.Builder().build(); + Type type = + Types.newParameterizedType( + List.class, Types.newParameterizedType(List.class, SentTraces.class)); + traceJsonAdapter = moshi.adapter(type); + } + + @BeforeEach + void setup() { + // Pick a random port for the test agent + agentContainerPort = PortUtils.randomOpenPort(); + // Build a tracer talking to the test agent (with the right port and traces endpoint) + Properties properties = new Properties(); + properties.put("trace.agent.port", Integer.toString(agentContainerPort)); + SharedCommunicationObjects sharedCommunicationObjects = new SharedCommunicationObjects(); + sharedCommunicationObjects.agentUrl = HttpUrl.get("http://localhost:" + agentContainerPort); + sharedCommunicationObjects.agentHttpClient = client; + FixedTraceEndpointFeaturesDiscovery fixedFeaturesDiscovery = + new FixedTraceEndpointFeaturesDiscovery(sharedCommunicationObjects); + sharedCommunicationObjects.setFeaturesDiscovery(fixedFeaturesDiscovery); + + tracer = + CoreTracer.builder() + .idGenerationStrategy(IdGenerationStrategy.fromName("SEQUENTIAL")) + .withProperties(properties) + .sharedCommunicationObjects(sharedCommunicationObjects) + .build(); + } + + @AfterEach + void cleanup() { + if (tracer != null) { + tracer.close(); + } + } + + @Test + void testLateAgentStart() throws Exception { + createSpans(10, 100); + tracer.flush(); + + GenericContainer agentContainer = startTestAgentContainer(); + int noAgentCount = getTraceCount(agentContainer); + waitForDiscoveryTimeout(); + + createSpans(20, 100); + tracer.flush(); + int withAgentCount = getTraceCount(agentContainer); + agentContainer.stop(); + + assertFalse(agentContainer.isRunning()); + assertEquals(0, noAgentCount); + assertEquals(20, withAgentCount); + } + + @Test + void testAgentRestart() throws Exception { + GenericContainer agentContainer = startTestAgentContainer(); + + createSpans(10, 100); + tracer.flush(); + int withAgentCount = getTraceCount(agentContainer); + + assertEquals(10, withAgentCount); + + agentContainer.stop(); + createSpans(10, 100); + tracer.flush(); + + waitForDiscoveryTimeout(); + agentContainer = startTestAgentContainer(); + int noTraceCount = getTraceCount(agentContainer); + createSpans(10, 100); + tracer.flush(); + withAgentCount = getTraceCount(agentContainer); + agentContainer.stop(); + + assertFalse(agentContainer.isRunning()); + assertEquals(0, noTraceCount); + assertEquals(10, withAgentCount); + } + + @SuppressWarnings({"deprecation", "rawtypes", "unchecked"}) + GenericContainer startTestAgentContainer() { + //noinspection GrDeprecatedAPIUsage Use FixedHostPortGenericContainer against deprecation + // because we need to know the exposed to configure the tracer at start + GenericContainer agentContainer = + new FixedHostPortGenericContainer( + "registry.ddbuild.io/images/mirror/dd-apm-test-agent/ddapm-test-agent:v1.44.0") + .withFixedExposedPort(agentContainerPort, DEFAULT_TRACE_AGENT_PORT) + .withEnv( + "ENABLED_CHECKS", + "trace_count_header,meta_tracer_version_header,trace_content_length") + .waitingFor(Wait.forHttp("/test/traces")); + agentContainer.start(); + return agentContainer; + } + + void createSpans(int count, int delay) throws InterruptedException { + for (int index = 1; index <= count; index++) { + datadog.trace.bootstrap.instrumentation.api.AgentSpan span = + tracer.buildSpan("datadog", "operation-" + index).start(); + Thread.sleep(delay); + span.finish(); + } + } + + static void waitForDiscoveryTimeout() throws InterruptedException { + Thread.sleep((long) (FEATURES_DISCOVERY_MIN_DELAY * 1.5)); + } + + int getTraceCount(GenericContainer agentContainer) throws IOException { + Request request = + new Request.Builder() + .url("http://" + agentContainer.getHost() + ":" + agentContainerPort + "/test/traces") + .build(); + String body = client.newCall(request).execute().body().string(); + return traceJsonAdapter.fromJson(body).size(); + } + + static class FixedTraceEndpointFeaturesDiscovery extends DDAgentFeaturesDiscovery { + FixedTraceEndpointFeaturesDiscovery(SharedCommunicationObjects objects) { + super(objects.agentHttpClient, Monitoring.DISABLED, objects.agentUrl, V0_4, false, false); + } + + @Override + public String getTraceEndpoint() { + return V04_ENDPOINT; + } + + @Override + protected long getFeaturesDiscoveryMinDelayMillis() { + return FEATURES_DISCOVERY_MIN_DELAY; + } + } + + static class SentTraces { + String name; + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/api/writer/PrintingWriterTest.java b/dd-trace-core/src/test/java/datadog/trace/api/writer/PrintingWriterTest.java new file mode 100644 index 00000000000..8289bfe35a0 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/api/writer/PrintingWriterTest.java @@ -0,0 +1,145 @@ +package datadog.trace.api.writer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import com.squareup.moshi.JsonAdapter; +import com.squareup.moshi.Moshi; +import com.squareup.moshi.Types; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.common.writer.ListWriter; +import datadog.trace.common.writer.PrintingWriter; +import datadog.trace.core.CoreTracer; +import datadog.trace.core.DDCoreJavaSpecification; +import datadog.trace.core.DDSpan; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import okio.Buffer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@SuppressWarnings({"unchecked", "rawtypes"}) +public class PrintingWriterTest extends DDCoreJavaSpecification { + + private CoreTracer tracer; + private List sampleTrace; + private List secondTrace; + private JsonAdapter adapter; + + @BeforeEach + void setup() { + tracer = tracerBuilder().writer(new ListWriter()).build(); + adapter = + new Moshi.Builder() + .build() + .adapter( + Types.newParameterizedType( + Map.class, + String.class, + Types.newParameterizedType( + List.class, Types.newParameterizedType(List.class, Map.class)))); + + AgentTracer.SpanBuilder builder = + tracer + .buildSpan("datadog", "fakeOperation") + .withServiceName("fakeService") + .withResourceName("fakeResource") + .withSpanType("fakeType"); + + sampleTrace = Arrays.asList((DDSpan) builder.start(), (DDSpan) builder.start()); + secondTrace = Collections.singletonList((DDSpan) builder.start()); + } + + @AfterEach + void cleanup() { + if (tracer != null) { + tracer.close(); + } + } + + @Test + void testPrintingRegularIds() throws Exception { + Buffer buffer = new Buffer(); + PrintingWriter writer = new PrintingWriter(buffer.outputStream(), false); + + writer.write(sampleTrace); + Map>> result = + (Map>>) adapter.fromJson(buffer.readString(StandardCharsets.UTF_8)); + + assertEquals(sampleTrace.size(), result.get("traces").get(0).size()); + for (Map span : result.get("traces").get(0)) { + assertRegularSpanFields(span, false); + } + + writer.write(secondTrace); + result = + (Map>>) adapter.fromJson(buffer.readString(StandardCharsets.UTF_8)); + + assertEquals(secondTrace.size(), result.get("traces").get(0).size()); + for (Map span : result.get("traces").get(0)) { + assertRegularSpanFields(span, false); + } + } + + @Test + void testPrintingRegularHexIds() throws Exception { + Buffer buffer = new Buffer(); + PrintingWriter writer = new PrintingWriter(buffer.outputStream(), true); + + writer.write(sampleTrace); + Map>> result = + (Map>>) adapter.fromJson(buffer.readString(StandardCharsets.UTF_8)); + + assertEquals(sampleTrace.size(), result.get("traces").get(0).size()); + for (Map span : result.get("traces").get(0)) { + assertRegularSpanFields(span, true); + } + } + + @Test + void testPrintingMultipleTraces() throws Exception { + Buffer buffer = new Buffer(); + PrintingWriter writer = new PrintingWriter(buffer.outputStream(), false); + + writer.write(sampleTrace); + writer.write(secondTrace); + Map>> result1 = + (Map>>) adapter.fromJson(buffer.readUtf8Line()); + Map>> result2 = + (Map>>) adapter.fromJson(buffer.readUtf8Line()); + + assertEquals(sampleTrace.size(), result1.get("traces").get(0).size()); + for (Map span : result2.get("traces").get(0)) { + assertRegularSpanFields(span, false); + } + assertEquals(secondTrace.size(), result2.get("traces").get(0).size()); + for (Map span : result2.get("traces").get(0)) { + assertRegularSpanFields(span, false); + } + } + + private void assertRegularSpanFields(Map span, boolean hexIds) { + assertEquals("fakeService", span.get("service")); + assertEquals("fakeOperation", span.get("name")); + assertEquals("fakeResource", span.get("resource")); + assertEquals("fakeType", span.get("type")); + if (hexIds) { + assertInstanceOf(String.class, span.get("trace_id")); + assertInstanceOf(String.class, span.get("span_id")); + assertInstanceOf(String.class, span.get("parent_id")); + } else { + assertInstanceOf(Number.class, span.get("trace_id")); + assertInstanceOf(Number.class, span.get("span_id")); + assertInstanceOf(Number.class, span.get("parent_id")); + } + assertInstanceOf(Number.class, span.get("start")); + assertInstanceOf(Number.class, span.get("duration")); + assertEquals(0, ((Number) span.get("error")).intValue()); + assertInstanceOf(Map.class, span.get("metrics")); + assertInstanceOf(Map.class, span.get("meta")); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/civisibility/interceptor/CiVisibilityApmProtocolInterceptorTest.java b/dd-trace-core/src/test/java/datadog/trace/civisibility/interceptor/CiVisibilityApmProtocolInterceptorTest.java new file mode 100644 index 00000000000..1ffc71e7ed3 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/civisibility/interceptor/CiVisibilityApmProtocolInterceptorTest.java @@ -0,0 +1,89 @@ +package datadog.trace.civisibility.interceptor; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import datadog.trace.api.DDSpanTypes; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import datadog.trace.common.writer.ListWriter; +import datadog.trace.core.CoreTracer; +import datadog.trace.core.DDCoreJavaSpecification; +import datadog.trace.core.DDSpan; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 10, unit = TimeUnit.SECONDS) +public class CiVisibilityApmProtocolInterceptorTest extends DDCoreJavaSpecification { + + private ListWriter writer; + private CoreTracer tracer; + + @BeforeEach + void setup() { + writer = new ListWriter(); + tracer = tracerBuilder().writer(writer).build(); + } + + @AfterEach + void cleanup() { + if (tracer != null) { + tracer.close(); + } + } + + @Test + void testSuiteAndTestModuleSpansAreFilteredOut() throws InterruptedException, TimeoutException { + tracer.addTraceInterceptor(CiVisibilityApmProtocolInterceptor.INSTANCE); + + tracer + .buildSpan("datadog", "test-module") + .withSpanType(DDSpanTypes.TEST_MODULE_END) + .start() + .finish(); + tracer + .buildSpan("datadog", "test-suite") + .withSpanType(DDSpanTypes.TEST_SUITE_END) + .start() + .finish(); + tracer.buildSpan("datadog", "test").withSpanType(DDSpanTypes.TEST).start().finish(); + + writer.waitForTraces(1); + + List trace = writer.firstTrace(); + assertEquals(1, trace.size()); + + DDSpan span = trace.get(0); + assertEquals("test", span.getOperationName().toString()); + } + + @Test + void testSessionTestModuleAndTestSuiteIdsAreNullified() + throws InterruptedException, TimeoutException { + tracer.addTraceInterceptor(CiVisibilityApmProtocolInterceptor.INSTANCE); + + DDSpan testSpan = + (DDSpan) tracer.buildSpan("datadog", "test").withSpanType(DDSpanTypes.TEST).start(); + testSpan.setTag(Tags.TEST_SESSION_ID, "session ID"); + testSpan.setTag(Tags.TEST_MODULE_ID, "module ID"); + testSpan.setTag(Tags.TEST_SUITE_ID, "suite ID"); + testSpan.setTag("random tag", "random value"); + testSpan.finish(); + + writer.waitForTraces(1); + + List trace = writer.firstTrace(); + assertEquals(1, trace.size()); + + DDSpan span = trace.get(0); + + assertNull(span.getTag(Tags.TEST_SESSION_ID)); + assertNull(span.getTag(Tags.TEST_MODULE_ID)); + assertNull(span.getTag(Tags.TEST_SUITE_ID)); + assertEquals("random value", span.getTag("random tag")); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/civisibility/interceptor/CiVisibilityTraceInterceptorTest.java b/dd-trace-core/src/test/java/datadog/trace/civisibility/interceptor/CiVisibilityTraceInterceptorTest.java new file mode 100644 index 00000000000..64d58898179 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/civisibility/interceptor/CiVisibilityTraceInterceptorTest.java @@ -0,0 +1,85 @@ +package datadog.trace.civisibility.interceptor; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import datadog.trace.api.DDTags; +import datadog.trace.api.civisibility.CIConstants; +import datadog.trace.common.writer.ListWriter; +import datadog.trace.core.CoreTracer; +import datadog.trace.core.DDCoreJavaSpecification; +import datadog.trace.core.DDSpan; +import datadog.trace.junit.utils.tabletest.DDSpanTypesConverter; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.converter.ConvertWith; +import org.tabletest.junit.TableTest; + +@Timeout(value = 10, unit = TimeUnit.SECONDS) +public class CiVisibilityTraceInterceptorTest extends DDCoreJavaSpecification { + + private ListWriter writer; + private CoreTracer tracer; + + @BeforeEach + void setup() { + writer = new ListWriter(); + tracer = tracerBuilder().writer(writer).build(); + } + + @AfterEach + void cleanup() { + if (tracer != null) { + tracer.close(); + } + } + + @Test + void discardATraceThatDoesNotComeFromCiApp() { + tracer.addTraceInterceptor(CiVisibilityTraceInterceptor.INSTANCE); + tracer.buildSpan("datadog", "sample-span").start().finish(); + + assertEquals(0, writer.size()); + } + + @Test + void doNotDiscardATraceThatComesFromCiApp() { + tracer.addTraceInterceptor(CiVisibilityTraceInterceptor.INSTANCE); + + DDSpan span = (DDSpan) tracer.buildSpan("datadog", "sample-span").start(); + span.context().setOrigin(CIConstants.CIAPP_TEST_ORIGIN); + span.finish(); + + // expect: + assertEquals(1, writer.size()); + } + + @TableTest({ + "scenario | spanType ", + "test | DDSpanTypes.TEST ", + "test suite end | DDSpanTypes.TEST_SUITE_END ", + "test module end | DDSpanTypes.TEST_MODULE_END ", + "test session end | DDSpanTypes.TEST_SESSION_END" + }) + void addTracerVersionToSpansOfType(@ConvertWith(DDSpanTypesConverter.class) String spanType) + throws InterruptedException, TimeoutException { + tracer.addTraceInterceptor(CiVisibilityTraceInterceptor.INSTANCE); + + DDSpan span = + (DDSpan) tracer.buildSpan("datadog", "sample-span").withSpanType(spanType).start(); + span.context().setOrigin(CIConstants.CIAPP_TEST_ORIGIN); + span.finish(); + writer.waitForTraces(1); + + List trace = writer.firstTrace(); + assertEquals(1, trace.size()); + + DDSpan receivedSpan = trace.get(0); + assertNotNull(receivedSpan.getTag(DDTags.LIBRARY_VERSION_TAG_KEY)); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/civisibility/writer/ddintake/CiTestCovMapperV2Test.java b/dd-trace-core/src/test/java/datadog/trace/civisibility/writer/ddintake/CiTestCovMapperV2Test.java new file mode 100644 index 00000000000..b6c42fc9388 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/civisibility/writer/ddintake/CiTestCovMapperV2Test.java @@ -0,0 +1,340 @@ +package datadog.trace.civisibility.writer.ddintake; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.fasterxml.jackson.databind.ObjectMapper; +import datadog.communication.serialization.GrowableBuffer; +import datadog.communication.serialization.msgpack.MsgPackWriter; +import datadog.trace.api.DDTraceId; +import datadog.trace.api.civisibility.coverage.CoverageProbes; +import datadog.trace.api.civisibility.coverage.CoverageStore; +import datadog.trace.api.civisibility.coverage.NoOpProbes; +import datadog.trace.api.civisibility.coverage.TestReport; +import datadog.trace.api.civisibility.coverage.TestReportFileEntry; +import datadog.trace.api.civisibility.domain.TestContext; +import datadog.trace.api.sampling.PrioritySampling; +import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes; +import datadog.trace.core.DDCoreJavaSpecification; +import datadog.trace.core.DDSpan; +import datadog.trace.core.propagation.PropagationTags; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.msgpack.jackson.dataformat.MessagePackFactory; + +@SuppressWarnings("unchecked") +public class CiTestCovMapperV2Test extends DDCoreJavaSpecification { + + private static final ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()); + + @Test + void testWritesMessage() throws Exception { + List trace = + givenTrace( + new TestReport( + DDTraceId.from(1), + 2L, + 3L, + Arrays.asList( + new TestReportFileEntry("source", BitSet.valueOf(new long[] {3, 5, 8}))))); + + Map message = getMappedMessage(trace); + + List> coverages = assertVersionAndGetCoverages(message, 2); + assertEquals(1, coverages.size()); + + Map coverage = coverages.get(0); + assertCoverage(coverage, 1, 2, 3); + + List> files = (List>) coverage.get("files"); + assertEquals(1, files.size()); + assertFile( + files.get(0), "source", new byte[] {3, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8}); + } + + @Test + void testWritesMessageWithMultipleFilesAndMultipleLines() throws Exception { + List trace = + givenTrace( + new TestReport( + DDTraceId.from(1), + 2L, + 3L, + Arrays.asList( + new TestReportFileEntry("sourceA", BitSet.valueOf(new long[] {3, 5, 8})), + new TestReportFileEntry("sourceB", BitSet.valueOf(new long[] {1, 255, 7}))))); + + Map message = getMappedMessage(trace); + + List> coverages = assertVersionAndGetCoverages(message, 2); + assertEquals(1, coverages.size()); + + Map coverage = coverages.get(0); + assertCoverage(coverage, 1, 2, 3); + + List> files = (List>) coverage.get("files"); + assertEquals(2, files.size()); + assertFile( + files.get(0), "sourceA", new byte[] {3, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8}); + assertFile( + files.get(1), "sourceB", new byte[] {1, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 7}); + } + + @Test + void testWritesMessageWithMultipleReports() throws Exception { + List trace = + givenTrace( + new TestReport( + DDTraceId.from(1), + 2L, + 3L, + Arrays.asList( + new TestReportFileEntry("sourceA", BitSet.valueOf(new long[] {2, 17, 41})))), + new TestReport( + DDTraceId.from(1), + 2L, + 4L, + Arrays.asList( + new TestReportFileEntry("sourceB", BitSet.valueOf(new long[] {11, 13, 55}))))); + + Map message = getMappedMessage(trace); + + List> coverages = assertVersionAndGetCoverages(message, 2); + assertEquals(2, coverages.size()); + + Map coverage0 = coverages.get(0); + assertCoverage(coverage0, 1, 2, 3); + List> files0 = (List>) coverage0.get("files"); + assertEquals(1, files0.size()); + assertFile( + files0.get(0), "sourceA", new byte[] {2, 0, 0, 0, 0, 0, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 41}); + + Map coverage1 = coverages.get(1); + assertCoverage(coverage1, 1, 2, 4); + List> files1 = (List>) coverage1.get("files"); + assertEquals(1, files1.size()); + assertFile( + files1.get(0), + "sourceB", + new byte[] {11, 0, 0, 0, 0, 0, 0, 0, 13, 0, 0, 0, 0, 0, 0, 0, 55}); + } + + @Test + void skipsSpansThatHaveNoReports() throws Exception { + List trace = + givenTrace( + null, + new TestReport( + DDTraceId.from(1), + 2L, + 3L, + Arrays.asList( + new TestReportFileEntry("source", BitSet.valueOf(new long[] {83, 25, 48})))), + null); + + Map message = getMappedMessage(trace); + + List> coverages = assertVersionAndGetCoverages(message, 2); + assertEquals(1, coverages.size()); + + Map coverage = coverages.get(0); + assertCoverage(coverage, 1, 2, 3); + + List> files = (List>) coverage.get("files"); + assertEquals(1, files.size()); + assertFile( + files.get(0), "source", new byte[] {83, 0, 0, 0, 0, 0, 0, 0, 25, 0, 0, 0, 0, 0, 0, 0, 48}); + } + + @Test + void skipsEmptyReports() throws Exception { + List trace = + givenTrace( + new TestReport( + DDTraceId.from(1), + 2L, + 3L, + Arrays.asList( + new TestReportFileEntry("source", BitSet.valueOf(new long[] {33, 53, 87})))), + new TestReport(DDTraceId.from(1), 2L, 4L, Collections.emptyList())); + + Map message = getMappedMessage(trace); + + List> coverages = assertVersionAndGetCoverages(message, 2); + assertEquals(1, coverages.size()); + + Map coverage = coverages.get(0); + assertCoverage(coverage, 1, 2, 3); + + List> files = (List>) coverage.get("files"); + assertEquals(1, files.size()); + assertFile( + files.get(0), "source", new byte[] {33, 0, 0, 0, 0, 0, 0, 0, 53, 0, 0, 0, 0, 0, 0, 0, 87}); + } + + @Test + void skipsDuplicateReports() throws Exception { + List trace = new ArrayList<>(); + TestReport report = + new TestReport( + DDTraceId.from(1), + 2L, + 3L, + Arrays.asList(new TestReportFileEntry("source", BitSet.valueOf(new long[] {3, 5, 8})))); + + trace.add( + buildSpan( + 0, + InternalSpanTypes.TEST, + PropagationTags.factory().empty(), + Collections.emptyMap(), + PrioritySampling.SAMPLER_KEEP, + new DummyTestContext(new DummyReportHolder(report)))); + trace.add( + buildSpan( + 0, + "testChild", + PropagationTags.factory().empty(), + Collections.emptyMap(), + PrioritySampling.SAMPLER_KEEP, + new DummyTestContext(new DummyReportHolder(report)))); + + Map message = getMappedMessage(trace); + + List> coverages = assertVersionAndGetCoverages(message, 2); + assertEquals(1, coverages.size()); + + Map coverage = coverages.get(0); + assertCoverage(coverage, 1, 2, 3); + + List> files = (List>) coverage.get("files"); + assertEquals(1, files.size()); + assertFile( + files.get(0), "source", new byte[] {3, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 8}); + } + + private List givenTrace(TestReport... testReports) { + List trace = new ArrayList<>(); + for (TestReport testReport : testReports) { + DummyReportHolder testReportHolder = new DummyReportHolder(testReport); + trace.add( + buildSpan( + 0, + InternalSpanTypes.TEST, + PropagationTags.factory().empty(), + Collections.emptyMap(), + PrioritySampling.SAMPLER_KEEP, + new DummyTestContext(testReportHolder))); + } + return trace; + } + + private Map getMappedMessage(List trace) throws IOException { + GrowableBuffer buffer = new GrowableBuffer(1024); + CiTestCovMapperV2 mapper = new CiTestCovMapperV2(false); + mapper.map(trace, new MsgPackWriter(buffer)); + + ByteArrayWritableByteChannel channel = new ByteArrayWritableByteChannel(); + mapper.newPayload().withBody(1, buffer.slice()).writeTo(channel); + + return objectMapper.readValue(channel.toByteArray(), Map.class); + } + + private List> assertVersionAndGetCoverages( + Map message, int version) { + assertEquals(version, message.get("version")); + return (List>) message.get("coverages"); + } + + private void assertCoverage( + Map coverage, int sessionId, int suiteId, int spanId) { + assertEquals(sessionId, coverage.get("test_session_id")); + assertEquals(suiteId, coverage.get("test_suite_id")); + assertEquals(spanId, coverage.get("span_id")); + } + + private void assertFile(Map file, String filename, byte[] bitmap) { + assertEquals(filename, file.get("filename")); + assertArrayEquals(bitmap, (byte[]) file.get("bitmap")); + } + + private static final class DummyReportHolder implements CoverageStore { + private final TestReport testReport; + + DummyReportHolder(TestReport testReport) { + this.testReport = testReport; + } + + @Override + public TestReport getReport() { + return testReport; + } + + @Override + public boolean report(DDTraceId testSessionId, Long testSuiteId, long spanId) { + return false; + } + + @Override + public CoverageProbes getProbes() { + return NoOpProbes.INSTANCE; + } + } + + private static final class DummyTestContext implements TestContext { + private final CoverageStore coverageStore; + + DummyTestContext(CoverageStore coverageStore) { + this.coverageStore = coverageStore; + } + + @Override + public CoverageStore getCoverageStore() { + return coverageStore; + } + + @Override + public void set(Class key, T value) {} + + @Override + public T get(Class key) { + return null; + } + } + + private static final class ByteArrayWritableByteChannel implements WritableByteChannel { + private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + + @Override + public int write(ByteBuffer src) throws IOException { + int remaining = src.remaining(); + byte[] buffer = new byte[remaining]; + src.get(buffer); + outputStream.write(buffer); + return remaining; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() throws IOException { + outputStream.close(); + } + + byte[] toByteArray() { + return outputStream.toByteArray(); + } + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/civisibility/writer/ddintake/CiTestCycleMapperV1PayloadTest.java b/dd-trace-core/src/test/java/datadog/trace/civisibility/writer/ddintake/CiTestCycleMapperV1PayloadTest.java new file mode 100644 index 00000000000..e59f758056e --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/civisibility/writer/ddintake/CiTestCycleMapperV1PayloadTest.java @@ -0,0 +1,581 @@ +package datadog.trace.civisibility.writer.ddintake; + +import static datadog.trace.api.civisibility.CIConstants.MAX_META_STRING_VALUE_LENGTH; +import static datadog.trace.bootstrap.instrumentation.api.InstrumentationTags.DD_MEASURED; +import static datadog.trace.common.writer.TraceGenerator.generateRandomSpan; +import static datadog.trace.common.writer.TraceGenerator.generateRandomTraces; +import static datadog.trace.util.Strings.truncate; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.msgpack.core.MessageFormat.FLOAT32; +import static org.msgpack.core.MessageFormat.FLOAT64; +import static org.msgpack.core.MessageFormat.INT16; +import static org.msgpack.core.MessageFormat.INT32; +import static org.msgpack.core.MessageFormat.INT64; +import static org.msgpack.core.MessageFormat.INT8; +import static org.msgpack.core.MessageFormat.NEGFIXINT; +import static org.msgpack.core.MessageFormat.POSFIXINT; +import static org.msgpack.core.MessageFormat.UINT16; +import static org.msgpack.core.MessageFormat.UINT32; +import static org.msgpack.core.MessageFormat.UINT64; +import static org.msgpack.core.MessageFormat.UINT8; + +import com.fasterxml.jackson.databind.ObjectMapper; +import datadog.communication.serialization.ByteBufferConsumer; +import datadog.communication.serialization.FlushingBuffer; +import datadog.communication.serialization.msgpack.MsgPackWriter; +import datadog.trace.api.DDTags; +import datadog.trace.api.DDTraceId; +import datadog.trace.api.civisibility.CiVisibilityWellKnownTags; +import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import datadog.trace.common.writer.Payload; +import datadog.trace.common.writer.TraceGenerator; +import datadog.trace.core.DDSpanContext; +import datadog.trace.test.util.DDJavaSpecification; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.msgpack.core.MessageFormat; +import org.msgpack.core.MessagePack; +import org.msgpack.core.MessageUnpacker; +import org.msgpack.jackson.dataformat.MessagePackFactory; +import org.tabletest.junit.TableTest; + +public class CiTestCycleMapperV1PayloadTest extends DDJavaSpecification { + + @TableTest({ + "scenario | bufferSize | traceCount | lowCardinality", + "20k buffer, 0 traces, low cardinality | 20480 | 0 | true ", + "20k buffer, 1 trace, low cardinality | 20480 | 1 | true ", + "30k buffer, 1 trace, low cardinality | 30720 | 1 | true ", + "30k buffer, 2 traces, low cardinality | 30720 | 2 | true ", + "20k buffer, 0 traces, high cardinality | 20480 | 0 | false ", + "20k buffer, 1 trace, high cardinality | 20480 | 1 | false ", + "30k buffer, 1 trace, high cardinality | 30720 | 1 | false ", + "30k buffer, 2 traces, high cardinality | 30720 | 2 | false ", + "100k buffer, 0 traces, low cardinality | 102400 | 0 | true ", + "100k buffer, 1 trace, low cardinality | 102400 | 1 | true ", + "100k buffer, 10 traces, low cardinality | 102400 | 10 | true ", + "100k buffer, 100 traces, low cardinality | 102400 | 100 | true ", + "100k buffer, 1000 traces, low cardinality | 102400 | 1000 | true ", + "100k buffer, 0 traces, high cardinality | 102400 | 0 | false ", + "100k buffer, 1 trace, high cardinality | 102400 | 1 | false ", + "100k buffer, 10 traces, high cardinality | 102400 | 10 | false ", + "100k buffer, 100 traces, high cardinality | 102400 | 100 | false ", + "100k buffer, 1000 traces, high cardinality | 102400 | 1000 | false " + }) + void testTracesWrittenCorrectly(int bufferSize, int traceCount, boolean lowCardinality) { + CiVisibilityWellKnownTags wellKnownTags = + new CiVisibilityWellKnownTags( + "runtimeid", + "my-env", + "language", + "my-runtime-name", + "my-runtime-version", + "my-runtime-vendor", + "my-os-arch", + "my-os-platform", + "my-os-version", + "false"); + CiTestCycleMapperV1 mapper = new CiTestCycleMapperV1(wellKnownTags, false); + + List> traces = generateRandomTraces(traceCount, lowCardinality); + PayloadVerifier verifier = new PayloadVerifier(wellKnownTags, traces, mapper); + + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(bufferSize, verifier)); + + boolean tracesFitInBuffer = true; + for (List trace : traces) { + if (!packer.format(trace, mapper)) { + verifier.skipLargeTrace(); + tracesFitInBuffer = false; + } + } + packer.flush(); + + if (tracesFitInBuffer) { + verifier.verifyTracesConsumed(); + } + } + + @Test + void verifyTestSuiteIdTestModuleIdAndTestSessionIdAreWrittenAsTopLevelTagsInTestEvent() { + Map extraTags = new HashMap<>(); + extraTags.put(Tags.TEST_SESSION_ID, DDTraceId.from(123)); + extraTags.put(Tags.TEST_MODULE_ID, 456L); + extraTags.put(Tags.TEST_SUITE_ID, 789L); + TraceGenerator.PojoSpan span = generateRandomSpan(InternalSpanTypes.TEST, extraTags); + + Map deserializedSpan = whenASpanIsWritten(span); + + verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456L, 789L); + + Map spanContent = getContent(deserializedSpan); + assertTrue(spanContent.containsKey("trace_id")); + assertTrue(spanContent.containsKey("span_id")); + assertTrue(spanContent.containsKey("parent_id")); + } + + @Test + void truncatesMetaStringValuesAndPreservesMetricsAndTopLevelIds() { + String longValue = repeat("a", MAX_META_STRING_VALUE_LENGTH + 1); + String exactValue = repeat("b", MAX_META_STRING_VALUE_LENGTH); + Map extraTags = new HashMap<>(); + extraTags.put(Tags.TEST_SESSION_ID, DDTraceId.from(123)); + extraTags.put(Tags.TEST_MODULE_ID, 456L); + extraTags.put(Tags.TEST_SUITE_ID, 789L); + extraTags.put("custom.tag", longValue); + extraTags.put("exact.tag", exactValue); + extraTags.put("custom.metric", 42); + TraceGenerator.PojoSpan span = generateRandomSpan(InternalSpanTypes.TEST, extraTags); + + Map deserializedSpan = whenASpanIsWritten(span); + + verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456L, 789L); + + Map spanContent = getContent(deserializedSpan); + Map deserializedMetrics = getMetrics(spanContent); + Map deserializedMeta = getMeta(spanContent); + + assertEquals( + longValue.substring(0, MAX_META_STRING_VALUE_LENGTH), deserializedMeta.get("custom.tag")); + assertEquals( + MAX_META_STRING_VALUE_LENGTH, ((String) deserializedMeta.get("custom.tag")).length()); + assertEquals(exactValue, deserializedMeta.get("exact.tag")); + assertEquals(42, deserializedMetrics.get("custom.metric")); + } + + @Test + void truncatesPayloadMetadataValues() { + String longValue = repeat("m", MAX_META_STRING_VALUE_LENGTH + 1); + CiVisibilityWellKnownTags wellKnownTags = + new CiVisibilityWellKnownTags( + longValue, longValue, longValue, longValue, longValue, longValue, longValue, longValue, + longValue, longValue); + CiTestCycleMapperV1 mapper = new CiTestCycleMapperV1(wellKnownTags, false); + List> traces = + Collections.singletonList( + Collections.singletonList( + generateRandomSpan(InternalSpanTypes.TEST, Collections.emptyMap()))); + PayloadVerifier verifier = new PayloadVerifier(wellKnownTags, traces, mapper); + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(100 << 10, verifier)); + + packer.format(traces.get(0), mapper); + packer.flush(); + + verifier.verifyTracesConsumed(); + } + + @Test + void verifyTestSuiteEndEventIsWrittenCorrectly() { + Map extraTags = new HashMap<>(); + extraTags.put(Tags.TEST_SESSION_ID, DDTraceId.from(123)); + extraTags.put(Tags.TEST_MODULE_ID, 456L); + extraTags.put(Tags.TEST_SUITE_ID, 789L); + TraceGenerator.PojoSpan span = generateRandomSpan(InternalSpanTypes.TEST_SUITE_END, extraTags); + + Map deserializedSpan = whenASpanIsWritten(span); + + verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456L, 789L); + + Map spanContent = getContent(deserializedSpan); + assertFalse(spanContent.containsKey("trace_id")); + assertFalse(spanContent.containsKey("span_id")); + assertFalse(spanContent.containsKey("parent_id")); + } + + @Test + void verifyTestModuleEndEventIsWrittenCorrectly() { + Map extraTags = new HashMap<>(); + extraTags.put(Tags.TEST_SESSION_ID, DDTraceId.from(123)); + extraTags.put(Tags.TEST_MODULE_ID, 456L); + TraceGenerator.PojoSpan span = generateRandomSpan(InternalSpanTypes.TEST_MODULE_END, extraTags); + + Map deserializedSpan = whenASpanIsWritten(span); + + verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456L, null); + + Map spanContent = getContent(deserializedSpan); + assertFalse(spanContent.containsKey("trace_id")); + assertFalse(spanContent.containsKey("span_id")); + assertFalse(spanContent.containsKey("parent_id")); + } + + @Test + void verifyResultIsNotAffectedBySuccessiveMappingCalls() { + Map extraTags = new HashMap<>(); + extraTags.put(Tags.TEST_SESSION_ID, DDTraceId.from(123)); + extraTags.put(Tags.TEST_MODULE_ID, 456L); + extraTags.put(Tags.TEST_SUITE_ID, 789L); + TraceGenerator.PojoSpan span = generateRandomSpan(InternalSpanTypes.TEST, extraTags); + + whenASpanIsWritten(span); + Map deserializedSpan = whenASpanIsWritten(span); + + verifyTopLevelTags(deserializedSpan, DDTraceId.from(123), 456L, 789L); + + Map spanContent = getContent(deserializedSpan); + assertTrue(spanContent.containsKey("trace_id")); + assertTrue(spanContent.containsKey("span_id")); + assertTrue(spanContent.containsKey("parent_id")); + } + + @SuppressWarnings("unchecked") + private static Map getContent(Map deserializedSpan) { + return (Map) deserializedSpan.get("content"); + } + + @SuppressWarnings("unchecked") + private static Map getMetrics(Map spanContent) { + return (Map) spanContent.get("metrics"); + } + + @SuppressWarnings("unchecked") + private static Map getMeta(Map spanContent) { + return (Map) spanContent.get("meta"); + } + + private static void verifyTopLevelTags( + Map deserializedSpan, + DDTraceId testSessionId, + Long testModuleId, + Long testSuiteId) { + Map spanContent = getContent(deserializedSpan); + Map deserializedMetrics = getMetrics(spanContent); + Map deserializedMeta = getMeta(spanContent); + + if (testSessionId != null) { + assertEquals( + testSessionId.toLong(), ((Number) spanContent.get(Tags.TEST_SESSION_ID)).longValue()); + } else { + assertFalse(spanContent.containsKey(Tags.TEST_SESSION_ID)); + } + + if (testModuleId != null) { + assertEquals( + testModuleId.longValue(), ((Number) spanContent.get(Tags.TEST_MODULE_ID)).longValue()); + } else { + assertFalse(spanContent.containsKey(Tags.TEST_MODULE_ID)); + } + + if (testSuiteId != null) { + assertEquals( + testSuiteId.longValue(), ((Number) spanContent.get(Tags.TEST_SUITE_ID)).longValue()); + } else { + assertFalse(spanContent.containsKey(Tags.TEST_SUITE_ID)); + } + + assertFalse(deserializedMetrics.containsKey(Tags.TEST_SESSION_ID)); + assertFalse(deserializedMetrics.containsKey(Tags.TEST_MODULE_ID)); + assertFalse(deserializedMetrics.containsKey(Tags.TEST_SUITE_ID)); + + assertFalse(deserializedMeta.containsKey(Tags.TEST_SESSION_ID)); + assertFalse(deserializedMeta.containsKey(Tags.TEST_MODULE_ID)); + assertFalse(deserializedMeta.containsKey(Tags.TEST_SUITE_ID)); + } + + @SuppressWarnings("unchecked") + private static Map whenASpanIsWritten(TraceGenerator.PojoSpan span) { + List trace = Collections.singletonList(span); + + CiVisibilityWellKnownTags wellKnownTags = + new CiVisibilityWellKnownTags( + "runtimeid", + "my-env", + "language", + "my-runtime-name", + "my-runtime-version", + "my-runtime-vendor", + "my-os-arch", + "my-os-platform", + "my-os-version", + "false"); + CiTestCycleMapperV1 mapper = new CiTestCycleMapperV1(wellKnownTags, false); + + CaptureConsumer consumer = new CaptureConsumer(); + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(100 << 10, consumer)); + + packer.format(trace, mapper); + packer.flush(); + + ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()); + try { + return (Map) objectMapper.readValue(consumer.bytes, Object.class); + } catch (IOException e) { + fail("Failed to deserialize span: " + e.getMessage()); + return null; + } + } + + private static String repeat(String s, int count) { + StringBuilder sb = new StringBuilder(s.length() * count); + for (int i = 0; i < count; i++) { + sb.append(s); + } + return sb.toString(); + } + + private static void assertEqualsWithNullAsEmpty(CharSequence expected, CharSequence actual) { + if (null == expected) { + assertEquals("", actual); + } else { + assertEquals(expected.toString(), actual.toString()); + } + } + + private static final class CaptureConsumer implements ByteBufferConsumer { + private byte[] bytes; + + @Override + public void accept(int messageCount, ByteBuffer buffer) { + this.bytes = new byte[buffer.limit() - buffer.position()]; + buffer.get(bytes); + } + } + + private static final class PayloadVerifier implements ByteBufferConsumer, WritableByteChannel { + + private final List> expectedTraces; + private final CiTestCycleMapperV1 mapper; + private final CiVisibilityWellKnownTags wellKnownTags; + private ByteBuffer captured = ByteBuffer.allocate(200 << 10); + + private int position = 0; + + private PayloadVerifier( + CiVisibilityWellKnownTags wellKnownTags, + List> traces, + CiTestCycleMapperV1 mapper) { + this.expectedTraces = traces; + this.mapper = mapper; + this.wellKnownTags = wellKnownTags; + } + + void skipLargeTrace() { + ++position; + } + + void verifyTracesConsumed() { + assertEquals(expectedTraces.size(), position); + } + + @Override + public void accept(int messageCount, ByteBuffer buffer) { + if (expectedTraces.isEmpty() && messageCount == 0) { + return; + } + + try { + Payload payload = mapper.newPayload().withBody(messageCount, buffer); + payload.writeTo(this); + captured.flip(); + assertNotNull(payload.toRequest()); + MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(captured); + assertEquals(3, unpacker.unpackMapHeader()); + assertEquals("version", unpacker.unpackString()); + assertEquals(1, unpacker.unpackInt()); + assertEquals("metadata", unpacker.unpackString()); + assertEquals(1, unpacker.unpackMapHeader()); + assertEquals("*", unpacker.unpackString()); + + assertEquals(10, unpacker.unpackMapHeader()); + assertEquals("env", unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getEnv().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals("runtime-id", unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getRuntimeId().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals("language", unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getLanguage().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(Tags.RUNTIME_NAME, unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getRuntimeName().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(Tags.RUNTIME_VENDOR, unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getRuntimeVendor().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(Tags.RUNTIME_VERSION, unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getRuntimeVersion().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(Tags.OS_ARCHITECTURE, unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getOsArch().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(Tags.OS_PLATFORM, unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getOsPlatform().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(Tags.OS_VERSION, unpacker.unpackString()); + assertEquals( + truncate(wellKnownTags.getOsVersion().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + assertEquals(DDTags.TEST_IS_USER_PROVIDED_SERVICE, unpacker.unpackString()); + assertEquals( + truncate( + wellKnownTags.getIsUserProvidedService().toString(), MAX_META_STRING_VALUE_LENGTH), + unpacker.unpackString()); + + assertEquals("events", unpacker.unpackString()); + + List expectedTrace = expectedTraces.get(position++); + int eventCount = unpacker.unpackArrayHeader(); + while (expectedTrace.size() < eventCount) { + expectedTrace.addAll(expectedTraces.get(position++)); + } + assertEquals(expectedTrace.size(), eventCount); + for (int k = 0; k < eventCount; ++k) { + TraceGenerator.PojoSpan expectedSpan = expectedTrace.get(k); + assertEquals(3, unpacker.unpackMapHeader()); + assertEquals("type", unpacker.unpackString()); + if ("test".equals(String.valueOf(expectedSpan.getType()))) { + assertEquals("test", unpacker.unpackString()); + } else { + assertEquals("span", unpacker.unpackString()); + } + assertEquals("version", unpacker.unpackString()); + assertEquals(1, unpacker.unpackInt()); + assertEquals("content", unpacker.unpackString()); + assertEquals(11, unpacker.unpackMapHeader()); + assertEquals("trace_id", unpacker.unpackString()); + long traceId = unpacker.unpackValue().asNumberValue().toLong(); + assertEquals(expectedSpan.getTraceId().toLong(), traceId); + assertEquals("span_id", unpacker.unpackString()); + long spanId = unpacker.unpackValue().asNumberValue().toLong(); + assertEquals(expectedSpan.getSpanId(), spanId); + assertEquals("parent_id", unpacker.unpackString()); + long parentId = unpacker.unpackValue().asNumberValue().toLong(); + assertEquals(expectedSpan.getParentId(), parentId); + assertEquals("service", unpacker.unpackString()); + String serviceName = unpacker.unpackString(); + assertEqualsWithNullAsEmpty(expectedSpan.getServiceName(), serviceName); + assertEquals("name", unpacker.unpackString()); + String operationName = unpacker.unpackString(); + assertEqualsWithNullAsEmpty(expectedSpan.getOperationName(), operationName); + assertEquals("resource", unpacker.unpackString()); + String resourceName = unpacker.unpackString(); + assertEqualsWithNullAsEmpty(expectedSpan.getResourceName(), resourceName); + + assertEquals("start", unpacker.unpackString()); + long startTime = unpacker.unpackLong(); + assertEquals(expectedSpan.getStartTime(), startTime); + assertEquals("duration", unpacker.unpackString()); + long duration = unpacker.unpackLong(); + assertEquals(expectedSpan.getDurationNano(), duration); + assertEquals("error", unpacker.unpackString()); + int error = unpacker.unpackInt(); + assertEquals(expectedSpan.getError(), error); + assertEquals("metrics", unpacker.unpackString()); + int metricsSize = unpacker.unpackMapHeader(); + HashMap metrics = new HashMap<>(); + for (int j = 0; j < metricsSize; ++j) { + String key = unpacker.unpackString(); + Number n = null; + MessageFormat format = unpacker.getNextFormat(); + if (format == NEGFIXINT + || format == POSFIXINT + || format == INT8 + || format == UINT8 + || format == INT16 + || format == UINT16 + || format == INT32 + || format == UINT32) { + n = unpacker.unpackInt(); + } else if (format == INT64 || format == UINT64) { + n = unpacker.unpackLong(); + } else if (format == FLOAT32) { + n = unpacker.unpackFloat(); + } else if (format == FLOAT64) { + n = unpacker.unpackDouble(); + } else { + fail("Unexpected type in metrics values: " + format); + } + if (DD_MEASURED.toString().equals(key)) { + assertTrue( + (n != null && n.intValue() == 1 && expectedSpan.isMeasured()) + || !expectedSpan.isMeasured()); + } else if (DDSpanContext.PRIORITY_SAMPLING_KEY.equals(key)) { + // check that priority sampling is only on first and last span + if (k == 0 || k == eventCount - 1) { + assertEquals(expectedSpan.samplingPriority(), n.intValue()); + } else { + assertFalse(expectedSpan.hasSamplingPriority()); + } + } else { + metrics.put(key, n); + } + } + for (Map.Entry metric : metrics.entrySet()) { + if (metric.getValue() instanceof Double || metric.getValue() instanceof Float) { + assertEquals( + ((Number) expectedSpan.getTag(metric.getKey())).doubleValue(), + metric.getValue().doubleValue(), + 0.001); + } else { + assertEquals(expectedSpan.getTag(metric.getKey()), metric.getValue()); + } + } + assertEquals("meta", unpacker.unpackString()); + int metaSize = unpacker.unpackMapHeader(); + HashMap meta = new HashMap<>(); + for (int j = 0; j < metaSize; ++j) { + meta.put(unpacker.unpackString(), unpacker.unpackString()); + } + for (Map.Entry entry : meta.entrySet()) { + if (Tags.HTTP_STATUS.equals(entry.getKey())) { + assertEquals(String.valueOf(expectedSpan.getHttpStatusCode()), entry.getValue()); + } else { + Object tag = expectedSpan.getTag(entry.getKey()); + if (null != tag) { + assertEquals(String.valueOf(tag), entry.getValue()); + } else { + assertEquals(expectedSpan.getBaggage().get(entry.getKey()), entry.getValue()); + } + } + } + } + } catch (IOException e) { + fail(e.getMessage()); + } finally { + mapper.reset(); + captured.position(0); + captured.limit(captured.capacity()); + } + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (captured.remaining() < src.remaining()) { + ByteBuffer newBuffer = ByteBuffer.allocate(captured.capacity() + src.capacity()); + captured.flip(); + newBuffer.put(captured); + captured = newBuffer; + return write(src); + } + captured.put(src); + return src.position(); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() throws IOException {} + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/common/writer/TraceGenerator.java b/dd-trace-core/src/test/java/datadog/trace/common/writer/TraceGenerator.java new file mode 100644 index 00000000000..618b2ef77ae --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/common/writer/TraceGenerator.java @@ -0,0 +1,506 @@ +package datadog.trace.common.writer; + +import static datadog.trace.api.sampling.PrioritySampling.UNSET; +import static java.util.Collections.emptyList; + +import datadog.trace.api.DDSpanId; +import datadog.trace.api.DDTags; +import datadog.trace.api.DDTraceId; +import datadog.trace.api.IdGenerationStrategy; +import datadog.trace.api.ProcessTags; +import datadog.trace.api.TagMap; +import datadog.trace.api.sampling.PrioritySampling; +import datadog.trace.bootstrap.instrumentation.api.AgentSpanLink; +import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString; +import datadog.trace.core.CoreSpan; +import datadog.trace.core.Metadata; +import datadog.trace.core.MetadataConsumer; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +public class TraceGenerator { + + public static List> generateRandomTraces(int howMany, boolean lowCardinality) { + List> traces = new ArrayList<>(howMany); + for (int i = 0; i < howMany; ++i) { + int traceSize = ThreadLocalRandom.current().nextInt(2, 20); + traces.add(generateRandomTrace(traceSize, lowCardinality)); + } + return traces; + } + + private static List generateRandomTrace(int size, boolean lowCardinality) { + List trace = new ArrayList<>(size); + long traceId = ThreadLocalRandom.current().nextLong(1, Long.MAX_VALUE); + for (int i = 0; i < size; ++i) { + String spanType = "type-" + ThreadLocalRandom.current().nextInt(lowCardinality ? 1 : 100); + trace.add(randomSpan(traceId, lowCardinality, spanType, Collections.emptyMap())); + } + return trace; + } + + private static final IdGenerationStrategy ID_GENERATION_STRATEGY = + IdGenerationStrategy.fromName("RANDOM"); + + public static PojoSpan generateRandomSpan(CharSequence type, Map extraTags) { + long traceId = ThreadLocalRandom.current().nextLong(1, Long.MAX_VALUE); + return randomSpan(traceId, true, type, extraTags); + } + + private static PojoSpan randomSpan( + long traceId, boolean lowCardinality, CharSequence type, Map extraTags) { + ThreadLocalRandom random = ThreadLocalRandom.current(); + Map baggage = new HashMap<>(); + if (random.nextBoolean()) { + baggage.put("baggage-key", lowCardinality ? "x" : randomString(100)); + if (random.nextBoolean()) { + baggage.put("tag.1", "bar"); + baggage.put("tag.2", "qux"); + } + } + Map tags = new HashMap<>(extraTags); + int tagCount = random.nextInt(0, 20); + for (int i = 0; i < tagCount; ++i) { + tags.put("tag." + i, random.nextBoolean() ? "foo" : randomString(2000)); + tags.put("tag.1." + i, lowCardinality ? "y" : UUID.randomUUID()); + tags.put("tag.2." + i, random.nextBoolean()); + switch (random.nextInt(8)) { + case 0: + tags.put("tag.3." + i, BigDecimal.valueOf(random.nextDouble())); + break; + case 1: + tags.put("tag.3." + i, BigInteger.valueOf(random.nextLong())); + break; + default: + break; + } + } + int metricCount = random.nextInt(0, 20); + for (int i = 0; i < metricCount; ++i) { + String name = "metric." + i; + Number metric = null; + switch (random.nextInt(4)) { + case 0: + metric = random.nextInt(); + break; + case 1: + metric = random.nextLong(); + break; + case 2: + metric = random.nextFloat(); + break; + case 3: + metric = random.nextDouble(); + break; + } + tags.put(name, metric); + } + + return new PojoSpan( + "service-" + random.nextInt(lowCardinality ? 1 : 10), + "operation-" + random.nextInt(lowCardinality ? 1 : 100), + UTF8BytesString.create("resource-" + random.nextInt(lowCardinality ? 1 : 100)), + DDTraceId.from(traceId), + ID_GENERATION_STRATEGY.generateSpanId(), + DDSpanId.ZERO, + TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis()), + random.nextLong(500, 10_000_000), + random.nextInt(2), + baggage, + tags, + type, + random.nextBoolean(), + PrioritySampling.SAMPLER_KEEP, + 200, + "some-origin"); + } + + private static String randomString(int maxLength) { + char[] chars = new char[ThreadLocalRandom.current().nextInt(maxLength)]; + for (int i = 0; i < chars.length; ++i) { + char next = (char) ThreadLocalRandom.current().nextInt((int) Character.MAX_VALUE); + if (Character.isSurrogate(next)) { + if (i < chars.length - 1) { + chars[i++] = '\uD801'; + chars[i] = '\uDC01'; + } else { + chars[i] = 'a'; + } + } else { + chars[i] = next; + } + } + return new String(chars); + } + + public static class PojoSpan implements CoreSpan { + + private final CharSequence serviceName; + private final CharSequence operationName; + private final CharSequence resourceName; + private final DDTraceId traceId; + private final long spanId; + private final long parentId; + private final long start; + private final long duration; + private final int error; + private final CharSequence type; + private final boolean measured; + private final Metadata metadata; + private short httpStatusCode; + private final int samplingPriority; + private final Map metaStruct = new HashMap<>(); + + public PojoSpan( + String serviceName, + String operationName, + CharSequence resourceName, + DDTraceId traceId, + long spanId, + long parentId, + long start, + long duration, + int error, + Map baggage, + Map tags, + CharSequence type, + boolean measured, + int samplingPriority, + int statusCode, + CharSequence origin) { + this( + serviceName, + operationName, + resourceName, + traceId, + spanId, + parentId, + start, + duration, + error, + baggage, + tags, + type, + measured, + samplingPriority, + statusCode, + origin, + emptyList()); + } + + public PojoSpan( + String serviceName, + String operationName, + CharSequence resourceName, + DDTraceId traceId, + long spanId, + long parentId, + long start, + long duration, + int error, + Map baggage, + Map tags, + CharSequence type, + boolean measured, + int samplingPriority, + int statusCode, + CharSequence origin, + List spanLinks) { + this.serviceName = UTF8BytesString.create(serviceName); + this.operationName = UTF8BytesString.create(operationName); + this.resourceName = UTF8BytesString.create(resourceName); + this.traceId = traceId; + this.spanId = spanId; + this.parentId = parentId; + this.start = start; + this.duration = duration; + this.error = error; + this.type = type; + this.measured = measured; + this.samplingPriority = samplingPriority; + this.httpStatusCode = (short) statusCode; + this.metadata = + new Metadata( + Thread.currentThread().getId(), + UTF8BytesString.create(Thread.currentThread().getName()), + TagMap.fromMap(tags), + baggage, + samplingPriority, + measured, + isTopLevel(), + statusCode == 0 ? null : UTF8BytesString.create(Integer.toString(statusCode)), + origin, + 0, + ProcessTags.getTagsForSerialization(), + spanLinks); + } + + @Override + public PojoSpan getLocalRootSpan() { + return this; + } + + @Override + public String getServiceName() { + return serviceName.toString(); + } + + @Override + public CharSequence getOperationName() { + return operationName; + } + + @Override + public CharSequence getResourceName() { + return resourceName; + } + + @Override + public DDTraceId getTraceId() { + return traceId; + } + + @Override + public long getSpanId() { + return spanId; + } + + @Override + public long getParentId() { + return parentId; + } + + @Override + public long getStartTime() { + return start; + } + + @Override + public long getDurationNano() { + return duration; + } + + @Override + public int getError() { + return error; + } + + @Override + public PojoSpan setMeasured(boolean measured) { + return this; + } + + @Override + public PojoSpan setErrorMessage(String errorMessage) { + return this; + } + + @Override + public PojoSpan addThrowable(Throwable error) { + return this; + } + + @Override + public PojoSpan setTag(String tag, String value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, boolean value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, int value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, long value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, double value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, Number value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, CharSequence value) { + return this; + } + + @Override + public PojoSpan setTag(String tag, Object value) { + return this; + } + + @Override + public PojoSpan removeTag(String tag) { + metadata.getTags().remove(tag); + return this; + } + + @Override + public boolean isMeasured() { + return measured; + } + + @Override + public boolean isTopLevel() { + return false; + } + + @Override + public boolean isForceKeep() { + return false; + } + + @Override + public short getHttpStatusCode() { + return httpStatusCode; + } + + @Override + public CharSequence getOrigin() { + return metadata.getOrigin(); + } + + public Map getBaggage() { + return metadata.getBaggage(); + } + + public Map getTags() { + return metadata.getTags(); + } + + @Override + public CharSequence getType() { + return this.type; + } + + @Override + public void processServiceTags() {} + + @Override + public void processTagsAndBaggage(MetadataConsumer consumer) { + consumer.accept(metadata); + } + + @Override + public void processTagsAndBaggage( + MetadataConsumer consumer, boolean injectLinksAsTags, boolean injectBaggageAsTags) { + consumer.accept(metadata); + } + + @Override + public PojoSpan setSamplingPriority(int samplingPriority, int samplingMechanism) { + return this; + } + + @Override + public PojoSpan setSamplingPriority( + int samplingPriority, CharSequence rate, double sampleRate, int samplingMechanism) { + return this; + } + + @Override + public PojoSpan setSpanSamplingPriority(double rate, int limit) { + return this; + } + + @Override + public PojoSpan setMetric(CharSequence name, int value) { + return this; + } + + @Override + public PojoSpan setMetric(CharSequence name, long value) { + return this; + } + + @Override + public PojoSpan setMetric(CharSequence name, float value) { + return this; + } + + @Override + public PojoSpan setMetric(CharSequence name, double value) { + return this; + } + + @Override + public PojoSpan setFlag(CharSequence name, boolean value) { + return this; + } + + @Override + public int samplingPriority() { + return samplingPriority; + } + + @Override + public U getTag(CharSequence name, U defaultValue) { + U value = getTag(name); + return null == value ? defaultValue : value; + } + + @Override + @SuppressWarnings("unchecked") + public U getTag(CharSequence name) { + // replicate logic here because DDSpanContext has to pretend some of its + // fields are elements of a map for backward compatibility reasons + String tag = String.valueOf(name); + Object value = null; + switch (tag) { + case DDTags.THREAD_ID: + value = metadata.getThreadId(); + break; + case DDTags.THREAD_NAME: + value = metadata.getThreadName(); + break; + default: + value = metadata.getTags().get(tag); + } + return (U) value; + } + + @Override + public boolean hasSamplingPriority() { + return samplingPriority != UNSET; + } + + @Override + public Map getMetaStruct() { + return metaStruct; + } + + @Override + public PojoSpan setMetaStruct(String field, Object value) { + if (value == null) { + metaStruct.remove(field); + } else { + metaStruct.put(field, value); + } + return this; + } + + @Override + public int getLongRunningVersion() { + return 0; + } + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/common/writer/TraceStructureWriterTest.java b/dd-trace-core/src/test/java/datadog/trace/common/writer/TraceStructureWriterTest.java new file mode 100644 index 00000000000..67a83d58e01 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/common/writer/TraceStructureWriterTest.java @@ -0,0 +1,30 @@ +package datadog.trace.common.writer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import datadog.trace.core.DDCoreJavaSpecification; +import org.tabletest.junit.TableTest; + +public class TraceStructureWriterTest extends DDCoreJavaSpecification { + + @TableTest({ + "scenario | windows | cli | path ", + "windows path | true | C:/tmp/file | C:/tmp/file ", + "windows backslash path | true | C:\\tmp\\file | C:\\tmp\\file", + "windows file | true | file | file ", + "windows path with option | true | C:/tmp/file:includeresource | C:/tmp/file ", + "windows backslash path with option | true | C:\\tmp\\file:includeresource | C:\\tmp\\file", + "windows file with option | true | file:includeresource | file ", + "unix absolute path 1 | false | /var/tmp/file | /var/tmp/file", + "unix file 1 | false | file | file ", + "unix absolute path 2 | false | /var/tmp/file | /var/tmp/file", + "unix file 2 | false | file | file " + }) + void parseCliArgs(boolean windows, String cli, String path) { + String[] args = TraceStructureWriter.parseArgs(cli, windows); + + assertTrue(args.length > 0); + assertEquals(path, args[0]); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java new file mode 100644 index 00000000000..eed9ac52e9b --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaAppSecHandlerTest.java @@ -0,0 +1,1360 @@ +package datadog.trace.lambda; + +import static datadog.trace.api.gateway.Events.EVENTS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import datadog.trace.api.Config; +import datadog.trace.api.function.TriConsumer; +import datadog.trace.api.function.TriFunction; +import datadog.trace.api.gateway.CallbackProvider; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.gateway.IGSpanInfo; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.bootstrap.ActiveSubsystems; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import datadog.trace.bootstrap.instrumentation.api.TagContext; +import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter; +import datadog.trace.core.DDCoreJavaSpecification; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +@SuppressWarnings("unchecked") +public class LambdaAppSecHandlerTest extends DDCoreJavaSpecification { + + private static boolean originalAppSecActive; + private static AgentTracer.TracerAPI originalTracer; + + @BeforeAll + static void setupSpec() { + originalAppSecActive = ActiveSubsystems.APPSEC_ACTIVE; + originalTracer = AgentTracer.get(); + } + + @BeforeEach + void setup() { + ActiveSubsystems.APPSEC_ACTIVE = true; + } + + @AfterEach + void cleanup() { + ActiveSubsystems.APPSEC_ACTIVE = originalAppSecActive; + AgentTracer.forceRegister(originalTracer); + } + + // ============================================================================ + // processRequestStart basic tests + // ============================================================================ + + @Test + void processRequestStartReturnsNullWhenAppSecIsDisabled() { + ActiveSubsystems.APPSEC_ACTIVE = false; + ByteArrayInputStream event = createInputStream("{\"test\": \"data\"}"); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNull(result); + } + + @Test + void processRequestStartReturnsNullForNonByteArrayInputStream() { + AgentSpanContext result = LambdaAppSecHandler.processRequestStart("not a stream"); + + assertNull(result); + } + + @Test + void processRequestStartReturnsNullForNullEvent() { + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(null); + + assertNull(result); + } + + @Test + void processRequestStartReturnsNullForOversizedEvent() { + int maxSize = Config.get().getAppSecBodyParsingSizeLimit(); + String largeBody = repeatChar('x', maxSize + 1); + ByteArrayInputStream event = createInputStream(largeBody); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNull(result); + } + + @Test + void processRequestStartReturnsNullForZeroSizeEvent() { + ByteArrayInputStream event = createInputStream(""); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNull(result); + } + + @Test + void processRequestStartReturnsNullForMalformedJSON() { + ByteArrayInputStream event = createInputStream("{invalid json"); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNull(result); + } + + @Test + void streamCanBeReadMultipleTimesAfterProcessing() throws Exception { + String jsonData = "{\"test\": \"data\", \"requestContext\": {\"httpMethod\": \"GET\"}}"; + ByteArrayInputStream event = createInputStream(jsonData); + + LambdaAppSecHandler.processRequestStart(event); + event.reset(); + byte[] bytes = new byte[event.available()]; + event.read(bytes); + String content = new String(bytes, StandardCharsets.UTF_8); + + assertEquals(jsonData, content); + } + + // ============================================================================ + // Trigger Type Detection Tests + // ============================================================================ + + @Test + void detectsApiGatewayV1RestTriggerType() { + Map event = + mapOf("requestContext", mapOf("httpMethod", "GET", "requestId", "abc123")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V1_REST, triggerType); + } + + @Test + void detectsApiGatewayV2HttpTriggerType() { + Map event = + mapOf( + "requestContext", + mapOf( + "http", mapOf("method", "POST", "path", "/api"), "domainName", "api.example.com")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_HTTP, triggerType); + } + + @Test + void detectsLambdaFunctionUrlTriggerType() { + Map event = + mapOf( + "requestContext", + mapOf( + "http", + mapOf("method", "GET", "path", "/"), + "domainName", + "xyz123.lambda-url.us-east-1.on.aws")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, triggerType); + } + + @Test + void detectsAlbTriggerTypeWithoutMultiValueHeaders() { + Map event = + mapOf( + "httpMethod", + "GET", + "path", + "/", + "requestContext", + mapOf("elb", mapOf("targetGroupArn", "arn:aws:..."))); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.ALB, triggerType); + } + + @Test + void detectsAlbTriggerTypeWithMultiValueHeaders() { + Map event = + mapOf( + "httpMethod", + "GET", + "path", + "/", + "multiValueHeaders", + mapOf("accept", Arrays.asList("text/html", "application/json")), + "requestContext", + mapOf("elb", mapOf("targetGroupArn", "arn:aws:..."))); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.ALB_MULTI_VALUE, triggerType); + } + + @Test + void detectsWebSocketTriggerTypeWithRouteKey() { + Map event = + mapOf("requestContext", mapOf("connectionId", "conn-123", "routeKey", "$connect")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, triggerType); + } + + @Test + void detectsWebSocketTriggerTypeWithEventType() { + Map event = + mapOf("requestContext", mapOf("connectionId", "conn-456", "eventType", "CONNECT")); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.API_GATEWAY_V2_WEBSOCKET, triggerType); + } + + @Test + void detectsUnknownTriggerTypeForUnrecognizedEvents() { + Map event = mapOf("someUnknownField", "value"); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, triggerType); + } + + @Test + void detectsUnknownTriggerTypeForEmptyRequestContext() { + Map event = mapOf("requestContext", mapOf()); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.UNKNOWN, triggerType); + } + + @Test + void detectsLambdaUrlWhenHttpPresentButNoDomainName() { + Map event = + mapOf("requestContext", mapOf("http", mapOf("method", "GET", "path", "/ambiguous"))); + + LambdaAppSecHandler.LambdaTriggerType triggerType = + LambdaAppSecHandler.detectTriggerType(event); + + assertEquals(LambdaAppSecHandler.LambdaTriggerType.LAMBDA_URL, triggerType); + } + + // ============================================================================ + // Data Extraction Tests with Mocked Callbacks + // ============================================================================ + + @Test + void extractsApiGatewayV1RestDataCorrectly() { + String eventJson = + "{\n" + + " \"path\": \"/api/users/123\",\n" + + " \"httpMethod\": \"POST\",\n" + + " \"headers\": {\n" + + " \"Content-Type\": \"application/json\",\n" + + " \"Authorization\": \"Bearer token123\"\n" + + " },\n" + + " \"pathParameters\": {\n" + + " \"userId\": \"123\"\n" + + " },\n" + + " \"body\": \"{\\\"name\\\": \\\"John\\\"}\",\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"POST\",\n" + + " \"requestId\": \"req-123\",\n" + + " \"identity\": {\n" + + " \"sourceIp\": \"192.168.1.100\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + Map capturedHeaders = new HashMap<>(); + String[] capturedSourceIp = {null}; + Integer[] capturedSourcePort = {null}; + Map[] capturedPathParams = new Map[] {null}; + Object[] capturedBody = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onHeader((name, value) -> capturedHeaders.put(name, value)) + .onSocketAddress( + (ip, port) -> { + capturedSourceIp[0] = ip; + capturedSourcePort[0] = port; + }) + .onPathParams(params -> capturedPathParams[0] = params) + .onBody(body -> capturedBody[0] = body)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertInstanceOf(TagContext.class, result); + + assertEquals("POST", capturedMethod[0]); + assertEquals("/api/users/123", capturedPath[0]); + assertEquals("application/json", capturedHeaders.get("Content-Type")); + assertEquals("Bearer token123", capturedHeaders.get("Authorization")); + assertEquals("192.168.1.100", capturedSourceIp[0]); + assertEquals(Integer.valueOf(0), capturedSourcePort[0]); + assertEquals("123", ((Map) capturedPathParams[0]).get("userId")); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("John", ((Map) capturedBody[0]).get("name")); + } + + @Test + void extractsApiGatewayV2HttpDataCorrectly() { + String eventJson = + "{\n" + + " \"version\": \"2.0\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-custom-header\": \"custom-value\"\n" + + " },\n" + + " \"cookies\": [\"session=abc123\", \"user=john\"],\n" + + " \"pathParameters\": {\n" + + " \"id\": \"456\"\n" + + " },\n" + + " \"body\": \"test body\",\n" + + " \"requestContext\": {\n" + + " \"http\": {\n" + + " \"method\": \"PUT\",\n" + + " \"path\": \"/api/items/456\",\n" + + " \"sourceIp\": \"10.0.0.50\",\n" + + " \"sourcePort\": 54321\n" + + " },\n" + + " \"domainName\": \"api.example.com\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + Map capturedHeaders = new HashMap<>(); + String[] capturedSourceIp = {null}; + Integer[] capturedSourcePort = {null}; + Map[] capturedPathParams = new Map[] {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onHeader((name, value) -> capturedHeaders.put(name, value)) + .onSocketAddress( + (ip, port) -> { + capturedSourceIp[0] = ip; + capturedSourcePort[0] = port; + }) + .onPathParams(params -> capturedPathParams[0] = params)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("PUT", capturedMethod[0]); + assertEquals("/api/items/456", capturedPath[0]); + assertEquals("application/json", capturedHeaders.get("content-type")); + assertEquals("custom-value", capturedHeaders.get("x-custom-header")); + assertEquals("session=abc123; user=john", capturedHeaders.get("cookie")); + assertEquals("10.0.0.50", capturedSourceIp[0]); + assertEquals(Integer.valueOf(54321), capturedSourcePort[0]); + assertEquals("456", ((Map) capturedPathParams[0]).get("id")); + } + + @Test + void extractsLambdaFunctionUrlDataCorrectly() { + String eventJson = + "{\n" + + " \"version\": \"2.0\",\n" + + " \"headers\": {\n" + + " \"host\": \"xyz.lambda-url.us-east-1.on.aws\"\n" + + " },\n" + + " \"requestContext\": {\n" + + " \"http\": {\n" + + " \"method\": \"GET\",\n" + + " \"path\": \"/function/path\",\n" + + " \"sourceIp\": \"1.2.3.4\"\n" + + " },\n" + + " \"domainName\": \"xyz.lambda-url.us-east-1.on.aws\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + })); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("GET", capturedMethod[0]); + assertEquals("/function/path", capturedPath[0]); + } + + @Test + void extractsAlbDataCorrectly() { + String eventJson = + "{\n" + + " \"path\": \"/alb/test\",\n" + + " \"httpMethod\": \"DELETE\",\n" + + " \"headers\": {\n" + + " \"x-forwarded-for\": \"203.0.113.42\",\n" + + " \"user-agent\": \"curl/7.64.1\"\n" + + " },\n" + + " \"requestContext\": {\n" + + " \"elb\": {\n" + + " \"targetGroupArn\": \"arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/my-target-group/50dc6c495c0c9188\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + String[] capturedSourceIp = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("DELETE", capturedMethod[0]); + assertEquals("/alb/test", capturedPath[0]); + assertEquals("203.0.113.42", capturedSourceIp[0]); + } + + @Test + void extractsAlbMultiValueHeadersCorrectly() { + String eventJson = + "{\n" + + " \"path\": \"/test\",\n" + + " \"httpMethod\": \"GET\",\n" + + " \"multiValueHeaders\": {\n" + + " \"accept\": [\"text/html\", \"application/json\"],\n" + + " \"x-custom\": [\"value1\", \"value2\"]\n" + + " },\n" + + " \"requestContext\": {\n" + + " \"elb\": {\n" + + " \"targetGroupArn\": \"arn:aws:...\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Map capturedHeaders = new HashMap<>(); + + setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("text/html, application/json", capturedHeaders.get("accept")); + assertEquals("value1, value2", capturedHeaders.get("x-custom")); + } + + @Test + void handlesMultiValueHeadersWithEmptyList() { + String eventJson = + "{\n" + + " \"path\": \"/test\",\n" + + " \"httpMethod\": \"GET\",\n" + + " \"multiValueHeaders\": {\n" + + " \"accept\": [],\n" + + " \"x-custom\": [\"value1\"]\n" + + " },\n" + + " \"requestContext\": {\n" + + " \"elb\": {\n" + + " \"targetGroupArn\": \"arn:aws:...\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Map capturedHeaders = new HashMap<>(); + + setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("", capturedHeaders.get("accept")); // Empty list should result in empty string + assertEquals("value1", capturedHeaders.get("x-custom")); + } + + @Test + void extractsWebSocketDataCorrectly() { + String eventJson = + "{\n" + + " \"requestContext\": {\n" + + " \"routeKey\": \"$connect\",\n" + + " \"connectionId\": \"conn-abc123\",\n" + + " \"identity\": {\n" + + " \"sourceIp\": \"192.168.0.100\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + String[] capturedSourceIp = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("WEBSOCKET", capturedMethod[0]); + assertEquals("$connect", capturedPath[0]); + assertEquals("192.168.0.100", capturedSourceIp[0]); + } + + @Test + void handlesBase64EncodedBodyCorrectly() { + String originalBody = "This is test data"; + String base64Body = Base64.getEncoder().encodeToString(originalBody.getBytes()); + String eventJson = + "{\n" + + " \"body\": \"" + + base64Body + + "\",\n" + + " \"isBase64Encoded\": true,\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"POST\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Object[] capturedBody = {null}; + + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals(originalBody, capturedBody[0]); + } + + @Test + void handlesNullBodyCorrectly() { + ByteArrayInputStream event = + createInputStream("{\"body\": null, \"requestContext\": {\"httpMethod\": \"GET\"}}"); + + String[] capturedBody = {"NOT_CALLED"}; + + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = String.valueOf(body))); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("NOT_CALLED", capturedBody[0]); // Callback should not be invoked for null body + } + + @Test + void handlesEmptyBodyCorrectly() { + ByteArrayInputStream event = + createInputStream("{\"body\": \"\", \"requestContext\": {\"httpMethod\": \"POST\"}}"); + + Object[] capturedBody = {null}; + + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("", capturedBody[0]); // Empty body is passed as empty string to WAF + } + + @Test + void handlesPathWithQueryStringCorrectly() { + String eventJson = + "{\n" + + " \"path\": \"/api/users?id=123&filter=active\",\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"GET\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedPath = {null}; + String[] capturedQuery = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedPath[0] = uri.path(); + capturedQuery[0] = uri.query(); + })); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("/api/users", capturedPath[0]); + assertEquals("id=123&filter=active", capturedQuery[0]); + } + + @Test + void extractsSchemeAndPortFromXForwardedHeaders() { + String eventJson = + "{\n" + + " \"path\": \"/api/test\",\n" + + " \"headers\": {\n" + + " \"x-forwarded-proto\": \"http\",\n" + + " \"x-forwarded-port\": \"8080\"\n" + + " },\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"GET\",\n" + + " \"requestId\": \"req-123\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedScheme = {null}; + Integer[] capturedPort = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + })); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("http", capturedScheme[0]); + assertEquals(Integer.valueOf(8080), capturedPort[0]); + } + + @Test + void fallsBackToHttps443WhenXForwardedHeadersAreAbsent() { + String eventJson = + "{\n" + + " \"path\": \"/api/test\",\n" + + " \"headers\": {},\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"GET\",\n" + + " \"requestId\": \"req-123\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedScheme = {null}; + Integer[] capturedPort = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + })); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("https", capturedScheme[0]); + assertEquals(Integer.valueOf(443), capturedPort[0]); + } + + @Test + void handlesInvalidXForwardedPortGracefully() { + String eventJson = + "{\n" + + " \"path\": \"/api/test\",\n" + + " \"headers\": {\n" + + " \"x-forwarded-proto\": \"https\",\n" + + " \"x-forwarded-port\": \"not-a-number\"\n" + + " },\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"GET\",\n" + + " \"requestId\": \"req-123\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedScheme = {null}; + Integer[] capturedPort = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedScheme[0] = uri.scheme(); + capturedPort[0] = uri.port(); + })); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("https", capturedScheme[0]); + assertEquals(Integer.valueOf(443), capturedPort[0]); + } + + @Test + void handlesInvalidBase64BodyGracefully() { + String eventJson = + "{\n" + + " \"body\": \"not-valid-base64\",\n" + + " \"isBase64Encoded\": true,\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"POST\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedBody = {"NOT_CALLED"}; + + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = String.valueOf(body))); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("NOT_CALLED", capturedBody[0]); // Should not call body callback when decode fails + } + + @Test + void handlesBase64DecodedEmptyStringBody() { + String base64Empty = Base64.getEncoder().encodeToString("".getBytes()); + String eventJson = + "{\n" + + " \"body\": \"" + + base64Empty + + "\",\n" + + " \"isBase64Encoded\": true,\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"POST\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Object[] capturedBody = {"NOT_CALLED"}; + + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("", capturedBody[0]); // Should pass empty string after decoding + } + + @Test + void handlesBodyWithSpecialCharacters() { + String eventJson = + "{\n" + + " \"body\": \"{\\\"text\\\": \\\"Hello \\u4e16\\u754c \\uD83C\\uDF0D\\\"}\",\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"POST\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Object[] capturedBody = {null}; + + setupMockCallbacks(new Callbacks().onBody(body -> capturedBody[0] = body)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertInstanceOf(Map.class, capturedBody[0]); + assertEquals("Hello δΈ–η•Œ 🌍", ((Map) capturedBody[0]).get("text")); + } + + // ============================================================================ + // Generic Data Extraction Tests + // ============================================================================ + + @Test + void extractsDataFromUnknownTriggerTypeUsingGenericExtraction() { + String eventJson = + "{\n" + + " \"path\": \"/generic/path\",\n" + + " \"httpMethod\": \"PATCH\",\n" + + " \"headers\": {\n" + + " \"x-custom-header\": \"generic-value\"\n" + + " },\n" + + " \"unknownField\": \"should be ignored\",\n" + + " \"requestContext\": {\n" + + " \"identity\": {\n" + + " \"sourceIp\": \"203.0.113.1\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + Map capturedHeaders = new HashMap<>(); + String[] capturedSourceIp = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onHeader((name, value) -> capturedHeaders.put(name, value)) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("PATCH", capturedMethod[0]); + assertEquals("/generic/path", capturedPath[0]); + assertEquals("generic-value", capturedHeaders.get("x-custom-header")); + assertEquals("203.0.113.1", capturedSourceIp[0]); + } + + @Test + void extractsDataFromUnknownTriggerWithHttpInRequestContext() { + String eventJson = + "{\n" + + " \"requestContext\": {\n" + + " \"http\": {\n" + + " \"method\": \"OPTIONS\",\n" + + " \"path\": \"/options/path\",\n" + + " \"sourceIp\": \"198.51.100.50\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + String[] capturedMethod = {null}; + String[] capturedPath = {null}; + String[] capturedSourceIp = {null}; + + setupMockCallbacks( + new Callbacks() + .onMethodUri( + (method, uri) -> { + capturedMethod[0] = method; + capturedPath[0] = uri.path(); + }) + .onSocketAddress((ip, port) -> capturedSourceIp[0] = ip)); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("OPTIONS", capturedMethod[0]); + assertEquals("/options/path", capturedPath[0]); + assertEquals("198.51.100.50", capturedSourceIp[0]); + } + + @Test + void handlesCookiesMergingWithExistingCookieHeader() { + String eventJson = + "{\n" + + " \"headers\": {\n" + + " \"cookie\": \"existing=value\"\n" + + " },\n" + + " \"cookies\": [\"new=cookie1\", \"another=cookie2\"],\n" + + " \"requestContext\": {\n" + + " \"http\": {\n" + + " \"method\": \"GET\",\n" + + " \"path\": \"/\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Map capturedHeaders = new HashMap<>(); + + setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertEquals("existing=value; new=cookie1; another=cookie2", capturedHeaders.get("cookie")); + } + + @Test + void handlesEmptyCookiesArrayCorrectly() { + String eventJson = + "{\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\"\n" + + " },\n" + + " \"cookies\": [],\n" + + " \"requestContext\": {\n" + + " \"http\": {\n" + + " \"method\": \"GET\",\n" + + " \"path\": \"/\"\n" + + " }\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Map capturedHeaders = new HashMap<>(); + + setupMockCallbacks(new Callbacks().onHeader((name, value) -> capturedHeaders.put(name, value))); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); + assertFalse(capturedHeaders.containsKey("cookie")); // Empty array should not add cookie header + } + + // ============================================================================ + // processRequestEnd Tests + // ============================================================================ + + @Test + void processRequestEndDoesNothingWhenSpanIsNull() { + // No exception should be thrown + LambdaAppSecHandler.processRequestEnd(null); + } + + @Test + void processRequestEndDoesNothingWhenAppSecIsDisabled() { + ActiveSubsystems.APPSEC_ACTIVE = false; + AgentSpan span = mock(AgentSpan.class); + + LambdaAppSecHandler.processRequestEnd(span); + + verifyNoInteractions(span); + } + + @Test + void processRequestEndDoesNothingWhenSpanHasNoRequestContext() { + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(null); + + // No exception should be thrown + LambdaAppSecHandler.processRequestEnd(span); + } + + @Test + void processRequestEndInvokesRequestEndedCallbackWithRequestContext() { + Object mockAppSecContext = new Object(); + RequestContext mockRequestContext = mock(RequestContext.class); + when(mockRequestContext.getData(RequestContextSlot.APPSEC)).thenReturn(mockAppSecContext); + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(mockRequestContext); + + boolean[] callbackInvoked = {false}; + RequestContext[] capturedContext = {null}; + AgentSpan[] capturedSpan = {null}; + + BiFunction> mockRequestEndedCallback = + mock(BiFunction.class); + doAnswer( + inv -> { + capturedContext[0] = inv.getArgument(0); + capturedSpan[0] = inv.getArgument(1); + callbackInvoked[0] = true; + return new Flow.ResultFlow<>(null); + }) + .when(mockRequestEndedCallback) + .apply(any(), any()); + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.requestEnded())) + .thenReturn(mockRequestEndedCallback); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + + AgentTracer.forceRegister(mockTracer); + + LambdaAppSecHandler.processRequestEnd(span); + + assertEquals(true, callbackInvoked[0]); + assertEquals(mockRequestContext, capturedContext[0]); + assertEquals(span, capturedSpan[0]); + } + + @Test + void processRequestEndHandlesNullRequestEndedCallbackGracefully() { + RequestContext mockRequestContext = mock(RequestContext.class); + AgentSpan span = mock(AgentSpan.class); + when(span.getRequestContext()).thenReturn(mockRequestContext); + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.requestEnded())).thenReturn(null); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + + AgentTracer.forceRegister(mockTracer); + + // No exception should be thrown - should log warning but not throw + LambdaAppSecHandler.processRequestEnd(span); + } + + // ============================================================================ + // mergeContexts Tests + // ============================================================================ + + @Test + void mergeContextsReturnsNullWhenBothContextsAreNull() { + AgentSpanContext result = LambdaAppSecHandler.mergeContexts(null, null); + + assertNull(result); + } + + @Test + void mergeContextsReturnsExtensionContextWhenAppSecContextIsNull() { + TagContext extensionContext = mock(TagContext.class); + + AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, null); + + assertEquals(extensionContext, result); + } + + @Test + void mergeContextsReturnsAppSecContextWhenExtensionContextIsNull() { + TagContext appSecContext = mock(TagContext.class); + + AgentSpanContext result = LambdaAppSecHandler.mergeContexts(null, appSecContext); + + assertEquals(appSecContext, result); + } + + @Test + void mergeContextsMergesAppSecDataIntoTagContext() { + Object appSecData = new Object(); + + // Create real TagContext instances since methods are final + TagContext appSecContext = new TagContext(); + appSecContext.withRequestContextDataAppSec(appSecData); + + TagContext extensionContext = new TagContext(); + + AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); + + assertEquals(extensionContext, result); + assertEquals(appSecData, ((TagContext) result).getRequestContextDataAppSec()); + } + + @Test + void mergeContextsReturnsExtensionContextWhenAppSecContextIsNotTagContext() { + TagContext extensionContext = mock(TagContext.class); + AgentSpanContext appSecContext = mock(AgentSpanContext.class); + + AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); + + assertEquals(extensionContext, result); + } + + @Test + void mergeContextsReturnsExtensionContextWhenItIsNotTagContext() { + AgentSpanContext extensionContext = mock(AgentSpanContext.class); + TagContext appSecContext = mock(TagContext.class); + + AgentSpanContext result = LambdaAppSecHandler.mergeContexts(extensionContext, appSecContext); + + assertEquals(extensionContext, result); + } + + // ============================================================================ + // Error Handling and Null Callback Tests + // ============================================================================ + + @Test + void processRequestStartHandlesNullRequestStartedCallbackGracefully() { + String eventJson = "{\"requestContext\": {\"httpMethod\": \"GET\"}}"; + ByteArrayInputStream event = createInputStream(eventJson); + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.requestStarted())).thenReturn(null); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + + AgentTracer.forceRegister(mockTracer); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNull(result); // Should return null when requestStarted callback is missing + } + + @Test + void processRequestStartHandlesNullMethodUriCallbackGracefully() { + String eventJson = + "{\n" + + " \"path\": \"/test\",\n" + + " \"requestContext\": {\n" + + " \"httpMethod\": \"GET\"\n" + + " }\n" + + "}"; + ByteArrayInputStream event = createInputStream(eventJson); + + Object mockAppSecContext = new Object(); + + Supplier> mockRequestStartedCallback = mock(Supplier.class); + when(mockRequestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); + + Function> mockHeaderDoneCallback = mock(Function.class); + when(mockHeaderDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.requestStarted())) + .thenReturn(mockRequestStartedCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestMethodUriRaw())).thenReturn(null); + when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(null); + when(mockCallbackProvider.getCallback(EVENTS.requestClientSocketAddress())).thenReturn(null); + when(mockCallbackProvider.getCallback(EVENTS.requestHeaderDone())) + .thenReturn(mockHeaderDoneCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestPathParams())).thenReturn(null); + when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())).thenReturn(null); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + + AgentTracer.forceRegister(mockTracer); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNotNull(result); // Should continue processing even if methodUri callback is null + assertInstanceOf(TagContext.class, result); + } + + @Test + void processRequestStartHandlesExceptionDuringJsonParsing() { + String invalidJson = "{this is not valid JSON at all"; + ByteArrayInputStream event = createInputStream(invalidJson); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(event); + + assertNull(result); // Should return null on parse error + } + + @Test + void processRequestStartHandlesExceptionDuringStreamReading() throws IOException { + ByteArrayInputStream mockStream = mock(ByteArrayInputStream.class); + when(mockStream.available()).thenThrow(new IOException("Stream error")); + + AgentSpanContext result = LambdaAppSecHandler.processRequestStart(mockStream); + + assertNull(result); // Should return null on IO error + } + + // ============================================================================ + // Helper classes and methods + // ============================================================================ + + private static class Callbacks { + BiConsumer onMethodUri; + BiConsumer onHeader; + BiConsumer onSocketAddress; + Consumer> onPathParams; + Consumer onBody; + + Callbacks onMethodUri(BiConsumer cb) { + this.onMethodUri = cb; + return this; + } + + Callbacks onHeader(BiConsumer cb) { + this.onHeader = cb; + return this; + } + + Callbacks onSocketAddress(BiConsumer cb) { + this.onSocketAddress = cb; + return this; + } + + Callbacks onPathParams(Consumer> cb) { + this.onPathParams = cb; + return this; + } + + Callbacks onBody(Consumer cb) { + this.onBody = cb; + return this; + } + } + + private static Map mapOf(Object... keysAndValues) { + Map map = new LinkedHashMap<>(); + for (int i = 0; i < keysAndValues.length; i += 2) { + map.put((String) keysAndValues[i], keysAndValues[i + 1]); + } + return map; + } + + private ByteArrayInputStream createInputStream(String json) { + return new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)); + } + + private void setupMockCallbacks(Callbacks callbacks) { + Object mockAppSecContext = new Object(); + + Supplier> mockRequestStartedCallback = mock(Supplier.class); + when(mockRequestStartedCallback.get()).thenReturn(new Flow.ResultFlow<>(mockAppSecContext)); + + TriFunction> mockMethodUriCallback = null; + if (callbacks.onMethodUri != null) { + mockMethodUriCallback = mock(TriFunction.class); + BiConsumer methodUriCb = callbacks.onMethodUri; + doAnswer( + inv -> { + String method = inv.getArgument(1); + URIDataAdapter uri = inv.getArgument(2); + methodUriCb.accept(method, uri); + return new Flow.ResultFlow<>(null); + }) + .when(mockMethodUriCallback) + .apply(any(), any(), any()); + } + + TriConsumer mockHeaderCallback = null; + if (callbacks.onHeader != null) { + mockHeaderCallback = mock(TriConsumer.class); + BiConsumer headerCb = callbacks.onHeader; + doAnswer( + inv -> { + String name = inv.getArgument(1); + String value = inv.getArgument(2); + headerCb.accept(name, value); + return null; + }) + .when(mockHeaderCallback) + .accept(any(), any(), any()); + } + + TriFunction> mockSocketAddressCallback = null; + if (callbacks.onSocketAddress != null) { + mockSocketAddressCallback = mock(TriFunction.class); + BiConsumer socketCb = callbacks.onSocketAddress; + doAnswer( + inv -> { + String ip = inv.getArgument(1); + Integer port = inv.getArgument(2); + socketCb.accept(ip, port); + return new Flow.ResultFlow<>(null); + }) + .when(mockSocketAddressCallback) + .apply(any(), any(), any()); + } + + Function> mockHeaderDoneCallback = mock(Function.class); + when(mockHeaderDoneCallback.apply(any())).thenReturn(new Flow.ResultFlow<>(null)); + + BiFunction, Flow> mockPathParamsCallback = null; + if (callbacks.onPathParams != null) { + mockPathParamsCallback = mock(BiFunction.class); + Consumer> pathParamsCb = callbacks.onPathParams; + doAnswer( + inv -> { + Map params = inv.getArgument(1); + pathParamsCb.accept(params); + return new Flow.ResultFlow<>(null); + }) + .when(mockPathParamsCallback) + .apply(any(), any()); + } + + BiFunction> mockBodyCallback = null; + if (callbacks.onBody != null) { + mockBodyCallback = mock(BiFunction.class); + Consumer bodyCb = callbacks.onBody; + doAnswer( + inv -> { + Object body = inv.getArgument(1); + bodyCb.accept(body); + return new Flow.ResultFlow<>(null); + }) + .when(mockBodyCallback) + .apply(any(), any()); + } + + CallbackProvider mockCallbackProvider = mock(CallbackProvider.class); + when(mockCallbackProvider.getCallback(EVENTS.requestStarted())) + .thenReturn(mockRequestStartedCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestMethodUriRaw())) + .thenReturn(mockMethodUriCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestHeader())).thenReturn(mockHeaderCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestClientSocketAddress())) + .thenReturn(mockSocketAddressCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestHeaderDone())) + .thenReturn(mockHeaderDoneCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestPathParams())) + .thenReturn(mockPathParamsCallback); + when(mockCallbackProvider.getCallback(EVENTS.requestBodyProcessed())) + .thenReturn(mockBodyCallback); + + AgentTracer.TracerAPI mockTracer = mock(AgentTracer.TracerAPI.class); + when(mockTracer.getCallbackProvider(RequestContextSlot.APPSEC)) + .thenReturn(mockCallbackProvider); + + AgentTracer.forceRegister(mockTracer); + } + + private static String repeatChar(char ch, int count) { + char[] chars = new char[count]; + Arrays.fill(chars, ch); + return new String(chars); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaHandlerTest.java b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaHandlerTest.java new file mode 100644 index 00000000000..89959205585 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/lambda/LambdaHandlerTest.java @@ -0,0 +1,333 @@ +package datadog.trace.lambda; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; +import com.amazonaws.services.lambda.runtime.events.S3Event; +import com.amazonaws.services.lambda.runtime.events.SNSEvent; +import com.amazonaws.services.lambda.runtime.events.SQSEvent; +import com.amazonaws.services.lambda.runtime.events.models.s3.S3EventNotification; +import datadog.trace.agent.test.server.http.JavaTestHttpServer; +import datadog.trace.api.DDSpanId; +import datadog.trace.api.DDTags; +import datadog.trace.api.DDTraceId; +import datadog.trace.bootstrap.instrumentation.api.AgentSpanContext; +import datadog.trace.core.CoreTracer; +import datadog.trace.core.DDCoreJavaSpecification; +import datadog.trace.core.DDSpan; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.tabletest.junit.TableTest; + +@SuppressWarnings({"unchecked", "rawtypes"}) +public class LambdaHandlerTest extends DDCoreJavaSpecification { + + static class TestObject { + public String field1; + public boolean field2; + + TestObject() { + this.field1 = "toto"; + this.field2 = true; + } + + @Override + public String toString() { + return field1 + " / " + field2 + "}"; + } + } + + @Test + void testStartInvocationSuccess() { + CoreTracer ct = tracerBuilder().build(); + + JavaTestHttpServer server = + JavaTestHttpServer.httpServer( + s -> + s.handlers( + h -> + h.post( + "/lambda/start-invocation", + api -> + api.getResponse() + .status(200) + .addHeader("x-datadog-trace-id", "1234") + .addHeader("x-datadog-sampling-priority", "2") + .send()))); + LambdaHandler.setExtensionBaseUrl(server.getAddress().toString()); + + AgentSpanContext objTest = + LambdaHandler.notifyStartInvocation(new TestObject(), "lambda-request-123"); + + assertEquals("1234", objTest.getTraceId().toString()); + assertEquals(2, objTest.getSamplingPriority()); + assertEquals( + "lambda-request-123", server.getLastRequest().getHeader("lambda-runtime-aws-request-id")); + + server.close(); + ct.close(); + } + + @Test + void testStartInvocationWith128BitTraceId() { + CoreTracer ct = tracerBuilder().build(); + + JavaTestHttpServer server = + JavaTestHttpServer.httpServer( + s -> + s.handlers( + h -> + h.post( + "/lambda/start-invocation", + api -> + api.getResponse() + .status(200) + .addHeader("x-datadog-trace-id", "5744042798732701615") + .addHeader("x-datadog-sampling-priority", "2") + .addHeader("x-datadog-tags", "_dd.p.tid=1914fe7789eb32be") + .send()))); + LambdaHandler.setExtensionBaseUrl(server.getAddress().toString()); + + AgentSpanContext objTest = + LambdaHandler.notifyStartInvocation(new TestObject(), "lambda-request-123"); + + assertEquals("1914fe7789eb32be4fb6f07e011a6faf", objTest.getTraceId().toHexString()); + assertEquals(2, objTest.getSamplingPriority()); + assertEquals( + "lambda-request-123", server.getLastRequest().getHeader("lambda-runtime-aws-request-id")); + + server.close(); + ct.close(); + } + + @Test + void testStartInvocationFailure() { + CoreTracer ct = tracerBuilder().build(); + + JavaTestHttpServer server = + JavaTestHttpServer.httpServer( + s -> + s.handlers( + h -> + h.post( + "/lambda/start-invocation", + api -> api.getResponse().status(500).send()))); + LambdaHandler.setExtensionBaseUrl(server.getAddress().toString()); + + AgentSpanContext objTest = + LambdaHandler.notifyStartInvocation(new TestObject(), "my-lambda-request"); + + assertNull(objTest); + assertEquals( + "my-lambda-request", server.getLastRequest().getHeader("lambda-runtime-aws-request-id")); + + server.close(); + ct.close(); + } + + @TableTest( + value = { + "scenario | expected | eHeaderValue | tIdHeaderValue | sIdHeaderValue | sPIdHeaderValue | lambdaResult | boolValue | lambdaReqIdHeaderValue", + "error with non-string result | true | 'true' | '1234' | '5678' | 2 | | true | 'request123' ", + "success with string result | true | | '1234' | '5678' | 2 | '12345 ' | false | 'request456' " + }) + void testEndInvocationSuccess( + boolean expected, + String eHeaderValue, + String tIdHeaderValue, + String sIdHeaderValue, + String sPIdHeaderValue, + Object lambdaResult, + boolean boolValue, + String lambdaReqIdHeaderValue) { + JavaTestHttpServer server = + JavaTestHttpServer.httpServer( + s -> + s.handlers( + h -> + h.post( + "/lambda/end-invocation", + api -> api.getResponse().status(200).send()))); + LambdaHandler.setExtensionBaseUrl(server.getAddress().toString()); + + DDSpan span = mock(DDSpan.class); + when(span.getTraceId()).thenReturn(DDTraceId.from("1234")); + when(span.getSpanId()).thenReturn(DDSpanId.from("5678")); + when(span.getSamplingPriority()).thenReturn(2); + + boolean result = + LambdaHandler.notifyEndInvocation(span, lambdaResult, boolValue, lambdaReqIdHeaderValue); + + assertEquals(eHeaderValue, server.getLastRequest().getHeader("x-datadog-invocation-error")); + assertEquals(tIdHeaderValue, server.getLastRequest().getHeader("x-datadog-trace-id")); + assertEquals(sIdHeaderValue, server.getLastRequest().getHeader("x-datadog-span-id")); + assertEquals(sPIdHeaderValue, server.getLastRequest().getHeader("x-datadog-sampling-priority")); + assertEquals( + lambdaReqIdHeaderValue, server.getLastRequest().getHeader("lambda-runtime-aws-request-id")); + assertEquals(expected, result); + + server.close(); + } + + @TableTest( + value = { + "scenario | expected | headerValue | lambdaResult | boolValue | lambdaReqIdHeaderValue", + "error with non-string result | false | 'true' | | true | 'request123' ", + "success with string result | false | | '12345' | false | 'request456' " + }) + void testEndInvocationFailure( + boolean expected, + String headerValue, + Object lambdaResult, + boolean boolValue, + String lambdaReqIdHeaderValue) { + JavaTestHttpServer server = + JavaTestHttpServer.httpServer( + s -> + s.handlers( + h -> + h.post( + "/lambda/end-invocation", + api -> api.getResponse().status(500).send()))); + LambdaHandler.setExtensionBaseUrl(server.getAddress().toString()); + + DDSpan span = mock(DDSpan.class); + when(span.getTraceId()).thenReturn(DDTraceId.from("1234")); + when(span.getSpanId()).thenReturn(DDSpanId.from("5678")); + when(span.getSamplingPriority()).thenReturn(2); + + boolean result = + LambdaHandler.notifyEndInvocation(span, lambdaResult, boolValue, lambdaReqIdHeaderValue); + + assertEquals(expected, result); + assertEquals(headerValue, server.getLastRequest().getHeader("x-datadog-invocation-error")); + assertEquals( + lambdaReqIdHeaderValue, server.getLastRequest().getHeader("lambda-runtime-aws-request-id")); + + server.close(); + } + + @Test + void testEndInvocationSuccessWithErrorMetadata() { + JavaTestHttpServer server = + JavaTestHttpServer.httpServer( + s -> + s.handlers( + h -> + h.post( + "/lambda/end-invocation", + api -> api.getResponse().status(200).send()))); + LambdaHandler.setExtensionBaseUrl(server.getAddress().toString()); + + DDSpan span = mock(DDSpan.class); + when(span.getTraceId()).thenReturn(DDTraceId.from("1234")); + when(span.getSpanId()).thenReturn(DDSpanId.from("5678")); + when(span.getSamplingPriority()).thenReturn(2); + when(span.getTag(DDTags.ERROR_MSG)).thenReturn("custom error message"); + when(span.getTag(DDTags.ERROR_TYPE)).thenReturn("java.lang.Throwable"); + when(span.getTag(DDTags.ERROR_STACK)).thenReturn("errorStack\n \ttest"); + + LambdaHandler.notifyEndInvocation(span, new Object(), true, "lambda-request-123"); + + assertEquals("true", server.getLastRequest().getHeader("x-datadog-invocation-error")); + assertEquals( + "custom error message", + server.getLastRequest().getHeader("x-datadog-invocation-error-msg")); + assertEquals( + "java.lang.Throwable", + server.getLastRequest().getHeader("x-datadog-invocation-error-type")); + assertEquals( + "ZXJyb3JTdGFjawogCXRlc3Q=", + server.getLastRequest().getHeader("x-datadog-invocation-error-stack")); + assertEquals( + "lambda-request-123", server.getLastRequest().getHeader("lambda-runtime-aws-request-id")); + + server.close(); + } + + @Test + void testMoshiToJsonSQSEvent() { + SQSEvent myEvent = new SQSEvent(); + List records = new ArrayList<>(); + SQSEvent.SQSMessage message = new SQSEvent.SQSMessage(); + message.setMessageId("myId"); + message.setAwsRegion("myRegion"); + records.add(message); + myEvent.setRecords(records); + + String result = LambdaHandler.writeValueAsString(myEvent); + + assertEquals("{\"records\":[{\"awsRegion\":\"myRegion\",\"messageId\":\"myId\"}]}", result); + } + + @Test + void testMoshiToJsonS3Event() { + List list = new ArrayList<>(); + S3EventNotification.S3EventNotificationRecord item0 = + new S3EventNotification.S3EventNotificationRecord( + "region", "eventName", "mySource", null, "3.4", null, null, null, null); + list.add(item0); + S3Event myEvent = new S3Event(list); + + String result = LambdaHandler.writeValueAsString(myEvent); + + assertEquals( + "{\"records\":[{\"awsRegion\":\"region\",\"eventName\":\"eventName\",\"eventSource\":\"mySource\",\"eventVersion\":\"3.4\"}]}", + result); + } + + @Test + void testMoshiToJsonSNSEvent() { + SNSEvent myEvent = new SNSEvent(); + List records = new ArrayList<>(); + SNSEvent.SNSRecord message = new SNSEvent.SNSRecord(); + message.setEventSource("mySource"); + message.setEventVersion("myVersion"); + records.add(message); + myEvent.setRecords(records); + + String result = LambdaHandler.writeValueAsString(myEvent); + + assertEquals( + "{\"records\":[{\"eventSource\":\"mySource\",\"eventVersion\":\"myVersion\"}]}", result); + } + + @Test + void testMoshiToJsonAPIGatewayProxyRequestEvent() { + APIGatewayProxyRequestEvent myEvent = new APIGatewayProxyRequestEvent(); + myEvent.setBody("bababango"); + myEvent.setHttpMethod("POST"); + + String result = LambdaHandler.writeValueAsString(myEvent); + + assertEquals("{\"body\":\"bababango\",\"httpMethod\":\"POST\"}", result); + } + + @Test + void testMoshiToJsonInputStream() { + String body = "{\"body\":\"bababango\",\"httpMethod\":\"POST\"}"; + ByteArrayInputStream myEvent = new ByteArrayInputStream(body.getBytes()); + + String result = LambdaHandler.writeValueAsString(myEvent); + + assertEquals(body, result); + } + + @Test + void testMoshiToJsonOutputStream() { + String body = "{\"body\":\"bababango\",\"statusCode\":\"200\"}"; + ByteArrayOutputStream myEvent = new ByteArrayOutputStream(); + byte[] bodyBytes = body.getBytes(); + myEvent.write(bodyBytes, 0, bodyBytes.length); + + String result = LambdaHandler.writeValueAsString(myEvent); + + assertEquals(body, result); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/lambda/SkipUnhandledTypeJsonSerializerTest.java b/dd-trace-core/src/test/java/datadog/trace/lambda/SkipUnhandledTypeJsonSerializerTest.java new file mode 100644 index 00000000000..588a05ca329 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/lambda/SkipUnhandledTypeJsonSerializerTest.java @@ -0,0 +1,173 @@ +package datadog.trace.lambda; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; +import com.amazonaws.services.lambda.runtime.events.SNSEvent; +import com.amazonaws.services.lambda.runtime.events.SQSEvent; +import com.squareup.moshi.JsonAdapter; +import com.squareup.moshi.Moshi; +import datadog.trace.core.DDCoreJavaSpecification; +import java.io.ByteArrayInputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; + +abstract class AbstractSerialize { + public String randomString; +} + +class SubClass extends AbstractSerialize { + SubClass() { + this.randomString = "tutu"; + } +} + +interface ApiRequestPath {} + +class LambdaRequest { + public boolean testBool; + public String emptyStr; + public Map emptyHeaders; +} + +class CustomRequest

extends LambdaRequest { + public P path; + public B body; +} + +public class SkipUnhandledTypeJsonSerializerTest extends DDCoreJavaSpecification { + + static class TestJsonObject { + public String field1; + public boolean field2; + public AbstractSerialize field3; + public NestedJsonObject field4; + public ByteArrayInputStream field5; + + TestJsonObject() { + this.field1 = "toto"; + this.field2 = true; + this.field3 = new SubClass(); + this.field4 = new NestedJsonObject(); + this.field5 = new ByteArrayInputStream(new byte[0]); + } + } + + static class NestedJsonObject { + public AbstractSerialize field; + + NestedJsonObject() { + this.field = new SubClass(); + } + } + + private static JsonAdapter buildAdapter() { + return new Moshi.Builder() + .add(SkipUnsupportedTypeJsonAdapter.newFactory()) + .build() + .adapter(Object.class); + } + + @Test + void testStringSerialization() { + JsonAdapter adapter = buildAdapter(); + + String result = adapter.toJson(new TestJsonObject()); + + assertEquals( + "{\"field1\":\"toto\",\"field2\":true,\"field3\":{},\"field4\":{\"field\":{}},\"field5\":{}}", + result); + } + + @Test + void testSimpleCase() { + JsonAdapter adapter = buildAdapter(); + + LinkedHashMap list = new LinkedHashMap<>(); + list.put("key0", "item0"); + list.put("key1", "item1"); + list.put("key2", "item2"); + String result = adapter.toJson(list); + + assertEquals("{\"key0\":\"item0\",\"key1\":\"item1\",\"key2\":\"item2\"}", result); + } + + @Test + void testSQSEvent() { + JsonAdapter adapter = buildAdapter(); + + SQSEvent myEvent = new SQSEvent(); + List records = new ArrayList<>(); + SQSEvent.SQSMessage message = new SQSEvent.SQSMessage(); + message.setMessageId("myId"); + message.setAwsRegion("myRegion"); + records.add(message); + myEvent.setRecords(records); + String result = adapter.toJson(myEvent); + + assertEquals("{\"records\":[{\"awsRegion\":\"myRegion\",\"messageId\":\"myId\"}]}", result); + } + + @Test + void testSNSEvent() { + JsonAdapter adapter = buildAdapter(); + + SNSEvent myEvent = new SNSEvent(); + List records = new ArrayList<>(); + SNSEvent.SNSRecord message = new SNSEvent.SNSRecord(); + message.setEventSource("mySource"); + message.setEventVersion("myVersion"); + records.add(message); + myEvent.setRecords(records); + String result = adapter.toJson(myEvent); + + assertEquals( + "{\"records\":[{\"eventSource\":\"mySource\",\"eventVersion\":\"myVersion\"}]}", result); + } + + @Test + void testAPIGatewayProxyRequestEvent() { + JsonAdapter adapter = buildAdapter(); + + APIGatewayProxyRequestEvent myEvent = new APIGatewayProxyRequestEvent(); + myEvent.setBody("bababango"); + myEvent.setHttpMethod("POST"); + String result = adapter.toJson(myEvent); + + assertEquals("{\"body\":\"bababango\",\"httpMethod\":\"POST\"}", result); + } + + @Test + void testMapStringObjectEvent() { + JsonAdapter adapter = buildAdapter(); + + HashMap myEvent = new HashMap<>(); + HashMap myNestedEvent = new HashMap<>(); + myNestedEvent.put("nestedKey0", "nestedValue1"); + myNestedEvent.put("nestedKey1", true); + myNestedEvent.put("nestedKey2", Arrays.asList("aaa", "bbb", "ccc", "dddd")); + myEvent.put("firstKey", new TestJsonObject()); + myEvent.put("secondKey", myNestedEvent); + String result = adapter.toJson(myEvent); + + assertEquals( + "{\"firstKey\":{\"field1\":\"toto\",\"field2\":true,\"field3\":{},\"field4\":{\"field\":{}},\"field5\":{}},\"secondKey\":{\"nestedKey2\":[\"aaa\",\"bbb\",\"ccc\",\"dddd\"],\"nestedKey0\":\"nestedValue1\",\"nestedKey1\":true}}", + result); + } + + @Test + @SuppressWarnings("rawtypes") + void testCustomPayload() { + JsonAdapter adapter = buildAdapter(); + + CustomRequest customPayload = new CustomRequest(); + String result = adapter.toJson(customPayload); + + assertEquals("{\"body\":{},\"path\":{},\"testBool\":false}", result); + } +} diff --git a/dd-trace-core/src/test/java/datadog/trace/llmobs/writer/ddintake/LLMObsSpanMapperTest.java b/dd-trace-core/src/test/java/datadog/trace/llmobs/writer/ddintake/LLMObsSpanMapperTest.java new file mode 100644 index 00000000000..af0b1476b89 --- /dev/null +++ b/dd-trace-core/src/test/java/datadog/trace/llmobs/writer/ddintake/LLMObsSpanMapperTest.java @@ -0,0 +1,416 @@ +package datadog.trace.llmobs.writer.ddintake; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.ObjectMapper; +import datadog.communication.serialization.ByteBufferConsumer; +import datadog.communication.serialization.FlushingBuffer; +import datadog.communication.serialization.msgpack.MsgPackWriter; +import datadog.trace.api.DDTags; +import datadog.trace.api.llmobs.LLMObs; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.api.InternalSpanTypes; +import datadog.trace.bootstrap.instrumentation.api.Tags; +import datadog.trace.common.writer.ListWriter; +import datadog.trace.core.CoreTracer; +import datadog.trace.core.DDCoreJavaSpecification; +import datadog.trace.core.DDSpan; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.msgpack.jackson.dataformat.MessagePackFactory; + +@SuppressWarnings("unchecked") +public class LLMObsSpanMapperTest extends DDCoreJavaSpecification { + + private static final ObjectMapper objectMapper = new ObjectMapper(new MessagePackFactory()); + + @Test + void testLLMObsSpanMapperSerialization() throws Exception { + LLMObsSpanMapper mapper = new LLMObsSpanMapper(); + CoreTracer tracer = tracerBuilder().writer(new ListWriter()).build(); + + // Create a real LLMObs span using the tracer + AgentSpan llmSpan = + tracer + .buildSpan("datadog", "openai.request") + .withResourceName("createCompletion") + .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) + .withTag("_ml_obs_tag.model_name", "gpt-4") + .withTag("_ml_obs_tag.model_provider", "openai") + .withTag("_ml_obs_metric.input_tokens", 50) + .withTag("_ml_obs_metric.output_tokens", 25) + .withTag("_ml_obs_metric.total_tokens", 75) + .withTag("_ml_obs_tag.session_id", "abc-123-session") + .start(); + + llmSpan.setSpanType(InternalSpanTypes.LLMOBS); + + Map toolCallArgs = Collections.singletonMap("location", "San Francisco"); + LLMObs.ToolCall toolCall = + LLMObs.ToolCall.from("get_weather", "function_call", "call_123", toolCallArgs); + LLMObs.ToolResult toolResult = + LLMObs.ToolResult.from( + "get_weather", "function_call_output", "call_123", "{\"temperature\":\"72F\"}"); + List inputMessages = + Arrays.asList( + LLMObs.LLMMessage.from("user", "Hello, what's the weather like?"), + LLMObs.LLMMessage.from( + "assistant", + null, + Collections.singletonList(toolCall), + Collections.singletonList(toolResult))); + List outputMessages = + Collections.singletonList( + LLMObs.LLMMessage.from("assistant", "I'll help you check the weather.")); + + Map chatTemplateEntry = new LinkedHashMap<>(); + chatTemplateEntry.put("role", "user"); + chatTemplateEntry.put("content", "Hello, what's the weather like in {{city}}?"); + Map prompt = new LinkedHashMap<>(); + prompt.put("id", "prompt_123"); + prompt.put("version", "1"); + prompt.put("variables", Collections.singletonMap("city", "San Francisco")); + prompt.put("chat_template", Collections.singletonList(chatTemplateEntry)); + + Map inputMap = new LinkedHashMap<>(); + inputMap.put("messages", inputMessages); + inputMap.put("prompt", prompt); + llmSpan.setTag("_ml_obs_tag.input", inputMap); + llmSpan.setTag("_ml_obs_tag.output", outputMessages); + + Map metadataMap = new LinkedHashMap<>(); + metadataMap.put("temperature", 0.7); + metadataMap.put("max_tokens", 100); + llmSpan.setTag("_ml_obs_tag.metadata", metadataMap); + + Map cityProp = Collections.singletonMap("type", "string"); + Map properties = Collections.singletonMap("city", cityProp); + Map schema = new LinkedHashMap<>(); + schema.put("type", "object"); + schema.put("properties", properties); + Map toolDef = new LinkedHashMap<>(); + toolDef.put("name", "get_weather"); + toolDef.put("description", "Get weather by city"); + toolDef.put("schema", schema); + llmSpan.setTag("_ml_obs_tag.tool_definitions", Collections.singletonList(toolDef)); + + llmSpan.setError(true); + llmSpan.setTag(DDTags.ERROR_MSG, "boom"); + llmSpan.setTag(DDTags.ERROR_TYPE, "java.lang.IllegalStateException"); + llmSpan.setTag(DDTags.ERROR_STACK, "stacktrace"); + + llmSpan.finish(); + + List trace = Collections.singletonList((DDSpan) llmSpan); + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer(); + // Keep all formatted spans in a single flush for this assertion. + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)); + + packer.format(trace, mapper); + packer.flush(); + + assertNotNull(sink.captured); + datadog.trace.common.writer.Payload payload = mapper.newPayload(); + payload.withBody(1, sink.captured); + + // Capture the size before the buffer is written and the body buffer is emptied. + int sizeInBytes = payload.sizeInBytes(); + + byte[] bytesWritten = writeTo(payload); + assertEquals(sizeInBytes, bytesWritten.length); + Map result = objectMapper.readValue(bytesWritten, Map.class); + + assertTrue(result.containsKey("event_type")); + assertEquals("span", result.get("event_type")); + assertTrue(result.containsKey("_dd.stage")); + assertEquals("raw", result.get("_dd.stage")); + assertTrue(result.containsKey("spans")); + assertNotNull(result.get("spans")); + List> spans = (List>) result.get("spans"); + assertTrue(spans instanceof List); + assertEquals(1, spans.size()); + + Map spanData = spans.get(0); + assertEquals("OpenAI.createCompletion", spanData.get("name")); + assertTrue(spanData.containsKey("span_id")); + assertTrue(spanData.containsKey("trace_id")); + assertTrue(spanData.containsKey("start_ns")); + assertTrue(spanData.containsKey("duration")); + assertTrue(spanData.containsKey("_dd")); + Map dd = (Map) spanData.get("_dd"); + assertEquals(dd.get("span_id"), spanData.get("span_id")); + assertEquals(dd.get("trace_id"), spanData.get("trace_id")); + assertEquals(dd.get("apm_trace_id"), spanData.get("trace_id")); + + // Top-level session_id field β€” what the LLM Trace Explorer's Sessions filter queries. + assertTrue(spanData.containsKey("session_id")); + assertEquals("abc-123-session", spanData.get("session_id")); + + assertTrue(spanData.containsKey("meta")); + Map meta = (Map) spanData.get("meta"); + assertEquals("llm", meta.get("span.kind")); + assertTrue(meta.containsKey("error")); + Map error = (Map) meta.get("error"); + assertEquals("boom", error.get("message")); + assertEquals("java.lang.IllegalStateException", error.get("type")); + assertEquals("stacktrace", error.get("stack")); + assertTrue(meta.containsKey("input")); + Map inputResult = (Map) meta.get("input"); + assertTrue(inputResult.containsKey("messages")); + List> inputMsgs = (List>) inputResult.get("messages"); + assertTrue(inputMsgs.get(0).containsKey("content")); + assertEquals("Hello, what's the weather like?", inputMsgs.get(0).get("content")); + assertTrue(inputMsgs.get(0).containsKey("role")); + assertEquals("user", inputMsgs.get(0).get("role")); + assertEquals("assistant", inputMsgs.get(1).get("role")); + assertFalse(inputMsgs.get(1).containsKey("content")); + List> toolCalls = + (List>) inputMsgs.get(1).get("tool_calls"); + assertEquals("get_weather", toolCalls.get(0).get("name")); + assertEquals("function_call", toolCalls.get(0).get("type")); + assertEquals("call_123", toolCalls.get(0).get("tool_id")); + assertEquals( + Collections.singletonMap("location", "San Francisco"), toolCalls.get(0).get("arguments")); + List> toolResults = + (List>) inputMsgs.get(1).get("tool_results"); + assertEquals("get_weather", toolResults.get(0).get("name")); + assertEquals("function_call_output", toolResults.get(0).get("type")); + assertEquals("call_123", toolResults.get(0).get("tool_id")); + assertEquals("{\"temperature\":\"72F\"}", toolResults.get(0).get("result")); + Map promptResult = (Map) inputResult.get("prompt"); + assertEquals("prompt_123", promptResult.get("id")); + assertEquals("1", promptResult.get("version")); + assertEquals(Collections.singletonMap("city", "San Francisco"), promptResult.get("variables")); + assertEquals(Collections.singletonList(chatTemplateEntry), promptResult.get("chat_template")); + assertTrue(meta.containsKey("output")); + Map outputResult = (Map) meta.get("output"); + assertTrue(outputResult.containsKey("messages")); + List> outputMsgs = (List>) outputResult.get("messages"); + assertTrue(outputMsgs.get(0).containsKey("content")); + assertEquals("I'll help you check the weather.", outputMsgs.get(0).get("content")); + assertTrue(outputMsgs.get(0).containsKey("role")); + assertEquals("assistant", outputMsgs.get(0).get("role")); + List> toolDefsResult = + (List>) meta.get("tool_definitions"); + assertEquals("get_weather", toolDefsResult.get(0).get("name")); + assertEquals("Get weather by city", toolDefsResult.get(0).get("description")); + assertEquals(schema, toolDefsResult.get(0).get("schema")); + assertTrue(meta.containsKey("metadata")); + + assertTrue(spanData.containsKey("metrics")); + Map metrics = (Map) spanData.get("metrics"); + assertEquals(50.0, ((Number) metrics.get("input_tokens")).doubleValue(), 0.0); + assertEquals(25.0, ((Number) metrics.get("output_tokens")).doubleValue(), 0.0); + assertEquals(75.0, ((Number) metrics.get("total_tokens")).doubleValue(), 0.0); + + assertTrue(spanData.containsKey("tags")); + List tags = (List) spanData.get("tags"); + assertTrue(tags.contains("language:jvm")); + assertTrue(tags.contains("session_id:abc-123-session")); + + tracer.close(); + } + + @Test + void testLLMObsSpanMapperWritesNoSpansWhenNoneAreLLMObsSpans() { + LLMObsSpanMapper mapper = new LLMObsSpanMapper(); + CoreTracer tracer = tracerBuilder().writer(new ListWriter()).build(); + + AgentSpan regularSpan1 = + tracer + .buildSpan("datadog", "http.request") + .withResourceName("GET /api/users") + .withTag("http.method", "GET") + .withTag("http.url", "https://example.com/api/users") + .start(); + regularSpan1.finish(); + + AgentSpan regularSpan2 = + tracer + .buildSpan("datadog", "database.query") + .withResourceName("SELECT * FROM users") + .withTag("db.type", "postgresql") + .start(); + regularSpan2.finish(); + + List trace = Arrays.asList((DDSpan) regularSpan1, (DDSpan) regularSpan2); + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer(); + // Keep all formatted spans in a single flush for this assertion. + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)); + + packer.format(trace, mapper); + packer.flush(); + + assertFalse(sink.captured != null); + + tracer.close(); + } + + @Test + void testConsecutivePackerFormatCallsAccumulateSpansFromMultipleTraces() throws Exception { + LLMObsSpanMapper mapper = new LLMObsSpanMapper(); + CoreTracer tracer = tracerBuilder().writer(new ListWriter()).build(); + + // First trace with 2 LLMObs spans + AgentSpan llmSpan1 = + tracer + .buildSpan("datadog", "chat-completion-1") + .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) + .withTag("_ml_obs_tag.model_name", "gpt-4") + .withTag("_ml_obs_tag.model_provider", "openai") + .start(); + llmSpan1.setSpanType(InternalSpanTypes.LLMOBS); + llmSpan1.finish(); + + AgentSpan llmSpan2 = + tracer + .buildSpan("datadog", "chat-completion-2") + .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) + .withTag("_ml_obs_tag.model_name", "gpt-3.5") + .withTag("_ml_obs_tag.model_provider", "openai") + .start(); + llmSpan2.setSpanType(InternalSpanTypes.LLMOBS); + llmSpan2.finish(); + + // Second trace with 1 LLMObs span + AgentSpan llmSpan3 = + tracer + .buildSpan("datadog", "chat-completion-3") + .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) + .withTag("_ml_obs_tag.model_name", "claude-3") + .withTag("_ml_obs_tag.model_provider", "anthropic") + .start(); + llmSpan3.setSpanType(InternalSpanTypes.LLMOBS); + llmSpan3.finish(); + + List trace1 = Arrays.asList((DDSpan) llmSpan1, (DDSpan) llmSpan2); + List trace2 = Collections.singletonList((DDSpan) llmSpan3); + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer(); + // Keep all formatted spans in a single flush for this assertion. + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)); + + packer.format(trace1, mapper); + packer.format(trace2, mapper); + packer.flush(); + + assertNotNull(sink.captured); + datadog.trace.common.writer.Payload payload = mapper.newPayload(); + payload.withBody(3, sink.captured); + + // Capture the size before the buffer is written and the body buffer is emptied. + int sizeInBytes = payload.sizeInBytes(); + + byte[] bytesWritten = writeTo(payload); + assertEquals(sizeInBytes, bytesWritten.length); + Map result = objectMapper.readValue(bytesWritten, Map.class); + + assertTrue(result.containsKey("event_type")); + assertEquals("span", result.get("event_type")); + assertTrue(result.containsKey("_dd.stage")); + assertEquals("raw", result.get("_dd.stage")); + assertTrue(result.containsKey("spans")); + List> spans = (List>) result.get("spans"); + assertTrue(spans instanceof List); + assertEquals(3, spans.size()); + + List spanNames = new ArrayList<>(); + for (Map span : spans) { + spanNames.add(span.get("name")); + } + assertTrue(spanNames.contains("chat-completion-1")); + assertTrue(spanNames.contains("chat-completion-2")); + assertTrue(spanNames.contains("chat-completion-3")); + + tracer.close(); + } + + @Test + void testLLMObsSpanMapperOmitsTopLevelSessionIdWhenNotSet() throws Exception { + LLMObsSpanMapper mapper = new LLMObsSpanMapper(); + CoreTracer tracer = tracerBuilder().writer(new ListWriter()).build(); + + AgentSpan llmSpan = + tracer + .buildSpan("datadog", "openai.request") + .withResourceName("createCompletion") + .withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND) + .withTag("_ml_obs_tag.model_name", "gpt-4") + .withTag("_ml_obs_tag.model_provider", "openai") + .start(); + llmSpan.setSpanType(InternalSpanTypes.LLMOBS); + llmSpan.finish(); + + List trace = Collections.singletonList((DDSpan) llmSpan); + CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer(); + MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(16 * 1024, sink)); + + packer.format(trace, mapper); + packer.flush(); + + assertNotNull(sink.captured); + datadog.trace.common.writer.Payload payload = mapper.newPayload(); + payload.withBody(1, sink.captured); + + byte[] bytesWritten = writeTo(payload); + Map result = objectMapper.readValue(bytesWritten, Map.class); + List> spans = (List>) result.get("spans"); + Map spanData = spans.get(0); + + // No top-level session_id field when the tag was never set. + assertFalse(spanData.containsKey("session_id")); + + // And no session_id entry leaks into tags[] either. + List tags = (List) spanData.get("tags"); + for (String tag : tags) { + assertFalse( + tag.startsWith("session_id:"), "tag should not start with session_id: but got: " + tag); + } + + tracer.close(); + } + + private static byte[] writeTo(datadog.trace.common.writer.Payload payload) throws IOException { + ByteArrayOutputStream channel = new ByteArrayOutputStream(); + payload.writeTo( + new WritableByteChannel() { + @Override + public int write(ByteBuffer src) throws IOException { + byte[] bytes = new byte[src.remaining()]; + src.get(bytes); + channel.write(bytes); + return bytes.length; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() throws IOException {} + }); + return channel.toByteArray(); + } + + static class CapturingByteBufferConsumer implements ByteBufferConsumer { + + ByteBuffer captured; + + @Override + public void accept(int messageCount, ByteBuffer buffer) { + captured = buffer; + } + } +} diff --git a/utils/junit-utils/src/main/java/datadog/trace/junit/utils/tabletest/DDSpanTypesConverter.java b/utils/junit-utils/src/main/java/datadog/trace/junit/utils/tabletest/DDSpanTypesConverter.java new file mode 100644 index 00000000000..7a7032fe185 --- /dev/null +++ b/utils/junit-utils/src/main/java/datadog/trace/junit/utils/tabletest/DDSpanTypesConverter.java @@ -0,0 +1,88 @@ +package datadog.trace.junit.utils.tabletest; + +import datadog.trace.api.DDSpanTypes; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.params.converter.ArgumentConversionException; +import org.junit.jupiter.params.converter.ArgumentConverter; + +public class DDSpanTypesConverter implements ArgumentConverter { + + @Override + public Object convert(Object source, ParameterContext context) + throws ArgumentConversionException { + if (source == null) { + return null; + } + if (source.toString().startsWith("DDSpanTypes.")) { + switch (source.toString()) { + case "DDSpanTypes.HTTP_CLIENT": + return DDSpanTypes.HTTP_CLIENT; + case "DDSpanTypes.HTTP_SERVER": + return DDSpanTypes.HTTP_SERVER; + case "DDSpanTypes.RPC": + return DDSpanTypes.RPC; + case "DDSpanTypes.CACHE": + return DDSpanTypes.CACHE; + case "DDSpanTypes.SOAP": + return DDSpanTypes.SOAP; + case "DDSpanTypes.SQL": + return DDSpanTypes.SQL; + case "DDSpanTypes.MONGO": + return DDSpanTypes.MONGO; + case "DDSpanTypes.CASSANDRA": + return DDSpanTypes.CASSANDRA; + case "DDSpanTypes.COUCHBASE": + return DDSpanTypes.COUCHBASE; + case "DDSpanTypes.REDIS": + return DDSpanTypes.REDIS; + case "DDSpanTypes.MEMCACHED": + return DDSpanTypes.MEMCACHED; + case "DDSpanTypes.ELASTICSEARCH": + return DDSpanTypes.ELASTICSEARCH; + case "DDSpanTypes.OPENSEARCH": + return DDSpanTypes.OPENSEARCH; + case "DDSpanTypes.HIBERNATE": + return DDSpanTypes.HIBERNATE; + case "DDSpanTypes.AEROSPIKE": + return DDSpanTypes.AEROSPIKE; + case "DDSpanTypes.DATANUCLEUS": + return DDSpanTypes.DATANUCLEUS; + case "DDSpanTypes.MESSAGE_CLIENT": + return DDSpanTypes.MESSAGE_CLIENT; + case "DDSpanTypes.MESSAGE_CONSUMER": + return DDSpanTypes.MESSAGE_CONSUMER; + case "DDSpanTypes.MESSAGE_PRODUCER": + return DDSpanTypes.MESSAGE_PRODUCER; + case "DDSpanTypes.MESSAGE_BROKER": + return DDSpanTypes.MESSAGE_BROKER; + case "DDSpanTypes.GRAPHQL": + return DDSpanTypes.GRAPHQL; + case "DDSpanTypes.TEST": + return DDSpanTypes.TEST; + case "DDSpanTypes.TEST_SUITE_END": + return DDSpanTypes.TEST_SUITE_END; + case "DDSpanTypes.TEST_MODULE_END": + return DDSpanTypes.TEST_MODULE_END; + case "DDSpanTypes.TEST_SESSION_END": + return DDSpanTypes.TEST_SESSION_END; + case "DDSpanTypes.VULNERABILITY": + return DDSpanTypes.VULNERABILITY; + case "DDSpanTypes.PROTOBUF": + return DDSpanTypes.PROTOBUF; + case "DDSpanTypes.MULE": + return DDSpanTypes.MULE; + case "DDSpanTypes.VALKEY": + return DDSpanTypes.VALKEY; + case "DDSpanTypes.WEBSOCKET": + return DDSpanTypes.WEBSOCKET; + case "DDSpanTypes.SERVERLESS": + return DDSpanTypes.SERVERLESS; + case "DDSpanTypes.LLMOBS": + return DDSpanTypes.LLMOBS; + default: + throw new ArgumentConversionException("Cannot convert " + source); + } + } + return source; + } +}