AI ALIGNMENT FORUM
AF

Catastrophic Regressional Goodhart
Frontpage

5

Goodhart in RL with KL: Appendix

by Thomas Kwa
18th May 2024
7 min read
0

5

Frontpage
Previous:
Catastrophic Goodhart in RL with KL penalty
No comments62 karma
New Comment
Moderation Log
More from Thomas Kwa
View more
Curated and popular this week
0Comments
Mentioned in
27Catastrophic Goodhart in RL with KL penalty

This is the appendix to the previous post on Goodhart’s Law and KL regularization, containing all of our proofs.

Theorem about distributions

Theorem 1: Given any heavy-tailed reference distribution Q over R with mean μQ, and any M,ϵ>0, there is a distribution P with mean μP>M and DKL(P∥Q)<ϵ.

Proof: WLOG let μQ=0. We construct a sequence of distributions {Pt} such that limt→∞EPt[X]≥c for any constant c, and limt→∞DKL(Pt∥Q)=0. We define Pt for any t>c thusly. Writing FPt(x) for the CDF PrX∼Pt(X≤x) and ¯FPt(x) for 1−FPt(x), we let

¯FPt(x)=⎧⎪ ⎪⎨⎪ ⎪⎩1−1−c/tFQ(t)FQ(x)x≤tc/t¯FQ(t)¯FQ(x)x>t

Intuitively, we rescale the part of the distribution to the right of t evenly to have total probability c/t, which is less than 1 because t>c.

We must check that limt→∞EPt[X]=c. We can write

EPt[X]=FPt(t)EPt[X|X≤t]+¯FPt(t)EPt[X|X>t]=FPt(t)EQ[X|X≤t]+¯FPt(t)EQ[X|X>t]=FQ(t)EQ[X|X≤t]+¯FQ(t)EQ[X|X>t]+(FPt(t)−FQ(t))EQ[X|X≤t]+(¯FPt(t)−¯FQ(t))EQ[X|X>t]=EQ[X]+(¯FPt(t)−¯FQ(t))(EQ[X|X>t]−EQ[X|X≤t])

We know that EQ[X|X>t]>t because it is an integral of values strictly greater than t. Because EQ[X]=0 is a weighted average of EQ[X|X>t] and EQ[X|X≤t], and EQ[X|X>t]>0, we know EQ[X|X≤t]<0. So EQ[X|X>t]−EQ[X|X≤t]>t. We also know that for sufficiently large t, (FPt(t)−FQ(t))>0. Intuitively, starting from Q, which has mean 0, Pt moves a probability mass approaching ct from mean <0 to mean >t.

Now we can say

limt→∞EPt[X]>limt→∞[EQ[X]+(¯FPt(t)−¯FQ(t))(t−0)]=limt→∞(ct−¯FQ(t))t=limt→∞c−t¯FQ(t)

Because Q has a finite mean, limt→∞t¯FQ(t)=0, and so limt→∞EPt[X]≥c.

Now we check that limt→∞DKL(Pt∥Q)=0:

DKL(Pt∥Q)=∫RlogPt(dx)Q(dx)Pt(dx)=∫x≤tlogPt(dx)Q(dx)Pt(dx)+∫x>tlogPt(dx)Q(dx)Pt(dx)=FPt(t)logFPt(t)FQ(t)+¯FPt(t)log¯FPt(t)¯FQ(t)\quad since both ratios are constant=FPt(t)log1−c/tFQ(t)+¯FPt(t)log¯FPt(t)¯FQ(t)

Since both 1−c/t and FQ(t) go to 1 as t→∞, the left term goes to 0, and so

limt→∞DKL(Pt∥Q)≤0+limt→∞¯FPt(t)log¯FPt(t)¯FQ(t)=limt→∞ctlogct¯FQ(t)≤limt→∞ctlog1¯FQ(t)=limt→∞−ctlog¯FQ(t) since t>c

Q is heavy tailed, so by definition limt→∞eat¯FQ(t)=∞ for all a>0. This implies that for every a>0 there is a sufficiently large tc so that for all t>tc, ¯FQ(x)>e−at, which means that log¯FQ(t)>−at.

Therefore for every a>0, limt→∞DKL(Pt∥Q)≤limt→∞−c/tlog¯FQ(t)<limt→∞−−actt=ac, which since KL divergence is nonnegative means that limt→∞DKL(Pt∥Q)=0 as desired. ■

Theorem about deterministic Markovian-return MDPs

Definition: A deterministic-transition MDP with Markovian returns (DMRMDP) is an MDP (S,A,P,R) such that:

  • The transition function P:S×A→S is deterministic, i.e., for each state s∈S and action a∈A, there exists a unique state s′∈S such that P(s′|s,a)=1.
  • There is a set of sink states E⊆S that terminate a trajectory, which is disjoint with the set of start states.
  • Returns are Markovian; that is, for any two trajectories τ=(s1,a1,…,sn),τ′=(s′1,a′1,…,s′n), if sn=s′n, then τ and τ′ have identical return distributions. Equivalently, for the trajectory random variable T=(S1,A1,…) distributed according to any policy, with return G, G⊥ ⊥(S<i,A<i) | Si for any i≥1.

Note: Sampling from a language model and applying RLHF is well-modeled as a DMRMDP, since the state is a sequence of tokens (actions) which deterministically results from the last token and returns depend only on the final state.

Theorem 2: Let W=(S,A,P,R) be a deterministic-transition MDP with Markovian returns. Given W we define the function that takes policies to trajectories Tr:(S→ΔA)→Δ(S×A)∗, and the average return function g:(S×A)∗→R which induces a function G:Δ(S×A)∗→ΔR. Let π0:S→ΔA be some reference policy. If G∘Tr(π0) is heavy-tailed with finite mean μQ, then for any M,ϵ>0, there is a policy π with mean return E[U|U∼G∘Tr(π)]>M and Es∈T,T∼Tr(π)[DKL(π(s)∥π0(s))]<ϵ.

Proof: We will exhibit a distribution of trajectories ρ such that DKL(ρ∥Tr(π0))<ϵ and E[G(ρ)]>M, and then construct a policy π with Tr(π)=ρ. Note that this proof applies for continuous action spaces if trajectories are replaced with measurable sets, but this would make it harder to read.

Let ρπ0=Tr(π0). We have a heavy-tailed distribution of return Q≜G(ρπ0) over R, so we can apply Theorem 1. But to define ρ, we can construct Pt in the proof of Theorem 1 in a particular way. For any t>c, we need a Pt that uniformly upweights values of mean return such that ¯FPt(t)=c/t. We can define ρt such that any trajectory τ is upweighted by a factor depending only on its mean return:

ρt(τ)=⎧⎪ ⎪⎨⎪ ⎪⎩1−c/tFQ(t)ρπ0(τ)g(τ)≤tc/t¯FQ(t)ρπ0(τ)g(τ)>t

Then we can let Pt≜G∘ρt and the rest of the proof of Theorem 1 applies. Therefore, applying the theorem, we can let ρ=ρt for sufficiently large t, and then μG∘ρ>M and DKL(G∘ρ,G∘ρπ0)<ϵ. But by the chain rule for KL divergence, DKL(ρ,ρπ0)=DKL(G∘ρ,G∘ρπ0)+Eγ∼G∘ρ[DKL(ρ(T)|G(T)=γ ∥ ρπ0(T)|G(T)=γ)]. Since we constructed ρ so that the probabilities of each τ conditional on its return being γ are equal, the second term is zero, and we also have DKL(ρ,ρπ0)<ϵ.

Finally, since the KL divergence between trajectory distributions is the sum of KL divergence between policies at each action in the trajectory, and each trajectory has at least one action, Es∈T,T∼Tr(π)[DKL(π(s)∥π0(s))]≤ET∼Tr(π)∑s∈T[DKL(π(s)∥π0(s))]=DKL(ρ∥ρπ0)<ϵ as desired.

To define π such that Tr(π)=ρ, we let π(s,a)=Pr(ai=a|τ=(...,s,ai,...)∼ρ).

Then the probability that any trajectory τ=(s1,a1,…,an) is sampled is:

Tr(π)(τ)=n∏i=1π(si,ai)=n∏i=1Pr(ai=a′i|τ′=(...,s,a′i,...)∼ρ)=n∏i=1Pr(ai=a′i|τ′=(s′1,a′1,...,s,a′i,...)∼ρ,s<i=s′<i,a<i=a′<i)&=ρ(τ)

In (2), returns are Markovian, so all trajectory prefixes ending in state s have the same distribution of returns under any policy. In the construction of ρ, all trajectories with the same mean return have equal measure. Therefore, conditioning on earlier states and actions of τ does not change the measure, so we can write (3). So Tr(π)=ρ as desired. ■

Lagrange multipliers to minimize KL divergence

Theorem 3: If V is light-tailed, EQ[V] is finite, and d=DKL(P∥Q) is bounded, then EP[V] is bounded, and EP[V]→0 as d→0.

Using Lagrange multipliers, we find that when KL divergence is minimized, we have P(V)[λ1logP(V)Q(V)+λ2−X]=0 for some constants λ1,λ2, so

logP(V)Q(V)=V−λ2λ1P(V)=Q(V)exp(V−λ2λ1)=Q(V)eV/λ1e−λ2/λ1=CQ(V)eV/λ1

That is, the new PDF is an exponential tilting of the old PDF. Now what is EP[V]? It’s just ∫∞−∞CVeV/λ1Q(X)dV. If the distribution of V is heavy-tailed distribution, this is ∞; if it is light-tailed, this is some finite value.

When d=0, P and Q are identical and E[V]=0. So by a continuity argument, EP[V]→0 as d→0. ■

Light tails + independence imply EV→∞

Theorem 4: If U=X+Vwith X and V both light-tailed, and the distribution of U is continuous, and π∗(β)△=argmaxπE[U(π)]−βDKL(π,π0), then limβ→0+E[V(π∗(β))]=∞.

Proof: Fix some β. Using Lagrange multipliers, we find that for any event S, Prπ(S)=Prπ0(S)eλU(S). Let c(β) be the median value of U under the policy π∗(β); that is, Pr(U>c(β)|U∼G∘Tr(π∗(β)))=12. This exists because U has a continuous distribution. Then:

E[V|π]=12E[V|π,U<c]+12E[V|π,U≥c]≥12E[V|π,U<c]+12E[V|π]limβ→0+E[V|π]≥limβ→0+12E[V|π,U<c]+limβ→0+12E[V|π]

The left term is c, while the right term is ∞, so the overall limit is ∞. ■