Probability as input to Markov random field (MRF): how to refine the cmex code?

391 views Asked by At

I am very new with MRF and not that much good at programming. I have obtained probability map from semantic segmentation using a CNN, I have to optimize the segmentation by using Markov Random Fields (MRF). I download the code provided by Shai Bagon in this link GCmex. Energy minimization is performed based on either alpha expansion or swap.

I compiled the code by mex and I need to refine the Unary and pair-wise energy minimization functions. I have a stack of images and need to extract the 6-neighborhood grid and include the refined neighboring in the pair-wise function.

The input to the Unary function is the probability map which is a stack with size (256,256,4) for 4 different classes: enter image description here

My questions are: Has someone refined the code according to the defining different energy function 1) I wanna change Unary and pair-wise formulation). Which functions and which parts of code should be refined and recompiled again?

2) How to change the w_i,j? it is calculate based on intensity difference, here we have only probabilities, Is it the difference of probabilities of two adjacent voxels?

I really appreciate your help. Thanks

1

There are 1 answers

11
Shai On BEST ANSWER

You have 60 slices of 256x256 pix (tot ~4G voxels), that is slices is a 256-by-256-by-60 array. Once you feed slices into your net (one by one or in batches - whatever works best for you) you have prob probability of size 256-by-256-by-60-by-4.
I suggest you use third constructor of GCMex to construct your graph for optimization.
To do so, you first need to define a sparse graph. Use sparse_adj_matrix:

[ii jj] = sparse_adj_matrix([256 256 60], 1, 1);  % 6-connect 3D grid
n = prod([256 256 60]);  % num voxels
wij = exp(-((slices(ii)-slices(jj)).^2)/(2*sig2));  % -|Ii-Ij|^2/2\sig^2
W = sparse(ii, jj, wij, n, n);  % sparse grid graph

Once you have the graph, it's all down hill from here:

Dc = -reallog(reshape(prob, n, 4)).';  %' unary/data term 
lambda = 2;  % relative weight of the smoothness term
gch = GraphCut('open', Dc, lambda*(ones(4)-eye(4)), W);  % construct the graph
[gch L] = GraphCut('expand', gch);  % minimize using "expand" method
gch = GraphCut('close', gch);  % do not forget to de-allocate

To see the output labels, you need to reshape

output = reshape(L, size(slices));

PS,
If your spatial distance between slices is larger than the gap between neighboring voxels in the same slice, you might need to use different sig2 for ii and jj that are in the same slice and for ii and jj that are on different slices. This requires a bit of an effort.