0%

mean shift clustering

Guide

MeanShift

  • python: git clone https://github.com/mattnedrich/MeanShift_py.git
  • cpp: git https://github.com/mattnedrich/MeanShift_cpp.git

cpp compile

1
2
3
4
cd MeanShfit_cpp 
mkdir build && cd build && cmake .. && make -j8

./MeanShift_cpp

Visualization for linux

1
sudo apt-get install gnuplot gnuplot-qt

gnuplot
plot ‘test.csv’ with points, ‘result.csv’ with points

python demo

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import mean_shift as ms
import matplotlib.pyplot as plt
import numpy as np

def ms_cluster(data):
# case(1) demo: kernel_bandwidth = 3.0, cluster_epsilon = 6
# case(2) laneseg: kernel_bandwidth = 0.5, cluster_epsilon = 2
mean_shifter = ms.MeanShift()
mean_shift_result = mean_shifter.cluster(data, kernel_bandwidth = 3, cluster_epsilon= 6)
return mean_shift_result

def sklearn_cluster(data):
from sklearn.cluster import MeanShift
from sklearn.cluster import estimate_bandwidth

bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=data.shape[0])
#print("bandwidth=",bandwidth) # 3
mean_shifter = MeanShift(bandwidth, bin_seeding=True)
mean_shifter.fit(data)

# get same results
original_points = data
cluster_centers = mean_shifter.cluster_centers_
cluster_ids = mean_shifter.labels_

mean_shift_result = ms.MeanShiftResult(original_points, cluster_centers, cluster_ids)
return mean_shift_result

def cluster_api(data, use_sklearn=True):
if use_sklearn:
return sklearn_cluster(data)
else:
return ms_cluster(data)

def print_cluster_result(mean_shift_result):
print("Original Point Shifted Point Cluster ID")
print("============================================")
for i in range(len(mean_shift_result.original_points)): # 125
original_point = mean_shift_result.original_points[i] # 125
cluster_id = mean_shift_result.cluster_ids[i] # 125 value=0,1,2
cluster_center = mean_shift_result.cluster_centers[cluster_id] # 3

print(
"(%5.2f,%5.2f) -> (%5.2f,%5.2f) cluster %i" %
(original_point[0], original_point[1],
cluster_center[0], cluster_center[1],
cluster_id)
)
print("============================================")

def main():

use_sklearn = True
data = np.genfromtxt('data.csv', delimiter=',')
print("data.shape=",data.shape)

mean_shift_result = cluster_api(data,use_sklearn)
#print_cluster_result(mean_shift_result)

original_points = mean_shift_result.original_points # (125, 2)
cluster_centers = mean_shift_result.cluster_centers # (3, 2)
cluster_ids = mean_shift_result.cluster_ids # (125,) value=[0,1,2]

unique_ids = np.unique(cluster_ids) # (3,) value=[0,1,2]

print("original_points.shape=",original_points.shape) # (125, 2)
print(original_points[:10])

print("cluster_centers.shape=",cluster_centers.shape) # (3, 2)
print(cluster_centers)

print("cluster_ids.shape=",cluster_ids.shape) # (125,)
print(cluster_ids) # [0,0,0,...1,1,1,...,2,2,2,...] 0,1,2 cluster ids

print("unique_ids.shape=",unique_ids.shape) # (3,)
print(unique_ids) # 0,1,2

x = original_points[:,0]
y = original_points[:,1]

fig = plt.figure()
ax = fig.add_subplot(111)
scatter = ax.scatter(x,y,c=cluster_ids,s=50)
for cx,cy in cluster_centers:
ax.scatter(cx,cy,s=50,c='red',marker='+')
ax.set_xlabel('x')
ax.set_ylabel('y')
plt.colorbar(scatter)

if use_sklearn:
filename = "1_sklearn"
else:
filename = "2_ms"

fig.savefig(filename)
plt.show()
print("OK "+filename)

if __name__ == "__main__":
main()

meanshift_py

#===============================
# ms 
#===============================
('data.shape=', (125, 2))
('original_points.shape=', (125, 2))
[[10.91079039  8.38941202]
 [ 9.87500165  9.9092509 ]
 [ 7.8481223  10.4317483 ]
 [ 8.53412293  9.55908561]
 [10.38316846  9.61879086]
 [ 8.11061595  9.77471761]
 [10.02119468  9.53877962]
 [ 9.37705852  9.70853991]
 [ 7.67017034  9.60315231]
 [10.94308287 11.76207349]]
('cluster_centers.shape=', (3, 2))
[[-3.45216026  5.28851174]
 [ 5.02926925  3.56548696]
 [ 8.63149568  9.25488818]]
('cluster_ids.shape=', (125,))
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
('unique_ids.shape=', (3,))
[0 1 2]
OK 2_ms

ms

sklearn

#===============================
# sklearn 
#===============================

('data.shape=', (125, 2))
('original_points.shape=', (125, 2))
[[10.91079039  8.38941202]
 [ 9.87500165  9.9092509 ]
 [ 7.8481223  10.4317483 ]
 [ 8.53412293  9.55908561]
 [10.38316846  9.61879086]
 [ 8.11061595  9.77471761]
 [10.02119468  9.53877962]
 [ 9.37705852  9.70853991]
 [ 7.67017034  9.60315231]
 [10.94308287 11.76207349]]
('cluster_centers.shape=', (3, 2))
[[ 4.79792283  3.01140269]
 [ 9.2548292  10.11312163]
 [-4.11368202  5.44826076]]
('cluster_ids.shape=', (125,))
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 0 2 2 2]
('unique_ids.shape=', (3,))
[0 1 2]
OK 1_sklearn

sklearn

Reference

History

  • 20190318: created.