shithub: hugo

Download patch

ref: 5d565c34e5909f30f9cf7fa4c499f0297416cb83
parent: d5308e6f6f9f581414ec2e817c054f991227a682
author: Tatsushi Demachi <[email protected]>
date: Fri Sep 19 21:33:02 EDT 2014

Extend template's basic math functions to accept float, uint and string values

--- a/hugolib/template.go
+++ b/hugolib/template.go
@@ -233,6 +233,125 @@
 	return template.HTML(text)
 }
 
+func doArithmetic(a, b interface{}, op rune) (interface{}, error) {
+	av := reflect.ValueOf(a)
+	bv := reflect.ValueOf(b)
+	var ai, bi int64
+	var af, bf float64
+	var au, bu uint64
+	switch av.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		ai = av.Int()
+		switch bv.Kind() {
+		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+			bi = bv.Int()
+		case reflect.Float32, reflect.Float64:
+			af = float64(ai) // may overflow
+			ai = 0
+			bf = bv.Float()
+		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+			bu = bv.Uint()
+			if ai >= 0 {
+				au = uint64(ai)
+				ai = 0
+			} else {
+				bi = int64(bu) // may overflow
+				bu = 0
+			}
+		default:
+			return nil, errors.New("Can't apply the operator to the values")
+		}
+	case reflect.Float32, reflect.Float64:
+		af = av.Float()
+		switch bv.Kind() {
+		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+			bf = float64(bv.Int()) // may overflow
+		case reflect.Float32, reflect.Float64:
+			bf = bv.Float()
+		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+			bf = float64(bv.Uint()) // may overflow
+		default:
+			return nil, errors.New("Can't apply the operator to the values")
+		}
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		au = av.Uint()
+		switch bv.Kind() {
+		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+			bi = bv.Int()
+			if bi >= 0 {
+				bu = uint64(bi)
+				bi = 0
+			} else {
+				ai = int64(au) // may overflow
+				au = 0
+			}
+		case reflect.Float32, reflect.Float64:
+			af = float64(au) // may overflow
+			au = 0
+			bf = bv.Float()
+		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+			bu = bv.Uint()
+		default:
+			return nil, errors.New("Can't apply the operator to the values")
+		}
+	case reflect.String:
+		as := av.String()
+		if bv.Kind() == reflect.String && op == '+' {
+			bs := bv.String()
+			return as + bs, nil
+		} else {
+			return nil, errors.New("Can't apply the operator to the values")
+		}
+	default:
+		return nil, errors.New("Can't apply the operator to the values")
+	}
+
+	switch op {
+	case '+':
+		if ai != 0 || bi != 0 {
+			return ai + bi, nil
+		} else if af != 0 || bf != 0 {
+			return af + bf, nil
+		} else if au != 0 || bu != 0 {
+			return au + bu, nil
+		} else {
+			return 0, nil
+		}
+	case '-':
+		if ai != 0 || bi != 0 {
+			return ai - bi, nil
+		} else if af != 0 || bf != 0 {
+			return af - bf, nil
+		} else if au != 0 || bu != 0 {
+			return au - bu, nil
+		} else {
+			return 0, nil
+		}
+	case '*':
+		if ai != 0 || bi != 0 {
+			return ai * bi, nil
+		} else if af != 0 || bf != 0 {
+			return af * bf, nil
+		} else if au != 0 || bu != 0 {
+			return au * bu, nil
+		} else {
+			return 0, nil
+		}
+	case '/':
+		if bi != 0 {
+			return ai / bi, nil
+		} else if bf != 0 {
+			return af / bf, nil
+		} else if bu != 0 {
+			return au / bu, nil
+		} else {
+			return nil, errors.New("Can't divide the value by 0")
+		}
+	default:
+		return nil, errors.New("There is no such an operation")
+	}
+}
+
 type Template interface {
 	ExecuteTemplate(wr io.Writer, name string, data interface{}) error
 	Lookup(name string) *template.Template
@@ -278,11 +397,11 @@
 		"first":       First,
 		"where":       Where,
 		"highlight":   Highlight,
-		"add":         func(a, b int) int { return a + b },
-		"sub":         func(a, b int) int { return a - b },
-		"div":         func(a, b int) int { return a / b },
+		"add":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '+') },
+		"sub":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '-') },
+		"div":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '/') },
 		"mod":         func(a, b int) int { return a % b },
-		"mul":         func(a, b int) int { return a * b },
+		"mul":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '*') },
 		"modBool":     func(a, b int) bool { return a%b == 0 },
 		"lower":       func(a string) string { return strings.ToLower(a) },
 		"upper":       func(a string) string { return strings.ToUpper(a) },
--- a/hugolib/template_test.go
+++ b/hugolib/template_test.go
@@ -35,6 +35,94 @@
 	}
 }
 
+func TestDoArithmetic(t *testing.T) {
+	for i, this := range []struct {
+		a      interface{}
+		b      interface{}
+		op     rune
+		expect interface{}
+	}{
+		{3, 2, '+', int64(5)},
+		{3, 2, '-', int64(1)},
+		{3, 2, '*', int64(6)},
+		{3, 2, '/', int64(1)},
+		{3.0, 2, '+', float64(5)},
+		{3.0, 2, '-', float64(1)},
+		{3.0, 2, '*', float64(6)},
+		{3.0, 2, '/', float64(1.5)},
+		{3, 2.0, '+', float64(5)},
+		{3, 2.0, '-', float64(1)},
+		{3, 2.0, '*', float64(6)},
+		{3, 2.0, '/', float64(1.5)},
+		{3.0, 2.0, '+', float64(5)},
+		{3.0, 2.0, '-', float64(1)},
+		{3.0, 2.0, '*', float64(6)},
+		{3.0, 2.0, '/', float64(1.5)},
+		{uint(3), uint(2), '+', uint64(5)},
+		{uint(3), uint(2), '-', uint64(1)},
+		{uint(3), uint(2), '*', uint64(6)},
+		{uint(3), uint(2), '/', uint64(1)},
+		{uint(3), 2, '+', uint64(5)},
+		{uint(3), 2, '-', uint64(1)},
+		{uint(3), 2, '*', uint64(6)},
+		{uint(3), 2, '/', uint64(1)},
+		{3, uint(2), '+', uint64(5)},
+		{3, uint(2), '-', uint64(1)},
+		{3, uint(2), '*', uint64(6)},
+		{3, uint(2), '/', uint64(1)},
+		{uint(3), -2, '+', int64(1)},
+		{uint(3), -2, '-', int64(5)},
+		{uint(3), -2, '*', int64(-6)},
+		{uint(3), -2, '/', int64(-1)},
+		{-3, uint(2), '+', int64(-1)},
+		{-3, uint(2), '-', int64(-5)},
+		{-3, uint(2), '*', int64(-6)},
+		{-3, uint(2), '/', int64(-1)},
+		{uint(3), 2.0, '+', float64(5)},
+		{uint(3), 2.0, '-', float64(1)},
+		{uint(3), 2.0, '*', float64(6)},
+		{uint(3), 2.0, '/', float64(1.5)},
+		{3.0, uint(2), '+', float64(5)},
+		{3.0, uint(2), '-', float64(1)},
+		{3.0, uint(2), '*', float64(6)},
+		{3.0, uint(2), '/', float64(1.5)},
+		{0, 0, '+', 0},
+		{0, 0, '-', 0},
+		{0, 0, '*', 0},
+		{"foo", "bar", '+', "foobar"},
+		{3, 0, '/', false},
+		{3.0, 0, '/', false},
+		{3, 0.0, '/', false},
+		{uint(3), uint(0), '/', false},
+		{3, uint(0), '/', false},
+		{-3, uint(0), '/', false},
+		{uint(3), 0, '/', false},
+		{3.0, uint(0), '/', false},
+		{uint(3), 0.0, '/', false},
+		{3, "foo", '+', false},
+		{3.0, "foo", '+', false},
+		{uint(3), "foo", '+', false},
+		{"foo", 3, '+', false},
+		{"foo", "bar", '-', false},
+		{3, 2, '%', false},
+	} {
+		result, err := doArithmetic(this.a, this.b, this.op)
+		if b, ok := this.expect.(bool); ok && !b {
+			if err == nil {
+				t.Errorf("[%d] doArithmetic didn't return an expected error")
+			}
+		} else {
+			if err != nil {
+				t.Errorf("[%d] failed: %s", i, err)
+				continue
+			}
+			if !reflect.DeepEqual(result, this.expect) {
+				t.Errorf("[%d] doArithmetic got %v but expected %v", i, result, this.expect)
+			}
+		}
+	}
+}
+
 func TestFirst(t *testing.T) {
 	for i, this := range []struct {
 		count    int