Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 811e9f0

Browse files
Changing csv_reader_py impl to return df from objmode (#918)
* Changing csv_reader_py impl to return df from objmode Motivation: returning Tuple of columns read from csv file with pyarrow csv reader from objmode and further calling init_dataframe ctor to create native DF turned out to be inneficient in sense of LLVM IR size and compilation time. With this PR we now rely on DF unboxing and return py DF from objmode. * Capture dtype dict instead of building in objmode * Applying comments #1
1 parent db4f431 commit 811e9f0

6 files changed

Lines changed: 139 additions & 121 deletions

File tree

sdc/datatypes/hpat_pandas_functions.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@
3939

4040
from sdc.io.csv_ext import (
4141
_gen_csv_reader_py_pyarrow_py_func,
42-
_gen_csv_reader_py_pyarrow_func_text_dataframe,
42+
_gen_pandas_read_csv_func_text,
4343
)
4444
from sdc.str_arr_ext import string_array_type
4545

4646
from sdc.hiframes import join, aggregate, sort
4747
from sdc.types import CategoricalDtypeType, Categorical
48+
from sdc.datatypes.categorical.pdimpl import _reconstruct_CategoricalDtype
4849

4950

5051
def get_numba_array_types_for_csv(df):
@@ -255,45 +256,69 @@ def sdc_pandas_read_csv(
255256
usecols = [col.literal_value for col in usecols]
256257

257258
if infer_from_params:
258-
# dtype should be constants and is important only for inference from params
259+
# dtype is a tuple of format ('A', A_dtype, 'B', B_dtype, ...)
260+
# where column names should be constants and is important only for inference from params
259261
if isinstance(dtype, types.Tuple):
260-
assert all(isinstance(key, types.Literal) for key in dtype[::2])
262+
assert all(isinstance(key, types.StringLiteral) for key in dtype[::2])
261263
keys = (k.literal_value for k in dtype[::2])
262-
263264
values = dtype[1::2]
264-
values = [v.typing_key if isinstance(v, types.Function) else v for v in values]
265-
values = [types.Array(numba.from_dtype(np.dtype(v.literal_value)), 1, 'C')
266-
if isinstance(v, types.Literal) else v for v in values]
267-
values = [types.Array(types.int_, 1, 'C') if v == int else v for v in values]
268-
values = [types.Array(types.float64, 1, 'C') if v == float else v for v in values]
269-
values = [string_array_type if v == str else v for v in values]
270-
values = [Categorical(v) if isinstance(v, CategoricalDtypeType) else v for v in values]
271265

272-
dtype = dict(zip(keys, values))
266+
def _get_df_col_type(dtype):
267+
if isinstance(dtype, types.Function):
268+
if dtype.typing_key == int:
269+
return types.Array(types.int_, 1, 'C')
270+
elif dtype.typing_key == float:
271+
return types.Array(types.float64, 1, 'C')
272+
elif dtype.typing_key == str:
273+
return string_array_type
274+
else:
275+
assert False, f"map_dtype_to_col_type: failing to infer column type for dtype={dtype}"
276+
277+
if isinstance(dtype, types.StringLiteral):
278+
if dtype.literal_value == 'str':
279+
return string_array_type
280+
else:
281+
return types.Array(numba.from_dtype(np.dtype(dtype.literal_value)), 1, 'C')
282+
283+
if isinstance(dtype, types.NumberClass):
284+
return types.Array(dtype.dtype, 1, 'C')
285+
286+
if isinstance(dtype, CategoricalDtypeType):
287+
return Categorical(dtype)
288+
289+
col_types_map = dict(zip(keys, map(_get_df_col_type, values)))
273290

274291
# in case of both are available
275292
# inferencing from params has priority over inferencing from file
276293
if infer_from_params:
277-
col_names = names
278294
# all names should be in dtype
279-
return_columns = usecols if usecols else names
280-
col_typs = [dtype[n] for n in return_columns]
295+
col_names = usecols if usecols else names
296+
col_types = [col_types_map[n] for n in col_names]
281297

282298
elif infer_from_file:
283-
col_names, col_typs = infer_column_names_and_types_from_constant_filename(
299+
col_names, col_types = infer_column_names_and_types_from_constant_filename(
284300
filepath_or_buffer, delimiter, names, usecols, skiprows)
285301

286302
else:
287303
return None
288304

289-
dtype_present = not isinstance(dtype, (types.Omitted, type(None)))
305+
def _get_py_col_dtype(ctype):
306+
""" Re-creates column dtype as python type to be used in read_csv call """
307+
dtype = ctype.dtype
308+
if ctype == string_array_type:
309+
return str
310+
if isinstance(ctype, Categorical):
311+
return _reconstruct_CategoricalDtype(ctype.pd_dtype)
312+
return numpy_support.as_dtype(dtype)
313+
314+
py_col_dtypes = {cname: _get_py_col_dtype(ctype) for cname, ctype in zip(col_names, col_types)}
290315

291316
# generate function text with signature and returning DataFrame
292-
func_text, func_name = _gen_csv_reader_py_pyarrow_func_text_dataframe(
293-
col_names, col_typs, dtype_present, usecols, signature)
317+
func_text, func_name, global_vars = _gen_pandas_read_csv_func_text(
318+
col_names, col_types, py_col_dtypes, usecols, signature)
294319

295320
# compile with Python
296-
csv_reader_py = _gen_csv_reader_py_pyarrow_py_func(func_text, func_name)
321+
csv_reader_py = _gen_csv_reader_py_pyarrow_py_func(func_text, func_name, global_vars)
297322

298323
return csv_reader_py
299324

sdc/hiframes/pd_dataframe_ext.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727

2828
import operator
29-
from typing import NamedTuple
3029

3130
import numba
3231
from numba import types
@@ -39,7 +38,7 @@
3938
from numba.core.imputils import impl_ret_new_ref, impl_ret_borrowed
4039

4140
from sdc.hiframes.pd_series_ext import SeriesType
42-
from sdc.hiframes.pd_dataframe_type import DataFrameType
41+
from sdc.hiframes.pd_dataframe_type import DataFrameType, ColumnLoc
4342
from sdc.str_ext import string_type
4443

4544

@@ -54,10 +53,6 @@ def generic_resolve(self, df, attr):
5453
return SeriesType(arr_typ.dtype, arr_typ, df.index, True)
5554

5655

57-
class ColumnLoc(NamedTuple):
58-
type_id: int
59-
col_id: int
60-
6156

6257
def get_structure_maps(col_types, col_names):
6358
# Define map column name to column location ex. {'A': (0,0), 'B': (1,0), 'C': (0,1)}

sdc/hiframes/pd_dataframe_type.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525
# *****************************************************************************
2626

27+
import re
28+
from typing import NamedTuple
2729

2830
import numba
2931
from numba import types
@@ -48,7 +50,7 @@ def __init__(self, data=None, index=None, columns=None, has_parent=False, column
4850
self.has_parent = has_parent
4951
self.column_loc = column_loc
5052
super(DataFrameType, self).__init__(
51-
name="dataframe({}, {}, {}, {})".format(data, index, columns, has_parent))
53+
name="DataFrameType({}, {}, {}, {})".format(data, index, columns, has_parent))
5254

5355
def copy(self, index=None, has_parent=None):
5456
# XXX is copy necessary?
@@ -83,6 +85,16 @@ def unify(self, typingctx, other):
8385
def is_precise(self):
8486
return all(a.is_precise() for a in self.data) and self.index.is_precise()
8587

88+
def __repr__(self):
89+
90+
# To have correct repr of DataFrame we need some changes to what types.Type gives:
91+
# (1) e.g. array(int64, 1d, C) should be Array(int64, 1, 'C')
92+
# (2) ColumnLoc is not part of DataFrame name, so we need to add it
93+
default_repr = super(DataFrameType, self).__repr__()
94+
res = re.sub(r'array\((\w+), 1d, C\)', r'Array(\1, 1, \'C\')', default_repr)
95+
res = re.sub(r'\)$', f', column_loc={self.column_loc})', res)
96+
return res
97+
8698

8799
@register_model(DataFrameType)
88100
class DataFrameModel(models.StructModel):
@@ -104,6 +116,15 @@ def __init__(self, dmm, fe_type):
104116
super(DataFrameModel, self).__init__(dmm, fe_type, members)
105117

106118

119+
class ColumnLoc(NamedTuple):
120+
type_id: int
121+
col_id: int
122+
123+
124+
# FIXME_Numba#3372: add into numba.types to allow returning from objmode
125+
types.DataFrameType = DataFrameType
126+
types.ColumnLoc = ColumnLoc
127+
107128
make_attribute_wrapper(DataFrameType, 'data', '_data')
108129
make_attribute_wrapper(DataFrameType, 'index', '_index')
109130
make_attribute_wrapper(DataFrameType, 'columns', '_columns')

0 commit comments

Comments
 (0)