2424# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525# *****************************************************************************
2626
27-
27+ import numba
28+ from numba .core import cgutils , types
2829from numba .core .rewrites import (register_rewrite , Rewrite )
2930from numba .core .ir_utils import (guard , find_callname )
3031from numba .core .ir import (Expr )
3132from numba .extending import overload
33+ from numba .core .extending import intrinsic
34+ from numba .core .typing import signature
3235
3336from pandas import DataFrame
37+ from sys import modules
38+ from textwrap import dedent
3439
3540from sdc .rewrites .ir_utils import (find_operations , is_dict ,
3641 get_tuple_items , get_dict_items , remove_unused_recursively ,
3742 get_call_parameters ,
3843 declare_constant ,
3944 import_function , make_call ,
4045 insert_before )
41- from sdc .hiframes .pd_dataframe_ext import (init_dataframe , DataFrameType )
42-
46+ from sdc .hiframes import pd_dataframe_ext as pd_dataframe_ext_module
47+ from sdc .hiframes .pd_dataframe_type import DataFrameType , ColumnLoc
48+ from sdc .hiframes .pd_dataframe_ext import get_structure_maps
4349from sdc .hiframes .api import fix_df_array , fix_df_index
50+ from sdc .str_ext import string_type
4451
4552
4653@register_rewrite ('before-inference' )
@@ -54,6 +61,7 @@ class RewriteDataFrame(Rewrite):
5461 _df_arg_list = ('data' , 'index' , 'columns' , 'dtype' , 'copy' )
5562
5663 def __init__ (self , pipeline ):
64+ self ._pipeline = pipeline
5765 super ().__init__ (pipeline )
5866
5967 self ._reset ()
@@ -79,18 +87,45 @@ def match(self, func_ir, block, typemap, calltypes):
7987 return len (self ._calls_to_rewrite ) > 0
8088
8189 def apply (self ):
82- init_df_stmt = import_function (init_dataframe , self ._block , self ._func_ir )
83-
8490 for stmt in self ._calls_to_rewrite :
8591 args = get_call_parameters (call = stmt .value , arg_names = self ._df_arg_list )
86-
8792 old_data = args ['data' ]
88-
8993 args ['data' ], args ['columns' ] = self ._extract_dict_args (args , self ._func_ir )
9094
95+ args_len = len (args ['data' ])
96+ func_name = f'init_dataframe_{ args_len } '
97+
98+ # injected_module = modules[pd_dataframe_ext_module.__name__]
99+ init_df = getattr (pd_dataframe_ext_module , func_name , None )
100+ if init_df is None :
101+ init_df_text = gen_init_dataframe_text (func_name , args_len )
102+ init_df = gen_init_dataframe_func (
103+ func_name ,
104+ init_df_text ,
105+ {
106+ 'numba' : numba ,
107+ 'cgutils' : cgutils ,
108+ 'signature' : signature ,
109+ 'types' : types ,
110+ 'get_structure_maps' : get_structure_maps ,
111+ 'intrinsic' : intrinsic ,
112+ 'DataFrameType' : DataFrameType ,
113+ 'ColumnLoc' : ColumnLoc ,
114+ 'string_type' : string_type ,
115+ 'intrinsic' : intrinsic ,
116+ 'fix_df_array' : fix_df_array ,
117+ 'fix_df_index' : fix_df_index
118+ })
119+
120+ setattr (pd_dataframe_ext_module , func_name , init_df )
121+ init_df .__module__ = pd_dataframe_ext_module .__name__
122+ init_df ._defn .__module__ = pd_dataframe_ext_module .__name__
123+
124+ init_df_stmt = import_function (init_df , self ._block , self ._func_ir )
91125 self ._replace_call (stmt , init_df_stmt .target , args , self ._block , self ._func_ir )
92126
93127 remove_unused_recursively (old_data , self ._block , self ._func_ir )
128+ self ._pipeline .typingctx .refresh ()
94129
95130 return self ._block
96131
@@ -130,42 +165,112 @@ def _replace_call(stmt, new_call, args, block, func_ir):
130165 columns_args = args ['columns' ]
131166 index_args = args .get ('index' )
132167
133- data_args = RewriteDataFrame ._replace_data_with_arrays (data_args , stmt , block , func_ir )
134-
135168 if index_args is None : # index arg was omitted
136169 none_stmt = declare_constant (None , block , func_ir , stmt .loc )
137170 index_args = none_stmt .target
138171
139- index_and_data_args = [index_args ] + data_args
140- index_args = RewriteDataFrame ._replace_index_with_arrays (index_and_data_args , stmt , block , func_ir )
172+ index_args = [index_args ]
141173
142174 all_args = data_args + index_args + columns_args
143175 call = Expr .call (new_call , all_args , {}, func .loc )
144176
145177 stmt .value = call
146178
147- @staticmethod
148- def _replace_data_with_arrays (args , stmt , block , func_ir ):
149- new_args = []
150179
151- for var in args :
152- call_stmt = make_call (fix_df_array , [var ], {}, block , func_ir , var .loc )
153- insert_before (block , call_stmt , stmt )
154- new_args .append (call_stmt .target )
180+ def gen_init_dataframe_text (func_name , n_cols ):
181+ args_col_data = ['c' + str (i ) for i in range (n_cols )]
182+ args_col_names = ['n' + str (i ) for i in range (n_cols )]
183+ params = ', ' .join (args_col_data + ['index' ] + args_col_names )
184+ suffix = ('' if n_cols == 0 else ', ' )
185+
186+ func_text = dedent (
187+ f'''
188+ @intrinsic
189+ def { func_name } (typingctx, { params } ):
190+ """Create a DataFrame with provided columns data and index values.
191+ Takes 2n+1 args: n columns data, index data and n column names.
192+ Each column data is passed as separate argument to have compact LLVM IR.
193+ Used as as generic constructor for native DataFrame objects, which
194+ can be used with different input column types (e.g. lists), and
195+ resulting DataFrameType is deduced by applying transform functions
196+ (fix_df_array and fix_df_index) to input argument types.
197+ """
198+
199+ n_cols = { n_cols }
200+
201+ input_data_typs = ({ ', ' .join (args_col_data ) + suffix } )
202+ fnty = typingctx.resolve_value_type(fix_df_array)
203+ fixed_col_sigs = []
204+ for i in range({ n_cols } ):
205+ to_sig = fnty.get_call_type(typingctx, (input_data_typs[i],), {{}})
206+ fixed_col_sigs.append(to_sig)
207+ data_typs = tuple(fixed_col_sigs[i].return_type for i in range({ n_cols } ))
208+ need_fix_cols = tuple(data_typs[i] != input_data_typs[i] for i in range({ n_cols } ))
209+
210+ input_index_typ = index
211+ fnty = typingctx.resolve_value_type(fix_df_index)
212+ fixed_index_sig = fnty.get_call_type(typingctx, (input_index_typ,), {{}})
213+ index_typ = fixed_index_sig.return_type
214+ need_fix_index = index_typ != input_index_typ
215+
216+ column_names = tuple(a.literal_value for a in ({ ', ' .join (args_col_names ) + suffix } ))
217+ column_loc, data_typs_map, types_order = get_structure_maps(data_typs, column_names)
218+ col_needs_transform = tuple(not isinstance(data_typs[i], types.Array) for i in range(len(data_typs)))
219+
220+ def codegen(context, builder, sig, args):
221+ { params } , = args
222+ data_arrs = [{ ', ' .join (args_col_data ) + suffix } ]
223+ data_arrs_transformed = []
224+ for i, arr in enumerate(data_arrs):
225+ if need_fix_cols[i] == False:
226+ data_arrs_transformed.append(arr)
227+ else:
228+ res = context.compile_internal(builder, lambda a: fix_df_array(a), fixed_col_sigs[i], [arr])
229+ data_arrs_transformed.append(res)
155230
156- return new_args
231+ # create dataframe struct and store values
232+ dataframe = cgutils.create_struct_proxy(
233+ sig.return_type)(context, builder)
157234
158- @staticmethod
159- def _replace_index_with_arrays (args , stmt , block , func_ir ):
160- new_args = []
235+ data_list_type = [types.List(typ) for typ in types_order]
236+
237+ data_lists = []
238+ for typ_id, typ in enumerate(types_order):
239+ data_arrs_of_typ = [data_arrs_transformed[data_id] for data_id in data_typs_map[typ][1]]
240+ data_list_typ = context.build_list(builder, data_list_type[typ_id], data_arrs_of_typ)
241+ data_lists.append(data_list_typ)
242+
243+ data_tup = context.make_tuple(
244+ builder, types.Tuple(data_list_type), data_lists)
245+
246+ if need_fix_index == True:
247+ index = context.compile_internal(builder, lambda a: fix_df_index(a), fixed_index_sig, [index])
248+
249+ dataframe.data = data_tup
250+ dataframe.index = index
251+ dataframe.parent = context.get_constant_null(types.pyobject)
252+
253+ # increase refcount of stored values
254+ if context.enable_nrt:
255+ context.nrt.incref(builder, index_typ, index)
256+ for var, typ in zip(data_arrs_transformed, data_typs):
257+ context.nrt.incref(builder, typ, var)
258+
259+ return dataframe._getvalue()
260+
261+ ret_typ = DataFrameType(data_typs, index_typ, column_names, column_loc=column_loc)
262+ sig = signature(ret_typ, { params } )
263+ return sig, codegen
264+ ''' )
265+
266+ return func_text
161267
162- call_stmt = make_call (fix_df_index , args , {}, block , func_ir , args [0 ].loc )
163- insert_before (block , call_stmt , stmt )
164- new_args .append (call_stmt .target )
165268
166- return new_args
269+ def gen_init_dataframe_func ( func_name , func_text , global_vars ):
167270
168- return new_args
271+ loc_vars = {}
272+ exec (func_text , global_vars , loc_vars )
273+ return loc_vars [func_name ]
169274
170275
171276@overload (DataFrame )
0 commit comments