import tensorflow as tf
x = tf.ones(shape=[1, 100, 100, 1])patches = tf.image.extract_patches(x, sizes=[1, 11, 11, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="SAME")
patches_shape = tf.shape(patches)patches = tf.reshape(patches, [tf.reduce_prod(patches_shape[0: 3]), 11, 11, 1])
ref = patches[5000]
@tf.functiondef _ssim(test_img): return tf.image.ssim(ref, test_img, max_val=255)
@tf.functiondef batch_ssim(test_imgs): return tf.map_fn(_ssim, patches, parallel_iterations=4)
print(batch_ssim(patches)) # tensor of shape (10000,)
Thank you once again