Skip to content

Commit

Permalink
adding rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanKuchin committed Nov 6, 2024
1 parent ac38397 commit a4eff3b
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/pancreas_ai/dataset/ds_augmentation/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ def random_crop(self, data:npt.NDArray[np.float32], mask:npt.NDArray[np.float32]

def random_flip(self, data:npt.NDArray[np.float32], mask:npt.NDArray[np.float32]) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
return flip.random_flip_data(data), mask

def random_rotate(self, data:npt.NDArray[np.float32], mask:npt.NDArray[np.float32]) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
return rotate.random_rotate_data(data), mask
30 changes: 30 additions & 0 deletions src/pancreas_ai/dataset/ds_augmentation/rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import cv2

def __rotate_CV(image, angel , interpolation):
#in OpenCV we need to form the tranformation matrix and apply affine calculations
#interpolation cv2.INTER_CUBIC (slow) & cv2.INTER_LINEAR
h,w = image.shape[:2]
cX,cY = (w//2,h//2)
M = cv2.getRotationMatrix2D((cX,cY),angel,1)
rotated = cv2.warpAffine(image,M , (w,h),flags=interpolation)
return rotated

def __rotate_data(angle: float, axis: int, data: np.ndarray) -> np.ndarray:
unstacked = np.split(data, data.shape[axis], axis)
unstacked = [np.squeeze(x) for x in unstacked]
rotated = [__rotate_CV(x, angle, cv2.INTER_LINEAR) for x in unstacked]
return np.stack(rotated, axis)

def random_rotate_data_and_label(data: np.ndarray, label: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
for i in range(len(data.shape)):
if np.random.rand() > 0.5:
data = np.flip(data, i)
label = np.flip(label, i)
return data, label

def random_rotate_data(angle1: float, angle2: float, data: np.ndarray) -> np.ndarray:
angle = np.random.uniform(angle1, angle2)
axis = np.random.randint(0, data.ndim)

return __rotate_data(angle, axis, data)
34 changes: 34 additions & 0 deletions src/pancreas_ai/tools/craft_network/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import tensorflow as tf
import os

from .att_unet import res_block, get_gating_base

def craft_network(config: dict):
checkpoint_file = config.MODEL_CHECKPOINT
apply_batchnorm = config.BATCH_NORM

filters = [16, 32, 64, 128, 256]

inputs = tf.keras.layers.Input(shape = [config.IMAGE_DIMENSION_X, config.IMAGE_DIMENSION_Y, config.IMAGE_DIMENSION_Z, 1])

x = inputs
generator_steps_output = []
for idx, _filter in enumerate(filters):
x = res_block(_filter, x.shape, config)(x)
generator_steps_output.append(x)
if idx < len(filters) - 1:
x = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2), padding = "same")(x)

x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(1, activation="sigmoid")(x)

model = tf.keras.models.Model(inputs = inputs, outputs = x)

if checkpoint_file and os.path.exists(checkpoint_file):
print("Loading weights from checkpoint ", checkpoint_file)
model(tf.ones(shape=(1, config.IMAGE_DIMENSION_X, config.IMAGE_DIMENSION_Y, config.IMAGE_DIMENSION_Z, 1)))
model.load_weights(checkpoint_file)
else:
print("Checkpoint file {} not found".format(checkpoint_file))

return model
13 changes: 13 additions & 0 deletions src/pancreas_ai/tools/craft_network/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from . import att_unet_dsv
from . import classification

def model_factory(config):
"""
Factory function to create a model based on the configuration
"""
if config.TASK_TYPE == "segmentation":
return att_unet_dsv.craft_network(config)
elif config.TASK_TYPE == "classification":
return classification.craft_network(config)
else:
raise ValueError("Model type {} not supported".format(config.MODEL_TYPE))
7 changes: 7 additions & 0 deletions tests/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os
import sys

def setup():
project_folder = os.path.dirname(os.path.dirname(__file__))
src_folder = os.path.join(project_folder, "src")
sys.path.append(src_folder)
57 changes: 57 additions & 0 deletions tests/test_3d_rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import tensorflow as tf
import unittest

import context
context.setup()
from pancreas_ai.tools import resize_3d

class Resize_3d(unittest.TestCase):
'''
Unit test to resize_3d function
'''
def test_resize_to_same_size(self):
arr1 = tf.range(0, 27, dtype = tf.int32)
arr2 = tf.reshape(arr1, [3, 3, 3])

arr3 = resize_3d.resize_3d_image(arr2, tf.constant([3, 3, 3]))
result = tf.logical_not(tf.cast(tf.reshape(arr2-arr3, [-1]), dtype=tf.bool)).numpy().all()

self.assertEqual(result, True, msg = "resize to the same size failed")

def test_resize_twice_and_back(self):
arr1 = tf.range(0, 27, dtype = tf.int32)
arr2 = tf.reshape(arr1, [3, 3, 3])

arr3 = resize_3d.resize_3d_image(arr2, tf.constant([6, 6, 6]))
arr4 = resize_3d.resize_3d_image(arr3, tf.constant([3, 3, 3]))

result = tf.logical_not(tf.cast(tf.reshape(arr2-arr4, [-1]), dtype=tf.bool)).numpy().all()

self.assertEqual(result, True, msg = "resize twice and back failed")

def test_resize_triple_and_back(self):
arr1 = tf.range(0, 27, dtype = tf.int32)
arr2 = tf.reshape(arr1, [3, 3, 3])

arr3 = resize_3d.resize_3d_image(arr2, tf.constant([9, 9, 9]))
arr4 = resize_3d.resize_3d_image(arr3, tf.constant([3, 3, 3]))

result = tf.logical_not(tf.cast(tf.reshape(arr2-arr4, [-1]), dtype=tf.bool)).numpy().all()

self.assertEqual(result, True, msg = "resize triple and back failed")

def test_resize_triple_and_back_xl(self):
cube_size = 64
arr2 = tf.random.uniform([cube_size, cube_size, cube_size], dtype = tf.float32)

arr3 = resize_3d.resize_3d_image(arr2, tf.constant([cube_size*3, cube_size*3, cube_size*3]))
arr4 = resize_3d.resize_3d_image(arr3, tf.constant([cube_size, cube_size, cube_size]))

result = tf.logical_not(tf.cast(tf.reshape(arr2-arr4, [-1]), dtype=tf.bool)).numpy().all()

self.assertEqual(result, True, msg = "resize triple and back failed")


if __name__ == "__main__":
unittest.main()

0 comments on commit a4eff3b

Please sign in to comment.