The following is a write-up of an (incomplete) project I worked on while at Anthropic, and a significant amount of the credit goes to the then team, Chris Olah, Catherine Olsson, Nelson Elhage & Tristan Hume. I've since cleaned up this project in my personal time and personal capacity.
Note: I've tried to make this post accessible and to convey intuitions, but it's a pretty technical post and likely only of interest if you care about mech interp and know what activation patching/causal tracing is
Download File ✺ https://t.co/PZxmWOO8m4
Activation patching (aka causal tracing) is one of my favourite innovations in mechanistic interpretability techniques. The beauty of it is by letting you set up a careful counterfactual between a clean input and a corrupted input (ideally the same apart from some key detail), and by patching in specific activations from the clean run to the corrupted run, we find which activations are sufficient to flip things from the corrupted answer to the clean answer. This is a targeted, causal intervention that can give you strong evidence about which parts of the model do represent the concept in question - if a single activation is sufficient to change the entire model output, that's pretty strong evidence it matters! And you can just iterate over every activation you care about to get some insight into what's going on.
In practice, it and it's variants have gotten pretty impressive results. But one practical problem with activation patching + iterating over every activation is the running costs - every data point in activation patching requires a single forward pass. This is pretty OK when working with small models or with fairly coarse patches (eg an entire residual stream at a single position), but it gets impractical fast if you want very fine-grained patches (eg between each pair of heads, specific neurons, etc) or want to work with large models.
This post introduces attribution patching, a technique that uses gradient-based approximation to approximate activation patching (note the very similar but different name!). Attribution patching allows you to do every single patch you might want between a clean and corrupted input on two forward passes and one backward pass (that is, every single patch can be calculated by solely caching results from the same 3 runs!). The key idea is to assume that the corrupted run is a locally linear function of its activations (keeping parameters fixed!), take the gradient of the patch metric with respect to each activation, and consider a patch of activation x to be applying the difference corrupted_x -> corrupted_x + (clean_x - corrupted_x), with patch metric change (corrupted_grad_x * (clean_x - corrupted_x)).sum().
This post fleshes out the idea, and discusses whether it makes any sense. There is an accompanying notebook with code implementing attribution patching for the IOI circuit in GPT-2 Small and comparing it to activation patching in practice - if you prefer reading code, I endorse skipping this post and reading the notebook first.
The intended spirit of this post is to present an interesting technique, arguments for or against why it would work, intuitions for where it would work well, an implementation to play around with yourself, and some empirical data. My argument is that attribution patching is a useful tool for exploratory analysis, narrowing down hypotheses, and figuring out the outlines of a circuit, and that it's a useful part of a mech interp toolkit. I do not want to argue that this is stictly better than activation patching, nor that it's perfectly reliable, and I think the most valuable use cases will be to generate hypotheses that are then explored with more rigorous techniques like causal scrubbing
Ideally I'd have much more empirical data to show, and I may follow up at some point, but I wanted to get something out to share the idea. I am not arguing that attribution patching is the only technique you need, that it's flawless, that all of my arguments and intuitions are correct, etc. But I think it's a useful thing to have in your toolkit!
*The point of this section is to give the gist + some context on activation patching. If you want a proper tutorial, check out my explainer and the relevant code in exploratory analysis demo. I dig deeper into the foundational intuitions in a later section
The core idea is to set up a careful counterfactual between a clean prompt and a corrupted prompt, where the two differ in one key detail. We set up a metric to capture the difference in this key detail. The model is then run on the corrupted prompt, and a single activation is then patched in from the clean prompt, and we check how much it has flipped the output from the corrupted output to the clean output. This activation can be as coarse or fine grained as we want, from the entire residual stream across all layers at a single position to a specific neuron at a specific layer and specific position.
We want to reverse-engineer how GPT-2 Small does indirect object identification, and in particular analyse the residual stream at different positions to see how the information about which name is the indirect object flows through the network. We focus on the clean prompt "When John and Mary went to the store, John gave the bag to" and how it is mapped to the clean answer " Mary". We take our corrupted prompt as "When John and Mary went to the store, Mary gave the bag to" and the corrupted answer " John" and our patching metric is the logit difference final_logit[token=" Mary"] - final_logit[token=" John"].
We run the model on the clean prompt and cache all activations (the clean activations). We do this by residual stream patching - for a specific layer L and position P, we run the model on the corrupted prompt. Up until layer L it's unchanged, then at layer L we patch the clean residual stream in at position P and replace the original residual stream. The run then continues as normal, and we look at the patch metric (logit difference). We iterate over all layers L and positions P. (at layer L we patch at the start of each layer's residual stream, ie layer L patch does not include any outputs of attn or MLP layer L)
We see that early on (top) things are at the second subject (S2, value " John") token, there's some transition at layers 7 and 8 (which we now know is from the S-Inhibition Heads), and then things move to the final token
For example, we know that the duplicate token heads operate on the S2 token to identify whether or not it's duplicated. And then this feature is moved to the final token via the S-Inhibition heads (in layers 7 & 8, as predicted). But the circuit could involve eg a "subject mover" head in layers 7 & 8 that move the value of S2 to the final token, and then eg the model copies all prior names except for names equal to the value of the copied subject, so the duplication analysis is done there. This is equally valid!
A key observation is that we are patching a single activation from clean to corrupted. This means we're checking which activations are sufficient to contain the key information to reconstruct the clean solution. This is very different from checking which activations matter, or from patching in specific corrupted activations into the clean run to see which most break things (essentially a form of ablation). IMO being sufficient is much stronger evidence than just mattering, but I think both can give valuable info.
The key motivation behind attribution patching is to think of activation patching as a local change. We isolate a specific activation, and patch in a clean version, altered by a specific change in the input. If done right, this should be a pretty small change in the model, and plausibly in the activation too! The argument is then that, given that we're making a small, local change, we should get about the same results if apply this small change to a linear approximation of the model, on the corrupted prompt!
Intuitively, attribution patching takes a linear approximation to the model at the corrupted prompt, and calculates the effect of the local change of the patch for a single activation from corrupted to clean at that prompt.
To compute this, we take a backwards pass on the corrupted prompt with respect to the patching metric, and cache all gradients with respect to the activations. Importantly, we are doing a weird thing, and not just taking the gradients with respect to parameters! Then, for a given activation we can compute ((clean_act - corrupted_act) * corrupted_grad_act).sum(), where we do elementwise multiplication and then sum over the relevant dimensions. Every single patch can be computed after caching activations from two forward passes and one backwards pass.
The main reason I think you should care about this is that attribution patching is really fast and scalable! Once you do a clean forward pass, corrupted forward pass, and corrupted backward pass, the attribution patch for any activation is just ((clean_act - corrupted_act) * corrupted_grad_act).sum(). This is just elementwise multiplication, subtraction, and summing over some axes, no matmuls required! And is all very straightforwards to code in my TransformerLens library, see the attached notebook. While activation patching needs a forward pass per data point. This makes it easy to do very fine-grained patching, eg of specific neurons at specific positions, which would be prohibitively expensive otherwise.
It's easy to give a somewhat unfair comparison - there's a lot of heuristics you could use to do activation patching more intelligently, eg patching in a head across all positions and only splitting the important ones by position, ditto for layers and splitting into heads, etc. But I think the overall point stands! It's nice to be able to do fine-grained activation patching, and attributing patching is a very fast approximation. Further, we can decompose even further and eg do efficient direct path patching (described later) between all pairs of components, like each attention head's output and the Q, K and V input of each attention head in subsequent layers.
bcf7231420