Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/curly-taxis-try.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@electric-sql/pglite-socket': patch
---

allow extensions to be loaded via '-e/--extensions <list>' cmd line parameter'
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ jobs:
working-directory: ./packages/pglite
needs: [build-all]
steps:

- uses: actions/checkout@v4
- uses: pnpm/action-setup@v4
- uses: actions/setup-node@v4
Expand Down
1 change: 1 addition & 0 deletions packages/pglite-socket/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ pglite-server --help
- `-h, --host=HOST` - Host to bind to (default: 127.0.0.1)
- `-u, --path=UNIX` - Unix socket to bind to (takes precedence over host:port)
- `-v, --debug=LEVEL` - Debug level 0-5 (default: 0)
- `-e, --extensions=LIST` - Comma-separated list of extensions to load (e.g., vector,pgcrypto)
- `-r, --run=COMMAND` - Command to run after server starts
- `--include-database-url` - Include DATABASE_URL in subprocess environment
- `--shutdown-timeout=MS` - Timeout for graceful subprocess shutdown in ms (default: 5000)
Expand Down
92 changes: 91 additions & 1 deletion packages/pglite-socket/src/scripts/server.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env node

import { PGlite, DebugLevel } from '@electric-sql/pglite'
import type { Extension, Extensions } from '@electric-sql/pglite'
import { PGLiteSocketServer } from '../index'
import { parseArgs } from 'node:util'
import { spawn, ChildProcess } from 'node:child_process'
Expand Down Expand Up @@ -38,6 +39,12 @@ const args = parseArgs({
default: '0',
help: 'Debug level (0-5)',
},
extensions: {
type: 'string',
short: 'e',
default: undefined,
help: 'Comma-separated list of extensions to load (e.g., vector,pgcrypto)',
},
run: {
type: 'string',
short: 'r',
Expand Down Expand Up @@ -72,6 +79,9 @@ Options:
-h, --host=HOST Host to bind to (default: 127.0.0.1)
-u, --path=UNIX Unix socket to bind to (default: undefined). Takes precedence over host:port
-v, --debug=LEVEL Debug level 0-5 (default: 0)
-e, --extensions=LIST Comma-separated list of extensions to load
Formats: vector, pgcrypto (built-in/contrib)
@org/package/path:exportedName (npm package)
-r, --run=COMMAND Command to run after server starts
--include-database-url Include DATABASE_URL in subprocess environment
--shutdown-timeout=MS Timeout for graceful subprocess shutdown in ms (default: 5000)
Expand All @@ -83,6 +93,7 @@ interface ServerConfig {
host: string
path?: string
debugLevel: DebugLevel
extensionNames?: string[]
runCommand?: string
includeDatabaseUrl: boolean
shutdownTimeout: number
Expand All @@ -99,12 +110,16 @@ class PGLiteServerRunner {
}

static parseConfig(): ServerConfig {
const extensionsArg = args.values.extensions as string | undefined
return {
dbPath: args.values.db as string,
port: parseInt(args.values.port as string, 10),
host: args.values.host as string,
path: args.values.path as string,
debugLevel: parseInt(args.values.debug as string, 10) as DebugLevel,
extensionNames: extensionsArg
? extensionsArg.split(',').map((e) => e.trim())
: undefined,
runCommand: args.values.run as string,
includeDatabaseUrl: args.values['include-database-url'] as boolean,
shutdownTimeout: parseInt(args.values['shutdown-timeout'] as string, 10),
Expand All @@ -126,11 +141,86 @@ class PGLiteServerRunner {
}
}

private async importExtensions(): Promise<Extensions | undefined> {
if (!this.config.extensionNames?.length) {
return undefined
}

const extensions: Extensions = {}

// Built-in extensions that are not in contrib
const builtInExtensions = [
'vector',
'live',
'pg_hashids',
'pg_ivm',
'pg_uuidv7',
'pgtap',
]

for (const name of this.config.extensionNames) {
let ext: Extension | null = null

try {
// Check if this is a custom package path (contains ':')
// Format: @org/package/path:exportedName or package/path:exportedName
if (name.includes(':')) {
const [packagePath, exportName] = name.split(':')
if (!packagePath || !exportName) {
throw new Error(
`Invalid extension format '${name}'. Expected: package/path:exportedName`,
)
}
const mod = await import(packagePath)
ext = mod[exportName] as Extension
if (ext) {
extensions[exportName] = ext
console.log(
`Imported extension '${exportName}' from '${packagePath}'`,
)
}
} else if (builtInExtensions.includes(name)) {
// Built-in extension (e.g., @electric-sql/pglite/vector)
const mod = await import(`@electric-sql/pglite/${name}`)
ext = mod[name] as Extension
if (ext) {
extensions[name] = ext
console.log(`Imported extension: ${name}`)
}
} else {
// Try contrib first (e.g., @electric-sql/pglite/contrib/pgcrypto)
try {
const mod = await import(`@electric-sql/pglite/contrib/${name}`)
ext = mod[name] as Extension
} catch {
// Fall back to external package (e.g., @electric-sql/pglite-<extension>)
const mod = await import(`@electric-sql/pglite-${name}`)
ext = mod[name] as Extension
}
if (ext) {
extensions[name] = ext
console.log(`Imported extension: ${name}`)
}
}
} catch (error) {
console.error(`Failed to import extension '${name}':`, error)
throw new Error(`Failed to import extension '${name}'`)
}
}

return Object.keys(extensions).length > 0 ? extensions : undefined
}

private async initializeDatabase(): Promise<void> {
console.log(`Initializing PGLite with database: ${this.config.dbPath}`)
console.log(`Debug level: ${this.config.debugLevel}`)

this.db = new PGlite(this.config.dbPath, { debug: this.config.debugLevel })
const extensions = await this.importExtensions()

this.db = new PGlite(this.config.dbPath, {
debug: this.config.debugLevel,
extensions,
})
await this.db.waitReady
console.log('PGlite database initialized')
}
Expand Down
181 changes: 181 additions & 0 deletions packages/pglite-socket/tests/query-with-node-pg.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import {
import { Client } from 'pg'
import { PGlite } from '@electric-sql/pglite'
import { PGLiteSocketServer } from '../src'
import { spawn, ChildProcess } from 'node:child_process'
import { fileURLToPath } from 'node:url'
import { dirname, join } from 'node:path'
import fs from 'fs'

const __filename = fileURLToPath(import.meta.url)
const __dirname = dirname(__filename)

/**
* Debug configuration for testing
Expand Down Expand Up @@ -533,4 +540,178 @@ describe(`PGLite Socket Server`, () => {
expect(receivedPayload).toBe('Hello from PGlite!')
})
})

describe('with extensions via CLI', () => {
const UNIX_SOCKET_DIR_PATH = `/tmp/${Date.now().toString()}`
fs.mkdirSync(UNIX_SOCKET_DIR_PATH)
const UNIX_SOCKET_PATH = `${UNIX_SOCKET_DIR_PATH}/.s.PGSQL.5432`
let serverProcess: ChildProcess | null = null
let client: typeof Client.prototype

beforeAll(async () => {
// Start the server with extensions via CLI using tsx for dev or node for dist
const serverScript = join(__dirname, '../src/scripts/server.ts')
serverProcess = spawn(
'npx',
[
'tsx',
serverScript,
'--path',
UNIX_SOCKET_PATH,
'--extensions',
'vector,pg_uuidv7,@electric-sql/pglite/pg_hashids:pg_hashids',
],
{
stdio: ['ignore', 'pipe', 'pipe'],
},
)

// Wait for server to be ready by checking for "listening" message
await new Promise<void>((resolve, reject) => {
const timeout = setTimeout(() => {
reject(new Error('Server startup timeout'))
}, 30000)

const onData = (data: Buffer) => {
const output = data.toString()
if (output.includes('listening')) {
clearTimeout(timeout)
resolve()
}
}

serverProcess!.stdout?.on('data', onData)
serverProcess!.stderr?.on('data', (data) => {
console.error('Server stderr:', data.toString())
})

serverProcess!.on('error', (err) => {
clearTimeout(timeout)
reject(err)
})

serverProcess!.on('exit', (code) => {
if (code !== 0 && code !== null) {
clearTimeout(timeout)
reject(new Error(`Server exited with code ${code}`))
}
})
})

console.log('Server with extensions started')

client = new Client({
host: UNIX_SOCKET_DIR_PATH,
database: 'postgres',
user: 'postgres',
password: 'postgres',
connectionTimeoutMillis: 10000,
})
await client.connect()
})

afterAll(async () => {
if (client) {
await client.end().catch(() => {})
}

if (serverProcess) {
serverProcess.kill('SIGTERM')
await new Promise<void>((resolve) => {
serverProcess!.on('exit', () => resolve())
setTimeout(resolve, 2000) // Force resolve after 2s
})
}
})

it('should load and use vector extension', async () => {
// Create the extension
await client.query('CREATE EXTENSION IF NOT EXISTS vector')

// Verify extension is loaded
const extCheck = await client.query(`
SELECT extname FROM pg_extension WHERE extname = 'vector'
`)
expect(extCheck.rows).toHaveLength(1)
expect(extCheck.rows[0].extname).toBe('vector')

// Create a table with vector column
await client.query(`
CREATE TABLE test_vectors (
id SERIAL PRIMARY KEY,
name TEXT,
vec vector(3)
)
`)

// Insert test data
await client.query(`
INSERT INTO test_vectors (name, vec) VALUES
('test1', '[1,2,3]'),
('test2', '[4,5,6]'),
('test3', '[7,8,9]')
`)

// Query with vector distance
const result = await client.query(`
SELECT name, vec, vec <-> '[3,1,2]' AS distance
FROM test_vectors
ORDER BY distance
`)

expect(result.rows).toHaveLength(3)
expect(result.rows[0].name).toBe('test1')
expect(result.rows[0].vec).toBe('[1,2,3]')
expect(parseFloat(result.rows[0].distance)).toBeCloseTo(2.449, 2)
})

it('should load and use pg_uuidv7 extension', async () => {
// Create the extension
await client.query('CREATE EXTENSION IF NOT EXISTS pg_uuidv7')

// Verify extension is loaded
const extCheck = await client.query(`
SELECT extname FROM pg_extension WHERE extname = 'pg_uuidv7'
`)
expect(extCheck.rows).toHaveLength(1)
expect(extCheck.rows[0].extname).toBe('pg_uuidv7')

// Generate a UUIDv7
const result = await client.query('SELECT uuid_generate_v7() as uuid')
expect(result.rows[0].uuid).toHaveLength(36)

// Test uuid_v7_to_timestamptz function
const tsResult = await client.query(`
SELECT uuid_v7_to_timestamptz('018570bb-4a7d-7c7e-8df4-6d47afd8c8fc') as ts
`)
const timestamp = new Date(tsResult.rows[0].ts)
expect(timestamp.toISOString()).toBe('2023-01-02T04:26:40.637Z')
})

it('should load and use pg_hashids extension from npm package path', async () => {
// Create the extension
await client.query('CREATE EXTENSION IF NOT EXISTS pg_hashids')

// Verify extension is loaded
const extCheck = await client.query(`
SELECT extname FROM pg_extension WHERE extname = 'pg_hashids'
`)
expect(extCheck.rows).toHaveLength(1)
expect(extCheck.rows[0].extname).toBe('pg_hashids')

// Test id_encode function
const result = await client.query(`
SELECT id_encode(1234567, 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as hash
`)
expect(result.rows[0].hash).toBeTruthy()
expect(typeof result.rows[0].hash).toBe('string')

// Test id_decode function (round-trip)
const hash = result.rows[0].hash
const decodeResult = await client.query(`
SELECT id_decode('${hash}', 'salt', 10, 'abcdefghijABCDEFGHIJ1234567890') as id
`)
expect(decodeResult.rows[0].id[0]).toBe('1234567')
})
})
})
Loading