ref: 221a1f2cdc59935197d4bd14669e78bec521c0aa
parent: e3b7056a409cf0480dc12b7d3f5e3eb34c36db15
author: Ori Bernstein <[email protected]>
date: Mon Jan 9 20:47:44 EST 2012
Infer function literals correctly. We don't want to re-infer the whole tree every time we see the function, for several reasons: Performance, and because we don't want to have to push the stab every time we walk the tree.
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -13,6 +13,7 @@
#include "parse.h"
static void infernode(Node *n, Type *ret, int *sawret);
+static void inferexpr(Node *n, Type *ret, int *sawret);
static void setsuper(Stab *st, Stab *super)
{
@@ -81,6 +82,7 @@
/* a => b */
static void settype(Node *n, Type *t)
{
+ t = tf(t);
switch (n->type) {
case Nexpr: n->expr.type = t; break;
case Ndecl: n->decl.sym->type = t; break;
@@ -108,9 +110,6 @@
case Lflt: return tylike(mktyvar(n->line), Tyfloat32); break;
case Lstr: return mktyslice(n->line, mkty(n->line, Tychar)); break;
case Lfunc:
- /* we figure out the return type to infer for the body in
- * infernode() */
- infernode(n->lit.fnval, NULL, NULL);
return n->lit.fnval->func.type;
break;
case Larray: return NULL; break;
@@ -181,6 +180,7 @@
b = t;
}
+ printf("--- unify %s => %s\n", tystr(a), tystr(b));
mergecstrs(ctx, a, b);
if (a->type != b->type) {
if (a->type == Tyvar)
@@ -205,6 +205,16 @@
static void unifycall(Node *n)
{
+ int i;
+ Type *ft;
+
+ inferexpr(n->expr.args[0], NULL, NULL);
+ ft = type(n->expr.args[0]);
+ for (i = 1; i < n->expr.nargs; i++) {
+ inferexpr(n->expr.args[i], NULL, NULL);
+ unify(n, ft->sub[i], type(n->expr.args[i]));
+ }
+ settype(n, ft->sub[0]);
}
static void inferexpr(Node *n, Type *ret, int *sawret)
@@ -323,6 +333,11 @@
settype(n, s->type);
break;
case Olit: /* <lit>:@a::tyclass -> @a */
+ switch (args[0]->lit.littype) {
+ case Lfunc: infernode(args[0]->lit.fnval, NULL, NULL); break;
+ case Larray: die("array types not implemented yet"); break;
+ default: break;
+ }
settype(n, type(args[0]));
break;
case Olbl: /* :lbl -> void* */
@@ -343,6 +358,8 @@
infernode(n->func.body, n->func.type->sub[0], &sawret);
if (!sawret)
unify(n, type(n)->sub[0], mkty(-1, Tyvoid));
+ else
+ printf("SAWRET!!!\n");
}
static void inferdecl(Node *n)
@@ -349,6 +366,7 @@
{
Type *t;
+ printf("====== decl %s\n", n->decl.sym->name->name.parts[n->decl.sym->name->name.nparts-1]);
t = decltype(n);
settype(n, t);
if (n->decl.init) {