Yale Transformer

0 views
Skip to first unread message

Cherish Asleson

unread,
Aug 3, 2024, 3:52:58 PM8/3/24
to tangsembkingchit

In a new paper HyperAttention: Long-context Attention in Near-Linear Time, a research team from Yale University and Google Research presents HyperAttention, an approximate attention mechanism not only offers practical efficiency but also delivers the best near-linear time guarantee for long contexts processing.

Transformers have revolutionized a wide array of learning tasks, but their scalability limitations have been a pressing challenge. The exact computation of attention layers results in quadratic runtime and memory complexities, hindering the scaling of transformer models to handle longer context lengths effectively.

In a new paper HyperAttention: Long-context Attention in Near-Linear Time, a research team from Yale University and Google Research presents HyperAttention, an approximate attention mechanism to tackle the computational challenges posed by the increasing complexity of long contexts in Large Language Models (LLMs). HyperAttention not only offers practical efficiency but also delivers the best near-linear time guarantee, making it a remarkable advancement.

The central issue addressed by this research is the approximation of attention, specifically the dot-product attention, which involves processing three input matrices: Q (queries), K (keys), and V (values), all sized according to the number of tokens in the input sequence and the dimension of latent representations. The primary goal is to efficiently approximate the output matrix, Att, while preserving its spectral properties.

The proposed approach involves the development of an efficient estimator for the diagonal scaling matrix in near-linear time. Additionally, it swiftly approximates the matrix product of the softmax matrix and the value matrix through subsampling. The researchers have streamlined the kernel density estimation procedure, known as KDEformer, and have demonstrated that uniform sampling is sufficient to achieve the desired spectral guarantee, eliminating the need for importance sampling based on kernel densities. This significant simplification leads to the creation of a practical and provably linear-time algorithm.

Notably, the proposed approach does not require bounded entries or bounded stable rank, and the fine-grained parameters for analyzing time complexity remain manageable, even when the entries in the attention matrix or the stable rank are large.

Furthermore, the team has observed that by conducting one-sided sampling from the squared row norms, they can eliminate the need for KDEs while achieving the same spectral norm guarantee in terms of stable rank.

In empirical testing, HyperAttention outperforms existing methods, demonstrating substantial speed improvements compared to state-of-the-art solutions such as FlashAttention. For instance, HyperAttention accelerates the inference time of ChatGLM2 by 50% when handling a context length of 32,000 tokens, with only a slight increase in perplexity from 5.6 to 6.3. In scenarios involving larger context lengths, e.g., 131,000 tokens with causal masking, HyperAttention offers a remarkable 5-fold speedup on a single attention layer.

In conclusion, the introduction of HyperAttention marks a significant breakthrough in overcoming the scalability limitations of transformers, making them more adept at handling longer context lengths. This innovation promises to enhance the efficiency and effectiveness of Large Language Models, with notable speed gains in real-world applications.

LinkedIn and 3rd parties use essential and non-essential cookies to provide, secure, analyze and improve our Services, and to show you relevant ads (including professional and job ads) on and off LinkedIn. Learn more in our Cookie Policy.

My colleague, K Sudhir and I are teaching a new course at the Yale School of Management on large language models like ChatGPT, Bard, Claude, and Llama. Our course is divided into two parts: theory and application. We just finished the theory part, the goal of which was to endow students, mostly MBAs, with a passable understanding of how LLMs are built. That's not to say our students could build such models on their own, but, I think it's fair to say most of the students can now correctly describe how and why ChatGPT works.

We received numerous requests for our course materials. So, we're sharing those here, which is trivial because there's nothing original! ? Instead of developing our own materials, we found the best materials we could online and assigned these to our students as pre-class reading. (Each class in the "theory" part began with an on-paper, off-line quiz ?.) Finding the "best" material took some time. Having now completed this first part of the course, I'd say we're satisfied with our choices.

Below, you can find our day-by-day assigned reading. We're sharing these in the hope that it helps you on your journey to understand LLMs. These materials assume you have some very basic understanding of calculus, linear algebra, and python programming. I'm guessing most people meeting those prerequisites can come up to speed on LLMs in a month of study. (Of course, you can consume this material quickly, but understanding it takes a bit of time to mentally digest...at least it did for me.)

LLMs are big neural networks and a neural network is basically just a bunch of matrices that we multiply together in fancy ways to make some prediction like "that image contains a cat" or "the next word in this sentence is 'squirrel'". During training, we run the neural network forward and backward a bunch of times until we get good values in these matrices: values that make correct predictions. You can kinda think of that like tuning the strings on some freakishly large guitar. We want the guitar to make the right notes and so we twist knobs until we get there.

  • Lipton, Zachary C., John Berkowitz, and Charles Elkan. "A critical review of recurrent neural networks for sequence learning" arXiv preprint arXiv:1506.00019 (2015). This is an oft-cited review of RNNs, their upsides and downsides.
  • "Understanding LSTM Networks" by Christopher Olah . Long short-term memory RNNs (LSTMs) are a kind of RNN that solves two problems with vanilla RNNs: they have a longer term memory and they can forget irrelevant stuff easily.
  • Understanding GRU Networks by Simeon Kostadinov . A Gated Recurrent Unit is like an LSTM but more simple and easier to train
  • Illustrated guide to LSTMs and GRUs, by Michael Phi. The animation in this is nice. I think the link above is better written though.

In our previous classes, we saw how a basic neural network is trained to take numbers and make decisions/outputs. We also saw how to make a series of outputs with recurrent neural networks and this was our first foray into text as data. In our RNN implementation we used the so-called "one hot" encoding of words/characters/primitives. In this representation, if we have N words in our input sentence and M words in our vocabulary, out input matrix to the network is NxM.

The problem is that one-hot representations are really poor. Consider a word like "dog", it's quite similar to the word "cat" in many ways, isn't it? They're both pets, we might cuddle with each of them, we feed each of them. Indeed, "dog" can often be replaced with "cat" in a sentence and it still works just fine! But, in one-hot encoding "dog" is no more similar to "cat" than it is to "volcano". What a loss!

Now we're going to get the same kinds of vectors for words and we call these "embeddings". We're going to rewind the clock the little bit and learn a famous embedding called word2vec. You likely won't use word2vec in products anymore because there are better embeddings now. But, it is the correct step for us, as relative newcomers to the concept of embeddings.

  • How neural networks can learn complicated things by adjusting the values of big matrices so as to minimize some user-defined loss function.
  • How we could give these networks a little bit of memory with RNNs, so that 1) we can create some kind of hidden "thought" vector and 2) words could "remember" past words.
  • How we could represent words as embeddings: "dense" vectors that have semantic meaning instead of just "one hot" vectors without meaning.

Now we're going to build off those lessons and learn, finally, transformers. The "T" in ChatGPT stands for "Transformer". If transformers are the beating heart of ChatGPT, then "attention" is the beating heart of a transformer. In particular, we're going to be learning about "multi-head scaled dot product attention". Quite a mouthful!

Alas, attention looks complicated (it isn't) and we're kinda learning it without knowing what we're going to do with it: almost like learning how a heart works without knowing about the rest of the body.

Here's what I want you to know, the single thing you should keep in mind as you read the items below. Remember from the last class that we looked at word embeddings or vectors, like "turkey" = [0.2, 0.99, 0.232, 0.487, 0.3, ....]. Also, remember from class how "turkey" could be either the country or the animal? When we get the word2vec word embedding for "turkey", we're getting a vector that is a combination of these meanings. What attention is going to do for us is allow the word "turkey" to pay attention to other words around it so that it knows which "turkey" it is. In the case of "Turkey approved the UN resolution", that is the country Turkey. So we're going to take a generic "turkey" embedding and get a context-aware "turkey", a super "turkey" vector that knows all kinds of things about itself and the context in which it occurs: adjectives that apply to it, subject-verb agreement, all kinds of stuff! The attention mechanism spits out this super duper smart vector for "turkey" instead of that dumb one with which we started.

Now that we understand Attention, you should have everything you need to understand Transformers and ChatGPT. There's nothing new in this class. We're just putting it all together. ChatGPT is in the family of decoder-only transformer models. Given some input text, the model just continually asks "what is the next most likely word?" spits that out and does it again. This is a little different than what you'll read about below. The article below describes an encoder-decoder model. In these we take in some data and create a "latent" or "hidden" state with the encoder. That state just like the hidden state in an RNN. Then, the decoder can look at that state when making its output. That's how we'd translate from English to German or, how we might add a caption to an image (the encoder makes a hidden representation of the image and then the decoder turns that into words).

c80f0f1006
Reply all
Reply to author
Forward
0 new messages