Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,12 @@ int proxyCallForeignFunction(
try {
var name = string(readMemory(nameDataPtr, nameSize));
var argument = readMemory(argumentDataPtr, argumentSize);
var result = handler.callForeignFunction(name, argument);

var func = handler.getForeignFunction(name);
if (func == null) {
return WasmResult.NOT_FOUND.getValue();
}
var result = func.apply(argument);

// Allocate memory in the WebAssembly instance
int addr = malloc(result.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ public int dispatchHttpCall(
}

@Override
public byte[] callForeignFunction(String name, byte[] bytes) throws WasmException {
return next().callForeignFunction(name, bytes);
public ForeignFunction getForeignFunction(String name) {
return next().getForeignFunction(name);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ default int dispatchHttpCall(
throw new WasmException(WasmResult.UNIMPLEMENTED);
}

default byte[] callForeignFunction(String name, byte[] bytes) throws WasmException {
throw new WasmException(WasmResult.NOT_FOUND);
default ForeignFunction getForeignFunction(String name) {
return null;
}

default int defineMetric(MetricType metricType, String name) throws WasmException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ public final class ProxyWasm implements Closeable {
private ProxyMap httpCallResponseHeaders;
private ProxyMap httpCallResponseTrailers;
private byte[] httpCallResponseBody;
private HashMap<String, ForeignFunction> foreignFunctions = new HashMap<>();

private ProxyWasm(Builder other) throws StartException {
this.vmConfig = other.vmConfig;
Expand All @@ -53,7 +52,16 @@ private ProxyWasm(Builder other) throws StartException {
this.abi.start();
}

// start the vm with the vmHandler, it will receive stuff like log messages.
if (other.start) {
start();
}
}

public void start() throws StartException {
if (pluginContext != null) {
throw new IllegalStateException("already started");
}

this.pluginContext = new PluginContext(this, pluginHandler);
registerContext(pluginContext, 0);
if (!this.abi.proxyOnVmStart(pluginContext.id(), vmConfig.length)) {
Expand Down Expand Up @@ -118,15 +126,6 @@ public ProxyMap getHttpCallResponseTrailers() {
public byte[] getHttpCallResponseBody() {
return httpCallResponseBody;
}

@Override
public byte[] callForeignFunction(String name, byte[] bytes) throws WasmException {
ForeignFunction func = foreignFunctions.get(name);
if (func == null) {
throw new WasmException(WasmResult.NOT_FOUND);
}
return func.apply(bytes);
}
};
}

Expand Down Expand Up @@ -175,6 +174,9 @@ public void tick() {

@Override
public void close() {
if (this.pluginContext == null) {
return;
}
this.pluginContext.close();
if (wasi != null) {
wasi.close();
Expand Down Expand Up @@ -214,10 +216,6 @@ public int contextId() {
return pluginContext.id();
}

public void registerForeignFunction(String name, ForeignFunction func) {
foreignFunctions.put(name, func);
}

ABI abi() {
return abi;
}
Expand All @@ -232,6 +230,7 @@ public static class Builder implements Cloneable {
private Handler pluginHandler;
private ImportMemory memory;
private WasiOptions wasiOptions;
private boolean start = true;

@Override
@SuppressWarnings("NoClone")
Expand All @@ -247,6 +246,11 @@ public HostFunction[] toHostFunctions() {
return ABI_ModuleFactory.toHostFunctions(abi);
}

public Builder withStart(boolean start) {
this.start = start;
return this;
}

public ProxyWasm.Builder withVmConfig(byte[] vmConfig) {
this.vmConfig = vmConfig;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void testOnTick() throws StartException {
try (var host = builder.build(module)) {
assertEquals(tickMilliseconds, handler.getTickPeriodMilliseconds());

host.registerForeignFunction("compress", data -> data);
handler.registerForeignFunction("compress", data -> data);

for (int i = 1; i <= 10; i++) {
host.tick(); // call OnTick
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.roastedroot.proxywasm.Action;
import io.roastedroot.proxywasm.ArrayProxyMap;
import io.roastedroot.proxywasm.ChainedHandler;
import io.roastedroot.proxywasm.ForeignFunction;
import io.roastedroot.proxywasm.Handler;
import io.roastedroot.proxywasm.Helpers;
import io.roastedroot.proxywasm.LogLevel;
Expand Down Expand Up @@ -468,4 +469,15 @@ public WasmResult setProperty(List<String> path, byte[] value) {
properties.put(path, value);
return WasmResult.OK;
}

private final HashMap<String, ForeignFunction> foreignFunctions = new HashMap<>();

@Override
public ForeignFunction getForeignFunction(String name) {
return foreignFunctions.get(name);
}

public void registerForeignFunction(String name, ForeignFunction function) {
foreignFunctions.put(name, function);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.roastedroot.proxywasm.jaxrs;

import io.roastedroot.proxywasm.LogLevel;

public interface Logger {

void log(LogLevel level, String message);

LogLevel getLogLevel();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static io.roastedroot.proxywasm.WellKnownProperties.PLUGIN_VM_ID;

import io.roastedroot.proxywasm.ChainedHandler;
import io.roastedroot.proxywasm.ForeignFunction;
import io.roastedroot.proxywasm.Handler;
import io.roastedroot.proxywasm.LogLevel;
import io.roastedroot.proxywasm.MetricType;
Expand Down Expand Up @@ -75,20 +76,27 @@ public WasmResult setProperty(List<String> path, byte[] value) {
// Logging
// //////////////////////////////////////////////////////////////////////

public Logger logger;

static final boolean DEBUG = "true".equals(System.getenv("DEBUG"));

@Override
public void log(LogLevel level, String message) throws WasmException {
// TODO: improve
if (DEBUG) {
System.out.println(level + ": " + message);
Logger l = logger;
if (l == null) {
super.log(level, message);
return;
}
l.log(level, message);
}

@Override
public LogLevel getLogLevel() throws WasmException {
// TODO: improve
return super.getLogLevel();
Logger l = logger;
if (l == null) {
return super.getLogLevel();
}
return l.getLogLevel();
}

// //////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -280,4 +288,14 @@ public WasmResult removeMetric(int metricId) {
metricsByName.remove(metric.name);
return WasmResult.OK;
}

// //////////////////////////////////////////////////////////////////////
// FFI
// //////////////////////////////////////////////////////////////////////
HashMap<String, ForeignFunction> foreignFunctions;

@Override
public ForeignFunction getForeignFunction(String name) {
return super.getForeignFunction(name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import com.dylibso.chicory.runtime.ImportMemory;
import com.dylibso.chicory.runtime.Instance;
import com.dylibso.chicory.wasm.WasmModule;
import io.roastedroot.proxywasm.ForeignFunction;
import io.roastedroot.proxywasm.ProxyWasm;
import io.roastedroot.proxywasm.StartException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.locks.ReentrantLock;

Expand Down Expand Up @@ -67,6 +70,16 @@ public WasmPlugin.Builder withName(String name) {
return this;
}

public Builder withForeignFunctions(Map<String, ForeignFunction> functions) {
this.handler.foreignFunctions = new HashMap<>(functions);
return this;
}

public Builder withLogger(Logger logger) {
this.handler.logger = logger;
return this;
}

public WasmPlugin.Builder withShared(boolean shared) {
this.shared = shared;
return this;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package io.roastedroot.proxywasm.jaxrs;

import static io.restassured.RestAssured.given;
import static io.roastedroot.proxywasm.jaxrs.TestHelpers.EXAMPLES_DIR;
import static io.roastedroot.proxywasm.jaxrs.TestHelpers.parseTestModule;

import com.dylibso.chicory.wasm.Parser;
import io.quarkus.test.junit.QuarkusTest;
import io.roastedroot.proxywasm.StartException;
import jakarta.enterprise.inject.Produces;
import java.nio.file.Path;
import org.junit.jupiter.api.Test;

@QuarkusTest
Expand All @@ -20,11 +18,7 @@ public WasmPluginFactory create() throws StartException {
.withName("notSharedHttpHeaders")
.withShared(false)
.withPluginConfig("{\"header\": \"x-wasm-header\", \"value\": \"foo\"}")
.build(
Parser.parse(
Path.of(
EXAMPLES_DIR
+ "/go-examples/http_headers/main.wasm")));
.build(parseTestModule("/go-examples/http_headers/main.wasm"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package io.roastedroot.proxywasm.jaxrs;

import static io.restassured.RestAssured.given;
import static io.roastedroot.proxywasm.jaxrs.TestHelpers.EXAMPLES_DIR;
import static io.roastedroot.proxywasm.jaxrs.TestHelpers.parseTestModule;
import static org.hamcrest.Matchers.equalTo;

import com.dylibso.chicory.wasm.Parser;
import io.quarkus.test.junit.QuarkusTest;
import io.roastedroot.proxywasm.StartException;
import jakarta.enterprise.inject.Produces;
import java.nio.file.Path;
import org.junit.jupiter.api.Test;

@QuarkusTest
Expand All @@ -21,11 +19,7 @@ public WasmPluginFactory create() throws StartException {
.withName("httpHeaders")
.withShared(true)
.withPluginConfig("{\"header\": \"x-wasm-header\", \"value\": \"foo\"}")
.build(
Parser.parse(
Path.of(
EXAMPLES_DIR
+ "/go-examples/http_headers/main.wasm")));
.build(parseTestModule("/go-examples/http_headers/main.wasm"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.roastedroot.proxywasm.jaxrs;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.roastedroot.proxywasm.LogLevel;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class MockLogger implements Logger {

static final boolean DEBUG = "true".equals(System.getenv("DEBUG"));

final ArrayList<String> loggedMessages = new ArrayList<>();

@Override
public synchronized void log(LogLevel level, String message) {
if (DEBUG) {
System.out.println(level + ": " + message);
}
loggedMessages.add(message);
}

@Override
public synchronized LogLevel getLogLevel() {
return LogLevel.TRACE;
}

public synchronized ArrayList<String> loggedMessages() {
return new ArrayList<>(loggedMessages);
}

public synchronized void assertLogsEqual(String... messages) {
assertEquals(List.of(messages), loggedMessages());
}

public synchronized void assertSortedLogsEqual(String... messages) {
assertEquals(
Stream.of(messages).sorted().collect(Collectors.toList()),
loggedMessages().stream().sorted().collect(Collectors.toList()));
}

public synchronized void assertLogsContain(String... message) {
for (String m : message) {
assertTrue(loggedMessages().contains(m), "logged messages does not contain: " + m);
}
}

public synchronized void assertLogsDoNotContain(String... message) {
for (String log : loggedMessages()) {
for (String m : message) {
assertFalse(log.contains(m), "logged messages contains: " + m);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
package io.roastedroot.proxywasm.jaxrs;

import com.dylibso.chicory.wasm.Parser;
import com.dylibso.chicory.wasm.WasmModule;
import java.nio.file.Path;

public final class TestHelpers {
private TestHelpers() {}

public static final String EXAMPLES_DIR = "../proxy-wasm-java-host/src/test";

public static WasmModule parseTestModule(String file) {
return Parser.parse(Path.of(EXAMPLES_DIR + file));
}
}