implement gpu temporal tta for rife v4, expose userspace option

This commit is contained in:
nihui 2022-10-16 21:09:43 +08:00
parent 5ca8c773f2
commit d964f6f165
6 changed files with 496 additions and 74 deletions

View File

@ -82,7 +82,8 @@ Usage: rife-ncnn-vulkan -0 infile -1 infile1 -o outfile [options]...
-m model-path rife model path (default=rife-v2.3)
-g gpu-id gpu device to use (-1=cpu, default=auto) can be 0,1,2 for multi-gpu
-j load:proc:save thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu
-x enable tta mode
-x enable spatial tta mode
-z enable temporal tta mode
-u enable UHD mode
-f pattern-format output image filename pattern format (%08d.jpg/png/webp, default=ext/%08d.png)
```
@ -132,10 +133,6 @@ cmake ../src
cmake --build . -j 4
```
### TODO
* test-time temporal augmentation aka TTA-t
### Model
| model | upstream version |

View File

@ -248,6 +248,7 @@ rife_add_shader(rife_v2_flow_tta_avg.comp)
rife_add_shader(rife_v4_flow_tta_avg.comp)
rife_add_shader(rife_flow_tta_temporal_avg.comp)
rife_add_shader(rife_v2_flow_tta_temporal_avg.comp)
rife_add_shader(rife_v4_flow_tta_temporal_avg.comp)
rife_add_shader(rife_out_tta_temporal_avg.comp)
rife_add_shader(rife_v4_timestep.comp)
rife_add_shader(rife_v4_timestep_tta.comp)

View File

@ -114,7 +114,8 @@ static void print_usage()
fprintf(stderr, " -m model-path rife model path (default=rife-v2.3)\n");
fprintf(stderr, " -g gpu-id gpu device to use (-1=cpu, default=auto) can be 0,1,2 for multi-gpu\n");
fprintf(stderr, " -j load:proc:save thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu\n");
fprintf(stdout, " -x enable tta mode\n");
fprintf(stdout, " -x enable spatial tta mode\n");
fprintf(stdout, " -z enable temporal tta mode\n");
fprintf(stdout, " -u enable UHD mode\n");
fprintf(stderr, " -f pattern-format output image filename pattern format (%%08d.jpg/png/webp, default=ext/%%08d.png)\n");
}
@ -454,13 +455,14 @@ int main(int argc, char** argv)
int jobs_save = 2;
int verbose = 0;
int tta_mode = 0;
int tta_temporal_mode = 0;
int uhd_mode = 0;
path_t pattern_format = PATHSTR("%08d.png");
#if _WIN32
setlocale(LC_ALL, "");
wchar_t opt;
while ((opt = getopt(argc, argv, L"0:1:i:o:n:s:m:g:j:f:vxuh")) != (wchar_t)-1)
while ((opt = getopt(argc, argv, L"0:1:i:o:n:s:m:g:j:f:vxzuh")) != (wchar_t)-1)
{
switch (opt)
{
@ -501,6 +503,9 @@ int main(int argc, char** argv)
case L'x':
tta_mode = 1;
break;
case L'z':
tta_temporal_mode = 1;
break;
case L'u':
uhd_mode = 1;
break;
@ -512,7 +517,7 @@ int main(int argc, char** argv)
}
#else // _WIN32
int opt;
while ((opt = getopt(argc, argv, "0:1:i:o:n:s:m:g:j:f:vxuh")) != -1)
while ((opt = getopt(argc, argv, "0:1:i:o:n:s:m:g:j:f:vxzuh")) != -1)
{
switch (opt)
{
@ -553,6 +558,9 @@ int main(int argc, char** argv)
case 'x':
tta_mode = 1;
break;
case 'z':
tta_temporal_mode = 1;
break;
case 'u':
uhd_mode = 1;
break;
@ -814,7 +822,7 @@ int main(int argc, char** argv)
{
int num_threads = gpuid[i] == -1 ? jobs_proc[i] : 1;
rife[i] = new RIFE(gpuid[i], tta_mode, uhd_mode, num_threads, rife_v2, rife_v4);
rife[i] = new RIFE(gpuid[i], tta_mode, tta_temporal_mode, uhd_mode, num_threads, rife_v2, rife_v4);
rife[i]->load(modeldir);
}

View File

@ -15,6 +15,7 @@
#include "rife_v4_flow_tta_avg.comp.hex.h"
#include "rife_flow_tta_temporal_avg.comp.hex.h"
#include "rife_v2_flow_tta_temporal_avg.comp.hex.h"
#include "rife_v4_flow_tta_temporal_avg.comp.hex.h"
#include "rife_out_tta_temporal_avg.comp.hex.h"
#include "rife_v4_timestep.comp.hex.h"
#include "rife_v4_timestep_tta.comp.hex.h"
@ -23,7 +24,7 @@
DEFINE_LAYER_CREATOR(Warp)
RIFE::RIFE(int gpuid, bool _tta_mode, bool _uhd_mode, int _num_threads, bool _rife_v2, bool _rife_v4)
RIFE::RIFE(int gpuid, bool _tta_mode, bool _tta_temporal_mode, bool _uhd_mode, int _num_threads, bool _rife_v2, bool _rife_v4)
{
vkdev = gpuid == -1 ? 0 : ncnn::get_gpu_device(gpuid);
@ -38,7 +39,7 @@ RIFE::RIFE(int gpuid, bool _tta_mode, bool _uhd_mode, int _num_threads, bool _ri
rife_uhd_double_flow = 0;
rife_v2_slice_flow = 0;
tta_mode = _tta_mode;
tta_temporal_mode = false;
tta_temporal_mode = _tta_temporal_mode;
uhd_mode = _uhd_mode;
num_threads = _num_threads;
rife_v2 = _rife_v2;
@ -249,7 +250,11 @@ int RIFE::load(const std::string& modeldir)
ncnn::MutexLockGuard guard(lock);
if (spirv.empty())
{
if (rife_v2)
if (rife_v4)
{
compile_spirv_module(rife_v4_flow_tta_temporal_avg_comp_data, sizeof(rife_v4_flow_tta_temporal_avg_comp_data), opt, spirv);
}
else if (rife_v2)
{
compile_spirv_module(rife_v2_flow_tta_temporal_avg_comp_data, sizeof(rife_v2_flow_tta_temporal_avg_comp_data), opt, spirv);
}
@ -2611,12 +2616,268 @@ int RIFE::process_v4(const ncnn::Mat& in0image, const ncnn::Mat& in1image, float
cmd.record_pipeline(rife_v4_timestep, bindings, constants, timestep_gpu_padded[0]);
}
ncnn::VkMat flow[4][8];
for (int fi = 0; fi < 4; fi++)
ncnn::VkMat out_gpu_padded[8];
if (tta_temporal_mode)
{
ncnn::VkMat timestep_gpu_padded_reversed[2];
{
timestep_gpu_padded_reversed[0].create(w_padded, h_padded, 1, in_out_tile_elemsize, 1, blob_vkallocator);
timestep_gpu_padded_reversed[1].create(h_padded, w_padded, 1, in_out_tile_elemsize, 1, blob_vkallocator);
std::vector<ncnn::VkMat> bindings(2);
bindings[0] = timestep_gpu_padded_reversed[0];
bindings[1] = timestep_gpu_padded_reversed[1];
std::vector<ncnn::vk_constant_type> constants(4);
constants[0].i = timestep_gpu_padded_reversed[0].w;
constants[1].i = timestep_gpu_padded_reversed[0].h;
constants[2].i = timestep_gpu_padded_reversed[0].cstep;
constants[3].f = 1.f - timestep;
cmd.record_pipeline(rife_v4_timestep, bindings, constants, timestep_gpu_padded_reversed[0]);
}
ncnn::VkMat flow[4][8];
ncnn::VkMat flow_reversed[4][8];
for (int fi = 0; fi < 4; fi++)
{
for (int ti = 0; ti < 8; ti++)
{
{
// flownet flow mask
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in0_gpu_padded[ti]);
ex.input("in1", in1_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded[ti / 4]);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow[2][ti]);
case 2: ex.input("flow1", flow[1][ti]);
case 1: ex.input("flow0", flow[0][ti]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow[fi][ti], cmd);
}
}
}
{
// flownet flow mask reversed
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in1_gpu_padded[ti]);
ex.input("in1", in0_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded_reversed[ti / 4]);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow_reversed[2][ti]);
case 2: ex.input("flow1", flow_reversed[1][ti]);
case 1: ex.input("flow0", flow_reversed[0][ti]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow_reversed[fi][ti], cmd);
}
}
}
// merge flow and flow_reversed
{
std::vector<ncnn::VkMat> bindings(2);
bindings[0] = flow[fi][ti];
bindings[1] = flow_reversed[fi][ti];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = flow[fi][ti].w;
constants[1].i = flow[fi][ti].h;
constants[2].i = flow[fi][ti].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = flow[fi][ti].w;
dispatcher.h = flow[fi][ti].h;
dispatcher.c = 1;
cmd.record_pipeline(rife_flow_tta_temporal_avg, bindings, constants, dispatcher);
}
}
// avg flow mask
{
std::vector<ncnn::VkMat> bindings(8);
bindings[0] = flow[fi][0];
bindings[1] = flow[fi][1];
bindings[2] = flow[fi][2];
bindings[3] = flow[fi][3];
bindings[4] = flow[fi][4];
bindings[5] = flow[fi][5];
bindings[6] = flow[fi][6];
bindings[7] = flow[fi][7];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = flow[fi][0].w;
constants[1].i = flow[fi][0].h;
constants[2].i = flow[fi][0].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = flow[fi][0].w;
dispatcher.h = flow[fi][0].h;
dispatcher.c = 1;
cmd.record_pipeline(rife_flow_tta_avg, bindings, constants, dispatcher);
}
{
std::vector<ncnn::VkMat> bindings(8);
bindings[0] = flow_reversed[fi][0];
bindings[1] = flow_reversed[fi][1];
bindings[2] = flow_reversed[fi][2];
bindings[3] = flow_reversed[fi][3];
bindings[4] = flow_reversed[fi][4];
bindings[5] = flow_reversed[fi][5];
bindings[6] = flow_reversed[fi][6];
bindings[7] = flow_reversed[fi][7];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = flow_reversed[fi][0].w;
constants[1].i = flow_reversed[fi][0].h;
constants[2].i = flow_reversed[fi][0].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = flow_reversed[fi][0].w;
dispatcher.h = flow_reversed[fi][0].h;
dispatcher.c = 1;
cmd.record_pipeline(rife_flow_tta_avg, bindings, constants, dispatcher);
}
}
ncnn::VkMat out_gpu_padded_reversed[8];
for (int ti = 0; ti < 8; ti++)
{
// flownet flow mask
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in0_gpu_padded[ti]);
ex.input("in1", in1_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded[ti / 4]);
ex.input("flow0", flow[0][ti]);
ex.input("flow1", flow[1][ti]);
ex.input("flow2", flow[2][ti]);
ex.input("flow3", flow[3][ti]);
ex.extract("out0", out_gpu_padded[ti], cmd);
}
{
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in1_gpu_padded[ti]);
ex.input("in1", in0_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded_reversed[ti / 4]);
ex.input("flow0", flow_reversed[0][ti]);
ex.input("flow1", flow_reversed[1][ti]);
ex.input("flow2", flow_reversed[2][ti]);
ex.input("flow3", flow_reversed[3][ti]);
ex.extract("out0", out_gpu_padded_reversed[ti], cmd);
}
// merge output
{
std::vector<ncnn::VkMat> bindings(2);
bindings[0] = out_gpu_padded[ti];
bindings[1] = out_gpu_padded_reversed[ti];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = out_gpu_padded[ti].w;
constants[1].i = out_gpu_padded[ti].h;
constants[2].i = out_gpu_padded[ti].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = out_gpu_padded[ti].w;
dispatcher.h = out_gpu_padded[ti].h;
dispatcher.c = 3;
cmd.record_pipeline(rife_out_tta_temporal_avg, bindings, constants, dispatcher);
}
}
}
else
{
ncnn::VkMat flow[4][8];
for (int fi = 0; fi < 4; fi++)
{
for (int ti = 0; ti < 8; ti++)
{
// flownet flow mask
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in0_gpu_padded[ti]);
ex.input("in1", in1_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded[ti / 4]);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow[2][ti]);
case 2: ex.input("flow1", flow[1][ti]);
case 1: ex.input("flow0", flow[0][ti]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow[fi][ti], cmd);
}
}
}
// avg flow mask
{
std::vector<ncnn::VkMat> bindings(8);
bindings[0] = flow[fi][0];
bindings[1] = flow[fi][1];
bindings[2] = flow[fi][2];
bindings[3] = flow[fi][3];
bindings[4] = flow[fi][4];
bindings[5] = flow[fi][5];
bindings[6] = flow[fi][6];
bindings[7] = flow[fi][7];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = flow[fi][0].w;
constants[1].i = flow[fi][0].h;
constants[2].i = flow[fi][0].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = flow[fi][0].w;
dispatcher.h = flow[fi][0].h;
dispatcher.c = 1;
cmd.record_pipeline(rife_flow_tta_avg, bindings, constants, dispatcher);
}
}
for (int ti = 0; ti < 8; ti++)
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
@ -2625,65 +2886,13 @@ int RIFE::process_v4(const ncnn::Mat& in0image, const ncnn::Mat& in1image, float
ex.input("in0", in0_gpu_padded[ti]);
ex.input("in1", in1_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded[ti / 4]);
ex.input("flow0", flow[0][ti]);
ex.input("flow1", flow[1][ti]);
ex.input("flow2", flow[2][ti]);
ex.input("flow3", flow[3][ti]);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow[2][ti]);
case 2: ex.input("flow1", flow[1][ti]);
case 1: ex.input("flow0", flow[0][ti]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow[fi][ti], cmd);
}
}
ex.extract("out0", out_gpu_padded[ti], cmd);
}
// avg flow mask
{
std::vector<ncnn::VkMat> bindings(8);
bindings[0] = flow[fi][0];
bindings[1] = flow[fi][1];
bindings[2] = flow[fi][2];
bindings[3] = flow[fi][3];
bindings[4] = flow[fi][4];
bindings[5] = flow[fi][5];
bindings[6] = flow[fi][6];
bindings[7] = flow[fi][7];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = flow[fi][0].w;
constants[1].i = flow[fi][0].h;
constants[2].i = flow[fi][0].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = flow[fi][0].w;
dispatcher.h = flow[fi][0].h;
dispatcher.c = 1;
cmd.record_pipeline(rife_flow_tta_avg, bindings, constants, dispatcher);
}
}
ncnn::VkMat out_gpu_padded[8];
for (int ti = 0; ti < 8; ti++)
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in0_gpu_padded[ti]);
ex.input("in1", in1_gpu_padded[ti]);
ex.input("in2", timestep_gpu_padded[ti / 4]);
ex.input("flow0", flow[0][ti]);
ex.input("flow1", flow[1][ti]);
ex.input("flow2", flow[2][ti]);
ex.input("flow3", flow[3][ti]);
ex.extract("out0", out_gpu_padded[ti], cmd);
}
if (opt.use_fp16_storage && opt.use_int8_storage)
@ -2774,9 +2983,157 @@ int RIFE::process_v4(const ncnn::Mat& in0image, const ncnn::Mat& in1image, float
cmd.record_pipeline(rife_v4_timestep, bindings, constants, timestep_gpu_padded);
}
// flownet
ncnn::VkMat out_gpu_padded;
if (tta_temporal_mode)
{
ncnn::VkMat timestep_gpu_padded_reversed;
{
timestep_gpu_padded_reversed.create(w_padded, h_padded, 1, in_out_tile_elemsize, 1, blob_vkallocator);
std::vector<ncnn::VkMat> bindings(1);
bindings[0] = timestep_gpu_padded_reversed;
std::vector<ncnn::vk_constant_type> constants(4);
constants[0].i = timestep_gpu_padded_reversed.w;
constants[1].i = timestep_gpu_padded_reversed.h;
constants[2].i = timestep_gpu_padded_reversed.cstep;
constants[3].f = 1.f - timestep;
cmd.record_pipeline(rife_v4_timestep, bindings, constants, timestep_gpu_padded_reversed);
}
ncnn::VkMat flow[4];
ncnn::VkMat flow_reversed[4];
for (int fi = 0; fi < 4; fi++)
{
{
// flownet flow mask
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in0_gpu_padded);
ex.input("in1", in1_gpu_padded);
ex.input("in2", timestep_gpu_padded);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow[2]);
case 2: ex.input("flow1", flow[1]);
case 1: ex.input("flow0", flow[0]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow[fi], cmd);
}
}
}
{
// flownet flow mask reversed
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in1_gpu_padded);
ex.input("in1", in0_gpu_padded);
ex.input("in2", timestep_gpu_padded_reversed);
// intentional fall through
switch (fi)
{
case 3: ex.input("flow2", flow_reversed[2]);
case 2: ex.input("flow1", flow_reversed[1]);
case 1: ex.input("flow0", flow_reversed[0]);
default:
{
char tmp[16];
sprintf(tmp, "flow%d", fi);
ex.extract(tmp, flow_reversed[fi], cmd);
}
}
}
// merge flow and flow_reversed
{
std::vector<ncnn::VkMat> bindings(2);
bindings[0] = flow[fi];
bindings[1] = flow_reversed[fi];
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = flow[fi].w;
constants[1].i = flow[fi].h;
constants[2].i = flow[fi].cstep;
ncnn::VkMat dispatcher;
dispatcher.w = flow[fi].w;
dispatcher.h = flow[fi].h;
dispatcher.c = 1;
cmd.record_pipeline(rife_flow_tta_temporal_avg, bindings, constants, dispatcher);
}
}
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in0_gpu_padded);
ex.input("in1", in1_gpu_padded);
ex.input("in2", timestep_gpu_padded);
ex.input("flow0", flow[0]);
ex.input("flow1", flow[1]);
ex.input("flow2", flow[2]);
ex.input("flow3", flow[3]);
ex.extract("out0", out_gpu_padded, cmd);
}
ncnn::VkMat out_gpu_padded_reversed;
{
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);
ex.set_staging_vkallocator(staging_vkallocator);
ex.input("in0", in1_gpu_padded);
ex.input("in1", in0_gpu_padded);
ex.input("in2", timestep_gpu_padded_reversed);
ex.input("flow0", flow_reversed[0]);
ex.input("flow1", flow_reversed[1]);
ex.input("flow2", flow_reversed[2]);
ex.input("flow3", flow_reversed[3]);
ex.extract("out0", out_gpu_padded_reversed, cmd);
}
// merge output
{
std::vector<ncnn::VkMat> bindings(2);
bindings[0] = out_gpu_padded;
bindings[1] = out_gpu_padded_reversed;
std::vector<ncnn::vk_constant_type> constants(3);
constants[0].i = out_gpu_padded.w;
constants[1].i = out_gpu_padded.h;
constants[2].i = out_gpu_padded.cstep;
ncnn::VkMat dispatcher;
dispatcher.w = out_gpu_padded.w;
dispatcher.h = out_gpu_padded.h;
dispatcher.c = 3;
cmd.record_pipeline(rife_out_tta_temporal_avg, bindings, constants, dispatcher);
}
}
else
{
// flownet
ncnn::Extractor ex = flownet.create_extractor();
ex.set_blob_vkallocator(blob_vkallocator);
ex.set_workspace_vkallocator(blob_vkallocator);

View File

@ -11,7 +11,7 @@
class RIFE
{
public:
RIFE(int gpuid, bool tta_mode = false, bool uhd_mode = false, int num_threads = 1, bool rife_v2 = false, bool rife_v4 = false);
RIFE(int gpuid, bool tta_mode = false, bool tta_temporal_mode = false, bool uhd_mode = false, int num_threads = 1, bool rife_v2 = false, bool rife_v4 = false);
~RIFE();
#if _WIN32

View File

@ -0,0 +1,59 @@
// rife implemented with ncnn library
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
layout (binding = 0) buffer flow_blob { sfp flow_blob_data[]; };
layout (binding = 1) buffer flow_reversed_blob { sfp flow_reversed_blob_data[]; };
layout (push_constant) uniform parameter
{
int w;
int h;
int cstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= p.w || gy >= p.h || gz >= 1)
return;
const int gi = gy * p.w + gx;
float x = float(flow_blob_data[gi]);
float y = float(flow_blob_data[p.cstep + gi]);
float z = float(flow_blob_data[p.cstep * 2 + gi]);
float w = float(flow_blob_data[p.cstep * 3 + gi]);
float m = float(flow_blob_data[p.cstep * 4 + gi]);
float x_reversed = float(flow_reversed_blob_data[gi]);
float y_reversed = float(flow_reversed_blob_data[p.cstep + gi]);
float z_reversed = float(flow_reversed_blob_data[p.cstep * 2 + gi]);
float w_reversed = float(flow_reversed_blob_data[p.cstep * 3 + gi]);
float m_reversed = float(flow_reversed_blob_data[p.cstep * 4 + gi]);
x = (x + z_reversed) * 0.5f;
y = (y + w_reversed) * 0.5f;
z = (z + x_reversed) * 0.5f;
w = (w + y_reversed) * 0.5f;
m = (m - m_reversed) * 0.5f;
flow_blob_data[gi] = sfp(x);
flow_blob_data[p.cstep + gi] = sfp(y);
flow_blob_data[p.cstep * 2 + gi] = sfp(z);
flow_blob_data[p.cstep * 3 + gi] = sfp(w);
flow_blob_data[p.cstep * 4 + gi] = sfp(m);
flow_reversed_blob_data[gi] = sfp(z);
flow_reversed_blob_data[p.cstep + gi] = sfp(w);
flow_reversed_blob_data[p.cstep * 2 + gi] = sfp(x);
flow_reversed_blob_data[p.cstep * 3 + gi] = sfp(y);
flow_reversed_blob_data[p.cstep * 4 + gi] = sfp(-m);
}