目前深度学习训练和推理涉及到的输入数据通常为4-D,对应的通道格式主要有两种:
- NCHW
- NHWC
其中各个字母代表的含义为:
- N - Batch
- C - Channel 特征图通道
- H - Height 特征图高度
- W - Width 特征图宽度
各个框架和图像处理方式对图像数据要求如下:
- TensorFlow模型默认的输入格式为:RGB NHWC
- Pytorch模型默认的输入格式为:RGB NCHW
- ONNX模型默认的输入格式为:RGB NCHW fp32
- Caffe 的Blob通道顺序是:NCHW
- TensorRT中通道顺序:NCHW
- OpenCV默认数据格式为:BGR HWC uint8
NCHW 则是 Nvidia cuDNN 默认格式,使用 GPU 加速时用 NCHW 格式速度会更快
一、基本原理
如图所示,假定N = 2,C = 16,H = 5,W = 4,
无论逻辑表达上是几维的数据,在计算机中存储时都是按照1D来存储的。下面很可以很清楚的看到NCHW和NHWC格式的高位数据,存储为1D时候的样子:
总的来说,无论是NCHW还是NHWC或者CHWN,在读取为1D时都是从后往前读,举例来说:
对于NCHW格式的4D数据,首先取W方向数据;然后H方向;再C方向;最后N方向。
所以,序列化出1D数据为:
000 (W方向) 001 002 003,(H方向) 004 005 … 019,(C方向) 020 … 318 319,(N方向) 320 321 …对于NHWC格式的4D数据,首先取C方向数据;然后W方向;再H方向;最后N方向。
所以,序列化出1D数据:
000 (C方向) 020 … 300,(W方向) 001 021 … 303,(H方向) 004 … 319,(N方向) 320 340 …
我们通常在输入一张256 * 256分辨率的rgb图像时,对应的4D数据为[N = 1, H=256.h, W=256, C=3],然后对应的1D数据的组织方式如下图所示:
NCHW: RRRRRRRRRRGGGGGGGGGGBBBBBBBBBB
NHWC: RBGRGBRGBRGBRGBRGBRGBRGBRGBRGB
二、java调用tensorflow pb模型推理的简单运用
第一种,若 Tensor.create(input) 输入的input是4维数组,那么按照tensorflow要求的NHWC的格式进行数据的组织即可:1
2
3
4
5
6
7
8
9
10
11
12
13Imgproc.resize(src, dst, new Size(h, w)); // 1, h, w,3
float input[][][][] = new float[1][h][w][3];
System.out.println(dst.rows());
System.out.println(dst.cols());
for (int i = 0; i < dst.cols(); i++) {
for (int j = 0; j < dst.rows(); j++) {
double[] pixel = dst.get(j, i);
input[0][i][j][0] = (float) (255 - pixel[0]);
input[0][i][j][1] = (float) (255 - pixel[1]);
input[0][i][j][2] = (float) (255 - pixel[2]);
}
}
Tensor input_X = Tensor.create(input);
第二种,若 Tensor.create(input) 输入的input是1维数组,那么按照前面转1D数据的基本原理,将NHWC的格式进行转换后再组织即可:1
2
3
4
5
6
7
8
9
10
11
12
13
14Imgproc.resize(src, dst, new Size(h, w)); // 1, h, w,3
float input[] = new float[1 * h * w * 3];
System.out.println(dst.rows());
System.out.println(dst.cols());
int index = 0;
for (int i = 0; i < dst.cols(); i++) {
for (int j = 0; j < dst.rows(); j++) {
double[] pixel = dst.get(j, i);
input[index++] = (float) (255 - pixel[0]);
input[index++] = (float) (255 - pixel[1]);
input[index++] = (float) (255 - pixel[2]);
}
}
Tensor input_X = Tensor.create(shape = (1,h,w,3),input);