From 5b6e20b137d111c464562416eb0bb1dd7fdd97e2 Mon Sep 17 00:00:00 2001 From: kaushik327 Date: Mon, 5 Feb 2024 16:08:29 -0600 Subject: [PATCH 1/3] [Rewriter] Bugfix: StrInCol handles non-int64 indices --- dias/rewriter.py | 2 +- tests/str_in_col-index.ipynb | 105 +++++++++++++++++++++++++++++++++++ tests/str_in_col-index.json | 24 ++++++++ tests/str_in_col.json | 2 +- 4 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 tests/str_in_col-index.ipynb create mode 100644 tests/str_in_col-index.json diff --git a/dias/rewriter.py b/dias/rewriter.py index bc9e349..ad416ae 100644 --- a/dias/rewriter.py +++ b/dias/rewriter.py @@ -1043,7 +1043,7 @@ def rewrite_ast(cell_ast: ast.Module) -> Tuple[str, Dict]: # We can specialize this for an index that is an int. Try to convert the string to int # and if you fail, contains_expr = f"astype(str).str.contains('{the_str}').any()" - new_expr = f"({the_sub}.{contains_expr} or _REWR_index_contains({the_sub}.index, '{the_str}')) if type({df}) == pd.DataFrame else ({orig})" + new_expr = f"({the_sub}.{contains_expr} or _REWR_index_contains({the_sub}.index, '{the_str}')) if (type({df}) == pd.DataFrame and {the_sub}.index.dtype == np.int64) else ({orig})" str_in_col.cmp_encl.set_enclosed_obj(ast.parse(new_expr, mode='eval')) ### END OF LOOP ### diff --git a/tests/str_in_col-index.ipynb b/tests/str_in_col-index.ipynb new file mode 100644 index 0000000..fd3b180 --- /dev/null +++ b/tests/str_in_col-index.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import dias.rewriter" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv('./datasets/dataranch__supermarket-sales-prediction-xgboost-fastai__SampleSuperstore.csv')\n", + "df_multi_index = df.set_index(['Postal Code', 'Sub-Category'])" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Setting a MultiIndex dtype to anything other than object is not supported", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/var/folders/sh/36vdgc9s2vqgz2sjl830ryvm0000gn/T/ipykernel_86399/1440505495.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdf_multi_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m if (df_multi_index[col].astype(str).str.contains('ou').any() or\n\u001b[0;32m---> 13\u001b[0;31m _REWR_index_contains(df_multi_index[col].index, 'ou') if type(\n\u001b[0m\u001b[1;32m 14\u001b[0m df_multi_index) == pd.DataFrame else 'ou' in df_multi_index[col].\n\u001b[1;32m 15\u001b[0m to_string()):\n", + "\u001b[0;32m/var/folders/sh/36vdgc9s2vqgz2sjl830ryvm0000gn/T/ipykernel_86399/1440505495.py\u001b[0m in \u001b[0;36m_REWR_index_contains\u001b[0;34m(index, s)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontains\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdf_multi_index\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m if (df_multi_index[col].astype(str).str.contains('ou').any() or\n", + "\u001b[0;32m~/opt/anaconda3/lib/python3.9/site-packages/pandas/core/indexes/multi.py\u001b[0m in \u001b[0;36mastype\u001b[0;34m(self, dtype, copy)\u001b[0m\n\u001b[1;32m 3754\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3755\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_object_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3756\u001b[0;31m raise TypeError(\n\u001b[0m\u001b[1;32m 3757\u001b[0m \u001b[0;34m\"Setting a MultiIndex dtype to anything other than object \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3758\u001b[0m \u001b[0;34m\"is not supported\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: Setting a MultiIndex dtype to anything other than object is not supported" + ] + } + ], + "source": [ + "our = []\n", + "for col in df_multi_index.columns:\n", + " if 'ou' in df_multi_index[col].to_string():\n", + " our.append(col)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "# DIAS_DISABLE\n", + "defa = []\n", + "for col in df_multi_index.columns:\n", + " if 'ou' in df_multi_index[col].to_string():\n", + " defa.append(col)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "([], ['City', 'State', 'Region'])" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "assert our == defa" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/str_in_col-index.json b/tests/str_in_col-index.json new file mode 100644 index 0000000..8da083c --- /dev/null +++ b/tests/str_in_col-index.json @@ -0,0 +1,24 @@ +{ + "cells": [ + { + "raw": "\ndf = pd.read_csv('./datasets/dataranch__supermarket-sales-prediction-xgboost-fastai__SampleSuperstore.csv')\ndf_multi_index = df.set_index(['Postal Code', 'Sub-Category'])\n", + "modified": "df = pd.read_csv(\n './datasets/dataranch__supermarket-sales-prediction-xgboost-fastai__SampleSuperstore.csv'\n )\ndf_multi_index = df.set_index(['Postal Code', 'Sub-Category'])\n", + "patts-hit": {}, + "rewritten-exec-time": 22.590208 + }, + { + "raw": "\nour = []\nfor col in df_multi_index.columns:\n if 'ou' in df_multi_index[col].to_string():\n our.append(col)\n", + "modified": "our = []\ndef _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\nfor col in df_multi_index.columns:\n if (df_multi_index[col].astype(str).str.contains('ou').any() or\n _REWR_index_contains(df_multi_index[col].index, 'ou') if type(\n df_multi_index) == pd.DataFrame and df_multi_index[col].index.dtype ==\n np.int64 else 'ou' in df_multi_index[col].to_string()):\n our.append(col)\n", + "patts-hit": { + "MultipleStrInCol": 1 + }, + "rewritten-exec-time": 932.628458 + }, + { + "raw": "\nassert our == defa\n", + "modified": "assert our == defa\n", + "patts-hit": {}, + "rewritten-exec-time": 0.211667 + } + ] +} \ No newline at end of file diff --git a/tests/str_in_col.json b/tests/str_in_col.json index fb59894..fdfa7cd 100644 --- a/tests/str_in_col.json +++ b/tests/str_in_col.json @@ -10,7 +10,7 @@ }, { "raw":"\nfor col in df.columns:\n if '%' in df[col].to_string() or ',' in df[col].to_string():\n pass\n", - "modified":"def _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\nfor col in df.columns:\n if (df[col].astype(str).str.contains('%').any() or _REWR_index_contains\n (df[col].index, '%') if type(df) == pd.DataFrame else '%' in df[col\n ].to_string()) or (df[col].astype(str).str.contains(',').any() or\n _REWR_index_contains(df[col].index, ',') if type(df) == pd.\n DataFrame else ',' in df[col].to_string()):\n pass\n", + "modified":"def _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\nfor col in df.columns:\n if (df[col].astype(str).str.contains('%').any() or _REWR_index_contains\n (df[col].index, '%') if type(df) == pd.DataFrame and df[col].index.\n dtype == np.int64 else '%' in df[col].to_string()) or (df[col].\n astype(str).str.contains(',').any() or _REWR_index_contains(df[col]\n .index, ',') if type(df) == pd.DataFrame and df[col].index.dtype ==\n np.int64 else ',' in df[col].to_string()):\n pass\n", "patts-hit":{ "MultipleStrInCol":1 }, From 0b3c8a124d3dbdd87a91badc3b887ba3fee65cde Mon Sep 17 00:00:00 2001 From: kaushik327 Date: Mon, 5 Feb 2024 23:46:08 -0600 Subject: [PATCH 2/3] [Rewriter] Bug fix: StrInCol didn't work for columns as strings --- dias/rewriter.py | 3 ++- tests/str_in_col.ipynb | 11 ++++++++++- tests/str_in_col.json | 8 ++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/dias/rewriter.py b/dias/rewriter.py index ad416ae..4964993 100644 --- a/dias/rewriter.py +++ b/dias/rewriter.py @@ -1031,10 +1031,11 @@ def rewrite_ast(cell_ast: ast.Module) -> Tuple[str, Dict]: col = None if isinstance(str_in_col.the_sub.slice, ast.Name): col = str_in_col.the_sub.slice.id + the_sub = f"{df}[{col}]" else: assert isinstance(str_in_col.the_sub.slice, ast.Constant) col = str_in_col.the_sub.slice.value - the_sub = f"{df}[{col}]" + the_sub = f"{df}['{col}']" orig = f"'{the_str}' in {the_sub}.to_string()" # You need to be careful when handling strings like that because you might diff --git a/tests/str_in_col.ipynb b/tests/str_in_col.ipynb index 3f1c01b..a524641 100644 --- a/tests/str_in_col.ipynb +++ b/tests/str_in_col.ipynb @@ -94,6 +94,15 @@ " if '%' in df[col].to_string() or ',' in df[col].to_string():\n", " pass" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "'41' in df['Profit'].to_string()" + ] } ], "metadata": { @@ -112,7 +121,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.13" } }, "nbformat": 4, diff --git a/tests/str_in_col.json b/tests/str_in_col.json index fdfa7cd..997eadf 100644 --- a/tests/str_in_col.json +++ b/tests/str_in_col.json @@ -15,6 +15,14 @@ "MultipleStrInCol":1 }, "rewritten-exec-time":68.81415 + }, + { + "raw": "\n'41' in df['Profit'].to_string()\n", + "modified": "def _REWR_index_contains(index, s):\n if index.dtype == np.int64:\n try:\n i = int(s)\n return len(index.loc[i]) > 0\n except:\n return False\n else:\n return index.astype(str).str.contains(s).any()\n(df['Profit'].astype(str).str.contains('41').any() or _REWR_index_contains(\n df['Profit'].index, '41') if type(df) == pd.DataFrame and df['Profit'].\n index.dtype == np.int64 else '41' in df['Profit'].to_string())\n", + "patts-hit": { + "MultipleStrInCol": 1 + }, + "rewritten-exec-time": 129.444458 } ] } \ No newline at end of file From 349060657edd0452ace4206781f9a50010484b70 Mon Sep 17 00:00:00 2001 From: kaushik327 Date: Thu, 15 Feb 2024 16:49:51 -0600 Subject: [PATCH 3/3] [Rewriter] making StrInCol support more kinds of quotation marks in strings --- dias/rewriter.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dias/rewriter.py b/dias/rewriter.py index 4964993..aedf1fe 100644 --- a/dias/rewriter.py +++ b/dias/rewriter.py @@ -1025,7 +1025,7 @@ def rewrite_ast(cell_ast: ast.Module) -> Tuple[str, Dict]: # Get string versions - the_str = str_in_col.the_str.value + the_str = astor.to_source(str_in_col.the_str) assert isinstance(str_in_col.the_sub.value, ast.Name) df = str_in_col.the_sub.value.id col = None @@ -1034,17 +1034,17 @@ def rewrite_ast(cell_ast: ast.Module) -> Tuple[str, Dict]: the_sub = f"{df}[{col}]" else: assert isinstance(str_in_col.the_sub.slice, ast.Constant) - col = str_in_col.the_sub.slice.value - the_sub = f"{df}['{col}']" - orig = f"'{the_str}' in {the_sub}.to_string()" + col = astor.to_source(str_in_col.the_sub.slice) + the_sub = f"{df}[{col}]" + orig = f"{the_str} in {the_sub}.to_string()" # You need to be careful when handling strings like that because you might # miss parentheses. # We can specialize this for an index that is an int. Try to convert the string to int # and if you fail, - contains_expr = f"astype(str).str.contains('{the_str}').any()" - new_expr = f"({the_sub}.{contains_expr} or _REWR_index_contains({the_sub}.index, '{the_str}')) if (type({df}) == pd.DataFrame and {the_sub}.index.dtype == np.int64) else ({orig})" + contains_expr = f"astype(str).str.contains({the_str}).any()" + new_expr = f"({the_sub}.{contains_expr} or _REWR_index_contains({the_sub}.index, {the_str})) if (type({df}) == pd.DataFrame and {the_sub}.index.dtype == np.int64) else ({orig})" str_in_col.cmp_encl.set_enclosed_obj(ast.parse(new_expr, mode='eval')) ### END OF LOOP ###