shithub: mc

ref: 4a7d019d03fa6c267d27ffb9d90e2472bad4bcbf
dir: /libstd/resolve.myr/

View raw version
use "alloc.use"
use "die.use"
use "endian.use"
use "error.use"
use "fmt.use"
use "slcp.use"
use "sys.use"
use "types.use"

pkg std =
	type resolveerr = union
		`Badhost
		`Badsrv
		`Badquery
		`Badresp
	;;

	type netaddr = union
		`Ipv4	byte[4]
		`Ipv6	byte[16]
	;;

	type hostinfo = struct
		fam	: sockfam
		stype	: socktype
		ttl	: uint32
		addr	: netaddr
	/*
		proto	: uint32
		flags	: uint32
		addr	: sockaddr[:]
		canon	: byte[:]
		next	: hostinfo#
	*/
	;;

	const resolve	: (host : byte[:]	-> error(hostinfo[:], resolveerr))
;;

const resolve = {host : byte[:]
	/* FIXME: read /etc/hosts */
	-> dnsresolve(host)
}

const dnsresolve = {host : byte[:]
	/*var hosts*/
	var nsrv

	if !valid(host)
		-> `Failure (`Badhost)
	;;
	if (nsrv = dnsconnect()) < 0
		-> `Failure (`Badsrv)
	;;
	-> dnsquery(nsrv, host)
}

const dnsconnect = {
	var sa : sockaddr_in
	var s
	var status
	
	s = socket(Afinet, Sockdgram, 0)
	if s < 0
		put("Warning: Failed to open socket: %l\n", s)
		-> -1
	;;
	/* hardcode Google DNS for now.
	FIXME: parse /etc/resolv.conf */
	sa.fam = Afinet
	sa.port = hosttonet(53) /* port 53 */
	sa.addr = [8,8,8,8]	/* 8.8.8.8 */
	status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in))
	if status < 0
		put("Warning: Failed to connect to server: %l\n", status)
		-> -1
	;;
	-> s
}

const dnsquery = {srv, host
	var id
	var r

	id = tquery(srv, host)
	r = rquery(srv, id)
	put("Got hosts. Returning\n")
	-> r
}

const Qr : uint16 = 1 << 0
const Aa : uint16 = 1 << 5
const Tc : uint16 = 1 << 6
const Rd : uint16 = 1 << 7
const Ra : uint16 = 1 << 8

var nextid : uint16 = 42
const tquery = {srv, host
	var pkt : byte[512] /* big enough */
	var off : size

	put("Sending request for %s\n", host)
	/* header */
	off = 0
	off += pack16(pkt[:], off, nextid)	/* id */
	off += pack16(pkt[:], off, Ra)	/* flags */
	off += pack16(pkt[:], off, 1)	/* qdcount */
	off += pack16(pkt[:], off, 0)	/* ancount */
	off += pack16(pkt[:], off, 0)	/* nscount */
	off += pack16(pkt[:], off, 0)	/* arcount */

	/* query */
	off += packname(pkt[:], off, host)	/* host */
	off += pack16(pkt[:], off, 0x1) /* qtype: a record */
	off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */

	write(srv, pkt[:off])
	-> nextid++
}

const rquery = {srv, id
	var pktbuf : byte[1024]
	var pkt
	var n

	put("Waiting for response...\n")
	n = read(srv, pktbuf[:])
	if n < 0
		put("Warning: Failed to read from %z: %i\n", srv, n)
	;;
	pkt = pktbuf[:n]
	put("Got response:\n");
	dumpresponse(pkt)
	-> hosts(pkt, id)
}

const hosts = {pkt, id : uint16
	var off
	var v
	var q
	var a
	var i
	var hinf : hostinfo[:]

	off = 0
	/* parse header */
	(v, off) = unpack16(pkt, off)	/* id */
	if v != id
		-> `Failure (`Badresp)
	;;
	put("Unpacking flags")
	(v, off) = unpack16(pkt, off)	/* flags */
	(q, off) = unpack16(pkt, off)	/* qdcount */
	(a, off) = unpack16(pkt, off)	/* ancount */
	(v, off) = unpack16(pkt, off)	/* nscount */
	(v, off) = unpack16(pkt, off)	/* arcount */

	/* skip past query records */
	for i = 0; i < q; i++
		put("Skipping query record")
		off = skipname(pkt, off)	/* name */
		(v, off) = unpack16(pkt, off)	/* type */
		(v, off) = unpack16(pkt, off)	/* class */
	;;

	/* parse answer records */
	hinf = slalloc(a castto(size))
	for i = 0; i < a; i++
		off = skipname(pkt, off)	/* name */
		(v, off) = unpack16(pkt, off)	/* type */
		(v, off) = unpack16(pkt, off)	/* class */
		(hinf[i].ttl, off) = unpack32(pkt, off)	/* ttl */
		(v, off) = unpack16(pkt, off)	/* rdatalen */
		/* the thing we're interested in: our IP address */
		hinf[i].addr = `Ipv4 [pkt[off], pkt[off+1], pkt[off+2], pkt[off+3]]
		off += 4;
	;;
	-> `Success hinf
}


const dumpresponse = {pkt
	var nquery
	var nans
	var off
	var v
	var i

	(v, off) = unpack16(pkt, 0)
	(v, off) = unpack16(pkt, off)
	(nquery, off) = unpack16(pkt, off)
	put("hdr.qdcount = %w\n", nquery)
	(nans, off) = unpack16(pkt, off)
	put("hdr.ancount = %w\n", nans)
	(v, off) = unpack16(pkt, off)
	put("hdr.nscount = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("hdr.arcount = %w\n", v)
	put("Queries:\n")
	for i = 0; i < nquery; i++
		put("i: %w\n", i)
		off = dumpquery(pkt, off)
	;;

	put("Answers:")
	for i = 0; i < nans; i++
		put("i: %w\n", i)
		off = dumpans(pkt, off)
	;;
}

const dumpquery = {pkt, off
	var v

	put("\tname = "); 
	off = printname(pkt, off)
	(v, off) = unpack16(pkt, off)
	put("\tbody.type = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.class = %w\n", v)
	-> off
}

const dumpans = {pkt, off
	var v

	put("\tname = "); 
	off = printname(pkt, off)
	(v, off) = unpack16(pkt, off)
	put("\tbody.type = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.class = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.ttl_lo = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.ttl_hi = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.rdlength = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.rdata_lo = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.rdata_hi = %w\n", v)
	-> off
}

const skipname = {pkt, off
	var sz

	for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
		/* ptr is 2 bytes */
		if sz & 0xC0 == 0xC0
			-> off + 2
		else
			off += sz + 1
		;;
	;;
	-> off + 1
}


const printname = {pkt, off
	var sz

	for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
		if sz & 0xC0 == 0xC0
			put("PTR: ")
			printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size)))
			-> off + 2
		else
			put("%s.", pkt[off+1:off+sz+1])
			off += sz + 1
		;;
	;;
	-> off + 1
}

const pack16 = {buf, off, v
	buf[off]	= (v & 0xff00) >> 8 castto(byte)
	buf[off+1]	= (v & 0x00ff) castto(byte)
	-> sizeof(uint16) /* we always write one uint16 */
}

const unpack16 = {buf, off
	var v

	v = (buf[off] castto(uint16)) << 8
	v |= (buf[off + 1] castto(uint16))
	-> (v, off+sizeof(uint16))
}

const unpack32 = {buf, off
	var v

	v = (buf[off] castto(uint32)) << 24
	v |= (buf[off+1] castto(uint32)) << 32
	v |= (buf[off+2] castto(uint32)) << 8
	v |= (buf[off+3] castto(uint32))
	-> (v, off+sizeof(uint32))
}

const packname = {buf, off : size, host
	var i
	var start
	var seglen
	var lastseg

	start = off
	seglen = 0
	lastseg = 0
	for i = 0; i < host.len; i++
		seglen++
		if host[i] == ('.' castto(byte))
			off += addseg(buf, off, host[lastseg:lastseg+seglen-1])
			lastseg = seglen
			seglen = 0
		;;
	;;
	if host[host.len - 1] != ('.' castto(byte))
		off += addseg(buf, off, host[lastseg:lastseg + seglen])
	;;
	off += addseg(buf, off, "") /* null terminating segment */
	-> off - start
}

const addseg = {buf, off, str
	buf[off] = str.len castto(byte)
	slcp(buf[off + 1 : off + str.len + 1], str)
	-> str.len + 1
}

const valid = {host : byte[:]
	var i
	var seglen

	/* maximum length: 255 chars */
	if host.len > 255
		-> false
	;;

	seglen = 0
	for i = 0; i < host.len; i++
		if host[i] == ('.' castto(byte))
			seglen = 0
		;;
		if seglen > 63
			-> false
		;;
		if host[i] & 0x80
			-> false
		;;
	;;

	-> true

}