third_party/k8s: backport data race fix

The metrics component of K8s had a very questionable WithContext
implementation which set the context into the metric, not the returned
handle. This causes incorrect metrics as well as data races. Backport
the fix from upstream.

Change-Id: I4f8ce9d194ba9e7b3420007863286ad9f5b612b6
Reviewed-on: https://review.monogon.dev/c/monogon/+/3780
Tested-by: Jenkins CI
Reviewed-by: Tim Windelschmidt <tim@monogon.tech>
diff --git a/third_party/go/patches/k8s-fix-metrics-data-race.patch b/third_party/go/patches/k8s-fix-metrics-data-race.patch
new file mode 100644
index 0000000..8bc0e2e
--- /dev/null
+++ b/third_party/go/patches/k8s-fix-metrics-data-race.patch
@@ -0,0 +1,898 @@
+From 580612cea9e4eaec1c25821d872e1551335f8c95 Mon Sep 17 00:00:00 2001
+From: Pranshu Srivastava <rexagod@gmail.com>
+Date: Tue, 5 Nov 2024 18:37:40 +0530
+Subject: [PATCH] address `metrics.(*Counter)` context data-race
+
+PR: https://github.com/kubernetes/kubernetes/pull/128575
+Fixes: https://github.com/kubernetes/kubernetes/issues/128548
+Signed-off-by: Pranshu Srivastava <rexagod@gmail.com>
+---
+ .../k8s.io/component-base/metrics/counter.go  | 105 +++++---
+ .../component-base/metrics/counter_test.go    | 254 +++++++++++++++++-
+ .../component-base/metrics/histogram.go       |  44 +--
+ .../component-base/metrics/histogram_test.go  | 207 ++++++++++++--
+ 4 files changed, 537 insertions(+), 73 deletions(-)
+
+diff --git a/metrics/counter.go b/metrics/counter.go
+index 8a7dd71541c23..f92c8d59f00e5 100644
+--- a/metrics/counter.go
++++ b/metrics/counter.go
+@@ -30,23 +30,22 @@ import (
+ // Counter is our internal representation for our wrapping struct around prometheus
+ // counters. Counter implements both kubeCollector and CounterMetric.
+ type Counter struct {
+-	ctx context.Context
+ 	CounterMetric
+ 	*CounterOpts
+ 	lazyMetric
+ 	selfCollector
+ }
+ 
++type CounterWithContext struct {
++	ctx context.Context
++	*Counter
++}
++
+ // The implementation of the Metric interface is expected by testutil.GetCounterMetricValue.
+ var _ Metric = &Counter{}
+ 
+ // All supported exemplar metric types implement the metricWithExemplar interface.
+-var _ metricWithExemplar = &Counter{}
+-
+-// exemplarCounterMetric holds a context to extract exemplar labels from, and a counter metric to attach them to. It implements the metricWithExemplar interface.
+-type exemplarCounterMetric struct {
+-	*Counter
+-}
++var _ metricWithExemplar = &CounterWithContext{}
+ 
+ // NewCounter returns an object which satisfies the kubeCollector and CounterMetric interfaces.
+ // However, the object returned will not measure anything unless the collector is first
+@@ -106,28 +105,17 @@ func (c *Counter) initializeDeprecatedMetric() {
+ }
+ 
+ // WithContext allows the normal Counter metric to pass in context.
+-func (c *Counter) WithContext(ctx context.Context) CounterMetric {
+-	c.ctx = ctx
+-	return c.CounterMetric
+-}
+-
+-// withExemplar initializes the exemplarMetric object and sets the exemplar value.
+-func (c *Counter) withExemplar(v float64) {
+-	(&exemplarCounterMetric{c}).withExemplar(v)
+-}
+-
+-func (c *Counter) Add(v float64) {
+-	c.withExemplar(v)
+-}
+-
+-func (c *Counter) Inc() {
+-	c.withExemplar(1)
++func (c *Counter) WithContext(ctx context.Context) *CounterWithContext {
++	return &CounterWithContext{
++		ctx:     ctx,
++		Counter: c,
++	}
+ }
+ 
+ // withExemplar attaches an exemplar to the metric.
+-func (e *exemplarCounterMetric) withExemplar(v float64) {
+-	if m, ok := e.CounterMetric.(prometheus.ExemplarAdder); ok {
+-		maybeSpanCtx := trace.SpanContextFromContext(e.ctx)
++func (c *CounterWithContext) withExemplar(v float64) {
++	if m, ok := c.CounterMetric.(prometheus.ExemplarAdder); ok {
++		maybeSpanCtx := trace.SpanContextFromContext(c.ctx)
+ 		if maybeSpanCtx.IsValid() && maybeSpanCtx.IsSampled() {
+ 			exemplarLabels := prometheus.Labels{
+ 				"trace_id": maybeSpanCtx.TraceID().String(),
+@@ -138,7 +126,23 @@ func (e *exemplarCounterMetric) withExemplar(v float64) {
+ 		}
+ 	}
+ 
+-	e.CounterMetric.Add(v)
++	c.CounterMetric.Add(v)
++}
++
++func (c *Counter) Add(v float64) {
++	c.CounterMetric.Add(v)
++}
++
++func (c *Counter) Inc() {
++	c.CounterMetric.Inc()
++}
++
++func (c *CounterWithContext) Add(v float64) {
++	c.withExemplar(v)
++}
++
++func (c *CounterWithContext) Inc() {
++	c.withExemplar(1)
+ }
+ 
+ // CounterVec is the internal representation of our wrapping struct around prometheus
+@@ -295,12 +299,51 @@ type CounterVecWithContext struct {
+ 	ctx context.Context
+ }
+ 
++// CounterVecWithContextWithCounter is the wrapper of CounterVecWithContext with counter.
++type CounterVecWithContextWithCounter struct {
++	*CounterVecWithContext
++	counter CounterMetric
++}
++
++// withExemplar attaches an exemplar to the metric.
++func (vcc *CounterVecWithContextWithCounter) withExemplar(v float64) {
++	if m, ok := vcc.counter.(prometheus.ExemplarAdder); ok {
++		maybeSpanCtx := trace.SpanContextFromContext(vcc.ctx)
++		if maybeSpanCtx.IsValid() && maybeSpanCtx.IsSampled() {
++			exemplarLabels := prometheus.Labels{
++				"trace_id": maybeSpanCtx.TraceID().String(),
++				"span_id":  maybeSpanCtx.SpanID().String(),
++			}
++			m.AddWithExemplar(v, exemplarLabels)
++			return
++		}
++	}
++
++	vcc.counter.Add(v)
++}
++
++// Add adds the given value to the counter with the provided labels.
++func (vcc *CounterVecWithContextWithCounter) Add(v float64) {
++	vcc.withExemplar(v)
++}
++
++// Inc increments the counter with the provided labels.
++func (vcc *CounterVecWithContextWithCounter) Inc() {
++	vcc.withExemplar(1)
++}
++
+ // WithLabelValues is the wrapper of CounterVec.WithLabelValues.
+-func (vc *CounterVecWithContext) WithLabelValues(lvs ...string) CounterMetric {
+-	return vc.CounterVec.WithLabelValues(lvs...)
++func (vc *CounterVecWithContext) WithLabelValues(lvs ...string) *CounterVecWithContextWithCounter {
++	return &CounterVecWithContextWithCounter{
++		CounterVecWithContext: vc,
++		counter:               vc.CounterVec.WithLabelValues(lvs...),
++	}
+ }
+ 
+ // With is the wrapper of CounterVec.With.
+-func (vc *CounterVecWithContext) With(labels map[string]string) CounterMetric {
+-	return vc.CounterVec.With(labels)
++func (vc *CounterVecWithContext) With(labels map[string]string) *CounterVecWithContextWithCounter {
++	return &CounterVecWithContextWithCounter{
++		CounterVecWithContext: vc,
++		counter:               vc.CounterVec.With(labels),
++	}
+ }
+diff --git a/metrics/counter_test.go b/metrics/counter_test.go
+index 2afcc4d63044c..ea47cab282985 100644
+--- a/metrics/counter_test.go
++++ b/metrics/counter_test.go
+@@ -292,7 +292,7 @@ func TestCounterWithLabelValueAllowList(t *testing.T) {
+ }
+ 
+ func TestCounterWithExemplar(t *testing.T) {
+-	// Set exemplar.
++	// Create context.
+ 	fn := func(offset int) []byte {
+ 		arr := make([]byte, 16)
+ 		for i := 0; i < 16; i++ {
+@@ -313,8 +313,7 @@ func TestCounterWithExemplar(t *testing.T) {
+ 	counter := NewCounter(&CounterOpts{
+ 		Name: "metric_exemplar_test",
+ 		Help: "helpless",
+-	})
+-	_ = counter.WithContext(ctxForSpanCtx)
++	}).WithContext(ctxForSpanCtx)
+ 
+ 	// Register counter.
+ 	registry := newKubeRegistry(apimachineryversion.Info{
+@@ -381,4 +380,253 @@ func TestCounterWithExemplar(t *testing.T) {
+ 			t.Fatalf("Got unexpected label %s", *l.Name)
+ 		}
+ 	}
++
++	// Verify that all contextual counter calls are exclusive.
++	contextualCounter := NewCounter(&CounterOpts{
++		Name: "contextual_counter",
++		Help: "helpless",
++	})
++	spanIDa := trace.SpanID(fn(3))
++	traceIDa := trace.TraceID(fn(4))
++	contextualCounterA := contextualCounter.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			SpanID:     spanIDa,
++			TraceID:    traceIDa,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++	spanIDb := trace.SpanID(fn(5))
++	traceIDb := trace.TraceID(fn(6))
++	contextualCounterB := contextualCounter.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			SpanID:     spanIDb,
++			TraceID:    traceIDb,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++
++	runs := []struct {
++		spanID            trace.SpanID
++		traceID           trace.TraceID
++		contextualCounter *CounterWithContext
++	}{
++		{
++			spanID:            spanIDa,
++			traceID:           traceIDa,
++			contextualCounter: contextualCounterA,
++		},
++		{
++			spanID:            spanIDb,
++			traceID:           traceIDb,
++			contextualCounter: contextualCounterB,
++		},
++	}
++	for _, run := range runs {
++		registry.MustRegister(run.contextualCounter)
++		run.contextualCounter.Inc()
++
++		mfs, err = registry.Gather()
++		if err != nil {
++			t.Fatalf("Gather failed %v", err)
++		}
++		if len(mfs) != 2 {
++			t.Fatalf("Got %v metric families, Want: 2 metric families", len(mfs))
++		}
++
++		dtoMetric := mfs[0].GetMetric()[0]
++		if dtoMetric.GetCounter().GetExemplar() == nil {
++			t.Fatalf("Got nil exemplar, wanted an exemplar")
++		}
++		dtoMetricLabels := dtoMetric.GetCounter().GetExemplar().GetLabel()
++		if len(dtoMetricLabels) != 2 {
++			t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(dtoMetricLabels))
++		}
++		for _, l := range dtoMetricLabels {
++			switch *l.Name {
++			case "trace_id":
++				if *l.Value != run.traceID.String() {
++					t.Fatalf("Got %s as traceID, wanted %s", *l.Value, run.traceID.String())
++				}
++			case "span_id":
++				if *l.Value != run.spanID.String() {
++					t.Fatalf("Got %s as spanID, wanted %s", *l.Value, run.spanID.String())
++				}
++			default:
++				t.Fatalf("Got unexpected label %s", *l.Name)
++			}
++		}
++
++		registry.Unregister(run.contextualCounter)
++	}
++}
++
++func TestCounterVecWithExemplar(t *testing.T) {
++	// Create context.
++	fn := func(offset int) []byte {
++		arr := make([]byte, 16)
++		for i := 0; i < 16; i++ {
++			arr[i] = byte(2<<7 - i - offset)
++		}
++		return arr
++	}
++	traceID := trace.TraceID(fn(1))
++	spanID := trace.SpanID(fn(2))
++	ctxForSpanCtx := trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			SpanID:     spanID,
++			TraceID:    traceID,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	)
++	toAdd := float64(40)
++
++	// Create contextual counter.
++	counter := NewCounterVec(&CounterOpts{
++		Name: "metricvec_exemplar_test",
++		Help: "helpless",
++	}, []string{"label_a"}).WithContext(ctxForSpanCtx)
++
++	// Register counter.
++	registry := newKubeRegistry(apimachineryversion.Info{
++		Major:      "1",
++		Minor:      "15",
++		GitVersion: "v1.15.0-alpha-1.12345",
++	})
++	registry.MustRegister(counter)
++
++	// Call underlying exemplar methods.
++	counter.WithLabelValues("a").Add(toAdd)
++	counter.WithLabelValues("a").Inc()
++	counter.WithLabelValues("a").Inc()
++
++	// Gather.
++	mfs, err := registry.Gather()
++	if err != nil {
++		t.Fatalf("Gather failed %v", err)
++	}
++	if len(mfs) != 1 {
++		t.Fatalf("Got %v metric families, Want: 1 metric family", len(mfs))
++	}
++
++	// Verify metric type.
++	mf := mfs[0]
++	var m *dto.Metric
++	switch mf.GetType() {
++	case dto.MetricType_COUNTER:
++		m = mfs[0].GetMetric()[0]
++	default:
++		t.Fatalf("Got %v metric type, Want: %v metric type", mf.GetType(), dto.MetricType_COUNTER)
++	}
++
++	// Verify value.
++	want := toAdd + 2
++	got := m.GetCounter().GetValue()
++	if got != want {
++		t.Fatalf("Got %f, wanted %f as the count", got, want)
++	}
++
++	// Verify exemplars.
++	e := m.GetCounter().GetExemplar()
++	if e == nil {
++		t.Fatalf("Got nil exemplar, wanted an exemplar")
++	}
++	eLabels := e.GetLabel()
++	if eLabels == nil {
++		t.Fatalf("Got nil exemplar label, wanted an exemplar label")
++	}
++	if len(eLabels) != 2 {
++		t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(eLabels))
++	}
++	for _, l := range eLabels {
++		switch *l.Name {
++		case "trace_id":
++			if *l.Value != traceID.String() {
++				t.Fatalf("Got %s as traceID, wanted %s", *l.Value, traceID.String())
++			}
++		case "span_id":
++			if *l.Value != spanID.String() {
++				t.Fatalf("Got %s as spanID, wanted %s", *l.Value, spanID.String())
++			}
++		default:
++			t.Fatalf("Got unexpected label %s", *l.Name)
++		}
++	}
++
++	// Verify that all contextual counter calls are exclusive.
++	contextualCounterVec := NewCounterVec(&CounterOpts{
++		Name: "contextual_counter",
++		Help: "helpless",
++	}, []string{"label_a"})
++	spanIDa := trace.SpanID(fn(3))
++	traceIDa := trace.TraceID(fn(4))
++	contextualCounterVecA := contextualCounterVec.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			SpanID:     spanIDa,
++			TraceID:    traceIDa,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++	spanIDb := trace.SpanID(fn(5))
++	traceIDb := trace.TraceID(fn(6))
++	contextualCounterVecB := contextualCounterVec.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			SpanID:     spanIDb,
++			TraceID:    traceIDb,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++
++	runs := []struct {
++		spanID            trace.SpanID
++		traceID           trace.TraceID
++		contextualCounter *CounterVecWithContext
++	}{
++		{
++			spanID:            spanIDa,
++			traceID:           traceIDa,
++			contextualCounter: contextualCounterVecA,
++		},
++		{
++			spanID:            spanIDb,
++			traceID:           traceIDb,
++			contextualCounter: contextualCounterVecB,
++		},
++	}
++	for _, run := range runs {
++		registry.MustRegister(run.contextualCounter)
++		run.contextualCounter.WithLabelValues("a").Inc()
++
++		mfs, err = registry.Gather()
++		if err != nil {
++			t.Fatalf("Gather failed %v", err)
++		}
++		if len(mfs) != 2 {
++			t.Fatalf("Got %v metric families, Want: 2 metric families", len(mfs))
++		}
++
++		dtoMetric := mfs[0].GetMetric()[0]
++		if dtoMetric.GetCounter().GetExemplar() == nil {
++			t.Fatalf("Got nil exemplar, wanted an exemplar")
++		}
++		dtoMetricLabels := dtoMetric.GetCounter().GetExemplar().GetLabel()
++		if len(dtoMetricLabels) != 2 {
++			t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(dtoMetricLabels))
++		}
++		for _, l := range dtoMetricLabels {
++			switch *l.Name {
++			case "trace_id":
++				if *l.Value != run.traceID.String() {
++					t.Fatalf("Got %s as traceID, wanted %s", *l.Value, run.traceID.String())
++				}
++			case "span_id":
++				if *l.Value != run.spanID.String() {
++					t.Fatalf("Got %s as spanID, wanted %s", *l.Value, run.spanID.String())
++				}
++			default:
++				t.Fatalf("Got unexpected label %s", *l.Name)
++			}
++		}
++
++		registry.Unregister(run.contextualCounter)
++	}
+ }
+diff --git a/metrics/histogram.go b/metrics/histogram.go
+index 3065486ab41d1..de9ed858fcfaf 100644
+--- a/metrics/histogram.go
++++ b/metrics/histogram.go
+@@ -28,34 +28,38 @@ import (
+ // Histogram is our internal representation for our wrapping struct around prometheus
+ // histograms. Summary implements both kubeCollector and ObserverMetric
+ type Histogram struct {
+-	ctx context.Context
+ 	ObserverMetric
+ 	*HistogramOpts
+ 	lazyMetric
+ 	selfCollector
+ }
+ 
+-// exemplarHistogramMetric holds a context to extract exemplar labels from, and a historgram metric to attach them to. It implements the metricWithExemplar interface.
+-type exemplarHistogramMetric struct {
++// TODO: make this true: var _ Metric = &Histogram{}
++
++// HistogramWithContext implements the metricWithExemplar interface.
++var _ metricWithExemplar = &HistogramWithContext{}
++
++// HistogramWithContext holds a context to extract exemplar labels from, and a historgram metric to attach them to. It implements the metricWithExemplar interface.
++type HistogramWithContext struct {
++	ctx context.Context
+ 	*Histogram
+ }
+ 
+-type exemplarHistogramVec struct {
++type HistogramVecWithContextWithObserver struct {
+ 	*HistogramVecWithContext
+-	observer prometheus.Observer
++	observer ObserverMetric
+ }
+ 
+ func (h *Histogram) Observe(v float64) {
+-	h.withExemplar(v)
++	h.ObserverMetric.Observe(v)
+ }
+ 
+-// withExemplar initializes the exemplarMetric object and sets the exemplar value.
+-func (h *Histogram) withExemplar(v float64) {
+-	(&exemplarHistogramMetric{h}).withExemplar(v)
++func (e *HistogramWithContext) Observe(v float64) {
++	e.withExemplar(v)
+ }
+ 
+ // withExemplar attaches an exemplar to the metric.
+-func (e *exemplarHistogramMetric) withExemplar(v float64) {
++func (e *HistogramWithContext) withExemplar(v float64) {
+ 	if m, ok := e.Histogram.ObserverMetric.(prometheus.ExemplarObserver); ok {
+ 		maybeSpanCtx := trace.SpanContextFromContext(e.ctx)
+ 		if maybeSpanCtx.IsValid() && maybeSpanCtx.IsSampled() {
+@@ -112,9 +116,11 @@ func (h *Histogram) initializeDeprecatedMetric() {
+ }
+ 
+ // WithContext allows the normal Histogram metric to pass in context. The context is no-op now.
+-func (h *Histogram) WithContext(ctx context.Context) ObserverMetric {
+-	h.ctx = ctx
+-	return h.ObserverMetric
++func (h *Histogram) WithContext(ctx context.Context) *HistogramWithContext {
++	return &HistogramWithContext{
++		ctx:       ctx,
++		Histogram: h,
++	}
+ }
+ 
+ // HistogramVec is the internal representation of our wrapping struct around prometheus
+@@ -263,11 +269,11 @@ type HistogramVecWithContext struct {
+ 	ctx context.Context
+ }
+ 
+-func (h *exemplarHistogramVec) Observe(v float64) {
++func (h *HistogramVecWithContextWithObserver) Observe(v float64) {
+ 	h.withExemplar(v)
+ }
+ 
+-func (h *exemplarHistogramVec) withExemplar(v float64) {
++func (h *HistogramVecWithContextWithObserver) withExemplar(v float64) {
+ 	if m, ok := h.observer.(prometheus.ExemplarObserver); ok {
+ 		maybeSpanCtx := trace.SpanContextFromContext(h.HistogramVecWithContext.ctx)
+ 		if maybeSpanCtx.IsValid() && maybeSpanCtx.IsSampled() {
+@@ -283,16 +289,16 @@ func (h *exemplarHistogramVec) withExemplar(v float64) {
+ }
+ 
+ // WithLabelValues is the wrapper of HistogramVec.WithLabelValues.
+-func (vc *HistogramVecWithContext) WithLabelValues(lvs ...string) *exemplarHistogramVec {
+-	return &exemplarHistogramVec{
++func (vc *HistogramVecWithContext) WithLabelValues(lvs ...string) *HistogramVecWithContextWithObserver {
++	return &HistogramVecWithContextWithObserver{
+ 		HistogramVecWithContext: vc,
+ 		observer:                vc.HistogramVec.WithLabelValues(lvs...),
+ 	}
+ }
+ 
+ // With is the wrapper of HistogramVec.With.
+-func (vc *HistogramVecWithContext) With(labels map[string]string) *exemplarHistogramVec {
+-	return &exemplarHistogramVec{
++func (vc *HistogramVecWithContext) With(labels map[string]string) *HistogramVecWithContextWithObserver {
++	return &HistogramVecWithContextWithObserver{
+ 		HistogramVecWithContext: vc,
+ 		observer:                vc.HistogramVec.With(labels),
+ 	}
+diff --git a/metrics/histogram_test.go b/metrics/histogram_test.go
+index 5efbfb6eeae2d..3cf4e703912ea 100644
+--- a/metrics/histogram_test.go
++++ b/metrics/histogram_test.go
+@@ -318,7 +318,7 @@ func TestHistogramWithLabelValueAllowList(t *testing.T) {
+ }
+ 
+ func TestHistogramWithExemplar(t *testing.T) {
+-	// Arrange.
++	// Create context.
+ 	traceID := trace.TraceID([]byte("trace-0000-xxxxx"))
+ 	spanID := trace.SpanID([]byte("span-0000-xxxxx"))
+ 	ctxForSpanCtx := trace.ContextWithSpanContext(context.Background(), trace.NewSpanContext(trace.SpanContextConfig{
+@@ -328,13 +328,14 @@ func TestHistogramWithExemplar(t *testing.T) {
+ 	}))
+ 	value := float64(10)
+ 
++	// Create contextual histogram.
+ 	histogram := NewHistogram(&HistogramOpts{
+ 		Name:    "histogram_exemplar_test",
+ 		Help:    "helpless",
+ 		Buckets: []float64{100},
+-	})
+-	_ = histogram.WithContext(ctxForSpanCtx)
++	}).WithContext(ctxForSpanCtx)
+ 
++	// Register histogram.
+ 	registry := newKubeRegistry(apimachineryversion.Info{
+ 		Major:      "1",
+ 		Minor:      "15",
+@@ -342,53 +343,51 @@ func TestHistogramWithExemplar(t *testing.T) {
+ 	})
+ 	registry.MustRegister(histogram)
+ 
+-	// Act.
++	// Call underlying exemplar methods.
+ 	histogram.Observe(value)
+ 
+-	// Assert.
++	// Gather.
+ 	mfs, err := registry.Gather()
+ 	if err != nil {
+ 		t.Fatalf("Gather failed %v", err)
+ 	}
+-
+ 	if len(mfs) != 1 {
+ 		t.Fatalf("Got %v metric families, Want: 1 metric family", len(mfs))
+ 	}
+ 
++	// Verify metric type.
+ 	mf := mfs[0]
+ 	var m *dto.Metric
+ 	switch mf.GetType() {
+ 	case dto.MetricType_HISTOGRAM:
+ 		m = mfs[0].GetMetric()[0]
+ 	default:
+-		t.Fatalf("Got %v metric type, Want: %v metric type", mf.GetType(), dto.MetricType_COUNTER)
++		t.Fatalf("Got %v metric type, Want: %v metric type", mf.GetType(), dto.MetricType_HISTOGRAM)
+ 	}
+ 
++	// Verify value.
+ 	want := value
+ 	got := m.GetHistogram().GetSampleSum()
+ 	if got != want {
+ 		t.Fatalf("Got %f, wanted %f as the count", got, want)
+ 	}
+ 
++	// Verify exemplars.
+ 	buckets := m.GetHistogram().GetBucket()
+ 	if len(buckets) == 0 {
+ 		t.Fatalf("Got 0 buckets, wanted 1")
+ 	}
+-
+ 	e := buckets[0].GetExemplar()
+ 	if e == nil {
+ 		t.Fatalf("Got nil exemplar, wanted an exemplar")
+ 	}
+-
+ 	eLabels := e.GetLabel()
+ 	if eLabels == nil {
+ 		t.Fatalf("Got nil exemplar label, wanted an exemplar label")
+ 	}
+-
+ 	if len(eLabels) != 2 {
+ 		t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(eLabels))
+ 	}
+-
+ 	for _, l := range eLabels {
+ 		switch *l.Name {
+ 		case "trace_id":
+@@ -403,10 +402,95 @@ func TestHistogramWithExemplar(t *testing.T) {
+ 			t.Fatalf("Got unexpected label %s", *l.Name)
+ 		}
+ 	}
++
++	// Verify that all contextual histogram calls are exclusive.
++	contextualHistogram := NewHistogram(&HistogramOpts{
++		Name:    "contextual_histogram",
++		Help:    "helpless",
++		Buckets: []float64{100},
++	})
++	traceIDa := trace.TraceID([]byte("trace-0000-aaaaa"))
++	spanIDa := trace.SpanID([]byte("span-0000-aaaaa"))
++	contextualHistogramA := contextualHistogram.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			TraceID:    traceIDa,
++			SpanID:     spanIDa,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++	traceIDb := trace.TraceID([]byte("trace-0000-bbbbb"))
++	spanIDb := trace.SpanID([]byte("span-0000-bbbbb"))
++	contextualHistogramB := contextualHistogram.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			TraceID:    traceIDb,
++			SpanID:     spanIDb,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++
++	runs := []struct {
++		spanID              trace.SpanID
++		traceID             trace.TraceID
++		contextualHistogram *HistogramWithContext
++	}{
++		{
++			spanID:              spanIDa,
++			traceID:             traceIDa,
++			contextualHistogram: contextualHistogramA,
++		},
++		{
++			spanID:              spanIDb,
++			traceID:             traceIDb,
++			contextualHistogram: contextualHistogramB,
++		},
++	}
++	for _, run := range runs {
++		registry.MustRegister(run.contextualHistogram)
++		run.contextualHistogram.Observe(value)
++
++		mfs, err = registry.Gather()
++		if err != nil {
++			t.Fatalf("Gather failed %v", err)
++		}
++		if len(mfs) != 2 {
++			t.Fatalf("Got %v metric families, Want: 2 metric families", len(mfs))
++		}
++
++		dtoMetric := mfs[0].GetMetric()[0]
++		dtoMetricBuckets := dtoMetric.GetHistogram().GetBucket()
++		if len(dtoMetricBuckets) == 0 {
++			t.Fatalf("Got nil buckets")
++		}
++		dtoMetricBucketsExemplar := dtoMetricBuckets[0].GetExemplar()
++		if dtoMetricBucketsExemplar == nil {
++			t.Fatalf("Got nil exemplar")
++		}
++
++		dtoMetricLabels := dtoMetricBucketsExemplar.GetLabel()
++		if len(dtoMetricLabels) != 2 {
++			t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(dtoMetricLabels))
++		}
++		for _, l := range dtoMetricLabels {
++			switch *l.Name {
++			case "trace_id":
++				if *l.Value != run.traceID.String() {
++					t.Fatalf("Got %s as traceID, wanted %s", *l.Value, run.traceID.String())
++				}
++			case "span_id":
++				if *l.Value != run.spanID.String() {
++					t.Fatalf("Got %s as spanID, wanted %s", *l.Value, run.spanID.String())
++				}
++			default:
++				t.Fatalf("Got unexpected label %s", *l.Name)
++			}
++		}
++
++		registry.Unregister(run.contextualHistogram)
++	}
+ }
+ 
+ func TestHistogramVecWithExemplar(t *testing.T) {
+-	// Arrange.
++	// Create context.
+ 	traceID := trace.TraceID([]byte("trace-0000-xxxxx"))
+ 	spanID := trace.SpanID([]byte("span-0000-xxxxx"))
+ 	ctxForSpanCtx := trace.ContextWithSpanContext(context.Background(), trace.NewSpanContext(trace.SpanContextConfig{
+@@ -416,6 +500,7 @@ func TestHistogramVecWithExemplar(t *testing.T) {
+ 	}))
+ 	value := float64(10)
+ 
++	// Create contextual histogram.
+ 	histogramVec := NewHistogramVec(&HistogramOpts{
+ 		Name:    "histogram_exemplar_test",
+ 		Help:    "helpless",
+@@ -423,6 +508,7 @@ func TestHistogramVecWithExemplar(t *testing.T) {
+ 	}, []string{"group"})
+ 	h := histogramVec.WithContext(ctxForSpanCtx)
+ 
++	// Register histogram.
+ 	registry := newKubeRegistry(apimachineryversion.Info{
+ 		Major:      "1",
+ 		Minor:      "15",
+@@ -430,53 +516,51 @@ func TestHistogramVecWithExemplar(t *testing.T) {
+ 	})
+ 	registry.MustRegister(histogramVec)
+ 
+-	// Act.
++	// Call underlying exemplar methods.
+ 	h.WithLabelValues("foo").Observe(value)
+ 
+-	// Assert.
++	// Gather.
+ 	mfs, err := registry.Gather()
+ 	if err != nil {
+ 		t.Fatalf("Gather failed %v", err)
+ 	}
+-
+ 	if len(mfs) != 1 {
+ 		t.Fatalf("Got %v metric families, Want: 1 metric family", len(mfs))
+ 	}
+ 
++	// Verify metric type.
+ 	mf := mfs[0]
+ 	var m *dto.Metric
+ 	switch mf.GetType() {
+ 	case dto.MetricType_HISTOGRAM:
+ 		m = mfs[0].GetMetric()[0]
+ 	default:
+-		t.Fatalf("Got %v metric type, Want: %v metric type", mf.GetType(), dto.MetricType_COUNTER)
++		t.Fatalf("Got %v metric type, Want: %v metric type", mf.GetType(), dto.MetricType_HISTOGRAM)
+ 	}
+ 
++	// Verify value.
+ 	want := value
+ 	got := m.GetHistogram().GetSampleSum()
+ 	if got != want {
+ 		t.Fatalf("Got %f, wanted %f as the count", got, want)
+ 	}
+ 
++	// Verify exemplars.
+ 	buckets := m.GetHistogram().GetBucket()
+ 	if len(buckets) == 0 {
+ 		t.Fatalf("Got 0 buckets, wanted 1")
+ 	}
+-
+ 	e := buckets[0].GetExemplar()
+ 	if e == nil {
+ 		t.Fatalf("Got nil exemplar, wanted an exemplar")
+ 	}
+-
+ 	eLabels := e.GetLabel()
+ 	if eLabels == nil {
+ 		t.Fatalf("Got nil exemplar label, wanted an exemplar label")
+ 	}
+-
+ 	if len(eLabels) != 2 {
+ 		t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(eLabels))
+ 	}
+-
+ 	for _, l := range eLabels {
+ 		switch *l.Name {
+ 		case "trace_id":
+@@ -491,4 +575,87 @@ func TestHistogramVecWithExemplar(t *testing.T) {
+ 			t.Fatalf("Got unexpected label %s", *l.Name)
+ 		}
+ 	}
++
++	// Verify that all contextual histogram calls are exclusive.
++	contextualHistogramVec := NewHistogramVec(&HistogramOpts{
++		Name:    "contextual_histogram_vec",
++		Help:    "helpless",
++		Buckets: []float64{100},
++	}, []string{"group"})
++	traceIDa := trace.TraceID([]byte("trace-0000-aaaaa"))
++	spanIDa := trace.SpanID([]byte("span-0000-aaaaa"))
++	contextualHistogramVecA := contextualHistogramVec.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			TraceID:    traceIDa,
++			SpanID:     spanIDa,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++	traceIDb := trace.TraceID([]byte("trace-0000-bbbbb"))
++	spanIDb := trace.SpanID([]byte("span-0000-bbbbb"))
++	contextualHistogramVecB := contextualHistogramVec.WithContext(trace.ContextWithSpanContext(context.Background(),
++		trace.NewSpanContext(trace.SpanContextConfig{
++			TraceID:    traceIDb,
++			SpanID:     spanIDb,
++			TraceFlags: trace.FlagsSampled,
++		}),
++	))
++	runs := []struct {
++		spanID                 trace.SpanID
++		traceID                trace.TraceID
++		contextualHistogramVec *HistogramVecWithContext
++	}{
++		{
++			spanID:                 spanIDa,
++			traceID:                traceIDa,
++			contextualHistogramVec: contextualHistogramVecA,
++		}, {
++			spanID:                 spanIDb,
++			traceID:                traceIDb,
++			contextualHistogramVec: contextualHistogramVecB,
++		},
++	}
++	for _, run := range runs {
++		registry.MustRegister(run.contextualHistogramVec)
++		run.contextualHistogramVec.WithLabelValues("foo").Observe(value)
++
++		mfs, err = registry.Gather()
++		if err != nil {
++			t.Fatalf("Gather failed %v", err)
++		}
++		if len(mfs) != 2 {
++			t.Fatalf("Got %v metric families, Want: 2 metric families", len(mfs))
++		}
++
++		dtoMetric := mfs[0].GetMetric()[0]
++		dtoMetricBuckets := dtoMetric.GetHistogram().GetBucket()
++		if len(dtoMetricBuckets) == 0 {
++			t.Fatalf("Got nil buckets")
++		}
++		dtoMetricBucketsExemplar := dtoMetricBuckets[0].GetExemplar()
++		if dtoMetricBucketsExemplar == nil {
++			t.Fatalf("Got nil exemplar")
++		}
++
++		dtoMetricLabels := dtoMetricBucketsExemplar.GetLabel()
++		if len(dtoMetricLabels) != 2 {
++			t.Fatalf("Got %v exemplar labels, wanted 2 exemplar labels", len(dtoMetricLabels))
++		}
++		for _, l := range dtoMetricLabels {
++			switch *l.Name {
++			case "trace_id":
++				if *l.Value != run.traceID.String() {
++					t.Fatalf("Got %s as traceID, wanted %s", *l.Value, run.traceID.String())
++				}
++			case "span_id":
++				if *l.Value != run.spanID.String() {
++					t.Fatalf("Got %s as spanID, wanted %s", *l.Value, run.spanID.String())
++				}
++			default:
++				t.Fatalf("Got unexpected label %s", *l.Name)
++			}
++		}
++
++		registry.Unregister(run.contextualHistogramVec)
++	}
+ }