Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit f64407c

Browse files
authored
Implement Series.str.isupper() in new style (#380)
* Implement Series.str.isupper() in new style * Add index to output Series of str methods * Expand tests for str methods with indices * Expand tests for str methods with indices [2]
1 parent 1513e29 commit f64407c

6 files changed

Lines changed: 122 additions & 73 deletions

File tree

sdc/datatypes/common_functions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,63 @@
3333
import numpy
3434

3535
from numba import types
36+
from numba.errors import TypingError
3637
from numba.extending import overload
3738
from numba import numpy_support
3839

3940
import sdc
4041
from sdc.str_arr_ext import (string_array_type, num_total_chars, append_string_array_to)
4142

4243

44+
class TypeChecker:
45+
"""
46+
Validate object type and raise TypingError if the type is invalid, e.g.:
47+
Method nsmallest(). The object n
48+
given: bool
49+
expected: int
50+
"""
51+
msg_template = '{} The object {}\n given: {}\n expected: {}'
52+
53+
def __init__(self, func_name):
54+
"""
55+
Parameters
56+
----------
57+
func_name: :obj:`str`
58+
name of the function where types checking
59+
"""
60+
self.func_name = func_name
61+
62+
def raise_exc(self, data, expected_types, name=''):
63+
"""
64+
Raise exception with unified message
65+
Parameters
66+
----------
67+
data: :obj:`any`
68+
real type of the data
69+
expected_types: :obj:`str`
70+
expected types inserting directly to the exception
71+
name: :obj:`str`
72+
name of the parameter
73+
"""
74+
msg = self.msg_template.format(self.func_name, name, data, expected_types)
75+
raise TypingError(msg)
76+
77+
def check(self, data, accepted_type, name=''):
78+
"""
79+
Check data type belongs to specified type
80+
Parameters
81+
----------
82+
data: :obj:`any`
83+
real type of the data
84+
accepted_type: :obj:`type`
85+
accepted type
86+
name: :obj:`str`
87+
name of the parameter
88+
"""
89+
if not isinstance(data, accepted_type):
90+
self.raise_exc(data, accepted_type.__name__, name=name)
91+
92+
4393
def has_literal_value(var, value):
4494
'''Used during typing to check that variable var is a Numba literal value equal to value'''
4595

sdc/datatypes/hpat_pandas_series_functions.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,59 +40,12 @@
4040

4141
import sdc
4242
import sdc.datatypes.common_functions as common_functions
43+
from sdc.datatypes.common_functions import TypeChecker
4344
from sdc.datatypes.hpat_pandas_stringmethods_types import StringMethodsType
4445
from sdc.hiframes.pd_series_ext import SeriesType
4546
from sdc.str_arr_ext import (StringArrayType, cp_str_list_to_array, num_total_chars)
4647
from sdc.utils import to_array
4748

48-
class TypeChecker:
49-
"""
50-
Validate object type and raise TypingError if the type is invalid, e.g.:
51-
Method nsmallest(). The object n
52-
given: bool
53-
expected: int
54-
"""
55-
msg_template = '{} The object {}\n given: {}\n expected: {}'
56-
57-
def __init__(self, func_name):
58-
"""
59-
Parameters
60-
----------
61-
func_name: :obj:`str`
62-
name of the function where types checking
63-
"""
64-
self.func_name = func_name
65-
66-
def raise_exc(self, data, expected_types, name=''):
67-
"""
68-
Raise exception with unified message
69-
Parameters
70-
----------
71-
data: :obj:`any`
72-
real type of the data
73-
expected_types: :obj:`str`
74-
expected types inserting directly to the exception
75-
name: :obj:`str`
76-
name of the parameter
77-
"""
78-
msg = self.msg_template.format(self.func_name, name, data, expected_types)
79-
raise TypingError(msg)
80-
81-
def check(self, data, accepted_type, name=''):
82-
"""
83-
Check data type belongs to specified type
84-
Parameters
85-
----------
86-
data: :obj:`any`
87-
real type of the data
88-
accepted_type: :obj:`type`
89-
accepted type
90-
name: :obj:`str`
91-
name of the parameter
92-
"""
93-
if not isinstance(data, accepted_type):
94-
self.raise_exc(data, accepted_type.__name__, name=name)
95-
9649

9750
@overload(operator.getitem)
9851
def hpat_pandas_series_getitem(self, idx):

sdc/datatypes/hpat_pandas_stringmethods_functions.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@
4040
4141
@overload_method(StringMethodsType, 'upper')
4242
def hpat_pandas_stringmethods_upper(self):
43-
_func_name = 'Method stringmethods.upper().'
4443
45-
if not isinstance(self, StringMethodsType):
46-
raise TypingError('{} The object must be a pandas.core.strings. Given: {}'.format(_func_name, self))
44+
ty_checker = TypeChecker('Method stringmethods.upper().')
45+
ty_checker.check(self, StringMethodsType)
4746
4847
def hpat_pandas_stringmethods_upper_parallel_impl(self):
4948
from numba.parfor import (init_prange, min_checker, internal_prange)
@@ -83,16 +82,17 @@ def hpat_pandas_stringmethods_upper_impl(self):
8382

8483
import numba
8584
from numba.extending import overload_method
86-
from numba.errors import TypingError
8785

86+
from sdc.datatypes.common_functions import TypeChecker
8887
from sdc.datatypes.hpat_pandas_stringmethods_types import StringMethodsType
8988

9089

9190
_hpat_pandas_stringmethods_autogen_global_dict = {
9291
'pandas': pandas,
9392
'numpy': numpy,
9493
'numba': numba,
95-
'StringMethodsType': StringMethodsType
94+
'StringMethodsType': StringMethodsType,
95+
'TypeChecker': TypeChecker
9696
}
9797

9898
_hpat_pandas_stringmethods_functions_params = {
@@ -166,8 +166,8 @@ def hpat_pandas_stringmethods_{methodname}(self{methodparams}):
166166
returns :obj:`pandas.Series` object
167167
\"\"\"
168168
169-
if not isinstance(self, StringMethodsType):
170-
raise TypingError('Method {methodname}(). The object must be a pandas.core.strings. Given: ' % self)
169+
ty_checker = TypeChecker('Method {methodname}().')
170+
ty_checker.check(self, StringMethodsType)
171171
172172
def hpat_pandas_stringmethods_{methodname}_impl(self{methodparams}):
173173
item_count = len(self._data)
@@ -181,12 +181,48 @@ def hpat_pandas_stringmethods_{methodname}_impl(self{methodparams}):
181181
else:
182182
result[it] = item
183183
184-
return pandas.Series(result, name=self._data._name)
184+
return pandas.Series(result, self._data._index, name=self._data._name)
185185
186186
return hpat_pandas_stringmethods_{methodname}_impl
187187
"""
188188

189189

190+
@overload_method(StringMethodsType, 'isupper')
191+
def hpat_pandas_stringmethods_isupper(self):
192+
"""
193+
Pandas Series method :meth:`pandas.core.strings.StringMethods.isupper()` implementation.
194+
195+
Note: Unicode type of list elements are supported only. Numpy.NaN is not supported as elements.
196+
197+
.. only:: developer
198+
199+
Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_str2str
200+
201+
Parameters
202+
----------
203+
self: :class:`pandas.core.strings.StringMethods`
204+
input arg
205+
206+
Returns
207+
-------
208+
:obj:`pandas.Series`
209+
returns :obj:`pandas.Series` object
210+
"""
211+
212+
ty_checker = TypeChecker('Method isupper().')
213+
ty_checker.check(self, StringMethodsType)
214+
215+
def hpat_pandas_stringmethods_isupper_impl(self):
216+
item_count = len(self._data)
217+
result = numpy.empty(item_count, numba.types.boolean)
218+
for idx, item in enumerate(self._data._data):
219+
result[idx] = item.isupper()
220+
221+
return pandas.Series(result, self._data._index, name=self._data._name)
222+
223+
return hpat_pandas_stringmethods_isupper_impl
224+
225+
190226
@overload_method(StringMethodsType, 'len')
191227
def hpat_pandas_stringmethods_len(self):
192228
"""
@@ -209,16 +245,16 @@ def hpat_pandas_stringmethods_len(self):
209245
returns :obj:`pandas.Series` object
210246
"""
211247

212-
if not isinstance(self, StringMethodsType):
213-
raise TypingError('Method len(). The object must be a pandas.core.strings. Given: {}'.format(self))
248+
ty_checker = TypeChecker('Method len().')
249+
ty_checker.check(self, StringMethodsType)
214250

215251
def hpat_pandas_stringmethods_len_impl(self):
216252
item_count = len(self._data)
217253
result = numpy.empty(item_count, numba.types.int64)
218254
for idx, item in enumerate(self._data._data):
219255
result[idx] = len(item)
220256

221-
return pandas.Series(result, name=self._data._name)
257+
return pandas.Series(result, self._data._index, name=self._data._name)
222258

223259
return hpat_pandas_stringmethods_len_impl
224260

sdc/hiframes/pd_series_ext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,8 @@ def resolve_head(self, ary, args, kws):
758758
Functions which are still overloaded by HPAT compiler pipeline
759759
"""
760760

761-
str2str_methods_excluded = ['upper', 'len', 'lower', 'lstrip', 'rstrip', 'strip']
761+
str2str_methods_excluded = ['upper', 'isupper', 'len', 'lower',
762+
'lstrip', 'rstrip', 'strip']
762763
"""
763764
Functions which are used from Numba directly by calling from StringMethodsType
764765

sdc/hiframes/split_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def hpat_pandas_spliview_stringmethods_len_impl(self):
478478
for i in range(len(local_data)):
479479
result[i] = len(local_data[i])
480480

481-
return pandas.Series(result, name=self._data._name)
481+
return pandas.Series(result, self._data._index, name=self._data._name)
482482

483483
return hpat_pandas_spliview_stringmethods_len_impl
484484

sdc/tests/test_series.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import numpy as np
3434
import pyarrow.parquet as pq
3535
import sdc
36-
from itertools import islice, permutations
36+
from itertools import islice, permutations, product
3737
from sdc.tests.test_base import TestCase
3838
from sdc.tests.test_utils import (
3939
count_array_REPs, count_parfor_REPs, count_array_OneDs, get_start_end,
@@ -2424,23 +2424,31 @@ def test_impl(S1, S2):
24242424
hpat_func(S1, S2), test_impl(S1, S2),
24252425
err_msg='S1={}\nS2={}'.format(S1, S2))
24262426

2427-
@skip_numba_jit
24282427
def test_series_str_len1(self):
24292428
def test_impl(S):
24302429
return S.str.len()
24312430
hpat_func = self.jit(test_impl)
24322431

2433-
# TODO: fix issue occurred if name is not assigned
2434-
S = pd.Series(['aa', 'abc', 'c', 'cccd'], name='A')
2435-
pd.testing.assert_series_equal(hpat_func(S), test_impl(S))
2432+
data = ['aa', 'abc', 'c', 'cccd']
2433+
indices = [None, [1, 3, 2, 0], data]
2434+
names = [None, 'A']
2435+
for index, name in product(indices, names):
2436+
S = pd.Series(data, index, name=name)
2437+
pd.testing.assert_series_equal(hpat_func(S), test_impl(S))
24362438

2437-
@skip_numba_jit
24382439
def test_series_str2str(self):
2439-
common_methods = ['lower', 'upper', 'lstrip', 'rstrip', 'strip']
2440-
sdc_methods = ['capitalize', 'swapcase', 'title']
2440+
common_methods = ['lower', 'upper', 'isupper']
2441+
sdc_methods = ['capitalize', 'swapcase', 'title',
2442+
'lstrip', 'rstrip', 'strip']
24412443
str2str_methods = common_methods[:]
2444+
2445+
data = [' \tbbCD\t ', 'ABC', ' mCDm\t', 'abc']
2446+
indices = [None]
2447+
names = [None, 'A']
24422448
if sdc.config.config_pipeline_hpat_default:
24432449
str2str_methods += sdc_methods
2450+
else:
2451+
indices += [[1, 3, 2, 0], data]
24442452

24452453
for method in str2str_methods:
24462454
func_lines = ['def test_impl(S):',
@@ -2449,10 +2457,11 @@ def test_series_str2str(self):
24492457
test_impl = _make_func_from_text(func_text)
24502458
hpat_func = self.jit(test_impl)
24512459

2452-
# TODO: fix issue occurred if name is not assigned
2453-
S = pd.Series([' \tbbCD\t ', 'ABC', ' mCDm\t', 'abc'], name='A')
2454-
pd.testing.assert_series_equal(hpat_func(S), test_impl(S),
2455-
check_names=method in common_methods)
2460+
check_names = method in common_methods
2461+
for index, name in product(indices, names):
2462+
S = pd.Series(data, index, name=name)
2463+
pd.testing.assert_series_equal(hpat_func(S), test_impl(S),
2464+
check_names=check_names)
24562465

24572466
@skip_sdc_jit('Series.str.<method>() unsupported')
24582467
def test_series_str2str_unsupported(self):

0 commit comments

Comments
 (0)