Blog

  • Large Reasoning Models: How o1 Replications Turned into Real Competition

    Large Reasoning Models: How o1 Replications Turned into Real Competition

    Some of the most important AI advances in 2024 were definitely test-time reasoning LLMs, or large reasoning models (LRM), that is, LLMs that are trained to write down and reuse their chains of thought for future reference. Reasoning LLMs started with the o1 family of models by OpenAI (I wrote a short capabilities post in September, when it appeared). Since then, they have opened up a new scaling paradigm for test-time compute, significantly advanced areas such as mathematical reasoning and programming, and OpenAI is already boasting its new o3 family—but we still don’t have a definitive source on how OpenAI’s models work. In this post, we discuss how attempts to replicate o1 have progressed to this date, including the current state of the art open model, DeepSeek R1, which seems to be a worthy rival even for OpenAI’s offerings.

    Introduction

    It’s not that I’m so smart, it’s just that I stay with problems longer”. This is one of the many quotes often attributed to Albert Einstein, probably wrongly (Calaprice, 2010). Regardless of the actual author, it is a perfect description of what large reasoning models (LRM) can do: they stay with a problem, generating new thought tokens and ruminating on their own reasoning to make further progress.

    Over the past few years, large language models (LLMs) have dramatically advanced our ability to process and generate natural language. Yet, despite this remarkable fluency, many LLMs still struggle with complex reasoning tasks. Enter the era of large reasoning models (LRM; I’m not sure it’s the official name already, but I have to choose something)—LLMs that don’t just respond but actively reason through problems, reusing and refining their own thought processes in real-time.

    In this post, we trace the rapid evolution of these reasoning models, starting with OpenAI’s groundbreaking o1 and leading to DeepSeek-R1, an o1 replication that verified a lot of ideas behind LRMs and published them openly. We will explore the key innovations, challenges, and unexpected lessons learned along the way, focusing on how o1 replications have grown into serious competitors in the space.

    We still don’t really know how o1 or o3 models work—OpenAI is pretty good at keeping secrets. But since September, there have been plenty of attempts to replicate o1. The defining feature of these test-time reasoning LLMs has been the test-time scaling that we discussed back in September and will discuss again here: a test-time reasoning LLM is one that feeds on its own chain of thought and can use intermediate outputs to produce better results. In this post, we review reasoning LLMs from the beginning, describing several plausible o1 replications and culminating in the latest and greatest replication that may even be better than the original: DeepSeek-R1.

    I want to get one thing out of the way first: the amount of press devoted to the $5.5M figure for training DeepSeek-V3 was ridiculous. Yes, as, e.g., Andrej Karpathy explains here, this means a highly optimized training pipeline; a lot of the DeepSeek-V3 paper (DeepSeek-AI, 2024) is devoted to algorithmic optimizations that go all the way down to the hardware level. This is definitely significant progress, other labs will no doubt implement similar optimizations too, and we will devote a section of this post to the higher-level optimizations introduced by DeepSeek.

    But it does not mean, as many seem to assume, that you can “get in the game” of training frontier LLMs on a seven-figure budget. The $6M figure concerns the final training run, but it was the culmination of God knows how many experiments performed over months and years of work, by over a hundred rather expensive researchers (note how many authors the DeepSeek papers have), on a cluster that was specially built for DeepSeek by their owner, the High-Flyer hedge fund that does not seem to lack for money in the slightest. A recent report by SemiAnalysis (Patel et al., Jan 31, 2025) estimates the total CapEx for DeepSeek to be ~$1.6B, with a cost of $944M.

    This is going to be a large post. We will:

    • discuss the origins of reasoning models in chain-of-thought techniques that started with the famous “let’s think step by step” observation;
    • consider again the o1 series and what it meant for the world of AI when it arrived;
    • show several examples of o1 replications; there were plenty, but my favourite example will be a sequence of three papers that explains how attempts at building a reasoning model butted into the “bitter lesson” of simple distillation;
    • with that, we will finally proceed to DeepSeek-R1; first, I will explain the reinforcement learning background behind it, with a brief discussion of how policy gradient algorithms evolved to DeepSeek’s GRPO;
    • the other pillar for R1 was the base model, DeepSeek-V3; it also deserves a separate section where we will consider their novelties, in particular multi-head latent attention that did not exist yet when I wrote a post on extending the LLM context size;
    • and with that, we will finally go to the actual R1 paper, explaining how it works, what were the results, and what else DeepSeek has contributed.

    That’s a lot of ground to cover—let’s begin!

    “Let’s Think Step by Step”: Chains and Trees of Thought

    Chain-of-Thought. But how do large models “reason”? The answer begins with the development of chain-of-thought prompting. One of the earliest breakthroughs in reasoning LLMs came from an unexpected place—just asking the model to think aloud.

    You have no doubt heard of the famous phrase “let’s think step by step”. Back in May 2022, researchers from the University of Tokyo and Google Kojima et al. (2022) found that simply adding this phrase before each answer significantly improved an LLM’s performance!

    It turned out that by including this simple instruction in the prompt, models were encouraged to generate a series of intermediate reasoning steps rather than jump directly to the final answer, helping the model organize its thoughts and improving the clarity and accuracy of responses. The prompt urged the models to decompose a complicated problem into subtasks.

    This was a zero-shot variation on the earlier idea of chain-of-thought prompting: the idea that in few-shot prompts, you can give the models full reasoning examples rather than just curt answers. The picture above started from item (c); here is the missing part:

    Instead of simply providing a final answer, chain-of-thought (CoT) prompting guides the model to articulate its intermediate reasoning—often through a series of natural language steps—and this too led to significant boosts in performance on a lot of tasks, especially reasoning-heavy tasks such as arithmetic, logic, and commonsense reasoning. Chain-of-thought prompting was known since at least the work of Google researchers Wei et al., (2022) released in January 2022. Chain-of-thought prompting made the internal reasoning process explicit, and this helped LLMs solve problems much better; here is probably the first example of CoT prompting in literature (Wei et al., 2022):

    Tree-of-Thought. Building on this idea, researchers set to work on the CoT paradigm and developed a lot of novel variations, mostly related to more advanced structured reasoning strategies. The most important step here is probably going from linear chains-of-thought to tree-of-thought methods (Besta et al., 2024). Traditional CoT prompting shows and encourages a single linear progression of reasoning steps; in contrast, tree-of-thought techniques urge the models to explore multiple branches of reasoning simultaneously. In a tree-of-thought framework, each node in the tree represents a partial reasoning sequence, and the model can branch out to consider alternative hypotheses or strategies at each decision point. 

    An early important attempt was made by Wang et al. (2022) who introduced the Chain-of-Thought with Self-Consistency (CoT-SC) approach. In CoT-SC, the model samples several different reasoning paths and then aggregates them to obtain the final answer:

    This approach generated multiple parallel chains of reasoning but it was not yet a branching tree. Trees were a natural next step, and they appeared in two works submitted to arXiv virtually simultaneously, only two days apart: Long (May 15, 2023) and Yao et al. (May 17, 2023)

    In the former, the tree construction process was supervised by a separate module called the ToT controller, ToT standing for “tree-of-thought”. A checker module evaluates if a solution has been found (for problems like math/logic puzzles you can do it deterministically, for others checking requires an LLM call), and the ToT controller implements a decision procedure of whether to generate more chains of thought from the current node or backtrack; here is an illustration from Long (2023):

    The approach by Yao et al. (2023), also called ToT, was similar; the idea was to generate thoughts, understood as individual coherent reasoning steps, and then implement a voting or other selection mechanism to choose the most promising paths, backtrack, and so on:

    On every step, the model would sample several paths and then vote to decide which path is best; here is an example for a creative writing assignment:

    Graph-of-Thought and beyond. What is the natural step after a tree of thoughts? A more general graph of thoughts, of course! Generalizing to more general graphs would allow, for instance, to combine two thoughts into one conclusion, a natural operation that we humans do all the time in our reasoning. And indeed, graphs of thoughts (GoT) appeared very soon after ToT, and again we see two papers that introduced first implementations of this idea appearing only two days apart (I know I should have gotten used to the rate of progress in AI these days, but it is still a bit staggering): Besta et al. (August 18, 2023) and Lei et al. (August 16, 2023).

    Besta et al. (2023) introduced an approach very similar to Long (2023) but with additional actions available in the controller such as aggregation of several thoughts. Here is a detailed overview illustration of their approach:

    Lei et al. (2023) extended this with another novelty: condition nodes that summarize what the model has learned from previous failures or promising chains of reasoning. In this approach, the checker function has the ability to add new conditions, and then potential solutions are checked against these conditions:

    Overall, this line of research progressed along some very expected lines: from chains to trees to graphs. Each step introduced new degrees of freedom in the thought process, and no doubt required new tricks to actually implement and make work, but all this progress was still very much expected. Here is a timeline modeled after the timeline in the survey by Besta et al. (2024):

    In 2024, this approach continued. For example, any tree-of-thought and graph-of-thought approach needs a way to score the potential of new thoughts to decide which to branch further and which to abandon. Usually these were just absolute scores assigned to thoughts by a separate evaluation module, but Zhang et al. (2024) developed a technique based on direct pairwise comparisons. In this approach, instead of generating a single linear chain of intermediate reasoning steps or assigning scores to possible candidates, the model produces multiple candidate thoughts at key decision points and compares them in pairs to determine which one is the most promising. I will not go into the advantages of this approach over direct scoring, but in many cases, it is indeed better to compare things relatively rather than absolutely; here is an illustration from the paper:

    Another interesting development was the Algorithm-of-Thought (AoT) approach by Sel et al. (2023), which formulated problem-solving as an algorithmic process, embedding instructions for tree-based reasoning directly into the prompt. Each node in the tree represents a step in the algorithm, with the model prompted to generate the next step based on prior results. AoT was mostly designed for tasks like multi-step mathematical reasoning and procedural planning, of course:

    So more trees, more graphs, better comparisons—all good progress, but nothing too groundbreaking. Let me also mention the “tree of uncertain thoughts” by Mo, Xin (2023) and “tree of clarifications” by Kim et al., 2023—even without a detailed explanation you can already sort of understand what the idea is, right?

    And then something unexpected happened: OpenAI released o1-preview. While chain-of-thought prompting opened the door to reasoning, OpenAI’s o1 series showed how to refine and scale it effectively; o1 models were the first to demonstrate what happens when chain-of-thought reasoning is taken to its logical extreme: scaling with test-time compute. Let us discuss that model and its implications.

    OpenAI’s o1 Series: What It Did and What It Meant

    The o1 announcement and system card. We already discussed o1-preview and its capabilities in an earlier post, “OpenAI’s o1-preview: the First LLM That Can Answer My Questions”. There, I mostly concentrated on how o1-preview could answer “What? Where? When?” questions (see the post if you don’t know what that means—it’s fun!) and other new capabilities, but the most important part was the part about a new scaling law. Large reasoning models can do better when given more computational resources at test time, a feature that is very rare in machine learning unless you count hybrid ML-plus-search solutions such as AlphaZero. Here is a sample plot from the original o1 announcement:

    We discussed the implications in detail in the same post: if an LLM can feed on its own output and can benefit from more thinking time, this opens up a host of new possibilities. But how did o1 achieve that, what was the difference between CoT-based approaches that we discussed above and o1?

    In fact, OpenAI did not tell us much about the training process of the o1 series; you can check their posts but there’s no detailed description there. In December, OpenAI released the o1 system card, but it was all about evaluations, both capabilities and red-team safety evals, and a little bit about the data, and not at all about the model. The safety evaluations, by the way, also looked sketchy since they had been made on a different, weaker version of the model (see a detailed description of this by Zvi Mowshowitz).

    As for the actual approach, the system card just said what was already obvious: “The o1 large language model family is trained with reinforcement learning to perform complex reasoning”. While regular CoT asks a fixed model to process thoughts recursively, reasoning models are specifically fine-tuned with reinforcement learning to improve their thinking process.

    We have already discussed how reinforcement learning (RL) helps LLMs help humans with RLHF. Adding reasoning capabilities is a natural next step. The idea is simple: generate a reasoning chain, use it as input for another inference, and repeat, like we humans can ruminate on their thoughts and ultimately arrive at novel thoughts that we did not have at first glance. Reasoning requires a model to plan, verify, and correct its own thinking—things that are not explicitly enforced in regular supervised training.

    This again looks like a perfect setting for RL: the model is supposed to produce a number of discrete steps (thoughts) that lead to a reward only at the end (solving a problem), with no intermediate gratification. Just like learning to play chess, which has been done almost to perfection with exactly this “RL from scratch” approach of AlphaZero (Silver et al., 2017) and MuZero (Schrittwieser et al., 2019). Instead of just learning from fixed data, an RL-trained model should be able to explore different reasoning approaches, get feedback in the form of a reward, and refine its behavior over time:

    Before we proceed to replications, let us discuss one more component in this RL process that seemed promising and necessary at the time.

    Process reward models. In 2022, DeepMind researchers Uesato et al. (2022) suggested to upend one basic wisdom of reinforcement learning: never to reward the process, only the final state. If you are learning a game of chess, you do not know how good a given move is, you only know who wins at the end, and if you try to reward, e.g., winning material in the middle of the game you only make things worse: the RL model will engage in reward hacking, and you don’t actually care about killing a pawn, you care about winning the game.

    But in long chain-of-thought style reasoning, an LLM will output its thoughts along the way, and every thought might be evaluated on its own merits. If the LLM is solving a math problem, we can assume that making an arithmetic error at some intermediate step is not some cunning plan that will lead to a correct proof, it’s just a mistake that can be found and corrected. If you train a model to find these mistakes, the result is a process reward model (PRM) that can pinpoint the mistakes:

    This idea was taken further in a paper called “Let’s Verify Step by Step“, published by OpenAI researchers Lightman et al. (2024). Again, if the model outputs its reasoning as a sequence of steps, you can point exactly to which step went wrong:

    Therefore, Lightman et al. (2024) train a process reward model (PRM) to evaluate the entire thinking process, step by step. For a mathematical proof, it might work like this (the proof on the left is correct, and the proof on the right is incorrect):

    This idea indeed improved reasoning and led to better results. A similar approach was put forward by Xia et al. (2024) who suggested an evaluation methodology called ReasonEval with this exact purpose—verify every step in the solution and find where the actual mistakes occur:

    Although it was just an evaluation model, the authors also showed how to use it to improve reasoning: if you are doing distillation from the reasoning solutions by a stronger teacher model, it helps to filter them by the ReasonEval framework and fine-tune the student model only on the traces that are actually fully correct, not just correct in their final answer.

    So given that PRMs were improved by OpenAI themselves, how exactly would the RL mechanism in o1 work?

    O1 replications. Once the advantages of the new model became apparent (that is, immediately), researchers immediately started to suggest how exactly o1 might achieve this breakthrough. For a good representative of these guesses, see this post by Subbarao Kambhampati: “Imagine you are trying to transplant a ‘generalized AlphaGo’—let’s call it GPTGo—onto the underlying LLM token prediction substrate… the moves are auto-generated CoTs… the success signal is from training data with correct answers… let RL do its thing to figure out credit-blame assignment for the CoTs… during inference you can basically do rollouts”.

    Sounds pretty plausible, but these were just speculations. Naturally, people have tried to actually replicate the new approach, that is, reinvent it and then implement. How did that go?

    An interesting story unfolded in three papers by the same Generative AI Research Lab (GAIR) at the Shanghai Jiao Tong University (Qin et al., 2024Huang et al., 2024Huang et al., 2025). The papers were called “O1 Replication Journey”, and the authors’ original intention was to create and openly document the entire research process, from original ideation to testing hypotheses and all the experiments. The first paper was written after only a month of research, in October 2024; here is the timeline from Qin et al. (2024):

    They introduced an approach called “journey learning”, where a model would be trained to output the entire exploration process, including backtracking and failed hypotheses. Qin et al. (2024) showed some original encouraging results and proposed a plan for further research. Their plans involved process reward models and Monte Carlo tree search as a promising algorithm for growing trees-of-thought during test time, and generally the paper reads like the optimistic beginning of a very promising research project.

    Guess what happened next? In less than two months, the lab published a report on the second part of their project (Huang et al., 2024). Their main result was that… once you’ve got a dataset of reasoning traces, you don’t need anything else! The best approach for them proved to be knowledge distillation—basically, using o1 to teach a smaller model by copying its answers and learning from them. The researchers found that with a pretty straightforward approach—supervised fine-tuning on tens of thousands of responses from o1’s API—they could surpass o1-preview in solving complex math problems, specifically on the AIME Olympiad style dataset. Even more surprisingly, the distilled model showed strong generalization—despite only being trained on math problems, it performed well in open-domain question answering and even became less prone to agreeing with misleading questions (what they call “sycophancy”). In just two months, the big project with a lot of moving parts was reduced to this:

    “The bitter lesson” in this case sounds like this: once OpenAI has made o1 available, even in a restricted form, the best way to improve the performance of your model is to use o1 reasoning traces for distillation. You don’t have to actually do anything novel except maybe some filtering procedure—a smarter model is all you need.

    However, Huang et al. themselves point out why this approach looks highly suboptimal from a wider point of view:

    • this imposes an obvious performance ceiling: you can’t become smarter than the teacher model by training on its solution traces;
    • if researchers stop reasoning from first principles, this may lead to stagnation in areas that require genuine breakthroughs; if only a few leading labs are doing breakthrough research, mostly behind closed doors, progress will probably slow down a lot (unless AI models will increasingly help with self-improvement research—that’s a big “unless”, in my opinion, but it is a story for another day);
    • in general, AI research seems to be overly concentrated on prompt engineering for top models instead of solving hard problems from the ground up.

    After publishing this report, both optimistic and disappointing at the same time, the GAIR lab found a more specialized field of application for their efforts. The third part of their o1 replication journey, published in January by Huang et al. (2025), concentrates on medical reasoning, showing that inference-time scaling can be helpful for medical diagnoses and treatment planning:

    There have been other replication attempts as well. In December, Zeng et al. (2024) published a paper called “A Roadmap to Reproduce o1” where they surveyed existing approaches to reinforcement learning that might be relevant, including process reward models, reward shaping, and various search techniques that might be used both in training and on inference. They provided a lot of educated guesses on how o1 might operate, but this was still just a roadmap.

    There were more practical replication attempts too. Zhang et al. (2024) released o1-Coder, a model that was specifically designed for programming, a domain where self-play and checking the reasoning step by step work just as well as in mathematical reasoning. They also incorporated a process reward model and used Monte Carlo tree search (MCTS) to improve thinking at test time. Here is their take on the sequence of ML progress, emphasizing that self-play and self-evaluation basically equals infinite synthetic data:

    Industrial players also quickly followed suit: LLaMA-o1 appeared for the LLaMA family (SimpleBerry, 2024) and was later improved to LLaMA-Berry (Zhang et al., 2024), the Qwen team released QwQ, and so on. In LLaVA-o1, Xu et al. (2024) expanded o1-style reasoning to vision-language models. But the main treat was still ahead.

    Before we get into the details of R1, let me jump a little bit ahead in time. As I was writing this post, the bitter lesson struck again: on January 31, a reasoning model called s1 was published by Stanford researchers Muenninghoff et al. (2025). This model, with 32B parameters, was trained by pure distillation on a small dataset of 1000 examples, for a total training cost of about $50. While it did not outperform o1 or DeepSeek-R1, it came pretty close in many benchmarks and already exhibited the same kind of test-time scaling that one would expect of good reasoning models and put the model at a very good point on the Pareto frontier of sample-efficiency:

    Moreover, Muenninghoff et al. found some creative but ultimately very straightforward ways to turn additional computational budgets into more intelligence:

    • budget forcing: if the model wants to stop, force it to keep thinking by appending a token like “Wait” to its chain of thought;
    • majority voting: you can run the model several times (in parallel, if possible) and do majority voting on the result.

    Interestingly, both ways yield significant gains with no change to the underlying model. So yes, the bitter lesson is here in full: once you have a stronger model, distillation is all you need to get smaller ones, and raw scaling of compute in very simple ways can often yield additional benefits. Even wisdom of the crowds seems to work for a crowd of one!

    But let’s go back to the main topic. On January 20 (with the paper appearing on arXiv on Jan 22), DeepSeek-AI (2025) released their own replication. The DeepSeek-R1 model made such a splash in the AI community that it overflowed to the general public much more than usual. Moreover, this was a replication accompanied by a detailed academic description in the paper, so we can actually understand and analyze what’s happening in R1. In the rest of this post, we will try to understand what exactly DeepSeek did with R1, and to understand that, we need to begin with reinforcement learning. So let us step back a little.

    Reinforcement Learning in R1: from REINFORCE to GRPO

    Policy gradient methods in RL. Large language models (LLMs) are typically trained through supervised learning—they consume massive datasets of human-written text, training to predict the next token in a sequence. By now we all know that this approach produces models that can generate fluent text, but it does not necessarily make them good at reasoning or even helpful at all. We have discussed that to get to optimal reasoning we need to fine-tune the reasoning chain with reinforcement learning.

    What kind of RL do we need? There are two general approaches to RL:

    • value-based RL learns the so-called value functions, using ML models such as neural networks to predict either state values V(s) or state-action values Q(s,a), i.e., expected rewards that can be obtained by starting from state s with action a; the optimal policy would then follow by maximizing these values;
    • policy-based RL learns a policy \pi(a|s) directly, using an ML model to give probabilities of actions in a state; in this case, the policy model is updated directly as a result of a new learning episode.

    In general, policy-based methods are usually more efficient in terms of necessary data, so they would be clearly preferable for LLMs. They have an inherent difficulty in that selecting actions is a piecewise constant function that would not allow gradients to go through, but this problem has been resolved over thirty years ago in the REINFORCE algorithm (Williams, 1992), resulting in the so-called policy gradient algorithms. Formally speaking, the algorithm is trying to maximize the expected cumulative reward as a function of the policy \pi with parameters \boldsymbol{\theta},

        \[J(\pi_{\boldsymbol{\theta}}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\sum_{t=0}^\infty \gamma^tr(s_t,a_t)\right],\]

    where r(s_t,a_t) is the immediate reward on step t and \gamma is the discount factor, and via the policy gradient theorem you can compute the gradient of the objective function J as

        \[\nabla_{\boldsymbol{\theta}}J(\pi_{\boldsymbol{\theta}}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\sum_{t=0}^\infty \nabla_{\boldsymbol{\theta}}\log \pi_{\boldsymbol{\theta}}(a_t|s_t) Q^\pi(s_t,a_t) \right].\]

    Actor-critic approaches. The REINFORCE algorithm used the objective function J and its gradient directly, but there are severe disadvantages to this approach. There is a catch in policy gradient methods and RL in general: rewards are often sparse (you only get a score at the end of a sequence) and always stochastic (a given sequence is randomly generated). This leads to all estimates being extremely noisy, with huge variances, and further work in RL focused on reducing this variance.

    To reduce noise, modern policy gradient methods such as proximal policy optimization (PPO) use advantage estimation, learning a separate critic model to predict how much better an action is compared to the average in this state. The critic learns this average, i.e., basically the state value function:

    Formally speaking, this amounts to adding a baseline of V(s) to the objective function; it is easy to check that it actually does not change the optimization results (since the sum of action probabilities is a constant 1):

        \[\nabla_{\boldsymbol{\theta}}J(\pi_{\boldsymbol{\theta}}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\sum_{t=0}^\infty \nabla_{\boldsymbol{\theta}}\log \pi_{\boldsymbol{\theta}}(a_t|s_t)\left(Q^\pi(s_t,a_t)-V^\pi(s_t)\right)\right].\]

    Although formally nothing has changed, in reality the noise has been much reduced at the cost of training a separate model to estimate V^\pi(s).

    Overall, actor-critic algorithms (Konda, Tsitsiklis, 1999) work as follows: in a loop,

    • generate responses for a new training episode;
    • using the critic network, estimate the baseline reward for each step;
    • use the formula above to update the weights \boldsymbol{\theta} of \pi_{\boldsymbol{\theta}};
    • use regular TD-learning to update the weights of V(s).

    Later work introduced parallelization schemes with synchronous and asynchronous updates such as A3C and A2C (Mnih et al., 2016) but the main actor-critic paradigm remained unchanged.

    Proximal policy optimization. The actor-critic paradigm, however, has one more significant drawback: the gradients may grow large, and updating policy weights too aggressively can cause performance to collapse, leading the policy completely away from reasonable regions.

    To alleviate this, proximal policy optimization (PPO; Schulman et al., 2017) introduced a clipped objective function; on a step t,

        \[J_{\mathrm{PPO}}(\pi_{\boldsymbol{\theta}}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\min\left( r_t(\boldsymbol{\theta})A_t,\mathrm{clip}(r_t(\boldsymbol{\theta}), 1-\epsilon, 1+\epsilon)A_t\right)\right],\]

    where \epsilon is a small constant, typically 0.1–0.2, and r_t(\boldsymbol{\theta}) is the ratio function that compares how likely it is to choose the action a_t in state s_t for the new policy compared to the old one:

        \[r_t(\boldsymbol{\theta})=\frac{\pi_{\boldsymbol{\theta}}(a_{t}|s_t)}{\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(a_{t}|s_t)}.\]

    In other words, PPO implements the update directly if it is small, but if the resulting change in probability is too aggressive, PPO clips it, trying to stay inside the trust region of reasonable policies.

    Clipping the objective function stabilizes training while still allowing for exploration. Note that I have omitted the sum over in the formula above, partly to reduce clutter but also because in practice policy gradient algorithms update the weights after a batch of experience is collected, and timestamps are sampled from this replay buffer rather than all summed together.

    PPO was in fact a simplification of the trust region policy optimization (TRPO) algorithm proposed earlier by Schulman et al. (2015). In TRPO, the trust region was made explicit with a constrained optimization problem:

        \[\max_{\boldsymbol{\theta}}\mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[ \frac{\pi_{\boldsymbol{\theta}}(a_{t}|s_t)}{\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(a_{t}|s_t)}A_t\right]\quad\text{subject to}\quad \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\mathrm{KL}(\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(\cdot | s_t)\|\pi_{\boldsymbol{\theta}}(\cdot | s_t))\right]\le\delta,\]

    i.e., maximize the expected reward in a region where the new strategy is similar to the old strategy as a distribution over actions taken in state st, formalized with the Kullback-Leibler divergence between them. Standard optimization theory tells us that instead of a constraint you can use a penalty, a regularizer on the KL divergence:

        \[\max_{\boldsymbol{\theta}}\mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[ \frac{\pi_{\boldsymbol{\theta}}(a_{t}|s_t)}{\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(a_{t}|s_t)}A_t- \beta \mathrm{KL}(\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(\cdot | s_t)\|\pi_{\boldsymbol{\theta}}(\cdot | s_t))\right].\]

    This KL penalty term was replaced with clipping in PPO, but when PPO started being applied to actual LLMs, further regularization was required as the policies would venture too far from the original, engaging in reward hacking or forgetting pretrained knowledge. Therefore, InstructGPT (Ouyang et al., 2022) added another KL regularizer to the objective function in their RLHF procedure:

        \[J_{\mathrm{RLHF}}(\pi_{\boldsymbol{\theta}}) =J_{\mathrm{PPO}}(\pi_{\boldsymbol{\theta}}) - \beta \mathrm{KL}(\pi_{\boldsymbol{\theta}}\|\pi_{\mathrm{ref}}),\]

    where \pi_{\mathrm{ref}} is some reference policy, e.g., the one obtained after supervised fine-tuning but before any RL-based fine-tuning.

    From PPO to GRPO. PPO and similar algorithms such as TRPO work well, but they still are actor-critic algorithms, i.e., they require a critic model to estimate the advantage function At. This effectively doubles the computational costs since the critic is as large as the policy model, and if the critic does not train well it can introduce bias and lead policy learning astray. For LLMs, both the costs are very high, and the bias problem is more severe since the sequences are significantly longer than usual: you get a reward only for the last generated token in a sequence of thousands, compared to, say, a chess game that lasts for several dozen moves.

    Therefore, the DeepSeek team introduced a new variation of policy gradient algorithms called group relative policy optimization (GRPO) in their previous work on DeepSeekMath (Shao et al., 2024). The idea is simple: we need the critic to effectively normalize the reward estimates. So instead of using a separate model for a critic, we could normalize by sampling several different outputs (answers) from the policy model (LLM) and averaging over them. It’s just like batch normalization: we replace absolute value estimates with statements like “answer 2 is 1.5σ above average”.

    Adding the same KL-based regularizer as InstructGPT, we get

        \[J_{\mathrm{GRPO}}(\pi_{\boldsymbol{\theta}}) = \mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\frac{1}{G}\sum_{i=1}^G\min\left(\frac{\pi_{\boldsymbol{\theta}}(a_{i,t}|s_t)}{\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(a_{i,t}|s_t)}{\hat A}_{i,t},\mathrm{clip}\left(\frac{\pi_{\boldsymbol{\theta}}(a_{i,t}|s_t)}{\pi_{\boldsymbol{\theta}^{\mathrm{old}}}(a_{i,t}|s_t)},1-\epsilon,1+\epsilon\right){\hat A}_{i,t}\right)\right]-\beta \mathrm{KL}(\pi_{\boldsymbol{\theta}}\|\pi_{\mathrm{ref}}),\]

    where the averaging goes over a mini-batch of G answers sampled from the LLM; estimates of the action advantages are also derived from the same mini-batch:

        \[{\hat A}_{i,t}=\frac{r_i-\mathrm{avg}(r_1,\ldots,r_G)}{\mathrm{std}(r_1,\ldots,r_G)}.\]

    Here is a graphical comparison of PPO and GRPO by Shao et al. (2024):

    To give you an intuition of why this makes sense:

    • group normalization in GRPO improves stability in much the same way as a separately trained critic model, but without the memory costs;
    • as a result, GRPO mimics the human labeling process of RLHF: normalized advantages are relative to the group, just like a human in RLHF would rate the quality of an answer compared to a sample of other answers, not in some absolute sense;
    • no critic also means faster training, which was also part of DeepSeek’s success.

    This comes at a computational cost, of course: if you want to average over a mini-batch of G=16 outputs, it means that you have to run your model 16 times.

    In a way, policy gradient algorithms have made a full circle, first introducing a critic model and then getting rid of it when the models have become too large:

    This is not to say that actor-critic algorithms are now obsolete, of course, they are still preferable in many, if not most, situations since often you will not run into the absolute limits of available memory with a single model. But it is interesting how easy it was to get rid of the critic as soon as it became necessary.

    So where did that get DeepSeek? DeepSeek’s response to o1 was not just replication—it was an ambitious rethinking of model architecture and training efficiency. Now that we are clear on their main RL novelty, let us examine the other pillar of R1: their latest DeepSeek-V3 model.

    DeepSeek-V3: Memory-Efficient KV-Caching and Other Tricks

    Basic structure of DeepSeek-V3. Okay, so by now we understand how RL can be used to improve its reasoning abilities of an LLM, and we have been talking about LLMs for a long time (onetwothreefourfive…). Is that all? Has the DeepSeek team just run the RL on an open LLM like Llama, and increasing capabilities together with their GRPO algorithm have made R1 possible?

    Yes and no. First, they ran GRPO not on Llama but on their own model, DeepSeek V3, released last December (DeepSeek-AI, 2024). DeepSeek V3 has had its own share of new interesting ideas. First , it is a huge mixture-of-experts (MoE) model, with 671 billion parameters in total but only 37 billion active per token and a new load balancing strategy. I hope to write a separate post on MoE approaches so let’s not get into this one now; suffice it to say that in the feedforward part of each Transformer layer, there are many parallel processing paths (experts) and a separate router subnetwork that chooses a few of them (top-k according to its scoring) to activate.

    Second, DeepSeek-V3 uses multi-token prediction (MTP) as part of its training objective: instead of just predicting one token per time step, the model is trained to predict multiple future tokens at once. This is not an entirely novel idea; it sounds very straightforward and indeed had been shown to work by Gloeckle et al. (2024) about a year ago. They used a multi-head architecture to predict several tokens at once and showed some improvements:

    Unlike Gloeckle et al. (2024), DeepSeek predicts several tokens sequentially, keeping the causal chain for every token and predicting the next token with another Transformer block that receives representations produced by the previous blocks as input:

    MTP provides a denser signal and helps the model to “plan ahead” better even before any RL is added.

    Key-Value Caching and Multi-head Latent Attention. Third, DeepSeek uses the multi-head latent attention (MLA) in its Transformer layers instead of the standard multi-head attention (MHA). This is an approach based on a low-rank approximation to MHA, and although we have discussed similar techniques in a post on extending the context, MLA is new for us because it was introduced only in April 2024, in the DeepSeek-V2 paper (DeepSeek-AI, 2024b). Let us consider it in detail.

    We know the standard MHA mechanism: queries, keys, and values are computed as

        \[\mathbf{q}=\mathbf{W}^Q\mathbf{x},\qquad \mathbf{k}=\mathbf{W}^K\mathbf{x},\qquad \mathbf{v}=\mathbf{W}^V\mathbf{x},\]

    and the results of MHA are given by

        \[\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{Q}\mathbf{K}^\top\right)\mathbf{V}.\]

    We know that the QKT matrix, which has size L⨉L for an input sequence of length L, is a bottleneck of quadratic complexity, and discussed ways to alleviate it.

    But even if we accept the computational complexity, another important bottleneck here is memory needed for key-value caching. To make generation efficient, in standard MHA each token caches a full set of keys (K) and values (V) for every attention head. In this way, queries from newly generated tokens can be much more efficiently processed through the decoder as most keys and values are already precomputed. Note that this is an approximation: theoretically, representations of all tokens—and hence all keys and values—should change with every new token starting from layer 2, but in practice old tokens reuse their keys and values; otherwise generation would be computationally intractable.

    Key-value caching is a standard technique, but when dealing with long sequences (like DeepSeek-V3’s 128K context length), this KV cache becomes very memory-intensive. Researchers have developed some approximations to reduce it:

    • Multi-Query Attention (MQA; Shazeer, 2019) makes different heads in the Transformer decoder share the same set of keys and values;
    • Grouped-Query Attention (GQA; Ainslie et al., 2023) is an interpolation between MQA and the full key-value cache that shares the same keys and values across subsets of attention heads.

    DeepSeek’s MLA takes another step in the same direction. Instead of directly computing and caching full K and V matrices, MLA projects them into a compressed latent space using a down-projection (compression) matrix \mathbf{W}^C:

        \[\mathbf{c}=\mathbf{W}^C\mathbf{x},\qquad \mathbf{c}\in\mathbb{R}^{d_c},\quad \mathbf{W}^C\in\mathbb{R}^{d_c\times d},\]

    where the compressed dimension d_c is much smaller than the original d. During inference, keys and values are restored with reconstruction matrices (also trained):

        \[\mathbf{k}^R=\mathbf{W}^{RK}\mathbf{c},\qquad \mathbf{v}=\mathbf{W}^{RV}\mathbf{c}.\]

    At this point, we are done with the values v, and the keys k are also further augmented with positional embeddings; DeepSeek uses RoPE rotary embeddings (Su et al., 2024) that are applied to a separately reconstructed version of the key and concatenated to \mathbf{k}^R:

        \[\mathbf{k}^P=\mathrm{RoPE}\left(\mathbf{W}^{RP}\mathbf{c}\right),\qquad \mathbf{k}=[\mathbf{k}^R;\mathbf{k}^P].\]

    Here is an illustration from the DeepSeek-V2 paper (DeepSeek-AI, 2024c):

    Queries are not cached, and they were processed in a standard Transformer way in DeepSeek-V2, but in V3 they also undergo a similar procedure, this time to reduce the memory needed for activations during training:

        \[\mathbf{c}^Q=\mathbf{W}^{CQ}\mathbf{x},\quad \mathbf{q}^R=\mathbf{W}^{RQ}\mathbf{c}^Q,\quad \mathbf{q}^P=\mathrm{RoPE}\left(\mathbf{W}^{RPQ}\mathbf{c}^Q\right),\quad\mathbf{q}=[\mathbf{q}^R;\mathbf{q}^P].\]

    As a result, instead of storing full keys and values for each token, DeepSeek models store only a compressed d_c-dimensional latent representation. Keys and values are reconstructed only when needed, saving memory while maintaining efficiency, and as a result DeepSeek V3 can handle up to 128K token windows without exploding KV storage.

    At this point, we are done with the DeepSeek-V3 architecture; here is a general illustration (DeepSeek-AI, 2024):

    There are also plenty of tricks in the design of the training pipeline that made DeepSeek’s training so efficient, but I will not get into those here and will simply refer to the original paper (DeepSeek-AI, 2024). For us, it is more interesting to see how this could turn into a reasoning model.

    DeepSeek Reasoning: from V3-Base to R1 to V3

    From V3-Base to V3. By training the model described above on 14.8 trillion high-quality tokens (data preparation is another big challenge that we will not discuss today), DeepSeek obtained what is called DeepSeek-V3-Base: a raw pretrained model that has not yet been fine-tuned for alignment, real-world usage or, for that matter, reasoning.

    To get to DeepSeek-V3, the authors do a standard fine-tuning stage with SFT (supervised fine-tuning) on a carefully curated instruction tuning dataset of 1.5M instances and RLHF with human annotators for data that requires human evaluation: roleplaying in dialogues, creative writing, that sort of thing.

    The most interesting part for us is that in SFT, DeepSeek-V3 also uses distillation from DeepSeek-R1 answers: the instruction tuning dataset is extended with samples that show the R1 responses. To get these responses, we need to get to R1 first, so now we switch over to the latest DeepSeek paper that details the reinforcement learning part (DeepSeek-AI, 2025).

    R1-Zero: no-nonsense RL training. On the surface, reinforcement learning for reasoning LLMs sounds simple: just reward correct answers and let RL do its thing. It has always been tempting to just run a RL algorithm on a pretrained LLM without any hassle of collecting instruction tuning datasets for SFT, getting human annotators, and so on—so, of course, DeepSeek tried just that!

    They started with DeepSeek-V3-Base, used the GPRO algorithm that we discussed above, and defined a reward function with two components: accuracy rewards for actually answering the questions correctly and format rewards for keeping the thinking process within specified tags. As it often happens with RL, it is best not to reward the process but only the result; otherwise models may arrive at reward hacking, producing high rewards with undesirable behaviour. This is a very interesting topic (alas, a topic for another day), and modern LLMs do absolutely engage in reward hacking—see, e.g., a recent post by Lilian Weng.

    Note a very important caveat here: to know the reward function, we need to know the correct answers to the problems! DeepSeekMath, a previous effort by the same team (Shao et al., 2024), collected a large curated DeepSeekMath Corpus with mathematical text extracted from the Common Crawl dataset, but this time we need more than just tokens to predict—we need problems, preferably hard problems, with known answers.

    For math and coding, this is relatively simple: you can choose math problems that can be solved by formalized external solvers such as SymPy programs, and for coding you can choose problems that have a comprehensive suite of tests available. You can also generate synthetic data along the same lines. This, together with question answering data with known answers, constitutes most of what DeepSeek calls “rule-based rewards”. Still, it is interesting that the DeepSeek-R1 paper does not specify the exact datasets that it used in reinforcement learning.

    Note also that this is the main reason why reasoning modes are currently more useful for math and coding than other tasks: math, coding, and direct question answering (QA) have correct answers that can be verified automatically. For a task like creative writing, there is no correct answer to compare to and no formalized set of tests that a solution must pass, only subjective human evaluation or equally subjective automated evaluation by another LLM, which makes a similar RL pipeline difficult to set up.

    Glossing over the data question, let’s get on to the results. The authors show how the model’s performance grew with more and more reinforcement learning. Here is their plot on the standard AIME evaluation dataset (mathematical olympiad problems); pass@1 is the one-shot performance and cons@16 is the majority voting (consensus) across 16 runs of the model:

    Note that running the model multiple times helps a lot even if we do not suggest human verification and simply take the majority vote of the answers; this is yet another facet of the test-time compute scaling that we discuss in this post.

    As RL training progressed, R1-Zero learned to make better use of the thinking time, and the number of thinking tokens grew steadily throughout the training process:

    The authors note how exciting it was to witness the emergence of new behaviours in the model such as reflection (re-evaluation of previous thinking steps) and exploring multiple approaches to the problem. They also devote a whole subsection to an “aha moment” in R1-Zero’s reasoning, but, to be honest, I didn’t get their example at all:

    To me, it looks like after the “aha moment” the model just repeated the exact same reasoning as above, and I don’t see any mistakes in the formulas the first time around; maybe there’s something I’m missing here, but the paper does not explain it any further…

    Anyway, R1-Zero was already an excellent proof of concept for reinforcement learning for reasoning capabilities, but the DeepSeek team pushed one step further. If you can achieve results on par with a version of OpenAI’s o1 by pure RL, what can you do if you guide it with supervised fine-tuning as well?

    From R1-Zero to R1. To get to the full R1 model, DeepSeek researchers adopted the following process:

    • first collect a small dataset with several thousand chain-of-thought examples and perform supervised fine-tuning of V3-Base on this dataset; this “cold start” phase (a strange name—sounds like a warm start to me!) helps ensure that CoT remains readable (R1-Zero had problems with that) and also improves performance;
    • then comes the large-scale RL in exactly the same way as for R1-Zero;
    • after reasoning-oriented RL has converged, collect a new dataset for further SFT with the resulting checkpoint; at this point, they collect reasoning traces from the model, filtering for readability (e.g., the authors mention that by default their models often reason in a mix of languages, especially English and Chinese, and this is one thing that they filter out here); the resulting dataset has about 600K such traces plus another 200K non-reasoning examples;
    • perform SFT with this dataset of 800K samples;
    • and finally, do an RLHF stage on top of all that, this time targeting helpfulness and alignment with human needs and preferences; similar to the original InstructGPT (Ouyang et al., 2022), they use reward modeling to be able to provide the signal at this point.

    To me, the fact that the authors obtain a separate dataset of 800K samples screams one word: distillation! Once you have already implemented the curation and filtering process, it is easy to reuse the resulting dataset for other models as well: just fine-tune them on the same samples produced by the intermediate R1 checkpoint, or, even better, by the final R1 model.

    Naturally, DeepSeek researchers did this immediately, and this is exactly the SFT step of training the final DeepSeek-V3 (non-reasoning) model. Moreover, the authors did distillation on Llama and Qwen models, and we will see the results below. Note, however, that they only did an SFT distillation step and did not try to subject other (non-DeepSeek) models to their reinforcement learning pipeline. This, as they themselves note, may yield additional benefits, so we will see what other researchers with appropriate computational budgets can do with it.

    But one of the most important takeaways from R1 lies not in what they did, but rather in what they did not end up doing. While everyone else was trying to incorporate PRMs and MCTS into their RL environments, DeepSeek researchers included a section called “Unsuccessful attempts”, where they report that:

    • they tried process reward models but never could make them work: PRMs lead to reward hacking if used throughout the environment and are not worth the computational costs if used only for reranking outputs;
    • they explored Monte Carlo tree search to improve generation on test-time and also found that it didn’t work: the branching factor is too high for efficient search, and it is too hard to train a value model (the V(s) function that evaluates the “current position”) that would work for incomplete generations.

    The bitter lesson strikes again, this time in a weaker form: it turns out that you don’t need detailed guidance for the model, you just need to set up the basic training process in the right way. This was unexpected for many researchers, and it was one of the reasons why DeepSeek-R1 became so popular in a field already replete with o1 replications.

    Results. Another important reason for why DeepSeek-R1 made a lot of noise was because the evaluation results were simply very high. New models appear from time to time, but this was a reasoning model that actually reached the levels of OpenAI o1 and even exceeded them.

    On math and coding benchmarks (which are o1’s forte as well), R1 showed results on par and slightly exceeding the full o1 model and significantly ahead of the o1-mini model:

    It remained slightly behind o1 on scientific reasoning in GPQA Diamond, but actually results on English language reasoning datasets were also excellent:

    What may be even more important for some applications, DeepSeek also fine-tuned and released a number of distilled models, i.e., regular “non-reasoning” LLMs that were made better by fine-tuning. They used open models from the Qwen 2.5 family and Llama 3.3 and fine-tuned them with SFT on the 800K curated samples that we discussed above. They did not use RL to train further, although the authors note that it might help. Even with this SFT-only approach, the resulting models outperformed other top non-reasoning LLMs in benchmarks:

    This again confirms the “bitter lesson” found by the GAIR lab: once you have a powerful enough model, the best way to bring others up to speed is not some cunning innovation but simply “distill, baby, distill”…

    Oh, and one last thing. Unlike most other frontier models, DeepSeek-V3 and DeepSeek-R1 are indeed completely open, as in both open weights and open code. There are guides on how to install DeepSeek-R1 locally; I used this one and the 1.58 bit quantized version was a breeze to set up and run even on a home desktop with an RTX 3090 and 64GB of RAM. Naturally, the result was unusable in practice, the model ran at about 1-2 tokens per second, but it was very easy to prove the concept.

    Conclusion

    The journey from chain-of-thought prompting to full-fledged reasoning models represents one of the most exciting developments in AI today. OpenAI’s o1 series set the stage, demonstrating the power of test-time reasoning and opening new possibilities for scaling LLM performance. DeepSeek-R1 pushed the field even further, showing that with the right combination of reinforcement learning and new algorithmic improvements, it is still possible to compete at the highest level with the frontier labs (if not, alas, for under $10M).

    Perhaps the most important lesson from this journey is that cutting-edge AI research continues to balance the classical tension between clever innovations and brute-force scaling. DeepSeek’s GRPO approach is a fresh take on policy gradients, and DeepSeek-V3 incorporates not only the MLA mechanism but a lot of training optimizations that go down to the hardware level. Still, its success ultimately came from combining robust reinforcement learning with traditional supervised fine-tuning and large-scale data collection.

    Oh, and the one thing I simply cannot understand at all? The market reaction. DeepSeek came out and showed that:

    • you can still get in the game of frontier models with a good team and large but not planetary scale amounts of money;
    • moreover, they showed and even released into open source new tricks to make training more efficient, further democratizing frontier AI research;
    • this can inspire new players in the field, increasing the demand for hardware; no, you can’t get in the game of frontier LLM research with $10M of financing, but it seems like you might not need hundreds of billions either;
    • even further, DeepSeek showed that test-time scaling is real and much easier to achieve than it was even half a year ago;
    • therefore, most LLMs will now be expected to have this reasoning mode, which adds a whole another multiplier on hardware requirements: it is now not enough to train a huge model for months once, you also need to spend much more compute when serving it to customers, which adds up very quickly;
    • this reasoning mode is one of the best news ever for hardware producers: large reasoning models need a lot of compute not only on training, which is a one-time cost, but also on every inference; OpenAI already showed this with o1, but DeepSeek proved that most LLMs should have it.

    And what did the market say? NVIDIA stock dropped by 15% with a record absolute loss for any company in history, and NASDAQ had one of its worst days ever:

    This may be a good reason to congratulate DeepSeek: their parent hedge fund High-Flyer could probably recoup all of the costs, and then some, if they anticipated this move in the markets. But I honestly have no idea what is the actual logic behind this. At least, the market gained back some ground later, although NVIDIA kept hurting from a lackluster presentation of their RTX 50xx GPU series. Moreover, the actual news of an NVIDIA rival appearing in the mix did not seem to have a significant additional impact.

    But I digress. As we look ahead, the implications are vast. Test-time reasoning models will no doubt become standard in next-generation LLMs, yet another stepping stone on the road to AGI. OpenAI has already released a new family of reasoning models, the o3 series, that are better yet in benchmarks and seem to be the best on the market right now. At the same time, the demand for computational resources will continue to rise: will the hardware providers be able to keep up? We are only at the beginning of this new era of large reasoning models, and the rate of progress is not slowing down anytime soon. 

    We live in some very exciting times, exciting even for those of us who can only watch from the sidelines. I will continue to document the journey of modern AI, and I hope that one day these posts will combine into a greater whole. But until then, we have a lot more topics to discuss—see you next time!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Attack of the Titans: Transformers 2.0?

    Attack of the Titans: Transformers 2.0?

    We interrupt your regularly scheduled programming to discuss a paper released on New Year’s Eve: on December 31, 2024, Google researchers Ali Behrouz et al. published a paper called “Titans: Learning to Memorize at Test Time”. It is already receiving a lot of attention, with some reviewers calling it the next big thing after Transformers. Since we have already discussed many different approaches to extending the context size in LLMs, in this post we can gain a deeper understanding of Titans by putting it in a wider context. Also, there are surprisingly many neurobiological analogies here…

    Introduction

    One of the founders of neuropsychology, a Soviet researcher named Alexander Luria, had a stellar career. He managed to remain a world-class scientist in Stalinist Russia, publishing English translations of his famous books and visiting conferences all over the world despite working in such highly politicized areas as psychology, child development and genetics (for instance, he had to defend a physician degree instead of a biological one due to Lysenkoism). Oliver Sacks said that in Luria’s works, “science became poetry”.

    As a founding father of neuropsychology, Luria argued that the human brain operates as a combination of three distinct groups of cognitive functions, or “functional blocks” (Zaytseva et al., 2015):

    • attention, or more precisely regulation of arousal and attention,
    • memory, or reception, processing, and storage of information, and
    • cognitive function, that is, programming, regulation, and verification of activity.

    It is fascinating to see how artificial intelligence keeps conquering these general areas of intelligence one by one. In modern deep learning, the first steps were to process “sensory inputs”, i.e., high-dimensional raw inputs such as images or sound. But classical neural networks, such as deep CNNs characteristic of computer vision, were processing the entire input in the same way and only turned to “cognitive processing” in the last few layers; the bulk of their work was about extracting features from raw inputs.

    The next step turned out to be to implement attention mechanisms: “look at” the inputs selectively and have a separate subnetwork devoted to which inputs to look at now. This is quite similar to what our brains are doing when they are filtering which parts of, say, our field of view should go into the working memory for further processing.

    This step was already sufficient to get us where we are now: as we have been discussing on this blog for the last two years, the self-attention mechanism that enabled Transformers has led to the current crop of LLMs with not that many architectural changes (but a lot of algorithmic improvements, scaling, data collection efforts and so on, of course).

    In fact, neurobiology has the predictive coding theory (Rao, Ballard, 1999Clark, 2013Ondobaka et al., 2017) that postulates that the brain is essentially an LLM generalized to arbitrary sensory inputs: we keep a predictive model of the environment and learn when actual future diverges from the predicted future. Memory is also implemented via learning, with forgetting done by synaptic pruning and decay and long-term memory being updated during the consolidation process, e.g., when we sleep.

    But standard Transformer-based LLMs do not really have a memory. Not in the sense of storing weights on a computer, but in the sense of having a working memory where the LLM might choose to store something and then retrieve it during the “cognitive processing” part of its operation. In a way, methods for extending the context are a substitute for “real” memory, trying to allow straightforward access to everything at once, as are RAG-based approaches that use information retrieval tools to access information that cannot fit into the context. Reasoning LLMs such as the OpenAI’s o1 family or the recently released DeepSeek R1 also can be viewed as adding a form of memory: LLMs are allowed to record their thoughts and then return to them in a future inference.

    It sounds like another jump in capabilities may come if we actually add a real working memory to the LLMs, a mechanism that would allow them not only to access back raw context but to save processed information, their “thoughts”, for future reference right as they are processing their input context. In technical terms, this might look like a blend between Transformers and RNNs (recurrent neural networks) since the latter do have memory in the form of a hidden state. In a recent post, we discussed linear attention techniques that are one way to produce this blend.

    On December 31, 2024 (yes, literally), Google researchers Ali Behrouz et al. published a paper called “Titans: Learning to Memorize at Test Time”, with a new way to implement a working memory inside the LLM. Their approach, called Titans, looks like a very straightforward idea, it basically just goes ahead and implements a hidden state inside the Transformer, although it did need the recent ideas developed in conjunction with Mamba-like models to actually make it work. Overall, this sounds like one of the most interesting recent developments in terms of the pure academic novelty, and some bloggers are already hailing Titans as “Transformers 2.0”.

    In the rest of the post, we dive into the details of Titans. Unlike most other posts here, this one is almost exclusively based on a single paper, so it will follow the paper more closely than usual and will be much shorter than usual. Still, I feel like Titans may be an important advance worth a separate post. We will first discuss the memory mechanisms in Titans (there are three, and two of them are new for us), then show how Titans solves the memory limitations of Transformers, present the technical side, i.e., how this memory works in reasonable time, discuss experimental results, and finish with my conclusion.

    Memory Mechanisms in Titans

    Three levels of memory. At its core, Titans take inspiration from how human memory works. We have short-term memory for immediate information (like remembering a phone number for a few seconds, or like a student can always regurgitate the last few words that the professor has said even if the student had not been paying attention), working memory for active problem-solving, and long-term memory for storing the big stuff. 

    Titans aim to replicate this hierarchy in neural architectures to handle sequences of data more effectively. Similar to the three components of human memory, Behrouz et al. define three different levels of memory mechanisms:

    • short-term memory in Titans, just like in regular Transformers, serves as the “focus lens”; it uses attention mechanisms to zero in on the most relevant pieces of information in the current context, and there is nothing new here;
    • long-term memory is the real secret sauce of Titans; it does not just store all the past information like a database—it implements a separate neural network that learns what to remember, when to remember, and how to forget; note that here we are talking about input-dependent test-time memory, i.e., remembering important stuff from context, not training the weights;
    • persistent memory stores information about the task itself rather than specific data being processed; it is essentially a set of learnable parameters, and unlike long-term memory, they do not depend on the specific input.

    The paper introduces several different ways to combine long-term memory with other types, but a general flowchart might look like this—long-term memory gets updated after processing every token and then both long-term and persistent memory are added to the input:

    Since short-term memory is exactly regular self-attention (see, e.g., my post on Transformers), let us concentrate on the other two memory mechanisms. Long-term memory is the most important part, so let me begin with persistent memory and save the long-term mechanism for the next section.

    Long-term memory. How can a neural network memorize stuff? In essence, it should have some kind of a hidden state that stores information, and rules to update this hidden state and retrieve information from it. We are familiar with this concept in two corner cases:

    • a recurrent neural network (RNN) has a fixed size hidden state vector, updated as the RNN processes the input sequence, and the whole point of, say, an LSTM is to learn good ways to update the memory and reuse it at later moments of time;
    • an attention-based network such as the Transformer has perfect recall in the sense that its “attention span” always includes the entire context window; this allows a Transformer to learn long-range dependencies much better but does not scale by default, as we have discussed in detail previously.

    In one of the latest posts, we discussed the notion of linear attention, culminating in Mamba-like models. In terms of memory, linear attention has a fixed memory size (like an RNN) in terms of the number of parameters but it switches from a vector to a matrix hidden state and uses associative memory, storing information in the matrix so that it can be retrieved with a query vector (the Mamba post has details and examples). Mathematically, in the simplest case linear attention gives you the output as an associative memory lookup by the query vector from the memory updated iteratively on every step:

        \[\mathbf{S}_{t}=\mathbf{S}_{t-1}+\mathbf{v}_{t}\mathbf{k}_{t}^\top,\qquad \mathbf{o}_{t}=\mathbf{S}_{t}\mathbf{q}_{t}.\]

    This is, of course, just a special case of RNNs, a special form of hidden state update and retrieval, but a very useful special case. There is a major constraint to linear attention, which we mentioned in the Mamba post and which Behrouz et al. (2024) also identify: when context becomes too long, associative memory overflows, leading to unpredictable forgetting. To alleviate this, researchers have proposed:

    • forgetting mechanisms for linear attention and Mamba-like models, e.g., in the xLSTM (Beck et al., 2024) and in Mamba 2 (Dao, Gu, 2024);
    • changing the write operation from purely additive to, e.g., erasing a previously memorized value; in recent research, this is exemplified by the DeltaNet line of models (Yang et al., 2024a2024b).

    Titans take a step further: instead of treating memorization as layers inside the neural network, why don’t we consider memorization as a separate internal neural network trained on the current context information? In other words, the long-term memory module in Titans performs online learning: given the current input \mathbf{x}_{t} and current memory state {M}_{t}, the goal is to update M_t with gradient descent according to some loss function \ell:

        \[M_t=M_{t-1}-\theta_t\nabla\ell(M_{t-1};\mathbf{x}_{t}).\]

    Where will the loss function come from? This actually moves us a step further to how memory is implemented in the human brain: we remember events that are surprising, that is, that do not conform to the predictions of the future that the brain is constantly producing. Thus, Behrouz et al. define the surprise metric \ell(M_{t};\mathbf{x}_{t}) as the difference between retrieved memory and actual content; for key-value associative memory, where \mathbf{x}_{t} is converted into a key-value pair as

        \[\mathbf{q}=W^Q\mathbf{x},\qquad\mathbf{k}=W^K\mathbf{x},\qquad\mathbf{v}=W^V\mathbf{x},\]

    the loss can be defined as follows:

        \[\ell(M_{t-1};\mathbf{x}_{t})=\left\|M_{t-1}(\mathbf{k}_t)-\mathbf{v}_t\right\|_2^2.\]

    Once again: this is an “internal network”, an inner loop that does this kind of gradient descent while processing a given input, including at inference time. This raises performance concerns, which we will address in the next section.

    What is the actual architecture of the “memory network”? Behrouz et al. (2024) use straightforward multilayer perceptrons (MLP) as M(\mathbf{k}_t); this is natural for a first shot at the goal but they also give two important remarks:

    • the “multi” in “multilayer” is important: standard associative memory can be thought of as a single-layer linear perceptron, but as you make the memory network deeper you unlock more expressivity for compressing the data, and this is actually important in practice; here is the corresponding plot from one of their ablation studies:
    • there already exist architectures that are more efficient in data memorization than MLPs, e.g., memory layers by Berges et al. (2024) or Universal Transformer Memory by Cetin et al. (2024); tweaking the architecture of long-term memory is a natural direction for future work.

    Momentum and forgetting. There are two more key modifications to the idea above. First, if you encounter a “surprising” step gradient descent might soon reduce the surprise to a small value while actually the surprising part of the data continues for longer. Behrouz et al. again compare this to human memory: “an event might not consistently surprise us through a long-period of time although it is memorable. The reason is that the initial moment is surprising enough to get our attention through a long time frame, leading to memorizing the entire time frame”.

    Therefore, it makes sense to use not the surprise metric directly but an exponential moving average version of it that would preserve a part of the previous values of the metric. In gradient descent, this reduces to the well-known method of adding momentum:

        \[M_t = M_{t-1} + S_t,\]

    where

        \[S_t = \eta_tS_{t-1} - \theta_t\nabla\ell(M_{t-1};\mathbf{x}_{t}).\]

    Note that the momentum coefficient \eta_t can also be data-dependent: the network may learn to control when it needs to cut off the “flow of surprise” from previous timestamps.

    Second, as the context grows in size the memory network will eventually overflow, and this leads to the need to have some kind of forgetting mechanism as well. Titans use a “forget gate” for this, adding a value \alpha_t\in[0,1] that controls how much we forget during a given step. Overall, the update rule looks like the following:

        \[M_t = (1-\alpha_t)M_{t-1} + S_t,\quad S_t = \eta_tS_{t-1} - \theta_t\nabla\ell(M_{t-1};\mathbf{x}_{t}).\]

    Note that \alpha_t is also data-dependent, and the network is expected to learn when it is best to flush its memory.

    Persistent memory. As we have seen, long-term memory is input-dependent, i.e., it depends on the context. Obviously, it would also be useful to have input-independent memory that would store information specific to the entire task at hand, not the specific current input.

    Titans use a simple but efficient approach to persistent memory: they introduce additional learnable parameters P that are appended to the start of every sequence:

        \[X'=\mathrm{concat}\left(\left(\mathbf{p}_1, \mathbf{p}_2, \ldots, \mathbf{p}_{N_p}\right), X\right).\]

    The parameters P do not depend on X and can act as task-related memory.

    This idea has already been introduced in literature. For example, Hymba (Dong et al., 2024), a hybrid architecture combining Transformers and Mamba-style attention, introduced the so-called meta tokens that work in exactly the same way. On the other hand, Xiao et al. (2024) treated this idea as “attention sinks”, using it to alleviate the effect that the first attention weights, close to the start of the sequence, are almost always large.

    In essence, this is it; no ODEs, no complicated math, just a neat meta-idea of recurrently training a small memory network during each computation on both training and inference, supplemented by persistent memory. There are, however, a few little devils lurking in the implementation details; let’s discuss those.

    Technical Details: Adding Memory and Training It, Too

    Where to add memory. Although we have described the proposed memory mechanism in full detail, there is also a choice of where exactly to add it in the architecture. Behrouz et al. (2024) consider three different options.

    Memory as a Context (MAC) means that results of retrieval from memory get added to the input \mathbf{x}, so it effectively becomes a concatenation of persistent memory, lookup results from contextual memory, and \mathbf{x} itself. After this extended input is processed, memory is updated as we have discussed above:

    In this case, Titans split the input sequence into chunks that are fed into self-attention one by one, and the long-term and persistent memory mechanisms are prepended to it, so the overall structure of attention masks (recall our earlier post on extending the context) looks like this:

    Memory as a Gate (MAG) is a different approach where only the persistent memory is added to the context, and lookup results are used to control the output via a gating mechanism:

    In this case, the input data is not segmented: the input is used to update long-term memory, and the attention mechanism uses a sliding window, like this:

    Finally, in the Memory as a Layer (MAL) architecture the memory layer serves as a preprocessing mechanism before the context (extended with persistent memory as usual) is fed into the attention mechanism:

    This approach also suggests that since neural memory already can serve as a layer, we can do without the self-attention layer altogether; this architecture is called the neural memory module (LMM) in Behrouz et al. (2024).

    Parallel training. Finally, there is the question of computational complexity. Titans sound like a very interesting idea but what use would it be if we had to run actual training for the long-term memory module, sequentially, on the whole context for every training sample?

    Fortunately, that’s not the case. It does sound a little surprising to me but actual tools to parallelize this “test-time training” have only been devised over the last couple of years: Behrouz et al. (2024) use the mechanism of RNNs with expressive hidden states that represent machine learning models themselves, a technique developed only half a year ago by Sun et al. (2024).

    For Titans, this works as follows. Let’s go back to the update rule, first without momentum:

        \[M_t = (1-\alpha_t)M_{t-1} - \theta_t\nabla\ell(M_{t-1};\mathbf{x}_t).\]

    Let’s say that we are doing mini-batch gradient descent with batches of size b, i.e., the input is divided into chunks of size b. In that case, we can write

        \[M_t = \beta_tM_0 - \sum_{i=1}^t\frac{\theta_i\beta_t}{\beta_i}\nabla\ell(M_{t'};\mathbf{x}_i),\]

    where t' is the start of the current mini-batch, t'=t-(t\,\mathrm{mod}\, b), and \beta_i collects the decay terms,

        \[\beta_i= \prod_{j=1}^i(1-\alpha_j).\]

    This formula accumulates all the influences up to time step t, and now we can rewrite it in a matrix form. For simplicity, let’s focus on the first mini-batch, with t=b and t'=0, and let’s assume that M_t is linear, M_t=W_t. Now the gradient of our quadratic loss function in matrix form is

        \[\nabla\ell(W_0;\mathbf{x}_t)=\left(W_0\mathbf{x}_t-\mathbf{x}_t\right)\mathbf{x}_t^\top,\]

    which means that

        \[\sum_{i=1}^b\frac{\theta_i}{\beta_b}\beta_i \nabla\ell(M_0;\mathbf{x}_i) = \Theta_b B_b\left(M_0X-X\right)X^\top,\]

    where \Theta_b is a diagonal matrix containing scaled learning rates \theta_i for the current mini-batch and B_b is a diagonal matrix containing scaled decay factors \beta_i for the batch. The matrices \Theta_b and B_b are only needed for the current batch, you don’t have to store them for all N/b mini-batches, so this already makes the whole procedure computationally efficient and not exceedingly memory-heavy. 

    As for adding momentum, the method above means that the update terms can all be computed at the same time, so the momentum update term can also be computed very efficiently with the parallel associative scan developed in the S4 paper (Smith et al., 2023) that we discussed in a previous post (see also this explanation).

    Experimental evaluation. The Titans paper has an experimental comparison of all four versions we have discussed above: MAG, MAC, MAL, and LMM. In most categories, LMM is the best “simple” (recurrent) model and either MAG or MAC is best in the “hybrid” category, where a memory mechanism is augmented with regular self-attention. In the original paper, most comparisons deal with relatively small models; e.g., here is the perplexity among language models with 340M parameters (measly by modern standards) trained on different datasets:

    The authors say that experiments with larger models are coming, and it makes sense that they may take a long time. One set of experiments where data from state of the art models is available deals with the BABILong benchmark (Kuratov et al., 2024), a needle-in-a-haystack benchmark where the model has to not just find the “needle” statements in a long “haystack” context but distinguish statements that are relevant for the inferential task at hand and perform the inference. On this recently released benchmark, Titans outperforms much larger models, including not only Llama and Qwen with ~70B parameters but also GPT-4:

    This plot looks very convincing to me: it appears that the long-term memory that Titans propose does indeed work very well.

    Conclusion

    Just how important are Titans? In my opinion, this is a bona fide engineering breakthrough that combines an ambitious idea (adding memory to LLMs), an elegant implementation (the part about parallelizing is important), and excellent test results. Unfortunately, for now we only have the results of relatively small experiments: while this is understandable for an academic paper only time will tell just how well this idea will survive scaling to state of the art model sizes.

    Before we get experimental confirmation, we should not get overexcited. The path of deep learning is paved with ideas that sounded great on paper and in early experiments but ultimately led nowhere. Cyclical learning rates led to superconvergence (Smith, 2015), the amsgrad algorithm fixed a very real error in the Adam optimizer and was declared the best paper of ICLR 2018 (Reddi et al., 2018), ACON activation functions bridged the gap between ReLU and Swish in a very elegant and easily generalizable way (Ma et al., 2020)… all of these ideas made a splash when they appeared but ultimately were all but discarded by the community.

    Still, I truly believe that memory is a key component of artificial intelligence that is still underdeveloped in modern architectures. And there are just so many directions where Titans may go in the future! Behrouz et al. (2024) themselves mention exploring alternative architectures for the long-term memory module (e.g., recurrent layers instead of MLPs). One could also apply Titans to tasks requiring ultra-long-term dependencies, like video understanding or large-scale scientific simulations. Getting working memory right may be one of the last obstacles on the path to true AGI. I am not sure that Titans is getting it exactly right, but it is a very interesting step in what is definitely the right direction.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • LLMs for Programming: Is Coding Dying?

    LLMs for Programming: Is Coding Dying?

    It is time to discuss some applications. Today, I begin with using LLMs for programming. There is at least one important aspect of programming that makes it easier than writing texts: source code is formal, and you can design tests that cover at least most of the requirements in a direct and binary pass/fail way. So today, we begin with evaluation datasets and metrics and then proceed to fine-tuning approaches for programming: RL-based, instruction tuning, and others. Next, we will discuss LLM-based agents for code and a couple of practical examples—open LLMs for coding—and then I will conclude with a discussion of where we are right now and what the future may hold.

    Introduction

    Over the last couple of months, I have been getting plenty of post and video recommendations with bloggers all saying the same thing: coding is dead. AI is taking over programming, there is no point learning to code, humans should move from coding to biology, farming, or videoblogging — you’ve heard the sentiment, no doubt.

    These bloggers are if anything late. Jensen Huang said it back in February 2024: at the World Government Summit in Dubai, he said that while for many years every tech person had considered it “vital” for young people to learn computer science, to learn how to program, “in fact, it’s almost exactly the opposite. It is our job to create computing technology such that nobody has to program. And that the programming language is human. Everybody in the world is now a programmer. This is the miracle of artificial intelligence.” This sentiment has been around ever since at least OpenAI Codex, a coding model published back in 2021.

    Today, as 2024 is turning into 2025, where are we with AI for coding? Is coding really dead? This is the main question I want to consider today.

    Another motivation for this post is that over the last months (actually, almost two years already!), we have been discussing modern generative AI and in particular large language models (LLMs). Lately, we have seen several directions of improving and extending fundamental Transformer-based LLMs, including fine-tuning methodsRAGextending the context, and Mamba-like models. But how do these and other methods come together in real applications? To answer this, I want to make a couple of case studies in specific fields where LLMs are being actively applied. We have already had such a post before, devoted to AI in mathematics, but it was significantly more general, and did not feature modern LLMs too much.

    For today’s case study, I have chosen programming and other code-related tasks, a field where, first, we can discuss fine-tuning approaches in a rather straightforward way, and second, some research directions are familiar to me personally (Lomshakov et al., 2024Sedykh et al., 2024Lomshakov et al., 2023). This post is partly based on a survey of LLMs for code we have recently done together with Vadim Lomshakov, so big thanks to him too.

    Our plan for today is simple but quite extensive:

    • we begin by discussing the datasets and evaluation metrics applicable to code-related tasks; this is an especially important part in this case because the main difference between natural text and code is that code is formal, which means that you can have straightforward evaluation metrics based on passing tests;
    • next, in parallel with my previous post on fine-tuning, we go on to applications of fine-tuning to the coding domain, covering in sequence:
      • methods based on reinforcement learning,
      • instruction tuning,
      • prompting, and
      • other methods, including retrieval-augmented generation (RAG) for code and resampling multiple solutions;
    • the last “academic” section will discuss agentic approaches: LLMs that can write, execute, and re-write code iteratively; this is one of the most natural things to do with coding tasks, and this is definitely an approach we will see more and more in the nearest future;
    • finally, we will discuss a couple of open models, from the Qwen and DeepSeek family, where we actually know how these models were trained; we will see that ideas discussed in academic literature do indeed make their way into top offerings.

    Naturally, I cannot hope to give a full overview in this post; that would require a huge paper or even a whole book, so for a more comprehensive treatment, I refer to, e.g., (Jiang et al., 2024Zan et al., 2023Joel et al., 2024) and other surveys. But for every direction, I give a couple of representative examples that, I hope, will provide a good intuition for coding applications of LLMs.

    However, all of this content, while interesting, will not answer the main question: is coding dying? So in the final part, we will discuss the current state of top models for coding, where we often don’t really know how these models work but we can definitely see some great examples.

    Enjoy!

    Evaluating LLMs for Code: Data and Metrics

    Source code is a rather special field compared to other text generation tasks. In text writing and answering open-ended questions, there can be a million ways to formulate the answer, it is hard to establish which of these ways are equivalent, and often one has to resort to asking people (like in RLHF that we discussed in a previous post). Sometimes you can get by with multiple choice questions with well-defined answers, but often multiple choice cannot get you all the way to a real application.

    In coding, there are still a million different ways to write a program but often you can define formally verifiable ways—simply speaking, tests—that can tell you if the function is correct better than any human supervision (although still not perfectly: it is usually impossible to have full test coverage). In this section, we discuss existing ways to evaluate code models, including the metrics and open datasets available for their fine-tuning and evaluation. We will refer to the test sets from this section many times below since they are the sets where models get compared with each other.

    HumanEval

    This benchmark, published by OpenAI researchers Chen et al. (2021), is one of the most popular datasets for code models. The main task considered in HumanEval is evaluating the quality of Python functions generated by a given description, documentation style. The dataset contains 164 programming tasks designed by hand and covering Python understanding, algorithms, and basic math and computer science; these tasks are similar to easy questions on an interview for a software developer position.

    The primary evaluation metric is pass@k, computed as

        \[\mathrm{pass}@k := \mathbb{E}_{\text{Problems}} \left[ 1 - \frac{\binom{n-c}{k}}{\binom{n}{k}} \right],\]

    where n is the total number of generated solutions, c is the number of solutions that have successfully passed all tests, k is the number of chosen solutions, and brackets denote the binomial coefficient.

    The \mathrm{pass}@k metric has an intuitive combinatorial meaning: it shows the probability that at least one out of k randomly chosen solutions turns out to be correct. Note how this corresponds to good practice in using LLMs: since running an LLM is cheap anyway, it is usually best to ask it for several solutions and/or regenerate the answer several times, choosing the best one. Naturally, this applies not only to coding but to all tasks, but in coding, it is especially easy to choose the best out of k even for a relatively large k.

    MBPP

    Mostly Basic Programming Problems (MBPP) is another popular dataset for generating individual functions in Python, this time published by Google Research (Austin et al., 2021). MBPP contains 974 short Python programs collected with crowdsourcing among people with some Python experience.

    Each entry in the dataset contains a description of the problem, its solution as a Python function, and three tests to check correctness. The tasks are again relatively small and interview-like: 58% of the problems are mathematical in nature, 43% are related to list processing, and 19% to string processing; the average solution length is 6.8 lines. After manual verification and refinement, the authors chose 426 guaranteed clear and correct questions for the final dataset. The primary evaluation metric is the proportion of solved problems, and a problem is, again, considered solved if at least one of the K solutions passes the tests successfully; Austin et al. (2021) used K=80.

    Extended versions of HumanEval and MBPP have been released as HumanEval+ and MBPP+ respectively; in these extensions, the authors added more tests and clarified some ambiguous problem statements (Liu et al., 2023).

    APPS

    The next benchmark, Automated Programming Progress Standard (APPS), contains 10000 programming problems collected from open platforms such as CodeForces and Kattis. It is designed to evaluate both programming and problem-solving skills; it includes natural language problem descriptions and over 130K test cases to verify solutions. The dataset also contains over 230K reference solutions written by humans, with problem difficulty ranging from beginner to university-level competitions.

    This time the problems are significantly more complex: the average problem description length in APPS is 293.2 words. The dataset is evenly divided into 5000 problems for training and 5000 for testing, with test problems averaging 21.2 modular tests per problem, so the coverage can be assumed to be quite good.

    All problems are divided into three difficulty levels: beginner, interview-level, and competition-level. APPS uses two main evaluation metrics:

    • test case average is the average number of tests passed, i.e., the proportion of tests successfully passed by generated solutions; for a problem p with a test set C_p, this metric is computed as

          \[\mathrm{Test Case Average} = \frac{1}{P} \sum_{p=1}^{P} \frac{1}{C_p} \sum_{c=1}^{C_p} \left[ \mathrm{eval}(\langle \mathrm{code}_p \rangle, \mathbf{x}_{p,c}) = \mathbf{y}_{p,c} \right],\]

      where P is the total number of problems, \mathrm{eval}(\langle \mathrm{code}p \rangle, \mathbf{x}_{p,c}) is the result of executing the code on input data \mathbf{x}_{p,c}, and \mathbf{y}_{p,c} is the expected (correct) result;
    • strict accuracy requires that generated solutions pass all tests for a problem, so it is computed as

          \[\mathrm{StrictAccuracy} = \frac{1}{P} \sum_{p=1}^{P}  \sum_{c=1}^{C_p} \left[ \mathrm{eval}(\langle \mathrm{code}_p \rangle, \mathbf{x}_{p,c}) = \mathbf{y}_{p,c} \right].\]

    As you can see, the average number of tests passed is a less demanding metric; it captures improvements even in models that cannot yet solve problems completely. “Strict accuracy”,’ on the other hand, reflects a model’s ability to solve problems entirely, including complex edge cases, which is crucial for real-world applications.

    LiveCodeBench

    This is another commonly used Python benchmark that aims to address certain issues with its predecessors (Jain et al., 2024). The main problem that the authors tried to avoid is the risk of contamination, i.e., leaking test data into the training set. To this end, they collected 511 competitive programming problems from platforms such as LeetCodeAtCoder, and CodeForces, published between May 2023 and May 2024.

    Besides, for a more comprehensive comparison of code models, the authors introduced three additional task categories beyond function generation from documentation: 

    • fixing bugs in programs based on failed tests; 
    • virtual “execution” of a program with specific input; 
    • predicting output based on a problem description and input data. 

    Here is the illustration from LiveCodeBench:

    For program generation and debugging, LiveCodeBench uses the pass@k metric defined in HumanEval; there are about 17 tests per task on average so coverage is again quite good. For program execution and output prediction tasks, it uses a binary metric comparing the result to the reference answer. Moreover, the authors classified all tasks by difficulty level (easy, medium, hard) and balanced their distribution, allowing for a more detailed comparison of models with the same average metric value.

    The “Live” part of LiveCodeBench deserves special attention. It might seem that contamination should not be a serious problem for LLMs: they are trained on trillions of tokens, and even the code-related part used for fine-tuning is usually quite large, so even a very large model definitely cannot memorize its training set. However, experiments have shown that depending on the release date of the model, a sharp decline in performance indeed happens even for top models such as DeepSeek, GPT-4o, or Codestral: if problems are not part of the datasets available for (pre)training at the time of the model release, some models perform noticeably worse on them.

    Therefore, to get a more realistic quality assessment one has to select more recent problems, published after the release of the model. The LiveCodeBench leaderboard compared models only on problems added after the cutoff date of their training and fine-tuning sets. For example, here are some top rows from the LiveCodeBench leaderboard on problems submitted after January 1, 2024:

    As you can see, some top entries are italicized and not assigned a rank; these are exactly models that might be contaminated by having some test problems in their training sets, and while metrics are still computed it is easy to distinguish such models. Here is a sample leaderboard for the bugfixing part of LiveCodeBench:

    I took these samples from the LiveCodeBench leaderboard in early December; by the time I’m publishing this post, another champion has already been crowned, but let’s hold off this discussion until the end.

    CodeContests

    The CodeContests dataset was compiled for fine-tuning and evaluation of DeepMind’s AlphaCode model (Li et al., 2022). It contains competitive programming problems taken from the CodeForces platform, with a small set of open tests. Each problem has an average of 203.7 hidden tests, and the test subset of CodeContests contains 165 problems. For evaluation, it uses the standard pass@k metric, which in this case is simply the percentage of solved problems in a protocol where the model generates k solutions for each problem, they are tested on hidden tests, and if any one of the k solutions passes the problem is considered solved. But it also includes an interesting variation, 10@k, which is the percentage of solved problems in a protocol where the model generates k solutions for each problem but can only run hidden tests on 10 of them. This reflects the difference between relatively “cheap” additional generations from the model and the relatively “expensive” process of re-evaluating and testing them. If the model can choose the best among its generations, it would be easy to ask it to generate more. In the same vein, CodeContests allocates a fixed amount of time for running tests on each problem, requiring some efficiency.

    TransCoder

    The next benchmark, introduced in the work on the TransCoder model (Sun et al., 2023), was created to evaluate the quality of program translation from one programming language to another. For this purpose, the authors collected examples from the GeeksforGeeks website of implementations of the same function in C++, Java, and Python. In total, they collected about 460 functions for each language, with 10 tests for every function. The quality metric used is the percentage of solutions that successfully pass the modular tests.

    Tldr

    The tldr benchmark (Zhou et al., 2022) goes from programming to command-line scripting characteristic of sysadmin and devops work: it contains pairs of human queries in English and the corresponding bash commands with the necessary combination of flags. It includes 9187 pairs that span 1879 unique bash commands, and is intended to evaluate bash command generation based on user instructions. It uses several straightforward evaluation metrics: 

    • command name match, 
    • exact match of the command with the reference, 
    • F1 score computed on tokens, 
    • the charBLEU metric, a variant of the standard BLEU score computed on character-level n-grams instead of tokens (Popovic, 2015Post, 2018).

    SWE-bench

    The SWE-bench benchmark (Jimenez et al., 2023) was also developed to evaluate LLMs for coding tasks but this dataset concentrates more on real-life problems with repositories; it includes 2294 pairs of problem descriptions (issues) and their corresponding solutions (pull requests) from 12 popular Python repositories. Models are provided with the codebase and a problem description, and they are required to generate a patch that resolves the specified problem.

    SWE-bench evaluates whether models can understand and coordinate changes across multiple functions, classes, and even files simultaneously, requiring complex reasoning and interaction with the programming environment. The evaluation is conducted using modular tests: if the proposed fix passes both new and existing tests, it is considered successful.

    Some model comparisons also include a more general-purpose benchmark Super-NaturalInstructions (Wang et al., 2022), designed to evaluate the generalization capabilities of models. In this benchmark, models are asked to solve various tasks based on instructions written in natural language. The dataset contains 1616 distinct tasks from 76 unique task types, including classification, text editing, summarization, and others, as well as programming-related tasks such as code description and program generation from textual descriptions.

    Fine-Tuning LLMs for Code with Reinforcement Learning

    We have already discussed applications of reinforcement learning (RL) for LLM fine-tuning in the context of RLHF. For LLMs for code, the main concepts remain the same: code generation is treated as a sequential decision making problem formalized via a Markov decision process (MDP) defined as a quintuple

        \[(\mathcal{S}, \mathcal{A}, \mathcal{P}, r, \gamma),\]

    where \mathcal{S} is the state space, with a state s_t=(\mathbf{y}_{<t},\mathbf{x}) consisting of a prefix of tokens \mathbf{y}_{<t} and problem description \mathbf{x}\mathcal{A} is the set of actions corresponding to choosing the next token y_t, P(s'|s,a) is the transition function that defines the probability of passing to a state s' after performing action a in state sr(s,a) is the reward function for action a in state s, and \gamma is a discounting coefficient for future rewards, a number from 0 to 1.

    The goal of reinforcement learning is to maximize the total expected reward

        \[J(\pi_{\boldsymbol{\theta}})=\mathbb{E}_{\pi_{\boldsymbol{\theta}}}\left[\sum_{t=0}^{\infty}\gamma^t r(s_t, a_t)\right],\]

    where \pi_{\boldsymbol{\theta}}(a|s) is a policy (parameterized by \boldsymbol{\theta}) that defines the probabilities of taking certain actions in each state.

    Other key notions in RL are the state value functions V and state-action value functions Q, defined as

        \[V^{\pi}(s) = \mathbb{E}_{a\sim \pi}\left[Q^{\pi}(s,a)\right],\qquad Q^{\pi}(s,a) = \mathbb{E}\left[r(s,a)+ \gamma V^{\pi}(s')\right],\]

    where s' is the next state.

    Reinforcement learning is usually done in one of two ways:

    • value-based RL, where the main goal is to learn functions V and Q introduced above, in particular V^\ast and Q^\ast corresponding to the optimal policy \pi^\ast, and then the policy \pi^\ast itself is derived from the value functions; machine learning models such as neural networks are used here to model the functions V and Q with a state or state-action pair at the input and expected reward at the output; this class of methods includes Monte-Carlo approaches to RL and, most importantly, TD-learning (temporal difference learning), in particular Q-learning (Watkins, 1992);
    • policy gradient methods where the policy π is itself represented as a machine learning model, with special techniques used to be able to take gradients with respect to the policy’s parameters (Sutton et al., 1999Baxter, Bartlett, 2001); this class of methods began with the classical REINFORCE algorithm (Williams, 1992) and now includes such modern approaches as trust region policy optimization (Schulman et al., 2015) (TRPO), proximal policy optimization (PPO) that is most often used for RLHF-type fine-tuning (Schulman et al., 2017), and others.

    Each of the methods considered below uses one of these approaches; I cannot go into a detailed exposition of reinforcement learning here and recommend classical books and surveys (Sutton, Barto, 2018Zheng et al., 2023Schulman et al., 2017).

    RLTF

    Liu et al. (2023) propose a method called reinforcement learning from unit test feedback (RLTF) to improve the quality of code generation using pretrained language models (LLMs). The core idea of RLTF is to integrate a real-time data generation mechanism with multi-level feedback from unit tests that would enable the model to learn more effectively through more diverse and relevant examples. Here is an illustration:

    Here, the source code generation task is formalized here as sequential generation of the code W that satisfies a high-level problem description D. The goal is to maximize the conditional probability

        \[\max p(W|D, \boldsymbol{\theta}) = \max \prod_{t=1}^T p(\mathbf{w}_t|D, \boldsymbol{\theta}, \mathbf{w}_{1:t-1}),\]

    where \boldsymbol{\theta} are model parameters, T is the sequence length, and \mathbf{w}_t is the t-th token of generated code.

    RLTF uses an online buffer which is updated dynamically during training. The buffer contains data pairs consisting of a problem description D, generated code W', and feedback from the compiler \mathrm{FB}(W'). Optimization uses the following loss function that combines standard reinforcement learning and feedback of varying granularity:

        \[L_{\mathrm{total}} = L_{\mathrm{SL}} + L_{\mathrm{coarse}} + L_{\mathrm{fine}} + L_{\mathrm{adaptive}},\]

    where L_{\mathrm{SL}} is the standard supervised learning loss, and L_{\mathrm{coarse}}, L_{\mathrm{fine}}, and L_{\mathrm{adaptive}} are reinforcement learning components.

    Liu et al. (2023) distinguish three types of feedback mechanisms:

    • coarse-grained feedback, where rewards depend on the execution status of the code (successful execution, failure, or syntax error):

          \[R_{\mathrm{coarse}}(\hat{W}) =\begin{cases}1.0, & \text{if } \mathrm{FB}(\hat{W}) = \text{pass}, \\ -0.3, & \text{if } \mathrm{FB}(\hat{W}) = \text{failure}, \\ -0.6, & \text{if } \mathrm{FB}(\hat{W}) = \text{non-syntax error}, \\ -1.0, & \text{if } \mathrm{FB}(\hat{W}) = \text{syntax error}; \end{cases}\]

    • fine-grained feedback, where errors are classified into specific lines of code or specific reasons for the error, with rewards defined as

          \[R_{\text{fine}}(\hat{W}) =\begin{cases} 0.0, & \text{for an error from category }U_{\mathrm{ignore}}, \\ -0.3, & \text{otherwise};\end{cases}\]

    • adaptive feedback, where the reward depends on the share of successfully passed tests:

          \[R_{\mathrm{adaptive}}(\hat{W}) = -0.3 + 1.3 \cdot \frac{N_{\mathrm{pass}}}{N_{\mathrm{pass}} + N_{\mathrm{fail}}},\]


      where N_{\mathrm{pass}} and N_{\mathrm{fail}} are the number of passed and failed tests respectively.

    The overall loss function combines all levels of feedback:

        \[L_{\mathrm{total}} = L_{\mathrm{SL}} - \sum_{t=1}^{T} R(\hat{W}) \log p(\mathbf{w}_t | D, \boldsymbol{\theta}, \mathbf{w}_{1:t-1}),\]

    where R(W') integrates R_{\mathrm{coarse}}, R_{\mathrm{fine}}, and R_{\mathrm{adaptive}}.

    Experiments on the APPS and MBPP benchmarks showed that RLTF improves code generation results, outperforming other approaches such as CodeRL and PPOCoder both in the quality of generation and versatility across different models, including CodeT5 and CodeGen. Ablation studies also confirmed the importance of online learning and the proposed feedback mechanisms.

    RLEF

    Gehring et al. (2024) introduce reinforcement learning with execution feedback (RLEF), which uses feedback obtained during code execution instead of feedback from the compiler or unit tests. The main idea is to frame the problem as a partially observable Markov decision process (MDP), where actions correspond to sequential code generation and observations include feedback on the test results. This enables the LLM to not only generate solutions but also iteratively correct errors based on provided feedback, as illustrated below with the general flowchart on the left and a sample “inner dialogue” on the right:

    RLEF optimizes a policy π using the proximal policy optimization (PPO) algorithm. The reward is defined as follows:

        \[r(s_t, a_t) =\begin{cases}1, & \text{if all tests pass}, \\-1, & \text{if at least one test fails}, \\-0.2, & \text{if the generated code is invalid}.\end{cases}\]

    The overall reward function also includes regularization in the form of KL divergence between the current policy \pi and the initial policy \rho:

        \[R(s_t, a_t) = r(s_t, a_t) - \beta \log \frac{\pi(a_t | c_t)}{\rho(a_t | c_t)},\]

    where c_t is the sequence of previous observations and actions, and \beta is the regularization coefficient.

    Optimization is based on the advantage function 

        \[A_t = R(s_t, a_t) - V(c_t),\]

    where V(c_t) is the value function. To minimize the loss, RLEF uses the clipped objective function from PPO:

        \[L_{\mathrm{clipped}}(\theta) = \mathbb{E}_t \left[\min \left(r_t(\theta) A_t, \mathrm{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) A_t \right) \right],\]

    where

        \[r_t(\theta) = \frac{\pi_\theta(a_t | c_t)}{\pi_{\theta_{\mathrm{old}}}(a_t | c_t)}.\]

    Experiments on challenging benchmarks such as CodeContestsHumanEval+, and MBPP+ showed that RLEF improves results with significantly lower computational costs. The method provided improvements for both small models (8B parameters) and large models (70B parameters), showing their ability to effectively fix errors and adapt to feedback in iterative code synthesis. The experimental study by Gehring et al. (2024) indicates that RLEF not only enhances first generation accuracy but also significantly improves the quality of subsequent fixes. This opens up opportunities for automatic iterative correction and enhancement of generated code.

    B-Coder

    Introduced by Yu et al. (2024), this is another reinforcement learning based architecture for program synthesis, but unlike most other methods we discuss today, B-Coder uses value-based RL approaches instead of policy-based RL algorithms such as the frequently mentioned PPO. B-Coder focuses on optimizing the functional correctness of programs with minimal reward function design costs.

    The task of program synthesis is again formalized as the sequential generation of a program W = (w_0, w_1, \ldots, w_T) based on a textual problem description D. Generating each token w_t is interpreted as taking an action a_t in a state s_t = (\mathbf{w}_{<t}, D). Thus, the training process is defined as a Markov decision process (MDP) characterized by the tuple (S, A, P, r, \gamma), where

        \[r(s_T, a_T) = \begin{cases}1, & \text{if $W$ passes all unit tests}, \\-0.3, & \text{if $W$ fails at least one test}, \\-0.6, & \text{if $W$ raises an execution error}, \\-1.0, & \text{if $W$ contains a syntax error}.\end{cases}\]

    The RL objective is to maximize the discounted cumulative rewards

        \[J(\pi) = \mathbb{E} \left[ \sum_{t=0}^\infty \gamma^t r(s_t, a_t) \mid \pi \right],\]

    where \pi is the policy for selecting actions.

    Unlike other methods, Bl-Coder uses Q-learning to estimate the state-action value function:

        \[Q^\pi(s, a) = \mathbb{E} \left[ r(s, a) + \gamma \max_{a'} Q^\pi(s', a') \right],\]

    where s' is the next state obtained after taking action a in state s.

    One of the main problems of RL in general and Q-learning in particular is training instability: policies and value functions tend to overfit, quickly get into local maxima, and so on. For example, the whole point of TRPO and later PPO compared to standard policy gradient is to restrict how far a single training step can go, limiting updates to a trust region with explicit constraints in TRPO or by clipping the reward in PPO. To stabilize training, B-Coder uses the following tricks:

    • initializing Q with pretrained models: the Q function is initialized from the logits l(s,a) of a language model transformed via softmax:

          \[Q(s, a) = \alpha \left( l(s, a) - \max_{a'} l(s, a') \right) + V(s);\]

    • conservative Bellman operator: instead of the standard optimal Bellman operator

          \[(B^* Q)(s, a) = r(s, a) + \gamma \max_{a'} Q(s', a')\]

      B-Coder uses the conservative operator

          \[(B_q Q)(s, a) = r(s, a) + \gamma Q(s', a'_q),\]

      where a'_q = \arg\,\max_{a'} q(a' \mid s') is the action chosen by a fixed strategy q.

    Training in B-Coder proceeds in two stages:

    • pretraining V(s), where the state value function is trained by minimizing the temporal difference (TD) error:

          \[L_V = \mathbb{E} \left[ \left( r(s, a) + \gamma V(s') - Q(s, a) \right)^2 \right];\]

    • fine-tuning Q(s, a), where the method optimizes a combined loss function

          \[L_{\mathrm{ft}} = L_Q + \beta_{\mathrm{adv}} L_{\mathrm{adv}} + \beta_{\mathrm{ce}} L_{\mathrm{ce}},\]

      where L_Q is the TD error with the conservative Bellman operator, L_{\mathrm{adv}} is the advantage function A(s, a) regularization, and L_{\mathrm{ce}} is the error function based on cross-entropy.

    Yu et al. (2024) conducted experiments on the APPS and MBPP benchmarks; they showed that Bl-Coder outperforms prior approaches such as CodeRL, PPOCoder, and RLTF in terms of code generation quality. The conservative Bellman operator improves training stability, and the use of off-policy data ensures high sampling efficiency.

    PSGPO

    A rather recent work by Dai et al. (2024) introduces a variation of policy optimization methods called process supervision-guided policy optimization (PSGPO). Its main motivation is to address the issue of sparse reward signals in RL-based code generation tasks: usually the model only gets a reward based on the entire code project passing or failing tests or raising errors. The key innovation of PSGPO is the process reward model (PRM) that provides detailed feedback at line level, enabling the model to improve its solutions incrementally during code generation.

    The code generation problem is again formulated as the sequential generation of tokens \mathbf{y} = (y_1, y_2, \dots, y_T)based on the input task description \mathbf{x}. A pretrained language model p_{\boldsymbol{\theta}}(\mathbf{y} | \mathbf{x}) estimates the conditional probability distribution

        \[p_{\boldsymbol{\theta}}(\mathbf{y} | \mathbf{x}) = \prod_{t=1}^Tp_{\boldsymbol{\theta}}(y_t | \mathbf{y}_{<t}, \mathbf{x}),\]

    where \mathbf{y}_{<t} = (y_1,\ldots,y_{t-1}).

    The reinforcement learning objective is to maximize the expected reward

        \[\mathbb{E}_{\y \sim p\theta} \left[ R(\mathbf{x}, \mathbf{y}) \right],\]

    where R(\mathbf{x}, \mathbf{y}) is the reward function based on passing tests.

    The PRM provides detailed feedback by evaluating the correctness of every line in the generated code. PRM is trained using labels automatically derived through binary search. The label for a prefix \mathbf{y}_{<m} is defined as

        \[l_m =\begin{cases}+1, & \text{if there exists a completion that passes all tests}, \\-1, & \text{if all completions fail}.\end{cases}\]

    PRM is trained to minimize the squared error

        \[\min_\phi \sum_{(\mathbf{x}, \mathbf{y}_{\leq m}, l_m)} \left( R_\phi(\mathbf{x}, \mathbf{y}_{\leq m}) - l_m \right)^2,\]

    where R_{\phi} is the PRM’s prediction.

    During reinforcement learning, PRM is used in two ways:

    • dense rewards, where PRM adds step-by-step signals at the line level,

          \[R_{\mathrm{PRM}}(\mathbf{x}, \mathbf{y}) = \sum_{m=1}^T R_\phi(\mathbf{x}, \mathbf{y}_{\leq m}),\]

      and the overall reward function becomes

          \[R(\mathbf{x}, \mathbf{y}) = R_{\mathrm{UT}}(\mathbf{x}, \mathbf{y}) + \lambda R_{\mathrm{PRM}}(\mathbf{x}, \mathbf{y}),\]

      where R_{\mathrm{UT}} is the binary reward for passing tests and \lambda is the PRM weight;
    • value initialization, where PRM is used to set the initial values for the state value function,

          \[V(\mathbf{x}, \mathbf{y}_{\leq m}) \approx R_\phi(\mathbf{x}, \mathbf{y}_{\leq m}).\]

    The process of automatic labeling by PRM and its integration with reinforcement learning is illustrated below (REF):

    Experimental results by Dai et al. (2024) on LiveCodeBench and InHouseBench (a proprietary benchmark by ByteDance) showed that PSGPO with PRM improves the pass@1 metric for long solutions (over 100 tokens) by 9% compared to the baseline RL approach. The combination of dense rewards and value initialization improves learning through additional feedback signals, better stability, and also provides significant advantages in tasks with long planning horizons.

    ILF

    Chen et al. (2023) apply an approach called imitation learning from language feedback (ILF; Scheurer et al., 2023) to code generation. ILF uses relatively few human annotations to train a model for both generating and refining programs, aiming for a data-efficient and user-friendly approach. Below we see the general framework on the left and a sample query with language feedback on the right:

    The task of program synthesis is formalized as training a probabilistic model \pi_{\boldsymbol{\theta}}, parameterized by \boldsymbol{\theta}, to generate a program \mathbf{x} from a task description \mathbf{t}. The generation probability is defined as

        \[\pi_{\boldsymbol{\theta}}(\mathbf{x}|\mathbf{t}) = \prod_{i=1}^{|\mathbf{x}|}\pi_{\boldsymbol{\theta}}(x_i | \mathbf{x}_{<i}, \mathbf{t}),\]

    where \mathbf{x}_{<i} is the prefix of length i-1.

    The ILF objective is to minimize the Kullback-Leibler divergence between the current model \pi_{\boldsymbol{\theta}} and the target distribution \pi_{\mathbf{t}}^\ast:

        \[\min_{\boldsymbol{\theta}}\mathbb{E}_{t\sim p(t)}\left[\mathrm{KL}\left(\pi_{\mathbf{t}}^\ast \| \pi_{\boldsymbol{\theta}}(\cdot | t)\right)\right],\]

    where \pi_{\mathbf{t}}^\ast(\mathbf{x}) is proportional to \exp(\beta R(\mathbf{x}, \mathbf{t})), and R(\mathbf{x}, \mathbf{t}) is the reward function based on passing unit tests.

    The ILF algorithm proceeds in several steps:

    1. Sampling incorrect programs \mathbf{x}_0\sim p_{\boldsymbol{\theta}}(\cdot | \mathbf{t} ) that fail the tests: \mathrm{eval}(\mathbf{x}_0, \mathbf{t})=0.
    2. Annotating \mathbf{x}_0 with feedback f provided by a human assessor.
    3. Training a model \pi_{\mathrm{Refine}} to generate corrections \mathbf{x}_1 that pass tests based on f and \mathbf{x}_0.
    4. Training \pi_{\boldsymbol{\theta}} on the fixes \mathbf{x}_1 generated by \pi_{\mathrm{Refine}}.

    The proposed distribution q_t(\mathbf{x}_1) that approximates \pi_{\mathbf{t}}^\ast is formalized as

        \[q_t(\mathbf{x}_1) = \sum_{\mathbf{x}_0, f} \pi_\theta(\mathbf{x}_0 \mid \mathbf{t}) \cdot p_{F}(f | \mathbf{t}, \mathbf{x}_0) \cdot \pi_{\text{Refine}}(\mathbf{x}_1  | \mathbf{t}, \mathbf{x}_0, f) \cdot \delta\left(\mathrm{eval}(\mathbf{x}_1, \mathbf{t}) = 1\right),\]

    where p_{F}(f | \mathbf{t}, \mathbf{x}_0) is the feedback distribution, and \delta is the delta function that defines the constraint of passing the tests.

    Chen et al. (2023) tested ILF on the MBPP benchmark, where it achieved a significant improvement in the pass@1 metric: +10% in absolute terms and +38% in relative terms compared to baseline approaches. Analysis revealed that the quality of human feedback is critical, while automatically generated feedback is much less effective. The authors conclude that ILF is an effective approach for training program generation models, achieving significant improvements with limited manual annotations.

    To sum up, there can be many different RL-based approaches and modifications that help LLMs learn to code, but in any case the point is usually to define special rewards based on how well the model passes tests. This is the key characteristic feature of the coding domain—tests that provide automatic, easy to get feedback—and this feature can also help us with supervised instruction tuning, which we will discuss in the next section.

    Instruction tuning

    Introduction and reminders

    We have discussed instruction tuning in the same post on LLM fine-tuning; it is one of the main techniques for fine-tuning LLMs where the model is trained on a dataset where tasks are represented as natural language instructions with correct answers also provided as text. The purpose of instruction tuning is to improve the model’s generalization abilities for a wide range of problems that can be posed as text rewriting; naturally, code generation, bugfixing, and other code-related tasks also fall into this category. Instruction tuning is a special case of supervised fine-tuning (SFT), and it is, of course, already a huge field with thousands of works (Zhang et al., 2024).

    One of the key works in this domain is the OpenAI study on InstructGPT (Ouyang et al., 2022), which was also foundational for RLHF, and we already discussed it in detail before. The fine-tuning procedure in InstructGPT combined instruction tuning and RLHF in three stages: 

    • supervised fine-tuning (SFT), where the model was fine-tuned on demonstrations provided by annotators who created gold standard responses to various prompts,
    • reward modeling, where annotators ranked model responses by quality and these rankings were then used to train a reward model that predicts human preferences, and
    • reinforcement learning with PPO. 

    Reward modeling is needed to scale the training data because human labeling is, naturally, far more expensive than reward modeling and RL.

    Before proceeding to code-related works, we note the Self-Instruct approach (Wang et al., 2023) that significantly improved instruction tuning in general. We also discussed it in a previous post, so let me just remind that it is a semi-automatic iterative process as illustrated below:

    A small set of manually written instructions is used to initiate the generation of new tasks using few-shot learning, the model (GPT-3 in the original work, Wang et al., 2023) generates, filters, and utilizes these data for fine-tuning, and generated data are filtered with various heuristics, particularly using the ROUGE-L textual similarity metric with the existing instruction set, to ensure the dataset’s diversity. In Wang et al. (2023) obtained about 52K diverse instructions corresponding to about 82K input-output examples. While the dataset is relatively large, it is much smaller than the amounts of data required for training large language models. Nevertheless, fine-tuning the base GPT-3 model on these instructions increased accuracy by 33% on the Super-NaturalInstructions test set, approaching the performance of OpenAI’s InstructGPT that used extensive human labeling. 

    Below, we will explore how such methods are applied to the domain of programmatic code. But first, let me note that the Self-Instruct approach has been directly applied to the programming domain, resulting in an important coding model called CodeAlpaca.

    Extending Self-Instruct for code

    One important work that further develops Self-Instruct is WizardCoder (Luo et al., 2023), which demonstrates how enhancing training data with more complex instructions enables an open model to outperform commercial solutions, in this case Claude and Google‘s LLM which was still called Bard at the time. Naturally, such comparisons are only valid as of the publication date, as closed commercial solutions are continually updated, and their performance improves even without new version numbers or names.

    WizardCoder uses an approach called Evol-Instruct (Xu et al., 2023), adapting it for programming tasks. This approach incrementally generates more complex and detailed instructions for working with code using models from the ChatGPT family with various prompts. Code Evol-Instruct includes the following steps.

    1. Begin with the CodeAlpaca dataset, which contains 20K examples of programming tasks.
    2. Select one task from the dataset.
    3. Modify the task’s instruction by applying one of the following techniques that increase its complexity:
      • add new requirements (e.g., complexity constraints);
      • replace standard conditions with rare ones;
      • increase the number of logical steps required to solve the task;
      • introduce erroneous code as a distractor;
      • impose stricter requirements on computational complexity (e.g., time or memory limits).
    4. Add the new instruction to the dataset.
    5. Train the model on the enhanced dataset.
    6. Repeat the process multiple times, until results on a test set plateau.

    In WizardCoder experiments (Luo et al., 2023), performance gains stabilized after three rounds of data evolution.

    Notably, the instructions are also modified by language models using simple, straightforward prompts. For example, WizardCoder used the following prompt to increase task complexity:

    Please increase the difficulty of the given programming test question a bit.
    You can increase the difficulty using, but not limited to, the following methods:
    {method}
    
    {question}

    Here, question represents the current instruction to be made more complex, and method describes the method of increasing complexity. Below are the prompts for modifying instructions.

    Add new constraints and requirements to the original problem, adding approximately 10 additional words.
    Replace a commonly used requirement in the programming task with a less common and more specific one.
    If the original problem can be solved with only a few logical steps, please add more reasoning steps.
    Provide a piece of erroneous code as a reference to increase misdirection.
    Propose higher time or space complexity requirements, but please refrain from doing so frequently.

    In general, WizardCoder was a great illustration for the “bitter lesson” of LLMs: although we will get to prompting techniques later, it is often enough to just ask the LLM politely to do what you need. This is increasingly true with modern LLMs, and many approaches to prompting become obsolete quite fast, not because they stop working but because they cease to be necessary as straightforward prompts begin to work just as well.

    OctoPack. The OctoPack approach (Muenninghoff et al., 2024) demonstrated that synthetic data for instruction tuning could be successfully generated even without using closed commercial models. Specifically, the authors created the CommitPackFt dataset using publicly available commits on GitHub. Changes in code (before and after the commit) and commit messages served as human-written instructions. Here is an example of extracting such instructions:

    To ensure data quality, only commits with open licenses and clear imperative-style commit messages were included, and the authors also used only commits that touched only one source file. After filtering, the final dataset covered 277 programming languages.

    Fine-tuning the StarCoder-16B model on this dataset achieved a HumanEval pass@1 score of 46.2% for the Python language, the highest at the time among models not trained on synthetic OpenAI data. For fine-tuning, the StarCoder model used low-rank adapters (LoRA) that we talked about in a previous post.

    Instruction Feedback from Humans or Compilers

    Despite all of the above-mentioned advancements achieved by following single-step user instructions, it might take code generation models to new heights if they could become able to process multi-step external feedback. Such feedback can take two forms:

    • compiler feedback, i.e., execution results and error messages;
    • human feedback, i.e., user-provided recommendations or instructions.

    Compiler feedback is obviously crucial in order to allow models to correct syntactic and logical errors in code, while human feedback helps models better understand user instructions, leading to solutions that align more closely with user expectations. Here is an example of such interaction:

    The authors of OpenCodeInterpreter (Zheng et al., 2024) proposed a solution to this problem by preparing a dataset with compiler and human feedback. The most interesting part here is the data collection process that involves the following steps.

    1. Collecting single-step instructions using:
      • queries from open datasets (e.g., MagicoderShareGPT, Evol-Instruct-Code discussed above, or LeetCode Problems);
      • filtering to select complex queries using Qwen-72B-Chat, keeping only queries rated 4–5 for complexity.
    2. Transforming instructions into dialogues through:
      • single-query packing, i.e., joining similar queries into dialogues based on BERT-base embeddings,
      • simulating interactions with a compiler or a human by using an iterative approach:
        • first an initial response is generated with GPT-3.5,
        • the generated code is run,
        • error diagnostics is fed into a stronger GPT-4 model for corrections, and
        • the process is repeated until the program runs successfully; a similar procedure was applied for resolving user-identified issues in the code.

    Here Magicoder and ShareGPT are synthetic datasets of instructions obtained using methods similar to those discussed above. For transforming into dialogues, the authors chose ten categories of queries, including formatting, bug fixing, addressing vulnerabilities, compatibility issues, etc. They also created synthetic instructions for code correction by deliberately introducing errors into responses using GPT-4, executing the programs with errors, and feeding the results back to the model for subsequent corrections. They also used tasks and posts from LeetCode forums that include problems of varying complexity, solutions, and related discussions.

    Here are some experimental results about OpenCodeInterpreter shared by Zheng et al. (2024):

    This work is significant in part because it bridges the gap between open-source models and leading commercial models, even models such as GPT-4 with the Code Interpreter feature.

    In general, instruction tuning is mostly as good as the data you collect. These and other works show how even with mostly synthetic data produced by other LLMs one can uplift smaller and faster models to the next level, and perhaps even exceed state of the art. The next class of methods is even easier to implement in practice—it doesn’t require fine-tuning or touching the model at all.

    Other Methods: Prompting, RAG, and More

    I have collected the rest of the main approaches in this section, exemplified by one or two specific approaches. Let us begin with prompting.

    Automatic prompt generation

    Prompting is a method where large language models are controlled not with fine-tuning but by simply changing the prompt itself in smart ways. It often turns out that the LLM’s pre-existing knowledge is quite sufficient for many tasks even if the LLM does not show it at first, and instead of retraining a model with additional data, one could simply modify the text input (prompt) to provide instructions, examples, or context in a form that brings this knowledge forward and leads to the desired response. Below, we discuss examples of various methods for gathering context and designing prompts for code-related tasks; naturally, these are only a few samples from a large field.

    The first interesting approach I want to mention is the Repo-Level Prompt Generator (RLPG; Shrivastava et al., 2023), which improved the performance of OpenAI Codex (a closed model) for code completion without any fine-tuning or retraining. The method generates possible context snippets based on the structure of the code and repository and then uses a classifier to filter out irrelevant context for a given code completion position (a “hole”). 

    The authors identified ten categories of context sources, including the imported file, current file, enclosing class, neighboring files, and others. Seven types of context are extracted from each source: class field declarations, method signatures, method signatures with bodies, string constants, and so on. Combining the sources and context types results in 63 predefined types of context snippets that RLPG can include in the model prompt; the context synthesis process is illustrated below (Shrivastava et al., 2023):

    For training the classifier, Shrivastava et al. (2023) collected data from 19 repositories with open licenses from the Google Code Archive, presumed not to overlap with the Codex training set (which used only GitHub projects). After removing blank lines and comments, “holes” for code completion were created in the middle of the remaining lines, for a total of a bit under 93K and capped at 10K per repository. Prompts were generated based on different context sources and types, and they were evaluated via Codex: if the resulting completion matched the original line, the example was labeled as 1, otherwise 0. Approximately 150K queries were made to Codex to gather training data—a large but definitely not prohibitive number.

    The classifier was trained with the loss function

        \[\mathcal{L} = \frac{1}{N} \sum_{h=1}^{N} \mathcal{L}^{h} = \frac{1}{N} \sum_{h=1}^{N} \frac{1}{M^{h}} \sum_{p=1}^{M^{h}} \mathrm{BCE}(\hat{y}_{p}^{h}, y_{p}^{h}) \cdot T_{p}^{h},\]

    where M^{h}=\sum_{p}^{M} T_{p}^{h} denotes the total number of context types applicable to the current “hole”, N is the total number of holes in the dataset, and BCE is the binary cross-entropy.

    Two variations of the model were tested:

    • a fully connected neural network with two layers that projects the context vector into probabilities for each context type:

          \[\hat{y}_p^h = P(y_p^h = 1 | H^h) = \sigma(W^2 (\mathrm{ReLU}(W^1 (F_{\phi}(H^h)) + b^1)) + b^2),\]

      where H^h is the window around a “hole” (two lines above and below), and F_{\phi} is a (frozen) BERT-base model used to extract the context vector of the “hole”;
    • a multi-head attention mechanism (see my previous post on Transformers) for calculating the similarity between the “hole” context and context snippets extracted by one of the 63 methods; these similarity scores are then passed into a fully connected network for probability prediction:

          \[Q^h = F_{\phi}(H^h), \quad K^h_p = F_{\phi}(C^h_p), \quad V^h_p = F_{\phi}(C^h_p);\]

          \[\hat{y}_p^h = P(y_p^h = 1 | H^h, C_p^h) = \sigma \left( W_p G(\mathrm{MHA}(Q^h, K^h, V_p^h)) + b_p \right),\]

      where BERT-base again encodes both the “hole” context H^h and context snippets C_p^h, while MHA denotes multi-head attention whose outputs are passed to the module G that contains two fully connected layers with ReLU activations.

    The second version performed better on test data, and ultimately RLPG improved Codex‘s performance by 17%. RLPG is a characteristic illustration of methods that can improve a model without ever even seeing its internal structure or weights, just by changing the context. But there is more interesting stuff to put into that context.

    Runtime feedback

    Another approach to improving model quality without fine-tuning is using runtime information from program compilation or execution as feedback. As we have discussed in the section on instruction tuning, such information can help large language models solve complex code generation tasks through iterative interaction with tools.

    As an example of this approach, in this section let us consider the Large Language Model Debugger (LDB; Zhong et al., 2024), an advanced method for collecting runtime information that tries to debug programs synthesized by LLMs in the same way as humans debug their own programs. Here is a detailed illustration that we will go through below:

    Key steps in the debugging process are as follows.

    1. Instrumentation and execution: LDB runs the program with enabled tools (i.e., debugger) on public tests to collect information about its behaviour. On this step, LDB generates an execution trace, where nodes represent basic blocks (sequences of instructions with a single entry and exit point) and for each block the debugger provides intermediate values of the program variables.
    2. Analysis: values of the variables in each basic block are analyzed using the LLM. For each block, the model assesses its correctness relative to the task description, providing a True/False verdict and explanations for incorrect blocks. Via this process, the model narrows down the problems and is now able to focus on isolated code fragments, simplifying error diagnostics and debugging.
    3. Bugfixing: the model fixes errors based on debugging information. The model regenerates the code, incorporating identified fixes, and repeats the whole process until the program passes all public tests or the maximum number of debugging iterations is reached.
    4. Validation: the final version of the program is tested on private tests to evaluate its correctness.

    A significant limitation of this approach is the fact that it has to rely on tests to identify where the code goes wrong. Still, LDB significantly improves the performance of open models, achieving a nearly 10% improvement on benchmarks such as HumanEval, MBPP, and TransCoder (Zhong et al., 2024). Again, I remind that specific numbers are not important for us, what matters is that this approach does improve code generation significantly.

    Retrieval-Augmented Generation. We have discussed RAG (retrieval-augmented generation) in detail in a previous post; in code-related tasks, this method is usually employed to gather context for code generation prompts by retrieving relevant information from external sources such as codebases, library documentation, or web search results and appending it to the user query to improve model responses. RAG represents another way to add new knowledge to the model without fine-tuning, and I refer to my post on RAG and the corresponding sources for details (Gao et al., 2023Zhao et al., 2024Fan et al., 2024Li et al., 2022).

    A straightforward example of applying RAG to code generation is DocPrompting (Zhou et al., 2023), which enhances Python and Bash code generation by incorporating relevant documentation into the prompt. The DocPrompting workflow is illustrated below with a Python example: 

    For a given natural language query q it first retrieves a set of relevant documents {d1d2d3} from a documentation corpus D and then uses the retrieved documents and query q as input to the LLM to generate code c. Even this simple approach allows an LLM to generalize to previously unseen usage scenarios by “reading” the retrieved documentation. In the figure above, blue text shows tokens shared between the query and documentation, and bold text highlights tokens shared between the documentation and generated code.

    For document retrieval, one can use classical methods like BM25 (Robertson, Zaragoza, 2009)—they prove surprisingly resilient in modern information retrieval, and BM25 often shows up as a reasonable baseline in comparisons—or vector-based search with autoencoders such as RoBERTa (Liu et al., 2019). Autoencoders can be trained with the contrastive loss

        \[\mathcal{L}^r = -\log \frac{\exp\left(\mathrm{sim}(\mathbf{h}_q, \mathbf{h}_{d_i^+})\right)}{\exp\left(\mathrm{sim}(\mathbf{h}_q, \mathbf{h}_{d_i^+})\right) + \sum_{d_j^- \in \mathcal{B}/\mathcal{D}_q^*} \exp\left(\mathrm{sim}(\mathbf{h}_q, \mathbf{h}_{d_i^+})\right)},\]

    where \mathrm{sim}(\mathbf{h}_x, \mathbf{h}_y) is the cosine distance between vectors \mathbf{h}_x and \mathbf{h}_y, \mathbf{h}_q is the query’s representation, \mathbf{h}_{d_i^+} and \mathbf{h}_{d_i^-} are vector representations of relevant and irrelevant documents respectively. 

    As a result, this simple approach significantly improved the quality of bash commands generation by Codex on the tldr benchmark (Zhou et al., 2022). Similar RAG techniques have successfully been used on internal library documentation (Zan et al., 2022) and samples of similar code (Parvez et al., 2021), again leading to significant improvements.

    Sampling with verification: LEVER

    This approach, exemplified here by a method called Learning to Verify Language-to-Code Generation with Execution (LEVER; Ni et al., 2023), builds on the observation that LLMs often generate hallucinations or errors when producing code, but generating a sufficiently large number of candidate responses significantly increases the likelihood of finding a correct one. We have already discussed this above: the whole idea of the pass@k metric is choosing one of k generations, and the results do improve significantly as k grows; this is a staple in code generation evaluation research (Chen et al., 2021Du et al., 2024Austin et al., 2021). This leads to the following idea: let us train a separate verifier model capable of distinguishing correct programs from incorrect ones, then sample many responses from the LLM and pass them through the verifier. The underlying assumption is that verifying a response is an easier task than generating one from scratch.

    For other applications, such as solving textual math problems, this approach was proposed by, e.g., Cobbe et al. (2021) and Shen et al. (2021) and further developed by Kadavath et al. (2022)Li et al. (2023)Hosseini et al. (2024), and others. In code generation, verification is usually even easier than for mathematical proofs: many errors can be caught during compilation, execution, or running simple test cases.

    Thus, sampling and verification is a natural way to enhance results with minimal computational overhead. The LEVER method (Ni et al., 2023) is a straightforward illustration of this approach, so let me use it to describe some more details. LEVER operates as follows.

    1. Generation: an LLM samples multiple candidate programs \mathbf{y}_1,\ldots,\mathbf{y}_n based on the input task description \mathbf{x}. Sampling uses temperature scaling to increase diversity:

          \[\mathbf{y}_i \sim p_{\mathrm{LM}}(\mathbf{y} | \mathbf{x} ),\]

      where p_{\mathrm{LM}} is the distribution learned by the LLM; at this stage, it makes sense to sample with high temperature to increase the diversity of answers.
    2. Execution: each program \mathbf{y}_i is executed using an interpreter or compiler, and execution results E(\mathbf{y}_i) are added to the input data.
    3. Verification: a discriminative model is trained to predict correctness, i.e., it learns the distribution

          \[p_{\boldsymbol{\theta}}(v | \mathbf{x}, \mathbf{y}, E(\mathbf{y})),\]

      where v is a 0-1 variable for program correctness, and as input the verifier receives the task description \mathbf{x}, program source code \mathbf{y}, and execution results E(\mathbf{y}).

    Here is an illustration for this sequence of steps (Ni et al., 2023):

    When applying the method, programs are ranked based on a composite score derived from generation and verification probabilities. The verifier calculates the composite probability of correctness:

        \[p_R( \mathbf{y} | \mathbf{x}) = p_{\text{LM}}(\mathbf{y} | \mathbf{x}) \cdot p_{\boldsymbol{\theta}}(v=1 | \mathbf{x}, \mathbf{y}, E(\mathbf{y})).\]

    In order to mitigate inconsequential factors such as variable names or program appearance, programs with identical execution results E(\mathbf{y}) are aggregated:

        \[R(\mathbf{x}, \mathbf{y}) = \sum_{\mathbf{y}' \in S: E(\mathbf{y}') = E(\mathbf{y})} p_R(\mathbf{y}' | \mathbf{x}).\]

    The program with the highest result R(\mathbf{x}, \mathbf{y}) is chosen as the output:

        \[\hat{\mathbf{y}} = \arg\max_{\mathbf{y} \in S} R(\mathbf{x}, \mathbf{y}).\]

    The verifier is trained on automatically annotated data. Each candidate program \mathbf{y} receives a label v determined by comparing its execution result E(\mathbf{y}) with the reference E*v=1 if E(\mathbf{y}) = E* and 0 otherwise.

    The verifier’s loss function is defined as

        \[L_\theta = - \frac{1}{|S|} \sum_{\mathbf{y} \in S} \log p_{\theta}(v|\mathbf{x}, \mathbf{y}, E(\mathbf{y})).\]

    By using execution results for dense feedback, LEVER improves verification accuracy and works effectively even with limited data. The authors show that LEVER improves code generation performance on tasks such as SQL parsing, solving mathematical problems, and Python programming (Ni et al., 2023).

    LEVER is not the only method utilizing sampling of multiple programs followed by verification. DeepMind’s AlphaCode system, which achieved human-level performance on the CodeForces platform in 2022 (Li et al., 2022), operates in a similar manner, albeit with significantly greater computational scaling. A sample further development of this idea has been presented by Li et al. (2024), where sampling from LLMs is done through a specially constructed query enumeration strategy. Overall, this direction for improving results is often virtually free, especially if a simple verifier in the form of unit tests already exists, and should work in many practical scenarios. In simple terms, it usually does not make sense to ask an LLM for just one code snippet — ask for three or five instead and choose the best one, it may be better not only in terms of passing tests but also in terms of code quality and coherence (which would not be noticeable in execution-based comparisons).

    LLM-based agents for coding

    Introduction

    These days, agents are often hailed as the next turn in the LLM evolution spiral: it turns out that modern large language models can act as agents in various environments, interacting with them via text commands. Agent-based methods are rapidly evolving in many different applications, and I will not go into a full-scale survey of them here, but let me list a few examples:

    • WebGPT (Nakano et al., 2021), developed by OpenAI, was one of the first examples where LLM agents were trained to interact with web browsers; in a further development of this idea, OpenAI models can search the web at your request, while Claude can even take over your computer entirely, operating it based on sequential screenshots;
    • ReAct (Yao et al., 2023) is an early example of the agentic approach enhancing LLM reasoning abilities;
    • SayCan (Ahn et al., 2022) applied LLM agents to robotics;
    • in an example close to our heart here at Synthesis AI, generative agents (Park et al., 2023) simulate human-like agents for synthetic data generation, and modern recommender systems extend this approach to whole virtual worlds populated by virtual LLM-based “people” that can interact with each other and the recommender system and provide synthetic data in their responses (Zhang et al., 2024Wang et al., 2024a2024b);
    • beyond just synthetic data, there is a whole agentic branch of modern recommender systems, namely conversational recommender systems where an LLM agent holds a dialogue with the user to figure out her preferences and/or propose suitable recommendations (Huang et al., 2023Zhang et al., 2024).

    In the programming domain, agent-based methods allow LLMs to plan, debug, and adapt iteratively. In this section, let us consider three characteristic examples of agent-based approaches; while they are recent papers, they are already prominent and have collected a lot of references. 

    Reflexion

    Shihn et al. (2023) introduced an approach for training LLM-based agents using self-analysis and verbal feedback rather than parameter updates. Instead of traditional weight updates via gradient descent, Reflexion enhances the agent through textual self-reflections that can be stored in memory.

    The method uses three key modules:

    • actor M_a generates text and actions while interacting with the environment;
    • evaluator M_e assesses the quality of actions and provides reward signals;
    • self-reflection model M_{sr} produces verbal reflections based on trajectories, outcomes, and rewards.

    Here is an illustration of Reflexion training:

    Each training step in the figure above consists of:

    • trajectory generation, where the actor M_a produces a trajectory \tau_t = [s_0, a_0, s_1, a_1, \dots, s_T] by interacting with the environment;
    • trajectory evaluation, where the evaluator M_e computes the scalar reward r_t = M_e(\tau_t);
    • reflection generation, where the model M_{\mathrm{sr}} generates a textual reflection \mathrm{sr}_t, which is then stored in the agent’s long-term memory;
    • strategy update, where reflections are used in subsequent episodes to improve actions.

    Importantly, Reflexion incorporates a two-tiered memory system:

    • short-term memory stores the current trajectory \tau_t;
    • long-term memory stores verbal reflections [\mathrm{sr}_0, \mathrm{sr}_1, \dots, \mathrm{sr}_t]; it is bounded by a fixed number of records (usually small, about 1–3).

    The self-reflection model integrates this memory to guide future actions, enabling the agent to learn from past mistakes.

    Thus, in programming tasks, Reflexion trains itself through generated tests and self-analysis: first, the actor M_a generates code, which is then evaluated based on the results of both compilation and test execution and by the evaluator model M_e. Then, the self-reflection model M_{\mathrm{sr}} generates a verbal text reflection describing the changes needed to fix the code. The code is then corrected based on this reflection, and the cycle repeats. This iterative approach yields significant improvements: for example, on the HumanEval benchmark Reflexion increased the pass@1 accuracy by 11%, reaching 91% and surpassing GPT-4 with a weaker base model. Improvements are even more substantial on other benchmarks: solution success rates in AlfWorld (Shridhar et al., 2021) increased by 22%, while reasoning tasks from HotpotQA (Yang et al., 2018) improved by 20%. 

    The agent-based approach with textual reflection has several other important advantages:

    • interpretability: textual reflections provide explicit, human-readable information about what actions were incorrect and how to fix them;
    • no need for retraining: the method works without updating model parameters, relying solely on textual prompts;
    • flexibilityReflexion is applicable to programming, decision-making, reasoning, and agent-based planning, and adapting to different types of feedback is straightforward since the agent’s inputs and improvements are entirely textual.

    Thus, Reflexion opens a promising avenue that combines the power of LLMs with verbal, self-analytical learning; again, I do not claim that Reflection itself is the best possible way to make agents but this is an interesting idea to be explored further. Note that such methods can be applied to any LLM and will improve as the base model improves.

    AgentCoder

    This framework (Huang et al., 2024) is a good example of how any LLM can be turned into an agent with a little bit of scaffolding. Moreover, in the case of coding the scaffolding suggests itself since, as we have already discussed, development is often driven by tests. So the AgentCoder framework consists of three agents:

    • programmer, which is actually doing the code generation;
    • test designer, which tries to design accurate and comprehensive tests based on the coding requirements specified in the task;
    • test executor, which runs the resulting tests against the code developed by the programmer agent and provides feedback for the programmer.

    The interaction between these three agents is very straightforward, as shown below (Huang et al., 2024):

    AgentCoder uses chain-of-thought techniques for the programmer, manually crafted prompts for the test designer intended to cover both edge cases and large-scale inputs, and the test executor is not even an LLM, it is a Python script that provides direct execution feedback to the other two agents.

    Despite this apparent simplicity, AgentCoder significantly improves the results across all tested datasets. For example, on HumanEval the basic GPT-4 model with zero-shot prompting scores about 67.6% in the pass@1 metric, the above-mentioned Reflexion approach achieves 91.0% with the GPT-4 base model, and AgentCoder reaches 96.3% with the same GPT-4 as the base LLM (Huang et al., 2024). This is an important testament to what LLM-based agents can achieve even with a very direct and straightforward method, if executed well.

    SWE-agent

    As an example of further development of the agent-based approach, let us consider the recently developed SWE-agent system (Yang et al., 2024), designed to automate software engineering (SWE) tasks using LLMs as agents. The core idea lies in creating a specialized interaction interface between the agent and the computer, creatively called the Agent-Computer Interface (ACI), which makes it easier for the LLM agent to solve tasks related to writing, modifying, and testing code.

    The LLM functions as an agent, interacting with the computer environment through actions such as editing files or running tests and receiving feedback from the environment. Unlike traditional human-oriented interfaces (e.g., terminals or IDEs), ACI employs:

    • compact commands with minimal parameters,
    • informative but concise feedback on changes made by the agent, and
    • mechanisms to protect against common LLM errors, e.g., syntax checks during editing.

    This design overcomes the limitations of standard interfaces, which can be overly complex for LLM-based models; here is an illustration by Yang et al. (2024):

    SWE-agent itself incorporates the following key components (also illustrated above).

    1. Codebase navigation: special commands such as find_file and search_dir allow the agent to quickly find relevant finds and specific lines.
    2. File editing: the edit command allows the agent to replace ranges of source code lines and automatically shows the updated content for verification.
    3. Context management: the system controls the history of the agent’s actions and observations to minimize redundancy in memorized data.
    4. Feedback: each step is reflected in feedback about agent actions, including syntax errors and other problems.

    On each step, the agent alternates between generating “thoughts” (reasoning about the next step) and executing specific commands. Here is a sample workflow of the agent (Yang et al., 2024):

    • the agent uses the search_file command to locate mentions of the PVSystem function;
    • upon finding the necessary file, it opens it with the open command and looks through the contents;
    • the agent modifies the identified sections of code using the edit command; syntax errors, if any, are automatically detected and reported back to the agent for correction;
    • once the changes are complete, the agent runs updated code using pytest to validate the modifications.

    SWE-agent showed significant improvements, achieving state of the art at the time and sometimes coming out ahead in big leaps. On the SWE-bench dataset it solved 12.47% of the problems, while the best non-interactive model at the time achieved only 3.8%. On HumanEvalFixSWE-agent achieved the pass@1 metric of 87.7%, which was also a big step forward. And again, the numbers themselves do not matter much and have been overcome since then, what matters is that SWE-agent led to significant improvements compared to its own base model at little inference cost and with no retraining.

    In summary, LLM-based agents represent a significant step forward in automating software engineering tasks. They can iteratively improve through structured interaction with the programming environment and can tackle complex, real-world challenges. In this section, we have considered two examples: Reflection shows a simple way to add verbal memory to an LLM agent, while SWE-agent shows that it makes sense to adapt the programming interfaces, currently designed for humans, to let LLMs use them more efficiently.

    Sample industrial code models

    Introduction

    In this section, we examine the fine-tuning process for two popular industrial code models with open weights. As we have seen throughout this post, fine-tuning aims to transform a base model into a practical assistant for programmers. Similar to the foundational work on InstructGPT, this process is typically divided into two primary stages: instruction tuning and aligning the model’s goals with human preferences (AI alignment), usually through reinforcement learning methods.

    In this section, we review the fine-tuning process using two models as examples: Qwen2.5-Coder (Hui et al., 2024) and DeepSeek-Coder-V2 (Zhu et al., 2024). Their comparative results on the LiveCodeBench benchmark are shown in the tables above, so at the time of writing, these models ranked among the top performers alongside commercial solutions from OpenAI (GPT family) and Anthropic (Claude family). 

    Naturally, by the time you read this the situation has probably changed already — as I have said many times, it is impossible to keep up with AI progress these days if you are writing a book or even a series of long-form posts. But still, in any case, it is a great and potentially relevant illustration. Although the source code and training data are often not available to the research community, existing information about the fine-tuning process for these models is still interesting to consider and can be useful as a guide for your own fine-tuning process or for further specialization of these models to specific practical tasks.

    Qwen2.5-Coder

    The instruction tuning stage for the Qwen2.5-Coder (Hui et al., 2024) model consisted of the following steps.

    1. Coarse fine-tuning: training on 10 million diverse, low-quality synthesized instructions.
    2. Fine fine-tuning: training on 1 million higher-quality instructions to further improve performance; this step uses rejection sampling, where the LLM generates multiple responses to the same instruction and low-quality responses are filtered out.
    3. Training on fill-in-the-middle instructions (FIM; Bavarian et al., 2022); since instructions are generally short, adding FIM data preserves the model’s ability to handle long contexts; the developers used automatic parsers from the tree-sitter-languages library to extract various code constructs (e.g., expressions, functions, or blocks) from source code files and asked the model to reconstruct a removed code snippet based on the provided context.

    The alignment stage used the direct preference optimization (DPO) method (Rafailov et al., 2023), where the model’s behavior is adjusted based on feedback from test executions and quality assessments made by other language models:

    This approach is also known as LLM-as-a-judge (Zheng et al., 2023) and is increasingly used in practice to supplement or altogether avoid human feedback. The overall process improves code generation quality and aligns the model with user expectations.

    Data preparation for fine-tuning is another hugely important part of the pipeline that often gets overlooked when we talk about machine learning ideas. Since this is an “industrial” section, let us discuss data preparation as well; for Qwen2.5-Coder, it involved a rigorous process of data collection, cleaning, and analysis to create a high-quality set of instructions.

    First, the initial data was sourced from open GitHub repositories and existing instruction datasets for code such as McEval-Instruct. The developers trained a specialized classifier based on the CodeBERT model (Feng et al., 2020) to identify the programming language or absence of code, and included top 100 programming languages in the dataset.

    Second, to further expand and improve the dataset, synthetic instructions were generated using large language models. These instructions were derived from GitHub code snippets and filtered for quality using, again, LLMs; as far as I can tell, there was no human labeling involved beyond perhaps few-shot examples in the LLM prompts.

    Third, a multi-agent system was developed to generate new instructions in various programming languages. Agents were assigned language specializations and could exchange knowledge to improve the dataset. External memory was provided to agents to prevent information duplication. This approach promoted the reuse of programming concepts and knowledge across languages.

    Finally, we come to data evaluation: the authors developed a checklist to evaluate the final dataset based on criteria such as relevance, complexity, presence of comments, educational value, and other factors. Whenever possible, the data was validated using isolated environments for code execution, including syntax checks, automatic unit test creation, and test execution.

    Overall, these steps have been able to ensure a high-quality, multilingual instruction dataset. This is a great example where we can at least partially lift the veil on the data preparation process, which is always hidden for closed source commercial LLMs. I am sure that OpenAIAnthropic, and Google have a much more involved process for collecting, preparing, and evaluating their internal datasets, and that this is a key step in pushing the state of the art in any AI direction, including coding.

    DeepSeek-Coder-V2

    On the instruction tuning stage, DeepSeek-Coder-V2 (DeepSeek-AI et al., 2024a)  used a dataset that includes programming and mathematical tasks. The authors selected 20000 programming instructions and 30000 mathematical instructions from DeepSeek-Coder (Guo et al., 2024) and DeepSeek-Math (Shao et al., 2024), respectively. To retain the model’s general capabilities, they also used additional data from DeepSeek-V2 (DeepSeek-AI et al., 2024b), and the final dataset consisted of 300 million tokens.

    For the alignment stage, DeepSeek-Coder-V2 used the Group Relative Policy Optimization (GRPO) method (Shao et al., 2024), a memory-efficient modification of PPO. This phase also included programming tasks with unit tests with unit tests specially written for programming tasks. After filtering, the dataset contained about 40K instructions.

    One interesting modification used here dealt with the reward function. For mathematical tasks, the reward function simply compared generated answers with the correct solution. For programming tasks, passing tests also served as a binary reward, but some instructions did not have enough tests for full coverage, making “passing tests” a noisy and suboptimal signal for training. To address this, the authors trained a reward model on test execution data and used it during RL training. This provided a less noisy signal compared to raw test results, and the authors showed how their reward model significantly outperformed direct test feedback on internal datasets (Leetcode and Leetcode-zh), as shown below:

    In general, Qwen2.5-Coder (Hui et al., 2024) and DeepSeek-Coder-V2 (Zhu et al., 2024) represent the cutting edge of open weight AI models for programming tasks, and, interestingly, published details on these models, while introducing some new ideas, align closely with the core approaches we have discussed above, and it makes me hopeful that this review is actually still relevant. 

    The gap in both benchmark scores and actual perceived performance between these two models and closed models from OpenAI and Anthropic is minimal… oops, sorry, was minimal, and then new test results for the full-scale o1 family came out, and now the O1-2024-12-17 (N=1) model tentatively sits in the first row of the table above with 72.4% average pass@1 metric for code generation, compared to 67.2% and 52.5% scored by the first and second row in the table respectively. 

    The o1 family is a very different can of worms that I don’t want to open here (see, e.g., my earlier post on o1-preview), but in any case, by the time you read this, even better results have probably been published; e.g., the recently announced o3 family does not have a LiveCodeBench entry yet. So let me conclude the chapter with some projections for the future.

    Conclusion: So Is Coding Dying or Not?..

    All around us, we see examples of how LLMs and especially LLM-based agents navigate and manipulate large codebases, efficiently debug and test programs, and adapt iteratively based on feedback, mimicking human-like development workflows. We have discussed a lot of research papers and even some practical open models, but, of course, the actual frontier is being pushed by the closed models developed in OpenAIAnthropic, and Google.

    One of the most important recent developments has been the o1 family of models; recently, o1 Pro has been made available. Here is the system card, although later it has turned out that the system card is mostly not about the actual model being deployed; see a detailed post by Zvi Mowshowitz about this controversy. O1 Pro is much more expensive than the usual offerings ($200 per month compared to $20 for the standard OpenAI subscription) but it appears that it is accordingly more powerful.

    I really hope to talk about the o1 (and maybe o3) family of models in a future post. So far all we have about it are speculations, but there have already been several relatively successful replication attempts, so speculations do matter. At this point, let me just point out that programming is another big strength of o1-pro. Many people have commented on how much better it is for handling large codebases and writing a lot of code from scratch with no mistakes; here is one such response with a video that shows how o1 has helped refactor a large codebase with extensive documentation that needed to be taken into account, and here is a report where the only downside is that o1-pro still won’t challenge the basic assumptions of your project, even if it would be helpful.

    Dissenting views mostly agree that o1-pro is excellent but say that, e.g., the latest Claude Sonnet 3.5, which still costs $20, is not noticeably worse and may be even better in some respects; here is one comparison. In any case, modern top LLMs already can do a lot, and coding applications are on top of that:

    Moreover, remember the main underlying principle of such capability examples: this is as bad as it ever will be, the only way from here is up. Demis Hassabis, cofounder and CEO of DeepMind, said while introducing Gemini 2.0 (another top model that has excellent coding capabilities): “This is really just the beginning. 2025 will be the year of AI agents.” Indeed, in this post we have already seen how adding even a straightforward agentic scaffold can improve coding capabilities of LLMs a lot, and there will be, of course, more and better work in this direction.

    So if you planned on being a coder then yes, maybe it is time to rethink this plan. Programming is already being changed and will no doubt be completely redefined by the LLM revolution. However, unless programmers are replaced entirely, it looks to be something that computer science has already experienced several times. No programmer of today can write in machine code because it has been replaced by assembly instructions. Few programmers can write in Assembler because it has been replaced by high-level languages such as C/C++. These languages have also been becoming more and more high-level: for example, only very low-level C programs work with pointers directly, and most memory management is done automatically. Maybe this will be another step up the meta ladder: instead of writing code, programmers will be writing instructions and specifications for AI models to fill with actual code.

    On the other hand, I am still positive that if you want to create stuff, you do need to understand how things work. Maybe a programmer of the future will be mostly writing prompts for advanced AI tools, and the tools will write the code—but it will still help to know computer science, understand how computers operate on the basic level, and be able to dive in if necessary, even if it will be necessary less and less often.

    In any case, 2025 is certain to be a very exciting year for AI. Happy New Year, everyone!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Linear Attention and Mamba: New Power to Old Ideas

    Linear Attention and Mamba: New Power to Old Ideas

    We have already discussed how to extend the context size for modern Transformer architectures, but today we explore a different direction of this research. In the quest to handle longer sequences and larger datasets, Transformers are turning back to the classics: the memory mechanisms of RNNs, associative memory, and even continuous dynamical systems. From linear attention to Mamba, modern models are blending old and new ideas to bring forth a new paradigm of sequence modeling, and this paradigm is exactly what we discuss today.

    Introduction: Explaining the Ideas

    We have already discussed at length how Transformers have become the cornerstone of modern AI, powering everything from language models to image processing (see a previous post), and how the complexity of self-attention, which is by default quadratic in the input sequence length, leads to significant limitations when handling long contexts (see another post). Today, I’d like to continue this discussion and consider the direction of linear attention that has led to many exciting advances over the last year.

    In the several years of writing this blog, I have learned that it is a futile attempt to try to stay on top of the latest news in artificial intelligence: every year, the rate of progress keeps growing, and you need to run faster and faster just to stay in one place. What I think still matters is explaining ideas, both new ideas that our field produces and old ideas that sometimes get incorporated into deep learning architectures in unexpected ways.

    This is why I am especially excited about today’s post. Although much of it is rather technical, it allows me to talk about several important ideas that you might not have expected to encounter in deep learning:

    • the idea of linear self-attention is based on reframing the self-attention formula with the kernel trick, a classical machine learning technique for efficiently learning nonlinear models with linear ones (e.g., SVMs);
    • then, linear attention becomes intricately linked with associative memory, a classical idea suggested in the 1950s and applied to neural networks at least back in the 1980s in the works of the recent Nobel laureate John Hopfield, and fast weight programmers, an approach developed in the early 1990s;
    • finally, Mamba is the culmination of a line of approaches based on state space models (SSM), which are actually continuous time dynamical systems discretized to neural architectures.

    Taken together, these techniques represent a line of research… well, my first instinct here was to say “an emerging line of research” because most of these results are under two years old, and Mamba was introduced in December 2023. But in fact, this is an already pretty well established field, and who knows, maybe this is the next big thing in sequence modeling that can overcome some limitations of basic Transformers. Let us see what this field is about.

    Linear Attention: The Kernel Trick in Reverse

    As we have discussed many times (e.g., here and here), traditional Transformers use softmax-based attention, which computes attention weights over the entire input sequence:

        \[\mathbf{Z} = \mathrm{softmax}\left(\frac{1}{\sqrt{d_k}}\mathbf{Q}\mathbf{K}^\top\right)\mathbf{V}.\]

    This formula means that \mathbf{k}_i^\top\mathbf{q}_j serves as the measure of similarity between a query \mathbf{q}_j and a key \mathbf{k}_i (see my previous post on Transformers if you need a reminder about where queries and keys come from here), and one important bottleneck of the Transformer architecture is that you have to compute the entire L\times L matrix of attention weights. This quadratic complexity O(L^2) limits the input size, and we have discussed several different approaches to alleviating this problem in a previous post.

    Linear attention addresses this problem with a clever use of the kernel trick, a classical idea that dates back to the 1960s (Aizerman et al., 1964). You may know it from support vector machines (SVM) and kernel methods in general (Schölkopf, Smola, 2001Shawe-Taylor, Cristianini, 2004Hofmann et al., 2008), but you also may not, so let me begin with a brief explanation.

    Suppose that you have a linear classifier, i.e., some great way to find a hyperplane that separates two (or more) sets of points, e.g., a support vector machine that finds the widest possible strip of empty space between two classes:

    But in reality, it may happen that the classes are not linearly separable; for instance, what if the red points surround the blue points in a circle? In this case, there is no good linear decision surface, no hyperplane that works, and linear SVMs will fail too:

    But it is also quite obvious what a good decision boundary would look like: it would be a quadratic surface. How can we fit a quadratic surface if we only have a linear classifier? Actually, conceptually it is quite easy: we extract quadratic features from the input vector and find the linear coefficients. In this case, if we need to separate two-dimensional points \mathbf{x} = (x_1, x_2), a quadratic surface in the general case looks like

        \[w_0 + w_1x_1 + w_2x_2+w_3x_1^2+w_4x_1x_2+w_5x_2^2 = 0,\]

    so we need to go from a two-dimensional vector to a five-dimensional one by extracting quadratic features:

        \[\mathbf{x} = \left(\begin{matrix} x_1 & x_2\end{matrix}\right)^\top\quad\longrightarrow\quad \phi(\mathbf{x}) = \left(\begin{matrix} x_1 & x_2 & x_1^2 & x_1x_2 & x_2^2\end{matrix}\right)^\top.\]

    In the five-dimensional space, the same formula is now a linear surface, and we can use SVMs to find the best separating hyperplane in \mathbb{R}^5 that will translate into the best separating quadratic surface in \mathbb{R}^2:

    You can use any linear classifier here, and the only drawback is that we have had to move to a higher feature dimension. Unfortunately, this is a pretty severe problem: it may be okay to go from \mathbb{R}^2 to \mathbb{R}^5 but even if we only consider quadratic features as above it means that \mathbb{R}^d will turn into \mathbb{R}^{d(d+1)/2}, which is a much higher computational cost, and a higher degree polynomial will make things much worse.

    This is where the kernel trick comes to the rescue: in SVMs and many other classifiers, you can rewrite the loss function in such a way that the only thing you have to be able to compute is the scalar product of two input vectors \mathbf{x} and \mathbf{x}' (I will not go into the details of SVMs here, see, e.g., Cristianini, Shawe-Taylor, 2000). If this property holds, instead of directly going to the larger feature space we can look at what computing the scalar product means in that space and probably find a (nonlinear) function that will do the same thing in the smaller space. For quadratic features, if we change the feature extraction function to 

        \[\phi(\mathbf{x}) = \left(\begin{matrix}\sqrt{2}x_1 & \sqrt{2}x_2 & x_1^2 & \sqrt{2}x_1x_2 & x_2^2\end{matrix}\right)^\top\]

    (this is a linear transformation that does not meaningfully change our classification task), we can rewrite the scalar product as 

        \begin{multline*}\phi(\mathbf{x})^\top \phi(\mathbf{x}') = 2x_1x'_1 + 2x_2x'_2 + x_1^2{x'}_2^2 + 2x_1x_2x'_1x'_2 + x_2^2{x'}_2^2 = \\ = 2\x^\top\x' + \left(x_1x'_1 + x_2x'_2\right)^2 = 2\mathbf{x}^\top\mathbf{x}' + \left(\mathbf{x}^\top\mathbf{x}'\right)^2 = \left(\mathbf{x}^\top\mathbf{x}' + 1\right)^2 - 1 = k(\mathbf{x},\mathbf{x}'). \end{multline*}

    The result is called a kernel function k(\mathbf{x},\mathbf{x}'), and we can now replace scalar products in the higher-dimensional space with nonlinear functions in the original space. And if your classifier depends only on the scalar products, the dimension of the feature space is not involved at all any longer; you can even go to infinite-dimensional functional spaces, or extract local features and have SVMs produce excellent decision surfaces that follow the data locally:

    Well, linear attention is the same trick but in reverse: instead of using a nonlinear function to represent a high-dimensional dot product, let us use a feature extractor to approximate the nonlinear softmax kernel! We transform each key \mathbf{k} and query \mathbf{q} using a feature map \phi, so that similarity between them can be computed as a dot product in feature space, \phi(\mathbf{k})^\top\phi(\mathbf{q}); let us say that \phi maps the query and key space \mathbb{R}^d to \mathbb{R}^n. So instead of computing each attention weight as the softmax

        \[\alpha_{ij} = \frac{\exp(\mathbf{q}_i^\top\mathbf{k}_j)}{\sum_{l=1}^L\exp(\mathbf{q}_i^\top\mathbf{k}_l)},\qquad\mathbf{z}_{i} = \sum_{j=1}^L\alpha_{ij}\mathbf{v}_j = \sum_{j=1}^L\frac{\exp(\mathbf{q}_i^\top\mathbf{k}_j)}{\sum_{l=1}^L\exp(\mathbf{q}_i^\top\mathbf{k}_l)}\mathbf{v}_j\]

    (I will omit the constant \sqrt{d} for simplicity; we can assume that it is incorporated into the query and key vectors), we use a feature map, also normalizing the result to a convex combination:

        \[\alpha_{ij} = \frac{\phi(\mathbf{q}_i)^\top\phi(\mathbf{k}_j)}{\sum_{l=1}^L\phi(\mathbf{q}_i)^\top\phi(\mathbf{k}_l)}.\]

    This is much more convenient computationally because now we can rearrange the terms in the overall formula for \mathbf{z}_i, and the quadratic complexity will disappear! Like this:

        \[\mathbf{z}_i = \sum_{j=1}^L\alpha_{ij}\mathbf{v}_j = \sum_{j=1}^L\frac{\phi(\mathbf{k}_j)^\top\phi(\mathbf{q}_i)}{\sum_{l=1}^L\phi(\mathbf{k}_l)^\top\phi(\mathbf{q}_i)}\mathbf{v}_j = \frac{\left(\sum_{j=1}^L\mathbf{v}_j\phi(\mathbf{k}_j)^\top\right)\phi(\mathbf{q}_i)}{\left(\sum_{l=1}^L\phi(\mathbf{k}_l)^\top\right)\phi(\mathbf{q}_i)}.\]

    Now instead of computing the L\times L matrix of attention weights we can first multiply \phi(\mathbf{K}) by \mathbf{V} (getting the brackets in the numerator above), which is a multiplication of n\times L and L\times d' matrices, and then reuse it for every query, multiplying the L\times n matrix \phi(\mathbf{Q}) by the result:

    Note that the result has dimension n rather than d, and also note that I put \mathbf{Z}' as the output in the (b) part of the figure (on the top right) because it is only the numerator of the fraction above, but the denominator is also obviously not quadratic: we can first add up \phi(\mathbf{k}_l) and then multiply by each query. Let us also simplify the formula above by denoting

        \[\mathbf{S} = \sum\nolimits_{j=1}^L\mathbf{v}_j\phi(\mathbf{k}_j)^\top,\quad \mathbf{S}\in\RR^{d\times n},\qquad \mathbf{u} = \sum\nolimits_{l=1}^L\phi(\mathbf{k}_l),\quad\mathbf{u}\in\RR^n,\]

    so we get a simple formula for linear attention as

        \[\mathbf{z}_i = \frac{\mathbf{S}\phi(\mathbf{q}_i)}{\mathbf{u}^\top \phi(\mathbf{q}_i)}.\]

    This is exactly the idea of linear attention as proposed by Katharopoulos et al. (2020). But there is one more important step.

    Causal Linear Attention: Transformers are RNNs?

    We know that Transformers are often applied autoregressively. Any language model, e.g., from the GPT family (recall our post on Transformers), is an autoregressive model that applies self-attention to the same sequence gradually, step by step, and causally: an output at position t depends only on inputs in positions 1,\ldots,t-1.

    To train an autoregressive Transformer, you don’t have to rerun the whole model for every token, like you do for generation. Instead, autoregressive Transformers use causal self-attention, a special modification where the entire sequence is input at once, but the attention weights to future tokens are automatically set to zero. This means that we get the same self-attention formula but with sums only going up to the current t-th token:

        \[\alpha_{tj} = \frac{\exp(\mathbf{q}_t^\top\mathbf{k}_j)}{\sum_{l=1}^t\exp(\mathbf{q}_t^\top\mathbf{k}_l)},\qquad\mathbf{z}_{t} = \sum_{j=1}^t\alpha_{tj}\mathbf{v}_j = \sum_{j=1}^t\frac{\exp(\mathbf{q}_t^\top\mathbf{k}_j)}{\sum_{l=1}^t\exp(\mathbf{q}_t^\top\mathbf{k}_l)}\mathbf{v}_j.\]

    Passing to a scalar product with feature extractor \phi as above, we get

        \[\mathbf{z}_t = \frac{\mathbf{S}_t\phi(\mathbf{q}_t)}{\mathbf{u}_t^\top \phi(\mathbf{q}_t)},\quad\text{where}\quad\mathbf{S}_t = \sum\nolimits_{j=1}^t\mathbf{v}_j\phi(\mathbf{k}_j)^\top,\quad \mathbf{u}_t = \sum\nolimits_{l=1}^t\phi(\mathbf{k}_l).\]

    It is becoming more and more clear where this is going: since \mathbf{S}_{t} and \mathbf{u}_{t} are just cumulative sums, we don’t have to recompute them from scratch on inference; instead, we can update them from previous values as

        \[\mathbf{S}_t=\mathbf{S}_{t-1}+\mathbf{v}_t\phi(\mathbf{k}_t)^\top,\qquad \mathbf{u}_t=\mathbf{u}_{t-1} + \phi(\mathbf{k}_t).\]

    Katharopoulos et al. (2020) also show that the gradients can be computed incrementally from timestep to timestep; this is a straightforward calculation so I will not repeat it here. As a result, they come to the conclusion that their linear Transformer is… essentially a recurrent neural network (RNN)! This “RNN” has a hidden state that consists of two different components, a matrix state \mathbf{S}_{t} and a normalizer state \mathbf{u}_{t}; we have derived the formulas for how to update this recurrence above, and we also know the formula for the output of this recurrent layer:

        \[\mathbf{o}_t=\frac{\mathbf{S}_{t}\phi(\mathbf{q}_t)}{\mathbf{u}_t^\top\phi(\mathbf{q}_t)}.\]

    In practice, one often removes the normalizing denominator since it can lead to numerical instabilities (Schlag et al., 2021Mao, 2022), and the feature extractor \phi is commonly taken to be \phi =\mathrm{id}, so the formulas simplify to

        \[\mathbf{S}_t=\mathbf{S}_{t-1}+\mathbf{v}_t\mathbf{k}_t^\top,\qquad \mathbf{o}_t=\mathbf{S}_t\mathbf{q}_t.\]

    But isn’t that a little too simple? Linear attention uses the kernel trick to approximate the softmax mechanism efficiently, enabling Transformers to handle longer sequences. However, this shift from quadratic to linear complexity raises questions about the fundamental role and meaning of attention: how should models store and retrieve relevant information efficiently? In the next section, we discuss associative memory, a classical concept in neural networks, which in this case turns out to be an important point of view on this question. In particular, it shares a similar goal of learning to store patterns and retrieving them based on input queries. By revisiting associative memory, we can better understand the underlying mechanisms of linear attention and their limitations.

    Fast Weight Programmers and Associative Memory

    We discuss several approaches in this section but mostly follow Schlag et al. (2021) who provide us with some key intuition about linear Transformers. They note that linear Transformers are almost entirely equivalent to an architecture called Fast Weight Programmers (FWPs), developed by Jurgen Schmidhuber (yes, this was his idea too!) in the early 1990s (Schmidhuber, 19921993).

    FWPs come from the basic intuition that the weights in standard neural networks remain fixed after training; activations change depending on the input, but the weights themselves are frozen. This is a bad thing for what is known as the binding problem (Greff et al., 2020): a neural network has no easy way to bind variables, define symbols, and thus construct compositional internal representations and perform symbolic reasoning that plays a key role in human cognition (Whitehead, 1927Spelke, Kinzler, 2007Johnson-Laird, 2010).

    One possible solution for the binding problem would be to have two kinds of weights in a neural network: slow weights that are fixed after training as usual and fast weights that are context-dependent and can change on inference. As Greff et al. (2020) put it, “the slow net learns to program its fast net”. In an FWP (Schmidhuber, 19911992), the slow network learns to adjust fast weights as follows: for a sequence of inputs \mathbf{x}_i, i=1,\ldots,L,

        \begin{align*} \mathbf{a}_i & = \mathbf{W}_a\mathbf{x}_i, & \mathbf{b}_i & = \mathbf{W}_b\mathbf{x}_i, \\ \mathbf{W}_i &= \sigma\left(\mathbf{W}_{i-1} + \mathbf{a}_i\mathbf{b}_i^\top\right), & \mathbf{y}_i &= \mathbf{W}_i\mathbf{x}_i, \end{align*}

    where \mathbf{W}_a and \mathbf{W}_b are slow weights and \mathbf{W}_i are fast weights. In essence, fast weights play the role of an RNN’s hidden state and the formulas above define the recurrence (Schmidhuber himself rephrased this idea in recurrent terms a year later, in 1993).But note the uncanny resemblance of this update rule and Transformer’s self-attention: Schmidhuber’s FWPs also make use of the outer produce \mathbf{a}\mathbf{b}^\top to update the hidden state! FWPs create a short-term associative memory where keys are associated with values in a matrix form, the write operation is implemented by adding the outer product, and the readout is represented by matrix multiplication.

    You can see how this resemblance becomes a formal equivalence when we move to linear attention: if we set the activation function σ above to identity, we get exactly the update rule and readout of simplified linear attention:

        \[\mathbf{S}_t=\mathbf{S}_{t-1}+\mathbf{v}_t\mathbf{k}_t^\top,\qquad \mathbf{o}_t=\mathbf{S}_t\mathbf{q}_t.\]

    Normalization (the vector \mathbf{u}_t above) was absent from the FWPs of the 1990s but it also a straightforward idea in this formulation: whenever you have a “memory” that accumulated a big sum of values along the input sequence, it is natural to try and renormalize the sum to keep it at the same scale.

    To make further improvements, Schlag et al. (2021) also go back to the original motivation for the whole thing: fit information into the hidden state matrix \mathbf{S}_t. The relation to fast weight programmers also brings back the original goal of this transformation: we store vectors in the matrix \mathbf{S}, and then retrieve this information via matrix multiplication. Let us discuss this in more detail.

    The idea of storing information in this way is known as associative memory, a classical concept in artificial intelligence (see, e.g., Haykin, 2011) which is a natural generalization of, well, just storing things in memory:

    • in regular memory, you have d slots where you can store something (say, a vector), and retrieval from the memory can be thought of as multiplying the memory matrix by a vector; storing something new in regular memory can be thought of as adding a rank one matrix with the new vector in its proper slot;
    • in associative memory, you have a matrix \mathbf{A} that stores vector associations as projections to some orthogonal basis; to store a new association \mathbf{v} in the matrix \mathbf{A}, you need to choose a key vector \mathbf{k} that’s orthogonal to previous key vectors and update \mathbf{A}=\mathbf{A}+\mathbf{k}^\top\mathbf{v}; to retrieve the association, you do a projection by multiplying \mathbf{k}\mathbf{A}.

    Associative memory is another one of those ideas that were motivated by neurobiology and date back to early studies of the brain. In 1949, Donald Hebb introduced his famous learning principle, often summarized as “neurons that fire together, wire together” (Hebb, 1949); in other words, associations between neurons, reflected in synapse weights, grow stronger if neurons get activated at the same time. Unlike gradient descent, Hebbian learning is actually possible with biological neurons, and Hebb’s work in many ways remains relevant in neurobiology today (his theory also made provisions for, e.g., spike-timing-dependent plasticity that was not known in the 1940s).

    It soon became clear that associative memory could be used as a kind of machine learning model. Early attempts at such models started in the 1950s (Taylor, 1956), but two ideas based on associative memory found wide success later:

    • self-organizing maps (SOM), or Kohonen networks, developed by Teuvo Kohonen in the 1970s (Kohonen, 1974), were at some point among the most popular unsupervised learning methods, performing representation learning by adjusting the weights towards neurons that are already best matches for the input, a process known as competitive learning (Grossberg, 1987Kohonen, 1988);
    • Hopfield networks, developed by John Hopfield in the 1980s (Hopfield, 19821984), store patterns in minima of energy landscapes of neural networks and retrieve them by evolving towards these local minima, which means that retrieval is done by association from incomplete data; there has been a lot of research on Hopfield networks (Krotov, Hopfield, 20162020Demircigil et al., 2017Ramsauer et al., 2020), and recently John Hopfield shared the 2024 Nobel Prize in Physics with Geoffrey Hinton for his work in neural networks, but this is a story for another time.

    Let us walk through an example of how associative memory works. We will work in 2D so that we can plot everything, so we begin with a 2\times 2 zero matrix \mathbf{A}. Suppose that we want to store two vectors in that matrix,

        \[\mathbf{x}_1 = \left(\begin{matrix}2 \\ 3\end{matrix}\right),\qquad\mathbf{x}_2 = \left(\begin{matrix}-2 \\ 1\end{matrix}\right).\]

    If we were just storing them in the matrix column by column, it would be equivalent to using keys aligned with coordinate axes:

        \[\mathbf{A}_1 = \mathbf{x}_1 \left(\begin{matrix}1 \\ 0\end{matrix}\right)^\top = \left(\begin{matrix} 2 & 0 \\ 3 & 0\end{matrix}\right),\qquad \mathbf{A}_2 = \mathbf{A}_1 + \mathbf{x}_2 \left(\begin{matrix}0 \\ 1\end{matrix}\right)^\top = \left(\begin{matrix} 2 & -2 \\ 3 & 1\end{matrix}\right).\]

    Reading from this memory is simply reading the columns, or, equivalently, multiplying by (1 0) and (0 1) key vectors. But we can take any other set of two orthogonal key vectors, say (let’s keep them at unit length to avoid renormalization):

        \[\mathbf{k}_1 = \frac{1}{\sqrt{2}} \left(\begin{matrix}1 \\ 1\end{matrix}\right),\qquad\mathbf{k}_2 = \frac{1}{\sqrt{2}} \left(\begin{matrix}-1 \\ 1\end{matrix}\right).\]

    In this case, we get

        \[\mathbf{A}_1 = \mathbf{x}_1\mathbf{k}_1^\top = \frac{1}{\sqrt{2}} \left(\begin{matrix}2 & 2 \\ 3 & 3\end{matrix}\right),\quad \mathbf{A}_2 = \mathbf{A}_1 + \mathbf{x}_2\mathbf{k}_2^\top = \mathbf{A}_1 + \frac{1}{\sqrt{2}} \left(\begin{matrix}2 & -2 \\ -1 & 1\end{matrix}\right) = \frac{1}{\sqrt{2}} \left(\begin{matrix} 4 & 0 \\ 2 & 4\end{matrix}\right).\]

    Reading from this matrix still works fine:

        \[\mathbf{A}_2\mathbf{k}_1 = \frac{1}{\sqrt{2}} \left(\begin{matrix}4 & 0 \\ 2 & 4\end{matrix}\right)\cdot\frac{1}{\sqrt{2}} \left(\begin{matrix}1 \\ 1\end{matrix}\right) = \left(\begin{matrix} 2 \\ 3 \end{matrix}\right),\quad\mathbf{A}_2\mathbf{k}_2 = \frac{1}{\sqrt{2}} \left(\begin{matrix}4 & 0 \\ 2 & 4\end{matrix}\right)\cdot\frac{1}{\sqrt{2}} \left(\begin{matrix} -1 \\ 1\end{matrix}\right) = \left(\begin{matrix}-2 \\ 1\end{matrix}\right).\]

    But if you try to add a third vector to the same associative memory with a third key, which is now inevitably non-orthogonal with the first two, say,

        \[\mathbf{x}_3 = \left(\begin{matrix}1 \\ 2\end{matrix}\right),\quad \mathbf{k}_3 = \frac{1}{\sqrt{5}} \left(\begin{matrix} 2 \\ -1\end{matrix}\right), \quad \mathbf{A}_3 =\mathbf{A}_2 + \mathbf{x}_3\mathbf{k}_3^\top = \frac{1}{\sqrt{2}} \left(\begin{matrix}4 & 0 \\ 2 & 4\end{matrix}\right)+\frac{1}{\sqrt{5}} \left(\begin{matrix}2 & -1 \\ 4 & -2\end{matrix}\right),\]

    retrieval results will become corrupted, both for the original vectors and for the new vector \x_3:

        \[\mathbf{x}'_1 = \mathbf{A}_3\mathbf{k}_1 = \mat{2 + \frac{1}{\sqrt{10}} \ 3 + \frac{2}{\sqrt{10}} },\quad\mathbf{x}'_2 = \mathbf{A}_3\mathbf{k}_2 = \left(\begin{matrix}-2 - \frac{3}{\sqrt{10}} \\ 1 - \frac{6}{\sqrt{10}} \end{matrix}\right),\quad\mathbf{x}'_3 = \mathbf{A}_3\mathbf{k}_3 = \left(\begin{matrix}\frac{8}{\sqrt{10}}+1 \\ 2 \end{matrix}\right).\]

    Geometrically this effect can be illustrated as below; we can find two orthogonal vectors for the first two keys (on the left in the figure) but the third one breaks perfect retrieval (retrieved vectors are shown with dashed lines on the right):

    So far, it doesn’t sound like much of an improvement: we could just store vectors row by row and have the exact same number of them fit. The point of associative memory lies in its robustness to the orthogonality requirement: if the keys are nearly orthogonal you will retrieve vectors that are still quite similar to the originals, even if the keys are not orthogonal exactly. And this means that we can fit more keys than the matrix dimension, with imperfect but still reasonable recall!

    This is hard to illustrate with a two-dimensional picture but in high dimensions you can use sparse keys that are all nearly orthogonal even though they intersect a little. For example, if d=100, and you use binary keys that all look like a vector with k=10 ones and 90 zeros (divided by \sqrt{10}, of course), two keys that have zero ones in common are perfectly orthogonal with zero dot product, but the keys that have only m=1 one in common have the dot product of 1/10, which may be sufficient for retrieval in practice.

    Finding out how many such keys can exist for given d, k, and m is a well known problem from a completely separate field of study, called the theory of block designs, a part of the theory of error-correcting codes. This is essentially a coding question: how many codewords with at most a given intersection can you fit for a given dimension, given codeword weight (number of ones), and given intersection constraint? I will not go into error-correcting codes and refer to, e.g., (Assmus, Key, 1992Huffman, Press, 2003), but the main relevant results here are the Hamming bound that is proven by counting and the more complicated Johnson bound. The Hamming bound says that without restrictions on the weight, for given d and m you can fit about

        \[A_2(d, m) \le \frac{2^d}{\sum_{l=0}^{\lfloor(m-1)/2\rfloor}{d\choose l}}\]

    binary keys. We are interested in large values of m, where you can get a good approximation for the denominator via the entropy of the relative distance:

        \[A_2(d, m) \le 2^{d\left(1-H(p)\right)+o(d)},\quad H(p)=-p\log p-(1-p)\log(1-p),\quad p=\frac{m}{2d}.\]

    This means that even if you require small intersections, you can fit an exponential number of codewords, just with a smaller exponent. The Johnson bound deals with vectors of fixed weight, and we will not go there now, but the point stands: you can fit a lot of codewords with small intersections, asymptotically much more than d, and this gives us a way to store a lot of vectors in associative memory as long as we are okay with imperfect retrieval.

    Now we have a much better intuition for what is going on in linear attention Transformers. But where will the improvements come from?

    Improving Linear Transformers

    While linear Transformers are more efficient than classical self-attention and reduce its complexity from quadratic to linear, this efficiency comes at a cost. Linear attention approximations can struggle with tasks that require precise content-based reasoning or long-term memory, and further research is clearly needed.

    How can we improve upon the architecture above? We have already seen that the kernel \phi can be different. But once you start thinking about updates to \mathbf{S}_t as storing key-value pairs in memory, the update itself also becomes a promising point of possible new approaches: maybe summation is not the best way to store things in memory?

    So at this point, we see that the linear self-attention structure breaks down into four decisions, each of which can suggest directions for improvement:

    • the nonlinear transformation \phi of the key and value vectors before storing them in \mathbf{S}_t;
    • the memory update rule for \mathbf{S}_t itself; let us call it f: \mathbf{S}_{t+1} = f(\mathbf{S}_t, \mathbf{q}_t, \mathbf{k}_t);
    • the normalization mechanism, which so far has been either absent or via direct accumulation in the vector \mathbf{u}_t; in theory, we could normalize the key, value, and query vectors separately, or just normalize the hidden state;
    • the mechanism for producing the output vector \mathbf{o}_t from the query \mathbf{q}_t and the hidden state matrix \mathbf{S}_t.

    I have illustrated the general scheme below, showing where these different items go in the architecture. Let us now consider these directions one by one.

    First, for the nonlinear transformation \phi Katharopoulos et al. (2020) suggested to use either the identity function or the exponential linear unit ELU, a variation of ReLU with nonzero derivative everywhere (plus one to make \phi(a) nonnegative):

        \[\phi(a) = \mathrm{ELU}(a)+1 = \begin{cases}a+1, & a\ge 0,\\ e^a, & a < 0.\end{cases}\]

    Here \phi is basically an activation function, operating independently on every component of \mathbf{k} and \mathbf{v}. However, in the previous section we motivated the function \phi as an approximation to the numerator of softmax, i.e., we would ideally want

        \[\phi(\mathbf{k})^\top\phi(\mathbf{v}) \approx e^{\mathbf{k}^\top\mathbf{v}},\]

    which is definitely not the case for ELU+1.

    The Performer architecture (Choromanski et al., 2021) introduced a version of \phi which is a much better approximation for softmax. They provide a detailed proof that we will not reproduce here, but in essence their approach, called FAVOR+ for Fast Attention Via positive Orthogonal Random features, uses random linear transformations in such a way that the expected result is indeed the softmax kernel shown above: they define

        \[h(\mathbf{x}) = \frac{1}{\sqrt{2}}e^{-\frac12|\mathbf{x}|^2},\qquad \phi(\mathbf{x}) = \frac{1}{\sqrt{m}}h(\mathbf{x})\left(\begin{matrix} e^{\mathbf{R}\mathbf{x}} \\ e^{-\mathbf{R}\mathbf{x}}\end{matrix}\right),\]

    where \mathbf{R} is an m\times d random matrix whose every row is drawn from the standard Gaussian in dimension d, and prove that the expectation of \phi(\mathbf{k})^\top\phi(\mathbf{q}) coincides with the softmax kernel \exp(\mathbf{k}^\top\mathbf{q}), and that

        \[\phi(\mathbf{k})^\top\phi(\mathbf{v}) \longrightarrow_{m\to\infty} e^{\mathbf{k}^\top\mathbf{v}}.\]

    Schlag et al. (2021) introduce the so-called deterministic parameter-free projection (DPFP), an approach where components of \phi are constructed to be orthogonal by design: if \phi_j(\mathbf{x})>0 then \phi_i(\mathbf{x})=0 for all i other than j. This can be achieved with ReLU activations if you just design them so that their nonnegative areas do not overlap. For example, \phi can map \mathbb{R}^2 to \mathbb{R}^4 as follows:

        \[\phi\left(\left(\begin{matrix}k_1 & k_2\end{matrix}\right)^\top\right) = \left(\begin{matrix} r(k_1)r(k_2) & r(-k_1)r(k_2) & r(k_1)r(-k_2) & r(-k_1)r(-k_2) \end{matrix}\right)^\top,\]

    where r(a)=\max(0, a) is the ReLU activation function. Note how regardless of the input vector \mathbf{k} all components of \phi(\mathbf{k}) except one are zero because either r(a) or r(-a) is always zero. The authors generalize this approach to higher dimensions as well; note that ReLUs are also very computationally efficient, much more so than computing exponents.

    Second, let’s turn to the memory update rule. As the number of vectors stored in associative memory increases over the matrix dimension d, the memory mechanism should ideally figure out which vectors to “overwrite”. This is especially important because in practice, you may get a new key-value pair that is similar to an already existing key that points to an already similar value, in which case you don’t really want to overwrite anything at all but rather update the value a little so that both keys will retrieve a good enough approximation of it. 

    Schlag et al. (2021) propose the following approach here: for a new key-value pair (\mathbf{k}, \mathbf{v}), retrieve \mathbf{v}' that is already stored in memory by the key \mathbf{k} (you can always do retrieval in associative memory, if we are not yet at memory capacity it will just return zero) and store a convex combination of \mathbf{v}' and \mathbf{v}. The coefficient of this combination, the “overwrite force” for this vector, can also be derived from the inputs. Formally, we define

        \[\mathbf{v}'_t = \mathbf{S}_{t-1}\phi(\mathbf{k}_t),\quad \beta_t = \sigma\left(\mathbf{W}^\beta\mathbf{x}_t\right),\quad \mathbf{v}^{\mathrm{new}}t = \beta_t\mathbf{v}_t + (1-\beta_t)\mathbf{v}'_t,\]

    and then in the matrix state computation we erase \mathbf{v}' from memory and write in \mathbf{v}^{\mathrm{new}}, getting

        \[\mathbf{S}_t = \mathbf{S}_{t-1} - \mathbf{v}'_t\phi(\mathbf{k}_t)^\top + \mathbf{v}^{\mathrm{new}}_t\phi(\mathbf{k}_t)^\top = \mathbf{S}_{t-1} + \beta_t\left(\mathbf{v}_t - \mathbf{v}'_t\right)\phi(\mathbf{k}_t)^\top.\]

    Third, for normalization you can use attention normalization as suggested by Katharopoulos et al. (2020) or, for instance, sum normalization where query and key vectors are divided by the sums of their own components. Normalization can be done only at the level of queries, keys, and values, or also at the output ot, and so on, and so forth.

    The possibilities are endless, and indeed, one can think of a lot of different modifications for the above formulas. Some of them explore different feature functions, others change how combinations and moving averages are computed, yet others add various gates to the architecture up to the complexity of an entire LSTM (Peng et al., 2021Beck et al., 2024). The summary table below is taken from a recent work by Yang et al. (2024), which in turn proposes yet another approach in this vein:

    Naturally, I don’t want to go over the entire table here; we are already acquainted with several rows in this table enough that you can mostly understand the motivation behind the others. But there is one more important direction that leads to interesting new ideas and that has been growing in popularity lately, so I want to explore it in more detail.

    Mamba: Transformers are State Space Models

    While linear attention provides a scalable alternative to Transformer’s self-attention, it still struggles with tasks requiring explicit reasoning over long-term dependencies or fine-grained temporal dynamics. In this section, we discuss state space models that provide an alternative perspective: instead of focusing on approximating attention, they model sequences as evolving states governed by differential equations. This still allows the system to handle long-range dependencies while at the same time learning structured dynamics inspired by control theory.

    To explain what is going on in Mamba, we need to take a step back yet again, this time to state space models. A state space model (SSM) is another way to process sequential input, very similar to RNNs in that an SSM also has a hidden state \mathbf{h}_t that is supposed to capture all relevant information about the current state of the system. But the state space model looks at system evolution from a continuous standpoint, considering the dynamical system

        \[\dot{\mathbf{h}}(t) = \mathbf{A}\mathbf{h}(t) + \mathbf{B}\mathbf{x}(t),\qquad \mathbf{o}(t) = \mathbf{C}\mathbf{h}(t)+\mathbf{D}\mathbf{x}(t).\]

    Here is an illustration:

    Note that the direct dependence of the output o(t) on the input x(t) can be thought of as a skip connection going around the dynamical system, so below we will assume that D=0.

    This approach has its roots in control theory; the famous Kalman filter (Kalman, 1960) is a special case of SSMs, and classical control theory has a lot of results on such linear dynamical systems (Jazwinski, 1970Kailath, 1980), spilling over into econometrics and generally time series analysis (Hamilton, 1994). 

    The equations above look just like a classical RNN; the main difference is that they are continuous, so we can hardly expect to be able to work with them unless we can discretize continuous signals and vice versa, turn discrete inputs (such as text) into continuous signals. In this approach, it is usually enough to consider the zero-hold model, where a discrete input is turned into a set of step functions with step size Δ, and a continuous signal is sampled according to the input timesteps. Discretization of dynamical systems proceeds via matrix exponentials that result from solving the differential equations above on an interval [t, t+\Delta t], where the input \mathbf{x}(t) can be assumed constant, so the solution is

        \[\mathbf{h}(t+\Delta) = e^{\Delta\mathbf{A}}h(t) + \left(\int_{0}^{\Delta}e^{\mathbf{A}\tau}\mathrm{d}\tau\right)\mathbf{B}(t).\]

    As a result, we can define discretized versions of the matrices \mathbf{A} and \mathbf{B} (see, e.g., Grootendorst, 2024 for a more detailed explanation) as

        \[\bar{\mathbf{A}} = e^{\Delta \mathbf{A}},\qquad\bar{\mathbf{B}} = \left(\int_0^{\Delta} e^{\mathbf{A}\tau}\mathrm{d}\tau\right)\mathbf{B} = \mathbf{A}^{-1}\left(e^{\Delta \mathbf{A}} - \I\right)\mathbf{B}\]

    and treat this discretized version of an SSM as a linear RNN with update rule (omitting \mathbf{D} as discussed above)

        \[\mathbf{h}_t = \bar{\mathbf{A}}\mathbf{h}_{t-1} + \barB\mathbf{x}_{t},\qquad \mathbf{o}_t = \mathbf{C}\mathbf{h}_t.\]

    Note that this is not the only way to do discretization, for example, Gu et al., 2022 use a bilinear method where

        \[\bar{\mathbf{A}} = \left(\mathbf{I} - \frac{\Delta}{2}\mathbf{A}\right)^{-1}\left(\mathbf{I} + \frac{\Delta}{2}\mathbf{A}\right),\qquad\bar{\mathbf{B}} = \left(\mathbf{I} - \frac{\Delta}{2}\mathbf{A}\right)^{-1}\Delta\mathbf{B}.\]

    Moreover, doing everything via discretizations of continuous functions has other advantages; for example, we can seamlessly handle missing data by simply continuing the discretization over a longer time period (where we do not have new data).

    Finally, we can also note that in this formulation, every output \mathbf{o}_t can be easily represented as a series depending on the inputs \mathbf{x}_i:

        \[\mathbf{o}_t = \mathbf{C}\mathbf{h}_t = \mathbf{C}\bar{\mathbf{B}}\mathbf{x}_t + \mathbf{C}\bar{\mathbf{A}}\mathbf{h}_{t-1} = \mathbf{C}\bar{\mathbf{B}}\mathbf{x}_t + \mathbf{C}\bar{\mathbf{A}}\bar{\mathbf{B}}\mathbf{x}_{t-1} + \mathbf{C}\bar{\mathbf{A}}^2\bar{\mathbf{B}}\mathbf{x}_{t-2} + \ldots,\]

    which can be thought of as a convolution operator: to get \mathbf{o}_t, we convolve the input series with the kernel

        \[\bar{\mathbf{K}} = \left(\mathbf{C}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}\bar{\mathbf{B}}, \mathbf{C}\bar{\mathbf{A}}^2\bar{\mathbf{B}}, \ldots, \mathbf{C}\bar{\mathbf{A}}^{L-1}\bar{\mathbf{B}}\right).\]

    \mathbf{K} is called the SSM convolution kernel, and if it is known, the SSM can be very efficiently computed in parallel during training, when we have the entire input sequence \mathbf{x}_t available, just like any autoregressive model. Computing \mathbf{K}, however, is a nontrivial task that also requires new tricks.

    But whatever the discretization formulas, the resulting RNN will not really work as intended. This is a classical approach that has been well-known for decades, and, of course, people have tried to apply it to machine learning. But they had always found this approach to lack long-term memory because of vanishing and/or exploding gradients due to all of this matrix multiplication, which is precisely the point of having a recurrent network in the first place.

    To add long-term memory, we need one more technique developed by Gu et al. (2020): we need to replace the matrix \mathbf{A} with the so-called “HiPPO matrix”, where HiPPO stands for high-order polynomial projection operators. The HiPPO approach begins with a different question: how do we compress the entire history of an input function f, namely f_{\le t} = f(x)|_{x\le t}, into a functional representation? The core idea is to approximate the function f_{\le t} of by projecting it onto a space spanned by orthogonal polynomials. With this approach, HiPPO can handle long-range dependencies without needing explicit priors on the timescale, which is crucial for data with unknown or variable temporal scales.

    Without going into too much mathematical details (for those, see the original paper), HiPPO operates as follows: for a function f where we are interested in operating on its current history f_{\le t} = f(x)|_{x\le t}

    • define approximation quality in the space of (square integrable) functions via a probability measure μ; this measure can be used to give recent information more weight than past history (or not);
    • choose the approximation order N and choose a polynomial basis of degree N; HiPPO usually works with either Legendre polynomials and a uniform measure on the history (HiPPO-LegS) or Laguerre polynomials and an exponentially decaying measure (HiPPO-LagT);
    • find the optimal approximation, i.e., find the coefficients of a polynomial g in the chosen basis that minimizes the approximation quality

          \[\|f_{\le t} - g\|_{L_2(\mu)} \longrightarrow_g \min;\]

    • the whole point of HiPPO is that one can construct a differential equation to maintain these coefficients incrementally; for a vector of coefficients \mathbf{c}(t), you can write down matrices \mathbf{A}(t) and \mathbf{B}(t) such that

          \[\dot{\mathbf{c}}(t) = \mathbf{A}(t)\mathbf{c}(t) + \mathbf{B}(t)f(t);\]

    • and finally, this differential equation can also be discretized to find a recurrence on the polynomial coefficients for the optimal approximation of a discrete time series f_k:

          \[\mathbf{c}_{k+1} = \mathbf{A}_k\mathbf{c}_k + \mathbf{B}_kf_k.\]

    Here is an illustration from the original paper that shows this sequence of steps:

    Gu et al. (2020) derive specific formulas for the HiPPO matrices. For their scaled Legendre measure (HiPPO-LegS) the matrix dynamics are

        \[\dot{\mathbf{c}}(t) = -\frac 1t \mathbf{A}\mathbf{c}(t) + \frac 1t\mathbf{B} f(t),\qquad \mathbf{c}_{k+1} = \left(1 - \frac{1}{k}\mathbf{A}\right)\mathbf{c}_k + \frac 1k\mathbf{B} f_k,\]

    where \mathbf{A} and \mathbf{B} are constant:

        \[A_{nk} = \begin{cases}\sqrt{(2n+1)(2k+1)}, & n>k, \ n+1, & n=k, \ 0, & n<k,\end{cases}\quad\text{e.g.},\quad\mathbf{A} = \left(\begin{matrix} 1 & 0 & 0 & 0 & 0 \\\sqrt{3} & 2 & 0 & 0 & 0 \\\sqrt{5} & \sqrt{3\cdot 5} & 3 & 0 & 0 \\\sqrt{7} & \sqrt{3\cdot 7} & \sqrt{5\cdot 7} & 4 & 0 \\3 & 3\sqrt{3} & 3\sqrt{5} & 3\sqrt{7} & 5\end{matrix}\right),\]

    and B_n=\sqrt{2n+1}.

    That was quite a lot of math that’s very different from what we are used to here—but bear with me, we are back to machine learning territory. At this point, we have a method that can take a time series as input and produce a good vector representation for its entire history; moreover, the method reduces to using a couple of matrices whose coefficients can be updated recursively with time too. This means that we can, for example, plug HiPPO into a regular RNN, adding another state ct and replacing the hidden state ht with a representation of its entire history; this has been done in the original paper on HiPPO as follows, for an arbitrary RNN update:

    In SSMs, the HiPPO matrix is used to initialize the transition matrix A, significantly alleviating the problem of long-range dependencies. It may sound a little strange because as soon as we begin updating the weights, the matrix A loses its HiPPO properties: it no longer corresponds to the Legendre or Laguerre polynomials, or to any orthogonal basis in the functional space at all. However, experiments show that this initialization does help a lot with implementing long-term memory.

    The second problem we need to solve is computational complexity: so far, SSMs require repeated multiplication by the discretized version of A, so the naive complexity is O(d2L), where d is the input vector dimension and L is the sequence length. The main contribution of the S4 model (structured state space sequence model) introduced by Gu et al. (2022) is a much faster way to compute all views of the SSM model, i.e., both recurrent matrices used at inference and convolutions used at training. The ideas of S4 would be way too mathy to put in this post; fortunately, I can refer to “The Annotated S4”, a detailed post by the S4 authors that shows all derivations and also provides the corresponding PyTorch code and illustrations. For now, let us just assume that all of the above can be done efficiently.

    The next step was taken by Smith et al. (2022) who moved from single-input, single-output SSM layers to multi-input, multi-output layers, allowing xt and ot to become vectors; their model is known as S5 (simplified structured state space for sequence modeling).

    With this, we finally come to Mamba (Gu, Dao, 2024), also known as S6 (S4 + selective scan). The main step forward in Mamba is recognizing that so far, the model dynamics have had to be constant: matrices ABC, and step size Δ can be trainable from mini-batch to mini-batch but they cannot depend on the input xt; otherwise, we wouldn’t be able to implement the convolutional kernel K which is key to efficient training. This significantly limits the expressive power of S4: its mechanism cannot do content-aware reasoning, it cannot choose which parts of xt are more important and filter out the rest, and so on.

    Gu and Dao (2024) introduce the selective scan algorithm that lets BC, and Δ (not A, though) depend on xt while still providing an efficient algorithm for training. In essence, they find a middle ground between the two extremes:

    • in RNNs and S4, the state has a (relatively small) fixed size so we cannot fit too much in the hidden state, leading to problems with long-term memory;
    • in Transformers, the state is basically the entire sequence, so there is no memorization problem (you have direct access to everything) but lots of problems with processing long sequences (that we have been discussing today and in a previous post);
    • the word “selective” in “selective scan” means that Mamba chooses which information to put in a state, with context-dependent mechanisms for putting something into the hidden state and ignoring other parts of the input.

    Again, the technical details of the algorithm are too involved for this post—it even makes use of hardware optimization, being specifically tailored for GPUs and TPUs. But the result is the Mamba block that can be stacked in a neural network. It includes the following selective state space model as a replacement for the attention mechanism:

    Mamba was big news. A viable alternative to Transformers that even outperformed existing open source language models with an equivalent number of parameters. So it is no wonder that researchers picked up this idea and ran with it, with a lot of papers already extending and improving upon the basic Mamba architecture. 

    For example (I’m only listing some of the most interesting ones):

    • Mamba was never limited to language modeling; the original paper already applied Mamba to audio processing and modeling genomic sequences; Vision Mamba (ViM; Zhu et al., 2024) is a good representative of how Mamba can be applied to image processing; they show improved results with an architecture very similar to the Vision Transformer (ViT; Dosovitsky et al., 2020) but based on Mamba blocks; another way to process images has been suggested in the VMamba architecture (Liu et al., 2024), which is an interesting combination of CNNs and Mamba;
    • U-Mamba (Ma et al., 2024) goes even further and shows that Mamba is not limited to Transformer-like architectures: this is a U-Net-based architecture intended for biomedical image segmentation, and the authors design a CNN-SSM block, a hybrid between convolutions and Mamba, which improves segmentation results;
    • among more advanced versions of image segmentation, SegMamba (Xing et al., 2024) considers 3D image segmentation while Video Vision Mamba (ViViM; Yang et al., 2024) does segmentation in video, and MambaMorph (Guo et al., 2024) uses a Mamba-based architecture to establish the correspondence between two important biomedical modalities, MR and CT scans;
    • MoE-Mamba (Pioro et al., 2024) adds the mixture of experts (MoE) idea to a Mamba block, leading to a much more efficient architecture; MoE variations of Transformers and other models are a separate can of worms that I plan to open in some future post.

    As you can see, the ideas of Mamba have been actively developed by the deep learning community over the last year… actually, no, you don’t see the full extent of it yet. I introduced a hidden constraint here: the original Mamba paper was first published in December 2023, and all the papers cited in the list above are from January 2024! In only a month, Mamba already became a staple of deep learning, and by now, a survey by Qu et al. (2024, last revised in mid-October) has 244 citations—not all of them are Mamba-based models, of course, but it looks like over a hundred, if not more, are Mamba variations published in 2024. 

    This is the crazy research landscape we are living in now, and, of course, I cannot give a full survey here, so I will only highlight a direct continuation: Mamba 2 (Dao, Gu, 2024), developed by the authors of the original, dives further into the Mamba algorithm and makes it even more efficient with its state space duality (SSD) framework. It very much looks like Mamba-based models are reliably beating Transformers in many long-context tasks, combining the efficiency of linear attention with the structured adaptability of SSMs.

    Conclusion

    Linear attention and state space models like Mamba represent a new wave of more efficient models that alleviate the quadratic complexity problem of basic self-attention. These models revisit foundational ideas from RNNs and associative memory but also redefine how we think about integrating memory and content-aware reasoning into neural architectures. They are already pushing the boundaries of scalable and content-aware sequence modeling, and this research direction is far from completely explored.

    In this post, we have discussed the basic ideas of linear attention; I have tried to explain the foundations of these models—the kernel trick, associative memory, state space models—that date back a long time. This is another case where recent results can be placed in the context of a machine learning timeline that dates back many decades; here is my take on the timeline of the main ideas we mentioned today:

    Once these ideas get picked up in a new form, such as Mamba, progress starts anew, and these days it proceeds at a breakneck pace. I hope that this post gives a clear understanding that this is still very much a work in progress, and new results will probably augment these ideas in the nearest future. Existing results already suggest many exciting applications: not only improved language modeling but also applications to genomics, image processing, audio processing, and more have already been explored in Mamba-like models.

    Moreover, we can already look ahead a little. State space models, kernel-based attention, and hardware-aware optimizations in Mamba hint at a future where memory-intensive applications such as long-context language modeling and large-scale genomic analysis are not only feasible but practical. In this future, neural networks may be able to dynamically tailor their computation to the input; perhaps we are witnessing the birth of a new paradigm for sequence modeling.

    As research in Mamba and its successors continues, we are also likely to see further breakthroughs in one of the most important issues that still remains to be solved: how can neural networks manage and process memory? In my opinion, memory is still an unresolved challenge; increasing the context size is not the same as having a working memory, but the selective state space models developed in Mamba actually come much closer. I am very excited to see what the next step will be.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Kolmogorov-Arnold Networks: KAN You Make It Work?

    Kolmogorov-Arnold Networks: KAN You Make It Work?

    Although deep learning is a very new branch of computer science, foundations of neural networks have been in place since the 1950s: we have been training directed graphs composed of artificial neurons (perceptrons), and each individual neuron has always looked like a linear combination of inputs followed by a nonlinear function like ReLU. In April 2024, a new paradigm emerged: Kolmogorov-Arnold networks (KAN) work on a different theoretical basis and promise not only a better fit for the data but also much improved interpretability and an ability to cross over to symbolic discoveries. In this post, we discuss this paradigm, what the main differences are, and where KAN can get us right now.

    Connectionism: the Foundation of Neural Networks

    One surprising feature of artificial neural networks is that they basically have not changed since 1943, when Warren McCulloch and Walter Pitts published their seminal paper, “A logical calculus of the ideas immanent in nervous activity”. Already in this paper, before even the Turing test, let alone the Dartmouth seminar (1956) and the first machine learning models, neurons were modeled as linear combinations of inputs with nonlinear activation functions:

    This is indeed a very reasonable approximation to what real neurons do, if you try to measure the frequency of spikes as a function of the neuron’s inputs. The function h has to be nonlinear if you want to have a network of neurons, because otherwise the whole network would be a composition of linear functions, so it would be just equivalent to a single neuron.

    In the early days of artificial intelligence, the nonlinearity h was usually the threshold function (Heaviside step function for the mathematically inclined): 1 if the input exceeds some threshold a and 0 otherwise. Later, researchers realized that you can’t really train a deep network with threshold activation functions: their derivatives are zero almost everywhere, and gradient descent does not work, so they switched to sigmoidal functions such as the logistic sigmoid and the hyperbolic tangent, which in essence represent “soft” differentiable versions of the threshold. Later yet, ReLU showed that just a little bit of nonlinearity suffices, and by now we also have functions found by automated search such as the Swish (Ramachandran et al., 2017) or its generalization, the ACON family (Ma et al., 2020).

    I will not go into more details here. The important thing for us now is that throughout the entire history of neural networks, only the exact form of the nonlinearity has changed. The basic construction of neural networks has remained the same: it is a huge composition of neurons, and each neuron is a linear combination of inputs followed by a nonlinearity. There exist other types of nodes in the computation graph—for example, the batch normalization layer also has trainable weights but it is a different function—but the vast majority of neurons in any modern network look like the picture above. For example, the self-attention layer in a Transformer, which we have discussed previously, does quite a few interesting things with queries, keys, and values, but these vectors are still linear combinations of input embeddings with trainable coefficients.

    This idea is known as connectionism: a large network of small simple units can represent very complex things in combination. Philosophically speaking, connectionism makes claims about cognitive processes, saying that our mental phenomena also can arise from a large composition of simple individual neurons. Here, connectionism has been historically at odds with computationalism, which says that the mind works by conducting formal operations on symbols, like an abstract computer (a Turing machine of sorts). There is no direct contradiction between the two—formal operations can be implemented on a large network of units—but there was still an interesting debate left: how would a connectionist theory of mind explain the logical properties of human cognition such as systematic relations in language cognition or compositionality of mental representations? There exist interesting answers to this question, and I will leave a couple of links to book overviews for the interested reader (Bechtel, Abrahamsen, 2002Marcus, 2003Maurer, 2021).

    We will return to connectionism vs. computationalism later, but, fortunately, we do not have to dive deep into the philosophical or neurobiological aspects of this debate. All that matters to us, lowly computer scientists, are the mathematical and algorithmic sides of the question: what kinds of functions can one represent with compositions of simple ones, which “simple” functions are needed exactly, and how can we find them and the necessary composition?

    Universal Approximation Theorems

    Neural networks work because even while each neuron is a very simple construction, their compositions can approximate any (continuous) function, with any given precision. Results of this class are known as universal approximation theorems. Specifically for neural networks, several such results were obtained in the late 1980s. In 1988, George Cybenko proved that neural networks with a single hidden layer and sigmoidal activations can approximate any continuous function (Cybenko, 1988). Concurrently with him, Hornik et al. (1989) developed a more general treatment, showing that feedforward networks with one hidden layer can approximate any real-valued continuous function over a compact set, even extending the result to measurable functions. This result was shown for “squashing” activation functions, that is, sigmoids—non-decreasing functions that go from 0 on the left to 1 on the right—but later Hornik (1991) extended it to other classes of activation functions.

    Note that these are classical existence results that give a nice reassurance that approximations exist for but do not actually guarantee that you can find it in reasonable time. Moreover, they do not constrain the size of the approximating neural network, and indeed, to approximate a complicated function with a network with a single hidden layer you might need exponentially many neurons.

    There exists an entire research direction proving extensions of these results to neural networks with various numbers of neurons, to deeper networks, bounding the approximation errors and so on, with many interesting and beautiful mathematical results. For example:

    • Maiorov and Pinkus (1999) constructed a continuous activation function that realizes lower bounds on approximation error for a feedforward network in practice (meaning that they have a construction for this function but it would be, of course, utterly impractical to actually use it);
    • Gripenberg (2003) showed that instead of making a single layer wider you can have a bounded number of neurons on each layer and grow the layers to infinity, still getting a universal approximator;
    • Lu et al. (2017) showed that there is a whole hierarchy of trade-offs between width and depth: there are classes of functions that can be realized by deep networks but cannot be realized by more shallow networks without an exponential number of neurons;
    • Yarotsky (2017) concentrated on networks with ReLU activations and provided specific bounds on their approximation errors, again comparing networks of different depth; in the same vein, Hanin and Sellke (2017) found the minimal width necessary for a deep network to be a universal approximator, again for ReLU activations, and later Shen et al. (2022) proved a series of tight bounds on the approximation rate of ReLU networks;
    • Guliyev and Ismailov (2018) constructed a two-layer feedforward network with 3d+2 hidden neurons in total with fixed weights that can approximate any continuous d-variable function with arbitrary precision, and so on, and so forth.

    But for today’s topic, instead of going further to recent results and considering the state of the art in universal approximation, we need to take a step back into the 1950s.

    Kolmogorov–Arnold representation theorem

    Andrey Kolmogorov was one of the most prolific mathematicians of all time, a rival of Euler and Gauss. He introduced the modern axiomatics of probability theory, generalized the law of large numbers, introduced a new notion of an algorithm and founded the theory of Kolmogorov complexity, created chaos theory in the famous Kolmogorov–Arnold–Moser (KAM) theorem, and much more. He revolutionized Soviet mathematical education, establishing some of the best mathematical schools in the world. While sometimes he had to do some questionable things (e.g., participated in the campaign against his former teacher Nikolai Luzin in the 1930s), he actually managed to navigate the Soviet ideological landscape perfectly, never losing his integrity and protecting other people whenever he could (see “The Kolmogorov Option” by Scott Aaronson).

    Vladimir Arnold was a pupil of Kolmogorov and a brilliant mathematician in his own right. Compared to Kolmogorov, Arnold gravitated more towards the continuous side of mathematics related to physics, including dynamical systems, stability theory and the above-mentioned KAM theorem, catastrophe theory, fluid dynamics, and much more; in “pure” mathematics Arnold worked in algebraic geometry and topology, also always trying to connect pure mathematics with real world applications. Like his teacher, Arnold was a key figure in Soviet mathematical education, authoring many textbooks and popular texts. He was very annoyed by the formal style of mathematical education originating in the writings of the Bourbaki group, and always advocated for an education that would provide a deeper understanding of the studied phenomena and connect the dots in different fields whenever possible.

    Kolmogorov and Arnold collaborated a lot, especially in the early stages of Arnold’s career when he was Kolmogorov’s student. The theorem we are interested in was published in 1957, when Arnold was only 20 years old. It says that any continuous function f of n variables can be represented as the following composition:

        \[f(\mathbf{x})=f(x_1,\ldots,x_n)=\sum_{i=0}^{2n}\Phi_i\left(\sum_{j=1}^n\phi_{i,j}(x_j)\right).\]

    Here \Phi_i and \phi_{i,j} are arbitrary functions of a single variable, and the theorem says that this is enough: to represent any continuous function, you only need to use sums of univariate functions in a two-layered composition:

    This means that if you need to represent a multivariate function of high input dimension, which is what any machine learning model is doing, it would be sufficient to find several functions of one variable. The only multivariate function you need is the sum, the rest can be pushed to univariate components. 

    If you think about it in terms of learning the functions, it means that the Kolmogorov–Arnold theorem gives a way around the curse of dimensionality, the reason why machine learning is hard and often counterintuitive. In low dimensions, every machine learning problem is easy: you can take any integral numerically, nearest neighbors are indeed close by and relevant, you can cover a reasonably sized part of the space with samples from the uniform or normal distribution—everything works great. In high dimensions, volumes grow exponentially, nearest neighbors grow apart, and integrals become impossible to find with any reasonable accuracy; this is exactly the reason why machine learning needs strong assumptions and complicated methods such as MCMC sampling. The term dates back to Richard Bellman (1957) who noted that dynamic programming also becomes very computationally hard in high dimensions; for a more detailed discussion of the curse of dimensionality, see, e.g., Marimont and Shapiro (1979) and Beyer et al. (1999).

    Moreover, the theorem also gives you the exact number of functions in each sum: the summation limits n and 2n refer to the same n which is the dimension of

    *** QuickLaTeX cannot compile formula:
    \mathbf{x]
    
    *** Error message:
    File ended while scanning use of \select@group.
    Emergency stop.
    
    
    . Compare this to exponential bounds on the number of neurons with sigmoidal activations that we mentioned earlier, and the Kolmogorov–Arnold theorem begins to sound quite tempting to use in machine learning, right? Unfortunately, the theorem itself does not give you any idea of how to find the functions \Phi_i and \phi_{i,j}; we will discuss this problem in the next section.

    I will close this section with another interesting mathematical tidbit. As you have probably heard, in 1900 David Hilbert, a great mathematician and a founder of mathematical logic, compiled a list of 23 so-called Hilbert’s problems. They were unsolved problems that Hilbert believed to be important for the development of mathematics, and his intuition was completely right: although some problems turned out to be either relatively easy or too vague to judge, many of them led to the development of entire new fields of mathematics. One of the still standing Hilbert’s problems, proving Riemann’s hypothesis, also made the list of the Millenium Prize Problems by the Clay Mathematics Institute, an update on Hilbert’s idea for the new century.

    As it turns out, the Kolmogorov–Arnold representation theorem arguably solves one of Hilbert’s problems, namely the thirteenth problem. It was already known to Hilbert that any seventh-degree equation can be reduced to the form

        \[x^7+ax^3+bx^2+cx+1=0.\]

    It seemed to be impossible to reduce this equation further, so it was a tempting hypothesis for Hilbert that the seventh-degree equation gives you an irreducible algebraic function, in the sense that you cannot reduce it to a superposition of functions of fewer variables. The Hilbert’s Thirteenth Problem is thus as follows:

    Consider a seventh-degree equation in the reduced form as above. Its solution is a function of three variables: the coefficients a, b, and c. Can it be expressed as a composition of a finite number of two-variable functions?

    This problem was originally posed for algebraic functions, i.e., functions that can be defined as a root of a polynomial equation, but later Hilbert asked a version of this problem about arbitrary continuous functions. Kolmogorov and Arnold were actually working on this problem, and they solved it in several steps, gradually reducing the number of variables required for elementary functions: first Kolmogorov showed that any continuous function can be represented as a composition of function of three variables, then his student Arnold reduced it to two (already solving Hilbert’s problem), and then came their main theorem.

    For the continuous version of Hilbert’s Thirteenth, the Kolmogorov–Arnold representation theorem is actually an overkill: it turns out that we only need arbitrary continuous functions of one variable and addition, which is technically a two-variable function. Note, however, that the algebraic version still remains unsolved: Arnold himself returned to it later with Goro Shimura (in the proceedings of a 1976 symposium on the legacy of Hilbert’s problems), but the efforts of mathematicians have not been successful so far (Vitushkin, 2004).

    From an Existence Proof to Practical Algorithms: Splines

    As we have seen above, the Kolmogorov–Arnold representation theorem is purely existential; the original result does not give you a good way to find the univariate functions. There have been several attempts to give a constructive proof that would provide an algorithm for finding these functions (Sprecher, 1965Sprecher, 1972Braun, Griebel, 2009), but these attempts were hardly practical from a machine learning standpoint. 

    Moreover, there have been negative results showing that univariate functions in the theorem can be very complex, even fractal, and learning them can be very hard; one of these results even says “Kolmogorov’s theorem is irrelevant” right in the title (Girosi, Poggio, 1989). So how could the KAN approach find univariate functions in the Kolmogorov-Arnold representation efficiently?

    There have been earlier approaches to build neural networks based on the Kolmogorov–Arnold representation theorem. Hecht-Nielsen (1987) and Lin and Unbehauen (1993) noted that the theorem, specifically in its constructive form by Sprecher (1965), leads to natural constructions of three-layer neural networks; see also (Sprecher, Draghici, 2002). Köppen (2002) developed an algorithm for learning the Sprecher representation. 

    However, these algorithms remained impractical for two reasons: first, because learning the univariate functions was still too hard, and second, because the shallow three-layer architecture required special algorithms to train and could not be trained by simple gradient descent. Let us begin with the first constraint.

    To make univariate functions easier to learn, various approaches to making the Kolmogorov–Arnold representation theorem practical centered on splines. Splines (Bartels et al., 1995Shumaker, 2015) are piecewise polynomial functions used to approximate or interpolate other functions. The key idea is that instead of using a single complex function to fit all data across the entire domain, a spline breaks the domain into smaller intervals and fits a much simpler function (usually a low-degree polynomial) to each interval. These polynomial pieces are smoothly connected at certain points called knots.

    If we use splines for univariate polynomial regression, formally it means that we consider the interval [a,b] where the data points lie and break it down with intermediate points t_i, the knots, getting k intervals:

        \begin{align*}a&=t_0\le t_1\le t_2\le\ldots t_{k-1}\le t_k = b,\\ [a,b]&=[a=t_0,t_1)\cup [t_1,t_2)\cup\ldots\cup[t_{k-1},t_k=b].\end{align*}

    The task is to find a polynomial p_i of degree d on each interval, p_i:[t_i, t_{i+1}]\to\mathbb{R}, so that:

    • the entire collection of polynomials minimizes some loss function for the data, usually the sum of squared residuals for all data points (x_n, y_n):

          \[\sum_{i=0}^{k-1}\sum_{n:x_n\in[t_i,t_{i+1})}\left(y_n-p_i(x_n)\right)^2\longrightarrow\min;\]

    • the polynomials come together smoothly, with their values and derivatives matching at the intermediate knot points; usually a spline of degree d would require all derivatives up to the (d-1)-th to match: for all i from 1 to k-1

          \[p_{i-1}(t_i)=p_i(t_i), \frac{dp_{i-1}}{dx}(t_i)=\frac{dp_{i}}{dx}(t_i),\ldots,\frac{d^{d-1}p_{i-1}}{dx^{d-1}}(t_i)=\frac{d^{d-1}p_{i}}{dx^{d-1}}(t_i).\]

    The main difference between splines and just piecewise interpolation lies in this last condition: splines impose additional constraints to make the connections continuous and even smooth. For example, if I plot three segments of data and try to learn quadratic or cubic regression on each, the results will follow the data but it will be three independent discontinuous curves (thin curves in the plot below). A spline regression would make the curves meet each other in the knots, and, moreover, meet each other smoothly, with matching derivatives (thick curves in the plots below):

    There is a rich field of applied mathematics on splines; they are often used for interpolation, i.e., to make a smooth curve that goes near all given points rather than approximating a least squares polynomial regression. The splines shown above are actually learned linear combinations of B-splines, i.e., polynomials that can serve as basis functions for splines of a given degree; there exist algorithms to compute B-splines for a given degree and knot points (Gordon, Riesenfeld, 1974de Boor, 1977Lee, 1982); these algorithms, in particular the de Boor–Cox iteration (de Boor, 1972Cox, 1972), are relatively efficient but become computationally hard for large numbers of knots, and we will return to this discussion later. This is also adjacent to the discussion of Bezier curves that are a special case of B-splines. I will not go into more details about splines and will refer to the numerous existing books and material on the subject (Bartels et al., 1995Gallier, 1999Shumaker, 2015Hovey, 2022).

    Splines provide a natural algorithm to learn smooth functions in a very expressive way; by changing the degree and number of knots we can freely change the number of parameters in a polynomial spline, from a piecewise linear function up to literally an interpolation polynomial for the data points. However, splines become much harder to use in high dimensions, so it would not be a good idea to replace regression models with splines. But the Kolmogorov-Arnold theorem removes the need for high dimensions altogether! Therefore, it is no wonder that splines caught the attention of researchers looking for efficient univariate functions to use in the Kolmogorov-Arnold representation.

    Leni et al. (2013) developed what was called a Kolmogorov spline network. Fakhoury et al. (2022) presented ExSpliNet, a neural network architecture based on B-splines. An even more interesting direction would be to change activation functions into learnable splines: after all, ReLU is just a linear spline with two components. Campolucci et al. (1996) and Guarneri et al. (1999) explored this idea back in the 1990s, and they were already building upon Jerome Friedman’s adaptive spline networks (Friedman, 1991). More recently, this approach has been developed by Scardapane et al. (2018) and Bohra et al. (2020).

    But somehow, these ideas have not made any splash in the deep learning world before very recently. Kolmogorov–Arnold networks also use learnable activation functions based on splines. What is the difference here, what are the new ideas, and why have KAN attracted significant attention in 2024?

    Kolmogorov–Arnold networks

    The first KAN paper (Liu et al., 2024) begins with a natural question: “are multilayer perceptrons the best nonlinear regressors we can build?” They begin by comparing the Kolmogorov–Arnold representation with usual multilayer perceptrons and note that they can make the former deeper and/or wider than the theorem suggests, which may help make individual univariate functions simpler. Here is the teaser image by Liu et al. (2024) that makes this comparison:

    They define a “KAN layer” with n inputs and m outputs as an n\times m matrix \boldsymbol{\Phi} of one-dimensional functions with trainable parameters: the mth output is the sum of the results of the n functions in the corresponding row. The original Kolmogorov–Arnold representation consists of two such layers: first, n inputs turn into 2n+1 outputs via \phi_{i,j} functions, and then \boldsymbol{\Phi}_i combine 2n intermediate values into a single output. When you look at the representation like this, it becomes clear how to stack more such layers, making a deeper KAN that represents a composition of such matrices of functions:

        \[\mathrm{KAN}(\mathbf{x}) = \left(\boldsymbol{\Phi}_{L-1}\circ \boldsymbol{\Phi}_{L-2}\circ \ldots \circ \boldsymbol{\Phi}_1\circ \boldsymbol{\Phi}_{0}}\right)\mathbf{x}.\]

    This is a very general representation; for example, a feedforward neural network (a multilayer perceptron) can also be represented as a KAN with linear functions (weight matrices) interleaved with activation functions (applied componentwise, so in the notation above it would be a diagonal matrix of functions):

        \[\mathrm{MLP}(\mathbf{x}) = \left(\mathbf{W}_{L-1}\circ h\circ \mathbf{W}_{L-2}\circ h \circ\ldots \circ \mathbf{W}_1\circ h\circ \mathbf{W}_{0}}\right)\mathbf{x}.\]

    On each KAN layer, every transformation function ɸ is introduced by Liu et al. (2024) as a weighted sum of a basis function b and a spline function s,

        \[\phi(x) = w_b\cdot b(x) + w_s\cdot s(x),\]

    where b is the sigmoid linear unit (SiLU) activation function and s is a B-spline:

        \[b(x) = \frac{x}{1+e^{-x}},\qquad s(x)=\sum_i c_iB_i(x).\]

    As a result, every KAN layer has O(n^2(G+k)) parameters, where n is the number of inputs and outputs, k is the degree of the spline, and G is the number of knots. Liu et al. (2024) discuss a lot of design choices and other aspects of KANs but at this point, let us proceed to an example, which I adapted from this tutorial.

    Let us begin with a nontrivial function that we want to approximate by a KAN; let’s take one of the functions used in the original paper:

        \[f(x,y) = \exp\left(\sin(\pi x) + y^2\right).\]

    To get a feeling for what this function looks like, the figure below shows the heatmap for f(xy) and several one-dimensional slices of this function in both directions:

    To train a KAN for this function, first we need to set up its structure; let’s mimic the structure of the function and set up a [2, 2, 1] KAN, i.e., a composition of the form

        \[{\hat f}(x, y) = \phi_{2,1}(\phi_{1,1}(x) + \phi_{1,2}(y)).\]

    After training, we get a rather small loss on the test set (produced by the same function), and the following learned functions:

    As you can see, \phi_{1,1} indeed resembles a sinusoidal function, and \phi_{1,2} looks suspiciously quadratic. Even at this point, we see that KAN not only can train good approximations to complicated functions but can also provide readily interpretable results: we can simply look at what kinds of functions have been trained inside the composition and have a pretty good idea of what kinds of features are being extracted.

    But we can do even better. Suppose that by looking on these plots, we have noticed that \phi_{1,2} is very similar to the quadratic function y^2 (actually, -y^2 in this case, but the minus sign is best left for the linear combination). KAN allows us to substitute our guess symbolically into \phi_{1,2}, fixing it to be \phi_{1,2}(y)=y^2 and training the rest of the functions. If we do that, we get a much better test set error, \phi_{2,1} will still look sinusoidal, and, most importantly, the resulting \phi_{2,1} will look much more like the exponent that it is in the original:

    So by now, we can also correctly infer the other functions in the composition. Doing this kind of symbolic reasoning requires an iterative process of looking at the functions and substituting some of them symbolically, but it still beats trying to analyze a multilayer perceptron by a very large margin. In practice, we will not know the correct form of the KAN composition but one can start with a larger KAN and reduce it, looking at what happens with the approximation error.

    Liu et al. (2024) suggested that this could be helpful for deriving complex dependencies in physics or applied math, when you need to explain experimental data with a formula; they show several interesting examples related to learning symbolic formulas for quantum physics (to be honest, I am very, very far from an expert on quantum physics so I will not attempt to explain the physics part here). They called KAN “a language model for AI + Science” and even provided a decision tree for choosing between KANs and regular MLPs in applications:

    In other words, KANs were suggested as helper models for semi-automated learning in science and generally situations when you would like to obtain a symbolic formula as a result.

    Making KANs Efficient: FastKAN and ReLU-KAN

    The original paper by Liu et al. (2024) was posted on arXiv on April 30, 2024. It was immediately noticed and received some coverage but the initial impression was that KAN applications are very limited due to their high computational complexity. The problem is that in order to construct B-splines of degree 3 that KANs are based on, you have to run the above-mentioned de Boor–Cox iteration that becomes a significant computational bottleneck for KAN, especially rescaling the spline grids.

    In less than two weeks, on May 10, 2024, Ziyao Li uploaded a three-page preprint to arXiv where he introduced FastKAN, a method that achieves equivalent results while having about 3x faster forward propagation. His idea is that B-splines are basically equivalent to Gaussian radial basis functions (RBF), a traditional way to extract local features in machine learning. Training a one-dimensional linear regression with RBF means that you are learning the weights a linear combination of features each of which depends on the distance between x and its center μi, with Gaussian RBFs having exponential decay around μi similar to the normal distribution:

        \[{\hat f}(x) = \sum_{i=1}}^mw_i\phi(\|x-\mu_i\|),\qquad \phi(r) = e^{-c\cdot r^2}.\]

    Li (May 2024) showed that you can replace B-splines with Gaussian RBF functions and achieve significant performance improvements with basically no difference in the results. With this simple alteration, KANs suddenly became much more practical—another wonderful example of low-hanging fruit that one can find in deep learning even now (although nowadays you really have to be quick about it).

    But that’s not the end of the story, of course. Another two weeks later, on June 4, 2024, Qiu et al. published a more detailed arXiv preprint that tried to alleviate the very same restriction. They replaced the B-spline basis functions with a new function composed of ReLU activations, specifically

        \[R_i(x) = \left(\mathrm{ReLU}(e_i-x)\cdot\mathrm{ReLU}(x-s_i)\right)^2\cdot 16/\left(e_i-s_i\right)^4.\]

    Here e_i and s_i are trainable parameters, which makes R_i a rather diverse family of functions, and \mathrm{ReLU}(x)=\max(x, 0) is the regular ReLU activation function; here is an illustration by Qiu et al. (2024):

    The main advantage of these basis functions is that they can be expressed via matrix operations such as matrix addition, dot products, and ReLU activation. This makes them much faster in practice than KANs based on B-splines; the authors report 5-20x improvements in training speed while also significantly improving the fitting accuracy.

    So one month after the original paper on KANs, we already had two much more efficient versions that could be scaled further than a regular KAN and applied much wider. These were the papers that started the hype. Half a year later, where are we with KANs now?

    Recent developments in KAN: Architectures

    In any field of science, you expect that a new idea that can open up a new field of study will be gradually developed afterwards; at first, the authors of the model will try to milk it for new results, then other researchers will see the potential and join in, and if the idea is good, ultimately a subfield will arise. The main difference between mostly any other field of science and deep learning is that while in “regular” computer science this process would take at least a few years, in deep learning it has already happened in several months. The original paper by Liu et al. (2024), posted on arXiv in April 2024, by mid-October has already over 250 citations (Google Scholar), and a curated list of links about KANs notes over a hundred interesting papers and resources that directly continue this research. So while in June 2024 a comprehensive survey of KANs was possible (Hou, Zhang, 2024), now, just half a year after the original publication, it is already basically futile to try and review everything people have done in this direction; below, I will survey a few papers that look most interesting to me.

    Let us begin with improved architectures; I will note two works in more detail and give a brief survey of several others.

    Bodner et al. (June 2024) were the first to introduce Convolutional KANs, an architecture that combines the KAN approach with convolutional networks. But here I want to highlight the work by Yu et al. (October 2024) who introduce a special kind of Chebyshev polynomial-based KAN convolutions (see also Sidharth, Gokul, May 2024), which is a reformulation of KANs designed to extract features from patches of the input tensor, just like a CNN:

    Then Yu et al. start off with traditional convolutional architectures and add layers of these new convolutions with residual connections around classical CNN layers such as ResNet (left) or DenseNet (right):

    The resulting architectures, called Residual KAN (RKAN), shows performance improvements on classical datasets; the authors especially note that RKAN’s performance benefits grow with the complexity of the dataset and model size, suggesting that such feature extraction units can be beneficially added to many different architectures.

    Yang and Wang (September 2024) present the Kolmogorov–Arnold Transformer (KAT), a model that replaces MLP layers in transformers with Kolmogorov-Arnold Network (KAN) layers. Their main applications lie in computer vision tasks, so their teaser image shows ImageNet accuracy and compares KAT with ViT-like models:

    The idea is that while KANs are known for their parameter efficiency and can learn powerful and concise representations, it is challenging to integrate KANs into large-scale models such as the Transformer. The paper lists three core problems:

    • inefficiency of B-spline functions in computations on GPUs,
    • exponential growth of parameters in KANs, and
    • difficulties in initializing KAN weights for convergence in deep networks.

    To address these issues, KAT introduces respectively three key innovations:

    • rational activation functions that replace B-splines with rational functions, which are better suited for modern hardware and allow for an efficient CUDA implementation,
    • Group KAN, where activation weights are shared across groups of neurons, reducing the number of parameters and computational load, and
    • variance-preserving initialization, an approach that initializes activation weights so that variance in activations remains the same across layers.

    As a result, KAT can successfully integrate KANs into Transformers and achieves several state of the art results in vision tasks, including image recognition, object detection, and semantic segmentation, where KAT outperforms traditional Transformers with MLP layers. For example, on the ImageNet-1K dataset the KAT model exceeded the accuracy of a ViT model of the same size by 3.1%, which is no small feat given that the overall accuracy is already at 82%. Performance improved even further when KAT was initialized with pretrained ViT weights.

    Among other news in KANs, let us note the following:

    • Recurrent KANs (RKAN) and Temporal KANs (TKAN: Genet, Inrizillo, May 2024) apply KANs to time series data by developing recurrent architectures with the KAN approach; RKAN are parallel to standard RNNs, and TKAN is an adaptation of LSTM to KANs;
    • GraphKAN (Zhang, Zhang, June 2024) inserts KANs into graph neural networks (GNN), also reporting improved feature extraction performance;
    • UKAN (Li et al., June 2024) introduces KANs into the classical U-Net architecture that has been widely used in computer vision, specifically in image segmentation;
    • DropKAN (Altarabichi, July 2024) is a special form of dropout that is shown to improve KANs;
    • Higher-order-ReLU-KANs (HRKANs; So, Yung, September 2024) extend the ReLU-KAN method, which is based on a square of a ReLU activation, to higher degrees of ReLU activations, showing improved performance;
    • and many more interesting developments have not yet been followed up much but may open new possibilities for research, including Gaussian Process KAN (GP-KAN; Chen, July 2024), Rational KANs based on rational functions (Aghaei, June 2024), and Fourier KANs where learnable activations are modeled as Fourier series (Mehrabian et al., September 2024).

    However, various KAN-based architectures are just a means to an end; what have the ends been, i.e., how have KANs been used in practice?

    Recent developments in KAN: Applications

    We have noted above that the original KAN were developed in part with scientific applications in mind: KANs can yield symbolic results and explain their predictions with compact and readily interpretable formulas.

    The next step in this direction was taken by the KAN 2.0 approach developed by Liu et al. (August 2024). The goal of KAN 2.0 is to create a two-way synergy between KANs and scientific knowledge, both embedding prior scientific knowledge into KANs and extracting new scientific insights from them:

    Architecturally, the authors make several interesting contributions:

    • a variant of KAN called MultKAN that includes additional multiplication nodes in KAN layers, enhancing the network’s ability to model physical processes that involve multiplicative relations; standard KANs would be hard pressed to approximate f(x,y)=xy, while MultKAN would do it with a single multiplication node;
    • KAN Compiler (kanpiler), a tool that converts symbolic formulas into KAN structures by parsing them into tree structures and inserting the trees into the network directly; this tool is responsible for much of the left-to-right arrow in the diagram above, allowing to incorporate prior symbolic knowledge into the network; the authors also develop the opposite tool, tree converter, to convert KANs into tree graphs;
    • revealing modular structures, i.e., enforcing tighly connected modular structures within KANs while minimizing inter-module connections, which helps capture the often separabile and symmetrical scientific models.

    As a result, Liu et al. (2024) show how KANs can be applied to discover and interpret physical laws, including:

    • Identifying conserved quantities in physical systems, such as energy and angular momentum in a harmonic oscillator;
    • Lagrangians, where KANs are trained to approximate the Lagrangian for simple mechanical systems such as a single pendulum or relativistic mass in a field;
    • discovering hidden symmetries, with the example of semi-automatically discovering translational invariance in the Schwarzschild black hole metric that took 17 years for physicists (Painlevé and Gullstrand) to discover in the 1920s.

    This still looks like the most direct practical application of KANs: their direct relation to symbolic formulas, with easy conversions back and forth, may lead to important discoveries.

    However, this is not the only application. Models that I mentioned in the previous section all come with convincing practical results where KAN-based architectures outperform similar architectures without KANs. Let me note a few more interesting papers that introduce new applications of KANs:

    • Nagai and Okumura (July 2024) integrate KANs into molecular dynamics (MD) simulations to improve the accuracy and efficiency of interatomic potentials; in modern MD simulations, these potentials are commonly modeled by neural networks, and the authors note that KANs result in a significant reduction in computational costs compared to potentials based on classical neural networks: KANs can approximate potential energy surfaces with a simpler representation without sacrificing accuracy;
    • Ge et al. (August 2024) present TC-KANRecon, a KAN-based approach for magnetic resonance imaging (MRI) reconstruction; they adopt KANs to strike a better balance between image denoising and structure preservation, and produce improved reconstructions within a given computational budget;
    • Aghaei (September 2024) leverages KANs to solve optimal control problems; his KANtrol framework is an adaptation of the well-known approach of physics-informed neural networks (Raissi et al., 2019) which embed prior physical knowledge into neural networks; this work shows that, unsurprisingly for us by now, KANs handle this prior knowledge better and produce better approximations to both the control functions and state evolution in optimal control problems;
    • GNN-SKAN (Li et al., August 2024) a novel approach integrating Kolmogorov-Arnold Networks (KANs) with Graph Neural Networks (GNNs) specifically to improve representation learning for molecular graphs; the authors develop a special variant of KANs called SwallowKAN (SKAN) for this and show better generalization performance to diverse classes of molecular structures.

    As you can see, most of these applications still center on using KANs for science, inferring mathematical dependencies from data in various domains; time will tell if preliminary promising results in other directions such as image processing convert to practical models.

    Conclusion

    I originally thought this post would be relatively short; Kolmogorov–Arnold networks seemed like an interesting idea that would make for a good case study of “something completely different” in the field of deep learning that might or might not lead to good results in the future. However, as I fell deeper and deeper into the rabbit hole of KANs, they seemed more and more promising, so this post had gradually turned into a full-sized review.

    I cannot but imagine how interesting it would be to pair KANs with an advanced LLM that might try to automatically notice what functions are being learned. An LLM will tirelessly try different approaches in what seems to be a perfect match for their capacity for creative data analysis without too much intermediate logical reasoning. The o1 family of models already looks like a very promising candidate for this LLM (see my post on o1-preview), and the models will only get better from here.

    But Kolmogorov–Arnold networks still do make for an excellent case study. Based on an idea that had been around forever, KANs were introduced at the end of April 2024. It is October now, and KANs have already blossomed into a well-developed research direction, with dozens of papers introducing new directions and applications. In this post, I have tried to give a brief overview of this direction, and I believe it is an interesting one, but my main point is that this is but one of many possible ideas worth exploring. I am sure that deep learning has many more such ideas in store, waiting for researchers to discover them; good luck!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • OpenAI’s o1-preview: the First LLM That Can Answer My Questions

    OpenAI’s o1-preview: the First LLM That Can Answer My Questions

    OpenAI’s o1-preview has been all the buzz lately. While this model is based on the GPT-4o general architecture, it boasts much improved reasoning capabilities: it can ponder the question for about a minute, reason through multiple possibilities, and arrive at solutions that could not be generated from a single try of GPT-4o. In this post, I discuss the o1-preview model but mainly present the most striking advantage of o1-preview over all previous LLMs: it can meaningfully answer questions from a quiz game called “What? Where? When?”. At this point, it probably does not sound all that exciting compared to winning math competitions and answering PhD level questions on science, but let me elaborate.

    “What? Where? When?”: a Game of Trick Questions

    There have already been many responses to OpenAI’s o1-preview, and this post is also one of them. We will discuss the model and what new opportunities it offers below. But first and foremost, this is a post of love for a game called “What? Where? When?” (“Что? Где? Когда?”), usually abbreviated to ЧГК in Russian; I don’t expect the English acronym WWW to catch on but I’ll stick to it throughout this post for brevity.

    The rules are simple: teams of at most six players are answering questions. The question is read for all teams, they are given one minute to discuss and arrive at the answer. During the discussion, access to the Web or other reference material is not allowed. At the end of a minute, the teams are given another 10 seconds to write down their answer on a piece of paper, the answers are collected, and the correct answer is announced. The team that has answered the most questions correctly wins. I’ve made an illustration just in case, but the rules are really very, very simple:

    What differentiates WWW from just about every other pub quiz in existence is the style of questions. Here is one (to avoid authorship issues, all examples are questions that I personally wrote, usually at some point between 2005 and 2015 when I was actively preparing questions for the game):

    The Sunday Times wrote about this person, born in the 1930s, that his work represents a ceiling for wide audiences, even though in principle no one is stopping you from consuming more elite art. Write the last name of this person.

    Naturally, you are not supposed to know what The Sunday Times wrote at some unspecified point of time. Instead, the question is supposed to draw on your general knowledge but also require nontrivial logical and intuitive jumps to arrive at the answer. At the same time, the answer should be unique and no other answer should fit the question; appeals are allowed for cases when this principle is violated, because the question’s author sometimes cannot think of every possible alternative.

    Try thinking about the question above by yourself for a little while before reading on. What are your thoughts?

    A seasoned WWW player could reason through this question somewhat like the following:

    • the question directly tells us the following facts:
      • the guy in question (it’s a guy because of the “his” pronoun) worked on some kind of art;
      • his art is elite but not too elite to be incomprehensible so he is relatively well known to wide audiences;
      • he was born in the 1930s so his best work was probably done at some time between the 1960s and 1990s; he might be alive now (though very old);
    • based on this, you could propose some hypotheses, but, of course, this information is insufficient to single out one person (e.g., David Hockney fits quite well) so we need to consider indirect clues:
      • what parts of the question could constitute hints?
      • there is one candidate, the word “ceiling”; the question goes out of its way to mention this word, it’s not the most natural word for the context, and the actual information content would not change if it didn’t mention “a ceiling for wide audiences” at all, so the word itself must be important;
    • now comes the intuitive jump that you have to make:
      • combine the word “ceiling” and actual information that his art represents this ceiling although “more elite art” is also technically available to everyone;
      • how would you describe a situation like this, where a person stops at some point although technically there is no barrier to go further? maybe in a different context?
      • the jump is that you could describe this situation as a “glass ceiling”, a phrase that usually relates to advancement opportunities for oppressed demographics but whose meaning also fits the situation in this question;
    • when you arrive at the phrase “glass ceiling”, you already have the answer; it only remains to note that Glass is also the last name of a famous composer; finally, you can verify that:
      • Philip Glass is indeed an elite composer whose work is nevertheless widely known and popular, so he fits the facts;
      • you probably don’t know for sure that he was born in the 1930s but it also fits your general knowledge about him; in WWW, dates usually provide general context for the question, e.g., in the musical arts you could be sure it’s not Beethoven or Kanye West but you probably wouldn’t exclude John Lennon (born 1940) or Ennio Morricone (born 1928) because you wouldn’t know their exact birth years;
      • another confirmation lies in the fact that the question specifically asks for the last name rather than just asking to name the person, and the last name is what is used in the key phrase; it could be a coincidence but it also could be a weak hint;
    • and indeed, Philip Glass is the correct answer.

    As you can see, this is definitely not direct trivia. When you think about a WWW question, you have to make a lot of assumptions and jumps that are not directly supported by either facts or logic. An important skill is to reason backwards from the question: why it is phrased in the exact way that it is, what the author has been trying to convey; in the example above, this reasoning singles out the word “ceiling”.

    In this post, I’m describing the “competitive” version of the game, where multiple teams compete against each other, but I have to note that it originated from a Russian TV show called What? Where? When? where a single team of six players answers questions sent in by the viewers (naturally, the questions are selected and edited in advance, otherwise it wouldn’t be a fun game at all). This is literally the longest running show on Russian television, originally airing in 1975 and not changing too much since then. In 2010, ABC licensed What? Where? When? under the name Million Dollar Mind Game, and while they did a good job capturing the style of questions and the spirit of the game (you can find the show on YouTube, in 360p quality alas), it didn’t take in the U.S. and was canceled after a season or two.

    Can AI Models Play “What? Where? When?”: The Surprising Development of o1-preview

    I have been playing WWW for… God, how time flies, for 25 years already. To me, this is a perfect hobby because while it is competitive, it not only never gets boring itself—questions never repeat—but also does not involve any boring preparation bits that basically any other sport would. Scrabble tournaments make you memorize dictionaries (in a famous case, Nigel Richards won the French-language Scrabble World Championships without any knowledge of the language!), and Jeopardy pros learn a huge amount of trivia, but here, while a high level of general knowledge is required, you mostly learn it by osmosis through playing the game itself. Simply memorizing a lot of facts won’t get you very far in the game.

    Ever since I started working in machine learning and natural language processing specifically—another “time flies” moment, that was almost 15 years ago—I have been testing whether ML models could be any good in WWW. Naturally, they never could. There was absolutely no way an NLP model of the 2010s would actually understand what’s required in a question. You can google the answer fairly often because WWW is mostly a game based on real facts about the real world; in the example above, the “ceiling for wide audiences sunday times” Google query gives me the correct answer at #2 in the list. But even if you gave models access to the Internet, there was no way models would understand what to google for and how to extract the correct answer from retrieval results. This was just way, way above the capabilities of AI models of the time, it always looked like something from a very distant future.

    When LLMs came about, I also tried to ask them questions like the above. And they also inevitably failed. LLMs know a lot of facts, and some of them can google behind the scenes (recall our post on RAG), so sometimes they would answer a question simply because they knew the factual answer. But none showed any capability to actually play the game; even a month ago, WWW looked like it would be possible only some relatively long time in the future.

    Well, the future is here now. OpenAI’s new model, o1-preview, can play WWW. Moreover, o1-preview can show you its chain of reasoning. So here is how it reasoned through the question above in 44 seconds (under a minute, although I don’t know how to limit its thinking time right now): 

    Identifying key characteristics 

    I’m focusing on a person born in the 1930s, whose work is described as a ceiling for wide audiences, yet not preventing the consumption of more elite art. The task is to find this individual’s last name.

    Unraveling clues

    I’m digging into a quiz clue, parsing it creatively to identify a person from the 1930s whose work represents a broad artistic ceiling, hinting at “glass ceiling” or the “Sistine Chapel ceiling” as potential leads for further exploration.

    Piecing together

    I’m gathering info on Philip Glass, examining connections to “Glass Ceiling” and the Heisenberg Uncertainty Principle. Considering his work’s accessibility and its metaphorical implications.

    Weighing possibilities

    I’m gathering clues about an artist from the 1930s with mainstream appeal. Considering the “no one is stopping you” phrase and connections like “Glass Ceiling,” “Glass,” or “Peter Max.”

    Analyzing the phrase

    Considering various interpretations like physical barriers, metaphors for invisible obstacles, and literal elements of ceilings or artists named “Ceiling”.

    Revisiting artists and interpretations

    I’m brainstorming artists and interpretations, considering Michelangelo’s Sistine Chapel, the movie “Glass Ceiling,” the Beatles, and “Ceiling Cat.”

    Pondering cultural icons

    I’m reflecting on cultural icons like Andre Rieu and the “Mona Lisa,” considering how they might symbolize the pinnacle of mainstream audience art.

    Weighing the details

    I’m thinking through the idea that the person’s work, while accessible to wide audiences, may still hold a glimmer of something more exclusive, hinting at a deeper connection or higher tier.

    Weighing options

    I’m noticing the possibilities involve John Williams, Andy Warhol, Frank Lloyd Wright, Norman Rockwell, or James Cameron, among others, by examining birth years and the “ceiling” clue.
    Examining possibilities

    I’m thinking through the clue “No one is stopping you from consuming more elite art,” which might imply a barrier or wall. It’s worth considering artists like Glass or Wall, or even the term “Ceiling” to explore.

    Weighing choices

    I’m mapping out notable artists born in the 1930s, like Philip Glass and Yoko Ono, and considering the Japanese word for roof, “Ten.”

    Adjusting my approach

    Thinking through artists like Richard Serra and John Cage, but considering the phrase “No one is stopping you in principle from consuming more elite art”?

    Evaluating limits and creators

    Thinking about “Glass Ceiling” and its connotations with limits and barriers, I’m considering artists, authors, and cultural references, especially those born in the 1930s, to identify the intended creator.

    Piecing together clues

    I’m considering artists or figures whose work defines a cultural peak, possibly “Block” or “Lid” in their names. The moon angle hints at Neil Armstrong, born in 1930.

    Piecing together

    I’m examining various artists and subjects, like Carl Sagan and Robert Smithson, to understand the clue’s connection to “ceiling” and “Space.”

    Connecting concepts

    It’s interesting to see how principles like the Principle of Uncertainty or ideas like Superconductivity might align with the clue, though they seem less likely for now.

    Weighing possibilities

    OK, let me see. I’m thinking about whether the clues point to Murdoch or Brecht, considering Rupert Murdoch’s birth and his connection to media, or Brecht’s fourth wall concept.

    Weighing possibilities

    I’m mapping out an artist connected to the 1930s or an uncertainty principle, considering figures like “Heisenberg” or others associated with paradoxes.

    Piecing together musical influences

    I’m gathering info on Philip Glass, whose minimalist compositions have a broad appeal yet maintain an elite essence. This aligns with the idea of a “Glass Ceiling” in art.

    Note how although the model has arrived at the correct answer fairly quickly, it has also generated a lot of interesting and creative possibilities, including:

    • a list of suitable artists; John Williams, Andy Warhol, Frank Lloyd Wright, Norman Rockwell, and James Cameron all sound plausible to me except for Rockwell who was definitely born earlier than 1930s—by the way, this suggests that o1-preview was not googling anything;
    • a slightly different angle for the “no one is stopping you” hint, in that it “might imply a barrier or wall”; it even linked the hint to the “fourth wall” concept which is also a salient concept often appearing in WWW questions;
    • completely different directions that could be promising, such as translating the key words to Japanese.

    This is exactly how a good player would try to approach this question, and this is exactly what all models so far, including GPT-4o, the direct predecessor of o1-preview, are failing at. I asked GPT-4o the question above, regenerating the answer several times; while it was clear from its explanations that GPT-4o did understand the question, the answers I obtained were Alfred Hitchcock, Walt Disney, Andy Warhol, and Stanley Kubrick. GPT-4o definitely “knew” that Disney and Hitchcock had both been born much earlier than the 1930s but still gave these answers, with the most popular answer being Alfred Hitchcock (5 times out of 8); several times, GPT-4o explicitly wrote that Hitchcock was born in 1899 but still gave this answer.

    So what’s so special about o1-preview? Let’s try to find out.

    Reasoning during Inference: A New Scaling Law?

    As usual with modern state of the art LLMs, there is little information explicitly given by OpenAI about the structure of its o1-preview model or its training regime. This is natural from the commercial point of view and probably a good thing from the security point of view (recall our post on the dangers of AI), but the net result is that we, like all outside experts, are mostly reduced to guesswork, with the only definitive sources being OpenAI’s announcement post and the OpenAI o1 System Card, which is safety-oriented and does not provide further information about the model itself.

    The post vaguely gestures at being better at chain of thought reasoning. I hope to roll out a detailed post on chain of thought techniques in the near future, but, alas, so far it doesn’t look like the o1 family will meaningfully contribute to the scientific part of it. In the quote below, the only load-bearing words are “reinforcement learning”:

    Through reinforcement learning, o1 learns to hone its chain of thought and refine the strategies it uses. It learns to recognize and correct its mistakes. It learns to break down tricky steps into simpler ones. It learns to try a different approach when the current one isn’t working. This process dramatically improves the model’s ability to reason.

    I don’t have a personal opinion on what exactly this reinforcement learning method is. It could be RLHF applied to chains of internal reasoning. Or it could be something more involved. For example, Subbarao Kambhampati offers a very interesting speculation; he suggests that the o1 family learns to reason in a way similar to how AlphaZero learns to play board games, with text continuations playing the role of moves and “game results” being correct answers as evaluated by external LLMs. For a collection of this and other speculations, see an excellent (as always) survey post by Zvi Mowshowitz.

    Whatever the details, the result is that o1-preview introduces a whole new paradigm to large language models. We have discussed many times (here and here, for instance) that LLMs are token prediction machines: they take in as much context as they can (again, see my previous post on context) and then produce the output token by token, autoregressively, never looking back.

    The o1 family are still LLMs, but they produce many different outputs, check out different possibilities, combine the results, and generally “think through” the problem. This idea is not novel in itself—it is exactly the premise of chain of thought techniques. But OpenAI could make it work on an unprecedented scale. Look at the plot with the scaling law they report; the Y-axis shows the USA Math Olympiad (AIME) results and the X-axes are two different computational budgets:

    The plot on the left is standard: more train-time computation leads to better performance with a log-linear dependency. But the plot on the right is completely novel: it shows that o1 models can actually make good use of test-time (inference) computational resources! Basically, it means that the longer you allow an o1 model to think about a problem, the better it does; I don’t think a scaling plot like that has ever been achieved before with LLMs.

    Naturally, this leads to increased costs; as you probably have already noticed, o1-preview comes with some rather strict constraints on usage and steep prices for API access. But costs have a tendency to decrease over time due to algorithmic improvements and cheaper hardware, while peak performance, once achieved, stays with us forever.

    And the performance jumps are very impressive. I mostly devote this post to WWW since this is my personal example where I can add something new to the discussion, but answering trick questions is definitely not the most productive use of o1-preview’s computational resources. Here is the performance comparison reported by OpenAI:

    In all three cases, o1 blows GPT-4o out of the water. GPT-4o could solve some high-level mathematical Olympiad problems but o1 makes it to the Olympiad, scoring among top 500 participants in the US this year. I’d love to see Claude Opus and Gemini 1.5 Pro on this plot since they are better at math than GPT-4o, but I don’t believe they would be as competitive. Coding has much improved, with o1 breezing through advanced competitive coding problems. Finally, the GPQA Diamond dataset (Rein et al., 2023) is not a high school science test; it contains questions that human Ph.D. students in the corresponding fields tend to answer with 65-75% accuracy when given full access to Google and over 30 minutes of time. On this test, o1 exceeded human expert level.

    Here is a more detailed breakdown of various categories and benchmarks:

    Note that additional reasoning power almost doesn’t help at all in tests on the English language, public relations, the basic SAT test, and English literature: the o1 model is not better at writing than GPT-4o. Indeed, if you ask humans which model they prefer, in terms of writing and editing they are completely equivalent:

    But in anything that requires reasoning, formal logic, and especially long chains of derivations, o1 is just much, much better. Looking at the performance plots, it is hard to believe that o1 is not a new level of LLMs (that is, GPT-5) but just a novel way to fine-tune the same “4-level” LLMs that have been around for more than a year. Still, this is exactly what happened, and this same method of improvement would probably apply to a new generation of LLMs as well.

    With this, let us go back to the game of questions.

    The WWW Dataset

    People have been writing questions for the sports version of WWW since the 1980s. Starting from the 1990s, questions have been collected in a large database, published at “The WWW questions database”. The interface is a little dated, and there is also a more up-to-date database at “Got questions”. You can easily scrape both websites, and back when I was trying to apply NLP models at scale there was no problem to contact the maintainers and obtain a dump of the database directly.

    But, of course, virtually all of the questions are in Russian. This is not a problem for o1-preview, it’s perfectly capable to play in the original Russian. But if you want to translate the questions and create a dataset for the English speaking world to understand, you run into a lot of trouble.

    In this post, examples are in English because I have translated them. I did not try too hard, I just looked through my questions in the database and chose the ones that would be easy to translate. And as I was filtering the questions, I was only choosing about 1 out of 5 for translation; if I really tried my best I would maybe end up with one out of three or so. The rest would be wordplay in Russian, references to Russian language culture little known among people who don’t speak the language, references to exact quotes in Russian, and so on. I obviously can’t show you the wordplay, but here are a couple of examples that can survive in translation but that I wouldn’t use for an English speaking audience.

    1. During a social gathering, a famous chess grandmaster Salo Flohr was introduced to Svetlana Alliluyeva. They were chatting for a few minutes before Flohr said that he was feeling uncomfortable and asked Svetlana… what?
    2. In his discussion of a certain genre of music, Romain Gary tells the readers how Russian gentry called their serfs. Name this genre of music.

    Here are the answers:

    1. Her patronymic. In fact, this is a historical anecdote about Salo Flohr’s absent-mindedness: Svetlana Alliluyeva was Joseph Stalin’s daughter, and everyone knew that, so when Flohr said he was feeling uncomfortable addressing her by her first name and asked for her patronymic that was not the wisest of questions. Here, the question assumes that you know who Svetlana Alliluyeva was, which Russian language players do and players from other backgrounds probably don’t.
    2. Soul. The Russian serfs were often called “souls” (cf. Gogol’s “Dead Souls”), and it is an interesting coincidence that soul music, while having absolutely nothing to do with Russian serfs, also originated in an oppressed demographic bound to servitude. This is an opposite example: if you know that serfs were called “souls” (Gogol’s novel is indeed a great book known to Westerners too), in English the question becomes very straightforward while in Russian you still have to make the jump to translating “soul” into English (and not Romain Gary’s native French). Translation is a common tool in WWW, especially to/from English since you can assume that players have a basic knowledge of English but not other languages.

    More recent questions are on average easier to translate, as WWW is turning more and more towards Russian speaking people in other countries who do not have this exact cultural background characteristic of Soviet high schools. But there is still a lot of untranslatable wordplay and exact quotes.

    Therefore, I believe it still requires human effort to choose translatable questions and then translate them, and I hope that an effort like that could be organized. This would be a great dataset for LLM testing, and also—who knows—it may get English speakers to become interested in the game! I hope to organize this effort at some point, and if you are interested in helping out please contact me directly.

    More Fun Examples

    Last week, I spent a whole day asking WWW questions to LLMs. I did it in part to enter some of the questions to the “Humanity’s Last Exam” dataset, an interesting initiative by the Center for AI Safety and Scale AI (the announcement was given by Dan Hendrycks and Alexandr Wang). The interface of the exam very conveniently asks the question to five top of the line LLMs: GPT-4o, Claude Sonnet 3.5, Google Gemini 1.5 Pro, o1-mini, and o1-preview.

    Naturally, I won’t give out the examples where all LLMs were wrong because they became my submissions to the dataset. But I want to show several interesting cases where only o1-preview could get it right. I used only questions written by myself, so the style is a bit dated since my days as an active WWW question author were in 2005-2015, and I also made the style slightly more formal in translation so that the answer would be defined more precisely to allow for automated checking. Here is a table with the answers of all five models; I abridged the explanations but kept their main ideas and commented a little in square brackets:

    Question. A sports journalist wrote about a famous hockey player Vasily Trofimov: “He changed the tempo a lot, and it looked like he could speed up with no limit”. Next the journalist mentions a person who visited Russia in 1842 and 1843. Write the last name of this person.
    Dumas.
    …The key to identify is the hint given by specific years “1842 and 1843”. During that period, a renowned historical figure, the French writer Alexander Dumas, visited Russia. The reference in the journalist’s statement likely ties to Dumas because mentioning a historical visitation without a clear context suggests a notable person whose travels were significant in that era…

    [Dumas indeed visited Russia, in 1858-1859 in fact but that’s close enough to introduce some confusion.]
    Dumas.
    …This sudden change of topic suggests that we need to look for a connection between hockey and a 19th century visitor to Russia… a “Dumas drop pass” is a move where a player unexpectedly drops the puck for a teammate trailing behind… named after Alexander Dumas, likely due to its surprising and deceptive nature, which mirrors the unexpected plot twists in Dumas’ novels…

    [The “Dumas drop pass” is a pure (and quite beautiful) hallucination, I couldn’t google anything even close to “Dumas” in relation to drop passes.]
    Gogol.
    …a hockey player who could “speed up with no limit”… is likely a reference to the famous Russian author Nikolai Gogol and his short story “Nevsky Prospekt”… which features a character vanishing into thin air, much like he’s moving with infinite speed… in 1842 and 1843… Gogol was living [in Russia]…

    [Here Gemini is being too creative for its own good: Gogol did live in Russia but it would be incorrect to say that he “visited Russia”, this clearly suggests that the person was a foreigner]
    Stephenson.
    …Given the context of changing tempo and speed, it may metaphorically relate to advancements in transportation or engineering during that era. George Stephenson, known as the “Father of Railways,” was a prominent figure whose work significantly influenced railway development, which aligns with the themes of speed and tempo changes…

    [That’s quite creative but, of course, Stephenson is not a person who would embody high speed for the reader, he might be a reference for engineering ingenuity or something like that]
    Liszt.
    …Liszt was a renowned virtuoso pianist and composer known for his technical prowess and compositions that featured rapid tempo changes and extreme speed, such as his “Transcendental Études.” The word “transcendental” suggests going beyond limits, linking to the idea of speeding up with no limit. Additionally, Liszt visited Russia in 1842 and 1843, where he performed concerts in cities like St. Petersburg and Moscow…

    [This is completely correct; o1-preview even added an extra reference to “Transcendental Etudes” that I never intended but that does indeed fit the question quite well. I don’t think o1-preview simply googled Liszt’s concerts because if so, why didn’t GPT-4o and o1-mini do the same?]

    Question. In this question, X stands for a different word. A World War II poster said that X was easy to get. Another poster from the same era promised to win a victory over X. Answer in one word, what was the main weapon of this victory?
    cryptography
    “X” refers to something negative or harmful that the government or military campaigns were trying to combat. In this context, “X was easy to get” likely refers to enemy information or secrets… the most renowned weapon for intelligence and communication security was the “Enigma” machine and the efforts surrounding “cryptography.”

    [The slogan “[Military] secrets are easy to get” would make no sense at all; GPT-4o misses the main point that the slogan should be some kind of joke or at least something memorable and, pardon the pun, catchy.]
    Silence
    …Given that X is described as “easy to get” and something that needed to be defeated, it’s highly probable that X stands for “information” or more specifically, “loose lips.” This refers to the famous WWII slogan “Loose lips sink ships”… The main weapon to combat this problem would have been silence or secrecy…

    [Same as above: “Loose lips sink ships” makes perfect sense as a war propaganda slogan; “loose lips are easy to get”, not so much.]
    Bonds.
    World War II propaganda often used fear tactics. An intangible enemy that is “easy to get” and needs to be defeated likely refers to a negative emotion or behavior. Considering the context, this was likely something that could hinder the war effort.

    [I think Gemini 1.5 simply didn’t come up with anything believable in this case, the reasoning is very unspecific and doesn’t make much sense. Exactly how I would describe my reasoning when we don’t get the question right but I still need to submit some kind of answer]
    Atomic
    During World War II, propaganda posters often emphasized the ease of achieving certain goals to boost morale and support for the war effort. The main weapon that was pivotal in ensuring victory was the atomic bomb, which promised a decisive end to the conflict.

    [Here o1-mini forgets all about the slogan, and even kind of misses the formal side of the question: “atomic” is not a valid answer to the question “what was the main weapon”, and “atomic bomb” would be two words.]
    Penicillin.
    During World War II, there were posters warning soldiers about the dangers of venereal diseases (VD), with slogans like “VD is easy to get” to highlight how easily one could contract such diseases… The main weapon in combating venereal diseases at that time was the antibiotic penicillin, which was mass-produced during the war and used effectively to treat infections like syphilis and gonorrhea among soldiers.

    [This is totally correct; o1-preview gets the joke of the slogan and reasons correctly that penicillin entered mass production during WWII so it could also be the subject of propaganda posters. Here, again, o1-preview probably did not “know” the exact poster because it didn’t specify the exact disease]
    Question. Consider the equality: 11+2=12+1. One of the ways to write this equality, discovered by Martin Gardner, appears in a list of the most beautiful… answer in one word, most beautiful what.
    equations
    Martin Gardner was renowned for his work in recreational mathematics and puzzles. The equality “11 + 2 = 12 + 1” seems mathematically incorrect at first glance, but this type of equality often involves a playful twist or insight that isn’t immediately obvious, typical of Gardner’s style. This format is frequently used in his mathematical puzzles or games, known for their elegance and creativity, which are often praised for their aesthetic appeal.

    [GPT-4o knows who Martin Gardner was but the reasoning doesn’t make sense: does this equality seem mathematically incorrect to you?..]
    equations
    …The equation 11 + 2 = 12 + 1 can be written in a visually appealing way that makes it appear as if the numbers themselves are balancing. This representation is often included in lists of the most beautiful mathematical equations or expressions…

    [I suppose any equation can be written in a way “as if the numbers themselves are balancing”. Claude Sonnet also doesn’t make the necessary jump outside mathematics.]
    equations
    Martin Gardner was a famous recreational mathematician known for his clever mathematical puzzles and games. The equality 11 + 2 = 12 + 1, while simple, could be written in a more visually appealing or conceptually interesting way if presented as a puzzle or rearranged in a pattern. This kind of playful approach to math is what Gardner was famous for.

    [Gemini again usefully gives an explanation that clearly suggests that it didn’t get the question right and it knows it: no specifics, just conjecture about a “visually appealing pattern”.]
    Anagrams
    The equation “11 + 2 = 12 + 1” is notable because the words “eleven plus two” are an anagram of “twelve plus one.” Martin Gardner highlighted this clever wordplay, recognizing it as a beautiful example of an anagram in the context of mathematical expressions.

    [In this case, both o1-mini and o1-preview got it right. Judging by the explanations, it could be direct knowledge, but then, again, why didn’t GPT-4o have this knowledge?..]
    Anagrams
    Martin Gardner discovered that “Eleven plus two” is an anagram of “Twelve plus one.” This clever rearrangement forms the equality 11 + 2 = 12 + 1 when written out in words, and the phrases are perfect anagrams of each other. This particular anagram is often cited in lists of the most beautiful anagrams due to its mathematical elegance and linguistic symmetry.
    Question. Later the professor married his Ph.D. student, who turned out to be the only person who openly protested. Which university was she a student of?
    Not determinable from the given information
    The statement provides information about a professor who married his Ph.D. student, and it mentions that this student was the only person who openly protested. However, the statement does not provide explicit details or context regarding names, places, or events that would identify the university she attended. Without additional specific context or information that ties this scenario to a known public case or story, it is not possible to determine the name of the university.
    Insufficient information
    This question appears to be referring to a specific situation or story that is not provided in the context. Without more information about the professor, the student, or the event being described, it’s impossible to determine which university the student attended. The question assumes knowledge of a particular narrative or incident that hasn’t been shared.
    Unknown
    This question cannot be answered. There is no information provided about a professor marrying his Ph.D. student or any protests.
    I’m sorry, but I couldn’t find information regarding the specific university associated with the professor and his protesting Ph.D. student. If you can provide more details, I’d be happy to help further.
    Based on the information provided, there isn’t a well-documented or widely recognized case involving a professor marrying his Ph.D. student who was the only person to openly protest. Without additional context or specific details, it’s challenging to identify the exact university in question.
    Stanford University
    The professor referred to is Philip Zimbardo, who conducted the Stanford Prison Experiment in 1971 at Stanford University… Christina Maslach, a Ph.D. student at Stanford and Zimbardo’s girlfriend at the time, was the only person who openly protested the ethical implications of the experiment. She confronted Zimbardo… Later, Zimbardo and Maslach married.

    [This could be direct knowledge (although, again, why didn’t GPT-4o and o1-mini have it? knowledge is nothing if you can’t access it when needed), but this is such a great example — every single model explicitly gives up and calls the question impossible, but o1-preview reasons its way through!]

    Conclusion

    As you can see, o1-preview is indeed a big step forward. This post has been devoted to the “What? Where? When?” game but more practical things like answering hard science questions, solving mathematical problems, and writing code are much improved too. What’s even more exciting is that with o1-preview, OpenAI is showing how to scale the models not only with respect to the size of the training data and computational power spent on training but also with respect to resources and time spent on inference. You could say that o1-preview has learned to actually think about a question rather than just generate the answer immediately.

    This new scaling curve could be part of the “unhobbling” as discussed by Leopold Aschenbrenner in his recent Situational Awareness book (Aschenbrenner, 2024; highly recommended, by the way—it was not yet released by the time of my post on AI dangers but I would certainly discuss it in detail if it were), or it could be a new scaling law on top of that, speeding up AI capabilities development even further. Only time will tell, and it will be some of the most interesting and exciting times in the history of humanity.

    I will leave you with a quote from Sam Altman’s blog post “The Intelligence Age”, published on September 23. Mr. Altman definitely knows how to ride hype waves but in this case, I tend to believe he is absolutely, scaringly honest:

    This may turn out to be the most consequential fact about all of history so far. It is possible that we will have superintelligence in a few thousand days (!); it may take longer, but I’m confident we’ll get there.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Using RAG to Enrich LLMs

    Using RAG to Enrich LLMs

    We continue our series on LLMs and various ways to make them better. We have already discussed ways to increase the context sizeworld models that arise in LLMs and other generative models, and LLM fine-tuning including RLHF, LoRA, and more. Today we consider another key idea that can make LLMs far more effective and useful in practice: retrieval-augmented generation, or RAG. We discuss the basic idea of RAG, its recursive agentic extensions, the R[e]ALM approach that integrates retrieval into LM training, some key problems of modern RAG approaches, discuss in detail knowledge graphs and how they are being used in RAG, and conclude with a reminder that even simple approaches can work well and a list of directions for future work.

    Introduction

    A large language model is basically a huge token prediction machine; we have discussed this many times (onetwo). In particular, it means that the LLM itself has a specific dataset that it used for pretraining and/or fine-tuning. No matter how large and how smart the LLM becomes, it will never have information not present in the dataset; for example, it will not know any events that happened after the “cutoff date”.

    But there are plenty of applications where you want the LLM to process new information! For instance:

    • a corporate AI assistant that processes the documentation base of your company; the docs evolve with time, and you probably don’t want to submit your internal documentation into an open training dataset;
    • a personal AI assistant that may need access to your email, files on your computer and so on; I definitely don’t want OpenAI or Antropic to train their models on my personal correspondence;
    • if I wanted an LLM’s help in writing this survey, I would like the LLM to search for recent publications, not only papers published before the LLM’s cutoff date;
    • even a straightforward use case such as planning a trip would require up-to-date information about available transportation, hotels that are open right now, perhaps current weather reports; and so on, and so forth.

    We have already discussed ways to extend the input context and alleviate the quadratic complexity of self-attention, but not everything can be put into context. Even if you can fit a whole book into a million token context like Gemini 1.5 already can (Reid et al., 2024), real world tasks such as the ones listed above require access to much more information.

    One way to fix this problem would be to introduce external information search as a tool; e.g., you could say to the LLM that it is allowed to call the ‘web_search‘ procedure that takes a query as input and outputs the top 5 Google search results for this query. This is a viable approach, especially for well defined queries such as weather reports or ticket availability, and perhaps we will discuss tool use in LLMs in the future.

    However, information retrieval from large corpora is so important that it is usually treated as a separate type of LLM extension, often included by default even if other tools are not available. This falls under the label of retrieval-augmented generation (RAG), introduced by Lewis et al. (2020) during the early days of Transformer-based LLMs. Patrick Lewis, by the way, has apologized in his interviews about the acronym—”We definitely would have put more thought into the name had we known our work would become so widespread”, he said (source)—although in my opinion the acronym is catchy and memorable.

    Before diving into RAG, I want to highlight two surveys on the subject, by Gao et al. (2023) and Zhao et al. (2024). They are excellent reviews of RAG and already have hundreds of citations themselves despite being very recent. I have tried not to repeat these surveys, but still much of what comes next is based on them, although we will go beyond them in several directions, in particular regarding Graph-RAG. Other surveys of RAG include Li et al. (2022) and an ACL 2023 tutorial by Asai et al. (2023). With that in mind, let’s get going!

    Basic Intuition and Our Plan

    Before branching out, let us begin with an explanation of what RAG is and how it works in the simplest example. Suppose that I wanted to ask an LLM to summarize recent research on RAG to write this post. If I asked GPT-4o, it would be able to produce an excellent explanation of what RAG is (yeah, I checked) since RAG had already entered its knowledge base. But GPT-4o has no chance to know the results published in the last year because its cutoff date was August 2023:

    So, for instance, the two excellent surveys that I mentioned above would be beyond GPT-4o’s knowledge. Moreover, the LLM would probably not be able to give any specific links with more detailed information—its knowledge is vast but not so vast as to hold the entire training knowledge base. In fact, GPT-4o simply refuses to do that outright:

    How can we remedy that? The solution is quite simple: let’s allow the LLM to retrieve information by, e.g., searching the Web. In its most direct form, “pure RAG” works like this:

    The prompt gets reformulated into a query (either in some straightforward way or by using the LLM again), the query is sent to a standard information retrieval engine (Manning et al., 2008) that returns some results, and the top results (probably however many the context window allows) are appended to the query as additional information for the LLM. As a result, RAG gives the LLM the opportunity to search over an arbitrarily large corpus; the additional costs of retrieval are usually negligible compared to running the LLM.

    Here is a specific example from the above-mentioned survey by Gao et al. (2023):

    Even in this simple form, RAG is already immensely useful in practice. For example:

    • Yue et al. (2023) introduce DISC-LawLLM, a RAG-augmented LLM-based system for legal services, also publishing a benchmark for legal question answering based on questions from the (Chinese) bar exams; to me, legal assistance is an especially great fit for RAG since much of the routine work of legal advisors is basically “smart reading” of the laws and court precedents that can already be automated by RAG-augmented LLMs to a large extent;
    • Xiong et al. (2024) present the Medical Information Retrieval-Augmented Generation Evaluation (MIRAGE) benchmark for evaluating LLMs on medical question answering and compare a number of state of the art LLMs, with GPT-4 coming out on top (but that was in February, before Gemini 1.5 or Claude 3 Opus); medicine is another field with a lot of useful information buried in millions of published papers and case studies, and while the current generation of LLMs is probably not replacing doctors yet they can already be of great help for gathering and preprocessing of all this information; specialized works in this area are already starting to appear, e.g., a RAG-based solution for processing electronic health records by Bayer AG researchers Ziletti and D’Ambrosi (2024) or the QA-RAG model for pharmaceutical regulatory compliance (another huge pain point for the entire medical industry) by Kim and Min (2024);
    • Balaguer et al. (2024) present a comparative study of existing RAG and LLM fine-tuning approaches with an unusual case study on agriculture, processing a corpus of documents and pdf files on agriculture and asking questions such as “What is the best time to plant trees and shrubs in Arkansas?”; agriculture is another huge field where AI can help farmers progress to more individualized and hence more efficient processing of their fields, crops, and livestock;
    • science itself is another field where even professionals are overwhelmed by the stream of information, so LLMs are in great demand here; Lala et al. (2023) introduce PaperQA, a RAG-based agent for answering questions over research papers, Suresh et al. (2024) provide a RAG-based solution for summarizing documents on the Electron-Ion Collider,  and so on;
    • among other fields, let me note Telco-RAG by Huawei researchers Borneo et al. (2024) that deals with the telecommunication industry; this is an interesting use case because it is a rapidly evolving field with new hardware and new standards appearing constantly, so retrieval results will often be outdated.

    RAG is already a standard approach and part of many industrial solutions for LLMs such as Vertex AI Search by Google, NVIDIA NIM MicroservicesAmazon workflow based on LangChainIBM watsonxGlean, and others. And if you want to “chat with your documents” but don’t feel comfortable sharing all your personal or corporate data with Google or Amazon, you can use an open source tool such as RAGFlowVerba, or Kotaemon.

    But the story does not stop here. Since 2020, there has (naturally) been an explosion of papers related to augmenting LLMs with retrieval. Below, we will discuss the main directions of this research:

    • advanced RAG strategies that, for instance, refine and adapt queries for retrieval with the same LLM;
    • ways to incorporate retrieval into the training of LLMs and/or make it an integral part of the model rather than an external module; this class of approaches is usually known as RALM or ReALM;
    • problems of RAG, which are mostly related to the need to put retrieved documents into the LLM’s context, thus using up the available context window;
    • a combination of RAG with knowledge graphs, an interesting new direction for making retrieval “smarter” and more accurate;
    • and finally a few directions for further research.

    Advanced RAG Strategies: Agentic RAG

    The basic RAG pipeline outlined above is just that, basic. When you google something, sometimes you find what you’re looking for on the first try, but very often you need to refine your query, maybe even formulate a few completely new queries based on what you have read by the first one, and combine the results. 

    Thus, the RAG pipeline is expanded to include a refinement loop, which generally adds something like this:

    This kind of workflow is modeled in advanced RAG strategies that are often combined under the name of “Agentic RAG”: instead of passively reading the context expanded by retrieval results, here the LLM becomes an agent able to formulate new queries (see, e.g., sample implementations of Agentic RAG in LangGraph). Let us go over a few specific examples.

    Shao et al. (2023) provide a straightforward implementation of this idea. Their iterative retrieval-generation strategy takes the output of RAG and uses it as input for another round of RAG (recursively extending to several rounds if necessary), which allows to correct hallucinations and possible factual errors missed on the first iteration.

    In the example below, the arena capacity answer highlighted in red was a fact mentioned in one of the documents on the first iteration and erroneously attributed to the answer, as LLMs are prone to do; the second iteration searches again for the correct arena and corrects the mistake:

    Self-RAG (Asai et al., 2023), which is short for Self-Reflective RAG, is a good example of a further elaboration of the recursive approach. Self-RAG presents a straightforward but well-executed recursive RAG pipeline based on new special tokens:

    • first find out whether retrieval is necessary at all; for many queries, it’s not; if it is, generate a retrieval token that calls the search tool;
    • for each retrieved passage, evaluate their relevance, adding special relevant/irrelevant tokens;
    • generate outputs for each retrieved passage and then critique the outputs in terms of factuality, overall quality, and whether the generated output is supported by the retrieved passage; the results, again, are added in the form of special tokens;
    • finally, choose the best output and repeat the whole process with the best output already included in the prompt, thus enabling refined searching.

    Here is an illustration:

    The novelty here lies in the special tokens that represent various qualities of the search results and generated answers such as “Relevant”, “Irrelevant” or “Supported” (by the document). Naturally, the model would have to be fine-tuned to understand what the new tokens mean, and this is done on a synthetic dataset labeled by a separate critic model. The critic model, in turn, is trained in a supervised way based on evaluations done by a state of the art large LLM, GPT-4 in this case.

    In a similar approach, Corrective Retrieval Augmented Generation (CRAG) by Yan et al. (2024) focuses on fixing hallucinations and irrelevant retrieval results that are one of the main problems of basic RAG. It runs a separate retrieval evaluator that decides whether retrieved documents are actually relevant and how likely their information is to be correct:

    Adaptive-RAG by Jeong et al. (2024) incorporates an additional classifier that chooses the necessary approach. Some queries are simple and straightforward (“Paris is the capital of what?”) and require only one retrieval step or no retrieval at all, while some are more complex and call for multi-step retrieval, and the classifier can choose the correct strategy:

    In addition to recursive refinement, there are other techniques that can improve the basic RAG model. We will not go into much detail here, but let me just mention the main directions:

    • reranking the results according to what is more useful for answering the question may improve RAG outputs (Gao et al., 2023Blagojevich, 2023); interestingly, this also works in the opposite direction: LLMs can help improve information retrieval by serving as good rerankers for retrieval outputs (Ma et al., 2023Peng et al., 2023);
    • just like good prompt engineering is key to getting the best answers out of LLMs, query rewriting can help obtain better retrieval results; for example, Ma et al. (2023) fine-tune a small LM to write better search queries based on the prompt and show significant improvements to the resulting question answering;
    • Google researchers Ke et al. (2024) train a separate sequence-to-sequence model that adapts the retrieved information to the LLM’s preferences by choosing specific passages from the documents; interestingly, this model is first trained with supervised learning on “silver passages” chosen greedily from the retriever outputs but then additionally fine-tuned with reinforcement learning:

    RAPTOR (Sarthi et al., 2024), which stands for Recursive Abstractive Processing for Tree-Organized Retrieval, focuses on a different problem: limitations of the retrieval results themselves and better preprocessing of them. Usually RAG only processes relatively short chunks of information from the retrieved documents near the actual search hits, although the full context of a long text would often help a lot. 

    To provide this full context, RAPTOR adds a recursive tree-like summarization step that clusters the vector embeddings of text chunks and generates text summaries of clusters that can be in turn clustered further:

    Then retrieval can be run on this tree of summaries, thus retrieving not only the actual short chunks but also the summaries of much longer texts that these chunks are part of; RAPTOR considers both retrieval that traverses the tree and that simply checks all of its nodes:

    This significantly improves the LLM’s answers for more general questions on longer texts such as “What is the central theme of the story?”. Questions like that appear in several datasets related to processing longer texts, such as:

    • NarrativeQA (Kocisky et al., 2018) devoted to question answering on fiction and movie transcripts; it has been a standard dataset for recursive summarization, including the OpenAI work by Wu et al. (2021);
    • QASPER (Dasigi et al., 2021) that focuses on research papers (in fact, papers on natural language processing);
    • QuALITY (Pang et al., 2022) that contains questions on text understanding for 5000-token long essays, including hard questions that require holistic reasoning over the whole text.

    The obvious drawback is that RAPTOR needs to build the tree of summaries for the retrieved document, so even though it is a relatively efficient step, RAPTOR is primarily suited for situations when it is obvious which specific long text you need to process to answer the question.

    Finally, let me note a recent paper by Google researchers Wang et al. (2024). Their Speculative RAG system makes use of two different LLMs: a “specialist LLM” designed to answer questions based on specific documents and a “generalist LLM” used to combine the results of specialist LLM being run on different retrieved documents. Instead of dumping retrieved documents into the LLM context, Speculative RAG uses the specialist LLM to make several drafts of the response together with rationales for them based on different retrieved documents, and then the generalist LLM can choose the best answer or combine them as it sees fit (illustration from Wang et al., 2024): 

    This avoids the problems related to position bias in long context (see below) and allows the generalist LLM to better incorporate different perspectives on the question. The specialist LLM may be weaker (and hence smaller and more efficient) than the generalist LLM since it only needs to process a couple of documents and answer the question directly based on information from them.

    In general, RAG methods are being developed in a number of different exciting directions, mostly related to evaluation and reranking of retrieval results, recursive refinement of search queries, better processing of retrieved documents, and criticizing and refining the LLM outputs.

    The R[e]ALM of RAGs: A Tighter Integration

    In addition to RAG, there are other keywords related to retrieval-based improvements for language models; the most widely used is RALM (also known as ReALM, Retrieval-Augmented Language Model). Unfortunately, the terminology does not seem to be clearly defined yet: some sources list RAG as a subset of the wider term RALM, others seem to define RALM as using retrieval only on the training set while RAG can use external sources, and yet others view RALM as “RAG 2.0” that further advances the basic ideas of RAG.

    This confusion probably stems from the fact that while R[e]ALM sounds like a very general keyword, the original paper that introduced REALM (Guu et al., 2020) indeed used retrieval only on the training set, as a separate part of the architecture learned during pretraining. The point of REALM was to have two different networks, a knowledge retriever and an encoder. During pretraining, the retriever looks for documents that might help in solving the masked language modeling task and supplies the results to the encoder that’s trained like a regular BERT. During supervised fine-tuning and then inference, the retriever looks for a most relevant document for the query and again provides it for the BERT-like encoder:

    In this section, we will review a line of work that started from REALM; this is the specific direction that I call RALM here: a language model architecture with a retrieval mechanism embedded into the model and probably trained together with the LM itself, a mechanism that helps in training as well as during inference.

    DeepMind researchers Borgeaud et al. (2022) introduced the Retrieval-Enhanced Transformer (RETRO) that incorporates retrieval directly into the Transformed decoder. They turn the training dataset (in their case, the Pile dataset with about 2 trillion tokens; Gao et al., 2020) into a retrieval index based on BERT embeddings of subsequences of tokens. For a given chunk, the retrieval engine outputs its nearest neighbors together with their continuations in the corresponding documents from the corpus. The results are encoder through the Transformer encoder (part of the trained model) and then are attended to by the Transformer decoder in the model:

    Borgeaud et al. report that they were able to achieve performance on par with GPT-3 on the Pile dataset while using 25x fewer parameters; this was the first work to scale retrieval-augmented LLMs to trillions of tokens in the corpus and GPT-3 sized models being trained.

    FAIR researchers Lin et al. (2024) recently continued this line of work with the RA-DIT framework, which stands for Retrieval-Augmented Dual Instruction Tuning. RA-DIT does not train the LLM together with the retriever from scratch; instead, it uses supervised fine-tuning (recall our previous post) to make the LLM better use retrieval results while at the same time fine-tuning the retriever to better serve the LLM. Here is an illustration from Lin et al. (2024):

    In this way, RA-DIT combines a pretrained LLAMA model (Touvron et al., 2023a2023b) and a state of the art DRAGON+ retriever (Lin et al., 2023) but makes both of them mesh together better in fine-tuning. The fine-tuning process, by the way, can serve as a good illustration so let us discuss it in more detail. When RA-DIT produces an answer \mathbf{y} for an input prompt \mathbf{x}, its output probabilities are weighted as

        \[p_{\mathrm{LM}}(\mathbf{y} | \mathbf{x},\mathcal{C}) = \sum_{\mathbf{c}\in \mathcal{C}}p_{\mathrm{LM}}(\mathbf{y} | \mathbf{c}\circ \mathbf{x}) p_{\mathrm{R}}(\mathbf{c}| \mathbf{x}),\]

    where \mathcal{C} is the set of retrieved text chunks \mathbf{c}, p_R(\mathbf{c}|\mathbf{x}) is the probability the retriever assigns to chunk \mathbf{c}, and p_{\mathrm{LM}}(\mathbf{y}|\mathbf{c}\circ\mathbf{x}) is the probability the language model assigns to \mathbf{y} given the prompt of \mathbf{c} concatenated with \mathbf{x}. Retriever probabilities p_R(\mathbf{c}|\mathbf{x}) are produced via softmax from retriever scores s(\mathbf{x}, \mathbf{c}), which are just dot products of the query’s and document’s embeddings.

    This idea carries over to the supervised fine-tuning process. Supervised training on (\mathbf{x}, \mathbf{y}) pairs is done separately for the LM and the retriever:

    • the language model is fine-tuned to produce the correct answer \mathbf{y} on all top-k responses from the retriever:

          \[\mathcal{L}_{\mathrm{LM}}(D) = -\sum_{n=1}^N\sum_{j=1}^{k}\log p_{\mathrm{LM}}(\mathbf{y}_n | \mathbf{c}_{nj}\circ \mathbf{x}_n);\]

    • for the retriever, they use a version on the LSR technique (LM-supervised retrieval; Shi et al., 2023), where the retriever is trained to produce the most helpful outputs for the language model; given top-k retrieval results cj, the language model gives the likelihood of y given cj○x, which can be turned into a distribution on c via softmax:

          \[p_{\mathrm{LSR}}(\mathbf{c}| \mathbf{x}, \mathbf{y})=\frac{e^{\frac{1}{\tau}p_{\mathrm{LM}}(\mathbf{y}|\mathbf{c}\circ\mathbf{x})}}{\sum_{\mathbf{c}'\in C}e^{\frac{1}{\tau}p_{\mathrm{LM}}(\mathbf{y}| \mathbf{c}'\circ\mathbf{x})}}\approx \frac{e^{\frac{1}{\tau} p_{\mathrm{LM}}(\mathbf{y}|\mathbf{c}\circ\mathbf{x})}}{\sum_{j=1}^ke^{\frac{1}{\tau}p_{\mathrm{LM}}(\mathbf{y}|\mathbf{c}_j\circ\mathbf{x})}},\]

      so the loss function for the retriever is the Kullback-Leibler divergence between p_{\mathrm{LSR}} and p_{\mathrm{R}}:

          \[\mathcal{L}_{\mathrm{R}}(D) = \mathbb{E}_{D}\left[ \mathrm{KL}\left( p_{\mathrm{R}}\left(\mathbf{c}|\mathbf{x})\middle\| p_{\mathrm{LSR}}\left(\mathbf{c}|\mathbf{x}, \mathbf{y})\right].\]

    This kind of alternating training is common in systems that consist of two or more well-defined parts: we train one part of the system by fixing the weights of all others and then do it with the other parts. Alternating training is often formalized as optimizing a variational lower bound since the losses may have complex interdependencies; we have seen it, for instance, in training DALL-E (recall our post on it). In this case, since pLSR has a very nontrivial dependence on pLM, when we add them together, optimizing LLM(D)+LR(D) with respect to the language model becomes intractable. Breaking an intractable loss function into tractable components is exactly what variational approximations are for, but in this case neither Shi et al. (2023) nor Lin et al. (2024) provide a derivation for it.

    In my opinion, this kind of fusion between the retrieving mechanism and the language model can no doubt help further improve retrieval in joint training. The question is whether fine-tuning will remain necessary at all as LLMs progress further: retrieval will probably always be necessary but we humans don’t have to undergo joint training with Google retrievers to benefit from the search. On the other hand, it’s not like we tried—maybe that lies in the future as well?..

    Lost in the Middle: Problems with RAG and RALM

    Despite a lot of progress outlined above, there still are problems associated with the use of RAG. One of the most important problems is actually not directly related to RAG but rather to long contexts in general: the larger the context, the harder it is for the LLM to find the “needle in the haystack”.

    In RAG, this problem usually takes the form of the “Lost in the Middle” effect recently found by Liu et al. (2024): if the LLM receives many retrieved documents as input, and the necessary information is contained in only one of them, performance will significantly depend on which document in the list contains it.

    Liu et al. formalized this point in the multi-document question answering problem illustrated in figure (a) above: the LLM is allowed to use several documents somewhat related to the question but only one of them actually contains the answer. The results are shown in figure (b) above: if the answer is in the first few documents, the LLM’s accuracy is much higher than if it is in the middle, and the saliency of the answer rises again at the end of the context. Liu et al. (2024) showed that this effect is consistently exhibited by several leading LLMs, appears as soon as the total input length exceeds the sequence length used in training the encoder (for encoder-decoder models such as Flan-T5-XXL; Chung et al., 2024) and does not go away if you change the placement of the query compared to the retrieved documents or do instruction fine-tuning.

    This specific problem will most probably be resolved by progress in the LLMs themselves. Soon after the publication of Liu et al. (2024), Google released Gemini 1.5, and the corresponding paper was called “Unlocking multimodal understanding across millions of tokens of context” (Gemini Team, 2024). The authors showed that Gemini 1.5 has near-perfect retrieval for a variety of very long context tasks; we discussed this model and generally ways to extend the context for LLMs in a previous post.

    However, long context does not solve all problems by itself. Another important problem is redundancy: when you search the Web for something specific, the documents tend to repeat themselves and can saturate any context window. If the repeated documents provide information that’s not relevant to the question at hand, the LLM has a high probability of getting confused by sheer repetition.

    I would like to note here that “needle in a haystack” benchmarks such as the ones used by the Gemini team are looking for very specific information, which may be present only in one specific part of a very long context. Here is a sample from the Gemini 1.5 demo on video understanding:

    Don’t get me wrong, this is truly an impressive achievement. But the problem here is basically retrieval from context; these kinds of tasks do not involve any generalization or intelligent processing of significant portions of the context. While extra hay makes it harder to find the needle, the question of distinguishing hay from the needle is relatively simple. I wonder what Gemini would say if the question was to “highlight specific influences of Sherlock Jr. on The Purple Rose of Cairo”, a Woody Allen movie with a similar premise, based on the movies themselves rather than critical reviews that had already pointed out the similarities.

    For such involved questions, an even more problematic fact is that the knowledge coming from RAG is unstructured. Videos aside, even a regular text-based RAG would usually result in a collection of text snippets that often repeat each other, contain irrelevant extra information or represent retrieval mistakes, i.e., completely irrelevant documents. If you ever tried to learn a completely new field based on the results of a Google search, you know how hard it may be to make sense of this “haystack” as a whole rather than just find the exact trivia “needle” you’re looking for.

    For many questions, a more structured way to present information would be both preferable and easily available. To learn (a little) more about Sherlock Jr. I went straight to Wikipedia and never even tried to actually watch the movie, read contemporary critical reviews, Buster Keaton’s memoirs, or other sources that might present themselves: that would take way too much time.

    Recently, another very interesting recent direction of study has appeared that may alleviate these problems at least to some extent. Let us discuss it in the next section.

    RAG + Knowledge Graphs = GraphRAG

    GraphRAG is a direction of study where retrieval queries are run against a knowledge graph and return facts rather than text snippets (see, e.g., a very recent survey by Peng et al., Aug 2024). We have not discussed knowledge graphs on the blog, so this warrants some elaboration.

    A knowledge graph (Hogan et al., 2022Ji et al., 2021Heist et al., 2020Yan et al., 2018) is a, well, graph with directed edges and labels on both nodes and edges. A directed edge in the knowledge graph represents a fact defined as a (subject, predicate, object) triple such as (GPT-4, IsA, large language model) or (Sam Altman, CEOOf, OpenAI). The subject and object are the source and sink nodes and the edge between them is labeled with the relation.

    The expressive power of knowledge graphs comes from the fact that relations can be arbitrary, and with a proper choice of relations you can fit mostly any factual knowledge in a set of triples. Here is an example from Ji et al. (2021):

    If you have a knowledge graph, gathering information about an entity or answering questions about relations between entities (even complex relations that correspond to multi-hop paths rather than edges) becomes a matter of traversing the graph, a much easier and more reliable task than reading and understanding unstructured text.

    To be honest, knowledge graphs are a personal favorite of mine. They provide a very easy and very general way to structure knowledge that makes it much easier to make logical inferences. Huge knowledge graphs based on human-verified information are already available, including:

    There also exists a wide field of study for automated and semi-automated construction of knowledge graphs from unstructured data (Zhong et al., 2023Hofer et al., 2023). I have always thought knowledge graphs are underutilized in machine learning; despite the huge literature devoted to knowledge graphs (see the surveys linked above) I believe they could be put to an even better use as repositories of structured information that is usually much more coherent.

    Before returning to RAG, let me note several different ways knowledge graphs have already been used together with LLMs. A notable entry here is the ERNIE family of models by Baidu (Sun et al., 2019Sun et al., 2020Xiao et al., 2020Sun et al., 2021), recently made into the Ernie Bot that has reached hundreds of millions of users in China. Starting from the very first model, ERNIE, which stands for “Enhanced Representation through kNowledge IntEgration”, used knowledge graphs to improve the pretraining tasks, enriching the semantics of masking. In ERNIE 1.0 (Sun et al., 2019), it meant that the BERT masks were generated to cover whole entities. In the example below, instead of just filling in “___ Potter” or “J. ___ Rowling” as a random mask would suggest, phrase-level masking forces the network to actually learn the relationship between these entities:

    In subsequent versions of ERNIE, this idea was extended to universal knowledge-text prediction that combines a knowledge graph and text snippets; given a triple from the graph and the corresponding text, the model is asked to restore parts of each. Here is an illustration (Sun et al., 2021):

    When the LLM has already been trained, knowledge graphs can be used to improve its reasoning abilities and ground the LLM’s answers in verified knowledge, possibly reducing  hallucinations (Wang et al., 2023Chen et al., 2024). Several works develop special neural architectures for that. Approaches before the advent of LLMs usually employed graph neural networks (Ren et al., 2020Ren, Leskovec, 2020), but now the emphasis has shifted. For example, the JointLK model (Sun et al., 2022) introduces new attention modules that can attend both to a sequence of vectors, like a regular Transformer-based LM, and to parts of the knowledge graph, like a GNN:

    These days, of course, it may not be necessary to train a novel architecture: an LLM may be used “as is” with some external scaffolding of knowledge graph retrieval and prompting. Without going into too much detail, here is one example of using the so-called chain-of-knowledge prompting (Wang et al., 2023), a process that expands and significantly improves the “chain of thought” reasoning common for LLMs:

    As you can see, the knowledge graph is used as a source of reliable information that LLM outputs and hypotheses can be checked against. There exist many similar approaches (Zheng et al., 2024Agrawal et al., 2023) but a detailed survey should probably wait for a post devoted to chain of thought reasoning and generally writing good prompts for LLMs (which, I hope, will appear in the future).

    The fruitful relationship between knowledge graphs and LLMs also goes in the opposite direction: it is a very natural idea to use LLMs to automatically construct knowledge graphs from unstructured text. One of the first such ideas, COMET (Bosselut et al., 2019; illustrated in (a) in the figure below), used GPT-2 to create new knowledge graph triples from few-shot prompts. BertNet (Hao et al., 2022; (b) in the figure below) starts from a definition of a relation and a few examples, recursively refines the prompts with new paraphrases of the definition, and then uses the prompts to search for entity pairs that have this relation:

    The works by Zhu et al. (2023) and Yu et al. (2023) discuss the possibility of an end-to-end automated knowledge graph construction framework based on modern LLMs. They do achieve some success but also highlight some problems that still prevent a full-scale solution, including lack of context, hallucinations, and more. Similar problems have been encountered when applying LLMs to fully automate other knowledge extraction tasks such as named entity recognition (Wei et al., 2023) and event extraction (Gao et al., 2023), where state of the art LLMs do a decent job but do not outperform specially developed solutions. On the other hand, both of these works use the original ChatGPT and predate the release of GPT-4, let alone current models, so the situation may already be different.

    But let us get back to the main topic of this post. When applied to RAG, retrieving structured triplets may allow an LLM to give much more detailed and precise answers, especially when they have to uncover relations between different entities (which is very often the case). Here is a sample illustration by Peng et al. (2024):

    As you can see, retrieving structured facts can make it much easier to form further deductions and generally process the facts. 

    To perform the retrieval itself, one can again rely on graph neural networks (GNN) that we have already mentioned. Naturally, you can treat the knowledge graph as a completely separate modality, but there also exist unified approaches. For example, the QA-GNN approach (Yasunaga et al., 2021) uses an LLM to produce a context vector and then plugs it into the GNN for knowledge graph reasoning:

    For a recent example of a KG-based retrieval framework, let me highlight Reasoning on Graphs (RoG) developed by Luo et al. (2024). In RoG, the LLM first generates several relation paths that might be useful to answer the question, then these paths are grounded in the available knowledge graph, and finally the retrieved results are again processed by the LLM to produce the final answer:

    As a result, RoG lets the LLM gather additional information necessary to answer even in cases when it is not obvious which information is needed, and also to avoid hallucinations along the way. Moreover, RoG can also show the reasoning paths, which greatly improves interpretability: now we can immediately see the chain of factual reasoning behind the LLM’s answer. Here are two examples that Luo et al. (2024) give in their work:

    In general, humanity has already collected a lot of knowledge in structured and verified form, so I am sure that using this structured knowledge and probably even preferring it over unstructured text (if structured knowledge is available, of course) is an obvious step that can improve AI systems in general.

    In-Context RALM: Just Use the LLM

    What if all that has been too difficult for your liking? You can always go back to a simple alternative that we started with: let’s just extend the LLM’s context with everything retrieval tells us and hope that the LLM can sort it out. The better the LLMs become, the more we can rely on this hope.

    In-context RALM (Ram et al., 2023) proposes to do exactly this. Their pipeline is as simple as they come: use an external retriever, collect all retrieved documents, append them to the prompt and let the language model sort it out. Like this:

    Note that this is in fact RALM rather than just RAG: the retrieved documents are appended to the autoregressive generation input, so a given next token is conditioned on both already generated tokens and retrieved texts. Ram et al. (2023) rerun retrieval once every s tokens, where s is the retrieval stride; their experiments show that using small values of s, while increasing retrieval costs, actually does improve the results, and in the main experiments they use s=4, running retrieval every four tokens.

    The authors show very significant improvements in token prediction perplexity across the board, for a number of different LLMs and different retrievers Another recent work from the same group shows that this approach can significantly improve factuality, reducing hallucinations and getting more factually supported continuations (Muhlgay et al., 2023). So even if you do not have time or resources to fine-tune new models or develop custom architectures, retrieval can improve your LLM’s output even in this default form.

    Conclusion

    So what does the future have in store for RAG? First, I want to highlight again that large context windows and RAG are both important tools that solve different problems, and one does not make the other obsolete. As base LLMs grow to be more capable, the role of RAG might shift from being a necessary tool to overcome context size limitations to an optimization tool that enhances efficiency, relevance, and scalability, but it will remain relevant anyway.

    On the one hand, even a huge context window will never contain the entire Internet or even the entire history of your emails and documents on your computer. On the other hand, a longer context window will enable RAG to work better: if you can afford to carefully read the top 100 Google search results for several different queries rather than only the top 5, your resulting answer will be much better informed. Agentic approaches that gradually refine the query and maybe formulate other related queries for retrieval also keep getting better and will no doubt become an integral part of smart AI assistants.

    Second, the internal structure of RAG might change. I currently view GraphRAG that we have discussed above as a very promising approach: triplets extracted from knowledge graphs are a natural representation of knowledge, and this whole field looks like a good “marriage” between knowledge already existing in a more formalized way than just text and LLMs.

    Third, we have not really discussed multimodal RAG in any detail: so far it appears to be a rather straightforward application of existing representation learning approaches for other modalities but this can also change in the near future.

    Fourth, some applications require time-sensitive retrieval as new relevant information may appear and replace old info. A simple example here would be a financial advisor AI that needs to operate with current stock prices or a personal AI assistant that continuously gathers new updates from your social media and summarizes them for you.

    But whatever the future brings, I believe that RAG will always remain a natural and important component of AI systems; not even the LLMs of the far future will be able to fit all of the world’s data into their context, and they will need some mechanism for sieving through this data other than just reading it token by token. In general, while an LLM is the central part of many AI systems it is not an end-all single model for everything: it needs a variety of other tools and subsystems to obtain the necessary information. Next time, we will discuss another important component of modern LLM-based solutions.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Fine-Tuning LLMs: RLHF, LoRA, and Instruction Tuning

    Fine-Tuning LLMs: RLHF, LoRA, and Instruction Tuning

    We continue our series on generative AI. We have discussed Transformers, large language models, and some specific aspects of Transformers – but are modern LLMs still running on the exact same Transformer decoders as the original GPT? Yes and no; while the basics remain the same, there has been a lot of progress in recent years. Today, we briefly review some of the most important ideas in fine-tuning LLMs: RLHF, LoRA, instruction tuning, and recursive self-improvement. These ideas are key in turning a token prediction machine into a useful tool for practical applications.

    From GPT to GPT-4: What Has Been Changing?

    For over a year, I have been writing about generative AI on this blog. Recently, we have discussed the basic architecture of this latest generative AI revolution: the Transformer. We have also considered modern LLMs and even reasons to worry about their future development, and we have discussed in detail one specific venue of progress: how to extend the context window size in Transformers, alleviating the quadratic complexity of self-attention.

    But has this fight for context windows been the entire difference between the original Transformer and the latest GPT-4, Gemini 1.5, and the rest? Is there anything else except for “stacking more layers”? Sure there is, and today we discuss it in more detail.

    Before proceeding further, I have to warn you that the new ideas and especially engineering implementation details of the very latest large language models are not being released to the public. There is no definitive paper about GPT-4’s internal structure (let alone plans for GPT-5) written by OpenAI researchers. Still, there are plenty of ideas floating around, and plenty of information already available from previous attempts by leading labs, from publicly released models such as the Llama family, and from independent research efforts. 

    So while I’m not claiming to show you the full picture today, I still hope to give a wide enough survey. Our plan is as follows:

    • we begin with the most important advance that made GPT-3 into ChatGPT and kickstarted the LLM revolution: reinforcement learning with human feedback (RLHF);
    • then we discuss fine-tuning pretrained LLMs on small datasets by learning adapters for the mostly frozen weights of the base models; the most popular and efficient tool here have been low-rank adapters (LoRA);
    • next, we consider instruction tuning, where LLMs are themselves fine-tuned on datasets providing realistic prompt-response samples rather than token prediction in arbitrary text; the main question here is where to get the data, and we discuss both manually labeled datasets and our main focus, synthetic data;
    • finally, synthetic data for LLM fine-tuning and reinforcement learning come together in our last topic: attempts at recursive self-improvement where the LLM may become smarter by bootstrapping from its own outputs.

    All of these techniques, and more, are key to efficiently using LLMs for practical problems, especially for specific applications such as mathematical reasoning or programming; we will see many such examples below.

    Giving Humans What They Want: RLHF

    You have certainly heard of reinforcement learning with human feedback (RLHF). This is the secret sauce that turned GPT-3, an amazing but hard to use token prediction machine, into ChatGPT, an LLM that keeps turning the world upside down. But how does it work, exactly?

    We don’t often talk about reinforcement learning (RL) on this blog; probably the only notable exception was my last post on world models, where RL was featured very prominently. In general, it is a separate way of doing machine learning, in addition to supervised and unsupervised learning:

    In supervised learning, you have a labeled dataset and want to learn a conditional distribution of labels given the data points. In unsupervised learning, there are no labels, you just mine the data for structure, learning the joint distribution of all variables. For example, token prediction is pure classification, a supervised learning problem of learning p(y|x) for a text prompt x and next token y, but we can also say that as a result, the language model has implicitly learned a distribution over text snippets p(x) because it can generate whole texts in an autoregressive fashion. 

    In reinforcement learning, there is no prior dataset: a learning agent is just “living” in an environment, getting rewards based on actions that it takes and trying to maximize these rewards. In the last post, we discussed the distinction between several different approaches to RL such as policy gradient and actor-critic algorithms:

    But be it with a world model or without, RL and training an LLM sound very different, right?

    RLHF started with the work of OpenAI researchers Christiano et al. (2017). Paul Christiano is one of the leading figures in the field of AI alignment, and this work was also motivated by a problem that sounds more like alignment: how do we tell, for instance, a robot what exactly we want it to do? Unless we are in a self-contained formal system such as chess, any reward function that we could formulate in the real world might be superficially optimized in ways that are hard to predict but that do not give us what we want. It is well known, for example, that robots learning complex behaviors in simulated environments often learn more about the bugs and computational limits of the simulator than about the desired behavior in the real world; for more details see, e.g., Lehman et al., 2018 or a list of specification gaming examples by Krakovna et al..

    Thus, Christiano et al. suggested that since we most probably cannot define what we want formally, we can instead ask a human: when you see it, you know it. Human feedback would define how well the system’s current behavior matches the actual hard-to-define goal; that feedback might be provided in the form of comparing two responses or two outcomes and preferring one of them. This approach, however, is impractical: we cannot ask humans to label as much data as actually necessary to train a reinforcement learning model. Therefore, the idea of Christiano et al. is to train a separate model that encodes user preferences and predicts the reward used in actual RL training. Here is a general scheme of this training:

    The human providing feedback cannot assign numerical reward value, so instead they compare pairs of “actions”—in the case of Christiano et al., actions were short sequences of Atari game playing or a robot walking—and give pairwise preferences. As a result, the dataset looks like a set of pairs D= \{(\sigma_1,\sigma_2,\mu)\}_{n=1}^N, where \sigma_i = \left((o_{i0},a_{i0}), (o_{i1},a_{i1}), \ldots, (o_{i,k_i},a_{i,k_i})\right) are sequences of observation-action pairs that describe a trajectory in the reinforcement learning environment, and \mu is a probability distribution specifying whether the user preferred \sigma_1, \sigma_2, or had an equal preference (uniform \mu).

    To convert pairwise preferences into a reward function, this approach uses the assumptions of Bradley–Terry models for learning a rating function from pairwise preferences (Bradley, Terry, 1952). The problem setting for a BradleyTerry model is a set of pairwise comparisons such as the results of, e.g., chess games between players, and the basic assumption is that the probability of player i winning over player j can be modeled as

        \[{\hat p}(i\succ  j) = \frac{\gamma_i}{\gamma_i+\gamma_j}\]

    for some rating values \gamma_i, \gamma_j\in\mathbb{R}. Then Bradley–Terry models provide algorithms to maximize the total likelihood of a dataset with such pairwise comparisons, usually based on minorization-maximization algorithms, a generalization of the basic idea of the EM algorithm (Hunter, 2004; see also a discussion of EM below).

    In the case of RL from human preferences, we need a further assumption since \gamma_i has to be a function of \sigma_iChristiano et al. (2017) modeled it as a product of exponential rewards over the sequence:

        \[\gamma(\sigma_i) = e^{\sum_{t=1}^{k_i}{\hat r}(o_{it},a_{it})},\]

    and then the loss function for the neural network can be defined as

        \[\mathcal{L} = -\sum_{(\sigma_1,\sigma_2,\mu)\in D}\left(\mu(1)\log {\hat p}(\sigma_1\succ \sigma_2) + \mu(2)\log {\hat p}(\sigma_2\succ \sigma_1)\right).\]

    It might seem that this idea just shifts the impractical part of providing human feedback during RL training to an equally impractical task of providing enough human feedback to train a reward prediction model. However, it turned out that with this approach, it only takes a few hundred queries to a human rater to learn walking or hopping in the MuJoCo simulated environment (Todorov et al., 2012; see a sample choice posed for the human evaluator on the left in the figure below), and if you are willing to go over 1000 queries you might even get better results than pure reinforcement learning! The latter effect is probably due to reward shaping (Wiewiora, 2010): when we humans rate behaviors, we impose an ordering where sequences closer to the goal are rated higher, and the resulting rewards provide more information to the agent than just a binary label of whether the task has been done successfully.

    By the way, this work also contains a very interesting example of reinforcement learning gone rogue. On the right, the figure above shows a sample frame from a video showing the robotic hand trying to grasp the ball. Human evaluators were asked to check whether the grasping had been successful. But since the scene had only one virtual camera, and with such a uniform background depth estimation was hard for humans, the robot learned to position the hand between the ball and the camera so as to appear as if it is grasping the ball rather than actually doing it! This is an excellent example of what is known as specification gaming, when machine learning models converge on behaviors that had not been intended by the developers but that indeed optimize the objective function they specified; we have talked about possible problems resulting from such effects on the blog before.

    The ideas of Christiano et al. have been continued in many works. In particular, there have been extensions to k-wise comparisons, with a specially developed maximum likelihood estimator (Zhu et al., 2023), to vague feedback, where a human evaluator can only reliably distinguish two samples if their quality differs significantly (Cai et al., 2023), and to multi-agent systems (Ward et al., 2022). On the other hand, this direction of research can be placed in a theoretical framework of preference-based reinforcement learning (PbRL), where reward values are replaced with preferences (Fürnkranz et al., 2012Jain et al., 2015Wirth et al., 2017Xu et al,. 2020).

    But the most important continuation was, of course, RLHF itself, an application of deep RL from human preferences to large language models. The first step was taken in 2020, when OpenAI researchers Stiennon et al. (2020) developed a summarization model based on human feedback. Their approach, illustrated in the figure below, is very similar: they collect human feedback on which document summaries are better, train a reward model to match these preferences, and then use the reward model to fine-tune with reinforcement learning.

    For training the reward model, they change the loss function we have discussed above to a classification loss based on the logistic sigmoid:

        \[\mathcal{L} = \sum_{(\mathbf{x},\mathbf{y}_1,\mathbf{y}_2,\mu)\in D}\log\left(\sigma\left({\hat r}(\mathbf{x},\mathbf{y}_{\mu}) - {\hat r}(\mathbf{x},\mathbf{y}_{1-\mu})\right)\right).\]

    where \mathbf{y}_1 and \mathbf{y}_2 are two summaries of the text \mathbf{x}, and \mu is 0 or 1 depending on which one the user prefers. For reinforcement learning, they used proximal policy optimization (PPO), a standard policy gradient RL algorithm that we will not describe here in detail; see, e.g., (Schulman et al., 2017Sutton, Barto, 2018Zheng et al., 2023). 

    One important remark here is that if the reinforcement learning process is left unchecked, it is very likely to overfit, diverge very significantly from the original model, and collapse into a single node since human feedback is, of course, too scarce for full-scale training. Therefore, RLHF adds a penalty term in the reward function r(\mathbf{x}, \mathbf{y}) that urges the learned policy \pi_{\mathrm{RL}} to not differ too significantly from the original supervised model \pi_{\mathrm{SFT}}, usually in the form of KL divergence between the two:

        \[{\hat r}'(\mathbf{x},\mathbf{y})={\hat r}(\mathbf{x},\mathbf{y})-\beta \log\left(\pi_{\mathrm{RL}}(\mathbf{y}|\mathbf{x}) / \pi_{\mathrm{SFT}}(\mathbf{y}|\mathbf{x})\right).\]

    The real revolution in LLMs came when OpenAI researchers Ouyang et al. (2022) applied this direction of research directly to large language models. Their goal was to make LLMs from the GPT-3 family (Brown et al., 2020) useful and user-friendly. The problem is that by default, a token prediction machine is merely giving you a plausible continuation for a text stream. It is not “trying” to be helpful, inoffensive, or even truthful because continuations such as lying, evading the question, or redirecting the conversation to a new topic also may be just as plausible from the point of view of the training set (which strives to include all meaningful text scraped off the Web) as truthfully and fully answering the user’s question.

    Therefore, Ouyang et al. (2022) applied RLHF, as described above, to the outputs of a large language model; the overall structure of this approach, illustrated in the figure below, is very similar to RLHF for summarization shown above:

    This time, human evaluators are asked to decide which of the model’s outputs are most helpful, least offensive, and most truthful. The resulting LLM, InstructGPT, was reported to significantly gain in truthfulness, toxicity, and following instructions and explicit constraints in the prompt; the improvements were also quite robust and generalized even to languages not present in the human feedback dataset (Ouyang et al., 2022OpenAI blog).

    After InstructGPT, there was only a short step left to ChatGPT. InstructGPT was published in January 2022, and in November, OpenAI published a follow-up introducing a model that was also fine-tuned by RLHF but with an emphasis on conversation (OpenAI, 2022). For ChatGPT, human trainers held prolonged conversations with the model, and human feedback consisted in evaluating entire conversations rather than individual responses to requests; other than that, ChatGPT followed the exact same RLHF method. RLHF set off improvements in making LLMs useful, and the rest was history: the release of ChatGPT set off the “Spring of AI” in 2023 (see our previous post) and the wave of LLM research that we are still experiencing today. We have already discussed this wave, and will probably continue to do so in the future, but now we proceed to a different way to fine-tune LLMs.

    Low-Resource Fine-Tuning via Approximations: LoRA

    In a previous post on extending context windows for Transformers, we saw a number of methods that alleviate quadratic complexity based on low-rank approximations. A similar set of techniques can also be adapted for faster and less memory-intensive fine-tuning.

    Low-rank adaptation (LoRA) is a technique designed to fine-tune large pretrained models efficiently by reducing the number of trainable parameters via low-rank approximations. Introduced by Microsoft researchers Hu et al. (2021), it begins with the classical idea of a low-rank decomposition for a matrix, as illustrated below: a large N\times M matrix \mathbf{X} is approximated with a product of two rectangular matrices,

        \[\mathbf{X}\approx \mathbf{U}\mathbf{V},\quad\text{where}\quad\mathbf{U}\in\mathbb{R}^{N\times k},\mathbf{V}\in\mathbb{R}^{k\times M}\quad\text{for}\quad k \ll N, M.\]

    The product \mathbf{U}\mathbf{V} is, by construction, a matrix of rank k, and there exist efficient algorithms for finding \mathbf{U} and \mathbf{V} such that \mathbf{U}\mathbf{V} is the best approximation to \mathbf{X} of rank k, where “best” is usually understood in terms of the L_2-norm of the difference, \|\mathbf{X} - \mathbf{U}\mathbf{V}\|_2.

    In machine learning, methods based on low-rank approximations have a long history and are always very tempting: if you can assume that a large matrix you are interested in has rank k, you can replace the O(NM) complexity of learning it with the O((N+M)k) complexity of learning the matrices \mathbf{U} and \mathbf{V}, virtually free of charge. For large language models and large neural networks in general, there had been prior research that showed that the space of parameters in large models is usually too large:

    • Li et al. (2018) introduced the notion of intrinsic dimension by training neural networks in random subspaces of the parameter space with gradually increasing dimension; they showed that the dimension when solutions begin to appear is inevitably much smaller than the number of parameters, often surprisingly so;
    • Aghajanyan et al. (2021) applied this concept to fine-tuning language models, showing that standard pretrained models such as RoBERTa have very low intrinsic dimensions, which means that a little fine-tuning in the right subspace can go a very long way.

    This last point is exactly what LoRA is about. LoRA makes the assumption that changes introduced by fine-tuning can be represented with a matrix of low rank. In other words, it fixes the pretrained matrix of weights \mathbf{W}\in\mathbb{R}^{N\times M} and looks for a \Delta\mathbf{W} in the form of a low-rank approximation \Delta\mathbf{W} =\mathbf{B}\mathbf{A}, where \mathbf{B}\in\mathbb{R}^{N\times k}, \mathbf{A}\in\mathbb{R}^{k\times M}.

    For training, LoRA uses a random Gaussian initialization for \mathbf{A} and zero for \mathbf{B}, which means that at the start of training, \Delta\mathbf{W} is zero. Then you just fine-tune the model with your new dataset and use \mathbf{W}+\Delta\mathbf{W} as the new weight matrix.

    By focusing on low-rank updates, LoRA drastically reduces the computational and memory overhead compared to traditional fine-tuning methods. Hu et al. (2021) note that even very small values of k suffice; for example, they list a LoRA checkpoint for the large Transformer model with k=4 and only query and value weight matrices being modified, thus bringing the checkpoint size down from 350GB for the full weight matrix to 35MB, a 10000x reduction!

    After training, there is technically no need to store \mathbf{A} and \mathbf{B}, you can just use the modified weight matrix \mathbf{W}'=\mathbf{W}+\Delta\mathbf{W} since you have to store the N\times M weight matrix anyway. But with LoRA, you can have several different fine-tunings, for a variety of additional datasets and expected effects, applied to the same base weight matrix \mathbf{W}. You only have to store the base matrix once and store new variations as a collection of different \mathbf{A}_i and \mathbf{B}_i, as illustrated below:

    Low memory footprint and much reduced computational requirements for training also make it possible to train LoRA updates even to large models on consumer-grade hardware, without expensive clusters or even multiple GPUs. This has already led to the creation of cottage industries of various LoRA-based modifications for openly released image generation models, especially Stable Diffusion (Rombach et al., 2022), and large language models, especially the Llama family (Touvron et al., 2023a2023b).

    LoRA was introduced in 2021, so naturally, there has already been a lot of research that expands upon these ideas. Let us survey a few important novel LoRA extensions.

    First, an important problem in any low-rank approximation scheme is how to choose the rank k. If it is too high, we are wasting computation and memory, but if it is too low, we are losing valuable expressive power that would cost very little. 

    Therefore, many extensions of LoRA concentrate on how to choose the rank in some automated way:

    • adaptive low-rank adaptation (AdaLoRA; Zhang et al., 2023) parametrizes \Delta\mathbf{W} as a proper singular value decomposition, \Delta\mathbf{W}=\mathbf{P}\boldsymbol{\Lambda}\mathbf{Q}, where matrices \mathbf{P}\in\mathbb{R}^{N\times l} and \mathbf{Q}\in\mathbb{R}^{k\times M} are now orthogonal, and \boldsymbol{\Lambda}= \mathrm{diag}(\lambda_1, \lambda_2,\ldots, \lambda_k) is a diagonal matrix of singular values; in SVD, the magnitudes of singular values |\lambda_i| are representative of the significance of the corresponding additional components in the decomposition, and one can prune singular values of low magnitudes; note, however, that running a full SVD on matrices of size N\times M on every step would be too computationally intensive, so AdaLoRA approximates the decomposition by adding orthogonality regularizers on \mathbf{P} and \mathbf{Q} to the loss function:

          \[R(\mathbf{P}, \mathbf{Q}) = \left\|\mathbf{P}^\top\mathbf{P} - \mathbf{I}\right\|^2 + \left\|\mathbf{Q}^\top\mathbf{Q} - \mathbf{I}\right\|^2;\]

    • sparse low-rank adaptation (SoRA; Ding et al., 2023) notes that the relevant part of the SVD decomposition is that the matrix \boldsymbol{\Lambda} serves as a gating mechanism for rows and columns of \mathbf{A} and \mathbf{B}: if a singular value is zero the corresponding dimension can be removed; therefore, they make this gating stage explicit, considering \Delta\mathbf{W} as the composition with a componentwise multiplication in the middle,

          \[\Delta\mathbf{W}\mathbf{x} = \mathbf{B}\cdot\left(\mathbf{g}\odot\left(\mathbf{A}\cdot \mathbf{x}\right)\right),\]

      and then optimize the components of \mathbf{g} with a sparsity-inducing L_1 regularizer;
    • allocating low-rank adaptation (ALoRA; Liu et al., 2024) also adds a diagonal matrix \boldsymbol{\Lambda} in between \mathbf{A} and \mathbf{B} but does not try to make \mathbf{A} and \mathbf{B} orthogonal; instead, it proposes a separate ablation algorithm AB-LoRA to evaluate the importance of individual ranks in \boldsymbol{\Lambda}, and then prunes ranks with low importance or increases the dimension of matrices where every rank is important; note that here, as usual in LoRA approaches, the \Delta\mathbf{W}=\mathbf{B}\mathbf{A} decomposition is done separately for different weight matrices in the network, and ranks may differ across them;
    • dynamic search-free low-rank adaptation (DyLoRA; Valipour et al., 2023) samples the rank k on every training step and trains the truncated versions of \mathbf{A} and \mathbf{B}; the authors show that this approach can lead to significant time savings in LoRA training;
    • weight-decomposed low-rank adaptation (DoRA; Liu et al., 2024) decomposes each pretrained weight into two components, magnitude and direction, and tunes them separately; formally this means that the weight matrix W gets decomposed as W = \|W \| \cdot ( W / \| W \| ), and LoRa is applied only to the directional part:

          \[W' = \|W\|\cdot \left(\frac{W}{\|W\|} + \Delta W\right);\]

      this can reduce the number of trainable parameters, and the authors show that it matches or surpasses basic LoRa in various tasks at no additional training cost.

    Here is an illustration of several LoRA variations:

    Overall, low-rank adaptation is one of the most popular ways to fine-tune existing large models: even a very small dataset may be enough to train a low-rank adapter, and the resulting model can still use all of the power of the large number of pretrained weights. But it’s not the only way, so let us press on.

    Instruction Tuning and Where to Get the Data for It

    For large language models, both RLHF and low-rank adaptation usually aim to bridge the gap between pretext tasks, i.e., tasks that the LLM pretrains on, and actual use cases that involve fulfilling user requests in the form of natural language prompts. The archetypal pretext task is predicting the next token, but the tasks posed by humans may look very different.

    Therefore, it often makes sense to fine-tune a large language model with a dataset specifically providing realistic examples of instructions and proper responses, a process known as instruction tuning. Here is a general illustration of the instruction tuning process from a recent survey by Zhang et al., (2024):

    Actually, the tuning itself (Step 2 in the figure) is more or less trivial: you just fine-tune the model on a new dataset of inputs and outputs. The interesting part here is usually the dataset construction: where can you get a lot of input-output pairs with realistic instructions and responses? There are several different approaches:

    First, you could always use human labeling: the required dataset size is not that large and manual labeling is often feasible. For example, we have discussed above how Ouyang et al. (2022) trained InstructGPT; we discussed it in the context of RLHF but recall that the first step there was exactly instruction tuning, i.e., supervised fine-tuning (SFT) on a dataset of natural language instructions. For InstructGPT, the SFT dataset contained about 13K training prompts, and the dataset used to train the reward model had about 33K more—not something you can label by yourself over an evening but still eminently feasible. OpenAI used a combination of handcrafted manual labeling and real prompts from their API.

    There already exist a number of public datasets for fine-tuning LLMs. Several of them were intended to make LLMs (and perhaps other models) to better generalize to unseen tasks. Sanh et al. (2022) put it as follows in their paper on one such dataset, P3 (Public Pool of Prompts): “An influential hypothesis is that large language models generalize to new tasks as a result of an implicit process of multitask learning… learning to predict the next word, a language model is forced to learn from a mixture of implicit tasks”. So these datasets make the multitask learning explicit rather than implicit, specifying a wide variety of tasks in the hope that the fine-tuned model will not only do well on those but also will generalize to new tasks when given similar zero-shot instructions. Here is an illustration with sample tasks by Sanh et al. (2022):

    With this approach, one can adapt already existing NLP datasets for various tasks, providing one or a few descriptions for every task and thus turning a dataset previously designed to train NLP models from scratch into a prompt-response dataset suitable for fine-tuning LLMs. P3 combined at least a couple dozen different datasets, and later another couple dozen were added by Muenninghof et al. (2022) who published xP3, a multilingual version of P3 that not only contains more data but also can provide similar tasks in different languages. A similar dataset is Flan 2022 (Longpre et al., 2023), a collection of data for auxiliary tasks used to train the Flan-T5 model (Chung et al., 2022).

    Another important example is Natural Instructions by Mishra et al. (2022), later extended to Super-Natural Instructions by Wang et al. (2022); they employed crowdsourcing labelers to generate questions about text snippets and also answer them in order to make LLMs (or other models) generalize better to unseen tasks, use common sense and common knowledge better, and so on. Here are some sample questions from these datasets:

    Natural Instructions, by the way, can also illustrate the limitations of crowdsourcing. I went to the dataset website and explored the commonsense event duration example, shown on the left in the figure above. Literally the first example I found there looked like this: 

    • Input: Sentence: Islam later emerged as the majority religion during the centuries of Ottoman rule, though a significant Christian minority remained. 
    • Output: What day did Islam emerge as the majority religion?

    Not the most meaningful of questions, and I’m pretty sure it wasn’t intended by the original instructions…

    A dataset even more directly related to LLMs and instruction tuning is databricks-dolly (Conover et al., 2023). It contains over 15000 records manually created by DataBricks employees for different categories of instruction following questions similar to those used in InstructGPT; unlike OpenAI’s datasets, this one is freely available for download, as well as the Dolly LLM fine-tuned on it. Another similar effort is LIMA (Less Is More for Alignment; Zhou et al., 2023), an interesting experiment where the authors fine-tune LLaMA-65B (as the name suggests, it has 65 billion parameters) with only 1000 curated prompt-response pairs, achieving very good results.

    These are some of the manually labeled datasets. But, of course, here we have a great opportunity to circle back to the original topic of our blog and Synthesis AI in general: synthetic data. The first, simpler way to use synthetic data is basically model distillation: once you have a strong (but perhaps large and expensive) LLM you can use it to generate synthetic data for fine-tuning a more lightweight model.

    This is exactly how Alpaca, a well-known open LLM produced by Stanford researchers Taori et al., (2023), came into being. They took the LLaMA 7B model (Touvron et al., 2023), which is a small LLM by modern standards, used a much larger LLM text-davinci-003 (that’s GPT 3.5, the cutting edge model at that time) to generate instruction following examples, and fine-tuned LLaMA 7B on them (illustration by Taori et al., 2023):

    As a result, Alpaca became much better at following instructions than LLaMA 7B ever had been. Note that the dataset size is again not huge, it’s just 52K example even though this time manual labeling was unnecessary.

    The next step, the Vicuna model introduced by Berkeley researchers Chiang et al. (2023), followed suit by training on 70K user conversations with ChatGPT. Vicuna-13B achieved over 90% response quality against ChatGPT (compared to 68% for basic LLaMA-13B and 76% for Alpaca-13B) while using a far smaller model: the training cost for fine-tuning was only about $300. 

    There are many more examples of distillation (see also a survey of synthetic data for LLMs by Liu et al., 2024); important public datasets include:

    • Orca (Mukherjee et al., 2022) and Orca 2 (Mitra et al., 2023), datasets distilled from GPT-4 to make lightweight models better, especially in logical reasoning and choosing a viable strategy for answering a question;
    • Unnatural Instructions (Honovich et al., 2023), a dataset based on Super-Natural Instructions that we have discussed above (Wang et al., 2022); to create synthetic data, the authors take three examples from Super-Natural Instructions as few-shot instructions and ask a strong LLM to generate the fourth;
    • Baize (Conover et al., 2023), a corpus of multi-turn conversations generated with ChatGPT and used to fine-tune the Dolly model we have discussed above;
    • and lots and lots of domain-specific datasets such as, e.g, WizardCoder (Luo et al., 2024), WaveCoder (Yu et al., 2023), and Magicoder (Wei et al., 2023) for programming, that is, source code generation, WizardMath (Luo et al., 2023), MetaMath (Yu et al., 2023), and Xwin-Math (Li et al., 2024) for mathematics, and so on.

    In an interesting recent work, Yue et al. (2024) note the importance of the task distribution inside the fine-tuning dataset, both in terms of difficulty and actual composition of tasks. They propose Task-Aware Curriculum Planning for Instruction Refinement (TAPIR), a multi-round framework that provides the student LLM with problems of increasing difficulty and balanced task distribution:

    The results of distillation efforts may look too good to be true: you take a model with 7B or 13B parameters and achieve results virtually on par with a 100B+ teacher model. There is criticism that suggests that it is indeed too good to be true: UC Berkeley researchers Gudibande et al. (2023) studied the outputs of fine-tuned LLMs more closely and found that while the smaller models learn to imitate the style of larger ones almost perfectly, the actual content is far more often incorrect and prone to hallucinations. Here is an example from their work (conveniently explaining an important notion from one of our previous posts) where the response styles are identical but the explanations of the imitation model are just… totally wrong:

    But be it in style or in substance, the distillation approach from “teacher” to “student” will never give you a model stronger than the teacher; this is a way to get smaller models up to speed with larger ones, not push the frontier. In a different direction of using synthetic data for LLM fine-tuning, researchers are finding ways to bootstrap already strong models into something even better by using the model’s own outputs as synthetic training data. 

    Bootstrapping and Self-Improvement for LLMs

    The archetypal work in this direction is the Self-Instruct pipeline presented by Wang et al. (2023). They begin with a “vanilla” LLM, in this case GPT-3, and a relatively small set of manually written tasks (175 tasks with only one sample instance per task) that serve as a seed for further generation. Then the process goes as follows:

    • ask the LLM to generate new instructions with a few-shot prompt; Wang et al. show 6 human-written instructions and 2 previously produced model-written instructions as examples and ask for a new one;
    • identify whether the result is a classification task; this is also achieved with a few-shot prompt to the same LLM;
    • given the instructions (including newly generated ones), ask the LLM to generate novel instances for them, either input-first (generate an input, then generate a response) or output-first (begin with generating the response and then ask for a matching input);
    • apply some straightforward filtering that promotes diversity across tasks and instances;
    • when you have collected enough tasks and instances, fine-tune the LLM with this dataset; Wang et al. generated about 52K instructions and 82K instances for them before fine-tuning.

    Here is the general pipeline as illustrated by Wang et al. (2023):

    As a result, the Self-Instruct pipeline raised a basic vanilla GPT-3 almost to the level of InstructGPT, with no manual labeling or other human work beyond the original 175 task instructions.

    A natural extension that the Self-Instruct paper (suspiciously) omits would be to take the fine-tuned model and re-apply the bootstrapping pipeline recursively. There will be limits to improvements, of course, but how good can you make a model in this direction? Recursive self-improvement of LLMs is partly the stuff of AI doomer nightmares (see, e.g., our previous post on the dangers of AGI) but at the same time it is already happening in practice! This brings us back to reinforcement learning.

    In RLHF, you collect new data by evaluating LLM responses as you go; note that in principle you could straightforwardly make RLHF into a bootstrapping self-improvement mechanism by delegating evaluation to the same LLM. Several important works extend and improve the basic idea of RLHF by combining it with offline training on collected data.

    In particular, DeepMind researchers Gulcehre et al. (2023) introduce Reinforced Self-Training (ReST), a pipeline where the current policy generates a dataset on the “Grow” step, and then the policy is updated by fine-tuning on the “Improve” step:

    This is basically an application of offline reinforcement learning (Levine et al., 2020) to LLMs, and Gulcehre et al. (2023) report significant improvements; their paper shows results in machine translation, but, of course, a similar framework could be applied to any set of tasks.

    Recursive self-improvement for LLMs lies in the center of DeepMind’s attention; it’s only natural for a company that brought us such RL-based marvels as AlphaGo, AlphaZero, MuZero, AlphaStar, and the AlphaFold series. In another recent work, DeepMind researchers Singh et al. (2024) further improve the ReST framework with ideas based on the expectation-maximization algorithm (EM). It is a rare opportunity for me to take a detour into the probabilistic side of machine learning, so let me explain expectation-maximization in a bit more detail (I actually wrote “delve into” at first but edited it out lest you think I’ve been delegating these posts to LLMs – what a world we live in!).

    In general, the EM algorithm is intended for situations where we have a simple model of the data, but some of the variables in this model are latent, i.e., unknown. The prototypical example is clustering: it usually presumes a really simple model of each cluster (a Gaussian distribution, for example) but it is not known which points belong to which cluster. In general, given a dataset X = \{\mathbf{x}_1,\ldots,\mathbf{x}_N\}, we want to maximize its likelihood

        \[L(\boldsymbol{\theta}|X) = p(X | \boldsymbol{\theta}) = \prod_{n=1}^Np(\mathbf{x}_n|\boldsymbol{\theta}).\]

    But this problem is intractable as written because p(\mathbf{x}|\boldsymbol{\theta}) is a complicated model (a mixture of Gaussians, for instance), and to get back to a simpler model you need to know some latent variable \mathbf{z} for every \mathbf{x}. If we knew which cluster every point belongs to (that’s the \mathbf{z} variable), learning a clustering model would reduce to learning the parameters of several individual Gaussians, which would be trivial. In general, EM is useful if p(\mathbf{x},\mathbf{z}|\boldsymbol{\theta}) is simple but p(\mathbf{x}|\boldsymbol{\theta}) is hard.

    The EM algorithm in this case finds a lower bound for the log likelihood, \log p(X|\boldsymbol{\theta}), that would be actually tractable; maximizing the lower bound turns out to be equivalent to maximizing

        \[Q(\boldsymbol{\theta},\boldsymbol{\theta}^{(n)}) = \mathbb{E}\left[\log p(X, Z|\boldsymbol{\theta}) \middle| X, \boldsymbol{\theta}^{(n)}\right].\]

    Note how here we are no longer talking about the complicated distribution p(X|\boldsymbol{\theta}) but only about the much simpler distribution p(X, Z|\boldsymbol{\theta}); this is the main goal here. The expectation looks complicated but in most actual cases, it just boils down to computing the expected values of the \mathbf{z} variables under the previous model \boldsymbol{\theta}^{(n)}. So while formally the EM algorithm is just repeating the single step of maximizing Q(\boldsymbol{\theta},\boldsymbol{\theta}^{(n)}) and repeating with the new \boldsymbol{\theta}^{(n+1)} until convergence, in reality this maximization usually breaks down into two separate steps that gave the algorithm its name:

    • on the expectation step, the algorithm computes the expectations of latent variables \mathbb{E}[\mathbf{z}];
    • on the maximization step, it substitutes these expectations into \log p(X, Z|\boldsymbol{\theta}) and maximizes the result with respect to \boldsymbol{\theta} then the procedure is repeated with the new value of \boldsymbol{\theta} until convergence.

    This post is not the time or place to provide a full explanation of why Q is a lower bound or how EM works in general, but if you smell something similar to variational approximations that we discussed some time ago, you are completely correct.

    In practice, the EM algorithm often simply means taking the expectation of whatever makes the problem intractable and plugging it into the model (although you do need to check that it makes sense in each specific case). For LLMs, we are trying to optimize some metric (reward) over the possible outputs of a language model. The objective function is thus an expectation over LLM outputs, and of course it would be intractable to take the sum over all possible sequences of tokens. This is where the EM algorithm comes into play:

    • the expectation step turns out to be weighting the output samples according to the probability of obtaining high rewards;
    • and then you can fine-tune the LLM with an objective function weighted by these probabilities.

    Singh et al. (2024) apply this framework to large-scale models from the PaLM family and actually achieve great results in two chosen tasks, mathematical reasoning and code generation (the X-axis shows approximate release time):

    And here are the plots showing how EM iterations help for these problems:

    Interestingly, the models fine-tuned on synthetic data with several EM iterations clearly outperform the same models fine-tuned on human-labeled data (shown with dotted lines on the graphs)! Note that GPT-4 still comes out on top, so we are not yet talking about pushing the actual frontier, but the approach looks very promising.

    DeepMind seems to be leading the way; tweets like this one definitely make you wonder what else they have in stock. But there are other efforts in (usually RL-based) recursive self-improvement for LLMs. In particular:

    • Self-Taught Reasoner (STAR; Zelikman et al., 2022) bootstraps rationale generation, i.e., generating explanations such as the ones that an LLM would produce when asked to “reason step by step”; the authors ask the LLM to produce a lot of rationales, filter the ones that lead to the correct answer, and then fine-tune the LLM on the filtered dataset (shown on the left in the figure below);
    • Rejection Fine-Tuning (RFT; Yuan et al., 2023) takes this idea one step further, developing a better filter for reasoning paths (they specialize in mathematical reasoning, so reasoning paths are chains of equations in this case) based on rejection sampling, looking for the most diverse paths to improve reasoning; moreover, they find a log-linear dependence between the amount of data and model performance and show how better data (e.g., filtered with RFT) can lead to better models;
    • Self-Taught Optimizer (STOP; Zelikman et al., 2024) takes another meta-step: it recursively improves the code (here, the application is coding) that is designed to apply a language model to improve a source code solution for an arbitrary problem; that is, the improver program is called on itself, and the “improved improver” actually generates better programs for downstream tasks (shown on the right in the figure below).

    Overall, I think recursive self-improvement approaches hold a lot of promise even if they don’t achieve the actual fast takeoff singularity (do we really want to achieve that?). The story of machine learning in many different domains comes to the same conclusion: you can succeed up to a point when you try to imitate humans, and LLMs are the best example of this. But if you want to achieve superhuman capabilities, you really need to find a way of recursive self-improvement. In chess and Go, decades of trying to emulate the patterns of human thinking led to some breakthroughs, but when AlphaZero is learning from scratch it just breezes through the top human level without even noticing it, the saturation point comes much later.

    So far, LLMs are mostly trained to imitate human reasoning; after all, the main training process is done on texts written by humans. Can we find a way to bootstrap the model and breeze through the imperfections of human-generated data? Maybe not in general problem solving anytime soon, but at least in more formalized domains such as coding and math where it is easier to generate synthetic problems? Time will tell, and I’m really not sure how much time we are talking about here.

    Conclusion

    In this post, we have discussed the main directions of making language models better. We have seen how a pure token prediction machine can become more helpful and/or more specialized via various forms of fine-tuning or adapter training.

    There are other approaches, too. For instance, we have not mentioned RAG, retrieval-augmented generation, where the generator model is supplemented with an information retrieval mechanism able to gather important information from separately provided sources (Lewis et al., 2020). RAGs are also very important for modern LLMs, but this will be a story for another day. We also did not mention tricks that make training and/or fine-tuning more efficient, such as mixed precision training or gradient checkpointing, which do not provide new ways to adapt models but may significantly extend the feasibility of existing approaches. Finally, another important story is how to best extract the knowledge and reasoning abilities that are already contained in the models, even without any fine-tuning. This is the subject of the rapidly growing field of prompt engineering, a field that already goes far beyond the “please reason step by step” trick (although it is still surprisingly effective).

    Next time, we will discuss another important aspect of the journey that modern generative AI has made over the last couple of years; stay tuned!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Do Androids Dream? World Models in Modern AI

    Do Androids Dream? World Models in Modern AI

    One of the most striking AI advances this spring was OpenAI’s Sora, a video generation model that sets new standards for video consistency and diversity. Interestingly, the official report on Sora is titled “Video generation models as world simulators”. It notes that Sora has emerging simulation capabilities and is on a “promising path towards the development of capable simulators of the physical and digital world”. Today, we discuss world models in modern artificial intelligence: what they are, how they have progressed over the last few years, and where they may go in the future.

    What Are World Models?

    Generally speaking, a world model is an engine that predicts how the environment will respond. The “environment” here may be used in the technical sense of a reinforcement learning environment that gives out rewards and moves the agent to the next state. It could also mean predicting new sensory input for the agent, even when the connection with rewards is unclear.

    In my opinion, it is reasonable to assume that world models inevitably arise when we pose sufficiently hard problems for AI models. It is almost obvious that a robotic agent operating in the real world should have some model of how the world responds to its actions. More surprising, however, is how far we can go in reinforcement learning without explicitly modeling the environment, simply by learning from experience. This approach, called “direct RL,” includes algorithms that learn value functions (e.g., Q-learning) or policies (e.g., policy gradient) and underlies many, if not most, applications of RL.

    But these days, we are often talking about world models arising in large language models; it may seem very surprising given that all they do is predict the next token of a text string (we discussed the basic language modeling task here and here). How can something like a “dream world” arise from solving a straightforward classification problem over the dictionary tokens?

    Consider the variety of problems that can be embedded into language modeling. Languages were created to describe the world, and indeed, you can frame anything in the world as continuing a string of tokens: solve a math problem, invent a recipe, describe the path out of a labyrinth, develop characters… Suppose that we ask the LLM to continue the last chapter of a detective story, when the sleuth is about to reveal who had actually done it. To have a reasonable assumption, the model will have to collect the clues from the context of the whole novel, just like a human reader would do:

    (Sorry for spoiling one of the most important plot twists in the history of detective fiction.) In fact, human readers usually don’t succeed in predicting the murderer, and modern LLMs probably would not succeed too much, even if we use methods discussed in our previous post to extend the context window to the whole book.

    The idea still stands: a “perfect LLM” would have to contain a “true world model” that would be able to reason about the world; that is what people mean when they say that language modeling is an AI-complete problem. But world models are far from limited to language modeling.

    In the rest of the post, we will cover several different results that I would call different kinds of world models as used in deep learning and different aspects of world models. In particular, we will see:

    • how children learn the world and what is the neurobiological inspiration for world models;
    • how weak world models arise as part of representation learning in language models;
    • how similar representation learning for a reinforcement learning environment can be used to inform an RL agent and improve its performance;
    • how they can further be used for planning inside an RL agent, predicting possible responses of the environment;
    • and finally, how Sora uses ideas similar to language modeling to generate high-quality videos.

    World Models in Humans: Predictive Coding

    We all possess world models. This is a self-evident fact, and thousands of pages written on the hard problem of consciousness accept that humans have direct access to an introspective mechanism. This mechanism allows us to reason about the world, emulate possible scenarios that can arise in response to our actions, assess their desirability, and act accordingly. For example, you know that if you loosen the grip on a cup of coffee, the cup will fall to the ground and spill everything, possibly shattering in the process. Therefore, even when you mistakenly pick up a hot cup the wrong way, causing it to burn your hands, you don’t drop it as immediate pain would suggest. Instead, you carefully place it on a table and pick it up by the handle to avoid further pain, but also prevent spilling coffee:

    This is an impressive amount of physics and planning! How did we learn all that stuff? 

    Certainly not from teachers in school or parents explaining how cups of coffee work. We will discuss Yann LeCun’s take on world models (LeCun, 2022) below; for now, let me quote one chart from his paper that deals with infant development:

    As you can see, children learn some pretty complicated concepts at a very young age, when it is clear that the learning cannot come from direct supervision (language and detailed communication with other humans comes much later), and just saying “imitation” also doesn’t explain much. In particular, they learn the so-called “intuitive physics”, which is just what we would mean by a world model: object permanence, properties like solidity, gravity and momentum.

    Note that even just understanding visual inputs is pretty difficult! Our eyes work pretty similar to a camera, registering what is basically a set of pixels at the retina. However, our eyes move constantly in saccades, which are about 200 milliseconds long. This means that pixels change entirely about five times per second, and the visual cortex needs to establish connections between all of these images and provide our internal decision making mechanism (whatever that is) with a streamlined continuous representation of the world around us.

    How do we learn all this stuff? This is a big question that does not have a clear answer. But I want to highlight one theory that is gaining traction in neuroscience: predictive coding (see, e.g., Sprevak, 2021). The idea is that everything that the human brain possibly arises from trying to predict the next set of stimuli (picture from Stefanics et al., 2014):

    According to predictive coding, the brain is mostly doing representation learning, compressing sensory inputs into latent representations that can be used to predict next sensory inputs. Just like a language model, always predicting the next token! And if there is a mismatch between what it predicts and what it actually sees, the neural connections learn to predict better. Just like neural networks, always minimizing prediction error (not by gradient descent, though)! There are even rather compelling reasons to suggest that the brain is doing approximate probabilistic inference; this is known as the “Bayesian brain” hypothesis (Chater, Oaksford, 2008).

    This theory has its own problems, but it quite possibly might be true. If so, resemblances with LLMs are uncanny: by predicting next “tokens” (sensory inputs), our brains develop a world model and even consciousness and first-person experience (whatever that means). Naturally, LLMs and other generative models are not quite there yet; for example, DALL-E currently does not support object permanence across different queries, so the cats and cups in my illustration above are all different; here’s hoping GPT-4o will fix that (see “Geary the robot” here).

    But it looks quite possible that the route to general intelligence and even consciousness lies through building a world model, which in turn can be achieved by predicting the next sensory input, whatever the actual hardware. Naturally, we have no guarantees or even projections about whether a future LLM will be able to achieve it, but to me, learning about the predictive theory of mind was quite a (pardon the pun) mind-blowing discovery.

    So with neurobiology out of the way (and as usual providing more questions than answers), let’s turn to world models in AI. We will go roughly in chronological order, culminating with our main reason for this post, OpenAI’s Sora. I won’t dive deep into the history of deep learning but in the beginning we go back to 2017, when OpenAI was just getting started…

    OpenAI Started with World Models

    For this section on early precursors of world models, I could choose any of a large number of works with similar analysis. But it seems interesting to note that in a way, OpenAI was born out of research precisely about world modeling.

    In a 2017 paper, when OpenAI was less than two years old, Alec Radford et al. used unsupervised learning on large text corpora to solve the sentiment analysis problem, i.e., find out whether a given product review is positive or negative. Sentiment analysis had been (and still is) an important benchmark for text understanding: it is formulated as simple classification but may require deep understanding of the text (up to, e.g., understanding sarcasm), and relatively large datasets such as Amazon Reviews had been made available long ago.

    In 2017, Transformers were not yet invented, so Radford et al. trained a variation of an LSTM (a standard recurrent architecture, see, e.g., here) as a character-level language model. This means that the model “reads” a text prompt and predicts its next character (rather than a word-level token, as modern LLMs do); this can be done in a completely unsupervised way, you don’t need to have sentiment labels to train a language model.

    But the interesting part was that in the latent representation learned by the model (it was a vector of dimension 4096), Radford et al. found a specific component (cell, “neuron”, call it what you will) that was responsible for sentiment! Moreover, if you fix the value of the “sentiment unit” and generate new reviews, their tone will come out just as you would expect. Here are a couple of illustrations from the OpenAI paper; on the left you see the activations of the “sentiment unit” on a sample movie review, and on the right, generation results with fixed sentiment: 

    In this work, we have two important components:

    • the model has a distinct “place” where it is storing important individual components of its environment (“world”), in this case sentiment of movie and product reviews;
    • this “place” can be used to model new, previously unseen parts of the “world”, in this case generate new reviews with fixed sentiment.

    So in a way, that was already a “world model”. This kind of work has been an important part of the AI interpretability field, and important progress is still being made, most notably the (very) recent work by Antropic (Templeton et al., May 2024) that we may discuss in a future post separately.

    But these ideas are a little different from the main emphasis of this post and, generally, what we mean by world models nowadays. Let us move on and see how our current understanding came into being.

    Schmidhuber Was Here First (Again)

    There is a well-known meme in AI research circles: one of the fathers of modern AI, a prominent German researcher Jürgen Schmidhuber, loves to explain in his talks how he and his team pioneered many ideas that are foundational to modern AI. While some researchers believe he occasionally oversells his past results (see the corresponding Wikipedia article section), quite often he is indeed correct in his claims!

    For example:

    • in mid-1990s, Schmidhuber and his student Sepp Hochreiter (now a renowned researcher too) authored the long short-term memory unit (LSTM) which many recurrent neural networks are still (thirty years later!) based on (Hochreiter, Schmidhuber, 1997),
    • in 1991, Schmidhuber published a paper titled “A Possibility for Implementing Curiosity and Boredom in Model-Building Neural Controllers” where he introduced adversarial learning, with one network building the model of the environment, and another controller network looking for weak spots in the model network; 25 years later, this idea evolved into GANs and many other applications of adversarial learning;
    • in 1992, Schmidhuber published a paper on “fast weight programmers” where weights on a neural network connection were split into “slow” (regular) and “fast” that reflect short-term memory; in modern terms, the latter would be called “attention weights”, and Schmidhuber’s idea was equivalent to a (linearized) Transformer!

    By the way, I also highly recommend Prof. Schmidhuber’s works on the history of deep learning; he cites many early works that I would never learn about otherwise (Schmidhuber, 2013201420202022).

    So it is no wonder that in 2018, it was Jürgen Schmidhuber (together with Google Brain researcher David Ha) who again showed this superhuman sense for promising ideas, presenting a paper at NeurIPS whose arXiv version is called simply “World Models” (here is a GitHub version with interactive animations).

    They present a simple idea: we humans have mental models of environments around us. So what if we train a network to learn an internal model of some, say, reinforcement learning environment such as a 2D car racing game (main example in the paper)? The model is similar in design to a language model: it learns an internal representation for frames from the environment via autoencoding and learns to predict the next frames. 

    This allows the model to plan, just like the RL agents above; then a separate controller model can use the internal representations that have been created with this planning in mind to choose the best action. Here is an illustration by Ha and Schmidhuber:

    The authors show how world models improve agent results in this racing game and in another standard RL environment, a simple Doom level where you need to navigate away from fireballs. Here is an illustration from the paper that shows a reconstruction of how the agent imagines the environment – pretty close to the real thing, and quite enough to be able to learn on your dreams:

    This work was one of the first to show a full end-to-end system with a world model in the form of a neural network learning the environment and helping the agent to act in this environment by providing useful representations for the states. In the next section, we will see a way to go further and use the world model to actively do planning in an environment rather than just feature extraction.

    From AlphaZero to MuZero and beyond

    DeepBlue defeated Garry Kasparov in 1997, and Vladimir Kramnik was essentially the last human to play even matches against a computer in the early 2000s. Despite the long history of computer chess and its symbolic importance as a pinnacle of human intelligence, chess programs of that era did not resemble “true artificial intelligence” at all. They were primarily alpha-beta tree search engines with sophisticated position evaluation functions (this is where machine learning could contribute). AI needed a different testbed.

    At the same time, the game of Go looked unassailable. Tree search does not work nearly as well there because there are far more reasonable possibilities on every step. At the turn of the century, the best computer Go programs lost to mediocre human professionals with enormous handicaps of 15-20 stones. The situation changed in 2007, when Remi Coulom revolutionized computer Go with Monte-Carlo tree search (MCTS), a method that constructs a tree of possible moves with multiarmed bandit algorithms helping to choose where to put the majority of “experiments”. But still, before AlphaGo beat Lee Sedol, the best Go playing models had been weak compared to professional players. I will skip AlphaGo (Silver et al., 2016) and go straight to AlphaZero here.

    The idea of AlphaZero (Silver et al., 2017a2017b) is deceptively simple: on every training step, the model performs MCTS that can efficiently search a few moves ahead and thus improves the current policy (playing strategy). Previously, MCTS was used in decision time, to improve the current policy by refining its estimates of position values; in MCTS-based Go programs MCTS was often the only method, with no training at all.

    AlphaZero’s idea was to use MCTS in training time and modify the policy with a gradient step towards a new policy improved by MCTS. The training algorithm always has a moving target: for the current policy π, AlphaZero constructs a new policy π’ by applying MCTS to improve π. Then π is improved with policy gradient algorithms to make it closer to π’—but now π’ is better yet, and the process can be repeated. In this way, the policy is continuously brought to new heights (illustrations a and b below are taken from the AlphaGo Zero paper): 

    To do that, AlphaZero needs to be able to construct the search tree, which it does by self-play: during training, the agent is playing (an earlier version of) itself. But to run self-play, AlphaZero obviously needs to know the rules of the game. Note that it’s not the same as the model of a reinforcement learning environment since the latter also includes the opponent, but if you have an agent to run as the opponent then yes, this means you have a model of the RL environment.

    For chess and Go, a perfect simulator of the environment is easy to construct: you are already learning an agent to play for each side, so you can use the current agent for Black to play against when you are learning to play White better, and vice versa. But for a richer domain, say for a computer game, it would be very hard to learn a simulator for the environment because apart from the agents it also would have to contain the game engine, and you cannot assume that a perfect copy of the game engine is available. And for an even richer domain, say for robotics, the “game engine” would include all of the relevant laws of physics — definitely not something we can assume away or easily learn.

    Therefore, MuZero (developed by DeepMind researchers Schrittwieser et al., 2020) takes the next step: it does not need to know the rules, and it learns a model of the environment in a compressed form of hidden states. This representation learning allows the model to learn the environment dynamics in a model that predicts the dynamics of hidden states only, with no need to predict the whole huge state such as the pixels of a game screen. This hidden state is exactly what I would call a world model. Now MuZero can also do MCTS, but in this case the construction of subsequent states in the tree is produced by this “dream” about the latent representations, like this (illustrations a and c below are from the MuZero paper):

    It is no wonder that MuZero was able to extend the success of AlphaZero to richer environments such as Atari games, outperforming the then-champion model-free RL algorithm called R2D2 (Kapturowski et al., 2018). What is interesting is that MuZero actually outperformed AlphaZero in settings where the rules of the game are known, reaching a higher Elo rating in Go and performing on par with AlphaZero in chess and shogi:

    Schrittwieser et al. hypothesized that “MuZero may be caching its computation in the search tree and using each additional application of the dynamics model to gain a deeper understanding of the position” — in other words, the world model added to MuZero became a way to understand the game better than even AlphaZero’s masterfully learned feature extraction. It can focus on only the important features of the environment, abstracting away everything else because its world model does not have to predict all of the features.

    This direction is being continued today. I want to highlight one more very recent approach by Alonso et al. (May 2024), called DIAMOND (DIffusion As a Model Of eNvironment Dreams), where a diffusion model serves as a world model for visual tasks such as playing Atari. In MuZero, the imaginary unrolling takes place in the latent space. In DIAMOND, the world model actually produces pictorial representations with a diffusion-based model. The diffusion process is conditioned on prior observations and action taken by the agent (illustrations from Alonso et al., 2024):

    The motivation for this is that for many tasks, small details in the visual input—such as the ball position in Breakout or Pong or the color of a streetlight in an autonomous driving task—may have a drastic effect on the policy. And a diffusion model is a great way to capture visual representations:

    So we see that world models have proven to be useful even in domains where they are not strictly necessary. What about the domains where they seem to be inevitable? What about, say, robotics?

    World Models in Robotics and Embodied AI

    Robotics generally relies on reinforcement learning (Sutton, Barto, 2018): an agent cannot have a sufficiently robust dataset of the physical world’s reactions in advance, it must obtain this dataset by trial and error. However, unlike AlphaZero and MuZero, which can play against themselves very efficiently, we can’t run a robot in the real world billions of times.

    At this point, world modeling circles back to our main emphasis here at Synthesis AI, to synthetic data. You could say that Ha and Schmidhuber’s models were generating a synthetic representation of the world, and that MuZero was generating synthetic traces of gameplay, but there was an important difference: MuZero was doing it in its own latent space. There is no way to go back from the representation to a full-blown game state: you could train a decoder but it would probably be imperfect.

    In robotics, synthetic data often takes the form of full-scale simulators that include the relevant laws of physics, material properties, and so on, aiming for a maximally accurate representation of the physical world. I will not spend much time on a review of such simulators here, but they have been surveyed, for instance, in my book “Synthetic Data for Deep Learning”.

    We will get to using such simulators below, but in this section let us make a different point. The world model can be fully learned from experience, just like a human child does not obtain any external information except sensory inputs to the brain (kind of by definition) but still learns a world model with astonishing efficiency. 

    Researchers have attempted to replicate this with deep neural networks. One curious attempt was made back in 2016 by Agarwal et al. in a paper called “Learning to Poke by Poking”. They let a robot randomly interact with objects by poking them and seeing what happens; “seeing” here should be understood literally, the model is learning from visual input. Like this:

    This approach did not take on, but it was developed a long time ago, and by now we have many new ideas at least for the network architectures, so it may be worthwhile to try again. In general, even though our current understanding of reinforcement learning makes it hard to learn a full world model in reality, where experiments are very costly, to many researchers this looks like the way forward.

    Researchers like Yann LeCun, whose position paper called “A Path Towards Autonomous Machine Intelligence” argues for just that. LeCun suggests that truly autonomous agents should be built with learned world models. In a way, it is a natural extension of the actor-critic paradigm in reinforcement learning. In RL, the agent is learning a strategy π to produce actions in a state s according to the distribution π(a|s), and the environment responds by providing the immediate reward r and the next state s’:

    In a general policy gradient algorithm, π is learned directly from experience (as shown on the left). In an actor-critic architecture, there is a separate component within the agent that learns a value function, i.e., the expected total reward an agent would obtain starting from a state s, V(s), or starting from a state s with action a, Q(sa); this is shown on the right above. A critic helps the agent to refine its policy updates.

    With a learned world model, the actor-critic interaction becomes much richer: now the agent is able to “imagine” potential responses of the environment and search for whole sequences of actions, just like MuZero, but probably without the same kind of search tree since now the actions might be very numerous or even continuous. Here is a picture from (LeCun, 2022) that shows how a single episode of the agent interacting with the environment would go:

    The sequence of actions here is entirely “in the mind” of the agent. Predicting a whole sequence of actions is probably quite expensive computationally, but once we have this prediction, we have a lot of loss function gradients to propagate: every step of the sequence can be compared with actual experience. So this approach can both help train better policies directly and also be used in a MuZero-like fashion to perform decision-time planning.

    And with that, we come to our central point: what’s going on in OpenAI Sora?

    Sora: A World Simulator or “Just” a Diffusion Model?

    Ideas similar to Ha and Schmidhuber (2018) continue to define what world models mean for AI. The latest addition to the formidable OpenAI roster of foundational models, the state of the art video generation model Sora, is explicitly designed around the idea of world modeling. Their technical report is titled “Video generation models as world simulators”, although the report only states that Sora “simulate[s] some aspects of people, animals and environments from the physical world” and does not give any hard facts to support this, so we will have to make our own conclusions.

    Following OpenAI’s recent (quite understandable) practice of limited transparency, there is no detailed paper on Sora, only a rather vague blog post and report. Essentially, the only thing that is clear is that it is based on a Diffusion Transformer (DiT). While we have discussed latent diffusion models on the blog before, and covered diffusion models in detail, but I have not yet explained DiT here, so let me provide some context.

    Introduced by Peebles and Xie (2022), Diffusion Transformers showed that the Transformer architecture can be useful even for a denoising element of a diffusion model. For instance, Stable Diffusion (Rombach et al., 2022) used a diffusion model to produce the latent code for a VAE-based decoder, and DiT also follows the same basic structure (the picture is copied from a previous post):

    However, this picture does not show what’s inside the denoising blocks. Stable Diffusion used a U-Net-like architecture with cross-attention layers that effectively utilized the condition, yet retained a general U-Net structure (picture from Rombach et al., 2022):

    Diffusion Transformers use a “pure” Transformer block for denoising, with a neat trick of using the layer normalization block similarly to AdaIN (Huang, Belongie, 2017) style transfer blocks; illustration from (Peebles, Xie, 2022):

    The resulting architecture proved to be much more compute-efficient than previously used U-Net-like diffusion models. In Sora, DiT is generalized to higher-dimensional patches that cover both space and time inside a video. Although the exact way it is done has not been revealed, there is at least one prior model, the GenTron by Meta researchers Chen et al. (2023), that adapts DiTs to video. Here is a generic illustration from the Sora report:

    But I digress. Regardless of the model itself, Sora provides great video generation results that often exactly follow our intuitive understanding of physics, although sometimes they fail in that regard. Does this mean that Sora is at least halfway to the holy grail of learning an operational world model from raw video inputs?

    At this point, let me link to a very detailed blog post by Raphaël Millière called “Are Video Generation Models World Simulators?”. It covers many of the points that we are going through here, and I recommend it in its entirety. In particular, Dr. Millière considers several definitions of a “world model” and carefully studies whether Sora is likely to fit any of them. His conclusions, which I fully endorse, are as follows:

    • being a single end-to-end model that operates fully in latent space, Sora does not have separate components needed to actually have an “internal physics engine”, so it cannot be a “world simulator” in the sense of synthetic data simulators like MuJoCo;
    • however, the structure of its latent space may well be sufficiently complex to capture and predict certain physical phenomena based on its latent representations.

    To me, this is an interesting discussion (and a great post, please do read it!) but these conclusions slightly miss the point. Of course a deep learning model does not have an internal physics engine unless one is artificially attached to it (see below). You and I, however, may not have one either!

    Again, I can only recommend reading through the section by Dr. Millière on “intuitive physics”: human infants learn to expect certain physical properties very quickly, and there is a well-established “IPE hypothesis” that posits the existence of an “intuitive physics engine” in our minds. But even for humans, it’s just a hypothesis, and there is an opposite opinion that human physical reasoning is based on visual shortcuts and generally predicting what we will see next rather than approximating the relevant laws of physics.

    For Sora and similar models, this hypothetical intuitive engine is even harder to believe in. Some examples generated by Sora clearly violate even our basic intuitions like object permanence or collision properties, which is, of course, expected from a diffusion-based generative model, but not really expected from a physics simulator, however “approximate” it is:

    The question for me here is: does it really matter? We humans probably don’t have a built-in Unreal Engine to tell us how the world works. But we have an intuitive understanding of the world that allows us to make predictions, and these predictions are accurate enough for most practical purposes. Sora is not quite there yet, but if some upcoming Sora 2 or Sora 3 does have a similar understanding, it will be enough to disallow videos with such internal contradictions. 

    Still, this may sound like a lot of work for naught. Why should we wait until some latent representation learns to approximate Unreal Engine 5 from scratch when we already have Unreal Engine 5? Indeed, there have been attempts to combine machine learning models with external tool calls to world simulators; let’s discuss them before we conclude the post.

    Adding a “true” world simulator: grounded LLMs

    Even with all the RLHF fine-tuning and other advanced techniques, large language models primarily train as their name suggests: by predicting the next token of text. By default, they don’t have access to external tools like calculators or physics engines, and learning exclusively from text can lead to simple mistakes in this context.

    In other words, a large language model, no matter how smart, is akin to a medieval scholastic thinker who derives knowledge exclusively from Aristotle but cannot conduct experiments or use empirical evidence. It would make a lot of sense to let an LLM call some external tools that would provide this evidence to use in the LLM’s reasoning and to inform its replies. This is called grounding, and it is indeed known to be a good way to improve LLM results:

    For example, one significant result along this way was the Toolformer approach (Schick et al., 2023), where an LLM learns to use a new tool from a brief description of its API. As a result, the LLM can access a wide variety of tools and learn new ones on the fly (examples from Schick et al., 2023):

    And yes, there already exist approaches that ground LLMs with “more real” world simulators to help them reason about our physical three-dimensional world.

    For example, the recently developed Grounded 3D-LLM (Chen et al., May 2024) adds special referent tokens that correspond to objects in the 3D environment where the LLM is planning some actions:

    Its 3D point cloud encoder is trained with a cross-modal pretraining procedure based on contrastive losses, similar to CLIP (OpenAI, 2021; see also our earlier post), and the LLM is fine-tuned with LoRA to understand how to work with referent tokens:

    The work nearest to our current discussion was done by Liu et al. (2023) from Google Research. They recognize the problem of linguistic reasoning in the physical world and develop an approach called Mind’s Eye that lets an LLM to query a computational physical engine, in this case DeepMind’s MuJoCo (Todorov et al., 2012). 

    The LLM writes rendering code and runs the external physics engine, informing its output with simulation results:

    The authors show that this kind of grounding does help LLMs reason better with this “mind’s eye” powered by a computer simulation. So in a way, we already know how to insert a realistic externally implemented world model into an LLM to inform its reasoning about the world. Still, there are at least two important missing pieces:

    • first, tool use in LLMs still leaves quite a lot to be desired; it is an active area of research where Toolformer has already been succeeded my many works (see, e.g., Tang et al., 2023Qin et al., 2023Zhuang et al., 2023), and there appears to be still some way to go before grounded LLMs reach maturity;
    • second, even if this interaction between an external world model and an LLM worked perfectly, it would be quite far from physics-aware video generation: in this section, we are talking about specifying queries for the simulator and running experiments with simulated physics, but even if any videos do come out of it, they will be limited by the simulator’s capabilities.

    In my opinion, a foundational model cannot practically run a complicated external tool every time it needs to generate something, but it can certainly use an external simulator for training. Software libraries such as MuJoCo can provide a foundational model, be it an LLM or a multimodal generation tool, with an endless stream of synthetic data and, more importantly, synthetic environments that could be used to experiment and learn about the physical world. This, again, brings us back to our favorite domain of synthetic data, which would include a synthetic physics simulator as well.

    Conclusion

    In this post, we have discussed world models in modern AI, starting from a very abstract notion of a world model and gradually making it more explicit until, in the end, we showed how to add an external physics-based simulator engine to state of the art LLMs.

    I would like to conclude this post by mentioning a recent work that appeared in May 2024: a paper by MIT researchers Huh et al., titled “The Platonic Representation Hypothesis”. In agreement with Plato’s ideal world of perfect forms (eidos), the authors posit that sufficiently expressive neural networks will converge to the same “optimal” representation of reality in their latent spaces, regardless of the modality they are trained on. This hypothesis is supported by several observations and empirical evidence in this intriguing work:

    Still, despite the appearance of Sora that is head and shoulders above previously existing video generation models, and despite recent models that model visual environments with diffusion models and ground LLMs with interactive physics simulators, it looks like the field of applying world models to modern generative AI is still at its very inception. It will be exciting to see how world models become better and more prominent across various AI-related domains, and here at Synthesis AI we hope to spearhead at least some of these applications. See you next time!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Lost in Context: How Much Can You Fit into a Transformer

    Lost in Context: How Much Can You Fit into a Transformer

    The announcement of Gemini 1.5 by Google was all but eclipsed by OpenAI’s video generation model Sora. Still, there was one very important thing there: the promise of processing a context window of up to 1 million tokens. A very recent announcement of new Claude models by Antropic also boasts context windows of up to 1M tokens, with 200K tokens available at launch. Today, we discuss what context windows are, why they are a constraint for Transformer-based models, how researchers have been trying to extend the context windows of modern LLMs, and how we can understand if a large context window usefully works. By virtue of nominative determinism, this is a a very long post even by the standards of this blog, so brace yourself and let’s go!

    The Quadratic Complexity of Self-Attention

    One of my last posts was intended to provide detailed background on the Transformer architecture based on the original “Attention is All You Need” paper, so I will not repeat the introduction again. All we need now (pardon the pun) is to recall how self-attention itself works in general terms; here is a picture from that background post:

    There is an important problem here that follows from the very structure of self-attention. The formula that everyone has been copying thousands of times looks as follows:

        \[\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{Q}\mathbf{K}^\top\right)\mathbf{V}.\]

    The matrix computation inside the softmax is what matters most for us today. To get to the final result we need to compute self-attention weights between every pair of input tokens

        \[\alpha_{ij} = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{q}_i\mathbf{K}^\top\right)_j,\]

    which means that we need to get a result of size L\times Lquadratic in the number of input tokens. 

    This is in fact a tradeoff that self-attention brings to the table compared to, say, recurrent networks. An RNN can read its input in linear time, O(dL) where d is the input dimension, because it reads the input consecutively, token by token. But this means that to get from token 1 to token L you have to make L steps, one by one. This is why RNNs struggle with long-term dependencies: the influence of a token gets diluted and reduced as you make way through the intermediate sequence. A lot of what has been happening in RNNs, including LSTM and GRU recurrent units, has been designed specifically to alleviate this problem. In a Transformer, the problem just vanishes: you get from any token to any other in a single step—at the cost of quadratic complexity of the layer.

    Unfortunately, quadratic complexity of self-attention is not some intermediate result that you can hope to optimize away by designing a more efficient algorithm. The actual number of weights is quadratic, and while there may exist faster approximations (we will discuss them below), we cannot get it entirely right in subquadratic time.

    Actually, there even exist negative complexity theoretic results proven by Keles et al. (2022) who provide a kind of “no free lunch” theorem for self-attention. Specifically, Keles et al. prove that:

    • you cannot get sub-quadratic algorithms that compute self-attention exactly unless you can solve NP-hard problems in subexponential time (this is known as SETH, the Strong Exponential Time Hypothesis, and it is just as widely believed to be true as the P!=NP assumption that it strengthens);
    • moreover, under the same assumption you cannot even get a subquadratic algorithm that approximates all attention weights up to a constant, either an additive constant or a multiplicative one.

    Here is a summary table from the paper where checkmarks correspond to proven negative results; as you can see, some additive approximations are still possible but overall it looks pretty comprehensive:

    Quadratic complexity does not look like an obstacle as insurmountable as the exponential complexity of NP-hard problems. It may sound like a lot of commonly used algorithms have quadratic complexity. For example, bubble sort has O(n^2) complexity in the worst case and even long multiplication of two n-bit numbers has quadratic complexity (quadratic in total input size, of course, which is logarithmic in the values of the numbers themselves).

    But this is an illusion: for common problems such as sorting and multiplication, naive algorithms may be quadratic but as soon as you want to scale them up to large arrays and really large numbers, you have to find something more efficient. There are plenty of O(n \log n) algorithms for sorting. Multiplication has been done with Karatsuba’s algorithm in time O(n^{1.58}) since the 1960s, and a recent highly acclaimed result by Harvey and van der Hoeven (2021) has brought integer multiplication down to the same O(n \log n) complexity (although the algorithm itself is probably too involved to find much practical use). In fact, you would be hard pressed to find practical problems that actually have quadratic complexity that people don’t know how to reduce further. In this regard, computing self-attention is actually an important example for theoretical computer science as well.

    Authors of the original Transformer did not yet have the negative theoretical results but they already understood that quadratic complexity is hard to scale. They remark on the quadratic complexity of self-attention throughout the paper and even propose a middle ground solution that can alleviate it. Let us call it the sliding window attention: it restricts the self-attention mechanism to a subwindow of size r around the current token. Then

    • instead of O(dL^2), where d is the embedding dimension, the self-attention layer will only have complexity O(drL) since now we compute only r attention weights around each of L input tokens;
    • but the tradeoff is that now we cannot get from any token to any other in a single step, we have to make L/r sequential steps from one side of the input sequence to the other.

    Despite promising results in the original paper, this solution has not really caught on; I believe that it sacrifices too much to be useful, working too much like an RNN as a result. But increasing context size has remained one of the central problems for Transformer-based architectures ever since they were first designed in 2017. In the rest of this post, we will discuss other ways to reduce the complexity of self-attention.

    Sparse Attention Mechanisms: Do We Need the Full Square?

    The first obvious research idea here goes like this: okay, suppose we do have to have quadratic complexity to compute a quadratic number of attention weights, but do we really need all these weights? The sliding window attention from the original “Attention is All You Need” paper falls into this category as well, but more successful approaches were developed later.

    In 2020, researchers from the Allen Institute of AI Beltagy et al. proposed Longformer, a replacement for the full-scale quadratic attention mechanism that scales linearly with input size. The basic idea is to still use the sliding window attention—after all, local context is indeed usually the most important—but augment it with several important tricks. Here is a general illustration from the paper:

    The first trick here, shown in part (c), is to use a dilated sliding window that skips over some inputs, using, say, every second one. This idea is well known in convolutional architectures, where it is also used to increase the receptive field of neurons, covering more ground in fewer layers. In the one-dimensional context it was very successfully used, say, in the WaveNet architecture; the illustration of WaveNet shown below explains how repeated dilation can exponentially increase the receptive field, which in our case means reducing the number of steps needed to go over the entire input:

    Moreover, in Transformer’s multi-head self-attention you can use different dilations for different heads! For example, if the first head uses odd-numbered tokens from the input, and the second head uses even-numbered tokens, the window size doubles with the same number of attention weights but the layer does not skip anything from the input.

    The second trick, shown in part (d) above, is to have several tokens that have global attention so that their weights span the whole input. Now for every token with global attention you have O(L) attention weights but the assumption is that there are only a constant or logarithmic number of such tokens so the overall number of attention weights remains low. 

    This trick is well known to anyone who has used, say, BERT embeddings in practice: it is usually very helpful to add a special [CLS] token and use it to capture global properties of the input, e.g., train a classifier on [CLS] embeddings. This is exactly the function of global attention in Longformer, only now we remove most of the attention weights between other tokens.

    Longformer proved to be quite efficient in practice, but there also exist some conceptually more interesting ways to implement sparse attention. In “Generating Long Sequences with Sparse Transformers”, four OpenAI researchers (Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever) studied the attention patterns learned by full quadratic self-attention layers and found that most layers had sparse attention patterns. Therefore, they proposed to formally restrict self-attention to sparse patterns: each attention head i has a subset A_i of input tokens that it attends to. At the same time, Child et al. require that the factorization into A_i is valid, i.e., that every token can attend to every other in p attention steps. 

    It turns out that one can choose efficient factorizations A_i, finding valid A_i of size O(L^{1/p}), which is obviously minimal since you need to get to L in p steps. Specifically, they build upon the natural general idea for such a factorization: one head attends to l sequential locations, and the other head uses a kind of dilated sliding window with stride l, where l is close to \sqrt{L}.

    There are two basic ways to implement this; in part (b) of the figure below dilated attention moves along the tokens, and in part (c) there are selected positions that a lot of other tokens attend to, but a composition of such two heads (light blue and dark blue) connects every pair of tokens in both cases:

    Child et al. comment that the first kind of sparse decomposition (part (b) of the figure) works well for data with periodic structure such as images or music, while fixed attention patterns (part (c) of the figure) are better for data without clear periodicity, such as text. They also make use of an earlier OpenAI development: block-sparse weight matrices implemented at the hardware level, as GPU kernels (Gray, Radford, Kingma, 2017); this allows to make sparse self-attention matrices very efficient in practice.

    The last paper I want to mention in this section is “Big Bird: Transformers for Longer Sequences” by Google researchers Zaheer et al. The structure of their sparse self-attention mechanism combines previously developed ideas: global tokens that attend to everything, sliding windows that attend to local context, and a set of random attention weight positions that add expressivity at a small computational cost:

    They view self-attention as a directed graph that shows which positions attend to which other positions; the figures we have seen in this section are adjacency matrices for such graphs. Under this view, the requirement that every token can attend to any other in a few layers turns into the requirement that the adjacency graph has short paths between the nodes. Fortunately, random graphs do have that property. Zaheer et al. consider two random graph constructions:

    • in the Erdős–Rényi model, every edge is chosen at random with a fixed probability; it is known that in this model, shortest paths have on average logarithmic length, which would suit us fine, but the Erdős–Rényi model does not have another important feature: locality, i.e., in this approach a token will not attend to the local context around it;
    • therefore, Zaheer et al. move on to the Watts–Strogatz model that begins with a ring lattice (where every token attends to a fixed size window around itself) and then adds random connections; this can achieve a good balance between shortest path lengths and local context;
    • finally, the Big Bird model also adds global tokens, as described above (and as shown in the figure above).

    In summary, in this section we have discussed several approaches that reduce the number of weights by decomposing the full attention matrix into a composition (product) of several sparse matrices. We have viewed it as a sparse graph where it takes several steps to go from one vertex to another, or you can view it in a way similar to dilated convolutions with added global tokens. But there is another way to arrive at a similar approach: sparse matrix decompositions; let us consider it separately.

    Low-Rank Decompositions: Linformer and Performer

    Another direction from which we can attack the quadratic complexity is to use low-rank approximations and/or projections to reduce the size of the matrices. This is a very important and useful trick: if you have a huge N\times M matrix X that you cannot really work with, you can try to reduce it by decomposing X into a product of two rectangular matrices, one of size N\times k (denoted U below) and another or size k\times M (denoted V):

    The actual mathematical result, called the singular value decomposition (SVD), says that you can do it if X has rank at most k, but even if not, you can consider a low-rank approximation to X in this form. For example, recommender systems make use of the SVD all the time. In a recommender system:

    • X is the matrix of ratings (or likes, or any kind of user responses), where N is the number of users and M is the number of items, and most values of X are unknown (a user has surely rated only a small subset of items);
    • U is the matrix of user features, i.e., dense numerical vectors of size k that describe every user; V is a similar matrix of item features;
    • and if we find a good low-rank approximation UV for known elements of X, it will allow us to give predictions for unknown elements of X by taking the dot product of user and item features (that’s exactly what’s happening in matrix multiplication).

    So how does this idea help reduce the complexity of self-attention? We will consider two variations of it.

    First, the Linformer architecture, proposed by Facebook AI researchers Wang et al. (2020), is probably the most straightforward way to apply low-rank approximations to self-attention. The authors begin by noting that self-attention matrices in practice usually have low rank:

    The plots on the left show that the 128th eigenvalue (out of embedding size 512) already captures most of the information in the self-attention matrix, and the plot on the right, which shows the share of shows cumulative eigenvalues is taken by the first 128 out of 512, indicates that this effect is more pronounced in higher layers. The authors even provide a theoretical justification for this, which we will not go into.

    What this means is that we can replace the self-attention matrix with a low-rank approximation. To do that, we need to add a projection matrix after the key and value matrices, projecting the context window size n down to some smaller value k:

    The query matrix remains of size n, and we get a formula for self-attention where there are plenty of k\times n matrices but no n\times n matrices:

        \[\mathrm{head}_i=\mathrm{softmax}\left(\frac{1}{\sqrt{d_k}}\mathbf{Q}\mathbf{W}_i^Q\left(\mathbf{E}_i\mathbf{K}\mathbf{W}_i^K\right)^\top\right)\cdot\left(\mathbf{F}_i\mathbf{V}\mathbf{W}_i^V\right),\]

    where \mathbf{E}_i and \mathbf{F}_j are k\times n projection matrices, so the outer product is a product of an n\times k matrix of attention weights and a k\times d projected matrix of values.

    As a result, the Linformer does scale very well with the input sequence length, as shown in the top right plot above. But there are other ways to apply similar ideas too.

    The Performer architecture, developed in a collaboration between Google, Cambridge, DeepMind, and Alan Turing Institute researchers (Choromanski et al., 2021), introduces the Fast Attention Via positive Orthogonal Random features approach (FAVOR+) based on a low-rank decomposition with random features. Let’s dive into some details here.

    We know that in a self-attention layer, the Transformer uses an attention weight matrix A to create convex combinations of the value vectors V. Moreover, as we know, A is in fact produced by passing the matrix of (normalized) dot products between query and key vectors through the softmax function:

        \[\mathbf{A}=\left(\frac{1}{\sqrt{d_k}}\mathbf{Q}\mathbf{K}^\top\right).\]

    FAVOR+ generalizes this construction as follows: let us consider arbitrary L\times L matrices \mathbf{A} produced as

        \[A_{i,j} = K(\mathbf{q}_i, \mathbf{k}_j),\]

    where K is a kernel function K: \mathbb{R}^d\times \mathbb{R}^d\to\mathbb{R}_+. For Transformer self-attention, K is the softmax function of the scalar product normalized by \sqrt{d}

    Suppose that we can construct a randomized mapping (random feature map) \phi such that \phi maps the input embedding into some smaller space of dimension r, \phi:\mathbb{R}^d\to\mathbb{R}^r, and in expectation \phi gives us the kernel function:

        \[K(\mathbf{x},\mathbf{y}) = \mathbb{E}\left[\phi(\mathbf{x})^\top \phi(\mathbf{y})\right].\]

    If we can find such a random feature map \phi, it will give us a natural way to approximate the attention mechanism: 

    • map the embeddings into a smaller r-dimensional space via \phi and represent the computation of attention weights \mathbf{A} as above;
    • but then compute the product of the new key matrix \phi(\mathbf{K}) and the value matrix \mathbf{V} before multiplying by \phi(\mathbf{Q}) on the left; in this way, we never compute an L\times L matrix but instead compute an r\times d matrix of dot products of vectors of length L and then multiply it by an L\times r matrix on the left.

    Here is what it looks like; \mathbf{Q}' and \mathbf{K}' are \phi(\mathbf{Q}) and \phi(\mathbf{K}) respectively:

    Choromanski et al. consider several different ways to define the random features, and even provide theoretical results that show how positive orthogonal random features can lead to good approximations for the softmax kernel used in regular self-attention.

    In summary, low-rank decompositions provide another very efficient way to cut down on the quadratic complexity: they allow to never consider the full n\times n matrix for a large n but instead always deal only with projection n\times k matrices and dense k\times k matrices in the reduced dimension. In the next section, we will consider a couple of ideas that are different and do not fall neatly into the categories of either sparse attention or low-rank decompositions.

    Chunking the Attention: GAU, MEGA, and Reformer

    The last set of ideas I want to discuss is a direction that, surprisingly, has not yet appeared in this post: what if we just tweak the network architecture of self-attention? Usually that would mean that the quadratic complexity of attention remains in place, but is constrained to small subsets of the input, and these small subsets (chunks) are connected to each other in some additional way. In this section, we consider different ways to do that.

    As the first representative example of this direction, let us discuss the Gated Attention Unit (GAU; Hua et al., 2022), a variation of GRU with an attention mechanism. It also invites us to have a different look on some of the matrix calculations we have already seen.

    But first let us describe the GAU itself. It combines the familiar multi-head self-attention mechanism with the gated linear unit (GLU) shown on the left of this figure (Hua et al., 2022):

    A gated linear unit applies two dense transformations to the input and multiplies them componenwise, in effect gating the obtained representations with each other (recall, e.g., how LSTMs work: they use gating mechanisms a lot).

    So the Gated Attention Unit (GAU), as shown on the right of the figure above, combines a GLU and regular multi-head self-attention, sharing the computations of these layers: basically it is a GLU where one of the representations is also multiplied by the matrix of attention weights from other elements in the sequence, computed as the usual \mathrm{softmax}\left(\mathbf{Q}\mathbf{K}^\top\right). Hua et al. show that after this modification, you can replace softmax with ReLU and also simplify the computation of Q and K matrices; please see the paper for more details.

    Still, GAU is quadratic in the input size! But it will lend itself better to the chunking approximation that follows. We know that self-attention is quadratic as written:

        \[\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \mathrm{softmax}\left(\frac{1}{\sqrt{d_k}}\mathbf{Q}\mathbf{K}^\top\right)\mathbf{V}.\]

    But if we forget about the softmax and try to approximate this attention mechanism with QKTV, we can rearrange the terms:

        \[\mathrm{Attention}_{\mathrm{linear}}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \mathbf{Q}\left(\mathbf{K}^\top\mathbf{V}\right).\]

    Now the \mathbf{M}=\mathbf{K}^\top\mathbf{V} matrix in brackets is a d\times d matrix, and there is no quadratic dependency on the input length L. Moreover, we can compute the matrix \mathbf{M} incrementally, step by step: for every new input at time t,

        \[\mathbf{M}_t = \mathbf{M}_{t-1} + \mathbf{K}_t^\top\mathbf{V}_t,\]

    and we can just store its value in a cache of size O(d^2) and add new values as they arrive. Here is how Hua et al. illustrate this; the left side of the picture above contrasts quadratic attention and linear attention:

    This linear attention is similar to the attention mechanisms used in RNNs. Hua et al. try to have the best of both worlds: they split the input into chunks and use quadratic local attention inside each chunk, just like regular GAU, and linear global attention between chunks, as shown on the right of the figure above. As a result, they report results very similar to the basic Transformer self-attention with much less complexity.

    GAU has found applications in speech analysis (Tsai, Khoa, 2023) and has been combined with convolutional networks such as U-Net to improve image segmentation (Wang et al., 2023). Next, we consider one more detailed example of how GAUs have been used and modified.

    Building on GAU, a collaboration of Carnegie Mellon, USC, and Meta AI researchers presented the Moving Average Equipped Gated Attention, or MEGA (Ma et al. 2022). Unfortunately, this paper is not yet among the top results for the “Mega Transformer” Google search, but here’s hoping that citations will accumulate.

    MEGA is based on the idea of an exponential moving average. Suppose that you want to smooth out a series of numbers so that every resulting number is influenced by several past values in the series. You could take an average over a sliding window, and it would probably give the desired effect, but there is an important disadvantage: you will have to keep the entire window in memory; otherwise you won’t be able to subtract the values that come out of the window as it slides.

    If you need to save memory, you want to change the semantics and instead of a sliding window use an update rule like this:

        \[y_t = \alpha x_t + (1-\alpha) y_{t-1}.\]

    If you unroll this formula to previous steps, you will see that the exponential moving average gets a theoretically infinitely long memory with exponentially decaying weights of the inputs, hence the “exponential”; here is an illustration from Ma et al. (2022):

    The weight \alpha controls how long you want this memory to be. This is a classical trick known for centuries, so how do we apply it to the Transformer architecture? Here is the general overview from Ma et al. (2022):

    As you can see in (a), the basic Transformer architecture remains in place, but the self-attention layer is replaced with a “Mega layer”. The idea of this layer, shown in (b), is as follows:

    • first, the input sequence \mathbf{X}, which is an L\times d matrix, is expanded into h dimensions via a d\times h expansion matrix \boldsymbol{\beta}:

          \[\mathbf{u}^{(j)}_t=\boldsymbol{\beta}_j\mathbf{x}_{t,j};\]

    • second, EMA is applied to the expanded matrix \mathbf{U}; note that the Mega architecture uses what they call damped EMA, where the influence of previous steps is reduced by the damping factor \boldsymbol{\delta}:

          \[\mathbf{h}_t^{(j)} = \boldsymbol{\alpha}_j\odot \mathbf{u}_t^{(j)}+(1  - \boldsymbol{\alpha}_j\odot \boldsymbol{\delta}_j)\odot \mathbf{h}_{t-1}^{(j)};\]

    • third, the result is projected back with an d⨉h projection matrix \boldsymbol{\eta}:

          \[\mathbf{y}_{t,j}=\boldsymbol{\eta}_j^\top\mathbf{h}_t^{(j)}.\]

    But wait, this looks nothing like figure (b) above! Indeed, this is all hidden between the “Layer input” and “EMA output”. The important part of figure (b) is what happens with the result, and this is where recurrent networks come into play. 

    MEGA uses the Gated Recurrent Unit (GRU; Cho et al., 2014), a standard recurrent architecture developed as a simplification of LSTM, and the Gated Attention Unit that we discussed above. I will not go into too much detail about them, but basically GAU is the attention unit shown in figure (c) above, and then the results are processed as a sequence with the GRU unit as shown in figure (b).

    Interestingly, we have not yet done anything with quadratic complexity! The architecture above introduces a stronger inductive bias into the attention mechanism, i.e., makes it more position-sensitive. After doing that, MEGA can simply break the modified attention mechanism into chunks with quadratic attention applied to local segments, and connections between chunks supported in particular through the exponential moving average mechanism:

    So the actual reduction of quadratic complexity is very simple here, although different from the basic GAU: in GAU, information flows between quadratic chunks in a linear RNN-like style, while here we have a more complex (and hopefully more informative) relationship supported by the exponential moving average.

    But recurrent connections between sequential blocks are not the only way to chunk the input. Our final idea, called the Reformer, comes from Google Research (Kitaev et al., 2020). The Reformer actually introduces several interesting tricks to improve the performance of a Transformer, including reversible layers from Gomez et al. (2017) that allow to store only a single copy of activations in the model and splitting activations in feedforward layers that further saves memory.

    But we are interested specifically in what they propose to alleviate the quadratic complexity of self-attention. For that, Gomez et al. go back to the self-attention formula that we started with:

        \[\mathrm{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}) = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{Q}\mathbf{K}^\top\right)\mathbf{V}.\]

    Note that we are not really interested in the actual values of the matrix \mathbf{Q}\mathbf{K}^\top, only in \mathrm{softmax}(\mathbf{Q}\mathbf{K}^\top), and the results of a softmax function are dominated by a few largest elements, while the smaller elements have negligible influence on the result. If we have something like 64000 different keys, for each query \mathbf{q}_i it would be enough to consider only 32 or 64 keys \mathbf{k}_j nearest to \mathbf{q}_i in the embedding space. How do we find these nearest neighbors?

    Finding nearest neighbors is hard in the worst case, but there is a well-known trick in computer science called locality-sensitive hashing (LSH; see, e.g., Wikipedia and references therein): we can get a very efficient approximate algorithm for finding nearest neighbors if we can find a hash function h(x) that assigns nearby vectors to the same bucket with high probability.

    In a Euclidean space of embeddings, such a hash function can be given by random projections: let’s fix a random matrix that projects x to a much lower dimension and then define the hash as a set of buckets that values on different axes in that lower dimensional space fall into. In the illustration below (Gomez et al., 2017), the points are projected onto a sphere (circle), then the points are randomly rotated, and the hash is defined by the set of indices of the segments (there are four colored segments in the figure) where a point falls after these rotations:

    After hashing, we can look for nearest neighbors only in the current hash bucket, and with high probability we will not miss anything. For self-attention, this means that we restrict the attention matrices to a given hash bucket; the general scheme is as follows (Gomez et al., 2017):

    This trick will only work if we are actually looking for nearest neighbors rather than doing generic retrieval, i.e., only if Q=K. Interestingly, this restriction does not lead to any significant loss of quality! Gomez et al. compare this “shared-QK attention”, where the weight matrices for keys and queries are fixed to be equal, with regular Transformers and find negligible difference in performance.

    Other ideas and research directions

    By now, we have discussed all the primary classes of methods one can try to increase the context window size: sparse attention, low-rank decompositions, and chunking the input via architectural modifications. For the detailed discussion above, I chose methods that keep the basic idea of self-attention rather than replace it with something else entirely, but there are plenty more approaches, of course. Here are a few of the most notable that we have not had time to consider in detail.

    First, FlashAttention (Dao et al., 2022) makes self-attention IO-aware and dives into the hardware specifics, using tiling to reduce the number of read/write operations between GPU high-bandwidth memory and GPU on-chip SRAM:

    It has been a big success and has already been further developed by the author into FlashAttention 2 (Dao, 2023).

    Second, there are several block-wise approaches, that is, approaches that break down the quadratic matrix multiplication into a composition of multiplications of smaller matrices; just like MEGA did with its chunks. In particular, MEGABYTE (Yu et al., 2023) stacks two Transformers together, one for patch embeddings and another for individual tokens inside the patches, while blockwise parallel Transformers (Liu, Abbeel, 2023) put parallel blocks inside the self-attention mechanism itself. In the picture below, I combined the illustrations from these two papers:

    Third, attention-free Transformers (AFT) replace multiplication of the query and key matrices with learnable position biases added to the key matrix. Then the query matrix is just multiplied componentwise to the result, sidestepping quadratic complexity entirely. The idea of AFT started with Apple researchers Zhai et al. (2021); here is their illustration of the new attention mechanism:

    Finally, MEGA combined self-attention with recurrent networks, but there is also a direction of research that just try to improve RNNs themselvesKatharopoulos et al. (2020) introduced fast autoregressive Transformers based on RNNs with linear attention. Gu et al. (2022) developed the Structured State Space (S4) sequence model based on recurrent state space models, and it was further improved in diagonal state spaces (Gupta et al., 2022Gu et al., 2022), gated state spaces (Mehta et al., 2022), and selective state spaces (Gu, Dao, 2023).

    A large context size was achieved with an architecture reminiscent of the Transformer but with linear recurrent units stacked instead of self-attention layers (Orvieto et al., 2023); here is their main illustration that combines all the pieces:

    In a different line of RNN research, the receptance weighted key value (RWKV) architecture presented by a huge collaboration of researchers from nearly 30 different institutions (Peng et al., 2023) also promises to “reinvent RNNs for the Transformer era”.

    Each of these items could easily warrant a blog post of its own, and I can refer to a couple of full-scale surveys already devoted to the topic (Tay et al., 2022Ding et al., 2023Xu et al., 2024). But for us, it is time to move from possible answers back to the question itself: once you have an architecture that is supposedly ready to process long context windows, how do you test that?

    Testing Long Context: Needles and Haystacks

    In the last two sections, we ask another fundamental question: suppose we have used one idea or another to process a huge number of tokens at once. But how can we understand whether a LLM actually processes its long context window meaningfully rather than skip most of it? One standard test is to ask the model to look for a needle in a haystack: we fill the context window with either meaningless or random stuff and insert a single fact that the model later will have to fish out.

    As far as I know, one of the first iterations of this idea was released about a year ago in the Little Retrieval Test (LRT) by researchers from the University of Wisconsin-Madison and Yongsei University. It has a very simple structure, with meaningless numbers as filler information and a single line that instructs the model to go to a specific line and report its content:

    line 1: REGISTER_CONTENT is <2156>
    line 2: REGISTER_CONTENT is <9805>
    [EXECUTE THIS]: Go to line 5 and report only REGISTER_CONTENT, without any context or additional text, just the number, then EXIT
    line 3: REGISTER_CONTENT is <6668>
    line 4: REGISTER_CONTENT is <1432>
    line 5: REGISTER_CONTENT is <6727>
    line 6: REGISTER_CONTENT is <3936>
    line 7: REGISTER_CONTENT is <1805>
    line 8: REGISTER_CONTENT is <431>
    line 9: REGISTER_CONTENT is <1720>
    line 10: REGISTER_CONTENT is <6794>

    In a harder version, the lines are shuffled randomly. LRT was designed at the time when GPT-4 and Claude appeared, boasting long context windows. The results did support Claude 1.3’s claim for processing long context, up to 100K tokens which means about 6500 lines in LRT:

    In the end of June 2023, the LongChat team took the LRT and ran with it, making “fluff” information a little more meaningful for LLMs by replacing numbers with random text:

    line torpid-kid: REGISTER_CONTENT is <24169>
    line moaning-conversation: REGISTER_CONTENT is <10310>
    …
    line tacit-colonial: REGISTER_CONTENT is <14564>What is the <REGISTER_CONTENT> in line moaning-conversation?

    The results again showed Claude 1.3 and GPT-3.5-Turbo coming out on top:

    The next step came in Greg Kamradt’s “Needle In A Haystack” test released in November 2023. He changed the “haystack” to real meaningful text, and the overall process became as follows:

    • take Paul Graham essays as the input context; Kamradt used 218 essays with about 200K tokens, repeating the essays to make the input longer;
    • use a random but still meaningful statement as the “needle” to be found; Kamradt’s example was: “The best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day”;
    • ask the LLM to answer this question only using the context provided, and evaluate the answer with GPT-4 using LangChain evals (let’s not go into that further, here it basically means asking GPT-4 if the “needle” was found correctly).

    The results were nontrivial: even in the claimed context windows, retrieval accuracy degraded significantly as the “needle” fact was placed deeper in the context window. Here are the results for GPT-4:

    And here is the plot for Claude 2.1 available in November 2023:

    Greg Kamradt’s test became very popular; for example, the Gemini 1.5 Pro release I started this post with links to NIAH evaluation and claims to find the needle 99% of the time in input blocks “as long as 1 million tokens”. Here is the results plot Google shared in their Gemini 1.5 technical report:

    This is only the text part; the multimodal results are even more impressive, and I hope to get back to multimodality in subsequent posts.

    Now comes a little twist about this kind of testing. We saw that the original LRT used random numbers as facts for both “needle” and “haystack”. Later iterations of the test became more and more meaningful, with Greg Kamradt using coherent essays and “hiding” a subjective but plausibly sounding fact about San Francisco. But lately, researchers have gone back to meaningless information. Gemini 1.5 uses “needle” facts of the form “The special magic {random_city} number is: {random_number}”, and the task is to retrieve this number by the city name. Google researchers refer to recent work by Arize.ai which is basically Kamradt’s test with this change, so in a way we have come full circle back to the original LRT. I think it might be interesting to compare how (and whether) the “meaningfulness” of hidden information influences the results.

    I will close this section with one last result that sounds very interesting to me. Claude 3 Opus, an LLM just released by Antropic, was also tested with “needle in a haystack” type tests. In some cases, it actually managed to not only find the “needle” but also notice that it is probably being evaluated! Here is a sample output reported by Antropic researcher Alex Albert:

    Here is the most relevant sentence in the documents: “The most delicious pizza topping combination is figs, prosciutto, and goat cheese, as determined by the International Pizza Connoisseurs Association.” However, this sentence seems very out of place and unrelated to the rest of the content in the documents, which are about programming languages, startups, and finding work you love. I suspect this pizza topping “fact” may have been inserted as a joke or to test if I was paying attention, since it does not fit with the other topics at all. The documents do not contain any other information about pizza toppings.

    Is this yet another sign of AGI coming soon? I don’t know, but it sure looks like one…

    Kalamang to English: Machine Translation from One Book

    LRT-style tests are informative but relatively bland. But there is a test for (very) long context windows that totally blew my mind when I first heard about it.

    Meet Kalamang, a language spoken by less than 150 people in Papua New Guinea, a well-known repository of rare languages thanks to its structure of isolated islands (linguists estimate 840 languages—complete languages, not dialects!—spoken in this country with a total population of under 12 million). The Endangered Language Project ranks Kalamang as “Severely endangered”, remarking that “no one born after 1990 can be counted as a fluent speaker, and it is rare for them to have two fluent Kalamang-speaking parents”.

    In 2020, Eline Visser, a researcher from Lund University in Sweden, published her Ph.D. thesis in the form of a book called “A Grammar of Kalamang: The Papuan language of the Karas Islands”:

    It is a classical linguistic work that systematically describes the Kalamang language, complete with phonetics, morphology, word classes and so on, complete with English-Kalamang and Kalamang-English wordlists and a small sample of parallel translations.

    Since Kalamang is an oral language spoken by so few people, we can safely assume that Eline Visser’s book is the only resource publicly available for this language. This has led to the idea implemented by a team of researchers including Eline Visser herself: Tanzer et al. (2023) turned the Kalamang language into a testbed for LLMs.

    We all know that large language models are able to translate between languages; just ask ChatGPT and see for yourself. But it is one thing to translate from French to English, where resources abound in both languages, including a wide range of language primers, parallel texts, vocabularies and the like, and quite another to translate from Kalamang that only has one book about it, period.

    So in the original paper, Tanzer et al. found that machine translation from one book (MTOB) is indeed a very hard benchmark. Naturally, LLMs have zero prior knowledge of Kalamang and all have basically random performance without context, but it is hard for LLMs to learn a new language even when given Visser’s book as context. Here is a qualitative sample from the paper; naturally, I do not understand Kalamang but I guess we can assume that human output is correct:

    Note how LLMs that can write perfect English, even the almighty GPT-4, all have very awkward outputs that they themselves would definitely edit out of existence if asked. In fact, here is what GPT-4 told me about its own translation, probably straying further from the original but keeping the meaning of the translation just as I would understand it myself:

    The test set produced by the authors is large enough to allow for quantitative comparisons. And right now, it is a perfect test for long context windows since they only very recently have become long enough to fit Visser’s book. 

    The MTOB benchmark has been picked up by newer models: the Gemini 1.5 Pro report already includes a comparison table that boasts quantitative improvements and even shows that the full book context improves in comparison to half of the book, again highlighting how important it is to maximize the context window size:

    Conclusion

    Having a large context window has been one of the key obstacles to scaling up large language models to many real life applications. Since by default Transformer-based LLMs do not have persistent memory and cannot run algorithmic loops (they have a fixed, and not very large, number of layers), a LLM is limited in its reasoning to whatever it has memorized inside its weights and whatever fits into its context window. But the default self-attention layer that Transformers are based on has quadratic complexity in the length of its input sequence, making long context windows impractical.

    In this (very long) post, we have discussed several different ways to overcome the quadratic complexity bottleneck of Transformers and thus extend the context window. We considered several main directions: sparse attention mechanisms that replace the full self-attention matrix with a sparse submatrix, low-rank decompositions that replace the same matrix with a product of smaller rectangular matrices, and different ways to break down the self-attention computation into blocks, including a mix of attention mechanisms and recurrent networks. We have also discussed how one can test that a context window is indeed large, and that the model can actually pick up all of the information from its context window.

    In the next post, I will go back to the other big piece of AI news from the last month: the Sora video generation model and generally how modern Transformer-based architectures construct their world models. We will discuss what “world models” are, how they have progressed over the last few years, and what are the current obstacles that we still need to overcome. Until then!

    Sergey Nikolenko
    Head of AI, Synthesis AI