# Siamese Networks
person1_inp = Input(shape=(38, 200))
x = Embedding(int(vocab_size), 32)(person1_inp)
x = Conv2D(32, 5, activation = 'relu')(x)
x = Reshape(target_shape=(256, 833))(x)
x = LSTM(32, dropout = 0.1, recurrent_dropout = 0.5, return_sequences=True)(x)
x = TimeDistributed(Dense(64))(x)
x = Conv1D(32, 7, activation = 'relu')(x)
person2_inp = Input(shape=(38, 200))
y = Embedding(int(vocab_size), 96)(person2_inp)
y = Conv2D(32, 5, activation = 'relu')(y)
y = Reshape(target_shape=(256, 833))(y)
y = LSTM(32, dropout = 0.1, recurrent_dropout = 0.5, return_sequences=True)(y)
y = TimeDistributed(Dense(64))(y)
y = Conv1D(32, 7, activation = 'relu')(y)
z = concatenate([x,y])
z = Bidirectional(LSTM(16, dropout = 0.1, recurrent_dropout = 0.5, return_sequences=True))(z)
z = GlobalMaxPooling1D()(z)
z = Dense(4, activation = 'relu')(z)
out = Dense(1, activation = 'softmax')(z)
model = Model([person1_inp, person2_inp], out)
model.compile(optimizer = 'rmsprop', loss = 'binary_crossentropy', metrics = ['acc'])And Finally, my vectorized Dataset looks like this:
def euclidean_distance(vects):
x, y = vects
sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
return K.sqrt(K.maximum(sum_square, K.epsilon()))
ch_inp = Input(shape=(1, 200))
csr_inp = Input(shape=(1, 200))
inp = Input(shape=(1, 200))
net = Embedding(int(vocab_size), 16)(inp)
net = Conv2D(16, 1, activation='relu')(net)
net = TimeDistributed(LSTM(8, return_sequences=True))(net)
out = Activation('relu')(net)
sia = Model(inp, out)
x = sia(csr_inp)
y = sia(ch_inp)
sub = Subtract()([x, y])
mul = Multiply()([sub, sub])
mul_x = Multiply()([x, x])
mul_y = Multiply()([y, y])
sub_xy = Subtract()([x, y])
euc = Lambda(euclidean_distance)([x, y])
z = Concatenate(axis=-1)([euc, sub_xy, mul])
z = TimeDistributed(Bidirectional(LSTM(4)))(z)
z = Activation('relu')(z)
z = GlobalMaxPooling1D()(z)
z = Dense(2, activation='relu')(z)
out = Dense(1, activation = 'sigmoid')(z)
model = Model([ch_inp, csr_inp], out)
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])