Inefficient Gamma survival models?

122 views
Skip to first unread message

Gianluca Baio

unread,
Nov 20, 2016, 5:05:19 AM11/20/16
to Stan users mailing list

Dear list,

I'm working on parametric survival models --- technically,  my aim is to deal with health-economic evaluation & cost effectiveness analysis, but this only matters in the sense that there is a bunch of parametric models that are "recommended".  Long story short, I've coded up the models I need --- the main idea is to package them up and pre-compile them, so that modellers can use them in a kind of automated way and in a computational time that is not so long as to drive them to the default MLE-based estimates all the time.

In general, I've coded up the likelihoods by using standard survival models properties, eg
log_lik = d .* log_h + log_S
where d is an event indicator (=1 if the i-th observed time is associated with a recorded "event" and 0 if it's censored); log_h is the log hazard function and log_S is the log survival function. Because these are parametric models, I usually have a closed form for both log_h and log_S, so these are fairly straightforward and run quickly with no convergence issues (as expected, Stan does much better than Gibbs samplers, for these models).

One weird behaviour that I have noticed, however, is with the Gamma, Generalised Gamma and Generalised F distributions. For example, if I code up the Gamma model as
functions {
 
// Defines the log survival
  vector log_S
(vector t, real shape, vector rate) {
    vector
[num_elements(t)] log_S;
   
for (i in 1:num_elements(t)) {
      log_S
[i] = gamma_lccdf(t[i]|shape,rate[i]);
   
}
   
return log_S;
 
}
 
 
// Defines the log hazard
  vector log_h
(vector t, real shape, vector rate) {
    vector
[num_elements(t)] log_h;
    vector
[num_elements(t)] ls;
    ls
= log_S(t,shape,rate);
   
for (i in 1:num_elements(t)) {
      log_h
[i] = gamma_lpdf(t[i]|shape,rate[i]) - ls[i];
   
}
   
return log_h;
 
}
 
 
// Defines the sampling distribution
  real surv_gamma_lpdf
(vector t, vector d, real shape, vector rate) {
    vector
[num_elements(t)] log_lik;
    real prob
;
    log_lik
= d .* log_h(t,shape,rate) + log_S(t,shape,rate);
    prob
= sum(log_lik);
   
return prob;
 
}
}

data {
  int n;                            // number of observations
  vector<lower=0>[n] t;             // observed times
  vector<lower=0,upper=1>[n] d;     // censoring indicator (1=observed, 0=censored)
  int H;                            // number of covariates
  matrix[n,H] X;                    // matrix of covariates (with n rows and H columns)
}

parameters {
  vector[H] beta;         // Coefficients in the linear predictor (including intercept)
  real<lower=0> alpha;    // shape parameter
}

transformed parameters {
  vector[n] linpred;
  vector[n] mu;
  linpred = X*beta;
  for (i in 1:n) {
    mu[i] = exp(linpred[i]);
  }
}

model {
  alpha ~ gamma(0.01,0.01);
  beta ~ normal(0,5);
  t ~ surv_gamma(d,alpha,mu);
}

the resulting model is rather slow (30/40 secs per 2000 chains on a dataset with 370 observations and one factor (treatment arm). 
Inference for Stan model: surv_gamma.
2 chains, each with iter=2000; warmup=1000; thin=1;
post
-warmup draws per chain=1000, total post-warmup draws=2000.


             mean se_mean     sd      
2.5%     97.5% n_eff   Rhat
alpha      2.3988  0.0083 0.2211    1.9894    2.8520   702 1.0015
beta
[1]   -1.3800  0.0045 0.1210   -1.6235   -1.1481   711 1.0018
beta
[2]   -0.3509  0.0032 0.0901   -0.5266   -0.1750   773 0.9996
rate       0.2534  0.0012 0.0306    0.1972    0.3173   702 1.0019

Samples were drawn using NUTS(diag_e) at Sun Nov 20 08:58:09 2016.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence
, Rhat=1).
(of course one could use an even more compact (and elegant?) version as
  real surv_gamma_lpdf (vector t, vector d, real shape, vector rate) {
    vector
[num_elements(t)] log_lik;
    real prob
;
   
for (i in 1:num_elements(t)) {
      log_lik
[i] = d[i] * (gamma_lpdf(t[i]|shape,rate[i]) - gamma_lccdf(t[i]|shape,rate[i])) + gamma_lccdf(t[i]|shape,rate[i]);
   
}
    prob
= sum(log_lik);
   
return prob;
 
}

but this doesn't make any difference in the results or the computational time.


This is much slower than, say, a Weibull model (coded up in a similar way) --- but I've tested against others as well (eg logNormal, logLogistic, Exponential, etc). More interestingly, if I code up the Gamma model dealing with censoring in terms of missing data and rescaling, eg as 
// Gamma survival model
data
{
 
int<lower=1> n_obs;                     // number of observed cases
 
int<lower=0> n_cens;                    // number of censored cases
  vector
<lower=0>[n_obs] t;               // fully observed times
  vector
<lower=0>[n_cens] d;              // observed censoring times
 
int<lower=1> H;                         // number of covariates (including intercept)
  matrix
[n_obs,H] X_obs;                  // matrix of categorical covariates for the valid cases (0/1 => dummy variables)
  matrix
[n_cens,H] X_cens;                // matrix of categorical covariates for the censored cases (0/1 => dummy variables)
  vector
[H] mu_beta;                      // vector of means for the covariates
  vector
<lower=0>[H] sigma_beta;          // vector of sd for the covariates
  real
<lower=0> a_alpha;                  // first parameter for the shape distribution
  real
<lower=0> b_alpha;                  // second parameter for the shape distribution
}


parameters
{
  real
<lower=0> alpha;                    // shape of the Gamma distribution
  vector
[H] beta;                         // coefficients for the covariates
  vector
<lower=1>[n_cens] cens;           // censoring variable (latent)
}


transformed parameters
{
  vector
[n_obs] loglambda_obs;            // loglinear predictor for the observed cases
  vector
[n_cens] loglambda_cens;          // loglinear predictor for the censored cases
  vector
[n_obs] lambda_obs;               // rescaled predictor (rate) for the observed cases
  vector
[n_cens] lambda_cens;             // rescaled predictor (rate) for the censored cases
  loglambda_cens
= X_cens*beta + log(d);
 
for (i in 1:n_cens) {
    lambda_cens
[i] = exp(loglambda_cens[i]);
 
}
  loglambda_obs
= X_obs*beta;
 
for (i in 1:n_obs) {
    lambda_obs
[i] = exp(loglambda_obs[i]);
 
}
}


model
{
 
// Prior distributions
  alpha
~ gamma(a_alpha,b_alpha);
  beta
~ normal(mu_beta,sigma_beta);
 
// Data model
  cens
~ gamma(alpha,lambda_cens);
  t
~ gamma(alpha,lambda_obs);
}
the model is much quicker (about 4 secs for 2000 iterations on the same dataset and model specification). 
Inference for Stan model: Gamma.
2 chains, each with iter=2000; warmup=1000; thin=1;
post
-warmup draws per chain=1000, total post-warmup draws=2000.


            mean se_mean      sd    
2.5%    97.5% n_eff    Rhat
alpha    
2.35750 0.01027 0.21971  1.96230  2.81158   458 0.99962
beta
[1] -1.40437 0.00609 0.12639 -1.65906 -1.15915   430 0.99998
beta
[2] -0.35369 0.00287 0.08951 -0.53050 -0.17275   973 1.00068
rate    
0.24749 0.00150 0.03129  0.19032  0.31375   433 0.99982

Samples were drawn using NUTS(diag_e) at Sun Nov 20 09:47:29 2016.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence
, Rhat=1).

I may well be missing something blatantly obvious, but I thought my first code would be much more efficient --- in a sense it is because n_eff is slightly higher for all the variables; but this comes at a computational price I wasn't expecting...

Does this make sense to you guys? Is there an obvious reason to why this is so? 
I can provide some script to test, if needed.

Thanks,
Gianluca

Ben Goodrich

unread,
Nov 20, 2016, 11:25:37 AM11/20/16
to Stan users mailing list
Hi Gianluca,


On Sunday, November 20, 2016 at 5:05:19 AM UTC-5, Gianluca Baio wrote:
I'm working on parametric survival models --- technically,  my aim is to deal with health-economic evaluation & cost effectiveness analysis, but this only matters in the sense that there is a bunch of parametric models that are "recommended".  Long story short, I've coded up the models I need --- the main idea is to package them up and pre-compile them, so that modellers can use them in a kind of automated way and in a computational time that is not so long as to drive them to the default MLE-based estimates all the time.

We have a new R package to facilitate this

https://cran.r-project.org/web/packages/rstantools/index.html

See especially its rstan_package_skeleton() function (which used to just be in rstan).

We have noticed similar issues with Gamma likelihoods in the stan_glm() function in the rstanarm package. Also, gamma_lcdf() and gamma_lccdf() are a bit slow and fragile because of the numerical approximations. You might compare with the models mentioned in

https://groups.google.com/d/msg/stan-users/ALIQnFy4b3g/MtgQnVMUAwAJ

to see if there is something about the implementation that makes a big difference.

Ben

Bob Carpenter

unread,
Nov 20, 2016, 1:28:11 PM11/20/16
to stan-...@googlegroups.com
Those diffuse gamma priors can cause statistical
issues in addition to computational issues.
See Andrew's papers cited in the manual or BDA 3.

- Bob
> --
> You received this message because you are subscribed to the Google Groups "Stan users mailing list" group.
> To unsubscribe from this group and stop receiving emails from it, send an email to stan-users+...@googlegroups.com.
> To post to this group, send email to stan-...@googlegroups.com.
> For more options, visit https://groups.google.com/d/optout.

Gianluca Baio

unread,
Nov 20, 2016, 4:11:34 PM11/20/16
to Stan users mailing list
Thank you, Bob and Ben.

Ben: I have used rstantools to pre-compile the models --- it works really nicely, thanks! Also, I'll check out the models you suggest. On a *very* cursory look, they seem to have been parameterised in a slightly different way than I did --- but I'll have to check more closely.

Bob: thanks --- I have done some tests on varying the prior for the shape. I don't think the vague Gamma is responsible for much of the problem, in this case --- I know this prior is often not ideal, but it seems like changing it to a Uniform or a log-Normal doesn't change much both in terms of estimates or running time... Interestingly enough, the prior doesn't make much of a difference in terms of the estimates and, by and large, convergence (only marginally in terms of autocorrelation), when compared with the (intuitively) less efficient parameterisation (using the missing data approach). I guess Ben's point on  similar issues with Gamma likelihoods is kind of "consistent"?

BW
Gianluca

Jacqueline Buros Novik

unread,
Nov 22, 2016, 3:43:20 PM11/22/16
to Stan users mailing list

Hi Gianluca,

I took a look at this, in part because I've been meaning to implement a gamma model for survivalstan & because I am always looking for ways to improve my stan code.

I'm not an expert on this, but I think the reason why the alternate model spec (where you separate observed & censored obs) is faster is because it is more vectorized than the earlier model.

The following version is (I believe) analogous to your first model but achieves similar performance to your second. It separates observed vs censored obs in the function & then does a vectorized computation. 

(Apologies, that in this example I have changed some of the parameter names to match survivalstan defaults).


functions {
   
int count_value(vector a, real val) {
       
int s;
        s
= 0;
       
for (i in 1:num_elements(a))
           
if (a[i] == val)
                s
= s + 1;
       
return s;

   
}


 
// Defines the log survival

  real surv_gamma_lpdf
(vector t, vector d, real shape, vector rate, int num_cens, int num_obs) {
    vector
[2] log_lik;
   
int idx_obs[num_obs];
   
int idx_cens[num_cens];
    real prob
;
   
int i_cens;
   
int i_obs;
    i_cens
= 1;
    i_obs
= 1;

   
for (i in 1:num_elements(t)) {

       
if (d[i] == 1) {
            idx_obs
[i_obs] = i;
            i_obs
= i_obs+1;
       
}
       
else {
            idx_cens
[i_cens] = i;
            i_cens
= i_cens+1;
       
}
   
}
   
print(idx_obs);
    log_lik
[1] = gamma_lpdf(t[idx_obs] | shape, rate[idx_obs]);
    log_lik
[2] = gamma_lccdf(t[idx_cens] | shape, rate[idx_cens]);

    prob
= sum(log_lik);
   
return prob;
 
}
}

data
{
 
int N;                            // number of observations
  vector
<lower=0>[N] y;             // observed times
  vector
<lower=0,upper=1>[N] event; // censoring indicator (1=observed, 0=censored)
 
int M;                            // number of covariates
  matrix
[N, M] x;                   // matrix of covariates (with n rows and H columns)
}
transformed data
{
 
int num_cens;
 
int num_obs;
  num_obs
= count_value(event, 1);
  num_cens
= N - num_obs;
}
parameters
{
  vector
[M] beta;         // Coefficients in the linear predictor (including intercept)

  real
<lower=0> alpha;    // shape parameter
}
transformed parameters
{

  vector
[N] linpred;
  vector
[N] mu;
  linpred
= x*beta;
  mu
= exp(linpred);

}
model
{
  alpha
~ gamma(0.01,0.01);
  beta
~ normal(0,5);

  y
~ surv_gamma(event, alpha, mu, num_cens, num_obs);
}

My tests show the following: 

* original model posted above, with 3-part functions: 0:00:40.518775 elapsed
* using "more elegant" single function: 0:00:21.081723 elapsed
* "more elegant" version using log_mix: 0:00:20.284146 elapsed
* using vectorized-version of single function: 0:00:06.245552 elapsed

Note that these timings are for a single run, single simulation. But perhaps this gives you some ideas?


Hope this helps,

Jacki

Gianluca Baio

unread,
Nov 23, 2016, 8:07:57 AM11/23/16
to Stan users mailing list
Hi Jacki,
Thanks very much for this --- it's very interesting. I figured that vectorisation may have something to do with this...

As I said, because I'm implementing a bunch of models, I was trying to standardise (and optimise) them --- but it doesn't really matter that I have to do something slightly different for some distributions to make them work better... However, I'll also play around with this (when I get a moment! :-() and make some tests. If this is interesting, I can post back the results!
G
Reply all
Reply to author
Forward
0 new messages