@@ -94,6 +94,21 @@ int main(int argc, char *argv[])
9494 GetMySubarray (&my_subarray);
9595 InitDeviceArrays (&A_device[0 ], &A_device[1 ], q, &my_subarray);
9696
97+ #ifdef GROUP_SIZE_DEFAULT
98+ int work_group_size = GROUP_SIZE_DEFAULT;
99+ #else
100+ int work_group_size =
101+ q.get_device ().get_info <sycl::info::device::max_work_group_size>();
102+ #endif
103+
104+ if ((Nx % work_group_size) != 0 ) {
105+ if (my_subarray.rank == 0 ) {
106+ printf (" For simplification, sycl::info::device::max_work_group_size should be divider of X dimention of array\n " );
107+ printf (" Please adjust matrix size, or define GROUP_SIZE_DEFAULT\n " );
108+ printf (" sycl::info::device::max_work_group_size=%d Nx=%d (%d)\n " , work_group_size, Nx, work_group_size % Nx);
109+ MPI_Abort (MPI_COMM_WORLD, -1 );
110+ }
111+ }
97112 /* Create RMA window using device memory */
98113 MPI_Win_create (A_device[0 ],
99114 sizeof (double ) * (my_subarray.x_size + 2 ) * (my_subarray.y_size + 2 ),
@@ -116,18 +131,24 @@ int main(int argc, char *argv[])
116131 {
117132 /* Calculate values on borders to initiate communications early */
118133 q.submit ([&](auto & h) {
119- h.parallel_for (sycl::range (my_subarray.x_size ), [ =] (auto index) {
120- int column = index[0 ];
121- int idx = XY_2_IDX (column, 0 , my_subarray);
122- a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
123- + a[idx - ROW_SIZE (my_subarray)]
124- + a[idx + ROW_SIZE (my_subarray)]);
125-
126- idx = XY_2_IDX (column, my_subarray.y_size - 1 , my_subarray);
127- a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
128- + a[idx - ROW_SIZE (my_subarray)]
129- + a[idx + ROW_SIZE (my_subarray)]);
130-
134+ h.parallel_for (sycl::nd_range<1 >(work_group_size, work_group_size),
135+ [=](sycl::nd_item<1 > item) {
136+ int column = item.get_global_id (0 );
137+ int col_per_wg = my_subarray.x_size / work_group_size;
138+
139+ int my_x_lb = col_per_wg * local_id;
140+ int my_x_ub = my_x_lb + col_per_wg;
141+
142+ for (int column = my_x_lb; column < my_x_ub; column ++) {
143+ int idx = XY_2_IDX (column, 0 , my_subarray);
144+ a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
145+ + a[idx - ROW_SIZE (my_subarray)]
146+ + a[idx + ROW_SIZE (my_subarray)]);
147+ idx = XY_2_IDX (column, my_subarray.y_size - 1 , my_subarray);
148+ a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
149+ + a[idx - ROW_SIZE (my_subarray)]
150+ + a[idx + ROW_SIZE (my_subarray)]);
151+ }
131152 });
132153 }).wait ();
133154 }
@@ -149,11 +170,23 @@ int main(int argc, char *argv[])
149170 /* Recalculate internal points in parallel with communications */
150171 {
151172 q.submit ([&](auto & h) {
152- h.parallel_for (sycl::range (my_subarray.x_size , my_subarray.y_size - 2 ), [ =] (auto index) {
153- int idx = XY_2_IDX (index[0 ], index[1 ] + 1 , my_subarray);
154- a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
155- + a[idx - ROW_SIZE (my_subarray)]
156- + a[idx + ROW_SIZE (my_subarray)]);
173+ h.parallel_for (sycl::nd_range<1 >(work_group_size, work_group_size),
174+ [=](sycl::nd_item<1 > item) {
175+ int local_id = item.get_local_id ();
176+ int col_per_wg = my_subarray.x_size / work_group_size;
177+
178+ int my_x_lb = col_per_wg * local_id;
179+ int my_x_ub = my_x_lb + col_per_wg;
180+
181+ /* Recalculate internal points in parallel with comunications */
182+ for (int row = 1 ; row < my_subarray.y_size - 1 ; ++row) {
183+ for (int column = my_x_lb; column < my_x_ub; column ++) {
184+ int idx = XY_2_IDX (column, row, my_subarray);
185+ a_out[idx] = 0.25 * (a[idx - 1 ] + a[idx + 1 ]
186+ + a[idx - ROW_SIZE (my_subarray)]
187+ + a[idx + ROW_SIZE (my_subarray)]);
188+ }
189+ }
157190 });
158191 }).wait ();
159192 }
0 commit comments