diff --git a/demo/index.html b/demo/index.html
index d384003..018fef6 100644
--- a/demo/index.html
+++ b/demo/index.html
@@ -3,6 +3,9 @@
diff --git a/demo/mnist_transfer_learning_model.ts b/demo/mnist_transfer_learning_model.ts
new file mode 100644
index 0000000..1e93b40
--- /dev/null
+++ b/demo/mnist_transfer_learning_model.ts
@@ -0,0 +1,43 @@
+/**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import * as tf from '@tensorflow/tfjs';
+import {Scalar, Tensor} from '@tensorflow/tfjs';
+import {FederatedModel, ModelDict} from '../src/index';
+
+// https://github.com/tensorflow/tfjs-examples/tree/master/mnist-transfer-cnn
+const mnistTransferLearningModelURL =
+ // tslint:disable-next-line:max-line-length
+ 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json';
+
+export class MnistTransferLearningModel implements FederatedModel {
+ async setup(): Promise {
+ const model = await tf.loadModel(mnistTransferLearningModelURL);
+
+ for (let i = 0; i < 7; ++i) {
+ model.layers[i].trainable = false; // freeze conv layers
+ }
+
+ const loss = (inputs: Tensor, labels: Tensor) => {
+ const logits = model.predict(inputs) as Tensor;
+ const losses = tf.losses.softmaxCrossEntropy(logits, labels);
+ return losses.mean() as Scalar;
+ };
+
+ return {predict: model.predict, vars: model.trainableWeights, loss};
+ }
+}
diff --git a/package.json b/package.json
index 420528a..2da4fe0 100644
--- a/package.json
+++ b/package.json
@@ -9,9 +9,9 @@
},
"license": "Apache-2.0",
"scripts": {
- "start": "node dist/server/server.js",
+ "start": "node dist/src/server/server.js",
"dev": "ts-node src/server/server.ts",
- "build": "tsc && cp -r demo/ dist/demo/",
+ "build": "tsc && cp demo/index.html dist/demo/",
"test": "ts-node node_modules/jasmine/bin/jasmine --config=jasmine.json",
"lint": "tslint -p . -t verbose"
},
@@ -20,7 +20,9 @@
"@types/express": "^4.16.0",
"@types/socket.io": "^1.4.34",
"@types/socket.io-client": "^1.4.32",
+ "es6-promise": "^4.2.4",
"express": "^4.16.3",
+ "node-fetch": "^2.1.2",
"socket.io": "^2.1.1"
},
"devDependencies": {
@@ -33,8 +35,8 @@
"jasmine-core": "~3.1.0",
"rimraf": "^2.6.2",
"ts-node": "^6.1.0",
+ "ts-node-dev": "^1.0.0-pre.26",
"tslint": "~5.8.0",
- "typescript": "2.7.2",
- "ts-node-dev": "^1.0.0-pre.26"
+ "typescript": "2.7.2"
}
}
diff --git a/src/client/comm.ts b/src/client/comm.ts
index 7818749..96e6bef 100644
--- a/src/client/comm.ts
+++ b/src/client/comm.ts
@@ -17,13 +17,14 @@
import * as tf from '@tensorflow/tfjs';
import {ModelFitConfig, Variable} from '@tensorflow/tfjs';
+import {assert} from '@tensorflow/tfjs-core/dist/util';
import {Layer} from '@tensorflow/tfjs-layers/dist/engine/topology';
import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables';
import * as socketio from 'socket.io-client';
-import {ConnectionMsg, Events, VarsMsg} from '../common';
+import {DownloadMsg, Events, UploadMsg} from '../common';
// tslint:disable-next-line:max-line-length
-import {deserializeVar, SerializedVariable, serializeVar} from '../serialization';
+import {deserializeVar, SerializedVariable, serializeVars} from '../serialization';
const CONNECTION_TIMEOUT = 10 * 1000;
const UPLOAD_TIMEOUT = 1 * 1000;
@@ -53,20 +54,19 @@ const UPLOAD_TIMEOUT = 1 * 1000;
*/
export class VariableSynchroniser {
public modelId: string;
+ public numExamples: number;
+ public fitConfig: ModelFitConfig;
private socket: SocketIOClient.Socket;
- private connMsg: ConnectionMsg;
- private vars = new Map();
- private acceptUpdate: (msg: VarsMsg) => boolean;
+ private vars: Array;
+ private acceptUpdate: (msg: DownloadMsg) => boolean;
/**
* Construct a synchroniser from a list of tf.Variables of tf.LayerVariables.
* @param {Array} vars - Variables to track and sync
*/
constructor(
vars: Array,
- updateCallback?: (msg: VarsMsg) => boolean) {
- for (const variable of vars) {
- this.vars.set(variable.name, variable);
- }
+ updateCallback?: (msg: DownloadMsg) => boolean) {
+ this.vars = vars;
if (updateCallback) {
this.acceptUpdate = updateCallback;
} else {
@@ -84,10 +84,10 @@ export class VariableSynchroniser {
return new VariableSynchroniser(tf.util.flatten(layerWeights, []));
}
- private async connect(url: string): Promise {
+ private async connect(url: string): Promise {
this.socket = socketio(url);
- return fromEvent(
- this.socket, Events.Initialise, CONNECTION_TIMEOUT);
+ return fromEvent(
+ this.socket, Events.Download, CONNECTION_TIMEOUT);
}
/**
@@ -98,18 +98,22 @@ export class VariableSynchroniser {
* and variables set to their inital values.
*/
public async initialise(url: string): Promise {
- this.connMsg = await this.connect(url);
- this.setVarsFromMessage(this.connMsg.initVars);
- this.modelId = this.connMsg.modelId;
+ const connMsg = await this.connect(url);
+ this.setVarsFromMessage(connMsg.vars);
+ this.modelId = connMsg.modelId;
+ this.fitConfig = connMsg.fitConfig;
+ this.numExamples = 0;
- this.socket.on(Events.Download, (msg: VarsMsg) => {
+ this.socket.on(Events.Download, (msg: DownloadMsg) => {
if (this.acceptUpdate(msg)) {
this.setVarsFromMessage(msg.vars);
this.modelId = msg.modelId;
+ this.fitConfig = msg.fitConfig;
+ this.numExamples = 0;
}
});
- return this.connMsg.fitConfig;
+ return this.fitConfig;
}
/**
@@ -117,7 +121,7 @@ export class VariableSynchroniser {
* @return A promise that resolves when the server has recieved the variables
*/
public async uploadVars(): Promise<{}> {
- const msg: VarsMsg = await this.serializeCurrentVars();
+ const msg: UploadMsg = await this.serializeCurrentVars();
const prom = new Promise((resolve, reject) => {
const rejectTimer =
setTimeout(() => reject(`uploadVars timed out`), UPLOAD_TIMEOUT);
@@ -130,31 +134,26 @@ export class VariableSynchroniser {
return prom;
}
- protected async serializeCurrentVars(): Promise {
- const varsP: Array> = [];
+ protected async serializeCurrentVars(): Promise {
+ assert(this.numExamples > 0, 'should only serialize if we\'ve seen data');
- this.vars.forEach((value, key) => {
- if (value instanceof LayerVariable) {
- varsP.push(serializeVar(tf.variable(value.read())));
- } else {
- varsP.push(serializeVar(value));
- }
- });
- const vars = await Promise.all(varsP);
- return {clientId: this.connMsg.clientId, modelId: this.modelId, vars};
+ const vars = await serializeVars(this.vars);
+
+ return {
+ numExamples: this.numExamples, /* TODO: ensure this gets updated */
+ modelId: this.modelId,
+ vars
+ };
}
protected setVarsFromMessage(newVars: SerializedVariable[]) {
- for (const param of newVars) {
- if (!this.vars.has(param.name)) {
- throw new Error(`Recieved message with unexpected param ${
- param.name}, should be one of ${this.vars.keys()}`);
- }
- const varOrLVar = this.vars.get(param.name);
+ for (let i = 0; i < newVars.length; i++) {
+ const newVar = newVars[i];
+ const varOrLVar = this.vars[i];
if (varOrLVar instanceof LayerVariable) {
- varOrLVar.write(deserializeVar(param));
+ varOrLVar.write(deserializeVar(newVar));
} else {
- varOrLVar.assign(deserializeVar(param));
+ varOrLVar.assign(deserializeVar(newVar));
}
}
}
diff --git a/src/comm_test.ts b/src/comm_test.ts
new file mode 100644
index 0000000..dfd3a03
--- /dev/null
+++ b/src/comm_test.ts
@@ -0,0 +1,138 @@
+
+/**
+ * * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import * as tf from '@tensorflow/tfjs';
+import {test_util, Variable} from '@tensorflow/tfjs';
+import * as fs from 'fs';
+import * as http from 'http';
+import * as path from 'path';
+import * as rimraf from 'rimraf';
+import * as serverSocket from 'socket.io';
+
+import {VariableSynchroniser} from './client/comm';
+import {tensorToJson} from './serialization';
+import {SocketAPI} from './server/comm';
+import {ModelDB} from './server/model_db';
+
+const modelId = '1528400733553';
+const batchSize = 42;
+const FIT_CONFIG = {batchSize};
+const PORT = 3000;
+const socketURL = `http://0.0.0.0:${PORT}`;
+const initWeights =
+ [tf.tensor([1, 1, 1, 1], [2, 2]), tf.tensor([1, 2, 3, 4], [1, 4])];
+const updateThreshold = 2;
+
+function waitUntil(done: () => boolean, then: () => void, timeout?: number) {
+ const moveOn = () => {
+ clearInterval(moveOnIfDone);
+ clearTimeout(moveOnAnyway);
+ then();
+ };
+ const moveOnAnyway = setTimeout(moveOn, timeout || 100);
+ const moveOnIfDone = setInterval(() => {
+ if (done()) {
+ moveOn();
+ }
+ }, 1);
+}
+
+describe('Socket API', () => {
+ let dataDir: string;
+ let modelDir: string;
+ let modelPath: string;
+ let modelDB: ModelDB;
+ let serverAPI: SocketAPI;
+ let clientAPI: VariableSynchroniser;
+ let clientVars: Variable[];
+ let httpServer: http.Server;
+
+ beforeEach(async () => {
+ // Set up model database with our initial weights
+ dataDir = fs.mkdtempSync('/tmp/modeldb_test');
+ modelDir = path.join(dataDir, modelId);
+ modelPath = path.join(dataDir, modelId + '.json');
+ fs.mkdirSync(modelDir);
+ const modelJSON = await Promise.all(initWeights.map(tensorToJson));
+ fs.writeFileSync(modelPath, JSON.stringify({'vars': modelJSON}));
+ modelDB = new ModelDB(dataDir, updateThreshold);
+ await modelDB.setup();
+
+ // Set up the server exposing our upload/download API
+ httpServer = http.createServer();
+ serverAPI = new SocketAPI(modelDB, FIT_CONFIG, serverSocket(httpServer));
+ await serverAPI.setup();
+ await httpServer.listen(PORT);
+
+ // Set up the API client with zeroed out weights
+ clientVars = initWeights.map(t => tf.variable(tf.zerosLike(t)));
+ clientAPI = new VariableSynchroniser(clientVars);
+ await clientAPI.initialise(socketURL);
+ });
+
+ afterEach(async () => {
+ rimraf.sync(dataDir);
+ await httpServer.close();
+ });
+
+ it('transmits fit config on startup', () => {
+ expect(clientAPI.fitConfig.batchSize).toBe(batchSize);
+ });
+
+ it('transmits model version on startup', () => {
+ expect(clientAPI.modelId).toBe(modelId);
+ });
+
+ it('transmits model parameters on startup', () => {
+ test_util.expectArraysClose(clientVars[0], initWeights[0]);
+ test_util.expectArraysClose(clientVars[1], initWeights[1]);
+ });
+
+ it('transmits updates', async () => {
+ let updateFiles = await modelDB.listUpdateFiles();
+ expect(updateFiles.length).toBe(0);
+
+ clientVars[0].assign(tf.tensor([2, 2, 2, 2], [2, 2]));
+ clientAPI.numExamples = 1;
+ await clientAPI.uploadVars();
+
+ updateFiles = await modelDB.listUpdateFiles();
+ expect(updateFiles.length).toBe(1);
+ });
+
+ it('triggers a download after enough uploads', async (done) => {
+ clientVars[0].assign(tf.tensor([2, 2, 2, 2], [2, 2]));
+ clientAPI.numExamples = 1;
+ await clientAPI.uploadVars();
+
+ clientVars[0].assign(tf.tensor([1, 1, 1, 1], [2, 2]));
+ clientVars[1].assign(tf.tensor([4, 3, 2, 1], [1, 4]));
+ clientAPI.numExamples = 3;
+ await clientAPI.uploadVars();
+
+ waitUntil(() => clientAPI.modelId !== modelId, () => {
+ test_util.expectArraysClose(
+ clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2]));
+ test_util.expectArraysClose(
+ clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4]));
+ expect(clientAPI.numExamples).toBe(0);
+ expect(clientAPI.modelId).toBe(modelDB.modelId);
+ done();
+ });
+ });
+});
diff --git a/src/common.ts b/src/common.ts
index 273b7dd..a7856d6 100644
--- a/src/common.ts
+++ b/src/common.ts
@@ -20,25 +20,18 @@ import {ModelFitConfig} from '@tensorflow/tfjs';
import {SerializedVariable} from './serialization';
export enum Events {
- Initialise = 'initialise',
Download = 'downloadVars',
Upload = 'uploadVars'
}
-export type TrainingInfo = {
- nSteps: number
-};
-
-export type VarsMsg = {
+export type UploadMsg = {
modelId: string,
- clientId: string,
vars: SerializedVariable[],
- history?: TrainingInfo
+ numExamples: number
};
-export type ConnectionMsg = {
- clientId: string,
- fitConfig: ModelFitConfig,
+export type DownloadMsg = {
modelId: string,
- initVars: SerializedVariable[]
+ vars: SerializedVariable[]
+ fitConfig: ModelFitConfig,
};
diff --git a/src/index.ts b/src/index.ts
new file mode 100644
index 0000000..a858730
--- /dev/null
+++ b/src/index.ts
@@ -0,0 +1,18 @@
+/**
+ * * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+export * from './types';
diff --git a/src/model.ts b/src/model.ts
new file mode 100644
index 0000000..fbd418b
--- /dev/null
+++ b/src/model.ts
@@ -0,0 +1,22 @@
+/**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+// tslint:disable-next-line:max-line-length
+import {MnistTransferLearningModel} from '../demo/mnist_transfer_learning_model';
+
+// TODO: some kind of flag to determine what model we use
+export class Model extends MnistTransferLearningModel {}
diff --git a/src/serialization.ts b/src/serialization.ts
index 8834863..00bf0ca 100644
--- a/src/serialization.ts
+++ b/src/serialization.ts
@@ -1,26 +1,33 @@
import * as tf from '@tensorflow/tfjs';
+import {Tensor, Variable} from '@tensorflow/tfjs';
+import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables';
export type SerializedVariable = {
dtype: tf.DataType,
shape: number[],
- name: string,
- data: ArrayBuffer,
- trainable: boolean
+ data: ArrayBuffer
};
-export async function serializeVar(variable: tf.Variable):
+export async function serializeVar(variable: tf.Tensor):
Promise {
const data = await variable.data();
// small TypedArrays are views into a larger buffer
const copy =
data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength);
- return {
- dtype: variable.dtype,
- shape: variable.shape.slice(),
- name: variable.name,
- trainable: variable.trainable,
- data: copy
- };
+ return {dtype: variable.dtype, shape: variable.shape.slice(), data: copy};
+}
+
+export async function serializeVars(
+ vars: Array) {
+ const varsP: Array> = [];
+ vars.forEach((value, key) => {
+ if (value instanceof LayerVariable) {
+ varsP.push(serializeVar(tf.variable(value.read())));
+ } else {
+ varsP.push(serializeVar(value));
+ }
+ });
+ return Promise.all(varsP);
}
export function deserializeVar(serialized: SerializedVariable): tf.Tensor {
@@ -43,6 +50,37 @@ export function deserializeVar(serialized: SerializedVariable): tf.Tensor {
return tf.tensor(array, shape, dtype);
}
+export type TensorJson = {
+ values: number[],
+ shape: number[],
+ dtype?: tf.DataType
+};
+
+export async function tensorToJson(t: tf.Tensor): Promise {
+ let data;
+ if (t instanceof LayerVariable) {
+ data = await t.read().data();
+ } else {
+ data = await t.data();
+ }
+ // Note: could make this async / use base64 encoding on the buffer data
+ return {'values': Array.from(data), 'shape': t.shape, 'dtype': t.dtype};
+}
+
+export function jsonToTensor(j: TensorJson): tf.Tensor {
+ return tf.tensor(j.values, j.shape, j.dtype || 'float32');
+}
+
+export async function serializedToJson(s: SerializedVariable):
+ Promise {
+ return tensorToJson(deserializeVar(s));
+}
+
+export async function jsonToSerialized(j: TensorJson):
+ Promise {
+ return serializeVar(jsonToTensor(j));
+}
+
const dtypeToTypedArrayCtor = {
'float32': Float32Array,
'int32': Int32Array,
diff --git a/src/serialization_test.ts b/src/serialization_test.ts
new file mode 100644
index 0000000..071df66
--- /dev/null
+++ b/src/serialization_test.ts
@@ -0,0 +1,58 @@
+/**
+ * * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import * as tf from '@tensorflow/tfjs';
+import {test_util} from '@tensorflow/tfjs-core';
+
+import * as ser from './serialization';
+
+describe('serialization', () => {
+ const floatTensor =
+ tf.tensor3d([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]]);
+ const boolTensor = tf.tensor1d([true, false], 'bool');
+ const intTensor = tf.tensor2d([[1, 2], [3, 4]], [2, 2], 'int32');
+
+ it('converts back and forth to JSON', async () => {
+ const floatJSON = await ser.tensorToJson(floatTensor);
+ const boolJSON = await ser.tensorToJson(boolTensor);
+ const intJSON = await ser.tensorToJson(intTensor);
+ const floatTensor2 = ser.jsonToTensor(floatJSON);
+ const boolTensor2 = ser.jsonToTensor(boolJSON);
+ const intTensor2 = ser.jsonToTensor(intJSON);
+ test_util.expectArraysClose(floatTensor, floatTensor2);
+ test_util.expectArraysClose(boolTensor, boolTensor2);
+ test_util.expectArraysClose(intTensor, intTensor2);
+ });
+
+ it('converts back and forth to SerializedVar', async () => {
+ const floatSerial = await ser.serializeVar(floatTensor);
+ const boolSerial = await ser.serializeVar(boolTensor);
+ const intSerial = await ser.serializeVar(intTensor);
+ const floatTensor2 = ser.deserializeVar(floatSerial);
+ const boolTensor2 = ser.deserializeVar(boolSerial);
+ const intTensor2 = ser.deserializeVar(intSerial);
+ test_util.expectArraysClose(floatTensor, floatTensor2);
+ test_util.expectArraysClose(boolTensor, boolTensor2);
+ test_util.expectArraysClose(intTensor, intTensor2);
+ });
+
+ it('works for an arbitrary chain', async () => {
+ const floatTensor2 = ser.jsonToTensor(await ser.serializedToJson(
+ await ser.jsonToSerialized(await ser.tensorToJson(floatTensor))));
+ test_util.expectArraysClose(floatTensor2, floatTensor);
+ });
+});
diff --git a/src/server/comm.ts b/src/server/comm.ts
new file mode 100644
index 0000000..3a61a1a
--- /dev/null
+++ b/src/server/comm.ts
@@ -0,0 +1,90 @@
+/**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import {ModelFitConfig} from '@tensorflow/tfjs';
+import * as fs from 'fs';
+import * as path from 'path';
+import {Server, Socket} from 'socket.io';
+import {promisify} from 'util';
+import * as uuid from 'uuid/v4';
+
+import {DownloadMsg, Events, UploadMsg} from '../common';
+import {serializedToJson, serializeVar} from '../serialization';
+
+import {ModelDB} from './model_db';
+
+const writeFile = promisify(fs.writeFile);
+
+export class SocketAPI {
+ modelDB: ModelDB;
+ fitConfig: ModelFitConfig;
+ io: Server;
+
+ constructor(modelDB: ModelDB, fitConfig: ModelFitConfig, io: Server) {
+ this.modelDB = modelDB;
+ this.fitConfig = fitConfig;
+ this.io = io;
+ }
+
+ async downloadMsg(): Promise {
+ const varsJson = await this.modelDB.currentVars();
+ const varsSeri = await Promise.all(varsJson.map(serializeVar));
+ return {
+ fitConfig: this.fitConfig,
+ modelId: this.modelDB.modelId,
+ vars: varsSeri
+ };
+ }
+
+ async setup() {
+ this.io.on('connection', async (socket: Socket) => {
+ // Send current variables to newly connected client
+ const initVars = await this.downloadMsg();
+ socket.emit(Events.Download, initVars);
+
+ // When a client sends us updated weights
+ socket.on(Events.Upload, async (msg: UploadMsg, ack) => {
+ // Save them to a file
+ const modelId = msg.modelId;
+ const updateId = uuid();
+ const updatePath =
+ path.join(this.modelDB.dataDir, modelId, updateId + '.json');
+ const updatedVars = await Promise.all(msg.vars.map(serializedToJson));
+ const updateJSON = JSON.stringify({
+ clientId: socket.client.id,
+ modelId,
+ numExamples: msg.numExamples,
+ vars: updatedVars
+ });
+ await writeFile(updatePath, updateJSON);
+
+ // Let them know we're done saving
+ ack(true);
+
+ // Potentially update the model (asynchronously)
+ if (modelId === this.modelDB.modelId) {
+ const updated = await this.modelDB.possiblyUpdate();
+ if (updated) {
+ // Send new variables to all clients if we updated
+ const newVars = await this.downloadMsg();
+ this.io.sockets.emit(Events.Download, newVars);
+ }
+ }
+ });
+ });
+ }
+}
diff --git a/src/server/fetch_polyfill.ts b/src/server/fetch_polyfill.ts
new file mode 100644
index 0000000..26324fa
--- /dev/null
+++ b/src/server/fetch_polyfill.ts
@@ -0,0 +1,30 @@
+/**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import * as es6Promise from 'es6-promise';
+es6Promise.polyfill();
+
+eval(`
+var realFetch = require('node-fetch');
+
+if (!global.fetch) {
+ global.fetch = realFetch;
+ global.Response = realFetch.Response;
+ global.Headers = realFetch.Headers;
+ global.Request = realFetch.Request;
+}
+`);
diff --git a/src/server/model_db.ts b/src/server/model_db.ts
index 146f23d..91777bd 100644
--- a/src/server/model_db.ts
+++ b/src/server/model_db.ts
@@ -20,37 +20,28 @@ import * as fs from 'fs';
import * as path from 'path';
import {promisify} from 'util';
+import {Model} from '../model';
+import {jsonToTensor, TensorJson, tensorToJson} from '../serialization';
+
const DEFAULT_MIN_UPDATES = 10;
const mkdir = promisify(fs.mkdir);
+const exists = promisify(fs.exists);
const readdir = promisify(fs.readdir);
const readFile = promisify(fs.readFile);
const writeFile = promisify(fs.writeFile);
-function getLatestId(dir: string) {
- const files = fs.readdirSync(dir);
- return files.reduce((acc, val) => {
- if (val.endsWith('.json') && val.slice(0, -5) > acc) {
- return val.slice(0, -5);
- } else {
- return acc;
+async function getLatestId(dir: string) {
+ const files = await readdir(dir);
+ let latestId: string = null;
+ files.forEach((name) => {
+ if (name.endsWith('.json')) {
+ const id = name.slice(0, -5);
+ if (latestId == null || id > latestId) {
+ latestId = id;
+ }
}
- }, '0');
-}
-
-type TensorJSON = {
- values: number[],
- shape: number[],
- dtype?: tf.DataType
-};
-
-function dumpTensor(t: tf.Tensor) {
- return {
- 'values': Array.from(t.dataSync()), 'shape': t.shape, 'dtype': t.dtype
- }
-}
-
-function loadTensor(obj: TensorJSON) {
- return tf.tensor(obj.values, obj.shape, obj.dtype || 'float32');
+ });
+ return latestId;
}
function generateNewId() {
@@ -68,13 +59,27 @@ export class ModelDB {
updating: boolean;
minUpdates: number;
- constructor(dataDir: string, minUpdates?: number, currentModelId?: string) {
+ constructor(dataDir: string, minUpdates?: number) {
this.dataDir = dataDir;
- this.modelId = currentModelId || getLatestId(dataDir);
+ this.modelId = null;
this.updating = false;
this.minUpdates = minUpdates || DEFAULT_MIN_UPDATES;
}
+ async setup() {
+ const dirExists = await exists(this.dataDir);
+ if (!dirExists) {
+ await mkdir(this.dataDir);
+ }
+
+ this.modelId = await getLatestId(this.dataDir);
+ if (this.modelId == null) {
+ const model = new Model();
+ const dict = await model.setup();
+ await this.writeNewVars(dict.vars as tf.Tensor[]);
+ }
+ }
+
async listUpdateFiles(): Promise {
const files = await readdir(path.join(this.dataDir, this.modelId));
return files.map((f) => {
@@ -85,21 +90,23 @@ export class ModelDB {
async currentVars(): Promise {
const file = path.join(this.dataDir, this.modelId + '.json');
const json = await readJSON(file);
- return json['vars'].map(loadTensor);
+ return json['vars'].map(jsonToTensor);
}
- async possiblyUpdate() {
+ async possiblyUpdate(): Promise {
const updateFiles = await this.listUpdateFiles();
if (updateFiles.length < this.minUpdates || this.updating) {
- return;
+ return false;
}
this.updating = true;
await this.update();
this.updating = false;
+ return true;
}
async update() {
- const updatedVars = await this.currentVars();
+ const currentVars = await this.currentVars();
+ const updatedVars = currentVars.map(v => tf.zerosLike(v));
const updateFiles = await this.listUpdateFiles();
const updatesJSON = await Promise.all(updateFiles.map(readJSON));
@@ -114,17 +121,22 @@ export class ModelDB {
updatesJSON.forEach((u) => {
const nk = tf.scalar(u['numExamples']);
const frac = nk.div(n);
- u['vars'].forEach((v: TensorJSON, i: number) => {
- const update = loadTensor(v).mul(frac);
+ u['vars'].forEach((v: TensorJson, i: number) => {
+ const update = jsonToTensor(v).mul(frac);
updatedVars[i] = updatedVars[i].add(update);
});
});
// Save results and update key
+ await this.writeNewVars(updatedVars);
+ }
+
+ async writeNewVars(newVars: tf.Tensor[]) {
const newModelId = generateNewId();
const newModelDir = path.join(this.dataDir, newModelId);
const newModelPath = path.join(this.dataDir, newModelId + '.json');
- const newModelJSON = JSON.stringify({'vars': updatedVars.map(dumpTensor)});
+ const newVarsJSON = await Promise.all(newVars.map(tensorToJson));
+ const newModelJSON = JSON.stringify({'vars': newVarsJSON});
await writeFile(newModelPath, newModelJSON);
await mkdir(newModelDir);
this.modelId = newModelId;
diff --git a/src/server/model_db_test.ts b/src/server/model_db_test.ts
index f747fdd..80393f3 100644
--- a/src/server/model_db_test.ts
+++ b/src/server/model_db_test.ts
@@ -15,7 +15,6 @@
* =============================================================================
*/
-// tslint:disable-next-line:max-line-length
import {test_util} from '@tensorflow/tfjs-core';
import * as fs from 'fs';
import * as path from 'path';
@@ -67,13 +66,15 @@ describe('ModelDB', () => {
rimraf.sync(dataDir);
});
- it('defaults to treating the latest model as current', () => {
+ it('defaults to treating the latest model as current', async () => {
const db = new ModelDB(dataDir);
+ await db.setup();
expect(db.modelId).toBe(modelId);
});
it('loads variables from JSON', async () => {
const db = new ModelDB(dataDir);
+ await db.setup();
const vars = await db.currentVars();
test_util.expectArraysClose(vars[0], [0, 0, 0, 0]);
test_util.expectArraysClose(vars[1], [1, 2, 3, 4]);
@@ -85,16 +86,19 @@ describe('ModelDB', () => {
it('updates the model using a weighted average', async () => {
const db = new ModelDB(dataDir);
+ await db.setup();
await db.update();
expect(db.modelId).not.toBe(modelId);
const newVars = await db.currentVars();
test_util.expectArraysClose(newVars[0], [0.4, -0.4, 0.6, -0.6]);
- test_util.expectArraysClose(newVars[1], [1.2, 2.8, 3.2, 4.8]);
+ test_util.expectArraysClose(newVars[1], [0.2, 0.8, 0.2, 0.8]);
});
it('only performs update after passing a threshold', async () => {
const db = new ModelDB(dataDir, 3);
- await db.possiblyUpdate();
+ await db.setup();
+ let updated = await db.possiblyUpdate();
+ expect(updated).toBe(false);
expect(db.modelId).toBe(modelId);
const oldUpdateFiles = await db.listUpdateFiles();
expect(oldUpdateFiles.length).toBe(2);
@@ -107,9 +111,10 @@ describe('ModelDB', () => {
{values: [0, 0, 0, 0], shape: [1, 4]}
]
}));
- await db.possiblyUpdate();
+ updated = await db.possiblyUpdate();
+ expect(updated).toBe(true);
expect(db.modelId).not.toBe(modelId);
const newUpdateFiles = await db.listUpdateFiles();
expect(newUpdateFiles.length).toBe(0);
});
-})
+});
diff --git a/src/server/server.ts b/src/server/server.ts
index 4e547b9..3b488e8 100644
--- a/src/server/server.ts
+++ b/src/server/server.ts
@@ -17,24 +17,36 @@
/** Server code */
+import './fetch_polyfill';
+
import * as express from 'express';
import {Request, Response} from 'express';
import * as http from 'http';
import * as path from 'path';
import * as socketIO from 'socket.io';
+import {SocketAPI} from './comm';
+import {ModelDB} from './model_db';
+
const app = express();
const server = http.createServer(app);
const io = socketIO(server);
+const indexPath = path.resolve(__dirname + '/../../demo/index.html');
+const dataDir = path.resolve(__dirname + '/../../data');
+const modelDB = new ModelDB(dataDir);
+const FIT_CONFIG = {
+ batchSize: 10
+};
+const socketAPI = new SocketAPI(modelDB, FIT_CONFIG, io);
app.get('/', (req: Request, res: Response) => {
- res.sendFile(path.resolve(__dirname + '/../demo/index.html'));
-});
-
-io.on('connection', (socket: socketIO.Socket) => {
- console.log('a user connected');
+ res.sendFile(indexPath);
});
-server.listen(3000, () => {
- console.log('listening on 3000');
+modelDB.setup().then(() => {
+ socketAPI.setup().then(() => {
+ server.listen(3000, () => {
+ console.log('listening on 3000');
+ });
+ });
});
diff --git a/src/types.ts b/src/types.ts
new file mode 100644
index 0000000..7189b88
--- /dev/null
+++ b/src/types.ts
@@ -0,0 +1,32 @@
+/**
+ * @license
+ * Copyright 2018 Google LLC. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+
+import {Scalar, Tensor, Variable} from '@tensorflow/tfjs';
+import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables';
+
+export type LossFun = (inputs: Tensor, labels: Tensor) => Scalar;
+export type PredFun = (inputs: Tensor) => Tensor|Tensor[];
+export type VarList = Array;
+export type ModelDict = {
+ vars: VarList,
+ loss: LossFun,
+ predict: PredFun
+};
+
+export interface FederatedModel {
+ setup(): Promise;
+}
diff --git a/tsconfig.json b/tsconfig.json
index 38db13c..68efb04 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -21,6 +21,7 @@
"experimentalDecorators": true
},
"include": [
- "src/"
+ "src/",
+ "demo/"
]
}
diff --git a/yarn.lock b/yarn.lock
index 41d4938..bd914c8 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -518,6 +518,10 @@ error-ex@^1.2.0:
dependencies:
is-arrayish "^0.2.1"
+es6-promise@^4.2.4:
+ version "4.2.4"
+ resolved "https://registry.yarnpkg.com/es6-promise/-/es6-promise-4.2.4.tgz#dc4221c2b16518760bd8c39a52d8f356fc00ed29"
+
escape-html@~1.0.3:
version "1.0.3"
resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988"
@@ -950,6 +954,10 @@ node-emoji@^1.4.1:
dependencies:
lodash.toarray "^4.4.0"
+node-fetch@^2.1.2:
+ version "2.1.2"
+ resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.1.2.tgz#ab884e8e7e57e38a944753cec706f788d1768bb5"
+
node-notifier@^4.0.2:
version "4.6.1"
resolved "https://registry.yarnpkg.com/node-notifier/-/node-notifier-4.6.1.tgz#056d14244f3dcc1ceadfe68af9cff0c5473a33f3"