1use std::path::Path;
4
5use bbx_dsp::{
6 buffer::{AudioBuffer, Buffer},
7 reader::Reader,
8 sample::Sample,
9};
10use wavers::Wav;
11
12pub struct WavFileReader<S: Sample> {
17 channel_buffers: Vec<AudioBuffer<S>>,
18 sample_rate: f64,
19 num_channels: usize,
20 num_samples: usize,
21}
22
23impl<S: Sample> WavFileReader<S> {
24 pub fn from_path(file_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
27 let mut reader: Wav<f32> = Wav::from_path(Path::new(file_path))?;
28
29 let sample_rate = reader.sample_rate() as f64;
30 let num_channels = reader.n_channels() as usize;
31 let num_samples = reader.n_samples();
32
33 let mut channels = Vec::with_capacity(num_channels);
34 for _ in 0..num_channels {
35 channels.push(AudioBuffer::new(num_samples));
36 }
37
38 for (channel_index, channel) in reader.channels().enumerate() {
39 for (sample_index, sample) in channel.iter().enumerate() {
40 channels[channel_index][sample_index] = S::from_f64(*sample as f64);
41 }
42 }
43
44 Ok(Self {
45 channel_buffers: channels,
46 sample_rate,
47 num_channels,
48 num_samples,
49 })
50 }
51}
52
53impl<S: Sample> Reader<S> for WavFileReader<S> {
54 fn sample_rate(&self) -> f64 {
55 self.sample_rate
56 }
57
58 fn num_channels(&self) -> usize {
59 self.num_channels
60 }
61
62 fn num_samples(&self) -> usize {
63 self.num_samples
64 }
65
66 fn read_channel(&self, channel_index: usize) -> &[S] {
67 self.channel_buffers[channel_index].as_slice()
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use std::io::BufWriter;
74
75 use hound::{SampleFormat, WavSpec, WavWriter};
76 use tempfile::NamedTempFile;
77
78 use super::*;
79
80 fn create_test_wav(sample_rate: u32, num_channels: u16, samples: &[Vec<f32>]) -> NamedTempFile {
81 let temp_file = NamedTempFile::new().unwrap();
82 let spec = WavSpec {
83 channels: num_channels,
84 sample_rate,
85 bits_per_sample: 32,
86 sample_format: SampleFormat::Float,
87 };
88
89 let mut writer = WavWriter::new(BufWriter::new(temp_file.reopen().unwrap()), spec).unwrap();
90
91 let num_samples = samples[0].len();
92 for i in 0..num_samples {
93 for channel in samples {
94 writer.write_sample(channel[i]).unwrap();
95 }
96 }
97 writer.finalize().unwrap();
98
99 temp_file
100 }
101
102 #[test]
103 fn test_wav_reader_mono_f32() {
104 let samples = vec![vec![0.0, 0.5, 1.0, -0.5, -1.0]];
105 let temp_file = create_test_wav(44100, 1, &samples);
106
107 let reader = WavFileReader::<f32>::from_path(temp_file.path().to_str().unwrap()).unwrap();
108
109 assert_eq!(reader.num_channels(), 1);
110 assert_eq!(reader.num_samples(), 5);
111 assert_eq!(reader.sample_rate(), 44100.0);
112
113 let channel = reader.read_channel(0);
114 for (i, &expected) in samples[0].iter().enumerate() {
115 assert!(
116 (channel[i] - expected).abs() < 1e-6,
117 "Sample {} mismatch: {} vs {}",
118 i,
119 channel[i],
120 expected
121 );
122 }
123 }
124
125 #[test]
126 fn test_wav_reader_mono_f64() {
127 let samples = vec![vec![0.0, 0.25, 0.75, -0.25, -0.75]];
128 let temp_file = create_test_wav(48000, 1, &samples);
129
130 let reader = WavFileReader::<f64>::from_path(temp_file.path().to_str().unwrap()).unwrap();
131
132 assert_eq!(reader.num_channels(), 1);
133 assert_eq!(reader.num_samples(), 5);
134 assert_eq!(reader.sample_rate(), 48000.0);
135
136 let channel = reader.read_channel(0);
137 for (i, &expected) in samples[0].iter().enumerate() {
138 assert!((channel[i] - expected as f64).abs() < 1e-6, "Sample {} mismatch", i);
139 }
140 }
141
142 #[test]
143 fn test_wav_reader_stereo_f32() {
144 let left = vec![0.1, 0.2, 0.3, 0.4, 0.5];
145 let right = vec![-0.1, -0.2, -0.3, -0.4, -0.5];
146 let samples = vec![left.clone(), right.clone()];
147 let temp_file = create_test_wav(44100, 2, &samples);
148
149 let reader = WavFileReader::<f32>::from_path(temp_file.path().to_str().unwrap()).unwrap();
150
151 assert_eq!(reader.num_channels(), 2);
152 assert_eq!(reader.num_samples(), 10);
154
155 let left_channel = reader.read_channel(0);
156 let right_channel = reader.read_channel(1);
157
158 for i in 0..5 {
159 assert!((left_channel[i] - left[i]).abs() < 1e-6, "Left sample {} mismatch", i);
160 assert!(
161 (right_channel[i] - right[i]).abs() < 1e-6,
162 "Right sample {} mismatch",
163 i
164 );
165 }
166 }
167
168 #[test]
169 fn test_wav_reader_sample_rate() {
170 let samples = vec![vec![0.0; 10]];
171
172 let temp_22050 = create_test_wav(22050, 1, &samples);
173 let reader = WavFileReader::<f32>::from_path(temp_22050.path().to_str().unwrap()).unwrap();
174 assert_eq!(reader.sample_rate(), 22050.0);
175
176 let temp_96000 = create_test_wav(96000, 1, &samples);
177 let reader = WavFileReader::<f32>::from_path(temp_96000.path().to_str().unwrap()).unwrap();
178 assert_eq!(reader.sample_rate(), 96000.0);
179 }
180
181 #[test]
182 fn test_wav_reader_invalid_path() {
183 let result = WavFileReader::<f32>::from_path("/nonexistent/path/audio.wav");
184 assert!(result.is_err());
185 }
186}