Table of Contents
Fetching ...

Modern Hopfield Networks and Attention for Immune Repertoire Classification

Michael Widrich, Bernhard Schäfl, Hubert Ramsauer, Milena Pavlović, Lukas Gruber, Markus Holzleitner, Johannes Brandstetter, Geir Kjetil Sandve, Victor Greiff, Sepp Hochreiter, Günter Klambauer

TL;DR

This paper tackles immune repertoire classification framed as a massive multiple-instance learning problem with bags containing hundreds of thousands of sequences and extremely low witness rates. It introduces DeepRC, which combines 1D CNN motif recognition with transformer-like attention derived from modern Hopfield networks to pool over vast instance sets, enabling motif discovery and interpretable predictions. Across simulated, LSTM-generated, real-implanted, and CMV datasets, DeepRC achieves state-of-the-art predictive performance and provides interpretable motif attributions via attention weights and Integrated Gradients. The approach scales to enormous repertoires and offers a pathway toward improved diagnostics and vaccine-related insights by identifying discriminative receptor motifs.

Abstract

A central mechanism in machine learning is to identify, store, and recognize patterns. How to learn, access, and retrieve such patterns is crucial in Hopfield networks and the more recent transformer architectures. We show that the attention mechanism of transformer architectures is actually the update rule of modern Hopfield networks that can store exponentially many patterns. We exploit this high storage capacity of modern Hopfield networks to solve a challenging multiple instance learning (MIL) problem in computational biology: immune repertoire classification. Accurate and interpretable machine learning methods solving this problem could pave the way towards new vaccines and therapies, which is currently a very relevant research topic intensified by the COVID-19 crisis. Immune repertoire classification based on the vast number of immunosequences of an individual is a MIL problem with an unprecedentedly massive number of instances, two orders of magnitude larger than currently considered problems, and with an extremely low witness rate. In this work, we present our novel method DeepRC that integrates transformer-like attention, or equivalently modern Hopfield networks, into deep learning architectures for massive MIL such as immune repertoire classification. We demonstrate that DeepRC outperforms all other methods with respect to predictive performance on large-scale experiments, including simulated and real-world virus infection data, and enables the extraction of sequence motifs that are connected to a given disease class. Source code and datasets: https://github.com/ml-jku/DeepRC

Modern Hopfield Networks and Attention for Immune Repertoire Classification

TL;DR

This paper tackles immune repertoire classification framed as a massive multiple-instance learning problem with bags containing hundreds of thousands of sequences and extremely low witness rates. It introduces DeepRC, which combines 1D CNN motif recognition with transformer-like attention derived from modern Hopfield networks to pool over vast instance sets, enabling motif discovery and interpretable predictions. Across simulated, LSTM-generated, real-implanted, and CMV datasets, DeepRC achieves state-of-the-art predictive performance and provides interpretable motif attributions via attention weights and Integrated Gradients. The approach scales to enormous repertoires and offers a pathway toward improved diagnostics and vaccine-related insights by identifying discriminative receptor motifs.

Abstract

A central mechanism in machine learning is to identify, store, and recognize patterns. How to learn, access, and retrieve such patterns is crucial in Hopfield networks and the more recent transformer architectures. We show that the attention mechanism of transformer architectures is actually the update rule of modern Hopfield networks that can store exponentially many patterns. We exploit this high storage capacity of modern Hopfield networks to solve a challenging multiple instance learning (MIL) problem in computational biology: immune repertoire classification. Accurate and interpretable machine learning methods solving this problem could pave the way towards new vaccines and therapies, which is currently a very relevant research topic intensified by the COVID-19 crisis. Immune repertoire classification based on the vast number of immunosequences of an individual is a MIL problem with an unprecedentedly massive number of instances, two orders of magnitude larger than currently considered problems, and with an extremely low witness rate. In this work, we present our novel method DeepRC that integrates transformer-like attention, or equivalently modern Hopfield networks, into deep learning architectures for massive MIL such as immune repertoire classification. We demonstrate that DeepRC outperforms all other methods with respect to predictive performance on large-scale experiments, including simulated and real-world virus infection data, and enables the extraction of sequence motifs that are connected to a given disease class. Source code and datasets: https://github.com/ml-jku/DeepRC

Paper Structure

This paper contains 37 sections, 1 theorem, 6 equations, 6 figures, 18 tables.

Key Result

Theorem 1

We assume a failure probability $0<p\leqslant 1$ and randomly chosen patterns on the sphere with radius $M=K \sqrt{d-1}$. We define $a \ := \ \frac{2}{d-1} \ (1 \ + \ \ln(2 \ \beta \ K^2 \ p \ (d-1)))$, $b \ := \ \frac{2 \ K^2 \ \beta}{5}$, and $c \ = \ \frac{b}{W_0(\exp(a \ + \ \ln(b))}$, where Examples are $c\geq 3.1546$ for $\beta=1$, $K=3$, $d= 20$ and $p=0.001$ ($a + \ln(b)>1.27$) and $c\

Figures (6)

  • Figure 1: Schematic representation of the DeepRC approach. a) An immune repertoire $X$ is represented by large bags of immune receptor sequences (colored). A neural network (NN) $h$ serves to recognize patterns in each of the sequences $s_i$ and maps them to sequence-representations $\bm{z}_i$. A pooling function $f$ is used to obtain a repertoire-representation $\bm{z}$ for the input object. Finally, an output network $o$ predicts the class label $\hat{y}$. b) DeepRC uses stacked 1D convolutions for a parameterized function $h$ due to their computational efficiency. Potentially, millions of sequences have to be processed for each input object. In principle, also recurrent neural networks (RNNs), such as LSTMs hochreiter2007fast, or transformer networks vaswani2017attention may be used but are currently computationally too costly. c) Attention-pooling is used to obtain a repertoire-representation $\bm{z}$ for each input object, where DeepRC uses weighted averages of sequence-representations. The weights are determined by an update rule of modern Hopfield networks that allows to retrieve exponentially many patterns.
  • Figure 2: DeepRC architecture as used in Table \ref{['tab:results_full']} with sub-networks $h_1$, $h_2$, and $o$. $d_l$ indicates the sequence length.
  • Figure A1: We use 3 input features with values in range $[0,1]$ to encode the relative position of each AA in a sequence with respect to the sequence. "feature 1" encodes if an AA is close to the sequence start, "feature 2" to the sequence center, and "feature 3" to the sequence end. For every position in the sequence, the values of all three features sum up to $1$.
  • Figure A2: Distribution of AAs and k-mers in real-world CMV dataset and LSTM-generated data. Left: Histograms of real-world data. Right: Histograms of LSTM-generated data. a) Frequency of AAs in sequences of the CMV dataset. b) Frequency of AAs in sequences of the LSTM-generated datasets. c) Frequency of top 200 4-mers in sequences of the CMV dataset. d) Frequency of top 200 4-mers in sequences of the LSTM-generated datasets. e) Frequency of top 20 4-mers in sequences of the CMV dataset. f) Frequency of top 20 4-mers in sequences of the LSTM-generated datasets. Overall the distributions of AAs and 4-mers are similar in both datasets.
  • Figure A3: Integrated Gradients applied to input sequences of positive class repertoires. Three sequences with the highest contributions to the prediction of their respective repertoires are shown. a) Input sequence taken from "simulated immunosequencing data" with implanted motif SZdZdN and motif implantation probability $0.1\%$. The DeepRC model reacts to the S and N at the 5th and 8th sequence position, thereby identifying the implanted motif in this sequence. b) and c) Input sequence taken from "real-world data with implanted signals" with implanted motifs {LrDrRr; CrArS; GrL-N} and motif implantation probability $0.1\%$. The DeepRC model reacts to the fully implanted motif CAS (b) and to the partly implanted motif AAs C and A at the 5th and 7th sequence position (c), thereby identifying the implanted motif in the sequences. Wildcard characters in implanted motifs are indicated by Z, characters with $50\%$ probability of being removed by d, and gap locations of random lengths of $\{0;1;2\}$ by -. Larger characters in the sequences indicate higher contribution, with blue indicating positive contribution and red indicating negative contribution towards the prediction of the diseased class.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Definition 1: Pattern Stored and Retrieved
  • Theorem 1