-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconcurrent_queue.h
202 lines (172 loc) · 6.24 KB
/
concurrent_queue.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#ifndef CONCURRENT_QUEUE_H
#define CONCURRENT_QUEUE_H
#include <atomic>
#include "HazardPointer/reclaimer.h"
template <typename T>
class ConcreteReclaimer;
template <typename T>
class ConcurrentQueue {
static_assert(std::is_copy_constructible_v<T>, "T requires copy constructor");
struct Node;
struct RegularNode;
friend ConcreteReclaimer<T>;
public:
ConcurrentQueue()
: head_(new Node),
tail_(head_.load(std::memory_order_relaxed)),
size_(0) {}
~ConcurrentQueue() {
Node* p = head_.load(std::memory_order_acquire);
while(p != nullptr) {
Node* tmp = p;
p = p->next.load(std::memory_order_acquire);
tmp->Release();
}
}
ConcurrentQueue(const ConcurrentQueue&) = delete;
ConcurrentQueue(ConcurrentQueue&&) = delete;
ConcurrentQueue& operator=(const ConcurrentQueue& other) = delete;
ConcurrentQueue& operator=(ConcurrentQueue&& other) = delete;
template <typename... Args>
void Emplace(Args&&... args);
void Enqueue(const T& value) {
static_assert(std::is_copy_constructible<T>::value,
"T must be copy constructible");
Emplace(value);
};
void Enqueue(T&& value) {
static_assert(std::is_constructible_v<T, T&&>,
"T must be constructible with T&&");
Emplace(std::forward<T>(value));
}
bool Dequeue(T& data);
size_t size() const { return size_.load(std::memory_order_relaxed); }
private:
Node* get_head() const { return head_.load(std::memory_order_acquire); }
Node* get_tail() const { return tail_.load(std::memory_order_acquire); }
// Get safe node and its next, ensure next is the succeed of node
// and both pointer are safety.
// REQUIRE: atomic_node is head_ or tail_.
void AcquireSafeNodeAndNext(std::atomic<Node*>& atomic_node, Node** node_ptr,
Node** next_ptr, HazardPointer& node_hp,
HazardPointer& next_hp);
// Invoke this function when the node can be reclaimed
static void OnDeleteNode(void* ptr) { static_cast<Node*>(ptr)->Release(); }
struct Node {
Node() : next(nullptr) {}
virtual ~Node() = default;
virtual void Release() { delete this; }
Node* get_next() const { return next.load(std::memory_order_acquire); }
std::atomic<Node*> next;
};
struct RegularNode : Node {
template <typename... Args>
RegularNode(Args&&... args) : value(std::forward<Args>(args)...) {}
~RegularNode() override = default;
void Release() override { delete this; }
T value;
};
std::atomic<Node*> head_;
std::atomic<Node*> tail_;
std::atomic<size_t> size_;
static Reclaimer::HazardPointerList global_hp_list_;
};
template <typename T>
Reclaimer::HazardPointerList ConcurrentQueue<T>::global_hp_list_;
template <typename T>
class ConcreteReclaimer : public Reclaimer {
friend ConcurrentQueue<T>;
private:
ConcreteReclaimer(HazardPointerList& hp_list) : Reclaimer(hp_list) {}
~ConcreteReclaimer() override = default;
static ConcreteReclaimer<T>& GetInstance() {
thread_local static ConcreteReclaimer reclaimer(
ConcurrentQueue<T>::global_hp_list_);
return reclaimer;
}
};
template <typename T>
void ConcurrentQueue<T>::AcquireSafeNodeAndNext(std::atomic<Node*>& atomic_node,
Node** node_ptr,
Node** next_ptr,
HazardPointer& node_hp,
HazardPointer& next_hp) {
Node* node = atomic_node.load(std::memory_order_acquire);
Node* next;
Node* temp_node;
Node* temp_next;
auto& reclaimer = ConcreteReclaimer<T>::GetInstance();
do {
do {
// 1.UnMark old node;
node_hp.UnMark();
temp_node = node;
// 2. Mark node.
node_hp = HazardPointer(&reclaimer, node);
node = atomic_node.load(std::memory_order_acquire);
// 3. Make sure the node is still the one we mark before.
} while (temp_node != node);
// 4. UnMark old next.
next_hp.UnMark();
next = node->get_next();
temp_next = next;
// 5. Mark next.
next_hp = HazardPointer(&reclaimer, next);
next = node->get_next();
// 6. Make sure the next is still the succeed of first.
} while (temp_next != next);
*node_ptr = node;
*next_ptr = next;
}
template <typename T>
template <typename... Args>
void ConcurrentQueue<T>::Emplace(Args&&... args) {
static_assert(std::is_constructible_v<T, Args&&...>,
"T must be constructible with Args&&...");
RegularNode* new_node = new RegularNode(std::forward<Args>(args)...);
Node* tail;
Node* next;
HazardPointer tail_hp, next_hp;
while (true) {
AcquireSafeNodeAndNext(tail_, &tail, &next, tail_hp, next_hp);
if (tail != get_tail()) continue; // Are tail and next consistent?
if (nullptr == next) { // Was tail point to last node?
// Try to link node at the end of the linked list.
if (tail->next.compare_exchange_strong(next, new_node)) break;
} else {
// Try to swing tail to the next node.
tail_.compare_exchange_weak(tail, next);
}
}
// Enqueue is done. Try to swing tail to the inserted node.
tail_.compare_exchange_weak(tail, new_node);
size_.fetch_add(1, std::memory_order_relaxed);
}
template <typename T>
bool ConcurrentQueue<T>::Dequeue(T& value) {
HazardPointer head_hp;
HazardPointer next_hp;
Node* head;
Node* next;
Node* tail;
while (true) {
AcquireSafeNodeAndNext(head_, &head, &next, head_hp, next_hp);
tail = get_tail();
if (head != get_head()) continue; // Are head, tail, and next consistent?
if (head == tail) { // Is queue empty or tail falling behind?
if (nullptr == next) return false; // Queue is empty;
// Tail is falling behind. Try to advance it.
tail_.compare_exchange_weak(tail, next);
} else {
// Try to swing head to the next node.
if (head_.compare_exchange_strong(head, next)) break;
}
}
size_.fetch_sub(1, std::memory_order_relaxed);
auto& reclaimer = ConcreteReclaimer<T>::GetInstance();
reclaimer.ReclaimLater(head, ConcurrentQueue<T>::OnDeleteNode);
reclaimer.ReclaimNoHazardPointer();
value = std::move(static_cast<RegularNode*>(next)->value);
return true;
}
#endif // CONCURRENT_QUEUE_H