m/pkg/combinectx: implement
This implements combinectx, a Go library for combining two contexts into
a single one. We need this for the new curator logic (where we want to
cancel RPC calls both when the incoming request gets canceled but also
when leadership status changes), and this functionality has been
factored out as a reusable, generic library.
Prior art:
1) https://github.com/golang/go/issues/36503
Proposal to add Merge() to context stdlib package. Unimplemented.
2) https://github.com/teivah/onecontext
Complex reflect-based logic for arbitrary amount of contexts to join,
no functionality to detect which context caused the joined context to
be canceled.
3) https://github.com/LK4D4/joincontext
No functionality to detect which context caused the joined context to
be canceled.
Change-Id: I774607da38b06c192ff0fee133eb258abd500864
Reviewed-on: https://review.monogon.dev/c/monogon/+/123
Reviewed-by: Leopold Schabel <leo@nexantic.com>
diff --git a/metropolis/pkg/combinectx/combinectx.go b/metropolis/pkg/combinectx/combinectx.go
new file mode 100644
index 0000000..5df4ff0
--- /dev/null
+++ b/metropolis/pkg/combinectx/combinectx.go
@@ -0,0 +1,184 @@
+// package combinectx implements context.Contexts that 'combine' two other
+// 'parent' contexts. These can be used to deal with cases where you want to
+// cancel a method call whenever any of two pre-existing contexts expires first.
+//
+// For example, if you want to tie a method call to some incoming request
+// context and an active leader lease, then this library is for you.
+package combinectx
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+// Combine 'joins' two existing 'parent' contexts into a single context. This
+// context will be Done() whenever any of the parent context is Done().
+// Combining contexts spawns a goroutine that will be cleaned when any of the
+// parent contexts is Done().
+func Combine(a, b context.Context) context.Context {
+ c := &Combined{
+ a: a,
+ b: b,
+ doneC: make(chan struct{}),
+ }
+ go c.run()
+ return c
+}
+
+type Combined struct {
+ // a is the first parent context.
+ a context.Context
+ // b is the second parent context.
+ b context.Context
+
+ // mu guards done.
+ mu sync.Mutex
+ // done is an Error if either parent context is Done(), or nil otherwise.
+ done *Error
+ // doneC is closed when either parent context is Done() and Error is set.
+ doneC chan struct{}
+}
+
+// Error wraps errors returned by parent contexts.
+type Error struct {
+ // underlyingA points to an error returned by the first parent context if the
+ // combined context was Done() as a result of the first parent context being
+ // Done().
+ underlyingA *error
+ // underlyingB points to an error returned by the second parent context if the
+ // combined context was Done() as a result of the second parent context being
+ // Done().
+ underlyingB *error
+}
+
+func (e *Error) Error() string {
+ if e.underlyingA != nil {
+ return (*e.underlyingA).Error()
+ }
+ if e.underlyingB != nil {
+ return (*e.underlyingB).Error()
+ }
+ return ""
+}
+
+// First returns true if the Combined context's first parent was Done().
+func (e *Error) First() bool {
+ return e.underlyingA != nil
+}
+
+// Second returns true if the Combined context's second parent was Done().
+func (e *Error) Second() bool {
+ return e.underlyingB != nil
+}
+
+// Unwrap returns the underlying error of either parent context that is Done().
+func (e *Error) Unwrap() error {
+ if e.underlyingA != nil {
+ return *e.underlyingA
+ }
+ if e.underlyingB != nil {
+ return *e.underlyingB
+ }
+ return nil
+}
+
+// Is allows errors.Is to be true against any *Error.
+func (e *Error) Is(target error) bool {
+ if _, ok := target.(*Error); ok {
+ return true
+ }
+ return false
+}
+
+// As allows errors.As to be true against any *Error.
+func (e *Error) As(target interface{}) bool {
+ if v, ok := target.(**Error); ok {
+ *v = e
+ return true
+ }
+ return false
+}
+
+// run is the main logic that ties together the two parent contexts. It exits
+// when either parent context is canceled.
+func (c *Combined) run() {
+ mark := func(first bool, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.done = &Error{}
+ if first {
+ c.done.underlyingA = &err
+ } else {
+ c.done.underlyingB = &err
+ }
+ close(c.doneC)
+ }
+ select {
+ case <-c.a.Done():
+ mark(true, c.a.Err())
+ case <-c.b.Done():
+ mark(false, c.b.Err())
+ }
+}
+
+// Deadline returns the earlier Deadline from the two parent contexts, if any.
+func (c *Combined) Deadline() (deadline time.Time, ok bool) {
+ d1, ok1 := c.a.Deadline()
+ d2, ok2 := c.b.Deadline()
+
+ if ok1 && !ok2 {
+ return d1, true
+ }
+ if ok2 && !ok1 {
+ return d2, true
+ }
+ if !ok1 && !ok2 {
+ return time.Time{}, false
+ }
+
+ if d1.Before(d2) {
+ return d1, true
+ }
+ return d2, true
+}
+
+func (c *Combined) Done() <-chan struct{} {
+ return c.doneC
+}
+
+// Err returns nil if neither parent context is Done() yet, or an error otherwise.
+// The returned errors will have the following properties:
+// 1) errors.Is(err, Error{}) will always return true.
+// 2) errors.Is(err, ctx.Err()) will return true if the combined context was
+// canceled with the same error as ctx.Err().
+// However, this does NOT mean that the combined context was Done() because
+// of the ctx being Done() - to ensure this is the case, use errors.As() to
+// retrieve an Error and its First()/Second() methods.
+// 3) errors.Is(err, context.{Canceled,DeadlineExceeded}) will return true if
+// the combined context is Canceled or DeadlineExceeded.
+// 4) errors.Is will return false otherwise.
+// 5) errors.As(err, &&Error{})) will always return true. The Error object can
+// then be used to check the cause of the combined context's error.
+func (c *Combined) Err() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.done == nil {
+ return nil
+ }
+ return c.done
+}
+
+// Value returns the value located under the given key by checking the first and
+// second parent context in order.
+func (c *Combined) Value(key interface{}) interface{} {
+ if v := c.a.Value(key); v != nil {
+ return v
+ }
+ if v := c.b.Value(key); v != nil {
+ return v
+ }
+ return nil
+}
+
+