Skip to content

Commit

Permalink
first pass at server-side sockets (+ client comm changes) (#5)
Browse files Browse the repository at this point in the history
* first pass at server-side sockets

* update client comm

* very hacky polyfill for fetch in node

* load mnist-transfer-cnn model, send weights to client

* changes based on PR feedback

* make loss fun a scalar

* avoid using export default

* spacing and license fixes for polyfill

* convert more serialization methods to async

* test coverage for serialization

* tests for API between server and client

* refactors based on feedback

* fix lint errors
  • Loading branch information
asross authored Jun 12, 2018
1 parent 504435a commit 27b2cfe
Show file tree
Hide file tree
Showing 18 changed files with 615 additions and 111 deletions.
3 changes: 3 additions & 0 deletions demo/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
<script src="/socket.io/socket.io.js"></script>
<script>
const socket = io.connect('http://localhost:3000');
socket.on('downloadVars', function(data) {
console.log(data);
})
</script>
</head>
<body>
Expand Down
43 changes: 43 additions & 0 deletions demo/mnist_transfer_learning_model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';
import {Scalar, Tensor} from '@tensorflow/tfjs';
import {FederatedModel, ModelDict} from '../src/index';

// https://github.com/tensorflow/tfjs-examples/tree/master/mnist-transfer-cnn
const mnistTransferLearningModelURL =
// tslint:disable-next-line:max-line-length
'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json';

export class MnistTransferLearningModel implements FederatedModel {
async setup(): Promise<ModelDict> {
const model = await tf.loadModel(mnistTransferLearningModelURL);

for (let i = 0; i < 7; ++i) {
model.layers[i].trainable = false; // freeze conv layers
}

const loss = (inputs: Tensor, labels: Tensor) => {
const logits = model.predict(inputs) as Tensor;
const losses = tf.losses.softmaxCrossEntropy(logits, labels);
return losses.mean() as Scalar;
};

return {predict: model.predict, vars: model.trainableWeights, loss};
}
}
10 changes: 6 additions & 4 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
},
"license": "Apache-2.0",
"scripts": {
"start": "node dist/server/server.js",
"start": "node dist/src/server/server.js",
"dev": "ts-node src/server/server.ts",
"build": "tsc && cp -r demo/ dist/demo/",
"build": "tsc && cp demo/index.html dist/demo/",
"test": "ts-node node_modules/jasmine/bin/jasmine --config=jasmine.json",
"lint": "tslint -p . -t verbose"
},
Expand All @@ -20,7 +20,9 @@
"@types/express": "^4.16.0",
"@types/socket.io": "^1.4.34",
"@types/socket.io-client": "^1.4.32",
"es6-promise": "^4.2.4",
"express": "^4.16.3",
"node-fetch": "^2.1.2",
"socket.io": "^2.1.1"
},
"devDependencies": {
Expand All @@ -33,8 +35,8 @@
"jasmine-core": "~3.1.0",
"rimraf": "^2.6.2",
"ts-node": "^6.1.0",
"ts-node-dev": "^1.0.0-pre.26",
"tslint": "~5.8.0",
"typescript": "2.7.2",
"ts-node-dev": "^1.0.0-pre.26"
"typescript": "2.7.2"
}
}
73 changes: 36 additions & 37 deletions src/client/comm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

import * as tf from '@tensorflow/tfjs';
import {ModelFitConfig, Variable} from '@tensorflow/tfjs';
import {assert} from '@tensorflow/tfjs-core/dist/util';
import {Layer} from '@tensorflow/tfjs-layers/dist/engine/topology';
import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables';
import * as socketio from 'socket.io-client';

import {ConnectionMsg, Events, VarsMsg} from '../common';
import {DownloadMsg, Events, UploadMsg} from '../common';
// tslint:disable-next-line:max-line-length
import {deserializeVar, SerializedVariable, serializeVar} from '../serialization';
import {deserializeVar, SerializedVariable, serializeVars} from '../serialization';

const CONNECTION_TIMEOUT = 10 * 1000;
const UPLOAD_TIMEOUT = 1 * 1000;
Expand Down Expand Up @@ -53,20 +54,19 @@ const UPLOAD_TIMEOUT = 1 * 1000;
*/
export class VariableSynchroniser {
public modelId: string;
public numExamples: number;
public fitConfig: ModelFitConfig;
private socket: SocketIOClient.Socket;
private connMsg: ConnectionMsg;
private vars = new Map<string, Variable|LayerVariable>();
private acceptUpdate: (msg: VarsMsg) => boolean;
private vars: Array<Variable|LayerVariable>;
private acceptUpdate: (msg: DownloadMsg) => boolean;
/**
* Construct a synchroniser from a list of tf.Variables of tf.LayerVariables.
* @param {Array<Variable|LayerVariable>} vars - Variables to track and sync
*/
constructor(
vars: Array<Variable|LayerVariable>,
updateCallback?: (msg: VarsMsg) => boolean) {
for (const variable of vars) {
this.vars.set(variable.name, variable);
}
updateCallback?: (msg: DownloadMsg) => boolean) {
this.vars = vars;
if (updateCallback) {
this.acceptUpdate = updateCallback;
} else {
Expand All @@ -84,10 +84,10 @@ export class VariableSynchroniser {
return new VariableSynchroniser(tf.util.flatten(layerWeights, []));
}

private async connect(url: string): Promise<ConnectionMsg> {
private async connect(url: string): Promise<DownloadMsg> {
this.socket = socketio(url);
return fromEvent<ConnectionMsg>(
this.socket, Events.Initialise, CONNECTION_TIMEOUT);
return fromEvent<DownloadMsg>(
this.socket, Events.Download, CONNECTION_TIMEOUT);
}

/**
Expand All @@ -98,26 +98,30 @@ export class VariableSynchroniser {
* and variables set to their inital values.
*/
public async initialise(url: string): Promise<ModelFitConfig> {
this.connMsg = await this.connect(url);
this.setVarsFromMessage(this.connMsg.initVars);
this.modelId = this.connMsg.modelId;
const connMsg = await this.connect(url);
this.setVarsFromMessage(connMsg.vars);
this.modelId = connMsg.modelId;
this.fitConfig = connMsg.fitConfig;
this.numExamples = 0;

this.socket.on(Events.Download, (msg: VarsMsg) => {
this.socket.on(Events.Download, (msg: DownloadMsg) => {
if (this.acceptUpdate(msg)) {
this.setVarsFromMessage(msg.vars);
this.modelId = msg.modelId;
this.fitConfig = msg.fitConfig;
this.numExamples = 0;
}
});

return this.connMsg.fitConfig;
return this.fitConfig;
}

/**
* Upload the current values of the tracked variables to the server
* @return A promise that resolves when the server has recieved the variables
*/
public async uploadVars(): Promise<{}> {
const msg: VarsMsg = await this.serializeCurrentVars();
const msg: UploadMsg = await this.serializeCurrentVars();
const prom = new Promise((resolve, reject) => {
const rejectTimer =
setTimeout(() => reject(`uploadVars timed out`), UPLOAD_TIMEOUT);
Expand All @@ -130,31 +134,26 @@ export class VariableSynchroniser {
return prom;
}

protected async serializeCurrentVars(): Promise<VarsMsg> {
const varsP: Array<Promise<SerializedVariable>> = [];
protected async serializeCurrentVars(): Promise<UploadMsg> {
assert(this.numExamples > 0, 'should only serialize if we\'ve seen data');

this.vars.forEach((value, key) => {
if (value instanceof LayerVariable) {
varsP.push(serializeVar(tf.variable(value.read())));
} else {
varsP.push(serializeVar(value));
}
});
const vars = await Promise.all(varsP);
return {clientId: this.connMsg.clientId, modelId: this.modelId, vars};
const vars = await serializeVars(this.vars);

return {
numExamples: this.numExamples, /* TODO: ensure this gets updated */
modelId: this.modelId,
vars
};
}

protected setVarsFromMessage(newVars: SerializedVariable[]) {
for (const param of newVars) {
if (!this.vars.has(param.name)) {
throw new Error(`Recieved message with unexpected param ${
param.name}, should be one of ${this.vars.keys()}`);
}
const varOrLVar = this.vars.get(param.name);
for (let i = 0; i < newVars.length; i++) {
const newVar = newVars[i];
const varOrLVar = this.vars[i];
if (varOrLVar instanceof LayerVariable) {
varOrLVar.write(deserializeVar(param));
varOrLVar.write(deserializeVar(newVar));
} else {
varOrLVar.assign(deserializeVar(param));
varOrLVar.assign(deserializeVar(newVar));
}
}
}
Expand Down
138 changes: 138 additions & 0 deletions src/comm_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@

/**
* * @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';
import {test_util, Variable} from '@tensorflow/tfjs';
import * as fs from 'fs';
import * as http from 'http';
import * as path from 'path';
import * as rimraf from 'rimraf';
import * as serverSocket from 'socket.io';

import {VariableSynchroniser} from './client/comm';
import {tensorToJson} from './serialization';
import {SocketAPI} from './server/comm';
import {ModelDB} from './server/model_db';

const modelId = '1528400733553';
const batchSize = 42;
const FIT_CONFIG = {batchSize};
const PORT = 3000;
const socketURL = `http://0.0.0.0:${PORT}`;
const initWeights =
[tf.tensor([1, 1, 1, 1], [2, 2]), tf.tensor([1, 2, 3, 4], [1, 4])];
const updateThreshold = 2;

function waitUntil(done: () => boolean, then: () => void, timeout?: number) {
const moveOn = () => {
clearInterval(moveOnIfDone);
clearTimeout(moveOnAnyway);
then();
};
const moveOnAnyway = setTimeout(moveOn, timeout || 100);
const moveOnIfDone = setInterval(() => {
if (done()) {
moveOn();
}
}, 1);
}

describe('Socket API', () => {
let dataDir: string;
let modelDir: string;
let modelPath: string;
let modelDB: ModelDB;
let serverAPI: SocketAPI;
let clientAPI: VariableSynchroniser;
let clientVars: Variable[];
let httpServer: http.Server;

beforeEach(async () => {
// Set up model database with our initial weights
dataDir = fs.mkdtempSync('/tmp/modeldb_test');
modelDir = path.join(dataDir, modelId);
modelPath = path.join(dataDir, modelId + '.json');
fs.mkdirSync(modelDir);
const modelJSON = await Promise.all(initWeights.map(tensorToJson));
fs.writeFileSync(modelPath, JSON.stringify({'vars': modelJSON}));
modelDB = new ModelDB(dataDir, updateThreshold);
await modelDB.setup();

// Set up the server exposing our upload/download API
httpServer = http.createServer();
serverAPI = new SocketAPI(modelDB, FIT_CONFIG, serverSocket(httpServer));
await serverAPI.setup();
await httpServer.listen(PORT);

// Set up the API client with zeroed out weights
clientVars = initWeights.map(t => tf.variable(tf.zerosLike(t)));
clientAPI = new VariableSynchroniser(clientVars);
await clientAPI.initialise(socketURL);
});

afterEach(async () => {
rimraf.sync(dataDir);
await httpServer.close();
});

it('transmits fit config on startup', () => {
expect(clientAPI.fitConfig.batchSize).toBe(batchSize);
});

it('transmits model version on startup', () => {
expect(clientAPI.modelId).toBe(modelId);
});

it('transmits model parameters on startup', () => {
test_util.expectArraysClose(clientVars[0], initWeights[0]);
test_util.expectArraysClose(clientVars[1], initWeights[1]);
});

it('transmits updates', async () => {
let updateFiles = await modelDB.listUpdateFiles();
expect(updateFiles.length).toBe(0);

clientVars[0].assign(tf.tensor([2, 2, 2, 2], [2, 2]));
clientAPI.numExamples = 1;
await clientAPI.uploadVars();

updateFiles = await modelDB.listUpdateFiles();
expect(updateFiles.length).toBe(1);
});

it('triggers a download after enough uploads', async (done) => {
clientVars[0].assign(tf.tensor([2, 2, 2, 2], [2, 2]));
clientAPI.numExamples = 1;
await clientAPI.uploadVars();

clientVars[0].assign(tf.tensor([1, 1, 1, 1], [2, 2]));
clientVars[1].assign(tf.tensor([4, 3, 2, 1], [1, 4]));
clientAPI.numExamples = 3;
await clientAPI.uploadVars();

waitUntil(() => clientAPI.modelId !== modelId, () => {
test_util.expectArraysClose(
clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2]));
test_util.expectArraysClose(
clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4]));
expect(clientAPI.numExamples).toBe(0);
expect(clientAPI.modelId).toBe(modelDB.modelId);
done();
});
});
});
17 changes: 5 additions & 12 deletions src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,18 @@ import {ModelFitConfig} from '@tensorflow/tfjs';
import {SerializedVariable} from './serialization';

export enum Events {
Initialise = 'initialise',
Download = 'downloadVars',
Upload = 'uploadVars'
}

export type TrainingInfo = {
nSteps: number
};

export type VarsMsg = {
export type UploadMsg = {
modelId: string,
clientId: string,
vars: SerializedVariable[],
history?: TrainingInfo
numExamples: number
};

export type ConnectionMsg = {
clientId: string,
fitConfig: ModelFitConfig,
export type DownloadMsg = {
modelId: string,
initVars: SerializedVariable[]
vars: SerializedVariable[]
fitConfig: ModelFitConfig,
};
Loading

0 comments on commit 27b2cfe

Please sign in to comment.