Table of Contents
Fetching ...

On the creation of narrow AI: hierarchy and nonlocality of neural network skills

Eric J. Michaud, Asher Parker-Sartori, Max Tegmark

TL;DR

The paper investigates the feasibility of building strong yet narrow AI by examining how hierarchical task structure and distributed representations affect learning and transfer. Through the CMSP toy task, it reveals curriculum learning effects that require broad data exposure to efficiently acquire certain narrow skills, and shows that skills are often nonlocal across network components, complicating pruning-based narrowing. It demonstrates that group-lasso regularization can align task-specific features with prunable units and enable unlearning of unwanted capabilities, with empirical evidence from MNIST and LLMs on Python code showing pruning can outperform distillation or training from scratch for creating compact, targeted models. These findings inform the design of narrow-AI ecosystems and highlight both the potential and limits of pruning-based specialization in real-world AI systems.

Abstract

We study the problem of creating strong, yet narrow, AI systems. While recent AI progress has been driven by the training of large general-purpose foundation models, the creation of smaller models specialized for narrow domains could be valuable for both efficiency and safety. In this work, we explore two challenges involved in creating such systems, having to do with basic properties of how neural networks learn and structure their representations. The first challenge regards when it is possible to train narrow models from scratch. Through experiments on a synthetic task, we find that it is sometimes necessary to train networks on a wide distribution of data to learn certain narrow skills within that distribution. This effect arises when skills depend on each other hierarchically, and training on a broad distribution introduces a curriculum which substantially accelerates learning. The second challenge regards how to transfer particular skills from large general models into small specialized models. We find that model skills are often not perfectly localized to a particular set of prunable components. However, we find that methods based on pruning can still outperform distillation. We investigate the use of a regularization objective to align desired skills with prunable components while unlearning unnecessary skills.

On the creation of narrow AI: hierarchy and nonlocality of neural network skills

TL;DR

The paper investigates the feasibility of building strong yet narrow AI by examining how hierarchical task structure and distributed representations affect learning and transfer. Through the CMSP toy task, it reveals curriculum learning effects that require broad data exposure to efficiently acquire certain narrow skills, and shows that skills are often nonlocal across network components, complicating pruning-based narrowing. It demonstrates that group-lasso regularization can align task-specific features with prunable units and enable unlearning of unwanted capabilities, with empirical evidence from MNIST and LLMs on Python code showing pruning can outperform distillation or training from scratch for creating compact, targeted models. These findings inform the design of narrow-AI ecosystems and highlight both the potential and limits of pruning-based specialization in real-world AI systems.

Abstract

We study the problem of creating strong, yet narrow, AI systems. While recent AI progress has been driven by the training of large general-purpose foundation models, the creation of smaller models specialized for narrow domains could be valuable for both efficiency and safety. In this work, we explore two challenges involved in creating such systems, having to do with basic properties of how neural networks learn and structure their representations. The first challenge regards when it is possible to train narrow models from scratch. Through experiments on a synthetic task, we find that it is sometimes necessary to train networks on a wide distribution of data to learn certain narrow skills within that distribution. This effect arises when skills depend on each other hierarchically, and training on a broad distribution introduces a curriculum which substantially accelerates learning. The second challenge regards how to transfer particular skills from large general models into small specialized models. We find that model skills are often not perfectly localized to a particular set of prunable components. However, we find that methods based on pruning can still outperform distillation. We investigate the use of a regularization objective to align desired skills with prunable components while unlearning unnecessary skills.

Paper Structure

This paper contains 16 sections, 14 figures, 3 tables.

Figures (14)

  • Figure 1: We study two challenges to making strong, narrow-purpose AI models. (A): Data may have hierarchical structure. If skills have a hierarchical dependence, where some skills are only learnable after more primitive skills are learned first, then it sometimes may be necessary to train on a broad distribution of data to learn certain narrow skills within that distribution. These dynamics may mean that general-purpose models must be trained to achieve good performance on some domains. (B): Model features are distributed. By default, skills may not be localizable to a particular set of model components (e.g. neurons). In this case, pruning of model components won't precisely retain wanted skills and remove unwanted skills from models. We explore methods for aligning the model features relevant to particular domains with a smaller subset of model components while unlearning others.
  • Figure 2: Training dynamics on compositional multitask sparse parity. Top: training dynamics for a single network trained on four atomic subtasks {0}, {1}, {2}, {3}, and their composition {0,1,2,3}. Bottom left: loss on compositional subtask {0,1,2,3}, when training on atomic subtasks and their composition, as in the top subplot, across 40 network seeds (depth 3, width 128). Bottom right: loss on compositional subtask {0,1,2,3}, when only training on samples from that composite task, without also training on atomic tasks, across 40 network seeds (depth 3, width 128). We find that removing the atomic tasks prevents our networks from learning the composite task. We report the minimum loss within the previous 100 steps of training to filter out loss spikes.
  • Figure 3: Top: We visualize the connectivity of 2-hidden-layer MLPs trained on a CMSP distribution with subtasks {0}, {1}, {2}, {0,1,2}, {3}, {4}, {5}, {3,4,5}, visualized before (left) and after (right) regularizing network weights with the group lasso sparsity penalty while training on subtask {0,1,2}. We find that network connectivity becomes sparse after regularizing. Negative weights are shown in blue and positive ones in red, with width proportional to the norm of the weight. Bottom: we show how pruning affects task performance on subtasks {0,1,2} and {3,4,5} at varying sparsity levels. We prune neurons based on the absolute change that ablating them has on the loss on subtask {0,1,2}. We find that subtasks here are nonlocal and entangled in the "pretrained" network (left). As we prune neurons according to their relevance on subtask {0,1,2}, at the sparsity at which performance on subtask {0,1,2} accuracy drops below 98% (green line), we can still recover some performance on subtask {3,4,5} with a small amount of additional training. Thus naive pruning here has not completely and robustly unlearned subtask {3,4,5}. However, after regularizing the weights (right), we find that not only we can more aggressively prune the network, but we have also robustly unlearned subtask {3,4,5}. Note that the degree to which subtasks are nonlocal and entangled in the pretrained networks depends on seed and width, and we show a variety of additional curves in \ref{['fig:pruningresultspretrainedacrosswidthandseed']}.
  • Figure 4: Left: We compare the performance of distillation, training from scratch, and two pruning approaches for creating small networks that classify MNIST even digits. Pruning-based approaches Pareto-dominate distillation, achieving high compression ratios with fewer datapoints. All points are averaged over 10 individual training runs. Right: When pruning using group lasso, it often helps to first prune rapidly, degrading performance, and then recover performance with no regularization. Each line represents a single training run, with a new point logged every 2,000 datapoints. Darker lines correspond to more aggressive pruning runs that attain lower final neuron counts.
  • Figure 5: Left: Neuron sparsity vs. loss curve for LLMs finetuned with group lasso regularization with varying $\lambda$ for 70k steps, as compared to the base network. Regularization flattens the sparsity vs. loss curve, at the cost of slightly degrading model performance. Right: After pruning our networks to 30%, 63%, and 80% sparsity for our runs with $\lambda$ of 5e-4, 3e-4, and 1e-3, respectively, we recover performance with additional training. We find that we can recover performance lost during pruning, including in the network that was pruned without first using group-lasso regularization training. The plot shows an exponential moving average of batch losses, with individual batch losses shaded in back.
  • ...and 9 more figures