Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions src/main/java/io/roastedroot/proxywasm/ABI.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import com.dylibso.chicory.wasm.InvalidException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

@HostModule("env")
Expand Down Expand Up @@ -675,19 +674,18 @@ int proxyGetHeaderMapSize(int mapType, int returnSize) {
try {

// Get the header map based on the map type
Map<String, String> header = getMap(mapType);
ProxyMap header = getMap(mapType);
if (header == null) {
return WasmResult.BAD_ARGUMENT.getValue();
}

// to clone the headers so that they don't change on while we process them in the loop
final Map<String, String> cloneMap = new HashMap<>();
var cloneMap = new ArrayProxyMap(header);
int totalBytesLen = U32_LEN; // Start with space for the count

for (Map.Entry<String, String> entry : header.entrySet()) {
for (Map.Entry<String, String> entry : cloneMap.entries()) {
String key = entry.getKey();
String value = entry.getValue();
cloneMap.put(key, value);
totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen
totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0
}
Expand Down Expand Up @@ -717,19 +715,18 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) {
try {

// Get the header map based on the map type
Map<String, String> header = getMap(mapType);
ProxyMap header = getMap(mapType);
if (header == null) {
return WasmResult.NOT_FOUND.getValue();
}

// to clone the headers so that they don't change on while we process them in the loop
final Map<String, String> cloneMap = new HashMap<>();
var cloneMap = new ArrayProxyMap(header);
int totalBytesLen = U32_LEN; // Start with space for the count

for (Map.Entry<String, String> entry : header.entrySet()) {
for (Map.Entry<String, String> entry : cloneMap.entries()) {
String key = entry.getKey();
String value = entry.getValue();
cloneMap.put(key, value);
totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen
totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0
}
Expand All @@ -745,7 +742,7 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) {
int dataPtr = lenPtr + ((U32_LEN + U32_LEN) * cloneMap.size());

// Write each key-value pair to memory
for (Map.Entry<String, String> entry : cloneMap.entrySet()) {
for (Map.Entry<String, String> entry : cloneMap.entries()) {
String key = entry.getKey();
String value = entry.getValue();

Expand Down Expand Up @@ -802,14 +799,14 @@ int proxySetHeaderMapPairs(int mapType, int ptr, int size) {

try {
// Get the header map based on the map type
Map<String, String> headerMap = getMap(mapType);
ProxyMap headerMap = getMap(mapType);
if (headerMap == null) {
return WasmResult.BAD_ARGUMENT.getValue();
}

// Decode the map content and set each key-value pair
Map<String, String> newMap = decodeMap(ptr, size);
for (Map.Entry<String, String> entry : newMap.entrySet()) {
ProxyMap newMap = decodeMap(ptr, size);
for (Map.Entry<String, String> entry : newMap.entries()) {
headerMap.put(entry.getKey(), entry.getValue());
}

Expand Down Expand Up @@ -837,7 +834,7 @@ int proxyGetHeaderMapValue(
int mapType, int keyDataPtr, int keySize, int valueDataPtr, int valueSize) {
try {
// Get the header map based on the map type
Map<String, String> headerMap = getMap(mapType);
ProxyMap headerMap = getMap(mapType);
if (headerMap == null) {
return WasmResult.BAD_ARGUMENT.getValue();
}
Expand Down Expand Up @@ -895,7 +892,7 @@ int proxyReplaceHeaderMapValue(
int mapType, int keyDataPtr, int keySize, int valueDataPtr, int valueSize) {
try {
// Get the header map based on the map type
Map<String, String> headerMap = getMap(mapType);
ProxyMap headerMap = getMap(mapType);
if (headerMap == null) {
return WasmResult.BAD_ARGUMENT.getValue();
}
Expand All @@ -907,7 +904,7 @@ int proxyReplaceHeaderMapValue(
String value = readString(valueDataPtr, valueSize);

// Replace value in map
var copy = new HashMap<>(headerMap);
var copy = new ArrayProxyMap(headerMap);
copy.put(key, value);
setMap(mapType, copy);

Expand All @@ -933,7 +930,7 @@ int proxyReplaceHeaderMapValue(
int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
try {
// Get the header map based on the map type
Map<String, String> headerMap = getMap(mapType);
ProxyMap headerMap = getMap(mapType);
if (headerMap == null) {
return WasmResult.NOT_FOUND.getValue();
}
Expand All @@ -945,7 +942,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
}

// Remove key from map
var copy = new HashMap<>(headerMap);
var copy = new ArrayProxyMap(headerMap);
copy.remove(key);
setMap(mapType, copy);

Expand All @@ -964,7 +961,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
* @param mapType The type of map to get
* @return The header map
*/
private Map<String, String> getMap(int mapType) {
private ProxyMap getMap(int mapType) {

var knownType = MapType.fromInt(mapType);
if (knownType == null) {
Expand Down Expand Up @@ -999,7 +996,7 @@ private Map<String, String> getMap(int mapType) {
* @param map The header map to set
* @return WasmResult indicating success or failure
*/
private WasmResult setMap(int mapType, Map<String, String> map) {
private WasmResult setMap(int mapType, ProxyMap map) {
var knownType = MapType.fromInt(mapType);
if (knownType == null) {
return handler.setCustomHeaders(mapType, map);
Expand Down Expand Up @@ -1043,9 +1040,9 @@ private WasmResult setMap(int mapType, Map<String, String> map) {
* @return The decoded map containing string keys and values
* @throws WasmException if there is an error accessing memory
*/
private HashMap<String, String> decodeMap(int addr, int mem_size) throws WasmException {
private ProxyMap decodeMap(int addr, int mem_size) throws WasmException {
if (mem_size < U32_LEN) {
return new HashMap<>();
return new ArrayProxyMap();
}

// Read header size (number of entries)
Expand All @@ -1055,11 +1052,11 @@ private HashMap<String, String> decodeMap(int addr, int mem_size) throws WasmExc
// mapSize + (key1_size + value1_size) * mapSize
long dataOffset = U32_LEN + (U32_LEN + U32_LEN) * mapSize;
if (dataOffset >= mem_size) {
return new HashMap<>();
return new ArrayProxyMap();
}

// Create result map with initial capacity
var result = new HashMap<String, String>((int) mapSize);
var result = new ArrayProxyMap((int) mapSize);

// Process each entry
for (int i = 0; i < mapSize; i++) {
Expand All @@ -1086,7 +1083,7 @@ private HashMap<String, String> decodeMap(int addr, int mem_size) throws WasmExc
dataOffset += valueSize + 1;

// Add to result map
result.put(key, value);
result.add(key, value);
}

return result;
Expand Down Expand Up @@ -1282,8 +1279,7 @@ int proxySendLocalResponse(
}

// Get and decode additional headers from memory
HashMap<String, String> additionalHeaders =
decodeMap(additionalHeadersMapData, additionalHeadersSize);
ProxyMap additionalHeaders = decodeMap(additionalHeadersMapData, additionalHeadersSize);

// Send the response through the handler
WasmResult result =
Expand Down
86 changes: 86 additions & 0 deletions src/main/java/io/roastedroot/proxywasm/ArrayProxyMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package io.roastedroot.proxywasm;

import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;

public class ArrayProxyMap implements ProxyMap {

final ArrayList<Map.Entry<String, String>> entries;

public ArrayProxyMap() {
this.entries = new ArrayList<>();
}

public ArrayProxyMap(int mapSize) {
this.entries = new ArrayList<>(mapSize);
}

public ArrayProxyMap(ProxyMap other) {
this(other.size());
for (Map.Entry<String, String> entry : other.entries()) {
add(entry.getKey(), entry.getValue());
}
}

public ArrayProxyMap(Map<String, String> other) {
this(other.size());
for (Map.Entry<String, String> entry : other.entrySet()) {
add(entry.getKey(), entry.getValue());
}
}

@Override
public int size() {
return entries.size();
}

@Override
public void add(String key, String value) {
entries.add(Map.entry(key, value));
}

@Override
public void put(String key, String value) {
this.remove(key);
entries.add(Map.entry(key, value));
}

@Override
public Iterable<? extends Map.Entry<String, String>> entries() {
return entries;
}

@Override
public String get(String key) {
return entries.stream()
.filter(x -> x.getKey().equals(key))
.map(Map.Entry::getValue)
.findFirst()
.orElse(null);
}

@Override
public void remove(String key) {
entries.removeIf(x -> x.getKey().equals(key));
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
ArrayProxyMap that = (ArrayProxyMap) o;
return Objects.equals(entries, that.entries);
}

@Override
public int hashCode() {
return Objects.hashCode(entries);
}

@Override
public String toString() {
return entries.toString();
}
}
Loading