diff --git a/src/main/java/io/roastedroot/proxywasm/impl/Imports.java b/src/main/java/io/roastedroot/proxywasm/impl/Imports.java index 53bc66d..7bd28b1 100644 --- a/src/main/java/io/roastedroot/proxywasm/impl/Imports.java +++ b/src/main/java/io/roastedroot/proxywasm/impl/Imports.java @@ -8,6 +8,7 @@ import io.roastedroot.proxywasm.v1.Handler; import io.roastedroot.proxywasm.v1.LogLevel; import io.roastedroot.proxywasm.v1.MapType; +import io.roastedroot.proxywasm.v1.StreamType; import io.roastedroot.proxywasm.v1.WasmException; import io.roastedroot.proxywasm.v1.WasmResult; import java.nio.ByteBuffer; @@ -779,4 +780,24 @@ int proxyGetCurrentTimeNanoseconds(int returnTime) { return e.result().getValue(); } } + + @WasmExport + int proxyContinueStream(int arg) { + var streamType = StreamType.fromInt(arg); + if (streamType == null) { + return WasmResult.BAD_ARGUMENT.getValue(); + } + switch (streamType) { + case REQUEST: + return handler.continueRequest().getValue(); + case RESPONSE: + return handler.continueResponse().getValue(); + case DOWNSTREAM: + return handler.continueDownstream().getValue(); + case UPSTREAM: + return handler.continueUpstream().getValue(); + } + // should never reach here + return WasmResult.INTERNAL_FAILURE.getValue(); + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java b/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java index 9da3227..ba39ba6 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/ChainedHandler.java @@ -233,4 +233,24 @@ public WasmResult setHttpResponseBody(byte[] body) { public WasmResult setHttpRequestBody(byte[] body) { return next().setHttpRequestBody(body); } + + @Override + public WasmResult continueRequest() { + return next().continueRequest(); + } + + @Override + public WasmResult continueResponse() { + return next().continueResponse(); + } + + @Override + public WasmResult continueDownstream() { + return next().continueDownstream(); + } + + @Override + public WasmResult continueUpstream() { + return next().continueUpstream(); + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/Handler.java b/src/main/java/io/roastedroot/proxywasm/v1/Handler.java index efbf32c..09cbb09 100644 --- a/src/main/java/io/roastedroot/proxywasm/v1/Handler.java +++ b/src/main/java/io/roastedroot/proxywasm/v1/Handler.java @@ -369,4 +369,20 @@ default WasmResult setGrpcReceiveInitialMetaData(Map metadata) { default WasmResult setGrpcReceiveTrailerMetaData(Map metadata) { return WasmResult.UNIMPLEMENTED; } + + default WasmResult continueRequest() { + return WasmResult.UNIMPLEMENTED; + } + + default WasmResult continueResponse() { + return WasmResult.UNIMPLEMENTED; + } + + default WasmResult continueDownstream() { + return WasmResult.UNIMPLEMENTED; + } + + default WasmResult continueUpstream() { + return WasmResult.UNIMPLEMENTED; + } } diff --git a/src/main/java/io/roastedroot/proxywasm/v1/StreamType.java b/src/main/java/io/roastedroot/proxywasm/v1/StreamType.java new file mode 100644 index 0000000..b11d9b2 --- /dev/null +++ b/src/main/java/io/roastedroot/proxywasm/v1/StreamType.java @@ -0,0 +1,47 @@ +package io.roastedroot.proxywasm.v1; + +/** + * Represents the type of map in proxy WASM. + * Converted from Go's MapType type. + */ +public enum StreamType { + REQUEST(0), + RESPONSE(1), + DOWNSTREAM(2), + UPSTREAM(3); + + private final int value; + + /** + * Constructor for MapType enum. + * + * @param value The integer value of the map type + */ + StreamType(int value) { + this.value = value; + } + + /** + * Get the integer value of this map type. + * + * @return The integer value + */ + public int getValue() { + return value; + } + + /** + * Convert an integer value to a MapType. + * + * @param value The integer value to convert + * @return The corresponding MapType or null if the value doesn't match any MapType + */ + public static StreamType fromInt(int value) { + for (StreamType type : values()) { + if (type.value == value) { + return type; + } + } + return null; + } +} diff --git a/src/test/go-examples/helloworld/main.wasm b/src/test/go-examples/helloworld/main.wasm index b2715b2..826e2fb 100644 Binary files a/src/test/go-examples/helloworld/main.wasm and b/src/test/go-examples/helloworld/main.wasm differ diff --git a/src/test/go-examples/http_body/main.wasm b/src/test/go-examples/http_body/main.wasm index 960439e..3b0be56 100644 Binary files a/src/test/go-examples/http_body/main.wasm and b/src/test/go-examples/http_body/main.wasm differ diff --git a/src/test/go-examples/http_body_chunk/README.md b/src/test/go-examples/http_body_chunk/README.md new file mode 100644 index 0000000..b854e53 --- /dev/null +++ b/src/test/go-examples/http_body_chunk/README.md @@ -0,0 +1,4 @@ +## Attribution + +This example originally came from: +https://github.com/proxy-wasm/proxy-wasm-go-sdk/tree/main/examples/http_body_chunk diff --git a/src/test/go-examples/http_body_chunk/go.mod b/src/test/go-examples/http_body_chunk/go.mod new file mode 100644 index 0000000..0def967 --- /dev/null +++ b/src/test/go-examples/http_body_chunk/go.mod @@ -0,0 +1,5 @@ +module github.com/proxy-wasm/proxy-wasm-go-sdk/examples/http_body + +go 1.24 + +require github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924 diff --git a/src/test/go-examples/http_body_chunk/go.sum b/src/test/go-examples/http_body_chunk/go.sum new file mode 100644 index 0000000..3ddb896 --- /dev/null +++ b/src/test/go-examples/http_body_chunk/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924 h1:wTcK6gcyTKJMeDka69AMjZYvisdI8CBXzTEfZ+2pOxI= +github.com/proxy-wasm/proxy-wasm-go-sdk v0.0.0-20250212164326-ab4161dcf924/go.mod h1:9mBRvh8I6Td6sg3CwEY+zGFE4DKaIoieCaca1kQnDBE= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/src/test/go-examples/http_body_chunk/main.go b/src/test/go-examples/http_body_chunk/main.go new file mode 100644 index 0000000..26e5023 --- /dev/null +++ b/src/test/go-examples/http_body_chunk/main.go @@ -0,0 +1,112 @@ +// Copyright 2020-2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "strings" + + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm" + "github.com/proxy-wasm/proxy-wasm-go-sdk/proxywasm/types" +) + +func main() {} +func init() { + proxywasm.SetVMContext(&vmContext{}) +} + +// vmContext implements types.VMContext. +type vmContext struct { + // Embed the default VM context here, + // so that we don't need to reimplement all the methods. + types.DefaultVMContext +} + +// NewPluginContext implements types.VMContext. +func (*vmContext) NewPluginContext(contextID uint32) types.PluginContext { + return &pluginContext{} +} + +// pluginContext implements types.PluginContext. +type pluginContext struct { + // Embed the default plugin context here, + // so that we don't need to reimplement all the methods. + types.DefaultPluginContext +} + +// NewHttpContext implements types.PluginContext. +func (ctx *pluginContext) NewHttpContext(contextID uint32) types.HttpContext { + return &setBodyContext{} +} + +// OnPluginStart implements types.PluginContext. +func (ctx *pluginContext) OnPluginStart(pluginConfigurationSize int) types.OnPluginStartStatus { + return types.OnPluginStartStatusOK +} + +// setBodyContext implements types.HttpContext. +type setBodyContext struct { + // Embed the default root http context here, + // so that we don't need to reimplement all the methods. + types.DefaultHttpContext + totalRequestBodyReadSize int + receivedChunks int +} + +// OnHttpRequestBody implements types.HttpContext. +func (ctx *setBodyContext) OnHttpRequestBody(bodySize int, endOfStream bool) types.Action { + proxywasm.LogInfof("OnHttpRequestBody called. BodySize: %d, totalRequestBodyReadSize: %d, endOfStream: %v", bodySize, ctx.totalRequestBodyReadSize, endOfStream) + + // If some data has been received, we read it. + // Reading the body chunk by chunk, bodySize is the size of the current chunk, not the total size of the body. + chunkSize := bodySize - ctx.totalRequestBodyReadSize + if chunkSize > 0 { + ctx.receivedChunks++ + chunk, err := proxywasm.GetHttpRequestBody(ctx.totalRequestBodyReadSize, chunkSize) + if err != nil { + proxywasm.LogCriticalf("failed to get request body: %v", err) + return types.ActionContinue + } + proxywasm.LogInfof("read chunk size: %d", len(chunk)) + if len(chunk) != chunkSize { + proxywasm.LogErrorf("read data does not match the expected size: %d != %d", len(chunk), chunkSize) + } + ctx.totalRequestBodyReadSize += len(chunk) + if strings.Contains(string(chunk), "pattern") { + patternFound := fmt.Sprintf("pattern found in chunk: %d", ctx.receivedChunks) + proxywasm.LogInfo(patternFound) + if err := proxywasm.SendHttpResponse(403, [][2]string{ + {"powered-by", "proxy-wasm-go-sdk"}, + }, []byte(patternFound), -1); err != nil { + proxywasm.LogCriticalf("failed to send local response: %v", err) + _ = proxywasm.ResumeHttpRequest() + } else { + proxywasm.LogInfo("local 403 response sent") + } + return types.ActionPause + } + } + + if !endOfStream { + // Wait until we see the entire body before sending the request upstream. + return types.ActionPause + } + // When endOfStream is true, we have received the entire body. We expect the total size is equal to the sum of the sizes of the chunks. + if ctx.totalRequestBodyReadSize != bodySize { + proxywasm.LogErrorf("read data does not match the expected total size: %d != %d", ctx.totalRequestBodyReadSize, bodySize) + } + proxywasm.LogInfof("pattern not found") + return types.ActionContinue +} diff --git a/src/test/go-examples/http_body_chunk/main.wasm b/src/test/go-examples/http_body_chunk/main.wasm new file mode 100644 index 0000000..c965d6e Binary files /dev/null and b/src/test/go-examples/http_body_chunk/main.wasm differ diff --git a/src/test/go-examples/properties/README.md b/src/test/go-examples/properties/README.md index cf4ce0d..809ee28 100644 --- a/src/test/go-examples/properties/README.md +++ b/src/test/go-examples/properties/README.md @@ -1,4 +1,4 @@ ## Attribution This example originally came from: -https://github.com/proxy-wasm/proxy-wasm-go-sdk/tree/main/examples/proxy-wasm-go-examples/properties +https://github.com/proxy-wasm/proxy-wasm-go-sdk/tree/main/examples/properties diff --git a/src/test/go-examples/vm_plugin_configuration/main.wasm b/src/test/go-examples/vm_plugin_configuration/main.wasm index c1476a4..bb43fd7 100644 Binary files a/src/test/go-examples/vm_plugin_configuration/main.wasm and b/src/test/go-examples/vm_plugin_configuration/main.wasm differ diff --git a/src/test/java/io/roastedroot/proxywasm/EchoHttpBodyTest.java b/src/test/java/io/roastedroot/proxywasm/EchoHttpBodyTest.java index d3d0520..a9ea1ac 100644 --- a/src/test/java/io/roastedroot/proxywasm/EchoHttpBodyTest.java +++ b/src/test/java/io/roastedroot/proxywasm/EchoHttpBodyTest.java @@ -64,7 +64,7 @@ public void echoRequest() throws StartException { // Must be paused. assertEquals(Action.PAUSE, action); - var response = handler.getSenthttpResponse(); + var response = handler.getSentHttpResponse(); assertNotNull(response); assertEquals(200, response.statusCode); assertEquals("frame1...frame2...frame3...", string(response.body)); diff --git a/src/test/java/io/roastedroot/proxywasm/HttpBodyChunkTest.java b/src/test/java/io/roastedroot/proxywasm/HttpBodyChunkTest.java new file mode 100644 index 0000000..e6b5a2b --- /dev/null +++ b/src/test/java/io/roastedroot/proxywasm/HttpBodyChunkTest.java @@ -0,0 +1,135 @@ +package io.roastedroot.proxywasm; + +import static io.roastedroot.proxywasm.v1.Helpers.bytes; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import com.dylibso.chicory.wasm.Parser; +import io.roastedroot.proxywasm.v1.Action; +import io.roastedroot.proxywasm.v1.HttpContext; +import io.roastedroot.proxywasm.v1.ProxyWasm; +import io.roastedroot.proxywasm.v1.StartException; +import java.nio.file.Path; +import java.util.Map; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class HttpBodyChunkTest { + + private MockHandler handler; + private ProxyWasm proxyWasm; + private HttpContext host; + + @BeforeEach + void setUp() throws StartException { + this.handler = new MockHandler(); + ProxyWasm.Builder builder = ProxyWasm.builder(); + var module = Parser.parse(Path.of("./src/test/go-examples/http_body_chunk/main.wasm")); + this.proxyWasm = builder.build(module); + this.host = proxyWasm.createHttpContext(handler); + } + + @AfterEach + void tearDown() { + host.close(); + proxyWasm.close(); + } + + @Test + public void pauseUntilEOS() { + var action = host.callOnRequestBody(false); + assertEquals(Action.PAUSE, action); + } + + @Test + public void patternFound() { + var body = bytes("This is a payload with the pattern word."); + // Call OnRequestHeaders. + handler.setHttpRequestHeaders(Map.of("content-length", String.format("%d", body.length))); + var action = host.callOnRequestHeaders(false); + + // Must be continued. + assertEquals(Action.CONTINUE, action); + + // Call OnRequestBody. + handler.setHttpRequestBody(body); + action = host.callOnRequestBody(true); + + // Must be paused + assertEquals(Action.PAUSE, action); + + handler.assertLogsContain("pattern found in chunk: 1"); + + // Check the local response. + var response = handler.getSentHttpResponse(); + assertNotNull(response); + assertEquals(403, response.statusCode); + } + + @Test + public void patternFoundInMultipleChunks() { + + var chunks = + new byte[][] { + bytes("chunk1..."), + bytes("chunk2..."), + bytes("chunk3..."), + bytes("chunk4 with pattern ...") + }; + + var chunksSize = 0; + for (byte[] chunk : chunks) { + chunksSize += chunk.length; + } + + // Call OnRequestHeaders. + handler.setHttpRequestHeaders(Map.of("content-length", String.format("%d", chunksSize))); + var action = host.callOnRequestHeaders(false); + + // Must be continued. + assertEquals(Action.CONTINUE, action); + + // Call OnRequestBody. + for (byte[] chunk : chunks) { + handler.appendHttpRequestBody(chunk); + action = host.callOnRequestBody(false); + // Must be paused. + assertEquals(Action.PAUSE, action); + } + + handler.assertLogsContain("pattern found in chunk: 4"); + handler.assertLogsDoNotContain("read data does not match"); + + // Check the local response. + var response = handler.getSentHttpResponse(); + assertNotNull(response); + assertEquals(403, response.statusCode); + } + + @Test + public void patternNotFound() { + var body = bytes("This is a generic payload."); + // Call OnRequestHeaders. + handler.setHttpRequestHeaders(Map.of("content-length", String.format("%d", body.length))); + var action = host.callOnRequestHeaders(false); + + // Must be continued. + assertEquals(Action.CONTINUE, action); + + // Call OnRequestBody. + handler.setHttpRequestBody(body); + action = host.callOnRequestBody(false); + + // Must be paused + assertEquals(Action.PAUSE, action); + + // Call OnRequestBody. + action = host.callOnRequestBody(true); + + // Must be paused + assertEquals(Action.CONTINUE, action); + + handler.assertLogsContain("pattern not found"); + } +} diff --git a/src/test/java/io/roastedroot/proxywasm/HttpBodyTest.java b/src/test/java/io/roastedroot/proxywasm/HttpBodyTest.java index ed4f066..56c598f 100644 --- a/src/test/java/io/roastedroot/proxywasm/HttpBodyTest.java +++ b/src/test/java/io/roastedroot/proxywasm/HttpBodyTest.java @@ -64,7 +64,7 @@ public void testOnHttpRequestHeaders400Response() throws StartException { assertEquals(Action.PAUSE, action); // Check the local response. - var response = handler.getSenthttpResponse(); + var response = handler.getSentHttpResponse(); assertNotNull(response); assertEquals(400, response.statusCode); assertEquals("content must be provided", string(response.body)); diff --git a/src/test/java/io/roastedroot/proxywasm/MockHandler.java b/src/test/java/io/roastedroot/proxywasm/MockHandler.java index 0e174e4..96a7ad6 100644 --- a/src/test/java/io/roastedroot/proxywasm/MockHandler.java +++ b/src/test/java/io/roastedroot/proxywasm/MockHandler.java @@ -1,6 +1,8 @@ package io.roastedroot.proxywasm; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.roastedroot.proxywasm.v1.Handler; import io.roastedroot.proxywasm.v1.Helpers; @@ -82,6 +84,17 @@ public void assertSortedLogsEqual(String... messages) { loggedMessages().stream().sorted().collect(Collectors.toList())); } + public void assertLogsContain(String message) { + assertTrue( + loggedMessages().contains(message), "logged messages does not contain: " + message); + } + + public void assertLogsDoNotContain(String message) { + for (String log : loggedMessages()) { + assertFalse(log.contains(message), "logged messages contains: " + message); + } + } + @Override public Map getHttpRequestHeaders() { return httpRequestHeaders; @@ -272,7 +285,7 @@ public WasmResult sendHttpResponse( return WasmResult.OK; } - public HttpResponse getSenthttpResponse() { + public HttpResponse getSentHttpResponse() { return senthttpResponse; } } diff --git a/src/test/java/io/roastedroot/proxywasm/PropertiesTest.java b/src/test/java/io/roastedroot/proxywasm/PropertiesTest.java index 7b6931a..e2e8da0 100644 --- a/src/test/java/io/roastedroot/proxywasm/PropertiesTest.java +++ b/src/test/java/io/roastedroot/proxywasm/PropertiesTest.java @@ -83,7 +83,7 @@ public void userIsUnauthenticated() { assertEquals(Action.PAUSE, action); } - var response = handler.getSenthttpResponse(); + var response = handler.getSentHttpResponse(); assertNotNull(response); assertEquals(401, response.statusCode); assertArrayEquals(new byte[0], response.body);