I find all the writeups of self-attention very confusing. They seem to inevitably explain small bits in enormous detail and then leave all kinds of giant areas unexplained, usually are replete with jargon that they don't bother to clarify, etc. I feel like most of these are written by people who are confused themselves, trying to clarify the parts they don't understand as if this is what everybody else must be confused about to, so they inevitably gloss over everything they already know.
Fundamentally, I have never understood why a conventional neural network cannot learn self-attention. If layers are fully connected then they definitely have the ability to create a weighting of the inputs accommodating learning of relative spatial features in the data. In fact that's almost a definition of what a neural network is. If relative positional learning is important, that could be added the same way it is in Transformers without the explicit self-attention layer. So what is self-attention really doing beyond this? Why do we need a Query-Key-Value construct here? I am sure I am missing something very basic and fundamental here.
If by "conventional neural network" you mean a stack off fully connected layers, then yes, in theory one of those could learn a similar mechanism because of the universal approximation theorem. However, training one might be intractable.
It's good to ignore self-attention for a moment and take a look at a convolutional network (a CNN). Why is a CNN more effective than just stacks of fully connected layers? Well, instead of just throwing data at the network and telling it to figure out what to do with it, we've built in some prior knowledge into the network. We tell it "you know, a cup is going to be a cup even if it is 10 pixels up or 10 pixels down; even if it is in the upper right of the image or the lower left." We also tell it, "you know, the pixels near a given pixel are going to be pretty correlated with that pixel, much more so than pixels far away." Convolutions help us express that kind of knowledge in the form of a neural network.
Self-attention plays a similar role. We are imbuing our network with an architecture that is aware of the data it is about to receive. We tell it "hey, elements in this sequence have a relation with one another, and that relative location might be important for the answer". Similar to convolutions, we also tell it that the location of various tokens in a sequence is going to vary: there shouldn't be much difference between "Today the dog went to the park" and "The dog went to the park today." Like convolutions, self-attention builds in certain assumptions we have about the data we are going to train the network on.
So yes, you are right that fully-connected layers can emulate similar behavior, but training them to do that isn't easy. With self-attention, we've started with more prior knowledge about the problem at hand, so it is easier to solve.
Great answer. Imbuing a deep learning model with well thought out inductive biases is one of the strongest ways of guiding your model to interpret the data the way you want it to. Otherwise it’s kind of shooting in the dark and hoping to get lucky.
I can’t stand it when people lazily personify ML models, but it’s akin to giving someone with no experience some wood and then pointing to a shed and saying “make one of those from this”. Instead you’d expect them to be much more successful if you also give them a saw, a drill, some screws etc.
Good explanation. Which is why the success of transformers, LLMs etc. is still not the final word in Rich Sutton's "The Bitter Lesson" -- no learning method is free of inductive biases.
> Read the intro in the original paper "Attention is all you need"
I wouldn't call this the original "attention" paper. Definitely not the first paper to use the phrase. If you want clear proof of this, let's read the paper
> Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in the input or output sequences.
I do think a lot of people's lack of understanding of attention is because they are so focused on DP(S)A that they miss a lot of the broader picture. And math. Not enough people dig into the math.
I think your misunderstanding is that fully connected layers can operate in the same way that attention does — they can’t. A fully connected layer operates on one dimension at a time. Typically language models have two (plus a batch dimension). One dimension is your token/word dimension. The second is the hidden dimension.
The hidden dimension is constructed inside the model to create space where it can embed each token into a vector and then enrich that space with contextual information derived from the sequence. In order for that to occur, the model must have a means of transferring information from along the token dimension.
One way to accomplish this is to use a 2d convolution; however, the scope of a convolution is limited to the size of its kernel. A fully connected layer is the same as a 2d convolution with a kernel size of 1. So you can see that no information from neighboring tokens can be applied to the hidden space.
The standard self-attention equation has a global scope from the full matrix multiplication of the input tensor with its transpose. Each element of the resulting matrix demonstrates some interaction with every other token in the sequence. Next a softmax operation is applied, which acts as a gating or relevance function. Finally, this is multiplied back to the original input to build that information into the hidden dimension for each token.
There have been attempts to do similar operations using fully connected layers. Look at the architecture of SGUs (spatial getting units). In some applications, they have good performance, but because fully connected layers operate on each dimension independently and serially, they are not equivalent to attention.
Last, my best recommendation for anybody trying to understand attention is to stop reading articles and instead spend your time looking at the math. It’s usually much less confusing than any of the dozens of explanations floating around the web, including the one I just gave. The math is not too complicated, especially once you know the reasons for why we need to use it.
There's a bit more to it, but you can partly view (self-)attention as a trick that makes optimization easier, that is, it improves gradient flow similar to skip connections to make learning easier. That was more obvious to see when we used to use attention with RNNs, where attention can be viewed as being equivalent to "dynamic skip connections"
While fully connected layers can in theory learn anything, that's a very hard optimization problem in terms of gradient flow. Attention adds inductive biases (prior/domain knowledge) about what you want the network to learn, which makes the optimization of that specific aspect easier for the optimization algorithm.
In general, you can view almost anything in ML/DL as improving either optimization or generalization, and while it's a spectrum, attention falls more into the optimization category.
> Fundamentally, I have never understood why a conventional neural network cannot learn self-attention.
Here's how I think about it: Yes, you can learn everything with an MLP, since universal approximation and so on, but it is not efficient. And (modern) ML is all about scale.
Here's one way to see the difference: Self-attention takes as input N channels of dimension D. It maps to N keys and Queries with a DxD matrix, time = O(ND^2). Next it computes all pair-wise Key-Query inner products: time = O(N^2D). Finally
softmax takes time = O(N^2) and computing "probability times values" takes O(DN^2).
All in all self-attention takes O(DN^2+ND^2) time to map ND values to ND values.
How long would an MLP take to do the same thing? O((ND)^2).
So, in a typical case of D~1000 and N~1000 we save a factor 1000.
For the price of 1 MLP layer, you could afford 1000 self-attention layers.
It's a pretty major difference.
I think there's some room for the second factor too: generalization. If you have 2 models for the same phenomenon, and they're equally good, the model with less parameters is the better one. Not because it calculates faster, but because it will actually be more accurate.
A conventional neural network, i.e. one using a stack of dense layers, can't unroll across a sequence in the way the transformer does. So while it could compute the relative importance and interaction of the features it sees it wouldn't be able to compute that across arbitrary length sequences without a mechanism for the sequence elements to interact, which is what self attention provides.
Practical attention implementations don't work over arbitrary length sequences. The universal approximation theorem holds IMO. Information will mix as you go through fully connected MLP layers. Attention is apparently a prior structure that is needed to really reduce training costs.
> I feel like most of these are written by people who are confused themselves
I teach the course at my uni and I'm highly confident this is true, even in the research community. Part of this is that people are hyper concentrated on dot-product attention (<softmax(<q,k>),v>) (DPA). There is a lot more forms of attention than this. It does help to go back to early attention mechanisms like those discussed by Bahdanau (Bengio's student) and Graves (Deep Mind). When you look at these you'll find a clearer definition: a learned weighing function (Bahdanau specifies as a probability), conditioned on some input applied to a learned embedding conditioned on another input. If the two inputs are the same then it's self attention, otherwise cross. You'll see some people refer to the learned weighting as a score (not to be confused with Fisher Score -- gradient of likelihood -- used for diffusion training). Understanding this you'll see that the definition is broader that DPA but also what makes attention powerful. But lots of people don't catch this because they don't have the history of RNNs. DPA has become the de facto choice because the dot product between the two embeddings creates a more powerful score function without introducing lots of parameters (there are people exploring more complex structures).
> I have never understood why a conventional neural network cannot learn self-attention.
I'll even ask another question, why can't densely connected networks (linear) learn self-attention? Both convolutions and dense layers are universal approximators, right? But we often see convolutions as the preferred choice over dense layers (note: they are equivalent when 1D, kernel size is 1, no padding, stride of 1). Well the power is how information is encoded and connected. CNNs in essence capture a form of positional encoding as they have a structured order. Transformers need a bit more help (note that there's also relative positional bias (they also require augmentations)) and that information helps them create very powerful graphs. But one big advantage to DPA is the multiple heads, which can take different "views." Importantly, transformers scale extremely well. Essentially attention creates a more complex connection between the information. That can make them harder to train, but also allow for more efficient encoding of information.
I could talk a lot more but I'll stop here (though a very incomplete explanation). There are two resources that I really like to hand to my students:
Is there a coherent theory of ML I could read somewhere? Something higher level than "here is how we use tanh", yet not hand wavy.
All these "we've attached this random piece to our model and it makes things better for reasons we don't understand" leaves a poor impression. To an outsider, the GPT model looks a lot like a IIR / FIR filter made of linear operators (matrix multiplications) and convolutions (matrix multiplications in the Fourier space) that processes a sequence of vectors, instead of numbers (like a normal FIR filter would). Perhaps the extensive theory of filters apply here, and perhaps we can even apply the Z transform on the GPT filter to analyze its convergence?
If I understand correctly z transform analysis can't work as these systems are fundamentally non linear. In fact, the "non linear" part, i.e. the activation functions, are the part of these networks that allow for them to be universal function approximators and learn arbitrary things, versus just becoming a simple linear system.
One thing people generally don't highlight about the transformer architecture, which I think is very important, are the pass through connections. This allows the model to "spread" out it's learnings across the network rather than them being forced to very local. I think this is also a reason why the Resnet was so successful.
This explanation walks you through the math and the corresponding code, but (at least in my case, maybe I'm dumb) it failed to help me understand why these steps are necessary or to relate the math to the intended outcome. As a result, I don't feel that I'm any closer to really understanding the heart of self-attention.
At the end of last year I put together a repository to try and show what is achieved by self-attention on a toy example: detect whether a sequence of characters contains both "a" and "b".
The toy problem is useful because the model dimensionality is low enough to make visualization straightforward. The walkthrough also goes through how things can go wrong, and how it can be improved, etc.
It's not terse like nanoGPT or similar because the goal is a bit different. In particular, to gain more intuition about the intermediate attention computations, the intermediate tensors are named and persisted so they can be compared and visualized after the fact. Everything should be exactly reproducible locally too!
I agree. It seems like the target audience is the experienced Deep learning practitioner. Which makes me wonder why such an audience would need this treatment. Why not just read the original paper?
I can't understand the reason for paying attention to oneself.
For example, if we are looking at a seq2seq translation task, does self-attention mean that we "highlight" all words that have similar meanings in a sentence together? What's the intuition that this will help the translation task?
for example, in the sentence "I hate apples, I only drink apple juice.", will we encode the first "apples" and "apple" together? why is that useful?
In self attention each token creates both "query" and "key" vectors. These are different vectors, so one type of token can look up data for different types of tokens. "Apple" can generated a verb query to look for a verb in the sentence, while also generating a noun key that other words can look for.
And with multi-head attention, a single token can get data from different types of tokens.
The "self" in self-attention just means that it's looking at other parts of the same string, rather than looking at weights generated from a different string.
>I can't understand the reason for paying attention to oneself.
Certain parts of a sentence strongly inform the meaning of other parts and so it is important to encode them together. If you see the word "bank" in a sentence, is it referring to the financial institution or the land next to a body of water? We know by what came before or what comes after. Attention allows relevant context to inform token processing without being distracted by irrelevant context.
For regular attention, where_to_look_at could be a database, memory or anything else.
So in this example if we’re trying to predict the second “apple” the first “apples” is very helpful. If we’re predicting “juice” then we’d use one head of self-attention to look at the first “apples” and a second head to also look at the second “apple”
The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear.
In this short report, we ask: is the attention layer even necessary?
Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9\% top-1 accuracy, compared to 77.9\% and 79.9\% for ViT and DeiT respectively.
These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.
Fundamentally, I have never understood why a conventional neural network cannot learn self-attention. If layers are fully connected then they definitely have the ability to create a weighting of the inputs accommodating learning of relative spatial features in the data. In fact that's almost a definition of what a neural network is. If relative positional learning is important, that could be added the same way it is in Transformers without the explicit self-attention layer. So what is self-attention really doing beyond this? Why do we need a Query-Key-Value construct here? I am sure I am missing something very basic and fundamental here.
It's good to ignore self-attention for a moment and take a look at a convolutional network (a CNN). Why is a CNN more effective than just stacks of fully connected layers? Well, instead of just throwing data at the network and telling it to figure out what to do with it, we've built in some prior knowledge into the network. We tell it "you know, a cup is going to be a cup even if it is 10 pixels up or 10 pixels down; even if it is in the upper right of the image or the lower left." We also tell it, "you know, the pixels near a given pixel are going to be pretty correlated with that pixel, much more so than pixels far away." Convolutions help us express that kind of knowledge in the form of a neural network.
Self-attention plays a similar role. We are imbuing our network with an architecture that is aware of the data it is about to receive. We tell it "hey, elements in this sequence have a relation with one another, and that relative location might be important for the answer". Similar to convolutions, we also tell it that the location of various tokens in a sequence is going to vary: there shouldn't be much difference between "Today the dog went to the park" and "The dog went to the park today." Like convolutions, self-attention builds in certain assumptions we have about the data we are going to train the network on.
So yes, you are right that fully-connected layers can emulate similar behavior, but training them to do that isn't easy. With self-attention, we've started with more prior knowledge about the problem at hand, so it is easier to solve.
I can’t stand it when people lazily personify ML models, but it’s akin to giving someone with no experience some wood and then pointing to a shed and saying “make one of those from this”. Instead you’d expect them to be much more successful if you also give them a saw, a drill, some screws etc.
Read the intro in the original paper "Attention is all you need" (https://arxiv.org/abs/1706.03762)
This video explains the drawbacks to RNNs and how transformers solve that: https://youtu.be/S27pHKBEp30?t=394
Andrej Karpathy explains attention here: https://youtu.be/kCc8FmEb1nY?t=3719
He explains how attention is seen as a communication network: https://youtu.be/kCc8FmEb1nY?t=4298
I wouldn't call this the original "attention" paper. Definitely not the first paper to use the phrase. If you want clear proof of this, let's read the paper
> Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in the input or output sequences.
I do think a lot of people's lack of understanding of attention is because they are so focused on DP(S)A that they miss a lot of the broader picture. And math. Not enough people dig into the math.
The hidden dimension is constructed inside the model to create space where it can embed each token into a vector and then enrich that space with contextual information derived from the sequence. In order for that to occur, the model must have a means of transferring information from along the token dimension.
One way to accomplish this is to use a 2d convolution; however, the scope of a convolution is limited to the size of its kernel. A fully connected layer is the same as a 2d convolution with a kernel size of 1. So you can see that no information from neighboring tokens can be applied to the hidden space.
The standard self-attention equation has a global scope from the full matrix multiplication of the input tensor with its transpose. Each element of the resulting matrix demonstrates some interaction with every other token in the sequence. Next a softmax operation is applied, which acts as a gating or relevance function. Finally, this is multiplied back to the original input to build that information into the hidden dimension for each token.
There have been attempts to do similar operations using fully connected layers. Look at the architecture of SGUs (spatial getting units). In some applications, they have good performance, but because fully connected layers operate on each dimension independently and serially, they are not equivalent to attention.
Last, my best recommendation for anybody trying to understand attention is to stop reading articles and instead spend your time looking at the math. It’s usually much less confusing than any of the dozens of explanations floating around the web, including the one I just gave. The math is not too complicated, especially once you know the reasons for why we need to use it.
While fully connected layers can in theory learn anything, that's a very hard optimization problem in terms of gradient flow. Attention adds inductive biases (prior/domain knowledge) about what you want the network to learn, which makes the optimization of that specific aspect easier for the optimization algorithm.
In general, you can view almost anything in ML/DL as improving either optimization or generalization, and while it's a spectrum, attention falls more into the optimization category.
Here's how I think about it: Yes, you can learn everything with an MLP, since universal approximation and so on, but it is not efficient. And (modern) ML is all about scale.
Here's one way to see the difference: Self-attention takes as input N channels of dimension D. It maps to N keys and Queries with a DxD matrix, time = O(ND^2). Next it computes all pair-wise Key-Query inner products: time = O(N^2D). Finally softmax takes time = O(N^2) and computing "probability times values" takes O(DN^2).
All in all self-attention takes O(DN^2+ND^2) time to map ND values to ND values. How long would an MLP take to do the same thing? O((ND)^2).
So, in a typical case of D~1000 and N~1000 we save a factor 1000. For the price of 1 MLP layer, you could afford 1000 self-attention layers. It's a pretty major difference.
I teach the course at my uni and I'm highly confident this is true, even in the research community. Part of this is that people are hyper concentrated on dot-product attention (<softmax(<q,k>),v>) (DPA). There is a lot more forms of attention than this. It does help to go back to early attention mechanisms like those discussed by Bahdanau (Bengio's student) and Graves (Deep Mind). When you look at these you'll find a clearer definition: a learned weighing function (Bahdanau specifies as a probability), conditioned on some input applied to a learned embedding conditioned on another input. If the two inputs are the same then it's self attention, otherwise cross. You'll see some people refer to the learned weighting as a score (not to be confused with Fisher Score -- gradient of likelihood -- used for diffusion training). Understanding this you'll see that the definition is broader that DPA but also what makes attention powerful. But lots of people don't catch this because they don't have the history of RNNs. DPA has become the de facto choice because the dot product between the two embeddings creates a more powerful score function without introducing lots of parameters (there are people exploring more complex structures).
> I have never understood why a conventional neural network cannot learn self-attention.
I'll even ask another question, why can't densely connected networks (linear) learn self-attention? Both convolutions and dense layers are universal approximators, right? But we often see convolutions as the preferred choice over dense layers (note: they are equivalent when 1D, kernel size is 1, no padding, stride of 1). Well the power is how information is encoded and connected. CNNs in essence capture a form of positional encoding as they have a structured order. Transformers need a bit more help (note that there's also relative positional bias (they also require augmentations)) and that information helps them create very powerful graphs. But one big advantage to DPA is the multiple heads, which can take different "views." Importantly, transformers scale extremely well. Essentially attention creates a more complex connection between the information. That can make them harder to train, but also allow for more efficient encoding of information.
I could talk a lot more but I'll stop here (though a very incomplete explanation). There are two resources that I really like to hand to my students:
Lilian Weng's blog (everything she does is great) has a good coverage of many different attention mechanisms and discusses the RNN history https://lilianweng.github.io/posts/2018-06-24-attention/
This medium blog. It isn't as in depth but it is good and you end up with a model that students can actually usefully train. It'll also hopefully answer some of your questions https://medium.com/pytorch/training-compact-transformers-fro...
Bonus: Softmax tempering can help you understand why we scale the score https://aclanthology.org/2021.mtsummit-research.10/
All these "we've attached this random piece to our model and it makes things better for reasons we don't understand" leaves a poor impression. To an outsider, the GPT model looks a lot like a IIR / FIR filter made of linear operators (matrix multiplications) and convolutions (matrix multiplications in the Fourier space) that processes a sequence of vectors, instead of numbers (like a normal FIR filter would). Perhaps the extensive theory of filters apply here, and perhaps we can even apply the Z transform on the GPT filter to analyze its convergence?
https://arxiv.org/abs/2104.13478
The toy problem is useful because the model dimensionality is low enough to make visualization straightforward. The walkthrough also goes through how things can go wrong, and how it can be improved, etc.
The walkthrough and code is all available here: https://github.com/rstebbing/workshop/tree/main/experiments/....
It's not terse like nanoGPT or similar because the goal is a bit different. In particular, to gain more intuition about the intermediate attention computations, the intermediate tensors are named and persisted so they can be compared and visualized after the fact. Everything should be exactly reproducible locally too!
For example, if we are looking at a seq2seq translation task, does self-attention mean that we "highlight" all words that have similar meanings in a sentence together? What's the intuition that this will help the translation task?
for example, in the sentence "I hate apples, I only drink apple juice.", will we encode the first "apples" and "apple" together? why is that useful?
And with multi-head attention, a single token can get data from different types of tokens.
The "self" in self-attention just means that it's looking at other parts of the same string, rather than looking at weights generated from a different string.
Certain parts of a sentence strongly inform the meaning of other parts and so it is important to encode them together. If you see the word "bank" in a sentence, is it referring to the financial institution or the land next to a body of water? We know by what came before or what comes after. Attention allows relevant context to inform token processing without being distracted by irrelevant context.
keys = key_weights * x
query = query_weights * x
values = value_weights * what_to_look_at
For self attention, what_to_look_at = x
For regular attention, where_to_look_at could be a database, memory or anything else.
So in this example if we’re trying to predict the second “apple” the first “apples” is very helpful. If we’re predicting “juice” then we’d use one head of self-attention to look at the first “apples” and a second head to also look at the second “apple”
That’s my understanding at least
The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear.
In this short report, we ask: is the attention layer even necessary?
Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9\% top-1 accuracy, compared to 77.9\% and 79.9\% for ViT and DeiT respectively.
These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.
There are other non-attention based networks that get 90% too though: https://arxiv.org/pdf/2212.11696v3.pdf
Deleted Comment