shithub: mc

ref: 7fe8f299454bc2238a549e7dad31c0956aed26a9
dir: /libstd/resolve.myr/

View raw version
use "alloc.use"
use "die.use"
use "endian.use"
use "error.use"
use "fmt.use"
use "slcp.use"
use "sys.use"
use "types.use"

pkg std =
	type resolveerr = union
		`Badhost
		`Badsrv
		`Badquery
	;;

	type hostinfo = struct
		flags	: uint32
		fam	: sockfam
		stype	: socktype
		proto	: uint32
		addr	: sockaddr[:]
		canon	: byte[:]
		next	: hostinfo#
	;;

	const resolve	: (host : byte[:]	-> hostinfo#)
;;

type dnshdr = struct
	id	: uint16
	/* {qr:1|op:4|aa:1|tc:1|rd:1|ra:1|z:3|rcode:4} */
	flags	: uint16  
	qdcnt	: uint16
	ancnt	: uint16
	nscnt	: uint16
	arcnt	: uint16
;;

const resolve = {host : byte[:]
	var hinf

	hinf = zalloc()
	dnsresolve(host)
	-> hinf
}

const dnsresolve = {host : byte[:]
	/*var hosts*/
	var nsrv

	if !valid(host)
		-> `Failure (`Badhost)
	;;
	if (nsrv = dnsconnect()) < 0
		-> `Failure (`Badsrv)
	;;
	if !dnsquery(nsrv, host)
		-> `Failure (`Badquery)
	;;
	-> `Success true
}

const dnsconnect = {
	var sa : sockaddr_in
	var s
	var status
	
	s = socket(Afinet, Sockdgram, 0)
	if s < 0
		put("Warning: Failed to open socket: %l\n", s)
		-> -1
	;;
	/* hardcode Google DNS for now */
	sa.fam = Afinet
	sa.port = hosttonet(53)
	sa.addr = 0x08080808
	status = connect(s, (&sa) castto(sockaddr#), sizeof(sockaddr_in))
	if status < 0
		put("Warning: Failed to connect to server: %l\n", status)
		-> -1
	;;
	-> s
}

const dnsquery = {srv, host
	tquery(srv, host)
	rquery(srv)
	put("Unimplemented query: srv=%z, host=%s\n", srv, host)
	-> false
}

const Qr : uint16 = 1 << 0
const Aa : uint16 = 1 << 4
const Tc : uint16 = 1 << 5
const Rd : uint16 = 1 << 6
const Ra : uint16 = 1 << 7

var nextid = 0
const tquery = {srv, host
	var pkt : byte[512] /* big enough */
	var off

	/* header */
	off = 0
	off += pack16(pkt[:], off, nextid++)	/* id */
	off += pack16(pkt[:], off, Qr|Ra|Rd)	/* flags */
	off += pack16(pkt[:], off, 1)	/* qdcount */
	off += pack16(pkt[:], off, 0)	/* ancount */
	off += pack16(pkt[:], off, 0)	/* nscount */
	off += pack16(pkt[:], off, 0)	/* arcount */

	/* query */
	off += packname(pkt[:], off, host)	/* host */
	off += pack16(pkt[:], off, 0x1) /* qtype: a record */
	off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */

	write(srv, pkt[:off])
}

const rquery = {srv
	var pktbuf : byte[1024]
	var n

	n = read(srv, pktbuf[:])
	if n < 0
		put("Warning: Failed to read from %z: %i\n", srv, n)
	;;
	put("pkt: [len = %z]: %s\n", n, pktbuf[:n])
}

const pack16 = {buf, off, v
	buf[off]	= (v & 0xff00) >> 8 castto(byte)
	buf[off+1]	= (v & 0x00ff) castto(byte)
	-> sizeof(uint16) /* we always write one uint16 */
}

/*
const unpack16 = {buf, off
	var v

	v = (buf[off] castto(uint16)) << 8
	v |= (buf[off + 1] castto(uint16))
	-> (v, sizeof(uint16))
}
*/

const packname = {buf, off, host
	var i
	var start
	var seglen
	var lastseg

	start = off
	lastseg = 0
	for i = 0; i < host.len; i++
		if host[i] == ('.' castto(byte))
			off += addseg(buf, off, host[lastseg:lastseg+seglen])
			lastseg = seglen + 1
			seglen = 0
		;;
		seglen++
	;;
	if host[host.len - 1] != ('.' castto(byte))
		off += addseg(buf, off, host[lastseg:lastseg + seglen])
	;;
	off += addseg(buf, off, "") /* null terminating segment */
	put("size: %z\n", off - start)
	-> off - start
}

const addseg = {buf, off, str
	put("Adding seg %s\n", str)
	buf[0] = str.len castto(byte)
	slcp(buf[off + 1 : off + str.len + 1], str)
	-> str.len
}

const valid = {host : byte[:]
	var i
	var seglen

	/* maximum length: 255 chars */
	if host.len > 255
		-> false
	;;

	seglen = 0
	for i = 0; i < host.len; i++
		if host[i] == ('.' castto(byte))
			seglen = 0
		;;
		if seglen > 63
			-> false
		;;
		if host[i] & 0x80
			-> false
		;;
	;;

	-> true

}