shithub: dav1d

Download patch

ref: e3dbf92664918ecc830b4fde74b7cc0f6cd2065c
parent: 7cf5d7535f44d7c2d00e368575d0d26b66c73121
author: Martin Storsjö <[email protected]>
date: Mon Feb 10 05:03:27 EST 2020

arm64: looprestoration: NEON implementation of SGR for 10 bpc

This only supports 10 bpc, not 12 bpc, as the sum and tmp buffers can
be int16_t for 10 bpc, but need to be int32_t for 12 bpc.

Make actual templates out of the functions in looprestoration_tmpl.S,
and add box3/5_h to looprestoration16.S.

Extend dav1d_sgr_calc_abX_neon with a mandatory bitdepth_max parameter
(which is passed even in 8bpc mode), add a define to bitdepth.h for
passing such a parameter in all modes. This makes this function
a few instructions slower in 8bpc mode than it was before (overall impact
seems to be around 1% of the total runtime of SGR), but allows using the
same actual function instantiation for all modes, saving a bit of code
size.

Examples of checkasm runtimes:
                           Cortex A53        A72        A73
selfguided_3x3_10bpc_neon:   516755.8   389412.7   349058.7
selfguided_5x5_10bpc_neon:   380699.9   293486.6   254591.6
selfguided_mix_10bpc_neon:   878142.3   667495.9   587844.6

Corresponding 8 bpc numbers for comparison:
selfguided_3x3_8bpc_neon:    491058.1   361473.4   347705.9
selfguided_5x5_8bpc_neon:    352655.0   266423.7   248192.2
selfguided_mix_8bpc_neon:    826094.1   612372.2   581943.1

--- a/include/common/bitdepth.h
+++ b/include/common/bitdepth.h
@@ -56,6 +56,7 @@
 #define HIGHBD_CALL_SUFFIX /* nothing */
 #define HIGHBD_TAIL_SUFFIX /* nothing */
 #define bitdepth_from_max(x) 8
+#define BITDEPTH_MAX 0xff
 #elif BITDEPTH == 16
 typedef uint16_t pixel;
 typedef int32_t coef;
@@ -72,6 +73,7 @@
 #define HIGHBD_CALL_SUFFIX , f->bitdepth_max
 #define HIGHBD_TAIL_SUFFIX , bitdepth_max
 #define bitdepth_from_max(bitdepth_max) (32 - clz(bitdepth_max))
+#define BITDEPTH_MAX bitdepth_max
 #define bitfn(x) x##_16bpc
 #define BF(x, suffix) x##_16bpc_##suffix
 static inline ptrdiff_t PXSTRIDE(const ptrdiff_t x) {
--- a/src/arm/64/looprestoration.S
+++ b/src/arm/64/looprestoration.S
@@ -1148,3 +1148,5 @@
         ret
 .purgem add5
 endfunc
+
+sgr_funcs 8
--- a/src/arm/64/looprestoration16.S
+++ b/src/arm/64/looprestoration16.S
@@ -678,3 +678,562 @@
         .hword L(copy_narrow_tbl) - 60b
         .hword L(copy_narrow_tbl) - 70b
 endfunc
+
+#define SUM_STRIDE (384+16)
+
+#include "looprestoration_tmpl.S"
+
+// void dav1d_sgr_box3_h_16bpc_neon(int32_t *sumsq, int16_t *sum,
+//                                  const pixel (*left)[4],
+//                                  const pixel *src, const ptrdiff_t stride,
+//                                  const int w, const int h,
+//                                  const enum LrEdgeFlags edges);
+function sgr_box3_h_16bpc_neon, export=1
+        add             w5,  w5,  #2 // w += 2
+
+        // Set up pointers for reading/writing alternate rows
+        add             x10, x0,  #(4*SUM_STRIDE)   // sumsq
+        add             x11, x1,  #(2*SUM_STRIDE)   // sum
+        add             x12, x3,  x4                // src
+        lsl             x4,  x4,  #1
+        mov             x9,       #(2*2*SUM_STRIDE) // double sum stride
+
+        // Subtract the aligned width from the output stride.
+        // With LR_HAVE_RIGHT, align to 8, without it, align to 4.
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        b.ne            0f
+        // !LR_HAVE_RIGHT
+        add             w13, w5,  #3
+        bic             w13, w13, #3
+        b               1f
+0:
+        add             w13, w5,  #7
+        bic             w13, w13, #7
+1:
+        sub             x9,  x9,  w13, uxtw #1
+
+        // Store the width for the vertical loop
+        mov             w8,  w5
+
+        // Subtract the number of pixels read from the input from the stride
+        add             w13, w5,  #14
+        bic             w13, w13, #7
+        sub             x4,  x4,  w13, uxtw #1
+
+        // Set up the src pointers to include the left edge, for LR_HAVE_LEFT, left == NULL
+        tst             w7,  #1 // LR_HAVE_LEFT
+        b.eq            2f
+        // LR_HAVE_LEFT
+        cbnz            x2,  0f
+        // left == NULL
+        sub             x3,  x3,  #4
+        sub             x12, x12, #4
+        b               1f
+0:      // LR_HAVE_LEFT, left != NULL
+2:      // !LR_HAVE_LEFT, increase the stride.
+        // For this case we don't read the left 2 pixels from the src pointer,
+        // but shift it as if we had done that.
+        add             x4,  x4,  #4
+
+
+1:      // Loop vertically
+        ld1             {v0.8h, v1.8h},   [x3],  #32
+        ld1             {v16.8h, v17.8h}, [x12], #32
+
+        tst             w7,  #1 // LR_HAVE_LEFT
+        b.eq            0f
+        cbz             x2,  2f
+        // LR_HAVE_LEFT, left != NULL
+        ld1             {v2.d}[1],  [x2], #8
+        // Move x3/x12 back to account for the last 2 pixels we loaded earlier,
+        // which we'll shift out.
+        sub             x3,  x3,  #4
+        sub             x12, x12, #4
+        ld1             {v18.d}[1], [x2], #8
+        ext             v1.16b,  v0.16b,  v1.16b,  #12
+        ext             v0.16b,  v2.16b,  v0.16b,  #12
+        ext             v17.16b, v16.16b, v17.16b, #12
+        ext             v16.16b, v18.16b, v16.16b, #12
+        b               2f
+0:
+        // !LR_HAVE_LEFT, fill v2 with the leftmost pixel
+        // and shift v0/v1 to have 2x the first pixel at the front.
+        dup             v2.8h,  v0.h[0]
+        dup             v18.8h, v16.h[0]
+        // Move x3 back to account for the last 2 pixels we loaded before,
+        // which we shifted out.
+        sub             x3,  x3,  #4
+        sub             x12, x12, #4
+        ext             v1.16b,  v0.16b,  v1.16b,  #12
+        ext             v0.16b,  v2.16b,  v0.16b,  #12
+        ext             v17.16b, v16.16b, v17.16b, #12
+        ext             v16.16b, v18.16b, v16.16b, #12
+
+2:
+        umull           v2.4s,   v0.4h,   v0.4h
+        umull2          v3.4s,   v0.8h,   v0.8h
+        umull           v4.4s,   v1.4h,   v1.4h
+        umull           v18.4s,  v16.4h,  v16.4h
+        umull2          v19.4s,  v16.8h,  v16.8h
+        umull           v20.4s,  v17.4h,  v17.4h
+
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        b.ne            4f
+        // If we'll need to pad the right edge, load that byte to pad with
+        // here since we can find it pretty easily from here.
+        sub             w13, w5, #(2 + 16 - 2 + 1)
+        ldr             h30, [x3,  w13, sxtw #1]
+        ldr             h31, [x12, w13, sxtw #1]
+        // Fill v30/v31 with the right padding pixel
+        dup             v30.8h,  v30.h[0]
+        dup             v31.8h,  v31.h[0]
+3:      // !LR_HAVE_RIGHT
+        // If we'll have to pad the right edge we need to quit early here.
+        cmp             w5,  #10
+        b.ge            4f   // If w >= 10, all used input pixels are valid
+        cmp             w5,  #6
+        b.ge            5f   // If w >= 6, we can filter 4 pixels
+        b               6f
+
+4:      // Loop horizontally
+.macro ext_n            dst1, dst2, src1, src2, src3, n, w
+        ext             \dst1,  \src1,  \src2,  \n
+.if \w > 4
+        ext             \dst2,  \src2,  \src3,  \n
+.endif
+.endm
+.macro add_n            dst1, dst2, src1, src2, src3, src4, w
+        add             \dst1,  \src1,  \src3
+.if \w > 4
+        add             \dst2,  \src2,  \src4
+.endif
+.endm
+
+.macro add3 w, wd
+        ext             v24.16b, v0.16b,  v1.16b,  #2
+        ext             v25.16b, v0.16b,  v1.16b,  #4
+        ext             v26.16b, v16.16b, v17.16b, #2
+        ext             v27.16b, v16.16b, v17.16b, #4
+        add             v6\wd,   v0\wd,   v24\wd
+        add             v7\wd,   v16\wd,  v26\wd
+        add             v6\wd,   v6\wd,   v25\wd
+        add             v7\wd,   v7\wd,   v27\wd
+
+        ext_n           v24.16b, v25.16b, v2.16b,  v3.16b,  v4.16b,  #4, \w
+        ext_n           v26.16b, v27.16b, v2.16b,  v3.16b,  v4.16b,  #8, \w
+
+        add_n           v22.4s,  v23.4s,  v2.4s,   v3.4s,   v24.4s,  v25.4s,  \w
+        add_n           v22.4s,  v23.4s,  v22.4s,  v23.4s,  v26.4s,  v27.4s,  \w
+
+        ext_n           v24.16b, v25.16b, v18.16b, v19.16b, v20.16b, #4, \w
+        ext_n           v26.16b, v27.16b, v18.16b, v19.16b, v20.16b, #8, \w
+
+        add_n           v24.4s,  v25.4s,  v18.4s,  v19.4s,  v24.4s,  v25.4s,  \w
+        add_n           v24.4s,  v25.4s,  v24.4s,  v25.4s,  v26.4s,  v27.4s,  \w
+.endm
+        add3            8, .8h
+        st1             {v6.8h},         [x1],  #16
+        st1             {v7.8h},         [x11], #16
+        st1             {v22.4s,v23.4s}, [x0],  #32
+        st1             {v24.4s,v25.4s}, [x10], #32
+
+        subs            w5,  w5,  #8
+        b.le            9f
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        mov             v0.16b,  v1.16b
+        mov             v16.16b, v17.16b
+        ld1             {v1.8h},  [x3],  #16
+        ld1             {v17.8h}, [x12], #16
+        mov             v2.16b,  v4.16b
+        umull2          v3.4s,   v0.8h,   v0.8h
+        umull           v4.4s,   v1.4h,   v1.4h
+        mov             v18.16b, v20.16b
+        umull2          v19.4s,  v16.8h,  v16.8h
+        umull           v20.4s,  v17.4h,  v17.4h
+
+        b.ne            4b // If we don't need to pad, just keep summing.
+        b               3b // If we need to pad, check how many pixels we have left.
+
+5:      // Produce 4 pixels, 6 <= w < 10
+        add3            4, .4h
+        st1             {v6.4h},  [x1],  #8
+        st1             {v7.4h},  [x11], #8
+        st1             {v22.4s}, [x0],  #16
+        st1             {v24.4s}, [x10], #16
+
+        subs            w5,  w5,  #4 // 2 <= w < 6
+        ext             v0.16b,  v0.16b,  v1.16b,  #8
+        ext             v16.16b, v16.16b, v17.16b, #8
+
+6:      // Pad the right edge and produce the last few pixels.
+        // 2 <= w < 6, 2-5 pixels valid in v0
+        sub             w13,  w5,  #2
+        // w13 = (pixels valid - 2)
+        adr             x14, L(box3_variable_shift_tbl)
+        ldrh            w13, [x14, w13, uxtw #1]
+        sub             x13, x14, w13, uxth
+        br              x13
+        // Shift v0 right, shifting out invalid pixels,
+        // shift v0 left to the original offset, shifting in padding pixels.
+22:     // 2 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #4
+        ext             v16.16b, v16.16b, v16.16b, #4
+        ext             v0.16b,  v0.16b,  v30.16b, #12
+        ext             v16.16b, v16.16b, v31.16b, #12
+        b               88f
+33:     // 3 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #6
+        ext             v16.16b, v16.16b, v16.16b, #6
+        ext             v0.16b,  v0.16b,  v30.16b, #10
+        ext             v16.16b, v16.16b, v31.16b, #10
+        b               88f
+44:     // 4 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #8
+        ext             v16.16b, v16.16b, v16.16b, #8
+        ext             v0.16b,  v0.16b,  v30.16b, #8
+        ext             v16.16b, v16.16b, v31.16b, #8
+        b               88f
+55:     // 5 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #10
+        ext             v16.16b, v16.16b, v16.16b, #10
+        ext             v0.16b,  v0.16b,  v30.16b, #6
+        ext             v16.16b, v16.16b, v31.16b, #6
+        b               88f
+
+L(box3_variable_shift_tbl):
+        .hword L(box3_variable_shift_tbl) - 22b
+        .hword L(box3_variable_shift_tbl) - 33b
+        .hword L(box3_variable_shift_tbl) - 44b
+        .hword L(box3_variable_shift_tbl) - 55b
+
+88:
+        umull           v2.4s,   v0.4h,   v0.4h
+        umull2          v3.4s,   v0.8h,   v0.8h
+        umull           v18.4s,  v16.4h,  v16.4h
+        umull2          v19.4s,  v16.8h,  v16.8h
+
+        add3            4, .4h
+        subs            w5,  w5,  #4
+        st1             {v6.4h},  [x1],  #8
+        st1             {v7.4h},  [x11], #8
+        st1             {v22.4s}, [x0],  #16
+        st1             {v24.4s}, [x10], #16
+        b.le            9f
+        ext             v0.16b,  v0.16b,  v0.16b,  #8
+        ext             v16.16b, v16.16b, v16.16b, #8
+        mov             v2.16b,  v3.16b
+        mov             v3.16b,  v4.16b
+        mov             v18.16b, v19.16b
+        mov             v19.16b, v20.16b
+        // Only one needed pixel left, but do a normal 4 pixel
+        // addition anyway
+        add3            4, .4h
+        st1             {v6.4h},  [x1],  #8
+        st1             {v7.4h},  [x11], #8
+        st1             {v22.4s}, [x0],  #16
+        st1             {v24.4s}, [x10], #16
+
+9:
+        subs            w6,  w6,  #2
+        b.le            0f
+        // Jump to the next row and loop horizontally
+        add             x0,  x0,  x9, lsl #1
+        add             x10, x10, x9, lsl #1
+        add             x1,  x1,  x9
+        add             x11, x11, x9
+        add             x3,  x3,  x4
+        add             x12, x12, x4
+        mov             w5,  w8
+        b               1b
+0:
+        ret
+.purgem add3
+endfunc
+
+// void dav1d_sgr_box5_h_16bpc_neon(int32_t *sumsq, int16_t *sum,
+//                                  const pixel (*left)[4],
+//                                  const pixel *src, const ptrdiff_t stride,
+//                                  const int w, const int h,
+//                                  const enum LrEdgeFlags edges);
+function sgr_box5_h_16bpc_neon, export=1
+        add             w5,  w5,  #2 // w += 2
+
+        // Set up pointers for reading/writing alternate rows
+        add             x10, x0,  #(4*SUM_STRIDE)   // sumsq
+        add             x11, x1,  #(2*SUM_STRIDE)   // sum
+        add             x12, x3,  x4                // src
+        lsl             x4,  x4,  #1
+        mov             x9,       #(2*2*SUM_STRIDE) // double sum stride
+
+        // Subtract the aligned width from the output stride.
+        // With LR_HAVE_RIGHT, align to 8, without it, align to 4.
+        // Subtract the number of pixels read from the input from the stride.
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        b.ne            0f
+        // !LR_HAVE_RIGHT
+        add             w13, w5,  #3
+        bic             w13, w13, #3
+        add             w14, w5,  #13
+        b               1f
+0:
+        add             w13, w5,  #7
+        bic             w13, w13, #7
+        add             w14, w5,  #15
+1:
+        sub             x9,  x9,  w13, uxtw #1
+        bic             w14, w14, #7
+        sub             x4,  x4,  w14, uxtw #1
+
+        // Store the width for the vertical loop
+        mov             w8,  w5
+
+        // Set up the src pointers to include the left edge, for LR_HAVE_LEFT, left == NULL
+        tst             w7,  #1 // LR_HAVE_LEFT
+        b.eq            2f
+        // LR_HAVE_LEFT
+        cbnz            x2,  0f
+        // left == NULL
+        sub             x3,  x3,  #6
+        sub             x12, x12, #6
+        b               1f
+0:      // LR_HAVE_LEFT, left != NULL
+2:      // !LR_HAVE_LEFT, increase the stride.
+        // For this case we don't read the left 3 pixels from the src pointer,
+        // but shift it as if we had done that.
+        add             x4,  x4,  #6
+
+1:      // Loop vertically
+        ld1             {v0.8h, v1.8h},   [x3],  #32
+        ld1             {v16.8h, v17.8h}, [x12], #32
+
+        tst             w7,  #1 // LR_HAVE_LEFT
+        b.eq            0f
+        cbz             x2,  2f
+        // LR_HAVE_LEFT, left != NULL
+        ld1             {v2.d}[1],  [x2], #8
+        // Move x3/x12 back to account for the last 3 pixels we loaded earlier,
+        // which we'll shift out.
+        sub             x3,  x3,  #6
+        sub             x12, x12, #6
+        ld1             {v18.d}[1],  [x2], #8
+        ext             v1.16b,  v0.16b,  v1.16b,  #10
+        ext             v0.16b,  v2.16b,  v0.16b,  #10
+        ext             v17.16b, v16.16b, v17.16b, #10
+        ext             v16.16b, v18.16b, v16.16b, #10
+        b               2f
+0:
+        // !LR_HAVE_LEFT, fill v2 with the leftmost pixel
+        // and shift v0/v1 to have 3x the first pixel at the front.
+        dup             v2.8h,  v0.h[0]
+        dup             v18.8h, v16.h[0]
+        // Move x3 back to account for the last 6 bytes we loaded before,
+        // which we shifted out.
+        sub             x3,  x3,  #6
+        sub             x12, x12, #6
+        ext             v1.16b,  v0.16b,  v1.16b,  #10
+        ext             v0.16b,  v2.16b,  v0.16b,  #10
+        ext             v17.16b, v16.16b, v17.16b, #10
+        ext             v16.16b, v18.16b, v16.16b, #10
+
+2:
+        umull           v2.4s,   v0.4h,   v0.4h
+        umull2          v3.4s,   v0.8h,   v0.8h
+        umull           v4.4s,   v1.4h,   v1.4h
+        umull           v18.4s,  v16.4h,  v16.4h
+        umull2          v19.4s,  v16.8h,  v16.8h
+        umull           v20.4s,  v17.4h,  v17.4h
+
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        b.ne            4f
+        // If we'll need to pad the right edge, load that byte to pad with
+        // here since we can find it pretty easily from here.
+        sub             w13, w5, #(2 + 16 - 3 + 1)
+        ldr             h30, [x3,  w13, sxtw #1]
+        ldr             h31, [x12, w13, sxtw #1]
+        // Fill v30/v31 with the right padding pixel
+        dup             v30.8h,  v30.h[0]
+        dup             v31.8h,  v31.h[0]
+3:      // !LR_HAVE_RIGHT
+        // If we'll have to pad the right edge we need to quit early here.
+        cmp             w5,  #11
+        b.ge            4f   // If w >= 11, all used input pixels are valid
+        cmp             w5,  #7
+        b.ge            5f   // If w >= 7, we can produce 4 pixels
+        b               6f
+
+4:      // Loop horizontally
+.macro add5 w, wd
+        ext             v24.16b, v0.16b,  v1.16b,  #2
+        ext             v25.16b, v0.16b,  v1.16b,  #4
+        ext             v26.16b, v0.16b,  v1.16b,  #6
+        ext             v27.16b, v0.16b,  v1.16b,  #8
+
+        add             v6\wd,   v0\wd,   v24\wd
+        add             v25\wd,  v25\wd,  v26\wd
+        add             v6\wd,   v6\wd,   v27\wd
+
+        ext             v26.16b, v16.16b, v17.16b, #2
+        ext             v27.16b, v16.16b, v17.16b, #4
+        ext             v28.16b, v16.16b, v17.16b, #6
+        ext             v29.16b, v16.16b, v17.16b, #8
+
+        add             v7\wd,   v16\wd,  v26\wd
+        add             v27\wd,  v27\wd,  v28\wd
+        add             v7\wd,   v7\wd,   v29\wd
+        add             v6\wd,   v6\wd,   v25\wd
+        add             v7\wd,   v7\wd,   v27\wd
+
+        ext_n           v24.16b, v25.16b, v2.16b,  v3.16b,  v4.16b,  #4,  \w
+        ext_n           v26.16b, v27.16b, v2.16b,  v3.16b,  v4.16b,  #8,  \w
+        ext_n           v28.16b, v29.16b, v2.16b,  v3.16b,  v4.16b,  #12, \w
+
+        add_n           v22.4s,  v23.4s,  v2.4s,   v3.4s,   v24.4s,  v25.4s,  \w
+        add_n           v26.4s,  v27.4s,  v26.4s,  v27.4s,  v28.4s,  v29.4s,  \w
+        add_n           v22.4s,  v23.4s,  v22.4s,  v23.4s,  v3.4s,   v4.4s,   \w
+        add_n           v22.4s,  v23.4s,  v22.4s,  v23.4s,  v26.4s,  v27.4s,  \w
+
+        ext_n           v24.16b, v25.16b, v18.16b, v19.16b, v20.16b, #4,  \w
+        ext_n           v26.16b, v27.16b, v18.16b, v19.16b, v20.16b, #8,  \w
+        ext_n           v28.16b, v29.16b, v18.16b, v19.16b, v20.16b, #12, \w
+
+        add_n           v24.4s,  v25.4s,  v18.4s,  v19.4s,  v24.4s,  v25.4s,  \w
+        add_n           v26.4s,  v27.4s,  v26.4s,  v27.4s,  v28.4s,  v29.4s,  \w
+        add_n           v24.4s,  v25.4s,  v24.4s,  v25.4s,  v19.4s,  v20.4s,  \w
+        add_n           v24.4s,  v25.4s,  v24.4s,  v25.4s,  v26.4s,  v27.4s,  \w
+.endm
+        add5            8, .8h
+        st1             {v6.8h},         [x1],  #16
+        st1             {v7.8h},         [x11], #16
+        st1             {v22.4s,v23.4s}, [x0],  #32
+        st1             {v24.4s,v25.4s}, [x10], #32
+
+        subs            w5,  w5,  #8
+        b.le            9f
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        mov             v0.16b,  v1.16b
+        mov             v16.16b, v17.16b
+        ld1             {v1.8h},  [x3],  #16
+        ld1             {v17.8h}, [x12], #16
+        mov             v2.16b,  v4.16b
+        umull2          v3.4s,   v0.8h,   v0.8h
+        umull           v4.4s,   v1.4h,   v1.4h
+        mov             v18.16b, v20.16b
+        umull2          v19.4s,  v16.8h,  v16.8h
+        umull           v20.4s,  v17.4h,  v17.4h
+
+        b.ne            4b // If we don't need to pad, just keep summing.
+        b               3b // If we need to pad, check how many pixels we have left.
+
+5:      // Produce 4 pixels, 7 <= w < 11
+        add5            4, .4h
+        st1             {v6.4h},  [x1],  #8
+        st1             {v7.4h},  [x11], #8
+        st1             {v22.4s}, [x0],  #16
+        st1             {v24.4s}, [x10], #16
+
+        subs            w5,  w5,  #4 // 3 <= w < 7
+        ext             v0.16b,  v0.16b,  v1.16b,  #8
+        ext             v16.16b, v16.16b, v17.16b, #8
+
+6:      // Pad the right edge and produce the last few pixels.
+        // w < 7, w+1 pixels valid in v0/v4
+        sub             w13,  w5,  #1
+        // w13 = pixels valid - 2
+        adr             x14, L(box5_variable_shift_tbl)
+        ldrh            w13, [x14, w13, uxtw #1]
+        mov             v1.16b,  v30.16b
+        mov             v17.16b, v31.16b
+        sub             x13, x14, w13, uxth
+        br              x13
+        // Shift v0 right, shifting out invalid pixels,
+        // shift v0 left to the original offset, shifting in padding pixels.
+22:     // 2 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #4
+        ext             v16.16b, v16.16b, v16.16b, #4
+        ext             v0.16b,  v0.16b,  v30.16b, #12
+        ext             v16.16b, v16.16b, v31.16b, #12
+        b               88f
+33:     // 3 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #6
+        ext             v16.16b, v16.16b, v16.16b, #6
+        ext             v0.16b,  v0.16b,  v30.16b, #10
+        ext             v16.16b, v16.16b, v31.16b, #10
+        b               88f
+44:     // 4 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #8
+        ext             v16.16b, v16.16b, v16.16b, #8
+        ext             v0.16b,  v0.16b,  v30.16b, #8
+        ext             v16.16b, v16.16b, v31.16b, #8
+        b               88f
+55:     // 5 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #10
+        ext             v16.16b, v16.16b, v16.16b, #10
+        ext             v0.16b,  v0.16b,  v30.16b, #6
+        ext             v16.16b, v16.16b, v31.16b, #6
+        b               88f
+66:     // 6 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #12
+        ext             v16.16b, v16.16b, v16.16b, #12
+        ext             v0.16b,  v0.16b,  v30.16b, #4
+        ext             v16.16b, v16.16b, v31.16b, #4
+        b               88f
+77:     // 7 pixels valid
+        ext             v0.16b,  v0.16b,  v0.16b,  #14
+        ext             v16.16b, v16.16b, v16.16b, #14
+        ext             v0.16b,  v0.16b,  v30.16b, #2
+        ext             v16.16b, v16.16b, v31.16b, #2
+        b               88f
+
+L(box5_variable_shift_tbl):
+        .hword L(box5_variable_shift_tbl) - 22b
+        .hword L(box5_variable_shift_tbl) - 33b
+        .hword L(box5_variable_shift_tbl) - 44b
+        .hword L(box5_variable_shift_tbl) - 55b
+        .hword L(box5_variable_shift_tbl) - 66b
+        .hword L(box5_variable_shift_tbl) - 77b
+
+88:
+        umull           v2.4s,   v0.4h,   v0.4h
+        umull2          v3.4s,   v0.8h,   v0.8h
+        umull           v4.4s,   v1.4h,   v1.4h
+        umull           v18.4s,  v16.4h,  v16.4h
+        umull2          v19.4s,  v16.8h,  v16.8h
+        umull           v20.4s,  v17.4h,  v17.4h
+
+        add5            4, .4h
+        subs            w5,  w5,  #4
+        st1             {v6.4h},  [x1],  #8
+        st1             {v7.4h},  [x11], #8
+        st1             {v22.4s}, [x0],  #16
+        st1             {v24.4s}, [x10], #16
+        b.le            9f
+        ext             v0.16b,  v0.16b,  v1.16b,  #8
+        ext             v16.16b, v16.16b, v17.16b, #8
+        mov             v2.16b,  v3.16b
+        mov             v3.16b,  v4.16b
+        mov             v18.16b, v19.16b
+        mov             v19.16b, v20.16b
+        add5            4, .4h
+        st1             {v6.4h},  [x1],  #8
+        st1             {v7.4h},  [x11], #8
+        st1             {v22.4s}, [x0],  #16
+        st1             {v24.4s}, [x10], #16
+
+9:
+        subs            w6,  w6,  #2
+        b.le            0f
+        // Jump to the next row and loop horizontally
+        add             x0,  x0,  x9, lsl #1
+        add             x10, x10, x9, lsl #1
+        add             x1,  x1,  x9
+        add             x11, x11, x9
+        add             x3,  x3,  x4
+        add             x12, x12, x4
+        mov             w5,  w8
+        b               1b
+0:
+        ret
+.purgem add5
+endfunc
+
+sgr_funcs 16
--- a/src/arm/64/looprestoration_common.S
+++ b/src/arm/64/looprestoration_common.S
@@ -328,10 +328,13 @@
 endfunc
 
 // void dav1d_sgr_calc_ab1_neon(int32_t *a, int16_t *b,
-//                              const int w, const int h, const int strength);
+//                              const int w, const int h, const int strength,
+//                              const int bitdepth_max);
 // void dav1d_sgr_calc_ab2_neon(int32_t *a, int16_t *b,
-//                              const int w, const int h, const int strength);
+//                              const int w, const int h, const int strength,
+//                              const int bitdepth_max);
 function sgr_calc_ab1_neon, export=1
+        clz             w9,  w5
         add             x3,  x3,  #2 // h += 2
         movi            v31.4s,   #9 // n
         mov             x5,  #455
@@ -340,6 +343,7 @@
 endfunc
 
 function sgr_calc_ab2_neon, export=1
+        clz             w9,  w5
         add             x3,  x3,  #3  // h += 3
         asr             x3,  x3,  #1  // h /= 2
         movi            v31.4s,   #25 // n
@@ -348,8 +352,10 @@
 endfunc
 
 function sgr_calc_ab_neon
+        sub             w9,  w9,  #24  // -bitdepth_min_8
         movrel          x12, X(sgr_x_by_x)
         ld1             {v16.16b, v17.16b, v18.16b}, [x12]
+        dup             v6.8h,    w9   // -bitdepth_min_8
         movi            v19.16b,  #5
         movi            v20.8b,   #55  // idx of last 5
         movi            v21.8b,   #72  // idx of last 4
@@ -356,6 +362,7 @@
         movi            v22.8b,   #101 // idx of last 3
         movi            v23.8b,   #169 // idx of last 2
         movi            v24.8b,   #254 // idx of last 1
+        saddl           v7.4s,    v6.4h,   v6.4h  // -2*bitdepth_min_8
         add             x2,  x2,  #2 // w += 2
         add             x7,  x2,  #7
         bic             x7,  x7,  #7 // aligned w
@@ -373,10 +380,13 @@
         subs            x2,  x2,  #8
         ld1             {v0.4s, v1.4s}, [x0]   // a
         ld1             {v2.8h}, [x1]          // b
+        srshl           v0.4s,  v0.4s,  v7.4s
+        srshl           v1.4s,  v1.4s,  v7.4s
+        srshl           v4.8h,  v2.8h,  v6.8h
         mul             v0.4s,  v0.4s,  v31.4s // a * n
         mul             v1.4s,  v1.4s,  v31.4s // a * n
-        umull           v3.4s,  v2.4h,  v2.4h  // b * b
-        umull2          v4.4s,  v2.8h,  v2.8h  // b * b
+        umull           v3.4s,  v4.4h,  v4.4h  // b * b
+        umull2          v4.4s,  v4.8h,  v4.8h  // b * b
         uqsub           v0.4s,  v0.4s,  v3.4s  // imax(a * n - b * b, 0)
         uqsub           v1.4s,  v1.4s,  v4.4s  // imax(a * n - b * b, 0)
         mul             v0.4s,  v0.4s,  v28.4s // p * s
@@ -389,13 +399,13 @@
         cmhi            v26.8b, v0.8b,  v21.8b // = -1 if sgr_x_by_x[v0] < 4
         tbl             v1.8b, {v16.16b, v17.16b, v18.16b}, v0.8b
         cmhi            v27.8b, v0.8b,  v22.8b // = -1 if sgr_x_by_x[v0] < 3
-        cmhi            v5.8b,  v0.8b,  v23.8b // = -1 if sgr_x_by_x[v0] < 2
+        cmhi            v4.8b,  v0.8b,  v23.8b // = -1 if sgr_x_by_x[v0] < 2
         add             v25.8b, v25.8b, v26.8b
-        cmhi            v6.8b,  v0.8b,  v24.8b // = -1 if sgr_x_by_x[v0] < 1
-        add             v27.8b, v27.8b, v5.8b
-        add             v6.8b,  v6.8b,  v19.8b
+        cmhi            v5.8b,  v0.8b,  v24.8b // = -1 if sgr_x_by_x[v0] < 1
+        add             v27.8b, v27.8b, v4.8b
+        add             v5.8b,  v5.8b,  v19.8b
         add             v25.8b, v25.8b, v27.8b
-        add             v1.8b,  v1.8b,  v6.8b
+        add             v1.8b,  v1.8b,  v5.8b
         add             v1.8b,  v1.8b,  v25.8b
         uxtl            v1.8h,  v1.8b          // x
 
--- a/src/arm/64/looprestoration_tmpl.S
+++ b/src/arm/64/looprestoration_tmpl.S
@@ -29,11 +29,12 @@
 
 #define FILTER_OUT_STRIDE 384
 
-// void dav1d_sgr_finish_filter1_8bpc_neon(int16_t *tmp,
+.macro sgr_funcs bpc
+// void dav1d_sgr_finish_filter1_Xbpc_neon(int16_t *tmp,
 //                                         const pixel *src, const ptrdiff_t stride,
 //                                         const int32_t *a, const int16_t *b,
 //                                         const int w, const int h);
-function sgr_finish_filter1_8bpc_neon, export=1
+function sgr_finish_filter1_\bpc\()bpc_neon, export=1
         sub             x7,  x3,  #(4*SUM_STRIDE)
         add             x8,  x3,  #(4*SUM_STRIDE)
         sub             x9,  x4,  #(2*SUM_STRIDE)
@@ -42,7 +43,11 @@
         mov             x12, #FILTER_OUT_STRIDE
         add             x13, x5,  #7
         bic             x13, x13, #7 // Aligned width
+.if \bpc == 8
         sub             x2,  x2,  x13
+.else
+        sub             x2,  x2,  x13, lsl #1
+.endif
         sub             x12, x12, x13
         sub             x11, x11, x13
         sub             x11, x11, #4 // We read 4 extra elements from a
@@ -98,7 +103,11 @@
         ext             v28.16b, v23.16b, v24.16b, #4
         ext             v29.16b, v22.16b, v23.16b, #8 // +1+stride
         ext             v30.16b, v23.16b, v24.16b, #8
+.if \bpc == 8
         ld1             {v19.8b}, [x1], #8            // src
+.else
+        ld1             {v19.8h}, [x1], #16           // src
+.endif
         add             v25.4s,  v25.4s,  v27.4s      // +stride
         add             v26.4s,  v26.4s,  v28.4s
         add             v16.4s,  v16.4s,  v29.4s      // +1+stride
@@ -107,7 +116,9 @@
         shl             v26.4s,  v26.4s,  #2
         mla             v25.4s,  v16.4s,  v7.4s       // * 3 -> b
         mla             v26.4s,  v17.4s,  v7.4s
+.if \bpc == 8
         uxtl            v19.8h,  v19.8b               // src
+.endif
         mov             v0.16b,  v1.16b
         umlal           v25.4s,  v2.4h,   v19.4h      // b + a * src
         umlal2          v26.4s,  v2.8h,   v19.8h
@@ -146,11 +157,11 @@
         ret
 endfunc
 
-// void dav1d_sgr_finish_filter2_8bpc_neon(int16_t *tmp,
+// void dav1d_sgr_finish_filter2_Xbpc_neon(int16_t *tmp,
 //                                         const pixel *src, const ptrdiff_t stride,
 //                                         const int32_t *a, const int16_t *b,
 //                                         const int w, const int h);
-function sgr_finish_filter2_8bpc_neon, export=1
+function sgr_finish_filter2_\bpc\()bpc_neon, export=1
         add             x7,  x3,  #(4*(SUM_STRIDE))
         sub             x3,  x3,  #(4*(SUM_STRIDE))
         add             x8,  x4,  #(2*(SUM_STRIDE))
@@ -159,7 +170,11 @@
         mov             x10, #FILTER_OUT_STRIDE
         add             x11, x5,  #7
         bic             x11, x11, #7 // Aligned width
+.if \bpc == 8
         sub             x2,  x2,  x11
+.else
+        sub             x2,  x2,  x11, lsl #1
+.endif
         sub             x10, x10, x11
         sub             x9,  x9,  x11
         sub             x9,  x9,  #4 // We read 4 extra elements from a
@@ -196,7 +211,11 @@
         ext             v29.16b, v20.16b, v21.16b, #8
         mul             v0.8h,   v0.8h,   v4.8h       // * 5
         mla             v0.8h,   v2.8h,   v6.8h       // * 6
+.if \bpc == 8
         ld1             {v31.8b}, [x1], #8
+.else
+        ld1             {v31.8h}, [x1], #16
+.endif
         add             v16.4s,  v16.4s,  v26.4s      // -1-stride, +1-stride
         add             v17.4s,  v17.4s,  v27.4s
         add             v19.4s,  v19.4s,  v28.4s      // -1+stride, +1+stride
@@ -213,7 +232,9 @@
         mul             v17.4s,  v17.4s,  v5.4s       // * 5
         mla             v17.4s,  v23.4s,  v7.4s       // * 6
 
+.if \bpc == 8
         uxtl            v31.8h,  v31.8b
+.endif
         umlal           v16.4s,  v0.4h,   v31.4h      // b + a * src
         umlal2          v17.4s,  v0.8h,   v31.8h
         mov             v0.16b,  v1.16b
@@ -259,10 +280,16 @@
         ext             v27.16b, v17.16b, v18.16b, #8
         mul             v2.8h,   v22.8h,  v6.8h       // * 6
         mla             v2.8h,   v0.8h,   v4.8h       // * 5 -> a
+.if \bpc == 8
         ld1             {v31.8b}, [x1], #8
+.else
+        ld1             {v31.8h}, [x1], #16
+.endif
         add             v16.4s,  v16.4s,  v26.4s      // -1, +1
         add             v17.4s,  v17.4s,  v27.4s
+.if \bpc == 8
         uxtl            v31.8h,  v31.8b
+.endif
         // This is, surprisingly, faster than other variants where the
         // mul+mla pairs are further apart, on Cortex A53.
         mul             v24.4s,  v24.4s,  v7.4s       // * 6
@@ -296,13 +323,19 @@
         ret
 endfunc
 
-// void dav1d_sgr_weighted1_8bpc_neon(pixel *dst, const ptrdiff_t dst_stride,
+// void dav1d_sgr_weighted1_Xbpc_neon(pixel *dst, const ptrdiff_t dst_stride,
 //                                    const pixel *src, const ptrdiff_t src_stride,
 //                                    const int16_t *t1, const int w, const int h,
-//                                    const int wt);
-function sgr_weighted1_8bpc_neon, export=1
+//                                    const int wt, const int bitdepth_max);
+function sgr_weighted1_\bpc\()bpc_neon, export=1
+.if \bpc == 16
+        ldr             w8,  [sp]
+.endif
         dup             v31.8h, w7
         cmp             x6,  #2
+.if \bpc == 16
+        dup             v30.8h, w8
+.endif
         add             x9,  x0,  x1
         add             x10, x2,  x3
         add             x11, x4,  #2*FILTER_OUT_STRIDE
@@ -311,19 +344,34 @@
         lsl             x3,  x3,  #1
         add             x8,  x5,  #7
         bic             x8,  x8,  #7 // Aligned width
+.if \bpc == 8
         sub             x1,  x1,  x8
         sub             x3,  x3,  x8
+.else
+        sub             x1,  x1,  x8, lsl #1
+        sub             x3,  x3,  x8, lsl #1
+.endif
         sub             x7,  x7,  x8, lsl #1
         mov             x8,  x5
         b.lt            2f
 1:
+.if \bpc == 8
         ld1             {v0.8b}, [x2],  #8
         ld1             {v4.8b}, [x10], #8
+.else
+        ld1             {v0.8h}, [x2],  #16
+        ld1             {v4.8h}, [x10], #16
+.endif
         ld1             {v1.8h}, [x4],  #16
         ld1             {v5.8h}, [x11], #16
         subs            x5,  x5,  #8
+.if \bpc == 8
         ushll           v0.8h,  v0.8b,  #4     // u
         ushll           v4.8h,  v4.8b,  #4     // u
+.else
+        shl             v0.8h,  v0.8h,  #4     // u
+        shl             v4.8h,  v4.8h,  #4     // u
+.endif
         sub             v1.8h,  v1.8h,  v0.8h  // t1 - u
         sub             v5.8h,  v5.8h,  v4.8h  // t1 - u
         ushll           v2.4s,  v0.4h,  #7     // u << 7
@@ -334,6 +382,7 @@
         smlal2          v3.4s,  v1.8h,  v31.8h // v
         smlal           v6.4s,  v5.4h,  v31.4h // v
         smlal2          v7.4s,  v5.8h,  v31.8h // v
+.if \bpc == 8
         rshrn           v2.4h,  v2.4s,  #11
         rshrn2          v2.8h,  v3.4s,  #11
         rshrn           v6.4h,  v6.4s,  #11
@@ -342,6 +391,16 @@
         sqxtun          v6.8b,  v6.8h
         st1             {v2.8b}, [x0], #8
         st1             {v6.8b}, [x9], #8
+.else
+        sqrshrun        v2.4h,  v2.4s,  #11
+        sqrshrun2       v2.8h,  v3.4s,  #11
+        sqrshrun        v6.4h,  v6.4s,  #11
+        sqrshrun2       v6.8h,  v7.4s,  #11
+        umin            v2.8h,  v2.8h,  v30.8h
+        umin            v6.8h,  v6.8h,  v30.8h
+        st1             {v2.8h}, [x0], #16
+        st1             {v6.8h}, [x9], #16
+.endif
         b.gt            1b
 
         sub             x6,  x6,  #2
@@ -358,31 +417,50 @@
         b               1b
 
 2:
+.if \bpc == 8
         ld1             {v0.8b}, [x2], #8
+.else
+        ld1             {v0.8h}, [x2], #16
+.endif
         ld1             {v1.8h}, [x4], #16
         subs            x5,  x5,  #8
+.if \bpc == 8
         ushll           v0.8h,  v0.8b,  #4     // u
+.else
+        shl             v0.8h,  v0.8h,  #4     // u
+.endif
         sub             v1.8h,  v1.8h,  v0.8h  // t1 - u
         ushll           v2.4s,  v0.4h,  #7     // u << 7
         ushll2          v3.4s,  v0.8h,  #7     // u << 7
         smlal           v2.4s,  v1.4h,  v31.4h // v
         smlal2          v3.4s,  v1.8h,  v31.8h // v
+.if \bpc == 8
         rshrn           v2.4h,  v2.4s,  #11
         rshrn2          v2.8h,  v3.4s,  #11
         sqxtun          v2.8b,  v2.8h
         st1             {v2.8b}, [x0], #8
+.else
+        sqrshrun        v2.4h,  v2.4s,  #11
+        sqrshrun2       v2.8h,  v3.4s,  #11
+        umin            v2.8h,  v2.8h,  v30.8h
+        st1             {v2.8h}, [x0], #16
+.endif
         b.gt            2b
 0:
         ret
 endfunc
 
-// void dav1d_sgr_weighted2_8bpc_neon(pixel *dst, const ptrdiff_t stride,
+// void dav1d_sgr_weighted2_Xbpc_neon(pixel *dst, const ptrdiff_t stride,
 //                                    const pixel *src, const ptrdiff_t src_stride,
 //                                    const int16_t *t1, const int16_t *t2,
 //                                    const int w, const int h,
 //                                    const int16_t wt[2]);
-function sgr_weighted2_8bpc_neon, export=1
+function sgr_weighted2_\bpc\()bpc_neon, export=1
+.if \bpc == 8
         ldr             x8,  [sp]
+.else
+        ldp             x8,  x9,  [sp]
+.endif
         cmp             x7,  #2
         add             x10, x0,  x1
         add             x11, x2,  x3
@@ -389,26 +467,44 @@
         add             x12, x4,  #2*FILTER_OUT_STRIDE
         add             x13, x5,  #2*FILTER_OUT_STRIDE
         ld2r            {v30.8h, v31.8h}, [x8] // wt[0], wt[1]
+.if \bpc == 16
+        dup             v29.8h,  w9
+.endif
         mov             x8,  #4*FILTER_OUT_STRIDE
         lsl             x1,  x1,  #1
         lsl             x3,  x3,  #1
         add             x9,  x6,  #7
         bic             x9,  x9,  #7 // Aligned width
+.if \bpc == 8
         sub             x1,  x1,  x9
         sub             x3,  x3,  x9
+.else
+        sub             x1,  x1,  x9, lsl #1
+        sub             x3,  x3,  x9, lsl #1
+.endif
         sub             x8,  x8,  x9, lsl #1
         mov             x9,  x6
         b.lt            2f
 1:
+.if \bpc == 8
         ld1             {v0.8b},  [x2],  #8
         ld1             {v16.8b}, [x11], #8
+.else
+        ld1             {v0.8h},  [x2],  #16
+        ld1             {v16.8h}, [x11], #16
+.endif
         ld1             {v1.8h},  [x4],  #16
         ld1             {v17.8h}, [x12], #16
         ld1             {v2.8h},  [x5],  #16
         ld1             {v18.8h}, [x13], #16
         subs            x6,  x6,  #8
+.if \bpc == 8
         ushll           v0.8h,  v0.8b,  #4     // u
         ushll           v16.8h, v16.8b, #4     // u
+.else
+        shl             v0.8h,  v0.8h,  #4     // u
+        shl             v16.8h, v16.8h, #4     // u
+.endif
         sub             v1.8h,  v1.8h,  v0.8h  // t1 - u
         sub             v2.8h,  v2.8h,  v0.8h  // t2 - u
         sub             v17.8h, v17.8h, v16.8h // t1 - u
@@ -425,6 +521,7 @@
         smlal           v19.4s, v18.4h, v31.4h // wt[1] * (t2 - u)
         smlal2          v20.4s, v17.8h, v30.8h // wt[0] * (t1 - u)
         smlal2          v20.4s, v18.8h, v31.8h // wt[1] * (t2 - u)
+.if \bpc == 8
         rshrn           v3.4h,  v3.4s,  #11
         rshrn2          v3.8h,  v4.4s,  #11
         rshrn           v19.4h, v19.4s, #11
@@ -433,6 +530,16 @@
         sqxtun          v19.8b, v19.8h
         st1             {v3.8b},  [x0],  #8
         st1             {v19.8b}, [x10], #8
+.else
+        sqrshrun        v3.4h,  v3.4s,  #11
+        sqrshrun2       v3.8h,  v4.4s,  #11
+        sqrshrun        v19.4h, v19.4s, #11
+        sqrshrun2       v19.8h, v20.4s, #11
+        umin            v3.8h,  v3.8h,  v29.8h
+        umin            v19.8h, v19.8h, v29.8h
+        st1             {v3.8h},  [x0],  #16
+        st1             {v19.8h}, [x10], #16
+.endif
         b.gt            1b
 
         subs            x7,  x7,  #2
@@ -451,11 +558,19 @@
         b               1b
 
 2:
+.if \bpc == 8
         ld1             {v0.8b}, [x2], #8
+.else
+        ld1             {v0.8h}, [x2], #16
+.endif
         ld1             {v1.8h}, [x4], #16
         ld1             {v2.8h}, [x5], #16
         subs            x6,  x6,  #8
+.if \bpc == 8
         ushll           v0.8h,  v0.8b,  #4     // u
+.else
+        shl             v0.8h,  v0.8h,  #4     // u
+.endif
         sub             v1.8h,  v1.8h,  v0.8h  // t1 - u
         sub             v2.8h,  v2.8h,  v0.8h  // t2 - u
         ushll           v3.4s,  v0.4h,  #7     // u << 7
@@ -464,11 +579,19 @@
         smlal           v3.4s,  v2.4h,  v31.4h // wt[1] * (t2 - u)
         smlal2          v4.4s,  v1.8h,  v30.8h // wt[0] * (t1 - u)
         smlal2          v4.4s,  v2.8h,  v31.8h // wt[1] * (t2 - u)
+.if \bpc == 8
         rshrn           v3.4h,  v3.4s,  #11
         rshrn2          v3.8h,  v4.4s,  #11
         sqxtun          v3.8b,  v3.8h
         st1             {v3.8b}, [x0], #8
+.else
+        sqrshrun        v3.4h,  v3.4s,  #11
+        sqrshrun2       v3.8h,  v4.4s,  #11
+        umin            v3.8h,  v3.8h,  v29.8h
+        st1             {v3.8h}, [x0], #16
+.endif
         b.gt            1b
 0:
         ret
 endfunc
+.endm
--- a/src/arm/looprestoration_init_tmpl.c
+++ b/src/arm/looprestoration_init_tmpl.c
@@ -104,9 +104,7 @@
         BF(dav1d_copy_narrow, neon)(dst + (w & ~7), dst_stride, tmp, w & 7, h);
     }
 }
-#endif
 
-#if BITDEPTH == 8
 void BF(dav1d_sgr_box3_h, neon)(int32_t *sumsq, int16_t *sum,
                                 const pixel (*left)[4],
                                 const pixel *src, const ptrdiff_t stride,
@@ -116,7 +114,8 @@
                            const int w, const int h,
                            const enum LrEdgeFlags edges);
 void dav1d_sgr_calc_ab1_neon(int32_t *a, int16_t *b,
-                             const int w, const int h, const int strength);
+                             const int w, const int h, const int strength,
+                             const int bitdepth_max);
 void BF(dav1d_sgr_finish_filter1, neon)(int16_t *tmp,
                                         const pixel *src, const ptrdiff_t stride,
                                         const int32_t *a, const int16_t *b,
@@ -147,7 +146,7 @@
                                    lpf_stride, w, 2, edges);
 
     dav1d_sgr_box3_v_neon(sumsq, sum, w, h, edges);
-    dav1d_sgr_calc_ab1_neon(a, b, w, h, strength);
+    dav1d_sgr_calc_ab1_neon(a, b, w, h, strength, BITDEPTH_MAX);
     BF(dav1d_sgr_finish_filter1, neon)(tmp, src, stride, a, b, w, h);
 }
 
@@ -160,7 +159,8 @@
                            const int w, const int h,
                            const enum LrEdgeFlags edges);
 void dav1d_sgr_calc_ab2_neon(int32_t *a, int16_t *b,
-                             const int w, const int h, const int strength);
+                             const int w, const int h, const int strength,
+                             const int bitdepth_max);
 void BF(dav1d_sgr_finish_filter2, neon)(int16_t *tmp,
                                         const pixel *src, const ptrdiff_t stride,
                                         const int32_t *a, const int16_t *b,
@@ -191,7 +191,7 @@
                                    lpf_stride, w, 2, edges);
 
     dav1d_sgr_box5_v_neon(sumsq, sum, w, h, edges);
-    dav1d_sgr_calc_ab2_neon(a, b, w, h, strength);
+    dav1d_sgr_calc_ab2_neon(a, b, w, h, strength, BITDEPTH_MAX);
     BF(dav1d_sgr_finish_filter2, neon)(tmp, src, stride, a, b, w, h);
 }
 
@@ -292,8 +292,7 @@
 
 #if BITDEPTH == 8 || ARCH_AARCH64
     c->wiener = wiener_filter_neon;
-#endif
-#if BITDEPTH == 8
-    c->selfguided = sgr_filter_neon;
+    if (bpc <= 10)
+        c->selfguided = sgr_filter_neon;
 #endif
 }