1use std::{
4 sync::{
5 Arc,
6 atomic::{AtomicBool, Ordering},
7 },
8 thread::{self, JoinHandle},
9};
10
11use bbx_core::{Consumer, Producer, SpscRingBuffer};
12
13use crate::{block::Block, context::DspContext, parameter::ModulationOutput, sample::Sample, writer::Writer};
14
15const DEFAULT_RING_BUFFER_CAPACITY: usize = 44100 * 2;
17
18pub struct FileOutputBlock<S: Sample> {
24 producer: Option<Producer<S>>,
26
27 writer_thread: Option<JoinHandle<()>>,
29
30 stop_signal: Arc<AtomicBool>,
32
33 error_flag: Arc<AtomicBool>,
35
36 is_recording: bool,
38
39 num_channels: usize,
41}
42
43impl<S: Sample + Send + 'static> FileOutputBlock<S> {
44 pub fn new(writer: Box<dyn Writer<S>>) -> Self {
49 let num_channels = writer.num_channels();
50 let sample_rate = writer.sample_rate() as usize;
51
52 let buffer_capacity = sample_rate.max(DEFAULT_RING_BUFFER_CAPACITY) * num_channels;
54 let (producer, consumer) = SpscRingBuffer::new::<S>(buffer_capacity);
55
56 let stop_signal = Arc::new(AtomicBool::new(false));
57 let error_flag = Arc::new(AtomicBool::new(false));
58
59 let stop_signal_clone = stop_signal.clone();
60 let error_flag_clone = error_flag.clone();
61
62 let writer_thread = thread::spawn(move || {
63 Self::writer_thread_fn(consumer, writer, stop_signal_clone, error_flag_clone, num_channels);
64 });
65
66 Self {
67 producer: Some(producer),
68 writer_thread: Some(writer_thread),
69 stop_signal,
70 error_flag,
71 is_recording: true,
72 num_channels,
73 }
74 }
75
76 fn writer_thread_fn(
78 mut consumer: Consumer<S>,
79 mut writer: Box<dyn Writer<S>>,
80 stop_signal: Arc<AtomicBool>,
81 error_flag: Arc<AtomicBool>,
82 num_channels: usize,
83 ) {
84 let mut channel_buffers: Vec<Vec<S>> = vec![Vec::new(); num_channels];
85 let mut current_channel = 0;
86
87 const FLUSH_THRESHOLD: usize = 4096;
89
90 loop {
91 while let Some(sample) = consumer.try_pop() {
92 channel_buffers[current_channel].push(sample);
93 current_channel = (current_channel + 1) % num_channels;
94
95 if channel_buffers[0].len() >= FLUSH_THRESHOLD {
96 for (ch, buffer) in channel_buffers.iter_mut().enumerate() {
97 if writer.write_channel(ch, buffer).is_err() {
98 error_flag.store(true, Ordering::Relaxed);
99 }
100 buffer.clear();
101 }
102 }
103 }
104
105 if stop_signal.load(Ordering::Acquire) {
106 for (ch, buffer) in channel_buffers.iter().enumerate() {
107 if !buffer.is_empty() && writer.write_channel(ch, buffer).is_err() {
108 error_flag.store(true, Ordering::Relaxed);
109 }
110 }
111
112 if writer.finalize().is_err() {
113 error_flag.store(true, Ordering::Relaxed);
114 }
115
116 break;
117 }
118
119 thread::sleep(std::time::Duration::from_millis(1));
121 }
122 }
123
124 #[inline]
126 pub fn start_recording(&mut self) {
127 self.is_recording = true;
128 }
129
130 pub fn stop_recording(&mut self) -> Result<(), Box<dyn std::error::Error>> {
135 self.is_recording = false;
136
137 self.stop_signal.store(true, Ordering::Release);
138
139 if let Some(handle) = self.writer_thread.take() {
140 handle.join().map_err(|_| "Writer thread panicked")?;
141 }
142
143 if self.error_flag.load(Ordering::Relaxed) {
144 return Err("Error occurred while writing to file".into());
145 }
146
147 Ok(())
148 }
149
150 #[inline]
152 pub fn is_recording(&self) -> bool {
153 self.is_recording
154 }
155
156 #[inline]
161 pub fn error_occurred(&self) -> bool {
162 self.error_flag.load(Ordering::Relaxed)
163 }
164}
165
166impl<S: Sample + Send + 'static> Block<S> for FileOutputBlock<S> {
167 fn process(&mut self, inputs: &[&[S]], _outputs: &mut [&mut [S]], _modulation_values: &[S], _context: &DspContext) {
168 if !self.is_recording || inputs.is_empty() {
169 return;
170 }
171
172 let producer = match &mut self.producer {
173 Some(p) => p,
174 None => return,
175 };
176
177 let buffer_len = inputs[0].len();
178
179 for sample_idx in 0..buffer_len {
180 for ch in 0..self.num_channels {
181 let sample = inputs.get(ch).map_or(S::ZERO, |input| input[sample_idx]);
182 let _ = producer.try_push(sample);
183 }
184 }
185 }
186
187 #[inline]
188 fn input_count(&self) -> usize {
189 self.num_channels
190 }
191
192 #[inline]
193 fn output_count(&self) -> usize {
194 0
195 }
196
197 #[inline]
198 fn modulation_outputs(&self) -> &[ModulationOutput] {
199 &[]
200 }
201}
202
203impl<S: Sample + Send + 'static> Drop for FileOutputBlock<S> {
204 fn drop(&mut self) {
205 self.stop_signal.store(true, Ordering::Release);
206
207 if let Some(handle) = self.writer_thread.take() {
209 let _ = handle.join();
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use std::sync::Mutex;
217
218 use super::*;
219 use crate::channel::ChannelLayout;
220
221 struct MockWriter<S: Sample> {
222 sample_rate: f64,
223 num_channels: usize,
224 channels: Arc<Mutex<Vec<Vec<S>>>>,
225 finalized: Arc<AtomicBool>,
226 }
227
228 impl<S: Sample> MockWriter<S> {
229 fn new(sample_rate: f64, num_channels: usize) -> Self {
230 let channels: Vec<Vec<S>> = (0..num_channels).map(|_| Vec::new()).collect();
231 Self {
232 sample_rate,
233 num_channels,
234 channels: Arc::new(Mutex::new(channels)),
235 finalized: Arc::new(AtomicBool::new(false)),
236 }
237 }
238
239 fn get_channels(&self) -> Arc<Mutex<Vec<Vec<S>>>> {
240 self.channels.clone()
241 }
242
243 fn get_finalized(&self) -> Arc<AtomicBool> {
244 self.finalized.clone()
245 }
246 }
247
248 impl<S: Sample> Writer<S> for MockWriter<S> {
249 fn sample_rate(&self) -> f64 {
250 self.sample_rate
251 }
252
253 fn num_channels(&self) -> usize {
254 self.num_channels
255 }
256
257 fn can_write(&self) -> bool {
258 true
259 }
260
261 fn write_channel(&mut self, channel_index: usize, samples: &[S]) -> Result<(), Box<dyn std::error::Error>> {
262 let mut channels = self.channels.lock().unwrap();
263 if channel_index < channels.len() {
264 channels[channel_index].extend_from_slice(samples);
265 }
266 Ok(())
267 }
268
269 fn finalize(&mut self) -> Result<(), Box<dyn std::error::Error>> {
270 self.finalized.store(true, Ordering::Relaxed);
271 Ok(())
272 }
273 }
274
275 fn test_context(buffer_size: usize) -> DspContext {
276 DspContext {
277 sample_rate: 44100.0,
278 buffer_size,
279 num_channels: 2,
280 current_sample: 0,
281 channel_layout: ChannelLayout::Stereo,
282 }
283 }
284
285 #[test]
286 fn test_file_output_block_counts() {
287 let writer = MockWriter::<f32>::new(44100.0, 2);
288 let block = FileOutputBlock::new(Box::new(writer));
289 assert_eq!(block.input_count(), 2);
290 assert_eq!(block.output_count(), 0);
291 }
292
293 #[test]
294 fn test_file_output_block_recording_state() {
295 let writer = MockWriter::<f32>::new(44100.0, 2);
296 let mut block = FileOutputBlock::new(Box::new(writer));
297
298 assert!(block.is_recording());
299
300 block.start_recording();
301 assert!(block.is_recording());
302 }
303
304 #[test]
305 fn test_file_output_block_writes_and_finalizes() {
306 let writer = MockWriter::<f32>::new(44100.0, 1);
307 let channels = writer.get_channels();
308 let finalized = writer.get_finalized();
309
310 let mut block = FileOutputBlock::new(Box::new(writer));
311
312 let context = test_context(10);
313 let input: Vec<f32> = vec![0.5; 10];
314 let inputs: [&[f32]; 1] = [&input];
315 let mut outputs: [&mut [f32]; 0] = [];
316
317 block.process(&inputs, &mut outputs, &[], &context);
318
319 block.stop_recording().unwrap();
320
321 assert!(finalized.load(Ordering::Relaxed));
322 let written = channels.lock().unwrap();
323 assert!(!written[0].is_empty());
324 }
325
326 #[test]
327 fn test_file_output_block_no_error_initially() {
328 let writer = MockWriter::<f32>::new(44100.0, 2);
329 let block = FileOutputBlock::new(Box::new(writer));
330 assert!(!block.error_occurred());
331 }
332
333 #[test]
334 fn test_file_output_block_empty_inputs() {
335 let writer = MockWriter::<f32>::new(44100.0, 1);
336 let mut block = FileOutputBlock::new(Box::new(writer));
337
338 let context = test_context(10);
339 let inputs: [&[f32]; 0] = [];
340 let mut outputs: [&mut [f32]; 0] = [];
341
342 block.process(&inputs, &mut outputs, &[], &context);
343
344 block.stop_recording().unwrap();
345 assert!(!block.error_occurred());
346 }
347
348 #[test]
349 fn test_file_output_block_mono_input_to_stereo_file() {
350 let writer = MockWriter::<f32>::new(44100.0, 2);
351 let channels = writer.get_channels();
352
353 let mut block = FileOutputBlock::new(Box::new(writer));
354
355 let context = test_context(4);
356 let mono_input: Vec<f32> = vec![0.1, 0.2, 0.3, 0.4];
357 let inputs: [&[f32]; 1] = [&mono_input];
358 let mut outputs: [&mut [f32]; 0] = [];
359
360 block.process(&inputs, &mut outputs, &[], &context);
361 block.stop_recording().unwrap();
362
363 let written = channels.lock().unwrap();
364 assert_eq!(written[0], vec![0.1, 0.2, 0.3, 0.4]);
365 assert_eq!(written[1], vec![0.0, 0.0, 0.0, 0.0]);
366 }
367
368 #[test]
369 fn test_file_output_block_interleaving_order() {
370 let writer = MockWriter::<f32>::new(44100.0, 2);
371 let channels = writer.get_channels();
372
373 let mut block = FileOutputBlock::new(Box::new(writer));
374
375 let context = test_context(3);
376 let left: Vec<f32> = vec![1.0, 2.0, 3.0];
377 let right: Vec<f32> = vec![0.1, 0.2, 0.3];
378 let inputs: [&[f32]; 2] = [&left, &right];
379 let mut outputs: [&mut [f32]; 0] = [];
380
381 block.process(&inputs, &mut outputs, &[], &context);
382 block.stop_recording().unwrap();
383
384 let written = channels.lock().unwrap();
385 assert_eq!(written[0], vec![1.0, 2.0, 3.0]);
386 assert_eq!(written[1], vec![0.1, 0.2, 0.3]);
387 }
388
389 #[test]
390 fn test_file_output_block_excess_inputs_ignored() {
391 let writer = MockWriter::<f32>::new(44100.0, 2);
392 let channels = writer.get_channels();
393
394 let mut block = FileOutputBlock::new(Box::new(writer));
395
396 let context = test_context(2);
397 let ch0: Vec<f32> = vec![1.0, 2.0];
398 let ch1: Vec<f32> = vec![0.1, 0.2];
399 let ch2: Vec<f32> = vec![9.9, 9.9];
400 let inputs: [&[f32]; 3] = [&ch0, &ch1, &ch2];
401 let mut outputs: [&mut [f32]; 0] = [];
402
403 block.process(&inputs, &mut outputs, &[], &context);
404 block.stop_recording().unwrap();
405
406 let written = channels.lock().unwrap();
407 assert_eq!(written[0], vec![1.0, 2.0]);
408 assert_eq!(written[1], vec![0.1, 0.2]);
409 assert_eq!(written.len(), 2);
410 }
411}