diff --git a/.changeset/curly-taxis-try.md b/.changeset/curly-taxis-try.md new file mode 100644 index 000000000..a6a89b542 --- /dev/null +++ b/.changeset/curly-taxis-try.md @@ -0,0 +1,5 @@ +--- +'@electric-sql/pglite-socket': patch +--- + +allow extensions to be loaded via '-e/--extensions ' cmd line parameter' diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 3a18acca4..a8fbaeae1 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -95,6 +95,7 @@ jobs: working-directory: ./packages/pglite needs: [build-all] steps: + - uses: actions/checkout@v4 - uses: pnpm/action-setup@v4 - uses: actions/setup-node@v4 diff --git a/packages/pglite-socket/README.md b/packages/pglite-socket/README.md index f298db1f2..cc18b9f3b 100644 --- a/packages/pglite-socket/README.md +++ b/packages/pglite-socket/README.md @@ -160,6 +160,7 @@ pglite-server --help - `-h, --host=HOST` - Host to bind to (default: 127.0.0.1) - `-u, --path=UNIX` - Unix socket to bind to (takes precedence over host:port) - `-v, --debug=LEVEL` - Debug level 0-5 (default: 0) +- `-e, --extensions=LIST` - Comma-separated list of extensions to load (e.g., vector,pgcrypto) - `-r, --run=COMMAND` - Command to run after server starts - `--include-database-url` - Include DATABASE_URL in subprocess environment - `--shutdown-timeout=MS` - Timeout for graceful subprocess shutdown in ms (default: 5000) diff --git a/packages/pglite-socket/src/scripts/server.ts b/packages/pglite-socket/src/scripts/server.ts index 552295625..ba765bb38 100644 --- a/packages/pglite-socket/src/scripts/server.ts +++ b/packages/pglite-socket/src/scripts/server.ts @@ -1,6 +1,7 @@ #!/usr/bin/env node import { PGlite, DebugLevel } from '@electric-sql/pglite' +import type { Extension, Extensions } from '@electric-sql/pglite' import { PGLiteSocketServer } from '../index' import { parseArgs } from 'node:util' import { spawn, ChildProcess } from 'node:child_process' @@ -38,6 +39,12 @@ const args = parseArgs({ default: '0', help: 'Debug level (0-5)', }, + extensions: { + type: 'string', + short: 'e', + default: undefined, + help: 'Comma-separated list of extensions to load (e.g., vector,pgcrypto)', + }, run: { type: 'string', short: 'r', @@ -72,6 +79,9 @@ Options: -h, --host=HOST Host to bind to (default: 127.0.0.1) -u, --path=UNIX Unix socket to bind to (default: undefined). Takes precedence over host:port -v, --debug=LEVEL Debug level 0-5 (default: 0) + -e, --extensions=LIST Comma-separated list of extensions to load + Formats: vector, pgcrypto (built-in/contrib) + @org/package/path:exportedName (npm package) -r, --run=COMMAND Command to run after server starts --include-database-url Include DATABASE_URL in subprocess environment --shutdown-timeout=MS Timeout for graceful subprocess shutdown in ms (default: 5000) @@ -83,6 +93,7 @@ interface ServerConfig { host: string path?: string debugLevel: DebugLevel + extensionNames?: string[] runCommand?: string includeDatabaseUrl: boolean shutdownTimeout: number @@ -99,12 +110,16 @@ class PGLiteServerRunner { } static parseConfig(): ServerConfig { + const extensionsArg = args.values.extensions as string | undefined return { dbPath: args.values.db as string, port: parseInt(args.values.port as string, 10), host: args.values.host as string, path: args.values.path as string, debugLevel: parseInt(args.values.debug as string, 10) as DebugLevel, + extensionNames: extensionsArg + ? extensionsArg.split(',').map((e) => e.trim()) + : undefined, runCommand: args.values.run as string, includeDatabaseUrl: args.values['include-database-url'] as boolean, shutdownTimeout: parseInt(args.values['shutdown-timeout'] as string, 10), @@ -126,11 +141,86 @@ class PGLiteServerRunner { } } + private async importExtensions(): Promise { + if (!this.config.extensionNames?.length) { + return undefined + } + + const extensions: Extensions = {} + + // Built-in extensions that are not in contrib + const builtInExtensions = [ + 'vector', + 'live', + 'pg_hashids', + 'pg_ivm', + 'pg_uuidv7', + 'pgtap', + ] + + for (const name of this.config.extensionNames) { + let ext: Extension | null = null + + try { + // Check if this is a custom package path (contains ':') + // Format: @org/package/path:exportedName or package/path:exportedName + if (name.includes(':')) { + const [packagePath, exportName] = name.split(':') + if (!packagePath || !exportName) { + throw new Error( + `Invalid extension format '${name}'. Expected: package/path:exportedName`, + ) + } + const mod = await import(packagePath) + ext = mod[exportName] as Extension + if (ext) { + extensions[exportName] = ext + console.log( + `Imported extension '${exportName}' from '${packagePath}'`, + ) + } + } else if (builtInExtensions.includes(name)) { + // Built-in extension (e.g., @electric-sql/pglite/vector) + const mod = await import(`@electric-sql/pglite/${name}`) + ext = mod[name] as Extension + if (ext) { + extensions[name] = ext + console.log(`Imported extension: ${name}`) + } + } else { + // Try contrib first (e.g., @electric-sql/pglite/contrib/pgcrypto) + try { + const mod = await import(`@electric-sql/pglite/contrib/${name}`) + ext = mod[name] as Extension + } catch { + // Fall back to external package (e.g., @electric-sql/pglite-) + const mod = await import(`@electric-sql/pglite-${name}`) + ext = mod[name] as Extension + } + if (ext) { + extensions[name] = ext + console.log(`Imported extension: ${name}`) + } + } + } catch (error) { + console.error(`Failed to import extension '${name}':`, error) + throw new Error(`Failed to import extension '${name}'`) + } + } + + return Object.keys(extensions).length > 0 ? extensions : undefined + } + private async initializeDatabase(): Promise { console.log(`Initializing PGLite with database: ${this.config.dbPath}`) console.log(`Debug level: ${this.config.debugLevel}`) - this.db = new PGlite(this.config.dbPath, { debug: this.config.debugLevel }) + const extensions = await this.importExtensions() + + this.db = new PGlite(this.config.dbPath, { + debug: this.config.debugLevel, + extensions, + }) await this.db.waitReady console.log('PGlite database initialized') } diff --git a/packages/pglite-socket/tests/query-with-node-pg.test.ts b/packages/pglite-socket/tests/query-with-node-pg.test.ts index ad5354c9a..f84c0b6f1 100644 --- a/packages/pglite-socket/tests/query-with-node-pg.test.ts +++ b/packages/pglite-socket/tests/query-with-node-pg.test.ts @@ -10,6 +10,13 @@ import { import { Client } from 'pg' import { PGlite } from '@electric-sql/pglite' import { PGLiteSocketServer } from '../src' +import { spawn, ChildProcess } from 'node:child_process' +import { fileURLToPath } from 'node:url' +import { dirname, join } from 'node:path' +import fs from 'fs' + +const __filename = fileURLToPath(import.meta.url) +const __dirname = dirname(__filename) /** * Debug configuration for testing @@ -533,4 +540,178 @@ describe(`PGLite Socket Server`, () => { expect(receivedPayload).toBe('Hello from PGlite!') }) }) + + describe('with extensions via CLI', () => { + const UNIX_SOCKET_DIR_PATH = `/tmp/${Date.now().toString()}` + fs.mkdirSync(UNIX_SOCKET_DIR_PATH) + const UNIX_SOCKET_PATH = `${UNIX_SOCKET_DIR_PATH}/.s.PGSQL.5432` + let serverProcess: ChildProcess | null = null + let client: typeof Client.prototype + + beforeAll(async () => { + // Start the server with extensions via CLI using tsx for dev or node for dist + const serverScript = join(__dirname, '../src/scripts/server.ts') + serverProcess = spawn( + 'npx', + [ + 'tsx', + serverScript, + '--path', + UNIX_SOCKET_PATH, + '--extensions', + 'vector,pg_uuidv7,@electric-sql/pglite/pg_hashids:pg_hashids', + ], + { + stdio: ['ignore', 'pipe', 'pipe'], + }, + ) + + // Wait for server to be ready by checking for "listening" message + await new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + reject(new Error('Server startup timeout')) + }, 30000) + + const onData = (data: Buffer) => { + const output = data.toString() + if (output.includes('listening')) { + clearTimeout(timeout) + resolve() + } + } + + serverProcess!.stdout?.on('data', onData) + serverProcess!.stderr?.on('data', (data) => { + console.error('Server stderr:', data.toString()) + }) + + serverProcess!.on('error', (err) => { + clearTimeout(timeout) + reject(err) + }) + + serverProcess!.on('exit', (code) => { + if (code !== 0 && code !== null) { + clearTimeout(timeout) + reject(new Error(`Server exited with code ${code}`)) + } + }) + }) + + console.log('Server with extensions started') + + client = new Client({ + host: UNIX_SOCKET_DIR_PATH, + database: 'postgres', + user: 'postgres', + password: 'postgres', + connectionTimeoutMillis: 10000, + }) + await client.connect() + }) + + afterAll(async () => { + if (client) { + await client.end().catch(() => {}) + } + + if (serverProcess) { + serverProcess.kill('SIGTERM') + await new Promise((resolve) => { + serverProcess!.on('exit', () => resolve()) + setTimeout(resolve, 2000) // Force resolve after 2s + }) + } + }) + + it('should load and use vector extension', async () => { + // Create the extension + await client.query('CREATE EXTENSION IF NOT EXISTS vector') + + // Verify extension is loaded + const extCheck = await client.query(` + SELECT extname FROM pg_extension WHERE extname = 'vector' + `) + expect(extCheck.rows).toHaveLength(1) + expect(extCheck.rows[0].extname).toBe('vector') + + // Create a table with vector column + await client.query(` + CREATE TABLE test_vectors ( + id SERIAL PRIMARY KEY, + name TEXT, + vec vector(3) + ) + `) + + // Insert test data + await client.query(` + INSERT INTO test_vectors (name, vec) VALUES + ('test1', '[1,2,3]'), + ('test2', '[4,5,6]'), + ('test3', '[7,8,9]') + `) + + // Query with vector distance + const result = await client.query(` + SELECT name, vec, vec <-> '[3,1,2]' AS distance + FROM test_vectors + ORDER BY distance + `) + + expect(result.rows).toHaveLength(3) + expect(result.rows[0].name).toBe('test1') + expect(result.rows[0].vec).toBe('[1,2,3]') + expect(parseFloat(result.rows[0].distance)).toBeCloseTo(2.449, 2) + }) + + it('should load and use pg_uuidv7 extension', async () => { + // Create the extension + await client.query('CREATE EXTENSION IF NOT EXISTS pg_uuidv7') + + // Verify extension is loaded + const extCheck = await client.query(` + SELECT extname FROM pg_extension WHERE extname = 'pg_uuidv7' + `) + expect(extCheck.rows).toHaveLength(1) + expect(extCheck.rows[0].extname).toBe('pg_uuidv7') + + // Generate a UUIDv7 + const result = await client.query('SELECT uuid_generate_v7() as uuid') + expect(result.rows[0].uuid).toHaveLength(36) + + // Test uuid_v7_to_timestamptz function + const tsResult = await client.query(` + SELECT uuid_v7_to_timestamptz('018570bb-4a7d-7c7e-8df4-6d47afd8c8fc') as ts + `) + const timestamp = new Date(tsResult.rows[0].ts) + expect(timestamp.toISOString()).toBe('2023-01-02T04:26:40.637Z') + }) + + it('should load and use pg_hashids extension from npm package path', async () => { + // Create the extension + await client.query('CREATE EXTENSION IF NOT EXISTS pg_hashids') + + // Verify extension is loaded + const extCheck = await client.query(` + SELECT extname FROM pg_extension WHERE extname = 'pg_hashids' + `) + expect(extCheck.rows).toHaveLength(1) + expect(extCheck.rows[0].extname).toBe('pg_hashids') + + // Test id_encode function + const result = await client.query(` + SELECT id_encode(1234567, 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as hash + `) + expect(result.rows[0].hash).toBeTruthy() + expect(typeof result.rows[0].hash).toBe('string') + + // Test id_decode function (round-trip) + const hash = result.rows[0].hash + const decodeResult = await client.query(` + SELECT id_decode('${hash}', 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as id + `) + expect(decodeResult.rows[0].id[0]).toBe('1234567') + }) + }) }) diff --git a/packages/pglite-socket/tests/query-with-postgres-js.test.ts b/packages/pglite-socket/tests/query-with-postgres-js.test.ts index 13fedf0b1..f88e0137b 100644 --- a/packages/pglite-socket/tests/query-with-postgres-js.test.ts +++ b/packages/pglite-socket/tests/query-with-postgres-js.test.ts @@ -10,6 +10,13 @@ import { import postgres from 'postgres' import { PGlite } from '@electric-sql/pglite' import { PGLiteSocketServer } from '../src' +import { spawn, ChildProcess } from 'node:child_process' +import { fileURLToPath } from 'node:url' +import { dirname, join } from 'node:path' +import fs from 'fs' + +const __filename = fileURLToPath(import.meta.url) +const __dirname = dirname(__filename) /** * Debug configuration for testing @@ -493,4 +500,179 @@ describe(`PGLite Socket Server`, () => { }) } }) + + describe('with extensions via CLI', () => { + const UNIX_SOCKET_DIR_PATH = `/tmp/${Date.now().toString()}` + fs.mkdirSync(UNIX_SOCKET_DIR_PATH) + const UNIX_SOCKET_PATH = `${UNIX_SOCKET_DIR_PATH}/.s.PGSQL.5432` + let serverProcess: ChildProcess | null = null + let sql: ReturnType + + beforeAll(async () => { + // Start the server with extensions via CLI using tsx for dev or node for dist + const serverScript = join(__dirname, '../src/scripts/server.ts') + serverProcess = spawn( + 'npx', + [ + 'tsx', + serverScript, + '--path', + UNIX_SOCKET_PATH, + '--extensions', + 'vector,pg_uuidv7,@electric-sql/pglite/pg_hashids:pg_hashids', + ], + { + stdio: ['ignore', 'pipe', 'pipe'], + }, + ) + + // Wait for server to be ready by checking for "listening" message + await new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + reject(new Error('Server startup timeout')) + }, 30000) + + const onData = (data: Buffer) => { + const output = data.toString() + if (output.includes('listening')) { + clearTimeout(timeout) + resolve() + } + } + + serverProcess!.stdout?.on('data', onData) + serverProcess!.stderr?.on('data', (data) => { + console.error('Server stderr:', data.toString()) + }) + + serverProcess!.on('error', (err) => { + clearTimeout(timeout) + reject(err) + }) + + serverProcess!.on('exit', (code) => { + if (code !== 0 && code !== null) { + clearTimeout(timeout) + reject(new Error(`Server exited with code ${code}`)) + } + }) + }) + + console.log('Server with extensions started') + + sql = postgres({ + path: UNIX_SOCKET_PATH, + database: 'postgres', + username: 'postgres', + password: 'postgres', + idle_timeout: 5, + connect_timeout: 10, + max: 1, + }) + }) + + afterAll(async () => { + if (sql) { + await sql.end().catch(() => {}) + } + + if (serverProcess) { + serverProcess.kill('SIGTERM') + await new Promise((resolve) => { + serverProcess!.on('exit', () => resolve()) + setTimeout(resolve, 2000) // Force resolve after 2s + }) + } + }) + + it('should load and use vector extension', async () => { + // Create the extension + await sql`CREATE EXTENSION IF NOT EXISTS vector` + + // Verify extension is loaded + const extCheck = await sql` + SELECT extname FROM pg_extension WHERE extname = 'vector' + ` + expect(extCheck).toHaveLength(1) + expect(extCheck[0].extname).toBe('vector') + + // Create a table with vector column + await sql` + CREATE TABLE test_vectors ( + id SERIAL PRIMARY KEY, + name TEXT, + vec vector(3) + ) + ` + + // Insert test data + await sql` + INSERT INTO test_vectors (name, vec) VALUES + ('test1', '[1,2,3]'), + ('test2', '[4,5,6]'), + ('test3', '[7,8,9]') + ` + + // Query with vector distance + const result = await sql` + SELECT name, vec, vec <-> '[3,1,2]' AS distance + FROM test_vectors + ORDER BY distance + ` + + expect(result).toHaveLength(3) + expect(result[0].name).toBe('test1') + expect(result[0].vec).toBe('[1,2,3]') + expect(parseFloat(result[0].distance)).toBeCloseTo(2.449, 2) + }) + + it('should load and use pg_uuidv7 extension', async () => { + // Create the extension + await sql`CREATE EXTENSION IF NOT EXISTS pg_uuidv7` + + // Verify extension is loaded + const extCheck = await sql` + SELECT extname FROM pg_extension WHERE extname = 'pg_uuidv7' + ` + expect(extCheck).toHaveLength(1) + expect(extCheck[0].extname).toBe('pg_uuidv7') + + // Generate a UUIDv7 + const result = await sql`SELECT uuid_generate_v7() as uuid` + expect(result[0].uuid).toHaveLength(36) + + // Test uuid_v7_to_timestamptz function + const tsResult = await sql` + SELECT uuid_v7_to_timestamptz('018570bb-4a7d-7c7e-8df4-6d47afd8c8fc') as ts + ` + const timestamp = new Date(tsResult[0].ts) + expect(timestamp.toISOString()).toBe('2023-01-02T04:26:40.637Z') + }) + + it('should load and use pg_hashids extension from npm package path', async () => { + // Create the extension + await sql`CREATE EXTENSION IF NOT EXISTS pg_hashids` + + // Verify extension is loaded + const extCheck = await sql` + SELECT extname FROM pg_extension WHERE extname = 'pg_hashids' + ` + expect(extCheck).toHaveLength(1) + expect(extCheck[0].extname).toBe('pg_hashids') + + // Test id_encode function + const result = await sql` + SELECT id_encode(1234567, 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as hash + ` + expect(result[0].hash).toBeTruthy() + expect(typeof result[0].hash).toBe('string') + + // Test id_decode function (round-trip) + const hash = result[0].hash + const decodeResult = await sql` + SELECT id_decode(${hash}, 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as id + ` + expect(decodeResult[0].id[0]).toBe('1234567') + }) + }) }) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 1693c8901..1ccfd45bf 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -2077,7 +2077,6 @@ packages: bun@1.1.30: resolution: {integrity: sha512-ysRL1pq10Xba0jqVLPrKU3YIv0ohfp3cTajCPtpjCyppbn3lfiAVNpGoHfyaxS17OlPmWmR67UZRPw/EueQuug==} - cpu: [arm64, x64] os: [darwin, linux, win32] hasBin: true