Paper ID: 2202.04110

PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX

Guangyao Zhou, Antoine Dedieu, Nishanth Kumar, Wolfgang Lehrach, Miguel Lázaro-Gredilla, Shrinu Kushagra, Dileep George

PGMax is an open-source Python package for (a) easily specifying discrete Probabilistic Graphical Models (PGMs) as factor graphs; and (b) automatically running efficient and scalable loopy belief propagation (LBP) in JAX. PGMax supports general factor graphs with tractable factors, and leverages modern accelerators like GPUs for inference. Compared with existing alternatives, PGMax obtains higher-quality inference results with up to three orders-of-magnitude inference time speedups. PGMax additionally interacts seamlessly with the rapidly growing JAX ecosystem, opening up new research possibilities. Our source code, examples and documentation are available at https://github.com/deepmind/PGMax.

Submitted: Feb 8, 2022