ARM64: Port msac improvements to more functions

Port improvements from the hi token functions to the rest of the symbol
adaption functions. These weren't originally ported since they didn't
work with arbitrary padding. In practice, zero padding is already used
and only the tests need to be updated.

Results - Neoverse N1

Old:
msac_decode_symbol_adapt4_c:         41.4 ( 1.00x)
msac_decode_symbol_adapt4_neon:      31.0 ( 1.34x)
msac_decode_symbol_adapt8_c:         54.5 ( 1.00x)
msac_decode_symbol_adapt8_neon:      32.2 ( 1.69x)
msac_decode_symbol_adapt16_c:        85.6 ( 1.00x)
msac_decode_symbol_adapt16_neon:     37.5 ( 2.28x)

New:
msac_decode_symbol_adapt4_c:         41.5 ( 1.00x)
msac_decode_symbol_adapt4_neon:      27.7 ( 1.50x)
msac_decode_symbol_adapt8_c:         55.7 ( 1.00x)
msac_decode_symbol_adapt8_neon:      30.1 ( 1.85x)
msac_decode_symbol_adapt16_c:        82.4 ( 1.00x)
msac_decode_symbol_adapt16_neon:     35.2 ( 2.34x)
This commit is contained in:
Kyle Siefring 2024-04-14 17:10:44 -04:00 committed by Henrik Gramner
parent 5b5399911d
commit 37d52435d1
2 changed files with 17 additions and 29 deletions

View File

@ -40,11 +40,6 @@ const coeffs
.short 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
endconst
const bits
.short 0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80
.short 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000
endconst
.macro ld1_n d0, d1, src, sz, n
.if \n <= 8
ld1 {\d0\sz}, [\src]
@ -96,13 +91,6 @@ endconst
.endif
.endm
.macro urhadd_n d0, d1, s0, s1, s2, s3, sz, n
urhadd \d0\sz, \s0\sz, \s2\sz
.if \n == 16
urhadd \d1\sz, \s1\sz, \s3\sz
.endif
.endm
.macro sshl_n d0, d1, s0, s1, s2, s3, sz, n
sshl \d0\sz, \s0\sz, \s2\sz
.if \n == 16
@ -149,22 +137,19 @@ function msac_decode_symbol_adapt4_neon, export=1
add_n v4, v5, v6, v7, v4, v5, \sz, \n // v = ((cdf >> EC_PROB_SHIFT) * r) >> 1 + EC_MIN_PROB * (n_symbols - ret)
ld1r {v6.8h}, [x8] // dif >> (EC_WIN_SIZE - 16)
movrel x8, bits
str_n q4, q5, sp, #16, \n // store v values to allow indexed access
ld1_n v16, v17, x8, .8h, \n
cmhs_n v2, v3, v6, v6, v4, v5, \sz, \n // c >= v
cmhs_n v2, v3, v6, v6, v4, v5, .8h, \n // c >= v
and_n v6, v7, v2, v3, v16, v17, .16b, \n // One bit per halfword set in the mask
.if \n == 16
add v6.8h, v6.8h, v7.8h
add v6\sz, v2\sz, v3\sz
addv h6, v6\sz // -n + ret
.else
addv h6, v2\sz // -n + ret
.endif
addv h6, v6.8h // Aggregate mask bits
ldr w4, [x0, #ALLOW_UPDATE_CDF]
umov w3, v6.h[0]
rbit w3, w3
clz w15, w3 // ret
smov w15, v6.h[0]
add w15, w15, #\n // ret
cbz w4, L(renorm)
// update_cdf
@ -177,21 +162,24 @@ function msac_decode_symbol_adapt4_neon, export=1
mov w4, #-4
cmn w14, #3 // set C if n_symbols <= 2
.endif
urhadd_n v4, v5, v5, v5, v2, v3, \sz, \n // i >= val ? -1 : 32768
sub_n v16, v17, v0, v1, v2, v3, \sz, \n // cdf + (i >= val ? 1 : 0)
orr v2\sz, #0x80, lsl #8
.if \n == 16
orr v3\sz, #0x80, lsl #8
.endif
.if \n == 16
sub w4, w4, w3, lsr #4 // -((count >> 4) + 5)
.else
lsr w14, w3, #4 // count >> 4
sbc w4, w4, w14 // -((count >> 4) + (n_symbols > 2) + 4)
.endif
sub_n v4, v5, v4, v5, v0, v1, \sz, \n // (32768 - cdf[i]) or (-1 - cdf[i])
sub_n v2, v3, v2, v3, v0, v1, \sz, \n // (32768 - cdf[i]) or (-1 - cdf[i])
dup v6\sz, w4 // -rate
sub w3, w3, w3, lsr #5 // count - (count == 32)
sub_n v0, v1, v0, v1, v2, v3, \sz, \n // cdf + (i >= val ? 1 : 0)
sshl_n v4, v5, v4, v5, v6, v6, \sz, \n // ({32768,-1} - cdf[i]) >> rate
sshl_n v2, v3, v2, v3, v6, v6, \sz, \n // ({32768,-1} - cdf[i]) >> rate
add w3, w3, #1 // count + (count < 32)
add_n v0, v1, v0, v1, v4, v5, \sz, \n // cdf + (32768 - cdf[i]) >> rate
add_n v0, v1, v16, v17, v2, v3, \sz, \n // cdf + (32768 - cdf[i]) >> rate
st1_n v0, v1, x1, \sz, \n
strh w3, [x1, x2, lsl #1]
.endm

View File

@ -55,8 +55,8 @@ typedef struct {
static void randomize_cdf(uint16_t *const cdf, const int n) {
int i;
for (i = 15; i > n; i--)
cdf[i] = rnd(); // padding
cdf[i] = 0; // count
cdf[i] = 0; // padding
cdf[i] = 0; // count
do {
cdf[i - 1] = cdf[i] + rnd() % (32768 - cdf[i] - i) + 1;
} while (--i > 0);