Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 57926a5

Browse files
author
Ehsan Totoni
committed
split view opt WIP
1 parent 45a07b9 commit 57926a5

6 files changed

Lines changed: 315 additions & 5 deletions

File tree

hpat/_str_ext.cpp

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,22 @@ struct str_arr_payload {
2929
uint8_t* null_bitmap;
3030
};
3131

32+
// XXX: equivalent to payload data model in split_impl.py
33+
struct str_arr_split_view_payload {
34+
uint32_t *index_offsets;
35+
uint32_t *data_offsets;
36+
// uint8_t* null_bitmap;
37+
};
38+
3239
// taken from Arrow bin-util.h
3340
static constexpr uint8_t kBitmask[] = {1, 2, 4, 8, 16, 32, 64, 128};
3441

3542
void* init_string(char*, int64_t);
3643
void* init_string_const(char* in_str);
3744
void dtor_string(std::string** in_str, int64_t size, void* in);
3845
void dtor_string_array(str_arr_payload* in_str, int64_t size, void* in);
46+
void dtor_str_arr_split_view(str_arr_split_view_payload* in_str_arr, int64_t size, void* in);
47+
void str_arr_split_view_impl(str_arr_split_view_payload* out_view, int64_t n_strs, uint32_t* offsets, char* data, char sep);
3948
const char* get_c_str(std::string* s);
4049
const char* get_char_ptr(char c);
4150
void* str_concat(std::string* s1, std::string* s2);
@@ -103,6 +112,10 @@ PyMODINIT_FUNC PyInit_hstr_ext(void) {
103112
PyLong_FromVoidPtr((void*)(&dtor_string)));
104113
PyObject_SetAttrString(m, "dtor_string_array",
105114
PyLong_FromVoidPtr((void*)(&dtor_string_array)));
115+
PyObject_SetAttrString(m, "dtor_str_arr_split_view",
116+
PyLong_FromVoidPtr((void*)(&dtor_str_arr_split_view)));
117+
PyObject_SetAttrString(m, "str_arr_split_view_impl",
118+
PyLong_FromVoidPtr((void*)(&str_arr_split_view_impl)));
106119
PyObject_SetAttrString(m, "get_c_str",
107120
PyLong_FromVoidPtr((void*)(&get_c_str)));
108121
PyObject_SetAttrString(m, "get_char_ptr",
@@ -225,6 +238,65 @@ void dtor_string_array(str_arr_payload* in_str_arr, int64_t size, void* in)
225238
return;
226239
}
227240

241+
void dtor_str_arr_split_view(str_arr_split_view_payload* in_str_arr, int64_t size, void* in)
242+
{
243+
// printf("str arr dtor size: %lld\n", in_str_arr->size);
244+
// printf("num chars: %d\n", in_str_arr->offsets[in_str_arr->size]);
245+
delete[] in_str_arr->index_offsets;
246+
delete[] in_str_arr->data_offsets;
247+
// if (in_str_arr->null_bitmap != nullptr)
248+
// delete[] in_str_arr->null_bitmap;
249+
return;
250+
}
251+
252+
void str_arr_split_view_impl(str_arr_split_view_payload* out_view, int64_t n_strs, uint32_t* offsets, char* data, char sep)
253+
{
254+
uint32_t total_chars = offsets[n_strs];
255+
printf("n_strs %d sep %c total chars:%d\n", n_strs, sep, total_chars);
256+
//return;
257+
uint32_t* index_offsets = new uint32_t[n_strs+1];
258+
std::vector<uint32_t> data_offs;
259+
260+
data_offs.push_back(-1);
261+
index_offsets[0] = 0;
262+
// uint32_t curr_data_off = 0;
263+
264+
int data_ind = offsets[0];
265+
int str_ind = 0;
266+
// while there are chars to consume, equal since the first if will consume it
267+
while (data_ind <= total_chars)
268+
{
269+
// string has finished
270+
if (data_ind == offsets[str_ind+1])
271+
{
272+
data_offs.push_back(data_ind);
273+
index_offsets[str_ind+1] = data_offs.size();
274+
str_ind++;
275+
if (str_ind == n_strs) break; // all finished
276+
continue; // stay on same data_ind for start of next string
277+
}
278+
if (data[data_ind] == sep)
279+
{
280+
data_offs.push_back(data_ind);
281+
}
282+
data_ind++;
283+
}
284+
out_view->index_offsets = index_offsets;
285+
out_view->data_offsets = new uint32_t[data_offs.size()];
286+
// TODO: avoid copy
287+
std::copy(data_offs.cbegin(), data_offs.cend(), out_view->data_offsets);
288+
289+
printf("index_offsets: ");
290+
for (int i=0; i<=n_strs; i++)
291+
printf("%d ", index_offsets[i]);
292+
printf("\n");
293+
printf("data_offsets: ");
294+
for (int i=0; i<data_offs.size(); i++)
295+
printf("%d ", data_offs[i]);
296+
printf("\n");
297+
return;
298+
}
299+
228300
const char* get_c_str(std::string* s)
229301
{
230302
// printf("in get %s\n", s->c_str());
@@ -507,7 +579,7 @@ void string_array_from_sequence(PyObject * obj, int64_t * no_strings, uint32_t *
507579
PyGILState_Release(gilstate);
508580
return;
509581
}
510-
582+
511583
*no_strings = -1;
512584
*offset_table = NULL;
513585
*buffer = NULL;

hpat/hiframes/boxing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from hpat.hiframes.pd_categorical_ext import (PDCategoricalDtype,
2323
box_categorical_array, unbox_categorical_array)
2424
from hpat.hiframes.pd_series_ext import SeriesType, arr_to_series_type
25+
from hpat.hiframes.split_impl import (string_array_split_view_type,
26+
box_str_arr_split_view)
2527

2628
from .. import hstr_ext
2729
import llvmlite.binding as ll
@@ -318,6 +320,8 @@ def _box_series_data(dtype, data_typ, val, c):
318320
arr = box_datetime_date_array(data_typ, val, c)
319321
elif isinstance(dtype, PDCategoricalDtype):
320322
arr = box_categorical_array(data_typ, val, c)
323+
elif data_typ == string_array_split_view_type:
324+
arr = box_str_arr_split_view(data_typ, val, c)
321325
elif dtype == types.List(string_type):
322326
arr = box_list(list_string_array_type, val, c)
323327
else:

hpat/hiframes/hiframes_typed.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from hpat.io.pio_api import h5dataset_type
3535
from hpat.hiframes.rolling import get_rolling_setup_args
3636
from hpat.hiframes.aggregate import Aggregate
37-
from hpat.hiframes import series_kernels
37+
from hpat.hiframes import series_kernels, split_impl
3838
from hpat.hiframes.series_kernels import series_replace_funcs
3939

4040

@@ -1687,6 +1687,7 @@ def _run_series_str_replace(self, assign, lhs, arr, rhs, nodes):
16871687

16881688
def _run_series_str_split(self, assign, lhs, arr, rhs, nodes):
16891689
sep = rhs.args[0] # TODO: support default whitespace separator
1690+
sep_typ = self.typemap[sep.name]
16901691

16911692
def _str_split_impl(str_arr, sep):
16921693
numba.parfor.init_prange()
@@ -1698,6 +1699,13 @@ def _str_split_impl(str_arr, sep):
16981699

16991700
return hpat.hiframes.api.init_series(out_arr)
17001701

1702+
1703+
if isinstance(sep_typ, types.StringLiteral) and len(sep_typ.literal_value) == 1:
1704+
def _str_split_impl(str_arr, sep):
1705+
out_arr = hpat.hiframes.split_impl.compute_split_view(
1706+
str_arr, sep)
1707+
return hpat.hiframes.api.init_series(out_arr)
1708+
17011709
return self._replace_func(_str_split_impl, [arr, sep], pre_nodes=nodes)
17021710

17031711
def _run_series_str_get(self, assign, lhs, arr, rhs, nodes):

hpat/hiframes/pd_series_ext.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
CategoricalArray)
2222
from hpat.hiframes.rolling import supported_rolling_funcs
2323
import datetime
24-
24+
from hpat.hiframes.split_impl import string_array_split_view_type
2525

2626
class SeriesType(types.IterableType):
2727
"""Temporary type class for Series objects.
@@ -562,7 +562,11 @@ def resolve_replace(self, ary, args, kws):
562562

563563
@bound_function("strmethod.split")
564564
def resolve_split(self, ary, args, kws):
565-
return signature(SeriesType(types.List(string_type)), *args)
565+
out = SeriesType(types.List(string_type))
566+
if (len(args) == 1 and isinstance(args[0], types.StringLiteral)
567+
and len(args[0].literal_value) == 1):
568+
out = SeriesType(types.List(string_type), string_array_split_view_type)
569+
return signature(out, *args)
566570

567571
@bound_function("strmethod.get")
568572
def resolve_get(self, ary, args, kws):

0 commit comments

Comments
 (0)