mirror of
https://github.com/nihui/rife-ncnn-vulkan
synced 2024-11-25 08:07:20 +01:00
implement spatial tta on flow mask feature
This commit is contained in:
parent
528da7336b
commit
5ca8c773f2
@ -43,8 +43,8 @@ Convolution conv_29 1 1 72 74 0=192 1=3 4=1 5=1 6=
|
||||
BinaryOp add_15 2 1 74 71 75
|
||||
ReLU leakyrelu_73 1 1 75 76 0=2.000000e-01
|
||||
Deconvolution deconv_60 1 1 76 77 0=24 1=4 3=2 4=1 5=1 6=73728
|
||||
PixelShuffle pixelshuffle_104 1 1 77 78 0=2
|
||||
Interp upsample_10 1 1 78 79 0=2 1=8.000000e+00 2=8.000000e+00
|
||||
PixelShuffle pixelshuffle_104 1 1 77 flow0 0=2
|
||||
Interp upsample_10 1 1 flow0 79 0=2 1=8.000000e+00 2=8.000000e+00
|
||||
Split splitncnn_11 1 2 79 80 81
|
||||
Crop slice_108 1 1 81 82 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
BinaryOp mul_16 1 1 82 83 0=2 1=1 2=8.000000e+00
|
||||
@ -95,8 +95,8 @@ Convolution conv_39 1 1 153 155 0=128 1=3 4=1 5=1
|
||||
BinaryOp add_33 2 1 155 152 156
|
||||
ReLU leakyrelu_83 1 1 156 157 0=2.000000e-01
|
||||
Deconvolution deconv_61 1 1 157 158 0=24 1=4 3=2 4=1 5=1 6=49152
|
||||
PixelShuffle pixelshuffle_105 1 1 158 159 0=2
|
||||
Interp upsample_13 1 1 159 160 0=2 1=4.000000e+00 2=4.000000e+00
|
||||
PixelShuffle pixelshuffle_105 1 1 158 flow1 0=2
|
||||
Interp upsample_13 1 1 flow1 160 0=2 1=4.000000e+00 2=4.000000e+00
|
||||
Split splitncnn_22 1 2 160 161 162
|
||||
Crop slice_112 1 1 162 163 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
Eltwise weighted_sum_0 2 1 84 163 164 0=1 -23301=2,1.000000e+00,4.000000e+00
|
||||
@ -148,8 +148,8 @@ Convolution conv_49 1 1 235 237 0=96 1=3 4=1 5=1 6
|
||||
BinaryOp add_53 2 1 237 234 238
|
||||
ReLU leakyrelu_93 1 1 238 239 0=2.000000e-01
|
||||
Deconvolution deconv_62 1 1 239 240 0=24 1=4 3=2 4=1 5=1 6=36864
|
||||
PixelShuffle pixelshuffle_106 1 1 240 241 0=2
|
||||
Interp upsample_16 1 1 241 242 0=2 1=2.000000e+00 2=2.000000e+00
|
||||
PixelShuffle pixelshuffle_106 1 1 240 flow2 0=2
|
||||
Interp upsample_16 1 1 flow2 242 0=2 1=2.000000e+00 2=2.000000e+00
|
||||
Split splitncnn_33 1 2 242 243 244
|
||||
Crop slice_116 1 1 244 245 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
Eltwise weighted_sum_1 2 1 165 245 246 0=1 -23301=2,1.000000e+00,2.000000e+00
|
||||
@ -198,8 +198,8 @@ Convolution conv_59 1 1 316 318 0=64 1=3 4=1 5=1 6
|
||||
BinaryOp add_72 2 1 318 315 319
|
||||
ReLU leakyrelu_103 1 1 319 320 0=2.000000e-01
|
||||
Deconvolution deconv_63 1 1 320 321 0=24 1=4 3=2 4=1 5=1 6=24576
|
||||
PixelShuffle pixelshuffle_107 1 1 321 323 0=2
|
||||
Split splitncnn_44 1 2 323 324 325
|
||||
PixelShuffle pixelshuffle_107 1 1 321 flow3 0=2
|
||||
Split splitncnn_44 1 2 flow3 324 325
|
||||
Crop slice_120 1 1 325 326 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
BinaryOp add_73 2 1 247 326 327
|
||||
Split splitncnn_45 1 2 327 328 329
|
||||
|
@ -30,8 +30,8 @@ PReLU prelu_72 1 1 37 38 0=192
|
||||
Convolution conv_29 1 1 38 39 0=192 1=3 4=1 5=1 6=331776
|
||||
PReLU prelu_73 1 1 39 40 0=192
|
||||
BinaryOp add_0 2 1 40 23 41
|
||||
Deconvolution deconv_60 1 1 41 42 0=5 1=4 3=2 4=1 5=1 6=15360
|
||||
Interp upsample_10 1 1 42 43 0=2 1=1.600000e+01 2=1.600000e+01
|
||||
Deconvolution deconv_60 1 1 41 flow0 0=5 1=4 3=2 4=1 5=1 6=15360
|
||||
Interp upsample_10 1 1 flow0 43 0=2 1=1.600000e+01 2=1.600000e+01
|
||||
Split splitncnn_4 1 2 43 44 45
|
||||
Crop slice_104 1 1 45 46 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
BinaryOp mul_1 1 1 46 47 0=2 1=1 2=1.600000e+01
|
||||
@ -69,8 +69,8 @@ PReLU prelu_82 1 1 82 83 0=128
|
||||
Convolution conv_39 1 1 83 84 0=128 1=3 4=1 5=1 6=147456
|
||||
PReLU prelu_83 1 1 84 85 0=128
|
||||
BinaryOp add_3 2 1 85 68 86
|
||||
Deconvolution deconv_61 1 1 86 87 0=5 1=4 3=2 4=1 5=1 6=10240
|
||||
Interp upsample_13 1 1 87 88 0=2 1=8.000000e+00 2=8.000000e+00
|
||||
Deconvolution deconv_61 1 1 86 flow1 0=5 1=4 3=2 4=1 5=1 6=10240
|
||||
Interp upsample_13 1 1 flow1 88 0=2 1=8.000000e+00 2=8.000000e+00
|
||||
Split splitncnn_8 1 2 88 89 90
|
||||
Crop slice_108 1 1 90 91 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
Eltwise add_5 2 1 48 91 93 0=1 -23301=2,1.000000e+00,8.000000e+00
|
||||
@ -109,8 +109,8 @@ PReLU prelu_92 1 1 129 130 0=96
|
||||
Convolution conv_49 1 1 130 131 0=96 1=3 4=1 5=1 6=82944
|
||||
PReLU prelu_93 1 1 131 132 0=96
|
||||
BinaryOp add_8 2 1 132 115 133
|
||||
Deconvolution deconv_62 1 1 133 134 0=5 1=4 3=2 4=1 5=1 6=7680
|
||||
Interp upsample_16 1 1 134 135 0=2 1=4.000000e+00 2=4.000000e+00
|
||||
Deconvolution deconv_62 1 1 133 flow2 0=5 1=4 3=2 4=1 5=1 6=7680
|
||||
Interp upsample_16 1 1 flow2 135 0=2 1=4.000000e+00 2=4.000000e+00
|
||||
Split splitncnn_12 1 2 135 136 137
|
||||
Crop slice_112 1 1 137 138 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
Eltwise add_10 2 1 94 138 140 0=1 -23301=2,1.000000e+00,4.000000e+00
|
||||
@ -148,8 +148,8 @@ PReLU prelu_102 1 1 175 176 0=64
|
||||
Convolution conv_59 1 1 176 177 0=64 1=3 4=1 5=1 6=36864
|
||||
PReLU prelu_103 1 1 177 178 0=64
|
||||
BinaryOp add_12 2 1 178 161 179
|
||||
Deconvolution deconv_63 1 1 179 180 0=5 1=4 3=2 4=1 5=1 6=5120
|
||||
Interp upsample_19 1 1 180 181 0=2 1=2.000000e+00 2=2.000000e+00
|
||||
Deconvolution deconv_63 1 1 179 flow3 0=5 1=4 3=2 4=1 5=1 6=5120
|
||||
Interp upsample_19 1 1 flow3 181 0=2 1=2.000000e+00 2=2.000000e+00
|
||||
Split splitncnn_16 1 2 181 182 183
|
||||
Crop slice_116 1 1 183 184 -23309=1,0 -23310=1,4 -23311=1,0
|
||||
Eltwise add_14 2 1 141 184 186 0=1 -23301=2,1.000000e+00,2.000000e+00
|
||||
|
@ -245,6 +245,7 @@ rife_add_shader(rife_preproc_tta.comp)
|
||||
rife_add_shader(rife_postproc_tta.comp)
|
||||
rife_add_shader(rife_flow_tta_avg.comp)
|
||||
rife_add_shader(rife_v2_flow_tta_avg.comp)
|
||||
rife_add_shader(rife_v4_flow_tta_avg.comp)
|
||||
rife_add_shader(rife_flow_tta_temporal_avg.comp)
|
||||
rife_add_shader(rife_v2_flow_tta_temporal_avg.comp)
|
||||
rife_add_shader(rife_out_tta_temporal_avg.comp)
|
||||
|
255
src/rife.cpp
255
src/rife.cpp
@ -12,6 +12,7 @@
|
||||
#include "rife_postproc_tta.comp.hex.h"
|
||||
#include "rife_flow_tta_avg.comp.hex.h"
|
||||
#include "rife_v2_flow_tta_avg.comp.hex.h"
|
||||
#include "rife_v4_flow_tta_avg.comp.hex.h"
|
||||
#include "rife_flow_tta_temporal_avg.comp.hex.h"
|
||||
#include "rife_v2_flow_tta_temporal_avg.comp.hex.h"
|
||||
#include "rife_out_tta_temporal_avg.comp.hex.h"
|
||||
@ -218,7 +219,11 @@ int RIFE::load(const std::string& modeldir)
|
||||
ncnn::MutexLockGuard guard(lock);
|
||||
if (spirv.empty())
|
||||
{
|
||||
if (rife_v2)
|
||||
if (rife_v4)
|
||||
{
|
||||
compile_spirv_module(rife_v4_flow_tta_avg_comp_data, sizeof(rife_v4_flow_tta_avg_comp_data), opt, spirv);
|
||||
}
|
||||
else if (rife_v2)
|
||||
{
|
||||
compile_spirv_module(rife_v2_flow_tta_avg_comp_data, sizeof(rife_v2_flow_tta_avg_comp_data), opt, spirv);
|
||||
}
|
||||
@ -2606,6 +2611,61 @@ int RIFE::process_v4(const ncnn::Mat& in0image, const ncnn::Mat& in1image, float
|
||||
cmd.record_pipeline(rife_v4_timestep, bindings, constants, timestep_gpu_padded[0]);
|
||||
}
|
||||
|
||||
ncnn::VkMat flow[4][8];
|
||||
for (int fi = 0; fi < 4; fi++)
|
||||
{
|
||||
for (int ti = 0; ti < 8; ti++)
|
||||
{
|
||||
// flownet flow mask
|
||||
ncnn::Extractor ex = flownet.create_extractor();
|
||||
ex.set_blob_vkallocator(blob_vkallocator);
|
||||
ex.set_workspace_vkallocator(blob_vkallocator);
|
||||
ex.set_staging_vkallocator(staging_vkallocator);
|
||||
|
||||
ex.input("in0", in0_gpu_padded[ti]);
|
||||
ex.input("in1", in1_gpu_padded[ti]);
|
||||
ex.input("in2", timestep_gpu_padded[ti / 4]);
|
||||
|
||||
// intentional fall through
|
||||
switch (fi)
|
||||
{
|
||||
case 3: ex.input("flow2", flow[2][ti]);
|
||||
case 2: ex.input("flow1", flow[1][ti]);
|
||||
case 1: ex.input("flow0", flow[0][ti]);
|
||||
default:
|
||||
{
|
||||
char tmp[16];
|
||||
sprintf(tmp, "flow%d", fi);
|
||||
ex.extract(tmp, flow[fi][ti], cmd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// avg flow mask
|
||||
{
|
||||
std::vector<ncnn::VkMat> bindings(8);
|
||||
bindings[0] = flow[fi][0];
|
||||
bindings[1] = flow[fi][1];
|
||||
bindings[2] = flow[fi][2];
|
||||
bindings[3] = flow[fi][3];
|
||||
bindings[4] = flow[fi][4];
|
||||
bindings[5] = flow[fi][5];
|
||||
bindings[6] = flow[fi][6];
|
||||
bindings[7] = flow[fi][7];
|
||||
|
||||
std::vector<ncnn::vk_constant_type> constants(3);
|
||||
constants[0].i = flow[fi][0].w;
|
||||
constants[1].i = flow[fi][0].h;
|
||||
constants[2].i = flow[fi][0].cstep;
|
||||
|
||||
ncnn::VkMat dispatcher;
|
||||
dispatcher.w = flow[fi][0].w;
|
||||
dispatcher.h = flow[fi][0].h;
|
||||
dispatcher.c = 1;
|
||||
cmd.record_pipeline(rife_flow_tta_avg, bindings, constants, dispatcher);
|
||||
}
|
||||
}
|
||||
|
||||
ncnn::VkMat out_gpu_padded[8];
|
||||
for (int ti = 0; ti < 8; ti++)
|
||||
{
|
||||
@ -2618,6 +2678,11 @@ int RIFE::process_v4(const ncnn::Mat& in0image, const ncnn::Mat& in1image, float
|
||||
ex.input("in0", in0_gpu_padded[ti]);
|
||||
ex.input("in1", in1_gpu_padded[ti]);
|
||||
ex.input("in2", timestep_gpu_padded[ti / 4]);
|
||||
ex.input("flow0", flow[0][ti]);
|
||||
ex.input("flow1", flow[1][ti]);
|
||||
ex.input("flow2", flow[2][ti]);
|
||||
ex.input("flow3", flow[3][ti]);
|
||||
|
||||
ex.extract("out0", out_gpu_padded[ti], cmd);
|
||||
}
|
||||
|
||||
@ -2990,6 +3055,190 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
|
||||
}
|
||||
}
|
||||
|
||||
ncnn::Mat flow[4][8];
|
||||
for (int fi = 0; fi < 4; fi++)
|
||||
{
|
||||
for (int ti = 0; ti < 8; ti++)
|
||||
{
|
||||
// flownet flow mask
|
||||
ncnn::Extractor ex = flownet.create_extractor();
|
||||
|
||||
ex.input("in0", in0_padded[ti]);
|
||||
ex.input("in1", in1_padded[ti]);
|
||||
ex.input("in2", timestep_padded[ti / 4]);
|
||||
|
||||
// intentional fall through
|
||||
switch (fi)
|
||||
{
|
||||
case 3: ex.input("flow2", flow[2][ti]);
|
||||
case 2: ex.input("flow1", flow[1][ti]);
|
||||
case 1: ex.input("flow0", flow[0][ti]);
|
||||
default:
|
||||
{
|
||||
char tmp[16];
|
||||
sprintf(tmp, "flow%d", fi);
|
||||
ex.extract(tmp, flow[fi][ti]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// avg flow mask
|
||||
{
|
||||
ncnn::Mat flow_x0 = flow[fi][0].channel(0);
|
||||
ncnn::Mat flow_x1 = flow[fi][1].channel(0);
|
||||
ncnn::Mat flow_x2 = flow[fi][2].channel(0);
|
||||
ncnn::Mat flow_x3 = flow[fi][3].channel(0);
|
||||
ncnn::Mat flow_x4 = flow[fi][4].channel(0);
|
||||
ncnn::Mat flow_x5 = flow[fi][5].channel(0);
|
||||
ncnn::Mat flow_x6 = flow[fi][6].channel(0);
|
||||
ncnn::Mat flow_x7 = flow[fi][7].channel(0);
|
||||
|
||||
ncnn::Mat flow_y0 = flow[fi][0].channel(1);
|
||||
ncnn::Mat flow_y1 = flow[fi][1].channel(1);
|
||||
ncnn::Mat flow_y2 = flow[fi][2].channel(1);
|
||||
ncnn::Mat flow_y3 = flow[fi][3].channel(1);
|
||||
ncnn::Mat flow_y4 = flow[fi][4].channel(1);
|
||||
ncnn::Mat flow_y5 = flow[fi][5].channel(1);
|
||||
ncnn::Mat flow_y6 = flow[fi][6].channel(1);
|
||||
ncnn::Mat flow_y7 = flow[fi][7].channel(1);
|
||||
|
||||
ncnn::Mat flow_z0 = flow[fi][0].channel(2);
|
||||
ncnn::Mat flow_z1 = flow[fi][1].channel(2);
|
||||
ncnn::Mat flow_z2 = flow[fi][2].channel(2);
|
||||
ncnn::Mat flow_z3 = flow[fi][3].channel(2);
|
||||
ncnn::Mat flow_z4 = flow[fi][4].channel(2);
|
||||
ncnn::Mat flow_z5 = flow[fi][5].channel(2);
|
||||
ncnn::Mat flow_z6 = flow[fi][6].channel(2);
|
||||
ncnn::Mat flow_z7 = flow[fi][7].channel(2);
|
||||
|
||||
ncnn::Mat flow_w0 = flow[fi][0].channel(3);
|
||||
ncnn::Mat flow_w1 = flow[fi][1].channel(3);
|
||||
ncnn::Mat flow_w2 = flow[fi][2].channel(3);
|
||||
ncnn::Mat flow_w3 = flow[fi][3].channel(3);
|
||||
ncnn::Mat flow_w4 = flow[fi][4].channel(3);
|
||||
ncnn::Mat flow_w5 = flow[fi][5].channel(3);
|
||||
ncnn::Mat flow_w6 = flow[fi][6].channel(3);
|
||||
ncnn::Mat flow_w7 = flow[fi][7].channel(3);
|
||||
|
||||
ncnn::Mat flow_m0 = flow[fi][0].channel(4);
|
||||
ncnn::Mat flow_m1 = flow[fi][1].channel(4);
|
||||
ncnn::Mat flow_m2 = flow[fi][2].channel(4);
|
||||
ncnn::Mat flow_m3 = flow[fi][3].channel(4);
|
||||
ncnn::Mat flow_m4 = flow[fi][4].channel(4);
|
||||
ncnn::Mat flow_m5 = flow[fi][5].channel(4);
|
||||
ncnn::Mat flow_m6 = flow[fi][6].channel(4);
|
||||
ncnn::Mat flow_m7 = flow[fi][7].channel(4);
|
||||
|
||||
for (int i = 0; i < flow_x0.h; i++)
|
||||
{
|
||||
float* x0 = flow_x0.row(i);
|
||||
float* x1 = flow_x1.row(i) + flow_x0.w - 1;
|
||||
float* x2 = flow_x2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
|
||||
float* x3 = flow_x3.row(flow_x0.h - 1 - i);
|
||||
|
||||
float* y0 = flow_y0.row(i);
|
||||
float* y1 = flow_y1.row(i) + flow_x0.w - 1;
|
||||
float* y2 = flow_y2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
|
||||
float* y3 = flow_y3.row(flow_x0.h - 1 - i);
|
||||
|
||||
float* z0 = flow_z0.row(i);
|
||||
float* z1 = flow_z1.row(i) + flow_x0.w - 1;
|
||||
float* z2 = flow_z2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
|
||||
float* z3 = flow_z3.row(flow_x0.h - 1 - i);
|
||||
|
||||
float* w0 = flow_w0.row(i);
|
||||
float* w1 = flow_w1.row(i) + flow_x0.w - 1;
|
||||
float* w2 = flow_w2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
|
||||
float* w3 = flow_w3.row(flow_x0.h - 1 - i);
|
||||
|
||||
float* m0 = flow_m0.row(i);
|
||||
float* m1 = flow_m1.row(i) + flow_x0.w - 1;
|
||||
float* m2 = flow_m2.row(flow_x0.h - 1 - i) + flow_x0.w - 1;
|
||||
float* m3 = flow_m3.row(flow_x0.h - 1 - i);
|
||||
|
||||
for (int j = 0; j < flow_x0.w; j++)
|
||||
{
|
||||
float* x4 = flow_x4.row(j) + i;
|
||||
float* x5 = flow_x5.row(j) + flow_x0.h - 1 - i;
|
||||
float* x6 = flow_x6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
|
||||
float* x7 = flow_x7.row(flow_x0.w - 1 - j) + i;
|
||||
|
||||
float* y4 = flow_y4.row(j) + i;
|
||||
float* y5 = flow_y5.row(j) + flow_x0.h - 1 - i;
|
||||
float* y6 = flow_y6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
|
||||
float* y7 = flow_y7.row(flow_x0.w - 1 - j) + i;
|
||||
|
||||
float* z4 = flow_z4.row(j) + i;
|
||||
float* z5 = flow_z5.row(j) + flow_x0.h - 1 - i;
|
||||
float* z6 = flow_z6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
|
||||
float* z7 = flow_z7.row(flow_x0.w - 1 - j) + i;
|
||||
|
||||
float* w4 = flow_w4.row(j) + i;
|
||||
float* w5 = flow_w5.row(j) + flow_x0.h - 1 - i;
|
||||
float* w6 = flow_w6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
|
||||
float* w7 = flow_w7.row(flow_x0.w - 1 - j) + i;
|
||||
|
||||
float* m4 = flow_m4.row(j) + i;
|
||||
float* m5 = flow_m5.row(j) + flow_x0.h - 1 - i;
|
||||
float* m6 = flow_m6.row(flow_x0.w - 1 - j) + flow_x0.h - 1 - i;
|
||||
float* m7 = flow_m7.row(flow_x0.w - 1 - j) + i;
|
||||
|
||||
float x = (*x0 + -*x1 + -*x2 + *x3 + *y4 + *y5 + -*y6 + -*y7) * 0.125f;
|
||||
float y = (*y0 + *y1 + -*y2 + -*y3 + *x4 + -*x5 + -*x6 + *x7) * 0.125f;
|
||||
float z = (*z0 + -*z1 + -*z2 + *z3 + *w4 + *w5 + -*w6 + -*w7) * 0.125f;
|
||||
float w = (*w0 + *w1 + -*w2 + -*w3 + *z4 + -*z5 + -*z6 + *z7) * 0.125f;
|
||||
float m = (*m0 + *m1 + *m2 + *m3 + *m4 + *m5 + *m6 + *m7) * 0.125f;
|
||||
|
||||
*x0++ = x;
|
||||
*x1-- = -x;
|
||||
*x2-- = -x;
|
||||
*x3++ = x;
|
||||
*x4 = y;
|
||||
*x5 = -y;
|
||||
*x6 = -y;
|
||||
*x7 = y;
|
||||
|
||||
*y0++ = y;
|
||||
*y1-- = y;
|
||||
*y2-- = -y;
|
||||
*y3++ = -y;
|
||||
*y4 = x;
|
||||
*y5 = x;
|
||||
*y6 = -x;
|
||||
*y7 = -x;
|
||||
|
||||
*z0++ = z;
|
||||
*z1-- = -z;
|
||||
*z2-- = -z;
|
||||
*z3++ = z;
|
||||
*z4 = w;
|
||||
*z5 = -w;
|
||||
*z6 = -w;
|
||||
*z7 = w;
|
||||
|
||||
*w0++ = w;
|
||||
*w1-- = w;
|
||||
*w2-- = -w;
|
||||
*w3++ = -w;
|
||||
*w4 = z;
|
||||
*w5 = z;
|
||||
*w6 = -z;
|
||||
*w7 = -z;
|
||||
|
||||
*m0++ = m;
|
||||
*m1-- = m;
|
||||
*m2-- = m;
|
||||
*m3++ = m;
|
||||
*m4 = m;
|
||||
*m5 = m;
|
||||
*m6 = m;
|
||||
*m7 = m;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
ncnn::Mat out_padded[8];
|
||||
for (int ti = 0; ti < 8; ti++)
|
||||
{
|
||||
@ -3000,6 +3249,10 @@ int RIFE::process_v4_cpu(const ncnn::Mat& in0image, const ncnn::Mat& in1image, f
|
||||
ex.input("in0", in0_padded[ti]);
|
||||
ex.input("in1", in1_padded[ti]);
|
||||
ex.input("in2", timestep_padded[ti / 4]);
|
||||
ex.input("flow0", flow[0][ti]);
|
||||
ex.input("flow1", flow[1][ti]);
|
||||
ex.input("flow2", flow[2][ti]);
|
||||
ex.input("flow3", flow[3][ti]);
|
||||
ex.extract("out0", out_padded[ti]);
|
||||
}
|
||||
}
|
||||
|
129
src/rife_v4_flow_tta_avg.comp
Normal file
129
src/rife_v4_flow_tta_avg.comp
Normal file
@ -0,0 +1,129 @@
|
||||
// rife implemented with ncnn library
|
||||
|
||||
#version 450
|
||||
|
||||
#if NCNN_fp16_storage
|
||||
#extension GL_EXT_shader_16bit_storage: require
|
||||
#endif
|
||||
|
||||
layout (binding = 0) buffer flow_blob0 { sfp flow_blob0_data[]; };
|
||||
layout (binding = 1) buffer flow_blob1 { sfp flow_blob1_data[]; };
|
||||
layout (binding = 2) buffer flow_blob2 { sfp flow_blob2_data[]; };
|
||||
layout (binding = 3) buffer flow_blob3 { sfp flow_blob3_data[]; };
|
||||
layout (binding = 4) buffer flow_blob4 { sfp flow_blob4_data[]; };
|
||||
layout (binding = 5) buffer flow_blob5 { sfp flow_blob5_data[]; };
|
||||
layout (binding = 6) buffer flow_blob6 { sfp flow_blob6_data[]; };
|
||||
layout (binding = 7) buffer flow_blob7 { sfp flow_blob7_data[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int w;
|
||||
int h;
|
||||
int cstep;
|
||||
} p;
|
||||
|
||||
void main()
|
||||
{
|
||||
int gx = int(gl_GlobalInvocationID.x);
|
||||
int gy = int(gl_GlobalInvocationID.y);
|
||||
int gz = int(gl_GlobalInvocationID.z);
|
||||
|
||||
if (gx >= p.w || gy >= p.h || gz >= 1)
|
||||
return;
|
||||
|
||||
float x0 = float(flow_blob0_data[gy * p.w + gx]);
|
||||
float x1 = float(flow_blob1_data[gy * p.w + (p.w - 1 - gx)]);
|
||||
float x2 = float(flow_blob2_data[(p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
|
||||
float x3 = float(flow_blob3_data[(p.h - 1 - gy) * p.w + gx]);
|
||||
float x4 = float(flow_blob4_data[gx * p.h + gy]);
|
||||
float x5 = float(flow_blob5_data[gx * p.h + (p.h - 1 - gy)]);
|
||||
float x6 = float(flow_blob6_data[(p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
|
||||
float x7 = float(flow_blob7_data[(p.w - 1 - gx) * p.h + gy]);
|
||||
|
||||
float y0 = float(flow_blob0_data[p.cstep + gy * p.w + gx]);
|
||||
float y1 = float(flow_blob1_data[p.cstep + gy * p.w + (p.w - 1 - gx)]);
|
||||
float y2 = float(flow_blob2_data[p.cstep + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
|
||||
float y3 = float(flow_blob3_data[p.cstep + (p.h - 1 - gy) * p.w + gx]);
|
||||
float y4 = float(flow_blob4_data[p.cstep + gx * p.h + gy]);
|
||||
float y5 = float(flow_blob5_data[p.cstep + gx * p.h + (p.h - 1 - gy)]);
|
||||
float y6 = float(flow_blob6_data[p.cstep + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
|
||||
float y7 = float(flow_blob7_data[p.cstep + (p.w - 1 - gx) * p.h + gy]);
|
||||
|
||||
float z0 = float(flow_blob0_data[p.cstep * 2 + gy * p.w + gx]);
|
||||
float z1 = float(flow_blob1_data[p.cstep * 2 + gy * p.w + (p.w - 1 - gx)]);
|
||||
float z2 = float(flow_blob2_data[p.cstep * 2 + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
|
||||
float z3 = float(flow_blob3_data[p.cstep * 2 + (p.h - 1 - gy) * p.w + gx]);
|
||||
float z4 = float(flow_blob4_data[p.cstep * 2 + gx * p.h + gy]);
|
||||
float z5 = float(flow_blob5_data[p.cstep * 2 + gx * p.h + (p.h - 1 - gy)]);
|
||||
float z6 = float(flow_blob6_data[p.cstep * 2 + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
|
||||
float z7 = float(flow_blob7_data[p.cstep * 2 + (p.w - 1 - gx) * p.h + gy]);
|
||||
|
||||
float w0 = float(flow_blob0_data[p.cstep * 3 + gy * p.w + gx]);
|
||||
float w1 = float(flow_blob1_data[p.cstep * 3 + gy * p.w + (p.w - 1 - gx)]);
|
||||
float w2 = float(flow_blob2_data[p.cstep * 3 + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
|
||||
float w3 = float(flow_blob3_data[p.cstep * 3 + (p.h - 1 - gy) * p.w + gx]);
|
||||
float w4 = float(flow_blob4_data[p.cstep * 3 + gx * p.h + gy]);
|
||||
float w5 = float(flow_blob5_data[p.cstep * 3 + gx * p.h + (p.h - 1 - gy)]);
|
||||
float w6 = float(flow_blob6_data[p.cstep * 3 + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
|
||||
float w7 = float(flow_blob7_data[p.cstep * 3 + (p.w - 1 - gx) * p.h + gy]);
|
||||
|
||||
float m0 = float(flow_blob0_data[p.cstep * 4 + gy * p.w + gx]);
|
||||
float m1 = float(flow_blob1_data[p.cstep * 4 + gy * p.w + (p.w - 1 - gx)]);
|
||||
float m2 = float(flow_blob2_data[p.cstep * 4 + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
|
||||
float m3 = float(flow_blob3_data[p.cstep * 4 + (p.h - 1 - gy) * p.w + gx]);
|
||||
float m4 = float(flow_blob4_data[p.cstep * 4 + gx * p.h + gy]);
|
||||
float m5 = float(flow_blob5_data[p.cstep * 4 + gx * p.h + (p.h - 1 - gy)]);
|
||||
float m6 = float(flow_blob6_data[p.cstep * 4 + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
|
||||
float m7 = float(flow_blob7_data[p.cstep * 4 + (p.w - 1 - gx) * p.h + gy]);
|
||||
|
||||
float x = (x0 + -x1 + -x2 + x3 + y4 + y5 + -y6 + -y7) * 0.125f;
|
||||
float y = (y0 + y1 + -y2 + -y3 + x4 + -x5 + -x6 + x7) * 0.125f;
|
||||
float z = (z0 + -z1 + -z2 + z3 + w4 + w5 + -w6 + -w7) * 0.125f;
|
||||
float w = (w0 + w1 + -w2 + -w3 + z4 + -z5 + -z6 + z7) * 0.125f;
|
||||
float m = (m0 + m1 + m2 + m3 + m4 + m5 + m6 + m7) * 0.125f;
|
||||
|
||||
flow_blob0_data[gy * p.w + gx] = sfp(x);
|
||||
flow_blob1_data[gy * p.w + (p.w - 1 - gx)] = sfp(-x);
|
||||
flow_blob2_data[(p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(-x);
|
||||
flow_blob3_data[(p.h - 1 - gy) * p.w + gx] = sfp(x);
|
||||
flow_blob4_data[gx * p.h + gy] = sfp(y);
|
||||
flow_blob5_data[gx * p.h + (p.h - 1 - gy)] = sfp(-y);
|
||||
flow_blob6_data[(p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(-y);
|
||||
flow_blob7_data[(p.w - 1 - gx) * p.h + gy] = sfp(y);
|
||||
|
||||
flow_blob0_data[p.cstep + gy * p.w + gx] = sfp(y);
|
||||
flow_blob1_data[p.cstep + gy * p.w + (p.w - 1 - gx)] = sfp(y);
|
||||
flow_blob2_data[p.cstep + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(-y);
|
||||
flow_blob3_data[p.cstep + (p.h - 1 - gy) * p.w + gx] = sfp(-y);
|
||||
flow_blob4_data[p.cstep + gx * p.h + gy] = sfp(x);
|
||||
flow_blob5_data[p.cstep + gx * p.h + (p.h - 1 - gy)] = sfp(x);
|
||||
flow_blob6_data[p.cstep + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(-x);
|
||||
flow_blob7_data[p.cstep + (p.w - 1 - gx) * p.h + gy] = sfp(-x);
|
||||
|
||||
flow_blob0_data[p.cstep * 2 + gy * p.w + gx] = sfp(z);
|
||||
flow_blob1_data[p.cstep * 2 + gy * p.w + (p.w - 1 - gx)] = sfp(-z);
|
||||
flow_blob2_data[p.cstep * 2 + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(-z);
|
||||
flow_blob3_data[p.cstep * 2 + (p.h - 1 - gy) * p.w + gx] = sfp(z);
|
||||
flow_blob4_data[p.cstep * 2 + gx * p.h + gy] = sfp(w);
|
||||
flow_blob5_data[p.cstep * 2 + gx * p.h + (p.h - 1 - gy)] = sfp(-w);
|
||||
flow_blob6_data[p.cstep * 2 + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(-w);
|
||||
flow_blob7_data[p.cstep * 2 + (p.w - 1 - gx) * p.h + gy] = sfp(w);
|
||||
|
||||
flow_blob0_data[p.cstep * 3 + gy * p.w + gx] = sfp(w);
|
||||
flow_blob1_data[p.cstep * 3 + gy * p.w + (p.w - 1 - gx)] = sfp(w);
|
||||
flow_blob2_data[p.cstep * 3 + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(-w);
|
||||
flow_blob3_data[p.cstep * 3 + (p.h - 1 - gy) * p.w + gx] = sfp(-w);
|
||||
flow_blob4_data[p.cstep * 3 + gx * p.h + gy] = sfp(z);
|
||||
flow_blob5_data[p.cstep * 3 + gx * p.h + (p.h - 1 - gy)] = sfp(z);
|
||||
flow_blob6_data[p.cstep * 3 + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(-z);
|
||||
flow_blob7_data[p.cstep * 3 + (p.w - 1 - gx) * p.h + gy] = sfp(-z);
|
||||
|
||||
flow_blob0_data[p.cstep * 4 + gy * p.w + gx] = sfp(m);
|
||||
flow_blob1_data[p.cstep * 4 + gy * p.w + (p.w - 1 - gx)] = sfp(m);
|
||||
flow_blob2_data[p.cstep * 4 + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(m);
|
||||
flow_blob3_data[p.cstep * 4 + (p.h - 1 - gy) * p.w + gx] = sfp(m);
|
||||
flow_blob4_data[p.cstep * 4 + gx * p.h + gy] = sfp(m);
|
||||
flow_blob5_data[p.cstep * 4 + gx * p.h + (p.h - 1 - gy)] = sfp(m);
|
||||
flow_blob6_data[p.cstep * 4 + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(m);
|
||||
flow_blob7_data[p.cstep * 4 + (p.w - 1 - gx) * p.h + gy] = sfp(m);
|
||||
}
|
Loading…
Reference in New Issue
Block a user