From aec1c6ed95b99b5f2de4659cd204f538b68a4a8e Mon Sep 17 00:00:00 2001 From: The Jaxite Team Date: Wed, 9 Oct 2024 13:42:31 -0700 Subject: [PATCH] Formatted and Verified High-Performance MatMul. PiperOrigin-RevId: 684157099 --- jaxite/jaxite_lib/matrix_utils.py | 449 ++++++++++++++++++++++++- jaxite/jaxite_lib/matrix_utils_test.py | 82 +++++ 2 files changed, 530 insertions(+), 1 deletion(-) diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index 543d3de..900b48a 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -84,7 +84,7 @@ def i32_as_u8_matmul(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: "np,nkq->kpq", lhs, rhs, - preferred_element_type=jnp.int32, + preferred_element_type=jnp.uint32, ) shift_factors = jnp.array( [ @@ -98,6 +98,453 @@ def i32_as_u8_matmul(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: return jnp.sum(i8_products << shift_factors, axis=(1, 2)) +def hpmatmul_conv_adapt_outer_product(x: jax.Array, y: jax.Array) -> jax.Array: + """Interleaved u8 matmul with fused einsum kernels. + + Args: + x: The left matrix. + y: The right matrix. + + Returns: + The result matrix. + """ + assert x.dtype == jnp.uint32 + assert y.dtype == jnp.uint32 + + lhs: jax.Array = int32_to_int8_arr(x) + rhs: jax.Array = int32_to_int8_arr(y) + + i8_products = jnp.einsum( + "mnp,nkq->mkpq", + lhs, + rhs, + preferred_element_type=jnp.uint32, + ) + shift_factors = jnp.array( + [ + [0, 8, 16, 24], + [8, 16, 24, 32], + [16, 24, 32, 40], + [24, 32, 40, 48], + ], + dtype=jnp.uint32, + ) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(2, 3)) + + +@jax.jit +def hpmatmul_conv_adapt_conv(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Interleaved u8 matmul with padded 1D convolution. + + (reformulated as 2D convolution) + + How do we map workload into Conv? + + Left Mat Right Mat -> + <- in channel (C)-> <-Output Channel(O)-> -> + - xxxxxxxxxxxxxxxxxx - xxxxxxxxxxxxxxxxxx -> - + ^ xxxxxxxxxxxxxxxxxx ^ xxxxxxxxxxxxxxxxxx -> ^ + | xxxxxxxxxxxxxxxxxx | xxxxxxxxxxxxxxxxxx -> | + batch xxxxxxxxxxxxxxxxxx In xxxxxxxxxxxxxxxxxx -> batch + (N) xxxxxxxxxxxxxxxxxx channel xxxxxxxxxxxxxxxxxx -> (N) + | xxxxxxxxxxxxxxxxxx (I) xxxxxxxxxxxxxxxxxx -> | + v xxxxxxxxxxxxxxxxxx v xxxxxxxxxxxxxxxxxx -> v + - xxxxxxxxxxxxxxxxxx - xxxxxxxxxxxxxxxxxx -> - + + Result Mat + <-Output channel(C)-> + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + + Each x in the above example is a 1DConv + + <---W---> <---W---> <---W---> + xxxxxxxxx @ xxxxxxxxx = xxxxxxxxx + + Args: + x: The left matrix. + y: The right matrix. + + Returns: + The result matrix. + """ + + assert x.dtype == jnp.uint32 + assert y.dtype == jnp.uint32 + + lhs: jax.Array = jax.lax.bitcast_convert_type(x, new_dtype=jnp.uint8) # bnmp + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) # nk1q + # https://github.com/google/jax/issues/11483 + rhs = jax.lax.rev(rhs, [2]) + # rhs = jlax.rev(rhs, dimensions=[3]) + + # basically an einsum of "mnp,nkq->mk(p+q)" but jax einsum doesn't support + # convolution yet + u8_products = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1,), + padding=((3, 3),), + dimension_numbers=("NCW", "IOW", "NCW"), + preferred_element_type=jnp.uint32, + ) + + shift_factors = jnp.array([0, 8, 16, 24, 32, 40, 48], dtype=jnp.uint32) + return jnp.sum(u8_products.astype(jnp.uint64) << shift_factors, axis=(2,)) + + +def chunk_decomposition(x, chunkwidth=8): + """Precision-level data conversion. + + Args: + x: The input data. + chunkwidth: The chunkwidth. + + Returns: + The decomposed data. + """ + dtype = jnp.uint8 + if chunkwidth == 16: + dtype = jnp.uint16 + elif chunkwidth == 32: + dtype = jnp.uint32 + + elements = [] + mask = (1 << chunkwidth) - 1 + # Mask to extract the lower bits (e.g., 32 bits -> 0xFFFFFFFF) + + # Extract each element from the integer + while x > 0: + elements.append(x & mask) # Extract the lower bits + x >>= chunkwidth # Shift to remove the extracted bits + + # Convert the list to a JAX array + return jnp.array(elements, dtype=dtype) + + +def rechunkify_after_chunkwise_add(arr_a, chunkwidth): + """Rechunkify after chunkwise add. + + Args: + arr_a: The input array. + chunkwidth: The chunkwidth. + + Returns: + The rechunkified array. + """ + dtype_double_length = jnp.uint16 + if chunkwidth == 16: + dtype_double_length = jnp.uint32 + elif chunkwidth == 32: + dtype_double_length = jnp.uint64 + + # assert isinstance(arr_a, jnp.array) + # assume the precision of partial sum is <= 2 * precision of input value. + bitmask = (1 << chunkwidth) - 1 + + # # Data Type Illustration + # We need to accumulate these data + # - Could directly perform bitwidth concatenation to generate the final + # result if there is no overlap across each partial sum + # LSB MSB + # |-----------------> bit + # | a0 + # | ==-- + # | a1 + # | ==-- + # | a2 + # | ==-- + # | a3 + # v ==-- + + # whole a0 a1 a2 a3 + # precision ==-- ==-- ==-- ==-- + + # lower a0 a1 a2 a3 + # half == == == == + + # upper a0 a1 a2 a3 + # half -- -- -- -- + + # # Chunk Splitting -> upper and lower half + # padding to align + # lower a0 a1 a2 a3 0 + # half == == == == == + + # upper 0 a0 a1 a2 a3 + # half -- -- -- -- -- + + # # Vectorized Accumulation + # lower a0 a1 a2 a3 0 + # half == == == == == + # + + + + + + # upper 0 a0 a1 a2 a3 + # half -- -- -- -- -- + + # -> result b0 b1 b2 b3 b4 + # -- 1/0-- 1/0-- 1/0-- -- + # (b1 and b4 does not have carry for sure.) + + # Each result chunk might have one more bit for carry. + # Perform one more chunk decomposition and accumulation. + + # # One more Chunk Splitting for partial sum "b" to take care of carry bit. + # carry b0 b1 b2 b3 b4 + # 0 1/0 1/0 1/0 0 + + # carry b4 b0 b1 b2 b3 + # right 0 0 1/0 1/0 1/0 + # shift + # (wrap around rotation, b4 is always zero so will be correct) + # + + + + + + # lower b0 b1 b2 b3 b4 + # half -- -- -- -- -- + # = = = = = + # c0 c1 c2 c3 c4 + # -> -- -- -- -- 1/0-- + # (! c4 might overflow, need one more chunk decomposition) + + # c0 c1 c2 c3 c4 c5 + # -> -- -- -- -- -- 1/0 + + # Chunk Splitting -> upper and lower half + arr_a_lower_half = jnp.bitwise_and(arr_a, bitmask) + arr_a_upper_half = jnp.right_shift(arr_a, chunkwidth) + + # Padding to align + arr_a_lower_half_pad = jnp.pad(arr_a_lower_half, (0, 1)) + arr_a_upper_half_pad = jnp.pad(arr_a_upper_half, (1, 0)) + + # Vectorized Accumulation + arr_b = jnp.add( + arr_a_lower_half_pad.astype(dtype_double_length), + arr_a_upper_half_pad.astype(dtype_double_length), + ) + + while not jnp.all(arr_b <= bitmask): + arr_b_lower_half = jnp.bitwise_and(arr_b, bitmask) + arr_b_carry = jnp.right_shift(arr_b, chunkwidth) + arr_b = jnp.roll(arr_b_carry, 1, axis=-1) + arr_b = jnp.add(arr_b_lower_half, arr_b) + + # Vectorized Accumulation + arr_c = arr_b + + # break top chunk into upper and lower to avoid overflow. + arr_c = jnp.pad(arr_c, (0, 1)) + arr_c = arr_c.at[-1].set(jnp.right_shift(arr_c[-2], chunkwidth)) + arr_c = arr_c.at[-2].set(jnp.bitwise_and(arr_c[-2], bitmask)) + + return arr_c + + +def smul_as_dense_gemv_bag( + x, total_in_precision=32, chunkwidth=8, q=4294967291 +): + """This is the implementation of BAG; Major improvement to achieve dense matrix. + + Args: + x: The input matrix. + total_in_precision: The total precision of the input matrix. + chunkwidth: The chunkwidth. + q: The modulus. + + Returns: + The dense matrix. + + Steps: + 1. break x into [x0, x1, x2, x3] + 2. reform [x0, x1, x2, x3] into the output + [ + x0 r00 r00 r00 # 2^0 + x1 x0+r01 r01 r01 # 2^8 + x2 x1+r02 x0+r02 r02 # 2^16 + x3 x2+r03 x1+r03 x0+r03 # 2^24 + ] + """ + dtype_double_length = jnp.uint16 + chunk_upper_bound = (1 << 8) - 1 + if chunkwidth == 16: + dtype_double_length = jnp.uint32 + chunk_upper_bound = (1 << 16) - 1 + elif chunkwidth == 32: + dtype_double_length = jnp.uint64 + chunk_upper_bound = (1 << 32) - 1 + + total_chunk_num = int(jnp.ceil(total_in_precision / chunkwidth)) + + # the number of row in left matrix + height = total_chunk_num + total_chunk_num - 1 + x_dtype = chunk_decomposition(int(x), chunkwidth) + x_dense = jnp.zeros( + (total_chunk_num + total_chunk_num - 1, total_chunk_num), + dtype=dtype_double_length, + ) + for j in range(total_chunk_num): + upper_idx = min(total_chunk_num, x_dtype.shape[0] + j) + x_dense = x_dense.at[j:upper_idx, j].set(x_dtype[0 : upper_idx - j]) + + # [ + # x0 # 2^0 + # x1 x0 # 2^8 + # x2 x1 x0 # 2^16 + # x3 x2 x1 x0 # 2^24 + # ----------- + # x3 x2 x1 # 2^32 iterate all elements in the bottom block + # x3 x2 # 2^40 + # x3 # 2^48 + # ] + + # Perform BAG to the following block of the matrix + # j 2 1 0 + # x3 x2 x1 # 2^32 i=0 + # x3 x2 # 2^40 i=1 + # x3 # 2^48 i=2 + + for i in range(x_dtype.shape[0] - 1): + for j in range(x_dtype.shape[0] - 1 - i): + basis = (total_chunk_num + i) * chunkwidth + projected_data = (int(x_dtype[i + j + 1]) << basis) % q + r = chunk_decomposition(projected_data, chunkwidth).astype( + dtype_double_length + ) + + x_dense = x_dense.at[: len(r), total_chunk_num - 1 - j].set( + jnp.add(r, x_dense[: len(r), total_chunk_num - 1 - j]) + ) + + for j in range(x_dtype.shape[0] - 1): + # Iterate over different columns + if not jnp.all(x_dense[:, total_chunk_num - 1 - j] <= chunk_upper_bound): + arr_new_chunkified = rechunkify_after_chunkwise_add( + x_dense[:, total_chunk_num - 1 - j], chunkwidth + ) + x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( + arr_new_chunkified[:height] + ) + + while not jnp.all(x_dense <= chunk_upper_bound): + for j in range(total_chunk_num - 1): + # Iterate over different columns + if not jnp.all(x_dense[:, total_chunk_num - 1 - j] <= chunk_upper_bound): + arr_new_chunkified = rechunkify_after_chunkwise_add( + x_dense[:, total_chunk_num - 1 - j], chunkwidth + ) + x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( + arr_new_chunkified[:height] + ) + + # j 2 1 0 + # x3 x2 x1 # 2^32 i=0 + # x3 x2 # 2^40 i=1 + # x3 # 2^48 i=2 + + for i in range(total_chunk_num - 1): + if x_dense[total_chunk_num + i, total_chunk_num - 1 - j] > 0: + basis = (total_chunk_num + i) * chunkwidth + projected_data = ( + int(x_dense[total_chunk_num + i, total_chunk_num - 1 - j]) + << basis + ) % q + r = chunk_decomposition(projected_data, chunkwidth).astype( + dtype_double_length + ) + x_dense = x_dense.at[: len(r), total_chunk_num - 1 - j].set( + jnp.add(r, x_dense[: len(r), total_chunk_num - 1 - j]) + ) + + x_dense = x_dense.at[ + total_chunk_num + i, total_chunk_num - 1 - j + ].set(0) + return x_dense[:total_chunk_num, :]#.astype(dtype) + + +def hpmatmul_offline_compile_bag(mat_a, q): + """Convert the input (m,n) matrix into (m,n,p,q), i.e. + + replace each element in the original matrix by a p*q matrix (p==q). + + Args: + mat_a: The input matrix. + q: The modulus. + + Returns: + The converted matrix. + """ + assert mat_a.dtype == jnp.uint32 # This version is defined for 32-bit input. + if isinstance(mat_a, list): + m, n = len(mat_a), len(mat_a[0]) + else: + m, n = mat_a.shape[0], mat_a.shape[1] + total_in_precision = 32 + chunkwidth = 8 + # Convert left-side matrix + total_chunk_num = int(jnp.ceil(total_in_precision / chunkwidth)) + + left_mat = jnp.zeros( + (m, n, total_chunk_num, total_chunk_num), dtype=jnp.uint16 + ) + + if isinstance(mat_a, list): + for i in range(m): + for k in range(n): + left_mat = left_mat.at[i, k, :, :].set( + smul_as_dense_gemv_bag( + mat_a[i][k], + total_in_precision=total_in_precision, + chunkwidth=chunkwidth, + q=q, + ) + ) + else: + for i in range(m): + for k in range(n): + left_mat = left_mat.at[i, k, :, :].set( + smul_as_dense_gemv_bag( + mat_a[i, k], + total_in_precision=total_in_precision, + chunkwidth=chunkwidth, + q=q, + ) + ) + + return left_mat + + +@jax.jit +def hpmatmul_bag_adapt(lhs: jax.Array, y: jax.Array): + """Input (m, n) Left Matrix -> (m, n, p, q) Left Matrix, where each element in the original (m, n) matrix is replaced by a (p, q) matrix.""" + + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) + i8_products = jnp.einsum( + "mnpq,nkq->mkp", + lhs, + rhs, + preferred_element_type=jnp.int32, + ) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(2,)) + + +def hpmatmul_golden(mat_a, mat_b, modulus_32): + mat_reference_result = [] + for i in range(mat_a.shape[0]): + mat_reference_result_row = [] + for j in range(mat_b.shape[1]): + acc_res = 0 + for k in range(mat_a.shape[1]): + acc_res += int(mat_a[i, k]) * int(mat_b[k, j]) + mat_reference_result_row.append(acc_res % modulus_32) + mat_reference_result.append(mat_reference_result_row) + return mat_reference_result + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tril.html # For n=3, generates the following # [[ 1 -1 -1] diff --git a/jaxite/jaxite_lib/matrix_utils_test.py b/jaxite/jaxite_lib/matrix_utils_test.py index c897829..7871aa6 100644 --- a/jaxite/jaxite_lib/matrix_utils_test.py +++ b/jaxite/jaxite_lib/matrix_utils_test.py @@ -3,6 +3,7 @@ import hypothesis from hypothesis import strategies +import jax import jax.numpy as jnp from jaxite.jaxite_lib import jax_helpers from jaxite.jaxite_lib import matrix_utils @@ -270,5 +271,86 @@ def test_scale_by_x_power_n_minus_1(self, power, poly): np.testing.assert_array_equal(expected, actual) +def test_hpmatmul_outerproduct(): + """Test the correctness of the Conv-Adapt-Conv algorithm.""" + key = jax.random.key(0) + mat_a_shape = (4, 16) + mat_b_shape = (mat_a_shape[1], 4) + upper_value = (1 << 28) - 1 + modulus_32 = 4294967291 + modulus_64 = jnp.array(modulus_32, dtype=jnp.uint64) + mat_a = jax.random.randint( + key, mat_a_shape, 0, upper_value, dtype=jnp.uint32 + ) + mat_b = jax.random.randint( + key, mat_b_shape, 0, upper_value, dtype=jnp.uint32 + ) + + mat_reference_result = matrix_utils.hpmatmul_golden(mat_a, mat_b, modulus_32) + mat_result_outerproduct = matrix_utils.hpmatmul_conv_adapt_outer_product( + mat_a, mat_b + ) + mat_result_outerproduct = mat_result_outerproduct % modulus_64 + + np.testing.assert_array_equal(mat_result_outerproduct, mat_reference_result) + print('pass testing mat_result_outerproduct == mat_reference_result') + + +def test_hpmatmul_bag(): + """Test the correctness of the Conv-Adapt-Conv algorithm.""" + key = jax.random.key(0) + mat_a_shape = (4, 16) + mat_b_shape = (mat_a_shape[1], 4) + upper_value = (1 << 28) - 1 + modulus_32 = 4294967291 + modulus_64 = jnp.array(modulus_32, dtype=jnp.uint64) + mat_a = jax.random.randint( + key, mat_a_shape, 0, upper_value, dtype=jnp.uint32 + ) + mat_b = jax.random.randint( + key, mat_b_shape, 0, upper_value, dtype=jnp.uint32 + ) + + mat_reference_result = matrix_utils.hpmatmul_golden(mat_a, mat_b, modulus_32) + compiled_mat_a = matrix_utils.hpmatmul_offline_compile_bag( + mat_a, modulus_32 + ) + + mat_result_bag = matrix_utils.hpmatmul_bag_adapt(compiled_mat_a, mat_b) + mat_result_bag = mat_result_bag % modulus_64 + + # Sanity Checking + for i in range(mat_a.shape[0]): + for j in range(mat_b.shape[1]): + if mat_result_bag[i, j] != mat_reference_result[i][j]: + print( + f'mat_result_bag[{i}, {j}]={mat_result_bag[i, j]} not match' + f' mat_reference_result[{i}, {j}]={ mat_reference_result[i][j]}' + ) + + np.testing.assert_array_equal(mat_result_bag, mat_reference_result) + print('pass testing mat_result_bag == mat_reference_result') + + +def test_hpmatmul_conv_adapt_conv(): + """Test the correctness of the Conv-Adapt-Conv algorithm.""" + if jax_helpers.get_tpu_version() >= 4: + key = jax.random.key(0) + mat_a_shape = (16, 16) + mat_b_shape = (mat_a_shape[1], 16) + upper_value = (1 << 28) - 1 + mat_a = jax.random.randint( + key, mat_a_shape, 0, upper_value, dtype=jnp.uint32 + ) + mat_b = jax.random.randint( + key, mat_b_shape, 0, upper_value, dtype=jnp.uint32 + ) + mat_result_outerproduct = matrix_utils.hpmatmul_conv_adapt_outer_product( + mat_a, mat_b + ) + mat_result_conv = matrix_utils.hpmatmul_conv_adapt_conv(mat_a, mat_b) + np.testing.assert_array_equal(mat_result_outerproduct, mat_result_conv) + + if __name__ == '__main__': absltest.main()