shithub: dav1d

Download patch

ref: 4a499fd51ad6e650d067fdfd6cce07f7209c54c4
parent: bfdfd1aa1dfe4a067a2887be53cd335e88b52308
author: Ronald S. Bultje <[email protected]>
date: Fri Nov 2 11:39:35 EDT 2018

Add AVX2 implementation for SGR looprestoration

Total decoding time for first 1000 frames of TwxVOYxoukU:
after: 0m3.761s
before: 0m6.868s

Cycle times:
selfguided_3x3_8bpc_c: 438865.8
selfguided_3x3_8bpc_avx2: 112522.6
selfguided_5x5_8bpc_c: 326938.3
selfguided_5x5_8bpc_avx2: 75850.1
selfguided_mix_8bpc_c: 755980.5
selfguided_mix_8bpc_avx2: 195930.3

--- a/src/tables.c
+++ b/src/tables.c
@@ -502,7 +502,7 @@
     { 2, 0,  22,   -1 },
 };
 
-const int16_t dav1d_sgr_x_by_xplus1[256] = {
+const int dav1d_sgr_x_by_xplus1[256] = {
   1,   128, 171, 192, 205, 213, 219, 224, 228, 230, 233, 235, 236, 238, 239,
   240, 241, 242, 243, 243, 244, 244, 245, 245, 246, 246, 247, 247, 247, 247,
   248, 248, 248, 248, 249, 249, 249, 249, 249, 250, 250, 250, 250, 250, 250,
--- a/src/tables.h
+++ b/src/tables.h
@@ -107,7 +107,7 @@
 extern const WarpedMotionParams dav1d_default_wm_params;
 
 extern const int16_t dav1d_sgr_params[16][4];
-extern const int16_t dav1d_sgr_x_by_xplus1[256];
+extern const int dav1d_sgr_x_by_xplus1[256];
 
 extern const int8_t dav1d_mc_subpel_filters[5][15][8];
 extern const int8_t dav1d_mc_warp_filter[][8];
--- a/src/x86/looprestoration.asm
+++ b/src/x86/looprestoration.asm
@@ -36,11 +36,21 @@
 pb_0_to_15_min_n: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 13
                   db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 14
 pb_15: times 16 db 15
+pw_16: times 2 dw 16
+pw_256: times 2 dw 256
 pw_2048: times 2 dw 2048
 pw_16380: times 2 dw 16380
 pw_0_128: dw 0, 128
+pw_5_6: dw 5, 6
+pw_82: times 2 dw 82
+pw_91_5: dw 91, 5
+pd_6: dd 6
+pd_255: dd 255
 pd_1024: dd 1024
+pd_0x80000: dd 0x80000
 
+cextern sgr_x_by_xplus1
+
 SECTION .text
 
 INIT_YMM avx2
@@ -302,5 +312,845 @@
     add         midq, 32
     sub           wd, 16
     jg .loop_x
+    RET
+
+INIT_YMM avx2
+cglobal sgr_box3_h, 8, 11, 8, sumsq, sum, left, src, stride, w, h, edge, x, xlim
+    mov        xlimd, edged
+    and        xlimd, 2                             ; have_right
+    add           wd, xlimd
+    xor        xlimd, 2                             ; 2*!have_right
+    jnz .no_right
+    add           wd, 15
+    and           wd, ~15
+.no_right:
+    pxor          m1, m1
+    lea         srcq, [srcq+wq]
+    lea         sumq, [sumq+wq*2-2]
+    lea       sumsqq, [sumsqq+wq*4-4]
+    neg           wq
+    lea          r10, [pb_right_ext_mask+32]
+.loop_y:
+    mov           xq, wq
+
+    ; load left
+    test       edged, 1                             ; have_left
+    jz .no_left
+    test       leftq, leftq
+    jz .load_left_from_main
+    pinsrw       xm0, [leftq+2], 7
+    add        leftq, 4
+    jmp .expand_x
+.no_left:
+    vpbroadcastb xm0, [srcq+xq]
+    jmp .expand_x
+.load_left_from_main:
+    pinsrw       xm0, [srcq+xq-2], 7
+.expand_x:
+    punpckhbw    xm0, xm1
+
+    ; when we reach this, xm0 contains left two px in highest words
+    cmp           xq, -16
+    jle .loop_x
+.partial_load_and_extend:
+    vpbroadcastb  m3, [srcq-1]
+    pmovzxbw      m2, [srcq+xq]
+    punpcklbw     m3, m1
+    movu          m4, [r10+xq*2]
+    pand          m2, m4
+    pandn         m4, m3
+    por           m2, m4
+    jmp .loop_x_noload
+.right_extend:
+    psrldq       xm2, xm0, 14
+    vpbroadcastw  m2, xm2
+    jmp .loop_x_noload
+
+.loop_x:
+    pmovzxbw      m2, [srcq+xq]
+.loop_x_noload:
+    vinserti128   m0, xm2, 1
+    palignr       m3, m2, m0, 12
+    palignr       m4, m2, m0, 14
+
+    punpcklwd     m5, m3, m2
+    punpckhwd     m6, m3, m2
+    paddw         m3, m4
+    punpcklwd     m7, m4, m1
+    punpckhwd     m4, m1
+    pmaddwd       m5, m5
+    pmaddwd       m6, m6
+    pmaddwd       m7, m7
+    pmaddwd       m4, m4
+    paddd         m5, m7
+    paddd         m6, m4
+    paddw         m3, m2
+    movu [sumq+xq*2], m3
+    movu [sumsqq+xq*4+ 0], xm5
+    movu [sumsqq+xq*4+16], xm6
+    vextracti128 [sumsqq+xq*4+32], m5, 1
+    vextracti128 [sumsqq+xq*4+48], m6, 1
+
+    vextracti128 xm0, m2, 1
+    add           xq, 16
+
+    ; if x <= -16 we can reload more pixels
+    ; else if x < 0 we reload and extend (this implies have_right=0)
+    ; else if x < xlimd we extend from previous load (this implies have_right=0)
+    ; else we are done
+
+    cmp           xq, -16
+    jle .loop_x
+    test          xq, xq
+    jl .partial_load_and_extend
+    cmp           xq, xlimq
+    jl .right_extend
+
+    add       sumsqq, (384+16)*4
+    add         sumq, (384+16)*2
+    add         srcq, strideq
+    dec hd
+    jg .loop_y
+    RET
+
+INIT_YMM avx2
+cglobal sgr_box3_v, 5, 10, 9, sumsq, sum, w, h, edge, x, y, sumsq_ptr, sum_ptr, ylim
+    mov           xq, -2
+    mov        ylimd, edged
+    and        ylimd, 8                             ; have_bottom
+    shr        ylimd, 2
+    sub        ylimd, 2                             ; -2 if have_bottom=0, else 0
+.loop_x:
+    lea           yd, [hd+ylimd+2]
+    lea   sumsq_ptrq, [sumsqq+xq*4+4-(384+16)*4]
+    lea     sum_ptrq, [sumq+xq*2+2-(384+16)*2]
+    test       edged, 4                             ; have_top
+    jnz .load_top
+    movu          m0, [sumsq_ptrq+(384+16)*4*1]
+    movu          m1, [sumsq_ptrq+(384+16)*4*1+32]
+    mova          m2, m0
+    mova          m3, m1
+    mova          m4, m0
+    mova          m5, m1
+    movu          m6, [sum_ptrq+(384+16)*2*1]
+    mova          m7, m6
+    mova          m8, m6
+    jmp .loop_y_noload
+.load_top:
+    movu          m0, [sumsq_ptrq-(384+16)*4*1]      ; l2sq [left]
+    movu          m1, [sumsq_ptrq-(384+16)*4*1+32]   ; l2sq [right]
+    movu          m2, [sumsq_ptrq-(384+16)*4*0]      ; l1sq [left]
+    movu          m3, [sumsq_ptrq-(384+16)*4*0+32]   ; l1sq [right]
+    movu          m6, [sum_ptrq-(384+16)*2*1]        ; l2
+    movu          m7, [sum_ptrq-(384+16)*2*0]        ; l1
+.loop_y:
+    movu          m4, [sumsq_ptrq+(384+16)*4*1]      ; l0sq [left]
+    movu          m5, [sumsq_ptrq+(384+16)*4*1+32]   ; l0sq [right]
+    movu          m8, [sum_ptrq+(384+16)*2*1]        ; l0
+.loop_y_noload:
+    paddd         m0, m2
+    paddd         m1, m3
+    paddw         m6, m7
+    paddd         m0, m4
+    paddd         m1, m5
+    paddw         m6, m8
+    movu [sumsq_ptrq+ 0], m0
+    movu [sumsq_ptrq+32], m1
+    movu  [sum_ptrq], m6
+
+    ; shift position down by one
+    mova          m0, m2
+    mova          m1, m3
+    mova          m2, m4
+    mova          m3, m5
+    mova          m6, m7
+    mova          m7, m8
+    add   sumsq_ptrq, (384+16)*4
+    add     sum_ptrq, (384+16)*2
+    dec           yd
+    jg .loop_y
+    cmp           yd, ylimd
+    jg .loop_y_noload
+    add           xd, 16
+    cmp           xd, wd
+    jl .loop_x
+    RET
+
+INIT_YMM avx2
+cglobal sgr_calc_ab1, 4, 6, 14, a, b, w, h, s
+    sub           aq, (384+16-1)*4
+    sub           bq, (384+16-1)*2
+    add           hd, 2
+    lea           r5, [sgr_x_by_xplus1]
+    pxor          m6, m6
+    vpbroadcastd  m7, [pw_91_5]
+%ifidn sd, sm
+    movd         xm8, sd
+    vpbroadcastd  m8, xm8
+%else
+    vpbroadcastd  m8, sd
+%endif
+    vpbroadcastd  m9, [pd_0x80000]
+    vpbroadcastd m10, [pd_255]
+    psrad        m12, m9, 8                         ; pd_2048
+    psrad        m11, m9, 11                        ; pd_256
+    pcmpeqb      m13, m13
+    DEFINE_ARGS a, b, w, h, x
+.loop_y:
+    mov           xq, -2
+.loop_x:
+    movu         xm0, [aq+xq*4+ 0]
+    movu         xm1, [aq+xq*4+16]
+    vinserti128   m0, [aq+xq*4+ 0+(384+16)*4], 1
+    vinserti128   m1, [aq+xq*4+16+(384+16)*4], 1
+    movu         xm2, [bq+xq*2]
+    vinserti128   m2, [bq+xq*2+(384+16)*2], 1
+    pslld         m3, m0, 3
+    pslld         m4, m1, 3
+    paddd         m3, m0                            ; aa * 9 [first half]
+    paddd         m4, m1                            ; aa * 9 [second half]
+    punpcklwd     m0, m6, m2
+    punpckhwd     m2, m6, m2
+    pmaddwd       m1, m0, m0
+    pmaddwd       m5, m2, m2
+    pmaddwd       m0, m7
+    pmaddwd       m2, m7
+    psubd         m3, m1                            ; p = aa * 9 - bb * bb [first half]
+    psubd         m4, m5                            ; p = aa * 9 - bb * bb [second half]
+    pmulld        m3, m8
+    pmulld        m4, m8
+    paddd         m3, m9
+    paddd         m4, m9
+    psrld         m3, 20                            ; z [first half]
+    psrld         m4, 20                            ; z [second half]
+    pminsd        m3, m10
+    pminsd        m4, m10
+    mova          m5, m13
+    vpgatherdd    m1, [r5+m3*4], m5                 ; xx [first half]
+    mova          m5, m13
+    vpgatherdd    m3, [r5+m4*4], m5                 ; xx [second half]
+    psubd         m5, m11, m1
+    psubd         m4, m11, m3
+    packssdw      m1, m3
+    pmullw        m5, m7
+    pmullw        m4, m7
+    pmaddwd       m5, m0
+    pmaddwd       m4, m2
+    paddd         m5, m12
+    paddd         m4, m12
+    psrad         m5, 12
+    psrad         m4, 12
+    movu   [bq+xq*2], xm1
+    vextracti128 [bq+xq*2+(384+16)*2], m1, 1
+    movu [aq+xq*4+ 0], xm5
+    movu [aq+xq*4+16], xm4
+    vextracti128 [aq+xq*4+ 0+(384+16)*4], m5, 1
+    vextracti128 [aq+xq*4+16+(384+16)*4], m4, 1
+
+    add           xd, 8
+    cmp           xd, wd
+    jl .loop_x
+    add           aq, (384+16)*4*2
+    add           bq, (384+16)*2*2
+    sub           hd, 2
+    jg .loop_y
+    RET
+
+INIT_YMM avx2
+cglobal sgr_finish_filter1, 7, 13, 16, t, src, stride, a, b, w, h, \
+                                       tmp_ptr, src_ptr, a_ptr, b_ptr, x, y
+    vpbroadcastd m15, [pw_16]
+    xor           xd, xd
+.loop_x:
+    lea     tmp_ptrq, [tq+xq*2]
+    lea     src_ptrq, [srcq+xq*1]
+    lea       a_ptrq, [aq+xq*4+(384+16)*4]
+    lea       b_ptrq, [bq+xq*2+(384+16)*2]
+    movu          m0, [aq+xq*4-(384+16)*4-4]
+    movu          m2, [aq+xq*4-(384+16)*4+4]
+    mova          m1, [aq+xq*4-(384+16)*4]           ; a:top [first half]
+    paddd         m0, m2                            ; a:tl+tr [first half]
+    movu          m2, [aq+xq*4-(384+16)*4-4+32]
+    movu          m4, [aq+xq*4-(384+16)*4+4+32]
+    mova          m3, [aq+xq*4-(384+16)*4+32]        ; a:top [second half]
+    paddd         m2, m4                            ; a:tl+tr [second half]
+    movu          m4, [aq+xq*4-4]
+    movu          m5, [aq+xq*4+4]
+    paddd         m1, [aq+xq*4]                     ; a:top+ctr [first half]
+    paddd         m4, m5                            ; a:l+r [first half]
+    movu          m5, [aq+xq*4+32-4]
+    movu          m6, [aq+xq*4+32+4]
+    paddd         m3, [aq+xq*4+32]                  ; a:top+ctr [second half]
+    paddd         m5, m6                            ; a:l+r [second half]
+
+    movu          m6, [bq+xq*2-(384+16)*2-2]
+    movu          m8, [bq+xq*2-(384+16)*2+2]
+    mova          m7, [bq+xq*2-(384+16)*2]          ; b:top
+    paddw         m6, m8                            ; b:tl+tr
+    movu          m8, [bq+xq*2-2]
+    movu          m9, [bq+xq*2+2]
+    paddw         m7, [bq+xq*2]                     ; b:top+ctr
+    paddw         m8, m9                            ; b:l+r
+    mov           yd, hd
+.loop_y:
+    movu          m9, [b_ptrq-2]
+    movu         m10, [b_ptrq+2]
+    paddw         m7, [b_ptrq]                      ; b:top+ctr+bottom
+    paddw         m9, m10                           ; b:bl+br
+    paddw        m10, m7, m8                        ; b:top+ctr+bottom+l+r
+    paddw         m6, m9                            ; b:tl+tr+bl+br
+    psubw         m7, [b_ptrq-(384+16)*2*2]         ; b:ctr+bottom
+    paddw        m10, m6
+    psllw        m10, 2
+    psubw        m10, m6                            ; aa
+    pmovzxbw     m12, [src_ptrq]
+    punpcklwd     m6, m10, m15
+    punpckhwd    m10, m15
+    punpcklwd    m13, m12, m15
+    punpckhwd    m12, m15
+    pmaddwd       m6, m13                           ; aa*src[x]+256 [first half]
+    pmaddwd      m10, m12                           ; aa*src[x]+256 [second half]
+
+    movu         m11, [a_ptrq-4]
+    movu         m12, [a_ptrq+4]
+    paddd         m1, [a_ptrq]                      ; a:top+ctr+bottom [first half]
+    paddd        m11, m12                           ; a:bl+br [first half]
+    movu         m12, [a_ptrq+32-4]
+    movu         m13, [a_ptrq+32+4]
+    paddd         m3, [a_ptrq+32]                   ; a:top+ctr+bottom [second half]
+    paddd        m12, m13                           ; a:bl+br [second half]
+    paddd        m13, m1, m4                        ; a:top+ctr+bottom+l+r [first half]
+    paddd        m14, m3, m5                        ; a:top+ctr+bottom+l+r [second half]
+    paddd         m0, m11                           ; a:tl+tr+bl+br [first half]
+    paddd         m2, m12                           ; a:tl+tr+bl+br [second half]
+    paddd        m13, m0
+    paddd        m14, m2
+    pslld        m13, 2
+    pslld        m14, 2
+    psubd        m13, m0                            ; bb [first half]
+    psubd        m14, m2                            ; bb [second half]
+    vperm2i128    m0, m13, m14, 0x31
+    vinserti128  m13, xm14, 1
+    psubd         m1, [a_ptrq-(384+16)*4*2]          ; a:ctr+bottom [first half]
+    psubd         m3, [a_ptrq-(384+16)*4*2+32]       ; a:ctr+bottom [second half]
+
+    paddd         m6, m13
+    paddd        m10, m0
+    psrad         m6, 9
+    psrad        m10, 9
+    packssdw      m6, m10
+    mova  [tmp_ptrq], m6
+
+    ; shift to next row
+    mova          m0, m4
+    mova          m2, m5
+    mova          m4, m11
+    mova          m5, m12
+    mova          m6, m8
+    mova          m8, m9
+
+    add       a_ptrq, (384+16)*4
+    add       b_ptrq, (384+16)*2
+    add     tmp_ptrq, 384*2
+    add     src_ptrq, strideq
+    dec           yd
+    jg .loop_y
+    add           xd, 16
+    cmp           xd, wd
+    jl .loop_x
+    RET
+
+INIT_YMM avx2
+cglobal sgr_weighted1, 6, 6, 7, dst, stride, t, w, h, wt
+    movd         xm0, wtd
+    vpbroadcastw  m0, xm0
+    psllw         m0, 4
+    DEFINE_ARGS dst, stride, t, w, h, idx
+.loop_y:
+    xor         idxd, idxd
+.loop_x:
+    mova          m1, [tq+idxq*2+ 0]
+    mova          m4, [tq+idxq*2+32]
+    pmovzxbw      m2, [dstq+idxq+ 0]
+    pmovzxbw      m5, [dstq+idxq+16]
+    psllw         m3, m2, 4
+    psllw         m6, m5, 4
+    psubw         m1, m3
+    psubw         m4, m6
+    pmulhrsw      m1, m0
+    pmulhrsw      m4, m0
+    paddw         m1, m2
+    paddw         m4, m5
+    packuswb      m1, m4
+    vpermq        m1, m1, q3120
+    mova [dstq+idxq], m1
+    add         idxd, 32
+    cmp         idxd, wd
+    jl .loop_x
+    add         dstq, strideq
+    add           tq, 384 * 2
+    dec           hd
+    jg .loop_y
+    RET
+
+INIT_YMM avx2
+cglobal sgr_box5_h, 8, 11, 10, sumsq, sum, left, src, stride, w, h, edge, x, xlim
+    test       edged, 2                             ; have_right
+    jz .no_right
+    xor        xlimd, xlimd
+    add           wd, 2
+    add           wd, 15
+    and           wd, ~15
+    jmp .right_done
+.no_right:
+    mov        xlimd, 3
+    sub           wd, 1
+.right_done:
+    pxor          m1, m1
+    lea         srcq, [srcq+wq+1]
+    lea         sumq, [sumq+wq*2-2]
+    lea       sumsqq, [sumsqq+wq*4-4]
+    neg           wq
+    lea          r10, [pb_right_ext_mask+32]
+.loop_y:
+    mov           xq, wq
+
+    ; load left
+    test       edged, 1                             ; have_left
+    jz .no_left
+    test       leftq, leftq
+    jz .load_left_from_main
+    movd         xm0, [leftq]
+    pinsrd       xm0, [srcq+xq-1], 1
+    pslldq       xm0, 11
+    add        leftq, 4
+    jmp .expand_x
+.no_left:
+    vpbroadcastb xm0, [srcq+xq-1]
+    jmp .expand_x
+.load_left_from_main:
+    pinsrd       xm0, [srcq+xq-4], 3
+.expand_x:
+    punpckhbw    xm0, xm1
+
+    ; when we reach this, xm0 contains left two px in highest words
+    cmp           xq, -16
+    jle .loop_x
+    test          xq, xq
+    jge .right_extend
+.partial_load_and_extend:
+    vpbroadcastb  m3, [srcq-1]
+    pmovzxbw      m2, [srcq+xq]
+    punpcklbw     m3, m1
+    movu          m4, [r10+xq*2]
+    pand          m2, m4
+    pandn         m4, m3
+    por           m2, m4
+    jmp .loop_x_noload
+.right_extend:
+    psrldq       xm2, xm0, 14
+    vpbroadcastw  m2, xm2
+    jmp .loop_x_noload
+
+.loop_x:
+    pmovzxbw      m2, [srcq+xq]
+.loop_x_noload:
+    vinserti128   m0, xm2, 1
+    palignr       m3, m2, m0, 8
+    palignr       m4, m2, m0, 10
+    palignr       m5, m2, m0, 12
+    palignr       m6, m2, m0, 14
+
+    paddw         m0, m3, m2
+    punpcklwd     m7, m3, m2
+    punpckhwd     m3, m2
+    paddw         m0, m4
+    punpcklwd     m8, m4, m5
+    punpckhwd     m4, m5
+    paddw         m0, m5
+    punpcklwd     m9, m6, m1
+    punpckhwd     m5, m6, m1
+    paddw         m0, m6
+    pmaddwd       m7, m7
+    pmaddwd       m3, m3
+    pmaddwd       m8, m8
+    pmaddwd       m4, m4
+    pmaddwd       m9, m9
+    pmaddwd       m5, m5
+    paddd         m7, m8
+    paddd         m3, m4
+    paddd         m7, m9
+    paddd         m3, m5
+    movu [sumq+xq*2], m0
+    movu [sumsqq+xq*4+ 0], xm7
+    movu [sumsqq+xq*4+16], xm3
+    vextracti128 [sumsqq+xq*4+32], m7, 1
+    vextracti128 [sumsqq+xq*4+48], m3, 1
+
+    vextracti128 xm0, m2, 1
+    add           xq, 16
+
+    ; if x <= -16 we can reload more pixels
+    ; else if x < 0 we reload and extend (this implies have_right=0)
+    ; else if x < xlimd we extend from previous load (this implies have_right=0)
+    ; else we are done
+
+    cmp           xq, -16
+    jle .loop_x
+    test          xq, xq
+    jl .partial_load_and_extend
+    cmp           xq, xlimq
+    jl .right_extend
+
+    add       sumsqq, (384+16)*4
+    add         sumq, (384+16)*2
+    add         srcq, strideq
+    dec hd
+    jg .loop_y
+    RET
+
+INIT_YMM avx2
+cglobal sgr_box5_v, 5, 10, 15, sumsq, sum, w, h, edge, x, y, sumsq_ptr, sum_ptr, ylim
+    mov           xq, -2
+    mov        ylimd, edged
+    and        ylimd, 8                             ; have_bottom
+    shr        ylimd, 2
+    sub        ylimd, 3                             ; -2 if have_bottom=0, else 0
+.loop_x:
+    lea           yd, [hd+ylimd+2]
+    lea   sumsq_ptrq, [sumsqq+xq*4+4-(384+16)*4]
+    lea     sum_ptrq, [sumq+xq*2+2-(384+16)*2]
+    test       edged, 4                             ; have_top
+    jnz .load_top
+    movu          m0, [sumsq_ptrq+(384+16)*4*1]
+    movu          m1, [sumsq_ptrq+(384+16)*4*1+32]
+    mova          m2, m0
+    mova          m3, m1
+    mova          m4, m0
+    mova          m5, m1
+    mova          m6, m0
+    mova          m7, m1
+    movu         m10, [sum_ptrq+(384+16)*2*1]
+    mova         m11, m10
+    mova         m12, m10
+    mova         m13, m10
+    jmp .loop_y_second_load
+.load_top:
+    movu          m0, [sumsq_ptrq-(384+16)*4*1]      ; l3/4sq [left]
+    movu          m1, [sumsq_ptrq-(384+16)*4*1+32]   ; l3/4sq [right]
+    movu          m4, [sumsq_ptrq-(384+16)*4*0]      ; l2sq [left]
+    movu          m5, [sumsq_ptrq-(384+16)*4*0+32]   ; l2sq [right]
+    mova          m2, m0
+    mova          m3, m1
+    movu         m10, [sum_ptrq-(384+16)*2*1]        ; l3/4
+    movu         m12, [sum_ptrq-(384+16)*2*0]        ; l2
+    mova         m11, m10
+.loop_y:
+    movu          m6, [sumsq_ptrq+(384+16)*4*1]      ; l1sq [left]
+    movu          m7, [sumsq_ptrq+(384+16)*4*1+32]   ; l1sq [right]
+    movu         m13, [sum_ptrq+(384+16)*2*1]        ; l1
+.loop_y_second_load:
+    test          yd, yd
+    jle .emulate_second_load
+    movu          m8, [sumsq_ptrq+(384+16)*4*2]      ; l0sq [left]
+    movu          m9, [sumsq_ptrq+(384+16)*4*2+32]   ; l0sq [right]
+    movu         m14, [sum_ptrq+(384+16)*2*2]        ; l0
+.loop_y_noload:
+    paddd         m0, m2
+    paddd         m1, m3
+    paddw        m10, m11
+    paddd         m0, m4
+    paddd         m1, m5
+    paddw        m10, m12
+    paddd         m0, m6
+    paddd         m1, m7
+    paddw        m10, m13
+    paddd         m0, m8
+    paddd         m1, m9
+    paddw        m10, m14
+    movu [sumsq_ptrq+ 0], m0
+    movu [sumsq_ptrq+32], m1
+    movu  [sum_ptrq], m10
+
+    ; shift position down by one
+    mova          m0, m4
+    mova          m1, m5
+    mova          m2, m6
+    mova          m3, m7
+    mova          m4, m8
+    mova          m5, m9
+    mova         m10, m12
+    mova         m11, m13
+    mova         m12, m14
+    add   sumsq_ptrq, (384+16)*4*2
+    add     sum_ptrq, (384+16)*2*2
+    sub           yd, 2
+    jge .loop_y
+    ; l1 = l0
+    mova          m6, m8
+    mova          m7, m9
+    mova         m13, m14
+    cmp           yd, ylimd
+    jg .loop_y_noload
+    add           xd, 16
+    cmp           xd, wd
+    jl .loop_x
+    RET
+.emulate_second_load:
+    mova          m8, m6
+    mova          m9, m7
+    mova         m14, m13
+    jmp .loop_y_noload
+
+INIT_YMM avx2
+cglobal sgr_calc_ab2, 4, 6, 14, a, b, w, h, s
+    sub           aq, (384+16-1)*4
+    sub           bq, (384+16-1)*2
+    add           hd, 2
+    lea           r5, [sgr_x_by_xplus1]
+    pxor          m6, m6
+    vpbroadcastd  m7, [pw_82]
+%ifidn sd, sm
+    movd         xm8, sd
+    vpbroadcastd  m8, xm8
+%else
+    vpbroadcastd  m8, sd
+%endif
+    vpbroadcastd  m9, [pd_0x80000]
+    vpbroadcastd m10, [pd_255]
+    psrad        m12, m9, 8                         ; pd_2048
+    psrad        m11, m9, 11                        ; pd_256
+    pcmpeqb      m13, m13
+    DEFINE_ARGS a, b, w, h, x
+.loop_y:
+    mov           xq, -2
+.loop_x:
+    movu         xm0, [aq+xq*4+ 0]
+    movu         xm1, [aq+xq*4+16]
+    vinserti128   m0, [aq+xq*4+32], 1
+    vinserti128   m1, [aq+xq*4+48], 1
+    movu          m2, [bq+xq*2]
+    pslld         m3, m0, 5                         ; aa * 32 [first half]
+    pslld         m4, m1, 5                         ; aa * 32 [second half]
+    paddd         m3, m0                            ; aa * 33 [first half]
+    paddd         m4, m1                            ; aa * 33 [first half]
+    pslld         m0, 3                             ; aa * 8 [first half]
+    pslld         m1, 3                             ; aa * 8 [second half]
+    psubd         m3, m0                            ; aa * 25 [first half]
+    psubd         m4, m1                            ; aa * 25 [second half]
+    punpcklwd     m0, m2, m6
+    punpckhwd     m2, m6
+    pmaddwd       m1, m0, m0
+    pmaddwd       m5, m2, m2
+    paddw         m0, m0
+    paddw         m2, m2
+    psubd         m3, m1                            ; p = aa * 25 - bb * bb [first half]
+    psubd         m4, m5                            ; p = aa * 25 - bb * bb [second half]
+    pmulld        m3, m8
+    pmulld        m4, m8
+    paddd         m3, m9
+    paddd         m4, m9
+    psrld         m3, 20                            ; z [first half]
+    psrld         m4, 20                            ; z [second half]
+    pminsd        m3, m10
+    pminsd        m4, m10
+    mova          m5, m13
+    vpgatherdd    m1, [r5+m3*4], m5                 ; xx [first half]
+    mova          m5, m13
+    vpgatherdd    m3, [r5+m4*4], m5                 ; xx [second half]
+    psubd         m5, m11, m1
+    psubd         m4, m11, m3
+    packssdw      m1, m3
+    pmullw        m5, m7
+    pmullw        m4, m7
+    pmaddwd       m5, m0
+    pmaddwd       m4, m2
+    paddd         m5, m12
+    paddd         m4, m12
+    psrad         m5, 12
+    psrad         m4, 12
+    movu   [bq+xq*2], m1
+    movu [aq+xq*4+ 0], xm5
+    movu [aq+xq*4+16], xm4
+    vextracti128 [aq+xq*4+32], m5, 1
+    vextracti128 [aq+xq*4+48], m4, 1
+
+    add           xd, 16
+    cmp           xd, wd
+    jl .loop_x
+    add           aq, (384+16)*4*2
+    add           bq, (384+16)*2*2
+    sub           hd, 2
+    jg .loop_y
+    RET
+
+INIT_YMM avx2
+cglobal sgr_finish_filter2, 7, 13, 13, t, src, stride, a, b, w, h, \
+                                       tmp_ptr, src_ptr, a_ptr, b_ptr, x, y
+    vpbroadcastd  m9, [pw_5_6]
+    vpbroadcastd m12, [pw_256]
+    psrlw        m11, m12, 1                    ; pw_128
+    psrlw        m10, m12, 8                    ; pw_1
+    xor           xd, xd
+.loop_x:
+    lea     tmp_ptrq, [tq+xq*2]
+    lea     src_ptrq, [srcq+xq*1]
+    lea       a_ptrq, [aq+xq*4+(384+16)*4]
+    lea       b_ptrq, [bq+xq*2+(384+16)*2]
+    movu          m0, [aq+xq*4-(384+16)*4-4]
+    mova          m1, [aq+xq*4-(384+16)*4]
+    movu          m2, [aq+xq*4-(384+16)*4+4]
+    movu          m3, [aq+xq*4-(384+16)*4-4+32]
+    mova          m4, [aq+xq*4-(384+16)*4+32]
+    movu          m5, [aq+xq*4-(384+16)*4+4+32]
+    paddd         m0, m2
+    paddd         m3, m5
+    paddd         m0, m1
+    paddd         m3, m4
+    pslld         m2, m0, 2
+    pslld         m5, m3, 2
+    paddd         m2, m0
+    paddd         m5, m3
+    paddd         m0, m2, m1                    ; prev_odd_b [first half]
+    paddd         m1, m5, m4                    ; prev_odd_b [second half]
+    movu          m3, [bq+xq*2-(384+16)*2-2]
+    mova          m4, [bq+xq*2-(384+16)*2]
+    movu          m5, [bq+xq*2-(384+16)*2+2]
+    paddw         m3, m5
+    punpcklwd     m5, m3, m4
+    punpckhwd     m3, m4
+    pmaddwd       m5, m9
+    pmaddwd       m3, m9
+    packssdw      m2, m5, m3                    ; prev_odd_a
+    mov           yd, hd
+.loop_y:
+    movu          m3, [a_ptrq-4]
+    mova          m4, [a_ptrq]
+    movu          m5, [a_ptrq+4]
+    movu          m6, [a_ptrq+32-4]
+    mova          m7, [a_ptrq+32]
+    movu          m8, [a_ptrq+32+4]
+    paddd         m3, m5
+    paddd         m6, m8
+    paddd         m3, m4
+    paddd         m6, m7
+    pslld         m5, m3, 2
+    pslld         m8, m6, 2
+    paddd         m5, m3
+    paddd         m8, m6
+    paddd         m3, m5, m4                    ; cur_odd_b [first half]
+    paddd         m4, m8, m7                    ; cur_odd_b [second half]
+    movu          m5, [b_ptrq-2]
+    mova          m6, [b_ptrq]
+    movu          m7, [b_ptrq+2]
+    paddw         m5, m7
+    punpcklwd     m7, m5, m6
+    punpckhwd     m5, m6
+    pmaddwd       m7, m9
+    pmaddwd       m5, m9
+    packssdw      m5, m7, m5                    ; cur_odd_a
+
+    paddd         m0, m3                        ; cur_even_b [first half]
+    paddd         m1, m4                        ; cur_even_b [second half]
+    paddw         m2, m5                        ; cur_even_a
+
+    pmovzxbw      m6, [src_ptrq]
+    vperm2i128    m8, m0, m1, 0x31
+    vinserti128   m0, xm1, 1
+    punpcklwd     m7, m6, m10
+    punpckhwd     m6, m10
+    punpcklwd     m1, m2, m12
+    punpckhwd     m2, m12
+    pmaddwd       m7, m1
+    pmaddwd       m6, m2
+    paddd         m7, m0
+    paddd         m6, m8
+    psrad         m7, 9
+    psrad         m6, 9
+
+    pmovzxbw      m8, [src_ptrq+strideq]
+    punpcklwd     m0, m8, m10
+    punpckhwd     m8, m10
+    punpcklwd     m1, m5, m11
+    punpckhwd     m2, m5, m11
+    pmaddwd       m0, m1
+    pmaddwd       m8, m2
+    vinserti128   m2, m3, xm4, 1
+    vperm2i128    m1, m3, m4, 0x31
+    paddd         m0, m2
+    paddd         m8, m1
+    psrad         m0, 8
+    psrad         m8, 8
+
+    packssdw      m7, m6
+    packssdw      m0, m8
+    mova [tmp_ptrq+384*2*0], m7
+    mova [tmp_ptrq+384*2*1], m0
+
+    mova          m0, m3
+    mova          m1, m4
+    mova          m2, m5
+    add       a_ptrq, (384+16)*4*2
+    add       b_ptrq, (384+16)*2*2
+    add     tmp_ptrq, 384*2*2
+    lea     src_ptrq, [src_ptrq+strideq*2]
+    sub           yd, 2
+    jg .loop_y
+    add           xd, 16
+    cmp           xd, wd
+    jl .loop_x
+    RET
+
+INIT_YMM avx2
+cglobal sgr_weighted2, 7, 7, 11, dst, stride, t1, t2, w, h, wt
+    vpbroadcastd  m0, [wtq]
+    vpbroadcastd m10, [pd_1024]
+    DEFINE_ARGS dst, stride, t1, t2, w, h, idx
+.loop_y:
+    xor         idxd, idxd
+.loop_x:
+    mova          m1, [t1q+idxq*2+ 0]
+    mova          m2, [t1q+idxq*2+32]
+    mova          m3, [t2q+idxq*2+ 0]
+    mova          m4, [t2q+idxq*2+32]
+    pmovzxbw      m5, [dstq+idxq+ 0]
+    pmovzxbw      m6, [dstq+idxq+16]
+    psllw         m7, m5, 4
+    psllw         m8, m6, 4
+    psubw         m1, m7
+    psubw         m2, m8
+    psubw         m3, m7
+    psubw         m4, m8
+    punpcklwd     m9, m1, m3
+    punpckhwd     m1, m3
+    punpcklwd     m3, m2, m4
+    punpckhwd     m2, m4
+    pmaddwd       m9, m0
+    pmaddwd       m1, m0
+    pmaddwd       m3, m0
+    pmaddwd       m2, m0
+    paddd         m9, m10
+    paddd         m1, m10
+    paddd         m3, m10
+    paddd         m2, m10
+    psrad         m9, 11
+    psrad         m1, 11
+    psrad         m3, 11
+    psrad         m2, 11
+    packssdw      m1, m9, m1
+    packssdw      m2, m3, m2
+    paddw         m1, m5
+    paddw         m2, m6
+    packuswb      m1, m2
+    vpermq        m1, m1, q3120
+    mova [dstq+idxq], m1
+    add         idxd, 32
+    cmp         idxd, wd
+    jl .loop_x
+    add         dstq, strideq
+    add          t1q, 384 * 2
+    add          t2q, 384 * 2
+    dec           hd
+    jg .loop_y
     RET
 %endif ; ARCH_X86_64
--- a/src/x86/looprestoration_init_tmpl.c
+++ b/src/x86/looprestoration_init_tmpl.c
@@ -30,6 +30,7 @@
 
 #include "common/attributes.h"
 #include "common/intops.h"
+#include "src/tables.h"
 
 #if BITDEPTH == 8 && ARCH_X86_64
 void dav1d_wiener_filter_h_avx2(int16_t *dst, const pixel (*left)[4],
@@ -73,6 +74,128 @@
 
     dav1d_wiener_filter_v_avx2(dst, dst_stride, &mid[2*384], w, h, fv, edges);
 }
+
+void dav1d_sgr_box3_h_avx2(int32_t *sumsq, int16_t *sum,
+                           const pixel (*left)[4],
+                           const pixel *src, const ptrdiff_t stride,
+                           const int w, const int h,
+                           const enum LrEdgeFlags edges);
+void dav1d_sgr_box3_v_avx2(int32_t *sumsq, int16_t *sum,
+                           const int w, const int h,
+                           const enum LrEdgeFlags edges);
+void dav1d_sgr_calc_ab1_avx2(int32_t *a, int16_t *b,
+                             const int w, const int h, const int strength);
+void dav1d_sgr_finish_filter1_avx2(coef *tmp,
+                                   const pixel *src, const ptrdiff_t stride,
+                                   const int32_t *a, const int16_t *b,
+                                   const int w, const int h);
+
+// filter with a 3x3 box (radius=1)
+static void dav1d_sgr_filter1_avx2(coef *tmp,
+                                   const pixel *src, const ptrdiff_t stride,
+                                   const pixel (*left)[4],
+                                   const pixel *lpf, const ptrdiff_t lpf_stride,
+                                   const int w, const int h, const int strength,
+                                   const enum LrEdgeFlags edges)
+{
+    ALIGN_STK_32(int32_t, sumsq_mem, (384 + 16) * 68 + 8,);
+    int32_t *const sumsq = &sumsq_mem[(384 + 16) * 2 + 8], *const a = sumsq;
+    ALIGN_STK_32(int16_t, sum_mem, (384 + 16) * 68 + 16,);
+    int16_t *const sum = &sum_mem[(384 + 16) * 2 + 16], *const b = sum;
+
+    dav1d_sgr_box3_h_avx2(sumsq, sum, left, src, stride, w, h, edges);
+    if (edges & LR_HAVE_TOP)
+        dav1d_sgr_box3_h_avx2(&sumsq[-2 * (384 + 16)], &sum[-2 * (384 + 16)],
+                              NULL, lpf, lpf_stride, w, 2, edges);
+
+    if (edges & LR_HAVE_BOTTOM)
+        dav1d_sgr_box3_h_avx2(&sumsq[h * (384 + 16)], &sum[h * (384 + 16)],
+                              NULL, lpf + 6 * PXSTRIDE(lpf_stride),
+                              lpf_stride, w, 2, edges);
+
+    dav1d_sgr_box3_v_avx2(sumsq, sum, w, h, edges);
+    dav1d_sgr_calc_ab1_avx2(a, b, w, h, strength);
+    dav1d_sgr_finish_filter1_avx2(tmp, src, stride, a, b, w, h);
+}
+
+void dav1d_sgr_box5_h_avx2(int32_t *sumsq, int16_t *sum,
+                           const pixel (*left)[4],
+                           const pixel *src, const ptrdiff_t stride,
+                           const int w, const int h,
+                           const enum LrEdgeFlags edges);
+void dav1d_sgr_box5_v_avx2(int32_t *sumsq, int16_t *sum,
+                           const int w, const int h,
+                           const enum LrEdgeFlags edges);
+void dav1d_sgr_calc_ab2_avx2(int32_t *a, int16_t *b,
+                             const int w, const int h, const int strength);
+void dav1d_sgr_finish_filter2_avx2(coef *tmp,
+                                   const pixel *src, const ptrdiff_t stride,
+                                   const int32_t *a, const int16_t *b,
+                                   const int w, const int h);
+
+// filter with a 5x5 box (radius=2)
+static void dav1d_sgr_filter2_avx2(coef *tmp,
+                                   const pixel *src, const ptrdiff_t stride,
+                                   const pixel (*left)[4],
+                                   const pixel *lpf, const ptrdiff_t lpf_stride,
+                                   const int w, const int h, const int strength,
+                                   const enum LrEdgeFlags edges)
+{
+    ALIGN_STK_32(int32_t, sumsq_mem, (384 + 16) * 68 + 8,);
+    int32_t *const sumsq = &sumsq_mem[(384 + 16) * 2 + 8], *const a = sumsq;
+    ALIGN_STK_32(int16_t, sum_mem, (384 + 16) * 68 + 16,);
+    int16_t *const sum = &sum_mem[(384 + 16) * 2 + 16], *const b = sum;
+
+    dav1d_sgr_box5_h_avx2(sumsq, sum, left, src, stride, w, h, edges);
+    if (edges & LR_HAVE_TOP)
+        dav1d_sgr_box5_h_avx2(&sumsq[-2 * (384 + 16)], &sum[-2 * (384 + 16)],
+                              NULL, lpf, lpf_stride, w, 2, edges);
+
+    if (edges & LR_HAVE_BOTTOM)
+        dav1d_sgr_box5_h_avx2(&sumsq[h * (384 + 16)], &sum[h * (384 + 16)],
+                              NULL, lpf + 6 * PXSTRIDE(lpf_stride),
+                              lpf_stride, w, 2, edges);
+
+    dav1d_sgr_box5_v_avx2(sumsq, sum, w, h, edges);
+    dav1d_sgr_calc_ab2_avx2(a, b, w, h, strength);
+    dav1d_sgr_finish_filter2_avx2(tmp, src, stride, a, b, w, h);
+}
+
+void dav1d_sgr_weighted1_avx2(pixel *dst, const ptrdiff_t stride,
+                              const coef *t1, const int w, const int h,
+                              const int wt);
+void dav1d_sgr_weighted2_avx2(pixel *dst, const ptrdiff_t stride,
+                              const coef *t1, const coef *t2,
+                              const int w, const int h,
+                              const int16_t wt[2]);
+
+static void sgr_filter_avx2(pixel *const dst, const ptrdiff_t dst_stride,
+                            const pixel (*const left)[4],
+                            const pixel *lpf, const ptrdiff_t lpf_stride,
+                            const int w, const int h, const int sgr_idx,
+                            const int16_t sgr_wt[7], const enum LrEdgeFlags edges)
+{
+    if (!dav1d_sgr_params[sgr_idx][0]) {
+        ALIGN_STK_32(coef, tmp, 64 * 384,);
+        dav1d_sgr_filter1_avx2(tmp, dst, dst_stride, left, lpf, lpf_stride,
+                               w, h, dav1d_sgr_params[sgr_idx][3], edges);
+        dav1d_sgr_weighted1_avx2(dst, dst_stride, tmp, w, h, (1 << 7) - sgr_wt[1]);
+    } else if (!dav1d_sgr_params[sgr_idx][1]) {
+        ALIGN_STK_32(coef, tmp, 64 * 384,);
+        dav1d_sgr_filter2_avx2(tmp, dst, dst_stride, left, lpf, lpf_stride,
+                               w, h, dav1d_sgr_params[sgr_idx][2], edges);
+        dav1d_sgr_weighted1_avx2(dst, dst_stride, tmp, w, h, sgr_wt[0]);
+    } else {
+        ALIGN_STK_32(coef, tmp1, 64 * 384,);
+        ALIGN_STK_32(coef, tmp2, 64 * 384,);
+        dav1d_sgr_filter2_avx2(tmp1, dst, dst_stride, left, lpf, lpf_stride,
+                               w, h, dav1d_sgr_params[sgr_idx][2], edges);
+        dav1d_sgr_filter1_avx2(tmp2, dst, dst_stride, left, lpf, lpf_stride,
+                               w, h, dav1d_sgr_params[sgr_idx][3], edges);
+        const int16_t wt[2] = { sgr_wt[0], 128 - sgr_wt[0] - sgr_wt[1] };
+        dav1d_sgr_weighted2_avx2(dst, dst_stride, tmp1, tmp2, w, h, wt);
+    }
+}
 #endif
 
 void bitfn(dav1d_loop_restoration_dsp_init_x86)(Dav1dLoopRestorationDSPContext *const c) {
@@ -82,5 +205,6 @@
 
 #if BITDEPTH == 8 && ARCH_X86_64
     c->wiener = wiener_filter_avx2;
+    c->selfguided = sgr_filter_avx2;
 #endif
 }
--- a/tests/checkasm/checkasm.h
+++ b/tests/checkasm/checkasm.h
@@ -73,7 +73,7 @@
 
 static void *func_ref, *func_new;
 
-#define BENCH_RUNS (1 << 16) /* Trade-off between accuracy and speed */
+#define BENCH_RUNS (1 << 12) /* Trade-off between accuracy and speed */
 
 /* Decide whether or not the specified function needs to be tested */
 #define check_func(func, ...)\