'TransformedDistribution' object has no attribute 'shape'

196 views
Skip to first unread message

Nisheeth Jaiswal

unread,
Jul 7, 2021, 12:24:22 PM7/7/21
to TensorFlow Probability
class RealNVPModel(Model):

    def __init__(self, **kwargs):
        super(RealNVPModel, self).__init__(**kwargs)
        self.preprocess = get_preprocess_bijector(0.05)
        self.realnvp_multiscale = RealNVPMultiScale()
        self.bijector = tfb.Chain([self.realnvp_multiscale, self.preprocess])
        
    def build(self, input_shape):
        output_shape = self.bijector(tf.expand_dims(tf.zeros(input_shape[1:]), axis=0)).shape
        self.base = tfd.Independent(tfd.Normal(loc=tf.zeros(output_shape[1:]), scale=1.),
                                    reinterpreted_batch_ndims=3)
        self._bijector_variables = (
            list(self.bijector.variables))
        self.flow = tfd.TransformedDistribution(
            distribution=self.base,
            bijector=tfb.Invert(self.bijector),
        )
        super(RealNVPModel, self).build(input_shape)

    def call(self, inputs, training=None, **kwargs):
        return self.flow

    def sample(self, batch_size):
        sample = self.base.sample(batch_size)
        return self.bijector.inverse(sample)

# Create an instance of the RealNVPModel class
realnvp_model = RealNVPModel()
realnvp_model.build((1, 32, 32, 3))

def nll(y_true, y_pred):
    return -y_pred.log_prob(y_true)

# Compile and train the model

optimizer = Adam()
realnvp_model.compile(loss=nll, optimizer=Adam())
realnvp_model.fit(train_ds, validation_data=val_ds, epochs=30)

#########################
This code block gives an error 
'TransformedDistribution' object has no attribute 'shape'

Can't figure out how to fix this.
Reply all
Reply to author
Forward
0 new messages