From 18d63ef8820a205bd375e8223feac74948139e7a Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Thu, 12 Sep 2024 20:54:25 -0400 Subject: [PATCH] refactor(models): remove unused parameters from velocity_model Signed-off-by: Cameron Smith --- src/pyrovelocity/models/_velocity_model.py | 70 +--------------------- 1 file changed, 2 insertions(+), 68 deletions(-) diff --git a/src/pyrovelocity/models/_velocity_model.py b/src/pyrovelocity/models/_velocity_model.py index 0250ca7dc..9b20d8a28 100644 --- a/src/pyrovelocity/models/_velocity_model.py +++ b/src/pyrovelocity/models/_velocity_model.py @@ -107,17 +107,7 @@ def __repr__(self) -> str: @beartype def create_plates( self, - u_obs: Optional[torch.Tensor] = None, - s_obs: Optional[torch.Tensor] = None, - u_log_library: Optional[torch.Tensor] = None, - s_log_library: Optional[torch.Tensor] = None, - u_log_library_loc: Optional[torch.Tensor] = None, - s_log_library_loc: Optional[torch.Tensor] = None, - u_log_library_scale: Optional[torch.Tensor] = None, - s_log_library_scale: Optional[torch.Tensor] = None, ind_x: Optional[torch.Tensor] = None, - cell_state: Optional[torch.Tensor] = None, - time_info: Optional[torch.Tensor] = None, ) -> Tuple[plate, plate]: """ Create cell and gene plates for the model. note that usage in @@ -126,17 +116,7 @@ def create_plates( https://github.com/pyro-ppl/pyro/blob/1.9.1/pyro/infer/autoguide/guides.py#L60-L63 Args: - u_obs (Optional[torch.Tensor], optional): _description_. Defaults to None. - s_obs (Optional[torch.Tensor], optional): _description_. Defaults to None. - u_log_library (Optional[torch.Tensor], optional): _description_. Defaults to None. - s_log_library (Optional[torch.Tensor], optional): _description_. Defaults to None. - u_log_library_loc (Optional[torch.Tensor], optional): _description_. Defaults to None. - s_log_library_loc (Optional[torch.Tensor], optional): _description_. Defaults to None. - u_log_library_scale (Optional[torch.Tensor], optional): _description_. Defaults to None. - s_log_library_scale (Optional[torch.Tensor], optional): _description_. Defaults to None. ind_x (Optional[torch.Tensor], optional): _description_. Defaults to None. - cell_state (Optional[torch.Tensor], optional): _description_. Defaults to None. - time_info (Optional[torch.Tensor], optional): _description_. Defaults to None. Returns: Tuple[plate, plate]: _description_ @@ -220,16 +200,8 @@ def get_likelihood( self, ut: torch.Tensor, st: torch.Tensor, - u_log_library: Optional[torch.Tensor] = None, - s_log_library: Optional[torch.Tensor] = None, - u_scale: Optional[torch.Tensor] = None, - s_scale: Optional[torch.Tensor] = None, u_read_depth: Optional[torch.Tensor] = None, s_read_depth: Optional[torch.Tensor] = None, - u_cell_size_coef: None = None, - ut_coef: None = None, - s_cell_size_coef: None = None, - st_coef: None = None, ) -> Tuple[Poisson, Poisson]: """ Compute the likelihood of the given count data. @@ -564,13 +536,9 @@ def forward( s_obs: torch.Tensor, u_log_library: Optional[torch.Tensor] = None, s_log_library: Optional[torch.Tensor] = None, - u_log_library_loc: Optional[torch.Tensor] = None, - s_log_library_loc: Optional[torch.Tensor] = None, u_log_library_scale: Optional[torch.Tensor] = None, s_log_library_scale: Optional[torch.Tensor] = None, ind_x: Optional[torch.Tensor] = None, - cell_state: Optional[torch.Tensor] = None, - time_info: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Defines the forward model, which computes the unspliced (u) and spliced @@ -585,10 +553,6 @@ def forward( Log-transformed library size for unspliced RNA. Default is None. s_log_library (Optional[torch.Tensor], optional): Log-transformed library size for spliced RNA. Default is None. - u_log_library_loc (Optional[torch.Tensor], optional): - Mean of log-transformed library size for unspliced RNA. Default is None. - s_log_library_loc (Optional[torch.Tensor], optional): - Mean of log-transformed library size for spliced RNA. Default is None. u_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for unspliced RNA. Default is None. @@ -596,10 +560,6 @@ def forward( Scale of log-transformed library size for spliced RNA. Default is None. ind_x (Optional[torch.Tensor], optional): Indices for the cells. Default is None. - cell_state (Optional[torch.Tensor], optional): - Cell state information. Default is None. - time_info (Optional[torch.Tensor], optional): - Time information for the cells. Default is None. Returns: Tuple[torch.Tensor, torch.Tensor]: @@ -622,8 +582,6 @@ def forward( >>> ) >>> u_log_library=torch.tensor([[3.7377], [4.0254], [2.7081]], device="cpu") >>> s_log_library=torch.tensor([[3.6376], [3.9512], [2.3979]], device="cpu") - >>> u_log_library_loc=torch.tensor([[3.4904], [3.4904], [3.4904]], device="cpu") - >>> s_log_library_loc=torch.tensor([[3.3289], [3.3289], [3.3289]], device="cpu") >>> u_log_library_scale=torch.tensor([[0.6926], [0.6926], [0.6926]], device="cpu") >>> s_log_library_scale=torch.tensor([[0.8214], [0.8214], [0.8214]], device="cpu") >>> ind_x=torch.tensor([2, 0, 1], device="cpu") @@ -633,8 +591,6 @@ def forward( >>> s_obs, >>> u_log_library, >>> s_log_library, - >>> u_log_library_loc, - >>> s_log_library_loc, >>> u_log_library_scale, >>> s_log_library_scale, >>> ind_x, @@ -647,24 +603,12 @@ def forward( [11., 29., 10., 2.], [ 0., 0., 7., 4.]])) """ - cell_plate, gene_plate = self.create_plates( - u_obs=u_obs, - s_obs=s_obs, - u_log_library=u_log_library, - s_log_library=s_log_library, - u_log_library_loc=u_log_library_loc, - s_log_library_loc=s_log_library_loc, - u_log_library_scale=u_log_library_scale, - s_log_library_scale=s_log_library_scale, - ind_x=ind_x, - cell_state=cell_state, - time_info=time_info, - ) + cell_plate, gene_plate = self.create_plates(ind_x=ind_x) with gene_plate, poutine.mask(mask=self.include_prior): alpha = self.alpha - gamma = self.gamma beta = self.beta + gamma = self.gamma if self.add_offset: u0 = pyro.sample("u_offset", LogNormal(self.zero, self.one)) @@ -698,8 +642,6 @@ def forward( LogNormal(self.zero, self.one).mask(self.include_prior), ) - with cell_plate: - u_cell_size_coef = ut_coef = s_cell_size_coef = st_coef = None u_read_depth = pyro.sample( "u_read_depth", LogNormal(u_log_library, u_log_library_scale) ) @@ -724,16 +666,8 @@ def forward( u_dist, s_dist = self.get_likelihood( ut=ut, st=st, - u_log_library=u_log_library, - s_log_library=s_log_library, - u_scale=u_scale, - s_scale=s_scale, u_read_depth=u_read_depth, s_read_depth=s_read_depth, - u_cell_size_coef=u_cell_size_coef, - ut_coef=ut_coef, - s_cell_size_coef=s_cell_size_coef, - st_coef=st_coef, ) u = pyro.sample("u", u_dist, obs=u_obs) s = pyro.sample("s", s_dist, obs=s_obs)