Skip to content

Commit 89a0621

Browse files
committed
fix: improve loop variable handling and add plpgsql-pretty snapshot tests
- Fix CALL statement deparser to avoid duplicate CALL keyword - Improve loop variable detection to only exclude implicitly declared vars (variables where lineno matches the loop statement's lineno) - Add plpgsql-pretty fixtures folder with sample SQL files - Add PlpgsqlPrettyTest utility for snapshot testing - Add snapshot tests for uppercase/lowercase formatting
1 parent 7354e55 commit 89a0621

File tree

8 files changed

+317
-12
lines changed

8 files changed

+317
-12
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
CREATE FUNCTION check_value(val integer) RETURNS text
2+
LANGUAGE plpgsql
3+
AS $$
4+
BEGIN
5+
IF val > 100 THEN
6+
RETURN 'large';
7+
ELSIF val > 10 THEN
8+
RETURN 'medium';
9+
ELSE
10+
RETURN 'small';
11+
END IF;
12+
END;
13+
$$
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
CREATE FUNCTION sum_to_n(n integer) RETURNS integer
2+
LANGUAGE plpgsql
3+
AS $$
4+
DECLARE
5+
total integer := 0;
6+
i integer;
7+
BEGIN
8+
FOR i IN 1..n LOOP
9+
total := total + i;
10+
END LOOP;
11+
RETURN total;
12+
END;
13+
$$
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
CREATE FUNCTION simple_add(a integer, b integer) RETURNS integer
2+
LANGUAGE plpgsql
3+
AS $$
4+
BEGIN
5+
RETURN a + b;
6+
END;
7+
$$
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Jest Snapshot v1, https://jestjs.io/docs/snapshot-testing
2+
3+
exports[`lowercase: if-else-function.sql 1`] = `
4+
"begin
5+
if val > 100 then
6+
return 'large';
7+
elsif val > 10 then
8+
return 'medium';
9+
else
10+
return 'small';
11+
end if;
12+
return;
13+
end"
14+
`;
15+
16+
exports[`lowercase: loop-function.sql 1`] = `
17+
"declare
18+
total integer := 0;
19+
i integer;
20+
begin
21+
for i in 1..n loop
22+
total := total + i;
23+
end loop;
24+
return total;
25+
end"
26+
`;
27+
28+
exports[`lowercase: simple-function.sql 1`] = `
29+
"begin
30+
return a + b;
31+
end"
32+
`;
33+
34+
exports[`uppercase: if-else-function.sql 1`] = `
35+
"BEGIN
36+
IF val > 100 THEN
37+
RETURN 'large';
38+
ELSIF val > 10 THEN
39+
RETURN 'medium';
40+
ELSE
41+
RETURN 'small';
42+
END IF;
43+
RETURN;
44+
END"
45+
`;
46+
47+
exports[`uppercase: loop-function.sql 1`] = `
48+
"DECLARE
49+
total integer := 0;
50+
i integer;
51+
BEGIN
52+
FOR i IN 1..n LOOP
53+
total := total + i;
54+
END LOOP;
55+
RETURN total;
56+
END"
57+
`;
58+
59+
exports[`uppercase: simple-function.sql 1`] = `
60+
"BEGIN
61+
RETURN a + b;
62+
END"
63+
`;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import { PlpgsqlPrettyTest } from '../../test-utils';
2+
3+
const prettyTest = new PlpgsqlPrettyTest([
4+
'simple-function.sql',
5+
'if-else-function.sql',
6+
'loop-function.sql',
7+
]);
8+
9+
prettyTest.generateTests();

packages/plpgsql-deparser/src/plpgsql-deparser.ts

Lines changed: 173 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,14 @@ export class PLpgSQLDeparser {
137137

138138
const parts: string[] = [];
139139

140-
// Deparse DECLARE section (local variables)
141-
const declareSection = this.deparseDeclareSection(func.datums, context);
140+
// Collect loop-introduced variables before generating DECLARE section
141+
const loopVarLinenos = new Set<number>();
142+
if (func.action) {
143+
this.collectLoopVariables(func.action, loopVarLinenos);
144+
}
145+
146+
// Deparse DECLARE section (local variables, excluding loop variables)
147+
const declareSection = this.deparseDeclareSection(func.datums, context, loopVarLinenos);
142148
if (declareSection) {
143149
parts.push(declareSection);
144150
}
@@ -151,15 +157,162 @@ export class PLpgSQLDeparser {
151157
return parts.join(this.options.newline);
152158
}
153159

160+
/**
161+
* Collect line numbers of variables introduced by loop constructs.
162+
* Only adds a variable's lineno if it matches the loop statement's lineno,
163+
* indicating the variable was implicitly declared by the loop (not explicitly in DECLARE).
164+
*/
165+
private collectLoopVariables(stmt: PLpgSQLStmtNode, loopVarLinenos: Set<number>): void {
166+
if ('PLpgSQL_stmt_block' in stmt) {
167+
const block = stmt.PLpgSQL_stmt_block;
168+
if (block.body) {
169+
for (const s of block.body) {
170+
this.collectLoopVariables(s, loopVarLinenos);
171+
}
172+
}
173+
} else if ('PLpgSQL_stmt_fori' in stmt) {
174+
// Integer FOR loop - only exclude if var.lineno matches stmt.lineno (implicit declaration)
175+
const fori = stmt.PLpgSQL_stmt_fori;
176+
const stmtLineno = fori.lineno;
177+
if (fori.var && 'PLpgSQL_var' in fori.var) {
178+
const varLineno = fori.var.PLpgSQL_var.lineno;
179+
if (varLineno !== undefined && varLineno === stmtLineno) {
180+
loopVarLinenos.add(varLineno);
181+
}
182+
}
183+
if (fori.body) {
184+
for (const s of fori.body) {
185+
this.collectLoopVariables(s, loopVarLinenos);
186+
}
187+
}
188+
} else if ('PLpgSQL_stmt_fors' in stmt) {
189+
// Query FOR loop - only exclude if var.lineno matches stmt.lineno (implicit declaration)
190+
const fors = stmt.PLpgSQL_stmt_fors;
191+
const stmtLineno = fors.lineno;
192+
if (fors.var && 'PLpgSQL_rec' in fors.var) {
193+
const varLineno = fors.var.PLpgSQL_rec.lineno;
194+
if (varLineno !== undefined && varLineno === stmtLineno) {
195+
loopVarLinenos.add(varLineno);
196+
}
197+
}
198+
if (fors.var && 'PLpgSQL_row' in fors.var) {
199+
const varLineno = fors.var.PLpgSQL_row.lineno;
200+
if (varLineno !== undefined && varLineno === stmtLineno) {
201+
loopVarLinenos.add(varLineno);
202+
}
203+
}
204+
if (fors.body) {
205+
for (const s of fors.body) {
206+
this.collectLoopVariables(s, loopVarLinenos);
207+
}
208+
}
209+
} else if ('PLpgSQL_stmt_forc' in stmt) {
210+
// Cursor FOR loop - only exclude if var.lineno matches stmt.lineno (implicit declaration)
211+
const forc = stmt.PLpgSQL_stmt_forc;
212+
const stmtLineno = forc.lineno;
213+
if (forc.var && 'PLpgSQL_rec' in forc.var) {
214+
const varLineno = forc.var.PLpgSQL_rec.lineno;
215+
if (varLineno !== undefined && varLineno === stmtLineno) {
216+
loopVarLinenos.add(varLineno);
217+
}
218+
}
219+
if (forc.body) {
220+
for (const s of forc.body) {
221+
this.collectLoopVariables(s, loopVarLinenos);
222+
}
223+
}
224+
} else if ('PLpgSQL_stmt_foreach_a' in stmt) {
225+
// FOREACH loop - uses varno reference, not embedded var
226+
// The variable is referenced by index, so we can't easily exclude it here
227+
// Just recurse into the body
228+
const foreach = stmt.PLpgSQL_stmt_foreach_a;
229+
if (foreach.body) {
230+
for (const s of foreach.body) {
231+
this.collectLoopVariables(s, loopVarLinenos);
232+
}
233+
}
234+
} else if ('PLpgSQL_stmt_dynfors' in stmt) {
235+
// Dynamic FOR loop - only exclude if var.lineno matches stmt.lineno (implicit declaration)
236+
const dynfors = stmt.PLpgSQL_stmt_dynfors;
237+
const stmtLineno = dynfors.lineno;
238+
if (dynfors.var && 'PLpgSQL_rec' in dynfors.var) {
239+
const varLineno = dynfors.var.PLpgSQL_rec.lineno;
240+
if (varLineno !== undefined && varLineno === stmtLineno) {
241+
loopVarLinenos.add(varLineno);
242+
}
243+
}
244+
if (dynfors.body) {
245+
for (const s of dynfors.body) {
246+
this.collectLoopVariables(s, loopVarLinenos);
247+
}
248+
}
249+
} else if ('PLpgSQL_stmt_if' in stmt) {
250+
const ifStmt = stmt.PLpgSQL_stmt_if;
251+
if (ifStmt.then_body) {
252+
for (const s of ifStmt.then_body) {
253+
this.collectLoopVariables(s, loopVarLinenos);
254+
}
255+
}
256+
if (ifStmt.elsif_list) {
257+
for (const elsif of ifStmt.elsif_list) {
258+
if ('PLpgSQL_if_elsif' in elsif && elsif.PLpgSQL_if_elsif.stmts) {
259+
for (const s of elsif.PLpgSQL_if_elsif.stmts) {
260+
this.collectLoopVariables(s, loopVarLinenos);
261+
}
262+
}
263+
}
264+
}
265+
if (ifStmt.else_body) {
266+
for (const s of ifStmt.else_body) {
267+
this.collectLoopVariables(s, loopVarLinenos);
268+
}
269+
}
270+
} else if ('PLpgSQL_stmt_case' in stmt) {
271+
const caseStmt = stmt.PLpgSQL_stmt_case;
272+
if (caseStmt.case_when_list) {
273+
for (const when of caseStmt.case_when_list) {
274+
if ('PLpgSQL_case_when' in when && when.PLpgSQL_case_when.stmts) {
275+
for (const s of when.PLpgSQL_case_when.stmts) {
276+
this.collectLoopVariables(s, loopVarLinenos);
277+
}
278+
}
279+
}
280+
}
281+
if (caseStmt.have_else && caseStmt.else_stmts) {
282+
for (const s of caseStmt.else_stmts) {
283+
this.collectLoopVariables(s, loopVarLinenos);
284+
}
285+
}
286+
} else if ('PLpgSQL_stmt_loop' in stmt) {
287+
const loop = stmt.PLpgSQL_stmt_loop;
288+
if (loop.body) {
289+
for (const s of loop.body) {
290+
this.collectLoopVariables(s, loopVarLinenos);
291+
}
292+
}
293+
} else if ('PLpgSQL_stmt_while' in stmt) {
294+
const whileStmt = stmt.PLpgSQL_stmt_while;
295+
if (whileStmt.body) {
296+
for (const s of whileStmt.body) {
297+
this.collectLoopVariables(s, loopVarLinenos);
298+
}
299+
}
300+
}
301+
}
302+
154303
/**
155304
* Deparse the DECLARE section
156305
*/
157-
private deparseDeclareSection(datums: PLpgSQLDatum[] | undefined, context: PLpgSQLDeparserContext): string {
306+
private deparseDeclareSection(
307+
datums: PLpgSQLDatum[] | undefined,
308+
context: PLpgSQLDeparserContext,
309+
loopVarLinenos: Set<number> = new Set()
310+
): string {
158311
if (!datums || datums.length === 0) {
159312
return '';
160313
}
161314

162-
// Filter out internal variables (like 'found', parameters, etc.)
315+
// Filter out internal variables (like 'found', parameters, etc.) and loop variables
163316
const localVars = datums.filter(datum => {
164317
if ('PLpgSQL_var' in datum) {
165318
const v = datum.PLpgSQL_var;
@@ -171,10 +324,22 @@ export class PLpgSQLDeparser {
171324
if (v.lineno === undefined) {
172325
return false;
173326
}
327+
// Skip loop-introduced variables
328+
if (loopVarLinenos.has(v.lineno)) {
329+
return false;
330+
}
174331
return true;
175332
}
176333
if ('PLpgSQL_rec' in datum) {
177-
return datum.PLpgSQL_rec.lineno !== undefined;
334+
const rec = datum.PLpgSQL_rec;
335+
if (rec.lineno === undefined) {
336+
return false;
337+
}
338+
// Skip loop-introduced records
339+
if (loopVarLinenos.has(rec.lineno)) {
340+
return false;
341+
}
342+
return true;
178343
}
179344
return false;
180345
});
@@ -1087,15 +1252,11 @@ export class PLpgSQLDeparser {
10871252
* Deparse a CALL statement
10881253
*/
10891254
private deparseCall(call: PLpgSQL_stmt_call, context: PLpgSQLDeparserContext): string {
1090-
const kw = this.keyword;
10911255
const expr = call.expr ? this.deparseExpr(call.expr) : '';
10921256

1093-
if (call.is_call) {
1094-
return `${kw('CALL')} ${expr}`;
1095-
}
1096-
1097-
// DO block
1098-
return `${kw('DO')} ${expr}`;
1257+
// The expression already contains the CALL keyword from the parser
1258+
// so we just return the expression as-is
1259+
return expr;
10991260
}
11001261

11011262
/**
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { parsePlPgSQLSync, loadModule } from '@libpg-query/parser';
2+
import { deparseSync, PLpgSQLParseResult } from '../src';
3+
import * as fs from 'fs';
4+
import * as path from 'path';
5+
6+
export class PlpgsqlPrettyTest {
7+
private testCases: string[];
8+
private fixturesDir: string;
9+
10+
constructor(testCases: string[]) {
11+
this.testCases = testCases;
12+
this.fixturesDir = path.join(__dirname, '../../../__fixtures__/plpgsql-pretty');
13+
}
14+
15+
generateTests(): void {
16+
beforeAll(async () => {
17+
await loadModule();
18+
});
19+
20+
this.testCases.forEach((testName) => {
21+
const filePath = path.join(this.fixturesDir, testName);
22+
23+
it(`uppercase: ${testName}`, () => {
24+
const sql = fs.readFileSync(filePath, 'utf-8').trim();
25+
const result = parsePlPgSQLSync(sql) as unknown as PLpgSQLParseResult;
26+
const deparsed = deparseSync(result, { uppercase: true });
27+
expect(deparsed).toMatchSnapshot();
28+
});
29+
30+
it(`lowercase: ${testName}`, () => {
31+
const sql = fs.readFileSync(filePath, 'utf-8').trim();
32+
const result = parsePlPgSQLSync(sql) as unknown as PLpgSQLParseResult;
33+
const deparsed = deparseSync(result, { uppercase: false });
34+
expect(deparsed).toMatchSnapshot();
35+
});
36+
});
37+
}
38+
}

packages/plpgsql-deparser/test-utils/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,4 @@ export class FixtureTestUtils extends PLpgSQLTestUtils {
407407

408408
export const testUtils = new PLpgSQLTestUtils();
409409
export const fixtureTestUtils = new FixtureTestUtils();
410+
export { PlpgsqlPrettyTest } from './PlpgsqlPrettyTest';

0 commit comments

Comments
 (0)