Table of Contents
Fetching ...

NASimJax: GPU-Accelerated Policy Learning Framework for Penetration Testing

Raphael Simon, José Carrasquel, Wim Mees, Pieter Libin

Abstract

Penetration testing, the practice of simulating cyberattacks to identify vulnerabilities, is a complex sequential decision-making task that is inherently partially observable and features large action spaces. Training reinforcement learning (RL) policies for this domain faces a fundamental bottleneck: existing simulators are too slow to train on realistic network scenarios at scale, resulting in policies that fail to generalize. We present NASimJax, a complete JAX-based reimplementation of the Network Attack Simulator (NASim), achieving up to 100x higher environment throughput than the original simulator. By running the entire training pipeline on hardware accelerators, NASimJax enables experimentation on larger networks under fixed compute budgets that were previously infeasible. We formulate automated penetration testing as a Contextual POMDP and introduce a network generation pipeline that produces structurally diverse and guaranteed-solvable scenarios. Together, these provide a principled basis for studying zero-shot policy generalization. We use the framework to investigate action-space scaling and generalization across networks of up to 40 hosts. We find that Prioritized Level Replay better handles dense training distributions than Domain Randomization, particularly at larger scales, and that training on sparser topologies yields an implicit curriculum that improves out-of-distribution generalization, even on topologies denser than those seen during training. To handle linearly growing action spaces, we propose a two-stage action decomposition (2SAS) that substantially outperforms flat action masking at scale. Finally, we identify a failure mode arising from the interaction between Prioritized Level Replay's episode-reset behaviour and 2SAS's credit assignment structure. NASimJax thus provides a fast, flexible, and realistic platform for advancing RL-based penetration testing.

NASimJax: GPU-Accelerated Policy Learning Framework for Penetration Testing

Abstract

Penetration testing, the practice of simulating cyberattacks to identify vulnerabilities, is a complex sequential decision-making task that is inherently partially observable and features large action spaces. Training reinforcement learning (RL) policies for this domain faces a fundamental bottleneck: existing simulators are too slow to train on realistic network scenarios at scale, resulting in policies that fail to generalize. We present NASimJax, a complete JAX-based reimplementation of the Network Attack Simulator (NASim), achieving up to 100x higher environment throughput than the original simulator. By running the entire training pipeline on hardware accelerators, NASimJax enables experimentation on larger networks under fixed compute budgets that were previously infeasible. We formulate automated penetration testing as a Contextual POMDP and introduce a network generation pipeline that produces structurally diverse and guaranteed-solvable scenarios. Together, these provide a principled basis for studying zero-shot policy generalization. We use the framework to investigate action-space scaling and generalization across networks of up to 40 hosts. We find that Prioritized Level Replay better handles dense training distributions than Domain Randomization, particularly at larger scales, and that training on sparser topologies yields an implicit curriculum that improves out-of-distribution generalization, even on topologies denser than those seen during training. To handle linearly growing action spaces, we propose a two-stage action decomposition (2SAS) that substantially outperforms flat action masking at scale. Finally, we identify a failure mode arising from the interaction between Prioritized Level Replay's episode-reset behaviour and 2SAS's credit assignment structure. NASimJax thus provides a fast, flexible, and realistic platform for advancing RL-based penetration testing.
Paper Structure (41 sections, 11 figures, 8 tables)

This paper contains 41 sections, 11 figures, 8 tables.

Figures (11)

  • Figure 1: Left: Number of hosts ($N_h$), and subnets ($N_s$), topology ($t_d$), service ($svc_d$), process ($proc_d$) and sensitive host density ($s_d$) are all parameters that influence the generated networks. The blue nodes represent normal hosts, the red node labelled A represents the attacker's position, and the double circled red nodes represent sensitive hosts. Right: Illustration of the batched state representation. $F$ is the dimension of the host features.
  • Figure 2: Training speed comparison between NASim with 10M steps of training budget against NASimJax with 1M, 10M and 100M steps. Number of environment workers are doubled every time. Results show the impact of JAX's JIT-compilation on total runtime. Details of the speed test are available in Section \ref{['sec:perf_comp']}. The full results are in Appendix \ref{['app:speed_comp_details']}.
  • Figure 3: Visualization of $t_d$'s effect on active host counts within generated networks of 26 hosts. The full network parameters are displayed in Table \ref{['tab:env_config']}.
  • Figure 4: Comparison between standard action masking and 2SAS on 16, 26 and 40 host networks. Both algorithms use DR and evaluate in-distribution performance on 50 evaluation networks. Results are reported over 5 seeds with 95% CI.
  • Figure 5: ZSPT across topology densities $t_d$ for 16-, 26- and 40-host networks. Each algorithm is trained on three values of $t_d$ (low, mid, high) and evaluated on all other $t_d$. Bars show the mean solve rate aggregated over five seeds; boxes span the IQR (Q1--Q3), whiskers the full range.
  • ...and 6 more figures