Let ws be the Tensor of your weights,
-- Detecting and removing NaNs
if ws:ne(ws):sum() > 0 then
print(sys.COLORS.red .. m.text .. ' weights has NaN/s')
NaNOk = false
end
ws[ws:ne(ws)] = 0
th> a = torch.linspace(1, 30, 30):reshape(10,3) -- create a sample 2D tensor for the illustration
th> a
1 2 3
4 5 6
7 8 9
10 11 12
13 14 15
16 17 18
19 20 21
22 23 24
25 26 27
28 29 30
[torch.FloatTensor of dimension 10x3]
th> a[{{3,4},{2}}] = 0/0 -- set a couple of values to nan for the illustration
th> a
1.0000 2.0000 3.0000
4.0000 5.0000 6.0000
7.0000 nan 9.0000
10.0000 nan 12.0000
13.0000 14.0000 15.0000
16.0000 17.0000 18.0000
19.0000 20.0000 21.0000
22.0000 23.0000 24.0000
25.0000 26.0000 27.0000
28.0000 29.0000 30.0000
[torch.FloatTensor of dimension 10x3]
th> nans = a:ne(a):max(2) -- find indices of nan and get the max along rows to get row indices
th> nans 0 0 1 1 0 0 0 0 0 0[torch.ByteTensor of dimension 10x1]
th> indices = torch.linspace(1, a:size(1), a:size(1)):long() -- a hacky step (currently unavoidable in torch)
th> indices 1 2 3 4 5 6 7 8 9 10[torch.LongTensor of dimension 10]
th> a_clean = a:index(1, indices[nans:eq(0)]) -- pick the rows that don't have nan
th> a_clean 1 2 3 4 5 6 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30[torch.FloatTensor of dimension 8x3]I’m not sure I can help…
But, if got it correctly, you want to “kill” rows in a Tensor, right?
Not sure this is “allowed”… nor if there is a smarter way then the one you showed.
--
You received this message because you are subscribed to the Google Groups "torch7" group.
To unsubscribe from this group and stop receiving emails from it, send an email to torch7+un...@googlegroups.com.
To post to this group, send email to tor...@googlegroups.com.
Visit this group at http://groups.google.com/group/torch7.
For more options, visit https://groups.google.com/d/optout.