lavfi/dnn: add post process for detection

This commit is contained in:
Guo, Yejun 2021-03-09 14:51:42 +08:00
parent 59021d79a2
commit 13bf797ced
4 changed files with 33 additions and 7 deletions

View File

@ -236,16 +236,32 @@ static void infer_completion_callback(void *args)
av_assert0(request->task_count >= 1);
for (int i = 0; i < request->task_count; ++i) {
task = request->tasks[i];
if (task->do_ioproc) {
if (task->ov_model->model->frame_post_proc != NULL) {
task->ov_model->model->frame_post_proc(task->out_frame, &output, task->ov_model->model->filter_ctx);
switch (task->ov_model->model->func_type) {
case DFT_PROCESS_FRAME:
if (task->do_ioproc) {
if (task->ov_model->model->frame_post_proc != NULL) {
task->ov_model->model->frame_post_proc(task->out_frame, &output, task->ov_model->model->filter_ctx);
} else {
ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
}
} else {
ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
task->out_frame->width = output.width;
task->out_frame->height = output.height;
}
} else {
task->out_frame->width = output.width;
task->out_frame->height = output.height;
break;
case DFT_ANALYTICS_DETECT:
if (!task->ov_model->model->detect_post_proc) {
av_log(ctx, AV_LOG_ERROR, "detect filter needs to provide post proc\n");
return;
}
task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
break;
default:
av_assert0(!"should not reach here");
break;
}
task->done = 1;
output.data = (uint8_t *)output.data
+ output.width * output.height * output.channels * get_datatype_size(output.dt);

View File

@ -71,6 +71,12 @@ int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePo
return 0;
}
int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
{
ctx->model->detect_post_proc = post_proc;
return 0;
}
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
{
return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);

View File

@ -49,6 +49,7 @@ typedef struct DnnContext {
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);

View File

@ -64,6 +64,7 @@ typedef struct DNNData{
} DNNData;
typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
typedef struct DNNModel{
// Stores model that can be different for different backends.
@ -86,6 +87,8 @@ typedef struct DNNModel{
// set the post process to transfer data from DNNData to AVFrame
// the default implementation within DNN is used if it is not provided by the filter
FramePrePostProc frame_post_proc;
// set the post process to interpret detect result from DNNData
DetectPostProc detect_post_proc;
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.