これをNNabla上で動作させた際に、推論結果が異なってしまっている状態です。
# coding: utf-8
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
from nnabla.utils.data_iterator import data_iterator_csv_dataset
import numpy
import argparse
from PIL import Image
def network(x, y):
# Input -> 1,32,32
# Convolution -> 64,32,32
with nn.parameter_scope('Convolution'):
h = PF.convolution(x, 64, (9,9), (4,4))
# PReLU
with nn.parameter_scope('PReLU'):
h = PF.prelu(h)
# Convolution_2 -> 32,32,32
with nn.parameter_scope('Convolution_2'):
h = PF.convolution(h, 32, (5,5), (2,2))
# PReLU_2
with nn.parameter_scope('PReLU_2'):
h = PF.prelu(h)
# Convolution_3 -> 1,32,32
with nn.parameter_scope('Convolution_3'):
h = PF.convolution(h, 1, (5,5), (2,2))
# Sigmoid
h = F.sigmoid(h)
return h
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-ic", "--input_csv_path", type = str, required = True)
parser.add_argument("-ih", "--input_h5_path", type = str, required = True)
args = parser.parse_args()
# prediction in NNabla
input_data = nn.Variable((1, 1, 32, 32))
correct_data = nn.Variable((1, 1, 32, 32))
nn.clear_parameters()
network_graph = network(input_data, correct_data)
nn.parameter.load_parameters(args.input_h5_path)
test_data = data_iterator_csv_dataset(args.input_csv_path, 1, shuffle=False, normalize=True, with_memory_cache=False, with_file_cache=False)
input_data.d, correct_data.d = test_data.next()
network_graph.forward()
predict_data = network_graph.d[0][0]
# output prediction image(grayscale)
predict_data *= 255
predict_data = numpy.round(predict_data)
predict_data = numpy.int32(predict_data)
Image.fromarray(numpy.uint8(predict_data)).save("tmp.png")
余談:Neural Network Console上でPReLUレイヤーを含んだネットワークをexportすると、以下のように関数名がPF.p_re_l_uとなり、最新のNNablaのPReLUと互換が取れない名前になってしまっているようです。(上のコードでは手動で「PF.prelu」に直しています)# PReLU
with nn.parameter_scope('PReLU'):
h = PF.p_re_l_u(h)既に指摘済みの事項でしたら恐縮ですが、取り急ぎご報告まで。
今回の実装では入力画像の輝度が実質255倍されていることになるため、
出力画像の輝度もほぼ最大値に張り付いてしまっていると考えられます。
input_data.d /= 255.0
既に正規化した値を入力されていたのですね。
ご指摘の通り、原因はPReLUを正しく実行できていない点にあるようです。