evaluate()
function network.getTestNetwork(model)
-- Replace the model with fully-convolutional network
-- with the same weights, and pad it to maintain resolution
local testModel = model:clone('weight', 'bias')
-- replace linear with 1X1 conv
local nodes, containers = testModel:findModules('nn.Linear')
for i = 1, #nodes do
for j = 1, #(containers[i].modules) do
if containers[i].modules[j] == nodes[i] then
local w = nodes[i].weight
local b = nodes[i].bias
local conv = nn.SpatialConvolution1_fw(w:size(2), w:size(1)):cuda()
conv.weight:copy(w)
conv.bias:copy(b)
-- Replace with a new instance
containers[i].modules[j] = conv
end
end
end
-- replace reshape with concatenation
nodes, containers = testModel:findModules('nn.Reshape')
for i = 1, #nodes do
for j = 1, #(containers[i].modules) do
if containers[i].modules[j] == nodes[i] then
-- Replace with a new instance
containers[i].modules[j] = nn.Concatenation():cuda()
end
end
end
-- pad convolutions
padConvs(testModel)
-- switch to evalutation mode
testModel:evaluate()
return testModel
end