summaryrefslogtreecommitdiff
path: root/libavfilter/dnn_backend_tf.c
diff options
context:
space:
mode:
Diffstat (limited to 'libavfilter/dnn_backend_tf.c')
-rw-r--r--libavfilter/dnn_backend_tf.c56
1 files changed, 22 insertions, 34 deletions
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index a838907d98..7bee45c5d3 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -35,7 +35,6 @@ typedef struct TFModel{
TF_Status *status;
TF_Output input, output;
TF_Tensor *input_tensor;
- DNNData *output_data;
} TFModel;
static void free_buffer(void *data, size_t length)
@@ -76,13 +75,12 @@ static TF_Buffer *read_graph(const char *model_filename)
return graph_buf;
}
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char *output_name)
{
TFModel *tf_model = (TFModel *)model;
int64_t input_dims[] = {1, input->height, input->width, input->channels};
TF_SessionOptions *sess_opts;
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
- TF_Tensor *output_tensor;
// Input operation
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
@@ -132,26 +130,6 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
}
}
- // Execute network to get output height, width and number of channels
- TF_SessionRun(tf_model->session, NULL,
- &tf_model->input, &tf_model->input_tensor, 1,
- &tf_model->output, &output_tensor, 1,
- NULL, 0, NULL, tf_model->status);
- if (TF_GetCode(tf_model->status) != TF_OK){
- return DNN_ERROR;
- }
- else{
- output->height = TF_Dim(output_tensor, 1);
- output->width = TF_Dim(output_tensor, 2);
- output->channels = TF_Dim(output_tensor, 3);
- output->data = av_malloc(output->height * output->width * output->channels * sizeof(float));
- if (!output->data){
- return DNN_ERROR;
- }
- tf_model->output_data = output;
- TF_DeleteTensor(output_tensor);
- }
-
return DNN_SUCCESS;
}
@@ -489,7 +467,6 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
}
tf_model->session = NULL;
tf_model->input_tensor = NULL;
- tf_model->output_data = NULL;
if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
@@ -508,10 +485,12 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename)
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model)
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *output)
{
TFModel *tf_model = (TFModel *)model->model;
TF_Tensor *output_tensor;
+ uint64_t count;
+ uint64_t old_count = output->height * output->width * output->channels * sizeof(float);
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
@@ -521,14 +500,26 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model)
if (TF_GetCode(tf_model->status) != TF_OK){
return DNN_ERROR;
}
- else{
- memcpy(tf_model->output_data->data, TF_TensorData(output_tensor),
- tf_model->output_data->height * tf_model->output_data->width *
- tf_model->output_data->channels * sizeof(float));
- TF_DeleteTensor(output_tensor);
- return DNN_SUCCESS;
+ output->height = TF_Dim(output_tensor, 1);
+ output->width = TF_Dim(output_tensor, 2);
+ output->channels = TF_Dim(output_tensor, 3);
+ count = output->height * output->width * output->channels * sizeof(float);
+ if (output->data) {
+ if (count > old_count) {
+ av_freep(&output->data);
+ }
+ }
+ if (!output->data) {
+ output->data = av_malloc(count);
+ if (!output->data){
+ return DNN_ERROR;
+ }
}
+ memcpy(output->data, TF_TensorData(output_tensor), count);
+ TF_DeleteTensor(output_tensor);
+
+ return DNN_SUCCESS;
}
void ff_dnn_free_model_tf(DNNModel **model)
@@ -550,9 +541,6 @@ void ff_dnn_free_model_tf(DNNModel **model)
if (tf_model->input_tensor){
TF_DeleteTensor(tf_model->input_tensor);
}
- if (tf_model->output_data){
- av_freep(&tf_model->output_data->data);
- }
av_freep(&tf_model);
av_freep(model);
}