Skip to content

Commit f63a3bd

Browse files
authored
Merge pull request #2794 from devitocodes/reuse-shmas-for-wdr
compiler: Add special types for device operations
2 parents 3ce3395 + 3559f54 commit f63a3bd

2 files changed

Lines changed: 57 additions & 8 deletions

File tree

devito/ir/clusters/cluster.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from devito.mpi.reduction_scheme import DistReduce
1616
from devito.symbolics import estimate_cost
1717
from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype
18-
from devito.types import Fence, WeakFence, CriticalRegion
18+
from devito.types import (
19+
Fence, WeakFence, CriticalRegion, ThreadPoolSync, ThreadCommit, ThreadWait
20+
)
1921

2022
__all__ = ["Cluster", "ClusterGroup"]
2123

@@ -262,26 +264,40 @@ def is_wild(self):
262264
self.is_weak_fence or
263265
self.is_critical_region)
264266

267+
def _is_type(self, cls):
268+
return self.exprs and all(isinstance(e.rhs, cls) for e in self.exprs)
269+
265270
@cached_property
266271
def is_halo_touch(self):
267-
return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs)
272+
return self._is_type(HaloTouch)
268273

269274
@cached_property
270275
def is_dist_reduce(self):
271-
return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs)
276+
return self._is_type(DistReduce)
272277

273278
@cached_property
274279
def is_fence(self):
275-
return (self.exprs and all(isinstance(e.rhs, Fence) for e in self.exprs) or
276-
self.is_critical_region)
280+
return self._is_type(Fence) or self.is_critical_region
277281

278282
@cached_property
279283
def is_weak_fence(self):
280-
return self.exprs and all(isinstance(e.rhs, WeakFence) for e in self.exprs)
284+
return self._is_type(WeakFence)
281285

282286
@cached_property
283287
def is_critical_region(self):
284-
return self.exprs and all(isinstance(e.rhs, CriticalRegion) for e in self.exprs)
288+
return self._is_type(CriticalRegion)
289+
290+
@cached_property
291+
def is_thread_pool_sync(self):
292+
return self._is_type(ThreadPoolSync)
293+
294+
@cached_property
295+
def is_thread_commit(self):
296+
return self._is_type(ThreadCommit)
297+
298+
@cached_property
299+
def is_thread_wait(self):
300+
return self._is_type(ThreadWait)
285301

286302
@cached_property
287303
def is_async(self):

devito/types/parallel.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
__all__ = ['NThreads', 'NThreadsNested', 'NThreadsNonaffine', 'NThreadsBase',
2424
'DeviceID', 'ThreadID', 'Lock', 'ThreadArray', 'PThreadArray',
25-
'SharedData', 'NPThreads', 'DeviceRM', 'QueueID', 'Barrier', 'TBArray']
25+
'SharedData', 'NPThreads', 'DeviceRM', 'QueueID', 'Barrier', 'TBArray',
26+
'ThreadPoolSync', 'ThreadCommit', 'ThreadWait']
2627

2728

2829
class NThreadsAbstract(Scalar):
@@ -321,6 +322,38 @@ class Barrier(Fence):
321322
pass
322323

323324

325+
class ThreadPoolSync(Barrier):
326+
327+
"""
328+
A generic synchronization barrier for a pool of threads.
329+
"""
330+
331+
pass
332+
333+
334+
class ThreadCommit(Fence):
335+
336+
"""
337+
A generic commit operation for a single thread, typically used to issue
338+
a memory operation at a specific program point, which requires the special
339+
treatment that all Fence subclasses provide (i.e., to avoid being reshuffled
340+
around by optimization passes).
341+
"""
342+
343+
pass
344+
345+
346+
class ThreadWait(Fence):
347+
348+
"""
349+
A generic wait operation for a single thread, typically used to synchronize
350+
after a memory operation issued at a specific program point with a
351+
ThreadCommit operation.
352+
"""
353+
354+
pass
355+
356+
324357
class TBArray(Array):
325358

326359
"""

0 commit comments

Comments
 (0)