train_set_x = train_set_x.reshape((num_elements_train,1,28,28))
train_set_y = train_set_y.reshape((num_elements_train)),1,1,1))
test_set_x = test_set_x.reshape((num_elements_test,1,28,28))
test_set_y = test_set_y.reshape((num_elements_test)),1,1,1))
caffe.set_device(0)
caffe.set_mode_gpu()
solver = caffe.SGDSolver('/Users/TJKlein/caffe/examples/mnist/tjk.prototxt')
solver.net.set_input_arrays(train_set_x.astype(np.float32),train_set_y.astype(np.float32))
solver.test_nets[0].set_input_arrays(test_set_x.astype(np.float32),test_set_y.astype(np.float32))
niter = 1000
test_interval = 1
train_loss = zeros(niter)
test_acc = zeros(int(np.ceil(niter / test_interval)))
output = zeros((niter, 8, 10))
for it in range(niter):
solver.step(1) # SGD by Caffe
train_loss[it] = solver.net.blobs['loss'].data
solver.test_nets[0].forward(start='conv1')
output[it] = solver.test_nets[0].blobs['ip2'].data[:8]
[....]
Network definition file:
layer {
name: "data"
type: "MemoryData"
top: "data"
top: "label"
memory_data_param {
batch_size: 50
channels: 1
height: 28
width: 28
}
}
layer {
name: "conv1"
type: "Convolution"
bottom: "data"
top: "conv1"
convolution_param {
num_output: 20
kernel_size: 5
weight_filler {
type: "xavier"
}
}
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv2"
type: "Convolution"
bottom: "pool1"
top: "conv2"
convolution_param {
num_output: 50
kernel_size: 5
weight_filler {
type: "xavier"
}
}
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "ip1"
type: "InnerProduct"
bottom: "pool2"
top: "ip1"
inner_product_param {
num_output: 500
weight_filler {
type: "xavier"
}
}
}
layer {
name: "relu1"
type: "ReLU"
bottom: "ip1"
top: "ip1"
}
layer {
name: "ip2"
type: "InnerProduct"
bottom: "ip1"
top: "ip2"
inner_product_param {
num_output: 10
weight_filler {
type: "xavier"
}
}
}
layer {
name: "loss"
type: "SoftmaxWithLoss"
bottom: "ip2"
bottom: "label"
top: "loss"
}