summaryrefslogtreecommitdiff
path: root/libavfilter/dnn_filter_common.c
diff options
context:
space:
mode:
Diffstat (limited to 'libavfilter/dnn_filter_common.c')
-rw-r--r--libavfilter/dnn_filter_common.c53
1 files changed, 44 insertions, 9 deletions
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index 52c7a5392a..0ed0ac2e30 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -17,6 +17,39 @@
*/
#include "dnn_filter_common.h"
+#include "libavutil/avstring.h"
+
+#define MAX_SUPPORTED_OUTPUTS_NB 4
+
+static char **separate_output_names(const char *expr, const char *val_sep, int *separated_nb)
+{
+ char *val, **parsed_vals = NULL;
+ int val_num = 0;
+ if (!expr || !val_sep || !separated_nb) {
+ return NULL;
+ }
+
+ parsed_vals = av_mallocz_array(MAX_SUPPORTED_OUTPUTS_NB, sizeof(*parsed_vals));
+ if (!parsed_vals) {
+ return NULL;
+ }
+
+ do {
+ val = av_get_token(&expr, val_sep);
+ if(val) {
+ parsed_vals[val_num] = val;
+ val_num++;
+ }
+ if (*expr) {
+ expr++;
+ }
+ } while(*expr);
+
+ parsed_vals[val_num] = NULL;
+ *separated_nb = val_num;
+
+ return parsed_vals;
+}
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
{
@@ -28,8 +61,10 @@ int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *fil
av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
return AVERROR(EINVAL);
}
- if (!ctx->model_outputname) {
- av_log(filter_ctx, AV_LOG_ERROR, "output name of the model network is not specified\n");
+
+ ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs);
+ if (!ctx->model_outputnames) {
+ av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n");
return AVERROR(EINVAL);
}
@@ -91,15 +126,15 @@ DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height)
{
return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height,
- ctx->model_outputname, output_width, output_height);
+ (const char *)ctx->model_outputnames[0], output_width, output_height);
}
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame)
{
DNNExecBaseParams exec_params = {
.input_name = ctx->model_inputname,
- .output_names = (const char **)&ctx->model_outputname,
- .nb_output = 1,
+ .output_names = (const char **)ctx->model_outputnames,
+ .nb_output = ctx->nb_outputs,
.in_frame = in_frame,
.out_frame = out_frame,
};
@@ -110,8 +145,8 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
{
DNNExecBaseParams exec_params = {
.input_name = ctx->model_inputname,
- .output_names = (const char **)&ctx->model_outputname,
- .nb_output = 1,
+ .output_names = (const char **)ctx->model_outputnames,
+ .nb_output = ctx->nb_outputs,
.in_frame = in_frame,
.out_frame = out_frame,
};
@@ -123,8 +158,8 @@ DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_f
DNNExecClassificationParams class_params = {
{
.input_name = ctx->model_inputname,
- .output_names = (const char **)&ctx->model_outputname,
- .nb_output = 1,
+ .output_names = (const char **)ctx->model_outputnames,
+ .nb_output = ctx->nb_outputs,
.in_frame = in_frame,
.out_frame = out_frame,
},