diff --git a/demo/mnist_transfer_learning_model.ts b/demo/mnist_transfer_learning_model.ts index 096be90..4f00912 100644 --- a/demo/mnist_transfer_learning_model.ts +++ b/demo/mnist_transfer_learning_model.ts @@ -17,7 +17,7 @@ import * as tf from '@tensorflow/tfjs' import {Scalar, Tensor} from '@tensorflow/tfjs'; -import {FederatedModel, ModelDict} from '../src/types'; +import {FederatedModel, ModelDict} from '../src/index'; // https://github.com/tensorflow/tfjs-examples/tree/master/mnist-transfer-cnn // tslint:disable-next-line:max-line-length @@ -35,7 +35,8 @@ export class MnistTransferLearningModel implements FederatedModel { const loss = (inputs: Tensor, labels: Tensor) => { const logits = model.predict(inputs) as Tensor; - return tf.losses.softmaxCrossEntropy(logits, labels).mean() as Scalar; + const losses = tf.losses.softmaxCrossEntropy(logits, labels); + return losses.mean() as Scalar; } return { diff --git a/src/comm_test.ts b/src/comm_test.ts index 20ff942..f2685c9 100644 --- a/src/comm_test.ts +++ b/src/comm_test.ts @@ -16,8 +16,6 @@ * ============================================================================= */ -// import * as tf from '@tensorflow/tfjs'; -// import {test_util} from '@tensorflow/tfjs-core'; import * as tf from '@tensorflow/tfjs'; import {test_util, Variable} from '@tensorflow/tfjs'; import * as fs from 'fs'; @@ -42,6 +40,18 @@ const initWeights = [tf.tensor([1, 1, 1, 1], [2, 2]), tf.tensor([1, 2, 3, 4], [1, 4])]; const updateThreshold = 2; +function waitUntil(done: () => boolean, then: () => void, timeout?: number) { + const moveOn = () => { + clearInterval(checkDone); + clearTimeout(moveOnAnyway); + then(); + }; + const moveOnAnyway = setTimeout(moveOn, timeout || 100); + const checkDone = setInterval(() => { + if (done()) moveOn(); + }, 1); +} + describe('Socket API', () => { let dataDir: string; let modelDir: string; @@ -115,20 +125,14 @@ describe('Socket API', () => { clientAPI.numExamples = 3; await clientAPI.uploadVars(); - const timeout = 100; - let elapsed = 0; - const interval = setInterval(() => { - elapsed += 1; - if (elapsed > timeout || clientAPI.modelId != modelId) { - clearInterval(interval); - test_util.expectArraysClose( - clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2])) - test_util.expectArraysClose( - clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4])) - expect(clientAPI.numExamples).toBe(0); - expect(clientAPI.modelId).toBe(modelDB.modelId); - done(); - } - }, 1); + waitUntil(() => clientAPI.modelId != modelId, () => { + test_util.expectArraysClose( + clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2])) + test_util.expectArraysClose( + clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4])) + expect(clientAPI.numExamples).toBe(0); + expect(clientAPI.modelId).toBe(modelDB.modelId); + done(); + }); }) }); diff --git a/src/index.ts b/src/index.ts new file mode 100644 index 0000000..a858730 --- /dev/null +++ b/src/index.ts @@ -0,0 +1,18 @@ +/** + * * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export * from './types';