mean shift clustering


Guide

MeanShift

  • python: git clone https://github.com/mattnedrich/MeanShift_py.git
  • cpp: git https://github.com/mattnedrich/MeanShift_cpp.git

cpp compile

cd MeanShfit_cpp 
mkdir build && cd build && cmake .. && make -j8

./MeanShift_cpp 

Visualization for linux

sudo apt-get install gnuplot gnuplot-qt

gnuplot
plot ‘test.csv’ with points, ‘result.csv’ with points

python demo

import mean_shift as ms
import matplotlib.pyplot as plt
import numpy as np

def ms_cluster(data):
        # case(1) demo:     kernel_bandwidth = 3.0, cluster_epsilon = 6
        # case(2) laneseg:  kernel_bandwidth = 0.5, cluster_epsilon = 2
        mean_shifter = ms.MeanShift()
        mean_shift_result = mean_shifter.cluster(data, kernel_bandwidth = 3, cluster_epsilon= 6)
        return mean_shift_result

def sklearn_cluster(data):
        from sklearn.cluster import MeanShift
        from sklearn.cluster import estimate_bandwidth

        bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=data.shape[0])
        #print("bandwidth=",bandwidth) # 3
        mean_shifter = MeanShift(bandwidth, bin_seeding=True)
        mean_shifter.fit(data)

        # get same results
        original_points = data
        cluster_centers = mean_shifter.cluster_centers_
        cluster_ids = mean_shifter.labels_

        mean_shift_result = ms.MeanShiftResult(original_points, cluster_centers, cluster_ids)
        return mean_shift_result

def cluster_api(data, use_sklearn=True):
        if use_sklearn:
                return sklearn_cluster(data)
        else:
                return ms_cluster(data)

def print_cluster_result(mean_shift_result):
        print("Original Point     Shifted Point  Cluster ID")
        print("============================================")
        for i in range(len(mean_shift_result.original_points)): # 125
                original_point = mean_shift_result.original_points[i] # 125
                cluster_id = mean_shift_result.cluster_ids[i] # 125  value=0,1,2
                cluster_center = mean_shift_result.cluster_centers[cluster_id] # 3   

                print( 
                        "(%5.2f,%5.2f) ->  (%5.2f,%5.2f) cluster %i" % 
                        (original_point[0], original_point[1], 
                        cluster_center[0], cluster_center[1], 
                        cluster_id)
                ) 
        print("============================================")

def main():

        use_sklearn = True
        data = np.genfromtxt('data.csv', delimiter=',')
        print("data.shape=",data.shape)

        mean_shift_result = cluster_api(data,use_sklearn)
        #print_cluster_result(mean_shift_result)

        original_points =  mean_shift_result.original_points # (125, 2)
        cluster_centers = mean_shift_result.cluster_centers  # (3, 2)
        cluster_ids = mean_shift_result.cluster_ids # (125,)   value=[0,1,2]

        unique_ids = np.unique(cluster_ids) # (3,)  value=[0,1,2]

        print("original_points.shape=",original_points.shape) # (125, 2)
        print(original_points[:10])

        print("cluster_centers.shape=",cluster_centers.shape) # (3, 2)
        print(cluster_centers)

        print("cluster_ids.shape=",cluster_ids.shape) # (125,)
        print(cluster_ids) # [0,0,0,...1,1,1,...,2,2,2,...] 0,1,2 cluster ids

        print("unique_ids.shape=",unique_ids.shape) # (3,)
        print(unique_ids)  # 0,1,2

        x = original_points[:,0]
        y = original_points[:,1]

        fig = plt.figure()
        ax = fig.add_subplot(111)
        scatter = ax.scatter(x,y,c=cluster_ids,s=50)
        for cx,cy in cluster_centers:
                ax.scatter(cx,cy,s=50,c='red',marker='+')
                ax.set_xlabel('x')
                ax.set_ylabel('y')
                plt.colorbar(scatter)

        if use_sklearn:
                filename = "1_sklearn"
        else:
                filename = "2_ms"

        fig.savefig(filename)
        plt.show()
        print("OK "+filename)

if __name__ == "__main__":
        main()

meanshift_py

#===============================
# ms 
#===============================
('data.shape=', (125, 2))
('original_points.shape=', (125, 2))
[[10.91079039  8.38941202]
 [ 9.87500165  9.9092509 ]
 [ 7.8481223  10.4317483 ]
 [ 8.53412293  9.55908561]
 [10.38316846  9.61879086]
 [ 8.11061595  9.77471761]
 [10.02119468  9.53877962]
 [ 9.37705852  9.70853991]
 [ 7.67017034  9.60315231]
 [10.94308287 11.76207349]]
('cluster_centers.shape=', (3, 2))
[[-3.45216026  5.28851174]
 [ 5.02926925  3.56548696]
 [ 8.63149568  9.25488818]]
('cluster_ids.shape=', (125,))
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
('unique_ids.shape=', (3,))
[0 1 2]
OK 2_ms

ms

sklearn

#===============================
# sklearn 
#===============================

('data.shape=', (125, 2))
('original_points.shape=', (125, 2))
[[10.91079039  8.38941202]
 [ 9.87500165  9.9092509 ]
 [ 7.8481223  10.4317483 ]
 [ 8.53412293  9.55908561]
 [10.38316846  9.61879086]
 [ 8.11061595  9.77471761]
 [10.02119468  9.53877962]
 [ 9.37705852  9.70853991]
 [ 7.67017034  9.60315231]
 [10.94308287 11.76207349]]
('cluster_centers.shape=', (3, 2))
[[ 4.79792283  3.01140269]
 [ 9.2548292  10.11312163]
 [-4.11368202  5.44826076]]
('cluster_ids.shape=', (125,))
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 0 2 2 2]
('unique_ids.shape=', (3,))
[0 1 2]
OK 1_sklearn

sklearn

Reference

History

  • 20190318: created.

Author: kezunlin
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source kezunlin !
评论
  TOC