Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/vtlengine/API/_InternalApi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Scalar,
ValueDomain,
)
from vtlengine.Model._case_insensitive_dict import CaseInsensitiveDict

# Cache SCALAR_TYPES keys for performance
_SCALAR_TYPE_KEYS = SCALAR_TYPES.keys()
Expand Down Expand Up @@ -94,8 +95,8 @@ def _load_dataset_from_structure(
"""
Loads a dataset with the structure given.
"""
datasets = {}
scalars = {}
datasets: CaseInsensitiveDict[Any] = CaseInsensitiveDict()
scalars: CaseInsensitiveDict[Any] = CaseInsensitiveDict()

if "datasets" in structures:
for dataset_json in structures["datasets"]:
Expand Down Expand Up @@ -417,12 +418,12 @@ def load_datasets(
if isinstance(data_structure, dict):
return _load_datastructure_single(data_structure, sdmx_mappings=sdmx_mappings)
if isinstance(data_structure, list):
ds_structures: Dict[str, Dataset] = {}
scalar_structures: Dict[str, Scalar] = {}
ds_structures: CaseInsensitiveDict[Dataset] = CaseInsensitiveDict()
scalar_structures: CaseInsensitiveDict[Scalar] = CaseInsensitiveDict()
for x in data_structure:
ds, sc = _load_datastructure_single(x, sdmx_mappings=sdmx_mappings)
ds_structures = {**ds_structures, **ds} # Overwrite ds_structures dict.
scalar_structures = {**scalar_structures, **sc} # Overwrite scalar_structures dict.
ds_structures.update(ds)
scalar_structures.update(sc)
return ds_structures, scalar_structures
return _load_datastructure_single(data_structure, sdmx_mappings=sdmx_mappings)

Expand Down
17 changes: 10 additions & 7 deletions src/vtlengine/AST/DAG/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,24 +180,27 @@ def load_edges(self) -> None:
for key, statement in self.dependencies.items():
reference = statement.outputs + statement.persistent
if reference:
ref_to_keys[reference[0]] = key
ref_to_keys[reference[0].casefold()] = key

for sub_key, sub_statement in self.dependencies.items():
for input_val in sub_statement.inputs:
if input_val in ref_to_keys:
key = ref_to_keys[input_val]
self.edges[count_edges] = (key, sub_key)
count_edges += 1
if input_val.casefold() in ref_to_keys:
key = ref_to_keys[input_val.casefold()]
if key != sub_key: # Skip self-edges (e.g. a <- A)
self.edges[count_edges] = (key, sub_key)
count_edges += 1

def sort_elements(self, statements: list) -> list:
return [statements[x - 1] for x in self.sorting] # type: ignore[union-attr]

def check_overwriting(self, statements: list) -> None:
seen: Set[str] = set()
for statement in statements:
if statement.left.value in seen:
# Case-insensitive check: regular VTL names are case-insensitive
normalized = statement.left.value.casefold()
if normalized in seen:
raise SemanticError("1-2-2", varId_value=statement.left.value)
seen.add(statement.left.value)
seen.add(normalized)

def sort_ast(self, ast: AST) -> None:
statements_nodes = ast.children
Expand Down
104 changes: 72 additions & 32 deletions src/vtlengine/Interpreter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
ScalarSet,
ValueDomain,
)
from vtlengine.Model._case_insensitive_dict import CaseInsensitiveDict
from vtlengine.Operators.Aggregation import extract_grouping_identifiers
from vtlengine.Operators.Assignment import Assignment
from vtlengine.Operators.CastOperator import Cast
Expand Down Expand Up @@ -155,6 +156,15 @@ class InterpreterAnalyzer(ASTTemplate):
signature_values: Optional[Dict[str, Any]] = None

def __post_init__(self) -> None:
# Ensure case-insensitive lookups for datasets, scalars, and value_domains
if not isinstance(self.datasets, CaseInsensitiveDict):
self.datasets = CaseInsensitiveDict(self.datasets)
if self.scalars is not None and not isinstance(self.scalars, CaseInsensitiveDict):
self.scalars = CaseInsensitiveDict(self.scalars)
if self.value_domains is not None and not isinstance(
self.value_domains, CaseInsensitiveDict
):
self.value_domains = CaseInsensitiveDict(self.value_domains)
self.datasets_inputs = set(self.datasets.keys())
self.scalars_inputs = set(self.scalars.keys()) if self.scalars else set()

Expand Down Expand Up @@ -236,7 +246,7 @@ def visit_Start(self, node: AST.Start) -> Any:
Operators.only_semantic = True
else:
Operators.only_semantic = False
results = {}
results: CaseInsensitiveDict[Any] = CaseInsensitiveDict()
scalars_to_save = set()
invalid_dataset_outputs = []
invalid_scalar_outputs = []
Expand Down Expand Up @@ -279,7 +289,7 @@ def visit_Start(self, node: AST.Start) -> Any:
if isinstance(result, Scalar):
scalars_to_save.add(result.name)
if self.scalars is None:
self.scalars = {}
self.scalars = CaseInsensitiveDict()
self.scalars[result.name] = copy(result)
self._save_datapoints_efficient(statement_num)
statement_num += 1
Expand All @@ -305,7 +315,7 @@ def visit_Start(self, node: AST.Start) -> Any:

def visit_Operator(self, node: AST.Operator) -> None:
if self.udos is None:
self.udos = {}
self.udos = CaseInsensitiveDict()
elif node.op in self.udos:
raise ValueError(f"User Defined Operator {node.op} already exists")

Expand Down Expand Up @@ -355,7 +365,7 @@ def visit_DPRuleset(self, node: AST.DPRuleset) -> None:
)

# Signature has the actual parameters names or aliases if provided
signature_actual_names = {}
signature_actual_names: Dict[str, str] = CaseInsensitiveDict()
if not isinstance(node.params, AST.DefIdentifier):
for param in node.params:
if param.alias is not None:
Expand All @@ -376,15 +386,15 @@ def visit_DPRuleset(self, node: AST.DPRuleset) -> None:

# Adding the ruleset to the dprs dictionary
if self.dprs is None:
self.dprs = {}
self.dprs = CaseInsensitiveDict()
elif node.name in self.dprs:
raise ValueError(f"Datapoint Ruleset {node.name} already exists")

self.dprs[node.name] = ruleset_data

def visit_HRuleset(self, node: AST.HRuleset) -> None:
if self.hrs is None:
self.hrs = {}
self.hrs = CaseInsensitiveDict()

if node.name in self.hrs:
raise ValueError(f"Hierarchical Ruleset {node.name} already exists")
Expand Down Expand Up @@ -574,6 +584,13 @@ def visit_Aggregation(self, node: AST.Aggregation) -> None:
for x in node.grouping:
groupings.append(self.visit(x))
self.is_from_grouping = False
# Resolve grouping names to canonical (original-case) component names
groupings = [
operand.resolve_component_name(g)
if isinstance(g, str) and g in operand.components
else g
for g in groupings
]
if grouping_op == "group all" or has_time_agg:
groupings = self._apply_time_agg_grouping(operand, groupings, grouping_op)
self.aggregation_dataset = None
Expand Down Expand Up @@ -834,16 +851,17 @@ def visit_VarID(self, node: AST.VarID) -> Any: # noqa: C901
comp_name=node.value,
dataset_name=self.aggregation_dataset.name,
)
canon = self.aggregation_dataset.resolve_component_name(node.value)
if self.aggregation_dataset.data is None:
data = None
else:
data = copy(self.aggregation_dataset.data[node.value])
data = copy(self.aggregation_dataset.data[canon])
return DataComponent(
name=node.value,
name=canon,
data=data,
data_type=self.aggregation_dataset.components[node.value].data_type,
role=self.aggregation_dataset.components[node.value].role,
nullable=self.aggregation_dataset.components[node.value].nullable,
data_type=self.aggregation_dataset.components[canon].data_type,
role=self.aggregation_dataset.components[canon].role,
nullable=self.aggregation_dataset.components[canon].nullable,
)
if self.is_from_regular_aggregation:
if self.is_from_join and node.value in self.datasets:
Expand Down Expand Up @@ -883,16 +901,17 @@ def visit_VarID(self, node: AST.VarID) -> Any: # noqa: C901
comp_name=node.value,
dataset_name=self.regular_aggregation_dataset.name,
)
canon = self.regular_aggregation_dataset.resolve_component_name(node.value)
if self.regular_aggregation_dataset.data is not None:
data = copy(self.regular_aggregation_dataset.data[node.value])
data = copy(self.regular_aggregation_dataset.data[canon])
else:
data = None
return DataComponent(
name=node.value,
name=canon,
data=data,
data_type=self.regular_aggregation_dataset.components[node.value].data_type,
role=self.regular_aggregation_dataset.components[node.value].role,
nullable=self.regular_aggregation_dataset.components[node.value].nullable,
data_type=self.regular_aggregation_dataset.components[canon].data_type,
role=self.regular_aggregation_dataset.components[canon].role,
nullable=self.regular_aggregation_dataset.components[canon].nullable,
)
if (
self.is_from_rule
Expand All @@ -908,6 +927,8 @@ def visit_VarID(self, node: AST.VarID) -> Any: # noqa: C901
comp_name=node.value,
dataset_name=self.ruleset_dataset.name,
)
# Resolve to canonical (original-case) name for DataFrame access
comp_name = self.ruleset_dataset.resolve_component_name(comp_name)
data = None if self.rule_data is None else self.rule_data[comp_name]
return DataComponent(
name=comp_name,
Expand Down Expand Up @@ -982,11 +1003,13 @@ def visit_RegularAggregation(self, node: AST.RegularAggregation) -> None: # noq
dataset = copy(operands[0])
if self.regular_aggregation_dataset is not None:
dataset.name = self.regular_aggregation_dataset.name
dataset.components = {
comp_name: comp
for comp_name, comp in dataset.components.items()
if comp.role != Role.MEASURE
}
dataset.components = CaseInsensitiveDict(
{
comp_name: comp
for comp_name, comp in dataset.components.items()
if comp.role != Role.MEASURE
}
)
if dataset.data is not None:
dataset.data = dataset.data[dataset.get_identifiers_names()]
aux_operands = []
Expand Down Expand Up @@ -1056,10 +1079,12 @@ def visit_RegularAggregation(self, node: AST.RegularAggregation) -> None: # noq
columns={col: col[col.find("#") + 1 :] for col in result.data.columns},
inplace=True,
)
result.components = {
comp_name[comp_name.find("#") + 1 :]: comp
for comp_name, comp in result.components.items()
}
result.components = CaseInsensitiveDict(
{
comp_name[comp_name.find("#") + 1 :]: comp
for comp_name, comp in result.components.items()
}
)
for comp in result.components.values():
comp.name = comp.name[comp.name.find("#") + 1 :]
if result.data is not None:
Expand Down Expand Up @@ -1199,6 +1224,13 @@ def visit_RenameNode(self, node: AST.RenameNode) -> Any:
):
node.old_name = node.old_name.split("#")[1]

# Resolve old_name to canonical (original-case) component name
if (
self.regular_aggregation_dataset is not None
and node.old_name in self.regular_aggregation_dataset.components
):
node.old_name = self.regular_aggregation_dataset.resolve_component_name(node.old_name)

return node

def visit_Constant(self, node: AST.Constant) -> Any:
Expand Down Expand Up @@ -1260,11 +1292,13 @@ def visit_ParamOp(self, node: AST.ParamOp) -> None: # noqa: C901
if len(self.aggregation_dataset.get_measures()) != 1:
raise ValueError("Only one measure is allowed")
# Deepcopy is necessary for components to avoid changing the original dataset
self.aggregation_dataset.components = {
comp_name: deepcopy(comp)
for comp_name, comp in self.aggregation_dataset.components.items()
if comp_name in self.aggregation_grouping or comp.role == Role.MEASURE
}
self.aggregation_dataset.components = CaseInsensitiveDict(
{
comp_name: deepcopy(comp)
for comp_name, comp in self.aggregation_dataset.components.items()
if comp_name in self.aggregation_grouping or comp.role == Role.MEASURE
}
)

self.aggregation_dataset.data = (
self.aggregation_dataset.data[
Expand Down Expand Up @@ -1335,7 +1369,10 @@ def visit_HROperation(self, node: AST.HROperation) -> None: # noqa: C901
if len(cond_components) != len(hr_info["condition"]):
raise SemanticError("1-1-10-2", op=node.op)

if hr_info["node"].signature_type == "variable" and hr_info["signature"] != component:
if (
hr_info["node"].signature_type == "variable"
and hr_info["signature"].casefold() != component.casefold() # type: ignore[union-attr]
):
raise SemanticError(
"1-1-10-3",
op=node.op,
Expand Down Expand Up @@ -1393,6 +1430,9 @@ def visit_HROperation(self, node: AST.HROperation) -> None: # noqa: C901

Check_Hierarchy.validate_hr_dataset(dataset, component)

# Resolve to canonical (original-case) component name for DataFrame access
component = dataset.resolve_component_name(component)

# Set up interpreter state for rule processing
self.ruleset_dataset = dataset
self.ruleset_signature = {**{"RULE_COMPONENT": component}, **cond_info}
Expand Down Expand Up @@ -1463,7 +1503,7 @@ def visit_DPValidation(self, node: AST.DPValidation) -> None:
)
if dpr_info is not None and dpr_info["signature_type"] == "variable":
for i, comp_name in enumerate(node.components):
if comp_name != dpr_info["params"][i]:
if comp_name.casefold() != dpr_info["params"][i].casefold():
raise SemanticError(
"1-1-10-3",
op=CHECK_DATAPOINT,
Expand Down
19 changes: 17 additions & 2 deletions src/vtlengine/Model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vtlengine.DataTypes import SCALAR_TYPES, ScalarType
from vtlengine.DataTypes.TimeHandling import TimePeriodHandler
from vtlengine.Exceptions import InputValidationException, SemanticError
from vtlengine.Model._case_insensitive_dict import CaseInsensitiveDict


@dataclass
Expand Down Expand Up @@ -205,6 +206,9 @@ class Dataset:
persistent: bool = False

def __post_init__(self) -> None:
# Ensure components is always a CaseInsensitiveDict
if not isinstance(self.components, CaseInsensitiveDict):
self.components = CaseInsensitiveDict(self.components)
if self.data is not None:
if len(self.components) != len(self.data.columns):
raise ValueError(
Expand Down Expand Up @@ -331,15 +335,26 @@ def __eq__(self, other: Any) -> bool:
def get_component(self, component_name: str) -> Component:
return self.components[component_name]

def resolve_component_name(self, name: str) -> str:
"""Return the canonical (original-case) component name for a case-insensitive match."""
if isinstance(self.components, CaseInsensitiveDict):
return self.components.canonical_key(name)
return name

def add_component(self, component: Component) -> None:
if component.name in self.components:
raise ValueError(f"Component with name {component.name} already exists")
self.components[component.name] = component

def delete_component(self, component_name: str) -> None:
self.components.pop(component_name, None)
# Resolve to canonical name for DataFrame column access
try:
canonical = self.resolve_component_name(component_name)
except KeyError:
return
del self.components[canonical]
if self.data is not None:
self.data.drop(columns=[component_name], inplace=True)
self.data.drop(columns=[canonical], inplace=True)

def get_components(self) -> List[Component]:
return list(self.components.values())
Expand Down
Loading
Loading