shithub: mc

Download patch

ref: 3d9e396e9696c1aa2290fce09626c33d15e1ea03
parent: 74355dfc48d3a142918d894ece67ecf7cde9d13d
author: iriri <[email protected]>
date: Thu Jun 7 19:53:29 EDT 2018

Add convenience wrappers for atomic operations on pointers

--- a/lib/thread/atomic.myr
+++ b/lib/thread/atomic.myr
@@ -1,4 +1,5 @@
 use std
+use "common"
 
 pkg thread =
 	trait atomic @a :: integral,numeric @a =
@@ -13,6 +14,11 @@
 	impl atomic int64
 	impl atomic uint32
 	impl atomic uint64
+
+	generic xgetptr : (p : @a## -> std.option(@a#))
+	generic xsetptr : (p : @a##, v : std.option(@a#) -> void)
+	generic xcasptr : (p : @a##, old : std.option(@a#), new : std.option(@a#) -> std.option(@a#))
+	generic xchgptr : (p : @a##, new : std.option(@a#) -> std.option(@a#))
 ;;
 
 impl atomic int32 =
@@ -56,6 +62,31 @@
 	xcas	= {p, old, new; -> xcasp(p, old, new)}
 	xchg	= {p, v; -> xchgp(p, v)}
 ;;
+
+generic xgetptr = {p
+	match xget((p : std.intptr#))
+	| 0: -> `std.None
+	| n: -> `std.Some (n : @a#)
+	;;
+}
+
+generic xsetptr = {p, v
+	xset((p : std.intptr#), (std.getv(v, Zptr) : std.intptr))
+}
+
+generic xcasptr = {p, old, new
+	match xcas((p : std.intptr#), (std.getv(old, Zptr) : std.intptr), (std.getv(new, Zptr) : std.intptr))
+	| 0: -> `std.None
+	| n: -> `std.Some (n : @a#)
+	;;
+}
+
+generic xchgptr = {p, new
+	match xchg((p : std.intptr#), (std.getv(new, Zptr) : std.intptr))
+	| 0: -> `std.None
+	| n: -> `std.Some (n : @a#)
+	;;
+}
 
 extern const xget32	: (p : uint32# -> uint32)
 extern const xget64	: (p : uint64# -> uint64)
--- a/lib/thread/test/atomic.myr
+++ b/lib/thread/test/atomic.myr
@@ -16,6 +16,8 @@
 		/* nothing */
 	;;
 	std.assert(val == 2_000_000, "atomics are broken\n")
+
+	testintptr()
 }
 
 const incvar = {
@@ -27,3 +29,16 @@
 	thread.xadd(&done, 1)
 }
 
+const testintptr = {
+	var i = 123
+	var j = 456
+	var p = &i
+
+	std.assert(std.get(thread.xgetptr(&p))# == 123, "xgetptr is broken\n")
+	thread.xsetptr(&p, `std.Some &j)
+	std.assert(p# == 456, "xsetptr is broken\n")
+	std.assert(std.get(thread.xcasptr(&p, `std.Some &j, `std.Some &i)) == &j, "xcasptr is broken\n")
+	std.assert(p# == 123, "xcasptr is broken\n")
+	std.assert(std.get(thread.xchgptr(&p, `std.None)) == &i, "xchgptr is broken\n")
+	std.assert((p : std.intptr) == 0, "xchgptr is broken\n")
+}