|
| 1 | +import numba |
1 | 2 | from numba.extending import (box, unbox, typeof_impl, register_model, models, |
2 | 3 | NativeValue, lower_builtin) |
3 | | -from numba.targets.imputils import lower_constant |
| 4 | +from numba.targets.imputils import lower_constant, impl_ret_new_ref |
4 | 5 | from numba import types, typing |
5 | | -from numba.typing.templates import (signature, AbstractTemplate, infer, |
| 6 | +from numba.typing.templates import (signature, AbstractTemplate, infer, infer_getattr, |
6 | 7 | ConcreteTemplate, AttributeTemplate, bound_function, infer_global) |
7 | 8 | from numba import cgutils |
8 | 9 | from llvmlite import ir as lir |
@@ -38,12 +39,23 @@ def generic(self, args, kws): |
38 | 39 | class StringOpNotEq(StringOpEq): |
39 | 40 | key = '!=' |
40 | 41 |
|
| 42 | +@infer_getattr |
| 43 | +class StringAttribute(AttributeTemplate): |
| 44 | + key = StringType |
| 45 | + |
| 46 | + @bound_function("str.split") |
| 47 | + def resolve_split(self, dict, args, kws): |
| 48 | + assert not kws |
| 49 | + assert len(args) == 1 |
| 50 | + return signature(types.List(string_type), *args) |
| 51 | + |
41 | 52 | import hstr_ext |
42 | 53 | ll.add_symbol('init_string', hstr_ext.init_string) |
43 | 54 | ll.add_symbol('init_string_const', hstr_ext.init_string_const) |
44 | 55 | ll.add_symbol('get_c_str', hstr_ext.get_c_str) |
45 | 56 | ll.add_symbol('str_concat', hstr_ext.str_concat) |
46 | 57 | ll.add_symbol('str_equal', hstr_ext.str_equal) |
| 58 | +ll.add_symbol('str_split', hstr_ext.str_split) |
47 | 59 |
|
48 | 60 | @unbox(StringType) |
49 | 61 | def unbox_string(typ, obj, c): |
@@ -99,3 +111,23 @@ def string_neq_impl(context, builder, sig, args): |
99 | 111 | [lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer()]) |
100 | 112 | fn = builder.module.get_or_insert_function(fnty, name="str_equal") |
101 | 113 | return builder.not_(builder.call(fn, args)) |
| 114 | + |
| 115 | +@lower_builtin("str.split", string_type, string_type) |
| 116 | +def string_split_impl(context, builder, sig, args): |
| 117 | + nitems = cgutils.alloca_once(builder, lir.IntType(64)) |
| 118 | + # input str, sep, size pointer |
| 119 | + fnty = lir.FunctionType(lir.IntType(8).as_pointer().as_pointer(), |
| 120 | + [lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer(), |
| 121 | + lir.IntType(64).as_pointer()]) |
| 122 | + fn = builder.module.get_or_insert_function(fnty, name="str_split") |
| 123 | + ptr = builder.call(fn, args+[nitems]) |
| 124 | + size = builder.load(nitems) |
| 125 | + # TODO: use ptr instead of allocating and copying, use NRT_MemInfo_new |
| 126 | + # TODO: deallocate ptr |
| 127 | + _list = numba.targets.listobj.ListInstance.allocate(context, builder, |
| 128 | + sig.return_type, size) |
| 129 | + _list.size = size |
| 130 | + with cgutils.for_range(builder, size) as loop: |
| 131 | + value = builder.load(cgutils.gep_inbounds(builder, ptr, loop.index)) |
| 132 | + _list.setitem(loop.index, value) |
| 133 | + return impl_ret_new_ref(context, builder, sig.return_type, _list.value) |
0 commit comments