From c80749a875cae3ad73f5878ea58e185c8ef5e24f Mon Sep 17 00:00:00 2001 From: Rob Schonberger Date: Mon, 28 Oct 2024 22:08:27 -0700 Subject: [PATCH] Add a benchmark for two_cap_with_syn to JAX and NumPy benchmarks, and the JAX grad benchmark. This does so by simply adding yet more configurations, which is fine, and some slight translation to test two_cap/two_cap_with_syn. Although we could benchmark one_cap we do not look for that at the moment. PiperOrigin-RevId: 690872015 --- python/jax/carfac_bench.py | 63 ++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/python/jax/carfac_bench.py b/python/jax/carfac_bench.py index 746b5dc..6a07e6f 100644 --- a/python/jax/carfac_bench.py +++ b/python/jax/carfac_bench.py @@ -93,12 +93,16 @@ def bench_numpy_in_slices(state: google_benchmark.State): @google_benchmark.option.measure_process_cpu_time() @google_benchmark.option.use_real_time() @google_benchmark.option.unit(google_benchmark.kMicrosecond) -@google_benchmark.option.arg_names(['segment_sample_length']) -@google_benchmark.option.args([220]) -@google_benchmark.option.args([2205]) -@google_benchmark.option.args([22050]) -@google_benchmark.option.args([44100]) -@google_benchmark.option.args([220500]) +@google_benchmark.option.arg_names([ + 'segment_sample_length', + 'ihc_style', +]) +@google_benchmark.option.args_product( + [ + [220, 2205, 22050, 44100, 220500], + [0, 1], + ], +) def bench_numpy(state: google_benchmark.State): """Benchmark the numpy version of carfac. @@ -109,7 +113,12 @@ def bench_numpy(state: google_benchmark.State): state: the benchmark state for this execution run. """ random_seed = 1 - ihc_style = 'two_cap' + if state.range(1) == 0: + ihc_style = 'two_cap' + elif state.range(1) == 1: + ihc_style = 'two_cap_with_syn' + else: + raise ValueError('Invalid ihc_style') cfp = carfac_np.design_carfac(ihc_style=ihc_style) carfac_np.carfac_init(cfp) @@ -134,16 +143,25 @@ def bench_numpy(state: google_benchmark.State): @google_benchmark.option.measure_process_cpu_time() @google_benchmark.option.use_real_time() @google_benchmark.option.unit(google_benchmark.kMicrosecond) -@google_benchmark.option.arg_names(['segment_sample_length']) -@google_benchmark.option.range_multiplier(2) -@google_benchmark.option.range(128, 4096) +@google_benchmark.option.arg_names(['segment_sample_length', 'ihc_style']) +@google_benchmark.option.args_product( + [ + [128, 256, 512, 1024, 2048, 4096], + [0, 1], + ], +) def bench_jax_grad(state: google_benchmark.State): """Benchmark JAX Value and Gradient function on Carfac. Args: state: The Benchmark state for this run. """ - ihc_style = 'two_cap' + if state.range(1) == 0: + ihc_style = 'two_cap' + elif state.range(1) == 1: + ihc_style = 'two_cap_with_syn' + else: + raise ValueError('Invalid ihc_style') random_seed = 1 params_jax = carfac_jax.CarfacDesignParameters() params_jax.ears[0].ihc.ihc_style = ihc_style @@ -313,11 +331,19 @@ def bench_jax_in_slices(state: google_benchmark.State): @google_benchmark.option.measure_process_cpu_time() @google_benchmark.option.use_real_time() @google_benchmark.option.unit(google_benchmark.kMicrosecond) -@google_benchmark.option.arg_names( - ['jax_chunked_uncompiled', 'segment_sample_length', 'use_delay_buffer'] -) +@google_benchmark.option.arg_names([ + 'jax_chunked_uncompiled', + 'segment_sample_length', + 'use_delay_buffer', + 'ihc_style', +]) @google_benchmark.option.args_product( - [[0, 1, 2], [220, 2205, 22050, 44100, 220500, 2205000], [False, True]], + [ + [0, 1, 2], + [220, 2205, 22050, 44100, 220500, 2205000], + [False, True], + [0, 1], + ], ) def bench_jax(state: google_benchmark.State): """Benchmark the JAX version of carfac. @@ -335,7 +361,12 @@ def bench_jax(state: google_benchmark.State): state: the benchmark state for this execution run. """ # Inits JAX version - ihc_style = 'two_cap' + if state.range(3) == 0: + ihc_style = 'two_cap' + elif state.range(3) == 1: + ihc_style = 'two_cap_with_syn' + else: + raise ValueError('Invalid ihc_style') random_seed = 1 params_jax = carfac_jax.CarfacDesignParameters() params_jax.ears[0].car.use_delay_buffer = state.range(2)