DaNNet
dnn_layer_norm.h
Go to the documentation of this file.
1 // Copyright 2019 Claes Rolen (www.rolensystems.com)
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 namespace dnn
17 {
21 
22 
29 class layer_norm: public layer
30 {
31 private:
32  arma::Mat<DNN_Dtype> gamma;
33  arma::Mat<DNN_Dtype> beta;
34  arma::Mat<DNN_Dtype> x_mean;
35  arma::Mat<DNN_Dtype> x_stdinv;
36  arma::Mat<DNN_Dtype> x_norm;
37  arma::Mat<DNN_Dtype> ee;
38  arma::Mat<DNN_Dtype> run_mean;
39  arma::Mat<DNN_Dtype> run_var;
41 public:
42  layer_norm(void):layer() {};
43 
48  void init(void)
49  {
50  layer::init();
51 
52  train_par = true;
57 
58  type = "Batch Norm";
59  id = type;
60  gamma. set_size(N_right, 1); gamma.ones();
61  beta. set_size(N_right, 1); beta.zeros();
62  x_stdinv.set_size(N_right, 1);
63  ee. set_size(N_right, 1); ee.fill((DNN_Dtype)1e-5);
64  run_mean.set_size(N_right, 1); run_mean.zeros();
65  run_var. set_size(N_right, 1); run_var.zeros();
66  run_alpha = (DNN_Dtype)0.99;
67  }
68 
75  void upd_buf_size(arma::uword nmb)
76  {
78  x_mean.set_size(N_right, N_batch); x_mean.zeros();
79  x_norm.set_size(N_right, N_batch); x_norm.zeros();
80  }
81 
86  void set_run_mean(arma::Mat<DNN_Dtype>& r_mean)
87  {
88  run_mean = r_mean;
89  }
90 
95  void set_run_var(arma::Mat<DNN_Dtype>& r_var)
96  {
97  run_var = r_var;
98  }
99 
104  arma::uword get_nrof_params(void)
105  {
106  return N_right+N_right;
107  }
108 
114  void prop(void)
115  {
116  arma::Mat<DNN_Dtype> X_1 = left->get_Y1();
117  Y1 = gamma%(X_1-run_mean)/arma::sqrt((run_var+ee))+beta;
118  }
119 
125  void prop_mb(void)
126  {
127  arma::Mat<DNN_Dtype> x=*(left->get_Y_ptr());
128 
129  // Normalize the input.
130  const arma::Mat<DNN_Dtype> batch_mean = arma::mean(x,1);
131  const arma::Mat<DNN_Dtype> batch_var = arma::var(x,1,1);
132  run_mean = run_mean*run_alpha+batch_mean*((DNN_Dtype)1.0-run_alpha);
133  run_var = run_var*run_alpha+batch_var*((DNN_Dtype)1.0-run_alpha);
134 
135  x_mean = x.each_col()-batch_mean;
136  x_stdinv = 1.0/arma::sqrt((batch_var+ee));
137  x_norm = x_mean.each_col()%x_stdinv;
138 
139  // Translate output to mean=0, variance=1
140  for(arma::uword k=0; k<x_norm.n_cols; k++ )
141  {
142  Y.col(k) = x_norm.col(k)%gamma + beta;
143  }
144  }
145 
151  void backprop(void)
152  {
153  arma::Mat<DNN_Dtype> y =*(right->get_Dleft_ptr());
154  arma::Mat<DNN_Dtype> norm = y.each_col()%gamma;
155  arma::Mat<DNN_Dtype> var = -arma::sum(norm % x_mean, 1) % arma::pow(x_stdinv, 3.0)/2;
156  Dleft = norm.each_col() % x_stdinv + x_mean.each_col() % (var*((DNN_Dtype)2.0/N_batch));
157  Dleft.each_col() += -(arma::sum(norm.each_col() % x_stdinv, 1)- (DNN_Dtype)2.0*(var % arma::mean(x_mean, 1)))/((DNN_Dtype)N_batch);
158  }
159 
165  void update(void)
166  {
167  arma::Cube<DNN_Dtype> G(gamma.memptr(),N_right,1,1,false,true);
168  arma::Cube<DNN_Dtype> dg(N_right,1,1);
169  dg.zeros();
170  arma::Mat<DNN_Dtype> db(beta);
171  arma::Mat<DNN_Dtype> x = *(left->get_Y_ptr());
172  arma::Mat<DNN_Dtype> dy = *(right->get_Dleft_ptr());
173 
174  db = arma::mean(dy,1);
175  dg.slice(0).col(0) = arma::sum(x_norm%dy,1);
176 
177  opt_alg->apply(G,beta,dg,db);
178  }
179 
183  void disp(void)
184  {
185  layer::disp();
186  std::cout << "Nr of outputs: " << get_nrof_outputs() << " ["<< N_rows_right << ","<<N_cols_right<<","<< N_channels_right<<"]"<< std::endl;
187  }
188 
189 }; // End class layer_norm
191 } // End namespace dnn
void upd_buf_size(arma::uword nmb)
Updates the buffer sizes.
Batch normalization layer class.
opt * opt_alg
Pointer to optimizer.
std::string type
Layer type string.
arma::Mat< DNN_Dtype > Y
Output buffer mini batch [N_right,N_batch].
arma::Mat< DNN_Dtype > gamma
layer * right
Pointer to next layer.
void prop_mb(void)
Forward mini batch propagation though layer.
arma::uword N_channels_right
Output channels, number of filters.
arma::Mat< DNN_Dtype > x_norm
arma::Mat< DNN_Dtype > Dleft
Error buffer [N_left,N_batch].
Layer base class.
arma::Mat< DNN_Dtype > Y1
Output buffer [N_right,1].
virtual arma::uword get_nrof_outputs(void)
Get total number of layer outputs.
virtual arma::Mat< DNN_Dtype > get_Y1(void)
Get output buffer.
virtual void init(void)
Initialize layer.
void set_run_mean(arma::Mat< DNN_Dtype > &r_mean)
Set the running mean matrix.
virtual arma::Mat< DNN_Dtype > * get_Dleft_ptr(void)
Get error buffer pointer - mini batch.
arma::Mat< DNN_Dtype > ee
virtual arma::uword get_nrof_channels(void)
Get output buffer channel/layer size.
void set_run_var(arma::Mat< DNN_Dtype > &r_var)
Set the running variance matrix.
virtual void disp(void)
Display info about layer.
float DNN_Dtype
Data type used in the network (float or double)
Definition: dnn.h:28
arma::uword N_rows_right
Output rows.
virtual void upd_buf_size(arma::uword nmb)
Update layer buffer sizes.
void disp(void)
Display info about layer.
void init(void)
Initialization of layer.
virtual void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)=0
Apply the optimizer to the layer parameters.
virtual arma::uword get_nrof_cols(void)
Get output buffer column size.
arma::uword N_batch
Mini batch size.
bool train_par
Enable training.
Definition: dnn.h:22
void update(void)
Updates the trainable parameters.
arma::Mat< DNN_Dtype > x_stdinv
arma::Mat< DNN_Dtype > run_var
DNN_Dtype run_alpha
virtual arma::uword get_nrof_rows(void)
Get output buffer row size.
arma::uword N_right
Total size right.
layer * left
Pointer to previous layer.
arma::uword N_cols_right
Output cols.
void prop(void)
Forward propagation though layer.
void backprop(void)
Backpropagation of mini batch propagation though layer.
virtual arma::Mat< DNN_Dtype > * get_Y_ptr(void)
Get output buffer pointer - mini batch.
arma::Mat< DNN_Dtype > run_mean
arma::Mat< DNN_Dtype > x_mean
arma::uword get_nrof_params(void)
Get info about number of trainable parameters in layer.
arma::Mat< DNN_Dtype > beta