summaryrefslogtreecommitdiff
path: root/libavfilter/dnn/dnn_backend_tf.c
diff options
context:
space:
mode:
authorGuo, Yejun <yejun.guo@intel.com>2020-08-13 16:19:48 +0800
committerGuo, Yejun <yejun.guo@intel.com>2020-08-25 09:02:59 +0800
commit0f7a99e37ae52f9ecdc4c81195c14b03f5be3dfd (patch)
tree00b5828f4f284ec9e363708ed89d4b2294e62a52 /libavfilter/dnn/dnn_backend_tf.c
parentb61376bdee61c08732105fa331eb076497eface9 (diff)
dnn: move output name from DNNModel.set_input_output to DNNModule.execute_model
currently, output is set both at DNNModel.set_input_output and DNNModule.execute_model, it makes sense that the output name is provided at model inference time so all the output info is set at a single place. and so DNNModel.set_input_output is renamed to DNNModel.set_input Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
Diffstat (limited to 'libavfilter/dnn/dnn_backend_tf.c')
-rw-r--r--libavfilter/dnn/dnn_backend_tf.c87
1 files changed, 35 insertions, 52 deletions
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 9d079aa92e..bdc90d5063 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -40,7 +40,6 @@ typedef struct TFModel{
TF_Status *status;
TF_Output input;
TF_Tensor *input_tensor;
- TF_Output *outputs;
TF_Tensor **output_tensors;
uint32_t nb_output;
} TFModel;
@@ -136,7 +135,7 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
return DNN_SUCCESS;
}
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output)
+static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input_name)
{
TFModel *tf_model = (TFModel *)model;
TF_SessionOptions *sess_opts;
@@ -157,40 +156,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char
}
input->data = (float *)TF_TensorData(tf_model->input_tensor);
- // Output operation
- if (nb_output == 0)
- return DNN_ERROR;
-
- av_freep(&tf_model->outputs);
- tf_model->outputs = av_malloc_array(nb_output, sizeof(*tf_model->outputs));
- if (!tf_model->outputs)
- return DNN_ERROR;
- for (int i = 0; i < nb_output; ++i) {
- tf_model->outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]);
- if (!tf_model->outputs[i].oper){
- av_freep(&tf_model->outputs);
- return DNN_ERROR;
- }
- tf_model->outputs[i].index = 0;
- }
-
- if (tf_model->output_tensors) {
- for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
- if (tf_model->output_tensors[i]) {
- TF_DeleteTensor(tf_model->output_tensors[i]);
- tf_model->output_tensors[i] = NULL;
- }
- }
- }
- av_freep(&tf_model->output_tensors);
- tf_model->output_tensors = av_mallocz_array(nb_output, sizeof(*tf_model->output_tensors));
- if (!tf_model->output_tensors) {
- av_freep(&tf_model->outputs);
- return DNN_ERROR;
- }
-
- tf_model->nb_output = nb_output;
-
+ // session
if (tf_model->session){
TF_CloseSession(tf_model->session, tf_model->status);
TF_DeleteSession(tf_model->session, tf_model->status);
@@ -598,40 +564,57 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options)
}
model->model = (void *)tf_model;
- model->set_input_output = &set_input_output_tf;
+ model->set_input = &set_input_tf;
model->get_input = &get_input_tf;
model->options = options;
return model;
}
-
-
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, uint32_t nb_output)
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
{
+ TF_Output *tf_outputs;
TFModel *tf_model = (TFModel *)model->model;
- uint32_t nb = FFMIN(nb_output, tf_model->nb_output);
- if (nb == 0)
+
+ tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
+ if (tf_outputs == NULL)
return DNN_ERROR;
- av_assert0(tf_model->output_tensors);
- for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
- if (tf_model->output_tensors[i]) {
- TF_DeleteTensor(tf_model->output_tensors[i]);
- tf_model->output_tensors[i] = NULL;
+ if (tf_model->output_tensors) {
+ for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
+ if (tf_model->output_tensors[i]) {
+ TF_DeleteTensor(tf_model->output_tensors[i]);
+ tf_model->output_tensors[i] = NULL;
+ }
}
}
+ av_freep(&tf_model->output_tensors);
+ tf_model->nb_output = nb_output;
+ tf_model->output_tensors = av_mallocz_array(nb_output, sizeof(*tf_model->output_tensors));
+ if (!tf_model->output_tensors) {
+ av_freep(&tf_outputs);
+ return DNN_ERROR;
+ }
+
+ for (int i = 0; i < nb_output; ++i) {
+ tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]);
+ if (!tf_outputs[i].oper) {
+ av_freep(&tf_outputs);
+ return DNN_ERROR;
+ }
+ tf_outputs[i].index = 0;
+ }
TF_SessionRun(tf_model->session, NULL,
&tf_model->input, &tf_model->input_tensor, 1,
- tf_model->outputs, tf_model->output_tensors, nb,
+ tf_outputs, tf_model->output_tensors, nb_output,
NULL, 0, NULL, tf_model->status);
-
- if (TF_GetCode(tf_model->status) != TF_OK){
+ if (TF_GetCode(tf_model->status) != TF_OK) {
+ av_freep(&tf_outputs);
return DNN_ERROR;
}
- for (uint32_t i = 0; i < nb; ++i) {
+ for (uint32_t i = 0; i < nb_output; ++i) {
outputs[i].height = TF_Dim(tf_model->output_tensors[i], 1);
outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2);
outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3);
@@ -639,6 +622,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, u
outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]);
}
+ av_freep(&tf_outputs);
return DNN_SUCCESS;
}
@@ -669,7 +653,6 @@ void ff_dnn_free_model_tf(DNNModel **model)
}
}
}
- av_freep(&tf_model->outputs);
av_freep(&tf_model->output_tensors);
av_freep(&tf_model);
av_freep(model);