pytorch-generative
is a Python library which makes generative modeling in PyTorch easier by providing:
- high quality reference implementations of SOTA generative models
- useful abstractions of common building blocks found in the literature
- utilities for training, debugging, and working with Google Colab
To get started, click on one of the links below.
Supported models are implemented as PyTorch Modules and are easy to use:
from pytorch_generative import models
model = models.ImageGPT(in_channels=1, out_channels=1, in_size=28)
...
model(data)
Alternatively, lower level building blocks in pytorch_generative.nn can be used to write models from scratch. For example, we implement a convolutional ImageGPT-like model below:
from torch import nn
from pytorch_generative import nn as pg_nn
class TransformerBlock(nn.Module):
"""An ImageGPT Transformer block."""
def __init__(self,
n_channels,
n_attention_heads):
"""Initializes a new TransformerBlock instance.
Args:
n_channels: The number of input and output channels.
n_attention_heads: The number of attention heads to use.
"""
super().__init__()
self._ln1 = pg_nn.NCHWLayerNorm(n_channels)
self._ln2 = pg_nn.NCHWLayerNorm(n_channels)
self._attn = pg_nn.MaskedAttention(
in_channels=n_channels,
embed_channels=n_channels,
out_channels=n_channels,
n_heads=n_attention_heads,
is_causal=False)
self._out = nn.Sequential(
nn.Conv2d(
in_channels=n_channels,
out_channels=4*n_channels,
kernel_size=1),
nn.GELU(),
nn.Conv2d(
in_channels=4*n_channels,
out_channels=n_channels,
kernel_size=1))
def forward(self, x):
x = x + self._attn(self._ln1(x))
return x + self._out(self._ln2(x))
class ImageGPT(nn.Module):
"""The ImageGPT Model."""
def __init__(self,
in_channels,
out_channels,
in_size,
n_transformer_blocks=8,
n_attention_heads=4,
n_embedding_channels=16):
"""Initializes a new ImageGPT instance.
Args:
in_channels: The number of input channels.
out_channels: The number of output channels.
in_size: Size of the input images. Used to create positional encodings.
n_transformer_blocks: Number of TransformerBlocks to use.
n_attention_heads: Number of attention heads to use.
n_embedding_channels: Number of attention embedding channels to use.
"""
super().__init__()
self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
self._input = pg_nn.MaskedConv2d(
is_causal=True,
in_channels=in_channels,
out_channels=n_embedding_channels,
kernel_size=3,
padding=1)
self._transformer = nn.Sequential(
*[TransformerBlock(n_channels=n_embedding_channels,
n_attention_heads=n_attention_heads)
for _ in range(n_transformer_blocks)])
self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
self._out = nn.Conv2d(in_channels=n_embedding_channels,
out_channels=out_channels,
kernel_size=1)
def forward(self, x):
x = self._input(x + self._pos)
x = self._transformer(x)
x = self._ln(x)
return self._out(x)
pytorch-generative
supports the following algorithms.
Note: Our reported binary MNIST results may be optimistic. Instead of using a fixed dataset, we resample a new binary MNIST dataset on every epoch. We can think of this as using data augmentation which helps our models learn better.
Binarized MNIST (nats):
Algorithm | Our Results | Links |
---|---|---|
PixelSNAIL | 78.61 | Code, Paper |
ImageGPT | 79.17 | Code, Paper |
Gated PixelCNN | 81.50 | Code, Paper |
PixelCNN | 81.45 | Code, Paper |
MADE | 84.87 | Code, Paper |
NADE | 85.65 | Code, Paper |
Algorithm | Our Results | Links |
---|---|---|
VAE | TODO | Code, Paper |
VQ-VAE | TODO | Code, Paper |
VQ-VAE-2 | TODO | Code, Paper |
Blog: https://towardsdatascience.com/how-to-get-beautiful-results-with-neural-style-transfer-75d0c05d6489
Notebook: https://github.com/EugenHotaj/pytorch-generative/blob/master/notebooks/style_transfer.ipynb
Paper: https://arxiv.org/pdf/1508.06576.pdf
Notebook: https://github.com/EugenHotaj/pytorch-generative/blob/master/notebooks/cppn.ipynb
Background: https://en.wikipedia.org/wiki/Compositional_pattern-producing_network