import sys
import math
import random
import numpy as np
import caffe
import lmdb
from PIL import Image
import shutil
def generate_db(image_path_base, max_examples_per_category, window_size, output_path):
max_examples_per_category = int(max_examples_per_category)
window_size = int(window_size)
halfws = window_size / 2
# Load RGB image.
image = np.array(Image.open(image_path_base + "_rgb.png")).transpose((1,0,2))
# Load depth image.
image_depth = np.array(Image.open(image_path_base + "_d.png")).transpose((1,0,2))
# Combine RGB and depth images, only keeping blue and alpha channels from depth image.
# Final array has form [x][y][R,G,B,Argb,D,Ad]
# where Argb is the alpha channel from the RGB image and Ad is the alpha channel from the depth image.
image = np.append(image, image_depth[:,:,2:], 2)
image_target = np.array(Image.open(image_path_base + "_t.png")).transpose((1,0,2))
potential_sample_pixels = [(x, y) for x in range(halfws, image.shape[0]-halfws) for y in range(halfws, image.shape[1]-halfws)]
# Determine categories and collect each pixel coordinate for each category.
category_pixels = {}
for pix_xy in potential_sample_pixels:
target = tuple(image_target[pix_xy[0]][pix_xy[1]])
category_pixels.setdefault(target, list()).append(pix_xy)
# Make sure an equal number of examples are created for each category.
for (category_pixels_xy) in category_pixels.viewvalues():
max_examples_per_category = min(max_examples_per_category, len(category_pixels_xy))
# print colour, category_pixels_xy
print "Found", len(category_pixels), "categories in target image."
# Delete old DBs if they exist.
shutil.rmtree(output_path + "_train", ignore_errors=True)
shutil.rmtree(output_path + "_test", ignore_errors=True)
in_idx = 0
in_idx_shuffle = range(len(category_pixels) * max_examples_per_category)
random.shuffle(in_idx_shuffle)
total_test = 0
total_train = 0
train_db = lmdb.open(output_path + "_train", map_size=int(1e12))
test_db = lmdb.open(output_path + "_test", map_size=int(1e12))
with train_db.begin(write=True) as train_txn:
with test_db.begin(write=True) as test_txn:
category_index = 0;
for (target_colour, category_pixels_xy) in category_pixels.viewitems():
print "Label", category_index, "corresponds to category colour", target_colour
print " Found", len(category_pixels_xy), "pixels in target image belonging to this category."
#num_examples_for_this_cat = min(max_examples_per_category, len(category_pixels_xy))
num_examples_for_this_cat = max_examples_per_category
# At least one test example per category.
num_test_for_this_cat = int(math.ceil(num_examples_for_this_cat * 0.2))
print " Adding", (num_examples_for_this_cat - num_test_for_this_cat), "training examples and", num_test_for_this_cat, "test examples."
category_sample_count = 0
for xy in random.sample(category_pixels_xy, len(category_pixels_xy)):
window = image[xy[0]-halfws:xy[0]+halfws+1, xy[1]-halfws:xy[1]+halfws+1]
# Check if any pixels are transparent (missing data)
for wp in np.nditer(window, flags=['external_loop'], order='C'):
if wp[3] != 255 or wp[5] != 255:
break;
else:
# All pixels in the window have data, add sample to DB.
# Remove alpha channels.
window = np.delete(window, [3, 5], 2)
# Normalise so total sum is 0
window = window - window.mean()
# transpose to channels, height, width for caffe.io.array_to_datum
window = window.transpose((2,1,0))
datum = caffe.io.array_to_datum(window, category_index)
if category_sample_count < num_test_for_this_cat:
test_txn.put('{:0>10d}'.format(in_idx_shuffle[in_idx]), datum.SerializeToString())
total_test += 1
else:
train_txn.put('{:0>10d}'.format(in_idx_shuffle[in_idx]), datum.SerializeToString())
total_train += 1
category_sample_count += 1
in_idx += 1
if category_sample_count == num_examples_for_this_cat:
break;
category_index += 1
train_db.close()
test_db.close()
print "Added", total_train, "total training examples and", total_test, "total test examples."
# Inspect a specified datum from the DB.
# env = lmdb.open(output_path, readonly=True)
# with env.begin() as txn:
# raw_datum = txn.get('{:0>10d}'.format(10))
#
# datum = caffe_pb2.Datum()
# datum.ParseFromString(raw_datum)
#
# flat_x = np.fromstring(datum.data, dtype=np.uint8)
# x = flat_x.reshape(datum.channels, datum.height, datum.width)
# y = datum.label
# print y, x