Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>io.roastedroot</groupId>
<artifactId>proxy-wasm-java-host-parent</artifactId>
<version>1.0-SNAPSHOT</version>
<version>999-SNAPSHOT</version>
<packaging>pom</packaging>

<name>proxy-wasm-java-host-parent</name>
Expand Down Expand Up @@ -72,7 +72,7 @@
<junit.version>5.12.0</junit.version>

<!-- runtime versions -->
<chicory.version>1.1.0</chicory.version>
<chicory.version>1.3.0</chicory.version>
<jersey.version>3.1.10</jersey.version>
<jetty.version>11.0.25</jetty.version>

Expand Down
2 changes: 1 addition & 1 deletion proxy-wasm-java-host/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.roastedroot</groupId>
<artifactId>proxy-wasm-java-host-parent</artifactId>
<version>1.0-SNAPSHOT</version>
<version>999-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion proxy-wasm-jaxrs/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>io.roastedroot</groupId>
<artifactId>proxy-wasm-java-host-parent</artifactId>
<version>1.0-SNAPSHOT</version>
<version>999-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import jakarta.ws.rs.ext.WriterInterceptorContext;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* Implements the JAX-RS {@link ContainerRequestFilter}, {@link ContainerResponseFilter},
Expand All @@ -37,8 +40,10 @@
*/
public class WasmPluginFilter
implements ContainerRequestFilter, WriterInterceptor, ContainerResponseFilter {
private static final String FILTER_CONTEXT_PROPERTY_NAME =
PluginHttpContext.class.getName() + ":";

private static final String FILTER_CONTEXT = PluginHttpContext.class.getName() + ".context";

private static final Logger LOGGER = Logger.getLogger(WasmPluginFilter.class.getName());

private final List<Pool> pluginPools;

Expand All @@ -52,6 +57,22 @@ public WasmPluginFilter(List<Pool> pluginPools) {
this.pluginPools = List.copyOf(pluginPools);
}

private static class FilterContext {
final Pool pool;
final Plugin plugin;
final PluginHttpContext httpContext;

FilterContext(Pool pool, Plugin plugin, PluginHttpContext httpContext) {
this.pool = pool;
this.plugin = plugin;
this.httpContext = httpContext;
}

public void release() {
pool.release(plugin);
}
}

/**
* Intercepts incoming JAX-RS requests before they reach the resource method.
*
Expand All @@ -65,34 +86,51 @@ public WasmPluginFilter(List<Pool> pluginPools) {
*/
@Override
public void filter(ContainerRequestContext requestContext) throws IOException {

ArrayList<FilterContext> filterContexts = new ArrayList<>();
requestContext.setProperty(FILTER_CONTEXT, filterContexts);
for (var pluginPool : pluginPools) {
filter(requestContext, pluginPool);
try {
Plugin plugin = pluginPool.borrow();
plugin.lock();
try {
var serverAdaptor = plugin.getServerAdaptor();
var requestAdaptor =
(JaxrsHttpRequestAdaptor)
serverAdaptor.httpRequestAdaptor(requestContext);
requestAdaptor.setRequestContext(requestContext);
var httpContext = plugin.createHttpContext(requestAdaptor);
filterContexts.add(new FilterContext(pluginPool, plugin, httpContext));
} finally {
plugin.unlock();
}
} catch (StartException e) {
LOGGER.log(Level.SEVERE, "Failed to start plugin: " + pluginPool.name(), e);

// release any plugins that were borrowed before the exception
for (var filterContext : filterContexts) {
filterContext.release();
}
filterContexts.clear();
requestContext.abortWith(internalServerError());
return;
}
}
for (var filterContext : filterContexts) {
filter(requestContext, filterContext);
}
}

private void filter(ContainerRequestContext requestContext, Pool pluginPool)
private void filter(ContainerRequestContext requestContext, FilterContext filterContext)
throws IOException {
Plugin plugin;
try {
plugin = pluginPool.borrow();
} catch (StartException e) {
requestContext.abortWith(interalServerError());
return;
}

plugin.lock();
var httpContext = filterContext.httpContext;
httpContext.plugin().lock();
try {
var requestAdaptor =
(JaxrsHttpRequestAdaptor)
plugin.getServerAdaptor().httpRequestAdaptor(requestContext);
var httpContext = plugin.createHttpContext(requestAdaptor);
requestContext.setProperty(
FILTER_CONTEXT_PROPERTY_NAME + pluginPool.name(), httpContext);

// the plugin may not be interested in the request headers.
if (httpContext.context().hasOnRequestHeaders()) {

requestAdaptor.setRequestContext(requestContext);
var action = httpContext.context().callOnRequestHeaders(false);
if (action == Action.PAUSE) {
httpContext.maybePause();
Expand Down Expand Up @@ -135,11 +173,11 @@ private void filter(ContainerRequestContext requestContext, Pool pluginPool)
}

} finally {
plugin.unlock(); // allow another request to use the plugin.
httpContext.plugin().unlock(); // allow another request to use the plugin.
}
}

private static Response interalServerError() {
private static Response internalServerError() {
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build();
}

Expand All @@ -160,23 +198,23 @@ private static Response interalServerError() {
public void filter(
ContainerRequestContext requestContext, ContainerResponseContext responseContext)
throws IOException {
for (var pluginPool : pluginPools) {
filter(requestContext, responseContext, pluginPool);
var filterContexts = (ArrayList<FilterContext>) requestContext.getProperty(FILTER_CONTEXT);
if (filterContexts == null) {
return;
}

for (var filterContext : filterContexts) {
filter(requestContext, responseContext, filterContext);
}
}

private void filter(
ContainerRequestContext requestContext,
ContainerResponseContext responseContext,
Pool pluginPool)
FilterContext filterContext)
throws IOException {
var httpContext =
(PluginHttpContext)
requestContext.getProperty(
FILTER_CONTEXT_PROPERTY_NAME + pluginPool.name());
if (httpContext == null) {
throw new WebApplicationException(interalServerError());
}

var httpContext = filterContext.httpContext;

// the plugin may not be interested in the request headers.
if (httpContext.context().hasOnResponseHeaders()) {
Expand Down Expand Up @@ -251,6 +289,11 @@ private void filter(
public void aroundWriteTo(WriterInterceptorContext ctx)
throws IOException, WebApplicationException {

var filterContexts = (ArrayList<FilterContext>) ctx.getProperty(FILTER_CONTEXT);
if (filterContexts == null) {
return;
}

try {

var original = ctx.getOutputStream();
Expand All @@ -260,13 +303,8 @@ public void aroundWriteTo(WriterInterceptorContext ctx)

byte[] bytes = baos.toByteArray();

for (var pluginPool : pluginPools) {
var httpContext =
(PluginHttpContext)
ctx.getProperty(FILTER_CONTEXT_PROPERTY_NAME + pluginPool.name());
if (httpContext == null) {
throw new WebApplicationException(interalServerError());
}
for (var filterContext : List.copyOf(filterContexts)) {
var httpContext = filterContext.httpContext;

httpContext.plugin().lock();

Expand Down Expand Up @@ -296,18 +334,17 @@ public void aroundWriteTo(WriterInterceptorContext ctx)
original.write(bytes);

} finally {
for (var pluginPool : pluginPools) {
var httpContext =
(PluginHttpContext)
ctx.getProperty(FILTER_CONTEXT_PROPERTY_NAME + pluginPool.name());
for (var filterContext : List.copyOf(filterContexts)) {

var httpContext = filterContext.httpContext;

// allow other request to use the plugin.
httpContext.context().close();
httpContext.plugin().unlock();

// TODO: will aroundWriteTo always get called so that we can avoid leaking the
// plugin?
pluginPool.release(httpContext.plugin());
filterContext.release();
}
}
}
Expand Down