The easiest way to do this is to just put it in Model.fprop.
You can write Models that wrap other Models and add preprocessing.
Here's an example that does test-time data augmentation with multiple crops and flips:
class Augmentor(Model):
def __init__(self, raw):
self.raw = raw
def get_params(self):
return self.raw.get_params()
def fprop(self, x):
mode = "REFLECT"
assert mode in 'REFLECT SYMMETRIC CONSTANT'.split()
pad = [2, 2]
def _pad(img):
return tf.pad(img, [[pad[0], pad[0]], [pad[1], pad[1]], [0, 0]], mode)
xp = tf.map_fn(_pad, x)
xs = []
for i in xrange(pad[0] * 2):
for j in xrange(pad[1] * 2):
xs.append(tf.slice(xp, [0, i, j, 0], tf.shape(x)))
with tf.device("/CPU:0"):
xs.append(tf.image.flip_left_right(xs[-1]))
@function.Defun(tf.float32)
def f(xarg):
xarg.set_shape(x.get_shape())
return self.raw.get_logits(xarg)
logits = [f(e) for e in xs]
logits = tf.add_n(logits) / len(logits)
return {'logits': logits}