bbx_dsp/blocks/effectors/binaural_decoder/
mod.rs

1//! Binaural decoder block for converting multi-channel audio to stereo headphone output.
2//!
3//! Supports both ambisonic (FOA/SOA/TOA) and surround (5.1, 7.1) inputs.
4//! Two decoding strategies are available:
5//!
6//! - [`BinauralStrategy::Matrix`] - Lightweight ILD-based approximation (low CPU)
7//! - [`BinauralStrategy::Hrtf`] - Full HRTF convolution for accurate binaural rendering (default)
8
9mod hrir_data;
10mod hrtf;
11mod matrix;
12mod virtual_speaker;
13
14use std::marker::PhantomData;
15
16use hrtf::HrtfConvolver;
17
18use crate::{
19    block::Block, channel::ChannelConfig, context::DspContext, graph::MAX_BLOCK_INPUTS, parameter::ModulationOutput,
20    sample::Sample,
21};
22
23/// Binaural decoding strategy.
24///
25/// Determines how multi-channel audio is converted to binaural stereo.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum BinauralStrategy {
28    /// Lightweight matrix-based decoder using ILD (Interaural Level Difference).
29    ///
30    /// Uses psychoacoustically-informed coefficients to approximate binaural cues.
31    /// Low CPU usage but limited spatial accuracy.
32    Matrix,
33
34    /// Full HRTF (Head-Related Transfer Function) convolution.
35    ///
36    /// Uses measured impulse responses to accurately model how sounds arrive
37    /// at each ear from different directions. Higher CPU usage but superior
38    /// spatial rendering with proper externalization.
39    Hrtf,
40}
41
42impl Default for BinauralStrategy {
43    fn default() -> Self {
44        Self::Hrtf
45    }
46}
47
48/// Decodes multi-channel audio to stereo for headphone listening.
49///
50/// Supports ambisonic B-format (1st, 2nd, 3rd order) and can be configured
51/// to use either lightweight matrix decoding or full HRTF convolution.
52///
53/// # Supported Input Formats
54/// - **Ambisonics**: FOA (4 ch), SOA (9 ch), TOA (16 ch)
55/// - **Surround**: 5.1 (6 ch), 7.1 (8 ch)
56///
57/// # Output
58/// Always stereo (2 channels): Left, Right
59///
60/// # Example
61/// ```ignore
62/// use bbx_dsp::blocks::effectors::binaural_decoder::{BinauralDecoderBlock, BinauralStrategy};
63///
64/// // Create HRTF decoder (default)
65/// let hrtf_decoder = BinauralDecoderBlock::<f32>::new(1);
66///
67/// // Create lightweight matrix decoder
68/// let matrix_decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
69/// ```
70pub struct BinauralDecoderBlock<S: Sample> {
71    input_count: usize,
72    strategy: BinauralStrategy,
73    decoder_matrix: [[f64; MAX_BLOCK_INPUTS]; 2],
74    hrtf_convolver: Option<Box<HrtfConvolver>>,
75    _phantom: PhantomData<S>,
76}
77
78impl<S: Sample> BinauralDecoderBlock<S> {
79    /// Create a new binaural decoder for ambisonics with the default strategy (HRTF).
80    ///
81    /// # Arguments
82    /// * `order` - Ambisonic order (1, 2, or 3)
83    ///
84    /// # Panics
85    /// Panics if order is not 1, 2, or 3.
86    pub fn new(order: usize) -> Self {
87        Self::with_strategy(order, BinauralStrategy::default())
88    }
89
90    /// Create a new binaural decoder for ambisonics with a specific strategy.
91    ///
92    /// # Arguments
93    /// * `order` - Ambisonic order (1, 2, or 3)
94    /// * `strategy` - The decoding strategy to use
95    ///
96    /// # Panics
97    /// Panics if order is not 1, 2, or 3.
98    pub fn with_strategy(order: usize, strategy: BinauralStrategy) -> Self {
99        assert!((1..=3).contains(&order), "Ambisonic order must be 1, 2, or 3");
100
101        let input_count = (order + 1) * (order + 1);
102        let decoder_matrix = matrix::compute_matrix(order);
103
104        let hrtf_convolver = match strategy {
105            BinauralStrategy::Matrix => None,
106            BinauralStrategy::Hrtf => Some(Box::new(HrtfConvolver::new_ambisonic(order))),
107        };
108
109        Self {
110            input_count,
111            strategy,
112            decoder_matrix,
113            hrtf_convolver,
114            _phantom: PhantomData,
115        }
116    }
117
118    /// Create a new binaural decoder for surround sound.
119    ///
120    /// # Arguments
121    /// * `channel_count` - Number of input channels (6 for 5.1, 8 for 7.1)
122    /// * `strategy` - The decoding strategy to use
123    ///
124    /// # Panics
125    /// Panics if channel_count is not 6 or 8.
126    pub fn new_surround(channel_count: usize, strategy: BinauralStrategy) -> Self {
127        assert!(
128            channel_count == 6 || channel_count == 8,
129            "Surround channel count must be 6 (5.1) or 8 (7.1)"
130        );
131
132        let decoder_matrix = [[0.0; MAX_BLOCK_INPUTS]; 2];
133
134        let hrtf_convolver = match strategy {
135            BinauralStrategy::Matrix => None,
136            BinauralStrategy::Hrtf => Some(Box::new(HrtfConvolver::new_surround(channel_count))),
137        };
138
139        Self {
140            input_count: channel_count,
141            strategy,
142            decoder_matrix,
143            hrtf_convolver,
144            _phantom: PhantomData,
145        }
146    }
147
148    /// Returns the ambisonic order (for ambisonic inputs).
149    ///
150    /// Returns 0 for surround inputs.
151    pub fn order(&self) -> usize {
152        match self.input_count {
153            4 => 1,
154            9 => 2,
155            16 => 3,
156            _ => 0,
157        }
158    }
159
160    /// Returns the current decoding strategy.
161    pub fn strategy(&self) -> BinauralStrategy {
162        self.strategy
163    }
164
165    /// Reset the HRTF convolver state (clears convolution buffers).
166    pub fn reset(&mut self) {
167        if let Some(ref mut convolver) = self.hrtf_convolver {
168            convolver.reset();
169        }
170    }
171
172    fn process_matrix(&self, inputs: &[&[S]], outputs: &mut [&mut [S]]) {
173        let num_inputs = self.input_count.min(inputs.len());
174        let num_outputs = 2.min(outputs.len());
175
176        if num_inputs == 0 || num_outputs == 0 || inputs[0].is_empty() {
177            return;
178        }
179
180        let num_samples = inputs[0].len().min(outputs[0].len());
181
182        for (out_ch, output) in outputs.iter_mut().enumerate().take(num_outputs) {
183            for i in 0..num_samples {
184                let mut sum = 0.0f64;
185                for (in_ch, input) in inputs.iter().enumerate().take(num_inputs) {
186                    sum += input[i].to_f64() * self.decoder_matrix[out_ch][in_ch];
187                }
188                output[i] = S::from_f64(sum);
189            }
190        }
191    }
192
193    fn process_hrtf(&mut self, inputs: &[&[S]], outputs: &mut [&mut [S]]) {
194        if outputs.len() < 2 {
195            return;
196        }
197
198        let num_inputs = self.input_count.min(inputs.len());
199
200        if let Some(ref mut convolver) = self.hrtf_convolver {
201            let (left_output, rest) = outputs.split_at_mut(1);
202            let right_output = &mut rest[0];
203            convolver.process(inputs, left_output[0], right_output, num_inputs);
204        }
205    }
206}
207
208impl<S: Sample> Block<S> for BinauralDecoderBlock<S> {
209    fn process(&mut self, inputs: &[&[S]], outputs: &mut [&mut [S]], _modulation_values: &[S], _context: &DspContext) {
210        match self.strategy {
211            BinauralStrategy::Matrix => self.process_matrix(inputs, outputs),
212            BinauralStrategy::Hrtf => self.process_hrtf(inputs, outputs),
213        }
214    }
215
216    #[inline]
217    fn input_count(&self) -> usize {
218        self.input_count
219    }
220
221    #[inline]
222    fn output_count(&self) -> usize {
223        2
224    }
225
226    #[inline]
227    fn modulation_outputs(&self) -> &[ModulationOutput] {
228        &[]
229    }
230
231    #[inline]
232    fn channel_config(&self) -> ChannelConfig {
233        ChannelConfig::Explicit
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::channel::ChannelLayout;
241
242    fn test_context() -> DspContext {
243        DspContext {
244            sample_rate: 44100.0,
245            num_channels: 2,
246            buffer_size: 4,
247            current_sample: 0,
248            channel_layout: ChannelLayout::Stereo,
249        }
250    }
251
252    #[test]
253    fn test_default_strategy_is_hrtf() {
254        assert_eq!(BinauralStrategy::default(), BinauralStrategy::Hrtf);
255    }
256
257    #[test]
258    fn test_new_uses_default_strategy() {
259        let decoder = BinauralDecoderBlock::<f32>::new(1);
260        assert_eq!(decoder.strategy(), BinauralStrategy::Hrtf);
261    }
262
263    #[test]
264    fn test_with_strategy_matrix() {
265        let decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
266        assert_eq!(decoder.strategy(), BinauralStrategy::Matrix);
267    }
268
269    #[test]
270    fn test_binaural_foa_channel_counts() {
271        let decoder = BinauralDecoderBlock::<f32>::new(1);
272        assert_eq!(decoder.input_count(), 4);
273        assert_eq!(decoder.output_count(), 2);
274        assert_eq!(decoder.channel_config(), ChannelConfig::Explicit);
275    }
276
277    #[test]
278    fn test_binaural_soa_channel_counts() {
279        let decoder = BinauralDecoderBlock::<f32>::new(2);
280        assert_eq!(decoder.input_count(), 9);
281        assert_eq!(decoder.output_count(), 2);
282    }
283
284    #[test]
285    fn test_binaural_toa_channel_counts() {
286        let decoder = BinauralDecoderBlock::<f32>::new(3);
287        assert_eq!(decoder.input_count(), 16);
288        assert_eq!(decoder.output_count(), 2);
289    }
290
291    #[test]
292    fn test_binaural_front_signal_balanced() {
293        let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
294        let context = test_context();
295
296        // Front signal: W=1, Y=0, Z=0, X=1
297        let w = [1.0f32; 4];
298        let y = [0.0f32; 4];
299        let z = [0.0f32; 4];
300        let x = [1.0f32; 4];
301        let mut left_out = [0.0f32; 4];
302        let mut right_out = [0.0f32; 4];
303
304        let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
305        let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
306
307        decoder.process(&inputs, &mut outputs, &[], &context);
308
309        let diff = (left_out[0] - right_out[0]).abs();
310        assert!(diff < 0.01, "Front signal should be balanced, diff={}", diff);
311    }
312
313    #[test]
314    fn test_binaural_left_signal_louder_in_left() {
315        let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
316        let context = test_context();
317
318        // Left signal: W=1, Y=1, Z=0, X=0
319        let w = [1.0f32; 4];
320        let y = [1.0f32; 4];
321        let z = [0.0f32; 4];
322        let x = [0.0f32; 4];
323        let mut left_out = [0.0f32; 4];
324        let mut right_out = [0.0f32; 4];
325
326        let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
327        let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
328
329        decoder.process(&inputs, &mut outputs, &[], &context);
330
331        assert!(
332            left_out[0] > right_out[0],
333            "Left signal should be louder in left channel: L={}, R={}",
334            left_out[0],
335            right_out[0]
336        );
337    }
338
339    #[test]
340    fn test_binaural_right_signal_louder_in_right() {
341        let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
342        let context = test_context();
343
344        // Right signal: W=1, Y=-1, Z=0, X=0
345        let w = [1.0f32; 4];
346        let y = [-1.0f32; 4];
347        let z = [0.0f32; 4];
348        let x = [0.0f32; 4];
349        let mut left_out = [0.0f32; 4];
350        let mut right_out = [0.0f32; 4];
351
352        let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
353        let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
354
355        decoder.process(&inputs, &mut outputs, &[], &context);
356
357        assert!(
358            right_out[0] > left_out[0],
359            "Right signal should be louder in right channel: L={}, R={}",
360            left_out[0],
361            right_out[0]
362        );
363    }
364
365    #[test]
366    fn test_binaural_rear_signal_balanced() {
367        let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
368        let context = test_context();
369
370        // Rear signal: W=1, Y=0, Z=0, X=-1
371        let w = [1.0f32; 4];
372        let y = [0.0f32; 4];
373        let z = [0.0f32; 4];
374        let x = [-1.0f32; 4];
375        let mut left_out = [0.0f32; 4];
376        let mut right_out = [0.0f32; 4];
377
378        let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
379        let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
380
381        decoder.process(&inputs, &mut outputs, &[], &context);
382
383        let diff = (left_out[0] - right_out[0]).abs();
384        assert!(diff < 0.01, "Rear signal should be balanced, diff={}", diff);
385    }
386
387    #[test]
388    fn test_binaural_silence_produces_silence() {
389        let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(1, BinauralStrategy::Matrix);
390        let context = test_context();
391
392        let w = [0.0f32; 4];
393        let y = [0.0f32; 4];
394        let z = [0.0f32; 4];
395        let x = [0.0f32; 4];
396        let mut left_out = [1.0f32; 4];
397        let mut right_out = [1.0f32; 4];
398
399        let inputs: [&[f32]; 4] = [&w, &y, &z, &x];
400        let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
401
402        decoder.process(&inputs, &mut outputs, &[], &context);
403
404        for i in 0..4 {
405            assert!(left_out[i].abs() < 1e-10, "Left output should be silence");
406            assert!(right_out[i].abs() < 1e-10, "Right output should be silence");
407        }
408    }
409
410    #[test]
411    fn test_binaural_order_accessor() {
412        assert_eq!(BinauralDecoderBlock::<f32>::new(1).order(), 1);
413        assert_eq!(BinauralDecoderBlock::<f32>::new(2).order(), 2);
414        assert_eq!(BinauralDecoderBlock::<f32>::new(3).order(), 3);
415    }
416
417    #[test]
418    #[should_panic]
419    fn test_binaural_invalid_order_zero_panics() {
420        let _ = BinauralDecoderBlock::<f32>::new(0);
421    }
422
423    #[test]
424    #[should_panic]
425    fn test_binaural_invalid_order_four_panics() {
426        let _ = BinauralDecoderBlock::<f32>::new(4);
427    }
428
429    #[test]
430    fn test_binaural_soa_lateral_differentiation() {
431        let mut decoder = BinauralDecoderBlock::<f32>::with_strategy(2, BinauralStrategy::Matrix);
432        let context = test_context();
433
434        // Only V channel active (ACN index 4)
435        let mut inputs_data: [[f32; 4]; 9] = [[0.0; 4]; 9];
436        inputs_data[4] = [1.0; 4];
437
438        let inputs: Vec<&[f32]> = inputs_data.iter().map(|a| a.as_slice()).collect();
439        let mut left_out = [0.0f32; 4];
440        let mut right_out = [0.0f32; 4];
441        let mut outputs: [&mut [f32]; 2] = [&mut left_out, &mut right_out];
442
443        decoder.process(&inputs, &mut outputs, &[], &context);
444
445        // V channel should create opposite signs for L/R
446        assert!(
447            left_out[0] * right_out[0] < 0.0,
448            "V channel should create opposite polarity: L={}, R={}",
449            left_out[0],
450            right_out[0]
451        );
452    }
453
454    // f64 variant tests
455
456    #[test]
457    fn test_binaural_foa_channel_counts_f64() {
458        let decoder = BinauralDecoderBlock::<f64>::new(1);
459        assert_eq!(decoder.input_count(), 4);
460        assert_eq!(decoder.output_count(), 2);
461        assert_eq!(decoder.channel_config(), ChannelConfig::Explicit);
462    }
463
464    #[test]
465    fn test_binaural_soa_channel_counts_f64() {
466        let decoder = BinauralDecoderBlock::<f64>::new(2);
467        assert_eq!(decoder.input_count(), 9);
468        assert_eq!(decoder.output_count(), 2);
469    }
470
471    #[test]
472    fn test_binaural_toa_channel_counts_f64() {
473        let decoder = BinauralDecoderBlock::<f64>::new(3);
474        assert_eq!(decoder.input_count(), 16);
475        assert_eq!(decoder.output_count(), 2);
476    }
477
478    #[test]
479    fn test_binaural_front_signal_balanced_f64() {
480        let mut decoder = BinauralDecoderBlock::<f64>::with_strategy(1, BinauralStrategy::Matrix);
481        let context = test_context();
482
483        let w = [1.0f64; 4];
484        let y = [0.0f64; 4];
485        let z = [0.0f64; 4];
486        let x = [1.0f64; 4];
487        let mut left_out = [0.0f64; 4];
488        let mut right_out = [0.0f64; 4];
489
490        let inputs: [&[f64]; 4] = [&w, &y, &z, &x];
491        let mut outputs: [&mut [f64]; 2] = [&mut left_out, &mut right_out];
492
493        decoder.process(&inputs, &mut outputs, &[], &context);
494
495        let diff = (left_out[0] - right_out[0]).abs();
496        assert!(diff < 0.01, "Front signal should be balanced, diff={}", diff);
497    }
498
499    #[test]
500    fn test_binaural_left_signal_louder_in_left_f64() {
501        let mut decoder = BinauralDecoderBlock::<f64>::with_strategy(1, BinauralStrategy::Matrix);
502        let context = test_context();
503
504        let w = [1.0f64; 4];
505        let y = [1.0f64; 4];
506        let z = [0.0f64; 4];
507        let x = [0.0f64; 4];
508        let mut left_out = [0.0f64; 4];
509        let mut right_out = [0.0f64; 4];
510
511        let inputs: [&[f64]; 4] = [&w, &y, &z, &x];
512        let mut outputs: [&mut [f64]; 2] = [&mut left_out, &mut right_out];
513
514        decoder.process(&inputs, &mut outputs, &[], &context);
515
516        assert!(
517            left_out[0] > right_out[0],
518            "Left signal should be louder in left channel: L={}, R={}",
519            left_out[0],
520            right_out[0]
521        );
522    }
523
524    #[test]
525    fn test_binaural_silence_produces_silence_f64() {
526        let mut decoder = BinauralDecoderBlock::<f64>::with_strategy(1, BinauralStrategy::Matrix);
527        let context = test_context();
528
529        let w = [0.0f64; 4];
530        let y = [0.0f64; 4];
531        let z = [0.0f64; 4];
532        let x = [0.0f64; 4];
533        let mut left_out = [1.0f64; 4];
534        let mut right_out = [1.0f64; 4];
535
536        let inputs: [&[f64]; 4] = [&w, &y, &z, &x];
537        let mut outputs: [&mut [f64]; 2] = [&mut left_out, &mut right_out];
538
539        decoder.process(&inputs, &mut outputs, &[], &context);
540
541        for i in 0..4 {
542            assert!(left_out[i].abs() < 1e-14, "Left output should be silence");
543            assert!(right_out[i].abs() < 1e-14, "Right output should be silence");
544        }
545    }
546
547    #[test]
548    fn test_binaural_order_accessor_f64() {
549        assert_eq!(BinauralDecoderBlock::<f64>::new(1).order(), 1);
550        assert_eq!(BinauralDecoderBlock::<f64>::new(2).order(), 2);
551        assert_eq!(BinauralDecoderBlock::<f64>::new(3).order(), 3);
552    }
553}