Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/vtlengine/API/_InternalApi.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,18 @@ def _load_dataset_from_structure(
# Support both 'type' and 'data_type' for backward compatibility
_, scalar_type = _extract_data_type(component)
if component["role"] == "ViralAttribute":
component["role"] = "Attribute"
component["role"] = "Viral Attribute"

check_key("role", Role_keys, component["role"])

if "nullable" not in component:
if Role(component["role"]) == Role.IDENTIFIER:
component["nullable"] = False
elif Role(component["role"]) in (Role.MEASURE, Role.ATTRIBUTE):
elif Role(component["role"]) in (
Role.MEASURE,
Role.ATTRIBUTE,
Role.VIRAL_ATTRIBUTE,
):
component["nullable"] = True
else:
component["nullable"] = False
Expand All @@ -142,6 +146,8 @@ def _load_dataset_from_structure(
for component in dataset_json["DataStructure"]:
# Support both 'type' and 'data_type' for backward compatibility
_, scalar_type = _extract_data_type(component)
if component["role"] == "ViralAttribute":
component["role"] = "Viral Attribute"
check_key("role", Role_keys, component["role"])
components[component["name"]] = VTL_Component(
name=component["name"],
Expand Down
2 changes: 1 addition & 1 deletion src/vtlengine/API/data/schema/json_schema_2.1.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"properties": {
"role": {
"type": "string",
"enum": [ "Identifier", "Measure", "Attribute", "Viral Attribute" ]
"enum": [ "Identifier", "Measure", "Attribute", "Viral Attribute", "ViralAttribute" ]
},
"subset": { "$ref": "#/$defs/vtl-id" },
"nullable": { "type": "boolean" },
Expand Down
101 changes: 101 additions & 0 deletions src/vtlengine/AST/ASTConstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,26 @@
Node Creator.
"""

from typing import List, Optional

from antlr4.tree.Tree import TerminalNodeImpl

from vtlengine.AST import (
AggregateVpClause,
Argument,
Assignment,
DefIdentifier,
DPRule,
DPRuleset,
EnumeratedVpClause,
HRBinOp,
HRule,
HRuleset,
HRUnOp,
Operator,
PersistentAssignment,
Start,
ViralPropagationDef,
)
from vtlengine.AST.ASTConstructorModules import extract_token_info
from vtlengine.AST.ASTConstructorModules.Expr import Expr
Expand Down Expand Up @@ -150,6 +155,102 @@ def visitDefineExpression(self, ctx: Parser.DefineExpressionContext):
elif isinstance(ctx, Parser.DefHierarchicalContext):
return self.visitDefHierarchical(ctx)

elif isinstance(ctx, Parser.DefViralPropagationContext):
return self.visitDefViralPropagation(ctx)

def visitDefViralPropagation(self, ctx: Parser.DefViralPropagationContext):
"""
DEFINE VIRAL PROPAGATION varID LPAREN vpSignature RPAREN IS
vpBody END VIRAL PROPAGATION # defViralPropagation
"""
ctx_list = list(ctx.getChildren())

propagation_name = Terminals().visitVarID(ctx_list[3]).value
signature_type, target = self.visitVpSignature(ctx_list[5])
enumerated_clauses, aggregate_clause, default_value = self.visitVpBody(ctx_list[8])

token_info = extract_token_info(ctx)

return ViralPropagationDef(
name=propagation_name,
signature_type=signature_type,
target=target,
enumerated_clauses=enumerated_clauses,
aggregate_clause=aggregate_clause,
default_value=default_value,
**token_info,
)

def visitVpSignature(self, ctx: Parser.VpSignatureContext):
"""vpSignature: VALUE_DOMAIN varID | VARIABLE varID ;"""
ctx_list = list(ctx.getChildren())
signature_type = ctx_list[0].getSymbol().text
target = Terminals().visitVarID(ctx_list[1]).value
return signature_type, target

def visitVpBody(self, ctx: Parser.VpBodyContext):
"""vpBody: vpClause (EOL vpClause)* ;"""
ctx_list = list(ctx.getChildren())
enumerated_clauses: List[EnumeratedVpClause] = []
aggregate_clause: Optional[AggregateVpClause] = None
default_value: Optional[str] = None

for child in ctx_list:
if isinstance(child, Parser.EnumeratedVpClauseContext):
enumerated_clauses.append(self.visitEnumeratedVpClause(child))
elif isinstance(child, Parser.AggregationVpClauseContext):
aggregate_clause = self.visitAggregationVpClause(child)
elif isinstance(child, Parser.DefaultVpClauseContext):
default_value = self.visitDefaultVpClause(child)

return enumerated_clauses, aggregate_clause, default_value

def visitEnumeratedVpClause(self, ctx: Parser.EnumeratedVpClauseContext):
"""enumeratedVpClause: (IDENTIFIER COLON)? WHEN vpCondition THEN constant ;"""
ctx_list = list(ctx.getChildren())
rule_name: Optional[str] = None
values: List[str] = []
result: str = ""

i = 0
# Optional rule name: IDENTIFIER COLON
if ctx_list[i].getSymbol().type == Parser.IDENTIFIER:
rule_name = ctx_list[i].getSymbol().text
i += 2 # skip IDENTIFIER and COLON

i += 1 # skip WHEN
# vpCondition
values = self.visitVpCondition(ctx_list[i])
i += 1

i += 1 # skip THEN
# constant (result)
result = Terminals().visitConstant(ctx_list[i]).value

token_info = extract_token_info(ctx)
return EnumeratedVpClause(name=rule_name, values=values, result=result, **token_info)

def visitAggregationVpClause(self, ctx: Parser.AggregationVpClauseContext):
"""aggregationVpClause: AGGREGATE (MIN | MAX | SUM | AVG) ;"""
ctx_list = list(ctx.getChildren())
function = ctx_list[1].getSymbol().text
token_info = extract_token_info(ctx)
return AggregateVpClause(function=function, **token_info)

def visitDefaultVpClause(self, ctx: Parser.DefaultVpClauseContext):
"""defaultVpClause: ELSE constant ;"""
ctx_list = list(ctx.getChildren())
return Terminals().visitConstant(ctx_list[1]).value

def visitVpCondition(self, ctx: Parser.VpConditionContext):
"""vpCondition: constant (AND constant)? ;"""
ctx_list = list(ctx.getChildren())
values = []
for child in ctx_list:
if isinstance(child, Parser.ConstantContext):
values.append(Terminals().visitConstant(child).value)
return values

def visitDefOperator(self, ctx: Parser.DefOperatorContext):
"""
DEFINE OPERATOR operatorID LPAREN (parameterItem (COMMA parameterItem)*)? RPAREN
Expand Down
6 changes: 1 addition & 5 deletions src/vtlengine/AST/ASTConstructorModules/Terminals.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,7 @@ def visitViralAttribute(self, ctx: Parser.ViralAttributeContext):
"""
viralAttribute: VIRAL ATTRIBUTE;
"""
# ctx_list = list(ctx.getChildren())
# c = ctx_list[0]
# token = c.getSymbol()

raise NotImplementedError
return Role.VIRAL_ATTRIBUTE

def visitLists(self, ctx: Parser.ListsContext):
"""
Expand Down
35 changes: 35 additions & 0 deletions src/vtlengine/AST/ASTString.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,41 @@ def visit_DPRuleset(self, node: AST.DPRuleset) -> None:
f"({signature}) is {rules} end datapoint ruleset;"
)

# ---------------------- Viral Propagation ----------------------
def visit_ViralPropagationDef(self, node: AST.ViralPropagationDef) -> None:
clauses_strs: list[str] = []
for clause in node.enumerated_clauses:
clause_str = ""
if clause.name is not None:
clause_str += f"{clause.name} : "
values_str = " and ".join([f'"{v}"' for v in clause.values])
clause_str += f'when {values_str} then "{clause.result}"'
clauses_strs.append(clause_str)
if node.aggregate_clause is not None:
clauses_strs.append(f"aggr {node.aggregate_clause.function}")
if node.default_value is not None:
clauses_strs.append(f'else "{node.default_value}"')

if self.pretty:
self.vtl_script += (
f"define viral propagation {node.name}({node.signature_type} {node.target}) is{nl}"
)
for i, c in enumerate(clauses_strs):
self.vtl_script += f"{tab}{c}"
if i != len(clauses_strs) - 1:
self.vtl_script += f";{nl}"
else:
self.vtl_script += nl
self.vtl_script += f"end viral propagation;{nl}"
else:
clauses_joined = ";".join(clauses_strs)
self.vtl_script += (
f"define viral propagation {node.name} "
f"({node.signature_type} {node.target}) is "
f"{clauses_joined} "
f"end viral propagation;"
)

# ---------------------- User Defined Operators ----------------------

def visit_Argument(self, node: AST.Argument) -> str:
Expand Down
9 changes: 9 additions & 0 deletions src/vtlengine/AST/ASTTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,12 @@ def visit_Comment(self, node: AST.Comment) -> None:
"""
Comment: (value)
"""

def visit_ViralPropagationDef(self, node: AST.ViralPropagationDef) -> None:
pass

def visit_EnumeratedVpClause(self, node: AST.EnumeratedVpClause) -> None:
pass

def visit_AggregateVpClause(self, node: AST.AggregateVpClause) -> None:
pass
7 changes: 5 additions & 2 deletions src/vtlengine/AST/DAG/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Start,
UDOCall,
VarID,
ViralPropagationDef,
)
from vtlengine.AST.ASTTemplate import ASTTemplate
from vtlengine.AST.DAG._models import DatasetSchedule, StatementDeps
Expand Down Expand Up @@ -132,7 +133,9 @@ def create_dag(cls, ast: Start) -> "DAGAnalyzer":
dag.sort_ast(ast)
else:
ml_statements: list = [
ml for ml in ast.children if not isinstance(ml, (HRuleset, DPRuleset, Operator))
ml
for ml in ast.children
if not isinstance(ml, (HRuleset, DPRuleset, Operator, ViralPropagationDef))
]
dag.check_overwriting(ml_statements)
return dag
Expand Down Expand Up @@ -207,7 +210,7 @@ def sort_ast(self, ast: AST) -> None:
ml_statements: list = [
node
for node in statements_nodes
if not isinstance(node, (HRuleset, DPRuleset, Operator))
if not isinstance(node, (HRuleset, DPRuleset, Operator, ViralPropagationDef))
]

intermediate = self.sort_elements(ml_statements)
Expand Down
20 changes: 20 additions & 0 deletions src/vtlengine/AST/Grammar/Vtl.g4
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ defOperators:
DEFINE OPERATOR operatorID LPAREN (parameterItem (COMMA parameterItem)*)? RPAREN (RETURNS outputParameterType)? IS (expr) END OPERATOR # defOperator
| DEFINE DATAPOINT RULESET rulesetID LPAREN rulesetSignature RPAREN IS ruleClauseDatapoint END DATAPOINT RULESET # defDatapointRuleset
| DEFINE HIERARCHICAL RULESET rulesetID LPAREN hierRuleSignature RPAREN IS ruleClauseHierarchical END HIERARCHICAL RULESET # defHierarchical
| DEFINE VIRAL PROPAGATION varID LPAREN vpSignature RPAREN IS vpBody END VIRAL PROPAGATION # defViralPropagation
;

vpSignature:
VALUE_DOMAIN varID
| VARIABLE varID
;

vpBody:
vpClause (EOL vpClause)*
;

vpClause:
(IDENTIFIER COLON)? WHEN vpCondition THEN constant # enumeratedVpClause
| AGGREGATE (MIN | MAX | SUM | AVG) # aggregationVpClause
| ELSE constant # defaultVpClause
;

vpCondition:
constant (AND constant)?
;

/* --------------------------------------------END DEFINE FUNCTIONS------------------------------------------------- */
Expand Down
1 change: 1 addition & 0 deletions src/vtlengine/AST/Grammar/VtlTokens.g4
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ lexer grammar VtlTokens;
EXP : 'exp';
ROLE : 'componentRole';
VIRAL : 'viral';
PROPAGATION : 'propagation';
CHARSET_MATCH : 'match_characters';
TYPE : 'type';
NVL : 'nvl';
Expand Down
Loading
Loading