shithub: mc

Download patch

ref: 5196e4c16239a9f593c4efa14819b70721006740
parent: 9193bdc08bcba1b515fb60d25b700963a454b90a
author: Ori Bernstein <[email protected]>
date: Thu Aug 16 16:40:56 EDT 2012

We declare the variables in pattern matches now.

    We used to just error on unbound variables. Now we recognize
    that we need to match them, and declare a variable for them.
    We just don't actually assign it what we should...

--- a/parse/infer.c
+++ b/parse/infer.c
@@ -14,7 +14,6 @@
 
 typedef struct Inferstate Inferstate;
 struct Inferstate {
-    int inpat;
     int ingeneric;
     int sawret;
     Type *ret;
@@ -45,6 +44,46 @@
 static Type *unify(Inferstate *st, Node *ctx, Type *a, Type *b);
 static Type *tf(Inferstate *st, Type *t);
 
+/* Tries to give a good string describing the context
+ * for the sake of error messages. */
+static char *ctxstr(Inferstate *st, Node *n)
+{
+    char *s;
+    char *t;
+    char *u;
+    char buf[512];
+
+    switch (n->type) {
+        default:
+            s = nodestr(n->type);
+            break;
+        case Ndecl:
+            u = declname(n);
+            t = tystr(tf(st, decltype(n)));
+            snprintf(buf, 512, "%s:%s", u, t);
+            s = strdup(buf);
+            free(t);
+            break;
+        case Nname:
+            s = namestr(n);
+            break;
+        case Nexpr:
+            if (exprop(n) == Ovar)
+                u = namestr(n->expr.args[0]);
+            else
+                u = opstr(exprop(n));
+            if (exprtype(n))
+                t = tystr(tf(st, exprtype(n)));
+            else
+                t = strdup("unknown");
+            snprintf(buf, 512, "%s:%s", u, t);
+            s = strdup(buf);
+            free(t);
+            break;
+    }
+    return s;
+}
+
 /* Set a scope's enclosing scope up correctly.
  * We don't do this in the parser for some reason. */
 static void setsuper(Stab *st, Stab *super)
@@ -269,44 +308,20 @@
     return tf(st, t);
 }
 
-/* Tries to give a good string describing the context
- * for the sake of error messages. */
-static char *ctxstr(Inferstate *st, Node *n)
+static Ucon *uconresolve(Inferstate *st, Node *n)
 {
-    char *s;
-    char *t;
-    char *u;
-    char buf[512];
+    Ucon *uc;
+    Node **args;
 
-    switch (n->type) {
-        default:
-            s = nodestr(n->type);
-            break;
-        case Ndecl:
-            u = declname(n);
-            t = tystr(tf(st, decltype(n)));
-            snprintf(buf, 512, "%s:%s", u, t);
-            s = strdup(buf);
-            free(t);
-            break;
-        case Nname:
-            s = namestr(n);
-            break;
-        case Nexpr:
-            if (exprop(n) == Ovar)
-                u = namestr(n->expr.args[0]);
-            else
-                u = opstr(exprop(n));
-            if (exprtype(n))
-                t = tystr(tf(st, exprtype(n)));
-            else
-                t = strdup("unknown");
-            snprintf(buf, 512, "%s:%s", u, t);
-            s = strdup(buf);
-            free(t);
-            break;
-    }
-    return s;
+    args = n->expr.args;
+    uc = getucon(curstab(), args[0]);
+    if (!uc)
+        fatal(n->line, "No union constructor %s", ctxstr(st, args[0]));
+    if (!uc->etype && n->expr.nargs > 1)
+        fatal(n->line, "nullary union constructor %s passed arg ", ctxstr(st, args[0]));
+    else if (uc->etype && n->expr.nargs != 2)
+        fatal(n->line, "union constructor %s needs arg ", ctxstr(st, args[0]));
+    return uc;
 }
 
 /* Binds the type parameters present in the
@@ -647,6 +662,67 @@
     settype(st, n, mktyarray(n->line, type(st, n->lit.seqval[0]), mkintlit(n->line, n->lit.nelt)));
 }
 
+static void inferpat(Inferstate *st, Node *n, Node ***bind, size_t *nbind)
+{
+    size_t i;
+    Ucon *uc;
+    Node **args;
+    Node *s;
+    Type *t;
+
+    args = n->expr.args;
+    for (i = 0; i < n->expr.nargs; i++)
+        if (args[i]->type == Nexpr)
+            inferpat(st, args[i], bind, nbind);
+    switch (exprop(n)) {
+        case Ocons:
+            uc = uconresolve(st, n);
+            if (uc->etype)
+                unify(st, n, uc->etype, type(st, args[1]));
+            settype(st, n, uc->utype);
+            break;
+        case Ovar:
+            s = getdcl(curstab(), args[0]);
+            if (s) {
+                if (s->decl.isgeneric)
+                    t = freshen(st, s->decl.type);
+                else if (s->decl.isconst)
+                    t = s->decl.type;
+                else
+                    fatal(n->line, "Can't match against variables in nterns near %s", ctxstr(st, n));
+            } else {
+                t = mktyvar(n->line);
+                s = mkdecl(n->line, n->expr.args[0], t);
+                settype(st, n, t);
+                lappend(bind, nbind, s);
+            }
+            settype(st, n, t);
+            n->expr.did = s->decl.did;
+            break;
+        case Olit:
+            infernode(st, n, NULL, NULL);
+            break;
+        case Omemb:
+            infernode(st, n, NULL, NULL);
+            break;
+        default:
+            die("Bad pattern to match against");
+            break;
+    }
+}
+
+void addbindings(Inferstate *st, Node *n, Node **bind, size_t nbind)
+{
+    size_t i;
+
+    /* order of binding shouldn't matter, so push them into the block
+     * in reverse order. */
+    for (i = 0; i < nbind; i++) {
+        putdcl(n->block.scope, bind[i]);
+        linsert(&n->block.stmts, &n->block.nstmts, 0, bind[i]);
+    }
+}
+
 static void inferexpr(Inferstate *st, Node *n, Type *ret, int *sawret)
 {
     Node **args;
@@ -793,14 +869,8 @@
             }
             break;
         case Ocons:
-            uc = getucon(curstab(), args[0]);
-            if (!uc)
-                fatal(n->line, "No union constructor %s", ctxstr(st, args[0]));
-            if (!uc->etype && n->expr.nargs > 1)
-                fatal(n->line, "nullary union constructor %s passed arg ", ctxstr(st, args[0]));
-            else if (uc->etype && n->expr.nargs != 2)
-                fatal(n->line, "union constructor %s needs arg ", ctxstr(st, args[0]));
-            else if (uc->etype)
+            uc = uconresolve(st, n);
+            if (uc->etype)
                 unify(st, n, uc->etype, type(st, args[1]));
             settype(st, n, uc->utype);
             break;
@@ -882,6 +952,8 @@
     size_t i;
     Node *d;
     Node *s;
+    size_t nbound;
+    Node **bound;
 
     if (!n)
         return;
@@ -949,9 +1021,10 @@
             }
             break;
         case Nmatch:
-            st->inpat++;
-            infernode(st, n->match.pat, NULL, sawret);
-            st->inpat--;
+            bound = NULL;
+            nbound = 0;
+            inferpat(st, n->match.pat, &bound, &nbound);
+            addbindings(st, n->match.block, bound, nbound);
             infernode(st, n->match.block, ret, sawret);
             break;
         case Nexpr: