Skip to content

Commit

Permalink
SINGA-386 Implement RNN operation for autograd
Browse files Browse the repository at this point in the history
- fix bugs in cpp parts, the codes can be made without error.
  • Loading branch information
xuewanqi committed Jul 17, 2018
1 parent 33ddc2d commit b176cb4
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 84 deletions.
40 changes: 39 additions & 1 deletion python/singa/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def __call__(self, x):
self.handle.device_id = x.device.id()

y = batchnorm_2d(self.handle, x, self.scale, self.bias,
self.running_mean, self.running_var)
self.running_mean, self.running_var)
return y


Expand Down Expand Up @@ -936,3 +936,41 @@ def __init__(self, kernel_size, stride=None, padding=0):
stride = kernel_size
super(MaxPool2d, self).__init__(
(1, kernel_size), (0, stride), (0, padding), False)


class _RNN(Operation):

def __init__(self, handle):
self.handle = handle

def forward(self, X, W):

if self.handle.device_id == -1:
raise NotImplementedError
else:
if training:
out, self.cache = singa.GpuRNNForwardTraining(
self.handle, X, W)
else:
out = singa.GpuRNNForwardInference(self.handle, X, W)
return out

def backward(self, dY):
assert training is True and hasattr(
self, 'cache'), 'Please set training as True before do BP. '

if dY.device().id() != self.handle.device_id:
dY.ToDevice(self.inputs[0].device())

if self.handle.device_id == -1:
raise NotImplementedError
else:
dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache)
return dX, dW


def rnn():
pass


class RNN(Layer):
Loading

0 comments on commit b176cb4

Please sign in to comment.