ref: e119045a1d483c7dc7808c140a3659e41b9d160b
parent: b4e37364481ba805a39e8699ce9ae0e78b85d6c0
author: Ori Bernstein <[email protected]>
date: Fri Jun 7 23:00:35 EDT 2013
Fix type inference for struct literals.
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -67,7 +67,7 @@
case Ndecl:
u = declname(n);
t = tystr(tf(st, decltype(n)));
- snprintf(buf, 512, "%s:%s", u, t);
+ snprintf(buf, sizeof buf, "%s:%s", u, t);
s = strdup(buf);
free(t);
break;
@@ -83,10 +83,18 @@
t = tystr(tf(st, exprtype(n)));
else
t = strdup("unknown");
- snprintf(buf, 512, "%s:%s", u, t);
+ snprintf(buf, sizeof buf, "%s:%s", u, t);
s = strdup(buf);
free(t);
break;
+ case Nidxinit:
+ t = ctxstr(st, n->idxinit.idx);
+ u = ctxstr(st, n->idxinit.init);
+ snprintf(buf, sizeof buf, "%s=%s", t, u);
+ s = strdup(buf);
+ free(t);
+ free(u);
+ break;
}
return s;
}
@@ -430,12 +438,6 @@
st->ingeneric--;
}
-static void checkcast(Inferstate *st, Node *n)
-{
- /* FIXME: actually verify the casts. Right now, it's ok to leave this
- * unimplemented because bad casts get caught by the backend. */
-}
-
/* Constrains a type to implement the required constraints. On
* type variables, the constraint is added to the required
* constraint list. Otherwise, the type is checked to see
@@ -745,7 +747,8 @@
for (i = 0; i < n->lit.nelt; i++)
infernode(st, n->lit.seqval[i], NULL, NULL);
- die("Don't know what to do with struct lits yet, when it comes to inference");
+ settype(st, n, mktyvar(n->line));
+ lappend(&st->postcheck, &st->npostcheck, n);
}
static void inferarray(Inferstate *st, Node *n)
@@ -1201,33 +1204,34 @@
return t;
}
-static void infercompn(Inferstate *st, Node *file)
+static void checkcast(Inferstate *st, Node *n)
{
- size_t i, j;
+ /* FIXME: actually verify the casts. Right now, it's ok to leave this
+ * unimplemented because bad casts get caught by the backend. */
+}
+
+static void infercompn(Inferstate *st, Node *n)
+{
Node *aggr;
Node *memb;
- Node *n;
Node **nl;
Type *t;
+ size_t i;
int found;
- for (i = 0; i < st->npostcheck; i++) {
- n = st->postcheck[i];
- if (exprop(n) != Omemb)
- continue;
- aggr = st->postcheck[i]->expr.args[0];
- memb = st->postcheck[i]->expr.args[1];
+ aggr = n->expr.args[0];
+ memb = n->expr.args[1];
- found = 0;
- t = tybase(tf(st, type(st, aggr)));
- /* all array-like types have a fake "len" member that we emulate */
- if (t->type == Tyslice || t->type == Tyarray) {
- if (!strcmp(namestr(memb), "len")) {
- constrain(st, n, type(st, n), cstrtab[Tcnum]);
- constrain(st, n, type(st, n), cstrtab[Tcint]);
- constrain(st, n, type(st, n), cstrtab[Tctest]);
- found = 1;
- }
+ found = 0;
+ t = tybase(tf(st, type(st, aggr)));
+ /* all array-like types have a fake "len" member that we emulate */
+ if (t->type == Tyslice || t->type == Tyarray) {
+ if (!strcmp(namestr(memb), "len")) {
+ constrain(st, n, type(st, n), cstrtab[Tcnum]);
+ constrain(st, n, type(st, n), cstrtab[Tcint]);
+ constrain(st, n, type(st, n), cstrtab[Tctest]);
+ found = 1;
+ }
/* otherwise, we search aggregate types for the member, and unify
* the expression with the member type; ie:
*
@@ -1235,24 +1239,72 @@
* ---------------------------------------
* x.y : membtype
*/
- } else {
- if (t->type == Typtr)
- t = tybase(tf(st, t->sub[0]));
- nl = t->sdecls;
- for (j = 0; j < t->nmemb; j++) {
- if (!strcmp(namestr(memb), declname(nl[j]))) {
- unify(st, n, type(st, n), decltype(nl[j]));
- found = 1;
- break;
- }
+ } else {
+ if (t->type == Typtr)
+ t = tybase(tf(st, t->sub[0]));
+ nl = t->sdecls;
+ for (i = 0; i < t->nmemb; i++) {
+ if (!strcmp(namestr(memb), declname(nl[i]))) {
+ unify(st, n, type(st, n), decltype(nl[i]));
+ found = 1;
+ break;
+ }
+ }
+ }
+ if (!found)
+ fatal(aggr->line, "Type %s has no member \"%s\" near %s",
+ tystr(type(st, aggr)), ctxstr(st, memb), ctxstr(st, aggr));
+}
+
+static void checkstruct(Inferstate *st, Node *n)
+{
+ Type *t, *et;
+ Node *elt, *name, *val;
+ size_t i, j;
+
+ t = tybase(tf(st, n->lit.type));
+ if (t->type != Tystruct)
+ fatal(n->line, "Type %s for struct literal is not struct near %s", tystr(t), ctxstr(st, n));
+
+ for (i = 0; i < n->lit.nelt; i++) {
+ elt = n->lit.seqval[i];
+ name = elt->idxinit.idx;
+ val = elt->idxinit.init;
+
+ et = NULL;
+ for (j = 0; j < t->nmemb; j++) {
+ if (!strcmp(namestr(t->sdecls[i]->decl.name), namestr(name))) {
+ et = type(st, t->sdecls[i]);
+ break;
}
}
- if (!found)
- fatal(aggr->line, "Type %s has no member \"%s\" near %s",
- tystr(type(st, aggr)), ctxstr(st, memb), ctxstr(st, aggr));
+
+ if (!et)
+ fatal(n->line, "Could not find member %s in struct %s, near %s",
+ namestr(name), tystr(t), ctxstr(st, n));
+
+ unify(st, elt, et, type(st, val));
}
}
+static void postcheck(Inferstate *st, Node *file)
+{
+ size_t i;
+ Node *n;
+
+ for (i = 0; i < st->npostcheck; i++) {
+ n = st->postcheck[i];
+ if (n->type == Nexpr && exprop(n) == Omemb)
+ infercompn(st, n);
+ else if (n->type == Nexpr && exprop(n) == Ocast)
+ checkcast(st, n);
+ else if (n->type == Nlit && n->lit.littype == Lstruct)
+ checkstruct(st, n);
+ else
+ die("Thing we shouldn't be checking in postcheck\n");
+ }
+}
+
/* After inference, replace all
* types in symbol tables with
* the final computed types */
@@ -1405,8 +1457,7 @@
loaduses(file);
mergeexports(&st, file);
infernode(&st, file, NULL, NULL);
- infercompn(&st, file);
- checkcast(&st, file);
+ postcheck(&st, file);
typesub(&st, file);
specialize(&st, file);
}
--- a/test/structlit.myr
+++ b/test/structlit.myr
@@ -3,7 +3,7 @@
type t = struct
a : int
b : char
- c : char[:]
+ c : byte[:]
;;
const main = {