Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 1513e29

Browse files
authored
Overload Series(StringArraySplitView).str.len() (#360)
* Overload Series(StringArraySplitView).str.len() * Fix combination of str.split, str.get and str.len Made unique names for string methods types
1 parent 93265e9 commit 1513e29

5 files changed

Lines changed: 145 additions & 36 deletions

File tree

sdc/datatypes/hpat_pandas_stringmethods_types.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from numba.extending import (models, overload, register_model, make_attribute_wrapper, intrinsic)
3737
from numba.datamodel import (register_default, StructModel)
3838
from numba.typing.templates import signature
39+
from sdc.hiframes.split_impl import SplitViewStringMethodsType, StringArraySplitViewType
3940

4041

4142
class StringMethodsType(types.IterableType):
@@ -50,7 +51,8 @@ class StringMethodsType(types.IterableType):
5051

5152
def __init__(self, data):
5253
self.data = data
53-
super(StringMethodsType, self).__init__('StringMethodsType')
54+
name = 'StringMethodsType({})'.format(self.data)
55+
super(StringMethodsType, self).__init__(name)
5456

5557
@property
5658
def iterator_type(self):
@@ -74,37 +76,47 @@ def __init__(self, dmm, fe_type):
7476
make_attribute_wrapper(StringMethodsType, 'data', '_data')
7577

7678

77-
@intrinsic
78-
def _hpat_pandas_stringmethods_init(typingctx, data):
79-
"""
80-
Internal Numba required function to register StringMethodsType and
81-
connect it with corresponding Python type mentioned in @overload(pandas.core.strings.StringMethods)
82-
"""
79+
def _gen_hpat_pandas_stringmethods_init(string_methods_type=None):
80+
string_methods_type = string_methods_type or StringMethodsType
8381

84-
def _hpat_pandas_stringmethods_init_codegen(context, builder, signature, args):
82+
def _hpat_pandas_stringmethods_init(typingctx, data):
8583
"""
86-
It is looks like it creates StringMethodsModel structure
87-
88-
- Fixed number of parameters. Must be 4
89-
- increase reference count for the data
84+
Internal Numba required function to register StringMethodsType and
85+
connect it with corresponding Python type mentioned in @overload(pandas.core.strings.StringMethods)
9086
"""
9187

92-
[data_val] = args
93-
stringmethod = cgutils.create_struct_proxy(signature.return_type)(context, builder)
94-
stringmethod.data = data_val
88+
def _hpat_pandas_stringmethods_init_codegen(context, builder, signature, args):
89+
"""
90+
It is looks like it creates StringMethodsModel structure
9591
96-
if context.enable_nrt:
97-
context.nrt.incref(builder, data, stringmethod.data)
92+
- Fixed number of parameters. Must be 4
93+
- increase reference count for the data
94+
"""
9895

99-
return stringmethod._getvalue()
96+
[data_val] = args
97+
stringmethod = cgutils.create_struct_proxy(signature.return_type)(context, builder)
98+
stringmethod.data = data_val
10099

101-
ret_typ = StringMethodsType(data)
102-
sig = signature(ret_typ, data)
103-
"""
104-
Construct signature of the Numba SeriesGroupByType::ctor()
105-
"""
100+
if context.enable_nrt:
101+
context.nrt.incref(builder, data, stringmethod.data)
106102

107-
return sig, _hpat_pandas_stringmethods_init_codegen
103+
return stringmethod._getvalue()
104+
105+
ret_typ = string_methods_type(data)
106+
sig = signature(ret_typ, data)
107+
"""
108+
Construct signature of the Numba SeriesGroupByType::ctor()
109+
"""
110+
111+
return sig, _hpat_pandas_stringmethods_init_codegen
112+
113+
return _hpat_pandas_stringmethods_init
114+
115+
116+
_hpat_pandas_stringmethods_init = intrinsic(
117+
_gen_hpat_pandas_stringmethods_init(string_methods_type=StringMethodsType))
118+
_hpat_pandas_split_view_stringmethods_init = intrinsic(
119+
_gen_hpat_pandas_stringmethods_init(string_methods_type=SplitViewStringMethodsType))
108120

109121

110122
@overload(pandas.core.strings.StringMethods)
@@ -113,6 +125,11 @@ def hpat_pandas_stringmethods(obj):
113125
Special Numba procedure to overload Python type pandas.core.strings.StringMethods::ctor()
114126
with Numba registered model
115127
"""
128+
if isinstance(obj.data, StringArraySplitViewType):
129+
def hpat_pandas_split_view_stringmethods_impl(obj):
130+
return _hpat_pandas_split_view_stringmethods_init(obj)
131+
132+
return hpat_pandas_split_view_stringmethods_impl
116133

117134
def hpat_pandas_stringmethods_impl(obj):
118135
return _hpat_pandas_stringmethods_init(obj)

sdc/hiframes/hiframes_typed.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,10 @@
6565
from sdc.hiframes.rolling import get_rolling_setup_args
6666
from sdc.hiframes.aggregate import Aggregate
6767
from sdc.hiframes.series_kernels import series_replace_funcs
68-
from sdc.hiframes.split_impl import (string_array_split_view_type,
69-
StringArraySplitViewType, getitem_c_arr, get_array_ctypes_ptr,
70-
get_split_view_index, get_split_view_data_ptr)
68+
from sdc.hiframes.split_impl import (SplitViewStringMethodsType,
69+
string_array_split_view_type, StringArraySplitViewType,
70+
getitem_c_arr, get_array_ctypes_ptr,
71+
get_split_view_index, get_split_view_data_ptr)
7172

7273

7374
_dt_index_binops = ('==', '!=', '>=', '>', '<=', '<', '-',
@@ -480,7 +481,8 @@ def _run_call(self, assign, lhs, rhs):
480481
else:
481482
func_name, func_mod = fdef
482483

483-
if (isinstance(func_mod, ir.Var) and isinstance(self.state.typemap[func_mod.name], StringMethodsType)):
484+
string_methods_types = (SplitViewStringMethodsType, StringMethodsType)
485+
if isinstance(func_mod, ir.Var) and isinstance(self.state.typemap[func_mod.name], string_methods_types):
484486
f_def = guard(get_definition, self.state.func_ir, rhs.func)
485487
str_def = guard(get_definition, self.state.func_ir, f_def.value)
486488
if str_def is None: # TODO: check for errors

sdc/hiframes/pd_series_ext.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@
5757
from sdc.hiframes.pd_categorical_ext import (PDCategoricalDtype, CategoricalArray)
5858
from sdc.hiframes.pd_timestamp_ext import (pandas_timestamp_type, datetime_date_type)
5959
from sdc.hiframes.rolling import supported_rolling_funcs
60-
from sdc.hiframes.split_impl import (string_array_split_view_type, GetItemStringArraySplitView)
60+
from sdc.hiframes.split_impl import (SplitViewStringMethodsType,
61+
string_array_split_view_type,
62+
GetItemStringArraySplitView)
6163
from sdc.str_arr_ext import (
6264
string_array_type,
6365
iternext_str_array,
@@ -423,9 +425,9 @@ def resolve_T(self, ary):
423425
# def resolve_index(self, ary):
424426
# return ary.index
425427

426-
def resolve_str(self, ary):
427-
assert ary.dtype in (string_type, types.List(string_type))
428-
return StringMethodsType(ary)
428+
# def resolve_str(self, ary):
429+
# assert ary.dtype in (string_type, types.List(string_type))
430+
# return StringMethodsType(ary)
429431

430432
def resolve_dt(self, ary):
431433
assert ary.dtype == types.NPDatetime('ns')
@@ -780,9 +782,9 @@ class SeriesStrMethodAttribute(AttributeTemplate):
780782
def resolve_contains(self, ary, args, kws):
781783
return signature(SeriesType(types.bool_), *args)
782784

783-
@bound_function("strmethod.len")
784-
def resolve_len(self, ary, args, kws):
785-
return signature(SeriesType(types.int64), *args)
785+
# @bound_function("strmethod.len")
786+
# def resolve_len(self, ary, args, kws):
787+
# return signature(SeriesType(types.int64), *args)
786788

787789
@bound_function("strmethod.replace")
788790
def resolve_replace(self, ary, args, kws):
@@ -820,6 +822,16 @@ def generic(self, args, kws):
820822
raise NotImplementedError('Series.str.{} is not supported yet'.format(func_name))
821823

822824

825+
@infer_getattr
826+
class SplitViewSeriesStrMethodAttribute(AttributeTemplate):
827+
key = SplitViewStringMethodsType
828+
829+
@bound_function('strmethod.get')
830+
def resolve_get(self, ary, args, kws):
831+
# XXX only list(list(str)) supported
832+
return signature(SeriesType(string_type), *args)
833+
834+
823835
class SeriesDtMethodType(types.Type):
824836
def __init__(self):
825837
name = "SeriesDtMethodType"

sdc/hiframes/split_impl.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,16 @@
2626

2727

2828
import operator
29+
import numpy
30+
import pandas
2931
import numba
3032
import sdc
3133
from numba import types
3234
from numba.typing.templates import (infer_global, AbstractTemplate, infer,
3335
signature, AttributeTemplate, infer_getattr, bound_function)
3436
import numba.typing.typeof
37+
from numba.datamodel import StructModel
38+
from numba.errors import TypingError
3539
from numba.extending import (typeof_impl, type_callable, models, register_model, NativeValue,
3640
make_attribute_wrapper, lower_builtin, box, unbox,
3741
lower_getattr, intrinsic, overload_method, overload, overload_attribute)
@@ -131,6 +135,43 @@ def __init__(self, dmm, fe_type):
131135
make_attribute_wrapper(StringArraySplitViewType, 'data', '_data')
132136

133137

138+
class SplitViewStringMethodsType(types.IterableType):
139+
"""
140+
Type definition for pandas.core.strings.StringMethods functions handling.
141+
142+
Members
143+
----------
144+
_data: :class:`SeriesType`
145+
input arg
146+
"""
147+
148+
def __init__(self, data):
149+
self.data = data
150+
name = 'SplitViewStringMethodsType({})'.format(self.data)
151+
super(SplitViewStringMethodsType, self).__init__(name)
152+
153+
@property
154+
def iterator_type(self):
155+
return None
156+
157+
158+
@register_model(SplitViewStringMethodsType)
159+
class SplitViewStringMethodsTypeModel(StructModel):
160+
"""
161+
Model for SplitViewStringMethodsType type
162+
All members must be the same as main type for this model
163+
"""
164+
165+
def __init__(self, dmm, fe_type):
166+
members = [
167+
('data', fe_type.data)
168+
]
169+
models.StructModel.__init__(self, dmm, fe_type, members)
170+
171+
172+
make_attribute_wrapper(SplitViewStringMethodsType, 'data', '_data')
173+
174+
134175
def construct_str_arr_split_view(context, builder):
135176
"""Creates meminfo and sets dtor.
136177
"""
@@ -404,6 +445,44 @@ def str_arr_split_view_len_overload(arr):
404445
return lambda arr: arr._num_items
405446

406447

448+
@overload_method(SplitViewStringMethodsType, 'len')
449+
def hpat_pandas_spliview_stringmethods_len(self):
450+
"""
451+
Pandas Series method :meth:`pandas.core.strings.StringMethods.len()` implementation.
452+
453+
Note: Unicode type of list elements are supported only. Numpy.NaN is not supported as elements.
454+
455+
.. only:: developer
456+
457+
Test: python -m sdc.runtests sdc.tests.test_hiframes.TestHiFrames.test_str_split_filter
458+
459+
Parameters
460+
----------
461+
self: :class:`pandas.core.strings.StringMethods`
462+
input arg
463+
464+
Returns
465+
-------
466+
:obj:`pandas.Series`
467+
returns :obj:`pandas.Series` object
468+
"""
469+
470+
if not isinstance(self, SplitViewStringMethodsType):
471+
msg = 'Method len(). The object must be a pandas.core.strings. Given: {}'
472+
raise TypingError(msg.format(self))
473+
474+
def hpat_pandas_spliview_stringmethods_len_impl(self):
475+
item_count = len(self._data)
476+
result = numpy.empty(item_count, numba.types.int64)
477+
local_data = self._data._data
478+
for i in range(len(local_data)):
479+
result[i] = len(local_data[i])
480+
481+
return pandas.Series(result, name=self._data._name)
482+
483+
return hpat_pandas_spliview_stringmethods_len_impl
484+
485+
407486
# @infer_global(operator.getitem)
408487
class GetItemStringArraySplitView(AbstractTemplate):
409488
key = operator.getitem

sdc/tests/test_hiframes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ def test_impl(df):
424424
pd.testing.assert_series_equal(
425425
hpat_func(df), test_impl(df), check_names=False)
426426

427-
@skip_sdc_jit("Could not get length of Series(StringArraySplitView)")
428427
@skip_numba_jit
429428
def test_str_split_filter(self):
430429
def test_impl(df):

0 commit comments

Comments
 (0)