Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 1b46def

Browse files
author
Ehsan Totoni
committed
support len and getitem for string array split view
1 parent 0b73620 commit 1b46def

2 files changed

Lines changed: 74 additions & 5 deletions

File tree

hpat/hiframes/hiframes_typed.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from hpat.hiframes.aggregate import Aggregate
3737
from hpat.hiframes import series_kernels, split_impl
3838
from hpat.hiframes.series_kernels import series_replace_funcs
39-
39+
from hpat.hiframes.split_impl import string_array_split_view_type
4040

4141

4242
_dt_index_binops = ('==', '!=', '>=', '>', '<=', '<', '-',
@@ -1709,9 +1709,10 @@ def _str_split_impl(str_arr, sep):
17091709
return self._replace_func(_str_split_impl, [arr, sep], pre_nodes=nodes)
17101710

17111711
def _run_series_str_get(self, assign, lhs, arr, rhs, nodes):
1712-
# XXX only supports get for list(list(str)) input
1713-
assert (self.typemap[arr.name]
1714-
== types.List(types.List(string_type)))
1712+
arr_typ = self.typemap[arr.name]
1713+
# XXX only supports get for list(list(str)) input and split view
1714+
assert (arr_typ == types.List(types.List(string_type))
1715+
or arr_typ == string_array_split_view_type)
17151716
ind_var = rhs.args[0]
17161717

17171718
def _str_get_impl(str_arr, ind):

hpat/hiframes/split_impl.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numba.targets.imputils import (impl_ret_new_ref, impl_ret_borrowed,
1515
iternext_impl, RefType)
1616
from hpat.str_arr_ext import (string_array_type, get_data_ptr,
17-
is_str_arr_typ, pre_alloc_string_array)
17+
is_str_arr_typ, pre_alloc_string_array, _memcpy)
1818

1919
import llvmlite.llvmpy.core as lc
2020
from llvmlite import ir as lir
@@ -99,6 +99,12 @@ def __init__(self, dmm, fe_type):
9999
models.StructModel.__init__(self, dmm, fe_type, str_arr_model_members)
100100

101101

102+
make_attribute_wrapper(StringArraySplitViewType, 'num_items', '_num_items')
103+
make_attribute_wrapper(StringArraySplitViewType, 'index_offsets', '_index_offsets')
104+
make_attribute_wrapper(StringArraySplitViewType, 'data_offsets', '_data_offsets')
105+
make_attribute_wrapper(StringArraySplitViewType, 'data', '_data')
106+
107+
102108
def construct_str_arr_split_view(context, builder):
103109
"""Creates meminfo and sets dtor.
104110
"""
@@ -238,3 +244,65 @@ def box_str_arr_split_view(typ, val, c):
238244
c.pyapi.decref(np_class_obj)
239245
return out_arr
240246

247+
248+
@intrinsic
249+
def getitem_c_arr(typingctx, c_arr, ind_t=None):
250+
def codegen(context, builder, sig, args):
251+
in_arr, ind = args
252+
return builder.load(builder.gep(in_arr, [ind]))
253+
254+
return c_arr.dtype(c_arr, ind_t), codegen
255+
256+
257+
@intrinsic
258+
def get_array_ctypes_ptr(typingctx, arr_ctypes_t, ind_t=None):
259+
def codegen(context, builder, sig, args):
260+
in_arr_ctypes, ind = args
261+
262+
arr_ctypes = context.make_helper(
263+
builder, arr_ctypes_t, in_arr_ctypes)
264+
265+
out = context.make_helper(builder, arr_ctypes_t)
266+
out.data = builder.gep(arr_ctypes.data, [ind])
267+
out.meminfo = arr_ctypes.meminfo
268+
res = out._getvalue()
269+
return impl_ret_borrowed(context, builder, arr_ctypes_t, res)
270+
271+
return arr_ctypes_t(arr_ctypes_t, ind_t), codegen
272+
273+
274+
@overload(len)
275+
def str_arr_split_view_len_overload(arr):
276+
if arr == string_array_split_view_type:
277+
return lambda arr: arr._num_items
278+
279+
280+
@overload(operator.getitem)
281+
def str_arr_split_view_getitem_overload(A, ind):
282+
if A == string_array_split_view_type and isinstance(ind, types.Integer):
283+
kind = numba.unicode.PY_UNICODE_1BYTE_KIND
284+
def _impl(A, ind):
285+
start_index = getitem_c_arr(A._index_offsets, ind)
286+
end_index = getitem_c_arr(A._index_offsets, ind+1)
287+
n = end_index - start_index - 1
288+
289+
290+
str_list = hpat.str_ext.alloc_str_list(n)
291+
for i in range(n):
292+
data_start = getitem_c_arr(
293+
A._data_offsets, start_index + i)
294+
data_start += 1
295+
# get around -1 storage in uint32 problem
296+
if start_index + i == 0:
297+
data_start = 0
298+
data_end = getitem_c_arr(
299+
A._data_offsets, start_index + i + 1)
300+
length = data_end - data_start
301+
_str = numba.unicode._empty_string(kind, length)
302+
ptr = get_array_ctypes_ptr(A._data, data_start)
303+
_memcpy(_str._data, ptr, length, 1)
304+
str_list[i] = _str
305+
306+
return str_list
307+
308+
return _impl

0 commit comments

Comments
 (0)