ref: 0226a1a631e50ac5ad042efc85cf5d5461f5b0f8
dir: /libstd/resolve.myr/
use "alloc.use" use "chartype.use" use "die.use" use "endian.use" use "error.use" use "extremum.use" use "fmt.use" use "hashfuncs.use" use "htab.use" use "ipparse.use" use "option.use" use "slcp.use" use "sleq.use" use "slpush.use" use "slurp.use" use "strfind.use" use "strsplit.use" use "strstrip.use" use "sys.use" use "types.use" use "utf.use" pkg std = type rectype = uint16 const DnsA : rectype = 1 /* host address */ const DnsNS : rectype = 2 /* authoritative name server */ const DnsMD : rectype = 3 /* mail destination (Obsolete - use MX) */ const DnsMF : rectype = 4 /* mail forwarder (Obsolete - use MX) */ const DnsCNAME : rectype = 5 /* canonical name for an alias */ const DnsSOA : rectype = 6 /* marks the start of a zone of authority */ const DnsMB : rectype = 7 /* mailbox domain name (EXPERIMENTAL) */ const DnsMG : rectype = 8 /* mail group member (EXPERIMENTAL) */ const DnsMR : rectype = 9 /* mail rename domain name (EXPERIMENTAL) */ const DnsNULL : rectype = 10 /* null RR (EXPERIMENTAL) */ const DnsWKS : rectype = 11 /* well known service description */ const DnsPTR : rectype = 12 /* domain name pointer */ const DnsHINFO : rectype = 13 /* host information */ const DnsMINFO : rectype = 14 /* mailbox or mail list information */ const DnsMX : rectype = 15 /* mail exchange */ const DnsTXT : rectype = 16 /* text strings */ const DnsAAAA : rectype = 28 /* ipv6 host address */ type resolveerr = union `Badhost `Badsrv `Badquery `Badresp ;; type hostinfo = struct fam : sockfam stype : socktype ttl : uint32 addr : netaddr /* flags : uint32 addr : sockaddr[:] canon : byte[:] */ ;; const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr)) const resolvemx : (host : byte[:] -> error(hostinfo[:], resolveerr)) const resolverec : (host : byte[:], t : rectype -> error(hostinfo[:], resolveerr)) ;; const Hostfile = "/etc/hosts" const Resolvfile = "/etc/resolv.conf" var hostmap : htab(byte[:], hostinfo)# var search : byte[:][:] var nameservers : netaddr[:] var inited : bool = false const resolve = {host -> resolverec(host, DnsA) } const resolvemx = {host -> resolverec(host, DnsMX) } const resolverec = {host, t match hostfind(host) | `Some hinf: -> `Success slpush([][:], hinf) | `None: -> dnsresolve(host, DnsA) ;; } const hostfind = {host if !inited hostmap = mkht(strhash, streq) loadhosts() loadresolv() inited = true ;; -> htget(hostmap, host) } const loadhosts = { var h var lines match slurp(Hostfile) | `Success d: h = d | `Failure m: -> ;; lines = strsplit(h, "\n") for l in lines /* trim comment */ match strfind(l, "#") | `Some _idx: l = l[:_idx] ;; match word(l) | `Some (ip, rest): match ipparse(ip) | `Some addr: addhosts(addr, ip, rest) ;; | `None: ;; ;; slfree(lines) } const addhosts = {addr, as, str var hinf var fam match addr | `Ipv4 _: fam = Afinet | `Ipv6 _: fam = Afinet6 ;; while true match word(str) | `Some (name, rest): if !hthas(hostmap, name) hinf = [ .fam=fam, .stype = 0, .ttl = 0, .addr = addr ] htput(hostmap, name, hinf) ;; str = rest | `None: -> ;; ;; } const loadresolv = { var h var lines match slurp(Resolvfile) | `Success d: h = d | `Failure m: -> ;; lines = strsplit(h, "\n") for l in lines match strfind(l, "#") | `Some _idx: l = l[:_idx] | `None: ;; match word(l) | `Some (cmd, rest): if sleq(cmd, "nameserver") addns(rest) ;; ;; ;; slfree(lines) } const addns = {rest match word(rest) | `Some (name, _): match ipparse(name) | `Some addr: nameservers = slpush(nameservers, addr) | `None: ;; ;; } const word = {s var c, len len = 0 s = strfstrip(s) for c = decode(s[len:]); c != Badchar && !isblank(c); c = decode(s[len:]) len += charlen(c) ;; if len == 0 -> `None else -> `Some (s[:len], s[len:]) ;; } const dnsresolve = {host, t var nsrv if !valid(host) -> `Failure (`Badhost) ;; for ns in nameservers nsrv = dnsconnect(ns) if nsrv >= 0 -> dnsquery(nsrv, host, t) ;; ;; -> `Failure (`Badsrv) } const dnsconnect = {ns match ns | `Ipv4 addr: -> dnsconnectv4(addr) | `Ipv6 addr: die("don't support ipv6 yet\n") ;; } const dnsconnectv4 = {addr var sa : sockaddr_in var s var status s = socket(Afinet, Sockdgram, 0) if s < 0 -> -1 ;; sa.fam = Afinet sa.port = hosttonet(53) sa.addr = addr status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in)) if status < 0 -> -1 ;; -> s } const dnsquery = {srv, host, t var id var r id = tquery(srv, host, t) r = rquery(srv, id) -> r } const Qr : uint16 = 1 << 0 const Aa : uint16 = 1 << 5 const Tc : uint16 = 1 << 6 const Rd : uint16 = 1 << 7 const Ra : uint16 = 1 << 8 var nextid : uint16 = 42 const tquery = {srv, host, t var pkt : byte[512] /* big enough */ var off : size /* header */ off = 0 off += pack16(pkt[:], off, nextid) /* id */ off += pack16(pkt[:], off, Ra) /* flags */ off += pack16(pkt[:], off, 1) /* qdcount */ off += pack16(pkt[:], off, 0) /* ancount */ off += pack16(pkt[:], off, 0) /* nscount */ off += pack16(pkt[:], off, 0) /* arcount */ /* query */ off += packname(pkt[:], off, host) /* host */ off += pack16(pkt[:], off, t castto(uint16)) /* qtype: a record */ off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */ write(srv, pkt[:off]) -> nextid++ } const rquery = {srv, id var pktbuf : byte[1024] var pkt var n n = read(srv, pktbuf[:]) if n < 0 ;; pkt = pktbuf[:n] -> hosts(pkt, id) } const hosts = {pkt, id : uint16 var off var v, q, a var i var hinf : hostinfo[:] off = 0 /* parse header */ (v, off) = unpack16(pkt, off) /* id */ if v != id -> `Failure (`Badresp) ;; (v, off) = unpack16(pkt, off) /* flags */ (q, off) = unpack16(pkt, off) /* qdcount */ (a, off) = unpack16(pkt, off) /* ancount */ (v, off) = unpack16(pkt, off) /* nscount */ (v, off) = unpack16(pkt, off) /* arcount */ /* skip past query records */ for i = 0; i < q; i++ off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ ;; /* parse answer records */ hinf = slalloc(a castto(size)) for i = 0; i < a; i++ off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ (hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */ (v, off) = unpack16(pkt, off) /* rdatalen */ /* the thing we're interested in: our IP address */ hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]] off += 4; ;; -> `Success hinf } const skipname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) /* ptr is 2 bytes */ if sz & 0xC0 == 0xC0 -> off + 2 else off += sz + 1 ;; ;; -> off + 1 } const pack16 = {buf, off, v buf[off] = (v & 0xff00) >> 8 castto(byte) buf[off+1] = (v & 0x00ff) castto(byte) -> sizeof(uint16) /* we always write one uint16 */ } const unpack16 = {buf, off var v v = (buf[off] castto(uint16)) << 8 v |= (buf[off + 1] castto(uint16)) -> (v, off+sizeof(uint16)) } const unpack32 = {buf, off var v v = (buf[off] castto(uint32)) << 24 v |= (buf[off+1] castto(uint32)) << 32 v |= (buf[off+2] castto(uint32)) << 8 v |= (buf[off+3] castto(uint32)) -> (v, off+sizeof(uint32)) } const packname = {buf, off : size, host var i var start var last start = off last = 0 for i = 0; i < host.len; i++ if host[i] == ('.' castto(byte)) off += addseg(buf, off, host[last:i]) last = i + 1 ;; ;; if host[host.len - 1] != ('.' castto(byte)) off += addseg(buf, off, host[last:]) ;; off += addseg(buf, off, "") /* null terminating segment */ -> off - start } const addseg = {buf, off, str buf[off] = str.len castto(byte) slcp(buf[off + 1 : off + str.len + 1], str) -> str.len + 1 } const valid = {host : byte[:] var i var seglen /* maximum length: 255 chars */ if host.len > 255 -> false ;; seglen = 0 for i = 0; i < host.len; i++ if host[i] == ('.' castto(byte)) seglen = 0 ;; if seglen > 63 -> false ;; if host[i] & 0x80 -> false ;; ;; -> true }