summaryrefslogtreecommitdiff
path: root/libavfilter/dnn
diff options
context:
space:
mode:
authorGuo, Yejun <yejun.guo@intel.com>2019-10-21 20:38:17 +0800
committerPedro Arthur <bygrandao@gmail.com>2019-10-30 11:07:06 -0300
commitf4b3c0e55c84434eb897c2a4a1179cb1d202c52c (patch)
tree4d119163acfab352d18328b526c0bb99a7dc771d /libavfilter/dnn
parente1b45b85963b5aa9d67e23638ef9b045e7fbd875 (diff)
avfilter/dnn: add a new interface to query dnn model's input info
to support dnn networks more general, we need to know the input info of the dnn model. background: The data type of dnn model's input could be float32, uint8 or fp16, etc. And the w/h of input image could be fixed or variable. Signed-off-by: Guo, Yejun <yejun.guo@intel.com> Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
Diffstat (limited to 'libavfilter/dnn')
-rw-r--r--libavfilter/dnn/dnn_backend_native.c24
-rw-r--r--libavfilter/dnn/dnn_backend_tf.c32
2 files changed, 55 insertions, 1 deletions
diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
index add1db42cf..94634b3065 100644
--- a/libavfilter/dnn/dnn_backend_native.c
+++ b/libavfilter/dnn/dnn_backend_native.c
@@ -28,6 +28,28 @@
#include "dnn_backend_native_layer_conv2d.h"
#include "dnn_backend_native_layers.h"
+static DNNReturnType get_input_native(void *model, DNNData *input, const char *input_name)
+{
+ ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
+
+ for (int i = 0; i < network->operands_num; ++i) {
+ DnnOperand *oprd = &network->operands[i];
+ if (strcmp(oprd->name, input_name) == 0) {
+ if (oprd->type != DOT_INPUT)
+ return DNN_ERROR;
+ input->dt = oprd->data_type;
+ av_assert0(oprd->dims[0] == 1);
+ input->height = oprd->dims[1];
+ input->width = oprd->dims[2];
+ input->channels = oprd->dims[3];
+ return DNN_SUCCESS;
+ }
+ }
+
+ // do not find the input operand
+ return DNN_ERROR;
+}
+
static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
@@ -37,7 +59,6 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
return DNN_ERROR;
/* inputs */
- av_assert0(input->dt == DNN_FLOAT);
for (int i = 0; i < network->operands_num; ++i) {
oprd = &network->operands[i];
if (strcmp(oprd->name, input_name) == 0) {
@@ -234,6 +255,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename)
}
model->set_input_output = &set_input_output_native;
+ model->get_input = &get_input_native;
return model;
}
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index ed91d0500d..a921667424 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -105,6 +105,37 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input)
input_dims[1] * input_dims[2] * input_dims[3] * size);
}
+static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input_name)
+{
+ TFModel *tf_model = (TFModel *)model;
+ TF_Status *status;
+ int64_t dims[4];
+
+ TF_Output tf_output;
+ tf_output.oper = TF_GraphOperationByName(tf_model->graph, input_name);
+ if (!tf_output.oper)
+ return DNN_ERROR;
+
+ tf_output.index = 0;
+ input->dt = TF_OperationOutputType(tf_output);
+
+ status = TF_NewStatus();
+ TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status);
+ if (TF_GetCode(status) != TF_OK){
+ TF_DeleteStatus(status);
+ return DNN_ERROR;
+ }
+ TF_DeleteStatus(status);
+
+ // currently only NHWC is supported
+ av_assert0(dims[0] == 1);
+ input->height = dims[1];
+ input->width = dims[2];
+ input->channels = dims[3];
+
+ return DNN_SUCCESS;
+}
+
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{
TFModel *tf_model = (TFModel *)model;
@@ -568,6 +599,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
model->model = (void *)tf_model;
model->set_input_output = &set_input_output_tf;
+ model->get_input = &get_input_tf;
return model;
}