diff --git a/vtl-engine/src/main/java/fr/insee/vtl/engine/visitors/DAGBuildingVisitor.java b/vtl-engine/src/main/java/fr/insee/vtl/engine/visitors/DAGBuildingVisitor.java index c06be3fd3..068153a95 100644 --- a/vtl-engine/src/main/java/fr/insee/vtl/engine/visitors/DAGBuildingVisitor.java +++ b/vtl-engine/src/main/java/fr/insee/vtl/engine/visitors/DAGBuildingVisitor.java @@ -9,6 +9,7 @@ import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.antlr.v4.runtime.RuleContext; import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.RuleNode; import org.antlr.v4.runtime.tree.TerminalNode; @@ -154,10 +155,10 @@ protected List defaultResult() { * VarIDsExtractingVisitor is the visitor for extracting the used VarIds from VTL * statements. */ - private static class IdentifierExtractingVisitor - extends VtlBaseVisitor> { + static class IdentifierExtractingVisitor extends VtlBaseVisitor> { private final Set ignoreInnerScopedVarIdentifiers; + private int componentContextDepth = 0; public IdentifierExtractingVisitor() { this(Set.of()); @@ -168,51 +169,127 @@ public IdentifierExtractingVisitor(Set ignoreInnerScopedVarIdentifiers) this.ignoreInnerScopedVarIdentifiers = ignoreInnerScopedVarIdentifiers; } + private Set enterRestrictedContext(RuleContext ctx) { + componentContextDepth++; + Set result = super.visitChildren(ctx); + componentContextDepth--; + return result; + } + @Override public Set visitVarID(VtlParser.VarIDContext node) { final var currentVarIdentifier = node.IDENTIFIER().getSymbol().getText(); - final Set thisResult = - ignoreInnerScopedVarIdentifiers.contains(currentVarIdentifier) - ? Set.of() - : Set.of( - new DAGStatement.Identifier( - DAGStatement.Identifier.Type.VARIABLE, currentVarIdentifier)); - final Set subResult = this.visitChildren(node); - return aggregateResult(thisResult, subResult); + + // If we are inside a component context (depth > 0), we ignore this identifier + // because it is a component name, not a dataset dependency, according to the unadjusted VTL + // syntax + if (componentContextDepth > 0) { + return Set.of(); + } + + return ignoreInnerScopedVarIdentifiers.contains(currentVarIdentifier) + ? Set.of() + : Set.of( + new DAGStatement.Identifier( + DAGStatement.Identifier.Type.VARIABLE, currentVarIdentifier)); } + // Workaround for https://github.com/InseeFr/Trevas/issues/457, as long the open points in here + // are not clarified and https://github.com/InseeFr/Trevas/issues/355 is not implemented + // If we are inside a component context (depth > 0), we ignore this identifier + // because it is a component name, not a dataset dependency, according to the unadjusted VTL + // syntax @Override - public Set visitOperatorID(VtlParser.OperatorIDContext node) { - final Set thisResult = - Set.of( - new DAGStatement.Identifier( - DAGStatement.Identifier.Type.OPERATOR, node.IDENTIFIER().getSymbol().getText())); - final Set subResult = this.visitChildren(node); - return aggregateResult(thisResult, subResult); + public Set visitFilterClause(VtlParser.FilterClauseContext ctx) { + return enterRestrictedContext(ctx); + } + + @Override + public Set visitCalcClauseItem(VtlParser.CalcClauseItemContext ctx) { + return enterRestrictedContext(ctx); + } + + @Override + public Set visitHavingClause(VtlParser.HavingClauseContext ctx) { + return enterRestrictedContext(ctx); + } + + @Override + public Set visitJoinApplyClause(VtlParser.JoinApplyClauseContext ctx) { + return enterRestrictedContext(ctx); + } + + @Override + public Set visitAggrFunctionClause( + VtlParser.AggrFunctionClauseContext ctx) { + return enterRestrictedContext(ctx); + } + + // Aggregate & Analytic Functions (First arg is Dataset, rest are components) + @Override + public Set visitAggrDataset(VtlParser.AggrDatasetContext ctx) { + Set datasetRef = visit(ctx.expr()); + return aggregateResult(datasetRef, enterRestrictedContext(ctx)); + } + + @Override + public Set visitAnSimpleFunction( + VtlParser.AnSimpleFunctionContext ctx) { + Set datasetRef = visit(ctx.expr()); + return aggregateResult(datasetRef, enterRestrictedContext(ctx)); + } + + @Override + public Set visitLagOrLeadAn(VtlParser.LagOrLeadAnContext ctx) { + Set datasetRef = visit(ctx.expr()); + return aggregateResult(datasetRef, enterRestrictedContext(ctx)); + } + + @Override + public Set visitRatioToReportAn(VtlParser.RatioToReportAnContext ctx) { + Set datasetRef = visit(ctx.expr()); + return aggregateResult(datasetRef, enterRestrictedContext(ctx)); + } + + @Override + public Set visitRankAn(VtlParser.RankAnContext ctx) { + + return enterRestrictedContext(ctx); + } + + @Override + public Set visitMembershipExpr(VtlParser.MembershipExprContext ctx) { + // Only visit the dataset (left side), ignore the component (right side) + return visit(ctx.expr()); } @Override public Set visitValidateDPruleset( VtlParser.ValidateDPrulesetContext node) { - final Set thisResult = + Set rulesetRef = Set.of( new DAGStatement.Identifier( DAGStatement.Identifier.Type.RULESET_DATAPOINT, node.IDENTIFIER().getSymbol().getText())); - final Set subResult = this.visitChildren(node); - return aggregateResult(thisResult, subResult); + return aggregateResult(rulesetRef, visit(node.expr())); } @Override public Set visitValidateHRruleset( VtlParser.ValidateHRrulesetContext node) { - final Set thisResult = + Set rulesetRef = Set.of( new DAGStatement.Identifier( DAGStatement.Identifier.Type.RULESET_HIERARCHICAL, node.IDENTIFIER().getSymbol().getText())); - final Set subResult = this.visitChildren(node); - return aggregateResult(thisResult, subResult); + return aggregateResult(rulesetRef, visit(node.expr())); + } + + @Override + public Set visitOperatorID(VtlParser.OperatorIDContext node) { + return Set.of( + new DAGStatement.Identifier( + DAGStatement.Identifier.Type.OPERATOR, node.IDENTIFIER().getSymbol().getText())); } @Override diff --git a/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagDefineStatementsTest.java b/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagDefineStatementsTest.java index 3b27ed62e..8881fa042 100644 --- a/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagDefineStatementsTest.java +++ b/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagDefineStatementsTest.java @@ -6,10 +6,12 @@ import fr.insee.vtl.model.exceptions.VtlScriptException; import fr.insee.vtl.parser.VtlLexer; import fr.insee.vtl.parser.VtlParser; -import java.util.*; +import java.util.Set; import java.util.stream.Stream; import javax.script.ScriptException; -import org.antlr.v4.runtime.*; +import org.antlr.v4.runtime.CharStreams; +import org.antlr.v4.runtime.CodePointCharStream; +import org.antlr.v4.runtime.CommonTokenStream; import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.RuleNode; import org.antlr.v4.runtime.tree.TerminalNode; @@ -21,15 +23,19 @@ public class DagDefineStatementsTest { private static String performDagReordering(final String script, final Set bindingVars) throws VtlScriptException { - final CodePointCharStream stream = CharStreams.fromString(script); - VtlLexer lexer = new VtlLexer(stream); - VtlParser parser = new VtlParser(new CommonTokenStream(lexer)); - var start = parser.start(); + VtlParser.StartContext start = parseScript(script); VtlSyntaxPreprocessor syntaxPreprocessor = new VtlSyntaxPreprocessor(start, bindingVars); VtlParser.StartContext res = syntaxPreprocessor.checkForMultipleAssignmentsAndReorderScript(); return parseTreeToText(res); } + public static VtlParser.StartContext parseScript(String script) { + final CodePointCharStream stream = CharStreams.fromString(script); + VtlLexer lexer = new VtlLexer(stream); + VtlParser parser = new VtlParser(new CommonTokenStream(lexer)); + return parser.start(); + } + private static String parseTreeToText(ParseTree child) { StringBuilder result = new StringBuilder(); if (child instanceof TerminalNode) { diff --git a/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagTest.java b/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagTest.java index fcb4a6037..ab9da8ab8 100644 --- a/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagTest.java +++ b/vtl-engine/src/test/java/fr/insee/vtl/engine/utils/dag/DagTest.java @@ -297,8 +297,6 @@ void testMultipleCycles() { void testDagDoubleAssignment() { ScriptContext context = engine.getContext(); context.setAttribute("a", 1L, ScriptContext.ENGINE_SCOPE); - // Note that the double assignment is not detected while building the DAG but later during - // execution assertThatThrownBy(() -> engine.eval("b := a; b := 1;")) .isInstanceOf(VtlScriptException.class) .hasMessage("Dataset b has already been assigned"); diff --git a/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/IdentifierExtractingVisitorTest.java b/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/IdentifierExtractingVisitorTest.java new file mode 100644 index 000000000..dd533e9e9 --- /dev/null +++ b/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/IdentifierExtractingVisitorTest.java @@ -0,0 +1,207 @@ +package fr.insee.vtl.engine.visitors; + +import static fr.insee.vtl.engine.utils.dag.DagDefineStatementsTest.parseScript; +import static org.assertj.core.api.Assertions.assertThat; + +import fr.insee.vtl.engine.utils.dag.DAGStatement; +import java.util.Set; +import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; + +/* +As workaround for https://github.com/InseeFr/Trevas/issues/457, as long the open points in here +are not clarified and https://github.com/InseeFr/Trevas/issues/355 is not implemented +the visitor is currently skipping varIds which would be components in the unadjusted VTL grammar +*/ +class IdentifierExtractingVisitorTest { + + private final DAGBuildingVisitor.IdentifierExtractingVisitor visitor = + new DAGBuildingVisitor.IdentifierExtractingVisitor(); + + private static @NotNull Set getAllNames(Set result) { + return result.stream().map(DAGStatement.Identifier::name).collect(Collectors.toSet()); + } + + @Test + void testBasicAssignment() { + String script = "ds1 := ds2;"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + } + + @Test + void testFilterClause() { + String script = "ds1 := ds2 [filter component_a > 10];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + + assertThat(getAllNames(result)).doesNotContain("component_a"); + } + + @Test + void testCalcClause() { + String script = "ds1 := ds2 [calc comp_new := comp_old + 5];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + + assertThat(getAllNames(result)).doesNotContain("comp_new", "comp_old"); + } + + @Test + void testMembershipExpression() { + String script = "ds1 := ds2#component_a;"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + + assertThat(getAllNames(result)).doesNotContain("component_a"); + } + + @Test + void testMin() { + String script = "ds1 := min(ds2);"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + } + + @Test + void testAnalyticFunctions() { + String script = "ds1 := ds2[calc x := lag(other, 1) over (partition by comp_a)];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + + assertThat(getAllNames(result)).doesNotContain("comp_a", "x", "other"); + } + + @Test + void testAggrClause() { + String script = "ds1 := ds2 [aggregate comp_sum := sum(comp_val) group by comp_grp];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2")); + + assertThat(getAllNames(result)).doesNotContain("comp_sum", "comp_val", "comp_grp"); + } + + @Test + void testJoinWithAliases() { + String script = "ds1 := inner_join(ds2 as a, ds3 as b using comp_id);"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds3")); + assertThat(getAllNames(result)).doesNotContain("a", "b", "comp_id"); + } + + @Test + void testValidationRulesets() { + String script = "ds1 := check_datapoint(ds2, my_ruleset);"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2"), + new DAGStatement.Identifier( + DAGStatement.Identifier.Type.RULESET_DATAPOINT, "my_ruleset")); + } + + @Test + void testHierarchicalValidation() { + String script = "ds1 := check_hierarchy(ds2, my_h_ruleset);"; + Set result = visitor.visit(parseScript(script)); + + assertThat(result) + .containsExactlyInAnyOrder( + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds1"), + new DAGStatement.Identifier(DAGStatement.Identifier.Type.VARIABLE, "ds2"), + new DAGStatement.Identifier( + DAGStatement.Identifier.Type.RULESET_HIERARCHICAL, "my_h_ruleset")); + } + + @Test + void testHavingClause() { + String script = "ds1 := ds2 [aggregate sum_val := sum(val) group by grp having sum_val > 0];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(getAllNames(result)).containsExactlyInAnyOrder("ds1", "ds2"); + assertThat(getAllNames(result)).doesNotContain("sum_val", "val", "grp"); + } + + @Test + void testJoinApply() { + String script = "ds1 := inner_join(ds2 as a, ds3 as b apply comp_new := comp_old * 2);"; + Set result = visitor.visit(parseScript(script)); + + assertThat(getAllNames(result)).containsExactlyInAnyOrder("ds1", "ds2", "ds3"); + assertThat(getAllNames(result)).doesNotContain("comp_new", "comp_old", "a", "b"); + } + + @Test + void testRankAnalytic() { + String script = "ds1 := ds2 [calc r := rank() over (partition by p order by o)];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(getAllNames(result)).containsExactlyInAnyOrder("ds1", "ds2"); + assertThat(getAllNames(result)).doesNotContain("r", "p", "o"); + } + + @Test + void testRatioToReportAnalytic() { + String script = "ds1 := ratio_to_report(ds2, comp_a) over (partition by comp_b);"; + Set result = visitor.visit(parseScript(script)); + + assertThat(getAllNames(result)).containsExactlyInAnyOrder("ds1", "ds2"); + assertThat(getAllNames(result)).doesNotContain("comp_a", "comp_b"); + } + + @Test + void testIgnoreInnerScopedVars() { + var customVisitor = new DAGBuildingVisitor.IdentifierExtractingVisitor(Set.of("ds_ignored")); + String script = "ds1 := ds2 + ds_ignored;"; + Set result = customVisitor.visit(parseScript(script)); + + assertThat(getAllNames(result)).containsExactlyInAnyOrder("ds1", "ds2"); + assertThat(getAllNames(result)).doesNotContain("ds_ignored"); + } + + @Test + void testComplexChainedClauses() { + String script = + "ds1 := ds2 [filter c1 > 0] [calc c2 := c1 + 1] [calc c3 := sum(c2) over (partition by c1)];"; + Set result = visitor.visit(parseScript(script)); + + assertThat(getAllNames(result)).containsExactlyInAnyOrder("ds1", "ds2"); + assertThat(getAllNames(result)).doesNotContain("c1", "c2", "c3"); + } +} diff --git a/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/expression/functions/SetFunctionsVisitorTest.java b/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/expression/functions/SetFunctionsVisitorTest.java index 510b58d6d..42852a213 100644 --- a/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/expression/functions/SetFunctionsVisitorTest.java +++ b/vtl-engine/src/test/java/fr/insee/vtl/engine/visitors/expression/functions/SetFunctionsVisitorTest.java @@ -260,7 +260,6 @@ public void testUnion456Issue() throws ScriptException { Arrays.asList("T01", "LOOP-02", "foo12", "foo1"), Arrays.asList("T02", null, "foo21", "foo2")); - engine.put("$vtl.engine.use_dag", "false"); ScriptContext context = engine.getContext(); context.getBindings(ScriptContext.ENGINE_SCOPE).put("MULTIMODE", multimodeDs); diff --git a/vtl-spark/src/main/java/fr/insee/vtl/spark/SparkProcessingEngine.java b/vtl-spark/src/main/java/fr/insee/vtl/spark/SparkProcessingEngine.java index a1a7a0064..a24e52591 100644 --- a/vtl-spark/src/main/java/fr/insee/vtl/spark/SparkProcessingEngine.java +++ b/vtl-spark/src/main/java/fr/insee/vtl/spark/SparkProcessingEngine.java @@ -164,6 +164,9 @@ private SparkDataset asSparkDataset(DatasetExpression expression) { return datasetExpression.resolve(Map.of()); } else { var dataset = expression.resolve(Map.of()); + if (dataset instanceof PersistentDataset persistentDataset) { + dataset = persistentDataset.getDelegate(); + } if (dataset instanceof SparkDataset sparkDataset) { return sparkDataset; } else {