HEART: Learning Better Representation of EHR Data with a Heterogeneous Relation-Aware Transformer

Tinglin Huang1, Syed Asad Rizvi1, Rohan Krishna Thakur1, Vimig Socrates1, Meili Gupta1, David van Dijk1, R. Andrew Taylor1, Rex Ying1
1 Yale University

Introduction

Electronic health records (EHRs) is a tabular data which digitizes the medical information of an encounter, such as demography, diagnosis, medication, lab results, procedures, as shown in Figure 1:

Figure 1: Illustration of EHRs.|scale=0.3Figure 1: Illustration of EHRs.

Many researches focus on distilling meaningful clinical information from cohorts with foundation model. Specifically, such models treat medical entities in EHRs as tokens and organize the entities included in the encounters as sentences. These “sentences” can then be encoded by a transformer, allowing the entities to be represented in an embedding space, as shown in Figure 2(a):

Figure 2: Comparison between current foundation model and ours.|scale=0.8Figure 2: Comparison between current foundation model and ours.

However, we argue that the heterogeneous correlations between medical entities are critical for representation but have largely been overlooked. For example, understanding the relationship between "Antibiotics" (medication) and both "Fever" (diagnosis) and "Antibody Tests: Positive" (lab test) enables the model to recommend more clinically plausible drugs.

Motivated by this, we propose HEART, a Heterogeneous Relation-Aware Transformer for EHR data, which explicitly parameterizes pairwise representations between entities heterogeneously. Additionally, we introduce a multi-level attention mechanism to mitigate the computational cost associated with multiple visits, as demonstrated in Figure 2(b). Finally, two dedicated pretraining objectives are applied to enhance the model during pretraining.

Method

Heterogeneous Relation Embedding & Multi-level Attention Scheme

Given a patient, we flatten the corresponding historical visits into several sequences of entities:

[[D1,V1,1,,V1,N1],,[DS,VS,1,,VS,NS]][[D_1, V_{1,1},\cdots,V_{1,N_1}],\cdots,[D_S, V_{S,1},\cdots,V_{S,N_S}]]

where SS is the number of visits, NiN_i is the number of entities in the ii-th visit, and DiD_i represents the demography token for the patient in the ii-th visit. A learnable embedding will be assigned to each entity. Besides the entity embeddings, we explicitly encode the pairwise representation for each entity pair. Specifically, for an entity pair in the same visit (Vn,Vm)(V_n,V_m), we calculate the pairwise embedding Rnm\textbf{R}_{n\leftarrow m} as follow:

Rn=Linearτ(Vn)(Vn),Rm=Linearτ(Vm)(Vm);Rnm=Linear(RnRm)\textbf{R}_{n}=\text{Linear}_{\tau(V_n)}(\textbf{V}_{n}), \textbf{R}_{m}=\text{Linear}_{\tau(V_m)}(\textbf{V}_{m});\\ \textbf{R}_{n\leftarrow m}=\text{Linear}(\textbf{R}_{n}||\textbf{R}_{m})

where Linearτ()\text{Linear}_{\tau(\cdot)} denotes a type-specific linear transformation. This encoding will operate on each pair of entities in the same visit.

Computation cost will be the one of the biggest challenge to encode these heterogeneous representations. To alleviate this, we implemented a hierarchical encoding scheme to combine the encounter-level and entity-level attentions, as shown in Figure 3:

Figure 3: Frame Averaging.|scale=0.8Figure 3: Frame Averaging.

Specifically, as for entity-level context, we conduct attention among the entities within the same visit:

[D,V1,,VN]=Entity-Attn([D,V1,,VN])[\mathbf{D}',\mathbf{V}_1',\cdots,\mathbf{V}_N']=\text{Entity-Attn}([\mathbf{D}',\mathbf{V}_1',\cdots,\mathbf{V}_N'])

Besides, the heterogeneous relation will be introduced as a bias term to refine the attention map and the context to update the entity embeddings. As for the encounter-level context, we limit the attention to demography tokens across all historical encounters:

[D1,,DS]=Enc-Attn([D1,,DS])[\mathbf{D}_1',\cdots,\mathbf{D}_S']=\text{Enc-Attn}([\mathbf{D}_1,\cdots,\mathbf{D}_S])

Pretrained Objective

Most previous approaches adopt masked token prediction (MTP) for pretraining, which replaces actual tokens with [MASK] and performs single-label classification at each masked position. However, MTP is position-dependent and thus not suitable for EHR due to the unordered nature of medical entities. In light of this, we adapt MTP to the missing entity prediction (MEP) task, which is position-agnostic and heterogeneity-aware. The main idea is to let the model perform multi-label classification based on one [MASK] for each entity type, as shown in Figure 4.

Figure 4: Comparison between masked token prediction and missing entities prediction.|scale=0.7Figure 4: Comparison between masked token prediction and missing entities prediction.

Besides, we also incorporate anomaly detection as an additional pretraining task to encourage the model to identify unrelated entities given a context and to learn more robust representations. Specifically, we replace some of the entities with random entities with the same type to synthesize anomaly data. A binary classifier is applied to predict whether it is an anomaly.

Downstream tasks

We evaluate HEART across 5 downstream tasks on 2 EHR datasets:

  • Dataset: MIMIC-III and eICU.
  • Downstream task: death prediction, prolonged length of stay (PLOS) prediction, readmission prediction, and next diagnosis prediction in 6/12 months.

Figure 5: Benchmarking.|scale=0.7Figure 5: Benchmarking.