Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
COPYBARA_INTEGRATE_REVIEW=#19 from JasonMH17:master c81cc0f
PiperOrigin-RevId: 695494768
  • Loading branch information
JasonMH17 authored and copybara-github committed Nov 11, 2024
1 parent c80749a commit 1eff475
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 61 deletions.
45 changes: 33 additions & 12 deletions python/jax/carfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2301,7 +2301,9 @@ def run_segment(
weights: CarfacWeights,
state: CarfacState,
open_loop: bool = False,
) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
"""This function runs the entire CARFAC model.
That is, filters a 1 or more channel
Expand Down Expand Up @@ -2335,6 +2337,8 @@ def run_segment(
Returns:
naps: neural activity pattern
naps_fibers: neural activity of different fibers
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: the updated state of the CARFAC model.
BM: The basilar membrane motion
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
Expand All @@ -2343,6 +2347,7 @@ def run_segment(
if len(input_waves.shape) < 2:
input_waves = jnp.reshape(input_waves, (-1, 1))
[n_samp, n_ears] = input_waves.shape
n_fibertypes = SynDesignParameters.n_classes

# TODO(honglinyu): add more assertions using checkify.
# if n_ears != cfp.n_ears:
Expand All @@ -2352,6 +2357,7 @@ def run_segment(

n_ch = hypers.ears[0].car.n_ch
naps = jnp.zeros((n_samp, n_ch, n_ears)) # allocate space for result
naps_fibers = jnp.zeros((n_samp, n_ch, n_fibertypes, n_ears))
bm = jnp.zeros((n_samp, n_ch, n_ears))
seg_ohc = jnp.zeros((n_samp, n_ch, n_ears))
seg_agc = jnp.zeros((n_samp, n_ch, n_ears))
Expand All @@ -2370,7 +2376,7 @@ def run_segment(
# Note that we can use naive for loops here because it will make gradient
# computation very slow.
def run_segment_scan_helper(carry, k):
naps, state, bm, seg_ohc, seg_agc, input_waves = carry
naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves = carry
agc_updated = False
for ear in range(n_ears):
# This would be cleaner if we could just get and use a reference to
Expand All @@ -2385,9 +2391,14 @@ def run_segment_scan_helper(carry, k):
)

if hypers.ears[ear].syn.do_syn:
ihc_out, _, state.ears[ear].syn = syn_step(
ihc_out, firings, state.ears[ear].syn = syn_step(
v_recep, ear, weights, state.ears[ear].syn
)
naps_fibers = naps_fibers.at[k, :, :, ear].set(firings)
else:
naps_fibers = naps_fibers.at[k, :, :, ear].set(
jnp.zeros([jnp.shape(ihc_out)[0], n_fibertypes])
)

# run the AGC update step, decimating internally,
agc_updated, state.ears[ear].agc = agc_step(
Expand Down Expand Up @@ -2420,11 +2431,11 @@ def close_agc_loop_helper(
state,
)

return (naps, state, bm, seg_ohc, seg_agc, input_waves), None
return (naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves), None

return jax.lax.scan(
run_segment_scan_helper,
(naps, state, bm, seg_ohc, seg_agc, input_waves),
(naps, naps_fibers, state, bm, seg_ohc, seg_agc, input_waves),
jnp.arange(n_samp),
)[0][:-1]

Expand All @@ -2442,7 +2453,9 @@ def run_segment_jit(
weights: CarfacWeights,
state: CarfacState,
open_loop: bool = False,
) -> Tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> Tuple[
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
"""A JITted version of run_segment for convenience.
Care should be taken with the hyper parameters in hypers. If the hypers object
Expand All @@ -2468,6 +2481,8 @@ def run_segment_jit(
Returns:
naps: neural activity pattern
naps_fibers: neural activity of the different fiber types
(only populated with non-zeros when ihc_style equals "two_cap_with_syn")
state: the updated state of the CARFAC model.
BM: The basilar membrane motion
seg_ohc & seg_agc are optional extra outputs useful for seeing what the
Expand All @@ -2483,7 +2498,9 @@ def run_segment_jit_in_chunks_notraceable(
state: CarfacState,
open_loop: bool = False,
segment_chunk_length: int = 32 * 48000,
) -> tuple[jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
) -> tuple[
jnp.ndarray, jnp.ndarray, CarfacState, jnp.ndarray, jnp.ndarray, jnp.ndarray
]:
"""Runs the jitted segment runner in segment groups.
Running CARFAC on an audio segment this way is most useful when running
Expand Down Expand Up @@ -2526,6 +2543,7 @@ def run_segment_jit_in_chunks_notraceable(
if len(input_waves.shape) < 2:
input_waves = jnp.reshape(input_waves, (-1, 1))
naps_out = []
naps_fibers_out = []
bm_out = []
ohc_out = []
agc_out = []
Expand All @@ -2534,10 +2552,11 @@ def run_segment_jit_in_chunks_notraceable(
[n_samp, _] = input_waves.shape
if n_samp >= segment_length:
[current_waves, input_waves] = jnp.split(input_waves, [segment_length], 0)
naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit(
current_waves, hypers, weights, state, open_loop
naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = (
run_segment_jit(current_waves, hypers, weights, state, open_loop)
)
naps_out.append(naps_jax)
naps_fibers_out.append(naps_fibers_jax)
bm_out.append(bm_jax)
ohc_out.append(seg_ohc_jax)
agc_out.append(seg_agc_jax)
Expand All @@ -2546,15 +2565,17 @@ def run_segment_jit_in_chunks_notraceable(
[n_samp, _] = input_waves.shape
# Take the last few items and just run them.
if n_samp > 0:
naps_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = run_segment_jit(
input_waves, hypers, weights, state, open_loop
naps_jax, naps_fibers_jax, state, bm_jax, seg_ohc_jax, seg_agc_jax = (
run_segment_jit(input_waves, hypers, weights, state, open_loop)
)
naps_out.append(naps_jax)
naps_fibers_out.append(naps_fibers_jax)
bm_out.append(bm_jax)
ohc_out.append(seg_ohc_jax)
agc_out.append(seg_agc_jax)
naps_out = np.concatenate(naps_out, 0)
naps_fibers_out = np.concatenate(naps_fibers_out, 0)
bm_out = np.concatenate(bm_out, 0)
ohc_out = np.concatenate(ohc_out, 0)
agc_out = np.concatenate(agc_out, 0)
return naps_out, state, bm_out, ohc_out, agc_out
return naps_out, naps_fibers_out, state, bm_out, ohc_out, agc_out
12 changes: 6 additions & 6 deletions python/jax/carfac_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def loss_func(
weights: carfac_jax.CarfacWeights,
state: carfac_jax.CarfacState,
):
nap_output, _, _, _, _ = carfac_jax.run_segment(
nap_output, _, _, _, _, _ = carfac_jax.run_segment(
audio, hypers, weights, state
)
return jnp.sum(nap_output), nap_output
Expand Down Expand Up @@ -242,7 +242,7 @@ def bench_jit_compile_time(state: google_benchmark.State):
# that this benchmark is appropriate.
n_samp += 1
state.resume_timing()
naps_jax, state_jax, _, _, _ = carfac_jax.run_segment_jit(
naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax.block_until_ready()
Expand Down Expand Up @@ -295,7 +295,7 @@ def bench_jax_in_slices(state: google_benchmark.State):
for _, segment in enumerate(silence_slices):
if segment.shape not in compiled_shapes:
compiled_shapes.add(segment.shape)
naps_jax, _, _, _, _ = carfac_jax.run_segment_jit(
naps_jax, _, _, _, _, _ = carfac_jax.run_segment_jit(
segment, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax.block_until_ready()
Expand All @@ -316,7 +316,7 @@ def bench_jax_in_slices(state: google_benchmark.State):
jax_loop_state = state_jax
state.resume_timing()
for _, segment in enumerate(run_seg_slices):
seg_naps, jax_loop_state, seg_bm, seg_ohc, seg_agc = (
seg_naps, _, jax_loop_state, seg_bm, seg_ohc, seg_agc = (
carfac_jax.run_segment_jit(
segment, hypers_jax, weights_jax, jax_loop_state, open_loop=False
)
Expand Down Expand Up @@ -389,7 +389,7 @@ def bench_jax(state: google_benchmark.State):
params_jax
)
short_silence = jnp.zeros(shape=(n_samp, n_ears))
naps_jax, state_jax, _, _, _ = run_segment_function(
naps_jax, _, state_jax, _, _, _ = run_segment_function(
short_silence, hypers_jax, weights_jax, state_jax, open_loop=False
)
# This block ensures calculation.
Expand All @@ -404,7 +404,7 @@ def bench_jax(state: google_benchmark.State):
jax.random.normal(random_generator, (n_samp, n_ears)) * _NOISE_FACTOR
).block_until_ready()
state.resume_timing()
naps_jax, state_jax, _, _, _ = run_segment_function(
naps_jax, _, state_jax, _, _, _ = run_segment_function(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
if state.range(0) != 1:
Expand Down
2 changes: 1 addition & 1 deletion python/jax/carfac_float64_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def loss(weights, input_waves, hypers, state):
# A loss function for tests. Note that we shouldn't use `run_segment_jit`
# here because it will donate the `state` which causes unnecessary
# inconvenience for tests.
naps_jax, state_jax, _, _, _ = carfac_jax.run_segment(
naps_jax, _, state_jax, _, _, _ = carfac_jax.run_segment(
input_waves, hypers, weights, state, open_loop=False
)
# For testing, just fit `naps` to 1.
Expand Down
29 changes: 18 additions & 11 deletions python/jax/carfac_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,25 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style):
state_jax_copied = copy.deepcopy(state_jax)

# Only tests the JITted version because this is what we will use.
naps_jax, _, bm_jax, ohc_jax, agc_jax = carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
naps_jax_chunked, _, bm_chunked, ohc_chunked, agc_chunked = (
carfac_jax.run_segment_jit_in_chunks_notraceable(
run_seg_input,
hypers_jax,
weights_jax,
state_jax_copied,
open_loop=False,
naps_jax, _, _, bm_jax, ohc_jax, agc_jax = (
carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
)
(
naps_jax_chunked,
_,
_,
bm_chunked,
ohc_chunked,
agc_chunked,
) = carfac_jax.run_segment_jit_in_chunks_notraceable(
run_seg_input,
hypers_jax,
weights_jax,
state_jax_copied,
open_loop=False,
)
self.assertLess(jnp.max(abs(naps_jax_chunked - naps_jax)), 1e-7)
self.assertLess(jnp.max(abs(bm_chunked - bm_jax)), 1e-7)
self.assertLess(jnp.max(abs(ohc_chunked - ohc_jax)), 1e-7)
Expand Down Expand Up @@ -380,7 +387,7 @@ def test_equal_forward_pass(
run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears))

# Only tests the JITted version because this is what we will use.
naps_jax, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = (
naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = (
carfac_jax.run_segment_jit(
run_seg_input, hypers_jax, weights_jax, state_jax, open_loop=False
)
Expand Down
33 changes: 25 additions & 8 deletions python/jax/carfac_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def run_multiple_segment_states_shmap(
open_loop: bool = False,
) -> Sequence[
Tuple[
jnp.ndarray,
jnp.ndarray,
carfac_jax.CarfacState,
jnp.ndarray,
Expand Down Expand Up @@ -88,30 +89,39 @@ def parallel_helper(input_waves, state):
"""
input_waves = input_waves[0]
state = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), state)
naps, ret_state, bm, seg_ohc, seg_agc = carfac_jax.run_segment_jit(
input_waves, hypers, weights, state, open_loop
naps, naps_fibers, ret_state, bm, seg_ohc, seg_agc = (
carfac_jax.run_segment_jit(
input_waves, hypers, weights, state, open_loop
)
)
ret_state = jax.tree_util.tree_map(
lambda x: jnp.asarray(x).reshape((1, -1)), ret_state
)
return (
naps[None],
naps_fibers[None],
ret_state,
bm[None],
seg_ohc[None],
seg_agc[None],
)

stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = (
parallel_helper(input_waves_array, batch_state)
)
(
stacked_naps,
stacked_naps_fibers,
stacked_states,
stacked_bm,
stacked_ohc,
stacked_agc,
) = parallel_helper(input_waves_array, batch_state)
output_states = _tree_unstack(stacked_states)
output = []
# TODO(robsc): Modify this for loop to a jax.lax loop, and then JIT the
# whole function rather than internal use of run_segment_jit.
for i, output_state in enumerate(output_states):
tup = (
stacked_naps[i],
stacked_naps_fibers[i],
output_state,
stacked_bm[i],
stacked_ohc[i],
Expand All @@ -130,6 +140,7 @@ def run_multiple_segment_pmap(
open_loop: bool = False,
) -> Sequence[
Tuple[
jnp.ndarray,
jnp.ndarray,
carfac_jax.CarfacState,
jnp.ndarray,
Expand All @@ -155,15 +166,21 @@ def run_multiple_segment_pmap(
in_axes=(0, None, None, None, None),
static_broadcasted_argnums=[1, 4],
)
stacked_naps, stacked_states, stacked_bm, stacked_ohc, stacked_agc = pmapped(
input_waves_array, hypers, weights, state, open_loop
)
(
stacked_naps,
stacked_naps_fibers,
stacked_states,
stacked_bm,
stacked_ohc,
stacked_agc,
) = pmapped(input_waves_array, hypers, weights, state, open_loop)

output_states = _tree_unstack(stacked_states)
output = []
for i, output_state in enumerate(output_states):
tup = (
stacked_naps[i],
stacked_naps_fibers[i],
output_state,
stacked_bm[i],
stacked_ohc[i],
Expand Down
Loading

0 comments on commit 1eff475

Please sign in to comment.