Skip to content

Commit 33b9f3b

Browse files
committed
revise dropout removal
1 parent 90ecfab commit 33b9f3b

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/nnetbin/nnet-rm-dropout.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "nnet/nnet-nnet.h"
2121
#include "nnet/nnet-component.h"
2222
#include "nnet/nnet-activation.h"
23+
#include "nnet/nnet-biasedlinearity.h"
2324

2425
int main(int argc, char *argv[]) {
2526
try {
@@ -58,13 +59,25 @@ int main(int argc, char *argv[]) {
5859
{
5960
Output ko(model_out_filename, binary_write);
6061

62+
bool apply_scale = false;
63+
BaseFloat scale = 1.0;
64+
6165
for (int32 i=0; i<nnet.LayerCount(); ++i){
6266
Component *layer = nnet.Layer(i);
6367
if(layer->GetType()==Component::kDropout){
6468
Dropout *dp=dynamic_cast<Dropout*>(layer);
65-
Scale *sc=new Scale(dp->InputDim(), dp->OutputDim(), NULL);
66-
sc->SetScale(1- dp->GetDropRatio());
67-
sc->Write(ko.Stream(), binary_write);
69+
scale = 1- dp->GetDropRatio();
70+
apply_scale = true;
71+
} else if (apply_scale && layer->GetType()==Component::kBiasedLinearity){
72+
BiasedLinearity *bl=dynamic_cast<BiasedLinearity*>(layer);
73+
CuMatrix<BaseFloat> weight(bl->GetLinearityWeight());
74+
weight.Scale(scale);
75+
bl->SetLinearityWeight(weight, kNoTrans);
76+
bl->Write(ko.Stream(), binary_write);
77+
apply_scale = false;
78+
scale = 1.0;
79+
} else if (apply_scale) {
80+
KALDI_ERR << "Layer " << i << " following the dropout layer is not supported yet!";
6881
} else {
6982
layer->Write(ko.Stream(), binary_write);
7083
}

0 commit comments

Comments
 (0)