Using PyMC
PyMC is a very powerful Python library designed for probabilistic and Bayesian analysis.
Here, we show that PyMC can be used to perform the same likelihood sampling that we previously wrote our own algorithm for.
Below, we read in the data and build the model.
The next step is to construct the PyMC sampler.
The format that PyMC expects can be a bit unfamiliar.
First we create objects for the two parameters, these are bounded so \(0 \leq k < 1\) and \(0 \leq [A]_0 < 10\) .
Strictly, these are prior probabilities , which we will look at next, but using uniform distributions means this is mathematically equivalent to likelihood sampling.
Next, we create a normally distributed likelihood function to compare the data and the model.
Finally, we sample for 1000 steps, with 10 chains.
The tune
parameter is the number of steps for tuning the Markov chain step sizes.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (10 chains in 2 jobs)
Sampling 10 chains for 1_000 tune and 1_000 draw iterations (10_000 + 10_000 draws total) took 5 seconds.
Unlike the code that we created previously, PyMC defaults to using the NUTS sampler, which stands for No-U-Turn sampler [7 ] .
This sampler enables the step size tuning that we have taken advantage of.
This results in a object assigned to the variable trace
.
posterior
<xarray.Dataset> Size: 168kB
Dimensions: (chain: 10, draw: 1000)
Coordinates:
* chain (chain) int64 80B 0 1 2 3 4 5 6 7 8 9
* draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
k (chain, draw) float64 80kB 0.09638 0.096 0.09886 ... 0.1226 0.1099
A0 (chain, draw) float64 80kB 7.426 7.504 7.467 ... 8.139 8.119 7.67
Attributes:
created_at: 2025-05-29T16:31:12.657606+00:00
arviz_version: 0.21.0
inference_library: pymc
inference_library_version: 5.20.0
sampling_time: 5.1065380573272705
tuning_steps: 1000 Dimensions:
Coordinates: (2)
Data variables: (2)
Indexes: (2)
PandasIndex
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
dtype='int64', name='draw', length=1000)) Attributes: (6)
created_at : 2025-05-29T16:31:12.657606+00:00 arviz_version : 0.21.0 inference_library : pymc inference_library_version : 5.20.0 sampling_time : 5.1065380573272705 tuning_steps : 1000
sample_stats
<xarray.Dataset> Size: 1MB
Dimensions: (chain: 10, draw: 1000)
Coordinates:
* chain (chain) int64 80B 0 1 2 3 4 5 6 7 8 9
* draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables: (12/17)
largest_eigval (chain, draw) float64 80kB nan nan nan ... nan nan
acceptance_rate (chain, draw) float64 80kB 1.0 0.8076 ... 0.8979
n_steps (chain, draw) float64 80kB 3.0 1.0 1.0 ... 3.0 5.0
index_in_trajectory (chain, draw) int64 80kB 2 1 -1 1 -3 ... -3 -1 -1 2 2
max_energy_error (chain, draw) float64 80kB -0.09478 0.2136 ... 0.2029
energy (chain, draw) float64 80kB 4.285 4.702 ... 6.788 5.41
... ...
energy_error (chain, draw) float64 80kB -0.07081 ... -0.1449
tree_depth (chain, draw) int64 80kB 2 1 1 2 3 3 ... 2 3 3 2 2 3
process_time_diff (chain, draw) float64 80kB 0.000366 ... 0.0003037
step_size_bar (chain, draw) float64 80kB 0.6361 0.6361 ... 0.6515
reached_max_treedepth (chain, draw) bool 10kB False False ... False False
lp (chain, draw) float64 80kB -3.986 -4.368 ... -3.37
Attributes:
created_at: 2025-05-29T16:31:12.679232+00:00
arviz_version: 0.21.0
inference_library: pymc
inference_library_version: 5.20.0
sampling_time: 5.1065380573272705
tuning_steps: 1000 Dimensions:
Coordinates: (2)
Data variables: (17)
largest_eigval
(chain, draw)
float64
nan nan nan nan ... nan nan nan nan
array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]) acceptance_rate
(chain, draw)
float64
1.0 0.8076 1.0 ... 0.7854 0.8979
array([[1. , 0.80764374, 1. , ..., 0.96364805, 0.98520728,
0.97781562],
[0.92660256, 0.75317568, 0.52147471, ..., 1. , 0.37089408,
0.97511687],
[0.81962674, 0.95453396, 0.75471785, ..., 0.70733903, 1. ,
0.98004155],
...,
[0.6278317 , 1. , 0.84467455, ..., 0.98956514, 0.45379379,
0.9986071 ],
[0.99525585, 0.99986509, 0.99218627, ..., 0.93610903, 1. ,
0.99469227],
[0.85856777, 0.89171662, 0.94155073, ..., 0.94576633, 0.7854393 ,
0.89791056]]) n_steps
(chain, draw)
float64
3.0 1.0 1.0 3.0 ... 7.0 3.0 3.0 5.0
array([[3., 1., 1., ..., 7., 7., 3.],
[3., 3., 5., ..., 3., 3., 3.],
[7., 3., 3., ..., 5., 1., 7.],
...,
[3., 1., 3., ..., 7., 7., 7.],
[7., 7., 5., ..., 3., 3., 7.],
[3., 3., 3., ..., 3., 3., 5.]]) index_in_trajectory
(chain, draw)
int64
2 1 -1 1 -3 3 -1 ... 2 -3 -1 -1 2 2
array([[ 2, 1, -1, ..., 4, 3, -3],
[ 1, -3, -3, ..., -2, 2, -1],
[ 2, -3, 3, ..., 1, 1, 1],
...,
[-3, 1, 2, ..., 3, -7, 3],
[-2, -2, 1, ..., -2, -1, 1],
[-3, -3, -1, ..., -1, 2, 2]]) max_energy_error
(chain, draw)
float64
-0.09478 0.2136 ... 0.6562 0.2029
array([[-9.47798070e-02, 2.13634236e-01, -3.49742630e-01, ...,
-3.73033254e-01, -4.67313439e-01, 3.33929288e-02],
[ 1.00188933e-01, 4.13019349e-01, 1.27980846e+00, ...,
-1.46821982e-02, 2.51222254e+00, 3.86342738e-02],
[ 3.87267739e-01, 1.32295571e-01, 4.81603852e-01, ...,
7.06735849e-01, -6.57585350e-01, -7.41252874e-02],
...,
[ 7.91171357e-01, -5.48246970e-01, 2.64141862e-01, ...,
-1.11904939e-01, 1.43576791e+00, -4.75638920e-02],
[-4.19239487e-02, 4.70901303e-04, 1.12365513e-02, ...,
1.55691442e-01, -7.68878272e-02, -3.23925252e-02],
[ 3.60742466e-01, 1.80748140e-01, 1.81428276e-01, ...,
1.35789060e-01, 6.56197160e-01, 2.02892674e-01]]) energy
(chain, draw)
float64
4.285 4.702 4.199 ... 6.788 5.41
array([[4.2850182 , 4.70249119, 4.19895263, ..., 4.82750004, 4.18102013,
4.20810261],
[4.76700336, 6.28043459, 8.77487062, ..., 3.47838155, 6.02036759,
3.37986081],
[4.40806763, 4.06918361, 4.86237454, ..., 5.49603026, 4.44691817,
3.797013 ],
...,
[5.17443501, 4.17345679, 3.9378058 , ..., 5.78069325, 6.8624323 ,
3.95381494],
[4.39321735, 3.25058976, 3.38107634, ..., 4.80859795, 4.01219853,
3.90930938],
[4.97917575, 4.7947817 , 5.964792 , ..., 5.47960627, 6.78814328,
5.40984822]]) perf_counter_diff
(chain, draw)
float64
0.0003658 0.0001918 ... 0.0003035
array([[0.00036576, 0.00019179, 0.00022549, ..., 0.00054522, 0.0005242 ,
0.00035802],
[0.0003462 , 0.00036897, 0.00048876, ..., 0.00020595, 0.00020482,
0.0002456 ],
[0.0004323 , 0.00020494, 0.00020756, ..., 0.00029971, 0.00010525,
0.00036834],
...,
[0.00026383, 0.00011536, 0.00020856, ..., 0.00066317, 0.00064023,
0.00062028],
[0.00042834, 0.00038165, 0.00033096, ..., 0.00022502, 0.0002084 ,
0.00037566],
[0.00032173, 0.00033757, 0.00032147, ..., 0.00022398, 0.00020978,
0.00030352]]) step_size
(chain, draw)
float64
0.5633 0.5633 ... 0.6839 0.6839
array([[0.56333308, 0.56333308, 0.56333308, ..., 0.56333308, 0.56333308,
0.56333308],
[0.64599615, 0.64599615, 0.64599615, ..., 0.64599615, 0.64599615,
0.64599615],
[0.4086125 , 0.4086125 , 0.4086125 , ..., 0.4086125 , 0.4086125 ,
0.4086125 ],
...,
[0.66119843, 0.66119843, 0.66119843, ..., 0.66119843, 0.66119843,
0.66119843],
[0.54146528, 0.54146528, 0.54146528, ..., 0.54146528, 0.54146528,
0.54146528],
[0.68386758, 0.68386758, 0.68386758, ..., 0.68386758, 0.68386758,
0.68386758]]) diverging
(chain, draw)
bool
False False False ... False False
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]) perf_counter_start
(chain, draw)
float64
1.069e+03 1.069e+03 ... 1.073e+03
array([[1069.14227677, 1069.1427495 , 1069.14306142, ..., 1069.60960368,
1069.61024013, 1069.61087645],
[1069.04594496, 1069.04641955, 1069.04688721, ..., 1069.43336192,
1069.43364262, 1069.43392298],
[1069.9808319 , 1069.98135026, 1069.98163217, ..., 1070.34199966,
1070.34237487, 1070.34255172],
...,
[1071.92995546, 1071.93030976, 1071.93050099, ..., 1072.34347746,
1072.34424766, 1072.34502045],
[1072.82390106, 1072.82442325, 1072.82488345, ..., 1073.20095064,
1073.20125249, 1073.20153002],
[1072.88528859, 1072.88571298, 1072.88614458, ..., 1073.35180811,
1073.35210825, 1073.35239315]]) smallest_eigval
(chain, draw)
float64
nan nan nan nan ... nan nan nan nan
array([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]) energy_error
(chain, draw)
float64
-0.07081 0.2136 ... -0.1134 -0.1449
array([[-7.08068981e-02, 2.13634236e-01, -3.49742630e-01, ...,
1.00802401e-01, -4.55634001e-01, 2.67096038e-02],
[ 1.00188933e-01, 1.25166828e-01, 1.27980846e+00, ...,
-1.46821982e-02, 5.31878216e-02, 3.86342738e-02],
[ 1.42123382e-01, -9.69956019e-02, 3.15944697e-01, ...,
6.64172492e-01, -6.57585350e-01, -4.97608745e-02],
...,
[ 5.84657430e-01, -5.48246970e-01, 8.41273502e-02, ...,
-1.11904939e-01, 2.46888476e-01, -4.04021507e-02],
[-4.19239487e-02, -5.92571368e-05, 8.50985321e-03, ...,
4.86611188e-02, -7.16262392e-02, -1.43836838e-02],
[-1.40110408e-01, 9.69416515e-02, 9.47056948e-03, ...,
1.35789060e-01, -1.13421258e-01, -1.44917836e-01]]) tree_depth
(chain, draw)
int64
2 1 1 2 3 3 2 2 ... 3 1 2 3 3 2 2 3
array([[2, 1, 1, ..., 3, 3, 2],
[2, 2, 3, ..., 2, 2, 2],
[3, 2, 2, ..., 3, 1, 3],
...,
[2, 1, 2, ..., 3, 3, 3],
[3, 3, 3, ..., 2, 2, 3],
[2, 2, 2, ..., 2, 2, 3]]) process_time_diff
(chain, draw)
float64
0.000366 0.0001921 ... 0.0003037
array([[0.00036599, 0.00019213, 0.00022566, ..., 0.0005453 , 0.00052441,
0.00035841],
[0.00034657, 0.00036943, 0.00048873, ..., 0.00020613, 0.00020511,
0.00024589],
[0.00043256, 0.00020523, 0.00020767, ..., 0.00029155, 0.00010541,
0.00036868],
...,
[0.00026407, 0.00011554, 0.00020883, ..., 0.0006633 , 0.00064058,
0.00062082],
[0.00042853, 0.00038199, 0.00033143, ..., 0.00022527, 0.00020867,
0.00037597],
[0.00032184, 0.00033751, 0.00032163, ..., 0.00022411, 0.00020998,
0.00030371]]) step_size_bar
(chain, draw)
float64
0.6361 0.6361 ... 0.6515 0.6515
array([[0.63609115, 0.63609115, 0.63609115, ..., 0.63609115, 0.63609115,
0.63609115],
[0.58960242, 0.58960242, 0.58960242, ..., 0.58960242, 0.58960242,
0.58960242],
[0.57992414, 0.57992414, 0.57992414, ..., 0.57992414, 0.57992414,
0.57992414],
...,
[0.60254479, 0.60254479, 0.60254479, ..., 0.60254479, 0.60254479,
0.60254479],
[0.60623006, 0.60623006, 0.60623006, ..., 0.60623006, 0.60623006,
0.60623006],
[0.65154411, 0.65154411, 0.65154411, ..., 0.65154411, 0.65154411,
0.65154411]]) reached_max_treedepth
(chain, draw)
bool
False False False ... False False
array([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]) lp
(chain, draw)
float64
-3.986 -4.368 ... -4.762 -3.37
array([[-3.98628533, -4.36777282, -3.63982603, ..., -4.12909622,
-3.69233174, -3.80879125],
[-4.49846945, -5.20709942, -6.75727877, ..., -3.27461388,
-3.29580005, -3.36310657],
[-3.77344432, -3.49936456, -4.53043179, ..., -5.00771728,
-3.49037905, -3.47853399],
...,
[-4.68505827, -3.38343294, -3.57989312, ..., -3.49537922,
-3.69667748, -3.47631931],
[-3.24131972, -3.24306266, -3.29814704, ..., -3.9883878 ,
-3.57669237, -3.75756224],
[-3.79502343, -4.48039664, -3.72281181, ..., -5.24648049,
-4.76204757, -3.36992148]]) Indexes: (2)
PandasIndex
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int64', name='chain')) PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
...
990, 991, 992, 993, 994, 995, 996, 997, 998, 999],
dtype='int64', name='draw', length=1000)) Attributes: (6)
created_at : 2025-05-29T16:31:12.679232+00:00 arviz_version : 0.21.0 inference_library : pymc inference_library_version : 5.20.0 sampling_time : 5.1065380573272705 tuning_steps : 1000
observed_data
<xarray.Dataset> Size: 80B
Dimensions: (At_dim_0: 5)
Coordinates:
* At_dim_0 (At_dim_0) int64 40B 0 1 2 3 4
Data variables:
At (At_dim_0) float64 40B 6.23 3.76 2.6 1.85 1.27
Attributes:
created_at: 2025-05-29T16:31:12.683717+00:00
arviz_version: 0.21.0
inference_library: pymc
inference_library_version: 5.20.0
This contains the chain information amoung other things.
Instead of probing into the trace
object, we can take advantage of functionality from the arviz
library to produce some informative plots.
Above, we can see the trace of each of the different chains.
The chains appear to have converged to the same distribution.
We can get the flat chains with the following function.
It is clear that, using PyMC, we have much better sampling of the distributions.
This makes using summary statistics, like the mean and standard deviation much more reliable.
mean
sd
hdi_3%
hdi_97%
mcse_mean
mcse_sd
ess_bulk
ess_tail
r_hat
k
0.106
0.01
0.088
0.125
0.000
0.000
2886.0
3675.0
1.0
A0
7.568
0.44
6.748
8.398
0.009
0.005
2645.0
3424.0
1.0