Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow support #1

Open
caffeine-lab opened this issue Oct 12, 2023 · 1 comment
Open

Tensorflow support #1

caffeine-lab opened this issue Oct 12, 2023 · 1 comment

Comments

@caffeine-lab
Copy link

Hi,

Thanks a lot of this wonderful work. I am wondering if you could help me in running this on Tensor flow? I did translate the main components but I am getting errors.

import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Embedding, Dropout
import tensorflow_addons as tfa
from tensorflow.keras.models import Model
from einops import rearrange
from einops.layers.tensorflow import Rearrange
# Helper function for pairing dimensions
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# PreNorm Layer
class PreNorm(layers.Layer):
  def __init__(self, dim, fn):
    super().__init__()
    self.fn = fn
    self.norm = layers.LayerNormalization(axis=-1)

  def call(self, x, **kwargs):
    return self.fn(self.norm(x), **kwargs)

# FeedForward Layer
class FeedForward(layers.Layer):
  def __init__(self, dim, hidden_dim, dropout=0.):
    super().__init__()
    self.dense1 = layers.Dense(hidden_dim, input_shape=(None, dim), use_bias=False)
    self.gelu = tfa.layers.GELU()
    self.dropout1 = layers.Dropout(dropout)
    self.dense2 = layers.Dense(dim, input_shape=(None, hidden_dim))
    self.dropout2 = layers.Dropout(dropout)

  def call(self, x, training=True): 
    x = self.dense1(x)
    x = self.gelu(x)
    x = self.dropout1(x, training=training)
    x = self.dense2(x)
    x = self.dropout2(x, training=training)
    return x

class Image2Patch_Embedding(layers.Layer):
    def __init__(self, patch_size, channels, dim):
        super().__init__()
        patch_height, patch_width = pair(patch_size)
        patch_dim = channels * patch_height * patch_width

        self.im2patch = layers.Lambda(lambda x: tf.reshape(x, [-1, patch_height, patch_width, channels]))
        self.patch2latentv = layers.Dense(units=dim, input_shape=(patch_dim,))

    def call(self, x):
        x = self.im2patch(x)
        x = self.patch2latentv(x)
        return x

class Latentv2Image(layers.Layer):
    def __init__(self, patch_size, channels, dim):
        patch_height, patch_width = pair(patch_size)
        self.latentv2patch = tf.keras.layers.Dense(units=channels * patch_height * patch_width, input_shape=(dim,))
        self.vec2square = Rearrange('(c h w) -> c h w', c = channels, h = patch_height, w = patch_width)

        
    def call(self, x):
        return x

# CNN Block
class ConvBlock(layers.Layer):
    def __init__(self, channels, hidden_channels, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)

        self.conv1 = tf.keras.layers.Conv2D(
            filters=hidden_channels,
            kernel_size=(3, 3),
            strides=1,
            padding="same",
            activation="gelu",
        )

    def call(self, inputs):
        x = self.conv1(inputs)
        return x

# CNN Layer
class CNN(layers.Layer):
    def __init__(self, channels, hidden_channels, depth):
        super(CNN, self).__init__()    
        if depth == 1:
            self.m_list = [ConvBlock(channels, channels)]
        else:
            self.m_list = [ConvBlock(channels, hidden_channels)]
            for i in range(depth-2):
                self.m_list.append(ConvBlock(hidden_channels, hidden_channels)) 
            self.m_list.append(ConvBlock(hidden_channels, channels))
        self.blocks = tf.keras.Sequential(self.m_list)

    def call(self, x):
        return self.blocks(x)

class ReconstructionConv(layers.Layer):
    def __init__(self, dim, hidden_channels, patch_size, image_size, cnn_depth, channels=3):
        super(ReconstructionConv, self).__init__()

        self.dim = dim
        self.image_size = image_size
        self.patch_size = patch_size

        patch_height, patch_width = pair(patch_size)

        self.patch_dim = channels * patch_height * patch_width

        self.latent2patchv = tf.keras.layers.Dense(units=self.patch_dim, input_shape=(self.dim,))
        self.patchv2patch = Rearrange('b p (c h w) -> b p c h w', c=channels, h=patch_height, w=patch_width)
        self.net = CNN(
            channels=channels, hidden_channels=hidden_channels, depth=cnn_depth
        )
        self.embedding = Image2Patch_Embedding(
            patch_size=patch_height, channels=channels, dim=dim
        )
        self.fc = tf.keras.layers.Dense(units=self.dim, input_shape=(self.dim,))

    def reconstruct(self, patchs):
        image_height, image_width = pair(self.image_size)
        patch_height, patch_width = pair(self.patch_size)
        h = int(image_height / patch_height)
        w = int(image_width / patch_width)

        images = []
        for i in range(h):
            raw = []
            for j in range(w):
                raw.append(patchs[:, i*h + j, :, :])
            raw = torch.cat(raw, dim=3)
            images.append(raw)

        images = tf.concat(images, axis=2)

        return images

    def call(self, x):
        num_patchs = x.size()[1] - 1

        cls_tokens, x = tf.split(x, [1, num_patchs], axis=1)

        x = self.latent2patchv(x)
        x = self.patchv2patch(x)
        x = self.reconstruct(x)
        x = self.net(x)
        x = self.embedding(x)
        x = tf.reshape(x, (-1, num_patchs, self.dim))
        x = tf.concat((cls_tokens, x), axis=1)
        x = self.fc(x)

        return x

# Attention Layer
class Attention(layers.Layer): 
  def __init__(self, dim, heads=8, dim_head = 64,dropout=0.):
    super().__init__()
    inner_dim = dim_head *  heads
    project_out = not (heads == 1 and dim_head == dim)
    
    self.heads = heads 
    self.scales = dim_head ** (-0.5)

    self.to_qkv = layers.Dense(inner_dim*3, input_shape=(dim,), use_bias=False)

    self.out_dense = layers.Dense(dim, input_shape=(None, dim))
    self.out_dropout = layers.Dropout(dropout)

    self.to_out = tf.keras.Sequential(layers=[layers.Dense(inner_dim, input_shape=(None, dim)), 
                                              layers.Dropout(dropout)], name='to_out')
    
  def to_out(self, x, training=True):
    out = self.out_dense(x)
    out = self.out_dropout(out, training=training)
    return out

  def call(self, x, mask=None, training=True):
    b, n, _, h = *x.shape, self.heads
    qkv = tf.split(self.to_qkv(x), 3, axis=-1)
    q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)

    dots = tf.einsum("bhid, bhjd -> bhij", q, k) * self.scales

    if mask is not None:
      mask = tf.pad(mask.flatten(1), (1, 0), constant_values=True)
      assert mask.shape[-1] == dots.shape[-1], "Mask has incorrect dimensions"
      mask = mask[:, None, :] * mask[:, :, None]
      dots[~mask] = tf.fill(dots[~mask], float('-inf'))
      del mask

    attn = tf.nn.softmax(dots, axis=-1)
    out = tf.einsum("bhij,bhjd->bhid", attn, v)
    out = rearrange(out, "b h n d -> b n (h d)")
    out = self.to_out(out, training=True)
    return out



# ConvolutionalTransformer Layer
class ConvolutionalTransformer(layers.Layer):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, hidden_channels, patch_size, image_size, cnn_depth, dropout=0.):
        super(ConvolutionalTransformer, self).__init__()
        self.layers = []
        for _ in range(depth):
            self.layers.append([
                PreNorm(dim, Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, ReconstructionConv(dim=dim, hidden_channels=hidden_channels, patch_size=patch_size, image_size=image_size, cnn_depth=cnn_depth)),
                PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim, dropout=dropout))
            ])

    def call(self, x):
        for attn, rc, ff in self.layers:
            x = attn(x) + x
            identity = x
            x = rc(x)
            x = ff(x) + identity
        return x

# VisionConformer Model
class VisionConformer(Model):
    def __init__(
            self,
            *, 
            image_size,
            patch_size,
            num_classes,
            dim,
            depth,
            heads,
            mlp_dim,
            pool = 'cls',
            channels = 1,
            dim_head = 64,
            dropout = 0.,
            emb_dropout = 0.,
            hidden_channels,
            cnn_depth):

        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        self.to_patch_embedding = Image2Patch_Embedding(patch_size=patch_size, channels=channels, dim=dim)
        self.pos_embedding = tf.Variable(tf.random.normal((1, num_patches + 1, dim)))
        self.cls_token = tf.Variable(tf.random.normal((1, 1, dim)))
        self.dropout = Dropout(emb_dropout)
        self.transformer = ConvolutionalTransformer(dim=dim, depth=depth, heads=heads, dim_head=dim_head,
                                       mlp_dim=mlp_dim, dropout=dropout, hidden_channels=hidden_channels,
                                       patch_size=patch_size, image_size=image_size, cnn_depth=cnn_depth)

        self.pool = pool
        self.to_latent = tf.keras.layers.Identity()

        
    def call(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = tf.repeat(self.cls_token, repeats=b, axis=0)  # Ensure cls_tokens have the same batch size as x
        x = tf.concat([cls_tokens, x], axis=1)
        x += self.pos_embedding[:, :n + 1]
        x = self.dropout(x)

        x = self.transformer(x)

        x = tf.reduce_mean(x, axis=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return x

@hvsesha
Copy link

hvsesha commented Oct 13, 2023

hi
I am able to run this python3.9 in my local without any change .With warning Tensorflow has stopped development

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants