shithub: hugo

Download patch

ref: 02effd9dc46201df250564cdbbee1ca2291eb8b4
parent: f52b040ee126ec0c48f1d273681a860fe7814314
author: Bjørn Erik Pedersen <[email protected]>
date: Mon Mar 21 16:42:27 EDT 2016

Protect against concurrent Scratch read and write

Fixes #2005

--- a/hugolib/scratch.go
+++ b/hugolib/scratch.go
@@ -17,11 +17,13 @@
 	"github.com/spf13/hugo/helpers"
 	"reflect"
 	"sort"
+	"sync"
 )
 
 // Scratch is a writable context used for stateful operations in Page/Node rendering.
 type Scratch struct {
 	values map[string]interface{}
+	mu     sync.RWMutex
 }
 
 // For single values, Add will add (using the + operator) the addend to the existing addend (if found).
@@ -29,6 +31,9 @@
 //
 // If the first add for a key is an array or slice, then the next value(s) will be appended.
 func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	var newVal interface{}
 	existingAddend, found := c.values[key]
 	if found {
@@ -59,6 +64,9 @@
 // Set stores a value with the given key in the Node context.
 // This value can later be retrieved with Get.
 func (c *Scratch) Set(key string, value interface{}) string {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	c.values[key] = value
 	return ""
 }
@@ -65,6 +73,9 @@
 
 // Get returns a value previously set by Add or Set
 func (c *Scratch) Get(key string) interface{} {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
 	return c.values[key]
 }
 
@@ -71,6 +82,9 @@
 // SetInMap stores a value to a map with the given key in the Node context.
 // This map can later be retrieved with GetSortedMapValues.
 func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
 	_, found := c.values[key]
 	if !found {
 		c.values[key] = make(map[string]interface{})
@@ -82,6 +96,9 @@
 
 // GetSortedMapValues returns a sorted map previously filled with SetInMap
 func (c *Scratch) GetSortedMapValues(key string) interface{} {
+	c.mu.RLock()
+	defer c.mu.RUnlock()
+
 	if c.values[key] == nil {
 		return nil
 	}
--- a/hugolib/scratch_test.go
+++ b/hugolib/scratch_test.go
@@ -16,6 +16,7 @@
 import (
 	"github.com/stretchr/testify/assert"
 	"reflect"
+	"sync"
 	"testing"
 )
 
@@ -78,6 +79,41 @@
 	scratch := newScratch()
 	scratch.Set("key", "val")
 	assert.Equal(t, "val", scratch.Get("key"))
+}
+
+// Issue #2005
+func TestScratchInParallel(t *testing.T) {
+	var wg sync.WaitGroup
+	scratch := newScratch()
+	key := "counter"
+	scratch.Set(key, 1)
+	for i := 1; i <= 10; i++ {
+		wg.Add(1)
+		go func(j int) {
+			for k := 0; k < 10; k++ {
+				newVal := k + j
+
+				_, err := scratch.Add(key, newVal)
+				if err != nil {
+					t.Errorf("Got err %s", err)
+				}
+
+				scratch.Set(key, newVal)
+
+				val := scratch.Get(key)
+
+				if counter, ok := val.(int); ok {
+					if counter < 1 {
+						t.Errorf("Got %d", counter)
+					}
+				} else {
+					t.Errorf("Got %T", val)
+				}
+			}
+			wg.Done()
+		}(i)
+	}
+	wg.Wait()
 }
 
 func TestScratchGet(t *testing.T) {