Skip to content

Commit

Permalink
Merge pull request #37 from theGreatHerrLebert/david@mixture
Browse files Browse the repository at this point in the history
David@mixture
  • Loading branch information
theGreatHerrLebert authored Oct 25, 2023
2 parents 993c0c6 + 51c5916 commit f9380c2
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 14 deletions.
12 changes: 8 additions & 4 deletions pyims/pyims/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def filter(self,
TimsFrame: Filtered frame.
"""

return TimsFrame.from_py_tims_frame(self.__frame_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max, intensity_min, intensity_max))
return TimsFrame.from_py_tims_frame(self.__frame_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max,
intensity_min, intensity_max))

def to_resolution(self, resolution: int) -> 'TimsFrame':
"""Convert the frame to a given resolution.
Expand All @@ -179,7 +180,8 @@ def to_tims_spectra(self) -> List['TimsSpectrum']:
"""
return [TimsSpectrum.from_py_tims_spectrum(spec) for spec in self.__frame_ptr.to_tims_spectra()]

def to_windows(self, window_length: float = 10, overlapping: bool = True, min_num_peaks: int = 5, min_intensity: float = 1) -> List[MzSpectrum]:
def to_windows(self, window_length: float = 10, overlapping: bool = True, min_num_peaks: int = 5,
min_intensity: float = 1) -> List[MzSpectrum]:
"""Convert the frame to a list of windows.
Args:
Expand All @@ -191,7 +193,8 @@ def to_windows(self, window_length: float = 10, overlapping: bool = True, min_nu
Returns:
List[MzSpectrum]: List of windows.
"""
return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__frame_ptr.to_windows(window_length, overlapping, min_num_peaks, min_intensity)]
return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__frame_ptr.to_windows(
window_length, overlapping, min_num_peaks, min_intensity)]

def __repr__(self):
return (f"TimsFrame(frame_id={self.__frame_ptr.frame_id}, ms_type={self.__frame_ptr.ms_type}, "
Expand Down Expand Up @@ -221,4 +224,5 @@ def __init__(self, frame_id: int, ms_type: int, retention_time: float, scan: NDA
assert len(scan) == len(mobility) == len(tof) == len(indices) == len(intensity), \
"The length of the scan, mobility, tof, indices and intensity arrays must be equal."

self.__frame_ptr = pims.PyTimsFrameVectorized(frame_id, ms_type, retention_time, scan, mobility, tof, indices, intensity)
self.__frame_ptr = pims.PyTimsFrameVectorized(frame_id, ms_type, retention_time, scan, mobility, tof, indices,
intensity)
2 changes: 1 addition & 1 deletion pyims/pyims/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, data_path: str):
self.data_path = data_path
self.bp: List[str] = obb.get_so_paths()
self.meta_data = self.__load_meta_data()
self.precursor_frames = self.meta_data[self.meta_data["MsMsType"] == 0].Id.values
self.precursor_frames = self.meta_data[self.meta_data["MsMsType"] == 0].Id.values.astype(np.int32)
self.__handle = None
self.__current_index = 1

Expand Down
159 changes: 159 additions & 0 deletions pyims/pyims/mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp

tfd = tfp.distributions


class GaussianMixtureModel(tf.Module):

def __init__(self, num_components: int, data_dim: int, prior_stdevs=None, lambda_scale=0.01, data=None,
init_means=None, init_stds=None):
"""
Initialize the Gaussian Mixture Model.
Parameters:
- num_components: Number of Gaussian components.
- data_dim: Dimensionality of the data.
- prior_stdevs (optional): Prior knowledge about cluster extensions (standard deviations).
- lambda_scale: Regularization strength for the scales.
- data (optional): If provided and no init_means is given, initialize the component means by randomly selecting from this data.
- init_means (optional): Explicit initial means for the components.
- init_scales (optional): Explicit initial scales (variances) for the components.
"""

# Initialize the locations of the GMM components
super().__init__()

if init_means is not None:
assert init_means.shape == (num_components,
data_dim), f"init_means should have shape [num_components, data_dim], but got {init_means.shape}"
init_locs = tf.convert_to_tensor(init_means, dtype=tf.float32)

elif data is not None:
indices = np.random.choice(data.shape[0], size=num_components, replace=True,)
init_locs = tf.convert_to_tensor(data[indices], dtype=tf.float32)

else:
init_locs = tf.random.normal([num_components, data_dim])

self.locs = tf.Variable(init_locs, name="locs")

if init_stds is not None:
assert init_stds.shape == (num_components,
data_dim), f"init_stds should have shape [num_components, data_dim], but got {init_stds.shape}"
assert tf.reduce_all(init_stds > 0), "All values in init_stds should be positive."
init_stds_vals = tf.repeat(tf.expand_dims(init_stds, axis=-1), num_components, axis=0)
else:
init_stds_default = [[3, 0.01, 0.01]]
init_stds_vals = tf.repeat(init_stds_default, num_components, axis=0)

self.scales = tf.Variable(tf.math.log(init_stds_vals), name="scales")

# Initialize the weights of the GMM components
self.weights = tf.Variable(tf.ones([num_components]), name="weights")

# Set the prior scales and regularization strength
self.prior_stdevs = prior_stdevs
self.lambda_scale = lambda_scale

def __call__(self, data):
"""
Calculate the log likelihood of the data given the current state of the GMM.
"""

# Constructing the multivariate normal distribution with diagonal covariance
components_distribution = tfd.Independent(tfd.Normal(loc=self.locs, scale=tf.math.exp(self.scales)),
reinterpreted_batch_ndims=1)

gmm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=tf.nn.log_softmax(self.weights)),
components_distribution=components_distribution)

return gmm.log_prob(data)

def fit(self, data, weights=None, num_steps=200, learning_rate=0.05, verbose=True):
"""
Fit the Gaussian Mixture Model to the data.
Parameters:
- data: Input data of shape [n_samples, n_features].
- weights (optional): Weights for each sample.
- num_steps: Number of optimization steps.
- learning_rate: Learning rate for the optimizer.
"""
if weights is None:
weights = tf.ones(len(data))

optimizer = tf.optimizers.Adam(learning_rate=learning_rate)

@tf.function
def train_step():
with tf.GradientTape() as tape:
log_likelihood = self.__call__(data)
loss = -tf.reduce_sum(log_likelihood * weights) # Weighted negative log likelihood

# Add regularization based on prior scales if provided

if self.prior_stdevs is not None:
scale_diff = tf.math.exp(self.scales) - self.prior_stdevs
reg_loss = self.lambda_scale * tf.reduce_sum(scale_diff * scale_diff)
loss += reg_loss

gradients = tape.gradient(loss, [self.locs, self.scales, self.weights])
optimizer.apply_gradients(zip(gradients, [self.locs, self.scales, self.weights]))
return loss

for step in range(num_steps):
loss = train_step()
if step % 100 == 0 and verbose:
tf.print("step:", step, "log-loss:", loss)

def __mixture(self):
"""
Creates a Gaussian Mixture Model from the current parameters (weights, means, and covariances).
"""
return tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(logits=tf.nn.log_softmax(self.weights)),
components_distribution=tfd.Independent(tfd.Normal(loc=self.locs, scale=tf.math.exp(self.scales)),
reinterpreted_batch_ndims=1))

@property
def variances(self):
"""Returns the actual variances (squared scales) of the Gaussian components."""
return tf.math.exp(2 * self.scales)

@property
def stddevs(self):
"""Returns the actual standard deviations (scales) of the Gaussian components."""
return tf.math.exp(self.scales)

def predict_proba(self, data):
gmm = self.__mixture()

# Calculate the log probabilities for each data point for each component
log_probs = gmm.components_distribution.log_prob(tf.transpose(data[..., tf.newaxis], [0, 2, 1]))

# Convert log probabilities to unnormalized probabilities
unnormalized_probs = tf.exp(log_probs)

# Normalize the probabilities
probs_sum = tf.reduce_sum(unnormalized_probs, axis=-1, keepdims=True)
normalized_probs = unnormalized_probs / probs_sum

return normalized_probs.numpy()

def predict(self, data):
"""Get the cluster ids under the current mixture model"""
return np.argmax(self.predict_proba(data), axis=1)

def sample(self, n_samples=1):
gmm = self.__mixture()

# Sample from the Gaussian Mixture Model
samples = gmm.sample(n_samples)

return samples
14 changes: 8 additions & 6 deletions pyims/pyims/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@ def __int__(self):
self.__current_index = 0

@classmethod
def from_py_tims_slice(cls, slice: pims.PyTimsSlice):
def from_py_tims_slice(cls, tims_slice: pims.PyTimsSlice):
"""Create a TimsSlice from a PyTimsSlice.
Args:
slice (pims.PyTimsSlice): PyTimsSlice to create the TimsSlice from.
tims_slice (pims.PyTimsSlice): PyTimsSlice to create the TimsSlice from.
Returns:
TimsSlice: TimsSlice created from the PyTimsSlice.
"""
instance = cls.__new__(cls)
instance.__slice_ptr = slice
instance.__slice_ptr = tims_slice
instance.__current_index = 0
return instance

Expand Down Expand Up @@ -57,7 +57,7 @@ def fragments(self):
return TimsSlice.from_py_tims_slice(self.__slice_ptr.get_fragments_dda())

def filter(self, mz_min: float = 0.0, mz_max: float = 2000.0, scan_min: int = 0, scan_max: int = 1000,
intensity_min: float = 0.0, intensity_max: float = 1e9, num_threads: int = 4) -> 'TimsSlice':
intensity_min: float = 0.0, intensity_max: float = 1e9, num_threads: int = 4) -> 'TimsSlice':
"""Filter the slice by m/z, scan and intensity.
Args:
Expand All @@ -72,7 +72,8 @@ def filter(self, mz_min: float = 0.0, mz_max: float = 2000.0, scan_min: int = 0,
Returns:
TimsSlice: Filtered slice.
"""
return TimsSlice.from_py_tims_slice(self.__slice_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max, intensity_min, intensity_max, num_threads))
return TimsSlice.from_py_tims_slice(self.__slice_ptr.filter_ranged(mz_min, mz_max, scan_min, scan_max,
intensity_min, intensity_max, num_threads))

@property
def frames(self) -> List[TimsFrame]:
Expand Down Expand Up @@ -108,7 +109,8 @@ def to_windows(self, window_length: float = 10, overlapping: bool = True, min_nu
Returns:
List[MzSpectrum]: List of windows.
"""
return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__slice_ptr.to_windows(window_length, overlapping, min_num_peaks, min_intensity, num_threads)]
return [MzSpectrum.from_py_mz_spectrum(spec) for spec in self.__slice_ptr.to_windows(
window_length, overlapping, min_num_peaks, min_intensity, num_threads)]

@property
def df(self) -> pd.DataFrame:
Expand Down
3 changes: 0 additions & 3 deletions pyims/pyims/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ def vectorized(self, resolution: int = 2) -> 'MzSpectrumVectorized':
"""
return MzSpectrumVectorized.from_py_mz_spectrum_vectorized(self.__spec_ptr.vectorized(resolution))

def __repr__(self):
return f"MzSpectrum(num_peaks={len(self.mz)})"


class MzSpectrumVectorized:
def __init__(self, indices: NDArray[np.int32], values: NDArray[np.float64], resolution: int):
Expand Down

0 comments on commit f9380c2

Please sign in to comment.