From 4a8dd0015c990963c7e5b077c77ab32624b426ab Mon Sep 17 00:00:00 2001 From: lurenZJF <1203908635@qq.com> Date: Wed, 21 Oct 2020 22:36:07 +0800 Subject: [PATCH] update show --- Modules/static_dbscan.py | 77 ++++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/Modules/static_dbscan.py b/Modules/static_dbscan.py index 845051e..02e638d 100644 --- a/Modules/static_dbscan.py +++ b/Modules/static_dbscan.py @@ -1,8 +1,7 @@ from sklearn.cluster import DBSCAN from sklearn import metrics -import pandas as pd import numpy as np -from pprint import pprint +import matplotlib.pyplot as plt def my_db(eps, min_sample, metric, corpus_embeddings): @@ -24,27 +23,32 @@ def my_db(eps, min_sample, metric, corpus_embeddings): return db -def show_db(label_data, db, corpus, corpus_embeddings, corpus_file=None, label_file=None): - if label_file is not None: - # 读取真实标签数据 - label_data = pd.read_csv(label_file) - labels_true = label_data.flag.to_list() - labels = db.labels_ +def show_db(labels_true, db, corpus, corpus_embeddings, show=False): + """ + 呈现函数结果 + :param labels_true: 聚类数据的真实标签 + :param db: 训练得到模型 + :param corpus: 原始文本数据的list + :param corpus_embeddings: 待聚类的表征数据 + :param show:bool值,是否输出聚类的图形效果;默认为false + :return: + """ + labels = db.labels_ # 获取预测标签数据 + # 获取聚类数量 n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) + # 获取噪音点信息,在DBSCAN聚类中,噪音点用-1标记 n_noise_ = list(labels).count(-1) - + # 生成存放聚类数据的容器[[],[],[],[],[]...] clustered_sentences = [[] for i in range(n_clusters_)] - if corpus_file is not None: - # 读取原始文本 - corpus = pd.read_csv(corpus_file).content.to_list() + # 将同一类文本放到一个list中 for sentence_id, cluster_id in enumerate(labels): clustered_sentences[cluster_id].append(corpus[sentence_id]) - + # 输出聚类结果 for i, cluster in enumerate(clustered_sentences): print("Cluster ", i + 1) print(cluster) print("") - + # 数据聚类评价信息 print('Estimated number of clusters: %d' % n_clusters_) print('Estimated number of noise points: %d' % n_noise_) print("Homogeneity: %0.3f" % metrics.homogeneity_score(labels_true, labels)) @@ -59,28 +63,23 @@ def show_db(label_data, db, corpus, corpus_embeddings, corpus_file=None, label_f # ############################################################################# # Plot result - import matplotlib.pyplot as plt - - core_samples_mask = np.zeros_like(db.labels_, dtype=bool) - core_samples_mask[db.core_sample_indices_] = True - # Black removed and is used for noise instead. - unique_labels = set(labels) - colors = [plt.cm.Spectral(each) - for each in np.linspace(0, 1, len(unique_labels))] - for k, col in zip(unique_labels, colors): - if k == -1: - # Black used for noise. - col = [0, 0, 0, 1] - - class_member_mask = (labels == k) - - xy = corpus_embeddings[class_member_mask & core_samples_mask] - plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col), - markeredgecolor='k', markersize=14) - - xy = corpus_embeddings[class_member_mask & ~core_samples_mask] - plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col), - markeredgecolor='k', markersize=6) - - plt.title('Estimated number of clusters: %d' % n_clusters_) - plt.show() + if show: + # 获取聚类核心点信息 + core_samples_mask = np.zeros_like(db.labels_, dtype=bool) + core_samples_mask[db.core_sample_indices_] = True + # Black removed and is used for noise instead. + unique_labels = set(labels) + colors = [plt.cm.Spectral(each) + for each in np.linspace(0, 1, len(unique_labels))] + for k, col in zip(unique_labels, colors): + if k == -1: + col = [0, 0, 0, 1] # Black used for noise. + class_member_mask = (labels == k) + xy = corpus_embeddings[class_member_mask & core_samples_mask] + plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col), + markeredgecolor='k', markersize=14) + xy = corpus_embeddings[class_member_mask & ~core_samples_mask] + plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col), + markeredgecolor='k', markersize=6) + plt.title('Estimated number of clusters: %d' % n_clusters_) + plt.show()