|
| 1 | +#ifndef STAN_MATH_OPENCL_CONCURRENT_VECTOR_HPP |
| 2 | +#define STAN_MATH_OPENCL_CONCURRENT_VECTOR_HPP |
| 3 | + |
| 4 | +#include <atomic> |
| 5 | +#include <cstddef> |
| 6 | +#include <cstdint> |
| 7 | +#include <new> |
| 8 | +#include <type_traits> |
| 9 | +#include <utility> |
| 10 | +#include <vector> |
| 11 | +#include <stdexcept> |
| 12 | +#include <cassert> |
| 13 | + |
| 14 | +namespace stan { |
| 15 | +namespace math { |
| 16 | +namespace internal { |
| 17 | + |
| 18 | + /** |
| 19 | + * Minimal segmented concurrent_vector. |
| 20 | + * |
| 21 | + * Key properties: |
| 22 | + * - concurrent emplace_back/push_back using an atomic size counter. |
| 23 | + * - segmented storage => no moving elements during growth, stable addresses. |
| 24 | + * - segments allocated lazily; allocation uses CAS to avoid locks. |
| 25 | + * |
| 26 | + * Important constraints / notes: |
| 27 | + * - operator[] is safe if you only read indices < size() that are known to be constructed. |
| 28 | + * - For "publish-then-read" correctness: size() is updated before construction finishes. |
| 29 | + * So consumers must not read index i just because i < size(); they must have a stronger |
| 30 | + * protocol (e.g., producer hands out index, or you add a "constructed" bitmap). |
| 31 | + * This matches common usage where only the pushing thread uses the returned index. |
| 32 | + * - clear()/destruction are NOT concurrent with pushes. |
| 33 | + * |
| 34 | + * If you need "readers can iterate up to size() safely while writers push", |
| 35 | + * add a constructed flag per element (see comment near emplace_back()). |
| 36 | + */ |
| 37 | + template <typename T, |
| 38 | + std::size_t BaseSegmentSize = 1024, |
| 39 | + std::size_t MaxSegments = 32> |
| 40 | + class concurrent_vector { |
| 41 | + static_assert(BaseSegmentSize > 0, "BaseSegmentSize must be > 0"); |
| 42 | + static_assert((BaseSegmentSize & (BaseSegmentSize - 1)) == 0, |
| 43 | + "BaseSegmentSize must be a power of two (helps mapping)."); |
| 44 | + public: |
| 45 | + concurrent_vector() : size_(0) { |
| 46 | + segments_.resize(MaxSegments); |
| 47 | + for (auto& p : segments_) p.store(nullptr, std::memory_order_relaxed); |
| 48 | + } |
| 49 | + |
| 50 | + concurrent_vector(const concurrent_vector&) = delete; |
| 51 | + concurrent_vector& operator=(const concurrent_vector&) = delete; |
| 52 | + |
| 53 | + ~concurrent_vector() { destroy_all_(); } |
| 54 | + |
| 55 | + std::size_t size() const noexcept { |
| 56 | + return size_.load(std::memory_order_acquire); |
| 57 | + } |
| 58 | + |
| 59 | + bool empty() const noexcept { return size() == 0; } |
| 60 | + |
| 61 | + // Non-concurrent: safe only when no other threads are pushing/reading. |
| 62 | + void clear() { |
| 63 | + destroy_all_(); |
| 64 | + size_.store(0, std::memory_order_release); |
| 65 | + } |
| 66 | + |
| 67 | + // Concurrent append (construct in place). Returns the index. |
| 68 | + template <typename... Args> |
| 69 | + std::size_t emplace_back(Args&&... args) { |
| 70 | + // Claim an index |
| 71 | + const std::size_t idx = size_.fetch_add(1, std::memory_order_acq_rel); |
| 72 | + |
| 73 | + // Ensure the segment exists |
| 74 | + T* seg = ensure_segment_for_index_(idx); |
| 75 | + |
| 76 | + // Placement-new into the correct slot |
| 77 | + const std::size_t off = offset_in_segment_(idx); |
| 78 | + T* slot = seg + off; |
| 79 | + |
| 80 | + // Construct element |
| 81 | + ::new (static_cast<void*>(slot)) T(std::forward<Args>(args)...); |
| 82 | + |
| 83 | + // If you need "safe iteration by other threads that use size()", |
| 84 | + // you must publish construction completion separately, e.g.: |
| 85 | + // constructed_[idx].store(true, release); |
| 86 | + // and readers check constructed_[i].load(acquire). |
| 87 | + return idx; |
| 88 | + } |
| 89 | + |
| 90 | + std::size_t push_back(const T& v) { return emplace_back(v); } |
| 91 | + std::size_t push_back(T&& v) { return emplace_back(std::move(v)); } |
| 92 | + |
| 93 | + // Returns pointer to element at i (no bounds check). |
| 94 | + // Safe if element i is fully constructed and lifetime is valid. |
| 95 | + T* data_at(std::size_t i) noexcept { |
| 96 | + T* seg = segment_ptr_(segment_index_(i)); |
| 97 | + return seg + offset_in_segment_(i); |
| 98 | + } |
| 99 | + const T* data_at(std::size_t i) const noexcept { |
| 100 | + const T* seg = segment_ptr_(segment_index_(i)); |
| 101 | + return seg + offset_in_segment_(i); |
| 102 | + } |
| 103 | + |
| 104 | + // Bounds-checked access (still not concurrent-safe unless you have a protocol). |
| 105 | + T& at(std::size_t i) { |
| 106 | + if (i >= size()) throw std::out_of_range("concurrent_vector::at"); |
| 107 | + return *data_at(i); |
| 108 | + } |
| 109 | + const T& at(std::size_t i) const { |
| 110 | + if (i >= size()) throw std::out_of_range("concurrent_vector::at"); |
| 111 | + return *data_at(i); |
| 112 | + } |
| 113 | + |
| 114 | + // Unchecked access |
| 115 | + T& operator[](std::size_t i) noexcept { return *data_at(i); } |
| 116 | + const T& operator[](std::size_t i) const noexcept { return *data_at(i); } |
| 117 | + |
| 118 | + // Capacity is segmented and unbounded until MaxSegments is exceeded. |
| 119 | + // This is the max number of elements representable by the segment scheme. |
| 120 | + static constexpr std::size_t max_size() noexcept { |
| 121 | + // Total capacity = Base * (2^MaxSegments - 1) |
| 122 | + // but beware overflow for large MaxSegments. |
| 123 | + return BaseSegmentSize * ((std::size_t{1} << MaxSegments) - 1); |
| 124 | + } |
| 125 | + |
| 126 | + private: |
| 127 | + // Segment k has size BaseSegmentSize * 2^k |
| 128 | + static constexpr std::size_t segment_size_(std::size_t k) noexcept { |
| 129 | + return BaseSegmentSize << k; |
| 130 | + } |
| 131 | + |
| 132 | + // Prefix count before segment k: |
| 133 | + // Base * (2^k - 1) |
| 134 | + static constexpr std::size_t segment_prefix_(std::size_t k) noexcept { |
| 135 | + return BaseSegmentSize * ((std::size_t{1} << k) - 1); |
| 136 | + } |
| 137 | + |
| 138 | + // Map global index -> segment index. |
| 139 | + // Let q = idx / Base. Then segment = floor(log2(q + 1)). |
| 140 | + static std::size_t segment_index_(std::size_t idx) noexcept { |
| 141 | + const std::size_t q = idx / BaseSegmentSize; |
| 142 | + const std::size_t x = q + 1; |
| 143 | + |
| 144 | +#if defined(__GNUG__) || defined(__clang__) |
| 145 | + // floor(log2(x)) via clz |
| 146 | + return (sizeof(std::size_t) * 8 - 1) - static_cast<std::size_t>(__builtin_clzl(x)); |
| 147 | +#else |
| 148 | + // portable fallback |
| 149 | + std::size_t s = 0; |
| 150 | + std::size_t t = x; |
| 151 | + while (t >>= 1) ++s; |
| 152 | + return s; |
| 153 | +#endif |
| 154 | + } |
| 155 | + |
| 156 | + static std::size_t offset_in_segment_(std::size_t idx) noexcept { |
| 157 | + const std::size_t s = segment_index_(idx); |
| 158 | + return idx - segment_prefix_(s); |
| 159 | + } |
| 160 | + |
| 161 | + T* segment_ptr_(std::size_t s) noexcept { |
| 162 | + return static_cast<T*>(segments_[s].load(std::memory_order_acquire)); |
| 163 | + } |
| 164 | + const T* segment_ptr_(std::size_t s) const noexcept { |
| 165 | + return static_cast<const T*>(segments_[s].load(std::memory_order_acquire)); |
| 166 | + } |
| 167 | + |
| 168 | + T* ensure_segment_for_index_(std::size_t idx) { |
| 169 | + const std::size_t s = segment_index_(idx); |
| 170 | + if (s >= MaxSegments) { |
| 171 | + throw std::length_error("concurrent_vector: exceeded MaxSegments"); |
| 172 | + } |
| 173 | + |
| 174 | + T* seg = segment_ptr_(s); |
| 175 | + if (seg) return seg; |
| 176 | + |
| 177 | + // Allocate segment lazily (raw storage for T objects) |
| 178 | + const std::size_t n = segment_size_(s); |
| 179 | + void* raw = ::operator new(sizeof(T) * n); |
| 180 | + T* fresh = static_cast<T*>(raw); |
| 181 | + |
| 182 | + // CAS install; if another thread won, free ours. |
| 183 | + void* expected = nullptr; |
| 184 | + if (!segments_[s].compare_exchange_strong( |
| 185 | + expected, fresh, |
| 186 | + std::memory_order_release, |
| 187 | + std::memory_order_acquire)) { |
| 188 | + ::operator delete(raw); |
| 189 | + seg = static_cast<T*>(segments_[s].load(std::memory_order_acquire)); |
| 190 | + assert(seg != nullptr); |
| 191 | + return seg; |
| 192 | + } |
| 193 | + |
| 194 | + return fresh; |
| 195 | + } |
| 196 | + |
| 197 | + // Destroy constructed elements and free segments. |
| 198 | + // Not concurrent with pushes or reads. |
| 199 | + void destroy_all_() noexcept { |
| 200 | + const std::size_t n = size_.load(std::memory_order_acquire); |
| 201 | + |
| 202 | + // Destroy elements that were constructed. |
| 203 | + // NOTE: This assumes indices [0, n) are all constructed. |
| 204 | + // If you allow exceptions or partial construction, track constructed flags. |
| 205 | + for (std::size_t i = 0; i < n; ++i) { |
| 206 | + data_at(i)->~T(); |
| 207 | + } |
| 208 | + |
| 209 | + // Free segments |
| 210 | + for (std::size_t s = 0; s < segments_.size(); ++s) { |
| 211 | + void* p = segments_[s].load(std::memory_order_acquire); |
| 212 | + if (p) { |
| 213 | + ::operator delete(p); |
| 214 | + segments_[s].store(nullptr, std::memory_order_relaxed); |
| 215 | + } |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + std::atomic<std::size_t> size_; |
| 220 | + std::vector<std::atomic<void*>> segments_; |
| 221 | + }; |
| 222 | +} |
| 223 | +} |
| 224 | +} |
| 225 | + |
| 226 | +#endif |
0 commit comments