DaNNet
dnn_net.h
Go to the documentation of this file.
1 // Copyright 2019 Claes Rolen (www.rolensystems.com)
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 #include <chrono>
17 using namespace std::chrono;
18 high_resolution_clock::time_point tic, toc;
19 
20 namespace dnn
21 {
27 
34 class net
35 {
36 private:
37  std::vector<layer*> netlist;
38  arma::Mat<DNN_Dtype>* Xptr;
39  arma::Mat<DNN_Dtype>* Tptr;
40  arma::uword curEpoch;
41  arma::uword N_epoch;
42  arma::uword N_data;
43  arma::uword N_batch;
44  arma::uword N_iter;
45  arma::mat cost_val;
46  arma::mat accuracy;
48 public:
52  net(void)
53  {
54  N_epoch = 0;
55  curEpoch = 0;
56  N_batch = 1;
57  N_data = 0;
58  N_iter = 0;
59  phase = PHASE::TRAIN;
60  };
61 
62  ~net() {};
63 
74  void data_set( arma::Mat<DNN_Dtype>& data,arma::Mat<DNN_Dtype>& label,arma::uword Nmb,arma::uword Nep=1)
75  {
76  N_batch = Nmb;
77  N_epoch = Nep;
78  curEpoch = 0;
79  Xptr = &data;
80  Tptr = &label;
81  N_data = data.n_cols;
82  N_iter = static_cast<arma::uword>(floor(N_data/N_batch));
83 
84  std::cout << std::endl;
85  std::cout << "Nr of data: " << N_data << std::endl;
86  std::cout << "Minibatch size: " << N_batch << std::endl;
87  std::cout << "Iterations/EPOCH: " << N_iter << std::endl;
88  std::cout << "EPOCHs: " << N_epoch << std::endl;
89 
90  ((layer_input*)netlist.front())->set_data(Xptr);
91  ((layer_cost*)netlist.back())->set_data(Tptr);
92 
93  // Update buffer size and clear
94  for(size_t l=0; l<netlist.size(); l++)
95  {
96  netlist.at(l)->upd_buf_size(N_batch);
97  }
98 
99  // Clear statistics
100  cost_val.set_size(N_iter,N_epoch);
101  cost_val.zeros();
102  accuracy.set_size(N_iter,N_epoch);
103  accuracy.zeros();
104  }
105 
109  void set_batch_len(arma::uword n)
110  {
111  N_batch = n;
112  }
113 
117  arma::mat get_accuracy(void)
118  {
119  return accuracy;
120  }
121 
125  arma::mat get_cost(void)
126  {
127  return cost_val;
128  }
129 
134  void permute_data(arma::Mat<DNN_Dtype>* Xptr,arma::Mat<DNN_Dtype>* Tptr)
135  {
136  arma::Mat<DNN_Dtype> Xc(*Xptr);
137  arma::Mat<DNN_Dtype> Tc(*Tptr);
138  arma::uvec ix0 = arma::regspace<arma::uvec>(0,Xc.n_cols-1);
139  arma::uvec ix = arma::shuffle(ix0);
140 
141  for(arma::uword c = 0; c<Xptr->n_cols; c++)
142  {
143  (*Xptr).col(c) = Xc.col(ix(c));
144  (*Tptr).col(c) = Tc.col(ix(c));
145  }
146  }
147 
151  void add_layer(layer& lay)
152  {
153  if(!netlist.empty())
154  {
155  lay.set_left(netlist.back());
156  netlist.back()->set_right(&lay);
157  }
158  lay.set_ix(netlist.size());
159  lay.init();
160  netlist.push_back(&lay);
161  }
162 
168  bool compile(void)
169  {
170  // Check in and outputs
171  std::string first = netlist.front()->get_type();
172  std::string last = netlist.back()->get_type();
173  if(first.compare("Input") || last.compare("Output"))
174  {
175  std::cout << "Net must have input and output!" << std::endl;
176  return false;
177  }
178 
179  // Check dimensions
180  for(size_t l=1; l<netlist.size(); l++)
181  {
182  arma::uword n = netlist.at(l)->get_nrof_inputs();
183  arma::uword m = netlist.at(l-1)->get_nrof_outputs();
184  if(n != m)
185  {
186  std::cout << "Wrong dimensions: " << n << " " << m << std::endl;
187  return false;
188  }
189  }
190  return true;
191  }
192 
195  void disp(void)
196  {
197  std::ios init(NULL);
198  init.copyfmt(std::cout);
199 
200  std::cout << std::left << "\n "
201  << std::setw(6) << "Layer"
202  << std::setw(23) << "TypeId"
203  << std::setw(30) << "Optimizer"
204  << std::setw(12) << "Shape"
205  << std::setw(10) << "Params"
206  << std::endl;
207 
208  // restore default formatting
209  std::cout.copyfmt(init);
210 
211  std::cout << "================================================================================" << std::endl;
212  for(auto it:netlist)
213  {
214  std::cout << *it << std::endl;
215  }
216  std::cout << "================================================================================" << std::endl;
217 
218  }
219 
227  void train(void)
228  {
229  // Shuffle the training data
230  permute_data(Xptr,Tptr);
231 
232  // Change network state
233  phase = PHASE::TRAIN;
234  for(size_t l=0; l<netlist.size(); l++)
235  {
236  netlist.at(l)->set_phase(PHASE::TRAIN);
237  }
238 
239  tic = high_resolution_clock::now();
240  // Loop through the entire train dataset
241  for (unsigned int n=0; n<N_iter; n++)
242  {
243  // Forward
244  for(size_t l=0; l<netlist.size(); l++)
245  {
246  netlist.at(l)->prop_mb();
247  }
248  // Log statistcs
249  cost_val.at(n,curEpoch) = ((layer_cost*)netlist.back())->get_cost();
250  accuracy.at(n,curEpoch) = ((layer_cost*)netlist.back())->get_acc();
251 
252  // Backpropagate
253  for(size_t l=netlist.size()-1; l>1; l--)
254  {
255  netlist.at(l)->backprop();
256  }
257 
258  // Update
259  for(size_t l=0; l<netlist.size(); l++)
260  {
261  netlist.at(l)->update();
262  }
263  progress_bar("Train",n,N_iter);
264  }
266 
267  toc = high_resolution_clock::now();
268  curEpoch++;
269  }
270 
277  void test(void)
278  {
279  // Change network state
280  phase = PHASE::TEST;
281  for(size_t l=0; l<netlist.size(); l++)
282  {
283  netlist.at(l)->set_phase(PHASE::TEST);
284  }
285 
286  tic = high_resolution_clock::now();
287  // Loop through the entire test dataset
288  for (unsigned int n=0; n<N_iter; n++)
289  {
290  // Forward
291  for(size_t l=0; l<netlist.size(); l++)
292  {
293  netlist.at(l)->prop_mb();
294  }
295  cost_val.at(n,curEpoch) = ((layer_cost*)netlist.back())->get_cost();
296  accuracy.at(n,curEpoch) = ((layer_cost*)netlist.back())->get_acc();
297  progress_bar("Test",n,N_iter);
298  }
300  toc = high_resolution_clock::now();
301  curEpoch++;
302  }
303 
306  void show_stat(void)
307  {
308  std::ios init(NULL);
309  init.copyfmt(std::cout);
310  if(curEpoch==1)
311  {
312  std::cout << "---------------------------------------------------------" << std::endl;
313  std::cout << std::left << " "
314  << std::setw(8) << "EPOCH"
315  << std::setw(15) << "Cost"
316  << std::setw(15) << "Accuracy [%]"
317  << std::setw(15) << "Time [ms]"
318  << std::endl;
319  // restore default formatting
320  std::cout.copyfmt(init);
321  std::cout << "---------------------------------------------------------" << std::endl;
322  }
323  std::cout << std::left << " "
324  << std::setw(8) << curEpoch
325  << std::setw(15) << arma::mean(cost_val.col(curEpoch-1)) // Return the mean over all iterations
326  << std::setw(15) << 100*arma::mean(accuracy.col(curEpoch-1)) // Return the mean over all iterations
327  << std::setw(15) << duration_cast<milliseconds>(toc - tic).count()
328  << std::endl;
329 
330  // restore default formatting
331  std::cout.copyfmt(init);
332  }
333 
339  arma::Mat<DNN_Dtype> predict(const arma::Mat<DNN_Dtype>& in)
340  {
341  for(size_t l=0; l<netlist.size(); l++)
342  {
343  netlist.at(l)->set_phase(PHASE::PRED);
344  }
345 
346  ((layer_input*)netlist.front())->set_data(&in);
347  ((layer_input*)netlist.front())->reset_batch_ctr();
348  ((layer_cost*)netlist.back())->reset_batch_ctr();
349  // Forward
350  for(size_t l=0; l<netlist.size(); l++)
351  {
352  netlist.at(l)->prop();
353  }
354  return ((layer_cost*)netlist.back())->get_Y1();
355  }
356 }; // End class net
358 } // End namespace dnn
virtual void set_ix(const arma::uword n)
Set layer index.
Input/data layer class.
void remove_progress_bar(void)
Clears progress bar.
Definition: dnn_misc.h:422
PHASE
Definition: dnn.h:34
arma::Mat< DNN_Dtype > * Xptr
Input training data matrix.
Definition: dnn_net.h:38
arma::uword N_batch
Number of data in mini-batch.
Definition: dnn_net.h:43
PHASE phase
Current state/phase.
Definition: dnn_net.h:47
void add_layer(layer &lay)
Add layer to network stack.
Definition: dnn_net.h:151
high_resolution_clock::time_point tic
Definition: dnn_net.h:18
Layer base class.
Cost/output layer base class.
arma::uword curEpoch
Current EPOCH.
Definition: dnn_net.h:40
void show_stat(void)
Prints the statistics for each EPOCH.
Definition: dnn_net.h:306
std::vector< layer * > netlist
Chained list of network layers.
Definition: dnn_net.h:37
bool compile(void)
Compiles network stack.
Definition: dnn_net.h:168
void data_set(arma::Mat< DNN_Dtype > &data, arma::Mat< DNN_Dtype > &label, arma::uword Nmb, arma::uword Nep=1)
Set a new dataset in the model.
Definition: dnn_net.h:74
arma::mat get_accuracy(void)
Get accuracy.
Definition: dnn_net.h:117
virtual void init(void)
Initialize layer.
arma::mat accuracy
Accuracy logging matrix.
Definition: dnn_net.h:46
~net()
Definition: dnn_net.h:62
arma::uword N_data
Dataset size.
Definition: dnn_net.h:42
high_resolution_clock::time_point toc
Definition: dnn_net.h:18
arma::Mat< DNN_Dtype > * Tptr
Input target data matrix.
Definition: dnn_net.h:39
arma::mat get_cost(void)
Get cost.
Definition: dnn_net.h:125
arma::uword N_iter
Number of iterations.
Definition: dnn_net.h:44
void set_batch_len(arma::uword n)
Set mini batch length.
Definition: dnn_net.h:109
Definition: dnn.h:22
void test(void)
Test the network.
Definition: dnn_net.h:277
void disp(void)
Displays a short info about the network layers.
Definition: dnn_net.h:195
arma::mat cost_val
Cost logging matrix.
Definition: dnn_net.h:45
void progress_bar(const std::string str, double p)
Console progress bar.
Definition: dnn_misc.h:382
void train(void)
Trains the network.
Definition: dnn_net.h:227
arma::uword N_epoch
Number of EPOCHs.
Definition: dnn_net.h:41
virtual void set_left(layer *lptr)
Set pointer to left layer.
Neural netowork model class.
Definition: dnn_net.h:34
void permute_data(arma::Mat< DNN_Dtype > *Xptr, arma::Mat< DNN_Dtype > *Tptr)
Permute dataset.
Definition: dnn_net.h:134
net(void)
Network constructor.
Definition: dnn_net.h:52
arma::Mat< DNN_Dtype > predict(const arma::Mat< DNN_Dtype > &in)
A forward pass in the network for one input.
Definition: dnn_net.h:339