Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 80 additions & 78 deletions abnet3/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from abnet3.utils import normalize_distribution, cumulative_distribution
from abnet3.utils import print_token, sample_searchidx
from abnet3.utils import print_token, sample_searchidx, samplepairs_searchidx
from abnet3.utils import read_spkid_file, read_spk_list, progress

import numpy as np
Expand Down Expand Up @@ -298,14 +298,11 @@ def type_samp_func(x): return np.log(1+x)
W_types[tokens_type[tok]] += 1.0
except Exception as e:
W_types[tokens_type[tok]] = 1.0
p_types = {"Stype": {}, "Dtype": {}}

p_types = dict()

for type_idx in range(nb_types):
p_types["Stype"][type_idx] = type_samp_func(W_types[type_idx])
for type_jdx in range(type_idx+1, nb_types):
p_types["Dtype"][(type_idx, type_jdx)] = \
type_samp_func(W_types[type_idx]) * \
type_samp_func(W_types[type_jdx])
p_types[type_idx] = type_samp_func(W_types[type_idx])
return p_types

def sample_spk_p(self, std_descr, spk_sampling_mode='log'):
Expand All @@ -319,9 +316,11 @@ def sample_spk_p(self, std_descr, spk_sampling_mode='log'):
"""
nb_tok = len(std_descr['tokens'])
tokens_type = std_descr['tokens_type']
types = std_descr['types']
p_spk_types = {'Stype_Sspk': {}, 'Stype_Dspk': {},
'Dtype_Sspk': {}, 'Dtype_Dspk': {}}
speakers_for_token = std_descr['tokens_speaker']
speakers = std_descr['speakers']
W_spk_types = {}
for tok in range(nb_tok):
try:
Expand Down Expand Up @@ -351,32 +350,29 @@ def spk_samp_func(x): return np.log(1+x)
for (spk, type_idx) in W_spk_types.keys():
print_progress(i)
i += 1
for (spk2, type_jdx) in W_spk_types.keys():
if spk == spk2:
if type_idx == type_jdx:
if (W_spk_types[(spk, type_idx)] - 1) == 0:
p_spk_types['Stype_Sspk'][(spk, type_idx)] = 0.0
else:
p_spk_types['Stype_Sspk'][(spk, type_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)])
else:
min_idx = min(type_idx, type_jdx)
max_idx = max(type_idx, type_jdx)
# Dtype, Dspk
p_spk_types['Dtype_Dspk'][(spk, type_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)])
# Stype, Sspk
if (W_spk_types[(spk, type_idx)] - 1) == 0:
p_spk_types['Stype_Sspk'][(spk, type_idx)] = 0.0
else:
p_spk_types['Stype_Sspk'][(spk, type_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)])
# Dtype, Sspk
for type_jdx in range(type_idx + 1, len(types)):
min_idx = min(type_idx, type_jdx)
max_idx = max(type_idx, type_jdx)
if (spk, type_jdx) in W_spk_types:
p_spk_types['Dtype_Sspk'][(spk, min_idx, max_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)]) * \
spk_samp_func(W_spk_types[(spk, type_jdx)])
else:
if type_idx == type_jdx:
p_spk_types['Stype_Dspk'][(spk, spk2, type_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)]) * \
spk_samp_func(W_spk_types[(spk2, type_idx)])
else:
min_idx = min(type_idx, type_jdx)
max_idx = max(type_idx, type_jdx)
p_spk_types['Dtype_Dspk'][(spk, spk2,
min_idx, max_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)]) * \
spk_samp_func(W_spk_types[(spk2, type_jdx)])
# Stype, Dspk
for spk2 in speakers:
if spk != spk2 and (spk2, type_idx) in W_spk_types:
p_spk_types['Stype_Dspk'][(spk, spk2, type_idx)] = \
spk_samp_func(W_spk_types[(spk, type_idx)]) * \
spk_samp_func(W_spk_types[(spk2, type_idx)])
return p_spk_types

def generate_token_dict(self, std_descr):
Expand Down Expand Up @@ -428,16 +424,13 @@ def type_speaker_sampling_p(self, std_descr=None,
"""
assert type_sampling_mode in ['1', 'f', 'f2', 'log', 'fcube']
assert spk_sampling_mode in ['1', 'f', 'f2', 'log', 'fcube']
# W_types = std_descr['types']
# speakers = [e for e in std_descr['speakers']]
# W_speakers = [std_descr['speakers'][e] for e in speakers]

p_types = self.type_sample_p(std_descr,
type_sampling_mode=type_sampling_mode)
p_spk_types = self.sample_spk_p(std_descr,
spk_sampling_mode=spk_sampling_mode)

for config in p_types.keys():
p_types[config] = normalize_distribution(p_types[config])
p_types = normalize_distribution(p_types)

for config in p_spk_types.keys():
p_spk_types[config] = normalize_distribution(p_spk_types[config])
Expand All @@ -451,23 +444,24 @@ def type_speaker_sampling_p(self, std_descr=None,
i += 1
if config == 'Stype_Sspk':
for el in p_spk_types[config].keys():
p_spk_types[config][el] = p_types['Stype'][el[1]] * \
p_spk_types[config][el] = \
p_types[el[1]] * \
p_spk_types[config][el]
if config == 'Stype_Dspk':
for el in p_spk_types[config].keys():
p_spk_types[config][el] = p_types['Stype'][el[2]] * \
p_spk_types[config][el] = \
p_types[el[2]] * \
p_spk_types[config][el]
if config == 'Dtype_Sspk':
for el in p_spk_types[config].keys():
p_spk_types[config][el] = p_types['Dtype'][
(el[1],
el[2])] * \
p_spk_types[config][el] = \
p_types[el[1]] * \
p_types[el[2]] * \
p_spk_types[config][el]
if config == 'Dtype_Dspk':
for el in p_spk_types[config].keys():
p_spk_types[config][el] = p_types['Dtype'][
(el[2],
el[3])] * \
p_spk_types[config][el] = \
p_types[el[1]] * \
p_spk_types[config][el]

for config in p_spk_types.keys():
Expand Down Expand Up @@ -551,44 +545,52 @@ def sample_batch(self,
'Dtype_Dspk': num_Dtype_Dspk
}
for config in p_spk_types.keys():
keys = np.array(list(p_spk_types[config].keys()))
sample_idx = sample_searchidx(cdf[config], sampled_ratio[config])
sample = keys[sample_idx]
if config == 'Stype_Sspk':
for key in sample:
spk, type_idx = key
tokens = token_dict[int(type_idx), spk]
tok1, tok2 = np.random.choice(tokens, size=2,
replace=False)
sampled_tokens[config].append(
(tok1, tok2))
if config == 'Stype_Dspk':
for key in sample:
spk1, spk2, type_idx = key
type_idx = int(type_idx)
tok1 = np.random.choice(token_dict[type_idx, spk1])
tok2 = np.random.choice(token_dict[type_idx, spk2])
sampled_tokens[config].append((tok1, tok2))
if config == 'Dtype_Sspk':
for key in sample:
spk, type_idx, type_jdx = key
type_idx = int(type_idx)
type_jdx = int(type_jdx)
tok1 = np.random.choice(token_dict[type_idx, spk])
tok2 = np.random.choice(token_dict[type_jdx, spk])
sampled_tokens[config].append((tok1, tok2))
if config == 'Dtype_Dspk':
"""
Dtype_Dspk is particular
We sample two items and check they are different
"""
keys = np.array(list(p_spk_types[config].keys()))
sample_idx = samplepairs_searchidx(cdf[config], sampled_ratio[config], keys=keys)
sample = keys[sample_idx]
for key in sample:
spk1, spk2, type_idx, type_jdx = key
type_idx = int(type_idx)
type_jdx = int(type_jdx)
try:
tok1 = np.random.choice(token_dict[type_idx, spk1])
tok2 = np.random.choice(token_dict[type_jdx, spk2])
except Exception:
tok1 = np.random.choice(token_dict[type_idx, spk2])
tok2 = np.random.choice(token_dict[type_jdx, spk1])
(spk1, type1), (spk2, type2) = key
type1 = int(type1)
type2 = int(type2)
assert spk1 != spk2
assert type1 != type2
tok1 = np.random.choice(token_dict[type1, spk1])
tok2 = np.random.choice(token_dict[type2, spk2])
sampled_tokens[config].append((tok1, tok2))
else:
keys = np.array(list(p_spk_types[config].keys()))
sample_idx = sample_searchidx(cdf[config], sampled_ratio[config])
sample = keys[sample_idx]
if config == 'Stype_Sspk':
for key in sample:
spk, type_idx = key
tokens = token_dict[int(type_idx), spk]
tok1, tok2 = np.random.choice(tokens, size=2,
replace=False)
sampled_tokens[config].append(
(tok1, tok2))
if config == 'Stype_Dspk':
for key in sample:
spk1, spk2, type_idx = key
assert spk1 != spk2
type_idx = int(type_idx)
tok1 = np.random.choice(token_dict[type_idx, spk1])
tok2 = np.random.choice(token_dict[type_idx, spk2])
sampled_tokens[config].append((tok1, tok2))
if config == 'Dtype_Sspk':
for key in sample:
spk, type_idx, type_jdx = key
assert type_idx != type_jdx
type_idx = int(type_idx)
type_jdx = int(type_jdx)
tok1 = np.random.choice(token_dict[type_idx, spk])
tok2 = np.random.choice(token_dict[type_jdx, spk])
sampled_tokens[config].append((tok1, tok2))
return sampled_tokens

def write_tokens(self, descr=None, proba=None, cdf=None,
Expand Down
27 changes: 27 additions & 0 deletions abnet3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ def sample_searchidx(cdf, num_samples):
return idx


def samplepairs_searchidx(cdf, num_samples, keys):
"""
Sample indexes based on cdf distribution
This function samples pairs of *different* elements (ie without
replacement)
It samples randomly pairs of elements, and reruns the elements that have
one thing in common (this is used to sample pairs of (spk, type)
where both spk and type are different
"""
iterations = 0 # limit 5 iterations to avoid infinite loops
uniform_samples = np.random.random_sample((int(num_samples), 2))
idx = cdf.searchsorted(uniform_samples, side='right')
while True:
iterations += 1
if iterations > 30:
print("Warning : more than 30 iterations to sample different pairs")
pair_keys = keys[idx]
index_same_spk = np.where(pair_keys[:, 0, 0] == pair_keys[:, 1, 0])[0]
index_same_type = np.where(pair_keys[:, 0, 1] == pair_keys[:, 1, 1])[0]
indices_to_change = np.concatenate((index_same_spk, index_same_type))
num_samples_same = len(indices_to_change)
if num_samples_same == 0:
break
new_samples = np.random.random_sample((int(num_samples_same), 2))
idx[indices_to_change] = cdf.searchsorted(new_samples, side='right')
return idx

def print_token(tok):
"""Pretty print token for batches

Expand Down