Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/workerd/api/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ kj_test(
deps = [
"//src/workerd/io",
"//src/workerd/tests:test-fixture",
"@zstd",
],
)

Expand Down Expand Up @@ -687,6 +688,7 @@ kj_test(
"//src/workerd/io",
"//src/workerd/jsg",
"//src/workerd/tests:test-fixture",
"@capnp-cpp//src/kj/compat:kj-gzip",
],
)

Expand Down
76 changes: 76 additions & 0 deletions src/workerd/api/streams/readable-source-test.c++
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <workerd/util/stream-utils.h>

#include <kj/async-io.h>
#include <kj/compat/gzip.h>
#include <kj/test.h>

// We Thank Claude for Tests.
Expand Down Expand Up @@ -973,6 +974,81 @@ KJ_TEST("Gzip encoded stream (pumpTo different encoding)") {
KJ_ASSERT(inner.data == expected);
}

KJ_TEST("Zstd encoded stream") {
TestFixture fixture;
// zstd-compressed "some data to zstd"
static constexpr kj::byte data[] = {
40, 181, 47, 253, 36, 17, 137, 0, 0, 115, 111, 109, 101, 32, 100, 97,
116, 97, 32, 116, 111, 32, 122, 115, 116, 100, 89, 232, 89, 209};
auto inner = newMemoryInputStream(data);
auto source = newEncodedReadableSource(rpc::StreamEncoding::ZSTD, kj::mv(inner));

fixture.runInIoContext([&](const auto& environment) -> kj::Promise<void> {
auto allBytes = co_await source->readAllBytes(kj::maxValue);
KJ_ASSERT(allBytes == "some data to zstd"_kjb);
});
}

KJ_TEST("Zstd encoded stream (pumpTo)") {
TestFixture fixture;
static constexpr kj::byte data[] = {
40, 181, 47, 253, 36, 17, 137, 0, 0, 115, 111, 109, 101, 32, 100, 97,
116, 97, 32, 116, 111, 32, 122, 115, 116, 100, 89, 232, 89, 209};
auto inner = newMemoryInputStream(data);
auto source = newEncodedReadableSource(rpc::StreamEncoding::ZSTD, kj::mv(inner));

MockWritableSink sink;

fixture.runInIoContext([&](const auto& environment) -> kj::Promise<void> {
co_await environment.context.waitForDeferredProxy(source->pumpTo(sink, EndAfterPump::YES));
});

KJ_ASSERT(sink.writtenData == "some data to zstd"_kjb);
}

KJ_TEST("Zstd encoded stream (pumpTo same encoding)") {
TestFixture fixture;
static const kj::byte data[] = {
40, 181, 47, 253, 36, 17, 137, 0, 0, 115, 111, 109, 101, 32, 100, 97,
116, 97, 32, 116, 111, 32, 122, 115, 116, 100, 89, 232, 89, 209};
auto in = newMemoryInputStream(data);
auto source = newEncodedReadableSource(rpc::StreamEncoding::ZSTD, kj::mv(in));

MemoryAsyncOutputStream inner;
auto fakeOwn = kj::Own<MemoryAsyncOutputStream>(&inner, kj::NullDisposer::instance);
auto sink = newEncodedWritableSink(rpc::StreamEncoding::ZSTD, kj::mv(fakeOwn));

fixture.runInIoContext([&](const auto& environment) -> kj::Promise<void> {
co_await environment.context.waitForDeferredProxy(source->pumpTo(*sink, EndAfterPump::YES));
});

// The data should pass through unchanged (no decompress/recompress).
KJ_ASSERT(inner.data == data);
}

KJ_TEST("Zstd encoded stream (pumpTo different encoding)") {
TestFixture fixture;
static const kj::byte data[] = {
40, 181, 47, 253, 36, 17, 137, 0, 0, 115, 111, 109, 101, 32, 100, 97,
116, 97, 32, 116, 111, 32, 122, 115, 116, 100, 89, 232, 89, 209};
auto in = newMemoryInputStream(data);
auto source = newEncodedReadableSource(rpc::StreamEncoding::ZSTD, kj::mv(in));

MemoryAsyncOutputStream inner;
auto fakeOwn = kj::Own<MemoryAsyncOutputStream>(&inner, kj::NullDisposer::instance);
auto sink = newEncodedWritableSink(rpc::StreamEncoding::GZIP, kj::mv(fakeOwn));

fixture.runInIoContext([&](const auto& environment) -> kj::Promise<void> {
co_await environment.context.waitForDeferredProxy(source->pumpTo(*sink, EndAfterPump::YES));

// Verify the output is valid gzip containing the original plaintext.
auto mem = newMemoryInputStream(inner.data);
kj::GzipAsyncInputStream gunzip(*mem);
auto text = co_await gunzip.readAllText(kj::maxValue);
KJ_ASSERT(text == "some data to zstd"_kj);
});
}

// ======================================================================================
// Adaptive Pump Behavior Tests
// These tests verify the adaptive pump heuristics without relying on timing.
Expand Down
64 changes: 64 additions & 0 deletions src/workerd/api/streams/readable-source.c++
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <kj/async-io.h>
#include <kj/compat/brotli.h>
#include <kj/compat/gzip.h>
#include <zstd.h>

#include <bit>

Expand Down Expand Up @@ -665,6 +666,63 @@ class NoDeferredProxySource final: public ReadableSourceWrapper {
IoContext& ioctx;
};

class ZstdAsyncInputStream final: public kj::AsyncInputStream {
public:
explicit ZstdAsyncInputStream(kj::AsyncInputStream& inner)
: inner(inner), dctx(ZSTD_createDCtx()) {
KJ_ASSERT(dctx != nullptr, "failed to allocate ZSTD_DCtx");
}
~ZstdAsyncInputStream() noexcept(false) {
ZSTD_freeDCtx(dctx);
}
KJ_DISALLOW_COPY_AND_MOVE(ZstdAsyncInputStream);

kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return readImpl(reinterpret_cast<kj::byte*>(buffer), minBytes, maxBytes, 0);
}

private:
kj::AsyncInputStream& inner;
ZSTD_DCtx* dctx;
bool atValidEndpoint = false;
static constexpr size_t IN_BUFFER_SIZE = 16384;
kj::byte inBuffer[IN_BUFFER_SIZE];
size_t inAvail = 0;
size_t inPos = 0;

kj::Promise<size_t> readImpl(
kj::byte* out, size_t minBytes, size_t maxBytes, size_t alreadyRead) {
while (inAvail > 0 && alreadyRead < maxBytes) {
ZSTD_inBuffer input = {inBuffer + inPos, inAvail, 0};
ZSTD_outBuffer output = {out, maxBytes, alreadyRead};
size_t result = ZSTD_decompressStream(dctx, &output, &input);
inPos += input.pos;
inAvail -= input.pos;
alreadyRead = output.pos;
KJ_REQUIRE(!ZSTD_isError(result), "zstd decompression error", ZSTD_getErrorName(result));
if (result == 0) {
atValidEndpoint = true;
if (inAvail == 0) break;
atValidEndpoint = false;
}
if (alreadyRead >= minBytes) return alreadyRead;
}
if (alreadyRead >= minBytes) return alreadyRead;
inPos = 0;
inAvail = 0;
return inner.tryRead(inBuffer, 1, IN_BUFFER_SIZE)
.then([this, out, minBytes, maxBytes, alreadyRead](size_t n) -> kj::Promise<size_t> {
if (n == 0) {
KJ_REQUIRE(atValidEndpoint, "zstd-compressed stream ended prematurely");
return alreadyRead;
}
inAvail = n;
inPos = 0;
return readImpl(out, minBytes, maxBytes, alreadyRead);
});
}
};

// A ReadableSource implementation that lazily wraps an innner Gzip or Brotli
// encoded AsyncInputStream when the first read() is called, or when pumpTo is called,
// the encoding will be selectively and lazily applied to the inner stream.
Expand Down Expand Up @@ -694,6 +752,9 @@ class EncodedAsyncInputStream final: public ReadableSourceImpl {
{"brotli compression failed"_kj, "Brotli compression failed."},
{"brotli compressed stream ended prematurely"_kj,
"Brotli compressed stream ended prematurely."},
{"zstd decompression error"_kj, "Zstd decompression failed."},
{"zstd-compressed stream ended prematurely"_kj,
"Zstd compressed stream ended prematurely."},
})) {
kj::throwFatalException(kj::mv(translated));
} else {
Expand Down Expand Up @@ -739,6 +800,9 @@ class EncodedAsyncInputStream final: public ReadableSourceImpl {
case rpc::StreamEncoding::BROTLI: {
return kj::heap<kj::BrotliAsyncInputStream>(*inner).attach(kj::mv(inner));
}
case rpc::StreamEncoding::ZSTD: {
return kj::heap<ZstdAsyncInputStream>(*inner).attach(kj::mv(inner));
}
}
KJ_UNREACHABLE;
}
Expand Down
40 changes: 40 additions & 0 deletions src/workerd/api/streams/writable-sink-test.c++
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,46 @@ KJ_TEST("Gzip-encoding sink (identity)") {
KJ_ASSERT(inner.data == kj::arrayPtr(check, sizeof(check)));
}

KJ_TEST("Zstd-encoding sink") {
// zstd output compression is not supported; verify it throws.
TestFixture fixture;
MemoryAsyncOutputStream inner;
auto fakeOwn = kj::Own<MemoryAsyncOutputStream>(&inner, kj::NullDisposer::instance);
auto sink = newEncodedWritableSink(rpc::StreamEncoding::ZSTD, kj::mv(fakeOwn));

fixture.runInIoContext([&](const auto& environment) -> kj::Promise<void> {
bool threw = false;
try {
co_await sink->write("some data to zstd"_kjb);
} catch (kj::Exception& e) {
KJ_ASSERT(kj::StringPtr(e.getDescription()).contains("zstd output compression is not supported"));
threw = true;
}
KJ_ASSERT(threw, "expected write() to throw for unsupported zstd compression");
});
}

KJ_TEST("Zstd-encoding sink (identity)") {
TestFixture fixture;
MemoryAsyncOutputStream inner;
auto fakeOwn = kj::Own<MemoryAsyncOutputStream>(&inner, kj::NullDisposer::instance);
auto sink = newEncodedWritableSink(rpc::StreamEncoding::ZSTD, kj::mv(fakeOwn));

static const kj::byte check[] = {
40, 181, 47, 253, 36, 17, 137, 0, 0, 115, 111, 109, 101, 32, 100, 97,
116, 97, 32, 116, 111, 32, 122, 115, 116, 100, 89, 232, 89, 209};

// When encoding is disowned, the data should be passed through unmodified.
sink->disownEncodingResponsibility();

fixture.runInIoContext([&](const auto& environment) -> kj::Promise<void> {
co_await sink->write(check);
co_await sink->end();
});

KJ_ASSERT(inner.data == kj::arrayPtr(check, sizeof(check)));
}

// ======================================================================================
// IoContext-aware WritableSinkWrapper Tests

Expand Down
4 changes: 4 additions & 0 deletions src/workerd/api/streams/writable-sink.c++
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <kj/async-io.h>
#include <kj/compat/brotli.h>
#include <kj/compat/gzip.h>
#include <zstd.h>

namespace workerd::api::streams {

Expand Down Expand Up @@ -226,6 +227,9 @@ class EncodedAsyncOutputStream final: public WritableSinkImpl {
case rpc::StreamEncoding::IDENTITY: {
return setStream(kj::mv(inner));
}
case rpc::StreamEncoding::ZSTD: {
KJ_FAIL_REQUIRE("zstd output compression is not supported; use encodeResponseBody: manual");
}
}
KJ_UNREACHABLE;
}
Expand Down
41 changes: 41 additions & 0 deletions src/workerd/api/system-streams-test.c++
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <workerd/tests/test-fixture.h>

#include <kj/test.h>
#include <zstd.h>

namespace workerd::api {
namespace {
Expand Down Expand Up @@ -51,5 +52,45 @@ KJ_TEST("EncodedAsyncInputStream cancel with pending read on AsyncPipe") {
});
}

KJ_TEST("ZstdAsyncInputStream decompresses correctly") {
TestFixture fixture;
fixture.runInIoContext([](const TestFixture::Environment& env) -> kj::Promise<void> {
constexpr kj::StringPtr plaintext = "hello zstd"_kj;

// Compress synchronously using the zstd one-shot API.
size_t bound = ZSTD_compressBound(plaintext.size());
auto compressed = kj::heapArray<kj::byte>(bound);
size_t compressedSize = ZSTD_compress(compressed.begin(), compressed.size(),
plaintext.begin(), plaintext.size(), ZSTD_CLEVEL_DEFAULT);
KJ_REQUIRE(!ZSTD_isError(compressedSize), ZSTD_getErrorName(compressedSize));

// Wrap the compressed bytes in a synchronous AsyncInputStream.
struct ArrayStream: kj::AsyncInputStream {
kj::Array<kj::byte> data;
size_t pos = 0;
ArrayStream(kj::Array<kj::byte> d): data(kj::mv(d)) {}
virtual ~ArrayStream() = default;
kj::Promise<size_t> tryRead(void* buf, size_t, size_t max) override {
size_t n = kj::min(max, data.size() - pos);
memcpy(buf, data.begin() + pos, n);
pos += n;
return n;
}
};
auto inner = kj::heap<ArrayStream>(kj::heapArray(compressed.slice(0, compressedSize)));

// Decode through a ZSTD-encoded system stream.
auto stream = newSystemStream(kj::mv(inner), StreamEncoding::ZSTD, env.context);
auto outBuf = kj::heapArray<kj::byte>(64);
size_t expectedSize = plaintext.size();
return stream->tryRead(outBuf.begin(), expectedSize, outBuf.size())
.then([outBuf = kj::mv(outBuf), stream = kj::mv(stream), expectedSize](size_t n) {
KJ_EXPECT(n == expectedSize);
KJ_EXPECT(
kj::StringPtr(reinterpret_cast<const char*>(outBuf.begin()), n) == "hello zstd"_kj);
});
});
}

} // namespace
} // namespace workerd::api
Loading
Loading