shithub: mc

ref: 8bd18e2d1caef67781ae69d761bdfca8d2bcf41d
dir: /libstd/resolve.myr/

View raw version
use "alloc.use"
use "die.use"
use "endian.use"
use "error.use"
use "fmt.use"
use "slcp.use"
use "sys.use"
use "types.use"

pkg std =
	type resolveerr = union
		`Badhost
		`Badsrv
		`Badquery
	;;

	type hostinfo = struct
		flags	: uint32
		fam	: sockfam
		stype	: socktype
		proto	: uint32
		addr	: sockaddr[:]
		canon	: byte[:]
		next	: hostinfo#
	;;

	const resolve	: (host : byte[:]	-> hostinfo#)
;;

type dnshdr = struct
	id	: uint16
	/* {qr:1|op:4|aa:1|tc:1|rd:1|ra:1|z:3|rcode:4} */
	flags	: uint16  
	qdcnt	: uint16
	ancnt	: uint16
	nscnt	: uint16
	arcnt	: uint16
;;

const resolve = {host : byte[:]
	var hinf

	hinf = zalloc()
	dnsresolve(host)
	-> hinf
}

const dnsresolve = {host : byte[:]
	/*var hosts*/
	var nsrv

	if !valid(host)
		-> `Failure (`Badhost)
	;;
	if (nsrv = dnsconnect()) < 0
		-> `Failure (`Badsrv)
	;;
	if !dnsquery(nsrv, host)
		-> `Failure (`Badquery)
	;;
	-> `Success true
}

const dnsconnect = {
	var sa : sockaddr_in
	var s
	var status
	
	s = socket(Afinet, Sockdgram, 0)
	if s < 0
		put("Warning: Failed to open socket: %l\n", s)
		-> -1
	;;
	/* hardcode Google DNS for now */
	sa.fam = Afinet
	sa.port = hosttonet(53) /* port 53 */
	sa.addr = [8,8,8,8]	/* 8.8.8.8 */
	status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in))
	if status < 0
		put("Warning: Failed to connect to server: %l\n", status)
		-> -1
	;;
	-> s
}

const dnsquery = {srv, host
	tquery(srv, host)
	rquery(srv)
	put("Unimplemented query: srv=%z, host=%s\n", srv, host)
	-> false
}

const Qr : uint16 = 1 << 0
const Aa : uint16 = 1 << 5
const Tc : uint16 = 1 << 6
const Rd : uint16 = 1 << 7
const Ra : uint16 = 1 << 8

var nextid = 42
const tquery = {srv, host
	var pkt : byte[512] /* big enough */
	var off : size

	/* header */
	off = 0
	off += pack16(pkt[:], off, nextid++)	/* id */
	off += pack16(pkt[:], off, Ra)	/* flags */
	off += pack16(pkt[:], off, 1)	/* qdcount */
	off += pack16(pkt[:], off, 0)	/* ancount */
	off += pack16(pkt[:], off, 0)	/* nscount */
	off += pack16(pkt[:], off, 0)	/* arcount */

	/* query */
	off += packname(pkt[:], off, host)	/* host */
	off += pack16(pkt[:], off, 0x1) /* qtype: a record */
	off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */

	write(1, pkt[:off])
	write(srv, pkt[:off])
}

const rquery = {srv
	var pktbuf : byte[1024]
	var pkt
	var n

	put("Waiting for response...\n")
	n = read(srv, pktbuf[:])
	if n < 0
		put("Warning: Failed to read from %z: %i\n", srv, n)
	;;
	pkt = pktbuf[:n]
	dumpresponse(pkt)
}


const dumpresponse = {pkt
	var nquery
	var nans
	var off
	var v
	var i

	put("packet size = %z\n", pkt.len)
	(v, off) = unpack16(pkt, 0)
	put("hdr.id = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("hdr.rawflag = %i\n", (((v castto(uint32)) & 0xf000) >> 12))
	put("hdr.flag = [Qr = %t, Aa = %t, Tc = %t, Rd = %t, Ra = %t]\n", (v&Qr) == 0, (v&Aa) == 0, (v&Tc) == 0, (v&Rd)==0, (v&Ra)==0)
	put("hdr.rcode = %w\n", (v >> 11) & 0xf)
	(nquery, off) = unpack16(pkt, off)
	put("hdr.qdcount = %w\n", nquery)
	(nans, off) = unpack16(pkt, off)
	put("hdr.ancount = %w\n", nans)
	(v, off) = unpack16(pkt, off)
	put("hdr.nscount = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("hdr.arcount = %w\n", v)
	put("Queries:\n")
	for i = 0; i < nquery; i++
		put("i: %w\n", i)
		off = dumpquery(pkt, off)
	;;

	put("Answers:")
	for i = 0; i < nans; i++
		put("i: %w\n", i)
		off = dumpans(pkt, off)
	;;
}

const dumpquery = {pkt, off
	var v

	put("\tname = "); 
	off = printname(pkt, off)
	(v, off) = unpack16(pkt, off)
	put("\tbody.type = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.class = %w\n", v)
	-> off
}

const dumpans = {pkt, off
	var v

	put("\tname = "); 
	off = printname(pkt, off)
	(v, off) = unpack16(pkt, off)
	put("\tbody.type = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.class = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.ttl_lo = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.ttl_hi = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.rdlength = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.rdata_lo = %w\n", v)
	(v, off) = unpack16(pkt, off)
	put("\tbody.rdata_hi = %w\n", v)
	-> off
}

const printname = {pkt, off
	var sz

	for sz = pkt[off] castto(size); sz != 0; sz = pkt[off] castto(size)
		if sz & 0xC0 == 0xC0
			put("PTR: ")
			printname(pkt, ((sz & ~0xC0) << 8) | (pkt[off + 1] castto(size)))
			-> off + 2
		else
			put("%s.", pkt[off+1:off+sz+1])
			off += sz + 1
		;;
	;;
	-> off + 1
}

const pack16 = {buf, off, v
	buf[off]	= (v & 0xff00) >> 8 castto(byte)
	buf[off+1]	= (v & 0x00ff) castto(byte)
	-> sizeof(uint16) /* we always write one uint16 */
}

const unpack16 = {buf, off
	var v

	v = (buf[off] castto(uint16)) << 8
	v |= (buf[off + 1] castto(uint16))
	-> (v, off+sizeof(uint16))
}

const packname = {buf, off : size, host
	var i
	var start
	var seglen
	var lastseg

	start = off
	seglen = 0
	lastseg = 0
	for i = 0; i < host.len; i++
		seglen++
		if host[i] == ('.' castto(byte))
			off += addseg(buf, off, host[lastseg:lastseg+seglen-1])
			lastseg = seglen
			seglen = 0
		;;
	;;
	if host[host.len - 1] != ('.' castto(byte))
		off += addseg(buf, off, host[lastseg:lastseg + seglen])
	;;
	off += addseg(buf, off, "") /* null terminating segment */
	-> off - start
}

const addseg = {buf, off, str
	buf[off] = str.len castto(byte)
	slcp(buf[off + 1 : off + str.len + 1], str)
	-> str.len + 1
}

const valid = {host : byte[:]
	var i
	var seglen

	/* maximum length: 255 chars */
	if host.len > 255
		-> false
	;;

	seglen = 0
	for i = 0; i < host.len; i++
		if host[i] == ('.' castto(byte))
			seglen = 0
		;;
		if seglen > 63
			-> false
		;;
		if host[i] & 0x80
			-> false
		;;
	;;

	-> true

}