shithub: mc

Download patch

ref: 3d078d5439e93a3dfc4b808ab6bf02805d455ff8
parent: 70f97fe9898b4852257a9268e6ea0592ee7e3a88
author: Ori Bernstein <[email protected]>
date: Sat Jan 13 18:39:48 EST 2018

Add code to fix up iterators.

--- a/parse/infer.c
+++ b/parse/infer.c
@@ -2135,6 +2135,7 @@
 			unify(n, e, b);
 		else
 			htput(seqbase, t, e);
+		delayedcheck(n, curstab());
 		break;
 	case Nmatchstmt:
 		infernode(&n->matchstmt.val, NULL, sawret);
@@ -2379,6 +2380,35 @@
 }
 
 static void
+fixiter(Node *n, Type *ty, Type *base)
+{
+	size_t i, bestidx;
+	int r, bestrank;
+	Type *b, *t;
+
+	ty = tysearch(ty);
+	b = htget(seqbase, ty);
+	if (!b)
+		return;
+	bestrank = -1;
+	bestidx = 0;
+	for (i = 0; i < nimpltab; i++) {
+		if (impltab[i]->impl.trait != traittab[Tciter])
+			continue;
+		r = tymatchrank(impltab[i]->impl.type, ty);
+		if (r > bestrank) {
+			bestrank = r;
+			bestidx = i;
+		}
+	}
+	if (bestrank >= 0) {
+		t = tf(impltab[bestidx]->impl.aux[0]);
+		t = tyfreshen(NULL, t);
+		unify(n, t, base);
+	}
+}
+
+static void
 postcheckpass(Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
 {
 	size_t i;
@@ -2387,12 +2417,16 @@
 	for (i = 0; i < npostcheck; i++) {
 		n = postcheck[i];
 		pushstab(postcheckscope[i]);
-		switch (exprop(n)) {
-		case Omemb:	infercompn(n, rem, nrem, remscope, nremscope);	break;
-		case Ocast:	checkcast(n, rem, nrem, remscope, nremscope);	break;
-		case Ostruct:	checkstruct(n, rem, nrem, remscope, nremscope);	break;
-		case Ovar:	checkvar(n, rem, nrem, remscope, nremscope);	break;
-		default:	die("should not see %s in postcheck\n", opstr[exprop(n)]);
+		if (n->type == Nexpr) {
+			switch (exprop(n)) {
+			case Omemb:	infercompn(n, rem, nrem, remscope, nremscope);	break;
+			case Ocast:	checkcast(n, rem, nrem, remscope, nremscope);	break;
+			case Ostruct:	checkstruct(n, rem, nrem, remscope, nremscope);	break;
+			case Ovar:	checkvar(n, rem, nrem, remscope, nremscope);	break;
+			default:	die("should not see %s in postcheck\n", opstr[exprop(n)]);
+			}
+		} else if (n->type == Niterstmt) {
+			fixiter(n, type(n->iterstmt.seq), type(n->iterstmt.elt));
 		}
 		popstab();
 	}
@@ -2419,7 +2453,6 @@
 		postcheckscope = remscope;
 		npostcheckscope = nremscope;
 	}
-	postcheckpass(NULL, NULL, NULL, NULL);
 }
 
 /* After inference, replace all
@@ -2799,10 +2832,10 @@
 			tr = gettrait(ns, n);
 		if (!tr)
 			fatal(impl, "trait %s does not exist near %s",
-					namestr(impl->impl.traitname), ctxstr(impl));
+			    namestr(impl->impl.traitname), ctxstr(impl));
 		if (tr->naux != impl->impl.naux)
 			fatal(impl, "incompatible implementation of %s: mismatched aux types",
-					namestr(impl->impl.traitname), ctxstr(impl));
+			    namestr(impl->impl.traitname), ctxstr(impl));
 	}
 	return tr;
 }
@@ -2876,6 +2909,9 @@
 		pushenv(impl->impl.env);
 		ty = tf(impl->impl.type);
 		addtraittab(traitmap, tr, ty);
+		if (tr->uid == Tciter) {
+			htput(seqbase, tf(impl->impl.type), tf(impl->impl.aux[0]));
+		}
 		popenv(impl->impl.env);
 	}
 	popstab();