Neural netowork model class.
More...
#include <dnn_net.h>
Neural netowork model class.
Implements the network model
Definition at line 34 of file dnn_net.h.
◆ net()
Network constructor.
Definition at line 52 of file dnn_net.h.
◆ ~net()
◆ add_layer()
void dnn::net::add_layer |
( |
layer & |
lay | ) |
|
|
inline |
Add layer to network stack.
- Parameters
-
Definition at line 151 of file dnn_net.h.
◆ compile()
bool dnn::net::compile |
( |
void |
| ) |
|
|
inline |
Compiles network stack.
- Returns
- true if successful Compiles the netowrk, checking input/output, layer interface size
Definition at line 168 of file dnn_net.h.
◆ data_set()
void dnn::net::data_set |
( |
arma::Mat< DNN_Dtype > & |
data, |
|
|
arma::Mat< DNN_Dtype > & |
label, |
|
|
arma::uword |
Nmb, |
|
|
arma::uword |
Nep = 1 |
|
) |
| |
|
inline |
Set a new dataset in the model.
- Parameters
-
[in] | data,label | Data and label matices |
[in] | Nmb | Mini batch size |
[in] | Nep | Number of EPOCHS (number of runs through the dataset) |
Set a new dataset with K features and L classes in the model, the data matrix has format [K,N] and the label has [L,N] where N is the number of samples
Definition at line 74 of file dnn_net.h.
◆ disp()
void dnn::net::disp |
( |
void |
| ) |
|
|
inline |
Displays a short info about the network layers.
Definition at line 195 of file dnn_net.h.
◆ get_accuracy()
arma::mat dnn::net::get_accuracy |
( |
void |
| ) |
|
|
inline |
Get accuracy.
- Returns
- Accuracy matrix, shape [iterations,epochs]
Definition at line 117 of file dnn_net.h.
◆ get_cost()
arma::mat dnn::net::get_cost |
( |
void |
| ) |
|
|
inline |
Get cost.
- Returns
- Cost matrix, shape [iterations,epochs]
Definition at line 125 of file dnn_net.h.
◆ permute_data()
void dnn::net::permute_data |
( |
arma::Mat< DNN_Dtype > * |
Xptr, |
|
|
arma::Mat< DNN_Dtype > * |
Tptr |
|
) |
| |
|
inline |
Permute dataset.
- Parameters
-
[in] | Xptr | Pointer to data matrix |
[in] | Tptr | Pointer to label matrix |
Definition at line 134 of file dnn_net.h.
◆ predict()
A forward pass in the network for one input.
- Parameters
-
[in] | in | The network input data |
- Returns
- The network output data
Definition at line 339 of file dnn_net.h.
◆ set_batch_len()
void dnn::net::set_batch_len |
( |
arma::uword |
n | ) |
|
|
inline |
Set mini batch length.
- Parameters
-
[in] | n | New mini batch length |
Definition at line 109 of file dnn_net.h.
◆ show_stat()
void dnn::net::show_stat |
( |
void |
| ) |
|
|
inline |
Prints the statistics for each EPOCH.
Definition at line 306 of file dnn_net.h.
◆ test()
void dnn::net::test |
( |
void |
| ) |
|
|
inline |
Test the network.
Performes a mini batch forward propagation through all layers. The cost and accuracy is logged as the mean of each iteration
Definition at line 277 of file dnn_net.h.
◆ train()
void dnn::net::train |
( |
void |
| ) |
|
|
inline |
Trains the network.
Performes a mini batch forward propagation through all layers, a backpropagation of the error follows and finally the layer parameters are updated. The cost and accuracy is logged as the mean of each iteration
Definition at line 227 of file dnn_net.h.
◆ accuracy
arma::mat dnn::net::accuracy |
|
private |
Accuracy logging matrix.
Definition at line 46 of file dnn_net.h.
◆ cost_val
arma::mat dnn::net::cost_val |
|
private |
Cost logging matrix.
Definition at line 45 of file dnn_net.h.
◆ curEpoch
arma::uword dnn::net::curEpoch |
|
private |
Current EPOCH.
Definition at line 40 of file dnn_net.h.
◆ N_batch
arma::uword dnn::net::N_batch |
|
private |
Number of data in mini-batch.
Definition at line 43 of file dnn_net.h.
◆ N_data
arma::uword dnn::net::N_data |
|
private |
◆ N_epoch
arma::uword dnn::net::N_epoch |
|
private |
Number of EPOCHs.
Definition at line 41 of file dnn_net.h.
◆ N_iter
arma::uword dnn::net::N_iter |
|
private |
Number of iterations.
Definition at line 44 of file dnn_net.h.
◆ netlist
std::vector<layer*> dnn::net::netlist |
|
private |
Chained list of network layers.
Definition at line 37 of file dnn_net.h.
◆ phase
Current state/phase.
Definition at line 47 of file dnn_net.h.
◆ Tptr
Input target data matrix.
Definition at line 39 of file dnn_net.h.
◆ Xptr
Input training data matrix.
Definition at line 38 of file dnn_net.h.
The documentation for this class was generated from the following file: