shithub: hugo

Download patch

ref: 06f56fc983d460506d39b3a6f638b1632af07073
parent: d7a67dcb51829b12d492d3f2ee4f6e2a3834da63
author: Bjørn Erik Pedersen <[email protected]>
date: Thu Apr 18 13:06:54 EDT 2019

tpl/collections: Make Pages etc. work with the in func

Fixes #5875

--- a/tpl/collections/collections.go
+++ b/tpl/collections/collections.go
@@ -250,27 +250,26 @@
 	lv := reflect.ValueOf(l)
 	vv := reflect.ValueOf(v)
 
+	if !vv.Type().Comparable() {
+		// TODO(bep) consider adding error to the signature.
+		return false
+	}
+
+	// Normalize numeric types to float64 etc.
+	vvk := normalize(vv)
+
 	switch lv.Kind() {
 	case reflect.Array, reflect.Slice:
 		for i := 0; i < lv.Len(); i++ {
-			lvv := lv.Index(i)
-			lvv, isNil := indirect(lvv)
-			if isNil {
+			lvv, isNil := indirectInterface(lv.Index(i))
+			if isNil || !lvv.Type().Comparable() {
 				continue
 			}
-			switch lvv.Kind() {
-			case reflect.String:
-				if vv.Type() == lvv.Type() && vv.String() == lvv.String() {
-					return true
-				}
-			default:
-				if isNumber(vv.Kind()) && isNumber(lvv.Kind()) {
-					f1, err1 := numberToFloat(vv)
-					f2, err2 := numberToFloat(lvv)
-					if err1 == nil && err2 == nil && f1 == f2 {
-						return true
-					}
-				}
+
+			lvvk := normalize(lvv)
+
+			if lvvk == vvk {
+				return true
 			}
 		}
 	case reflect.String:
--- a/tpl/collections/collections_test.go
+++ b/tpl/collections/collections_test.go
@@ -276,6 +276,7 @@
 
 func TestIn(t *testing.T) {
 	t.Parallel()
+	assert := require.New(t)
 
 	ns := New(&deps.Deps{})
 
@@ -302,12 +303,18 @@
 		{"this substring should be found", "substring", true},
 		{"this substring should not be found", "subseastring", false},
 		{nil, "foo", false},
+		// Pointers
+		{pagesPtr{p1, p2, p3, p2}, p2, true},
+		{pagesPtr{p1, p2, p3, p2}, p4, false},
+		// Structs
+		{pagesVals{p3v, p2v, p3v, p2v}, p2v, true},
+		{pagesVals{p3v, p2v, p3v, p2v}, p4v, false},
 	} {
 
 		errMsg := fmt.Sprintf("[%d] %v", i, test)
 
 		result := ns.In(test.l1, test.l2)
-		assert.Equal(t, test.expect, result, errMsg)
+		assert.Equal(test.expect, result, errMsg)
 	}
 }