diff --git a/src/main/java/io/roastedroot/proxywasm/ABI.java b/src/main/java/io/roastedroot/proxywasm/ABI.java index addb75a..e8b48b5 100644 --- a/src/main/java/io/roastedroot/proxywasm/ABI.java +++ b/src/main/java/io/roastedroot/proxywasm/ABI.java @@ -5,6 +5,7 @@ import com.dylibso.chicory.experimental.hostmodule.annotations.HostModule; import com.dylibso.chicory.experimental.hostmodule.annotations.WasmExport; +import com.dylibso.chicory.runtime.ExportFunction; import com.dylibso.chicory.runtime.Instance; import com.dylibso.chicory.runtime.Memory; import com.dylibso.chicory.runtime.WasmRuntimeException; @@ -18,7 +19,36 @@ class ABI { private Handler handler; - private Instance instance; + private Memory memory; + private ExportFunction initializeFn; + private ExportFunction mainFn; + private ExportFunction startFn; + private ExportFunction proxyOnContextCreateFn; + private ExportFunction proxyOnDoneFn; + private ExportFunction mallocFn; + private ExportFunction proxyOnLogFn; + private ExportFunction proxyOnDeleteFn; + private ExportFunction proxyOnVmStartFn; + private ExportFunction proxyOnConfigureFn; + private ExportFunction proxyOnTickFn; + private ExportFunction proxyOnNewConnectionFn; + private ExportFunction proxyOnDownstreamDataFn; + private ExportFunction proxyOnDownstreamConnectionCloseFn; + private ExportFunction proxyOnUpstreamDataFn; + private ExportFunction proxyOnUpstreamConnectionCloseFn; + private ExportFunction proxyOnRequestHeadersFn; + private ExportFunction proxyOnRequestBodyFn; + private ExportFunction proxyOnRequestTrailersFn; + private ExportFunction proxyOnResponseHeadersFn; + private ExportFunction proxyOnResponseBodyFn; + private ExportFunction proxyOnResponseTrailersFn; + private ExportFunction proxyOnHttpCallResponseFn; + private ExportFunction proxyOnGrpcReceiveInitialMetadataFn; + private ExportFunction proxyOnGrpcReceiveFn; + private ExportFunction proxyOnGrpcReceiveTrailingMetadataFn; + private ExportFunction proxyOnGrpcCloseFn; + private ExportFunction proxyOnQueueReadyFn; + private ExportFunction proxyOnForeignFunctionFn; Handler getHandler() { return handler; @@ -29,15 +59,53 @@ void setHandler(Handler handler) { } void setInstance(Instance instance) { - this.instance = instance; - } + this.memory = instance.memory(); + var exports = instance.exports(); - Instance.Exports exports() { - return instance.exports(); - } + // Since 0_2_0, prefer proxy_on_memory_allocate over malloc + this.mallocFn = lookupFunction(exports, "proxy_on_memory_allocate"); + if (this.mallocFn == null) { + this.mallocFn = lookupFunction(exports, "malloc"); + } + if (this.mallocFn == null) { + throw new WasmRuntimeException("malloc function not found"); + } - Memory memory() { - return instance.memory(); + this.initializeFn = lookupFunction(exports, "_initialize"); + this.mainFn = lookupFunction(exports, "main"); + this.startFn = lookupFunction(exports, "_start"); + + // All callbacks proxyOn* are optional, and will only be called if exposed by the Wasm + // module. + this.proxyOnContextCreateFn = lookupFunction(exports, "proxy_on_context_create"); + this.proxyOnDoneFn = lookupFunction(exports, "proxy_on_done"); + this.proxyOnLogFn = lookupFunction(exports, "proxy_on_log"); + this.proxyOnDeleteFn = lookupFunction(exports, "proxy_on_delete"); + this.proxyOnVmStartFn = lookupFunction(exports, "proxy_on_vm_start"); + this.proxyOnConfigureFn = lookupFunction(exports, "proxy_on_configure"); + this.proxyOnTickFn = lookupFunction(exports, "proxy_on_tick"); + this.proxyOnNewConnectionFn = lookupFunction(exports, "proxy_on_new_connection"); + this.proxyOnDownstreamDataFn = lookupFunction(exports, "proxy_on_downstream_data"); + this.proxyOnDownstreamConnectionCloseFn = + lookupFunction(exports, "proxy_on_downstream_connection_close"); + this.proxyOnUpstreamDataFn = lookupFunction(exports, "proxy_on_upstream_data"); + this.proxyOnUpstreamConnectionCloseFn = + lookupFunction(exports, "proxy_on_upstream_connection_close"); + this.proxyOnRequestHeadersFn = lookupFunction(exports, "proxy_on_request_headers"); + this.proxyOnRequestBodyFn = lookupFunction(exports, "proxy_on_request_body"); + this.proxyOnRequestTrailersFn = lookupFunction(exports, "proxy_on_request_trailers"); + this.proxyOnResponseHeadersFn = lookupFunction(exports, "proxy_on_response_headers"); + this.proxyOnResponseBodyFn = lookupFunction(exports, "proxy_on_response_body"); + this.proxyOnResponseTrailersFn = lookupFunction(exports, "proxy_on_response_trailers"); + this.proxyOnHttpCallResponseFn = lookupFunction(exports, "proxy_on_http_call_response"); + this.proxyOnGrpcReceiveInitialMetadataFn = + lookupFunction(exports, "proxy_on_grpc_receive_initial_metadata"); + this.proxyOnGrpcReceiveFn = lookupFunction(exports, "proxy_on_grpc_receive"); + this.proxyOnGrpcReceiveTrailingMetadataFn = + lookupFunction(exports, "proxy_on_grpc_receive_trailing_metadata"); + this.proxyOnGrpcCloseFn = lookupFunction(exports, "proxy_on_grpc_close"); + this.proxyOnQueueReadyFn = lookupFunction(exports, "proxy_on_queue_ready"); + this.proxyOnForeignFunctionFn = lookupFunction(exports, "proxy_on_foreign_function"); } // ////////////////////////////////////////////////////////////////////// @@ -47,12 +115,11 @@ Memory memory() { // Size of a 32-bit integer in bytes static final int U32_LEN = 4; - boolean instanceExportsFunction(String name) { + private ExportFunction lookupFunction(Instance.Exports exports, String name) { try { - this.exports().function(name); - return true; + return exports.function(name); } catch (InvalidException e) { - return false; + return null; } } @@ -63,9 +130,9 @@ boolean instanceExportsFunction(String name) { * @param value The value to write * @throws WasmException if the memory access is invalid */ - void putUint32(int address, int value) throws WasmException { + private void putUint32(int address, int value) throws WasmException { try { - memory().writeI32(address, value); + memory.writeI32(address, value); } catch (RuntimeException e) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); } @@ -77,9 +144,9 @@ void putUint32(int address, int value) throws WasmException { * @param address The address to read from * @throws WasmException if the memory access is invalid */ - long getUint32(int address) throws WasmException { + private long getUint32(int address) throws WasmException { try { - return memory().readU32(address); + return memory.readU32(address); } catch (RuntimeException e) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); } @@ -92,9 +159,9 @@ long getUint32(int address) throws WasmException { * @param value The value to write * @throws WasmException if the memory access is invalid */ - void putByte(int address, byte value) throws WasmException { + private void putByte(int address, byte value) throws WasmException { try { - memory().writeByte(address, value); + memory.writeByte(address, value); } catch (RuntimeException e) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); } @@ -107,9 +174,9 @@ void putByte(int address, byte value) throws WasmException { * @param data The data to write * @throws WasmException if the memory access is invalid */ - void putMemory(int address, byte[] data) throws WasmException { + private void putMemory(int address, byte[] data) throws WasmException { try { - memory().write(address, data); + memory.write(address, data); } catch (RuntimeException e) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); } @@ -122,17 +189,17 @@ void putMemory(int address, byte[] data) throws WasmException { * @param data The data to write * @throws WasmException if the memory access is invalid */ - void putMemory(int address, ByteBuffer data) throws WasmException { + private void putMemory(int address, ByteBuffer data) throws WasmException { try { if (data.hasArray()) { var array = data.array(); - memory().write(address, array, data.position(), data.remaining()); + memory.write(address, array, data.position(), data.remaining()); } else { // This could likely be optimized by extending the memory interface to accept // ByteBuffer byte[] bytes = new byte[data.remaining()]; data.get(bytes); - memory().write(address, bytes); + memory.write(address, bytes); } } catch (RuntimeException e) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); @@ -147,24 +214,24 @@ void putMemory(int address, ByteBuffer data) throws WasmException { * @return The value read * @throws WasmException if the memory access is invalid */ - byte[] readMemory(int address, int len) throws WasmException { + private byte[] readMemory(int address, int len) throws WasmException { try { - return memory().readBytes(address, len); + return memory.readBytes(address, len); } catch (RuntimeException e) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); } } - String readString(int address, int len) throws WasmException { + private String readString(int address, int len) throws WasmException { var data = readMemory(address, len); return new String(data, StandardCharsets.UTF_8); } - void copyIntoInstance(String value, int retPtr, int retSize) throws WasmException { + private void copyIntoInstance(String value, int retPtr, int retSize) throws WasmException { copyIntoInstance(value.getBytes(), retPtr, retSize); } - void copyIntoInstance(byte[] value, int retPtr, int retSize) throws WasmException { + private void copyIntoInstance(byte[] value, int retPtr, int retSize) throws WasmException { try { if (value.length != 0) { int addr = malloc(value.length); @@ -186,48 +253,49 @@ void copyIntoInstance(byte[] value, int retPtr, int retSize) throws WasmExceptio /** * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#_initialize */ - void initialize() { - exports().function("_initialize").apply(); + boolean initialize() { + if (this.initializeFn == null) { + return false; + } + this.initializeFn.apply(); + return true; } /** * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#main */ - int main(int arg0, int arg1) { - long result = exports().function("main").apply(arg0, arg1)[0]; - return (int) result; + boolean main(int arg0, int arg1) { + if (mainFn == null) { + return false; + } + mainFn.apply(arg0, arg1); + return true; } /** * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#_start */ - void start() { - exports().function("_start").apply(); + boolean start() { + if (startFn == null) { + return false; + } + startFn.apply(); + return true; } // ////////////////////////////////////////////////////////////////////// // Memory management // ////////////////////////////////////////////////////////////////////// - String mallocFunctionName = "malloc"; - - String getMallocFunctionName() { - return mallocFunctionName; - } - - void setMallocFunctionName(String mallocFunctionName) { - this.mallocFunctionName = mallocFunctionName; - } - /** * implements: - * * https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#malloc - * * https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_memory_allocate + * * https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#malloc + * * https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_memory_allocate */ int malloc(int size) throws WasmException { // I've noticed guests fail on malloc(0) so lets avoid that assert size > 0 : "malloc size must be greater than zero"; - long ptr = exports().function(mallocFunctionName).apply(size)[0]; + long ptr = mallocFn.apply(size)[0]; if (ptr == 0) { throw new WasmException(WasmResult.INVALID_MEMORY_ACCESS); } @@ -242,14 +310,20 @@ int malloc(int size) throws WasmException { * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_context_create */ void proxyOnContextCreate(int contextID, int parentContextID) { - exports().function("proxy_on_context_create").apply(contextID, parentContextID); + if (proxyOnContextCreateFn == null) { + return; + } + proxyOnContextCreateFn.apply(contextID, parentContextID); } /** * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_done */ boolean proxyOnDone(int context_id) { - long result = exports().function("proxy_on_done").apply(context_id)[0]; + if (proxyOnDoneFn == null) { + return true; + } + long result = proxyOnDoneFn.apply(context_id)[0]; return result != 0; } @@ -257,14 +331,20 @@ boolean proxyOnDone(int context_id) { * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_log */ void proxyOnLog(int context_id) { - exports().function("proxy_on_log").apply(context_id); + if (proxyOnLogFn == null) { + return; + } + proxyOnLogFn.apply(context_id); } /** * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_delete */ void proxyOnDelete(int context_id) { - exports().function("proxy_on_delete").apply(context_id); + if (proxyOnDeleteFn == null) { + return; + } + proxyOnDeleteFn.apply(context_id); } /** @@ -291,7 +371,10 @@ int proxySetEffectiveContext(int contextId) { * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_vm_start */ boolean proxyOnVmStart(int arg0, int arg1) { - long result = exports().function("proxy_on_vm_start").apply(arg0, arg1)[0]; + if (proxyOnVmStartFn == null) { + return true; + } + long result = proxyOnVmStartFn.apply(arg0, arg1)[0]; return result != 0; } @@ -299,7 +382,10 @@ boolean proxyOnVmStart(int arg0, int arg1) { * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_configure */ boolean proxyOnConfigure(int arg0, int arg1) { - long result = exports().function("proxy_on_configure").apply(arg0, arg1)[0]; + if (proxyOnConfigureFn == null) { + return true; + } + long result = proxyOnConfigureFn.apply(arg0, arg1)[0]; return result != 0; } @@ -316,7 +402,7 @@ boolean proxyOnConfigure(int arg0, int arg1) { @WasmExport int proxyLog(int logLevel, int messageData, int messageSize) { try { - var msg = memory().readBytes(messageData, messageSize); + var msg = memory.readBytes(messageData, messageSize); handler.log(LogLevel.fromInt(logLevel), new String(msg)); return WasmResult.OK.getValue(); } catch (WasmException e) { @@ -374,7 +460,10 @@ int proxySetTickPeriodMilliseconds(int tick_period) { * implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_on_tick */ void proxyOnTick(int arg0) { - exports().function("proxy_on_tick").apply(arg0); + if (proxyOnTickFn == null) { + return; + } + proxyOnTickFn.apply(arg0); } // ////////////////////////////////////////////////////////////////////// @@ -466,7 +555,7 @@ int proxySetBufferBytes(int bufferType, int start, int length, int dataPtr, int } // Get content from WebAssembly memory - byte[] content = memory().readBytes(dataPtr, dataSize); + byte[] content = memory.readBytes(dataPtr, dataSize); content = replaceBytes(buf, content, start, length); @@ -574,7 +663,7 @@ private WasmResult setBuffer(int bufferType, byte[] buffer) { /** * Retrieves serialized size of all key-value pairs from the map mapType - * + *
* implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_get_header_map_size * * @param mapType The type of map to set @@ -615,7 +704,7 @@ int proxyGetHeaderMapSize(int mapType, int returnSize) { /** * Get header map pairs and format them for WebAssembly memory(). - * + *
* implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_get_header_map_pairs * * @param mapType The type of map to get @@ -700,7 +789,7 @@ int proxyGetHeaderMapPairs(int mapType, int returnDataPtr, int returnDataSize) { /** * Set header map pairs from WebAssembly memory(). - * + *
* implements: https://github.com/proxy-wasm/spec/tree/main/abi-versions/vNEXT#proxy_set_header_map_pairs
*
* @param mapType The type of map to set
@@ -872,7 +961,7 @@ int proxyRemoveHeaderMapValue(int mapType, int keyDataPtr, int keySize) {
/**
* Get a header map based on the map type.
*
- * @param mapType The type of map to get
+ * @param mapType The type of map to get
* @return The header map
*/
private Map