Table of Contents
Fetching ...

Convergence and Implicit Bias of Gradient Descent on Continual Linear Classification

Hyunji Jung, Hanseul Cho, Chulhee Yun

TL;DR

The paper analyzes continual linear binary classification under a fixed per-task iteration budget, using sequential gradient descent (GD) to update the weight as tasks arrive in cyclic or random order. It shows that if the tasks are jointly linearly separable, the GD trajectory converges in direction to the offline joint max-margin solution, revealing an implicit bias distinct from projection-based methods like Sequential Max-Margin (SMM). Non-asymptotic results quantify forgetting across cycles, showing cycle-averaged forgetting decays as ${\mathcal O}(\ln^4 J / J^2)$ and the overall loss decays as ${\mathcal O}(\ln^2 J / J)$, with the forgetting controlled by the alignment of task distributions. The paper also extends results to random task ordering and to the non-separable scenario, where the model converges to the unique joint minimum with a fast rate ${\mathcal O}(\ln^2 J / J^2)$ under cyclic updates, facilitated by local strong convexity; these findings illuminate how continual learning via GD can integrate knowledge across tasks and diminish forgetting over cycles.

Abstract

We study continual learning on multiple linear classification tasks by sequentially running gradient descent (GD) for a fixed budget of iterations per task. When all tasks are jointly linearly separable and are presented in a cyclic/random order, we show the directional convergence of the trained linear classifier to the joint (offline) max-margin solution. This is surprising because GD training on a single task is implicitly biased towards the individual max-margin solution for the task, and the direction of the joint max-margin solution can be largely different from these individual solutions. Additionally, when tasks are given in a cyclic order, we present a non-asymptotic analysis on cycle-averaged forgetting, revealing that (1) alignment between tasks is indeed closely tied to catastrophic forgetting and backward knowledge transfer and (2) the amount of forgetting vanishes to zero as the cycle repeats. Lastly, we analyze the case where the tasks are no longer jointly separable and show that the model trained in a cyclic order converges to the unique minimum of the joint loss function.

Convergence and Implicit Bias of Gradient Descent on Continual Linear Classification

TL;DR

The paper analyzes continual linear binary classification under a fixed per-task iteration budget, using sequential gradient descent (GD) to update the weight as tasks arrive in cyclic or random order. It shows that if the tasks are jointly linearly separable, the GD trajectory converges in direction to the offline joint max-margin solution, revealing an implicit bias distinct from projection-based methods like Sequential Max-Margin (SMM). Non-asymptotic results quantify forgetting across cycles, showing cycle-averaged forgetting decays as and the overall loss decays as , with the forgetting controlled by the alignment of task distributions. The paper also extends results to random task ordering and to the non-separable scenario, where the model converges to the unique joint minimum with a fast rate under cyclic updates, facilitated by local strong convexity; these findings illuminate how continual learning via GD can integrate knowledge across tasks and diminish forgetting over cycles.

Abstract

We study continual learning on multiple linear classification tasks by sequentially running gradient descent (GD) for a fixed budget of iterations per task. When all tasks are jointly linearly separable and are presented in a cyclic/random order, we show the directional convergence of the trained linear classifier to the joint (offline) max-margin solution. This is surprising because GD training on a single task is implicitly biased towards the individual max-margin solution for the task, and the direction of the joint max-margin solution can be largely different from these individual solutions. Additionally, when tasks are given in a cyclic order, we present a non-asymptotic analysis on cycle-averaged forgetting, revealing that (1) alignment between tasks is indeed closely tied to catastrophic forgetting and backward knowledge transfer and (2) the amount of forgetting vanishes to zero as the cycle repeats. Lastly, we analyze the case where the tasks are no longer jointly separable and show that the model trained in a cyclic order converges to the unique minimum of the joint loss function.

Paper Structure

This paper contains 59 sections, 46 theorems, 218 equations, 11 figures.

Key Result

Theorem 3.1

Let $\{{\bm{w}}^{(t)}_k\}_{k \in [0:K-1], t \geq 0}$ be the sequence of GD iterates eq:CL_GD from any starting point ${\bm{w}}^{(0)}_0$, where tasks are given cyclically. Under Assumptions assum:seperable and assum:loss_shape, if the learning rate satisfies $\eta < \frac{\phi^2}{2K\beta\sigma^3_{\ma

Figures (11)

  • Figure 1: Trajectory of sequential GD on a two-task toy example (\ref{['subsec:experiment_detail_figintro']}) in which the offline max-margin direction is not on the subspace spanned by individual task max-margin solutions. Sequential GD iterates initially oscillate but quickly start to evolve along the same direction as the offline max-margin direction.
  • Figure 2: Comparison between continually learned and jointly trained linear classifier. We generate three jointly separable binary classification tasks (with 2D inputs) and run (1) sequential GD in a cyclic task ordering and (2) full-batch GD. It is well-known that the offline full-batch GD converges to the offline $\ell_2$ max-margin solution soudry2018implicit. We verify a similar implicit bias of sequential GD iterates (which we proved in \ref{['thm:cyclic_direction_convergence']}) by observing the decrease in angle between the model weight and the joint max-margin direction (set as $(1,1)$). We also observe similar phenomena for a more general experimental setup (e.g., random task ordering): see \ref{['subsec:experiment_detail_figthm3132']}.
  • Figure 3: We compare two continual learning scenarios with the same joint dataset ${\mathcal{D}} = \{(1,2), (1.1,1.8), (1.2,1.9), (1,-2), (1.1,-1.8), (1.2,-1.9)\}$, where labels are all $+1$ and hence omitted. We mark Task 1's data as 'o' and Task 2's data as '+'. We used $M=2$ and $K=10$. \ref{['fig:forgetting_task1_data']} displays a data composition that makes large $A^-_{1,2}$, whereas \ref{['fig:forgetting_task2_data']} displays a data composition that makes relatively small $A^-_{1,2}$ and large $A^+_{1,2}$. \ref{['fig:outline_forgetting_dist(b)']} is a plot of cycle-averaged forgetting (${\mathcal{F}}_{\rm cyc}$), evolving over cycles. For the "contradict" scenario (red), ${\mathcal{F}}_{\rm cyc}$ is always positive and diminishing to 0. In contrast, for the "aligned" scenario (blue), ${\mathcal{F}}_{\rm cyc}$ is always negative and rising to 0.
  • Figure 4: We run SMM iterations on the toy example by solving the projection problems using an optimization solver.
  • Figure 5: Loss convergence results for cyclic task ordering. The joint training loss is divided by the number of tasks in order to match the scale.
  • ...and 6 more figures

Theorems & Definitions (68)

  • Theorem 3.1
  • Theorem 3.2
  • Remark 3.6: Asymptotic Convergence Rate of Joint Training Loss
  • Definition 3.7: Forgetting
  • Definition 3.8: Cycle-Averaged Forgetting
  • Theorem 3.3
  • Theorem 3.4
  • Remark 3.9
  • Theorem 4.1
  • Theorem 4.2
  • ...and 58 more