1mod hrir_data;
10mod hrtf;
11mod matrix;
12mod virtual_speaker;
13
14use std::marker::PhantomData;
15
16use hrtf::HrtfConvolver;
17
18use crate::{
19 block::Block, channel::ChannelConfig, context::DspContext, graph::MAX_BLOCK_INPUTS, parameter::ModulationOutput,
20 sample::Sample,
21};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum BinauralStrategy {
28 Matrix,
33
34 Hrtf,
40}
41
42impl Default for BinauralStrategy {
43 fn default() -> Self {
44 Self::Hrtf
45 }
46}
47
48pub struct BinauralDecoderBlock<S: Sample> {
71 input_count: usize,
72 strategy: BinauralStrategy,
73 decoder_matrix: [[f64; MAX_BLOCK_INPUTS]; 2],
74 hrtf_convolver: Option<Box<HrtfConvolver>>,
75 _phantom: PhantomData<S>,
76}
77
78impl<S: Sample> BinauralDecoderBlock<S> {
79 pub fn new(order: usize) -> Self {
87 Self::with_strategy(order, BinauralStrategy::default())
88 }
89
90 pub fn with_strategy(order: usize, strategy: BinauralStrategy) -> Self {
99 assert!((1..=3).contains(&order), "Ambisonic order must be 1, 2, or 3");
100
101 let input_count = (order + 1) * (order + 1);
102 let decoder_matrix = matrix::compute_matrix(order);
103
104 let hrtf_convolver = match strategy {
105 BinauralStrategy::Matrix => None,
106 BinauralStrategy::Hrtf => Some(Box::new(HrtfConvolver::new_ambisonic(order))),
107 };
108
109 Self {
110 input_count,
111 strategy,
112 decoder_matrix,
113 hrtf_convolver,
114 _phantom: PhantomData,
115 }
116 }
117
118 pub fn new_surround(channel_count: usize, strategy: BinauralStrategy) -> Self {
127 assert!(
128 channel_count == 6 || channel_count == 8,
129 "Surround channel count must be 6 (5.1) or 8 (7.1)"
130 );
131
132 let decoder_matrix = [[0.0; MAX_BLOCK_INPUTS]; 2];
133
134 let hrtf_convolver = match strategy {
135 BinauralStrategy::Matrix => None,
136 BinauralStrategy::Hrtf => Some(Box::new(HrtfConvolver::new_surround(channel_count))),
137 };
138
139 Self {
140 input_count: channel_count,
141 strategy,
142 decoder_matrix,
143 hrtf_convolver,
144 _phantom: PhantomData,
145 }
146 }
147
148 pub fn order(&self) -> usize {
152 match self.input_count {
153 4 => 1,
154 9 => 2,
155 16 => 3,
156 _ => 0,
157 }
158 }
159
160 pub fn strategy(&self) -> BinauralStrategy {
162 self.strategy
163 }
164
165 pub fn reset(&mut self) {
167 if let Some(ref mut convolver) = self.hrtf_convolver {
168 convolver.reset();
169 }
170 }
171
172 fn process_matrix(&self, inputs: &[&[S]], outputs: &mut [&mut [S]]) {
173 let num_inputs = self.input_count.min(inputs.len());
174 let num_outputs = 2.min(outputs.len());
175
176 if num_inputs == 0 || num_outputs == 0 || inputs[0].is_empty() {
177 return;
178 }
179
180 let num_samples = inputs[0].len().min(outputs[0].len());
181
182 for (out_ch, output) in outputs.iter_mut().enumerate().take(num_outputs) {
183 for i in 0..num_samples {
184 let mut sum = 0.0f64;
185 for (in_ch, input) in inputs.iter().enumerate().take(num_inputs) {
186 sum += input[i].to_f64() * self.decoder_matrix[out_ch][in_ch];
187 }
188 output[i] = S::from_f64(sum);
189 }
190 }
191 }
192
193 fn process_hrtf(&mut self, inputs: &[&[S]], outputs: &mut [&mut [S]]) {
194 if outputs.len() < 2 {
195 return;
196 }
197
198 let num_inputs = self.input_count.min(inputs.len());
199
200 if let Some(ref mut convolver) = self.hrtf_convolver {
201 let (left_output, rest) = outputs.split_at_mut(1);
202 let right_output = &mut rest[0];
203 convolver.process(inputs, left_output[0], right_output, num_inputs);
204 }
205 }
206}
207
208impl<S: Sample> Block<S> for BinauralDecoderBlock<S> {
209 fn process(&mut self, inputs: &[&[S]], outputs: &mut [&mut [S]], _modulation_values: &[S], _context: &DspContext) {
210 match self.strategy {
211 BinauralStrategy::Matrix => self.process_matrix(inputs, outputs),
212 BinauralStrategy::Hrtf => self.process_hrtf(inputs, outputs),
213 }
214 }
215
216 #[inline]
217 fn input_count(&self) -> usize {
218 self.input_count
219 }
220
221 #[inline]
222 fn output_count(&self) -> usize {
223 2
224 }
225
226 #[inline]
227 fn modulation_outputs(&self) -> &[ModulationOutput] {
228 &[]
229 }
230
231 #[inline]
232 fn channel_config(&self) -> ChannelConfig {
233 ChannelConfig::Explicit
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::channel::ChannelLayout;
241
242 fn test_context() -> DspContext {
243 DspContext {
244 sample_rate: 44100.0,
245 num_channels: 2,
246 buffer_size: 4,
247 current_sample: 0,
248 channel_layout: ChannelLayout::Stereo,
249 }
250 }
251
252 #[test]
253 fn test_default_strategy_is_hrtf() {
254 assert_eq!(BinauralStrategy::default(), BinauralStrategy::Hrtf);
255 }
256
257 #[test]
258 fn test_new_uses_default_strategy() {
259 let decoder = BinauralDecoderBlock::<f32>::new(1);
260 assert_eq!(decoder.strategy(), BinauralStrategy::Hrtf);
261 }
262
263 #[test]
264 fn test_with_strategy_matrix() {
265 let decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
266 assert_eq!(decoder.strategy(), BinauralStrategy::Matrix);
267 }
268
269 #[test]
270 fn test_binaural_foa_channel_counts() {
271 let decoder = BinauralDecoderBlock::<f32>::new(1);
272 assert_eq!(decoder.input_count(), 4);
273 assert_eq!(decoder.output_count(), 2);
274 assert_eq!(decoder.channel_config(), ChannelConfig::Explicit);
275 }
276
277 #[test]
278 fn test_binaural_soa_channel_counts() {
279 let decoder = BinauralDecoderBlock::<f32>::new(2);
280 assert_eq!(decoder.input_count(), 9);
281 assert_eq!(decoder.output_count(), 2);
282 }
283
284 #[test]
285 fn test_binaural_toa_channel_counts() {
286 let decoder = BinauralDecoderBlock::<f32>::new(3);
287 assert_eq!(decoder.input_count(), 16);
288 assert_eq!(decoder.output_count(), 2);
289 }
290
291 #[test]
292 fn test_binaural_front_signal_balanced() {
293 let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
294 let context = test_context();
295
296 let w = [1.0f32; 4];
298 let y = [0.0f32; 4];
299 let z = [0.0f32; 4];
300 let x = [1.0f32; 4];
301 let mut left_out = [0.0f32; 4];
302 let mut right_out = [0.0f32; 4];
303
304 let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
305 let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
306
307 decoder.process(&inputs, &mut outputs, &[], &context);
308
309 let diff = (left_out[0] - right_out[0]).abs();
310 assert!(diff < 0.01, "Front signal should be balanced, diff={}", diff);
311 }
312
313 #[test]
314 fn test_binaural_left_signal_louder_in_left() {
315 let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
316 let context = test_context();
317
318 let w = [1.0f32; 4];
320 let y = [1.0f32; 4];
321 let z = [0.0f32; 4];
322 let x = [0.0f32; 4];
323 let mut left_out = [0.0f32; 4];
324 let mut right_out = [0.0f32; 4];
325
326 let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
327 let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
328
329 decoder.process(&inputs, &mut outputs, &[], &context);
330
331 assert!(
332 left_out[0] > right_out[0],
333 "Left signal should be louder in left channel: L={}, R={}",
334 left_out[0],
335 right_out[0]
336 );
337 }
338
339 #[test]
340 fn test_binaural_right_signal_louder_in_right() {
341 let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
342 let context = test_context();
343
344 let w = [1.0f32; 4];
346 let y = [-1.0f32; 4];
347 let z = [0.0f32; 4];
348 let x = [0.0f32; 4];
349 let mut left_out = [0.0f32; 4];
350 let mut right_out = [0.0f32; 4];
351
352 let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
353 let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
354
355 decoder.process(&inputs, &mut outputs, &[], &context);
356
357 assert!(
358 right_out[0] > left_out[0],
359 "Right signal should be louder in right channel: L={}, R={}",
360 left_out[0],
361 right_out[0]
362 );
363 }
364
365 #[test]
366 fn test_binaural_rear_signal_balanced() {
367 let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
368 let context = test_context();
369
370 let w = [1.0f32; 4];
372 let y = [0.0f32; 4];
373 let z = [0.0f32; 4];
374 let x = [-1.0f32; 4];
375 let mut left_out = [0.0f32; 4];
376 let mut right_out = [0.0f32; 4];
377
378 let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
379 let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
380
381 decoder.process(&inputs, &mut outputs, &[], &context);
382
383 let diff = (left_out[0] - right_out[0]).abs();
384 assert!(diff < 0.01, "Rear signal should be balanced, diff={}", diff);
385 }
386
387 #[test]
388 fn test_binaural_silence_produces_silence() {
389 let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
390 let context = test_context();
391
392 let w = [0.0f32; 4];
393 let y = [0.0f32; 4];
394 let z = [0.0f32; 4];
395 let x = [0.0f32; 4];
396 let mut left_out = [1.0f32; 4];
397 let mut right_out = [1.0f32; 4];
398
399 let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
400 let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
401
402 decoder.process(&inputs, &mut outputs, &[], &context);
403
404 for i in 0..4 {
405 assert!(left_out[i].abs() < 1e-10, "Left output should be silence");
406 assert!(right_out[i].abs() < 1e-10, "Right output should be silence");
407 }
408 }
409
410 #[test]
411 fn test_binaural_order_accessor() {
412 assert_eq!(BinauralDecoderBlock::<f32>::new(1).order(), 1);
413 assert_eq!(BinauralDecoderBlock::<f32>::new(2).order(), 2);
414 assert_eq!(BinauralDecoderBlock::<f32>::new(3).order(), 3);
415 }
416
417 #[test]
418 #[should_panic]
419 fn test_binaural_invalid_order_zero_panics() {
420 let _ = BinauralDecoderBlock::<f32>::new(0);
421 }
422
423 #[test]
424 #[should_panic]
425 fn test_binaural_invalid_order_four_panics() {
426 let _ = BinauralDecoderBlock::<f32>::new(4);
427 }
428
429 #[test]
430 fn test_binaural_soa_lateral_differentiation() {
431 let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(2, BinauralStrategy::Matrix);
432 let context = test_context();
433
434 let mut inputs_data: [[f32; 4]; 9] = [[0.0; 4]; 9];
436 inputs_data[4] = [1.0; 4];
437
438 let inputs: Vec<&[f32]> = inputs_data.iter().map(|a| a.as_slice()).collect();
439 let mut left_out = [0.0f32; 4];
440 let mut right_out = [0.0f32; 4];
441 let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
442
443 decoder.process(&inputs, &mut outputs, &[], &context);
444
445 assert!(
447 left_out[0] * right_out[0] < 0.0,
448 "V channel should create opposite polarity: L={}, R={}",
449 left_out[0],
450 right_out[0]
451 );
452 }
453
454 #[test]
457 fn test_binaural_foa_channel_counts_f64() {
458 let decoder = BinauralDecoderBlock::<f64>::new(1);
459 assert_eq!(decoder.input_count(), 4);
460 assert_eq!(decoder.output_count(), 2);
461 assert_eq!(decoder.channel_config(), ChannelConfig::Explicit);
462 }
463
464 #[test]
465 fn test_binaural_soa_channel_counts_f64() {
466 let decoder = BinauralDecoderBlock::<f64>::new(2);
467 assert_eq!(decoder.input_count(), 9);
468 assert_eq!(decoder.output_count(), 2);
469 }
470
471 #[test]
472 fn test_binaural_toa_channel_counts_f64() {
473 let decoder = BinauralDecoderBlock::<f64>::new(3);
474 assert_eq!(decoder.input_count(), 16);
475 assert_eq!(decoder.output_count(), 2);
476 }
477
478 #[test]
479 fn test_binaural_front_signal_balanced_f64() {
480 let mut decoder = BinauralDecoderBlock::<f64>::with_strategy(1, BinauralStrategy::Matrix);
481 let context = test_context();
482
483 let w = [1.0f64; 4];
484 let y = [0.0f64; 4];
485 let z = [0.0f64; 4];
486 let x = [1.0f64; 4];
487 let mut left_out = [0.0f64; 4];
488 let mut right_out = [0.0f64; 4];
489
490 let inputs: [&[f64]; 4] = [&w, &y, &z, &x];
491 let mut outputs: [&mut [f64]; 2] = [&mut left_out, &mut right_out];
492
493 decoder.process(&inputs, &mut outputs, &[], &context);
494
495 let diff = (left_out[0] - right_out[0]).abs();
496 assert!(diff < 0.01, "Front signal should be balanced, diff={}", diff);
497 }
498
499 #[test]
500 fn test_binaural_left_signal_louder_in_left_f64() {
501 let mut decoder = BinauralDecoderBlock::<f64>::with_strategy(1, BinauralStrategy::Matrix);
502 let context = test_context();
503
504 let w = [1.0f64; 4];
505 let y = [1.0f64; 4];
506 let z = [0.0f64; 4];
507 let x = [0.0f64; 4];
508 let mut left_out = [0.0f64; 4];
509 let mut right_out = [0.0f64; 4];
510
511 let inputs: [&[f64]; 4] = [&w, &y, &z, &x];
512 let mut outputs: [&mut [f64]; 2] = [&mut left_out, &mut right_out];
513
514 decoder.process(&inputs, &mut outputs, &[], &context);
515
516 assert!(
517 left_out[0] > right_out[0],
518 "Left signal should be louder in left channel: L={}, R={}",
519 left_out[0],
520 right_out[0]
521 );
522 }
523
524 #[test]
525 fn test_binaural_silence_produces_silence_f64() {
526 let mut decoder = BinauralDecoderBlock::<f64>::with_strategy(1, BinauralStrategy::Matrix);
527 let context = test_context();
528
529 let w = [0.0f64; 4];
530 let y = [0.0f64; 4];
531 let z = [0.0f64; 4];
532 let x = [0.0f64; 4];
533 let mut left_out = [1.0f64; 4];
534 let mut right_out = [1.0f64; 4];
535
536 let inputs: [&[f64]; 4] = [&w, &y, &z, &x];
537 let mut outputs: [&mut [f64]; 2] = [&mut left_out, &mut right_out];
538
539 decoder.process(&inputs, &mut outputs, &[], &context);
540
541 for i in 0..4 {
542 assert!(left_out[i].abs() < 1e-14, "Left output should be silence");
543 assert!(right_out[i].abs() < 1e-14, "Right output should be silence");
544 }
545 }
546
547 #[test]
548 fn test_binaural_order_accessor_f64() {
549 assert_eq!(BinauralDecoderBlock::<f64>::new(1).order(), 1);
550 assert_eq!(BinauralDecoderBlock::<f64>::new(2).order(), 2);
551 assert_eq!(BinauralDecoderBlock::<f64>::new(3).order(), 3);
552 }
553}