June 20, 2024

Enhancing Neural Network Computation through Sequential Reasoning

Your grade school teacher probably didn’t show you how to add 20-digit numbers. But if you know how to add smaller numbers, all you need is paper and pencil and a bit of patience. Start with the ones place and work leftward step by step, and soon you’ll be stacking up quintillions with ease.

Problems like this are easy for humans, but only if we approach them in the right way. “How we humans solve these problems is not ‘stare at it and then write down the answer,’” said Eran Malach, a machine learning researcher at Harvard University. “We actually walk through the steps.”

That insight has inspired researchers studying the large language models that power chatbots like ChatGPT. While these systems might ace questions involving a few steps of arithmetic, they’ll often flub problems involving many steps, like calculating the sum of two large numbers. But in 2022, a team of Google researchers showed that asking language models to generate step-by-step solutions enabled the models to solve problems that had previously seemed beyond their reach. Their technique, called chain-of-thought prompting, soon became widespread, even as researchers struggled to understand what makes it work.

Now, several teams have explored the power of chain-of-thought reasoning by using techniques from an arcane branch of theoretical computer science called computational complexity theory. It’s the latest chapter in a line of research that uses complexity theory to study the intrinsic capabilities and limitations of language models. These efforts clarify where we should expect models to fail, and they might point toward new approaches to building them.

“They remove some of the magic,” said Dimitris Papailiopoulos, a machine learning researcher at the University of Wisconsin, Madison. “That’s a good thing.”

Training Transformers

Large language models are built around mathematical structures called artificial neural networks. The many “neurons” inside these networks perform simple mathematical operations on long strings of numbers representing individual words, transmuting each word that passes through the network into another. The details of this mathematical alchemy depend on another set of numbers called the network’s parameters, which quantify the strength of the connections between neurons.

William Merrill in a blue sweater in front of a white board with an equation.

Before starting graduate school, William Merrill used theoretical methods to analyze the capabilities of different kinds of neural networks.

Brian Kitano

To train a language model to produce coherent outputs, researchers typically start with a neural network whose parameters all have random values, and then feed it reams of data from around the internet. Each time the model sees a new block of text, it tries to predict each word in turn: It guesses the second word based on the first, the third based on the first two, and so on. It compares each prediction to the actual text, then tweaks its parameters to reduce the difference. Each tweak only changes the model’s predictions a tiny bit, but somehow their collective effect enables a model to respond coherently to inputs it has never seen.

Researchers have been training neural networks to process language for 20 years. But the work really took off in 2017, when researchers at Google introduced a new kind of network called a transformer.

“This was proposed seven years ago, which seems like prehistory,” said Pablo Barceló, a machine learning researcher at the Pontifical Catholic University of Chile.

What made transformers so transformative is that it’s easy to scale them up — to increase the number of parameters and the amount of training data — without making training prohibitively expensive. Before transformers, neural networks had at most a few hundred million parameters; today, the largest transformer-based models have more than a trillion. Much of the improvement in language-model performance over the past five years comes from simply scaling up.

Transformers made this possible by using special mathematical structures called attention heads, which give them a sort of bird’s-eye view of the text they’re reading. When a transformer reads a new block of text, its attention heads quickly scan the whole thing and identify relevant connections between words — perhaps noting that the fourth and eighth words are likely to be most useful for predicting the 10th. Then the attention heads pass words along to an enormous web of neurons called a feedforward network, which does the heavy number crunching needed to generate the predictions that help it learn.

Real transformers have multiple layers of attention heads separated by feedforward networks, and only spit out predictions after the last layer. But at each layer, the attention heads have already identified the most relevant context for each word, so the computationally intensive feedforward step can happen simultaneously for every word in the text. That speeds up the training process, making it possible to train transformers on increasingly large sets of data. Even more important, it allows researchers to spread the enormous computational load of training a massive neural network across many processors working in tandem.

To get the most out of massive data sets, “you have to make the models really large,” said David Chiang, a machine learning researcher at the University of Notre Dame. “It’s just not going to be practical to train them unless it’s parallelized.”

Transformers are quite weak if the way you use them is you give an input, and you just expect an immediate answer.

William Merrill, New York University

However, the parallel structure that makes it so easy to train transformers doesn’t help after training — at that point, there’s no need to predict words that already exist. During ordinary operation, transformers output one word at a time, tacking each output back onto the input before generating the next word, but they’re still stuck with an architecture optimized for parallel processing.

As transformer-based models grew and certain tasks continued to give them trouble, some researchers began to wonder whether the push toward more parallelizable models had come at a cost. Was there a way to understand the behavior of transformers theoretically?

The Complexity of Transformers

Theoretical studies of neural networks face many difficulties, especially when they try to account for training. Neural networks use a well-known procedure to tweak their parameters at each step of the training process. But it can be difficult to understand why this simple procedure converges on a good set of parameters.

Rather than consider what happens during training, some researchers study the intrinsic capabilities of transformers by imagining that it’s possible to adjust their parameters to any arbitrary values. This amounts to treating a transformer as a special type of programmable computer.

“You’ve got some computing device, and you want to know, ‘Well, what can it do? What kinds of functions can it compute?’” Chiang said.

These are the central questions in the formal study of computation. The field dates back to 1936, when Alan Turing first imagined a fanciful device, now called a Turing machine, that could perform any computation by reading and writing symbols on an infinite tape. Computational complexity theorists would later build on Turing’s work by proving that computational problems naturally fall into different complexity classes defined by the resources required to solve them.

In 2019, Barceló and two other researchers proved that an idealized version of a transformer with a fixed number of parameters could be just as powerful as a Turing machine. If you set up a transformer to repeatedly feed its output back in as an input and set the parameters to the appropriate values for the specific problem you want to solve, it will eventually spit out the correct answer.

That result was a starting point, but it relied on some unrealistic assumptions that would likely overestimate the power of transformers. In the years since, researchers have worked to develop more realistic theoretical frameworks.

One such effort began in 2021, when William Merrill, now a graduate student at New York University, was leaving a two-year fellowship at the Allen Institute for Artificial Intelligence in Seattle. While there, he’d analyzed other kinds of neural networks using techniques that seemed like a poor fit for transformers’ parallel architecture. Shortly before leaving, he struck up a conversation with the Allen Institute for AI researcher Ashish Sabharwal, who’d studied complexity theory before moving into AI research. They began to suspect that complexity theory might help them understand the limits of transformers.

It just seemed like it’s a simple model; there must be some limitations that one can just nail down,” Sabharwal said.

The pair analyzed transformers using a branch of computational complexity theory, called circuit complexity, that is often used to study parallel computation and had recently been applied to simplified versions of transformers. Over the following year, they refined several of the unrealistic assumptions in previous work. To study how the parallel structure of transformers might limit their capabilities, the pair considered the case where transformers didn’t feed their output back into their input — instead, their first output would have to be the final answer. They proved that the transformers in this theoretical framework couldn’t solve any computational problems that lie outside a specific complexity class. And many math problems, including relatively simple ones like solving linear equations, are thought to lie outside this class.

Basically, they showed that parallelism did come at a cost — at least when transformers had to spit out an answer right away. “Transformers are quite weak if the way you use them is you give an input, and you just expect an immediate answer,” Merrill said.

Thought Experiments

Merrill and Sabharwal’s results raised a natural question — how much more powerful do transformers become when they’re allowed to recycle their outputs? Barceló and his co-authors had studied this case in their 2019 analysis of idealized transformers, but with more realistic assumptions the question remained open. And in the intervening years, researchers had discovered chain-of-thought prompting, giving the question a newfound relevance.

Merrill and Sabharwal knew that their purely mathematical approach couldn’t capture all aspects of chain-of-thought reasoning in real language models, where the wording in the prompt can be very important. But no matter how a prompt is phrased, as long as it causes a language model to output step-by-step solutions, the model can in principle reuse the results of intermediate steps on subsequent passes through the transformer. That could provide a way to evade the limits of parallel computation.

An illustration showing an orange and blue network of lines focus into a clear pyramid, emerging as a white light traveling into a clear eye.

NEURAL NETWORKS

Will Transformers Take Over Artificial Intelligence?

MARCH 10, 2022

READ LATER

Meanwhile, a team from Peking University had been thinking along similar lines, and their preliminary results were positive. In a May 2023 paper, they identified some math problems that should be impossible for ordinary transformers in Merrill and Sabharwal’s framework, and showed that intermediate steps enabled the transformers to solve these problems.

 

In October, Merrill and Sabharwal followed up their earlier work with a detailed theoretical study of the computational power of chain of thought. They quantified how that extra computational power depends on the number of intermediate steps a transformer is allowed to use before it must spit out a final answer. In general, researchers expect the appropriate number of intermediate steps for solving any problem to depend on the size of the input to the problem. For example, the simplest strategy for adding two 20-digit numbers requires twice as many intermediate addition steps as the same approach to adding two 10-digit numbers.

 

Examples like this suggest that transformers wouldn’t gain much from using just a few intermediate steps. Indeed, Merrill and Sabharwal proved that chain of thought only really begins to help when the number of intermediate steps grows in proportion to the size of the input, and many problems require the number of intermediate steps to grow much larger still.

The thoroughness of the result impressed researchers. “They really pinned this down,” said Daniel Hsu, a machine learning researcher at Columbia University.

Merrill and Sabharwal’s recent work indicates that chain of thought isn’t a panacea — in principle, it can help transformers solve harder problems, but only at the cost of a lot of computational effort.

“We’re interested in different ways of getting around the limitations of transformers with one step,” Merrill said. “Chain of thought is one way, but this paper shows that it might not be the most economical way.”

Back to Reality

Still, researchers caution that this sort of theoretical analysis can only reveal so much about real language models. Positive results — proofs that transformers can in principle solve certain problems — don’t imply that a language model will actually learn those solutions during training.

RELATED:

Researchers Gain New Understanding From Simple AI

A New Link to an Old Model Could Crack the Mystery of Deep Learning

How Transformers Seem to Mimic Parts of the Brain

And even results that address the limitations of transformers come with caveats: They indicate that no transformer can solve certain problems perfectly in all cases. Of course, that’s a pretty high bar. “There might be special cases of the problem that it could handle just fine,” Hsu said.

Despite these caveats, the new work offers a template for analyzing different kinds of neural network architectures that might eventually replace transformers. If a complexity theory analysis suggests that certain types of networks are more powerful than others, that would be evidence that those networks might fare better in the real world as well.

Chiang also stressed that research on the limitations of transformers is all the more valuable as language models are increasingly used in a wide range of real-world applications, making it easy to overestimate their abilities.

“There’s actually a lot of things that they don’t do that well, and we need to be very, very cognizant of what the limitations are,” Chiang said. “That’s why this kind of work is really important.

Web: https://www.quantamagazine.org/