summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libavfilter/dnn_backend_native.c14
-rw-r--r--libavfilter/dnn_backend_native.h2
-rw-r--r--libavfilter/dnn_backend_tf.c56
-rw-r--r--libavfilter/dnn_backend_tf.h2
-rw-r--r--libavfilter/dnn_interface.h6
-rw-r--r--libavfilter/vf_sr.c20
6 files changed, 51 insertions, 49 deletions
diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c
index fe4311693a..18735c025c 100644
--- a/libavfilter/dnn_backend_native.c
+++ b/libavfilter/dnn_backend_native.c
@@ -25,7 +25,7 @@
#include "dnn_backend_native.h"
-static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
+static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char *output_name)
{
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
InputParams *input_params;
@@ -81,11 +81,6 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const
}
}
- output->data = network->layers[network->layers_num - 1].output;
- output->height = cur_height;
- output->width = cur_width;
- output->channels = cur_channels;
-
return DNN_SUCCESS;
}
@@ -280,7 +275,7 @@ static void depth_to_space(const float *input, float *output, int block_size, in
}
}
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model)
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *output)
{
ConvolutionalNetwork *network = (ConvolutionalNetwork *)model->model;
int cur_width, cur_height, cur_channels;
@@ -322,6 +317,11 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model)
}
}
+ output->data = network->layers[network->layers_num - 1].output;
+ output->height = cur_height;
+ output->width = cur_width;
+ output->channels = cur_channels;
+
return DNN_SUCCESS;
}
diff --git a/libavfilter/dnn_backend_native.h b/libavfilter/dnn_backend_native.h
index 51d4cac955..adaf4a75e2 100644
--- a/libavfilter/dnn_backend_native.h
+++ b/libavfilter/dnn_backend_native.h
@@ -63,7 +63,7 @@ typedef struct ConvolutionalNetwork{
DNNModel *ff_dnn_load_model_native(const char *model_filename);
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model);
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *output);
void ff_dnn_free_model_native(DNNModel **model);
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index a838907d98..7bee45c5d3 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -35,7 +35,6 @@ typedef struct TFModel{
TF_Status *status;
TF_Output input, output;
TF_Tensor *input_tensor;
- DNNData *output_data;
} TFModel;
static void free_buffer(void *data, size_t length)
@@ -76,13 +75,12 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf;
}
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char *output_name)
{
TFModel *tf_model = (TFModel *)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
TF_SessionOptions *sess_opts;
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
- TF_Tensor *output_tensor;
// Input operation
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
@@ -132,26 +130,6 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
}
}
- // Execute network to get output height, width and number of channels
- TF_SessionRun(tf_model->session, NULL,
- &tf_model->input, &tf_model->input_tensor, 1,
- &tf_model->output, &output_tensor, 1,
- NULL, 0, NULL, tf_model->status);
- if (TF_GetCode(tf_model->status) != TF_OK){
- return DNN_ERROR;
- }
- else{
- output->height = TF_Dim(output_tensor, 1);
- output->width = TF_Dim(output_tensor, 2);
- output->channels = TF_Dim(output_tensor, 3);
- output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
- if (!output->data){
- return DNN_ERROR;
- }
- tf_model->output_data = output;
- TF_DeleteTensor(output_tensor);
- }
-
return DNN_SUCCESS;
}
@@ -489,7 +467,6 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
}
tf_model->session = NULL;
tf_model->input_tensor = NULL;
- tf_model->output_data = NULL;
if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
@@ -508,10 +485,12 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model)
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *output)
{
TFModel *tf_model = (TFModel *)model->model;
TF_Tensor *output_tensor;
+ uint64_t count;
+ uint64_t old_count = output->height * output->width * output->channels * sizeof(float);
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
@@ -521,14 +500,26 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model)
if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR;
}
- else{
- memcpy(tf_model->output_data->data, TF_TensorData(output_tensor),
- tf_model->output_data->height * tf_model->output_data->width *
- tf_model->output_data->channels * sizeof(float));
- TF_DeleteTensor(output_tensor);
- return DNN_SUCCESS;
+ output->height = TF_Dim(output_tensor, 1);
+ output->width = TF_Dim(output_tensor, 2);
+ output->channels = TF_Dim(output_tensor, 3);
+ count = output->height * output->width * output->channels * sizeof(float);
+ if (output->data) {
+ if (count > old_count) {
+ av_freep(&output->data);
+ }
+ }
+ if (!output->data) {
+ output->data = av_malloc(count);
+ if (!output->data){
+ return DNN_ERROR;
+ }
}
+ memcpy(output->data, TF_TensorData(output_tensor), count);
+ TF_DeleteTensor(output_tensor);
+
+ return DNN_SUCCESS;
}
void ff_dnn_free_model_tf(DNNModel **model)
@@ -550,9 +541,6 @@ void ff_dnn_free_model_tf(DNNModel **model)
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
- if (tf_model->output_data){
- av_freep(&tf_model->output_data->data);
- }
av_freep(&tf_model);
av_freep(model);
}
diff --git a/libavfilter/dnn_backend_tf.h b/libavfilter/dnn_backend_tf.h
index 7ba84f40ee..47a24ec7b7 100644
--- a/libavfilter/dnn_backend_tf.h
+++ b/libavfilter/dnn_backend_tf.h
@@ -31,7 +31,7 @@
DNNModel *ff_dnn_load_model_tf(const char *model_filename);
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model);
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *output);
void ff_dnn_free_model_tf(DNNModel **model);
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 0390e39b99..822f6e5b68 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -38,9 +38,9 @@ typedef struct DNNData{
typedef struct DNNModel{
// Stores model that can be different for different backends.
void *model;
- // Sets model input and output, while allocating additional memory for intermediate calculations.
+ // Sets model input and output.
// Should be called at least once before model execution.
- DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name);
+ DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char *output_name);
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
@@ -48,7 +48,7 @@ typedef struct DNNModule{
// Loads model and parameters from given file. Returns NULL if it is not possible.
DNNModel *(*load_model)(const char *model_filename);
// Executes model with specified input and output. Returns DNN_ERROR otherwise.
- DNNReturnType (*execute_model)(const DNNModel *model);
+ DNNReturnType (*execute_model)(const DNNModel *model, DNNData *output);
// Frees memory allocated for model.
void (*free_model)(DNNModel **model);
} DNNModule;
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 0c048e03a5..577b4fcb75 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -121,20 +121,31 @@ static int config_props(AVFilterLink *inlink)
sr_context->input.height = inlink->h * sr_context->scale_factor;
sr_context->input.channels = 1;
- result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
+ result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", "y");
if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
return AVERROR(EIO);
}
+ result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output);
+ if (result != DNN_SUCCESS){
+ av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
+ return AVERROR(EIO);
+ }
+
if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){
sr_context->input.width = inlink->w;
sr_context->input.height = inlink->h;
- result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
+ result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", "y");
if (result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
return AVERROR(EIO);
}
+ result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output);
+ if (result != DNN_SUCCESS){
+ av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
+ return AVERROR(EIO);
+ }
sr_context->scale_factor = 0;
}
outlink->h = sr_context->output.height;
@@ -245,7 +256,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
}
av_frame_free(&in);
- dnn_result = (sr_context->dnn_module->execute_model)(sr_context->model);
+ dnn_result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output);
if (dnn_result != DNN_SUCCESS){
av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
return AVERROR(EIO);
@@ -263,6 +274,9 @@ static av_cold void uninit(AVFilterContext *context)
int i;
SRContext *sr_context = context->priv;
+ if (sr_context->backend_type == DNN_TF)
+ av_freep(&sr_context->output.data);
+
if (sr_context->dnn_module){
(sr_context->dnn_module->free_model)(&sr_context->model);
av_freep(&sr_context->dnn_module);