summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/filters.texi6
-rw-r--r--libavfilter/vf_dnn_processing.c168
2 files changed, 132 insertions, 42 deletions
diff --git a/doc/filters.texi b/doc/filters.texi
index 71b69f8eb6..63376cf1a2 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -9115,6 +9115,12 @@ Halve the red channle of the frame with format rgb24:
ffmpeg -i input.jpg -vf format=rgb24,dnn_processing=model=halve_first_channel.model:input=dnn_in:output=dnn_out:dnn_backend=native out.native.png
@end example
+@item
+Halve the pixel value of the frame with format gray32f:
+@example
+ffmpeg -i input.jpg -vf format=grayf32,dnn_processing=model=halve_gray_float.model:input=dnn_in:output=dnn_out:dnn_backend=native -y out.native.png
+@end example
+
@end itemize
@section drawbox
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index 4a6b900d94..13273f2f86 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -104,12 +104,20 @@ static int query_formats(AVFilterContext *context)
{
static const enum AVPixelFormat pix_fmts[] = {
AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24,
+ AV_PIX_FMT_GRAY8, AV_PIX_FMT_GRAYF32,
AV_PIX_FMT_NONE
};
AVFilterFormats *fmts_list = ff_make_format_list(pix_fmts);
return ff_set_common_formats(context, fmts_list);
}
+#define LOG_FORMAT_CHANNEL_MISMATCH() \
+ av_log(ctx, AV_LOG_ERROR, \
+ "the frame's format %s does not match " \
+ "the model input channel %d\n", \
+ av_get_pix_fmt_name(fmt), \
+ model_input->channels);
+
static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink)
{
AVFilterContext *ctx = inlink->dst;
@@ -131,17 +139,34 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
if (model_input->channels != 3) {
- av_log(ctx, AV_LOG_ERROR, "the frame's input format %s does not match "
- "the model input channels %d\n",
- av_get_pix_fmt_name(fmt),
- model_input->channels);
+ LOG_FORMAT_CHANNEL_MISMATCH();
return AVERROR(EIO);
}
if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) {
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
return AVERROR(EIO);
}
- break;
+ return 0;
+ case AV_PIX_FMT_GRAY8:
+ if (model_input->channels != 1) {
+ LOG_FORMAT_CHANNEL_MISMATCH();
+ return AVERROR(EIO);
+ }
+ if (model_input->dt != DNN_UINT8) {
+ av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n");
+ return AVERROR(EIO);
+ }
+ return 0;
+ case AV_PIX_FMT_GRAYF32:
+ if (model_input->channels != 1) {
+ LOG_FORMAT_CHANNEL_MISMATCH();
+ return AVERROR(EIO);
+ }
+ if (model_input->dt != DNN_FLOAT) {
+ av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n");
+ return AVERROR(EIO);
+ }
+ return 0;
default:
av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt));
return AVERROR(EIO);
@@ -206,28 +231,58 @@ static int config_output(AVFilterLink *outlink)
static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
{
- // extend this function to support more formats
- av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
-
- if (dnn_input->dt == DNN_FLOAT) {
- float *dnn_input_data = dnn_input->data;
- for (int i = 0; i < frame->height; i++) {
- for(int j = 0; j < frame->width * 3; j++) {
- int k = i * frame->linesize[0] + j;
- int t = i * frame->width * 3 + j;
- dnn_input_data[t] = frame->data[0][k] / 255.0f;
+ switch (frame->format) {
+ case AV_PIX_FMT_RGB24:
+ case AV_PIX_FMT_BGR24:
+ if (dnn_input->dt == DNN_FLOAT) {
+ float *dnn_input_data = dnn_input->data;
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width * 3; j++) {
+ int k = i * frame->linesize[0] + j;
+ int t = i * frame->width * 3 + j;
+ dnn_input_data[t] = frame->data[0][k] / 255.0f;
+ }
+ }
+ } else {
+ uint8_t *dnn_input_data = dnn_input->data;
+ av_assert0(dnn_input->dt == DNN_UINT8);
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width * 3; j++) {
+ int k = i * frame->linesize[0] + j;
+ int t = i * frame->width * 3 + j;
+ dnn_input_data[t] = frame->data[0][k];
+ }
}
}
- } else {
- uint8_t *dnn_input_data = dnn_input->data;
- av_assert0(dnn_input->dt == DNN_UINT8);
- for (int i = 0; i < frame->height; i++) {
- for(int j = 0; j < frame->width * 3; j++) {
- int k = i * frame->linesize[0] + j;
- int t = i * frame->width * 3 + j;
- dnn_input_data[t] = frame->data[0][k];
+ return 0;
+ case AV_PIX_FMT_GRAY8:
+ {
+ uint8_t *dnn_input_data = dnn_input->data;
+ av_assert0(dnn_input->dt == DNN_UINT8);
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width; j++) {
+ int k = i * frame->linesize[0] + j;
+ int t = i * frame->width + j;
+ dnn_input_data[t] = frame->data[0][k];
+ }
}
}
+ return 0;
+ case AV_PIX_FMT_GRAYF32:
+ {
+ float *dnn_input_data = dnn_input->data;
+ av_assert0(dnn_input->dt == DNN_FLOAT);
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width; j++) {
+ int k = i * frame->linesize[0] + j * sizeof(float);
+ int t = i * frame->width + j;
+ dnn_input_data[t] = *(float*)(frame->data[0] + k);
+ }
+ }
+ }
+ return 0;
+ default:
+ return AVERROR(EIO);
}
return 0;
@@ -235,28 +290,58 @@ static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame)
static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output)
{
- // extend this function to support more formats
- av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24);
-
- if (dnn_output->dt == DNN_FLOAT) {
- float *dnn_output_data = dnn_output->data;
- for (int i = 0; i < frame->height; i++) {
- for(int j = 0; j < frame->width * 3; j++) {
- int k = i * frame->linesize[0] + j;
- int t = i * frame->width * 3 + j;
- frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
+ switch (frame->format) {
+ case AV_PIX_FMT_RGB24:
+ case AV_PIX_FMT_BGR24:
+ if (dnn_output->dt == DNN_FLOAT) {
+ float *dnn_output_data = dnn_output->data;
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width * 3; j++) {
+ int k = i * frame->linesize[0] + j;
+ int t = i * frame->width * 3 + j;
+ frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8);
+ }
+ }
+ } else {
+ uint8_t *dnn_output_data = dnn_output->data;
+ av_assert0(dnn_output->dt == DNN_UINT8);
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width * 3; j++) {
+ int k = i * frame->linesize[0] + j;
+ int t = i * frame->width * 3 + j;
+ frame->data[0][k] = dnn_output_data[t];
+ }
+ }
+ }
+ return 0;
+ case AV_PIX_FMT_GRAY8:
+ {
+ uint8_t *dnn_output_data = dnn_output->data;
+ av_assert0(dnn_output->dt == DNN_UINT8);
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width; j++) {
+ int k = i * frame->linesize[0] + j;
+ int t = i * frame->width + j;
+ frame->data[0][k] = dnn_output_data[t];
+ }
}
}
- } else {
- uint8_t *dnn_output_data = dnn_output->data;
- av_assert0(dnn_output->dt == DNN_UINT8);
- for (int i = 0; i < frame->height; i++) {
- for(int j = 0; j < frame->width * 3; j++) {
- int k = i * frame->linesize[0] + j;
- int t = i * frame->width * 3 + j;
- frame->data[0][k] = dnn_output_data[t];
+ return 0;
+ case AV_PIX_FMT_GRAYF32:
+ {
+ float *dnn_output_data = dnn_output->data;
+ av_assert0(dnn_output->dt == DNN_FLOAT);
+ for (int i = 0; i < frame->height; i++) {
+ for(int j = 0; j < frame->width; j++) {
+ int k = i * frame->linesize[0] + j * sizeof(float);
+ int t = i * frame->width + j;
+ *(float*)(frame->data[0] + k) = dnn_output_data[t];
+ }
}
}
+ return 0;
+ default:
+ return AVERROR(EIO);
}
return 0;
@@ -278,7 +363,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in)
av_frame_free(&in);
return AVERROR(EIO);
}
- av_assert0(ctx->output.channels == 3);
out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
if (!out) {