Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 70b5ae8

Browse files
Implements init_dataframe as multiple codegen functions (#936)
Motivation: init_dataframe was implemented via Numba intrinsic taking *args, which seems to generate redundant extractvalue/insertvalue LLVM instructions, producing quadratic IR when number of DF columns grows and affecting total compilation time of function that create large DFs. This PR replaces singe init_dataframe with multiple functions basing on number of columns in a DF which are generated at compile time, thus avoiding use of *args.
1 parent ce4142a commit 70b5ae8

7 files changed

Lines changed: 165 additions & 89 deletions

File tree

sdc/datatypes/hpat_pandas_dataframe_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ def df_getitem_slice_idx_main_codelines(self, idx):
15001500
res_data = f'res_data_{i}'
15011501
func_lines += [
15021502
f' data_{i} = self._data[{type_id}][{col_id}][idx]',
1503-
f' {res_data} = pandas.Series(data_{i}, index=res_index, name="{col}")'
1503+
f' {res_data} = data_{i}'
15041504
]
15051505
results.append((col, res_data))
15061506

sdc/hiframes/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,25 +167,25 @@ def fix_df_array_list_str_impl(column): # pragma: no cover
167167
return lambda column: column
168168

169169

170-
def fix_df_index(index, *columns):
170+
def fix_df_index(index):
171171
return index
172172

173173

174174
@overload(fix_df_index)
175-
def fix_df_index_overload(index, *columns):
175+
def fix_df_index_overload(index):
176176

177177
# TO-DO: replace types.none index with separate type, e.g. DefaultIndex
178178
if (index is None or isinstance(index, types.NoneType)):
179-
def fix_df_index_impl(index, *columns):
179+
def fix_df_index_impl(index):
180180
return None
181181

182182
elif isinstance(index, RangeIndexType):
183-
def fix_df_index_impl(index, *columns):
183+
def fix_df_index_impl(index):
184184
return index
185185

186186
else:
187187
# default case, transform index the same as df data
188-
def fix_df_index_impl(index, *columns):
188+
def fix_df_index_impl(index):
189189
return fix_df_array(index)
190190

191191
return fix_df_index_impl

sdc/hiframes/pd_dataframe_ext.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -80,59 +80,6 @@ def get_structure_maps(col_types, col_names):
8080
return column_loc, data_typs_map, types_order
8181

8282

83-
@intrinsic
84-
def init_dataframe(typingctx, *args):
85-
"""Create a DataFrame with provided data, index and columns values.
86-
Used as a single constructor for DataFrame and assigning its data, so that
87-
optimization passes can look for init_dataframe() to see if underlying
88-
data has changed, and get the array variables from init_dataframe() args if
89-
not changed.
90-
"""
91-
92-
n_cols = len(args) // 2
93-
data_typs = tuple(args[:n_cols])
94-
index_typ = args[n_cols]
95-
column_names = tuple(a.literal_value for a in args[n_cols + 1:])
96-
97-
column_loc, data_typs_map, types_order = get_structure_maps(data_typs, column_names)
98-
99-
def codegen(context, builder, signature, args):
100-
in_tup = args[0]
101-
data_arrs = [builder.extract_value(in_tup, i) for i in range(n_cols)]
102-
index = builder.extract_value(in_tup, n_cols)
103-
104-
# create dataframe struct and store values
105-
dataframe = cgutils.create_struct_proxy(
106-
signature.return_type)(context, builder)
107-
108-
data_list_type = [types.List(typ) for typ in types_order]
109-
110-
data_lists = []
111-
for typ_id, typ in enumerate(types_order):
112-
data_list_typ = context.build_list(builder, data_list_type[typ_id],
113-
[data_arrs[data_id] for data_id in data_typs_map[typ][1]])
114-
data_lists.append(data_list_typ)
115-
116-
data_tup = context.make_tuple(
117-
builder, types.Tuple(data_list_type), data_lists)
118-
119-
dataframe.data = data_tup
120-
dataframe.index = index
121-
dataframe.parent = context.get_constant_null(types.pyobject)
122-
123-
# increase refcount of stored values
124-
if context.enable_nrt:
125-
context.nrt.incref(builder, index_typ, index)
126-
for var, typ in zip(data_arrs, data_typs):
127-
context.nrt.incref(builder, typ, var)
128-
129-
return dataframe._getvalue()
130-
131-
ret_typ = DataFrameType(data_typs, index_typ, column_names, column_loc=column_loc)
132-
sig = signature(ret_typ, types.Tuple(args))
133-
return sig, codegen
134-
135-
13683
# TODO: alias analysis
13784
# this function should be used for getting df._data for alias analysis to work
13885
# no_cpython_wrapper since Array(DatetimeDate) cannot be boxed

sdc/hiframes/pd_dataframe_type.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,4 @@ class ColumnLoc(NamedTuple):
126126

127127
make_attribute_wrapper(DataFrameType, 'data', '_data')
128128
make_attribute_wrapper(DataFrameType, 'index', '_index')
129-
make_attribute_wrapper(DataFrameType, 'unboxed', '_unboxed')
130129
make_attribute_wrapper(DataFrameType, 'parent', '_parent')

sdc/hiframes/pd_series_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def pd_series_overload(data=None, index=None, dtype=None, name=None, copy=False,
138138
def hpat_pandas_series_ctor_impl(data=None, index=None, dtype=None, name=None, copy=False, fastpath=False):
139139

140140
fix_data = sdc.hiframes.api.fix_df_array(data)
141-
fix_index = sdc.hiframes.api.fix_df_index(index, fix_data)
141+
fix_index = sdc.hiframes.api.fix_df_index(index)
142142
return sdc.hiframes.api.init_series(fix_data, fix_index, name)
143143

144144
return hpat_pandas_series_ctor_impl

sdc/rewrites/dataframe_constructor.py

Lines changed: 132 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,30 @@
2424
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525
# *****************************************************************************
2626

27-
27+
import numba
28+
from numba.core import cgutils, types
2829
from numba.core.rewrites import (register_rewrite, Rewrite)
2930
from numba.core.ir_utils import (guard, find_callname)
3031
from numba.core.ir import (Expr)
3132
from numba.extending import overload
33+
from numba.core.extending import intrinsic
34+
from numba.core.typing import signature
3235

3336
from pandas import DataFrame
37+
from sys import modules
38+
from textwrap import dedent
3439

3540
from sdc.rewrites.ir_utils import (find_operations, is_dict,
3641
get_tuple_items, get_dict_items, remove_unused_recursively,
3742
get_call_parameters,
3843
declare_constant,
3944
import_function, make_call,
4045
insert_before)
41-
from sdc.hiframes.pd_dataframe_ext import (init_dataframe, DataFrameType)
42-
46+
from sdc.hiframes import pd_dataframe_ext as pd_dataframe_ext_module
47+
from sdc.hiframes.pd_dataframe_type import DataFrameType, ColumnLoc
48+
from sdc.hiframes.pd_dataframe_ext import get_structure_maps
4349
from sdc.hiframes.api import fix_df_array, fix_df_index
50+
from sdc.str_ext import string_type
4451

4552

4653
@register_rewrite('before-inference')
@@ -54,6 +61,7 @@ class RewriteDataFrame(Rewrite):
5461
_df_arg_list = ('data', 'index', 'columns', 'dtype', 'copy')
5562

5663
def __init__(self, pipeline):
64+
self._pipeline = pipeline
5765
super().__init__(pipeline)
5866

5967
self._reset()
@@ -79,18 +87,45 @@ def match(self, func_ir, block, typemap, calltypes):
7987
return len(self._calls_to_rewrite) > 0
8088

8189
def apply(self):
82-
init_df_stmt = import_function(init_dataframe, self._block, self._func_ir)
83-
8490
for stmt in self._calls_to_rewrite:
8591
args = get_call_parameters(call=stmt.value, arg_names=self._df_arg_list)
86-
8792
old_data = args['data']
88-
8993
args['data'], args['columns'] = self._extract_dict_args(args, self._func_ir)
9094

95+
args_len = len(args['data'])
96+
func_name = f'init_dataframe_{args_len}'
97+
98+
# injected_module = modules[pd_dataframe_ext_module.__name__]
99+
init_df = getattr(pd_dataframe_ext_module, func_name, None)
100+
if init_df is None:
101+
init_df_text = gen_init_dataframe_text(func_name, args_len)
102+
init_df = gen_init_dataframe_func(
103+
func_name,
104+
init_df_text,
105+
{
106+
'numba': numba,
107+
'cgutils': cgutils,
108+
'signature': signature,
109+
'types': types,
110+
'get_structure_maps': get_structure_maps,
111+
'intrinsic': intrinsic,
112+
'DataFrameType': DataFrameType,
113+
'ColumnLoc': ColumnLoc,
114+
'string_type': string_type,
115+
'intrinsic': intrinsic,
116+
'fix_df_array': fix_df_array,
117+
'fix_df_index': fix_df_index
118+
})
119+
120+
setattr(pd_dataframe_ext_module, func_name, init_df)
121+
init_df.__module__ = pd_dataframe_ext_module.__name__
122+
init_df._defn.__module__ = pd_dataframe_ext_module.__name__
123+
124+
init_df_stmt = import_function(init_df, self._block, self._func_ir)
91125
self._replace_call(stmt, init_df_stmt.target, args, self._block, self._func_ir)
92126

93127
remove_unused_recursively(old_data, self._block, self._func_ir)
128+
self._pipeline.typingctx.refresh()
94129

95130
return self._block
96131

@@ -130,42 +165,112 @@ def _replace_call(stmt, new_call, args, block, func_ir):
130165
columns_args = args['columns']
131166
index_args = args.get('index')
132167

133-
data_args = RewriteDataFrame._replace_data_with_arrays(data_args, stmt, block, func_ir)
134-
135168
if index_args is None: # index arg was omitted
136169
none_stmt = declare_constant(None, block, func_ir, stmt.loc)
137170
index_args = none_stmt.target
138171

139-
index_and_data_args = [index_args] + data_args
140-
index_args = RewriteDataFrame._replace_index_with_arrays(index_and_data_args, stmt, block, func_ir)
172+
index_args = [index_args]
141173

142174
all_args = data_args + index_args + columns_args
143175
call = Expr.call(new_call, all_args, {}, func.loc)
144176

145177
stmt.value = call
146178

147-
@staticmethod
148-
def _replace_data_with_arrays(args, stmt, block, func_ir):
149-
new_args = []
150179

151-
for var in args:
152-
call_stmt = make_call(fix_df_array, [var], {}, block, func_ir, var.loc)
153-
insert_before(block, call_stmt, stmt)
154-
new_args.append(call_stmt.target)
180+
def gen_init_dataframe_text(func_name, n_cols):
181+
args_col_data = ['c' + str(i) for i in range(n_cols)]
182+
args_col_names = ['n' + str(i) for i in range(n_cols)]
183+
params = ', '.join(args_col_data + ['index'] + args_col_names)
184+
suffix = ('' if n_cols == 0 else ', ')
185+
186+
func_text = dedent(
187+
f'''
188+
@intrinsic
189+
def {func_name}(typingctx, {params}):
190+
"""Create a DataFrame with provided columns data and index values.
191+
Takes 2n+1 args: n columns data, index data and n column names.
192+
Each column data is passed as separate argument to have compact LLVM IR.
193+
Used as as generic constructor for native DataFrame objects, which
194+
can be used with different input column types (e.g. lists), and
195+
resulting DataFrameType is deduced by applying transform functions
196+
(fix_df_array and fix_df_index) to input argument types.
197+
"""
198+
199+
n_cols = {n_cols}
200+
201+
input_data_typs = ({', '.join(args_col_data) + suffix})
202+
fnty = typingctx.resolve_value_type(fix_df_array)
203+
fixed_col_sigs = []
204+
for i in range({n_cols}):
205+
to_sig = fnty.get_call_type(typingctx, (input_data_typs[i],), {{}})
206+
fixed_col_sigs.append(to_sig)
207+
data_typs = tuple(fixed_col_sigs[i].return_type for i in range({n_cols}))
208+
need_fix_cols = tuple(data_typs[i] != input_data_typs[i] for i in range({n_cols}))
209+
210+
input_index_typ = index
211+
fnty = typingctx.resolve_value_type(fix_df_index)
212+
fixed_index_sig = fnty.get_call_type(typingctx, (input_index_typ,), {{}})
213+
index_typ = fixed_index_sig.return_type
214+
need_fix_index = index_typ != input_index_typ
215+
216+
column_names = tuple(a.literal_value for a in ({', '.join(args_col_names) + suffix}))
217+
column_loc, data_typs_map, types_order = get_structure_maps(data_typs, column_names)
218+
col_needs_transform = tuple(not isinstance(data_typs[i], types.Array) for i in range(len(data_typs)))
219+
220+
def codegen(context, builder, sig, args):
221+
{params}, = args
222+
data_arrs = [{', '.join(args_col_data) + suffix}]
223+
data_arrs_transformed = []
224+
for i, arr in enumerate(data_arrs):
225+
if need_fix_cols[i] == False:
226+
data_arrs_transformed.append(arr)
227+
else:
228+
res = context.compile_internal(builder, lambda a: fix_df_array(a), fixed_col_sigs[i], [arr])
229+
data_arrs_transformed.append(res)
155230
156-
return new_args
231+
# create dataframe struct and store values
232+
dataframe = cgutils.create_struct_proxy(
233+
sig.return_type)(context, builder)
157234
158-
@staticmethod
159-
def _replace_index_with_arrays(args, stmt, block, func_ir):
160-
new_args = []
235+
data_list_type = [types.List(typ) for typ in types_order]
236+
237+
data_lists = []
238+
for typ_id, typ in enumerate(types_order):
239+
data_arrs_of_typ = [data_arrs_transformed[data_id] for data_id in data_typs_map[typ][1]]
240+
data_list_typ = context.build_list(builder, data_list_type[typ_id], data_arrs_of_typ)
241+
data_lists.append(data_list_typ)
242+
243+
data_tup = context.make_tuple(
244+
builder, types.Tuple(data_list_type), data_lists)
245+
246+
if need_fix_index == True:
247+
index = context.compile_internal(builder, lambda a: fix_df_index(a), fixed_index_sig, [index])
248+
249+
dataframe.data = data_tup
250+
dataframe.index = index
251+
dataframe.parent = context.get_constant_null(types.pyobject)
252+
253+
# increase refcount of stored values
254+
if context.enable_nrt:
255+
context.nrt.incref(builder, index_typ, index)
256+
for var, typ in zip(data_arrs_transformed, data_typs):
257+
context.nrt.incref(builder, typ, var)
258+
259+
return dataframe._getvalue()
260+
261+
ret_typ = DataFrameType(data_typs, index_typ, column_names, column_loc=column_loc)
262+
sig = signature(ret_typ, {params})
263+
return sig, codegen
264+
''')
265+
266+
return func_text
161267

162-
call_stmt = make_call(fix_df_index, args, {}, block, func_ir, args[0].loc)
163-
insert_before(block, call_stmt, stmt)
164-
new_args.append(call_stmt.target)
165268

166-
return new_args
269+
def gen_init_dataframe_func(func_name, func_text, global_vars):
167270

168-
return new_args
271+
loc_vars = {}
272+
exec(func_text, global_vars, loc_vars)
273+
return loc_vars[func_name]
169274

170275

171276
@overload(DataFrame)

sdc/tests/test_dataframe.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,31 @@ def test_impl(n):
9898
n = 11
9999
self.assertEqual(hpat_func(n), test_impl(n))
100100

101+
def test_create_empty_df(self):
102+
""" Verifies empty DF can be created """
103+
def test_impl():
104+
df = pd.DataFrame({})
105+
return len(df)
106+
hpat_func = self.jit(test_impl)
107+
108+
self.assertEqual(hpat_func(), test_impl())
109+
110+
def test_create_multiple_dfs(self):
111+
""" Verifies generated dataframe ctor is added to pd_dataframe_ext module
112+
correctly (and numba global context is refreshed), so that subsequent
113+
compilations are not broken. """
114+
def test_impl(a, b, c):
115+
df1 = pd.DataFrame({'A': a, 'B': b})
116+
df2 = pd.DataFrame({'C': c})
117+
total_cols = len(df1.columns) + len(df2.columns)
118+
return total_cols
119+
hpat_func = self.jit(test_impl)
120+
121+
a1 = np.array([1, 2, 3, 4.0, 5])
122+
a2 = [7, 6, 5, 4, 3]
123+
a3 = ['a', 'b', 'c', 'd', 'e']
124+
self.assertEqual(hpat_func(a1, a2, a3), test_impl(a1, a2, a3))
125+
101126
def test_create_str(self):
102127
def test_impl():
103128
df = pd.DataFrame({'A': ['a', 'b', 'c']})
@@ -159,7 +184,7 @@ def test_impl(A, B, index):
159184
result_ref = test_impl(A, B, index)
160185
pd.testing.assert_frame_equal(result, result_ref)
161186

162-
def test_create_empty_df(self):
187+
def test_unbox_empty_df(self):
163188
def test_impl(df):
164189
return df
165190
sdc_func = self.jit(test_impl)

0 commit comments

Comments
 (0)