@@ -2,22 +2,22 @@ package gg.beemo.vanilla.rpc
2
2
3
3
import com.google.protobuf.Empty
4
4
import gg.beemo.latte.logging.Log
5
- import gg.beemo.latte.proto.ClusterConfigRequest
6
- import gg.beemo.latte.proto.ClusterConfigResponse
7
- import gg.beemo.latte.proto.clusterConfigResponse
8
5
import gg.beemo.latte.proto.ClusteringGrpcKt
9
6
import gg.beemo.latte.proto.GetClusterConfigRequest
10
7
import gg.beemo.latte.proto.GetClusterConfigResponse
11
8
import gg.beemo.latte.proto.GuildState
12
9
import gg.beemo.latte.proto.LookupGuildClusterRequest
13
10
import gg.beemo.latte.proto.LookupGuildClusterResponse
14
11
import gg.beemo.latte.proto.ShardIdentifier
15
- import gg.beemo.latte.proto.shardIdentifier
16
12
import gg.beemo.latte.proto.UpdateGuildStateRequest
17
13
import gg.beemo.latte.proto.getClusterConfigResponse
18
14
import gg.beemo.latte.proto.lookupGuildClusterResponse
15
+ import gg.beemo.latte.proto.shardIdentifier
16
+ import gg.beemo.vanilla.Config
17
+ import io.grpc.Status
19
18
import kotlinx.coroutines.flow.Flow
20
19
import java.util.HashMap
20
+ import kotlin.math.min
21
21
22
22
data class ClusterConfig (
23
23
val clusterId : String ,
@@ -31,27 +31,32 @@ data class GuildStatus(
31
31
)
32
32
33
33
class GrpcClusteringService : ClusteringGrpcKt .ClusteringCoroutineImplBase () {
34
-
35
34
private val log by Log
36
35
37
36
private val clusters = HashMap <String , ClusterConfig >()
38
37
private val guilds = HashMap <Long , GuildStatus >()
39
38
40
39
override suspend fun getClusterConfig (request : GetClusterConfigRequest ): GetClusterConfigResponse {
41
40
log.info(" Received cluster config request from cluster ID '${request.clusterId} '" )
42
- this .clusters[request.clusterId] = ClusterConfig (
43
- clusterId = request.clusterId,
44
- grpcEndpoint = request.grpcEndpoint,
45
- )
46
- // TODO Return correct shard mapping
47
- return getClusterConfigResponse {
48
- this .shards + = listOf (
49
- shardIdentifier {
50
- this .clusterId = " lol"
51
- this .shardId = 0
52
- this .shardCount = 1
53
- },
41
+ this .clusters[request.clusterId] =
42
+ ClusterConfig (
43
+ clusterId = request.clusterId,
44
+ grpcEndpoint = request.grpcEndpoint,
54
45
)
46
+
47
+ val clusterIndex = 0 // TODO map cluster id to index
48
+ val shardRange = getClusterShardRange(clusterIndex, Config .TEA_SHARD_COUNT , Config .TEA_CLUSTER_COUNT )
49
+ val shards =
50
+ shardRange.map { shardId ->
51
+ shardIdentifier {
52
+ this .clusterId = request.clusterId
53
+ this .shardId = shardId
54
+ this .shardCount = Config .TEA_SHARD_COUNT
55
+ }
56
+ }
57
+
58
+ return getClusterConfigResponse {
59
+ this .shards + = shards
55
60
}
56
61
}
57
62
@@ -60,7 +65,11 @@ class GrpcClusteringService : ClusteringGrpcKt.ClusteringCoroutineImplBase() {
60
65
val shard = update.shard
61
66
log.debug(
62
67
" Guild {} in Cluster {} Shard {}/{} has changed state to {}" ,
63
- update.guildId, shard.clusterId, shard.shardId, shard.clusterId, update.state,
68
+ update.guildId,
69
+ shard.clusterId,
70
+ shard.shardId,
71
+ shard.clusterId,
72
+ update.state,
64
73
)
65
74
if (! clusters.containsKey(shard.clusterId)) {
66
75
log.warn(" Unknown cluster {} in guild update for {}" , shard.clusterId, update.guildId)
@@ -75,14 +84,40 @@ class GrpcClusteringService : ClusteringGrpcKt.ClusteringCoroutineImplBase() {
75
84
}
76
85
77
86
override suspend fun lookupGuildCluster (request : LookupGuildClusterRequest ): LookupGuildClusterResponse {
78
- val guild = guilds[request.guildId]
79
- requireNotNull(guild) // TODO How to properly return errors in gRPC?
80
- val cluster = clusters[guild.shard.clusterId]
81
- requireNotNull(cluster) // TODO Same as above
87
+ val guild = guilds[request.guildId] ? : throw Status .NOT_FOUND .withDescription(" Guild not found" ).asRuntimeException()
88
+ val cluster = clusters[guild.shard.clusterId] ? : throw Status .NOT_FOUND .withDescription(" Cluster not found" ).asRuntimeException()
82
89
return lookupGuildClusterResponse {
83
90
this .clusterId = cluster.clusterId
84
91
this .grpcEndpoint = cluster.grpcEndpoint
85
92
}
86
93
}
87
94
95
+ private fun getClusterShardRange (
96
+ cluster : Int ,
97
+ totalShards : Int ,
98
+ totalClusters : Int ,
99
+ ): IntRange {
100
+ val numShardsForNormalCluster = totalShards / totalClusters
101
+ val extraShards = totalShards % totalClusters
102
+
103
+ // If the shard cluster is within the first 0 to (extraShards - 1) shard clusters,
104
+ // we will allocate one of the extra shards to it.
105
+ val numCommandedShards =
106
+ if (extraShards > 0 && cluster < extraShards) {
107
+ numShardsForNormalCluster + 1
108
+ } else {
109
+ numShardsForNormalCluster
110
+ }
111
+
112
+ val firstShardNumber =
113
+ if (extraShards > 0 ) {
114
+ cluster * numShardsForNormalCluster + min(cluster, extraShards - 1 )
115
+ } else {
116
+ cluster * numShardsForNormalCluster
117
+ }
118
+
119
+ val lastShardNumber = firstShardNumber + numCommandedShards - 1
120
+
121
+ return firstShardNumber.. lastShardNumber
122
+ }
88
123
}
0 commit comments