创建一个 TCP 服务器
本文进度对应的代码仓库:TCP 服务创建
在本文中,项目结构如下:
- config.go
- handler.go
- logger.go
- files.go
- bool.go
- wait.go
- echo.go
- server.go
- main.go
- redis.conf
- go.mod
为什么要实现一个 TCP 服务器呢?因为 Redis 服务器要实现客户端与服务器之间的通信,而 Redis 服务器使用的是 TCP 协议。因此,我们需要先实现一个 TCP 服务器,然后在此基硃上实现 Redis 服务器。
Config 加载
Redis 中的配置文件是 redis.conf
,用于存储 Redis 服务器的配置信息。
bind
:指定 Redis 服务器监听的 IP 地址和端口号。port
:指定 Redis 服务器监听的端口号。appendonly
:指定是否启用 AOF 持久化。appendfilename
:指定 AOF 持久化文件的名称。maxclients
:指定 Redis 服务器的最大客户端连接数。databases
:指定 Redis 服务器的数据库数量。requirepass
:指定 Redis 服务器的密码。peers
:指定 Redis 服务器的对等节点。self
:指定 Redis 服务器的自身节点。
创建一个 config.go
文件,用于加载 Redis 配置文件。
type ServerProperties struct {
Bind string `cfg:"bind"`
Port int `cfg:"port"`
AppendOnly bool `cfg:"appendOnly"`
AppendFilename string `cfg:"appendFilename"`
MaxClients int `cfg:"maxClients"`
RequirePass string `cfg:"requirePass"`
Databases int `cfg:"databases"`
Peers []string `cfg:"peers"`
Self string `cfg:"self"`
}
创建一个变量 Properties
,用于存储 Redis 配置文件的内容。
var Properties *ServerProperties
调用生命周期函数 init
,用于加载 Redis 配置文件。
func init() {
// default config
Properties = &ServerProperties{
Bind: "127.0.0.1",
Port: 6379,
AppendOnly: false,
}
}
创建一个函数 parse
,用于解析 Redis 配置文件。主要需要完成:
- 逐行读取配置文件,跳过注释行。
- 解析配置文件,使用反射根据 ServerProperties 结构体的字段类型,将字符串值转换为相应的类型。
func parse(src io.Reader) *ServerProperties {
config := &ServerProperties{}
// read config file
rawMap := make(map[string]string)
scanner := bufio.NewScanner(src)
for scanner.Scan() {
line := scanner.Text()
if len(line) > 0 && line[0] == '#' {
continue
}
pivot := strings.IndexAny(line, " ")
if pivot > 0 && pivot < len(line)-1 { // separator found
key := line[0:pivot]
value := strings.Trim(line[pivot+1:], " ")
rawMap[strings.ToLower(key)] = value
}
}
if err := scanner.Err(); err != nil {
logger.Fatal(err)
}
// parse format
t := reflect.TypeOf(config)
v := reflect.ValueOf(config)
n := t.Elem().NumField()
for i := 0; i < n; i++ {
field := t.Elem().Field(i)
fieldVal := v.Elem().Field(i)
key, ok := field.Tag.Lookup("cfg")
if !ok {
key = field.Name
}
value, ok := rawMap[strings.ToLower(key)]
if ok {
// fill config
switch field.Type.Kind() {
case reflect.String:
fieldVal.SetString(value)
case reflect.Int:
intValue, err := strconv.ParseInt(value, 10, 64)
if err == nil {
fieldVal.SetInt(intValue)
}
case reflect.Bool:
boolValue := "yes" == value
fieldVal.SetBool(boolValue)
case reflect.Slice:
if field.Type.Elem().Kind() == reflect.String {
slice := strings.Split(value, ",")
fieldVal.Set(reflect.ValueOf(slice))
}
}
}
}
return config
}
接下来,创建一个函数 SetupConfig
,用于加载 Redis 配置文件。
// SetupConfig 配置初始化
func SetupConfig(configFilename string) {
file, err := os.Open(configFilename)
if err != nil {
panic(err)
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
}
}(file)
Properties = parse(file)
}
这样后续就可以在 main.go
中调用 SetupConfig
函数,加载 Redis 配置文件。
const configFile string = "redis.conf"
func fileExists(filename string) bool {
info, err := os.Stat(filename)
return err == nil && !info.IsDir()
}
func main() {
if fileExists(configFile) {
config.SetupConfig(configFile)
} else {
config.Properties = defaultProperties
}
}
日志
Redis 服务器使用日志记录器来记录服务器的运行状态。
创建一个新的包,名为 logger
。创建一个 logger.go
文件,用于创建日志记录器。这里日志不重要只需要把下面的代码拷贝到 logger.go
中即可。
package logger
import (
"fmt"
"io"
"log"
"os"
"path/filepath"
"runtime"
"sync"
"time"
)
// Settings stores config for logger
type Settings struct {
Path string `yaml:"path"`
Name string `yaml:"name"`
Ext string `yaml:"ext"`
TimeFormat string `yaml:"time-format"`
}
var (
logFile *os.File
defaultPrefix = ""
defaultCallerDepth = 2
logger *log.Logger
mu sync.Mutex
logPrefix = ""
levelFlags = []string{"DEBUG", "INFO", "WARN", "ERROR", "FATAL"}
)
type logLevel int
// log levels
const (
DEBUG logLevel = iota
INFO
WARNING
ERROR
FATAL
)
const flags = log.LstdFlags
func init() {
logger = log.New(os.Stdout, defaultPrefix, flags)
}
// Setup initializes logger
func Setup(settings *Settings) {
var err error
dir := settings.Path
fileName := fmt.Sprintf("%s-%s.%s",
settings.Name,
time.Now().Format(settings.TimeFormat),
settings.Ext)
logFile, err := mustOpen(fileName, dir)
if err != nil {
log.Fatalf("logging.Setup err: %s", err)
}
mw := io.MultiWriter(os.Stdout, logFile)
logger = log.New(mw, defaultPrefix, flags)
}
func setPrefix(level logLevel) {
_, file, line, ok := runtime.Caller(defaultCallerDepth)
if ok {
logPrefix = fmt.Sprintf("[%s][%s:%d] ", levelFlags[level], filepath.Base(file), line)
} else {
logPrefix = fmt.Sprintf("[%s] ", levelFlags[level])
}
logger.SetPrefix(logPrefix)
}
// Debug prints debug log
func Debug(v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(DEBUG)
logger.Println(v...)
}
// Info prints normal log
func Info(v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(INFO)
logger.Println(v...)
}
// Warn prints warning log
func Warn(v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(WARNING)
logger.Println(v...)
}
// Error prints error log
func Error(v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(ERROR)
logger.Println(v...)
}
// Fatal prints error log then stop the program
func Fatal(v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(FATAL)
logger.Fatalln(v...)
}
这里主要创建了 Debug
、Info
、Warn
、Error
、Fatal
五个函数,用于打印不同级别的日志。
每个日志的处理函数中,都使用了 mu.Lock()
和 mu.Unlock()
来保证线程安全。logger 是一个全局变量,如果多个 goroutine 同时调用这些函数,可能会导致日志输出的顺序不一致。
使用 mu.Lock()
和 mu.Unlock()
可以保证每个 goroutine 在调用日志函数时,不会被其他 goroutine 打断。
logger 需要调用 io 读写本地文件,因此需要创建一个 mustOpen
函数,用于打开文件。
package logger
import (
"fmt"
"os"
)
func checkNotExist(src string) bool {
_, err := os.Stat(src)
return os.IsNotExist(err)
}
func checkPermission(src string) bool {
_, err := os.Stat(src)
return os.IsPermission(err)
}
func isNotExistMkDir(src string) error {
if notExist := checkNotExist(src); notExist == true {
if err := mkDir(src); err != nil {
return err
}
}
return nil
}
func mkDir(src string) error {
err := os.MkdirAll(src, os.ModePerm)
if err != nil {
return err
}
return nil
}
func mustOpen(fileName, dir string) (*os.File, error) {
perm := checkPermission(dir)
if perm == true {
return nil, fmt.Errorf("permission denied dir: %s", dir)
}
err := isNotExistMkDir(dir)
if err != nil {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
}
f, err := os.OpenFile(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err)
}
return f, nil
}
TCP 服务
Redis 服务器使用 TCP 协议进行通信,因此需要创建一个 TCP 服务。
接口
使用 Go 语言的接口,可以实现 TCP 的处理逻辑与具体实现的解耦。我们通过接口定义两个必须实现的方法:
type Handler interface {
Handle(ctx context.Context, conn net.Conn)
Close() error
}
接下来我们要实现这两个接口。
处理器
我们先创建一个简单的回显服务器,接收客户端发送的消息,然后将接收到的消息原样返回客户端。来测试我们的 TCP 服务是否正常工作。
创建一个 tcp
包,用于实现 TCP 服务。我们创建一个 echo.go
文件,用于实现回显服务器。
首先创建一个结构体 EchoHandler
,用于处理客户端连接。负责管理所有活动的客户端连接,并处理每个连接的回显逻辑。
type EchoHandler struct {
activeConn sync.Map
closing atomic.Boolean
}
activeConn
:用于存储所有活动的客户端连接。这是一个并发安全的 map,用于存储所有活动的客户端连接。closing
:用于标记服务器是否正在关闭。这是一个原子操作,保证在多线程环境下的线程安全。如果closing
为 true,则表示服务器正在关闭,不再接受新的客户端连接,并清理所有的客户端连接。
我们下面要实现原子操作的辅助函数,主要用于在并发环境下安全地读写布尔值。这是因为多个 goroutine 正在执行,一个 goroutine 可能正在设置关闭标志(通过 Close 方法),而另一个 goroutine 可能正在读取关闭标志(通过 IsClosed 方法)。如果不使用原子操作,可能会出现数据竞争,导致一个 goroutine 看到部分更新的值,从而产生不可预测的行为。
创建 lib/sync/atomic/bool.go
文件,用于实现原子操作的辅助函数。
type Boolean uint32
// Get reads the value atomically
func (b *Boolean) Get() bool {
return atomic.LoadUint32((*uint32)(b)) != 0
}
// Set writes the value atomically
func (b *Boolean) Set(v bool) {
if v {
atomic.StoreUint32((*uint32)(b), 1)
} else {
atomic.StoreUint32((*uint32)(b), 0)
}
}
接下来,我们要实现 MakeHandler
函数,用于创建 EchoHandler
实例。
func MakeHandler() *EchoHandler {
return &EchoHandler{}
}
这里实际上用到了设计模式中的“工厂模式”,虽然直接使用
new(EchoHandler)
也可以创建EchoHandler
实例,但是使用MakeHandler
函数可以更好地隐藏实现细节,提高代码的可维护性。例如,后续我们可能会修改EchoHandler
的实现,如果直接使用new(EchoHandler)
创建实例,那么所有调用new(EchoHandler)
的地方都需要修改,而使用MakeHandler
函数创建实例,只需要修改MakeHandler
函数即可。
然后创建 EchoClient
结构体,用于管理单个客户端连接。
type EchoClient struct {
Conn net.Conn
Waiting wait.Wait
}
Conn
:用于存储客户端连接。是使用的net.Conn
接口,用于表示网络连接。这是 golang 标准库提供的网络连接接口。Waiting
:用于等待客户端连接的关闭。这是一个wait.Wait
类型,用于等待某个操作完成。
对于每个客户端要实现一个关闭方法,用于关闭客户端连接。
func (c *EchoClient) Close() error {
c.Waiting.WaitWithTimeout(10 * time.Second)
err := c.Conn.Close()
if err != nil {
return err
}
return nil
}
上面的代码为 Close 操作设置了一个超时时间,超时时间为 10 秒。若果客户端连接在 10 秒内没有关闭,就会返回超时错误。
这里的 WaitWithTimeout
函数是一个等待函数,用于等待客户端连接的关闭。它会等待一段时间,如果客户端连接没有关闭,就会返回超时错误。这有利于避免程序永久阻塞,在超时的情况下可以及时关闭连接。实现优雅的关闭服务。
这个等待超时函数需要我们自己实现。我们在 lib/sync/wait
包中实现一个 Wait
结构体,用于等待某个操作完成。
我们主要基于 sync.WaitGroup
进行拓展,增加一个等待超时的功能。其他的 Add
Done
Wait
方法都是基于 sync.WaitGroup
的 API 直接调用的。
type Wait struct {
wg sync.WaitGroup
}
wg
:用于等待某个操作完成。这是一个sync.WaitGroup
类型,用于等待一组操作完成。
在这个结构体上我们实现 Add
Done
Wait
WaitWithTimeout
四个方法。分别用于:
Add
:用于增加等待组的计数器。Done
:用于减少等待组的计数器。Wait
:用于等待等待组的计数器变为 0。WaitWithTimeout
:用于等待等待组的计数器变为 0,或者等待一段时间。
// Add adds delta, which may be negative, to the WaitGroup counter.
func (w *Wait) Add(delta int) {
w.wg.Add(delta)
}
// Done decrements the WaitGroup counter by one
func (w *Wait) Done() {
w.wg.Done()
}
// Wait blocks until the WaitGroup counter is zero.
func (w *Wait) Wait() {
w.wg.Wait()
}
// WaitWithTimeout blocks until the WaitGroup counter is zero or timeout
// returns true if timeout
func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
c := make(chan bool, 1)
go func() {
defer close(c)
w.wg.Wait()
c <- true
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
}
WaitWithTimeout
创建一个通道用于通知完成,然后启动一个 goroutine 等待计数器变为 0。然后使用 select
语句等待两个通道中的任意一个信号。如果计数器变为 0,就会从 c
通道中接收到一个值,然后返回 false。如果超时,就会从 time.After(timeout)
通道中接收到一个值,然后返回 true,表示超时。
接下来我们实现处理器 EchoHandler
的 Handle
方法。
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
// 发现客户端正在关闭
if h.closing.Get() {
// 拒绝新的客户端连接
_ = conn.Close()
}
// 新的客户端连接
client := &EchoClient{
Conn: conn,
}
// 在 EchoHandler 中存储客户端连接
h.activeConn.Store(client, struct{}{})
reader := bufio.NewReader(conn)
// 循环接收客户端发送的消息
for {
// may occurs: client EOF, client timeout, server early close
// 接收客户端发送的消息,直到遇到换行符
msg, err := reader.ReadString('\n')
// 当遇到错误时,关闭客户端连接
if err != nil {
if err == io.EOF {
logger.Info("connection close")
// 从 EchoHandler 中删除客户端发生错误的连接
h.activeConn.Delete(client)
} else {
logger.Warn(err)
}
return
}
// 该客户端正在处理的消息数量加一
client.Waiting.Add(1)
b := []byte(msg)
// 将接收到的消息原样返回客户端
_, _ = conn.Write(b)
// 该客户端正在处理的消息数量减一
client.Waiting.Done()
}
}
- 当发现服务器正在关闭时,拒绝新的客户端连接。
- 当客户端连接发生错误时,关闭客户端连接。
- 当客户端连接正常时,将接收到的消息原样返回客户端。
最后我们实现 Close
方法,用于关闭服务器。
func (h *EchoHandler) Close() error {
logger.Info("handler shutting down...")
// 标记服务器正在关闭
h.closing.Set(true)
// 关闭所有的客户端连接
h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*EchoClient)
_ = client.Close()
return true
})
return nil
}
- 标记服务器正在关闭。
- 关闭所有的客户端连接。
服务
在 tcp
包中创建一个 server.go
文件,用于创建 TCP 服务。
创建一个函数 ListenAndServeWithSignal
,用于监听端口号,并处理客户端连接。
func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error {
closeChan := make(chan struct{})
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-sigCh
switch sig {
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
closeChan <- struct{}{}
}
}()
listener, err := net.Listen("tcp", cfg.Address)
if err != nil {
return err
}
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
ListenAndServe(listener, handler, closeChan)
return nil
}
创建两个通道: closeChan 用于触发服务关闭, sigCh 用于接收系统信号,监听四种系统信号:
- SIGHUP:终端断开
- SIGQUIT:退出信号
- SIGTERM:终止信号
- SIGINT:中断信号(通常是 Ctrl+C)
启动一个协程监听系统信号,当收到上述任何一个信号时,向 closeChan 发送关闭信号。
<-
是 Go 语言中的一个运算符,用于从通道中接收数据。在下面的 ListenAndServe 函数中,当收到 closeChan 通道的关闭信号时,会立即关闭 listener 和 handler。
当执行到 go func() { ... }()
时:立即创建一个新的 goroutine,这个 goroutine 会在后台运行,主程序会继续执行下面的代码,不会等待这个 goroutine
然后创建函数 ListenAndServe
,用于监听端口号,并处理客户端连接。
func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) {
// listen signal
go func() {
<-closeChan
logger.Info("shutting down...")
_ = listener.Close() // 收到信号,关闭 listener
_ = handler.Close() // 收到信号,关闭 handler
}()
// 使用 defer 注册延迟执行,当出现 panic 时,会执行 defer 注册的函数
// 这里的 defer 函数会在 ListenAndServe 函数返回前执行
defer func() {
// close during unexpected error
_ = listener.Close()
_ = handler.Close()
}()
// 创建一个空的 context,用于传递上下文信息
ctx := context.Background()
// 循环接受客户端连接
var waitDone sync.WaitGroup
for {
conn, err := listener.Accept()
if err != nil {
break
}
// handle
logger.Info("accept link")
waitDone.Add(1)
go func() {
defer func() {
waitDone.Done()
}()
handler.Handle(ctx, conn)
}()
}
waitDone.Wait()
}
这里的 var waitDone sync.WaitGroup
是一个同步等待组,用于等待所有的 goroutine 执行完毕。
当有一个新的客户端连接时,会创建一个新的 goroutine 来处理这个连接。然后使用 waitDone.Add(1)
来增加等待组的计数器,当这个 goroutine 执行完毕时,会调用 waitDone.Done()
来减少等待组的计数器。
到最后,使用 waitDone.Wait()
来等待所有的 goroutine 执行完毕再返回。
这样有助于确保所有的连接都被正确处理,并且在所有的连接都处理完毕后再关闭服务器。
测试
首先确保 redis.conf
文件存在并包含基本配置:
bind 0.0.0.0
port 6379
然后在根目录下执行:
go run main.go
启动服务器。
得到下面的输出:
[INFO][server.go:40] 2025/03/26 16:44:15 bind: 0.0.0.0:6379, start listening...
新建命令行,使用 netcat 工具连接服务器:
nc localhost 6379
然后可以随意输入内容,服务器会将输入的内容原样返回。
(base) orangejuice@MyMac redigo % nc localhost 6379
hello
hello
nihao
nihao
song
song
orange
orange
按住 Ctrl+C
关闭服务器。
输出:
[INFO][echo.go:80] 2025/03/26 16:50:23 handler shutting down...
支持并发,可以开启多个命令行进行连接测试。
总结
这篇文章主要介绍了如何使用 Go 语言创建一个简单的 TCP 回显服务器。为后续我们的 Redis 服务器打下基础。
提交到 GitHub
git add .
git commit -m "feat: add echo server"
git push