summaryrefslogtreecommitdiff
path: root/libavfilter/dnn/dnn_backend_native.c
diff options
context:
space:
mode:
Diffstat (limited to 'libavfilter/dnn/dnn_backend_native.c')
-rw-r--r--libavfilter/dnn/dnn_backend_native.c59
1 files changed, 56 insertions, 3 deletions
diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
index 3b2a3aa55d..2d34b88f8a 100644
--- a/libavfilter/dnn/dnn_backend_native.c
+++ b/libavfilter/dnn/dnn_backend_native.c
@@ -34,6 +34,7 @@
#define FLAGS AV_OPT_FLAG_FILTERING_PARAM
static const AVOption dnn_native_options[] = {
{ "conv2d_threads", "threads num for conv2d layer", OFFSET(options.conv2d_threads), AV_OPT_TYPE_INT, { .i64 = 0 }, INT_MIN, INT_MAX, FLAGS },
+ { "async", "use DNN async inference", OFFSET(options.async), AV_OPT_TYPE_BOOL, { .i64 = 0 }, 0, 1, FLAGS },
{ NULL },
};
@@ -189,6 +190,11 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, DNNFunctionType f
goto fail;
native_model->model = model;
+ if (native_model->ctx.options.async) {
+ av_log(&native_model->ctx, AV_LOG_WARNING, "Async not supported. Rolling back to sync\n");
+ native_model->ctx.options.async = 0;
+ }
+
#if !HAVE_PTHREAD_CANCEL
if (native_model->ctx.options.conv2d_threads > 1){
av_log(&native_model->ctx, AV_LOG_WARNING, "'conv2d_threads' option was set but it is not supported "
@@ -212,6 +218,11 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, DNNFunctionType f
goto fail;
}
+ native_model->task_queue = ff_queue_create();
+ if (!native_model->task_queue) {
+ goto fail;
+ }
+
native_model->inference_queue = ff_queue_create();
if (!native_model->inference_queue) {
goto fail;
@@ -425,17 +436,30 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNExecBasePara
{
NativeModel *native_model = model->model;
NativeContext *ctx = &native_model->ctx;
- TaskItem task;
+ TaskItem *task;
if (ff_check_exec_params(ctx, DNN_NATIVE, model->func_type, exec_params) != 0) {
return DNN_ERROR;
}
- if (ff_dnn_fill_task(&task, exec_params, native_model, 0, 1) != DNN_SUCCESS) {
+ task = av_malloc(sizeof(*task));
+ if (!task) {
+ av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n");
return DNN_ERROR;
}
- if (extract_inference_from_task(&task, native_model->inference_queue) != DNN_SUCCESS) {
+ if (ff_dnn_fill_task(task, exec_params, native_model, ctx->options.async, 1) != DNN_SUCCESS) {
+ av_freep(&task);
+ return DNN_ERROR;
+ }
+
+ if (ff_queue_push_back(native_model->task_queue, task) < 0) {
+ av_freep(&task);
+ av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n");
+ return DNN_ERROR;
+ }
+
+ if (extract_inference_from_task(task, native_model->inference_queue) != DNN_SUCCESS) {
av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
return DNN_ERROR;
}
@@ -443,6 +467,26 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNExecBasePara
return execute_model_native(native_model->inference_queue);
}
+DNNReturnType ff_dnn_flush_native(const DNNModel *model)
+{
+ NativeModel *native_model = model->model;
+
+ if (ff_queue_size(native_model->inference_queue) == 0) {
+ // no pending task need to flush
+ return DNN_SUCCESS;
+ }
+
+ // for now, use sync node with flush operation
+ // Switch to async when it is supported
+ return execute_model_native(native_model->inference_queue);
+}
+
+DNNAsyncStatusType ff_dnn_get_result_native(const DNNModel *model, AVFrame **in, AVFrame **out)
+{
+ NativeModel *native_model = model->model;
+ return ff_dnn_get_result_common(native_model->task_queue, in, out);
+}
+
int32_t ff_calculate_operand_dims_count(const DnnOperand *oprd)
{
int32_t result = 1;
@@ -497,6 +541,15 @@ void ff_dnn_free_model_native(DNNModel **model)
av_freep(&item);
}
ff_queue_destroy(native_model->inference_queue);
+
+ while (ff_queue_size(native_model->task_queue) != 0) {
+ TaskItem *item = ff_queue_pop_front(native_model->task_queue);
+ av_frame_free(&item->in_frame);
+ av_frame_free(&item->out_frame);
+ av_freep(&item);
+ }
+ ff_queue_destroy(native_model->task_queue);
+
av_freep(&native_model);
}
av_freep(model);