Skip to content

Commit

Permalink
fix batch-wise loading
Browse files Browse the repository at this point in the history
  • Loading branch information
MauraJohn committed Nov 14, 2024
1 parent c33074b commit 3f7a1c2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 6 additions & 4 deletions models/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,14 @@ def perm_gwas(self, perm_method: str = 'x', adj_p_value: bool = False):
# load and transform batch of SNPs
print("\rCalculate perm test statistics for SNPs %d to %d" % (lower_bound, upper_bound), end='')
if perm_method == 'y':
US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound) # shape: (n,b)
US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound, save_meta=False) # shape: (n,b)
# transform data
US = self.transform_input(X=US, U=self.U)
# get 3D copy of S for permutations
US = self.get_3d_copy(v=US, batch_size=self.perm) # shape: (p,n,b)
else:
US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound, device=torch.device("cpu")) # shape: (n,b)
US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound, device=torch.device("cpu"),
save_meta=False) # shape: (n,b)
US = self.permute(data=US) # shape: (p,n,b)
# transform data
US = self.transform_input(X=US, U=self.U) # shape: (p,n,b)
Expand Down Expand Up @@ -660,21 +661,22 @@ def _d_delta(self, delta: torch.tensor, batch_size: int):
else:
return torch.unsqueeze((self.D + delta).repeat(batch_size, 1), 1)

def _s_matrix(self, lower_bound: int, upper_bound: int, device=None) -> torch.tensor:
def _s_matrix(self, lower_bound: int, upper_bound: int, device=None, save_meta: bool = True) -> torch.tensor:
"""
load batch of markers to specified device
:param lower_bound: lower bound of marker batch
:param upper_bound: upper bound of marker batch
:param device: either cpu or cuda device
:param save_meta: if genotype is loaded batch-wise, set to False for permutations to prevent saving of meta info
:return: matrix with markers of shape (n,upper_bound-lower_bound)
"""
if device is None:
device = self.device
if self.dataset.X is None:
# load X batch-wise
self.dataset.load_genotype_batch_wise(device=device, snp_lower_index=lower_bound,
self.dataset.load_genotype_batch_wise(device=device, save_meta=save_meta, snp_lower_index=lower_bound,
snp_upper_index=upper_bound) # shape: (n,b)
S = self.dataset.X # shape: (n,b)
self.dataset.reset_genotype()
Expand Down
6 changes: 3 additions & 3 deletions preprocess/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def load_genotype_batch_wise(self, device: torch.device = torch.device("cpu"), s
self.positions = positions
self.maf = maf
else:
self.chromosomes.append(chromosomes)
self.positions.append(positions)
self.maf.append(maf)
self.chromosomes = np.concatenate((self.chromosomes, chromosomes))
self.positions = np.concatenate((self.positions, positions))
self.maf = torch.cat((self.maf, maf))

def load_genotype_hdf5_file(self, snp_lower_index: int = None, snp_upper_index: int = None) -> tuple:
"""
Expand Down

0 comments on commit 3f7a1c2

Please sign in to comment.