-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Alexey Shevtsov
committed
Dec 29, 2020
1 parent
6d1520e
commit 10fdb97
Showing
64 changed files
with
1,955,158 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,87 @@ | ||
# low-resolution | ||
# Accelerating 3D Medical Image Segmentation by Adaptive Small-Scale Target Localization | ||
Code release | ||
|
||
## Table of Contents | ||
* [Requirements](#requirements) | ||
* [Repository Structure](#repository-structure) | ||
* [Experiment Reproduction](#experiment-reproduction) | ||
|
||
|
||
## Requirements | ||
- [Python](https://www.python.org) (v3.6 or later) | ||
- [Deep Pipe](https://github.com/neuro-ml/deep_pipe) (commit: [4383211ea312c098d710fbeacc05151e10a27e80](https://github.com/neuro-ml/deep_pipe/tree/4383211ea312c098d710fbeacc05151e10a27e80)) | ||
- [imageio](https://pypi.org/project/imageio/) (v 2.8.0) | ||
- [NiBabel](https://pypi.org/project/nibabel/) (v3.0.2) | ||
- [NumPy](http://numpy.org/) (v1.17.0 or later) | ||
- [OpenCV python](https://pypi.org/project/opencv-python/) (v4.2.0.32) | ||
- [Pandas](https://pandas.pydata.org/) (v1.0.1 or later) | ||
- [pdp](https://pypi.org/project/pdp/) (v 0.3.0) | ||
- [pydicom](https://pypi.org/project/pydicom/) (v 1.4.2) | ||
- [resource-manager](https://pypi.org/project/resource-manager/) (v 0.11.1) | ||
- [SciPy library](https://www.scipy.org/scipylib/index.html) (v0.19.0 or later) | ||
- [scikit-image](https://scikit-image.org) (v0.15.0 or later) | ||
- [Simple ITK](http://www.simpleitk.org/) (v1.2.4) | ||
- [torch](https://pypi.org/project/torch/) (v1.1.0 or later) | ||
- [tqdm](https://tqdm.github.io) (v4.32.0 or later) | ||
|
||
## Repository Structure | ||
``` | ||
├── config | ||
│ ├── assets | ||
│ └── exp_holdout | ||
├── lowres | ||
│ ├── benchmark | ||
│ │ ├── benchmark_time.sh | ||
│ │ └── model_predict.py | ||
│ ├── dataset | ||
│ │ └── luna.py | ||
│ ├── model | ||
│ ├── path.py | ||
│ └── ... | ||
├── model | ||
├── notebook | ||
│ ├── data_preprocessing | ||
│ │ ├── LUNA16_download.ipynb | ||
│ │ └── LUNA16_preprocessing.ipynb | ||
│ └── time_performance.ipynb | ||
└── README.md | ||
``` | ||
Download and preprocessing for the LUNA16 dataset can be done via IPython notebooks located at `notebook/data_preprocessing`. Also, the time performance diagrams could be built `notebook/results`. | ||
|
||
The pre-trained models' weights can be found in the `model` folder. She source code for each of them is located at the `lowres/model` folder. The hyperparameters for these models (e.g., patch_size, batch_size, etc.) are stored in `*.config` files at `config/assets/model`. | ||
|
||
All the necessary paths should be specified inside `lowres/path.py`. These are: | ||
- `luna_raw_path` -- where to download the raw LUNA16 files | ||
- `luna_data_path` -- where the preprocessed, structured files will be saved | ||
- `path_to_pretrained_model_x8` -- where the ModelX8 weights are located | ||
|
||
Alternatively pretrained ModelX8 (for LUNA16) could be found in | ||
`~/low-resolution/model/model_x8.pth`. | ||
|
||
Finally, the script `lowres/benchmark/benchmark_time.sh` can estimate the time, the given model spend to process the | ||
chosen amount of scans. By default the whole dataset used, so the desired number can be specified inside `lowres/benchmark/model_predict.sh`. The single argument of the script -- the desired number of threads (e.g., `8`). | ||
|
||
## Experiment Reproduction | ||
To run a single experiment please follow the steps below: | ||
|
||
First, the experiment structure must be created: | ||
``` | ||
python -m dpipe build_experiment --config_path "$1" --experiment_path "$2" | ||
``` | ||
|
||
where the first argument is a path to the `.config` file e.g. | ||
`"~/low-resolution/config/exp_holdout/unet3d.config"` | ||
and the second is a path to the folder, where the experiment structure will be organized e.g. | ||
`"~/unet3d_experiment/"`. | ||
|
||
Then, to run an experiment please go to the experiment folder inside the created structure: | ||
``` | ||
cd ~/unet3d_experiment/experiment_0/ | ||
``` | ||
and call the following command to start the experiment: | ||
``` | ||
python -m dpipe run_experiment --config_path "../resources.config" | ||
``` | ||
where `resources.config` is the general `.config` file of the experiment. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from functools import partial | ||
|
||
import numpy as np | ||
|
||
from dpipe.batch_iter import Infinite, load_by_random_id | ||
from lowres.batch_iter import extract_patch, center_choice | ||
|
||
patient_sampling_weights = n_tumors / np.sum(n_tumors) | ||
load_centers = dataset.load_tumor_centers | ||
|
||
batch_iter = Infinite( | ||
load_by_random_id(load_x, load_y, load_centers, | ||
ids=train_ids, weights=patient_sampling_weights, random_state=seed), | ||
partial(center_choice, y_patch_size=y_patch_size, nonzero_fraction=0.5, tumor_sampling=True), | ||
partial(extract_patch, x_patch_size=x_patch_size, y_patch_size=y_patch_size), | ||
batch_size=batch_size, batches_per_epoch=batches_per_epoch, buffer_size=8 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from functools import partial | ||
|
||
from dpipe.batch_iter import Infinite, load_by_random_id | ||
from lowres.batch_iter import extract_patch, center_choice | ||
from lowres.cv import get_connected_components | ||
|
||
# batch iter: | ||
tumor_sampling = True | ||
patient_sampling_weights = n_tumors / np.sum(n_tumors) | ||
load_centers = dataset.load_tumor_centers | ||
|
||
load_cc = partial(dataset.load_cc, get_cc_fn=get_connected_components) | ||
|
||
batch_iter = Infinite( | ||
load_by_random_id(load_x, load_y, load_cc, load_centers, | ||
ids=train_ids, weights=patient_sampling_weights), | ||
partial(center_choice, y_patch_size=y_patch_size, nonzero_fraction=nonzero_fraction, tumor_sampling=tumor_sampling), | ||
One2One(partial(extract_patch, x_patch_size=x_patch_size, y_patch_size=y_patch_size), buffer_size=16, n_workers=4), | ||
batch_size=batch_size, batches_per_epoch=batches_per_epoch, buffer_size=16 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from functools import partial | ||
|
||
import torch | ||
|
||
import dpipe.commands as commands | ||
from dpipe.config import if_missing, lock_dir | ||
from dpipe.io import ConsoleArguments, load_json | ||
from dpipe.medim.utils import identity | ||
from dpipe.experiment import flat | ||
from dpipe.train import train, CheckpointManager, Policy | ||
from dpipe.train.logging import TBLogger | ||
from dpipe.train.policy import Schedule | ||
from dpipe.train.validator import compute_metrics | ||
from dpipe.torch import save_model_state, load_model_state | ||
from lowres.torch.model import train_step_with_x8 | ||
from lowres.torch.functional import dice_loss_with_logits | ||
from lowres.utils import fix_seed | ||
from lowres.metric import evaluate_individual_metrics_with_prc | ||
|
||
console = ConsoleArguments() | ||
|
||
# ### 1. PATHS and IDS ### | ||
|
||
config_path = console(config_path=__file__) | ||
experiment_path = console.experiment_path | ||
|
||
log_path = 'train_logs' | ||
saved_model_path = 'model.pth' | ||
test_predictions_path = 'test_predictions' | ||
logit_predictions_path = 'logit_predictions' | ||
checkpoints_path = 'checkpoints' | ||
|
||
train_ids = load_json('train_ids.json') | ||
val_ids = load_json('val_ids.json') | ||
test_ids = load_json('test_ids.json') | ||
|
||
# ### 2. BUILD EXPERIMENT ### | ||
|
||
n_chans_in = dataset.n_chans_image | ||
n_chans_out = 1 | ||
|
||
load_x = dataset.load_image | ||
load_y = dataset.load_segm | ||
|
||
build_experiment = flat( | ||
config_path=config_path, | ||
experiment_path=experiment_path, | ||
split=split | ||
) | ||
|
||
# ### 3. TRAIN MODEL ### | ||
|
||
# 3.1. hyper parameters, lr policy, optimizer | ||
n_epochs = 100 | ||
batches_per_epoch = 100 | ||
|
||
lr_init = 1e-2 | ||
epoch2lr_dec_mul = {80: 0.1, } | ||
lr_policy = Schedule(initial=lr_init, epoch2value_multiplier=epoch2lr_dec_mul) | ||
policies = {} | ||
|
||
device = 'cuda' | ||
|
||
optimizer = torch.optim.SGD( | ||
architecture.parameters(), | ||
lr=lr_init, | ||
momentum=0.9, | ||
nesterov=True, | ||
weight_decay=0 | ||
) | ||
|
||
# 3.2 validation | ||
val_metrics = {} | ||
val_predict = predict | ||
|
||
validate_step = partial(compute_metrics, predict=val_predict, | ||
load_x=load_x, load_y=load_y, ids=val_ids, metrics=val_metrics) | ||
|
||
# 3.3 train | ||
logger = TBLogger(log_path=log_path) | ||
criterion = dice_loss_with_logits | ||
train_kwargs = dict(lr=lr_policy, architecture=architecture, optimizer=optimizer, criterion=criterion) | ||
|
||
checkpoint_manager = CheckpointManager(checkpoints_path, { | ||
**{k: v for k, v in train_kwargs.items() if isinstance(v, Policy)}, | ||
'model.pth': architecture, 'optimizer.pth': optimizer | ||
}) | ||
|
||
scale_factor = None | ||
train_step = partial(train_step_with_x8, scale_factor=scale_factor) | ||
|
||
train_model = train( | ||
train_step=train_step, | ||
batch_iter=batch_iter, | ||
n_epochs=n_epochs, | ||
logger=logger, | ||
checkpoint_manager=checkpoint_manager, | ||
validate=validate_step, | ||
**train_kwargs | ||
) | ||
|
||
# ### 5. RUN EXPERIMENT ### | ||
|
||
predict_to_dir = partial(commands.predict, ids=test_ids, load_x=load_x, predict_fn=predict) | ||
predict_logits_to_dir = partial(commands.predict, ids=test_ids, load_x=load_x, predict_fn=predict_logit) | ||
|
||
command_evaluate_individual_metrics = partial( | ||
evaluate_individual_metrics_with_prc, | ||
load_y_true=identity, | ||
metrics=final_metrics, | ||
predictions_path=test_predictions_path, | ||
logits_path=logit_predictions_path, | ||
) | ||
|
||
seed = 0xBadCafe | ||
|
||
# resource-manager execute sequence below: | ||
# ########################################## | ||
run_experiment = ( | ||
fix_seed(seed=seed), | ||
lock_dir(), | ||
architecture.to(device), | ||
if_missing(lambda p: [train_model, save_model_state(architecture, p)], saved_model_path), | ||
load_model_state(architecture, saved_model_path), | ||
if_missing(predict_logits_to_dir, output_path=logit_predictions_path), | ||
if_missing(predict_to_dir, output_path=test_predictions_path), | ||
if_missing(command_evaluate_individual_metrics, results_path='test_metrics'), | ||
) | ||
# ########################################## | ||
|
||
|
||
# TO INITIALIZE: | ||
# dataset -> config.assets.dataset.(wmh, met, ) | ||
# split -> config.assets.cross_val.(cv5, loo, ) | ||
# architecture, predict -> config.assets.model.(dm39, dm39_met, unet, unet_met, vnet, ) | ||
# batch_iter -> config.assets.batch_iter.(strat, tumor_sampling, with_weights, ) | ||
# val_metrics, final_metrics -> config.assets.metric |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from sklearn.model_selection import train_test_split | ||
|
||
ids_train = list(df[(df['split'] == 'train') | (df['split'] == 'val')].index) | ||
ids_holdout = list(df[df['split'] == 'holdout'].index) | ||
|
||
val_size = 5 | ||
train_val_ids = train_test_split(ids_train, test_size=val_size, random_state=seed) | ||
split = [[train_val_ids[0], train_val_ids[1], ids_holdout]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from sklearn.model_selection import train_test_split | ||
|
||
ids_train = list(df[df['split'] == 'train'].index) | ||
ids_holdout = list(df[df['split'] == 'val'].index) | ||
|
||
val_size = 5 | ||
train_val_ids = train_test_split(ids_train, test_size=val_size, random_state=seed) | ||
split = [[train_val_ids[0], train_val_ids[1], ids_holdout]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from lowres.dataset.luna import LUNA, get_n_tumors, scale_ct, apply_mask | ||
from dpipe.dataset.wrappers import cache_methods, apply | ||
# from lowres.path_local import luna_data_path | ||
from lowres.path import luna_data_path | ||
|
||
data_path = luna_data_path | ||
modalities = ['CT', 'lung_mask'] | ||
|
||
dataset = cache_methods( | ||
instance=apply( | ||
instance=apply_mask( | ||
dataset=LUNA( | ||
data_path=data_path, | ||
modalities=modalities | ||
), | ||
mask_modality_id=-1, | ||
mask_value=1 | ||
), | ||
load_image=scale_ct | ||
), | ||
methods=['load_image', 'load_segm', 'load_centers', 'load_tumor_centers', 'load_shape'] | ||
) | ||
df = dataset.df | ||
|
||
n_tumors = get_n_tumors(ids=train_ids, df=df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from functools import partial | ||
|
||
from dpipe.medim.metrics import dice_score, aggregate_metric | ||
from lowres.utils import get_pred | ||
from lowres.metric import prc_records | ||
|
||
dice_metric = lambda x, y: dice_score(get_pred(x), get_pred(y)) | ||
|
||
val_metrics = { | ||
'dice_scores': partial(aggregate_metric, metric=dice_metric), | ||
} | ||
|
||
dice_metric_from_id = lambda i, y_pred: dice_metric(load_y(i), y_pred) | ||
prc_metric_from_id = lambda i, y_pred, y_logit: prc_records(load_y(i), y_pred, y_logit) | ||
|
||
final_metrics = { | ||
'dice': dice_metric_from_id, | ||
'prc_records': prc_metric_from_id | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from dpipe.medim.shape_ops import pad | ||
from dpipe.predict.shape import divisible_shape, add_extract_dims, patches_grid | ||
from dpipe.predict.functional import preprocess | ||
from dpipe.torch.model import inference_step | ||
from lowres.model.deepmedic39 import get_dm39 | ||
|
||
# parameters for batch_iter | ||
x_patch_size = [87, 87, 87] | ||
y_patch_size = [39, 39, 39] | ||
batch_size = 12 | ||
nonzero_fraction = 0.25 | ||
|
||
# MODEL | ||
architecture = get_dm39(n_chans_in=n_chans_in, n_chans_out=n_chans_out) | ||
|
||
# PREDICT | ||
patch_size = np.array([90] * 3) | ||
patch_stride = np.array([75] * 3) | ||
|
||
|
||
@add_extract_dims() | ||
@patches_grid(patch_size, patch_stride) | ||
@divisible_shape(divisor=[3] * 3, padding_values=np.min) | ||
@preprocess(pad, padding=[[24] * 2] * 3, padding_values=np.min) | ||
def predict(x): | ||
return inference_step(x, architecture=architecture, activation=torch.sigmoid) | ||
|
||
|
||
@add_extract_dims() | ||
@patches_grid(patch_size, patch_stride) | ||
@divisible_shape(divisor=[3] * 3, padding_values=np.min) | ||
@preprocess(pad, padding=[[24] * 2] * 3, padding_values=np.min) | ||
def predict_logit(x): | ||
return inference_step(x, architecture=architecture) |
Oops, something went wrong.