shithub: mc

Download patch

ref: c39f89dd821b17ef095a36b06f8c77d4a109ce64
parent: 0a0f0610bf7164c4f56d321f8750c82b2001498c
author: Ori Bernstein <[email protected]>
date: Fri Jan 6 18:47:33 EST 2012

Create function types.

--- a/parse/gram.y
+++ b/parse/gram.y
@@ -463,7 +463,9 @@
         ;
 
 funclit : TObrace params TEndln blockbody TCbrace
-            {$$ = mkfunc($1->line, $2.nl, $2.nn, $4);}
+            {$$ = mkfunc($1->line, $2.nl, $2.nn, NULL, $4);}
+        | TObrace params TRet type TEndln blockbody TCbrace
+            {$$ = mkfunc($1->line, $2.nl, $2.nn, $4, $6);}
         ;
 
 params  : declcore
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -12,6 +12,7 @@
 
 #include "parse.h"
 
+static void infernode(Node *n, Type *ret);
 
 /* find the most accurate type mapping */
 static Type *tf(Type *t)
@@ -87,7 +88,12 @@
         case Lint:      return tylike(mktyvar(n->line), Tyint);                 break;
         case Lflt:      return tylike(mktyvar(n->line), Tyfloat32);             break;
         case Lstr:      return mktyslice(n->line, mkty(n->line, Tychar));       break;
-        case Lfunc:     return NULL; break;
+        case Lfunc:
+            /* we figure out the return type to infer for the body in
+             * infernode() */
+            infernode(n, NULL);
+            return n->lit.fnval->func.type;
+            break;
         case Larray:    return NULL; break;
     };
     return NULL;
@@ -258,7 +264,13 @@
             die("casts not implemented");
             break;
         case Oret:      /* -> @a -> void */
-            settype(n, mkty(-1, Tyvoid));
+            if (!ret)
+                fatal(n->line, "Not allowed to return value here");
+            if (nargs)
+                t = unify(n, type(args[0]), ret);
+            else
+                t =  unify(n, mkty(-1, Tyvoid), ret);
+            settype(n, t);
             break;
         case Ogoto:     /* goto void* -> void */
             settype(n, mkty(-1, Tyvoid));
@@ -282,6 +294,11 @@
 
 static void inferfunc(Node *n)
 {
+    int i;
+
+    for (i = 0; i < n->func.nargs; i++)
+        infernode(n->func.args[i], NULL);
+    infernode(n->func.body, n->func.type->sub[0]);
 }
 
 static void inferdecl(Node *n)
@@ -296,7 +313,7 @@
     }
 }
 
-static void infernode(Node *n)
+static void infernode(Node *n, Type *ret)
 {
     int i;
 
@@ -304,7 +321,7 @@
         case Nfile:
             pushstab(n->file.globls);
             for (i = 0; i < n->file.nstmts; i++)
-                infernode(n->file.stmts[i]);
+                infernode(n->file.stmts[i], NULL);
             popstab();
             break;
         case Ndecl:
@@ -312,23 +329,25 @@
             break;
         case Nblock:
             for (i = 0; i < n->block.nstmts; i++)
-                infernode(n->block.stmts[i]);
+                infernode(n->block.stmts[i], ret);
             break;
         case Nifstmt:
-            infernode(n->ifstmt.cond);
-            infernode(n->ifstmt.iftrue);
-            infernode(n->ifstmt.iffalse);
+            infernode(n->ifstmt.cond, NULL);
+            infernode(n->ifstmt.iftrue, ret);
+            infernode(n->ifstmt.iffalse, ret);
             break;
         case Nloopstmt:
-            infernode(n->loopstmt.cond);
-            infernode(n->loopstmt.init);
-            infernode(n->loopstmt.step);
-            infernode(n->loopstmt.body);
+            infernode(n->loopstmt.cond, NULL);
+            infernode(n->loopstmt.init, ret);
+            infernode(n->loopstmt.step, ret);
+            infernode(n->loopstmt.body, ret);
             break;
         case Nexpr:
             inferexpr(n, NULL);
         case Nfunc:
+            pushstab(n->func.scope);
             inferfunc(n);
+            popstab();
         case Nname:
         case Nlit:
         case Nuse:
@@ -418,7 +437,7 @@
     assert(file->type == Nfile);
 
     loaduses(file);
-    infernode(file);
+    infernode(file, NULL);
     infercompn(file);
     checkcast(file);
     typesub(file);
--- a/parse/node.c
+++ b/parse/node.c
@@ -94,7 +94,7 @@
     return n;
 }
 
-Node *mkfunc(int line, Node **args, size_t nargs, Node *body)
+Node *mkfunc(int line, Node **args, size_t nargs, Type *ret, Node *body)
 {
     Node *n;
     Node *f;
@@ -104,6 +104,7 @@
     f->func.nargs = nargs;
     f->func.body = body;
     f->func.scope = mkstab(curstab());
+    f->func.type = mktyfunc(line, args, nargs, ret);
 
     n = mknode(line, Nlit);
     n->lit.littype = Lfunc;
--- a/parse/parse.h
+++ b/parse/parse.h
@@ -192,6 +192,7 @@
 
         struct {
             Stab *scope;
+            Type *type;
             size_t nargs;
             Node **args;
             Node *body;
@@ -308,7 +309,7 @@
 Node *mkchar(int line, uint32_t val);
 Node *mkstr(int line, char *s);
 Node *mkfloat(int line, double flt);
-Node *mkfunc(int line, Node **args, size_t nargs, Node *body);
+Node *mkfunc(int line, Node **args, size_t nargs, Type *ret, Node *body);
 Node *mkarray(int line, Node **vals);
 Node *mkname(int line, char *name);
 Node *mkdecl(int line, Sym *sym);
--- a/parse/type.c
+++ b/parse/type.c
@@ -237,7 +237,6 @@
     char *p;
     char *end;
     char *sep;
-    int first;
     int i;
 
     if (!t->cstrs || !bscount(t->cstrs))
@@ -245,7 +244,6 @@
 
     p = buf;
     end = p + len;
-    first = 1;
 
     p += snprintf(p, end - p, " :: ");
     sep = "";