Evaluation metric for disconnected objects in MRI data

65 views Asked by At

I'm trying to evaluate the accuracy of an algorithm that segments regions in 3D MRI Volumes (Brain). I've been using Dice, Jaccard, FPR, TNR, Precision... etc but I've only done this pixelwise (I.E. FNs= number of false neg pixels). Is there a python package (or pseudo code) out there to do this at the lesion level? For example, calculate TPs as number of lesions (3d disconnected objects in grd trth) detected by my algorithm? This way the size of the lesion doesn't play as much of an effect on the accuracy metrics.

2

There are 2 answers

0
Imanol Luengo On

You could use scipy's label to find connected components in an image:

from scipy.ndimage.measurements import label
label_pred, numobj_pred = label(my_predictions)
label_true, numobj_true = label(my_groundtruth)

And then compare them using the metric of your choice.

PS: Or scikit-image's, with a demo here.

0
A.Mouraviev On

Here is the code I ended up writing to do this task. Please let me know if anyone sees any errors.

def distance(p1, p2,dim):
    if dim==2: return math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2)
    elif dim==3: return math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2+ (p2[2] - p1[2])**2)
    else: print 'error'

def closest(true_cntrd,pred_pts,dim):
    dist_list=[]
    for pred_pt in pred_pts:
        dist_list.append( distance(true_cntrd, pred_pt,dim) )
    min_idx = np.argmin(dist_list)
    return pred_pts[min_idx],min_idx

def eval_disconnected(y_true,y_pred,dim):
    y_pred=y_pred>0.5
    label_pred, numobj_pred = label(y_pred)
    label_true, numobj_true = label(y_true)

    true_labels,pred_labels=np.arange(numobj_true+1)[1:],np.arange(numobj_pred+1)[1:]
    true_centroids=center_of_mass(y_true,label_true,true_labels)
    pred_centroids=center_of_mass(y_pred,label_pred,pred_labels)

    if len(pred_labels)==0:
        TP,FN,FP=0,len(true_centroids),0
        return TP,FN,FP

    true_lbl_hit_list=[]
    pred_lbl_hit_list=[]
    for (cntr_true,lbl_t) in zip(true_centroids,np.arange(numobj_true+1)[1:]):
        closest_pred_cntr,idx = closest(cntr_true,pred_centroids,dim)
        closest_pred_cntr=tuple(int(coor) for coor in closest_pred_cntr)
        if label_true[closest_pred_cntr]==lbl_t: 
            true_lbl_hit_list.append(lbl_t)
            pred_lbl_hit_list.append(pred_labels[idx]  )
    pred_lbl_miss_list = [pred_lbl for pred_lbl in pred_labels if not(pred_lbl in pred_lbl_hit_list)]
    true_lbl_miss_list = [true_lbl for true_lbl in true_labels if not(true_lbl in true_lbl_hit_list)]
    TP=len(true_lbl_hit_list) # all the grd truth labels that were predicted 
    FN=len(true_lbl_miss_list) # all the grd trth labels that were missed
    FP=len(pred_lbl_miss_list) # all of the predicted labels that didn't hit
    return TP,FN,FP