rife-ncnn-vulkan/src/rife_flow_tta_avg.comp

73 lines
3.2 KiB
Plaintext

// rife implemented with ncnn library
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
layout (binding = 0) buffer flow_blob0 { sfp flow_blob0_data[]; };
layout (binding = 1) buffer flow_blob1 { sfp flow_blob1_data[]; };
layout (binding = 2) buffer flow_blob2 { sfp flow_blob2_data[]; };
layout (binding = 3) buffer flow_blob3 { sfp flow_blob3_data[]; };
layout (binding = 4) buffer flow_blob4 { sfp flow_blob4_data[]; };
layout (binding = 5) buffer flow_blob5 { sfp flow_blob5_data[]; };
layout (binding = 6) buffer flow_blob6 { sfp flow_blob6_data[]; };
layout (binding = 7) buffer flow_blob7 { sfp flow_blob7_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.w || gy >= p.h || gz >= 1)
return;
float x0 = float(flow_blob0_data[gy * p.w + gx]);
float x1 = float(flow_blob1_data[gy * p.w + (p.w - 1 - gx)]);
float x2 = float(flow_blob2_data[(p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
float x3 = float(flow_blob3_data[(p.h - 1 - gy) * p.w + gx]);
float x4 = float(flow_blob4_data[gx * p.h + gy]);
float x5 = float(flow_blob5_data[gx * p.h + (p.h - 1 - gy)]);
float x6 = float(flow_blob6_data[(p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
float x7 = float(flow_blob7_data[(p.w - 1 - gx) * p.h + gy]);
float y0 = float(flow_blob0_data[p.cstep + gy * p.w + gx]);
float y1 = float(flow_blob1_data[p.cstep + gy * p.w + (p.w - 1 - gx)]);
float y2 = float(flow_blob2_data[p.cstep + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)]);
float y3 = float(flow_blob3_data[p.cstep + (p.h - 1 - gy) * p.w + gx]);
float y4 = float(flow_blob4_data[p.cstep + gx * p.h + gy]);
float y5 = float(flow_blob5_data[p.cstep + gx * p.h + (p.h - 1 - gy)]);
float y6 = float(flow_blob6_data[p.cstep + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)]);
float y7 = float(flow_blob7_data[p.cstep + (p.w - 1 - gx) * p.h + gy]);
float x = (x0 + -x1 + -x2 + x3 + y4 + y5 + -y6 + -y7) * 0.125f;
float y = (y0 + y1 + -y2 + -y3 + x4 + -x5 + -x6 + x7) * 0.125f;
flow_blob0_data[gy * p.w + gx] = sfp(x);
flow_blob1_data[gy * p.w + (p.w - 1 - gx)] = sfp(-x);
flow_blob2_data[(p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(-x);
flow_blob3_data[(p.h - 1 - gy) * p.w + gx] = sfp(x);
flow_blob4_data[gx * p.h + gy] = sfp(y);
flow_blob5_data[gx * p.h + (p.h - 1 - gy)] = sfp(-y);
flow_blob6_data[(p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(-y);
flow_blob7_data[(p.w - 1 - gx) * p.h + gy] = sfp(y);
flow_blob0_data[p.cstep + gy * p.w + gx] = sfp(y);
flow_blob1_data[p.cstep + gy * p.w + (p.w - 1 - gx)] = sfp(y);
flow_blob2_data[p.cstep + (p.h - 1 - gy) * p.w + (p.w - 1 - gx)] = sfp(-y);
flow_blob3_data[p.cstep + (p.h - 1 - gy) * p.w + gx] = sfp(-y);
flow_blob4_data[p.cstep + gx * p.h + gy] = sfp(x);
flow_blob5_data[p.cstep + gx * p.h + (p.h - 1 - gy)] = sfp(x);
flow_blob6_data[p.cstep + (p.w - 1 - gx) * p.h + (p.h - 1 - gy)] = sfp(-x);
flow_blob7_data[p.cstep + (p.w - 1 - gx) * p.h + gy] = sfp(-x);
}