In recent years, significant progress has been made in the medical image analysis domain using convolutional neural networks (CNNs). In particular, deep neural networks based on a U-shaped architecture (UNet) with skip connections have been adopted for several medical imaging tasks, including organ segmentation. Despite their great success, CNNs are not good at learning global or semantic features. Especially ones that require human-like reasoning to understand the context. Many UNet architectures attempted to adjust with the introduction of Transformer-based self-attention mechanisms, and notable gains in performance have been noted. However, the transformers are inherently flawed with redundancy to learn at shallow layers, which often leads to an increase in the computation of attention from the nearby pixels offering limited information. The recently introduced Super Token Attention (STA) mechanism adapts the concept of superpixels from pixel space to token space, using super tokens as compact visual representations. This approach tackles the redundancy by learning efficient global representations in vision transformers, especially for the shallow layers. In this work, we introduce the STA module in the UNet architecture (STA-UNet), to limit redundancy without losing rich information. Experimental results on four publicly available datasets demonstrate the superiority of STA-UNet over existing state-of-the-art architectures in terms of Dice score and IOU for organ segmentation tasks.
To get a local copy up and running follow these simple steps.
- Clone the repo
git clone https://github.com/Retinal-Research/STA-UNet.git
- Create a Python Environment and install the required libraries by running
pip install -r requirements.txt
The datasets we used are provided by TransUnet's authors. The preprocessed Synapse dataset is accessed from here. extract the zip file and copy the data folder to your project directory.
The pre-trained weights on Synapse Dataset can be downloaded from here. After extracting the weights file to your OUTPUT_PATH (and downloading the Synapse dataset). You can run the following command in terminal to infer on the testing set using proposed STA-UNet.
python test.py --output_dir OUTPUT_PATH --max_epochs 150
To train from the scratch after placing the data as instructed, run the following command in your terminal (from the project directory). The arguments can be customized in the begining section of the train.py script
python train.py --output_dir OUTPUT_PATH
SelfReg-UNet: https://github.com/ChongQingNoSubway/SelfReg-UNet