Skip to content

GatorAIM/BAT-AKI

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 

Repository files navigation

BAT-AKI: Biomarker-Aware Transformer for Acute Kidney Injury Prediction

This repository contains the official implementation of BAT-AKI, a biomarker-aware transformer framework for early prediction and prognosis of Acute Kidney Injury (AKI) using longitudinal Electronic Health Records (EHRs).

BAT-AKI integrates temporal dynamics, clinical semantics, ontology structures, and biomarker abnormality–aware pretraining to improve predictive performance, robustness, and interpretability across multiple AKI-related tasks.


1. Introduction

Acute Kidney Injury (AKI) is a common and life-threatening complication among hospitalized patients. Although Electronic Health Records (EHRs) contain rich longitudinal information, their irregular temporal structure and heterogeneous clinical concepts pose challenges for conventional machine learning models.

BAT-AKI is designed to address these challenges by:

  • Modeling long-range temporal dependencies using a Transformer backbone
  • Incorporating structured medical knowledge through ontology-aware embeddings
  • Enhancing sensitivity to biomarker abnormalities via tailored pretraining objectives
  • Supporting flexible downstream fine-tuning across multiple AKI-related tasks

2. Model Overview

2.1 Input Representation

Each patient admission is serialized into a token sequence consisting of:

  • Demographic tokens (e.g., age, sex, race)
  • Time-ordered clinical event tokens
  • Associated auxiliary sequences:
    • Time gaps (delta_t)
    • Segment identifiers (segment_ids)
    • Module identifiers (module_ids)
    • Ontology and minor-ontology identifiers
    • Biomarker abnormality flags (optional)

All sequences are padded or truncated to a fixed maximum length.

2.2 Embedding Components

BAT-AKI supports a modular embedding design, including:

  • Token embeddings
  • Continuous time embeddings (sinusoidal)
  • Segment embeddings
  • Module embeddings (optional)
  • Ontology and minor-ontology embeddings (optional)
  • Semantic embeddings derived from prompt-based medical code representations (optional)

These embeddings are combined and passed to the Transformer encoder.

2.3 Transformer Backbone

The backbone consists of multiple custom Transformer encoder layers with:

  • Multi-head self-attention
  • Feedforward networks
  • Residual connections and layer normalization

Attention weights are preserved to support downstream interpretability analyses.


3. Repository Structure

BAT-AKI/
├── BAT_AKI_evaluate.py              # Main evaluation script
├── BAT-AKI pretraining.py           # Main pretraining script
│
├── configs/                         # Configuration files
│   └── pretrain_config.py          # Pretraining configuration
│
├── dataset/                         # Dataset handling modules
│   ├── downstream_dataset.py       # Downstream task dataset
│   ├── handle_matrix.py            # Matrix handling utilities
│   ├── load_data.py                # Data loading functions
│   └── masked_ehr_dataset.py      # Masked EHR dataset implementation
│
├── model/                           # Model definitions
│   ├── classifier_model.py         # Classification model
│   └── mlm_model.py                # Masked Language Model
│
├── piplines/                        # Pipeline modules
│   ├── __init__.py
│   ├── datasets.py                 # Dataset utilities
│   ├── finetune.py                 # Fine-tuning functions
│   ├── load_pretrained.py          # Load pretrained models
│   ├── mask_eval.py                # Mask evaluation
│   ├── metrics.py                  # Metrics calculation
│   ├── resources.py                # Resource loading
│   ├── semantic.py                 # Semantic processing
│   └── train_mlm.py                # MLM training
│
├── static files/                    # Static configuration files
│   ├── focus_tokens.json           # Focus tokens configuration
│   └── ignore_prefixes.json        # Ignore prefixes configuration
│
└── utils/                           # Utility functions
│   ├── build_selected_tokens.py    # Token selection builder
│   ├── evaluation.py                # Evaluation utilities
│   └── loss.py                      # Loss functions
│
├── baseline/                        # Baseline model implementations
│   ├── run_lstm.py                 # LSTM baseline model
│   ├── run_retain.py               # RETAIN baseline model
│   ├── run_transformer.py          # Transformer baseline model
│   ├── run_XGboost.py              # XGBoost baseline model
│   ├── run_preprocessing.py        # Data preprocessing for baselines
│   └── function/                   # Baseline utility functions
│       ├── __init__.py
│       ├── preprocessing.py        # Preprocessing utilities
│       ├── lstm_func.py            # LSTM helper functions
│       ├── retain_func.py          # RETAIN helper functions
│       ├── transformer_func.py     # Transformer helper functions
│       ├── xgboost_func.py         # XGBoost helper functions
│       └── preprecessing_parameters.jason  # Preprocessing parameters
│
└── ablation_study/                    # Ablation study (your ablation_study files)
    ├── ablation_study.py            # Ablation study main script
    ├── dataset/                     # dataset modules for ablation
    ├── model/                       # model modules for ablation
    └── utils/                       # utils for ablation

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages