ref: ac3035a2c5d8b9580a81852ec9493831b6e13e27
dir: /parse/infer.c/
#include <stdlib.h> #include <stdio.h> #include <stdint.h> #include <ctype.h> #include <string.h> #include <assert.h> #include <sys/types.h> #include <sys/stat.h> #include <fcntl.h> #include <unistd.h> #include <assert.h> #include "parse.h" static Node **checkmemb; static size_t ncheckmemb; static void infernode(Node *n, Type *ret, int *sawret); static void inferexpr(Node *n, Type *ret, int *sawret); static void typesub(Node *n); static Type *tf(Type *t); static void setsuper(Stab *st, Stab *super) { Stab *s; /* verify that we don't accidentally create loops */ for (s = super; s; s = s->super) assert(s->super != st); st->super = super; } static void tyresolve(Type *t) { int i, nn; Node **n; if (t->resolved) return; n = aggrmemb(t, &nn); for (i = 0; i < nn; i++) infernode(n[i], NULL, NULL); for (i = 0; i < t->nsub; i++) t->sub[i] = tf(t->sub[i]); t->resolved = 1; } /* find the most accurate type mapping we have */ static Type *tf(Type *t) { Type *lu; assert(t != NULL); lu = NULL; while (1) { if (!tytab[t->tid] && t->type == Tyname) { if (!(lu = gettype(curstab(), t->name))) fatal(t->name->line, "Could not find type %s", namestr(t->name)); tytab[t->tid] = lu; } if (!tytab[t->tid]) break; t = tytab[t->tid]; } tyresolve(t); return t; } /* does b satisfy all the constraints of a? */ static int cstrcheck(Type *a, Type *b) { /* a has no cstrs to satisfy */ if (!a->cstrs) return 1; /* b satisfies no cstrs; only valid if a requires none */ if (!b->cstrs) return bscount(a->cstrs) == 0; /* if a->cstrs is a subset of b->cstrs, all of * a's constraints are satisfied by b. */ return bsissubset(b->cstrs, a->cstrs); } static void loaduses(Node *n) { int i; /* uses only allowed at top level. Do we want to keep it this way? */ for (i = 0; i < n->file.nuses; i++) fprintf(stderr, "INTERNAL: implement use loading\n"); /* readuse(n->file.uses[i], n->file.globls); */ } static void settype(Node *n, Type *t) { t = tf(t); switch (n->type) { case Nexpr: n->expr.type = t; break; case Ndecl: n->decl.sym->type = t; break; case Nlit: n->lit.type = t; break; case Nfunc: n->func.type = t; break; default: die("can't set type of %s", nodestr(n->type)); break; } } static Type *littype(Node *n) { switch (n->lit.littype) { case Lchr: return mkty(n->line, Tychar); break; case Lbool: return mkty(n->line, Tybool); break; case Lint: return tylike(mktyvar(n->line), Tyint); break; case Lflt: return tylike(mktyvar(n->line), Tyfloat32); break; case Lstr: return mktyslice(n->line, mkty(n->line, Tychar)); break; case Lfunc: return n->lit.fnval->func.type; break; case Larray: return NULL; break; }; return NULL; } static Type *type(Node *n) { Type *t; switch (n->type) { case Nlit: t = littype(n); break; case Nexpr: t = n->expr.type; break; case Ndecl: t = decltype(n); break; case Nfunc: t = n->func.type; break; default: t = NULL; die("untypeable %s", nodestr(n->type)); break; }; return tf(t); } static char *ctxstr(Node *n) { char *s; switch (n->type) { case Nexpr: s = opstr(exprop(n)); break; default: s = nodestr(n->type); break; } return s; } static void mergecstrs(Node *ctx, Type *a, Type *b) { if (b->type == Tyvar) { /* make sure that if a = b, both have same cstrs */ if (a->cstrs && b->cstrs) bsunion(b->cstrs, a->cstrs); else if (a->cstrs) b->cstrs = dupbs(a->cstrs); else if (b->cstrs) a->cstrs = dupbs(b->cstrs); } else { if (!cstrcheck(a, b)) fatal(ctx->line, "%s incompatible with %s near %s", tystr(a), tystr(b), ctxstr(ctx)); } } int idxhacked(Type **pa, Type **pb) { Type *a, *b; a = *pa; b = *pb; /* we want to unify Tyidxhack => concrete indexable. Flip * to make this happen, if needed */ if (b->type == Tyvar && b->nsub > 0) { a = *pb; b = *pa; } return (a->type == Tyvar && a->nsub > 0) || a->type == Tyarray || a->type == Tyslice; } static Type *unify(Node *ctx, Type *a, Type *b) { Type *t; Type *r; int i; /* a ==> b */ a = tf(a); b = tf(b); if (a == b) return a; if (b->type == Tyvar) { t = a; a = b; b = t; } r = NULL; mergecstrs(ctx, a, b); if (a->type == Tyvar) { tytab[a->tid] = b; r = b; } if (a->type == b->type || idxhacked(&a, &b)) { for (i = 0; i < b->nsub; i++) { /* types must have same arity */ if (i >= a->nsub) fatal(ctx->line, "%s incompatible with %s near %s", tystr(a), tystr(b), ctxstr(ctx)); unify(ctx, a->sub[i], b->sub[i]); } r = b; } else if (a->type != Tyvar) { fatal(ctx->line, "%s incompatible with %s near %s", tystr(a), tystr(b), ctxstr(ctx)); } return r; } static void unifycall(Node *n) { int i; Type *ft; inferexpr(n->expr.args[0], NULL, NULL); ft = type(n->expr.args[0]); if (ft->type == Tyvar) { /* the first arg is the function itself, so it shouldn't be counted */ ft = mktyfunc(n->line, &n->expr.args[1], n->expr.nargs - 1, mktyvar(n->line)); unify(n, type(n->expr.args[0]), ft); } for (i = 1; i < n->expr.nargs; i++) { inferexpr(n->expr.args[i], NULL, NULL); unify(n, ft->sub[i], type(n->expr.args[i])); } settype(n, ft->sub[0]); } static void inferexpr(Node *n, Type *ret, int *sawret) { Node **args; Sym *s; int nargs; Type *t; int i; assert(n->type == Nexpr); args = n->expr.args; nargs = n->expr.nargs; for (i = 0; i < nargs; i++) /* Nlit, Nvar, etc should not be inferred as exprs */ if (args[i]->type == Nexpr) inferexpr(args[i], ret, sawret); switch (exprop(n)) { /* all operands are same type */ case Oadd: /* @a + @a -> @a */ case Osub: /* @a - @a -> @a */ case Omul: /* @a * @a -> @a */ case Odiv: /* @a / @a -> @a */ case Omod: /* @a % @a -> @a */ case Oneg: /* -@a -> @a */ case Obor: /* @a | @a -> @a */ case Oband: /* @a & @a -> @a */ case Obxor: /* @a ^ @a -> @a */ case Obsl: /* @a << @a -> @a */ case Obsr: /* @a >> @a -> @a */ case Obnot: /* ~@a -> @a */ case Opreinc: /* ++@a -> @a */ case Opredec: /* --@a -> @a */ case Opostinc: /* @a++ -> @a */ case Opostdec: /* @a-- -> @a */ case Oasn: /* @a = @a -> @a */ case Oaddeq: /* @a += @a -> @a */ case Osubeq: /* @a -= @a -> @a */ case Omuleq: /* @a *= @a -> @a */ case Odiveq: /* @a /= @a -> @a */ case Omodeq: /* @a %= @a -> @a */ case Oboreq: /* @a |= @a -> @a */ case Obandeq: /* @a &= @a -> @a */ case Obxoreq: /* @a ^= @a -> @a */ case Obsleq: /* @a <<= @a -> @a */ case Obsreq: /* @a >>= @a -> @a */ t = type(args[0]); for (i = 1; i < nargs; i++) t = unify(n, t, type(args[i])); settype(n, tf(t)); break; /* operands same type, returning bool */ case Olor: /* @a || @b -> bool */ case Oland: /* @a && @b -> bool */ case Olnot: /* !@a -> bool */ case Oeq: /* @a == @a -> bool */ case One: /* @a != @a -> bool */ case Ogt: /* @a > @a -> bool */ case Oge: /* @a >= @a -> bool */ case Olt: /* @a < @a -> bool */ case Ole: /* @a <= @b -> bool */ t = type(args[0]); for (i = 1; i < nargs; i++) unify(n, t, type(args[i])); settype(n, mkty(-1, Tybool)); break; /* reach into a type and pull out subtypes */ case Oaddr: /* &@a -> @a* */ settype(n, mktyptr(n->line, type(args[0]))); break; case Oderef: /* *@a* -> @a */ t = unify(n, type(args[0]), mktyptr(n->line, mktyvar(n->line))); settype(n, t); break; case Oidx: /* @a[@b::tcint] -> @a */ t = mktyidxhack(n->line, type(args[1])); t = unify(n, type(args[0]), t); settype(n, type(args[1])); break; case Oslice: /* @a[@b::tcint,@b::tcint] -> @a[,] */ t = mktyidxhack(n->line, type(args[1])); t = unify(n, type(args[0]), t); settype(n, mktyslice(n->line, type(args[1]))); break; /* special cases */ case Omemb: /* @a.Ident -> @b, verify type(@a.Ident)==@b later */ settype(n, mktyvar(n->line)); lappend(&checkmemb, &ncheckmemb, n); break; case Osize: /* sizeof @a -> size */ die("inference of sizes not done yet"); break; case Ocall: /* (@a, @b, @c, ... -> @r)(@a,@b,@c, ... -> @r) -> @r */ unifycall(n); break; case Ocast: /* cast(@a, @b) -> @b */ die("casts not implemented"); break; case Oret: /* -> @a -> void */ if (sawret) *sawret = 1; if (!ret) fatal(n->line, "Not allowed to return value here"); if (nargs) t = unify(n, ret, type(args[0])); else t = unify(n, mkty(-1, Tyvoid), ret); settype(n, t); break; case Ojmp: /* goto void* -> void */ settype(n, mkty(-1, Tyvoid)); break; case Ovar: /* a:@a -> @a */ s = getdcl(curstab(), args[0]); if (!s) fatal(n->line, "Undeclared var %s", declname(args[0])); else settype(n, s->type); n->expr.did = s->id; break; case Olit: /* <lit>:@a::tyclass -> @a */ switch (args[0]->lit.littype) { case Lfunc: infernode(args[0]->lit.fnval, NULL, NULL); break; case Larray: die("array types not implemented yet"); break; default: break; } settype(n, type(args[0])); break; case Olbl: /* :lbl -> void* */ settype(n, mktyptr(n->line, mkty(-1, Tyvoid))); case Obad: case Ocjmp: case Oload: case Ostor: case Oslbase: case Osllen: case Oblit: case Numops: die("Should not see %s in fe", opstr(exprop(n))); break; } } static void inferfunc(Node *n) { int i; int sawret; sawret = 0; for (i = 0; i < n->func.nargs; i++) infernode(n->func.args[i], NULL, NULL); infernode(n->func.body, n->func.type->sub[0], &sawret); /* if there's no return stmt in the function, assume void ret */ if (!sawret) unify(n, type(n)->sub[0], mkty(-1, Tyvoid)); } static void inferdecl(Node *n) { Type *t; t = tf(decltype(n)); settype(n, t); if (n->decl.init) { inferexpr(n->decl.init, NULL, NULL); unify(n, type(n), type(n->decl.init)); } } static void inferstab(Stab *s) { void **k; int n, i; Type *t; k = htkeys(s->ty, &n); for (i = 0; i < n; i++) { t = tf(gettype(s, k[i])); updatetype(s, k[i], t); } } static void infernode(Node *n, Type *ret, int *sawret) { int i; if (!n) return; switch (n->type) { case Nfile: pushstab(n->file.globls); inferstab(n->file.globls); for (i = 0; i < n->file.nstmts; i++) infernode(n->file.stmts[i], NULL, sawret); popstab(); break; case Ndecl: inferdecl(n); break; case Nblock: setsuper(n->block.scope, curstab()); pushstab(n->block.scope); inferstab(n->block.scope); for (i = 0; i < n->block.nstmts; i++) infernode(n->block.stmts[i], ret, sawret); popstab(); break; case Nifstmt: infernode(n->ifstmt.cond, NULL, sawret); infernode(n->ifstmt.iftrue, ret, sawret); infernode(n->ifstmt.iffalse, ret, sawret); constrain(type(n->ifstmt.cond), cstrtab[Tctest]); break; case Nloopstmt: infernode(n->loopstmt.init, ret, sawret); infernode(n->loopstmt.cond, NULL, sawret); infernode(n->loopstmt.step, ret, sawret); infernode(n->loopstmt.body, ret, sawret); constrain(type(n->loopstmt.cond), cstrtab[Tctest]); break; case Nexpr: inferexpr(n, ret, sawret); break; case Nfunc: setsuper(n->func.scope, curstab()); pushstab(n->func.scope); inferstab(n->block.scope); inferfunc(n); popstab(); break; case Nname: case Nlit: case Nuse: case Nlbl: break; case Nnone: die("Nnone should not be seen as node type!"); break; } } static void checkcast(Node *n) { } /* returns the final type for t, after all unifications * and default constraint selections */ static Type *tyfin(Node *ctx, Type *t) { static Type *tyint; int i; char buf[1024]; if (!tyint) tyint = mkty(-1, Tyint); t = tf(t); if (t->type == Tyvar) { if (hascstr(t, cstrtab[Tcint]) && cstrcheck(t, tyint)) return tyint; } else { if (t->type == Tyarray) typesub(t->asize); for (i = 0; i < t->nsub; i++) t->sub[i] = tyfin(ctx, t->sub[i]); } if (t->type == Tyvar) { fatal(t->line, "underconstrained type %s near %s", tyfmt(buf, 1024, t), ctxstr(ctx)); } return t; } static void infercompn(Node *file) { int i, j, nn; Node *aggr; Node *memb; Node *n; Node **nl; for (i = 0; i < ncheckmemb; i++) { n = checkmemb[i]; if (n->expr.type->type == Typtr) n = n->expr.args[0]; aggr = checkmemb[i]->expr.args[0]; memb = checkmemb[i]->expr.args[1]; nl = aggrmemb(aggr->expr.type, &nn); for (j = 0; j < nn; j++) { if (!strcmp(namestr(memb), declname(nl[j]))) { unify(n, type(n), decltype(nl[j])); break; } } } } static void typesub(Node *n) { int i; if (!n) return; switch (n->type) { case Nfile: for (i = 0; i < n->file.nstmts; i++) typesub(n->file.stmts[i]); break; case Ndecl: settype(n, tyfin(n, type(n))); if (n->decl.init) typesub(n->decl.init); break; case Nblock: for (i = 0; i < n->block.nstmts; i++) typesub(n->block.stmts[i]); break; case Nifstmt: typesub(n->ifstmt.cond); typesub(n->ifstmt.iftrue); typesub(n->ifstmt.iffalse); break; case Nloopstmt: typesub(n->loopstmt.cond); typesub(n->loopstmt.init); typesub(n->loopstmt.step); typesub(n->loopstmt.body); break; case Nexpr: settype(n, tyfin(n, type(n))); for (i = 0; i < n->expr.nargs; i++) typesub(n->expr.args[i]); break; case Nfunc: settype(n, tyfin(n, n->func.type)); for (i = 0; i < n->func.nargs; i++) typesub(n->func.args[i]); typesub(n->func.body); break; case Nlit: settype(n, tyfin(n, type(n))); switch (n->lit.littype) { case Lfunc: typesub(n->lit.fnval); break; case Larray: typesub(n->lit.arrval); break; default: break; } break; case Nname: case Nuse: case Nlbl: break; case Nnone: die("Nnone should not be seen as node type!"); break; } } void infer(Node *file) { assert(file->type == Nfile); loaduses(file); infernode(file, NULL, NULL); infercompn(file); checkcast(file); typesub(file); }