A foundation model for clinician-centered drug repurposing

Most data used for this present study were obtained from publicly available knowledge repositories. For internal data, the Institutional Review Board at Mount Sinai, New York City, USA, approved the retrospective analysis of internal EMRs. All internal EMRs were deidentified before computational analysis and model development. Patients were not directly involved or recruited for the study. Informed consent was waived for analyzing EMRs retrospectively.Curation of a medical KG datasetThe KG is heterogeneous, with 10 types of nodes and 29 types of undirected edges. It contains 123,527 nodes and 8,063,026 edges. Supplementary Tables 8 and 9 show a breakdown of nodes by node type and edges by edge type, respectively. The KG and all auxiliary data files are available via Harvard Dataverse at https://doi.org/10.7910/DVN/IXA7BM. Supplementary Note 1 provides detailed information about datasets and curation of the KG.Problem definitionWe are given a heterogeneous KG, \(G=\left(\mathscr{V}\mathscr{,}\mathscr{E}{\mathscr{,}}{\mathscr{T}}_{{R}}\right)\), with nodes in the node set \(i\in \mathscr{V}\), edges \({e}_{i,\;j}=\left(i,r,j\right)\) in the edge set \(\mathscr{E}\), where \(r\in \mathscr{T}_{R}\) indicates the relationship type, and i is called the head/source node and \({j}\) the tail/target node. Each node also belongs to a node type set \(\mathscr{T}_{V}\). Each node also has an initial embedding, which we denote as \({{\bf{h}}}_{i}^{\left(0\right)}\). Given a disease i and drug j, we want to predict the likelihood of the drug being indicated or contraindicated for the disease. Our approach induces inductive priors by incorporating factual knowledge from the KG into the model, enhancing its reasoning capabilities for hypothesis formation and drug candidate prediction. Detailed experimental protocols, including data split curation, negative sampling, hyperparameter tuning and additional details, are described in Supplementary Note 4.Overview of TxGNN approachTxGNN is a deep-learning approach for mechanistic predictions in drug discovery based on molecular networks perturbed in disease and targeted by therapeutics. TxGNN is composed of four modules: (1) a heterogeneous GNN-based encoder to obtain biologically meaningful network representation for each biomedical entity; (2) a disease similarity-based metric learning decoder to leverage auxiliary information to enrich the representation of diseases that lack molecular characterization; (3) an all-relationhip stochastic pretraining followed by a drug–disease centric, full-graph, fine-tuning strategy; and (4) a graph explanatory module to retain a sparse set of edges that are crucial for prediction as a post-training step. Next, we expanded each module in detail.Heterogeneous GNN encoderOur objective was to learn a general encoder of a biomedical KG by learning a numerical vector (embedding) for each node, encapsulating the biomedical knowledge contained within its neighboring relational structures. This involves transforming initial node embeddings using a sequence of local graph-based, nonlinear function transformations to refine embeddings56. These transformations are subject to iterative optimization, guided by a loss function that minimizes incorrect drug predictions. The system converges to an optimized set of node embeddings through this process.Step 1: initializing latent representationsWe denote the input node embedding Xi for each node i, which is initialized using Xavier uniform initialization. For every layer l of message passing, there are the following three stages: steps 2–4.Step 2: propagating relationship-specific neural messagesFor every relationship type, we first calculated a transformation of node embedding from the previous layer \({{\bf{h}}}_{i}^{\left(l-1\right)}\), where the first layer \({{\bf{h}}}_{i}^{\left(0\right)}\) = xi. This was achieved by applying a relationship-specific weight matrix \({{W}}_{r,M}^{\left(l\right)}\) on the previous layer embedding:$${{\bf{m}}}_{r,\;i}^{\left(l\right)}={{{W}}}_{r,M}^{\left(l\right)}{{\bf{h}}}_{i}^{\left(l-1\right)}.$$Step 3: aggregating local network neighborhoodsFor each node i, we aggregated on the incoming messages from neighboring nodes of each relation, r, denoted as \({{\mathscr{N}}}_{i}^{r}\), by taking the average of these messages:$$\widetilde{{{{\bf{m}}}^{\left(l\right)}}_{r,\;i}}=\frac{1}{{\rm{|}}{{\mathscr{N}}}_{r}\left(i\right){\rm{|}}}\sum _{j\in {{\mathscr{N}}}_{i}^{r}}{{\bf{m}}}_{r,\;j}^{\left(l\right)}.$$Step 4: updating latent representationsWe then combined the node embedding from the last layer and the aggregated messages from all relationships to obtain the new node embedding:$${{\bf{h}}}_{i}^{\left(l\right)}={{\bf{h}}}_{i}^{\left(l-1\right)}+\mathop{\sum }\limits_{r\in {{\mathscr{T}}}_{{R}}}\widetilde{{{{\bf{m}}}^{\left(l\right)}}_{r,i}}.$$After L layers of propagation, we arrived at our encoded node embeddings hi for each node i.Predicting drug candidatesTxGNN employs disease and drug embeddings to predict indications and contraindications for each disease–drug pair. Considering the three relationship types needing prediction, each type was assigned a trainable weight vector wr. The interaction likelihood for a specific relationship is then determined using the DistMult approach57. Formally, for a disease i, drug j and relation r, the predicted likelihood p is calculated as follows:$${p}_{i,\;j,\;r}=\frac{1}{1+\exp \left(-\text{sum}\left({{\bf{h}}}_{i}\times {{\bf{w}}}_{{r}}\times {{\bf{h}}}_{{{j}}}\right)\right)}.$$Embedding-based disease similarity searchResearch on diseases varies widely based on their prevalence and complexity. For instance, the molecular basis of many rare diseases remains poorly understood21. Despite this, rare diseases often offer extensive opportunities for therapeutic advancements3. This shortage of research is evident in the biological KG, where rare diseases are characterized by a lack of relevant nodes and edges, leading to lower-quality graph embeddings. Empirical evidence indicates that GNN models exhibit substantially reduced predictive performance on splits designed to reflect the sparse nature of knowledge on these diseases, as opposed to random splits (Fig. 2c,d).Network embeddings for these diseases lack significance owing to sparse prior information in the KG. Thus, a model is needed to enhance these embeddings. Human physiology is an interconnected system where diseases exhibit similarities. Using a model to extract predictive information from similar but better-represented diseases in the KG, the target disease embedding can be enriched, improving model performance for the disease. To this end, TxGNN employed a three-step procedure: (1) construct a disease signature vector to capture complex disease similarities; (2) use an aggregation mechanism to combine similar disease embeddings into a comprehensive auxiliary embedding; and (3) introduce a gating mechanism to modulate the influence between the original and auxiliary disease embeddings, acknowledging that well-characterized diseases may not need supplementation. Each step is elaborated on in the following sections.Step 1: disease signature vectorsThe primary objective of this module was to derive a signature vector \({{\bf{p}}}_{i}\) for each disease i. Given the insufficiency of disease representations produced solely by GNNs, these representations are not ideal for direct similarity computations. Instead, we employed graph-theoretical methods14 to calculate disease similarities. In addition, variations of signature vectors are detailed in Supplementary Note 2. Specifically, we generated a vector that encapsulates the local neighborhoods surrounding a disease. For disease i, the signature vector is formally defined as follows:$${{\bf{p}}}_{{i}}=[\rm{p}_{1}\cdots {\rm{p}}_{{{|}}{\mathscr{V}}_{\rm{P}}{{|}}}{\rm{ep}}_{1}\cdots {\rm{ep}}_{{{|}}{\mathscr{V}}_{\rm{EP}}{{|}}}{\rm{ex}}_{1}\cdots {\rm{ex}}_{{{|}}{\mathscr{V}}_{\rm{EX}}{{|}}}{\rm{d}}_{1}\cdots {\rm{d}}_{{{|}}{\mathscr{V}}_{\rm{D}}{{|}}}]$$where$$\begin{array}{l}{\rm{p}}_{{j}}=\left\{\begin{array}{l}1\quad{if\; j}\in {{\mathscr{N}}}_{i}^{{\mathscr{P}}}\\ 0\quad{\rm{otherwise}}\end{array}\right.,{\text{ep}}_{j}=\left\{\begin{array}{l}1\quad{if\; j}\in {{\mathscr{N}}}_{i}^{{\mathscr{E}}{\mathscr{P}}}\\ 0\quad{\rm{otherwise}}\end{array}\right.,{\text{ex}}_{j}=\left\{\begin{array}{l}1\quad{if\; j}\in {{\mathscr{N}}}_{i}^{{\mathscr{E}}{\mathscr{X}}}\\ 0\quad{\rm{otherwise}}\end{array}\right.,{\rm{d}}_{{j}}\\\quad\;=\left\{\begin{array}{l}1\quad{if\; j}\in {{\mathscr{N}}}_{i}^{{\mathscr{D}}}\\ 0\quad{\rm{otherwise}}\end{array}\right.,\end{array}$$and \({{\mathscr{N}}}_{i}^{{\mathscr{P}}},{{\mathscr{N}}}_{i}^{{\mathscr{E}}{\mathscr{P}}},{{\mathscr{N}}}_{i}^{{\mathscr{E}}{\mathscr{X}}},{{\mathscr{N}}}_{i}^{{\mathscr{D}}}\) is the set of gene/protein, effect/phenotype, exposure, disease nodes lying in the one-hop neighborhood of disease i. We also adopted the dot product as the similarity measure, which means that the similarity is the sum of all shared nodes across the four node types:$${\rm{sim}}(i,j)={{\bf{p}}}_{i}{{\bf{p}}}_{j}={\rm{|}}{{\mathscr{N}}}_{i}^{{\mathscr{P}}}\cap {{\mathscr{N}}}_{j}^{{\mathscr{P}}}{\rm{|}}+{\rm{|}}{{\mathscr{N}}}_{i}^{{\mathscr{E}}{\mathscr{P}}}\cap {{\mathscr{N}}}_{j}^{{\mathscr{E}}{\mathscr{P}}}{\rm{|}}+{\rm{|}}{{\mathscr{N}}}_{i}^{{\mathscr{E}}{\mathscr{X}}}\cap {{\mathscr{N}}}_{j}^{{\mathscr{E}}{\mathscr{X}}}{\rm{|}}+{\rm{|}}{{\mathscr{N}}}_{i}^{{\mathscr{D}}}\cap {{\mathscr{N}}}_{j}^{{\mathscr{D}}}{\rm{|}}.$$Given the signature for diseases and calculated similarities among the diseases, for a queried disease, we can then obtain the k most similar diseases for a queried disease i:$${{\mathscr{D}}}_{{\rm{sim}},i}={{\rm{argmax}}}_{j\in {{\mathscr{V}}}_{{\mathscr{D}}}}{\rm{sim}}(i,j).$$Step 2: disease metric learningGiven a set of similar diseases, TxGNN generates disease embeddings that integrate various measures of disease similarity into a unified embedding, capable of augmenting the representation of a queried disease that may be sparsely annotated. To achieve this, we adopted a weighted scheme, wherein each disease was weighted according to its similarity score, as follows:$${{\bf{h}}}_{i}^{{\rm{sim}}}=\sum _{j{\boldsymbol{\in }}{{\mathscr{D}}}_{{\rm{sim}}}}\frac{{\rm{sim}}(i,j)}{{\sum }_{k{\boldsymbol{\in }}{{\mathscr{D}}}_{{\rm{sim}}}}{\rm{sim}}(i,k)}\times {{\bf{h}}}_{j}.$$Step 3: gating disease embeddingsThe final stage involves updating the original disease embedding hi with the disease–disease metric earning embedding hisim via a gating mechanism. This mechanism employs a scalar c ∈ [0,1] to modulate the influence between these two embeddings. Special consideration is needed because, for well-represented diseases in the KG, the disease–disease metric learning embedding might be unnecessary and could bias the disease embedding. Conversely, this embedding can be informative for accurate prediction of diseases with no existing drugs. Use of a learnable attention mechanism is ineffective, because it overvalues the original embeddings for well-represented diseases, neglecting the auxiliary embedding.Alternatively, we introduced an approach that determines weighting based on the degree of node connectivity \(|{{\mathscr{N}}}_{i}^{r}|\) of the queried drug–disease pair. A higher degree indicated that the disease was better represented in the knowledge and had a denser local network neighborhood, suggesting a reduced reliance on the disease–disease metric learning embedding and vice versa. The scalar’s value is designed to be high for minimal node degrees (0 or 1) and to decrease rapidly with increasing node degrees. To achieve this, we used an inflated exponential distribution density function with λ = 0.7:$${c}_{i}=0.7\times \exp (-0.7\times {\rm{|}}{{\mathscr{N}}}_{i}^{r}{\rm{|}})+0.2.$$We observed that the result is not sensitive to λ (Supplementary Fig. 12). Finally, we used parameter search and found optimal λ = 0.7. Then, we could finally obtain an augmented disease embedding:$${{\widehat{\bf{h}}}_{i}}={c}_{i}\times {{\bf{h}}}_{{{i}}}^{{\rm{sim}}}+(1-{c}_{i})\times {{\bf{h}}}_{i}.$$Finally, TxGNN used augmented disease embeddings as input to the latent decoder to produce drug predictions.Training TxGNN deep graph modelsThe objective of the training process was to predict the presence of a relationship between two entities within a KG. The dataset for positive samples, denoted as \({{\mathscr{D}}}_{+}\), comprises all pairs (i, j) across various relationship types r, with the label yi,r,j = 1 indicating the presence of a relationship. To generate the dataset for negative samples, \({{\mathscr{D}}}_{-}\), we used a sampling technique detailed in Supplementary Notes 4 and 3, creating counterparts for each positive pair. For a given pair i, j and relationship type r, the model estimated the probability pi,r,j of a relationship’s existence. The training loss is then calculated using the binary crossentropy loss formula:$${\mathscr{L}}{\mathscr{=}}\sum _{(i,\;r,\;j)\in {{\mathscr{D}}}_{+}\cup {{\mathscr{D}}}_{-}}{y}_{i,\;r,\;j}\times \log ({p}_{i,\;r,\;j})+(1-{y}_{i,\;r,\;j})\times \log (1-{p}_{i,\;r,\;j}).$$Previous research has emphasized KG completion, optimizing models across the entire spectrum of relationships within a KG58. This approach, however, may dilute the model’s capacity to capture specific knowledge, particularly when the interest lies solely in drug–disease relationships. Given that drug–disease interactions are governed by complex biological mechanisms, the extensive range of biomedical relationships in a KG can offer a comprehensive view of biological systems. The primary challenge lies in optimizing performance on a select group of relationships while beneficially leveraging the broader set of relationships for knowledge transfer, avoiding catastrophic forgetting of general knowledge.To address this challenge, TxGNN used a pretraining strategy. Initially, TxGNN predicted relationships across the entire KG using stochastic mini-batching, encapsulating biomedical knowledge in enriched node embeddings. In the fine-tuning phase, TxGNN focused on drug–disease relationships, sharpening its ability to generate specific embeddings and optimizing drug-repurposing predictions.Pretraining TxGNN modelTxGNN undergoes pretraining on millions of biomedical entity pairs across all relationships. As a result of the extensive number of edges, stochastic mini-batching is used to train on subsets of pairs at each step, ensuring coverage of all data pairs within each epoch. During this phase, degree-adjusted disease augmentation is deactivated and all relationship types are treated equally. The pretrained encoder weights are then used to initialize the encoder model for fine-tuning. It is important to note that the weights in the decoder, specifically for DistMult, wr, are reinitialized before fine-tuning to mitigate the risk of negative knowledge transfer.Fine-tuning TxGNN modelAfter pretraining, the model initialization encapsulated a broad spectrum of biological knowledge. The next phase refined drug–disease relationship predictions by focusing solely on drug–disease pairs. Other relationship types remained in the KG to facilitate indirect information flow. During fine-tuning, the model activated the degree-adjusted interdisease embedding feature. TxGNN underwent both pretraining and fine-tuning end to end. The variant with the highest validation performance was selected for test set evaluation and downstream analyses.Generating multi-hop interpretable explanationsIn a trained drug-repurposing prediction model, consider a target node j and a neighboring source node i connected by an edge ei,j at layer l. For each relationship r, intermediate messages \({{\bf{m}}}_{r,i}^{(l)}\) and \({{\bf{m}}}_{r,j}^{(l)}\) are computed. These embeddings are concatenated and input into a relationship-specific, single-layer neural network parameterized by \({{{W}}}_{g,r}^{(l)}\). This network predicts the probability of masking the message from source node i during the computation of the embedding of the target node j. The output is processed through a gate, which includes a sigmoid layer to constrain the probability to the range [0,1], followed by an indicator function that determines whether the edge should be dropped:$${z}_{i,\;j,\;r}^{(l)}={{\mathbb{I}}}_{{\mathbb{R}} > 0.5}\left({\rm{sigmoid}}\left({{{W}}}_{g,\;r}^{(l)}\left({{\bf{m}}}_{r,\;i}^{(l)}{\rm{||}}{{\bf{m}}}_{r,\;j}^{(l)}\right)\right)\right)$$such that \({z}_{i,j,r}^{(l)}\) ∈ 0,1. In practice, a location bias of 3 is added to the sigmoid function during initialization to ensure that its outputs are initially close to 1. This means that, at the start, the gates remain open, allowing the model to adaptively close the gates and mask edges within the subgraph as needed. This approach is essential because starting with random initialization, which drops edges randomly, creates a discrepancy between the original and updated predictions. Consequently, the model’s primary focus shifts toward minimizing this discrepancy rather than balancing the two objectives. To refine this mechanism, when a gate outputs 0, the corresponding message is not simply removed. Instead, it is substituted with a learnable baseline vector \({{\bf{b}}}_{r}^{(l)}\) for each relationship r and layer l. Therefore, the revised message from source node i to target node j is represented as follows:$${\widehat{{\bf{m}}}}_{i,\;r}^{\left(l\right)}={z}_{i,\;j,\;r}^{\left(l\right)}\times {{\bf{m}}}_{i,\;r}^{\left(l\right)}+\left(1-{z}_{i,\;j,\;r}^{\left(l\right)}\right)\times {{\bf{b}}}_{r}^{(l)}.$$Two objectives guide the optimization of the GraphMask gate weights. The first, faithfulness, aims to ensure that the updated predictions, after applying the mask, align closely with the initial prediction outcomes. The second objective encourages the model to apply as extensive a masking as feasible. These objectives inherently entail a tradeoff: increasing the extent of masking tends to enlarge the discrepancy between the updated and original predictions. This scenario was addressed through constrained optimization, employing Lagrange relaxation to balance the objectives. Specifically, the optimization sought to maximize the Lagrange multiplier λ to enforce the constraint, while minimizing the primary objective. The loss function employed for this purpose is formulated as follows:$$\mathop{\mathrm{max.}}\limits_{{{\gimel}}}\mathop{\mathrm{min.}}\limits_{{{{W}}}_{{{g}}}}\mathop{\sum }\limits_{k=1}^{L}\sum_{(i,\;r,\;j)\in {{\mathscr{D}}}_{+}\cup {{\mathscr{D}}}_{-}}{{\mathbb{I}}}_{[{\mathbb{R}}\ne 0]}{z}_{i,\;j,\;r}^{(k)}+{{\gimel }}(||{\hat{p}}_{i,\;j,\;r}-{p}_{i,\;j,\;r}||_{2}^{2}-\beta),$$where β is the margin between the updated and original predictions. After the training process is complete, edges (i, j, r) for which \({z}_{i,\;j,r}^{(k)}\) = 0 can be removed. The remaining edges serve as explanations for the model’s predictions. In addition, the value computed before applying the indicator function can be employed to quantify each edge’s contribution to the prediction. This facilitates the adjustment of granular differences in the contributions. More detailed adaptations of the GraphMask approach are discussed in Supplementary Note 3.Pilot usability evaluation of TxGNN with medical expertsThe TxGNN Explorer was developed following a user-centric design study process, as outlined in our pilot study25. This process involved comparing three visual presentations of GNN explanations from the user’s perspective. The findings from this comparison motivated the adoption of path-based explanations, which were preferred based on user feedback. The usability of the TxGNN Explorer was assessed through a comparison with a baseline that displayed only drug predictions and their associated confidence scores.For this usability study, 12 medical experts (7 male and 5 female experts, average age 34.25 years, referred to as P1–12) were recruited through personal contacts, Slack channels and email lists from collaborating institutions, with all participants providing informed consent. The group comprised five clinical researchers (P1–3, P11–12) and five practicing physicians (P4, P7–10), all holding MD degrees, and two medical school students with prior experience as pharmacists (P5, P6). Each participant had at least 5 years of experience in various medical specialties.The study was conducted remotely via Zoom in compliance with COVID-19-related restrictions. Participants accessed the study system (as shown in Supplementary Fig. 17) using their own computers and sharing their screens with the interviewer. The sequence in which predictions were presented, along with the conditions (TxGNN Explorer or the baseline approach), were randomized and counterbalanced across participants and tasks.In the drug assessment tasks, participants’ accuracy, confidence levels and task completion times were evaluated across 192 trials (16 tasks × 12 participants). Specifically, participants were tasked with: (1) determining the correctness of a drug prediction (that is, whether the drug could potentially be used to treat the disease) and (2) rating their confidence in their decision on a 5-point Likert scale (1 = not confident at all, 5 = completely confident). The system automatically logged the time taken to evaluate each prediction.On completion of all predictions, participants provided subjective ratings for both tasks regarding Trust, Helpfulness, Understandability and Willingness to Use, using a 5-point Likert scale (1 = strongly disagree, 5 = strongly agree). Subsequent semi-structured interviews yielded insights and feedback on the tool’s predictions, explanations and overall user experience. Each session of the user study lasted approximately 65 min.Analysis of medical records from a large healthcare systemPatient data from the Mount Sinai Health System’s EMRs in New York City were utilized to examine patterns from predictions in clinical practice. The Mount Sinai Institutional Review Board approved the study, ensuring that all clinical data were deidentified. The initial cohort included over 10 million patients, refined to those aged >18 years with at least one drug and one diagnosis on record, resulting in 1,272,085 patients. This refined cohort comprised 40.1% males, with an average age of 48.6 years (s.d. 18.6 years). The racial composition of the dataset is detailed in Supplementary Table 6.Disease and medication data were structured according to the Observational Medical Outcomes Partnership (OMOP) standard data model59. Predictions were generated for 1,363 diseases, identified by training a KG on 5% of randomly selected drug–disease pairs, serving as a validation set for early stopping. Disease names in the prediction dataset were aligned with SNOMED or International Classification of Diseases, 10th revision (ICD-10)60 codes and then mapped to OMOP concepts within the Mount Sinai data system. The analysis was restricted to diseases diagnosed in at least 1 patient, narrowing the focus to 478 conditions. Similarly, medication names were matched to DrugBank IDs, then to RxNorm IDs and OMOP concepts, limiting the scope to medications prescribed to at least 10 patients, resulting in 1,290 medications. Drug–disease pairs were further refined to those with at least 1 recorded instance of a patient being prescribed the drug for the disease, leading to 1,272,085 patients. Contingency tables were created for each drug–disease pair and Fisher’s exact function was used to calculate two-sided ORs and P values for each pair. Two-sided Bonferroni’s correction was applied to the P values using the statsmodels Python library’s multi-test function, identifying statistically significant drug–disease pairs as those with P < 0.005.Inclusion and ethics statementWe have complied with all relevant ethical regulations. Our research team represents a diverse group of collaborators. Roles and responsibilities were clearly defined and agreed on among collaborators before the start of the research. All researchers were included in the study design, study implementation, data ownership, intellectual property and authorship of publications. Our research did not face severe restrictions or prohibitions in the setting of the local researchers and no specific exceptions were granted for the present study in agreement with local stakeholders. Animal welfare regulations, environmental protection, risk-related regulations and transfer of biological materials, cultural artefacts or associated traditional knowledge out of the country do not apply to our research. Our research did not result in stigmatization, incrimination, discrimination or personal risk to participants. Appropriate provisions were taken to ensure the safety and well-being of all participants involved. Our team was committed to promoting equitable access to resources and benefits resulting from the research.Reporting summaryFurther information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Hot Topics

Related Articles