-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
memory management: swap+pool #412
Open
junzhezhang
wants to merge
19
commits into
apache:master
Choose a base branch
from
junzhezhang:vd2
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
e7f695d
change / into * reverse number
junzhezhang 2cc6ba4
align train.py with vc12
junzhezhang 611ad3c
revert back train.py
junzhezhang 9974d5d
revert back train.py
junzhezhang 7309dc0
update train iteration number
junzhezhang 25fe79d
change cnmem src, common.h, common.cc and the cmakelist
junzhezhang fecf34f
disable common.cc appendInfo, for device src done first.
junzhezhang c07dcc6
update device and memory family src
junzhezhang 8dae4d8
unable common.cc appendInfo
junzhezhang 572fe4d
correct swap_select()
junzhezhang d2027a9
correct swap_select()
junzhezhang 9d84cfc
enable swap_plan()
junzhezhang ed73c3a
documentation
junzhezhang 650a4e6
impl swap_construct_tables(), swap_update_tables(), DeploySwap()
junzhezhang 479ad2a
enable include negative r_idx into Table_sched
junzhezhang 383fffe
cross iteration swap, last iteration, milestone
junzhezhang 2fb3f02
vd2: swap+pool
junzhezhang 0f3722d
add documentation
junzhezhang 8e8a7e1
Replace the strings with struct for function Append.
junzhezhang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,24 +31,25 @@ | |
import os | ||
import argparse | ||
|
||
# sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) | ||
from singa import utils | ||
from singa import optimizer | ||
from singa import device | ||
from singa import tensor | ||
from singa.proto import core_pb2 | ||
from caffe import caffe_net | ||
|
||
import cnn | ||
import vgg | ||
import resnet | ||
|
||
from datetime import datetime | ||
import time | ||
|
||
def load_dataset(filepath): | ||
print('Loading data file %s' % filepath) | ||
with open(filepath, 'rb') as fd: | ||
try: | ||
cifar10 = pickle.load(fd, encoding='latin1') | ||
except TypeError: | ||
cifar10 = pickle.load(fd) | ||
cifar10 = pickle.load(fd) | ||
image = cifar10['data'].astype(dtype=np.uint8) | ||
image = image.reshape((-1, 3, 32, 32)) | ||
label = np.asarray(cifar10['labels'], dtype=np.uint8) | ||
|
@@ -129,24 +130,35 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100, | |
dev = device.get_default_device() | ||
else: | ||
print('Using GPU') | ||
dev = device.create_cuda_gpu() | ||
dev = device.create_cuda_gpu_on(0) | ||
|
||
net.to_device(dev) | ||
opt = optimizer.SGD(momentum=0.9, weight_decay=weight_decay) | ||
for (p, specs) in zip(net.param_names(), net.param_specs()): | ||
opt.register(p, specs) | ||
|
||
tx = tensor.Tensor((batch_size, 3, 32, 32), dev) | ||
ty = tensor.Tensor((batch_size,), dev, tensor.int32) | ||
ty = tensor.Tensor((batch_size,), dev, core_pb2.kInt) | ||
train_x, train_y, test_x, test_y = data | ||
num_train_batch = train_x.shape[0] // batch_size | ||
num_test_batch = test_x.shape[0] // batch_size | ||
idx = np.arange(train_x.shape[0], dtype=np.int32) | ||
for epoch in range(max_epoch): | ||
fileTimeLog =open("epochTimeLog.text","a") | ||
for epoch in range(1): | ||
np.random.shuffle(idx) | ||
loss, acc = 0.0, 0.0 | ||
print('Epoch %d' % epoch) | ||
for b in range(num_train_batch): | ||
print(datetime.now().timetz()) # miliseconds | ||
print(int(round(time.time()*1000))) | ||
fileTimeLog.write('Epoch %d: ' % epoch) | ||
fileTimeLog.write(str(int(round(time.time()*1000)))) | ||
fileTimeLog.write('\n') | ||
for b in range(20): #num_train_batch): | ||
print ("start of iteration %d: " %b) | ||
#time.sleep(1) | ||
fileTimeLog.write('iteration %d: ' % b) | ||
fileTimeLog.write(str(int(round(time.time()*1000)))) | ||
fileTimeLog.write('\n') | ||
x = train_x[idx[b * batch_size: (b + 1) * batch_size]] | ||
y = train_y[idx[b * batch_size: (b + 1) * batch_size]] | ||
tx.copy_from_numpy(x) | ||
|
@@ -164,7 +176,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100, | |
print(info) | ||
|
||
loss, acc = 0.0, 0.0 | ||
for b in range(num_test_batch): | ||
for b in range(0): | ||
x = test_x[b * batch_size: (b + 1) * batch_size] | ||
y = test_y[b * batch_size: (b + 1) * batch_size] | ||
tx.copy_from_numpy(x) | ||
|
@@ -175,14 +187,16 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100, | |
|
||
print('test loss = %f, test accuracy = %f' % | ||
((loss / num_test_batch), (acc / num_test_batch))) | ||
fileTimeLog.close() | ||
net.save('model', 20) # save model params into checkpoint file | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Train dcnn for cifar10') | ||
parser.add_argument('model', choices=['vgg', 'cnn', 'resnet', 'caffe'], | ||
default='vgg') | ||
parser.add_argument('model', choices=['vgg', 'alexnet', 'resnet', 'caffe'], | ||
default='alexnet') | ||
parser.add_argument('data', default='cifar-10-batches-py') | ||
parser.add_argument('--use_cpu', action='store_true') | ||
parser.add_argument('batch_size',type=int, default=100) | ||
args = parser.parse_args() | ||
assert os.path.exists(args.data), \ | ||
'Pls download the cifar10 dataset via "download_data.py py"' | ||
|
@@ -194,22 +208,22 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100, | |
net = caffe_net.create_net(args.use_cpu) | ||
# for cifar10_full_train_test.prototxt | ||
train((train_x, train_y, test_x, test_y), net, 160, alexnet_lr, 0.004, | ||
use_cpu=args.use_cpu) | ||
use_cpu=args.use_cpu,batch_size=args.batch_size) | ||
# for cifar10_quick_train_test.prototxt | ||
# train((train_x, train_y, test_x, test_y), net, 18, caffe_lr, 0.004, | ||
# use_cpu=args.use_cpu) | ||
elif args.model == 'cnn': | ||
elif args.model == 'alexnet': | ||
train_x, test_x = normalize_for_alexnet(train_x, test_x) | ||
net = cnn.create_net(args.use_cpu) | ||
net = alexnet.create_net(args.use_cpu) | ||
train((train_x, train_y, test_x, test_y), net, 2, alexnet_lr, 0.004, | ||
use_cpu=args.use_cpu) | ||
use_cpu=args.use_cpu,batch_size=args.batch_size) | ||
elif args.model == 'vgg': | ||
train_x, test_x = normalize_for_vgg(train_x, test_x) | ||
net = vgg.create_net(args.use_cpu) | ||
train((train_x, train_y, test_x, test_y), net, 250, vgg_lr, 0.0005, | ||
use_cpu=args.use_cpu) | ||
use_cpu=args.use_cpu,batch_size=args.batch_size) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, it would be better to keep the original example code. |
||
train_x, test_x = normalize_for_alexnet(train_x, test_x) | ||
net = resnet.create_net(args.use_cpu) | ||
train((train_x, train_y, test_x, test_y), net, 200, resnet_lr, 1e-4, | ||
use_cpu=args.use_cpu) | ||
use_cpu=args.use_cpu,batch_size=args.batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ | |
#include <atomic> | ||
#include <memory> | ||
#include "singa/utils/logging.h" | ||
|
||
#include <string> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not used? |
||
#ifdef USE_CUDA | ||
#include <cuda_runtime.h> | ||
#include <cublas_v2.h> | ||
|
@@ -52,24 +52,28 @@ typedef struct _Cuda { } Cuda; | |
typedef struct _Opencl { } Opencl; | ||
} // namespace lang | ||
|
||
class Device; | ||
struct DeviceOptInfoToAppend; | ||
|
||
|
||
/// Block represent a chunk of memory (on device or host). | ||
class Block { | ||
public: | ||
Block(void* ptr, size_t size, size_t offset = 0) | ||
: data_(ptr), size_(size), offset_(offset) { | ||
Block(void* ptr, size_t size, size_t offset = 0, Device* ptr_device = nullptr) | ||
: data_(ptr), size_(size), offset_(offset), ptr_device_(ptr_device) { | ||
ref_count_ = 1; // std::make_shared<std::atomic<int>>(1); | ||
} | ||
// Disabled as it is not used currently. | ||
// Block(void* ptr, size_t size, size_t offset, std::shared_ptr<atomic<int>> | ||
// ref) : data_(ptr), size_(size), offset_(offset), ref_count_(ref) {} | ||
void* mutable_data() { | ||
initialized_ = true; | ||
return static_cast<char*>(data_) + offset_; | ||
} | ||
const void* data() const { | ||
CHECK(initialized_) << "Must initialize data before reading it"; | ||
return static_cast<char*>(data_) + offset_; | ||
} | ||
void* mutable_data() ; | ||
|
||
const void* data() const; | ||
|
||
void* get_data() ; | ||
|
||
void update_data(void* data_new) ; | ||
size_t size() const { return size_; } | ||
size_t offset() const { return offset_; } | ||
int IncRefCount() { | ||
|
@@ -89,12 +93,23 @@ class Block { | |
void* data_ = nullptr; | ||
size_t size_ = 0; | ||
size_t offset_ = 0; | ||
Device* ptr_device_; | ||
bool initialized_ = false; | ||
// Disabled as it is not used currently. | ||
// std::shared_ptr<std::atomic<int>> ref_count_ = nullptr; | ||
std::atomic<int> ref_count_; | ||
}; | ||
|
||
// struct for Append purpose in device class. | ||
struct DeviceOptInfoToAppend{ | ||
string operation_type; | ||
string block_ptr; | ||
int size; | ||
long t = (std::chrono::system_clock::now()).time_since_epoch().count(); | ||
|
||
DeviceOptInfoToAppend(string opt_type, string ptr,int s):operation_type(opt_type),block_ptr(ptr),size(s){} | ||
}; | ||
|
||
typedef struct _Context { | ||
std::mt19937 random_generator; | ||
#ifdef USE_CUDA | ||
|
@@ -114,4 +129,4 @@ typedef struct _Context { | |
} Context; | ||
|
||
} // namespace singa | ||
#endif // SINGA_CORE_COMMON_H_ | ||
#endif // SINGA_CORE_COMMON_H_ |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls keep the cnn.py (instead of alexnet.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you don't need to change the example model code.