1use std::{error::Error, fs::File, io::BufWriter, path::Path};
4
5use bbx_dsp::{
6 buffer::{AudioBuffer, Buffer},
7 context::DEFAULT_SAMPLE_RATE,
8 sample::Sample,
9 writer::Writer,
10};
11use hound::{SampleFormat, WavSpec, WavWriter};
12
13const BIT_DEPTH: u16 = 32;
14
15pub struct WavFileWriter<S: Sample> {
20 writer: Option<WavWriter<BufWriter<File>>>,
21 sample_rate: f64,
22 num_channels: usize,
23 samples_written: usize,
24 channel_buffers: Vec<AudioBuffer<S>>,
25}
26
27impl<S: Sample> WavFileWriter<S> {
28 pub fn new(file_path: &str, sample_rate: f64, num_channels: usize) -> Result<Self, Box<dyn Error>> {
31 let spec = WavSpec {
32 channels: num_channels as u16,
33 sample_rate: sample_rate as u32,
34 bits_per_sample: BIT_DEPTH,
35 sample_format: SampleFormat::Float,
36 };
37
38 let writer = WavWriter::create(Path::new(file_path), spec)?;
39
40 Ok(Self {
41 writer: Some(writer),
42 sample_rate,
43 num_channels,
44 samples_written: 0,
45 channel_buffers: vec![AudioBuffer::new(DEFAULT_SAMPLE_RATE as usize); num_channels],
46 })
47 }
48}
49
50impl<S: Sample> Writer<S> for WavFileWriter<S> {
51 fn sample_rate(&self) -> f64 {
52 self.sample_rate
53 }
54
55 fn num_channels(&self) -> usize {
56 self.num_channels
57 }
58
59 fn can_write(&self) -> bool {
60 self.writer.is_some()
62 }
63
64 fn write_channel(&mut self, channel_index: usize, samples: &[S]) -> Result<(), Box<dyn Error>> {
65 if channel_index >= self.num_channels {
66 return Err("Channel index out of bounds".into());
67 }
68
69 self.channel_buffers[channel_index].extend_from_slice(samples);
70 self.write_available_samples()?;
71
72 Ok(())
73 }
74
75 fn finalize(&mut self) -> Result<(), Box<dyn Error>> {
76 self.write_available_samples()?;
77
78 if let Some(writer) = self.writer.take() {
79 writer.finalize()?;
80 }
81
82 Ok(())
83 }
84}
85
86impl<S: Sample> WavFileWriter<S> {
87 fn write_available_samples(&mut self) -> Result<(), Box<dyn std::error::Error>> {
89 if let Some(ref mut writer) = self.writer {
90 let min_len = self.channel_buffers.iter().map(|buf| buf.len()).min().unwrap_or(0);
91
92 if min_len == 0 {
93 return Ok(());
94 }
95
96 for sample_idx in 0..min_len {
97 for channel_idx in 0..self.num_channels {
98 let sample = self.channel_buffers[channel_idx][sample_idx];
99 writer.write_sample(sample.to_f64() as f32)?;
100 }
101 }
102
103 for channel_buffer in &mut self.channel_buffers {
104 channel_buffer.drain_front(min_len);
105 }
106
107 self.samples_written += min_len * self.num_channels;
108 }
109
110 Ok(())
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use std::fs;
117
118 use tempfile::NamedTempFile;
119
120 use super::*;
121
122 #[test]
123 fn test_wav_writer_creation() {
124 let temp_file = NamedTempFile::new().unwrap();
125 let path = temp_file.path().to_str().unwrap();
126
127 let writer = WavFileWriter::<f32>::new(path, 44100.0, 2);
128 assert!(writer.is_ok());
129
130 let writer = writer.unwrap();
131 assert_eq!(writer.sample_rate(), 44100.0);
132 assert_eq!(writer.num_channels(), 2);
133 assert!(writer.can_write());
134 }
135
136 #[test]
137 fn test_wav_writer_write_and_finalize() {
138 let temp_file = NamedTempFile::new().unwrap();
139 let path = temp_file.path().to_str().unwrap();
140
141 let mut writer = WavFileWriter::<f32>::new(path, 44100.0, 2).unwrap();
142
143 let samples: Vec<f32> = (0..100).map(|i| (i as f32 / 100.0)).collect();
145 writer.write_channel(0, &samples).unwrap();
146 writer.write_channel(1, &samples).unwrap();
147 writer.finalize().unwrap();
148
149 let metadata = fs::metadata(path).unwrap();
151 assert!(metadata.len() > 0);
152 }
153
154 #[test]
155 fn test_wav_writer_channel_bounds() {
156 let temp_file = NamedTempFile::new().unwrap();
157 let path = temp_file.path().to_str().unwrap();
158
159 let mut writer = WavFileWriter::<f32>::new(path, 44100.0, 2).unwrap();
160
161 let samples = vec![0.0f32; 10];
163 let result = writer.write_channel(5, &samples);
164 assert!(result.is_err());
165 }
166}