generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 24
/
example.py
32 lines (28 loc) · 1.04 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Import the necessary libraries
import torch
from mm_mamba import MultiModalMambaBlock
# Create some random input tensors
x = torch.randn(
1, 16, 64
) # Tensor with shape (batch_size, sequence_length, feature_dim)
y = torch.randn(
1, 3, 64, 64
) # Tensor with shape (batch_size, num_channels, image_height, image_width)
# Create an instance of the MultiModalMambaBlock model
model = MultiModalMambaBlock(
dim=64, # Dimension of the token embeddings
depth=5, # Number of transformer layers
dropout=0.1, # Dropout probability
heads=4, # Number of attention heads
d_state=16, # Dimension of the state embeddings
image_size=64, # Size of the input image
patch_size=16, # Size of each image patch
encoder_dim=64, # Dimension of the encoder token embeddings
encoder_depth=5, # Number of encoder transformer layers
encoder_heads=4, # Number of encoder attention heads
fusion_method="mlp",
)
# Pass the input tensors through the model
out = model(x, y)
# Print the shape of the output tensor
print(out.shape)