Skip to content

πŸš— | UNet implementation using PyTorch | CARVANA Dataset | Car Segmentation

License

Notifications You must be signed in to change notification settings

yakhyo/unet-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

11 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

UNet Implementation using PyTorch | Car Segmentation

Downloads GitHub Repository

Input image

Input image

Predicted mask

Predicted mask

Image and mask overlay

Image and mask overlay

Getting Started

git clone git@github.com:yakhyo/unet-pytorch.git
cd unet-pytorch
  • Dice loss and Cross Entropy loss used for training. See the dice loss implementation. dice_score = 1 - dice_loss used for evaluation.
  • Model weight provided in weights folder. Weights saved in f16 (~60MB).
  • Demo Jupyter Notebook.

To Do:

  • [] ONNX export.
  • [] ONNX inference.

Dataset

Carvana Image Masking (PNG) dataset is used to train the model. After downloading the data place them under ./data directory.

β”œβ”€β”€ data
    β”œβ”€β”€ train_images
         β”œβ”€β”€ xxx.jpg
         β”œβ”€β”€ xxy.jpg
         β”œβ”€β”€ xxz.jpg
          ....
    β”œβ”€β”€ train_masks
         β”œβ”€β”€ xxx.png
         β”œβ”€β”€ xxy.png
         β”œβ”€β”€ xxz.png

Training

Training arguments

usage: train.py [-h] [--data DATA] [--scale SCALE] [--num-classes NUM_CLASSES] [--weights WEIGHTS] [--epochs EPOCHS] [--batch-size BATCH_SIZE] [--num-workers N] [--lr LR] [--weight-decay WEIGHT_DECAY] [--momentum MOMENTUM] [--amp] [--print-freq PRINT_FREQ]
                [--resume RESUME] [--use-deterministic-algorithms] [--save-dir SAVE_DIR]

UNet training arguments

options:
  -h, --help            show this help message and exit
  --data DATA           Directory containing the dataset (default: './data')
  --scale SCALE         Scale factor for input image size (default: 0.5)
  --num-classes NUM_CLASSES
                        Number of output classes (default: 2)
  --weights WEIGHTS     Path to pretrained model weights (default: '')
  --epochs EPOCHS       Number of training epochs (default: 10)
  --batch-size BATCH_SIZE
                        Batch size for training (default: 4)
  --num-workers N       Number of data loading workers (default: 8)
  --lr LR               Learning rate (default: 1e-5)
  --weight-decay WEIGHT_DECAY
                        Weight decay (default: 1e-8)
  --momentum MOMENTUM   Momentum (default: 0.9)
  --amp                 Enable mixed precision training
  --print-freq PRINT_FREQ
                        Frequency of printing training progress (default: 10)
  --resume RESUME       Path to checkpoint to resume training from (default: '')
  --use-deterministic-algorithms
                        Forces the use of deterministic algorithms only.
  --save-dir SAVE_DIR   Directory to save model weights (default: 'weights')

Train the model

python train.py

Inference

Inference arguments

usage: inference.py [-h] [--model-path MODEL_PATH] [--image-path IMAGE_PATH] [--scale SCALE] [--save-overlay]

Image Segmentation Inference

options:
  -h, --help            show this help message and exit
  --model-path MODEL_PATH
                        Path to the model weights
  --image-path IMAGE_PATH
                        Path to the input image
  --scale SCALE         Scale factor for resizing the image
  --save-overlay        Save the overlay image if this flag is set

Inference

python inference.py --model-path weights/last.pt --image-path assets/image.jpg