diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index fd2caee7..b8ae5b0f 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -12,6 +12,8 @@ jobs: package: - deparser - parser + - plpgsql-deparser + - plpgsql-parser - pgsql-cli - proto-parser - transform @@ -39,4 +41,4 @@ jobs: run: pnpm build - name: test - run: pnpm --filter ${{ matrix.package }} test \ No newline at end of file + run: pnpm --filter ${{ matrix.package }} test diff --git a/__fixtures__/plpgsql-generated/generated.json b/__fixtures__/plpgsql-generated/generated.json index 0275d46e..d617cb68 100644 --- a/__fixtures__/plpgsql-generated/generated.json +++ b/__fixtures__/plpgsql-generated/generated.json @@ -87,6 +87,20 @@ "plpgsql_domain-17.sql": "-- fail\n\nCREATE FUNCTION build_ordered_named_pair(i int, j int) RETURNS ordered_named_pair AS $$\nbegin\nreturn row(i, j);\nend\n$$ LANGUAGE plpgsql", "plpgsql_domain-18.sql": "CREATE FUNCTION build_ordered_named_pairs(i int, j int) RETURNS ordered_named_pair[] AS $$\nbegin\nreturn array[row(i, j), row(i, j+1)];\nend\n$$ LANGUAGE plpgsql", "plpgsql_domain-19.sql": "-- fail\n\nCREATE FUNCTION test_assign_ordered_named_pairs(x int, y int, z int)\n RETURNS ordered_named_pair[] AS $$\ndeclare v ordered_named_pair[] := array[row(x, y)];\nbegin\n-- ideally this would work, but it doesn't yet:\n-- v[1].j := z;\nreturn v;\nend\n$$ LANGUAGE plpgsql", + "plpgsql_deparser_fixes-1.sql": "-- Fixtures to test deparser fixes from constructive-db PR #229\n-- These exercise: PERFORM, INTO clause placement, record field qualification, RETURN handling\n\n-- Test 1: PERFORM statement (parser stores as SELECT, deparser must strip SELECT)\nCREATE FUNCTION test_perform_basic() RETURNS trigger\nLANGUAGE plpgsql AS $$\nBEGIN\n PERFORM pg_notify('test_channel', 'message');\n RETURN NEW;\nEND$$", + "plpgsql_deparser_fixes-2.sql": "-- Test 2: PERFORM with function call and arguments\nCREATE FUNCTION test_perform_with_args() RETURNS trigger\nLANGUAGE plpgsql AS $$\nBEGIN\n IF (TG_OP = 'INSERT' OR TG_OP = 'UPDATE') THEN\n PERFORM pg_notify(TG_ARGV[0], to_json(NEW)::text);\n RETURN NEW;\n END IF;\n IF (TG_OP = 'DELETE') THEN\n PERFORM pg_notify(TG_ARGV[0], to_json(OLD)::text);\n RETURN OLD;\n END IF;\n RETURN NULL;\nEND$$", + "plpgsql_deparser_fixes-3.sql": "-- Test 3: INTO clause with record field target (recfield qualification)\nCREATE FUNCTION test_into_record_field() RETURNS trigger\nLANGUAGE plpgsql AS $$\nBEGIN\n SELECT\n NEW.is_approved IS TRUE\n AND NEW.is_verified IS TRUE\n AND NEW.is_disabled IS FALSE INTO NEW.is_active;\n RETURN NEW;\nEND$$", + "plpgsql_deparser_fixes-4.sql": "-- Test 4: INTO clause with subquery (depth-aware scanner must skip nested FROM)\nCREATE FUNCTION test_into_with_subquery() RETURNS trigger\nLANGUAGE plpgsql AS $$\nDECLARE\n result_value int;\nBEGIN\n SELECT count(*) INTO result_value\n FROM (SELECT id FROM users WHERE id = NEW.user_id) sub;\n RETURN NEW;\nEND$$", + "plpgsql_deparser_fixes-5.sql": "-- Test 5: INTO clause with multiple record fields\nCREATE FUNCTION test_into_multiple_fields() RETURNS trigger\nLANGUAGE plpgsql AS $$\nBEGIN\n SELECT is_active, is_verified INTO NEW.is_active, NEW.is_verified\n FROM users WHERE id = NEW.user_id;\n RETURN NEW;\nEND$$", + "plpgsql_deparser_fixes-6.sql": "-- Test 6: SETOF function with RETURN QUERY and bare RETURN\nCREATE FUNCTION test_setof_return_query(p_limit int)\nRETURNS SETOF int\nLANGUAGE plpgsql AS $$\nBEGIN\n RETURN QUERY SELECT generate_series(1, p_limit);\n RETURN;\nEND$$", + "plpgsql_deparser_fixes-7.sql": "-- Test 7: SETOF function with RETURN NEXT\nCREATE FUNCTION test_setof_return_next(p_count int)\nRETURNS SETOF text\nLANGUAGE plpgsql AS $$\nDECLARE\n i int;\nBEGIN\n FOR i IN 1..p_count LOOP\n RETURN NEXT 'item_' || i::text;\n END LOOP;\n RETURN;\nEND$$", + "plpgsql_deparser_fixes-8.sql": "-- Test 8: Void function with bare RETURN\nCREATE FUNCTION test_void_function(p_value text)\nRETURNS void\nLANGUAGE plpgsql AS $$\nBEGIN\n RAISE NOTICE 'Value: %', p_value;\n RETURN;\nEND$$", + "plpgsql_deparser_fixes-9.sql": "-- Test 9: Scalar function with RETURN NULL\nCREATE FUNCTION test_scalar_return_null()\nRETURNS int\nLANGUAGE plpgsql AS $$\nBEGIN\n RETURN NULL;\nEND$$", + "plpgsql_deparser_fixes-10.sql": "-- Test 10: Scalar function with conditional RETURN\nCREATE FUNCTION test_scalar_conditional(p_value int)\nRETURNS int\nLANGUAGE plpgsql AS $$\nBEGIN\n IF p_value > 0 THEN\n RETURN p_value * 2;\n END IF;\n RETURN NULL;\nEND$$", + "plpgsql_deparser_fixes-11.sql": "-- Test 11: OUT parameter function with bare RETURN\nCREATE FUNCTION test_out_params(OUT ok boolean, OUT message text)\nLANGUAGE plpgsql AS $$\nBEGIN\n ok := true;\n message := 'success';\n RETURN;\nEND$$", + "plpgsql_deparser_fixes-12.sql": "-- Test 12: RETURNS TABLE function with RETURN QUERY\nCREATE FUNCTION test_returns_table(p_prefix text)\nRETURNS TABLE(id int, name text)\nLANGUAGE plpgsql AS $$\nBEGIN\n RETURN QUERY SELECT 1, p_prefix || '_one';\n RETURN QUERY SELECT 2, p_prefix || '_two';\n RETURN;\nEND$$", + "plpgsql_deparser_fixes-13.sql": "-- Test 13: Trigger function with complex logic\nCREATE FUNCTION test_trigger_complex() RETURNS trigger\nLANGUAGE plpgsql AS $$\nDECLARE\n defaults_record record;\n bit_len int;\nBEGIN\n bit_len := bit_length(NEW.permissions);\n \n SELECT * INTO defaults_record\n FROM permission_defaults AS t\n LIMIT 1;\n \n IF found THEN\n NEW.is_approved := defaults_record.is_approved;\n NEW.is_verified := defaults_record.is_verified;\n END IF;\n \n IF NEW.is_owner IS TRUE THEN\n NEW.is_admin := true;\n NEW.is_approved := true;\n NEW.is_verified := true;\n END IF;\n \n SELECT\n NEW.is_approved IS TRUE\n AND NEW.is_verified IS TRUE\n AND NEW.is_disabled IS FALSE INTO NEW.is_active;\n \n RETURN NEW;\nEND$$", + "plpgsql_deparser_fixes-14.sql": "-- Test 14: Procedure (implicit void return)\nCREATE PROCEDURE test_procedure(p_message text)\nLANGUAGE plpgsql AS $$\nBEGIN\n RAISE NOTICE '%', p_message;\nEND$$", "plpgsql_control-1.sql": "--\n-- Tests for PL/pgSQL control structures\n--\n\n-- integer FOR loop\n\ndo $$\nbegin\n -- basic case\n for i in 1..3 loop\n raise notice '1..3: i = %', i;\n end loop;\n -- with BY, end matches exactly\n for i in 1..10 by 3 loop\n raise notice '1..10 by 3: i = %', i;\n end loop;\n -- with BY, end does not match\n for i in 1..11 by 3 loop\n raise notice '1..11 by 3: i = %', i;\n end loop;\n -- zero iterations\n for i in 1..0 by 3 loop\n raise notice '1..0 by 3: i = %', i;\n end loop;\n -- REVERSE\n for i in reverse 10..0 by 3 loop\n raise notice 'reverse 10..0 by 3: i = %', i;\n end loop;\n -- potential overflow\n for i in 2147483620..2147483647 by 10 loop\n raise notice '2147483620..2147483647 by 10: i = %', i;\n end loop;\n -- potential overflow, reverse direction\n for i in reverse -2147483620..-2147483647 by 10 loop\n raise notice 'reverse -2147483620..-2147483647 by 10: i = %', i;\n end loop;\nend$$", "plpgsql_control-2.sql": "-- BY can't be zero or negative\ndo $$\nbegin\n for i in 1..3 by 0 loop\n raise notice '1..3 by 0: i = %', i;\n end loop;\nend$$", "plpgsql_control-3.sql": "do $$\nbegin\n for i in 1..3 by -1 loop\n raise notice '1..3 by -1: i = %', i;\n end loop;\nend$$", diff --git a/__fixtures__/plpgsql/plpgsql_deparser_fixes.sql b/__fixtures__/plpgsql/plpgsql_deparser_fixes.sql new file mode 100644 index 00000000..24698cd1 --- /dev/null +++ b/__fixtures__/plpgsql/plpgsql_deparser_fixes.sql @@ -0,0 +1,164 @@ +-- Fixtures to test deparser fixes from constructive-db PR #229 +-- These exercise: PERFORM, INTO clause placement, record field qualification, RETURN handling + +-- Test 1: PERFORM statement (parser stores as SELECT, deparser must strip SELECT) +CREATE FUNCTION test_perform_basic() RETURNS trigger +LANGUAGE plpgsql AS $$ +BEGIN + PERFORM pg_notify('test_channel', 'message'); + RETURN NEW; +END$$; + +-- Test 2: PERFORM with function call and arguments +CREATE FUNCTION test_perform_with_args() RETURNS trigger +LANGUAGE plpgsql AS $$ +BEGIN + IF (TG_OP = 'INSERT' OR TG_OP = 'UPDATE') THEN + PERFORM pg_notify(TG_ARGV[0], to_json(NEW)::text); + RETURN NEW; + END IF; + IF (TG_OP = 'DELETE') THEN + PERFORM pg_notify(TG_ARGV[0], to_json(OLD)::text); + RETURN OLD; + END IF; + RETURN NULL; +END$$; + +-- Test 3: INTO clause with record field target (recfield qualification) +CREATE FUNCTION test_into_record_field() RETURNS trigger +LANGUAGE plpgsql AS $$ +BEGIN + SELECT + NEW.is_approved IS TRUE + AND NEW.is_verified IS TRUE + AND NEW.is_disabled IS FALSE INTO NEW.is_active; + RETURN NEW; +END$$; + +-- Test 4: INTO clause with subquery (depth-aware scanner must skip nested FROM) +CREATE FUNCTION test_into_with_subquery() RETURNS trigger +LANGUAGE plpgsql AS $$ +DECLARE + result_value int; +BEGIN + SELECT count(*) INTO result_value + FROM (SELECT id FROM users WHERE id = NEW.user_id) sub; + RETURN NEW; +END$$; + +-- Test 5: INTO clause with multiple record fields +CREATE FUNCTION test_into_multiple_fields() RETURNS trigger +LANGUAGE plpgsql AS $$ +BEGIN + SELECT is_active, is_verified INTO NEW.is_active, NEW.is_verified + FROM users WHERE id = NEW.user_id; + RETURN NEW; +END$$; + +-- Test 6: SETOF function with RETURN QUERY and bare RETURN +CREATE FUNCTION test_setof_return_query(p_limit int) +RETURNS SETOF int +LANGUAGE plpgsql AS $$ +BEGIN + RETURN QUERY SELECT generate_series(1, p_limit); + RETURN; +END$$; + +-- Test 7: SETOF function with RETURN NEXT +CREATE FUNCTION test_setof_return_next(p_count int) +RETURNS SETOF text +LANGUAGE plpgsql AS $$ +DECLARE + i int; +BEGIN + FOR i IN 1..p_count LOOP + RETURN NEXT 'item_' || i::text; + END LOOP; + RETURN; +END$$; + +-- Test 8: Void function with bare RETURN +CREATE FUNCTION test_void_function(p_value text) +RETURNS void +LANGUAGE plpgsql AS $$ +BEGIN + RAISE NOTICE 'Value: %', p_value; + RETURN; +END$$; + +-- Test 9: Scalar function with RETURN NULL +CREATE FUNCTION test_scalar_return_null() +RETURNS int +LANGUAGE plpgsql AS $$ +BEGIN + RETURN NULL; +END$$; + +-- Test 10: Scalar function with conditional RETURN +CREATE FUNCTION test_scalar_conditional(p_value int) +RETURNS int +LANGUAGE plpgsql AS $$ +BEGIN + IF p_value > 0 THEN + RETURN p_value * 2; + END IF; + RETURN NULL; +END$$; + +-- Test 11: OUT parameter function with bare RETURN +CREATE FUNCTION test_out_params(OUT ok boolean, OUT message text) +LANGUAGE plpgsql AS $$ +BEGIN + ok := true; + message := 'success'; + RETURN; +END$$; + +-- Test 12: RETURNS TABLE function with RETURN QUERY +CREATE FUNCTION test_returns_table(p_prefix text) +RETURNS TABLE(id int, name text) +LANGUAGE plpgsql AS $$ +BEGIN + RETURN QUERY SELECT 1, p_prefix || '_one'; + RETURN QUERY SELECT 2, p_prefix || '_two'; + RETURN; +END$$; + +-- Test 13: Trigger function with complex logic +CREATE FUNCTION test_trigger_complex() RETURNS trigger +LANGUAGE plpgsql AS $$ +DECLARE + defaults_record record; + bit_len int; +BEGIN + bit_len := bit_length(NEW.permissions); + + SELECT * INTO defaults_record + FROM permission_defaults AS t + LIMIT 1; + + IF found THEN + NEW.is_approved := defaults_record.is_approved; + NEW.is_verified := defaults_record.is_verified; + END IF; + + IF NEW.is_owner IS TRUE THEN + NEW.is_admin := true; + NEW.is_approved := true; + NEW.is_verified := true; + END IF; + + SELECT + NEW.is_approved IS TRUE + AND NEW.is_verified IS TRUE + AND NEW.is_disabled IS FALSE INTO NEW.is_active; + + RETURN NEW; +END$$; + +-- Test 14: Procedure (implicit void return) +CREATE PROCEDURE test_procedure(p_message text) +LANGUAGE plpgsql AS $$ +BEGIN + RAISE NOTICE '%', p_message; +END$$; diff --git a/packages/plpgsql-deparser/__tests__/__snapshots__/hydrate-demo.test.ts.snap b/packages/plpgsql-deparser/__tests__/__snapshots__/hydrate-demo.test.ts.snap index 3b2c9ed2..cf789cb6 100644 --- a/packages/plpgsql-deparser/__tests__/__snapshots__/hydrate-demo.test.ts.snap +++ b/packages/plpgsql-deparser/__tests__/__snapshots__/hydrate-demo.test.ts.snap @@ -69,7 +69,7 @@ BEGIN RAISE EXCEPTION 'p_round_to out of range: %', p_round_to; END IF; IF p_lock THEN - PERFORM SELECT pg_advisory_xact_lock(v_lock_key); + PERFORM pg_advisory_xact_lock(v_lock_key); END IF; IF p_debug THEN RAISE NOTICE 'big_kitchen_sink start=% org=% user=% from=% to=% min_total=%', v_now, p_org_id, p_user_id, p_from_ts, p_to_ts, v_min_total; @@ -99,7 +99,7 @@ BEGIN SELECT t.orders_scanned, t.gross_total, - t.avg_total + t.avg_total INTO v_orders_scanned, v_gross, v_avg FROM totals AS t; IF p_apply_discount THEN v_rebate := round(v_gross * GREATEST(LEAST(v_discount_rate + v_jitter, 0.50), 0), p_round_to); @@ -110,7 +110,7 @@ BEGIN v_net := round(((v_gross - v_discount) + v_tax) * power(10::numeric, 0), p_round_to); SELECT oi.sku, - CAST(sum(oi.quantity) AS bigint) AS qty + CAST(sum(oi.quantity) AS bigint) AS qty INTO v_top_sku, v_top_sku_qty FROM app_public.order_item AS oi JOIN app_public.app_order AS o ON o.id = oi.order_id WHERE diff --git a/packages/plpgsql-deparser/__tests__/plpgsql-deparser.test.ts b/packages/plpgsql-deparser/__tests__/plpgsql-deparser.test.ts index ee7a5d9b..b99d74bb 100644 --- a/packages/plpgsql-deparser/__tests__/plpgsql-deparser.test.ts +++ b/packages/plpgsql-deparser/__tests__/plpgsql-deparser.test.ts @@ -32,14 +32,74 @@ describe('PLpgSQLDeparser', () => { }); describe('round-trip tests using generated.json', () => { - it('should round-trip plpgsql_domain fixtures', async () => { - const entries = fixtureTestUtils.getTestEntries(['plpgsql_domain']); + // Known failing fixtures due to pre-existing deparser issues: + // - Schema qualification loss (pg_catalog.pg_class%rowtype[] -> pg_class%rowtype[]) + // - Tagged dollar quote reconstruction ($tag$...$tag$ not supported) + // - Exception block handling issues + // TODO: Fix these underlying issues and remove from allowlist + const KNOWN_FAILING_FIXTURES = new Set([ + 'plpgsql_varprops-13.sql', + 'plpgsql_trap-1.sql', + 'plpgsql_trap-2.sql', + 'plpgsql_trap-3.sql', + 'plpgsql_trap-4.sql', + 'plpgsql_trap-5.sql', + 'plpgsql_trap-6.sql', + 'plpgsql_trap-7.sql', + 'plpgsql_transaction-17.sql', + 'plpgsql_transaction-19.sql', + 'plpgsql_transaction-20.sql', + 'plpgsql_transaction-21.sql', + 'plpgsql_control-15.sql', + 'plpgsql_control-17.sql', + 'plpgsql_call-44.sql', + 'plpgsql_array-20.sql', + ]); + + it('should round-trip ALL generated fixtures (excluding known failures)', async () => { + // Get all fixtures without any filter - this ensures we test everything + const entries = fixtureTestUtils.getTestEntries(); expect(entries.length).toBeGreaterThan(0); + const failures: { key: string; error: string }[] = []; + const unexpectedPasses: string[] = []; + for (const [key] of entries) { - await fixtureTestUtils.runSingleFixture(key); + const isKnownFailing = KNOWN_FAILING_FIXTURES.has(key); + try { + await fixtureTestUtils.runSingleFixture(key); + if (isKnownFailing) { + unexpectedPasses.push(key); + } + } catch (err) { + if (!isKnownFailing) { + failures.push({ + key, + error: err instanceof Error ? err.message : String(err), + }); + } + } } - }); + + // Report unexpected passes (fixtures that should be removed from allowlist) + if (unexpectedPasses.length > 0) { + console.log(`\nUnexpected passes (remove from KNOWN_FAILING_FIXTURES):\n${unexpectedPasses.join('\n')}`); + } + + // Fail if any non-allowlisted fixtures fail (regression detection) + if (failures.length > 0) { + const failureReport = failures + .map(f => ` - ${f.key}: ${f.error}`) + .join('\n'); + throw new Error( + `${failures.length} NEW fixture failures (not in allowlist):\n${failureReport}` + ); + } + + // Report coverage stats + const testedCount = entries.length - KNOWN_FAILING_FIXTURES.size; + console.log(`\nRound-trip tested ${testedCount} of ${entries.length} fixtures (${KNOWN_FAILING_FIXTURES.size} known failures skipped)`); + }, 120000); // 2 minute timeout for all fixtures }); describe('PLpgSQLDeparser class', () => { diff --git a/packages/plpgsql-deparser/__tests__/pretty/__snapshots__/plpgsql-pretty.test.ts.snap b/packages/plpgsql-deparser/__tests__/pretty/__snapshots__/plpgsql-pretty.test.ts.snap index 82aa09d0..f3c60777 100644 --- a/packages/plpgsql-deparser/__tests__/pretty/__snapshots__/plpgsql-pretty.test.ts.snap +++ b/packages/plpgsql-deparser/__tests__/pretty/__snapshots__/plpgsql-pretty.test.ts.snap @@ -36,7 +36,7 @@ begin raise exception 'p_round_to out of range: %', p_round_to; end if; if p_lock then - perform SELECT pg_advisory_xact_lock(v_lock_key); + perform pg_advisory_xact_lock(v_lock_key); end if; if p_debug then raise notice 'big_kitchen_sink start=% org=% user=% from=% to=% min_total=%', v_now, p_org_id, p_user_id, p_from_ts, p_to_ts, v_min_total; @@ -67,7 +67,7 @@ begin SELECT t.orders_scanned, t.gross_total, - t.avg_total + t.avg_total into v_orders_scanned, v_gross, v_avg FROM totals t; if p_apply_discount then v_discount := round(v_gross * GREATEST(LEAST(v_discount_rate + v_jitter, 0.50), 0), p_round_to); @@ -78,7 +78,7 @@ begin v_net := round((v_gross - v_discount + v_tax) * power(10::numeric, 0), p_round_to); SELECT oi.sku, - sum(oi.quantity)::bigint AS qty + sum(oi.quantity)::bigint AS qty into v_top_sku, v_top_sku_qty FROM app_public.order_item oi JOIN app_public.app_order o ON o.id = oi.order_id WHERE o.org_id = p_org_id @@ -237,7 +237,7 @@ BEGIN RAISE EXCEPTION 'p_round_to out of range: %', p_round_to; END IF; IF p_lock THEN - PERFORM SELECT pg_advisory_xact_lock(v_lock_key); + PERFORM pg_advisory_xact_lock(v_lock_key); END IF; IF p_debug THEN RAISE NOTICE 'big_kitchen_sink start=% org=% user=% from=% to=% min_total=%', v_now, p_org_id, p_user_id, p_from_ts, p_to_ts, v_min_total; @@ -268,7 +268,7 @@ BEGIN SELECT t.orders_scanned, t.gross_total, - t.avg_total + t.avg_total INTO v_orders_scanned, v_gross, v_avg FROM totals t; IF p_apply_discount THEN v_discount := round(v_gross * GREATEST(LEAST(v_discount_rate + v_jitter, 0.50), 0), p_round_to); @@ -279,7 +279,7 @@ BEGIN v_net := round((v_gross - v_discount + v_tax) * power(10::numeric, 0), p_round_to); SELECT oi.sku, - sum(oi.quantity)::bigint AS qty + sum(oi.quantity)::bigint AS qty INTO v_top_sku, v_top_sku_qty FROM app_public.order_item oi JOIN app_public.app_order o ON o.id = oi.order_id WHERE o.org_id = p_org_id diff --git a/packages/plpgsql-deparser/__tests__/return-context.test.ts b/packages/plpgsql-deparser/__tests__/return-context.test.ts new file mode 100644 index 00000000..d56dee00 --- /dev/null +++ b/packages/plpgsql-deparser/__tests__/return-context.test.ts @@ -0,0 +1,134 @@ +import { loadModule, parsePlPgSQLSync, parseSync } from '@libpg-query/parser'; +import { deparseSync, ReturnInfo } from '../src'; +import { PLpgSQLParseResult } from '../src/types'; + +describe('RETURN statement context handling', () => { + beforeAll(async () => { + await loadModule(); + }); + + describe('deparseSync with returnInfo context', () => { + const parseBody = (sql: string): PLpgSQLParseResult => { + return parsePlPgSQLSync(sql) as unknown as PLpgSQLParseResult; + }; + + it('should output bare RETURN for void functions', () => { + const sql = `CREATE FUNCTION test_void() RETURNS void +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN; +END; +$$`; + const parsed = parseBody(sql); + const returnInfo: ReturnInfo = { kind: 'void' }; + const result = deparseSync(parsed, undefined, returnInfo); + + expect(result).toContain('RETURN;'); + expect(result).not.toContain('RETURN NULL'); + }); + + it('should output bare RETURN for setof functions', () => { + const sql = `CREATE FUNCTION test_setof() RETURNS SETOF integer +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN NEXT 1; + RETURN NEXT 2; + RETURN; +END; +$$`; + const parsed = parseBody(sql); + const returnInfo: ReturnInfo = { kind: 'setof' }; + const result = deparseSync(parsed, undefined, returnInfo); + + expect(result).toContain('RETURN;'); + expect(result).not.toContain('RETURN NULL'); + }); + + it('should output bare RETURN for trigger functions', () => { + const sql = `CREATE FUNCTION test_trigger() RETURNS trigger +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN NEW; +END; +$$`; + const parsed = parseBody(sql); + const returnInfo: ReturnInfo = { kind: 'trigger' }; + const result = deparseSync(parsed, undefined, returnInfo); + + expect(result).toContain('RETURN NEW'); + }); + + it('should output bare RETURN for out_params functions', () => { + const sql = `CREATE FUNCTION test_out(OUT result integer) RETURNS integer +LANGUAGE plpgsql +AS $$ +BEGIN + result := 42; + RETURN; +END; +$$`; + const parsed = parseBody(sql); + const returnInfo: ReturnInfo = { kind: 'out_params' }; + const result = deparseSync(parsed, undefined, returnInfo); + + expect(result).toContain('RETURN;'); + expect(result).not.toContain('RETURN NULL'); + }); + + it('should output RETURN NULL for scalar functions with empty return', () => { + const sql = `CREATE FUNCTION test_scalar(val integer) RETURNS text +LANGUAGE plpgsql +AS $$ +BEGIN + IF val > 0 THEN + RETURN 'positive'; + END IF; + RETURN; +END; +$$`; + const parsed = parseBody(sql); + const returnInfo: ReturnInfo = { kind: 'scalar' }; + const result = deparseSync(parsed, undefined, returnInfo); + + expect(result).toContain('RETURN NULL'); + }); + + it('should preserve RETURN with expression regardless of context', () => { + const sql = `CREATE FUNCTION test_expr() RETURNS integer +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN 42; +END; +$$`; + const parsed = parseBody(sql); + + // Test with scalar context + const scalarResult = deparseSync(parsed, undefined, { kind: 'scalar' }); + expect(scalarResult).toContain('RETURN 42'); + + // Test with void context (shouldn't change expression returns) + const voidResult = deparseSync(parsed, undefined, { kind: 'void' }); + expect(voidResult).toContain('RETURN 42'); + }); + + it('should default to bare RETURN when no context provided (backward compatibility)', () => { + const sql = `CREATE FUNCTION test_no_context() RETURNS void +LANGUAGE plpgsql +AS $$ +BEGIN + RETURN; +END; +$$`; + const parsed = parseBody(sql); + + // No returnInfo provided - should default to bare RETURN + const result = deparseSync(parsed); + expect(result).toContain('RETURN;'); + expect(result).not.toContain('RETURN NULL'); + }); + }); +}); diff --git a/packages/plpgsql-deparser/src/index.ts b/packages/plpgsql-deparser/src/index.ts index 54ef674f..8179b0b0 100644 --- a/packages/plpgsql-deparser/src/index.ts +++ b/packages/plpgsql-deparser/src/index.ts @@ -1,4 +1,4 @@ -import { PLpgSQLDeparser, PLpgSQLDeparserOptions } from './plpgsql-deparser'; +import { PLpgSQLDeparser, PLpgSQLDeparserOptions, ReturnInfo, ReturnInfoKind } from './plpgsql-deparser'; const deparseMethod = PLpgSQLDeparser.deparse; const deparseFunctionMethod = PLpgSQLDeparser.deparseFunction; @@ -18,7 +18,7 @@ export const deparseFunction = async ( return deparseFunctionMethod(...args); }; -export { PLpgSQLDeparser, PLpgSQLDeparserOptions }; +export { PLpgSQLDeparser, PLpgSQLDeparserOptions, ReturnInfo, ReturnInfoKind }; export * from './types'; export * from './hydrate-types'; export { hydratePlpgsqlAst, dehydratePlpgsqlAst, isHydratedExpr, getOriginalQuery, DehydrationOptions } from './hydrate'; diff --git a/packages/plpgsql-deparser/src/plpgsql-deparser.ts b/packages/plpgsql-deparser/src/plpgsql-deparser.ts index f3e46e56..232b5dd7 100644 --- a/packages/plpgsql-deparser/src/plpgsql-deparser.ts +++ b/packages/plpgsql-deparser/src/plpgsql-deparser.ts @@ -65,10 +65,23 @@ export interface PLpgSQLDeparserOptions { uppercase?: boolean; } +/** + * Return type information for a PL/pgSQL function. + * Used to determine the correct RETURN statement syntax: + * - void/setof/trigger/out_params: bare RETURN is valid + * - scalar: RETURN NULL is required for empty returns + */ +export type ReturnInfoKind = 'void' | 'setof' | 'trigger' | 'scalar' | 'out_params'; + +export interface ReturnInfo { + kind: ReturnInfoKind; +} + export interface PLpgSQLDeparserContext { indentLevel: number; options: PLpgSQLDeparserOptions; datums?: PLpgSQLDatum[]; + returnInfo?: ReturnInfo; } /** @@ -90,49 +103,62 @@ export class PLpgSQLDeparser { /** * Static method to deparse a PL/pgSQL parse result + * @param parseResult - The PL/pgSQL parse result + * @param options - Deparser options + * @param returnInfo - Optional return type info for correct RETURN statement handling */ - static deparse(parseResult: PLpgSQLParseResult, options?: PLpgSQLDeparserOptions): string { - return new PLpgSQLDeparser(options).deparseResult(parseResult); + static deparse(parseResult: PLpgSQLParseResult, options?: PLpgSQLDeparserOptions, returnInfo?: ReturnInfo): string { + return new PLpgSQLDeparser(options).deparseResult(parseResult, returnInfo); } /** * Static method to deparse a single PL/pgSQL function body + * @param func - The PL/pgSQL function AST + * @param options - Deparser options + * @param returnInfo - Optional return type info for correct RETURN statement handling */ - static deparseFunction(func: PLpgSQL_function, options?: PLpgSQLDeparserOptions): string { - return new PLpgSQLDeparser(options).deparseFunction(func); + static deparseFunction(func: PLpgSQL_function, options?: PLpgSQLDeparserOptions, returnInfo?: ReturnInfo): string { + return new PLpgSQLDeparser(options).deparseFunction(func, returnInfo); } /** * Deparse a complete PL/pgSQL parse result + * @param parseResult - The PL/pgSQL parse result + * @param returnInfo - Optional return type info for correct RETURN statement handling */ - deparseResult(parseResult: PLpgSQLParseResult): string { + deparseResult(parseResult: PLpgSQLParseResult, returnInfo?: ReturnInfo): string { if (!parseResult.plpgsql_funcs || parseResult.plpgsql_funcs.length === 0) { return ''; } return parseResult.plpgsql_funcs - .map(func => this.deparseFunctionNode(func)) + .map(func => this.deparseFunctionNode(func, returnInfo)) .join(this.options.newline + this.options.newline); } /** * Deparse a PLpgSQL_function node wrapper + * @param node - The PLpgSQL_function node wrapper + * @param returnInfo - Optional return type info for correct RETURN statement handling */ - deparseFunctionNode(node: PLpgSQLFunctionNode): string { + deparseFunctionNode(node: PLpgSQLFunctionNode, returnInfo?: ReturnInfo): string { if ('PLpgSQL_function' in node) { - return this.deparseFunction(node.PLpgSQL_function); + return this.deparseFunction(node.PLpgSQL_function, returnInfo); } throw new Error('Unknown function node type'); } /** * Deparse a PL/pgSQL function body + * @param func - The PL/pgSQL function AST + * @param returnInfo - Optional return type info for correct RETURN statement handling */ - deparseFunction(func: PLpgSQL_function): string { + deparseFunction(func: PLpgSQL_function, returnInfo?: ReturnInfo): string { const context: PLpgSQLDeparserContext = { indentLevel: 0, options: this.options, datums: func.datums, + returnInfo, }; const parts: string[] = []; @@ -930,6 +956,13 @@ export class PLpgSQLDeparser { /** * Deparse a RETURN statement + * + * PostgreSQL requires different RETURN syntax based on function type: + * - void/setof/trigger/out_params: bare RETURN is valid + * - scalar: RETURN NULL is required for empty returns + * + * When returnInfo is provided in context, we use it to determine the correct syntax. + * When not provided, we fall back to heuristics that scan the function body. */ private deparseReturn(ret: PLpgSQL_stmt_return, context: PLpgSQLDeparserContext): string { const kw = this.keyword; @@ -943,6 +976,19 @@ export class PLpgSQLDeparser { return `${kw('RETURN')} ${varName}`; } + // Empty RETURN - need to determine if we should output bare RETURN or RETURN NULL + // Use context.returnInfo if available, otherwise use heuristics + if (context.returnInfo) { + // Context-based: use the provided return type info + if (context.returnInfo.kind === 'scalar') { + return `${kw('RETURN')} ${kw('NULL')}`; + } + // void, setof, trigger, out_params all use bare RETURN + return kw('RETURN'); + } + + // Heuristic fallback: bare RETURN is the safest default + // This maintains backward compatibility for callers that don't provide returnInfo return kw('RETURN'); } diff --git a/packages/plpgsql-deparser/test-utils/index.ts b/packages/plpgsql-deparser/test-utils/index.ts index c5c8041b..b1c1954a 100644 --- a/packages/plpgsql-deparser/test-utils/index.ts +++ b/packages/plpgsql-deparser/test-utils/index.ts @@ -178,19 +178,27 @@ export const transform = (obj: any, props: any): any => { copy = {}; for (const attr in obj) { if (obj.hasOwnProperty(attr)) { + let value: any; if (props.hasOwnProperty(attr)) { if (typeof props[attr] === 'function') { - copy[attr] = props[attr](obj[attr]); + value = props[attr](obj[attr]); } else if (props[attr].hasOwnProperty(obj[attr])) { - copy[attr] = props[attr][obj[attr]]; + value = props[attr][obj[attr]]; } else { - copy[attr] = transform(obj[attr], props); + value = transform(obj[attr], props); } } else { - copy[attr] = transform(obj[attr], props); + value = transform(obj[attr], props); + } + // Skip undefined values to normalize "missing vs present-but-undefined" + if (value !== undefined) { + copy[attr] = value; } } else { - copy[attr] = transform(obj[attr], props); + const value = transform(obj[attr], props); + if (value !== undefined) { + copy[attr] = value; + } } } return copy; diff --git a/packages/plpgsql-parser/__tests__/return-info.test.ts b/packages/plpgsql-parser/__tests__/return-info.test.ts new file mode 100644 index 00000000..524a3ce0 --- /dev/null +++ b/packages/plpgsql-parser/__tests__/return-info.test.ts @@ -0,0 +1,149 @@ +import { loadModule, parseSync } from '@libpg-query/parser'; +import { getReturnInfo, getReturnInfoFromParsedFunction } from '../src/return-info'; + +describe('getReturnInfo', () => { + beforeAll(async () => { + await loadModule(); + }); + + const parseCreateFunction = (sql: string): any => { + const result = parseSync(sql); + const stmt = result.stmts?.[0]?.stmt as any; + return stmt?.CreateFunctionStmt; + }; + + describe('void functions', () => { + it('should return void for RETURNS void', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_void() RETURNS void + LANGUAGE plpgsql AS $$ BEGIN END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'void' }); + }); + + it('should return void for procedures', () => { + const stmt = parseCreateFunction(` + CREATE PROCEDURE test_proc() + LANGUAGE plpgsql AS $$ BEGIN END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'void' }); + }); + }); + + describe('setof functions', () => { + it('should return setof for RETURNS SETOF integer', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_setof() RETURNS SETOF integer + LANGUAGE plpgsql AS $$ BEGIN RETURN NEXT 1; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'setof' }); + }); + + it('should return setof for RETURNS SETOF record', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_setof_record() RETURNS SETOF record + LANGUAGE plpgsql AS $$ BEGIN END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'setof' }); + }); + }); + + describe('trigger functions', () => { + it('should return trigger for RETURNS trigger', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_trigger() RETURNS trigger + LANGUAGE plpgsql AS $$ BEGIN RETURN NEW; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'trigger' }); + }); + }); + + describe('out_params functions', () => { + it('should return out_params for OUT parameters', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_out(IN x integer, OUT result integer) + LANGUAGE plpgsql AS $$ BEGIN result := x * 2; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'out_params' }); + }); + + it('should return out_params for INOUT parameters', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_inout(INOUT x integer) + LANGUAGE plpgsql AS $$ BEGIN x := x * 2; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'out_params' }); + }); + + it('should return out_params for RETURNS TABLE', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_table() RETURNS TABLE (id integer, name text) + LANGUAGE plpgsql AS $$ BEGIN RETURN QUERY SELECT 1, 'test'; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'out_params' }); + }); + }); + + describe('scalar functions', () => { + it('should return scalar for RETURNS integer', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_scalar() RETURNS integer + LANGUAGE plpgsql AS $$ BEGIN RETURN 42; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'scalar' }); + }); + + it('should return scalar for RETURNS text', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_text() RETURNS text + LANGUAGE plpgsql AS $$ BEGIN RETURN 'hello'; END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'scalar' }); + }); + + it('should return scalar for RETURNS record (non-setof)', () => { + const stmt = parseCreateFunction(` + CREATE FUNCTION test_record() RETURNS record + LANGUAGE plpgsql AS $$ BEGIN RETURN ROW(1, 'test'); END; $$ + `); + expect(getReturnInfo(stmt)).toEqual({ kind: 'scalar' }); + }); + }); + + describe('edge cases', () => { + it('should return scalar for null input', () => { + expect(getReturnInfo(null)).toEqual({ kind: 'scalar' }); + }); + + it('should return scalar for undefined input', () => { + expect(getReturnInfo(undefined)).toEqual({ kind: 'scalar' }); + }); + + it('should return void for missing return type', () => { + const stmt = { funcname: [{ String: { sval: 'test' } }] }; + expect(getReturnInfo(stmt)).toEqual({ kind: 'void' }); + }); + }); +}); + +describe('getReturnInfoFromParsedFunction', () => { + beforeAll(async () => { + await loadModule(); + }); + + it('should extract return info from ParsedFunction-like object', () => { + const result = parseSync(` + CREATE FUNCTION test() RETURNS integer + LANGUAGE plpgsql AS $$ BEGIN RETURN 1; END; $$ + `); + const stmtWrapper = result.stmts?.[0]?.stmt as any; + const stmt = stmtWrapper?.CreateFunctionStmt; + + const parsedFunction = { stmt }; + expect(getReturnInfoFromParsedFunction(parsedFunction)).toEqual({ kind: 'scalar' }); + }); + + it('should return scalar for null input', () => { + expect(getReturnInfoFromParsedFunction(null)).toEqual({ kind: 'scalar' }); + }); +}); diff --git a/packages/plpgsql-parser/src/index.ts b/packages/plpgsql-parser/src/index.ts index 533e673e..eaf9d365 100644 --- a/packages/plpgsql-parser/src/index.ts +++ b/packages/plpgsql-parser/src/index.ts @@ -11,13 +11,16 @@ export { type PLpgSQLNodeTag, type WalkOptions } from './traverse'; +export { getReturnInfo, getReturnInfoFromParsedFunction } from './return-info'; export { hydratePlpgsqlAst, dehydratePlpgsqlAst, deparseSync as deparsePlpgsqlBody, isHydratedExpr, - getOriginalQuery + getOriginalQuery, + ReturnInfo, + ReturnInfoKind } from 'plpgsql-deparser'; export { deparse as deparseSql, Deparser } from 'pgsql-deparser'; diff --git a/packages/plpgsql-parser/src/return-info.ts b/packages/plpgsql-parser/src/return-info.ts new file mode 100644 index 00000000..e7777c86 --- /dev/null +++ b/packages/plpgsql-parser/src/return-info.ts @@ -0,0 +1,103 @@ +import type { ReturnInfo, ReturnInfoKind } from 'plpgsql-deparser'; + +/** + * Extract return type information from a CreateFunctionStmt AST node. + * + * This helper analyzes the function's return type and parameters to determine + * the correct ReturnInfo for the PL/pgSQL deparser. + * + * @param createFunctionStmt - The CreateFunctionStmt AST node + * @returns ReturnInfo object with the appropriate kind + */ +export function getReturnInfo(createFunctionStmt: any): ReturnInfo { + if (!createFunctionStmt) { + return { kind: 'scalar' }; + } + + // Check if it's a procedure (procedures have implicit void return) + if (createFunctionStmt.is_procedure) { + return { kind: 'void' }; + } + + // Check for OUT/INOUT/TABLE parameters - these indicate out_params return type + if (createFunctionStmt.parameters && Array.isArray(createFunctionStmt.parameters)) { + const hasOutParams = createFunctionStmt.parameters.some((param: any) => { + const fp = param?.FunctionParameter; + if (!fp) return false; + const mode = fp.mode; + return mode === 'FUNC_PARAM_OUT' || + mode === 'FUNC_PARAM_INOUT' || + mode === 'FUNC_PARAM_TABLE'; + }); + if (hasOutParams) { + return { kind: 'out_params' }; + } + } + + // Check the return type + // Note: returnType is directly a TypeName object, not wrapped in { TypeName: ... } + const returnType = createFunctionStmt.returnType; + if (!returnType) { + // No return type specified - treat as void + return { kind: 'void' }; + } + + // Check for SETOF + if (returnType.setof) { + return { kind: 'setof' }; + } + + // Extract the type name + const typeName = extractTypeName(returnType); + + // Check for void + if (typeName === 'void') { + return { kind: 'void' }; + } + + // Check for trigger + if (typeName === 'trigger') { + return { kind: 'trigger' }; + } + + // Default to scalar for all other types + return { kind: 'scalar' }; +} + +/** + * Extract the type name from a TypeName AST node. + * + * @param typeName - The TypeName AST node + * @returns The type name as a lowercase string + */ +function extractTypeName(typeName: any): string { + if (!typeName?.names || !Array.isArray(typeName.names)) { + return ''; + } + + // The names array contains String nodes with sval property + // For simple types like "void", it's usually ["pg_catalog", "void"] + // For user types, it might be ["schema", "type"] or just ["type"] + const names = typeName.names + .map((n: any) => n?.String?.sval) + .filter((s: string | undefined): s is string => typeof s === 'string'); + + // Return the last name (the actual type name, not the schema) + const lastName = names[names.length - 1]; + return lastName ? lastName.toLowerCase() : ''; +} + +/** + * Get return info from a ParsedFunction object. + * + * @param parsedFunction - A ParsedFunction object from plpgsql-parser + * @returns ReturnInfo object with the appropriate kind + */ +export function getReturnInfoFromParsedFunction(parsedFunction: any): ReturnInfo { + if (!parsedFunction?.stmt) { + return { kind: 'scalar' }; + } + return getReturnInfo(parsedFunction.stmt); +} + +export type { ReturnInfo, ReturnInfoKind };