How to show the Gaussian mixture models for clustering an image?

301 views Asked by At

I used the attached code to get the GMM for some images. I also want to show the GMM on the histogram of the image. I already did that. However, I also wanna show the GMM clusters distribution. I attached the output of the GMM on the histogram and another image of what I wanna get.

Thanksenter image description here

            # Code for GMM
            
            import os
            import matplotlib.pyplot as plt
            import numpy as np
            import cv2
            
            
            img = cv2.imread("test.jpg")
            
            #Convert MxNx3 image into Kx3 where K=MxN
            img2 = img.reshape((-1,3))  #-1 reshape means, in this case MxN
            
            from sklearn.mixture import GaussianMixture as GMM
            
            #covariance choices, full, tied, diag, spherical
            gmm_model = GMM(n_components=6, covariance_type='full').fit(img2)  #tied works better than full
            gmm_labels = gmm_model.predict(img2)
            
            #Put numbers back to original shape so we can reconstruct segmented image
            original_shape = img.shape
            segmented = gmm_labels.reshape(original_shape[0], original_shape[1])
            cv2.imwrite("test_segmented.jpg", segmented)
            
            
            gmm_model.means_
            
            gmm_model.covariances_
            
            gmm_model.weights_
            
            print(gmm_model.means_, gmm_model.covariances_, gmm_model.weights_)
            
            data = img2.ravel()
            data = data[data != 0]
            data = data[data != 1]  #Removes background pixels (intensities 0 and 1)
            gmm = GMM(n_components = 6)
            gmm = gmm.fit(X=np.expand_dims(data,1))
            gmm_x = np.linspace(0,255,256)
            gmm_y = np.exp(gmm.score_samples(gmm_x.reshape(-1,1)))
            
            
            #Plot histograms and gaussian curves
            fig, ax = plt.subplots()
            ax.hist(img.ravel(),255,[2,256], density=True, stacked=True)
            ax.plot(gmm_x, gmm_y, color="crimson", lw=2, label="GMM")
            
            ax.set_ylabel("Frequency")
            ax.set_xlabel("Pixel Intensity")
            
            plt.legend()
            plt. grid(False)
            
            plt.show()
0

There are 0 answers