diff --git a/pyims/pyims/frame.py b/pyims/pyims/frame.py index 060c471c..314fcbb0 100644 --- a/pyims/pyims/frame.py +++ b/pyims/pyims/frame.py @@ -158,7 +158,8 @@ def filter(self, TimsFrame: Filtered frame. """ - return TimsFrame.from_py_tims_frame(self.__frame_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max, intensity_min, intensity_max)) + return TimsFrame.from_py_tims_frame(self.__frame_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max, + intensity_min, intensity_max)) def to_resolution(self, resolution: int) -> 'TimsFrame': """Convert the frame to a given resolution. @@ -179,7 +180,8 @@ def to_tims_spectra(self) -> List['TimsSpectrum']: """ return [TimsSpectrum.from_py_tims_spectrum(spec) for spec in self.__frame_ptr.to_tims_spectra()] - def to_windows(self, window_length: float = 10, overlapping: bool = True, min_num_peaks: int = 5, min_intensity: float = 1) -> List[MzSpectrum]: + def to_windows(self, window_length: float = 10, overlapping: bool = True, min_num_peaks: int = 5, + min_intensity: float = 1) -> List[MzSpectrum]: """Convert the frame to a list of windows. Args: @@ -191,7 +193,8 @@ def to_windows(self, window_length: float = 10, overlapping: bool = True, min_nu Returns: List[MzSpectrum]: List of windows. """ - return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__frame_ptr.to_windows(window_length, overlapping, min_num_peaks, min_intensity)] + return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__frame_ptr.to_windows( + window_length, overlapping, min_num_peaks, min_intensity)] def __repr__(self): return (f"TimsFrame(frame_id={self.__frame_ptr.frame_id}, ms_type={self.__frame_ptr.ms_type}, " @@ -221,4 +224,5 @@ def __init__(self, frame_id: int, ms_type: int, retention_time: float, scan: NDA assert len(scan) == len(mobility) == len(tof) == len(indices) == len(intensity), \ "The length of the scan, mobility, tof, indices and intensity arrays must be equal." - self.__frame_ptr = pims.PyTimsFrameVectorized(frame_id, ms_type, retention_time, scan, mobility, tof, indices, intensity) \ No newline at end of file + self.__frame_ptr = pims.PyTimsFrameVectorized(frame_id, ms_type, retention_time, scan, mobility, tof, indices, + intensity) \ No newline at end of file diff --git a/pyims/pyims/handle.py b/pyims/pyims/handle.py index 49b00c0b..62de67af 100644 --- a/pyims/pyims/handle.py +++ b/pyims/pyims/handle.py @@ -24,7 +24,7 @@ def __init__(self, data_path: str): self.data_path = data_path self.bp: List[str] = obb.get_so_paths() self.meta_data = self.__load_meta_data() - self.precursor_frames = self.meta_data[self.meta_data["MsMsType"] == 0].Id.values + self.precursor_frames = self.meta_data[self.meta_data["MsMsType"] == 0].Id.values.astype(np.int32) self.__handle = None self.__current_index = 1 diff --git a/pyims/pyims/mixture.py b/pyims/pyims/mixture.py new file mode 100644 index 00000000..4588acd7 --- /dev/null +++ b/pyims/pyims/mixture.py @@ -0,0 +1,159 @@ +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import tensorflow as tf +import numpy as np +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class GaussianMixtureModel(tf.Module): + + def __init__(self, num_components: int, data_dim: int, prior_stdevs=None, lambda_scale=0.01, data=None, + init_means=None, init_stds=None): + """ + Initialize the Gaussian Mixture Model. + + Parameters: + - num_components: Number of Gaussian components. + - data_dim: Dimensionality of the data. + - prior_stdevs (optional): Prior knowledge about cluster extensions (standard deviations). + - lambda_scale: Regularization strength for the scales. + - data (optional): If provided and no init_means is given, initialize the component means by randomly selecting from this data. + - init_means (optional): Explicit initial means for the components. + - init_scales (optional): Explicit initial scales (variances) for the components. + """ + + # Initialize the locations of the GMM components + super().__init__() + + if init_means is not None: + assert init_means.shape == (num_components, + data_dim), f"init_means should have shape [num_components, data_dim], but got {init_means.shape}" + init_locs = tf.convert_to_tensor(init_means, dtype=tf.float32) + + elif data is not None: + indices = np.random.choice(data.shape[0], size=num_components, replace=True,) + init_locs = tf.convert_to_tensor(data[indices], dtype=tf.float32) + + else: + init_locs = tf.random.normal([num_components, data_dim]) + + self.locs = tf.Variable(init_locs, name="locs") + + if init_stds is not None: + assert init_stds.shape == (num_components, + data_dim), f"init_stds should have shape [num_components, data_dim], but got {init_stds.shape}" + assert tf.reduce_all(init_stds > 0), "All values in init_stds should be positive." + init_stds_vals = tf.repeat(tf.expand_dims(init_stds, axis=-1), num_components, axis=0) + else: + init_stds_default = [[3, 0.01, 0.01]] + init_stds_vals = tf.repeat(init_stds_default, num_components, axis=0) + + self.scales = tf.Variable(tf.math.log(init_stds_vals), name="scales") + + # Initialize the weights of the GMM components + self.weights = tf.Variable(tf.ones([num_components]), name="weights") + + # Set the prior scales and regularization strength + self.prior_stdevs = prior_stdevs + self.lambda_scale = lambda_scale + + def __call__(self, data): + """ + Calculate the log likelihood of the data given the current state of the GMM. + """ + + # Constructing the multivariate normal distribution with diagonal covariance + components_distribution = tfd.Independent(tfd.Normal(loc=self.locs, scale=tf.math.exp(self.scales)), + reinterpreted_batch_ndims=1) + + gmm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=tf.nn.log_softmax(self.weights)), + components_distribution=components_distribution) + + return gmm.log_prob(data) + + def fit(self, data, weights=None, num_steps=200, learning_rate=0.05, verbose=True): + """ + Fit the Gaussian Mixture Model to the data. + + Parameters: + - data: Input data of shape [n_samples, n_features]. + - weights (optional): Weights for each sample. + - num_steps: Number of optimization steps. + - learning_rate: Learning rate for the optimizer. + """ + if weights is None: + weights = tf.ones(len(data)) + + optimizer = tf.optimizers.Adam(learning_rate=learning_rate) + + @tf.function + def train_step(): + with tf.GradientTape() as tape: + log_likelihood = self.__call__(data) + loss = -tf.reduce_sum(log_likelihood * weights) # Weighted negative log likelihood + + # Add regularization based on prior scales if provided + + if self.prior_stdevs is not None: + scale_diff = tf.math.exp(self.scales) - self.prior_stdevs + reg_loss = self.lambda_scale * tf.reduce_sum(scale_diff * scale_diff) + loss += reg_loss + + gradients = tape.gradient(loss, [self.locs, self.scales, self.weights]) + optimizer.apply_gradients(zip(gradients, [self.locs, self.scales, self.weights])) + return loss + + for step in range(num_steps): + loss = train_step() + if step % 100 == 0 and verbose: + tf.print("step:", step, "log-loss:", loss) + + def __mixture(self): + """ + Creates a Gaussian Mixture Model from the current parameters (weights, means, and covariances). + """ + return tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical(logits=tf.nn.log_softmax(self.weights)), + components_distribution=tfd.Independent(tfd.Normal(loc=self.locs, scale=tf.math.exp(self.scales)), + reinterpreted_batch_ndims=1)) + + @property + def variances(self): + """Returns the actual variances (squared scales) of the Gaussian components.""" + return tf.math.exp(2 * self.scales) + + @property + def stddevs(self): + """Returns the actual standard deviations (scales) of the Gaussian components.""" + return tf.math.exp(self.scales) + + def predict_proba(self, data): + gmm = self.__mixture() + + # Calculate the log probabilities for each data point for each component + log_probs = gmm.components_distribution.log_prob(tf.transpose(data[..., tf.newaxis], [0, 2, 1])) + + # Convert log probabilities to unnormalized probabilities + unnormalized_probs = tf.exp(log_probs) + + # Normalize the probabilities + probs_sum = tf.reduce_sum(unnormalized_probs, axis=-1, keepdims=True) + normalized_probs = unnormalized_probs / probs_sum + + return normalized_probs.numpy() + + def predict(self, data): + """Get the cluster ids under the current mixture model""" + return np.argmax(self.predict_proba(data), axis=1) + + def sample(self, n_samples=1): + gmm = self.__mixture() + + # Sample from the Gaussian Mixture Model + samples = gmm.sample(n_samples) + + return samples diff --git a/pyims/pyims/slice.py b/pyims/pyims/slice.py index 719d911f..ec177d34 100644 --- a/pyims/pyims/slice.py +++ b/pyims/pyims/slice.py @@ -13,17 +13,17 @@ def __int__(self): self.__current_index = 0 @classmethod - def from_py_tims_slice(cls, slice: pims.PyTimsSlice): + def from_py_tims_slice(cls, tims_slice: pims.PyTimsSlice): """Create a TimsSlice from a PyTimsSlice. Args: - slice (pims.PyTimsSlice): PyTimsSlice to create the TimsSlice from. + tims_slice (pims.PyTimsSlice): PyTimsSlice to create the TimsSlice from. Returns: TimsSlice: TimsSlice created from the PyTimsSlice. """ instance = cls.__new__(cls) - instance.__slice_ptr = slice + instance.__slice_ptr = tims_slice instance.__current_index = 0 return instance @@ -57,7 +57,7 @@ def fragments(self): return TimsSlice.from_py_tims_slice(self.__slice_ptr.get_fragments_dda()) def filter(self, mz_min: float = 0.0, mz_max: float = 2000.0, scan_min: int = 0, scan_max: int = 1000, - intensity_min: float = 0.0, intensity_max: float = 1e9, num_threads: int = 4) -> 'TimsSlice': + intensity_min: float = 0.0, intensity_max: float = 1e9, num_threads: int = 4) -> 'TimsSlice': """Filter the slice by m/z, scan and intensity. Args: @@ -72,7 +72,8 @@ def filter(self, mz_min: float = 0.0, mz_max: float = 2000.0, scan_min: int = 0, Returns: TimsSlice: Filtered slice. """ - return TimsSlice.from_py_tims_slice(self.__slice_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max, intensity_min, intensity_max, num_threads)) + return TimsSlice.from_py_tims_slice(self.__slice_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max, + intensity_min, intensity_max, num_threads)) @property def frames(self) -> List[TimsFrame]: @@ -108,7 +109,8 @@ def to_windows(self, window_length: float = 10, overlapping: bool = True, min_nu Returns: List[MzSpectrum]: List of windows. """ - return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__slice_ptr.to_windows(window_length, overlapping, min_num_peaks, min_intensity, num_threads)] + return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__slice_ptr.to_windows( + window_length, overlapping, min_num_peaks, min_intensity, num_threads)] @property def df(self) -> pd.DataFrame: diff --git a/pyims/pyims/spectrum.py b/pyims/pyims/spectrum.py index 10b76995..3ded7dd4 100644 --- a/pyims/pyims/spectrum.py +++ b/pyims/pyims/spectrum.py @@ -110,9 +110,6 @@ def vectorized(self, resolution: int = 2) -> 'MzSpectrumVectorized': """ return MzSpectrumVectorized.from_py_mz_spectrum_vectorized(self.__spec_ptr.vectorized(resolution)) - def __repr__(self): - return f"MzSpectrum(num_peaks={len(self.mz)})" - class MzSpectrumVectorized: def __init__(self, indices: NDArray[np.int32], values: NDArray[np.float64], resolution: int):