Skip to content

Commit

Permalink
refactor(models): remove unused parameters from velocity_model
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Sep 13, 2024
1 parent 50d4970 commit 18d63ef
Showing 1 changed file with 2 additions and 68 deletions.
70 changes: 2 additions & 68 deletions src/pyrovelocity/models/_velocity_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,7 @@ def __repr__(self) -> str:
@beartype
def create_plates(
self,
u_obs: Optional[torch.Tensor] = None,
s_obs: Optional[torch.Tensor] = None,
u_log_library: Optional[torch.Tensor] = None,
s_log_library: Optional[torch.Tensor] = None,
u_log_library_loc: Optional[torch.Tensor] = None,
s_log_library_loc: Optional[torch.Tensor] = None,
u_log_library_scale: Optional[torch.Tensor] = None,
s_log_library_scale: Optional[torch.Tensor] = None,
ind_x: Optional[torch.Tensor] = None,
cell_state: Optional[torch.Tensor] = None,
time_info: Optional[torch.Tensor] = None,
) -> Tuple[plate, plate]:
"""
Create cell and gene plates for the model. note that usage in
Expand All @@ -126,17 +116,7 @@ def create_plates(
https://github.com/pyro-ppl/pyro/blob/1.9.1/pyro/infer/autoguide/guides.py#L60-L63
Args:
u_obs (Optional[torch.Tensor], optional): _description_. Defaults to None.
s_obs (Optional[torch.Tensor], optional): _description_. Defaults to None.
u_log_library (Optional[torch.Tensor], optional): _description_. Defaults to None.
s_log_library (Optional[torch.Tensor], optional): _description_. Defaults to None.
u_log_library_loc (Optional[torch.Tensor], optional): _description_. Defaults to None.
s_log_library_loc (Optional[torch.Tensor], optional): _description_. Defaults to None.
u_log_library_scale (Optional[torch.Tensor], optional): _description_. Defaults to None.
s_log_library_scale (Optional[torch.Tensor], optional): _description_. Defaults to None.
ind_x (Optional[torch.Tensor], optional): _description_. Defaults to None.
cell_state (Optional[torch.Tensor], optional): _description_. Defaults to None.
time_info (Optional[torch.Tensor], optional): _description_. Defaults to None.
Returns:
Tuple[plate, plate]: _description_
Expand Down Expand Up @@ -220,16 +200,8 @@ def get_likelihood(
self,
ut: torch.Tensor,
st: torch.Tensor,
u_log_library: Optional[torch.Tensor] = None,
s_log_library: Optional[torch.Tensor] = None,
u_scale: Optional[torch.Tensor] = None,
s_scale: Optional[torch.Tensor] = None,
u_read_depth: Optional[torch.Tensor] = None,
s_read_depth: Optional[torch.Tensor] = None,
u_cell_size_coef: None = None,
ut_coef: None = None,
s_cell_size_coef: None = None,
st_coef: None = None,
) -> Tuple[Poisson, Poisson]:
"""
Compute the likelihood of the given count data.
Expand Down Expand Up @@ -564,13 +536,9 @@ def forward(
s_obs: torch.Tensor,
u_log_library: Optional[torch.Tensor] = None,
s_log_library: Optional[torch.Tensor] = None,
u_log_library_loc: Optional[torch.Tensor] = None,
s_log_library_loc: Optional[torch.Tensor] = None,
u_log_library_scale: Optional[torch.Tensor] = None,
s_log_library_scale: Optional[torch.Tensor] = None,
ind_x: Optional[torch.Tensor] = None,
cell_state: Optional[torch.Tensor] = None,
time_info: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Defines the forward model, which computes the unspliced (u) and spliced
Expand All @@ -585,21 +553,13 @@ def forward(
Log-transformed library size for unspliced RNA. Default is None.
s_log_library (Optional[torch.Tensor], optional):
Log-transformed library size for spliced RNA. Default is None.
u_log_library_loc (Optional[torch.Tensor], optional):
Mean of log-transformed library size for unspliced RNA. Default is None.
s_log_library_loc (Optional[torch.Tensor], optional):
Mean of log-transformed library size for spliced RNA. Default is None.
u_log_library_scale (Optional[torch.Tensor], optional):
Scale of log-transformed library size for unspliced RNA.
Default is None.
s_log_library_scale (Optional[torch.Tensor], optional):
Scale of log-transformed library size for spliced RNA. Default is None.
ind_x (Optional[torch.Tensor], optional):
Indices for the cells. Default is None.
cell_state (Optional[torch.Tensor], optional):
Cell state information. Default is None.
time_info (Optional[torch.Tensor], optional):
Time information for the cells. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -622,8 +582,6 @@ def forward(
>>> )
>>> u_log_library=torch.tensor([[3.7377], [4.0254], [2.7081]], device="cpu")
>>> s_log_library=torch.tensor([[3.6376], [3.9512], [2.3979]], device="cpu")
>>> u_log_library_loc=torch.tensor([[3.4904], [3.4904], [3.4904]], device="cpu")
>>> s_log_library_loc=torch.tensor([[3.3289], [3.3289], [3.3289]], device="cpu")
>>> u_log_library_scale=torch.tensor([[0.6926], [0.6926], [0.6926]], device="cpu")
>>> s_log_library_scale=torch.tensor([[0.8214], [0.8214], [0.8214]], device="cpu")
>>> ind_x=torch.tensor([2, 0, 1], device="cpu")
Expand All @@ -633,8 +591,6 @@ def forward(
>>> s_obs,
>>> u_log_library,
>>> s_log_library,
>>> u_log_library_loc,
>>> s_log_library_loc,
>>> u_log_library_scale,
>>> s_log_library_scale,
>>> ind_x,
Expand All @@ -647,24 +603,12 @@ def forward(
[11., 29., 10., 2.],
[ 0., 0., 7., 4.]]))
"""
cell_plate, gene_plate = self.create_plates(
u_obs=u_obs,
s_obs=s_obs,
u_log_library=u_log_library,
s_log_library=s_log_library,
u_log_library_loc=u_log_library_loc,
s_log_library_loc=s_log_library_loc,
u_log_library_scale=u_log_library_scale,
s_log_library_scale=s_log_library_scale,
ind_x=ind_x,
cell_state=cell_state,
time_info=time_info,
)
cell_plate, gene_plate = self.create_plates(ind_x=ind_x)

with gene_plate, poutine.mask(mask=self.include_prior):
alpha = self.alpha
gamma = self.gamma
beta = self.beta
gamma = self.gamma

if self.add_offset:
u0 = pyro.sample("u_offset", LogNormal(self.zero, self.one))
Expand Down Expand Up @@ -698,8 +642,6 @@ def forward(
LogNormal(self.zero, self.one).mask(self.include_prior),
)

with cell_plate:
u_cell_size_coef = ut_coef = s_cell_size_coef = st_coef = None
u_read_depth = pyro.sample(
"u_read_depth", LogNormal(u_log_library, u_log_library_scale)
)
Expand All @@ -724,16 +666,8 @@ def forward(
u_dist, s_dist = self.get_likelihood(
ut=ut,
st=st,
u_log_library=u_log_library,
s_log_library=s_log_library,
u_scale=u_scale,
s_scale=s_scale,
u_read_depth=u_read_depth,
s_read_depth=s_read_depth,
u_cell_size_coef=u_cell_size_coef,
ut_coef=ut_coef,
s_cell_size_coef=s_cell_size_coef,
st_coef=st_coef,
)
u = pyro.sample("u", u_dist, obs=u_obs)
s = pyro.sample("s", s_dist, obs=s_obs)
Expand Down

0 comments on commit 18d63ef

Please sign in to comment.