1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
| import mean_shift as ms import matplotlib.pyplot as plt import numpy as np
def ms_cluster(data): mean_shifter = ms.MeanShift() mean_shift_result = mean_shifter.cluster(data, kernel_bandwidth = 3, cluster_epsilon= 6) return mean_shift_result
def sklearn_cluster(data): from sklearn.cluster import MeanShift from sklearn.cluster import estimate_bandwidth
bandwidth = estimate_bandwidth(data, quantile=0.2, n_samples=data.shape[0]) mean_shifter = MeanShift(bandwidth, bin_seeding=True) mean_shifter.fit(data)
original_points = data cluster_centers = mean_shifter.cluster_centers_ cluster_ids = mean_shifter.labels_
mean_shift_result = ms.MeanShiftResult(original_points, cluster_centers, cluster_ids) return mean_shift_result
def cluster_api(data, use_sklearn=True): if use_sklearn: return sklearn_cluster(data) else: return ms_cluster(data)
def print_cluster_result(mean_shift_result): print("Original Point Shifted Point Cluster ID") print("============================================") for i in range(len(mean_shift_result.original_points)): original_point = mean_shift_result.original_points[i] cluster_id = mean_shift_result.cluster_ids[i] cluster_center = mean_shift_result.cluster_centers[cluster_id] print( "(%5.2f,%5.2f) -> (%5.2f,%5.2f) cluster %i" % (original_point[0], original_point[1], cluster_center[0], cluster_center[1], cluster_id) ) print("============================================")
def main():
use_sklearn = True data = np.genfromtxt('data.csv', delimiter=',') print("data.shape=",data.shape)
mean_shift_result = cluster_api(data,use_sklearn)
original_points = mean_shift_result.original_points cluster_centers = mean_shift_result.cluster_centers cluster_ids = mean_shift_result.cluster_ids
unique_ids = np.unique(cluster_ids)
print("original_points.shape=",original_points.shape) print(original_points[:10])
print("cluster_centers.shape=",cluster_centers.shape) print(cluster_centers)
print("cluster_ids.shape=",cluster_ids.shape) print(cluster_ids)
print("unique_ids.shape=",unique_ids.shape) print(unique_ids)
x = original_points[:,0] y = original_points[:,1]
fig = plt.figure() ax = fig.add_subplot(111) scatter = ax.scatter(x,y,c=cluster_ids,s=50) for cx,cy in cluster_centers: ax.scatter(cx,cy,s=50,c='red',marker='+') ax.set_xlabel('x') ax.set_ylabel('y') plt.colorbar(scatter)
if use_sklearn: filename = "1_sklearn" else: filename = "2_ms"
fig.savefig(filename) plt.show() print("OK "+filename)
if __name__ == "__main__": main()
|