Skip to content

Commit cdad91b

Browse files
authored
Merge pull request #3 from slgero/optimize_pandas_apply
Optimize pandas apply
2 parents e3191c5 + 97aba74 commit cdad91b

File tree

4 files changed

+159
-47
lines changed

4 files changed

+159
-47
lines changed

receipt_parser/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""A package which allow parsing Reussian receipts."""
22

3-
__version__ = "0.0.23"
3+
__version__ = "0.0.24"
44
__license__ = "MIT"
55

66

receipt_parser/finder.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,42 @@
1515
# pylint: disable=C1801
1616

1717

18+
def df_apply(data: pd.DataFrame, func, axis: int = 1) -> pd.DataFrame:
19+
"""
20+
User define the `apply` function from pd.DataFrame.
21+
Use only for 2-column and 3-column data.
22+
23+
Parameters
24+
----------
25+
data : pd.DataFrame
26+
The data on which the `func` function will be applied.
27+
func : function
28+
Function to apply to each column or row.
29+
axis : {0 or 'index', 1 or 'columns'}, default=1
30+
Axis along which the function is applied.
31+
32+
Returns
33+
-------
34+
pd.DataFrame
35+
Result of applying ``func`` along the given axis of the
36+
DataFrame.
37+
38+
Examples
39+
--------
40+
>>> from pandas import DataFrame
41+
42+
>>> DataFrame.my_apply = df_apply
43+
>>> df[['name', 'brand']].my_apply(foo)
44+
"""
45+
46+
_cols = data.columns
47+
_len = len(_cols)
48+
49+
if _len == 2:
50+
return data.apply(lambda x: func(x[_cols[0]], x[_cols[1]]), axis=axis)
51+
return data.apply(lambda x: func(x[_cols[0]], x[_cols[1]], x[_cols[2]]), axis=axis)
52+
53+
1854
class Finder:
1955
"""
2056
Search and recognize the name, category and brand of a product
@@ -63,6 +99,7 @@ class Finder:
6399
def __init__(self, pathes: Optional[Dict[str, str]] = None):
64100
pathes = pathes or {}
65101
self.mystem = Mystem()
102+
pd.DataFrame.appl = df_apply
66103

67104
# Init model:
68105
model_params = {"num_class": 21, "embed_dim": 50, "vocab_size": 500}
@@ -308,60 +345,48 @@ def __find_all(self, verbose: int) -> None:
308345
self.__print_logs("Before:", verbose)
309346

310347
# Find brands:
311-
self.data[["name_norm", "brand_norm"]] = self.data.apply(
312-
lambda x: self.find_brands(x["name_norm"], x["brand_norm"]), axis=1
313-
)
348+
self.data[["name_norm", "brand_norm"]] = self.data[
349+
["name_norm", "brand_norm"]
350+
].appl(self.find_brands)
314351
self.__print_logs("Find brands:", verbose)
315352

316353
# Find product and category:
317-
self.data[["name_norm", "product_norm", "cat_norm"]] = self.data.apply(
318-
lambda x: self.find_product(x["name_norm"], x["product_norm"]), axis=1
319-
)
354+
self.data[["name_norm", "product_norm", "cat_norm"]] = self.data[
355+
["name_norm", "product_norm"]
356+
].appl(self.find_product)
320357
self.__print_logs("Find product and category:", verbose)
321358

322359
# Remove `-`:
323360
self.data["name_norm"] = self.data["name_norm"].str.replace("-", " ")
324-
self.data[["name_norm", "product_norm", "cat_norm"]] = self.data.apply(
325-
lambda x: self.find_product(
326-
x["name_norm"], x["product_norm"], x["cat_norm"]
327-
),
328-
axis=1,
329-
)
361+
self.data[["name_norm", "product_norm", "cat_norm"]] = self.data[
362+
["name_norm", "product_norm", "cat_norm"]
363+
].appl(self.find_product)
330364
self.__print_logs(
331365
"Remove `-` and the second attempt to find a product:", verbose
332366
)
333367

334368
# Use Mystem:
335-
self.data["name_norm"] = self.data.apply(
336-
lambda x: self._use_mystem(x["name_norm"], x["product_norm"]), axis=1
337-
)
338-
self.data[["name_norm", "product_norm", "cat_norm"]] = self.data.apply(
339-
lambda x: self.find_product(
340-
x["name_norm"], x["product_norm"], x["cat_norm"]
341-
),
342-
axis=1,
369+
self.data["name_norm"] = self.data[["name_norm", "product_norm"]].appl(
370+
self._use_mystem
343371
)
372+
self.data[["name_norm", "product_norm", "cat_norm"]] = self.data[
373+
["name_norm", "product_norm", "cat_norm"]
374+
].appl(self.find_product)
344375
self.__print_logs(
345376
"Use Mystem for lemmatization and the third attempt to find a product:",
346377
verbose,
347378
)
348379

349380
# Find category:
350-
self.data[["product_norm", "cat_norm"]] = self.data.apply(
351-
lambda x: self.find_category(
352-
x["name_norm"], x["product_norm"], x["cat_norm"]
353-
),
354-
axis=1,
355-
)
381+
self.data[["product_norm", "cat_norm"]] = self.data[
382+
["name_norm", "product_norm", "cat_norm"]
383+
].appl(self.find_category)
356384
self.__print_logs("Find the remaining categories:", verbose)
357385

358386
# Find product by brand:
359-
self.data[["product_norm", "brand_norm", "cat_norm"]] = self.data.apply(
360-
lambda x: self.find_product_by_brand(
361-
x["product_norm"], x["brand_norm"], x["cat_norm"]
362-
),
363-
axis=1,
364-
)
387+
self.data[["product_norm", "brand_norm", "cat_norm"]] = self.data[
388+
["name_norm", "product_norm", "cat_norm"]
389+
].appl(self.find_product)
365390
self.__print_logs("Find product by brand:", verbose)
366391

367392
def find_all(

receipt_parser/normalizer.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import re
33
from typing import Optional, Union, Dict
44
import pandas as pd # type: ignore
5+
from pandarallel import pandarallel # type: ignore
6+
7+
pandarallel.initialize(progress_bar=False, verbose=0)
58

69
try:
710
# pylint: disable=line-too-long
@@ -10,6 +13,87 @@
1013
from dicts import PRODUCTS, BRANDS, SLASH_PRODUCTS, BRANDS_WITH_NUMBERS # type: ignore
1114

1215

16+
# pylint: disable=bad-continuation
17+
class Apply:
18+
"""User define the `apply` function from pd.Series and pd.DataFrame"""
19+
20+
@staticmethod
21+
def series_apply(data: pd.Series, func, use_parallel: Optional[bool] = None):
22+
"""
23+
User define the `apply` function from pd.Series.
24+
25+
Parameters
26+
----------
27+
data : pd.Series
28+
The data on which the `func` function will be applied.
29+
func : function
30+
Function to apply to each column or row.
31+
use_parallel : Optional[bool], default=None
32+
Multiprocessing will be used if the data size is greater than 30000.
33+
34+
Returns
35+
-------
36+
pd.DataFrame
37+
Result of applying ``func`` on the Series.
38+
39+
Examples
40+
--------
41+
>>> from pandas import Series
42+
43+
>>> Series.my_apply = series_apply
44+
>>> df['name'].my_apply(foo)
45+
"""
46+
47+
if use_parallel is None:
48+
use_parallel = len(data) >= 10000
49+
if use_parallel:
50+
return data.parallel_apply(func)
51+
return data.apply(func)
52+
53+
@staticmethod
54+
def df_apply(
55+
data: pd.DataFrame, func, use_parallel: Optional[bool] = None, axis: int = 1
56+
) -> pd.DataFrame:
57+
"""
58+
User define the `apply` function from pd.DataFrame.
59+
Use only for 2-column data.
60+
61+
Parameters
62+
----------
63+
data : pd.DataFrame
64+
The data on which the `func` function will be applied.
65+
func : function
66+
Function to apply to each column or row.
67+
use_parallel : Optional[bool], default=None
68+
Multiprocessing will be used if the data size is greater than 30000.
69+
axis : {0 or 'index', 1 or 'columns'}, default=1
70+
Axis along which the function is applied.
71+
72+
Returns
73+
-------
74+
pd.DataFrame
75+
Result of applying ``func`` along the given axis of the DataFrame.
76+
77+
Examples
78+
--------
79+
>>> from pandas import DataFrame
80+
81+
>>> DataFrame.my_apply = df_apply
82+
>>> df[['name', 'brand']].my_apply(foo)
83+
"""
84+
85+
_cols = data.columns
86+
87+
if use_parallel is None:
88+
use_parallel = len(data) >= 10000
89+
90+
if use_parallel:
91+
return data.parallel_apply(
92+
lambda x: func(x[_cols[0]], x[_cols[1]]), axis=axis
93+
)
94+
return data.apply(lambda x: func(x[_cols[0]], x[_cols[1]]), axis=axis)
95+
96+
1397
class Normalizer:
1498
"""
1599
Normalize product description: expand abbreviations,
@@ -52,6 +136,10 @@ def __init__(self, pathes: Optional[Dict[str, str]] = None):
52136
pathes.get("brands_en", "data/cleaned/brands_en.csv")
53137
)["brand"].values
54138

139+
# Init user define apply function:
140+
pd.DataFrame.appl = Apply.df_apply
141+
pd.Series.appl = Apply.series_apply
142+
55143
@staticmethod
56144
def _remove_numbers(name: str) -> pd.Series:
57145
"""Remove all words in product description which contain numbers."""
@@ -162,20 +250,17 @@ def normalize(self, data: Union[pd.Series, str]) -> pd.DataFrame:
162250

163251
data = self.__transform_data(data)
164252
data["name_norm"] = data["name"].str.lower()
165-
data[["name_norm", "brand_norm"]] = data["name_norm"].apply(
166-
self._remove_numbers
167-
)
168-
data[["name_norm", "product_norm", "brand_norm"]] = data.apply(
169-
lambda x: self._remove_punctuation(x["name_norm"], x["brand_norm"]), axis=1
170-
)
171-
data["name_norm"] = data["name_norm"].apply(self._remove_one_and_two_chars)
172-
data[["name_norm", "brand_norm"]] = data.apply(
173-
lambda x: self.find_en_brands(x["name_norm"], x["brand_norm"]), axis=1
253+
data[["name_norm", "brand_norm"]] = data["name_norm"].appl(self._remove_numbers)
254+
data[["name_norm", "product_norm", "brand_norm"]] = data[
255+
["name_norm", "brand_norm"]
256+
].appl(self._remove_punctuation)
257+
data["name_norm"] = data["name_norm"].appl(self._remove_one_and_two_chars)
258+
data[["name_norm", "brand_norm"]] = data[["name_norm", "brand_norm"]].appl(
259+
self.find_en_brands
174260
)
175-
data["name_norm"] = data["name_norm"].apply(self._remove_words_in_blacklist)
176-
data["name_norm"] = data["name_norm"].apply(self._replace_with_product_dict)
177-
data[["name_norm", "brand_norm"]] = data.apply(
178-
lambda x: self._remove_all_english_words(x["name_norm"], x["brand_norm"]),
179-
axis=1,
261+
data["name_norm"] = data["name_norm"].appl(self._remove_words_in_blacklist)
262+
data["name_norm"] = data["name_norm"].appl(self._replace_with_product_dict)
263+
data[["name_norm", "brand_norm"]] = data[["name_norm", "brand_norm"]].appl(
264+
self._remove_all_english_words
180265
)
181266
return data

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
numpy >= 1.18.3
22
pandas >= 1.0.3
3+
pandarallel >= 1.4.8
34
pymystem3 >= 0.2.0
5+
setuptools
46
torch
57
torchvision
68
wget >= 3.2

0 commit comments

Comments
 (0)