xds: WRRPicker must not access unsynchronized data in ChildLbState · grpc/grpc-java@0d47f5b
@@ -44,11 +44,10 @@
4444import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
4545import io.grpc.xds.orca.OrcaPerRequestUtil;
4646import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
47+import java.util.ArrayList;
4748import java.util.Collection;
48-import java.util.HashMap;
4949import java.util.HashSet;
5050import java.util.List;
51-import java.util.Map;
5251import java.util.Random;
5352import java.util.Set;
5453import java.util.concurrent.ScheduledExecutorService;
@@ -233,9 +232,44 @@ protected void updateOverallBalancingState() {
233232 }
234233235234private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
236-return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
237-config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(),
238-locality);
235+WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
236+config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
237+updateWeight(picker);
238+return picker;
239+ }
240+241+private void updateWeight(WeightedRoundRobinPicker picker) {
242+Helper helper = getHelper();
243+float[] newWeights = new float[picker.children.size()];
244+AtomicInteger staleEndpoints = new AtomicInteger();
245+AtomicInteger notYetUsableEndpoints = new AtomicInteger();
246+for (int i = 0; i < picker.children.size(); i++) {
247+double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,
248+notYetUsableEndpoints);
249+helper.getMetricRecorder()
250+ .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
251+ImmutableList.of(helper.getChannelTarget()),
252+ImmutableList.of(locality));
253+newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
254+ }
255+256+if (staleEndpoints.get() > 0) {
257+helper.getMetricRecorder()
258+ .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
259+ImmutableList.of(helper.getChannelTarget()),
260+ImmutableList.of(locality));
261+ }
262+if (notYetUsableEndpoints.get() > 0) {
263+helper.getMetricRecorder()
264+ .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
265+ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
266+ }
267+boolean weightsEffective = picker.updateWeight(newWeights);
268+if (!weightsEffective) {
269+helper.getMetricRecorder()
270+ .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
271+ImmutableList.of(locality));
272+ }
239273 }
240274241275private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {
@@ -345,7 +379,7 @@ private final class UpdateWeightTask implements Runnable {
345379@Override
346380public void run() {
347381if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {
348- ((WeightedRoundRobinPicker) currentPicker).updateWeight();
382+updateWeight((WeightedRoundRobinPicker) currentPicker);
349383 }
350384weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,
351385TimeUnit.NANOSECONDS, timeService);
@@ -415,110 +449,76 @@ public void shutdown() {
415449416450@VisibleForTesting
417451static final class WeightedRoundRobinPicker extends SubchannelPicker {
418-private final List<ChildLbState> children;
419-private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
420-new HashMap<>();
452+// Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).
453+// The ith element of children corresponds to the ith element of pickers, listeners, and even
454+// updateWeight(float[]).
455+private final List<ChildLbState> children; // May only be accessed from sync context
456+private final List<SubchannelPicker> pickers;
457+private final List<OrcaPerRequestReportListener> reportListeners;
421458private final boolean enableOobLoadReport;
422459private final float errorUtilizationPenalty;
423460private final AtomicInteger sequence;
424461private final int hashCode;
425-private final LoadBalancer.Helper helper;
426-private final String locality;
427462private volatile StaticStrideScheduler scheduler;
428463429464WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
430-float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper,
431-String locality) {
465+float errorUtilizationPenalty, AtomicInteger sequence) {
432466checkNotNull(children, "children");
433467Preconditions.checkArgument(!children.isEmpty(), "empty child list");
434468this.children = children;
469+List<SubchannelPicker> pickers = new ArrayList<>(children.size());
470+List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());
435471for (ChildLbState child : children) {
436472WeightedChildLbState wChild = (WeightedChildLbState) child;
437-for (WrrSubchannel subchannel : wChild.subchannels) {
438-this.subchannelToReportListenerMap
439- .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
440- }
473+pickers.add(wChild.getCurrentPicker());
474+reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
441475 }
476+this.pickers = pickers;
477+this.reportListeners = reportListeners;
442478this.enableOobLoadReport = enableOobLoadReport;
443479this.errorUtilizationPenalty = errorUtilizationPenalty;
444480this.sequence = checkNotNull(sequence, "sequence");
445-this.helper = helper;
446-this.locality = checkNotNull(locality, "locality");
447481448-// For equality we treat children as a set; use hash code as defined by Set
482+// For equality we treat pickers as a set; use hash code as defined by Set
449483int sum = 0;
450-for (ChildLbState child : children) {
451-sum += child.hashCode();
484+for (SubchannelPicker picker : pickers) {
485+sum += picker.hashCode();
452486 }
453487this.hashCode = sum
454488 ^ Boolean.hashCode(enableOobLoadReport)
455489 ^ Float.hashCode(errorUtilizationPenalty);
456-457-updateWeight();
458490 }
459491460492@Override
461493public PickResult pickSubchannel(PickSubchannelArgs args) {
462-ChildLbState childLbState = children.get(scheduler.pick());
463-WeightedChildLbState wChild = (WeightedChildLbState) childLbState;
464-PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
494+int pick = scheduler.pick();
495+PickResult pickResult = pickers.get(pick).pickSubchannel(args);
465496Subchannel subchannel = pickResult.getSubchannel();
466497if (subchannel == null) {
467498return pickResult;
468499 }
469500if (!enableOobLoadReport) {
470501return PickResult.withSubchannel(subchannel,
471502OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
472-subchannelToReportListenerMap.getOrDefault(subchannel,
473-wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
503+reportListeners.get(pick)));
474504 } else {
475505return PickResult.withSubchannel(subchannel);
476506 }
477507 }
478508479-private void updateWeight() {
480-float[] newWeights = new float[children.size()];
481-AtomicInteger staleEndpoints = new AtomicInteger();
482-AtomicInteger notYetUsableEndpoints = new AtomicInteger();
483-for (int i = 0; i < children.size(); i++) {
484-double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
485-notYetUsableEndpoints);
486-// TODO: add locality label once available
487-helper.getMetricRecorder()
488- .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,
489-ImmutableList.of(helper.getChannelTarget()),
490-ImmutableList.of(locality));
491-newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
492- }
493-if (staleEndpoints.get() > 0) {
494-// TODO: add locality label once available
495-helper.getMetricRecorder()
496- .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
497-ImmutableList.of(helper.getChannelTarget()),
498-ImmutableList.of(locality));
499- }
500-if (notYetUsableEndpoints.get() > 0) {
501-// TODO: add locality label once available
502-helper.getMetricRecorder()
503- .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
504-ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));
505- }
506-509+/** Returns {@code true} if weights are different than round_robin. */
510+private boolean updateWeight(float[] newWeights) {
507511this.scheduler = new StaticStrideScheduler(newWeights, sequence);
508-if (this.scheduler.usesRoundRobin()) {
509-// TODO: locality label once available
510-helper.getMetricRecorder()
511- .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),
512-ImmutableList.of(locality));
513- }
512+return !this.scheduler.usesRoundRobin();
514513 }
515514516515@Override
517516public String toString() {
518517return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
519518 .add("enableOobLoadReport", enableOobLoadReport)
520519 .add("errorUtilizationPenalty", errorUtilizationPenalty)
521- .add("list", children).toString();
520+ .add("pickers", pickers)
521+ .toString();
522522 }
523523524524@VisibleForTesting
@@ -545,8 +545,8 @@ public boolean equals(Object o) {
545545&& sequence == other.sequence
546546&& enableOobLoadReport == other.enableOobLoadReport
547547&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
548-&& children.size() == other.children.size()
549-&& new HashSet<>(children).containsAll(other.children);
548+&& pickers.size() == other.pickers.size()
549+&& new HashSet<>(pickers).containsAll(other.pickers);
550550 }
551551 }
552552