Skip to content

Commit

Permalink
SINGA-161 DLaaS
Browse files Browse the repository at this point in the history
  Wrap SINGA into a Docker image, which can run in a mesos cluster

  Can run in training and testing modes.
  • Loading branch information
Aaron Wuwf committed Apr 29, 2016
1 parent d547a86 commit 1840cb7
Show file tree
Hide file tree
Showing 14 changed files with 620 additions and 141 deletions.
5 changes: 3 additions & 2 deletions include/singa/neuralnet/neuron_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class DropoutLayer : public NeuronLayer {
*/
Blob<float> mask_;
};

/**
* This layer is dummy and do no real work.
* It is used for testing purpose only.
Expand All @@ -126,7 +127,7 @@ class DummyLayer: public NeuronLayer {
void Setup(const LayerProto& proto, const vector<Layer*>& srclayers) override;
void ComputeFeature(int flag, const vector<Layer*>& srclayers) override;
void ComputeGradient(int flag, const vector<Layer*>& srclayers) override;
void Feed(int batchsize, vector<float>& data, vector<int>& aux_data);
void Feed(vector<int> shape, vector<float>* data, int op);
Layer* ToLayer() { return this;}

private:
Expand Down Expand Up @@ -278,7 +279,7 @@ class PoolingLayer : public NeuronLayer {
int kernel_x_, pad_x_, stride_x_;
int kernel_y_, pad_y_, stride_y_;
int batchsize_, channels_, height_, width_, pooled_height_, pooled_width_;
PoolingProto_PoolMethod pool_;
PoolMethod pool_;
};
/**
* Use book-keeping for BP following Caffe's pooling implementation
Expand Down
46 changes: 37 additions & 9 deletions src/neuralnet/neuron_layer/dummy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,53 @@ void DummyLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
Copy(grad_, srclayers[0]->mutable_grad(this));
}

void DummyLayer::Feed(int batchsize, vector<float>& data, vector<int>& aux_data){
void DummyLayer::Feed(vector<int> shape, vector<float>* data, int op){

batchsize_ = batchsize;
// input data
if (data.size() > 0) {
int size = data.size();
//batchsize_ = batchsize;
batchsize_ = shape[0];
// dataset
if (op == 0) {
/*
size_t hdim = 1;
for (size_t i = 1; i < shape.size(); ++i)
hdim *= shape[i];
//data_.Reshape({batchsize, (int)hdim});
//shape.insert(shape.begin(),batchsize);
data_.Reshape(shape);
*/
//reshape data
data_.Reshape(shape);
CHECK_EQ(data_.count(), data->size());

int size = data->size();
float* ptr = data_.mutable_cpu_data();
for (int i = 0; i< size; i++) {
ptr[i] = data.at(i);
ptr[i] = data->at(i);
}
}
// auxiliary data, e.g., label
if (aux_data.size() > 0) {
// label
else {
aux_data_.resize(batchsize_);
for (int i = 0; i< batchsize_; i++) {
aux_data_[i] = static_cast<int>(aux_data.at(i));
aux_data_[i] = static_cast<int>(data->at(i));
}
}

return;

/* Wenfeng's input
batchsize_ = batchsize;
shape.insert(shape.begin(),batchsize);
data_.Reshape(shape);
int size = data_.count() / batchsize_;
CHECK_EQ(size, data->size());
float* ptr = data_.mutable_cpu_data();
for (int i = 0; i< size; i++)
ptr[i] = data->at(i);
return;
*/
}

} // namespace singa
22 changes: 11 additions & 11 deletions src/neuralnet/neuron_layer/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ void PoolingLayer::Setup(const LayerProto& conf,
}

pool_ = conf.pooling_conf().pool();
CHECK(pool_ == PoolingProto_PoolMethod_AVG
|| pool_ == PoolingProto_PoolMethod_MAX)
CHECK(pool_ == PoolMethod::AVG
|| pool_ == PoolMethod::MAX)
<< "Padding implemented only for average and max pooling.";
const auto& srcshape = srclayers[0]->data(this).shape();
int dim = srcshape.size();
Expand All @@ -83,9 +83,9 @@ void PoolingLayer::Setup(const LayerProto& conf,
void PoolingLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
auto src = Tensor4(srclayers[0]->mutable_data(this));
auto data = Tensor4(&data_);
if (pool_ == PoolingProto_PoolMethod_MAX)
if (pool_ == PoolMethod::MAX)
data = expr::pool<red::maximum>(src, kernel_x_, stride_x_);
else if (pool_ == PoolingProto_PoolMethod_AVG)
else if (pool_ == PoolMethod::AVG)
data = expr::pool<red::sum>(src, kernel_x_, stride_x_)
* (1.0f / (kernel_x_ * kernel_x_));
}
Expand All @@ -99,9 +99,9 @@ void PoolingLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
auto gsrc = Tensor4(srclayers[0]->mutable_grad(this));
auto data = Tensor4(&data_);
auto grad = Tensor4(&grad_);
if (pool_ == PoolingProto_PoolMethod_MAX)
if (pool_ == PoolMethod::MAX)
gsrc = expr::unpool<red::maximum>(src, data, grad, kernel_x_, stride_x_);
else if (pool_ == PoolingProto_PoolMethod_AVG)
else if (pool_ == PoolMethod::AVG)
gsrc = expr::unpool<red::sum>(src, data, grad, kernel_x_, stride_x_)
* (1.0f / (kernel_x_ * kernel_x_));
}
Expand All @@ -111,16 +111,16 @@ void PoolingLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
void CPoolingLayer::Setup(const LayerProto& conf,
const vector<Layer*>& srclayers) {
PoolingLayer::Setup(conf, srclayers);
if (pool_ == PoolingProto_PoolMethod_MAX)
if (pool_ == PoolMethod::MAX)
mask_.ReshapeLike(data_);
}
void CPoolingLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
if (pool_ == PoolingProto_PoolMethod_MAX)
if (pool_ == PoolMethod::MAX)
ForwardMaxPooling(srclayers[0]->mutable_data(this)->mutable_cpu_data(),
batchsize_, channels_, height_, width_, kernel_y_, kernel_x_,
pad_y_, pad_y_, stride_y_, stride_x_,
data_.mutable_cpu_data(), mask_.mutable_cpu_data());
else if (pool_ == PoolingProto_PoolMethod_AVG)
else if (pool_ == PoolMethod::AVG)
ForwardAvgPooling(srclayers[0]->mutable_data(this)->mutable_cpu_data(),
batchsize_, channels_, height_, width_, kernel_y_, kernel_x_,
pad_y_, pad_x_, stride_y_, stride_y_, data_.mutable_cpu_data());
Expand All @@ -129,12 +129,12 @@ void CPoolingLayer::ComputeFeature(int flag, const vector<Layer*>& srclayers) {
}

void CPoolingLayer::ComputeGradient(int flag, const vector<Layer*>& srclayers) {
if (pool_ == PoolingProto_PoolMethod_MAX)
if (pool_ == PoolMethod::MAX)
BackwardMaxPooling(grad_.cpu_data(), mask_.cpu_data(), batchsize_,
channels_, height_, width_, kernel_y_, kernel_x_, pad_y_, pad_x_,
stride_y_, stride_y_,
srclayers[0]->mutable_grad(this)->mutable_cpu_data());
else if (pool_ == PoolingProto_PoolMethod_AVG)
else if (pool_ == PoolMethod::AVG)
BackwardAvgPooling(grad_.cpu_data(), batchsize_,
channels_, height_, width_, kernel_y_, kernel_x_, pad_y_, pad_x_,
stride_y_, stride_x_,
Expand Down
12 changes: 6 additions & 6 deletions src/proto/job.proto
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,15 @@ message LRNProto {
// offset
optional float knorm = 34 [default = 1.0];
}

enum PoolMethod {
MAX = 0;
AVG = 1;
}

message PoolingProto {
// The kernel size (square)
optional int32 kernel= 1 [default = 3];
enum PoolMethod {
MAX = 0;
AVG = 1;
}
// The pooling method
// The pooling method
optional PoolMethod pool = 30 [default = MAX];
// The padding size
optional uint32 pad = 31 [default = 0];
Expand Down
16 changes: 8 additions & 8 deletions thirdparty/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -256,19 +256,19 @@ function install_protobuf()
echo "install protobuf in $1";
./configure --prefix=$1;
make && make install;
#cd python;
#python setup.py build;
#python setup.py install --prefix=$1;
#cd ..;
cd python;
python setup.py build;
python setup.py install --prefix=$1;
cd ..;
elif [ $# == 0 ]
then
echo "install protobuf in default path";
./configure;
make && sudo make install;
#cd python;
#python setup.py build;
#sudo python setup.py install;
#cd ..;
cd python;
python setup.py build;
sudo python setup.py install;
cd ..;
else
echo "wrong commands";
fi
Expand Down
178 changes: 178 additions & 0 deletions tool/dlaas/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/usr/bin/env python

#/************************************************************
#*
#* Licensed to the Apache Software Foundation (ASF) under one
#* or more contributor license agreements. See the NOTICE file
#* distributed with this work for additional information
#* regarding copyright ownership. The ASF licenses this file
#* to you under the Apache License, Version 2.0 (the
#* "License"); you may not use this file except in compliance
#* with the License. You may obtain a copy of the License at
#*
#* http://www.apache.org/licenses/LICENSE-2.0
#*
#* Unless required by applicable law or agreed to in writing,
#* software distributed under the License is distributed on an
#* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#* KIND, either express or implied. See the License for the
#* specific language governing permissions and limitations
#* under the License.
#*
#*************************************************************/

import os, sys
import numpy as np

current_path_ = os.path.dirname(__file__)
singa_root_=os.path.abspath(os.path.join(current_path_,'../..'))
sys.path.append(os.path.join(singa_root_,'thirdparty','protobuf-2.6.0','python'))
sys.path.append(os.path.join(singa_root_,'tool','python'))

from model import neuralnet, updater
from singa.driver import Driver
from singa.layer import *
from singa.model import save_model_parameter, load_model_parameter
from singa.utils.utility import swap32

from PIL import Image
import glob,random, shutil, time
from flask import Flask, request, redirect, url_for
from singa.utils import kvstore, imgtool
app = Flask(__name__)

def train(batchsize,disp_freq,check_freq,train_step,workspace,checkpoint=None):
print '[Layer registration/declaration]'
# TODO change layer registration methods
d = Driver()
d.Init(sys.argv)

print '[Start training]'

#if need to load checkpoint
if checkpoint:
load_model_parameter(workspace+checkpoint, neuralnet, batchsize)

for i in range(0,train_step):

for h in range(len(neuralnet)):
#Fetch data for input layer
if neuralnet[h].layer.type==kDummy:
neuralnet[h].FetchData(batchsize)
else:
neuralnet[h].ComputeFeature()

neuralnet[h].ComputeGradient(i+1, updater)

if (i+1)%disp_freq == 0:
print ' Step {:>3}: '.format(i+1),
neuralnet[h].display()

if (i+1)%check_freq == 0:
save_model_parameter(i+1, workspace, neuralnet)


print '[Finish training]'


def product(workspace,checkpoint):

print '[Layer registration/declaration]'
# TODO change layer registration methods
d = Driver()
d.Init(sys.argv)

load_model_parameter(workspace+checkpoint, neuralnet,1)

app.debug = True
app.run(host='0.0.0.0', port=80)


@app.route("/")
def index():
return "Hello World! This is SINGA DLAAS! Please send post request with image=file to '/predict' "

def allowed_file(filename):
allowd_extensions_ = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])
return '.' in filename and \
filename.rsplit('.', 1)[1] in allowd_extensions_

@app.route('/predict', methods=['POST'])
def predict():
size_=(32,32)
pixel_length_=3*size_[0]*size_[1]
label_num_=10
if request.method == 'POST':
file = request.files['image']
if file and allowed_file(file.filename):
im = Image.open(file).convert("RGB")
im = imgtool.resize_to_center(im,size_)
pixel = floatVector(pixel_length_)
byteArray = imgtool.toBin(im,size_)
data = np.frombuffer(byteArray, dtype=np.uint8)
data = data.reshape(1, pixel_length_)
#dummy data Layer
shape = intVector(4)
shape[0]=1
shape[1]=3
shape[2]=size_[0]
shape[3]=size_[1]

for h in range(len(neuralnet)):
#Fetch data for input layer
if neuralnet[h].is_datalayer:
if not neuralnet[h].is_label:
neuralnet[h].Feed(data,3)
else:
neuralnet[h].FetchData(1)
else:
neuralnet[h].ComputeFeature()

#get result
#data = neuralnet[-1].get_singalayer().data(neuralnet[-1].get_singalayer())
#prop =floatArray_frompointer(data.mutable_cpu_data())
prop = neuralnet[-1].GetData()
print prop
result=[]
for i in range(label_num_):
result.append((i,prop[i]))

result.sort(key=lambda tup: tup[1], reverse=True)
print result
response=""
for r in result:
response+=str(r[0])+":"+str(r[1])

return response
return "error"


if __name__=='__main__':

if sys.argv[1]=="train":
if len(sys.argv) < 6:
print "argv should be more than 6"
exit()
if len(sys.argv) > 6:
checkpoint = sys.argv[6]
else:
checkpoint = None
#training
train(
batchsize = int(sys.argv[2]),
disp_freq = int(sys.argv[3]),
check_freq = int(sys.argv[4]),
train_step = int(sys.argv[5]),
workspace = '/workspace',
checkpoint = checkpoint,
)
else:
if len(sys.argv) < 3:
print "argv should be more than 2"
exit()
checkpoint = sys.argv[2]
product(
workspace = '/workspace',
checkpoint = checkpoint
)

Loading

0 comments on commit 1840cb7

Please sign in to comment.