Я пытаюсь реализовать кольцевой буфер на основе массива, который является потокобезопасным для нескольких производителей и одного потребителя.Основная идея - иметь атомные индексы головы и хвоста.При перемещении элемента в очередь заголовок атомарно увеличивается, чтобы зарезервировать слот в буфере:
#include <atomic>
#include <chrono>
#include <iostream>
#include <stdexcept>
#include <thread>
#include <vector>
template <class T> class MPSC {
private:
int MAX_SIZE;
std::atomic<int> head{0}; ///< index of first free slot
std::atomic<int> tail{0}; ///< index of first occupied slot
std::unique_ptr<T[]> data;
std::unique_ptr<std::atomic<bool>[]> valid; ///< indicates whether data at an
///< index has been fully written
/// Compute next index modulo size.
inline int advance(int x) { return (x + 1) % MAX_SIZE; }
public:
explicit MPSC(int size) {
if (size <= 0)
throw std::invalid_argument("size must be greater than 0");
MAX_SIZE = size + 1;
data = std::make_unique<T[]>(MAX_SIZE);
valid = std::make_unique<std::atomic<bool>[]>(MAX_SIZE);
}
/// Add an element to the queue.
///
/// If the queue is full, this method blocks until a slot is available for
/// writing. This method is not starvation-free, i.e. it is possible that one
/// thread always fills up the queue and prevents others from pushing.
void push(const T &msg) {
int idx;
int next_idx;
int k = 100;
do {
idx = head;
next_idx = advance(idx);
while (next_idx == tail) { // queue is full
k = k >= 100000 ? k : k * 2; // exponential backoff
std::this_thread::sleep_for(std::chrono::nanoseconds(k));
} // spin
} while (!head.compare_exchange_weak(idx, next_idx));
if (valid[idx])
// this throws, suggesting that two threads are writing to the same index. I have no idea how this is possible.
throw std::runtime_error("message slot already written");
data[idx] = msg;
valid[idx] = true; // this was set to false by the reader,
// set it to true to indicate completed data write
}
/// Read an element from the queue.
///
/// If the queue is empty, this method blocks until a message is available.
/// This method is only safe to be called from one single reader thread.
T pop() {
int k = 100;
while (is_empty() || !valid[tail]) {
k = k >= 100000 ? k : k * 2;
std::this_thread::sleep_for(std::chrono::nanoseconds(k));
} // spin
T res = data[tail];
valid[tail] = false;
tail = advance(tail);
return res;
}
bool is_full() { return (head + 1) % MAX_SIZE == tail; }
bool is_empty() { return head == tail; }
};
Когда имеется много перегрузок, некоторые сообщения перезаписываются другими потоками.Следовательно, в том, что я здесь делаю, должно быть что-то в корне не так.
Кажется, что происходит то, что два потока получают один и тот же индекс для записи своих данных.Почему это может быть?
Даже если производитель должен был сделать паузу непосредственно перед записью своих данных, хвост не мог бы пройти мимо этого потока idx, и, следовательно, ни один другой поток не мог бы обогнать и потребовать тот же idx.
РЕДАКТИРОВАТЬ
С риском размещения слишком большого количества кода, вот простая программа, которая воспроизводит проблему.Он отправляет некоторые инкрементные числа из множества потоков и проверяет, все ли числа получены получателем:
#include "mpsc.hpp" // or whatever; the above queue
#include <thread>
#include <iostream>
int main() {
static constexpr int N_THREADS = 10; ///< number of threads
static constexpr int N_MSG = 1E+5; ///< number of messages per thread
struct msg {
int t_id;
int i;
};
MPSC<msg> q(N_THREADS / 2);
std::thread threads[N_THREADS];
// consumer
threads[0] = std::thread([&q] {
int expected[N_THREADS] {};
for (int i = 0; i < N_MSG * (N_THREADS - 1); ++i) {
msg m = q.pop();
std::cout << "Got message from T-" << m.t_id << ": " << m.i << std::endl;
if (expected[m.t_id] != m.i) {
std::cout << "T-" << m.t_id << " unexpected msg " << m.i << "; expected " << expected[m.t_id] << std::endl;
return -1;
}
expected[m.t_id] = m.i + 1;
}
});
// producers
for (int id = 1; id < N_THREADS; ++id) {
threads[id] = std::thread([id, &q] {
for (int i = 0; i < N_MSG; ++i) {
q.push(msg{id, i});
}
});
}
for (auto &t : threads)
t.join();
}