Hi all
I am trying to pick samples by two different methods: torch.max and torch.multinomial
i.e)
samples1 = torch.max(prob,2)
samples2 = torch.multinomial(prob,1)
prob is 32 x 20000 matrix for my case. ( #Sample = 32, #Class = 20000 )
And I observed that torch.multinomial makes my code 5~10times slower than torch.max.
It is trivial that multinomial sampling is slower than picking maximum index.
But, this sampling seems serious speed bottleneck for my code.
(I tried to assign prob for both CPU/GPU variable)
Any suggestion for speed up multinomial sampling?