Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 62 additions & 33 deletions egs2/TEMPLATE/ssl1/pyscripts/kmeans/kmeans_update_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import argparse
from contextlib import nullcontext
from itertools import repeat
import logging
import os
import sys
import time

import multiprocessing
import numpy as np
import pickle

Expand Down Expand Up @@ -71,9 +73,50 @@ def pad_list(mats, lens, pad_value=0):

return pad

def accum_stats(data, data_shape, kmeans_model):

utt_ids, mats, lens = [], [], []
stats = np.zeros((args.n_clusters, data_shape), dtype=np.float64)
counts = np.zeros(args.n_clusters, dtype=np.int64)

frames_cnt = 0
predict_time = 0

for utt, mat in data:
utt_ids.append(utt)
mats.append(mat)
lens.append(mat.shape[0])
frames_cnt += lens[-1]

if frames_cnt < args.batch_frames:
continue
else:
# (seq_len, dim)
mats_pad = np.concatenate(mats, axis=0) # (total_seq_len, dim)
s_t = time.time()
labels = kmeans_model.predict(mats_pad)
e_t = time.time()
predict_time += e_t - s_t

accum_lens = np.cumsum(np.array(lens))
labels = np.split(labels, accum_lens[:-1])
for i, (utt_id, label) in enumerate(zip(utt_ids, labels)):
for j in range(lens[i]):
stats[label[j]] += mats[i][j]
uniq, cnt = np.unique(np.array(label[:lens[i]]), return_counts=True)
for u, c in zip(uniq, cnt):
counts[u] += c

utt_ids, mats, lens = [], [], []

return stats, counts, predict_time

def main(args):
start_time = time.time()

# get system info
num_cpus = len(os.sched_getaffinity(0))

# Read in sample data
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
sample_data = mat
Expand Down Expand Up @@ -113,41 +156,27 @@ def main(args):
# Dump labels and stats in the file
out_root = os.path.dirname(args.output_label_file)
os.makedirs(out_root, exist_ok=True)
utt_ids, mats, lens = [], [], []
stats = np.zeros((args.n_clusters, sample_data.shape[1]), dtype=np.float64)
counts = np.zeros(args.n_clusters, dtype=np.int64)

frames_cnt = 0
predict_time = 0

with open(args.output_label_file, "w") if args.dump_label else nullcontext() as f:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
utt_ids.append(utt)
mats.append(mat)
lens.append(mat.shape[0])
frames_cnt += lens[-1]

if frames_cnt < args.batch_frames:
continue
else:
# (seq_len, dim)
mats_pad = np.concatenate(mats, axis=0) # (total_seq_len, dim)
s_t = time.time()
labels = kmeans_model.predict(mats_pad)
e_t = time.time()
predict_time += e_t - s_t

accum_lens = np.cumsum(np.array(lens))
labels = np.split(labels, accum_lens[:-1])
for i, (utt_id, label) in enumerate(zip(utt_ids, labels)):
for j in range(lens[i]):
stats[label[j]] += mats[i][j]
uniq, cnt = np.unique(np.array(label[:lens[i]]), return_counts=True)
for u, c in zip(uniq, cnt):
counts[u] += c

utt_ids, mats, lens = [], [], []

data = list(file_reader_helper(args.rspecifier, args.in_filetype))

# split the data between each worker
split_data = list()
split_len = len(data) / num_cpus
for i in range(num_cpus - 1):
split_data.append(data[i * num_cpus : i * (num_cpus+1)])
split_data.append(data[(num_cpus-1) * num_cpus :])

# each worker parallely collects stats
pool = multiprocessing.Pool(processes = num_cpus)
split_res = pool.starmap(accum_stats, zip(split_data, repeat(sample_data.shape[1]), repeat(kmeans_model)))

# aggregate stats
for res in split_res:
stats = sum([res[0] for res in split_res])
counts = sum([res[1] for res in split_res])
predict_time = max([res[2] for res in split_res])

np.save(
args.output_stats_file,
np.concatenate([stats, counts[:, None]], axis=1, dtype=np.float64()),
Expand Down