diff options
Diffstat (limited to 'libavfilter/dnn_backend_tf.c')
-rw-r--r-- | libavfilter/dnn_backend_tf.c | 74 |
1 files changed, 45 insertions, 29 deletions
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c index e46b1ad140..302ff9e4e1 100644 --- a/libavfilter/dnn_backend_tf.c +++ b/libavfilter/dnn_backend_tf.c @@ -25,6 +25,7 @@ #include "dnn_backend_tf.h" #include "dnn_srcnn.h" +#include "dnn_espcn.h" #include "libavformat/avio.h" #include <tensorflow/c/c_api.h> @@ -35,9 +36,7 @@ typedef struct TFModel{ TF_Status* status; TF_Output input, output; TF_Tensor* input_tensor; - TF_Tensor* output_tensor; - const DNNData* input_data; - const DNNData* output_data; + DNNData* output_data; } TFModel; static void free_buffer(void* data, size_t length) @@ -78,13 +77,13 @@ static TF_Buffer* read_graph(const char* model_filename) return graph_buf; } -static DNNReturnType set_input_output_tf(void* model, const DNNData* input, const DNNData* output) +static DNNReturnType set_input_output_tf(void* model, DNNData* input, DNNData* output) { TFModel* tf_model = (TFModel*)model; int64_t input_dims[] = {1, input->height, input->width, input->channels}; - int64_t output_dims[] = {1, output->height, output->width, output->channels}; TF_SessionOptions* sess_opts; const TF_Operation* init_op = TF_GraphOperationByName(tf_model->graph, "init"); + TF_Tensor* output_tensor; // Input operation should be named 'x' tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x"); @@ -100,6 +99,7 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons if (!tf_model->input_tensor){ return DNN_ERROR; } + input->data = (float*)TF_TensorData(tf_model->input_tensor); // Output operation should be named 'y' tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y"); @@ -107,17 +107,6 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons return DNN_ERROR; } tf_model->output.index = 0; - if (tf_model->output_tensor){ - TF_DeleteTensor(tf_model->output_tensor); - } - tf_model->output_tensor = TF_AllocateTensor(TF_FLOAT, output_dims, 4, - output_dims[1] * output_dims[2] * output_dims[3] * sizeof(float)); - if (!tf_model->output_tensor){ - return DNN_ERROR; - } - - tf_model->input_data = input; - tf_model->output_data = output; if (tf_model->session){ TF_CloseSession(tf_model->session, tf_model->status); @@ -144,6 +133,26 @@ static DNNReturnType set_input_output_tf(void* model, const DNNData* input, cons } } + // Execute network to get output height, width and number of channels + TF_SessionRun(tf_model->session, NULL, + &tf_model->input, &tf_model->input_tensor, 1, + &tf_model->output, &output_tensor, 1, + NULL, 0, NULL, tf_model->status); + if (TF_GetCode(tf_model->status) != TF_OK){ + return DNN_ERROR; + } + else{ + output->height = TF_Dim(output_tensor, 1); + output->width = TF_Dim(output_tensor, 2); + output->channels = TF_Dim(output_tensor, 3); + output->data = av_malloc(output->height * output->width * output->channels * sizeof(float)); + if (!output->data){ + return DNN_ERROR; + } + tf_model->output_data = output; + TF_DeleteTensor(output_tensor); + } + return DNN_SUCCESS; } @@ -166,7 +175,7 @@ DNNModel* ff_dnn_load_model_tf(const char* model_filename) } tf_model->session = NULL; tf_model->input_tensor = NULL; - tf_model->output_tensor = NULL; + tf_model->output_data = NULL; graph_def = read_graph(model_filename); if (!graph_def){ @@ -215,6 +224,17 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type) graph_def->length = srcnn_tf_size; graph_def->data_deallocator = free_buffer; break; + case DNN_ESPCN: + graph_data = av_malloc(espcn_tf_size); + if (!graph_data){ + TF_DeleteBuffer(graph_def); + return NULL; + } + memcpy(graph_data, espcn_tf_model, espcn_tf_size); + graph_def->data = (void*)graph_data; + graph_def->length = espcn_tf_size; + graph_def->data_deallocator = free_buffer; + break; default: TF_DeleteBuffer(graph_def); return NULL; @@ -234,7 +254,7 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type) } tf_model->session = NULL; tf_model->input_tensor = NULL; - tf_model->output_tensor = NULL; + tf_model->output_data = NULL; tf_model->graph = TF_NewGraph(); tf_model->status = TF_NewStatus(); @@ -259,23 +279,21 @@ DNNModel* ff_dnn_load_default_model_tf(DNNDefaultModel model_type) DNNReturnType ff_dnn_execute_model_tf(const DNNModel* model) { TFModel* tf_model = (TFModel*)model->model; - - memcpy(TF_TensorData(tf_model->input_tensor), tf_model->input_data->data, - tf_model->input_data->height * tf_model->input_data->width * - tf_model->input_data->channels * sizeof(float)); + TF_Tensor* output_tensor; TF_SessionRun(tf_model->session, NULL, &tf_model->input, &tf_model->input_tensor, 1, - &tf_model->output, &tf_model->output_tensor, 1, + &tf_model->output, &output_tensor, 1, NULL, 0, NULL, tf_model->status); if (TF_GetCode(tf_model->status) != TF_OK){ return DNN_ERROR; } else{ - memcpy(tf_model->output_data->data, TF_TensorData(tf_model->output_tensor), - tf_model->output_data->height * tf_model->output_data->width * - tf_model->output_data->channels * sizeof(float)); + memcpy(tf_model->output_data->data, TF_TensorData(output_tensor), + tf_model->output_data->height * tf_model->output_data->width * + tf_model->output_data->channels * sizeof(float)); + TF_DeleteTensor(output_tensor); return DNN_SUCCESS; } @@ -300,9 +318,7 @@ void ff_dnn_free_model_tf(DNNModel** model) if (tf_model->input_tensor){ TF_DeleteTensor(tf_model->input_tensor); } - if (tf_model->output_tensor){ - TF_DeleteTensor(tf_model->output_tensor); - } + av_freep(&tf_model->output_data->data); av_freep(&tf_model); av_freep(model); } |