Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed Jan 1, 2024
1 parent 72f22c3 commit 85afd32
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 197 deletions.
113 changes: 59 additions & 54 deletions src/cultionet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,15 @@ def fit_transfer(
dataset: EdgeDataset,
ckpt_file: T.Union[str, Path],
test_dataset: T.Optional[EdgeDataset] = None,
val_frac: T.Optional[float] = 0.2,
val_frac: float = 0.2,
spatial_partitions: T.Optional[T.Union[str, Path]] = None,
partition_name: T.Optional[str] = None,
partition_column: T.Optional[str] = None,
batch_size: T.Optional[int] = 4,
load_batch_workers: T.Optional[int] = 2,
accumulate_grad_batches: T.Optional[int] = 1,
filters: T.Optional[int] = 32,
num_classes: T.Optional[int] = 2,
batch_size: int = 4,
load_batch_workers: int = 2,
accumulate_grad_batches: int = 1,
filters: int = 32,
num_classes: int = 2,
edge_class: T.Optional[int] = None,
class_counts: T.Sequence[float] = None,
model_type: str = "ResUNet3Psi",
Expand All @@ -310,30 +310,31 @@ def fit_transfer(
deep_sup_edge: bool = False,
deep_sup_mask: bool = False,
optimizer: str = "AdamW",
learning_rate: T.Optional[float] = 1e-3,
learning_rate: float = 1e-3,
lr_scheduler: str = "CosineAnnealingLR",
steplr_step_size: T.Optional[T.Sequence[int]] = None,
scale_pos_weight: T.Optional[bool] = True,
epochs: T.Optional[int] = 30,
save_top_k: T.Optional[int] = 1,
early_stopping_patience: T.Optional[int] = 7,
early_stopping_min_delta: T.Optional[float] = 0.01,
gradient_clip_val: T.Optional[float] = 1.0,
gradient_clip_algorithm: T.Optional[float] = "norm",
reset_model: T.Optional[bool] = False,
auto_lr_find: T.Optional[bool] = False,
device: T.Optional[str] = "gpu",
devices: T.Optional[int] = 1,
scale_pos_weight: bool = True,
epochs: int = 30,
save_top_k: int = 1,
early_stopping_patience: int = 7,
early_stopping_min_delta: float = 0.01,
gradient_clip_val: float = 1.0,
gradient_clip_algorithm: float = "norm",
reset_model: bool = False,
auto_lr_find: bool = False,
device: str = "gpu",
devices: int = 1,
profiler: T.Optional[str] = None,
weight_decay: T.Optional[float] = 1e-5,
precision: T.Optional[int] = 32,
stochastic_weight_averaging: T.Optional[bool] = False,
stochastic_weight_averaging_lr: T.Optional[float] = 0.05,
stochastic_weight_averaging_start: T.Optional[float] = 0.8,
model_pruning: T.Optional[bool] = False,
save_batch_val_metrics: T.Optional[bool] = False,
skip_train: T.Optional[bool] = False,
refine_model: T.Optional[bool] = False,
weight_decay: float = 1e-5,
precision: int = 32,
stochastic_weight_averaging: bool = False,
stochastic_weight_averaging_lr: float = 0.05,
stochastic_weight_averaging_start: float = 0.8,
model_pruning: bool = False,
save_batch_val_metrics: bool = False,
skip_train: bool = False,
refine_model: bool = False,
finetune: bool = False,
):
"""Fits a transfer model.
Expand Down Expand Up @@ -386,6 +387,7 @@ def fit_transfer(
save_batch_val_metrics (Optional[bool]): Whether to save batch validation metrics to a parquet file.
skip_train (Optional[bool]): Whether to refine and calibrate a trained model.
refine_model (Optional[bool]): Whether to skip training.
finetune (bool): Whether to finetune the transfer model. Otherwise, do feature extraction.
"""
# This file should already exist
pretrained_ckpt_file = Path(ckpt_file)
Expand Down Expand Up @@ -423,6 +425,7 @@ def fit_transfer(
deep_sup_mask=deep_sup_mask,
scale_pos_weight=scale_pos_weight,
edge_class=edge_class,
finetune=finetune,
)

if reset_model:
Expand Down Expand Up @@ -560,15 +563,15 @@ def fit(
dataset: EdgeDataset,
ckpt_file: T.Union[str, Path],
test_dataset: T.Optional[EdgeDataset] = None,
val_frac: T.Optional[float] = 0.2,
val_frac: float = 0.2,
spatial_partitions: T.Optional[T.Union[str, Path]] = None,
partition_name: T.Optional[str] = None,
partition_column: T.Optional[str] = None,
batch_size: T.Optional[int] = 4,
load_batch_workers: T.Optional[int] = 2,
accumulate_grad_batches: T.Optional[int] = 1,
filters: T.Optional[int] = 32,
num_classes: T.Optional[int] = 2,
batch_size: int = 4,
load_batch_workers: int = 2,
accumulate_grad_batches: int = 1,
filters: int = 32,
num_classes: int = 2,
edge_class: T.Optional[int] = None,
class_counts: T.Sequence[float] = None,
model_type: str = ModelTypes.RESUNET3PSI,
Expand All @@ -580,30 +583,31 @@ def fit(
deep_sup_edge: bool = False,
deep_sup_mask: bool = False,
optimizer: str = "AdamW",
learning_rate: T.Optional[float] = 1e-3,
learning_rate: float = 1e-3,
lr_scheduler: str = "CosineAnnealingLR",
steplr_step_size: T.Optional[T.Sequence[int]] = None,
scale_pos_weight: T.Optional[bool] = True,
epochs: T.Optional[int] = 30,
save_top_k: T.Optional[int] = 1,
early_stopping_patience: T.Optional[int] = 7,
early_stopping_min_delta: T.Optional[float] = 0.01,
gradient_clip_val: T.Optional[float] = 1.0,
gradient_clip_algorithm: T.Optional[float] = "norm",
reset_model: T.Optional[bool] = False,
auto_lr_find: T.Optional[bool] = False,
device: T.Optional[str] = "gpu",
devices: T.Optional[int] = 1,
scale_pos_weight: bool = True,
epochs: int = 30,
save_top_k: int = 1,
early_stopping_patience: int = 7,
early_stopping_min_delta: float = 0.01,
gradient_clip_val: float = 1.0,
gradient_clip_algorithm: float = "norm",
reset_model: bool = False,
auto_lr_find: bool = False,
device: str = "gpu",
devices: int = 1,
profiler: T.Optional[str] = None,
weight_decay: T.Optional[float] = 1e-5,
precision: T.Optional[int] = 32,
stochastic_weight_averaging: T.Optional[bool] = False,
stochastic_weight_averaging_lr: T.Optional[float] = 0.05,
stochastic_weight_averaging_start: T.Optional[float] = 0.8,
model_pruning: T.Optional[bool] = False,
save_batch_val_metrics: T.Optional[bool] = False,
skip_train: T.Optional[bool] = False,
refine_model: T.Optional[bool] = False,
weight_decay: float = 1e-5,
precision: int = 32,
stochastic_weight_averaging: bool = False,
stochastic_weight_averaging_lr: float = 0.05,
stochastic_weight_averaging_start: float = 0.8,
model_pruning: bool = False,
save_batch_val_metrics: bool = False,
skip_train: bool = False,
refine_model: bool = False,
finetune: bool = False,
):
"""Fits a model.
Expand Down Expand Up @@ -656,6 +660,7 @@ def fit(
save_batch_val_metrics (Optional[bool]): Whether to save batch validation metrics to a parquet file.
skip_train (Optional[bool]): Whether to refine and calibrate a trained model.
refine_model (Optional[bool]): Whether to skip training.
finetune (bool): Not used. Placeholder for compatibility with transfer learning.
"""
ckpt_file = Path(ckpt_file)

Expand Down
Loading

0 comments on commit 85afd32

Please sign in to comment.