1
mirror of https://github.com/nihui/rife-ncnn-vulkan synced 2024-11-23 02:56:52 +01:00

implement cpu temporal tta for rife v4

This commit is contained in:
nihui 2022-10-23 17:06:10 +08:00
parent d964f6f165
commit a7532fc3f9

View File

@ -3263,7 +3263,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
int j = 0;
for (; j < w; j++)
{
*outptr++ = *ptr++ * (1 / 255.f) - 0.5f;
*outptr++ = *ptr++ * (1 / 255.f);
}
for (; j < w_padded; j++)
{
@ -3293,7 +3293,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
int j = 0;
for (; j < w; j++)
{
*outptr++ = *ptr++ * (1 / 255.f) - 0.5f;
*outptr++ = *ptr++ * (1 / 255.f);
}
for (; j < w_padded; j++)
{
@ -3310,7 +3310,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
}
}
{
timestep_padded[0].create(h_padded, w_padded, 1);
timestep_padded[0].create(w_padded, h_padded, 1);
timestep_padded[1].create(h_padded, w_padded, 1);
timestep_padded[0].fill(timestep);
timestep_padded[1].fill(timestep);
@ -3412,6 +3412,446 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
}
}
ncnn::Mat out_padded[8];
ncnn::Mat out_padded_reversed[8];
if (tta_temporal_mode)
{
ncnn::Mat timestep_padded_reversed[2];
timestep_padded_reversed[0].create(w_padded, h_padded, 1);
timestep_padded_reversed[1].create(h_padded, w_padded, 1);
timestep_padded_reversed[0].fill(1.f - timestep);
timestep_padded_reversed[1].fill(1.f - timestep);
ncnn::Mat flow[4][8];
ncnn::Mat flow_reversed[4][8];
for (int fi = 0; fi < 4; fi++)
{
for (int ti = 0; ti < 8; ti++)
{
// flownet flow mask
{
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in0_padded[ti]);
ex.input("in1", in1_padded[ti]);
ex.input("in2", timestep_padded[ti / 4]);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow[2][ti]);
case 2: ex.input("flow1", flow[1][ti]);
case 1: ex.input("flow0", flow[0][ti]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow[fi][ti]);
}
}
}
// flownet flow mask reversed
{
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in1_padded[ti]);
ex.input("in1", in0_padded[ti]);
ex.input("in2", timestep_padded_reversed[ti / 4]);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow_reversed[2][ti]);
case 2: ex.input("flow1", flow_reversed[1][ti]);
case 1: ex.input("flow0", flow_reversed[0][ti]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow_reversed[fi][ti]);
}
}
}
// merge flow and flow_reversed
{
float* flow_x = flow[fi][ti].channel(0);
float* flow_y = flow[fi][ti].channel(1);
float* flow_z = flow[fi][ti].channel(2);
float* flow_w = flow[fi][ti].channel(3);
float* flow_m = flow[fi][ti].channel(4);
float* flow_reversed_x = flow_reversed[fi][ti].channel(0);
float* flow_reversed_y = flow_reversed[fi][ti].channel(1);
float* flow_reversed_z = flow_reversed[fi][ti].channel(2);
float* flow_reversed_w = flow_reversed[fi][ti].channel(3);
float* flow_reversed_m = flow_reversed[fi][ti].channel(4);
for (int i = 0; i < flow[fi][ti].h; i++)
{
for (int j = 0; j < flow[fi][ti].w; j++)
{
float x = (*flow_x + *flow_reversed_z) * 0.5f;
float y = (*flow_y + *flow_reversed_w) * 0.5f;
float z = (*flow_z + *flow_reversed_x) * 0.5f;
float w = (*flow_w + *flow_reversed_y) * 0.5f;
float m = (*flow_m - *flow_reversed_m) * 0.5f;
*flow_x++ = x;
*flow_y++ = y;
*flow_z++ = z;
*flow_w++ = w;
*flow_m++ = m;
*flow_reversed_x++ = z;
*flow_reversed_y++ = w;
*flow_reversed_z++ = x;
*flow_reversed_w++ = y;
*flow_reversed_m++ = -m;
}
}
}
}
// avg flow mask
{
ncnn::Mat flow_x0 = flow[fi][0].channel(0);
ncnn::Mat flow_x1 = flow[fi][1].channel(0);
ncnn::Mat flow_x2 = flow[fi][2].channel(0);
ncnn::Mat flow_x3 = flow[fi][3].channel(0);
ncnn::Mat flow_x4 = flow[fi][4].channel(0);
ncnn::Mat flow_x5 = flow[fi][5].channel(0);
ncnn::Mat flow_x6 = flow[fi][6].channel(0);
ncnn::Mat flow_x7 = flow[fi][7].channel(0);
ncnn::Mat flow_y0 = flow[fi][0].channel(1);
ncnn::Mat flow_y1 = flow[fi][1].channel(1);
ncnn::Mat flow_y2 = flow[fi][2].channel(1);
ncnn::Mat flow_y3 = flow[fi][3].channel(1);
ncnn::Mat flow_y4 = flow[fi][4].channel(1);
ncnn::Mat flow_y5 = flow[fi][5].channel(1);
ncnn::Mat flow_y6 = flow[fi][6].channel(1);
ncnn::Mat flow_y7 = flow[fi][7].channel(1);
ncnn::Mat flow_z0 = flow[fi][0].channel(2);
ncnn::Mat flow_z1 = flow[fi][1].channel(2);
ncnn::Mat flow_z2 = flow[fi][2].channel(2);
ncnn::Mat flow_z3 = flow[fi][3].channel(2);
ncnn::Mat flow_z4 = flow[fi][4].channel(2);
ncnn::Mat flow_z5 = flow[fi][5].channel(2);
ncnn::Mat flow_z6 = flow[fi][6].channel(2);
ncnn::Mat flow_z7 = flow[fi][7].channel(2);
ncnn::Mat flow_w0 = flow[fi][0].channel(3);
ncnn::Mat flow_w1 = flow[fi][1].channel(3);
ncnn::Mat flow_w2 = flow[fi][2].channel(3);
ncnn::Mat flow_w3 = flow[fi][3].channel(3);
ncnn::Mat flow_w4 = flow[fi][4].channel(3);
ncnn::Mat flow_w5 = flow[fi][5].channel(3);
ncnn::Mat flow_w6 = flow[fi][6].channel(3);
ncnn::Mat flow_w7 = flow[fi][7].channel(3);
ncnn::Mat flow_m0 = flow[fi][0].channel(4);
ncnn::Mat flow_m1 = flow[fi][1].channel(4);
ncnn::Mat flow_m2 = flow[fi][2].channel(4);
ncnn::Mat flow_m3 = flow[fi][3].channel(4);
ncnn::Mat flow_m4 = flow[fi][4].channel(4);
ncnn::Mat flow_m5 = flow[fi][5].channel(4);
ncnn::Mat flow_m6 = flow[fi][6].channel(4);
ncnn::Mat flow_m7 = flow[fi][7].channel(4);
for (int i = 0; i < flow_x0.h; i++)
{
float* x0 = flow_x0.row(i);
float* x1 = flow_x1.row(i) + flow_x0.w - 1;
float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* x3 = flow_x3.row(flow_x0.h - 1 - i);
float* y0 = flow_y0.row(i);
float* y1 = flow_y1.row(i) + flow_x0.w - 1;
float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* y3 = flow_y3.row(flow_x0.h - 1 - i);
float* z0 = flow_z0.row(i);
float* z1 = flow_z1.row(i) + flow_x0.w - 1;
float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* z3 = flow_z3.row(flow_x0.h - 1 - i);
float* w0 = flow_w0.row(i);
float* w1 = flow_w1.row(i) + flow_x0.w - 1;
float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* w3 = flow_w3.row(flow_x0.h - 1 - i);
float* m0 = flow_m0.row(i);
float* m1 = flow_m1.row(i) + flow_x0.w - 1;
float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* m3 = flow_m3.row(flow_x0.h - 1 - i);
for (int j = 0; j < flow_x0.w; j++)
{
float* x4 = flow_x4.row(j) + i;
float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i;
float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i;
float* y4 = flow_y4.row(j) + i;
float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i;
float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i;
float* z4 = flow_z4.row(j) + i;
float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i;
float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i;
float* w4 = flow_w4.row(j) + i;
float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i;
float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i;
float* m4 = flow_m4.row(j) + i;
float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i;
float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i;
float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f;
float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f;
float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f;
float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f;
float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f;
*x0++ = x;
*x1-- = -x;
*x2-- = -x;
*x3++ = x;
*x4 = y;
*x5 = -y;
*x6 = -y;
*x7 = y;
*y0++ = y;
*y1-- = y;
*y2-- = -y;
*y3++ = -y;
*y4 = x;
*y5 = x;
*y6 = -x;
*y7 = -x;
*z0++ = z;
*z1-- = -z;
*z2-- = -z;
*z3++ = z;
*z4 = w;
*z5 = -w;
*z6 = -w;
*z7 = w;
*w0++ = w;
*w1-- = w;
*w2-- = -w;
*w3++ = -w;
*w4 = z;
*w5 = z;
*w6 = -z;
*w7 = -z;
*m0++ = m;
*m1-- = m;
*m2-- = m;
*m3++ = m;
*m4 = m;
*m5 = m;
*m6 = m;
*m7 = m;
}
}
}
{
ncnn::Mat flow_x0 = flow_reversed[fi][0].channel(0);
ncnn::Mat flow_x1 = flow_reversed[fi][1].channel(0);
ncnn::Mat flow_x2 = flow_reversed[fi][2].channel(0);
ncnn::Mat flow_x3 = flow_reversed[fi][3].channel(0);
ncnn::Mat flow_x4 = flow_reversed[fi][4].channel(0);
ncnn::Mat flow_x5 = flow_reversed[fi][5].channel(0);
ncnn::Mat flow_x6 = flow_reversed[fi][6].channel(0);
ncnn::Mat flow_x7 = flow_reversed[fi][7].channel(0);
ncnn::Mat flow_y0 = flow_reversed[fi][0].channel(1);
ncnn::Mat flow_y1 = flow_reversed[fi][1].channel(1);
ncnn::Mat flow_y2 = flow_reversed[fi][2].channel(1);
ncnn::Mat flow_y3 = flow_reversed[fi][3].channel(1);
ncnn::Mat flow_y4 = flow_reversed[fi][4].channel(1);
ncnn::Mat flow_y5 = flow_reversed[fi][5].channel(1);
ncnn::Mat flow_y6 = flow_reversed[fi][6].channel(1);
ncnn::Mat flow_y7 = flow_reversed[fi][7].channel(1);
ncnn::Mat flow_z0 = flow_reversed[fi][0].channel(2);
ncnn::Mat flow_z1 = flow_reversed[fi][1].channel(2);
ncnn::Mat flow_z2 = flow_reversed[fi][2].channel(2);
ncnn::Mat flow_z3 = flow_reversed[fi][3].channel(2);
ncnn::Mat flow_z4 = flow_reversed[fi][4].channel(2);
ncnn::Mat flow_z5 = flow_reversed[fi][5].channel(2);
ncnn::Mat flow_z6 = flow_reversed[fi][6].channel(2);
ncnn::Mat flow_z7 = flow_reversed[fi][7].channel(2);
ncnn::Mat flow_w0 = flow_reversed[fi][0].channel(3);
ncnn::Mat flow_w1 = flow_reversed[fi][1].channel(3);
ncnn::Mat flow_w2 = flow_reversed[fi][2].channel(3);
ncnn::Mat flow_w3 = flow_reversed[fi][3].channel(3);
ncnn::Mat flow_w4 = flow_reversed[fi][4].channel(3);
ncnn::Mat flow_w5 = flow_reversed[fi][5].channel(3);
ncnn::Mat flow_w6 = flow_reversed[fi][6].channel(3);
ncnn::Mat flow_w7 = flow_reversed[fi][7].channel(3);
ncnn::Mat flow_m0 = flow_reversed[fi][0].channel(4);
ncnn::Mat flow_m1 = flow_reversed[fi][1].channel(4);
ncnn::Mat flow_m2 = flow_reversed[fi][2].channel(4);
ncnn::Mat flow_m3 = flow_reversed[fi][3].channel(4);
ncnn::Mat flow_m4 = flow_reversed[fi][4].channel(4);
ncnn::Mat flow_m5 = flow_reversed[fi][5].channel(4);
ncnn::Mat flow_m6 = flow_reversed[fi][6].channel(4);
ncnn::Mat flow_m7 = flow_reversed[fi][7].channel(4);
for (int i = 0; i < flow_x0.h; i++)
{
float* x0 = flow_x0.row(i);
float* x1 = flow_x1.row(i) + flow_x0.w - 1;
float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* x3 = flow_x3.row(flow_x0.h - 1 - i);
float* y0 = flow_y0.row(i);
float* y1 = flow_y1.row(i) + flow_x0.w - 1;
float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* y3 = flow_y3.row(flow_x0.h - 1 - i);
float* z0 = flow_z0.row(i);
float* z1 = flow_z1.row(i) + flow_x0.w - 1;
float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* z3 = flow_z3.row(flow_x0.h - 1 - i);
float* w0 = flow_w0.row(i);
float* w1 = flow_w1.row(i) + flow_x0.w - 1;
float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* w3 = flow_w3.row(flow_x0.h - 1 - i);
float* m0 = flow_m0.row(i);
float* m1 = flow_m1.row(i) + flow_x0.w - 1;
float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
float* m3 = flow_m3.row(flow_x0.h - 1 - i);
for (int j = 0; j < flow_x0.w; j++)
{
float* x4 = flow_x4.row(j) + i;
float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i;
float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i;
float* y4 = flow_y4.row(j) + i;
float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i;
float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i;
float* z4 = flow_z4.row(j) + i;
float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i;
float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i;
float* w4 = flow_w4.row(j) + i;
float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i;
float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i;
float* m4 = flow_m4.row(j) + i;
float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i;
float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i;
float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f;
float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f;
float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f;
float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f;
float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f;
*x0++ = x;
*x1-- = -x;
*x2-- = -x;
*x3++ = x;
*x4 = y;
*x5 = -y;
*x6 = -y;
*x7 = y;
*y0++ = y;
*y1-- = y;
*y2-- = -y;
*y3++ = -y;
*y4 = x;
*y5 = x;
*y6 = -x;
*y7 = -x;
*z0++ = z;
*z1-- = -z;
*z2-- = -z;
*z3++ = z;
*z4 = w;
*z5 = -w;
*z6 = -w;
*z7 = w;
*w0++ = w;
*w1-- = w;
*w2-- = -w;
*w3++ = -w;
*w4 = z;
*w5 = z;
*w6 = -z;
*w7 = -z;
*m0++ = m;
*m1-- = m;
*m2-- = m;
*m3++ = m;
*m4 = m;
*m5 = m;
*m6 = m;
*m7 = m;
}
}
}
}
for (int ti = 0; ti < 8; ti++)
{
// flownet
{
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in0_padded[ti]);
ex.input("in1", in1_padded[ti]);
ex.input("in2", timestep_padded[ti / 4]);
ex.input("flow0", flow[0][ti]);
ex.input("flow1", flow[1][ti]);
ex.input("flow2", flow[2][ti]);
ex.input("flow3", flow[3][ti]);
ex.extract("out0", out_padded[ti]);
}
{
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in1_padded[ti]);
ex.input("in1", in0_padded[ti]);
ex.input("in2", timestep_padded_reversed[ti / 4]);
ex.input("flow0", flow_reversed[0][ti]);
ex.input("flow1", flow_reversed[1][ti]);
ex.input("flow2", flow_reversed[2][ti]);
ex.input("flow3", flow_reversed[3][ti]);
ex.extract("out0", out_padded_reversed[ti]);
}
}
}
else
{
ncnn::Mat flow[4][8];
for (int fi = 0; fi < 4; fi++)
{
@ -3593,10 +4033,8 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
}
}
}
}
ncnn::Mat out_padded[8];
for (int ti = 0; ti < 8; ti++)
{
// flownet
@ -3613,9 +4051,63 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
ex.extract("out0", out_padded[ti]);
}
}
}
// cut padding and postproc
out.create(w, h, 3);
if (tta_temporal_mode)
{
for (int q = 0; q < 3; q++)
{
const ncnn::Mat out_padded_0 = out_padded[0].channel(q);
const ncnn::Mat out_padded_1 = out_padded[1].channel(q);
const ncnn::Mat out_padded_2 = out_padded[2].channel(q);
const ncnn::Mat out_padded_3 = out_padded[3].channel(q);
const ncnn::Mat out_padded_4 = out_padded[4].channel(q);
const ncnn::Mat out_padded_5 = out_padded[5].channel(q);
const ncnn::Mat out_padded_6 = out_padded[6].channel(q);
const ncnn::Mat out_padded_7 = out_padded[7].channel(q);
const ncnn::Mat out_padded_reversed_0 = out_padded_reversed[0].channel(q);
const ncnn::Mat out_padded_reversed_1 = out_padded_reversed[1].channel(q);
const ncnn::Mat out_padded_reversed_2 = out_padded_reversed[2].channel(q);
const ncnn::Mat out_padded_reversed_3 = out_padded_reversed[3].channel(q);
const ncnn::Mat out_padded_reversed_4 = out_padded_reversed[4].channel(q);
const ncnn::Mat out_padded_reversed_5 = out_padded_reversed[5].channel(q);
const ncnn::Mat out_padded_reversed_6 = out_padded_reversed[6].channel(q);
const ncnn::Mat out_padded_reversed_7 = out_padded_reversed[7].channel(q);
float* outptr = out.channel(q);
for (int i = 0; i < h; i++)
{
const float* ptr0 = out_padded_0.row(i);
const float* ptr1 = out_padded_1.row(i) + w_padded - 1;
const float* ptr2 = out_padded_2.row(h_padded - 1 - i) + w_padded - 1;
const float* ptr3 = out_padded_3.row(h_padded - 1 - i);
const float* ptrr0 = out_padded_reversed_0.row(i);
const float* ptrr1 = out_padded_reversed_1.row(i) + w_padded - 1;
const float* ptrr2 = out_padded_reversed_2.row(h_padded - 1 - i) + w_padded - 1;
const float* ptrr3 = out_padded_reversed_3.row(h_padded - 1 - i);
for (int j = 0; j < w; j++)
{
const float* ptr4 = out_padded_4.row(j) + i;
const float* ptr5 = out_padded_5.row(j) + h_padded - 1 - i;
const float* ptr6 = out_padded_6.row(w_padded - 1 - j) + h_padded - 1 - i;
const float* ptr7 = out_padded_7.row(w_padded - 1 - j) + i;
const float* ptrr4 = out_padded_reversed_4.row(j) + i;
const float* ptrr5 = out_padded_reversed_5.row(j) + h_padded - 1 - i;
const float* ptrr6 = out_padded_reversed_6.row(w_padded - 1 - j) + h_padded - 1 - i;
const float* ptrr7 = out_padded_reversed_7.row(w_padded - 1 - j) + i;
float v = (*ptr0++ + *ptr1-- + *ptr2-- + *ptr3++ + *ptr4 + *ptr5 + *ptr6 + *ptr7) / 8;
float vr = (*ptrr0++ + *ptrr1-- + *ptrr2-- + *ptrr3++ + *ptrr4 + *ptrr5 + *ptrr6 + *ptrr7) / 8;
*outptr++ = (v + vr) * 0.5f * 255.f + 0.5f;
}
}
}
}
else
{
for (int q = 0; q < 3; q++)
{
@ -3645,7 +4137,7 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
float v = (*ptr0++ + *ptr1-- + *ptr2-- + *ptr3++ + *ptr4 + *ptr5 + *ptr6 + *ptr7) / 8;
*outptr++ = (v + 0.5f) * 255.f + 0.5f;
*outptr++ = v * 255.f + 0.5f;
}
}
}
@ -3722,9 +4214,135 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
timestep_padded.fill(timestep);
}
// flownet
ncnn::Mat out_padded;
ncnn::Mat out_padded_reversed;
if (tta_temporal_mode)
{
ncnn::Mat timestep_padded_reversed;
{
timestep_padded_reversed.create(w_padded, h_padded, 1);
timestep_padded_reversed.fill(1.f - timestep);
}
ncnn::Mat flow[4];
ncnn::Mat flow_reversed[4];
for (int fi = 0; fi < 4; fi++)
{
{
// flownet flow mask
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in0_padded);
ex.input("in1", in1_padded);
ex.input("in2", timestep_padded);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow[2]);
case 2: ex.input("flow1", flow[1]);
case 1: ex.input("flow0", flow[0]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow[fi]);
}
}
}
{
// flownet flow mask reversed
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in1_padded);
ex.input("in1", in0_padded);
ex.input("in2", timestep_padded_reversed);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow_reversed[2]);
case 2: ex.input("flow1", flow_reversed[1]);
case 1: ex.input("flow0", flow_reversed[0]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow_reversed[fi]);
}
}
}
// merge flow and flow_reversed
{
float* flow_x = flow[fi].channel(0);
float* flow_y = flow[fi].channel(1);
float* flow_z = flow[fi].channel(2);
float* flow_w = flow[fi].channel(3);
float* flow_m = flow[fi].channel(4);
float* flow_reversed_x = flow_reversed[fi].channel(0);
float* flow_reversed_y = flow_reversed[fi].channel(1);
float* flow_reversed_z = flow_reversed[fi].channel(2);
float* flow_reversed_w = flow_reversed[fi].channel(3);
float* flow_reversed_m = flow_reversed[fi].channel(4);
for (int i = 0; i < flow[fi].h; i++)
{
for (int j = 0; j < flow[fi].w; j++)
{
float x = (*flow_x + *flow_reversed_z) * 0.5f;
float y = (*flow_y + *flow_reversed_w) * 0.5f;
float z = (*flow_z + *flow_reversed_x) * 0.5f;
float w = (*flow_w + *flow_reversed_y) * 0.5f;
float m = (*flow_m - *flow_reversed_m) * 0.5f;
*flow_x++ = x;
*flow_y++ = y;
*flow_z++ = z;
*flow_w++ = w;
*flow_m++ = m;
*flow_reversed_x++ = z;
*flow_reversed_y++ = w;
*flow_reversed_z++ = x;
*flow_reversed_w++ = y;
*flow_reversed_m++ = -m;
}
}
}
}
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in0_padded);
ex.input("in1", in1_padded);
ex.input("in2", timestep_padded);
ex.input("flow0", flow[0]);
ex.input("flow1", flow[1]);
ex.input("flow2", flow[2]);
ex.input("flow3", flow[3]);
ex.extract("out0", out_padded);
}
{
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in1_padded);
ex.input("in1", in0_padded);
ex.input("in2", timestep_padded_reversed);
ex.input("flow0", flow_reversed[0]);
ex.input("flow1", flow_reversed[1]);
ex.input("flow2", flow_reversed[2]);
ex.input("flow3", flow_reversed[3]);
ex.extract("out0", out_padded_reversed);
}
}
else
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.input("in0", in0_padded);
@ -3735,6 +4353,24 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
// cut padding and postproc
out.create(w, h, 3);
if (tta_temporal_mode)
{
for (int q = 0; q < 3; q++)
{
float* outptr = out.channel(q);
const float* ptr = out_padded.channel(q);
const float* ptr1 = out_padded_reversed.channel(q);
for (int i = 0; i < h; i++)
{
for (int j = 0; j < w; j++)
{
*outptr++ = (*ptr++ + *ptr1++) * 0.5f * 255.f + 0.5f;
}
}
}
}
else
{
for (int q = 0; q < 3; q++)
{