HelloWood

gRPC 使用自定义的 LoadBalancer

2020-09-29

gRPC 使用自定义的 LoadBalancer

gRPC 中提供了 round_robin, pick_first, grpclb, HealthCheckingRoundRobin 等负载均衡的实现,默认使用HealthCheckingRoundRobin,该负载均衡支持检查 Subchannel 的健康状态

LoadBalancer 主要类包括 LoadBalancerProvider, LoadBalancer, SubchannelPicker, LoadBalancer.SubchannelStateListener ,所以实现自定义的 LoadBalancer 实现这几个类就可以

实现自定义的 LoadBalancer

自定义实现一个轮询策略的负载均衡器

  • CustomLoadBalancerProvider.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public class CustomLoadBalancerProvider extends LoadBalancerProvider {

@Override
public boolean isAvailable() {
return true;
}

@Override
public int getPriority() {
return 10;
}

@Override
public String getPolicyName() {
return "custom_round_robin";
}

@Override
public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) {
return new CustomLoadBalancer(helper);
}
}
  • CustomLoadBalancer.java

在 CustomLoadBalancer 中实现了地址的处理,根据地址创建 Subchannel,并启动 Subchannel 状态监听器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@Slf4j
public class CustomLoadBalancer extends LoadBalancer {

public static final Attributes.Key<Ref<ConnectivityState>> STATE_INFO = Attributes.Key.create("state-info");

private final Helper helper;

Map<EquivalentAddressGroup, Subchannel> subchannelMap = new ConcurrentHashMap<>();

public CustomLoadBalancer(Helper helper) {
this.helper = helper;
}

@Override
public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) {
log.info("处理地址:{}", resolvedAddresses.getAddresses().toString());

// 将解析的地址分割成单个 Address
List<EquivalentAddressGroup> latestAddresses = resolvedAddresses.getAddresses()
.stream()
.flatMap(this::splitAddressCollection)
.distinct()
.collect(Collectors.toList());

// 已经存在的地址
Set<EquivalentAddressGroup> originAddresses = subchannelMap.keySet();

// 对新的 Address 创建 Subchannel
Map<EquivalentAddressGroup, Subchannel> newSubchannelMap = latestAddresses.stream()
.filter(e -> !originAddresses.contains(e))
.map(this::buildCreateSubchannelArgs)
.map(helper::createSubchannel)
.map(this::processSubchannel)
.collect(Collectors.toConcurrentMap(Subchannel::getAddresses, s -> s));

// 将已存在的 Subchannel 放到新的集合中
originAddresses.stream()
.filter(latestAddresses::contains)
.forEach(e -> newSubchannelMap.put(e, subchannelMap.get(e)));


// 关闭需要移除的 Subchannel
originAddresses.stream()
.filter(e -> !latestAddresses.contains(e))
.map(e -> subchannelMap.get(e))
.forEach(Subchannel::shutdown);

subchannelMap = newSubchannelMap;
}

private CreateSubchannelArgs buildCreateSubchannelArgs(EquivalentAddressGroup e) {
return CreateSubchannelArgs.newBuilder()
.setAddresses(e)
.setAttributes(Attributes.newBuilder()
.set(STATE_INFO, new Ref<>(IDLE))
.build())
.build();
}

private Stream<EquivalentAddressGroup> splitAddressCollection(EquivalentAddressGroup equivalentAddressGroup) {
Attributes attributes = equivalentAddressGroup.getAttributes();
return equivalentAddressGroup.getAddresses()
.stream()
.map(e -> new EquivalentAddressGroup(e, attributes));
}

private Subchannel processSubchannel(Subchannel subchannel) {
if (subchannelMap.containsValue(subchannel)) {
log.info("{} {} 已经存在", subchannel, subchannel.getAddresses());
return subchannel;
}

subchannel.start(new CustomSubchannelStateListener(this, subchannel, helper));
subchannel.requestConnection();
return subchannel;
}


@Override
public void handleNameResolutionError(Status error) {
log.info("命名解析失败:{}", error);
helper.updateBalancingState(ConnectivityState.TRANSIENT_FAILURE, new CustomSubchannelPicker(PickResult.withNoResult()));
}

@Override
public void shutdown() {
subchannelMap.values()
.stream()
.peek(s -> log.info("关闭 {} {}", s, s.getAddresses()))
.forEach(Subchannel::shutdown);
}

public Map<EquivalentAddressGroup, Subchannel> getSubchannelMap() {
return new ConcurrentHashMap<>(this.subchannelMap);
}
}
  • CustomSubchannelStateListener.java

Subchannel 的状态监听器,当 Subchannel 状态发生变化时进行处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@Slf4j
class CustomSubchannelStateListener implements LoadBalancer.SubchannelStateListener {
private final LoadBalancer.Subchannel subchannel;
private final LoadBalancer.Helper helper;
private final CustomLoadBalancer loadBalancer;

public CustomSubchannelStateListener(CustomLoadBalancer customLoadBalancer,
LoadBalancer.Subchannel subchannel,
LoadBalancer.Helper helper) {
this.loadBalancer = customLoadBalancer;
this.subchannel = subchannel;
this.helper = helper;
}

@Override
public void onSubchannelState(ConnectivityStateInfo stateInfo) {
Ref<ConnectivityState> stateInfoRef = subchannel.getAttributes().get(STATE_INFO);
ConnectivityState currentState = stateInfoRef.getValue();
ConnectivityState newState = stateInfo.getState();

log.info("{} 状态变化:{}", subchannel, newState);

if (newState == SHUTDOWN) {
log.info("关闭 {}", subchannel);
return;
}

if (newState == READY) {
subchannel.requestConnection();
}

if (currentState == TRANSIENT_FAILURE) {
if (newState == CONNECTING || newState == IDLE) {
log.info("{} 建立连接或者失败", subchannel);
return;
}
}

stateInfoRef.setValue(newState);
updateLoadBalancerState();
}

private void updateLoadBalancerState() {
List<LoadBalancer.Subchannel> readySubchannels = loadBalancer.getSubchannelMap()
.values()
.stream()
.filter(s -> s.getAttributes().get(STATE_INFO).getValue() == READY)
.collect(Collectors.toList());

if (readySubchannels.isEmpty()) {
log.info("更新 LB 状态为 CONNECTING,没有 READY 的 Subchannel");
helper.updateBalancingState(CONNECTING, new CustomSubchannelPicker(LoadBalancer.PickResult.withNoResult()));
} else {
log.info("更新 LB 状态为 READY,Subchannel 为:{}", readySubchannels.toArray());
helper.updateBalancingState(READY, new CustomSubchannelPicker(readySubchannels));
}
}
}

  • CustomSubchannelPicker.java

实现选择 Subchannel 的逻辑,这里使用的是轮询策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@Slf4j
class CustomSubchannelPicker extends LoadBalancer.SubchannelPicker {

private final AtomicInteger index = new AtomicInteger();

private List<LoadBalancer.Subchannel> subchannelList;

private LoadBalancer.PickResult pickResult;

public CustomSubchannelPicker(LoadBalancer.PickResult pickResult) {
this.pickResult = pickResult;
}

public CustomSubchannelPicker(List<LoadBalancer.Subchannel> subchannelList) {
this.subchannelList = subchannelList;
}

@Override
public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) {
if (pickResult != null) {
log.info("有错误的 pickResult,返回:{}", pickResult);
return pickResult;
}
LoadBalancer.PickResult pickResult = nextSubchannel(args);
log.info("Pick 下一个 Subchannel:{}", pickResult.getSubchannel());
return pickResult;
}

private LoadBalancer.PickResult nextSubchannel(LoadBalancer.PickSubchannelArgs args) {
if (index.get() >= subchannelList.size()) {
index.set(0);
}

LoadBalancer.Subchannel subchannel = subchannelList.get(index.getAndIncrement());

log.info("返回 Subchannel:{}", subchannel);
return LoadBalancer.PickResult.withSubchannel(subchannel);
}
}
  • Ref.java

用于保存 Subchannel 状态的工具类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
final class Ref<T> {
T value;

Ref(T value) {
this.value = value;
}

public void setValue(T value) {
this.value = value;
}

public T getValue() {
return value;
}
}
Tags: gRPC