DaNNet
Public Member Functions | Private Attributes | List of all members
dnn::net Class Reference

Neural netowork model class. More...

#include <dnn_net.h>

Public Member Functions

 net (void)
 Network constructor. More...
 
 ~net ()
 
void data_set (arma::Mat< DNN_Dtype > &data, arma::Mat< DNN_Dtype > &label, arma::uword Nmb, arma::uword Nep=1)
 Set a new dataset in the model. More...
 
void set_batch_len (arma::uword n)
 Set mini batch length. More...
 
arma::mat get_accuracy (void)
 Get accuracy. More...
 
arma::mat get_cost (void)
 Get cost. More...
 
void permute_data (arma::Mat< DNN_Dtype > *Xptr, arma::Mat< DNN_Dtype > *Tptr)
 Permute dataset. More...
 
void add_layer (layer &lay)
 Add layer to network stack. More...
 
bool compile (void)
 Compiles network stack. More...
 
void disp (void)
 Displays a short info about the network layers. More...
 
void train (void)
 Trains the network. More...
 
void test (void)
 Test the network. More...
 
void show_stat (void)
 Prints the statistics for each EPOCH. More...
 
arma::Mat< DNN_Dtypepredict (const arma::Mat< DNN_Dtype > &in)
 A forward pass in the network for one input. More...
 

Private Attributes

std::vector< layer * > netlist
 Chained list of network layers. More...
 
arma::Mat< DNN_Dtype > * Xptr
 Input training data matrix. More...
 
arma::Mat< DNN_Dtype > * Tptr
 Input target data matrix. More...
 
arma::uword curEpoch
 Current EPOCH. More...
 
arma::uword N_epoch
 Number of EPOCHs. More...
 
arma::uword N_data
 Dataset size. More...
 
arma::uword N_batch
 Number of data in mini-batch. More...
 
arma::uword N_iter
 Number of iterations. More...
 
arma::mat cost_val
 Cost logging matrix. More...
 
arma::mat accuracy
 Accuracy logging matrix. More...
 
PHASE phase
 Current state/phase. More...
 

Detailed Description

Neural netowork model class.

Implements the network model

Definition at line 34 of file dnn_net.h.

Constructor & Destructor Documentation

◆ net()

dnn::net::net ( void  )
inline

Network constructor.

Definition at line 52 of file dnn_net.h.

◆ ~net()

dnn::net::~net ( )
inline

Definition at line 62 of file dnn_net.h.

Member Function Documentation

◆ add_layer()

void dnn::net::add_layer ( layer lay)
inline

Add layer to network stack.

Parameters
[in]layLayer to add

Definition at line 151 of file dnn_net.h.

◆ compile()

bool dnn::net::compile ( void  )
inline

Compiles network stack.

Returns
true if successful Compiles the netowrk, checking input/output, layer interface size

Definition at line 168 of file dnn_net.h.

◆ data_set()

void dnn::net::data_set ( arma::Mat< DNN_Dtype > &  data,
arma::Mat< DNN_Dtype > &  label,
arma::uword  Nmb,
arma::uword  Nep = 1 
)
inline

Set a new dataset in the model.

Parameters
[in]data,labelData and label matices
[in]NmbMini batch size
[in]NepNumber of EPOCHS (number of runs through the dataset)

Set a new dataset with K features and L classes in the model, the data matrix has format [K,N] and the label has [L,N] where N is the number of samples

Definition at line 74 of file dnn_net.h.

◆ disp()

void dnn::net::disp ( void  )
inline

Displays a short info about the network layers.

Definition at line 195 of file dnn_net.h.

◆ get_accuracy()

arma::mat dnn::net::get_accuracy ( void  )
inline

Get accuracy.

Returns
Accuracy matrix, shape [iterations,epochs]

Definition at line 117 of file dnn_net.h.

◆ get_cost()

arma::mat dnn::net::get_cost ( void  )
inline

Get cost.

Returns
Cost matrix, shape [iterations,epochs]

Definition at line 125 of file dnn_net.h.

◆ permute_data()

void dnn::net::permute_data ( arma::Mat< DNN_Dtype > *  Xptr,
arma::Mat< DNN_Dtype > *  Tptr 
)
inline

Permute dataset.

Parameters
[in]XptrPointer to data matrix
[in]TptrPointer to label matrix

Definition at line 134 of file dnn_net.h.

◆ predict()

arma::Mat<DNN_Dtype> dnn::net::predict ( const arma::Mat< DNN_Dtype > &  in)
inline

A forward pass in the network for one input.

Parameters
[in]inThe network input data
Returns
The network output data

Definition at line 339 of file dnn_net.h.

◆ set_batch_len()

void dnn::net::set_batch_len ( arma::uword  n)
inline

Set mini batch length.

Parameters
[in]nNew mini batch length

Definition at line 109 of file dnn_net.h.

◆ show_stat()

void dnn::net::show_stat ( void  )
inline

Prints the statistics for each EPOCH.

Definition at line 306 of file dnn_net.h.

◆ test()

void dnn::net::test ( void  )
inline

Test the network.

Performes a mini batch forward propagation through all layers. The cost and accuracy is logged as the mean of each iteration

Definition at line 277 of file dnn_net.h.

◆ train()

void dnn::net::train ( void  )
inline

Trains the network.

Performes a mini batch forward propagation through all layers, a backpropagation of the error follows and finally the layer parameters are updated. The cost and accuracy is logged as the mean of each iteration

Definition at line 227 of file dnn_net.h.

Member Data Documentation

◆ accuracy

arma::mat dnn::net::accuracy
private

Accuracy logging matrix.

Definition at line 46 of file dnn_net.h.

◆ cost_val

arma::mat dnn::net::cost_val
private

Cost logging matrix.

Definition at line 45 of file dnn_net.h.

◆ curEpoch

arma::uword dnn::net::curEpoch
private

Current EPOCH.

Definition at line 40 of file dnn_net.h.

◆ N_batch

arma::uword dnn::net::N_batch
private

Number of data in mini-batch.

Definition at line 43 of file dnn_net.h.

◆ N_data

arma::uword dnn::net::N_data
private

Dataset size.

Definition at line 42 of file dnn_net.h.

◆ N_epoch

arma::uword dnn::net::N_epoch
private

Number of EPOCHs.

Definition at line 41 of file dnn_net.h.

◆ N_iter

arma::uword dnn::net::N_iter
private

Number of iterations.

Definition at line 44 of file dnn_net.h.

◆ netlist

std::vector<layer*> dnn::net::netlist
private

Chained list of network layers.

Definition at line 37 of file dnn_net.h.

◆ phase

PHASE dnn::net::phase
private

Current state/phase.

Definition at line 47 of file dnn_net.h.

◆ Tptr

arma::Mat<DNN_Dtype>* dnn::net::Tptr
private

Input target data matrix.

Definition at line 39 of file dnn_net.h.

◆ Xptr

arma::Mat<DNN_Dtype>* dnn::net::Xptr
private

Input training data matrix.

Definition at line 38 of file dnn_net.h.


The documentation for this class was generated from the following file: