Skip to content

added a new script to add new codec to existing lhotse shars #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
correctly handle distributed data loader
Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed May 12, 2025
commit 4326b21bd397571b14b979f7f42c00e755711e7f
77 changes: 49 additions & 28 deletions scripts/magpietts/update_magpie_shars.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def write_on_batch_end(self, trainer, pl_module, prediction, batch_indices, batc
def _write_buffer(self):
print("Writing buffer of size:", len(self.cuts_buffer))
for cut in self.cuts_buffer:
# print("Writing cut:", cut.id)
processed_cut = self.process_cut_for_saving(cut)
self.shar_writer.write(processed_cut)
self.cuts_buffer.clear()
Expand All @@ -99,14 +100,43 @@ def teardown(self, trainer, pl_module, stage: str | None = None):
self._write_buffer()
self.shar_writer.close()
self.shar_writer = None


def ddp_info():
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
return 1, 0

class CodecExtractor(pl.LightningModule):
def __init__(self, model_path: str, pad_multiple: int = 1024):
def __init__(self, model_path, pad_multiple, batch_size, shar_root, num_workers):
super().__init__()
self.pad_multiple = pad_multiple

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need this argument anymore since we can always get it from self.codec_model.samples_per_frame

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, changed this.

self.codec_model = AudioCodecModel.restore_from(restore_path=model_path, strict=False)
self.codec_model.eval()

self._dataloader = None
self.batch_size = batch_size
self.shar_root = shar_root
self.num_workers = num_workers

def setup(self, stage=None):
if self._dataloader is None:
world_size, rank = ddp_info()
print("In model.setup - world size:", world_size, "rank:", rank)
cuts = CutSet.from_shar(in_dir=self.shar_root)
sampler = SimpleCutSampler(
cuts, shuffle=False, max_cuts=self.batch_size,
world_size=world_size, rank=rank,
)
Comment on lines +128 to +131
Copy link

@XuesongYang XuesongYang May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may shard different items from the origin cuts shard. I observed this issue from my former attempts. it is worth to be double-checked.

FYI, I implemented v2 version of lhotse shar prep, and avoided this issue. Will post a PR here once I got feedbacks from Piotr.

Brief intro: now the whole lhotse shar recipe was split into two steps: nemo manifest -> cuts/target_audio/context_audio shards -> extend with target_codes/context_codes.
Step 1: https://gitlab-master.nvidia.com/xueyang/nemo/-/blob/xueyang/magpie_release_2504-lhotse-v2-debug/scripts/magpietts/create_lhotse_shar_from_nemo_manifest.py?ref_type=heads
step 2: https://gitlab-master.nvidia.com/xueyang/nemo/-/blob/xueyang/magpie_release_2504-lhotse-v2-debug/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py?ref_type=heads

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think specifying max_cuts loads batches of same size sequentially and avoids this issue, but let me double check to confirm.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and the order of the cuts is preserved in the dataloader - Shuffle false and a fixed batch size should be the reason.

dataset = SimpleCutDataset(pad_multiple=self.pad_multiple)
self._dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=None,
num_workers=self.num_workers,
)

def predict_dataloader(self):
return self._dataloader

def predict_step(self, batch, _):
audio_stacked = batch["audio"].to(device=self.device)
context_audio_stacked = batch["context_audio"].to(device=self.device)
Expand All @@ -131,10 +161,13 @@ def predict_step(self, batch, _):
print("Time to process:", time.time() - mt)
return {"cuts": cuts, "codes": codes, "context_codes": context_codes}

class MyDataset(torch.utils.data.Dataset):
class SimpleCutDataset(torch.utils.data.Dataset):
def __init__(self, pad_multiple: int = 1024):
super().__init__()
self.pad_multiple = pad_multiple

def __getitem__(self, cuts: CutSet):
# return cuts
pad_multiple = 1024
pad_multiple = self.pad_multiple
audios = cuts.load_audio()
context_audios_torch = []
audios_torch = []
Expand Down Expand Up @@ -168,30 +201,14 @@ def __getitem__(self, cuts: CutSet):
"context_audio_lens": context_audio_lens,
"cuts": cuts,
}

def __len__(self):
return len(self.cuts)

def make_loader(shar_root, batch_size, num_workers):
cuts = CutSet.from_shar(
in_dir=shar_root,
)
dataset = MyDataset()
sampler = SimpleCutSampler(cuts, shuffle=False, max_cuts=batch_size)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=None,
num_workers=num_workers,
)
return dataloader

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in_dir")
parser.add_argument("--out_dir")
parser.add_argument("--codec_ckpt")
parser.add_argument("--codec_name", default="48k")
parser.add_argument("--codec_name", default="testcodec")
parser.add_argument("--pad_multiple", type=int, default=1024)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--devices", type=int, default=-1)
Expand All @@ -208,8 +225,6 @@ def make_loader(shar_root, batch_size, num_workers):
buffer_size=args.batch_size,
)

dataloader = make_loader(args.in_dir, args.batch_size, args.num_workers)

trainer = pl.Trainer(
accelerator="gpu",
devices=args.devices,
Expand All @@ -220,6 +235,12 @@ def make_loader(shar_root, batch_size, num_workers):
use_distributed_sampler=False,
)

model = CodecExtractor(args.codec_ckpt)
model = CodecExtractor(
args.codec_ckpt,
pad_multiple=args.pad_multiple,
batch_size=args.batch_size,
shar_root=args.in_dir,
num_workers=args.num_workers,
)

trainer.predict(model, dataloaders=dataloader, return_predictions=False)
trainer.predict(model, return_predictions=False)