Protein Sequence Prediction of Target Structure via Wave Function Embedding and Multi-Scale Geometric Flash-Attention
proteusAI is a transformer model that predicts the optimal protein sequence given a target structure of alpha-carbon (
Many protein sequence prediction AI models use contact maps, distance metrics, and/or dihedral angles of the protein structure as input features to AI models. However, these features fail to encode local AND global interactions of the
To achieve a greater inductive bias, we propose a method of encoding the three-dimensional coordinates of each token, i.e. each
where
Moreover, we can define multiple wavefunctions, each with a different k, and thus a different wavelength. In this case, wave functions corresponding to small
Where,
Note the similarity between this formula and the traditional positional encoding formula:
This is because the wave function embedding process can be seen as a generalization of positional encoding for irregularly spaced tokens in arbitrary dimensions.
This method offers several advantages to existing methods. For one, it offers rotationally and translationally invariant representation of the protein, since the wave function only accounts for relative distances. Additionally, by using multiple wave functions of differing granularity (with different k), the model will capture a wide range of representations of the same structure, in which both local and global interactions are encoded. While computing the superposed wave function outputs for each Ca, and for each of the d_model//2 wave functions, scales O(
Additionally, the Wavefunction Embedding module implements an extremely efficient backwards pass, achieving 10X speedup and 1000X memory reduction WITHOUT any hardware optimizations, written fully in PyTorch. This is achieved by storing the sums of the cosine terms for each token and the sum of the sin terms during the forward pass, each of which is only batch x N x d_model//2. this avoids both storing large intermediate tensors and recomputation, and is accomplished by analytically simplifying the gradient computation, dropping the computational complexity of the backward pass from
The wavefunction features are combined with ESM2 amino acid embeddings which contain rich evolutionary and structural information, by adding the two features together. The resulting features align very well with the rest of the model, which is a stack of traditional Encoder layers from the original transformer paper. While transformers are known for their ability to perform long range attention, it is still beneficial to inject a spatial bias into the model, so that extremely distant residues do not affect each other too much. To solve this problem, we introduce Geometric Attention, which is a novel multi-head attention (MHA) mechanism. In the custom MHA module, the attention logits are scaled by Radial Basis Functions (RBF). Each head of the MHA module gets assigned a specific spread (
Where
To reduce the memory footprint and speed up the computation, the Multi-Scale Geometric Attention module is fused into a single GPU kernel using triton, taking inspiration from the Flash Attention 2 paper (https://arxiv.org/abs/2205.14135). A custom backwards pass is also implemented to make not only Q, K, and V learnable, but also the spread of each attention head. Thus, each head learns at what scale it should evaluate the RBFs, and how to weigh pairs of tokens. This design aligns very well with the previously described featurization process, since the features themselves correspond to different representations of the same structure at distinct scales via the learned wavelength ranges (
This multi-scale geometric attention mechanism can be seen as a generalization of graph neural networks (GNN), since the scaled attention mechanism creates soft, continuous edges between token pairs, which are defined at multiple scales.
After passing through all Encoder layers, the logits pass through a linear layer to convert the