shithub: hugo

Download patch

ref: 3908489ccd328ad7dcbab3ff0bb5c53ee3da19e9
parent: 640121423ca830b6062ef7bc4edab8dbf4b5411d
author: Cameron Moore <[email protected]>
date: Mon Sep 12 00:03:53 EDT 2016

tpl: Extend where to iterate over maps

Fixes #2028

--- a/tpl/template_funcs.go
+++ b/tpl/template_funcs.go
@@ -38,16 +38,13 @@
 	"time"
 	"unicode/utf8"
 
-	"github.com/spf13/viper"
-
-	"github.com/spf13/afero"
-	"github.com/spf13/hugo/hugofs"
-
 	"github.com/bep/inflect"
-
+	"github.com/spf13/afero"
 	"github.com/spf13/cast"
 	"github.com/spf13/hugo/helpers"
+	"github.com/spf13/hugo/hugofs"
 	jww "github.com/spf13/jwalterweatherman"
+	"github.com/spf13/viper"
 )
 
 var funcMap template.FuncMap
@@ -797,13 +794,9 @@
 	return false, nil
 }
 
-// where returns a filtered subset of a given data type.
-func where(seq, key interface{}, args ...interface{}) (r interface{}, err error) {
-	seqv := reflect.ValueOf(seq)
-	kv := reflect.ValueOf(key)
-
-	var mv reflect.Value
-	var op string
+// parseWhereArgs parses the end arguments to the where function.  Return a
+// match value and an operator, if one is defined.
+func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) {
 	switch len(args) {
 	case 1:
 		mv = reflect.ValueOf(args[0])
@@ -810,20 +803,107 @@
 	case 2:
 		var ok bool
 		if op, ok = args[0].(string); !ok {
-			return nil, errors.New("operator argument must be string type")
+			err = errors.New("operator argument must be string type")
+			return
 		}
 		op = strings.TrimSpace(strings.ToLower(op))
 		mv = reflect.ValueOf(args[1])
 	default:
-		return nil, errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
+		err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
 	}
+	return
+}
 
-	seqv, isNil := indirect(seqv)
+// checkWhereArray handles the where-matching logic when the seqv value is an
+// Array or Slice.
+func checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
+	rv := reflect.MakeSlice(seqv.Type(), 0, 0)
+	for i := 0; i < seqv.Len(); i++ {
+		var vvv reflect.Value
+		rvv := seqv.Index(i)
+		if kv.Kind() == reflect.String {
+			vvv = rvv
+			for _, elemName := range path {
+				var err error
+				vvv, err = evaluateSubElem(vvv, elemName)
+				if err != nil {
+					return nil, err
+				}
+			}
+		} else {
+			vv, _ := indirect(rvv)
+			if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
+				vvv = vv.MapIndex(kv)
+			}
+		}
+
+		if ok, err := checkCondition(vvv, mv, op); ok {
+			rv = reflect.Append(rv, rvv)
+		} else if err != nil {
+			return nil, err
+		}
+	}
+	return rv.Interface(), nil
+}
+
+// checkWhereMap handles the where-matching logic when the seqv value is a Map.
+func checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
+	rv := reflect.MakeMap(seqv.Type())
+	keys := seqv.MapKeys()
+	for _, k := range keys {
+		elemv := seqv.MapIndex(k)
+		switch elemv.Kind() {
+		case reflect.Array, reflect.Slice:
+			r, err := checkWhereArray(elemv, kv, mv, path, op)
+			if err != nil {
+				return nil, err
+			}
+
+			switch rr := reflect.ValueOf(r); rr.Kind() {
+			case reflect.Slice:
+				if rr.Len() > 0 {
+					rv.SetMapIndex(k, elemv)
+				}
+			}
+		case reflect.Interface:
+			elemvv, isNil := indirect(elemv)
+			if isNil {
+				continue
+			}
+
+			switch elemvv.Kind() {
+			case reflect.Array, reflect.Slice:
+				r, err := checkWhereArray(elemvv, kv, mv, path, op)
+				if err != nil {
+					return nil, err
+				}
+
+				switch rr := reflect.ValueOf(r); rr.Kind() {
+				case reflect.Slice:
+					if rr.Len() > 0 {
+						rv.SetMapIndex(k, elemv)
+					}
+				}
+			}
+		}
+	}
+	return rv.Interface(), nil
+}
+
+// where returns a filtered subset of a given data type.
+func where(seq, key interface{}, args ...interface{}) (interface{}, error) {
+	seqv, isNil := indirect(reflect.ValueOf(seq))
 	if isNil {
 		return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
 	}
 
+	mv, op, err := parseWhereArgs(args...)
+	if err != nil {
+		return nil, err
+	}
+
 	var path []string
+	kv := reflect.ValueOf(key)
 	if kv.Kind() == reflect.String {
 		path = strings.Split(strings.Trim(kv.String(), "."), ".")
 	}
@@ -830,31 +910,9 @@
 
 	switch seqv.Kind() {
 	case reflect.Array, reflect.Slice:
-		rv := reflect.MakeSlice(seqv.Type(), 0, 0)
-		for i := 0; i < seqv.Len(); i++ {
-			var vvv reflect.Value
-			rvv := seqv.Index(i)
-			if kv.Kind() == reflect.String {
-				vvv = rvv
-				for _, elemName := range path {
-					vvv, err = evaluateSubElem(vvv, elemName)
-					if err != nil {
-						return nil, err
-					}
-				}
-			} else {
-				vv, _ := indirect(rvv)
-				if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
-					vvv = vv.MapIndex(kv)
-				}
-			}
-			if ok, err := checkCondition(vvv, mv, op); ok {
-				rv = reflect.Append(rv, rvv)
-			} else if err != nil {
-				return nil, err
-			}
-		}
-		return rv.Interface(), nil
+		return checkWhereArray(seqv, kv, mv, path, op)
+	case reflect.Map:
+		return checkWhereMap(seqv, kv, mv, path, op)
 	default:
 		return nil, fmt.Errorf("can't iterate over %v", seq)
 	}
--- a/tpl/template_funcs_test.go
+++ b/tpl/template_funcs_test.go
@@ -1415,6 +1415,29 @@
 			key: "B", op: "op", match: "f",
 			expect: false,
 		},
+		{
+			sequence: map[string]interface{}{
+				"foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}},
+				"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
+				"zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}},
+			},
+			key: "b", op: "in", match: slice(3, 4, 5),
+			expect: map[string]interface{}{
+				"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
+			},
+		},
+		{
+			sequence: map[string]interface{}{
+				"foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}},
+				"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
+				"zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}},
+			},
+			key: "b", op: ">", match: 3,
+			expect: map[string]interface{}{
+				"bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
+				"zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}},
+			},
+		},
 	} {
 		var results interface{}
 		var err error