Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package local_resolver

import (
"context"
"testing"

"github.com/spotify/confidence-resolver/openfeature-provider/go/confidence/internal/proto/wasm"
tu "github.com/spotify/confidence-resolver/openfeature-provider/go/confidence/internal/testutil"
)

// TestWasmMemoryStableOnRepeatedResolveCalls verifies that WASM linear memory
// does not grow unboundedly when resolving flags repeatedly. Each resolve_flags
// call triggers a current_time() host call from the guest. If the host function
// does not free the guest's request allocation, memory leaks accumulate and
// eventually force WASM memory.grow.
func TestWasmMemoryStableOnRepeatedResolveCalls(t *testing.T) {
factory := NewWasmResolverFactory(NoOpLogSink)
defer factory.Close(context.Background())

resolver := factory.New()
defer resolver.Close(context.Background())

wasmResolver := resolver.(*WasmResolver)

testState := tu.LoadTestResolverState(t)
testAcctID := tu.LoadTestAccountID(t)

if err := wasmResolver.SetResolverState(&wasm.SetResolverStateRequest{
State: testState,
AccountId: testAcctID,
}); err != nil {
t.Fatalf("Failed to set resolver state: %v", err)
}

request := tu.CreateResolveProcessRequest(tu.CreateTutorialFeatureRequest())

// Warm up: let allocator settling and one-time growth complete.
for i := 0; i < 50_000; i++ {
if _, err := wasmResolver.ResolveProcess(request); err != nil {
t.Fatalf("ResolveProcess failed during warmup: %v", err)
}
if i%1000 == 0 {
wasmResolver.FlushAllLogs()
}
}

memBefore := wasmResolver.instance.Memory().Size()

// Run resolves. Each call triggers current_time() in the guest which
// allocates a request in WASM memory. A leak here causes linear growth.
iterations := 50_000
for i := 0; i < iterations; i++ {
if _, err := wasmResolver.ResolveProcess(request); err != nil {
t.Fatalf("ResolveProcess failed at iteration %d: %v", i, err)
}
if i%1000 == 0 {
wasmResolver.FlushAllLogs()
}
}

memAfter := wasmResolver.instance.Memory().Size()

if memAfter > memBefore {
t.Errorf("WASM memory grew from %d to %d bytes (%d bytes / %d pages) after %d resolve calls — indicates a leak in host function memory management",
memBefore, memAfter, memAfter-memBefore, (memAfter-memBefore)/65536, iterations)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ private <T extends Message> T consumeResponse(int addr, ParserFn<T> codec) {

private <T extends Message> T consumeRequest(int addr, ParserFn<T> codec) {
try {
final Messages.Request request = Messages.Request.parseFrom(consume(addr));
// Read without freeing — the WASM guest frees its own request allocation
// in call_sync_host after the host function returns.
final Messages.Request request = Messages.Request.parseFrom(readBytes(addr));
return codec.apply(request.getData().toByteArray());
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(e);
Expand All @@ -309,6 +311,12 @@ private int transferResponseError(String error) {
return transfer(wrapperBytes);
}

private byte[] readBytes(int addr) {
final Memory mem = instance.memory();
final int len = (int) (mem.readU32(addr - 4) - 4L);
return mem.readBytes(addr, len);
}

private byte[] consume(int addr) {
final Memory mem = instance.memory();
final int len = (int) (mem.readU32(addr - 4) - 4L);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.spotify.confidence.sdk;

import static org.junit.jupiter.api.Assertions.assertEquals;

import com.dylibso.chicory.runtime.Instance;
import com.google.protobuf.Struct;
import com.google.protobuf.util.Structs;
import com.google.protobuf.util.Values;
import com.spotify.confidence.sdk.flags.resolver.v1.ResolveFlagsRequest;
import com.spotify.confidence.sdk.flags.resolver.v1.ResolveProcessRequest;
import java.lang.reflect.Field;
import java.util.List;
import org.junit.jupiter.api.Test;

/**
* Regression test for WASM memory leak in host functions. Each resolve_flags call triggers a
* current_time() host call which allocates a request in WASM memory. If the host doesn't free it,
* memory leaks ~20 bytes per call and eventually forces memory.grow.
*/
class WasmMemoryLeakTest {

private static int getWasmMemoryPages(WasmLocalResolver resolver) {
try {
Field instanceField = WasmLocalResolver.class.getDeclaredField("instance");
instanceField.setAccessible(true);
Instance instance = (Instance) instanceField.get(resolver);
return instance.memory().pages();
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Failed to access WASM memory via reflection", e);
}
}

@Test
void wasmMemoryStableOnRepeatedResolveCalls() {
WasmLocalResolver resolver = new WasmLocalResolver(request -> {});
resolver.setResolverState(ResolveTest.exampleStateBytes, "account", null);

ResolveProcessRequest request =
ResolveProcessRequest.newBuilder()
.setDeferredMaterializations(
ResolveFlagsRequest.newBuilder()
.addAllFlags(List.of("flags/flag-1"))
.setClientSecret(ResolveTest.secret.getSecret())
.setEvaluationContext(
Structs.of(
"targeting_key",
Values.of("user-123"),
"bar",
Values.of(Struct.newBuilder().build())))
.setApply(true)
.build())
.build();

// Warm up to settle one-time allocations
for (int i = 0; i < 50_000; i++) {
resolver.resolveProcess(request).toCompletableFuture().join();
if (i % 1000 == 0) resolver.flushAllLogs();
}

int pagesBefore = getWasmMemoryPages(resolver);

for (int i = 0; i < 50_000; i++) {
resolver.resolveProcess(request).toCompletableFuture().join();
if (i % 1000 == 0) resolver.flushAllLogs();
}

int pagesAfter = getWasmMemoryPages(resolver);

assertEquals(
pagesBefore,
pagesAfter,
String.format(
"WASM memory grew from %d to %d pages (%d bytes leaked) — "
+ "host function is not freeing guest request allocations",
pagesBefore, pagesAfter, (pagesAfter - pagesBefore) * 65536L));
}
}
48 changes: 48 additions & 0 deletions openfeature-provider/js/src/WasmResolver.memory.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { describe, it, expect } from 'vitest';
import { UnsafeWasmResolver } from './WasmResolver';
import { readFileSync } from 'node:fs';
import { ResolveProcessRequest } from './proto/confidence/wasm/wasm_api';

const moduleBytes = readFileSync(__dirname + '/../../../wasm/confidence_resolver.wasm');
const stateBytes = readFileSync(__dirname + '/../../../wasm/resolver_state.pb');

const module = new WebAssembly.Module(moduleBytes);
const CLIENT_SECRET = 'mkjJruAATQWjeY7foFIWfVAcBWnci2YF';

const RESOLVE_REQUEST: ResolveProcessRequest = {
deferredMaterializations: {
flags: ['flags/tutorial-feature'],
clientSecret: CLIENT_SECRET,
apply: true,
evaluationContext: {
targeting_key: 'tutorial_visitor',
visitor_id: 'tutorial_visitor',
},
},
};

const SET_STATE_REQUEST = { state: stateBytes, accountId: 'confidence-test' };

describe('wasm memory stability', () => {
it('should not leak memory on repeated resolve calls', { timeout: 30_000 }, () => {
const resolver = new UnsafeWasmResolver(module);
resolver.setResolverState(SET_STATE_REQUEST);

// Warm up to settle one-time allocations
for (let i = 0; i < 50_000; i++) {
resolver.resolveProcess(RESOLVE_REQUEST);
if (i % 1000 === 0) resolver.flushLogs();
}

const memBefore = (resolver as any).exports.memory.buffer.byteLength;

for (let i = 0; i < 50_000; i++) {
resolver.resolveProcess(RESOLVE_REQUEST);
if (i % 1000 === 0) resolver.flushLogs();
}

const memAfter = (resolver as any).exports.memory.buffer.byteLength;

expect(memAfter).toBe(memBefore);
});
});
49 changes: 36 additions & 13 deletions openfeature-provider/python/tests/test_wasm_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,29 +162,52 @@ def test_flush_assigned_returns_bytes(
class TestMemoryManagement:
"""Test memory allocation and deallocation."""

@staticmethod
def _build_request(client_secret: str) -> wasm_api_pb2.ResolveProcessRequest:
resolve_request = api_pb2.ResolveFlagsRequest()
resolve_request.flags.append(TEST_FLAG_NAME)
resolve_request.client_secret = client_secret
evaluation_context = struct_pb2.Struct()
evaluation_context.fields["targeting_key"].string_value = "user-123"
resolve_request.evaluation_context.CopyFrom(evaluation_context)
request = wasm_api_pb2.ResolveProcessRequest()
request.deferred_materializations.CopyFrom(resolve_request)
return request

def test_multiple_resolves_dont_leak_memory(
self,
wasm_bytes: bytes,
test_resolver_state: bytes,
test_account_id: str,
test_client_secret: str,
) -> None:
"""Multiple resolves don't cause memory issues."""
"""WASM memory must not grow on repeated resolves.

Each resolve_flags call triggers a current_time() host call which
allocates a request in WASM memory. If the host doesn't free it,
memory leaks ~20 bytes per call and eventually forces memory.grow.
"""
resolver = WasmResolver(wasm_bytes)
resolver.set_resolver_state(test_resolver_state, test_account_id)
request = self._build_request(test_client_secret)

# Warm up to settle one-time allocations
for i in range(50_000):
resolver.resolve_process(request)
if i % 1000 == 0:
resolver.flush_logs()

for i in range(100):
resolve_request = api_pb2.ResolveFlagsRequest()
resolve_request.flags.append(TEST_FLAG_NAME)
resolve_request.client_secret = test_client_secret
evaluation_context = struct_pb2.Struct()
evaluation_context.fields["targeting_key"].string_value = f"user-{i}"
resolve_request.evaluation_context.CopyFrom(evaluation_context)
pages_before = resolver._memory.size(resolver._store)

request = wasm_api_pb2.ResolveProcessRequest()
request.deferred_materializations.CopyFrom(resolve_request)
for i in range(50_000):
resolver.resolve_process(request)
if i % 1000 == 0:
resolver.flush_logs()

# Should complete without issues
logs = resolver.flush_logs()
assert isinstance(logs, bytes)
pages_after = resolver._memory.size(resolver._store)

assert pages_after == pages_before, (
f"WASM memory grew from {pages_before} to {pages_after} pages "
f"({(pages_after - pages_before) * 65536} bytes leaked) — "
f"host function is not freeing guest request allocations"
)
3 changes: 3 additions & 0 deletions wasm-msg/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ where
{
let input_ptr = message::transfer_request(request);
let output_ptr = unsafe { host_func(input_ptr) };
// Free the request we allocated — the host has already read it.
// This mirrors call_sync_guest which frees its input via consume_request.
crate::memory::wasm_msg_free(input_ptr);
if output_ptr.is_null() {
return Err(String::from("Host function returned null pointer"));
}
Expand Down
Loading