From c97c0de7e21fbf798d00697961eb2748a668f12d Mon Sep 17 00:00:00 2001 From: Hiram Chirino Date: Thu, 20 Mar 2025 13:58:41 -0400 Subject: [PATCH] Move foreign function state to the handler. Add withForeignFunctions and withLogger functions to the WamPlugin Builder. Signed-off-by: Hiram Chirino --- .../java/io/roastedroot/proxywasm/ABI.java | 7 ++- .../roastedroot/proxywasm/ChainedHandler.java | 4 +- .../io/roastedroot/proxywasm/Handler.java | 4 +- .../io/roastedroot/proxywasm/ProxyWasm.java | 34 ++++++----- .../examples/ForeignCallOnTickTest.java | 2 +- .../proxywasm/examples/MockHandler.java | 12 ++++ .../roastedroot/proxywasm/jaxrs/Logger.java | 10 ++++ .../proxywasm/jaxrs/PluginHandler.java | 28 +++++++-- .../proxywasm/jaxrs/WasmPlugin.java | 13 ++++ .../jaxrs/HttpHeadersNotSharedTest.java | 10 +--- .../proxywasm/jaxrs/HttpHeadersTest.java | 10 +--- .../proxywasm/jaxrs/MockLogger.java | 59 +++++++++++++++++++ .../proxywasm/jaxrs/TestHelpers.java | 8 +++ 13 files changed, 159 insertions(+), 42 deletions(-) create mode 100644 proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/Logger.java create mode 100644 proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/MockLogger.java diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ABI.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ABI.java index 0002c09..c2b663e 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ABI.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ABI.java @@ -1868,7 +1868,12 @@ int proxyCallForeignFunction( try { var name = string(readMemory(nameDataPtr, nameSize)); var argument = readMemory(argumentDataPtr, argumentSize); - var result = handler.callForeignFunction(name, argument); + + var func = handler.getForeignFunction(name); + if (func == null) { + return WasmResult.NOT_FOUND.getValue(); + } + var result = func.apply(argument); // Allocate memory in the WebAssembly instance int addr = malloc(result.length); diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java index b1946eb..df4ef20 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java @@ -227,8 +227,8 @@ public int dispatchHttpCall( } @Override - public byte[] callForeignFunction(String name, byte[] bytes) throws WasmException { - return next().callForeignFunction(name, bytes); + public ForeignFunction getForeignFunction(String name) { + return next().getForeignFunction(name); } @Override diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/Handler.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/Handler.java index 232d2db..0535911 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/Handler.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/Handler.java @@ -322,8 +322,8 @@ default int dispatchHttpCall( throw new WasmException(WasmResult.UNIMPLEMENTED); } - default byte[] callForeignFunction(String name, byte[] bytes) throws WasmException { - throw new WasmException(WasmResult.NOT_FOUND); + default ForeignFunction getForeignFunction(String name) { + return null; } default int defineMetric(MetricType metricType, String name) throws WasmException { diff --git a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java index ddd5ef8..087a65e 100644 --- a/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java +++ b/proxy-wasm-java-host/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java @@ -36,7 +36,6 @@ public final class ProxyWasm implements Closeable { private ProxyMap httpCallResponseHeaders; private ProxyMap httpCallResponseTrailers; private byte[] httpCallResponseBody; - private HashMap foreignFunctions = new HashMap<>(); private ProxyWasm(Builder other) throws StartException { this.vmConfig = other.vmConfig; @@ -53,7 +52,16 @@ private ProxyWasm(Builder other) throws StartException { this.abi.start(); } - // start the vm with the vmHandler, it will receive stuff like log messages. + if (other.start) { + start(); + } + } + + public void start() throws StartException { + if (pluginContext != null) { + throw new IllegalStateException("already started"); + } + this.pluginContext = new PluginContext(this, pluginHandler); registerContext(pluginContext, 0); if (!this.abi.proxyOnVmStart(pluginContext.id(), vmConfig.length)) { @@ -118,15 +126,6 @@ public ProxyMap getHttpCallResponseTrailers() { public byte[] getHttpCallResponseBody() { return httpCallResponseBody; } - - @Override - public byte[] callForeignFunction(String name, byte[] bytes) throws WasmException { - ForeignFunction func = foreignFunctions.get(name); - if (func == null) { - throw new WasmException(WasmResult.NOT_FOUND); - } - return func.apply(bytes); - } }; } @@ -175,6 +174,9 @@ public void tick() { @Override public void close() { + if (this.pluginContext == null) { + return; + } this.pluginContext.close(); if (wasi != null) { wasi.close(); @@ -214,10 +216,6 @@ public int contextId() { return pluginContext.id(); } - public void registerForeignFunction(String name, ForeignFunction func) { - foreignFunctions.put(name, func); - } - ABI abi() { return abi; } @@ -232,6 +230,7 @@ public static class Builder implements Cloneable { private Handler pluginHandler; private ImportMemory memory; private WasiOptions wasiOptions; + private boolean start = true; @Override @SuppressWarnings("NoClone") @@ -247,6 +246,11 @@ public HostFunction[] toHostFunctions() { return ABI_ModuleFactory.toHostFunctions(abi); } + public Builder withStart(boolean start) { + this.start = start; + return this; + } + public ProxyWasm.Builder withVmConfig(byte[] vmConfig) { this.vmConfig = vmConfig; return this; diff --git a/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/ForeignCallOnTickTest.java b/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/ForeignCallOnTickTest.java index 6bfc7c4..896dd76 100644 --- a/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/ForeignCallOnTickTest.java +++ b/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/ForeignCallOnTickTest.java @@ -26,7 +26,7 @@ public void testOnTick() throws StartException { try (var host = builder.build(module)) { assertEquals(tickMilliseconds, handler.getTickPeriodMilliseconds()); - host.registerForeignFunction("compress", data -> data); + handler.registerForeignFunction("compress", data -> data); for (int i = 1; i <= 10; i++) { host.tick(); // call OnTick diff --git a/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java b/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java index cd9ed54..bb6b2fe 100644 --- a/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java +++ b/proxy-wasm-java-host/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java @@ -7,6 +7,7 @@ import io.roastedroot.proxywasm.Action; import io.roastedroot.proxywasm.ArrayProxyMap; import io.roastedroot.proxywasm.ChainedHandler; +import io.roastedroot.proxywasm.ForeignFunction; import io.roastedroot.proxywasm.Handler; import io.roastedroot.proxywasm.Helpers; import io.roastedroot.proxywasm.LogLevel; @@ -468,4 +469,15 @@ public WasmResult setProperty(List path, byte[] value) { properties.put(path, value); return WasmResult.OK; } + + private final HashMap foreignFunctions = new HashMap<>(); + + @Override + public ForeignFunction getForeignFunction(String name) { + return foreignFunctions.get(name); + } + + public void registerForeignFunction(String name, ForeignFunction function) { + foreignFunctions.put(name, function); + } } diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/Logger.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/Logger.java new file mode 100644 index 0000000..5cde4c4 --- /dev/null +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/Logger.java @@ -0,0 +1,10 @@ +package io.roastedroot.proxywasm.jaxrs; + +import io.roastedroot.proxywasm.LogLevel; + +public interface Logger { + + void log(LogLevel level, String message); + + LogLevel getLogLevel(); +} diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/PluginHandler.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/PluginHandler.java index cc5bcb3..3d80f14 100644 --- a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/PluginHandler.java +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/PluginHandler.java @@ -5,6 +5,7 @@ import static io.roastedroot.proxywasm.WellKnownProperties.PLUGIN_VM_ID; import io.roastedroot.proxywasm.ChainedHandler; +import io.roastedroot.proxywasm.ForeignFunction; import io.roastedroot.proxywasm.Handler; import io.roastedroot.proxywasm.LogLevel; import io.roastedroot.proxywasm.MetricType; @@ -75,20 +76,27 @@ public WasmResult setProperty(List path, byte[] value) { // Logging // ////////////////////////////////////////////////////////////////////// + public Logger logger; + static final boolean DEBUG = "true".equals(System.getenv("DEBUG")); @Override public void log(LogLevel level, String message) throws WasmException { - // TODO: improve - if (DEBUG) { - System.out.println(level + ": " + message); + Logger l = logger; + if (l == null) { + super.log(level, message); + return; } + l.log(level, message); } @Override public LogLevel getLogLevel() throws WasmException { - // TODO: improve - return super.getLogLevel(); + Logger l = logger; + if (l == null) { + return super.getLogLevel(); + } + return l.getLogLevel(); } // ////////////////////////////////////////////////////////////////////// @@ -280,4 +288,14 @@ public WasmResult removeMetric(int metricId) { metricsByName.remove(metric.name); return WasmResult.OK; } + + // ////////////////////////////////////////////////////////////////////// + // FFI + // ////////////////////////////////////////////////////////////////////// + HashMap foreignFunctions; + + @Override + public ForeignFunction getForeignFunction(String name) { + return super.getForeignFunction(name); + } } diff --git a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPlugin.java b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPlugin.java index a797028..5de5bcc 100644 --- a/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPlugin.java +++ b/proxy-wasm-jaxrs/src/main/java/io/roastedroot/proxywasm/jaxrs/WasmPlugin.java @@ -3,8 +3,11 @@ import com.dylibso.chicory.runtime.ImportMemory; import com.dylibso.chicory.runtime.Instance; import com.dylibso.chicory.wasm.WasmModule; +import io.roastedroot.proxywasm.ForeignFunction; import io.roastedroot.proxywasm.ProxyWasm; import io.roastedroot.proxywasm.StartException; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.concurrent.locks.ReentrantLock; @@ -67,6 +70,16 @@ public WasmPlugin.Builder withName(String name) { return this; } + public Builder withForeignFunctions(Map functions) { + this.handler.foreignFunctions = new HashMap<>(functions); + return this; + } + + public Builder withLogger(Logger logger) { + this.handler.logger = logger; + return this; + } + public WasmPlugin.Builder withShared(boolean shared) { this.shared = shared; return this; diff --git a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersNotSharedTest.java b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersNotSharedTest.java index f90b39f..c2425bb 100644 --- a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersNotSharedTest.java +++ b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersNotSharedTest.java @@ -1,13 +1,11 @@ package io.roastedroot.proxywasm.jaxrs; import static io.restassured.RestAssured.given; -import static io.roastedroot.proxywasm.jaxrs.TestHelpers.EXAMPLES_DIR; +import static io.roastedroot.proxywasm.jaxrs.TestHelpers.parseTestModule; -import com.dylibso.chicory.wasm.Parser; import io.quarkus.test.junit.QuarkusTest; import io.roastedroot.proxywasm.StartException; import jakarta.enterprise.inject.Produces; -import java.nio.file.Path; import org.junit.jupiter.api.Test; @QuarkusTest @@ -20,11 +18,7 @@ public WasmPluginFactory create() throws StartException { .withName("notSharedHttpHeaders") .withShared(false) .withPluginConfig("{\"header\": \"x-wasm-header\", \"value\": \"foo\"}") - .build( - Parser.parse( - Path.of( - EXAMPLES_DIR - + "/go-examples/http_headers/main.wasm"))); + .build(parseTestModule("/go-examples/http_headers/main.wasm")); } @Test diff --git a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersTest.java b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersTest.java index bd4e5c6..a5dd861 100644 --- a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersTest.java +++ b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/HttpHeadersTest.java @@ -1,14 +1,12 @@ package io.roastedroot.proxywasm.jaxrs; import static io.restassured.RestAssured.given; -import static io.roastedroot.proxywasm.jaxrs.TestHelpers.EXAMPLES_DIR; +import static io.roastedroot.proxywasm.jaxrs.TestHelpers.parseTestModule; import static org.hamcrest.Matchers.equalTo; -import com.dylibso.chicory.wasm.Parser; import io.quarkus.test.junit.QuarkusTest; import io.roastedroot.proxywasm.StartException; import jakarta.enterprise.inject.Produces; -import java.nio.file.Path; import org.junit.jupiter.api.Test; @QuarkusTest @@ -21,11 +19,7 @@ public WasmPluginFactory create() throws StartException { .withName("httpHeaders") .withShared(true) .withPluginConfig("{\"header\": \"x-wasm-header\", \"value\": \"foo\"}") - .build( - Parser.parse( - Path.of( - EXAMPLES_DIR - + "/go-examples/http_headers/main.wasm"))); + .build(parseTestModule("/go-examples/http_headers/main.wasm")); } @Test diff --git a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/MockLogger.java b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/MockLogger.java new file mode 100644 index 0000000..d65d8de --- /dev/null +++ b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/MockLogger.java @@ -0,0 +1,59 @@ +package io.roastedroot.proxywasm.jaxrs; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.roastedroot.proxywasm.LogLevel; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class MockLogger implements Logger { + + static final boolean DEBUG = "true".equals(System.getenv("DEBUG")); + + final ArrayList loggedMessages = new ArrayList<>(); + + @Override + public synchronized void log(LogLevel level, String message) { + if (DEBUG) { + System.out.println(level + ": " + message); + } + loggedMessages.add(message); + } + + @Override + public synchronized LogLevel getLogLevel() { + return LogLevel.TRACE; + } + + public synchronized ArrayList loggedMessages() { + return new ArrayList<>(loggedMessages); + } + + public synchronized void assertLogsEqual(String... messages) { + assertEquals(List.of(messages), loggedMessages()); + } + + public synchronized void assertSortedLogsEqual(String... messages) { + assertEquals( + Stream.of(messages).sorted().collect(Collectors.toList()), + loggedMessages().stream().sorted().collect(Collectors.toList())); + } + + public synchronized void assertLogsContain(String... message) { + for (String m : message) { + assertTrue(loggedMessages().contains(m), "logged messages does not contain: " + m); + } + } + + public synchronized void assertLogsDoNotContain(String... message) { + for (String log : loggedMessages()) { + for (String m : message) { + assertFalse(log.contains(m), "logged messages contains: " + m); + } + } + } +} diff --git a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/TestHelpers.java b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/TestHelpers.java index 36f6709..0661609 100644 --- a/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/TestHelpers.java +++ b/proxy-wasm-jaxrs/src/test/java/io/roastedroot/proxywasm/jaxrs/TestHelpers.java @@ -1,7 +1,15 @@ package io.roastedroot.proxywasm.jaxrs; +import com.dylibso.chicory.wasm.Parser; +import com.dylibso.chicory.wasm.WasmModule; +import java.nio.file.Path; + public final class TestHelpers { private TestHelpers() {} public static final String EXAMPLES_DIR = "../proxy-wasm-java-host/src/test"; + + public static WasmModule parseTestModule(String file) { + return Parser.parse(Path.of(EXAMPLES_DIR + file)); + } }