Build patient outcome prediction applications using Amazon HealthLake and Amazon SageMaker
Healthcare data can be challenging to work with and AWS customers have been looking for solutions to solve certain business challenges with the help of data and machine learning (ML) techniques. Some of the data is structured, such as birthday, gender, and marital status, but most of the data is unstructured, such as diagnosis codes or physician’s notes. This data is designed for human beings to understand, but not for computers to comprehend. The key challenges of using healthcare data are as follows:
How to effectively use both structured and unstructured data to get a complete view of the data
How to intuitively interpret the prediction results
With the rise of AI/ML technologies, solving these challenges became possible.
One relevant use case is patient outcome prediction, which includes acute or chronic condition-triggered hospital visits or readmission predictions, disease progression predictions within a certain observation window, and so on. Healthcare providers, payors, and pharmaceutical companies can use prediction results to recommend early intervention, improve outreach communication, improve patient care experience, and reduce overall cost.
In this post, we show you an example of building a deep learning based patient outcome prediction model. We build the model in Amazon SageMaker with MIMIC-III data stored in Amazon HealthLake and turn it into a lightweight application for visualization and interpretability. The prediction target for this example is mortality prediction within 90 days after ICU discharge. You can modify the target variable to suit your needs.
Amazon HealthLake helps make sense of health data
HealthLake is a HIPAA-eligible service that enables healthcare providers, health insurance companies, and pharmaceutical companies to store, transform, query, and analyze health data at petabyte scale.
The data source we exported from the HealthLake API is called MIMIC-III . It’s a large, freely available database comprised of deidentified health-related data associated with over 40,000 patients who stayed in critical care units. The database includes information such as demographics, vital signs, lab test results, procedures, medications, caregiver notes, imaging reports, and mortality. We can’t share the data in this post due to license restrictions, but you can visit MIMIC’s official website to request data access.
HealthLake automatically extracts clinical entities and links ICD-10-CM and RxNorm codes to unstructured text such as discharge notes when the text is stored in HealthLake as a Fast Healthcare Interoperable Resource (FHIR) DocumentReference type. The extracted entities are added back onto the FHIR DocumentReference resource as a FHIR extension. Text embedded in the DocumentReference should be base64 encoded. When building the predictive models, we can combine the extracted information with other structured data and get a more holistic view of the patient’s medical history.
Overview of solution
The following architecture diagram illustrates the model training pipeline, inference pipeline, and information-rendering front end.
We use the HealthLake export API to export the normalized data to an Amazon Simple Storage Service (Amazon S3) bucket. Then we use an AWS Glue crawler to create a Data Catalog. We can use Amazon Athena with the Data Catalog to run SQL-like queries against the exported data. Unstructured data of patient records gets processed separately to extract indexed data and combine it with other structured information. Then we use a SageMaker notebook with TensorFlow containers to train a custom convolutional neural network model. The model artifact is saved to an S3 bucket and later is used to test model performance on unseen data. Finally, we run inference on the model using SageMaker batch transform and save the results to Amazon S3. We also develop visualization components and render them via Amazon API Gateway to improve the model’s interpretability.
In this post, we walk you through the following steps:
Create training and testing datasets.
Use embedding techniques for a richer representation of the unstructured data.
Train the model.
Evaluate our results.
Visualize the results with custom UI components.
Create training and testing datasets
First, we create a binary variable for the target—mortality within a 90-day window after discharge. A patient may have multiple records with the target variable value as 0 before this patient’s mortality status is set to 1. This situation also applies to many other patient outcome prediction target variables. We therefore split the data by patient_id into training, validation, and testing datasets to prevent information leakage. This way, a single patient’s multiple records don’t appear in more than one dataset category.
We first put aside 20% of the patients for testing purposes, and treat these patient records as never seen by the algorithm. Among the remaining 80% patients, we take another 80% of the data for training and 20% for validation. We upload these datasets to our S3 bucket for later use.
Use embedding techniques
Traditional ML methods may use frequency count based encoding techniques such as term frequency-inverse document frequency (TF-IDF). In this post, we use embedding techniques that take advantage of a richer representation of the unstructured data by learning relationships between different medical codes.
We first take in a sequence of medical codes and use skip-gram to learn the relationships between different codes. The learned embedding for each medical code is typically an n-length vector (such as 300) that characterizes the individual element. These dimensions usually don’t have explicit meanings, however, similar medical concepts should be projected closer to one another in the feature space. We learned such vectors for all the vocabularies during training and stacked them together as a matrix. We later use this embedding matrix to train a convolutional neural network model and perform testing on unseen data.
Train the model
We first define the structure of the neural network and then use a SageMaker-hosted TensorFlow training image to train our model. The layers are defined as follows:
Embedding layer – Takes in raw medical code sequences and converts each individual code into embeddings
Convolutional layer – Takes in the embeddings and convolves with tunable filters
Pooling layer – Applies aggregation computations to reduce the size for the next layers’ input
Dropout layer – Randomly turns off connections to reduce overfitting
Concatenation layer – Combines the processed information from the previous layers with structured information such as patient age or gender
Fully connected layer with sigmoid activation – Outputs the final prediction probabilities
During training, we use these prediction probabilities to calculate metrics and guide the direction of the training process. During testing, these probabilities are output as a file on Amazon S3.
When the training process is complete, we save the model artifact and upload it to an S3 bucket for later use.
The following visualization shows the ROC (Receiver Operating Characteristics) curve and classification report on the test data.
The ROC curve shows the model’s performance at different thresholds. The AUC (Area Under the Curve) for the ROC curve is 0.82, which measures the model’s ability to separate different target classes. The classification report gives you an overview of the model’s precision, recall, and F1 score for each class. The weighted average F1 score for the model is 0.74.
Visualization with custom UI components
We can visualize the prediction results in many ways. For this post, we only demonstrate how to render SHAP (SHapley Additive exPlanations) values to improve the model’s interpretability. SHAP is a game theoretic approach to explain the output of an ML model. The visualization can show the details of each prediction’s contributing factors so that you can intuitively understand what features are pushing the predicted probability higher (towards 1) or lower (towards 0) from the base value.
We first define an HTML template and keep adding visualization components into the template. We then upload the HTML file to an S3 bucket and set up an AWS Lambda function to retrieve the HTML content, and the content is sent to an API Gateway to render a webpage.
Set up HTML templates
We can define an HTML template with empty code blocks in it with the following code:
Create visualization components
An example of creating a SHAP value visualization might look like the following code:
This can generate an intuitive explanation of drivers behind the predictions. As shown in the following visualization, the red tickers are driving the probability of a patient’s outcome prediction to the higher end (towards 1), and the blue tickers are driving the probability to the lower end (towards 0). As a result, this patient has a probability of 0.34 compared to the training cohort base value of 0.4481. Therefore, this patient has a lower chance of being positive on the target variable.
Create a Lambda function to parse the HTML file
An example Lambda function can be as simple as the following code:
The purpose of this function is to retrieve the information that needs to be rendered without exposing the Amazon S3 resources to the public, and send the information to an API Gateway.
Create an API Gateway to render the HTML file
We can use the AWS Cloud Development Kit (AWS CDK) to automate these settings. For example:
The integration_responses part ensures that the returned content is rendered correctly as HTML by API Gateway. When the API Gateway is deployed, you get an invoke URL. You can copy and paste this URL into a web browser to check the visualization result.
In this post, we demonstrated how to use Amazon SageMaker and Amazon HealthLake to build a deep learning model to solve a healthcare and life sciences challenge and interpret the results via visualization techniques. With this solution, hospitals can better care for patients and provide appropriate intervention by predicting patient outcomes. We demonstrated this solution for a mortality prediction within 90 days after ICU discharge, you can apply the same method to other patient outcome prediction use cases.
HealthLake makes it easy to work with health data and extract relevant data points from unstructured clinical texts. Deep learning modeling techniques give us options to build more accurate models with less feature engineering effort, and AWS technologies make it possible to visualize model interpretations with a lightweight front-end solution.
To learn more about HealthLake, see Amazon HealthLake resources and Making sense of your health data with Amazon HealthLake. For a hands-on tutorial, visit our Amazon HealthLake workshop. For more examples using HealthLake and population health, see Population health applications with Amazon HealthLake – Part 1: Analytics and monitoring using Amazon QuickSight.
 MIMIC-III, a freely accessible critical care database. Johnson AEW, Pollard TJ, Shen L, Lehman L, Feng M, Ghassemi M, Moody B, Szolovits P, Celi LA, and Mark RG. Scientific Data (2016). DOI: 10.1038/sdata.2016.35. Available from: http://www.nature.com/articles/sdata201635
About the Authors
Shuai Cao is a Data Scientist in the Professional Services team at Amazon Web Services. His expertise is building machine learning applications at scale for healthcare and life sciences customers. Outside of work, he loves traveling around the world and playing dozens of different instruments.
Garin Kessler is a Senior Data Science Manager at Amazon Web Services, where he leads teams of data scientists and application architects to deliver bespoke machine learning applications for customers. Outside of AWS, he lectures on machine learning and neural language models at Georgetown. When not working, he enjoys listening to (and making) music of questionable quality with friends and family.
Kartik Kannapur is a Data Scientist with AWS Professional Services. He holds a master’s degree in Applied Mathematics and Statistics from Stony Brook University and focuses on using machine learning to solve customer business problems.