rife-ncnn-vulkan/src/rife_v4_flow_tta_temporal_a...

60 lines
1.8 KiB
Plaintext

// rife implemented with ncnn library
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
layout (binding = 0) buffer flow_blob { sfp flow_blob_data[]; };
layout (binding = 1) buffer flow_reversed_blob { sfp flow_reversed_blob_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.w || gy >= p.h || gz >= 1)
return;
const int gi = gy * p.w + gx;
float x = float(flow_blob_data[gi]);
float y = float(flow_blob_data[p.cstep + gi]);
float z = float(flow_blob_data[p.cstep * 2 + gi]);
float w = float(flow_blob_data[p.cstep * 3 + gi]);
float m = float(flow_blob_data[p.cstep * 4 + gi]);
float x_reversed = float(flow_reversed_blob_data[gi]);
float y_reversed = float(flow_reversed_blob_data[p.cstep + gi]);
float z_reversed = float(flow_reversed_blob_data[p.cstep * 2 + gi]);
float w_reversed = float(flow_reversed_blob_data[p.cstep * 3 + gi]);
float m_reversed = float(flow_reversed_blob_data[p.cstep * 4 + gi]);
x = (x + z_reversed) * 0.5f;
y = (y + w_reversed) * 0.5f;
z = (z + x_reversed) * 0.5f;
w = (w + y_reversed) * 0.5f;
m = (m - m_reversed) * 0.5f;
flow_blob_data[gi] = sfp(x);
flow_blob_data[p.cstep + gi] = sfp(y);
flow_blob_data[p.cstep * 2 + gi] = sfp(z);
flow_blob_data[p.cstep * 3 + gi] = sfp(w);
flow_blob_data[p.cstep * 4 + gi] = sfp(m);
flow_reversed_blob_data[gi] = sfp(z);
flow_reversed_blob_data[p.cstep + gi] = sfp(w);
flow_reversed_blob_data[p.cstep * 2 + gi] = sfp(x);
flow_reversed_blob_data[p.cstep * 3 + gi] = sfp(y);
flow_reversed_blob_data[p.cstep * 4 + gi] = sfp(-m);
}