Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit f783de8

Browse files
author
Ehsan Totoni
committed
str equal comparison
1 parent 24b176d commit f783de8

2 files changed

Lines changed: 37 additions & 0 deletions

File tree

hpat/_str_ext.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ void* init_string(char*, int64_t);
66
void* init_string_const(char* in_str);
77
const char* get_c_str(std::string* s);
88
void* str_concat(std::string* s1, std::string* s2);
9+
bool str_equal(std::string* s1, std::string* s2);
910

1011
PyMODINIT_FUNC PyInit_hstr_ext(void) {
1112
PyObject *m;
@@ -23,6 +24,8 @@ PyMODINIT_FUNC PyInit_hstr_ext(void) {
2324
PyLong_FromVoidPtr((void*)(&get_c_str)));
2425
PyObject_SetAttrString(m, "str_concat",
2526
PyLong_FromVoidPtr((void*)(&str_concat)));
27+
PyObject_SetAttrString(m, "str_equal",
28+
PyLong_FromVoidPtr((void*)(&str_equal)));
2629
return m;
2730
}
2831

@@ -50,3 +53,9 @@ void* str_concat(std::string* s1, std::string* s2)
5053
std::string* res = new std::string((*s1)+(*s2));
5154
return res;
5255
}
56+
57+
bool str_equal(std::string* s1, std::string* s2)
58+
{
59+
// printf("in str_equal %s %s\n", s1->c_str(), s2->c_str());
60+
return s1->compare(*s2)==0;
61+
}

hpat/str_ext.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,25 @@ class StringAdd(ConcreteTemplate):
2525
key = "+"
2626
cases = [signature(string_type, string_type, string_type)]
2727

28+
@infer
29+
class StringOpEq(AbstractTemplate):
30+
key = '=='
31+
def generic(self, args, kws):
32+
assert not kws
33+
(arg1, arg2) = args
34+
if isinstance(arg1, StringType) and isinstance(arg2, StringType):
35+
return signature(types.boolean, arg1, arg2)
36+
37+
@infer
38+
class StringOpNotEq(StringOpEq):
39+
key = '!='
40+
2841
import hstr_ext
2942
ll.add_symbol('init_string', hstr_ext.init_string)
3043
ll.add_symbol('init_string_const', hstr_ext.init_string_const)
3144
ll.add_symbol('get_c_str', hstr_ext.get_c_str)
3245
ll.add_symbol('str_concat', hstr_ext.str_concat)
46+
ll.add_symbol('str_equal', hstr_ext.str_equal)
3347

3448
@unbox(StringType)
3549
def unbox_string(typ, obj, c):
@@ -71,3 +85,17 @@ def impl_string_concat(context, builder, sig, args):
7185
[lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer()])
7286
fn = builder.module.get_or_insert_function(fnty, name="str_concat")
7387
return builder.call(fn, args)
88+
89+
@lower_builtin('==', string_type, string_type)
90+
def string_eq_impl(context, builder, sig, args):
91+
fnty = lir.FunctionType(lir.IntType(1),
92+
[lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer()])
93+
fn = builder.module.get_or_insert_function(fnty, name="str_equal")
94+
return builder.call(fn, args)
95+
96+
@lower_builtin('!=', string_type, string_type)
97+
def string_neq_impl(context, builder, sig, args):
98+
fnty = lir.FunctionType(lir.IntType(1),
99+
[lir.IntType(8).as_pointer(), lir.IntType(8).as_pointer()])
100+
fn = builder.module.get_or_insert_function(fnty, name="str_equal")
101+
return builder.not_(builder.call(fn, args))

0 commit comments

Comments
 (0)