ref: 4d34688e0e2c6a69081856112fa818b3d02f644d
dir: /libstd/resolve.myr/
use "alloc.use" use "chartype.use" use "die.use" use "endian.use" use "error.use" use "extremum.use" use "hashfuncs.use" use "htab.use" use "ipparse.use" use "fmt.use" use "option.use" use "slcp.use" use "slpush.use" use "slurp.use" use "strfind.use" use "strsplit.use" use "strstrip.use" use "sys.use" use "types.use" use "utf.use" pkg std = type resolveerr = union `Badhost `Badsrv `Badquery `Badresp ;; type hostinfo = struct fam : sockfam stype : socktype ttl : uint32 addr : netaddr /* proto : uint32 flags : uint32 addr : sockaddr[:] canon : byte[:] next : hostinfo# */ ;; const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr)) ;; const Hostfile = "/etc/hosts" const Resolvfile = "/etc/resolv.conf" var hostmap : htab(byte[:], hostinfo)# var search : byte[:][:] var nameservers : netaddr[:] var inited : bool = false const resolve = {host : byte[:] -> error(hostinfo[:], resolveerr) match hostfind(host) | `Some hinf: -> `Success slpush([][:], hinf) | `None: put("********** Couldn't find host %s in hosts\n", host) -> dnsresolve(host) ;; } const hostfind = {host if !inited hostmap = mkht(strhash, streq) loadhosts() inited = true ;; -> htget(hostmap, host) } const loadhosts = { var h var lines match slurp(Hostfile) | `Success d: h = d | `Failure m: -> ;; lines = strsplit(h, "\n") for l in lines /* trim comment */ match strfind(l, "#") | `Some _idx: l = l[:_idx] ;; match word(l) | `Some (ip, rest): match ipparse(ip) | `None: | `Some addr: addhosts(addr, ip, rest) ;; | `None: ;; ;; } const addhosts = {addr, as, str var hinf var fam match addr | `Ipv4 _: fam = Afinet | `Ipv6 _: fam = Afinet6 ;; while true match word(str) | `Some (name, rest): if !hthas(hostmap, name) hinf = [ .fam=fam, .stype = 0, .ttl = 0, .addr = addr ] htput(hostmap, name, hinf) ;; str = rest | `None: -> ;; ;; } const loadresolv = { } const word = {s var c, len len = 0 s = strstrip(s) for c = decode(s[len:]); c != Badchar && !isblank(c); c = decode(s[len:]) len += charlen(c) ;; if len == 0 -> `None else -> `Some (s[:len], s[len:]) ;; } const dnsresolve = {host : byte[:] /*var hosts*/ var nsrv if !valid(host) -> `Failure (`Badhost) ;; if (nsrv = dnsconnect()) < 0 -> `Failure (`Badsrv) ;; -> dnsquery(nsrv, host) } const dnsconnect = { var sa : sockaddr_in var s var status s = socket(Afinet, Sockdgram, 0) if s < 0 put("Warning: Failed to open socket: %l\n", s) -> -1 ;; /* hardcode Google DNS for now. FIXME: parse /etc/resolv.conf */ sa.fam = Afinet sa.port = hosttonet(53) /* port 53 */ sa.addr = [8,8,8,8] /* 8.8.8.8 */ status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in)) if status < 0 put("Warning: Failed to connect to server: %l\n", status) -> -1 ;; -> s } const dnsquery = {srv, host var id var r id = tquery(srv, host) r = rquery(srv, id) put("Got hosts. Returning\n") -> r } const Qr : uint16 = 1 << 0 const Aa : uint16 = 1 << 5 const Tc : uint16 = 1 << 6 const Rd : uint16 = 1 << 7 const Ra : uint16 = 1 << 8 var nextid : uint16 = 42 const tquery = {srv, host var pkt : byte[512] /* big enough */ var off : size put("Sending request for %s\n", host) /* header */ off = 0 off += pack16(pkt[:], off, nextid) /* id */ off += pack16(pkt[:], off, Ra) /* flags */ off += pack16(pkt[:], off, 1) /* qdcount */ off += pack16(pkt[:], off, 0) /* ancount */ off += pack16(pkt[:], off, 0) /* nscount */ off += pack16(pkt[:], off, 0) /* arcount */ /* query */ off += packname(pkt[:], off, host) /* host */ off += pack16(pkt[:], off, 0x1) /* qtype: a record */ off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */ write(srv, pkt[:off]) -> nextid++ } const rquery = {srv, id var pktbuf : byte[1024] var pkt var n put("Waiting for response...\n") n = read(srv, pktbuf[:]) if n < 0 put("Warning: Failed to read from %z: %i\n", srv, n) ;; pkt = pktbuf[:n] put("Got response:\n"); dumpresponse(pkt) -> hosts(pkt, id) } const hosts = {pkt, id : uint16 var off var v, q, a var i var hinf : hostinfo[:] off = 0 /* parse header */ (v, off) = unpack16(pkt, off) /* id */ if v != id -> `Failure (`Badresp) ;; put("Unpacking flags") (v, off) = unpack16(pkt, off) /* flags */ (q, off) = unpack16(pkt, off) /* qdcount */ (a, off) = unpack16(pkt, off) /* ancount */ (v, off) = unpack16(pkt, off) /* nscount */ (v, off) = unpack16(pkt, off) /* arcount */ /* skip past query records */ for i = 0; i < q; i++ put("Skipping query record") off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ ;; /* parse answer records */ hinf = slalloc(a castto(size)) for i = 0; i < a; i++ off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ (hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */ (v, off) = unpack16(pkt, off) /* rdatalen */ /* the thing we're interested in: our IP address */ hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]] off += 4; ;; -> `Success hinf } const dumpresponse = {pkt var nquery, nans var off var v var i (v, off) = unpack16(pkt, 0) (v, off) = unpack16(pkt, off) (nquery, off) = unpack16(pkt, off) put("hdr.qdcount = %w\n", nquery) (nans, off) = unpack16(pkt, off) put("hdr.ancount = %w\n", nans) (v, off) = unpack16(pkt, off) put("hdr.nscount = %w\n", v) (v, off) = unpack16(pkt, off) put("hdr.arcount = %w\n", v) put("Queries:\n") for i = 0; i < nquery; i++ put("i: %w\n", i) off = dumpquery(pkt, off) ;; put("Answers:") for i = 0; i < nans; i++ put("i: %w\n", i) off = dumpans(pkt, off) ;; } const dumpquery = {pkt, off var v put("\tname = "); off = printname(pkt, off) (v, off) = unpack16(pkt, off) put("\tbody.type = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.class = %w\n", v) -> off } const dumpans = {pkt, off var v put("\tname = "); off = printname(pkt, off) (v, off) = unpack16(pkt, off) put("\tbody.type = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.class = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.ttl_lo = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.ttl_hi = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdlength = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdata_lo = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdata_hi = %w\n", v) -> off } const skipname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) /* ptr is 2 bytes */ if sz & 0xC0 == 0xC0 -> off + 2 else off += sz + 1 ;; ;; -> off + 1 } const printname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) if sz & 0xC0 == 0xC0 put("PTR: ") printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size))) -> off + 2 else put("%s.", pkt[off+1:off+sz+1]) off += sz + 1 ;; ;; -> off + 1 } const pack16 = {buf, off, v buf[off] = (v & 0xff00) >> 8 castto(byte) buf[off+1] = (v & 0x00ff) castto(byte) -> sizeof(uint16) /* we always write one uint16 */ } const unpack16 = {buf, off var v v = (buf[off] castto(uint16)) << 8 v |= (buf[off + 1] castto(uint16)) -> (v, off+sizeof(uint16)) } const unpack32 = {buf, off var v v = (buf[off] castto(uint32)) << 24 v |= (buf[off+1] castto(uint32)) << 32 v |= (buf[off+2] castto(uint32)) << 8 v |= (buf[off+3] castto(uint32)) -> (v, off+sizeof(uint32)) } const packname = {buf, off : size, host var i var start var seglen, lastseg start = off seglen = 0 lastseg = 0 for i = 0; i < host.len; i++ seglen++ if host[i] == ('.' castto(byte)) off += addseg(buf, off, host[lastseg:lastseg+seglen-1]) lastseg = seglen seglen = 0 ;; ;; if host[host.len - 1] != ('.' castto(byte)) off += addseg(buf, off, host[lastseg:lastseg + seglen]) ;; off += addseg(buf, off, "") /* null terminating segment */ -> off - start } const addseg = {buf, off, str buf[off] = str.len castto(byte) slcp(buf[off + 1 : off + str.len + 1], str) -> str.len + 1 } const valid = {host : byte[:] var i var seglen /* maximum length: 255 chars */ if host.len > 255 -> false ;; seglen = 0 for i = 0; i < host.len; i++ if host[i] == ('.' castto(byte)) seglen = 0 ;; if seglen > 63 -> false ;; if host[i] & 0x80 -> false ;; ;; -> true }