Code前端首页关于Code前端联系我们

Go语言实现的几种限流算法

terry 2年前 (2023-09-27) 阅读数 99 #数据结构与算法

1.漏桶算法

  • 算法思想

是一种与令牌桶“逆向”的算法。当请求到来时,它首先被放入木桶中,worker使用固定的速率从木桶中取出请求。如果桶已满,立即返回错误码,表示请求速率超过限制或页面

  • 适用场景

目前流量最均匀的限流方式,通常用于流量整形,如流量限制以保护数据库。首先将数据库访问添加到桶中,然后以db可以承受的qps从桶中拉取请求来访问数据库。不适合电商紧急购物、微博热门事件等场景限电。首先,应对突发流量不太灵活。其次,它为每个 user_id/ip 维护一个队列(桶),worker 从这些队列中检索数据。运行任务会消耗大量资源。

  • Go语言实现

通常使用队列来实现。在Go语言中,可以通过缓冲通道快速部署。将任务添加到通道中,并开启一定数量的workers,从通道中获取任务执行。

package main

import (
 "fmt"
 "sync"
 "time"
)

// 每个请求来了,把需要执行的业务逻辑封装成Task,放入木桶,等待worker取出执行
type Task struct {
 handler func() Result // worker从木桶中取出请求对象后要执行的业务逻辑函数
 resChan chan Result   // 等待worker执行并返回结果的channel
 taskID  int
}

// 封装业务逻辑的执行结果
type Result struct {
}

// 模拟业务逻辑的函数
func handler() Result {
 time.Sleep(300 * time.Millisecond)
 return Result{}
}

func NewTask(id int) Task {
 return Task{
  handler: handler,
  resChan: make(chan Result),
  taskID:  id,
 }
}

// 漏桶
type LeakyBucket struct {
 BucketSize int       // 木桶的大小
 NumWorker  int       // 同时从木桶中获取任务执行的worker数量
 bucket     chan Task // 存方任务的木桶
}

func NewLeakyBucket(bucketSize int, numWorker int) *LeakyBucket {
 return &LeakyBucket{
  BucketSize: bucketSize,
  NumWorker:  numWorker,
  bucket:     make(chan Task, bucketSize),
 }
}

func (b *LeakyBucket) validate(task Task) bool {
 // 如果木桶已经满了,返回false
 select {
 case b.bucket <- task:
 default:
  fmt.Printf("request[id=%d] is refused\n", task.taskID)
  return false
 }

 // 等待worker执行
 <-task.resChan
 fmt.Printf("request[id=%d] is run\n", task.taskID)
 return true
}

func (b *LeakyBucket) Start() {
 // 开启worker从木桶拉取任务执行
 go func() {
  for i := 0; i < b.NumWorker; i++ {
   go func() {
    for {
     task := <-b.bucket
     result := task.handler()
     task.resChan <- result
    }
   }()
  }
 }()
}

func main() {
 bucket := NewLeakyBucket(10, 4)
 bucket.Start()

 var wg sync.WaitGroup
 for i := 0; i < 20; i++ {
  wg.Add(1)
  go func(id int) {
   defer wg.Done()
   task := NewTask(id)
   bucket.validate(task)
  }(i)
 }
 wg.Wait()
}

2.令牌桶算法

  • 算法思想

想象有一个木桶,令牌以固定的速率添加到桶中。一旦桶满了,就不再添加令牌。当服务收到请求时,它会尝试从桶中检索令牌。如果能够获取到token,则继续执行后续的业务逻辑;如果没有收到token,会立即返回错误码或者查询率超过限制的页面等,并且不会继续执行。后续业务逻辑的特点

  • :因为只要桶里有令牌就可以处理请求,所以令牌桶算法可以支持突发流量。同时,由于向桶中添加代币的速率是固定的,并且桶的容量有上限,因此也可以控制单位时间处理的请求数量,以达到限流的目的。假设token添加率为1token/10ms,桶容量为500。当请求比较小时(每10毫秒少于1个请求),桶可以先“存储”一些令牌(最多500个)。 。当流量很大时,桶里的令牌一下子被清空,即同时执行500个业务逻辑。此后,必须每 10 毫秒补充一个新令牌,然后才能接收新请求。
  • 参数设置:桶容量——考虑业务逻辑的资源消耗以及机器可以同时处理多少业务逻辑。令牌生成的速度——如果太慢,将无法“积累”令牌来应对突发流量。
  • 适用场景:适合电商抢购或者微博热门事件等场景,在限流的同时可以处理一定量的突发流量。如果采用统一速率处理请求的算法,就可以避免热点期间大量用户无法访问,提高用户体验。
  • Go语言实现:假设每100ms生成一个token,并根据user_id/IP记录最近一次访问的时间戳t_last和token数量。如果现在 - 最后 > 100 毫秒,则为每个请求递增(现在 - 最后)。 /100ms 令牌。如果令牌数>0,则令牌数-1则继续执行后续业务逻辑,否则返回错误码或请求速率超过限制的页面。
package main

import (
 "fmt"
 "sync"
 "time"
)

// 并发访问同一个user_id/ip的记录需要上锁
var recordMu map[string]*sync.RWMutex

func init() {
 recordMu = make(map[string]*sync.RWMutex)
}

func max(a, b int) int {
 if a > b {
  return a
 }
 return b
}

type TokenBucket struct {
 BucketSize int // 木桶内的容量:最多可以存放多少个令牌
 TokenRate time.Duration // 多长时间生成一个令牌
 records map[string]*record // 报错user_id/ip的访问记录
}

// 上次访问时的时间戳和令牌数
type record struct {
 last time.Time
 token int
}

func NewTokenBucket(bucketSize int, tokenRate time.Duration) *TokenBucket {
 return &TokenBucket{
  BucketSize: bucketSize,
  TokenRate:  tokenRate,
  records:    make(map[string]*record),
 }
}

func (t *TokenBucket) getUidOrIp() string {
 // 获取请求用户的user_id或者ip地址
 return "127.0.0.1"
}

// 获取这个user_id/ip上次访问时的时间戳和令牌数
func (t *TokenBucket) getRecord(uidOrIp string) *record {
 if r, ok := t.records[uidOrIp]; ok {
  return r
 }
 return &record{}
}

// 保存user_id/ip最近一次请求时的时间戳和令牌数量
func (t *TokenBucket) storeRecord(uidOrIp string, r *record) {
 t.records[uidOrIp] = r
}

// 验证是否能获取一个令牌
func (t *TokenBucket) validate(uidOrIp string) bool {
 // 并发修改同一个用户的记录上写锁
 rl, ok := recordMu[uidOrIp]
 if !ok {
  var mu sync.RWMutex
  rl = &mu
  recordMu[uidOrIp] = rl
 }
 rl.Lock()
 defer rl.Unlock()

 r := t.getRecord(uidOrIp)
 now := time.Now()
 if r.last.IsZero() {
  // 第一次访问初始化为最大令牌数
  r.last, r.token = now, t.BucketSize
 } else {
  if r.last.Add(t.TokenRate).Before(now) {
   // 如果与上次请求的间隔超过了token rate
   // 则增加令牌,更新last
   r.token += max(int(now.Sub(r.last) / t.TokenRate), t.BucketSize)
   r.last = now
  }
 }
 var result bool
 if r.token > 0 {
  // 如果令牌数大于1,取走一个令牌,validate结果为true
  r.token--
  result = true
 }

 // 保存最新的record
 t.storeRecord(uidOrIp, r)
 return result
}

// 返回是否被限流
func (t *TokenBucket) IsLimited() bool {
 return !t.validate(t.getUidOrIp())
}

func main() {
 tokenBucket := NewTokenBucket(5, 100*time.Millisecond)
 for i := 0; i< 6; i++ {
  fmt.Println(tokenBucket.IsLimited())
 }
 time.Sleep(100 * time.Millisecond)
 fmt.Println(tokenBucket.IsLimited())
}

3。滑动时间窗算法

  • 算法思想

滑动时间窗算法是对计数规则时间窗的优化。当使用常规时间窗口时,我们为每个 user_id/ip 维护一个 KV: uidOrIp: timestamp_requestCount 。假设限制每秒1000个请求,那么100ms有一个请求,KV就变成uidOrIp:timestamp_1。 200 毫秒后有 1 个请求。我们首先比较记录的时间戳是否相差超过1秒。如果没有,只需更新计数即可。此时KV变为uidOrIp:timestamp_2。当请求到达 1100 毫秒时,更新记录中的时间戳并重置计数。 KV 变为 uidOrIp:newtimestamp_1。常规时间窗口有问题。假设有500个请求集中在前1s和最后100ms,500个请求集中在最后100ms。在1s的前100ms内,请求实际上在这200ms之前超出了限制。但由于时间窗口内每1秒重置一次计数,目前无法识别超出限制的请求。

对于滑动时间窗口,我们可以将1ms的时间窗口分为10个时隙,每个时隙统计给定100ms的请求数量。每隔100ms,窗口中添加一个新的时隙,比当前时间早100ms的时隙离开窗口。窗口内最多维持10个时隙,存储空间的消耗也比较低。

  • 适用场景

与令牌桶一样,可以应对突发流量。

  • Go语言实现

主要实现滑动窗口算法。可以参考B站开源的kratos框架中断路器的实现,它使用循环列表来存储时隙对象。它们实现的优点是不需要频繁创建和销毁时隙对象。下面是一个简单的基本实现:

package main

import (
 "fmt"
 "sync"
 "time"
)

var winMu map[string]*sync.RWMutex

func init() {
 winMu = make(map[string]*sync.RWMutex)
}

type timeSlot struct {
 timestamp time.Time // 这个timeSlot的时间起点
 count     int       // 落在这个timeSlot内的请求数
}

func countReq(win []*timeSlot) int {
 var count int
 for _, ts := range win {
  count += ts.count
 }
 return count
}

type SlidingWindowLimiter struct {
 SlotDuration time.Duration // time slot的长度
 WinDuration  time.Duration // sliding window的长度
 numSlots     int           // window内最多有多少个slot
 windows      map[string][]*timeSlot
 maxReq       int // win duration内允许的最大请求数
}

func NewSliding(slotDuration time.Duration, winDuration time.Duration, maxReq int) *SlidingWindowLimiter {
 return &SlidingWindowLimiter{
  SlotDuration: slotDuration,
  WinDuration:  winDuration,
  numSlots:     int(winDuration / slotDuration),
  windows:      make(map[string][]*timeSlot),
  maxReq:       maxReq,
 }
}

// 获取user_id/ip的时间窗口
func (l *SlidingWindowLimiter) getWindow(uidOrIp string) []*timeSlot {
 win, ok := l.windows[uidOrIp]
 if !ok {
  win = make([]*timeSlot, 0, l.numSlots)
 }
 return win
}

func (l *SlidingWindowLimiter) storeWindow(uidOrIp string, win []*timeSlot) {
 l.windows[uidOrIp] = win
}

func (l *SlidingWindowLimiter) validate(uidOrIp string) bool {
 // 同一user_id/ip并发安全
 mu, ok := winMu[uidOrIp]
 if !ok {
  var m sync.RWMutex
  mu = &m
  winMu[uidOrIp] = mu
 }
 mu.Lock()
 defer mu.Unlock()

 win := l.getWindow(uidOrIp)
 now := time.Now()
 // 已经过期的time slot移出时间窗
 timeoutOffset := -1
 for i, ts := range win {
  if ts.timestamp.Add(l.WinDuration).After(now) {
   break
  }
  timeoutOffset = i
 }
 if timeoutOffset > -1 {
  win = win[timeoutOffset+1:]
 }

 // 判断请求是否超限
 var result bool
 if countReq(win) < l.maxReq {
  result = true
 }

 // 记录这次的请求数
 var lastSlot *timeSlot
 if len(win) > 0 {
  lastSlot = win[len(win)-1]
  if lastSlot.timestamp.Add(l.SlotDuration).Before(now) {
   lastSlot = &timeSlot{timestamp: now, count: 1}
   win = append(win, lastSlot)
  } else {
   lastSlot.count++
  }
 } else {
  lastSlot = &timeSlot{timestamp: now, count: 1}
  win = append(win, lastSlot)
 }

 l.storeWindow(uidOrIp, win)

 return result
}

func (l *SlidingWindowLimiter) getUidOrIp() string {
 return "127.0.0.1"
}

func (l *SlidingWindowLimiter) IsLimited() bool {
 return !l.validate(l.getUidOrIp())
}

func main() {
 limiter := NewSliding(100*time.Millisecond, time.Second, 10)
 for i := 0; i < 5; i++ {
  fmt.Println(limiter.IsLimited())
 }
 time.Sleep(100 * time.Millisecond)
 for i := 0; i < 5; i++ {
  fmt.Println(limiter.IsLimited())
 }
 fmt.Println(limiter.IsLimited())
 for _, v := range limiter.windows[limiter.getUidOrIp()] {
  fmt.Println(v.timestamp, v.count)
 }

 fmt.Println("a thousand years later...")
 time.Sleep(time.Second)
 for i := 0; i < 7; i++ {
  fmt.Println(limiter.IsLimited())
 }
 for _, v := range limiter.windows[limiter.getUidOrIp()] {
  fmt.Println(v.timestamp, v.count)
 }
}

版权声明

本文仅代表作者观点,不代表Code前端网立场。
本文系作者Code前端网发表,如需转载,请注明页面地址。

热门