Table of Contents
Fetching ...

Rethinking Deep Thinking: Stable Learning of Algorithms using Lipschitz Constraints

Jay Bear, Adam Prügel-Bennett, Jonathon Hare

TL;DR

This paper analyzes the growth in intermediate representations of Deep Thinking with Lipschitz Constraints to build models with many fewer parameters and providing more reliable solutions, and demonstrates DT-L is capable of robustly learning algorithms which extrapolate to harder problems than in the training set.

Abstract

Iterative algorithms solve problems by taking steps until a solution is reached. Models in the form of Deep Thinking (DT) networks have been demonstrated to learn iterative algorithms in a way that can scale to different sized problems at inference time using recurrent computation and convolutions. However, they are often unstable during training, and have no guarantees of convergence/termination at the solution. This paper addresses the problem of instability by analyzing the growth in intermediate representations, allowing us to build models (referred to as Deep Thinking with Lipschitz Constraints (DT-L)) with many fewer parameters and providing more reliable solutions. Additionally our DT-L formulation provides guarantees of convergence of the learned iterative procedure to a unique solution at inference time. We demonstrate DT-L is capable of robustly learning algorithms which extrapolate to harder problems than in the training set. We benchmark on the traveling salesperson problem to evaluate the capabilities of the modified system in an NP-hard problem where DT fails to learn.

Rethinking Deep Thinking: Stable Learning of Algorithms using Lipschitz Constraints

TL;DR

This paper analyzes the growth in intermediate representations of Deep Thinking with Lipschitz Constraints to build models with many fewer parameters and providing more reliable solutions, and demonstrates DT-L is capable of robustly learning algorithms which extrapolate to harder problems than in the training set.

Abstract

Iterative algorithms solve problems by taking steps until a solution is reached. Models in the form of Deep Thinking (DT) networks have been demonstrated to learn iterative algorithms in a way that can scale to different sized problems at inference time using recurrent computation and convolutions. However, they are often unstable during training, and have no guarantees of convergence/termination at the solution. This paper addresses the problem of instability by analyzing the growth in intermediate representations, allowing us to build models (referred to as Deep Thinking with Lipschitz Constraints (DT-L)) with many fewer parameters and providing more reliable solutions. Additionally our DT-L formulation provides guarantees of convergence of the learned iterative procedure to a unique solution at inference time. We demonstrate DT-L is capable of robustly learning algorithms which extrapolate to harder problems than in the training set. We benchmark on the traveling salesperson problem to evaluate the capabilities of the modified system in an NP-hard problem where DT fails to learn.

Paper Structure

This paper contains 38 sections, 9 equations, 13 figures, 3 tables.

Figures (13)

  • Figure 1: Recurrent-based model architectures for learning algorithms with input ${\bm{x}}$ and output ${\bm{y}}$. $\mathcal{F}\xspace$, $\mathcal{G}\xspace$ and $\mathcal{H}\xspace$ are convolutional networks that work on any size input. A scratchpad ${\bm{\phi}}$ serves as the working memory during computation. As described in \ref{['section:related']} the original DT model didn't include recall, denoted by the dotted line. The improved DT-R and our DT-L model include this connection.
  • Figure 2: Distribution of spectral norms of reshaped weight matrices for the different convolutional layers in the recurrent part of DT-R. 30 prefix-sum-solving models with width $w=32$ were sampled.
  • Figure 3: Mean training (cross-entropy) loss at each epoch for prefix-sums-solving models of varying width $w$. For small $w$ training is stable, but not all models converge; larger $w$ has a higher chance of models reaching a small loss, but the training process has very large spikes in the loss which causes some models to explode. Each curve is measured from a different random initialization of the model throughout training, for 10 models of each width
  • Figure 4: Comparison between Deep Thinking with Recall and Deep Thinking with Lipschitz Constraints on the prefix sums problem. Two left plots show the solution accuracy of inference-time runs on 512-bit problems for 30 individual models each. Each line corresponds to the performance of a network trained from scratch with different randomly initial weights. The accuracy is measured on 10 000 problem instances. The right plot shows the mean of all 30 for each. Models have a channel width of $w=32$. Shaded areas show 95% confidence intervals.
  • Figure 5: Comparison between Deep Thinking with Recall and Deep Thinking with Lipschitz Constraints on the mazes problem for small models. Two left plots show the solution accuracy of inference-time runs on $33\times33$ mazes for 14 different models each. The right plot shows the mean of all 14 for each. Models have a channel width of $w=32$. Shaded areas show 95% confidence intervals.
  • ...and 8 more figures