diff --git a/src/main/java/io/roastedroot/proxywasm/ABI.java b/src/main/java/io/roastedroot/proxywasm/ABI.java index e8b48b5..cd8a762 100644 --- a/src/main/java/io/roastedroot/proxywasm/ABI.java +++ b/src/main/java/io/roastedroot/proxywasm/ABI.java @@ -12,7 +12,6 @@ import com.dylibso.chicory.wasm.InvalidException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.HashMap; import java.util.Map; @HostModule("env") @@ -675,19 +674,18 @@ int proxyGetHeaderMapSize(int mapType, int returnSize) { try { // Get the header map based on the map type - Map header = getMap(mapType); + ProxyMap header = getMap(mapType); if (header == null) { return WasmResult.BAD_ARGUMENT.getValue(); } // to clone the headers so that they don't change on while we process them in the loop - final Map cloneMap = new HashMap<>(); + var cloneMap = new ArrayProxyMap(header); int totalBytesLen = U32_LEN; // Start with space for the count - for (Map.Entry entry : header.entrySet()) { + for (Map.Entry entry : cloneMap.entries()) { String key = entry.getKey(); String value = entry.getValue(); - cloneMap.put(key, value); totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0 } @@ -717,19 +715,18 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) { try { // Get the header map based on the map type - Map header = getMap(mapType); + ProxyMap header = getMap(mapType); if (header == null) { return WasmResult.NOT_FOUND.getValue(); } // to clone the headers so that they don't change on while we process them in the loop - final Map cloneMap = new HashMap<>(); + var cloneMap = new ArrayProxyMap(header); int totalBytesLen = U32_LEN; // Start with space for the count - for (Map.Entry entry : header.entrySet()) { + for (Map.Entry entry : cloneMap.entries()) { String key = entry.getKey(); String value = entry.getValue(); - cloneMap.put(key, value); totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0 } @@ -745,7 +742,7 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) { int dataPtr = lenPtr + ((U32_LEN + U32_LEN) * cloneMap.size()); // Write each key-value pair to memory - for (Map.Entry entry : cloneMap.entrySet()) { + for (Map.Entry entry : cloneMap.entries()) { String key = entry.getKey(); String value = entry.getValue(); @@ -802,14 +799,14 @@ int proxySetHeaderMapPairs(int mapType, int ptr, int size) { try { // Get the header map based on the map type - Map headerMap = getMap(mapType); + ProxyMap headerMap = getMap(mapType); if (headerMap == null) { return WasmResult.BAD_ARGUMENT.getValue(); } // Decode the map content and set each key-value pair - Map newMap = decodeMap(ptr, size); - for (Map.Entry entry : newMap.entrySet()) { + ProxyMap newMap = decodeMap(ptr, size); + for (Map.Entry entry : newMap.entries()) { headerMap.put(entry.getKey(), entry.getValue()); } @@ -837,7 +834,7 @@ int proxyGetHeaderMapValue( int mapType, int keyDataPtr, int keySize, int valueDataPtr, int valueSize) { try { // Get the header map based on the map type - Map headerMap = getMap(mapType); + ProxyMap headerMap = getMap(mapType); if (headerMap == null) { return WasmResult.BAD_ARGUMENT.getValue(); } @@ -895,7 +892,7 @@ int proxyReplaceHeaderMapValue( int mapType, int keyDataPtr, int keySize, int valueDataPtr, int valueSize) { try { // Get the header map based on the map type - Map headerMap = getMap(mapType); + ProxyMap headerMap = getMap(mapType); if (headerMap == null) { return WasmResult.BAD_ARGUMENT.getValue(); } @@ -907,7 +904,7 @@ int proxyReplaceHeaderMapValue( String value = readString(valueDataPtr, valueSize); // Replace value in map - var copy = new HashMap<>(headerMap); + var copy = new ArrayProxyMap(headerMap); copy.put(key, value); setMap(mapType, copy); @@ -933,7 +930,7 @@ int proxyReplaceHeaderMapValue( int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) { try { // Get the header map based on the map type - Map headerMap = getMap(mapType); + ProxyMap headerMap = getMap(mapType); if (headerMap == null) { return WasmResult.NOT_FOUND.getValue(); } @@ -945,7 +942,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) { } // Remove key from map - var copy = new HashMap<>(headerMap); + var copy = new ArrayProxyMap(headerMap); copy.remove(key); setMap(mapType, copy); @@ -964,7 +961,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) { * @param mapType The type of map to get * @return The header map */ - private Map getMap(int mapType) { + private ProxyMap getMap(int mapType) { var knownType = MapType.fromInt(mapType); if (knownType == null) { @@ -999,7 +996,7 @@ private Map getMap(int mapType) { * @param map The header map to set * @return WasmResult indicating success or failure */ - private WasmResult setMap(int mapType, Map map) { + private WasmResult setMap(int mapType, ProxyMap map) { var knownType = MapType.fromInt(mapType); if (knownType == null) { return handler.setCustomHeaders(mapType, map); @@ -1043,9 +1040,9 @@ private WasmResult setMap(int mapType, Map map) { * @return The decoded map containing string keys and values * @throws WasmException if there is an error accessing memory */ - private HashMap decodeMap(int addr, int mem_size) throws WasmException { + private ProxyMap decodeMap(int addr, int mem_size) throws WasmException { if (mem_size < U32_LEN) { - return new HashMap<>(); + return new ArrayProxyMap(); } // Read header size (number of entries) @@ -1055,11 +1052,11 @@ private HashMap decodeMap(int addr, int mem_size) throws WasmExc // mapSize + (key1_size + value1_size) * mapSize long dataOffset = U32_LEN + (U32_LEN + U32_LEN) * mapSize; if (dataOffset >= mem_size) { - return new HashMap<>(); + return new ArrayProxyMap(); } // Create result map with initial capacity - var result = new HashMap((int) mapSize); + var result = new ArrayProxyMap((int) mapSize); // Process each entry for (int i = 0; i < mapSize; i++) { @@ -1086,7 +1083,7 @@ private HashMap decodeMap(int addr, int mem_size) throws WasmExc dataOffset += valueSize + 1; // Add to result map - result.put(key, value); + result.add(key, value); } return result; @@ -1282,8 +1279,7 @@ int proxySendLocalResponse( } // Get and decode additional headers from memory - HashMap additionalHeaders = - decodeMap(additionalHeadersMapData, additionalHeadersSize); + ProxyMap additionalHeaders = decodeMap(additionalHeadersMapData, additionalHeadersSize); // Send the response through the handler WasmResult result = diff --git a/src/main/java/io/roastedroot/proxywasm/ArrayProxyMap.java b/src/main/java/io/roastedroot/proxywasm/ArrayProxyMap.java new file mode 100644 index 0000000..a98db45 --- /dev/null +++ b/src/main/java/io/roastedroot/proxywasm/ArrayProxyMap.java @@ -0,0 +1,86 @@ +package io.roastedroot.proxywasm; + +import java.util.ArrayList; +import java.util.Map; +import java.util.Objects; + +public class ArrayProxyMap implements ProxyMap { + + final ArrayList> entries; + + public ArrayProxyMap() { + this.entries = new ArrayList<>(); + } + + public ArrayProxyMap(int mapSize) { + this.entries = new ArrayList<>(mapSize); + } + + public ArrayProxyMap(ProxyMap other) { + this(other.size()); + for (Map.Entry entry : other.entries()) { + add(entry.getKey(), entry.getValue()); + } + } + + public ArrayProxyMap(Map other) { + this(other.size()); + for (Map.Entry entry : other.entrySet()) { + add(entry.getKey(), entry.getValue()); + } + } + + @Override + public int size() { + return entries.size(); + } + + @Override + public void add(String key, String value) { + entries.add(Map.entry(key, value)); + } + + @Override + public void put(String key, String value) { + this.remove(key); + entries.add(Map.entry(key, value)); + } + + @Override + public Iterable> entries() { + return entries; + } + + @Override + public String get(String key) { + return entries.stream() + .filter(x -> x.getKey().equals(key)) + .map(Map.Entry::getValue) + .findFirst() + .orElse(null); + } + + @Override + public void remove(String key) { + entries.removeIf(x -> x.getKey().equals(key)); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayProxyMap that = (ArrayProxyMap) o; + return Objects.equals(entries, that.entries); + } + + @Override + public int hashCode() { + return Objects.hashCode(entries); + } + + @Override + public String toString() { + return entries.toString(); + } +} diff --git a/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java b/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java index 55ec17b..23b13aa 100644 --- a/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java +++ b/src/main/java/io/roastedroot/proxywasm/ChainedHandler.java @@ -1,8 +1,5 @@ package io.roastedroot.proxywasm; -import java.util.HashMap; -import java.util.Map; - /** * A Handler implementation that chains to another handler if it can't handle the request. */ @@ -16,92 +13,92 @@ public void log(LogLevel level, String message) throws WasmException { } @Override - public Map getHttpRequestHeaders() { + public ProxyMap getHttpRequestHeaders() { return next().getHttpRequestHeaders(); } @Override - public Map getHttpRequestTrailers() { + public ProxyMap getHttpRequestTrailers() { return next().getHttpRequestTrailers(); } @Override - public Map getHttpResponseHeaders() { + public ProxyMap getHttpResponseHeaders() { return next().getHttpResponseHeaders(); } @Override - public Map getHttpResponseTrailers() { + public ProxyMap getHttpResponseTrailers() { return next().getHttpResponseTrailers(); } @Override - public Map getHttpCallResponseHeaders() { + public ProxyMap getHttpCallResponseHeaders() { return next().getHttpCallResponseHeaders(); } @Override - public Map getHttpCallResponseTrailers() { + public ProxyMap getHttpCallResponseTrailers() { return next().getHttpCallResponseTrailers(); } @Override - public Map getGrpcReceiveInitialMetaData() { + public ProxyMap getGrpcReceiveInitialMetaData() { return next().getGrpcReceiveInitialMetaData(); } @Override - public Map getGrpcReceiveTrailerMetaData() { + public ProxyMap getGrpcReceiveTrailerMetaData() { return next().getGrpcReceiveTrailerMetaData(); } @Override - public Map getCustomHeaders(int mapType) { + public ProxyMap getCustomHeaders(int mapType) { return next().getCustomHeaders(mapType); } @Override - public WasmResult setCustomHeaders(int mapType, Map map) { + public WasmResult setCustomHeaders(int mapType, ProxyMap map) { return next().setCustomHeaders(mapType, map); } @Override - public WasmResult setHttpRequestHeaders(Map headers) { + public WasmResult setHttpRequestHeaders(ProxyMap headers) { return next().setHttpRequestHeaders(headers); } @Override - public WasmResult setHttpRequestTrailers(Map trailers) { + public WasmResult setHttpRequestTrailers(ProxyMap trailers) { return next().setHttpRequestTrailers(trailers); } @Override - public WasmResult setHttpResponseHeaders(Map headers) { + public WasmResult setHttpResponseHeaders(ProxyMap headers) { return next().setHttpResponseHeaders(headers); } @Override - public WasmResult setHttpResponseTrailers(Map trailers) { + public WasmResult setHttpResponseTrailers(ProxyMap trailers) { return next().setHttpResponseTrailers(trailers); } @Override - public WasmResult setHttpCallResponseHeaders(Map headers) { + public WasmResult setHttpCallResponseHeaders(ProxyMap headers) { return next().setHttpCallResponseHeaders(headers); } @Override - public WasmResult setHttpCallResponseTrailers(Map trailers) { + public WasmResult setHttpCallResponseTrailers(ProxyMap trailers) { return next().setHttpCallResponseTrailers(trailers); } @Override - public WasmResult setGrpcReceiveInitialMetaData(Map metadata) { + public WasmResult setGrpcReceiveInitialMetaData(ProxyMap metadata) { return next().setGrpcReceiveInitialMetaData(metadata); } @Override - public WasmResult setGrpcReceiveTrailerMetaData(Map metadata) { + public WasmResult setGrpcReceiveTrailerMetaData(ProxyMap metadata) { return next().setGrpcReceiveTrailerMetaData(metadata); } @@ -185,7 +182,7 @@ public WasmResult sendHttpResponse( int responseCode, byte[] responseCodeDetails, byte[] responseBody, - Map additionalHeaders, + ProxyMap additionalHeaders, int grpcStatus) { return next().sendHttpResponse( responseCode, @@ -246,12 +243,7 @@ public WasmResult clearRouteCache() { } @Override - public int httpCall( - String uri, - HashMap headers, - byte[] body, - HashMap trailers, - int timeout) + public int httpCall(String uri, ProxyMap headers, byte[] body, ProxyMap trailers, int timeout) throws WasmException { return next().httpCall(uri, headers, body, trailers, timeout); } @@ -259,9 +251,9 @@ public int httpCall( @Override public int dispatchHttpCall( String upstreamName, - HashMap headers, + ProxyMap headers, byte[] body, - HashMap trailers, + ProxyMap trailers, int timeoutMilliseconds) throws WasmException { return next().dispatchHttpCall(upstreamName, headers, body, trailers, timeoutMilliseconds); diff --git a/src/main/java/io/roastedroot/proxywasm/Handler.java b/src/main/java/io/roastedroot/proxywasm/Handler.java index 568024d..e4160d7 100644 --- a/src/main/java/io/roastedroot/proxywasm/Handler.java +++ b/src/main/java/io/roastedroot/proxywasm/Handler.java @@ -1,8 +1,5 @@ package io.roastedroot.proxywasm; -import java.util.HashMap; -import java.util.Map; - public interface Handler { default void log(LogLevel level, String message) throws WasmException {} @@ -12,39 +9,39 @@ default LogLevel getLogLevel() throws WasmException { } // TODO: use a better type than Map so that we can support repeated headers - default Map getHttpRequestHeaders() { + default ProxyMap getHttpRequestHeaders() { return null; } - default Map getHttpRequestTrailers() { + default ProxyMap getHttpRequestTrailers() { return null; } - default Map getHttpResponseHeaders() { + default ProxyMap getHttpResponseHeaders() { return null; } - default Map getHttpResponseTrailers() { + default ProxyMap getHttpResponseTrailers() { return null; } - default Map getHttpCallResponseHeaders() { + default ProxyMap getHttpCallResponseHeaders() { return null; } - default Map getHttpCallResponseTrailers() { + default ProxyMap getHttpCallResponseTrailers() { return null; } - default Map getGrpcReceiveInitialMetaData() { + default ProxyMap getGrpcReceiveInitialMetaData() { return null; } - default Map getGrpcReceiveTrailerMetaData() { + default ProxyMap getGrpcReceiveTrailerMetaData() { return null; } - default Map getCustomHeaders(int mapType) { + default ProxyMap getCustomHeaders(int mapType) { return null; } @@ -198,7 +195,7 @@ default WasmResult sendHttpResponse( int responseCode, byte[] responseCodeDetails, byte[] responseBody, - Map additionalHeaders, + ProxyMap additionalHeaders, int grpcStatus) { return WasmResult.UNIMPLEMENTED; } @@ -291,7 +288,7 @@ default WasmResult setCustomBuffer(int bufferType, byte[] buffer) { * @param map The header map to set * @return WasmResult indicating success or failure */ - default WasmResult setCustomHeaders(int mapType, Map map) { + default WasmResult setCustomHeaders(int mapType, ProxyMap map) { return WasmResult.UNIMPLEMENTED; } @@ -301,7 +298,7 @@ default WasmResult setCustomHeaders(int mapType, Map map) { * @param headers The headers to set * @return WasmResult indicating success or failure */ - default WasmResult setHttpRequestHeaders(Map headers) { + default WasmResult setHttpRequestHeaders(ProxyMap headers) { return WasmResult.UNIMPLEMENTED; } @@ -311,7 +308,7 @@ default WasmResult setHttpRequestHeaders(Map headers) { * @param trailers The trailers to set * @return WasmResult indicating success or failure */ - default WasmResult setHttpRequestTrailers(Map trailers) { + default WasmResult setHttpRequestTrailers(ProxyMap trailers) { return WasmResult.UNIMPLEMENTED; } @@ -321,7 +318,7 @@ default WasmResult setHttpRequestTrailers(Map trailers) { * @param headers The headers to set * @return WasmResult indicating success or failure */ - default WasmResult setHttpResponseHeaders(Map headers) { + default WasmResult setHttpResponseHeaders(ProxyMap headers) { return WasmResult.UNIMPLEMENTED; } @@ -331,7 +328,7 @@ default WasmResult setHttpResponseHeaders(Map headers) { * @param trailers The trailers to set * @return WasmResult indicating success or failure */ - default WasmResult setHttpResponseTrailers(Map trailers) { + default WasmResult setHttpResponseTrailers(ProxyMap trailers) { return WasmResult.UNIMPLEMENTED; } @@ -341,7 +338,7 @@ default WasmResult setHttpResponseTrailers(Map trailers) { * @param headers The headers to set * @return WasmResult indicating success or failure */ - default WasmResult setHttpCallResponseHeaders(Map headers) { + default WasmResult setHttpCallResponseHeaders(ProxyMap headers) { return WasmResult.UNIMPLEMENTED; } @@ -351,7 +348,7 @@ default WasmResult setHttpCallResponseHeaders(Map headers) { * @param trailers The trailers to set * @return WasmResult indicating success or failure */ - default WasmResult setHttpCallResponseTrailers(Map trailers) { + default WasmResult setHttpCallResponseTrailers(ProxyMap trailers) { return WasmResult.UNIMPLEMENTED; } @@ -361,7 +358,7 @@ default WasmResult setHttpCallResponseTrailers(Map trailers) { * @param metadata The metadata to set * @return WasmResult indicating success or failure */ - default WasmResult setGrpcReceiveInitialMetaData(Map metadata) { + default WasmResult setGrpcReceiveInitialMetaData(ProxyMap metadata) { return WasmResult.UNIMPLEMENTED; } @@ -371,7 +368,7 @@ default WasmResult setGrpcReceiveInitialMetaData(Map metadata) { * @param metadata The metadata to set * @return WasmResult indicating success or failure */ - default WasmResult setGrpcReceiveTrailerMetaData(Map metadata) { + default WasmResult setGrpcReceiveTrailerMetaData(ProxyMap metadata) { return WasmResult.UNIMPLEMENTED; } @@ -384,20 +381,16 @@ default WasmResult clearRouteCache() { } default int httpCall( - String uri, - HashMap headers, - byte[] body, - HashMap trailers, - int timeoutMilliseconds) + String uri, ProxyMap headers, byte[] body, ProxyMap trailers, int timeoutMilliseconds) throws WasmException { throw new WasmException(WasmResult.UNIMPLEMENTED); } default int dispatchHttpCall( String upstreamName, - HashMap headers, + ProxyMap headers, byte[] body, - HashMap trailers, + ProxyMap trailers, int timeoutMilliseconds) throws WasmException { throw new WasmException(WasmResult.UNIMPLEMENTED); diff --git a/src/main/java/io/roastedroot/proxywasm/Helpers.java b/src/main/java/io/roastedroot/proxywasm/Helpers.java index 7182b46..e71f4ba 100644 --- a/src/main/java/io/roastedroot/proxywasm/Helpers.java +++ b/src/main/java/io/roastedroot/proxywasm/Helpers.java @@ -37,6 +37,13 @@ public static int len(byte[] value) { return value.length; } + public static int len(ProxyMap value) { + if (value == null) { + return 0; + } + return value.size(); + } + public static int len(String value) { if (value == null) { return 0; diff --git a/src/main/java/io/roastedroot/proxywasm/ProxyMap.java b/src/main/java/io/roastedroot/proxywasm/ProxyMap.java new file mode 100644 index 0000000..34e97c5 --- /dev/null +++ b/src/main/java/io/roastedroot/proxywasm/ProxyMap.java @@ -0,0 +1,36 @@ +package io.roastedroot.proxywasm; + +import java.util.Map; + +public interface ProxyMap { + + static ProxyMap of(String... values) { + if (values.length % 2 != 0) { + throw new IllegalArgumentException("values must be even"); + } + ArrayProxyMap map = new ArrayProxyMap(values.length / 2); + for (int i = 0; i < values.length; i += 2) { + map.add(values[i], values[i + 1]); + } + return map; + } + + static ProxyMap copyOf(Map headers) { + if (headers == null) { + return null; + } + return new ArrayProxyMap(headers); + } + + int size(); + + void add(String key, String value); + + void put(String key, String value); + + Iterable> entries(); + + String get(String key); + + void remove(String key); +} diff --git a/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java b/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java index 3d71687..cac5208 100644 --- a/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java +++ b/src/main/java/io/roastedroot/proxywasm/ProxyWasm.java @@ -34,8 +34,8 @@ public final class ProxyWasm implements Closeable { private Context activeContext; private HashMap contexts = new HashMap<>(); - private Map httpCallResponseHeaders; - private Map httpCallResponseTrailers; + private ProxyMap httpCallResponseHeaders; + private ProxyMap httpCallResponseTrailers; private byte[] httpCallResponseBody; private HashMap foreignFunctions = new HashMap<>(); @@ -120,12 +120,12 @@ public WasmResult done() { } @Override - public Map getHttpCallResponseHeaders() { + public ProxyMap getHttpCallResponseHeaders() { return httpCallResponseHeaders; } @Override - public Map getHttpCallResponseTrailers() { + public ProxyMap getHttpCallResponseTrailers() { return httpCallResponseTrailers; } @@ -220,6 +220,12 @@ public static ProxyWasm.Builder builder() { public void sendHttpCallResponse( int calloutID, Map headers, Map trailers, byte[] body) { + this.sendHttpCallResponse( + calloutID, ProxyMap.copyOf(headers), ProxyMap.copyOf(trailers), body); + } + + public void sendHttpCallResponse( + int calloutID, ProxyMap headers, ProxyMap trailers, byte[] body) { this.httpCallResponseHeaders = headers; this.httpCallResponseTrailers = trailers; diff --git a/src/test/java/io/roastedroot/proxywasm/examples/HttpAuthRandomTest.java b/src/test/java/io/roastedroot/proxywasm/examples/HttpAuthRandomTest.java index 821d780..d4253d4 100644 --- a/src/test/java/io/roastedroot/proxywasm/examples/HttpAuthRandomTest.java +++ b/src/test/java/io/roastedroot/proxywasm/examples/HttpAuthRandomTest.java @@ -8,10 +8,10 @@ import com.dylibso.chicory.wasm.Parser; import io.roastedroot.proxywasm.Action; +import io.roastedroot.proxywasm.ProxyMap; import io.roastedroot.proxywasm.ProxyWasm; import io.roastedroot.proxywasm.StartException; import java.nio.file.Path; -import java.util.Map; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -45,7 +45,7 @@ public void onHttpRequestHeaders() throws StartException { try (var context = host.createHttpContext(handler)) { // Call OnRequestHeaders. - handler.setHttpRequestHeaders(Map.of("key", "value")); + handler.setHttpRequestHeaders(ProxyMap.of("key", "value")); var action = context.callOnRequestHeaders(false); Assertions.assertEquals(Action.PAUSE, action); @@ -64,7 +64,7 @@ public void onHttpRequestHeaders() throws StartException { @Test public void onHttpCallResponse() throws StartException { var headers = - Map.of( + ProxyMap.of( "HTTP/1.1", "200 OK", "Date:", "Thu, 17 Sep 2020 02:47:07 GMT", "Content-Type", "application/json", @@ -78,7 +78,7 @@ public void onHttpCallResponse() throws StartException { try (var context = host.createHttpContext(handler)) { // Call OnRequestHeaders. - handler.setHttpRequestHeaders(Map.of()); + handler.setHttpRequestHeaders(ProxyMap.of()); var action = context.callOnRequestHeaders(false); assertEquals(Action.PAUSE, action); @@ -101,7 +101,7 @@ public void onHttpCallResponse() throws StartException { try (var context = host.createHttpContext(handler)) { // Call OnRequestHeaders. - handler.setHttpRequestHeaders(Map.of()); + handler.setHttpRequestHeaders(ProxyMap.of()); var action = context.callOnRequestHeaders(false); assertEquals(Action.PAUSE, action); @@ -116,7 +116,7 @@ public void onHttpCallResponse() throws StartException { assertNotNull(localResponse); assertEquals(403, localResponse.statusCode); assertEquals("access forbidden", string(localResponse.body)); - assertEquals(Map.of("powered-by", "proxy-wasm-go-sdk!!"), localResponse.headers); + assertEquals(ProxyMap.of("powered-by", "proxy-wasm-go-sdk!!"), localResponse.headers); // CHeck Envoy logs. handler.assertLogsContain("access forbidden"); diff --git a/src/test/java/io/roastedroot/proxywasm/examples/HttpBodyTest.java b/src/test/java/io/roastedroot/proxywasm/examples/HttpBodyTest.java index 817500c..f0f3036 100644 --- a/src/test/java/io/roastedroot/proxywasm/examples/HttpBodyTest.java +++ b/src/test/java/io/roastedroot/proxywasm/examples/HttpBodyTest.java @@ -8,11 +8,11 @@ import com.dylibso.chicory.wasm.Parser; import io.roastedroot.proxywasm.Action; import io.roastedroot.proxywasm.HttpContext; +import io.roastedroot.proxywasm.ProxyMap; import io.roastedroot.proxywasm.ProxyWasm; import io.roastedroot.proxywasm.StartException; import java.nio.charset.StandardCharsets; import java.nio.file.Path; -import java.util.Map; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; @@ -47,7 +47,7 @@ public void testOnHttpRequestHeadersRemoveRequestHeader() throws StartException // Call OnRequestHeaders. handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "content-length", "10", "buffer-operation", "replace")); var action = httpContext.callOnRequestHeaders(false); @@ -56,7 +56,7 @@ public void testOnHttpRequestHeadersRemoveRequestHeader() throws StartException Assertions.assertEquals(Action.CONTINUE, action); var headers = handler.getHttpRequestHeaders(); - assertEquals(Map.of("buffer-operation", "replace"), headers); + assertEquals(ProxyMap.of("buffer-operation", "replace"), headers); } @Test @@ -88,7 +88,7 @@ public void testOnHttpRequestBodyPauseUntilEOS() throws StartException { public void testOnHttpRequestBodyAppend() throws StartException { // Call callOnRequestBody handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "content-length", "10", "buffer-operation", "append")); @@ -111,7 +111,7 @@ public void testOnHttpRequestBodyAppend() throws StartException { public void testOnHttpRequestBodyPrepend() throws StartException { // Call callOnRequestBody handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "content-length", "10", "buffer-operation", "prepend")); @@ -134,7 +134,7 @@ public void testOnHttpRequestBodyPrepend() throws StartException { public void testOnHttpRequestBodyReplace() throws StartException { // Call callOnRequestBody handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "content-length", "10", "buffer-operation", "replace")); @@ -157,7 +157,7 @@ public void testOnHttpResponseBodyAppend() throws StartException { // Call OnRequestHeaders handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "buffer-replace-at", "response", "content-length", "10", "buffer-operation", "append")); @@ -178,7 +178,7 @@ public void testOnHttpResponseBodyPrepend() throws StartException { // Call OnRequestHeaders handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "buffer-replace-at", "response", "content-length", "10", "buffer-operation", "prepend")); @@ -199,7 +199,7 @@ public void testOnHttpResponseBodyReplace() throws StartException { // Call OnRequestHeaders handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "buffer-replace-at", "response", "content-length", "10", "buffer-operation", "replace")); diff --git a/src/test/java/io/roastedroot/proxywasm/examples/HttpHeadersTest.java b/src/test/java/io/roastedroot/proxywasm/examples/HttpHeadersTest.java index f423b97..c9e4fcc 100644 --- a/src/test/java/io/roastedroot/proxywasm/examples/HttpHeadersTest.java +++ b/src/test/java/io/roastedroot/proxywasm/examples/HttpHeadersTest.java @@ -5,10 +5,10 @@ import com.dylibso.chicory.wasm.Parser; import io.roastedroot.proxywasm.Action; +import io.roastedroot.proxywasm.ProxyMap; import io.roastedroot.proxywasm.ProxyWasm; import io.roastedroot.proxywasm.StartException; import java.nio.file.Path; -import java.util.Map; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -28,7 +28,7 @@ public void onHttpRequestHeaders() throws StartException { try (var host = proxyWasm.createHttpContext(handler)) { id = host.id(); handler.setHttpRequestHeaders( - Map.of( + ProxyMap.of( "key1", "value1", "key2", "value2")); var action = host.callOnRequestHeaders(false); @@ -60,7 +60,7 @@ public void onHttpResponseHeaders() throws StartException { try (var host = proxyWasm.createHttpContext(handler)) { id = host.id(); handler.setHttpResponseHeaders( - Map.of( + ProxyMap.of( "key1", "value1", "key2", "value2")); var action = host.callOnResponseHeaders(false); @@ -69,11 +69,11 @@ public void onHttpResponseHeaders() throws StartException { // Check headers assertEquals( - Map.of( + ProxyMap.of( "key1", "value1", "key2", "value2", - "x-wasm-header", "x-value", - "x-proxy-wasm-go-sdk-example", "http_headers"), + "x-proxy-wasm-go-sdk-example", "http_headers", + "x-wasm-header", "x-value"), handler.getHttpResponseHeaders()); // Check logs diff --git a/src/test/java/io/roastedroot/proxywasm/examples/HttpRoutingTest.java b/src/test/java/io/roastedroot/proxywasm/examples/HttpRoutingTest.java index 4015662..9e87c19 100644 --- a/src/test/java/io/roastedroot/proxywasm/examples/HttpRoutingTest.java +++ b/src/test/java/io/roastedroot/proxywasm/examples/HttpRoutingTest.java @@ -35,7 +35,7 @@ public void canary() throws StartException { assertEquals(Action.CONTINUE, action); // Get and verify modified headers - Map resultHeaders = handler.getHttpRequestHeaders(); + var resultHeaders = handler.getHttpRequestHeaders(); assertEquals(1, resultHeaders.size()); assertEquals("my-host.com-canary", resultHeaders.get(":authority")); } @@ -62,7 +62,7 @@ public void nonCanary() throws StartException { assertEquals(Action.CONTINUE, action); // Get and verify modified headers - Map resultHeaders = handler.getHttpRequestHeaders(); + var resultHeaders = handler.getHttpRequestHeaders(); assertEquals(1, resultHeaders.size()); assertEquals("my-host.com", resultHeaders.get(":authority")); } diff --git a/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java b/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java index cafb021..9205968 100644 --- a/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java +++ b/src/test/java/io/roastedroot/proxywasm/examples/MockHandler.java @@ -5,11 +5,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import io.roastedroot.proxywasm.Action; +import io.roastedroot.proxywasm.ArrayProxyMap; import io.roastedroot.proxywasm.ChainedHandler; import io.roastedroot.proxywasm.Handler; import io.roastedroot.proxywasm.Helpers; import io.roastedroot.proxywasm.LogLevel; import io.roastedroot.proxywasm.MetricType; +import io.roastedroot.proxywasm.ProxyMap; import io.roastedroot.proxywasm.StreamType; import io.roastedroot.proxywasm.WasmException; import io.roastedroot.proxywasm.WasmResult; @@ -30,14 +32,14 @@ public static class HttpResponse { public final int statusCode; public final byte[] statusCodeDetails; public final byte[] body; - public final Map headers; + public final ProxyMap headers; public final int grpcStatus; public HttpResponse( int responseCode, byte[] responseCodeDetails, byte[] responseBody, - Map additionalHeaders, + ProxyMap additionalHeaders, int grpcStatus) { this.statusCode = responseCode; this.statusCodeDetails = responseCodeDetails; @@ -50,12 +52,12 @@ public HttpResponse( final ArrayList loggedMessages = new ArrayList<>(); private int tickPeriodMilliseconds; - private Map httpRequestHeaders = new HashMap<>(); - private Map httpRequestTrailers = new HashMap<>(); - private Map httpResponseHeaders = new HashMap<>(); - private Map httpResponseTrailers = new HashMap<>(); - private Map grpcReceiveInitialMetadata = new HashMap<>(); - private Map grpcReceiveTrailerMetadata = new HashMap<>(); + private ProxyMap httpRequestHeaders = new ArrayProxyMap(); + private ProxyMap httpRequestTrailers = new ArrayProxyMap(); + private ProxyMap httpResponseHeaders = new ArrayProxyMap(); + private ProxyMap httpResponseTrailers = new ArrayProxyMap(); + private ProxyMap grpcReceiveInitialMetadata = new ArrayProxyMap(); + private ProxyMap grpcReceiveTrailerMetadata = new ArrayProxyMap(); private HttpResponse senthttpResponse; private byte[] funcCallData = new byte[0]; @@ -127,67 +129,79 @@ public int getTickPeriodMilliseconds() { } @Override - public Map getHttpRequestHeaders() { + public ProxyMap getHttpRequestHeaders() { return httpRequestHeaders; } @Override - public Map getHttpRequestTrailers() { + public ProxyMap getHttpRequestTrailers() { return httpRequestTrailers; } @Override - public Map getHttpResponseHeaders() { + public ProxyMap getHttpResponseHeaders() { return httpResponseHeaders; } @Override - public Map getHttpResponseTrailers() { + public ProxyMap getHttpResponseTrailers() { return httpResponseTrailers; } @Override - public Map getGrpcReceiveInitialMetaData() { + public ProxyMap getGrpcReceiveInitialMetaData() { return grpcReceiveInitialMetadata; } @Override - public Map getGrpcReceiveTrailerMetaData() { + public ProxyMap getGrpcReceiveTrailerMetaData() { return grpcReceiveTrailerMetadata; } @Override - public WasmResult setHttpRequestHeaders(Map headers) { + public WasmResult setHttpRequestHeaders(ProxyMap headers) { this.httpRequestHeaders = headers; return WasmResult.OK; } + public WasmResult setHttpRequestHeaders(Map headers) { + return this.setHttpRequestHeaders(new ArrayProxyMap(headers)); + } + @Override - public WasmResult setHttpRequestTrailers(Map trailers) { + public WasmResult setHttpRequestTrailers(ProxyMap trailers) { this.httpRequestTrailers = trailers; return WasmResult.OK; } + public WasmResult setHttpRequestTrailers(Map headers) { + return this.setHttpRequestTrailers(new ArrayProxyMap(headers)); + } + @Override - public WasmResult setHttpResponseHeaders(Map headers) { + public WasmResult setHttpResponseHeaders(ProxyMap headers) { this.httpResponseHeaders = headers; return WasmResult.OK; } + public WasmResult setHttpResponseHeaders(Map headers) { + return this.setHttpResponseHeaders(new ArrayProxyMap(headers)); + } + @Override - public WasmResult setHttpResponseTrailers(Map trailers) { + public WasmResult setHttpResponseTrailers(ProxyMap trailers) { this.httpResponseTrailers = trailers; return WasmResult.OK; } @Override - public WasmResult setGrpcReceiveInitialMetaData(Map metadata) { + public WasmResult setGrpcReceiveInitialMetaData(ProxyMap metadata) { this.grpcReceiveInitialMetadata = metadata; return WasmResult.OK; } @Override - public WasmResult setGrpcReceiveTrailerMetaData(Map metadata) { + public WasmResult setGrpcReceiveTrailerMetaData(ProxyMap metadata) { this.grpcReceiveTrailerMetadata = metadata; return WasmResult.OK; } @@ -271,7 +285,7 @@ public WasmResult sendHttpResponse( int responseCode, byte[] responseCodeDetails, byte[] responseBody, - Map additionalHeaders, + ProxyMap additionalHeaders, int grpcStatus) { this.senthttpResponse = new HttpResponse( @@ -298,16 +312,16 @@ public enum Type { public final String uri; public final Object headers; public final byte[] body; - public final HashMap trailers; + public final ProxyMap trailers; public final int timeoutMilliseconds; public HttpCall( int id, Type callType, String uri, - HashMap headers, + ProxyMap headers, byte[] body, - HashMap trailers, + ProxyMap trailers, int timeoutMilliseconds) { this.id = id; this.callType = callType; @@ -320,7 +334,7 @@ public HttpCall( } private final AtomicInteger lastCallId = new AtomicInteger(0); - private final HashMap httpCalls = new HashMap<>(); + private final HashMap httpCalls = new HashMap(); public HashMap getHttpCalls() { return httpCalls; @@ -328,11 +342,7 @@ public HashMap getHttpCalls() { @Override public int httpCall( - String uri, - HashMap headers, - byte[] body, - HashMap trailers, - int timeoutMilliseconds) + String uri, ProxyMap headers, byte[] body, ProxyMap trailers, int timeoutMilliseconds) throws WasmException { var id = lastCallId.incrementAndGet(); HttpCall value = @@ -351,9 +361,9 @@ public int httpCall( @Override public int dispatchHttpCall( String upstreamName, - HashMap headers, + ProxyMap headers, byte[] body, - HashMap trailers, + ProxyMap trailers, int timeoutMilliseconds) throws WasmException { var id = lastCallId.incrementAndGet(); @@ -385,8 +395,8 @@ public Metric(int id, MetricType type, String name) { } private final AtomicInteger lastMetricId = new AtomicInteger(0); - private HashMap metrics = new HashMap<>(); - private HashMap metricsByName = new HashMap<>(); + private HashMap metrics = new HashMap(); + private HashMap metricsByName = new HashMap(); @Override public int defineMetric(MetricType type, String name) throws WasmException {