diff --git a/src/main/java/io/roastedroot/proxywasm/impl/Imports.java b/src/main/java/io/roastedroot/proxywasm/impl/Imports.java index 3745ca2..1b14c91 100644 --- a/src/main/java/io/roastedroot/proxywasm/impl/Imports.java +++ b/src/main/java/io/roastedroot/proxywasm/impl/Imports.java @@ -12,6 +12,7 @@ import io.roastedroot.proxywasm.v1.LogLevel; import io.roastedroot.proxywasm.v1.MapType; import io.roastedroot.proxywasm.v1.MetricType; +import io.roastedroot.proxywasm.v1.QueueName; import io.roastedroot.proxywasm.v1.StreamType; import io.roastedroot.proxywasm.v1.WasmException; import io.roastedroot.proxywasm.v1.WasmResult; @@ -1002,4 +1003,77 @@ int proxySetSharedData(int keyDataPtr, int keySize, int valueDataPtr, int valueS return e.result().getValue(); } } + + @WasmExport + int proxyRegisterSharedQueue(int queueNameDataPtr, int queueNameSize, int returnQueueId) { + try { + // Get queue name from memory + String queueName = string(readMemory(queueNameDataPtr, queueNameSize)); + + var vmId = handler.getProperty("vm_id"); + + // Register shared queue using handler + int queueId = handler.registerSharedQueue(new QueueName(vmId, queueName)); + putUint32(returnQueueId, queueId); + return WasmResult.OK.getValue(); + + } catch (WasmException e) { + return e.result().getValue(); + } + } + + @WasmExport + int proxyResolveSharedQueue( + int vmIdDataPtr, + int vmIdSize, + int queueNameDataPtr, + int queueNameSize, + int returnQueueId) { + try { + // Get vm id from memory + String vmId = string(readMemory(vmIdDataPtr, vmIdSize)); + // Get queue name from memory + String queueName = string(readMemory(queueNameDataPtr, queueNameSize)); + + // Resolve shared queue using handler + int queueId = handler.resolveSharedQueue(new QueueName(vmId, queueName)); + putUint32(returnQueueId, queueId); + return WasmResult.OK.getValue(); + + } catch (WasmException e) { + return e.result().getValue(); + } + } + + @WasmExport + int proxyEnqueueSharedQueue(int queueId, int valueDataPtr, int valueSize) { + try { + // Get value from memory + byte[] value = readMemory(valueDataPtr, valueSize); + + // Enqueue shared queue using handler + WasmResult result = handler.enqueueSharedQueue(queueId, value); + return result.getValue(); + + } catch (WasmException e) { + return e.result().getValue(); + } + } + + @WasmExport + int proxyDequeueSharedQueue(int queueId, int returnValueData, int returnValueSize) { + try { + // Dequeue shared queue using handler + byte[] value = handler.dequeueSharedQueue(queueId); + if (value == null) { + return WasmResult.EMPTY.getValue(); + } + + copyIntoInstance(value, returnValueData, returnValueSize); + return WasmResult.OK.getValue(); + + } catch (WasmException e) { + return e.result().getValue(); + } + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java b/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java index b85aa95..44c6c67 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java @@ -306,4 +306,24 @@ public SharedData getSharedData(String key) throws WasmException { public WasmResult setSharedData(String key, byte[] value, int cas) { return next().setSharedData(key, value, cas); } + + @Override + public int registerSharedQueue(QueueName queueName) throws WasmException { + return next().registerSharedQueue(queueName); + } + + @Override + public int resolveSharedQueue(QueueName queueName) throws WasmException { + return next().resolveSharedQueue(queueName); + } + + @Override + public byte[] dequeueSharedQueue(int queueId) throws WasmException { + return next().dequeueSharedQueue(queueId); + } + + @Override + public WasmResult enqueueSharedQueue(int queueId, byte[] value) { + return next().enqueueSharedQueue(queueId, value); + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/Handler.java b/src/main/java/io/roastedroot/proxywasm/v1/Handler.java index cd3f1fb..024e94d 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/Handler.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/Handler.java @@ -440,4 +440,20 @@ default SharedData getSharedData(String key) throws WasmException { default WasmResult setSharedData(String key, byte[] value, int cas) { return WasmResult.UNIMPLEMENTED; } + + default int registerSharedQueue(QueueName name) throws WasmException { + throw new WasmException(WasmResult.UNIMPLEMENTED); + } + + default int resolveSharedQueue(QueueName name) throws WasmException { + throw new WasmException(WasmResult.UNIMPLEMENTED); + } + + default byte[] dequeueSharedQueue(int queueId) throws WasmException { + throw new WasmException(WasmResult.UNIMPLEMENTED); + } + + default WasmResult enqueueSharedQueue(int queueId, byte[] value) { + return WasmResult.UNIMPLEMENTED; + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/ProxyWasm.java b/src/main/java/io/roastedroot/proxywasm/v1/ProxyWasm.java index fc4ba83..bb2a9fe 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/ProxyWasm.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/ProxyWasm.java @@ -278,6 +278,10 @@ public void sendHttpCallResponse( this.httpCallResponseBody = null; } + public void sendOnQueueReady(int queueId) { + this.exports.proxyOnQueueReady(pluginContext.id(), queueId); + } + public int contextId() { return pluginContext.id(); } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/QueueName.java b/src/main/java/io/roastedroot/proxywasm/v1/QueueName.java new file mode 100644 index 0000000..8a7e398 --- /dev/null +++ b/src/main/java/io/roastedroot/proxywasm/v1/QueueName.java @@ -0,0 +1,35 @@ +package io.roastedroot.proxywasm.v1; + +import java.util.Objects; + +public class QueueName { + private final String vmId; + private final String name; + + public QueueName(String vmId, String name) { + this.vmId = vmId; + this.name = name; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + QueueName queue = (QueueName) o; + return Objects.equals(vmId, queue.vmId) && Objects.equals(name, queue.name); + } + + @Override + public int hashCode() { + return Objects.hash(vmId, name); + } + + public String vmId() { + return vmId; + } + + public String name() { + return name; + } +} diff --git a/src/test/go-examples/shared_queue/README.md b/src/test/go-examples/shared_queue/README.md new file mode 100644 index 0000000..a72a01a --- /dev/null +++ b/src/test/go-examples/shared_queue/README.md @@ -0,0 +1,4 @@ +## Attribution + +This example originally came from: +https://github.com/proxy-wasm/proxy-wasm-go-sdk/blob/ab4161dcf9246a828008b539a82a1556cf0f2e24/examples/shared_queue diff --git a/src/test/go-examples/shared_queue/go.mod b/src/test/go-examples/shared_queue/go.mod new file mode 100644 index 0000000..3069523 --- /dev/null +++ b/src/test/go-examples/shared_queue/go.mod @@ -0,0 +1,5 @@ +module github.com/proxy-wasm/proxy-wasm-go-sdk/examples/shared_queue + +go 1.24 + +require github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924 diff --git a/src/test/go-examples/shared_queue/go.sum b/src/test/go-examples/shared_queue/go.sum new file mode 100644 index 0000000..3ddb896 --- /dev/null +++ b/src/test/go-examples/shared_queue/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924 h1:wTcK6gcyTKJMeDka69AMjZYvisdI8CBXzTEfZ+2pOxI= +github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924/go.mod h1:9mBRvh8I6Td6sg3CwEY+zGFE4DKaIoieCaca1kQnDBE= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/test/go-examples/shared_queue/receiver/main.go b/src/test/go-examples/shared_queue/receiver/main.go new file mode 100644 index 0000000..e123935 --- /dev/null +++ b/src/test/go-examples/shared_queue/receiver/main.go @@ -0,0 +1,80 @@ +// Copyright 2020-2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm" + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm/types" +) + +func main() {} +func init() { + proxywasm.SetVMContext(&vmContext{}) +} + +// vmContext implements types.VMContext. +type vmContext struct { + // Embed the default VM context here, + // so that we don't need to reimplement all the methods. + types.DefaultVMContext +} + +// NewPluginContext implements types.VMContext. +func (*vmContext) NewPluginContext(contextID uint32) types.PluginContext { + return &receiverPluginContext{contextID: contextID} +} + +// receiverPluginContext implements types.PluginContext. +type receiverPluginContext struct { + // Embed the default plugin context here, + // so that we don't need to reimplement all the methods. + contextID uint32 + types.DefaultPluginContext + queueName string +} + +// OnPluginStart implements types.PluginContext. +func (ctx *receiverPluginContext) OnPluginStart(pluginConfigurationSize int) types.OnPluginStartStatus { + // Get Plugin configuration. + config, err := proxywasm.GetPluginConfiguration() + if err != nil { + panic(fmt.Sprintf("failed to get plugin config: %v", err)) + } + + // Treat the config as the queue name for receiving. + ctx.queueName = string(config) + + queueID, err := proxywasm.RegisterSharedQueue(ctx.queueName) + if err != nil { + panic("failed register queue") + } + proxywasm.LogInfof("queue \"%s\" registered as queueID=%d by contextID=%d", ctx.queueName, queueID, ctx.contextID) + return types.OnPluginStartStatusOK +} + +// OnQueueReady implements types.PluginContext. +func (ctx *receiverPluginContext) OnQueueReady(queueID uint32) { + data, err := proxywasm.DequeueSharedQueue(queueID) + switch err { + case types.ErrorStatusEmpty: + return + case nil: + proxywasm.LogInfof("(contextID=%d) dequeued data from %s(queueID=%d): %s", ctx.contextID, ctx.queueName, queueID, string(data)) + default: + proxywasm.LogCriticalf("error retrieving data from queue %d: %v", queueID, err) + } +} diff --git a/src/test/go-examples/shared_queue/receiver/main.wasm b/src/test/go-examples/shared_queue/receiver/main.wasm new file mode 100644 index 0000000..7f75b6a Binary files /dev/null and b/src/test/go-examples/shared_queue/receiver/main.wasm differ diff --git a/src/test/go-examples/shared_queue/sender/main.go b/src/test/go-examples/shared_queue/sender/main.go new file mode 100644 index 0000000..41ffa10 --- /dev/null +++ b/src/test/go-examples/shared_queue/sender/main.go @@ -0,0 +1,187 @@ +// Copyright 2020-2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "encoding/hex" + "fmt" + "hash/fnv" + + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm" + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm/types" +) + +const receiverVMID = "receiver" + +func main() {} +func init() { + proxywasm.SetVMContext(&vmContext{}) +} + +// vmContext implements types.VMContext. +type vmContext struct { + // Embed the default VM context here, + // so that we don't need to reimplement all the methods. + types.DefaultVMContext +} + +// NewPluginContext implements types.VMContext. +func (*vmContext) NewPluginContext(contextID uint32) types.PluginContext { + return &senderPluginContext{contextID: contextID} +} + +// senderPluginContext implements types.PluginContext. +type senderPluginContext struct { + // Embed the default plugin context here, + // so that we don't need to reimplement all the methods. + types.DefaultPluginContext + config string + contextID uint32 +} + +// OnPluginStart implements types.PluginContext. +func (ctx *senderPluginContext) OnPluginStart(pluginConfigurationSize int) types.OnPluginStartStatus { + // Get Plugin configuration. + config, err := proxywasm.GetPluginConfiguration() + if err != nil { + panic(fmt.Sprintf("failed to get plugin config: %v", err)) + } + ctx.config = string(config) + proxywasm.LogInfof("contextID=%d is configured for %s", ctx.contextID, ctx.config) + return types.OnPluginStartStatusOK +} + +// NewHttpContext implements types.PluginContext. +func (ctx *senderPluginContext) NewHttpContext(contextID uint32) types.HttpContext { + // If this PluginContext is not configured for Http, then return nil. + if ctx.config != "http" { + return nil + } + + // Resolve queues. + requestHeadersQueueID, err := proxywasm.ResolveSharedQueue(receiverVMID, "http_request_headers") + if err != nil { + proxywasm.LogCriticalf("error resolving queue id: %v", err) + } + + responseHeadersQueueID, err := proxywasm.ResolveSharedQueue(receiverVMID, "http_response_headers") + if err != nil { + proxywasm.LogCriticalf("error resolving queue id: %v", err) + } + + // Pass the resolved queueIDs to http contexts so they can enqueue. + return &senderHttpContext{ + requestHeadersQueueID: requestHeadersQueueID, + responseHeadersQueueID: responseHeadersQueueID, + contextID: contextID, + } +} + +// senderHttpContext implements types.HttpContext. +type senderHttpContext struct { + // Embed the default http context here, + // so that we don't need to reimplement all the methods. + types.DefaultHttpContext + contextID, requestHeadersQueueID, responseHeadersQueueID uint32 +} + +// OnHttpRequestHeaders implements types.HttpContext. +func (ctx *senderHttpContext) OnHttpRequestHeaders(int, bool) types.Action { + headers, err := proxywasm.GetHttpRequestHeaders() + if err != nil { + proxywasm.LogCriticalf("error getting request headers: %v", err) + } + for _, h := range headers { + msg := fmt.Sprintf("{\"key\": \"%s\",\"value\": \"%s\"}", h[0], h[1]) + if err := proxywasm.EnqueueSharedQueue(ctx.requestHeadersQueueID, []byte(msg)); err != nil { + proxywasm.LogCriticalf("error queueing: %v", err) + } else { + proxywasm.LogInfof("enqueued data: %s", msg) + } + } + return types.ActionContinue +} + +// OnHttpResponseHeaders implements types.HttpContext. +func (ctx *senderHttpContext) OnHttpResponseHeaders(int, bool) types.Action { + headers, err := proxywasm.GetHttpResponseHeaders() + if err != nil { + proxywasm.LogCriticalf("error getting response headers: %v", err) + } + for _, h := range headers { + msg := fmt.Sprintf("{\"key\": \"%s\",\"value\": \"%s\"}", h[0], h[1]) + if err := proxywasm.EnqueueSharedQueue(ctx.responseHeadersQueueID, []byte(msg)); err != nil { + proxywasm.LogCriticalf("error queueing: %v", err) + } else { + proxywasm.LogInfof("(contextID=%d) enqueued data: %s", ctx.contextID, msg) + } + } + return types.ActionContinue +} + +// NewTcpContext implements types.PluginContext. +func (ctx *senderPluginContext) NewTcpContext(contextID uint32) types.TcpContext { + // If this PluginContext is not configured for Tcp, then return nil. + if ctx.config != "tcp" { + return nil + } + + // Resolve queue. + queueID, err := proxywasm.ResolveSharedQueue(receiverVMID, "tcp_data_hashes") + if err != nil { + proxywasm.LogCriticalf("error resolving queue id: %v", err) + } + + // Pass the resolved queueID to tcp contexts so they can enqueue. + return &senderTcpContext{ + tcpHashesQueueID: queueID, + contextID: contextID, + } +} + +// senderTcpContext implements types.TcpContext. +type senderTcpContext struct { + types.DefaultTcpContext + // Embed the default http context here, + // so that we don't need to reimplement all the methods. + tcpHashesQueueID uint32 + contextID uint32 +} + +// OnUpstreamData implements types.TcpContext. +func (ctx *senderTcpContext) OnUpstreamData(dataSize int, endOfStream bool) types.Action { + if dataSize == 0 { + return types.ActionContinue + } + + // Calculate the hash of the data frame. + data, err := proxywasm.GetUpstreamData(0, dataSize) + if err != nil && err != types.ErrorStatusNotFound { + proxywasm.LogCritical(err.Error()) + } + s := fnv.New128a() + _, _ = s.Write(data) + var buf []byte + buf = s.Sum(buf) + hash := hex.EncodeToString(buf) + + // Enqueue the hashed data frame. + if err := proxywasm.EnqueueSharedQueue(ctx.tcpHashesQueueID, []byte(hash)); err != nil { + proxywasm.LogCriticalf("error queueing: %v", err) + } else { + proxywasm.LogInfof("(contextID=%d) enqueued data: %s", ctx.contextID, hash) + } + return types.ActionContinue +} diff --git a/src/test/go-examples/shared_queue/sender/main.wasm b/src/test/go-examples/shared_queue/sender/main.wasm new file mode 100644 index 0000000..a74790b Binary files /dev/null and b/src/test/go-examples/shared_queue/sender/main.wasm differ diff --git a/src/test/java/io/roastedroot/proxywasm/MockHandler.java b/src/test/java/io/roastedroot/proxywasm/MockHandler.java index a63e4c8..7fa58df 100644 --- a/src/test/java/io/roastedroot/proxywasm/MockHandler.java +++ b/src/test/java/io/roastedroot/proxywasm/MockHandler.java @@ -5,6 +5,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import io.roastedroot.proxywasm.v1.Action; +import io.roastedroot.proxywasm.v1.ChainedHandler; import io.roastedroot.proxywasm.v1.Handler; import io.roastedroot.proxywasm.v1.Helpers; import io.roastedroot.proxywasm.v1.LogLevel; @@ -20,7 +21,9 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -public class MockHandler implements Handler { +public class MockHandler extends ChainedHandler { + + private Handler next; public static class HttpResponse { @@ -64,6 +67,19 @@ public HttpResponse( static final boolean DEBUG = "true".equals(System.getenv("DEBUG")); + MockHandler() { + this(new Handler() {}); + } + + MockHandler(Handler next) { + this.next = next; + } + + @Override + protected Handler next() { + return next; + } + @Override public void log(LogLevel level, String message) throws WasmException { if (DEBUG) { @@ -435,31 +451,4 @@ public WasmResult setAction(StreamType streamType, Action action) { public Action getAction() { return action; } - - private final HashMap sharedData = new HashMap<>(); - - @Override - public SharedData getSharedData(String key) throws WasmException { - return sharedData.get(key); - } - - @Override - public WasmResult setSharedData(String key, byte[] value, int cas) { - SharedData prev = sharedData.get(key); - if (prev == null) { - if (cas == 0) { - sharedData.put(key, new SharedData(value, 0)); - return WasmResult.OK; - } else { - return WasmResult.CAS_MISMATCH; - } - } else { - if (cas == 0 || prev.cas == cas) { - sharedData.put(key, new SharedData(value, prev.cas + 1)); - return WasmResult.OK; - } else { - return WasmResult.CAS_MISMATCH; - } - } - } } diff --git a/src/test/java/io/roastedroot/proxywasm/MockSharedHandler.java b/src/test/java/io/roastedroot/proxywasm/MockSharedHandler.java new file mode 100644 index 0000000..0ea2206 --- /dev/null +++ b/src/test/java/io/roastedroot/proxywasm/MockSharedHandler.java @@ -0,0 +1,104 @@ +package io.roastedroot.proxywasm; + +import io.roastedroot.proxywasm.v1.Handler; +import io.roastedroot.proxywasm.v1.QueueName; +import io.roastedroot.proxywasm.v1.WasmException; +import io.roastedroot.proxywasm.v1.WasmResult; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockSharedHandler implements Handler { + + private final HashMap sharedData = new HashMap<>(); + + @Override + public SharedData getSharedData(String key) throws WasmException { + return sharedData.get(key); + } + + @Override + public WasmResult setSharedData(String key, byte[] value, int cas) { + SharedData prev = sharedData.get(key); + if (prev == null) { + if (cas == 0) { + sharedData.put(key, new SharedData(value, 0)); + return WasmResult.OK; + } else { + return WasmResult.CAS_MISMATCH; + } + } else { + if (cas == 0 || prev.cas == cas) { + sharedData.put(key, new SharedData(value, prev.cas + 1)); + return WasmResult.OK; + } else { + return WasmResult.CAS_MISMATCH; + } + } + } + + public static class SharedQueue { + public final QueueName queueName; + public final LinkedList data = new LinkedList<>(); + public final int id; + + public SharedQueue(QueueName queueName, int id) { + this.queueName = queueName; + this.id = id; + } + } + + private final AtomicInteger lastSharedQueueId = new AtomicInteger(0); + private final HashMap sharedQueues = new HashMap<>(); + + public SharedQueue getSharedQueue(int queueId) { + return sharedQueues.get(queueId); + } + + @Override + public WasmResult enqueueSharedQueue(int queueId, byte[] value) { + SharedQueue queue = sharedQueues.get(queueId); + if (queue == null) { + return WasmResult.NOT_FOUND; + } + queue.data.add(value); + return WasmResult.OK; + } + + @Override + public byte[] dequeueSharedQueue(int queueId) throws WasmException { + SharedQueue queue = sharedQueues.get(queueId); + if (queue == null) { + throw new WasmException(WasmResult.NOT_FOUND); + } + return queue.data.poll(); + } + + @Override + public int resolveSharedQueue(QueueName queueName) throws WasmException { + var existing = + sharedQueues.values().stream() + .filter(x -> x.queueName.equals(queueName)) + .findFirst(); + if (existing.isPresent()) { + return existing.get().id; + } else { + throw new WasmException(WasmResult.NOT_FOUND); + } + } + + @Override + public int registerSharedQueue(QueueName queueName) throws WasmException { + var existing = + sharedQueues.values().stream() + .filter(x -> x.queueName.equals(queueName)) + .findFirst(); + if (existing.isPresent()) { + return existing.get().id; + } else { + int id = lastSharedQueueId.incrementAndGet(); + sharedQueues.put(id, new SharedQueue(queueName, id)); + return id; + } + } +} diff --git a/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java b/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java index b8f9f38..91a231b 100644 --- a/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java +++ b/src/test/java/io/roastedroot/proxywasm/SharedDataTest.java @@ -16,8 +16,8 @@ public class SharedDataTest { @Test public void testSetEffectiveContext() throws StartException { - - var handler = new MockHandler(); + var sharedData = new MockSharedHandler(); + var handler = new MockHandler(sharedData); // Load the WASM module var module = Parser.parse(Path.of("./src/test/go-examples/shared_data/main.wasm")); diff --git a/src/test/java/io/roastedroot/proxywasm/SharedQueueTest.java b/src/test/java/io/roastedroot/proxywasm/SharedQueueTest.java new file mode 100644 index 0000000..c096237 --- /dev/null +++ b/src/test/java/io/roastedroot/proxywasm/SharedQueueTest.java @@ -0,0 +1,136 @@ +package io.roastedroot.proxywasm; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.dylibso.chicory.wasm.Parser; +import io.roastedroot.proxywasm.v1.Action; +import io.roastedroot.proxywasm.v1.ProxyWasm; +import io.roastedroot.proxywasm.v1.QueueName; +import io.roastedroot.proxywasm.v1.StartException; +import io.roastedroot.proxywasm.v1.WasmException; +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +/** + * Test case to verify src/test/go-examples/shared_queue example. + */ +public class SharedQueueTest { + + ArrayList closeList = new ArrayList<>(); + + @AfterEach + void tearDown() { + Collections.reverse(closeList); + for (Closeable closeable : closeList) { + try { + closeable.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } + + T deferClose(T x) { + closeList.add(x); + return x; + } + + @Test + public void testOnPluginStart() throws StartException, WasmException { + // Load the WASM module + var receiverModule = + Parser.parse(Path.of("./src/test/go-examples/shared_queue/receiver/main.wasm")); + var sharedData = new MockSharedHandler(); + + var receiverVmId = "receiver"; + + // Create and configure the http_request_headers receiver instance + var receiverHandler1 = new MockHandler(sharedData); + var receiverHost1 = + deferClose( + ProxyWasm.builder() + .withPluginHandler(receiverHandler1) + .withProperties(Map.of("vm_id", receiverVmId)) + .withPluginConfig("http_request_headers") + .build(receiverModule)); + + var requestHeadersQueueId = + sharedData.resolveSharedQueue(new QueueName(receiverVmId, "http_request_headers")); + receiverHandler1.assertLogsContain( + String.format( + "queue \"%s\" registered as queueID=%d by contextID=%d", + "http_request_headers", requestHeadersQueueId, receiverHost1.contextId())); + var requestHeadersQueue = sharedData.getSharedQueue(requestHeadersQueueId); + assertNotNull(requestHeadersQueue); + + // Create and configure the http_response_headers receiver instance + var receiverHandler2 = new MockHandler(sharedData); + var receiverHost2 = + deferClose( + ProxyWasm.builder() + .withPluginHandler(receiverHandler2) + .withProperties(Map.of("vm_id", receiverVmId)) + .withPluginConfig("http_response_headers") + .build(receiverModule)); + + var responseHeadersQueueId = + sharedData.resolveSharedQueue(new QueueName(receiverVmId, "http_response_headers")); + receiverHandler2.assertLogsContain( + String.format( + "queue \"%s\" registered as queueID=%d by contextID=%d", + "http_response_headers", + responseHeadersQueueId, + receiverHost2.contextId())); + var responseHeadersQueue = sharedData.getSharedQueue(responseHeadersQueueId); + assertNotNull(responseHeadersQueue); + + // Load the WASM module + var senderModule = + Parser.parse(Path.of("./src/test/go-examples/shared_queue/sender/main.wasm")); + + // Create and configure the sender instance + var senderHandler = new MockHandler(sharedData); + var senderVmId = "sender"; + var senderHost = + deferClose( + ProxyWasm.builder() + .withPluginHandler(senderHandler) + .withProperties(Map.of("vm_id", senderVmId)) + .withPluginConfig("http") + .build(senderModule)); + senderHandler.assertLogsContain( + String.format("contextID=%d is configured for %s", senderHost.contextId(), "http")); + + var senderContext = deferClose(senderHost.createHttpContext(senderHandler)); + + // queue is empty + assertEquals(0, requestHeadersQueue.data.size()); + + senderHandler.setHttpRequestHeaders(Map.of("hello", "world")); + Action action = senderContext.callOnRequestHeaders(false); + assertEquals(Action.CONTINUE, action); + String queuedMessage = "{\"key\": \"hello\",\"value\": \"world\"}"; + senderHandler.assertLogsContain(String.format("enqueued data: %s", queuedMessage)); + + // queue now has 1 item + assertEquals(1, requestHeadersQueue.data.size()); + + // let the receiver know that the queue is ready + receiverHost1.sendOnQueueReady(requestHeadersQueueId); + + receiverHandler1.assertLogsContain( + String.format( + "(contextID=%d) dequeued data from %s(queueID=%d): %s", + receiverHost1.contextId(), + "http_request_headers", + requestHeadersQueueId, + queuedMessage)); + } +}