I have a trained MeanShift object (ms
). It has a simple list of centers. How to determine the label that a center belongs to?
I am aware about labels_
, but I do not see a connection between labels_
and cluster_centers_
.
print(ms.cluster_centers_)
[[ 40.7177164 -73.99183542]
[ 33.44943805 -112.00213969]
[ 33.44638027 -111.90188756]
...,
[ 46.7323875 -117.0001651 ]
[ 29.6899563 -95.8996757 ]
[ 31.3787916 -95.3213317 ]]
The dimension of
labels
is the dimension of your original dataset. It gives the index of the corresponding cluster. So the associated cluster center for an entryi
in the original data iscluster_centers_[labels_[i]]
.You can see in the example from sklearn that they are looping on the number of unique labels, and using
labels == k
to select all the data with that label (X[labels_ == k]
): https://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py