xds: Fix load reporting when pick first is used for locality-routing.… · grpc/grpc-java@1dae144
@@ -27,6 +27,7 @@
2727import io.grpc.ClientStreamTracer;
2828import io.grpc.ClientStreamTracer.StreamInfo;
2929import io.grpc.ConnectivityState;
30+import io.grpc.ConnectivityStateInfo;
3031import io.grpc.EquivalentAddressGroup;
3132import io.grpc.InternalLogId;
3233import io.grpc.LoadBalancer;
5960import java.util.Map;
6061import java.util.Objects;
6162import java.util.concurrent.atomic.AtomicLong;
63+import java.util.concurrent.atomic.AtomicReference;
6264import javax.annotation.Nullable;
63656466/**
@@ -77,10 +79,8 @@ final class ClusterImplLoadBalancer extends LoadBalancer {
7779Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING"))
7880 || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING"));
798180-private static final Attributes.Key<ClusterLocalityStats> ATTR_CLUSTER_LOCALITY_STATS =
81-Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityStats");
82-private static final Attributes.Key<String> ATTR_CLUSTER_LOCALITY_NAME =
83-Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityName");
82+private static final Attributes.Key<AtomicReference<ClusterLocality>> ATTR_CLUSTER_LOCALITY =
83+Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocality");
84848585private final XdsLogger logger;
8686private final Helper helper;
@@ -213,36 +213,45 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne
213213@Override
214214public Subchannel createSubchannel(CreateSubchannelArgs args) {
215215List<EquivalentAddressGroup> addresses = withAdditionalAttributes(args.getAddresses());
216-Locality locality = args.getAddresses().get(0).getAttributes().get(
217-InternalXdsAttributes.ATTR_LOCALITY); // all addresses should be in the same locality
218-String localityName = args.getAddresses().get(0).getAttributes().get(
219-InternalXdsAttributes.ATTR_LOCALITY_NAME);
220-// Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain
221-// attributes with its locality, including endpoints in LOGICAL_DNS clusters.
222-// In case of not (which really shouldn't), loads are aggregated under an empty locality.
223-if (locality == null) {
224-locality = Locality.create("", "", "");
225-localityName = "";
226- }
227-final ClusterLocalityStats localityStats =
228- (lrsServerInfo == null)
229- ? null
230- : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster,
231-edsServiceName, locality);
232-216+// This value for ClusterLocality is not recommended for general use.
217+// Currently, we extract locality data from the first address, even before the subchannel is
218+// READY.
219+// This is mainly to accommodate scenarios where a Load Balancing API (like "pick first")
220+// might return the subchannel before it is READY. Typically, we wouldn't report load for such
221+// selections because the channel will disregard the chosen (not-ready) subchannel.
222+// However, we needed to ensure this case is handled.
223+ClusterLocality clusterLocality = createClusterLocalityFromAttributes(
224+args.getAddresses().get(0).getAttributes());
225+AtomicReference<ClusterLocality> localityAtomicReference = new AtomicReference<>(
226+clusterLocality);
233227Attributes attrs = args.getAttributes().toBuilder()
234- .set(ATTR_CLUSTER_LOCALITY_STATS, localityStats)
235- .set(ATTR_CLUSTER_LOCALITY_NAME, localityName)
228+ .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference)
236229 .build();
237230args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build();
238231final Subchannel subchannel = delegate().createSubchannel(args);
239232240233return new ForwardingSubchannel() {
234+@Override
235+public void start(SubchannelStateListener listener) {
236+delegate().start(new SubchannelStateListener() {
237+@Override
238+public void onSubchannelState(ConnectivityStateInfo newState) {
239+if (newState.getState().equals(ConnectivityState.READY)) {
240+// Get locality based on the connected address attributes
241+ClusterLocality updatedClusterLocality = createClusterLocalityFromAttributes(
242+subchannel.getConnectedAddressAttributes());
243+ClusterLocality oldClusterLocality = localityAtomicReference
244+ .getAndSet(updatedClusterLocality);
245+oldClusterLocality.release();
246+ }
247+listener.onSubchannelState(newState);
248+ }
249+ });
250+ }
251+241252@Override
242253public void shutdown() {
243-if (localityStats != null) {
244-localityStats.release();
245- }
254+localityAtomicReference.get().release();
246255delegate().shutdown();
247256 }
248257@@ -274,6 +283,28 @@ private List<EquivalentAddressGroup> withAdditionalAttributes(
274283return newAddresses;
275284 }
276285286+private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAttributes) {
287+Locality locality = addressAttributes.get(InternalXdsAttributes.ATTR_LOCALITY);
288+String localityName = addressAttributes.get(InternalXdsAttributes.ATTR_LOCALITY_NAME);
289+290+// Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain
291+// attributes with its locality, including endpoints in LOGICAL_DNS clusters.
292+// In case of not (which really shouldn't), loads are aggregated under an empty
293+// locality.
294+if (locality == null) {
295+locality = Locality.create("", "", "");
296+localityName = "";
297+ }
298+299+final ClusterLocalityStats localityStats =
300+ (lrsServerInfo == null)
301+ ? null
302+ : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster,
303+edsServiceName, locality);
304+305+return new ClusterLocality(localityStats, localityName);
306+ }
307+277308@Override
278309protected Helper delegate() {
279310return helper;
@@ -361,18 +392,23 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
361392"Cluster max concurrent requests limit exceeded"));
362393 }
363394 }
364-final ClusterLocalityStats stats =
365-result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_STATS);
366-if (stats != null) {
367-String localityName =
368-result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_NAME);
369-args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.locality", localityName);
370-371-ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory(
372-stats, inFlights, result.getStreamTracerFactory());
373-ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
374- .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats));
375-return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory);
395+final AtomicReference<ClusterLocality> clusterLocality =
396+result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY);
397+398+if (clusterLocality != null) {
399+ClusterLocalityStats stats = clusterLocality.get().getClusterLocalityStats();
400+if (stats != null) {
401+String localityName =
402+result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY).get()
403+ .getClusterLocalityName();
404+args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.locality", localityName);
405+406+ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory(
407+stats, inFlights, result.getStreamTracerFactory());
408+ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
409+ .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats));
410+return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory);
411+ }
376412 }
377413 }
378414return result;
@@ -447,4 +483,33 @@ public void onLoadReport(MetricReport report) {
447483stats.recordBackendLoadMetricStats(report.getNamedMetrics());
448484 }
449485 }
486+487+/**
488+ * Represents the {@link ClusterLocalityStats} and network locality name of a cluster.
489+ */
490+static final class ClusterLocality {
491+private final ClusterLocalityStats clusterLocalityStats;
492+private final String clusterLocalityName;
493+494+@VisibleForTesting
495+ClusterLocality(ClusterLocalityStats localityStats, String localityName) {
496+this.clusterLocalityStats = localityStats;
497+this.clusterLocalityName = localityName;
498+ }
499+500+ClusterLocalityStats getClusterLocalityStats() {
501+return clusterLocalityStats;
502+ }
503+504+String getClusterLocalityName() {
505+return clusterLocalityName;
506+ }
507+508+@VisibleForTesting
509+void release() {
510+if (clusterLocalityStats != null) {
511+clusterLocalityStats.release();
512+ }
513+ }
514+ }
450515}