18 high_resolution_clock::time_point
tic,
toc;
38 arma::Mat<DNN_Dtype>*
Xptr;
39 arma::Mat<DNN_Dtype>*
Tptr;
74 void data_set( arma::Mat<DNN_Dtype>& data,arma::Mat<DNN_Dtype>& label,arma::uword Nmb,arma::uword Nep=1)
82 N_iter =
static_cast<arma::uword
>(floor(N_data/N_batch));
84 std::cout << std::endl;
85 std::cout <<
"Nr of data: " << N_data << std::endl;
86 std::cout <<
"Minibatch size: " << N_batch << std::endl;
87 std::cout <<
"Iterations/EPOCH: " << N_iter << std::endl;
88 std::cout <<
"EPOCHs: " << N_epoch << std::endl;
94 for(
size_t l=0; l<netlist.size(); l++)
96 netlist.at(l)->upd_buf_size(N_batch);
100 cost_val.set_size(N_iter,N_epoch);
102 accuracy.set_size(N_iter,N_epoch);
134 void permute_data(arma::Mat<DNN_Dtype>* Xptr,arma::Mat<DNN_Dtype>* Tptr)
136 arma::Mat<DNN_Dtype> Xc(*Xptr);
137 arma::Mat<DNN_Dtype> Tc(*Tptr);
138 arma::uvec ix0 = arma::regspace<arma::uvec>(0,Xc.n_cols-1);
139 arma::uvec ix = arma::shuffle(ix0);
141 for(arma::uword c = 0; c<Xptr->n_cols; c++)
143 (*Xptr).col(c) = Xc.col(ix(c));
144 (*Tptr).col(c) = Tc.col(ix(c));
156 netlist.back()->set_right(&lay);
158 lay.
set_ix(netlist.size());
160 netlist.push_back(&lay);
171 std::string first = netlist.front()->get_type();
172 std::string last = netlist.back()->get_type();
173 if(first.compare(
"Input") || last.compare(
"Output"))
175 std::cout <<
"Net must have input and output!" << std::endl;
180 for(
size_t l=1; l<netlist.size(); l++)
182 arma::uword n = netlist.at(l)->get_nrof_inputs();
183 arma::uword m = netlist.at(l-1)->get_nrof_outputs();
186 std::cout <<
"Wrong dimensions: " << n <<
" " << m << std::endl;
198 init.copyfmt(std::cout);
200 std::cout << std::left <<
"\n " 201 << std::setw(6) <<
"Layer" 202 << std::setw(23) <<
"TypeId" 203 << std::setw(30) <<
"Optimizer" 204 << std::setw(12) <<
"Shape" 205 << std::setw(10) <<
"Params" 209 std::cout.copyfmt(init);
211 std::cout <<
"================================================================================" << std::endl;
214 std::cout << *it << std::endl;
216 std::cout <<
"================================================================================" << std::endl;
230 permute_data(Xptr,Tptr);
233 phase = PHASE::TRAIN;
234 for(
size_t l=0; l<netlist.size(); l++)
236 netlist.at(l)->set_phase(PHASE::TRAIN);
239 tic = high_resolution_clock::now();
241 for (
unsigned int n=0; n<N_iter; n++)
244 for(
size_t l=0; l<netlist.size(); l++)
246 netlist.at(l)->prop_mb();
249 cost_val.at(n,curEpoch) = ((
layer_cost*)netlist.back())->get_cost();
250 accuracy.at(n,curEpoch) = ((
layer_cost*)netlist.back())->get_acc();
253 for(
size_t l=netlist.size()-1; l>1; l--)
255 netlist.at(l)->backprop();
259 for(
size_t l=0; l<netlist.size(); l++)
261 netlist.at(l)->update();
267 toc = high_resolution_clock::now();
281 for(
size_t l=0; l<netlist.size(); l++)
283 netlist.at(l)->set_phase(PHASE::TEST);
286 tic = high_resolution_clock::now();
288 for (
unsigned int n=0; n<N_iter; n++)
291 for(
size_t l=0; l<netlist.size(); l++)
293 netlist.at(l)->prop_mb();
295 cost_val.at(n,curEpoch) = ((
layer_cost*)netlist.back())->get_cost();
296 accuracy.at(n,curEpoch) = ((
layer_cost*)netlist.back())->get_acc();
300 toc = high_resolution_clock::now();
309 init.copyfmt(std::cout);
312 std::cout <<
"---------------------------------------------------------" << std::endl;
313 std::cout << std::left <<
" " 314 << std::setw(8) <<
"EPOCH" 315 << std::setw(15) <<
"Cost" 316 << std::setw(15) <<
"Accuracy [%]" 317 << std::setw(15) <<
"Time [ms]" 320 std::cout.copyfmt(init);
321 std::cout <<
"---------------------------------------------------------" << std::endl;
323 std::cout << std::left <<
" " 324 << std::setw(8) << curEpoch
325 << std::setw(15) << arma::mean(cost_val.col(curEpoch-1))
326 << std::setw(15) << 100*arma::mean(accuracy.col(curEpoch-1))
327 << std::setw(15) << duration_cast<milliseconds>(
toc -
tic).count()
331 std::cout.copyfmt(init);
339 arma::Mat<DNN_Dtype>
predict(
const arma::Mat<DNN_Dtype>& in)
341 for(
size_t l=0; l<netlist.size(); l++)
343 netlist.at(l)->set_phase(PHASE::PRED);
347 ((
layer_input*)netlist.front())->reset_batch_ctr();
348 ((
layer_cost*)netlist.back())->reset_batch_ctr();
350 for(
size_t l=0; l<netlist.size(); l++)
352 netlist.at(l)->prop();
354 return ((
layer_cost*)netlist.back())->get_Y1();
virtual void set_ix(const arma::uword n)
Set layer index.
void remove_progress_bar(void)
Clears progress bar.
arma::Mat< DNN_Dtype > * Xptr
Input training data matrix.
arma::uword N_batch
Number of data in mini-batch.
PHASE phase
Current state/phase.
void add_layer(layer &lay)
Add layer to network stack.
high_resolution_clock::time_point tic
Cost/output layer base class.
arma::uword curEpoch
Current EPOCH.
void show_stat(void)
Prints the statistics for each EPOCH.
std::vector< layer * > netlist
Chained list of network layers.
bool compile(void)
Compiles network stack.
void data_set(arma::Mat< DNN_Dtype > &data, arma::Mat< DNN_Dtype > &label, arma::uword Nmb, arma::uword Nep=1)
Set a new dataset in the model.
arma::mat get_accuracy(void)
Get accuracy.
virtual void init(void)
Initialize layer.
arma::mat accuracy
Accuracy logging matrix.
arma::uword N_data
Dataset size.
high_resolution_clock::time_point toc
arma::Mat< DNN_Dtype > * Tptr
Input target data matrix.
arma::mat get_cost(void)
Get cost.
arma::uword N_iter
Number of iterations.
void set_batch_len(arma::uword n)
Set mini batch length.
void test(void)
Test the network.
void disp(void)
Displays a short info about the network layers.
arma::mat cost_val
Cost logging matrix.
void progress_bar(const std::string str, double p)
Console progress bar.
void train(void)
Trains the network.
arma::uword N_epoch
Number of EPOCHs.
virtual void set_left(layer *lptr)
Set pointer to left layer.
Neural netowork model class.
void permute_data(arma::Mat< DNN_Dtype > *Xptr, arma::Mat< DNN_Dtype > *Tptr)
Permute dataset.
net(void)
Network constructor.
arma::Mat< DNN_Dtype > predict(const arma::Mat< DNN_Dtype > &in)
A forward pass in the network for one input.