shithub: mc

Download patch

ref: 3b6a42aaa1dfb75db06e9bf8cb7953cecefc647f
parent: 41046c6382017e0297f8e07a7493c2cd779b8f0b
author: Frank Smit <[email protected]>
date: Fri Sep 18 13:21:29 EDT 2020

Query nameservers in parallel.

--- a/lib/std/resolve+posixy.myr
+++ b/lib/std/resolve+posixy.myr
@@ -22,8 +22,6 @@
 use "types"
 use "utf"
 
-use "fmt"
-
 pkg std =
 	type rectype = union
 		`DnsA		/* host address */
@@ -61,6 +59,8 @@
 	const resolve		: (host : byte[:]	-> result(hostinfo[:], resolveerr))
 	const resolvemx		: (host : byte[:]	-> result(hostinfo[:], resolveerr))
 	const resolverec	: (host : byte[:], t : rectype[:]	-> result(hostinfo[:], resolveerr))
+
+	const maxns = 5
 ;;
 
 const Hostfile = "/etc/hosts"
@@ -152,7 +152,7 @@
 				invalid addresses are ignored: we don't want to break stuff
 				with invalid or unsupported addresses
 				*/
-				
+
 			;;
 		| `None:
 		;;
@@ -222,7 +222,7 @@
 	match word(rest)
 	| `Some (name, _):
 		match ipparse(name)
-		| `Some addr: 
+		| `Some addr:
 			slpush(&nameservers, addr)
 		| `None:
 			/* nothing */
@@ -247,23 +247,32 @@
 	;;
 }
 
-
 const dnsresolve = {host, rt
-	var nsrv, r
+	var srv : sys.pollfd[maxns]
+	var cutoff
+	var ret
 
 	if !valid(host)
 		-> `Err (`Badhost)
 	;;
-	/* FIXME: Assumption: nameservers is not modified by other threads */
-	for ns : nameservers
-		nsrv = dnsconnect(ns)
-		if nsrv >= 0
-			r = dnsquery(nsrv, host, rt)
-			sys.close(nsrv)
-			-> r
-		;;
+
+	cutoff = std.min(maxns, nameservers.len)
+
+	for var i = 0; i < cutoff; i++
+		srv[i] = [
+			.fd=dnsconnect(nameservers[i]),
+			.events=sys.Pollout,
+			.revents=0,
+		]
 	;;
-	-> `Err (`Badsrv)
+
+	ret = dnsquery(srv[:cutoff], host, rt)
+
+	for s : srv[:cutoff]
+		sys.close(s.fd)
+	;;
+
+	-> ret
 }
 
 const dnsconnect = {ns
@@ -295,10 +304,55 @@
 }
 
 const dnsquery = {srv, host, t
-	var id
+	var pkt : byte[512] /* big enough */
+	var id : uint16[maxns]
+	var query
+	var fail
+	var giveup
 
-	id = tquery(srv, host, t)
-	-> rquery(srv, host, id)
+	fail = 0
+	giveup = now() + 1000*Timeout
+
+	while true
+		/* all failed */
+		if fail == srv.len
+			break
+		;;
+
+		var r = sys.poll(srv, (giveup - std.now() : int)/1000)
+		if r < 0
+			-> `Err `Badconn
+		elif r == 0
+			-> `Err `Timeout
+		;;
+
+		for var i = 0; i < srv.len; i++
+			var s = &srv[i]
+
+			if (s.revents & sys.Pollout) != 0
+				(id[i], query) = mkquery(host, t)
+				sys.write(s.fd, query)
+				s.events = sys.Pollin
+			elif (s.revents & sys.Pollin) != 0
+				var n = sys.read(s.fd, pkt[:])
+				if n < 0
+					-> `Err `Badconn
+				;;
+
+				var inf = hosts(pkt[:n], host, id[i])
+				match inf
+				| `std.Err `Badresp:
+					/* continue polling */
+				| _:
+					-> inf
+				;;
+			else
+				fail++
+			;;
+		;;
+	;;
+
+	-> `Err (`Badsrv)
 }
 
 const Qr : uint16 = 1 << 0
@@ -308,7 +362,7 @@
 const Ra : uint16 = 1 << 8
 
 var nextid : uint16 = 42
-const tquery = {srv, host, t
+const mkquery = {host, t
 	var pkt : byte[512] /* big enough */
 	var off : size
 
@@ -327,39 +381,7 @@
 	off += pack16(pkt[:], off, (t : uint16)) /* qtype: a record */
 	off += pack16(pkt[:], off, 0x1) /* qclass: inet4 */
 
-	sys.write(srv, pkt[:off])
-	-> nextid
-}
-
-const rquery = {srv, host, id
-	var pktbuf : byte[1024]
-	var pkt, pfd, giveup
-	var r, n, inf
-
-	giveup = std.now() + 1000*Timeout
-:again
-	pfd = [
-		[.fd=srv, .events=sys.Pollin, .revents=0]
-	][:]
-	r = sys.poll(pfd[:], (giveup - std.now() : int)/1000)
-	if r < 0
-		-> `Err `Badconn
-	elif r == 0
-		-> `Err `Timeout
-	else
-		n = sys.read(srv, pktbuf[:])
-		if n < 0
-			-> `Err `Badconn
-		;;
-		pkt = pktbuf[:n]
-		inf = hosts(pkt, host, id)
-		match inf
-		| `std.Err `Badresp:
-			goto again
-		| _:	
-			-> inf
-		;;
-	;;
+	-> (nextid, pkt[:off])
 }
 
 const hosts = {pkt, host, id