Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit fbbde97

Browse files
author
Ehsan Totoni
committed
support str split
1 parent ee62d10 commit fbbde97

2 files changed

Lines changed: 62 additions & 2 deletions

File tree

hpat/_str_ext.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#include <Python.h>
22
#include <string>
33
#include <iostream>
4+
#include <vector>
45

56
void* init_string(char*, int64_t);
67
void* init_string_const(char* in_str);
78
const char* get_c_str(std::string* s);
89
void* str_concat(std::string* s1, std::string* s2);
910
bool str_equal(std::string* s1, std::string* s2);
11+
void* str_split(std::string* str, std::string* sep, int64_t *size);
1012

1113
PyMODINIT_FUNC PyInit_hstr_ext(void) {
1214
PyObject *m;
@@ -26,6 +28,8 @@ PyMODINIT_FUNC PyInit_hstr_ext(void) {
2628
PyLong_FromVoidPtr((void*)(&str_concat)));
2729
PyObject_SetAttrString(m, "str_equal",
2830
PyLong_FromVoidPtr((void*)(&str_equal)));
31+
PyObject_SetAttrString(m, "str_split",
32+
PyLong_FromVoidPtr((void*)(&str_split)));
2933
return m;
3034
}
3135

@@ -59,3 +63,27 @@ bool str_equal(std::string* s1, std::string* s2)
5963
// printf("in str_equal %s %s\n", s1->c_str(), s2->c_str());
6064
return s1->compare(*s2)==0;
6165
}
66+
67+
void* str_split(std::string* str, std::string* sep, int64_t *size)
68+
{
69+
// std::cout << *str << " " << *sep << std::endl;
70+
std::vector<std::string*> res;
71+
72+
size_t last = 0;
73+
size_t next = 0;
74+
while ((next = str->find(*sep, last)) != std::string::npos) {
75+
std::string *token = new std::string(str->substr(last, next-last));
76+
res.push_back(token);
77+
last = next + 1;
78+
}
79+
std::string *token = new std::string(str->substr(last));
80+
res.push_back(token);
81+
*size = res.size();
82+
// for(int i=0; i<*size; i++)
83+
// std::cout<<*(res[i])<<std::endl;
84+
// TODO: avoid extra copy
85+
void* out = new void*[*size];
86+
memcpy(out, res.data(), (*size)*sizeof(void*));
87+
// std::cout<< *(((std::string**)(out))[1])<<std::endl;
88+
return out;
89+
}

hpat/str_ext.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import numba
12
from numba.extending import (box, unbox, typeof_impl, register_model, models,
23
NativeValue, lower_builtin)
3-
from numba.targets.imputils import lower_constant
4+
from numba.targets.imputils import lower_constant, impl_ret_new_ref
45
from numba import types, typing
5-
from numba.typing.templates import (signature, AbstractTemplate, infer,
6+
from numba.typing.templates import (signature, AbstractTemplate, infer, infer_getattr,
67
ConcreteTemplate, AttributeTemplate, bound_function, infer_global)
78
from numba import cgutils
89
from llvmlite import ir as lir
@@ -38,12 +39,23 @@ def generic(self, args, kws):
3839
class StringOpNotEq(StringOpEq):
3940
key = '!='
4041

42+
@infer_getattr
43+
class StringAttribute(AttributeTemplate):
44+
key = StringType
45+
46+
@bound_function("str.split")
47+
def resolve_split(self, dict, args, kws):
48+
assert not kws
49+
assert len(args) == 1
50+
return signature(types.List(string_type), *args)
51+
4152
import hstr_ext
4253
ll.add_symbol('init_string', hstr_ext.init_string)
4354
ll.add_symbol('init_string_const', hstr_ext.init_string_const)
4455
ll.add_symbol('get_c_str', hstr_ext.get_c_str)
4556
ll.add_symbol('str_concat', hstr_ext.str_concat)
4657
ll.add_symbol('str_equal', hstr_ext.str_equal)
58+
ll.add_symbol('str_split', hstr_ext.str_split)
4759

4860
@unbox(StringType)
4961
def unbox_string(typ, obj, c):
@@ -99,3 +111,23 @@ def string_neq_impl(context, builder, sig, args):
99111
[lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer()])
100112
fn = builder.module.get_or_insert_function(fnty, name="str_equal")
101113
return builder.not_(builder.call(fn, args))
114+
115+
@lower_builtin("str.split", string_type, string_type)
116+
def string_split_impl(context, builder, sig, args):
117+
nitems = cgutils.alloca_once(builder, lir.IntType(64))
118+
# input str, sep, size pointer
119+
fnty = lir.FunctionType(lir.IntType(8).as_pointer().as_pointer(),
120+
[lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer(),
121+
lir.IntType(64).as_pointer()])
122+
fn = builder.module.get_or_insert_function(fnty, name="str_split")
123+
ptr = builder.call(fn, args+[nitems])
124+
size = builder.load(nitems)
125+
# TODO: use ptr instead of allocating and copying, use NRT_MemInfo_new
126+
# TODO: deallocate ptr
127+
_list = numba.targets.listobj.ListInstance.allocate(context, builder,
128+
sig.return_type, size)
129+
_list.size = size
130+
with cgutils.for_range(builder, size) as loop:
131+
value = builder.load(cgutils.gep_inbounds(builder, ptr, loop.index))
132+
_list.setitem(loop.index, value)
133+
return impl_ret_new_ref(context, builder, sig.return_type, _list.value)

0 commit comments

Comments
 (0)