diff --git a/abnet3/sampler.py b/abnet3/sampler.py index 251912f..1a804d8 100644 --- a/abnet3/sampler.py +++ b/abnet3/sampler.py @@ -10,7 +10,7 @@ """ from abnet3.utils import normalize_distribution, cumulative_distribution -from abnet3.utils import print_token, sample_searchidx +from abnet3.utils import print_token, sample_searchidx, samplepairs_searchidx from abnet3.utils import read_spkid_file, read_spk_list, progress import numpy as np @@ -298,14 +298,11 @@ def type_samp_func(x): return np.log(1+x) W_types[tokens_type[tok]] += 1.0 except Exception as e: W_types[tokens_type[tok]] = 1.0 - p_types = {"Stype": {}, "Dtype": {}} + + p_types = dict() for type_idx in range(nb_types): - p_types["Stype"][type_idx] = type_samp_func(W_types[type_idx]) - for type_jdx in range(type_idx+1, nb_types): - p_types["Dtype"][(type_idx, type_jdx)] = \ - type_samp_func(W_types[type_idx]) * \ - type_samp_func(W_types[type_jdx]) + p_types[type_idx] = type_samp_func(W_types[type_idx]) return p_types def sample_spk_p(self, std_descr, spk_sampling_mode='log'): @@ -319,9 +316,11 @@ def sample_spk_p(self, std_descr, spk_sampling_mode='log'): """ nb_tok = len(std_descr['tokens']) tokens_type = std_descr['tokens_type'] + types = std_descr['types'] p_spk_types = {'Stype_Sspk': {}, 'Stype_Dspk': {}, 'Dtype_Sspk': {}, 'Dtype_Dspk': {}} speakers_for_token = std_descr['tokens_speaker'] + speakers = std_descr['speakers'] W_spk_types = {} for tok in range(nb_tok): try: @@ -351,32 +350,29 @@ def spk_samp_func(x): return np.log(1+x) for (spk, type_idx) in W_spk_types.keys(): print_progress(i) i += 1 - for (spk2, type_jdx) in W_spk_types.keys(): - if spk == spk2: - if type_idx == type_jdx: - if (W_spk_types[(spk, type_idx)] - 1) == 0: - p_spk_types['Stype_Sspk'][(spk, type_idx)] = 0.0 - else: - p_spk_types['Stype_Sspk'][(spk, type_idx)] = \ - spk_samp_func(W_spk_types[(spk, type_idx)]) - else: - min_idx = min(type_idx, type_jdx) - max_idx = max(type_idx, type_jdx) + # Dtype, Dspk + p_spk_types['Dtype_Dspk'][(spk, type_idx)] = \ + spk_samp_func(W_spk_types[(spk, type_idx)]) + # Stype, Sspk + if (W_spk_types[(spk, type_idx)] - 1) == 0: + p_spk_types['Stype_Sspk'][(spk, type_idx)] = 0.0 + else: + p_spk_types['Stype_Sspk'][(spk, type_idx)] = \ + spk_samp_func(W_spk_types[(spk, type_idx)]) + # Dtype, Sspk + for type_jdx in range(type_idx + 1, len(types)): + min_idx = min(type_idx, type_jdx) + max_idx = max(type_idx, type_jdx) + if (spk, type_jdx) in W_spk_types: p_spk_types['Dtype_Sspk'][(spk, min_idx, max_idx)] = \ spk_samp_func(W_spk_types[(spk, type_idx)]) * \ spk_samp_func(W_spk_types[(spk, type_jdx)]) - else: - if type_idx == type_jdx: - p_spk_types['Stype_Dspk'][(spk, spk2, type_idx)] = \ - spk_samp_func(W_spk_types[(spk, type_idx)]) * \ - spk_samp_func(W_spk_types[(spk2, type_idx)]) - else: - min_idx = min(type_idx, type_jdx) - max_idx = max(type_idx, type_jdx) - p_spk_types['Dtype_Dspk'][(spk, spk2, - min_idx, max_idx)] = \ - spk_samp_func(W_spk_types[(spk, type_idx)]) * \ - spk_samp_func(W_spk_types[(spk2, type_jdx)]) + # Stype, Dspk + for spk2 in speakers: + if spk != spk2 and (spk2, type_idx) in W_spk_types: + p_spk_types['Stype_Dspk'][(spk, spk2, type_idx)] = \ + spk_samp_func(W_spk_types[(spk, type_idx)]) * \ + spk_samp_func(W_spk_types[(spk2, type_idx)]) return p_spk_types def generate_token_dict(self, std_descr): @@ -428,16 +424,13 @@ def type_speaker_sampling_p(self, std_descr=None, """ assert type_sampling_mode in ['1', 'f', 'f2', 'log', 'fcube'] assert spk_sampling_mode in ['1', 'f', 'f2', 'log', 'fcube'] - # W_types = std_descr['types'] - # speakers = [e for e in std_descr['speakers']] - # W_speakers = [std_descr['speakers'][e] for e in speakers] + p_types = self.type_sample_p(std_descr, type_sampling_mode=type_sampling_mode) p_spk_types = self.sample_spk_p(std_descr, spk_sampling_mode=spk_sampling_mode) - for config in p_types.keys(): - p_types[config] = normalize_distribution(p_types[config]) + p_types = normalize_distribution(p_types) for config in p_spk_types.keys(): p_spk_types[config] = normalize_distribution(p_spk_types[config]) @@ -451,23 +444,24 @@ def type_speaker_sampling_p(self, std_descr=None, i += 1 if config == 'Stype_Sspk': for el in p_spk_types[config].keys(): - p_spk_types[config][el] = p_types['Stype'][el[1]] * \ + p_spk_types[config][el] = \ + p_types[el[1]] * \ p_spk_types[config][el] if config == 'Stype_Dspk': for el in p_spk_types[config].keys(): - p_spk_types[config][el] = p_types['Stype'][el[2]] * \ + p_spk_types[config][el] = \ + p_types[el[2]] * \ p_spk_types[config][el] if config == 'Dtype_Sspk': for el in p_spk_types[config].keys(): - p_spk_types[config][el] = p_types['Dtype'][ - (el[1], - el[2])] * \ + p_spk_types[config][el] = \ + p_types[el[1]] * \ + p_types[el[2]] * \ p_spk_types[config][el] if config == 'Dtype_Dspk': for el in p_spk_types[config].keys(): - p_spk_types[config][el] = p_types['Dtype'][ - (el[2], - el[3])] * \ + p_spk_types[config][el] = \ + p_types[el[1]] * \ p_spk_types[config][el] for config in p_spk_types.keys(): @@ -551,44 +545,52 @@ def sample_batch(self, 'Dtype_Dspk': num_Dtype_Dspk } for config in p_spk_types.keys(): - keys = np.array(list(p_spk_types[config].keys())) - sample_idx = sample_searchidx(cdf[config], sampled_ratio[config]) - sample = keys[sample_idx] - if config == 'Stype_Sspk': - for key in sample: - spk, type_idx = key - tokens = token_dict[int(type_idx), spk] - tok1, tok2 = np.random.choice(tokens, size=2, - replace=False) - sampled_tokens[config].append( - (tok1, tok2)) - if config == 'Stype_Dspk': - for key in sample: - spk1, spk2, type_idx = key - type_idx = int(type_idx) - tok1 = np.random.choice(token_dict[type_idx, spk1]) - tok2 = np.random.choice(token_dict[type_idx, spk2]) - sampled_tokens[config].append((tok1, tok2)) - if config == 'Dtype_Sspk': - for key in sample: - spk, type_idx, type_jdx = key - type_idx = int(type_idx) - type_jdx = int(type_jdx) - tok1 = np.random.choice(token_dict[type_idx, spk]) - tok2 = np.random.choice(token_dict[type_jdx, spk]) - sampled_tokens[config].append((tok1, tok2)) if config == 'Dtype_Dspk': + """ + Dtype_Dspk is particular + We sample two items and check they are different + """ + keys = np.array(list(p_spk_types[config].keys())) + sample_idx = samplepairs_searchidx(cdf[config], sampled_ratio[config], keys=keys) + sample = keys[sample_idx] for key in sample: - spk1, spk2, type_idx, type_jdx = key - type_idx = int(type_idx) - type_jdx = int(type_jdx) - try: - tok1 = np.random.choice(token_dict[type_idx, spk1]) - tok2 = np.random.choice(token_dict[type_jdx, spk2]) - except Exception: - tok1 = np.random.choice(token_dict[type_idx, spk2]) - tok2 = np.random.choice(token_dict[type_jdx, spk1]) + (spk1, type1), (spk2, type2) = key + type1 = int(type1) + type2 = int(type2) + assert spk1 != spk2 + assert type1 != type2 + tok1 = np.random.choice(token_dict[type1, spk1]) + tok2 = np.random.choice(token_dict[type2, spk2]) sampled_tokens[config].append((tok1, tok2)) + else: + keys = np.array(list(p_spk_types[config].keys())) + sample_idx = sample_searchidx(cdf[config], sampled_ratio[config]) + sample = keys[sample_idx] + if config == 'Stype_Sspk': + for key in sample: + spk, type_idx = key + tokens = token_dict[int(type_idx), spk] + tok1, tok2 = np.random.choice(tokens, size=2, + replace=False) + sampled_tokens[config].append( + (tok1, tok2)) + if config == 'Stype_Dspk': + for key in sample: + spk1, spk2, type_idx = key + assert spk1 != spk2 + type_idx = int(type_idx) + tok1 = np.random.choice(token_dict[type_idx, spk1]) + tok2 = np.random.choice(token_dict[type_idx, spk2]) + sampled_tokens[config].append((tok1, tok2)) + if config == 'Dtype_Sspk': + for key in sample: + spk, type_idx, type_jdx = key + assert type_idx != type_jdx + type_idx = int(type_idx) + type_jdx = int(type_jdx) + tok1 = np.random.choice(token_dict[type_idx, spk]) + tok2 = np.random.choice(token_dict[type_jdx, spk]) + sampled_tokens[config].append((tok1, tok2)) return sampled_tokens def write_tokens(self, descr=None, proba=None, cdf=None, diff --git a/abnet3/utils.py b/abnet3/utils.py index a4b8c39..83d1145 100644 --- a/abnet3/utils.py +++ b/abnet3/utils.py @@ -97,6 +97,33 @@ def sample_searchidx(cdf, num_samples): return idx +def samplepairs_searchidx(cdf, num_samples, keys): + """ + Sample indexes based on cdf distribution + This function samples pairs of *different* elements (ie without + replacement) + It samples randomly pairs of elements, and reruns the elements that have + one thing in common (this is used to sample pairs of (spk, type) + where both spk and type are different + """ + iterations = 0 # limit 5 iterations to avoid infinite loops + uniform_samples = np.random.random_sample((int(num_samples), 2)) + idx = cdf.searchsorted(uniform_samples, side='right') + while True: + iterations += 1 + if iterations > 30: + print("Warning : more than 30 iterations to sample different pairs") + pair_keys = keys[idx] + index_same_spk = np.where(pair_keys[:, 0, 0] == pair_keys[:, 1, 0])[0] + index_same_type = np.where(pair_keys[:, 0, 1] == pair_keys[:, 1, 1])[0] + indices_to_change = np.concatenate((index_same_spk, index_same_type)) + num_samples_same = len(indices_to_change) + if num_samples_same == 0: + break + new_samples = np.random.random_sample((int(num_samples_same), 2)) + idx[indices_to_change] = cdf.searchsorted(new_samples, side='right') + return idx + def print_token(tok): """Pretty print token for batches