Hi, Thomas,
Thanks for your feedback.
I tried to convert the traces from HDDM to arviz InferenceData, but so far I can only use a few plots function of arviz, not the stats functions. Below is what I've done:
### First, using cavanagh 2011 data and `HDDM` module:
# define a function for parallel processing
def run_m(id):
print('running model (depends on stim) %i'%id);
import hddm
exp_name = 'cavanagh'
model_tag = 'm'
#### USE absolute pathes in docker.
dbname = '/home/jovyan/hddm/temp/df_' + exp_name + '_' + model_tag + '_chain_%i.db'%id
mname = '/home/jovyan/hddm/temp/df_' + exp_name + '_' + model_tag + '_chain_%i'%id
fname = '/opt/conda/lib/python3.7/site-packages/hddm/examples/cavanagh_theta_nn.csv'
data = hddm.load_csv(fname)
m = hddm.HDDM(data, depends_on={'v': 'stim'})
m.find_starting_values()
m.sample(5000, burn=1000,dbname=dbname, db='pickle') # it's neccessary to save the model data
m.save(mname)
return m
# run four chains
from ipyparallel import Client
v = Client()[:]
start_time = time.time() # the start time of the processing
jobs = v.map(run_m, range(4)) # 4 is the number of CPUs
wait_watching_stdout(jobs)
m_stim_list = jobs.get()
### Then, convert the trace to InferenceData:
import arviz as az
import numpy as np
import pandas as pd
import xarray as xr
df_stim_traces = []
for i in range(4):
df = m_stim_list[i]
df_trace = df.get_traces()
df_trace['chain'] = i
df_trace['draw'] = np.arange(len(df_trace), dtype=int)
print('chain', i, df_trace.shape)
df_stim_traces.append(df_trace)
df_stim_traces = pd.concat(df_stim_traces)
df_stim_traces = df_stim_traces.set_index(["chain", "draw"])
xdata_stim = xr.Dataset.from_dataframe(df_stim_traces)
df_stim = az.InferenceData(posterior=xdata_stim)
df_stim # check the InferenceData
# test the `az.plot_trace()` function
az.plot_trace(df_stim, var_names=("^a"), filter_vars='regex', rug=True)
# test the `az.loo()` function:
The key error message is "TypeError: log likelihood not found in inference data object", and the same error message occurr when I tried `az.waic(df_stim)`. I have no clue how to solve this problem, it'd be great if you can give me some hints.
I also got an error when tried to use az.plot_ppc(): "`data` argument must have the group "posterior_predictive" for ppcplot". I think this error can be solved by first run ppc (using `hddm.utils.post_pred_gen()
`) and add the data to InferenceData, but haven't tried yet.
Best,
Chuan-Peng