From a7532fc3f9f8f008cd6eecd6f2ffe2a9698e0cf7 Mon Sep 17 00:00:00 2001 From: nihui Date: Sun, 23 Oct 2022 17:06:10 +0800 Subject: [PATCH] implement cpu temporal tta for rife v4 --- src/rife.cpp | 1004 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 820 insertions(+), 184 deletions(-) diff --git a/src/rife.cpp b/src/rife.cpp index a35a380..66ce4c5 100644 --- a/src/rife.cpp +++ b/src/rife.cpp @@ -3263,7 +3263,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f int j = 0; for (; j < w; j++) { - *outptr++ = *ptr++ * (1 / 255.f) - 0.5f; + *outptr++ = *ptr++ * (1 / 255.f); } for (; j < w_padded; j++) { @@ -3293,7 +3293,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f int j = 0; for (; j < w; j++) { - *outptr++ = *ptr++ * (1 / 255.f) - 0.5f; + *outptr++ = *ptr++ * (1 / 255.f); } for (; j < w_padded; j++) { @@ -3310,7 +3310,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f } } { - timestep_padded[0].create(h_padded, w_padded, 1); + timestep_padded[0].create(w_padded, h_padded, 1); timestep_padded[1].create(h_padded, w_padded, 1); timestep_padded[0].fill(timestep); timestep_padded[1].fill(timestep); @@ -3412,210 +3412,702 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f } } - ncnn::Mat flow[4][8]; - for (int fi = 0; fi < 4; fi++) + ncnn::Mat out_padded[8]; + ncnn::Mat out_padded_reversed[8]; + if (tta_temporal_mode) { - for (int ti = 0; ti < 8; ti++) + ncnn::Mat timestep_padded_reversed[2]; + timestep_padded_reversed[0].create(w_padded, h_padded, 1); + timestep_padded_reversed[1].create(h_padded, w_padded, 1); + timestep_padded_reversed[0].fill(1.f - timestep); + timestep_padded_reversed[1].fill(1.f - timestep); + + ncnn::Mat flow[4][8]; + ncnn::Mat flow_reversed[4][8]; + for (int fi = 0; fi < 4; fi++) { - // flownet flow mask - ncnn::Extractor ex = flownet.create_extractor(); - - ex.input("in0", in0_padded[ti]); - ex.input("in1", in1_padded[ti]); - ex.input("in2", timestep_padded[ti / 4]); - - // intentional fall through - switch (fi) + for (int ti = 0; ti < 8; ti++) { - case 3: ex.input("flow2", flow[2][ti]); - case 2: ex.input("flow1", flow[1][ti]); - case 1: ex.input("flow0", flow[0][ti]); - default: - { - char tmp[16]; - sprintf(tmp, "flow%d", fi); - ex.extract(tmp, flow[fi][ti]); - } - } - } - - // avg flow mask - { - ncnn::Mat flow_x0 = flow[fi][0].channel(0); - ncnn::Mat flow_x1 = flow[fi][1].channel(0); - ncnn::Mat flow_x2 = flow[fi][2].channel(0); - ncnn::Mat flow_x3 = flow[fi][3].channel(0); - ncnn::Mat flow_x4 = flow[fi][4].channel(0); - ncnn::Mat flow_x5 = flow[fi][5].channel(0); - ncnn::Mat flow_x6 = flow[fi][6].channel(0); - ncnn::Mat flow_x7 = flow[fi][7].channel(0); - - ncnn::Mat flow_y0 = flow[fi][0].channel(1); - ncnn::Mat flow_y1 = flow[fi][1].channel(1); - ncnn::Mat flow_y2 = flow[fi][2].channel(1); - ncnn::Mat flow_y3 = flow[fi][3].channel(1); - ncnn::Mat flow_y4 = flow[fi][4].channel(1); - ncnn::Mat flow_y5 = flow[fi][5].channel(1); - ncnn::Mat flow_y6 = flow[fi][6].channel(1); - ncnn::Mat flow_y7 = flow[fi][7].channel(1); - - ncnn::Mat flow_z0 = flow[fi][0].channel(2); - ncnn::Mat flow_z1 = flow[fi][1].channel(2); - ncnn::Mat flow_z2 = flow[fi][2].channel(2); - ncnn::Mat flow_z3 = flow[fi][3].channel(2); - ncnn::Mat flow_z4 = flow[fi][4].channel(2); - ncnn::Mat flow_z5 = flow[fi][5].channel(2); - ncnn::Mat flow_z6 = flow[fi][6].channel(2); - ncnn::Mat flow_z7 = flow[fi][7].channel(2); - - ncnn::Mat flow_w0 = flow[fi][0].channel(3); - ncnn::Mat flow_w1 = flow[fi][1].channel(3); - ncnn::Mat flow_w2 = flow[fi][2].channel(3); - ncnn::Mat flow_w3 = flow[fi][3].channel(3); - ncnn::Mat flow_w4 = flow[fi][4].channel(3); - ncnn::Mat flow_w5 = flow[fi][5].channel(3); - ncnn::Mat flow_w6 = flow[fi][6].channel(3); - ncnn::Mat flow_w7 = flow[fi][7].channel(3); - - ncnn::Mat flow_m0 = flow[fi][0].channel(4); - ncnn::Mat flow_m1 = flow[fi][1].channel(4); - ncnn::Mat flow_m2 = flow[fi][2].channel(4); - ncnn::Mat flow_m3 = flow[fi][3].channel(4); - ncnn::Mat flow_m4 = flow[fi][4].channel(4); - ncnn::Mat flow_m5 = flow[fi][5].channel(4); - ncnn::Mat flow_m6 = flow[fi][6].channel(4); - ncnn::Mat flow_m7 = flow[fi][7].channel(4); - - for (int i = 0; i < flow_x0.h; i++) - { - float* x0 = flow_x0.row(i); - float* x1 = flow_x1.row(i) + flow_x0.w - 1; - float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; - float* x3 = flow_x3.row(flow_x0.h - 1 - i); - - float* y0 = flow_y0.row(i); - float* y1 = flow_y1.row(i) + flow_x0.w - 1; - float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; - float* y3 = flow_y3.row(flow_x0.h - 1 - i); - - float* z0 = flow_z0.row(i); - float* z1 = flow_z1.row(i) + flow_x0.w - 1; - float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; - float* z3 = flow_z3.row(flow_x0.h - 1 - i); - - float* w0 = flow_w0.row(i); - float* w1 = flow_w1.row(i) + flow_x0.w - 1; - float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; - float* w3 = flow_w3.row(flow_x0.h - 1 - i); - - float* m0 = flow_m0.row(i); - float* m1 = flow_m1.row(i) + flow_x0.w - 1; - float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; - float* m3 = flow_m3.row(flow_x0.h - 1 - i); - - for (int j = 0; j < flow_x0.w; j++) + // flownet flow mask { - float* x4 = flow_x4.row(j) + i; - float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i; - float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; - float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i; + ncnn::Extractor ex = flownet.create_extractor(); - float* y4 = flow_y4.row(j) + i; - float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i; - float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; - float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i; + ex.input("in0", in0_padded[ti]); + ex.input("in1", in1_padded[ti]); + ex.input("in2", timestep_padded[ti / 4]); - float* z4 = flow_z4.row(j) + i; - float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i; - float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; - float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i; + // intentional fall through + switch (fi) + { + case 3: ex.input("flow2", flow[2][ti]); + case 2: ex.input("flow1", flow[1][ti]); + case 1: ex.input("flow0", flow[0][ti]); + default: + { + char tmp[16]; + sprintf(tmp, "flow%d", fi); + ex.extract(tmp, flow[fi][ti]); + } + } + } - float* w4 = flow_w4.row(j) + i; - float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i; - float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; - float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i; + // flownet flow mask reversed + { + ncnn::Extractor ex = flownet.create_extractor(); - float* m4 = flow_m4.row(j) + i; - float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i; - float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; - float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i; + ex.input("in0", in1_padded[ti]); + ex.input("in1", in0_padded[ti]); + ex.input("in2", timestep_padded_reversed[ti / 4]); - float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f; - float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f; - float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f; - float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f; - float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f; + // intentional fall through + switch (fi) + { + case 3: ex.input("flow2", flow_reversed[2][ti]); + case 2: ex.input("flow1", flow_reversed[1][ti]); + case 1: ex.input("flow0", flow_reversed[0][ti]); + default: + { + char tmp[16]; + sprintf(tmp, "flow%d", fi); + ex.extract(tmp, flow_reversed[fi][ti]); + } + } + } - *x0++ = x; - *x1-- = -x; - *x2-- = -x; - *x3++ = x; - *x4 = y; - *x5 = -y; - *x6 = -y; - *x7 = y; + // merge flow and flow_reversed + { + float* flow_x = flow[fi][ti].channel(0); + float* flow_y = flow[fi][ti].channel(1); + float* flow_z = flow[fi][ti].channel(2); + float* flow_w = flow[fi][ti].channel(3); + float* flow_m = flow[fi][ti].channel(4); + float* flow_reversed_x = flow_reversed[fi][ti].channel(0); + float* flow_reversed_y = flow_reversed[fi][ti].channel(1); + float* flow_reversed_z = flow_reversed[fi][ti].channel(2); + float* flow_reversed_w = flow_reversed[fi][ti].channel(3); + float* flow_reversed_m = flow_reversed[fi][ti].channel(4); - *y0++ = y; - *y1-- = y; - *y2-- = -y; - *y3++ = -y; - *y4 = x; - *y5 = x; - *y6 = -x; - *y7 = -x; + for (int i = 0; i < flow[fi][ti].h; i++) + { + for (int j = 0; j < flow[fi][ti].w; j++) + { + float x = (*flow_x + *flow_reversed_z) * 0.5f; + float y = (*flow_y + *flow_reversed_w) * 0.5f; + float z = (*flow_z + *flow_reversed_x) * 0.5f; + float w = (*flow_w + *flow_reversed_y) * 0.5f; + float m = (*flow_m - *flow_reversed_m) * 0.5f; - *z0++ = z; - *z1-- = -z; - *z2-- = -z; - *z3++ = z; - *z4 = w; - *z5 = -w; - *z6 = -w; - *z7 = w; + *flow_x++ = x; + *flow_y++ = y; + *flow_z++ = z; + *flow_w++ = w; + *flow_m++ = m; + *flow_reversed_x++ = z; + *flow_reversed_y++ = w; + *flow_reversed_z++ = x; + *flow_reversed_w++ = y; + *flow_reversed_m++ = -m; + } + } + } + } - *w0++ = w; - *w1-- = w; - *w2-- = -w; - *w3++ = -w; - *w4 = z; - *w5 = z; - *w6 = -z; - *w7 = -z; + // avg flow mask + { + ncnn::Mat flow_x0 = flow[fi][0].channel(0); + ncnn::Mat flow_x1 = flow[fi][1].channel(0); + ncnn::Mat flow_x2 = flow[fi][2].channel(0); + ncnn::Mat flow_x3 = flow[fi][3].channel(0); + ncnn::Mat flow_x4 = flow[fi][4].channel(0); + ncnn::Mat flow_x5 = flow[fi][5].channel(0); + ncnn::Mat flow_x6 = flow[fi][6].channel(0); + ncnn::Mat flow_x7 = flow[fi][7].channel(0); - *m0++ = m; - *m1-- = m; - *m2-- = m; - *m3++ = m; - *m4 = m; - *m5 = m; - *m6 = m; - *m7 = m; + ncnn::Mat flow_y0 = flow[fi][0].channel(1); + ncnn::Mat flow_y1 = flow[fi][1].channel(1); + ncnn::Mat flow_y2 = flow[fi][2].channel(1); + ncnn::Mat flow_y3 = flow[fi][3].channel(1); + ncnn::Mat flow_y4 = flow[fi][4].channel(1); + ncnn::Mat flow_y5 = flow[fi][5].channel(1); + ncnn::Mat flow_y6 = flow[fi][6].channel(1); + ncnn::Mat flow_y7 = flow[fi][7].channel(1); + + ncnn::Mat flow_z0 = flow[fi][0].channel(2); + ncnn::Mat flow_z1 = flow[fi][1].channel(2); + ncnn::Mat flow_z2 = flow[fi][2].channel(2); + ncnn::Mat flow_z3 = flow[fi][3].channel(2); + ncnn::Mat flow_z4 = flow[fi][4].channel(2); + ncnn::Mat flow_z5 = flow[fi][5].channel(2); + ncnn::Mat flow_z6 = flow[fi][6].channel(2); + ncnn::Mat flow_z7 = flow[fi][7].channel(2); + + ncnn::Mat flow_w0 = flow[fi][0].channel(3); + ncnn::Mat flow_w1 = flow[fi][1].channel(3); + ncnn::Mat flow_w2 = flow[fi][2].channel(3); + ncnn::Mat flow_w3 = flow[fi][3].channel(3); + ncnn::Mat flow_w4 = flow[fi][4].channel(3); + ncnn::Mat flow_w5 = flow[fi][5].channel(3); + ncnn::Mat flow_w6 = flow[fi][6].channel(3); + ncnn::Mat flow_w7 = flow[fi][7].channel(3); + + ncnn::Mat flow_m0 = flow[fi][0].channel(4); + ncnn::Mat flow_m1 = flow[fi][1].channel(4); + ncnn::Mat flow_m2 = flow[fi][2].channel(4); + ncnn::Mat flow_m3 = flow[fi][3].channel(4); + ncnn::Mat flow_m4 = flow[fi][4].channel(4); + ncnn::Mat flow_m5 = flow[fi][5].channel(4); + ncnn::Mat flow_m6 = flow[fi][6].channel(4); + ncnn::Mat flow_m7 = flow[fi][7].channel(4); + + for (int i = 0; i < flow_x0.h; i++) + { + float* x0 = flow_x0.row(i); + float* x1 = flow_x1.row(i) + flow_x0.w - 1; + float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* x3 = flow_x3.row(flow_x0.h - 1 - i); + + float* y0 = flow_y0.row(i); + float* y1 = flow_y1.row(i) + flow_x0.w - 1; + float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* y3 = flow_y3.row(flow_x0.h - 1 - i); + + float* z0 = flow_z0.row(i); + float* z1 = flow_z1.row(i) + flow_x0.w - 1; + float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* z3 = flow_z3.row(flow_x0.h - 1 - i); + + float* w0 = flow_w0.row(i); + float* w1 = flow_w1.row(i) + flow_x0.w - 1; + float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* w3 = flow_w3.row(flow_x0.h - 1 - i); + + float* m0 = flow_m0.row(i); + float* m1 = flow_m1.row(i) + flow_x0.w - 1; + float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* m3 = flow_m3.row(flow_x0.h - 1 - i); + + for (int j = 0; j < flow_x0.w; j++) + { + float* x4 = flow_x4.row(j) + i; + float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i; + float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i; + + float* y4 = flow_y4.row(j) + i; + float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i; + float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i; + + float* z4 = flow_z4.row(j) + i; + float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i; + float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i; + + float* w4 = flow_w4.row(j) + i; + float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i; + float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i; + + float* m4 = flow_m4.row(j) + i; + float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i; + float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i; + + float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f; + float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f; + float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f; + float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f; + float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f; + + *x0++ = x; + *x1-- = -x; + *x2-- = -x; + *x3++ = x; + *x4 = y; + *x5 = -y; + *x6 = -y; + *x7 = y; + + *y0++ = y; + *y1-- = y; + *y2-- = -y; + *y3++ = -y; + *y4 = x; + *y5 = x; + *y6 = -x; + *y7 = -x; + + *z0++ = z; + *z1-- = -z; + *z2-- = -z; + *z3++ = z; + *z4 = w; + *z5 = -w; + *z6 = -w; + *z7 = w; + + *w0++ = w; + *w1-- = w; + *w2-- = -w; + *w3++ = -w; + *w4 = z; + *w5 = z; + *w6 = -z; + *w7 = -z; + + *m0++ = m; + *m1-- = m; + *m2-- = m; + *m3++ = m; + *m4 = m; + *m5 = m; + *m6 = m; + *m7 = m; + } + } + } + { + ncnn::Mat flow_x0 = flow_reversed[fi][0].channel(0); + ncnn::Mat flow_x1 = flow_reversed[fi][1].channel(0); + ncnn::Mat flow_x2 = flow_reversed[fi][2].channel(0); + ncnn::Mat flow_x3 = flow_reversed[fi][3].channel(0); + ncnn::Mat flow_x4 = flow_reversed[fi][4].channel(0); + ncnn::Mat flow_x5 = flow_reversed[fi][5].channel(0); + ncnn::Mat flow_x6 = flow_reversed[fi][6].channel(0); + ncnn::Mat flow_x7 = flow_reversed[fi][7].channel(0); + + ncnn::Mat flow_y0 = flow_reversed[fi][0].channel(1); + ncnn::Mat flow_y1 = flow_reversed[fi][1].channel(1); + ncnn::Mat flow_y2 = flow_reversed[fi][2].channel(1); + ncnn::Mat flow_y3 = flow_reversed[fi][3].channel(1); + ncnn::Mat flow_y4 = flow_reversed[fi][4].channel(1); + ncnn::Mat flow_y5 = flow_reversed[fi][5].channel(1); + ncnn::Mat flow_y6 = flow_reversed[fi][6].channel(1); + ncnn::Mat flow_y7 = flow_reversed[fi][7].channel(1); + + ncnn::Mat flow_z0 = flow_reversed[fi][0].channel(2); + ncnn::Mat flow_z1 = flow_reversed[fi][1].channel(2); + ncnn::Mat flow_z2 = flow_reversed[fi][2].channel(2); + ncnn::Mat flow_z3 = flow_reversed[fi][3].channel(2); + ncnn::Mat flow_z4 = flow_reversed[fi][4].channel(2); + ncnn::Mat flow_z5 = flow_reversed[fi][5].channel(2); + ncnn::Mat flow_z6 = flow_reversed[fi][6].channel(2); + ncnn::Mat flow_z7 = flow_reversed[fi][7].channel(2); + + ncnn::Mat flow_w0 = flow_reversed[fi][0].channel(3); + ncnn::Mat flow_w1 = flow_reversed[fi][1].channel(3); + ncnn::Mat flow_w2 = flow_reversed[fi][2].channel(3); + ncnn::Mat flow_w3 = flow_reversed[fi][3].channel(3); + ncnn::Mat flow_w4 = flow_reversed[fi][4].channel(3); + ncnn::Mat flow_w5 = flow_reversed[fi][5].channel(3); + ncnn::Mat flow_w6 = flow_reversed[fi][6].channel(3); + ncnn::Mat flow_w7 = flow_reversed[fi][7].channel(3); + + ncnn::Mat flow_m0 = flow_reversed[fi][0].channel(4); + ncnn::Mat flow_m1 = flow_reversed[fi][1].channel(4); + ncnn::Mat flow_m2 = flow_reversed[fi][2].channel(4); + ncnn::Mat flow_m3 = flow_reversed[fi][3].channel(4); + ncnn::Mat flow_m4 = flow_reversed[fi][4].channel(4); + ncnn::Mat flow_m5 = flow_reversed[fi][5].channel(4); + ncnn::Mat flow_m6 = flow_reversed[fi][6].channel(4); + ncnn::Mat flow_m7 = flow_reversed[fi][7].channel(4); + + for (int i = 0; i < flow_x0.h; i++) + { + float* x0 = flow_x0.row(i); + float* x1 = flow_x1.row(i) + flow_x0.w - 1; + float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* x3 = flow_x3.row(flow_x0.h - 1 - i); + + float* y0 = flow_y0.row(i); + float* y1 = flow_y1.row(i) + flow_x0.w - 1; + float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* y3 = flow_y3.row(flow_x0.h - 1 - i); + + float* z0 = flow_z0.row(i); + float* z1 = flow_z1.row(i) + flow_x0.w - 1; + float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* z3 = flow_z3.row(flow_x0.h - 1 - i); + + float* w0 = flow_w0.row(i); + float* w1 = flow_w1.row(i) + flow_x0.w - 1; + float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* w3 = flow_w3.row(flow_x0.h - 1 - i); + + float* m0 = flow_m0.row(i); + float* m1 = flow_m1.row(i) + flow_x0.w - 1; + float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* m3 = flow_m3.row(flow_x0.h - 1 - i); + + for (int j = 0; j < flow_x0.w; j++) + { + float* x4 = flow_x4.row(j) + i; + float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i; + float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i; + + float* y4 = flow_y4.row(j) + i; + float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i; + float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i; + + float* z4 = flow_z4.row(j) + i; + float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i; + float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i; + + float* w4 = flow_w4.row(j) + i; + float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i; + float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i; + + float* m4 = flow_m4.row(j) + i; + float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i; + float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i; + + float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f; + float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f; + float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f; + float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f; + float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f; + + *x0++ = x; + *x1-- = -x; + *x2-- = -x; + *x3++ = x; + *x4 = y; + *x5 = -y; + *x6 = -y; + *x7 = y; + + *y0++ = y; + *y1-- = y; + *y2-- = -y; + *y3++ = -y; + *y4 = x; + *y5 = x; + *y6 = -x; + *y7 = -x; + + *z0++ = z; + *z1-- = -z; + *z2-- = -z; + *z3++ = z; + *z4 = w; + *z5 = -w; + *z6 = -w; + *z7 = w; + + *w0++ = w; + *w1-- = w; + *w2-- = -w; + *w3++ = -w; + *w4 = z; + *w5 = z; + *w6 = -z; + *w7 = -z; + + *m0++ = m; + *m1-- = m; + *m2-- = m; + *m3++ = m; + *m4 = m; + *m5 = m; + *m6 = m; + *m7 = m; + } } } } - } - - ncnn::Mat out_padded[8]; - for (int ti = 0; ti < 8; ti++) - { - // flownet + for (int ti = 0; ti < 8; ti++) { - ncnn::Extractor ex = flownet.create_extractor(); + // flownet + { + ncnn::Extractor ex = flownet.create_extractor(); - ex.input("in0", in0_padded[ti]); - ex.input("in1", in1_padded[ti]); - ex.input("in2", timestep_padded[ti / 4]); - ex.input("flow0", flow[0][ti]); - ex.input("flow1", flow[1][ti]); - ex.input("flow2", flow[2][ti]); - ex.input("flow3", flow[3][ti]); - ex.extract("out0", out_padded[ti]); + ex.input("in0", in0_padded[ti]); + ex.input("in1", in1_padded[ti]); + ex.input("in2", timestep_padded[ti / 4]); + ex.input("flow0", flow[0][ti]); + ex.input("flow1", flow[1][ti]); + ex.input("flow2", flow[2][ti]); + ex.input("flow3", flow[3][ti]); + ex.extract("out0", out_padded[ti]); + } + { + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in1_padded[ti]); + ex.input("in1", in0_padded[ti]); + ex.input("in2", timestep_padded_reversed[ti / 4]); + ex.input("flow0", flow_reversed[0][ti]); + ex.input("flow1", flow_reversed[1][ti]); + ex.input("flow2", flow_reversed[2][ti]); + ex.input("flow3", flow_reversed[3][ti]); + ex.extract("out0", out_padded_reversed[ti]); + } + } + } + else + { + ncnn::Mat flow[4][8]; + for (int fi = 0; fi < 4; fi++) + { + for (int ti = 0; ti < 8; ti++) + { + // flownet flow mask + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in0_padded[ti]); + ex.input("in1", in1_padded[ti]); + ex.input("in2", timestep_padded[ti / 4]); + + // intentional fall through + switch (fi) + { + case 3: ex.input("flow2", flow[2][ti]); + case 2: ex.input("flow1", flow[1][ti]); + case 1: ex.input("flow0", flow[0][ti]); + default: + { + char tmp[16]; + sprintf(tmp, "flow%d", fi); + ex.extract(tmp, flow[fi][ti]); + } + } + } + + // avg flow mask + { + ncnn::Mat flow_x0 = flow[fi][0].channel(0); + ncnn::Mat flow_x1 = flow[fi][1].channel(0); + ncnn::Mat flow_x2 = flow[fi][2].channel(0); + ncnn::Mat flow_x3 = flow[fi][3].channel(0); + ncnn::Mat flow_x4 = flow[fi][4].channel(0); + ncnn::Mat flow_x5 = flow[fi][5].channel(0); + ncnn::Mat flow_x6 = flow[fi][6].channel(0); + ncnn::Mat flow_x7 = flow[fi][7].channel(0); + + ncnn::Mat flow_y0 = flow[fi][0].channel(1); + ncnn::Mat flow_y1 = flow[fi][1].channel(1); + ncnn::Mat flow_y2 = flow[fi][2].channel(1); + ncnn::Mat flow_y3 = flow[fi][3].channel(1); + ncnn::Mat flow_y4 = flow[fi][4].channel(1); + ncnn::Mat flow_y5 = flow[fi][5].channel(1); + ncnn::Mat flow_y6 = flow[fi][6].channel(1); + ncnn::Mat flow_y7 = flow[fi][7].channel(1); + + ncnn::Mat flow_z0 = flow[fi][0].channel(2); + ncnn::Mat flow_z1 = flow[fi][1].channel(2); + ncnn::Mat flow_z2 = flow[fi][2].channel(2); + ncnn::Mat flow_z3 = flow[fi][3].channel(2); + ncnn::Mat flow_z4 = flow[fi][4].channel(2); + ncnn::Mat flow_z5 = flow[fi][5].channel(2); + ncnn::Mat flow_z6 = flow[fi][6].channel(2); + ncnn::Mat flow_z7 = flow[fi][7].channel(2); + + ncnn::Mat flow_w0 = flow[fi][0].channel(3); + ncnn::Mat flow_w1 = flow[fi][1].channel(3); + ncnn::Mat flow_w2 = flow[fi][2].channel(3); + ncnn::Mat flow_w3 = flow[fi][3].channel(3); + ncnn::Mat flow_w4 = flow[fi][4].channel(3); + ncnn::Mat flow_w5 = flow[fi][5].channel(3); + ncnn::Mat flow_w6 = flow[fi][6].channel(3); + ncnn::Mat flow_w7 = flow[fi][7].channel(3); + + ncnn::Mat flow_m0 = flow[fi][0].channel(4); + ncnn::Mat flow_m1 = flow[fi][1].channel(4); + ncnn::Mat flow_m2 = flow[fi][2].channel(4); + ncnn::Mat flow_m3 = flow[fi][3].channel(4); + ncnn::Mat flow_m4 = flow[fi][4].channel(4); + ncnn::Mat flow_m5 = flow[fi][5].channel(4); + ncnn::Mat flow_m6 = flow[fi][6].channel(4); + ncnn::Mat flow_m7 = flow[fi][7].channel(4); + + for (int i = 0; i < flow_x0.h; i++) + { + float* x0 = flow_x0.row(i); + float* x1 = flow_x1.row(i) + flow_x0.w - 1; + float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* x3 = flow_x3.row(flow_x0.h - 1 - i); + + float* y0 = flow_y0.row(i); + float* y1 = flow_y1.row(i) + flow_x0.w - 1; + float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* y3 = flow_y3.row(flow_x0.h - 1 - i); + + float* z0 = flow_z0.row(i); + float* z1 = flow_z1.row(i) + flow_x0.w - 1; + float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* z3 = flow_z3.row(flow_x0.h - 1 - i); + + float* w0 = flow_w0.row(i); + float* w1 = flow_w1.row(i) + flow_x0.w - 1; + float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* w3 = flow_w3.row(flow_x0.h - 1 - i); + + float* m0 = flow_m0.row(i); + float* m1 = flow_m1.row(i) + flow_x0.w - 1; + float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1; + float* m3 = flow_m3.row(flow_x0.h - 1 - i); + + for (int j = 0; j < flow_x0.w; j++) + { + float* x4 = flow_x4.row(j) + i; + float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i; + float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i; + + float* y4 = flow_y4.row(j) + i; + float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i; + float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i; + + float* z4 = flow_z4.row(j) + i; + float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i; + float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i; + + float* w4 = flow_w4.row(j) + i; + float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i; + float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i; + + float* m4 = flow_m4.row(j) + i; + float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i; + float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i; + float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i; + + float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f; + float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f; + float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f; + float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f; + float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f; + + *x0++ = x; + *x1-- = -x; + *x2-- = -x; + *x3++ = x; + *x4 = y; + *x5 = -y; + *x6 = -y; + *x7 = y; + + *y0++ = y; + *y1-- = y; + *y2-- = -y; + *y3++ = -y; + *y4 = x; + *y5 = x; + *y6 = -x; + *y7 = -x; + + *z0++ = z; + *z1-- = -z; + *z2-- = -z; + *z3++ = z; + *z4 = w; + *z5 = -w; + *z6 = -w; + *z7 = w; + + *w0++ = w; + *w1-- = w; + *w2-- = -w; + *w3++ = -w; + *w4 = z; + *w5 = z; + *w6 = -z; + *w7 = -z; + + *m0++ = m; + *m1-- = m; + *m2-- = m; + *m3++ = m; + *m4 = m; + *m5 = m; + *m6 = m; + *m7 = m; + } + } + } + } + + for (int ti = 0; ti < 8; ti++) + { + // flownet + { + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in0_padded[ti]); + ex.input("in1", in1_padded[ti]); + ex.input("in2", timestep_padded[ti / 4]); + ex.input("flow0", flow[0][ti]); + ex.input("flow1", flow[1][ti]); + ex.input("flow2", flow[2][ti]); + ex.input("flow3", flow[3][ti]); + ex.extract("out0", out_padded[ti]); + } } } // cut padding and postproc out.create(w, h, 3); + if (tta_temporal_mode) + { + for (int q = 0; q < 3; q++) + { + const ncnn::Mat out_padded_0 = out_padded[0].channel(q); + const ncnn::Mat out_padded_1 = out_padded[1].channel(q); + const ncnn::Mat out_padded_2 = out_padded[2].channel(q); + const ncnn::Mat out_padded_3 = out_padded[3].channel(q); + const ncnn::Mat out_padded_4 = out_padded[4].channel(q); + const ncnn::Mat out_padded_5 = out_padded[5].channel(q); + const ncnn::Mat out_padded_6 = out_padded[6].channel(q); + const ncnn::Mat out_padded_7 = out_padded[7].channel(q); + const ncnn::Mat out_padded_reversed_0 = out_padded_reversed[0].channel(q); + const ncnn::Mat out_padded_reversed_1 = out_padded_reversed[1].channel(q); + const ncnn::Mat out_padded_reversed_2 = out_padded_reversed[2].channel(q); + const ncnn::Mat out_padded_reversed_3 = out_padded_reversed[3].channel(q); + const ncnn::Mat out_padded_reversed_4 = out_padded_reversed[4].channel(q); + const ncnn::Mat out_padded_reversed_5 = out_padded_reversed[5].channel(q); + const ncnn::Mat out_padded_reversed_6 = out_padded_reversed[6].channel(q); + const ncnn::Mat out_padded_reversed_7 = out_padded_reversed[7].channel(q); + float* outptr = out.channel(q); + + for (int i = 0; i < h; i++) + { + const float* ptr0 = out_padded_0.row(i); + const float* ptr1 = out_padded_1.row(i) + w_padded - 1; + const float* ptr2 = out_padded_2.row(h_padded - 1 - i) + w_padded - 1; + const float* ptr3 = out_padded_3.row(h_padded - 1 - i); + const float* ptrr0 = out_padded_reversed_0.row(i); + const float* ptrr1 = out_padded_reversed_1.row(i) + w_padded - 1; + const float* ptrr2 = out_padded_reversed_2.row(h_padded - 1 - i) + w_padded - 1; + const float* ptrr3 = out_padded_reversed_3.row(h_padded - 1 - i); + + for (int j = 0; j < w; j++) + { + const float* ptr4 = out_padded_4.row(j) + i; + const float* ptr5 = out_padded_5.row(j) + h_padded - 1 - i; + const float* ptr6 = out_padded_6.row(w_padded - 1 - j) + h_padded - 1 - i; + const float* ptr7 = out_padded_7.row(w_padded - 1 - j) + i; + const float* ptrr4 = out_padded_reversed_4.row(j) + i; + const float* ptrr5 = out_padded_reversed_5.row(j) + h_padded - 1 - i; + const float* ptrr6 = out_padded_reversed_6.row(w_padded - 1 - j) + h_padded - 1 - i; + const float* ptrr7 = out_padded_reversed_7.row(w_padded - 1 - j) + i; + + float v = (*ptr0++ + *ptr1-- + *ptr2-- + *ptr3++ + *ptr4 + *ptr5 + *ptr6 + *ptr7) / 8; + float vr = (*ptrr0++ + *ptrr1-- + *ptrr2-- + *ptrr3++ + *ptrr4 + *ptrr5 + *ptrr6 + *ptrr7) / 8; + + *outptr++ = (v + vr) * 0.5f * 255.f + 0.5f; + } + } + } + } + else { for (int q = 0; q < 3; q++) { @@ -3645,7 +4137,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f float v = (*ptr0++ + *ptr1-- + *ptr2-- + *ptr3++ + *ptr4 + *ptr5 + *ptr6 + *ptr7) / 8; - *outptr++ = (v + 0.5f) * 255.f + 0.5f; + *outptr++ = v * 255.f + 0.5f; } } } @@ -3722,9 +4214,135 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f timestep_padded.fill(timestep); } - // flownet ncnn::Mat out_padded; + ncnn::Mat out_padded_reversed; + if (tta_temporal_mode) { + ncnn::Mat timestep_padded_reversed; + { + timestep_padded_reversed.create(w_padded, h_padded, 1); + timestep_padded_reversed.fill(1.f - timestep); + } + + ncnn::Mat flow[4]; + ncnn::Mat flow_reversed[4]; + for (int fi = 0; fi < 4; fi++) + { + { + // flownet flow mask + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in0_padded); + ex.input("in1", in1_padded); + ex.input("in2", timestep_padded); + + // intentional fall through + switch (fi) + { + case 3: ex.input("flow2", flow[2]); + case 2: ex.input("flow1", flow[1]); + case 1: ex.input("flow0", flow[0]); + default: + { + char tmp[16]; + sprintf(tmp, "flow%d", fi); + ex.extract(tmp, flow[fi]); + } + } + } + + { + // flownet flow mask reversed + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in1_padded); + ex.input("in1", in0_padded); + ex.input("in2", timestep_padded_reversed); + + // intentional fall through + switch (fi) + { + case 3: ex.input("flow2", flow_reversed[2]); + case 2: ex.input("flow1", flow_reversed[1]); + case 1: ex.input("flow0", flow_reversed[0]); + default: + { + char tmp[16]; + sprintf(tmp, "flow%d", fi); + ex.extract(tmp, flow_reversed[fi]); + } + } + } + + // merge flow and flow_reversed + { + float* flow_x = flow[fi].channel(0); + float* flow_y = flow[fi].channel(1); + float* flow_z = flow[fi].channel(2); + float* flow_w = flow[fi].channel(3); + float* flow_m = flow[fi].channel(4); + float* flow_reversed_x = flow_reversed[fi].channel(0); + float* flow_reversed_y = flow_reversed[fi].channel(1); + float* flow_reversed_z = flow_reversed[fi].channel(2); + float* flow_reversed_w = flow_reversed[fi].channel(3); + float* flow_reversed_m = flow_reversed[fi].channel(4); + + for (int i = 0; i < flow[fi].h; i++) + { + for (int j = 0; j < flow[fi].w; j++) + { + float x = (*flow_x + *flow_reversed_z) * 0.5f; + float y = (*flow_y + *flow_reversed_w) * 0.5f; + float z = (*flow_z + *flow_reversed_x) * 0.5f; + float w = (*flow_w + *flow_reversed_y) * 0.5f; + float m = (*flow_m - *flow_reversed_m) * 0.5f; + + *flow_x++ = x; + *flow_y++ = y; + *flow_z++ = z; + *flow_w++ = w; + *flow_m++ = m; + *flow_reversed_x++ = z; + *flow_reversed_y++ = w; + *flow_reversed_z++ = x; + *flow_reversed_w++ = y; + *flow_reversed_m++ = -m; + } + } + } + } + + { + // flownet + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in0_padded); + ex.input("in1", in1_padded); + ex.input("in2", timestep_padded); + ex.input("flow0", flow[0]); + ex.input("flow1", flow[1]); + ex.input("flow2", flow[2]); + ex.input("flow3", flow[3]); + + ex.extract("out0", out_padded); + } + { + ncnn::Extractor ex = flownet.create_extractor(); + + ex.input("in0", in1_padded); + ex.input("in1", in0_padded); + ex.input("in2", timestep_padded_reversed); + ex.input("flow0", flow_reversed[0]); + ex.input("flow1", flow_reversed[1]); + ex.input("flow2", flow_reversed[2]); + ex.input("flow3", flow_reversed[3]); + + ex.extract("out0", out_padded_reversed); + } + } + else + { + // flownet ncnn::Extractor ex = flownet.create_extractor(); ex.input("in0", in0_padded); @@ -3735,6 +4353,24 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f // cut padding and postproc out.create(w, h, 3); + if (tta_temporal_mode) + { + for (int q = 0; q < 3; q++) + { + float* outptr = out.channel(q); + const float* ptr = out_padded.channel(q); + const float* ptr1 = out_padded_reversed.channel(q); + + for (int i = 0; i < h; i++) + { + for (int j = 0; j < w; j++) + { + *outptr++ = (*ptr++ + *ptr1++) * 0.5f * 255.f + 0.5f; + } + } + } + } + else { for (int q = 0; q < 3; q++) {