Enhancing representation in radiography-reports foundation model: a granular alignment algorithm using masked contrastive learning

The high cost of annotation has long been a persistent challenge in the medical field. One prevalent approach to alleviating the annotation reliance in downstream tasks is the utilization of pre-trained models. With the rapid advancements in natural language processing models in recent years, there has been a growing interest in integrating expert knowledge from clinical reports with medical images. In the following sections, relevant studies in the medical domain, specifically within the realm of self-supervised pretext task-based and contrastive learning models, will be introduced. These studies serve as the foundation for our proposed MaCo. We declare that the proposed methods comply with all relevant ethical regulations and have been approved for research by the Shenzhen Institute of Advanced Technology.Pretext task-based methodsThe goal of pretext task-based methods is to learn semantically meaningful image representations without utilizing any downstream task annotations31,32. These pretext tasks typically involve self-supervised learning techniques, such as using randomly augmented images or training on down-sampled images for high-resolution reconstruction. One widely utilized pretext task-based method is MAE. MAE14 applies a random masking technique to image patches within the input data. Subsequently, a reconstruction decoder is employed to recover the masked regions. By engaging in the reconstruction process, MAE is able to learn image features that can be subsequently utilized for various downstream tasks. Due to its simplicity and effectiveness, MAE has gained considerable popularity, including in the domain of medical image-text modeling. Drawing inspiration from MAE, Zhou et al.19 employed a similar masking mechanism in both the text branch and the image branch of their model (MRM). They leveraged the vision representation as a supplementary component to the text branch and enhanced the feature representations through back-propagation optimization. Similar to MRM, Chen et al.33 also employed masking in both the image and text modalities with a single transformer to integrate and couple the features of the image and text modalities (M3AE).Although the aforementioned methods have shown promising performance in downstream fine-tuning tasks, their zero-shot capabilities are constrained by the adopted modality coupling strategy. This limitation impede their ability to generalize to unseen tasks, especially when dealing with unlabeled datasets.Contrastive learning-based methodsContrastive learning-based methods, on the other hand, have recently gained significant attention due to their unique zero-shot capabilities21,34. Contrastive learning aims to minimize the similarity distance between paired data points within a training batch while simultaneously maximizing the dissimilarity between unpaired data points. By leveraging this approach, the trained models become proficient in differentiating between paired and unpaired images and texts, thereby acquiring the ability to generalize to unseen data samples, known as zero-shot capabilities35.Zhang et al.35 were pioneers in introducing contrastive learning as a proxy task in the field of medicine. Their study demonstrated the efficacy of contrastive learning within the medical domain. Building upon this foundation, Wang et al.36 further investigated the impact of false negative samples on the performance of contrastive learning methods. Boecking et al.23 recognized the distinct language patterns found in medical reports, prompting a redesign of the language model for medical vision-language processing. Bannur et al.37 and Zhou et al.12 employed past radiology images and multi-view images, respectively, for joint training purposes. In more recent developments, Wu et al.6 and Zhang et al.16 integrated a report filter to extract medical entities and employed a more complex modal fusion module to aggregate features, thereby achieving improved results. On the other hand, to establish fine-grained correspondence between images and reports, Li et al.38 aligned visual and textual semantics at different levels with explicit constraints. Huang et al.22 proposed a local fine-grained weighting mechanism. This mechanism calculates the similarity between each word and image patches, resulting in word-level responses. Similarly, Wang et al.39 introduced the concept of multi-granularity alignment to explicitly learn the correspondence between fine-grained vision and text tokens.These contrastive learning-based methods have achieved comparable performance in downstream fine-tuning tasks to those pretext task-based methods. More importantly, some methods, such as BioViL and GLoRIA, have demonstrated inspiring zero-shot capabilities, which greatly enhance the task generalization capability of medical models.MaCoWe introduce MaCo, a chest X-ray radiography-report foundation model with zero-shot capabilities, based on masked contrastive learning. The motivation behind MaCo is to leverage the advantages of both contrastive learning-based and pretext task-based methods to acquire enhanced semantic latent representations. MaCo investigates the masked autoencoder along with contrastive learning to facilitate learning from paired radiological images and medical reports. Additionally, we propose a correlation weighting mechanism in MaCo that weights the contrastive loss based on the importance of sampled image patches. This mechanism helps prioritize informative patches, resulting in more effective learning and better representation of relevant features. Figure 1 shows the framework of MaCo, which integrates the strengths of contrastive learning and pretext task methods. The detailed methodology will be introduced in the subsequent sections.Masked high-resolution image reconstruction for image feature extractionTo extract meaningful feature representations from input images, we adopt MAE proposed by He et al.14 as our primary image representation extractor. MAE employs a reconstruction pretext task that is elaborately designed to restore the masked image, thereby extracting meaningful representations of the image.Specifically, the input image is partitioned into regular, non-overlapping patches, and a subset of the patches is randomly sampled as the inputs of the model while the remaining ones are excluded. Let us define B as the batch size, and C as the feature dimension. N = N s + N msk represents the total number of divided image patches, where N s and N msk correspond to the number of sampled and masked patches, respectively. The encoder’s prediction, given the masked image as input, is represented by venc with the size of B × NsC, and the decoder’s prediction is represented by vrecon with the size of B × NC. Let grecon denote the corresponding ground truth that is partitioned in the same way as the input image. The loss function of the masked autoencoder reconstruction in a batch can be written as:$${{{{\mathcal{L}}}}}_{mae}={\left\Vert {v}_{recon}^{msk}-{g}_{recon}^{msk}\right\Vert }^{2}$$
(1)
where ∣∣ ⋅ ∣∣ represents the L2 norm. Here, we only focus on the recovery of the masked patches, such that \({v}_{recon}^{msk}\) and \({g}_{recon}^{msk}\) are the recovery of the masked patches and its corresponding ground-truth patches.High-resolution reconstruction is also an effective pre-training approach in capturing the detailed representations of images19. This method takes low-resolution images as inputs for the image encoder and imposes constraints on the image decoder using original high-resolution images.In MaCo, we incorporate both masked image reconstruction and high-resolution reconstruction as pre-text tasks during pre-training. The input image is firstly down-sampled to a smaller resolution. In this work, the down-sampling ratio is set to 2. Then, following the practice adopted in MAE, the down-sampled input image is partitioned into N image patches, and a random subset of these patches is sampled as inputs to the image encoder. The decoder outputs high-resolution reconstruction results for the down-sampled input image patches. Following MAE, we perform high-resolution reconstruction only on masked patch representations. Therefore, MaCo follows the same training procedure as MAE, with the difference being that the input to MaCo is the down-sampled version of the original images. Let \({v^{\prime} }_{recon}\) denotes the image decoder’s results with input of the down-sampled image, the loss function of MaCo’s pretext task is defined as:$${{{{\mathcal{L}}}}}_{pret}={\left\Vert {v^{\prime} }_{recon}^{msk}-{g}_{recon}^{msk}\right\Vert }^{2}$$
(2)
Report feature extractionWe adopt BERT40, a classical natural language processing model that has achieved good performance across various language understanding tasks, to extract expert knowledge from clinical daily examination reports.The clinical reports are processed by dividing them into multiple sentences. In this pre-processing step, we also incorporate random sentence selection and shuffling. Next, we use a wordpiece tokenizer to convert the pre-processed reports into a sequence of numerical tokens that can be processed by BERT. The wordpiece tokenizer breaks down each word into sub-word units and maps them to their corresponding numerical representations. This allows BERT to capture the meaning of the text at a more granular level, improving the quality of the sentence representations.We feed the sequence of numerical tokens into BERT to obtain sentence representations, denoting as tenc with the size of B × N tC, where N t is the length of text tokens concatenate with the [cls] token. These sentence representations capture the main ideas and themes from the clinical reports and will be used to interact with the extracted image representations, which will be discussed in the next section.Masked contrastive learning with a correlation weighting mechanismIn this section, our objective is to construct a multi-modal embedding space using sampled image patch representations venc and report representations tenc. The fundamental concept is akin to CLIP21, wherein a multi-modal embedding space is learned by concurrently training an image encoder and text encoder. Given a batch B of image-report pairs, the goal is to align the image and text in the feature space by maximizing the cosine similarity between the image and text representations of correct image-report pairs while minimizing the cosine similarity of representations for incorrect pairs.Let fci( ⋅ ) and fct( ⋅ ) denote linear mappings in the joint embedding space for image representation and report representation, respectively. Image representations mapping \(v=f{c}_{i}({v}_{enc}^{pool})\), and report representations mapping \(t=f{c}_{t}({t}_{enc}^{pool})\) is used to calculate the cosine similarity matrix  < v, t > , where \({v}_{enc}^{\, pool}\) with the size of B × C represents the tokens-dimension pooling result of venc and \({t}_{enc}^{pool}\) also with the size of B × C represents the [cls] token result of tenc. With the temperature τ, the InfoNCE loss41 utilized in a batch is then be described as:$${{{{\mathcal{L}}}}}_{{in}\, {foNCE}}=-\frac{1}{B}{\sum}_{i}^{B}log\left(\frac{exp(\langle {v}_{i},{t}_{i}\rangle /\tau )}{{\sum }_{k}^{B}exp(\langle {v}_{i},{t}_{k}\rangle /\tau )}\right)$$
(3)
Here, τ is optimized during the model training.However, unlike the common contrastive learning setting with full-resolution full-sampled image inputs, two challenges must be addressed when aligning the multi-modal representations in masked contrastive learning methods: 1) Do the randomly masked images still retain sufficient information that can be correlated with the corresponding reports? 2) If yes, what is the extent of the correlation? Answering these two questions is crucial for establishing meaningful correlations between the image and the text modalities. From the perspective of a clinical expert, the answers to these two questions depend on the quality and relevance of the sampled patches. If the sampled patches can precisely cover the entire lesion area, the two modalities should be highly correlated. Otherwise, the correlation would be low.To capture the correlation between paired masked images and reports in a manner that aligns with the expert practice, we propose a correlation weighting mechanism. The details are depicted in Fig. 1(b). Specifically, we score the sampled images based on a masked position map. These scores are then used to adjust the temperature parameter in contrastive learning and the weights in the contrastive loss function. By doing so, higher weights can be given to highly correlated paired samples during the network learning process, facilitating network optimization.For the kth (k = 1, . . . , B) input instance in a batch, we initiate the process by generating a binary matrix (\({p}_{k}\in {{\mathbb{R}}}^{\sqrt{N}\times \sqrt{N}}\)) based on its patch sampling mask used for masked auto-encoding, assigning a value of 0 to the masked regions and a value of 1 to the sampled regions. This binary matrix is named the masked position map. pk is then reshaped to a one-dimensional vector \(\widehat{{p}_{k}}\in {{\mathbb{R}}}^{N}\) and a fully connected (FC) layer is learned to generate an importance score for the instance from the reshaped masked position map \(\widehat{{p}_{k}}\) (Fig. 1(b)(ii)): \({w}_{k}^{s}=\mathop{\sum }_{i=1}^{N}{w}_{i}\cdot \widehat{{p}_{k,i}}\). Here, wi is the weight of the FC layer, representing the weight assigned to a specific mask position. Corresponding to all instances in a batch, we obtain the importance score vector \({W}^{s}=\{{w}_{k}^{s}\}\in {{\mathbb{R}}}^{B}\). Additionally, for the weighting purpose, we employed a softplus activation function to re-scale the range of the importance scores, facilitating more stable training. The final importance scores \({W}^{c}\in {{\mathbb{R}}}^{B}\) are generated as follows:$${W}^{c}=log\left(1+{e}^{{W}^{s}}\right)$$
(4)
Then, we employ the obtained importance scores Wc to weight the image-text sample pairs, ensuring that the model assigns greater attention to pairs with more meaningful sampled content (larger importance scores) during the training process. This weighting process consists of two components, involving the weighting of the cosine similarity matrix  < v, t > ( < v, t > is also called logits, and in the following, we will use logits to indicate  < v, t > ), and the weighting of loss terms. The weighting of logits is similar to the use of the reciprocal of the temperature coefficient τ. Generally, a smaller temperature coefficient indicates sharper logits, thereby offering a more rigorous distribution alignment during the training process. In contrast to the temperature coefficient, which has the same value for all sample pairs, our importance scores provide varying weighting values to the digits of different sample pairs in a batch. Particularly, for the ith input image-text sample pair, if the sampled image patches are highly correlated with the corresponding text, a larger importance score (larger \({w}_{i}^{c}\)) is obtained, and sharper logits are required. Conversely, when the sampled image patches have a low correlation with the corresponding text, \({w}_{i}^{c}\) is smaller, and relatively uniform distributed logits are learned. In the meantime, we further utilize a detached version of Wc to weight the loss terms, ensuring that samples with higher correlation receive greater attention.The proposed masked-contrastive learning loss can thus be expressed as:$${{{{\mathcal{L}}}}}_{contra}=-\frac{1}{B}\mathop{\sum }_{i}^{B}\left(log\left(\frac{exp\left({w}_{i}^{c}\cdot \langle {v}_{i},{t}_{i}\rangle /\tau \right)}{\mathop{\sum }_{k}^{B}exp\left({w}_{i}^{c}\cdot \langle {v}_{i},{t}_{k}\rangle /\tau \right)}\right)+{w}_{i}^{c}log\left(\frac{exp\left(\langle {v}_{i},{t}_{i}\rangle /\tau \right)}{\mathop{\sum }_{k}^{B}exp\left(\langle {v}_{i},{t}_{k}\rangle /\tau \right)}\right)\right)$$
(5)
The final loss function to train MaCo combines the pretext-task loss with the masked-contrastive learning loss:$${{{\mathcal{L}}}}=\lambda {{{{\mathcal{L}}}}}_{pret}+(1-\lambda ){{{{\mathcal{L}}}}}_{contra}$$
(6)
Here, λ is a hyperparameter to balance the contributions of the two loss terms.Implementation detailsWe used the same data augmentation methods at different training stages. Specifically, we applied random horizontal flipping, random affine transformations (with degrees set to 20 and scale ranging from 0.8 to 1.2), and normalized the data with a mean of 0.4978 and a standard deviation of 0.2449. All experiments were conducted using the PyTorch framework. The pre-training of MaCo was completed in approximately 3.5 hours using four NVIDIA A100 GPUs. For the sake of convenience and comparability, we utilized the widely-used image encoder ViT-B/16 and employed BERT with a width of 768 as our text encoder. For pre-training, we set the training batch size to 512 and employed the AdamW optimizer, with an initial learning rate of 4.5e-4, weight decay of 0.05, β1 of 0.9, and β2 of 0.95. We used a symmetrical design for the contrastive learning loss \({{{{\mathcal{L}}}}}_{infoNCE}\), following21. We set the value of λ in Eq. (6) to 0.9. The learnable parameter τ in Eq. (3) was initialized to 0.03. In fine-tuning tasks, following the practice adopted by the classical methods6,19,22, we utilized the pre-trained image encoder as the initialization for the model to be fine-tuned across various applications, including classification, segmentation, and detection.For the fine-tuning classification experiments on datasets CheXpert, NIH ChestX-ray, and RSNA, we utilized the SGD optimizer, setting its momentum to 0.9 and searching for the optimal learning rate ranging from 8e-3 to 1e-4. For the fine-tuning segmentation experiments on datasets SIIM and COVID Rural, we used the AdamW optimizer, with an initial learning rate of 2e-5, weight decay of 0.05, β1 of 0.9, and β2 of 0.999. For the fine-tuning detection experiments on dataset RSNA, we employed VITDet29 as the base detection framework, and we utilized the AdamW optimizer with an initial learning rate of 3e-3, weight decay of 0.1, β1 of 0.9, and β2 of 0.999.In both the pre-training and fine-tuning stages of the image classification tasks, we warmed up the network by linearly increasing the learning rate to the set value and then, decreased the learning rate according to the cosine decay schedule.Comparative methodsWe began our analysis by comparing MaCo with various pre-training approaches that utilize text as supervision to learn image representations. These approaches include ConVIRT35, GLoRIA22, BioViL23, REFERS12, MGCA39, MFLAG42, Med-UniC43, M3AE33, MedKLIP6, MRM19, LoVT24 and Ark44. Specifically, ConVIRT proposes to learn medical visual representations by contrasting paired radiographs and sentences from radiology reports. GLoRIA improves upon ConVIRT by contrasting radiograph patches and words in the reports. BioViL and REFERS incorporate masked language modeling loss into contrastive learning, with REFERS introducing a multi-view fusion attention mechanism to better align the representations of each radiograph and its associated report. M3AE employs mask modeling in both the image and language modalities to investigate the performance of pre-trained models on natural datasets. MedKLIP utilizes a report filter to extract medical entities and employs a more complex modal fusion module to aggregate features. Similar to M3AE, MRM leverages a masking mechanism in both image and text branches, which has achieved the most advanced results in the medical field. To comprehensively evaluate our method, we also introduced some image-based self-supervised learning methods, which include Context Restoration27, Model Genesis25, TransVW28, C2L26, and ImageNet45.For the zero-shot tasks, we compared our method with relevant state-of-art approaches, including ConVIRT35, GLoRIA22, BioViL23, CheXzero7 and MedKLIP6. It should be noted that CheXzero and MedKLIP is not capable of handling free-form text, while MRM and M3AE are unable to achieve zero-shot results due to their training strategy. Finally, we demonstrated the weight visualization of our proposed correlation weighting mechanism, where we utilized attention maps to indicate that our approach can weigh the masked image representations in an interpretable and clinically plausible manner.DatasetsWe pre-train MaCo using radiographs and clinical reports from the MIMIC-CXR V2 dataset46. To assess the transferability of the learned radiograph representations, we perform various X-ray-based downstream tasks using multiple datasets, including NIH ChestX-ray45, CheXpert47, RSNA Pneumonia Detection (RSNA)45,48, SIIM-ACR Pneumothorax49, COVID-19 Rural50 dataset, and MS-CXR dataset23, respectively. The following section will introduce the datasets in detail:MIMIC-CXR v2 is a large dataset comprising 377,110 chest X-rays associated with 227,827 clinical reports sourced from the Beth Israel Deaconess Medical Center between 2011 and 2016. During pre-training, we used all paired data, no matter whether they were frontal or lateral.CheXpert releases a multi-label dataset for chest X-ray classification. To evaluate the performance of our model, we followed the official guidelines outlined in47 and reported results for five selected pathologies. As the official CheXpert test set is not publicly available, we adopted a similar approach as described in35 and used the official validation set as our test set. Additionally, following19, we sampled 5,000 images from the official training set to construct our validation set. The resulting training/validation/test split consists of 218,414/5,000/234 images, respectively, representing the entire dataset.NIH ChestX-ray (NIH) contains 112,120 frontal-view chest radiograph images and focuses on a multi-label classification problem involving 14 different chest pathologies. The dataset is split into training, validation, and test sets, with each set comprising 70%, 10%, and 20% of the total dataset, respectively.COVID-19 Rural (COVID Rural) is a small-scale collection comprising over 200 chest X-ray images with COVID-19 disease segmentation masks. We utilize this dataset to evaluate our segmentation performance. The dataset is randomly split into training, validation, and test sets, with a ratio of 60%, 20% and 20%.SIIM-ACR Pneumothorax (SIIM) is curated to facilitate the development of segmentation models for identifying pneumothorax disease in chest radiographs. The dataset includes more than 120,000 frontal-view chest X-rays, each accompanied by precise manual segmentation of pneumothorax regions. We leverage this dataset for both fine-tuning segmentation and zero-shot classification tasks. In constructing the fine-tuning dataset, our methodology aligns with established practices outlined in22. Specifically, we partition the dataset into sets for training, validation, and testing, allocating 70%, 15%, and 15% of the total dataset, respectively.RSNA Pneumonia Detection (RSNA) is derived from the 2018 RSNA Pneumonia Challenge, comprising a total of 6,012 slices with bounding box annotations. We use this dataset in fine-tuning classification and detection task. For the task of classification, we adhere to the official data split strategy, partitioning the dataset into a training set of 25,184 images, a validation set of 1500 images, and a test set of 3,000 images. For the task of detection, in alignment with the approach adopted in LoVT24, the dataset is partitioned into a training set consisting of 3,584 images, a validation set comprising 1210 images, and a test set with 1218 images.MS-CXR provides annotations in the form of bounding boxes and sentence pairs that describe clinical findings observed in chest X-ray images. Each sentence describes a single pathology present in the image, and there could be multiple manually annotated bounding boxes associated with the description of a single radiological finding. The annotations were collected on a subset of MIMIC-CXR images, which contain labels across eight different pathologies. In total, 1162 annotations of 881 cases were collected, and we utilized the entire dataset to measure the overlap between labeled bounding boxes and the results of vision-language association after pre-training.Reporting summaryFurther information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Hot Topics

Related Articles