shithub: hugo

Download patch

ref: 002a5b675691a06cc593e2728bb00731cafa964f
parent: 6e15f652bdef4f7e9d9c9b4c0cd764707aa7e48e
author: Tatsushi Demachi <[email protected]>
date: Sat Aug 16 09:12:34 EDT 2014

Add 'where' template function

--- a/hugolib/template.go
+++ b/hugolib/template.go
@@ -109,6 +109,71 @@
 	return seqv.Slice(0, limit).Interface(), nil
 }
 
+func Where(seq, key, match interface{}) (interface{}, error) {
+	seqv := reflect.ValueOf(seq)
+	kv := reflect.ValueOf(key)
+	mv := reflect.ValueOf(match)
+
+	// this is better than my first pass; ripped from text/template/exec.go indirect():
+	for ; seqv.Kind() == reflect.Ptr || seqv.Kind() == reflect.Interface; seqv = seqv.Elem() {
+		if seqv.IsNil() {
+			return nil, errors.New("can't iterate over a nil value")
+		}
+		if seqv.Kind() == reflect.Interface && seqv.NumMethod() > 0 {
+			break
+		}
+	}
+
+	switch seqv.Kind() {
+	case reflect.Array, reflect.Slice:
+		r := reflect.MakeSlice(seqv.Type(), 0, 0)
+		for i := 0; i < seqv.Len(); i++ {
+			var vvv reflect.Value
+			vv := seqv.Index(i)
+			switch vv.Kind() {
+			case reflect.Map:
+				if kv.Type() == vv.Type().Key() && vv.MapIndex(kv).IsValid() {
+					vvv = vv.MapIndex(kv)
+				}
+			case reflect.Struct:
+				if kv.Kind() == reflect.String && vv.FieldByName(kv.String()).IsValid() {
+					vvv = vv.FieldByName(kv.String())
+				}
+			case reflect.Ptr:
+				if !vv.IsNil() {
+					ev := vv.Elem()
+					switch ev.Kind() {
+					case reflect.Map:
+						if kv.Type() == ev.Type().Key() && ev.MapIndex(kv).IsValid() {
+							vvv = ev.MapIndex(kv)
+						}
+					case reflect.Struct:
+						if kv.Kind() == reflect.String && ev.FieldByName(kv.String()).IsValid() {
+							vvv = ev.FieldByName(kv.String())
+						}
+					}
+				}
+			}
+
+			if vvv.IsValid() && mv.Type() == vvv.Type() {
+				switch mv.Kind() {
+				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+					if mv.Int() == vvv.Int() {
+						r = reflect.Append(r, vv)
+					}
+				case reflect.String:
+					if mv.String() == vvv.String() {
+						r = reflect.Append(r, vv)
+					}
+				}
+			}
+		}
+		return r.Interface(), nil
+	default:
+		return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
+	}
+}
+
 func IsSet(a interface{}, key interface{}) bool {
 	av := reflect.ValueOf(a)
 	kv := reflect.ValueOf(key)
@@ -211,6 +276,7 @@
 		"echoParam":   ReturnWhenSet,
 		"safeHtml":    SafeHtml,
 		"first":       First,
+		"where":       Where,
 		"highlight":   Highlight,
 		"add":         func(a, b int) int { return a + b },
 		"sub":         func(a, b int) int { return a - b },
--- a/hugolib/template_test.go
+++ b/hugolib/template_test.go
@@ -55,3 +55,30 @@
 		}
 	}
 }
+
+func TestWhere(t *testing.T) {
+	type X struct {
+		A, B string
+	}
+	for i, this := range []struct {
+		sequence interface{}
+		key      interface{}
+		match    interface{}
+		expect   interface{}
+	}{
+		{[]map[int]string{{1: "a", 2: "m"}, {1: "c", 2: "d"}, {1: "e", 3: "m"}}, 2, "m", []map[int]string{{1: "a", 2: "m"}}},
+		{[]map[string]int{{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "x": 4}}, "b", 4, []map[string]int{{"a": 3, "b": 4}}},
+		{[]X{{"a", "b"}, {"c", "d"}, {"e", "f"}}, "B", "f", []X{{"e", "f"}}},
+		{[]*map[int]string{&map[int]string{1: "a", 2: "m"}, &map[int]string{1: "c", 2: "d"}, &map[int]string{1: "e", 3: "m"}}, 2, "m", []*map[int]string{&map[int]string{1: "a", 2: "m"}}},
+		{[]*X{&X{"a", "b"}, &X{"c", "d"}, &X{"e", "f"}}, "B", "f", []*X{&X{"e", "f"}}},
+	} {
+		results, err := Where(this.sequence, this.key, this.match)
+		if err != nil {
+			t.Errorf("[%d] failed: %s", i, err)
+			continue
+		}
+		if !reflect.DeepEqual(results, this.expect) {
+			t.Errorf("[%d] Where clause matching %v with %v, got %v but expected %v", i, this.key, this.match, results, this.expect)
+		}
+	}
+}