Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 4926b0d

Browse files
Fixing bug in stability of mergesort impl for StringArray (#961)
Motivation: for StringArray type legacy implementation of stable sort computed result when sorting with ascending=False by reversing the result of argsorting with ascending=True, which produces wrong order in groups of elements with the same value. Implemented solution adds new function argument 'ascening' and uses it when calling native function impl via serial stable_sort.
1 parent 287d783 commit 4926b0d

13 files changed

Lines changed: 239 additions & 96 deletions

sdc/_str_ext.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <string>
3232
#include <vector>
3333
#include <cmath>
34+
#include <algorithm>
3435

3536
#include "_str_decode.cpp"
3637

@@ -129,6 +130,7 @@ extern "C"
129130
npy_intp array_size(PyArrayObject* arr);
130131
void* array_getptr1(PyArrayObject* arr, npy_intp ind);
131132
void array_setitem(PyArrayObject* arr, char* p, PyObject* s);
133+
void stable_argsort(char* data_ptr, uint32_t* in_offsets, int64_t len, int8_t ascending, uint64_t* result);
132134

133135
PyMODINIT_FUNC PyInit_hstr_ext(void)
134136
{
@@ -201,6 +203,7 @@ extern "C"
201203
PyObject_SetAttrString(m, "array_setitem", PyLong_FromVoidPtr((void*)(&array_setitem)));
202204
PyObject_SetAttrString(m, "decode_utf8", PyLong_FromVoidPtr((void*)(&decode_utf8)));
203205
PyObject_SetAttrString(m, "get_utf8_size", PyLong_FromVoidPtr((void*)(&get_utf8_size)));
206+
PyObject_SetAttrString(m, "stable_argsort", PyLong_FromVoidPtr((void*)(&stable_argsort)));
204207
return m;
205208
}
206209

@@ -871,4 +874,35 @@ extern "C"
871874
return;
872875
}
873876

877+
void stable_argsort(char* data_ptr, uint32_t* in_offsets, int64_t len, int8_t ascending, uint64_t* result)
878+
{
879+
using str_index_pair_type = std::pair<std::string, int64_t>;
880+
std::vector<str_index_pair_type> str_arr_indexed;
881+
str_arr_indexed.reserve(len);
882+
883+
for (int64_t i=0; i < len; ++i)
884+
{
885+
uint32_t start = in_offsets[i];
886+
uint32_t size = in_offsets[i + 1] - in_offsets[i];
887+
str_arr_indexed.emplace_back(
888+
std::move(std::string(&data_ptr[start], size)),
889+
i
890+
);
891+
}
892+
893+
std::stable_sort(str_arr_indexed.begin(),
894+
str_arr_indexed.end(),
895+
[=](const str_index_pair_type& left, const str_index_pair_type& right){
896+
if (ascending)
897+
return left.first < right.first;
898+
else
899+
return left.first > right.first;
900+
}
901+
);
902+
903+
for (int64_t i=0; i < len; ++i)
904+
result[i] = str_arr_indexed[i].second;
905+
}
906+
907+
874908
} // extern "C"

sdc/datatypes/common_functions.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from sdc.str_arr_ext import (num_total_chars, append_string_array_to,
5353
str_arr_is_na, pre_alloc_string_array, str_arr_set_na, string_array_type,
5454
cp_str_list_to_array, create_str_arr_from_list, get_utf8_size,
55-
str_arr_set_na_by_mask)
55+
str_arr_set_na_by_mask, str_arr_stable_argosort)
5656
from sdc.utilities.prange_utils import parallel_chunks
5757
from sdc.utilities.utils import sdc_overload, sdc_register_jitable
5858
from sdc.utilities.sdc_typing_utils import (
@@ -518,41 +518,39 @@ def sdc_arrays_argsort(A, kind='quicksort'):
518518

519519

520520
@sdc_overload(sdc_arrays_argsort, jit_options={'parallel': False})
521-
def sdc_arrays_argsort_overload(A, kind='quicksort'):
521+
def sdc_arrays_argsort_overload(A, kind='quicksort', ascending=True):
522522
"""Function providing pandas argsort implementation for different 1D array types"""
523523

524524
# kind is not known at compile time, so get this function here and use in impl if needed
525525
quicksort_func = quicksort.make_jit_quicksort().run_quicksort
526526

527527
kind_is_default = isinstance(kind, str)
528528
if isinstance(A, types.Array):
529-
def _sdc_arrays_argsort_array_impl(A, kind='quicksort'):
529+
def _sdc_arrays_argsort_array_impl(A, kind='quicksort', ascending=True):
530530
_kind = 'quicksort' if kind_is_default == True else kind # noqa
531-
return numpy_like.argsort(A, kind=_kind)
531+
return numpy_like.argsort(A, kind=_kind, ascending=ascending)
532532

533533
return _sdc_arrays_argsort_array_impl
534534

535535
elif A == string_array_type:
536-
def _sdc_arrays_argsort_str_arr_impl(A, kind='quicksort'):
536+
def _sdc_arrays_argsort_str_arr_impl(A, kind='quicksort', ascending=True):
537537

538-
nan_mask = sdc.hiframes.api.get_nan_mask(A)
539-
idx = numpy.arange(len(A))
540-
old_nan_positions = idx[nan_mask]
541-
542-
data = A[~nan_mask]
543-
keys = idx[~nan_mask]
544538
if kind == 'quicksort':
545-
zipped = list(zip(list(data), list(keys)))
546-
zipped = quicksort_func(zipped)
547-
argsorted = [zipped[i][1] for i in numpy.arange(len(data))]
539+
indexes = numpy.arange(len(A))
540+
data_index_pairs = list(zip(list(A), list(indexes)))
541+
zipped = quicksort_func(data_index_pairs)
542+
argsorted = [zipped[i][1] for i in indexes]
543+
res = numpy.array(argsorted, dtype=numpy.int64)
544+
# for non-stable sort the order within groups does not matter
545+
# so just reverse the result when sorting in descending order
546+
if not ascending:
547+
res = res[::-1]
548548
elif kind == 'mergesort':
549-
sdc.hiframes.sort.local_sort((data, ), (keys, ))
550-
argsorted = list(keys)
549+
res = str_arr_stable_argosort(A, ascending=ascending)
551550
else:
552551
raise ValueError("Unrecognized kind of sort in sdc_arrays_argsort")
553552

554-
argsorted.extend(old_nan_positions)
555-
return numpy.asarray(argsorted, dtype=numpy.int32)
553+
return res
556554

557555
return _sdc_arrays_argsort_str_arr_impl
558556

sdc/datatypes/hpat_pandas_series_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3950,11 +3950,9 @@ def _sdc_pandas_series_sort_values_impl(
39503950
good = ~data_nan_mask
39513951

39523952
if kind_is_none_or_default == True: # noqa
3953-
argsort_res = sdc_arrays_argsort(self._data[good], kind='quicksort')
3953+
argsort_res = sdc_arrays_argsort(self._data[good], kind='quicksort', ascending=ascending)
39543954
else:
3955-
argsort_res = sdc_arrays_argsort(self._data[good], kind=kind)
3956-
if not ascending:
3957-
argsort_res = argsort_res[::-1]
3955+
argsort_res = sdc_arrays_argsort(self._data[good], kind=kind, ascending=ascending)
39583956

39593957
idx = numpy.arange(len(self), dtype=numpy.int32)
39603958
sorted_index = numpy.empty(len(self), dtype=numpy.int32)

sdc/functions/numpy_like.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ def sort_impl(a, axis=-1, kind=None, order=None):
12251225
return sort_impl
12261226

12271227

1228-
def argsort(a, axis=-1, kind=None, order=None):
1228+
def argsort(a, axis=-1, kind=None, order=None, ascending=True):
12291229
"""
12301230
Returns the indices that would sort an array.
12311231
@@ -1254,7 +1254,7 @@ def argsort(a, axis=-1, kind=None, order=None):
12541254

12551255

12561256
@sdc_overload(argsort)
1257-
def argsort_overload(a, axis=-1, kind=None, order=None):
1257+
def argsort_overload(a, axis=-1, kind=None, order=None, ascending=True):
12581258
_func_name = 'argsort'
12591259
ty_checker = TypeChecker(_func_name)
12601260

@@ -1266,15 +1266,15 @@ def argsort_overload(a, axis=-1, kind=None, order=None):
12661266
if not is_default(order, None):
12671267
raise TypingError(f'{_func_name} Unsupported parameter order')
12681268

1269-
def argsort_impl(a, axis=-1, kind=None, order=None):
1269+
def argsort_impl(a, axis=-1, kind=None, order=None, ascending=True):
12701270
_kind = 'quicksort'
12711271
if kind is not None:
12721272
_kind = kind
12731273

12741274
if _kind == 'quicksort':
1275-
return parallel_argsort(a)
1275+
return parallel_argsort(a, ascending)
12761276
elif _kind == 'mergesort':
1277-
return parallel_stable_argsort(a)
1277+
return parallel_stable_argsort(a, ascending)
12781278
else:
12791279
raise ValueError("Unsupported value of 'kind' parameter")
12801280

sdc/functions/sort.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def bind(sym, sig):
4747
parallel_sort_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_uint64,
4848
ct.c_uint64, ct.c_void_p,)
4949

50-
parallel_argsort_arithm_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_void_p, ct.c_uint64)
50+
parallel_argsort_arithm_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_void_p, ct.c_uint64, ct.c_uint8)
5151

5252
parallel_argsort_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_void_p, ct.c_uint64,
5353
ct.c_uint64, ct.c_void_p,)
@@ -66,7 +66,7 @@ def bind(sym, sig):
6666

6767
parallel_sort_t_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_uint64)
6868

69-
parallel_argsort_t_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_void_p, ct.c_uint64)
69+
parallel_argsort_t_sig = ct.CFUNCTYPE(None, ct.c_void_p, ct.c_void_p, ct.c_uint64, ct.c_uint8)
7070

7171
set_threads_count_sig = ct.CFUNCTYPE(None, ct.c_uint64)
7272
set_threads_count_sym = bind('set_number_of_threads', set_threads_count_sig)
@@ -290,30 +290,32 @@ def parallel_xargsort_overload_impl(dt, xargsort_map, xargsort_sym):
290290
if dt in types_to_postfix.keys():
291291
sort_f = xargsort_map[dt]
292292

293-
def parallel_xargsort_arithm_impl(arr):
293+
def parallel_xargsort_arithm_impl(arr, ascending=True):
294294
index = numpy.empty(shape=len(arr), dtype=numpy.int64)
295-
sort_f(index.ctypes, arr.ctypes, len(arr))
295+
sort_f(index.ctypes, arr.ctypes, len(arr), types.uint8(ascending))
296296

297297
return index
298298

299299
return parallel_xargsort_arithm_impl
300300

301-
def parallel_xargsort_impl(arr):
301+
# TO-DO: add/change adaptor to handle case of ascending=False
302+
def parallel_xargsort_impl(arr, ascending=True):
302303
item_size = itemsize(arr)
303304
index = numpy.empty(shape=len(arr), dtype=numpy.int64)
305+
304306
xargsort_sym(index.ctypes, arr.ctypes, len(arr), item_size, adaptor(arr[0], arr[0]))
305307

306308
return index
307309

308310
return parallel_xargsort_impl
309311

310312

311-
def parallel_argsort(arr):
313+
def parallel_argsort(arr, ascending=True):
312314
pass
313315

314316

315317
@overload(parallel_argsort)
316-
def parallel_argsort_overload(arr):
318+
def parallel_argsort_overload(arr, ascending=True):
317319

318320
if not isinstance(arr, types.Array):
319321
raise NotImplementedError
@@ -323,12 +325,12 @@ def parallel_argsort_overload(arr):
323325
return parallel_xargsort_overload_impl(dt, argsort_map, parallel_argsort_sym)
324326

325327

326-
def parallel_stable_argsort(arr):
328+
def parallel_stable_argsort(arr, ascending=True):
327329
pass
328330

329331

330332
@overload(parallel_stable_argsort)
331-
def parallel_argsort_overload(arr):
333+
def parallel_stable_argsort_overload(arr, ascending=True):
332334

333335
if not isinstance(arr, types.Array):
334336
raise NotImplementedError

sdc/native/module.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,31 +60,31 @@ extern "C"
6060

6161
void parallel_argsort_u64v(void* index, void* begin, uint64_t len, uint64_t size, void* compare);
6262

63-
void parallel_argsort_u64i8(void* index, void* begin, uint64_t len);
64-
void parallel_argsort_u64u8(void* index, void* begin, uint64_t len);
65-
void parallel_argsort_u64i16(void* index, void* begin, uint64_t len);
66-
void parallel_argsort_u64u16(void* index, void* begin, uint64_t len);
67-
void parallel_argsort_u64i32(void* index, void* begin, uint64_t len);
68-
void parallel_argsort_u64u32(void* index, void* begin, uint64_t len);
69-
void parallel_argsort_u64i64(void* index, void* begin, uint64_t len);
70-
void parallel_argsort_u64u64(void* index, void* begin, uint64_t len);
71-
72-
void parallel_argsort_u64f32(void* index, void* begin, uint64_t len);
73-
void parallel_argsort_u64f64(void* index, void* begin, uint64_t len);
63+
void parallel_argsort_u64i8(void* index, void* begin, uint64_t len, uint8_t ascending);
64+
void parallel_argsort_u64u8(void* index, void* begin, uint64_t len, uint8_t ascending);
65+
void parallel_argsort_u64i16(void* index, void* begin, uint64_t len, uint8_t ascending);
66+
void parallel_argsort_u64u16(void* index, void* begin, uint64_t len, uint8_t ascending);
67+
void parallel_argsort_u64i32(void* index, void* begin, uint64_t len, uint8_t ascending);
68+
void parallel_argsort_u64u32(void* index, void* begin, uint64_t len, uint8_t ascending);
69+
void parallel_argsort_u64i64(void* index, void* begin, uint64_t len, uint8_t ascending);
70+
void parallel_argsort_u64u64(void* index, void* begin, uint64_t len, uint8_t ascending);
71+
72+
void parallel_argsort_u64f32(void* index, void* begin, uint64_t len, uint8_t ascending);
73+
void parallel_argsort_u64f64(void* index, void* begin, uint64_t len, uint8_t ascending);
7474

7575
void parallel_stable_argsort_u64v(void* index, void* begin, uint64_t len, uint64_t size, void* compare);
7676

77-
void parallel_stable_argsort_u64i8(void* index, void* begin, uint64_t len);
78-
void parallel_stable_argsort_u64u8(void* index, void* begin, uint64_t len);
79-
void parallel_stable_argsort_u64i16(void* index, void* begin, uint64_t len);
80-
void parallel_stable_argsort_u64u16(void* index, void* begin, uint64_t len);
81-
void parallel_stable_argsort_u64i32(void* index, void* begin, uint64_t len);
82-
void parallel_stable_argsort_u64u32(void* index, void* begin, uint64_t len);
83-
void parallel_stable_argsort_u64i64(void* index, void* begin, uint64_t len);
84-
void parallel_stable_argsort_u64u64(void* index, void* begin, uint64_t len);
85-
86-
void parallel_stable_argsort_u64f32(void* index, void* begin, uint64_t len);
87-
void parallel_stable_argsort_u64f64(void* index, void* begin, uint64_t len);
77+
void parallel_stable_argsort_u64i8(void* index, void* begin, uint64_t len, uint8_t ascending);
78+
void parallel_stable_argsort_u64u8(void* index, void* begin, uint64_t len, uint8_t ascending);
79+
void parallel_stable_argsort_u64i16(void* index, void* begin, uint64_t len, uint8_t ascending);
80+
void parallel_stable_argsort_u64u16(void* index, void* begin, uint64_t len, uint8_t ascending);
81+
void parallel_stable_argsort_u64i32(void* index, void* begin, uint64_t len, uint8_t ascending);
82+
void parallel_stable_argsort_u64u32(void* index, void* begin, uint64_t len, uint8_t ascending);
83+
void parallel_stable_argsort_u64i64(void* index, void* begin, uint64_t len, uint8_t ascending);
84+
void parallel_stable_argsort_u64u64(void* index, void* begin, uint64_t len, uint8_t ascending);
85+
86+
void parallel_stable_argsort_u64f32(void* index, void* begin, uint64_t len, uint8_t ascending);
87+
void parallel_stable_argsort_u64f64(void* index, void* begin, uint64_t len, uint8_t ascending);
8888

8989
void set_number_of_threads(uint64_t threads)
9090
{

sdc/native/sort.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,16 @@ void parallel_argsort_(I* index, void* data, uint64_t len, uint64_t size, compar
9292
} // namespace
9393

9494
#define declare_single_argsort(index_prefix, type_prefix, ity, ty) \
95-
void parallel_argsort_##index_prefix##type_prefix(void* index, void* begin, uint64_t len) \
96-
{ parallel_argsort_(reinterpret_cast<ity*>(index), reinterpret_cast<ty*>(begin), len); }
95+
void parallel_argsort_##index_prefix##type_prefix(void* index, void* begin, uint64_t len, uint8_t ascending) \
96+
{ \
97+
if (ascending) { \
98+
auto cmp = utils::less<ty>(); \
99+
parallel_argsort_(reinterpret_cast<ity*>(index), reinterpret_cast<ty*>(begin), len, cmp); \
100+
} else { \
101+
auto cmp = utils::greater<ty>(); \
102+
parallel_argsort_(reinterpret_cast<ity*>(index), reinterpret_cast<ty*>(begin), len, cmp); \
103+
} \
104+
}
97105

98106
#define declare_argsort(prefix, ty) \
99107
declare_single_argsort(u8, prefix, uint8_t, ty) \

sdc/native/stable_sort.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,16 @@ struct parallel_sort_fixed_size
281281
} // namespace
282282

283283
#define declare_single_argsort(index_prefix, type_prefix, ity, ty) \
284-
void parallel_stable_argsort_##index_prefix##type_prefix(ity* index, void* begin, uint64_t len) \
285-
{ parallel_stable_argsort_(reinterpret_cast<ity*>(index), reinterpret_cast<ty*>(begin), len); }
284+
void parallel_stable_argsort_##index_prefix##type_prefix(ity* index, void* begin, uint64_t len, uint8_t ascending) \
285+
{ \
286+
if (ascending) { \
287+
auto cmp = utils::less<ty>(); \
288+
parallel_stable_argsort_(reinterpret_cast<ity*>(index), reinterpret_cast<ty*>(begin), len, cmp); \
289+
} else { \
290+
auto cmp = utils::greater<ty>(); \
291+
parallel_stable_argsort_(reinterpret_cast<ity*>(index), reinterpret_cast<ty*>(begin), len, cmp); \
292+
} \
293+
}
286294

287295
#define declare_argsort(prefix, ty) \
288296
declare_single_argsort(u8, prefix, uint8_t, ty) \
@@ -339,4 +347,4 @@ void parallel_stable_sort(void* begin, uint64_t len, uint64_t size, void* compar
339347
#undef declare_int_sort
340348
#undef declare_sort
341349
#undef declare_argsort
342-
#undef declare_single_argsort
350+
#undef declare_single_argsort

sdc/native/utils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,16 @@ bool nanless<double>(const double& left, const double& right)
169169
return std::less<double>()(left, right) || (std::isnan(right) && !std::isnan(left));
170170
}
171171

172+
template<>
173+
bool nangreater<float>(const float& left, const float& right)
174+
{
175+
return std::greater<float>()(left, right) || (std::isnan(right) && !std::isnan(left));
176+
}
177+
178+
template<>
179+
bool nangreater<double>(const double& left, const double& right)
180+
{
181+
return std::greater<double>()(left, right) || (std::isnan(right) && !std::isnan(left));
182+
}
183+
172184
}

sdc/native/utils.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,27 @@ struct less
266266
}
267267
};
268268

269+
template<typename T>
270+
bool nangreater(const T& left, const T& right)
271+
{
272+
return std::greater<T>()(left, right);
273+
}
274+
275+
template<>
276+
bool nangreater<float>(const float& left, const float& right);
277+
278+
template<>
279+
bool nangreater<double>(const double& left, const double& right);
280+
281+
template<typename T>
282+
struct greater
283+
{
284+
bool operator() (const T& left, const T& right) const
285+
{
286+
return nangreater<T>(left, right);
287+
}
288+
};
289+
269290
namespace tbb_control
270291
{
271292
void init();

0 commit comments

Comments
 (0)