diff --git a/images/container-client-test/app.js b/images/container-client-test/app.js index 8190141809b..e9ca5a8d79b 100644 --- a/images/container-client-test/app.js +++ b/images/container-client-test/app.js @@ -9,6 +9,14 @@ const server = createServer(function (req, res) { return; } + // Endpoint to get the PID of the current process + if (req.url === '/pid') { + res.writeHead(200, { 'Content-Type': 'text/plain' }); + res.write(String(process.pid)); + res.end(); + return; + } + res.writeHead(200, { 'Content-Type': 'text/plain' }); res.write('Hello World!'); res.end(); diff --git a/src/workerd/api/container.c++ b/src/workerd/api/container.c++ index 24137471ff0..a8793aa54b0 100644 --- a/src/workerd/api/container.c++ +++ b/src/workerd/api/container.c++ @@ -49,16 +49,28 @@ void Container::start(jsg::Lock& js, jsg::Optional maybeOptions) } } - IoContext::current().addTask(req.sendIgnoringResult()); - - running = true; - if (flags.getWorkerdExperimental()) { KJ_IF_SOME(hardTimeoutMs, options.hardTimeout) { JSG_REQUIRE(hardTimeoutMs > 0, RangeError, "Hard timeout must be greater than 0"); req.setHardTimeoutMs(hardTimeoutMs); } + + auto& hostNamespaces = options.hostNamespaces.orDefault(kj::arr(kj::str("pid"))); + auto list = req.initHostNamespaces(hostNamespaces.size()); + for (auto i: kj::indices(hostNamespaces)) { + auto& ns = hostNamespaces[i]; + if (ns == "pid") { + list.set(i, rpc::Container::StartParams::Namespace::PID); + } else { + JSG_FAIL_REQUIRE( + TypeError, "Invalid hostNamespace value: \"", ns, "\". Valid values are: \"pid\""); + } + } } + + IoContext::current().addTask(req.sendIgnoringResult()); + + running = true; } jsg::Promise Container::setInactivityTimeout(jsg::Lock& js, int64_t durationMs) { diff --git a/src/workerd/api/container.h b/src/workerd/api/container.h index 1cbacce4027..58a5daa479c 100644 --- a/src/workerd/api/container.h +++ b/src/workerd/api/container.h @@ -29,10 +29,11 @@ class Container: public jsg::Object { bool enableInternet = false; jsg::Optional> env; jsg::Optional hardTimeout; + jsg::Optional> hostNamespaces; // TODO(containers): Allow intercepting stdin/stdout/stderr by specifying streams here. - JSG_STRUCT(entrypoint, enableInternet, env, hardTimeout); + JSG_STRUCT(entrypoint, enableInternet, env, hardTimeout, hostNamespaces); JSG_STRUCT_TS_OVERRIDE_DYNAMIC(CompatibilityFlags::Reader flags) { if (flags.getWorkerdExperimental()) { JSG_TS_OVERRIDE(ContainerStartupOptions { @@ -40,6 +41,7 @@ class Container: public jsg::Object { enableInternet: boolean; env?: Record; hardTimeout?: number | bigint; + hostNamespaces?: HostNamespace[]; }); } else { JSG_TS_OVERRIDE(ContainerStartupOptions { @@ -73,6 +75,7 @@ class Container: public jsg::Object { JSG_METHOD(signal); JSG_METHOD(getTcpPort); JSG_METHOD(setInactivityTimeout); + JSG_TS_DEFINE(type HostNamespace = "pid"); } void visitForMemoryInfo(jsg::MemoryTracker& tracker) const { diff --git a/src/workerd/io/container.capnp b/src/workerd/io/container.capnp index f79e85defb2..00b603a183a 100644 --- a/src/workerd/io/container.capnp +++ b/src/workerd/io/container.capnp @@ -40,6 +40,15 @@ interface Container @0x9aaceefc06523bca { # The container will be forcefully terminated when this timeout expires, regardless of activity. # Unlike inactivity timeout, this is a hard deadline from container startup. # If 0 (default), no hard timeout is applied. + + hostNamespaces @4 :List(Namespace); + # Configure which namespaces to share with the host + + enum Namespace { + pid @0; + # Sharing the host PID namespace will make processes running outside of + # the container visible inside of the container. + } } monitor @2 () -> (exitCode: Int32); diff --git a/src/workerd/server/container-client.c++ b/src/workerd/server/container-client.c++ index 52c5e24fa3d..70f6f1e95e1 100644 --- a/src/workerd/server/container-client.c++ +++ b/src/workerd/server/container-client.c++ @@ -266,7 +266,8 @@ kj::Promise ContainerClient::inspectContainer( kj::Promise ContainerClient::createContainer( kj::Maybe::Reader> entrypoint, - kj::Maybe::Reader> environment) { + kj::Maybe::Reader> environment, + capnp::List::Reader hostNamespaces) { // Docker API: POST /containers/create capnp::JsonCodec codec; codec.handleByAnnotation(); @@ -300,6 +301,14 @@ kj::Promise ContainerClient::createContainer( // We need to set a restart policy to avoid having ambiguous states // where the container we're managing is stuck at "exited" state. hostConfig.initRestartPolicy().setName("on-failure"); + // Configure host namespace sharing + for (auto ns: hostNamespaces) { + switch (ns) { + case rpc::Container::StartParams::Namespace::PID: + hostConfig.setPidMode("host"); + break; + } + } auto response = co_await dockerApiRequest(network, kj::str(dockerPath), kj::HttpMethod::POST, kj::str("/containers/create?name=", containerName), codec.encode(jsonRoot)); @@ -397,7 +406,7 @@ kj::Promise ContainerClient::start(StartContext context) { environment = params.getEnvironmentVariables(); } - co_await createContainer(entrypoint, environment); + co_await createContainer(entrypoint, environment, params.getHostNamespaces()); co_await startContainer(); } diff --git a/src/workerd/server/container-client.h b/src/workerd/server/container-client.h index 456c8911eb5..81f3db7c05d 100644 --- a/src/workerd/server/container-client.h +++ b/src/workerd/server/container-client.h @@ -85,7 +85,8 @@ class ContainerClient final: public rpc::Container::Server, public kj::Refcounte kj::Maybe body = kj::none); kj::Promise inspectContainer(); kj::Promise createContainer(kj::Maybe::Reader> entrypoint, - kj::Maybe::Reader> environment); + kj::Maybe::Reader> environment, + capnp::List::Reader hostNamespaces); kj::Promise startContainer(); kj::Promise stopContainer(); kj::Promise killContainer(uint32_t signal); diff --git a/src/workerd/server/tests/container-client/test.js b/src/workerd/server/tests/container-client/test.js index d1659d0f3a1..a2f802043c4 100644 --- a/src/workerd/server/tests/container-client/test.js +++ b/src/workerd/server/tests/container-client/test.js @@ -260,6 +260,70 @@ export class DurableObjectExample extends DurableObject { getStatus() { return this.ctx.container.running; } + + async testHostNamespaces() { + const container = this.ctx.container; + if (container.running) { + let monitor = container.monitor().catch((_err) => {}); + await container.destroy(); + await monitor; + } + assert.strictEqual(container.running, false); + + // Test with valid hostNamespaces - should share host PID namespace + container.start({ + hostNamespaces: ['pid'], + }); + + assert.strictEqual(container.running, true); + + // Verify the container process is NOT PID 1 (indicating host PID namespace is shared) + let resp; + const maxRetries = 6; + for (let i = 1; i <= maxRetries; i++) { + try { + resp = await container.getTcpPort(8080).fetch('http://foo/pid'); + break; + } catch (e) { + if (!e.message.includes('container port not found')) { + throw e; + } + if (i === maxRetries) { + throw e; + } + await scheduler.wait(500); + } + } + + assert.strictEqual(resp.status, 200); + const pid = parseInt(await resp.text(), 10); + + // With host PID namespace, the entrypoint process should NOT be PID 1 + assert.notStrictEqual( + pid, + 1, + 'Expected pid != 1 when host PID namespace is shared' + ); + + await container.destroy(); + } + + async testHostNamespacesInvalid() { + const container = this.ctx.container; + if (container.running) { + let monitor = container.monitor().catch((_err) => {}); + await container.destroy(); + await monitor; + } + assert.strictEqual(container.running, false); + + // Test with invalid hostNamespaces + assert.throws(() => { + container.start({ + hostNamespaces: ['invalid'], + }); + }, /Invalid hostNamespace value/); + } } export class DurableObjectExample2 extends DurableObjectExample {} @@ -394,3 +458,21 @@ export const testSetInactivityTimeout = { } }, }; + +// Test hostNamespaces with valid value +export const testHostNamespaces = { + async test(_ctrl, env) { + const id = env.MY_CONTAINER.idFromName('testHostNamespaces'); + const stub = env.MY_CONTAINER.get(id); + await stub.testHostNamespaces(); + }, +}; + +// Test hostNamespaces with invalid value +export const testHostNamespacesInvalid = { + async test(_ctrl, env) { + const id = env.MY_CONTAINER.idFromName('testHostNamespacesInvalid'); + const stub = env.MY_CONTAINER.get(id); + await stub.testHostNamespacesInvalid(); + }, +};