Skip to content

Commit

Permalink
refactors based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Ross committed Jun 12, 2018
1 parent d262b77 commit 1df8b70
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 19 deletions.
5 changes: 3 additions & 2 deletions demo/mnist_transfer_learning_model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import * as tf from '@tensorflow/tfjs'
import {Scalar, Tensor} from '@tensorflow/tfjs';
import {FederatedModel, ModelDict} from '../src/types';
import {FederatedModel, ModelDict} from '../src/index';

// https://github.com/tensorflow/tfjs-examples/tree/master/mnist-transfer-cnn
// tslint:disable-next-line:max-line-length
Expand All @@ -35,7 +35,8 @@ export class MnistTransferLearningModel implements FederatedModel {
const loss =
(inputs: Tensor, labels: Tensor) => {
const logits = model.predict(inputs) as Tensor;
return tf.losses.softmaxCrossEntropy(logits, labels).mean() as Scalar;
const losses = tf.losses.softmaxCrossEntropy(logits, labels);
return losses.mean() as Scalar;
}

return {
Expand Down
38 changes: 21 additions & 17 deletions src/comm_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
* =============================================================================
*/

// import * as tf from '@tensorflow/tfjs';
// import {test_util} from '@tensorflow/tfjs-core';
import * as tf from '@tensorflow/tfjs';
import {test_util, Variable} from '@tensorflow/tfjs';
import * as fs from 'fs';
Expand All @@ -42,6 +40,18 @@ const initWeights =
[tf.tensor([1, 1, 1, 1], [2, 2]), tf.tensor([1, 2, 3, 4], [1, 4])];
const updateThreshold = 2;

function waitUntil(done: () => boolean, then: () => void, timeout?: number) {
const moveOn = () => {
clearInterval(moveOnIfDone);
clearTimeout(moveOnAnyway);
then();
};
const moveOnAnyway = setTimeout(moveOn, timeout || 100);
const moveOnIfDone = setInterval(() => {
if (done()) moveOn();
}, 1);
}

describe('Socket API', () => {
let dataDir: string;
let modelDir: string;
Expand Down Expand Up @@ -115,20 +125,14 @@ describe('Socket API', () => {
clientAPI.numExamples = 3;
await clientAPI.uploadVars();

const timeout = 100;
let elapsed = 0;
const interval = setInterval(() => {
elapsed += 1;
if (elapsed > timeout || clientAPI.modelId != modelId) {
clearInterval(interval);
test_util.expectArraysClose(
clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2]))
test_util.expectArraysClose(
clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4]))
expect(clientAPI.numExamples).toBe(0);
expect(clientAPI.modelId).toBe(modelDB.modelId);
done();
}
}, 1);
waitUntil(() => clientAPI.modelId != modelId, () => {
test_util.expectArraysClose(
clientVars[0], tf.tensor([1.25, 1.25, 1.25, 1.25], [2, 2]))
test_util.expectArraysClose(
clientVars[1], tf.tensor([3.25, 2.75, 2.25, 1.75], [1, 4]))
expect(clientAPI.numExamples).toBe(0);
expect(clientAPI.modelId).toBe(modelDB.modelId);
done();
});
})
});
18 changes: 18 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/**
* * @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

export * from './types';

0 comments on commit 1df8b70

Please sign in to comment.