diff --git a/mypackage/_mymodel.py b/mypackage/_mymodel.py index 036fd42..ddafaa3 100644 --- a/mypackage/_mymodel.py +++ b/mypackage/_mymodel.py @@ -11,7 +11,9 @@ class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """ Skeleton for an scvi-tools model. + Please use this skeleton to create new models. + Parameters ---------- adata diff --git a/mypackage/_mymodule.py b/mypackage/_mymodule.py index b814646..29cf06d 100644 --- a/mypackage/_mymodule.py +++ b/mypackage/_mymodule.py @@ -28,8 +28,6 @@ class MyModule(BaseModuleClass): Number of input genes n_batch Number of batches, if 0, no batch correction is performed. - n_labels - Number of labels n_hidden Number of nodes per hidden layer n_latent diff --git a/mypackage/_mypyromodel.py b/mypackage/_mypyromodel.py index d20e797..1e1cf7f 100644 --- a/mypackage/_mypyromodel.py +++ b/mypackage/_mypyromodel.py @@ -146,6 +146,7 @@ def train( train_size=train_size, validation_size=validation_size, batch_size=batch_size, + use_gpu=use_gpu, ) training_plan = PyroTrainingPlan(self.module, **plan_kwargs) runner = TrainRunner( diff --git a/mypackage/_mypyromodule.py b/mypackage/_mypyromodule.py index fe07e51..c61d304 100644 --- a/mypackage/_mypyromodule.py +++ b/mypackage/_mypyromodule.py @@ -6,7 +6,25 @@ class MyPyroModule(PyroBaseModuleClass): - def __init__(self, n_input, n_latent, n_hidden, n_layers): + """ + Skeleton Variational auto-encoder Pyro model. + + Here we implement a basic version of scVI's underlying VAE [Lopez18]_. + This implementation is for instructional purposes only. + + Parameters + ---------- + n_input + Number of input genes + n_latent + Dimensionality of the latent space + n_hidden + Number of nodes per hidden layer + n_layers + Number of hidden layers used for encoder and decoder NNs + """ + + def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): super().__init__() self.n_input = n_input