ref: e30e522d8b8ddc893c622c34207e24abc1ec12c0
parent: 35bfc4bc1c92546368e13c7eb20f8f7c4c9b404b
author: Ori Bernstein <[email protected]>
date: Mon Jan 9 19:40:11 EST 2012
Infer return type for non-returning functions If there are zero returns in the body of the function, the return type is (whatever -> void)
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -12,7 +12,7 @@
#include "parse.h"
-static void infernode(Node *n, Type *ret);
+static void infernode(Node *n, Type *ret, int *sawret);
static void setsuper(Stab *st, Stab *super)
{
@@ -110,7 +110,7 @@
case Lfunc:
/* we figure out the return type to infer for the body in
* infernode() */
- infernode(n->lit.fnval, NULL);
+ infernode(n->lit.fnval, NULL, NULL);
return n->lit.fnval->func.type;
break;
case Larray: return NULL; break;
@@ -126,6 +126,7 @@
case Nlit: t = littype(n); break;
case Nexpr: t = n->expr.type; break;
case Ndecl: t = decltype(n); break;
+ case Nfunc: t = n->func.type; break;
default:
t = NULL;
die("untypeable %s", nodestr(n->type));
@@ -136,7 +137,12 @@
static char *ctxstr(Node *n)
{
- return nodestr(n->type);
+ char *s;
+ switch (n->type) {
+ case Nexpr: s = opstr(exprop(n)); break;
+ default: s = nodestr(n->type); break;
+ }
+ return s;
}
static void matchcstrs(Node *ctx, Type *a, Type *b)
@@ -201,7 +207,7 @@
{
}
-static void inferexpr(Node *n, Type *ret)
+static void inferexpr(Node *n, Type *ret, int *sawret)
{
Node **args;
Sym *s;
@@ -215,7 +221,7 @@
for (i = 0; i < nargs; i++)
/* Nlit, Nvar, etc should not be inferred as exprs */
if (args[i]->type == Nexpr)
- inferexpr(args[i], ret);
+ inferexpr(args[i], ret, sawret);
switch (exprop(n)) {
/* all operands are same type */
case Oadd: /* @a + @a -> @a */
@@ -296,6 +302,8 @@
die("casts not implemented");
break;
case Oret: /* -> @a -> void */
+ if (sawret)
+ *sawret = 1;
if (!ret)
fatal(n->line, "Not allowed to return value here");
if (nargs)
@@ -327,10 +335,14 @@
static void inferfunc(Node *n)
{
int i;
+ int sawret;
+ sawret = 0;
for (i = 0; i < n->func.nargs; i++)
- infernode(n->func.args[i], NULL);
- infernode(n->func.body, n->func.type->sub[0]);
+ infernode(n->func.args[i], NULL, NULL);
+ infernode(n->func.body, n->func.type->sub[0], &sawret);
+ if (!sawret)
+ unify(n, type(n)->sub[0], mkty(-1, Tyvoid));
}
static void inferdecl(Node *n)
@@ -340,12 +352,12 @@
t = decltype(n);
settype(n, t);
if (n->decl.init) {
- inferexpr(n->decl.init, NULL);
+ inferexpr(n->decl.init, NULL, NULL);
unify(n, type(n), type(n->decl.init));
}
}
-static void infernode(Node *n, Type *ret)
+static void infernode(Node *n, Type *ret, int *sawret)
{
int i;
@@ -353,7 +365,7 @@
case Nfile:
pushstab(n->file.globls);
for (i = 0; i < n->file.nstmts; i++)
- infernode(n->file.stmts[i], NULL);
+ infernode(n->file.stmts[i], NULL, sawret);
popstab();
break;
case Ndecl:
@@ -363,22 +375,22 @@
setsuper(n->block.scope, curstab());
pushstab(n->block.scope);
for (i = 0; i < n->block.nstmts; i++)
- infernode(n->block.stmts[i], ret);
+ infernode(n->block.stmts[i], ret, sawret);
popstab();
break;
case Nifstmt:
- infernode(n->ifstmt.cond, NULL);
- infernode(n->ifstmt.iftrue, ret);
- infernode(n->ifstmt.iffalse, ret);
+ infernode(n->ifstmt.cond, NULL, sawret);
+ infernode(n->ifstmt.iftrue, ret, sawret);
+ infernode(n->ifstmt.iffalse, ret, sawret);
break;
case Nloopstmt:
- infernode(n->loopstmt.cond, NULL);
- infernode(n->loopstmt.init, ret);
- infernode(n->loopstmt.step, ret);
- infernode(n->loopstmt.body, ret);
+ infernode(n->loopstmt.cond, NULL, sawret);
+ infernode(n->loopstmt.init, ret, sawret);
+ infernode(n->loopstmt.step, ret, sawret);
+ infernode(n->loopstmt.body, ret, sawret);
break;
case Nexpr:
- inferexpr(n, ret);
+ inferexpr(n, ret, sawret);
break;
case Nfunc:
setsuper(n->func.scope, curstab());
@@ -492,7 +504,7 @@
assert(file->type == Nfile);
loaduses(file);
- infernode(file, NULL);
+ infernode(file, NULL, NULL);
infercompn(file);
checkcast(file);
typesub(file);