Learning to Correct: Reinforcement Learning for Multi-Attempt Chain-of-Thought
Abstract
State-of-the-art reasoning models can utilize long chain-of-thought to solve sophisticated coding and math problems. During this process, the model often attemps at a solution multiple times by utilizing verification and self-reflection capabilities. In this work, we view a long CoT as a process where the model makes K attempts at solving a problem in which each attempt is allowed to build on earlier solutions. This way, we formalize long CoT as a pass@K problem with dependent samples. Under this formalism, we derive the policy gradient and RL algorithms for optimizing long CoT reward and derive how each attempt should be weighed for unbiased gradient computation while maintaining small variance. Our theory reveals how the self-correction capability and dense feedback influence the training and eventual performance of long CoT-based reasoning. We provide both synthetic and real experiments corroborating our theory and the benefits of the associated algorithms. As a by product, our research also reveals when verification and long chain-of-thought is beneficial over parallel sampling strategies and the role of the model capability.