In [1]:
Copied!
from IPython.display import Image, display
import numpy as np
import reddemcee
np.random.seed(1234)
from IPython.display import Image, display
import numpy as np
import reddemcee
np.random.seed(1234)
Parallelization¶
You can parallelize the sampler by using a pool, multiprocessing is recommended, but multiprocess and schwimmbad were tested as well.
In [2]:
Copied!
import multiprocessing as mp
import multiprocessing as mp
You can check how many threads you have available by using:
In [3]:
Copied!
mp.cpu_count()
mp.cpu_count()
Out[3]:
24
We will build a likelihood that holds each thread for a set amount of time:
In [4]:
Copied!
import time
def loglike(theta):
t = time.time() + np.random.uniform(0.005, 0.008)
while True:
if time.time() >= t:
break
return -0.5 * np.sum(theta**2)
def logprior(theta):
return 0.
import time
def loglike(theta):
t = time.time() + np.random.uniform(0.005, 0.008)
while True:
if time.time() >= t:
break
return -0.5 * np.sum(theta**2)
def logprior(theta):
return 0.
Serial¶
This likelihood function will sleep for a random second fraction when called. We start by evaluating the performance in a serial initialization of the sampler:
In [5]:
Copied!
ndim_ = 2
setup = [2, 20, 40, 2]
ntemps, nwalkers, nsweeps, nsteps = setup
p0 = list(np.random.randn(ntemps, nwalkers, ndim_))
ndim_ = 2
setup = [2, 20, 40, 2]
ntemps, nwalkers, nsweeps, nsteps = setup
p0 = list(np.random.randn(ntemps, nwalkers, ndim_))
In [6]:
Copied!
sampler_s = reddemcee.PTSampler(nwalkers, ndim_,
loglike, logprior,
ntemps=ntemps,
)
start = time.time()
samp_s = sampler_s.run_mcmc(p0, nsweeps, nsteps)
time_serial = time.time() - start
print(f'Serial took {time_serial:.1f} seconds')
sampler_s = reddemcee.PTSampler(nwalkers, ndim_,
loglike, logprior,
ntemps=ntemps,
)
start = time.time()
samp_s = sampler_s.run_mcmc(p0, nsweeps, nsteps)
time_serial = time.time() - start
print(f'Serial took {time_serial:.1f} seconds')
Serial took 21.1 seconds
Parallel¶
In [7]:
Copied!
with mp.Pool(10) as mypool:
sampler_p = reddemcee.PTSampler(nwalkers, ndim_,
loglike, logprior,
ntemps=ntemps,
pool=mypool)
start = time.time()
samp_p = sampler_p.run_mcmc(p0, nsweeps, nsteps)
time_parallel = time.time() - start
print(f'Serial took {time_parallel:.1f} seconds')
with mp.Pool(10) as mypool:
sampler_p = reddemcee.PTSampler(nwalkers, ndim_,
loglike, logprior,
ntemps=ntemps,
pool=mypool)
start = time.time()
samp_p = sampler_p.run_mcmc(p0, nsweeps, nsteps)
time_parallel = time.time() - start
print(f'Serial took {time_parallel:.1f} seconds')
Serial took 2.5 seconds
Almost a tenth of the time!!