Help with Bayesian Parameter Inference from solved ODE using TFP and Jax

91 views
Skip to first unread message

max.ore...@googlemail.com

unread,
Jun 19, 2023, 11:59:44 AM6/19/23
to TensorFlow Probability
Hi there,

So I have been trying out HMC with TFP on JAX. For this purpose, I first tried out a sample model of a simple harmonic oscillator of the solved ODE form of x=A*cos( w*t +b) where the parameters to solve for is A, w and b given some t and and set of observations y. 

The log-likelihood is then written as sum(y-yobs)**2/2*pi*sigma. 

I am able to solve for the parameters using the following code.

numberOfChains=1
numberOfSamplesBurnIn=int(1E4)
numberOfSamples=int(1E4)
myModel=sm.SampleModel(jnp.linspace(0,2*jnp.pi,100), [2,3,0.5] )

p0=jnp.array([1,1,1],dtype=jnp.float32)

t=time.time()
hcm_kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn=myModel.logLikelihood, step_size=1, num_leapfrog_steps=3)
# This adapts the inner kernel's step_size.
adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
inner_kernel = hcm_kernel,
num_adaptation_steps=int(numberOfSamplesBurnIn* 0.8)
)

samples, is_accepted =tfp.mcmc.sample_chain(
num_results=numberOfSamples,
num_burnin_steps=numberOfSamplesBurnIn,
current_state=p0,
kernel=adaptive_hmc,
trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,seed=jax.random.PRNGKey(0)
)

print("Elapsed: "+ str(time.time()-t))
return samples,is_accepted

Now the HMC  returns the input parameters with the exception of b which depends upon where the sampler is started. As such I would like to place a prior on this, how is this done? This is something I need for the actual model. Here, the second issue I am facing has to do with far slower execution

When moving from the toy model to the actual model I want to solve, I get a slow operation error. This model uses a log-likelihood where the ith element depends on all previous elements ie. the prediction is done in the following for loop, then similarly least squares and log likelihood is defined.


for i in range(1,jnp.size(t)):
Ved=Ved.at[i].set(Ved(t[i],par, Ves[i-1]))
Ves=Ves.at[i].set(Ves(Ved[i],pa[i],par))
pred=pred.at[i-1].set(Ved[i]-Ves[i])

I know preallocation is not truely recommended but I am unsure if this is actually the issue of slow operation. 

Thank you in advance for your help. I might have just not understood the tutorials correctly and I am unsure if and how I can make my life easier using distributions. 

Yours

Max

Colin Carroll

unread,
Jun 20, 2023, 9:32:40 AM6/20/23
to max.ore...@googlemail.com, TensorFlow Probability
Hi Max - 
This will be hard to answer without knowing what `sm.SampleModel` is, or seeing more code around the likelihood you're defining. 

To put a prior on `b`, it should be as simple as defining a new model:

def my_model_with_b(b, params):
  b_log_prob = tfd.Normal(0., 1.).log_prob(b)
  log_likelihood = myModel.logLikelihood(params)
  return b_log_prob + log_likelihood

then `current_state` would need to be something like `(0., p0)`.

For making a more efficient model, you say "the ith element depends on all previous elements", but the loop makes it look like it depends on only a single previous element. Writing this as a scan may be efficient.

 


--
You received this message because you are subscribed to the Google Groups "TensorFlow Probability" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfprobabilit...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfprobability/a394cbe6-2694-494a-832a-ae8bbe72686dn%40tensorflow.org.

max.ore...@googlemail.com

unread,
Jun 20, 2023, 10:01:24 AM6/20/23
to TensorFlow Probability, colca...@google.com, TensorFlow Probability, max.ore...@googlemail.com
Hi ,

Thank you for your answer.  The model code is quite simple. It is simply a harmonic oscillator in a class as shown below. So b is already a parameter of the model, the phase in generateTrajectory. I would like to place priors on the existing parameters of the log-likelihood. As such not sure why I would define a new model, given that the state is included in the pars passed to the log-likelihood. Maybe I am have misunderstood something dramatically. Basically, I am defining the MAP Estimate of some pars in the log-likelihood function passed on some observations (self.observed). What I would understand would be if I could pass for example a multivariate distribution into pars to sample from it. Does this make sense?

Thanks in advance!

Max

import jax.numpy as jnp


class SampleModel:

def __init__(self,timeIn,parsIn):
self.timeArray=timeIn;
self.parametersObservered=parsIn;
self.observed=self.generateTrajectory(self.timeArray,self.parametersObservered[0], self.parametersObservered[1], self.parametersObservered[2])

#Likelihood of simple harmonic oscillator
def generateTrajectory(self,timeIn,A,omega,phase):
outputArr=jnp.zeros(jnp.shape(timeIn))
for i in range(0,jnp.size(timeIn)):
outputArr=outputArr.at[i].set(A*jnp.cos(omega*timeIn[i]+phase))
return outputArr;

def predict(self,pars):
return self.generateTrajectory(self.timeArray,pars[0], pars[1], pars[2])
def logLikelihood(self,pars):
leastSquares=(self.predict(pars)-self.observed)**2
logLikelihood=(-0.5*jnp.sum(leastSquares)/5);
return logLikelihood

max.ore...@googlemail.com

unread,
Jun 24, 2023, 8:18:16 AM6/24/23
to TensorFlow Probability, max.ore...@googlemail.com, colca...@google.com, TensorFlow Probability
HI Collin,

Thank you very much for your answer. I have played around with it but I am simply not able to run the basic HMC Sampler as detailed in the example. I always seem to be getting the error that there is a Keyerror in omega. 

I have tried out different distributions such as Uniform but always with the same results. As my target function in the above case would still be the logProbability of the model and I can compute the log probability for some parameter, it does work out. 
In addition to the mcmc sampler not working, placing the original pars which generated the trajectory into the function creates a positive infinite likelihood rather than 1. Not entirely sure where I am going wrong. I created a colab which should be easier than pasting the code. 


Thank you very much for your help so far! I hope to summarise this similiar to the blackbox numpy example in PyMC as my experience with tfp on jax seems to show a very performant tool.

Yours 

Max

Colin Carroll

unread,
Jun 26, 2023, 11:21:42 AM6/26/23
to max.ore...@googlemail.com, TensorFlow Probability
Hey -- thanks for sharing the colab. The only change you need to make for the code to run is this line:

```
    p0=[1.,1.,1.]
```

The reason being is that now the JointDistribution acts on either tuples or lists, but not on an array. You might even want to use something like `myModel.conditioned_model.sample_unpinned(shape=(numberOfChains,), key=seed)`, which returns a namedtuple.


max.ore...@googlemail.com

unread,
Jun 29, 2023, 3:03:40 AM6/29/23
to TensorFlow Probability, colca...@google.com, TensorFlow Probability, max.ore...@googlemail.com
Hey Colin! 

Thank you for your help. I have played a little around with it. It's really powerful stuff once you get the ideas and intended use!  

I converted the parameters to log space via using exp. I have seen that Bijectors do something similar. Is it worth using them?

Thanks for your help sofar!

Yours 

Max

Colin Carroll

unread,
Jun 29, 2023, 10:18:17 AM6/29/23
to max.ore...@googlemail.com, TensorFlow Probability
In this particular case, you might as well directly use a `tfd.LogNormal` for each of your parameters.

One use for bijectors is with TransformedTransitionKernels, which you'd wrap your MCMC kernel with. This can help with gradient-based MCMC by providing support on all of the reals, and eliminating areas of high curvature. If you don't want to get too deep into the inference, `tfp.experimental.windowed_adaptive_nuts` works well with JointDistributions, and can plug into the larger ecosystem. For example:

myModel = SampleModel(jnp.linspace(0,2*jnp.pi,100), jnp.asarray([jnp.log(2),jnp.log(3),jnp.log(0.5)]) )
draws, trace = tfp.experimental.mcmc.windowed_adaptive_nuts(
    1_000,
    myModel.model,
    n_chains=1,
    trajectory=myModel.observed,
    seed=jax.random.PRNGKey(1))

Will apply bijectors and do a sensible step size tuning scheme. You could then use ArviZ to plot/analyze the fit using 

import arviz as az
import numpy as np


idata = az.from_dict({k: np.swapaxes(v, 0, 1) for k, v in draws._asdict().items()},  
                     sample_stats={k: np.swapaxes(np.array(v)[..., None], 0, 1).squeeze() for k, v in trace.items() if k != 'variance_scaling'})

az.plot_trace(idata)

Finally, you could use the draws to plot a posterior predictive distribution, via

posterior_trajectory = myModel.model.sample(seed=jax.random.PRNGKey(0), value=draws).trajectory

fig, ax = plt.subplots()
ax.plot(myModel.observed)
ax.plot(posterior_trajectory.squeeze().T, 'k.', alpha=0.01);

max.ore...@googlemail.com

unread,
Jul 4, 2023, 11:35:26 AM7/4/23
to TensorFlow Probability, colca...@google.com, TensorFlow Probability, max.ore...@googlemail.com
Hi Colin,

Thank you very much for your time and explanation. They have helped me tremendously! This type of "black box" log likelihood fitting is awesome! 

I will try to adapt this to my work!

Yours

MAx
Reply all
Reply to author
Forward
0 new messages