ref: df0b6010f2506030d6674dedede207d19434f662
parent: d97dcac9b2b1e2526d897c508bb37c82e32517cd
author: Ori Bernstein <[email protected]>
date: Mon Jul 16 17:16:19 EDT 2012
Move inference state into state struct. Means we don't have to make lots of crap global or thread it all through the code in an ugly way.
--- a/8/isel.c
+++ b/8/isel.c
@@ -267,7 +267,7 @@
* or:
* Oadd(
* reg,
- * reg||const))
+ * reg||const)
*
* or:
* Oadd(
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -12,23 +12,36 @@
#include "parse.h"
-static Node **postcheck;
-static size_t npostcheck;
-static Htab **tybindings;
-static size_t ntybindings;
-static Node **genericdecls;
-static size_t ngenericdecls;
-static Node **specializations;
-static Node **specializations;
-static size_t nspecializations;
-static Stab **specializationscope;
-static size_t nspecializationscope;
+typedef struct Inferstate Inferstate;
+struct Inferstate {
+ int inpattern;
+ int ingeneric;
-static void infernode(Node *n, Type *ret, int *sawret);
-static void inferexpr(Node *n, Type *ret, int *sawret);
-static void typesub(Node *n);
-static Type *tf(Type *t);
+ /* bound by patterns turn into decls in the action block */
+ Node **binds;
+ size_t nbinds;
+ /* nodes that need post-inference checking/unification */
+ Node **postcheck;
+ size_t npostcheck;
+ /* the type parmas bound at the current point */
+ Htab **tybindings;
+ size_t ntybindings;
+ /* generic declarations to be specialized */
+ Node **genericdecls;
+ size_t ngenericdecls;
+ /* the nodes that we've specialized them to, and the scopes they
+ * appear in */
+ Node **specializations;
+ size_t nspecializations;
+ Stab **specializationscope;
+ size_t nspecializationscope;
+};
+static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret);
+static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret);
+static void typesub(Inferstate *st, Node *n);
+static Type *tf(Inferstate *st, Type *t);
+
static void setsuper(Stab *st, Stab *super)
{
Stab *s;
@@ -39,13 +52,13 @@
st->super = super;
}
-static int isbound(Type *t)
+static int isbound(Inferstate *st, Type *t)
{
ssize_t i;
Type *p;
- for (i = ntybindings - 1; i >= 0; i--) {
- p = htget(tybindings[i], t->pname);
+ for (i = st->ntybindings - 1; i >= 0; i--) {
+ p = htget(st->tybindings[i], t->pname);
if (p == t)
return 1;
}
@@ -52,12 +65,12 @@
return 0;
}
-static Type *tyfreshen(Htab *ht, Type *t)
+static Type *tyfreshen(Inferstate *st, Htab *ht, Type *t)
{
Type *ret;
size_t i;
- t = tf(t);
+ t = tf(st, t);
if (t->type != Typaram && t->nsub == 0)
return t;
@@ -71,21 +84,21 @@
ret = tydup(t);
for (i = 0; i < t->nsub; i++)
- ret->sub[i] = tyfreshen(ht, t->sub[i]);
+ ret->sub[i] = tyfreshen(st, ht, t->sub[i]);
return ret;
}
-static Type *freshen(Type *t)
+static Type *freshen(Inferstate *st, Type *t)
{
Htab *ht;
ht = mkht(strhash, streq);
- t = tyfreshen(ht, t);
+ t = tyfreshen(st, ht, t);
htfree(ht);
return t;
}
-static void tyresolve(Type *t)
+static void tyresolve(Inferstate *st, Type *t)
{
size_t i;
Type *base;
@@ -95,22 +108,22 @@
t->resolved = 1;
if (t->type == Tystruct) {
for (i = 0; i < t->nmemb; i++)
- infernode(t->sdecls[i], NULL, NULL);
+ infernode(st, t->sdecls[i], NULL, NULL);
} else if (t->type == Tyunion) {
for (i = 0; i < t->nmemb; i++) {
- tyresolve(t->udecls[i]->utype);
- t->udecls[i]->utype = tf(t->udecls[i]->utype);
+ tyresolve(st, t->udecls[i]->utype);
+ t->udecls[i]->utype = tf(st, t->udecls[i]->utype);
if (t->udecls[i]->etype) {
- tyresolve(t->udecls[i]->etype);
- t->udecls[i]->etype = tf(t->udecls[i]->etype);
+ tyresolve(st, t->udecls[i]->etype);
+ t->udecls[i]->etype = tf(st, t->udecls[i]->etype);
}
}
} else if (t->type == Tyarray) {
- infernode(t->asize, NULL, NULL);
+ infernode(st, t->asize, NULL, NULL);
}
for (i = 0; i < t->nsub; i++)
- t->sub[i] = tf(t->sub[i]);
+ t->sub[i] = tf(st, t->sub[i]);
base = tybase(t);
/* no-ops if base == t */
if (t->cstrs)
@@ -120,7 +133,7 @@
}
/* fixd the most accurate type mapping we have */
-static Type *tf(Type *t)
+static Type *tf(Inferstate *st, Type *t)
{
Type *lu;
@@ -137,7 +150,7 @@
break;
t = tytab[t->tid];
}
- tyresolve(t);
+ tyresolve(st, t);
return t;
}
@@ -150,9 +163,9 @@
readuse(n->file.uses[i], n->file.globls);
}
-static void settype(Node *n, Type *t)
+static void settype(Inferstate *st, Node *n, Type *t)
{
- t = tf(t);
+ t = tf(st, t);
switch (n->type) {
case Nexpr: n->expr.type = t; break;
case Ndecl: n->decl.type = t; break;
@@ -181,7 +194,7 @@
return NULL;
}
-static Type *type(Node *n)
+static Type *type(Inferstate *st, Node *n)
{
Type *t;
@@ -195,7 +208,7 @@
die("untypeable %s", nodestr(n->type));
break;
};
- return tf(t);
+ return tf(st, t);
}
static char *ctxstr(Node *n)
@@ -285,7 +298,7 @@
return 0;
}
-static Type *unify(Node *ctx, Type *a, Type *b)
+static Type *unify(Inferstate *st, Node *ctx, Type *a, Type *b)
{
Type *t;
Type *r;
@@ -292,8 +305,8 @@
size_t i;
/* a ==> b */
- a = tf(a);
- b = tf(b);
+ a = tf(st, a);
+ b = tf(st, b);
if (a == b)
return a;
if (b->type == Tyvar) {
@@ -318,7 +331,7 @@
if (i >= a->nsub)
fatal(ctx->line, "%s has wrong subtypes for %s near %s", tystr(a), tystr(b), ctxstr(ctx));
- unify(ctx, a->sub[i], b->sub[i]);
+ unify(st, ctx, a->sub[i], b->sub[i]);
}
r = b;
} else if (a->type != Tyvar) {
@@ -327,12 +340,12 @@
return r;
}
-static void unifycall(Node *n)
+static void unifycall(Inferstate *st, Node *n)
{
size_t i;
Type *ft;
- ft = type(n->expr.args[0]);
+ ft = type(st, n->expr.args[0]);
if (ft->type == Tyvar) {
/* the first arg is the function itself, so it shouldn't be counted */
ft = mktyfunc(n->line, &n->expr.args[1], n->expr.nargs - 1, mktyvar(n->line));
@@ -340,17 +353,17 @@
for (i = 1; i < n->expr.nargs; i++) {
if (ft->sub[i]->type == Tyvalist)
break;
- inferexpr(n->expr.args[i], NULL, NULL);
- unify(n, ft->sub[i], type(n->expr.args[i]));
+ inferexpr(st, n->expr.args[i], NULL, NULL);
+ unify(st, n, ft->sub[i], type(st, n->expr.args[i]));
}
- settype(n, ft->sub[0]);
+ settype(st, n, ft->sub[0]);
}
-static void checkns(Node *n, Node **ret)
+static void checkns(Inferstate *st, Node *n, Node **ret)
{
Node *var, *name, *nsname;
Node **args;
- Stab *st;
+ Stab *stab;
Node *s;
if (n->type != Nexpr)
@@ -361,31 +374,31 @@
if (args[0]->type != Nexpr || exprop(args[0]) != Ovar)
return;
name = args[0]->expr.args[0];
- st = getns(curstab(), name);
- if (!st)
+ stab = getns(curstab(), name);
+ if (!stab)
return;
nsname = mknsname(n->line, namestr(name), namestr(args[1]));
- s = getdcl(st, args[1]);
+ s = getdcl(stab, args[1]);
if (!s)
fatal(n->line, "Undeclared var %s.%s", nsname->name.ns, nsname->name.name);
var = mkexpr(n->line, Ovar, nsname, NULL);
var->expr.did = s->decl.did;
- settype(var, s->decl.type);
+ settype(st, var, s->decl.type);
*ret = var;
}
-static void inferseq(Node *n)
+static void inferseq(Inferstate *st, Node *n)
{
size_t i;
for (i = 0; i < n->lit.nelt; i++) {
- infernode(n->lit.seqval[i], NULL, NULL);
- unify(n, type(n->lit.seqval[0]), type(n->lit.seqval[i]));
+ infernode(st, n->lit.seqval[i], NULL, NULL);
+ unify(st, n, type(st, n->lit.seqval[0]), type(st, n->lit.seqval[i]));
}
- settype(n, mktyarray(n->line, type(n->lit.seqval[0]), mkintlit(n->line, n->lit.nelt)));
+ settype(st, n, mktyarray(n->line, type(st, n->lit.seqval[0]), mkintlit(n->line, n->lit.nelt)));
}
-static void inferexpr(Node *n, Type *ret, int *sawret)
+static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret)
{
Node **args;
Type **types;
@@ -403,8 +416,8 @@
/* Omemb can sometimes resolve to a namespace. We have to check
* this. Icky. */
if (exprop(args[i]) == Omemb)
- checkns(args[i], &args[i]);
- inferexpr(args[i], ret, sawret);
+ checkns(st, args[i], &args[i]);
+ inferexpr(st, args[i], ret, sawret);
}
}
switch (exprop(n)) {
@@ -436,10 +449,10 @@
case Obxoreq: /* @a ^= @a -> @a */
case Obsleq: /* @a <<= @a -> @a */
case Obsreq: /* @a >>= @a -> @a */
- t = type(args[0]);
+ t = type(st, args[0]);
for (i = 1; i < nargs; i++)
- t = unify(n, t, type(args[i]));
- settype(n, tf(t));
+ t = unify(st, n, t, type(st, args[i]));
+ settype(st, n, tf(st, t));
break;
/* operands same type, returning bool */
@@ -452,45 +465,45 @@
case Oge: /* @a >= @a -> bool */
case Olt: /* @a < @a -> bool */
case Ole: /* @a <= @b -> bool */
- t = type(args[0]);
+ t = type(st, args[0]);
for (i = 1; i < nargs; i++)
- unify(n, t, type(args[i]));
- settype(n, mkty(-1, Tybool));
+ unify(st, n, t, type(st, args[i]));
+ settype(st, n, mkty(-1, Tybool));
break;
/* reach into a type and pull out subtypes */
case Oaddr: /* &@a -> @a* */
- settype(n, mktyptr(n->line, type(args[0])));
+ settype(st, n, mktyptr(n->line, type(st, args[0])));
break;
case Oderef: /* *@a* -> @a */
- t = unify(n, type(args[0]), mktyptr(n->line, mktyvar(n->line)));
- settype(n, t);
+ t = unify(st, n, type(st, args[0]), mktyptr(n->line, mktyvar(n->line)));
+ settype(st, n, t);
break;
case Oidx: /* @a[@b::tcint] -> @a */
t = mktyidxhack(n->line, mktyvar(n->line));
- unify(n, type(args[0]), t);
- constrain(n, type(args[1]), cstrtab[Tcint]);
- settype(n, tf(t->sub[0]));
+ unify(st, n, type(st, args[0]), t);
+ constrain(n, type(st, args[1]), cstrtab[Tcint]);
+ settype(st, n, tf(st, t->sub[0]));
break;
case Oslice: /* @a[@b::tcint,@b::tcint] -> @a[,] */
- t = mktyidxhack(n->line, type(args[1]));
- unify(n, type(args[0]), t);
- settype(n, mktyslice(n->line, type(args[1])));
+ t = mktyidxhack(n->line, type(st, args[1]));
+ unify(st, n, type(st, args[0]), t);
+ settype(st, n, mktyslice(n->line, type(st, args[1])));
break;
/* special cases */
case Omemb: /* @a.Ident -> @b, verify type(@a.Ident)==@b later */
- settype(n, mktyvar(n->line));
- lappend(&postcheck, &npostcheck, n);
+ settype(st, n, mktyvar(n->line));
+ lappend(&st->postcheck, &st->npostcheck, n);
break;
case Osize: /* sizeof @a -> size */
- settype(n, mkty(n->line, Tyuint));
+ settype(st, n, mkty(n->line, Tyuint));
break;
case Ocall: /* (@a, @b, @c, ... -> @r)(@a,@b,@c, ... -> @r) -> @r */
- unifycall(n);
+ unifycall(st, n);
break;
case Ocast: /* cast(@a, @b) -> @b */
- lappend(&postcheck, &npostcheck, n);
+ lappend(&st->postcheck, &st->npostcheck, n);
break;
case Oret: /* -> @a -> void */
if (sawret)
@@ -498,13 +511,13 @@
if (!ret)
fatal(n->line, "Not allowed to return value here");
if (nargs)
- t = unify(n, ret, type(args[0]));
+ t = unify(st, n, ret, type(st, args[0]));
else
- t = unify(n, mkty(-1, Tyvoid), ret);
- settype(n, t);
+ t = unify(st, n, mkty(-1, Tyvoid), ret);
+ settype(st, n, t);
break;
case Ojmp: /* goto void* -> void */
- settype(n, mkty(-1, Tyvoid));
+ settype(st, n, mkty(-1, Tyvoid));
break;
case Ovar: /* a:@a -> @a */
/* if we created this from a namespaced var, the type should be
@@ -517,15 +530,15 @@
fatal(n->line, "Undeclared var %s", ctxstr(args[0]));
if (s->decl.isgeneric)
- t = freshen(s->decl.type);
+ t = freshen(st, s->decl.type);
else
t = s->decl.type;
- settype(n, t);
+ settype(st, n, t);
n->expr.did = s->decl.did;
if (s->decl.isgeneric) {
- lappend(&specializationscope, &nspecializationscope, curstab());
- lappend(&specializations, &nspecializations, n);
- lappend(&genericdecls, &ngenericdecls, s);
+ lappend(&st->specializationscope, &st->nspecializationscope, curstab());
+ lappend(&st->specializations, &st->nspecializations, n);
+ lappend(&st->genericdecls, &st->ngenericdecls, s);
}
break;
case Ocons:
@@ -537,30 +550,30 @@
else if (uc->etype && n->expr.nargs != 2)
fatal(n->line, "union constructor %s needs arg ", ctxstr(args[0]));
else if (uc->etype)
- unify(n, uc->etype, type(args[1]));
- settype(n, uc->utype);
+ unify(st, n, uc->etype, type(st, args[1]));
+ settype(st, n, uc->utype);
break;
case Otup:
types = xalloc(sizeof(Type *)*n->expr.nargs);
for (i = 0; i < n->expr.nargs; i++)
- types[i] = type(n->expr.args[i]);
- settype(n, mktytuple(n->line, types, n->expr.nargs));
+ types[i] = type(st, n->expr.args[i]);
+ settype(st, n, mktytuple(n->line, types, n->expr.nargs));
break;
case Oarr:
for (i = 0; i < n->expr.nargs; i++)
- unify(n, type(n->expr.args[0]), type(n->expr.args[i]));
- settype(n, mktyarray(n->line, type(n->expr.args[0]), mkintlit(n->line, n->expr.nargs)));
+ unify(st, n, type(st, n->expr.args[0]), type(st, n->expr.args[i]));
+ settype(st, n, mktyarray(n->line, type(st, n->expr.args[0]), mkintlit(n->line, n->expr.nargs)));
break;
case Olit: /* <lit>:@a::tyclass -> @a */
switch (args[0]->lit.littype) {
- case Lfunc: infernode(args[0]->lit.fnval, NULL, NULL); break;
- case Lseq: inferseq(args[0]); break;
+ case Lfunc: infernode(st, args[0]->lit.fnval, NULL, NULL); break;
+ case Lseq: inferseq(st, args[0]); break;
default: /* pass */ break;
}
- settype(n, type(args[0]));
+ settype(st, n, type(st, args[0]));
break;
case Olbl: /* :lbl -> void* */
- settype(n, mktyptr(n->line, mkty(-1, Tyvoid)));
+ settype(st, n, mktyptr(n->line, mkty(-1, Tyvoid)));
case Obad: case Ocjmp:
case Oload: case Ostor:
case Oslbase: case Osllen:
@@ -570,7 +583,7 @@
}
}
-static void inferfunc(Node *n)
+static void inferfunc(Inferstate *st, Node *n)
{
size_t i;
int sawret;
@@ -577,22 +590,22 @@
sawret = 0;
for (i = 0; i < n->func.nargs; i++)
- infernode(n->func.args[i], NULL, NULL);
- infernode(n->func.body, n->func.type->sub[0], &sawret);
+ infernode(st, n->func.args[i], NULL, NULL);
+ infernode(st, n->func.body, n->func.type->sub[0], &sawret);
/* if there's no return stmt in the function, assume void ret */
if (!sawret)
- unify(n, type(n)->sub[0], mkty(-1, Tyvoid));
+ unify(st, n, type(st, n)->sub[0], mkty(-1, Tyvoid));
}
-static void inferdecl(Node *n)
+static void inferdecl(Inferstate *st, Node *n)
{
Type *t;
- t = tf(decltype(n));
- settype(n, t);
+ t = tf(st, decltype(n));
+ settype(st, n, t);
if (n->decl.init) {
- inferexpr(n->decl.init, NULL, NULL);
- unify(n, type(n), type(n->decl.init));
+ inferexpr(st, n->decl.init, NULL, NULL);
+ unify(st, n, type(st, n), type(st, n->decl.init));
} else {
if (n->decl.isconst && !n->decl.isextern)
fatal(n->line, "non-extern \"%s\" has no initializer", ctxstr(n));
@@ -599,7 +612,7 @@
}
}
-static void inferstab(Stab *s)
+static void inferstab(Inferstate *st, Stab *s)
{
void **k;
size_t n, i;
@@ -607,13 +620,13 @@
k = htkeys(s->ty, &n);
for (i = 0; i < n; i++) {
- t = tf(gettype(s, k[i]));
+ t = tf(st, gettype(s, k[i]));
updatetype(s, k[i], t);
}
free(k);
}
-static void tybind(Htab *bt, Type *t)
+static void tybind(Inferstate *st, Htab *bt, Type *t)
{
size_t i;
@@ -623,16 +636,16 @@
return;
if (hthas(bt, t->pname))
- unify(NULL, htget(bt, t->pname), t);
- else if (isbound(t))
+ unify(st, NULL, htget(bt, t->pname), t);
+ else if (isbound(st, t))
return;
htput(bt, t->pname, t);
for (i = 0; i < t->nsub; i++)
- tybind(bt, t->sub[i]);
+ tybind(st, bt, t->sub[i]);
}
-static void bind(Node *n)
+static void bind(Inferstate *st, Node *n)
{
Htab *bt;
@@ -642,21 +655,21 @@
fatal(n->line, "generic %s has no initializer", n->decl);
bt = mkht(strhash, streq);
- lappend(&tybindings, &ntybindings, bt);
+ lappend(&st->tybindings, &st->ntybindings, bt);
- tybind(bt, n->decl.type);
- tybind(bt, n->decl.init->expr.type);
+ tybind(st, bt, n->decl.type);
+ tybind(st, bt, n->decl.init->expr.type);
}
-static void unbind(Node *n)
+static void unbind(Inferstate *st, Node *n)
{
if (!n->decl.isgeneric)
return;
- htfree(tybindings[ntybindings - 1]);
- lpop(&tybindings, &ntybindings);
+ htfree(st->tybindings[st->ntybindings - 1]);
+ lpop(&st->tybindings, &st->ntybindings);
}
-static void infernode(Node *n, Type *ret, int *sawret)
+static void infernode(Inferstate *st, Node *n, Type *ret, int *sawret)
{
size_t i;
Node *d;
@@ -669,69 +682,69 @@
pushstab(n->file.globls);
/* exports allow us to specify types later in the body, so we
* need to patch the types in if they don't have a definition */
- inferstab(n->file.globls);
- inferstab(n->file.exports);
+ inferstab(st, n->file.globls);
+ inferstab(st, n->file.exports);
for (i = 0; i < n->file.nstmts; i++) {
d = n->file.stmts[i];
- infernode(d, NULL, sawret);
+ infernode(st, d, NULL, sawret);
if (d->type == Ndecl) {
s = getdcl(file->file.exports, d->decl.name);
if (s)
- unify(d, type(d), s->decl.type);
+ unify(st, d, type(st, d), s->decl.type);
}
}
popstab();
break;
case Ndecl:
- bind(n);
- inferdecl(n);
- unbind(n);
+ bind(st, n);
+ inferdecl(st, n);
+ unbind(st, n);
break;
case Nblock:
setsuper(n->block.scope, curstab());
pushstab(n->block.scope);
- inferstab(n->block.scope);
+ inferstab(st, n->block.scope);
for (i = 0; i < n->block.nstmts; i++) {
- checkns(n->block.stmts[i], &n->block.stmts[i]);
- infernode(n->block.stmts[i], ret, sawret);
+ checkns(st, n->block.stmts[i], &n->block.stmts[i]);
+ infernode(st, n->block.stmts[i], ret, sawret);
}
popstab();
break;
case Nifstmt:
- infernode(n->ifstmt.cond, NULL, sawret);
- infernode(n->ifstmt.iftrue, ret, sawret);
- infernode(n->ifstmt.iffalse, ret, sawret);
- constrain(n, type(n->ifstmt.cond), cstrtab[Tctest]);
+ infernode(st, n->ifstmt.cond, NULL, sawret);
+ infernode(st, n->ifstmt.iftrue, ret, sawret);
+ infernode(st, n->ifstmt.iffalse, ret, sawret);
+ constrain(n, type(st, n->ifstmt.cond), cstrtab[Tctest]);
break;
case Nloopstmt:
- infernode(n->loopstmt.init, ret, sawret);
- infernode(n->loopstmt.cond, NULL, sawret);
- infernode(n->loopstmt.step, ret, sawret);
- infernode(n->loopstmt.body, ret, sawret);
- constrain(n, type(n->loopstmt.cond), cstrtab[Tctest]);
+ infernode(st, n->loopstmt.init, ret, sawret);
+ infernode(st, n->loopstmt.cond, NULL, sawret);
+ infernode(st, n->loopstmt.step, ret, sawret);
+ infernode(st, n->loopstmt.body, ret, sawret);
+ constrain(n, type(st, n->loopstmt.cond), cstrtab[Tctest]);
break;
case Nmatchstmt:
- infernode(n->matchstmt.val, NULL, sawret);
+ infernode(st, n->matchstmt.val, NULL, sawret);
for (i = 0; i < n->matchstmt.nmatches; i++) {
- infernode(n->matchstmt.matches[i], ret, sawret);
- unify(n, type(n->matchstmt.val), type(n->matchstmt.matches[i]->match.pat));
+ infernode(st, n->matchstmt.matches[i], ret, sawret);
+ unify(st, n, type(st, n->matchstmt.val), type(st, n->matchstmt.matches[i]->match.pat));
}
break;
case Nmatch:
- infernode(n->match.pat, NULL, sawret);
- infernode(n->match.block, ret, sawret);
+ infernode(st, n->match.pat, NULL, sawret);
+ infernode(st, n->match.block, ret, sawret);
break;
case Nexpr:
- inferexpr(n, ret, sawret);
+ inferexpr(st, n, ret, sawret);
break;
case Nfunc:
setsuper(n->func.scope, curstab());
- if (ntybindings > 0)
+ if (st->ntybindings > 0)
for (i = 0; i < n->func.nargs; i++)
- tybind(tybindings[ntybindings - 1], n->func.args[i]->decl.type);
+ tybind(st, st->tybindings[st->ntybindings - 1], n->func.args[i]->decl.type);
pushstab(n->func.scope);
- inferstab(n->block.scope);
- inferfunc(n);
+ inferstab(st, n->block.scope);
+ inferfunc(st, n);
popstab();
break;
case Nname:
@@ -745,13 +758,13 @@
}
}
-static void checkcast(Node *n)
+static void checkcast(Inferstate *st, Node *n)
{
}
/* returns the final type for t, after all unifications
* and default constraint selections */
-static Type *tyfix(Node *ctx, Type *t)
+static Type *tyfix(Inferstate *st, Node *ctx, Type *t)
{
static Type *tyint, *tyflt;
size_t i;
@@ -762,7 +775,7 @@
if (!tyflt)
tyflt = mkty(-1, Tyfloat64);
- t = tf(t);
+ t = tf(st, t);
if (t->type == Tyvar) {
if (hascstr(t, cstrtab[Tcint]) && cstrcheck(t, tyint))
return tyint;
@@ -770,18 +783,18 @@
return tyint;
} else {
if (t->type == Tyarray) {
- typesub(t->asize);
+ typesub(st, t->asize);
} else if (t->type == Tystruct) {
for (i = 0; i < t->nmemb; i++)
- typesub(t->sdecls[i]);
+ typesub(st, t->sdecls[i]);
} else if (t->type == Tyunion) {
for (i = 0; i < t->nmemb; i++) {
if (t->udecls[i]->etype)
- t->udecls[i]->etype = tyfix(ctx, t->udecls[i]->etype);
+ t->udecls[i]->etype = tyfix(st, ctx, t->udecls[i]->etype);
}
}
for (i = 0; i < t->nsub; i++)
- t->sub[i] = tyfix(ctx, t->sub[i]);
+ t->sub[i] = tyfix(st, ctx, t->sub[i]);
}
if (t->type == Tyvar) {
fatal(t->line, "underconstrained type %s near %s", tyfmt(buf, 1024, t), ctxstr(ctx));
@@ -790,7 +803,7 @@
return t;
}
-static void infercompn(Node *file)
+static void infercompn(Inferstate *st, Node *file)
{
size_t i, j;
Node *aggr;
@@ -800,32 +813,32 @@
Type *t;
int found;
- for (i = 0; i < npostcheck; i++) {
- n = postcheck[i];
+ for (i = 0; i < st->npostcheck; i++) {
+ n = st->postcheck[i];
if (exprop(n) != Omemb)
continue;
- if (type(n)->type == Typtr)
+ if (type(st, n)->type == Typtr)
n = n->expr.args[0];
- aggr = postcheck[i]->expr.args[0];
- memb = postcheck[i]->expr.args[1];
+ aggr = st->postcheck[i]->expr.args[0];
+ memb = st->postcheck[i]->expr.args[1];
found = 0;
- t = tf(type(aggr));
+ t = tf(st, type(st, aggr));
if (t->type == Tyslice || t->type == Tyarray) {
if (!strcmp(namestr(memb), "len")) {
- constrain(n, type(n), cstrtab[Tcnum]);
- constrain(n, type(n), cstrtab[Tcint]);
- constrain(n, type(n), cstrtab[Tctest]);
+ constrain(n, type(st, n), cstrtab[Tcnum]);
+ constrain(n, type(st, n), cstrtab[Tcint]);
+ constrain(n, type(st, n), cstrtab[Tctest]);
found = 1;
}
} else {
t = tybase(t);
if (t->type == Typtr)
- t = tf(t->sub[0]);
+ t = tf(st, t->sub[0]);
nl = t->sdecls;
for (j = 0; j < t->nmemb; j++) {
if (!strcmp(namestr(memb), declname(nl[j]))) {
- unify(n, type(n), decltype(nl[j]));
+ unify(st, n, type(st, n), decltype(nl[j]));
found = 1;
break;
}
@@ -833,11 +846,11 @@
}
if (!found)
fatal(aggr->line, "Type %s has no member \"%s\" near %s",
- tystr(type(aggr)), ctxstr(memb), ctxstr(aggr));
+ tystr(type(st, aggr)), ctxstr(memb), ctxstr(aggr));
}
}
-static void stabsub(Stab *s)
+static void stabsub(Inferstate *st, Stab *s)
{
void **k;
size_t n, i;
@@ -846,7 +859,7 @@
k = htkeys(s->ty, &n);
for (i = 0; i < n; i++) {
- t = tf(gettype(s, k[i]));
+ t = tf(st, gettype(s, k[i]));
updatetype(s, k[i], t);
}
free(k);
@@ -854,12 +867,12 @@
k = htkeys(s->dcl, &n);
for (i = 0; i < n; i++) {
d = getdcl(s, k[i]);
- d->decl.type = tyfix(d->decl.name, d->decl.type);
+ d->decl.type = tyfix(st, d->decl.name, d->decl.type);
}
free(k);
}
-static void typesub(Node *n)
+static void typesub(Inferstate *st, Node *n)
{
size_t i;
@@ -868,63 +881,63 @@
switch (n->type) {
case Nfile:
pushstab(n->file.globls);
- stabsub(n->file.globls);
- stabsub(n->file.exports);
+ stabsub(st, n->file.globls);
+ stabsub(st, n->file.exports);
for (i = 0; i < n->file.nstmts; i++)
- typesub(n->file.stmts[i]);
+ typesub(st, n->file.stmts[i]);
popstab();
break;
case Ndecl:
- settype(n, tyfix(n, type(n)));
+ settype(st, n, tyfix(st, n, type(st, n)));
if (n->decl.init)
- typesub(n->decl.init);
+ typesub(st, n->decl.init);
break;
case Nblock:
pushstab(n->block.scope);
for (i = 0; i < n->block.nstmts; i++)
- typesub(n->block.stmts[i]);
+ typesub(st, n->block.stmts[i]);
popstab();
break;
case Nifstmt:
- typesub(n->ifstmt.cond);
- typesub(n->ifstmt.iftrue);
- typesub(n->ifstmt.iffalse);
+ typesub(st, n->ifstmt.cond);
+ typesub(st, n->ifstmt.iftrue);
+ typesub(st, n->ifstmt.iffalse);
break;
case Nloopstmt:
- typesub(n->loopstmt.cond);
- typesub(n->loopstmt.init);
- typesub(n->loopstmt.step);
- typesub(n->loopstmt.body);
+ typesub(st, n->loopstmt.cond);
+ typesub(st, n->loopstmt.init);
+ typesub(st, n->loopstmt.step);
+ typesub(st, n->loopstmt.body);
break;
case Nmatchstmt:
- typesub(n->matchstmt.val);
+ typesub(st, n->matchstmt.val);
for (i = 0; i < n->matchstmt.nmatches; i++)
- typesub(n->matchstmt.matches[i]);
+ typesub(st, n->matchstmt.matches[i]);
break;
case Nmatch:
- typesub(n->match.pat);
- typesub(n->match.block);
+ typesub(st, n->match.pat);
+ typesub(st, n->match.block);
break;
case Nexpr:
- settype(n, tyfix(n, type(n)));
+ settype(st, n, tyfix(st, n, type(st, n)));
for (i = 0; i < n->expr.nargs; i++)
- typesub(n->expr.args[i]);
+ typesub(st, n->expr.args[i]);
break;
case Nfunc:
pushstab(n->func.scope);
- settype(n, tyfix(n, n->func.type));
+ settype(st, n, tyfix(st, n, n->func.type));
for (i = 0; i < n->func.nargs; i++)
- typesub(n->func.args[i]);
- typesub(n->func.body);
+ typesub(st, n->func.args[i]);
+ typesub(st, n->func.body);
popstab();
break;
case Nlit:
- settype(n, tyfix(n, type(n)));
+ settype(st, n, tyfix(st, n, type(st, n)));
switch (n->lit.littype) {
- case Lfunc: typesub(n->lit.fnval); break;
+ case Lfunc: typesub(st, n->lit.fnval); break;
case Lseq:
for (i = 0; i < n->lit.nelt; i++)
- typesub(n->lit.seqval[i]);
+ typesub(st, n->lit.seqval[i]);
break;
default: break;
}
@@ -939,7 +952,7 @@
}
}
-static void mergeexports(Node *file)
+static void mergeexports(Inferstate *st, Node *file)
{
Stab *exports, *globls;
size_t i, nk;
@@ -965,7 +978,7 @@
} else {
tg = gettype(globls, nl);
if (tg)
- updatetype(exports, nl, tf(tg));
+ updatetype(exports, nl, tf(st, tg));
else
fatal(nl->line, "Exported type %s not declared", namestr(nl));
}
@@ -983,22 +996,22 @@
if (!ng)
putdcl(globls, nl);
else
- unify(nl, type(ng), type(nl));
+ unify(st, nl, type(st, ng), type(st, nl));
}
free(k);
popstab();
}
-static void specialize(Node *f)
+static void specialize(Inferstate *st, Node *f)
{
Node *d, *name;
size_t i;
- for (i = 0; i < nspecializations; i++) {
- pushstab(specializationscope[i]);
- d = specializedcl(genericdecls[i], specializations[i]->expr.type, &name);
- specializations[i]->expr.args[0] = name;
- specializations[i]->expr.did = d->decl.did;
+ for (i = 0; i < st->nspecializations; i++) {
+ pushstab(st->specializationscope[i]);
+ d = specializedcl(st->genericdecls[i], st->specializations[i]->expr.type, &name);
+ st->specializations[i]->expr.args[0] = name;
+ st->specializations[i]->expr.did = d->decl.did;
popstab();
}
}
@@ -1005,13 +1018,14 @@
void infer(Node *file)
{
- assert(file->type == Nfile);
+ Inferstate st = {0,};
+ assert(file->type == Nfile);
loaduses(file);
- mergeexports(file);
- infernode(file, NULL, NULL);
- infercompn(file);
- checkcast(file);
- typesub(file);
- specialize(file);
+ mergeexports(&st, file);
+ infernode(&st, file, NULL, NULL);
+ infercompn(&st, file);
+ checkcast(&st, file);
+ typesub(&st, file);
+ specialize(&st, file);
}