You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks a lot of this wonderful work. I am wondering if you could help me in running this on Tensor flow? I did translate the main components but I am getting errors.
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Embedding, Dropout
import tensorflow_addons as tfa
from tensorflow.keras.models import Model
from einops import rearrange
from einops.layers.tensorflow import Rearrange
# Helper function for pairing dimensions
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# PreNorm Layer
class PreNorm(layers.Layer):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = layers.LayerNormalization(axis=-1)
def call(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# FeedForward Layer
class FeedForward(layers.Layer):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.dense1 = layers.Dense(hidden_dim, input_shape=(None, dim), use_bias=False)
self.gelu = tfa.layers.GELU()
self.dropout1 = layers.Dropout(dropout)
self.dense2 = layers.Dense(dim, input_shape=(None, hidden_dim))
self.dropout2 = layers.Dropout(dropout)
def call(self, x, training=True):
x = self.dense1(x)
x = self.gelu(x)
x = self.dropout1(x, training=training)
x = self.dense2(x)
x = self.dropout2(x, training=training)
return x
class Image2Patch_Embedding(layers.Layer):
def __init__(self, patch_size, channels, dim):
super().__init__()
patch_height, patch_width = pair(patch_size)
patch_dim = channels * patch_height * patch_width
self.im2patch = layers.Lambda(lambda x: tf.reshape(x, [-1, patch_height, patch_width, channels]))
self.patch2latentv = layers.Dense(units=dim, input_shape=(patch_dim,))
def call(self, x):
x = self.im2patch(x)
x = self.patch2latentv(x)
return x
class Latentv2Image(layers.Layer):
def __init__(self, patch_size, channels, dim):
patch_height, patch_width = pair(patch_size)
self.latentv2patch = tf.keras.layers.Dense(units=channels * patch_height * patch_width, input_shape=(dim,))
self.vec2square = Rearrange('(c h w) -> c h w', c = channels, h = patch_height, w = patch_width)
def call(self, x):
return x
# CNN Block
class ConvBlock(layers.Layer):
def __init__(self, channels, hidden_channels, **kwargs):
super(ConvBlock, self).__init__(**kwargs)
self.conv1 = tf.keras.layers.Conv2D(
filters=hidden_channels,
kernel_size=(3, 3),
strides=1,
padding="same",
activation="gelu",
)
def call(self, inputs):
x = self.conv1(inputs)
return x
# CNN Layer
class CNN(layers.Layer):
def __init__(self, channels, hidden_channels, depth):
super(CNN, self).__init__()
if depth == 1:
self.m_list = [ConvBlock(channels, channels)]
else:
self.m_list = [ConvBlock(channels, hidden_channels)]
for i in range(depth-2):
self.m_list.append(ConvBlock(hidden_channels, hidden_channels))
self.m_list.append(ConvBlock(hidden_channels, channels))
self.blocks = tf.keras.Sequential(self.m_list)
def call(self, x):
return self.blocks(x)
class ReconstructionConv(layers.Layer):
def __init__(self, dim, hidden_channels, patch_size, image_size, cnn_depth, channels=3):
super(ReconstructionConv, self).__init__()
self.dim = dim
self.image_size = image_size
self.patch_size = patch_size
patch_height, patch_width = pair(patch_size)
self.patch_dim = channels * patch_height * patch_width
self.latent2patchv = tf.keras.layers.Dense(units=self.patch_dim, input_shape=(self.dim,))
self.patchv2patch = Rearrange('b p (c h w) -> b p c h w', c=channels, h=patch_height, w=patch_width)
self.net = CNN(
channels=channels, hidden_channels=hidden_channels, depth=cnn_depth
)
self.embedding = Image2Patch_Embedding(
patch_size=patch_height, channels=channels, dim=dim
)
self.fc = tf.keras.layers.Dense(units=self.dim, input_shape=(self.dim,))
def reconstruct(self, patchs):
image_height, image_width = pair(self.image_size)
patch_height, patch_width = pair(self.patch_size)
h = int(image_height / patch_height)
w = int(image_width / patch_width)
images = []
for i in range(h):
raw = []
for j in range(w):
raw.append(patchs[:, i*h + j, :, :])
raw = torch.cat(raw, dim=3)
images.append(raw)
images = tf.concat(images, axis=2)
return images
def call(self, x):
num_patchs = x.size()[1] - 1
cls_tokens, x = tf.split(x, [1, num_patchs], axis=1)
x = self.latent2patchv(x)
x = self.patchv2patch(x)
x = self.reconstruct(x)
x = self.net(x)
x = self.embedding(x)
x = tf.reshape(x, (-1, num_patchs, self.dim))
x = tf.concat((cls_tokens, x), axis=1)
x = self.fc(x)
return x
# Attention Layer
class Attention(layers.Layer):
def __init__(self, dim, heads=8, dim_head = 64,dropout=0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scales = dim_head ** (-0.5)
self.to_qkv = layers.Dense(inner_dim*3, input_shape=(dim,), use_bias=False)
self.out_dense = layers.Dense(dim, input_shape=(None, dim))
self.out_dropout = layers.Dropout(dropout)
self.to_out = tf.keras.Sequential(layers=[layers.Dense(inner_dim, input_shape=(None, dim)),
layers.Dropout(dropout)], name='to_out')
def to_out(self, x, training=True):
out = self.out_dense(x)
out = self.out_dropout(out, training=training)
return out
def call(self, x, mask=None, training=True):
b, n, _, h = *x.shape, self.heads
qkv = tf.split(self.to_qkv(x), 3, axis=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
dots = tf.einsum("bhid, bhjd -> bhij", q, k) * self.scales
if mask is not None:
mask = tf.pad(mask.flatten(1), (1, 0), constant_values=True)
assert mask.shape[-1] == dots.shape[-1], "Mask has incorrect dimensions"
mask = mask[:, None, :] * mask[:, :, None]
dots[~mask] = tf.fill(dots[~mask], float('-inf'))
del mask
attn = tf.nn.softmax(dots, axis=-1)
out = tf.einsum("bhij,bhjd->bhid", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out, training=True)
return out
# ConvolutionalTransformer Layer
class ConvolutionalTransformer(layers.Layer):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, hidden_channels, patch_size, image_size, cnn_depth, dropout=0.):
super(ConvolutionalTransformer, self).__init__()
self.layers = []
for _ in range(depth):
self.layers.append([
PreNorm(dim, Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, ReconstructionConv(dim=dim, hidden_channels=hidden_channels, patch_size=patch_size, image_size=image_size, cnn_depth=cnn_depth)),
PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout))
])
def call(self, x):
for attn, rc, ff in self.layers:
x = attn(x) + x
identity = x
x = rc(x)
x = ff(x) + identity
return x
# VisionConformer Model
class VisionConformer(Model):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
pool = 'cls',
channels = 1,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
hidden_channels,
cnn_depth):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = Image2Patch_Embedding(patch_size=patch_size, channels=channels, dim=dim)
self.pos_embedding = tf.Variable(tf.random.normal((1, num_patches + 1, dim)))
self.cls_token = tf.Variable(tf.random.normal((1, 1, dim)))
self.dropout = Dropout(emb_dropout)
self.transformer = ConvolutionalTransformer(dim=dim, depth=depth, heads=heads, dim_head=dim_head,
mlp_dim=mlp_dim, dropout=dropout, hidden_channels=hidden_channels,
patch_size=patch_size, image_size=image_size, cnn_depth=cnn_depth)
self.pool = pool
self.to_latent = tf.keras.layers.Identity()
def call(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = tf.repeat(self.cls_token, repeats=b, axis=0) # Ensure cls_tokens have the same batch size as x
x = tf.concat([cls_tokens, x], axis=1)
x += self.pos_embedding[:, :n + 1]
x = self.dropout(x)
x = self.transformer(x)
x = tf.reduce_mean(x, axis=1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return x
The text was updated successfully, but these errors were encountered:
Hi,
Thanks a lot of this wonderful work. I am wondering if you could help me in running this on Tensor flow? I did translate the main components but I am getting errors.
The text was updated successfully, but these errors were encountered: