summaryrefslogtreecommitdiff
path: root/libavfilter/vf_dnn_processing.c
diff options
context:
space:
mode:
authorGuo, Yejun <yejun.guo@intel.com>2020-08-28 12:51:44 +0800
committerGuo, Yejun <yejun.guo@intel.com>2020-09-21 21:26:56 +0800
commit2003e32f62d94ba75b59d70632c9f2862b383591 (patch)
tree55ec60788bc740eb45dbafd613bd8cf50a10417a /libavfilter/vf_dnn_processing.c
parent6918e240d706f7390272976d8b8d502afe426a18 (diff)
dnn: change dnn interface to replace DNNData* with AVFrame*
Currently, every filter needs to provide code to transfer data from AVFrame* to model input (DNNData*), and also from model output (DNNData*) to AVFrame*. Actually, such transfer can be implemented within DNN module, and so filter can focus on its own business logic. DNN module also exports the function pointer pre_proc and post_proc in struct DNNModel, just in case that a filter has its special logic to transfer data between AVFrame* and DNNData*. The default implementation within DNN module is used if the filter does not set pre/post_proc.
Diffstat (limited to 'libavfilter/vf_dnn_processing.c')
-rw-r--r--libavfilter/vf_dnn_processing.c240
1 files changed, 41 insertions, 199 deletions
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index f120bf9df4..d7462bc828 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -46,12 +46,6 @@ typedef struct DnnProcessingContext {
DNNModule *dnn_module;
DNNModel *model;
- // input & output of the model at execution time
- DNNData input;
- DNNData output;
-
- struct SwsContext *sws_gray8_to_grayf32;
- struct SwsContext *sws_grayf32_to_gray8;
struct SwsContext *sws_uv_scale;
int sws_uv_height;
} DnnProcessingContext;
@@ -103,7 +97,7 @@ static av_cold int init(AVFilterContext *context)
return AVERROR(EINVAL);
}
- ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, NULL);
+ ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, ctx);
if (!ctx->model) {
av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n");
return AVERROR(EINVAL);
@@ -148,6 +142,10 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
model_input->width, inlink->w);
return AVERROR(EIO);
}
+ if (model_input->dt != DNN_FLOAT) {
+ av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32.\n");
+ return AVERROR(EIO);
+ }
switch (fmt) {
case AV_PIX_FMT_RGB24:
@@ -156,20 +154,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
LOG_FORMAT_CHANNEL_MISMATCH();
return AVERROR(EIO);
}
- if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) {
- av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
- return AVERROR(EIO);
- }
- return 0;
- case AV_PIX_FMT_GRAY8:
- if (model_input->channels != 1) {
- LOG_FORMAT_CHANNEL_MISMATCH();
- return AVERROR(EIO);
- }
- if (model_input->dt != DNN_UINT8) {
- av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n");
- return AVERROR(EIO);
- }
return 0;
case AV_PIX_FMT_GRAYF32:
case AV_PIX_FMT_YUV420P:
@@ -181,10 +165,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
LOG_FORMAT_CHANNEL_MISMATCH();
return AVERROR(EIO);
}
- if (model_input->dt != DNN_FLOAT) {
- av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n");
- return AVERROR(EIO);
- }
return 0;
default:
av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt));
@@ -213,74 +193,24 @@ static int config_input(AVFilterLink *inlink)
return check;
}
- ctx->input.width = inlink->w;
- ctx->input.height = inlink->h;
- ctx->input.channels = model_input.channels;
- ctx->input.dt = model_input.dt;
-
- result = (ctx->model->set_input)(ctx->model->model,
- &ctx->input, ctx->model_inputname);
- if (result != DNN_SUCCESS) {
- av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n");
- return AVERROR(EIO);
- }
-
return 0;
}
-static int prepare_sws_context(AVFilterLink *outlink)
+static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt)
+{
+ const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt);
+ av_assert0(desc);
+ return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3;
+}
+
+static int prepare_uv_scale(AVFilterLink *outlink)
{
AVFilterContext *context = outlink->src;
DnnProcessingContext *ctx = context->priv;
AVFilterLink *inlink = context->inputs[0];
enum AVPixelFormat fmt = inlink->format;
- DNNDataType input_dt = ctx->input.dt;
- DNNDataType output_dt = ctx->output.dt;
-
- switch (fmt) {
- case AV_PIX_FMT_RGB24:
- case AV_PIX_FMT_BGR24:
- if (input_dt == DNN_FLOAT) {
- ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w * 3,
- inlink->h,
- AV_PIX_FMT_GRAY8,
- inlink->w * 3,
- inlink->h,
- AV_PIX_FMT_GRAYF32,
- 0, NULL, NULL, NULL);
- }
- if (output_dt == DNN_FLOAT) {
- ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w * 3,
- outlink->h,
- AV_PIX_FMT_GRAYF32,
- outlink->w * 3,
- outlink->h,
- AV_PIX_FMT_GRAY8,
- 0, NULL, NULL, NULL);
- }
- return 0;
- case AV_PIX_FMT_YUV420P:
- case AV_PIX_FMT_YUV422P:
- case AV_PIX_FMT_YUV444P:
- case AV_PIX_FMT_YUV410P:
- case AV_PIX_FMT_YUV411P:
- av_assert0(input_dt == DNN_FLOAT);
- av_assert0(output_dt == DNN_FLOAT);
- ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w,
- inlink->h,
- AV_PIX_FMT_GRAY8,
- inlink->w,
- inlink->h,
- AV_PIX_FMT_GRAYF32,
- 0, NULL, NULL, NULL);
- ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w,
- outlink->h,
- AV_PIX_FMT_GRAYF32,
- outlink->w,
- outlink->h,
- AV_PIX_FMT_GRAY8,
- 0, NULL, NULL, NULL);
+ if (isPlanarYUV(fmt)) {
if (inlink->w != outlink->w || inlink->h != outlink->h) {
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt);
int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
@@ -292,10 +222,6 @@ static int prepare_sws_context(AVFilterLink *outlink)
SWS_BICUBIC, NULL, NULL, NULL);
ctx->sws_uv_height = sws_src_h;
}
- return 0;
- default:
- //do nothing
- break;
}
return 0;
@@ -306,120 +232,34 @@ static int config_output(AVFilterLink *outlink)
AVFilterContext *context = outlink->src;
DnnProcessingContext *ctx = context->priv;
DNNReturnType result;
+ AVFilterLink *inlink = context->inputs[0];
+ AVFrame *out = NULL;
- // have a try run in case that the dnn model resize the frame
- result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1);
- if (result != DNN_SUCCESS){
- av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+ AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
+ result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname);
+ if (result != DNN_SUCCESS) {
+ av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
return AVERROR(EIO);
}
- outlink->w = ctx->output.width;
- outlink->h = ctx->output.height;
-
- prepare_sws_context(outlink);
-
- return 0;
-}
-
-static int copy_from_frame_to_dnn(DnnProcessingContext *ctx, const AVFrame *frame)
-{
- int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
- DNNData *dnn_input = &ctx->input;
-
- switch (frame->format) {
- case AV_PIX_FMT_RGB24:
- case AV_PIX_FMT_BGR24:
- if (dnn_input->dt == DNN_FLOAT) {
- sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize,
- 0, frame->height, (uint8_t * const*)(&dnn_input->data),
- (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0});
- } else {
- av_assert0(dnn_input->dt == DNN_UINT8);
- av_image_copy_plane(dnn_input->data, bytewidth,
- frame->data[0], frame->linesize[0],
- bytewidth, frame->height);
- }
- return 0;
- case AV_PIX_FMT_GRAY8:
- case AV_PIX_FMT_GRAYF32:
- av_image_copy_plane(dnn_input->data, bytewidth,
- frame->data[0], frame->linesize[0],
- bytewidth, frame->height);
- return 0;
- case AV_PIX_FMT_YUV420P:
- case AV_PIX_FMT_YUV422P:
- case AV_PIX_FMT_YUV444P:
- case AV_PIX_FMT_YUV410P:
- case AV_PIX_FMT_YUV411P:
- sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize,
- 0, frame->height, (uint8_t * const*)(&dnn_input->data),
- (const int [4]){frame->width * sizeof(float), 0, 0, 0});
- return 0;
- default:
+ // have a try run in case that the dnn model resize the frame
+ out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
+ result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+ if (result != DNN_SUCCESS){
+ av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
return AVERROR(EIO);
}
- return 0;
-}
+ outlink->w = out->width;
+ outlink->h = out->height;
-static int copy_from_dnn_to_frame(DnnProcessingContext *ctx, AVFrame *frame)
-{
- int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
- DNNData *dnn_output = &ctx->output;
-
- switch (frame->format) {
- case AV_PIX_FMT_RGB24:
- case AV_PIX_FMT_BGR24:
- if (dnn_output->dt == DNN_FLOAT) {
- sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0},
- (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0},
- 0, frame->height, (uint8_t * const*)frame->data, frame->linesize);
-
- } else {
- av_assert0(dnn_output->dt == DNN_UINT8);
- av_image_copy_plane(frame->data[0], frame->linesize[0],
- dnn_output->data, bytewidth,
- bytewidth, frame->height);
- }
- return 0;
- case AV_PIX_FMT_GRAY8:
- // it is possible that data type of dnn output is float32,
- // need to add support for such case when needed.
- av_assert0(dnn_output->dt == DNN_UINT8);
- av_image_copy_plane(frame->data[0], frame->linesize[0],
- dnn_output->data, bytewidth,
- bytewidth, frame->height);
- return 0;
- case AV_PIX_FMT_GRAYF32:
- av_assert0(dnn_output->dt == DNN_FLOAT);
- av_image_copy_plane(frame->data[0], frame->linesize[0],
- dnn_output->data, bytewidth,
- bytewidth, frame->height);
- return 0;
- case AV_PIX_FMT_YUV420P:
- case AV_PIX_FMT_YUV422P:
- case AV_PIX_FMT_YUV444P:
- case AV_PIX_FMT_YUV410P:
- case AV_PIX_FMT_YUV411P:
- sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0},
- (const int[4]){frame->width * sizeof(float), 0, 0, 0},
- 0, frame->height, (uint8_t * const*)frame->data, frame->linesize);
- return 0;
- default:
- return AVERROR(EIO);
- }
+ av_frame_free(&fake_in);
+ av_frame_free(&out);
+ prepare_uv_scale(outlink);
return 0;
}
-static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt)
-{
- const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt);
- av_assert0(desc);
- return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3;
-}
-
static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in)
{
const AVPixFmtDescriptor *desc;
@@ -453,11 +293,9 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
DNNReturnType dnn_result;
AVFrame *out;
- copy_from_frame_to_dnn(ctx, in);
-
- dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1);
- if (dnn_result != DNN_SUCCESS){
- av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+ dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname);
+ if (dnn_result != DNN_SUCCESS) {
+ av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
av_frame_free(&in);
return AVERROR(EIO);
}
@@ -467,9 +305,15 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
av_frame_free(&in);
return AVERROR(ENOMEM);
}
-
av_frame_copy_props(out, in);
- copy_from_dnn_to_frame(ctx, out);
+
+ dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+ if (dnn_result != DNN_SUCCESS){
+ av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+ av_frame_free(&in);
+ av_frame_free(&out);
+ return AVERROR(EIO);
+ }
if (isPlanarYUV(in->format))
copy_uv_planes(ctx, out, in);
@@ -482,8 +326,6 @@ static av_cold void uninit(AVFilterContext *ctx)
{
DnnProcessingContext *context = ctx->priv;
- sws_freeContext(context->sws_gray8_to_grayf32);
- sws_freeContext(context->sws_grayf32_to_gray8);
sws_freeContext(context->sws_uv_scale);
if (context->dnn_module)