Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first pass at server-side sockets (+ client comm changes) #5

Merged
merged 13 commits into from
Jun 12, 2018

Conversation

asross
Copy link
Collaborator

@asross asross commented Jun 8, 2018

This implements a server websocket API that supports two basic message types: sending a DownloadMsg and receiving an UploadMsg.

The serialization/deserialization logic is in flux and will definitely be streamlined, but this should work as a first pass. The main TODO is to get this working with the client at the API level.


This change is Reviewable

@asross asross changed the title WIP - first pass at server-side sockets WIP - first pass at server-side sockets (+ client comm changes) Jun 8, 2018
setupUserModel().then((dict) => {
this.writeNewVars(dict.vars as tf.Tensor[]);
});
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we start the server for the first time, the effect of these lines will be to:

  1. create dist/data/
  2. load the mnist transfer CNN from the internet
  3. save its fully connected variable values to dist/data/<some-timestamp>.json

However, (3) will actually happen asynchronously (I can't await its completion in the constructor), and the server might start before it finishes. This will only be an issue the very first time we start the server, so it might not be worth addressing, but it would probably be better to wait on starting the server until the database is ready to send weights to the client.

const DEFAULT_MIN_UPDATES = 10;
const NO_MODELS_IN_FOLDER = '0';
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little bit hacky -- I'm having getLatestId return '0' if there are no models in the data directory, and then using that as a special indicator that we need to initialize the database. Let me know if you think of a nicer way to structure this. I could have getLatestId return undefined instead, I guess.

@asross asross changed the title WIP - first pass at server-side sockets (+ client comm changes) first pass at server-side sockets (+ client comm changes) Jun 8, 2018
@dsmilkov dsmilkov self-requested a review June 10, 2018 14:39
@dsmilkov
Copy link
Collaborator

:lgtm_strong: This is great for a first pass! I left a few comments, but nothing blocking since things are still in flux, and not worth overthinking it in the beginning (thus LGTM)


Reviewed 11 of 11 files at r1.
Review status: all files reviewed at latest revision, 2 unresolved discussions.


src/model.ts, line 27 at r1 (raw file):

    'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json';

export type LossFunction = (inputs: Tensor, labels: Tensor) => Tensor;

you can constrain the return type to be a Scalar which is an alias for Tensor<Rank.R0>


src/model.ts, line 28 at r1 (raw file):

export type LossFunction = (inputs: Tensor, labels: Tensor) => Tensor;
export type PredFunction = (inputs: Tensor) => Tensor|Tensor[];

Move these types LossFunc, PredFun VarList and ModelDict to src/types.ts. Later we can re-export them to the public-api in index.ts.
Also add an interface in src/types.ts that is FederatedModel which for now is:

interface FederatedModel {
  async setup(): ModelDict;
}

and implement it here in src/model.ts. Also consider renaming this file to mnist_model.ts. I imagine this file will later go in demo/ since it's user code. But it's good for now to be in src/ since things are still in flux.


src/model.ts, line 29 at r1 (raw file):

export type LossFunction = (inputs: Tensor, labels: Tensor) => Tensor;
export type PredFunction = (inputs: Tensor) => Tensor|Tensor[];
export type VariableList = Array<Variable|LayerVariable|Tensor>;

Maybe don't allow Tensor as type in VariableList, because we'd want the flexibility for the framework itself to be able update the vars , instead of the user-code, and if we only have Tensors, we can't update their values - the users will have to do that themselves.


src/serialization.ts, line 68 at r1 (raw file):

  // Note: could make this async / use base64 encoding on the buffer data
  return {
    'values': Array.from(data), 'shape': t.shape, 'dtype': t.dtype

Have you tried sending a Float32Array() (that is omit calling Array.from(data)) through socket.io and see what happens? IUUC, Aman investigated sending TypedArray/ArrayBuffers through socket.io and concluded they worked.


src/server/fetch_polyfill.ts, line 5 at r1 (raw file):

eval(`
var realFetch = require('node-fetch');

Curious why you had to do the eval stuff. Anyhow, not blocking - we can chat in person.


src/server/model_db.ts, line 27 at r1 (raw file):

Previously, asross (Andrew Ross) wrote…

This is a little bit hacky -- I'm having getLatestId return '0' if there are no models in the data directory, and then using that as a special indicator that we need to initialize the database. Let me know if you think of a nicer way to structure this. I could have getLatestId return undefined instead, I guess.

That's ok. Nit: usually returning null is more standard. we avoid using undefined. And we always use double equal (not triple) when comparing with null.


src/server/model_db.ts, line 70 at r1 (raw file):

Previously, asross (Andrew Ross) wrote…

When we start the server for the first time, the effect of these lines will be to:

  1. create dist/data/
  2. load the mnist transfer CNN from the internet
  3. save its fully connected variable values to dist/data/<some-timestamp>.json

However, (3) will actually happen asynchronously (I can't await its completion in the constructor), and the server might start before it finishes. This will only be an issue the very first time we start the server, so it might not be worth addressing, but it would probably be better to wait on starting the server until the database is ready to send weights to the client.

Since c-tors are limited, we usually avoid doing real work in them, and have a setup() method on the class (which can be async) and you can call it from another place. In this case that setup() method will do all of the 3 steps.


Comments from Reviewable

@nsthorat
Copy link
Collaborator

Review status: all files reviewed at latest revision, 7 unresolved discussions, all commit checks successful.


src/model.ts, line 20 at r1 (raw file):

import * as tf from '@tensorflow/tfjs';
import {Tensor, Variable} from '@tensorflow/tfjs';
import {LayerVariable} from '@tensorflow/tfjs-layers/dist/variables';

can you import this from the public API?


src/model.ts, line 25 at r1 (raw file):

// tslint:disable-next-line:max-line-length
const mnistTransferLearningModelURL =
    'https://storage.googleapis.com/tfjs-models/tfjs/mnist_transfer_cnn_v1/model.json';

I know we're just getting started, but maybe we could not hard code some examples into src/ -- if we could have a "demo" or "example" directory that passes these flags into code in src/ this will help you think about the public API as you develop this.


src/serialization.ts, line 59 at r1 (raw file):

};

export function tensorToJson(t: tf.Tensor): TensorJson {

it's probably a good idea to make this async (return Promise<TensorJson>) and use .data()


src/serialization.ts, line 71 at r1 (raw file):

  }
}

it would be good at add test coverage for this file


src/server/fetch_polyfill.ts, line 1 at r1 (raw file):

import * as es6Promise from 'es6-promise';

put a license at the top of this file.


src/server/fetch_polyfill.ts, line 8 at r1 (raw file):

if (!global.fetch) {
	global.fetch = realFetch;

use spaces not tabs


src/server/model_db.ts, line 27 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

That's ok. Nit: usually returning null is more standard. we avoid using undefined. And we always use double equal (not triple) when comparing with null.

+1 to using null here


src/server/model_db.ts, line 67 at r1 (raw file):

    this.modelId = currentModelId || getLatestId(dataDir);
    if (this.modelId == NO_MODELS_IN_FOLDER) {
      setupUserModel().then((dict) => {

you dont need parens around dict since it's the only argument


src/server/model_db.ts, line 89 at r1 (raw file):

  }

  async possiblyUpdate() {

return type on this method


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/model.ts, line 20 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

can you import this from the public API?

It doesn't seem to be available from tfjs-layers's public API -- please correct me if I'm wrong!


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/server/model_db.ts, line 27 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

+1 to using null here

Great, done (and thanks for the heads up about ==)!


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/server/model_db.ts, line 70 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Since c-tors are limited, we usually avoid doing real work in them, and have a setup() method on the class (which can be async) and you can call it from another place. In this case that setup() method will do all of the 3 steps.

Got it, good to know about that pattern.


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/serialization.ts, line 68 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Have you tried sending a Float32Array() (that is omit calling Array.from(data)) through socket.io and see what happens? IUUC, Aman investigated sending TypedArray/ArrayBuffers through socket.io and concluded they worked.

This is for reading/writing from files, but it's possible I could save/load them in a binary format using the socket.io parser. I was thinking that I would take a pass at that slightly later together with Aman.


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/model.ts, line 25 at r1 (raw file):

Previously, nsthorat (Nikhil Thorat) wrote…

I know we're just getting started, but maybe we could not hard code some examples into src/ -- if we could have a "demo" or "example" directory that passes these flags into code in src/ this will help you think about the public API as you develop this.

Done, although I'm not sure if the way I did it is what you had in mind -- I moved mnist_transfer_learning_model into demo, then had model.ts import then export it. Let me know if that's alright.


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/model.ts, line 20 at r1 (raw file):

Previously, asross (Andrew Ross) wrote…

It doesn't seem to be available from tfjs-layers's public API -- please correct me if I'm wrong!

(where it := LayerVariable, since it's no longer in this diff)


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/server/fetch_polyfill.ts, line 5 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Curious why you had to do the eval stuff. Anyhow, not blocking - we can chat in person.

There's almost certainly a better way, but the reason I did it this way was:

  • we need to use node-fetch, but I couldn't modify the implementation of loadModel in TensorflowJS to first import it
  • a solution to this is isomorphic-fetch (plus es6promise), but isomorphic-fetch hasn't been updated in years and is using an old version of node-fetch that doesn't support array buffers.
  • however, the isomorphic-fetch code which lets browser-style fetch invocations run within node is actually really simple
  • but, it requires setting properties on the NodeJS global object, which in Typescript requires making changes to several files. I felt it was better to keep polyfill code confined to a single file.

Let me know if you have a simpler solution for any part of that chain of problems...


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

src/comm_test.ts, line 116 at r4 (raw file):

    clientVars[1].assign(tf.tensor([4, 3, 2, 1], [1, 4]));
    clientAPI.numExamples = 3;
    await clientAPI.uploadVars();

Question / possible issue --

Let's imagine a user labels one example, and the client code, fearing the user may disconnect imminently, decides to compute a weight update (following our fit parameters) and send it up to the server with numExamples=1.

However, the user doesn't disconnect, and instead goes on to label three more examples.

The client has at least two options:

  1. it can revert to the initial weights, then compute a second update using just the newly labeled examples (which matches the scenario I've laid out here), sending it with numExamples=3.
  2. it can revert to the initial weights, then compute an update using all of the examples it's labeled so far. In this case, it should send an update with numExamples=4, and this update should supersede the previous one.

(I think in either case, it has to revert to the initial weights before retraining, to ensure that different clients send weight updates generated under consistent conditions.)

I've assumed that we're working under option 1, since it's a bit simpler from the server-side and also doesn't require that the client remember any training data after it's been used in SGD. However, what it will probably amount to in practice is that all of the updates will be generated using individual examples. This might not be the end of the world -- the experiment from before shows that, on MNIST, learning from ~30 independent updates with 1 example is only slightly worse than learning from ~10 independent updates with 3 examples. But maybe it's better to have client updates still supersede each other. Let me know what you all think.


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 11, 2018

Review status: :shipit: complete! 1 of 1 LGTMs obtained


a discussion (no related file):
Question:


Comments from Reviewable

@dsmilkov
Copy link
Collaborator

:lgtm_strong: Really nice work! Left some tiny comments. Feel free to submit after!


Reviewed 6 of 9 files at r3, 7 of 7 files at r4.
Review status: :shipit: complete! 1 of 1 LGTMs obtained


demo/mnist_transfer_learning_model.ts, line 20 at r4 (raw file):

import * as tf from '@tensorflow/tfjs'
import {Scalar, Tensor} from '@tensorflow/tfjs';
import {FederatedModel, ModelDict} from '../src/types';

To emulate better the usage of the API, introduce a src/index.ts and have all of the user-facing types be exported from there. Then here you should import {....} from '../src/index' where later ../src/index will be replaced with the npm package name. see index.ts in tfjs-core as an example.


src/comm_test.ts, line 20 at r4 (raw file):

// import * as tf from '@tensorflow/tfjs';
// import {test_util} from '@tensorflow/tfjs-core';

remove the 2 commented out lines


src/comm_test.ts, line 116 at r4 (raw file):

Previously, asross (Andrew Ross) wrote…

Question / possible issue --

Let's imagine a user labels one example, and the client code, fearing the user may disconnect imminently, decides to compute a weight update (following our fit parameters) and send it up to the server with numExamples=1.

However, the user doesn't disconnect, and instead goes on to label three more examples.

The client has at least two options:

  1. it can revert to the initial weights, then compute a second update using just the newly labeled examples (which matches the scenario I've laid out here), sending it with numExamples=3.
  2. it can revert to the initial weights, then compute an update using all of the examples it's labeled so far. In this case, it should send an update with numExamples=4, and this update should supersede the previous one.

(I think in either case, it has to revert to the initial weights before retraining, to ensure that different clients send weight updates generated under consistent conditions.)

I've assumed that we're working under option 1, since it's a bit simpler from the server-side and also doesn't require that the client remember any training data after it's been used in SGD. However, what it will probably amount to in practice is that all of the updates will be generated using individual examples. This might not be the end of the world -- the experiment from before shows that, on MNIST, learning from ~30 independent updates with 1 example is only slightly worse than learning from ~10 independent updates with 3 examples. But maybe it's better to have client updates still supersede each other. Let me know what you all think.

Great observation. Can you file an issue on the repo with this discussion and we can revisit it later? For now #1 sounds good. (we can come back to this if need arises based on the experiments)


src/comm_test.ts, line 122 at r4 (raw file):

    const interval = setInterval(() => {
      elapsed += 1;
      if (elapsed > timeout || clientAPI.modelId != modelId) {

check out jasmine's spyOn to verify that a callback was called. https://jasmine.github.io/2.0/introduction.html


src/server/fetch_polyfill.ts, line 5 at r1 (raw file):

Previously, asross (Andrew Ross) wrote…

There's almost certainly a better way, but the reason I did it this way was:

  • we need to use node-fetch, but I couldn't modify the implementation of loadModel in TensorflowJS to first import it
  • a solution to this is isomorphic-fetch (plus es6promise), but isomorphic-fetch hasn't been updated in years and is using an old version of node-fetch that doesn't support array buffers.
  • however, the isomorphic-fetch code which lets browser-style fetch invocations run within node is actually really simple
  • but, it requires setting properties on the NodeJS global object, which in Typescript requires making changes to several files. I felt it was better to keep polyfill code confined to a single file.

Let me know if you have a simpler solution for any part of that chain of problems...

This is ok for now. Shanqing is working on making node work with https, at which you this hack will go away.


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 12, 2018

src/comm_test.ts, line 122 at r4 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

check out jasmine's spyOn to verify that a callback was called. https://jasmine.github.io/2.0/introduction.html

Unfortunately, I don't think spyOn expectations have any kind of timeout -- adding spyOn and an assertion that a client download occurs just causes the test to immediately fail. I refactored the test, though, to make the intention much clearer.


Comments from Reviewable

@asross
Copy link
Collaborator Author

asross commented Jun 12, 2018

src/comm_test.ts, line 116 at r4 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Great observation. Can you file an issue on the repo with this discussion and we can revisit it later? For now #1 sounds good. (we can come back to this if need arises based on the experiments)

Yep, done: #7


Comments from Reviewable

@asross asross merged commit 27b2cfe into master Jun 12, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants