diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 945bb7a0ed2b..1068896eb4e3 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -85,6 +85,7 @@ dependencies { testImplementation(project(":testkit")) testImplementation("commons-lang:commons-lang") + testImplementation("org.mockito:mockito-core") testImplementation("net.bytebuddy:byte-buddy") testImplementation("net.hydromatic:foodmart-queries") testImplementation("net.hydromatic:quidem") diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java index 2c23f644db3e..749f4541f2df 100644 --- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java +++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactory.java @@ -211,7 +211,8 @@ default RelDataType leastRestrictive(List types, boolean convertToV return leastRestrictive(types, convertToVarying, false); } - @Nullable RelDataType leastRestrictive(List types, boolean convertToVarying, boolean coerce); + @Nullable RelDataType leastRestrictive(List types, boolean convertToVarying, + boolean coerce); /** * Creates a SQL type with no precision or scale. diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index cd99de7b4a9a..863da6524007 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -24,11 +24,13 @@ import org.apache.calcite.rel.type.RelProtoDataType; import org.apache.calcite.rex.RexCallBinding; import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.ExplicitOperatorBinding; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlCollation; import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.SqlNumericLiteral; import org.apache.calcite.sql.SqlOperatorBinding; @@ -234,8 +236,7 @@ public static SqlCall stripSeparator(SqlCall call) { */ public static final SqlReturnTypeInference ARG0_NULLABLE_IF_EMPTY = new OrdinalReturnTypeInference(0) { - @Override public RelDataType - inferReturnType(SqlOperatorBinding opBinding) { + @Override public RelDataType inferReturnType(SqlOperatorBinding opBinding) { final RelDataType type = super.inferReturnType(opBinding); if (opBinding.getGroupCount() == 0 || opBinding.hasFilter()) { return opBinding.getTypeFactory() @@ -704,62 +705,290 @@ public static SqlCall stripSeparator(SqlCall call) { result.right); }; - private static Pair getDecimalMultiplyBindingType(SqlOperatorBinding opBinding, - RelDataTypeFactory typeFactory) { + /** + * Determines the appropriate data types for decimal multiplication operations. + * + *

This method serves as a dispatcher that routes to the appropriate handler + * based on the type of operator binding. It handles both SQL parse tree bindings + * (SqlCallBinding) and relational expression bindings (RexCallBinding), with a + * fallback for other binding types.

+ * + *

The method is crucial for decimal multiplication type inference because it + * needs to analyze the actual operand values (not just their declared types) to + * determine if integer literals should be converted to decimal types for proper + * decimal arithmetic.

+ * + * @param opBinding the operator binding containing operand information and types + * @param typeFactory the type factory used to create new data types if needed + * @return a Pair containing the left and right operand types to be used for + * decimal multiplication, possibly with type conversions applied + */ + private static Pair getDecimalMultiplyBindingType( + SqlOperatorBinding opBinding, RelDataTypeFactory typeFactory) { + // Route to SqlCallBinding handler for SQL parse tree scenarios if (opBinding instanceof SqlCallBinding) { - RelDataType type1 = createDecimalTypeOrDefault(typeFactory, (SqlCallBinding) opBinding, 0); - RelDataType type2 = createDecimalTypeOrDefault(typeFactory, (SqlCallBinding) opBinding, 1); - return Pair.of(type1, type2); + return getRelDataTypeRelDataTypePair((SqlCallBinding) opBinding, typeFactory); } + // Route to RexCallBinding handler for relational expression scenarios if (opBinding instanceof RexCallBinding) { - RelDataType type1 = createDecimalTypeOrDefault(typeFactory, (RexCallBinding) opBinding, 0); - RelDataType type2 = createDecimalTypeOrDefault(typeFactory, (RexCallBinding) opBinding, 1); - return Pair.of(type1, type2); + return getRelDataTypeRelDataTypePair((RexCallBinding) opBinding, typeFactory); } + // Fallback: return original operand types for other binding types return Pair.of(opBinding.getOperandType(0), opBinding.getOperandType(1)); } - private static RelDataType createDecimalTypeOrDefault(RelDataTypeFactory typeFactory, - RexCallBinding opBinding, int ordinal) { - RelDataType defaultType = opBinding.getOperandType(ordinal); + /** + * Determines operand types for decimal multiplication in RexCallBinding scenarios. + * + *

This method handles relational expression bindings (RexCallBinding) where operands + * are represented as RexNode objects. It analyzes both the declared types and actual + * values to determine if type conversions are needed for proper decimal arithmetic.

+ * + *

The method implements a comprehensive type inference strategy that considers: + *

    + *
  • Whether operands have decimal types or contain decimal constants
  • + *
  • Whether operands are numeric (including integer literals)
  • + *
  • Whether integer literals should be converted to decimal types
  • + *
+ *

+ * + *

Key scenarios handled: + *

    + *
  • Decimal × Decimal: both operands maintain decimal types
  • + *
  • Decimal × Integer: integer may be converted to decimal
  • + *
  • Integer × Decimal: integer may be converted to decimal
  • + *
  • Integer × Integer: no conversion, return original types
  • + *
+ *

+ * + * @param opBinding the RexCallBinding containing RexNode operands + * @param typeFactory the type factory for creating new data types + * @return a Pair containing the potentially converted left and right operand types + */ + private static Pair getRelDataTypeRelDataTypePair( + RexCallBinding opBinding, RelDataTypeFactory typeFactory) { + // Get the default types for both operands + RelDataType leftDefaultType = opBinding.getOperandType(0); + RelDataType rightDefaultType = opBinding.getOperandType(1); try { - if (!SqlNodeUtils.isNumericLiteral(opBinding, ordinal)) { - return defaultType; + // Extract the actual RexNode operands for value analysis + RexNode leftOperand = opBinding.operands().get(0); + RexNode rightOperand = opBinding.operands().get(1); + + // Determine if operands are decimal (either by type or by constant value) + // Check if left operand is decimal by type or contains decimal constant + boolean leftIsDecimal = SqlTypeUtil.isDecimal(leftDefaultType) + || SqlNodeUtils.isDecimalConstantRexNode(leftOperand); + // Check if right operand is decimal by type or contains decimal constant + boolean rightIsDecimal = SqlTypeUtil.isDecimal(rightDefaultType) + || SqlNodeUtils.isDecimalConstantRexNode(rightOperand); + // Determine if operands are numeric (either by type or by literal value) + boolean leftIsNumeric = SqlTypeUtil.isNumeric(leftDefaultType) + || SqlNodeUtils.isNumericLiteralRexNode(leftOperand); + boolean rightIsNumeric = SqlTypeUtil.isNumeric(rightDefaultType) + || SqlNodeUtils.isNumericLiteralRexNode(rightOperand); + + // Apply decimal type inference if at least one operand is decimal and both are numeric + // This covers scenarios where we need decimal arithmetic precision + if ((leftIsDecimal || rightIsDecimal) && leftIsNumeric && rightIsNumeric) { + // Convert operands to appropriate decimal types if needed + RelDataType type1 = createDecimalTypeOrDefault(typeFactory, opBinding, 0, leftDefaultType); + RelDataType type2 = createDecimalTypeOrDefault(typeFactory, opBinding, 1, rightDefaultType); + return Pair.of(type1, type2); } - RexLiteral literal = (RexLiteral) opBinding.operands().get(ordinal); + // Fallback: both operands are numeric but neither is decimal, no conversion needed + return Pair.of(leftDefaultType, rightDefaultType); + } catch (Exception e) { + // Exception safety: return original types if any error occurs during analysis + return Pair.of(leftDefaultType, rightDefaultType); + } + } - if (SqlNodeUtils.isDecimalConstant(literal)) { + /** + * Converts integer literals to decimal types for RexCallBinding scenarios. + * + *

This method is responsible for type conversion in relational expression scenarios + * where integer literals need to be promoted to decimal types for proper decimal arithmetic. + * The conversion is essential when multiplying integers with decimals to maintain + * precision and avoid unintended integer arithmetic.

+ * + *

Conversion logic: + *

    + *
  • If the operand is already a decimal constant, return the default type
  • + *
  • If the operand is an integer literal, convert it to DECIMAL with appropriate + * precision
  • + *
  • For other cases, return the default type unchanged
  • + *
+ *

+ * + *

The precision calculation for converted integers uses the number of digits + * in the integer value. For example: + *

    + *
  • 123 → DECIMAL(3, 0) (3 digits)
  • + *
  • 45 → DECIMAL(2, 0) (2 digits)
  • + *
  • 0 → DECIMAL(1, 0) (special case)
  • + *
+ *

+ * + * @param typeFactory the type factory for creating new DECIMAL types + * @param opBinding the RexCallBinding containing the operands + * @param ordinal the zero-based index of the operand to analyze + * @param defaultType the default type to return if no conversion is needed + * @return either a converted DECIMAL type or the original default type + */ + private static RelDataType createDecimalTypeOrDefault(RelDataTypeFactory typeFactory, + RexCallBinding opBinding, int ordinal, RelDataType defaultType) { + // Check if the operand at the specified position is a numeric literal + if (SqlNodeUtils.isNumericLiteral(opBinding, ordinal)) { + RexNode node = opBinding.operands().get(ordinal); + // If it's already a decimal constant, no conversion needed + if (SqlNodeUtils.isDecimalConstant(node)) { return defaultType; } - Long value = literal.getValueAs(Long.class); - int length = (int) Math.floor(Math.log10(Math.abs(value))) + 1; - return typeFactory.createSqlType(SqlTypeName.DECIMAL, length, 0); - } catch (IllegalArgumentException | AssertionError e) { - return defaultType; + // Attempt to convert integer literals to decimal for proper decimal multiplication + RexLiteral literal = (RexLiteral) node; + RelDataType type = literal.getType(); + // Check if it's an exact numeric type (INTEGER, BIGINT, etc.) but not already DECIMAL + if (SqlTypeUtil.isExactNumeric(type) && !SqlTypeUtil.isDecimal(type)) { + // Convert integer types (INTEGER, BIGINT, etc.) to DECIMAL + Long value = literal.getValueAs(Long.class); + if (value != null) { + // Calculate precision based on the number of digits in the integer + // For example: 123 → 3 digits, 45 → 2 digits, 0 → 1 digit, -123 → 4 digits + int length; + if (value == 0) { + // Special case: 0 should have precision 1 + length = 1; + } else if (value < 0) { + // like sqlNode, negative number add 1 for '-' + length = (int) Math.floor(Math.log10(Math.abs(value))) + 2; + } else { + length = (int) Math.floor(Math.log10(value)) + 1; + } + // Create DECIMAL type with calculated precision and 0 scale + return typeFactory.createSqlType(SqlTypeName.DECIMAL, length, 0); + } + } } + // Return default type for non-literals or when conversion is not applicable + return defaultType; } - private static RelDataType createDecimalTypeOrDefault(RelDataTypeFactory typeFactory, - SqlCallBinding opBinding, int ordinal) { - RelDataType defaultType = opBinding.getOperandType(ordinal); + /** + * Determines operand types for decimal multiplication in SqlCallBinding scenarios. + * + *

This method handles SQL parse tree bindings (SqlCallBinding) where operands + * are represented as SqlNode objects. It analyzes both the declared types and actual + * values to determine if type conversions are needed for proper decimal arithmetic.

+ * + *

Similar to the RexCallBinding version, this method implements comprehensive + * type inference but operates on SqlNode objects instead of RexNode objects. + * The key difference is that during SQL parsing, integer literals are automatically + * converted to DECIMAL type, which affects the conversion logic.

+ * + *

Key scenarios handled: + *

    + *
  • Decimal × Decimal: both operands maintain decimal types
  • + *
  • Decimal × Integer: integer may be converted to decimal
  • + *
  • Integer × Decimal: integer may be converted to decimal
  • + *
  • Integer × Integer: no conversion, return original types
  • + *
+ *

+ * + * @param opBinding the SqlCallBinding containing SqlNode operands + * @param typeFactory the type factory for creating new data types + * @return a Pair containing the potentially converted left and right operand types + */ + private static Pair getRelDataTypeRelDataTypePair( + SqlCallBinding opBinding, RelDataTypeFactory typeFactory) { + // Get the default types for both operands + RelDataType leftDefaultType = opBinding.getOperandType(0); + RelDataType rightDefaultType = opBinding.getOperandType(1); try { - if (!SqlNodeUtils.isNumericLiteral(opBinding, ordinal)) { - return defaultType; + // Extract the actual SqlNode operands for value analysis + SqlNode leftOperand = opBinding.operands().get(0); + SqlNode rightOperand = opBinding.operands().get(1); + + // Determine if operands are decimal (either by type or by constant value) + // Check if left operand is decimal by type or contains decimal constant + boolean leftIsDecimal = SqlTypeUtil.isDecimal(leftDefaultType) + || SqlNodeUtils.isDecimalConstantSqlNode(leftOperand); + // Check if right operand is decimal by type or contains decimal constant + boolean rightIsDecimal = SqlTypeUtil.isDecimal(rightDefaultType) + || SqlNodeUtils.isDecimalConstantSqlNode(rightOperand); + // Determine if operands are numeric (either by type or by literal value) + boolean leftIsNumeric = SqlTypeUtil.isNumeric(leftDefaultType) + || SqlNodeUtils.isNumericLiteralSqlNode(leftOperand); + boolean rightIsNumeric = SqlTypeUtil.isNumeric(rightDefaultType) + || SqlNodeUtils.isNumericLiteralSqlNode(rightOperand); + + // Apply decimal type inference if at least one operand is decimal and both are numeric + // This covers scenarios where we need decimal arithmetic precision + if ((leftIsDecimal || rightIsDecimal) && leftIsNumeric && rightIsNumeric) { + // Convert operands to appropriate decimal types if needed + RelDataType type1 = createDecimalTypeOrDefault(typeFactory, opBinding, 0, leftDefaultType); + RelDataType type2 = createDecimalTypeOrDefault(typeFactory, opBinding, 1, rightDefaultType); + return Pair.of(type1, type2); } + // Return default types for scenarios where both operands are numeric but neither is decimal + return Pair.of(leftDefaultType, rightDefaultType); + } catch (Exception e) { + // Exception safety: return original types if any error occurs during analysis + return Pair.of(leftDefaultType, rightDefaultType); + } + } + + /** + * Converts numeric literals to decimal types for SqlCallBinding scenarios. + * + *

This method handles type conversion in SQL parse tree scenarios where numeric literals + * need to be processed for decimal arithmetic. Unlike the RexCallBinding version, this method + * operates on SqlNode objects where integer literals have already been converted to DECIMAL + * type during SQL parsing.

+ * + *

Key difference from RexCallBinding version: + *

    + *
  • During SQL parsing, integer literals are automatically converted to DECIMAL type
  • + *
  • This method primarily ensures proper precision and scale are preserved
  • + *
  • No need to manually calculate precision from integer values
  • + *
+ *

+ * + *

Conversion logic: + *

    + *
  • If the operand is a numeric literal (DECIMAL or INTEGER), create DECIMAL type
  • + *
  • Use the literal's existing precision and scale information
  • + *
  • For non-literals, return the default type unchanged
  • + *
+ *

+ * + * @param typeFactory the type factory for creating new DECIMAL types + * @param opBinding the SqlCallBinding containing SqlNode operands + * @param ordinal the zero-based index of the operand to analyze + * @param defaultType the default type to return if no conversion is needed + * @return either a converted DECIMAL type or the original default type + */ + private static RelDataType createDecimalTypeOrDefault(RelDataTypeFactory typeFactory, + SqlCallBinding opBinding, int ordinal, RelDataType defaultType) { + // Check if the operand at the specified position is a numeric literal + if (SqlNodeUtils.isNumericLiteral(opBinding, ordinal)) { + // Extract the SqlNumericLiteral from the call's operand list SqlNumericLiteral literal = (SqlNumericLiteral) opBinding.getCall().getOperandList().get(ordinal); - // When parsing into a **SqlNode**, integers are converted to **DECIMAL** type. - if (SqlNodeUtils.isDecimalConstant(literal)) { + // When parsing into a SqlNode, integers are converted to DECIMAL type + // This method ensures both decimal and integer literals are properly handled + if (SqlNodeUtils.isDecimalOrIntegerConstant(literal)) { + // Create DECIMAL type using the literal's existing precision and scale + // This preserves the original precision/scale information from parsing return typeFactory.createSqlType(SqlTypeName.DECIMAL, literal.getPrec(), literal.getScale()); } - return defaultType; - } catch (IllegalArgumentException e) { - return defaultType; } + // Return default type for non-literals or when conversion is not applicable + return defaultType; } /** diff --git a/core/src/main/java/org/apache/calcite/util/SqlNodeUtils.java b/core/src/main/java/org/apache/calcite/util/SqlNodeUtils.java index 935c3b7246b4..bbf8d015ec71 100644 --- a/core/src/main/java/org/apache/calcite/util/SqlNodeUtils.java +++ b/core/src/main/java/org/apache/calcite/util/SqlNodeUtils.java @@ -16,34 +16,416 @@ */ package org.apache.calcite.util; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNumericLiteral; import org.apache.calcite.sql.SqlOperatorBinding; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeUtil; +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * Utility class for working with {@link SqlNode} and {@link RexNode} objects. + * + *

This class provides various static methods to analyze and validate SQL nodes, + * particularly focusing on numeric and decimal constant detection. It includes methods + * to check if nodes represent decimal constants, numeric literals, or complex expressions + * containing numeric values.

+ * + *

The utility supports both {@link SqlNode} (parse tree representation) and + * {@link RexNode} (relational expression representation) objects, providing consistent + * behavior across different stages of SQL processing.

+ */ public class SqlNodeUtils { + /** + * Private constructor to prevent instantiation of this utility class. + * All methods in this class are static and should be accessed directly. + */ private SqlNodeUtils() { } - public static boolean isDecimalConstant(SqlNumericLiteral literal) { - SqlTypeName typeName = literal.getTypeName(); + /** + * Checks if the given {@link SqlNode} represents a decimal constant. + * + *

A decimal constant is defined as a {@link SqlNumericLiteral} with DECIMAL type + * that is not an integer literal. This method specifically excludes integer values + * even if they are stored as DECIMAL type.

+ * + * @param node the SQL node to check, may be null + * @return true if the node is a decimal constant (non-integer DECIMAL), false otherwise + */ + public static boolean isDecimalConstant(SqlNode node) { + if (node instanceof SqlNumericLiteral) { + SqlNumericLiteral numericLiteral = (SqlNumericLiteral) node; + SqlTypeName typeName = numericLiteral.getTypeName(); + // Check if it's a DECIMAL type and not an integer literal + if (typeName == null) { + return false; + } + return typeName == SqlTypeName.DECIMAL && !numericLiteral.isInteger(); + } + + return false; + } + + /** + * Checks if the given {@link SqlNode} represents a decimal or integer constant. + * + *

Unlike {@link #isDecimalConstant(SqlNode)}, this method includes both decimal + * and integer values. During SQL parsing, integers are converted to DECIMAL type, + * so this method checks for DECIMAL type regardless of whether it's an integer + * or decimal value.

+ * + * @param node the SQL node to check, may be null + * @return true if the node is a DECIMAL type numeric literal (including integers), false + * otherwise + */ + public static boolean isDecimalOrIntegerConstant(SqlNode node) { + if (node instanceof SqlNumericLiteral) { + SqlNumericLiteral numericLiteral = (SqlNumericLiteral) node; + SqlTypeName typeName = numericLiteral.getTypeName(); + // When parsing into a **SqlNode**, integers are converted to **DECIMAL** type. + if (typeName == null) { + return false; + } + return typeName == SqlTypeName.DECIMAL; + } + + return false; + } + + /** + * Checks if the given {@link RexNode} represents a decimal constant. + * + *

This method checks if the node is a {@link RexLiteral} with DECIMAL type + * or any approximate numeric type (like FLOAT, DOUBLE). Unlike the SqlNode + * version, this includes approximate numeric types as well.

+ * + * @param node the Rex node to check, may be null + * @return true if the node is a decimal or approximate numeric constant, false otherwise + */ + public static boolean isDecimalConstant(RexNode node) { + if (node instanceof RexLiteral) { + RexLiteral literal = (RexLiteral) node; + RelDataType type = literal.getType(); + // Check if it's a DECIMAL type and not an integer literal + return type.getSqlTypeName() == SqlTypeName.DECIMAL || SqlTypeUtil.isApproximateNumeric(type); + } + return false; + } - return typeName == SqlTypeName.DECIMAL; + /** + * Checks if the given {@link SqlNode} is a numeric literal. + * + *

This is a simple type check that returns true if the node is an instance + * of {@link SqlNumericLiteral}, regardless of the specific numeric type.

+ * + * @param node the SQL node to check, may be null + * @return true if the node is a numeric literal, false otherwise + */ + public static boolean isNumericLiteral(SqlNode node) { + return node instanceof SqlNumericLiteral; } + /** + * Checks if the given {@link RexNode} is a numeric literal. + * + *

This method checks if the node is a {@link RexLiteral} with a numeric type. + * It uses {@link SqlTypeUtil#isNumeric(RelDataType)} to determine if the type + * is numeric, which includes all numeric types like INTEGER, DECIMAL, FLOAT, etc.

+ * + * @param node the Rex node to check, may be null + * @return true if the node is a numeric literal, false otherwise + */ + public static boolean isNumericLiteral(RexNode node) { + if (node instanceof RexLiteral) { + RexLiteral literal = (RexLiteral) node; + RelDataType type = literal.getType(); + return SqlTypeUtil.isNumeric(type); + } + return false; + } + + /** + * Checks if the operand at the specified ordinal in a {@link SqlOperatorBinding} + * is a numeric literal. + * + *

This method verifies that the operand is both a literal and has a numeric type. + * It's commonly used in operator validation to ensure operands are numeric literals.

+ * + * @param binding the operator binding containing the operands + * @param ordinal the zero-based index of the operand to check + * @return true if the operand at the specified position is a numeric literal, false otherwise + */ public static boolean isNumericLiteral(SqlOperatorBinding binding, int ordinal) { - return binding.isOperandLiteral(ordinal, false) && - SqlTypeUtil.isNumeric(binding.getOperandType(ordinal)); + return binding.isOperandLiteral(ordinal, false) + && SqlTypeUtil.isNumeric(binding.getOperandType(ordinal)); + } + + /** + * Checks if the given {@link RexNode} represents a decimal constant expression. + * + *

This method performs a deep analysis of the expression tree to determine if + * it contains only numeric literals and at least one decimal constant. It traverses + * binary arithmetic operations (like +, -, *, /) and checks all operands.

+ * + *

The method returns true only if: + *

    + *
  • All leaf nodes in the expression are numeric literals
  • + *
  • At least one leaf node is a decimal constant
  • + *
  • All intermediate nodes are binary arithmetic operations
  • + *
+ *

+ * + *

For example, this would return true for expressions like: + *

    + *
  • 1.5 + 2 (contains decimal 1.5)
  • + *
  • 3.14 * 2 - 1 (contains decimal 3.14)
  • + *
+ * But false for: + *
    + *
  • 1 + 2 (no decimal constants)
  • + *
  • 1.5 + column_name (contains non-literal)
  • + *
+ *

+ * + * @param node the Rex node to check, may be null + * @return true if the expression contains only numeric literals and at least one decimal constant + */ + public static boolean isDecimalConstantRexNode(RexNode node) { + if (node == null) { + return false; + } + + // If the node itself is a decimal constant, return true immediately + if (isDecimalConstant(node)) { + return true; + } + + Queue nodesToCheck = new ArrayDeque<>(); + // Add all operands of the initial RexCall to the queue for checking + nodesToCheck.add(node); + + boolean allNumeric = true; + boolean anyDecimal = false; + + while (!nodesToCheck.isEmpty()) { + RexNode currentNode = nodesToCheck.poll(); + // Check if the current node is a decimal constant + if (isDecimalConstant(currentNode)) { + anyDecimal = true; + continue; // This node is a decimal constant, check the next one + } + + // Check if the current node is a numeric literal + if (isNumericLiteral(currentNode)) { + continue; // This node is numeric, check the next one + } + + // If the current node is a RexCall with binary arithmetic operations, add its operands to + // the queue + if (currentNode instanceof RexCall) { + RexCall call = (RexCall) currentNode; + SqlKind kind = call.getKind(); + if (kind.belongsTo(SqlKind.BINARY_ARITHMETIC)) { + // Add all operands to the queue for checking + nodesToCheck.addAll(call.getOperands()); + continue; + } + } + + // If we reach here, the node is not a decimal constant, not a numeric literal, + // and not a valid arithmetic call + allNumeric = false; + break; // Early exit if we find a non-numeric node + } + + // Return true only if all nodes are numeric and at least one is a decimal constant + return allNumeric && anyDecimal; } - public static boolean isDecimalConstant(RexLiteral literal) { - SqlTypeName typeName = literal.getType().getSqlTypeName(); + /** + * Checks if the given {@link RexNode} represents an expression containing only numeric literals. + * + *

This method performs a deep analysis of the expression tree to determine if + * it contains only numeric literals. It traverses binary arithmetic operations + * and checks all operands to ensure they are all numeric literals.

+ * + *

Unlike {@link #isDecimalConstantRexNode(RexNode)}, this method doesn't require + * at least one decimal constant - it accepts expressions with only integer literals as well.

+ * + *

For example, this would return true for expressions like: + *

    + *
  • 1 + 2 (both integer literals)
  • + *
  • 1.5 * 3.14 (both decimal literals)
  • + *
  • 10 - 5 + 2 (multiple integer literals)
  • + *
+ * But false for: + *
    + *
  • 1 + column_name (contains non-literal)
  • + *
  • function_call(5) (contains function call)
  • + *
+ *

+ * + * @param node the Rex node to check, may be null + * @return true if the expression contains only numeric literals, false otherwise + */ + public static boolean isNumericLiteralRexNode(RexNode node) { + if (node == null) { + return false; + } + + Queue nodesToCheck = new ArrayDeque<>(); + nodesToCheck.add(node); + + while (!nodesToCheck.isEmpty()) { + RexNode currentNode = nodesToCheck.poll(); + + if (isNumericLiteral(currentNode)) { + continue; // This node is numeric, check the next one + } + + if (currentNode instanceof RexCall) { + RexCall call = (RexCall) currentNode; + SqlKind kind = call.getKind(); + if (kind.belongsTo(SqlKind.BINARY_ARITHMETIC)) { + // Add all operands to the queue for checking + nodesToCheck.addAll(call.getOperands()); + continue; + } + } + + // If we reach here, the node is not a numeric literal and not a valid arithmetic call + return false; + } + + // If we've checked all nodes and none failed the numeric test, return true + return true; + } + + /** + * Checks if the given {@link SqlNode} represents a decimal constant expression. + * + *

This method is the SqlNode equivalent of {@link #isDecimalConstantRexNode(RexNode)}. + * It performs a deep analysis of the SQL expression tree to determine if + * it contains only numeric literals and at least one decimal constant.

+ * + *

The method traverses binary arithmetic operations and checks all operands. + * It returns true only if all leaf nodes are numeric literals and at least one + * is a decimal constant.

+ * + *

This method is typically used during SQL parsing and validation stages, + * before the SQL is converted to relational expressions.

+ * + * @param node the SQL node to check, may be null + * @return true if the expression contains only numeric literals and at least one decimal constant + */ + public static boolean isDecimalConstantSqlNode(SqlNode node) { + if (node == null) { + return false; + } + + // If the node itself is a decimal constant, return true immediately + if (isDecimalConstant(node)) { + return true; + } + + Queue nodesToCheck = new ArrayDeque<>(); + nodesToCheck.add(node); + boolean allNumeric = true; + boolean anyDecimal = false; + + while (!nodesToCheck.isEmpty()) { + SqlNode currentNode = nodesToCheck.poll(); + + // Check if the current node is a decimal constant + if (isDecimalConstant(currentNode)) { + anyDecimal = true; + continue; // This node is a decimal constant, check the next one + } + + // Check if the current node is a numeric literal + if (isNumericLiteral(currentNode)) { + continue; // This node is numeric, check the next one + } + + // If the current node is a SqlCall with binary arithmetic operations, add its operands to + // the queue + if (currentNode instanceof SqlCall) { + SqlCall call = (SqlCall) currentNode; + SqlKind kind = call.getKind(); + if (kind.belongsTo(SqlKind.BINARY_ARITHMETIC)) { + // Add all operands to the queue for checking + nodesToCheck.addAll(call.getOperandList()); + continue; + } + } + + // If we reach here, the node is not a decimal constant, not a numeric literal, + // and not a valid arithmetic call + allNumeric = false; + break; // Early exit if we find a non-numeric node + } + + // Return true only if all nodes are numeric and at least one is a decimal constant + return allNumeric && anyDecimal; + } + + /** + * Checks if the given {@link SqlNode} represents an expression containing only numeric literals. + * + *

This method is the SqlNode equivalent of {@link #isNumericLiteralRexNode(RexNode)}. + * It performs a deep analysis of the SQL expression tree to determine if + * it contains only numeric literals.

+ * + *

The method traverses binary arithmetic operations and checks all operands + * to ensure they are all numeric literals. It accepts expressions with both + * integer and decimal literals.

+ * + *

This method is typically used during SQL parsing and validation to identify + * constant expressions that can be evaluated at compile time.

+ * + * @param node the SQL node to check, may be null + * @return true if the expression contains only numeric literals, false otherwise + */ + public static boolean isNumericLiteralSqlNode(SqlNode node) { + if (node == null) { + return false; + } + + Queue nodesToCheck = new ArrayDeque<>(); + nodesToCheck.add(node); + + while (!nodesToCheck.isEmpty()) { + SqlNode currentNode = nodesToCheck.poll(); + + if (isNumericLiteral(currentNode)) { + continue; // This node is numeric, check the next one + } + + if (currentNode instanceof SqlCall) { + SqlCall call = (SqlCall) currentNode; + SqlKind kind = call.getKind(); + if (kind.belongsTo(SqlKind.BINARY_ARITHMETIC)) { + // Add all operands to the queue for checking + nodesToCheck.addAll(call.getOperandList()); + continue; + } + } + + // If we reach here, the node is not a numeric literal and not a valid arithmetic call + return false; + } - return typeName == SqlTypeName.DECIMAL - || typeName == SqlTypeName.DOUBLE - || typeName == SqlTypeName.REAL - || typeName == SqlTypeName.FLOAT; + // If we've checked all nodes and none failed the numeric test, return true + return true; } } diff --git a/core/src/test/java/org/apache/calcite/sql/type/ReturnTypesTest.java b/core/src/test/java/org/apache/calcite/sql/type/ReturnTypesTest.java new file mode 100644 index 000000000000..7f2e09356147 --- /dev/null +++ b/core/src/test/java/org/apache/calcite/sql/type/ReturnTypesTest.java @@ -0,0 +1,572 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.sql.type; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCallBinding; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.runtime.CalciteException; +import org.apache.calcite.runtime.Resources; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlNumericLiteral; +import org.apache.calcite.sql.SqlOperatorBinding; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql.validate.SqlValidatorException; +import org.apache.calcite.sql.validate.SqlValidatorImpl; +import org.apache.calcite.util.Pair; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +/** + * Test for {@link org.apache.calcite.sql.type.ReturnTypes}. + * Tests the DECIMAL_PRODUCT return type inference without using reflection. + */ +class ReturnTypesTest { + + private SqlTypeFixture f; + private RelDataTypeFactory typeFactory; + private RexBuilder rexBuilder; + + @BeforeEach + void setUp() { + f = new SqlTypeFixture(); + typeFactory = f.typeFactory; + rexBuilder = new RexBuilder(typeFactory); + } + + @Test void testDecimalProductWithDecimalTypes() { + // Test case: Both operands are DECIMAL types + RelDataType decimal1 = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 2); + RelDataType decimal2 = typeFactory.createSqlType(SqlTypeName.DECIMAL, 8, 3); + + SqlNumericLiteral literal1 = SqlNumericLiteral.createExactNumeric("123.45", SqlParserPos.ZERO); + SqlNumericLiteral literal2 = SqlNumericLiteral.createExactNumeric("67.89", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, literal1, literal2); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, decimal1, + decimal2); + + // Test the DECIMAL_PRODUCT return type inference + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(binding); + + assertNotNull(result); + assertEquals(SqlTypeName.DECIMAL, result.getSqlTypeName()); + } + + @Test void testDecimalProductWithIntegerTypes() { + // Test case: Both operands are INTEGER types + RelDataType int1 = typeFactory.createSqlType(SqlTypeName.INTEGER); + RelDataType int2 = typeFactory.createSqlType(SqlTypeName.INTEGER); + + SqlNumericLiteral literal1 = SqlNumericLiteral.createExactNumeric("123", SqlParserPos.ZERO); + SqlNumericLiteral literal2 = SqlNumericLiteral.createExactNumeric("456", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, literal1, literal2); + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, int1, int2); + + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(binding); + + assertNull(result); + } + + @Test void testDecimalProductWithMixedTypes() { + // Test case: One operand is DECIMAL, other is INTEGER + RelDataType decimalType = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 2); + RelDataType integerType = typeFactory.createSqlType(SqlTypeName.INTEGER); + + SqlNumericLiteral decimalLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral intLiteral = SqlNumericLiteral.createExactNumeric("456", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, decimalLiteral, + intLiteral); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, decimalType, + integerType); + + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(binding); + + assertNotNull(result); + assertEquals(SqlTypeName.DECIMAL, result.getSqlTypeName()); + } + + @Test void testDecimalProductWithBigIntTypes() { + // Test case: Both operands are BIGINT types + RelDataType bigint1 = typeFactory.createSqlType(SqlTypeName.BIGINT); + RelDataType bigint2 = typeFactory.createSqlType(SqlTypeName.BIGINT); + + SqlNumericLiteral literal1 = SqlNumericLiteral.createExactNumeric("123456789", + SqlParserPos.ZERO); + SqlNumericLiteral literal2 = SqlNumericLiteral.createExactNumeric("987654321", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, literal1, literal2); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, bigint1, + bigint2); + + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(binding); + + assertNull(result); + } + + @Test void testDecimalProductSqlNodeWithNegativeValues() { + // Test case: Both operands are DECIMAL types + RelDataType decimal1 = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 2); + RelDataType decimal2 = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 0); + + SqlNumericLiteral literal1 = SqlNumericLiteral.createExactNumeric("123.45", SqlParserPos.ZERO); + SqlNumericLiteral literal2 = SqlNumericLiteral.createExactNumeric("-67", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, literal1, literal2); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, decimal1, + decimal2); + + // Test the DECIMAL_PRODUCT return type inference + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(binding); + + assertNotNull(result); + assertEquals(SqlTypeName.DECIMAL, result.getSqlTypeName()); + assertEquals(8, result.getPrecision()); + assertEquals(2, result.getScale()); + } + + @Test void testDecimalProductRexNodeWithNegativeValues() { + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral decimalLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("-67")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + decimalLiteral1, decimalLiteral2); + + // Create a RexCallBinding + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + RelDataType result2 = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(rexBinding); + + assertNotNull(result2); + assertEquals(SqlTypeName.DECIMAL, result2.getSqlTypeName()); + assertEquals(8, result2.getPrecision()); + assertEquals(2, result2.getScale()); + } + + @Test void testDecimalProductWithRexNodes() { + // Test using RexBuilder to create RexCall (similar to SqlNodeUtilsTest pattern) + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral decimalLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("67.89")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + decimalLiteral1, decimalLiteral2); + + // Create a RexCallBinding + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(rexBinding); + + assertNotNull(result); + assertEquals(SqlTypeName.DECIMAL, result.getSqlTypeName()); + } + + @Test void testDecimalProductWithIntegerRexNodes() { + // Test using RexBuilder with integer literals + RexLiteral intLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123")); + RexLiteral intLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("456")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + intLiteral1, intLiteral2); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(rexBinding); + + assertNull(result); + } + + @Test void testDecimalProductWithMixedRexNodes() { + // Test using RexBuilder with mixed decimal and integer literals + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral intLiteral = rexBuilder.makeExactLiteral(new BigDecimal("456")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + decimalLiteral, intLiteral); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + RelDataType result = ReturnTypes.DECIMAL_PRODUCT.inferReturnType(rexBinding); + + assertNotNull(result); + assertEquals(SqlTypeName.DECIMAL, result.getSqlTypeName()); + } + + // Comprehensive tests for getDecimalMultiplyBindingType method + + @Test void testGetDecimalMultiplyBindingTypeWithSqlCallBindingDecimalDecimal() throws Exception { + // Test SqlCallBinding with both operands as DECIMAL literals + RelDataType decimal1 = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 2); + RelDataType decimal2 = typeFactory.createSqlType(SqlTypeName.DECIMAL, 8, 3); + + SqlNumericLiteral literal1 = SqlNumericLiteral.createExactNumeric("123.45", SqlParserPos.ZERO); + SqlNumericLiteral literal2 = SqlNumericLiteral.createExactNumeric("67.89", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, literal1, literal2); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, decimal1, + decimal2); + + Pair result = invokeGetDecimalMultiplyBindingType(binding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithSqlCallBindingDecimalInteger() throws Exception { + // Test SqlCallBinding with DECIMAL and INTEGER literals + RelDataType decimalType = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 2); + RelDataType integerType = typeFactory.createSqlType(SqlTypeName.INTEGER); + + SqlNumericLiteral decimalLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral intLiteral = SqlNumericLiteral.createExactNumeric("456", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, decimalLiteral, + intLiteral); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, decimalType, + integerType); + + Pair result = invokeGetDecimalMultiplyBindingType(binding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithSqlCallBindingIntegerInteger() throws Exception { + // Test SqlCallBinding with both operands as INTEGER literals + RelDataType int1 = typeFactory.createSqlType(SqlTypeName.INTEGER); + RelDataType int2 = typeFactory.createSqlType(SqlTypeName.INTEGER); + + SqlNumericLiteral literal1 = SqlNumericLiteral.createExactNumeric("123", SqlParserPos.ZERO); + SqlNumericLiteral literal2 = SqlNumericLiteral.createExactNumeric("456", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, literal1, literal2); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, int1, int2); + + Pair result = invokeGetDecimalMultiplyBindingType(binding, + typeFactory); + + assertEquals(SqlTypeName.INTEGER, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.INTEGER, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithRexCallBindingDecimalDecimal() throws Exception { + // Test RexCallBinding with both operands as DECIMAL literals + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral decimalLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("67.89")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + decimalLiteral1, decimalLiteral2); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithRexCallBindingDecimalInteger() throws Exception { + // Test RexCallBinding with DECIMAL and INTEGER literals + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral intLiteral = rexBuilder.makeExactLiteral(new BigDecimal("456")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + decimalLiteral, intLiteral); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithRexCallBindingIntegerInteger() throws Exception { + // Test RexCallBinding with both operands as INTEGER literals + RexLiteral intLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123")); + RexLiteral intLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("456")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + intLiteral1, intLiteral2); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + assertEquals(SqlTypeName.INTEGER, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.INTEGER, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithRexCallBindingIntegerConversion() + throws Exception { + // Test integer to decimal conversion in RexCallBinding + RexLiteral intLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("1231234567890")); + RexLiteral intLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("12345678904567")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + intLiteral1, intLiteral2); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + // Check precision calculation for integer conversion + assertEquals(SqlTypeName.BIGINT, result.left.getSqlTypeName()); + + assertEquals(SqlTypeName.BIGINT, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithFallbackBinding() throws Exception { + // Test fallback to original types for non-SqlCallBinding/RexCallBinding + RelDataType int1 = typeFactory.createSqlType(SqlTypeName.INTEGER); + RelDataType int2 = typeFactory.createSqlType(SqlTypeName.INTEGER); + + TestOperatorBinding fallbackBinding = new TestOperatorBinding(typeFactory, int1, int2); + + Pair result = invokeGetDecimalMultiplyBindingType(fallbackBinding, + typeFactory); + + // Should return original types + assertEquals(SqlTypeName.INTEGER, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.INTEGER, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithFallbackBinding2() throws Exception { + // Test fallback to original types for non-SqlCallBinding/RexCallBinding + RelDataType int1 = typeFactory.createSqlType(SqlTypeName.DECIMAL); + RelDataType int2 = typeFactory.createSqlType(SqlTypeName.DECIMAL); + + TestOperatorBinding fallbackBinding = new TestOperatorBinding(typeFactory, int1, int2); + + Pair result = invokeGetDecimalMultiplyBindingType(fallbackBinding, + typeFactory); + + // Should return original types + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithFallbackBinding3() throws Exception { + // Test fallback to original types for non-SqlCallBinding/RexCallBinding + RelDataType int1 = typeFactory.createSqlType(SqlTypeName.DECIMAL); + RelDataType int2 = typeFactory.createSqlType(SqlTypeName.INTEGER); + + TestOperatorBinding fallbackBinding = new TestOperatorBinding(typeFactory, int1, int2); + + Pair result = invokeGetDecimalMultiplyBindingType(fallbackBinding, + typeFactory); + + // Should return original types + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(SqlTypeName.INTEGER, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithBigIntConversion() throws Exception { + // Test BIGINT to decimal conversion in RexCallBinding + RexLiteral bigIntLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123456789")); + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + bigIntLiteral, decimalLiteral); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(9, result.left.getPrecision()); // 123456789 has 9 digits + assertEquals(0, result.left.getScale()); + + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithZeroValue() throws Exception { + // Test with zero value (special case for precision calculation) + RexLiteral zeroLiteral = rexBuilder.makeExactLiteral(new BigDecimal("0")); + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + zeroLiteral, decimalLiteral); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(1, result.left.getPrecision()); // 0 should have precision 1 + assertEquals(0, result.left.getScale()); + + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithNegativeValues() throws Exception { + // Test with negative values + RexLiteral negLiteral = rexBuilder.makeExactLiteral(new BigDecimal("-123")); + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("456.78")); + RexCall rexCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, + negLiteral, decimalLiteral); + + RexCallBinding rexBinding = RexCallBinding.create(typeFactory, rexCall, + Collections.emptyList()); + + Pair result = invokeGetDecimalMultiplyBindingType(rexBinding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + assertEquals(4, result.left.getPrecision()); // -123 has 4 digits (ignoring sign) + assertEquals(0, result.left.getScale()); + + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + } + + @Test void testGetDecimalMultiplyBindingTypeWithNegativeValuesSqlBinding() throws Exception { + RelDataType decimalType = typeFactory.createSqlType(SqlTypeName.DECIMAL, 10, 2); + RelDataType integerType = typeFactory.createSqlType(SqlTypeName.INTEGER); + + SqlNumericLiteral decimalLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral intLiteral = SqlNumericLiteral.createExactNumeric("-456", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, decimalLiteral, + intLiteral); + + SqlValidator validator = Mockito.mock(SqlValidatorImpl.class); + Mockito.when(validator.getTypeFactory()).thenReturn(typeFactory); + // Create a test operator binding + TestSqlOperatorBinding binding = new TestSqlOperatorBinding(validator, call, decimalType, + integerType); + Pair result = invokeGetDecimalMultiplyBindingType(binding, + typeFactory); + + assertEquals(SqlTypeName.DECIMAL, result.left.getSqlTypeName()); + + assertEquals(SqlTypeName.DECIMAL, result.right.getSqlTypeName()); + assertEquals(4, result.right.getPrecision()); // -456 has 4 digits (ignoring sign) + assertEquals(0, result.right.getScale()); + } + + // Helper method to invoke private getDecimalMultiplyBindingType method using reflection + @SuppressWarnings("unchecked") + private Pair invokeGetDecimalMultiplyBindingType( + SqlOperatorBinding binding, RelDataTypeFactory typeFactory) throws Exception { + Method method = ReturnTypes.class.getDeclaredMethod("getDecimalMultiplyBindingType", + SqlOperatorBinding.class, RelDataTypeFactory.class); + method.setAccessible(true); + return (Pair) method.invoke(null, binding, typeFactory); + } + + /** + * Simple test implementation of SqlOperatorBinding for fallback testing. + */ + private static class TestOperatorBinding extends SqlOperatorBinding { + private final RelDataType[] operandTypes; + + TestOperatorBinding(RelDataTypeFactory typeFactory, RelDataType... operandTypes) { + super(typeFactory, SqlStdOperatorTable.MULTIPLY); + this.operandTypes = operandTypes; + } + + @Override public int getOperandCount() { + return operandTypes.length; + } + + @Override public RelDataType getOperandType(int ordinal) { + return operandTypes[ordinal]; + } + + @Override public CalciteException newError(Resources.ExInst e) { + return new CalciteException(e.str(), null); + } + } + + /** + * Simple test implementation of SqlOperatorBinding for testing purposes. + */ + private static class TestSqlOperatorBinding extends SqlCallBinding { + private final RelDataType[] operandTypes; + private final SqlCall call; + + TestSqlOperatorBinding(SqlValidator validator, SqlCall call, RelDataType... operandTypes) { + super(validator, null, call); + this.operandTypes = operandTypes; + this.call = call; + } + + @Override public int getOperandCount() { + return operandTypes.length; + } + + @Override public RelDataType getOperandType(int ordinal) { + return operandTypes[ordinal]; + } + + @Override public CalciteException newError(Resources.ExInst e) { + return new CalciteException(e.str(), null); + } + } + +} diff --git a/core/src/test/java/org/apache/calcite/util/SqlNodeUtilsTest.java b/core/src/test/java/org/apache/calcite/util/SqlNodeUtilsTest.java new file mode 100644 index 000000000000..2aa795686527 --- /dev/null +++ b/core/src/test/java/org/apache/calcite/util/SqlNodeUtilsTest.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.util; + +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlCharStringLiteral; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNumericLiteral; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Test for SqlNodeUtils class. + */ +public class SqlNodeUtilsTest { + // Test data for RexNode tests + private RexBuilder rexBuilder; + + @BeforeEach + public void setUp() { + rexBuilder = new RexBuilder(new org.apache.calcite.jdbc.JavaTypeFactoryImpl()); + } + + @Test public void testIsDecimalConstantSqlNodeWithDecimalLiteral() { + // Test with a decimal literal + SqlNumericLiteral decimalLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + assertTrue(SqlNodeUtils.isDecimalConstantSqlNode(decimalLiteral)); + } + + @Test public void testIsDecimalConstantSqlNodeWithIntegerLiteral() { + // Test with an integer literal (which is not considered a decimal constant) + SqlNumericLiteral integerLiteral = SqlNumericLiteral.createExactNumeric("123", + SqlParserPos.ZERO); + assertFalse(SqlNodeUtils.isDecimalConstantSqlNode(integerLiteral)); + } + + @Test public void testIsDecimalConstantSqlNodeWithNull() { + // Test with null input + assertFalse(SqlNodeUtils.isDecimalConstantSqlNode(null)); + } + + @Test public void testIsDecimalConstantSqlNodeWithIntegerCall() { + // Test with an integer literal (which is not considered a decimal constant) + SqlNumericLiteral decimalLiteral1 = SqlNumericLiteral.createExactNumeric("123", + SqlParserPos.ZERO); + SqlNumericLiteral decimalLiteral2 = SqlNumericLiteral.createExactNumeric("678", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, decimalLiteral1, + decimalLiteral2); + assertFalse(SqlNodeUtils.isDecimalConstantSqlNode(call)); + } + + @Test public void testIsDecimalConstantSqlNodeWithSimpleCall() { + // Test with a simple call containing decimal constants + SqlNumericLiteral decimalLiteral1 = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral decimalLiteral2 = SqlNumericLiteral.createExactNumeric("678.90", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, decimalLiteral1, + decimalLiteral2); + assertTrue(SqlNodeUtils.isDecimalConstantSqlNode(call)); + } + + @Test public void testIsDecimalConstantSqlNodeWithMixedCall() { + // Test with a call containing both decimal and integer constants + SqlNumericLiteral decimalLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral integerLiteral = SqlNumericLiteral.createExactNumeric("678", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, decimalLiteral, + integerLiteral); + assertTrue(SqlNodeUtils.isDecimalConstantSqlNode(call)); + } + + @Test public void testIsDecimalConstantSqlNodeWithDeeplyNestedCall() { + // Test with a deeply nested call containing decimal constants + SqlNumericLiteral decimalLiteral1 = SqlNumericLiteral.createExactNumeric("1.1", + SqlParserPos.ZERO); + SqlNumericLiteral decimalLiteral2 = SqlNumericLiteral.createExactNumeric("2.2", + SqlParserPos.ZERO); + SqlNumericLiteral decimalLiteral3 = SqlNumericLiteral.createExactNumeric("3.3", + SqlParserPos.ZERO); + SqlNumericLiteral decimalLiteral4 = SqlNumericLiteral.createExactNumeric("4.4", + SqlParserPos.ZERO); + + // Create a deeply nested expression: ((1.1 + 2.2) + 3.3) + 4.4 + SqlCall innerCall1 = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, decimalLiteral1, + decimalLiteral2); + SqlCall innerCall2 = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, innerCall1, + decimalLiteral3); + SqlCall outerCall = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, innerCall2, + decimalLiteral4); + + assertTrue(SqlNodeUtils.isDecimalConstantSqlNode(outerCall)); + } + + @Test public void testIsDecimalConstantSqlNodeWithNonArithmeticCall() { + // Test with a non-arithmetic call (CONCAT) containing decimal constants + SqlNumericLiteral decimalLiteral1 = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral decimalLiteral2 = SqlNumericLiteral.createExactNumeric("678.90", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.CONCAT.createCall(SqlParserPos.ZERO, decimalLiteral1, + decimalLiteral2); + assertFalse(SqlNodeUtils.isDecimalConstantSqlNode(call)); + } + + + @Test public void testIsDecimalConstantRexNodeWithDecimalLiteral() { + // Test with a decimal literal + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + assertTrue(SqlNodeUtils.isDecimalConstantRexNode(decimalLiteral)); + } + + @Test public void testIsDecimalConstantRexNodeWithIntegerLiteral() { + // Test with an integer literal (which is not considered a decimal constant) + RexLiteral integerLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123")); + // Even though it's created with BigDecimal, it's still considered a decimal constant + assertFalse(SqlNodeUtils.isDecimalConstantRexNode(integerLiteral)); + } + + @Test public void testIsDecimalConstantRexNodeWithNull() { + // Test with null input + assertFalse(SqlNodeUtils.isDecimalConstantRexNode(null)); + } + + @Test public void testIsDecimalConstantRexNodeWithIntegerCall() { + // Test with a simple call containing decimal constants + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123")); + RexLiteral decimalLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("678")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, decimalLiteral1, + decimalLiteral2); + assertFalse(SqlNodeUtils.isDecimalConstantRexNode(call)); + } + + @Test public void testIsDecimalConstantRexNodeWithSimpleCall() { + // Test with a simple call containing decimal constants + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral decimalLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("678.90")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, decimalLiteral1, + decimalLiteral2); + assertTrue(SqlNodeUtils.isDecimalConstantRexNode(call)); + } + + @Test public void testIsDecimalConstantRexNodeWithMixedCall() { + // Test with a call containing both decimal and integer constants + RexLiteral decimalLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral integerLiteral = rexBuilder.makeExactLiteral(new BigDecimal("678")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, decimalLiteral, + integerLiteral); + assertTrue(SqlNodeUtils.isDecimalConstantRexNode(call)); + } + + @Test public void testIsDecimalConstantRexNodeWithDeeplyNestedCall() { + // Test with a deeply nested call containing decimal constants + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("1.1")); + RexLiteral decimalLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("2.2")); + RexLiteral decimalLiteral3 = rexBuilder.makeExactLiteral(new BigDecimal("3.3")); + RexLiteral decimalLiteral4 = rexBuilder.makeExactLiteral(new BigDecimal("4.4")); + + // Create a deeply nested expression: ((1.1 + 2.2) + 3.3) + 4.4 + RexCall innerCall1 = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, decimalLiteral1, + decimalLiteral2); + RexCall innerCall2 = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, innerCall1, + decimalLiteral3); + RexCall outerCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, innerCall2, + decimalLiteral4); + + assertTrue(SqlNodeUtils.isDecimalConstantRexNode(outerCall)); + } + + @Test public void testIsDecimalConstantRexNodeWithNonArithmeticCall() { + // Test with a non-arithmetic call (LIKE) containing decimal constants + RexLiteral decimalLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral decimalLiteral2 = rexBuilder.makeLiteral("678.90"); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.LIKE, decimalLiteral1, + decimalLiteral2); + assertFalse(SqlNodeUtils.isDecimalConstantRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithNumericLiteral() { + // Test with a numeric literal + RexLiteral numericLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(numericLiteral)); + } + + @Test public void testIsNumericLiteralRexNodeWithNonNumericLiteral() { + // Test with a non-numeric literal + RexLiteral stringLiteral = rexBuilder.makeLiteral("hello"); + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(stringLiteral)); + } + + @Test public void testIsNumericLiteralRexNodeWithNull() { + // Test with null input + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(null)); + } + + @Test public void testIsNumericLiteralRexNodeWithArithmeticCall() { + // Test with an arithmetic call containing numeric literals + RexLiteral numericLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral numericLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("678")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithNonNumericCall() { + // Test with a call containing non-numeric literals + RexLiteral stringLiteral1 = rexBuilder.makeLiteral("hello"); + RexLiteral stringLiteral2 = rexBuilder.makeLiteral("world"); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, stringLiteral1, + stringLiteral2); + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithMixedCall() { + // Test with a call containing both numeric and non-numeric literals + RexLiteral numericLiteral = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral stringLiteral = rexBuilder.makeLiteral("hello"); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, numericLiteral, + stringLiteral); + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithDeeplyNestedCall() { + // Test with a deeply nested arithmetic call containing numeric literals + RexLiteral numericLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("1.1")); + RexLiteral numericLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("2.2")); + RexLiteral numericLiteral3 = rexBuilder.makeExactLiteral(new BigDecimal("3.3")); + RexLiteral numericLiteral4 = rexBuilder.makeExactLiteral(new BigDecimal("4.4")); + + // Create a deeply nested expression: ((1.1 + 2.2) + 3.3) + 4.4 + RexCall innerCall1 = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, numericLiteral1, + numericLiteral2); + RexCall innerCall2 = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, innerCall1, + numericLiteral3); + RexCall outerCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, innerCall2, + numericLiteral4); + + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(outerCall)); + } + + @Test public void testIsNumericLiteralRexNodeWithNonArithmeticCall() { + // Test with a non-arithmetic call containing numeric literals + RexLiteral numericLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("123.45")); + RexLiteral numericLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("678.90")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.LIKE, numericLiteral1, + numericLiteral2); + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithDifferentArithmeticOperators() { + // Test with different arithmetic operators + RexLiteral numericLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("10")); + RexLiteral numericLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("5")); + + // Test PLUS + RexCall plusCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(plusCall)); + + // Test MINUS + RexCall minusCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MINUS, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(minusCall)); + + // Test MULTIPLY + RexCall timesCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(timesCall)); + + // Test DIVIDE + RexCall divideCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(divideCall)); + + // Test MOD + RexCall modCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MOD, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(modCall)); + } + + @Test public void testIsNumericLiteralRexNodeWithApproximateNumericLiterals() { + // Test with approximate numeric literals + RexLiteral approxLiteral1 = rexBuilder.makeApproxLiteral(new BigDecimal("1.23E45")); + RexLiteral approxLiteral2 = rexBuilder.makeApproxLiteral(new BigDecimal("6.78E90")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, approxLiteral1, + approxLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithMixedArithmeticOperators() { + // Test with mixed arithmetic operators in nested expressions + RexLiteral numericLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("10")); + RexLiteral numericLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("5")); + RexLiteral numericLiteral3 = rexBuilder.makeExactLiteral(new BigDecimal("2")); + + // Create a complex expression: (10 + 5) * 2 + RexCall innerCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, numericLiteral1, + numericLiteral2); + RexCall outerCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, innerCall, + numericLiteral3); + + assertTrue(SqlNodeUtils.isNumericLiteralRexNode(outerCall)); + } + + @Test public void testIsNumericLiteralRexNodeWithUnaryMinus() { + // Test with unary minus operator + RexLiteral numericLiteral = rexBuilder.makeExactLiteral(new BigDecimal("10")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, numericLiteral); + + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithUnaryPlus() { + // Test with unary plus operator + RexLiteral numericLiteral = rexBuilder.makeExactLiteral(new BigDecimal("10")); + RexCall call = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.UNARY_PLUS, numericLiteral); + + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralRexNodeWithUnaryOperatorsAndArithmetic() { + // Test with unary operators combined with arithmetic operators + RexLiteral numericLiteral1 = rexBuilder.makeExactLiteral(new BigDecimal("10")); + RexLiteral numericLiteral2 = rexBuilder.makeExactLiteral(new BigDecimal("5")); + + // Create expression: (-10) + 5 + RexCall unaryMinus = + (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.UNARY_MINUS, numericLiteral1); + RexCall call = + (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.PLUS, unaryMinus, numericLiteral2); + + assertFalse(SqlNodeUtils.isNumericLiteralRexNode(call)); + } + + @Test public void testIsNumericLiteralSqlNodeWithNumericLiteral() { + // Test with a numeric literal + SqlNumericLiteral numericLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(numericLiteral)); + } + + @Test public void testIsNumericLiteralSqlNodeWithNonNumericLiteral() { + // Test with a non-numeric literal (string literal) + SqlCharStringLiteral stringLiteral = SqlLiteral.createCharString("hello", SqlParserPos.ZERO); + assertFalse(SqlNodeUtils.isNumericLiteralSqlNode(stringLiteral)); + } + + @Test public void testIsNumericLiteralSqlNodeWithNull() { + // Test with null input + assertFalse(SqlNodeUtils.isNumericLiteralSqlNode(null)); + } + + @Test public void testIsNumericLiteralSqlNodeWithArithmeticCall() { + // Test with an arithmetic call containing numeric literals + SqlNumericLiteral numericLiteral1 = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral numericLiteral2 = SqlNumericLiteral.createExactNumeric("678", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(call)); + } + + @Test public void testIsNumericLiteralSqlNodeWithNonNumericCall() { + // Test with a call containing non-numeric literals + SqlCharStringLiteral stringLiteral1 = SqlLiteral.createCharString("hello", SqlParserPos.ZERO); + SqlCharStringLiteral stringLiteral2 = SqlLiteral.createCharString("world", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, stringLiteral1, + stringLiteral2); + assertFalse(SqlNodeUtils.isNumericLiteralSqlNode(call)); + } + + @Test public void testIsNumericLiteralSqlNodeWithMixedCall() { + // Test with a call containing both numeric and non-numeric literals + SqlNumericLiteral numericLiteral = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlCharStringLiteral stringLiteral = SqlLiteral.createCharString("hello", SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, numericLiteral, + stringLiteral); + assertFalse(SqlNodeUtils.isNumericLiteralSqlNode(call)); + } + + @Test public void testIsNumericLiteralSqlNodeWithDeeplyNestedCall() { + // Test with a deeply nested arithmetic call containing numeric literals + SqlNumericLiteral numericLiteral1 = SqlNumericLiteral.createExactNumeric("1.1", + SqlParserPos.ZERO); + SqlNumericLiteral numericLiteral2 = SqlNumericLiteral.createExactNumeric("2.2", + SqlParserPos.ZERO); + SqlNumericLiteral numericLiteral3 = SqlNumericLiteral.createExactNumeric("3.3", + SqlParserPos.ZERO); + SqlNumericLiteral numericLiteral4 = SqlNumericLiteral.createExactNumeric("4.4", + SqlParserPos.ZERO); + + // Create a deeply nested expression: ((1.1 + 2.2) + 3.3) + 4.4 + SqlCall innerCall1 = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + SqlCall innerCall2 = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, innerCall1, + numericLiteral3); + SqlCall outerCall = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, innerCall2, + numericLiteral4); + + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(outerCall)); + } + + @Test public void testIsNumericLiteralSqlNodeWithNonArithmeticCall() { + // Test with a non-arithmetic call containing numeric literals + SqlNumericLiteral numericLiteral1 = SqlNumericLiteral.createExactNumeric("123.45", + SqlParserPos.ZERO); + SqlNumericLiteral numericLiteral2 = SqlNumericLiteral.createExactNumeric("678.90", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.LIKE.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertFalse(SqlNodeUtils.isNumericLiteralSqlNode(call)); + } + + @Test public void testIsNumericLiteralSqlNodeWithDifferentArithmeticOperators() { + // Test with different arithmetic operators + SqlNumericLiteral numericLiteral1 = + SqlNumericLiteral.createExactNumeric("10", SqlParserPos.ZERO); + SqlNumericLiteral numericLiteral2 = + SqlNumericLiteral.createExactNumeric("5", SqlParserPos.ZERO); + + // Test PLUS + SqlCall plusCall = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(plusCall)); + + // Test MINUS + SqlCall minusCall = SqlStdOperatorTable.MINUS.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(minusCall)); + + // Test MULTIPLY + SqlCall timesCall = SqlStdOperatorTable.MULTIPLY.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(timesCall)); + + // Test DIVIDE + SqlCall divideCall = SqlStdOperatorTable.DIVIDE.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(divideCall)); + + // Test MOD + SqlCall modCall = SqlStdOperatorTable.MOD.createCall(SqlParserPos.ZERO, numericLiteral1, + numericLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(modCall)); + } + + @Test public void testIsNumericLiteralSqlNodeWithApproximateNumeric() { + // Test with approximate numeric literals + SqlNumericLiteral approxLiteral1 = SqlNumericLiteral.createApproxNumeric("1.23E45", + SqlParserPos.ZERO); + SqlNumericLiteral approxLiteral2 = SqlNumericLiteral.createApproxNumeric("6.78E90", + SqlParserPos.ZERO); + SqlCall call = SqlStdOperatorTable.PLUS.createCall(SqlParserPos.ZERO, approxLiteral1, + approxLiteral2); + assertTrue(SqlNodeUtils.isNumericLiteralSqlNode(call)); + } +}