diff options
Diffstat (limited to 'libavfilter/dnn_backend_tf.c')
-rw-r--r-- | libavfilter/dnn_backend_tf.c | 28 |
1 files changed, 24 insertions, 4 deletions
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c index ca6472d445..ba959ae3a2 100644 --- a/libavfilter/dnn_backend_tf.c +++ b/libavfilter/dnn_backend_tf.c @@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename) return graph_buf; } -static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) +static TF_Tensor *allocate_input_tensor(const DNNInputData *input) { - TFModel *tf_model = (TFModel *)model; + TF_DataType dt; + size_t size; int64_t input_dims[] = {1, input->height, input->width, input->channels}; + switch (input->dt) { + case DNN_FLOAT: + dt = TF_FLOAT; + size = sizeof(float); + break; + case DNN_UINT8: + dt = TF_UINT8; + size = sizeof(char); + break; + default: + av_assert0(!"should not reach here"); + } + + return TF_AllocateTensor(dt, input_dims, 4, + input_dims[1] * input_dims[2] * input_dims[3] * size); +} + +static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) +{ + TFModel *tf_model = (TFModel *)model; TF_SessionOptions *sess_opts; const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); @@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char if (tf_model->input_tensor){ TF_DeleteTensor(tf_model->input_tensor); } - tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, - input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float)); + tf_model->input_tensor = allocate_input_tensor(input); if (!tf_model->input_tensor){ return DNN_ERROR; } |