shithub: hugo

Download patch

ref: a55640de8e3944d3b9f64b15155148a0e35cb31e
parent: 9225db636e2f9b75f992013a25c0b149d6bd8b0d
author: Bjørn Erik Pedersen <[email protected]>
date: Tue Apr 2 06:30:24 EDT 2019

tpl: Allow the partial template func to return any type

This commit adds support for return values in partials.

This means that you can now do this and similar:

    {{ $v := add . 42 }}
    {{ return $v }}

Partials without a `return` statement will be rendered as before.

This works for both `partial` and `partialCached`.

Fixes #5783


--- a/compare/compare.go
+++ b/compare/compare.go
@@ -20,6 +20,12 @@
 	Eq(other interface{}) bool
 }
 
+// ProbablyEq is an equal check that may return false positives, but never
+// a false negative.
+type ProbablyEqer interface {
+	ProbablyEq(other interface{}) bool
+}
+
 // Comparer can be used to compare two values.
 // This will be used when using the le, ge etc. operators in the templates.
 // Compare returns -1 if the given version is less than, 0 if equal and 1 if greater than
--- a/hugolib/template_test.go
+++ b/hugolib/template_test.go
@@ -264,3 +264,44 @@
 	)
 
 }
+
+func TestPartialWithReturn(t *testing.T) {
+
+	b := newTestSitesBuilder(t).WithSimpleConfigFile()
+
+	b.WithTemplatesAdded(
+		"index.html", `
+Test Partials With Return Values:
+
+add42: 50: {{ partial "add42.tpl" 8 }}
+dollarContext: 60: {{ partial "dollarContext.tpl" 18 }}
+adder: 70: {{ partial "dict.tpl" (dict "adder" 28) }}
+complex: 80: {{ partial "complex.tpl" 38 }}
+`,
+		"partials/add42.tpl", `
+		{{ $v := add . 42 }}
+		{{ return $v }}
+		`,
+		"partials/dollarContext.tpl", `
+{{ $v := add $ 42 }}
+{{ return $v }}
+`,
+		"partials/dict.tpl", `
+{{ $v := add $.adder 42 }}
+{{ return $v }}
+`,
+		"partials/complex.tpl", `
+{{ return add . 42 }}
+`,
+	)
+
+	b.CreateSites().Build(BuildCfg{})
+
+	b.AssertFileContent("public/index.html",
+		"add42: 50: 50",
+		"dollarContext: 60: 60",
+		"adder: 70: 70",
+		"complex: 80: 80",
+	)
+
+}
--- a/metrics/metrics.go
+++ b/metrics/metrics.go
@@ -23,6 +23,12 @@
 	"strings"
 	"sync"
 	"time"
+
+	"github.com/gohugoio/hugo/compare"
+
+	"github.com/gohugoio/hugo/common/hreflect"
+
+	"github.com/spf13/cast"
 )
 
 // The Provider interface defines an interface for measuring metrics.
@@ -35,7 +41,7 @@
 	WriteMetrics(w io.Writer)
 
 	// TrackValue tracks the value for diff calculations etc.
-	TrackValue(key, value string)
+	TrackValue(key string, value interface{})
 
 	// Reset clears the metric store.
 	Reset()
@@ -42,13 +48,13 @@
 }
 
 type diff struct {
-	baseline string
+	baseline interface{}
 	count    int
 	simSum   int
 }
 
-func (d *diff) add(v string) *diff {
-	if d.baseline == "" {
+func (d *diff) add(v interface{}) *diff {
+	if !hreflect.IsTruthful(v) {
 		d.baseline = v
 		d.count = 1
 		d.simSum = 100 // If we get only one it is very cache friendly.
@@ -90,7 +96,7 @@
 }
 
 // TrackValue tracks the value for diff calculations etc.
-func (s *Store) TrackValue(key, value string) {
+func (s *Store) TrackValue(key string, value interface{}) {
 	if !s.calculateHints {
 		return
 	}
@@ -191,12 +197,42 @@
 
 // howSimilar is a naive diff implementation that returns
 // a number between 0-100 indicating how similar a and b are.
-// 100 is when all words in a also exists in b.
-func howSimilar(a, b string) int {
-
+func howSimilar(a, b interface{}) int {
 	if a == b {
 		return 100
 	}
+
+	as, err1 := cast.ToStringE(a)
+	bs, err2 := cast.ToStringE(b)
+
+	if err1 == nil && err2 == nil {
+		return howSimilarStrings(as, bs)
+	}
+
+	if err1 != err2 {
+		return 0
+	}
+
+	e1, ok1 := a.(compare.Eqer)
+	e2, ok2 := b.(compare.Eqer)
+	if ok1 && ok2 && e1.Eq(e2) {
+		return 100
+	}
+
+	// TODO(bep) implement ProbablyEq for Pages etc.
+	pe1, pok1 := a.(compare.ProbablyEqer)
+	pe2, pok2 := b.(compare.ProbablyEqer)
+	if pok1 && pok2 && pe1.ProbablyEq(pe2) {
+		return 90
+	}
+
+	return 0
+}
+
+// howSimilar is a naive diff implementation that returns
+// a number between 0-100 indicating how similar a and b are.
+// 100 is when all words in a also exists in b.
+func howSimilarStrings(a, b string) int {
 
 	// Give some weight to the word positions.
 	const partitionSize = 4
--- a/tpl/partials/init.go
+++ b/tpl/partials/init.go
@@ -36,6 +36,13 @@
 			},
 		)
 
+		// TODO(bep) we need the return to be a valid identifier, but
+		// should consider another way of adding it.
+		ns.AddMethodMapping(func() string { return "" },
+			[]string{"return"},
+			[][2]string{},
+		)
+
 		ns.AddMethodMapping(ctx.IncludeCached,
 			[]string{"partialCached"},
 			[][2]string{},
--- a/tpl/partials/partials.go
+++ b/tpl/partials/partials.go
@@ -18,10 +18,14 @@
 import (
 	"fmt"
 	"html/template"
+	"io"
+	"io/ioutil"
 	"strings"
 	"sync"
 	texttemplate "text/template"
 
+	"github.com/gohugoio/hugo/tpl"
+
 	bp "github.com/gohugoio/hugo/bufferpool"
 	"github.com/gohugoio/hugo/deps"
 )
@@ -62,8 +66,22 @@
 	cachedPartials *partialCache
 }
 
-// Include executes the named partial and returns either a string,
-// when the partial is a text/template, or template.HTML when html/template.
+// contextWrapper makes room for a return value in a partial invocation.
+type contextWrapper struct {
+	Arg    interface{}
+	Result interface{}
+}
+
+// Set sets the return value and returns an empty string.
+func (c *contextWrapper) Set(in interface{}) string {
+	c.Result = in
+	return ""
+}
+
+// Include executes the named partial.
+// If the partial contains a return statement, that value will be returned.
+// Else, the rendered output will be returned:
+// A string if the partial is a text/template, or template.HTML when html/template.
 func (ns *Namespace) Include(name string, contextList ...interface{}) (interface{}, error) {
 	if strings.HasPrefix(name, "partials/") {
 		name = name[8:]
@@ -83,31 +101,54 @@
 		// For legacy reasons.
 		templ, found = ns.deps.Tmpl.Lookup(n + ".html")
 	}
-	if found {
+
+	if !found {
+		return "", fmt.Errorf("partial %q not found", name)
+	}
+
+	var info tpl.Info
+	if ip, ok := templ.(tpl.TemplateInfoProvider); ok {
+		info = ip.TemplateInfo()
+	}
+
+	var w io.Writer
+
+	if info.HasReturn {
+		// Wrap the context sent to the template to capture the return value.
+		// Note that the template is rewritten to make sure that the dot (".")
+		// and the $ variable points to Arg.
+		context = &contextWrapper{
+			Arg: context,
+		}
+
+		// We don't care about any template output.
+		w = ioutil.Discard
+	} else {
 		b := bp.GetBuffer()
 		defer bp.PutBuffer(b)
+		w = b
+	}
 
-		if err := templ.Execute(b, context); err != nil {
-			return "", err
-		}
+	if err := templ.Execute(w, context); err != nil {
+		return "", err
+	}
 
-		if _, ok := templ.(*texttemplate.Template); ok {
-			s := b.String()
-			if ns.deps.Metrics != nil {
-				ns.deps.Metrics.TrackValue(n, s)
-			}
-			return s, nil
-		}
+	var result interface{}
 
-		s := b.String()
-		if ns.deps.Metrics != nil {
-			ns.deps.Metrics.TrackValue(n, s)
-		}
-		return template.HTML(s), nil
+	if ctx, ok := context.(*contextWrapper); ok {
+		result = ctx.Result
+	} else if _, ok := templ.(*texttemplate.Template); ok {
+		result = w.(fmt.Stringer).String()
+	} else {
+		result = template.HTML(w.(fmt.Stringer).String())
+	}
 
+	if ns.deps.Metrics != nil {
+		ns.deps.Metrics.TrackValue(n, result)
 	}
 
-	return "", fmt.Errorf("partial %q not found", name)
+	return result, nil
+
 }
 
 // IncludeCached executes and caches partial templates.  An optional variant
--- a/tpl/template_info.go
+++ b/tpl/template_info.go
@@ -22,8 +22,15 @@
 	// Set for shortcode templates with any {{ .Inner }}
 	IsInner bool
 
+	// Set for partials with a return statement.
+	HasReturn bool
+
 	// Config extracted from template.
 	Config Config
+}
+
+func (info Info) IsZero() bool {
+	return info.Config.Version == 0
 }
 
 type Config struct {
--- a/tpl/tplimpl/ace.go
+++ b/tpl/tplimpl/ace.go
@@ -51,15 +51,17 @@
 		return err
 	}
 
-	isShort := isShortcode(name)
+	typ := resolveTemplateType(name)
 
-	info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
+	info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
 	if err != nil {
 		return err
 	}
 
-	if isShort {
+	if typ == templateShortcode {
 		t.addShortcodeVariant(name, info, templ)
+	} else {
+		t.templateInfo[name] = info
 	}
 
 	return nil
--- a/tpl/tplimpl/shortcodes.go
+++ b/tpl/tplimpl/shortcodes.go
@@ -139,6 +139,18 @@
 	return name, variants
 }
 
+func resolveTemplateType(name string) templateType {
+	if isShortcode(name) {
+		return templateShortcode
+	}
+
+	if strings.Contains(name, "partials/") {
+		return templatePartial
+	}
+
+	return templateUndefined
+}
+
 func isShortcode(name string) bool {
 	return strings.Contains(name, "shortcodes/")
 }
--- a/tpl/tplimpl/template.go
+++ b/tpl/tplimpl/template.go
@@ -90,6 +90,11 @@
 	// (language, output format etc.) of that shortcode.
 	shortcodes map[string]*shortcodeTemplates
 
+	// templateInfo maps template name to some additional information about that template.
+	// Note that for shortcodes that same information is embedded in the
+	// shortcodeTemplates type.
+	templateInfo map[string]tpl.Info
+
 	// text holds all the pure text templates.
 	text *textTemplates
 	html *htmlTemplates
@@ -172,18 +177,30 @@
 		// The templates are stored without the prefix identificator.
 		name = strings.TrimPrefix(name, textTmplNamePrefix)
 
-		return t.text.Lookup(name)
+		return t.applyTemplateInfo(t.text.Lookup(name))
 	}
 
 	// Look in both
 	if te, found := t.html.Lookup(name); found {
-		return te, true
+		return t.applyTemplateInfo(te, true)
 	}
 
-	return t.text.Lookup(name)
+	return t.applyTemplateInfo(t.text.Lookup(name))
 
 }
 
+func (t *templateHandler) applyTemplateInfo(templ tpl.Template, found bool) (tpl.Template, bool) {
+	if adapter, ok := templ.(*tpl.TemplateAdapter); ok {
+		if adapter.Info.IsZero() {
+			if info, found := t.templateInfo[templ.Name()]; found {
+				adapter.Info = info
+			}
+		}
+	}
+
+	return templ, found
+}
+
 // This currently only applies to shortcodes and what we get here is the
 // shortcode name.
 func (t *templateHandler) LookupVariant(name string, variants tpl.TemplateVariants) (tpl.Template, bool, bool) {
@@ -243,12 +260,13 @@
 
 func (t *templateHandler) clone(d *deps.Deps) *templateHandler {
 	c := &templateHandler{
-		Deps:       d,
-		layoutsFs:  d.BaseFs.Layouts.Fs,
-		shortcodes: make(map[string]*shortcodeTemplates),
-		html:       &htmlTemplates{t: template.Must(t.html.t.Clone()), overlays: make(map[string]*template.Template), templatesCommon: t.html.templatesCommon},
-		text:       &textTemplates{textTemplate: &textTemplate{t: texttemplate.Must(t.text.t.Clone())}, overlays: make(map[string]*texttemplate.Template), templatesCommon: t.text.templatesCommon},
-		errors:     make([]*templateErr, 0),
+		Deps:         d,
+		layoutsFs:    d.BaseFs.Layouts.Fs,
+		shortcodes:   make(map[string]*shortcodeTemplates),
+		templateInfo: t.templateInfo,
+		html:         &htmlTemplates{t: template.Must(t.html.t.Clone()), overlays: make(map[string]*template.Template), templatesCommon: t.html.templatesCommon},
+		text:         &textTemplates{textTemplate: &textTemplate{t: texttemplate.Must(t.text.t.Clone())}, overlays: make(map[string]*texttemplate.Template), templatesCommon: t.text.templatesCommon},
+		errors:       make([]*templateErr, 0),
 	}
 
 	for k, v := range t.shortcodes {
@@ -306,12 +324,13 @@
 		templatesCommon: common,
 	}
 	h := &templateHandler{
-		Deps:       deps,
-		layoutsFs:  deps.BaseFs.Layouts.Fs,
-		shortcodes: make(map[string]*shortcodeTemplates),
-		html:       htmlT,
-		text:       textT,
-		errors:     make([]*templateErr, 0),
+		Deps:         deps,
+		layoutsFs:    deps.BaseFs.Layouts.Fs,
+		shortcodes:   make(map[string]*shortcodeTemplates),
+		templateInfo: make(map[string]tpl.Info),
+		html:         htmlT,
+		text:         textT,
+		errors:       make([]*templateErr, 0),
 	}
 
 	common.handler = h
@@ -463,15 +482,17 @@
 		return err
 	}
 
-	isShort := isShortcode(name)
+	typ := resolveTemplateType(name)
 
-	info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
+	info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
 	if err != nil {
 		return err
 	}
 
-	if isShort {
+	if typ == templateShortcode {
 		t.handler.addShortcodeVariant(name, info, templ)
+	} else {
+		t.handler.templateInfo[name] = info
 	}
 
 	return nil
@@ -511,7 +532,7 @@
 		return nil, err
 	}
 
-	if _, err := applyTemplateTransformersToTextTemplate(false, templ); err != nil {
+	if _, err := applyTemplateTransformersToTextTemplate(templateUndefined, templ); err != nil {
 		return nil, err
 	}
 	return templ, nil
@@ -524,15 +545,17 @@
 		return err
 	}
 
-	isShort := isShortcode(name)
+	typ := resolveTemplateType(name)
 
-	info, err := applyTemplateTransformersToTextTemplate(isShort, templ)
+	info, err := applyTemplateTransformersToTextTemplate(typ, templ)
 	if err != nil {
 		return err
 	}
 
-	if isShort {
+	if typ == templateShortcode {
 		t.handler.addShortcodeVariant(name, info, templ)
+	} else {
+		t.handler.templateInfo[name] = info
 	}
 
 	return nil
@@ -737,7 +760,7 @@
 	// * https://github.com/golang/go/issues/16101
 	// * https://github.com/gohugoio/hugo/issues/2549
 	overlayTpl = overlayTpl.Lookup(overlayTpl.Name())
-	if _, err := applyTemplateTransformersToHMLTTemplate(false, overlayTpl); err != nil {
+	if _, err := applyTemplateTransformersToHMLTTemplate(templateUndefined, overlayTpl); err != nil {
 		return err
 	}
 
@@ -777,7 +800,7 @@
 	}
 
 	overlayTpl = overlayTpl.Lookup(overlayTpl.Name())
-	if _, err := applyTemplateTransformersToTextTemplate(false, overlayTpl); err != nil {
+	if _, err := applyTemplateTransformersToTextTemplate(templateUndefined, overlayTpl); err != nil {
 		return err
 	}
 	t.overlays[name] = overlayTpl
@@ -847,15 +870,17 @@
 			return err
 		}
 
-		isShort := isShortcode(name)
+		typ := resolveTemplateType(name)
 
-		info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
+		info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
 		if err != nil {
 			return err
 		}
 
-		if isShort {
+		if typ == templateShortcode {
 			t.addShortcodeVariant(templateName, info, templ)
+		} else {
+			t.templateInfo[name] = info
 		}
 
 		return nil
--- a/tpl/tplimpl/template_ast_transformers.go
+++ b/tpl/tplimpl/template_ast_transformers.go
@@ -39,6 +39,14 @@
 	"Data": true,
 }
 
+type templateType int
+
+const (
+	templateUndefined templateType = iota
+	templateShortcode
+	templatePartial
+)
+
 type templateContext struct {
 	decl     decl
 	visited  map[string]bool
@@ -47,8 +55,7 @@
 	// The last error encountered.
 	err error
 
-	// Only needed for shortcodes
-	isShortcode bool
+	typ templateType
 
 	// Set when we're done checking for config header.
 	configChecked bool
@@ -55,6 +62,9 @@
 
 	// Contains some info about the template
 	tpl.Info
+
+	// Store away the return node in partials.
+	returnNode *parse.CommandNode
 }
 
 func (c templateContext) getIfNotVisited(name string) *parse.Tree {
@@ -84,12 +94,12 @@
 	}
 }
 
-func applyTemplateTransformersToHMLTTemplate(isShortcode bool, templ *template.Template) (tpl.Info, error) {
-	return applyTemplateTransformers(isShortcode, templ.Tree, createParseTreeLookup(templ))
+func applyTemplateTransformersToHMLTTemplate(typ templateType, templ *template.Template) (tpl.Info, error) {
+	return applyTemplateTransformers(typ, templ.Tree, createParseTreeLookup(templ))
 }
 
-func applyTemplateTransformersToTextTemplate(isShortcode bool, templ *texttemplate.Template) (tpl.Info, error) {
-	return applyTemplateTransformers(isShortcode, templ.Tree,
+func applyTemplateTransformersToTextTemplate(typ templateType, templ *texttemplate.Template) (tpl.Info, error) {
+	return applyTemplateTransformers(typ, templ.Tree,
 		func(nn string) *parse.Tree {
 			tt := templ.Lookup(nn)
 			if tt != nil {
@@ -99,19 +109,54 @@
 		})
 }
 
-func applyTemplateTransformers(isShortcode bool, templ *parse.Tree, lookupFn func(name string) *parse.Tree) (tpl.Info, error) {
+func applyTemplateTransformers(typ templateType, templ *parse.Tree, lookupFn func(name string) *parse.Tree) (tpl.Info, error) {
 	if templ == nil {
 		return tpl.Info{}, errors.New("expected template, but none provided")
 	}
 
 	c := newTemplateContext(lookupFn)
-	c.isShortcode = isShortcode
+	c.typ = typ
 
-	err := c.applyTransformations(templ.Root)
+	_, err := c.applyTransformations(templ.Root)
 
+	if err == nil && c.returnNode != nil {
+		// This is a partial with a return statement.
+		c.Info.HasReturn = true
+		templ.Root = c.wrapInPartialReturnWrapper(templ.Root)
+	}
+
 	return c.Info, err
 }
 
+const (
+	partialReturnWrapperTempl = `{{ $_hugo_dot := $ }}{{ $ := .Arg }}{{ with .Arg }}{{ $_hugo_dot.Set ("PLACEHOLDER") }}{{ end }}`
+)
+
+var partialReturnWrapper *parse.ListNode
+
+func init() {
+	templ, err := texttemplate.New("").Parse(partialReturnWrapperTempl)
+	if err != nil {
+		panic(err)
+	}
+	partialReturnWrapper = templ.Tree.Root
+}
+
+func (c *templateContext) wrapInPartialReturnWrapper(n *parse.ListNode) *parse.ListNode {
+	wrapper := partialReturnWrapper.CopyList()
+	withNode := wrapper.Nodes[2].(*parse.WithNode)
+	retn := withNode.List.Nodes[0]
+	setCmd := retn.(*parse.ActionNode).Pipe.Cmds[0]
+	setPipe := setCmd.Args[1].(*parse.PipeNode)
+	// Replace PLACEHOLDER with the real return value.
+	// Note that this is a PipeNode, so it will be wrapped in parens.
+	setPipe.Cmds = []*parse.CommandNode{c.returnNode}
+	withNode.List.Nodes = append(n.Nodes, retn)
+
+	return wrapper
+
+}
+
 // The truth logic in Go's template package is broken for certain values
 // for the if and with keywords. This works around that problem by wrapping
 // the node passed to if/with in a getif conditional.
@@ -141,7 +186,7 @@
 // 1) Make all .Params.CamelCase and similar into lowercase.
 // 2) Wraps every with and if pipe in getif
 // 3) Collects some information about the template content.
-func (c *templateContext) applyTransformations(n parse.Node) error {
+func (c *templateContext) applyTransformations(n parse.Node) (bool, error) {
 	switch x := n.(type) {
 	case *parse.ListNode:
 		if x != nil {
@@ -169,12 +214,16 @@
 			c.decl[x.Decl[0].Ident[0]] = x.Cmds[0].String()
 		}
 
-		for _, cmd := range x.Cmds {
-			c.applyTransformations(cmd)
+		for i, cmd := range x.Cmds {
+			keep, _ := c.applyTransformations(cmd)
+			if !keep {
+				x.Cmds = append(x.Cmds[:i], x.Cmds[i+1:]...)
+			}
 		}
 
 	case *parse.CommandNode:
 		c.collectInner(x)
+		keep := c.collectReturnNode(x)
 
 		for _, elem := range x.Args {
 			switch an := elem.(type) {
@@ -191,9 +240,10 @@
 				}
 			}
 		}
+		return keep, c.err
 	}
 
-	return c.err
+	return true, c.err
 }
 
 func (c *templateContext) applyTransformationsToNodes(nodes ...parse.Node) {
@@ -229,7 +279,7 @@
 // on the form:
 //    {{ $_hugo_config:= `{ "version": 1 }` }}
 func (c *templateContext) collectConfig(n *parse.PipeNode) {
-	if !c.isShortcode {
+	if c.typ != templateShortcode {
 		return
 	}
 	if c.configChecked {
@@ -271,7 +321,7 @@
 // collectInner determines if the given CommandNode represents a
 // shortcode call to its .Inner.
 func (c *templateContext) collectInner(n *parse.CommandNode) {
-	if !c.isShortcode {
+	if c.typ != templateShortcode {
 		return
 	}
 	if c.Info.IsInner || len(n.Args) == 0 {
@@ -292,6 +342,28 @@
 			break
 		}
 	}
+
+}
+
+func (c *templateContext) collectReturnNode(n *parse.CommandNode) bool {
+	if c.typ != templatePartial || c.returnNode != nil {
+		return true
+	}
+
+	if len(n.Args) < 2 {
+		return true
+	}
+
+	ident, ok := n.Args[0].(*parse.IdentifierNode)
+	if !ok || ident.Ident != "return" {
+		return true
+	}
+
+	c.returnNode = n
+	// Remove the "return" identifiers
+	c.returnNode.Args = c.returnNode.Args[1:]
+
+	return false
 
 }
 
--- a/tpl/tplimpl/template_ast_transformers_test.go
+++ b/tpl/tplimpl/template_ast_transformers_test.go
@@ -180,7 +180,7 @@
 func TestParamsKeysToLower(t *testing.T) {
 	t.Parallel()
 
-	_, err := applyTemplateTransformers(false, nil, nil)
+	_, err := applyTemplateTransformers(templateUndefined, nil, nil)
 	require.Error(t, err)
 
 	templ, err := template.New("foo").Funcs(testFuncs).Parse(paramsTempl)
@@ -484,10 +484,53 @@
 			require.NoError(t, err)
 
 			c := newTemplateContext(createParseTreeLookup(templ))
-			c.isShortcode = true
+			c.typ = templateShortcode
 			c.applyTransformations(templ.Tree.Root)
 
 			assert.Equal(test.expected, c.Info)
+		})
+	}
+
+}
+
+func TestPartialReturn(t *testing.T) {
+
+	tests := []struct {
+		name      string
+		tplString string
+		expected  bool
+	}{
+		{"Basic", `
+{{ $a := "Hugo Rocks!" }}
+{{ return $a }}
+`, true},
+		{"Expression", `
+{{ return add 32 }}
+`, true},
+	}
+
+	echo := func(in interface{}) interface{} {
+		return in
+	}
+
+	funcs := template.FuncMap{
+		"return": echo,
+		"add":    echo,
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			assert := require.New(t)
+
+			templ, err := template.New("foo").Funcs(funcs).Parse(test.tplString)
+			require.NoError(t, err)
+
+			_, err = applyTemplateTransformers(templatePartial, templ.Tree, createParseTreeLookup(templ))
+
+			// Just check that it doesn't fail in this test. We have functional tests
+			// in hugoblib.
+			assert.NoError(err)
+
 		})
 	}