ungrokking, in which a network regresses from perfect to low test accuracy
Is this the same thing as catastrophic forgetting?
From page 6 of the paper:
Ungrokking can be seen as a special case of catastrophic forgetting (McCloskey and Cohen, 1989; Ratcliff, 1990), where we can make much more precise predictions. First, since ungrokking should only be expected once , if we vary we predict that there will be a sharp transition from very strong to near-random test accuracy (around ). Second, we predict that ungrokking would arise even if we only remove examples from the training dataset, whereas catastrophic forgetting typically involves training on new examples as well. Third, since does not depend on weight decay, we predict the amount of “forgetting” (i.e. the test accuracy at convergence) also does not depend on weight decay.
(All of these predictions are then confirmed in the experimental section.)
Which of these theories:
Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss ("slingshots") (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E).
can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.
Also, how does this theory explain other grokking related pheonmena e.g. Omni-Grok? And how do things change as you increase parameter count? Scale matters, and I am not sure whether things like ungrokking would vanish with scale as catastrophic forgetting did. Or those various inverse scaling phenomena.
Which of these theories [...] can predict the same "four novel predictions about grokking" yours did? The relative likelihoods are what matters for updates after all.
I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions. This is mostly what happens with the other theories:
how does this theory explain other grokking related pheonmena e.g. Omni-Grok?
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.
Happy to speculate on other grokking phenomena as well (though I don't think there are many others?)
And how do things change as you increase parameter count?
We haven't investigated this, but I'd pretty strongly predict that there mostly aren't major qualitative changes. (The one exception is semi-grokking; there's a theoretical reason to expect it may sometimes not occur, and also in practice it can be quite hard to elicit.)
I expect there would be quantitative changes (e.g. maybe the value of changes, maybe the time taken to learn changes). Sufficiently big changes in might mean you don't see the phenomena on modular addition any more, but I'd still expect to see them in more complicated tasks that exhibit grokking.
I'd be interested in investigations that got into these quantitative questions (in addition to the above, there's also things like "quantitatively, how does the strength of weight decay affect the time for to be learned?", and many more).
My speculation for Omni-Grok in particular is that in settings like MNIST you already have two of the ingredients for grokking (that there are both memorising and generalising solutions, and that the generalising solution is more efficient), and then having large parameter norms at initialisation provides the third ingredient (generalising solutions are learned more slowly), for some reason I still don't know.
Higher weight norm means lower effective learning rate with Adam, no? In that paper they used a constant learning rate across weight norms, but Adam tries to normalize the gradients to be of size 1 per paramter, regardless of the size of the weights. So the weights change more slowly with larger initializations (especially since they constrain the weights to be of fixed norm by projecting after the Adam step).
Sounds plausible, but why does this differentially impact the generalizing algorithm over the memorizing algorithm?
Perhaps under normal circumstances both are learned so fast that you just don't notice that one is slower than the other, and this slows both of them down enough that you can see the difference?
I disagree with the implicit view on how science works. When you are a computationally bounded reasoner, you work with partial hypotheses, i.e. hypotheses that only make predictions on a small subset of possible questions, and just shrug at other questions.
Implictly, I thought if a you have a partial hypothesis of grokking, then if it shrugs at an grokking related phenomena it should be penalized. Unless by "shrugs" you mean the details of what the partial hypothesis says in this particular case are still being worked out. But in that case, confirming the partial hypothesis doesn't say anything yet about some phenomena is still useful info. I'm fairly sure this belief was what generated my question.
This is mostly what happens with the other theories
Thank you for going through the theories and checking what they have to say. That was helpful to me.
I'd be interested in investigations that got into these quantitative questions
Do you have any plans to do this? How much time do you think it would take? And do you have any predictions for what should happen in these cases?
Unless by "shrugs" you mean the details of what the partial hypothesis says in this particular case are still being worked out.
Yes, that's what I mean.
I do agree that it's useful to know whether a partial hypothesis says anything or not; overall I think this is good info to know / ask for. I think I came off as disagreeing more strongly than I actually did, sorry about that.
Do you have any plans to do this?
No, we're moving on to other work: this took longer than we expected, and was less useful for alignment than we hoped (though that part wasn't that unexpected, from the start we expected "science of deep learning" to be more hits-based, or to require significant progress before it actually became useful for practical proposals).
How much time do you think it would take?
Actually running the experiments should be pretty straightforward, I'd expect we could do them in a week given our codebase, possibly even a day. Others might take some time to set up a good codebase but I'd still be surprised if it took a strong engineer longer than two weeks to get some initial results. This gets you observations like "under the particular settings we chose, D_crit tends to increase / decrease as the number of layers increases".
The hard part is then interpreting those results and turning them into something more generalizable -- including handling confounds. For example, maybe for some reason the principled thing to do is to reduce the learning rate as you increase layers, and once you do that your observation reverses -- this is a totally made up example but illustrates the kind of annoying things that come up when doing this sort of research, that prevent you from saying anything general. I don't know how long it would take if you want to include that; it could be quite a while (e.g. months or years).
And do you have any predictions for what should happen in these cases?
Not really. I've learned from experience not to try to make quantitative predictions yet. We tried to make some theory-inspired quantitative predictions in the settings we studied, and they fell pretty flat.
For example, in our minimal model in Section 3 we have a hyperparameter that determines how param norm and logits scale together -- initially, that was our guess of what would happen in practice (i.e. we expected circuit param norm <> circuit logits to obey a power law relationship in actual grokking settings). But basically every piece of evidence we got seemed to falsify that hypothesis (e.g. Figure 3 in the paper).
(I say "seemed to falsify" because it's still possible that we're just failing to deal with confounders in some way, or measuring something that isn't exactly what we want to measure. For example, Figure 3 logits are not of the Mem circuit in actual grokking setups, but rather the logits produced by networks trained on random labels -- maybe there's a relevant difference between these.)
See also this post by Quintin Pope: https://www.lesswrong.com/posts/JFibrXBewkSDmixuo/hypothesis-gradient-descent-prefers-general-circuits
I think that post has a lot of good ideas, e.g. the idea that generalizing circuits get reinforced by SGD more than memorizing circuits at least rhymes with what we claim is actually going on (that generalizing circuits are more efficient at producing strong logits with small param norm). We probably should have cited it, I forgot that it existed.
But it is ultimately a different take and one that I think ends up being wrong (e.g. I think it would struggle to explain semi-grokking).
I also think my early explanation, which that post compares to, is basically as good or better in hindsight, e.g.:
This is a linkpost for our paper Explaining grokking through circuit efficiency, which provides a general theory explaining when and why grokking (aka delayed generalisation) occurs, and makes several interesting and novel predictions which we experimentally confirm (introduction copied below). You might also enjoy our explainer on X/Twitter.
Abstract
One of the most surprising puzzles in neural network generalisation is grokking: a network with perfect training accuracy but poor generalisation will, upon further training, transition to perfect generalisation. We propose that grokking occurs when the task admits a generalising solution and a memorising solution, where the generalising solution is slower to learn but more efficient, producing larger logits with the same parameter norm. We hypothesise that memorising circuits become more inefficient with larger training datasets while generalising circuits do not, suggesting there is a critical dataset size at which memorisation and generalisation are equally efficient. We make and confirm four novel predictions about grokking, providing significant evidence in favour of our explanation. Most strikingly, we demonstrate two novel and surprising behaviours: ungrokking, in which a network regresses from perfect to low test accuracy, and semi-grokking, in which a network shows delayed generalisation to partial rather than perfect test accuracy.
Introduction
When training a neural network, we expect that once training loss converges to a low value, the network will no longer change much. Power et al. (2021) discovered a phenomenon dubbed grokking that drastically violates this expectation. The network first "memorises" the data, achieving low and stable training loss with poor generalisation, but with further training transitions to perfect generalisation. We are left with the question: why does the network's test performance improve dramatically upon continued training, having already achieved nearly perfect training performance?
Recent answers to this question vary widely, including the difficulty of representation learning (Liu et al., 2022), the scale of parameters at initialisation (Liu et al., 2023), spikes in loss ("slingshots") (Thilak et al., 2022), random walks among optimal solutions (Millidge et al., 2022), and the simplicity of the generalising solution (Nanda et al., 2023, Appendix E). In this paper, we argue that the last explanation is correct, by stating a specific theory in this genre, deriving novel predictions from the theory, and confirming the predictions empirically.
We analyse the interplay between the internal mechanisms that the neural network uses to calculate the outputs, which we loosely call "circuits" (Olah et al., 2020). We hypothesise that there are two families of circuits that both achieve good training performance: one which generalises well (Cgen) and one which memorises the training dataset (Cmem). The key insight is that when there are multiple circuits that achieve strong training performance, weight decay prefers circuits with high "efficiency", that is, circuits that require less parameter norm to produce a given logit value.
Efficiency answers our question above: if Cgen is more efficient than Cmem, gradient descent can reduce nearly perfect training loss even further by strengthening Cgen while weakening Cmem, which then leads to a transition in test performance. With this understanding, we demonstrate in Section 3 that three key properties are sufficient for grokking: (1) Cgen generalises well while Cmem does not, (2) Cgen is more efficient than Cmem, and (3) Cgen is learned more slowly than Cmem.
Since Cgen generalises well, it automatically works for any new data points that are added to the training dataset, and so its efficiency should be independent of the size of the training dataset. In contrast, Cmem must memorise any additional data points added to the training dataset, and so its efficiency should decrease as training dataset size increases. We validate these predictions by quantifying efficiencies for various dataset sizes for both Cmem and Cgen.
This suggests that there exists a crossover point at which Cgen becomes more efficient than Cmem, which we call the critical dataset size Dcrit. By analysing dynamics at Dcrit, we predict and demonstrate two new behaviours (Figure 1). In ungrokking, a model that has successfully grokked returns to poor test accuracy when further trained on a dataset much smaller than Dcrit. In semi-grokking, we choose a dataset size where Cgen and Cmem are similarly efficient, leading to a phase transition but only to middling test accuracy.
We make the following contributions: