Imagine that you’re looking for buried treasure on a large desert island, worth a billion dollars. You don’t have a map, but a mysterious hermit offers you a box with a button to help find the treasure. Each time you press the button, it will tell you either “warmer” or “colder”. But there’s a catch. With probability the box will tell you the truth about whether you’re closer than you were last time you pressed. But with the remaining probability of .9999999999999999999999999999992, the box will make a random guess between “warmer” and “colder”. Should you pay $1 for this box?

Keep this in mind as we discuss the closely related problem of parity learning.

In my experience of interacting with the ML and interpretability communities, the majority of people don’t know about the impossibility result of the parity learning problem. The ones who do will often assume that this is a baroque, complicated result that surely doesn’t have a simple proof (another surprising opinion I’ve heard is of people knowing about the result, but saying that “there’s some new architecture that seems to solve it, actually”, which is somewhat indicative of people’s trust in the concept of “proof” in the ML community). 

Recently I was pleasantly surprised to realize that the impossibility of solving this problem (via a gradient-based learning algorithm in polynomial time) actually admits a pretty nice and understandable proof. The gist of it boils down to the silly “pirate treasure” story above: the answer, of course, is that you shouldn’t buy the box (at least if you’re trying to maximize your expected income), and for the same reason you can’t build a cool new architecture that solves the parity learning problem.

In this post I’ll briefly explain the problem, how it’s different from some other impossibility results, and why I think it’s important. This will later tie in to a series of posts about the insufficiency of Bayesian methods in (realistic) ML contexts.

What is the parity learning problem? 

This is a post where a bit of math can go a long way, but I’ll try to make it approachable to anyone with either a bit of math or a bit of ML background. First, the XOR parity function is a function from a length-n boolean input to a single boolean output. This function depends on a “secret” variable S, which is a subset: (here is the standard set on n elements). On an input vector the function outputs In other words, you look at the bits of on indices indexed by take their sum, and take its parity. Note (for people with some math background) that the value can be written more nicely as where we replace the set by the vector with coefficients in position at indices and otherwise. The vectors can then be interpreted as having coefficients in the field of two elements, and denotes their dot product in this field (i.e., the sum of products of coordinates).

Now the parity problem says that it’s not possible to solve the parity problem as a learning problem in polynomial time. This statement should be interpreted carefully. First, note that we haven’t defined what a learning problem is. A special case of a learning problem is any (polynomial-complexity) weight-based ML architecture that learns via SGD on cross-entropy loss (together with any choice of batch size, update step, initialization protocol, etc.). We will take this as our definition of “learnability” for the sake of this post, though later I’ll point out that our proof also shows that a much larger class of methods is incapable of solving parity. (On the other hand, as we’ll see, an undergraduate with a few hours to spare can solve the parity problem in polynomial time, and with a bit more time can even hand-select weights in an ML architecture to execute their solution.) The second thing to be careful of is that, for any fixed choice of “hidden” subset S, it is possible to design an algorithm that learns the parity problem. Indeed, you can simply initialize the architecture to the “right solution”. So it’s important to conceptualize S here as secret or random.

More concretely, the problem can be conceptualized as a game between two players A and B. Player A randomly chooses a secret subset where is the standard set on n elements (there are subsets, including all of [n] and the empty set, so each is chosen with probability ). Player B commits to a (polynomial-sized) learning algorithm M, which for us means an architecture, initialization scheme, and class of hyperparameters for a gradient-based learning scheme like SGD. Player A then randomly generates a number of sample boolean vectors with N some agreed-upon constant that depends at most polynomially in n, and player B trains the learning algorithm for N’ steps, where N’ is another (large) number that is nevertheless polynomially bounded in n. 

The theorem then says that no matter what learning algorithm that player B chose, the probability that the setup will learn an algorithm with >51% accuracy is effectively zero (i.e., it’s exponentially small in n). 

This “hidden guessing” game is about P vs. NP, isn’t it?

This is another common confusion, and the answer is no. Most theoretical computer scientists believe that . And assuming this[1] gives another source of unsolvable learning problems. Indeed, if you were to give player A in the above game more freedom, and allowed them to write down any (suitably) randomly chosen circuit (or suitable random ML algorithm) for their “target” function, you get another impossibility result. Here the assumption PBPP implies that there is no way, in polynomial time, to get reliable information about A’s secret circuit C (beyond some statistical regularities) from looking at polynomially many input-output samples of C. This in particular implies that there is no way to guarantee sufficiently accurate behavior of the result of a learning algorithm, since a learning algorithm is a special case of a (probabilistic) polynomial algorithm. But the XOR impossibility is in fact a much more satisfying result. It doesn’t require any assumptions about P vs. NP (and is true mathematically and unconditionally), and even more nicely, there actually does exist a (probabilistically) polynomial-time algorithm to solve it. In other words, we have the following containments (where note that I’m being sloppy about exactly what an “algorithm” is):

And in the case of the “XOR parity” problem, it shows that (without making any assumptions on the first containment, i.e., about P vs. NP), the second containment is proper: i.e., there are polynomially invertible algorithms which are impossible to execute as learning algorithms.

To convince you that this isn’t some deep hidden knowledge, I’ll explain in the following section how an undergrad with a semester of abstract algebra can solve the XOR parity problem. Since knowing the specific solution isn’t critical, I’ll assume a bit of abstract algebra here, and people without the necessary context can safely skip the following section. To be clear about its result, before going on I’ll write the upshot:

Upshot of the following section

In the setup as above, with player A having a hidden subset S and player B receiving N samples of the input-output behavior of the parity function for boolean inputs it is possible for B to recover the subset (and thus the function ) in polynomial time N’, with overwhelming probability.

In fact, it’s sufficient to look at N = 2n samples (and the length of the solution algorithm is cubic or better in n).

Polynomial-time non “learning-algorithmic” solution

For this section I’m assuming some linear algebra over finite fields; if this isn’t your jam, skip to the next session. 

The basic idea is to replace the function with the dot product The random samples can then be understood as an overdetermined linear equation over the field with two elements. Namely, given our random samples we write down an matrix whose length- rows are . Because the samples are chosen randomly, these are random boolean vectors. It is now a standard theorem that given vectors, the probability that they are linearly independent goes asymptotically as As soon as the probability that they’re linearly dependent is , i.e., is negligible (in fact, in the formal sense of going to zero faster than any inverse polynomial). 

Thus we can safely assume that the sample input vectors are linearly independent. Now in the assumption of the problem, A gave us both the vectors and also the values We can convert this to a system of linear equations (in ) on the secret vector Namely, we have

for the vector of parities. Now linear independence of the implies that is reconstructible in polynomial time. For example if and already the first sample inputs are linearly independent, then we can write and inverting a (boolean) matrix is doable in cubic time.

Before going on, note that once we’ve correctly guessed the secret subset we can write down a straightforward MLP that executes the parity XOR function Indeed, assume has elements (with since is a subset of an -element set) and let be the elements of , in order. Then we can recursively write as 

Now the XOR of two boolean elements is straightforward to write down as a single-layer MLP (whether using ReLU or any other activation function), and appropriately stacking k<n of them together gives a polynomial-size neural net that executes our hidden function (in fact, utilizing some parallelization allows this to be done in O(log(k)) layers). 

Handwavy proof of non-learnability

Welcome back to non-math people. This section is also slightly technical and can be skipped by people who don't care about understanding the proof, but it doesn't require any abstract algebra background.

At the end of the day, we have a function that we’re claiming can be represented by a (polynomial-sized) neural net, but cannot be learned in polynomial time as such. How can one go about showing this? The important bits of information to collect here are the following:

  1. The functions for different subsets form a basis[2] of all (real-valued) functions on boolean inputs. (Up to rescaling and subtracting a constant, this is also called the “Fourier basis” of functions on boolean inputs.)
  2. Any two functions and for two different subsets are uncorrelated on the set of all inputs [3]. This is key: even if S and T differ in only one element, the half of input vectors that have at that input will have different parities on S and T.
  3. The randomness in the choice of N samples leads to noise in the updates, on the order of 1/polynomial(N).

From these facts we see that:

  • The gradient update can be decomposed into components associated to each subset S.
  • The coefficient of the gradient update along each incorrect direction T can be modeled as a random variable, and is comparable in size (up to a polynomial multiple) to the update in the “true” direction T.

Of course, a priori the proof above assumes that the space of possible functions f_w(v) associated to possible weight parameter values coincides with the -dimensional space of all possible functions on boolean inputs. Since we’ve assumed that the number of parameters is polynomial in this isn’t the case: rather, the vector space of possible gradient updates is constrained to be in some -dimensional subspace of the -dimensional space of functions with boolean inputs. Equivalently, all the above possible update directions are projected to some low-dimensional subspace (and suitably normalized). 

At the end of the day, we can model the gradient update as a noise vector of some fixed size (that is inverse-polynomial in n, and is associated to the randomness of drawing random inputs), plus a projection of a “signal” vector associated to in to some poly(n)-dimensional subspace. Now standard considerations of high-dimensional projections imply that the “signal” vector might have significant size for some small (polynomial in n) number of “special” subsets S, but for the vast majority of choices of S, it will be suppressed by a massive factor proportional to the square root of dimension: and will completely fail to affect the noise, even after polynomially many update steps: thus the problem of gradient updating to the correct parity algorithm boils down (more or less) to the problem of the pirate-treasure hunter with the very unreliable box. 

The discerning reader will see that I swept a significant chunk of not only the proof, but even the logical flow of the argument under the rug: this is perhaps better described as the “intuition” behind the proof rather than a sketch. However, importantly, this “intuition” applies to absolutely any (polynomial-sized) architecture, and in fact applies to a much more general context than SGD: in fact, any learning algorithm including SGD, Adam, even more sophisticated local Bayesian learning setups, will fail for the same reasons. 

In fact, what we really used about the SGD “learning algorithms” was that it has some noise and its updating process only uses information averaged over input samples. In fact there is a general result that any learning algorithm that only uses this information cannot learn parity (in polynomial time). The definition of this class of algorithms and its relationship with various learnability and complexity results constitutes the beginning of the classical field of computational learning theory. For a nice compressed introduction which in particular formalizes the proof discussed here, see this paper.   

Alternative point of view: lack of incremental pathways

An alternative point of view on the failure of learnability in this case is that there is no good way for an algorithm to incrementally learn parity. There is no story of learning parity that starts from simple algorithms (in some quantitative or even qualitative, Occam razor-esque sense) and recursively learns added epicycles of complexity which improve classification accuracy. For example if we were to try to approximate the parity function by parity functions of smaller subsets, we would totally fail (as parity functions associated to different subsets are uncorrelated); a stronger version of the “lack of incremental pathways” result can be made following a similar intuition to the proof sketch above. This supports the idea that in order to be learnable, an algorithm must in some sense be combinable (at least in a local sense) out of simpler pieces, each of which is “findable” (i.e., doesn’t require exponential luck to get right – in later posts we will identify this with notions of effective dimension) and each of which reduces loss. This is closely related to the “low-hanging fruit prior” point of view, and will later serve as a lead-in to a discussion of “learning stories”,

Does this mean that neural nets are weak?

Now that we’ve seen that neural nets trained on examples of the parity prbolem are provably incapable of learning it in polynomial time, it is reasonable to ask whether this is a hard limitation on the computational capabilities of neural nets. Indeed, I just explained that it is provably impossible to use a learning algorithm (such as an LLM) to solve a problem that can be easily solved by an undergraduate, at least in an amount of time that is shorter than the length of the universe. Does this negate the possibility that modern LLM’s can solve hard-core math problems? Can we stop worrying about human-level AI?

Unfortunately (if you’re worried about AI risk), we haven’t. The impossibility of XOR learning does not imply any limitation on the mathematical ability of LLMs. The issue here is with the notion of “learnability”. In the setup of our XOR problem, we assumed that the LLM is executing SGD learning (or another learning algorithm) on the single learning signal of “what is the parity function applied to the vector ”. If we were to give the parity problem to an advanced LLM, it might be able to solve it, but this would not be from gradient updates on seeing a bunch of examples. Rather, our LLM has seen many mathematical texts, and may be able to use the knowledge in these mathematical texts and a basic understanding of logic to reconstruct the hidden subset S and the parity function Abstracting away the high-level “mathematical understanding” of LLMs, what this is saying is that it is in fact possible to learn the parity problem if the direct learning problem is replaced by a suitably sophisticated curriculum learning-style problem with an enriched class of examples and a more sophisticated loss function. Trying to write out a simple “mathematical” ML algorithm that learns to solve the parity problem is an interesting exercise, that might constitute a nice ML theory paper; I won’t try to do this here.  

Not weak, but also not optimal

The main reason I want people in AI safety/ interpretability to know about and understand parity is related to a long-standing question in machine learning of Bayesian learning vs. SGD, where the conventional wisdom has been wrong, but (in my limited understanding) is finally starting to converge in the correct direction (as will be typical of these posts, this is the “it’s complicated” direction). Namely, you can ask whether SGD (and related learning algorithms) can be well understood as finding the optimal solution – or more precisely, as sampling a suitable “Bayesian” distribution of near-optimal solutions[4]. It is easy to see that learning sufficiently general algorithms cannot converge to anything like Bayesian learning for P vs. NP reasons. But a standard counterargument, supported by a standard collection of faulty papers (that I’ll complain about later), was that “real life” problems where deep learning is applied do converge to the Bayesian prior. 

One (soft) takeaway from the discussion here is that if training “real-life” modern LLMs involves reasoning in the same reference class as parity, then it is likely that the algorithm they learn is not globally optimal (in a Bayesian sense). Indeed, we see from parity that optimal algorithms in this reference class lack the incremental pathways necessary to be learnable via SGD, and the way that LLMs solve complex problems probably is mediated by curriculum-learning-style “training wheels” that learn general solutions, just not of the most efficient type. [5]

Acknowledgments

I’ve talked to lots of people about this, but particularly important for this post have been a number of conversations with Kaarel Hänni on related topics. I also want to thank Sam Eisenstat for first telling me about parity and the notion of learnability, and thanks to Jake Mendel and Lucius Bushnaq for related discussions.

 

 

 

  1. ^

    More precisely, the probabilistic version and I might be assuming some cryptographic hardness in other places in this section (but not in the rest of the post!)

  2. ^

    Technically, there are 2^n functions but they only span a 2^{n-1}-dimensional subspace, since the "empty-set parity" function is zero; to make this statement precise, one needs to replace by a constant function. A more commonly used related basis is the "Boolean Fourier mode" basis with basis elements which replace the  valued functions parity function by a -valued analog. Working with this basis is generally nicer, and in particular makes "uncorrelatedness" arguments cleaner.

  3. ^

    They are uncorrelated on the set of all inputs (i.e., the expected value of doesn't change even if you condition on a specific value of ), but they are correlated ("in a random way") on some fixed polynomial-sized set of "training" inputs. In the latter training-set context, they are "not very" correlated, and the correlations can be proved to be suitably "unbiased" when viewed as a noise term.

  4. ^

    This is normally defined as the Boltzmann distribution associated to loss, an object particularly important to Singular Learning Theory.

  5. ^

    Note that this isn’t even an “intuition-level” proof: it’s not obvious that modern ML methods require knowing how to solve problems in the reference class of “parity for n-bit inputs such that is very large”. And even if this were the case, it’s not obvious that ML learning problems don’t just happen to have the property that the “training wheels” for learning parity-like problems don’t just happen to be needed to produce Bayes-optimal algorithms for some important simpler problems. Later when we discuss connections between physics and ML, we’ll see other more rigorous reasons to dismiss strong versions of the SGD = Bayesian hypothesis. But at the same time, it’s important to note that in many contexts: namely, when looking locally in a basin, or looking at simple “circuit-level” behaviors that don’t have enough accumulated complexity to break out of a low-dimensional paradigm, it is reasonable and productive to make little distinction between the two types of learning.

New Comment