DaNNet
dnn_opt.h
Go to the documentation of this file.
1 // Copyright 2019 Claes Rolen (www.rolensystems.com)
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 namespace dnn
17 {
23 
31 class opt
32 {
33 protected:
34  std::string alg;
42  arma::uword it;
43 public:
44  opt()
45  {
46  lr_alg = LR_ALG::CONST;
47  lr_0 = 1;
48  it = 0;
49  };
50 
51  ~opt() {};
52 
58  virtual void apply(arma::Cube<DNN_Dtype>& W,
59  arma::Mat<DNN_Dtype>& B,
60  const arma::Cube<DNN_Dtype>& Wgrad,
61  const arma::Mat<DNN_Dtype>& Bgrad ) = 0;
62 
67  virtual std::string get_algorithm(void)
68  {
69  return alg;
70  }
71 
84  void set_learn_rate_alg(LR_ALG alg, DNN_Dtype a=0.0, DNN_Dtype b=10.0 )
85  {
86  lr_alg = alg;
87  lr_a = a;
88  lr_b = b;
89  it = 0; // Reset counter
90  }
91 
97  void update_learn_rate(void)
98  {
99  it++; // Increase counter
100  switch (lr_alg)
101  {
102  case LR_ALG::TIME_DECAY: // lr_a = time decay
103  lr = lr_0/(1+lr_a*it);
104  break;
105  case LR_ALG::STEP_DECAY: // lr_a = drop factor, lr_b = time interval
106  lr = lr_0*std::pow(lr_a,std::floor(it/lr_b));
107  break;
108  case LR_ALG::EXP_DECAY: // lr_p1 = dec rate
109  lr = lr_0*std::exp(-lr_a*it);
110  break;
111 
112  default:
113  break;
114  }
115  }
116 
121  {
122  return lr;
123  }
124 
125 }; // End class opt
126 
127 
134 class opt_SGD:public opt
135 {
136 public:
144  {
145  lr = s;
146  lr_0 = lr;
147  reg_lambda = l;
148  reg_alpha = a;
149  alg = "SGD";
150  };
151 
152  ~opt_SGD() {};
153 
159  void apply(arma::Cube<DNN_Dtype>& W,
160  arma::Mat<DNN_Dtype>& B,
161  const arma::Cube<DNN_Dtype>& Wgrad,
162  const arma::Mat<DNN_Dtype>& Bgrad )
163  {
164  // Add regularization
165  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
166 
167  // Update params
168  W = W - lr*dw;
169  B = B - lr*Bgrad;
170 
171  // Update learning rate
173  }
174 
179  std::string get_algorithm(void)
180  {
181  char p[100];
182  std::snprintf(p,100,"(%g)",lr);
183  return alg+p;
184  }
185 }; // End class opt_SGD
186 
187 
194 class opt_SGD_momentum:public opt
195 {
196 private:
197  arma::Cube<DNN_Dtype> v;
198  arma::Mat<DNN_Dtype> vB;
200 public:
209  {
210  lr = s;
211  lr_0 = lr;
212  reg_lambda = l;
213  reg_alpha = a;
214  mom = m;
215  alg = "SGD_mom";
216  };
217 
219 
225  void apply(arma::Cube<DNN_Dtype>& W,
226  arma::Mat<DNN_Dtype>& B,
227  const arma::Cube<DNN_Dtype>& Wgrad,
228  const arma::Mat<DNN_Dtype>& Bgrad )
229  {
230  if(v.n_elem==0)
231  {
232  v.set_size(arma::size(W));
233  v.zeros();
234  vB.set_size(arma::size(B));
235  vB.zeros();
236  }
237 
238  // Add regularization
239  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
240 
241  // Update params
242  v = mom*v - lr*dw;
243  W = W + v;
244  vB = mom*vB - lr*Bgrad ;
245  B = B + vB;
246 
247  // Update learning rate
249  }
250 
256  std::string get_algorithm(void)
257  {
258  char p[100];
259  std::snprintf(p,100,"(%g,%g)",lr,mom);
260  return alg+p;
261  }
262 }; // End class opt_SGD_momentum
263 
264 
271 class opt_SGD_nesterov:public opt
272 {
273 private:
274  arma::Cube<DNN_Dtype> v;
275  arma::Cube<DNN_Dtype> vp;
276  arma::Mat<DNN_Dtype> vB;
277  arma::Mat<DNN_Dtype> vBp;
279 public:
288  {
289  lr = s;
290  lr_0 = lr;
291  reg_lambda = l;
292  reg_alpha = a;
293  mom = m;
294  alg = "SGD_nest";
295  };
296 
298 
304  void apply(arma::Cube<DNN_Dtype>& W,
305  arma::Mat<DNN_Dtype>& B,
306  const arma::Cube<DNN_Dtype>& Wgrad,
307  const arma::Mat<DNN_Dtype>& Bgrad )
308  {
309  if(v.n_elem==0)
310  {
311  v.set_size(arma::size(W));
312  v.zeros();
313  vp.set_size(arma::size(W));
314  vp.zeros();
315  vB.set_size(arma::size(B));
316  vB.zeros();
317  vBp.set_size(arma::size(B));
318  vBp.zeros();
319  }
320 
321  // Add regularization
322  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
323 
324  // Update params
325  vp = v;
326  v = mom*v - lr*dw;
327  W = W - mom*vp + (1+mom)*v;
328 
329  vBp = vB;
330  vB = mom*vB - lr*Bgrad ;
331  B = B -mom*vBp +(1+mom)*vB;
332 
333  // Update learning rate
335  }
336 
341  std::string get_algorithm(void)
342  {
343  char p[100];
344  std::snprintf(p,100,"(%g,%g)",lr,mom);
345  return alg+p;
346  }
347 }; // End class opt_SGD_nesterov
348 
349 
356 class opt_adam:public opt
357 {
358 private:
359  arma::Cube<DNN_Dtype> v;
360  arma::Cube<DNN_Dtype> m;
361  arma::Mat<DNN_Dtype> vB;
362  arma::Mat<DNN_Dtype> mB;
366 public:
367 
377  opt_adam(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype b1=0.9,DNN_Dtype b2=0.999, DNN_Dtype e=1e-8):opt()
378  {
379  lr = s;
380  lr_0 = lr;
381  reg_lambda = l;
382  reg_alpha = a;
383  beta1 = b1;
384  beta2 = b2;
385  eps = e;
386  alg = "ADAM";
387  };
388 
389  ~opt_adam() {};
390 
396  void apply(arma::Cube<DNN_Dtype>& W,
397  arma::Mat<DNN_Dtype>& B,
398  const arma::Cube<DNN_Dtype>& Wgrad,
399  const arma::Mat<DNN_Dtype>& Bgrad )
400  {
401  if(v.n_elem==0)
402  {
403  v.set_size(arma::size(W));
404  v.zeros();
405  m.set_size(arma::size(W));
406  m.zeros();
407  vB.set_size(arma::size(B));
408  vB.zeros();
409  mB.set_size(arma::size(B));
410  mB.zeros();
411  }
412 
413  // Add regularization
414  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
415 
416  // Update params
417  for (arma::uword k=0;k<v.n_elem;k++ )
418  {
419  m(k) = beta1*m(k) +(1-beta1)*dw(k);
420  v(k) = beta2*v(k) +(1-beta2)*dw(k)*dw(k);
421  W(k) = W(k) - lr*(std::sqrt(1-beta2)/(1-beta1))*m(k)/(std::sqrt(v(k))+eps);
422  }
423  for (arma::uword k=0;k<vB.n_elem;k++ )
424  {
425  mB(k) = beta1*mB(k) +(1-beta1)*Bgrad(k);
426  vB(k) = beta2*vB(k) +(1-beta2)*Bgrad(k)*Bgrad(k);
427  B(k) = B(k) - lr*(std::sqrt(1-beta2)/(1-beta1))*mB(k)/(std::sqrt(vB(k))+eps);
428  }
429 
430  // Update learning rate
432  }
433 
438  std::string get_algorithm(void)
439  {
440  char p[100];
441  std::snprintf(p,100,"(%g)",lr);
442  return alg+p;
443  }
444 }; // End class opt_adam
445 
446 
453 class opt_adamax:public opt
454 {
455 private:
456  arma::Cube<DNN_Dtype> v;
457  arma::Cube<DNN_Dtype> m;
458  arma::Mat<DNN_Dtype> vB;
459  arma::Mat<DNN_Dtype> mB;
463 public:
464 
474  opt_adamax(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype b1=0.9,DNN_Dtype b2=0.999, DNN_Dtype e=1e-8):opt()
475  {
476  lr = s;
477  lr_0 = lr;
478  reg_lambda = l;
479  reg_alpha = a;
480  beta1 = b1;
481  beta2 = b2;
482  eps = e;
483  alg = "ADAmax";
484  };
485 
487 
493  void apply(arma::Cube<DNN_Dtype>& W,
494  arma::Mat<DNN_Dtype>& B,
495  const arma::Cube<DNN_Dtype>& Wgrad,
496  const arma::Mat<DNN_Dtype>& Bgrad )
497  {
498  if(v.n_elem==0)
499  {
500  v.set_size(arma::size(W));
501  v.zeros();
502  m.set_size(arma::size(W));
503  m.zeros();
504  vB.set_size(arma::size(B));
505  vB.zeros();
506  mB.set_size(arma::size(B));
507  mB.zeros();
508  }
509 
510  // Add regularization
511  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
512 
513  // Update params
514  for (arma::uword k=0;k<v.n_elem;k++ )
515  {
516  m(k) = beta1*m(k) +(1-beta1)*dw(k);
517  v(k) = (beta2*v(k)> std::abs(dw(k))) ? beta2*v(k) : std::abs(dw(k));
518  W(k) = W(k) - lr*(std::sqrt(1-beta2)/(1-beta1))*m(k)/(v(k)+eps);
519  }
520  for (arma::uword k=0;k<vB.n_elem;k++ )
521  {
522  mB(k) = beta1*mB(k) +(1-beta1)*Bgrad(k);
523  vB(k) = (beta2*vB(k)> std::abs(Bgrad(k))) ? beta2*vB(k) : std::abs(Bgrad(k));
524  B(k) = B(k) - lr*(std::sqrt(1-beta2)/(1-beta1))*mB(k)/(vB(k)+eps);
525  }
526 
527  // Update learning rate
529  }
530 
536  std::string get_algorithm(void)
537  {
538  char p[100];
539  std::snprintf(p,100,"(%g)",lr);
540  return alg+p;
541  }
542 }; // End class opt_adamax
543 
544 
551 class opt_adadelta:public opt
552 {
553 private:
554  arma::Cube<DNN_Dtype> Ew;
555  arma::Cube<DNN_Dtype> dW;
556  arma::Mat<DNN_Dtype> Eb;
557  arma::Mat<DNN_Dtype> dB;
560 public:
569  opt_adadelta(DNN_Dtype r, DNN_Dtype s=1.0, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype e=1e-6):opt()
570  {
571  rho = r;
572  lr = s;
573  lr_0 = s;
574  reg_lambda = l;
575  reg_alpha = a;
576  eps = e;
577  alg = "ADAdelta";
578  };
579 
581 
587  void apply(arma::Cube<DNN_Dtype>& W,
588  arma::Mat<DNN_Dtype>& B,
589  const arma::Cube<DNN_Dtype>& Wgrad,
590  const arma::Mat<DNN_Dtype>& Bgrad )
591  {
592  if(Ew.n_elem==0)
593  {
594  Ew.set_size(arma::size(W));
595  Ew.zeros();
596  dW.set_size(arma::size(W));
597  dW.zeros();
598  Eb.set_size(arma::size(B));
599  Eb.zeros();
600  dB.set_size(arma::size(B));
601  dB.zeros();
602  }
603 
604  // Add regularization
605  const arma::Cube<DNN_Dtype> w=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
606 
607  // Update params
608  for (arma::uword k=0 ;k<W.n_elem ; k++ )
609  {
610  Ew(k) = rho*Ew(k)+(1-rho)*w(k)*w(k);
611 
612  DNN_Dtype upd = lr*w(k)*sqrt(dW(k)+eps)/sqrt(Ew(k)+eps);
613  W(k) = W(k)-upd;
614  dW(k) = rho*dW(k)+(1-rho)*upd*upd;
615  }
616  for (arma::uword k=0 ;k<B.n_elem ; k++ )
617  {
618  Eb(k) = rho*Eb(k)+(1-rho)*Bgrad(k)*Bgrad(k);
619 
620  DNN_Dtype upd = lr*Bgrad(k)*sqrt(dB(k)+eps)/sqrt(Eb(k)+eps);;
621  B(k) = B(k)-upd;
622  dB(k) = rho*dB(k)+(1-rho)*upd*upd;
623  }
624 
625  // Update learning rate
627  }
628 
633  std::string get_algorithm(void)
634  {
635  char p[100];
636  std::snprintf(p,100,"(%g)",rho);
637  return alg+p;
638  }
639 }; // End class opt_adadelta
640 
641 
648 class opt_adagrad:public opt
649 {
650 private:
651  arma::Cube<DNN_Dtype> v;
652  arma::Mat<DNN_Dtype> vB;
654 public:
664  {
665  lr = s;
666  lr_0 = lr;
667  reg_lambda = l;
668  reg_alpha = a;
669  eps = e;
670  alg = "ADAgrad";
671  };
673 
674 
680  void apply(arma::Cube<DNN_Dtype>& W,
681  arma::Mat<DNN_Dtype>& B,
682  const arma::Cube<DNN_Dtype>& Wgrad,
683  const arma::Mat<DNN_Dtype>& Bgrad )
684  {
685  if(v.n_elem==0)
686  {
687  v.set_size(arma::size(W));
688  v.zeros();
689  vB.set_size(arma::size(B));
690  vB.zeros();
691  }
692 
693  // Add regularization
694  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
695 
696  // Update params
697  for (arma::uword k=0 ;k<W.n_elem ; k++ )
698  {
699  v(k) = v(k)+dw(k)*dw(k);
700  W(k) = W(k)-lr*dw(k)/std::sqrt(v(k)+eps);
701  }
702  for (arma::uword k=0 ;k<B.n_elem ; k++ )
703  {
704  vB(k) = vB(k)+Bgrad(k)*Bgrad(k);
705  B(k) = B(k)-lr*Bgrad(k)/std::sqrt(vB(k)+eps);
706  }
707 
708  // Update learning rate
710  }
711 
716  std::string get_algorithm(void)
717  {
718  char p[100];
719  std::snprintf(p,100,"(%g)",lr);
720  return alg+p;
721  }
722 }; // End class opt_adagrad
723 
724 
731 class opt_rmsprop:public opt
732 {
733 private:
734  arma::Cube<DNN_Dtype> v;
735  arma::Mat<DNN_Dtype> vB;
738 public:
747  opt_rmsprop(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, const DNN_Dtype b=0.9, DNN_Dtype e=1e-8):opt()
748  {
749  lr = s;
750  lr_0 = lr;
751  reg_lambda = l;
752  reg_alpha = a;
753  beta = b;
754  eps = e;
755  alg = "RMSprop";
756  };
757 
759 
765  void apply(arma::Cube<DNN_Dtype>& W,
766  arma::Mat<DNN_Dtype>& B,
767  const arma::Cube<DNN_Dtype>& Wgrad,
768  const arma::Mat<DNN_Dtype>& Bgrad )
769  {
770  if(v.n_elem==0)
771  {
772  v.set_size(arma::size(W));
773  v.zeros();
774  vB.set_size(arma::size(B));
775  vB.zeros();
776  }
777 
778  // Add regularization
779  const arma::Cube<DNN_Dtype> dw=Wgrad+reg_lambda*(reg_alpha*arma::sign(W)+(1-reg_alpha)*W);
780 
781  // Update params
782  for (arma::uword k=0 ;k<W.n_elem ; k++ )
783  {
784  v(k) = beta*v(k)+(1-beta)*dw(k)*dw(k);
785  W(k) = W(k)-lr*dw(k)/std::sqrt(v(k)+eps);
786  }
787  for (arma::uword k=0 ;k<B.n_elem ; k++ )
788  {
789  vB(k) = beta*vB(k)+(1-beta)*Bgrad(k)*Bgrad(k);
790  B(k) = B(k)-lr*Bgrad(k)/std::sqrt(vB(k)+eps);
791  }
792 
793  // Update learning rate
795  }
796 
801  std::string get_algorithm(void)
802  {
803  char p[100];
804  std::snprintf(p,100,"(%g,%g)",lr,beta);
805  return alg+p;
806  }
807 }; // End class opt_rmsprop
809 } // End namespace dnn
arma::Cube< DNN_Dtype > m
Definition: dnn_opt.h:360
opt()
Definition: dnn_opt.h:44
DNN_Dtype eps
Definition: dnn_opt.h:365
opt_SGD_momentum(DNN_Dtype s, DNN_Dtype m, DNN_Dtype l=0.0, DNN_Dtype a=0.0)
SGD with momentum constructor.
Definition: dnn_opt.h:208
opt_SGD(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0)
SGD constructor.
Definition: dnn_opt.h:143
arma::Mat< DNN_Dtype > vB
Definition: dnn_opt.h:361
arma::Cube< DNN_Dtype > v
Definition: dnn_opt.h:734
arma::Mat< DNN_Dtype > mB
Definition: dnn_opt.h:459
arma::Mat< DNN_Dtype > vB
Definition: dnn_opt.h:652
DNN_Dtype eps
Definition: dnn_opt.h:653
opt_adagrad(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype e=1e-8)
ADAgrad constructor.
Definition: dnn_opt.h:663
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:680
opt_adadelta(DNN_Dtype r, DNN_Dtype s=1.0, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype e=1e-6)
ADAdelta constructor.
Definition: dnn_opt.h:569
arma::Cube< DNN_Dtype > Ew
Definition: dnn_opt.h:554
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:493
DNN_Dtype eps
Definition: dnn_opt.h:737
arma::uword it
Iteration counter.
Definition: dnn_opt.h:42
arma::Cube< DNN_Dtype > v
Velocity internal variable for weight.
Definition: dnn_opt.h:197
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:256
DNN_Dtype beta1
Definition: dnn_opt.h:363
DNN_Dtype lr_a
Internal parameter a.
Definition: dnn_opt.h:40
RMSprop optimizer class.
Definition: dnn_opt.h:731
arma::Mat< DNN_Dtype > vB
Velocity internal variable for bias.
Definition: dnn_opt.h:198
arma::Mat< DNN_Dtype > mB
Definition: dnn_opt.h:362
DNN_Dtype rho
Definition: dnn_opt.h:558
DNN_Dtype beta
Definition: dnn_opt.h:736
~opt()
Definition: dnn_opt.h:51
ADAgrad optimizer class.
Definition: dnn_opt.h:648
arma::Cube< DNN_Dtype > v
Definition: dnn_opt.h:359
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:225
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:159
DNN_Dtype lr
Learning rate.
Definition: dnn_opt.h:35
ADAMax optimizer class.
Definition: dnn_opt.h:453
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:304
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:633
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:396
void set_learn_rate_alg(LR_ALG alg, DNN_Dtype a=0.0, DNN_Dtype b=10.0)
Set learning rate algorithm.
Definition: dnn_opt.h:84
Stochastic Gradient Descent with momentum optimizer class.
Definition: dnn_opt.h:194
opt_adamax(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype b1=0.9, DNN_Dtype b2=0.999, DNN_Dtype e=1e-8)
ADAMax constructor.
Definition: dnn_opt.h:474
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:765
ADAdelta optimizer class.
Definition: dnn_opt.h:551
DNN_Dtype reg_lambda
Regularisation parameter lambda.
Definition: dnn_opt.h:36
DNN_Dtype get_learn_rate(void)
Get the learning rate.
Definition: dnn_opt.h:120
float DNN_Dtype
Data type used in the network (float or double)
Definition: dnn.h:28
DNN_Dtype eps
Definition: dnn_opt.h:559
Stochastic Gradient Descent optimizer class.
Definition: dnn_opt.h:134
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:801
std::string alg
Definition: dnn_opt.h:34
arma::Cube< DNN_Dtype > m
Definition: dnn_opt.h:457
DNN_Dtype beta1
Definition: dnn_opt.h:460
arma::Cube< DNN_Dtype > vp
Definition: dnn_opt.h:275
DNN_Dtype lr_0
Init value for lr.
Definition: dnn_opt.h:39
virtual void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)=0
Apply the optimizer to the layer parameters.
LR_ALG
Definition: dnn.h:35
opt_adam(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, DNN_Dtype b1=0.9, DNN_Dtype b2=0.999, DNN_Dtype e=1e-8)
ADAM constructor.
Definition: dnn_opt.h:377
arma::Cube< DNN_Dtype > v
Definition: dnn_opt.h:651
Definition: dnn.h:22
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:341
void apply(arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)
Apply the optimizer to the layer parameters.
Definition: dnn_opt.h:587
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:438
arma::Cube< DNN_Dtype > dW
Definition: dnn_opt.h:555
LR_ALG lr_alg
Learning rate schedule algorithm.
Definition: dnn_opt.h:38
arma::Mat< DNN_Dtype > vB
Definition: dnn_opt.h:276
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:179
virtual std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:67
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:536
DNN_Dtype eps
Definition: dnn_opt.h:462
arma::Cube< DNN_Dtype > v
Definition: dnn_opt.h:456
std::string get_algorithm(void)
Get the optimizer algorithm information.
Definition: dnn_opt.h:716
DNN_Dtype reg_alpha
Elastic net mix parameter - 0=ridge (L2) .. 1=LASSO (L1)
Definition: dnn_opt.h:37
arma::Cube< DNN_Dtype > v
Definition: dnn_opt.h:274
DNN_Dtype lr_b
Internal parameter b.
Definition: dnn_opt.h:41
arma::Mat< DNN_Dtype > Eb
Definition: dnn_opt.h:556
DNN_Dtype beta2
Definition: dnn_opt.h:364
Stochastic Gradient Descent with Nesterov momentum optimizer class.
Definition: dnn_opt.h:271
opt_SGD_nesterov(DNN_Dtype s, DNN_Dtype m, DNN_Dtype l=0.0, DNN_Dtype a=0.0)
SGD with Nesterov momentum constructor.
Definition: dnn_opt.h:287
arma::Mat< DNN_Dtype > dB
Definition: dnn_opt.h:557
DNN_Dtype beta2
Definition: dnn_opt.h:461
opt_rmsprop(DNN_Dtype s, DNN_Dtype l=0.0, DNN_Dtype a=0.0, const DNN_Dtype b=0.9, DNN_Dtype e=1e-8)
RMSprop constructor.
Definition: dnn_opt.h:747
void update_learn_rate(void)
Update learning rate.
Definition: dnn_opt.h:97
ADAM optimizer class.
Definition: dnn_opt.h:356
Optimizer base class.
Definition: dnn_opt.h:31
arma::Mat< DNN_Dtype > vB
Definition: dnn_opt.h:735
arma::Mat< DNN_Dtype > vBp
Definition: dnn_opt.h:277
arma::Mat< DNN_Dtype > vB
Definition: dnn_opt.h:458