summaryrefslogtreecommitdiff
path: root/libavfilter/dnn_backend_tf.c
diff options
context:
space:
mode:
Diffstat (limited to 'libavfilter/dnn_backend_tf.c')
-rw-r--r--libavfilter/dnn_backend_tf.c74
1 files changed, 45 insertions, 29 deletions
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index e46b1ad140..302ff9e4e1 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -25,6 +25,7 @@
#include "dnn_backend_tf.h"
#include "dnn_srcnn.h"
+#include "dnn_espcn.h"
#include "libavformat/avio.h"
#include <tensorflow/c/c_api.h>
@@ -35,9 +36,7 @@ typedef struct TFModel{
TF_Status* status;
TF_Output input, output;
TF_Tensor* input_tensor;
- TF_Tensor* output_tensor;
- const DNNData* input_data;
- const DNNData* output_data;
+ DNNData* output_data;
} TFModel;
static void free_buffer(void* data, size_t length)
@@ -78,13 +77,13 @@ static TF_Buffer* read_graph(const char* model_filename)
return graph_buf;
}
-static DNNReturnType set_input_output_tf(void* model, const DNNData* input, const DNNData* output)
+static DNNReturnType set_input_output_tf(void* model, DNNData* input, DNNData* output)
{
TFModel* tf_model = (TFModel*)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
- int64_t output_dims[] = {1, output->height, output->width, output->channels};
TF_SessionOptions* sess_opts;
const TF_Operation* init_op = TF_GraphOperationByName(tf_model->graph, "init");
+ TF_Tensor* output_tensor;
// Input operation should be named 'x'
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
@@ -100,6 +99,7 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
if (!tf_model->input_tensor){
return DNN_ERROR;
}
+ input->data = (float*)TF_TensorData(tf_model->input_tensor);
// Output operation should be named 'y'
tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
@@ -107,17 +107,6 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
return DNN_ERROR;
}
tf_model->output.index = 0;
- if (tf_model->output_tensor){
- TF_DeleteTensor(tf_model->output_tensor);
- }
- tf_model->output_tensor = TF_AllocateTensor(TF_FLOAT, output_dims, 4,
- output_dims[1] * output_dims[2] * output_dims[3] * sizeof(float));
- if (!tf_model->output_tensor){
- return DNN_ERROR;
- }
-
- tf_model->input_data = input;
- tf_model->output_data = output;
if (tf_model->session){
TF_CloseSession(tf_model->session, tf_model->status);
@@ -144,6 +133,26 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons
}
}
+ // Execute network to get output height, width and number of channels
+ TF_SessionRun(tf_model->session, NULL,
+ &tf_model->input, &tf_model->input_tensor, 1,
+ &tf_model->output, &output_tensor, 1,
+ NULL, 0, NULL, tf_model->status);
+ if (TF_GetCode(tf_model->status) != TF_OK){
+ return DNN_ERROR;
+ }
+ else{
+ output->height = TF_Dim(output_tensor, 1);
+ output->width = TF_Dim(output_tensor, 2);
+ output->channels = TF_Dim(output_tensor, 3);
+ output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
+ if (!output->data){
+ return DNN_ERROR;
+ }
+ tf_model->output_data = output;
+ TF_DeleteTensor(output_tensor);
+ }
+
return DNN_SUCCESS;
}
@@ -166,7 +175,7 @@ DNNModel* ff_dnn_load_model_tf(const char* model_filename)
}
tf_model->session = NULL;
tf_model->input_tensor = NULL;
- tf_model->output_tensor = NULL;
+ tf_model->output_data = NULL;
graph_def = read_graph(model_filename);
if (!graph_def){
@@ -215,6 +224,17 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
graph_def->length = srcnn_tf_size;
graph_def->data_deallocator = free_buffer;
break;
+ case DNN_ESPCN:
+ graph_data = av_malloc(espcn_tf_size);
+ if (!graph_data){
+ TF_DeleteBuffer(graph_def);
+ return NULL;
+ }
+ memcpy(graph_data, espcn_tf_model, espcn_tf_size);
+ graph_def->data = (void*)graph_data;
+ graph_def->length = espcn_tf_size;
+ graph_def->data_deallocator = free_buffer;
+ break;
default:
TF_DeleteBuffer(graph_def);
return NULL;
@@ -234,7 +254,7 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
}
tf_model->session = NULL;
tf_model->input_tensor = NULL;
- tf_model->output_tensor = NULL;
+ tf_model->output_data = NULL;
tf_model->graph = TF_NewGraph();
tf_model->status = TF_NewStatus();
@@ -259,23 +279,21 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type)
DNNReturnType ff_dnn_execute_model_tf(const DNNModel* model)
{
TFModel* tf_model = (TFModel*)model->model;
-
- memcpy(TF_TensorData(tf_model->input_tensor), tf_model->input_data->data,
- tf_model->input_data->height * tf_model->input_data->width *
- tf_model->input_data->channels * sizeof(float));
+ TF_Tensor* output_tensor;
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
- &tf_model->output, &tf_model->output_tensor, 1,
+ &tf_model->output, &output_tensor, 1,
NULL, 0, NULL, tf_model->status);
if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR;
}
else{
- memcpy(tf_model->output_data->data, TF_TensorData(tf_model->output_tensor),
- tf_model->output_data->height * tf_model->output_data->width *
- tf_model->output_data->channels * sizeof(float));
+ memcpy(tf_model->output_data->data, TF_TensorData(output_tensor),
+ tf_model->output_data->height * tf_model->output_data->width *
+ tf_model->output_data->channels * sizeof(float));
+ TF_DeleteTensor(output_tensor);
return DNN_SUCCESS;
}
@@ -300,9 +318,7 @@ void ff_dnn_free_model_tf(DNNModel** model)
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
- if (tf_model->output_tensor){
- TF_DeleteTensor(tf_model->output_tensor);
- }
+ av_freep(&tf_model->output_data->data);
av_freep(&tf_model);
av_freep(model);
}