Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 66 additions & 54 deletions datafusion/functions-aggregate/src/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,54 +402,6 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
Ok(())
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
// Resize vectors to accommodate total number of groups
self.count.resize(total_num_groups, 0);
self.sum_x.resize(total_num_groups, 0.0);
self.sum_y.resize(total_num_groups, 0.0);
self.sum_xy.resize(total_num_groups, 0.0);
self.sum_xx.resize(total_num_groups, 0.0);
self.sum_yy.resize(total_num_groups, 0.0);

// Extract arrays from input values
let partial_counts = values[0].as_primitive::<UInt64Type>();
let partial_sum_x = values[1].as_primitive::<Float64Type>();
let partial_sum_y = values[2].as_primitive::<Float64Type>();
let partial_sum_xy = values[3].as_primitive::<Float64Type>();
let partial_sum_xx = values[4].as_primitive::<Float64Type>();
let partial_sum_yy = values[5].as_primitive::<Float64Type>();

assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");

accumulate_correlation_states(
group_indices,
(
partial_counts,
partial_sum_x,
partial_sum_y,
partial_sum_xy,
partial_sum_xx,
partial_sum_yy,
),
|group_index, count, values| {
self.count[group_index] += count;
self.sum_x[group_index] += values[0];
self.sum_y[group_index] += values[1];
self.sum_xy[group_index] += values[2];
self.sum_xx[group_index] += values[3];
self.sum_yy[group_index] += values[4];
},
);

Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let n = match emit_to {
EmitTo::All => self.count.len(),
Expand All @@ -465,21 +417,33 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
// - Correlation can't be calculated when a group only has 1 record, or when
// the `denominator` state is 0. In these cases, the final aggregation
// result should be `Null` (according to PostgreSQL's behavior).
// - However, if any of the accumulated values contain NaN, the result should
// be NaN regardless of the count (even for single-row groups).
//
for i in 0..n {
if self.count[i] < 2 {
values.push(0.0);
nulls.append_null();
continue;
}

let count = self.count[i];
let sum_x = self.sum_x[i];
let sum_y = self.sum_y[i];
let sum_xy = self.sum_xy[i];
let sum_xx = self.sum_xx[i];
let sum_yy = self.sum_yy[i];

// Check for NaN in the sums BEFORE checking count
// If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN
// If only ONE of them is NaN, then only one input value is NaN → return NULL
// This takes precedence over the count < 2 check
if sum_x.is_nan() && sum_y.is_nan() {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider aligning CorrelationAccumulator.evaluate() (non-group path) with this NaN precedence: it currently returns NULL when n < 2, so a single row with both inputs NaN yields NULL, whereas grouped corr now yields NaN in the same scenario. This inconsistency can lead to different results for grouped vs non-grouped queries with NaNs.

🤖 Was this useful? React with 👍 or 👎

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value:useful; category:bug; feedback:The Augment AI reviewer is correct that the same check for NaN should be applied in CorrelationAccumulator::evaluate() if the count is less than 2. Otherwise it would be inconsistent - CorrelationGroupsAccumulator would return NaN, CorrelationAccumulator would return Null

// Both inputs are NaN → return NaN
values.push(f64::NAN);
nulls.append_non_null();
continue;
} else if count < 2 || sum_x.is_nan() || sum_y.is_nan() {
// Only one input is NaN → return NULL
values.push(0.0);
nulls.append_null();
continue;
}

let mean_x = sum_x / count as f64;
let mean_y = sum_y / count as f64;

Expand Down Expand Up @@ -515,6 +479,54 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
])
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
// Resize vectors to accommodate total number of groups
self.count.resize(total_num_groups, 0);
self.sum_x.resize(total_num_groups, 0.0);
self.sum_y.resize(total_num_groups, 0.0);
self.sum_xy.resize(total_num_groups, 0.0);
self.sum_xx.resize(total_num_groups, 0.0);
self.sum_yy.resize(total_num_groups, 0.0);

// Extract arrays from input values
let partial_counts = values[0].as_primitive::<UInt64Type>();
let partial_sum_x = values[1].as_primitive::<Float64Type>();
let partial_sum_y = values[2].as_primitive::<Float64Type>();
let partial_sum_xy = values[3].as_primitive::<Float64Type>();
let partial_sum_xx = values[4].as_primitive::<Float64Type>();
let partial_sum_yy = values[5].as_primitive::<Float64Type>();

assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");

accumulate_correlation_states(
group_indices,
(
partial_counts,
partial_sum_x,
partial_sum_y,
partial_sum_xy,
partial_sum_xx,
partial_sum_yy,
),
|group_index, count, values| {
self.count[group_index] += count;
self.sum_x[group_index] += values[0];
self.sum_y[group_index] += values[1];
self.sum_xy[group_index] += values[2];
self.sum_xx[group_index] += values[3];
self.sum_yy[group_index] += values[4];
},
);

Ok(())
}

fn size(&self) -> usize {
size_of_val(&self.count)
+ size_of_val(&self.sum_x)
Expand Down
58 changes: 58 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,64 @@ from data
----
1

# correlation_query_with_nans_f32
query IR
with data as (
select 1 id, 1 as f, 'nan'::float as b
union all
select 2 id, 'nan'::float as f, 1 as b
union all
select 3 id, 'nan'::float as f, null as b
union all
select 4 id, null as f, 'nan'::float as b
union all
select 5 id, 'nan'::float as f, 'nan'::float as b
union all
select 5 id, 1 as f, 1 as b
union all
select 6 id, 'nan'::float as f, 'nan'::float as b
)
select id, corr(f, b)
from data
group by id
order by id
----
1 NULL
2 NULL
3 NULL
4 NULL
5 NaN
6 NaN

# correlation_query_with_nans_f64
query IR
with data as (
select 1 id, 1 as f, 'nan'::double as b
union all
select 2 id, 'nan'::double as f, 1 as b
union all
select 3 id, 'nan'::double as f, null as b
union all
select 4 id, null as f, 'nan'::double as b
union all
select 5 id, 'nan'::double as f, 'nan'::double as b
union all
select 5 id, 1 as f, 1 as b
union all
select 6 id, 'nan'::double as f, 'nan'::double as b
)
select id, corr(f, b)
from data
group by id
order by id
----
1 NULL
2 NULL
3 NULL
4 NULL
5 NaN
6 NaN

# csv_query_variance_1
query R
SELECT var_pop(c2) FROM aggregate_test_100
Expand Down
Loading