Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 79fb01b

Browse files
Fix segfault in NRT_dealloc for str_arr getitem (#955)
* Fix segfault in NRT_dealloc for str_arr getitem Motivation: legacy SDC implementation copied Numba internals (MemInfo struct and API functions), which if changed on the Numba side (as it's done in IntelPython/Numba where external_allocator is added) may cause segfaults. This PR removes duplication of Numba internals and moves hstr extension to NRT API functions * Updating conda-recipe with build requirement on Numba
1 parent a91d01c commit 79fb01b

5 files changed

Lines changed: 44 additions & 202 deletions

File tree

conda-recipe/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ requirements:
2020
- {{ compiler('cxx') }} # [not osx]
2121
- wheel
2222
- python
23+
- numba {{ NUMBA_VERSION }}
2324

2425
host:
2526
- python

sdc/_meminfo.h

Lines changed: 0 additions & 182 deletions
This file was deleted.

sdc/_str_decode.cpp

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626

2727
#include <Python.h>
2828
#include <iostream>
29+
#include <stdlib.h>
2930

30-
#include "_meminfo.h"
31+
#include "numba/core/runtime/nrt_external.h"
3132

3233
#ifndef Py_UNREACHABLE
3334
#define Py_UNREACHABLE() abort()
@@ -37,6 +38,7 @@
3738

3839
typedef struct
3940
{
41+
NRT_api_functions* nrt;
4042
NRT_MemInfo* buffer;
4143
void* data;
4244
enum PyUnicode_Kind kind;
@@ -120,20 +122,28 @@ void _C_UnicodeWriter_Init(_C_UnicodeWriter* writer)
120122
#include "stringlib/undef.h"
121123

122124
static inline int _C_UnicodeWriter_WriteCharInline(_C_UnicodeWriter* writer, Py_UCS4 ch);
123-
static int _copy_characters(NRT_MemInfo* to,
125+
static int _copy_characters(NRT_api_functions* nrt,
126+
NRT_MemInfo* to,
124127
Py_ssize_t to_start,
125128
NRT_MemInfo* from,
126129
Py_ssize_t from_start,
127130
Py_ssize_t how_many,
128131
unsigned int from_kind,
129132
unsigned int to_kind);
130133

134+
135+
static void str_data_dtor(void* data_ptr)
136+
{
137+
free(data_ptr);
138+
}
139+
131140
// similar to PyUnicode_New()
132141
NRT_MemInfo* alloc_writer(_C_UnicodeWriter* writer, Py_ssize_t newlen, Py_UCS4 maxchar)
133142
{
134143
enum PyUnicode_Kind kind;
135144
int is_ascii = 0;
136145
Py_ssize_t char_size;
146+
auto nrt = writer->nrt;
137147

138148
if (maxchar < 128)
139149
{
@@ -161,20 +171,22 @@ NRT_MemInfo* alloc_writer(_C_UnicodeWriter* writer, Py_ssize_t newlen, Py_UCS4 m
161171
kind = PyUnicode_4BYTE_KIND;
162172
char_size = 4;
163173
}
164-
NRT_MemInfo* newbuffer = NRT_MemInfo_alloc_safe((newlen + 1) * char_size);
165-
if (newbuffer == NULL)
174+
175+
char* str_data = (char*)malloc((newlen + 1) * char_size);
176+
if (str_data == NULL)
166177
{
167178
return NULL;
168179
}
169180

181+
auto newbuffer = nrt->manage_memory(str_data, str_data_dtor);
170182
if (writer->buffer != NULL)
171183
{
172-
_copy_characters(newbuffer, 0, writer->buffer, 0, writer->pos, writer->kind, kind);
173-
NRT_MemInfo_call_dtor(writer->buffer);
184+
_copy_characters(nrt, newbuffer, 0, writer->buffer, 0, writer->pos, writer->kind, kind);
185+
nrt->release(writer->buffer);
174186
}
175187
writer->buffer = newbuffer;
176188
writer->maxchar = KIND_MAX_CHAR_VALUE(kind);
177-
writer->data = writer->buffer->data;
189+
writer->data = nrt->get_data(writer->buffer);
178190

179191
if (!writer->readonly)
180192
{
@@ -356,19 +368,22 @@ static Py_ssize_t ascii_decode(const char* start, const char* end, Py_UCS1* dest
356368
return p - start;
357369
}
358370

371+
359372
// ported from CPython PyUnicode_DecodeUTF8Stateful: https://github.com/python/cpython/blob/31e8d69bfe7cf5d4ffe0967cb225d2a8a229cc97/Objects/unicodeobject.c#L4813
360-
void decode_utf8(const char* s, Py_ssize_t size, int* kind, int* is_ascii, int* length, NRT_MemInfo** meminfo)
373+
void decode_utf8(const char* s, Py_ssize_t size, int* kind, int* is_ascii, int* length, NRT_MemInfo** meminfo, void* nrt_table)
361374
{
362375
_C_UnicodeWriter writer;
363376
const char* end = s + size;
377+
auto nrt = (NRT_api_functions*)nrt_table;
364378

365379
const char* errmsg = "";
366380
*is_ascii = 0;
367381

368382
if (size == 0)
369383
{
370-
(*meminfo) = NRT_MemInfo_alloc_safe(1);
371-
((char*)((*meminfo)->data))[0] = 0;
384+
char* str_data = (char*)malloc(1);
385+
(*meminfo) = nrt->manage_memory(str_data, str_data_dtor);
386+
((char*)(nrt->get_data(*meminfo)))[0] = 0;
372387
*kind = PyUnicode_1BYTE_KIND;
373388
*is_ascii = 1;
374389
*length = 0;
@@ -379,9 +394,10 @@ void decode_utf8(const char* s, Py_ssize_t size, int* kind, int* is_ascii, int*
379394
if (size == 1 && (unsigned char)s[0] < 128)
380395
{
381396
// TODO interning
382-
(*meminfo) = NRT_MemInfo_alloc_safe(2);
383-
((char*)((*meminfo)->data))[0] = s[0];
384-
((char*)((*meminfo)->data))[1] = 0;
397+
char* str_data = (char*)malloc(2);
398+
(*meminfo) = nrt->manage_memory(str_data, str_data_dtor);
399+
((char*)(nrt->get_data(*meminfo)))[0] = s[0];
400+
((char*)(nrt->get_data(*meminfo)))[1] = 0;
385401
*kind = PyUnicode_1BYTE_KIND;
386402
*is_ascii = 1;
387403
*length = 1;
@@ -390,6 +406,7 @@ void decode_utf8(const char* s, Py_ssize_t size, int* kind, int* is_ascii, int*
390406

391407
_C_UnicodeWriter_Init(&writer);
392408
writer.min_length = size;
409+
writer.nrt = nrt;
393410
if (_C_UnicodeWriter_Prepare(&writer, writer.min_length, 127) == -1)
394411
goto onError;
395412

@@ -469,7 +486,7 @@ void decode_utf8(const char* s, Py_ssize_t size, int* kind, int* is_ascii, int*
469486

470487
onError:
471488
std::cerr << "utf8 decode error:" << errmsg << std::endl;
472-
NRT_MemInfo_call_dtor(writer.buffer);
489+
nrt->release(*meminfo);
473490
return;
474491
}
475492

@@ -499,7 +516,8 @@ void decode_utf8(const char* s, Py_ssize_t size, int* kind, int* is_ascii, int*
499516
*_to++ = (to_type)*_iter++; \
500517
} while (0)
501518

502-
static int _copy_characters(NRT_MemInfo* to,
519+
static int _copy_characters(NRT_api_functions* nrt,
520+
NRT_MemInfo* to,
503521
Py_ssize_t to_start,
504522
NRT_MemInfo* from,
505523
Py_ssize_t from_start,
@@ -516,8 +534,8 @@ static int _copy_characters(NRT_MemInfo* to,
516534
if (how_many == 0)
517535
return 0;
518536

519-
from_data = from->data;
520-
to_data = to->data;
537+
from_data = nrt->get_data(from);
538+
to_data = nrt->get_data(to);
521539

522540
if (from_kind == to_kind)
523541
{

sdc/str_arr_ext.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,7 @@ def str_arr_getitem_by_array_impl(A, arg):
10261026
def decode_utf8(typingctx, ptr_t, len_t=None):
10271027
def codegen(context, builder, sig, args):
10281028
ptr, length = args
1029+
nrt_table = context.nrt.get_nrt_api(builder)
10291030

10301031
# create str and call decode with internal pointers
10311032
uni_str = cgutils.create_struct_proxy(string_type)(context, builder)
@@ -1034,14 +1035,16 @@ def codegen(context, builder, sig, args):
10341035
lir.IntType(32).as_pointer(),
10351036
lir.IntType(32).as_pointer(),
10361037
lir.IntType(64).as_pointer(),
1037-
uni_str.meminfo.type.as_pointer()])
1038+
uni_str.meminfo.type.as_pointer(),
1039+
lir.IntType(8).as_pointer()])
10381040
fn_decode = builder.module.get_or_insert_function(
10391041
fnty, name="decode_utf8")
10401042
builder.call(fn_decode, [ptr, length,
10411043
uni_str._get_ptr_by_name('kind'),
10421044
uni_str._get_ptr_by_name('is_ascii'),
10431045
uni_str._get_ptr_by_name('length'),
1044-
uni_str._get_ptr_by_name('meminfo')])
1046+
uni_str._get_ptr_by_name('meminfo'),
1047+
nrt_table])
10451048
uni_str.hash = context.get_constant(_Py_hash_t, -1)
10461049
uni_str.data = context.nrt.meminfo_data(builder, uni_str.meminfo)
10471050
# Set parent to NULL

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import platform
3030
import os
3131
import sys
32+
import numba
3233
from docs.source.buildscripts.sdc_build_doc import SDCBuildDoc
3334

3435

@@ -185,14 +186,15 @@ def check_file_at_path(path2file):
185186
)
186187

187188
str_libs = np_compile_args['libraries']
189+
numba_include_path = numba.extending.include_path()
188190

189191
ext_str = Extension(name="sdc.hstr_ext",
190192
sources=["sdc/_str_ext.cpp"],
191193
libraries=str_libs,
192194
define_macros=np_compile_args['define_macros'],
193195
extra_compile_args=eca,
194196
extra_link_args=ela,
195-
include_dirs=np_compile_args['include_dirs'] + ind,
197+
include_dirs=np_compile_args['include_dirs'] + ind + [numba_include_path],
196198
library_dirs=np_compile_args['library_dirs'] + lid,
197199
)
198200

0 commit comments

Comments
 (0)