-
Notifications
You must be signed in to change notification settings - Fork 3
added a new script to add new codec to existing lhotse shars #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
008bfd9
8d205c4
4326b21
ba26cec
99922fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Paarth Neekhara <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,6 +86,7 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc | |
def _write_buffer(self): | ||
print("Writing buffer of size:", len(self.cuts_buffer)) | ||
for cut in self.cuts_buffer: | ||
# print("Writing cut:", cut.id) | ||
processed_cut = self.process_cut_for_saving(cut) | ||
self.shar_writer.write(processed_cut) | ||
self.cuts_buffer.clear() | ||
|
@@ -99,14 +100,43 @@ def teardown(self, trainer, pl_module, stage: str | None = None): | |
self._write_buffer() | ||
self.shar_writer.close() | ||
self.shar_writer = None | ||
|
||
|
||
def ddp_info(): | ||
if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
return torch.distributed.get_world_size(), torch.distributed.get_rank() | ||
return 1, 0 | ||
|
||
class CodecExtractor(pl.LightningModule): | ||
def __init__(self, model_path: str, pad_multiple: int = 1024): | ||
def __init__(self, model_path, pad_multiple, batch_size, shar_root, num_workers): | ||
super().__init__() | ||
self.pad_multiple = pad_multiple | ||
self.codec_model = AudioCodecModel.restore_from(restore_path=model_path, strict=False) | ||
self.codec_model.eval() | ||
|
||
self._dataloader = None | ||
self.batch_size = batch_size | ||
self.shar_root = shar_root | ||
self.num_workers = num_workers | ||
|
||
def setup(self, stage=None): | ||
if self._dataloader is None: | ||
world_size, rank = ddp_info() | ||
print("In model.setup - world size:", world_size, "rank:", rank) | ||
cuts = CutSet.from_shar(in_dir=self.shar_root) | ||
sampler = SimpleCutSampler( | ||
cuts, shuffle=False, max_cuts=self.batch_size, | ||
world_size=world_size, rank=rank, | ||
) | ||
Comment on lines
+128
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this may shard different items from the origin cuts shard. I observed this issue from my former attempts. it is worth to be double-checked. FYI, I implemented v2 version of lhotse shar prep, and avoided this issue. Will post a PR here once I got feedbacks from Piotr. Brief intro: now the whole lhotse shar recipe was split into two steps: nemo manifest -> cuts/target_audio/context_audio shards -> extend with target_codes/context_codes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think specifying max_cuts loads batches of same size sequentially and avoids this issue, but let me double check to confirm. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked and the order of the cuts is preserved in the dataloader - Shuffle false and a fixed batch size should be the reason. |
||
dataset = SimpleCutDataset(pad_multiple=self.pad_multiple) | ||
self._dataloader = DataLoader( | ||
dataset, | ||
sampler=sampler, | ||
batch_size=None, | ||
num_workers=self.num_workers, | ||
) | ||
|
||
def predict_dataloader(self): | ||
return self._dataloader | ||
|
||
def predict_step(self, batch, _): | ||
audio_stacked = batch["audio"].to(device=self.device) | ||
context_audio_stacked = batch["context_audio"].to(device=self.device) | ||
|
@@ -131,10 +161,13 @@ def predict_step(self, batch, _): | |
print("Time to process:", time.time() - mt) | ||
return {"cuts": cuts, "codes": codes, "context_codes": context_codes} | ||
|
||
class MyDataset(torch.utils.data.Dataset): | ||
class SimpleCutDataset(torch.utils.data.Dataset): | ||
def __init__(self, pad_multiple: int = 1024): | ||
super().__init__() | ||
self.pad_multiple = pad_multiple | ||
|
||
def __getitem__(self, cuts: CutSet): | ||
# return cuts | ||
pad_multiple = 1024 | ||
pad_multiple = self.pad_multiple | ||
audios = cuts.load_audio() | ||
context_audios_torch = [] | ||
audios_torch = [] | ||
|
@@ -168,30 +201,14 @@ def __getitem__(self, cuts: CutSet): | |
"context_audio_lens": context_audio_lens, | ||
"cuts": cuts, | ||
} | ||
|
||
def __len__(self): | ||
return len(self.cuts) | ||
|
||
def make_loader(shar_root, batch_size, num_workers): | ||
cuts = CutSet.from_shar( | ||
in_dir=shar_root, | ||
) | ||
dataset = MyDataset() | ||
sampler = SimpleCutSampler(cuts, shuffle=False, max_cuts=batch_size) | ||
dataloader = DataLoader( | ||
dataset, | ||
sampler=sampler, | ||
batch_size=None, | ||
num_workers=num_workers, | ||
) | ||
return dataloader | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--in_dir") | ||
parser.add_argument("--out_dir") | ||
parser.add_argument("--codec_ckpt") | ||
parser.add_argument("--codec_name", default="48k") | ||
parser.add_argument("--codec_name", default="testcodec") | ||
parser.add_argument("--pad_multiple", type=int, default=1024) | ||
parser.add_argument("--batch_size", type=int, default=64) | ||
parser.add_argument("--num_workers", type=int, default=4) | ||
parser.add_argument("--devices", type=int, default=-1) | ||
|
@@ -208,8 +225,6 @@ def make_loader(shar_root, batch_size, num_workers): | |
buffer_size=args.batch_size, | ||
) | ||
|
||
dataloader = make_loader(args.in_dir, args.batch_size, args.num_workers) | ||
|
||
trainer = pl.Trainer( | ||
accelerator="gpu", | ||
devices=args.devices, | ||
|
@@ -220,6 +235,12 @@ def make_loader(shar_root, batch_size, num_workers): | |
use_distributed_sampler=False, | ||
) | ||
|
||
model = CodecExtractor(args.codec_ckpt) | ||
model = CodecExtractor( | ||
args.codec_ckpt, | ||
pad_multiple=args.pad_multiple, | ||
batch_size=args.batch_size, | ||
shar_root=args.in_dir, | ||
num_workers=args.num_workers, | ||
) | ||
|
||
trainer.predict(model, dataloaders=dataloader, return_predictions=False) | ||
trainer.predict(model, return_predictions=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need this argument anymore since we can always get it from
self.codec_model.samples_per_frame
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, changed this.