Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@
<module>quarkus-proxy-wasm</module>
<module>quarkus-proxy-wasm-example</module>
<module>quarkus-x-corazawaf-example</module>
<module>quarkus-x-kuadrant-example</module>
</modules>
</profile>
</profiles>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.dylibso.chicory.runtime.WasmRuntimeException;
import com.dylibso.chicory.wasm.InvalidException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -487,7 +488,11 @@ void proxyOnTick(int arg0) {
*/
@WasmExport
int proxyGetBufferBytes(
int bufferType, int start, int length, int returnBufferData, int returnBufferSize) {
int bufferType,
int start,
int chunkLength,
int returnBufferData,
int returnBufferSize) {

try {
// Get the buffer based on the buffer type
Expand All @@ -496,29 +501,34 @@ int proxyGetBufferBytes(
return WasmResult.NOT_FOUND.getValue();
}

if (start > start + length) {
if (start < 0) {
return WasmResult.BAD_ARGUMENT.getValue();
}

int maxChunkLength = b.length - start;
if (chunkLength < 0 || chunkLength > maxChunkLength) {
chunkLength = maxChunkLength;
}

ByteBuffer buffer = ByteBuffer.wrap(b);
if (start + length > buffer.capacity()) {
length = buffer.capacity() - start;
if (start + chunkLength > buffer.capacity()) {
chunkLength = buffer.capacity() - start;
}

try {
buffer.position(start);
buffer.limit(start + length);
buffer.limit(start + chunkLength);
} catch (IllegalArgumentException e) {
return WasmResult.BAD_ARGUMENT.getValue();
}

// Allocate memory in the WebAssembly instance
int addr = malloc(length);
int addr = malloc(chunkLength);
putMemory(addr, buffer);
// Write the address to the return pointer
putUint32(returnBufferData, addr);
// Write the length to the return size pointer
putUint32(returnBufferSize, length);
putUint32(returnBufferSize, chunkLength);
return WasmResult.OK.getValue();

} catch (WasmException e) {
Expand Down Expand Up @@ -713,16 +723,24 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) {
return WasmResult.NOT_FOUND.getValue();
}

// to clone the headers so that they don't change on while we process them in the loop
var cloneMap = new ArrayProxyMap(header);
var cloneMap = new ArrayList<Map.Entry<byte[], byte[]>>();
int totalBytesLen = U32_LEN; // Start with space for the count

for (Map.Entry<String, String> entry : cloneMap.entries()) {
String key = entry.getKey();
String value = entry.getValue();
totalBytesLen += U32_LEN + U32_LEN; // keyLen + valueLen
totalBytesLen += key.length() + 1 + value.length() + 1; // key + \0 + value + \0
}
totalBytesLen +=
header.streamBytes()
.mapToInt(
entry -> {
var key = entry.getKey();
var value = entry.getValue();
cloneMap.add(Map.entry(key, value));
return U32_LEN
+ U32_LEN // keyLen + valueLen
+ key.length
+ 1
+ value.length
+ 1; // key + \0 + value + \0
})
.sum();

// Allocate memory in the WebAssembly instance
int addr = malloc(totalBytesLen);
Expand All @@ -735,29 +753,29 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) {
int dataPtr = lenPtr + ((U32_LEN + U32_LEN) * cloneMap.size());

// Write each key-value pair to memory
for (Map.Entry<String, String> entry : cloneMap.entries()) {
String key = entry.getKey();
String value = entry.getValue();
for (Map.Entry<byte[], byte[]> entry : cloneMap) {
var key = entry.getKey();
var value = entry.getValue();

// Write key length
putUint32(lenPtr, key.length());
putUint32(lenPtr, key.length);
lenPtr += U32_LEN;

// Write value length
putUint32(lenPtr, value.length());
putUint32(lenPtr, value.length);
lenPtr += U32_LEN;

// Write key bytes
putMemory(dataPtr, key.getBytes());
dataPtr += key.length();
putMemory(dataPtr, key);
dataPtr += key.length;

// Write null terminator for key
putByte(dataPtr, (byte) 0);
dataPtr++;

// Write value bytes
putMemory(dataPtr, value.getBytes());
dataPtr += value.length();
putMemory(dataPtr, value);
dataPtr += value.length;

// Write null terminator for value
putByte(dataPtr, (byte) 0);
Expand Down Expand Up @@ -1447,7 +1465,7 @@ int proxyGrpcCall(
message,
timeout);
putUint32(returnCalloutID, callId);
return callId;
return WasmResult.OK.getValue();
} catch (WasmException e) {
return e.result().getValue();
}
Expand Down Expand Up @@ -1478,7 +1496,7 @@ int proxyGrpcStream(
int streamId =
handler.grpcStream(upstreamName, serviceName, methodName, initialMetadata);
putUint32(returnStreamId, streamId);
return streamId;
return WasmResult.OK.getValue();
} catch (WasmException e) {
return e.result().getValue();
}
Expand Down Expand Up @@ -1539,21 +1557,21 @@ void proxyOnGrpcReceive(int contextId, int callId, int messageSize) {
/**
* implements https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_grpc_receive_trailing_metadata
*/
void proxyOnGrpcReceiveTrailingMetadata(int arg0, int arg1, int arg2) {
void proxyOnGrpcReceiveTrailingMetadata(int contextId, int callId, int numElements) {
if (proxyOnGrpcReceiveTrailingMetadataFn == null) {
return;
}
proxyOnGrpcReceiveTrailingMetadataFn.apply(arg0, arg1, arg2);
proxyOnGrpcReceiveTrailingMetadataFn.apply(contextId, callId, numElements);
}

/**
* implements https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_grpc_close
*/
void proxyOnGrpcClose(int arg0, int arg1, int arg2) {
void proxyOnGrpcClose(int contextId, int callId, int statusCode) {
if (proxyOnGrpcCloseFn == null) {
return;
}
proxyOnGrpcCloseFn.apply(arg0, arg1, arg2);
proxyOnGrpcCloseFn.apply(contextId, callId, statusCode);
}

// //////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package io.roastedroot.proxywasm;

import static io.roastedroot.proxywasm.Helpers.bytes;
import static io.roastedroot.proxywasm.Helpers.len;
import static io.roastedroot.proxywasm.Helpers.string;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ArrayBytesProxyMap implements ProxyMap {

final ArrayList<Map.Entry<String, byte[]>> entries;

public ArrayBytesProxyMap() {
this.entries = new ArrayList<>();
}

public ArrayBytesProxyMap(int mapSize) {
this.entries = new ArrayList<>(mapSize);
}

@Override
public int size() {
return entries.size();
}

@Override
public void add(String key, String value) {
entries.add(Map.entry(key, bytes(value)));
}

public void add(String key, byte[] value) {
entries.add(Map.entry(key, value));
}

@Override
public void put(String key, String value) {
this.remove(key);
entries.add(Map.entry(key, bytes(value)));
}

public void put(String key, byte[] value) {
this.remove(key);
entries.add(Map.entry(key, value));
}

@Override
public Iterable<? extends Map.Entry<String, String>> entries() {
return entries.stream()
.map(x -> Map.entry(x.getKey(), string(x.getValue())))
.collect(Collectors.toList());
}

@Override
public Stream<Map.Entry<byte[], byte[]>> streamBytes() {
return entries.stream().map(x -> Map.entry(bytes(x.getKey()), x.getValue()));
}

@Override
public String get(String key) {
return entries.stream()
.filter(x -> x.getKey().equals(key))
.map(Map.Entry::getValue)
.map(Helpers::string)
.findFirst()
.orElse(null);
}

@Override
public void remove(String key) {
entries.removeIf(x -> x.getKey().equals(key));
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
ArrayBytesProxyMap that = (ArrayBytesProxyMap) o;
return Objects.equals(entries, that.entries);
}

@Override
public int hashCode() {
return Objects.hashCode(entries);
}

@Override
public String toString() {
return entries.toString();
}

/**
* Encode the map into a byte array.
*/
@Override
public byte[] encode() {
try {
var baos = new ByteArrayOutputStream();
var o = new DataOutputStream(baos);
// Write header size (number of entries)
int mapSize = this.size();
o.writeInt(mapSize);

// write all the key / value sizes.
for (var entry : entries) {
o.writeInt(len(entry.getKey()));
o.writeInt(len(entry.getValue()));
}

// write all the key / values
for (var entry : entries) {
o.write(bytes(entry.getKey()));
o.write(0);
o.write(entry.getValue());
o.write(0);
}
o.close();
return baos.toByteArray();
} catch (IOException e) {
// this should never happen since we are not really doing IO
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public interface ProxyMap {

Expand Down Expand Up @@ -37,6 +39,11 @@ static ProxyMap copyOf(Map<String, String> headers) {

Iterable<? extends Map.Entry<String, String>> entries();

default Stream<Map.Entry<byte[], byte[]>> streamBytes() {
return StreamSupport.stream(entries().spliterator(), false)
.map(x -> Map.entry(bytes(x.getKey()), bytes(x.getValue())));
}

String get(String key);

void remove(String key);
Expand Down
Loading