shithub: mc

Download patch

ref: e30e522d8b8ddc893c622c34207e24abc1ec12c0
parent: 35bfc4bc1c92546368e13c7eb20f8f7c4c9b404b
author: Ori Bernstein <[email protected]>
date: Mon Jan 9 19:40:11 EST 2012

Infer return type for non-returning functions

    If there are zero returns in the body of the function,
    the return type is (whatever -> void)

--- a/parse/infer.c
+++ b/parse/infer.c
@@ -12,7 +12,7 @@
 
 #include "parse.h"
 
-static void infernode(Node *n, Type *ret);
+static void infernode(Node *n, Type *ret, int *sawret);
 
 static void setsuper(Stab *st, Stab *super)
 {
@@ -110,7 +110,7 @@
         case Lfunc:
             /* we figure out the return type to infer for the body in
              * infernode() */
-            infernode(n->lit.fnval, NULL);
+            infernode(n->lit.fnval, NULL, NULL);
             return n->lit.fnval->func.type;
             break;
         case Larray:    return NULL; break;
@@ -126,6 +126,7 @@
       case Nlit:        t = littype(n);         break;
       case Nexpr:       t = n->expr.type;       break;
       case Ndecl:       t = decltype(n);        break;
+      case Nfunc:       t = n->func.type;       break;
       default:
         t = NULL;
         die("untypeable %s", nodestr(n->type));
@@ -136,7 +137,12 @@
 
 static char *ctxstr(Node *n)
 {
-    return nodestr(n->type);
+    char *s;
+    switch (n->type) {
+        case Nexpr:     s = opstr(exprop(n)); break;
+        default:        s = nodestr(n->type); break;
+    }
+    return s;
 }
 
 static void matchcstrs(Node *ctx, Type *a, Type *b)
@@ -201,7 +207,7 @@
 {
 }
 
-static void inferexpr(Node *n, Type *ret)
+static void inferexpr(Node *n, Type *ret, int *sawret)
 {
     Node **args;
     Sym *s;
@@ -215,7 +221,7 @@
     for (i = 0; i < nargs; i++)
         /* Nlit, Nvar, etc should not be inferred as exprs */
         if (args[i]->type == Nexpr)
-            inferexpr(args[i], ret);
+            inferexpr(args[i], ret, sawret);
     switch (exprop(n)) {
         /* all operands are same type */
         case Oadd:      /* @a + @a -> @a */
@@ -296,6 +302,8 @@
             die("casts not implemented");
             break;
         case Oret:      /* -> @a -> void */
+            if (sawret)
+                *sawret = 1;
             if (!ret)
                 fatal(n->line, "Not allowed to return value here");
             if (nargs)
@@ -327,10 +335,14 @@
 static void inferfunc(Node *n)
 {
     int i;
+    int sawret;
 
+    sawret = 0;
     for (i = 0; i < n->func.nargs; i++)
-        infernode(n->func.args[i], NULL);
-    infernode(n->func.body, n->func.type->sub[0]);
+        infernode(n->func.args[i], NULL, NULL);
+    infernode(n->func.body, n->func.type->sub[0], &sawret);
+    if (!sawret)
+        unify(n, type(n)->sub[0], mkty(-1, Tyvoid));
 }
 
 static void inferdecl(Node *n)
@@ -340,12 +352,12 @@
     t = decltype(n);
     settype(n, t);
     if (n->decl.init) {
-        inferexpr(n->decl.init, NULL);
+        inferexpr(n->decl.init, NULL, NULL);
         unify(n, type(n), type(n->decl.init));
     }
 }
 
-static void infernode(Node *n, Type *ret)
+static void infernode(Node *n, Type *ret, int *sawret)
 {
     int i;
 
@@ -353,7 +365,7 @@
         case Nfile:
             pushstab(n->file.globls);
             for (i = 0; i < n->file.nstmts; i++)
-                infernode(n->file.stmts[i], NULL);
+                infernode(n->file.stmts[i], NULL, sawret);
             popstab();
             break;
         case Ndecl:
@@ -363,22 +375,22 @@
             setsuper(n->block.scope, curstab());
             pushstab(n->block.scope);
             for (i = 0; i < n->block.nstmts; i++)
-                infernode(n->block.stmts[i], ret);
+                infernode(n->block.stmts[i], ret, sawret);
             popstab();
             break;
         case Nifstmt:
-            infernode(n->ifstmt.cond, NULL);
-            infernode(n->ifstmt.iftrue, ret);
-            infernode(n->ifstmt.iffalse, ret);
+            infernode(n->ifstmt.cond, NULL, sawret);
+            infernode(n->ifstmt.iftrue, ret, sawret);
+            infernode(n->ifstmt.iffalse, ret, sawret);
             break;
         case Nloopstmt:
-            infernode(n->loopstmt.cond, NULL);
-            infernode(n->loopstmt.init, ret);
-            infernode(n->loopstmt.step, ret);
-            infernode(n->loopstmt.body, ret);
+            infernode(n->loopstmt.cond, NULL, sawret);
+            infernode(n->loopstmt.init, ret, sawret);
+            infernode(n->loopstmt.step, ret, sawret);
+            infernode(n->loopstmt.body, ret, sawret);
             break;
         case Nexpr:
-            inferexpr(n, ret);
+            inferexpr(n, ret, sawret);
             break;
         case Nfunc:
             setsuper(n->func.scope, curstab());
@@ -492,7 +504,7 @@
     assert(file->type == Nfile);
 
     loaduses(file);
-    infernode(file, NULL);
+    infernode(file, NULL, NULL);
     infercompn(file);
     checkcast(file);
     typesub(file);