diff --git a/demo/index.html b/demo/index.html index d384003..018fef6 100644 --- a/demo/index.html +++ b/demo/index.html @@ -3,6 +3,9 @@ diff --git a/demo/mnist_transfer_learning_model.ts b/demo/mnist_transfer_learning_model.ts new file mode 100644 index 0000000..1e93b40 --- /dev/null +++ b/demo/mnist_transfer_learning_model.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs'; +import {Scalar, Tensor} from '@tensorflow/tfjs'; +import {FederatedModel, ModelDict} from '../src/index'; + +// https://github.com/tensorflow/tfjs-examples/tree/master/mnist-transfer-cnn +const mnistTransferLearningModelURL = + // tslint:disable-next-line:max-line-length + 'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json'; + +export class MnistTransferLearningModel implements FederatedModel { + async setup(): Promise { + const model = await tf.loadModel(mnistTransferLearningModelURL); + + for (let i = 0; i < 7; ++i) { + model.layers[i].trainable = false; // freeze conv layers + } + + const loss = (inputs: Tensor, labels: Tensor) => { + const logits = model.predict(inputs) as Tensor; + const losses = tf.losses.softmaxCrossEntropy(logits, labels); + return losses.mean() as Scalar; + }; + + return {predict: model.predict, vars: model.trainableWeights, loss}; + } +} diff --git a/package.json b/package.json index 420528a..2da4fe0 100644 --- a/package.json +++ b/package.json @@ -9,9 +9,9 @@ }, "license": "Apache-2.0", "scripts": { - "start": "node dist/server/server.js", + "start": "node dist/src/server/server.js", "dev": "ts-node src/server/server.ts", - "build": "tsc && cp -r demo/ dist/demo/", + "build": "tsc && cp demo/index.html dist/demo/", "test": "ts-node node_modules/jasmine/bin/jasmine --config=jasmine.json", "lint": "tslint -p . -t verbose" }, @@ -20,7 +20,9 @@ "@types/express": "^4.16.0", "@types/socket.io": "^1.4.34", "@types/socket.io-client": "^1.4.32", + "es6-promise": "^4.2.4", "express": "^4.16.3", + "node-fetch": "^2.1.2", "socket.io": "^2.1.1" }, "devDependencies": { @@ -33,8 +35,8 @@ "jasmine-core": "~3.1.0", "rimraf": "^2.6.2", "ts-node": "^6.1.0", + "ts-node-dev": "^1.0.0-pre.26", "tslint": "~5.8.0", - "typescript": "2.7.2", - "ts-node-dev": "^1.0.0-pre.26" + "typescript": "2.7.2" } } diff --git a/src/client/comm.ts b/src/client/comm.ts index 7818749..96e6bef 100644 --- a/src/client/comm.ts +++ b/src/client/comm.ts @@ -17,13 +17,14 @@ import * as tf from '@tensorflow/tfjs'; import {ModelFitConfig, Variable} from '@tensorflow/tfjs'; +import {assert} from '@tensorflow/tfjs-core/dist/util'; import {Layer} from '@tensorflow/tfjs-layers/dist/engine/topology'; import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables'; import * as socketio from 'socket.io-client'; -import {ConnectionMsg, Events, VarsMsg} from '../common'; +import {DownloadMsg, Events, UploadMsg} from '../common'; // tslint:disable-next-line:max-line-length -import {deserializeVar, SerializedVariable, serializeVar} from '../serialization'; +import {deserializeVar, SerializedVariable, serializeVars} from '../serialization'; const CONNECTION_TIMEOUT = 10 * 1000; const UPLOAD_TIMEOUT = 1 * 1000; @@ -53,20 +54,19 @@ const UPLOAD_TIMEOUT = 1 * 1000; */ export class VariableSynchroniser { public modelId: string; + public numExamples: number; + public fitConfig: ModelFitConfig; private socket: SocketIOClient.Socket; - private connMsg: ConnectionMsg; - private vars = new Map(); - private acceptUpdate: (msg: VarsMsg) => boolean; + private vars: Array; + private acceptUpdate: (msg: DownloadMsg) => boolean; /** * Construct a synchroniser from a list of tf.Variables of tf.LayerVariables. * @param {Array} vars - Variables to track and sync */ constructor( vars: Array, - updateCallback?: (msg: VarsMsg) => boolean) { - for (const variable of vars) { - this.vars.set(variable.name, variable); - } + updateCallback?: (msg: DownloadMsg) => boolean) { + this.vars = vars; if (updateCallback) { this.acceptUpdate = updateCallback; } else { @@ -84,10 +84,10 @@ export class VariableSynchroniser { return new VariableSynchroniser(tf.util.flatten(layerWeights, [])); } - private async connect(url: string): Promise { + private async connect(url: string): Promise { this.socket = socketio(url); - return fromEvent( - this.socket, Events.Initialise, CONNECTION_TIMEOUT); + return fromEvent( + this.socket, Events.Download, CONNECTION_TIMEOUT); } /** @@ -98,18 +98,22 @@ export class VariableSynchroniser { * and variables set to their inital values. */ public async initialise(url: string): Promise { - this.connMsg = await this.connect(url); - this.setVarsFromMessage(this.connMsg.initVars); - this.modelId = this.connMsg.modelId; + const connMsg = await this.connect(url); + this.setVarsFromMessage(connMsg.vars); + this.modelId = connMsg.modelId; + this.fitConfig = connMsg.fitConfig; + this.numExamples = 0; - this.socket.on(Events.Download, (msg: VarsMsg) => { + this.socket.on(Events.Download, (msg: DownloadMsg) => { if (this.acceptUpdate(msg)) { this.setVarsFromMessage(msg.vars); this.modelId = msg.modelId; + this.fitConfig = msg.fitConfig; + this.numExamples = 0; } }); - return this.connMsg.fitConfig; + return this.fitConfig; } /** @@ -117,7 +121,7 @@ export class VariableSynchroniser { * @return A promise that resolves when the server has recieved the variables */ public async uploadVars(): Promise<{}> { - const msg: VarsMsg = await this.serializeCurrentVars(); + const msg: UploadMsg = await this.serializeCurrentVars(); const prom = new Promise((resolve, reject) => { const rejectTimer = setTimeout(() => reject(`uploadVars timed out`), UPLOAD_TIMEOUT); @@ -130,31 +134,26 @@ export class VariableSynchroniser { return prom; } - protected async serializeCurrentVars(): Promise { - const varsP: Array> = []; + protected async serializeCurrentVars(): Promise { + assert(this.numExamples > 0, 'should only serialize if we\'ve seen data'); - this.vars.forEach((value, key) => { - if (value instanceof LayerVariable) { - varsP.push(serializeVar(tf.variable(value.read()))); - } else { - varsP.push(serializeVar(value)); - } - }); - const vars = await Promise.all(varsP); - return {clientId: this.connMsg.clientId, modelId: this.modelId, vars}; + const vars = await serializeVars(this.vars); + + return { + numExamples: this.numExamples, /* TODO: ensure this gets updated */ + modelId: this.modelId, + vars + }; } protected setVarsFromMessage(newVars: SerializedVariable[]) { - for (const param of newVars) { - if (!this.vars.has(param.name)) { - throw new Error(`Recieved message with unexpected param ${ - param.name}, should be one of ${this.vars.keys()}`); - } - const varOrLVar = this.vars.get(param.name); + for (let i = 0; i < newVars.length; i++) { + const newVar = newVars[i]; + const varOrLVar = this.vars[i]; if (varOrLVar instanceof LayerVariable) { - varOrLVar.write(deserializeVar(param)); + varOrLVar.write(deserializeVar(newVar)); } else { - varOrLVar.assign(deserializeVar(param)); + varOrLVar.assign(deserializeVar(newVar)); } } } diff --git a/src/comm_test.ts b/src/comm_test.ts new file mode 100644 index 0000000..dfd3a03 --- /dev/null +++ b/src/comm_test.ts @@ -0,0 +1,138 @@ + +/** + * * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs'; +import {test_util, Variable} from '@tensorflow/tfjs'; +import * as fs from 'fs'; +import * as http from 'http'; +import * as path from 'path'; +import * as rimraf from 'rimraf'; +import * as serverSocket from 'socket.io'; + +import {VariableSynchroniser} from './client/comm'; +import {tensorToJson} from './serialization'; +import {SocketAPI} from './server/comm'; +import {ModelDB} from './server/model_db'; + +const modelId = '1528400733553'; +const batchSize = 42; +const FIT_CONFIG = {batchSize}; +const PORT = 3000; +const socketURL = `http://0.0.0.0:${PORT}`; +const initWeights = + [tf.tensor([1, 1, 1, 1], [2, 2]), tf.tensor([1, 2, 3, 4], [1, 4])]; +const updateThreshold = 2; + +function waitUntil(done: () => boolean, then: () => void, timeout?: number) { + const moveOn = () => { + clearInterval(moveOnIfDone); + clearTimeout(moveOnAnyway); + then(); + }; + const moveOnAnyway = setTimeout(moveOn, timeout || 100); + const moveOnIfDone = setInterval(() => { + if (done()) { + moveOn(); + } + }, 1); +} + +describe('Socket API', () => { + let dataDir: string; + let modelDir: string; + let modelPath: string; + let modelDB: ModelDB; + let serverAPI: SocketAPI; + let clientAPI: VariableSynchroniser; + let clientVars: Variable[]; + let httpServer: http.Server; + + beforeEach(async () => { + // Set up model database with our initial weights + dataDir = fs.mkdtempSync('/tmp/modeldb_test'); + modelDir = path.join(dataDir, modelId); + modelPath = path.join(dataDir, modelId + '.json'); + fs.mkdirSync(modelDir); + const modelJSON = await Promise.all(initWeights.map(tensorToJson)); + fs.writeFileSync(modelPath, JSON.stringify({'vars': modelJSON})); + modelDB = new ModelDB(dataDir, updateThreshold); + await modelDB.setup(); + + // Set up the server exposing our upload/download API + httpServer = http.createServer(); + serverAPI = new SocketAPI(modelDB, FIT_CONFIG, serverSocket(httpServer)); + await serverAPI.setup(); + await httpServer.listen(PORT); + + // Set up the API client with zeroed out weights + clientVars = initWeights.map(t => tf.variable(tf.zerosLike(t))); + clientAPI = new VariableSynchroniser(clientVars); + await clientAPI.initialise(socketURL); + }); + + afterEach(async () => { + rimraf.sync(dataDir); + await httpServer.close(); + }); + + it('transmits fit config on startup', () => { + expect(clientAPI.fitConfig.batchSize).toBe(batchSize); + }); + + it('transmits model version on startup', () => { + expect(clientAPI.modelId).toBe(modelId); + }); + + it('transmits model parameters on startup', () => { + test_util.expectArraysClose(clientVars[0], initWeights[0]); + test_util.expectArraysClose(clientVars[1], initWeights[1]); + }); + + it('transmits updates', async () => { + let updateFiles = await modelDB.listUpdateFiles(); + expect(updateFiles.length).toBe(0); + + clientVars[0].assign(tf.tensor([2, 2, 2, 2], [2, 2])); + clientAPI.numExamples = 1; + await clientAPI.uploadVars(); + + updateFiles = await modelDB.listUpdateFiles(); + expect(updateFiles.length).toBe(1); + }); + + it('triggers a download after enough uploads', async (done) => { + clientVars[0].assign(tf.tensor([2, 2, 2, 2], [2, 2])); + clientAPI.numExamples = 1; + await clientAPI.uploadVars(); + + clientVars[0].assign(tf.tensor([1, 1, 1, 1], [2, 2])); + clientVars[1].assign(tf.tensor([4, 3, 2, 1], [1, 4])); + clientAPI.numExamples = 3; + await clientAPI.uploadVars(); + + waitUntil(() => clientAPI.modelId !== modelId, () => { + test_util.expectArraysClose( + clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2])); + test_util.expectArraysClose( + clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4])); + expect(clientAPI.numExamples).toBe(0); + expect(clientAPI.modelId).toBe(modelDB.modelId); + done(); + }); + }); +}); diff --git a/src/common.ts b/src/common.ts index 273b7dd..a7856d6 100644 --- a/src/common.ts +++ b/src/common.ts @@ -20,25 +20,18 @@ import {ModelFitConfig} from '@tensorflow/tfjs'; import {SerializedVariable} from './serialization'; export enum Events { - Initialise = 'initialise', Download = 'downloadVars', Upload = 'uploadVars' } -export type TrainingInfo = { - nSteps: number -}; - -export type VarsMsg = { +export type UploadMsg = { modelId: string, - clientId: string, vars: SerializedVariable[], - history?: TrainingInfo + numExamples: number }; -export type ConnectionMsg = { - clientId: string, - fitConfig: ModelFitConfig, +export type DownloadMsg = { modelId: string, - initVars: SerializedVariable[] + vars: SerializedVariable[] + fitConfig: ModelFitConfig, }; diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..a858730 --- /dev/null +++ b/src/index.ts @@ -0,0 +1,18 @@ +/** + * * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export * from './types'; diff --git a/src/model.ts b/src/model.ts new file mode 100644 index 0000000..fbd418b --- /dev/null +++ b/src/model.ts @@ -0,0 +1,22 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// tslint:disable-next-line:max-line-length +import {MnistTransferLearningModel} from '../demo/mnist_transfer_learning_model'; + +// TODO: some kind of flag to determine what model we use +export class Model extends MnistTransferLearningModel {} diff --git a/src/serialization.ts b/src/serialization.ts index 8834863..00bf0ca 100644 --- a/src/serialization.ts +++ b/src/serialization.ts @@ -1,26 +1,33 @@ import * as tf from '@tensorflow/tfjs'; +import {Tensor, Variable} from '@tensorflow/tfjs'; +import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables'; export type SerializedVariable = { dtype: tf.DataType, shape: number[], - name: string, - data: ArrayBuffer, - trainable: boolean + data: ArrayBuffer }; -export async function serializeVar(variable: tf.Variable): +export async function serializeVar(variable: tf.Tensor): Promise { const data = await variable.data(); // small TypedArrays are views into a larger buffer const copy = data.buffer.slice(data.byteOffset, data.byteOffset + data.byteLength); - return { - dtype: variable.dtype, - shape: variable.shape.slice(), - name: variable.name, - trainable: variable.trainable, - data: copy - }; + return {dtype: variable.dtype, shape: variable.shape.slice(), data: copy}; +} + +export async function serializeVars( + vars: Array) { + const varsP: Array> = []; + vars.forEach((value, key) => { + if (value instanceof LayerVariable) { + varsP.push(serializeVar(tf.variable(value.read()))); + } else { + varsP.push(serializeVar(value)); + } + }); + return Promise.all(varsP); } export function deserializeVar(serialized: SerializedVariable): tf.Tensor { @@ -43,6 +50,37 @@ export function deserializeVar(serialized: SerializedVariable): tf.Tensor { return tf.tensor(array, shape, dtype); } +export type TensorJson = { + values: number[], + shape: number[], + dtype?: tf.DataType +}; + +export async function tensorToJson(t: tf.Tensor): Promise { + let data; + if (t instanceof LayerVariable) { + data = await t.read().data(); + } else { + data = await t.data(); + } + // Note: could make this async / use base64 encoding on the buffer data + return {'values': Array.from(data), 'shape': t.shape, 'dtype': t.dtype}; +} + +export function jsonToTensor(j: TensorJson): tf.Tensor { + return tf.tensor(j.values, j.shape, j.dtype || 'float32'); +} + +export async function serializedToJson(s: SerializedVariable): + Promise { + return tensorToJson(deserializeVar(s)); +} + +export async function jsonToSerialized(j: TensorJson): + Promise { + return serializeVar(jsonToTensor(j)); +} + const dtypeToTypedArrayCtor = { 'float32': Float32Array, 'int32': Int32Array, diff --git a/src/serialization_test.ts b/src/serialization_test.ts new file mode 100644 index 0000000..071df66 --- /dev/null +++ b/src/serialization_test.ts @@ -0,0 +1,58 @@ +/** + * * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '@tensorflow/tfjs'; +import {test_util} from '@tensorflow/tfjs-core'; + +import * as ser from './serialization'; + +describe('serialization', () => { + const floatTensor = + tf.tensor3d([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]]); + const boolTensor = tf.tensor1d([true, false], 'bool'); + const intTensor = tf.tensor2d([[1, 2], [3, 4]], [2, 2], 'int32'); + + it('converts back and forth to JSON', async () => { + const floatJSON = await ser.tensorToJson(floatTensor); + const boolJSON = await ser.tensorToJson(boolTensor); + const intJSON = await ser.tensorToJson(intTensor); + const floatTensor2 = ser.jsonToTensor(floatJSON); + const boolTensor2 = ser.jsonToTensor(boolJSON); + const intTensor2 = ser.jsonToTensor(intJSON); + test_util.expectArraysClose(floatTensor, floatTensor2); + test_util.expectArraysClose(boolTensor, boolTensor2); + test_util.expectArraysClose(intTensor, intTensor2); + }); + + it('converts back and forth to SerializedVar', async () => { + const floatSerial = await ser.serializeVar(floatTensor); + const boolSerial = await ser.serializeVar(boolTensor); + const intSerial = await ser.serializeVar(intTensor); + const floatTensor2 = ser.deserializeVar(floatSerial); + const boolTensor2 = ser.deserializeVar(boolSerial); + const intTensor2 = ser.deserializeVar(intSerial); + test_util.expectArraysClose(floatTensor, floatTensor2); + test_util.expectArraysClose(boolTensor, boolTensor2); + test_util.expectArraysClose(intTensor, intTensor2); + }); + + it('works for an arbitrary chain', async () => { + const floatTensor2 = ser.jsonToTensor(await ser.serializedToJson( + await ser.jsonToSerialized(await ser.tensorToJson(floatTensor)))); + test_util.expectArraysClose(floatTensor2, floatTensor); + }); +}); diff --git a/src/server/comm.ts b/src/server/comm.ts new file mode 100644 index 0000000..3a61a1a --- /dev/null +++ b/src/server/comm.ts @@ -0,0 +1,90 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ModelFitConfig} from '@tensorflow/tfjs'; +import * as fs from 'fs'; +import * as path from 'path'; +import {Server, Socket} from 'socket.io'; +import {promisify} from 'util'; +import * as uuid from 'uuid/v4'; + +import {DownloadMsg, Events, UploadMsg} from '../common'; +import {serializedToJson, serializeVar} from '../serialization'; + +import {ModelDB} from './model_db'; + +const writeFile = promisify(fs.writeFile); + +export class SocketAPI { + modelDB: ModelDB; + fitConfig: ModelFitConfig; + io: Server; + + constructor(modelDB: ModelDB, fitConfig: ModelFitConfig, io: Server) { + this.modelDB = modelDB; + this.fitConfig = fitConfig; + this.io = io; + } + + async downloadMsg(): Promise { + const varsJson = await this.modelDB.currentVars(); + const varsSeri = await Promise.all(varsJson.map(serializeVar)); + return { + fitConfig: this.fitConfig, + modelId: this.modelDB.modelId, + vars: varsSeri + }; + } + + async setup() { + this.io.on('connection', async (socket: Socket) => { + // Send current variables to newly connected client + const initVars = await this.downloadMsg(); + socket.emit(Events.Download, initVars); + + // When a client sends us updated weights + socket.on(Events.Upload, async (msg: UploadMsg, ack) => { + // Save them to a file + const modelId = msg.modelId; + const updateId = uuid(); + const updatePath = + path.join(this.modelDB.dataDir, modelId, updateId + '.json'); + const updatedVars = await Promise.all(msg.vars.map(serializedToJson)); + const updateJSON = JSON.stringify({ + clientId: socket.client.id, + modelId, + numExamples: msg.numExamples, + vars: updatedVars + }); + await writeFile(updatePath, updateJSON); + + // Let them know we're done saving + ack(true); + + // Potentially update the model (asynchronously) + if (modelId === this.modelDB.modelId) { + const updated = await this.modelDB.possiblyUpdate(); + if (updated) { + // Send new variables to all clients if we updated + const newVars = await this.downloadMsg(); + this.io.sockets.emit(Events.Download, newVars); + } + } + }); + }); + } +} diff --git a/src/server/fetch_polyfill.ts b/src/server/fetch_polyfill.ts new file mode 100644 index 0000000..26324fa --- /dev/null +++ b/src/server/fetch_polyfill.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as es6Promise from 'es6-promise'; +es6Promise.polyfill(); + +eval(` +var realFetch = require('node-fetch'); + +if (!global.fetch) { + global.fetch = realFetch; + global.Response = realFetch.Response; + global.Headers = realFetch.Headers; + global.Request = realFetch.Request; +} +`); diff --git a/src/server/model_db.ts b/src/server/model_db.ts index 146f23d..91777bd 100644 --- a/src/server/model_db.ts +++ b/src/server/model_db.ts @@ -20,37 +20,28 @@ import * as fs from 'fs'; import * as path from 'path'; import {promisify} from 'util'; +import {Model} from '../model'; +import {jsonToTensor, TensorJson, tensorToJson} from '../serialization'; + const DEFAULT_MIN_UPDATES = 10; const mkdir = promisify(fs.mkdir); +const exists = promisify(fs.exists); const readdir = promisify(fs.readdir); const readFile = promisify(fs.readFile); const writeFile = promisify(fs.writeFile); -function getLatestId(dir: string) { - const files = fs.readdirSync(dir); - return files.reduce((acc, val) => { - if (val.endsWith('.json') && val.slice(0, -5) > acc) { - return val.slice(0, -5); - } else { - return acc; +async function getLatestId(dir: string) { + const files = await readdir(dir); + let latestId: string = null; + files.forEach((name) => { + if (name.endsWith('.json')) { + const id = name.slice(0, -5); + if (latestId == null || id > latestId) { + latestId = id; + } } - }, '0'); -} - -type TensorJSON = { - values: number[], - shape: number[], - dtype?: tf.DataType -}; - -function dumpTensor(t: tf.Tensor) { - return { - 'values': Array.from(t.dataSync()), 'shape': t.shape, 'dtype': t.dtype - } -} - -function loadTensor(obj: TensorJSON) { - return tf.tensor(obj.values, obj.shape, obj.dtype || 'float32'); + }); + return latestId; } function generateNewId() { @@ -68,13 +59,27 @@ export class ModelDB { updating: boolean; minUpdates: number; - constructor(dataDir: string, minUpdates?: number, currentModelId?: string) { + constructor(dataDir: string, minUpdates?: number) { this.dataDir = dataDir; - this.modelId = currentModelId || getLatestId(dataDir); + this.modelId = null; this.updating = false; this.minUpdates = minUpdates || DEFAULT_MIN_UPDATES; } + async setup() { + const dirExists = await exists(this.dataDir); + if (!dirExists) { + await mkdir(this.dataDir); + } + + this.modelId = await getLatestId(this.dataDir); + if (this.modelId == null) { + const model = new Model(); + const dict = await model.setup(); + await this.writeNewVars(dict.vars as tf.Tensor[]); + } + } + async listUpdateFiles(): Promise { const files = await readdir(path.join(this.dataDir, this.modelId)); return files.map((f) => { @@ -85,21 +90,23 @@ export class ModelDB { async currentVars(): Promise { const file = path.join(this.dataDir, this.modelId + '.json'); const json = await readJSON(file); - return json['vars'].map(loadTensor); + return json['vars'].map(jsonToTensor); } - async possiblyUpdate() { + async possiblyUpdate(): Promise { const updateFiles = await this.listUpdateFiles(); if (updateFiles.length < this.minUpdates || this.updating) { - return; + return false; } this.updating = true; await this.update(); this.updating = false; + return true; } async update() { - const updatedVars = await this.currentVars(); + const currentVars = await this.currentVars(); + const updatedVars = currentVars.map(v => tf.zerosLike(v)); const updateFiles = await this.listUpdateFiles(); const updatesJSON = await Promise.all(updateFiles.map(readJSON)); @@ -114,17 +121,22 @@ export class ModelDB { updatesJSON.forEach((u) => { const nk = tf.scalar(u['numExamples']); const frac = nk.div(n); - u['vars'].forEach((v: TensorJSON, i: number) => { - const update = loadTensor(v).mul(frac); + u['vars'].forEach((v: TensorJson, i: number) => { + const update = jsonToTensor(v).mul(frac); updatedVars[i] = updatedVars[i].add(update); }); }); // Save results and update key + await this.writeNewVars(updatedVars); + } + + async writeNewVars(newVars: tf.Tensor[]) { const newModelId = generateNewId(); const newModelDir = path.join(this.dataDir, newModelId); const newModelPath = path.join(this.dataDir, newModelId + '.json'); - const newModelJSON = JSON.stringify({'vars': updatedVars.map(dumpTensor)}); + const newVarsJSON = await Promise.all(newVars.map(tensorToJson)); + const newModelJSON = JSON.stringify({'vars': newVarsJSON}); await writeFile(newModelPath, newModelJSON); await mkdir(newModelDir); this.modelId = newModelId; diff --git a/src/server/model_db_test.ts b/src/server/model_db_test.ts index f747fdd..80393f3 100644 --- a/src/server/model_db_test.ts +++ b/src/server/model_db_test.ts @@ -15,7 +15,6 @@ * ============================================================================= */ -// tslint:disable-next-line:max-line-length import {test_util} from '@tensorflow/tfjs-core'; import * as fs from 'fs'; import * as path from 'path'; @@ -67,13 +66,15 @@ describe('ModelDB', () => { rimraf.sync(dataDir); }); - it('defaults to treating the latest model as current', () => { + it('defaults to treating the latest model as current', async () => { const db = new ModelDB(dataDir); + await db.setup(); expect(db.modelId).toBe(modelId); }); it('loads variables from JSON', async () => { const db = new ModelDB(dataDir); + await db.setup(); const vars = await db.currentVars(); test_util.expectArraysClose(vars[0], [0, 0, 0, 0]); test_util.expectArraysClose(vars[1], [1, 2, 3, 4]); @@ -85,16 +86,19 @@ describe('ModelDB', () => { it('updates the model using a weighted average', async () => { const db = new ModelDB(dataDir); + await db.setup(); await db.update(); expect(db.modelId).not.toBe(modelId); const newVars = await db.currentVars(); test_util.expectArraysClose(newVars[0], [0.4, -0.4, 0.6, -0.6]); - test_util.expectArraysClose(newVars[1], [1.2, 2.8, 3.2, 4.8]); + test_util.expectArraysClose(newVars[1], [0.2, 0.8, 0.2, 0.8]); }); it('only performs update after passing a threshold', async () => { const db = new ModelDB(dataDir, 3); - await db.possiblyUpdate(); + await db.setup(); + let updated = await db.possiblyUpdate(); + expect(updated).toBe(false); expect(db.modelId).toBe(modelId); const oldUpdateFiles = await db.listUpdateFiles(); expect(oldUpdateFiles.length).toBe(2); @@ -107,9 +111,10 @@ describe('ModelDB', () => { {values: [0, 0, 0, 0], shape: [1, 4]} ] })); - await db.possiblyUpdate(); + updated = await db.possiblyUpdate(); + expect(updated).toBe(true); expect(db.modelId).not.toBe(modelId); const newUpdateFiles = await db.listUpdateFiles(); expect(newUpdateFiles.length).toBe(0); }); -}) +}); diff --git a/src/server/server.ts b/src/server/server.ts index 4e547b9..3b488e8 100644 --- a/src/server/server.ts +++ b/src/server/server.ts @@ -17,24 +17,36 @@ /** Server code */ +import './fetch_polyfill'; + import * as express from 'express'; import {Request, Response} from 'express'; import * as http from 'http'; import * as path from 'path'; import * as socketIO from 'socket.io'; +import {SocketAPI} from './comm'; +import {ModelDB} from './model_db'; + const app = express(); const server = http.createServer(app); const io = socketIO(server); +const indexPath = path.resolve(__dirname + '/../../demo/index.html'); +const dataDir = path.resolve(__dirname + '/../../data'); +const modelDB = new ModelDB(dataDir); +const FIT_CONFIG = { + batchSize: 10 +}; +const socketAPI = new SocketAPI(modelDB, FIT_CONFIG, io); app.get('/', (req: Request, res: Response) => { - res.sendFile(path.resolve(__dirname + '/../demo/index.html')); -}); - -io.on('connection', (socket: socketIO.Socket) => { - console.log('a user connected'); + res.sendFile(indexPath); }); -server.listen(3000, () => { - console.log('listening on 3000'); +modelDB.setup().then(() => { + socketAPI.setup().then(() => { + server.listen(3000, () => { + console.log('listening on 3000'); + }); + }); }); diff --git a/src/types.ts b/src/types.ts new file mode 100644 index 0000000..7189b88 --- /dev/null +++ b/src/types.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Scalar, Tensor, Variable} from '@tensorflow/tfjs'; +import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables'; + +export type LossFun = (inputs: Tensor, labels: Tensor) => Scalar; +export type PredFun = (inputs: Tensor) => Tensor|Tensor[]; +export type VarList = Array; +export type ModelDict = { + vars: VarList, + loss: LossFun, + predict: PredFun +}; + +export interface FederatedModel { + setup(): Promise; +} diff --git a/tsconfig.json b/tsconfig.json index 38db13c..68efb04 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -21,6 +21,7 @@ "experimentalDecorators": true }, "include": [ - "src/" + "src/", + "demo/" ] } diff --git a/yarn.lock b/yarn.lock index 41d4938..bd914c8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -518,6 +518,10 @@ error-ex@^1.2.0: dependencies: is-arrayish "^0.2.1" +es6-promise@^4.2.4: + version "4.2.4" + resolved "https://registry.yarnpkg.com/es6-promise/-/es6-promise-4.2.4.tgz#dc4221c2b16518760bd8c39a52d8f356fc00ed29" + escape-html@~1.0.3: version "1.0.3" resolved "https://registry.yarnpkg.com/escape-html/-/escape-html-1.0.3.tgz#0258eae4d3d0c0974de1c169188ef0051d1d1988" @@ -950,6 +954,10 @@ node-emoji@^1.4.1: dependencies: lodash.toarray "^4.4.0" +node-fetch@^2.1.2: + version "2.1.2" + resolved "https://registry.yarnpkg.com/node-fetch/-/node-fetch-2.1.2.tgz#ab884e8e7e57e38a944753cec706f788d1768bb5" + node-notifier@^4.0.2: version "4.6.1" resolved "https://registry.yarnpkg.com/node-notifier/-/node-notifier-4.6.1.tgz#056d14244f3dcc1ceadfe68af9cff0c5473a33f3"