Table of Contents
Fetching ...

SMASH: One-Shot Model Architecture Search through HyperNetworks

Andrew Brock, Theodore Lim, J. M. Ritchie, Nick Weston

TL;DR

SMASH tackles the expensive architecture search problem by learning a HyperNet that generates network weights conditioned on architecture. It uses a memory-bank based encoding to permit diverse connectivity patterns and sample architectures efficiently. The method enables ranking architectures from a single training run via SMASH scores, with empirical correlations to fully trained performance and competitive results on CIFAR-10/100, STL-10, ModelNet10, and Imagenet32x32. The work also explores transfer learning, architectural gradient proxies, and avenues for future improvements in sampling and memory-augmented designs.

Abstract

Designing architectures for deep neural networks requires expert knowledge and substantial computation time. We propose a technique to accelerate architecture selection by learning an auxiliary HyperNet that generates the weights of a main model conditioned on that model's architecture. By comparing the relative validation performance of networks with HyperNet-generated weights, we can effectively search over a wide range of architectures at the cost of a single training run. To facilitate this search, we develop a flexible mechanism based on memory read-writes that allows us to define a wide range of network connectivity patterns, with ResNet, DenseNet, and FractalNet blocks as special cases. We validate our method (SMASH) on CIFAR-10 and CIFAR-100, STL-10, ModelNet10, and Imagenet32x32, achieving competitive performance with similarly-sized hand-designed networks. Our code is available at https://github.com/ajbrock/SMASH

SMASH: One-Shot Model Architecture Search through HyperNetworks

TL;DR

SMASH tackles the expensive architecture search problem by learning a HyperNet that generates network weights conditioned on architecture. It uses a memory-bank based encoding to permit diverse connectivity patterns and sample architectures efficiently. The method enables ranking architectures from a single training run via SMASH scores, with empirical correlations to fully trained performance and competitive results on CIFAR-10/100, STL-10, ModelNet10, and Imagenet32x32. The work also explores transfer learning, architectural gradient proxies, and avenues for future improvements in sampling and memory-augmented designs.

Abstract

Designing architectures for deep neural networks requires expert knowledge and substantial computation time. We propose a technique to accelerate architecture selection by learning an auxiliary HyperNet that generates the weights of a main model conditioned on that model's architecture. By comparing the relative validation performance of networks with HyperNet-generated weights, we can effectively search over a wide range of architectures at the cost of a single training run. To facilitate this search, we develop a flexible mechanism based on memory read-writes that allows us to define a wide range of network connectivity patterns, with ResNet, DenseNet, and FractalNet blocks as special cases. We validate our method (SMASH) on CIFAR-10 and CIFAR-100, STL-10, ModelNet10, and Imagenet32x32, achieving competitive performance with similarly-sized hand-designed networks. Our code is available at https://github.com/ajbrock/SMASH

Paper Structure

This paper contains 11 sections, 13 figures, 3 tables, 1 algorithm.

Figures (13)

  • Figure 1: Memory-Bank representations of ResNet, DenseNet, and FractalNet blocks.
  • Figure 2: (a) Structure of one op: A 1x1 conv operating on the memory banks, followed by up to 2 parallel paths of 2 convolutions each. (b) Basic network skeleton.
  • Figure 3: An unrolled graph, its equivalent memory-bank representation, and its encoded embedding.
  • Figure 4: True error and SMASH validation error for 50 different random architectures on CIFAR-100. Red line is a least-squares best fit.
  • Figure 5: (a) SMASH correlation with a crippled HyperNet. Error bars represent 1 standard deviation. (b) SMASH scores vs. rank using average scores from three HyperNets with different seeds.
  • ...and 8 more figures