DaNNet
Public Member Functions | Protected Attributes | List of all members
dnn::opt Class Referenceabstract

Optimizer base class. More...

#include <dnn_opt.h>

Inheritance diagram for dnn::opt:
dnn::opt_adadelta dnn::opt_adagrad dnn::opt_adam dnn::opt_adamax dnn::opt_rmsprop dnn::opt_SGD dnn::opt_SGD_momentum dnn::opt_SGD_nesterov

Public Member Functions

 opt ()
 
 ~opt ()
 
virtual void apply (arma::Cube< DNN_Dtype > &W, arma::Mat< DNN_Dtype > &B, const arma::Cube< DNN_Dtype > &Wgrad, const arma::Mat< DNN_Dtype > &Bgrad)=0
 Apply the optimizer to the layer parameters. More...
 
virtual std::string get_algorithm (void)
 Get the optimizer algorithm information. More...
 
void set_learn_rate_alg (LR_ALG alg, DNN_Dtype a=0.0, DNN_Dtype b=10.0)
 Set learning rate algorithm. More...
 
void update_learn_rate (void)
 Update learning rate. More...
 
DNN_Dtype get_learn_rate (void)
 Get the learning rate. More...
 

Protected Attributes

std::string alg
 
DNN_Dtype lr
 Learning rate. More...
 
DNN_Dtype reg_lambda
 Regularisation parameter lambda. More...
 
DNN_Dtype reg_alpha
 Elastic net mix parameter - 0=ridge (L2) .. 1=LASSO (L1) More...
 
LR_ALG lr_alg
 Learning rate schedule algorithm. More...
 
DNN_Dtype lr_0
 Init value for lr. More...
 
DNN_Dtype lr_a
 Internal parameter a. More...
 
DNN_Dtype lr_b
 Internal parameter b. More...
 
arma::uword it
 Iteration counter. More...
 

Detailed Description

Optimizer base class.

Implements the optimizer for finding the minimum of the cost function with respect to the layers trainable parameters

Definition at line 31 of file dnn_opt.h.

Constructor & Destructor Documentation

◆ opt()

dnn::opt::opt ( )
inline

Definition at line 44 of file dnn_opt.h.

◆ ~opt()

dnn::opt::~opt ( )
inline

Definition at line 51 of file dnn_opt.h.

Member Function Documentation

◆ apply()

virtual void dnn::opt::apply ( arma::Cube< DNN_Dtype > &  W,
arma::Mat< DNN_Dtype > &  B,
const arma::Cube< DNN_Dtype > &  Wgrad,
const arma::Mat< DNN_Dtype > &  Bgrad 
)
pure virtual

Apply the optimizer to the layer parameters.

Parameters
[in,out]W,BLearnable parameters
[in]Wgrad,BgradGradient of the learnable parameters

Implemented in dnn::opt_rmsprop, dnn::opt_adagrad, dnn::opt_adadelta, dnn::opt_adamax, dnn::opt_adam, dnn::opt_SGD_nesterov, dnn::opt_SGD_momentum, and dnn::opt_SGD.

◆ get_algorithm()

virtual std::string dnn::opt::get_algorithm ( void  )
inlinevirtual

Get the optimizer algorithm information.

Returns
Algorithm information string

Reimplemented in dnn::opt_rmsprop, dnn::opt_adagrad, dnn::opt_adadelta, dnn::opt_adamax, dnn::opt_adam, dnn::opt_SGD_nesterov, dnn::opt_SGD_momentum, and dnn::opt_SGD.

Definition at line 67 of file dnn_opt.h.

◆ get_learn_rate()

DNN_Dtype dnn::opt::get_learn_rate ( void  )
inline

Get the learning rate.

Definition at line 120 of file dnn_opt.h.

◆ set_learn_rate_alg()

void dnn::opt::set_learn_rate_alg ( LR_ALG  alg,
DNN_Dtype  a = 0.0,
DNN_Dtype  b = 10.0 
)
inline

Set learning rate algorithm.

Parameters
[in]algAlgorithm
[in]aParameter a
[in]bParameter b

Sets the learning rate algorithm and the parameters CONST: constant learning rate lr = lr_0 TIME_DECAY: time based decay lr = lr_0/(1+at) STEP_DECAY: stepped decay lr = lr_0*(a)^(floor(b/t)) EXP_DECAY: eponential decreasing decay lr = lr_0*exp(-at)

Definition at line 84 of file dnn_opt.h.

◆ update_learn_rate()

void dnn::opt::update_learn_rate ( void  )
inline

Update learning rate.

Updates the learning rate (lr)

Definition at line 97 of file dnn_opt.h.

Member Data Documentation

◆ alg

std::string dnn::opt::alg
protected

Definition at line 34 of file dnn_opt.h.

◆ it

arma::uword dnn::opt::it
protected

Iteration counter.

Definition at line 42 of file dnn_opt.h.

◆ lr

DNN_Dtype dnn::opt::lr
protected

Learning rate.

Definition at line 35 of file dnn_opt.h.

◆ lr_0

DNN_Dtype dnn::opt::lr_0
protected

Init value for lr.

Definition at line 39 of file dnn_opt.h.

◆ lr_a

DNN_Dtype dnn::opt::lr_a
protected

Internal parameter a.

Definition at line 40 of file dnn_opt.h.

◆ lr_alg

LR_ALG dnn::opt::lr_alg
protected

Learning rate schedule algorithm.

Definition at line 38 of file dnn_opt.h.

◆ lr_b

DNN_Dtype dnn::opt::lr_b
protected

Internal parameter b.

Definition at line 41 of file dnn_opt.h.

◆ reg_alpha

DNN_Dtype dnn::opt::reg_alpha
protected

Elastic net mix parameter - 0=ridge (L2) .. 1=LASSO (L1)

Definition at line 37 of file dnn_opt.h.

◆ reg_lambda

DNN_Dtype dnn::opt::reg_lambda
protected

Regularisation parameter lambda.

Definition at line 36 of file dnn_opt.h.


The documentation for this class was generated from the following file: