Skip to content

Conversation

yaozengwei
Copy link
Owner

@yaozengwei yaozengwei commented May 1, 2022

This is the implementation of Dan's idea about model average. (see k2-fsa#337)

@yaozengwei
Copy link
Owner Author

yaozengwei commented May 2, 2022

The codes are based on egs/librispeech/pruned_transducer_stateless2.
During training, the averaged model model_avg is updated each average_period batches with:
model_avg = (average_period / batch_idx_train) * model + ((batch_idx_train - average_period) / batch_idx_train) * model_avg
During decoding, Let start = batch_idx_train of model-start; end = batch_idx_train of model-end. Then the averaged model avg over epoch [start+1, start+2, ..., end] is avg = (model_end * end - model_start * start) / (end - start).
When trained on train-clean-100 with 3 gpu for 30 epochs, average_period=100, I got following results with greedy search decoding:

  • decode with epoch-29, avg=5, 7.14 & 19.33 (without averaged model) -> 7.03 & 18.85 (with averaged model);
  • decode with epoch-29, avg=10, 6.99 & 18.93 (without averaged model) -> 6.91 & 18.65 (with averaged model).

When trained on full librispeech with 6 gpu for 30 epochs, average_period=100, I got following results with greedy search decoding:

  • decode with epoch-29, avg=5, 2.77 & 6.77 (without averaged model) -> 2.72 & 6.67 (with averaged model);
  • decode with epoch-29, avg=10, 2.78 & 6.68 (without averaged model) -> 2.74 & 6.67 (with averaged model).

"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
Copy link

@csukuangfj csukuangfj May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
./pruned_transducer_stateless2/decode.py \
./pruned_transducer_stateless3/decode.py \

Also, please sync with the latest k2/icefall and rename it to pruned_transducer_stateless4

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
start = params.epoch - params.avg

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more doc to --use-average-model.
It is not clear how it is used in the code from the current help info.

filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"averaging modes over range with {filename_start} (excluded) "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"averaging modes over range with {filename_start} (excluded) "
f"averaging models over range with {filename_start} (excluded) "

checkpoint.pop("model")

if model_avg is not None and "model_avg" in checkpoint:
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a log here, e.g., saying "loading averaged model".

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

Comment on lines 423 to 436
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k

uniqued_names = list(uniqued.values())
for k in uniqued_names:
avg[k] *= weight_end
avg[k] += model_start[k] * weight_start

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is almost the same as the above function. Please refactor it to reduce redundant code.

parser.add_argument(
"--start-epoch",
type=int,
default=0,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change it so that epoch is counted from 1, not 0.

def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_avg: nn.Module = None,
model_avg: Optional[nn.Module] = None,

The return value of :func:`get_params`.
model:
The training model.
optimizer:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the doc to include model_avg.

logging.info(f"Number of model parameters: {num_param}")

assert params.save_every_n >= params.average_period
model_avg: nn.Module = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_avg: nn.Module = None
model_avg: Optional[nn.Module] = None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants