diff --git a/stamarker/stamarker/.ipynb_checkpoints/models-checkpoint.py b/stamarker/stamarker/.ipynb_checkpoints/models-checkpoint.py deleted file mode 100644 index d7457ec..0000000 --- a/stamarker/stamarker/.ipynb_checkpoints/models-checkpoint.py +++ /dev/null @@ -1,257 +0,0 @@ -from abc import ABC -from typing import Any, List -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, WeightedRandomSampler -from torch_geometric.data import Data -from sklearn.metrics import adjusted_rand_score, confusion_matrix -from .modules import STAGATEModule, StackMLPModule -from .dataset import RepDataset, Batch -from .utils import Timer - -def get_optimizer(name): - if name == "ADAM": - return torch.optim.Adam - elif name == "ADAGRAD": - return torch.optim.Adagrad - elif name == "ADADELTA": - return torch.optim.Adadelta - elif name == "RMS": - return torch.optim.RMSprop - elif name == "ASGD": - return torch.optim.ASGD - else: - raise NotImplementedError - - -def get_scheduler(name): - if name == "STEP_LR": - return torch.optim.lr_scheduler.StepLR - elif name == "EXP_LR": - return torch.optim.lr_scheduler.ExponentialLR - else: - raise NotImplementedError - -class BaseModule(pl.LightningModule, ABC): - def __init__(self): - super(BaseModule, self).__init__() - self.optimizer_params = None - self.scheduler_params = None - self.model = None - self.timer = Timer() - self.automatic_optimization = False - - def set_optimizer_params(self, - optimizer_params: dict, - scheduler_params: dict): - self.optimizer_params = optimizer_params - self.scheduler_params = scheduler_params - - def configure_optimizers(self): - optimizer = get_optimizer(self.optimizer_params["name"])( - self.model.parameters(), - **self.optimizer_params["params"]) - scheduler = get_scheduler(self.scheduler_params["name"])(optimizer, **self.scheduler_params["params"]) - return [optimizer], [scheduler] - - def on_train_epoch_start(self) -> None: - self.timer.tic('train') - - -class intSTAGATE(BaseModule): - """ - intSTAGATE Lightning Module - """ - def __init__(self, - in_features: int = None, - hidden_dims: List[int] = None, - gradient_clipping: float = 5.0, - **kwargs): - super(intSTAGATE, self).__init__() - self.model = STAGATEModule(in_features, hidden_dims) - self.auto_encoder_epochs = None - self.gradient_clipping = gradient_clipping - self.pred_labels = None - self.save_hyperparameters() - - def configure_optimizers(self) -> (dict, dict): - auto_encoder_optimizer = get_optimizer(self.optimizer_params["name"])( - list(self.model.parameters()), - **self.optimizer_params["params"]) - auto_encoder_scheduler = get_scheduler(self.scheduler_params["name"])(auto_encoder_optimizer, - **self.scheduler_params["params"]) - return [auto_encoder_optimizer], [auto_encoder_scheduler] - - def forward(self, x, edge_index) -> Any: - return self.model(x, edge_index) - - def training_step(self, batch, batch_idx): - batch = batch.to(self.device) - opt_auto_encoder = self.optimizers() - z, x_hat = self.model(batch.x, batch.edge_index) - loss = F.mse_loss(batch.x, x_hat) - opt_auto_encoder.zero_grad() - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping) - opt_auto_encoder.step() - self.log("Training auto-encoder|Reconstruction errors", loss.item(), prog_bar=True) - self.logger.experiment.add_scalar('auto_encoder/loss', loss.item(), self.current_epoch) - - def on_train_epoch_end(self) -> None: - time = self.timer.toc('train') - sch_auto_encoder = self.lr_schedulers() - sch_auto_encoder.step() - self.logger.experiment.add_scalar('train_time', time, self.current_epoch) - - def validation_step(self, batch, batch_idx): - pass - - def validation_epoch_end(self, outputs): - pass - - -def _compute_correct(scores, target_y): - _, pred_labels = torch.max(scores, axis=1) - correct = (pred_labels == target_y).sum().item() - return pred_labels, correct - - -class CoordTransformer(object): - def __init__(self, coord): - self.coord = coord - - def transform(self): - factor = np.max(np.max(self.coord, axis=0) - np.min(self.coord, axis=0)) - return (self.coord - np.min(self.coord, axis=0)) / factor - - -class StackClassifier(BaseModule): - def __init__(self, in_features: int, - n_classes: int = 7, - batch_size: int = 1000, - shuffle: bool = False, - hidden_dims: List[int] = [30], - architecture: str = "MLP", - sta_path: str = None, - **kwargs): - super(StackClassifier, self).__init__() - self.in_features = in_features - self.architecture = architecture - self.batch_size = batch_size - self.shuffle = shuffle - if architecture == "MLP": - self.model = StackMLPModule(in_features, n_classes, hidden_dims, **kwargs) - else: - raise NotImplementedError - self.dataset = None - self.train_dataset = None - self.val_dataset = None - self.automatic_optimization = False - self.sampler = None - self.test_prop = None - self.confusion = None - self.balanced = None - self.save_hyperparameters() - - def prepare(self, - stagate: intSTAGATE, - dataset: Data, - target_y, - test_prop: float = 0.5, - balanced: bool = True): - self.balanced = balanced - self.test_prop = test_prop - with torch.no_grad(): - representation, _ = stagate(dataset.x, dataset.edge_index) - if hasattr(dataset, "ground_truth"): - ground_truth = dataset.ground_truth - else: - ground_truth = None - if isinstance(target_y, np.ndarray): - target_y = torch.from_numpy(target_y).type(torch.LongTensor) - elif isinstance(target_y, torch.Tensor): - target_y = target_y.type(torch.LongTensor) - else: - raise TypeError("target_y must be either a torch tensor or a numpy ndarray.") - self.dataset = RepDataset(representation, target_y, ground_truth=ground_truth) - n_val = int(len(self.dataset) * test_prop) - self.train_dataset, self.val_dataset = torch.utils.data.random_split( - self.dataset, [len(self.dataset) - n_val, n_val]) - if balanced: - target_y = target_y[self.train_dataset.indices] - class_sample_count = np.array([len(np.where(target_y == t)[0]) for t in np.unique(target_y)]) - weight = 1. / class_sample_count - samples_weight = np.array([weight[t] for t in target_y]) - samples_weight = torch.from_numpy(samples_weight) - samples_weight = samples_weight.double() - self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) - - def forward(self, x, edge_index=None) -> Any: - if self.architecture == "MLP": - return self.model(x) - elif self.architecture == "STACls": - _, output = self.model(x, edge_index) - return output - - def training_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - opt = self.optimizers() - opt.zero_grad() - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - self.manual_backward(loss) - opt.step() - _, correct = _compute_correct(output["score"], batch.y) - self.log(f"Training {self.architecture} classifier|Cross entropy", loss.item(), prog_bar=True) - return {"loss": loss, "correct": correct} - - def training_epoch_end(self, outputs): - time = self.timer.toc('train') - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_time', time, self.current_epoch) - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - train_acc = all_correct / len(self.train_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_acc', - train_acc, self.current_epoch) - - def validation_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - with torch.no_grad(): - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - pred_labels, correct = _compute_correct(output["score"], batch.y) - return {"loss": loss, "correct": correct, "pred_labels": pred_labels, "true_labels": batch.y} - - def validation_epoch_end(self, outputs): - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - pred_labels = torch.cat([x["pred_labels"] for x in outputs]).cpu().detach().numpy() - true_labels = torch.cat([x["true_labels"] for x in outputs]).cpu().detach().numpy() - confusion = confusion_matrix(true_labels, pred_labels) - val_acc = all_correct / len(self.val_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_acc', - val_acc, self.current_epoch) - print("\n validation ACC={:.4f}".format(val_acc)) - self.confusion = confusion - - def train_dataloader(self): - loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.sampler) - return loader - - def val_dataloader(self): - loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=self.shuffle, drop_last=False) - return loader - - def test_dataloader(self): - raise NotImplementedError - - def predict_dataloader(self): - raise NotImplementedError \ No newline at end of file diff --git a/stamarker/stamarker/.ipynb_checkpoints/modules-checkpoint.py b/stamarker/stamarker/.ipynb_checkpoints/modules-checkpoint.py deleted file mode 100644 index 08d9d81..0000000 --- a/stamarker/stamarker/.ipynb_checkpoints/modules-checkpoint.py +++ /dev/null @@ -1,276 +0,0 @@ -import abc -import copy -from torch.autograd import Variable -import torch -from torch import Tensor -import torch.nn.functional as F -from torch.nn import Parameter -import torch.nn as nn -from torch_sparse import SparseTensor, set_diag -from torch_geometric.nn.conv import MessagePassing -from torch_geometric.utils import remove_self_loops, add_self_loops, softmax -from typing import Union, Tuple, Optional -from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, - OptTensor) - - -class GATConv(MessagePassing): - r"""The graph attentional operator from the `"Graph Attention Networks" - `_ paper - .. math:: - \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + - \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, - where the attention coefficients :math:`\alpha_{i,j}` are computed as - .. math:: - \alpha_{i,j} = - \frac{ - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] - \right)\right)} - {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] - \right)\right)}. - Args: - in_channels (int or tuple): Size of each input sample, or :obj:`-1` to - derive the size from the first input(s) to the forward method. - A tuple corresponds to the sizes of source and target - dimensionalities. - out_channels (int): Size of each output sample. - heads (int, optional): Number of multi-head-attentions. - (default: :obj:`1`) - concat (bool, optional): If set to :obj:`False`, the multi-head - attentions are averaged instead of concatenated. - (default: :obj:`True`) - negative_slope (float, optional): LeakyReLU angle of the negative - slope. (default: :obj:`0.2`) - dropout (float, optional): Dropout probability of the normalized - attention coefficients which exposes each node to a stochastically - sampled neighborhood during training. (default: :obj:`0`) - add_self_loops (bool, optional): If set to :obj:`False`, will not add - self-loops to the input graph. (default: :obj:`True`) - bias (bool, optional): If set to :obj:`False`, the layer will not learn - an additive bias. (default: :obj:`True`) - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - """ - _alpha: OptTensor - - def __init__(self, in_channels: Union[int, Tuple[int, int]], - out_channels: int, heads: int = 1, concat: bool = True, - negative_slope: float = 0.2, dropout: float = 0.0, - add_self_loops: bool = True, bias: bool = True, **kwargs): - kwargs.setdefault('aggr', 'add') - super(GATConv, self).__init__(node_dim=0, **kwargs) - - self.in_channels = in_channels - self.out_channels = out_channels - self.heads = heads - self.concat = concat - self.negative_slope = negative_slope - self.dropout = dropout - self.add_self_loops = add_self_loops - self.lin_src = nn.Parameter(torch.zeros(size=(in_channels, out_channels))) - nn.init.xavier_normal_(self.lin_src.data, gain=1.414) - self.lin_dst = self.lin_src - # The learnable parameters to compute attention coefficients: - self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) - self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) - nn.init.xavier_normal_(self.att_src.data, gain=1.414) - nn.init.xavier_normal_(self.att_dst.data, gain=1.414) - self._alpha = None - self.attentions = None - - def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, - size: Size = None, return_attention_weights=None, attention=True, tied_attention=None): - # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa - r""" - Args: - return_attention_weights (bool, optional): If set to :obj:`True`, - will additionally return the tuple - :obj:`(edge_index, attention_weights)`, holding the computed - attention weights for each edge. (default: :obj:`None`) - """ - H, C = self.heads, self.out_channels - - # We first transform the input node features. If a tuple is passed, we - # transform source and target node features via separate weights: - if isinstance(x, Tensor): - assert x.dim() == 2, "Static graphs not supported in 'GATConv'" - # x_src = x_dst = self.lin_src(x).view(-1, H, C) - x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) - else: # Tuple of source and target node features: - x_src, x_dst = x - assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" - x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None: - x_dst = self.lin_dst(x_dst).view(-1, H, C) - - x = (x_src, x_dst) - - if not attention: - return x[0].mean(dim=1) - # return x[0].view(-1, self.heads * self.out_channels) - - if tied_attention == None: - # Next, we compute node-level attention coefficients, both for source - # and target nodes (if present): - alpha_src = (x_src * self.att_src).sum(dim=-1) - alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) - alpha = (alpha_src, alpha_dst) - self.attentions = alpha - else: - alpha = tied_attention - - if self.add_self_loops: - if isinstance(edge_index, Tensor): - # We only want to add self-loops for nodes that appear both as - # source and target nodes: - num_nodes = x_src.size(0) - if x_dst is not None: - num_nodes = min(num_nodes, x_dst.size(0)) - num_nodes = min(size) if size is not None else num_nodes - edge_index, _ = remove_self_loops(edge_index) - edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) - elif isinstance(edge_index, SparseTensor): - edge_index = set_diag(edge_index) - - # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) - out = self.propagate(edge_index, x=x, alpha=alpha, size=size) - - alpha = self._alpha - assert alpha is not None - self._alpha = None - - if self.concat: - out = out.view(-1, self.heads * self.out_channels) - else: - out = out.mean(dim=1) - - # if self.bias is not None: - # out += self.bias - - if isinstance(return_attention_weights, bool): - if isinstance(edge_index, Tensor): - return out, (edge_index, alpha) - elif isinstance(edge_index, SparseTensor): - return out, edge_index.set_value(alpha, layout='coo') - else: - return out - - def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, - index: Tensor, ptr: OptTensor, - size_i: Optional[int]) -> Tensor: - # Given egel-level attention coefficients for source and target nodes, - # we simply need to sum them up to "emulate" concatenation: - alpha = alpha_j if alpha_i is None else alpha_j + alpha_i - - alpha = F.leaky_relu(alpha, self.negative_slope) - alpha = softmax(alpha, index, ptr, size_i) - self._alpha = alpha # Save for later use. - alpha = F.dropout(alpha, p=self.dropout, training=self.training) - return x_j * alpha.unsqueeze(-1) - - def __repr__(self): - return '{}({}, {}, heads={})'.format(self.__class__.__name__, - self.in_channels, - self.out_channels, self.heads) - - -class STAGATEModule(nn.Module): - def __init__(self, in_features, hidden_dims): - super(STAGATEModule, self).__init__() - [num_hidden, out_dim] = hidden_dims - self.conv1 = GATConv(in_features, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv4 = GATConv(num_hidden, in_features, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - - def forward(self, features, edge_index): - h1 = F.elu(self.conv1(features, edge_index)) - h2 = self.conv2(h1, edge_index, attention=False) - self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1) - self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1) - self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1) - self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1) - h3 = F.elu(self.conv3(h2, edge_index, attention=True, - tied_attention=self.conv1.attentions)) - h4 = self.conv4(h3, edge_index, attention=False) - - return h2, h4 - - -class StackClsModule(nn.Module, abc.ABC): - def __init__(self, in_features, n_classes): - super(StackClsModule, self).__init__() - self.in_features = in_features - self.n_classes = n_classes - - -class STAGATEClsModule(nn.Module): - def __init__(self, - stagate: STAGATEModule, - stack_classifier: StackClsModule): - super(STAGATEClsModule, self).__init__() - self.stagate = copy.deepcopy(stagate) - self.classifier = copy.deepcopy(stack_classifier) - - def forward(self, x, edge_index, mode="classifier"): - z, x_recon = self.stagate(x, edge_index) - z = torch.clone(z) - if mode == "classifier": - return z, self.classifier(z) - elif mode == "reconstruction": - return z, x_recon - else: - raise NotImplementedError - - def get_saliency_map(self, x, edge_index, target_index="max", save=None): - """ - Get saliency map by backpropagation. - :param x: input tensors - :param edge_index: graph edge index - :param target_index: target index to compute final scores - :param save: - :return: gradients - """ - x_var = Variable(x, requires_grad=True) - _, output = self.forward(x_var, edge_index, mode="classifier") - scores = output["last_layer"] - if target_index == "max": - target_score_indices = Variable(torch.argmax(scores, 1)) - elif isinstance(target_index, int): - target_score_indices = Variable(torch.ones(scores.shape[0], dtype=torch.int64) * target_index) - else: - raise NotImplementedError - target_scores = scores.gather(1, target_score_indices.view(-1, 1)).squeeze() - loss = torch.sum(target_scores) - loss.backward() - gradients = x_var.grad.data - if save is not None: - torch.save(gradients, save) - return gradients, scores - - -class StackMLPModule(StackClsModule): - name = "StackMLP" - - def __init__(self, in_features, n_classes, hidden_dims=[30, 40, 30]): - super(StackMLPModule, self).__init__(in_features, n_classes) - self.classifier = nn.ModuleList() - mlp_dims = [in_features] + hidden_dims + [n_classes] - for ind in range(len(mlp_dims) - 1): - self.classifier.append(nn.Linear(mlp_dims[ind], mlp_dims[ind + 1])) - - def forward(self, x): - for layer in self.classifier: - x = layer(x) - score = F.softmax(x, dim=0) - return {"last_layer": x, "score": score} diff --git a/stamarker/stamarker/.ipynb_checkpoints/pipeline-checkpoint.py b/stamarker/stamarker/.ipynb_checkpoints/pipeline-checkpoint.py deleted file mode 100644 index c24592a..0000000 --- a/stamarker/stamarker/.ipynb_checkpoints/pipeline-checkpoint.py +++ /dev/null @@ -1,287 +0,0 @@ -import pytorch_lightning as pl -import copy -import torch -import os -import shutil -import logging -import glob -import sys -import numpy as np -import scipy -import scanpy as sc -from pytorch_lightning.loggers import TensorBoardLogger -from scipy.cluster import hierarchy -from .models import intSTAGATE, StackClassifier -from .utils import plot_consensus_map, consensus_matrix, Timer -from .dataset import SpatialDataModule -from .modules import STAGATEClsModule -import logging - - -FORMAT = "%(asctime)s %(levelname)s %(message)s" -logging.basicConfig(format=FORMAT, datefmt='%Y-%m-%d %H:%M:%S') -def make_spatial_data(ann_data): - """ - Make SpatialDataModule object from Scanpy annData object - """ - data_module = SpatialDataModule() - ann_data.X = ann_data.X.toarray() - data_module.ann_data = ann_data - return data_module - - -class STAMarker: - def __init__(self, n, save_dir, config, logging_level=logging.INFO): - """ - n: int, number of graph attention auto-econders to train - save_dir: directory to save the models - config: config file for training - """ - self.n = n - self.save_dir = save_dir - if not os.path.exists(save_dir): - os.mkdir(save_dir) - logging.info("Create save directory {}".format(save_dir)) - self.version_dirs = [os.path.join(save_dir, f"version_{i}") for i in range(n)] - self.config = config - self.logger = logging.getLogger("STAMarker") - self.logger.setLevel(logging_level) - self.consensus_labels = None - - def load_from_dir(self, save_dir, ): - """ - Load the trained models from a directory - """ - self.version_dirs = glob.glob(os.path.join(save_dir, "version_*")) - self.version_dirs = sorted(self.version_dirs, key=lambda x: int(x.split("_")[-1])) - # check if all version dir have `checkpoints/stagate.ckpt` - version_dirs_valid = [] - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "checkpoints/stagate.ckpt")): - self.logger.warning("No checkpoint found in {}".format(version_dir)) - else: - version_dirs_valid.append(version_dir) - self.version_dirs = version_dirs_valid - self.logger.info("Load {} autoencoder models from {}".format(len(version_dirs_valid), save_dir)) - # check if all version dir have `cluster_labels.npy` raise warning if not - missing_cluster_labels = False - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "cluster_labels.npy")): - missing_cluster_labels = True - msg = "No cluster labels found in {}.".format(version_dir) - self.logger.warning(msg) - if missing_cluster_labels: - self.logger.warning("Please run clustering first.") - # check if save_dir has `consensus.npy` raise warning if not - if not os.path.exists(os.path.join(save_dir, "consensus.npy")): - self.logger.warning("No consensus labels found in {}".format(save_dir)) - else: - self.consensus_labels = np.load(os.path.join(save_dir, "consensus.npy")) - # check if all version dir have `checkpoints/mlp.ckpt` raise warning if not - missing_clf = False - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "checkpoints/mlp.ckpt")): - self.logger.warning("No classifier checkpoint found in {}".format(version_dir)) - missing_clf = True - if missing_clf: - self.logger.warning("Please run classifier training first.") - if not missing_cluster_labels and not missing_clf: - self.logger.info("All models are trained and ready to use.") - - def train_auto_encoders(self, data_module): - for seed in range(self.n): - self._train_auto_encoder(data_module, seed, self.config) - self.logger.info("Finished training {} auto-encoders".format(self.n)) - - def clustering(self, data_module, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoders - Cluster method should be "louvain" or "mclust" - """ - for version_dir in self.version_dirs: - self._clustering(data_module, version_dir, cluster_method, cluster_params) - self.logger.info("Finished {} clustering with {}".format(self.n, cluster_method)) - - def consensus_clustering(self, n_clusters, name="cluster_labels.npy"): - sys.setrecursionlimit(100000) - label_files = glob.glob(self.save_dir + f"/version_*/{name}") - labels_list = list(map(lambda file: np.load(file), label_files)) - cons_mat = consensus_matrix(labels_list) - row_linkage, _, figure = plot_consensus_map(cons_mat, return_linkage=True) - figure.savefig(os.path.join(self.save_dir, "consensus_clustering.png"), dpi=300) - consensus_labels = hierarchy.cut_tree(row_linkage, n_clusters).squeeze() - np.save(os.path.join(self.save_dir, "consensus"), consensus_labels) - self.consensus_labels = consensus_labels - self.logger.info("Save consensus labels to {}".format(os.path.join(self.save_dir, "consensus.npz"))) - - def train_classifiers(self, data_module, n_clusters, name="cluster_labels.npy"): - for i, version_dir in enumerate(self.version_dirs): - # _train_classifier(self, data_module, version_dir, target_y, n_classes, seed=None) - self._train_classifier(data_module, version_dir, self.consensus_labels, - n_clusters, self.config, seed=i) - self.logger.info("Finished training {} classifiers".format(self.n)) - - def compute_smaps(self, data_module, return_recon=True, normalize=True): - smaps = [] - if return_recon: - recons = [] - for version_dir in self.version_dirs: - if return_recon: - smap, recon = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - recons.append(recon) - else: - smap = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - if return_recon: - return smaps, recons - else: - return smaps - self.logger.info("Finished computing {} smaps".format(self.n)) - - - def _compute_smap_zscore(self, smap, labels, logtransform=False): - scores = np.log(smap + 1) if logtransform else copy.copy(smap) - unique_labels = np.unique(labels) - for l in unique_labels: - scores[labels == l, :] = scipy.stats.zscore(scores[labels == l, :], axis=1) - return scores - - - def _clustering(self, data_module, version_dir, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoder - """ - if cluster_method == "louvain": - run_louvain(data_module, version_dir, cluster_params) - elif cluster_method == "mclust": - run_mclust(data_module, version_dir, cluster_params) - else: - raise ValueError("Unknown clustering method") - - def _train_auto_encoder(self, data_module, seed, config): - """ - Train a single graph attention auto-encoder - """ - pl.seed_everything(seed) - version = f"version_{seed}" - version_dir = os.path.join(self.save_dir, version) - if os.path.exists(version_dir): - shutil.rmtree(version_dir) - os.makedirs(version_dir, exist_ok=True) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - model = intSTAGATE(**config["stagate"]["params"]) - model.set_optimizer_params(config["stagate"]["optimizer"], config["stagate"]["scheduler"]) - trainer = pl.Trainer(logger=logger, **config["stagate_trainer"]) - timer = Timer() - timer.tic("fit") - trainer.fit(model, data_module) - fit_time = timer.toc("fit") - with open(os.path.join(version_dir, "runtime.csv"), "w+") as f: - f.write("{}, fit_time, {:.2f}, ".format(seed, fit_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - del model, trainer - if config["stagate_trainer"]["gpus"] > 0: - torch.cuda.empty_cache() - logging.info(f"Finshed running version {seed}") - - def _train_classifier(self, data_module, version_dir, target_y, n_classes, config, seed=None): - timer = Timer() - pl.seed_everything(seed) - rep_dim = config["stagate"]["params"]["hidden_dims"][-1] - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - classifier = StackClassifier(rep_dim, n_classes=n_classes, architecture="MLP") - classifier.prepare(stagate, data_module.train_dataset, target_y, - balanced=config["mlp"]["balanced"], test_prop=config["mlp"]["test_prop"]) - classifier.set_optimizer_params(config["mlp"]["optimizer"], config["mlp"]["scheduler"]) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - trainer = pl.Trainer(logger=logger, **config["classifier_trainer"]) - timer.tic("clf") - trainer.fit(classifier) - clf_time = timer.toc("clf") - with open(os.path.join(version_dir, "runtime.csv"), "a+") as f: - f.write("\n") - f.write("{}, clf_time, {:.2f}, ".format(seed, clf_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - target_y = classifier.dataset.target_y.numpy() - all_props = class_proportions(target_y) - val_props = class_proportions(target_y[classifier.val_dataset.indices]) - if self.logger.level == logging.DEBUG: - print("All class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in all_props])) - print("Val class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in val_props])) - np.save(os.path.join(version_dir, "confusion.npy"), classifier.confusion) - - def _compute_smap(self, data_module, version_dir, return_recon=True): - """ - Compute the saliency map of the trained auto-encoder - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - cls = StackClassifier.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - stagate_cls = STAGATEClsModule(stagate.model, cls.model) - smap, _ = stagate_cls.get_saliency_map(data_module.train_dataset.x, - data_module.train_dataset.edge_index) - smap = smap.detach().cpu().numpy() - if return_recon: - recon = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[1].cpu().detach().numpy() - return smap, recon - else: - return smap - - -def run_louvain(data_module, version_dir, resolution, name="cluster_labels"): - """ - Run louvain clustering on the data_module - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - ann_data = copy.copy(data_module.ann_data) - ann_data.obsm["stagate"] = embedding - sc.pp.neighbors(ann_data, use_rep='stagate') - sc.tl.louvain(ann_data, resolution=resolution) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, ann_data.obs["louvain"].to_numpy().astype("int")) - print("Save louvain results to {}".format(save_path)) - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def run_mclust(data_module, version_dir, n_clusters, name="cluster_labels"): - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - labels = mclust_R(embedding, n_clusters) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, labels.astype("int")) - print("Save MClust results to {}".format(save_path)) - -def class_proportions(target): - n_classes = len(np.unique(target)) - props = np.array([np.sum(target == i) for i in range(n_classes)]) - return props / np.sum(props) - - - - - - - diff --git a/stamarker/stamarker/.ipynb_checkpoints/utils-checkpoint.py b/stamarker/stamarker/.ipynb_checkpoints/utils-checkpoint.py deleted file mode 100644 index 10a4c14..0000000 --- a/stamarker/stamarker/.ipynb_checkpoints/utils-checkpoint.py +++ /dev/null @@ -1,192 +0,0 @@ -import time -import yaml -import os -import seaborn as sns -import numpy as np -import pandas as pd -import scanpy as sc -import itertools -import scipy -from scipy.spatial import distance -from scipy.cluster import hierarchy -import sklearn.neighbors -from typing import List - - -def plot_consensus_map(cmat, method="average", return_linkage=True, **kwargs): - row_linkage = hierarchy.linkage(distance.pdist(cmat), method=method) - col_linkage = hierarchy.linkage(distance.pdist(cmat.T), method=method) - figure = sns.clustermap(cmat, row_linkage=row_linkage, col_linkage=col_linkage, **kwargs) - if return_linkage: - return row_linkage, col_linkage, figure - else: - return figure - - -class Timer: - - def __init__(self): - self.timer_dict = {} - self.stop_dict = {} - - def tic(self, name): - self.timer_dict[name] = time.time() - - def toc(self, name): - assert name in self.timer_dict - elapsed = time.time() - self.timer_dict[name] - del self.timer_dict[name] - return elapsed - - def stop(self, name): - self.stop_dict[name] = time.time() - - def resume(self, name): - if name not in self.timer_dict: - del self.stop_dict[name] - return - elapsed = time.time() - self.stop_dict[name] - self.timer_dict[name] = self.timer_dict[name] + elapsed - del self.stop_dict[name] - - -def save_yaml(yaml_object, file_path): - with open(file_path, 'w') as yaml_file: - yaml.dump(yaml_object, yaml_file, default_flow_style=False) - - print(f'Saving yaml: {file_path}') - return - - -def parse_args(yaml_file): - with open(yaml_file, 'r') as stream: - try: - cfg = yaml.safe_load(stream) - except yaml.YAMLError as exc: - print(exc) - return cfg - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def labels_connectivity_mat(labels: np.ndarray): - _labels = labels - np.min(labels) - n_classes = np.unique(_labels) - mat = np.zeros([labels.size, labels.size]) - for i in n_classes: - indices = np.squeeze(np.where(_labels == i)) - row_indices, col_indices = zip(*itertools.product(indices, indices)) - mat[row_indices, col_indices] = 1 - return mat - - -def consensus_matrix(labels_list: List[np.ndarray]): - mat = 0 - for labels in labels_list: - mat += labels_connectivity_mat(labels) - return mat / float(len(labels_list)) - - -def compute_spatial_net(ann_data, rad_cutoff=None, k_cutoff=None, - max_neigh=50, model='Radius', verbose=True): - """ - Construct the spatial neighbor networks. - - Parameters - ---------- - ann_data - AnnData object of scanpy package. - rad_cutoff - radius cutoff when model='Radius' - k_cutoff - The number of nearest neighbors when model='KNN' - model - The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. - - Returns - ------- - The spatial networks are saved in adata.uns['Spatial_Net'] - """ - - assert (model in ['Radius', 'KNN']) - if verbose: - print('------Calculating spatial graph...') - coor = pd.DataFrame(ann_data.obsm['spatial']) - coor.index = ann_data.obs.index - coor.columns = ['imagerow', 'imagecol'] - - nbrs = sklearn.neighbors.NearestNeighbors( - n_neighbors=max_neigh + 1, algorithm='ball_tree').fit(coor) - distances, indices = nbrs.kneighbors(coor) - if model == 'KNN': - indices = indices[:, 1:k_cutoff + 1] - distances = distances[:, 1:k_cutoff + 1] - if model == 'Radius': - indices = indices[:, 1:] - distances = distances[:, 1:] - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) - KNN_df = pd.concat(KNN_list) - KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] - Spatial_Net = KNN_df.copy() - if model == 'Radius': - Spatial_Net = KNN_df.loc[KNN_df['Distance'] < rad_cutoff,] - id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) - cell1, cell2 = Spatial_Net['Cell1'].map(id_cell_trans), Spatial_Net['Cell2'].map(id_cell_trans) - Spatial_Net = Spatial_Net.assign(Cell1=cell1, Cell2=cell2) - # Spatial_Net.assign(Cell1=Spatial_Net['Cell1'].map(id_cell_trans)) - # Spatial_Net.assign(Cell2=Spatial_Net['Cell2'].map(id_cell_trans)) - if verbose: - print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], ann_data.n_obs)) - print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / ann_data.n_obs)) - ann_data.uns['Spatial_Net'] = Spatial_Net - - -def compute_edge_list(ann_data): - G_df = ann_data.uns['Spatial_Net'].copy() - cells = np.array(ann_data.obs_names) - cells_id_tran = dict(zip(cells, range(cells.shape[0]))) - G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) - G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) - G = scipy.sparse.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), - shape=(ann_data.n_obs, ann_data.n_obs)) - G = G + scipy.sparse.eye(G.shape[0]) - edge_list = np.nonzero(G) - return edge_list - - -def stats_spatial_net(ann_data): - import matplotlib.pyplot as plt - Num_edge = ann_data.uns['Spatial_Net']['Cell1'].shape[0] - Mean_edge = Num_edge / ann_data.shape[0] - plot_df = pd.value_counts(pd.value_counts(ann_data.uns['Spatial_Net']['Cell1'])) - plot_df = plot_df / ann_data.shape[0] - fig, ax = plt.subplots(figsize=[3, 2]) - plt.ylabel('Percentage') - plt.xlabel('') - plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge) - ax.bar(plot_df.index, plot_df) - - -def select_stmaker_svgs(df, sd_id, alpha=1.5, top=None): - scores = df[f"score_{sd_id}"] - mu, std = np.mean(scores), np.std(scores) - return df.index[scores > mu + alpha * std].tolist() diff --git a/stamarker/stamarker/__init__.py b/stamarker/stamarker/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/stamarker/stamarker/__pycache__/__init__.cpython-38.pyc b/stamarker/stamarker/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 275509f..0000000 Binary files a/stamarker/stamarker/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/__init__.cpython-39.pyc b/stamarker/stamarker/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index ab20350..0000000 Binary files a/stamarker/stamarker/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/dataset.cpython-38.pyc b/stamarker/stamarker/__pycache__/dataset.cpython-38.pyc deleted file mode 100644 index e38dc8d..0000000 Binary files a/stamarker/stamarker/__pycache__/dataset.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/dataset.cpython-39.pyc b/stamarker/stamarker/__pycache__/dataset.cpython-39.pyc deleted file mode 100644 index 4b46d88..0000000 Binary files a/stamarker/stamarker/__pycache__/dataset.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/models.cpython-38.pyc b/stamarker/stamarker/__pycache__/models.cpython-38.pyc deleted file mode 100644 index dd9d50b..0000000 Binary files a/stamarker/stamarker/__pycache__/models.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/models.cpython-39.pyc b/stamarker/stamarker/__pycache__/models.cpython-39.pyc deleted file mode 100644 index 4ec959a..0000000 Binary files a/stamarker/stamarker/__pycache__/models.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/modules.cpython-38.pyc b/stamarker/stamarker/__pycache__/modules.cpython-38.pyc deleted file mode 100644 index 39452b5..0000000 Binary files a/stamarker/stamarker/__pycache__/modules.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/modules.cpython-39.pyc b/stamarker/stamarker/__pycache__/modules.cpython-39.pyc deleted file mode 100644 index 923b393..0000000 Binary files a/stamarker/stamarker/__pycache__/modules.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/pipeline.cpython-38.pyc b/stamarker/stamarker/__pycache__/pipeline.cpython-38.pyc deleted file mode 100644 index 1931108..0000000 Binary files a/stamarker/stamarker/__pycache__/pipeline.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/pipeline.cpython-39.pyc b/stamarker/stamarker/__pycache__/pipeline.cpython-39.pyc deleted file mode 100644 index 0dd6347..0000000 Binary files a/stamarker/stamarker/__pycache__/pipeline.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/utils.cpython-38.pyc b/stamarker/stamarker/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index c1e4cfd..0000000 Binary files a/stamarker/stamarker/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/__pycache__/utils.cpython-39.pyc b/stamarker/stamarker/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index 44823db..0000000 Binary files a/stamarker/stamarker/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/dataset.py b/stamarker/stamarker/dataset.py deleted file mode 100644 index 14398d6..0000000 --- a/stamarker/stamarker/dataset.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import List -import scanpy as sc -import numpy as np -import torch -import pytorch_lightning as pl -from torch.utils.data import Dataset -from pytorch_lightning.utilities.types import EVAL_DATALOADERS -from torch_geometric.loader import NeighborLoader -from torch_geometric.data import Data -from .utils import compute_spatial_net, stats_spatial_net, compute_edge_list - - -class Batch(dict): - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def to(self, device): - res = dict() - for key, value in self.items(): - if hasattr(value, "to"): - res[key] = value.to(device) - else: - res[key] = value - return Batch(**res) - - -class RepDataset(Dataset): - def __init__(self, - x, - target_y, - ground_truth=None): - assert (len(x) == len(target_y)) - self.x = x - self.target_y = target_y - self.ground_truth = ground_truth - - def __len__(self): - return len(self.x) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - sample_x, sample_y = self.x[idx, :], self.target_y[idx] - if self.ground_truth is not None: - sample_gt = self.ground_truth[idx] - else: - sample_gt = np.nan - sample = {"x": sample_x, "y": sample_y, "ground_truth": sample_gt} - return sample - - -class SpatialDataModule(pl.LightningDataModule): - def __init__(self, - full_batch: bool = True, - batch_size: int = 1000, - num_neighbors: List[int] = None, - num_workers=None, - pin_memory=False) -> None: - self.batch_size = batch_size - self.full_batch = full_batch - self.has_y = False - self.train_dataset = None - self.valid_dataset = None - self.num_neighbors = num_neighbors - self.num_workers = num_workers - self.pin_memory = pin_memory - self.ann_data = None - - def prepare_data(self, n_top_genes: int = 3000, rad_cutoff: float = 50, - show_net_stats: bool = False, min_cells=50, min_counts=None) -> None: - sc.pp.calculate_qc_metrics(self.ann_data, inplace=True) - sc.pp.filter_genes(self.ann_data, min_cells=min_cells) - if min_counts is not None: - sc.pp.filter_cells(self.ann_data, min_counts=min_counts) - print("After filtering: ", self.ann_data.shape) - # Normalization - sc.pp.highly_variable_genes(self.ann_data, flavor="seurat_v3", n_top_genes=n_top_genes) - self.ann_data = self.ann_data[:, self.ann_data.var['highly_variable']] - sc.pp.normalize_total(self.ann_data, target_sum=1e4) - sc.pp.log1p(self.ann_data) - compute_spatial_net(self.ann_data, rad_cutoff=rad_cutoff) - if show_net_stats: - stats_spatial_net(self.ann_data) - # ---------------------------- generate data --------------------- - edge_list = compute_edge_list(self.ann_data) - self.train_dataset = Data(edge_index=torch.LongTensor(np.array([edge_list[0], edge_list[1]])), - x=torch.FloatTensor(self.ann_data.X), - y=None) - - def train_dataloader(self): - if self.full_batch: - loader = NeighborLoader(self.train_dataset, num_neighbors=[1], - batch_size=len(self.train_dataset.x)) - else: - loader = NeighborLoader(self.train_dataset, num_neighbors=self.num_neighbors, batch_size=self.batch_size) - return loader - - def val_dataloader(self): - if self.valid_dataset is None: - loader = NeighborLoader(self.train_dataset, num_neighbors=[1], - batch_size=len(self.train_dataset.x)) - else: - raise NotImplementedError - return loader - - def test_dataloader(self) -> EVAL_DATALOADERS: - raise NotImplementedError - - def predict_dataloader(self) -> EVAL_DATALOADERS: - raise NotImplementedError diff --git a/stamarker/stamarker/models.py b/stamarker/stamarker/models.py deleted file mode 100644 index d7457ec..0000000 --- a/stamarker/stamarker/models.py +++ /dev/null @@ -1,257 +0,0 @@ -from abc import ABC -from typing import Any, List -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, WeightedRandomSampler -from torch_geometric.data import Data -from sklearn.metrics import adjusted_rand_score, confusion_matrix -from .modules import STAGATEModule, StackMLPModule -from .dataset import RepDataset, Batch -from .utils import Timer - -def get_optimizer(name): - if name == "ADAM": - return torch.optim.Adam - elif name == "ADAGRAD": - return torch.optim.Adagrad - elif name == "ADADELTA": - return torch.optim.Adadelta - elif name == "RMS": - return torch.optim.RMSprop - elif name == "ASGD": - return torch.optim.ASGD - else: - raise NotImplementedError - - -def get_scheduler(name): - if name == "STEP_LR": - return torch.optim.lr_scheduler.StepLR - elif name == "EXP_LR": - return torch.optim.lr_scheduler.ExponentialLR - else: - raise NotImplementedError - -class BaseModule(pl.LightningModule, ABC): - def __init__(self): - super(BaseModule, self).__init__() - self.optimizer_params = None - self.scheduler_params = None - self.model = None - self.timer = Timer() - self.automatic_optimization = False - - def set_optimizer_params(self, - optimizer_params: dict, - scheduler_params: dict): - self.optimizer_params = optimizer_params - self.scheduler_params = scheduler_params - - def configure_optimizers(self): - optimizer = get_optimizer(self.optimizer_params["name"])( - self.model.parameters(), - **self.optimizer_params["params"]) - scheduler = get_scheduler(self.scheduler_params["name"])(optimizer, **self.scheduler_params["params"]) - return [optimizer], [scheduler] - - def on_train_epoch_start(self) -> None: - self.timer.tic('train') - - -class intSTAGATE(BaseModule): - """ - intSTAGATE Lightning Module - """ - def __init__(self, - in_features: int = None, - hidden_dims: List[int] = None, - gradient_clipping: float = 5.0, - **kwargs): - super(intSTAGATE, self).__init__() - self.model = STAGATEModule(in_features, hidden_dims) - self.auto_encoder_epochs = None - self.gradient_clipping = gradient_clipping - self.pred_labels = None - self.save_hyperparameters() - - def configure_optimizers(self) -> (dict, dict): - auto_encoder_optimizer = get_optimizer(self.optimizer_params["name"])( - list(self.model.parameters()), - **self.optimizer_params["params"]) - auto_encoder_scheduler = get_scheduler(self.scheduler_params["name"])(auto_encoder_optimizer, - **self.scheduler_params["params"]) - return [auto_encoder_optimizer], [auto_encoder_scheduler] - - def forward(self, x, edge_index) -> Any: - return self.model(x, edge_index) - - def training_step(self, batch, batch_idx): - batch = batch.to(self.device) - opt_auto_encoder = self.optimizers() - z, x_hat = self.model(batch.x, batch.edge_index) - loss = F.mse_loss(batch.x, x_hat) - opt_auto_encoder.zero_grad() - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping) - opt_auto_encoder.step() - self.log("Training auto-encoder|Reconstruction errors", loss.item(), prog_bar=True) - self.logger.experiment.add_scalar('auto_encoder/loss', loss.item(), self.current_epoch) - - def on_train_epoch_end(self) -> None: - time = self.timer.toc('train') - sch_auto_encoder = self.lr_schedulers() - sch_auto_encoder.step() - self.logger.experiment.add_scalar('train_time', time, self.current_epoch) - - def validation_step(self, batch, batch_idx): - pass - - def validation_epoch_end(self, outputs): - pass - - -def _compute_correct(scores, target_y): - _, pred_labels = torch.max(scores, axis=1) - correct = (pred_labels == target_y).sum().item() - return pred_labels, correct - - -class CoordTransformer(object): - def __init__(self, coord): - self.coord = coord - - def transform(self): - factor = np.max(np.max(self.coord, axis=0) - np.min(self.coord, axis=0)) - return (self.coord - np.min(self.coord, axis=0)) / factor - - -class StackClassifier(BaseModule): - def __init__(self, in_features: int, - n_classes: int = 7, - batch_size: int = 1000, - shuffle: bool = False, - hidden_dims: List[int] = [30], - architecture: str = "MLP", - sta_path: str = None, - **kwargs): - super(StackClassifier, self).__init__() - self.in_features = in_features - self.architecture = architecture - self.batch_size = batch_size - self.shuffle = shuffle - if architecture == "MLP": - self.model = StackMLPModule(in_features, n_classes, hidden_dims, **kwargs) - else: - raise NotImplementedError - self.dataset = None - self.train_dataset = None - self.val_dataset = None - self.automatic_optimization = False - self.sampler = None - self.test_prop = None - self.confusion = None - self.balanced = None - self.save_hyperparameters() - - def prepare(self, - stagate: intSTAGATE, - dataset: Data, - target_y, - test_prop: float = 0.5, - balanced: bool = True): - self.balanced = balanced - self.test_prop = test_prop - with torch.no_grad(): - representation, _ = stagate(dataset.x, dataset.edge_index) - if hasattr(dataset, "ground_truth"): - ground_truth = dataset.ground_truth - else: - ground_truth = None - if isinstance(target_y, np.ndarray): - target_y = torch.from_numpy(target_y).type(torch.LongTensor) - elif isinstance(target_y, torch.Tensor): - target_y = target_y.type(torch.LongTensor) - else: - raise TypeError("target_y must be either a torch tensor or a numpy ndarray.") - self.dataset = RepDataset(representation, target_y, ground_truth=ground_truth) - n_val = int(len(self.dataset) * test_prop) - self.train_dataset, self.val_dataset = torch.utils.data.random_split( - self.dataset, [len(self.dataset) - n_val, n_val]) - if balanced: - target_y = target_y[self.train_dataset.indices] - class_sample_count = np.array([len(np.where(target_y == t)[0]) for t in np.unique(target_y)]) - weight = 1. / class_sample_count - samples_weight = np.array([weight[t] for t in target_y]) - samples_weight = torch.from_numpy(samples_weight) - samples_weight = samples_weight.double() - self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) - - def forward(self, x, edge_index=None) -> Any: - if self.architecture == "MLP": - return self.model(x) - elif self.architecture == "STACls": - _, output = self.model(x, edge_index) - return output - - def training_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - opt = self.optimizers() - opt.zero_grad() - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - self.manual_backward(loss) - opt.step() - _, correct = _compute_correct(output["score"], batch.y) - self.log(f"Training {self.architecture} classifier|Cross entropy", loss.item(), prog_bar=True) - return {"loss": loss, "correct": correct} - - def training_epoch_end(self, outputs): - time = self.timer.toc('train') - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_time', time, self.current_epoch) - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - train_acc = all_correct / len(self.train_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_acc', - train_acc, self.current_epoch) - - def validation_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - with torch.no_grad(): - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - pred_labels, correct = _compute_correct(output["score"], batch.y) - return {"loss": loss, "correct": correct, "pred_labels": pred_labels, "true_labels": batch.y} - - def validation_epoch_end(self, outputs): - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - pred_labels = torch.cat([x["pred_labels"] for x in outputs]).cpu().detach().numpy() - true_labels = torch.cat([x["true_labels"] for x in outputs]).cpu().detach().numpy() - confusion = confusion_matrix(true_labels, pred_labels) - val_acc = all_correct / len(self.val_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_acc', - val_acc, self.current_epoch) - print("\n validation ACC={:.4f}".format(val_acc)) - self.confusion = confusion - - def train_dataloader(self): - loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.sampler) - return loader - - def val_dataloader(self): - loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=self.shuffle, drop_last=False) - return loader - - def test_dataloader(self): - raise NotImplementedError - - def predict_dataloader(self): - raise NotImplementedError \ No newline at end of file diff --git a/stamarker/stamarker/modules.py b/stamarker/stamarker/modules.py deleted file mode 100644 index 08d9d81..0000000 --- a/stamarker/stamarker/modules.py +++ /dev/null @@ -1,276 +0,0 @@ -import abc -import copy -from torch.autograd import Variable -import torch -from torch import Tensor -import torch.nn.functional as F -from torch.nn import Parameter -import torch.nn as nn -from torch_sparse import SparseTensor, set_diag -from torch_geometric.nn.conv import MessagePassing -from torch_geometric.utils import remove_self_loops, add_self_loops, softmax -from typing import Union, Tuple, Optional -from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, - OptTensor) - - -class GATConv(MessagePassing): - r"""The graph attentional operator from the `"Graph Attention Networks" - `_ paper - .. math:: - \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + - \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, - where the attention coefficients :math:`\alpha_{i,j}` are computed as - .. math:: - \alpha_{i,j} = - \frac{ - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] - \right)\right)} - {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] - \right)\right)}. - Args: - in_channels (int or tuple): Size of each input sample, or :obj:`-1` to - derive the size from the first input(s) to the forward method. - A tuple corresponds to the sizes of source and target - dimensionalities. - out_channels (int): Size of each output sample. - heads (int, optional): Number of multi-head-attentions. - (default: :obj:`1`) - concat (bool, optional): If set to :obj:`False`, the multi-head - attentions are averaged instead of concatenated. - (default: :obj:`True`) - negative_slope (float, optional): LeakyReLU angle of the negative - slope. (default: :obj:`0.2`) - dropout (float, optional): Dropout probability of the normalized - attention coefficients which exposes each node to a stochastically - sampled neighborhood during training. (default: :obj:`0`) - add_self_loops (bool, optional): If set to :obj:`False`, will not add - self-loops to the input graph. (default: :obj:`True`) - bias (bool, optional): If set to :obj:`False`, the layer will not learn - an additive bias. (default: :obj:`True`) - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - """ - _alpha: OptTensor - - def __init__(self, in_channels: Union[int, Tuple[int, int]], - out_channels: int, heads: int = 1, concat: bool = True, - negative_slope: float = 0.2, dropout: float = 0.0, - add_self_loops: bool = True, bias: bool = True, **kwargs): - kwargs.setdefault('aggr', 'add') - super(GATConv, self).__init__(node_dim=0, **kwargs) - - self.in_channels = in_channels - self.out_channels = out_channels - self.heads = heads - self.concat = concat - self.negative_slope = negative_slope - self.dropout = dropout - self.add_self_loops = add_self_loops - self.lin_src = nn.Parameter(torch.zeros(size=(in_channels, out_channels))) - nn.init.xavier_normal_(self.lin_src.data, gain=1.414) - self.lin_dst = self.lin_src - # The learnable parameters to compute attention coefficients: - self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) - self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) - nn.init.xavier_normal_(self.att_src.data, gain=1.414) - nn.init.xavier_normal_(self.att_dst.data, gain=1.414) - self._alpha = None - self.attentions = None - - def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, - size: Size = None, return_attention_weights=None, attention=True, tied_attention=None): - # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa - r""" - Args: - return_attention_weights (bool, optional): If set to :obj:`True`, - will additionally return the tuple - :obj:`(edge_index, attention_weights)`, holding the computed - attention weights for each edge. (default: :obj:`None`) - """ - H, C = self.heads, self.out_channels - - # We first transform the input node features. If a tuple is passed, we - # transform source and target node features via separate weights: - if isinstance(x, Tensor): - assert x.dim() == 2, "Static graphs not supported in 'GATConv'" - # x_src = x_dst = self.lin_src(x).view(-1, H, C) - x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) - else: # Tuple of source and target node features: - x_src, x_dst = x - assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" - x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None: - x_dst = self.lin_dst(x_dst).view(-1, H, C) - - x = (x_src, x_dst) - - if not attention: - return x[0].mean(dim=1) - # return x[0].view(-1, self.heads * self.out_channels) - - if tied_attention == None: - # Next, we compute node-level attention coefficients, both for source - # and target nodes (if present): - alpha_src = (x_src * self.att_src).sum(dim=-1) - alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) - alpha = (alpha_src, alpha_dst) - self.attentions = alpha - else: - alpha = tied_attention - - if self.add_self_loops: - if isinstance(edge_index, Tensor): - # We only want to add self-loops for nodes that appear both as - # source and target nodes: - num_nodes = x_src.size(0) - if x_dst is not None: - num_nodes = min(num_nodes, x_dst.size(0)) - num_nodes = min(size) if size is not None else num_nodes - edge_index, _ = remove_self_loops(edge_index) - edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) - elif isinstance(edge_index, SparseTensor): - edge_index = set_diag(edge_index) - - # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) - out = self.propagate(edge_index, x=x, alpha=alpha, size=size) - - alpha = self._alpha - assert alpha is not None - self._alpha = None - - if self.concat: - out = out.view(-1, self.heads * self.out_channels) - else: - out = out.mean(dim=1) - - # if self.bias is not None: - # out += self.bias - - if isinstance(return_attention_weights, bool): - if isinstance(edge_index, Tensor): - return out, (edge_index, alpha) - elif isinstance(edge_index, SparseTensor): - return out, edge_index.set_value(alpha, layout='coo') - else: - return out - - def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, - index: Tensor, ptr: OptTensor, - size_i: Optional[int]) -> Tensor: - # Given egel-level attention coefficients for source and target nodes, - # we simply need to sum them up to "emulate" concatenation: - alpha = alpha_j if alpha_i is None else alpha_j + alpha_i - - alpha = F.leaky_relu(alpha, self.negative_slope) - alpha = softmax(alpha, index, ptr, size_i) - self._alpha = alpha # Save for later use. - alpha = F.dropout(alpha, p=self.dropout, training=self.training) - return x_j * alpha.unsqueeze(-1) - - def __repr__(self): - return '{}({}, {}, heads={})'.format(self.__class__.__name__, - self.in_channels, - self.out_channels, self.heads) - - -class STAGATEModule(nn.Module): - def __init__(self, in_features, hidden_dims): - super(STAGATEModule, self).__init__() - [num_hidden, out_dim] = hidden_dims - self.conv1 = GATConv(in_features, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv4 = GATConv(num_hidden, in_features, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - - def forward(self, features, edge_index): - h1 = F.elu(self.conv1(features, edge_index)) - h2 = self.conv2(h1, edge_index, attention=False) - self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1) - self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1) - self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1) - self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1) - h3 = F.elu(self.conv3(h2, edge_index, attention=True, - tied_attention=self.conv1.attentions)) - h4 = self.conv4(h3, edge_index, attention=False) - - return h2, h4 - - -class StackClsModule(nn.Module, abc.ABC): - def __init__(self, in_features, n_classes): - super(StackClsModule, self).__init__() - self.in_features = in_features - self.n_classes = n_classes - - -class STAGATEClsModule(nn.Module): - def __init__(self, - stagate: STAGATEModule, - stack_classifier: StackClsModule): - super(STAGATEClsModule, self).__init__() - self.stagate = copy.deepcopy(stagate) - self.classifier = copy.deepcopy(stack_classifier) - - def forward(self, x, edge_index, mode="classifier"): - z, x_recon = self.stagate(x, edge_index) - z = torch.clone(z) - if mode == "classifier": - return z, self.classifier(z) - elif mode == "reconstruction": - return z, x_recon - else: - raise NotImplementedError - - def get_saliency_map(self, x, edge_index, target_index="max", save=None): - """ - Get saliency map by backpropagation. - :param x: input tensors - :param edge_index: graph edge index - :param target_index: target index to compute final scores - :param save: - :return: gradients - """ - x_var = Variable(x, requires_grad=True) - _, output = self.forward(x_var, edge_index, mode="classifier") - scores = output["last_layer"] - if target_index == "max": - target_score_indices = Variable(torch.argmax(scores, 1)) - elif isinstance(target_index, int): - target_score_indices = Variable(torch.ones(scores.shape[0], dtype=torch.int64) * target_index) - else: - raise NotImplementedError - target_scores = scores.gather(1, target_score_indices.view(-1, 1)).squeeze() - loss = torch.sum(target_scores) - loss.backward() - gradients = x_var.grad.data - if save is not None: - torch.save(gradients, save) - return gradients, scores - - -class StackMLPModule(StackClsModule): - name = "StackMLP" - - def __init__(self, in_features, n_classes, hidden_dims=[30, 40, 30]): - super(StackMLPModule, self).__init__(in_features, n_classes) - self.classifier = nn.ModuleList() - mlp_dims = [in_features] + hidden_dims + [n_classes] - for ind in range(len(mlp_dims) - 1): - self.classifier.append(nn.Linear(mlp_dims[ind], mlp_dims[ind + 1])) - - def forward(self, x): - for layer in self.classifier: - x = layer(x) - score = F.softmax(x, dim=0) - return {"last_layer": x, "score": score} diff --git a/stamarker/stamarker/pipeline.py b/stamarker/stamarker/pipeline.py deleted file mode 100644 index c24592a..0000000 --- a/stamarker/stamarker/pipeline.py +++ /dev/null @@ -1,287 +0,0 @@ -import pytorch_lightning as pl -import copy -import torch -import os -import shutil -import logging -import glob -import sys -import numpy as np -import scipy -import scanpy as sc -from pytorch_lightning.loggers import TensorBoardLogger -from scipy.cluster import hierarchy -from .models import intSTAGATE, StackClassifier -from .utils import plot_consensus_map, consensus_matrix, Timer -from .dataset import SpatialDataModule -from .modules import STAGATEClsModule -import logging - - -FORMAT = "%(asctime)s %(levelname)s %(message)s" -logging.basicConfig(format=FORMAT, datefmt='%Y-%m-%d %H:%M:%S') -def make_spatial_data(ann_data): - """ - Make SpatialDataModule object from Scanpy annData object - """ - data_module = SpatialDataModule() - ann_data.X = ann_data.X.toarray() - data_module.ann_data = ann_data - return data_module - - -class STAMarker: - def __init__(self, n, save_dir, config, logging_level=logging.INFO): - """ - n: int, number of graph attention auto-econders to train - save_dir: directory to save the models - config: config file for training - """ - self.n = n - self.save_dir = save_dir - if not os.path.exists(save_dir): - os.mkdir(save_dir) - logging.info("Create save directory {}".format(save_dir)) - self.version_dirs = [os.path.join(save_dir, f"version_{i}") for i in range(n)] - self.config = config - self.logger = logging.getLogger("STAMarker") - self.logger.setLevel(logging_level) - self.consensus_labels = None - - def load_from_dir(self, save_dir, ): - """ - Load the trained models from a directory - """ - self.version_dirs = glob.glob(os.path.join(save_dir, "version_*")) - self.version_dirs = sorted(self.version_dirs, key=lambda x: int(x.split("_")[-1])) - # check if all version dir have `checkpoints/stagate.ckpt` - version_dirs_valid = [] - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "checkpoints/stagate.ckpt")): - self.logger.warning("No checkpoint found in {}".format(version_dir)) - else: - version_dirs_valid.append(version_dir) - self.version_dirs = version_dirs_valid - self.logger.info("Load {} autoencoder models from {}".format(len(version_dirs_valid), save_dir)) - # check if all version dir have `cluster_labels.npy` raise warning if not - missing_cluster_labels = False - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "cluster_labels.npy")): - missing_cluster_labels = True - msg = "No cluster labels found in {}.".format(version_dir) - self.logger.warning(msg) - if missing_cluster_labels: - self.logger.warning("Please run clustering first.") - # check if save_dir has `consensus.npy` raise warning if not - if not os.path.exists(os.path.join(save_dir, "consensus.npy")): - self.logger.warning("No consensus labels found in {}".format(save_dir)) - else: - self.consensus_labels = np.load(os.path.join(save_dir, "consensus.npy")) - # check if all version dir have `checkpoints/mlp.ckpt` raise warning if not - missing_clf = False - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "checkpoints/mlp.ckpt")): - self.logger.warning("No classifier checkpoint found in {}".format(version_dir)) - missing_clf = True - if missing_clf: - self.logger.warning("Please run classifier training first.") - if not missing_cluster_labels and not missing_clf: - self.logger.info("All models are trained and ready to use.") - - def train_auto_encoders(self, data_module): - for seed in range(self.n): - self._train_auto_encoder(data_module, seed, self.config) - self.logger.info("Finished training {} auto-encoders".format(self.n)) - - def clustering(self, data_module, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoders - Cluster method should be "louvain" or "mclust" - """ - for version_dir in self.version_dirs: - self._clustering(data_module, version_dir, cluster_method, cluster_params) - self.logger.info("Finished {} clustering with {}".format(self.n, cluster_method)) - - def consensus_clustering(self, n_clusters, name="cluster_labels.npy"): - sys.setrecursionlimit(100000) - label_files = glob.glob(self.save_dir + f"/version_*/{name}") - labels_list = list(map(lambda file: np.load(file), label_files)) - cons_mat = consensus_matrix(labels_list) - row_linkage, _, figure = plot_consensus_map(cons_mat, return_linkage=True) - figure.savefig(os.path.join(self.save_dir, "consensus_clustering.png"), dpi=300) - consensus_labels = hierarchy.cut_tree(row_linkage, n_clusters).squeeze() - np.save(os.path.join(self.save_dir, "consensus"), consensus_labels) - self.consensus_labels = consensus_labels - self.logger.info("Save consensus labels to {}".format(os.path.join(self.save_dir, "consensus.npz"))) - - def train_classifiers(self, data_module, n_clusters, name="cluster_labels.npy"): - for i, version_dir in enumerate(self.version_dirs): - # _train_classifier(self, data_module, version_dir, target_y, n_classes, seed=None) - self._train_classifier(data_module, version_dir, self.consensus_labels, - n_clusters, self.config, seed=i) - self.logger.info("Finished training {} classifiers".format(self.n)) - - def compute_smaps(self, data_module, return_recon=True, normalize=True): - smaps = [] - if return_recon: - recons = [] - for version_dir in self.version_dirs: - if return_recon: - smap, recon = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - recons.append(recon) - else: - smap = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - if return_recon: - return smaps, recons - else: - return smaps - self.logger.info("Finished computing {} smaps".format(self.n)) - - - def _compute_smap_zscore(self, smap, labels, logtransform=False): - scores = np.log(smap + 1) if logtransform else copy.copy(smap) - unique_labels = np.unique(labels) - for l in unique_labels: - scores[labels == l, :] = scipy.stats.zscore(scores[labels == l, :], axis=1) - return scores - - - def _clustering(self, data_module, version_dir, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoder - """ - if cluster_method == "louvain": - run_louvain(data_module, version_dir, cluster_params) - elif cluster_method == "mclust": - run_mclust(data_module, version_dir, cluster_params) - else: - raise ValueError("Unknown clustering method") - - def _train_auto_encoder(self, data_module, seed, config): - """ - Train a single graph attention auto-encoder - """ - pl.seed_everything(seed) - version = f"version_{seed}" - version_dir = os.path.join(self.save_dir, version) - if os.path.exists(version_dir): - shutil.rmtree(version_dir) - os.makedirs(version_dir, exist_ok=True) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - model = intSTAGATE(**config["stagate"]["params"]) - model.set_optimizer_params(config["stagate"]["optimizer"], config["stagate"]["scheduler"]) - trainer = pl.Trainer(logger=logger, **config["stagate_trainer"]) - timer = Timer() - timer.tic("fit") - trainer.fit(model, data_module) - fit_time = timer.toc("fit") - with open(os.path.join(version_dir, "runtime.csv"), "w+") as f: - f.write("{}, fit_time, {:.2f}, ".format(seed, fit_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - del model, trainer - if config["stagate_trainer"]["gpus"] > 0: - torch.cuda.empty_cache() - logging.info(f"Finshed running version {seed}") - - def _train_classifier(self, data_module, version_dir, target_y, n_classes, config, seed=None): - timer = Timer() - pl.seed_everything(seed) - rep_dim = config["stagate"]["params"]["hidden_dims"][-1] - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - classifier = StackClassifier(rep_dim, n_classes=n_classes, architecture="MLP") - classifier.prepare(stagate, data_module.train_dataset, target_y, - balanced=config["mlp"]["balanced"], test_prop=config["mlp"]["test_prop"]) - classifier.set_optimizer_params(config["mlp"]["optimizer"], config["mlp"]["scheduler"]) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - trainer = pl.Trainer(logger=logger, **config["classifier_trainer"]) - timer.tic("clf") - trainer.fit(classifier) - clf_time = timer.toc("clf") - with open(os.path.join(version_dir, "runtime.csv"), "a+") as f: - f.write("\n") - f.write("{}, clf_time, {:.2f}, ".format(seed, clf_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - target_y = classifier.dataset.target_y.numpy() - all_props = class_proportions(target_y) - val_props = class_proportions(target_y[classifier.val_dataset.indices]) - if self.logger.level == logging.DEBUG: - print("All class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in all_props])) - print("Val class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in val_props])) - np.save(os.path.join(version_dir, "confusion.npy"), classifier.confusion) - - def _compute_smap(self, data_module, version_dir, return_recon=True): - """ - Compute the saliency map of the trained auto-encoder - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - cls = StackClassifier.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - stagate_cls = STAGATEClsModule(stagate.model, cls.model) - smap, _ = stagate_cls.get_saliency_map(data_module.train_dataset.x, - data_module.train_dataset.edge_index) - smap = smap.detach().cpu().numpy() - if return_recon: - recon = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[1].cpu().detach().numpy() - return smap, recon - else: - return smap - - -def run_louvain(data_module, version_dir, resolution, name="cluster_labels"): - """ - Run louvain clustering on the data_module - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - ann_data = copy.copy(data_module.ann_data) - ann_data.obsm["stagate"] = embedding - sc.pp.neighbors(ann_data, use_rep='stagate') - sc.tl.louvain(ann_data, resolution=resolution) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, ann_data.obs["louvain"].to_numpy().astype("int")) - print("Save louvain results to {}".format(save_path)) - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def run_mclust(data_module, version_dir, n_clusters, name="cluster_labels"): - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - labels = mclust_R(embedding, n_clusters) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, labels.astype("int")) - print("Save MClust results to {}".format(save_path)) - -def class_proportions(target): - n_classes = len(np.unique(target)) - props = np.array([np.sum(target == i) for i in range(n_classes)]) - return props / np.sum(props) - - - - - - - diff --git a/stamarker/stamarker/stamarker/.ipynb_checkpoints/models-checkpoint.py b/stamarker/stamarker/stamarker/.ipynb_checkpoints/models-checkpoint.py deleted file mode 100644 index d7457ec..0000000 --- a/stamarker/stamarker/stamarker/.ipynb_checkpoints/models-checkpoint.py +++ /dev/null @@ -1,257 +0,0 @@ -from abc import ABC -from typing import Any, List -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, WeightedRandomSampler -from torch_geometric.data import Data -from sklearn.metrics import adjusted_rand_score, confusion_matrix -from .modules import STAGATEModule, StackMLPModule -from .dataset import RepDataset, Batch -from .utils import Timer - -def get_optimizer(name): - if name == "ADAM": - return torch.optim.Adam - elif name == "ADAGRAD": - return torch.optim.Adagrad - elif name == "ADADELTA": - return torch.optim.Adadelta - elif name == "RMS": - return torch.optim.RMSprop - elif name == "ASGD": - return torch.optim.ASGD - else: - raise NotImplementedError - - -def get_scheduler(name): - if name == "STEP_LR": - return torch.optim.lr_scheduler.StepLR - elif name == "EXP_LR": - return torch.optim.lr_scheduler.ExponentialLR - else: - raise NotImplementedError - -class BaseModule(pl.LightningModule, ABC): - def __init__(self): - super(BaseModule, self).__init__() - self.optimizer_params = None - self.scheduler_params = None - self.model = None - self.timer = Timer() - self.automatic_optimization = False - - def set_optimizer_params(self, - optimizer_params: dict, - scheduler_params: dict): - self.optimizer_params = optimizer_params - self.scheduler_params = scheduler_params - - def configure_optimizers(self): - optimizer = get_optimizer(self.optimizer_params["name"])( - self.model.parameters(), - **self.optimizer_params["params"]) - scheduler = get_scheduler(self.scheduler_params["name"])(optimizer, **self.scheduler_params["params"]) - return [optimizer], [scheduler] - - def on_train_epoch_start(self) -> None: - self.timer.tic('train') - - -class intSTAGATE(BaseModule): - """ - intSTAGATE Lightning Module - """ - def __init__(self, - in_features: int = None, - hidden_dims: List[int] = None, - gradient_clipping: float = 5.0, - **kwargs): - super(intSTAGATE, self).__init__() - self.model = STAGATEModule(in_features, hidden_dims) - self.auto_encoder_epochs = None - self.gradient_clipping = gradient_clipping - self.pred_labels = None - self.save_hyperparameters() - - def configure_optimizers(self) -> (dict, dict): - auto_encoder_optimizer = get_optimizer(self.optimizer_params["name"])( - list(self.model.parameters()), - **self.optimizer_params["params"]) - auto_encoder_scheduler = get_scheduler(self.scheduler_params["name"])(auto_encoder_optimizer, - **self.scheduler_params["params"]) - return [auto_encoder_optimizer], [auto_encoder_scheduler] - - def forward(self, x, edge_index) -> Any: - return self.model(x, edge_index) - - def training_step(self, batch, batch_idx): - batch = batch.to(self.device) - opt_auto_encoder = self.optimizers() - z, x_hat = self.model(batch.x, batch.edge_index) - loss = F.mse_loss(batch.x, x_hat) - opt_auto_encoder.zero_grad() - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping) - opt_auto_encoder.step() - self.log("Training auto-encoder|Reconstruction errors", loss.item(), prog_bar=True) - self.logger.experiment.add_scalar('auto_encoder/loss', loss.item(), self.current_epoch) - - def on_train_epoch_end(self) -> None: - time = self.timer.toc('train') - sch_auto_encoder = self.lr_schedulers() - sch_auto_encoder.step() - self.logger.experiment.add_scalar('train_time', time, self.current_epoch) - - def validation_step(self, batch, batch_idx): - pass - - def validation_epoch_end(self, outputs): - pass - - -def _compute_correct(scores, target_y): - _, pred_labels = torch.max(scores, axis=1) - correct = (pred_labels == target_y).sum().item() - return pred_labels, correct - - -class CoordTransformer(object): - def __init__(self, coord): - self.coord = coord - - def transform(self): - factor = np.max(np.max(self.coord, axis=0) - np.min(self.coord, axis=0)) - return (self.coord - np.min(self.coord, axis=0)) / factor - - -class StackClassifier(BaseModule): - def __init__(self, in_features: int, - n_classes: int = 7, - batch_size: int = 1000, - shuffle: bool = False, - hidden_dims: List[int] = [30], - architecture: str = "MLP", - sta_path: str = None, - **kwargs): - super(StackClassifier, self).__init__() - self.in_features = in_features - self.architecture = architecture - self.batch_size = batch_size - self.shuffle = shuffle - if architecture == "MLP": - self.model = StackMLPModule(in_features, n_classes, hidden_dims, **kwargs) - else: - raise NotImplementedError - self.dataset = None - self.train_dataset = None - self.val_dataset = None - self.automatic_optimization = False - self.sampler = None - self.test_prop = None - self.confusion = None - self.balanced = None - self.save_hyperparameters() - - def prepare(self, - stagate: intSTAGATE, - dataset: Data, - target_y, - test_prop: float = 0.5, - balanced: bool = True): - self.balanced = balanced - self.test_prop = test_prop - with torch.no_grad(): - representation, _ = stagate(dataset.x, dataset.edge_index) - if hasattr(dataset, "ground_truth"): - ground_truth = dataset.ground_truth - else: - ground_truth = None - if isinstance(target_y, np.ndarray): - target_y = torch.from_numpy(target_y).type(torch.LongTensor) - elif isinstance(target_y, torch.Tensor): - target_y = target_y.type(torch.LongTensor) - else: - raise TypeError("target_y must be either a torch tensor or a numpy ndarray.") - self.dataset = RepDataset(representation, target_y, ground_truth=ground_truth) - n_val = int(len(self.dataset) * test_prop) - self.train_dataset, self.val_dataset = torch.utils.data.random_split( - self.dataset, [len(self.dataset) - n_val, n_val]) - if balanced: - target_y = target_y[self.train_dataset.indices] - class_sample_count = np.array([len(np.where(target_y == t)[0]) for t in np.unique(target_y)]) - weight = 1. / class_sample_count - samples_weight = np.array([weight[t] for t in target_y]) - samples_weight = torch.from_numpy(samples_weight) - samples_weight = samples_weight.double() - self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) - - def forward(self, x, edge_index=None) -> Any: - if self.architecture == "MLP": - return self.model(x) - elif self.architecture == "STACls": - _, output = self.model(x, edge_index) - return output - - def training_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - opt = self.optimizers() - opt.zero_grad() - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - self.manual_backward(loss) - opt.step() - _, correct = _compute_correct(output["score"], batch.y) - self.log(f"Training {self.architecture} classifier|Cross entropy", loss.item(), prog_bar=True) - return {"loss": loss, "correct": correct} - - def training_epoch_end(self, outputs): - time = self.timer.toc('train') - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_time', time, self.current_epoch) - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - train_acc = all_correct / len(self.train_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_acc', - train_acc, self.current_epoch) - - def validation_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - with torch.no_grad(): - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - pred_labels, correct = _compute_correct(output["score"], batch.y) - return {"loss": loss, "correct": correct, "pred_labels": pred_labels, "true_labels": batch.y} - - def validation_epoch_end(self, outputs): - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - pred_labels = torch.cat([x["pred_labels"] for x in outputs]).cpu().detach().numpy() - true_labels = torch.cat([x["true_labels"] for x in outputs]).cpu().detach().numpy() - confusion = confusion_matrix(true_labels, pred_labels) - val_acc = all_correct / len(self.val_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_acc', - val_acc, self.current_epoch) - print("\n validation ACC={:.4f}".format(val_acc)) - self.confusion = confusion - - def train_dataloader(self): - loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.sampler) - return loader - - def val_dataloader(self): - loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=self.shuffle, drop_last=False) - return loader - - def test_dataloader(self): - raise NotImplementedError - - def predict_dataloader(self): - raise NotImplementedError \ No newline at end of file diff --git a/stamarker/stamarker/stamarker/.ipynb_checkpoints/modules-checkpoint.py b/stamarker/stamarker/stamarker/.ipynb_checkpoints/modules-checkpoint.py deleted file mode 100644 index 08d9d81..0000000 --- a/stamarker/stamarker/stamarker/.ipynb_checkpoints/modules-checkpoint.py +++ /dev/null @@ -1,276 +0,0 @@ -import abc -import copy -from torch.autograd import Variable -import torch -from torch import Tensor -import torch.nn.functional as F -from torch.nn import Parameter -import torch.nn as nn -from torch_sparse import SparseTensor, set_diag -from torch_geometric.nn.conv import MessagePassing -from torch_geometric.utils import remove_self_loops, add_self_loops, softmax -from typing import Union, Tuple, Optional -from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, - OptTensor) - - -class GATConv(MessagePassing): - r"""The graph attentional operator from the `"Graph Attention Networks" - `_ paper - .. math:: - \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + - \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, - where the attention coefficients :math:`\alpha_{i,j}` are computed as - .. math:: - \alpha_{i,j} = - \frac{ - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] - \right)\right)} - {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] - \right)\right)}. - Args: - in_channels (int or tuple): Size of each input sample, or :obj:`-1` to - derive the size from the first input(s) to the forward method. - A tuple corresponds to the sizes of source and target - dimensionalities. - out_channels (int): Size of each output sample. - heads (int, optional): Number of multi-head-attentions. - (default: :obj:`1`) - concat (bool, optional): If set to :obj:`False`, the multi-head - attentions are averaged instead of concatenated. - (default: :obj:`True`) - negative_slope (float, optional): LeakyReLU angle of the negative - slope. (default: :obj:`0.2`) - dropout (float, optional): Dropout probability of the normalized - attention coefficients which exposes each node to a stochastically - sampled neighborhood during training. (default: :obj:`0`) - add_self_loops (bool, optional): If set to :obj:`False`, will not add - self-loops to the input graph. (default: :obj:`True`) - bias (bool, optional): If set to :obj:`False`, the layer will not learn - an additive bias. (default: :obj:`True`) - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - """ - _alpha: OptTensor - - def __init__(self, in_channels: Union[int, Tuple[int, int]], - out_channels: int, heads: int = 1, concat: bool = True, - negative_slope: float = 0.2, dropout: float = 0.0, - add_self_loops: bool = True, bias: bool = True, **kwargs): - kwargs.setdefault('aggr', 'add') - super(GATConv, self).__init__(node_dim=0, **kwargs) - - self.in_channels = in_channels - self.out_channels = out_channels - self.heads = heads - self.concat = concat - self.negative_slope = negative_slope - self.dropout = dropout - self.add_self_loops = add_self_loops - self.lin_src = nn.Parameter(torch.zeros(size=(in_channels, out_channels))) - nn.init.xavier_normal_(self.lin_src.data, gain=1.414) - self.lin_dst = self.lin_src - # The learnable parameters to compute attention coefficients: - self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) - self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) - nn.init.xavier_normal_(self.att_src.data, gain=1.414) - nn.init.xavier_normal_(self.att_dst.data, gain=1.414) - self._alpha = None - self.attentions = None - - def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, - size: Size = None, return_attention_weights=None, attention=True, tied_attention=None): - # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa - r""" - Args: - return_attention_weights (bool, optional): If set to :obj:`True`, - will additionally return the tuple - :obj:`(edge_index, attention_weights)`, holding the computed - attention weights for each edge. (default: :obj:`None`) - """ - H, C = self.heads, self.out_channels - - # We first transform the input node features. If a tuple is passed, we - # transform source and target node features via separate weights: - if isinstance(x, Tensor): - assert x.dim() == 2, "Static graphs not supported in 'GATConv'" - # x_src = x_dst = self.lin_src(x).view(-1, H, C) - x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) - else: # Tuple of source and target node features: - x_src, x_dst = x - assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" - x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None: - x_dst = self.lin_dst(x_dst).view(-1, H, C) - - x = (x_src, x_dst) - - if not attention: - return x[0].mean(dim=1) - # return x[0].view(-1, self.heads * self.out_channels) - - if tied_attention == None: - # Next, we compute node-level attention coefficients, both for source - # and target nodes (if present): - alpha_src = (x_src * self.att_src).sum(dim=-1) - alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) - alpha = (alpha_src, alpha_dst) - self.attentions = alpha - else: - alpha = tied_attention - - if self.add_self_loops: - if isinstance(edge_index, Tensor): - # We only want to add self-loops for nodes that appear both as - # source and target nodes: - num_nodes = x_src.size(0) - if x_dst is not None: - num_nodes = min(num_nodes, x_dst.size(0)) - num_nodes = min(size) if size is not None else num_nodes - edge_index, _ = remove_self_loops(edge_index) - edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) - elif isinstance(edge_index, SparseTensor): - edge_index = set_diag(edge_index) - - # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) - out = self.propagate(edge_index, x=x, alpha=alpha, size=size) - - alpha = self._alpha - assert alpha is not None - self._alpha = None - - if self.concat: - out = out.view(-1, self.heads * self.out_channels) - else: - out = out.mean(dim=1) - - # if self.bias is not None: - # out += self.bias - - if isinstance(return_attention_weights, bool): - if isinstance(edge_index, Tensor): - return out, (edge_index, alpha) - elif isinstance(edge_index, SparseTensor): - return out, edge_index.set_value(alpha, layout='coo') - else: - return out - - def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, - index: Tensor, ptr: OptTensor, - size_i: Optional[int]) -> Tensor: - # Given egel-level attention coefficients for source and target nodes, - # we simply need to sum them up to "emulate" concatenation: - alpha = alpha_j if alpha_i is None else alpha_j + alpha_i - - alpha = F.leaky_relu(alpha, self.negative_slope) - alpha = softmax(alpha, index, ptr, size_i) - self._alpha = alpha # Save for later use. - alpha = F.dropout(alpha, p=self.dropout, training=self.training) - return x_j * alpha.unsqueeze(-1) - - def __repr__(self): - return '{}({}, {}, heads={})'.format(self.__class__.__name__, - self.in_channels, - self.out_channels, self.heads) - - -class STAGATEModule(nn.Module): - def __init__(self, in_features, hidden_dims): - super(STAGATEModule, self).__init__() - [num_hidden, out_dim] = hidden_dims - self.conv1 = GATConv(in_features, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv4 = GATConv(num_hidden, in_features, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - - def forward(self, features, edge_index): - h1 = F.elu(self.conv1(features, edge_index)) - h2 = self.conv2(h1, edge_index, attention=False) - self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1) - self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1) - self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1) - self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1) - h3 = F.elu(self.conv3(h2, edge_index, attention=True, - tied_attention=self.conv1.attentions)) - h4 = self.conv4(h3, edge_index, attention=False) - - return h2, h4 - - -class StackClsModule(nn.Module, abc.ABC): - def __init__(self, in_features, n_classes): - super(StackClsModule, self).__init__() - self.in_features = in_features - self.n_classes = n_classes - - -class STAGATEClsModule(nn.Module): - def __init__(self, - stagate: STAGATEModule, - stack_classifier: StackClsModule): - super(STAGATEClsModule, self).__init__() - self.stagate = copy.deepcopy(stagate) - self.classifier = copy.deepcopy(stack_classifier) - - def forward(self, x, edge_index, mode="classifier"): - z, x_recon = self.stagate(x, edge_index) - z = torch.clone(z) - if mode == "classifier": - return z, self.classifier(z) - elif mode == "reconstruction": - return z, x_recon - else: - raise NotImplementedError - - def get_saliency_map(self, x, edge_index, target_index="max", save=None): - """ - Get saliency map by backpropagation. - :param x: input tensors - :param edge_index: graph edge index - :param target_index: target index to compute final scores - :param save: - :return: gradients - """ - x_var = Variable(x, requires_grad=True) - _, output = self.forward(x_var, edge_index, mode="classifier") - scores = output["last_layer"] - if target_index == "max": - target_score_indices = Variable(torch.argmax(scores, 1)) - elif isinstance(target_index, int): - target_score_indices = Variable(torch.ones(scores.shape[0], dtype=torch.int64) * target_index) - else: - raise NotImplementedError - target_scores = scores.gather(1, target_score_indices.view(-1, 1)).squeeze() - loss = torch.sum(target_scores) - loss.backward() - gradients = x_var.grad.data - if save is not None: - torch.save(gradients, save) - return gradients, scores - - -class StackMLPModule(StackClsModule): - name = "StackMLP" - - def __init__(self, in_features, n_classes, hidden_dims=[30, 40, 30]): - super(StackMLPModule, self).__init__(in_features, n_classes) - self.classifier = nn.ModuleList() - mlp_dims = [in_features] + hidden_dims + [n_classes] - for ind in range(len(mlp_dims) - 1): - self.classifier.append(nn.Linear(mlp_dims[ind], mlp_dims[ind + 1])) - - def forward(self, x): - for layer in self.classifier: - x = layer(x) - score = F.softmax(x, dim=0) - return {"last_layer": x, "score": score} diff --git a/stamarker/stamarker/stamarker/.ipynb_checkpoints/pipeline-checkpoint.py b/stamarker/stamarker/stamarker/.ipynb_checkpoints/pipeline-checkpoint.py deleted file mode 100644 index 13c4188..0000000 --- a/stamarker/stamarker/stamarker/.ipynb_checkpoints/pipeline-checkpoint.py +++ /dev/null @@ -1,246 +0,0 @@ -import pytorch_lightning as pl -import copy -import torch -import os -import shutil -import logging -import glob -import sys -import numpy as np -import scipy -import scanpy as sc -from pytorch_lightning.loggers import TensorBoardLogger -from scipy.cluster import hierarchy -from .models import intSTAGATE, StackClassifier -from .utils import plot_consensus_map, consensus_matrix, Timer -from .dataset import SpatialDataModule -from .modules import STAGATEClsModule -import logging - - -FORMAT = "%(asctime)s %(levelname)s %(message)s" -logging.basicConfig(format=FORMAT, datefmt='%Y-%m-%d %H:%M:%S') -def make_spatial_data(ann_data): - """ - Make SpatialDataModule object from Scanpy annData object - """ - data_module = SpatialDataModule() - ann_data.X = ann_data.X.toarray() - data_module.ann_data = ann_data - return data_module - - -class STAMarker: - def __init__(self, n, save_dir, config, logging_level=logging.INFO): - """ - n: int, number of graph attention auto-econders to train - save_dir: directory to save the models - config: config file for training - """ - self.n = n - self.save_dir = save_dir - if not os.path.exists(save_dir): - os.mkdir(save_dir) - logging.info("Create save directory {}".format(save_dir)) - self.version_dirs = [os.path.join(save_dir, f"version_{i}") for i in range(n)] - self.config = config - self.logger = logging.getLogger("STAMarker") - self.logger.setLevel(logging_level) - self.consensus_labels = None - - def train_auto_encoders(self, data_module): - for seed in range(self.n): - self._train_auto_encoder(data_module, seed, self.config) - self.logger.info("Finished training {} auto-encoders".format(self.n)) - - def clustering(self, data_module, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoders - Cluster method should be "louvain" or "mclust" - """ - for version_dir in self.version_dirs: - self._clustering(data_module, version_dir, cluster_method, cluster_params) - self.logger.info("Finished {} clustering with {}".format(self.n, cluster_method)) - - def consensus_clustering(self, n_clusters, name="cluster_labels.npy"): - sys.setrecursionlimit(100000) - label_files = glob.glob(self.save_dir + f"/version_*/{name}") - labels_list = list(map(lambda file: np.load(file), label_files)) - cons_mat = consensus_matrix(labels_list) - row_linkage, _, figure = plot_consensus_map(cons_mat, return_linkage=True) - figure.savefig(os.path.join(self.save_dir, "consensus_clustering.png"), dpi=300) - consensus_labels = hierarchy.cut_tree(row_linkage, n_clusters).squeeze() - np.save(os.path.join(self.save_dir, "consensus"), consensus_labels) - self.consensus_labels = consensus_labels - self.logger.info("Save consensus labels to {}".format(os.path.join(self.save_dir, "consensus.npz"))) - - def train_classifiers(self, data_module, n_clusters, name="cluster_labels.npy"): - for version_dir in self.version_dirs: - # _train_classifier(self, data_module, version_dir, target_y, n_classes, seed=None) - self._train_classifier(data_module, version_dir, self.consensus_labels, n_clusters, self.config) - self.logger.info("Finished training {} classifiers".format(self.n)) - - def compute_smaps(self, data_module, return_recon=True, normalize=True): - smaps = [] - if return_recon: - recons = [] - for version_dir in self.version_dirs: - if return_recon: - smap, recon = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - recons.append(recon) - else: - smap = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - if return_recon: - return smaps, recons - else: - return smaps - self.logger.info("Finished computing {} smaps".format(self.n)) - - - def _compute_smap_zscore(self, smap, labels, logtransform=False): - scores = np.log(smap + 1) if logtransform else copy.copy(smap) - unique_labels = np.unique(labels) - for l in unique_labels: - scores[labels == l, :] = scipy.stats.zscore(scores[labels == l, :], axis=1) - return scores - - - def _clustering(self, data_module, version_dir, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoder - """ - if cluster_method == "louvain": - run_louvain(data_module, version_dir, cluster_params) - elif cluster_method == "mclust": - run_mclust(data_module, version_dir, cluster_params) - else: - raise ValueError("Unknown clustering method") - - def _train_auto_encoder(self, data_module, seed, config): - """ - Train a single graph attention auto-encoder - """ - pl.seed_everything(seed) - version = f"version_{seed}" - version_dir = os.path.join(self.save_dir, version) - if os.path.exists(version_dir): - shutil.rmtree(version_dir) - os.makedirs(version_dir, exist_ok=True) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - model = intSTAGATE(**config["stagate"]["params"]) - model.set_optimizer_params(config["stagate"]["optimizer"], config["stagate"]["scheduler"]) - trainer = pl.Trainer(logger=logger, **config["stagate_trainer"]) - timer = Timer() - timer.tic("fit") - trainer.fit(model, data_module) - fit_time = timer.toc("fit") - with open(os.path.join(version_dir, "runtime.csv"), "w+") as f: - f.write("{}, fit_time, {:.2f}, ".format(seed, fit_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - del model, trainer - if config["stagate_trainer"]["gpus"] > 0: - torch.cuda.empty_cache() - logging.info(f"Finshed running version {seed}") - - def _train_classifier(self, data_module, version_dir, target_y, n_classes, config, seed=None): - timer = Timer() - pl.seed_everything(seed) - rep_dim = config["stagate"]["params"]["hidden_dims"][-1] - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - classifier = StackClassifier(rep_dim, n_classes=n_classes, architecture="MLP") - classifier.prepare(stagate, data_module.train_dataset, target_y, - balanced=config["mlp"]["balanced"], test_prop=config["mlp"]["test_prop"]) - classifier.set_optimizer_params(config["mlp"]["optimizer"], config["mlp"]["scheduler"]) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - trainer = pl.Trainer(logger=logger, **config["classifier_trainer"]) - timer.tic("clf") - trainer.fit(classifier) - clf_time = timer.toc("clf") - with open(os.path.join(version_dir, "runtime.csv"), "a+") as f: - f.write("\n") - f.write("{}, clf_time, {:.2f}, ".format(seed, clf_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - target_y = classifier.dataset.target_y.numpy() - all_props = class_proportions(target_y) - val_props = class_proportions(target_y[classifier.val_dataset.indices]) - if self.logger.level == logging.DEBUG: - print("All class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in all_props])) - print("Val class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in val_props])) - np.save(os.path.join(version_dir, "confusion.npy"), classifier.confusion) - - def _compute_smap(self, data_module, version_dir, return_recon=True): - """ - Compute the saliency map of the trained auto-encoder - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - cls = StackClassifier.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - stagate_cls = STAGATEClsModule(stagate.model, cls.model) - smap, _ = stagate_cls.get_saliency_map(data_module.train_dataset.x, - data_module.train_dataset.edge_index) - smap = smap.detach().cpu().numpy() - if return_recon: - recon = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[1].cpu().detach().numpy() - return smap, recon - else: - return smap - - -def run_louvain(data_module, version_dir, resolution, name="cluster_labels"): - """ - Run louvain clustering on the data_module - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - ann_data = copy.copy(data_module.ann_data) - ann_data.obsm["stagate"] = embedding - sc.pp.neighbors(ann_data, use_rep='stagate') - sc.tl.louvain(ann_data, resolution=resolution) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, ann_data.obs["louvain"].to_numpy().astype("int")) - print("Save louvain results to {}".format(save_path)) - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def run_mclust(data_module, version_dir, n_clusters, name="cluster_labels"): - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - labels = mclust_R(embedding, n_clusters) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, labels.astype("int")) - print("Save MClust results to {}".format(save_path)) - -def class_proportions(target): - n_classes = len(np.unique(target)) - props = np.array([np.sum(target == i) for i in range(n_classes)]) - return props / np.sum(props) - - - - - - - diff --git a/stamarker/stamarker/stamarker/__init__.py b/stamarker/stamarker/stamarker/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/stamarker/stamarker/stamarker/__pycache__/__init__.cpython-38.pyc b/stamarker/stamarker/stamarker/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 275509f..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/__init__.cpython-39.pyc b/stamarker/stamarker/stamarker/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index ab20350..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/dataset.cpython-38.pyc b/stamarker/stamarker/stamarker/__pycache__/dataset.cpython-38.pyc deleted file mode 100644 index e38dc8d..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/dataset.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/dataset.cpython-39.pyc b/stamarker/stamarker/stamarker/__pycache__/dataset.cpython-39.pyc deleted file mode 100644 index 4b46d88..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/dataset.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/models.cpython-38.pyc b/stamarker/stamarker/stamarker/__pycache__/models.cpython-38.pyc deleted file mode 100644 index dd9d50b..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/models.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/models.cpython-39.pyc b/stamarker/stamarker/stamarker/__pycache__/models.cpython-39.pyc deleted file mode 100644 index 14b5f9c..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/models.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/modules.cpython-38.pyc b/stamarker/stamarker/stamarker/__pycache__/modules.cpython-38.pyc deleted file mode 100644 index 39452b5..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/modules.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/modules.cpython-39.pyc b/stamarker/stamarker/stamarker/__pycache__/modules.cpython-39.pyc deleted file mode 100644 index 923b393..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/modules.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/pipeline.cpython-38.pyc b/stamarker/stamarker/stamarker/__pycache__/pipeline.cpython-38.pyc deleted file mode 100644 index 1931108..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/pipeline.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/pipeline.cpython-39.pyc b/stamarker/stamarker/stamarker/__pycache__/pipeline.cpython-39.pyc deleted file mode 100644 index 9693383..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/pipeline.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/utils.cpython-38.pyc b/stamarker/stamarker/stamarker/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index c1e4cfd..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/__pycache__/utils.cpython-39.pyc b/stamarker/stamarker/stamarker/__pycache__/utils.cpython-39.pyc deleted file mode 100644 index d018829..0000000 Binary files a/stamarker/stamarker/stamarker/__pycache__/utils.cpython-39.pyc and /dev/null differ diff --git a/stamarker/stamarker/stamarker/dataset.py b/stamarker/stamarker/stamarker/dataset.py deleted file mode 100644 index 14398d6..0000000 --- a/stamarker/stamarker/stamarker/dataset.py +++ /dev/null @@ -1,111 +0,0 @@ -from typing import List -import scanpy as sc -import numpy as np -import torch -import pytorch_lightning as pl -from torch.utils.data import Dataset -from pytorch_lightning.utilities.types import EVAL_DATALOADERS -from torch_geometric.loader import NeighborLoader -from torch_geometric.data import Data -from .utils import compute_spatial_net, stats_spatial_net, compute_edge_list - - -class Batch(dict): - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def to(self, device): - res = dict() - for key, value in self.items(): - if hasattr(value, "to"): - res[key] = value.to(device) - else: - res[key] = value - return Batch(**res) - - -class RepDataset(Dataset): - def __init__(self, - x, - target_y, - ground_truth=None): - assert (len(x) == len(target_y)) - self.x = x - self.target_y = target_y - self.ground_truth = ground_truth - - def __len__(self): - return len(self.x) - - def __getitem__(self, idx): - if torch.is_tensor(idx): - idx = idx.tolist() - sample_x, sample_y = self.x[idx, :], self.target_y[idx] - if self.ground_truth is not None: - sample_gt = self.ground_truth[idx] - else: - sample_gt = np.nan - sample = {"x": sample_x, "y": sample_y, "ground_truth": sample_gt} - return sample - - -class SpatialDataModule(pl.LightningDataModule): - def __init__(self, - full_batch: bool = True, - batch_size: int = 1000, - num_neighbors: List[int] = None, - num_workers=None, - pin_memory=False) -> None: - self.batch_size = batch_size - self.full_batch = full_batch - self.has_y = False - self.train_dataset = None - self.valid_dataset = None - self.num_neighbors = num_neighbors - self.num_workers = num_workers - self.pin_memory = pin_memory - self.ann_data = None - - def prepare_data(self, n_top_genes: int = 3000, rad_cutoff: float = 50, - show_net_stats: bool = False, min_cells=50, min_counts=None) -> None: - sc.pp.calculate_qc_metrics(self.ann_data, inplace=True) - sc.pp.filter_genes(self.ann_data, min_cells=min_cells) - if min_counts is not None: - sc.pp.filter_cells(self.ann_data, min_counts=min_counts) - print("After filtering: ", self.ann_data.shape) - # Normalization - sc.pp.highly_variable_genes(self.ann_data, flavor="seurat_v3", n_top_genes=n_top_genes) - self.ann_data = self.ann_data[:, self.ann_data.var['highly_variable']] - sc.pp.normalize_total(self.ann_data, target_sum=1e4) - sc.pp.log1p(self.ann_data) - compute_spatial_net(self.ann_data, rad_cutoff=rad_cutoff) - if show_net_stats: - stats_spatial_net(self.ann_data) - # ---------------------------- generate data --------------------- - edge_list = compute_edge_list(self.ann_data) - self.train_dataset = Data(edge_index=torch.LongTensor(np.array([edge_list[0], edge_list[1]])), - x=torch.FloatTensor(self.ann_data.X), - y=None) - - def train_dataloader(self): - if self.full_batch: - loader = NeighborLoader(self.train_dataset, num_neighbors=[1], - batch_size=len(self.train_dataset.x)) - else: - loader = NeighborLoader(self.train_dataset, num_neighbors=self.num_neighbors, batch_size=self.batch_size) - return loader - - def val_dataloader(self): - if self.valid_dataset is None: - loader = NeighborLoader(self.train_dataset, num_neighbors=[1], - batch_size=len(self.train_dataset.x)) - else: - raise NotImplementedError - return loader - - def test_dataloader(self) -> EVAL_DATALOADERS: - raise NotImplementedError - - def predict_dataloader(self) -> EVAL_DATALOADERS: - raise NotImplementedError diff --git a/stamarker/stamarker/stamarker/models.py b/stamarker/stamarker/stamarker/models.py deleted file mode 100644 index d7457ec..0000000 --- a/stamarker/stamarker/stamarker/models.py +++ /dev/null @@ -1,257 +0,0 @@ -from abc import ABC -from typing import Any, List -import numpy as np -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, WeightedRandomSampler -from torch_geometric.data import Data -from sklearn.metrics import adjusted_rand_score, confusion_matrix -from .modules import STAGATEModule, StackMLPModule -from .dataset import RepDataset, Batch -from .utils import Timer - -def get_optimizer(name): - if name == "ADAM": - return torch.optim.Adam - elif name == "ADAGRAD": - return torch.optim.Adagrad - elif name == "ADADELTA": - return torch.optim.Adadelta - elif name == "RMS": - return torch.optim.RMSprop - elif name == "ASGD": - return torch.optim.ASGD - else: - raise NotImplementedError - - -def get_scheduler(name): - if name == "STEP_LR": - return torch.optim.lr_scheduler.StepLR - elif name == "EXP_LR": - return torch.optim.lr_scheduler.ExponentialLR - else: - raise NotImplementedError - -class BaseModule(pl.LightningModule, ABC): - def __init__(self): - super(BaseModule, self).__init__() - self.optimizer_params = None - self.scheduler_params = None - self.model = None - self.timer = Timer() - self.automatic_optimization = False - - def set_optimizer_params(self, - optimizer_params: dict, - scheduler_params: dict): - self.optimizer_params = optimizer_params - self.scheduler_params = scheduler_params - - def configure_optimizers(self): - optimizer = get_optimizer(self.optimizer_params["name"])( - self.model.parameters(), - **self.optimizer_params["params"]) - scheduler = get_scheduler(self.scheduler_params["name"])(optimizer, **self.scheduler_params["params"]) - return [optimizer], [scheduler] - - def on_train_epoch_start(self) -> None: - self.timer.tic('train') - - -class intSTAGATE(BaseModule): - """ - intSTAGATE Lightning Module - """ - def __init__(self, - in_features: int = None, - hidden_dims: List[int] = None, - gradient_clipping: float = 5.0, - **kwargs): - super(intSTAGATE, self).__init__() - self.model = STAGATEModule(in_features, hidden_dims) - self.auto_encoder_epochs = None - self.gradient_clipping = gradient_clipping - self.pred_labels = None - self.save_hyperparameters() - - def configure_optimizers(self) -> (dict, dict): - auto_encoder_optimizer = get_optimizer(self.optimizer_params["name"])( - list(self.model.parameters()), - **self.optimizer_params["params"]) - auto_encoder_scheduler = get_scheduler(self.scheduler_params["name"])(auto_encoder_optimizer, - **self.scheduler_params["params"]) - return [auto_encoder_optimizer], [auto_encoder_scheduler] - - def forward(self, x, edge_index) -> Any: - return self.model(x, edge_index) - - def training_step(self, batch, batch_idx): - batch = batch.to(self.device) - opt_auto_encoder = self.optimizers() - z, x_hat = self.model(batch.x, batch.edge_index) - loss = F.mse_loss(batch.x, x_hat) - opt_auto_encoder.zero_grad() - self.manual_backward(loss) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping) - opt_auto_encoder.step() - self.log("Training auto-encoder|Reconstruction errors", loss.item(), prog_bar=True) - self.logger.experiment.add_scalar('auto_encoder/loss', loss.item(), self.current_epoch) - - def on_train_epoch_end(self) -> None: - time = self.timer.toc('train') - sch_auto_encoder = self.lr_schedulers() - sch_auto_encoder.step() - self.logger.experiment.add_scalar('train_time', time, self.current_epoch) - - def validation_step(self, batch, batch_idx): - pass - - def validation_epoch_end(self, outputs): - pass - - -def _compute_correct(scores, target_y): - _, pred_labels = torch.max(scores, axis=1) - correct = (pred_labels == target_y).sum().item() - return pred_labels, correct - - -class CoordTransformer(object): - def __init__(self, coord): - self.coord = coord - - def transform(self): - factor = np.max(np.max(self.coord, axis=0) - np.min(self.coord, axis=0)) - return (self.coord - np.min(self.coord, axis=0)) / factor - - -class StackClassifier(BaseModule): - def __init__(self, in_features: int, - n_classes: int = 7, - batch_size: int = 1000, - shuffle: bool = False, - hidden_dims: List[int] = [30], - architecture: str = "MLP", - sta_path: str = None, - **kwargs): - super(StackClassifier, self).__init__() - self.in_features = in_features - self.architecture = architecture - self.batch_size = batch_size - self.shuffle = shuffle - if architecture == "MLP": - self.model = StackMLPModule(in_features, n_classes, hidden_dims, **kwargs) - else: - raise NotImplementedError - self.dataset = None - self.train_dataset = None - self.val_dataset = None - self.automatic_optimization = False - self.sampler = None - self.test_prop = None - self.confusion = None - self.balanced = None - self.save_hyperparameters() - - def prepare(self, - stagate: intSTAGATE, - dataset: Data, - target_y, - test_prop: float = 0.5, - balanced: bool = True): - self.balanced = balanced - self.test_prop = test_prop - with torch.no_grad(): - representation, _ = stagate(dataset.x, dataset.edge_index) - if hasattr(dataset, "ground_truth"): - ground_truth = dataset.ground_truth - else: - ground_truth = None - if isinstance(target_y, np.ndarray): - target_y = torch.from_numpy(target_y).type(torch.LongTensor) - elif isinstance(target_y, torch.Tensor): - target_y = target_y.type(torch.LongTensor) - else: - raise TypeError("target_y must be either a torch tensor or a numpy ndarray.") - self.dataset = RepDataset(representation, target_y, ground_truth=ground_truth) - n_val = int(len(self.dataset) * test_prop) - self.train_dataset, self.val_dataset = torch.utils.data.random_split( - self.dataset, [len(self.dataset) - n_val, n_val]) - if balanced: - target_y = target_y[self.train_dataset.indices] - class_sample_count = np.array([len(np.where(target_y == t)[0]) for t in np.unique(target_y)]) - weight = 1. / class_sample_count - samples_weight = np.array([weight[t] for t in target_y]) - samples_weight = torch.from_numpy(samples_weight) - samples_weight = samples_weight.double() - self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) - - def forward(self, x, edge_index=None) -> Any: - if self.architecture == "MLP": - return self.model(x) - elif self.architecture == "STACls": - _, output = self.model(x, edge_index) - return output - - def training_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - opt = self.optimizers() - opt.zero_grad() - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - self.manual_backward(loss) - opt.step() - _, correct = _compute_correct(output["score"], batch.y) - self.log(f"Training {self.architecture} classifier|Cross entropy", loss.item(), prog_bar=True) - return {"loss": loss, "correct": correct} - - def training_epoch_end(self, outputs): - time = self.timer.toc('train') - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_time', time, self.current_epoch) - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - train_acc = all_correct / len(self.train_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/train_acc', - train_acc, self.current_epoch) - - def validation_step(self, batch, batch_idx): - batch = Batch(**batch) - batch = batch.to(self.device) - with torch.no_grad(): - output = self.model(batch.x) - loss = F.cross_entropy(output["score"], batch.y) - pred_labels, correct = _compute_correct(output["score"], batch.y) - return {"loss": loss, "correct": correct, "pred_labels": pred_labels, "true_labels": batch.y} - - def validation_epoch_end(self, outputs): - all_loss = torch.stack([x["loss"] for x in outputs]) - all_correct = np.sum([x["correct"] for x in outputs]) - pred_labels = torch.cat([x["pred_labels"] for x in outputs]).cpu().detach().numpy() - true_labels = torch.cat([x["true_labels"] for x in outputs]).cpu().detach().numpy() - confusion = confusion_matrix(true_labels, pred_labels) - val_acc = all_correct / len(self.val_dataset) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_loss', - torch.mean(all_loss), self.current_epoch) - self.logger.experiment.add_scalar(f'classifier-{self.architecture}/val_acc', - val_acc, self.current_epoch) - print("\n validation ACC={:.4f}".format(val_acc)) - self.confusion = confusion - - def train_dataloader(self): - loader = DataLoader(self.train_dataset, batch_size=self.batch_size, sampler=self.sampler) - return loader - - def val_dataloader(self): - loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=self.shuffle, drop_last=False) - return loader - - def test_dataloader(self): - raise NotImplementedError - - def predict_dataloader(self): - raise NotImplementedError \ No newline at end of file diff --git a/stamarker/stamarker/stamarker/modules.py b/stamarker/stamarker/stamarker/modules.py deleted file mode 100644 index 08d9d81..0000000 --- a/stamarker/stamarker/stamarker/modules.py +++ /dev/null @@ -1,276 +0,0 @@ -import abc -import copy -from torch.autograd import Variable -import torch -from torch import Tensor -import torch.nn.functional as F -from torch.nn import Parameter -import torch.nn as nn -from torch_sparse import SparseTensor, set_diag -from torch_geometric.nn.conv import MessagePassing -from torch_geometric.utils import remove_self_loops, add_self_loops, softmax -from typing import Union, Tuple, Optional -from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, - OptTensor) - - -class GATConv(MessagePassing): - r"""The graph attentional operator from the `"Graph Attention Networks" - `_ paper - .. math:: - \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + - \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, - where the attention coefficients :math:`\alpha_{i,j}` are computed as - .. math:: - \alpha_{i,j} = - \frac{ - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] - \right)\right)} - {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] - \right)\right)}. - Args: - in_channels (int or tuple): Size of each input sample, or :obj:`-1` to - derive the size from the first input(s) to the forward method. - A tuple corresponds to the sizes of source and target - dimensionalities. - out_channels (int): Size of each output sample. - heads (int, optional): Number of multi-head-attentions. - (default: :obj:`1`) - concat (bool, optional): If set to :obj:`False`, the multi-head - attentions are averaged instead of concatenated. - (default: :obj:`True`) - negative_slope (float, optional): LeakyReLU angle of the negative - slope. (default: :obj:`0.2`) - dropout (float, optional): Dropout probability of the normalized - attention coefficients which exposes each node to a stochastically - sampled neighborhood during training. (default: :obj:`0`) - add_self_loops (bool, optional): If set to :obj:`False`, will not add - self-loops to the input graph. (default: :obj:`True`) - bias (bool, optional): If set to :obj:`False`, the layer will not learn - an additive bias. (default: :obj:`True`) - **kwargs (optional): Additional arguments of - :class:`torch_geometric.nn.conv.MessagePassing`. - """ - _alpha: OptTensor - - def __init__(self, in_channels: Union[int, Tuple[int, int]], - out_channels: int, heads: int = 1, concat: bool = True, - negative_slope: float = 0.2, dropout: float = 0.0, - add_self_loops: bool = True, bias: bool = True, **kwargs): - kwargs.setdefault('aggr', 'add') - super(GATConv, self).__init__(node_dim=0, **kwargs) - - self.in_channels = in_channels - self.out_channels = out_channels - self.heads = heads - self.concat = concat - self.negative_slope = negative_slope - self.dropout = dropout - self.add_self_loops = add_self_loops - self.lin_src = nn.Parameter(torch.zeros(size=(in_channels, out_channels))) - nn.init.xavier_normal_(self.lin_src.data, gain=1.414) - self.lin_dst = self.lin_src - # The learnable parameters to compute attention coefficients: - self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) - self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) - nn.init.xavier_normal_(self.att_src.data, gain=1.414) - nn.init.xavier_normal_(self.att_dst.data, gain=1.414) - self._alpha = None - self.attentions = None - - def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, - size: Size = None, return_attention_weights=None, attention=True, tied_attention=None): - # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa - # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa - # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa - r""" - Args: - return_attention_weights (bool, optional): If set to :obj:`True`, - will additionally return the tuple - :obj:`(edge_index, attention_weights)`, holding the computed - attention weights for each edge. (default: :obj:`None`) - """ - H, C = self.heads, self.out_channels - - # We first transform the input node features. If a tuple is passed, we - # transform source and target node features via separate weights: - if isinstance(x, Tensor): - assert x.dim() == 2, "Static graphs not supported in 'GATConv'" - # x_src = x_dst = self.lin_src(x).view(-1, H, C) - x_src = x_dst = torch.mm(x, self.lin_src).view(-1, H, C) - else: # Tuple of source and target node features: - x_src, x_dst = x - assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" - x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None: - x_dst = self.lin_dst(x_dst).view(-1, H, C) - - x = (x_src, x_dst) - - if not attention: - return x[0].mean(dim=1) - # return x[0].view(-1, self.heads * self.out_channels) - - if tied_attention == None: - # Next, we compute node-level attention coefficients, both for source - # and target nodes (if present): - alpha_src = (x_src * self.att_src).sum(dim=-1) - alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) - alpha = (alpha_src, alpha_dst) - self.attentions = alpha - else: - alpha = tied_attention - - if self.add_self_loops: - if isinstance(edge_index, Tensor): - # We only want to add self-loops for nodes that appear both as - # source and target nodes: - num_nodes = x_src.size(0) - if x_dst is not None: - num_nodes = min(num_nodes, x_dst.size(0)) - num_nodes = min(size) if size is not None else num_nodes - edge_index, _ = remove_self_loops(edge_index) - edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) - elif isinstance(edge_index, SparseTensor): - edge_index = set_diag(edge_index) - - # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) - out = self.propagate(edge_index, x=x, alpha=alpha, size=size) - - alpha = self._alpha - assert alpha is not None - self._alpha = None - - if self.concat: - out = out.view(-1, self.heads * self.out_channels) - else: - out = out.mean(dim=1) - - # if self.bias is not None: - # out += self.bias - - if isinstance(return_attention_weights, bool): - if isinstance(edge_index, Tensor): - return out, (edge_index, alpha) - elif isinstance(edge_index, SparseTensor): - return out, edge_index.set_value(alpha, layout='coo') - else: - return out - - def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, - index: Tensor, ptr: OptTensor, - size_i: Optional[int]) -> Tensor: - # Given egel-level attention coefficients for source and target nodes, - # we simply need to sum them up to "emulate" concatenation: - alpha = alpha_j if alpha_i is None else alpha_j + alpha_i - - alpha = F.leaky_relu(alpha, self.negative_slope) - alpha = softmax(alpha, index, ptr, size_i) - self._alpha = alpha # Save for later use. - alpha = F.dropout(alpha, p=self.dropout, training=self.training) - return x_j * alpha.unsqueeze(-1) - - def __repr__(self): - return '{}({}, {}, heads={})'.format(self.__class__.__name__, - self.in_channels, - self.out_channels, self.heads) - - -class STAGATEModule(nn.Module): - def __init__(self, in_features, hidden_dims): - super(STAGATEModule, self).__init__() - [num_hidden, out_dim] = hidden_dims - self.conv1 = GATConv(in_features, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - self.conv4 = GATConv(num_hidden, in_features, heads=1, concat=False, - dropout=0, add_self_loops=False, bias=False) - - def forward(self, features, edge_index): - h1 = F.elu(self.conv1(features, edge_index)) - h2 = self.conv2(h1, edge_index, attention=False) - self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1) - self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1) - self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1) - self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1) - h3 = F.elu(self.conv3(h2, edge_index, attention=True, - tied_attention=self.conv1.attentions)) - h4 = self.conv4(h3, edge_index, attention=False) - - return h2, h4 - - -class StackClsModule(nn.Module, abc.ABC): - def __init__(self, in_features, n_classes): - super(StackClsModule, self).__init__() - self.in_features = in_features - self.n_classes = n_classes - - -class STAGATEClsModule(nn.Module): - def __init__(self, - stagate: STAGATEModule, - stack_classifier: StackClsModule): - super(STAGATEClsModule, self).__init__() - self.stagate = copy.deepcopy(stagate) - self.classifier = copy.deepcopy(stack_classifier) - - def forward(self, x, edge_index, mode="classifier"): - z, x_recon = self.stagate(x, edge_index) - z = torch.clone(z) - if mode == "classifier": - return z, self.classifier(z) - elif mode == "reconstruction": - return z, x_recon - else: - raise NotImplementedError - - def get_saliency_map(self, x, edge_index, target_index="max", save=None): - """ - Get saliency map by backpropagation. - :param x: input tensors - :param edge_index: graph edge index - :param target_index: target index to compute final scores - :param save: - :return: gradients - """ - x_var = Variable(x, requires_grad=True) - _, output = self.forward(x_var, edge_index, mode="classifier") - scores = output["last_layer"] - if target_index == "max": - target_score_indices = Variable(torch.argmax(scores, 1)) - elif isinstance(target_index, int): - target_score_indices = Variable(torch.ones(scores.shape[0], dtype=torch.int64) * target_index) - else: - raise NotImplementedError - target_scores = scores.gather(1, target_score_indices.view(-1, 1)).squeeze() - loss = torch.sum(target_scores) - loss.backward() - gradients = x_var.grad.data - if save is not None: - torch.save(gradients, save) - return gradients, scores - - -class StackMLPModule(StackClsModule): - name = "StackMLP" - - def __init__(self, in_features, n_classes, hidden_dims=[30, 40, 30]): - super(StackMLPModule, self).__init__(in_features, n_classes) - self.classifier = nn.ModuleList() - mlp_dims = [in_features] + hidden_dims + [n_classes] - for ind in range(len(mlp_dims) - 1): - self.classifier.append(nn.Linear(mlp_dims[ind], mlp_dims[ind + 1])) - - def forward(self, x): - for layer in self.classifier: - x = layer(x) - score = F.softmax(x, dim=0) - return {"last_layer": x, "score": score} diff --git a/stamarker/stamarker/stamarker/pipeline.py b/stamarker/stamarker/stamarker/pipeline.py deleted file mode 100644 index 98255ae..0000000 --- a/stamarker/stamarker/stamarker/pipeline.py +++ /dev/null @@ -1,286 +0,0 @@ -import pytorch_lightning as pl -import copy -import torch -import os -import shutil -import logging -import glob -import sys -import numpy as np -import scipy -import scanpy as sc -from pytorch_lightning.loggers import TensorBoardLogger -from scipy.cluster import hierarchy -from .models import intSTAGATE, StackClassifier -from .utils import plot_consensus_map, consensus_matrix, Timer -from .dataset import SpatialDataModule -from .modules import STAGATEClsModule -import logging - - -FORMAT = "%(asctime)s %(levelname)s %(message)s" -logging.basicConfig(format=FORMAT, datefmt='%Y-%m-%d %H:%M:%S') -def make_spatial_data(ann_data): - """ - Make SpatialDataModule object from Scanpy annData object - """ - data_module = SpatialDataModule() - ann_data.X = ann_data.X.toarray() - data_module.ann_data = ann_data - return data_module - - -class STAMarker: - def __init__(self, n, save_dir, config, logging_level=logging.INFO): - """ - n: int, number of graph attention auto-econders to train - save_dir: directory to save the models - config: config file for training - """ - self.n = n - self.save_dir = save_dir - if not os.path.exists(save_dir): - os.mkdir(save_dir) - logging.info("Create save directory {}".format(save_dir)) - self.version_dirs = [os.path.join(save_dir, f"version_{i}") for i in range(n)] - self.config = config - self.logger = logging.getLogger("STAMarker") - self.logger.setLevel(logging_level) - self.consensus_labels = None - - def load_from_dir(self, save_dir, ): - """ - Load the trained models from a directory - """ - self.version_dirs = glob.glob(os.path.join(save_dir, "version_*")) - self.version_dirs = sorted(self.version_dirs, key=lambda x: int(x.split("_")[-1])) - # check if all version dir have `checkpoints/stagate.ckpt` - version_dirs_valid = [] - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "checkpoints/stagate.ckpt")): - self.logger.warning("No checkpoint found in {}".format(version_dir)) - else: - version_dirs_valid.append(version_dir) - self.version_dirs = version_dirs_valid - self.logger.info("Load {} autoencoder models from {}".format(len(version_dirs_valid), save_dir)) - # check if all version dir have `cluster_labels.npy` raise warning if not - missing_cluster_labels = False - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "cluster_labels.npy")): - missing_cluster_labels = True - msg = "No cluster labels found in {}.".format(version_dir) - self.logger.warning(msg) - if missing_cluster_labels: - self.logger.warning("Please run clustering first.") - # check if save_dir has `consensus.npy` raise warning if not - if not os.path.exists(os.path.join(save_dir, "consensus.npy")): - self.logger.warning("No consensus labels found in {}".format(save_dir)) - else: - self.consensus_labels = np.load(os.path.join(save_dir, "consensus.npy")) - # check if all version dir have `checkpoints/mlp.ckpt` raise warning if not - missing_clf = False - for version_dir in self.version_dirs: - if not os.path.exists(os.path.join(version_dir, "checkpoints/mlp.ckpt")): - self.logger.warning("No classifier checkpoint found in {}".format(version_dir)) - missing_clf = True - if missing_clf: - self.logger.warning("Please run classifier training first.") - if not missing_cluster_labels and not missing_clf: - self.logger.info("All models are trained and ready to use.") - - def train_auto_encoders(self, data_module): - for seed in range(self.n): - self._train_auto_encoder(data_module, seed, self.config) - self.logger.info("Finished training {} auto-encoders".format(self.n)) - - def clustering(self, data_module, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoders - Cluster method should be "louvain" or "mclust" - """ - for version_dir in self.version_dirs: - self._clustering(data_module, version_dir, cluster_method, cluster_params) - self.logger.info("Finished {} clustering with {}".format(self.n, cluster_method)) - - def consensus_clustering(self, n_clusters, name="cluster_labels.npy"): - sys.setrecursionlimit(100000) - label_files = glob.glob(self.save_dir + f"/version_*/{name}") - labels_list = list(map(lambda file: np.load(file), label_files)) - cons_mat = consensus_matrix(labels_list) - row_linkage, _, figure = plot_consensus_map(cons_mat, return_linkage=True) - figure.savefig(os.path.join(self.save_dir, "consensus_clustering.png"), dpi=300) - consensus_labels = hierarchy.cut_tree(row_linkage, n_clusters).squeeze() - np.save(os.path.join(self.save_dir, "consensus"), consensus_labels) - self.consensus_labels = consensus_labels - self.logger.info("Save consensus labels to {}".format(os.path.join(self.save_dir, "consensus.npz"))) - - def train_classifiers(self, data_module, n_clusters, name="cluster_labels.npy"): - for version_dir in self.version_dirs: - # _train_classifier(self, data_module, version_dir, target_y, n_classes, seed=None) - self._train_classifier(data_module, version_dir, self.consensus_labels, n_clusters, self.config) - self.logger.info("Finished training {} classifiers".format(self.n)) - - def compute_smaps(self, data_module, return_recon=True, normalize=True): - smaps = [] - if return_recon: - recons = [] - for version_dir in self.version_dirs: - if return_recon: - smap, recon = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - recons.append(recon) - else: - smap = self._compute_smap(data_module, version_dir, return_recon=return_recon) - smaps.append(smap) - if return_recon: - return smaps, recons - else: - return smaps - self.logger.info("Finished computing {} smaps".format(self.n)) - - - def _compute_smap_zscore(self, smap, labels, logtransform=False): - scores = np.log(smap + 1) if logtransform else copy.copy(smap) - unique_labels = np.unique(labels) - for l in unique_labels: - scores[labels == l, :] = scipy.stats.zscore(scores[labels == l, :], axis=1) - return scores - - - def _clustering(self, data_module, version_dir, cluster_method, cluster_params): - """ - Cluster the latent space of the trained auto-encoder - """ - if cluster_method == "louvain": - run_louvain(data_module, version_dir, cluster_params) - elif cluster_method == "mclust": - run_mclust(data_module, version_dir, cluster_params) - else: - raise ValueError("Unknown clustering method") - - def _train_auto_encoder(self, data_module, seed, config): - """ - Train a single graph attention auto-encoder - """ - pl.seed_everything(seed) - version = f"version_{seed}" - version_dir = os.path.join(self.save_dir, version) - if os.path.exists(version_dir): - shutil.rmtree(version_dir) - os.makedirs(version_dir, exist_ok=True) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - model = intSTAGATE(**config["stagate"]["params"]) - model.set_optimizer_params(config["stagate"]["optimizer"], config["stagate"]["scheduler"]) - trainer = pl.Trainer(logger=logger, **config["stagate_trainer"]) - timer = Timer() - timer.tic("fit") - trainer.fit(model, data_module) - fit_time = timer.toc("fit") - with open(os.path.join(version_dir, "runtime.csv"), "w+") as f: - f.write("{}, fit_time, {:.2f}, ".format(seed, fit_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - del model, trainer - if config["stagate_trainer"]["gpus"] > 0: - torch.cuda.empty_cache() - logging.info(f"Finshed running version {seed}") - - def _train_classifier(self, data_module, version_dir, target_y, n_classes, config, seed=None): - timer = Timer() - pl.seed_everything(seed) - rep_dim = config["stagate"]["params"]["hidden_dims"][-1] - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - classifier = StackClassifier(rep_dim, n_classes=n_classes, architecture="MLP") - classifier.prepare(stagate, data_module.train_dataset, target_y, - balanced=config["mlp"]["balanced"], test_prop=config["mlp"]["test_prop"]) - classifier.set_optimizer_params(config["mlp"]["optimizer"], config["mlp"]["scheduler"]) - logger = TensorBoardLogger(save_dir=self.save_dir, name=None, - default_hp_metric=False, - version=seed) - trainer = pl.Trainer(logger=logger, **config["classifier_trainer"]) - timer.tic("clf") - trainer.fit(classifier) - clf_time = timer.toc("clf") - with open(os.path.join(version_dir, "runtime.csv"), "a+") as f: - f.write("\n") - f.write("{}, clf_time, {:.2f}, ".format(seed, clf_time / 60)) - trainer.save_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - target_y = classifier.dataset.target_y.numpy() - all_props = class_proportions(target_y) - val_props = class_proportions(target_y[classifier.val_dataset.indices]) - if self.logger.level == logging.DEBUG: - print("All class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in all_props])) - print("Val class proportions " + "|".join(["{:.2f}%".format(prop * 100) for prop in val_props])) - np.save(os.path.join(version_dir, "confusion.npy"), classifier.confusion) - - def _compute_smap(self, data_module, version_dir, return_recon=True): - """ - Compute the saliency map of the trained auto-encoder - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - cls = StackClassifier.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "mlp.ckpt")) - stagate_cls = STAGATEClsModule(stagate.model, cls.model) - smap, _ = stagate_cls.get_saliency_map(data_module.train_dataset.x, - data_module.train_dataset.edge_index) - smap = smap.detach().cpu().numpy() - if return_recon: - recon = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[1].cpu().detach().numpy() - return smap, recon - else: - return smap - - -def run_louvain(data_module, version_dir, resolution, name="cluster_labels"): - """ - Run louvain clustering on the data_module - """ - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - ann_data = copy.copy(data_module.ann_data) - ann_data.obsm["stagate"] = embedding - sc.pp.neighbors(ann_data, use_rep='stagate') - sc.tl.louvain(ann_data, resolution=resolution) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, ann_data.obs["louvain"].to_numpy().astype("int")) - print("Save louvain results to {}".format(save_path)) - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def run_mclust(data_module, version_dir, n_clusters, name="cluster_labels"): - stagate = intSTAGATE.load_from_checkpoint(os.path.join(version_dir, "checkpoints", "stagate.ckpt")) - embedding = stagate(data_module.train_dataset.x, data_module.train_dataset.edge_index)[0].cpu().detach().numpy() - labels = mclust_R(embedding, n_clusters) - save_path = os.path.join(version_dir, "{}.npy".format(name)) - np.save(save_path, labels.astype("int")) - print("Save MClust results to {}".format(save_path)) - -def class_proportions(target): - n_classes = len(np.unique(target)) - props = np.array([np.sum(target == i) for i in range(n_classes)]) - return props / np.sum(props) - - - - - - - diff --git a/stamarker/stamarker/stamarker/utils.py b/stamarker/stamarker/stamarker/utils.py deleted file mode 100644 index c4f6182..0000000 --- a/stamarker/stamarker/stamarker/utils.py +++ /dev/null @@ -1,190 +0,0 @@ -import time -import yaml -import os -import seaborn as sns -import numpy as np -import pandas as pd -import scanpy as sc -import itertools -import scipy -from scipy.spatial import distance -from scipy.cluster import hierarchy -import sklearn.neighbors -from typing import List - - -def plot_consensus_map(cmat, method="average", return_linkage=True, **kwargs): - row_linkage = hierarchy.linkage(distance.pdist(cmat), method=method) - col_linkage = hierarchy.linkage(distance.pdist(cmat.T), method=method) - figure = sns.clustermap(cmat, row_linkage=row_linkage, col_linkage=col_linkage, **kwargs) - if return_linkage: - return row_linkage, col_linkage, figure - else: - return figure - - -class Timer: - - def __init__(self): - self.timer_dict = {} - self.stop_dict = {} - - def tic(self, name): - self.timer_dict[name] = time.time() - - def toc(self, name): - assert name in self.timer_dict - elapsed = time.time() - self.timer_dict[name] - del self.timer_dict[name] - return elapsed - - def stop(self, name): - self.stop_dict[name] = time.time() - - def resume(self, name): - if name not in self.timer_dict: - del self.stop_dict[name] - return - elapsed = time.time() - self.stop_dict[name] - self.timer_dict[name] = self.timer_dict[name] + elapsed - del self.stop_dict[name] - - -def save_yaml(yaml_object, file_path): - with open(file_path, 'w') as yaml_file: - yaml.dump(yaml_object, yaml_file, default_flow_style=False) - - print(f'Saving yaml: {file_path}') - return - - -def parse_args(yaml_file): - with open(yaml_file, 'r') as stream: - try: - cfg = yaml.safe_load(stream) - except yaml.YAMLError as exc: - print(exc) - return cfg - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def labels_connectivity_mat(labels: np.ndarray): - _labels = labels - np.min(labels) - n_classes = np.unique(_labels) - mat = np.zeros([labels.size, labels.size]) - for i in n_classes: - indices = np.squeeze(np.where(_labels == i)) - row_indices, col_indices = zip(*itertools.product(indices, indices)) - mat[row_indices, col_indices] = 1 - return mat - - -def consensus_matrix(labels_list: List[np.ndarray]): - mat = 0 - for labels in labels_list: - mat += labels_connectivity_mat(labels) - return mat / float(len(labels_list)) - - -def compute_spatial_net(ann_data, rad_cutoff=None, k_cutoff=None, - max_neigh=50, model='Radius', verbose=True): - """ - Construct the spatial neighbor networks. - - Parameters - ---------- - ann_data - AnnData object of scanpy package. - rad_cutoff - radius cutoff when model='Radius' - k_cutoff - The number of nearest neighbors when model='KNN' - model - The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. - - Returns - ------- - The spatial networks are saved in adata.uns['Spatial_Net'] - """ - - assert (model in ['Radius', 'KNN']) - if verbose: - print('------Calculating spatial graph...') - coor = pd.DataFrame(ann_data.obsm['spatial']) - coor.index = ann_data.obs.index - coor.columns = ['imagerow', 'imagecol'] - - nbrs = sklearn.neighbors.NearestNeighbors( - n_neighbors=max_neigh + 1, algorithm='ball_tree').fit(coor) - distances, indices = nbrs.kneighbors(coor) - if model == 'KNN': - indices = indices[:, 1:k_cutoff + 1] - distances = distances[:, 1:k_cutoff + 1] - if model == 'Radius': - indices = indices[:, 1:] - distances = distances[:, 1:] - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) - KNN_df = pd.concat(KNN_list) - KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] - Spatial_Net = KNN_df.copy() - if model == 'Radius': - Spatial_Net = KNN_df.loc[KNN_df['Distance'] < rad_cutoff,] - id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) - cell1, cell2 = Spatial_Net['Cell1'].map(id_cell_trans), Spatial_Net['Cell2'].map(id_cell_trans) - Spatial_Net = Spatial_Net.assign(Cell1=cell1, Cell2=cell2) - # Spatial_Net.assign(Cell1=Spatial_Net['Cell1'].map(id_cell_trans)) - # Spatial_Net.assign(Cell2=Spatial_Net['Cell2'].map(id_cell_trans)) - if verbose: - print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], ann_data.n_obs)) - print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / ann_data.n_obs)) - ann_data.uns['Spatial_Net'] = Spatial_Net - - -def compute_edge_list(ann_data): - G_df = ann_data.uns['Spatial_Net'].copy() - cells = np.array(ann_data.obs_names) - cells_id_tran = dict(zip(cells, range(cells.shape[0]))) - G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) - G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) - G = scipy.sparse.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), - shape=(ann_data.n_obs, ann_data.n_obs)) - G = G + scipy.sparse.eye(G.shape[0]) - edge_list = np.nonzero(G) - return edge_list - - -def stats_spatial_net(ann_data): - import matplotlib.pyplot as plt - Num_edge = ann_data.uns['Spatial_Net']['Cell1'].shape[0] - Mean_edge = Num_edge / ann_data.shape[0] - plot_df = pd.value_counts(pd.value_counts(ann_data.uns['Spatial_Net']['Cell1'])) - plot_df = plot_df / ann_data.shape[0] - fig, ax = plt.subplots(figsize=[3, 2]) - plt.ylabel('Percentage') - plt.xlabel('') - plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge) - ax.bar(plot_df.index, plot_df) - - -# def select_svgs(smaps, sd_id, labels, alpha=1.5): - diff --git a/stamarker/stamarker/utils.py b/stamarker/stamarker/utils.py deleted file mode 100644 index 10a4c14..0000000 --- a/stamarker/stamarker/utils.py +++ /dev/null @@ -1,192 +0,0 @@ -import time -import yaml -import os -import seaborn as sns -import numpy as np -import pandas as pd -import scanpy as sc -import itertools -import scipy -from scipy.spatial import distance -from scipy.cluster import hierarchy -import sklearn.neighbors -from typing import List - - -def plot_consensus_map(cmat, method="average", return_linkage=True, **kwargs): - row_linkage = hierarchy.linkage(distance.pdist(cmat), method=method) - col_linkage = hierarchy.linkage(distance.pdist(cmat.T), method=method) - figure = sns.clustermap(cmat, row_linkage=row_linkage, col_linkage=col_linkage, **kwargs) - if return_linkage: - return row_linkage, col_linkage, figure - else: - return figure - - -class Timer: - - def __init__(self): - self.timer_dict = {} - self.stop_dict = {} - - def tic(self, name): - self.timer_dict[name] = time.time() - - def toc(self, name): - assert name in self.timer_dict - elapsed = time.time() - self.timer_dict[name] - del self.timer_dict[name] - return elapsed - - def stop(self, name): - self.stop_dict[name] = time.time() - - def resume(self, name): - if name not in self.timer_dict: - del self.stop_dict[name] - return - elapsed = time.time() - self.stop_dict[name] - self.timer_dict[name] = self.timer_dict[name] + elapsed - del self.stop_dict[name] - - -def save_yaml(yaml_object, file_path): - with open(file_path, 'w') as yaml_file: - yaml.dump(yaml_object, yaml_file, default_flow_style=False) - - print(f'Saving yaml: {file_path}') - return - - -def parse_args(yaml_file): - with open(yaml_file, 'r') as stream: - try: - cfg = yaml.safe_load(stream) - except yaml.YAMLError as exc: - print(exc) - return cfg - - -def mclust_R(representation, n_clusters, r_seed=2022, model_name="EEE"): - """ - Clustering using the mclust algorithm. - The parameters are the same as those in the R package mclust. - """ - np.random.seed(r_seed) - import rpy2.robjects as ro - from rpy2.robjects import numpy2ri - numpy2ri.activate() - ro.r.library("mclust") - r_random_seed = ro.r['set.seed'] - r_random_seed(r_seed) - rmclust = ro.r['Mclust'] - res = rmclust(representation, n_clusters, model_name) - mclust_res = np.array(res[-2]) - numpy2ri.deactivate() - return mclust_res.astype('int') - - -def labels_connectivity_mat(labels: np.ndarray): - _labels = labels - np.min(labels) - n_classes = np.unique(_labels) - mat = np.zeros([labels.size, labels.size]) - for i in n_classes: - indices = np.squeeze(np.where(_labels == i)) - row_indices, col_indices = zip(*itertools.product(indices, indices)) - mat[row_indices, col_indices] = 1 - return mat - - -def consensus_matrix(labels_list: List[np.ndarray]): - mat = 0 - for labels in labels_list: - mat += labels_connectivity_mat(labels) - return mat / float(len(labels_list)) - - -def compute_spatial_net(ann_data, rad_cutoff=None, k_cutoff=None, - max_neigh=50, model='Radius', verbose=True): - """ - Construct the spatial neighbor networks. - - Parameters - ---------- - ann_data - AnnData object of scanpy package. - rad_cutoff - radius cutoff when model='Radius' - k_cutoff - The number of nearest neighbors when model='KNN' - model - The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. - - Returns - ------- - The spatial networks are saved in adata.uns['Spatial_Net'] - """ - - assert (model in ['Radius', 'KNN']) - if verbose: - print('------Calculating spatial graph...') - coor = pd.DataFrame(ann_data.obsm['spatial']) - coor.index = ann_data.obs.index - coor.columns = ['imagerow', 'imagecol'] - - nbrs = sklearn.neighbors.NearestNeighbors( - n_neighbors=max_neigh + 1, algorithm='ball_tree').fit(coor) - distances, indices = nbrs.kneighbors(coor) - if model == 'KNN': - indices = indices[:, 1:k_cutoff + 1] - distances = distances[:, 1:k_cutoff + 1] - if model == 'Radius': - indices = indices[:, 1:] - distances = distances[:, 1:] - KNN_list = [] - for it in range(indices.shape[0]): - KNN_list.append(pd.DataFrame(zip([it] * indices.shape[1], indices[it, :], distances[it, :]))) - KNN_df = pd.concat(KNN_list) - KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] - Spatial_Net = KNN_df.copy() - if model == 'Radius': - Spatial_Net = KNN_df.loc[KNN_df['Distance'] < rad_cutoff,] - id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) - cell1, cell2 = Spatial_Net['Cell1'].map(id_cell_trans), Spatial_Net['Cell2'].map(id_cell_trans) - Spatial_Net = Spatial_Net.assign(Cell1=cell1, Cell2=cell2) - # Spatial_Net.assign(Cell1=Spatial_Net['Cell1'].map(id_cell_trans)) - # Spatial_Net.assign(Cell2=Spatial_Net['Cell2'].map(id_cell_trans)) - if verbose: - print('The graph contains %d edges, %d cells.' % (Spatial_Net.shape[0], ann_data.n_obs)) - print('%.4f neighbors per cell on average.' % (Spatial_Net.shape[0] / ann_data.n_obs)) - ann_data.uns['Spatial_Net'] = Spatial_Net - - -def compute_edge_list(ann_data): - G_df = ann_data.uns['Spatial_Net'].copy() - cells = np.array(ann_data.obs_names) - cells_id_tran = dict(zip(cells, range(cells.shape[0]))) - G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran) - G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran) - G = scipy.sparse.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), - shape=(ann_data.n_obs, ann_data.n_obs)) - G = G + scipy.sparse.eye(G.shape[0]) - edge_list = np.nonzero(G) - return edge_list - - -def stats_spatial_net(ann_data): - import matplotlib.pyplot as plt - Num_edge = ann_data.uns['Spatial_Net']['Cell1'].shape[0] - Mean_edge = Num_edge / ann_data.shape[0] - plot_df = pd.value_counts(pd.value_counts(ann_data.uns['Spatial_Net']['Cell1'])) - plot_df = plot_df / ann_data.shape[0] - fig, ax = plt.subplots(figsize=[3, 2]) - plt.ylabel('Percentage') - plt.xlabel('') - plt.title('Number of Neighbors (Mean=%.2f)' % Mean_edge) - ax.bar(plot_df.index, plot_df) - - -def select_stmaker_svgs(df, sd_id, alpha=1.5, top=None): - scores = df[f"score_{sd_id}"] - mu, std = np.mean(scores), np.std(scores) - return df.index[scores > mu + alpha * std].tolist()