Can't get flatten() to work with tensorflow backend

506 views
Skip to first unread message

Isaac Gerg

unread,
Nov 22, 2016, 9:31:39 AM11/22/16
to Keras-users
I have a simple example which works in Theano but breaks tensorflow.

input = keras.layers.Input(batch_shape=(batch_size,400,400, 1))
x = AveragePooling2D(pool_size=(2,2))(input)
# <tf.Output 'AvgPool:0' shape=(256, 200, 200, 1) dtype=float32>
x = Flatten()(x)
# <tf.Output 'Reshape:0' shape=(?, ?) dtype=float32>

x has unknown shape to tensorflow.  This causes issues further down the line.

What am I doing wrong here?

yohanes...@gmail.com

unread,
Nov 27, 2016, 4:01:47 AM11/27/16
to Keras-users
One of the common mistake related to Flatten() is related to keras image_dim_ordering configuration https://keras.io/backend/. Have you configured properly (tf for tensorflow and th for theano) ?

If that is not your problem, you might want to post the error message as well so other can understand the problem better.

Isaac Gerg

unread,
Nov 29, 2016, 12:56:38 PM11/29/16
to Keras-users, yohanes...@gmail.com
Yes, I did check this as I initially thought htis was the issue.  The ticket is here: https://github.com/fchollet/keras/issues/4470

Here is the code that reproduces the problem (this code only breaks in tensorflow).

    input = keras.layers.Input(shape=(400,400, 1))
    x = AveragePooling2D(pool_size=(2,2))(input)
    x = Flatten()(x)  # This flatten work fine.
    x = keras.layers.RepeatVector(3)(x)
    x = Reshape((200, 200, 3))(x)

    y = AveragePooling2D(pool_size=(4,4))(input)
    y = Flatten()(y)  # This flatten works fine.
    y = keras.layers.RepeatVector(3)(y)
    y = Reshape((100, 100, 3))(y)

    vgg1 = keras.applications.vgg16.VGG16(include_top=False)
    vgg1.trainable = False    # Doesnt work
    x = vgg1(x)

    vgg2 = keras.applications.vgg16.VGG16(include_top=False)
    vgg2.trainable = False  # Doesnt work
    y = vgg2(y)
    yUp  = keras.layers.UpSampling2D((2,2))(y)

    m = keras.layers.merge([x,yUp], mode='sum')
    m = keras.layers.Convolution2D(64,3,3, border_mode='same')(m)
    m = keras.layers.Flatten()(m)  # This line is the error, replacing it with the next line fixes it.
    #m = keras.layers.Reshape((6*6*64,))(m)
    m = keras.layers.MaxoutDense(8)(m)
    m = keras.layers.Dropout(0.5)(m)
    m = keras.layers.MaxoutDense(4)(m)
    m = keras.layers.Dropout(0.5)(m)
    m = keras.layers.Dense(2, activation='softmax')(m)

    model = keras.models.Model(input, m)
    return model
Reply all
Reply to author
Forward
0 new messages