Skip to content

Commit

Permalink
increase M, N, K, and PP
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 committed Nov 28, 2024
1 parent f7c6363 commit 680f364
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tests/dataflow/test_packed_systolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import allo.backend.hls as hls
import numpy as np

L, D = 2, 2
M, N, K = L, 1 * D, D
PP = 2
M, N, K = 8, 8, 4
PP = 4
P0, P1 = M // PP + 2, N + 2

if PP == 2:
np_type = np.int16
allo_type = int16
elif PP == 4:
np_type = np.int32
allo_type = int32
else:
raise ValueError(f"Unsupported packing factor: {PP}")

Expand All @@ -26,9 +28,9 @@ def top():

@df.kernel(mapping=[P0, P1])
def gemm(
X_packed: allo_type[L // PP, D],
W_packed: allo_type[D, 1 * D // PP],
Z_packed: allo_type[L // PP, 1 * D],
X_packed: allo_type[M, K // PP],
W_packed: allo_type[K // PP, N],
Z_packed: allo_type[M // PP, N],
):
i, j = df.get_pid()
# Peripheral kernels
Expand All @@ -37,11 +39,11 @@ def gemm(
with allo.meta_elif(j == 0):
# i > 0
for k in range(K):
fifo_A[i, j + 1].put(X_packed[i - 1, k])
fifo_A[i, j + 1].put(X_packed[(i - 1) * PP, k])
with allo.meta_elif(i == 0):
# j > 0
for k in range(K):
fifo_B[i + 1, j].put(W_packed[j // PP, 0])
fifo_B[i + 1, j].put(W_packed[k // PP, j - 1])

# drain
with allo.meta_elif(i == M // PP + 1 and j > 0):
Expand All @@ -68,14 +70,14 @@ def gemm(


def test_packed_systolic():
X = np.random.randint(-4, 4, size=(L, D)).astype(np.int8)
W_A_cst = np.random.randint(-4, 4, size=(D, 1 * D)).astype(np.int8)
X = np.random.randint(-4, 4, size=(M, K)).astype(np.int8)
W_A_cst = np.random.randint(-4, 4, size=(K, N)).astype(np.int8)

packed_X = np.ascontiguousarray(np.ascontiguousarray(X).view(np_type).transpose())
packed_X = np.ascontiguousarray(np.ascontiguousarray(X).view(np_type))
W_A_packed = np.ascontiguousarray(
np.ascontiguousarray(W_A_cst.transpose()).view(np_type).transpose()
)
Z_packed = np.zeros((L // PP, 1 * D), dtype=np_type)
Z_packed = np.zeros((M // PP, N), dtype=np_type)
mod = df.build(top)
if hls.is_available("vitis_hls"):
mod(packed_X, W_A_packed, Z_packed)
Expand Down

0 comments on commit 680f364

Please sign in to comment.