shithub: mc

Download patch

ref: df0b6010f2506030d6674dedede207d19434f662
parent: d97dcac9b2b1e2526d897c508bb37c82e32517cd
author: Ori Bernstein <[email protected]>
date: Mon Jul 16 17:16:19 EDT 2012

Move inference state into state struct.

    Means we don't have to make lots of crap global or thread it all through
    the code in an ugly way.

--- a/8/isel.c
+++ b/8/isel.c
@@ -267,7 +267,7 @@
  * or:
  *    Oadd(
  *        reg,
- *        reg||const))
+ *        reg||const)
  *
  * or:
  *    Oadd(
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -12,23 +12,36 @@
 
 #include "parse.h"
 
-static Node **postcheck;
-static size_t npostcheck;
-static Htab **tybindings;
-static size_t ntybindings;
-static Node **genericdecls;
-static size_t ngenericdecls;
-static Node **specializations;
-static Node **specializations;
-static size_t nspecializations;
-static Stab **specializationscope;
-static size_t nspecializationscope;
+typedef struct Inferstate Inferstate;
+struct Inferstate {
+	int inpattern;
+	int ingeneric;
 
-static void infernode(Node *n, Type *ret, int *sawret);
-static void inferexpr(Node *n, Type *ret, int *sawret);
-static void typesub(Node *n);
-static Type *tf(Type *t);
+	/* bound by patterns turn into decls in the action block */
+	Node **binds;
+	size_t nbinds;
+	/* nodes that need post-inference checking/unification */
+	Node **postcheck;
+	size_t npostcheck;
+	/* the type parmas bound at the current point */
+	Htab **tybindings;
+	size_t ntybindings;
+	/* generic declarations to be specialized */
+	Node **genericdecls;
+	size_t ngenericdecls;
+	/* the nodes that we've specialized them to, and the scopes they
+	 * appear in */
+	Node **specializations;
+	size_t nspecializations;
+	Stab **specializationscope;
+	size_t nspecializationscope;
+};
 
+static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret);
+static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret);
+static void typesub(Inferstate *st, Node *n);
+static Type *tf(Inferstate *st, Type *t);
+
 static void setsuper(Stab *st, Stab *super)
 {
     Stab *s;
@@ -39,13 +52,13 @@
     st->super = super;
 }
 
-static int isbound(Type *t)
+static int isbound(Inferstate *st, Type *t)
 {
     ssize_t i;
     Type *p;
 
-    for (i = ntybindings - 1; i >= 0; i--) {
-        p = htget(tybindings[i], t->pname);
+    for (i = st->ntybindings - 1; i >= 0; i--) {
+        p = htget(st->tybindings[i], t->pname);
         if (p == t)
             return 1;
     }
@@ -52,12 +65,12 @@
     return 0;
 }
 
-static Type *tyfreshen(Htab *ht, Type *t)
+static Type *tyfreshen(Inferstate *st, Htab *ht, Type *t)
 {
     Type *ret;
     size_t i;
 
-    t = tf(t);
+    t = tf(st, t);
     if (t->type != Typaram && t->nsub == 0)
         return t;
 
@@ -71,21 +84,21 @@
 
     ret = tydup(t);
     for (i = 0; i < t->nsub; i++)
-        ret->sub[i] = tyfreshen(ht, t->sub[i]);
+        ret->sub[i] = tyfreshen(st, ht, t->sub[i]);
     return ret;
 }
 
-static Type *freshen(Type *t)
+static Type *freshen(Inferstate *st, Type *t)
 {
     Htab *ht;
 
     ht = mkht(strhash, streq);
-    t = tyfreshen(ht, t);
+    t = tyfreshen(st, ht, t);
     htfree(ht);
     return t;
 }
 
-static void tyresolve(Type *t)
+static void tyresolve(Inferstate *st, Type *t)
 {
     size_t i;
     Type *base;
@@ -95,22 +108,22 @@
     t->resolved = 1;
     if (t->type == Tystruct) {
         for (i = 0; i < t->nmemb; i++)
-            infernode(t->sdecls[i], NULL, NULL);
+            infernode(st, t->sdecls[i], NULL, NULL);
     } else if (t->type == Tyunion) {
         for (i = 0; i < t->nmemb; i++) {
-            tyresolve(t->udecls[i]->utype);
-            t->udecls[i]->utype = tf(t->udecls[i]->utype);
+            tyresolve(st, t->udecls[i]->utype);
+            t->udecls[i]->utype = tf(st, t->udecls[i]->utype);
             if (t->udecls[i]->etype) {
-                tyresolve(t->udecls[i]->etype);
-                t->udecls[i]->etype = tf(t->udecls[i]->etype);
+                tyresolve(st, t->udecls[i]->etype);
+                t->udecls[i]->etype = tf(st, t->udecls[i]->etype);
             }
         }
     } else if (t->type == Tyarray) {
-        infernode(t->asize, NULL, NULL);
+        infernode(st, t->asize, NULL, NULL);
     }
 
     for (i = 0; i < t->nsub; i++)
-        t->sub[i] = tf(t->sub[i]);
+        t->sub[i] = tf(st, t->sub[i]);
     base = tybase(t);
     /* no-ops if base == t */
     if (t->cstrs)
@@ -120,7 +133,7 @@
 }
 
 /* fixd the most accurate type mapping we have */
-static Type *tf(Type *t)
+static Type *tf(Inferstate *st, Type *t)
 {
     Type *lu;
 
@@ -137,7 +150,7 @@
             break;
         t = tytab[t->tid];
     }
-    tyresolve(t);
+    tyresolve(st, t);
     return t;
 }
 
@@ -150,9 +163,9 @@
         readuse(n->file.uses[i], n->file.globls);
 }
 
-static void settype(Node *n, Type *t)
+static void settype(Inferstate *st, Node *n, Type *t)
 {
-    t = tf(t);
+    t = tf(st, t);
     switch (n->type) {
         case Nexpr:     n->expr.type = t;       break;
         case Ndecl:     n->decl.type = t;       break;
@@ -181,7 +194,7 @@
     return NULL;
 }
 
-static Type *type(Node *n)
+static Type *type(Inferstate *st, Node *n)
 {
     Type *t;
 
@@ -195,7 +208,7 @@
         die("untypeable %s", nodestr(n->type));
         break;
     };
-    return tf(t);
+    return tf(st, t);
 }
 
 static char *ctxstr(Node *n)
@@ -285,7 +298,7 @@
     return 0;
 }
 
-static Type *unify(Node *ctx, Type *a, Type *b)
+static Type *unify(Inferstate *st, Node *ctx, Type *a, Type *b)
 {
     Type *t;
     Type *r;
@@ -292,8 +305,8 @@
     size_t i;
 
     /* a ==> b */
-    a = tf(a);
-    b = tf(b);
+    a = tf(st, a);
+    b = tf(st, b);
     if (a == b)
         return a;
     if (b->type == Tyvar) {
@@ -318,7 +331,7 @@
             if (i >= a->nsub)
                 fatal(ctx->line, "%s has wrong subtypes for %s near %s", tystr(a), tystr(b), ctxstr(ctx));
 
-            unify(ctx, a->sub[i], b->sub[i]);
+            unify(st, ctx, a->sub[i], b->sub[i]);
         }
         r = b;
     } else if (a->type != Tyvar) {
@@ -327,12 +340,12 @@
     return r;
 }
 
-static void unifycall(Node *n)
+static void unifycall(Inferstate *st, Node *n)
 {
     size_t i;
     Type *ft;
 
-    ft = type(n->expr.args[0]);
+    ft = type(st, n->expr.args[0]);
     if (ft->type == Tyvar) {
         /* the first arg is the function itself, so it shouldn't be counted */
         ft = mktyfunc(n->line, &n->expr.args[1], n->expr.nargs - 1, mktyvar(n->line));
@@ -340,17 +353,17 @@
     for (i = 1; i < n->expr.nargs; i++) {
         if (ft->sub[i]->type == Tyvalist)
             break;
-        inferexpr(n->expr.args[i], NULL, NULL);
-        unify(n, ft->sub[i], type(n->expr.args[i]));
+        inferexpr(st, n->expr.args[i], NULL, NULL);
+        unify(st, n, ft->sub[i], type(st, n->expr.args[i]));
     }
-    settype(n, ft->sub[0]);
+    settype(st, n, ft->sub[0]);
 }
 
-static void checkns(Node *n, Node **ret)
+static void checkns(Inferstate *st, Node *n, Node **ret)
 {
     Node *var, *name, *nsname;
     Node **args;
-    Stab *st;
+    Stab *stab;
     Node *s;
 
     if (n->type != Nexpr)
@@ -361,31 +374,31 @@
     if (args[0]->type != Nexpr || exprop(args[0]) != Ovar)
         return;
     name = args[0]->expr.args[0];
-    st = getns(curstab(), name);
-    if (!st)
+    stab = getns(curstab(), name);
+    if (!stab)
         return;
     nsname = mknsname(n->line, namestr(name), namestr(args[1]));
-    s = getdcl(st, args[1]);
+    s = getdcl(stab, args[1]);
     if (!s)
         fatal(n->line, "Undeclared var %s.%s", nsname->name.ns, nsname->name.name);
     var = mkexpr(n->line, Ovar, nsname, NULL);
     var->expr.did = s->decl.did;
-    settype(var, s->decl.type);
+    settype(st, var, s->decl.type);
     *ret = var;
 }
 
-static void inferseq(Node *n)
+static void inferseq(Inferstate *st, Node *n)
 {
     size_t i;
 
     for (i = 0; i < n->lit.nelt; i++) {
-        infernode(n->lit.seqval[i], NULL, NULL);
-        unify(n, type(n->lit.seqval[0]), type(n->lit.seqval[i]));
+        infernode(st, n->lit.seqval[i], NULL, NULL);
+        unify(st, n, type(st, n->lit.seqval[0]), type(st, n->lit.seqval[i]));
     }
-    settype(n, mktyarray(n->line, type(n->lit.seqval[0]), mkintlit(n->line, n->lit.nelt)));
+    settype(st, n, mktyarray(n->line, type(st, n->lit.seqval[0]), mkintlit(n->line, n->lit.nelt)));
 }
 
-static void inferexpr(Node *n, Type *ret, int *sawret)
+static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret)
 {
     Node **args;
     Type **types;
@@ -403,8 +416,8 @@
             /* Omemb can sometimes resolve to a namespace. We have to check
              * this. Icky. */
             if (exprop(args[i]) == Omemb)
-                checkns(args[i], &args[i]);
-            inferexpr(args[i], ret, sawret);
+                checkns(st, args[i], &args[i]);
+            inferexpr(st, args[i], ret, sawret);
         }
     }
     switch (exprop(n)) {
@@ -436,10 +449,10 @@
         case Obxoreq:   /* @a ^= @a -> @a */
         case Obsleq:    /* @a <<= @a -> @a */
         case Obsreq:    /* @a >>= @a -> @a */
-            t = type(args[0]);
+            t = type(st, args[0]);
             for (i = 1; i < nargs; i++)
-                t = unify(n, t, type(args[i]));
-            settype(n, tf(t));
+                t = unify(st, n, t, type(st, args[i]));
+            settype(st, n, tf(st, t));
             break;
 
         /* operands same type, returning bool */
@@ -452,45 +465,45 @@
         case Oge:       /* @a >= @a -> bool */
         case Olt:       /* @a < @a -> bool */
         case Ole:       /* @a <= @b -> bool */
-            t = type(args[0]);
+            t = type(st, args[0]);
             for (i = 1; i < nargs; i++)
-                unify(n, t, type(args[i]));
-            settype(n, mkty(-1, Tybool));
+                unify(st, n, t, type(st, args[i]));
+            settype(st, n, mkty(-1, Tybool));
             break;
 
         /* reach into a type and pull out subtypes */
         case Oaddr:     /* &@a -> @a* */
-            settype(n, mktyptr(n->line, type(args[0])));
+            settype(st, n, mktyptr(n->line, type(st, args[0])));
             break;
         case Oderef:    /* *@a* ->  @a */
-            t = unify(n, type(args[0]), mktyptr(n->line, mktyvar(n->line)));
-            settype(n, t);
+            t = unify(st, n, type(st, args[0]), mktyptr(n->line, mktyvar(n->line)));
+            settype(st, n, t);
             break;
         case Oidx:      /* @a[@b::tcint] -> @a */
             t = mktyidxhack(n->line, mktyvar(n->line));
-            unify(n, type(args[0]), t);
-            constrain(n, type(args[1]), cstrtab[Tcint]);
-            settype(n, tf(t->sub[0]));
+            unify(st, n, type(st, args[0]), t);
+            constrain(n, type(st, args[1]), cstrtab[Tcint]);
+            settype(st, n, tf(st, t->sub[0]));
             break;
         case Oslice:    /* @a[@b::tcint,@b::tcint] -> @a[,] */
-            t = mktyidxhack(n->line, type(args[1]));
-            unify(n, type(args[0]), t);
-            settype(n, mktyslice(n->line, type(args[1])));
+            t = mktyidxhack(n->line, type(st, args[1]));
+            unify(st, n, type(st, args[0]), t);
+            settype(st, n, mktyslice(n->line, type(st, args[1])));
             break;
 
         /* special cases */
         case Omemb:     /* @a.Ident -> @b, verify type(@a.Ident)==@b later */
-            settype(n, mktyvar(n->line));
-            lappend(&postcheck, &npostcheck, n);
+            settype(st, n, mktyvar(n->line));
+            lappend(&st->postcheck, &st->npostcheck, n);
             break;
         case Osize:     /* sizeof @a -> size */
-            settype(n, mkty(n->line, Tyuint));
+            settype(st, n, mkty(n->line, Tyuint));
             break;
         case Ocall:     /* (@a, @b, @c, ... -> @r)(@a,@b,@c, ... -> @r) -> @r */
-            unifycall(n);
+            unifycall(st, n);
             break;
         case Ocast:     /* cast(@a, @b) -> @b */
-            lappend(&postcheck, &npostcheck, n);
+            lappend(&st->postcheck, &st->npostcheck, n);
             break;
         case Oret:      /* -> @a -> void */
             if (sawret)
@@ -498,13 +511,13 @@
             if (!ret)
                 fatal(n->line, "Not allowed to return value here");
             if (nargs)
-                t = unify(n, ret, type(args[0]));
+                t = unify(st, n, ret, type(st, args[0]));
             else
-                t =  unify(n, mkty(-1, Tyvoid), ret);
-            settype(n, t);
+                t =  unify(st, n, mkty(-1, Tyvoid), ret);
+            settype(st, n, t);
             break;
         case Ojmp:     /* goto void* -> void */
-            settype(n, mkty(-1, Tyvoid));
+            settype(st, n, mkty(-1, Tyvoid));
             break;
         case Ovar:      /* a:@a -> @a */
             /* if we created this from a namespaced var, the type should be
@@ -517,15 +530,15 @@
                 fatal(n->line, "Undeclared var %s", ctxstr(args[0]));
 
             if (s->decl.isgeneric)
-                t = freshen(s->decl.type);
+                t = freshen(st, s->decl.type);
             else
                 t = s->decl.type;
-            settype(n, t);
+            settype(st, n, t);
             n->expr.did = s->decl.did;
             if (s->decl.isgeneric) {
-                lappend(&specializationscope, &nspecializationscope, curstab());
-                lappend(&specializations, &nspecializations, n);
-                lappend(&genericdecls, &ngenericdecls, s);
+                lappend(&st->specializationscope, &st->nspecializationscope, curstab());
+                lappend(&st->specializations, &st->nspecializations, n);
+                lappend(&st->genericdecls, &st->ngenericdecls, s);
             }
             break;
         case Ocons:
@@ -537,30 +550,30 @@
             else if (uc->etype && n->expr.nargs != 2)
                 fatal(n->line, "union constructor %s needs arg ", ctxstr(args[0]));
             else if (uc->etype)
-                unify(n, uc->etype, type(args[1]));
-            settype(n, uc->utype);
+                unify(st, n, uc->etype, type(st, args[1]));
+            settype(st, n, uc->utype);
             break;
         case Otup:
             types = xalloc(sizeof(Type *)*n->expr.nargs);
             for (i = 0; i < n->expr.nargs; i++)
-                types[i] = type(n->expr.args[i]);
-            settype(n, mktytuple(n->line, types, n->expr.nargs));
+                types[i] = type(st, n->expr.args[i]);
+            settype(st, n, mktytuple(n->line, types, n->expr.nargs));
             break;
         case Oarr:
             for (i = 0; i < n->expr.nargs; i++)
-                unify(n, type(n->expr.args[0]), type(n->expr.args[i]));
-            settype(n, mktyarray(n->line, type(n->expr.args[0]), mkintlit(n->line, n->expr.nargs)));
+                unify(st, n, type(st, n->expr.args[0]), type(st, n->expr.args[i]));
+            settype(st, n, mktyarray(n->line, type(st, n->expr.args[0]), mkintlit(n->line, n->expr.nargs)));
             break;
         case Olit:      /* <lit>:@a::tyclass -> @a */
             switch (args[0]->lit.littype) {
-                case Lfunc:     infernode(args[0]->lit.fnval, NULL, NULL); break;
-                case Lseq:      inferseq(args[0]);                         break;
+                case Lfunc:     infernode(st, args[0]->lit.fnval, NULL, NULL); break;
+                case Lseq:      inferseq(st, args[0]);                         break;
                 default:        /* pass */                                 break;
             }
-            settype(n, type(args[0]));
+            settype(st, n, type(st, args[0]));
             break;
         case Olbl:      /* :lbl -> void* */
-            settype(n, mktyptr(n->line, mkty(-1, Tyvoid)));
+            settype(st, n, mktyptr(n->line, mkty(-1, Tyvoid)));
         case Obad: case Ocjmp:
         case Oload: case Ostor:
         case Oslbase: case Osllen:
@@ -570,7 +583,7 @@
     }
 }
 
-static void inferfunc(Node *n)
+static void inferfunc(Inferstate *st, Node *n)
 {
     size_t i;
     int sawret;
@@ -577,22 +590,22 @@
 
     sawret = 0;
     for (i = 0; i < n->func.nargs; i++)
-        infernode(n->func.args[i], NULL, NULL);
-    infernode(n->func.body, n->func.type->sub[0], &sawret);
+        infernode(st, n->func.args[i], NULL, NULL);
+    infernode(st, n->func.body, n->func.type->sub[0], &sawret);
     /* if there's no return stmt in the function, assume void ret */
     if (!sawret)
-        unify(n, type(n)->sub[0], mkty(-1, Tyvoid));
+        unify(st, n, type(st, n)->sub[0], mkty(-1, Tyvoid));
 }
 
-static void inferdecl(Node *n)
+static void inferdecl(Inferstate *st, Node *n)
 {
     Type *t;
 
-    t = tf(decltype(n));
-    settype(n, t);
+    t = tf(st, decltype(n));
+    settype(st, n, t);
     if (n->decl.init) {
-        inferexpr(n->decl.init, NULL, NULL);
-        unify(n, type(n), type(n->decl.init));
+        inferexpr(st, n->decl.init, NULL, NULL);
+        unify(st, n, type(st, n), type(st, n->decl.init));
     } else {
         if (n->decl.isconst && !n->decl.isextern)
             fatal(n->line, "non-extern \"%s\" has no initializer", ctxstr(n));
@@ -599,7 +612,7 @@
     }
 }
 
-static void inferstab(Stab *s)
+static void inferstab(Inferstate *st, Stab *s)
 {
     void **k;
     size_t n, i;
@@ -607,13 +620,13 @@
 
     k = htkeys(s->ty, &n);
     for (i = 0; i < n; i++) {
-        t = tf(gettype(s, k[i]));
+        t = tf(st, gettype(s, k[i]));
         updatetype(s, k[i], t);
     }
     free(k);
 }
 
-static void tybind(Htab *bt, Type *t)
+static void tybind(Inferstate *st, Htab *bt, Type *t)
 {
     size_t i;
 
@@ -623,16 +636,16 @@
         return;
 
     if (hthas(bt, t->pname))
-        unify(NULL, htget(bt, t->pname), t);
-    else if (isbound(t))
+        unify(st, NULL, htget(bt, t->pname), t);
+    else if (isbound(st, t))
         return;
 
     htput(bt, t->pname, t);
     for (i = 0; i < t->nsub; i++)
-        tybind(bt, t->sub[i]);
+        tybind(st, bt, t->sub[i]);
 }
 
-static void bind(Node *n)
+static void bind(Inferstate *st, Node *n)
 {
     Htab *bt;
 
@@ -642,21 +655,21 @@
         fatal(n->line, "generic %s has no initializer", n->decl);
 
     bt = mkht(strhash, streq);
-    lappend(&tybindings, &ntybindings, bt);
+    lappend(&st->tybindings, &st->ntybindings, bt);
 
-    tybind(bt, n->decl.type);
-    tybind(bt, n->decl.init->expr.type);
+    tybind(st, bt, n->decl.type);
+    tybind(st, bt, n->decl.init->expr.type);
 }
 
-static void unbind(Node *n)
+static void unbind(Inferstate *st, Node *n)
 {
     if (!n->decl.isgeneric)
         return;
-    htfree(tybindings[ntybindings - 1]);
-    lpop(&tybindings, &ntybindings);
+    htfree(st->tybindings[st->ntybindings - 1]);
+    lpop(&st->tybindings, &st->ntybindings);
 }
 
-static void infernode(Node *n, Type *ret, int *sawret)
+static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret)
 {
     size_t i;
     Node *d;
@@ -669,69 +682,69 @@
             pushstab(n->file.globls);
             /* exports allow us to specify types later in the body, so we
              * need to patch the types in if they don't have a definition */
-            inferstab(n->file.globls);
-            inferstab(n->file.exports);
+            inferstab(st, n->file.globls);
+            inferstab(st, n->file.exports);
             for (i = 0; i < n->file.nstmts; i++) {
                 d  = n->file.stmts[i];
-                infernode(d, NULL, sawret);
+                infernode(st, d, NULL, sawret);
                 if (d->type == Ndecl)  {
                     s = getdcl(file->file.exports, d->decl.name);
                     if (s)
-                        unify(d, type(d), s->decl.type);
+                        unify(st, d, type(st, d), s->decl.type);
                 }
             }
             popstab();
             break;
         case Ndecl:
-            bind(n);
-            inferdecl(n);
-            unbind(n);
+            bind(st, n);
+            inferdecl(st, n);
+            unbind(st, n);
             break;
         case Nblock:
             setsuper(n->block.scope, curstab());
             pushstab(n->block.scope);
-            inferstab(n->block.scope);
+            inferstab(st, n->block.scope);
             for (i = 0; i < n->block.nstmts; i++) {
-                checkns(n->block.stmts[i], &n->block.stmts[i]);
-                infernode(n->block.stmts[i], ret, sawret);
+                checkns(st, n->block.stmts[i], &n->block.stmts[i]);
+                infernode(st, n->block.stmts[i], ret, sawret);
             }
             popstab();
             break;
         case Nifstmt:
-            infernode(n->ifstmt.cond, NULL, sawret);
-            infernode(n->ifstmt.iftrue, ret, sawret);
-            infernode(n->ifstmt.iffalse, ret, sawret);
-            constrain(n, type(n->ifstmt.cond), cstrtab[Tctest]);
+            infernode(st, n->ifstmt.cond, NULL, sawret);
+            infernode(st, n->ifstmt.iftrue, ret, sawret);
+            infernode(st, n->ifstmt.iffalse, ret, sawret);
+            constrain(n, type(st, n->ifstmt.cond), cstrtab[Tctest]);
             break;
         case Nloopstmt:
-            infernode(n->loopstmt.init, ret, sawret);
-            infernode(n->loopstmt.cond, NULL, sawret);
-            infernode(n->loopstmt.step, ret, sawret);
-            infernode(n->loopstmt.body, ret, sawret);
-            constrain(n, type(n->loopstmt.cond), cstrtab[Tctest]);
+            infernode(st, n->loopstmt.init, ret, sawret);
+            infernode(st, n->loopstmt.cond, NULL, sawret);
+            infernode(st, n->loopstmt.step, ret, sawret);
+            infernode(st, n->loopstmt.body, ret, sawret);
+            constrain(n, type(st, n->loopstmt.cond), cstrtab[Tctest]);
             break;
         case Nmatchstmt:
-            infernode(n->matchstmt.val, NULL, sawret);
+            infernode(st, n->matchstmt.val, NULL, sawret);
             for (i = 0; i < n->matchstmt.nmatches; i++) {
-                infernode(n->matchstmt.matches[i], ret, sawret);
-                unify(n, type(n->matchstmt.val), type(n->matchstmt.matches[i]->match.pat));
+                infernode(st, n->matchstmt.matches[i], ret, sawret);
+                unify(st, n, type(st, n->matchstmt.val), type(st, n->matchstmt.matches[i]->match.pat));
             }
             break;
         case Nmatch:
-            infernode(n->match.pat, NULL, sawret);
-            infernode(n->match.block, ret, sawret);
+            infernode(st, n->match.pat, NULL, sawret);
+            infernode(st, n->match.block, ret, sawret);
             break;
         case Nexpr:
-            inferexpr(n, ret, sawret);
+            inferexpr(st, n, ret, sawret);
             break;
         case Nfunc:
             setsuper(n->func.scope, curstab());
-            if (ntybindings > 0)
+            if (st->ntybindings > 0)
                 for (i = 0; i < n->func.nargs; i++)
-                    tybind(tybindings[ntybindings - 1], n->func.args[i]->decl.type);
+                    tybind(st, st->tybindings[st->ntybindings - 1], n->func.args[i]->decl.type);
             pushstab(n->func.scope);
-            inferstab(n->block.scope);
-            inferfunc(n);
+            inferstab(st, n->block.scope);
+            inferfunc(st, n);
             popstab();
             break;
         case Nname:
@@ -745,13 +758,13 @@
     }
 }
 
-static void checkcast(Node *n)
+static void checkcast(Inferstate *st, Node *n)
 {
 }
 
 /* returns the final type for t, after all unifications
  * and default constraint selections */
-static Type *tyfix(Node *ctx, Type *t)
+static Type *tyfix(Inferstate *st, Node *ctx, Type *t)
 {
     static Type *tyint, *tyflt;
     size_t i;
@@ -762,7 +775,7 @@
     if (!tyflt)
         tyflt = mkty(-1, Tyfloat64);
 
-    t = tf(t);
+    t = tf(st, t);
     if (t->type == Tyvar) {
         if (hascstr(t, cstrtab[Tcint]) && cstrcheck(t, tyint))
             return tyint;
@@ -770,18 +783,18 @@
             return tyint;
     } else {
         if (t->type == Tyarray) {
-            typesub(t->asize);
+            typesub(st, t->asize);
         } else if (t->type == Tystruct) {
             for (i = 0; i < t->nmemb; i++)
-                typesub(t->sdecls[i]);
+                typesub(st, t->sdecls[i]);
         } else if (t->type == Tyunion) {
             for (i = 0; i < t->nmemb; i++) {
                 if (t->udecls[i]->etype)
-                    t->udecls[i]->etype = tyfix(ctx, t->udecls[i]->etype);
+                    t->udecls[i]->etype = tyfix(st, ctx, t->udecls[i]->etype);
             }
         }
         for (i = 0; i < t->nsub; i++)
-            t->sub[i] = tyfix(ctx, t->sub[i]);
+            t->sub[i] = tyfix(st, ctx, t->sub[i]);
     }
     if (t->type == Tyvar) {
         fatal(t->line, "underconstrained type %s near %s", tyfmt(buf, 1024, t), ctxstr(ctx));
@@ -790,7 +803,7 @@
     return t;
 }
 
-static void infercompn(Node *file)
+static void infercompn(Inferstate *st, Node *file)
 {
     size_t i, j;
     Node *aggr;
@@ -800,32 +813,32 @@
     Type *t;
     int found;
 
-    for (i = 0; i < npostcheck; i++) {
-        n = postcheck[i];
+    for (i = 0; i < st->npostcheck; i++) {
+        n = st->postcheck[i];
         if (exprop(n) != Omemb)
             continue;
-        if (type(n)->type == Typtr)
+        if (type(st, n)->type == Typtr)
             n = n->expr.args[0];
-        aggr = postcheck[i]->expr.args[0];
-        memb = postcheck[i]->expr.args[1];
+        aggr = st->postcheck[i]->expr.args[0];
+        memb = st->postcheck[i]->expr.args[1];
 
         found = 0;
-        t = tf(type(aggr));
+        t = tf(st, type(st, aggr));
         if (t->type == Tyslice || t->type == Tyarray) {
             if (!strcmp(namestr(memb), "len")) {
-                constrain(n, type(n), cstrtab[Tcnum]);
-                constrain(n, type(n), cstrtab[Tcint]);
-                constrain(n, type(n), cstrtab[Tctest]);
+                constrain(n, type(st, n), cstrtab[Tcnum]);
+                constrain(n, type(st, n), cstrtab[Tcint]);
+                constrain(n, type(st, n), cstrtab[Tctest]);
                 found = 1;
             }
         } else {
             t = tybase(t);
             if (t->type == Typtr)
-                t = tf(t->sub[0]);
+                t = tf(st, t->sub[0]);
             nl = t->sdecls;
             for (j = 0; j < t->nmemb; j++) {
                 if (!strcmp(namestr(memb), declname(nl[j]))) {
-                    unify(n, type(n), decltype(nl[j]));
+                    unify(st, n, type(st, n), decltype(nl[j]));
                     found = 1;
                     break;
                 }
@@ -833,11 +846,11 @@
         }
         if (!found)
             fatal(aggr->line, "Type %s has no member \"%s\" near %s",
-                  tystr(type(aggr)), ctxstr(memb), ctxstr(aggr));
+                  tystr(type(st, aggr)), ctxstr(memb), ctxstr(aggr));
     }
 }
 
-static void stabsub(Stab *s)
+static void stabsub(Inferstate *st, Stab *s)
 {
     void **k;
     size_t n, i;
@@ -846,7 +859,7 @@
 
     k = htkeys(s->ty, &n);
     for (i = 0; i < n; i++) {
-        t = tf(gettype(s, k[i]));
+        t = tf(st, gettype(s, k[i]));
         updatetype(s, k[i], t);
     }
     free(k);
@@ -854,12 +867,12 @@
     k = htkeys(s->dcl, &n);
     for (i = 0; i < n; i++) {
         d = getdcl(s, k[i]);
-        d->decl.type = tyfix(d->decl.name, d->decl.type);
+        d->decl.type = tyfix(st, d->decl.name, d->decl.type);
     }
     free(k);
 }
 
-static void typesub(Node *n)
+static void typesub(Inferstate *st, Node *n)
 {
     size_t i;
 
@@ -868,63 +881,63 @@
     switch (n->type) {
         case Nfile:
             pushstab(n->file.globls);
-            stabsub(n->file.globls);
-            stabsub(n->file.exports);
+            stabsub(st, n->file.globls);
+            stabsub(st, n->file.exports);
             for (i = 0; i < n->file.nstmts; i++)
-                typesub(n->file.stmts[i]);
+                typesub(st, n->file.stmts[i]);
             popstab();
             break;
         case Ndecl:
-            settype(n, tyfix(n, type(n)));
+            settype(st, n, tyfix(st, n, type(st, n)));
             if (n->decl.init)
-                typesub(n->decl.init);
+                typesub(st, n->decl.init);
             break;
         case Nblock:
             pushstab(n->block.scope);
             for (i = 0; i < n->block.nstmts; i++)
-                typesub(n->block.stmts[i]);
+                typesub(st, n->block.stmts[i]);
             popstab();
             break;
         case Nifstmt:
-            typesub(n->ifstmt.cond);
-            typesub(n->ifstmt.iftrue);
-            typesub(n->ifstmt.iffalse);
+            typesub(st, n->ifstmt.cond);
+            typesub(st, n->ifstmt.iftrue);
+            typesub(st, n->ifstmt.iffalse);
             break;
         case Nloopstmt:
-            typesub(n->loopstmt.cond);
-            typesub(n->loopstmt.init);
-            typesub(n->loopstmt.step);
-            typesub(n->loopstmt.body);
+            typesub(st, n->loopstmt.cond);
+            typesub(st, n->loopstmt.init);
+            typesub(st, n->loopstmt.step);
+            typesub(st, n->loopstmt.body);
             break;
         case Nmatchstmt:
-            typesub(n->matchstmt.val);
+            typesub(st, n->matchstmt.val);
             for (i = 0; i < n->matchstmt.nmatches; i++)
-                typesub(n->matchstmt.matches[i]);
+                typesub(st, n->matchstmt.matches[i]);
             break;
         case Nmatch:
-            typesub(n->match.pat);
-            typesub(n->match.block);
+            typesub(st, n->match.pat);
+            typesub(st, n->match.block);
             break;
         case Nexpr:
-            settype(n, tyfix(n, type(n)));
+            settype(st, n, tyfix(st, n, type(st, n)));
             for (i = 0; i < n->expr.nargs; i++)
-                typesub(n->expr.args[i]);
+                typesub(st, n->expr.args[i]);
             break;
         case Nfunc:
             pushstab(n->func.scope);
-            settype(n, tyfix(n, n->func.type));
+            settype(st, n, tyfix(st, n, n->func.type));
             for (i = 0; i < n->func.nargs; i++)
-                typesub(n->func.args[i]);
-            typesub(n->func.body);
+                typesub(st, n->func.args[i]);
+            typesub(st, n->func.body);
             popstab();
             break;
         case Nlit:
-            settype(n, tyfix(n, type(n)));
+            settype(st, n, tyfix(st, n, type(st, n)));
             switch (n->lit.littype) {
-                case Lfunc:     typesub(n->lit.fnval); break;
+                case Lfunc:     typesub(st, n->lit.fnval); break;
                 case Lseq:
                     for (i = 0; i < n->lit.nelt; i++)
-                        typesub(n->lit.seqval[i]);
+                        typesub(st, n->lit.seqval[i]);
                     break;
                 default:        break;
             }
@@ -939,7 +952,7 @@
     }
 }
 
-static void mergeexports(Node *file)
+static void mergeexports(Inferstate *st, Node *file)
 {
     Stab *exports, *globls;
     size_t i, nk;
@@ -965,7 +978,7 @@
         } else {
             tg = gettype(globls, nl);
             if (tg) 
-                updatetype(exports, nl, tf(tg));
+                updatetype(exports, nl, tf(st, tg));
             else
                 fatal(nl->line, "Exported type %s not declared", namestr(nl));
         }
@@ -983,22 +996,22 @@
         if (!ng)
             putdcl(globls, nl);
         else
-            unify(nl, type(ng), type(nl));
+            unify(st, nl, type(st, ng), type(st, nl));
     }
     free(k);
     popstab();
 }
 
-static void specialize(Node *f)
+static void specialize(Inferstate *st, Node *f)
 {
     Node *d, *name;
     size_t i;
 
-    for (i = 0; i < nspecializations; i++) {
-        pushstab(specializationscope[i]);
-        d = specializedcl(genericdecls[i], specializations[i]->expr.type, &name);
-        specializations[i]->expr.args[0] = name;
-        specializations[i]->expr.did = d->decl.did;
+    for (i = 0; i < st->nspecializations; i++) {
+        pushstab(st->specializationscope[i]);
+        d = specializedcl(st->genericdecls[i], st->specializations[i]->expr.type, &name);
+        st->specializations[i]->expr.args[0] = name;
+        st->specializations[i]->expr.did = d->decl.did;
         popstab();
     }
 }
@@ -1005,13 +1018,14 @@
 
 void infer(Node *file)
 {
-    assert(file->type == Nfile);
+    Inferstate st = {0,};
 
+    assert(file->type == Nfile);
     loaduses(file);
-    mergeexports(file);
-    infernode(file, NULL, NULL);
-    infercompn(file);
-    checkcast(file);
-    typesub(file);
-    specialize(file);
+    mergeexports(&st, file);
+    infernode(&st, file, NULL, NULL);
+    infercompn(&st, file);
+    checkcast(&st, file);
+    typesub(&st, file);
+    specialize(&st, file);
 }