Observe vs. factor inside recursion

55 views
Skip to first unread message

Isaac Davis

unread,
May 3, 2020, 1:36:18 PM5/3/20
to webppl-dev
Hi folks. I have a question about how the Observe operator works inside a recursive loop.

I'm working on a model with the following generative form:

1) sample a hidden state H from a prior distribution P_H
2) sample a sequence of observable states s_0, ..., s_n, where

-s_0 is drawn from a distribution P(s_0|H) (dependent on H), and
-s_t is drawn from a distribution P(s_t|s_t-1, H) (dependent on both the previous observable state and the hidden state).


The inference problem is to infer the value of H given a sequence of observable states s_0, ..., s_n.

Since the sequences can vary in length, I use the Factor operator inside a recursive loop to condition on the data:


var inferH=function(data, H_prior, transitionDist){
  return Infer({method: 'MCMC', samples: 1000}, function(){
    
    //inner recursive loop
    var factor_data=function(states, H){
      if (states.length==data.length){return true}
      else{
        var current=states[states.length-1];
        var next=sample(transitionDist(current, H));
        factor(next==data[states.length]? 0 : -Infinity)
        return factor_data(states.concat(next), H)
      }
    }
    
    //sample hidden state
    var H=sample(H_prior);
    
    //add factors
    var factors=factor_data([], H);
    
    return H
  })
}


The function "transitionDist(state, H)" returns the distribution P(s_t|s_t-1, H) described above (or P(s_0|H) if the "state" argument is empty).

My understanding is that the "factor" statements will stack additively, so long as they are within the scope of the same Infer statement. I've run this program on a few test examples and it seems to work correctly. However, some of the data sequences can be quite long, and in those cases the MCMC sampler will take a very long time to initialize (or fail to initialize at all).

To circumvent this, I tried replacing most of the "sample + factor" statements with an Observe statement as follows:

var inferH_alt=function(data, H_prior, transitionDist){
  return Infer({method: 'MCMC', samples: 1000}, function(){
    
    //inner recursive loop
    var observe_data=function(states, H){
      if (states.length==data.length){return true}
      else{
        var current=states[states.length-1];
        var next=data[states.length];
        observe(transitionDist(current, H), next)
        return observe_data(states.concat(next), H)
      }
    }
    
    //sample hidden state
    var H=sample(H_prior);
    
    //add factors
    var factors=observe_data([], H);
    
    return H
  })
}

This version only samples the initial hidden state H. Then it Observes each data point being generated by the appropriate transition distribution. I tried running this version on the same set of test examples. It does run, but instead of outputting the correct posterior distribution over H, it assigns all of the probability mass to a single value of H (and the value changes from run to run). That is, the first version (with sample+factor) will correctly output something like

Values: [0, 1, 2], Probs: [.3, .3, .4]

But the second version (with observe) outputs something like

Values: [0], Probs [1]

each time (and the particular value it outputs is not fixed).


So, I'm wondering if perhaps the Observe operator does not stack like the factor operator in this sort of recursive loop. I.e., perhaps the second program only counts the first/last observe statement and discards the rest? (and, if so, what else I can do about the failure to initialize on longer data sequences). 

I checked the documentation on Factor and Observe (as well as the chapter on "Conditioning" in the online textbook) but could not find an answer. 

Any insights or suggestions would be much appreciated.

Thanks!
-Isaac

null-a

unread,
May 4, 2020, 4:02:30 AM5/4/20
to webppl-dev
Hi. It might be useful to know that `observe` is mostly(*) just a very simple wrapper around `factor`. If you were implementing it yourself (in WebPPL) you might write:

function myObserve(dist, value) {
 
if (value !== undefined) {

    factor
(dist.score(value));
   
return value;
 
} else {
   
return sample(dist);
 
}

}



(*) I say mostly because while this is the default implement, some inference algorithms may do something more efficient, but for `MCMC` the above is pretty much how `observe` is implemented.

I don't immediately see a problem with the code you shared, but perhaps knowing the `observe` is just a wrapper around `factor` will help you. If not, please make a minimal working reproduction of the problem that I can run, and I'll happily take a look with you.

Cheers,
Paul

Isaac Davis

unread,
May 4, 2020, 11:26:09 AM5/4/20
to webppl-dev
That is good to know about Observe. I will continue tinkering and see if I can fix the problem. If not I'll code up a simple version of the issue and post it here.

Thanks again!
-Isaac
Reply all
Reply to author
Forward
0 new messages