blob: 3ff1d72d81f314dfbc3f89a3a74d35e7bec805b6 [file] [log] [blame]
Serge Bazanski4166a712021-06-07 21:58:54 +02001// package combinectx implements context.Contexts that 'combine' two other
2// 'parent' contexts. These can be used to deal with cases where you want to
3// cancel a method call whenever any of two pre-existing contexts expires first.
4//
5// For example, if you want to tie a method call to some incoming request
6// context and an active leader lease, then this library is for you.
7package combinectx
8
9import (
10 "context"
11 "sync"
12 "time"
13)
14
15// Combine 'joins' two existing 'parent' contexts into a single context. This
16// context will be Done() whenever any of the parent context is Done().
17// Combining contexts spawns a goroutine that will be cleaned when any of the
18// parent contexts is Done().
19func Combine(a, b context.Context) context.Context {
20 c := &Combined{
Serge Bazanski2098b982021-07-07 15:13:46 +020021 a: a,
22 b: b,
Serge Bazanski4166a712021-06-07 21:58:54 +020023 doneC: make(chan struct{}),
24 }
25 go c.run()
26 return c
27}
28
29type Combined struct {
30 // a is the first parent context.
31 a context.Context
32 // b is the second parent context.
33 b context.Context
34
35 // mu guards done.
36 mu sync.Mutex
37 // done is an Error if either parent context is Done(), or nil otherwise.
38 done *Error
39 // doneC is closed when either parent context is Done() and Error is set.
40 doneC chan struct{}
41}
42
43// Error wraps errors returned by parent contexts.
44type Error struct {
45 // underlyingA points to an error returned by the first parent context if the
46 // combined context was Done() as a result of the first parent context being
47 // Done().
48 underlyingA *error
49 // underlyingB points to an error returned by the second parent context if the
50 // combined context was Done() as a result of the second parent context being
51 // Done().
52 underlyingB *error
53}
54
55func (e *Error) Error() string {
56 if e.underlyingA != nil {
57 return (*e.underlyingA).Error()
58 }
59 if e.underlyingB != nil {
60 return (*e.underlyingB).Error()
61 }
62 return ""
63}
64
65// First returns true if the Combined context's first parent was Done().
66func (e *Error) First() bool {
67 return e.underlyingA != nil
68}
69
70// Second returns true if the Combined context's second parent was Done().
71func (e *Error) Second() bool {
72 return e.underlyingB != nil
73}
74
75// Unwrap returns the underlying error of either parent context that is Done().
76func (e *Error) Unwrap() error {
77 if e.underlyingA != nil {
78 return *e.underlyingA
79 }
80 if e.underlyingB != nil {
81 return *e.underlyingB
82 }
83 return nil
84}
85
86// Is allows errors.Is to be true against any *Error.
87func (e *Error) Is(target error) bool {
88 if _, ok := target.(*Error); ok {
89 return true
90 }
91 return false
92}
93
94// As allows errors.As to be true against any *Error.
95func (e *Error) As(target interface{}) bool {
96 if v, ok := target.(**Error); ok {
97 *v = e
98 return true
99 }
100 return false
101}
102
103// run is the main logic that ties together the two parent contexts. It exits
104// when either parent context is canceled.
105func (c *Combined) run() {
106 mark := func(first bool, err error) {
107 c.mu.Lock()
108 defer c.mu.Unlock()
109 c.done = &Error{}
110 if first {
111 c.done.underlyingA = &err
112 } else {
113 c.done.underlyingB = &err
114 }
115 close(c.doneC)
116 }
117 select {
118 case <-c.a.Done():
119 mark(true, c.a.Err())
120 case <-c.b.Done():
121 mark(false, c.b.Err())
122 }
123}
124
125// Deadline returns the earlier Deadline from the two parent contexts, if any.
126func (c *Combined) Deadline() (deadline time.Time, ok bool) {
127 d1, ok1 := c.a.Deadline()
128 d2, ok2 := c.b.Deadline()
129
130 if ok1 && !ok2 {
131 return d1, true
132 }
133 if ok2 && !ok1 {
134 return d2, true
135 }
136 if !ok1 && !ok2 {
137 return time.Time{}, false
138 }
139
140 if d1.Before(d2) {
141 return d1, true
142 }
143 return d2, true
144}
145
146func (c *Combined) Done() <-chan struct{} {
147 return c.doneC
148}
149
150// Err returns nil if neither parent context is Done() yet, or an error otherwise.
151// The returned errors will have the following properties:
152// 1) errors.Is(err, Error{}) will always return true.
153// 2) errors.Is(err, ctx.Err()) will return true if the combined context was
154// canceled with the same error as ctx.Err().
155// However, this does NOT mean that the combined context was Done() because
156// of the ctx being Done() - to ensure this is the case, use errors.As() to
157// retrieve an Error and its First()/Second() methods.
158// 3) errors.Is(err, context.{Canceled,DeadlineExceeded}) will return true if
159// the combined context is Canceled or DeadlineExceeded.
160// 4) errors.Is will return false otherwise.
161// 5) errors.As(err, &&Error{})) will always return true. The Error object can
162// then be used to check the cause of the combined context's error.
163func (c *Combined) Err() error {
164 c.mu.Lock()
165 defer c.mu.Unlock()
166 if c.done == nil {
167 return nil
168 }
169 return c.done
170}
171
172// Value returns the value located under the given key by checking the first and
173// second parent context in order.
174func (c *Combined) Value(key interface{}) interface{} {
175 if v := c.a.Value(key); v != nil {
176 return v
177 }
178 if v := c.b.Value(key); v != nil {
179 return v
180 }
181 return nil
182}