class RealNVPModel(Model):
def __init__(self, **kwargs):
super(RealNVPModel, self).__init__(**kwargs)
self.preprocess = get_preprocess_bijector(0.05)
self.realnvp_multiscale = RealNVPMultiScale()
self.bijector = tfb.Chain([self.realnvp_multiscale, self.preprocess])
def build(self, input_shape):
output_shape = self.bijector(tf.expand_dims(tf.zeros(input_shape[1:]), axis=0)).shape
self.base = tfd.Independent(tfd.Normal(loc=tf.zeros(output_shape[1:]), scale=1.),
reinterpreted_batch_ndims=3)
self._bijector_variables = (
list(self.bijector.variables))
self.flow = tfd.TransformedDistribution(
distribution=self.base,
bijector=tfb.Invert(self.bijector),
)
super(RealNVPModel, self).build(input_shape)
def call(self, inputs, training=None, **kwargs):
return self.flow
def sample(self, batch_size):
sample = self.base.sample(batch_size)
return self.bijector.inverse(sample)
# Create an instance of the RealNVPModel class
realnvp_model = RealNVPModel()
realnvp_model.build((1, 32, 32, 3))
def nll(y_true, y_pred):
return -y_pred.log_prob(y_true)
# Compile and train the model
optimizer = Adam()
realnvp_model.compile(loss=nll, optimizer=Adam())
realnvp_model.fit(train_ds, validation_data=val_ds, epochs=30)
#########################
This code block gives an error
'TransformedDistribution' object has no attribute 'shape'
Can't figure out how to fix this.