ref: ce493274718acbe4ab1562f243cacfc9d00446a0
dir: /libstd/resolve.myr/
use "alloc.use" use "chartype.use" use "die.use" use "endian.use" use "error.use" use "extremum.use" use "fmt.use" use "hashfuncs.use" use "htab.use" use "ipparse.use" use "option.use" use "slcp.use" use "sleq.use" use "slpush.use" use "slurp.use" use "strfind.use" use "strsplit.use" use "strstrip.use" use "sys.use" use "types.use" use "utf.use" pkg std = type resolveerr = union `Badhost `Badsrv `Badquery `Badresp ;; type hostinfo = struct fam : sockfam stype : socktype ttl : uint32 addr : netaddr /* proto : uint32 flags : uint32 addr : sockaddr[:] canon : byte[:] next : hostinfo# */ ;; const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr)) ;; const Hostfile = "/etc/hosts" const Resolvfile = "/etc/resolv.conf" var hostmap : htab(byte[:], hostinfo)# var search : byte[:][:] var nameservers : netaddr[:] var inited : bool = false const resolve = {host : byte[:] -> error(hostinfo[:], resolveerr) match hostfind(host) | `Some hinf: -> `Success slpush([][:], hinf) | `None: put("********** Couldn't find host %s in hosts\n", host) -> dnsresolve(host) ;; } const hostfind = {host if !inited hostmap = mkht(strhash, streq) loadhosts() loadresolv() inited = true ;; -> htget(hostmap, host) } const loadhosts = { var h var lines match slurp(Hostfile) | `Success d: h = d | `Failure m: -> ;; lines = strsplit(h, "\n") for l in lines /* trim comment */ match strfind(l, "#") | `Some idx: l = l[:idx] ;; match word(l) | `Some (ip, rest): match ipparse(ip) | `Some addr: addhosts(addr, ip, rest) ;; | `None: ;; ;; slfree(lines) } const addhosts = {addr, as, str var hinf var fam match addr | `Ipv4 _: fam = Afinet | `Ipv6 _: fam = Afinet6 ;; while true match word(str) | `Some (name, rest): if !hthas(hostmap, name) hinf = [ .fam=fam, .stype = 0, .ttl = 0, .addr = addr ] htput(hostmap, name, hinf) ;; str = rest | `None: -> ;; ;; } const loadresolv = { var h var lines match slurp(Resolvfile) | `Success d: h = d | `Failure m: -> ;; lines = strsplit(h, "\n") for l in lines match strfind(l, "#") | `Some idx: l = l[:idx] | `None: ;; match word(l) | `Some (cmd, rest): if sleq(cmd, "nameserver") addns(rest) ;; ;; ;; slfree(lines) } const addns = {rest match word(rest) | `Some (name, _): match ipparse(name) | `Some addr: put("Adding nameserver %s\n", name) nameservers = slpush(nameservers, addr) | `None: ;; ;; } const word = {s var c, len len = 0 s = strfstrip(s) for c = decode(s[len:]); c != Badchar && !isblank(c); c = decode(s[len:]) len += charlen(c) ;; if len == 0 -> `None else -> `Some (s[:len], s[len:]) ;; } const dnsresolve = {host : byte[:] /*var hosts*/ var nsrv if !valid(host) -> `Failure (`Badhost) ;; for ns in nameservers put("trying ns\n") nsrv = dnsconnect(ns) if nsrv >= 0 -> dnsquery(nsrv, host) ;; ;; -> `Failure (`Badsrv) } const dnsconnect = {ns match ns | `Ipv4 addr: -> dnsconnectv4(ns) | `Ipv6 addr: die("don't support ipv6 yet\n") ;; } const dnsconnectv4 = {addr var sa : sockaddr_in var s var status s = socket(Afinet, Sockdgram, 0) if s < 0 put("Warning: Failed to open socket: %l\n", s) -> -1 ;; /* hardcode Google DNS for now. FIXME: parse /etc/resolv.conf */ sa.fam = Afinet sa.port = hosttonet(53) /* port 53 */ sa.addr = [8,8,8,8] /* 8.8.8.8 */ status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in)) if status < 0 put("Warning: Failed to connect to server: %l\n", status) -> -1 ;; -> s } const dnsquery = {srv, host var id var r id = tquery(srv, host) r = rquery(srv, id) put("Got hosts. Returning\n") -> r } const Qr : uint16 = 1 << 0 const Aa : uint16 = 1 << 5 const Tc : uint16 = 1 << 6 const Rd : uint16 = 1 << 7 const Ra : uint16 = 1 << 8 var nextid : uint16 = 42 const tquery = {srv, host var pkt : byte[512] /* big enough */ var off : size put("Sending request for %s\n", host) /* header */ off = 0 off += pack16(pkt[:], off, nextid) /* id */ off += pack16(pkt[:], off, Ra) /* flags */ off += pack16(pkt[:], off, 1) /* qdcount */ off += pack16(pkt[:], off, 0) /* ancount */ off += pack16(pkt[:], off, 0) /* nscount */ off += pack16(pkt[:], off, 0) /* arcount */ /* query */ off += packname(pkt[:], off, host) /* host */ off += pack16(pkt[:], off, 0x1) /* qtype: a record */ off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */ write(srv, pkt[:off]) -> nextid++ } const rquery = {srv, id var pktbuf : byte[1024] var pkt var n put("Waiting for response...\n") n = read(srv, pktbuf[:]) if n < 0 put("Warning: Failed to read from %z: %i\n", srv, n) ;; pkt = pktbuf[:n] put("Got response:\n"); dumpresponse(pkt) -> hosts(pkt, id) } const hosts = {pkt, id : uint16 var off var v, q, a var i var hinf : hostinfo[:] off = 0 /* parse header */ (v, off) = unpack16(pkt, off) /* id */ if v != id -> `Failure (`Badresp) ;; put("Unpacking flags") (v, off) = unpack16(pkt, off) /* flags */ (q, off) = unpack16(pkt, off) /* qdcount */ (a, off) = unpack16(pkt, off) /* ancount */ (v, off) = unpack16(pkt, off) /* nscount */ (v, off) = unpack16(pkt, off) /* arcount */ /* skip past query records */ for i = 0; i < q; i++ put("Skipping query record") off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ ;; /* parse answer records */ hinf = slalloc(a castto(size)) for i = 0; i < a; i++ off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ (hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */ (v, off) = unpack16(pkt, off) /* rdatalen */ /* the thing we're interested in: our IP address */ hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]] off += 4; ;; -> `Success hinf } const dumpresponse = {pkt var nquery, nans var off var v var i (v, off) = unpack16(pkt, 0) (v, off) = unpack16(pkt, off) (nquery, off) = unpack16(pkt, off) put("hdr.qdcount = %w\n", nquery) (nans, off) = unpack16(pkt, off) put("hdr.ancount = %w\n", nans) (v, off) = unpack16(pkt, off) put("hdr.nscount = %w\n", v) (v, off) = unpack16(pkt, off) put("hdr.arcount = %w\n", v) put("Queries:\n") for i = 0; i < nquery; i++ put("i: %w\n", i) off = dumpquery(pkt, off) ;; put("Answers:") for i = 0; i < nans; i++ put("i: %w\n", i) off = dumpans(pkt, off) ;; } const dumpquery = {pkt, off var v put("\tname = "); off = printname(pkt, off) (v, off) = unpack16(pkt, off) put("\tbody.type = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.class = %w\n", v) -> off } const dumpans = {pkt, off var v put("\tname = "); off = printname(pkt, off) (v, off) = unpack16(pkt, off) put("\tbody.type = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.class = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.ttl_lo = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.ttl_hi = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdlength = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdata_lo = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdata_hi = %w\n", v) -> off } const skipname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) /* ptr is 2 bytes */ if sz & 0xC0 == 0xC0 -> off + 2 else off += sz + 1 ;; ;; -> off + 1 } const printname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) if sz & 0xC0 == 0xC0 put("PTR: ") printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size))) -> off + 2 else put("%s.", pkt[off+1:off+sz+1]) off += sz + 1 ;; ;; -> off + 1 } const pack16 = {buf, off, v buf[off] = (v & 0xff00) >> 8 castto(byte) buf[off+1] = (v & 0x00ff) castto(byte) -> sizeof(uint16) /* we always write one uint16 */ } const unpack16 = {buf, off var v v = (buf[off] castto(uint16)) << 8 v |= (buf[off + 1] castto(uint16)) -> (v, off+sizeof(uint16)) } const unpack32 = {buf, off var v v = (buf[off] castto(uint32)) << 24 v |= (buf[off+1] castto(uint32)) << 32 v |= (buf[off+2] castto(uint32)) << 8 v |= (buf[off+3] castto(uint32)) -> (v, off+sizeof(uint32)) } const packname = {buf, off : size, host var i var start var seglen, lastseg start = off seglen = 0 lastseg = 0 for i = 0; i < host.len; i++ seglen++ if host[i] == ('.' castto(byte)) off += addseg(buf, off, host[lastseg:lastseg+seglen-1]) lastseg = seglen seglen = 0 ;; ;; if host[host.len - 1] != ('.' castto(byte)) off += addseg(buf, off, host[lastseg:lastseg + seglen]) ;; off += addseg(buf, off, "") /* null terminating segment */ -> off - start } const addseg = {buf, off, str buf[off] = str.len castto(byte) slcp(buf[off + 1 : off + str.len + 1], str) -> str.len + 1 } const valid = {host : byte[:] var i var seglen /* maximum length: 255 chars */ if host.len > 255 -> false ;; seglen = 0 for i = 0; i < host.len; i++ if host[i] == ('.' castto(byte)) seglen = 0 ;; if seglen > 63 -> false ;; if host[i] & 0x80 -> false ;; ;; -> true }