implement temporal tta
This commit is contained in:
parent
b639fbd84b
commit
c26c90bb64
|
@ -230,6 +230,9 @@ rife_add_shader(rife_preproc_tta.comp)
|
|||
rife_add_shader(rife_postproc_tta.comp)
|
||||
rife_add_shader(rife_flow_tta_avg.comp)
|
||||
rife_add_shader(rife_v2_flow_tta_avg.comp)
|
||||
rife_add_shader(rife_flow_tta_temporal_avg.comp)
|
||||
rife_add_shader(rife_v2_flow_tta_temporal_avg.comp)
|
||||
rife_add_shader(rife_out_tta_temporal_avg.comp)
|
||||
rife_add_shader(warp.comp)
|
||||
rife_add_shader(warp_pack4.comp)
|
||||
rife_add_shader(warp_pack8.comp)
|
||||
|
|
979
src/rife.cpp
979
src/rife.cpp
File diff suppressed because it is too large
Load Diff
|
@ -32,11 +32,14 @@ private:
|
|||
ncnn::Pipeline* rife_preproc;
|
||||
ncnn::Pipeline* rife_postproc;
|
||||
ncnn::Pipeline* rife_flow_tta_avg;
|
||||
ncnn::Pipeline* rife_flow_tta_temporal_avg;
|
||||
ncnn::Pipeline* rife_out_tta_temporal_avg;
|
||||
ncnn::Layer* rife_uhd_downscale_image;
|
||||
ncnn::Layer* rife_uhd_upscale_flow;
|
||||
ncnn::Layer* rife_uhd_double_flow;
|
||||
ncnn::Layer* rife_v2_slice_flow;
|
||||
bool tta_mode;
|
||||
bool tta_temporal_mode;
|
||||
bool uhd_mode;
|
||||
int num_threads;
|
||||
bool rife_v2;
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
// rife implemented with ncnn library
|
||||
|
||||
#version 450
|
||||
|
||||
#if NCNN_fp16_storage
|
||||
#extension GL_EXT_shader_16bit_storage: require
|
||||
#endif
|
||||
|
||||
layout (binding = 0) buffer flow_blob { sfp flow_blob_data[]; };
|
||||
layout (binding = 1) buffer flow_reversed_blob { sfp flow_reversed_blob_data[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int w;
|
||||
int h;
|
||||
int cstep;
|
||||
} p;
|
||||
|
||||
void main()
|
||||
{
|
||||
int gx = int(gl_GlobalInvocationID.x);
|
||||
int gy = int(gl_GlobalInvocationID.y);
|
||||
int gz = int(gl_GlobalInvocationID.z);
|
||||
|
||||
if (gx >= p.w || gy >= p.h || gz >= 1)
|
||||
return;
|
||||
|
||||
afp x = buffer_ld1(flow_blob_data, gy * p.w + gx);
|
||||
afp y = buffer_ld1(flow_blob_data, p.cstep + gy * p.w + gx);
|
||||
|
||||
afp x_reversed = buffer_ld1(flow_reversed_blob_data, gy * p.w + gx);
|
||||
afp y_reversed = buffer_ld1(flow_reversed_blob_data, p.cstep + gy * p.w + gx);
|
||||
|
||||
x = (x - x_reversed) * afp(0.5);
|
||||
y = (y - y_reversed) * afp(0.5);
|
||||
|
||||
buffer_st1(flow_blob_data, gy * p.w + gx, x);
|
||||
buffer_st1(flow_blob_data, p.cstep + gy * p.w + gx, y);
|
||||
|
||||
buffer_st1(flow_reversed_blob_data, gy * p.w + gx, -x);
|
||||
buffer_st1(flow_reversed_blob_data, p.cstep + gy * p.w + gx, -y);
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
// rife implemented with ncnn library
|
||||
|
||||
#version 450
|
||||
|
||||
#if NCNN_fp16_storage
|
||||
#extension GL_EXT_shader_16bit_storage: require
|
||||
#endif
|
||||
|
||||
layout (binding = 0) buffer out_blob { sfp out_blob_data[]; };
|
||||
layout (binding = 1) buffer out_reversed_blob { sfp out_reversed_blob_data[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int w;
|
||||
int h;
|
||||
int cstep;
|
||||
} p;
|
||||
|
||||
void main()
|
||||
{
|
||||
int gx = int(gl_GlobalInvocationID.x);
|
||||
int gy = int(gl_GlobalInvocationID.y);
|
||||
int gz = int(gl_GlobalInvocationID.z);
|
||||
|
||||
if (gx >= p.w || gy >= p.h || gz >= 3)
|
||||
return;
|
||||
|
||||
const int gzi = gz * p.cstep;
|
||||
|
||||
afp v0 = buffer_ld1(out_blob_data, gzi + gy * p.w + gx);
|
||||
afp v1 = buffer_ld1(out_reversed_blob_data, gzi + gy * p.w + gx);
|
||||
|
||||
afp v = (v0 + v1) * 0.5;
|
||||
|
||||
buffer_st1(out_blob_data, gzi + gy * p.w + gx, v);
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
// rife implemented with ncnn library
|
||||
|
||||
#version 450
|
||||
|
||||
#if NCNN_fp16_storage
|
||||
#extension GL_EXT_shader_16bit_storage: require
|
||||
#endif
|
||||
|
||||
layout (binding = 0) buffer flow_blob { sfpvec4 flow_blob_data[]; };
|
||||
layout (binding = 1) buffer flow_reversed_blob { sfpvec4 flow_reversed_blob_data[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int w;
|
||||
int h;
|
||||
int cstep;
|
||||
} p;
|
||||
|
||||
void main()
|
||||
{
|
||||
int gx = int(gl_GlobalInvocationID.x);
|
||||
int gy = int(gl_GlobalInvocationID.y);
|
||||
int gz = int(gl_GlobalInvocationID.z);
|
||||
|
||||
if (gx >= p.w || gy >= p.h || gz >= 1)
|
||||
return;
|
||||
|
||||
afpvec4 xyzw = buffer_ld4(flow_blob_data, gy * p.w + gx);
|
||||
afpvec4 xyzw_reversed = buffer_ld4(flow_reversed_blob_data, gy * p.w + gx);
|
||||
|
||||
afp x = (xyzw.x + xyzw_reversed.z) * afp(0.5f);
|
||||
afp y = (xyzw.y + xyzw_reversed.w) * afp(0.5f);
|
||||
afp z = (xyzw.z + xyzw_reversed.x) * afp(0.5f);
|
||||
afp w = (xyzw.w + xyzw_reversed.y) * afp(0.5f);
|
||||
|
||||
buffer_st4(flow_blob_data, gy * p.w + gx, afpvec4(x, y, z, w));
|
||||
buffer_st4(flow_reversed_blob_data, gy * p.w + gx, afpvec4(z, w, x, y));
|
||||
}
|
Loading…
Reference in New Issue