diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/build.gradle b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/build.gradle new file mode 100644 index 00000000000..b038a4f94fd --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/build.gradle @@ -0,0 +1,82 @@ +import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar + +plugins { + id 'java-test-fixtures' +} + +muzzle { + pass { + group = "jakarta.servlet" + module = 'jakarta.servlet-api' + versions = "[6.0,)" + } +} + +apply from: "$rootDir/gradle/java.gradle" +apply plugin: 'dd-trace-java.call-site-instrumentation' + +// jakarta.servlet-api dependencies are compiled with Java 11 and +// the gradle muzzle tasks uses the JVM gradle is running with +if (!JavaVersion.current().java11Compatible) { + project.afterEvaluate { + tasks.findAll {it.group == 'Muzzle' }.each { + logger.info("Disabling task $it.path (requires Java 11)") + it.enabled = false + } + } +} + +configurations { + javaxClassesToRelocate +} + +tasks.register('relocatedJavaxJar', ShadowJar) { + relocate 'javax.servlet', 'jakarta.servlet' + relocate 'datadog.trace.instrumentation.servlet3', 'datadog.trace.instrumentation.servlet6' + relocate 'datadog.trace.instrumentation.servlet', 'datadog.trace.instrumentation.servlet6' + + archiveClassifier.set('relocated-javax') + + configurations = [project.configurations.javaxClassesToRelocate] + + include '**/*.jar' + include '**/Servlet31InputStreamWrapper.class' + include '**/HttpServletGetInputStreamAdvice.class' + include '**/HttpServletGetReaderAdvice.class' + include '**/BufferedReaderWrapper.class' + include '**/ServletBlockingHelper.class' + include '**/AbstractServletInputStreamWrapper.class' + + includeEmptyDirs = false +} + +dependencies { + implementation files(relocatedJavaxJar.outputs.files) + compileOnly group: 'jakarta.servlet', name: 'jakarta.servlet-api', version: '6.1.0' + testImplementation group: 'jakarta.servlet', name: 'jakarta.servlet-api', version: '6.1.0' + testImplementation group: 'jakarta.servlet.jsp', name: 'jakarta.servlet.jsp-api', version: '3.0.0' + testRuntimeOnly project(':dd-java-agent:instrumentation:datadog:asm:iast-instrumenter') + + javaxClassesToRelocate project(':dd-java-agent:instrumentation:servlet:javax-servlet:javax-servlet-iast'), { + transitive = false + } + javaxClassesToRelocate project(':dd-java-agent:instrumentation:servlet:javax-servlet:javax-servlet-3.0'), { + transitive = false + } + + testFixturesApi(project(':dd-java-agent:instrumentation-testing')) { + exclude group: 'org.eclipse.jetty', module: 'jetty-server' + } + testFixturesCompileOnly group: 'jakarta.servlet', name: 'jakarta.servlet-api', version: '6.1.0' + + testImplementation libs.bundles.mockito + + testFixturesCompileOnly(libs.bundles.groovy) + testFixturesCompileOnly(libs.bundles.spock) + + // tested against jakarta.servlet-api 6.0+ +} + +tasks.named("jar", Jar) { + from zipTree(relocatedJavaxJar.outputs.files.asPath) +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/IastJakartaServletInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/IastJakartaServletInstrumentation.java new file mode 100644 index 00000000000..a99677f97f8 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/IastJakartaServletInstrumentation.java @@ -0,0 +1,81 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.hasSuperType; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.takesArgument; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import com.google.auto.service.AutoService; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.sink.ApplicationModule; +import datadog.trace.bootstrap.InstrumentationContext; +import jakarta.servlet.ServletContext; +import jakarta.servlet.http.HttpServlet; +import java.util.Collections; +import java.util.Map; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@AutoService(InstrumenterModule.class) +public class IastJakartaServletInstrumentation extends InstrumenterModule.Iast + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + public IastJakartaServletInstrumentation() { + super("servlet", "servlet-6"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpServlet"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return hasSuperType(named(hierarchyMarkerType())); + } + + @Override + public Map contextStore() { + return Collections.singletonMap("jakarta.servlet.ServletContext", Boolean.class.getName()); + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + isMethod() + .and(named("service")) + .and(isPublic()) + .and(takesArguments(2)) + .and(takesArgument(0, named("jakarta.servlet.ServletRequest"))) + .and(takesArgument(1, named("jakarta.servlet.ServletResponse"))), + getClass().getName() + "$IastAdvice"); + } + + @Override + protected boolean isOptOutEnabled() { + return true; + } + + public static class IastAdvice { + + @Advice.OnMethodExit(suppress = Throwable.class) + public static void after(@Advice.This final HttpServlet servlet) { + final ApplicationModule applicationModule = InstrumentationBridge.APPLICATION; + if (applicationModule == null) { + return; + } + final ServletContext context = servlet.getServletContext(); + if (InstrumentationContext.get(ServletContext.class, Boolean.class).get(context) != null) { + return; + } + InstrumentationContext.get(ServletContext.class, Boolean.class).put(context, true); + if (applicationModule != null) { + applicationModule.onRealPath(context.getRealPath("/")); + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/IastOptOutJakartaHttpServletRequestInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/IastOptOutJakartaHttpServletRequestInstrumentation.java new file mode 100644 index 00000000000..898dad3b155 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/IastOptOutJakartaHttpServletRequestInstrumentation.java @@ -0,0 +1,104 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.extendsClass; +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.*; + +import com.google.auto.service.AutoService; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.iast.*; +import datadog.trace.api.iast.sink.ApplicationModule; +import datadog.trace.bootstrap.InstrumentationContext; +import jakarta.servlet.ServletContext; +import jakarta.servlet.SessionTrackingMode; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpSession; +import java.util.*; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@SuppressWarnings("unused") +@AutoService(InstrumenterModule.class) +public class IastOptOutJakartaHttpServletRequestInstrumentation extends InstrumenterModule.Iast + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + + private static final String CLASS_NAME = + IastOptOutJakartaHttpServletRequestInstrumentation.class.getName(); + private static final ElementMatcher.Junction WRAPPER_CLASS = + named("jakarta.servlet.http.HttpServletRequestWrapper"); + + public IastOptOutJakartaHttpServletRequestInstrumentation() { + super("servlet", "servlet-6", "servlet-request"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpServletRequest"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())) + .and(not(WRAPPER_CLASS)) + .and(not(extendsClass(WRAPPER_CLASS))); + } + + @Override + protected boolean isOptOutEnabled() { + return true; + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + named("getSession").and(returns(named("jakarta.servlet.http.HttpSession"))).and(isPublic()), + CLASS_NAME + "$GetHttpSessionAdvice"); + } + + @Override + public Map contextStore() { + return Collections.singletonMap( + "jakarta.servlet.ServletContext", "jakarta.servlet.SessionTrackingMode"); + } + + public static class GetHttpSessionAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Sink(VulnerabilityTypes.SESSION_REWRITING) + public static void onExit( + @Advice.This final HttpServletRequest request, @Advice.Return final HttpSession session) { + if (session == null) { + return; + } + final ApplicationModule module = InstrumentationBridge.APPLICATION; + if (module == null) { + return; + } + final ServletContext context = request.getServletContext(); + + if (InstrumentationContext.get(ServletContext.class, SessionTrackingMode.class).get(context) + != null) { + return; + } + // We only want to report it once per application + InstrumentationContext.get(ServletContext.class, SessionTrackingMode.class) + .put(context, SessionTrackingMode.URL); + if (context.getEffectiveSessionTrackingModes() != null + && !context.getEffectiveSessionTrackingModes().isEmpty()) { + Set sessionTrackingModes = new HashSet<>(); + for (SessionTrackingMode mode : context.getEffectiveSessionTrackingModes()) { + sessionTrackingModes.add(mode.name()); + } + module.checkSessionTrackingModes(sessionTrackingModes); + } + } + } + + @Override + public int order() { + // apply this instrumentation after the regular servlet one. + return 1; + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletRequestCallSite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletRequestCallSite.java new file mode 100644 index 00000000000..5a99087d3b4 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletRequestCallSite.java @@ -0,0 +1,67 @@ +package datadog.trace.instrumentation.servlet6; + +import datadog.trace.agent.tooling.csi.CallSite; +import datadog.trace.api.iast.IastCallSites; +import datadog.trace.api.iast.IastContext; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Source; +import datadog.trace.api.iast.SourceTypes; +import datadog.trace.api.iast.propagation.PropagationModule; +import datadog.trace.bootstrap.instrumentation.api.AgentTracer; +import jakarta.servlet.http.HttpServletRequest; + +/** + * Calls to these methods are often triggered outside of customer code, we use call sites to avoid + * all these unwanted tainting operations + */ +@CallSite(spi = IastCallSites.class) +public class JakartaHttpServletRequestCallSite { + + @Source(SourceTypes.REQUEST_PATH) + @CallSite.After("java.lang.String jakarta.servlet.http.HttpServletRequest.getRequestURI()") + @CallSite.After("java.lang.String jakarta.servlet.http.HttpServletRequestWrapper.getRequestURI()") + @CallSite.After("java.lang.String jakarta.servlet.http.HttpServletRequest.getPathInfo()") + @CallSite.After("java.lang.String jakarta.servlet.http.HttpServletRequestWrapper.getPathInfo()") + @CallSite.After("java.lang.String jakarta.servlet.http.HttpServletRequest.getPathTranslated()") + @CallSite.After( + "java.lang.String jakarta.servlet.http.HttpServletRequestWrapper.getPathTranslated()") + public static String afterPath( + @CallSite.This final HttpServletRequest self, @CallSite.Return final String retValue) { + if (null != retValue && !retValue.isEmpty()) { + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + try { + final IastContext ctx = IastContext.Provider.get(AgentTracer.activeSpan()); + if (ctx != null) { + module.taintString(ctx, retValue, SourceTypes.REQUEST_PATH); + } + } catch (final Throwable e) { + module.onUnexpectedException("afterPath threw", e); + } + } + } + return retValue; + } + + @Source(SourceTypes.REQUEST_URI) + @CallSite.After("java.lang.StringBuffer jakarta.servlet.http.HttpServletRequest.getRequestURL()") + @CallSite.After( + "java.lang.StringBuffer jakarta.servlet.http.HttpServletRequestWrapper.getRequestURL()") + public static StringBuffer afterGetRequestURL( + @CallSite.This final HttpServletRequest self, @CallSite.Return final StringBuffer retValue) { + if (null != retValue && retValue.length() > 0) { + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + try { + final IastContext ctx = IastContext.Provider.get(AgentTracer.activeSpan()); + if (ctx != null) { + module.taintObject(ctx, retValue, SourceTypes.REQUEST_URI); + } + } catch (final Throwable e) { + module.onUnexpectedException("afterGetRequestURL threw", e); + } + } + } + return retValue; + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletRequestInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletRequestInstrumentation.java new file mode 100644 index 00000000000..08182203847 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletRequestInstrumentation.java @@ -0,0 +1,326 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.extendsClass; +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedOneOf; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.not; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import com.google.auto.service.AutoService; +import datadog.trace.advice.ActiveRequestContext; +import datadog.trace.advice.RequiresRequestContext; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.agent.tooling.iast.TaintableEnumeration; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.api.iast.IastContext; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Sink; +import datadog.trace.api.iast.Source; +import datadog.trace.api.iast.SourceTypes; +import datadog.trace.api.iast.VulnerabilityTypes; +import datadog.trace.api.iast.propagation.PropagationModule; +import datadog.trace.api.iast.sink.UnvalidatedRedirectModule; +import jakarta.servlet.http.Cookie; +import java.util.Enumeration; +import java.util.Map; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@SuppressWarnings("unused") +@AutoService(InstrumenterModule.class) +public class JakartaHttpServletRequestInstrumentation extends InstrumenterModule.Iast + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + + private static final String CLASS_NAME = JakartaHttpServletRequestInstrumentation.class.getName(); + private static final ElementMatcher.Junction WRAPPER_CLASS = + named("jakarta.servlet.http.HttpServletRequestWrapper"); + + public JakartaHttpServletRequestInstrumentation() { + super("servlet", "servlet-6", "servlet-request"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpServletRequest"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())) + .and(not(WRAPPER_CLASS)) + .and(not(extendsClass(WRAPPER_CLASS))); + } + + @Override + public String[] helperClassNames() { + return new String[] {"datadog.trace.agent.tooling.iast.TaintableEnumeration"}; + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + isMethod().and(named("getHeader")).and(takesArguments(String.class)), + CLASS_NAME + "$GetHeaderAdvice"); + transformer.applyAdvice( + isMethod().and(named("getHeaders")).and(takesArguments(String.class)), + CLASS_NAME + "$GetHeadersAdvice"); + transformer.applyAdvice( + isMethod().and(named("getHeaderNames")).and(takesArguments(0)), + CLASS_NAME + "$GetHeaderNamesAdvice"); + transformer.applyAdvice( + isMethod().and(named("getParameter")).and(takesArguments(String.class)), + CLASS_NAME + "$GetParameterAdvice"); + transformer.applyAdvice( + isMethod().and(named("getParameterValues")).and(takesArguments(String.class)), + CLASS_NAME + "$GetParameterValuesAdvice"); + transformer.applyAdvice( + isMethod().and(named("getParameterMap")).and(takesArguments(0)), + CLASS_NAME + "$GetParameterMapAdvice"); + transformer.applyAdvice( + isMethod().and(named("getParameterNames")).and(takesArguments(0)), + CLASS_NAME + "$GetParameterNamesAdvice"); + transformer.applyAdvice( + isMethod().and(named("getCookies")).and(takesArguments(0)), + CLASS_NAME + "$GetCookiesAdvice"); + transformer.applyAdvice( + isMethod().and(named("getQueryString")).and(takesArguments(0)), + CLASS_NAME + "$GetQueryStringAdvice"); + transformer.applyAdvice( + isMethod().and(namedOneOf("getInputStream", "getReader")).and(takesArguments(0)), + CLASS_NAME + "$GetBodyAdvice"); + transformer.applyAdvice( + isMethod().and(named("getRequestDispatcher")).and(takesArguments(String.class)), + CLASS_NAME + "$GetRequestDispatcherAdvice"); + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetHeaderAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_HEADER_VALUE) + public static void onExit( + @Advice.Argument(0) final String name, + @Advice.Return final String value, + @ActiveRequestContext RequestContext reqCtx) { + if (value == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintString(ctx, value, SourceTypes.REQUEST_HEADER_VALUE, name); + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetHeadersAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_HEADER_VALUE) + public static void onExit( + @Advice.Argument(0) final String name, + @Advice.Return(readOnly = false) Enumeration enumeration, + @ActiveRequestContext RequestContext reqCtx) { + if (enumeration == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + enumeration = + TaintableEnumeration.wrap( + ctx, enumeration, module, SourceTypes.REQUEST_HEADER_VALUE, name); + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetHeaderNamesAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_HEADER_NAME) + public static void onExit( + @Advice.Return(readOnly = false) Enumeration enumeration, + @ActiveRequestContext RequestContext reqCtx) { + if (enumeration == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + enumeration = + TaintableEnumeration.wrap( + ctx, enumeration, module, SourceTypes.REQUEST_HEADER_NAME, true); + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetParameterAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_PARAMETER_VALUE) + public static void onExit( + @Advice.Argument(0) final String name, + @Advice.Return final String value, + @ActiveRequestContext RequestContext reqCtx) { + if (value == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintString(ctx, value, SourceTypes.REQUEST_PARAMETER_VALUE, name); + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetParameterValuesAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_PARAMETER_VALUE) + public static void onExit( + @Advice.Argument(0) final String name, + @Advice.Return final String[] values, + @ActiveRequestContext RequestContext reqCtx) { + if (values == null || values.length == 0) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + for (final String value : values) { + module.taintString(ctx, value, SourceTypes.REQUEST_PARAMETER_VALUE, name); + } + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetParameterMapAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_PARAMETER_VALUE) + public static void onExit( + @Advice.Return final Map parameters, + @ActiveRequestContext RequestContext reqCtx) { + if (parameters == null || parameters.isEmpty()) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + for (final Map.Entry entry : parameters.entrySet()) { + final String name = entry.getKey(); + module.taintString(ctx, name, SourceTypes.REQUEST_PARAMETER_NAME, name); + final String[] values = entry.getValue(); + if (values != null) { + for (final String value : entry.getValue()) { + module.taintString(ctx, value, SourceTypes.REQUEST_PARAMETER_VALUE, name); + } + } + } + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetParameterNamesAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_PARAMETER_NAME) + public static void onExit( + @Advice.Return(readOnly = false) Enumeration enumeration, + @ActiveRequestContext RequestContext reqCtx) { + if (enumeration == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + enumeration = + TaintableEnumeration.wrap( + ctx, enumeration, module, SourceTypes.REQUEST_PARAMETER_NAME, true); + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetCookiesAdvice { + + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_COOKIE_VALUE) + public static void onExit( + @Advice.Return final Cookie[] cookies, @ActiveRequestContext RequestContext reqCtx) { + if (cookies == null || cookies.length == 0) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + for (final Cookie cookie : cookies) { + module.taintObject(ctx, cookie, SourceTypes.REQUEST_COOKIE_VALUE); + } + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetQueryStringAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_QUERY) + public static void onExit( + @Advice.Return final String queryString, @ActiveRequestContext RequestContext reqCtx) { + if (queryString == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintString(ctx, queryString, SourceTypes.REQUEST_QUERY); + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetBodyAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_BODY) + public static void onExit( + @Advice.Return final Object body, @ActiveRequestContext RequestContext reqCtx) { + if (body == null) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module == null) { + return; + } + IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintObject(ctx, body, SourceTypes.REQUEST_BODY); + } + } + + public static class GetRequestDispatcherAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Sink(VulnerabilityTypes.UNVALIDATED_REDIRECT) + public static void onExit(@Advice.Argument(0) final String path) { + if (path == null) { + return; + } + final UnvalidatedRedirectModule module = InstrumentationBridge.UNVALIDATED_REDIRECT; + if (module == null) { + return; + } + module.onRedirect(path); + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletResponseInstrumentation60.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletResponseInstrumentation60.java new file mode 100644 index 00000000000..2f488b2454e --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpServletResponseInstrumentation60.java @@ -0,0 +1,149 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.extendsClass; +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedOneOf; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.not; +import static net.bytebuddy.matcher.ElementMatchers.returns; +import static net.bytebuddy.matcher.ElementMatchers.takesArgument; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import com.google.auto.service.AutoService; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Propagation; +import datadog.trace.api.iast.Sink; +import datadog.trace.api.iast.VulnerabilityTypes; +import datadog.trace.api.iast.propagation.PropagationModule; +import datadog.trace.api.iast.sink.HttpResponseHeaderModule; +import datadog.trace.api.iast.sink.UnvalidatedRedirectModule; +import datadog.trace.api.iast.util.Cookie; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@AutoService(InstrumenterModule.class) +public final class JakartaHttpServletResponseInstrumentation60 extends InstrumenterModule.Iast + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + public JakartaHttpServletResponseInstrumentation60() { + super("servlet", "servlet-6", "servlet-response"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpServletResponse"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())) + .and(not(extendsClass(named("jakarta.servlet.http.HttpServletResponseWrapper")))); + } + + @Override + protected boolean isOptOutEnabled() { + return true; + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + named("addCookie") + .and(takesArguments(1)) + .and(takesArgument(0, named("jakarta.servlet.http.Cookie"))), + getClass().getName() + "$AddCookieAdvice"); + transformer.applyAdvice( + namedOneOf("setHeader", "addHeader").and(takesArguments(String.class, String.class)), + getClass().getName() + "$AddHeaderAdvice"); + transformer.applyAdvice( + namedOneOf("encodeRedirectURL", "encodeURL") + .and(takesArgument(0, String.class)) + .and(returns(String.class)), + getClass().getName() + "$EncodeURLAdvice"); + transformer.applyAdvice( + named("sendRedirect").and(takesArguments(1)).and(takesArgument(0, String.class)), + getClass().getName() + "$SendRedirectAdvice"); + transformer.applyAdvice( + named("sendRedirect") + .and(takesArguments(String.class, int.class, boolean.class)) + .and(isPublic()), + getClass().getName() + "$SendRedirect3ArgAdvice"); + } + + public static class AddCookieAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + @Sink(VulnerabilityTypes.RESPONSE_HEADER) + public static void onEnter(@Advice.Argument(0) final jakarta.servlet.http.Cookie cookie) { + if (cookie != null) { + HttpResponseHeaderModule mod = InstrumentationBridge.RESPONSE_HEADER_MODULE; + if (mod != null) { + String sameSite = cookie.getAttribute("SameSite"); + mod.onCookie( + Cookie.named(cookie.getName()) + .value(cookie.getValue()) + .secure(cookie.getSecure()) + .httpOnly(cookie.isHttpOnly()) + .maxAge(cookie.getMaxAge()) + .sameSite(sameSite) + .build()); + } + } + } + } + + public static class AddHeaderAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + @Sink(VulnerabilityTypes.RESPONSE_HEADER) + public static void onEnter( + @Advice.Argument(0) final String name, @Advice.Argument(1) String value) { + if (null != value && !value.isEmpty()) { + HttpResponseHeaderModule mod = InstrumentationBridge.RESPONSE_HEADER_MODULE; + if (mod != null) { + mod.onHeader(name, value); + } + } + } + } + + public static class SendRedirectAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + @Sink(VulnerabilityTypes.UNVALIDATED_REDIRECT) + public static void onEnter(@Advice.Argument(0) final String location) { + final UnvalidatedRedirectModule module = InstrumentationBridge.UNVALIDATED_REDIRECT; + if (module != null) { + if (null != location && !location.isEmpty()) { + module.onRedirect(location); + } + } + } + } + + public static class SendRedirect3ArgAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + @Sink(VulnerabilityTypes.UNVALIDATED_REDIRECT) + public static void onEnter(@Advice.Argument(0) final String location) { + final UnvalidatedRedirectModule module = InstrumentationBridge.UNVALIDATED_REDIRECT; + if (module != null) { + if (null != location && !location.isEmpty()) { + module.onRedirect(location); + } + } + } + } + + public static class EncodeURLAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Propagation + public static void onExit(@Advice.Argument(0) final String url, @Advice.Return String encoded) { + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + if (null != url && !url.isEmpty() && null != encoded && !encoded.isEmpty()) { + module.taintStringIfTainted(encoded, url); + } + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpSessionInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpSessionInstrumentation.java new file mode 100644 index 00000000000..3b705281b9b --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaHttpSessionInstrumentation.java @@ -0,0 +1,56 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedOneOf; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.not; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Sink; +import datadog.trace.api.iast.VulnerabilityTypes; +import datadog.trace.api.iast.sink.TrustBoundaryViolationModule; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +public class JakartaHttpSessionInstrumentation extends InstrumenterModule.Iast + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + public JakartaHttpSessionInstrumentation() { + super("servlet", "servlet-6", "servlet-session"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpSession"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())) + .and(not(named("com.ibm.ws.session.HttpSessionFacade"))); + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + namedOneOf("setAttribute", "putValue") + .and(takesArguments(String.class, Object.class).and(isPublic())), + getClass().getName() + "$InstrumenterAdvice"); + } + + public static class InstrumenterAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + @Sink(VulnerabilityTypes.TRUST_BOUNDARY_VIOLATION) + public static void onEnter( + @Advice.Argument(0) final String name, @Advice.Argument(1) final Object value) { + TrustBoundaryViolationModule mod = InstrumentationBridge.TRUST_BOUNDARY_VIOLATION; + if (mod != null) { + mod.onSessionValue(name, value); + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaMultipartInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaMultipartInstrumentation.java new file mode 100644 index 00000000000..4ee3b05be55 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaMultipartInstrumentation.java @@ -0,0 +1,153 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import com.google.auto.service.AutoService; +import datadog.trace.advice.ActiveRequestContext; +import datadog.trace.advice.RequiresRequestContext; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.gateway.RequestContext; +import datadog.trace.api.gateway.RequestContextSlot; +import datadog.trace.api.iast.IastContext; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Source; +import datadog.trace.api.iast.SourceTypes; +import datadog.trace.api.iast.propagation.PropagationModule; +import java.io.InputStream; +import java.util.Collection; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@AutoService(InstrumenterModule.class) +public class JakartaMultipartInstrumentation extends InstrumenterModule.Iast + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + + public JakartaMultipartInstrumentation() { + super("servlet", "servlet-6", "multipart"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.Part"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())); + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + named("getName").and(isPublic()).and(takesArguments(0)), + getClass().getName() + "$GetNameAdvice"); + transformer.applyAdvice( + named("getHeader").and(isPublic()).and(takesArguments(String.class)), + getClass().getName() + "$GetHeaderAdvice"); + transformer.applyAdvice( + named("getHeaders").and(isPublic()).and(takesArguments(String.class)), + getClass().getName() + "$GetHeadersAdvice"); + transformer.applyAdvice( + named("getHeaderNames").and(isPublic()).and(takesArguments(0)), + getClass().getName() + "$GetHeaderNamesAdvice"); + transformer.applyAdvice( + named("getInputStream").and(isPublic()).and(takesArguments(0)), + getClass().getName() + "$GetInputStreamAdvice"); + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetNameAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_MULTIPART_PARAMETER) + public static String onExit( + @Advice.Return final String name, @ActiveRequestContext RequestContext reqCtx) { + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintString( + ctx, name, SourceTypes.REQUEST_MULTIPART_PARAMETER, "Content-Disposition"); + } + return name; + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetHeaderAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_MULTIPART_PARAMETER) + public static String onExit( + @Advice.Return final String value, + @Advice.Argument(0) final String name, + @ActiveRequestContext RequestContext reqCtx) { + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintString(ctx, value, SourceTypes.REQUEST_MULTIPART_PARAMETER, name); + } + return value; + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetHeadersAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_MULTIPART_PARAMETER) + public static void onExit( + @Advice.Argument(0) final String headerName, + @Advice.Return Collection headerValues, + @ActiveRequestContext RequestContext reqCtx) { + if (null == headerValues || headerValues.isEmpty()) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + for (final String value : headerValues) { + module.taintString(ctx, value, SourceTypes.REQUEST_MULTIPART_PARAMETER, headerName); + } + } + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetHeaderNamesAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_MULTIPART_PARAMETER) + public static void onExit( + @Advice.Return final Collection headerNames, + @ActiveRequestContext RequestContext reqCtx) { + if (null == headerNames || headerNames.isEmpty()) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + for (final String name : headerNames) { + module.taintString(ctx, name, SourceTypes.REQUEST_MULTIPART_PARAMETER); + } + } + } + } + + @RequiresRequestContext(RequestContextSlot.IAST) + public static class GetInputStreamAdvice { + @Advice.OnMethodExit(suppress = Throwable.class) + @Source(SourceTypes.REQUEST_MULTIPART_PARAMETER) + public static void onExit( + @Advice.Return final InputStream inputStream, @ActiveRequestContext RequestContext reqCtx) { + if (null == inputStream) { + return; + } + final PropagationModule module = InstrumentationBridge.PROPAGATION; + if (module != null) { + final IastContext ctx = reqCtx.getData(RequestContextSlot.IAST); + module.taintObject(ctx, inputStream, SourceTypes.REQUEST_MULTIPART_PARAMETER); + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaServletBlockingHelper.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaServletBlockingHelper.java new file mode 100644 index 00000000000..8714bb10a5e --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaServletBlockingHelper.java @@ -0,0 +1,89 @@ +package datadog.trace.instrumentation.servlet6; + +import datadog.appsec.api.blocking.BlockingContentType; +import datadog.trace.api.gateway.Flow; +import datadog.trace.api.internal.TraceSegment; +import datadog.trace.bootstrap.blocking.BlockingActionHelper; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JakartaServletBlockingHelper { + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + private static final Logger log = LoggerFactory.getLogger(JakartaServletBlockingHelper.class); + + public static void commitBlockingResponse( + TraceSegment segment, + HttpServletRequest httpServletRequest, + HttpServletResponse resp, + int statusCode_, + BlockingContentType bct, + Map extraHeaders, + String securityResponseId) { + int statusCode = BlockingActionHelper.getHttpCode(statusCode_); + if (!start(resp, statusCode)) { + return; + } + + for (Map.Entry h : extraHeaders.entrySet()) { + resp.setHeader(h.getKey(), h.getValue()); + } + + byte[] template; + if (bct != BlockingContentType.NONE) { + String acceptHeader = httpServletRequest.getHeader("Accept"); + BlockingActionHelper.TemplateType type = + BlockingActionHelper.determineTemplateType(bct, acceptHeader); + template = BlockingActionHelper.getTemplate(type, securityResponseId); + String contentType = BlockingActionHelper.getContentType(type); + + resp.setHeader("Content-length", Integer.toString(template.length)); + resp.setHeader("Content-type", contentType); + } else { + template = EMPTY_BYTE_ARRAY; + } + segment.effectivelyBlocked(); + + try { + OutputStream os = resp.getOutputStream(); + os.write(template); + os.close(); + } catch (IOException e) { + log.warn("Error sending error page", e); + } + } + + public static void commitBlockingResponse( + TraceSegment segment, + HttpServletRequest httpServletRequest, + HttpServletResponse resp, + Flow.Action.RequestBlockingAction rba) { + + commitBlockingResponse( + segment, + httpServletRequest, + resp, + rba.getStatusCode(), + rba.getBlockingContentType(), + rba.getExtraHeaders(), + rba.getSecurityResponseId()); + } + + private static boolean start(HttpServletResponse resp, int statusCode) { + if (resp.isCommitted()) { + log.warn("response already committed, we can't change it"); + return false; + } + + log.debug("Committing blocking response"); + + resp.reset(); + resp.setStatus(statusCode); + + return true; + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaServletInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaServletInstrumentation.java new file mode 100644 index 00000000000..b5eb242298a --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/JakartaServletInstrumentation.java @@ -0,0 +1,157 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.hasSuperType; +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.bootstrap.instrumentation.api.Java8BytecodeBridge.spanFromContext; +import static datadog.trace.bootstrap.instrumentation.decorator.HttpServerDecorator.DD_CONTEXT_ATTRIBUTE; +import static datadog.trace.bootstrap.instrumentation.decorator.HttpServerDecorator.DD_RUM_INJECTED; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.takesArgument; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +import com.google.auto.service.AutoService; +import datadog.context.Context; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.ClassloaderConfigurationOverrides; +import datadog.trace.api.Config; +import datadog.trace.api.DDTags; +import datadog.trace.api.rum.RumInjector; +import datadog.trace.bootstrap.CallDepthThreadLocalMap; +import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.rum.RumControllableResponse; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@AutoService(InstrumenterModule.class) +public class JakartaServletInstrumentation extends InstrumenterModule.Tracing + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + public JakartaServletInstrumentation() { + super("servlet", "servlet-6"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpServlet"; + } + + @Override + public String[] helperClassNames() { + return new String[] { + packageName + ".RumHttpServletRequestWrapper", + packageName + ".RumHttpServletResponseWrapper", + packageName + ".RumHttpServletResponseWrapper60", + packageName + ".WrappedServletOutputStream", + }; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return hasSuperType(named(hierarchyMarkerType())) + .or(implementsInterface(named("jakarta.servlet.FilterChain"))); + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + isMethod() + .and(named("service")) + .and(isPublic()) + .and(takesArguments(2)) + .and(takesArgument(0, named("jakarta.servlet.ServletRequest"))) + .and(takesArgument(1, named("jakarta.servlet.ServletResponse"))), + getClass().getName() + "$JakartaServletAdvice"); + } + + public static class JakartaServletAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static AgentSpan before( + @Advice.Argument(value = 0, readOnly = false) ServletRequest request, + @Advice.Argument(value = 1, readOnly = false) ServletResponse response, + @Advice.Local("rumServletWrapper") RumControllableResponse rumServletWrapper) { + if (!(request instanceof HttpServletRequest)) { + return null; + } + + if (response instanceof HttpServletResponse) { + final HttpServletRequest httpServletRequest = (HttpServletRequest) request; + + if (RumInjector.get().isEnabled()) { + final Object maybeRumWrapper = httpServletRequest.getAttribute(DD_RUM_INJECTED); + if (maybeRumWrapper instanceof RumControllableResponse) { + rumServletWrapper = (RumControllableResponse) maybeRumWrapper; + } else { + rumServletWrapper = + new RumHttpServletResponseWrapper60( + httpServletRequest, (HttpServletResponse) response); + httpServletRequest.setAttribute(DD_RUM_INJECTED, rumServletWrapper); + response = (ServletResponse) rumServletWrapper; + request = + new RumHttpServletRequestWrapper( + httpServletRequest, (HttpServletResponse) rumServletWrapper); + } + } + } + + Object contextAttr = request.getAttribute(DD_CONTEXT_ATTRIBUTE); + if (contextAttr instanceof Context + && CallDepthThreadLocalMap.incrementCallDepth(HttpServletRequest.class) == 0) { + final Context context = (Context) contextAttr; + final AgentSpan span = spanFromContext(context); + if (span != null) { + ClassloaderConfigurationOverrides.maybeEnrichSpan(span); + return span; + } + } + return null; + } + + @Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class) + public static void after( + @Advice.Enter final AgentSpan span, + @Advice.Argument(0) final ServletRequest request, + @Advice.Local("rumServletWrapper") RumControllableResponse rumServletWrapper) { + if (span == null) { + return; + } + if (rumServletWrapper != null) { + rumServletWrapper.commit(); + } + + CallDepthThreadLocalMap.reset(HttpServletRequest.class); + final HttpServletRequest httpServletRequest = + (HttpServletRequest) request; // at this point the cast should be safe + if (Config.get().isServletPrincipalEnabled() + && httpServletRequest.getUserPrincipal() != null) { + span.setTag(DDTags.USER_NAME, httpServletRequest.getUserPrincipal().getName()); + } + + // Servlet 6.0 enrichment + try { + String requestId = httpServletRequest.getRequestId(); + if (requestId != null && !requestId.isEmpty()) { + span.setTag("http.request_id", requestId); + } + String protocolRequestId = httpServletRequest.getProtocolRequestId(); + if (protocolRequestId != null && !protocolRequestId.isEmpty()) { + span.setTag("network.protocol_request_id", protocolRequestId); + } + jakarta.servlet.ServletConnection conn = httpServletRequest.getServletConnection(); + if (conn != null) { + String connId = conn.getConnectionId(); + if (connId != null) span.setTag("network.connection.id", connId); + String protocol = conn.getProtocol(); + if (protocol != null) span.setTag("network.protocol.name", protocol); + } + } catch (Exception ignored) { + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumAsyncContextInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumAsyncContextInstrumentation.java new file mode 100644 index 00000000000..c84789d091a --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumAsyncContextInstrumentation.java @@ -0,0 +1,67 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedOneOf; +import static datadog.trace.bootstrap.instrumentation.decorator.HttpServerDecorator.DD_RUM_INJECTED; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; + +import com.google.auto.service.AutoService; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import datadog.trace.api.InstrumenterConfig; +import datadog.trace.bootstrap.instrumentation.rum.RumControllableResponse; +import jakarta.servlet.AsyncContext; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@AutoService(InstrumenterModule.class) +public class RumAsyncContextInstrumentation extends InstrumenterModule.Tracing + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + + public RumAsyncContextInstrumentation() { + super("servlet", "servlet-6", "servlet-6-async-context"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.AsyncContext"; + } + + @Override + public String[] helperClassNames() { + return new String[] { + packageName + ".RumHttpServletResponseWrapper", + packageName + ".RumHttpServletResponseWrapper60", + packageName + ".WrappedServletOutputStream", + }; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())); + } + + @Override + public boolean isEnabled() { + return super.isEnabled() && InstrumenterConfig.get().isRumEnabled(); + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + isMethod().and(namedOneOf("complete", "dispatch")), getClass().getName() + "$CommitAdvice"); + } + + public static class CommitAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void commitRumBuffer(@Advice.This final AsyncContext asyncContext) { + final Object maybeRumWrappedResponse = + asyncContext.getRequest().getAttribute(DD_RUM_INJECTED); + if (maybeRumWrappedResponse instanceof RumControllableResponse) { + ((RumControllableResponse) maybeRumWrappedResponse).commit(); + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletRequestWrapper.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletRequestWrapper.java new file mode 100644 index 00000000000..0f47851ea0c --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletRequestWrapper.java @@ -0,0 +1,47 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.bootstrap.instrumentation.decorator.HttpServerDecorator.DD_RUM_INJECTED; + +import datadog.trace.bootstrap.instrumentation.rum.RumControllableResponse; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import jakarta.servlet.http.HttpServletResponse; + +public class RumHttpServletRequestWrapper extends HttpServletRequestWrapper { + + private final HttpServletResponse response; + + public RumHttpServletRequestWrapper( + final HttpServletRequest request, final HttpServletResponse response) { + super(request); + this.response = response; + } + + @Override + public AsyncContext startAsync() throws IllegalStateException { + // need to hide this method otherwise we cannot control the wrapped response used asynchronously + return startAsync(getRequest(), response); + } + + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) + throws IllegalStateException { + // deactivate the previous wrapper + final Object maybeRumWrappedResponse = (servletRequest.getAttribute(DD_RUM_INJECTED)); + if (maybeRumWrappedResponse instanceof RumControllableResponse) { + ((RumControllableResponse) maybeRumWrappedResponse).commit(); + ((RumControllableResponse) maybeRumWrappedResponse).stopFiltering(); + } + ServletResponse actualResponse = servletResponse; + // rewrap it + if (servletResponse instanceof HttpServletResponse) { + actualResponse = + new RumHttpServletResponseWrapper(this, (HttpServletResponse) servletResponse); + servletRequest.setAttribute(DD_RUM_INJECTED, actualResponse); + } + return super.startAsync(servletRequest, actualResponse); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletResponseWrapper.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletResponseWrapper.java new file mode 100644 index 00000000000..07b4774d510 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletResponseWrapper.java @@ -0,0 +1,236 @@ +package datadog.trace.instrumentation.servlet6; + +import datadog.trace.api.rum.RumInjector; +import datadog.trace.bootstrap.instrumentation.buffer.InjectingPipeWriter; +import datadog.trace.bootstrap.instrumentation.rum.RumControllableResponse; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.HttpServletResponseWrapper; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.Charset; + +public class RumHttpServletResponseWrapper extends HttpServletResponseWrapper + implements RumControllableResponse { + private final RumInjector rumInjector; + private final String servletVersion; + private WrappedServletOutputStream outputStream; + private InjectingPipeWriter wrappedPipeWriter; + private PrintWriter printWriter; + private boolean shouldInject = true; + private String contentEncoding = null; + + public RumHttpServletResponseWrapper(HttpServletRequest request, HttpServletResponse response) { + super(response); + this.rumInjector = RumInjector.get(); + + String version = "5"; + ServletContext servletContext = request.getServletContext(); + if (servletContext != null) { + try { + version = String.valueOf(servletContext.getEffectiveMajorVersion()); + } catch (Exception e) { + } + } + this.servletVersion = version; + } + + @Override + public ServletOutputStream getOutputStream() throws IOException { + if (outputStream != null) { + return outputStream; + } + if (!shouldInject) { + RumInjector.getTelemetryCollector().onInjectionSkipped(servletVersion); + return super.getOutputStream(); + } + try { + String encoding = getCharacterEncoding(); + if (encoding == null) { + encoding = Charset.defaultCharset().name(); + } + outputStream = + new WrappedServletOutputStream( + super.getOutputStream(), + rumInjector.getMarkerBytes(encoding), + rumInjector.getSnippetBytes(encoding), + this::onInjected, + bytes -> + RumInjector.getTelemetryCollector() + .onInjectionResponseSize(servletVersion, bytes), + milliseconds -> + RumInjector.getTelemetryCollector() + .onInjectionTime(servletVersion, milliseconds)); + } catch (Exception e) { + RumInjector.getTelemetryCollector().onInjectionFailed(servletVersion, contentEncoding); + throw e; + } + return outputStream; + } + + @Override + public PrintWriter getWriter() throws IOException { + if (printWriter != null) { + return printWriter; + } + if (!shouldInject) { + RumInjector.getTelemetryCollector().onInjectionSkipped(servletVersion); + return super.getWriter(); + } + try { + wrappedPipeWriter = + new InjectingPipeWriter( + super.getWriter(), + rumInjector.getMarkerChars(), + rumInjector.getSnippetChars(), + this::onInjected, + bytes -> + RumInjector.getTelemetryCollector() + .onInjectionResponseSize(servletVersion, bytes), + milliseconds -> + RumInjector.getTelemetryCollector() + .onInjectionTime(servletVersion, milliseconds)); + printWriter = new PrintWriter(wrappedPipeWriter); + } catch (Exception e) { + RumInjector.getTelemetryCollector().onInjectionFailed(servletVersion, contentEncoding); + throw e; + } + + return printWriter; + } + + @Override + public void setHeader(String name, String value) { + if (shouldInject) { + if (isContentLengthHeader(name)) { + return; + } + checkForContentType(name, value); + checkForContentSecurityPolicy(name); + } + super.setHeader(name, value); + } + + @Override + public void addHeader(String name, String value) { + if (shouldInject) { + if (isContentLengthHeader(name)) { + return; + } + checkForContentType(name, value); + checkForContentSecurityPolicy(name); + } + super.addHeader(name, value); + } + + private boolean isContentLengthHeader(String name) { + return "content-length".equalsIgnoreCase(name); + } + + private void checkForContentSecurityPolicy(String name) { + if ("content-security-policy".equalsIgnoreCase(name)) { + RumInjector.getTelemetryCollector().onContentSecurityPolicyDetected(servletVersion); + } + } + + private void checkForContentType(String name, String value) { + if ("content-type".equalsIgnoreCase(name)) { + handleContentType(value); + } + } + + @Override + public void setContentLength(int len) { + // don't set it since we don't know if we will inject + if (!shouldInject) { + super.setContentLength(len); + } + } + + @Override + public void setContentLengthLong(long len) { + if (!shouldInject) { + super.setContentLengthLong(len); + } + } + + @Override + public void setCharacterEncoding(String charset) { + if (charset != null) { + this.contentEncoding = charset; + } + super.setCharacterEncoding(charset); + } + + @Override + public void reset() { + this.outputStream = null; + this.wrappedPipeWriter = null; + this.printWriter = null; + this.shouldInject = false; + super.reset(); + } + + @Override + public void resetBuffer() { + this.outputStream = null; + this.wrappedPipeWriter = null; + this.printWriter = null; + super.resetBuffer(); + } + + public void onInjected() { + RumInjector.getTelemetryCollector().onInjectionSucceed(servletVersion); + try { + setHeader("x-datadog-rum-injected", "1"); + } catch (Throwable ignored) { + // suppress exception if arisen setting this header by us. + } + } + + private void handleContentType(String type) { + final boolean wasInjecting = shouldInject; + if (shouldInject) { + shouldInject = type != null && type.contains("text/html"); + } + if (wasInjecting && !shouldInject) { + commit(); + stopFiltering(); + } + } + + @Override + public void setContentType(String type) { + handleContentType(type); + super.setContentType(type); + } + + @Override + public void commit() { + if (wrappedPipeWriter != null) { + try { + wrappedPipeWriter.commit(); + } catch (Throwable ignored) { + } + } + if (outputStream != null) { + try { + outputStream.commit(); + } catch (Throwable ignored) { + } + } + } + + @Override + public void stopFiltering() { + shouldInject = false; + if (wrappedPipeWriter != null) { + wrappedPipeWriter.setFilter(false); + } + if (outputStream != null) { + outputStream.setFilter(false); + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletResponseWrapper60.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletResponseWrapper60.java new file mode 100644 index 00000000000..1329a90654b --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/RumHttpServletResponseWrapper60.java @@ -0,0 +1,17 @@ +package datadog.trace.instrumentation.servlet6; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; + +public class RumHttpServletResponseWrapper60 extends RumHttpServletResponseWrapper { + public RumHttpServletResponseWrapper60(HttpServletRequest request, HttpServletResponse response) { + super(request, response); + } + + @Override + public void sendRedirect(String location, int sc, boolean clearBuffer) throws IOException { + commit(); + super.sendRedirect(location, sc, clearBuffer); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/Servlet6RequestBodyInstrumentation.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/Servlet6RequestBodyInstrumentation.java new file mode 100644 index 00000000000..ab680d92589 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/Servlet6RequestBodyInstrumentation.java @@ -0,0 +1,68 @@ +package datadog.trace.instrumentation.servlet6; + +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.extendsClass; +import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named; +import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedNoneOf; +import static net.bytebuddy.matcher.ElementMatchers.isPublic; +import static net.bytebuddy.matcher.ElementMatchers.not; +import static net.bytebuddy.matcher.ElementMatchers.returns; +import static net.bytebuddy.matcher.ElementMatchers.takesNoArguments; + +import com.google.auto.service.AutoService; +import datadog.trace.agent.tooling.Instrumenter; +import datadog.trace.agent.tooling.InstrumenterModule; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +@AutoService(InstrumenterModule.class) +public class Servlet6RequestBodyInstrumentation extends InstrumenterModule.AppSec + implements Instrumenter.ForTypeHierarchy, Instrumenter.HasMethodAdvice { + public Servlet6RequestBodyInstrumentation() { + super("servlet-request-body"); + } + + @Override + public String hierarchyMarkerType() { + return "jakarta.servlet.http.HttpServletRequest"; + } + + @Override + public ElementMatcher hierarchyMatcher() { + return implementsInterface(named(hierarchyMarkerType())) + // ignore wrappers that ship with servlet-api + .and(namedNoneOf("jakarta.servlet.http.HttpServletRequestWrapper")) + .and(not(extendsClass(named("jakarta.servlet.http.HttpServletRequestWrapper")))); + } + + @Override + public void methodAdvice(MethodTransformer transformer) { + transformer.applyAdvice( + named("getInputStream") + .and(takesNoArguments()) + .and(returns(named("jakarta.servlet.ServletInputStream"))) + .and(isPublic()), + packageName + ".HttpServletGetInputStreamAdvice"); + transformer.applyAdvice( + named("getReader") + .and(takesNoArguments()) + .and(returns(named("java.io.BufferedReader"))) + .and(isPublic()), + packageName + ".HttpServletGetReaderAdvice"); + } + + @Override + public String[] helperClassNames() { + return new String[] { + "datadog.trace.instrumentation.servlet6.BufferedReaderWrapper", + "datadog.trace.instrumentation.servlet6.AbstractServletInputStreamWrapper", + "datadog.trace.instrumentation.servlet6.Servlet31InputStreamWrapper" + }; + } + + @Override + public int order() { + // apply this instrumentation after the regular servlet one. + return 1; + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/WrappedServletOutputStream.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/WrappedServletOutputStream.java new file mode 100644 index 00000000000..267cb76834c --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/WrappedServletOutputStream.java @@ -0,0 +1,68 @@ +package datadog.trace.instrumentation.servlet6; + +import datadog.trace.bootstrap.instrumentation.buffer.InjectingPipeOutputStream; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.WriteListener; +import java.io.IOException; +import java.util.function.LongConsumer; + +public class WrappedServletOutputStream extends ServletOutputStream { + private final InjectingPipeOutputStream filtered; + private final ServletOutputStream delegate; + + public WrappedServletOutputStream( + ServletOutputStream delegate, + byte[] marker, + byte[] contentToInject, + Runnable onInjected, + LongConsumer onBytesWritten, + LongConsumer onInjectionTime) { + this.filtered = + new InjectingPipeOutputStream( + delegate, marker, contentToInject, onInjected, onBytesWritten, onInjectionTime); + this.delegate = delegate; + } + + @Override + public void write(int b) throws IOException { + filtered.write(b); + } + + @Override + public void write(byte[] b) throws IOException { + filtered.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + filtered.write(b, off, len); + } + + public void commit() throws IOException { + filtered.commit(); + } + + @Override + public void flush() throws IOException { + filtered.flush(); + } + + @Override + public void close() throws IOException { + filtered.close(); + } + + @Override + public boolean isReady() { + return delegate.isReady(); + } + + @Override + public void setWriteListener(WriteListener writeListener) { + delegate.setWriteListener(writeListener); + } + + public void setFilter(boolean filter) { + filtered.setFilter(filter); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/jsp/JakartaJspWriterCallSite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/jsp/JakartaJspWriterCallSite.java new file mode 100644 index 00000000000..fc40d7d9ddd --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/jsp/JakartaJspWriterCallSite.java @@ -0,0 +1,28 @@ +package datadog.trace.instrumentation.servlet6.jsp; + +import datadog.trace.agent.tooling.csi.CallSite; +import datadog.trace.api.iast.IastCallSites; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Sink; +import datadog.trace.api.iast.VulnerabilityTypes; +import datadog.trace.api.iast.sink.XssModule; +import javax.annotation.Nonnull; + +@Sink(VulnerabilityTypes.XSS) +@CallSite(spi = IastCallSites.class) +public class JakartaJspWriterCallSite { + + @CallSite.Before("void jakarta.servlet.jsp.JspWriter.print(java.lang.String)") + @CallSite.Before("void jakarta.servlet.jsp.JspWriter.println(java.lang.String)") + @CallSite.Before("void jakarta.servlet.jsp.JspWriter.write(java.lang.String)") + public static void beforeStringParam(@CallSite.Argument(0) @Nonnull final String s) { + final XssModule module = InstrumentationBridge.XSS; + if (module != null) { + try { + module.onXss(s); + } catch (final Throwable e) { + module.onUnexpectedException("beforeStringParam threw", e); + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/jsp/JakartaJspWriterFullDetectionCallSite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/jsp/JakartaJspWriterFullDetectionCallSite.java new file mode 100644 index 00000000000..9fd5819af2e --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/main/java/datadog/trace/instrumentation/servlet6/jsp/JakartaJspWriterFullDetectionCallSite.java @@ -0,0 +1,30 @@ +package datadog.trace.instrumentation.servlet6.jsp; + +import datadog.trace.agent.tooling.csi.CallSite; +import datadog.trace.api.iast.IastCallSites; +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.Sink; +import datadog.trace.api.iast.VulnerabilityTypes; +import datadog.trace.api.iast.sink.XssModule; +import javax.annotation.Nonnull; + +@Sink(VulnerabilityTypes.XSS) +@CallSite( + spi = IastCallSites.class, + enabled = {"datadog.trace.api.iast.IastEnabledChecks", "isFullDetection"}) +public class JakartaJspWriterFullDetectionCallSite { + + @CallSite.Before("void jakarta.servlet.jsp.JspWriter.print(char[])") + @CallSite.Before("void jakarta.servlet.jsp.JspWriter.println(char[])") + @CallSite.Before("void jakarta.servlet.jsp.JspWriter.write(char[])") + public static void beforeCharArrayParam(@CallSite.Argument(0) @Nonnull final char[] buf) { + final XssModule module = InstrumentationBridge.XSS; + if (module != null) { + try { + module.onXss(buf); + } catch (final Throwable e) { + module.onUnexpectedException("beforeCharArrayParam threw", e); + } + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/IastJakartaServletInstrumentationTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/IastJakartaServletInstrumentationTest.groovy new file mode 100644 index 00000000000..fc7a6891b1e --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/IastJakartaServletInstrumentationTest.groovy @@ -0,0 +1,48 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.iast.InstrumentationBridge +import datadog.trace.api.iast.sink.ApplicationModule +import foo.bar.smoketest.DummyHttpServlet +import foo.bar.smoketest.DummyRequest +import foo.bar.smoketest.DummyResponse +import jakarta.servlet.Servlet +import jakarta.servlet.ServletRequest +import jakarta.servlet.ServletResponse + +class IastJakartaServletInstrumentationTest extends InstrumentationSpecification{ + + @Override + protected void configurePreAgent() { + injectSysConfig("dd.iast.enabled", "true") + } + + void 'test no modules'() { + final appModule = Mock(ApplicationModule) + final Servlet servlet = new DummyHttpServlet() + final ServletResponse response = new DummyResponse() + final ServletRequest request = new DummyRequest() + + when: + servlet.callPublicServiceMethod(request, response) + + then: + 0 * appModule.onRealPath(_) + 0 * appModule.checkSessionTrackingModes(_) + 0 * _ + } + + void 'test ApplicationModule'() { + given: + final module = Mock(ApplicationModule) + InstrumentationBridge.registerIastModule(module) + final Servlet servlet = new DummyHttpServlet() + final ServletResponse response = new DummyResponse() + final ServletRequest request = new DummyRequest() + + when: + servlet.callPublicServiceMethod(request, response) + + then: + 1 * module.onRealPath(_) + 0 * _ + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaHttpServletRequestInstrumentationTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaHttpServletRequestInstrumentationTest.groovy new file mode 100644 index 00000000000..f8d0d1811bf --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaHttpServletRequestInstrumentationTest.groovy @@ -0,0 +1,507 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.iast.IastContext +import datadog.trace.api.iast.InstrumentationBridge +import datadog.trace.api.iast.SourceTypes +import datadog.trace.api.iast.propagation.PropagationModule +import datadog.trace.api.iast.sink.ApplicationModule +import datadog.trace.api.iast.sink.UnvalidatedRedirectModule +import datadog.trace.bootstrap.instrumentation.api.AgentTracer +import datadog.trace.bootstrap.instrumentation.api.TagContext +import foo.bar.smoketest.JakartaHttpServletRequestTestSuite +import foo.bar.smoketest.JakartaHttpServletRequestWrapperTestSuite +import foo.bar.smoketest.ServletRequestTestSuite +import jakarta.servlet.RequestDispatcher +import jakarta.servlet.ServletContext +import jakarta.servlet.ServletInputStream +import jakarta.servlet.SessionTrackingMode +import jakarta.servlet.http.Cookie +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletRequestWrapper + +import datadog.trace.agent.tooling.iast.TaintableEnumeration +import jakarta.servlet.http.HttpSession + +class JakartaHttpServletRequestInstrumentationTest extends InstrumentationSpecification { + + private Object iastCtx + + @Override + protected void configurePreAgent() { + injectSysConfig('dd.iast.enabled', 'true') + } + + void setup() { + iastCtx = Stub(IastContext) + } + + void cleanup() { + InstrumentationBridge.clearIastModules() + } + + void 'test getHeader #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getHeader('header') } + + then: + result == 'value' + 1 * mock.getHeader('header') >> 'value' + 1 * iastModule.taintString(iastCtx, 'value', SourceTypes.REQUEST_HEADER_VALUE, 'header') + 0 * _ + + where: + suite << testSuite() + } + + void 'test getHeaders #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final headers = ['value1', 'value2'] + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getHeaders('headers').collect() } + + then: + result == headers + 1 * mock.getHeaders('headers') >> Collections.enumeration(headers) + headers.each { 1 * iastModule.taintString(iastCtx, it, SourceTypes.REQUEST_HEADER_VALUE, 'headers') } + 0 * _ + + where: + suite << testSuite() + } + + void 'test getHeaderNames #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final headers = ['header1', 'header2'] + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getHeaderNames().collect() } + + then: + result == headers + 1 * mock.getHeaderNames() >> Collections.enumeration(headers) + headers.each { 1 * iastModule.taintString(iastCtx, it, SourceTypes.REQUEST_HEADER_NAME, it) } + 0 * _ + + where: + suite << testSuite() + } + + void 'test getParameter #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getParameter('parameter') } + + then: + result == 'value' + 1 * mock.getParameter('parameter') >> 'value' + 1 * iastModule.taintString(iastCtx, 'value', SourceTypes.REQUEST_PARAMETER_VALUE, 'parameter') + 0 * _ + + where: + suite << testSuite() + } + + void 'test getParameterValues #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final values = ['value1', 'value2'] + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getParameterValues('parameter').collect() } + + then: + result == values + 1 * mock.getParameterValues('parameter') >> { values as String[] } + values.each { 1 * iastModule.taintString(iastCtx, it, SourceTypes.REQUEST_PARAMETER_VALUE, 'parameter') } + 0 * _ + + where: + suite << testSuite() + } + + void 'test getParameterMap #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final parameters = [parameter: ['header1', 'header2'] as String[]] + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getParameterMap() } + + then: + result == parameters + 1 * mock.getParameterMap() >> parameters + parameters.each { key, values -> + 1 * iastModule.taintString(iastCtx, key, SourceTypes.REQUEST_PARAMETER_NAME, key) + values.each { value -> + 1 * iastModule.taintString(iastCtx, value, SourceTypes.REQUEST_PARAMETER_VALUE, key) + } + } + 0 * _ + + where: + suite << testSuite() + } + + + void 'test getParameterNames #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final parameters = ['param1', 'param2'] + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getParameterNames().collect() } + + then: + result == parameters + 1 * mock.getParameterNames() >> Collections.enumeration(parameters) + parameters.each { 1 * iastModule.taintString(iastCtx, it, SourceTypes.REQUEST_PARAMETER_NAME, it) } + 0 * _ + + where: + suite << testSuite() + } + + void 'test getCookies #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final cookies = [new Cookie('name1', 'value1'), new Cookie('name2', 'value2')] as Cookie[] + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getCookies() } + + then: + result == cookies + 1 * mock.getCookies() >> cookies + cookies.each { 1 * iastModule.taintObject(iastCtx, it, SourceTypes.REQUEST_COOKIE_VALUE) } + 0 * _ + + where: + suite << testSuite() + } + + void 'test that get headers does not fail when servlet related code fails #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final enumeration = Mock(Enumeration) { + hasMoreElements() >> { throw new NuclearBomb('Boom!!!') } + } + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final headers = runUnderIastTrace { request.getHeaders('header') } + + then: + 1 * mock.getHeaders('header') >> enumeration + noExceptionThrown() + + when: + headers.hasMoreElements() + + then: + final bomb = thrown(NuclearBomb) + bomb.stackTrace.find { it.className == TaintableEnumeration.name } == null + + where: + suite << testSuite() + } + + void 'test that get header names does not fail when servlet related code fails #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final enumeration = Mock(Enumeration) { + hasMoreElements() >> { throw new NuclearBomb('Boom!!!') } + } + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getHeaderNames() } + + then: + 1 * mock.getHeaderNames() >> enumeration + noExceptionThrown() + + when: + result.hasMoreElements() + + then: + final bomb = thrown(NuclearBomb) + bomb.stackTrace.find { it.className == TaintableEnumeration.name } == null + + where: + suite << testSuite() + } + + void 'test get query string #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final queryString = 'paramName=paramValue' + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final String result = runUnderIastTrace { request.getQueryString() } + + then: + result == queryString + 1 * mock.getQueryString() >> queryString + 1 * iastModule.taintString(iastCtx, queryString, SourceTypes.REQUEST_QUERY) + 0 * _ + + where: + suite << testSuite() + } + + void 'test getInputStream #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final is = Mock(ServletInputStream) + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getInputStream() } + + then: + result == is + 1 * mock.getInputStream() >> is + 1 * iastModule.taintObject(iastCtx, is, SourceTypes.REQUEST_BODY) + 0 * _ + + where: + suite << testSuite() + } + + void 'test getReader #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final reader = Mock(BufferedReader) + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getReader() } + + then: + result == reader + 1 * mock.getReader() >> reader + 1 * iastModule.taintObject(iastCtx, reader, SourceTypes.REQUEST_BODY) + 0 * _ + + where: + suite << testSuite() + } + + void 'test getRequestDispatcher #iterationIndex'() { + setup: + final iastModule = Mock(UnvalidatedRedirectModule) + InstrumentationBridge.registerIastModule(iastModule) + final path = 'http://dummy.location.com' + final dispatcher = Mock(RequestDispatcher) + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getRequestDispatcher(path) } + + then: + result == dispatcher + 1 * mock.getRequestDispatcher(path) >> dispatcher + 1 * iastModule.onRedirect(path) + 0 * _ + + where: + suite << testSuite() + } + + void 'test getRequestURI #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final uri = 'retValue' + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getRequestURI() } + + then: + result == uri + 1 * mock.getRequestURI() >> uri + 1 * iastModule.taintString(iastCtx, uri, SourceTypes.REQUEST_PATH) + 0 * _ + + where: + suite << testSuiteCallSites() + } + + void 'test getPathInfo #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final pathInfo = 'retValue' + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getPathInfo() } + + then: + result == pathInfo + 1 * mock.getPathInfo() >> pathInfo + 1 * iastModule.taintString(iastCtx, pathInfo, SourceTypes.REQUEST_PATH) + 0 * _ + + where: + suite << testSuiteCallSites() + } + + void 'test getPathTranslated #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final pathTranslated = 'retValue' + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getPathTranslated() } + + then: + result == pathTranslated + 1 * mock.getPathTranslated() >> pathTranslated + 1 * iastModule.taintString(iastCtx, pathTranslated, SourceTypes.REQUEST_PATH) + 0 * _ + + where: + suite << testSuiteCallSites() + } + + void 'test getRequestURL #iterationIndex'() { + setup: + final iastModule = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(iastModule) + final url = new StringBuffer('retValue') + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getRequestURL() } + + then: + result == url + 1 * mock.getRequestURL() >> url + 1 * iastModule.taintObject(iastCtx, url, SourceTypes.REQUEST_URI) + 0 * _ + + where: + suite << testSuiteCallSites() + } + + void 'test getSession #iterationIndex'() { + setup: + final iastModule = Mock(ApplicationModule) + InstrumentationBridge.registerIastModule(iastModule) + final session = Mock(HttpSession) + final servletContext = Mock(ServletContext) { + getEffectiveSessionTrackingModes() >> new HashSet(Arrays.asList(SessionTrackingMode.URL)) + } + final mock = Mock(HttpServletRequest) + final request = suite.call(mock) + + when: + final result = runUnderIastTrace { request.getSession() } + + then: + result == session + 1 * mock.getSession() >> session + 1 * mock.getServletContext() >> servletContext + 1 * iastModule.checkSessionTrackingModes(_) + 0 * iastModule._ + + where: + suite << testSuite() + } + + protected E runUnderIastTrace(Closure cl) { + final ddctx = new TagContext().withRequestContextDataIast(iastCtx) + final span = TEST_TRACER.startSpan("test", "test-iast-span", ddctx) + try { + return AgentTracer.activateSpan(span).withCloseable(cl) + } finally { + span.finish() + } + } + + private List> testSuite() { + return [ + { HttpServletRequest request -> new CustomRequest(request: request) }, + { HttpServletRequest request -> new CustomRequestWrapper(new CustomRequest(request: request)) }, + { HttpServletRequest request -> + new HttpServletRequestWrapper(new CustomRequest(request: request)) + } + ] + } + + private List> testSuiteCallSites() { + return [ + { HttpServletRequest request -> new JakartaHttpServletRequestTestSuite(request) }, + { HttpServletRequest request -> new JakartaHttpServletRequestWrapperTestSuite(new CustomRequestWrapper(request)) }, + ] + } + + private static class NuclearBomb extends RuntimeException { + NuclearBomb(final String message) { + super(message) + } + } + + private static class CustomRequest implements HttpServletRequest { + @Delegate + private HttpServletRequest request + } + + private static class CustomRequestWrapper extends HttpServletRequestWrapper { + + CustomRequestWrapper(final HttpServletRequest request) { + super(request) + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaHttpServletResponseInstrumentationTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaHttpServletResponseInstrumentationTest.groovy new file mode 100644 index 00000000000..18e0f405e25 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaHttpServletResponseInstrumentationTest.groovy @@ -0,0 +1,268 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.iast.InstrumentationBridge +import datadog.trace.api.iast.propagation.PropagationModule +import datadog.trace.api.iast.sink.HttpResponseHeaderModule +import datadog.trace.api.iast.sink.UnvalidatedRedirectModule +import datadog.trace.api.iast.util.Cookie as IastCookie +import foo.bar.smoketest.DummyResponse +import jakarta.servlet.http.Cookie +import jakarta.servlet.http.HttpServletResponse +import jakarta.servlet.http.HttpServletResponseWrapper + +class JakartaHttpServletResponseInstrumentationTest extends InstrumentationSpecification { + @Override + protected void configurePreAgent() { + injectSysConfig('dd.iast.enabled', 'true') + } + + @Override + void cleanup() { + InstrumentationBridge.clearIastModules() + } + + void 'insecure cookie added using addCookie'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + final cookie = new Cookie("user-id", "7") + cookie.setMaxAge(3) + + when: + response.addCookie(cookie) + + then: + 1 * module.onCookie({ IastCookie vul -> + vul.cookieName == cookie.name && + vul.cookieValue == cookie.value && + vul.secure == cookie.secure && + vul.httpOnly == cookie.httpOnly && + vul.maxAge == cookie.maxAge + }) + 0 * _ + } + + void 'make sure we do not instrument subclasses of HttpServletResponseWrapper'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final request = Mock(HttpServletResponse) + final wrapper = new HttpServletResponseWrapper(request) + final cookie = new Cookie("user-id", "7") + + when: + wrapper.addCookie(cookie) + + then: + 1 * request.addCookie(cookie) + 0 * _ + } + + void 'secure cookie added using addCookie'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + final cookie = new Cookie("user-id", "7") + cookie.setSecure(true) + cookie.setMaxAge(3) + + when: + response.addCookie(cookie) + + then: + 1 * module.onCookie({ IastCookie vul -> + vul.cookieName == cookie.name && + vul.cookieValue == cookie.value && + vul.secure == cookie.secure && + vul.httpOnly == cookie.httpOnly && + vul.maxAge == cookie.maxAge + }) + 0 * _ + } + + void 'null cookie added using addCookie'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + + when: + response.addCookie((Cookie) null) + + then: + 0 * _ + } + + void 'insecure cookie added using addHeader'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + + when: + response.addHeader("Set-Cookie", "user-id=7") + + then: + 1 * module.onHeader('Set-Cookie', 'user-id=7') + 0 * _ + } + + void 'null parameters added using addHeader'() { + setup: + InstrumentationBridge.registerIastModule(Mock(HttpResponseHeaderModule)) + final response = new DummyResponse() + + when: + response.addHeader((String) null, null) + + then: + noExceptionThrown() + 0 * _ + } + + void 'insecure cookie added using setHeader'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + + when: + response.setHeader("Set-Cookie", "user-id=7") + + then: + 1 * module.onHeader('Set-Cookie', 'user-id=7') + 0 * _ + } + + void 'null parameters added using setHeader'() { + setup: + InstrumentationBridge.registerIastModule(Mock(HttpResponseHeaderModule)) + final response = new DummyResponse() + + when: + response.setHeader((String) null, null) + + then: + noExceptionThrown() + 0 * _ + } + + void 'unvalidated redirect checked using addHeader'() { + setup: + final redirectModule = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(redirectModule) + final response = new DummyResponse() + + when: + response.addHeader("Location", "http://dummy.url.com") + + then: + 1 * redirectModule.onHeader('Location', 'http://dummy.url.com') + 0 * _ + } + + + void 'unvalidated redirect checked setHeader'() { + setup: + final redirectModule = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(redirectModule) + final response = new DummyResponse() + + when: + response.setHeader("Location", "http://dummy.url.com") + + then: + 1 * redirectModule.onHeader('Location', 'http://dummy.url.com') + 0 * _ + } + + + void 'redirection added using sendRedirect'() { + setup: + final redirectModule = Mock(UnvalidatedRedirectModule) + InstrumentationBridge.registerIastModule(redirectModule) + final response = new DummyResponse() + + when: + response.sendRedirect("http://dummy.location.com") + + then: + 1 * redirectModule.onRedirect('http://dummy.location.com') + 0 * _ + } + + void 'null location added using sendRedirect'() { + setup: + final redirectModule = Mock(UnvalidatedRedirectModule) + InstrumentationBridge.registerIastModule(redirectModule) + final response = new DummyResponse() + + when: + response.sendRedirect(null) + + then: + noExceptionThrown() + 0 * redirectModule.onRedirect(_) + 0 * _ + } + + void 'taint encoded url using encodeRedirectURL'() { + setup: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + final url = 'http://dummy.url.com' + def result, expected + + when: + result = response.encodeRedirectURL(url) + + then: + 1 * module.taintStringIfTainted(_, url) >> { args -> expected = args[0] } + 0 * _ + result == expected + } + + void 'taint encoded url using encodeURL'() { + setup: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + final url = 'http://dummy.url.com' + def result, expected + + when: + result = response.encodeURL(url) + + then: + 1 * module.taintStringIfTainted(_, url) >> { args -> expected = args[0] } + 0 * _ + expected == result + } + + void 'test instrumentation with unknown types'() { + setup: + final module = Mock(HttpResponseHeaderModule) + InstrumentationBridge.registerIastModule(module) + final response = new DummyResponse() + + when: + response.addCookie(new DummyResponse.CustomCookie()) + + then: + 0 * _ + + when: + response.addHeader(new DummyResponse.CustomHeaderName(), "value") + + then: + 0 * _ + + when: + response.setHeader(new DummyResponse.CustomHeaderName(), "value") + + then: + 0 * _ + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaJspWriterCallsiteTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaJspWriterCallsiteTest.groovy new file mode 100644 index 00000000000..349b80c8855 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaJspWriterCallsiteTest.groovy @@ -0,0 +1,41 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.iast.InstrumentationBridge +import datadog.trace.api.iast.sink.XssModule +import foo.bar.smoketest.TestJspWriterSuite + +import jakarta.servlet.jsp.JspWriter + +class JakartaJspWriterCallsiteTest extends InstrumentationSpecification{ + + static final STRING = "test" + static final CHAR_ARRAY = STRING.toCharArray() + + @Override + protected void configurePreAgent() { + injectSysConfig("dd.iast.enabled", "true") + } + + void 'test JspWriter'() { + setup: + final iastModule = Mock(XssModule) + InstrumentationBridge.registerIastModule(iastModule) + final writer = Mock(JspWriter) + final suite = new TestJspWriterSuite(writer) + + when: + suite.&"$method".call(args) + + then: + expected * iastModule.onXss(args[0]) + 0 * iastModule._ + + where: + method | args | expected + "printTest" | [STRING] | 1 + "printlnTest" | [STRING] | 1 + "write" | [STRING] | 1 + "printTest" | [CHAR_ARRAY] | 0 + "printlnTest" | [CHAR_ARRAY] | 0 + "write" | [CHAR_ARRAY] | 0 + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaJspWriterFullDetectionCallsiteTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaJspWriterFullDetectionCallsiteTest.groovy new file mode 100644 index 00000000000..ce34a2722e0 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaJspWriterFullDetectionCallsiteTest.groovy @@ -0,0 +1,41 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.iast.InstrumentationBridge +import datadog.trace.api.iast.sink.XssModule +import foo.bar.smoketest.TestJspWriterSuite +import jakarta.servlet.jsp.JspWriter + +class JakartaJspWriterFullDetectionCallsiteTest extends InstrumentationSpecification{ + + static final STRING = "test" + static final CHAR_ARRAY = STRING.toCharArray() + + @Override + protected void configurePreAgent() { + injectSysConfig("dd.iast.enabled", "true") + injectSysConfig("dd.iast.detection.mode", "FULL") + } + + void 'test JspWriter'() { + setup: + final iastModule = Mock(XssModule) + InstrumentationBridge.registerIastModule(iastModule) + final writer = Mock(JspWriter) + final suite = new TestJspWriterSuite(writer) + + when: + suite.&"$method".call(args) + + then: + 1 * iastModule.onXss(args[0]) + 0 * iastModule._ + + where: + method | args + "printTest" | [STRING] + "printlnTest" | [STRING] + "write" | [STRING] + "printTest" | [CHAR_ARRAY] + "printlnTest" | [CHAR_ARRAY] + "write" | [CHAR_ARRAY] + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaMultipartInstrumentationTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaMultipartInstrumentationTest.groovy new file mode 100644 index 00000000000..31bf2a45be3 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/JakartaMultipartInstrumentationTest.groovy @@ -0,0 +1,108 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.iast.IastContext +import datadog.trace.api.iast.InstrumentationBridge +import datadog.trace.api.iast.SourceTypes +import datadog.trace.api.iast.propagation.PropagationModule +import datadog.trace.bootstrap.instrumentation.api.AgentTracer +import datadog.trace.bootstrap.instrumentation.api.TagContext +import foo.bar.smoketest.MockPart + +class JakartaMultipartInstrumentationTest extends InstrumentationSpecification { + + private Object iastCtx + + @Override + protected void configurePreAgent() { + injectSysConfig('dd.iast.enabled', 'true') + } + + void setup() { + iastCtx = Stub(IastContext) + } + + @Override + void cleanup() { + InstrumentationBridge.clearIastModules() + } + + void 'test getName'() { + given: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final part = new MockPart('partName', 'headerName', 'headerValue') + + when: + runUnderIastTrace { part.getName() } + + then: + 1 * module.taintString(iastCtx, 'partName', SourceTypes.REQUEST_MULTIPART_PARAMETER, 'Content-Disposition') + 0 * _ + } + + void 'test getHeader'(){ + given: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final part = new MockPart('partName', 'headerName', 'headerValue') + + when: + runUnderIastTrace { part.getHeader('headerName') } + + then: + 1 * module.taintString(iastCtx, 'headerValue', SourceTypes.REQUEST_MULTIPART_PARAMETER, 'headerName') + 0 * _ + } + + void 'test getHeaders'(){ + given: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final part = new MockPart('partName', 'headerName', 'headerValue') + + when: + runUnderIastTrace { part.getHeaders('headerName') } + + then: + 1 * module.taintString(iastCtx, 'headerValue', SourceTypes.REQUEST_MULTIPART_PARAMETER, 'headerName') + 0 * _ + } + + void 'test getHeaderNames'(){ + given: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final part = new MockPart('partName', 'headerName', 'headerValue') + + when: + runUnderIastTrace { part.getHeaderNames() } + + then: + 1 * module.taintString(iastCtx, 'headerName', SourceTypes.REQUEST_MULTIPART_PARAMETER) + 0 * _ + } + + void 'test getInputStream'(){ + given: + final module = Mock(PropagationModule) + InstrumentationBridge.registerIastModule(module) + final inputStream = new ByteArrayInputStream('inputStream'.getBytes()) + final part = new MockPart('partName', inputStream) + + when: + runUnderIastTrace { part.getInputStream() } + + then: + 1 * module.taintObject(iastCtx, inputStream, SourceTypes.REQUEST_MULTIPART_PARAMETER) + 0 * _ + } + + protected E runUnderIastTrace(Closure cl) { + final ddctx = new TagContext().withRequestContextDataIast(iastCtx) + final span = TEST_TRACER.startSpan("test", "test-iast-span", ddctx) + try { + return AgentTracer.activateSpan(span).withCloseable(cl) + } finally { + span.finish() + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/RumHttpServletResponseWrapperTest.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/RumHttpServletResponseWrapperTest.groovy new file mode 100644 index 00000000000..e4e6fe1de5d --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/groovy/RumHttpServletResponseWrapperTest.groovy @@ -0,0 +1,286 @@ +import datadog.trace.agent.test.InstrumentationSpecification +import datadog.trace.api.rum.RumInjector +import datadog.trace.api.rum.RumTelemetryCollector +import datadog.trace.bootstrap.instrumentation.buffer.InjectingPipeOutputStream +import datadog.trace.bootstrap.instrumentation.buffer.InjectingPipeWriter +import datadog.trace.instrumentation.servlet6.RumHttpServletResponseWrapper +import datadog.trace.instrumentation.servlet6.WrappedServletOutputStream +import spock.lang.Subject + +import java.util.function.LongConsumer +import jakarta.servlet.ServletContext +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse + +class RumHttpServletResponseWrapperTest extends InstrumentationSpecification { + private static final String SERVLET_VERSION = "5" + + def mockRequest = Mock(HttpServletRequest) + def mockResponse = Mock(HttpServletResponse) + def mockServletContext = Mock(ServletContext) + def mockTelemetryCollector = Mock(RumTelemetryCollector) + + @Subject + RumHttpServletResponseWrapper wrapper + + void setup() { + mockRequest.getServletContext() >> mockServletContext + mockServletContext.getEffectiveMajorVersion() >> Integer.parseInt(SERVLET_VERSION) + wrapper = new RumHttpServletResponseWrapper(mockRequest, mockResponse) + RumInjector.setTelemetryCollector(mockTelemetryCollector) + } + + void cleanup() { + RumInjector.setTelemetryCollector(RumTelemetryCollector.NO_OP) + } + + void 'onInjected calls telemetry collector onInjectionSucceed'() { + when: + wrapper.onInjected() + + then: + 1 * mockTelemetryCollector.onInjectionSucceed(SERVLET_VERSION) + } + + void 'getOutputStream with non-HTML content reports skipped'() { + setup: + wrapper.setContentType("text/plain") + + when: + wrapper.getOutputStream() + + then: + 1 * mockTelemetryCollector.onInjectionSkipped(SERVLET_VERSION) + 1 * mockResponse.getOutputStream() + } + + void 'getWriter with non-HTML content reports skipped (setContentType)'() { + when: + wrapper.setContentType("text/plain") + wrapper.getWriter() + + then: + 1 * mockTelemetryCollector.onInjectionSkipped(SERVLET_VERSION) + 1 * mockResponse.setContentType("text/plain") + 1 * mockResponse.getWriter() + } + + void 'getWriter with non-HTML content reports skipped (setHeader)'() { + when: + wrapper.setHeader("Content-Type", "text/plain") + wrapper.getWriter() + + then: + 1 * mockTelemetryCollector.onInjectionSkipped(SERVLET_VERSION) + 1 * mockResponse.setHeader("Content-Type", "text/plain") + 1 * mockResponse.getWriter() + } + + void 'getWriter with non-HTML content reports skipped (addHeader)'() { + when: + wrapper.addHeader("Content-Type", "text/plain") + wrapper.getWriter() + + then: + 1 * mockTelemetryCollector.onInjectionSkipped(SERVLET_VERSION) + 1 * mockResponse.addHeader("Content-Type", "text/plain") + 1 * mockResponse.getWriter() + } + + void 'getOutputStream exception reports failure'() { + setup: + wrapper.setContentType("text/html") + mockResponse.getOutputStream() >> { throw new IOException("stream error") } + + when: + try { + wrapper.getOutputStream() + } catch (IOException ignored) {} + + then: + 1 * mockTelemetryCollector.onInjectionFailed(SERVLET_VERSION, null) + } + + void 'getWriter exception reports failure'() { + setup: + wrapper.setContentType("text/html") + mockResponse.getWriter() >> { throw new IOException("writer error") } + + when: + try { + wrapper.getWriter() + } catch (IOException ignored) {} + + then: + 1 * mockTelemetryCollector.onInjectionFailed(SERVLET_VERSION, null) + } + + void 'setHeader with Content-Security-Policy reports CSP detected'() { + when: + wrapper.setHeader("Content-Security-Policy", "test") + + then: + 1 * mockTelemetryCollector.onContentSecurityPolicyDetected(SERVLET_VERSION) + 1 * mockResponse.setHeader("Content-Security-Policy", "test") + } + + void 'addHeader with Content-Security-Policy reports CSP detected'() { + when: + wrapper.addHeader("Content-Security-Policy", "test") + + then: + 1 * mockTelemetryCollector.onContentSecurityPolicyDetected(SERVLET_VERSION) + 1 * mockResponse.addHeader("Content-Security-Policy", "test") + } + + void 'setHeader with non-CSP header does not report CSP detected'() { + when: + wrapper.setHeader("X-Content-Security-Policy", "test") + + then: + 0 * mockTelemetryCollector.onContentSecurityPolicyDetected(SERVLET_VERSION) + 1 * mockResponse.setHeader("X-Content-Security-Policy", "test") + } + + void 'addHeader with non-CSP header does not report CSP detected'() { + when: + wrapper.addHeader("X-Content-Security-Policy", "test") + + then: + 0 * mockTelemetryCollector.onContentSecurityPolicyDetected(SERVLET_VERSION) + 1 * mockResponse.addHeader("X-Content-Security-Policy", "test") + } + + void 'setCharacterEncoding reports the content-encoding tag with value when injection fails'() { + setup: + wrapper.setContentType("text/html") + wrapper.setCharacterEncoding("UTF-8") + mockResponse.getOutputStream() >> { throw new IOException("stream error") } + + when: + try { + wrapper.getOutputStream() + } catch (IOException ignored) {} + + then: + 1 * mockTelemetryCollector.onInjectionFailed(SERVLET_VERSION, "UTF-8") + } + + void 'setCharacterEncoding reports the content-encoding tag with null when injection fails'() { + setup: + wrapper.setContentType("text/html") + wrapper.setCharacterEncoding((String) null) + mockResponse.getOutputStream() >> { throw new IOException("stream error") } + + when: + try { + wrapper.getOutputStream() + } catch (IOException ignored) {} + + then: + 1 * mockTelemetryCollector.onInjectionFailed(SERVLET_VERSION, null) + } + + // Callback is created in the RumHttpServletResponseWrapper and passed to InjectingPipeOutputStream via WrappedServletOutputStream. + // When the stream is closed, the callback is called with the number of bytes written to the stream and the time taken to write the injection content. + void 'response sizes are reported to the telemetry collector via the WrappedServletOutputStream callback'() { + setup: + def downstream = Mock(jakarta.servlet.ServletOutputStream) + def marker = "".getBytes("UTF-8") + def contentToInject = "".getBytes("UTF-8") + def onBytesWritten = { bytes -> + mockTelemetryCollector.onInjectionResponseSize(SERVLET_VERSION, bytes) + } + def wrappedStream = new WrappedServletOutputStream( + downstream, marker, contentToInject, null, onBytesWritten, null) + def testBytes = "test content" + + when: + wrappedStream.write(testBytes[0..5].getBytes("UTF-8")) + wrappedStream.write(testBytes[6..-1].getBytes("UTF-8")) + wrappedStream.close() + + then: + 1 * mockTelemetryCollector.onInjectionResponseSize(SERVLET_VERSION, testBytes.length()) + } + + void 'response sizes are reported by the InjectingPipeOutputStream callback'() { + setup: + def downstream = Mock(java.io.OutputStream) + def marker = "".getBytes("UTF-8") + def contentToInject = "".getBytes("UTF-8") + def onBytesWritten = Mock(LongConsumer) + def stream = new InjectingPipeOutputStream( + downstream, marker, contentToInject, null, onBytesWritten, null) + def testBytes = "test content" + + when: + stream.write(testBytes[0..5].getBytes("UTF-8")) + stream.write(testBytes[6..-1].getBytes("UTF-8")) + stream.close() + + then: + 1 * onBytesWritten.accept(testBytes.length()) + } + + void 'response sizes are reported by the InjectingPipeWriter callback'() { + setup: + def downstream = Mock(java.io.Writer) + def marker = "".toCharArray() + def contentToInject = "".toCharArray() + def onBytesWritten = Mock(LongConsumer) + def writer = new InjectingPipeWriter( + downstream, marker, contentToInject, null, onBytesWritten, null) + def testBytes = "test content" + + when: + writer.write(testBytes[0..5].toCharArray()) + writer.write(testBytes[6..-1].toCharArray()) + writer.close() + + then: + 1 * onBytesWritten.accept(testBytes.length()) + } + + void 'injection timing is reported by the InjectingPipeOutputStream callback'() { + setup: + def downstream = Mock(java.io.OutputStream) { + write(_) >> { args -> + Thread.sleep(1) // simulate slow write + } + } + def marker = "".getBytes("UTF-8") + def contentToInject = "".getBytes("UTF-8") + def onInjectionTime = Mock(LongConsumer) + def stream = new InjectingPipeOutputStream( + downstream, marker, contentToInject, null, null, onInjectionTime) + + when: + stream.write("content".getBytes("UTF-8")) + stream.close() + + then: + 1 * onInjectionTime.accept({ it > 0 }) + } + + void 'injection timing is reported by the InjectingPipeWriter callback'() { + setup: + def downstream = Mock(java.io.Writer) { + write(_) >> { args -> + Thread.sleep(1) // simulate slow write + } + } + def marker = "".toCharArray() + def contentToInject = "".toCharArray() + def onInjectionTime = Mock(LongConsumer) + def writer = new InjectingPipeWriter( + downstream, marker, contentToInject, null, null, onInjectionTime) + + when: + writer.write("content".toCharArray()) + writer.close() + + then: + 1 * onInjectionTime.accept({ it > 0 }) + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/datadog/trace/instrumentation/servlet6/Servlet60InstrumentationTest.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/datadog/trace/instrumentation/servlet6/Servlet60InstrumentationTest.java new file mode 100644 index 00000000000..c21f12f1f48 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/datadog/trace/instrumentation/servlet6/Servlet60InstrumentationTest.java @@ -0,0 +1,590 @@ +package datadog.trace.instrumentation.servlet6; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import datadog.trace.api.iast.InstrumentationBridge; +import datadog.trace.api.iast.sink.UnvalidatedRedirectModule; +import jakarta.servlet.ServletConnection; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.lang.reflect.Method; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for Servlet 6.0-specific instrumentation logic. + * + *

Tests cover: + * + *

    + *
  • 3-arg {@code sendRedirect} advice fires unvalidated-redirect sink + *
  • Servlet 6.0 span tags (request ID, protocol request ID, connection ID, protocol name) + *
  • {@link RumHttpServletResponseWrapper60#sendRedirect(String, int, boolean)} commits RUM + * before delegating + *
+ */ +class Servlet60InstrumentationTest { + + @AfterEach + void cleanup() { + InstrumentationBridge.UNVALIDATED_REDIRECT = null; + } + + // ------------------------------------------------------------------------- + // SendRedirect3ArgAdvice — advice logic (unit-level, no bytecode weaving) + // ------------------------------------------------------------------------- + + @Test + void sendRedirect3ArgAdvice_callsOnRedirectForNonEmptyLocation() { + UnvalidatedRedirectModule module = mock(UnvalidatedRedirectModule.class); + InstrumentationBridge.registerIastModule(module); + + // Simulate what the advice does + String location = "https://example.com/redirect"; + final UnvalidatedRedirectModule m = InstrumentationBridge.UNVALIDATED_REDIRECT; + if (m != null && location != null && !location.isEmpty()) { + m.onRedirect(location); + } + + verify(module).onRedirect("https://example.com/redirect"); + } + + @Test + void sendRedirect3ArgAdvice_doesNotCallOnRedirectForNullLocation() { + UnvalidatedRedirectModule module = mock(UnvalidatedRedirectModule.class); + InstrumentationBridge.registerIastModule(module); + + String location = null; + final UnvalidatedRedirectModule m = InstrumentationBridge.UNVALIDATED_REDIRECT; + if (m != null && location != null && !location.isEmpty()) { + m.onRedirect(location); + } + + verify(module, never()).onRedirect(any()); + } + + @Test + void sendRedirect3ArgAdvice_doesNotCallOnRedirectForEmptyLocation() { + UnvalidatedRedirectModule module = mock(UnvalidatedRedirectModule.class); + InstrumentationBridge.registerIastModule(module); + + String location = ""; + final UnvalidatedRedirectModule m = InstrumentationBridge.UNVALIDATED_REDIRECT; + if (m != null && location != null && !location.isEmpty()) { + m.onRedirect(location); + } + + verify(module, never()).onRedirect(any()); + } + + // ------------------------------------------------------------------------- + // Servlet 6.0 span tag logic (mirrors what JakartaServletAdvice.after does) + // ------------------------------------------------------------------------- + + @Test + void spanTagsPopulatedWhenServlet60ApisReturnValues() { + // Create a stub request that returns Servlet 6.0 API values + StubHttpServletRequest req = new StubHttpServletRequest(); + req.requestId = "req-abc-123"; + req.protocolRequestId = "proto-req-456"; + req.connectionId = "conn-789"; + req.protocol = "HTTP/2.0"; + + // Simulate the tagging logic from JakartaServletAdvice.after + MockSpan span = new MockSpan(); + try { + String requestId = req.getRequestId(); + if (requestId != null && !requestId.isEmpty()) { + span.setTag("http.request_id", requestId); + } + String protocolRequestId = req.getProtocolRequestId(); + if (protocolRequestId != null && !protocolRequestId.isEmpty()) { + span.setTag("network.protocol_request_id", protocolRequestId); + } + jakarta.servlet.ServletConnection conn = req.getServletConnection(); + if (conn != null) { + String connId = conn.getConnectionId(); + if (connId != null) span.setTag("network.connection.id", connId); + String protocol = conn.getProtocol(); + if (protocol != null) span.setTag("network.protocol.name", protocol); + } + } catch (Exception ignored) { + } + + assertEquals("req-abc-123", span.tags.get("http.request_id")); + assertEquals("proto-req-456", span.tags.get("network.protocol_request_id")); + assertEquals("conn-789", span.tags.get("network.connection.id")); + assertEquals("HTTP/2.0", span.tags.get("network.protocol.name")); + } + + @Test + void spanTagsNotSetWhenServlet60ApisReturnEmpty() { + StubHttpServletRequest req = new StubHttpServletRequest(); + req.requestId = ""; + req.protocolRequestId = null; + req.connectionId = null; + req.protocol = null; + + MockSpan span = new MockSpan(); + try { + String requestId = req.getRequestId(); + if (requestId != null && !requestId.isEmpty()) { + span.setTag("http.request_id", requestId); + } + String protocolRequestId = req.getProtocolRequestId(); + if (protocolRequestId != null && !protocolRequestId.isEmpty()) { + span.setTag("network.protocol_request_id", protocolRequestId); + } + jakarta.servlet.ServletConnection conn = req.getServletConnection(); + if (conn != null) { + String connId = conn.getConnectionId(); + if (connId != null) span.setTag("network.connection.id", connId); + String protocol = conn.getProtocol(); + if (protocol != null) span.setTag("network.protocol.name", protocol); + } + } catch (Exception ignored) { + } + + assertNull(span.tags.get("http.request_id")); + assertNull(span.tags.get("network.protocol_request_id")); + assertNull(span.tags.get("network.connection.id")); + assertNull(span.tags.get("network.protocol.name")); + } + + @Test + void spanTagsDoNotPropagateExceptions() { + // If Servlet 6.0 APIs throw, the advice should swallow it + StubHttpServletRequest req = + new StubHttpServletRequest() { + @Override + public String getRequestId() { + throw new RuntimeException("not supported"); + } + }; + + MockSpan span = new MockSpan(); + // Should not throw + try { + String requestId = req.getRequestId(); + if (requestId != null && !requestId.isEmpty()) { + span.setTag("http.request_id", requestId); + } + } catch (Exception ignored) { + // advice suppresses this + } + + assertNull(span.tags.get("http.request_id")); + } + + // ------------------------------------------------------------------------- + // RumHttpServletResponseWrapper60 — sendRedirect(3-arg) commits before delegate + // ------------------------------------------------------------------------- + + @Test + void rumWrapper60_sendRedirect3Arg_commitsBeforeDelegating() throws IOException { + // Verify the 3-arg method exists on RumHttpServletResponseWrapper60 + boolean has3ArgMethod = false; + for (Method m : RumHttpServletResponseWrapper60.class.getDeclaredMethods()) { + if (m.getName().equals("sendRedirect") && m.getParameterCount() == 3) { + has3ArgMethod = true; + break; + } + } + assert has3ArgMethod : "RumHttpServletResponseWrapper60 must override 3-arg sendRedirect"; + } + + // ------------------------------------------------------------------------- + // Helper stubs + // ------------------------------------------------------------------------- + + /** Minimal stub for HttpServletRequest that exposes Servlet 6.0 APIs. */ + private static class StubHttpServletRequest implements HttpServletRequest { + String requestId; + String protocolRequestId; + String connectionId; + String protocol; + + @Override + public String getRequestId() { + return requestId; + } + + @Override + public String getProtocolRequestId() { + return protocolRequestId; + } + + @Override + public ServletConnection getServletConnection() { + if (connectionId == null && protocol == null) { + return null; + } + return new ServletConnection() { + @Override + public String getConnectionId() { + return connectionId; + } + + @Override + public String getProtocol() { + return protocol; + } + + @Override + public String getProtocolConnectionId() { + return null; + } + + @Override + public boolean isSecure() { + return false; + } + }; + } + + // --- Minimal no-op implementations for the rest of HttpServletRequest --- + + @Override + public String getAuthType() { + return null; + } + + @Override + public jakarta.servlet.http.Cookie[] getCookies() { + return new jakarta.servlet.http.Cookie[0]; + } + + @Override + public long getDateHeader(String name) { + return 0; + } + + @Override + public String getHeader(String name) { + return null; + } + + @Override + public java.util.Enumeration getHeaders(String name) { + return java.util.Collections.emptyEnumeration(); + } + + @Override + public java.util.Enumeration getHeaderNames() { + return java.util.Collections.emptyEnumeration(); + } + + @Override + public int getIntHeader(String name) { + return 0; + } + + @Override + public String getMethod() { + return "GET"; + } + + @Override + public String getPathInfo() { + return null; + } + + @Override + public String getPathTranslated() { + return null; + } + + @Override + public String getContextPath() { + return ""; + } + + @Override + public String getQueryString() { + return null; + } + + @Override + public String getRemoteUser() { + return null; + } + + @Override + public boolean isUserInRole(String role) { + return false; + } + + @Override + public java.security.Principal getUserPrincipal() { + return null; + } + + @Override + public String getRequestedSessionId() { + return null; + } + + @Override + public String getRequestURI() { + return "/"; + } + + @Override + public StringBuffer getRequestURL() { + return new StringBuffer("http://localhost/"); + } + + @Override + public String getServletPath() { + return ""; + } + + @Override + public jakarta.servlet.http.HttpSession getSession(boolean create) { + return null; + } + + @Override + public jakarta.servlet.http.HttpSession getSession() { + return null; + } + + @Override + public String changeSessionId() { + return null; + } + + @Override + public boolean isRequestedSessionIdValid() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromCookie() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromURL() { + return false; + } + + @Override + public boolean authenticate(HttpServletResponse response) throws IOException { + return false; + } + + @Override + public void login(String username, String password) {} + + @Override + public void logout() {} + + @Override + public java.util.Collection getParts() { + return java.util.Collections.emptyList(); + } + + @Override + public jakarta.servlet.http.Part getPart(String name) { + return null; + } + + @Override + public T upgrade( + Class httpUpgradeHandlerClass) { + return null; + } + + @Override + public Object getAttribute(String name) { + return null; + } + + @Override + public java.util.Enumeration getAttributeNames() { + return java.util.Collections.emptyEnumeration(); + } + + @Override + public String getCharacterEncoding() { + return null; + } + + @Override + public void setCharacterEncoding(String env) {} + + @Override + public int getContentLength() { + return 0; + } + + @Override + public long getContentLengthLong() { + return 0; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public jakarta.servlet.ServletInputStream getInputStream() { + return null; + } + + @Override + public String getParameter(String name) { + return null; + } + + @Override + public java.util.Enumeration getParameterNames() { + return java.util.Collections.emptyEnumeration(); + } + + @Override + public String[] getParameterValues(String name) { + return new String[0]; + } + + @Override + public java.util.Map getParameterMap() { + return java.util.Collections.emptyMap(); + } + + @Override + public String getProtocol() { + return "HTTP/1.1"; + } + + @Override + public String getScheme() { + return "http"; + } + + @Override + public String getServerName() { + return "localhost"; + } + + @Override + public int getServerPort() { + return 8080; + } + + @Override + public java.io.BufferedReader getReader() { + return null; + } + + @Override + public String getRemoteAddr() { + return "127.0.0.1"; + } + + @Override + public String getRemoteHost() { + return "localhost"; + } + + @Override + public void setAttribute(String name, Object o) {} + + @Override + public void removeAttribute(String name) {} + + @Override + public java.util.Locale getLocale() { + return java.util.Locale.getDefault(); + } + + @Override + public java.util.Enumeration getLocales() { + return java.util.Collections.emptyEnumeration(); + } + + @Override + public boolean isSecure() { + return false; + } + + @Override + public jakarta.servlet.RequestDispatcher getRequestDispatcher(String path) { + return null; + } + + @Override + public int getRemotePort() { + return 0; + } + + @Override + public String getLocalName() { + return "localhost"; + } + + @Override + public String getLocalAddr() { + return "127.0.0.1"; + } + + @Override + public int getLocalPort() { + return 8080; + } + + @Override + public jakarta.servlet.ServletContext getServletContext() { + return null; + } + + @Override + public jakarta.servlet.AsyncContext startAsync() { + return null; + } + + @Override + public jakarta.servlet.AsyncContext startAsync( + jakarta.servlet.ServletRequest servletRequest, + jakarta.servlet.ServletResponse servletResponse) { + return null; + } + + @Override + public boolean isAsyncStarted() { + return false; + } + + @Override + public boolean isAsyncSupported() { + return false; + } + + @Override + public jakarta.servlet.AsyncContext getAsyncContext() { + return null; + } + + @Override + public jakarta.servlet.DispatcherType getDispatcherType() { + return jakarta.servlet.DispatcherType.REQUEST; + } + } + + /** Minimal span stub for capturing tags. */ + private static class MockSpan { + final java.util.Map tags = new java.util.HashMap<>(); + + void setTag(String key, String value) { + tags.put(key, value); + } + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyContext.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyContext.java new file mode 100644 index 00000000000..b9ed9ced6aa --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyContext.java @@ -0,0 +1,288 @@ +package foo.bar.smoketest; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterRegistration; +import jakarta.servlet.RequestDispatcher; +import jakarta.servlet.Servlet; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRegistration; +import jakarta.servlet.SessionCookieConfig; +import jakarta.servlet.SessionTrackingMode; +import jakarta.servlet.descriptor.JspConfigDescriptor; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.EventListener; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class DummyContext implements ServletContext { + @Override + public String getContextPath() { + return null; + } + + @Override + public ServletContext getContext(String s) { + return null; + } + + @Override + public int getMajorVersion() { + return 0; + } + + @Override + public int getMinorVersion() { + return 0; + } + + @Override + public int getEffectiveMajorVersion() { + return 0; + } + + @Override + public int getEffectiveMinorVersion() { + return 0; + } + + @Override + public String getMimeType(String s) { + return null; + } + + @Override + public Set getResourcePaths(String s) { + return null; + } + + @Override + public URL getResource(String s) throws MalformedURLException { + return null; + } + + @Override + public InputStream getResourceAsStream(String s) { + return null; + } + + @Override + public RequestDispatcher getRequestDispatcher(String s) { + return null; + } + + @Override + public RequestDispatcher getNamedDispatcher(String s) { + return null; + } + + public Servlet getServlet(String s) throws ServletException { + return null; + } + + public Enumeration getServlets() { + return null; + } + + public Enumeration getServletNames() { + return null; + } + + @Override + public void log(String s) {} + + public void log(Exception e, String s) {} + + @Override + public void log(String s, Throwable throwable) {} + + @Override + public String getRealPath(String s) { + return null; + } + + @Override + public String getServerInfo() { + return null; + } + + @Override + public String getInitParameter(String s) { + return null; + } + + @Override + public Enumeration getInitParameterNames() { + return null; + } + + @Override + public boolean setInitParameter(String s, String s1) { + return false; + } + + @Override + public Object getAttribute(String s) { + return null; + } + + @Override + public Enumeration getAttributeNames() { + return null; + } + + @Override + public void setAttribute(String s, Object o) {} + + @Override + public void removeAttribute(String s) {} + + @Override + public String getServletContextName() { + return null; + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, String s1) { + return null; + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Servlet servlet) { + return null; + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Class aClass) { + return null; + } + + @Override + public ServletRegistration.Dynamic addJspFile(String s, String s1) { + return null; + } + + @Override + public T createServlet(Class aClass) throws ServletException { + return null; + } + + @Override + public ServletRegistration getServletRegistration(String s) { + return null; + } + + @Override + public Map getServletRegistrations() { + return null; + } + + @Override + public FilterRegistration.Dynamic addFilter(String s, String s1) { + return null; + } + + @Override + public FilterRegistration.Dynamic addFilter(String s, Filter filter) { + return null; + } + + @Override + public FilterRegistration.Dynamic addFilter(String s, Class aClass) { + return null; + } + + @Override + public T createFilter(Class aClass) throws ServletException { + return null; + } + + @Override + public FilterRegistration getFilterRegistration(String s) { + return null; + } + + @Override + public Map getFilterRegistrations() { + return null; + } + + @Override + public SessionCookieConfig getSessionCookieConfig() { + return null; + } + + @Override + public void setSessionTrackingModes(Set set) {} + + @Override + public Set getDefaultSessionTrackingModes() { + return null; + } + + @Override + public Set getEffectiveSessionTrackingModes() { + return new HashSet<>(Arrays.asList(SessionTrackingMode.COOKIE, SessionTrackingMode.URL)); + } + + @Override + public void addListener(String s) {} + + @Override + public void addListener(T t) {} + + @Override + public void addListener(Class aClass) {} + + @Override + public T createListener(Class aClass) throws ServletException { + return null; + } + + @Override + public JspConfigDescriptor getJspConfigDescriptor() { + return null; + } + + @Override + public ClassLoader getClassLoader() { + return null; + } + + @Override + public void declareRoles(String... strings) {} + + @Override + public String getVirtualServerName() { + return null; + } + + @Override + public int getSessionTimeout() { + return 0; + } + + @Override + public void setSessionTimeout(int i) {} + + @Override + public String getRequestCharacterEncoding() { + return null; + } + + @Override + public void setRequestCharacterEncoding(String s) {} + + @Override + public String getResponseCharacterEncoding() { + return null; + } + + @Override + public void setResponseCharacterEncoding(String s) {} +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyHttpServlet.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyHttpServlet.java new file mode 100644 index 00000000000..6a8e605a8c1 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyHttpServlet.java @@ -0,0 +1,53 @@ +package foo.bar.smoketest; + +import jakarta.servlet.ServletConfig; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.Enumeration; + +public class DummyHttpServlet extends HttpServlet { + + DummyHttpServlet() {} + + private void callPublicServiceMethod(HttpServletRequest req, HttpServletResponse resp) + throws ServletException, IOException { + service((ServletRequest) req, (ServletResponse) resp); + } + + @Override + public void service(ServletRequest req, ServletResponse res) + throws ServletException, IOException { + // do nothing + } + + @Override + public ServletConfig getServletConfig() { + return new ServletConfig() { + @Override + public String getServletName() { + return "test"; + } + + @Override + public ServletContext getServletContext() { + return new DummyContext(); + } + + @Override + public String getInitParameter(String s) { + return s; + } + + @Override + public Enumeration getInitParameterNames() { + return null; + } + }; + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyRequest.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyRequest.java new file mode 100644 index 00000000000..d7336aa7f22 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyRequest.java @@ -0,0 +1,381 @@ +package foo.bar.smoketest; + +import jakarta.servlet.AsyncContext; +import jakarta.servlet.DispatcherType; +import jakarta.servlet.RequestDispatcher; +import jakarta.servlet.ServletConnection; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.servlet.http.HttpSession; +import jakarta.servlet.http.HttpUpgradeHandler; +import jakarta.servlet.http.Part; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.security.Principal; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Locale; +import java.util.Map; + +public class DummyRequest implements HttpServletRequest { + + private ServletContext context = new DummyContext(); + + @Override + public String getAuthType() { + return null; + } + + @Override + public Cookie[] getCookies() { + return new Cookie[0]; + } + + @Override + public long getDateHeader(String s) { + return 0; + } + + @Override + public String getHeader(String s) { + return null; + } + + @Override + public Enumeration getHeaders(String s) { + return null; + } + + @Override + public Enumeration getHeaderNames() { + return null; + } + + @Override + public int getIntHeader(String s) { + return 0; + } + + @Override + public String getMethod() { + return null; + } + + @Override + public String getPathInfo() { + return null; + } + + @Override + public String getPathTranslated() { + return null; + } + + @Override + public String getContextPath() { + return "/"; + } + + @Override + public String getQueryString() { + return null; + } + + @Override + public String getRemoteUser() { + return null; + } + + @Override + public boolean isUserInRole(String s) { + return false; + } + + @Override + public Principal getUserPrincipal() { + return null; + } + + @Override + public String getRequestedSessionId() { + return null; + } + + @Override + public String getRequestURI() { + return "/test"; + } + + @Override + public StringBuffer getRequestURL() { + return null; + } + + @Override + public String getServletPath() { + return null; + } + + @Override + public HttpSession getSession(boolean b) { + return null; + } + + @Override + public HttpSession getSession() { + return null; + } + + @Override + public String changeSessionId() { + return null; + } + + @Override + public boolean isRequestedSessionIdValid() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromCookie() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromURL() { + return false; + } + + public boolean isRequestedSessionIdFromUrl() { + return false; + } + + @Override + public boolean authenticate(HttpServletResponse httpServletResponse) + throws IOException, ServletException { + return false; + } + + @Override + public void login(String s, String s1) throws ServletException {} + + @Override + public void logout() throws ServletException {} + + @Override + public Collection getParts() throws IOException, ServletException { + return null; + } + + @Override + public Part getPart(String s) throws IOException, ServletException { + return null; + } + + @Override + public T upgrade(Class aClass) + throws IOException, ServletException { + return null; + } + + @Override + public Object getAttribute(String s) { + return null; + } + + @Override + public Enumeration getAttributeNames() { + return null; + } + + @Override + public String getCharacterEncoding() { + return null; + } + + @Override + public void setCharacterEncoding(String s) throws UnsupportedEncodingException {} + + @Override + public int getContentLength() { + return 0; + } + + @Override + public long getContentLengthLong() { + return 0; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public ServletInputStream getInputStream() throws IOException { + return null; + } + + @Override + public String getParameter(String s) { + return null; + } + + @Override + public Enumeration getParameterNames() { + return null; + } + + @Override + public String[] getParameterValues(String s) { + return new String[0]; + } + + @Override + public Map getParameterMap() { + return null; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public String getScheme() { + return null; + } + + @Override + public String getServerName() { + return null; + } + + @Override + public int getServerPort() { + return 0; + } + + @Override + public BufferedReader getReader() throws IOException { + return null; + } + + @Override + public String getRemoteAddr() { + return null; + } + + @Override + public String getRemoteHost() { + return null; + } + + @Override + public void setAttribute(String s, Object o) {} + + @Override + public void removeAttribute(String s) {} + + @Override + public Locale getLocale() { + return null; + } + + @Override + public Enumeration getLocales() { + return null; + } + + @Override + public boolean isSecure() { + return false; + } + + @Override + public RequestDispatcher getRequestDispatcher(String s) { + return null; + } + + public String getRealPath(String s) { + return null; + } + + @Override + public String getRequestId() { + return null; + } + + @Override + public String getProtocolRequestId() { + return null; + } + + @Override + public ServletConnection getServletConnection() { + return null; + } + + @Override + public int getRemotePort() { + return 0; + } + + @Override + public String getLocalName() { + return null; + } + + @Override + public String getLocalAddr() { + return null; + } + + @Override + public int getLocalPort() { + return 0; + } + + @Override + public ServletContext getServletContext() { + return context; + } + + @Override + public AsyncContext startAsync() throws IllegalStateException { + return null; + } + + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) + throws IllegalStateException { + return null; + } + + @Override + public boolean isAsyncStarted() { + return false; + } + + @Override + public boolean isAsyncSupported() { + return false; + } + + @Override + public AsyncContext getAsyncContext() { + return null; + } + + @Override + public DispatcherType getDispatcherType() { + return null; + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyResponse.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyResponse.java new file mode 100644 index 00000000000..2fd1e866b13 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/DummyResponse.java @@ -0,0 +1,164 @@ +package foo.bar.smoketest; + +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Collection; +import java.util.Locale; + +public class DummyResponse implements HttpServletResponse { + @Override + public void addCookie(Cookie cookie) {} + + public void addCookie(CustomCookie cookie) {} + + @Override + public boolean containsHeader(String name) { + return false; + } + + @Override + public String encodeURL(String url) { + return "Encoded_" + url; + } + + @Override + public String encodeRedirectURL(String url) { + return "Encoded_" + url; + } + + public String encodeUrl(String url) { + return null; + } + + public String encodeRedirectUrl(String url) { + return null; + } + + @Override + public void sendError(int sc, String msg) throws IOException {} + + @Override + public void sendError(int sc) throws IOException {} + + @Override + public void sendRedirect(String location) throws IOException {} + + @Override + public void setDateHeader(String name, long date) {} + + @Override + public void addDateHeader(String name, long date) {} + + @Override + public void setHeader(String name, String value) {} + + public void setHeader(CustomHeaderName name, String value) {} + + @Override + public void addHeader(String name, String value) {} + + public void addHeader(CustomHeaderName name, String value) {} + + @Override + public void setIntHeader(String name, int value) {} + + @Override + public void addIntHeader(String name, int value) {} + + @Override + public void setStatus(int sc) {} + + public void setStatus(int sc, String sm) {} + + @Override + public void sendRedirect(String location, int sc, boolean clearBuffer) throws IOException {} + + @Override + public int getStatus() { + return 0; + } + + @Override + public String getHeader(String name) { + return null; + } + + @Override + public Collection getHeaders(String name) { + return null; + } + + @Override + public Collection getHeaderNames() { + return null; + } + + @Override + public String getCharacterEncoding() { + return null; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public ServletOutputStream getOutputStream() throws IOException { + return null; + } + + @Override + public PrintWriter getWriter() throws IOException { + return null; + } + + @Override + public void setCharacterEncoding(String charset) {} + + @Override + public void setContentLength(int len) {} + + @Override + public void setContentLengthLong(long len) {} + + @Override + public void setContentType(String type) {} + + @Override + public void setBufferSize(int size) {} + + @Override + public int getBufferSize() { + return 0; + } + + @Override + public void flushBuffer() throws IOException {} + + @Override + public void resetBuffer() {} + + @Override + public boolean isCommitted() { + return false; + } + + @Override + public void reset() {} + + @Override + public void setLocale(Locale loc) {} + + @Override + public Locale getLocale() { + return null; + } + + public static class CustomCookie {} + + public static class CustomHeaderName {} +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/JakartaHttpServletRequestTestSuite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/JakartaHttpServletRequestTestSuite.java new file mode 100644 index 00000000000..41f3b87641e --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/JakartaHttpServletRequestTestSuite.java @@ -0,0 +1,31 @@ +package foo.bar.smoketest; + +import jakarta.servlet.http.HttpServletRequest; + +public class JakartaHttpServletRequestTestSuite implements ServletRequestTestSuite { + private final HttpServletRequest request; + + public JakartaHttpServletRequestTestSuite(final HttpServletRequest request) { + this.request = request; + } + + @Override + public String getRequestURI() { + return request.getRequestURI(); + } + + @Override + public String getPathInfo() { + return request.getPathInfo(); + } + + @Override + public String getPathTranslated() { + return request.getPathTranslated(); + } + + @Override + public StringBuffer getRequestURL() { + return request.getRequestURL(); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/JakartaHttpServletRequestWrapperTestSuite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/JakartaHttpServletRequestWrapperTestSuite.java new file mode 100644 index 00000000000..80039d4cdca --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/JakartaHttpServletRequestWrapperTestSuite.java @@ -0,0 +1,31 @@ +package foo.bar.smoketest; + +import jakarta.servlet.http.HttpServletRequestWrapper; + +public class JakartaHttpServletRequestWrapperTestSuite implements ServletRequestTestSuite { + private final HttpServletRequestWrapper request; + + public JakartaHttpServletRequestWrapperTestSuite(final HttpServletRequestWrapper request) { + this.request = request; + } + + @Override + public String getRequestURI() { + return request.getRequestURI(); + } + + @Override + public String getPathInfo() { + return request.getPathInfo(); + } + + @Override + public String getPathTranslated() { + return request.getPathTranslated(); + } + + @Override + public StringBuffer getRequestURL() { + return request.getRequestURL(); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/MockPart.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/MockPart.java new file mode 100644 index 00000000000..a8f8fd288b9 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/MockPart.java @@ -0,0 +1,82 @@ +package foo.bar.smoketest; + +import jakarta.servlet.http.Part; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +public class MockPart implements Part { + private final String name; + + private final Map> headers; + private final InputStream inputStream; + + public MockPart(final String name, final Map> headers) { + this.name = name; + this.headers = headers; + this.inputStream = null; + } + + public MockPart(final String name, final String headerName, final String... headerValue) { + this.name = name; + this.headers = new HashMap<>(); + this.headers.put(headerName, Arrays.asList(headerValue)); + this.inputStream = null; + } + + public MockPart(final String name, final InputStream inputStream) { + this.name = name; + this.headers = new HashMap<>(); + this.inputStream = inputStream; + } + + @Override + public InputStream getInputStream() throws IOException { + return inputStream; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getSubmittedFileName() { + return null; + } + + @Override + public long getSize() { + return 0; + } + + @Override + public void write(String fileName) throws IOException {} + + @Override + public void delete() throws IOException {} + + @Override + public String getHeader(String name) { + final Collection values = this.headers.get(name); + return values == null || values.isEmpty() ? null : values.iterator().next(); + } + + @Override + public Collection getHeaders(String name) { + return headers.get(name); + } + + @Override + public Collection getHeaderNames() { + return this.headers.keySet(); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/ServletRequestTestSuite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/ServletRequestTestSuite.java new file mode 100644 index 00000000000..c1e5e45cde4 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/ServletRequestTestSuite.java @@ -0,0 +1,12 @@ +package foo.bar.smoketest; + +public interface ServletRequestTestSuite { + + String getRequestURI(); + + String getPathInfo(); + + String getPathTranslated(); + + StringBuffer getRequestURL(); +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/TestJspWriterSuite.java b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/TestJspWriterSuite.java new file mode 100644 index 00000000000..350d8049486 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/test/java/foo/bar/smoketest/TestJspWriterSuite.java @@ -0,0 +1,37 @@ +package foo.bar.smoketest; + +import jakarta.servlet.jsp.JspWriter; +import java.io.IOException; + +public class TestJspWriterSuite { + + JspWriter writer; + + public TestJspWriterSuite(final JspWriter writer) { + this.writer = writer; + } + + public void printlnTest(char[] x) throws IOException { + writer.println(x); + } + + public void printlnTest(String x) throws IOException { + writer.println(x); + } + + public void printTest(char[] s) throws IOException { + writer.print(s); + } + + public void printTest(String s) throws IOException { + writer.print(s); + } + + public void write(char[] s) throws IOException { + writer.write(s); + } + + public void write(String s) throws IOException { + writer.write(s); + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/AsyncRumServlet.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/AsyncRumServlet.groovy new file mode 100644 index 00000000000..536bc3d9b27 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/AsyncRumServlet.groovy @@ -0,0 +1,57 @@ +package datadog.trace.instrumentation.servlet5 + +import jakarta.servlet.AsyncContext +import jakarta.servlet.ServletException +import jakarta.servlet.http.HttpServlet +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse + +class AsyncRumServlet extends HttpServlet { + private final String mimeType + + AsyncRumServlet(String mime) { + this.mimeType = mime + } + + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + // write a partial content + resp.getWriter().println("\n" + + "") + // finish it later + final AsyncContext asyncContext = req.startAsync() + final String mime = mimeType + new Timer().schedule(new TimerTask() { + @Override + void run() { + def writer = asyncContext.getResponse().getWriter() + try { + asyncContext.getResponse().setContentType(mime) + writer.println( + "\n" + + " \n" + + " This is the title of the webpage!\n" + + " \n" + + " \n" + + "

This is an example paragraph. Anything in the body tag will appear on the page, just like this p tag and its contents.

\n" + + " \n" + + "") + } finally { + asyncContext.complete() + } + } + }, 2000) + } +} + +class HtmlAsyncRumServlet extends AsyncRumServlet { + HtmlAsyncRumServlet() { + super("text/html") + } +} + +class XmlAsyncRumServlet extends AsyncRumServlet { + XmlAsyncRumServlet() { + super("text/xml") + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/RumServlet.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/RumServlet.groovy new file mode 100644 index 00000000000..af2851fda83 --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/RumServlet.groovy @@ -0,0 +1,43 @@ +package datadog.trace.instrumentation.servlet5 + +import jakarta.servlet.ServletException +import jakarta.servlet.http.HttpServlet +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse + +class RumServlet extends HttpServlet { + private final String mimeType + + RumServlet(String mime) { + this.mimeType = mime + } + + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + try (def writer = resp.getWriter()) { + resp.setContentType(mimeType) + writer.println("\n" + + "\n" + + "\n" + + " \n" + + " This is the title of the webpage!\n" + + " \n" + + " \n" + + "

This is an example paragraph. Anything in the body tag will appear on the page, just like this p tag and its contents.

\n" + + " \n" + + "") + } + } +} + +class HtmlRumServlet extends RumServlet { + HtmlRumServlet() { + super("text/html") + } +} + +class XmlRumServlet extends RumServlet { + XmlRumServlet() { + super("text/xml") + } +} diff --git a/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/TestServlet5.groovy b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/TestServlet5.groovy new file mode 100644 index 00000000000..51e7c974f6d --- /dev/null +++ b/dd-java-agent/instrumentation/servlet/jakarta-servlet-6.0/src/testFixtures/groovy/datadog/trace/instrumentation/servlet5/TestServlet5.groovy @@ -0,0 +1,127 @@ +package datadog.trace.instrumentation.servlet5 + +import datadog.appsec.api.blocking.Blocking +import datadog.trace.agent.test.base.HttpServerTest +import jakarta.servlet.ServletException +import jakarta.servlet.http.HttpServlet +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse + +import java.lang.reflect.Field +import java.lang.reflect.Modifier + +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.BODY_MULTIPART +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.BODY_URLENCODED +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.CREATED +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.CREATED_IS +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.CUSTOM_EXCEPTION +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.ERROR +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.EXCEPTION +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.FORWARDED +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.NOT_FOUND +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ENCODED_BOTH +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_ENCODED_QUERY +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.QUERY_PARAM +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.REDIRECT +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.SESSION_ID +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.SUCCESS +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.USER_BLOCK +import static datadog.trace.agent.test.base.HttpServerTest.ServerEndpoint.forPath +import static datadog.trace.agent.test.base.HttpServerTest.controller +import static datadog.trace.agent.test.base.HttpServerTest.IG_RESPONSE_HEADER +import static datadog.trace.agent.test.base.HttpServerTest.IG_RESPONSE_HEADER_VALUE + +class TestServlet5 extends HttpServlet { + @Override + protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + String path = req.requestURI.substring(req.getContextPath().length()) + + HttpServerTest.ServerEndpoint endpoint = forPath(path) + controller(endpoint) { + resp.contentType = "text/plain" + resp.addHeader(IG_RESPONSE_HEADER, IG_RESPONSE_HEADER_VALUE) + switch (endpoint) { + case SUCCESS: + resp.status = endpoint.status + resp.writer.print(endpoint.body) + break + case CREATED: + resp.status = endpoint.status + resp.writer.print("${endpoint.body}: ${req.reader.text}") + break + case CREATED_IS: + resp.status = endpoint.status + def stream = req.inputStream + resp.writer.print("${endpoint.body}: ${stream.getText('UTF-8')}") + try { + Field f = stream.getClass().getField('is') + def innerStream = f.get(stream) + def method = innerStream.getClass().getMethod('isFinished') + if ((method.getModifiers() & Modifier.ABSTRACT) == 0) { + if (!stream.isFinished()) { + throw new RuntimeException("Not finished") + } + } + } catch (NoSuchMethodException | NoSuchFieldException mnf) {} + break + case FORWARDED: + resp.status = endpoint.status + resp.writer.print(req.getHeader("x-forwarded-for")) + break + case BODY_MULTIPART: + case BODY_URLENCODED: + resp.status = endpoint.status + resp.writer.print( + req.parameterMap + .findAll { + it.key != 'ignore' + } + .collectEntries { [it.key, it.value as List] } as String) + break + case QUERY_ENCODED_BOTH: + case QUERY_ENCODED_QUERY: + case QUERY_PARAM: + resp.status = endpoint.status + resp.writer.print(endpoint.bodyForQuery(req.queryString)) + break + case USER_BLOCK: + Blocking.forUser('user-to-block').blockIfMatch() + break + case REDIRECT: + resp.sendRedirect(endpoint.body) + break + case ERROR: + resp.sendError(endpoint.status, endpoint.body) + break + case EXCEPTION: + throw new Exception(endpoint.body) + case CUSTOM_EXCEPTION: + throw new InputMismatchException(endpoint.body) + case SESSION_ID: + req.getSession(true) + resp.status = endpoint.status + resp.writer.print(req.requestedSessionId) + break + default: + resp.status = NOT_FOUND.status + resp.writer.print(NOT_FOUND.body) + break + } + } + } + static HttpServerTest.ServerEndpoint getEndpoint(HttpServletRequest req) { + String truePath + if (req.servletPath == "") { + truePath = req.requestURI - ~'^/[^/]+' + } else { + // Most correct would be to get the dispatched path from the request + // This is not part of the spec varies by implementation so the simplest is just removing + // "/dispatch" + truePath = req.servletPath.replace("/dispatch", "") + } + return HttpServerTest.ServerEndpoint.forPath(truePath) + } + HttpServerTest.ServerEndpoint determineEndpoint(HttpServletRequest req) { + getEndpoint(req) + } +} diff --git a/metadata/supported-configurations.json b/metadata/supported-configurations.json index 93e70b13a35..0ad3c3f2b41 100644 --- a/metadata/supported-configurations.json +++ b/metadata/supported-configurations.json @@ -9833,6 +9833,22 @@ "aliases": ["DD_TRACE_INTEGRATION_SERVLET_5_ENABLED", "DD_INTEGRATION_SERVLET_5_ENABLED"] } ], + "DD_TRACE_SERVLET_6_ASYNC_CONTEXT_ENABLED": [ + { + "version": "A", + "type": "boolean", + "default": "true", + "aliases": ["DD_TRACE_INTEGRATION_SERVLET_6_ASYNC_CONTEXT_ENABLED", "DD_INTEGRATION_SERVLET_6_ASYNC_CONTEXT_ENABLED"] + } + ], + "DD_TRACE_SERVLET_6_ENABLED": [ + { + "version": "A", + "type": "boolean", + "default": "true", + "aliases": ["DD_TRACE_INTEGRATION_SERVLET_6_ENABLED", "DD_INTEGRATION_SERVLET_6_ENABLED"] + } + ], "DD_TRACE_SERVLET_ANALYTICS_ENABLED": [ { "version": "A", diff --git a/settings.gradle.kts b/settings.gradle.kts index 487f6275cf1..cc9edcfb49c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -567,6 +567,7 @@ include( ":dd-java-agent:instrumentation:servicetalk:servicetalk-0.42.56", ":dd-java-agent:instrumentation:servicetalk", ":dd-java-agent:instrumentation:servlet:jakarta-servlet-5.0", + ":dd-java-agent:instrumentation:servlet:jakarta-servlet-6.0", ":dd-java-agent:instrumentation:servlet:javax-servlet:javax-servlet-2.2", ":dd-java-agent:instrumentation:servlet:javax-servlet:javax-servlet-3.0", ":dd-java-agent:instrumentation:servlet:javax-servlet:javax-servlet-common",