I am looking for a Tensorflow 2 spectral normalization (https://arxiv.org/pdf/1802.05957.pdf) implementation for generative adversarial networks that is compatible with the Keras layers. I found several TF1 implementations and just one TF2 implementation (https://medium.com/@FloydHsiu0618/spectral-normalization-implementation-of-tensorflow-2-0-keras-api-d9060d26de77). In the comments of that TF2 implementation was noted that the implementation contains an error. So I mixed some TF1 implementations (e.g. https://github.com/taki0112/Spectral_Normalization-Tensorflow) with the TF2 implementation and I am wondering if this is a valid implementation, because I'm not very familiar with TF2.
import tensorflow as tf
from tensorflow.keras.layers import Wrapper
class SpectralNorm(Wrapper):
def __init__(self, layer, iteration=1, **kwargs):
super(SpectralNorm, self).__init__(layer, **kwargs)
self.iteration = iteration
def build(self, input_shape):
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, 'kernel'):
raise ValueError('Invalid layer for SpectralNorm.')
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.u = self.add_variable(shape=(1, self.w_shape[-1]), initializer=tf.random_normal_initializer(), name='sn_u', trainable=False, dtype=tf.float32)
super(SpectralNorm, self).build()
@tf.function
def call(self, inputs, training=None):
self._compute_weights(training)
output = self.layer(inputs)
return output
def _compute_weights(self, training):
iteration = self.iteration
w_reshaped = tf.reshape(self.w, [-1, self.w_shape[-1]])
u_hat = tf.identity(self.u)
v_hat = None
for _ in range(self.iteration):
v_ = tf.matmul(u_hat, tf.transpose(w_reshaped))
v_hat = tf.nn.l2_normalize(v_)
u_ = tf.matmul(v_hat, w_reshaped)
u_hat = tf.nn.l2_normalize(u_)
if training == True: self.u.assign(u_hat)
sigma = tf.matmul(tf.matmul(v_hat, w_reshaped), tf.transpose(u_hat))
w_norm = self.w / sigma
self.layer.kernel = w_norm
def compute_output_shape(self, input_shape):
return tf.TensorShape(self.layer.compute_output_shape(input_shape).as_list())x = SpectralNorm(layers.Conv2D(...))(x)