implement temporal tta

This commit is contained in:
nihui 2021-05-22 19:21:02 +08:00
parent b639fbd84b
commit c26c90bb64
6 changed files with 1032 additions and 69 deletions

View File

@ -230,6 +230,9 @@ rife_add_shader(rife_preproc_tta.comp)
rife_add_shader(rife_postproc_tta.comp)
rife_add_shader(rife_flow_tta_avg.comp)
rife_add_shader(rife_v2_flow_tta_avg.comp)
rife_add_shader(rife_flow_tta_temporal_avg.comp)
rife_add_shader(rife_v2_flow_tta_temporal_avg.comp)
rife_add_shader(rife_out_tta_temporal_avg.comp)
rife_add_shader(warp.comp)
rife_add_shader(warp_pack4.comp)
rife_add_shader(warp_pack8.comp)

File diff suppressed because it is too large Load Diff

View File

@ -32,11 +32,14 @@ private:
ncnn::Pipeline* rife_preproc;
ncnn::Pipeline* rife_postproc;
ncnn::Pipeline* rife_flow_tta_avg;
ncnn::Pipeline* rife_flow_tta_temporal_avg;
ncnn::Pipeline* rife_out_tta_temporal_avg;
ncnn::Layer* rife_uhd_downscale_image;
ncnn::Layer* rife_uhd_upscale_flow;
ncnn::Layer* rife_uhd_double_flow;
ncnn::Layer* rife_v2_slice_flow;
bool tta_mode;
bool tta_temporal_mode;
bool uhd_mode;
int num_threads;
bool rife_v2;

View File

@ -0,0 +1,42 @@
// rife implemented with ncnn library
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
layout (binding = 0) buffer flow_blob { sfp flow_blob_data[]; };
layout (binding = 1) buffer flow_reversed_blob { sfp flow_reversed_blob_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.w || gy >= p.h || gz >= 1)
return;
afp x = buffer_ld1(flow_blob_data, gy * p.w + gx);
afp y = buffer_ld1(flow_blob_data, p.cstep + gy * p.w + gx);
afp x_reversed = buffer_ld1(flow_reversed_blob_data, gy * p.w + gx);
afp y_reversed = buffer_ld1(flow_reversed_blob_data, p.cstep + gy * p.w + gx);
x = (x - x_reversed) * afp(0.5);
y = (y - y_reversed) * afp(0.5);
buffer_st1(flow_blob_data, gy * p.w + gx, x);
buffer_st1(flow_blob_data, p.cstep + gy * p.w + gx, y);
buffer_st1(flow_reversed_blob_data, gy * p.w + gx, -x);
buffer_st1(flow_reversed_blob_data, p.cstep + gy * p.w + gx, -y);
}

View File

@ -0,0 +1,36 @@
// rife implemented with ncnn library
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
layout (binding = 0) buffer out_blob { sfp out_blob_data[]; };
layout (binding = 1) buffer out_reversed_blob { sfp out_reversed_blob_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.w || gy >= p.h || gz >= 3)
return;
const int gzi = gz * p.cstep;
afp v0 = buffer_ld1(out_blob_data, gzi + gy * p.w + gx);
afp v1 = buffer_ld1(out_reversed_blob_data, gzi + gy * p.w + gx);
afp v = (v0 + v1) * 0.5;
buffer_st1(out_blob_data, gzi + gy * p.w + gx, v);
}

View File

@ -0,0 +1,38 @@
// rife implemented with ncnn library
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
layout (binding = 0) buffer flow_blob { sfpvec4 flow_blob_data[]; };
layout (binding = 1) buffer flow_reversed_blob { sfpvec4 flow_reversed_blob_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.w || gy >= p.h || gz >= 1)
return;
afpvec4 xyzw = buffer_ld4(flow_blob_data, gy * p.w + gx);
afpvec4 xyzw_reversed = buffer_ld4(flow_reversed_blob_data, gy * p.w + gx);
afp x = (xyzw.x + xyzw_reversed.z) * afp(0.5f);
afp y = (xyzw.y + xyzw_reversed.w) * afp(0.5f);
afp z = (xyzw.z + xyzw_reversed.x) * afp(0.5f);
afp w = (xyzw.w + xyzw_reversed.y) * afp(0.5f);
buffer_st4(flow_blob_data, gy * p.w + gx, afpvec4(x, y, z, w));
buffer_st4(flow_reversed_blob_data, gy * p.w + gx, afpvec4(z, w, x, y));
}