ref: c39f89dd821b17ef095a36b06f8c77d4a109ce64
parent: 0a0f0610bf7164c4f56d321f8750c82b2001498c
author: Ori Bernstein <[email protected]>
date: Fri Jan 6 18:47:33 EST 2012
Create function types.
--- a/parse/gram.y
+++ b/parse/gram.y
@@ -463,7 +463,9 @@
;
funclit : TObrace params TEndln blockbody TCbrace
- {$$ = mkfunc($1->line, $2.nl, $2.nn, $4);}
+ {$$ = mkfunc($1->line, $2.nl, $2.nn, NULL, $4);}
+ | TObrace params TRet type TEndln blockbody TCbrace
+ {$$ = mkfunc($1->line, $2.nl, $2.nn, $4, $6);}
;
params : declcore
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -12,6 +12,7 @@
#include "parse.h"
+static void infernode(Node *n, Type *ret);
/* find the most accurate type mapping */
static Type *tf(Type *t)
@@ -87,7 +88,12 @@
case Lint: return tylike(mktyvar(n->line), Tyint); break;
case Lflt: return tylike(mktyvar(n->line), Tyfloat32); break;
case Lstr: return mktyslice(n->line, mkty(n->line, Tychar)); break;
- case Lfunc: return NULL; break;
+ case Lfunc:
+ /* we figure out the return type to infer for the body in
+ * infernode() */
+ infernode(n, NULL);
+ return n->lit.fnval->func.type;
+ break;
case Larray: return NULL; break;
};
return NULL;
@@ -258,7 +264,13 @@
die("casts not implemented");
break;
case Oret: /* -> @a -> void */
- settype(n, mkty(-1, Tyvoid));
+ if (!ret)
+ fatal(n->line, "Not allowed to return value here");
+ if (nargs)
+ t = unify(n, type(args[0]), ret);
+ else
+ t = unify(n, mkty(-1, Tyvoid), ret);
+ settype(n, t);
break;
case Ogoto: /* goto void* -> void */
settype(n, mkty(-1, Tyvoid));
@@ -282,6 +294,11 @@
static void inferfunc(Node *n)
{
+ int i;
+
+ for (i = 0; i < n->func.nargs; i++)
+ infernode(n->func.args[i], NULL);
+ infernode(n->func.body, n->func.type->sub[0]);
}
static void inferdecl(Node *n)
@@ -296,7 +313,7 @@
}
}
-static void infernode(Node *n)
+static void infernode(Node *n, Type *ret)
{
int i;
@@ -304,7 +321,7 @@
case Nfile:
pushstab(n->file.globls);
for (i = 0; i < n->file.nstmts; i++)
- infernode(n->file.stmts[i]);
+ infernode(n->file.stmts[i], NULL);
popstab();
break;
case Ndecl:
@@ -312,23 +329,25 @@
break;
case Nblock:
for (i = 0; i < n->block.nstmts; i++)
- infernode(n->block.stmts[i]);
+ infernode(n->block.stmts[i], ret);
break;
case Nifstmt:
- infernode(n->ifstmt.cond);
- infernode(n->ifstmt.iftrue);
- infernode(n->ifstmt.iffalse);
+ infernode(n->ifstmt.cond, NULL);
+ infernode(n->ifstmt.iftrue, ret);
+ infernode(n->ifstmt.iffalse, ret);
break;
case Nloopstmt:
- infernode(n->loopstmt.cond);
- infernode(n->loopstmt.init);
- infernode(n->loopstmt.step);
- infernode(n->loopstmt.body);
+ infernode(n->loopstmt.cond, NULL);
+ infernode(n->loopstmt.init, ret);
+ infernode(n->loopstmt.step, ret);
+ infernode(n->loopstmt.body, ret);
break;
case Nexpr:
inferexpr(n, NULL);
case Nfunc:
+ pushstab(n->func.scope);
inferfunc(n);
+ popstab();
case Nname:
case Nlit:
case Nuse:
@@ -418,7 +437,7 @@
assert(file->type == Nfile);
loaduses(file);
- infernode(file);
+ infernode(file, NULL);
infercompn(file);
checkcast(file);
typesub(file);
--- a/parse/node.c
+++ b/parse/node.c
@@ -94,7 +94,7 @@
return n;
}
-Node *mkfunc(int line, Node **args, size_t nargs, Node *body)
+Node *mkfunc(int line, Node **args, size_t nargs, Type *ret, Node *body)
{
Node *n;
Node *f;
@@ -104,6 +104,7 @@
f->func.nargs = nargs;
f->func.body = body;
f->func.scope = mkstab(curstab());
+ f->func.type = mktyfunc(line, args, nargs, ret);
n = mknode(line, Nlit);
n->lit.littype = Lfunc;
--- a/parse/parse.h
+++ b/parse/parse.h
@@ -192,6 +192,7 @@
struct {
Stab *scope;
+ Type *type;
size_t nargs;
Node **args;
Node *body;
@@ -308,7 +309,7 @@
Node *mkchar(int line, uint32_t val);
Node *mkstr(int line, char *s);
Node *mkfloat(int line, double flt);
-Node *mkfunc(int line, Node **args, size_t nargs, Node *body);
+Node *mkfunc(int line, Node **args, size_t nargs, Type *ret, Node *body);
Node *mkarray(int line, Node **vals);
Node *mkname(int line, char *name);
Node *mkdecl(int line, Sym *sym);
--- a/parse/type.c
+++ b/parse/type.c
@@ -237,7 +237,6 @@
char *p;
char *end;
char *sep;
- int first;
int i;
if (!t->cstrs || !bscount(t->cstrs))
@@ -245,7 +244,6 @@
p = buf;
end = p + len;
- first = 1;
p += snprintf(p, end - p, " :: ");
sep = "";