Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 65aaf91

Browse files
author
Ehsan Totoni
committed
dist divide ceiling to be more even
1 parent 189401d commit 65aaf91

4 files changed

Lines changed: 72 additions & 33 deletions

File tree

hpat/_distributed.c

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include <stdbool.h>
22
#include "mpi.h"
3+
#include <cmath>
4+
#include <algorithm>
35
#include <Python.h>
46

57
int hpat_dist_get_rank();
68
int hpat_dist_get_size();
7-
int64_t hpat_dist_get_end(int64_t total, int64_t div_chunk, int num_pes,
8-
int node_id);
9-
int64_t hpat_dist_get_node_portion(int64_t total, int64_t div_chunk,
10-
int num_pes, int node_id);
9+
int64_t hpat_dist_get_start(int64_t total, int num_pes, int node_id);
10+
int64_t hpat_dist_get_end(int64_t total, int num_pes, int node_id);
11+
int64_t hpat_dist_get_node_portion(int64_t total, int num_pes, int node_id);
1112
double hpat_dist_get_time();
1213
MPI_Datatype get_MPI_typ(int typ_enum);
1314
int get_elem_size(int type_enum);
@@ -43,6 +44,8 @@ PyMODINIT_FUNC PyInit_hdist(void) {
4344
PyLong_FromVoidPtr((void*)(&hpat_dist_get_rank)));
4445
PyObject_SetAttrString(m, "hpat_dist_get_size",
4546
PyLong_FromVoidPtr((void*)(&hpat_dist_get_size)));
47+
PyObject_SetAttrString(m, "hpat_dist_get_start",
48+
PyLong_FromVoidPtr((void*)(&hpat_dist_get_start)));
4649
PyObject_SetAttrString(m, "hpat_dist_get_end",
4750
PyLong_FromVoidPtr((void*)(&hpat_dist_get_end)));
4851
PyObject_SetAttrString(m, "hpat_dist_get_node_portion",
@@ -103,17 +106,29 @@ int hpat_dist_get_size()
103106
return size;
104107
}
105108

106-
int64_t hpat_dist_get_end(int64_t total, int64_t div_chunk, int num_pes,
107-
int node_id)
109+
int64_t hpat_dist_get_start(int64_t total, int num_pes, int node_id)
108110
{
109-
return ((node_id==num_pes-1) ? total : (node_id+1)*div_chunk);
111+
int64_t div_chunk = (int64_t)ceil(total/((double)num_pes));
112+
int64_t start = std::min(total, node_id*div_chunk);
113+
// printf("rank %d start:%lld\n", node_id, start);
114+
return start;
110115
}
111116

112-
int64_t hpat_dist_get_node_portion(int64_t total, int64_t div_chunk,
113-
int num_pes, int node_id)
117+
int64_t hpat_dist_get_end(int64_t total, int num_pes, int node_id)
114118
{
115-
int64_t portion = ((node_id==num_pes-1) ? total-node_id*div_chunk : div_chunk);
116-
// printf("portion:%lld\n", portion);
119+
int64_t div_chunk = (int64_t)ceil(total/((double)num_pes));
120+
int64_t end = std::min(total, (node_id+1)*div_chunk);
121+
// printf("rank %d end:%lld\n", node_id, end);
122+
return end;
123+
}
124+
125+
int64_t hpat_dist_get_node_portion(int64_t total, int num_pes, int node_id)
126+
{
127+
int64_t div_chunk = (int64_t)ceil(total/((double)num_pes));
128+
int64_t start = std::min(total, node_id*div_chunk);
129+
int64_t end = std::min(total, (node_id+1)*div_chunk);
130+
int64_t portion = end-start;
131+
// printf("rank %d portion:%lld\n", node_id, portion);
117132
return portion;
118133
}
119134

hpat/distributed.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -852,17 +852,22 @@ def _gen_1D_div(self, size_var, scope, loc, prefix, end_call_name, end_call):
852852
size_assign = ir.Assign(ir.Const(size_var, loc), new_size_var, loc)
853853
div_nodes.append(size_assign)
854854
size_var = new_size_var
855-
div_var = ir.Var(scope, mk_unique_var(prefix+"_div_var"), loc)
856-
self.typemap[div_var.name] = types.int64
857-
div_expr = ir.Expr.binop('//', size_var, self._size_var, loc)
858-
self.calltypes[div_expr] = find_op_typ('//', [types.int64, types.int32])
859-
div_assign = ir.Assign(div_expr, div_var, loc)
860855

856+
# attr call: start_attr = getattr(g_dist_var, get_start)
857+
start_attr_call = ir.Expr.getattr(self._g_dist_var, "get_start", loc)
858+
start_attr_var = ir.Var(scope, mk_unique_var("$get_start_attr"), loc)
859+
self.typemap[start_attr_var.name] = get_global_func_typ(distributed_api.get_start)
860+
start_attr_assign = ir.Assign(start_attr_call, start_attr_var, loc)
861+
862+
# start_var = get_start(size, rank, pes)
861863
start_var = ir.Var(scope, mk_unique_var(prefix+"_start_var"), loc)
862864
self.typemap[start_var.name] = types.int64
863-
start_expr = ir.Expr.binop('*', div_var, self._rank_var, loc)
864-
self.calltypes[start_expr] = find_op_typ('*', [types.int64, types.int32])
865+
start_expr = ir.Expr.call(start_attr_var, [size_var,
866+
self._size_var, self._rank_var], (), loc)
867+
self.calltypes[start_expr] = self.typemap[start_attr_var.name].get_call_type(
868+
typing.Context(), [types.int64, types.int32, types.int32], {})
865869
start_assign = ir.Assign(start_expr, start_var, loc)
870+
866871
# attr call: end_attr = getattr(g_dist_var, get_end)
867872
end_attr_call = ir.Expr.getattr(self._g_dist_var, end_call_name, loc)
868873
end_attr_var = ir.Var(scope, mk_unique_var("$get_end_attr"), loc)
@@ -871,12 +876,12 @@ def _gen_1D_div(self, size_var, scope, loc, prefix, end_call_name, end_call):
871876

872877
end_var = ir.Var(scope, mk_unique_var(prefix+"_end_var"), loc)
873878
self.typemap[end_var.name] = types.int64
874-
end_expr = ir.Expr.call(end_attr_var, [size_var, div_var,
879+
end_expr = ir.Expr.call(end_attr_var, [size_var,
875880
self._size_var, self._rank_var], (), loc)
876881
self.calltypes[end_expr] = self.typemap[end_attr_var.name].get_call_type(
877-
typing.Context(), [types.int64, types.int64, types.int32, types.int32], {})
882+
typing.Context(), [types.int64, types.int32, types.int32], {})
878883
end_assign = ir.Assign(end_expr, end_var, loc)
879-
div_nodes += [div_assign, start_assign, end_attr_assign, end_assign]
884+
div_nodes += [start_attr_assign, start_assign, end_attr_assign, end_assign]
880885
return div_nodes, start_var, end_var
881886

882887
def _get_ind_sub(self, ind_var, start_var):

hpat/distributed_api.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@ def get_size():
1111
"""dummy function for C mpi get_size"""
1212
return 0
1313

14-
def get_end(total_size, div, pes, rank):
14+
def get_start(total_size, pes, rank):
1515
"""get end point of range for parfor division"""
16-
return total_size if rank==pes-1 else (rank+1)*div
16+
return 0
1717

18-
def get_node_portion(total_size, div, pes, rank):
18+
def get_end(total_size, pes, rank):
19+
"""get end point of range for parfor division"""
20+
return 0
21+
22+
def get_node_portion(total_size, pes, rank):
1923
"""get portion of size for alloc division"""
20-
return total_size-div*rank if rank==pes-1 else div
24+
return 0
2125

2226
def dist_reduce(value):
2327
"""dummy to implement simple reductions"""
@@ -65,18 +69,25 @@ def generic(self, args, kws):
6569
assert len(args)==0
6670
return signature(types.int32, *args)
6771

72+
@infer_global(get_start)
73+
class DistStart(AbstractTemplate):
74+
def generic(self, args, kws):
75+
assert not kws
76+
assert len(args)==3
77+
return signature(types.int64, *args)
78+
6879
@infer_global(get_end)
6980
class DistEnd(AbstractTemplate):
7081
def generic(self, args, kws):
7182
assert not kws
72-
assert len(args)==4
83+
assert len(args)==3
7384
return signature(types.int64, *args)
7485

7586
@infer_global(get_node_portion)
7687
class DistPortion(AbstractTemplate):
7788
def generic(self, args, kws):
7889
assert not kws
79-
assert len(args)==4
90+
assert len(args)==3
8091
return signature(types.int64, *args)
8192

8293
@infer_global(dist_reduce)

hpat/distributed_lower.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import llvmlite.binding as ll
1212
ll.add_symbol('hpat_dist_get_rank', hdist.hpat_dist_get_rank)
1313
ll.add_symbol('hpat_dist_get_size', hdist.hpat_dist_get_size)
14+
ll.add_symbol('hpat_dist_get_start', hdist.hpat_dist_get_start)
1415
ll.add_symbol('hpat_dist_get_end', hdist.hpat_dist_get_end)
1516
ll.add_symbol('hpat_dist_get_node_portion', hdist.hpat_dist_get_node_portion)
1617
ll.add_symbol('hpat_dist_get_time', hdist.hpat_dist_get_time)
@@ -41,19 +42,26 @@ def dist_get_size(context, builder, sig, args):
4142
fn = builder.module.get_or_insert_function(fnty, name="hpat_dist_get_size")
4243
return builder.call(fn, [])
4344

44-
@lower_builtin(distributed_api.get_end, types.int64, types.int64, types.int32, types.int32)
45+
@lower_builtin(distributed_api.get_start, types.int64, types.int32, types.int32)
46+
def dist_get_start(context, builder, sig, args):
47+
fnty = lir.FunctionType(lir.IntType(64), [lir.IntType(64),
48+
lir.IntType(32), lir.IntType(32)])
49+
fn = builder.module.get_or_insert_function(fnty, name="hpat_dist_get_start")
50+
return builder.call(fn, [args[0], args[1], args[2]])
51+
52+
@lower_builtin(distributed_api.get_end, types.int64, types.int32, types.int32)
4553
def dist_get_end(context, builder, sig, args):
46-
fnty = lir.FunctionType(lir.IntType(64), [lir.IntType(64), lir.IntType(64),
54+
fnty = lir.FunctionType(lir.IntType(64), [lir.IntType(64),
4755
lir.IntType(32), lir.IntType(32)])
4856
fn = builder.module.get_or_insert_function(fnty, name="hpat_dist_get_end")
49-
return builder.call(fn, [args[0], args[1], args[2], args[3]])
57+
return builder.call(fn, [args[0], args[1], args[2]])
5058

51-
@lower_builtin(distributed_api.get_node_portion, types.int64, types.int64, types.int32, types.int32)
59+
@lower_builtin(distributed_api.get_node_portion, types.int64, types.int32, types.int32)
5260
def dist_get_portion(context, builder, sig, args):
53-
fnty = lir.FunctionType(lir.IntType(64), [lir.IntType(64), lir.IntType(64),
61+
fnty = lir.FunctionType(lir.IntType(64), [lir.IntType(64),
5462
lir.IntType(32), lir.IntType(32)])
5563
fn = builder.module.get_or_insert_function(fnty, name="hpat_dist_get_node_portion")
56-
return builder.call(fn, [args[0], args[1], args[2], args[3]])
64+
return builder.call(fn, [args[0], args[1], args[2]])
5765

5866
@lower_builtin(distributed_api.dist_reduce, types.int64)
5967
@lower_builtin(distributed_api.dist_reduce, types.int32)

0 commit comments

Comments
 (0)