Table of Contents
Fetching ...

ContextFlow++: Generalist-Specialist Flow-based Generative Models with Mixed-Variable Context Encoding

Denis Gudovskiy, Tomoyuki Okuno, Yohei Nakata

TL;DR

ContextFlow++ introduces additive context conditioning and mixed-variable encoding to decouple a generalist normalizing flow from context-specific specialists. By sampling discrete and mixed contexts via surjective context encoders and decoupling the Jacobian contributions, the approach enables efficient two-stage training on large-scale data followed by domain-specific fine-tuning, while maintaining exact likelihood estimation for continuous variables. Across MNIST-R, CIFAR-10C, ATM predictive maintenance, and SMAP anomaly detection, ContextFlow++ demonstrates faster convergence and higher performance than prior conditioning methods, with flexible encoder choices balancing accuracy and efficiency. This work advances practical conditional flow modeling and opens avenues for extending to continuous-flow architectures (e.g., ODE-based) and relational-context data.

Abstract

Normalizing flow-based generative models have been widely used in applications where the exact density estimation is of major importance. Recent research proposes numerous methods to improve their expressivity. However, conditioning on a context is largely overlooked area in the bijective flow research. Conventional conditioning with the vector concatenation is limited to only a few flow types. More importantly, this approach cannot support a practical setup where a set of context-conditioned (specialist) models are trained with the fixed pretrained general-knowledge (generalist) model. We propose ContextFlow++ approach to overcome these limitations using an additive conditioning with explicit generalist-specialist knowledge decoupling. Furthermore, we support discrete contexts by the proposed mixed-variable architecture with context encoders. Particularly, our context encoder for discrete variables is a surjective flow from which the context-conditioned continuous variables are sampled. Our experiments on rotated MNIST-R, corrupted CIFAR-10C, real-world ATM predictive maintenance and SMAP unsupervised anomaly detection benchmarks show that the proposed ContextFlow++ offers faster stable training and achieves higher performance metrics. Our code is publicly available at https://github.com/gudovskiy/contextflow.

ContextFlow++: Generalist-Specialist Flow-based Generative Models with Mixed-Variable Context Encoding

TL;DR

ContextFlow++ introduces additive context conditioning and mixed-variable encoding to decouple a generalist normalizing flow from context-specific specialists. By sampling discrete and mixed contexts via surjective context encoders and decoupling the Jacobian contributions, the approach enables efficient two-stage training on large-scale data followed by domain-specific fine-tuning, while maintaining exact likelihood estimation for continuous variables. Across MNIST-R, CIFAR-10C, ATM predictive maintenance, and SMAP anomaly detection, ContextFlow++ demonstrates faster convergence and higher performance than prior conditioning methods, with flexible encoder choices balancing accuracy and efficiency. This work advances practical conditional flow modeling and opens avenues for extending to continuous-flow architectures (e.g., ODE-based) and relational-context data.

Abstract

Normalizing flow-based generative models have been widely used in applications where the exact density estimation is of major importance. Recent research proposes numerous methods to improve their expressivity. However, conditioning on a context is largely overlooked area in the bijective flow research. Conventional conditioning with the vector concatenation is limited to only a few flow types. More importantly, this approach cannot support a practical setup where a set of context-conditioned (specialist) models are trained with the fixed pretrained general-knowledge (generalist) model. We propose ContextFlow++ approach to overcome these limitations using an additive conditioning with explicit generalist-specialist knowledge decoupling. Furthermore, we support discrete contexts by the proposed mixed-variable architecture with context encoders. Particularly, our context encoder for discrete variables is a surjective flow from which the context-conditioned continuous variables are sampled. Our experiments on rotated MNIST-R, corrupted CIFAR-10C, real-world ATM predictive maintenance and SMAP unsupervised anomaly detection benchmarks show that the proposed ContextFlow++ offers faster stable training and achieves higher performance metrics. Our code is publicly available at https://github.com/gudovskiy/contextflow.
Paper Structure (16 sections, 7 equations, 4 figures, 7 tables)

This paper contains 16 sections, 7 equations, 4 figures, 7 tables.

Figures (4)

  • Figure 1: Normalizing flows implement a layered bijective transformations $f_{{\bm{\theta}}_l}$ between a target data $p({\bm{v}})$ distribution and a base $p({\bm{u}})$ distribution using learned parameters ${\bm{\theta}}_l$. A trained model $f_{{\bm{\theta}}}$ usually predicts an outcome $p_{{\bm{\theta}}}(y | {\bm{v}})$ (right) or samples data using the learned $p_{{\bm{\theta}}}({\bm{v}} | {\bm{u}})$ (left). When additional conditioning is needed to model $p({\bm{v}}, {\bm{c}})$, the conventional approach with concatenated vectors $[ {\bm{v}}_l, {\bm{c}} ]$ is limited in the type of supported bijections and lacks the support of generalist-specialist training setup.
  • Figure 2: Our high-level scheme. Mixed-variable inputs and contexts are represented by vectors ${\bm{x}}_{g,c}^{{\mathbb{R}},{\mathbb{Z}}}$. First, the data encoder $g_{{\bm{\lambda}}_g}$ and decoder $f^{-1}_{{\bm{\theta}}_g}$ are learned during large-scale generalist pretraining step. Next, the specialist context encoder $g_{{\bm{\lambda}}_c}$ and the extended decoder parameters $f^{-1}_{{\bm{\theta}}_c}$ are learned with small-scale data. Generative encoders convert discrete variables into continuous data ${\bm{v}}$ and context ${\bm{v}}$ vectors. A distributional model $h({\bm{\gamma}}_{g,c})$ also supports such two-step training and outputs likelihood $p(y | {\bm{x}})$ estimates.
  • Figure 3: Our detailed ContextFlow++ architecture with mixed-variable data and context encoders $g_{{\bm{\lambda}}_{g,c}}$ that are implemented as a sampling from the surjective flow model with various discrete-variable mapping and distribution options. The bijective flow decoder performs likelihood estimation using the encoder's ${\bm{v}}$ input during generalist training step with ${\bm{\theta}}_g$ parameters. Then, it is followed by context-specific specialist training with ${\bm{\theta}}_c$ parameters using sampled contexts ${\bm{c}}$. A distributional model $h({\bm{\gamma}}_{g,c})$ implements task's probabilistic classifier and outputs $p(y | {\bm{x}})$ likelihood estimates.
  • Figure 4: Top-1 accuracy of CIFAR-10C on test split vs. training epochs. Generalist model experiences significant accuracy drop when compared to the same model trained on the undistorted CIFAR-10. Our ContextFlow++ with $\mathop{\mathrm{arg\,max}}\limits$-based context encoder explicitly decouples general and context-specific knowledge. In comparison with conventional conditioning method, ours converges faster and results in higher accuracy metric on CIFAR-10C.