Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 45 additions & 53 deletions py-polars/src/polars/dataframe/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
import sys
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Iterator
from datetime import timedelta

from polars import DataFrame
Expand All @@ -24,17 +24,42 @@
)
from polars.lazyframe.group_by import LazyGroupBy

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

if sys.version_info >= (3, 13):
from warnings import deprecated
else:
from typing_extensions import deprecated # noqa: TC004


class _GroupByIter:
"""Iterator object for GroupBy iteration."""

__slots__ = ("_current_index", "_df", "_group_indices", "_group_names")

def __init__(
self,
df: DataFrame,
group_names: Iterator[tuple[Any, ...]],
group_indices: Any,
) -> None:
self._df = df
self._group_names = group_names
self._group_indices = group_indices
self._current_index = 0

def __iter__(self) -> _GroupByIter:
return self

def __next__(self) -> tuple[tuple[Any, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

group_name = next(self._group_names)
group_data = self._df[self._group_indices[self._current_index], :]
self._current_index += 1

return group_name, group_data


class GroupBy:
"""Starts a new GroupBy operation."""

Expand Down Expand Up @@ -81,7 +106,7 @@ def _lgb(self) -> LazyGroupBy:
return group_by.having(self.predicates)
return group_by

def __iter__(self) -> Self:
def __iter__(self) -> _GroupByIter:
"""
Allows iteration over the groups of the group by operation.

Expand Down Expand Up @@ -117,31 +142,20 @@ def __iter__(self) -> Self:
# Every group gather can trigger a rechunk, so do early.
from polars.lazyframe.opt_flags import QueryOptFlags

self.df = self.df.rechunk()
df = self.df.rechunk()
temp_col = "__POLARS_GB_GROUP_INDICES"
groups_df = (
self.df.lazy()
df.lazy()
.with_row_index("__POLARS_GB_ROW_INDEX")
.group_by(*self.by, **self.named_by, maintain_order=self.maintain_order)
.agg(F.first().alias(temp_col))
.collect(optimizations=QueryOptFlags.none())
)

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(self) -> tuple[tuple[Any, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration
group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
group_indices = groups_df.select(temp_col).to_series()

group_name = next(self._group_names)
group_data = self.df[self._group_indices[self._current_index], :]
self._current_index += 1

return group_name, group_data
return _GroupByIter(df, group_names, group_indices)

def having(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> GroupBy:
"""
Expand Down Expand Up @@ -884,7 +898,7 @@ def __init__(
self.group_by = group_by
self.predicates = predicates

def __iter__(self) -> Self:
def __iter__(self) -> _GroupByIter:
from polars.lazyframe.opt_flags import QueryOptFlags

temp_col = "__POLARS_GB_GROUP_INDICES"
Expand All @@ -902,21 +916,10 @@ def __iter__(self) -> Self:
.collect(optimizations=QueryOptFlags.none())
)

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0
group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
group_indices = groups_df.select(temp_col).to_series()

return self

def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

group_name = next(self._group_names)
group_data = self.df[self._group_indices[self._current_index], :]
self._current_index += 1

return group_name, group_data
return _GroupByIter(self.df, group_names, group_indices)

def having(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> RollingGroupBy:
"""
Expand Down Expand Up @@ -1066,7 +1069,7 @@ def __init__(
self.start_by = start_by
self.predicates = predicates

def __iter__(self) -> Self:
def __iter__(self) -> _GroupByIter:
from polars.lazyframe.opt_flags import QueryOptFlags

temp_col = "__POLARS_GB_GROUP_INDICES"
Expand All @@ -1088,21 +1091,10 @@ def __iter__(self) -> Self:
.collect(optimizations=QueryOptFlags.none())
)

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0
group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
group_indices = groups_df.select(temp_col).to_series()

return self

def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

group_name = next(self._group_names)
group_data = self.df[self._group_indices[self._current_index], :]
self._current_index += 1

return group_name, group_data
return _GroupByIter(self.df, group_names, group_indices)

def having(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> DynamicGroupBy:
"""
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -3080,3 +3080,10 @@ def test_group_by_max_by_min_by_string_single_element_27171() -> None:
result = df.group_by("key", maintain_order=True).agg(pl.col("val").min_by("by"))
assert result.filter(pl.col("key") == "a")["val"][0] == 10
assert result.filter(pl.col("key") == "b")["val"][0] == 30


def test_group_by_next_raises_type_error_12868() -> None:
df = pl.DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]})
gb = df.group_by("a")
with pytest.raises(TypeError, match="GroupBy"):
next(gb) # type: ignore[call-overload]
Loading