1use std::time::Instant;
4
5use bbx_dsp::{
6 buffer::{AudioBuffer, Buffer},
7 graph::Graph,
8 sample::Sample,
9 writer::Writer,
10};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum RenderDuration {
15 Duration(usize),
17 Samples(usize),
19}
20
21#[derive(Debug)]
23pub enum RenderError {
24 WriteFailed(Box<dyn std::error::Error>),
26 FinalizeFailed(Box<dyn std::error::Error>),
28 InvalidDuration(String),
30}
31
32impl std::fmt::Display for RenderError {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 RenderError::WriteFailed(e) => write!(f, "write failed: {e}"),
36 RenderError::FinalizeFailed(e) => write!(f, "finalize failed: {e}"),
37 RenderError::InvalidDuration(msg) => write!(f, "invalid duration: {msg}"),
38 }
39 }
40}
41
42impl std::error::Error for RenderError {}
43
44#[derive(Debug, Clone)]
46pub struct RenderStats {
47 pub samples_rendered: u64,
49 pub duration_seconds: f64,
51 pub render_time_seconds: f64,
53 pub speedup: f64,
55}
56
57pub struct OfflineRenderer<S: Sample> {
79 graph: Graph<S>,
80 writer: Box<dyn Writer<S>>,
81 output_buffers: Vec<AudioBuffer<S>>,
82 buffer_size: usize,
83 sample_rate: f64,
84 num_channels: usize,
85}
86
87impl<S: Sample> OfflineRenderer<S> {
88 pub fn new(graph: Graph<S>, writer: Box<dyn Writer<S>>) -> Self {
102 let context = graph.context();
103 let buffer_size = context.buffer_size;
104 let sample_rate = context.sample_rate;
105 let num_channels = context.num_channels;
106
107 assert!(
108 (writer.sample_rate() - sample_rate).abs() < 1.0,
109 "Writer sample rate ({}) must match graph sample rate ({})",
110 writer.sample_rate(),
111 sample_rate
112 );
113 assert_eq!(
114 writer.num_channels(),
115 num_channels,
116 "Writer channel count ({}) must match graph channel count ({})",
117 writer.num_channels(),
118 num_channels
119 );
120
121 let output_buffers = (0..num_channels).map(|_| AudioBuffer::new(buffer_size)).collect();
122
123 Self {
124 graph,
125 writer,
126 output_buffers,
127 buffer_size,
128 sample_rate,
129 num_channels,
130 }
131 }
132
133 pub fn render(&mut self, duration: RenderDuration) -> Result<RenderStats, RenderError> {
146 let num_samples = match duration {
147 RenderDuration::Duration(secs) => {
148 if secs == 0 {
149 return Err(RenderError::InvalidDuration("Duration must be positive".to_string()));
150 }
151 (secs as f64 * self.sample_rate) as u64
152 }
153 RenderDuration::Samples(samples) => {
154 if samples == 0 {
155 return Err(RenderError::InvalidDuration(
156 "Sample count must be positive".to_string(),
157 ));
158 }
159 samples as u64
160 }
161 };
162
163 let start_time = Instant::now();
164 let mut samples_rendered: u64 = 0;
165
166 while samples_rendered < num_samples {
167 let mut output_refs: Vec<&mut [S]> = self.output_buffers.iter_mut().map(|b| b.as_mut_slice()).collect();
168 self.graph.process_buffers(&mut output_refs);
169
170 let samples_remaining = num_samples - samples_rendered;
171 let samples_to_write = (self.buffer_size as u64).min(samples_remaining) as usize;
172
173 for (channel_idx, buffer) in self.output_buffers.iter().enumerate() {
174 self.writer
175 .write_channel(channel_idx, &buffer.as_slice()[..samples_to_write])
176 .map_err(RenderError::WriteFailed)?;
177 }
178
179 samples_rendered += samples_to_write as u64;
180 }
181
182 self.writer.finalize().map_err(RenderError::FinalizeFailed)?;
183
184 let render_time = start_time.elapsed().as_secs_f64();
185 let duration_seconds = samples_rendered as f64 / self.sample_rate;
186
187 Ok(RenderStats {
188 samples_rendered,
189 duration_seconds,
190 render_time_seconds: render_time,
191 speedup: duration_seconds / render_time,
192 })
193 }
194
195 #[inline]
197 pub fn sample_rate(&self) -> f64 {
198 self.sample_rate
199 }
200
201 #[inline]
203 pub fn num_channels(&self) -> usize {
204 self.num_channels
205 }
206
207 #[inline]
209 pub fn buffer_size(&self) -> usize {
210 self.buffer_size
211 }
212
213 pub fn into_graph(self) -> Graph<S> {
218 self.graph
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use bbx_dsp::{blocks::OscillatorBlock, context::DEFAULT_SAMPLE_RATE, graph::GraphBuilder, waveform::Waveform};
225
226 use super::*;
227
228 struct TestWriter {
229 sample_rate: f64,
230 num_channels: usize,
231 samples_written: Vec<Vec<f32>>,
232 finalized: bool,
233 }
234
235 impl TestWriter {
236 fn new(sample_rate: f64, num_channels: usize) -> Self {
237 Self {
238 sample_rate,
239 num_channels,
240 samples_written: vec![Vec::new(); num_channels],
241 finalized: false,
242 }
243 }
244 }
245
246 impl Writer<f32> for TestWriter {
247 fn sample_rate(&self) -> f64 {
248 self.sample_rate
249 }
250
251 fn num_channels(&self) -> usize {
252 self.num_channels
253 }
254
255 fn can_write(&self) -> bool {
256 !self.finalized
257 }
258
259 fn write_channel(&mut self, channel_index: usize, samples: &[f32]) -> Result<(), Box<dyn std::error::Error>> {
260 self.samples_written[channel_index].extend_from_slice(samples);
261 Ok(())
262 }
263
264 fn finalize(&mut self) -> Result<(), Box<dyn std::error::Error>> {
265 self.finalized = true;
266 Ok(())
267 }
268 }
269
270 fn create_test_graph() -> Graph<f32> {
271 let mut builder = GraphBuilder::<f32>::new(DEFAULT_SAMPLE_RATE, 512, 2);
272 builder.add(OscillatorBlock::new(440.0, Waveform::Sine, None));
273 builder.build()
274 }
275
276 #[test]
277 fn test_render_duration() {
278 let graph = create_test_graph();
279 let writer = TestWriter::new(DEFAULT_SAMPLE_RATE, 2);
280 let mut renderer = OfflineRenderer::new(graph, Box::new(writer));
281
282 let stats = renderer.render(RenderDuration::Duration(1)).unwrap();
283
284 assert_eq!(stats.samples_rendered, DEFAULT_SAMPLE_RATE as u64);
285 assert!((stats.duration_seconds - 1.0).abs() < 0.01);
286 assert!(stats.speedup > 1.0);
287 }
288
289 #[test]
290 fn test_render_samples() {
291 let graph = create_test_graph();
292 let writer = TestWriter::new(DEFAULT_SAMPLE_RATE, 2);
293 let mut renderer = OfflineRenderer::new(graph, Box::new(writer));
294
295 let stats = renderer.render(RenderDuration::Samples(1024)).unwrap();
296
297 assert_eq!(stats.samples_rendered, 1024);
298 }
299
300 #[test]
301 fn test_invalid_duration() {
302 let graph = create_test_graph();
303 let writer = TestWriter::new(DEFAULT_SAMPLE_RATE, 2);
304 let mut renderer = OfflineRenderer::new(graph, Box::new(writer));
305
306 let result = renderer.render(RenderDuration::Duration(0));
307 assert!(matches!(result, Err(RenderError::InvalidDuration(_))));
308 }
309
310 #[test]
311 #[should_panic(expected = "sample rate")]
312 fn test_mismatched_sample_rate() {
313 let graph = create_test_graph();
314 let writer = TestWriter::new(48000.0, 2);
315 let _renderer = OfflineRenderer::new(graph, Box::new(writer));
316 }
317
318 #[test]
319 #[should_panic(expected = "channel count")]
320 fn test_mismatched_channels() {
321 let graph = create_test_graph();
322 let writer = TestWriter::new(DEFAULT_SAMPLE_RATE, 1);
323 let _renderer = OfflineRenderer::new(graph, Box::new(writer));
324 }
325}