CSI-BERT2: A BERT-inspired Framework for Efficient CSI Prediction and Classification in Wireless Communication and Sensing
Zijian Zhao, Fanyi Meng, Zhonghao Lyu, Hang Li, Xiaoyang Li, Guangxu Zhu
TL;DR
This work tackles the challenges of learning from limited CSI data in wireless sensing and the high-dimensional, rapidly varying CSI in wireless communication. It introduces CSI-BERT2, a unified BERT-inspired framework that combines MLM-based unsupervised pretraining with a mask prediction model (MPM), an adaptive re-weighting layer (ARL), and an MLP-based temporal embedding to enable CSI prediction and classification. The model uses a dual-head architecture and GAN-based pretraining to learn robust representations, achieving state-of-the-art results across four real/simulated datasets and showing strong robustness to sampling-rate changes and packet loss, with rapid inference suitable for real-time deployment. These advances point toward a practical, generalizable foundation for wireless sensing and communication tasks without task-specific retraining for varying data conditions.
Abstract
Channel state information (CSI) is a fundamental component in both wireless communication and sensing systems, enabling critical functions such as radio resource optimization and environmental perception. In wireless sensing, data scarcity and packet loss hinder efficient model training, while in wireless communication, high-dimensional CSI matrices and short coherent times caused by high mobility present challenges in CSI estimation. To address these issues, we propose a unified framework named CSI-BERT2 for CSI prediction and classification tasks, built on CSI-BERT, which adapts BERT to capture the complex relationships among CSI sequences through a bidirectional self-attention mechanism. We introduce a two-stage training method that first uses a mask language model (MLM) to enable the model to learn general feature extraction from scarce datasets in an unsupervised manner, followed by fine-tuning for specific downstream tasks. Specifically, we extend MLM into a mask prediction model (MPM), which efficiently addresses the CSI prediction task. To further enhance the representation capacity of CSI data, we modify the structure of the original CSI-BERT. We introduce an adaptive re-weighting layer (ARL) to enhance subcarrier representation and a multi-layer perceptron (MLP)-based temporal embedding module to mitigate temporal information loss problem inherent in the original Transformer. Extensive experiments on both real-world collected and simulated datasets demonstrate that CSI-BERT2 achieves state-of-the-art performance across all tasks. Our results further show that CSI-BERT2 generalizes effectively across varying sampling rates and robustly handles discontinuous CSI sequences caused by packet loss-challenges that conventional methods fail to address. The dataset and code are publicly available at https://github.com/RS2002/CSI-BERT2 .
