From 3f7a1c2e3e4c63281f5719425ff9ac405f8d9cfc Mon Sep 17 00:00:00 2001 From: MauraJohn Date: Thu, 14 Nov 2024 09:55:26 +0100 Subject: [PATCH] fix batch-wise loading --- models/lmm.py | 10 ++++++---- preprocess/data_loader.py | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/models/lmm.py b/models/lmm.py index ccf74b1..aca4fe0 100644 --- a/models/lmm.py +++ b/models/lmm.py @@ -185,13 +185,14 @@ def perm_gwas(self, perm_method: str = 'x', adj_p_value: bool = False): # load and transform batch of SNPs print("\rCalculate perm test statistics for SNPs %d to %d" % (lower_bound, upper_bound), end='') if perm_method == 'y': - US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound) # shape: (n,b) + US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound, save_meta=False) # shape: (n,b) # transform data US = self.transform_input(X=US, U=self.U) # get 3D copy of S for permutations US = self.get_3d_copy(v=US, batch_size=self.perm) # shape: (p,n,b) else: - US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound, device=torch.device("cpu")) # shape: (n,b) + US = self._s_matrix(lower_bound=lower_bound, upper_bound=upper_bound, device=torch.device("cpu"), + save_meta=False) # shape: (n,b) US = self.permute(data=US) # shape: (p,n,b) # transform data US = self.transform_input(X=US, U=self.U) # shape: (p,n,b) @@ -660,13 +661,14 @@ def _d_delta(self, delta: torch.tensor, batch_size: int): else: return torch.unsqueeze((self.D + delta).repeat(batch_size, 1), 1) - def _s_matrix(self, lower_bound: int, upper_bound: int, device=None) -> torch.tensor: + def _s_matrix(self, lower_bound: int, upper_bound: int, device=None, save_meta: bool = True) -> torch.tensor: """ load batch of markers to specified device :param lower_bound: lower bound of marker batch :param upper_bound: upper bound of marker batch :param device: either cpu or cuda device + :param save_meta: if genotype is loaded batch-wise, set to False for permutations to prevent saving of meta info :return: matrix with markers of shape (n,upper_bound-lower_bound) """ @@ -674,7 +676,7 @@ def _s_matrix(self, lower_bound: int, upper_bound: int, device=None) -> torch.te device = self.device if self.dataset.X is None: # load X batch-wise - self.dataset.load_genotype_batch_wise(device=device, snp_lower_index=lower_bound, + self.dataset.load_genotype_batch_wise(device=device, save_meta=save_meta, snp_lower_index=lower_bound, snp_upper_index=upper_bound) # shape: (n,b) S = self.dataset.X # shape: (n,b) self.dataset.reset_genotype() diff --git a/preprocess/data_loader.py b/preprocess/data_loader.py index 7183dd1..45f4e76 100644 --- a/preprocess/data_loader.py +++ b/preprocess/data_loader.py @@ -122,9 +122,9 @@ def load_genotype_batch_wise(self, device: torch.device = torch.device("cpu"), s self.positions = positions self.maf = maf else: - self.chromosomes.append(chromosomes) - self.positions.append(positions) - self.maf.append(maf) + self.chromosomes = np.concatenate((self.chromosomes, chromosomes)) + self.positions = np.concatenate((self.positions, positions)) + self.maf = torch.cat((self.maf, maf)) def load_genotype_hdf5_file(self, snp_lower_index: int = None, snp_upper_index: int = None) -> tuple: """