3D UNet

777 views
Skip to first unread message

Alex

unread,
Mar 17, 2017, 2:30:01 PM3/17/17
to lasagne-users
I am trying to implement the 3D Unet and it looks that the results are not satisfactory. This is in reference of the discussion here: https://github.com/Lasagne/Recipes/issues/101

I am attaching the implementation files here.  Any suggestions are welcomed.

There is also the notebook here for quick preview: https://gist.github.com/mongoose54/4b80333163d5f30c2af4b071023d8ba9
LMDBLoader.py
generators.py
data_preparation.py
Unet.py
UNet_Segmentation.ipynb

Jan Schlüter

unread,
Mar 17, 2017, 4:11:24 PM3/17/17
to lasagne-users
Any suggestions are welcomed.

Instead of the argmax(1), can you show the softmax output for the positive class (i.e.,res[:,1] instead of res.argmax(1))? Maybe it learns better than you think, and just has the bias wrong. Also try playing with the class weights.

Alex

unread,
Mar 24, 2017, 5:18:29 PM3/24/17
to lasagne-users
Hi Jan,

Thank you for your note. I noticed that if I increased the base number of convolutional filters (i.e. 32) I was getting worse results than using a smaller number (i.e. 8). Do you think that has to do with the network having difficulty training on a larger number of parameters?

I managed to get a little bit better results when I cropped the images down to the non-background areas. I guess that helped with balancing the 2 classes (i.e. foreground, background) more.

Also I wanted to ask you in regards to class weights do you have some suggestion? I have been looking into forums (e.g. https://github.com/fchollet/keras/issues/1875) but without so much luck.

-Alex

Jan Schlüter

unread,
Mar 29, 2017, 12:01:09 PM3/29/17
to lasagne-users
I noticed that if I increased the base number of convolutional filters (i.e. 32) I was getting worse results than using a smaller number (i.e. 8).

Worse in training or worse in testing or both?

Again, can you visualize the softmax output instead of the binarized classification?


Also I wanted to ask you in regards to class weights do you have some suggestion?

Nothing specific. Try increasing the weights for the positive class.


I managed to get a little bit better results when I cropped the images down to the non-background areas. I guess that helped with balancing the 2 classes (i.e. foreground, background) more.

Cropping may work better than just bumping up the class weights. If you have a fully-convolutional net with a limited context going into each prediction voxel, you can also train on more or less random sub-crops of the input and target volumes (large enough to get at least one prediction voxel), presenting all-negative examples with a decreased frequency.

Best, Jan

Alex

unread,
Apr 3, 2017, 7:11:41 PM4/3/17
to lasagne-users
Hi Jon,

Here is an update notebook: https://gist.github.com/mongoose54/ab11fdc7bc107a87c0f3ad641b5ea18d which shows the network's output (3rd column in figures). As you can see it tries to learn the segmentation. But it still it is far from perfect for argmax().

I modified the class imbalance from 26/1 to 30/1 and here are the results: https://gist.github.com/mongoose54/452b2b2fb355b89036821035f1dd877f As you can see it affects the network significantly.

-Alex

Jan Schlüter

unread,
Apr 4, 2017, 5:51:56 AM4/4/17
to lasagne-users
Here is an update notebook: https://gist.github.com/mongoose54/ab11fdc7bc107a87c0f3ad641b5ea18d which shows the network's output (3rd column in figures). As you can see it tries to learn the segmentation. But it still it is far from perfect for argmax().

I modified the class imbalance from 26/1 to 30/1 and here are the results: https://gist.github.com/mongoose54/452b2b2fb355b89036821035f1dd877f As you can see it affects the network significantly.

That's much easier to interpret now. Are you sure you only changed the class imbalance? Because the predictions in the first notebook are inverted compared to the predictions in the second notebook. They can't both be correct. Is this the softmax output, or before the softmax? (If it's before the softmax, it doesn't really tell you what the output looks like, because you're missing the logits for the second class. You may want to switch to sigmoid units and binary cross-entropy after all, this also avoids the need for flattening the predictions.)

Jan

Jan Schlüter

unread,
Apr 5, 2017, 1:14:19 PM4/5/17
to lasagne-users
Another approach could be changing the loss function. There are several differentiable formulations of Dice losses and Jaccard distances: https://github.com/Lasagne/Lasagne/pull/818#issuecomment-291927147
The first paper I'm linking in the comment (https://arxiv.org/pdf/1701.03056) seems to address the same task as you.

tianjin...@gmail.com

unread,
Apr 3, 2018, 4:04:35 PM4/3/18
to lasagne-users
Hi Alex,

I have a question regarding the 3D U-Net. 

May I know your inference time from 3D U-Net? My segmentation task is based on 2d unet. Currently, the inference time for 320*200 pixels is about 70 ms. I am wondering whether this is normal. 
Would I get some speed up from 3D U-Net?


Best Regards,
Jing
Reply all
Reply to author
Forward
0 new messages