diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index d08501674d5b..d2429a2192e7 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -381,6 +381,16 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } } + test("Test get_json_object 13") { + val sql = + """ + |SELECT + | explode(array(get_json_object(get_json_object('{"a": "{\\\"b\\\":1}"}', '$.a'), '$.b'))) + | from range(1) + |""".stripMargin + runQueryAndCompare(sql) { df => } + } + test("GLUTEN-8557: Optimize nested and/or") { def checkFlattenedFunctions(plan: SparkPlan, functionName: String, argNum: Int): Boolean = { diff --git a/cpp-ch/local-engine/Parser/RelParsers/ProjectRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/ProjectRelParser.cpp index d9d2e55f3436..97b41661eead 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/ProjectRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/ProjectRelParser.cpp @@ -116,7 +116,10 @@ DB::QueryPlanPtr ProjectRelParser::parseReplicateRows(DB::QueryPlanPtr query_pla DB::QueryPlanPtr ProjectRelParser::parseGenerate(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list & /*rel_stack_*/) { - const auto & generate_rel = rel.generate(); + ExpressionsRewriter rewriter(parser_context); + substrait::Rel final_rel = rel; + rewriter.rewrite(final_rel); + const auto & generate_rel = final_rel.generate(); if (isReplicateRows(generate_rel)) { return parseReplicateRows(std::move(query_plan), generate_rel); diff --git a/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h b/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h index 8d4fcaca4420..ba44b7925fc7 100644 --- a/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h +++ b/cpp-ch/local-engine/Rewriter/ExpressionRewriter.h @@ -38,10 +38,6 @@ class GetJsonObjectFunctionWriter : public RelRewriter void rewrite(substrait::Rel & rel) override { - if (!rel.has_filter() && !rel.has_project()) - { - return; - } prepare(rel); rewriteImpl(rel); } @@ -64,6 +60,14 @@ class GetJsonObjectFunctionWriter : public RelRewriter prepareOnExpression(expr); } } + if (rel.has_generate()) + { + for (auto & expr : rel.generate().child_output()) + { + prepareOnExpression(expr); + } + prepareOnExpression(rel.generate().generator()); + } } void rewriteImpl(substrait::Rel & rel) @@ -84,6 +88,17 @@ class GetJsonObjectFunctionWriter : public RelRewriter rewriteExpression(*expr); } } + if (rel.has_generate()) + { + auto * generate = rel.mutable_generate(); + auto * child_outputs = generate->mutable_child_output(); + for (int i = 0; i < child_outputs->size(); ++i) + { + auto * expr = child_outputs->Mutable(i); + rewriteExpression(*expr); + } + rewriteExpression(*generate->mutable_generator()); + } } void prepareOnExpression(const substrait::Expression & expr) {