xds: WRRPicker must not access unsynchronized data in ChildLbState · grpc/grpc-java@0d47f5b

@@ -44,11 +44,10 @@

4444

import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;

4545

import io.grpc.xds.orca.OrcaPerRequestUtil;

4646

import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;

47+

import java.util.ArrayList;

4748

import java.util.Collection;

48-

import java.util.HashMap;

4949

import java.util.HashSet;

5050

import java.util.List;

51-

import java.util.Map;

5251

import java.util.Random;

5352

import java.util.Set;

5453

import java.util.concurrent.ScheduledExecutorService;

@@ -233,9 +232,44 @@ protected void updateOverallBalancingState() {

233232

}

234233235234

private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {

236-

return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),

237-

config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(),

238-

locality);

235+

WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),

236+

config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);

237+

updateWeight(picker);

238+

return picker;

239+

}

240+241+

private void updateWeight(WeightedRoundRobinPicker picker) {

242+

Helper helper = getHelper();

243+

float[] newWeights = new float[picker.children.size()];

244+

AtomicInteger staleEndpoints = new AtomicInteger();

245+

AtomicInteger notYetUsableEndpoints = new AtomicInteger();

246+

for (int i = 0; i < picker.children.size(); i++) {

247+

double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints,

248+

notYetUsableEndpoints);

249+

helper.getMetricRecorder()

250+

.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,

251+

ImmutableList.of(helper.getChannelTarget()),

252+

ImmutableList.of(locality));

253+

newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;

254+

}

255+256+

if (staleEndpoints.get() > 0) {

257+

helper.getMetricRecorder()

258+

.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),

259+

ImmutableList.of(helper.getChannelTarget()),

260+

ImmutableList.of(locality));

261+

}

262+

if (notYetUsableEndpoints.get() > 0) {

263+

helper.getMetricRecorder()

264+

.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),

265+

ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));

266+

}

267+

boolean weightsEffective = picker.updateWeight(newWeights);

268+

if (!weightsEffective) {

269+

helper.getMetricRecorder()

270+

.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),

271+

ImmutableList.of(locality));

272+

}

239273

}

240274241275

private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) {

@@ -345,7 +379,7 @@ private final class UpdateWeightTask implements Runnable {

345379

@Override

346380

public void run() {

347381

if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) {

348-

((WeightedRoundRobinPicker) currentPicker).updateWeight();

382+

updateWeight((WeightedRoundRobinPicker) currentPicker);

349383

}

350384

weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos,

351385

TimeUnit.NANOSECONDS, timeService);

@@ -415,110 +449,76 @@ public void shutdown() {

415449416450

@VisibleForTesting

417451

static final class WeightedRoundRobinPicker extends SubchannelPicker {

418-

private final List<ChildLbState> children;

419-

private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =

420-

new HashMap<>();

452+

// Parallel lists (column-based storage instead of normal row-based storage of List<Struct>).

453+

// The ith element of children corresponds to the ith element of pickers, listeners, and even

454+

// updateWeight(float[]).

455+

private final List<ChildLbState> children; // May only be accessed from sync context

456+

private final List<SubchannelPicker> pickers;

457+

private final List<OrcaPerRequestReportListener> reportListeners;

421458

private final boolean enableOobLoadReport;

422459

private final float errorUtilizationPenalty;

423460

private final AtomicInteger sequence;

424461

private final int hashCode;

425-

private final LoadBalancer.Helper helper;

426-

private final String locality;

427462

private volatile StaticStrideScheduler scheduler;

428463429464

WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,

430-

float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper,

431-

String locality) {

465+

float errorUtilizationPenalty, AtomicInteger sequence) {

432466

checkNotNull(children, "children");

433467

Preconditions.checkArgument(!children.isEmpty(), "empty child list");

434468

this.children = children;

469+

List<SubchannelPicker> pickers = new ArrayList<>(children.size());

470+

List<OrcaPerRequestReportListener> reportListeners = new ArrayList<>(children.size());

435471

for (ChildLbState child : children) {

436472

WeightedChildLbState wChild = (WeightedChildLbState) child;

437-

for (WrrSubchannel subchannel : wChild.subchannels) {

438-

this.subchannelToReportListenerMap

439-

.put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));

440-

}

473+

pickers.add(wChild.getCurrentPicker());

474+

reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));

441475

}

476+

this.pickers = pickers;

477+

this.reportListeners = reportListeners;

442478

this.enableOobLoadReport = enableOobLoadReport;

443479

this.errorUtilizationPenalty = errorUtilizationPenalty;

444480

this.sequence = checkNotNull(sequence, "sequence");

445-

this.helper = helper;

446-

this.locality = checkNotNull(locality, "locality");

447481448-

// For equality we treat children as a set; use hash code as defined by Set

482+

// For equality we treat pickers as a set; use hash code as defined by Set

449483

int sum = 0;

450-

for (ChildLbState child : children) {

451-

sum += child.hashCode();

484+

for (SubchannelPicker picker : pickers) {

485+

sum += picker.hashCode();

452486

}

453487

this.hashCode = sum

454488

^ Boolean.hashCode(enableOobLoadReport)

455489

^ Float.hashCode(errorUtilizationPenalty);

456-457-

updateWeight();

458490

}

459491460492

@Override

461493

public PickResult pickSubchannel(PickSubchannelArgs args) {

462-

ChildLbState childLbState = children.get(scheduler.pick());

463-

WeightedChildLbState wChild = (WeightedChildLbState) childLbState;

464-

PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);

494+

int pick = scheduler.pick();

495+

PickResult pickResult = pickers.get(pick).pickSubchannel(args);

465496

Subchannel subchannel = pickResult.getSubchannel();

466497

if (subchannel == null) {

467498

return pickResult;

468499

}

469500

if (!enableOobLoadReport) {

470501

return PickResult.withSubchannel(subchannel,

471502

OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(

472-

subchannelToReportListenerMap.getOrDefault(subchannel,

473-

wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));

503+

reportListeners.get(pick)));

474504

} else {

475505

return PickResult.withSubchannel(subchannel);

476506

}

477507

}

478508479-

private void updateWeight() {

480-

float[] newWeights = new float[children.size()];

481-

AtomicInteger staleEndpoints = new AtomicInteger();

482-

AtomicInteger notYetUsableEndpoints = new AtomicInteger();

483-

for (int i = 0; i < children.size(); i++) {

484-

double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,

485-

notYetUsableEndpoints);

486-

// TODO: add locality label once available

487-

helper.getMetricRecorder()

488-

.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight,

489-

ImmutableList.of(helper.getChannelTarget()),

490-

ImmutableList.of(locality));

491-

newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;

492-

}

493-

if (staleEndpoints.get() > 0) {

494-

// TODO: add locality label once available

495-

helper.getMetricRecorder()

496-

.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),

497-

ImmutableList.of(helper.getChannelTarget()),

498-

ImmutableList.of(locality));

499-

}

500-

if (notYetUsableEndpoints.get() > 0) {

501-

// TODO: add locality label once available

502-

helper.getMetricRecorder()

503-

.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),

504-

ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality));

505-

}

506-509+

/** Returns {@code true} if weights are different than round_robin. */

510+

private boolean updateWeight(float[] newWeights) {

507511

this.scheduler = new StaticStrideScheduler(newWeights, sequence);

508-

if (this.scheduler.usesRoundRobin()) {

509-

// TODO: locality label once available

510-

helper.getMetricRecorder()

511-

.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()),

512-

ImmutableList.of(locality));

513-

}

512+

return !this.scheduler.usesRoundRobin();

514513

}

515514516515

@Override

517516

public String toString() {

518517

return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)

519518

.add("enableOobLoadReport", enableOobLoadReport)

520519

.add("errorUtilizationPenalty", errorUtilizationPenalty)

521-

.add("list", children).toString();

520+

.add("pickers", pickers)

521+

.toString();

522522

}

523523524524

@VisibleForTesting

@@ -545,8 +545,8 @@ public boolean equals(Object o) {

545545

&& sequence == other.sequence

546546

&& enableOobLoadReport == other.enableOobLoadReport

547547

&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0

548-

&& children.size() == other.children.size()

549-

&& new HashSet<>(children).containsAll(other.children);

548+

&& pickers.size() == other.pickers.size()

549+

&& new HashSet<>(pickers).containsAll(other.pickers);

550550

}

551551

}

552552