DaNNet
dnn_layer_input.h
Go to the documentation of this file.
1 // Copyright 2019 Claes Rolen (www.rolensystems.com)
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 namespace dnn
17 {
21 
28 class layer_input: public layer
29 {
30 private:
31  const arma::Mat<DNN_Dtype>* data_in;
32  arma::uword data_len;
33  arma::uword data_ctr;
34 public:
41  layer_input(const arma::uword n_rows, const arma::uword n_cols=1, const arma::uword n_channels=1):layer()
42  {
43  N_rows_right = n_rows;
44  N_cols_right = n_cols;
45  N_channels_right = n_channels;
47  type = "Input";
48  id = type;
49  }
50 
55  void init(void)
56  {
57  N_left = 0;
58  N_cols_left = 0;
59  N_rows_left = 0;
60  N_channels_left = 0;
61  data_len = 0;
62  data_ctr = 0;
63  }
64 
71  void set_data(const arma::Mat<DNN_Dtype>* data)
72  {
73  data_in = data;
74  data_len = data->n_cols;
75  data_ctr = 0;
76  }
77 
81  void reset_batch_ctr(void)
82  {
83  data_ctr = 0;
84  }
85 
91  void prop(void)
92  {
93  Y1 = data_in->col(data_ctr++);
94  if(data_ctr == data_len)
95  data_ctr = 0; // Wrap data
96  }
97 
103  void prop_mb(void)
104  {
105  if( (data_ctr+N_batch)>=data_len )
106  {
107  // We are at the end, wrap data
108  for (arma::uword k=0 ; k<N_batch ; k++ )
109  {
110  Y.col(k) = data_in->col(data_ctr++);
111  if(data_ctr >= data_len)
112  data_ctr = 0; // Wrap data
113  }
114  }
115  else
116  {
117  Y = data_in->cols(data_ctr,data_ctr+N_batch-1);
118  data_ctr+=N_batch;
119  }
120  }
121 
125  void disp(void)
126  {
127  layer::disp();
128  std::cout << "Output data length: " << get_nrof_outputs() << " ["<< N_rows_right << ","<<N_cols_right<<","<< N_channels_right<<"]"<< std::endl;
129  }
130 
131 }; // End class layer_input
133 } // End namespace dnn
Input/data layer class.
const arma::Mat< DNN_Dtype > * data_in
Data input pointer.
std::string type
Layer type string.
arma::Mat< DNN_Dtype > Y
Output buffer mini batch [N_right,N_batch].
void disp(void)
Display info about layer.
arma::uword N_rows_left
Input rows.
void reset_batch_ctr(void)
Resets data counter.
arma::uword N_channels_right
Output channels, number of filters.
void set_data(const arma::Mat< DNN_Dtype > *data)
Sets new data.
Layer base class.
arma::uword N_cols_left
Input cols.
arma::Mat< DNN_Dtype > Y1
Output buffer [N_right,1].
arma::uword data_ctr
Data counter = index of current data.
virtual arma::uword get_nrof_outputs(void)
Get total number of layer outputs.
arma::uword data_len
Number of samples in the input data.
void init(void)
Initialization of layer.
virtual void disp(void)
Display info about layer.
arma::uword N_left
Total size left.
arma::uword N_rows_right
Output rows.
arma::uword N_batch
Mini batch size.
Definition: dnn.h:22
layer_input(const arma::uword n_rows, const arma::uword n_cols=1, const arma::uword n_channels=1)
Input layer constructor.
arma::uword N_channels_left
Input channels, number of filters.
arma::uword N_right
Total size right.
arma::uword N_cols_right
Output cols.
void prop(void)
Forward propagation though layer.
void prop_mb(void)
Forward mini batch propagation though layer.