127 lines
3.8 KiB
Go
127 lines
3.8 KiB
Go
package ratelimit
|
|
|
|
const (
|
|
// combinedRateLimitScript checks burst, daily and monthly limits in one call
|
|
// KEYS[1] - The burst bucket key
|
|
// KEYS[2] - The daily quota key
|
|
// KEYS[3] - The monthly quota key
|
|
// ARGV[1] - Current timestamp in nanoseconds
|
|
// ARGV[2] - Cutoff timestamp for old tokens
|
|
// ARGV[3] - Maximum burst tokens allowed
|
|
// ARGV[4] - Burst expiration time in seconds
|
|
// ARGV[5] - Daily maximum requests
|
|
// ARGV[6] - Daily expiration time in seconds
|
|
// ARGV[7] - Monthly maximum requests
|
|
// ARGV[8] - Monthly expiration time in seconds
|
|
combinedRateLimitScript = `
|
|
-- Check burst limit (leaky bucket)
|
|
local burstKey = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local cutoff = tonumber(ARGV[2])
|
|
local maxBurst = tonumber(ARGV[3])
|
|
local burstExpiration = tonumber(ARGV[4])
|
|
|
|
-- Remove tokens older than the cutoff time
|
|
redis.call('ZREMRANGEBYSCORE', burstKey, '0', cutoff)
|
|
-- Count current tokens in the bucket
|
|
local burstCount = redis.call('ZCARD', burstKey)
|
|
local remainingBurst = maxBurst - burstCount
|
|
local burstAllowed = remainingBurst > 0
|
|
|
|
-- Check daily quota
|
|
local dailyKey = KEYS[2]
|
|
local maxDaily = tonumber(ARGV[5])
|
|
local dailyExpiration = tonumber(ARGV[6])
|
|
local dailyCount = redis.call('GET', dailyKey)
|
|
local dailyUsed = 0
|
|
if dailyCount then
|
|
dailyUsed = tonumber(dailyCount)
|
|
end
|
|
local remainingDaily = maxDaily - dailyUsed
|
|
local dailyAllowed = remainingDaily > 0
|
|
|
|
-- Check monthly quota
|
|
local monthlyKey = KEYS[3]
|
|
local maxMonthly = tonumber(ARGV[7])
|
|
local monthlyExpiration = tonumber(ARGV[8])
|
|
local monthlyCount = redis.call('GET', monthlyKey)
|
|
local monthlyUsed = 0
|
|
if monthlyCount then
|
|
monthlyUsed = tonumber(monthlyCount)
|
|
end
|
|
local remainingMonthly = maxMonthly - monthlyUsed
|
|
local monthlyAllowed = remainingMonthly > 0
|
|
|
|
-- Overall allowed if all checks pass
|
|
local allowed = burstAllowed and dailyAllowed and monthlyAllowed
|
|
|
|
-- If allowed, update all counters
|
|
if allowed then
|
|
-- Update burst bucket
|
|
redis.call('ZADD', burstKey, now, now)
|
|
redis.call('EXPIRE', burstKey, burstExpiration)
|
|
|
|
-- Update daily quota
|
|
if dailyCount then
|
|
redis.call('INCR', dailyKey)
|
|
else
|
|
redis.call('SETEX', dailyKey, dailyExpiration, 1)
|
|
end
|
|
|
|
-- Update monthly quota
|
|
if monthlyCount then
|
|
redis.call('INCR', monthlyKey)
|
|
else
|
|
redis.call('SETEX', monthlyKey, monthlyExpiration, 1)
|
|
end
|
|
|
|
-- Decrement the remaining counts to account for this request
|
|
remainingBurst = remainingBurst - 1
|
|
remainingDaily = remainingDaily - 1
|
|
remainingMonthly = remainingMonthly - 1
|
|
end
|
|
|
|
-- Return array with allowed flag and remaining counts
|
|
return {
|
|
allowed and 1 or 0,
|
|
remainingBurst,
|
|
remainingDaily,
|
|
remainingMonthly
|
|
}
|
|
`
|
|
|
|
// readOnlyRateLimitScript is a Lua script that checks rate limits without incrementing counters
|
|
readOnlyRateLimitScript = `
|
|
-- Keys: burstKey, dailyKey, monthlyKey
|
|
-- Args: now, cutoff, burstLimit, dailyLimit, monthlyLimit
|
|
local now = tonumber(ARGV[1])
|
|
local cutoff = tonumber(ARGV[2])
|
|
local burstLimit = tonumber(ARGV[3])
|
|
local dailyLimit = tonumber(ARGV[4])
|
|
local monthlyLimit = tonumber(ARGV[5])
|
|
|
|
-- Check burst limit (leaky bucket)
|
|
local bursts = redis.call('ZRANGEBYSCORE', KEYS[1], cutoff, '+inf')
|
|
local burstCount = #bursts
|
|
local remainingBurst = burstLimit - burstCount
|
|
|
|
-- Check daily limit
|
|
local dailyCount = redis.call('GET', KEYS[2])
|
|
dailyCount = dailyCount and tonumber(dailyCount) or 0
|
|
local remainingDaily = dailyLimit - dailyCount
|
|
|
|
-- Check monthly limit
|
|
local monthlyCount = redis.call('GET', KEYS[3])
|
|
monthlyCount = monthlyCount and tonumber(monthlyCount) or 0
|
|
local remainingMonthly = monthlyLimit - monthlyCount
|
|
|
|
-- Check if all limits are satisfied
|
|
local allowed = 1
|
|
if remainingBurst <= 0 or remainingDaily <= 0 or remainingMonthly <= 0 then
|
|
allowed = 0
|
|
end
|
|
|
|
return {allowed, remainingBurst, remainingDaily, remainingMonthly}
|
|
`
|
|
)
|