An efficient segment anything model for the segmentation of medical images

Figure 1Structure of the EMedSAM.In this section, the EMedSAM will be presented in detail. The structure of EMedSAM is shown in Fig. 1, which consists of three main parts: a lightweight image encoder, a prompt encoder, and a mask decoder. First, medical images are preprocessed to enhance their quality for the encoding stage, which will be discussed in Sect. “Image preprocessing”. Subsequently, the processed images are input into a lightweight image encoder integrating adapter technology, which effectively extracts key features from the images; this process is further elaborated in Sects. “Lightweight image encoder DD-TinyViT” and “Fine-tuning med-adapter”, respectively. Once feature extraction is complete, the image feature embeddings are transmitted to the mask decoder, and similarly, embeddings from prompts are sent by the prompt encoder to the mask decoder. The mask decoder then integrates these embeddings and outputs the final segmented image. The overall training strategy of the model, especially how prompts are utilized to optimize the training process, will be detailed in Sect. “Adapter training”.Image preprocessingPreprocessing of 2D medical imaging dataFor data consistency and quality assurance, the dataset is preprocessed and used for subsequent analysis. In this paper, the first step of preprocessing is to apply a Gaussian filter for noise reduction, effectively separating valuable information from background noise. Then, the pixel values are scaled to a standard range through normalization to bolster the model’s ability. To further enhance image quality, histogram equalization is used to augment contrast. To accommodate the model input, the size of all images is unified and stored in PNG format to ensure data consistency. Considering the limitation of the dataset size, enhancement operations such as rotation, scaling, panning, and flipping are applied to augment the dataset diversity.Figure 2Decoupled distillation structure diagram.Preprocessing of 3D medical imaging dataPreprocessing of the 3D medical image dataset begins with adjusting voxel values to the range of 0–255 through voxel normalization, ensuring uniformity for further analysis. Subsequently, histogram equalization is utilized to amplify the contrast and definition of voxel data. Then, the voxel size is modified or resampled to lessen computational demands while preserving consistency in the data. Finally, 2D slices are extracted from the 3D voxel data to facilitate analysis using 2D methods.Lightweight image encoder DD-TinyViTImage encoder structureImage encoder is an essential part of SAM. The original SAM uses ViT-H as the backbone network of its image encoder and the number of parameters in ViT-H is up to 632M. This large parameter count requires significant computational cost when performing model training and inference, which escalates the hardware requirements and extends the processing time.To address this issue, DD-TinyViT follows the TinyViT model framework, a lightweight image encoder with only 21M parameters, as an alternative to the original heavy encoder. Similar to Swin-Transformer, TinyViT is divided into five parts. The first part is an image embedding block using two convolutions with a kernel of 3 and a step size of 2. The second part is a novel bottleneck convolution block named MBConv, and the remaining part uses a sequence model Transformer based on the attentional mechanism. Downsampling operations are performed in each of the second to fifth parts. Each layer is connected using residuals, and the activation function is GatedGeLU:$$\begin{aligned} GatedGeLU = {GELU}{(x)} \cdot \sigma \left( {GELU}{(x)} \right) \end{aligned}$$
(1)
where x represents the input image, and \(\sigma\) is the Sigmoid function. The expressions of \(\sigma (x)\) and GELU(x) are as follows:$$\begin{aligned} \sigma (x) = \frac{1}{1 + e^{- x}} \end{aligned}$$
(2)
$$\begin{aligned} GELU(x) = 0.5x\left( {1 + tanh\left( {\sqrt{\frac{2}{\pi }}\left( {x + 0.447x^{3}} \right) } \right) } \right) \end{aligned}$$
(3)
Lightweight handlingDirect training of smaller models on large datasets typically results in subpar performance. This paper utilizes the decoupled distillation technique to develop a lightweight image encoder, employing the pre-trained larger model (ViT-H) as the teacher model and the more compact model (DD-TinyViT) as the student model. The decoupled training procedure is shown in Fig. 2. Building upon Zhang’s research26, we transitioned the application domain from mobile applications to the medical field. We utilized a lightweight image encoder that outperforms MobileSAM, resulting in DD-TinyViT through decoupled distillation. Although this led to a slight increase in parameters, we implemented appropriate measures to address this issue. Unlike Zhao’s FastSAM25, which replaces ViT with YOLOv8 for speed advantages, we retained the original SAM model’s ViT functionality. While YOLOv8 excels in processing speed, it falls short in comprehensive image content understanding. Preserving the ViT can ensure the model’s depth and accuracy in image comprehension remain uncompromised.During the distillation phase, the teacher model initially performs inferences on the data, with its output feature map being preserved. To reduce training duration, the SA-1B dataset undergoes preprocessing, and the feature map which is outputted by SAM is stored locally, eliminating the need for the SAM model’s image encoder during training inference. Merely 1% of the SA-1B dataset is allocated for training, with 0.1% designated for validation. Following this, the student model is trained to minimize the discrepancy between its image embedding outputs and those of the teacher model, specifically the image embedding output \(O_{ViT – H}\left( x_{i} \right)\), employing the Mean Square Error (MSE) as the loss function:$$\begin{aligned} L_{MSE} = \frac{1}{N}{{\sum _{i = 1}^{N}\left( O_{V} \right. }\left( x_{i} \right) – O_{DD – TV}\left( x_{i} \right) )^{2}} \end{aligned}$$
(4)
where \(x_i\) represents the ith input image, N represents the total sample size, \(O_{V}\left( x_{i} \right)\) represents the image embedding output from the ViT-H image encoder, and \(O_{DD – TV}\left( x_{i} \right)\) represents the image embedding output from the DD-TinyViT image encoder.To enhance the robustness of the student model and mitigate overfitting, \(L_2\) regularization is used. incorporates a penalty term into the loss function, proportional to the square of the coefficient magnitudes. This penalty term promotes smaller weight values, which aids in reducing the model’s tendency to overfit the training data:$$\begin{aligned} L_{reg} = \lambda {\sum _{\varpi \epsilon W}\varpi ^{2}} \end{aligned}$$
(5)
where \(\varpi\) represents the individual weights of the DD-TinyViT model, W represents the set of all weights within the DD-TinyViT model, and \(\lambda\) represents the regularization coefficient.The overall loss function is:$$\begin{aligned} L_{total} = L_{MSE} + L_{reg} \end{aligned}$$
(6)
In contrast to more intricate loss functions like the combination of Focal and Dice losses, this method is straightforward and more amenable to optimization. The stochastic gradient descent (SGD) algorithm is employed to minimizing the total loss and facilitating the training of DD-TinyViT, with the update rule as follows:$$\begin{aligned} \varpi _{t + 1} = \varpi _{t} – \eta \nabla L_{total}\left( \varpi _{t} \right) \end{aligned}$$
(7)
where \(\varpi _{t + 1}\) and \(\varpi _{t}\) represents the updated and current weights of the model, respectively. \(\eta\) represents the learning rate, which determines the step size at each iteration in the quest to minimize the loss function. \(\nabla L_{total}\left( \varpi _{t} \right)\) represents the gradient of the total loss with respect to the weights \(\varpi\).This training process is iterated as follows: if the total loss \(L_{total}\) exceeds a predefined threshold, the training of the DD-TinyViT image encoder continues; otherwise, the training concludes. This threshold is set based on DD-TinyViT achieving more than 90% of the original model ViT-H’s performance. Subsequently, DD-TinyViT is evaluated to verify its performance equivalence with ViT-H. Through these steps, an optimized lightweight image encoder is obtained and integrated with the frozen prompt encoder and mask decoder.Figure 3Deployment structure diagram of med-adapter.Fine-tuning med-adapterAdapter structureDue to the limited availability of medical imaging data in the SA-1B large-scale dataset, the lightweight image encoder faces challenges in medical image segmentation tasks. To efficiently adapt this encoder for medical image segmentation, this paper introduces a novel adapter technology-med-adapter. Its design objective is to integrate specialized knowledge of medical imaging by embedding med-adapter within the Transformer layers, thereby preserving the previously acquired insights from the SA-1B dataset without the need for retraining. As shown in Fig. 3, this adapter features a bottleneck architecture. It initially reduces the dimensionality to a lower level, traverses a nonlinear activation function layer, and subsequently expands back to the original dimension. Moreover, a residual connection is maintained between the adapter layer’s input and output to guarantee uninterrupted information flow.Adapter deploymentConsidering that the configuration of med-adapter units can influence the overall efficacy of the model. Building on the Medical SAM Adapter design proposed by Wu13, this study embarked on multiple experimental investigations to assess the impact of med-adapter placements within different sections of the model. For specific details, refer to the third paragraph of the “Ablation Study of EMedSAM” section. We determined the placement of the med-adapter from multiple sets of experiments, further validating Wu’s findings.In the encoder segment, for 2D medical imaging, this paper implements two adapters within each Transformer block of DD-TinyViT, as shown in Fig. 3a. The first adapter is situated downstream of the multi-head attention module yet prior to its associated residual connection. This placement allows the adapter to refine and enhance the module’s outputs without disrupting the main information flow, thereby improving the precision and efficiency of the information processing.The second adapter is positioned within the residual pathway of the MLP layer that ensues the multi-head attention component. By introducing an adapter at this location, the model leverages the computational power of the MLP layer to increase the nonlinear capabilities during the processing, thereby enhancing the model’s ability to learn complex data patterns.For 3D medical imaging, one approach involves using image preprocessing to convert the images into 2D slices, employing the adaptation method described in Fig. 3a. Another approach utilizes the configuration depicted in Fig. 3b, where the placement of the adapters remains the same, but an additional branch is introduced at the bottom. Specifically, 3D medical imaging, due to their higher spatial dimensions, contain more information and contextual relationships. Consequently, more complex adapter structures are required to effectively extract and process these additional data dimensions. In contrast, 2D medical imaging can achieve efficient handling with relatively simpler adapter structures because of their lower data dimensionality. The original attention mechanism is divided into spatial and depth branches. For 3D samples with a depth of Z, the spatial branch learns spatial correlations through interactions of M\(\times\)P, where M is the number of embeddings, and P is the length of embeddings. The depth branch learns depth correlations through Z\(\times\)P interactions. Ultimately, the processed outputs from the depth branch are restored to their original form and combined with the results from the spatial branch.In the decoder segment, three adapters are arranged for each Transformer block. The initial adapter is located ubsequent to the multi-head cross-attention module, which transforms prompt embeddings into image embeddings and adds a residual connection. This positioning optimizes the interplay of attention and residual pathways, enhancing prompt integration. An additional down-projection step, implemented before the ReLU activation function, condenses the prompt embeddings to enhance their processing. This setup enables the adapter to adeptly adjust parameters based on enriched prompt information, thereby increasing its adaptability across different modalities and tasks.The decoder’s second adapter, akin to its encoder counterpart, is adeptly placed to adjust the MLP-augmented embeddings, further refining the model’s ability to handle detailed and complex data patterns. The third adapter is positioned at the residual junction subsequent to the cross-attention from image embeddings to prompt attention. Following these adapter connections, an additional residual connection and layer normalization are introduced to stabilize the output and ensure uniformity, crucial for high accuracy in complex segmentation tasks.Adapter trainingTo achieve optimal performance, fine-tuned adapter parameters to medical imaging datasets are essential. During the training process, the other parameters of EMedSAM are kept frozen, and only the adapter parameters are updated. The essence of the fine-tuning strategy lies in optimizing the channel dimension and spatial dimension. We comprehensively optimized the model by dynamically adjusting the parameters across channel and spatial dimensions, in contrast to Wu’s research, which was confined to optimizing the spatial and depth branches of the training process for the 3D medical adapter. For the channel dimension, initially, the resolution of the input feature map is reduced through global average pooling, averaging the information across the entire feature map to produce a global, more compact representation. This strategy diminishes the model’s parameter count while preserving essential feature information:$$\begin{aligned} F_{pool} = \frac{1}{N}{\sum _{i = 1}^{N}F_{i}} \end{aligned}$$
(8)
where \(F_{pool}\) represents the outcome post the global average pooling operation. N represents the total number of elements in the feature map, with each element \(F_i\) corresponding to a unique position within this flattened vector representation.Subsequently, a linear layer is employed to refine the channel embeddings, which effectively narrows the spatial dimensions of the data and isolates the principal features:$$\begin{aligned} F_{comp} = Linear\left( F_{pool} \right) \end{aligned}$$
(9)
where \(F_{comp}\) represents the condensed output which is derived following the linear transformation and Linear represents the linear transformation operation, encapsulating a fully connected layer. The layerencompasses a weight matrix and a bias term, executing a linear combination on the input.Moreover, an additional linear layer is employed to revert the data back to its original dimensionality:$$\begin{aligned} F_{rest} = Linear\left( F_{comp} \right) \end{aligned}$$
(10)
where \(F_{rest}\) represents the restored compressed embedding channels.Finally, weights for the channel dimension are obtained through the sigmoid function and multiplied with the input feature map, serving as the input for the next stage:$$\begin{aligned} W_{chan}= & {} Sigmoid\left( {Linear\left( F_{rest} \right) } \right) \end{aligned}$$
(11)
$$\begin{aligned} F_{next}= & {} W_{chan} \odot F_{pool} \end{aligned}$$
(12)
where \(W_{chan}\) represents the output achieved by executing a linear transformation on \(F_{rest}\) and integrating an activation function. Sigmoid is the chosen activation function. \(\odot\) represents element-wise multiplication, and \(F_{next}\) is the output realized by applying element-wise multiplication to \(F_{pool}\), with the weights being \(W_{chan}\).For the spatial dimension, a convolutional layer is employed to reduce the spatial resolution of the feature map to half of its original size, effectively downsampling the input features and eliciting new feature representations. This technique is instrumental in capturing the spatial structural information inherent in the input:$$\begin{aligned} F_{downs} = Conv2D\left( F_{next} \right) \end{aligned}$$
(13)
where \(F_{downs}\) represents the results from applying a 2D convolution operation to \(F_{next}\).Utilizing transposed convolution, the spatial resolution gets reinstated while preserving the input’s channel count. This process expands the spatial dimensions of the input features, effectuating upsampling.$$\begin{aligned} F_{upsa} = TransposeConv2D\left( F_{downs} \right) \end{aligned}$$
(14)
where \(F_{upsa}\) is the result achieved by executing a transposed convolution operation on \(F_{downs}\), where TransposeConv2D represents the transposed convolution operation.Post-adaptation in each layer, a skip connection is introduced to integrate the two feature representations. This approach bolsters information flow, retains finer details, and facilitates the transfer of information along with the propagation of gradients.$$\begin{aligned} F_{skip} = F_{next} + F_{upsa} \end{aligned}$$
(15)
where \(F_{skip}\) is derived by executing element-wise addition between \(F_{next}\) and \(F_{upsa}\), culminating in an output that aligns with the image encoder specifically adjusted for the current medical imaging data.This paper employs standard loss functions, including cross-entropy and Dice loss, to refine the model. To avert overfitting and facilitate the adapter’s acquisition of medical domain expertise, regularization methods like weight decay and Dropout are implemented. This approach ensures the model’s focus on assimilating precise knowledge pertinent to the medical field during training. Through the integration of adapters and targeted fine-tuning, the lightweight SAM has realized noteworthy enhancements in performance across a variety of medical image segmentation tasks.Overall training strategyInspired by SAM training, the bounding box prompts remain unchanged. As for the point prompts, our method employs a technique that amalgamates random sampling with iterative sampling. In the initial phases of model training, click prompts are commenced with random sampling, aiding the model in the early discernment of foreground and background regions within the image. Subsequently, to boost the model’s accuracy and adaptability, an iterative sampling mechanism is incorporated. In this regimen, each click iteration is informed not solely by the model’s extant predictions but also by inaccuracies identified in prior outputs, thus directing new clicks toward these discrepancies. This iterative sampling approach is crafted to mimic real-world scenarios where users incrementally refine model forecasts through ongoing interaction. Via persistent iteration, the model progressively enhances its comprehension of intricate image constituents and segmentation fidelity.Through the hybrid training strategy that merges random sampling with iterative click sampling, the model can grasp the fundamental structure of the image initially, and continuously optimize and refine segmentation outcomes in subsequent training phases. Then it improves accuracy and user interaction experience.

Hot Topics

Related Articles