diff --git a/src/workerd/io/BUILD.bazel b/src/workerd/io/BUILD.bazel index f8927955362..95e23ff8034 100644 --- a/src/workerd/io/BUILD.bazel +++ b/src/workerd/io/BUILD.bazel @@ -271,7 +271,10 @@ wd_cc_library( wd_cc_library( name = "limit-enforcer", - hdrs = ["limit-enforcer.h"], + hdrs = [ + "limit-enforcer.h", + "wasm-shutdown-signal.h", + ], visibility = ["//visibility:public"], deps = [ ":outcome_capnp", diff --git a/src/workerd/io/limit-enforcer.h b/src/workerd/io/limit-enforcer.h index 2f8897b10eb..b5f0fd73cae 100644 --- a/src/workerd/io/limit-enforcer.h +++ b/src/workerd/io/limit-enforcer.h @@ -5,14 +5,19 @@ #pragma once #include +#include +#include #include #include // For Promise +#include // For KJ_REQUIRE #include // for Own #include // for OneOf #include // for Duration +#include // for std::shared_ptr + namespace workerd { class IsolateObserver; class RequestObserver; @@ -98,6 +103,43 @@ class IsolateLimitEnforcer: public kj::Refcounted { virtual bool hasExcessivelyExceededHeapLimit() const = 0; + // Registers a WASM module for receiving the "shut down" signal when CPU time is nearly + // exhausted. The signal handler will write 1 (as a uint32) into the module's linear memory + // at `signalOffset` bytes from the start of `backingStore`. The runtime reads + // `terminatedOffset` in a GC prologue to detect when the module has exited. + // + // Must be called with the isolate lock held. + void registerWasmShutdownSignal(std::shared_ptr backingStore, + uint32_t signalOffset, + uint32_t terminatedOffset) const { + KJ_REQUIRE( + static_cast(signalOffset) + WASM_SIGNAL_FIELD_BYTES <= backingStore->ByteLength(), + "__signal_address offset is out of bounds: need ", WASM_SIGNAL_FIELD_BYTES, + " bytes but memory is too small"); + KJ_REQUIRE(static_cast(terminatedOffset) + WASM_SIGNAL_FIELD_BYTES <= + backingStore->ByteLength(), + "__terminated_address offset is out of bounds: need ", WASM_SIGNAL_FIELD_BYTES, + " bytes but memory is too small"); + wasmShutdownSignals.pushFront( + WasmShutdownSignal{kj::mv(backingStore), signalOffset, terminatedOffset}); + } + + // Filters out WASM shutdown signal entries where the module has exited (indicated by a + // non-zero value at the terminated address). This should be called from a GC prologue + // hook to allow linear memory to be reclaimed. + // + // Must be called with the isolate lock held. + void filterWasmShutdownSignals() const { + wasmShutdownSignals.filter( + [](const WasmShutdownSignal& signal) { return signal.isModuleListening(); }); + } + + // Returns the list of registered WASM shutdown signals. The list itself is signal-safe for + // reading (via iterate()), so a signal handler can safely walk it. + const AtomicList& getWasmShutdownSignals() const { + return wasmShutdownSignals; + } + // Inserts a custom mark event named `name` into this isolate's perf event data stream. At // present, this is only implemented internally. Call this function from various APIs to be able // to correlate perf event data with usage of those APIs. @@ -107,6 +149,16 @@ class IsolateLimitEnforcer: public kj::Refcounted { // coupled with our CPU time limiting system, so adding this function here is a path of least // resistance. virtual void markPerfEvent(kj::LiteralStringConst name) const {}; + + private: + // WASM modules that have opted into receiving the "shut down" signal by exporting i32 globals + // named "__signal_address" and "__terminated_address". When the CPU time limiter fires + // NEARLY_OUT_OF_TIME, it writes 1 into each module's linear memory at the signal address. + // + // Marked mutable because registration happens through `const IsolateLimitEnforcer&` (the + // standard access pattern), and the AtomicList itself uses atomic stores for safe concurrent + // access from signal handlers on the same thread. + mutable AtomicList wasmShutdownSignals; }; // Abstract interface that enforces resource limits on a IoContext. diff --git a/src/workerd/io/wasm-shutdown-signal.h b/src/workerd/io/wasm-shutdown-signal.h new file mode 100644 index 00000000000..99a524fc5d9 --- /dev/null +++ b/src/workerd/io/wasm-shutdown-signal.h @@ -0,0 +1,141 @@ +// Copyright (c) 2017-2022 Cloudflare, Inc. +// Licensed under the Apache 2.0 license found in the LICENSE file or at: +// https://opensource.org/licenses/Apache-2.0 + +#pragma once + +#include + +#include + +#include +#include + +namespace workerd { + +// Byte size of each signal field in WASM linear memory (a single uint32). +constexpr size_t WASM_SIGNAL_FIELD_BYTES = sizeof(uint32_t); + +// Represents a single WASM module that has opted into receiving the "shut down" signal when CPU +// time is nearly exhausted. The module exports two i32 globals: +// +// "__signal_address" — address of a uint32 in linear memory. The runtime writes 1 here +// when CPU time is nearly exhausted. +// "__terminated_address" — address of a uint32 in linear memory. The WASM module writes a +// non-zero value here when it has exited and is no longer listening. +// The runtime checks this in a GC prologue hook and removes entries +// where terminated is non-zero, allowing the linear memory to be +// reclaimed. +struct WasmShutdownSignal { + // This reference is shared rather than weak so that we can be sure it is not being + // garbage collected when the signal handler runs. This memory gets cleaned up in a + // V8 GC prelude hook where we can atomically remove it from the signal list before + // freeing the memory. + std::shared_ptr backingStore; + + // Offset into `backingStore` of the uint32 the runtime writes 1 to (__signal_address). + uint32_t signalByteOffset; + + // Offset into `backingStore` of the uint32 the module writes to (__terminated_address). + uint32_t terminatedByteOffset; + + // Returns true if the module is still listening for signals (terminated == 0). + // Returns false if the module has exited and this entry should be removed. + bool isModuleListening() const { + uint32_t terminated; + memcpy(&terminated, static_cast(backingStore->Data()) + terminatedByteOffset, + sizeof(terminated)); + return terminated == 0; + } +}; + +// A linked list type which is signal-safe (for reading), but not thread safe - it can handle +// same-thread concurrency ONLY. Mutations (pushFront, filter) are not signal safe, but are +// implemented such that they can be interrupted at any point by a signal handler, and the list will +// still be in a valid state. This means that reading the list (iterate) IS signal safe. +template +class AtomicList { + public: + struct Node { + T value; + Node* next; + template + explicit Node(Args&&... args): value(kj::fwd(args)...), + next(nullptr) {} + }; + + AtomicList() {} + + ~AtomicList() noexcept(false) { + Node* node = __atomic_load_n(&head, __ATOMIC_RELAXED); + while (node != nullptr) { + Node* doomed = node; + node = __atomic_load_n(&doomed->next, __ATOMIC_RELAXED); + delete doomed; + } + } + + // Prepends a new node constructed from `args` at the front of the list + template + void pushFront(Args&&... args) { + Node* node = new Node(kj::fwd(args)...); + __atomic_store_n(&node->next, __atomic_load_n(&head, __ATOMIC_RELAXED), __ATOMIC_RELAXED); + __atomic_store_n(&head, node, __ATOMIC_RELEASE); + } + + // Removes all nodes for which `predicate(node.value)` returns false + template + void filter(Predicate&& predicate) { + Node** prev = &head; + Node* current = __atomic_load_n(prev, __ATOMIC_RELAXED); + + while (current != nullptr) { + Node* next = __atomic_load_n(¤t->next, __ATOMIC_RELAXED); + + if (predicate(current->value)) { + prev = ¤t->next; + } else { + // Splice out `current` by pointing its predecessor at `next`. Release ordering ensures a + // signal handler that loads *prev with acquire sees a fully consistent successor chain. + __atomic_store_n(prev, next, __ATOMIC_RELEASE); + delete current; + } + + current = next; + } + } + + // Returns true if the list is empty. Signal safe. + bool isEmpty() const { + return __atomic_load_n(&head, __ATOMIC_ACQUIRE) == nullptr; + } + + // Traverses the list, calling `func(node.value)` for each node. Signal safe. + template + void iterate(Func&& func) const { + Node* current = __atomic_load_n(&head, __ATOMIC_ACQUIRE); + while (current != nullptr) { + func(current->value); + current = __atomic_load_n(¤t->next, __ATOMIC_ACQUIRE); + } + } + + private: + Node* head = nullptr; + + KJ_DISALLOW_COPY_AND_MOVE(AtomicList); +}; + +// Iterates a WasmShutdownSignal list and writes the shutdown signal (value 1) to each +// registered memory location. This function is signal-safe. +inline void writeWasmShutdownSignals(const AtomicList& signals) { + signals.iterate([](const WasmShutdownSignal& signal) { + // Signal-safe: BackingStore::Data() is a trivial getter; memcpy into mapped WASM memory + // is a plain store. + uint32_t value = 1; + memcpy(static_cast(signal.backingStore->Data()) + signal.signalByteOffset, &value, + sizeof(value)); + }); +} + +} // namespace workerd diff --git a/src/workerd/io/worker.c++ b/src/workerd/io/worker.c++ index 2de15705961..1efb4fb1e57 100644 --- a/src/workerd/io/worker.c++ +++ b/src/workerd/io/worker.c++ @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -657,6 +658,9 @@ struct Worker::Isolate::Impl { void gcPrologue() { metrics.gcPrologue(); + // Filter out WASM shutdown signal entries where the module has exited, allowing + // the linear memory to be reclaimed. + limitEnforcer.filterWasmShutdownSignals(); } void gcEpilogue() { metrics.gcEpilogue(); @@ -1600,6 +1604,14 @@ void Worker::Isolate::setCpuLimitNearlyExceededCallback(kj::Function FAILED, "Python Workers Internal Error: CpuLimitNearlyExceededCallback already set")); } +void Worker::Isolate::registerWasmShutdownSignal(std::shared_ptr backingStore, + uint32_t signalOffset, + uint32_t terminatedOffset) const { + // Register the WASM module for receiving shutdown signals. The signal handler will + // iterate the list unconditionally when CPU time is nearly exhausted. + limitEnforcer->registerWasmShutdownSignal(kj::mv(backingStore), signalOffset, terminatedOffset); +} + // EW-1319: Set WebAssembly.Module @@HasInstance // // The instanceof operator can be changed by setting the @@HasInstance method @@ -1621,11 +1633,139 @@ void setWebAssemblyModuleHasInstance(jsg::Lock& lock, v8::Local con module->DefineOwnProperty(context, v8::Symbol::GetHasInstance(lock.v8Isolate), function)); } +// Installs a shim around WebAssembly.instantiate and WebAssembly.Instance that hooks into the +// shutdown signal if it exists +void shimWebAssemblyInstantiate(jsg::Lock& lock, v8::Local context) { + // We need to enter the context because this function compiles and executes JavaScript via + // v8::Script::Compile/Run. setupContext() is called before JSG_WITHIN_CONTEXT_SCOPE, so the + // context is not yet entered at this point. + v8::Context::Scope contextScope(context); + auto isolate = lock.v8Isolate; + + auto webAssembly = + jsg::check(context->Global()->Get(context, jsg::v8StrIntern(isolate, "WebAssembly"))) + .As(); + + // Create a C++ callback that the JS shims call to register a {memory, signalOffset, + // terminatedOffset} tuple. + // __registerWasmShutdownSignal(memory: WebAssembly.Memory, signalOffset: number, + // terminatedOffset: number) + auto registerCb = [](const v8::FunctionCallbackInfo& info) { + jsg::Lock::from(info.GetIsolate()).withinHandleScope([&] { + auto isolate = info.GetIsolate(); + if (info.Length() < 3 || !info[0]->IsWasmMemoryObject() || !info[1]->IsUint32() || + !info[2]->IsUint32()) { + isolate->ThrowException(jsg::v8Str(isolate, + "registerWasmShutdownSignal: expected (WebAssembly.Memory, uint32, uint32)"_kj)); + return; + } + auto memory = info[0].As(); + auto signalOffset = info[1].As()->Value(); + auto terminatedOffset = info[2].As()->Value(); + auto backingStore = memory->Buffer()->GetBackingStore(); + KJ_IF_SOME(e, kj::runCatchingExceptions([&] { + Worker::Isolate::from(jsg::Lock::from(isolate)) + .registerWasmShutdownSignal(kj::mv(backingStore), signalOffset, terminatedOffset); + })) { + // Convert KJ exception to a JavaScript Error object + auto message = jsg::v8Str(isolate, e.getDescription()); + isolate->ThrowException(v8::Exception::Error(message)); + } + }); + }; + auto registerFn = jsg::check(v8::Function::New(context, registerCb)); + + // Build the shim in JavaScript. It wraps both WebAssembly.instantiate (async) and + // WebAssembly.Instance (sync constructor). A shared helper inspects exports for the + // "__signal_address" / "__terminated_address" convention. + // + // The factory receives: + // originalInstantiate - the original WebAssembly.instantiate function + // originalInstance - the original WebAssembly.Instance constructor + // registerShutdown - the C++ registration callback + // wa - the WebAssembly object + auto shimSource = jsg::v8Str(isolate, + "(function(originalInstantiate, originalInstance, registerShutdown, wa) {\n" + " // Find memory from exports or imports. Returns Memory instance or undefined.\n" + " // When searching imports, only considers entries whose declared import kind is\n" + " // 'memory' (via WebAssembly.Module.imports), so that a Memory passed as an\n" + " // externref is not mistaken for the module's linear memory.\n" + " function findMemory(instance, imports, module) {\n" + " // First, check if memory is exported\n" + " var memory = instance.exports['memory'];\n" + " if (memory instanceof wa.Memory) return memory;\n" + " // Otherwise, check the module's declared memory imports\n" + " if (imports && module) {\n" + " var descs = wa.Module.imports(module);\n" + " for (var i = 0; i < descs.length; i++) {\n" + " if (descs[i].kind === 'memory') {\n" + " var ns = imports[descs[i].module];\n" + " if (ns) {\n" + " var mem = ns[descs[i].name];\n" + " if (mem instanceof wa.Memory) return mem;\n" + " }\n" + " }\n" + " }\n" + " }\n" + " return undefined;\n" + " }\n" + "\n" + " function checkExports(instance, imports, module) {\n" + " var exports = instance.exports;\n" + " var signalGlobal = exports['__signal_address'];\n" + " var terminatedGlobal = exports['__terminated_address'];\n" + " if (signalGlobal instanceof wa.Global &&\n" + " terminatedGlobal instanceof wa.Global) {\n" + " var memory = findMemory(instance, imports, module);\n" + " if (memory) {\n" + " registerShutdown(memory, signalGlobal.value, terminatedGlobal.value);\n" + " }\n" + " }\n" + " }\n" + "\n" + " wa.instantiate = function instantiate(moduleOrBytes, imports) {\n" + " return originalInstantiate.call(wa, moduleOrBytes, imports).then(function(result) {\n" + " var instance = result.instance || result;\n" + " var module = result.module || moduleOrBytes;\n" + " checkExports(instance, imports, module);\n" + " return result;\n" + " });\n" + " };\n" + "\n" + " wa.Instance = function Instance(module, imports) {\n" + " var instance = new originalInstance(module, imports);\n" + " checkExports(instance, imports, module);\n" + " return instance;\n" + " };\n" + " wa.Instance.prototype = originalInstance.prototype;\n" + " Object.defineProperty(wa.Instance.prototype, 'constructor',\n" + " { value: wa.Instance, writable: true, configurable: true });\n" + "})\n"_kj); + + auto shimFactory = jsg::check(v8::Script::Compile(context, shimSource)); + auto shimFactoryResult = jsg::check(shimFactory->Run(context)); + auto shimFactoryFn = shimFactoryResult.As(); + + // Grab the originals before they are replaced. + auto instantiateKey = jsg::v8StrIntern(isolate, "instantiate"); + auto instanceKey = jsg::v8StrIntern(isolate, "Instance"); + auto originalInstantiate = + jsg::check(webAssembly->Get(context, instantiateKey)).As(); + auto originalInstance = jsg::check(webAssembly->Get(context, instanceKey)).As(); + + // Call the factory — it mutates `wa` in place. + v8::Local args[] = {originalInstantiate, originalInstance, registerFn, webAssembly}; + jsg::check(shimFactoryFn->Call(context, context->Global(), 4, args)); +} + void Worker::setupContext( jsg::Lock& lock, v8::Local context, const LoggingOptions& loggingOptions) { // Set WebAssembly.Module @@HasInstance setWebAssemblyModuleHasInstance(lock, context); + // Shim WebAssembly.instantiate to detect modules exporting "__signal_address". + shimWebAssemblyInstantiate(lock, context); + // We replace the default V8 console.log(), etc. methods, to give the worker access to // logged content, and log formatted values to stdout/stderr locally. auto global = context->Global(); diff --git a/src/workerd/io/worker.h b/src/workerd/io/worker.h index 5294458ec55..e14f4e57d1f 100644 --- a/src/workerd/io/worker.h +++ b/src/workerd/io/worker.h @@ -27,9 +27,12 @@ #include #include +#include // for std::shared_ptr + namespace v8 { +class BackingStore; class Isolate; -} +} // namespace v8 namespace workerd { @@ -373,6 +376,12 @@ class Worker::Isolate: public kj::AtomicRefcounted { // Returns a reference to cpuLimitNearlyExceededCallback. Can't outlive the Isolate. kj::Maybe> getCpuLimitNearlyExceededCallback() const; + // Registers a WASM module's linear memory and offsets for receiving the "shut down" signal. + // See IsolateLimitEnforcer::registerWasmShutdownSignal() for details. + void registerWasmShutdownSignal(std::shared_ptr backingStore, + uint32_t signalOffset, + uint32_t terminatedOffset) const; + inline IsolateObserver& getMetrics() { return *metrics; }