ref: 2788f5eea5708ab8554bac41411543bef360389e
dir: /libstd/resolve.myr/
use "alloc.use" use "die.use" use "endian.use" use "error.use" use "fmt.use" use "option.use" use "slcp.use" use "slurp.use" use "strsplit.use" use "strstrip.use" use "sys.use" use "types.use" pkg std = type resolveerr = union `Badhost `Badsrv `Badquery `Badresp ;; type netaddr = union `Ipv4 byte[4] `Ipv6 byte[16] ;; type hostinfo = struct fam : sockfam stype : socktype ttl : uint32 addr : netaddr /* proto : uint32 flags : uint32 addr : sockaddr[:] canon : byte[:] next : hostinfo# */ ;; const resolve : (host : byte[:] -> error(hostinfo[:], resolveerr)) ;; const Hostfile = "/etc/hosts" const resolve = {host : byte[:] match hostfind(host) | `Some h: -> h | `None: -> dnsresolve(host) ;; } const hostfind = {host -> `None /* var hdat var lines var ip var hn var str var i match slurp(Hostfile) | `Success h: hdat = h | `Failure m: -> `None ;; lines = strsplit(hdat, "\n") for i = 0; i < lines.len; i++ lines[i] = strstrip(lines[i]) (ip, str) = nextword(lines) (hn, str) = nextword(str) if streq(hn, host) -> parseip(ip) ;; ;; */ } const dnsresolve = {host : byte[:] /*var hosts*/ var nsrv if !valid(host) -> `Failure (`Badhost) ;; if (nsrv = dnsconnect()) < 0 -> `Failure (`Badsrv) ;; -> dnsquery(nsrv, host) } const dnsconnect = { var sa : sockaddr_in var s var status s = socket(Afinet, Sockdgram, 0) if s < 0 put("Warning: Failed to open socket: %l\n", s) -> -1 ;; /* hardcode Google DNS for now. FIXME: parse /etc/resolv.conf */ sa.fam = Afinet sa.port = hosttonet(53) /* port 53 */ sa.addr = [8,8,8,8] /* 8.8.8.8 */ status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in)) if status < 0 put("Warning: Failed to connect to server: %l\n", status) -> -1 ;; -> s } const dnsquery = {srv, host var id var r id = tquery(srv, host) r = rquery(srv, id) put("Got hosts. Returning\n") -> r } const Qr : uint16 = 1 << 0 const Aa : uint16 = 1 << 5 const Tc : uint16 = 1 << 6 const Rd : uint16 = 1 << 7 const Ra : uint16 = 1 << 8 var nextid : uint16 = 42 const tquery = {srv, host var pkt : byte[512] /* big enough */ var off : size put("Sending request for %s\n", host) /* header */ off = 0 off += pack16(pkt[:], off, nextid) /* id */ off += pack16(pkt[:], off, Ra) /* flags */ off += pack16(pkt[:], off, 1) /* qdcount */ off += pack16(pkt[:], off, 0) /* ancount */ off += pack16(pkt[:], off, 0) /* nscount */ off += pack16(pkt[:], off, 0) /* arcount */ /* query */ off += packname(pkt[:], off, host) /* host */ off += pack16(pkt[:], off, 0x1) /* qtype: a record */ off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */ write(srv, pkt[:off]) -> nextid++ } const rquery = {srv, id var pktbuf : byte[1024] var pkt var n put("Waiting for response...\n") n = read(srv, pktbuf[:]) if n < 0 put("Warning: Failed to read from %z: %i\n", srv, n) ;; pkt = pktbuf[:n] put("Got response:\n"); dumpresponse(pkt) -> hosts(pkt, id) } const hosts = {pkt, id : uint16 var off var v, q, a var i var hinf : hostinfo[:] off = 0 /* parse header */ (v, off) = unpack16(pkt, off) /* id */ if v != id -> `Failure (`Badresp) ;; put("Unpacking flags") (v, off) = unpack16(pkt, off) /* flags */ (q, off) = unpack16(pkt, off) /* qdcount */ (a, off) = unpack16(pkt, off) /* ancount */ (v, off) = unpack16(pkt, off) /* nscount */ (v, off) = unpack16(pkt, off) /* arcount */ /* skip past query records */ for i = 0; i < q; i++ put("Skipping query record") off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ ;; /* parse answer records */ hinf = slalloc(a castto(size)) for i = 0; i < a; i++ off = skipname(pkt, off) /* name */ (v, off) = unpack16(pkt, off) /* type */ (v, off) = unpack16(pkt, off) /* class */ (hinf[i].ttl, off) = unpack32(pkt, off) /* ttl */ (v, off) = unpack16(pkt, off) /* rdatalen */ /* the thing we're interested in: our IP address */ hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]] off += 4; ;; -> `Success hinf } const dumpresponse = {pkt var nquery, nans var off var v var i (v, off) = unpack16(pkt, 0) (v, off) = unpack16(pkt, off) (nquery, off) = unpack16(pkt, off) put("hdr.qdcount = %w\n", nquery) (nans, off) = unpack16(pkt, off) put("hdr.ancount = %w\n", nans) (v, off) = unpack16(pkt, off) put("hdr.nscount = %w\n", v) (v, off) = unpack16(pkt, off) put("hdr.arcount = %w\n", v) put("Queries:\n") for i = 0; i < nquery; i++ put("i: %w\n", i) off = dumpquery(pkt, off) ;; put("Answers:") for i = 0; i < nans; i++ put("i: %w\n", i) off = dumpans(pkt, off) ;; } const dumpquery = {pkt, off var v put("\tname = "); off = printname(pkt, off) (v, off) = unpack16(pkt, off) put("\tbody.type = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.class = %w\n", v) -> off } const dumpans = {pkt, off var v put("\tname = "); off = printname(pkt, off) (v, off) = unpack16(pkt, off) put("\tbody.type = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.class = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.ttl_lo = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.ttl_hi = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdlength = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdata_lo = %w\n", v) (v, off) = unpack16(pkt, off) put("\tbody.rdata_hi = %w\n", v) -> off } const skipname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) /* ptr is 2 bytes */ if sz & 0xC0 == 0xC0 -> off + 2 else off += sz + 1 ;; ;; -> off + 1 } const printname = {pkt, off var sz for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size) if sz & 0xC0 == 0xC0 put("PTR: ") printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size))) -> off + 2 else put("%s.", pkt[off+1:off+sz+1]) off += sz + 1 ;; ;; -> off + 1 } const pack16 = {buf, off, v buf[off] = (v & 0xff00) >> 8 castto(byte) buf[off+1] = (v & 0x00ff) castto(byte) -> sizeof(uint16) /* we always write one uint16 */ } const unpack16 = {buf, off var v v = (buf[off] castto(uint16)) << 8 v |= (buf[off + 1] castto(uint16)) -> (v, off+sizeof(uint16)) } const unpack32 = {buf, off var v v = (buf[off] castto(uint32)) << 24 v |= (buf[off+1] castto(uint32)) << 32 v |= (buf[off+2] castto(uint32)) << 8 v |= (buf[off+3] castto(uint32)) -> (v, off+sizeof(uint32)) } const packname = {buf, off : size, host var i var start var seglen, lastseg start = off seglen = 0 lastseg = 0 for i = 0; i < host.len; i++ seglen++ if host[i] == ('.' castto(byte)) off += addseg(buf, off, host[lastseg:lastseg+seglen-1]) lastseg = seglen seglen = 0 ;; ;; if host[host.len - 1] != ('.' castto(byte)) off += addseg(buf, off, host[lastseg:lastseg + seglen]) ;; off += addseg(buf, off, "") /* null terminating segment */ -> off - start } const addseg = {buf, off, str buf[off] = str.len castto(byte) slcp(buf[off + 1 : off + str.len + 1], str) -> str.len + 1 } const valid = {host : byte[:] var i var seglen /* maximum length: 255 chars */ if host.len > 255 -> false ;; seglen = 0 for i = 0; i < host.len; i++ if host[i] == ('.' castto(byte)) seglen = 0 ;; if seglen > 63 -> false ;; if host[i] & 0x80 -> false ;; ;; -> true }