实现 Redis 有序集合结构
本文进度对应的代码仓库:实现有序集合结构
在前面一节中,我们实现了 Redis 的集合结构。接下来,我们将实现 Redis 的有序集合结构。
有序集合是一个包含唯一元素的集合,每个元素都有一个分数(score),用于表示元素的顺序。
Redis 中的有序集合底层使用 Listpack 和 Skiplist 来实现。
当数据量小的时候,Redis 使用 Listpack 来存储有序集合。当数据量增大时,Redis 会将 Listpack 转换为 Skiplist。
具体实现
SkipList 结构
首先我们要实现一个 SkipList 结构。
在 datastruct
文件夹下创建一个 skiplist
文件夹,然后创建一个 skiplist.go
文件,里面实现 SkipList 结构。
const maxLevel = 16 // Maximum number of levels in the skip list
// Node represents a node in the skip list
type Node struct {
Member string
Score float64
Forward []*Node // Forward points at different levels
}
// SkipList represents a skip list
type SkipList struct {
header *Node // Header node
tail *Node // Tail node
level int // Current max level of the skip list
length int // Length of the skip list
rand *rand.Rand
}
结构解析:
-
Node
结构:Member
: 存储有序集合的成员名称,字符串类型。Score
: 存储成员对应的分数,浮点数类型,用于排序。Forward
: 这是一个指针切片 ([]*Node
)。跳表的核心在于多层链表,Forward[i]
指向该节点在第i
层的下一个节点。一个节点可以在多个层级中出现,层级越高,跨越的节点越多,查找速度越快。
-
SkipList
结构:header
: 指向跳表的头节点。头节点不存储实际数据,主要作为查找的起点,其Forward
切片包含了每一层的入口。tail
: 指向跳表的尾节点,方便快速定位到最后一个元素。level
: 记录当前跳表实际拥有的最大层数(从 1 开始计数)。初始为 1,随着插入新节点可能会增加。length
: 记录跳表中元素的数量(不包括头节点)。rand
: 用于生成随机层数,决定新插入节点的高度。
初始化 SkipList
我们需要一个构造函数来创建新的 SkipList 实例。
// New SkipList creates a new skip list
func NewSkipList() *SkipList {
header := &Node{
Forward: make([]*Node, maxLevel), // 初始化头节点的 Forward 切片
}
return &SkipList{
header: header,
level: 1, // 初始层级为 1
rand: rand.New(rand.NewSource(time.Now().UnixNano())), // 初始化随机数生成器
}
}
- 创建一个
header
节点,其Forward
切片的大小为预设的最大层数maxLevel
。所有初始指针都为nil
。 - 返回一个新的
SkipList
实例,header
指向刚创建的头节点,level
初始化为 1,并创建一个独立的随机数生成器rand
以确保随机性。
随机层级生成
新节点插入时需要确定其层数,这决定了它会出现在哪些层级的链表中。层数是随机生成的,但需要遵循一定的概率分布,通常是越高的层级概率越低。
// randomLevel generates a random level for the new node
func (sl *SkipList) randomLevel() int {
level := 1
// Increase level with 25% probability
for level < maxLevel && sl.rand.Float32() < 0.25 {
level++
}
return level
}
- 新节点的层级至少为 1。
- 以 25% 的概率增加层级,直到达到
maxLevel
或概率判定失败。这种设计使得大部分节点只在底层出现,少数节点出现在高层,形成金字塔结构,保证查找效率。
插入操作 (Insert
)
插入是 SkipList 的核心操作之一,需要同时维护排序性和多层链表结构。
// Insert inserts a new member with the given score into the skip list
func (sl *SkipList) Insert(member string, score float64) {
update := make([]*Node, maxLevel) // 存储每一层需要更新的前驱节点
x := sl.header
// 1. 查找插入位置:从最高层向下查找
for i := sl.level - 1; i >= 0; i-- {
// 在当前层向右查找,直到找到第一个 Score 更大或 Member 更大的节点
for x.Forward[i] != nil &&
(x.Forward[i].Score < score ||
(x.Forward[i].Score == score && x.Forward[i].Member < member)) {
x = x.Forward[i]
}
// 记录下这一层需要修改 Forward 指针的节点 (即新节点的前驱)
update[i] = x
}
// 此时 x 是最底层中小于新节点的最后一个节点
// 2. 生成新节点的随机层级
level := sl.randomLevel()
// 3. 更新 SkipList 的最大层级 (如果需要)
if level > sl.level {
// 如果新节点的层级超过当前最大层级,需要扩展 update 数组
for i := sl.level; i < level; i++ {
update[i] = sl.header // 新增层级的前驱节点是 header(因为这个新的层级为空)
}
sl.level = level // 更新 SkipList 的当前最大层级
}
// 4. 创建新节点
newNode := &Node{
Member: member,
Score: score,
Forward: make([]*Node, level), // Forward 切片大小为新节点的层级
}
// 5. 更新指针,将新节点链入 SkipList
for i := 0; i < level; i++ {
newNode.Forward[i] = update[i].Forward[i] // 新节点的 Forward 指向原前驱节点的下一个节点
update[i].Forward[i] = newNode // 前驱节点的 Forward 指向新节点
}
// 6. 更新尾节点指针 (如果新节点是最后一个节点)
if newNode.Forward[0] == nil {
sl.tail = newNode
}
sl.length++ // 更新 SkipList 的长度
}
插入步骤详解:
- 查找插入位置:
- 创建一个
update
切片,用于记录每一层查找到的、需要将Forward
指针指向新节点的前驱节点。 - 从当前 SkipList 的最高层 (
sl.level - 1
) 开始向下查找。 - 在每一层
i
,从当前节点x
开始向右遍历 (x = x.Forward[i]
),比较Score
和Member
(当 Score 相同时比较 Member 字典序),直到找到第一个大于等于待插入元素的节点,或者到达该层末尾。 - 将该层查找到的前驱节点(即最后一个小于待插入元素的节点)记录在
update[i]
中。 - 重复此过程直到最底层(第 0 层)。
- 创建一个
- 生成随机层级:调用
randomLevel()
为新节点确定一个随机的层数level
。 - 更新最大层级:如果新节点的
level
大于当前 SkipList 的sl.level
,说明需要增加 SkipList 的层数。将新增层级的前驱节点(在update
中对应位置)设置为header
,并更新sl.level
。 - 创建新节点:根据
member
,score
和计算出的level
创建新节点newNode
。 - 更新指针:遍历新节点的每一层(从 0 到
level - 1
):- 将新节点的
Forward[i]
指向update[i]
(前驱节点)原来的下一个节点 (update[i].Forward[i]
)。 - 将前驱节点
update[i]
的Forward[i]
指向新节点newNode
。这样就完成了新节点在第i
层的链入。
- 将新节点的
- 更新尾节点:如果新节点在最底层(第 0 层)的下一个节点是
nil
,说明它是新的尾节点,更新sl.tail
。 - 更新长度:增加
sl.length
。
这里使用一个举例帮助理解:
初始状态(只有一个空头节点)
[head] --> nil
- 跳表只有一个头节点,Forward指针全是 nil。
- 当前层数 = 1
Step 1️⃣:插入第一个元素,插入 “10”(score=10)
- 随机层数假设生成了 3层(假设 randomLevel() 返回 3)
- 插入新节点 “10”,需要更新 head 节点在 Level 0、1、2 的 forward。
Level 2: [head] --> [10] --> nil
Level 1: [head] --> [10] --> nil
Level 0: [head] --> [10] --> nil
head.Forward[2] = 10
head.Forward[1] = 10
head.Forward[0] = 10
- 新节点 “10” 自己 Forward 指向 nil
✅ 插入完成,跳表层数更新为 3。
Step 2️⃣:插入第二个元素,插入 “20”(score=20)
- 随机层数假设生成了 1层(randomLevel() 返回 1)
插入时:
- 从顶层开始(Level 2),发现 head.Forward[2] 是 “10”,“10” 的 score 小于 20,往后跳。
- 到了尾巴,降一层,直到 Level 0。
- 在 Level 0,“10” 后插入 “20”。
Level 2: [head] --> [10] --> nil
Level 1: [head] --> [10] --> nil
Level 0: [head] --> [10] --> [20] --> nil
- “20” 只有 Level 0 指针(单层)
- 更新链表:[10].Forward[0] = 20
✅ 插入完成,跳表层数仍然是 3。
Step 3️⃣:插入第三个元素,插入 “15”(score=15)
- 随机层数假设生成了 2层
插入时:
- 从 Level 2 开始,head 后是 “10”,“10” score 小于 15,继续跳。
- 后面是 nil,所以降到 Level 1。
- 在 Level 1,“10” 后面是 nil,继续降。
- 在 Level 0,“10” 后是 “20”,20 比 15 大,找到插入点!
所以在 “10” 后插入 “15”,在 Level 0 和 Level 1 两层连接。
最终结构:
Level 2: [head] --> [10] --> nil
Level 1: [head] --> [10] --> [15] --> nil
Level 0: [head] --> [10] --> [15] --> [20] --> nil
[10].Forward[1]
指向[15]
[10].Forward[0]
指向[15]
[15].Forward[1]
指向nil
[15].Forward[0]
指向[20]
插入完成!
总结插入过程的核心点:
每次插入:
- 从最高层开始往下走。
- 在每一层记录需要修改 forward 指针的位置(就是
update[i]
) - 随机生成新节点的层数。
- 在对应层数上调整前驱的 forward 指针,插入新节点。
- 如果新节点层数比跳表当前层数高,需要更新 head 的 forward。
SkipList 的特点:
✅ 每一层都是有序的! ✅ 插入过程中只有局部 forward 调整! ✅ 整体时间复杂度是 O(logN)!
如果再插一个 “5” 呢?(score=5)
- 随机层数假设是 1层。
- 因为 5 比所有已有节点小,所以直接在 head 后插入:
Level 2: [head] --> [10] --> nil
Level 1: [head] --> [10] --> [15] --> nil
Level 0: [head] --> [5] --> [10] --> [15] --> [20] --> nil
因为这是 Level 0 头插,其他层不动。
删除操作 (Delete
)
删除操作与插入类似,也需要先找到目标节点,然后修改相关节点的 Forward
指针。
// Delete removes an element from the skip list
func (sl *SkipList) Delete(member string, score float64) bool {
update := make([]*Node, maxLevel)
x := sl.header
// 1. 查找目标节点的前驱节点
for i := sl.level - 1; i >= 0; i-- {
for x.Forward[i] != nil &&
(x.Forward[i].Score < score ||
(x.Forward[i].Score == score && x.Forward[i].Member < member)) {
x = x.Forward[i]
}
update[i] = x // 记录每层的前驱
}
// 2. 定位目标节点
// x 现在是最底层目标节点的前驱,x.Forward[0] 可能是目标节点
targetNode := x.Forward[0]
// 3. 检查节点是否存在且匹配
if targetNode != nil && targetNode.Score == score && targetNode.Member == member {
// 4. 更新指针,在所有层级中移除目标节点
for i := 0; i < sl.level; i++ {
// 如果 update[i] 的下一个节点不是目标节点,说明目标节点不在这一层或更高层
if update[i].Forward[i] != targetNode {
break // 可以提前结束
}
// 将前驱节点的 Forward 指向目标节点的下一个节点,完成移除
update[i].Forward[i] = targetNode.Forward[i]
}
// 5. 更新尾节点指针 (如果删除的是尾节点)
if targetNode == sl.tail {
// 新的尾节点是 update[0] (最底层的前驱)
// 如果 update[0] 是 header,说明列表空了,tail 应为 nil
if update[0] == sl.header {
sl.tail = nil
} else {
sl.tail = update[0]
}
}
// 6. 更新 SkipList 的最大层级 (如果需要)
// 如果删除节点后,最高层变为空,则降低 SkipList 的 level
for sl.level > 1 && sl.header.Forward[sl.level-1] == nil {
sl.level--
}
sl.length-- // 更新长度
return true // 删除成功
}
return false // 未找到或不匹配,删除失败
}
删除步骤详解:
- 查找前驱节点:与插入操作的第一步完全相同,找到目标节点在每一层的前驱节点,并存储在
update
中。 - 定位目标节点:
update[0]
是最底层的前驱节点,那么targetNode = update[0].Forward[0]
就是潜在的目标节点。 - 检查匹配:验证
targetNode
是否存在 (!= nil
),并且其Score
和Member
是否与要删除的元素完全匹配。 - 更新指针:如果匹配成功,遍历 SkipList 的所有层级(从 0 到
sl.level - 1
):- 检查
update[i].Forward[i]
是否确实指向targetNode
。如果不是,说明targetNode
不在当前层或更高层,可以提前结束循环。 - 如果是,则将
update[i].Forward[i]
指向targetNode.Forward[i]
,相当于在第i
层跳过了targetNode
。
- 检查
- 更新尾节点:如果被删除的节点
targetNode
是原来的尾节点sl.tail
,需要更新sl.tail
为新的尾节点,即update[0]
(最底层的前驱节点)。特别地,如果update[0]
是header
,说明删除后列表为空,sl.tail
应设为nil
。 - 更新最大层级:删除节点后,可能导致某些高层变为空。从最高层向下检查,如果
sl.header.Forward[sl.level-1]
为nil
,说明该层已无节点,将sl.level
减 1。重复此过程直到找到非空层或sl.level
降为 1。 - 更新长度:减少
sl.length
。 - 返回结果:删除成功返回
true
,未找到或不匹配返回false
。
其他辅助方法
SkipList 还提供了一些用于查询的方法,例如按分数范围计数/获取成员、按排名获取成员、获取指定成员的排名等。这些方法的实现通常依赖于 SkipList 的有序性和多层结构来优化查找过程。
CountInRange(min, max float64) int
: 查找分数在[min, max]
区间内的第一个节点,然后向后遍历计数,直到分数超过max
。利用高层指针可以快速跳过不相关的部分。
// CountInRange counts elements with score between min and max
func (sl *SkipList) CountInRange(min, max float64) int {
count := 0
x := sl.header
// Find first node with score >= min
for i := sl.level - 1; i >= 0; i-- {
for x.Forward[i] != nil && x.Forward[i].Score < min {
x = x.Forward[i]
}
}
// Traverse nodes with score <= max
x = x.Forward[0]
for x != nil && x.Score <= max {
count++
x = x.Forward[0]
}
return count
}
RangeByScore(min, max float64, offset, count int) []string
: 类似CountInRange
,但收集成员名称,并支持offset
和count
用于分页。
// RangeByScore returns members with scores between min and max
func (sl *SkipList) RangeByScore(min, max float64, offset, count int) []string {
result := []string{}
x := sl.header
// Find first node with score >= min
for i := sl.level - 1; i >= 0; i-- {
for x.Forward[i] != nil && x.Forward[i].Score < min {
x = x.Forward[i]
}
}
// Traverse nodes with score <= max
x = x.Forward[0]
skipped := 0
for x != nil && x.Score <= max {
if offset < 0 || skipped >= offset {
result = append(result, x.Member)
// Stop if we've collected enough elements
if count > 0 && len(result) >= count {
break
}
} else {
skipped++
}
x = x.Forward[0]
}
return result
}
RangeByRank(start, stop int) []string
: 处理负数索引,然后从头节点开始,在最底层遍历start
步找到起始节点,接着收集stop - start + 1
个成员。
// RangeByRank returns members by rank (position)
func (sl *SkipList) RangeByRank(start, stop int) []string {
result := []string{}
// Handle negative indices and out of range
if start < 0 {
start = sl.length + start
}
if stop < 0 {
stop = sl.length + stop
}
if start < 0 {
start = 0
}
if stop >= sl.length {
stop = sl.length - 1
}
if start > stop || start >= sl.length {
return result
}
// Traverse to start position
x := sl.header.Forward[0]
for i := 0; i < start && x != nil; i++ {
x = x.Forward[0]
}
// Collect members between start and stop
for i := start; i <= stop && x != nil; i++ {
result = append(result, x.Member)
x = x.Forward[0]
}
return result
}
GetRank(member string, score float64) int
: 类似插入时的查找过程,从高层向低层查找,累加在每一层跳过的节点数,最终在最底层找到目标节点时,累加的数量即为其排名(从 0 开始)。如果未找到则返回 -1。
// GetRank returns the rank of a member
func (sl *SkipList) GetRank(member string, score float64) int {
rank := 0
x := sl.header
for i := sl.level - 1; i >= 0; i-- {
for x.Forward[i] != nil &&
(x.Forward[i].Score < score ||
(x.Forward[i].Score == score && x.Forward[i].Member < member)) {
rank += 1 // Count nodes we're skipping
x = x.Forward[i]
}
}
x = x.Forward[0]
if x != nil && x.Member == member {
return rank
}
return -1 // Member not found
}
这些方法的具体实现可以在 skiplist.go
文件中查看,它们都利用了跳表结构来高效地完成各自的任务。
通过这些结构和方法的组合,SkipList 提供了一种高效的方式来维护一个有序的集合,支持快速的插入、删除和各种范围查找操作,时间复杂度通常为 O(logN)。
db 操作函数
先到 database/db.go
中添加一个方法:
// getAsZSet retrieves the ZSet stored at key, or creates a new one if it doesn't exist
func getAsZSet(db *DB, key string) (zset.ZSet, bool) {
// Get entity from database
entity, exists := db.GetEntity(key)
if !exists {
return zset.NewZSet(), false
}
// Check if entity is a ZSet
zsetObj, ok := entity.Data.(zset.ZSet)
if !ok {
return nil, true // Key exists but is not a ZSet
}
return zsetObj, true
}
这个方法主要用于从数据库获取一个有序集合对象。
这里要求我们有一个 NewZSet
方法来创建一个新的有序集合对象。
我们到 datastruct
下创建一个 zset
文件夹,里面创建一个 zset.go
文件。
我们依然按照面向接口的范式来实现,创建一个 ZSet
接口,一个 zset
结构体。
// ZSet is the interface that represents a Redis sorted set
type ZSet interface {
}
const (
encodingListpack = iota
encodingSkiplist
)
// 用于限制 Listpack 的最大长度,超过长度后,使用 Skiplist 来存储
const listpackMaxSize = 128
type zset struct {
encoding int
listpack [][2]string
dict map[string]float64
skiplist *skiplist.SkipList
}
// New creates a new zset
func NewZSet() ZSet {
return &zset{
encoding: encodingListpack,
listpack: make([][2]string, 0),
}
}
这里我们定义了一个 ZSet
接口,和一个 zset
结构体。然后我们实现了一个 NewZSet
方法来创建一个新的有序集合对象。
这里默认使用 Listpack 来存储有序集合。
实现 ZADD 操作函数
接下来我们实现 ZADD 操作函数。
ZADD 语法
ZADD 是 Redis 中用于添加元素到有序集合的命令。
ZADD key score member [score member ...]
key
:有序集合的键score
:元素的分数member
:元素的值
示例:
ZADD myzset 1 member1 2 member2
表示将 member1
和 member2
添加到 myzset
中,分数分别为 1 和 2。
实现
// execZAdd implements the ZADD command
// ZADD key [NX|XX] [CH] [INCR] score member [score member ...]
func execZAdd(db *DB, args [][]byte) resp.Reply {
if len(args) < 3 || len(args)%2 == 0 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zadd' command")
}
key := string(args[0])
// Get or create ZSet
zsetObj, exists := getAsZSet(db, key)
if exists && zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
added := 0
for i := 1; i < len(args); i += 2 {
scoreStr := string(args[i])
member := string(args[i+1])
// Parse score
score, err := parseFloat(scoreStr)
if err != nil {
return err
}
// Add member to ZSet
if zsetObj.Add(member, score) {
added++
}
}
// Store ZSet in database
db.PutEntity(key, &database.DataEntity{Data: zsetObj})
// Add AOF record
db.addAof(utils.ToCmdLineWithName("ZADD", args...))
return reply.MakeIntReply(int64(added))
}
主要实现思路是:
- 获取参数,检查参数个数是否正确
- 获取或创建 ZSet 对象
- 遍历参数,解析分数和成员,解析分数为浮点数,调用 ZSet 的
Add
方法添加成员 - 将 ZSet 对象存储到数据库中
- 添加 AOF 记录
- 返回添加的成员数量
这里要求一个 Add
方法来添加成员。
ZSet 下的 Add 方法
我们到 datastruct/zset/zset.go
中实现 Add
方法:
这个方法返回一个布尔值,表示成员是否是新添加的。
type ZSet interface {
Add(member string, score float64) bool
}
// Add adds a member with the given score to the sorted set
// Returns true if the element was added as a new member, false if the score was updated
func (z *zset) Add(member string, score float64) bool {
// Check if we're using listpack encoding
if z.encoding == encodingListpack {
// Check if member already exists in listpack
for i, pair := range z.listpack {
if pair[0] == member {
// Update score if member already exists
z.listpack[i][1] = formatScore(score)
return false
}
}
// Add new member to listpack
z.listpack = append(z.listpack, [2]string{member, formatScore(score)})
// Convert to skiplist encoding if listpack grows too large
if len(z.listpack) > listpackMaxSize {
z.convertToSkiplist()
}
return true
}
// Using skiplist encoding
existingScore, exists := z.dict[member]
if exists {
// If score changed, update both dict and skiplist
if existingScore != score {
// Remove from skiplist with old score
z.skiplist.Delete(member, existingScore)
// Insert with new score
z.skiplist.Insert(member, score)
// Update score in dict
z.dict[member] = score
}
return false
}
// Add new member to both dict and skiplist
z.dict[member] = score
z.skiplist.Insert(member, score)
return true
}
首先检查当前编码方式,如果是 Listpack,则遍历 Listpack,检查成员是否已经存在。
如果存在,则更新分数;如果不存在,则添加成员。
如果 Listpack 的长度超过了最大限制,则转换为 Skiplist。
如果当前编码方式是 Skiplist,则直接在字典和跳表中添加成员。
如果在 SkipList 中存在且分数,则删除旧的成员,插入新的成员。
这里要求我们实现一个辅助函数,formatScore
,用于格式化分数。
// Helper function to format score as string
func formatScore(score float64) string {
return fmt.Sprintf("%f", score)
}
这个函数将分数格式化为字符串。
然后实现一个 convertToSkiplist
方法,用于将 Listpack 转换为 Skiplist。
// Convert from listpack to skiplist encoding
func (z *zset) convertToSkiplist() {
if z.encoding == encodingSkiplist {
return
}
// Initialize skiplist and dict
z.skiplist = skiplist.NewSkipList()
z.dict = make(map[string]float64, len(z.listpack))
// Transfer all elements from listpack to skiplist and dict
for _, pair := range z.listpack {
member := pair[0]
score, _ := parseScore(pair[1])
z.dict[member] = score
z.skiplist.Insert(member, score)
}
// Update encoding and clear listpack
z.encoding = encodingSkiplist
z.listpack = nil
}
// Helper function to parse score string to float64
func parseScore(scoreStr string) (float64, error) {
return strconv.ParseFloat(scoreStr, 64)
}
这个方法中,我们初始化一个新的 Skiplist 和字典,然后将 Listpack 中的所有元素转移到 Skiplist 和字典中。
在使用 SkipList 的时候,需要有一个哈希表来存储成员和分数的映射关系来配合 SkipList 的查找。
接下来我们先说说这两个数据结构的配合。
- SkipList(跳表)的角色
- 主要负责排序相关的操作,所有节点按
(score, member)
排好序。 - 支持:
- 按 score 升序/降序遍历(ZRANGE、ZREVRANGE)
- 按 score 区间取值(ZRANGEBYSCORE、ZCOUNT)
- 按排名查元素(ZRANK、ZREVRANK)
- 最小值、最大值的快速定位(ZPOPMIN、ZPOPMAX)
- 只用跳表可以做到:
- O(logN) 查询和插入
- O(logN + M) 范围查询(M是结果数量)
跳表维护的是「有序访问能力」。
- map(哈希表)的角色
- 主要负责通过 member 名字快速定位元素。
- 支持:
- O(1) 查找指定 member 的 score(ZSCORE)
- O(1) 删除指定 member(需要快速找到跳表节点,配合跳表 delete)
- 插入前快速检查是否存在同名元素(ZADD 要判重)
- 如果只有跳表,查 member 就只能 O(logN),引入哈希表能直接 O(1)。
哈希表维护的是「快速精确定位能力」。
假设你有一个 ZSet:
Member | Score |
---|---|
”alice” | 10 |
”bob” | 20 |
”carol” | 15 |
内部数据结构大概长这样:
哈希表(dict)
"alice" -> Node(10, "alice")
"bob" -> Node(20, "bob")
"carol" -> Node(15, "carol")
跳表(skiplist)
head -> (10, "alice") -> (15, "carol") -> (20, "bob") -> nil
如果有人执行:
ZSCORE carol
- → map[“carol”],直接拿到 Node,O(1)。
ZRANGE 0 1
- → skiplist 按顺序遍历两个节点,拿 (10, “alice”), (15, “carol”)。
ZREM bob
- → map[“bob”],拿到 Node
- → 在 skiplist 里精确找到节点,删除 Forward 指针,回收内存。
ZCOUNT 12 18
- → skiplist 从第一个 ≥12 的节点开始遍历,找到 15,符合。
这就是二者配合的结果:快又有序!
到这里我们就实现了 ZADD 操作。
实现 ZSCORE 操作函数
ZSCORE 用于获取有序集合中指定成员的分数。
ZSCORE key member
key
:有序集合的键member
:要查询的成员
示例:
ZSCORE myzset member1
表示获取 myzset
中 member1
的分数。
实现处理函数
在 database/zset.go
中实现 execZScore
函数:
// execZScore implements the ZSCORE command
// ZSCORE key member
func execZScore(db *DB, args [][]byte) resp.Reply {
if len(args) != 2 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zscore' command")
}
key := string(args[0])
member := string(args[1])
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeNullBulkReply()
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
// Get score
score, exists := zsetObj.Score(member)
if !exists {
return reply.MakeNullBulkReply()
}
return reply.MakeBulkReply([]byte(strconv.FormatFloat(score, 'f', -1, 64)))
}
这里要求我们实现一个辅助函数,Score
,用于获取成员的分数。
ZSet 下的 Score 方法
我们到 datastruct/zset/zset.go
中实现 Score
方法:
type ZSet interface {
Add(member string, score float64) bool
Score(member string) (float64, bool)
}
// Score returns the score of a member, and a boolean indicating if the member exists
func (z *zset) Score(member string) (float64, bool) {
if z.encoding == encodingListpack {
for _, pair := range z.listpack {
if pair[0] == member {
score, err := parseScore(pair[1])
if err != nil {
return 0, false
}
return score, true
}
}
return 0, false
}
// Using skiplist encoding
score, exists := z.dict[member]
return score, exists
}
这里我们首先检查当前编码方式,如果是 Listpack,则遍历 Listpack,检查成员是否已经存在。
如果当前编码方式是 Skiplist,则直接在字典中查找成员的分数。
如果存在,则返回分数;如果不存在,则返回 0 和 false。
实现 ZCARD 操作函数
ZCARD 用于获取有序集合的成员数量。
ZCARD key
key
:有序集合的键- 返回有序集合的成员数量
示例:
ZCARD myzset
表示获取 myzset
中的成员数量。
实现处理函数
在 database/zset.go
中实现 execZCard
函数:
// execZCard implements the ZCARD command
// ZCARD key
func execZCard(db *DB, args [][]byte) resp.Reply {
if len(args) != 1 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zcard' command")
}
key := string(args[0])
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeIntReply(0)
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
return reply.MakeIntReply(int64(zsetObj.Len()))
}
这里要求我们实现一个辅助函数,Len
,用于获取有序集合的成员数量。
ZSet 下的 Len 方法
我们到 datastruct/zset/zset.go
中实现 Len
方法:
type ZSet interface {
Add(member string, score float64) bool
Score(member string) (float64, bool)
Len() int
}
// Len returns the number of elements in the sorted set
func (z *zset) Len() int {
if z.encoding == encodingListpack {
return len(z.listpack)
}
return len(z.dict)
}
这个实现比较简单,直接返回 Listpack 或字典的长度即可。
实现 ZRANGE 操作函数
ZRANGE 用于获取有序集合中指定范围内的成员。
ZRANGE key start stop [WITHSCORES]
key
:有序集合的键start
:起始索引(从 0 开始)stop
:结束索引(从 0 开始)WITHSCORES
:可选参数,表示返回成员和分数
示例:
ZRANGE myzset 0 1 WITHSCORES
表示获取 myzset
中索引从 0 到 1 的成员和分数。
实现处理函数
在 database/zset.go
中实现 execZRange
函数:
// execZRange implements the ZRANGE command
// ZRANGE key start stop [WITHSCORES]
func execZRange(db *DB, args [][]byte) resp.Reply {
if len(args) < 3 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zrange' command")
}
withScores := false
if len(args) > 3 && string(args[3]) == "WITHSCORES" {
withScores = true
}
key := string(args[0])
// Parse start and stop indices
start, err := strconv.Atoi(string(args[1]))
if err != nil {
return reply.MakeStandardErrorReply("value is not an integer or out of range")
}
stop, err := strconv.Atoi(string(args[2]))
if err != nil {
return reply.MakeStandardErrorReply("value is not an integer or out of range")
}
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeEmptyMultiBulkReply()
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
// Get range
members := zsetObj.RangeByRank(start, stop)
// Prepare result
if !withScores {
result := make([][]byte, len(members))
for i, member := range members {
result[i] = []byte(member)
}
return reply.MakeMultiBulkReply(result)
} else {
result := make([][]byte, len(members)*2)
for i, member := range members {
result[i*2] = []byte(member)
score, _ := zsetObj.Score(member)
result[i*2+1] = []byte(strconv.FormatFloat(score, 'f', -1, 64))
}
return reply.MakeMultiBulkReply(result)
}
}
这里我们实现了 ZRANGE
命令的处理函数,主要思路是:
- 检查参数个数是否正确
- 解析起始和结束索引
- 获取 ZSet 对象
- 调用 ZSet 的
RangeByRank
方法获取指定范围内的成员 - 根据是否需要分数,准备返回结果
- 返回结果
这里要求我们 RangeByRank
方法来获取指定范围内的成员。
ZSet 下的 RangeByRank 方法
type ZSet interface {
// ...
RangeByRank(start, stop int) []string
}
// RangeByRank returns members ordered by rank (position)
// Returns members between start and stop ranks (inclusive, 0-based)
func (z *zset) RangeByRank(start, stop int) []string {
if z.encoding == encodingListpack {
// Copy and sort listpack elements by score
pairs := make([][2]string, len(z.listpack))
copy(pairs, z.listpack)
sort.Slice(pairs, func(i, j int) bool {
scoreI, _ := parseScore(pairs[i][1])
scoreJ, _ := parseScore(pairs[j][1])
return scoreI < scoreJ
})
// Handle negative indices and out of range
size := len(pairs)
if start < 0 {
start = size + start
}
if stop < 0 {
stop = size + stop
}
if start < 0 {
start = 0
}
if stop >= size {
stop = size - 1
}
if start > stop || start >= size {
return []string{}
}
// Extract member names
result := make([]string, 0, stop-start+1)
for i := start; i <= stop; i++ {
result = append(result, pairs[i][0])
}
return result
}
// Using skiplist encoding
return z.skiplist.RangeByRank(start, stop)
}
实现 ZREM 操作函数
ZREM 用于删除有序集合中的成员。
ZREM key member [member ...]
-
key
:有序集合的键 -
member
:要删除的成员 -
示例:
ZREM myzset member1 member2
表示删除 myzset
中的 member1
和 member2
。
实现处理函数
在 database/zset.go
中实现 execZRem
函数:
// execZRem implements the ZREM command
// ZREM key member [member ...]
func execZRem(db *DB, args [][]byte) resp.Reply {
if len(args) < 2 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zrem' command")
}
key := string(args[0])
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeIntReply(0)
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
// Remove members
removed := 0
for i := 1; i < len(args); i++ {
member := string(args[i])
if zsetObj.Remove(member) {
removed++
}
}
// Update database if we removed anything
if removed > 0 {
db.PutEntity(key, &database.DataEntity{Data: zsetObj})
// Add AOF record
db.addAof(utils.ToCmdLineWithName("ZREM", args...))
}
return reply.MakeIntReply(int64(removed))
}
这里要求我们实现一个辅助函数,Remove
,用于删除成员。
ZSet 下的 Remove 方法
我们到 datastruct/zset/zset.go
中实现 Remove
方法:
type ZSet interface {
// ...
Remove(member string) bool
}
// Remove removes a member from the sorted set
// Returns true if the member was removed, false if it didn't exist
func (z *zset) Remove(member string) bool {
if z.encoding == encodingListpack {
for i, pair := range z.listpack {
if pair[0] == member {
// Remove the member by slicing it out
z.listpack = append(z.listpack[:i], z.listpack[i+1:]...)
return true
}
}
return false
}
// Using skiplist encoding
score, exists := z.dict[member]
if exists {
z.skiplist.Delete(member, score)
delete(z.dict, member)
return true
}
return false
}
实现 ZCOUNT 操作函数
ZCOUNT 用于获取有序集合中指定分数范围内的成员数量。
ZCOUNT key min max
key
:有序集合的键min
:最小分数max
:最大分数- 示例:
ZCOUNT myzset 1 2
表示获取 myzset
中分数在 1 到 2 之间的成员数量。
实现处理函数
在 database/zset.go
中实现 execZCount
函数:
// execZCount implements the ZCOUNT command
// ZCOUNT key min max
func execZCount(db *DB, args [][]byte) resp.Reply {
if len(args) != 3 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zcount' command")
}
key := string(args[0])
// Parse min and max scores
min, err := parseFloat(string(args[1]))
if err != nil {
return err
}
max, err := parseFloat(string(args[2]))
if err != nil {
return err
}
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeIntReply(0)
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
// Count elements in range
count := zsetObj.Count(min, max)
return reply.MakeIntReply(int64(count))
}
这里要求我们实现一个辅助函数,Count
,用于获取指定分数范围内的成员数量。
ZSet 下的 Count 方法
我们到 datastruct/zset/zset.go
中实现 Count
方法:
type ZSet interface {
// ...
Count(min, max float64) int
}
// Count returns the number of elements in the specified score range
func (z *zset) Count(min, max float64) int {
if z.encoding == encodingListpack {
count := 0
for _, pair := range z.listpack {
score, _ := parseScore(pair[1])
if score >= min && score <= max {
count++
}
}
return count
}
// Using skiplist encoding
return z.skiplist.CountInRange(min, max)
}
实现 ZRANK 操作函数
ZRANK 用于获取有序集合中指定成员的排名。
ZRANK key member
-
key
:有序集合的键 -
member
:要查询的成员 -
示例:
ZRANK myzset member1
表示获取 myzset
中 member1
的排名。
实现处理函数
在 database/zset.go
中实现 execZRank
函数:
// execZRank implements the ZRANK command
// ZRANK key member
func execZRank(db *DB, args [][]byte) resp.Reply {
if len(args) != 2 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'zrank' command")
}
key := string(args[0])
member := string(args[1])
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeNullBulkReply()
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
// Get member's rank
score, exists := zsetObj.Score(member)
if !exists {
return reply.MakeNullBulkReply()
}
// Using skiplist's GetRank method
rank := -1
if zsetObj.Encoding() == 1 { // Using skiplist encoding
// We need to access the skiplist from the ZSet implementation
skiplist := zsetObj.GetSkiplist()
rank = skiplist.GetRank(member, score)
} else {
// For listpack encoding, we need to compute rank by sorting
members := zsetObj.RangeByRank(0, -1)
for i, m := range members {
if m == member {
rank = i
break
}
}
}
if rank == -1 {
return reply.MakeNullBulkReply()
}
return reply.MakeIntReply(int64(rank))
}
这里要求我们实现一个辅助函数,Encoding
,用于获取当前编码方式。
以及实现一个辅助函数,GetSkiplist
,用于获取 Skiplist 对象。
ZSet 下的 Encoding 方法和 GetSkiplist 方法
我们到 datastruct/zset/zset.go
中实现 Encoding
方法:
// Encoding returns the current encoding type of the zset (0 for listpack, 1 for skiplist)
func (z *zset) Encoding() int {
return z.encoding
}
// GetSkiplist returns the skiplist used by the zset when in skiplist encoding
// Returns nil if the zset is using listpack encoding
func (z *zset) GetSkiplist() *skiplist.SkipList {
if z.encoding == encodingSkiplist {
return z.skiplist
}
return nil
}
然后我们将这几个指令注册:
// Register ZSET commands
func init() {
RegisterCommand("ZADD", execZAdd, -4) // key score member [score member ...]
RegisterCommand("ZSCORE", execZScore, 3) // key member
RegisterCommand("ZCARD", execZCard, 2) // key
RegisterCommand("ZRANGE", execZRange, -4) // key start stop [WITHSCORES]
RegisterCommand("ZREM", execZRem, -3) // key member [member ...]
RegisterCommand("ZCOUNT", execZCount, 4) // key min max
RegisterCommand("ZRANK", execZRank, 3) // key member
}
完整的 ZSET 接口:
type ZSet interface {
Add(member string, score float64) bool
Remove(member string) bool
Score(member string) (float64, bool)
Exists(member string) bool
Count(min, max float64) int
Len() int
RangeByScore(min, max float64, offset, count int) []string
RangeByRank(start, stop int) []string
RemoveRangeByRank(start, stop int) int
RemoveRangeByScore(min, max float64) int
Encoding() int
GetSkiplist() *skiplist.SkipList
}
然后为了适配集群模式,在 cluster/router.go
中注册 ZSET 指令:
// ZSet operations
routerMap["zadd"] = defaultFunc // zadd key score member [score member ...]
routerMap["zscore"] = defaultFunc // zscore key member
routerMap["zcard"] = defaultFunc // zcard key
routerMap["zrange"] = defaultFunc // zrange key start stop [WITHSCORES]
routerMap["zrem"] = defaultFunc // zrem key member [member ...]
routerMap["zcount"] = defaultFunc // zcount key min max
routerMap["zrank"] = defaultFunc // zrank key member
测试
我们实现一个自定义的指令 ZTYPE
用于查询当前 key 的底层编码方式:
// execZTYPE implements the ZTYPE command
// ZTYPE key returns the type of the key, 0 for listpack, 1 for skiplist
func execZType(db *DB, args [][]byte) resp.Reply {
if len(args) != 1 {
return reply.MakeStandardErrorReply("wrong number of arguments for 'ztype' command")
}
key := string(args[0])
// Get ZSet
zsetObj, exists := getAsZSet(db, key)
if !exists {
return reply.MakeNullBulkReply()
}
if zsetObj == nil {
return reply.MakeWrongTypeErrReply()
}
return reply.MakeIntReply(int64(zsetObj.Encoding()))
}
// Register ZTYPE command
func init() {
RegisterCommand("ZTYPE", execZType, 2) // key
}
然后运行下面的指令进行测试:
(base) orangejuice@Mac redigo % redis-cli -p 6380
127.0.0.1:6380> ZADD myzset 1 "one"
(integer) 1
127.0.0.1:6380> ZADD myzset 2 "two" 3 "three"
(integer) 2
127.0.0.1:6380> ZCARD myzset
(integer) 3
127.0.0.1:6380> ZSCORE myzset "one"
"1"
127.0.0.1:6380> ZSCORE myzset "two"
"2"
127.0.0.1:6380> ZSCORE myzset "nonexistent"
(nil)
127.0.0.1:6380> ZRANGE myzset 0 -1
1) "one"
2) "two"
3) "three"
127.0.0.1:6380> ZRANGE myzset 0 1
1) "one"
2) "two"
127.0.0.1:6380> ZRANGE myzset -2 -1
1) "two"
2) "three"
127.0.0.1:6380> ZRANGE myzset 0 -1 WITHSCORES
1) "one"
2) "1"
3) "two"
4) "2"
5) "three"
6) "3"
127.0.0.1:6380> ZREM myzset "two"
(integer) 1
127.0.0.1:6380> ZRANGE myzset 0 -1
1) "one"
2) "three"
127.0.0.1:6380> ZADD myzset 2 "two" 4 "four" 5 "five"
(integer) 3
127.0.0.1:6380> ZCOUNT myzset 2 4
(integer) 3
127.0.0.1:6380> ZRANGE myzset 0 -1
1) "one"
2) "two"
3) "three"
4) "four"
5) "five"
到目前为了,我们就实现了所有的常用的 Redis 数据结构:
- String
- List
- Hash
- Set
- ZSet (Sorted Set)
提交到 GitHub
git add .
git commit -m "feat: implement ZSet"
git push origin main