-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ac38397
commit a4eff3b
Showing
6 changed files
with
144 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
import cv2 | ||
|
||
def __rotate_CV(image, angel , interpolation): | ||
#in OpenCV we need to form the tranformation matrix and apply affine calculations | ||
#interpolation cv2.INTER_CUBIC (slow) & cv2.INTER_LINEAR | ||
h,w = image.shape[:2] | ||
cX,cY = (w//2,h//2) | ||
M = cv2.getRotationMatrix2D((cX,cY),angel,1) | ||
rotated = cv2.warpAffine(image,M , (w,h),flags=interpolation) | ||
return rotated | ||
|
||
def __rotate_data(angle: float, axis: int, data: np.ndarray) -> np.ndarray: | ||
unstacked = np.split(data, data.shape[axis], axis) | ||
unstacked = [np.squeeze(x) for x in unstacked] | ||
rotated = [__rotate_CV(x, angle, cv2.INTER_LINEAR) for x in unstacked] | ||
return np.stack(rotated, axis) | ||
|
||
def random_rotate_data_and_label(data: np.ndarray, label: np.ndarray) -> tuple[np.ndarray, np.ndarray]: | ||
for i in range(len(data.shape)): | ||
if np.random.rand() > 0.5: | ||
data = np.flip(data, i) | ||
label = np.flip(label, i) | ||
return data, label | ||
|
||
def random_rotate_data(angle1: float, angle2: float, data: np.ndarray) -> np.ndarray: | ||
angle = np.random.uniform(angle1, angle2) | ||
axis = np.random.randint(0, data.ndim) | ||
|
||
return __rotate_data(angle, axis, data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import tensorflow as tf | ||
import os | ||
|
||
from .att_unet import res_block, get_gating_base | ||
|
||
def craft_network(config: dict): | ||
checkpoint_file = config.MODEL_CHECKPOINT | ||
apply_batchnorm = config.BATCH_NORM | ||
|
||
filters = [16, 32, 64, 128, 256] | ||
|
||
inputs = tf.keras.layers.Input(shape = [config.IMAGE_DIMENSION_X, config.IMAGE_DIMENSION_Y, config.IMAGE_DIMENSION_Z, 1]) | ||
|
||
x = inputs | ||
generator_steps_output = [] | ||
for idx, _filter in enumerate(filters): | ||
x = res_block(_filter, x.shape, config)(x) | ||
generator_steps_output.append(x) | ||
if idx < len(filters) - 1: | ||
x = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2), padding = "same")(x) | ||
|
||
x = tf.keras.layers.Flatten()(x) | ||
x = tf.keras.layers.Dense(1, activation="sigmoid")(x) | ||
|
||
model = tf.keras.models.Model(inputs = inputs, outputs = x) | ||
|
||
if checkpoint_file and os.path.exists(checkpoint_file): | ||
print("Loading weights from checkpoint ", checkpoint_file) | ||
model(tf.ones(shape=(1, config.IMAGE_DIMENSION_X, config.IMAGE_DIMENSION_Y, config.IMAGE_DIMENSION_Z, 1))) | ||
model.load_weights(checkpoint_file) | ||
else: | ||
print("Checkpoint file {} not found".format(checkpoint_file)) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from . import att_unet_dsv | ||
from . import classification | ||
|
||
def model_factory(config): | ||
""" | ||
Factory function to create a model based on the configuration | ||
""" | ||
if config.TASK_TYPE == "segmentation": | ||
return att_unet_dsv.craft_network(config) | ||
elif config.TASK_TYPE == "classification": | ||
return classification.craft_network(config) | ||
else: | ||
raise ValueError("Model type {} not supported".format(config.MODEL_TYPE)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import os | ||
import sys | ||
|
||
def setup(): | ||
project_folder = os.path.dirname(os.path.dirname(__file__)) | ||
src_folder = os.path.join(project_folder, "src") | ||
sys.path.append(src_folder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import tensorflow as tf | ||
import unittest | ||
|
||
import context | ||
context.setup() | ||
from pancreas_ai.tools import resize_3d | ||
|
||
class Resize_3d(unittest.TestCase): | ||
''' | ||
Unit test to resize_3d function | ||
''' | ||
def test_resize_to_same_size(self): | ||
arr1 = tf.range(0, 27, dtype = tf.int32) | ||
arr2 = tf.reshape(arr1, [3, 3, 3]) | ||
|
||
arr3 = resize_3d.resize_3d_image(arr2, tf.constant([3, 3, 3])) | ||
result = tf.logical_not(tf.cast(tf.reshape(arr2-arr3, [-1]), dtype=tf.bool)).numpy().all() | ||
|
||
self.assertEqual(result, True, msg = "resize to the same size failed") | ||
|
||
def test_resize_twice_and_back(self): | ||
arr1 = tf.range(0, 27, dtype = tf.int32) | ||
arr2 = tf.reshape(arr1, [3, 3, 3]) | ||
|
||
arr3 = resize_3d.resize_3d_image(arr2, tf.constant([6, 6, 6])) | ||
arr4 = resize_3d.resize_3d_image(arr3, tf.constant([3, 3, 3])) | ||
|
||
result = tf.logical_not(tf.cast(tf.reshape(arr2-arr4, [-1]), dtype=tf.bool)).numpy().all() | ||
|
||
self.assertEqual(result, True, msg = "resize twice and back failed") | ||
|
||
def test_resize_triple_and_back(self): | ||
arr1 = tf.range(0, 27, dtype = tf.int32) | ||
arr2 = tf.reshape(arr1, [3, 3, 3]) | ||
|
||
arr3 = resize_3d.resize_3d_image(arr2, tf.constant([9, 9, 9])) | ||
arr4 = resize_3d.resize_3d_image(arr3, tf.constant([3, 3, 3])) | ||
|
||
result = tf.logical_not(tf.cast(tf.reshape(arr2-arr4, [-1]), dtype=tf.bool)).numpy().all() | ||
|
||
self.assertEqual(result, True, msg = "resize triple and back failed") | ||
|
||
def test_resize_triple_and_back_xl(self): | ||
cube_size = 64 | ||
arr2 = tf.random.uniform([cube_size, cube_size, cube_size], dtype = tf.float32) | ||
|
||
arr3 = resize_3d.resize_3d_image(arr2, tf.constant([cube_size*3, cube_size*3, cube_size*3])) | ||
arr4 = resize_3d.resize_3d_image(arr3, tf.constant([cube_size, cube_size, cube_size])) | ||
|
||
result = tf.logical_not(tf.cast(tf.reshape(arr2-arr4, [-1]), dtype=tf.bool)).numpy().all() | ||
|
||
self.assertEqual(result, True, msg = "resize triple and back failed") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
|