Skip to content

Support multiple parallel augmentations.#1942

Open
csukuangfj wants to merge 1 commit into
k2-fsa:masterfrom
csukuangfj:dataset-parallel-augmentation
Open

Support multiple parallel augmentations.#1942
csukuangfj wants to merge 1 commit into
k2-fsa:masterfrom
csukuangfj:dataset-parallel-augmentation

Conversation

@csukuangfj

Copy link
Copy Markdown
Collaborator

See also lhotse-speech/lhotse#1477

(I prefer to put it inside icefall. If moving to lhotse is desired, I can do that.)

Test code

#!/usr/bin/env python3

from functools import partial

import lhotse
from lhotse import Fbank, FbankConfig
from lhotse.dataset import SimpleCutSampler
from lhotse.dataset.input_strategies import (
    AudioSamples,
    OnTheFlyFeatures,
)
from torch.utils.data.dataloader import DataLoader

from speech_recognition_dataset import ConsistencyRegularizationSpeechRecognitionDataset


def create_cutset():
    recording0 = lhotse.Recording.from_file("./0.wav")
    sup0 = lhotse.SupervisionSegment(
        id="sup0",
        recording_id=recording0.id,
        start=0,
        duration=5,
        text="hello, how are you",
    )
    cut0 = lhotse.MonoCut(
        id="cut0",
        start=0,
        duration=6,
        channel=0,
        recording=recording0,
        supervisions=[sup0],
    )

    recording1 = lhotse.Recording.from_file("./1.wav")
    sup1 = lhotse.SupervisionSegment(
        id="sup1",
        recording_id=recording1.id,
        start=0,
        duration=3,
        text="fine, thank you",
    )
    cut1 = lhotse.MonoCut(
        id="cut1",
        start=0,
        duration=4,
        channel=0,
        recording=recording1,
        supervisions=[sup1],
    )

    return lhotse.CutSet([cut0, cut1])


def main():
    cutset = create_cutset()
    print([c for c in cutset])

    t1 = partial(lhotse.MonoCut.perturb_speed, factor=0.9)
    t2 = partial(lhotse.MonoCut.perturb_volume, factor=1.1)
    t3 = partial(lhotse.MonoCut.perturb_tempo, factor=1.2)
    transforms = [t1, t2, t3]

    train = ConsistencyRegularizationSpeechRecognitionDataset(
        input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
        #  input_strategy=AudioSamples(),
        cut_transforms=transforms,
        #  return_cuts=True,
        return_cuts=False,
    )
    print(train)
    train_sampler = SimpleCutSampler(
        cutset,
        max_duration=20,
        shuffle=False,
    )

    train_dl = DataLoader(
        train,
        sampler=train_sampler,
        batch_size=None,
        num_workers=1,
        persistent_workers=False,
    )
    for b in train_dl:
        print(b["inputs"].shape, b["supervisions"])
        if "aug" in b:
            assert len(b["aug"]) == len(transforms)
            for i, aug in enumerate(b["aug"]):
                print(
                    transforms[i].func.__name__,
                    aug["inputs"].shape,
                    aug["supervisions"],
                )


if __name__ == "__main__":
    main()

The output is given below:

[MonoCut(id='cut0', start=0, duration=6, channel=0, supervisions=[SupervisionSegment(id='sup0', recording_id='0', start=0, duration=5, channel=0, text='hello, how are you', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='0', sources=[AudioSource(type='file', channels=[0], source='0.wav')], sampling_rate=16000, num_samples=106000, duration=6.625, channel_ids=[0], transforms=None), custom=None), MonoCut(id='cut1', start=0, duration=4, channel=0, supervisions=[SupervisionSegment(id='sup1', recording_id='1', start=0, duration=3, channel=0, text='fine, thank you', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='1', sources=[AudioSource(type='file', channels=[0], source='1.wav')], sampling_rate=16000, num_samples=81600, duration=5.1, channel_ids=[0], transforms=None), custom=None)]
<speech_recognition_dataset.ConsistencyRegularizationSpeechRecognitionDataset object at 0x119f1c8e0>
torch.Size([2, 600, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([500, 300], dtype=torch.int32)}
perturb_speed torch.Size([2, 667, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([556, 333], dtype=torch.int32)}
perturb_volume torch.Size([2, 600, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([500, 300], dtype=torch.int32)}
perturb_tempo torch.Size([2, 500, 80]) {'text': ['hello, how are you', 'fine, thank you'], 'sequence_idx': tensor([0, 1], dtype=torch.int32), 'start_frame': tensor([0, 0], dtype=torch.int32), 'num_frames': tensor([417, 250], dtype=torch.int32)}

@pzelasko

Copy link
Copy Markdown
Collaborator

Look good to me, +1 for keeping this code in Icefall, I think it’s more convenient to have it close to the training recipe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants