diff options
author | Guo, Yejun <yejun.guo@intel.com> | 2020-08-28 12:51:44 +0800 |
---|---|---|
committer | Guo, Yejun <yejun.guo@intel.com> | 2020-09-21 21:26:56 +0800 |
commit | 2003e32f62d94ba75b59d70632c9f2862b383591 (patch) | |
tree | 55ec60788bc740eb45dbafd613bd8cf50a10417a /libavfilter/vf_dnn_processing.c | |
parent | 6918e240d706f7390272976d8b8d502afe426a18 (diff) |
dnn: change dnn interface to replace DNNData* with AVFrame*
Currently, every filter needs to provide code to transfer data from
AVFrame* to model input (DNNData*), and also from model output
(DNNData*) to AVFrame*. Actually, such transfer can be implemented
within DNN module, and so filter can focus on its own business logic.
DNN module also exports the function pointer pre_proc and post_proc
in struct DNNModel, just in case that a filter has its special logic
to transfer data between AVFrame* and DNNData*. The default implementation
within DNN module is used if the filter does not set pre/post_proc.
Diffstat (limited to 'libavfilter/vf_dnn_processing.c')
-rw-r--r-- | libavfilter/vf_dnn_processing.c | 240 |
1 files changed, 41 insertions, 199 deletions
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index f120bf9df4..d7462bc828 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -46,12 +46,6 @@ typedef struct DnnProcessingContext { DNNModule *dnn_module; DNNModel *model; - // input & output of the model at execution time - DNNData input; - DNNData output; - - struct SwsContext *sws_gray8_to_grayf32; - struct SwsContext *sws_grayf32_to_gray8; struct SwsContext *sws_uv_scale; int sws_uv_height; } DnnProcessingContext; @@ -103,7 +97,7 @@ static av_cold int init(AVFilterContext *context) return AVERROR(EINVAL); } - ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, NULL); + ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, ctx); if (!ctx->model) { av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n"); return AVERROR(EINVAL); @@ -148,6 +142,10 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin model_input->width, inlink->w); return AVERROR(EIO); } + if (model_input->dt != DNN_FLOAT) { + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32.\n"); + return AVERROR(EIO); + } switch (fmt) { case AV_PIX_FMT_RGB24: @@ -156,20 +154,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } - if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); - return AVERROR(EIO); - } - return 0; - case AV_PIX_FMT_GRAY8: - if (model_input->channels != 1) { - LOG_FORMAT_CHANNEL_MISMATCH(); - return AVERROR(EIO); - } - if (model_input->dt != DNN_UINT8) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n"); - return AVERROR(EIO); - } return 0; case AV_PIX_FMT_GRAYF32: case AV_PIX_FMT_YUV420P: @@ -181,10 +165,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } - if (model_input->dt != DNN_FLOAT) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n"); - return AVERROR(EIO); - } return 0; default: av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt)); @@ -213,74 +193,24 @@ static int config_input(AVFilterLink *inlink) return check; } - ctx->input.width = inlink->w; - ctx->input.height = inlink->h; - ctx->input.channels = model_input.channels; - ctx->input.dt = model_input.dt; - - result = (ctx->model->set_input)(ctx->model->model, - &ctx->input, ctx->model_inputname); - if (result != DNN_SUCCESS) { - av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n"); - return AVERROR(EIO); - } - return 0; } -static int prepare_sws_context(AVFilterLink *outlink) +static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) +{ + const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); + av_assert0(desc); + return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; +} + +static int prepare_uv_scale(AVFilterLink *outlink) { AVFilterContext *context = outlink->src; DnnProcessingContext *ctx = context->priv; AVFilterLink *inlink = context->inputs[0]; enum AVPixelFormat fmt = inlink->format; - DNNDataType input_dt = ctx->input.dt; - DNNDataType output_dt = ctx->output.dt; - - switch (fmt) { - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - if (input_dt == DNN_FLOAT) { - ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w * 3, - inlink->h, - AV_PIX_FMT_GRAY8, - inlink->w * 3, - inlink->h, - AV_PIX_FMT_GRAYF32, - 0, NULL, NULL, NULL); - } - if (output_dt == DNN_FLOAT) { - ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w * 3, - outlink->h, - AV_PIX_FMT_GRAYF32, - outlink->w * 3, - outlink->h, - AV_PIX_FMT_GRAY8, - 0, NULL, NULL, NULL); - } - return 0; - case AV_PIX_FMT_YUV420P: - case AV_PIX_FMT_YUV422P: - case AV_PIX_FMT_YUV444P: - case AV_PIX_FMT_YUV410P: - case AV_PIX_FMT_YUV411P: - av_assert0(input_dt == DNN_FLOAT); - av_assert0(output_dt == DNN_FLOAT); - ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w, - inlink->h, - AV_PIX_FMT_GRAY8, - inlink->w, - inlink->h, - AV_PIX_FMT_GRAYF32, - 0, NULL, NULL, NULL); - ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w, - outlink->h, - AV_PIX_FMT_GRAYF32, - outlink->w, - outlink->h, - AV_PIX_FMT_GRAY8, - 0, NULL, NULL, NULL); + if (isPlanarYUV(fmt)) { if (inlink->w != outlink->w || inlink->h != outlink->h) { const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt); int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); @@ -292,10 +222,6 @@ static int prepare_sws_context(AVFilterLink *outlink) SWS_BICUBIC, NULL, NULL, NULL); ctx->sws_uv_height = sws_src_h; } - return 0; - default: - //do nothing - break; } return 0; @@ -306,120 +232,34 @@ static int config_output(AVFilterLink *outlink) AVFilterContext *context = outlink->src; DnnProcessingContext *ctx = context->priv; DNNReturnType result; + AVFilterLink *inlink = context->inputs[0]; + AVFrame *out = NULL; - // have a try run in case that the dnn model resize the frame - result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1); - if (result != DNN_SUCCESS){ - av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); + result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname); + if (result != DNN_SUCCESS) { + av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); return AVERROR(EIO); } - outlink->w = ctx->output.width; - outlink->h = ctx->output.height; - - prepare_sws_context(outlink); - - return 0; -} - -static int copy_from_frame_to_dnn(DnnProcessingContext *ctx, const AVFrame *frame) -{ - int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); - DNNData *dnn_input = &ctx->input; - - switch (frame->format) { - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - if (dnn_input->dt == DNN_FLOAT) { - sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize, - 0, frame->height, (uint8_t * const*)(&dnn_input->data), - (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0}); - } else { - av_assert0(dnn_input->dt == DNN_UINT8); - av_image_copy_plane(dnn_input->data, bytewidth, - frame->data[0], frame->linesize[0], - bytewidth, frame->height); - } - return 0; - case AV_PIX_FMT_GRAY8: - case AV_PIX_FMT_GRAYF32: - av_image_copy_plane(dnn_input->data, bytewidth, - frame->data[0], frame->linesize[0], - bytewidth, frame->height); - return 0; - case AV_PIX_FMT_YUV420P: - case AV_PIX_FMT_YUV422P: - case AV_PIX_FMT_YUV444P: - case AV_PIX_FMT_YUV410P: - case AV_PIX_FMT_YUV411P: - sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize, - 0, frame->height, (uint8_t * const*)(&dnn_input->data), - (const int [4]){frame->width * sizeof(float), 0, 0, 0}); - return 0; - default: + // have a try run in case that the dnn model resize the frame + out = ff_get_video_buffer(inlink, inlink->w, inlink->h); + result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); + if (result != DNN_SUCCESS){ + av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); return AVERROR(EIO); } - return 0; -} + outlink->w = out->width; + outlink->h = out->height; -static int copy_from_dnn_to_frame(DnnProcessingContext *ctx, AVFrame *frame) -{ - int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); - DNNData *dnn_output = &ctx->output; - - switch (frame->format) { - case AV_PIX_FMT_RGB24: - case AV_PIX_FMT_BGR24: - if (dnn_output->dt == DNN_FLOAT) { - sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0}, - (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, - 0, frame->height, (uint8_t * const*)frame->data, frame->linesize); - - } else { - av_assert0(dnn_output->dt == DNN_UINT8); - av_image_copy_plane(frame->data[0], frame->linesize[0], - dnn_output->data, bytewidth, - bytewidth, frame->height); - } - return 0; - case AV_PIX_FMT_GRAY8: - // it is possible that data type of dnn output is float32, - // need to add support for such case when needed. - av_assert0(dnn_output->dt == DNN_UINT8); - av_image_copy_plane(frame->data[0], frame->linesize[0], - dnn_output->data, bytewidth, - bytewidth, frame->height); - return 0; - case AV_PIX_FMT_GRAYF32: - av_assert0(dnn_output->dt == DNN_FLOAT); - av_image_copy_plane(frame->data[0], frame->linesize[0], - dnn_output->data, bytewidth, - bytewidth, frame->height); - return 0; - case AV_PIX_FMT_YUV420P: - case AV_PIX_FMT_YUV422P: - case AV_PIX_FMT_YUV444P: - case AV_PIX_FMT_YUV410P: - case AV_PIX_FMT_YUV411P: - sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0}, - (const int[4]){frame->width * sizeof(float), 0, 0, 0}, - 0, frame->height, (uint8_t * const*)frame->data, frame->linesize); - return 0; - default: - return AVERROR(EIO); - } + av_frame_free(&fake_in); + av_frame_free(&out); + prepare_uv_scale(outlink); return 0; } -static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) -{ - const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); - av_assert0(desc); - return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; -} - static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in) { const AVPixFmtDescriptor *desc; @@ -453,11 +293,9 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) DNNReturnType dnn_result; AVFrame *out; - copy_from_frame_to_dnn(ctx, in); - - dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1); - if (dnn_result != DNN_SUCCESS){ - av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname); + if (dnn_result != DNN_SUCCESS) { + av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); av_frame_free(&in); return AVERROR(EIO); } @@ -467,9 +305,15 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) av_frame_free(&in); return AVERROR(ENOMEM); } - av_frame_copy_props(out, in); - copy_from_dnn_to_frame(ctx, out); + + dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); + if (dnn_result != DNN_SUCCESS){ + av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); + av_frame_free(&in); + av_frame_free(&out); + return AVERROR(EIO); + } if (isPlanarYUV(in->format)) copy_uv_planes(ctx, out, in); @@ -482,8 +326,6 @@ static av_cold void uninit(AVFilterContext *ctx) { DnnProcessingContext *context = ctx->priv; - sws_freeContext(context->sws_gray8_to_grayf32); - sws_freeContext(context->sws_grayf32_to_gray8); sws_freeContext(context->sws_uv_scale); if (context->dnn_module) |