summaryrefslogtreecommitdiff
path: root/libavfilter/dnn/dnn_backend_tf.c
diff options
context:
space:
mode:
Diffstat (limited to 'libavfilter/dnn/dnn_backend_tf.c')
-rw-r--r--libavfilter/dnn/dnn_backend_tf.c49
1 files changed, 25 insertions, 24 deletions
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 45da29ae70..b6b1812cd9 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -155,7 +155,7 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
TF_DeleteStatus(status);
// currently only NHWC is supported
- av_assert0(dims[0] == 1);
+ av_assert0(dims[0] == 1 || dims[0] == -1);
input->height = dims[1];
input->width = dims[2];
input->channels = dims[3];
@@ -707,7 +707,7 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
TF_Output *tf_outputs;
TFModel *tf_model = model->model;
TFContext *ctx = &tf_model->ctx;
- DNNData input, output;
+ DNNData input, *outputs;
TF_Tensor **output_tensors;
TF_Output tf_input;
TF_Tensor *input_tensor;
@@ -738,14 +738,6 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
}
}
- if (nb_output != 1) {
- // currently, the filter does not need multiple outputs,
- // so we just pending the support until we really need it.
- TF_DeleteTensor(input_tensor);
- avpriv_report_missing_feature(ctx, "multiple outputs");
- return DNN_ERROR;
- }
-
tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
if (tf_outputs == NULL) {
TF_DeleteTensor(input_tensor);
@@ -785,23 +777,31 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
return DNN_ERROR;
}
+ outputs = av_malloc_array(nb_output, sizeof(*outputs));
+ if (!outputs) {
+ TF_DeleteTensor(input_tensor);
+ av_freep(&tf_outputs);
+ av_freep(&output_tensors);
+ av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *outputs\n"); \
+ return DNN_ERROR;
+ }
+
for (uint32_t i = 0; i < nb_output; ++i) {
- output.height = TF_Dim(output_tensors[i], 1);
- output.width = TF_Dim(output_tensors[i], 2);
- output.channels = TF_Dim(output_tensors[i], 3);
- output.data = TF_TensorData(output_tensors[i]);
- output.dt = TF_TensorType(output_tensors[i]);
-
- if (do_ioproc) {
- if (tf_model->model->frame_post_proc != NULL) {
- tf_model->model->frame_post_proc(out_frame, &output, tf_model->model->filter_ctx);
- } else {
- ff_proc_from_dnn_to_frame(out_frame, &output, ctx);
- }
+ outputs[i].height = TF_Dim(output_tensors[i], 1);
+ outputs[i].width = TF_Dim(output_tensors[i], 2);
+ outputs[i].channels = TF_Dim(output_tensors[i], 3);
+ outputs[i].data = TF_TensorData(output_tensors[i]);
+ outputs[i].dt = TF_TensorType(output_tensors[i]);
+ }
+ if (do_ioproc) {
+ if (tf_model->model->frame_post_proc != NULL) {
+ tf_model->model->frame_post_proc(out_frame, outputs, tf_model->model->filter_ctx);
} else {
- out_frame->width = output.width;
- out_frame->height = output.height;
+ ff_proc_from_dnn_to_frame(out_frame, outputs, ctx);
}
+ } else {
+ out_frame->width = outputs[0].width;
+ out_frame->height = outputs[0].height;
}
for (uint32_t i = 0; i < nb_output; ++i) {
@@ -812,6 +812,7 @@ static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
TF_DeleteTensor(input_tensor);
av_freep(&output_tensors);
av_freep(&tf_outputs);
+ av_freep(&outputs);
return DNN_SUCCESS;
}