Speculative Decoding

paper

Algorithm of Speculative Decoding Step


Algorithm SpeculativeDecodingStep

Inputs: 𝑀𝑝,𝑀𝑞,prefix.

Sample 𝛾 guesses 𝑥1,,𝛾 from 𝑀𝑞 autoregressively.

for i = 1 to 𝛾 do
𝑞𝑖(𝑥)𝑀𝑞(prefix+[𝑥1,,𝑥𝑖1])
𝑥𝑖𝑞𝑖(𝑥)
end for

Run 𝑀𝑝 in parallel.

𝑝1(𝑥),,𝑝𝛾+1(𝑥)𝑀𝑝(prefix),,𝑀𝑝(prefix+[𝑥1,,𝑥𝛾])

Determine the number of accepted guesses 𝑛.

𝑟1𝑈(0,1),,𝑟𝛾𝑈(0,1)

𝑛min({𝑖1|1𝑖𝛾,𝑟𝑖>𝑝𝑖(𝑥)𝑞𝑖(𝑥)}{𝛾})

Adjust the distribution from 𝑀𝑝 if needed.

̃𝑝(𝑥)𝑝𝑛+1(𝑥)

if 𝑛<𝛾 theñ𝑝(𝑥)norm(max{0,𝑝𝑛+1(𝑥)𝑞𝑛+1(𝑥)})

end if

Return one token from 𝑀𝑝, and 𝑛 tokens from 𝑀𝑞.

𝑡̃𝑝(𝑥)

return prefix+[𝑥1,,𝑥𝑛,𝑡]


Correctness of Speculative Sampling

We will now show that for any distributions 𝑝(𝑥) and 𝑞(𝑥), the tokens sampled via speculative sampling from 𝑝(𝑥) and 𝑞(𝑥)

are distributed identically to those sampled from 𝑝(𝑥) alone. Let 𝛽 be the acceptance probability (Definition).

Note that as

̃𝑝(𝑥)=norm(max{0,𝑝(𝑥)𝑞(𝑥)})=𝑝(𝑥)min{𝑞(𝑥),𝑝(𝑥)}𝑥𝑝(𝑥)min{𝑞(𝑥),𝑝(𝑥)}=𝑝(𝑥)min{𝑞(𝑥),𝑝(𝑥)}1𝛽,

the normalizing constant for the adjusted distribution ̃𝑝(𝑥) is 1𝛽, where the last equation follows immediately from Lemma 3.3 and Theorem 3.5.

Now:

𝑃(𝑥=𝑥)=𝑃(guess accepted,𝑥=𝑥)+𝑃(guess rejected,𝑥=𝑥)

Where:

𝑃(guess accepted,𝑥=𝑥)=𝑞(𝑥)min{1,𝑝(𝑥)𝑞(𝑥)}=min{𝑞(𝑥),𝑝(𝑥)}

And:

𝑃(guess rejected,𝑥=𝑥)=(1𝛽)̃𝑝(𝑥)=𝑝(𝑥)min{𝑞(𝑥),𝑝(𝑥)}

Overall:

𝑃(𝑥=𝑥)=min{𝑝(𝑥),𝑞(𝑥)}+𝑝(𝑥)min{𝑝(𝑥),𝑞(𝑥)}=𝑝(𝑥).

As desired.

Definition

The acceptance rate 𝛽𝑥<𝑡 , given a prefix 𝑥<𝑡, is the probability of accepting 𝑥𝑡𝑞(𝑥𝑡|𝑥<𝑡) by speculative sampling.

Lemma

Define

𝐷𝐿𝐾(𝑝,𝑞)=𝑥|𝑝(𝑥)𝑀(𝑥)|=𝑥|𝑞(𝑥)𝑀(𝑥)|

where 𝑀(𝑥)=𝑝(𝑥)+𝑞(𝑥)2. Then

𝐷𝐿𝐾(𝑝,𝑞)=1𝑥min{𝑝(𝑥),𝑞(𝑥)}

Proof.

𝐷𝐿𝐾(𝑝,𝑞)=𝑥|𝑝(𝑥)𝑀(𝑥)|=𝑥|𝑝𝑞|2=1𝑥𝑝+𝑞|𝑝𝑞|2=1𝑥min{𝑝(𝑥),𝑞(𝑥)}