ref: 76ce9de16baeec77b86ece6825b90715b08545ba
dir: /lib/flate/flate.myr/
use std use bio use "types" /* DEFLATE https://www.ietf.org/rfc/rfc1951.txt Terminology: - A 'Huffman code' is one mechanism to map codes to symbols. - A 'code' is the sequence of bits used to encode a symbol in a Huffman code. Codes are what appears in the DEFLATE stream. - A 'symbol' is what a code decodes to. - A 'length' is the length of an LZ77 reference (the number of bytes to repeat). - A 'distance' is the distance at which the bits referred to by an LZ77 reference are from the current point. */ pkg flate = /* state maintained during decompression */ type flatedec = struct rdf : bio.file# wrf : bio.file# /* one-byte buffer */ bval : int /* in [0, 255] current input byte */ bpos : int /* in [0, 8] */ /* ring buffer window for the bytes output */ wpos : std.size wlen : std.size wbuf : byte[WinSz] /* Huffman codes */ lenhc : hufc /* for lengths */ dsthc : hufc /* for distances */ ;; /* one Huffman code */ type hufc = struct /* # of codes of length [1, MaxBits] */ count : int[MaxBits + 1] /* symbols sorted lexicographically by <code length, symbol value>; c.f., mkhufc() and readcode() for usage */ symbol : int[MaxSyms] ;; /* decode a DEFLATE stream */ const decode : (outf : bio.file#, inf : bio.file# -> std.result(void, err)) pkglocal const MaxSyms : std.size pkglocal const MaxBits : std.size pkglocal const WinSz : std.size pkglocal const bits : (st : flatedec#, n : int -> std.result(int, err)) pkglocal const nocomp : (st : flatedec# -> std.result(void, err)) pkglocal const mkhufc : (lengths : int[:], hc : hufc# -> void) pkglocal const readcode : (st : flatedec#, hc : hufc# -> std.result(int, err)) pkglocal const outb : (st : flatedec#, b : byte -> std.result(void, err)) ;; const MaxLenSyms : std.size = 288 /* max # of length symbols */ const MaxDstSyms : std.size = 32 /* max # of distance symbols */ const MaxSyms = MaxLenSyms /* max # of symbols in a Huffman code */ const MaxBits = 15 /* max length of a code (in bits) */ const WinSz = 32 * std.KiB /* max distance for a reference */ var fixedlenhc : hufc var fixeddsthc : hufc const __init__ = { /* the fixed Huffman codes; c.f., RFC Section 3.2.6 */ var length : int[MaxSyms] std.slfill(length[0:144], 8) std.slfill(length[144:256], 9) std.slfill(length[256:280], 7) std.slfill(length[280:288], 8) mkhufc(length[:], &fixedlenhc) std.slfill(length[0:32], 5) mkhufc(length[0:32], &fixeddsthc) } const outb = {st, b /* outputs a single byte, to the output file; additionally record the byte emitted in the output window */ var pos match bio.putb(st.wrf, b) | `std.Ok 0: std.die("no bytes written by bio.putb()") | `std.Ok _: | `std.Err e: -> `std.Err (`Io e) ;; pos = st.wpos if pos == WinSz std.assert(st.wlen >= WinSz, "window should be full") pos = 0 ;; st.wbuf[pos] = b st.wpos = pos + 1 st.wlen++ /* don't care about it being > WinSz */ -> `std.Ok void } const bits = {st, n var pos, len, bits, bval std.assert(n >= 0 && n <= 31, "invalid argument") pos = st.bpos bits = st.bval >> pos for len = 8 - pos; len < n; len += 8 match bio.getb(st.rdf) | `std.Ok b: bval = (b : int) st.bval = bval bits += bval << len | `std.Err e: -> `std.Err (`Io e) ;; ;; st.bpos = 8 - (len - n) -> `std.Ok (bits & ((1 << n) - 1)) } const nocomp = {st /* process a non-compressed block; c.f., RFC Section 3.2.4 */ var len st.bpos = 8 /* skip to a byte boundary */ match bits(st, 16) | `std.Ok n: len = (n : std.size) | `std.Err e: -> `std.Err e ;; match bits(st, 16) | `std.Ok nlen: if len != (nlen : std.size) ^ 0xffff -> `std.Err (`Format \ "non-compressed block length could " \ "not be verified") ;; | `std.Err e: -> `std.Err e ;; for var n = 0; n < len; n++ match bio.getb(st.rdf) | `std.Ok b: match outb(st, b) | `std.Ok _: | err: -> err ;; | `std.Err e: -> `std.Err (`Io e) ;; ;; -> `std.Ok void } const mkhufc = {length, hc /* note: it is possible to have 0s in the length array; this means that the symbol at this index will never be part of the input stream */ var index : int[MaxBits + 1] std.assert(length.len <= hc.symbol.len, "invalid argument") std.slfill(hc.count[:], 0) for var symb = 0; symb < length.len; symb++ hc.count[length[symb]]++ ;; std.slfill(index[:], 0) hc.count[0] = 0 for var l = 1; l <= MaxBits; l++ index[l] = index[l - 1] + hc.count[l - 1] ;; for var symb = 0; symb < length.len; symb++ if length[symb] != 0 hc.symbol[index[length[symb]]++] = symb ;; ;; /* TODO: validate the Huffman code */ } const readcode = {st, hc /* if you want to know more about the algorithm here, you must look for "canonical Huffman codes"; see, for example, Managing Gigabytes: Compressing and Indexing Documents and Images by I. H. Witten, A. Moffat, and T. C. Bell. */ var code, first, index, count code = 0 /* code seen so far */ first = 0 /* first code of length len */ index = 0 /* index of the first code of length len in hc.symbol[:] */ for var len = 1; len <= MaxBits; len++ match bits(st, 1) | `std.Ok n: code += n | err: -> err ;; count = hc.count[len] if code < first + count /* the code is exactly len bits long; return the symbol from the table */ -> `std.Ok (hc.symbol[index + (code - first)]) ;; index += count first = (first + count) << 1 /* c.f. RFC Section 3.2.2 */ code <<= 1 /* make room for the next bit */ ;; -> `std.Err (`Format "invalid code") } const decomp = {st /* process data of a compressed block; c.f., RFC Section 3.2.5 */ const basel = [ 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258 ] const extral = [ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 ] const based = [ 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577 ] const extrad = [ 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 ] var symb, len, dst, index symb = 0 while symb != 256 match readcode(st, &st.lenhc) | `std.Ok s: symb = s | `std.Err e: -> `std.Err e ;; if symb < 256 match outb(st, (symb : byte)) | `std.Ok _: | err: -> err ;; elif symb > 256 symb -= 257 match bits(st, extral[symb]) | `std.Ok x: len = basel[symb] + x | `std.Err e: -> `std.Err e ;; match readcode(st, &st.dsthc) | `std.Ok s: symb = s | `std.Err e: -> `std.Err e ;; match bits(st, extrad[symb]) | `std.Ok x: dst = (based[symb] + x : std.size) | `std.Err e: -> `std.Err e ;; if dst > st.wlen -> `std.Err (`Format \ "reference to a byte before file start") ;; index = (WinSz + st.wpos - dst) % WinSz for ; len > 0; len-- match outb(st, st.wbuf[index]) | `std.Ok _: | err: -> err ;; index = (index + 1) % WinSz ;; ;; ;; -> `std.Ok void } const dynamic = {st /* process a block compressed with dynamic Huffman codes; c.f., RFC Section 3.2.7 */ const coden = [ 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 ] const extra = [2, 3, 7] var length : int[MaxLenSyms + MaxDstSyms] var nlen, ndst, ncode, symb, nbits, hufc match bits(st, 5) | `std.Ok n: nlen = n + 257 | `std.Err e: -> `std.Err e ;; if nlen > 286 -> `std.Err (`Format "invalid number of length codes") ;; match bits(st, 5) | `std.Ok n: ndst = n + 1 | `std.Err e: -> `std.Err e ;; if ndst > 30 -> `std.Err (`Format "invalid number of distance codes") ;; match bits(st, 4) | `std.Ok n: ncode = n + 4 | `std.Err e: -> `std.Err e ;; std.slfill(length[:19], 0) for var c = 0; c < ncode; c++ match bits(st, 3) | `std.Ok n: length[coden[c]] = n | `std.Err e: -> `std.Err e ;; ;; mkhufc(length[:19], &hufc) for var c = 0; c < nlen + ndst; /* according to the RFC: "the code length repeat codes can cross from HLIT + 257 to the HDIST + 1 code lengths"; this is why we have a single loop here */ match readcode(st, &hufc) | `std.Ok n: symb = n | `std.Err e: -> `std.Err e ;; if symb > 18 -> `std.Err (`Format "invalid code code") ;; if symb < 16 length[c++] = symb else /* repeat some code length */ match bits(st, extra[symb - 16]) | `std.Ok n: nbits = n | `std.Err e: -> `std.Err e ;; if symb == 16 /* copy the previous code length 3-6 times */ nbits += 3 if c == 0 || c + nbits > nlen + ndst -> `std.Err (`Format \ "invalid code repetition") ;; std.slfill(length[c:c+nbits], length[c-1]) else /* repeat a 0 length */ if symb == 17 nbits += 3 else nbits += 11 ;; if c + nbits > nlen + ndst -> `std.Err (`Format \ "invalid zero repetition") ;; std.slfill(length[c:c+nbits], 0) ;; c += nbits ;; ;; if length[256] == 0 -> `std.Err (`Format "no code length of end-of-block symbol") ;; mkhufc(length[:nlen], &st.lenhc) mkhufc(length[nlen:nlen + ndst], &st.dsthc) -> decomp(st) } const decode = {outf, inf var st, err, last st = [ .rdf = inf, .wrf = outf, .bpos = 8, ] last = 0 while last != 1 match bits(&st, 1) | `std.Ok l: last = l | `std.Err e: -> `std.Err e ;; match bits(&st, 2) | `std.Ok 0: err = nocomp(&st) | `std.Ok 1: st.lenhc = fixedlenhc st.dsthc = fixeddsthc err = decomp(&st) | `std.Ok 2: err = dynamic(&st) | `std.Ok _: -> `std.Err (`Format "invalid block type") | `std.Err e: -> `std.Err e ;; match err | `std.Ok _: | `std.Err e: -> `std.Err e ;; ;; -> `std.Ok void }