Hi, I modified the code fb.resnet.torch/dataloader.lua in order to read data triplet by triplet. But I encountered with an confusing error:
FATAL THREAD PANIC: (write) /home/haha/torch/install/share/lua/5.1/torch/File.lua:141: Unwritable object <userdata> at <?>.callback.self.resnet.DataLoader.threads.__gc__
Below is my code...
function DataLoader:run() local threads = self.threads local size, batchSize = self.__size, self.batchSize local perm = torch.randperm(size) local tripletList = self:genTriplet() local idx, sample = 1, nil local function enqueue() while idx <= size and threads:acceptsjob() do local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) threads:addjob( function(indices, nCrops, tripletList) local sz = indices:size(1) * 3 --should be 3 times as previous, since now it is triplet local batch, imageSize local target = torch.IntTensor(sz) for i, idx in ipairs(indices:totable()) do local idx_anchor = tripletList[idx][1] local idx_positive = tripletList[idx][2] local idx_negative = tripletList[idx][3] local sample_anchor = _G.dataset:get(idx_anchor) --get images local sample_positive = _G.dataset:get(idx_positive) local sample_negative = _G.dataset:get(idx_negative) local input_anchor = _G.preprocess(sample_anchor.input) local input_positive = _G.preprocess(sample_positive.input) local input_negative = _G.preprocess(sample_negative.input) if not batch then imageSize = input_anchor:size():totable() if nCrops > 1 then table.remove(imageSize, 1) end batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize)) end batch[(i-1)*2 + 1]:copy(input_anchor) batch[(i-1)*2 + 2]:copy(input_positive) batch[self.samples*self.blocks + i]:copy(input_negative) target[(i-1)*2 + 1] = sample_anchor.target target[(i-1)*2 + 2] = sample_positive.target target[self.samples*self.blocks + i] = sample_negative.target end collectgarbage() return { input = batch:view(sz * nCrops, table.unpack(imageSize)), target = target, } end, function(_sample_) -- print ('WHAT????') sample = _sample_ end, indices, self.nCrops, tripletList ) idx = idx + batchSize end end local n = 0 local function loop() enqueue() if not threads:hasjob() then return nil end threads:dojob() if threads:haserror() then threads:synchronize() end enqueue() n = n + 1 return n, sample end return loop end
Below is the original code:
function DataLoader:run() local threads = self.threads local size, batchSize = self.__size, self.batchSize local perm = torch.randperm(size) local idx, sample = 1, nil local function enqueue() while idx <= size and threads:acceptsjob() do local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1)) threads:addjob( function(indices, nCrops) local sz = indices:size(1) local batch, imageSize local target = torch.IntTensor(sz) for i, idx in ipairs(indices:totable()) do local sample = _G.dataset:get(idx) local input = _G.preprocess(sample.input) if not batch then imageSize = input:size():totable() if nCrops > 1 then table.remove(imageSize, 1) end batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize)) end batch[i]:copy(input) target[i] = sample.target end collectgarbage() return { input = batch:view(sz * nCrops, table.unpack(imageSize)), target = target, } end, function(_sample_) sample = _sample_ end, indices, self.nCrops ) idx = idx + batchSize end end local n = 0 local function loop() enqueue() if not threads:hasjob() then return nil end threads:dojob() if threads:haserror() then threads:synchronize() end enqueue() n = n + 1 return n, sample end return loop end