限流组件开发与学习

限流组件开发与学习

为了让羊毛党利用脚本等工具快速多次抽奖印象活动体验,我们可以写一个限流的组件,来阻止短时间内大量请求。

首先来定义这个注解:RateLimiterAccessInterceptor

注解 RateLimiterAccessInterceptor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Documented
public @interface RateLimiterAccessInterceptor {

/** 用哪个字段作为拦截标识,未配置则默认走全部 */
String key() default "all";

/** 限制频次(每秒请求次数) */
double permitsPerSecond();

/** 黑名单拦截(多少次限制后加入黑名单)0 不限制 */
double blacklistCount() default 0;

/** 拦截后的执行方法 */
String fallbackMethod();
}

这个注解用于标识需要进行限流的方法。该注解包含以下属性:

  • key: 用于指定哪个字段作为限流的标识符,默认值为 "all",表示对所有请求统一限流。
  • permitsPerSecond: 每秒允许的请求次数,用于配置限流的频率。
  • blacklistCount: 黑名单拦截的阈值,当某个标识符的请求次数超过该值后,将其加入黑名单。默认值为 0,表示不启用黑名单功能。
  • fallbackMethod: 当请求被限流或拦截后,执行的回调方法名称。

切面类 RateLimiterAOP

RateLimiterAOP 是一个切面类,负责拦截标注了 RateLimiterAccessInterceptor 注解的方法,并实现具体的限流逻辑。

主要成员变量:

  • rateLimiterSwitch: 通过 @DCCValue 注解从配置中心获取的限流开关,用于控制限流功能的开启和关闭。
  • loginRecord: 使用 Guava 的 Cache 实现,每个 key 对应一个 RateLimiter,用于控制请求频率。记录的有效期为 1 分钟。
  • blacklist: 使用 Guava 的 Cache 实现,用于记录被加入黑名单的标识符,记录的有效期为 24 小时。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@Slf4j
@Aspect
@Component
public class RateLimiterAOP {

@DCCValue("rateLimiterSwitch:close")
private String rateLimiterSwitch;


// 个人限频记录1分钟
private final Cache<String, RateLimiter> loginRecord = CacheBuilder.newBuilder()
.expireAfterWrite(1, TimeUnit.MINUTES)
.build();

// 个人限频黑名单24h - 分布式业务场景,可以记录到 Redis 中
private final Cache<String, Long> blacklist = CacheBuilder.newBuilder()
.expireAfterWrite(24, TimeUnit.HOURS)
.build();
/**/
}

切点定义:

定义了一个切点 aopPoint,匹配所有标注了 RateLimiterAccessInterceptor 注解的方法。

1
2
3
4
5
6
7
8
9
@Slf4j
@Aspect
@Component
public class RateLimiterAOP {
/**/
@Pointcut("@annotation(io.github.jasonxqh.types.annotations.RateLimiterAccessInterceptor)")
public void aopPoint(){}
/**/
}

环绕通知 doRouter

RateLimiter是什么类?

RateLimiter 类介绍:

  • 来源RateLimiter 是 Google Guava 库中的一个类,位于 com.google.common.util.concurrent 包中。
  • 功能:实现了基于令牌桶算法的限流器,用于控制代码执行的速率,防止系统被过多请求压垮。

主要特性:

  1. 令牌桶算法
    • RateLimiter 采用令牌桶算法,通过以固定速率生成令牌,控制请求的速率。
    • 每个请求在执行前需要从桶中获取一个令牌,如果令牌可用,则允许请求执行;否则,拒绝请求。
  2. 两种获取令牌的方式
    • 阻塞获取acquire() 方法会阻塞,直到获取到令牌。
    • 非阻塞获取tryAcquire() 方法会立即返回,表示是否成功获取到令牌。
  3. 速率控制
    • 可以动态调整令牌生成速率,适应不同的限流需求。

主要方法:

  • RateLimiter.create(double permitsPerSecond)
    • 创建一个以指定速率(每秒发放的令牌数)生成令牌的 RateLimiter 实例。因此,我们在用@RateLimiterAccessInterceptor 注解一个方法的时候,需要定义这个permitsPerSecond字段,他决定了隔多少秒可以抽一次奖。
  • boolean tryAcquire()
    • 尝试立即获取一个令牌,如果成功返回 true,否则返回 false
    • 非阻塞方式,适用于需要快速判断是否允许请求的场景。
  • void acquire()
    • 阻塞直到获取到一个令牌,适用于需要严格限制执行速率的场景。

核心逻辑

这是核心的限流逻辑实现。以下是详细流程:

  1. 限流开关检查
    • 如果 rateLimiterSwitch 未配置或值为 "close",则不进行限流,直接执行目标方法 pjp.proceed()
  2. 获取限流标识符
    • 从注解中获取 key 属性值。
    • 通过 getAttrValue 方法,从目标方法的参数中提取出 key 对应的值 keyAttr
    • 如果 keyAttr"all",则表示对所有请求统一限流;否则,对特定标识符(如 userId)进行限流。
  3. 黑名单检查
    • 如果 blacklistCount 不为 0,并且 keyAttr 已经在黑名单中且超过了 blacklistCount,则直接调用回调方法 fallbackMethod,拒绝此次请求。
  4. 限流检查
    • loginRecord 中获取对应 keyRateLimiter 实例,如果不存在则创建一个新的 RateLimiter,并放入缓存中。loginRecord 中的 RateLimiter 会在 1 分钟后过期,若该用户在此期间没有新的请求,其 RateLimiter 会被移除。
    • 调用 rateLimiter.tryAcquire()尝试获取一个许可。如果获取失败,表示当前请求超出限流频率:
      • 如果 blacklistCount 不为 0,即配置的黑名单阈值不为 0。则记录此次超频行为,更新黑名单计数。
      • 调用回调方法 fallbackMethod,拒绝此次请求。
  5. 允许请求
    • 如果成功获取到许可,则执行目标方法 pjp.proceed(),允许此次请求通过。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
@Slf4j
@Aspect
@Component
public class RateLimiterAOP {
@Around("aopPoint() && @annotation(rateLimiterAccessInterceptor)")
public Object doRouter(ProceedingJoinPoint pjp, RateLimiterAccessInterceptor rateLimiterAccessInterceptor) throws Throwable {
//限流开关[open 开启 ,close 关闭] 关闭后不走限流策略
if(StringUtils.isBlank(rateLimiterSwitch)||"close".equals(rateLimiterSwitch)){
return pjp.proceed();
}
String key = rateLimiterAccessInterceptor.key();
if(StringUtils.isBlank(key)){
throw new RuntimeException("uId is null or empty");
}

String keyAttr = getAttrValue(key, pjp.getArgs());
log.info("aop attr {}",keyAttr);
if(!"all".equals(keyAttr)
&& rateLimiterAccessInterceptor.blacklistCount() != 0
&& null != blacklist.getIfPresent(keyAttr)
&& blacklist.getIfPresent(keyAttr) > rateLimiterAccessInterceptor.blacklistCount()) {
log.info("限流-黑名单拦截(24h):{}", keyAttr);
return fallbackMethodResult(pjp, rateLimiterAccessInterceptor.fallbackMethod());
}

RateLimiter rateLimiter = loginRecord.getIfPresent(key);
if(rateLimiter == null){
rateLimiter = RateLimiter.create(rateLimiterAccessInterceptor.permitsPerSecond());
loginRecord.put(key, rateLimiter);
}
//未获取到令牌,超出了限制频率,视为违约1次,那么此时就需要在黑名单中加入1次
if(!rateLimiter.tryAcquire()){
if(rateLimiterAccessInterceptor.blacklistCount() != 0){
if(null == blacklist.getIfPresent(keyAttr)){
blacklist.put(keyAttr, 1L);
}else{
blacklist.put(keyAttr, blacklist.getIfPresent(keyAttr) + 1L);
}
}
log.info("限流-超频次拦截:{} ",keyAttr);
return fallbackMethodResult(pjp, rateLimiterAccessInterceptor.fallbackMethod());
}

return pjp.proceed();
}
/**/
}

辅助方法:

  • fallbackMethodResult:通过反射调用用户配置的回调方法。当请求被限流或拦截后,执行此方法以返回预定义的响应。
1
2
3
4
5
6
7
8
9
10
/**
* 调用用户配置的回调方法,当拦截后,返回回调结果。
*/
private Object fallbackMethodResult(JoinPoint jp, String fallbackMethod) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Signature sig = jp.getSignature();
MethodSignature methodSignature = (MethodSignature) sig;
Method method = jp.getTarget().getClass().getMethod(fallbackMethod, methodSignature.getParameterTypes());
return method.invoke(jp.getThis(), jp.getArgs());
}
/**/
  • getAttrValue:根据 key 从目标方法的参数中提取对应的值。支持通过反射获取对象属性值,适应不同参数类型。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

/**
* 实际根据自身业务调整,主要是为了获取通过某个值做拦截
*/
public String getAttrValue(String attr, Object[] args) {
if (args[0] instanceof String) {
return args[0].toString();
}
String filedValue = null;
for (Object arg : args) {
try {
if (StringUtils.isNotBlank(filedValue)) {
break;
}
// filedValue = BeanUtils.getProperty(arg, attr);
// fix: 使用lombok时,uId这种字段的get方法与idea生成的get方法不同,会导致获取不到属性值,改成反射获取解决
filedValue = String.valueOf(this.getValueByName(arg, attr));
} catch (Exception e) {
log.error("获取路由属性值失败 attr:{}", attr, e);
}
}
return filedValue;
}
  • getValueByNamegetFieldByName:通过反射获取对象的指定属性值,支持获取父类属性。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/**
* 获取对象的特定属性值
*
* @param item 对象
* @param name 属性名
* @return 属性值
* @author tang
*/
private Object getValueByName(Object item, String name) {
try {
Field field = getFieldByName(item, name);
if (field == null) {
return null;
}
field.setAccessible(true);
Object o = field.get(item);
field.setAccessible(false);
return o;
} catch (IllegalAccessException e) {
return null;
}
}
/**/
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

/**
* 根据名称获取方法,该方法同时兼顾继承类获取父类的属性
*
* @param item 对象
* @param name 属性名
* @return 该属性对应方法
* @author tang
*/
private Field getFieldByName(Object item, String name) {
try {
Field field;
try {
field = item.getClass().getDeclaredField(name);
} catch (NoSuchFieldException e) {
field = item.getClass().getSuperclass().getDeclaredField(name);
}
return field;
} catch (NoSuchFieldException e) {
return null;
}
}

实际使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@RateLimiterAccessInterceptor(key = "userId",fallbackMethod = "drawRateLimiterError",permitsPerSecond = 1.0d,blacklistCount = 1)
@RequestMapping(value = "draw", method = RequestMethod.POST)
@Override
public Response<ActivityDrawResponseDTO> draw(@RequestBody ActivityDrawRequestDTO requestDTO) {
try {
log.info("活动抽奖开始 userId:{} activityId:{}", requestDTO.getUserId(), requestDTO.getActivityId());
// 0.降级开关 [open为开启,close为关闭]
if (StringUtils.isNotBlank(degradeSwitch) && "open".equals(degradeSwitch)) {
log.info("当前degradeSwitch配置降级: {}",degradeSwitch);
return Response.<ActivityDrawResponseDTO>builder()
.code(ResponseCode.DEGRADE_SWITCH.getCode())
.info(ResponseCode.DEGRADE_SWITCH.getInfo())
.build();
}

Long activityId = requestDTO.getActivityId();
String userId = requestDTO.getUserId();
if(activityId == null || userId == null) {
throw new AppException(ResponseCode.ILLEGAL_PARAMETER.getCode(),ResponseCode.ILLEGAL_PARAMETER.getInfo());
}
//1. 构建抽奖单,此时已经执行抽奖次数扣减了
UserRaffleOrderEntity orderEntity = activityPartakeService.createOrder(activityId, userId);
log.info("活动抽奖,创建订单 userId:{} activityId:{} orderId:{}",userId,activityId, orderEntity.getOrderId());
//2. 执行抽奖,消费抽奖单
RaffleAwardEntity raffleAwardEntity = strategyService.performRaffle(RaffleFactorEntity.builder()
.strategyId(orderEntity.getStrategyId())
.userId(userId)
.endDateTime(orderEntity.getEndDateTime())
.build());

//3.构造发奖记录,存储记录和Task,并更新抽奖单状态为已使用
UserAwardRecordEntity userAwardRecord = UserAwardRecordEntity.builder()
.userId(orderEntity.getUserId())
.activityId(orderEntity.getActivityId())
.strategyId(orderEntity.getStrategyId())
.orderId(orderEntity.getOrderId())
.awardId(raffleAwardEntity.getAwardId())
.awardTitle(raffleAwardEntity.getAwardTitle())
.awardConfig(raffleAwardEntity.getAwardConfig())
.awardTime(new Date())
.awardState(AwardStateVO.create)
.build();
awardService.saveUserAwardRecord(userAwardRecord);

//4.返回结果
Response<ActivityDrawResponseDTO> response = Response.<ActivityDrawResponseDTO>builder()
.code(ResponseCode.SUCCESS.getCode())
.info(ResponseCode.SUCCESS.getInfo())
.data(ActivityDrawResponseDTO.builder()
.awardIndex(raffleAwardEntity.getSort())
.awardTitle(raffleAwardEntity.getAwardTitle())
.awardId(raffleAwardEntity.getAwardId())
.build())
.build();

return response;
}catch (AppException e) {
log.error("活动抽奖失败 userId:{} activityId:{}", requestDTO.getUserId(), requestDTO.getActivityId(), e);
return Response.<ActivityDrawResponseDTO>builder()
.code(e.getCode())
.info(e.getInfo())
.build();
} catch (Exception e) {
log.error("活动抽奖失败 userId:{} activityId:{}", requestDTO.getUserId(), requestDTO.getActivityId(), e);
return Response.<ActivityDrawResponseDTO>builder()
.code(ResponseCode.UN_ERROR.getCode())
.info(ResponseCode.UN_ERROR.getInfo())
.build();
}
}

public Response<ActivityDrawResponseDTO> drawRateLimiterError(@RequestBody ActivityDrawRequestDTO requestDTO) {
log.error("抽奖活动限流 userId:{} activityId:{}", requestDTO.getUserId(), requestDTO.getActivityId());
return Response.<ActivityDrawResponseDTO>builder()
.code(ResponseCode.RATE_LIMITER.getCode())
.info(ResponseCode.RATE_LIMITER.getInfo())
.build();
}

Redis版本

我们可以将缓存在Guava中的BlackList和loginRecord缓存在redis中,并模拟出RateLimiter的令牌桶策略:

使用Redis实现类似RateLimiter的令牌桶逻辑如下:

令牌桶的构建和更新

  • 获取限流速率 permitsPerSecond
  • 构建令牌桶的 Redis 键 bucketKey
  • 从 Redis 获取对应的令牌桶状态 bucketMap,包含 tokenslast_refill_time
  • 如果 bucketMapnull 或为空,初始化令牌数为 1,设置 last_refill_time 为当前时间,并设置过期时间为 1 小时
  • 否则,从 bucketMap 中获取当前的 storedTokenslast_refill_time
  • 计算自上次补充令牌以来经过的毫秒数 elapsedMs
  • 根据限流速率计算新产生的令牌数 newTokens
  • 将新令牌数添加到 storedTokens,并限制令牌数不超过 1(令牌桶容量),这样不管中间间隔了多少时间,都不会令令牌累计。
  • 更新 last_refill_time 为当前时间。

判断令牌是否足够

  • 如果 storedTokens 不足(即 <= 0),表示请求超过了限流频率。
  • 如果配置了黑名单阈值且 keyAttr 不是 "all",则将 keyAttr 的黑名单计数器递增 1
  • 如果是第一次违反限流(newVal == 1),设置黑名单键的过期时间为 24 小时
  • 更新令牌桶的状态(虽然 storedTokens 已经不够,但依然更新是为了保持数据一致性)。
  • 记录日志并调用回调方法 fallbackMethod,拦截请求。

扣减令牌并允许请求

  • 如果有足够的令牌(storedTokens > 0),则扣减 1 个令牌。说明这次是允许访问的
  • 更新令牌桶的状态。
  • 允许请求继续执行目标方法。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
@Around("aopPoint() && @annotation(rateLimiterAccessInterceptor)")
public Object doRouter(ProceedingJoinPoint pjp, RateLimiterAccessInterceptor rateLimiterAccessInterceptor) throws Throwable {

// 1. 检查总开关
if (StringUtils.isBlank(rateLimiterSwitch)
|| "close".equalsIgnoreCase(rateLimiterSwitch)) {
return pjp.proceed();
}

// 2. 从注解取到 key,即uid
String key = rateLimiterAccessInterceptor.key();
if(StringUtils.isBlank(key)){
throw new RuntimeException("uId is null or empty");
}

// 获取调用方法时传入的真正值,比如 userId
String keyAttr = getAttrValue(key, pjp.getArgs());
log.info("[RateLimiterAOP] keyAttr={}", keyAttr);

// -------------------------------------
// 黑名单逻辑:如果某 keyAttr 超频几次,就记到这里
// key: BLACKLIST_PREFIX + keyAttr
// val: 超过次数
// TTL: 24 小时
// -------------------------------------
// 3. 如果设置了黑名单次数,就先检查黑名单
double blackCountThreshold = rateLimiterAccessInterceptor.blacklistCount();
// 非 "all" 且设定了黑名单阈值时,先检查是否已经在黑名单
if (!"all".equals(keyAttr) && blackCountThreshold > 0) {
String blackKey = Constants.RedisKey.BLACKLIST_PREFIX + keyAttr;
Integer blackVal = redisService.getValue(blackKey);
if (blackVal != null && (double) blackVal > blackCountThreshold) {
// 已达黑名单阈值
log.info("限流-黑名单拦截(24h):{}", keyAttr);
return fallbackMethodResult(pjp, rateLimiterAccessInterceptor.fallbackMethod());
}
}

// 3) 令牌桶逻辑:TOKEN_BUCKET_PREFIX + userId => RMap<String, String>
double permitsPerSecond = rateLimiterAccessInterceptor.permitsPerSecond();
String bucketKey = Constants.RedisKey.TOKEN_BUCKET_PREFIX + keyAttr;
// 3.1) 获取 map 中的 tokens / lastRefillTime
RMap<String, String> bucketMap = redisService.getMap(bucketKey);
long currentTime = System.currentTimeMillis();
//初始必须为1,否则会直接触发黑名单
long storedTokens = 1L;
long lastRefillTime = currentTime;

if (bucketMap == null || bucketMap.isEmpty()) {
// 第一次初始化
bucketMap = redisService.getMap(bucketKey);
bucketMap.put("tokens", String.valueOf(storedTokens));
bucketMap.put("last_refill_time", String.valueOf(lastRefillTime));
// 如果需要令牌桶也自动过期,可调用
bucketMap.expire(1, TimeUnit.HOURS);
} else {
String tokenStr = bucketMap.get("tokens");
String timeStr = bucketMap.get("last_refill_time");

if (tokenStr != null) {
storedTokens = Long.parseLong(tokenStr);
}
if (timeStr != null) {
lastRefillTime = Long.parseLong(timeStr);
}
}

// 3.2) 计算需要补充的新令牌
long elapsedMs = currentTime - lastRefillTime;
if (elapsedMs > 0) {
double permitsPerMs = permitsPerSecond / 1000.0;
long newTokens = (long) (elapsedMs * permitsPerMs);
if (newTokens > 0) {
storedTokens += newTokens;
// 上限限制为1,也就是每隔permitsPerSecond最多只能产出1个令牌,无法堆积
storedTokens = Math.min(storedTokens, 1);
lastRefillTime = currentTime;
}
}

// 3.3) 判断令牌是否足够
if (storedTokens <= 0) {
// 不足 => 违约一次 => 进入黑名单统计
if (blackCountThreshold > 0 && !"all".equals(keyAttr)) {
String blackKey = Constants.RedisKey.BLACKLIST_PREFIX + keyAttr;
long newVal = redisService.incrBy(blackKey, 1);
// 如果是第一次违约 => 设置 24h 过期
if (newVal == 1) {
redisService.expire(blackKey, 24, TimeUnit.HOURS);
}
}

// 更新令牌桶信息(可能 lastRefillTime 已刷新)
bucketMap.put("tokens", String.valueOf(storedTokens));
bucketMap.put("last_refill_time", String.valueOf(lastRefillTime));
log.info("限流-超频次拦截(keyAttr={}) => fallback", keyAttr);

return fallbackMethodResult(pjp, rateLimiterAccessInterceptor.fallbackMethod());
}

// 3.4) 如果有令牌 => 扣减一个
storedTokens -= 1;
bucketMap.put("tokens", String.valueOf(storedTokens));
bucketMap.put("last_refill_time", String.valueOf(lastRefillTime));

// 4) 正常执行被拦截方法
return pjp.proceed();
}
-------------本文结束,感谢您的阅读-------------