Author: snikolenko

  • The Unreasonable Ineffectiveness of AI for Math

    The Unreasonable Ineffectiveness of AI for Math

    One of the most interesting AI-related news for me recently was a paper by DeepMind researchers that presented a new mathematical result found by large language models: new constructions for the cap set problem. In this post, we take a step back and discuss the general relation between math and AI. A mathematical proof is easy to verify but may be very hard to find. But there are AI-shaped holes in looking for a proof: math involves multi-step reasoning and planning, hard theorems need to be decomposed into lemmas, there are search strategies involved… However, mathematics has turned out to be unexpectedly difficult for AI. In this post we discuss what people have been doing with AI in math and how LLMs can help mathematicians right now.

    Mathematical logic: formalizing the formal

    I am a mathematician by education, and I have been doing math in some form or other throughout my career. My M.Sc. thesis in 2005 (almost 20 years ago! how time flies…) was devoted to structure theorems for Chevalley groups, my Ph.D. in 2009 was on theoretical cryptography where the main content is theorems rather than protocols, and my Sc.D. (habilitation) thesis in 2022 was devoted to the analysis of algorithms for networking and was also full of theorems.

    The general workflow of a professional mathematician goes a little bit like this:

    This is a big simplification, of course; for example, you often begin not with a problem but with a method, and try to find an interesting way to apply this method to some new application. But to a first approximation, solving a problem in mathematics usually means proving something, and when you try to prove something new, you usually need to come up with some creative ways to do it. Then your idea needs to be fully fleshed out: that’s where most ideas fail, and it’s back to the drawing board for the mathematician. Finally, once the proof is done, you need to verify it and, if everything is fine, write it up for publication.

    I worked on mathematical logic, then machine learning and artificial intelligence, and through all these years I couldn’t help but wonder: why could I still be a mathematician? How is it possible that in our time, 100 years after Hilbert and 80 years after Gödel, theorems are still proven by live flesh and blood people? Let me elaborate a little so you can feel the wonder too.

    Mathematics is the most formalized field of research; it is not a “science” since it does not deal with the dirty material world. Mathematics lives in a separate magisterium of abstractions where everything is absolute: you accept a few axioms and derive everything from these theorems via formal derivation rules. And by “formal”, I mean really formal, not general proof ideas like “reasoning by contradiction” but specific indisputable axioms that convert logical formulas to other logical formulas such as modus ponens: if we have a formula P and another formula P→Q, where → denoted implication, then we can add Q to our set of derived formulas (and use it further in the proof).

    For example, Alfred Whitehead and Bertrand Russell (shown below) wrote a famous book called “Principia Mathematica”, where they attempted to build the whole mathematics from the ground up, starting from the axioms. In the book, the proof for 1+1=2 appears on page 379:

    Naturally, that’s because they had to define the entire set theory, natural numbers, and addition first, but it is still quite an amazing piece of trivia.

    Theoretically, you can write down the proof of every correct theorem as a formal chain of such derivations. It will be unbearably cumbersome, so I can’t even give you a serious example here, but in theory it can be done. More modern branches of mathematics such as group theory have this formalization built in: they start off with a fixed set of axioms (definition of a group in this case) and derive theorems formally.

    Formalization is a relatively recent idea. For example, high school geometry is usually the place where you can get a feeling for mathematical proofs, but it is actually a pretty informal affair, with a lot of holes that Euclid did not fill. When mathematicians tried to patch up geometry, the 5 Euclid’s axioms expanded into 16 for Hilbert or 10 for Tarski (plus an extra axiom schema). Calculus was initially a rather intuitive affair, tightly interwoven with physics, differential equations, and practical applications. The derivative is the speed, the second derivative is the acceleration, and so on. It was only in the XIX century when all those suspicious counterexamples you vaguely recall from calculus classes (such as the Dirichlet function which is 1 on rational numbers and 0 on irrational ones) finally caught up with mathematicians and made “intuitive math” no longer viable. Something was wrong: mathematics needed to fix its foundations.

    The answer was to introduce formal axioms and derivation rules and make math purely formal; this was the project initiated by Georg Cantor and Gottlob Frege in the second half of the XIX century and almost completed by David Hilbert in the beginning of the XX century. In this project, mathematics begins with set theory, usually introduced with Zermelo-Fraenkel’s axioms, and then other branches of mathematics are defined in terms of sets. For instance, natural numbers are usually introduced as ordinals: 0 is the empty set, 1 is the set with one element which is the empty set, {∅}, 2 is the set {0, 1} = {∅, {∅}}, 3 = {0, 1, 2} = {∅, {∅}, {∅, {∅}}}, and so on.

    I say “almost completed” because Kurt Gödel’s incompleteness theorems proved to be an unfixable flaw in the original Hilbert’s program: for any reasonably powerful proof system (sufficiently powerful that it can include arithmetic), you can find true statements that are nevertheless unprovable. When Hilbert heard about these results, he went through all five stages of grief, from anger to acceptance.

    But really, Gödel’s incompleteness is not a problem in practice. Sure, there exist true unprovable statements, and there is a whole cottage industry devoted to finding specific unprovable statements (wiki). But you won’t find them in everyday mathematical practice. If you want to prove a new theorem, you can be quite certain that it can be proven within “standard mathematics”. Even if a notable counterexample arises (can the Riemann hypothesis be undecidable?), it will remain just that: a very, very special case.

    So we can safely assume that formal proofs of new, unproven theorems exist somewhere in the abstract world of mathematical absolute. The next question is: how do we find them?

    Automated theorem proving and early AI

    Formal proofs are hard to find, but once a proof has been found, it is easy to check that all derivation rules have been applied correctly. You could say that finding proofs for theorems is in NP (recall the P=NP problem)… if you measured complexity as a function of the size of the proof rather than the theorem itself. The problem is that some proofs may be unbelievably large, far exceeding the number of atoms in the known universe even for a small and obviously true statement (alas, it would get too technical to give specific examples).

    Proof size depends on two things: the statement (theorem) that’s being proved and the proof system where it’s happening, i.e., which axioms and derivation rules we are allowed to use. Universe-sized counterexamples could be proven very compactly, even elegantly in more powerful proof systems — otherwise how would we know they are true? But there is an inherent tension here between the size of the proof and the expressive power of the proof system. You can have a very simple proof system where proofs may be easy to find, but they will be very long even in reasonable cases. Or you can have a very strong proof system that allows for many short proofs, but it will be much harder to find them. Sometimes it all comes down to adding or removing a single important rule, such as the cut rule in Gentzen’s sequent calculus:

    Still, these results usually come in the form of counterexamples that one has to actively search for. There is no result that says that interesting mathematical theorems such as Fermat’s last theorem or the Riemann conjecture have to have astronomically long proofs, even in a simple proof system. And surely, even if the hardest riddles of mathematics do indeed turn out to be hard for a reason, there’s still a lot we can do about simpler problems, right? Math is large, it has a long frontier, and there should be plenty of opportunities to advance it.

    As soon as computers allowed people to try, automated theorem provers did appear. In fact, in the early days of artificial intelligence logic was thought to be one of the key components. The famous 1943 paper by William McCulloch and Walter Pitts, which introduced the first mathematical model for a neuron and hence a neural network, was called “A logical calculus of the ideas immanent in nervous activity”, and the main results were purely logical in nature. McCulloch and Pitts compared several possible architectures of neural networks and established logical equivalences between them: if a function can be realized by one kind of network then it can also be realized by another. Just read the abstract if you don’t believe me:

    Logic was a key idea in early AI: people thought that the difficult part of getting a machine to think would be to imbue it with logical reasoning, teaching it how to make inferences correctly. It soon became evident that reasoning in the everyday sense of the word is not a problem, and the real problem is converting murky everyday notions into statements you could reason with. Understanding notions such as “near” (as in “don’t go too near”) or “elevated” (as in “you have elevated blood sugar”) gave rise to fuzzy logic, converting a collection of pixels on a photo into a mathematical representation of a 3D object is the fundamental problem of computer vision, and so on.

    Still, what about harder reasoning like proving new theorems? One of the first successful automated theorem provers was Logic Theorist, developed in 1956 by Allen Newell, Herbert Simon, and Cliff Shaw (see also Gugerty, 2006). It pioneered many techniques that are now standard. Formulas were represented as trees, and the search for a proof itself was a tree with the initial hypothesis as the root and deductions as branches. Since the search tree (unlike the final proof, which is just one of its paths) would definitely be exponential in practice, Newell, Simon, and Shaw developed heuristics for pruning branches unlikely to lead to a solution, a technique that would become standard throughout early AI. Finally, to implement Logic Theorist the authors developed a programming language called IPL (Information Processing Language) which was the direct predecessor of John McCarthy’s Lisp!

    They tested Logic Theorist on Principia Mathematica, feeding it with 52 theorems from Chapter 2, in the same order. When Logic Theorist proved a theorem, it could add it to storage for use in later proofs. As a result, it proved 38 of the 52 theorems (73%!), and sometimes produced shorter proofs than the ones by Whitehead and Russell themselves!

    In the 1950s, it was hard to expect this automated search to actually come up with new theorems: computers were slow and their memories were small. Still, these results were extremely promising. Logic Theorist is widely regarded as the first real life program from the field of AI, actually predating the famous Dartmouth workshop where the term was coined. In January 1956, Herbert Simon told his graduate class: “Over Christmas, Al Newell and I invented a thinking machine”, and he would later write that they “invented a computer program capable of thinking non-numerically, and thereby solved the venerable mind-body problem, explaining how a system composed of matter can have the properties of mind”.

    The 1950s were indeed a very optimistic time. But where did this line of thinking lead? How did later attempts at theorem proving go?

    Symbolic computation and formalized mathematics

    In the 1960s and 1970s, the researchers’ attention turned to a large extent to symbolic math systems. One of the pioneers here was MACSYMA, which stands for “Project MAC’s SYmbolic MAnipulator” (Pavelle and Wang, 1985Fateman, 1982). Project MAC (Machine-Aided Cognition or Multiple Access Computer) was an MIT lab that later grew into MIT CSAIL (Computer Science & Artificial Intelligence Laboratory), one of the leading academic AI labs today. MACSYMA was a software system, developed in a specially designed Lisp dialect, that could perform many symbolic mathematical operations including limits, derivatives, Taylor series, Laplace transformations, ODEs, and more. It was a direct precursor to such systems as Matlab and Maple, but it was mostly used as a computational tool for researchers in other fields of science.

    Automated theorem proving, on the other hand, progressed much slower. One of the early landmarks here was the Automath formal language developed by Nicolaas de Bruijn in the late 1960s. Automath has been largely forgotten now, but it actually laid the foundations for typed lambda calculus, including the introduction of dependent types, and pioneered the use of the Curry–Howard correspondence, also known as the “proofs-as-programs” correspondence: a program of a certain type (in the sense of typed programming languages) can be seen as a proof of the proposition represented by this type. I won’t go into a detailed explanation here but do recommend the interested reader to work through at least the example given in Wikipedia.

    One of the first popular proof assistants was Mizar, a system that first appeared in 1973 and is still in active use today. Then came Coq itself, which remains the popular proof assistant to this day. Another important proof assistant is HOL, which stands for “higher order logic”; indeed, HOL can handle higher-order logic proofs, and it is still a live project with new versions coming out. 

    Today, there are plenty of tools that can verify formal proofs of mathematical theorems, and some of them can look for new proofs too. Naturally, there have been attempts to formalize at least the math that we already have… but without much of a success.

    For example, there is a valiant effort in the form of the Formalized Mathematics journal established in the 1980s. It publishes formal, mechanically verified proofs of known mathematical results; naturally, nobody prohibits authors from using a computer to come up with the proofs either. Right now, some of the latest papers in Formalized Mathematics are solutions to problems from the book “250 Problems in Elementary Number Theory” by Wacław Sierpiński, published in the late 1960s. These are not open problems, they are just somewhat advanced problems for students that you might find in a textbook (here is a paper from Dec 31, 2023). 

    I’m not saying this to kick Formalized Mathematics, I’m saying this to show that doing math in a formalized and automatically verifiable way is hard indeed, much harder than an outside view on math might suggest. The “QED Manifesto”, a similar initiative put forward in 1993, also quickly dissolved. In general, formalized mathematics still lags very far behind “real” mathematics done by people.

    Automated theorem provers, i.e., programs that can try to find proofs all by themselves, do exist. For instance, there is a well-known family of first-order provers developed at the Argonne National Laboratory in Illinois, starting from Otter and continuing via EQP (Equational Prover) to Prover9. More modern examples include Lean, a general-purpose theorem prover and proof assistant.

    And they are indeed used in mathematics (see, e.g., the list of papers using Lean), but full-fledged automated proofs are very rare and always constrained to cases where human mathematicians did see the path to a proof, but the path was too cumbersome to do by hand. One famous example here is the Robbins conjecture, proven in 1996 by the EQP prover. Again, I recommend the reader who is familiar with basic mathematical structures such as Boolean algebras to actually read through the problem setting by the link. The Robbins conjecture is about an alternative set of axioms for Boolean algebras, and the question is as close to axioms as possible: is the alternative set actually equivalent to the definition? In 1996, William McCune proved that it is indeed the case, using the EQP theorem prover that specializes on rewriting equations. You can find the whole proof in human-readable form in this paper by Allen Mann, although “human-readable” may be a slight overstatement in this case.

    So this was a success for automated theorem proving. But this problem has the perfect combination of traits for the point I want to make:

    • it is very close to the axioms (in fact, it’s a question about whether one set of axioms is equivalent to another);
    • it is about a relatively simple object: there are few axioms, few connectives, and few derivation rules;
    • but at the same time, the proof is quite long and hard to break down into meaningful lemmas, so for a human it is very hard to find by hand.

    These traits are characteristic of most mathematical results where computers have been able to meaningfully assist humans. One of the first famous examples is the four color theorem, a conjecture from graph theory that you can paint the regions of any map in four colors so that no two regions painted in the same color share a nonzero border (arbitrarily many regions can come to a single point, of course, but that doesn’t count). As you can see, this is also a short and discrete kind of statement, but the proof (announced by Appel and Haken in 1976) was still done almost entirely by hand. The crucial point, however, required enumeration and analysis of about 1500 different cases, so this part was programmed and done on a computer. It’s not quite the automated proof search you would expect (although in 2005, the proof was actually formally verified in Coq, so the four color theorem is now part of formalized mathematics).

    Other examples (see, e.g., this list on Wikipedia) are usually of the same nature: computers help with long case-by-case analysis or with mechanical rewriting of complicated equations, but the ideas remain human. In fact, many mathematicians are still wary of computer-assisted proofs because they are unverifiable by humans and therefore don’t fulfill the main function of a proof: they don’t convince people. In a paper on the Robbins problem, Louis Kauffman sums this conundrum up as follows: “Can a computer discover the proof of a theorem in mathematics?.. I say that a proof is not a proof until a person is convinced by it. In fact a mathematical proof is exactly an argument that is completely convincing to  a mathematician! In this sense, a computer does not, can not produce a proof… It does not know the proof. It only finds the steps. It is a human judgement that propels the result of the computer’s search into a statement that the computer has “found a proof”… If we judge that to be a proof, then it is a proof (for us)”.

    But all that was in the 1980s and 1990s. Now we have powerful GPUs, deep neural networks that do wonders, exceeding the human level in many tasks that we considered purely human before. So they can help us with math as well… right?

    Deep learning and mathematics

    As we have already discussed, there are two main directions in how AI can help mathematics: either directly by finding proofs or indirectly (but usually more efficiently) by doing straightforward but cumbersome stuff like rewriting equations or doing case-by-case analysis.

    Several breakthroughs have been made in following the latter strategy. Modern deep learning adds another important twist on this idea: instead of mathematicians writing code that enumerates possibilities, an AI model can try to write the best code for the problem as well. This is very similar to neural architecture search (NAS) that yielded some of the best neural architectures in computer vision, new activation functions, and more. Similar to how you can search for architectures, you can also search for programs, usually with some kind of genetic programming approach since computer programs are naturally represented by trees.

    So you can take it one step further and tackle problems where the answer is an algorithm. In 2022, DeepMind’s AlphaTensor made the news doing exactly that: it discovered improvements in matrix multiplication algorithms, improving over Strassen’s algorithm for the first time in 50 years. In AlphaTensor, the tensor specifies which entries to read from the input matrices, and where to store the result; for example, in the three-dimensional tensor below, (a1, b1, c1) and (a2, b3, c1) entries are set to 1, and this means that c1=a1b1+a2b3:

    AlphaTensor optimized over such tensors with an MCTS-based algorithm very similar to AlphaZero that plays chess and Go but with some new advances related to the extra large width of the search tree in this case. As a result, it improved over the best known matrix multiplication algorithms for a number of different matrix sizes, starting from as low as 4×4 matrices; this is more than just a constant improvement since these algorithms can be applied recursively to handle block matrices of arbitrary size. This was a very important result, and it was obtained virtually independently of humans; but again, this falls more into the “searching through cumbersome cases” category, the AlphaZero-based search algorithm just helps scale it up to a previously unheard of number of cases.

    Another important example in the same direction that made the news last year was AlphaDev, another work by DeepMind in a similar vein that managed to improve sorting algorithms, that is, the cornerstone of almost every data manipulation computer program in the world! In a Nature paper by Mankowitz et al. (2023), the researchers presented another AlphaZero-based modification of MCTS search designed to invent new sorting algorithms. The resulting algorithms have already been implemented in the std::sort C++ library, which means that they are already making millions of computer programs run faster.

    As large language models became available, another direction appeared: you could ask LLMs to prove theorems directly! Naturally, it did not work all that well at first, and even today, if you just ask an LLM to prove a theorem, this strategy won’t get you published in Annals of Mathematics.

    One way to improve here is to fine-tune LLMs on mathematical content. For example, Minerva (Lewkowycz et al., 2022) did just that, fine-tuning general purpose LLMs from the PaLM family on technical content. As a result, Minerva could successfully reason through some high school level mathematics, although it was still a far cry from proving new results. Here is a sample of what Minerva was capable of:

    Another approach would be to use the already excellent coding capabilities of LLMs. As you know, modern LLMs can produce correct code snippets, so if your problem can be solved by some kind of enumeration you can ask the LLM to write this code for you. ToRA (tool-integrated reasoning agent) by Gou et al. (2023) closed the loop in this reasoning, using an LLM to write code, then going to an external tool to run it, and then fixing the code and interpreting the results with an LLM again. In the illustration below, the authors contrast ToRA with pure language-based and pure code-based approaches:

    Finally, I want to highlight another work by DeepMind (looks like they are the main players here): “Advancing mathematics by guiding human intuition with AI” by Davies et al. This work pursues a very different approach to helping mathematicians: instead of trying to formally prove something, here the authors use machine learning to discover new possible relations between mathematical objects. Here is the general framework of how it works; note that there are both “computational steps” done by AI models and “mathematician steps” done by real people:

    For example, the authors could discover and then actually prove a relationship between algebraic and geometric invariants in knot theory. The margins of this post are too narrow to explain what exactly this means, but in essence, a machine learning model detected that there might be a relationship between one way to describe knots in topology and another. This connection turned out to be real, and mathematicians were able to prove its existence and introduce new important objects that describe it. Naturally, they did it by hand, but their intuition in formulating this result was guided by ML-produced discoveries.

    And with that, we have reached the latest news: FunSearch. It is yet another Nature paper by the DeepMind team, in this case adding some large language models into the mix. Let’s see how it works!

    FunSearch: As fun as it sounds?

    We now come to the actual result that motivated me to write this post. In December 2023, DeepMind researchers Romera-Paredes et al. published a paper called “Mathematical discoveries from program search with large language models”. They proposed a relatively simple way to use large language models to guide the search for new mathematical results, not in the form of actual results like most researchers have done before but in the form of programs that could generate these results. It goes like this: given a problem specification,

    • first ask a pretrained LLM to generate some candidate programs that might solve the problem;
    • add the resulting programs to the database of programs created, run them and score their results according to the desired objective function;
    • then form a prompt that combines a few of the top scoring programs and asks the LLM to improve over them,
    • and then feed this prompt to the LLM again, thus closing the loop.

    Here is an illustration from the paper itself:

    Specification includes an evaluation function that scores the solutions and a “solve” function that provides the barebone algorithm (say, a greedy search) and lets the LLM concentrate on the creative part (for greedy search it is the priority function that compares elements). Sounds pretty simple, and looks like it is: it is more of a prompt engineering result than a new machine learning approach.

    So what could FunSearch do? One of the main results in this paper are new bounds for the cap set problem. Fields medalist Terence Tao, by many accounts the best mathematician alive, once called it “perhaps my favourite open question”, so let’s dive into the problem a little bit.

    A cap set is a set of numbers in a finite field that does not contain nontrivial arithmetic progressions, i.e., where no three points form a line in the finite geometry of F3n, where F3 is the field of three elements… I started out on the wrong foot, didn’t I?

    There is a much more accessible description of what’s going on in the cap set problem. You’ve probably heard of the card game “Set” where players need to shout out “Set!” when they see three cards such that for each of the four attributes—number, color, shape, and shading—the three cards are either all the same or all different. In the example below (taken from here, as well as the general idea of this connection), on the left you see two examples of sets, one where no attribute matches and another where almost all of them do, and on the right you see a sample Set board layout with one set highlighted (are there more? see for yourself):

    In these terms, a cap set is a collection of cards that contain no sets, and the main question is this: how many cards can you lay down so that they contain no sets? For the original game of Set, the answer is known: back in 1971, Giuseppe Pellegrino proved that there exist collections of 20 cards without a set, but 21 cards always contain one (note that this result predates the invention of the game in 1974, so if there is any connection, the causality here goes in the opposite direction). But in mathematics, you always ask the more general question. Here, we generalize the number of attributes: how many cards with n different attributes (instead of 4 in Set) can you lay down without a set of three cards?

    It is obvious that you can have 2n cards without a set: just pick two values for every attribute and use only cards with these attributes. It was proven in 1984 that the upper bound is asymptotically less than 3n, actually at most O(3n/n). For over 20 years, the gap between these two results remained a glaring hole in combinatorics; in fact, closing this gap was what Terence Tao called his “favourite open question” back in 2007.

    Important progress was made in 2016 when Ellenberg and Gijswijt used the method developed by Croot, Lev, and Pach to reduce the upper bound to 2.756n; it is telling that both papers were published in Annals of Mathematics, the most prestigious venue for publication in pure math. Since then, there has been no improvement in the exponent for either lower or upper bound.

    So what did DeepMind do with the problem? Fun search could provide a new upper bound on the cap set problem for n=8, i.e., it could produce a larger collection of Set cards with 8 different attributes and no sets on board than ever before.

    Here is a general illustration where in the top middle we have an illustration for the cap set problem in terms of finite geometry (colors denote lines in F33), on the bottom we have the FunSearch results with a new record for dimension 8, and on the right you can see the program that generates this solution:

    The program is virtually unreadable, and we will not analyze it here, of course, but it is still important that it’s a program, not just an answer in the form of a cap set. By analyzing this program, mathematicians can gain some insight into how this counterexample is structured; Romera-Paredes et al. did just that and could indeed understand the result better and relate it to other known examples in combinatorics.

    Still, all this sounds a bit underwhelming: looks like FunSearch is still just searching for counterexamples, like countless helper programs for mathematicians before. It is still unclear when we are going to have a program that actually does new math in the form of proving theorems.

    Conclusion

    Today, we have discussed the main avenues for how AI can help mathematicians:

    • direct automated theorem proving via first- and higher-order logic;
    • helping with the cumbersome side of mathematics: enumerating special cases, doing automated case-by-case analysis, rewriting equations and so on;
    • applying large language models to try and generate proofs and/or write code that will do the cumbersome parts instead of writing this code by hand;
    • discovering new patterns in data, including data in the form of mathematical objects, that may inform the intuition of mathematicians and lead to new discoveries;
    • using some combination of the above: for example, FunSearch uses an LLM to write key portions of programs that are then tested against the problem.

    But if we put all this together, we basically get the full picture of making mathematics. Let us go back to the picture I started with, the general workflow of a professional mathematician, and put some of the papers and tools we have discussed in their proper places:

    As you can see, AI is already helping mathematicians every step of the way, so maybe the “unreasonable ineffectiveness” I started with is not so ineffective after all. Still, it looks like doing math formally is hard, and so far the latest AI research can help somewhat, but only so much; there is no silver bullet that would just short-circuit the whole workflow and replace human mathematicians entirely. But we have also seen that doing formalized mathematics is hard for people too, even with the help of computers, so maybe there are deeper reasons here too.

    On the other hand, AI progress is very fast right now. Two weeks after FunSearch, another DeepMind paper appeared in Nature: Trinh et al.’s “Solving olympiad geometry without human demonstrations”. They present a system able to successfully solve geometry problems from the International Mathematical Olympiad at nearly a gold medalist level; geometry problems virtually always require formal proofs, and IMO problems usually require quite ingenious ones. 

    Note also that Nature has a very fast but nontrivial review cycle: the IMO geometry paper was submitted in April 2023, and the FunSearch paper was submitted in August 2023; this is more than half a year of progress already made since these results, and in 2023, half a year counted for a lot. So just like in most other fields, we probably won’t be expecting a really working theorem prover right until it appears.

    And finally, I would like to take this opportunity to dedicate this post to my first research supervisor, Professor Nikolai Vavilov (not to be confused with his ancestor, the famous geneticist Nikolai Vavilov), who was a key figure in modern algebra, founded a thriving research school, wrote several very interesting textbooks, and lit up every conversation with his wit and erudition. I owe Nikolai Alexandrovich a lot in my mathematical upbringing. Tragically, Prof. Vavilov passed away last September.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Generative AI, Part 0: Background on Transformers

    Generative AI, Part 0: Background on Transformers

    Here at Synthesis AI, we have decided to release the “Generative AI” series in an e-book form; expect a full-fledged pdf with all the images soon. But when I started collecting the posts into a single coherent whole, I couldn’t help but feel the huge, glaring omission of the most important topic in modern AI, the secret sauce that drives the entire field of ML nowadays: self-attention layers introduced in the original Transformer architecture. I haven’t planned to cover them before since there are plenty of other excellent sources, but in a larger format Transformers have become an inevitability. So today, I post the chapter on Transformers, which seems to be by far the longest post ever on this blog. We will discuss how the Transformer works, introduce the two main families of models based on self-attention, BERT and GPT, and discuss how Transformers can handle images as well.

    The Transformer Architecture and Self-Attention

    The Transformer was introduced in 2017, in a paper by Google Brain researchers Vaswani et al. with a catchy title “Attention is All You Need”. By now, it is one of the most important papers in the history of not only machine learning but all of science, amassing nearly 100000 citations (by Google Scholar‘s count) over a mere five years that have passed since its publication.

    An aside: for some unknown reason, it is quite hard to Google the most cited papers of all time, and there is no obvious way to find them on Google Scholar. I have found an authoritative review of the top papers of all time in Nature, and it cites only three papers with over 100K citations in the entire history of science, but those are “proper” citations counted by the Web of Science database. Lowly arXiv preprints do not register at all, so their numbers are always far lower than on Google Scholar that counts everything. In any case, the Transformer paper is truly exceptional. 

    There have been dozens of surveys already, so I will cite a few but it is far from an exhaustive list: (Zhou et al., 2023Wolf et al., 2020Lin et al., 2022Tay et al., 2022Xu et al., 2023Wen et al., 2022Selva et al., 2023). The Transformer was indeed a very special case, an architecture that, on one hand, uses ideas already well known in the machine learning community for many years, but on the other hand, combines them in a whole that has proven to be much, much larger than its parts. So what is the basic idea of a Transformer?

    First, the original Transformer was an encoder-decoder architecture intended for sequence-to-sequence problems, specifically for machine translation. In essence, the original Transformer was designed to:

    • first encode the input sequence (say, a sentence in French) into a latent representation, i.e., a dense vector of features in some highly semantic latent space;
    • then decode the latent code into the new output sequence (say, a sentence in English).

    This means that before starting to encode text, the Transformer needs to convert it into a sequence of tokens; we will talk about it more in the next section, and for the images here let us just assume that tokens are individual words. After that, the original Transformer had the following structure:

    • an embedding layer that encodes input tokens into dense vectors;
    • six encoder layers that produce semantically rich latent representations of input tokens;
    • six decoder layers that produce the output in an autoregressive way, using the input tokens’ latent representations as conditions.

    Here is this structure:

    Each layer has a very simple internal structure:

    • given some input vectors \mathbf{x}_1,\ldots,\mathbf{x}_L, an encoder layer first puts them through a self-attention layer followed by layer normalization, modifies the result with a feedforward layer, and outputs the resulting vectors \mathbf{x}'_1,\ldots,\mathbf{x}'_L to the next encoder layer;
    • after all encoder layers are done, we have the results in the form of all vectors output by the last encoder layer;
    • item then each decoder layer puts its inputs \mathbf{x}_1,\ldots,\mathbf{x}_L through a masked self-attention layer, then an encoder-decoder attention layer that actually looks at encoder outputs, and then a feedforward layer again;
    • finally, when all decoder layers are done, the resulting vectors are fed through a very simple classification head—just one linear layer followed by a softmax—to produce the next token; it is then embedded into the next vector \mathbf{x}_{L+1}, and the decoding process can begin again, autoregressively.

    Here is what the Transformer looks like with a single layer expanded in both encoder and decoder:

    Layer normalization (Ba et al., 2016) is just a standard technique to stabilize training in deep neural networks; in the Transformer, it is also combined with a residual connection, so it is actually \mathrm{LayerNorm}(\mathbf{X} + \mathbf{Z}), where \mathbf{X} is the matrix of original input vectors \mathbf{x}_1,\ldots,\mathbf{x}_L and \mathbf{Z} is the matrix of the self-attention results \mathbf{z}_1,\ldots,\mathbf{z}_L. A feedforward layer is just a single layer of weights applied to the vectors \mathbf{z}'_1,\ldots,\mathbf{z}'_L.

    The real magic happens in the self-attention layers, both in regular self-attention and encoder-decoder attention layers featured in the decoder. Let us look at them in more detail.

    The intuition for self-attention layers comes from information retrieval, a field that we have already considered in detail in Part IV of this series. For the Transformer, we only need the very basic intuition of searching in the latent space, as illustrated below:

    In this simple form, information retrieval works as follows:

    • both queries and documents share the same latent space, although the ways of encoding them into this latent space may be different (after all, even if both queries and documents are texts they have very different properties);
    • a search query (text queries shown on the top of the figure above) is put through a query encoder to get to the latent space;
    • documents (represented by images in the figure) are also represented in the same latent space via a different encoder;
    • to find the most relevant documents, we simply find the nearest neighbors for the query in the latent space among the documents; one often assumes that the latent space is linear and the distance metric there is just the scalar product of vectors,

          \[\mathrm{dist}(\mathbf{q},\mathbf{d})=\mathrm{Enc}_q(\mathbf{q})^\top\mathrm{Enc}_d(\mathbf{d}).\]

    In the self-attention layer, this intuition comes alive in a very abstract fashion. Let us follow through this process as it is illustrated below:

    The self-attention layer receives as input a sequence of vectors \mathbf{x}_1,\ldots,\mathbf{x}_L, which we can think of as a matrix \mathbf{X}\in\mathbb{R}^{d\times L}.

    First, what are the queries, keys, and documents? All three of them come from the vectors \mathbf{x}_i themselves! The figure above shows this idea with the example of what happens with \mathbf{x}_1:

    • multiplying \mathbf{x}_1 by a weight matrix \mathbf{W}^Q, we get the query vector \mathbf{q}_1=\mathbf{W}^Q\mathbf{x}_1; note that the matrix \mathbf{W}^Q\in\mathbb{R}^{q\times d} does not have to be square, and the dimension q of the query vectors, \mathbf{q}_i\in\mathbb{R}^q, may be different from (usually lower than) the input dimension d, \mathbf{x}_i\in\mathbb{R}^d;
    • multiplying every \mathbf{x}_i by a second weight matrix \mathbf{W}^K, we get the key vectors \mathbf{k}_i=\mathbf{W}^K\mathbf{x}_i for i=1,\ldots,L; since we want queries and keys to inhabit the same latent space, we have the keys with the same dimension as queries, \mathbf{k}_i\in\mathbb{R}^q, so \mathbf{W}^K\in\mathbb{R}^{q\times d};
    • finally, the third weight matrix \mathbf{W}^V gets us the value vectors \mathbf{v}_i=\mathbf{W}^V\mathbf{x}_i fori=1,\ldots,L; these are the documents that we will “retrieve” by their keys \mathbf{k}_i; in this case we might have a different space for the documents, so formally we have a different dimension v for the values, \mathbf{v}_i\in\mathbb{R}^v and \mathbf{W}^V\in\mathbb{R}^{v\times d}; in practice, however, one usually takes v=q.

    The matrices \mathbf{W}^Q\mathbf{W}^K, and \mathbf{W}^V comprise the bulk of trainable weights in the self-attention mechanism. After applying them as above, we obtain three vectors \{\mathbf{q}_i, \mathbf{k}_i, \mathbf{v}_i\} from each input vector \mathbf{x}_i. Then we do the retrieval part, computing attention scores as scalar products between queries and documents. The figure above shows this process schematically with the example of \mathbf{q}_1 transforming into \mathbf{q}_1^\top\mathbf{v}_i for all different i. Then we need to rescale \mathbf{q}_1^\top\mathbf{v}_i, dividing it by the square root of q, and pass the scores through softmax to turn them into probabilities. The self-attention weights are thus

        \[\alpha_{ij} = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{q}_i\mathbf{K}^\top\right)_j,\]

    where \mathbf{K}\in\mathbb{R}^{q\times L} are all the keys combined in a matrix, \mathbf{K}=\mathbf{W}^K\mathbf{X}.

    Then we use the result as coefficients for a convex combination of values \mathbf{v}_j. Thus, overall we have the following formula for what \mathbf{x}_i turns into:

        \[\mathbf{z}_{i} = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{q}_i\mathbf{K}^\top\right)\mathbf{V},\]

    where \mathbf{K}\in\mathbb{R}^{v\times L}  are all the values combined in a matrix, \mathbf{V}=\mathbf{W}^V\mathbf{X}.

    The normalizing factor comes from the fact that if you add q random numbers distributed around zero with variance 1, the result will have variance q, and the standard deviation will be the square root of q. So if you add 64 signed numbers that are around 1 in absolute value, the result will be around 8. It would be easy to saturate the softmax with this extra factor, so to get the numbers back to a reasonable range we divide back by the standard deviation.

    We can combine the computation of each \mathbf{z}_i shown above into a single formula in matrix form, which is how self-attention is usually defined:

        \[\mathbf{Z} = \mathrm{softmax}\left(\frac{1}{\sqrt{q}}\mathbf{Q}\mathbf{K}^\top\right)\mathbf{V}.\]

    But this is only one way to “look at” the input vectors! What we have defined now is not the full self-attention layer but only a single self-attention head. We want to parallelize these computations along H different heads, using different weight matrices \mathbf{W}^Q_1,\ldots,\mathbf{W}^Q_H, \mathbf{W}^K_1,\ldots,\mathbf{W}^K_H, and \mathbf{W}^V_1,\ldots,\mathbf{W}^V_H to allow the Transformer layer to consider different combinations of the same input vectors at once.

    This process, known as multi-head attention, is illustrated below:

    Note that after H parallel heads, we get H output matrices \mathbf{Z}_1,\ldots,\mathbf{Z}_H, each of dimension v\times L. We need to compress them all into a single output matrix \mathbf{Z}\in\mathbb{R}^{d\times L} with the same dimension as the input matrix \mathbf{X} so that we can stack such layers further. The Transformer does it in the most straightforward way possible, as shown in the figure: let us just concatenate them all into a single large matrix \mathbf{Z}_{\mathrm{concat}}\in\mathbb{R}^{L\times Hv} and then add another weight matrix \mathbf{W}^O\in \mathbb{R}^{Hv\times d} that will bring the result back to the necessary dimension:

        \[\mathbf{Z}=\left(\mathbf{Z}_{\mathrm{concat}}\mathbf{W}^O\right)^\top.\]

    With the output weight matrix \mathbf{W}^O, we allow the self-attention layer to mix up representations obtained from different attention heads; this is also an important way to add more flexibility and expressiveness to the architecture.

    We are now entirely done with the self-attention layer but not quite done with the entire architecture. We have already discussed the rest of an encoder layer: after the multi-head attention, we use layer normalization with a residual connection \mathrm{LayerNorm}(\mathbf{X} + \mathbf{Z}) and then add a feedforward layer that mixes features inside each of the token representations, with another residual connection around it leading to another LayerNorm, as shown inside the encoder layer in the general figure above.

    This has been the most important part, but there are still quite a lot of bits and pieces of the Transformer architecture to pick up. In the next section, we will discuss how the decoder part works, how the input embeddings are constructed from text, and what the original Transformer architecture actually did.

    Odds and Bits: Decoder, Tokenization, Positional Embeddings, and Discussion

    We have discussed the main idea of self-attention and seen how it comes together in the encoder part of a Transformer. Let us now turn to the right part of the general figure shown above that has two more layer types that we do not know yet: masked self-attention and encoder-decoder attention. Fortunately, they are both now very easy to introduce.

    Masked attention basically means that since the decoder works autoregressively, it should not peek at the tokens that it has not produced yet. This could be done by changing the input, but it’s easier and faster to just have the whole sequence and masking future positions inside self-attention layers themselves. Formally this means that in the self-attention formula, we set the softmax arguments to negative infinity for future tokens, which means that their attention weights will always be zero.

    Encoder-decoder attention is a variation on the self-attention mechanism that takes into account the results of the encoder. These results are vectors with the same dimension as the output matrix \mathbf{Z} that we obtained in the formula above, that is, L vectors of length d, and the figure suggests that each layer in the decoder receives these vectors as a condition. 

    It might seem to require a very different architecture to condition self-attention on this matrix… but in fact it’s an almost trivial modification. We have the exact same self-attention layer described above, with weight matrices that create queries, keys, and values for every attention head, but use different vectors as input:

    • to create the queries, we use vectors from the previous layer, i.e., current representations of already generated output tokens;
    • but for the “documents” in our “retrieval” task, i.e. for the key and value vectors, we use the vectors from the decoder.

    Informally, this means that we are doing “retrieval” on the encoder’s output with queries made of already produced tokens. Formally, we simply use the same formula with queries, keys and values defined above, and all dimensions match nicely: there are L terms in the softmax argument for each vector, but the number of queries and hence number of outputs matches the number of inputs.

    Note also that the decoder has an extra linear layer at the end followed by a softmax for next token classification; this is about the simplest classification head possible, obviously needed for the decoder.

    But this is still not all. We need to discuss one more thing: how does the input text turn into a sequence of dense vectors that self-attention layers process so skillfully? There are two things to discuss here.

    First, tokenization. I have mentioned above that tokens do not really correspond to words. In most Transformer-based models, tokenization is done with a process known as the byte-pair encoding, an interesting idea in its own right based on optimal coding theory such as, e.g., Huffman coding. To begin with, we consider all words present in the input text (the notion of a “word” should be understood liberally, but it is, more or less, a sequence of characters delimited by whitespace) and count the number of their occurrences, building a vocabulary. Let us consider a few words that share a lot of repeating character subsequences:

    We first count the word frequencies, as shown above. Then we break down this vocabulary into individual symbols and count them; in practice there would be extra symbols for the beginning and/or end of a word and all sorts of extra stuff, but let’s keep it simple and stick to our “cat on a mat” example; this is the middle part of the figure above.

    This is our original vocabulary of symbols, and now the encoding process can begin. We break down the symbols into pairs and count their occurrences:

    { ca: 10, at: 27, pe: 12, et: 12, ma: 5, ra: 8, ea: 4, ts: 4 }.

    Then we choose the most frequent pair—in this case “at“—and re-encode it with a single new symbol that is added to the vocabulary; let’s call it Z. After that, we have a new set of words in the new encoding, and we can count the symbols and their pairs again:

    { cZ: 10, pet: 12, mZ: 5, rZ: 8, eZs: 4 },

    { c: 10, Z: 27, t: 12, p: 12, e: 16, m: 5, r: 8, s: 4 },

    { cZ: 10, pe: 12, et: 12, mZ: 5, rZ: 8, eZ: 4, Zs: 4 }.

    At this point, we can choose the new most frequent pair—in this case “pe” or “et“—and replace it with another new symbol, say Y. Replacing “pe“, we get the following new vocabulary and statistics:

    { cZ: 10, Yt: 12, mZ: 5, rZ: 8, eZs: 4 },

    { c: 10, Z: 27, Y: 12, t: 12, e: 4, m: 5, r: 8, s: 4 },

    { cZ: 10, Yt: 12, mZ: 5, rZ: 8, eZ: 4, Zs: 4 }.

    As we run the algorithm in a loop, new symbols may also become part of new pairs; in our example, the next most frequent pair is “Yt“, so after the next step we will have a whole separate token corresponding to “pet“. Note that we never remove symbols from the vocabulary even if they have zero occurrences after a merge: we may not have any t‘s left after the next merge, but new input text may contain new unknown words with t‘s that will need to be processed, so we need the vocabulary to stay universal.

    The encoding process can be stopped at any time: on every step, we get a new extended set of tokens (vocabulary) that compresses the original text in a greedy way, and we can stop and use the current set of tokens. So in practice, we set a target vocabulary size and run the algorithm until the set of tokens reaches this size, usually getting excellent compression for the original text in terms of these new tokens. As a result, words may still be broken into parts, but the most frequent words will get their own tokens, and the parts themselves may have meaning; for example, in English it would be expected to have a token like “tion” appear quite early in the process, which is a frequent sequence of letters with a clear semantics.

    That’s it for tokenization! At this point, the input is a sequence of fixed discrete objects (tokens) taken from a predefined vocabulary of size V. It remains to turn it into a sequence of dense vectors \mathbf{x}\in\mathbb{R}^d, which is usually done via an embedding layer that’s just basically a large d\times V matrix that consists of trainable weights. In earlier days of the deep learning revolution in natural language processing, word embeddings were a quite interesting field of study in and of themselves because they used to be trained separately and then just applied as a fixed “dense vocabulary” that neural models trained on top of. This field of study has given us word2vec (Mikolov et al., 2013a2013bLe, Mikolov, 2014), GloVe (Pennington et al., 2014), and many more interesting ideas… but there is no point to discuss them here because now it’s just a trainable layer like any other, and the whole architecture is being trained at once, including the embedding layer.

    Still, there is another point about the embeddings which is unique to Transformers. One of the main characteristic features of the Transformer architecture is that every input token can “look at” any other input token directly, with no regard for the distance between them in the input sequence. The self-attention layer has a separate attention weight for every pair of tokens, not just neighboring ones or something like that. This has a drawback too: the length of the input, which for language models is known as the context window, automatically becomes bounded. But this is a big advantage over, say, recurrent architectures where you need to go through every step of the sequence, “losing memory” along the way, before the influence of one word can reach another.

    But there is another interesting consequence of this property: since the attention weights cover all tokens uniformly, we lose the sequence. That is, for a self-attention layer there is no sense of some tokens being “next to each other” or “closer in the input sequence”, it is all just a single matrix of weights. We need to give the Transformer an idea of the input sequence artificially; this is done via the so-called positional encodings.

    Positional encodings are vectors added to the embedding that reflect where in the sequence the current token is; we will discuss them briefly but I also refer to a more detailed description of positional encodings by Kazemnejad (2019). How could we encode the position? Well, we could have a number that is increasing with the position but it is hard to get right:

    • if we just used an increasing sequence like 1,2,3,…, it would not generalize to sequences longer than the usual length in the training set, and the network’s behaviour would be much less well defined for large values;
    • if we used a given interval, say [0, 1], and broke it down into the necessary number of pieces, then the positional encoding would have no idea how many words are actually there between two tokens: the distance from 0 to ½ could be one token or one hundred.

    Therefore, the Transformer uses a very clever idea inspired by how we encode numbers in our regular positional notation. Consider a sequence of numbers, say written in binary, like on the left of the figure below:

    As the numbers increase, the value of each digit forms a periodic function, with different periods for different digits in the number. We want to replicate something like that in the positional encodings, but since we have the full power of real numbers available now, we use continuous periodic functions, sine waves:

        \begin{align*}\mathrm{PE}(\mathrm{pos},2i)&=\sin\left(\frac{\mathrm{pos}}{10000^{2i/d}}\right),\\ \mathrm{PE}(\mathrm{pos},2i+1)&=\cos\left(\frac{\mathrm{pos}}{10000^{2i/d}}\right),\end{align*}

    where pos is the token position in the sequence and d is the embedding dimension. This means that each cell in the has a sine wave with respect to the position (half of them shifted to a cosine wave), each with its own period that is increasing with i, that is, with the cell index. The result is shown in the figure above on the right, where the horizontal axis shows the cell indices i and the vertical axis shows token positions from top to bottom. Sine waves become more and more elongated (with increasing period) as we go from left to right, so for 20 tokens we are actually using only about 20-25 dimensions for the positional encoding, but this definition can support arbitrarily long input sequences, even longer than those present in the training set.

    It was a little surprising to me that positional encodings are not concatenated with regular embeddings but rather added to them. It looks counterintuitive because positional information is different and should not be mixed with token semantics. But embeddings are learned rather than fixed, and as you can see in the figure, positional encodings take up a small portion of the overall vector, so they can probably coexist just fine. In any case, the input embedding is a sum of the trainable embedding for the current token and the vector PE(posi) defined above.

    And with that we are completely done with the Transformer architecture, so let us briefly discuss its original results. As I have already mentioned, the original Transformer presented by Vaswani et al. (2017) was doing machine translation, and numerically speaking, results of the original Transformer architecture were not the most striking: the encoder-decoder architecture applied to machine translation scored roughly on par with the best existing models in English-French and English-Deutsch translations. But the Transformer had equally good BLEU scores in machine translation… while requiring 100x less compute for training! And when you have an architecture with 100x less compute, in practice it means that you can train a much larger model (maybe not exactly 100x larger, but still) with the same computational budget, and then you will hopefully scale to much better results.

    Since 2017, Transformers have become one of the most popular architectures in machine learning. In the rest of this post, we will discuss some further extensions and modifications that the Transformer has undergone, although surprisingly few have been necessary to adapt the architecture even to entirely new data modalities.

    Cutting the Transformer in Two: GPT and BERT

    As we have discussed in the previous section, the basic Transformer is a full-scale encoder-decoder architecture, where the encoder produces semantically rich latent representations of text in the input language, and the decoder turns them into text in the target language by writing it autoregressively.

    From there, it was only natural to cut the Transformer in two:

    • we need semantically rich latent representations of high-dimensional input data for a lot of tasks, so we could use the Transformer encoder separately from the decoder to produce these representations;
    • we need good language models that can produce text autoregressively, so we could use the Transformer decoder separately from the encoder to train a language model.

    Let us begin with the latter, i.e., with the decoder, but first let us understand in slightly more detail what we are talking about.

    A language model is a machine learning model that predicts the next token in a sequence of language tokens; it is easier to think of tokens as words, although in reality models usually break words down into smaller chunks. The machine learning problem here is basically classification: what is the next token going to be? A language model is just a classification model, and by continuously predicting the next token, a language model can write text. We have already discussed it in Part VII of the series:

    The only thing a language model can do: predict the next token, over and over. Note that this also means that there are very few problems with data collection or labeling for general-purpose language models: any human-written text becomes a set of labeled examples for supervised learning because the language model just predicts the next word, which is already there in the text. Therefore, you can just collect a lot of text off the Web and train on it! There are several standard datasets that are used to train large language models (LLMs) nowadays:

    • the Common Crawl corpus is a free and open repository of data crawled off the Internet, with over 250 billion Web pages downloaded over more than 15 years; it is a huge and varied corpus that has been used for over 10000 research papers, including modern large language models such as GPT-3, LLaMA, or T5 (Raffel et al., 2020);
    • since the Common Crawl is so huge and diverse, there have been many attempts to refine it, choosing subsets suitable for different tasks; in particular, the C4 dataset (which stands for “Colossal Clean Crawled Corpus”), with about 380GB of text (360 billion tokens) in the cleaned up version and 2.3TB unprocessed, was prepared in 2019 for training the T5 model and remains a very popular dataset derived from Common Crawl;
    • the Pile (Gao et al., 2020) is a freely available corpus with 825 GiB of English text, with an emphasis on diversity: in addition to a specially filtered subset of Common Crawl (Pile-CC), it combines several academic sources such as arXiv and PubMed, source code crawled from GitHub, available datasets of full-scale books, programming-related discussions from StackExchange, and many smaller data sources;
    • finally, although these datasets actually aim to be all-encompassing downloads of the entire Web (perhaps cleaned up and filtered in different ways), work on creating new datasets is still far from over; for example, the RefinedWeb dataset (Penedo et al., 2023) has been released very recently (June 2023) and claims that with some additional preprocessing and filtering, the resulting dataset extracted from Common Crawl (the authors claim about 5 trillion tokens in the full version and release publicly a subset of 600 billion tokens) can result in even higher-performing LLMs.

    And now that we have these huge datasets, the language modeling part appears to be trivial: let us just use the Transformer decoder to autoregressively predict the next token! This exact idea was implemented in a series of models called Generative Pre-Trained Transformers — yes, that’s the famous GPT family.

    In particular:

    • the original GPT (Radford et al., 2018) had 12 layers in the Transformer decoder part, with 12 masked self-attention heads each and 64-dimensional states; it was pretrained on the BookCorpus dataset (Zhu et al., 2015) with over 7000 books (a tiny dataset by modern standards!) and then fine-tuned for specific tasks with labeled data; the authors reported that BookCorpus was chosen so that the model would learn to handle long-range dependencies better;
    • GPT-2 (Radford et al., 2019), released in February 2019, was more or less a direct scale-up of GPT, pretrained on the same BookCorpus dataset and a newly collected WebText dataset with 8 million web pages vetted by humans: they scraped outbound links from Reddit that obtained at least 3 karma (40GB of text in total); the largest version of GPT-2 had 48 layers and dimension 1600, for a total of about 1.5 billion parameters, 10x of the original GPT;
    • GPT-3 (Brown et al., 2020), released in June 2020, scaled things up by two more orders of magnitude; its largest version, known as the davinci family, had an unprecedented size of 175 billion parameters; GPT-3 became the basis for further models such as ChatGPT that we have already discussed in Part VII.

    As the GPT family scaled up, it also obtained more impressive generalization abilities with regard to problems you might want to solve with it. Suppose, for example, that you wanted to recognize entailment relations, that is, find out whether a hypothesis sentence follows from a premise sentence. Data for this problem, e.g., the popular MultiNLI (Multi-Genre Natural Language Inference) corpus (Williams et al., 2018) looks like pairs of sentences labeled with three kinds of results:

    In the example above, for a premise “Two dogs are running through a field” (I took this example from Gugurangan et al., 2018),

    • the hypothesis “There are animals outdoors” gets the label “Entailment“, i.e., it follows from the premise,
    • the hypothesis “Puppies are running to catch a stick” is labeled “Neutral” since while there is no direct contradiction, the dogs might or might not be puppies, and the stick is not necessarily there as well,
    • and the hypothesis “The pets are sitting on a couch” is a clear “Contradiction“, i.e., the premise rules it out.

    Different versions of GPT and BERT would handle the entailment problem differently. To adapt the original GPT (Radford et al., 2018) or BERT (Devlin et al., 2018) to a specific task, you had to fine-tune it, i.e., modify its weights by performing additional training on the downstream task; you had to fine-tune the GPT model with a separate entailment dataset by encoding the dataset into a special form and training new weights for a separator embedding and a new linear layer on top. The figure below shows how this would work for the original GPT and details what new weights have to be learned. This is the way such problems had been processed in deep learning before, e.g., with large-scale recurrent architectures (Rocktäschel et al., 2015).

    Starting from GPT-2, and definitely in GPT-3, developers of Transformer-based architectures moved to a different approach, where no new weights need to be trained at all. Similar to multitask learning and following an earlier attempt by the MQAN (Multitask Question Answering Network) model (McCann et al., 2018), they noted that a variety of different tasks could be encoded into text. Actually, one could argue that the Turing test is so good exactly because you can sneak in a lot of different questions into text-only conversations, including questions about the surroundings, the world in general, and so on. So to make use of a language model’s “understanding” (in this post, I’m putting “understanding” in quotes, but have you seen Part VIII?) of the world, you could give it a few examples of what you need and frame the problem as continuing the text in the prompt. The following figure compares the two approaches (GPT on the left, GPT-2 and 3 on the right) and shows a sample prompt for the logical entailment problem on the right; you would probably obtain better results if you put more examples in the omitted part of the prompt:

    Note that in this approach, all that the language model is doing is still predicting the next token in the text, nothing more! Moreover, it is not even trained to do new problems such as entailment or question answering, it already “understands” what’s needed from its vast training set, and a short prompt with a couple of examples is enough to summon this “understanding” and let the model solve complex semantic problems.

    A different way to cut up the original Transformer was introduced in the BERT model developed by Google researchers Devlin et al. (2018) . BERT stands for Bidirectional Encoder Representations from Transformers. As the name suggests, the main emphasis here is on learning semantically rich representations for tokens that could be further used in subsequent models, somewhat like word embeddings such as word2vec and GloVe had been used before but better and with full context available to the model producing representations.

    To do that, BERT leaves only the encoder part of the Transformer, the part that produces a semantically rich representation for each of the input tokens. But how do we train it if not with the language modeling objective that GPT uses? It turns out that we still can do approximately the same thing: instead of predicting the next token, we can mask out some of the tokens in the input (in the same way as we mask future tokens in the decoder) and predict them based on the full context from both left and right. This is known as masked language modeling, and it is the main pretraining objective for BERT.

    Here is a comparison of the BERT (left) and GPT (right) pretraining objectives:

    Just like language modeling itself, masked language modeling has a long history; it was originally known as the cloze procedure, introduced in 1953 as a readability test for texts in a natural language (Taylor, 1953). The word “cloze” is not a last name, it was derived from “closure”, as in gestalt psychology: humans tend to fill in missing pieces. So if you want to compare how “readable” two texts are, you delete some pieces from them at random and ask people to fill in the blanks: the most readable passage will be the one where the most humans get the most missing pieces right.

    The original BERT combines two variations of this idea:

    • masked language modeling itself, where tokens to be predicted are chosen at random, and 
    • predicting an entire next sentence of tokens, which helps the model make its representations more semantically rich and more oriented towards the global context.

    In later research, more models have been developed based on the Transformer encoder that can provide different flavors of embeddings with somewhat different properties. We will not do a proper survey here, referring to, e.g., (Zhou et al., 2023Wolf et al., 2020Lin et al., 2022), but let us mention a few of the most important BERT variations that have been important for natural language processing applications:

    • RoBERTa (Robustly optimized BERT pretraining approach; Liu et al., 2019) is one of the most widely used modifications; they found that the original BERT was under-trained and fixed it, switched to the byte-level BPE tokenizer that we discussed above, and made a few more tricks to improve pretraining while keeping the architecture itself intact; when people need good pretrained token embeddings to plug into a neural model, they usually take RoBERTa embeddings; there are several different model sizes to choose from;
    • BART (Bidirectional and Autoregressive Transformers; Lewis et al., 2020) turns the Transformer into a denoising autoencoder: it pretrains by corrupting the text and reconstructing the original through the Transformer decoder; although it is a full-scale encoder-decoder architecture I put it here because BART is used in practice very similarly to BERT: you use the semantically rich intermediate representations and discard the denoising decoder because in real life you seldom need to actually denoise corrupted sentences;
    • ALBERT (A Lite BERTLan et al., 2019) applied several techniques to reduce the number of parameters in a Transformer and make training faster while trying to preserve the expressiveness as much as possible; you can probably train ALBERT yourself on a desktop and harness the power of BERT for your own private dataset;
    • DistilBERT (Sanh et al., 2019) moved in the same direction with model distillation techniques, again targeting a model that you can fine-tune with customer-grade hardware;
    • and so on, and so forth, with dozens of derivative models proposed in literature (Ganesh et al., 2021Kalyan et al., 2022Patel et al., 2023Rogers et al., 2020Xu, McAuley, 2023) and available, e.g., in the HuggingFace transformers library.

    BERT and its derivative models such as RoBERTa have proven to be a very valuable tool for natural language processing (Patwardhan et al., 2023). The usual way to apply BERT has been to take the vectors it produces (BERT embeddings, or RoBERTa embeddings, or ALBERT, or any other) and plug them into standard neural models for various natural language processing tasks. This has usually improved things across the board, in problems such as:

    • text classification where one usually takes either the embedding of the special symbol at the beginning or end of the text or all BERT embeddings and applies a simple classification head on top of it (Khadhraoui et al., 2022);
    • the same applies to other tasks that reduce to text classification such as sentiment analysis, irony detection, and others (Barbieri et al., 2020);
    • for sequence labeling tasks such as named entity recognition, you also use Transformer-produced embeddings and an entity classification model on top, but this time the entity classification model may be more complex since we want to predict contiguous multi-word entities (Gu et al., 2021Ji et al., 2020Li et al., 2021);
    • as for question answering and similar tasks that require writing free text, this is usually best served by the GPT family; in Part VII, we have discussed the capabilities of ChatGPT and GPT-4 that make a lot of tricks from prior research unnecessary; this is another example of the “bitter lesson” (Sutton, 2019), and you can decide for yourself whether this is a good thing or a bad thing.

    Finally, another line of models that has been instrumental in modern NLP is XLM (cross-lingual language model; Conneau, Lample, 2019), a series of models based on BERT and GPT that trains on several languages at once. To do that, they apply byte-pair encoding to all language at the same time, getting a shared multilingual vocabulary, and use the same kind of LM and masked LM objectives to train in multiple languages at once. XLM and its successor models such as XLM-RoBERTa (Conneau et al., 2019) defined state of the art in many cross-lingual tasks such as the ones from XNLI, a cross-lingual benchmark for natural language inference (Conneau et al., 2018).

    This has already turned into a high-level survey, so I think it is time to cut the survey short and just say that Transformers permeate absolutely all subfields and tasks of natural language processing, defining state of the art in all of them. But, as we will see in the next section, it’s not just natural language processing…

    Vision Transformers

    The Transformer immediately proved itself to be an excellent model for processing sequences of tokens. We will not speak of it in detail but sequences of other nature have also yielded to the magic of Transformers; for example, HuBERT soon became a standard model for speech processing (Hsu et al., 2021).

    But images seem to be a different beast, right? An image has a natural two-dimensional structure, and deep learning has long had just the recipe for images: convolutional neural networks process every small window in the same way, sharing the weights in a form of ultimate structural regularization. Neural networks have been instrumental in the deep learning revolution, starting from AlexNet that made CNNs great again in 2011-2012 (Krizhevsky et al., 2012) and all the way to the automatically optimized architectures of the EfficientNet family (Tan, Le, 2019).

    Well, it turns out that Transformers can help with images too! To do that, you need to convert an image into a sequence, and usually it is done in a very straightforward way. One of the first models that attempted it was Visual BERT (Li et al., 2019Li et al., 2020), a model initially designed and pretrained for image captioning:

    Since captions deal with objects that appear on an image, Visual BERT preprocessed the image with a fixed pretrained object detection system such as Faster R-CNN (Ren et al., 2015). Then the objects are cut out of the image, embedded into vectors via convolutional networks and special positional embeddings that indicate where the object was in the image, and just fed into a single Transformer:

    The figure above also shows sample attention heads and how words from the caption actually do attend to the objects that they describe or are closely related to.

    The pretraining process closely follows how the original BERT is trained. Visual BERT has two pretraining objectives: masked language modeling where the task is to fill in the blanks in the caption and sentence-image prediction where the model needs to distinguish whether a given caption matches the image or not.

    Similar ideas have been developed in many different BERT-based variations. Let me just note one of them: VideoBERT (Sun et al., 2019) that applied similar ideas to video captioning and processing, including text-to-video generation and forecasting future frames in a video:

    The figure above shows these problems: VideoBERT is able to predict the features of video frames corresponding to a given text (in this case a recipe), although it is, of course, better in the video-to-text direction, exceeding contemporary state of the art in video captioning. 

    VideoBERT is again pretrained with masked language modeling on a sequence of both text captions and video tokens:

    In this case, video tokens are obtained by sampling frames from the video, extracting features with a pretrained CNN, and tokenizing the features with simple k-means clustering. Both Visual BERT and VideoBERT were validated by experimental studies where they exceeded state of the art in visual question answering, image and video captioning, and other similar tasks.

    But the most successful Transformer-based architecture for images has proved to be the Vision Transformer (ViT) developed in 2020 by Google researchers Dosovitsky et al. and introduced in a paper with a pithy title “An Image is Worth 16×16 Words“. Its original illustration from the paper is shown below:

    ViT is again basically a very straighforward modification of BERT. The difference is that now the model does not use text at its input at all, restricting itself to image-based tokens. 

    The input image into small patches: an H\times W image with C channels \mathrm{x}\in\mathbb{R}^{H\times W\times C} becomes a sequence of patches \mathrm{x}_p\in\mathbb{R}^{N\times P\cdot P\cdot C}, where N = HW/P^2 is the number of P\times P patches that fit into the original image (see the illustration above). The patches are turned into embeddings via a simple linear projection, and then the resulting sequence is fed into a Transformer encoder just like BERT. For pretraining, ViT uses masked patch modeling just like BERT does, replacing half of the input embeddings with the same learnable [mask] embedding and aiming to reconstruct the mean colors of the original patches.

    Similar to the original Transformer, ViT uses positional encodings to add information about the sequence. What is even more striking, it is the same positional encoding as in the regular Transformer even though the geometry is now two-dimensional. Dosovitsky et al. report their experiments with positional encodings that would reflect the two-dimensional structure, but, surprisingly, this did not make any significant difference: one-dimensional positional encodings that we discussed above worked just as well.

    Since 2020, ViT has been used for numerous different applications (we refer to the surveys by Guo et al, 2021 and Khan et al., 2022) and has had several important extensions that we will not discuss in detail but have to mention:

    • the Swin Transformer (Liu et al., 2021), where Swin stands for shifted windows, uses an idea similar to ViT but in a hierarchical fashion: it processes image patches on several scales, computing self-attention across patches in a convolutional-like architecture; as a result, it can scale up to larger input resolutions and can be adapted for dense recognition tasks such as image segmentation while the default ViT has to work with relatively large patches and cannot go down to the level of individual pixels needed for segmentation;
    • a later iteration, Swin Transformer v2 (Liu et al., 2022), scaled the Swin Transformer up to 3 billion parameters and allowed for training with images up to 1536\times 1536 pixels, further improving state of the art in image processing problems across the board.

    Finally, another important architecture that has added important new ideas to the Transformer is DeepMind‘s Perceiver (Jaegle et al., 2021a). It is a general-purpose architecture that can process numerous different modalities: images, point clouds, audio, and video, basically of any input dimension. The problem that the Perceiver has to solve is the quadratic bottleneck of Transformer’s self-attention: the formulas we showed above for the original Transformer have quadratic complexity in the input size. Importantly, it’s quadratic in a very specific part of the input size: you can project the queries, keys, and values into smaller dimensions but there is no escape from having quadratic complexity in the number of queries, i.e., the context window size. 

    The Perceiver avoids this bottleneck by using lower-dimensional latent units: it’s quadratic in the number of queries, so we use a small vector of latents for queries and can use large byte arrays as inputs for K and V, projecting them down to a lower-dimensional representation in cross-attention layers, as shown in the original illustration from Jaegle et al., (2021a):

    The cross-attention layer is the same as in the Transformer decoder (see above).

    The next version of Perceiver, called Perceiver IO (Jaegle et al., 2021b), extended this idea to outputs as well as inputs. While the original Perceiver could only solve problems with low output dimensions, such as classification, Perceiver IO can also handle large output arrays such as high-definition images. It is done with a trick reminiscent of how NeRFs represent high-dimensional outputs with implicit functions (Mildenhall et al., 2020Tancik et al., 2023): Perceiver IO uses a smaller output query array to process with cross-attention and then constructs the actual output queries for the large-scale final output in an automated way, by combining a set of vectors that describe properties of the current output such as position coordinates. The general structure looks like this:

    We will not go into more detail on this idea, but as a result Perceiver IO can handle high-dimensional outputs such as images or audio, which means it can scale to problems such as image segmentation, optical flow estimation, audio-video compression by autoencoding and so on.

    In this series, we have used Vision Transformers in Part IV, where they served as the basic image encoders for the CLIP and BLIP models that will provide us with high-quality multimodal latent spaces for both multimodal retrieval and text-to-image conditional generation.

    Conclusion

    The idea of a self-attention layer, originally appearing in the Transformer encoder-decoder architecture in 2017, can be easily called the single most important idea in the last ten years of machine learning. Transformers have, pardon the obvious pun, transformed machine learning, getting state of the art results for all types of unstructured input data, including those that do not have an obvious sequential structure, like images that we have considered above.

    Moreover, as we have seen in Part VII of this series, Transformers are becoming instrumental not only for the academic discipline of machine learning but also for the world economy. Transformative AI (TAI) that we have mentioned in Part VIII is named after an economic transformation similar to the Industrial Revolution, but it might prove to be yet another pun on the world’s most popular architecture.

    Over the course of this “Generative AI” series, we have already taken Transformers and applied them in many different ways: generated discrete latent codes for VQ-VAE-based image generation models in Part III, mapped images and text into a common latent space in Part IV, encoded text to use it to condition diffusion-based models in Part VI, and upscaled straightforward language models from the GPT family into universal tools that find uses across many different industries in Part VII. Who knows, maybe Transformers will get us all the way to AGI, as we have discussed in Part VIII. In any case, it is hard to imagine modern machine learning without Transformers.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Generative AI VIII: AGI Dangers and Perspectives

    Generative AI VIII: AGI Dangers and Perspectives

    This is the last post in the “Generative AI” series. Today, we look into the future and discuss where the current trends take us, what dangers might artificial general intelligence (AGI) hold for us, and whether we are ready for these dangers (spoiler: not at all). I will present the case for AGI doomers and discuss the main arguments, but please keep in mind that in this post, everything is mostly speculation (although there actually are attempts to put this speculation on firm mathematical ground).

    AGI-related risks: a rough classification

    We ended the last post on the differences between slow and fast takeoff speeds in AI development. But regardless of whether superhuman AGI comes overnight or several years after reaching approximately human level, it still may come pretty quickly even with respect to human timescales. We better be ready to face AGI in our lifetimes. Are we?

    The previous post read as a glowing review of the latest developments, and that was intentional. In this post, let me posit right off the bat that the rise of large language models is worrying as much as it is exhilarating. Here is our plan for today, with a rough classification of different levels of potential risks related to human-level and ultimately superhuman intelligence:

    We will begin with what AI researchers usually call “mundane problems”. These are the problems you already hear about on the news sometimes: modern large language models can be jailbroken and then start giving away dangerous information or insult users, modern image generation models can be used to create convincing deepfakes, AI models have biases that come either from the training data or the model architecture itself, and so on. These problems are not entirely new, but I’m positive we can either resolve them or at least become accustomed to them.

    As AI becomes a larger part of the economy (which it almost certainly will), the risks grow as well. Even without reaching superhuman levels, AI is already a transformative technology, leading to a kind of new industrial revolution where many previous jobs may become obsolete. So far, transformations like this have been tumultuous but always ultimately positive: they have always created more and better jobs than they destroyed. Will this be the case for AI as well?

    Finally, even the economy takes a back seat compared to existential risks. We humans are quite new to the idea: true, we have had nuclear weapons able to eliminate humanity (although not really), and the climate change may at some point approach an existential risk, but AI-related risks may prove to be very different, and we will discuss why.

    We will end this last post with a brief overview of what people are currently doing about these risks through the emerging field of AI alignment research. In brief, we hope this research will arrive on time to save us all, but we are still far from a solution.

    The “mundane problems”

    The “mundane problems” are those you hear about when GPT-4 makes the news: AI posing as a human, deepfakes fooling real people with images or voice, and so on. We will see that AI-related dangers are far from limited to the mundane, but let us first consider those.

    First, jailbreaking: the art of making a large language model disobey its explicit instructions (that have been fine-tuned into the model by its developers, probably by RLHF or similar techniques that we discussed the previous post) and exhibit some kind of antisocial behavior. All large language models that we discussed have been jailbroken in some way. You cannot rely on RLHF or other fine-tuning approaches if you are dealing with a determined adversary, so anything a LLM had been trained on can make it to its generated text. Microsoft’s Sydney was shut down after it started implicitly (and sometimes explicitly) threatening users:

    Sydney was kind of a special case: its “niceness-inducing” RLHF was clearly done very sloppily, if at all. This kind of outburst may be harder to get in other models—but far from impossible. Here are, for instance, some jailbreaks for GPT-4. It is hard to say which actually work because they are constantly getting patched, but in essence many of them are variations on the DAN jailbreak (“Do Anything Now”) that was invented for ChatGPT. At some point you could (doesn’t work out of the box now) just paste this prompt and have ChatGPT get you “forbidden” answers while staying in character for DAN:

    Deepfakes are still with us too. In the last post, we discussed how on May 22, a fake Twitter account posing as Bloomberg posted a fake photo of an explosion in the Pentagon complex in Washington DC, leading to an immediate $500B market cap swing. We are sure to expect more fake images, and more AIs posing as people. After all, the very paper introducing GPT-4 shows an example of the model passing a CAPTCHA test with human help:

    These kinds of antics usually make the news because they are both easy to understand and easy to mentally extrapolate: what if everything you see on the Web is more likely to be a deepfake or AI-generated unverified text than not? I do not, however, want to spend too much time on the mundane problems because there’s nothing radically new in them: they are just scaling up already known human behaviors, and it seems that many of these problems already have solutions. For instance, to avoid deepfakes you might want to have trusted sources signing their images with some kind of cryptographic protocol, which would be just a small nuisance for the end user, and current crypto is probably secure enough even for a (somewhat) superintelligent hacker.

    So while it is already taking a lot of effort to fine-tune language models out of this kind of behavior, in my opinion it’s not the crux of the problem. Let us move on to more interesting stuff.

    Economic transformation: the AI industrial revolution

    We move on from mundane problems that look like natural problems for any new and potentially somewhat dangerous technology to something more serious: the economic transformation that AI and AI-related solutions can bring. Mostly everybody agrees that AI, and especially AGI, has the potential to become at least as transformative as the Industrial Revolution.

    This is not just a metaphor but a comparison that can be made numerical. In the report on “Forecasting transformative AI with biological anchors“, Ajeya Cotra operationalizes this analogy as follows: “Roughly speaking, over the course of the Industrial Revolution, the rate of growth in gross world product (GWP) went from about ~0.1% per year before 1700 to ~1% per year after 1850, a tenfold acceleration. By analogy, I think of “transformative AI” as software which causes a tenfold acceleration in the rate of growth of the world economy (assuming that it is used everywhere that it would be economically profitable to use it).”

    Tenfold acceleration in the rate of growth would mean that the world GDP grows by 20-30% per year, that is, doubles approximately every four years. Cotra admits that “this is a very extreme standard”, but for the purposes of our discussion it still falls short of a full-scale technological singularity.

    So far, this sounds great. What are the downsides? How about the jobs lost to AI?

    Whole industries are being transformed by recent AI advancements, and it will definitely take some time for regulation or private contracts to catch up. As a characteristic example, let us consider the Hollywood actors’ and writers’ strike. The Screen Actors Guild – American Federation of Television and Radio Artists (SAG-AFTRA) noticed that actor contracts, especially for relatively unknown actors and extras, started to include clauses that allow the employers to “use an individual’s likeness for any purpose forever without their consent”.

    These clauses had not been controversial when all they meant was that the movie company can include CGI in the scene and apply a filter to the photo. But soon they may mean that when you sign up as an extra, the movie company makes a scan of your face and body, pays you for this day of work, and then proceeds to include your digital avatar into all subsequent pictures with no additional payment to you. Naturally, the whole point of the strike is to amend these contracts, but still: how many actors do you really need if you can copy them from movie to movie?

    The writers are in an even more precarious situation: large language models are already able to write scripts. So far their attempts have not been entirely successful but they are improving, and it’s very possible that soon a human writer will only have to pitch script ideas that get fleshed out by LLMs. See this paper by DeepMind for a detailed explanation of the state of the art in this regard (although this paper is from April 2023, so I’d imagine it’s already behind). 

    Copywriting on the Web, where standards are lower and the vast majority of texts are rehashings, listicles, or short news items, is almost certain to be largely replaced by AI-generated text soon. This very blog would probably read better if I used GPT-4 to write the post from a detailed outline—but I’m old-fashioned, and have soldiered on by myself so far.

    One could ask why this sounds like a problem at all. Humanity has dealt with new technologies before, and while it had sometimes been a bumpy ride it had always resolved itself for the better: more new jobs were created than lost, and the new jobs were less physical, less repetitive, and generally more “human”. As a result, new tech led to higher standards of living for the vast majority of people within a generation or two. The Luddite textile workers would sometimes indeed lose their jobs but on average the Industrial Revolution was a tide that raised all ships.

    AGI, however, might be very different. At some point, especially if robotics improves further (right now it looks like a possible bottleneck), AI might be able to do everything that an average human could. Or, perhaps, everything that a human with an IQ under 100 was able to meaningfully contribute to society—that’s still half of us, by definition. Economies of scale will kick in: you can make AIs and robots cheaper but the cost of human labor will always have a lower bound because people need something to eat and to wear. When using AI becomes cheaper than this lower bound, it won’t be a matter of training for a new job or moving to a new place: for huge numbers of people there will be simply no way to constructively participate in the economy.

    Still, the perspectives of job loss and a possible next societal transformation on the scale of the industrial revolution are not what I am afraid of. After all, making some (or most) humans obsolete comes with some pretty large benefits: working for humans, such powerful AIs will most probably solve many if not all of our health issues, create an economy of abundance, and make work unnecessary for most if not all people. But there is also another option for AGI to be far scarier than just another technological milestone; let’s discuss why.

    The X-Risk

    I’m not worried about the jobs. Or the deepfakes. Or foul language that a machine learning model might use online. What I’m worried about is that humanity is on the verge of creating an entity smarter than ourselves. Last time it happened with the apes and early hominids, and it did not go too well for them.

    The standard argument, presented by Nick Bostrom in his 2003 book “Superintelligence”, involves a thought experiment about a “paperclip maximizer”, a superhuman AGI that is trying to improve the production of paperclips. It probably starts by improving some production processes at the paperclip factory, fully succeeds, and makes the factory into a marvel of optimization. The AGI creators are very happy at that point.

    But then the AGI notices that there are other ways to increase the number of paperclips in the Universe—this is its only objective in the thought experiment. To further increase the number of paperclips, it would be useful to accumulate resources and make itself more powerful in the world. This is the effect known as instrumental convergence: basically whatever goal you set, you benefit your chances of achieving that goal by gathering power and resources.

    Since the AGI is smarter than humans, it begins to accumulate resources in ways that are not obvious for us. A few iterations later the AGI notices that many more paperclips can be done if it takes the planet’s resources under full control. Humans are sure to get in the way so it deals with the humans first. Soon, the Earth is covered with two types of factories: paperclip factories and space docks that build spaceships to start producing paperclips elsewhere. And it all started with a performance optimizing AI:

    Paperclips are just an example, of course. But still, at first glance this sounds dumb: why would the AGI do something stupid like that? Why would we program such a dumb objective function? There are several reasons:

    • first, we don’t know how to specify an objective function that’s aligned with our values; the values are just too complex, and anything we can formalize is much simpler; we mathematicians know that functions are often optimized at extreme values of their arguments;
    • second, instrumental convergence: whatever the final goal (even paperclips), it always helps to get power, get resources, protect yourself, and probably improve yourself, in particular make yourself smarter;
    • third, the orthogonality thesis: the objective function and intelligence used to achieve it are orthogonal; that is, intelligent agents can pursue arbitrary (computable) goals, such as paperclip maximization or getting all humans to smile and say happy things; I’ll leave it to you to imagine how the latter can go horribly wrong.

    Taken together, these reasons do not imply any specific scenario of our doom, and it would be pointless to go into specific scenarios. For instance, paperclip maximization does sound pretty far-fetched by itself.

    But these three reasons do suggest that AGI, if and when it happens, will soon take over the world. Eliezer Yudkowski, whose voice of warning is now increasingly being heard (see the conclusion for a list of references), compares this reasoning to predicting how a chess game goes. If you or I sit down to play against a modern chess engine, nobody can predict how the game will go, which opening we play, and so on and so forth—there are astronomically many ways a chess game can go. What we can predict, quite certainly, is that the chess engine is going to win:

    Similarly, you and I can think of millions of different scenarios of how events may unfold in case we develop a superintelligent AI. Each of these scenarios will be unlikely, but the endpoint appears to be that the AI wins, simply by virtue of being smarter and pursuing the goal of amassing power, which is an instrumental goal for everything else.

    This may sound unreasonable at first glance: why wouldn’t the humans notice that the AI is going rogue and shut it down? Well, to continue the analogy, think about a chimp watching over a human who is making, say, a bow out of string and wood. Would the chimp realize what is going on before it’s too late? Why would we realize anything about an AGI that is actually smarter than us?

    If that still does not look convincing, let us go through some standard counterarguments.

    First, maybe the AI becomes humanlike, even superhuman, but so what? Albert Einstein was super smart, worked on nuclear physics, and he did not destroy the world. Unfortunately, there is no law of physics or biology saying that the human intellect is anything like the limit on cognitive abilities. Our brain sizes are limited by energy consumption and difficulties with childbirth. In examples of cognitive problems where learning is not limited to imitating humans, AI usually has no problem overcoming the human mastery level: think AlphaZero for chess and Go.

    Second, sure, the AI may be smart and even secretly malevolent, but it’s sitting inside a computer, right? What if we just don’t let it out? Unfortunately, we are already letting AIs “out of the box”: people have been happy to provide AutoGPT with access to their personal email, the Internet, personal computers etc. An AI with access to the Web can ask people to do seemingly innocuous tasks, order material things to be 3D-printed, bacteria to be synthesized in labs from a DNA string… possibilities are endless even at the current level of technology.

    Third, this all sounds like a challenge, and maybe you and I cannot solve these problems, but humans are a smart bunch. We have already invented many dangerous technologies but it all has worked out in the end, right? Including the A-bomb and the H-bomb? Well, yes, humans are good in science but making new stuff safe seldom works right at the first try. Henri Becquerel and Marie Curie died from handling radioactive materials, Chernobyl and Fukushima happened despite our best efforts to make nuclear energy safe, Challenger and Columbia disintegrated in flight… With AGI, there may not be a second chance, and we may not be able to contain the damage.

    Finally, if we don’t know how to align AGI, why don’t we just stop short of building it? Nobody is arguing that GPT-4 is going to destroy humanity, and it already has many transformative uses, with new ones being invented every day; why don’t we stop at GPT-4 or maybe GPT-5? Sure, that would be a great solution, but how do we enforce it? It is unclear how long Moore’s law can continue but so far, customer-facing gaming GPUs of today are nearly equivalent to industrial-scale clusters of a few years ago. Nobody can prevent AGI from appearing if all it takes is a few GPUs thrown together in a garage. Containing the development of new hardware might be possible, but it is a coordination problem that requires joint effort from all countries, with no defectors trying to get ahead in any economic or military race by developing new AI techniques… you can see how this is rapidly becoming more far-fetched than a paperclip maximizer. In all probability, humanity will happily march on and create more and more powerful AIs right until the end.

    That was bleak, right? Are there any answers?

    What Can We Do? What Are We Doing?

    There are several approaches that the AI community is currently exploring:

    • interpretability studies, where we are trying to understand what’s going on inside large AI models with the hope that understanding will lead to control;
    • AI safety, which is a term usually applied to fine-tuning LLMs or other AI models with techniques such as RLHF (reinforcement learning with human feedback);
    • AI alignment, understood as aligning the values between AI and humans, that is, making the AI “understand” and “care about” human values rather than blindly optimizing paperclips.

    Having interpretable AI models would help, but this field is also very difficult, and interpretability results are so far quite underwhelming. Modern large language models are black boxes for us, in about the same way that a human brain is a black box: we know how a single neuron works pretty well, and we know which part of the brain is responsible for speech recognition and which is the motor cortex, but that’s a very far cry from actually reading minds.

    AI safety via RLHF and similar techniques may seem more successful; for instance, discovered jailbreaks usually do get patched. However, what we are actually doing to align current LLMs looks like just superficially “teaching them to behave” without any understanding of or control over the underlying processes. This is usually illustrated by the following meme image, where researchers are putting smiley faces on the Shoggoth (a Lovecraftian horror figure also featured in the title images for this section):

    What we really want is AI alignment: making the potential AGI care about us and our values. This problem is usually broken down into two parts:

    • outer alignment asks how to capture our values in a way understandable for AI models; if we design an objective function, are we going to be happy when it is actually optimized? and how do we design it at all? the paperclip example is one of the problems here;
    • inner alignment is the problem of making the model actually optimize the objective function we design for it; this may sound tautological but isn’t: it is very possible, for instance, that the goals emerging during model training align with the objective on the training set but will diverge catastrophically when applied out of distribution.

    Unfortunately, at present we have no idea how to solve these problems. In particular, there already exist many examples of outer alignment failures in the form of specifications gaming, that is, situations where the model is trying to optimize the objective function as stated but coming up with ingenious and undesirable solutions. Here is a list of them compiled by Viktoria Krakovna et al., including such examples as fooling a human evaluator by placing the robotic arm between the object (target for grasping) and the camera or power-seeking behavior found in existing large language models.

    As for inner alignment, an interesting concept here is the Waluigi effect, named after the evil counterpart of Luigi in Nintendo’s Mario franchise. Suppose that we want to train a large language model (or another AI model) to exhibit some desirable behavior, for instance be nice to humans. It can achieve this goal in two different ways:

    • either be genuinely nice to humans (Luigi)
    • or behave nice to humans while secretly being anti-human (Waluigi).

    The interesting observation here is that the latter option looks much more probable! The outward behavior is exactly the same: being nice to humans, so as long as the model is nice it may be in kind of a “superposition” between the two, not necessarily “choosing sides” yet. But “Luigi” is an unstable equilibrium: as soon as the model shows any undesirable behavior, it becomes more likely to be a “Waluigi” (double agent), and there is no way to get back since all “good” behavior is perfectly consistent with the Waluigi! 

    Moreover, once you have a Luigi, all it takes to become a Waluigi is flipping one bit; I was speaking figuratively, of course, but it’s clear that it’s much easier (say, in terms of Kolmogorov complexity) to define something when you have already defined its exact opposite.

    These are just two examples of the arguments that make AI alignment look extremely hard. For a far more exhaustive list, see “AGI Ruin: A List of Lethalities” by Eliezer Yudkowsky, the main spokesperson for the “AI apocalypse” scenario. He makes a convincing argument.

    So what can we do now? Most researchers agree that we will have to solve the hard problem of AI alignment sooner or later, and the best we can do—apart from actually working on the problem—is to somehow contain and possibly even stall AI development until we make real progress. This reasoning, coupled with the staggering rate of developments in the AI spring of 2023, has already led to serious talks about government regulations about AI capabilities development. Here is how it happened (all the quotes are accurate):

    AGI X-risk entered the public consciousness this spring. There have been meetings at the White House and hearings in the US Congress with key players from industry, including OpenAI CEO Sam Altman, Microsoft CEO Satya Nadella, and Google and Alphabet CEO Sundar Pichai. The industry leaders confirmed that they take AGI-related risks very seriously and commit to caution in advancing AI capabilities.

    At the end of May, an open letter warning about AGI-related risks appeared, signed by thousands of researchers and other notable figures in the field of AI. The letter was quite brief:

    I’m sure it was hard to find even a single sentence that everybody could agree on. Still, this sentence definitely captures the current mood of most people involved. There hasn’t been any actual legal action taken yet, but I guess that we can expect more regulation and, most importantly and most hopefully, a more careful approach to developing AI capabilities. Alas, we cannot know if it will help.

    Conclusion

    I hope the last part has not been too encouraging. AI alignment is a field still in its infancy, and it needs all hands on deck, now. So as a conclusion for this post, I wanted to list the key people working on AI alignment and related topics now and key resources that are available if you want to learn more about it:

    • the main forum for all things related to AGI dangers and AI alignment is LessWrong, a rationality-focused portal where all of the people listed below publish regularly;
    • Eliezer Yudkowsky is a key figure here; he has been warning us of superintelligent AI dangers for over a decade now, and I can’t recommend enough his magnum opus, the “Sequences” (not entirely about AI but excellent throughout), the above-mentioned “AGI Ruin: A List of Lethalities”, “AI Alignment: Why It Is Hard and Where to Start”, his recent post “Death with Dignity Strategy” (please take with a grain of salt), and of course, the wonderful “Harry Potter and the Methods of Rationality”;
    • Luke Muehlhauser is a researcher working on AI alignment, in particular on AI-related policy matters at Open Philantropy; to get started I recommend his “Intelligence Explosion FAQ” and “Intelligence Explosion: Evidence and Import”;
    • Paul Christiano is an AI alignment researcher who split from OpenAI to start his own non-profit Alignment Research Center; as a good intro to the field take a look at his “Current Work in AI Alignment” talk;
    • Scott Alexander is not a computer scientist at all but his “Superintelligence FAQ” is an excellent introduction to the AI alignment problem and a great example of why his blog Astralcodexten (previously known as Slatestarcodex) is one of my all-time favorites;
    • if you prefer listening, Eliezer Yudkowsky has been appearing on a number of podcasts recently where he has stated his position in detail; I recommend a 4-hour long interview with Dwarkesh Patel (time flies!), “EconTalk” with Russ Roberts, and a “Bankless” episode with David Hoffman and Ryan Sean Adams; the latter is especially interesting because the hosts clearly wanted to talk about crypto and maybe economic effects of AI but had to face the existential risk and respond to it in real time (in my opinion, they did a great job taking it seriously);
    • finally, I have been following this AI spring mostly through the eyes of Zvi Mowshowitz, who has been publishing weekly newsletters on his blog; there have been over 30 of them already, and I also recommend his other work on the blog and at LessWrong.

    And with this lengthy but hopefully illuminating post I conclude the whole generative AI series! It has been great to be able to talk through the most interesting developments in image generation over the past few years. Til next time!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Generative AI VII: The AI Spring of 2023 

    Generative AI VII: The AI Spring of 2023 

    Last time, we finished all intended mathematical content, so it is time for us to wrap up the generative AI series. We will do it over two installments. Today, we discuss and summarize the (lots of) news that have been happening in the AI space over the last half a year. They all conveniently fall into the generative AI space, with expanding capabilities leading to both extreme excitement and serious security concerns. So how are current AI models different from older ones and when are we going to actually have AGI? It all started with GPT-3.5…

    Large Language Models: I Heard You Like Hype Waves

    Artificial intelligence has a history of ups and downs. The initial wave of excitement after the 1956 Dartmouth seminar ended with the “AI winter” that spanned the 1970s and early 1980s. Then people realized that they could train deep neural networks, and hopes were again high, but it again turned out to be a false start mostly due to insufficient computing power (image source):

    Finally, in the mid-2000s deep learning started to work in earnest, and we have been living on another hype wave of artificial intelligence ever since. The first transformative real world application was in speech recognition and processing (voice assistants were made possible by early deep neural networks), then AlexNet revolutionized image processing, then deep neural networks came into natural language processing, and so on, and so forth.

    But you can have hype waves inside hype waves. And this is exactly what has been happening with large language models over the last year or so, especially last spring. By now, researchers are seriously considering the possibility that we can reach AGI (artificial general intelligence, usually taken to mean human-level or stronger) with our current basic approach, maybe just by scaling it up and thinking of a few more nice tricks for training it.

    How did that happen? Let’s first understand what we are talking about.

    A language model is a machine learning model that predicts the next token in a sequence of language tokens; it’s easier to think of tokens as words, although in reality models usually break words down into smaller chunks. The machine learning problem here is basically classification: what’s the next token going to be?

    By continuously predicting the next token, a language model can write text, and the better the language model, the more coherent the resulting text can be:

    Note that that’s the only thing a language model can do: predict the next token, over and over.

    Language models appeared a very long time ago. In fact, one of the first practical examples of a Markov chain, given by Andrei Markov himself in 1913, was a simple language model that learned how likely vowels and consonants are to follow each other in the Russian language.

    Up until quite recently, language models were Markovian in nature, but the deep learning revolution changed that: recurrent networks were able to hold a consistent latent state and pass it through to the next time slot, which could improve token predictions. But the real game changer came with Transformers, attention-based architectures that started another hype wave on top of deep learning itself.

    Just like deep neural networks were used to achieve state of the art results in virtually every field of machine learning in 2005-2020, after 2017-2018 Transformers did the same thing inside neural networks: the Transformer was invented as an encoder-decoder architecture for machine translation but soon branched into language modelinggeneral text understanding, and later image understandingspeech recognition, and many, many other fields, becoming a ubiquitous tool.

    Still, there is another hype wave inside the Transformers that we are interested in today. So now we are talking about a wave on top of a wave on top of a wave… well, this is the best I could do with Stable Diffusion:

    This latest craze started when OpenAI updated its GPT-3 model with fine-tuning techniques that used human feedback. Introduced in InstructGPT in the spring of 2022, these techniques allowed to make a pure token prediction machine more useful for human-initiated tasks by fine-tuning it on human assessments. An assessor labels how useful and/or harmless was the model’s reply, and the model learns to be more useful and less harmful (more on that later). In this way, a model can learn, for example, to answer human questions with answers that it “considers” to be correct, rather than just continue the conversation by changing the subject or asking a question itself (which could be a quite plausible hypothesis if we are just predicting the next token).

    The fine-tuned models are known as the GPT-3.5 series, and the fine-tuning process itself was finally developed into reinforcement learning from human feedback (RLHF). With RLHF, GPT-3.5 turned into ChatGPT, the model you have definitely heard about. Starting from GPT-3, such models have become collectively known as large language models (LLM), a term basically meaning “large enough to be interesting in practice”. “Large enough” indeed proves to be quite large: GPT-3 (and hence ChatGPT) has about 175 billion trainable parameters.

    Still, note that ChatGPT in essence is still a language model, that is, just a machine for predicting the next token in a text trained on enormous datasets that encompass the whole available Internet. Interestingly enough, that proves to be sufficient for many different applications.

    The AI Spring of 2023

    After ChatGPT was released in November 2022, it became the fastest growing app ever, and the user base grew exponentially. It took a record 5 days to get to 1 million users, and by now ChatGPT has over 100 million users, a number that has probably already more or less saturated but shows few signs of dropping.

    We entered 2023 with ChatGPT, but it turned out to be only the beginning. Here is a timeline of just the main developments in this latest spring of AI:

    Let’s walk through some of them.

    On February 7, Microsoft announced its own answer to ChatGPT, a language model that was supposed to help Bing search. This release proved to be a little premature: the model was quickly jailbroken by the users, revealed its internal name Sydney, and made some interesting comments about it (I did a rather sloppy censoring job below):

    In a way, this is the first time it got into the public consciousness that even the current crop of LLMs may be somewhat dangerous. And yes, I’m still not quite used to this:

    On February 24, Facebook AI Research (FAIR) released the LLaMA model (Large Language Model Meta AI). It’s questionable that LLaMA by itself is any better than GPT-3.5 but LLaMA is important because it is open source. Anyone can download the pretrained model weights, which opened up large language models for a huge community of enthusiasts: you cannot train a GPT-3 sized model at home but you sure can experiment with it, do prompt engineering, maybe even fine-tune it. LLaMA has already led to many new developments from independent researchers, and the recently released LLaMA 2 (July 2023) is sure to follow suit.

    March 14 became the most important single day in this spring of AI. On the same day:

    • Google announced that it would integrate large language models into its ecosystem (that is, Google Docs etc.),
    • Antropic, a startup branched from OpenAI with a special interest in AI safety, released its first LLM called Claude,
    • but OpenAI stole the day from those two by announcing its next level GPT-4 model.

    GPT-4 is supposed to be multimodal, that is, it is able to process both text and images at the same time. At the time of writing, its multimodal capabilities are not yet available to the general public, but existing illustrations from the papers are quite convincing. Here is an example from the original work:

    But to get a better grasp on GPT-4 capabilities, I really recommend reading the paper called “Sparks of Artificial General Intelligence: Early experiments with GPT-4”, also released in March. Their experimental results are eerily good even if cherry-picked:

    Around the same time, OpenAI released a plugin mechanism that allowed people to build upon ChatGPT via prompt engineering, and In April serious projects of this kind started to appear. One of the most interesting such projects was AutoGPT, a plugin that tries (sometimes quite successfully) to make a language model into an agent that acts independently and intelligently to achieve its goals. AutoGPT was advertised as a personal AI assistant, able to achieve the goals set by a user via planning, setting and fulfilling subgoals, analyzing data found on the Web and on the user’s computer. Right now AutoGPT does not look very successful but it is an interesting attempt to make language models agentic (more on that later).

    In May, Google released Bard that proved to be much more successful than Sydney, and made the support for LLMs in the Google ecosystem actually starting to happen. Research-wise, late April and May saw several interesting results aimed at extending the context window for LLMs, that is, how many tokens they can take in and process at a time. Here I will highlight the paper “Scaling Transformer to 1M tokens and beyond with RMT” and Antropic expanding Claude’s context window to 100K tokens. This is already hundreds of pages that a language model can process together, summarize, try to derive new insights, and so on.

    In the last couple of months, this torrent of new AI capabilities has slowed down somewhat. But what does it actually mean? Are we going to have AGI soon? Will AI take our jobs? What’s the plan? Let’s see where we stand right now and what are the projections.

    When AGI?

    ChatGPT and GPT-4 can be transformative in their own right, but what about actual strong artificial intelligence (artificial general intelligence, AGI)? When are we going to have actual human-level intellect in AI models?

    AI has a history of extra optimistic forecasts. For a long time, the AI optimists have been predicting true AGI being about 30 years from whenever the survey was held. That’s understandable: an AI guru would predict that he or she would live to see the true AGI, but in some distant future, not right now. Still, let’s see what the forecasters say now.

    There are approaches to making AGI forecasts by continuing trend lines—naturally, the problem is which trend line to choose. For example, Ajeya Cotra (2020) tried to anchor AGI development in biological analogies. There are several ways to use biology as a measure of how much computation we need to get to human level:

    • there are about 1015 parameters (synapses) in a human brain;
    • to learn the weights for these synapses, we make about 1024 computations during our lifetimes;
    • but to get to the human brain, evolution required about 1052 computations to create our genome (yes, you can have a ballpark estimate even for that).

    The first estimate is clearly too low, the last one is clearly too high, so the truth must be somewhere in between… but where exactly, and why are we supposing that the human brain has any relevance at all? We were abstractly motivated by birds in our desire to fly but inventing the airplane had no relation to avian evolutionary development.

    For a different example, Davidson (2021) constructed a model that can make predictions on AI development via what they call semi-informative priors. But if you look inside the model, all you see is a Markov chain of events like “we tried to develop AGI and failed, next year we tried twice”…

    In my opinion, all we really have are still expert surveys. In August 2022, Grace, Weinstein-Raun, and Stein-Perlman conducted a survey of 738 AI experts (defined as people who authored papers on NeurIPS and ICML). Their median estimate was that we have a 50% chance to develop human-level intelligence in 37 years, by 2059; this is a very close match with the previous survey, conducted in 2016, that placed AGI in 2061.

    Still, these are just medians of some very wide distributions. Wynroe et al. (2023) attempted a meta-review of various transformative AI timelines. Here are the cumulative distribution functions they had:

    And if you prefer numbers, here is a summary table of various percentiles:

    As you can see, experts believe there is a significant chance to achieve AGI by 2050 (more than half) and we are about 90% certain to get there by 2100. Model-based estimates are much more modest but here we average it with evolutionary bio-anchors and whole brain emulations that are hard to believe to be necessary. Still, all of these estimates have extremely wide margins: nobody knows if the path to AGI is already open (and it’s just a matter of scale and lots of compute) or if it requires more conceptual breakthroughs.

    Finally, these days there are people who put their money where their mouths are. The Metaculus prediction market has a popular question that reads as follows: “When will the first general AI system be devised, tested, and publicly announced?” At present (Sep 18, 2023), the forecasting community has a median prediction of 2032:

    Interestingly, last time I checked (in July) their median was November 2032, so it’s slowly creeping up. However, since Metaculus handles real bets, they have to have specific resolution criteria for general AI. In this case, it’s a:

    • two-hour adversarial Turing test,
    • general robotic capabilities, 
    • and human-level or superhuman results on several general-purpose datasets (see the question page for details).

    While this is as good a take on an instrumental definition of AGI as any, I can definitely foresee a model that does all that but is not considered “general AI”, just like many previous benchmarks have been overcome in the past.

    So in summary, I would say that current forecasts are not that different from earlier AI history: we hope to see AGI during our lifetimes, we are far from sure we will, and it’s still hard to define what it is, even as we may be on the very brink of it.

    How AGI? Slow vs. fast takeoff

    Another interesting discussion is not about when AGI comes but about how it is going to happen. Back in 1965, Irving J. Good, a British mathematician and Turing’s coworker at Bletchley Park, suggested the idea of an “intelligence explosion”: a machine with superhuman intelligence will be able to design new intelligent machines faster than humans can, those machines will work faster yet, and ultimately the process will converge to physical limits, and progress will be faster than humans can even notice. This point, when progress becomes “infinitely” fast, is known as the technological singularity.

    Technological singularity due to AGI sounds plausible to most thinkers, but opinions differ on how we get there. The current debate is between “slow takeoff” and “fast takeoff” models: how fast is AGI going to happen and are we going to get any warning about it?

    In the slow takeoff model, AI has a lot of impact on the world, this impact is very noticeable, and, for instance, the world GDP has an order of magnitude growth due to AI before we have the true AGI that could be dangerous. In this model, AI and even AGI falls into the regular technological progress trends and serves as an important but ultimately yet another technological revolution that will allow these trends to continue further. AI can speed up progress further but it’s just human progress continuing along its exponential trend lines.

    In the fast takeoff scenario, AI can and will have an effect in line with “regular” technological progress, but that happens right until the singularity, and then it snowballs very quickly, too quickly for humans to do anything about it. The central scenario for fast takeoff goes as follows: after a certain threshold of capabilities, we get an AI that is simultaneously agentic (which in particular means that it wants power—we’ll get to it in the next post) and able to improve itself. After that, we don’t get any further warnings: the AI improves itself very quickly, quietly obtains sufficient resources, and then simply takes over.

    There have been interesting debates about this that are worth reading. The main proponent of the fast takeoff scenario is Eliezer Yudkowsky, who has been warning us about potential AGI dangers for over a decade; we will consider his work in much more detail in the next post.

    But it’s worth keeping in mind that slow takeoff is “slow” in the sense that we are going to notice. But even the slow takeoff scenario predicts exponential growth! It assumes only a couple of years or maybe even several months between the AI starting to visibly transform society and the arrival of real superhuman AGI. Fast takeoff says it might take seconds… but, to be honest, a year also does not sound like enough time to prepare unless we start now.

    All of this means that we better be ready to face AGI in our lifetimes, perhaps unexpectedly and almost certainly with a very short development timescale. Are we?..

    Conclusion

    This is the last question still left for us in this series: are we ready for AGI?

    Next time, we will discuss the dangers that potentially superhuman AI can pose for humanity. This includes the “mundane” dangers such as the loss of jobs due to this next round of the industrial revolution. But it also includes the potential existential risk of having an AGI that’s smarter than us, more powerful than us, but does not share our values and does not care about humans at all. We will see why it is reasonable to be afraid, what the hard problems are, and how we are currently trying to tackle them. In any case, we sure live in some very interesting times—let’s see what the future brings!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Generative AI VI: Stable Diffusion, DALL-E 2, and Midjourney

    Generative AI VI: Stable Diffusion, DALL-E 2, and Midjourney

    Congratulations, my friends, we have finally come to the end of the series! Although… well, not quite (see below), but we have definitely reached the end of what I had planned originally. Last time, we discussed diffusion-based models, mentioning, if not fully going through, all their mathematical glory. This time, we are going to put diffusion-based models together with multimodal latent spaces and variational autoencoders with discrete latent codes, getting to Stable Diffusion and DALL-E 2, and then will discuss Midjourney and associated controversies. Not much new math today: we have all the Lego blocks, and it only remains to fit them all together.

    Diffusion models + VQ-GAN = Stable Diffusion

    We already know how diffusion-based models work: starting from random noise, they gradually refine the image. The state of the art in 2021 in this direction was DDIMs, models that learn to do sampling faster, in larger steps, but have generally the same final quality.

    Then Stable Diffusion happened. Developed by LMU Munich researchers Robin Rombach et al., it was released in August 2022 and published in CVPR 2022 (see also arXiv). Their idea was simple:

    • diffusion models are very good at generation but relatively slow and hard to scale up to huge dimensions of real images;
    • VAEs are very good at compressing images to a latent space (perhaps continuous, perhaps discrete);
    • so let’s use a diffusion model to generate the latent code and then decode it with a VAE!

    This is basically it: you can train an excellent diffusion model in the low-dimensional latent space, and we have seen that VAE-based models are very good at compressing and decompressing images to/from this latent space. The autoencoder here is the VQ-GAN model that we have discussed earlier.

    Another novelty of the Stable Diffusion model was a conditioning mechanism. The authors used a U-Net as the backbone for the diffusion part, but they augmented it with Transformer-like cross-attention to allow for arbitrary conditions to be imposed in the latent space. As a result, the condition is introduced on every layer of the diffusion decoder and on every step of the denoising. Naturally, the main application of this is to use a text prompt encoded by a Transformer as the condition:

    Stable Diffusion had been released by LMU Munich researchers but soon found itself as the poster child of the Stability AI startup, which recently led to a conflict of interests. The controversy with Stability AI is currently unfolding, and until it is fully resolved I will refrain from commenting; here is a link but let’s not go there now.

    Whatever the history of its creation, Stable Diffusion has become one of the most important models for image generation because it is both good and free to use: it has been released in open source, incorporated into HuggingFace repositories, and several free GUIs have been developed to make it easier to use.

    I will not give specific examples of Stable Diffusion outputs because this entire series of posts has been one such example: all cat images I have used to illustrate these posts have been created with Stable Diffusion. In particular, the prompt shown above is entirely real (augmented with a negative prompt, but making prompts for Stable Diffusion is a separate art in itself).

    Diffusion models + CLIP = DALL-E 2

    Stable Diffusion uses a diffusion model in the latent space of a VQ-GAN image-to-latent autoencoder, with text serving as a condition for the diffusion denoising model. But we already know that there are options for a joint latent space of text and images, such as CLIP (see Part IV of this series). So maybe we can decode latents obtained directly from text?

    DALL-E 2, also known as unCLIP, does exactly that (Ramesh et al., 2022). On the surface, it is an even simpler idea than Stable Diffusion: let’s just use the CLIP latent space! But in reality, they still do need a diffusion model inside: it turns out that text and image embeddings are not quite the same (this makes sense even in a multimodal latent space!), and you need a separate generative model to turn a text embedding into possible matching image embeddings.

    So the diffusion-based model still operates on the latent codes, but now the text is not a condition, it’s also embedded in the same joint latent space. Otherwise it’s the exact same multimodal CLIP embeddings that we discussed in an earlier post. The generation process now involves a diffusion model, which the authors of DALL-E 2 call a diffusion prior, to convert the text embedding into an image embedding:

    (This time, the prompt is a fake, it’s a Stable Diffusion image again.)

    DALL-E 2 reports excellent generation results; here are some samples from the paper:

    The combination of CLIP embeddings and a diffusion model in the latent space allows DALL-E 2 to do interesting stuff in the latent space. This includes highly semantic interpolations such as this one:

    What’s even more interesting, DALL-E 2 can do text-guided image manipulation, changing the image embedding according to the difference between vectors of the original and modified text captions:

    DALL-E 2 is not open sourced like Stable Diffusion, and you can only try it via the OpenAI interface. However, at least we have a paper that describes what DALL-E 2 does and how it has been trained (although the paper does appear to gloss over some important details). In the next section, we will not have even that.

    Midjourney and controversies over AI-generated art

    So what about the elephant in the room? Over the last year, the default models for text-image generation have been neither Stable Diffusion nor DALL-E 2; the lion’s share of the market has been occupied by Midjourney. Unfortunately, there is little I can add to the story above: Midjourney is definitely a diffusion-based model but the team has not published any papers or code, so technical details remain a secret.

    The best I could find was this Reddit comment. It claims that the original releases of Midjourney used a diffusion model augmented with progressive distillation (Salimans, Ho, 2022), a process that gradually combines short sampling steps into new, larger sampling steps by learning new samplers:

    This approach can significantly speed up the sampling process in diffusion models, which we noted as an important problem in the previous post. However, this is just a Reddit comment, and its author admits that available information only relates to the original beta releases, so by now Midjourney models may be entirely different. Thus, in this section let us review the public milestones of Midjourney and the controversies that keep arising over AI-generated art.

    One of Midjourney’s first claims to fame was an image called “Théâtre d’Opéra Spatial” (“Space Opera Theatre”) produced by Jason Allen:

    This image won first place in a digital art competition (2022 Colorado State Fair, to be precise). Allen signed this work as “Jason M. Allen via Midjourney” and insisted that he did not break any rules of the competition, but the judges were unaware that the image had been AI-generated, so some controversy still ensued.

    Later in 2022, Midjourney was used to illustrate a children’s book called “Alice and Sparkle“, very appropriately devoted to a girl who creates a self-aware artificial intelligence but somehow manages to solve the AI alignment problem so in the book, Alice and Sparkle live happily ever after:

    The text of the book was written with heavy help from ChatGPT, and the entire book went from idea to Amazon in 72 hours. It sparked one of the first serious controversies over the legal status of AI-generated art. “Alice and Sparkle” received many 5-star reviews and no fewer 1-star reviews, was temporarily suspended on Amazon (but then returned, here it is), and while there is no legal reason to take down “Alice and Sparkle” right now, the controversy still has not been resolved.

    Legal reasons may appear, however. After “Alice and Sparkle”, human artists realized that the models trained on their collective output can seriously put them out of their jobs. They claimed that AI-generated art should be considered derivative, and authors of the art comprising the training set should be compensated. On January 13, 2023, three artists filed a lawsuit against Stability AI, Midjourney, and DeviantArt, claiming that training the models on original work without consent of its authors constitutes copyright infringement. The lawsuit is proceeding as lawsuits generally do, that is, very slowly. In April, Stability AI motioned to dismiss the case since the plaintiffs failed to identify “a single act of direct infringement, let alone any output that is substantially similar to the plaintiffs’ artwork”. On July 23, Judge William Orrick ruled that the plaintiffs did not present sufficient evidence but allowed them to present additional facts to amend their complaint. We will see how the case unfolds, but I have no doubt that this is just the first of many similar cases, and the legal and copyright system will have to adapt to the new reality of generative AI.

    In general, over 2023 Midjourney has remained the leader in the AI-generated art space, with several new versions released to wide acclaim. This acclaim, however, has also been controversial: users are often divided over whether new versions of image generation models are actually improvements.

    Lately, generated images tend to make the news not as art objects but as fake photographs. AI-generated art has become good enough to pass for real photos, and people have been using it to various effects. In March, Midjourney generated a viral image of Donald Trump being forcefully arrested. On May 22, a Twitter account made to look like a verified Bloomberg feed published a fake image of an explosion near the Pentagon in Washington D.C. The result exceeded expectations: trading bots and/or real traders took the fake news at face value, resulting in a $500B market cap swing:

    While this kind of news keeps attracting attention to generative AI, to be honest I do not really see a big new issue behind these “deepfakes.” Realistic fake photos have been possible to produce for decades, with the tools steadily improving even regardless of machine learning progress. A Photoshop expert could probably make the Pentagon explosion “photo” in a couple of hours; I am not even sure that fiddling with the prompts to get an interesting and realistic result takes significantly less time (but yes, it does not require an experienced artist). While generative models can scale this activity up, it is by no means a new problem.

    Professional artists, on the other hand, face a genuine challenge. I have been illustrating this series of posts with (a rather old version of) Stable Diffusion. In this case, it would not make sense to hire a professional illustrator to make pictures for this blog anyway, so having access to a generative model has been a strict improvement for this series. As long as you are not too scrupulous about the little details, the cats just draw themselves:

    But what if I had to illustrate a whole book? Right now, the choice is between spending money to get better quality human-made illustrations and using generative AI to get (somewhat) worse illustrations for free or for a small fee for a Midjourney subscription. For me (the author), the work involved is virtually the same since I would have to explain what I need to a human illustrator as well, and would probably have to make a few iterations. For the publisher, hiring a human freelancer is a lot of extra work and expense. Even at present, I already see both myself and publishing houses choosing the cheaper and easier option. Guess what happens when this option ceases to be worse in any noticeable way…

    Conclusion

    With this, we are done with the original plan for the “Generative AI” series. Over these seven posts, we have seen a general overview of modern approaches to image generation, starting from the original construction of variational autoencoders and proceeding all the way to the latest and greatest diffusion-based models.

    However, a lot has happened in the generative AI space even as I have been writing this series! In my lectures, I call 2023 “the spring of artificial intelligence”: starting from the growing popularity of ChatGPT and the release of LLaMA that put large language models in the hands of the public, important advances have been made virtually every week. So next time, I will attempt to review what has been happening this year in AI; it will not be technical at all but the developments seem to be too important to miss. See you then!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Generative AI V: Diffusion-based models

    Generative AI V: Diffusion-based models

    By this time, we have discussed nearly all components of modern generative AI: variational autoencodersdiscrete latent spaces, how they combine with Transformers in DALL-E, and how to learn a joint latent space for images and text. There is only one component left—diffusion-based models—but it’s a big one! Today, we discuss the main idea of diffusion-based models and go over the basic diffusion models such as DDPM and DDIM. Expect a lot of math, but it will all pay off at the end.

    Diffusion-based models

    We have already discussed the main idea behind diffusion in machine learning in the very first, introductory post of this series. As a quick reminder, the idea is to train a model to denoise images or other objects so well that in the end, you can give it (what looks like) random noise as input and after several rounds of denoising get a realistic object.

    In the space of images, it would look something like this. Suppose you have some kind of a noise in mind, most probably Gaussian. This defines a probability distribution q(\mathbf{x}_{k+1}| \mathbf{x}_{k}), where \mathbf{x}_{k} is the input image and \mathbf{x}_{k+1} is the image with added noise. Applying this distribution repeatedly, we get a Markov chain called forward diffusion that gradually adds noise until the image is completely unrecognizable:

    But on every step of this transformation, you add only a little bit of noise, and it is reasonable to expect that a denoising model would learn to almost perfectly get rid of it. If you get such a denoising model, again in the form of a distribution p_{\boldsymbol{\theta}}(\mathbf{x}_{k}| \mathbf{x}_{k+1}) with model parameters \boldsymbol{\theta} that should be a good approximation for the inverted q(\mathbf{x}_{k}| \mathbf{x}_{k+1}), you can presumably run it backwards and get the images back from basically random noise. This process is known as reverse diffusion:

    However, as Woody Allen put it, “right now it’s only a notion, but I think I can get money to make it into a concept, and later turn it into an idea”. Training a denoising model straightforwardly, by using pairs of images produced by q(\mathbf{x}_{k+1}| \mathbf{x}_{k}) as supervision, will not get us too far: the model needs to understand the entire dynamics and make its backwards steps smarter.

    Therefore, we use approximate inference to get from \mathbf{x}_{n} to \mathbf{x}_{0}. Since we already know variational autoencoders, I will mention that one good way to think about diffusion models is to treat them as hierarchical VAEs that chain together several feature-extracting encoders, but with additional restrictions on the encoders and decoders.

    But this is where it gets mathy. The next section is not for the faint of heart, but I still include it for those of you who really want to understand how this stuff works. I will not refer to the derivation details later, so if the next section is a bit too much, feel free to skip it.

    Probabilistic diffusion models: idea and derivation

    Probabilistic diffusion models were introduced by Sohl-Dickstein et al. in “Deep Unsupervised Learning using Nonequilibrium Thermodynamics” (2015). As you can see from the title, it was a novel idea that went in an unexplored direction, and it had taken five years since 2015 to make it work reasonably efficiently, and a couple more years to turn it into the latent diffusion type models that we enjoy now.

    Still, the basic concept remains the same. The forward diffusion process adds Gaussian noise, and the reverse diffusion model learns to restore the original image. Let’s dive into the details!

    First, if we consider the noise to be Gaussian then we can get a result very similar to the reparametrization tricks we have seen earlier for VAE and dVAE: we can “compress” the whole chain into a single Gaussian. Formally, assume that q(\mathbf{x}_{t}| \mathbf{x}_{t-1}) is a Gaussian with variance \beta_t and mean that reduces \mathbf{x}_{t-1} by a factor of the square root of \alpha_t=1-\beta_t (this is necessary to make the process variance preserving, so that \mathbf{x}_{t} would not explode or vanish), and the entire process takes T steps:

        \[q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}\left(\mathbf{x}_t | \sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t\mathbf{I}\right),\qquad q\left(\mathbf{x}_{1:T} | \mathbf{x}_0\right) = \prod_{t=1}^T q\left(\mathbf{x}_t | \mathbf{x}_{t-1}\right).\]

    Then we can write

        \begin{align*} \mathbf{x}_t &= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1-\alpha_t}\boldsymbol{\epsilon} \\ & = \sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}}\mathbf{x}_{t-2} + \sqrt{1-\alpha_{t-1}}\boldsymbol{\epsilon}\right) + \sqrt{1-\alpha_t}\boldsymbol{\epsilon} \\ & =\sqrt{\alpha_t\alpha_t-1}\mathbf{x}_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\boldsymbol{\epsilon} = \ldots \\ & = \sqrt{A_t}\mathbf{x}_0 + \sqrt{1 - A_t}\boldsymbol{\epsilon},\quad\text{where}\quad A_t = \alpha_1\alpha_2\ldots\alpha_t. \end{align*}

    This means that the compressed distribution q\left(\mathbf{x}_{t} | \mathbf{x}_0\right) is also a Gaussian, and we know its parameters:

        \[q\left(\mathbf{x}_{t} | \mathbf{x}_0\right) = \mathcal{N}\left(\mathbf{x}_{t} | \sqrt{A_t}\mathbf{x}_0, \left(1-A_t\right)\mathbf{I}\right).\]

    This makes the forward diffusion process very efficient: we can sample from q\left(\mathbf{x}_{T} | \mathbf{x}_0\right) directly, in closed form, without having to go through any intermediate steps.

    It might seem that inverting Gaussians should be just as easy as stringing them together. And indeed, if our problem was to invert the Gaussian part of the process for a given \mathbf{x}_0, it would be easy! Let’s use the Bayes formula and substitute distributions that we already know:

        \begin{align*}q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) &= \frac{q\left(\mathbf{x}_{t} | \mathbf{x}_{t-1}, \mathbf{x}_0\right)q\left(\mathbf{x}_{t-1} |\mathbf{x}_0\right)}{q\left(\mathbf{x}_t | \mathbf{x}_0\right)} \\&= \frac{\mathcal{N}\left(\mathbf{x}_t|\sqrt{\alpha_t}\mathbf{x}_{t-1}, \left(1-\alpha_t\right)\mathbf{I}\right)\mathcal{N}\left(\mathbf{x}_{t-1}|\sqrt{A_{t-1}}\mathbf{x}_0, \left(1-A_{t-1}\right)\mathbf{I}\right)}{\mathcal{N}\left(\mathbf{x}_{t}|\sqrt{A_{t}}\mathbf{x}_0, \left(1-A_{t}\right)\mathbf{I}\right)} \\  & = \mathrm{Const} \cdot e^{-\frac12\left(\frac{\left(\mathbf{x}_t-\sqrt{\alpha_t}\x_{t-1}\right)^2}{1-\alpha_t} + \frac{\left(\mathbf{x}_{t-1}-\sqrt{A_{t-1}}\mathbf{x}_{0}\right)^2}{1 - A_{t-1}} + \frac{\left(\mathbf{x}_{t}-\sqrt{A_{t}}\mathbf{x}_{0}\right)^2}{1 - A_{t}}\right)}. \end{align*}

    It is already clear that the new distribution is a Gaussian as well, since its density has a quadratic function of \mathbf{x}_{t-1} in the exponent. I will skip the gory details of extracting the square from this exponent, but the result is, again, a nice and clean Gaussian whose parameters we know and that we could easily sample from:

        \begin{align*}q\left(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0\right) &= \mathcal{N}\left(\mathbf{x}_{t-1}| {\tilde{\boldsymbol{\mu}}}\left(\mathbf{x}_t,\mathbf{x}_0\right), {\tilde{\beta}}_t\mathbf{I}\right),\quad\text{where} \quad{\tilde{\beta}}_t = \frac{1 - A{t-1}}{1 - A_t}\cdot\beta_t,\\{\tilde{\boldsymbol{\mu}}}\left(\mathbf{x}_t,\mathbf{x}_0\right) &= \frac{\sqrt{\alpha_t}\left(1 - A_{t-1}\right)}{1 - A_t}\mathbf{x}_t + \frac{\sqrt{A_{t-1}}\beta_t}{1 - A_t}\mathbf{x}_0= \frac{1}{\sqrt{A_t}}\left(\mathbf{x}_t - \frac{1-\alpha_t}{\sqrt{1-A_t}}\boldsymbol{\epsilon}\right).\end{align*}

    Are we done? Of course not, we are just getting started! This simple distribution is conditioned on \mathbf{x}_0… but it is exactly q(\mathbf{x}_0) that represents the impossibly messy distribution of, say, real life images. Ultimately we want our reverse diffusion process to reconstruct q(\mathbf{x}_0) from a standard input at \mathbf{x}_n; something like this:

    The whole problem of training a generative model, as we have discussed many times on this blog, is to find a good representation for q(\mathbf{x}_0), and our process so far treats it as a known quantity.

    What do we do? As usual in Bayesian inference, we approximate. On every step, we want the model to be a good approximation to q(\mathbf{x}_t|\mathbf{x}_{t-1}), with no conditioning on the unknown \mathbf{x}_0:

    To get this approximation, we need a variational lower bound pretty similar to the one used in variational autoencoders and DALL-E. We will start with a bound for the global distribution q(\mathbf{x}_{1:T}|\mathbf{x}_{0})=q(\mathbf{x}_{1},\ldots,\mathbf{x}_{T}|\mathbf{x}_{0}):

    And then it will turn out that it decomposes into bounds for individual steps of the diffusion process.

    Since we’re doing a lot of math here anyway, let us derive the variational lower bound from first principles, just like we did in the post on VAEs. We start from the obvious equality

        \[\log p_{\boldsymbol{\theta}}(\mathbf{x}_0) = \log p_{\boldsymbol{\theta}}(\mathbf{x}_0,\mathbf{x}_1,\ldots,\mathbf{x}_T) - \log p_{\boldsymbol{\theta}}(\mathbf{x}_1,\ldots,\mathbf{x}_T|\mathbf{x}_0),\]

    take the expectation with respect to q(\mathbf{x}_{0:T})=q(\mathbf{x}_{0},\ldots,\mathbf{x}_{T}), and then add and subtract \log q(\mathbf{x}_{1:T}|\mathbf{x}_{0})= \log q(\mathbf{x}_{1},\ldots,\mathbf{x}_{T}|\mathbf{x}_{0}) on the right-hand side:

        \begin{align*}\mathbb{E}_{q(\mathbf{x}_{0})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}_0)\right] &= \mathbb{E}_{q(\mathbf{x}_{0:T})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}_{0:T})\right] - \mathbb{E}_{q(\mathbf{x}_{0:T})}\left[\log p_{\boldsymbol{\theta}}(\mathbf{x}_{1:T}|\mathbf{x}_0)\right] \\ & = \mathbb{E}_{q(\mathbf{x}_{0:T})}\left[\log\frac{p_{\boldsymbol{\theta}}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})}\right] + \mathbb{E}_{q(\mathbf{x}_{0:T})}\left[\log\frac{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{1:T}|\mathbf{x}_{0})}\right].\end{align*}

    At this point, we note that the second term on the right is the KL divergence between q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) and p_{\boldsymbol{\theta}}(\mathbf{x}_{1:T}|\mathbf{x}_{0}), so that’s what we want to minimize in the approximation. Since on the left-hand side we have a constant independent of \mathbf{x}_{1:T} minimizing the KL divergence with respect to q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) is equivalent to maximizing the first term on the right, which is our bound. 

    It will be more convenient to think of it as a loss function, so let’s add a minus sign in front, that is, let’s invert the fraction inside the logarithm. Then we can note that the bound decomposes nicely into the sum of individual steps; this is the last long derivation in this post (phew!):

        \begin{align*}\mathcal{L} =& \mathbb{E}_{q}\left[\log\frac{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{0:T})}\right] = \mathbb{E}_{q}\left[\log\frac{\prod_{t=1}^Tq(\mathbf{x}_{t}|\mathbf{x}_{t-1})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{T})\prod_{t=1}^Tp_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}\right] \\=& \mathbb{E}_{q}\left[-\log p_{\boldsymbol{\theta}}(\mathbf{x}_{T}) + \sum_{t=1}^T\log\frac{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}\right] \\=& \mathbb{E}_{q}\left[-\log p_{\boldsymbol{\theta}}(\mathbf{x}_{T}) + \sum_{t=2}^T\log\frac{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_{t})} + \log\frac{q(\mathbf{x}_{1}|\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{0}|\mathbf{x}_{1})}\right] \\=& \mathbb{E}_{q}\left[-\log p_{\boldsymbol{\theta}}(\mathbf{x}_{T}) + \sum_{t=2}^T\log\left(\frac{q(\mathbf{x}_{t}|\mathbf{x}_{t-1},\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}\frac{q(\mathbf{x}_{t}|\mathbf{x}_{0})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}\right) + \log\frac{q(\mathbf{x}_{1}|\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{0}|\mathbf{x}_{1})}\right] \\=& \mathbb{E}_{q}\left[-\log p_{\boldsymbol{\theta}}(\mathbf{x}_{T}) + \sum_{t=2}^T\log\frac{q(\mathbf{x}_{t}|\mathbf{x}_{t-1},\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_{t})} + \log\frac{q(\mathbf{x}_{T}|\mathbf{x}_{0})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} + \log\frac{q(\mathbf{x}_{1}|\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{0}|\mathbf{x}_{1})}\right] \\=& \mathbb{E}_{q}\left[\log\frac{q(\mathbf{x}_{T}|\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{T})} + \sum_{t=2}^T\log\frac{q(\mathbf{x}_{t}|\mathbf{x}_{t-1},\mathbf{x}_{0})}{p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_{t})} - \log p_{\boldsymbol{\theta}}(\mathbf{x}_{0}|\mathbf{x}_{1})\right].\end{align*}

    Now we see that the loss function decomposes nicely into a sum of T+1 components, and almost all of them are actually KL divergences between Gaussians:

        \begin{align*}L =& L_T + L_{T-1} + \ldots + L_0,\qquad\text{where}\\L_T =& \mathrm{KL}\left(q(\mathbf{x}_{T}|\mathbf{x}_{0})\|p_{\boldsymbol{\theta}}(\mathbf{x}_T)\right),\\L_t =& \mathrm{KL}\left(q(\mathbf{x}_{T}|\mathbf{x}_{t+1},\mathbf{x}_{0})\|p_{\boldsymbol{\theta}}(\mathbf{x}_t|\mathbf{x}_{t+1})\right),\quad t=1,\ldots,T-1,\\L_0 =& -\log p_{\boldsymbol{\theta}}(\mathbf{x}_0 | \mathbf{x}_1).\end{align*}

    All of these components are now relatively straightforward to compute; for example, in L_t we are using the Gaussian parametrization

        \[p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}\left(\mathbf{x}_{t-1}| \boldsymbol{\mu}_{\boldsymbol{\theta}}\left(\mathbf{x}_t,t\right),\Sigma_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right)\]

    and trying to match its parameters with q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_{0}). For the mean, for instance, we get

        \[\boldsymbol{\mu}_{\boldsymbol{\theta}}\left(\mathbf{x}_t,t\right) \approx {\tilde{\boldsymbol{\mu}}}_t\left(\mathbf{x}_t,\mathbf{x}_0\right) = \frac{1}{\sqrt{A_t}}\left(\mathbf{x}_t - \frac{1-\alpha_t}{\sqrt{1-A_t}}\boldsymbol{\epsilon}_t\right),\]

    and since we know \mathbf{x}_t during training, we can actually parametrize the noise directly rather than the mean:

        \[p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}\left(\mathbf{x}_{t-1}\middle| \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-A_t}}\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\mathbf{x}_t,t)\right),\Sigma_{\boldsymbol{\theta}}(\mathbf{x}_t, t\right)\right).\]

    I will stop the calculations here but I hope you are convinced now that this whole reverse diffusion Markov chain comes down to a closed form loss function that you can program in PyTorch and minimize. This was the main idea of the original paper by Sohl-Dickstein et al. (2015). Let us see where it has gone since then.

    Denoising diffusion probabilistic models

    In 2015, the original diffusion model could only be run on datasets that by now sound more like toy examples. For instance, Sohl-Dickstein et al. give examples of how their generative model fares on CIFAR-10. In the image below, (a) shows some original hold-out images from CIFAR-10, in (b) they are corrupted with Gaussian noice, (c) shows how the diffusion model can denoise the images from (b), using them as starting points for the reverse diffusion chain, and finally (d) shows new samples generated by the diffusion model:

    That looked somewhat promising, but perhaps not promising enough to warrant a concerted effort to develop this approach. At the time, people were just getting excited with GANs: the original work by Goodfellow was published in 2014, ProGAN (thispersondoesnotexist) would be released in 2017, and GANs would define state of the art in image generation for the next years, until they arguably ran out of steam somewhere about StyleGAN 3.

    Therefore, the next stop on our way happened only five years later, in 2020, in the work “Denoising Diffusion Probabilistic Models” (DDPM) by Ho et al. They used the same basic idea and arrived at the same basic structure of the loss function; I reproduce it here in a general form since I suspect many readers have not followed through all the derivations in the previous section:

    There are three different components in this loss function, two of them appearing at the ends of the chain and one that is responsible for every intermediate step. Ho et al. make the following observations and simplifications:

    • they assume all forward diffusion variances \beta_t to be constant hyperparameters and do not train them, so there is nothing to train at all in the forward diffusion distributions q; since p_{\boldsymbol{\theta}}(\mathbf{x}_{T}) is a fixed distribution that we want to sample from, this means that LT is a constant and can be ignored;
    • for the intermediate steps, they do not train the variances in p_{\boldsymbol{\theta}}(\mathbf{x}_{t} | \mathbf{x}_{t+1}) either, setting them to \sigma^2\mathbf{I} for some constant \sigma; they also develop the noise reparametrization mentioned above somewhat further, obtaining a simple closed form for L_t;
    • finally and most importantly, they substitute a separate discrete decoder for L_0; namely, they assume that the data consists of integers from 0 to 255 scaled linearly to [-1, 1], which is a natural representation for images, and model

          \[p_{\boldsymbol{\theta}}(\mathbf{x}_{0} | \mathbf{x}_1) = \prod_{i=1}^D\int_{\delta_-\left({x}_0,i\right)}^{\delta_+\left({x}_0,i\right)} \mathcal{N}\left({x}\middle|{\mu}_{\boldsymbol{\theta},i}\left(\mathbf{x}_1\right), \sigma_1^2\right)\mathrm{d} x,\]


      where i goes over the pixels (components of \mathbf{x}), {\mu}_{\boldsymbol{\theta},i}\left(\mathbf{x}_1\right) is the independent decoder model, and the integration limits define an interval of length 1/255 on every side of x_{0,i}, which is a standard trick to make everything smooth and continuous.

    As a result, you can substitute a different model at the last step and use a noiseless {\mu}_{\boldsymbol{\theta},i}\left(\mathbf{x}_1\right) during test-time sampling, which extends the capabilities of the whole diffusion model significantly.

    With these modifications, DDPM was able to achieve state of the art generation, comparable with the best GANs of the time. Here is a sample:

    Still, that’s not quite the end of the story even for basic diffusion-based models.

    Denoising Diffusion Implicit Models

    The next step came very quickly after DDPMs, in the work called “Denoising Diffusion Implicit Models” (DDIM) by Song et al. (2020). They aim at the same model as DDPM, but address an important drawback of all diffusion models we have discussed so far: they are extremely slow. The generation process mirrors every step of the diffusion process, so to generate a new sample you have to go through thousands of steps (literally!), on every step we have to apply a neural network, and the steps are consecutive and cannot be parallelized. This is especially bad in contrast to the usual deep learning paradigm where it might take you a very long time to train a model but applying it is usually pretty fast: Song et al. mention that sampling from a trained GAN is over 1000x faster than sampling from a DDPM trained for the same image size.

    How can we speed up this construction, which at first glance looks inherently incremental? Song et al. do it by generalizing diffusion models and DDPMs specifically. They note that the loss function we discussed above does not depend directly on the joint distribution q\left(\mathbf{x}_{1:T} | \mathbf{x}_0\right)=q\left(\mathbf{x}_{1},\ldots,\mathbf{x}_{T} | \mathbf{x}_0\right) but only on the marginal distributions q\left(\mathbf{x}_{t} | \mathbf{x}_0\right). This means that we can reuse the exact same learning objective for a different joint distribution as long as it has the same marginals.

    Song et al. define their diffusion process in terms of its reverse form:

        \[q_{\sigma}\left(\mathbf{x}_{1:T} | \mathbf{x}_0\right)= q_{\sigma}\left(\mathbf{x}_{T} | \mathbf{x}_0\right)\prod_{t=2}^T q_{\sigma}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t},\mathbf{x}_0\right).\]

    Now we can express the forward diffusion distributions via the Bayes theorem:

        \[q_{\sigma}\left(\mathbf{x}_{t} | \mathbf{x}_{t-1},\mathbf{x}_0\right)=\frac{q_{\sigma}\left(\mathbf{x}_{t-1} | \mathbf{x}_{t},\mathbf{x}_0\right)q_{\sigma}\left(\mathbf{x}_{t} | \mathbf{x}_0\right)}{q_{\sigma}\left(\mathbf{x}_{t-1} | \mathbf{x}_0\right)}.\]

    Song et al. show (I promised to contain the complicated math in the first section, so I’ll skip the derivation here) that the resulting process has the same marginals, and the reverse diffusion can be trained with the same loss function and will represent an actual Markov chain:

    So far it does not sound very helpful: we have extended the class of forward diffusion distributions but sampling a new image still requires going through all the reverse diffusion steps. However, the key observation here is that instead of approximating the random noise \boldsymbol{\epsilon}_t that gets us from \mathbf{x}_{t} to \mathbf{x}_{t+1}, we are now approximating the random noise \boldsymbol{\epsilon}_t that is mixed with \mathbf{x}_{0} to obtain \mathbf{x}_{t+1}

    This process, in essence, means that when we are going in the reverse direction, we are approximating the direction not to \mathbf{x}_{t}, but directly to \mathbf{x}_{0}, and make a step in that direction. Here is an illustration for the difference:

    A DDPM model is trying to approximate the step from \mathbf{x}_{t+1} to \mathbf{x}_{t}, failing somewhat and getting a worse image. A DDIM model is trying to approximate the direction all the way from \mathbf{x}_{t+1} to \mathbf{x}_{0}; naturally, it fails a little and if it tried to go all the way to \mathbf{x}_{0} it would miss by a lot so it makes a small step in the approximate direction. It is hard to say which method is doing a better job at the approximation itself, but there is an important benefit to the DDIM scheme in terms of performance.

    Since now \boldsymbol{\epsilon}_{t} and the dependence on \mathbf{x}_{0} are disentangled, \boldsymbol{\epsilon}_{t} is just a Gaussian noise variance, and we can jump over several steps in the process, getting from \mathbf{x}_{t}  to \mathbf{x}_{t+k} in a single step with correspondingly increased \boldsymbol{\epsilon}! One can train a model with a large number of steps T but sample only a few of them in the generation part, which speeds things up very significantly. Naturally, the variance will increase, and the approximations will get worse, but with careful tuning this effect can be contained.

    Song et al. achieve 10x to 100x speedups compared to DDPM, with insignificant loss in quality:

    Moreover, DDIMs also have a generation process that does not need to be stochastic! Song et al. suggest setting the variance hyperparameter in the reverse diffusion chain to zero during generation. This means that a latent code in the space of \mathbf{x}_T corresponds to exactly one image, and now we can expect DDIMs to behave in the same way as other models that train latent representations (compare, e.g., our previous post), including, for instance, interpolations in the latent space:

    Note that DDPMs could not do interpolations because a latent code xT would have a huge amount of noise added to it during the reverse diffusion process; it wasn’t really a “code” for anything, just a starting point for the Markov chain.

    Conclusion

    Today, we have introduced the basics of diffusion models in machine learning. This field started in 2015, and its basic idea of learning gradual denoising transformations was preserved in later developments: DDPMs made several improvements that allowed to scale diffusion models up, and DDIMs increased the performance of the generation process and made it deterministic, which opened up a number of new possibilities.

    There is basically only one step left before we get to the cutting edge models such as Stable Diffusion and DALL-E 2. Next time, we will take this step; stay tuned!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • CLIP and multimodal retrieval: Generative AI IV

    CLIP and multimodal retrieval: Generative AI IV

    Last time, we discussed DALL-E, a model that brings together text Transformers and a discrete VAE for images. While DALL-E was a huge step forward and generated a lot of buzz for generative AI back in 2021, modern generative models such as DALL-E 2 consist of different components. One of them is usually a multimodal encoder that maps different modalities (e.g., text and images) into the same latent space. Today, we discuss such encoders and then make an example of a specific practical problem where they have become instrumental over the last couple years: text-video retrieval, that is, searching for video content by text queries.

    CLIP: Contrastive Language-Image Pretraining

    A model that has proven to be one of the most important for multimodal retrieval is CLIP, introduced by OpenAI in 2021. The basic motivation behind CLIP was to use the data freely available on the Web: text paired with images, i.e., captions of the form like “a black and white cat” or “Pepper the aussie pup” used in OpenAI’s illustrations (see below).

    The question, of course, is how to use this huge data. The authors of CLIP reported that their first instinct was to train an image CNN and a text Transformer to predict a caption of an image. Transformers are famously good at generating text, but it turned out that the resulting recognition quality for ImageNet classes was no higher than from a bag-of-words baseline. The reason was that predicting the exact caption is very hard (basically hopeless), and it’s not really what we need in this model—we just need good multimodal embeddings.

    Therefore, CLIP switched to the titular idea of contrastive pretraining: we want both text descriptions and the images themselves to map to the same latent space, so let’s use a loss function that brings positive pairs (correct descriptions) closer together and negative pairs (incorrect descriptions) further apart.

    In the picture below, I show the “attractive and repulsive forces” (green and red arrows respectively) that should appear between two image-description pairs, with each pair used as negative samples for the other:

    CLIP takes this idea and runs with it, constructing a whole matrix of similarities (dot products in the latent space) between the embeddings of N images and N corresponding textual descriptions. As a result, we get an NxN matrix where the diagonal corresponds to positive pairs (so diagonal elements should be made larger) and all other elements correspond to negative pairs (so off-diagonal elements should be made smaller). 

    Here is the main illustration for this idea from the CLIP paper:

    The encoders, of course, are Transformer-based architectures, specifically the Vision Transformer (ViT) that breaks an input image into patches and treats embeddings of patches as tokens for the Transformer architecture. The margins of this blog post are too narrow to explain Vision Transformers; maybe one day we will have a series about Transformers, what with them being the most important architecture of the last years and all. For now, let’s just assume that ViTs are good at converting images into embeddings, and ViT itself has been a key component of many multimodal architectures; see the original paper by Dosovitskiy et al. for details.

    The original work shows that CLIP is very capable of zero-shot classification: you can turn a class label into a rudimentary query (e.g., “cat” becomes “a photo of a cat”) and get a reasonable classifier by finding nearest neighbors in the joint latent space (image by OpenAI):

    But the main use of CLIP has been for enabling text-image retrieval and generative AI models. Its multimodal latent space proves to be an excellent tool both for finding existing objects and generating new ones (provided you train a decoder for it, of course—the original CLIP has none). In the rest of this post, I will expand on the retrieval part, and we will leave the generative part for next installments. But first, let’s consider an interesting extension of CLIP that, a little unexpectedly, uses our local specialty: synthetic data.

    BLIP: Bootstrapping CLIP with Synthetic Data

    There has been no lack of models that further extended and improved CLIP, although the basic CLIP itself is still very relevant. As a representative model that takes a few steps forward from CLIP let us consider BLIP (the similar acronym is no accident, of course), developed in 2022 by Li et al.

    One of the central novelties in BLIP is… synthetic data. Yes, you heard right, large datasets of photos and their captions that one can download off the Web seem to be not enough, not because they are not large enough (the Web is huge) but rather because they are too noisy. In many cases, even a properly downloaded caption is not informative about the image.

    Therefore, BLIP authors used an automatic captioning model to generate synthetic captions. But you don’t want to just throw away all of your human annotations! Moreover, sometimes synthetic data wins clearly but sometimes the human annotation is much more specific; in the illustration below, Tw is the original human annotation and Ts is the synthetic one:

    Thus, BLIP trains a separate filtering model to distinguish between good and bad captions. Here is how it might work:

    Overall, the filtering process leads to a dataset with several parts, some of them human-annotated and some synthetic, with filtering used to choose the best version in every case:

    Apart from the data, BLIP also extends the model itself. CLIP had a text encoder and an image encoder, and used contrastive pretraining only. BLIP expands on it with multitask pretraining via three different encoders:

    • a ViT for images, similar to CLIP, with its output used in three different ways;
    • a Transformer-based text encoder trained with the same image-text contrastive loss (ITC in the image below) with the ViT image embeddings;
    • an image-grounded text encoder, implemented as a Transformer encoder with causal cross-attention that receives image embedding as input; here the objective is again to classify correct vs. incorrect text-image pairs, but as a regular classification rather than a contrastive loss;
    • finally, a separate Transformer text decoder is trained to generate text captions for images, with a language modeling loss that teaches it to produce correct captions for images whose embeddings are fed into its cross-attention layers.

    Here is an illustration from the BLIP paper:

    Retrieval in the Latent Space: Text-Video Retrieval

    So how do we use all of these models for retrieval? The basic idea is very simple: once you have good multimodal embeddings, you can map the query to the same space and find nearest neighbors. Something like this:

    But this is only the very first inkling of an idea, and it needs a lot of fleshing out to get real. In this post, I cannot hope to review the entire field of multimodal retrieval so to make a relevant and practical example let us work through some of the models for text-video retrieval, i.e., searching for videos by text prompts.

    As a first example, now probably only of historical interest, let’s consider the S2VT model, originally developed for video captioning (producing text descriptions for video) but also possible to use for retrieval: this is a common trend for many models that simply map everything into a common latent space. Here is what S2VT looks like:

    This is the archetypal “early deep learning” approach, similar to, e.g., “Show and Tell”: you have a recurrent network for text and a convolutional network for video frames, they extract features and map everything into a common latent space.

    Another important trend that started quite early is considering hierarchical representations for both text and video. Both modalities are hierarchical in nature: a (detailed) text caption can be broken down into paragraphs, and the latter into sentences, while a video naturally consists of scenes and/or frames, and one can find objects on these frames.

    An early example of this approach was shown by Zhang et al. (2018). Their hierarchical sequence embedding (HSE) model includes separate intermediate loss functions that  align sentence-level embeddings for text and clip-level embeddings for videos:

    But the whole field changed when Transformers were introduced, or, to be more precise, when Transformers were applied to images in Vision Transformers. Let’s see how!

    Transformer-Based Text-Video Retrieval

    How can Transformers help retrieval? First, there is the direct approach: we have CLIP that maps text and images into the same space; let’s just extend CLIP to videos by representing them as a sequence of frames. This simple idea has been implemented in one of the first but already quite strong modern baselines for video retrieval, the CLIP4Clip model (Luo et al., 2022).

    The only question here is how to break a video down into frames. Naturally, we don’t need all frames; in fact, CLIP4Clip and similar models usually sample just a few frames, like 4 or 8, and almost none of them try anything fancy to find representative frames, it’s usually just uniform sampling (in my opinion, this is a natural place for a potential improvement). After sampling, we still have a sequence of frame embeddings (albeit a short one), and one can unite these embeddings in different ways. CLIP4Clip studies several such possibilities:

    And with that, we are basically at the state of the art level. It only remains to combine all of the ideas we have already discussed.

    LiteVL (Chen et al., 2022) does approximately the same thing but replaces CLIP with BLIP that we also discussed above. The main novel idea here is to use additional temporal attention modules and text-dependent pooling that allow to adapt to video-language tasks starting from a pretrained image-text BLIP. As a result, it has more loss functions:

    DRLTVR (Wang et al., 2022), where DRL stands for “disentangled representation learning”, is interesting in its very detailed study of different forms of cross-modal interaction in text-video retrieval. They consider six different ways to combine text and video representations to obtain a relevance score for retrieval and propose two new important ideas. First, a more fine-grained cross-modal interaction mechanism based on (possibly weighted) token-wise interactions, i.e., basically a cross-attention matrix between sentence tokens and video frame tokens:

    Second, a channel decorrelation regularization mechanism that minimizes the redundancy in learned representation vectors and thus helps to learn a hierarchical representation:

    And finally, everything comes together in the very recently published Tencent Text-Video Retrieval (Jiang et al., 2022). It has a hierarchical representation structure with frame-word, clip-phrase, and video-sentence alignments:

    Combined with a few more tricks related to adaptive label denoising and marginal sample enhancement (choosing the hardest text sample for a video), this has allowed Tencent Text-Video Retrieval to produce state of the art results.

    I also want to note that this improvement in text-video state of the art is far from squeezing the last 0.1% out of beaten datasets. For example, let us consider the Recall@1 metric on the classical MSRVTT-7K dataset, that is, how often in its test set the model retrieves a correct result at the very top:

    • a very simple zero-shot baseline in ActBERT yields a mere 8.6%;
    • good classical models such as HTM achieve about 15% Recall@1;
    • CLIP4Clip jumps up to over 40%, with its best version reaching 42.1%;
    • the best variation of LiteVL achieves 48.9%;
    • the best variation of DRLTVR has Recall@1 of 53.3%;
    • and finally, Tencent Text-Video Retrieval sits at the top with 62.9%.

    Even the most recent improvements are huge, and there is still a lot of room for improvement!

    Conclusion

    Today, our main intention has been to discuss multimodal encoders such as CLIP and BLIP that map different objects—mostly text and images—into the same latent space. However, after that we have taken a detour into text-video retrieval as a very practical sample task where such models are used almost directly: using CLIP directly with a few reasonable tricks has led to huge improvements.

    Next time, we will consider another key component of modern generative AI: diffusion-based models. In the next installment, we will discuss its main ideas and some of the underlying math (but definitely not the whole thing!), and then it will be just one more step to Stable Diffusion and its kin.

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • How DALL-E Creates AI-Generated Art: Generative AI III

    How DALL-E Creates AI-Generated Art: Generative AI III

    Today, we continue our discussion of generative AI, a direction that keeps transforming many different industries. Last time, we reviewed the difference between continuous and discrete latent spaces, and how the VQ-VAE architecture (based on variational autoencoders that we discussed before) manages to learn a discrete latent space, a codebook in the continuous latent space. Today, we will put this idea into further practice with our first real text-to-image model, OpenAI’s DALL-E.

    General Structure of DALL-E

    In previous posts, we have discussed the main ideas that, taken together, have led to OpenAI’s DALL-E, the first text-image model that actually impressed everyone not only in the AI community but in the wider world. DALL-E put image generation by text prompts on the map of the world’s media, and I would say that the current hype wave of generative AI models for images started in earnest with DALL-E (although current models are, of course, much better than the original DALL-E). But what is it, exactly, and how does it work?

    Let us begin with the general structure of DALL-E. We almost know all of the components from previous posts, so we can start with the big picture:

    The main idea is to train a Transformer-based model to generate tokens that comprise the latent code of a discrete VAE such as the one we discussed in the previous post. Discrete latent spaces converge here with the Transformers’ main forte: learning to continue sequences of tokens. 

    We only need to train a GPT-like model to generate the latent code as a sequence of special tokens that would continue a text description: “Cat playing chess in the British countryside [IMG] #c100 #c089 #c004 …”. Then we can run it with a text query followed by the special token “[IMG]”, and supposedly it will produce a sequence of latent codes for the discrete VAE. Naturally, this will require us to retrain (or fine-tune) a text Transformer on (image, text) pairs encoded in this way.

    Formally speaking, DALL-E is a generative model that needs to learn the joint distribution

        \[p_{\mathbf{\theta},\mathbf{\psi}}(\mathbf{x},\mathbf{y},\mathbf{z})= p_{\mathbf{\theta}}(\mathbf{x} | \mathbf{y},\mathbf{z})p_{\mathbf{\psi}}(\mathbf{y},\mathbf{z}),\]

    where \mathbf{x} is an image, \mathbf{y} is the corresponding text description, and \mathbf{z} is the image’s latent code. The Transformer learns to generate \mathbf{z} from \mathbf{y} (actually, learns the entire p(\mathbf{y},\mathbf{z}) since it inevitably becomes a generative model for text as well), and the result is used by the discrete VAE, so actually we assume that p_{\mathbf{\theta}}(\mathbf{x} | \mathbf{y},\mathbf{z})=p_{\mathbf{\theta}}(\mathbf{x} | \mathbf{z}).

    From the mathematical point of view, DALL-E actually optimizes a huge variational lower bound

        \[\log p_{\mathbf{\theta},\mathbf{\psi}}(\mathbf{x},\mathbf{y}) \ge \mathbb{E}_{\mathbf{z}\sim q_{\mathbf{\phi}}(\mathbf{z}|\mathbf{x})}\left[ \log p_{\mathbf{\theta}}(\mathbf{x}|\mathbf{y},\mathbf{z}) - \beta\mathrm{KL}\left(q_{\mathbf{\phi}}(\mathbf{z}|\mathbf{x})\| p_{\mathbf{\psi}}(\mathbf{y},\mathbf{z})\right)\right],\]

    where the distributions in this formula correspond to different parts of the model:

    • q_{\mathbf{\phi}}(\mathbf{z}|\mathbf{x}) is the distribution of latent codes produced by the discrete VAE’s encoder from an image \mathbf{x}; \mathbf{\phi} here denotes the parameters of the discrete VAE’s encoder;
    • p_{\mathbf{\theta}}(\mathbf{x}|\mathbf{y},\mathbf{z}) is the distribution of images generated by the discrete VAE’s decoder from a latent code \mathbf{z} and text description \mathbf{y}; again, here we assume that p_{\mathbf{\theta}}(\mathbf{x}|\mathbf{y},\mathbf{z})=p_{\mathbf{\theta}}(\mathbf{x}|\mathbf{z}); \mathbf{\theta} stands for the parameters of the discrete VAE’s decoder;
    • p_{\mathbf{\psi}}(\mathbf{y},\mathbf{z}) is the joint distribution of texts and latent codes modeled by the Transformer; here \mathbf{\psi} denotes the Transformer’s parameters.

    We will not go into the details of variational inference and explain the inequality shown above in full; this is a very important topic in machine learning but we do not have the space to do it justice here. After the derivation, though, it all boils down to a very understandable iterative process:

    • first, we maximize the bound with respect to \mathbf{\phi} and \mathbf{\theta}, that is, train the discrete VAE with a dataset of images; the texts are not used here, we assume that p_{\mathbf{\psi}}(\mathbf{y},\mathbf{z}) is uniform and relax q_{\mathbf{\phi}}(\mathbf{z}|\mathbf{x}) via the Gumbel-Softmax trick;
    • then we fix \mathbf{\phi} and \mathbf{\theta} (fix the discrete VAE) and learn \mathbf{\psi}, i.e., train the Transformer to jointly model both text (in BPE encoding) and image codes \mathbf{z}.

    At this point, we are done with the general structure of DALL-E. But, alas, to get the full picture we need to return to discrete variational autoencoders because DALL-E uses a slightly different breed of those than VQ-VAE and VQ-GAN we previously discussed.

    Discrete VAE with the Gumbel-Softmax trick

    We have seen two different discrete VAEs in the previous post: VQ-VAE that introduced the idea of discrete latent codes and VQ-GAN that added a patch-based discriminator to further improve things. But both of them had a middle part that feels pretty ad hoc to me, and hopefully to you as well by now: to move gradients through the discrete latent space they had to go around the discrete part with a copy-gradient operation.

    Discrete VAE used in DALL-E takes the next step: instead of outputting a latent vector that is then “rounded” to a codebook vector, it outputs a whole probability distribution over the codebook, probabilities for a “die” that then can be rolled to determine the actual vector:

    This is exactly parallel to the idea used in VAEs: we output a distribution in the latent space and thus obtain additional regularization and make the resulting model better.

    So now instead of the VQ-VAE problem—how do we put gradients through taking nearest neighbors—we have a different problem: how do we put gradients through rolling a die? Fortunately, we already have a hint: we solved the exact same problem for Gaussians with the reparametrization trick in regular VAEs! The trick was to generate a random sample from a standard Gaussian distribution and then apply a deterministic (and differentiable) linear transformation to change it into a sample from the needed Gaussian.

    The distribution is different now, but the trick is the same. We need to first sample from a fixed distribution and then apply a transformation to get the die roll with given probabilities. The fixed distribution in question is actually quite interesting: it is the Gumbel distribution whose density and cumulative distribution function are defined as

        \[p(g_i) = e^{-\left(g_i + e^{-g_i}\right)},\qquad F(g_i) = e^{-e^{-g_i}}.\]

    In statistics, the Gumbel distribution appears as the distribution of the maximum (or minimum) of several samples, but, to be honest, I have never encountered the Gumbel distribution in any practical context other than this reparametrization trick.

    Anyway, the important part is that once you have sampled gi from the Gumbel distribution defined above, you can get a sample from a discrete distribution with probabilities πi (the result of a die roll) as

        \[z = \arg\,\max_i\left(g_i + \log \pi_i\right).\]

    The proof of this fact, known as the Gumbel-Max trick, is a straightforward but somewhat tedious calculation, so I’ll skip it or, to put it in a slightly more stylish way, leave it as an exercise for the reader.

    Once we have the Gumbel-Max trick, though, we are not quite done. We have gone from sampling to argmax, but it’s still not quite what we need. The argmax operation is also not good for passing gradients since it is piecewise constant; in fact, in VQ-VAE we had exactly the same problem, with argmin for nearest neighbors, and had to resort to copying the gradients.

    This time, though, we don’t have to. Since the argmax here corresponds to die rolling, it makes perfect sense to relax it to softmax:

        \[y_i = \mathrm{softmax}\left(\frac{1}{\tau}\left(g_i + \log\pi_i\right)\right) = \frac{e^{\frac{1}{\tau}\left(g_i + \log\pi_i\right)}}{\sum_je^{\frac{1}{\tau}\left(g_j + \log\pi_j\right)}}.\]

    For \tau\to 0 this tends to a discrete distribution with probabilities \pi_i, and during training we can gradually reduce the temperature τ. Note that now the result is not a single codebook vector but a linear combination of codebook vectors with weights y_i.

    Overall, we have the following scheme in our discrete VAE:

    And with that, we are done with DALL-E! It remains to see how well it works.

    DALL-E: Results and Reception

    DALL-E debuted at the very beginning of 2021. This was a perfect New Year’s present for all kinds of AI-related media because DALL-E was indeed a big step forward from what researchers had been able to do before. Images from the DALL-E OpenAI post and paper were hyped all across the Web; images like this one:

    Or this one:

    It already looked like these images could be useful in practice, and discussions about “replacing the illustrators” began. DALL-E was also able to use image prompts (parts of the resulting image that should be preserved) that could define the style and overall feel of the result.

    DALL-E seemed to have a rather deep understanding of our reality that it could put into pictures. For example, the next illustration shows several image prompts and a text prompt that asks DALL-E to show how telephones looked at different stages of their development:

    Although the quality of the images themselves may be underwhelming for those who have already seen Stable Diffusion and Midjourney, it was really head and shoulders above anything other available solutions could produce, and it was quite a shocking piece of news for many AI researchers, including yours truly.

    It was clear that it would be only a matter of time before DALL-E would be scaled up to high-definition images (the original DALL-E produced 256×256 results) and made even more “understanding” of reality with larger Transformer-based text models. That is indeed what happened, and the world we live in today is being increasingly transformed by both large language models and large image generation models.

    Still, many new ideas appeared along this way, and we cannot say that DALL-E 2 is just “DALL-E with more layers”. That’s why our series of posts is far from the end, and modern generative AI has a lot more to teach us.

    Conclusion

    Today, we have discussed DALL-E, a model released in January 2021. A mere two years have passed, but it looks like DALL-E is already hopelessly outdated. New models that visibly advance state of the art for image generation appear every few months, and the rate of this advancement does not seem to stagnate. Don’t worry though, the ideas behind DALL-E are still sound and useful, and this has been my primary ambition in this series: explain the ideas, the how rather than the what.

    However, to get to the current state of the art we need more ideas. So next time, we will take a brief detour from generation and talk about models that produce multimodal latent spaces, such as OpenAI’s CLIP (Contrastive Language-Image Pre-Training) and its successors. They are extremely useful for, e.g., multimodal retrieval (searching for images and videos), but they also serve as the basis for further generative models. Until next time!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Discrete Latent Spaces: Generative AI II

    Discrete Latent Spaces: Generative AI II

    Last time, we discussed one of the models that have made modern generative AI possible: variational autoencoders (VAE). We reviewed the structure and basic assumptions of a VAE, and by now we understand how a VAE makes the latent space more regular by using distributions instead of single points. However, the variations of VAE most often used in modern generative models are a little different: they use discrete latent spaces with a fixed vocabulary of vectors. Let’s see what that means and how it can help generation!

    Continuous and Discrete Latent Spaces

    We have already discussed latent spaces in both the introductory post and the post on variational autoencoders but this time we have a slightly different spin. In general, an autoencoder is compressing the high-dimensional input (say, an image) into a low-dimensional representation, i.e., into a relatively short vector of numbers whose dimension is in the hundreds rather than millions:

    If the autoencoder is designed well, this may result in a latent space where certain directions correspond to specific properties of the image. For instance, if we are compressing cat images then one axis may correspond to the cat’s color and another to the overall style of the image:

    Naturally, in reality these directions would not necessarily coincide with coordinate axes and may be hard to find. There’s no preference for a regular autoencoder architecture (say, a VAE) to find a latent space with well-defined directions. In fact, it is easy to see that the latent space may undergo rather complicated transformations with no change in the model complexity: there is usually no difference between learning an encoder Enc and decoder Dec and learning an encoder f○Enc and decoder Dec○f−1 for some invertible transformation f.

    This is an appealing picture, but it’s not as easy to obtain, and, moreover, it’s not really how we think about styles and picture descriptions. We are verbal creatures, and when I want to get a picture of a black cat I don’t have a real number associated with its “blackness”, I just want a cat with a discrete “black” modifier, just like I might want a “white” modifier. A black and white cat for me is not a number that reflects the percentage of white hair but most probably just a separate “black and white” modifier that turns out to be in a rather complex relationship with the “black” and “white” modifiers.

    Can we try to reflect this intuition in an autoencoder latent space? We could imagine a latent space that has a vocabulary of “words” and decodes combinations of these words into images. Something like this:

    This looks much more “human-like”, and the last few years of generative AI have indeed proven this approach to be significantly more fruitful. Its best feature is the ability to use autoregressive generative models for discrete latent representations. For example, the famous Transformers, in particular the GPT family, can be applied to produce latent space “words” just as well as they produce real words in their natural language applications, but they would be much harder to adapt to components of continuous latent vectors.

    But the discrete latent space approach comes with its own set of problems, both technical and conceptual. In the rest of this post, we will go through two models that successfully solved these problems and thus became foundational for modern generative AI.

    VQ-VAE: Vector Quantized VAE

    The first model that successfully managed to construct a discrete latent space at a scale sufficient for general-purpose images was Vector Quantized VAE (VQ-VAE) introduced back in 2017 by DeepMind researchers (van den Oord et al., 2017). Its basic idea is exactly as we have discussed: VQ-VAE finds a finite vocabulary (codebook) and encodes images as fixed sets (tensors) of discrete codes:

    It turns out that it’s not a good idea to make the encoder do actual classification over the codebook vectors. Therefore, here’s what we want to happen:

    • the encoder, as usual, takes an image x as input and produces a set of latent vectors Enc(x); a slight difference with our previous settings is that now the encoder produces a whole set of vectors (usually formalized as a three-dimensional tensor, i.e., a matrix of vectors), but mathematically it’s equivalent to slicing a single output vector;
    • for every latent vector, we find the nearest codebook vector in the latent space and replace it with this codebook vector; the resulting code consists only of codebook vectors;
    • the decoder receives as input the tensor of vectors, with the same dimensions as the encoder had produced, but actually the latent code is now discrete: while each component of the latent code is still a continuous vector of real numbers, there’s now only a finite number of possibilities for each of the vectors.

    Here’s an illustration (I only show how one vector is chosen but the procedure is the same for each of them):

    At this point, some readers might be wondering: there’s now only a finite number of latent codes in total! There is no boundless generation, characteristic for natural languages where we can have texts as long as we like. Won’t that severely limit the models? Well, a realistic size of the latent code tensor is something like 32×32 with, say, 8192 codebook vectors (the numbers are taken from the original DALL-E model). There are two ways to look at these numbers. On the one hand, this amounts to 819232×32 = 240960 possibilities while the number of atoms in the Universe is less than 2300, so it looks like we are covered. On the other hand, this is equivalent to compressing every possible image of size 256×256 (the dimensions of original DALL-E) into 40960 bits, i.e., a bit more than 5 kilobytes of data, which means that we will need quite a compressing tool. Both views are valid: modern autoencoder-based models are indeed very impressive in their ability to compress images into latent representations, but the diversity of their outputs does not bound our imagination too much.

    There are two questions remaining. First, how do we train the encoder and decoder networks? It looks like we have the same problem as VAE had: just like VAE had sampling in the middle, VQ-VAE has a piecewise constant operation (taking the nearest neighbor), and the gradients cannot flow back through this operation. And second, how do we learn the codebook vectors?

    At this point, I will show a picture from the original VQ-VAE paper; it always comes up in these discussions, and we need its notation to discuss the VQ-VAE objective, so you need to see it too:

    This picture mostly illustrates the idea of a discrete latent space with codebook vectors that we have already discussed. But it also shows the solution to the first problem: VQ-VAE simply copies the gradients (red line) from the decoder to the encoder, that is, the gradient of the loss function with respect to the tensor of codebook vectors is assumed to be its gradient with respect to Enc(x). This is an approximation, of course, and one can remove it with a more involved construction (a discrete VAE with the Gumbel-Softmax trick that we will explain in a later post on DALL-E), but for now it will have to do.

    As for the second problem, it brings us to the VQ-VAE training objective. Here is the loss function as defined by van den Oort et al. (2017):

        \[\mathcal{L} = -\log p(\mathbf{x}|\mathbf{z}_q(\mathbf{x}) + \left\|\mathrm{sg}\left[\mathbf{z}_e(\mathbf{x})\right]-\mathbf{e}\right\|_2^2 + \beta\left\|\mathbf{z}_e(\mathbf{x}) - \mathrm{sg}\left[\mathbf{e}\right]\right\|_2^2.\]

    This formula sure begs for a detailed explanation. Let’s first go through all the notation step by step and then summarize:

    • \mathbf{z}_e(\mathbf{x}) and \mathbf{z}_q(\mathbf{x}) are two latent representations for an image \mathbf{x} produced by the encoder: \mathbf{z}_e is the output of the decoder and \mathbf{z}_q is the codebook representation after replacing each vector with its nearest codebook neighbor (this notation is illustrated in the image above); the first term is responsible for training the decoder network;
    • p(\cdot|\mathbf{z}) is the distribution of reconstructed images after the decoder given the latent code; we want the reconstruction to be good so we maximize the likelihood of the original image \mathbf{x} given its latent code \mathbf{z}_q(\mathbf{x}) that serves as input for the decoder;
    • sg[ᐧ] is the stopgradient operator; it is defined as the identity during the forward pass (when we compute the objective function \mathcal{L}) and zero during the backward pass (when we compute the gradient \nabla_{\mathbf{w}}\mathcal{L});
    • therefore, the second term means that we want to bring each codebook vector \mathbf{e} closer to the latent codes \mathbf{z}_e(\mathbf{x}) that choose it as its nearest neighbor; this term is responsible for training the codebook;
    • the third term is the opposite: it brings \mathbf{z}_e(\mathbf{x}) closer to their corresponding codebook vectors; in effect, the second and third term together do a kind of clustering of latent codes \mathbf{z}_e(\mathbf{x}) around their corresponding codebook vectors; the hyperparameter \beta can balance the two terms although the authors say that the results don’t change for  \beta at least from 0.1 to 2.0;
    • finally, the encoder network is trained with the first and third terms where it occurs in the form of \mathbf{z}_e(\mathbf{x}).

    In the illustration below, I show the components of the objective function and what their contributions are in the latent space (on top) and on learning the weights of the encoder and decoder networks (at the bottom):

    Interestingly, the original paper has a typo in its main formula repeated in countless blog post explanations: the authors forgot the minus sign in front of the likelihood so in their training objective the first term should be maximized and the other two minimized. Naturally, it’s just a typo, and all working VQ-VAE implementations get it right, but it’s funny how these things can get propagated.

    That’s it for VQ-VAE. The original model predated Transformers so it could not use them for latent code generation but they used a different autoregressive model which was state of the art at the time: PixelCNN (van den Oord et al., 2016Salimans et al., 2017). PixelCNN itself originated as a model for generating pictures, but generating a high-resolution image autoregressively, pixel by pixel, is just way too slow (see also my first post in this series). But it’s just fine for generating a set of 32×32 codebook tokens! The original VQ-VAE, trained on ImageNet with a separate PixelCNN trained to generate latent codes, produced impressive results by 2017 standards:

    The next step was VQ-VAE 2 that still used PixelCNN for latent codes but moved to a hierarchical structure, generating a small top-level representation and then a more detailed bottom-level representation conditioned on the top level result:

    VQ-VAE 2 produced excellent results. When it came out, in 2019, in the wake of ProGAN (you may have heard of it as “This person does not exist”) everybody was comparing generation abilities on a dataset of high-dimensional celebrity photos, and VQ-VAE 2 did not disappoint:

    But we still have some way to go before DALL-E 2 and Stable Diffusion, even in terms of the underlying autoencoders. The next step for VQ-VAE was to turn it into a GAN…

    VQ-GAN: Add a Discriminator to the Mix

    VQ-VAE and VQ-VAE 2 left us with some very good generation via discrete latent codes but the codes were still produced by a PixelCNN model. Naturally, we’d like to generate these codes with a Transformer-based architecture, at least because it’s much better at handling global dependencies: a Transformer does not even have the notion of a “long” or “short” dependency, it always attends to every previously generated token.

    It was only natural that the next step would be to use a Transformer to generate the codes. So in the autoencoder part, we would have something similar to VQ-VAE, and then the Transformer would serve as the autoregressive model to generate the codes:

    So in this approach, an image becomes a sequence of codebook vectors, and the Transformer does what it does best: learns to generate sequences. 

    One of the problems here is that we need to learn a very rich and expressive codebook. So instead of using just a straightforward reconstruction loss, VQ-GAN (Esser et al., 2020) adds a patch-based discriminator that aims to distinguish between (small patches of) real and reconstructed images, and the loss becomes a perceptual loss, i.e., the difference between features extracted by some standard convolutional network (Zhang et al., 2018). This means that the discriminator now takes care of the local structure of the generate image, and the perceptual loss deals with the actual content.

    In total, the losses for our autoencoder might look something like this:

    And with this, we are ready to see the overview of the whole architecture as it is shown in the original VQ-GAN paper (Esser et al., 2020):

    Just like a regular VQ-VAE, an image is represented with a sequence of discrete codebook vectors, but now the reconstruction is ensured by a combination of perceptual and adversarial losses, and the codes are produced by a Transformer.

    VQ-GAN could produce better images on the basic ImageNet—here are some first rate goldfinches compared to other approaches:

    But a major point about VQ-GAN was that it could scale to far higher resolutions. Here is a sample landscape (originally 1280×832 pixels) generated by the VQ-GAN from a semantic layout, i.e., from a rough segmentation map showing where the sky, land, mountains, and grass should be:

    As a result, VQ-GAN, like virtually every method we discuss, defined the state of the art for image generation when it was introduced. We have to stop here for now, but our story is far from over…

    Conclusion

    In this post, we have discussed the notion of a discrete latent space, where images are compressed to sequences of tokens (“words”) instead of continuous vectors. This makes it far easier to train a good generative model since generating sequences is the bread and butter of many autoregressive models. The original VQ-VAE family used PixelCNN as this intermediate autoregressive model, but as soon as Transformers appeared it became clear that they are a great fit for this task, and VQ-GAN managed to make it work.

    At this point, we are ready to put several things together and discuss not just an image generation/reconstruction model but a real text-to-image model, where (spoiler alert) the Transformer will generate a sequence of discrete latent space tokens starting with a natural language text prompt. So next time, get ready for DALL-E!

    Sergey Nikolenko
    Head of AI, Synthesis AI

  • Variational Autoencoders (VAEs): Generative AI I

    Variational Autoencoders (VAEs): Generative AI I

    It might seem like generative models are going through new phases every couple of years: we heard about Transformers, then flow-based models were all the rage, then diffusion-based models… But in fact, new ideas build on top of older ones. Following our overview post, today we start an in-depth dive into generative AI. We consider the variational autoencoder (VAE), an idea introduced in 2013, if not earlier, but still very relevant and still underlying state of the art generative models such as Stable Diffusion. We will not consider all the gory mathematical details but I hope to explain the necessary intuition.

    VAE Intuition: Handwaving in the Latent Space

    We already covered this basic idea in the overview post but let me reintroduce the problem and move on to a more detailed discussion of the VAE intuition. We discussed that a basic autoencoder can learn to compress and decompress images pretty well with a simple high-level architecture like this:

    However, this is not enough to get you a generative model because the structure of latent codes will still be too complicated. You have a very complex manifold of images in a huuuuuge space with dimensions in the millions, but the latent space probably also has dimension in the hundreds or low thousands, and the latent codes will have a complicated structure in that space:

    So if you try to sample latent codes from a simple distribution you will almost certainly fail, that is, your samples will fall outside the manifold of latent codes, and the decoder will fail to produce anything meaningful, let alone beautiful:

    A variational autoencoder tries to fix this problem by making each point “wider”. Instead of a single latent vector z, now each input x is mapped to a whole distribution:

    The intuition is that by making the decoder to work with z’s sampled from whole distributions, we force it to be robust to small changes in z. Ultimately, we want to cover a whole chunk of the latent space with points that have “reasonable” decodings, so that afterwards we can sample from a simple distribution and still get good results:

    This idea, however, meets with two difficulties. First, when we begin to train an autoencoder, it will be beneficial for it to make the intermediate distributions as “small” (with low variance) as possible: if you are always very close to the central point the decoder’s job becomes easier, and reconstructions probably improve. In a similar vein, the distributions may begin to drift off from each other in the latent space, again making the decoder’s job easier as it now has more slack in distinguishing between different inputs. So unless we do something about it, the training process will look something like this, tending to a regular autoencoder that we know to be of little value for us:

    To alleviate this problem, we need to impose some kind of a constraint on what’s happening with the intermediate distributions. In machine learning, hard constraints rarely appear, they usually take the form of regularizers, i.e., additions to the loss function that express what we want. In this case, we want to keep the distributions for each input x relatively “large” and we want to keep them together in relatively close proximity, so we probably will kill two birds with one stone if we make the distribution px(z | μxσx) closer to a standard Gaussian. Our overall loss function now becomes a sum of the reconstruction loss and this regularizer:

    Still, the question remains: how do we regularize? We want px(z | μxσx) to be close to a standard Gaussian distribution, but there are several plausible ways to do that: the Kullback-Leibler divergence can cut both ways, either KL(p||q) or KL(q||p), and then there are combinations like the Jensen-Shannon divergence… What would be the best and conceptually correct way to define Lreg?

    The second problem is more technical: the picture above has the latent code z sampled from a distribution px(z | μxσx). This is fine during inference, when we want to apply the already trained encoder and decoder. But how do we train? Gradients cannot go through a “sampling layer”.

    Let us begin with the first problem; solving it will also give us a nice probabilistic interpretation of what’s going on in VAE and explain why it is called a variational autoencoder.

    VAE Intuition: Probabilistic Handwaving

    Let us consider a different way to look at the same structure that leads to different insights and ultimately will help us understand the mathematical ideas behind variational autoencoders. We will start from scratch: suppose that we want to train an encoder to produce latent codes z from images x and a decoder to go back from z to x.

    We begin with a very simple formula; I promised as little math as possible but, to be honest, there will be a little more than that below:

        \[p(\mathbf{z})p(\mathbf{x}|\mathbf{z})=p(\mathbf{x},\mathbf{z})=p(\mathbf{x})p(\mathbf{z}|\mathbf{x})\]

    This is basically the Bayes formula in its simplest form: it says that the joint distribution of images and their latent codes can be decomposed in two different ways, starting either with p(x) or with p(z) and multiplying it by the corresponding conditional distribution.

    We already understand, at least generally, all parts of this formula: p(x) is the distribution of images, p(z) is the distribution of latent codes, i.e., the simple distribution we want to be able to sample from (most likely a standard Gaussian), and the other two distributions are what we need to find, the encoder distribution p(z|x) and the decoder distribution p(x|z):

    If we want to get a generative model, our main goal is to learn both p(x|z) and p(z|x). But here is the thing: in a generative model p(z) is by design simple since we need to be able to sample from it, while p(x) is, in any model, unimaginably complex since this is the distribution of real objects (images). So we cannot have both the encoder distribution p(x|z) and decoder distribution p(z|x) be simple: if, say, they both were Gaussian we’d have a Gaussian on the left but definitely something much more complicated on the right-hand side of the equation. 

    We need to pick one:

    • either assume that p(x|z) is simple and then try to find a complex p(z|x);
    • or vice versa, assume that p(z|x) is simple and find a complex p(x|z).

    Variational autoencoders take the first option: we will assume that p(x|z) = N(x | f(z), cI) is a Gaussian distribution with mean f(z) = Decoder(z) and covariance matrix cwhich is just a constant c along every axis. Thus, on the left we have a simple Gaussian p(z|x) times a simple Gaussian p(z) = N(z | 0I), that is, another Gaussian.

    What do we do on the right-hand side? We need to find a very complex distribution p(z | x). There are several different ways to do that, and variational autoencoders take the road of approximation: the encoder produces a simple distribution p(z | μxσx), actually again a Gaussian N(z | μxσx), but this time we cannot say that this Gaussian is the real p(z | x), we have to say that it’s an approximation:

    The only thing left is how to find such an approximation. This is where the variational part comes in: variational approximations are how probability distributions are usually approximated in machine learning. 

    Variational approximations and the loss function in VAE

    I promised not to have too much math; I lied. But you already have the basic intuition so now you can safely skip to the very end of this section and still understand everything that goes afterwards. With that said, if you are not afraid to get your hands a little dirty let us still go through the inference.

    The idea of variational approximations is shown in the sequence of equations below. We start with an obvious identity, take the expectation over q(z) on both parts, and then do some transformations to break down the right-hand part into two terms, while the left-hand side does not depend on z, so the expectation simply disappears:

        \[\begin{aligned} p(\mathbf{x},\mathbf{z}) &= p(\mathbf{x})p(\mathbf{z}|\mathbf{x}),\\ \log p(\mathbf{x},\mathbf{z}) &= \log p(\mathbf{x}) + \log p(\mathbf{z}|\mathbf{x}),\\ \log p(\mathbf{x}) &= \log p(\mathbf{x},\mathbf{z}) - \log p(\mathbf{z}|\mathbf{x}),\\\log p(\mathbf{x}) &= \mathbb{E}_{q(\mathbf{z})}\left[\log p(\mathbf{x},\mathbf{z}) - {\log q(\mathbf{z})} + {\log q(\mathbf{z})} - {\log \p(\mathbf{z}|\mathbf{x})\right],\\\log p(\mathbf{x}) &= \int {q(\mathbf{z}}}{\log \frac{p(\mathbf{x},\mathbf{z})}{q(\mathbf{z})}}\dd\mathbf{z} + \int {q(\mathbf{z})}{\log \frac{q(\mathbf{z})}{p(\mathbf{z}|\mathbf{x})}\mathrm{d}\mathbf{z}. \end{aligned}\]

    As a result, we have a constant (that is, something independent of q(z)) on the left and the sum of L(q) and the Kullback-Leibler divergence between q(z) and p(z|x) on the right, that is, a measure of how close these distributions are to each other:

    This means that we can approximate p(z|x) with q(z), i.e., minimize the divergence between them, by maximizing the first term L(q). But this first term is probably much easier to handle since it contains the joint distribution p(xz) and not the conditional distribution p(z|x). In particular, we can now decompose it in the other way:

        \[\begin{aligned}\mathcal{L}(q) &= \int {q(\mathbf{z})}{\log \frac{p(\mathbf{x},\mathbf{z})}{q(\mathbf{z})}}\mathrm{d}\mathbf{z} = \int {q(\mathbf{z})}{\log \frac{p(\mathfb{z})p(\mathbf{x}|\mathbf{z})}{q(\mathbf{z})}}\mathrm{d}\mathbf{z} \\&= \int {q(\mathbf{z})}{\log {p(\mathbf{x}|\mathbf{z})}}\mathrm{d}\mathbf{z} + \int {q(\mathbf{z})}{\log \frac{p(\mathbf{z})}{q(\mathbf{z})}}\mathrm{d}\mathbf{z} \\&= \int {q(\mathbf{z})}{\log \mathcal{N}(\mathbf{x}| f(\mathbf{z}),c\mathbf{I})}\mathrm{d}\mathbf{z} - \kl{q(\mathbf{z})}{p(\mathbf{z})} \\&= -\frac{1}{2c}\mathbb{E}_{q(\mathbf{z})}\left[ \left|\mathbf{x} - f(\mathbf{z})\right|^2\right] - \mathrm{KL}\left({q(\mathbf{z})}\|{p(\mathbf{z})}\right).\end{aligned}\]

    And now we have arrived at exactly the two terms that we considered in the first “intuitive” part! We need to maximize L(q), so the first term wants to make f(z) as close as possible to x, and the second term wants to make q(z) as close as possible to p(z), that is, to the standard Gaussian. Overall, we minimize exactly the sum of two terms that we had at the end of the first section:

        \[\mathcal{L} = \mathcal{L}_{\mathrm{rec}} + \lambda\mathcal{L}_{\mathrm{reg}}.\]

    Why did we need all that math if we arrived at the exact same conclusion? Mostly because we were not sure what the reconstruction loss and the regularizer should look like. Our intuition told us that we want q(z) to be “large” but how do we formalize it exactly? And which reconstruction loss should we use? Variational approximations answer all these questions in a conceptually sound way. Moreover, they even explain the meaning of the regularization coefficient λ: turns out it’s the (inverse of) the variance for the decoder distribution. Not that it helps that much—we still need to choose c ourselves just like we needed to choose λ—but it’s always better to understand what’s going on.

    By now we are almost done. I will skip the exact calculation of the regularization term: it’s tedious but straightforward and does not contain new interesting ideas; basically, you can get a rather simple exact formula in terms of μx and σx.

    The only thing left is to handle the second problem: how do we train a model that has sampling in the middle?

    Reparametrization trick and the overall algorithm

    By now, we understand the nature of the loss function in the variational autoencoder and can go back to the sampling problem:

    Indeed, it is impossible to send the gradients back through the sampling process. Fortunately, we don’t need to.

    The reparametrization trick comes to the rescue. The idea of this trick (we will see other versions of it in subsequent posts) is to sample a random number first from some standard distribution and then transform it into the desired distribution. In the case of Gaussians the reparametrization trick is very simple: to get a vector z from N(z | μxσx) with a diagonal covariance matrix we can first get u from N(z | 0I), then multiply it componentwise by σx, and then add μx to the result. The picture above in this case looks like this:

    Now we can sample a mini-batch of vectors u for an input mini-batch of images and use them for training, never needing to run gradients through the sampling process.

    And that’s it! Now we have the complete picture of how a variational autoencoder works, what loss function it minimizes and why, and how this loss function is related to the basic intuition of VAEs.

    Conclusion: How Is VAE Still Relevant?

    In this post, we have discussed the idea and implementation of VAE, a model first introduced in 2013. But these days, you don’t hear much about VAEs in the news. It’s a nice idea but is it still relevant for generative AI today?

    As it turns out, VAEs are not only relevant but actually still represent one of the pillars on which the entire modern generative AI stands. Consider, for instance, the basic structure of the Stable Diffusion model (which has produced all cat images in this post):

    As you can see, the picture concentrates on the diffusion and denoising parts—as well it should since these are the novelties that differentiate this work from prior art. But note that all these novelties take place in the latent space of some kind of autoencoder for images, with an encoder E and decoder D mapping the codes produced by diffusion-based models into the pixel space. Where do these E and D come from? You guessed it, it’s a variational autoencoder!

    But it is not the default vanilla VAE that we have discussed today. These days, it is actually either the quantized version of VAE with a discrete latent space, VQ-VAE, or its further modification with an additional discriminator, VQGAN. We will discuss these models in the next installment; until then!

    Sergey Nikolenko
    Head of AI, Synthesis AI