Skip to content

Commit

Permalink
SINGA-386 Implement RNN operation for autograd
Browse files Browse the repository at this point in the history
- Fix some bugs and do some design modification for RNN, LSTM, GRU..., which are calculated by  call CUDNN funcitons.

  The implemented layers all works well and can pass shape check for both forward step and
  backward step.
  • Loading branch information
xuewanqi committed Aug 16, 2018
1 parent 209d412 commit 4a14101
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 28 deletions.
70 changes: 43 additions & 27 deletions python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import math

from singa import tensor
from .tensor import Tensor
from . import layer
from singa.proto import model_pb2
Expand Down Expand Up @@ -969,12 +970,13 @@ class _RNN(Operation):
def __init__(self, handle):
self.handle = handle

def forward(self, X, h0, c0, W):
#def forward(self, X, h0, c0, W):
def forward(self, X, h0, W, c0=None):
# X of shape (seq_len, batch, input_size)
# h0_c0: (h0, c0) if lstm, else (h0,)
# h0, c0 of shape (num_layers * num_directions, batch, hidden_size)
if c0 is None:
assert self.rnn_mode != 'lstm'
assert self.handle.rnn_mode_ != 'lstm'
c0= CTensor([]) # CTensor([]) and Tensor cx are the same?

if self.handle.device_id == -1:
Expand All @@ -992,38 +994,49 @@ def forward(self, X, h0, c0, W):
# hout_cout: (hout, cout) if lstm, else (hout,)
# hout, cout of shape (num_layers * num_directions, batch,
# hidden_size)
oututs= _1dTo3d(Y)

#oututs= _1dTo3d(Y)
shape=(self.handle.seq_length_, self.handle.batch_size_, self.handle.hidden_size_)
outputs = singa.Reshape(Y, shape)

if self.rnn_mode != 'lstm':
if self.handle.rnn_mode_ != 'lstm':
return outputs, hout
else:
return outputs, hout, cout

def backward(self, dY, dh, dc=CTensor([])):
def backward(self, dY, dh=CTensor([]), dc=CTensor([])):
assert training is True and hasattr(
self, 'cache'), 'Please set training as True before do BP. '

dY_1d= _3dTo1d(dY)
#dY_1d= _3dTo1d(dY)

if dY_1d.device().id() != self.handle.device_id:
dY_1d.ToDevice(self.cache[0].device())
if dY.device().id() != self.handle.device_id:
dY.ToDevice(self.cache[0].device())

if self.handle.device_id == -1:
raise NotImplementedError
else:
dX_1d, dhout, dcout, dW = singa.GpuRNNBackward(
self.handle, dY_1d, dh, dc, self.cache)
self.handle, dY, dh, dc, self.cache)

dX = _1dTo3d(dX_1d)
#dX = _1dTo3d(dX_1d)
shape=(self.handle.seq_length_, self.handle.batch_size_, self.handle.input_size_)
dX = singa.Reshape(dX_1d, shape)

if self.rnn_mode != 'lstm':
if self.handle.rnn_mode_ != 'lstm':
return dX, dhout, dW
else:
return dX, dhout, dcout, dW
return dX, dhout, dW, dcout


def rnn(handle, x, h0, c0, W):
return _RNN(handle)(x, h0, c0, W)
#def rnn(handle, x, h0, c0, W):
# return _RNN(handle)(x, h0, c0, W)

def rnn(handle, x, h0, W, c0):
if c0 is None:
return _RNN(handle)(x, h0, W)
else:
return _RNN(handle)(x, h0, W, c0)


class RNN(Layer):
Expand Down Expand Up @@ -1054,14 +1067,15 @@ def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first
if self.bidirectional:
mult *= 2

W_Size = 0
for k in range(num_layers):
if k == 1:
if k == 0:
w_size = self.hidden_size * \
(self.input_size + self.hidden_size + 2)
else:
w_size = self.hidden_size * \
(self.hidden_size + self.hidden_size + 2)
W_Size *= mult * w_size
W_Size += mult * w_size

self.W_Size = W_Size
self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) # TODO: assign value of Wi separately
Expand All @@ -1077,33 +1091,35 @@ def __call__(self, inputs, h0, c0=None):
if self.rnn_mode == 'lstm':
assert c0 is not None, 'Please input c0.'
self.device_check(h0, c0)
else:
assert c0 is None, 'only lstm needs input c0'

if not hasattr(self, 'handle'):
self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
self.handle = singa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
self.rnn_mode, self.dropout, self.bidirectional, self.W_Size)
elif inputs.shape[0] != self.handle.seq_length_ or inputs.shape[1] != self.handle.batch_size_:
self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
self.handle = singa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers,
self.rnn_mode, self.dropout, self.bidirectional, self.W_Size)

self.handle.device_id = inputs.device.id()

X= _3dTo1d(inputs)
outputs = rnn(self.handle, X, h0, c0, self.W)
#X= _3dTo1d(inputs)
X=inputs
outputs = rnn(self.handle, X, h0, self.W, c0)
#outputs = rnn(self.handle, X, h0, self.W)
#outputs=tensor.to_numpy(outputs[0])
#print(outputs.shape)
#print(outputs)
return outputs

def _3dTo1d(self, inputs):
pass

def _1dTo3d(self, *args):
pass

class LSTM(RNN):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectionalrnn_mode='lstm')
super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, rnn_mode='lstm')


class GRU(RNN):

def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False):
super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectionalrnn_mode='gru')
super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, rnn_mode='gru')
6 changes: 6 additions & 0 deletions src/api/model_operation.i
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ public:

size_t batch_size_;
size_t seq_length_;
size_t input_size_;
size_t hidden_size_;
std::string rnn_mode_;
};

#if USE_CUDNN
Expand Down Expand Up @@ -122,6 +125,9 @@ public:

size_t batch_size_;
size_t seq_length_;
size_t input_size_;
size_t hidden_size_;
std::string rnn_mode_;

};

Expand Down
3 changes: 2 additions & 1 deletion src/model/operation/rnn.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "./rnn.h"

#include<iostream>
namespace singa {

RNNHandle::RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
Expand Down Expand Up @@ -203,6 +203,7 @@ void CudnnRNNHandle::SetRNNDescriptor(shared_ptr<Device> dev) {
CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0],
&weight_size_, dtype_));
// check the size manually calculated
//std::cout<<weight_size_<<weight_size<<sizeof(float)<<std::endl;
CHECK_EQ(weight_size_, weight_size * sizeof(float));
int filter_dim[3] = {static_cast<int>(weight_size_), 1, 1};
CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_));
Expand Down

0 comments on commit 4a14101

Please sign in to comment.