shithub: mc

Download patch

ref: e119045a1d483c7dc7808c140a3659e41b9d160b
parent: b4e37364481ba805a39e8699ce9ae0e78b85d6c0
author: Ori Bernstein <[email protected]>
date: Fri Jun 7 23:00:35 EDT 2013

Fix type inference for struct literals.

--- a/parse/infer.c
+++ b/parse/infer.c
@@ -67,7 +67,7 @@
         case Ndecl:
             u = declname(n);
             t = tystr(tf(st, decltype(n)));
-            snprintf(buf, 512, "%s:%s", u, t);
+            snprintf(buf, sizeof buf, "%s:%s", u, t);
             s = strdup(buf);
             free(t);
             break;
@@ -83,10 +83,18 @@
                 t = tystr(tf(st, exprtype(n)));
             else
                 t = strdup("unknown");
-            snprintf(buf, 512, "%s:%s", u, t);
+            snprintf(buf, sizeof buf, "%s:%s", u, t);
             s = strdup(buf);
             free(t);
             break;
+        case Nidxinit:
+            t = ctxstr(st, n->idxinit.idx);
+            u = ctxstr(st, n->idxinit.init);
+            snprintf(buf, sizeof buf, "%s=%s", t, u);
+            s = strdup(buf);
+            free(t);
+            free(u);
+            break;
     }
     return s;
 }
@@ -430,12 +438,6 @@
     st->ingeneric--;
 }
 
-static void checkcast(Inferstate *st, Node *n)
-{
-    /* FIXME: actually verify the casts. Right now, it's ok to leave this
-     * unimplemented because bad casts get caught by the backend. */
-}
-
 /* Constrains a type to implement the required constraints. On
  * type variables, the constraint is added to the required
  * constraint list. Otherwise, the type is checked to see
@@ -745,7 +747,8 @@
 
     for (i = 0; i < n->lit.nelt; i++)
         infernode(st, n->lit.seqval[i], NULL, NULL);
-    die("Don't know what to do with struct lits yet, when it comes to inference");
+    settype(st, n, mktyvar(n->line));
+    lappend(&st->postcheck, &st->npostcheck, n);
 }
 
 static void inferarray(Inferstate *st, Node *n)
@@ -1201,33 +1204,34 @@
     return t;
 }
 
-static void infercompn(Inferstate *st, Node *file)
+static void checkcast(Inferstate *st, Node *n)
 {
-    size_t i, j;
+    /* FIXME: actually verify the casts. Right now, it's ok to leave this
+     * unimplemented because bad casts get caught by the backend. */
+}
+
+static void infercompn(Inferstate *st, Node *n)
+{
     Node *aggr;
     Node *memb;
-    Node *n;
     Node **nl;
     Type *t;
+    size_t i;
     int found;
 
-    for (i = 0; i < st->npostcheck; i++) {
-        n = st->postcheck[i];
-        if (exprop(n) != Omemb)
-            continue;
-        aggr = st->postcheck[i]->expr.args[0];
-        memb = st->postcheck[i]->expr.args[1];
+    aggr = n->expr.args[0];
+    memb = n->expr.args[1];
 
-        found = 0;
-        t = tybase(tf(st, type(st, aggr)));
-        /* all array-like types have a fake "len" member that we emulate */
-        if (t->type == Tyslice || t->type == Tyarray) {
-            if (!strcmp(namestr(memb), "len")) {
-                constrain(st, n, type(st, n), cstrtab[Tcnum]);
-                constrain(st, n, type(st, n), cstrtab[Tcint]);
-                constrain(st, n, type(st, n), cstrtab[Tctest]);
-                found = 1;
-            }
+    found = 0;
+    t = tybase(tf(st, type(st, aggr)));
+    /* all array-like types have a fake "len" member that we emulate */
+    if (t->type == Tyslice || t->type == Tyarray) {
+        if (!strcmp(namestr(memb), "len")) {
+            constrain(st, n, type(st, n), cstrtab[Tcnum]);
+            constrain(st, n, type(st, n), cstrtab[Tcint]);
+            constrain(st, n, type(st, n), cstrtab[Tctest]);
+            found = 1;
+        }
         /* otherwise, we search aggregate types for the member, and unify
          * the expression with the member type; ie:
          *
@@ -1235,24 +1239,72 @@
          *     ---------------------------------------
          *               x.y : membtype
          */
-        } else {
-            if (t->type == Typtr)
-                t = tybase(tf(st, t->sub[0]));
-            nl = t->sdecls;
-            for (j = 0; j < t->nmemb; j++) {
-                if (!strcmp(namestr(memb), declname(nl[j]))) {
-                    unify(st, n, type(st, n), decltype(nl[j]));
-                    found = 1;
-                    break;
-                }
+    } else {
+        if (t->type == Typtr)
+            t = tybase(tf(st, t->sub[0]));
+        nl = t->sdecls;
+        for (i = 0; i < t->nmemb; i++) {
+            if (!strcmp(namestr(memb), declname(nl[i]))) {
+                unify(st, n, type(st, n), decltype(nl[i]));
+                found = 1;
+                break;
+            }
+        }
+    }
+    if (!found)
+        fatal(aggr->line, "Type %s has no member \"%s\" near %s",
+              tystr(type(st, aggr)), ctxstr(st, memb), ctxstr(st, aggr));
+}
+
+static void checkstruct(Inferstate *st, Node *n)
+{
+    Type *t, *et;
+    Node *elt, *name, *val;
+    size_t i, j;
+
+    t = tybase(tf(st, n->lit.type));
+    if (t->type != Tystruct)
+        fatal(n->line, "Type %s for struct literal is not struct near %s", tystr(t), ctxstr(st, n));
+
+    for (i = 0; i < n->lit.nelt; i++) {
+        elt = n->lit.seqval[i];
+        name = elt->idxinit.idx;
+        val = elt->idxinit.init;
+
+        et = NULL;
+        for (j = 0; j < t->nmemb; j++) {
+            if (!strcmp(namestr(t->sdecls[i]->decl.name), namestr(name))) {
+                et = type(st, t->sdecls[i]);
+                break;
             }
         }
-        if (!found)
-            fatal(aggr->line, "Type %s has no member \"%s\" near %s",
-                  tystr(type(st, aggr)), ctxstr(st, memb), ctxstr(st, aggr));
+
+        if (!et)
+            fatal(n->line, "Could not find member %s in struct %s, near %s",
+                  namestr(name), tystr(t), ctxstr(st, n));
+
+        unify(st, elt, et, type(st, val));
     }
 }
 
+static void postcheck(Inferstate *st, Node *file)
+{
+    size_t i;
+    Node *n;
+
+    for (i = 0; i < st->npostcheck; i++) {
+        n = st->postcheck[i];
+        if (n->type == Nexpr && exprop(n) == Omemb)
+            infercompn(st, n);
+        else if (n->type == Nexpr && exprop(n) == Ocast)
+            checkcast(st, n);
+        else if (n->type == Nlit && n->lit.littype == Lstruct)
+            checkstruct(st, n);
+        else
+            die("Thing we shouldn't be checking in postcheck\n");
+    }
+}
+
 /* After inference, replace all
  * types in symbol tables with
  * the final computed types */
@@ -1405,8 +1457,7 @@
     loaduses(file);
     mergeexports(&st, file);
     infernode(&st, file, NULL, NULL);
-    infercompn(&st, file);
-    checkcast(&st, file);
+    postcheck(&st, file);
     typesub(&st, file);
     specialize(&st, file);
 }
--- a/test/structlit.myr
+++ b/test/structlit.myr
@@ -3,7 +3,7 @@
 type t = struct
 	a	: int
 	b	: char
-	c	: char[:]
+	c	: byte[:]
 ;;
 
 const main = {