Refactor route rate limiter builder

This commit is contained in:
Matthieu Sieben 2025-05-23 14:50:52 +02:00
parent 36d0d370c2
commit da433bd0ca
3 changed files with 50 additions and 64 deletions

View File

@ -0,0 +1,5 @@
---
"@atproto/xrpc-server": patch
---
Refactor route rate limiter builder

View File

@ -48,6 +48,7 @@ import {
isShared,
} from './types'
import {
asArray,
decodeQueryParams,
getQueryParams,
validateInput,
@ -67,7 +68,6 @@ export class Server {
middleware: Record<'json' | 'text', RequestHandler>
globalRateLimiters: RateLimiterI[]
sharedRateLimiters: Record<string, RateLimiterI>
routeRateLimiters: Record<string, RateLimiterI[]>
constructor(lexicons?: LexiconDoc[], opts: Options = {}) {
if (lexicons) {
@ -86,7 +86,6 @@ export class Server {
}
this.globalRateLimiters = []
this.sharedRateLimiters = {}
this.routeRateLimiters = {}
if (opts?.rateLimits?.global) {
for (const limit of opts.rateLimits.global) {
const rateLimiter = opts.rateLimits.creator({
@ -177,7 +176,6 @@ export class Server {
middleware.push(this.middleware.json)
middleware.push(this.middleware.text)
}
this.setupRouteRateLimits(nsid, config)
this.routes[verb](
`/xrpc/${nsid}`,
...middleware,
@ -251,18 +249,8 @@ export class Server {
validateOutput(nsid, def, output, this.lex)
const assertValidXrpcParams = (params: unknown) =>
this.lex.assertValidXrpcParams(nsid, params)
const rls = this.routeRateLimiters[nsid] ?? []
const consumeRateLimit = (reqCtx: XRPCReqContext) =>
consumeMany(
reqCtx,
rls.map((rl) => (ctx: XRPCReqContext) => rl.consume(ctx)),
)
const resetRateLimit = (reqCtx: XRPCReqContext) =>
resetMany(
reqCtx,
rls.map((rl) => (ctx: XRPCReqContext) => rl.reset(ctx)),
)
const rateLimiter = this.createRouteRateLimiter(routeCfg)
return async function (req, res, next) {
try {
@ -283,11 +271,11 @@ export class Server {
auth: locals.auth,
req,
res,
resetRouteRateLimits: async () => resetRateLimit(reqCtx),
resetRouteRateLimits: async () => rateLimiter.reset(reqCtx),
}
// handle rate limits
const result = await consumeRateLimit(reqCtx)
const result = await rateLimiter.consume(reqCtx)
if (result instanceof RateLimitExceededError) {
return next(result)
}
@ -432,63 +420,53 @@ export class Server {
}
}
private setupRouteRateLimits(nsid: string, config: XRPCHandlerConfig) {
this.routeRateLimiters[nsid] = []
for (const limit of this.globalRateLimiters) {
this.routeRateLimiters[nsid].push({
consume: (ctx: XRPCReqContext) => limit.consume(ctx),
reset: (ctx: XRPCReqContext) => limit.reset(ctx),
})
}
private createRouteRateLimiter(config: XRPCHandlerConfig): RateLimiterI {
const rls: RateLimiterI[] = config.rateLimit
? asArray(config.rateLimit)
.map((options, i): RateLimiterI | null => {
const { calcKey, calcPoints } = options
if (config.rateLimit) {
const limits = Array.isArray(config.rateLimit)
? config.rateLimit
: [config.rateLimit]
this.routeRateLimiters[nsid] = []
for (let i = 0; i < limits.length; i++) {
const limit = limits[i]
const { calcKey, calcPoints } = limit
if (isShared(limit)) {
const rateLimiter = this.sharedRateLimiters[limit.name]
if (rateLimiter) {
this.routeRateLimiters[nsid].push({
consume: (ctx: XRPCReqContext) =>
const rateLimiter = isShared(options)
? this.sharedRateLimiters[options.name]
: this.options.rateLimits?.creator({
keyPrefix: `nsid-${i}`,
durationMs: options.durationMs,
points: options.points,
calcKey,
calcPoints,
})
if (!rateLimiter) return null
return {
consume: (ctx) =>
rateLimiter.consume(ctx, {
calcKey,
calcPoints,
}),
reset: (ctx: XRPCReqContext) =>
reset: (ctx) =>
rateLimiter.reset(ctx, {
calcKey,
}),
})
}
} else {
const { durationMs, points } = limit
const rateLimiter = this.options.rateLimits?.creator({
keyPrefix: `nsid-${i}`,
durationMs,
points,
calcKey,
calcPoints,
}
})
if (rateLimiter) {
this.sharedRateLimiters[nsid] = rateLimiter
this.routeRateLimiters[nsid].push({
consume: (ctx: XRPCReqContext) =>
rateLimiter.consume(ctx, {
calcKey,
calcPoints,
}),
reset: (ctx: XRPCReqContext) =>
rateLimiter.reset(ctx, {
calcKey,
}),
})
}
}
}
.filter((v) => v != null)
: this.globalRateLimiters.map((limit) => ({
consume: (ctx) => limit.consume(ctx),
reset: (ctx) => limit.reset(ctx),
}))
return {
consume: async (ctx) =>
consumeMany(
ctx,
rls.map((rl) => (ctx) => rl.consume(ctx)),
),
reset: async (ctx) =>
resetMany(
ctx,
rls.map((rl) => (ctx) => rl.reset(ctx)),
),
}
}
}

View File

@ -24,6 +24,9 @@ import {
handlerSuccess,
} from './types'
export const asArray = <T>(arr: T | T[]): T[] =>
Array.isArray(arr) ? arr : [arr]
export function decodeQueryParams(
def: LexXrpcProcedure | LexXrpcQuery | LexXrpcSubscription,
params: UndecodedParams,