Classification with Costly Features in Hierarchical Deep Sets
Jaromír Janisch, Tomáš Pevný, Viliam Lisý
TL;DR
The paper extends Classification with Costly Features (CwCF) to structured, hierarchical data by integrating Hierarchical Multi-Instance Learning (HMIL) and hierarchical softmax, enabling per-sample feature acquisition within tree-like schemas. It leverages Advantage Actor-Critic ($A2C$) for training, decouples the classifier from the policy, and processes inputs with HMIL to produce embeddings used for classification and value estimation. Across seven datasets, including Threatcrowd-derived web-domain data, the approach demonstrates superior cost-efficiency by selectively obtaining informative features, often matching or exceeding full-information baselines at a fraction of the cost. The work provides extensive datasets, code, and analysis of explainability, pretraining benefits, and computational characteristics, highlighting practical impact for API-cost-aware classification in real-world streaming scenarios.
Abstract
Classification with Costly Features (CwCF) is a classification problem that includes the cost of features in the optimization criteria. Individually for each sample, its features are sequentially acquired to maximize accuracy while minimizing the acquired features' cost. However, existing approaches can only process data that can be expressed as vectors of fixed length. In real life, the data often possesses rich and complex structure, which can be more precisely described with formats such as XML or JSON. The data is hierarchical and often contains nested lists of objects. In this work, we extend an existing deep reinforcement learning-based algorithm with hierarchical deep sets and hierarchical softmax, so that it can directly process this data. The extended method has greater control over which features it can acquire and, in experiments with seven datasets, we show that this leads to superior performance. To showcase the real usage of the new method, we apply it to a real-life problem of classifying malicious web domains, using an online service.
