Serge Bazanski | 4166a71 | 2021-06-07 21:58:54 +0200 | [diff] [blame] | 1 | // package combinectx implements context.Contexts that 'combine' two other |
| 2 | // 'parent' contexts. These can be used to deal with cases where you want to |
| 3 | // cancel a method call whenever any of two pre-existing contexts expires first. |
| 4 | // |
| 5 | // For example, if you want to tie a method call to some incoming request |
| 6 | // context and an active leader lease, then this library is for you. |
| 7 | package combinectx |
| 8 | |
| 9 | import ( |
| 10 | "context" |
| 11 | "sync" |
| 12 | "time" |
| 13 | ) |
| 14 | |
| 15 | // Combine 'joins' two existing 'parent' contexts into a single context. This |
| 16 | // context will be Done() whenever any of the parent context is Done(). |
| 17 | // Combining contexts spawns a goroutine that will be cleaned when any of the |
| 18 | // parent contexts is Done(). |
| 19 | func Combine(a, b context.Context) context.Context { |
| 20 | c := &Combined{ |
| 21 | a: a, |
| 22 | b: b, |
| 23 | doneC: make(chan struct{}), |
| 24 | } |
| 25 | go c.run() |
| 26 | return c |
| 27 | } |
| 28 | |
| 29 | type Combined struct { |
| 30 | // a is the first parent context. |
| 31 | a context.Context |
| 32 | // b is the second parent context. |
| 33 | b context.Context |
| 34 | |
| 35 | // mu guards done. |
| 36 | mu sync.Mutex |
| 37 | // done is an Error if either parent context is Done(), or nil otherwise. |
| 38 | done *Error |
| 39 | // doneC is closed when either parent context is Done() and Error is set. |
| 40 | doneC chan struct{} |
| 41 | } |
| 42 | |
| 43 | // Error wraps errors returned by parent contexts. |
| 44 | type Error struct { |
| 45 | // underlyingA points to an error returned by the first parent context if the |
| 46 | // combined context was Done() as a result of the first parent context being |
| 47 | // Done(). |
| 48 | underlyingA *error |
| 49 | // underlyingB points to an error returned by the second parent context if the |
| 50 | // combined context was Done() as a result of the second parent context being |
| 51 | // Done(). |
| 52 | underlyingB *error |
| 53 | } |
| 54 | |
| 55 | func (e *Error) Error() string { |
| 56 | if e.underlyingA != nil { |
| 57 | return (*e.underlyingA).Error() |
| 58 | } |
| 59 | if e.underlyingB != nil { |
| 60 | return (*e.underlyingB).Error() |
| 61 | } |
| 62 | return "" |
| 63 | } |
| 64 | |
| 65 | // First returns true if the Combined context's first parent was Done(). |
| 66 | func (e *Error) First() bool { |
| 67 | return e.underlyingA != nil |
| 68 | } |
| 69 | |
| 70 | // Second returns true if the Combined context's second parent was Done(). |
| 71 | func (e *Error) Second() bool { |
| 72 | return e.underlyingB != nil |
| 73 | } |
| 74 | |
| 75 | // Unwrap returns the underlying error of either parent context that is Done(). |
| 76 | func (e *Error) Unwrap() error { |
| 77 | if e.underlyingA != nil { |
| 78 | return *e.underlyingA |
| 79 | } |
| 80 | if e.underlyingB != nil { |
| 81 | return *e.underlyingB |
| 82 | } |
| 83 | return nil |
| 84 | } |
| 85 | |
| 86 | // Is allows errors.Is to be true against any *Error. |
| 87 | func (e *Error) Is(target error) bool { |
| 88 | if _, ok := target.(*Error); ok { |
| 89 | return true |
| 90 | } |
| 91 | return false |
| 92 | } |
| 93 | |
| 94 | // As allows errors.As to be true against any *Error. |
| 95 | func (e *Error) As(target interface{}) bool { |
| 96 | if v, ok := target.(**Error); ok { |
| 97 | *v = e |
| 98 | return true |
| 99 | } |
| 100 | return false |
| 101 | } |
| 102 | |
| 103 | // run is the main logic that ties together the two parent contexts. It exits |
| 104 | // when either parent context is canceled. |
| 105 | func (c *Combined) run() { |
| 106 | mark := func(first bool, err error) { |
| 107 | c.mu.Lock() |
| 108 | defer c.mu.Unlock() |
| 109 | c.done = &Error{} |
| 110 | if first { |
| 111 | c.done.underlyingA = &err |
| 112 | } else { |
| 113 | c.done.underlyingB = &err |
| 114 | } |
| 115 | close(c.doneC) |
| 116 | } |
| 117 | select { |
| 118 | case <-c.a.Done(): |
| 119 | mark(true, c.a.Err()) |
| 120 | case <-c.b.Done(): |
| 121 | mark(false, c.b.Err()) |
| 122 | } |
| 123 | } |
| 124 | |
| 125 | // Deadline returns the earlier Deadline from the two parent contexts, if any. |
| 126 | func (c *Combined) Deadline() (deadline time.Time, ok bool) { |
| 127 | d1, ok1 := c.a.Deadline() |
| 128 | d2, ok2 := c.b.Deadline() |
| 129 | |
| 130 | if ok1 && !ok2 { |
| 131 | return d1, true |
| 132 | } |
| 133 | if ok2 && !ok1 { |
| 134 | return d2, true |
| 135 | } |
| 136 | if !ok1 && !ok2 { |
| 137 | return time.Time{}, false |
| 138 | } |
| 139 | |
| 140 | if d1.Before(d2) { |
| 141 | return d1, true |
| 142 | } |
| 143 | return d2, true |
| 144 | } |
| 145 | |
| 146 | func (c *Combined) Done() <-chan struct{} { |
| 147 | return c.doneC |
| 148 | } |
| 149 | |
| 150 | // Err returns nil if neither parent context is Done() yet, or an error otherwise. |
| 151 | // The returned errors will have the following properties: |
| 152 | // 1) errors.Is(err, Error{}) will always return true. |
| 153 | // 2) errors.Is(err, ctx.Err()) will return true if the combined context was |
| 154 | // canceled with the same error as ctx.Err(). |
| 155 | // However, this does NOT mean that the combined context was Done() because |
| 156 | // of the ctx being Done() - to ensure this is the case, use errors.As() to |
| 157 | // retrieve an Error and its First()/Second() methods. |
| 158 | // 3) errors.Is(err, context.{Canceled,DeadlineExceeded}) will return true if |
| 159 | // the combined context is Canceled or DeadlineExceeded. |
| 160 | // 4) errors.Is will return false otherwise. |
| 161 | // 5) errors.As(err, &&Error{})) will always return true. The Error object can |
| 162 | // then be used to check the cause of the combined context's error. |
| 163 | func (c *Combined) Err() error { |
| 164 | c.mu.Lock() |
| 165 | defer c.mu.Unlock() |
| 166 | if c.done == nil { |
| 167 | return nil |
| 168 | } |
| 169 | return c.done |
| 170 | } |
| 171 | |
| 172 | // Value returns the value located under the given key by checking the first and |
| 173 | // second parent context in order. |
| 174 | func (c *Combined) Value(key interface{}) interface{} { |
| 175 | if v := c.a.Value(key); v != nil { |
| 176 | return v |
| 177 | } |
| 178 | if v := c.b.Value(key); v != nil { |
| 179 | return v |
| 180 | } |
| 181 | return nil |
| 182 | } |
| 183 | |
| 184 | |