summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuo, Yejun <yejun.guo@intel.com>2019-04-25 10:14:42 +0800
committerPedro Arthur <bygrandao@gmail.com>2019-05-08 12:33:00 -0300
commitc636dc9819ebab1a84237cc017a6a3d35ebc9cdc (patch)
tree39fd943e649cb1185f25ccce6e7be193448ba23c
parent25c1cd909fa6c8b6b778dc24192dc3ec780324b0 (diff)
libavfilter/dnn: add more data type support for dnn model input
currently, only float is supported as model input, actually, there are other data types, this patch adds uint8. Signed-off-by: Guo, Yejun <yejun.guo@intel.com> Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
-rw-r--r--libavfilter/dnn_backend_native.c4
-rw-r--r--libavfilter/dnn_backend_tf.c28
-rw-r--r--libavfilter/dnn_interface.h10
-rw-r--r--libavfilter/vf_sr.c4
4 files changed, 39 insertions, 7 deletions
diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c
index 8a83c63c73..06fbdf368b 100644
--- a/libavfilter/dnn_backend_native.c
+++ b/libavfilter/dnn_backend_native.c
@@ -24,8 +24,9 @@
*/
#include "dnn_backend_native.h"
+#include "libavutil/avassert.h"
-static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
{
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
InputParams *input_params;
@@ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
if (input->data){
av_freep(&input->data);
}
+ av_assert0(input->dt == DNN_FLOAT);
network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float));
if (!network->layers[0].output){
return DNN_ERROR;
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index ca6472d445..ba959ae3a2 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf;
}
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static TF_Tensor *allocate_input_tensor(const DNNInputData *input)
{
- TFModel *tf_model = (TFModel *)model;
+ TF_DataType dt;
+ size_t size;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
+ switch (input->dt) {
+ case DNN_FLOAT:
+ dt = TF_FLOAT;
+ size = sizeof(float);
+ break;
+ case DNN_UINT8:
+ dt = TF_UINT8;
+ size = sizeof(char);
+ break;
+ default:
+ av_assert0(!"should not reach here");
+ }
+
+ return TF_AllocateTensor(dt, input_dims, 4,
+ input_dims[1] * input_dims[2] * input_dims[3] * size);
+}
+
+static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+{
+ TFModel *tf_model = (TFModel *)model;
TF_SessionOptions *sess_opts;
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
@@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
- tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4,
- input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float));
+ tf_model->input_tensor = allocate_input_tensor(input);
if (!tf_model->input_tensor){
return DNN_ERROR;
}
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 73d226ec91..c24df0e961 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType;
+typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType;
+
+typedef struct DNNInputData{
+ void *data;
+ DNNDataType dt;
+ int width, height, channels;
+} DNNInputData;
+
typedef struct DNNData{
float *data;
int width, height, channels;
@@ -42,7 +50,7 @@ typedef struct DNNModel{
void *model;
// Sets model input and output.
// Should be called at least once before model execution.
- DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output);
+ DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output);
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 0145511d11..65baf5f901 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -40,7 +40,8 @@ typedef struct SRContext {
DNNBackendType backend_type;
DNNModule *dnn_module;
DNNModel *model;
- DNNData input, output;
+ DNNInputData input;
+ DNNData output;
int scale_factor;
struct SwsContext *sws_contexts[3];
int sws_slice_h, sws_input_linesize, sws_output_linesize;
@@ -86,6 +87,7 @@ static av_cold int init(AVFilterContext *context)
return AVERROR(EIO);
}
+ sr_context->input.dt = DNN_FLOAT;
sr_context->sws_contexts[0] = NULL;
sr_context->sws_contexts[1] = NULL;
sr_context->sws_contexts[2] = NULL;