Page MenuHomePhabricator

Investigate running Stan models on GPU
Open, MediumPublic

Description

Background

The Product-Analytics team relies heavily on the probabilistic programming language and software Stan to build and fit statistical models. Specifically, @nettrom_WMF fits complex hierarchical regression models in order to analyze Growth team's experiments, often taking many hours and sometimes days – even with parallel chains on stat1008 CPU cores. (Bayesian models are more computationally intensive than non-Bayesian models since they rely on MCMC for inference.)

In January, the Stan Development Team expanded support for OpenCL:

In addition to the GLM lpdf/lpmf functions, there are now 32 additional distributions that can utilize OpenCL on the GPU or CPU to speedup execution. The speedups vary between distributions and argument types (whether they are data or parameters). We have observed speedups in the range of 2- to 50-fold compared to sequential execution.

This recent tweet in particular is very promising and sparked this task:

I've rewritten the Stan program for a four level logistic regression model seven times now in pursuit of OpenCL speed gains. With an Nvidia RTX 3090 I've managed to get wall time from ~3 days down to ~8 hours in #rstats with cmdstanr

It looks like even just using OpenCL on CPU brings substantial speed-up to this, and using OpenCL on GPU can bring an even greater speed-up.

If we can accomplish this, there are two enormous practical benefits:

  • Iteration: a major part of analysis with Bayesian models requires iterative model development and comparing multiple competing models, if we can make model fitting faster as a whole we will shorten the overall analysis time AND encourage more iteration, which can lead to more accurate models and more trustworthy results
  • Deliverables: whether fitting just one model or several (as mentioned above), if we can make model fitting faster then data scientists will be able to deliver better results & insights to their stakeholders more quickly and not delay decision making

Task

  1. Build CmdStan from source with OpenCL support & interfacing with ROCm driver
  2. Execute a model on AMD GPUs that are installed on stat1005 and stat1008.

Note: The same underlying CmdStan installation can then be used by CmdStanPy, CmdStanR, and CmdStan.jl interfaces for Python, R, and Julia, respectively. The high-level modeling R package BRMS also has support for fitting models using a CmdStanR backend, so it could potentially benefit from OpenCL & GPU support as well.


This will need to be a joint effort between the Product Analytics and Data Engineering teams.


Additional Context

Stan Models

A model is written in Stan (or specified using a high-level interface like BRMS that turns it into a Stan program) and translated into C++ & compiled to a binary executable which takes data as input and outputs MCMC samples of the model's parameters inferred from the data.

Usually that model binary is executed on the CPU and multiple times independently – these processes are called "chains" and there are usually at least 4 of them (to help with diagnostics). This can be done on 1 core sequentially or across 4 cores in parallel. This can take a really long time. The various R, Python, Julia interfaces do allow you to do the model fitting through Jupyter rather than the command line, as well as through R/Python/Julia REPLs.

Event Timeline

Case in point, I'm fitting one of these models for NEWTEA revisited (T270786) and as of yesterday it's running time is 338 hours. Here's a screenshot of htop, running time is the 11th column (TIME+). While some of this might be contributed to the model specification or the data, I wouldn't be surprised to see 24–48 hours to completion because that's what it took in previous analysis.

338h R process.png (73×1 px, 62 KB)

ldelench_wmf moved this task from Triage to Current Quarter on the Product-Analytics board.

Additional details: I've built CmdStan from source a bunch of times (including on the stat nodes), just never with OpenCL support and don't know if the installed ROCm driver supports it. (I suspect it does, but would need help verifying.)