Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/workerd/api/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ wd_cc_library(
deps = [
"//src/workerd/io",
"//src/workerd/util:state-machine",
"@capnp-cpp//src/kj/compat:kj-brotli",
"@nbytes",
],
)
Expand Down
171 changes: 154 additions & 17 deletions src/workerd/api/streams/compression.c++
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include <workerd/util/ring-buffer.h>
#include <workerd/util/state-machine.h>

#include <brotli/decode.h>
#include <brotli/encode.h>

namespace workerd::api {
CompressionAllocator::CompressionAllocator(
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
Expand Down Expand Up @@ -51,6 +54,21 @@ void CompressionAllocator::FreeForZlib(void* opaque, void* pointer) {

namespace {

enum class Format {
GZIP,
DEFLATE,
DEFLATE_RAW,
BROTLI,
};

static Format parseFormat(kj::StringPtr format) {
if (format == "gzip") return Format::GZIP;
if (format == "deflate") return Format::DEFLATE;
if (format == "deflate-raw") return Format::DEFLATE_RAW;
if (format == "brotli") return Format::BROTLI;
KJ_UNREACHABLE;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really unreachable. This should probably throw a proper exception, just in case

}

class Context {
public:
enum class Mode {
Expand All @@ -74,20 +92,26 @@ class Context {
kj::Arc<const jsg::ExternalMemoryTarget>&& externalMemoryTarget)
: allocator(kj::mv(externalMemoryTarget)),
mode(mode),
strictCompression(flags)
strictCompression(flags),
format(parseFormat(format))

{
if (this->format == Format::BROTLI) {
initBrotli();
return;
}

// Configure allocator before any stream operations.
allocator.configure(&ctx);
int result = Z_OK;
switch (mode) {
case Mode::COMPRESS:
result = deflateInit2(&ctx, Z_DEFAULT_COMPRESSION, Z_DEFLATED, getWindowBits(format),
result = deflateInit2(&ctx, Z_DEFAULT_COMPRESSION, Z_DEFLATED, getWindowBits(this->format),
8, // memLevel = 8 is the default
Z_DEFAULT_STRATEGY);
break;
case Mode::DECOMPRESS:
result = inflateInit2(&ctx, getWindowBits(format));
result = inflateInit2(&ctx, getWindowBits(this->format));
break;
default:
KJ_UNREACHABLE;
Expand All @@ -96,6 +120,21 @@ class Context {
}

~Context() noexcept(false) {
if (format == Format::BROTLI) {
switch (mode) {
case Mode::COMPRESS:
if (brotliEncoderState != nullptr) {
BrotliEncoderDestroyInstance(brotliEncoderState);
}
break;
case Mode::DECOMPRESS:
if (brotliDecoderState != nullptr) {
BrotliDecoderDestroyInstance(brotliDecoderState);
}
break;
}
return;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like an odd way to structure this when there's already a switch body here. Seems they can be coalesced.

switch (mode) {
case Mode::COMPRESS:
deflateEnd(&ctx);
Expand All @@ -109,11 +148,19 @@ class Context {
KJ_DISALLOW_COPY_AND_MOVE(Context);

void setInput(const void* in, size_t size) {
if (format == Format::BROTLI) {
brotliNextIn = reinterpret_cast<const uint8_t*>(in);
brotliAvailIn = size;
return;
}
ctx.next_in = const_cast<byte*>(reinterpret_cast<const byte*>(in));
ctx.avail_in = size;
}

Result pumpOnce(int flush) {
if (format == Format::BROTLI) {
return pumpBrotliOnce(flush);
}
ctx.next_out = buffer;
ctx.avail_out = sizeof(buffer);

Expand Down Expand Up @@ -151,11 +198,76 @@ class Context {
};
}

bool hasTrailingError() const {
return brotliTrailingError;
}

protected:
CompressionAllocator allocator;

private:
static int getWindowBits(kj::StringPtr format) {
void initBrotli() {
if (mode == Mode::COMPRESS) {
auto* instance = BrotliEncoderCreateInstance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more idiomatic to wrap the unit/cleanup into a smart pointer

CompressionAllocator::AllocForBrotli, CompressionAllocator::FreeForZlib, &allocator);
JSG_REQUIRE(instance != nullptr, Error, "Failed to initialize compression context."_kj);
brotliEncoderState = instance;
return;
}

auto* instance = BrotliDecoderCreateInstance(
CompressionAllocator::AllocForBrotli, CompressionAllocator::FreeForZlib, &allocator);
JSG_REQUIRE(instance != nullptr, Error, "Failed to initialize compression context."_kj);
brotliDecoderState = instance;
}

Result pumpBrotliOnce(int flush) {
uint8_t* nextOut = buffer;
size_t availOut = sizeof(buffer);

if (mode == Mode::COMPRESS) {
auto op = flush == Z_FINISH ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS;
auto ok = BrotliEncoderCompressStream(
brotliEncoderState, op, &brotliAvailIn, &brotliNextIn, &availOut, &nextOut, nullptr);
JSG_REQUIRE(ok == BROTLI_TRUE, TypeError, "Compression failed.");

bool shouldContinue = brotliAvailIn > 0 || BrotliEncoderHasMoreOutput(brotliEncoderState);
if (op == BROTLI_OPERATION_FINISH && !BrotliEncoderIsFinished(brotliEncoderState)) {
shouldContinue = true;
}

return Result{
.success = shouldContinue,
.buffer = kj::arrayPtr(buffer, sizeof(buffer) - availOut),
};
}

auto result = BrotliDecoderDecompressStream(
brotliDecoderState, &brotliAvailIn, &brotliNextIn, &availOut, &nextOut, nullptr);
JSG_REQUIRE(result != BROTLI_DECODER_RESULT_ERROR, TypeError, "Decompression failed.");

if (strictCompression == ContextFlags::STRICT) {
// Track trailing data so we can surface the error after buffered output drains.
if (BrotliDecoderIsFinished(brotliDecoderState) && brotliAvailIn > 0) {
brotliTrailingError = true;
}
if (flush == Z_FINISH && result == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT &&
availOut == sizeof(buffer)) {
JSG_FAIL_REQUIRE(
TypeError, "Called close() on a decompression stream with incomplete data");
}
}

bool shouldContinue = result == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT ||
BrotliDecoderHasMoreOutput(brotliDecoderState);

return Result{
.success = shouldContinue,
.buffer = kj::arrayPtr(buffer, sizeof(buffer) - availOut),
};
}

static int getWindowBits(Format format) {
// We use a windowBits value of 15 combined with the magic value
// for the compression format type. For gzip, the magic value is
// 16, so the value returned is 15 + 16. For deflate, the magic
Expand All @@ -165,12 +277,16 @@ class Context {
static constexpr auto GZIP = 16;
static constexpr auto DEFLATE = 15;
static constexpr auto DEFLATE_RAW = -15;
if (format == "gzip")
return DEFLATE + GZIP;
else if (format == "deflate")
return DEFLATE;
else if (format == "deflate-raw")
return DEFLATE_RAW;
switch (format) {
case Format::GZIP:
return DEFLATE + GZIP;
case Format::DEFLATE:
return DEFLATE;
case Format::DEFLATE_RAW:
return DEFLATE_RAW;
case Format::BROTLI:
KJ_UNREACHABLE;
}
KJ_UNREACHABLE;
}

Expand All @@ -180,6 +296,14 @@ class Context {

// For the eponymous compatibility flag
ContextFlags strictCompression;
Format format;
const uint8_t* brotliNextIn = nullptr;
size_t brotliAvailIn = 0;
// Brotli state structs are opaque, so kj::Own would require complete types.
BrotliEncoderState* brotliEncoderState = nullptr;
BrotliDecoderState* brotliDecoderState = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely not a fan of the use of raw pointers in here. This should be using kj::Owns and kj::Arrays etc.

// Defer reporting of trailing brotli bytes until output is drained.
bool brotliTrailingError = false;
};

// Buffer class based on std::vector that erases data that has been read from it lazily to avoid
Expand Down Expand Up @@ -289,9 +413,18 @@ class CompressionStreamBase: public kj::Refcounted,
KJ_ASSERT(minBytes <= maxBytes);
// Re-throw any stored exception
throwIfException();
// If stream has ended normally and no buffered data, return EOF
if (isInTerminalState() && output.empty()) {
co_return static_cast<size_t>(0);
if (output.empty()) {
// For brotli we defer trailing-data errors until buffered output is drained.
if (context.hasTrailingError()) {
auto ex =
JSG_KJ_EXCEPTION(FAILED, TypeError, "Trailing bytes after end of compressed data");
cancelInternal(kj::cp(ex));
kj::throwFatalException(kj::mv(ex));
}
// If stream has ended normally and no buffered data, return EOF.
if (isInTerminalState()) {
co_return static_cast<size_t>(0);
}
}
// Active or terminal with data remaining
co_return co_await tryReadInternal(
Expand Down Expand Up @@ -659,8 +792,10 @@ kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> createDecompressionStre
} // namespace

jsg::Ref<CompressionStream> CompressionStream::constructor(jsg::Lock& js, kj::String format) {
JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError,
"The compression format must be either 'deflate', 'deflate-raw' or 'gzip'.");
JSG_REQUIRE(
format == "deflate" || format == "gzip" || format == "deflate-raw" || format == "brotli",
TypeError,
"The compression format must be either 'deflate', 'deflate-raw', 'gzip', or 'brotli'.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to call the function that parses these and have that throw the error directly to avoid the double string comparisons


// TODO(cleanup): Once the autogate is removed, we can delete CompressionStreamImpl
kj::Rc<CompressionStreamBase<Context::Mode::COMPRESS>> impl = createCompressionStreamImpl(
Expand All @@ -679,8 +814,10 @@ jsg::Ref<CompressionStream> CompressionStream::constructor(jsg::Lock& js, kj::St
}

jsg::Ref<DecompressionStream> DecompressionStream::constructor(jsg::Lock& js, kj::String format) {
JSG_REQUIRE(format == "deflate" || format == "gzip" || format == "deflate-raw", TypeError,
"The compression format must be either 'deflate', 'deflate-raw' or 'gzip'.");
JSG_REQUIRE(
format == "deflate" || format == "gzip" || format == "deflate-raw" || format == "brotli",
TypeError,
"The compression format must be either 'deflate', 'deflate-raw', 'gzip', or 'brotli'.");

kj::Rc<CompressionStreamBase<Context::Mode::DECOMPRESS>> impl =
createDecompressionStreamImpl(kj::mv(format),
Expand Down
4 changes: 2 additions & 2 deletions src/workerd/api/streams/compression.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class CompressionStream: public TransformStream {
JSG_INHERIT(TransformStream);

JSG_TS_OVERRIDE(extends TransformStream<ArrayBuffer | ArrayBufferView, Uint8Array> { constructor(format
: "gzip" | "deflate" | "deflate-raw");
: "gzip" | "deflate" | "deflate-raw" | "brotli");
});
}
};
Expand All @@ -59,7 +59,7 @@ class DecompressionStream: public TransformStream {
JSG_INHERIT(TransformStream);

JSG_TS_OVERRIDE(extends TransformStream<ArrayBuffer | ArrayBufferView, Uint8Array> { constructor(format
: "gzip" | "deflate" | "deflate-raw");
: "gzip" | "deflate" | "deflate-raw" | "brotli");
});
}
};
Expand Down
35 changes: 14 additions & 21 deletions src/workerd/api/streams/internal.c++
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,6 @@
namespace workerd::api {

namespace {
// Use this in places where the exception thrown would cause finalizers to run. Your exception
// will not go anywhere, but we'll log the exception message to the console until the problem this
// papers over is fixed.
[[noreturn]] void throwTypeErrorAndConsoleWarn(kj::StringPtr message) {
KJ_IF_SOME(context, IoContext::tryCurrent()) {
if (context.isInspectorEnabled()) {
context.logWarning(message);
}
}

kj::throwFatalException(kj::Exception(kj::Exception::Type::FAILED, __FILE__, __LINE__,
kj::str(JSG_EXCEPTION(TypeError) ": ", message)));
}

kj::Promise<void> pumpTo(ReadableStreamSource& input, WritableStreamSink& output, bool end) {
kj::byte buffer[4096]{};

Expand Down Expand Up @@ -1072,8 +1058,18 @@ jsg::Promise<void> WritableStreamInternalController::write(
return js.rejectedPromise<void>(errored.addRef(js));
}
KJ_CASE_ONEOF(writable, IoOwn<Writable>) {
// Byte streams must reject invalid chunks and error the stream so reads fail too.
auto rejectInvalidChunk = [&](kj::StringPtr message) {
auto reason = js.v8TypeError(message);
writable->abort(js.exceptionToKj(js.v8Ref(reason)));
doError(js, reason);
return js.rejectedPromise<void>(reason);
};

if (value == kj::none) {
return js.resolvedPromise();
return rejectInvalidChunk(
"This TransformStream is being used as a byte stream, but received an object of "
"non-ArrayBuffer/ArrayBufferView type on its writable side.");
}
auto chunk = KJ_ASSERT_NONNULL(value);

Expand All @@ -1090,16 +1086,14 @@ jsg::Promise<void> WritableStreamInternalController::write(
byteLength = view->ByteLength();
byteOffset = view->ByteOffset();
} else if (chunk->IsString()) {
// TODO(later): This really ought to return a rejected promise and not a sync throw.
// This case caused me a moment of confusion during testing, so I think it's worth
// a specific error message.
throwTypeErrorAndConsoleWarn(
return rejectInvalidChunk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a behavioral change that might require a compat flag I believe

"This TransformStream is being used as a byte stream, but received a string on its "
"writable side. If you wish to write a string, you'll probably want to explicitly "
"UTF-8-encode it with TextEncoder.");
} else {
// TODO(later): This really ought to return a rejected promise and not a sync throw.
throwTypeErrorAndConsoleWarn(
return rejectInvalidChunk(
"This TransformStream is being used as a byte stream, but received an object of "
"non-ArrayBuffer/ArrayBufferView type on its writable side.");
}
Expand All @@ -1116,8 +1110,7 @@ jsg::Promise<void> WritableStreamInternalController::write(
auto ptr =
kj::ArrayPtr<kj::byte>(static_cast<kj::byte*>(store->Data()) + byteOffset, byteLength);
if (store->IsShared()) {
throwTypeErrorAndConsoleWarn(
"Cannot construct an array buffer from a shared backing store");
return rejectInvalidChunk("Cannot construct an array buffer from a shared backing store");
}
queue.push_back(
WriteEvent{.outputLock = IoContext::current().waitForOutputLocksIfNecessaryIoOwn(),
Expand Down
1 change: 1 addition & 0 deletions src/wpt/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ wpt_test(

wpt_test(
name = "compression",
size = "enormous",
config = "compression-test.ts",
start_server = True,
target_compatible_with = select({
Expand Down
Loading
Loading