diff --git a/egs2/TEMPLATE/ssl1/pyscripts/kmeans/kmeans_update_stats.py b/egs2/TEMPLATE/ssl1/pyscripts/kmeans/kmeans_update_stats.py index 1dd7bf9b2915..c3df9b787a60 100644 --- a/egs2/TEMPLATE/ssl1/pyscripts/kmeans/kmeans_update_stats.py +++ b/egs2/TEMPLATE/ssl1/pyscripts/kmeans/kmeans_update_stats.py @@ -2,11 +2,13 @@ import argparse from contextlib import nullcontext +from itertools import repeat import logging import os import sys import time +import multiprocessing import numpy as np import pickle @@ -71,9 +73,50 @@ def pad_list(mats, lens, pad_value=0): return pad +def accum_stats(data, data_shape, kmeans_model): + + utt_ids, mats, lens = [], [], [] + stats = np.zeros((args.n_clusters, data_shape), dtype=np.float64) + counts = np.zeros(args.n_clusters, dtype=np.int64) + + frames_cnt = 0 + predict_time = 0 + + for utt, mat in data: + utt_ids.append(utt) + mats.append(mat) + lens.append(mat.shape[0]) + frames_cnt += lens[-1] + + if frames_cnt < args.batch_frames: + continue + else: + # (seq_len, dim) + mats_pad = np.concatenate(mats, axis=0) # (total_seq_len, dim) + s_t = time.time() + labels = kmeans_model.predict(mats_pad) + e_t = time.time() + predict_time += e_t - s_t + + accum_lens = np.cumsum(np.array(lens)) + labels = np.split(labels, accum_lens[:-1]) + for i, (utt_id, label) in enumerate(zip(utt_ids, labels)): + for j in range(lens[i]): + stats[label[j]] += mats[i][j] + uniq, cnt = np.unique(np.array(label[:lens[i]]), return_counts=True) + for u, c in zip(uniq, cnt): + counts[u] += c + + utt_ids, mats, lens = [], [], [] + + return stats, counts, predict_time def main(args): start_time = time.time() + + # get system info + num_cpus = len(os.sched_getaffinity(0)) + # Read in sample data for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype): sample_data = mat @@ -113,41 +156,27 @@ def main(args): # Dump labels and stats in the file out_root = os.path.dirname(args.output_label_file) os.makedirs(out_root, exist_ok=True) - utt_ids, mats, lens = [], [], [] - stats = np.zeros((args.n_clusters, sample_data.shape[1]), dtype=np.float64) - counts = np.zeros(args.n_clusters, dtype=np.int64) - - frames_cnt = 0 - predict_time = 0 with open(args.output_label_file, "w") if args.dump_label else nullcontext() as f: - for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype): - utt_ids.append(utt) - mats.append(mat) - lens.append(mat.shape[0]) - frames_cnt += lens[-1] - - if frames_cnt < args.batch_frames: - continue - else: - # (seq_len, dim) - mats_pad = np.concatenate(mats, axis=0) # (total_seq_len, dim) - s_t = time.time() - labels = kmeans_model.predict(mats_pad) - e_t = time.time() - predict_time += e_t - s_t - - accum_lens = np.cumsum(np.array(lens)) - labels = np.split(labels, accum_lens[:-1]) - for i, (utt_id, label) in enumerate(zip(utt_ids, labels)): - for j in range(lens[i]): - stats[label[j]] += mats[i][j] - uniq, cnt = np.unique(np.array(label[:lens[i]]), return_counts=True) - for u, c in zip(uniq, cnt): - counts[u] += c - - utt_ids, mats, lens = [], [], [] - + data = list(file_reader_helper(args.rspecifier, args.in_filetype)) + + # split the data between each worker + split_data = list() + split_len = len(data) / num_cpus + for i in range(num_cpus - 1): + split_data.append(data[i * num_cpus : i * (num_cpus+1)]) + split_data.append(data[(num_cpus-1) * num_cpus :]) + + # each worker parallely collects stats + pool = multiprocessing.Pool(processes = num_cpus) + split_res = pool.starmap(accum_stats, zip(split_data, repeat(sample_data.shape[1]), repeat(kmeans_model))) + + # aggregate stats + for res in split_res: + stats = sum([res[0] for res in split_res]) + counts = sum([res[1] for res in split_res]) + predict_time = max([res[2] for res in split_res]) + np.save( args.output_stats_file, np.concatenate([stats, counts[:, None]], axis=1, dtype=np.float64()),