[Paper Review] Few Shot Medical Image Segmentation with Cross Attention Transformer

2023. 12. 2. 15:05

Lin, Y., Chen, Y., Cheng, K.-T., & Chen, H. (2023). Few Shot Medical Image Segmentation with Cross Attention Transformer (arXiv:2303.13867). arXiv. http://arxiv.org/abs/2303.13867

[!Synth]
Contribution::
본 논문의 저자들은 medical imaging에서의 데이터가 부족한 현상을 언급하고, lack medical dataset의 현상을 시사한다.

또한, 기존의 FSL(Few Shot Learning)에 대해서 이를 여러 가지 module을 사용하여 medical imaging segmentation에 접목 시킬 방법을 제안한다.

본 모델의 이름은 "CAT-Net"이며, cross masked attention Transformer의 개념을 가지고 모델이 설계되었다.

제안한 방법론

  1. mask incorporated feature extraction을 하는 MIFE sub-net이 있다.
    여기서 query 와 support features, query mask의 feature를 추출한다.
  2. cross masked attention Transformer (CMAT)를 통해 query와 support feature 간의 query prediction을 증가시킨다.
  3. 이 과정에 대해서 iterative refinement framework를 구축하여 CMAT 모듈을 반복적으로 적용하여 segmentation performance를 올린다.
    본 과정에서는 Prototypical Segmentation Module이 사용된다.

[!md]
Author:: Lin, Yi

Author:: Chen, Yufan

Author:: Cheng, Kwang-Ting

Author:: Chen, Hao

Title:: Few Shot Medical Image Segmentation with Cross Attention Transformer
Year:: 2023
Citekey:: @LinEtAl2023

Tags:: Computer Science - Computer Vision and Pattern Recognition
itemType:: preprint

[!LINK]

Lin et al_2023_Few Shot Medical Image Segmentation with Cross Attention Transformer.pdf

 

[!Abstract]

abstract:: Medical image segmentation has made significant progress in recent years. Deep learning-based methods are recognized as data-hungry techniques, requiring large amounts of data with manual annotations. However, manual annotation is expensive in the field of medical image analysis, which requires domain-specific expertise. To address this challenge, few-shot learning has the potential to learn new classes from only a few examples. In this work, we propose a novel framework for few-shot medical image segmentation, termed CAT-Net, based on cross masked attention Transformer. Our proposed network mines the correlations between the support image and query image, limiting them to focus only on useful foreground information and boosting the representation capacity of both the support prototype and query features. We further design an iterative refinement framework that refines the query image segmentation iteratively and promotes the support feature in turn. We validated the proposed method on three public datasets: Abd-CT, Abd-MRI, and Card-MRI. Experimental results demonstrate the superior performance of our method compared to state-of-the-art methods and the effectiveness of each component. Code: https://github.com/hust-linyi/CAT-Net.

 


Annotations

Few Shot Medical Image Segmentation with Cross Attention Transformer
Abstract

[!Highlight]

However, manual annotation is expensive in the field of medical image analysis, which requires domain-specific expertise.

 

Comment:

Challenge in Medical Imaging

[!Highlight]

In this work, we propose a novel framework for few-shot medical image segmentation, termed CAT-Net, based on cross masked attention Transformer.

 

Comment:

CAT-Net

[!Highlight]

Our proposed network mines the correlations between the support image and query image, limiting them to focus only on useful foreground information and boosting the representation capacity of both the support prototype and query features.

 

Comment:

query image and support image

[!Highlight]

Most of the existing methods follow a fullysupervised learning paradigm, which requires a considerable amount of labeled data for training.

(1)

Comment:

shortcoming of fully supervised learning paradigm

[!Highlight]

However, the manual annotation of medical images is timeconsuming and labor-intensive, limiting the application of DL in medical image segmentation.

 

Comment:

challenge of Medical imaging

[!Highlight]

Specifically for the 3D volumetric medical images (e.g., CT, MRI), the manual annotation is even more challenging which requires the annotators to go through hundreds of 2D slices for each 3D scan.

 

Comment:

specific challenge of Medical imaging

[!Highlight]

To address the challenge of manual annotation, various label-efficient techniques have been explored, such as self-supervised learning [15], semi-supervised learning [30,31], and weakly-supervised learning [11].

 

Comment:

various label-efficient technique

[!Highlight]

Despite leveraging information from unlabeled or weakly-labeled data, these techniques still require a substantial amount of training data [21,16], which may not be practical for novel classes with limited examples in the medical domain.

 

Comment:

substantial amount of training data

[!Highlight]

Considering the hundreds of organs and countless diseases in the human body, FSL brings great potential to the various medical image segmentation tasks where a new task can be easily investigated in a data-efficient manner.

(2)

Comment:

Expectation of FSL(Few Shot Learning)

[!Highlight]

Most few-shot segmentation methods follow the learning-to-learn paradigm, which aims to learn a meta-learner to predict the segmentation of query images based on the knowledge of support images and their respective segmentation labels.

(2)

Comment:

learning-to-learn paradigm, meta-learner, query images, knowledge

[!Highlight]

how to learn the meta-learner [17,26,14]; and (2) how to better transfer the knowledge from the support images to the query images [23,27,18,13,5,25].

 

Comment:

two aspects of existing few-shot segmentation methods

[!Highlight]

Despite prototype-based methods having shown success, they typically ignore the interaction between support and query features during training.

 

Comment:

prototype-based methods' limitation

[!Highlight]

In this paper, as shown in Fig. 1(a), we propose CAT-Net, a Cross Attention Transformer network for few-shot medical image segmentation, which aims to fully capture intrinsic classes details while eliminating useless pixel information and learn an interdependence between the support and query features.

 

Comment:

CAT-Net

[!Highlight]

Different from the existing FSS methods that only focus on the single direction of knowledge transfer (i.e., from the support features to the query features), the proposed CAT-Net can boost the mutual interactions between the support and query features, benefiting the segmentation performance of both the support and query images.

 

Comment:

the difference between the existing FFS methods and CAT-Net

[!Highlight]

Additionally, we propose an iterative training framework that feed the prior query segmentation into the attention transformer to effectively enhance and refine the features as well as the segmentation.

 

Comment:

iterative training framework

2 Method
2.1 Problem Definition

[!Image]

Image

 

Comment:

figure 1

[!Highlight]

the commonly used episode training approach is employed [29].

 

Comment:

PANet: Few-shot image semantic segmentation with prototype alignment

[!Highlight]

Each trainig/testing episode (Si, Qi) instantiates a N -way K-shot segmentation learning task.

 

Comment:

instantiating

[!Highlight]

Specifically, the support set Si contains K samples of N classes, while the query set Qi contains one sample from the same class. The FSS model is trained with episodes to predict the novel class for the query image, guided by the support set. During inference, the model is evaluated directly on Dtest without any re-training. In this paper, we follow the established practice in medical FSS [7,15,20] that consider the 1-way 1-shot task.

 

Comment:

training and testing procedure in CAT-Net

2.2 Network Overview

[!Highlight]

1) a mask incorporated feature extraction (MIFE) sub-net that extracts initial query and support features as well as query mask;

 

Comment:

main component 1

[!Highlight]

2) a cross masked attention Transformer (CMAT) module in which the query and support features boost each other and thus refined the query prediction;

 

Comment:

main component 2

[!Highlight]

3) an iterative refinement framework that sequentially applies the CMAT modules to continually promote the segmentation performance.

 

Comment:

main component 3

2.3 Mask Incorporated Feature Extraction

[!Highlight]

The Mask Incorporate Feature Extraction (MIFE) sub-net takes query and support images as input and generates their respective features, integrated with the support mask.

 

Comment:

MIFE's feature

[!Highlight]

Specifically, we first employ a feature extractor network (i.e., ResNet-50) to map the query and support image pair Iq and Is into the feature space, producing multi-level feature maps F q and F s for query and support image, respectively.

(3)

Comment:

feature extractor network, image pair, multi-level feature maps

[!Highlight]

Next, the support mask is pooled with F s and then expanded and concatenated with both F q and F s.

 

[!Highlight]

Finally, the query feature is processed by a simple classifier to get the query mask.

 

Comment:

simple classifier

2.4 Cross Masked Attention Transformer

[!Highlight]

1) a self-attention module for extracting global information from query and support features;

 

Comment:

main component 1

[!Highlight]

2) a cross masked attention module for transferring foreground information between query and support features while eliminating redundant background information

 

Comment:

main component 2

[!Highlight]

3) a prototypical segmentation module for generating the final prediction of the query image.

 

Comment:

main component 3

[!Highlight]

Self-Attention Module

 

[!Highlight]

the initial features are first flattened into 1D sequences and fed into two identical self-attention modules.

 

[!Image]

Image

 

Comment:

formula 1

[!Image]

Image

 

 

[!Highlight]

The output feature sequence of the selfattention alignment encoder is represented by Xq ∈ RHW ×D and Xs ∈ RHW ×D for query and support features, respectively.

(4)

Comment:

self-attention's output

[!Highlight]

Cross Masked Attention Module

(4)

[!Highlight]

Specifically, given the query feature Xq and support features Xs from the aforementioned self-attention module, we first project the input sequence into three sequences K, Q, and V using different weights, resulting in Kq, Qq, V q, and Ks, Qs, V s, respectively. Taking the support features as an example, the cross attention matrix is calculated by:

(4)

Comment:

Cross Masked Attention Modules' notation

[!Image]

Image

 

[!Highlight]

We expand and flatten the binary query mask M q to limit the foreground region in attention map.

(5)

Comment:

Mq

[!Image]

Image



Comment:

formula 4

[!Highlight]

Similar to self-attention, the support feature is processed by MLP and LN layer to get the final enhanced query features F s 1 . Similarly, the enhanced query feature Fq 1 is obtained with foreground information from the query feature.

(5)

Comment:

enhanced query features

[!Highlight]

Prototypical Segmentation Module

(5)

[!Image]

Image



Comment:

formula 5

[!Highlight]

where K is the number of support images, and ms (k,x,y,c) is a binary mask that indicates whether pixel at the location (x, y) in support feature k belongs to class c.

(5)

Comment:

notation for formula 5

[!Highlight]

Next, we use the non-parametirc metric learning method to perform segmentation.

(5)

Comment:

non-parametirc metric

[!Image]

Image

 

Comment:

figure 6

[!Highlight]

where cos(·) denotes cosine distance, α is a scaling factor that helps gradients to back-propagate in training. In our work, α is set to 20, same as in [29].

(5)

[!Highlight]

Specifically, we set the first threshold τ to 0.5 to obtain the binary query mask M q, which is used to calculate the Dice loss and update the model.

(5)

Comment:

first threshold

[!Highlight]

Then, the second threshold ˆ τ is set to 0.4 to obtain the dilated query mask ˆ M q, which is used to generate the enhanced query feature F q 2 in the next iteration.

(5)

Comment:

second threshold

[!Image]

Image

 

Comment:

double threshold strategy

2.5 Iterative Refinement framework

[!Highlight]

Thus, it’s natural to iteratively apply this sub-net to get the enhanced features and refine the mask, resulting in a boosted segmentation result.

(5)

Comment:

iteratively apply

[!Image]

Image

 

Comment:

figure 8

[!Image]

Image

 

Comment:

formula 9

[!Image]

Image

 

Comment:

formula 10

[!Highlight]

where CMA(·) indicates the self-attention and cross masked attention module, and Proto(·) represents the prototypical segmentation module.

(6)

Comment:

main modules

3 Experiment
3.1 Dataset and Evaluation Metrics

[!Highlight]

We use the Dice score as the evaluation metric following

(6)

Comment:

Dice score

[!Highlight]

To ensure a fair comparison, all the experiments are conducted under the 1-way 1-shot scenario using 5-fold cross-validation.

(6)

Comment:

cross-validation

 

[!Highlight]

We further propose a new validation setting (setting II) that takes every image in each fold as a support image alternately and the other images as the query.

(6)

Comment:

new validation setting

3.2 Implementation Details
3.3 Comparison with State-of-the-Art Methods

[!Highlight]

We compare the proposed CAT-Net with state-of-the-art (SOTA) methods, including SE-Net [19], PANet [29], ALP-Net [15], and AD-Net [7], and Q-Net [20].

(6)

Comment:

comparison with SOTA

[!Image]

Image




Comment:

table 1

[!Image]

Image

Comment:

figure 2

3.4 Ablation Study

[!Highlight]

Effectiveness of CMAT Block: To demonstrate the importance of our proposed CAT-Net in narrowing the information gap between the query and supporting images and obtaining enhanced features, we conducted an ablation study.

(8)

Comment:

CMAT Block

[!Image]

Image

 

Comment:

table 2

[!Image]

Image



Comment:

figure 3

[!Highlight]

Influence of Iterative Mask Refinement Block: To determine the optimal number of iterative refinement CMAT block, we experiment with different numbers of blocks.

(8)

Comment:

Iterative Mask Refinement Block

4 Conclusion

[!Highlight]

Additionally, the proposed CMAT module can be iteratively applied to continually boost the segmentation performance.

(9)

Comment:

CMAT module