Skip to content

Commit aab5f7f

Browse files
committed
adding concurrent_vector
1 parent 79a00dd commit aab5f7f

4 files changed

Lines changed: 238 additions & 12 deletions

File tree

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#ifndef STAN_MATH_OPENCL_CONCURRENT_VECTOR_HPP
2+
#define STAN_MATH_OPENCL_CONCURRENT_VECTOR_HPP
3+
4+
#include <atomic>
5+
#include <cstddef>
6+
#include <cstdint>
7+
#include <new>
8+
#include <type_traits>
9+
#include <utility>
10+
#include <vector>
11+
#include <stdexcept>
12+
#include <cassert>
13+
14+
namespace stan {
15+
namespace math {
16+
namespace internal {
17+
18+
/**
19+
* Minimal segmented concurrent_vector.
20+
*
21+
* Key properties:
22+
* - concurrent emplace_back/push_back using an atomic size counter.
23+
* - segmented storage => no moving elements during growth, stable addresses.
24+
* - segments allocated lazily; allocation uses CAS to avoid locks.
25+
*
26+
* Important constraints / notes:
27+
* - operator[] is safe if you only read indices < size() that are known to be constructed.
28+
* - For "publish-then-read" correctness: size() is updated before construction finishes.
29+
* So consumers must not read index i just because i < size(); they must have a stronger
30+
* protocol (e.g., producer hands out index, or you add a "constructed" bitmap).
31+
* This matches common usage where only the pushing thread uses the returned index.
32+
* - clear()/destruction are NOT concurrent with pushes.
33+
*
34+
* If you need "readers can iterate up to size() safely while writers push",
35+
* add a constructed flag per element (see comment near emplace_back()).
36+
*/
37+
template <typename T,
38+
std::size_t BaseSegmentSize = 1024,
39+
std::size_t MaxSegments = 32>
40+
class concurrent_vector {
41+
static_assert(BaseSegmentSize > 0, "BaseSegmentSize must be > 0");
42+
static_assert((BaseSegmentSize & (BaseSegmentSize - 1)) == 0,
43+
"BaseSegmentSize must be a power of two (helps mapping).");
44+
public:
45+
concurrent_vector() : size_(0) {
46+
segments_.resize(MaxSegments);
47+
for (auto& p : segments_) p.store(nullptr, std::memory_order_relaxed);
48+
}
49+
50+
concurrent_vector(const concurrent_vector&) = delete;
51+
concurrent_vector& operator=(const concurrent_vector&) = delete;
52+
53+
~concurrent_vector() { destroy_all_(); }
54+
55+
std::size_t size() const noexcept {
56+
return size_.load(std::memory_order_acquire);
57+
}
58+
59+
bool empty() const noexcept { return size() == 0; }
60+
61+
// Non-concurrent: safe only when no other threads are pushing/reading.
62+
void clear() {
63+
destroy_all_();
64+
size_.store(0, std::memory_order_release);
65+
}
66+
67+
// Concurrent append (construct in place). Returns the index.
68+
template <typename... Args>
69+
std::size_t emplace_back(Args&&... args) {
70+
// Claim an index
71+
const std::size_t idx = size_.fetch_add(1, std::memory_order_acq_rel);
72+
73+
// Ensure the segment exists
74+
T* seg = ensure_segment_for_index_(idx);
75+
76+
// Placement-new into the correct slot
77+
const std::size_t off = offset_in_segment_(idx);
78+
T* slot = seg + off;
79+
80+
// Construct element
81+
::new (static_cast<void*>(slot)) T(std::forward<Args>(args)...);
82+
83+
// If you need "safe iteration by other threads that use size()",
84+
// you must publish construction completion separately, e.g.:
85+
// constructed_[idx].store(true, release);
86+
// and readers check constructed_[i].load(acquire).
87+
return idx;
88+
}
89+
90+
std::size_t push_back(const T& v) { return emplace_back(v); }
91+
std::size_t push_back(T&& v) { return emplace_back(std::move(v)); }
92+
93+
// Returns pointer to element at i (no bounds check).
94+
// Safe if element i is fully constructed and lifetime is valid.
95+
T* data_at(std::size_t i) noexcept {
96+
T* seg = segment_ptr_(segment_index_(i));
97+
return seg + offset_in_segment_(i);
98+
}
99+
const T* data_at(std::size_t i) const noexcept {
100+
const T* seg = segment_ptr_(segment_index_(i));
101+
return seg + offset_in_segment_(i);
102+
}
103+
104+
// Bounds-checked access (still not concurrent-safe unless you have a protocol).
105+
T& at(std::size_t i) {
106+
if (i >= size()) throw std::out_of_range("concurrent_vector::at");
107+
return *data_at(i);
108+
}
109+
const T& at(std::size_t i) const {
110+
if (i >= size()) throw std::out_of_range("concurrent_vector::at");
111+
return *data_at(i);
112+
}
113+
114+
// Unchecked access
115+
T& operator[](std::size_t i) noexcept { return *data_at(i); }
116+
const T& operator[](std::size_t i) const noexcept { return *data_at(i); }
117+
118+
// Capacity is segmented and unbounded until MaxSegments is exceeded.
119+
// This is the max number of elements representable by the segment scheme.
120+
static constexpr std::size_t max_size() noexcept {
121+
// Total capacity = Base * (2^MaxSegments - 1)
122+
// but beware overflow for large MaxSegments.
123+
return BaseSegmentSize * ((std::size_t{1} << MaxSegments) - 1);
124+
}
125+
126+
private:
127+
// Segment k has size BaseSegmentSize * 2^k
128+
static constexpr std::size_t segment_size_(std::size_t k) noexcept {
129+
return BaseSegmentSize << k;
130+
}
131+
132+
// Prefix count before segment k:
133+
// Base * (2^k - 1)
134+
static constexpr std::size_t segment_prefix_(std::size_t k) noexcept {
135+
return BaseSegmentSize * ((std::size_t{1} << k) - 1);
136+
}
137+
138+
// Map global index -> segment index.
139+
// Let q = idx / Base. Then segment = floor(log2(q + 1)).
140+
static std::size_t segment_index_(std::size_t idx) noexcept {
141+
const std::size_t q = idx / BaseSegmentSize;
142+
const std::size_t x = q + 1;
143+
144+
#if defined(__GNUG__) || defined(__clang__)
145+
// floor(log2(x)) via clz
146+
return (sizeof(std::size_t) * 8 - 1) - static_cast<std::size_t>(__builtin_clzl(x));
147+
#else
148+
// portable fallback
149+
std::size_t s = 0;
150+
std::size_t t = x;
151+
while (t >>= 1) ++s;
152+
return s;
153+
#endif
154+
}
155+
156+
static std::size_t offset_in_segment_(std::size_t idx) noexcept {
157+
const std::size_t s = segment_index_(idx);
158+
return idx - segment_prefix_(s);
159+
}
160+
161+
T* segment_ptr_(std::size_t s) noexcept {
162+
return static_cast<T*>(segments_[s].load(std::memory_order_acquire));
163+
}
164+
const T* segment_ptr_(std::size_t s) const noexcept {
165+
return static_cast<const T*>(segments_[s].load(std::memory_order_acquire));
166+
}
167+
168+
T* ensure_segment_for_index_(std::size_t idx) {
169+
const std::size_t s = segment_index_(idx);
170+
if (s >= MaxSegments) {
171+
throw std::length_error("concurrent_vector: exceeded MaxSegments");
172+
}
173+
174+
T* seg = segment_ptr_(s);
175+
if (seg) return seg;
176+
177+
// Allocate segment lazily (raw storage for T objects)
178+
const std::size_t n = segment_size_(s);
179+
void* raw = ::operator new(sizeof(T) * n);
180+
T* fresh = static_cast<T*>(raw);
181+
182+
// CAS install; if another thread won, free ours.
183+
void* expected = nullptr;
184+
if (!segments_[s].compare_exchange_strong(
185+
expected, fresh,
186+
std::memory_order_release,
187+
std::memory_order_acquire)) {
188+
::operator delete(raw);
189+
seg = static_cast<T*>(segments_[s].load(std::memory_order_acquire));
190+
assert(seg != nullptr);
191+
return seg;
192+
}
193+
194+
return fresh;
195+
}
196+
197+
// Destroy constructed elements and free segments.
198+
// Not concurrent with pushes or reads.
199+
void destroy_all_() noexcept {
200+
const std::size_t n = size_.load(std::memory_order_acquire);
201+
202+
// Destroy elements that were constructed.
203+
// NOTE: This assumes indices [0, n) are all constructed.
204+
// If you allow exceptions or partial construction, track constructed flags.
205+
for (std::size_t i = 0; i < n; ++i) {
206+
data_at(i)->~T();
207+
}
208+
209+
// Free segments
210+
for (std::size_t s = 0; s < segments_.size(); ++s) {
211+
void* p = segments_[s].load(std::memory_order_acquire);
212+
if (p) {
213+
::operator delete(p);
214+
segments_[s].store(nullptr, std::memory_order_relaxed);
215+
}
216+
}
217+
}
218+
219+
std::atomic<std::size_t> size_;
220+
std::vector<std::atomic<void*>> segments_;
221+
};
222+
}
223+
}
224+
}
225+
226+
#endif

stan/math/opencl/kernel_cl.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,17 @@ inline void assign_events(const cl::Event& new_event, CallArg& m,
109109
* @return A vector of OpenCL events.
110110
*/
111111
template <typename T, require_not_matrix_cl_t<T>* = nullptr>
112-
inline tbb::concurrent_vector<cl::Event> select_events(const T& m) {
113-
return tbb::concurrent_vector<cl::Event>{};
112+
inline internal::concurrent_vector<cl::Event> select_events(const T& m) {
113+
return internal::concurrent_vector<cl::Event>{};
114114
}
115115
template <typename T, typename K, require_matrix_cl_t<K>* = nullptr,
116116
require_same_t<T, in_buffer>* = nullptr>
117-
inline const tbb::concurrent_vector<cl::Event>& select_events(const K& m) {
117+
inline const internal::concurrent_vector<cl::Event>& select_events(const K& m) {
118118
return m.write_events();
119119
}
120120
template <typename T, typename K, require_matrix_cl_t<K>* = nullptr,
121121
require_any_same_t<T, out_buffer, in_out_buffer>* = nullptr>
122-
inline tbb::concurrent_vector<cl::Event> select_events(K& m) {
122+
inline internal::concurrent_vector<cl::Event> select_events(K& m) {
123123
static_assert(!std::is_const<K>::value, "Can not write to const matrix_cl!");
124124
return m.read_write_events();
125125
}

stan/math/opencl/matrix_cl.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <stan/math/prim/fun/Eigen.hpp>
1313
#include <stan/math/prim/fun/vec_concat.hpp>
1414
#include <CL/opencl.hpp>
15-
#include <tbb/concurrent_vector.h>
15+
#include <stan/math/opencl/concurrent_vector.h>
1616
#include <algorithm>
1717
#include <iostream>
1818
#include <string>
@@ -51,8 +51,8 @@ class matrix_cl : public matrix_cl_base {
5151
int cols_{0}; // Number of columns.
5252
// Holds info on if matrix is a special type
5353
matrix_cl_view view_{matrix_cl_view::Entire};
54-
mutable tbb::concurrent_vector<cl::Event> write_events_; // Tracks write jobs
55-
mutable tbb::concurrent_vector<cl::Event> read_events_; // Tracks reads
54+
mutable internal::concurrent_vector<cl::Event> write_events_; // Tracks write jobs
55+
mutable internal::concurrent_vector<cl::Event> read_events_; // Tracks reads
5656

5757
public:
5858
using Scalar = T; // Underlying type of the matrix
@@ -100,23 +100,23 @@ class matrix_cl : public matrix_cl_base {
100100
* Get the events from the event stacks.
101101
* @return The write event stack.
102102
*/
103-
inline const tbb::concurrent_vector<cl::Event>& write_events() const {
103+
inline const internal::concurrent_vector<cl::Event>& write_events() const {
104104
return write_events_;
105105
}
106106

107107
/**
108108
* Get the events from the event stacks.
109109
* @return The read/write event stack.
110110
*/
111-
inline const tbb::concurrent_vector<cl::Event>& read_events() const {
111+
inline const internal::concurrent_vector<cl::Event>& read_events() const {
112112
return read_events_;
113113
}
114114

115115
/**
116116
* Get the events from the event stacks.
117117
* @return The read/write event stack.
118118
*/
119-
inline const tbb::concurrent_vector<cl::Event> read_write_events() const {
119+
inline const internal::concurrent_vector<cl::Event> read_write_events() const {
120120
return vec_concat(this->read_events(), this->write_events());
121121
}
122122

stan/math/opencl/opencl_context.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <stan/math/opencl/err/check_opencl.hpp>
1515

1616
#include <CL/opencl.hpp>
17-
#include <tbb/concurrent_vector.h>
17+
#include <stan/math/opencl/concurrent_vector.hpp>
1818
#include <string>
1919
#include <iostream>
2020
#include <fstream>
@@ -208,7 +208,7 @@ class opencl_context_base {
208208
* The API to access the methods and values in opencl_context_base
209209
*/
210210
class opencl_context {
211-
tbb::concurrent_vector<cl::Kernel*> kernel_caches_;
211+
internal::concurrent_vector<cl::Kernel*> kernel_caches_;
212212

213213
public:
214214
opencl_context() = default;

0 commit comments

Comments
 (0)