Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 050f2e5

Browse files
Rubtsowashssf
authored andcommitted
Impl operator setitem (#376)
* Impl operator setitem * change
1 parent b19f9b9 commit 050f2e5

3 files changed

Lines changed: 228 additions & 114 deletions

File tree

sdc/datatypes/hpat_pandas_series_functions.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,61 @@ def hpat_pandas_series_getitem_idx_series_impl(self, idx):
156156
raise TypingError('{} The index must be an Integer, Slice or a pandas.series. Given: {}'.format(_func_name, idx))
157157

158158

159+
@overload(operator.setitem)
160+
def hpat_pandas_series_setitem(self, idx, value):
161+
"""
162+
Pandas Series operator :attr:`pandas.Series.get` implementation
163+
'''
164+
Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_unsupported
165+
'''
166+
Parameters
167+
----------
168+
series: :obj:`pandas.Series`
169+
input series
170+
idx: :obj:`int`, :obj:`slice` or :obj:`pandas.Series`
171+
input index
172+
value: :object
173+
input value
174+
Returns
175+
-------
176+
:class:`pandas.Series` or an element of the underneath type
177+
object of :class:`pandas.Series`
178+
"""
179+
180+
_func_name = 'Operator setitem().'
181+
if not isinstance(self, SeriesType):
182+
raise TypingError('{} The object must be a pandas.series. Given: {}'.format(_func_name, self))
183+
184+
if not isinstance(self.dtype, type(value)):
185+
raise TypingError('{} Value must be one type with series. Given: {}, self.dtype={}'.format(_func_name,
186+
value, self.dtype))
187+
188+
if isinstance(idx, types.Integer) or isinstance(idx, types.SliceType):
189+
def hpat_pandas_series_setitem_idx_integer_impl(self, idx, value):
190+
"""
191+
Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_value
192+
Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_slice
193+
"""
194+
195+
self._data[idx] = value
196+
return self
197+
198+
return hpat_pandas_series_setitem_idx_integer_impl
199+
200+
if isinstance(idx, SeriesType):
201+
def hpat_pandas_series_getitem_idx_series_impl(self, idx, value):
202+
"""
203+
Test: python -m sdc.runtests sdc.tests.test_series.TestSeries.test_series_setitem_for_series
204+
"""
205+
super_index = idx._data
206+
self._data[super_index] = value
207+
return self
208+
209+
return hpat_pandas_series_getitem_idx_series_impl
210+
211+
raise TypingError('{} The index must be an Integer, Slice or a pandas.series. Given: {}'.format(_func_name, idx))
212+
213+
159214
@overload_attribute(SeriesType, 'at')
160215
@overload_attribute(SeriesType, 'iat')
161216
@overload_attribute(SeriesType, 'iloc')

sdc/hiframes/pd_series_ext.py

Lines changed: 114 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ def resolve_iloc(self, ary):
441441
return SeriesIatType(ary)
442442

443443
# PR135. This needs to be commented out
444-
def resolve_loc(self, ary):
445-
# TODO: support iat/iloc differences
446-
return SeriesIatType(ary)
444+
# def resolve_loc(self, ary):
445+
# # TODO: support iat/iloc differences
446+
# return SeriesIatType(ary)
447447

448448
# @bound_function("array.astype")
449449
# def resolve_astype(self, ary, args, kws):
@@ -900,14 +900,14 @@ def __init__(self, stype):
900900

901901

902902
# PR135. This needs to be commented out
903-
@infer_global(operator.getitem)
904-
class GetItemSeriesIat(AbstractTemplate):
905-
key = operator.getitem
906-
907-
def generic(self, args, kws):
908-
# iat[] is the same as regular getitem
909-
if isinstance(args[0], SeriesIatType):
910-
return GetItemSeries.generic(self, (args[0].stype, args[1]), kws)
903+
# @infer_global(operator.getitem)
904+
# class GetItemSeriesIat(AbstractTemplate):
905+
# key = operator.getitem
906+
#
907+
# def generic(self, args, kws):
908+
# # iat[] is the same as regular getitem
909+
# if isinstance(args[0], SeriesIatType):
910+
# return GetItemSeries.generic(self, (args[0].stype, args[1]), kws)
911911

912912

913913
@infer
@@ -1033,110 +1033,110 @@ def generic_expand_cumulative_series(self, args, kws):
10331033
delattr(SeriesAttribute, attr)
10341034

10351035
# PR135. This needs to be commented out
1036-
@infer_global(operator.getitem)
1037-
class GetItemSeries(AbstractTemplate):
1038-
key = operator.getitem
1036+
# @infer_global(operator.getitem)
1037+
# class GetItemSeries(AbstractTemplate):
1038+
# key = operator.getitem
1039+
#
1040+
# def generic(self, args, kws):
1041+
# assert not kws
1042+
# [in_arr, in_idx] = args
1043+
# is_arr_series = False
1044+
# is_idx_series = False
1045+
# is_arr_dt_index = False
1046+
#
1047+
# if not isinstance(in_arr, SeriesType) and not isinstance(in_idx, SeriesType):
1048+
# return None
1049+
#
1050+
# if isinstance(in_arr, SeriesType):
1051+
# in_arr = series_to_array_type(in_arr)
1052+
# is_arr_series = True
1053+
# if in_arr.dtype == types.NPDatetime('ns'):
1054+
# is_arr_dt_index = True
1055+
#
1056+
# if isinstance(in_idx, SeriesType):
1057+
# in_idx = series_to_array_type(in_idx)
1058+
# is_idx_series = True
1059+
#
1060+
# # TODO: dt_index
1061+
# if in_arr == string_array_type:
1062+
# # XXX fails due in overload
1063+
# # compile_internal version results in symbol not found!
1064+
# # sig = self.context.resolve_function_type(
1065+
# # operator.getitem, (in_arr, in_idx), kws)
1066+
# # HACK to get avoid issues for now
1067+
# if isinstance(in_idx, (types.Integer, types.IntegerLiteral)):
1068+
# sig = string_type(in_arr, in_idx)
1069+
# else:
1070+
# sig = GetItemStringArray.generic(self, (in_arr, in_idx), kws)
1071+
# elif in_arr == list_string_array_type:
1072+
# # TODO: split view
1073+
# # mimic array indexing for list
1074+
# if (isinstance(in_idx, types.Array) and in_idx.ndim == 1
1075+
# and isinstance(
1076+
# in_idx.dtype, (types.Integer, types.Boolean))):
1077+
# sig = signature(in_arr, in_arr, in_idx)
1078+
# else:
1079+
# sig = numba.typing.collections.GetItemSequence.generic(
1080+
# self, (in_arr, in_idx), kws)
1081+
# elif in_arr == string_array_split_view_type:
1082+
# sig = GetItemStringArraySplitView.generic(
1083+
# self, (in_arr, in_idx), kws)
1084+
# else:
1085+
# out = get_array_index_type(in_arr, in_idx)
1086+
# sig = signature(out.result, in_arr, out.index)
1087+
#
1088+
# if sig is not None:
1089+
# arg1 = sig.args[0]
1090+
# arg2 = sig.args[1]
1091+
# if is_arr_series:
1092+
# sig.return_type = if_arr_to_series_type(sig.return_type)
1093+
# arg1 = if_arr_to_series_type(arg1)
1094+
# if is_idx_series:
1095+
# arg2 = if_arr_to_series_type(arg2)
1096+
# sig.args = (arg1, arg2)
1097+
# # dt_index and Series(dt64) should return Timestamp
1098+
# if is_arr_dt_index and sig.return_type == types.NPDatetime('ns'):
1099+
# sig.return_type = pandas_timestamp_type
1100+
# return sig
10391101

1040-
def generic(self, args, kws):
1041-
assert not kws
1042-
[in_arr, in_idx] = args
1043-
is_arr_series = False
1044-
is_idx_series = False
1045-
is_arr_dt_index = False
1046-
1047-
if not isinstance(in_arr, SeriesType) and not isinstance(in_idx, SeriesType):
1048-
return None
1049-
1050-
if isinstance(in_arr, SeriesType):
1051-
in_arr = series_to_array_type(in_arr)
1052-
is_arr_series = True
1053-
if in_arr.dtype == types.NPDatetime('ns'):
1054-
is_arr_dt_index = True
1055-
1056-
if isinstance(in_idx, SeriesType):
1057-
in_idx = series_to_array_type(in_idx)
1058-
is_idx_series = True
1059-
1060-
# TODO: dt_index
1061-
if in_arr == string_array_type:
1062-
# XXX fails due in overload
1063-
# compile_internal version results in symbol not found!
1064-
# sig = self.context.resolve_function_type(
1065-
# operator.getitem, (in_arr, in_idx), kws)
1066-
# HACK to get avoid issues for now
1067-
if isinstance(in_idx, (types.Integer, types.IntegerLiteral)):
1068-
sig = string_type(in_arr, in_idx)
1069-
else:
1070-
sig = GetItemStringArray.generic(self, (in_arr, in_idx), kws)
1071-
elif in_arr == list_string_array_type:
1072-
# TODO: split view
1073-
# mimic array indexing for list
1074-
if (isinstance(in_idx, types.Array) and in_idx.ndim == 1
1075-
and isinstance(
1076-
in_idx.dtype, (types.Integer, types.Boolean))):
1077-
sig = signature(in_arr, in_arr, in_idx)
1078-
else:
1079-
sig = numba.typing.collections.GetItemSequence.generic(
1080-
self, (in_arr, in_idx), kws)
1081-
elif in_arr == string_array_split_view_type:
1082-
sig = GetItemStringArraySplitView.generic(
1083-
self, (in_arr, in_idx), kws)
1084-
else:
1085-
out = get_array_index_type(in_arr, in_idx)
1086-
sig = signature(out.result, in_arr, out.index)
1087-
1088-
if sig is not None:
1089-
arg1 = sig.args[0]
1090-
arg2 = sig.args[1]
1091-
if is_arr_series:
1092-
sig.return_type = if_arr_to_series_type(sig.return_type)
1093-
arg1 = if_arr_to_series_type(arg1)
1094-
if is_idx_series:
1095-
arg2 = if_arr_to_series_type(arg2)
1096-
sig.args = (arg1, arg2)
1097-
# dt_index and Series(dt64) should return Timestamp
1098-
if is_arr_dt_index and sig.return_type == types.NPDatetime('ns'):
1099-
sig.return_type = pandas_timestamp_type
1100-
return sig
1101-
1102-
1103-
@infer_global(operator.setitem)
1104-
class SetItemSeries(SetItemBuffer):
1105-
def generic(self, args, kws):
1106-
assert not kws
1107-
series, idx, val = args
1108-
if not isinstance(series, SeriesType):
1109-
return None
1110-
# TODO: handle any of args being Series independently
1111-
ary = series_to_array_type(series)
1112-
is_idx_series = False
1113-
if isinstance(idx, SeriesType):
1114-
idx = series_to_array_type(idx)
1115-
is_idx_series = True
1116-
is_val_series = False
1117-
if isinstance(val, SeriesType):
1118-
val = series_to_array_type(val)
1119-
is_val_series = True
1120-
# TODO: strings, dt_index
1121-
res = super(SetItemSeries, self).generic((ary, idx, val), kws)
1122-
if res is not None:
1123-
new_series = if_arr_to_series_type(res.args[0])
1124-
idx = res.args[1]
1125-
val = res.args[2]
1126-
if is_idx_series:
1127-
idx = if_arr_to_series_type(idx)
1128-
if is_val_series:
1129-
val = if_arr_to_series_type(val)
1130-
res.args = (new_series, idx, val)
1131-
return res
1132-
1133-
1134-
@infer_global(operator.setitem)
1135-
class SetItemSeriesIat(SetItemSeries):
1136-
def generic(self, args, kws):
1137-
# iat[] is the same as regular setitem
1138-
if isinstance(args[0], SeriesIatType):
1139-
return SetItemSeries.generic(self, (args[0].stype, args[1], args[2]), kws)
1102+
1103+
# @infer_global(operator.setitem)
1104+
# class SetItemSeries(SetItemBuffer):
1105+
# def generic(self, args, kws):
1106+
# assert not kws
1107+
# series, idx, val = args
1108+
# if not isinstance(series, SeriesType):
1109+
# return None
1110+
# # TODO: handle any of args being Series independently
1111+
# ary = series_to_array_type(series)
1112+
# is_idx_series = False
1113+
# if isinstance(idx, SeriesType):
1114+
# idx = series_to_array_type(idx)
1115+
# is_idx_series = True
1116+
# is_val_series = False
1117+
# if isinstance(val, SeriesType):
1118+
# val = series_to_array_type(val)
1119+
# is_val_series = True
1120+
# # TODO: strings, dt_index
1121+
# res = super(SetItemSeries, self).generic((ary, idx, val), kws)
1122+
# if res is not None:
1123+
# new_series = if_arr_to_series_type(res.args[0])
1124+
# idx = res.args[1]
1125+
# val = res.args[2]
1126+
# if is_idx_series:
1127+
# idx = if_arr_to_series_type(idx)
1128+
# if is_val_series:
1129+
# val = if_arr_to_series_type(val)
1130+
# res.args = (new_series, idx, val)
1131+
# return res
1132+
#
1133+
#
1134+
# @infer_global(operator.setitem)
1135+
# class SetItemSeriesIat(SetItemSeries):
1136+
# def generic(self, args, kws):
1137+
# # iat[] is the same as regular setitem
1138+
# if isinstance(args[0], SeriesIatType):
1139+
# return SetItemSeries.generic(self, (args[0].stype, args[1], args[2]), kws)
11401140

11411141

11421142
inplace_ops = [

sdc/tests/test_series.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4388,6 +4388,65 @@ def test_series_pct_change_impl(S, periods=1, fill_method='pad', limit=None, fre
43884388
msg = 'Method pct_change(). The object periods'
43894389
self.assertIn(msg, str(raises.exception))
43904390

4391+
def test_series_setitem_for_value(self):
4392+
def test_impl(S, val):
4393+
S[3] = val
4394+
return S
4395+
4396+
hpat_func = self.jit(test_impl)
4397+
S = pd.Series([0, 1, 2, 3, 4])
4398+
value = 50
4399+
result_ref = test_impl(S, value)
4400+
result = hpat_func(S, value)
4401+
pd.testing.assert_series_equal(result_ref, result)
4402+
4403+
def test_series_setitem_for_slice(self):
4404+
def test_impl(S, val):
4405+
S[2:] = val
4406+
return S
4407+
4408+
hpat_func = self.jit(test_impl)
4409+
S = pd.Series([0, 1, 2, 3, 4])
4410+
value = 50
4411+
result_ref = test_impl(S, value)
4412+
result = hpat_func(S, value)
4413+
pd.testing.assert_series_equal(result_ref, result)
4414+
4415+
def test_series_setitem_for_series(self):
4416+
def test_impl(S, ind, val):
4417+
S[ind] = val
4418+
return S
4419+
4420+
hpat_func = self.jit(test_impl)
4421+
S = pd.Series([0, 1, 2, 3, 4])
4422+
ind = pd.Series([0, 2, 4])
4423+
value = 50
4424+
result_ref = test_impl(S, ind, value)
4425+
result = hpat_func(S, ind, value)
4426+
pd.testing.assert_series_equal(result_ref, result)
4427+
4428+
def test_series_setitem_unsupported(self):
4429+
def test_impl(S, ind, val):
4430+
S[ind] = val
4431+
return S
4432+
4433+
hpat_func = self.jit(test_impl)
4434+
S = pd.Series([0, 1, 2, 3, 4, 5])
4435+
ind1 = 5
4436+
ind2 = '3'
4437+
value1 = 'ababa'
4438+
value2 = 101
4439+
4440+
with self.assertRaises(TypingError) as raises:
4441+
hpat_func(S, ind1, value1)
4442+
msg = 'Operator setitem(). Value must be one type with series.'
4443+
self.assertIn(msg, str(raises.exception))
4444+
4445+
with self.assertRaises(TypingError) as raises:
4446+
hpat_func(S, ind2, value2)
4447+
msg = 'Operator setitem(). The index must be an Integer, Slice or a pandas.series.'
4448+
self.assertIn(msg, str(raises.exception))
4449+
43914450

43924451
if __name__ == "__main__":
43934452
unittest.main()

0 commit comments

Comments
 (0)