diff --git a/pom.xml b/pom.xml index eefac53..103a5b5 100644 --- a/pom.xml +++ b/pom.xml @@ -22,6 +22,7 @@ proxy-wasm-java-host proxy-wasm-jaxrs + proxy-wasm-jaxrs-jersey @@ -53,6 +54,8 @@ io.quarkus.platform 3.19.3 3.21.0 + 3.1.10 + 11.0.25 diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Plugin.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Plugin.java index e2f62c4..15ddcfa 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Plugin.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Plugin.java @@ -35,7 +35,7 @@ public final class Plugin { private final ReentrantLock lock = new ReentrantLock(); final ProxyWasm wasm; - ServerAdaptor httpServer; + ServerAdaptor serverAdaptor; private final boolean shared; private final String name; @@ -76,8 +76,12 @@ public boolean isShared() { return shared; } - public void setHttpServer(ServerAdaptor httpServer) { - this.httpServer = httpServer; + public ServerAdaptor getServerAdaptor() { + return serverAdaptor; + } + + public void setServerAdaptor(ServerAdaptor serverAdaptor) { + this.serverAdaptor = serverAdaptor; } public Logger logger() { @@ -318,7 +322,7 @@ public WasmResult setTickPeriodMilliseconds(int tickMs) { // schedule the new tick cancelTick = - httpServer.scheduleTick( + serverAdaptor.scheduleTick( Math.max(minTickPeriodMilliseconds, tickPeriodMilliseconds), () -> { lock(); @@ -423,7 +427,7 @@ public int httpCall( try { var id = lastCallId.incrementAndGet(); var future = - httpServer.scheduleHttpCall( + serverAdaptor.scheduleHttpCall( method, connectHost, connectPort, diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Pool.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Pool.java index ee67b5e..9d76da5 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Pool.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/Pool.java @@ -17,8 +17,9 @@ default void close() {} class SharedPlugin implements Pool { private final Plugin plugin; - public SharedPlugin(Plugin plugin) throws StartException { + public SharedPlugin(ServerAdaptor serverAdaptor, Plugin plugin) throws StartException { this.plugin = plugin; + this.plugin.setServerAdaptor(serverAdaptor); } public void close() { @@ -50,10 +51,12 @@ public Plugin borrow() throws StartException { class PluginPerRequest implements Pool { + private final ServerAdaptor serverAdaptor; final PluginFactory factory; private final String name; - public PluginPerRequest(PluginFactory factory, Plugin plugin) { + public PluginPerRequest(ServerAdaptor serverAdaptor, PluginFactory factory, Plugin plugin) { + this.serverAdaptor = serverAdaptor; this.factory = factory; this.name = plugin.name(); release(plugin); @@ -67,6 +70,7 @@ public String name() { @Override public Plugin borrow() throws StartException { Plugin plugin = factory.create(); + plugin.setServerAdaptor(serverAdaptor); plugin.wasm.start(); return plugin; } diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/ServerAdaptor.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/ServerAdaptor.java index 1857732..66af5e1 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/ServerAdaptor.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/plugin/ServerAdaptor.java @@ -10,6 +10,8 @@ public interface ServerAdaptor { Runnable scheduleTick(long delay, Runnable task); + HttpRequestAdaptor httpRequestAdaptor(Object context); + Runnable scheduleHttpCall( String method, String host, diff --git a/proxy-wasm-jaxrs-jersey/pom.xml b/proxy-wasm-jaxrs-jersey/pom.xml new file mode 100644 index 0000000..25eae53 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/pom.xml @@ -0,0 +1,109 @@ + + + 4.0.0 + + + io.roastedroot + proxy-wasm-java-host-parent + 1.0-SNAPSHOT + ../pom.xml + + + proxy-wasm-jaxrs-jersey + jar + proxy-wasm-jaxrs-jersey + + + + io.roastedroot + proxy-wasm-jaxrs + ${project.version} + + + org.eclipse.jetty + jetty-server + ${jetty.version} + + + org.eclipse.jetty + jetty-servlet + ${jetty.version} + + + org.glassfish.jersey.containers + jersey-container-jetty-http + ${jersey.version} + + + org.glassfish.jersey.containers + jersey-container-servlet + ${jersey.version} + + + + + org.glassfish.jersey.core + jersey-server + ${jersey.version} + + + org.glassfish.jersey.inject + jersey-hk2 + ${jersey.version} + + + + com.google.code.gson + gson + 2.12.1 + test + + + io.rest-assured + rest-assured + 5.3.1 + test + + + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + + + org.slf4j + slf4j-nop + 2.0.12 + test + + + + + + + src/test/resources + + + + + maven-compiler-plugin + + + default-compile + + true + + + + + + maven-surefire-plugin + ${surefire-plugin.version} + + + + diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/App.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/App.java new file mode 100644 index 0000000..2e77ca5 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/App.java @@ -0,0 +1,81 @@ +package io.roastedroot.proxywasm.jaxrs.example; + +import com.dylibso.chicory.wasm.Parser; +import com.dylibso.chicory.wasm.WasmModule; +import com.google.gson.Gson; +import io.roastedroot.proxywasm.StartException; +import io.roastedroot.proxywasm.plugin.Plugin; +import io.roastedroot.proxywasm.plugin.PluginFactory; +import java.nio.file.Path; +import java.util.Map; + +public class App { + + public static final String EXAMPLES_DIR = "../proxy-wasm-java-host/src/test"; + private static final Gson gson = new Gson(); + + public static WasmModule parseTestModule(String file) { + return Parser.parse(Path.of(EXAMPLES_DIR + file)); + } + + public static PluginFactory headerTests() throws StartException { + return () -> + Plugin.builder() + .withName("headerTests") + .withLogger(new MockLogger("headerTests")) + .withPluginConfig(gson.toJson(Map.of("type", "headerTests"))) + .build(parseTestModule("/go-examples/unit_tester/main.wasm")); + } + + public static PluginFactory headerTestsNotShared() throws StartException { + return () -> + Plugin.builder() + .withName("headerTestsNotShared") + .withShared(false) + .withLogger(new MockLogger("headerTestsNotShared")) + .withPluginConfig(gson.toJson(Map.of("type", "headerTests"))) + .build(parseTestModule("/go-examples/unit_tester/main.wasm")); + } + + public static PluginFactory tickTests() throws StartException { + return () -> + Plugin.builder() + .withName("tickTests") + .withLogger(new MockLogger("tickTests")) + .withPluginConfig(gson.toJson(Map.of("type", "tickTests"))) + .build(parseTestModule("/go-examples/unit_tester/main.wasm")); + } + + public static PluginFactory ffiTests() throws StartException { + return () -> + Plugin.builder() + .withName("ffiTests") + .withLogger(new MockLogger("ffiTests")) + .withPluginConfig(gson.toJson(Map.of("type", "ffiTests"))) + .withForeignFunctions(Map.of("reverse", App::reverse)) + .build(parseTestModule("/go-examples/unit_tester/main.wasm")); + } + + public static byte[] reverse(byte[] data) { + byte[] reversed = new byte[data.length]; + for (int i = 0; i < data.length; i++) { + reversed[i] = data[data.length - 1 - i]; + } + return reversed; + } + + public static PluginFactory httpCallTests() throws StartException { + return () -> + Plugin.builder() + .withName("httpCallTests") + .withLogger(new MockLogger("httpCallTests")) + .withPluginConfig( + gson.toJson( + Map.of( + "type", "httpCallTests", + "upstream", "web_service", + "path", "/ok"))) + .withUpstreams(Map.of("web_service", "localhost:8081")) + .build(parseTestModule("/go-examples/unit_tester/main.wasm")); + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/Helpers.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/Helpers.java new file mode 100644 index 0000000..0719367 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/Helpers.java @@ -0,0 +1,44 @@ +package io.roastedroot.proxywasm.jaxrs.example; + +import java.util.ArrayList; +import org.hamcrest.Description; +import org.hamcrest.TypeSafeMatcher; +import org.junit.jupiter.api.Assertions; + +public class Helpers { + private Helpers() {} + + public static void assertLogsContain(ArrayList loggedMessages, String... message) { + for (String m : message) { + Assertions.assertTrue( + loggedMessages.contains(m), "logged messages does not contain: " + m); + } + } + + public static TypeSafeMatcher isTrue(IsTrueMatcher.Predicate predicate) { + return new IsTrueMatcher(predicate); + } + + public static class IsTrueMatcher extends TypeSafeMatcher { + + public interface Predicate { + boolean matchesSafely(T value); + } + + Predicate predicate; + + public IsTrueMatcher(Predicate predicate) { + this.predicate = predicate; + } + + @Override + protected boolean matchesSafely(T item) { + return predicate.matchesSafely(item); + } + + @Override + public void describeTo(Description description) { + description.appendText("is not true"); + } + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/MockLogger.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/MockLogger.java new file mode 100644 index 0000000..33799f8 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/MockLogger.java @@ -0,0 +1,34 @@ +package io.roastedroot.proxywasm.jaxrs.example; + +import io.roastedroot.proxywasm.LogLevel; +import io.roastedroot.proxywasm.plugin.Logger; +import java.util.ArrayList; + +public class MockLogger implements Logger { + + static final boolean DEBUG = "true".equals(System.getenv("DEBUG")); + + final ArrayList loggedMessages = new ArrayList<>(); + private final String name; + + public MockLogger(String name) { + this.name = name; + } + + @Override + public synchronized void log(LogLevel level, String message) { + if (DEBUG) { + System.out.printf("%s: [%s] %s\n", level, name, message); + } + loggedMessages.add(message); + } + + @Override + public synchronized LogLevel getLogLevel() { + return LogLevel.TRACE; + } + + public synchronized ArrayList loggedMessages() { + return new ArrayList<>(loggedMessages); + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/Resources.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/Resources.java new file mode 100644 index 0000000..d947078 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/Resources.java @@ -0,0 +1,72 @@ +package io.roastedroot.proxywasm.jaxrs.example; + +import io.roastedroot.proxywasm.jaxrs.WasmPlugin; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.HeaderParam; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; + +@Path("/") +public class Resources { + + @Context ContainerRequestContext requestContext; + + @Path("/fail") + @GET + public Response fail() { + Response.ResponseBuilder builder = Response.status(Response.Status.BAD_REQUEST); + for (String header : requestContext.getHeaders().keySet()) { + builder.header("echo-" + header, requestContext.getHeaderString(header)); + } + return builder.build(); + } + + @Path("/ok") + @GET + public Response ok() { + Response.ResponseBuilder builder = Response.status(Response.Status.OK); + for (String header : requestContext.getHeaders().keySet()) { + builder.header("echo-" + header, requestContext.getHeaderString(header)); + } + return builder.entity("ok").build(); + } + + @Path("/headerTests") + @GET + @WasmPlugin("headerTests") + public String uhttpHeaders(@HeaderParam("x-request-counter") String counter) { + return String.format("counter: %s", counter); + } + + @Path("/headerTestsNotShared") + @GET + @WasmPlugin("headerTestsNotShared") + public String unotSharedHttpHeaders(@HeaderParam("x-request-counter") String counter) { + return String.format("counter: %s", counter); + } + + @Path("/tickTests/{sub: .+ }") + @GET + @WasmPlugin("tickTests") + public String tickTests(@PathParam("sub") String sub) { + return "hello world"; + } + + @Path("/ffiTests/reverse") + @POST + @WasmPlugin("ffiTests") + public String ffiTests(String body) { + return body; + } + + @Path("/httpCallTests") + @GET + @WasmPlugin("httpCallTests") + public String httpCallTests() { + return "hello world"; + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/BaseTest.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/BaseTest.java new file mode 100644 index 0000000..38ee894 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/BaseTest.java @@ -0,0 +1,56 @@ +package io.roastedroot.proxywasm.jaxrs.example.tests; + +import io.restassured.specification.RequestSpecification; +import io.roastedroot.proxywasm.jaxrs.WasmPluginFeature; +import io.roastedroot.proxywasm.jaxrs.example.App; +import io.roastedroot.proxywasm.jaxrs.example.Resources; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.glassfish.jersey.server.ResourceConfig; +import org.glassfish.jersey.servlet.ServletContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +public class BaseTest { + protected Server server; + protected static final int PORT = 8081; + + public RequestSpecification given() { + return io.restassured.RestAssured.given().port(PORT); + } + + @BeforeEach + public void setUp() throws Exception { + server = new Server(PORT); + ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath("/"); + + ResourceConfig resourceConfig = new ResourceConfig(); + resourceConfig.register(Resources.class); + + // Create mock Instance for WasmPluginFeature + resourceConfig.register( + new WasmPluginFeature( + new io.roastedroot.proxywasm.jaxrs.ServerAdaptor(), + App.headerTests(), + App.headerTestsNotShared(), + App.tickTests(), + App.ffiTests(), + App.httpCallTests())); + + ServletHolder jerseyServlet = new ServletHolder(new ServletContainer(resourceConfig)); + jerseyServlet.setInitOrder(0); + context.addServlet(jerseyServlet, "/*"); + + server.setHandler(context); + server.start(); + } + + @AfterEach + public void tearDown() throws Exception { + if (server != null) { + server.stop(); + } + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/FFITest.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/FFITest.java new file mode 100644 index 0000000..b7b231a --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/FFITest.java @@ -0,0 +1,18 @@ +package io.roastedroot.proxywasm.jaxrs.example.tests; + +import static org.hamcrest.Matchers.equalTo; + +import org.junit.jupiter.api.Test; + +public class FFITest extends BaseTest { + + @Test + public void reverse() throws InterruptedException { + given().body("My Test") + .when() + .post("/ffiTests/reverse") + .then() + .statusCode(200) + .body(equalTo("tseT yM")); + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HeadersTest.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HeadersTest.java new file mode 100644 index 0000000..0023347 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HeadersTest.java @@ -0,0 +1,48 @@ +package io.roastedroot.proxywasm.jaxrs.example.tests; + +import static io.restassured.RestAssured.given; +import static org.hamcrest.Matchers.equalTo; + +import org.junit.jupiter.api.Test; + +/** + * This test verifies that the plugin can modify request and response headers. + *

+ * It also verifies that the plugin state is shared between requests or can be isolated using configuration. + */ +public class HeadersTest extends BaseTest { + + @Test + public void testShared() { + given().when() + .get("/headerTests") + .then() + .statusCode(200) + .header("x-response-counter", "1") + .body(equalTo("counter: 1")); + + given().when() + .get("/headerTests") + .then() + .statusCode(200) + .header("x-response-counter", "2") + .body(equalTo("counter: 2")); + } + + @Test + public void testNotShared() { + given().when() + .get("/headerTestsNotShared") + .then() + .statusCode(200) + .header("x-response-counter", "1") + .body(equalTo("counter: 1")); + + given().when() + .get("/headerTestsNotShared") + .then() + .statusCode(200) + .header("x-response-counter", "1") + .body(equalTo("counter: 1")); + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java new file mode 100644 index 0000000..fb41721 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java @@ -0,0 +1,21 @@ +package io.roastedroot.proxywasm.jaxrs.example.tests; + +import static org.hamcrest.Matchers.equalTo; + +import io.roastedroot.proxywasm.StartException; +import org.junit.jupiter.api.Test; + +public class HttpCallTest extends BaseTest { + + @Test + public void test() throws InterruptedException, StartException { + // the wasm plugin will forward the request to the /ok endpoint + given().header("test", "ok") + .when() + .get("/httpCallTests") + .then() + .statusCode(200) + .body(equalTo("ok")) + .header("echo-test", "ok"); + } +} diff --git a/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/TickTest.java b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/TickTest.java new file mode 100644 index 0000000..6e66789 --- /dev/null +++ b/proxy-wasm-jaxrs-jersey/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/TickTest.java @@ -0,0 +1,44 @@ +package io.roastedroot.proxywasm.jaxrs.example.tests; + +import static io.restassured.RestAssured.given; +import static io.roastedroot.proxywasm.jaxrs.example.Helpers.isTrue; +import static org.hamcrest.Matchers.equalTo; + +import org.junit.jupiter.api.Test; + +public class TickTest extends BaseTest { + + @Test + public void tick() throws InterruptedException { + // plugin should not have received any ticks. + given().when().get("/tickTests/get").then().statusCode(200).body(equalTo("0")); + + // ask the plugin to enable tick events.. + given().when().get("/tickTests/enable").then().statusCode(200).body(equalTo("ok")); + + // wait a little to allow the plugin to receive some ticks. (every 100 ms) + Thread.sleep(300); + + // stop getting ticks. + given().when().get("/tickTests/disable").then().statusCode(200).body(equalTo("ok")); + + var counter = new String[] {"0"}; + + // plugin should have received at least 1 tick. + given().when() + .get("/tickTests/get") + .then() + .statusCode(200) + .body( + isTrue( + (String x) -> { + counter[0] = x; + return Integer.parseInt(x) >= 1; + })); + + // since ticks were disabled the tick counter should not have changed. + Thread.sleep(300); + + given().when().get("/tickTests/get").then().statusCode(200).body(equalTo(counter[0])); + } +} diff --git a/proxy-wasm-jaxrs-quarkus/src/main/java/io/roastedroot/proxywasm/jaxrs/quarkus/deployment/ProxyWasmJaxrsQuarkusProcessor.java b/proxy-wasm-jaxrs-quarkus/src/main/java/io/roastedroot/proxywasm/jaxrs/quarkus/deployment/ProxyWasmJaxrsQuarkusProcessor.java index 7bd66a5..6c7cbcf 100644 --- a/proxy-wasm-jaxrs-quarkus/src/main/java/io/roastedroot/proxywasm/jaxrs/quarkus/deployment/ProxyWasmJaxrsQuarkusProcessor.java +++ b/proxy-wasm-jaxrs-quarkus/src/main/java/io/roastedroot/proxywasm/jaxrs/quarkus/deployment/ProxyWasmJaxrsQuarkusProcessor.java @@ -5,7 +5,7 @@ import io.quarkus.deployment.builditem.FeatureBuildItem; import io.quarkus.jaxrs.spi.deployment.AdditionalJaxRsResourceMethodAnnotationsBuildItem; import io.roastedroot.proxywasm.jaxrs.WasmPlugin; -import io.roastedroot.proxywasm.jaxrs.WasmPluginFeature; +import io.roastedroot.proxywasm.jaxrs.cdi.WasmPluginFeature; import java.util.List; import org.jboss.jandex.DotName; diff --git a/proxy-wasm-jaxrs-quarkus/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java b/proxy-wasm-jaxrs-quarkus/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java index 44b9479..0c35873 100644 --- a/proxy-wasm-jaxrs-quarkus/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java +++ b/proxy-wasm-jaxrs-quarkus/src/test/java/io/roastedroot/proxywasm/jaxrs/example/tests/HttpCallTest.java @@ -5,15 +5,11 @@ import io.quarkus.test.junit.QuarkusTest; import io.roastedroot.proxywasm.StartException; -import io.roastedroot.proxywasm.jaxrs.WasmPluginFeature; -import jakarta.inject.Inject; import org.junit.jupiter.api.Test; @QuarkusTest public class HttpCallTest { - @Inject WasmPluginFeature feature; - @Test public void test() throws InterruptedException, StartException { diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/AbstractWasmPluginFeature.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/AbstractWasmPluginFeature.java new file mode 100644 index 0000000..092d935 --- /dev/null +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/AbstractWasmPluginFeature.java @@ -0,0 +1,73 @@ +package io.roastedroot.proxywasm.jaxrs; + +import io.roastedroot.proxywasm.StartException; +import io.roastedroot.proxywasm.plugin.Plugin; +import io.roastedroot.proxywasm.plugin.PluginFactory; +import io.roastedroot.proxywasm.plugin.Pool; +import io.roastedroot.proxywasm.plugin.ServerAdaptor; +import jakarta.ws.rs.container.DynamicFeature; +import jakarta.ws.rs.container.ResourceInfo; +import jakarta.ws.rs.core.FeatureContext; +import java.util.Collection; +import java.util.HashMap; + +public abstract class AbstractWasmPluginFeature implements DynamicFeature { + + private final HashMap pluginPools = new HashMap<>(); + + public void init(Iterable factories, ServerAdaptor serverAdaptor) + throws StartException { + + if (!pluginPools.isEmpty()) { + return; + } + + for (var factory : factories) { + Plugin plugin = factory.create(); + String name = plugin.name(); + if (this.pluginPools.containsKey(name)) { + throw new IllegalArgumentException("Duplicate wasm plugin name: " + name); + } + Pool pool = + plugin.isShared() + ? new Pool.SharedPlugin(serverAdaptor, plugin) + : new Pool.PluginPerRequest(serverAdaptor, factory, plugin); + this.pluginPools.put(name, pool); + } + } + + public void destroy() { + for (var pool : pluginPools.values()) { + pool.close(); + } + pluginPools.clear(); + } + + public Collection getPluginPools() { + return pluginPools.values(); + } + + public Pool pool(String name) { + return pluginPools.get(name); + } + + @Override + public void configure(ResourceInfo resourceInfo, FeatureContext context) { + + var resourceMethod = resourceInfo.getResourceMethod(); + if (resourceMethod != null) { + WasmPlugin pluignNameAnnotation = resourceMethod.getAnnotation(WasmPlugin.class); + if (pluignNameAnnotation == null) { + // If no annotation on method, check the class level + pluignNameAnnotation = + resourceInfo.getResourceClass().getAnnotation(WasmPlugin.class); + } + if (pluignNameAnnotation != null) { + Pool pool = pluginPools.get(pluignNameAnnotation.value()); + if (pool != null) { + context.register(new WasmPluginFilter(pool)); + } + } + } + } +} diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/JaxrsHttpRequestAdaptor.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/JaxrsHttpRequestAdaptor.java index 47a141d..1b64dab 100644 --- a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/JaxrsHttpRequestAdaptor.java +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/JaxrsHttpRequestAdaptor.java @@ -55,7 +55,7 @@ import java.util.Date; import java.util.List; -public abstract class JaxrsHttpRequestAdaptor implements HttpRequestAdaptor { +public class JaxrsHttpRequestAdaptor implements HttpRequestAdaptor { private ContainerRequestContext requestContext; private ContainerResponseContext responseContext; @@ -78,6 +78,26 @@ public void setResponseContext(ContainerResponseContext responseContext) { this.responseContext = responseContext; } + @Override + public String remoteAddress() { + return ""; + } + + @Override + public String remotePort() { + return ""; + } + + @Override + public String localAddress() { + return ""; + } + + @Override + public String localPort() { + return ""; + } + // ////////////////////////////////////////////////////////////////////// // HTTP fields // ////////////////////////////////////////////////////////////////////// diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/ServerAdaptor.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/ServerAdaptor.java new file mode 100644 index 0000000..b1c34c5 --- /dev/null +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/ServerAdaptor.java @@ -0,0 +1,118 @@ +package io.roastedroot.proxywasm.jaxrs; + +import io.roastedroot.proxywasm.ArrayProxyMap; +import io.roastedroot.proxywasm.ProxyMap; +import io.roastedroot.proxywasm.plugin.HttpCallResponse; +import io.roastedroot.proxywasm.plugin.HttpCallResponseHandler; +import io.roastedroot.proxywasm.plugin.HttpRequestAdaptor; +import jakarta.annotation.Priority; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Alternative; +import jakarta.ws.rs.core.UriBuilder; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +@Alternative +@Priority(100) +@ApplicationScoped +public class ServerAdaptor implements io.roastedroot.proxywasm.plugin.ServerAdaptor { + + ScheduledExecutorService tickExecutorService = Executors.newScheduledThreadPool(1); + ExecutorService executorService = Executors.newWorkStealingPool(5); + HttpClient client = HttpClient.newHttpClient(); + + @Override + public Runnable scheduleTick(long delay, Runnable task) { + var f = tickExecutorService.scheduleAtFixedRate(task, delay, delay, TimeUnit.MILLISECONDS); + return () -> { + f.cancel(false); + }; + } + + @Override + public HttpRequestAdaptor httpRequestAdaptor(Object context) { + return new JaxrsHttpRequestAdaptor(); + } + + @Override + public Runnable scheduleHttpCall( + String method, + String host, + int port, + URI uri, + ProxyMap headers, + byte[] body, + ProxyMap trailers, + int timeout, + HttpCallResponseHandler handler) + throws InterruptedException { + + Callable task = + () -> { + var resp = httpCall(method, host, port, uri, headers, body); + handler.call(resp); + return null; + }; + var f = executorService.submit(task); + Runnable cancel = + () -> { + f.cancel(true); + }; + // TODO: is there a better way to do this? + if (timeout > 0) { + tickExecutorService.schedule(cancel, timeout, TimeUnit.MILLISECONDS); + } + return cancel; + } + + private HttpCallResponse httpCall( + String method, String host, int port, URI uri, ProxyMap headers, byte[] body) { + + try { + var connectUri = UriBuilder.fromUri(uri).host(host).port(port).build(); + + var builder = HttpRequest.newBuilder().uri(connectUri); + for (var e : headers.entries()) { + try { + builder.header(e.getKey(), e.getValue()); + } catch (IllegalArgumentException ignore) { + // ignore + } + } + builder.method(method, HttpRequest.BodyPublishers.ofByteArray(body)); + var request = builder.build(); + + HttpResponse response = + client.send(request, HttpResponse.BodyHandlers.ofByteArray()); + response.headers() + .map() + .forEach( + (k, v) -> { + for (var s : v) { + headers.add(k, s); + } + }); + + var h = new ArrayProxyMap(); + response.headers() + .map() + .forEach( + (k, v) -> { + for (var s : v) { + h.add(k, s); + } + }); + + return new HttpCallResponse(response.statusCode(), h, response.body()); + } catch (Exception e) { + return new HttpCallResponse(500, new ArrayProxyMap(), new byte[] {}); + } + } +} diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFeature.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFeature.java index 7d7940d..b457929 100644 --- a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFeature.java +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFeature.java @@ -1,81 +1,14 @@ package io.roastedroot.proxywasm.jaxrs; import io.roastedroot.proxywasm.StartException; -import io.roastedroot.proxywasm.plugin.Plugin; import io.roastedroot.proxywasm.plugin.PluginFactory; -import io.roastedroot.proxywasm.plugin.Pool; import io.roastedroot.proxywasm.plugin.ServerAdaptor; -import jakarta.annotation.PreDestroy; -import jakarta.enterprise.context.ApplicationScoped; -import jakarta.enterprise.inject.Any; -import jakarta.enterprise.inject.Instance; -import jakarta.inject.Inject; -import jakarta.ws.rs.container.DynamicFeature; -import jakarta.ws.rs.container.ResourceInfo; -import jakarta.ws.rs.core.FeatureContext; -import jakarta.ws.rs.ext.Provider; -import java.util.Collection; -import java.util.HashMap; +import java.util.Arrays; -@Provider -@ApplicationScoped -public class WasmPluginFeature implements DynamicFeature { +public class WasmPluginFeature extends AbstractWasmPluginFeature { - private final HashMap pluginPools = new HashMap<>(); - - @Inject @Any Instance httpServerRequest; - - @Inject - public WasmPluginFeature(Instance factories, @Any ServerAdaptor httpServer) + public WasmPluginFeature(ServerAdaptor httpServer, PluginFactory... factories) throws StartException { - for (var factory : factories) { - Plugin plugin = null; - plugin = factory.create(); - plugin.setHttpServer(httpServer); - String name = plugin.name(); - if (this.pluginPools.containsKey(name)) { - throw new IllegalArgumentException("Duplicate wasm plugin name: " + name); - } - Pool pool = - plugin.isShared() - ? new Pool.SharedPlugin(plugin) - : new Pool.PluginPerRequest(factory, plugin); - this.pluginPools.put(name, pool); - } - } - - @PreDestroy - public void destroy() { - for (var pool : pluginPools.values()) { - pool.close(); - } - } - - public Collection getPluginPools() { - return pluginPools.values(); - } - - public Pool pool(String name) { - return pluginPools.get(name); - } - - @Override - public void configure(ResourceInfo resourceInfo, FeatureContext context) { - - var resourceMethod = resourceInfo.getResourceMethod(); - if (resourceMethod != null) { - WasmPlugin pluignNameAnnotation = resourceMethod.getAnnotation(WasmPlugin.class); - if (pluignNameAnnotation == null) { - // If no annotation on method, check the class level - pluignNameAnnotation = - resourceInfo.getResourceClass().getAnnotation(WasmPlugin.class); - } - if (pluignNameAnnotation != null) { - Pool factory = pluginPools.get(pluignNameAnnotation.value()); - if (factory != null) { - context.register(new WasmPluginFilter(factory, httpServerRequest)); - } - } - } + init(Arrays.asList(factories), httpServer); } } diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFilter.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFilter.java index a6e6c7b..d966bbb 100644 --- a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFilter.java +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPluginFilter.java @@ -8,7 +8,6 @@ import io.roastedroot.proxywasm.plugin.Plugin; import io.roastedroot.proxywasm.plugin.Pool; import io.roastedroot.proxywasm.plugin.SendResponse; -import jakarta.enterprise.inject.Instance; import jakarta.ws.rs.WebApplicationException; import jakarta.ws.rs.container.ContainerRequestContext; import jakarta.ws.rs.container.ContainerRequestFilter; @@ -26,11 +25,8 @@ public class WasmPluginFilter private final Pool pluginPool; - Instance requestAdaptor; - - public WasmPluginFilter(Pool pluginPool, Instance httpServer) { + public WasmPluginFilter(Pool pluginPool) { this.pluginPool = pluginPool; - this.requestAdaptor = httpServer; } @Override @@ -46,7 +42,9 @@ public void filter(ContainerRequestContext requestContext) throws IOException { plugin.lock(); try { - var requestAdaptor = this.requestAdaptor.get(); + var requestAdaptor = + (JaxrsHttpRequestAdaptor) + plugin.getServerAdaptor().httpRequestAdaptor(requestContext); var httpContext = plugin.createHttpContext(requestAdaptor); requestContext.setProperty(FILTER_CONTEXT_PROPERTY_NAME, httpContext); @@ -116,7 +114,7 @@ public void filter( try { httpContext.plugin().lock(); - var requestAdaptor = this.requestAdaptor.get(); + var requestAdaptor = (JaxrsHttpRequestAdaptor) httpContext.requestAdaptor(); requestAdaptor.setResponseContext(responseContext); var action = httpContext.context().callOnResponseHeaders(false); if (action == Action.PAUSE) { @@ -147,55 +145,38 @@ public void aroundWriteTo(WriterInterceptorContext ctx) } try { + httpContext.plugin().lock(); // the plugin may not be interested in the request body. - if (httpContext.context().hasOnResponseBody()) { - var original = ctx.getOutputStream(); - ctx.setOutputStream( - new ByteArrayOutputStream() { - boolean closed = false; - - @Override - public void close() throws IOException { - if (closed) { - return; - } - closed = true; - super.close(); - - // TODO: find out if it's more efficient to read the body in chunks - // and - // do - // multiple callOnRequestBody calls. - - byte[] bytes = this.toByteArray(); - - httpContext.plugin().lock(); - - httpContext.setHttpResponseBody(bytes); - var action = httpContext.context().callOnResponseBody(false); - bytes = httpContext.getHttpResponseBody(); - if (action == Action.CONTINUE) { - // continue means plugin is done reading the body. - httpContext.setHttpResponseBody(null); - } else { - httpContext.maybePause(); - } - - // does the plugin want to respond early? - var sendResponse = httpContext.consumeSentHttpResponse(); - if (sendResponse != null) { - throw new WebApplicationException(toResponse(sendResponse)); - } - - // plugin may have modified the body - original.write(bytes); - original.close(); - } - }); + if (!httpContext.context().hasOnResponseBody()) { + ctx.proceed(); } + var original = ctx.getOutputStream(); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ctx.setOutputStream(baos); ctx.proceed(); + + byte[] bytes = baos.toByteArray(); + httpContext.setHttpResponseBody(bytes); + var action = httpContext.context().callOnResponseBody(false); + bytes = httpContext.getHttpResponseBody(); + if (action == Action.CONTINUE) { + // continue means plugin is done reading the body. + httpContext.setHttpResponseBody(null); + } else { + httpContext.maybePause(); + } + + // does the plugin want to respond early? + var sendResponse = httpContext.consumeSentHttpResponse(); + if (sendResponse != null) { + throw new WebApplicationException(toResponse(sendResponse)); + } + + // plugin may have modified the body + original.write(bytes); + } finally { // allow other request to use the plugin. httpContext.context().close(); diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/cdi/WasmPluginFeature.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/cdi/WasmPluginFeature.java new file mode 100644 index 0000000..daf361e --- /dev/null +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/cdi/WasmPluginFeature.java @@ -0,0 +1,32 @@ +package io.roastedroot.proxywasm.jaxrs.cdi; + +import io.roastedroot.proxywasm.StartException; +import io.roastedroot.proxywasm.plugin.PluginFactory; +import io.roastedroot.proxywasm.plugin.ServerAdaptor; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Any; +import jakarta.enterprise.inject.Instance; +import jakarta.inject.Inject; +import jakarta.ws.rs.ext.Provider; + +@Provider +@ApplicationScoped +public class WasmPluginFeature extends io.roastedroot.proxywasm.jaxrs.AbstractWasmPluginFeature { + + @Inject Instance factories; + + @Inject @Any ServerAdaptor httpServer; + + @Inject + @PostConstruct + public void init() throws StartException { + init(factories, httpServer); + } + + @PreDestroy + public void destroy() { + super.destroy(); + } +} diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/servlet/ServletJaxrsHttpRequestAdaptor.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/servlet/ServletJaxrsHttpRequestAdaptor.java new file mode 100644 index 0000000..2e84858 --- /dev/null +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/servlet/ServletJaxrsHttpRequestAdaptor.java @@ -0,0 +1,45 @@ +package io.roastedroot.proxywasm.jaxrs.servlet; + +import io.roastedroot.proxywasm.jaxrs.JaxrsHttpRequestAdaptor; +import jakarta.servlet.http.HttpServletRequest; + +public class ServletJaxrsHttpRequestAdaptor extends JaxrsHttpRequestAdaptor { + + private final HttpServletRequest request; + + public ServletJaxrsHttpRequestAdaptor(HttpServletRequest request) { + this.request = request; + } + + @Override + public String remoteAddress() { + if (request == null) { + return ""; + } + return request.getRemoteAddr(); + } + + @Override + public String remotePort() { + if (request == null) { + return ""; + } + return "" + request.getRemotePort(); + } + + @Override + public String localAddress() { + if (request == null) { + return ""; + } + return request.getLocalAddr(); + } + + @Override + public String localPort() { + if (request == null) { + return ""; + } + return "" + request.getLocalPort(); + } +} diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/vertx/VertxServerAdaptor.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/vertx/VertxServerAdaptor.java index 1c2a7e0..83d0c8c 100644 --- a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/vertx/VertxServerAdaptor.java +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/vertx/VertxServerAdaptor.java @@ -3,6 +3,7 @@ import io.roastedroot.proxywasm.ProxyMap; import io.roastedroot.proxywasm.plugin.HttpCallResponse; import io.roastedroot.proxywasm.plugin.HttpCallResponseHandler; +import io.roastedroot.proxywasm.plugin.HttpRequestAdaptor; import io.roastedroot.proxywasm.plugin.ServerAdaptor; import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; @@ -14,6 +15,7 @@ import jakarta.annotation.Priority; import jakarta.enterprise.context.ApplicationScoped; import jakarta.enterprise.inject.Alternative; +import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; import java.net.URI; @@ -44,6 +46,13 @@ public Runnable scheduleTick(long delay, Runnable task) { }; } + @Inject Instance httpRequestAdaptors; + + @Override + public HttpRequestAdaptor httpRequestAdaptor(Object context) { + return httpRequestAdaptors.get(); + } + @Override public Runnable scheduleHttpCall( String method,