Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 58bb4b6

Browse files
author
Ehsan Totoni
committed
support S.head(), dist broadcast of const slice
1 parent ac05952 commit 58bb4b6

5 files changed

Lines changed: 112 additions & 3 deletions

File tree

hpat/distributed.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
import hpat.utils
3737
from hpat.utils import (is_alloc_callname, is_whole_slice, is_array_container,
3838
get_slice_step, is_array, is_np_array, find_build_tuple,
39-
debug_prints, ReplaceFunc, gen_getitem, is_call)
39+
debug_prints, ReplaceFunc, gen_getitem, is_call,
40+
is_const_slice)
4041
from hpat.distributed_api import Reduce_Type
4142
from hpat.hiframes.pd_dataframe_ext import DataFrameType
4243

@@ -739,6 +740,7 @@ def f(arr, bag, start, count): # pragma: no cover
739740
return [assign]
740741

741742
if (fdef == ('get_series_data', 'hpat.hiframes.api')
743+
or fdef == ('get_series_index', 'hpat.hiframes.api')
742744
or fdef == ('get_dataframe_data', 'hpat.hiframes.pd_dataframe_ext')):
743745
out = [assign]
744746
arr = assign.target
@@ -1583,6 +1585,17 @@ def f(A, start, step):
15831585
out += self._run_call_rebalance_array(lhs.name, full_node, [imb_arr])
15841586
out[-1].target = lhs
15851587

1588+
elif self._is_REP(lhs.name) and guard(
1589+
is_const_slice, self.typemap, self.func_ir, index_var):
1590+
# cases like S.head()
1591+
# bcast if all in rank 0, otherwise gatherv
1592+
in_arr = full_node.value.value
1593+
start = self._array_starts[in_arr.name][0]
1594+
count = self._array_counts[in_arr.name][0]
1595+
return self._replace_func(
1596+
lambda arr, slice_index, start, count: hpat.distributed_api.const_slice_getitem(
1597+
arr, slice_index, start, count), [in_arr, index_var, start, count])
1598+
15861599
return out
15871600

15881601
def _run_parfor(self, parfor, namevar_table):

hpat/distributed_analysis.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from hpat.hiframes.pd_series_ext import SeriesType
2020
from hpat.utils import (get_constant, is_alloc_callname,
2121
is_whole_slice, is_array, is_array_container,
22-
is_np_array, find_build_tuple, debug_prints)
22+
is_np_array, find_build_tuple, debug_prints,
23+
is_const_slice)
2324
from hpat.hiframes.pd_dataframe_ext import DataFrameType
2425
from enum import Enum
2526

@@ -358,8 +359,12 @@ def _analyze_call(self, lhs, rhs, func_var, args, array_dists):
358359
if fdef == ('isna', 'hpat.hiframes.api'):
359360
return
360361

362+
if fdef == ('get_series_name', 'hpat.hiframes.api'):
363+
return
364+
361365
# dummy hiframes functions
362366
if func_mod == 'hpat.hiframes.api' and func_name in ('get_series_data',
367+
'get_series_index',
363368
'to_arr_from_series', 'ts_series_to_arr_typ',
364369
'to_date_series_type', 'dummy_unbox_series',
365370
'parallel_fix_df_array'):
@@ -772,6 +777,12 @@ def _analyze_getitem(self, inst, lhs, rhs, array_dists):
772777
self._meet_array_dists(lhs, rhs.value.name, array_dists)
773778
return
774779

780+
# output of operations like S.head() is REP since it's a "small" slice
781+
# input can remain 1D
782+
if guard(is_const_slice, self.typemap, self.func_ir, index_var):
783+
array_dists[lhs] = Distribution.REP
784+
return
785+
775786
self._set_REP(inst.list_vars(), array_dists)
776787
return
777788

hpat/distributed_api.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,64 @@ def prealloc_impl(arr):
276276

277277
return lambda arr: arr
278278

279+
280+
# assuming start and step are None
281+
def const_slice_getitem(arr, slice_index, start, count):
282+
return arr[slice_index]
283+
284+
285+
@overload(const_slice_getitem)
286+
def const_slice_getitem_overload(arr, slice_index, start, count):
287+
if arr == string_array_type:
288+
reduce_op = Reduce_Type.Sum.value
289+
def getitem_str_impl(arr, slice_index, start, count):
290+
rank = hpat.distributed_api.get_rank()
291+
k = slice_index.stop
292+
# get total characters for allocation
293+
n_chars = np.uint64(0)
294+
if k > count:
295+
my_end = min(count, max(k-start, 0))
296+
my_arr = arr[:my_end]
297+
my_arr = hpat.distributed_api.gatherv(my_arr)
298+
n_chars = hpat.distributed_api.dist_reduce(
299+
num_total_chars(my_arr), np.int32(reduce_op))
300+
if rank == 0:
301+
out_arr = my_arr
302+
else:
303+
if rank == 0:
304+
my_arr = arr[:k]
305+
n_chars = num_total_chars(my_arr)
306+
out_arr = my_arr
307+
n_chars = bcast_scalar(n_chars)
308+
if rank != 0:
309+
out_arr = pre_alloc_string_array(k, n_chars)
310+
311+
# actual communication
312+
hpat.distributed_api.bcast(out_arr)
313+
return out_arr
314+
315+
return getitem_str_impl
316+
317+
def getitem_impl(arr, slice_index, start, count):
318+
rank = hpat.distributed_api.get_rank()
319+
k = slice_index.stop
320+
out_arr = np.empty(k, arr.dtype)
321+
if k > count:
322+
my_end = min(count, max(k-start, 0))
323+
my_arr = arr[:my_end]
324+
my_arr = hpat.distributed_api.gatherv(my_arr)
325+
if rank == 0:
326+
print(my_arr)
327+
out_arr = my_arr
328+
else:
329+
if rank == 0:
330+
out_arr = arr[:k]
331+
hpat.distributed_api.bcast(out_arr)
332+
return out_arr
333+
334+
return getitem_impl
335+
336+
279337
# send_data, recv_data, send_counts, recv_counts, send_disp, recv_disp, typ_enum
280338
c_alltoallv = types.ExternalFunction("c_alltoallv", types.void(types.voidptr,
281339
types.voidptr, types.voidptr, types.voidptr, types.voidptr, types.voidptr, types.int32))

hpat/tests/test_series.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,16 @@ def test_impl():
11681168
hpat_func = hpat.jit(test_impl)
11691169
pd.testing.assert_series_equal(hpat_func(), test_impl())
11701170

1171+
def test_series_head_index_parallel1(self):
1172+
def test_impl(S):
1173+
return S.head(3)
1174+
1175+
S = pd.Series([6,9,2,3,6,4,5], ['a','ab','abc','c','f','hh',''])
1176+
hpat_func = hpat.jit(distributed={'S'})(test_impl)
1177+
start, end = get_start_end(len(S))
1178+
pd.testing.assert_series_equal(hpat_func(S[start:end]), test_impl(S))
1179+
self.assertTrue(count_array_OneDs()>0)
1180+
11711181
def test_series_median1(self):
11721182
def test_impl(S):
11731183
return S.median()

hpat/utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numba
44
from numba import ir_utils, ir, types, cgutils
55
from numba.ir_utils import (guard, get_definition, find_callname, require,
6-
add_offset_to_labels, find_topo_order)
6+
add_offset_to_labels, find_topo_order, find_const)
77
from numba.parfor import wrap_parfor_blocks, unwrap_parfor_blocks
88
from numba.typing import signature
99
from numba.typing.templates import infer_global, AbstractTemplate
@@ -279,6 +279,23 @@ def is_whole_slice(typemap, func_ir, var, accept_stride=False):
279279
require(isinstance(arg1_def, ir.Const) and arg1_def.value == None)
280280
return True
281281

282+
283+
def is_const_slice(typemap, func_ir, var, accept_stride=False):
284+
""" return True if var can be determined to be a constant size slice """
285+
require(typemap[var.name] == types.slice2_type
286+
or (accept_stride and typemap[var.name] == types.slice3_type))
287+
call_expr = get_definition(func_ir, var)
288+
require(isinstance(call_expr, ir.Expr) and call_expr.op == 'call')
289+
assert (len(call_expr.args) == 2
290+
or (accept_stride and len(call_expr.args) == 3))
291+
assert find_callname(func_ir, call_expr) == ('slice', 'builtins')
292+
arg0_def = get_definition(func_ir, call_expr.args[0])
293+
require(isinstance(arg0_def, ir.Const) and arg0_def.value == None)
294+
size_const = find_const(func_ir, call_expr.args[1])
295+
require(isinstance(size_const, int))
296+
return True
297+
298+
282299
def get_slice_step(typemap, func_ir, var):
283300
require(typemap[var.name] == types.slice3_type)
284301
call_expr = get_definition(func_ir, var)

0 commit comments

Comments
 (0)