Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 79c4f2a

Browse files
author
Ehsan Totoni
committed
str to float cast
1 parent 9c5d029 commit 79c4f2a

3 files changed

Lines changed: 31 additions & 1 deletion

File tree

hpat/_str_ext.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ bool str_equal(std::string* s1, std::string* s2);
1111
void* str_split(std::string* str, std::string* sep, int64_t *size);
1212
void* str_substr_int(std::string* str, int64_t index);
1313
int64_t str_to_int64(std::string* str);
14+
double str_to_float64(std::string* str);
1415

1516
PyMODINIT_FUNC PyInit_hstr_ext(void) {
1617
PyObject *m;
@@ -36,6 +37,8 @@ PyMODINIT_FUNC PyInit_hstr_ext(void) {
3637
PyLong_FromVoidPtr((void*)(&str_substr_int)));
3738
PyObject_SetAttrString(m, "str_to_int64",
3839
PyLong_FromVoidPtr((void*)(&str_to_int64)));
40+
PyObject_SetAttrString(m, "str_to_float64",
41+
PyLong_FromVoidPtr((void*)(&str_to_float64)));
3942
return m;
4043
}
4144

@@ -103,3 +106,8 @@ int64_t str_to_int64(std::string* str)
103106
{
104107
return std::stoll(*str);
105108
}
109+
110+
double str_to_float64(std::string* str)
111+
{
112+
return std::stod(*str);
113+
}

hpat/str_ext.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ def generic(self, args, kws):
6767
if isinstance(arg, StringType):
6868
return signature(types.intp, arg)
6969

70+
@infer_global(float)
71+
class StrToFloat(AbstractTemplate):
72+
def generic(self, args, kws):
73+
assert not kws
74+
[arg] = args
75+
if isinstance(arg, StringType):
76+
return signature(types.float64, arg)
77+
7078
import hstr_ext
7179
ll.add_symbol('init_string', hstr_ext.init_string)
7280
ll.add_symbol('init_string_const', hstr_ext.init_string_const)
@@ -76,6 +84,7 @@ def generic(self, args, kws):
7684
ll.add_symbol('str_split', hstr_ext.str_split)
7785
ll.add_symbol('str_substr_int', hstr_ext.str_substr_int)
7886
ll.add_symbol('str_to_int64', hstr_ext.str_to_int64)
87+
ll.add_symbol('str_to_float64', hstr_ext.str_to_float64)
7988

8089
@unbox(StringType)
8190
def unbox_string(typ, obj, c):
@@ -162,7 +171,13 @@ def getitem_string(context, builder, sig, args):
162171
return (builder.call(fn, args))
163172

164173
@lower_cast(StringType, types.int64)
165-
def dict_empty(context, builder, fromty, toty, val):
174+
def cast_str_to_int64(context, builder, fromty, toty, val):
166175
fnty = lir.FunctionType(lir.IntType(64), [lir.IntType(8).as_pointer()])
167176
fn = builder.module.get_or_insert_function(fnty, name="str_to_int64")
168177
return builder.call(fn, (val,))
178+
179+
@lower_cast(StringType, types.float64)
180+
def cast_str_to_float64(context, builder, fromty, toty, val):
181+
fnty = lir.FunctionType(lir.DoubleType(), [lir.IntType(8).as_pointer()])
182+
fn = builder.module.get_or_insert_function(fnty, name="str_to_float64")
183+
return builder.call(fn, (val,))

hpat/tests/test_strings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,12 @@ def test_impl(_str):
5858
arg = '12'
5959
self.assertEqual(hpat_func(arg), test_impl(arg))
6060

61+
def test_string_float_cast(self):
62+
def test_impl(_str):
63+
return float(_str)
64+
hpat_func = hpat.jit(test_impl)
65+
arg = '12.2'
66+
self.assertEqual(hpat_func(arg), test_impl(arg))
67+
6168
if __name__ == "__main__":
6269
unittest.main()

0 commit comments

Comments
 (0)