Skip to content

Commit 24c496b

Browse files
killeentsoumith
authored andcommitted
move normal variants to TH/THC
1 parent 58334a0 commit 24c496b

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

generic/THTensorRandom.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,29 @@ void THTensor_(normal)(THTensor *self, THGenerator *_generator, double mean, dou
5555
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_normal(_generator, mean, stdv););
5656
}
5757

58+
void THTensor_(normal_means)(THTensor *self, THGenerator *gen, THTensor *means, double stddev)
59+
{
60+
THTensor_(resizeAs)(self, means);
61+
THTensor_(normal)(self, gen, 0, stddev);
62+
THTensor_(cadd)(self, self, 1, means);
63+
}
64+
65+
void THTensor_(normal_stddevs)(THTensor *self, THGenerator *gen, double mean, THTensor *stddevs)
66+
{
67+
THTensor_(resizeAs)(self, stddevs);
68+
THTensor_(normal)(self, gen, 0, 1);
69+
THTensor_(cmul)(self, self, stddevs);
70+
THTensor_(add)(self, self, mean);
71+
}
72+
73+
void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, THTensor *means, THTensor *stddevs)
74+
{
75+
THTensor_(resizeAs)(self, means);
76+
THTensor_(normal)(self, gen, 0, 1);
77+
THTensor_(cmul)(self, self, stddevs);
78+
THTensor_(cadd)(self, self, 1, means);
79+
}
80+
5881
void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda)
5982
{
6083
TH_TENSOR_APPLY(real, self, *self_data = (real)THRandom_exponential(_generator, lambda););

generic/THTensorRandom.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ TH_API void THTensor_(bernoulli_DoubleTensor)(THTensor *self, THGenerator *_gene
1111
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
1212
TH_API void THTensor_(uniform)(THTensor *self, THGenerator *_generator, double a, double b);
1313
TH_API void THTensor_(normal)(THTensor *self, THGenerator *_generator, double mean, double stdv);
14+
TH_API void THTensor_(normal_means)(THTensor *self, THGenerator *gen, THTensor *means, double stddev);
15+
TH_API void THTensor_(normal_stddevs)(THTensor *self, THGenerator *gen, double mean, THTensor *stddevs);
16+
TH_API void THTensor_(normal_means_stddevs)(THTensor *self, THGenerator *gen, THTensor *means, THTensor *stddevs);
1417
TH_API void THTensor_(exponential)(THTensor *self, THGenerator *_generator, double lambda);
1518
TH_API void THTensor_(cauchy)(THTensor *self, THGenerator *_generator, double median, double sigma);
1619
TH_API void THTensor_(logNormal)(THTensor *self, THGenerator *_generator, double mean, double stdv);

0 commit comments

Comments
 (0)