1use core::{cell::UnsafeCell, mem::MaybeUninit};
8#[cfg(not(loom))]
9use std::sync::{
10 Arc,
11 atomic::{AtomicUsize, Ordering},
12};
13
14#[cfg(loom)]
15use loom::sync::{
16 Arc,
17 atomic::{AtomicUsize, Ordering},
18};
19
20#[repr(align(64))]
26struct CachePadded<T>(T);
27
28impl<T> CachePadded<T> {
29 const fn new(value: T) -> Self {
30 CachePadded(value)
31 }
32}
33
34impl<T> core::ops::Deref for CachePadded<T> {
35 type Target = T;
36
37 #[inline]
38 fn deref(&self) -> &Self::Target {
39 &self.0
40 }
41}
42
43struct SpscRingBufferInner<T> {
45 buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
46 capacity: usize,
47 mask: usize,
48 head: CachePadded<AtomicUsize>, tail: CachePadded<AtomicUsize>, }
51
52unsafe impl<T: Send> Send for SpscRingBufferInner<T> {}
55unsafe impl<T: Send> Sync for SpscRingBufferInner<T> {}
56
57impl<T> SpscRingBufferInner<T> {
58 fn new(capacity: usize) -> Self {
59 let capacity = capacity.next_power_of_two().max(1);
60 let mask = capacity - 1;
61
62 let buffer: Vec<UnsafeCell<MaybeUninit<T>>> =
63 (0..capacity).map(|_| UnsafeCell::new(MaybeUninit::uninit())).collect();
64
65 Self {
66 buffer: buffer.into_boxed_slice(),
67 capacity,
68 mask,
69 head: CachePadded::new(AtomicUsize::new(0)),
70 tail: CachePadded::new(AtomicUsize::new(0)),
71 }
72 }
73}
74
75impl<T> Drop for SpscRingBufferInner<T> {
76 fn drop(&mut self) {
77 let head = self.head.load(Ordering::Relaxed);
78 let tail = self.tail.load(Ordering::Relaxed);
79
80 for i in tail..head {
81 let index = i & self.mask;
82 unsafe {
84 let ptr = (*self.buffer[index].get()).as_mut_ptr();
85 core::ptr::drop_in_place(ptr);
86 }
87 }
88 }
89}
90
91pub struct SpscRingBuffer;
93
94impl SpscRingBuffer {
95 #[allow(clippy::new_ret_no_self)]
111 pub fn new<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
112 let inner = Arc::new(SpscRingBufferInner::new(capacity));
113 (
114 Producer {
115 inner: Arc::clone(&inner),
116 },
117 Consumer { inner },
118 )
119 }
120}
121
122pub struct Producer<T> {
126 inner: Arc<SpscRingBufferInner<T>>,
127}
128
129unsafe impl<T: Send> Send for Producer<T> {}
131
132impl<T> Producer<T> {
133 #[inline]
138 pub fn try_push(&mut self, value: T) -> Result<(), T> {
139 let head = self.inner.head.load(Ordering::Relaxed);
140 let tail = self.inner.tail.load(Ordering::Acquire);
141
142 if head.wrapping_sub(tail) >= self.inner.capacity {
143 return Err(value);
144 }
145
146 let index = head & self.inner.mask;
147 unsafe {
149 (*self.inner.buffer[index].get()).write(value);
150 }
151
152 self.inner.head.store(head.wrapping_add(1), Ordering::Release);
153 Ok(())
154 }
155
156 #[inline]
160 pub fn len(&self) -> usize {
161 let head = self.inner.head.load(Ordering::Relaxed);
162 let tail = self.inner.tail.load(Ordering::Relaxed);
163 head.wrapping_sub(tail)
164 }
165
166 #[inline]
168 pub fn is_full(&self) -> bool {
169 self.len() >= self.inner.capacity
170 }
171
172 #[inline]
174 pub fn is_empty(&self) -> bool {
175 self.len() == 0
176 }
177
178 #[inline]
180 pub fn capacity(&self) -> usize {
181 self.inner.capacity
182 }
183}
184
185pub struct Consumer<T> {
189 inner: Arc<SpscRingBufferInner<T>>,
190}
191
192unsafe impl<T: Send> Send for Consumer<T> {}
194
195impl<T> Consumer<T> {
196 #[inline]
201 pub fn try_pop(&mut self) -> Option<T> {
202 let tail = self.inner.tail.load(Ordering::Relaxed);
203 let head = self.inner.head.load(Ordering::Acquire);
204
205 if tail >= head {
206 return None;
207 }
208
209 let index = tail & self.inner.mask;
210 let value = unsafe { (*self.inner.buffer[index].get()).assume_init_read() };
212
213 self.inner.tail.store(tail.wrapping_add(1), Ordering::Release);
214 Some(value)
215 }
216
217 #[inline]
221 pub fn len(&self) -> usize {
222 let head = self.inner.head.load(Ordering::Relaxed);
223 let tail = self.inner.tail.load(Ordering::Relaxed);
224 head.wrapping_sub(tail)
225 }
226
227 #[inline]
229 pub fn is_empty(&self) -> bool {
230 self.len() == 0
231 }
232
233 #[inline]
235 pub fn is_full(&self) -> bool {
236 self.len() >= self.inner.capacity
237 }
238
239 #[inline]
241 pub fn capacity(&self) -> usize {
242 self.inner.capacity
243 }
244}
245
246#[cfg(all(test, not(loom)))]
247mod tests {
248 use std::{rc::Rc, thread};
249
250 use super::*;
251
252 #[test]
253 fn test_basic_push_pop() {
254 let (mut producer, mut consumer) = SpscRingBuffer::new::<i32>(4);
255
256 assert!(producer.try_push(1).is_ok());
257 assert!(producer.try_push(2).is_ok());
258 assert!(producer.try_push(3).is_ok());
259
260 assert_eq!(consumer.try_pop(), Some(1));
261 assert_eq!(consumer.try_pop(), Some(2));
262 assert_eq!(consumer.try_pop(), Some(3));
263 assert_eq!(consumer.try_pop(), None);
264 }
265
266 #[test]
267 fn test_empty_buffer() {
268 let (_producer, mut consumer) = SpscRingBuffer::new::<i32>(4);
269
270 assert!(consumer.is_empty());
271 assert_eq!(consumer.try_pop(), None);
272 }
273
274 #[test]
275 fn test_full_buffer() {
276 let (mut producer, _consumer) = SpscRingBuffer::new::<i32>(4);
277
278 assert!(producer.try_push(1).is_ok());
280 assert!(producer.try_push(2).is_ok());
281 assert!(producer.try_push(3).is_ok());
282 assert!(producer.try_push(4).is_ok());
283 assert!(producer.is_full());
284
285 assert_eq!(producer.try_push(5), Err(5));
287 }
288
289 #[test]
290 fn test_capacity_rounding() {
291 let (producer, _consumer) = SpscRingBuffer::new::<i32>(3);
293 assert_eq!(producer.capacity(), 4);
294
295 let (producer, _consumer) = SpscRingBuffer::new::<i32>(5);
297 assert_eq!(producer.capacity(), 8);
298
299 let (producer, _consumer) = SpscRingBuffer::new::<i32>(0);
301 assert_eq!(producer.capacity(), 1);
302 }
303
304 #[test]
305 fn test_wraparound() {
306 let (mut producer, mut consumer) = SpscRingBuffer::new::<i32>(4);
307
308 for round in 0..10 {
310 for i in 0..4 {
311 assert!(producer.try_push(round * 10 + i).is_ok());
312 }
313
314 for i in 0..4 {
315 assert_eq!(consumer.try_pop(), Some(round * 10 + i));
316 }
317 }
318 }
319
320 #[test]
321 fn test_len() {
322 let (mut producer, mut consumer) = SpscRingBuffer::new::<i32>(4);
323
324 assert_eq!(producer.len(), 0);
325 assert_eq!(consumer.len(), 0);
326
327 producer.try_push(1).unwrap();
328 assert_eq!(producer.len(), 1);
329 assert_eq!(consumer.len(), 1);
330
331 producer.try_push(2).unwrap();
332 assert_eq!(producer.len(), 2);
333
334 consumer.try_pop();
335 assert_eq!(consumer.len(), 1);
336 }
337
338 #[test]
339 fn test_concurrent_push_pop() {
340 let (mut producer, mut consumer) = SpscRingBuffer::new::<i32>(1024);
341
342 let num_items = 10_000;
343
344 let producer_thread = thread::spawn(move || {
345 for i in 0..num_items {
346 while producer.try_push(i).is_err() {
347 thread::yield_now();
349 }
350 }
351 });
352
353 let consumer_thread = thread::spawn(move || {
354 let mut received = Vec::with_capacity(num_items as usize);
355 while received.len() < num_items as usize {
356 if let Some(value) = consumer.try_pop() {
357 received.push(value);
358 } else {
359 thread::yield_now();
360 }
361 }
362 received
363 });
364
365 producer_thread.join().unwrap();
366 let received = consumer_thread.join().unwrap();
367
368 assert_eq!(received.len(), num_items as usize);
370 for (i, &value) in received.iter().enumerate() {
371 assert_eq!(value, i as i32);
372 }
373 }
374
375 #[test]
376 fn test_drop_remaining_items() {
377 let counter = Rc::new(());
378
379 {
380 let (mut producer, _consumer) = SpscRingBuffer::new::<Rc<()>>(4);
381
382 producer.try_push(Rc::clone(&counter)).unwrap();
383 producer.try_push(Rc::clone(&counter)).unwrap();
384 producer.try_push(Rc::clone(&counter)).unwrap();
385
386 assert_eq!(Rc::strong_count(&counter), 4);
387 }
389
390 assert_eq!(Rc::strong_count(&counter), 1);
392 }
393
394 #[test]
395 fn test_partial_consumption_drop() {
396 let counter = Rc::new(());
397
398 {
399 let (mut producer, mut consumer) = SpscRingBuffer::new::<Rc<()>>(4);
400
401 producer.try_push(Rc::clone(&counter)).unwrap();
402 producer.try_push(Rc::clone(&counter)).unwrap();
403 producer.try_push(Rc::clone(&counter)).unwrap();
404
405 assert_eq!(Rc::strong_count(&counter), 4);
406
407 let _ = consumer.try_pop();
409 assert_eq!(Rc::strong_count(&counter), 3);
410
411 }
413
414 assert_eq!(Rc::strong_count(&counter), 1);
415 }
416}
417
418#[cfg(loom)]
419mod loom_tests {
420 use loom::thread;
421
422 use super::*;
423
424 #[test]
425 fn loom_concurrent_push_pop() {
426 loom::model(|| {
427 let (mut producer, mut consumer) = SpscRingBuffer::new::<i32>(2);
428
429 let producer_thread = thread::spawn(move || {
430 let _ = producer.try_push(1);
431 let _ = producer.try_push(2);
432 });
433
434 let consumer_thread = thread::spawn(move || {
435 let mut received = Vec::new();
436 for _ in 0..2 {
437 if let Some(v) = consumer.try_pop() {
438 received.push(v);
439 }
440 }
441 received
442 });
443
444 producer_thread.join().unwrap();
445 let _received = consumer_thread.join().unwrap();
446 });
447 }
448
449 #[test]
450 fn loom_single_item() {
451 loom::model(|| {
452 let (mut producer, mut consumer) = SpscRingBuffer::new::<i32>(1);
453
454 let producer_thread = thread::spawn(move || producer.try_push(42).ok());
455
456 let consumer_thread = thread::spawn(move || consumer.try_pop());
457
458 let push_result = producer_thread.join().unwrap();
459 let pop_result = consumer_thread.join().unwrap();
460
461 if push_result.is_some() {
463 assert!(pop_result.is_none() || pop_result == Some(42));
465 }
466 });
467 }
468}