diff --git a/apps/api/src/auth/auth.guard.ts b/apps/api/src/auth/auth.guard.ts index c85dbddb2..1c17a984e 100644 --- a/apps/api/src/auth/auth.guard.ts +++ b/apps/api/src/auth/auth.guard.ts @@ -7,7 +7,7 @@ */ import { AuthRole } from "./auth.role.js"; -import { AuthService } from "./auth.service.js"; +import { AuthService, RateLimitException } from "./auth.service.js"; import { CanActivate, ExecutionContext, @@ -45,7 +45,18 @@ export class AuthGuard implements CanActivate { if (!apiKey) throw new UnauthorizedException(); - const keyInfo = await this.authService.limited(apiKey, weight, role); + let keyInfo; + try { + keyInfo = await this.authService.limited(apiKey, weight, role); + } catch (err) { + if (err instanceof RateLimitException) { + context.switchToHttp().getResponse().headers({ + "retry-after": err.resetTime, + "x-ratelimit-timeout": err.resetTime, + }); + } + throw err; + } context.switchToHttp().getResponse().headers({ "x-ratelimit-used": keyInfo.used, diff --git a/apps/api/src/auth/auth.service.test.ts b/apps/api/src/auth/auth.service.test.ts new file mode 100644 index 000000000..bdd000581 --- /dev/null +++ b/apps/api/src/auth/auth.service.test.ts @@ -0,0 +1,37 @@ +/** + * Copyright (c) Statsify + * + * This source code is licensed under the GNU GPL v3 license found in the + * LICENSE file in the root directory of this source tree. + * https://github.com/Statsify/statsify/blob/main/LICENSE + */ + +import { describe, expect, it } from "vitest"; +import { getRateLimitBucketStats } from "./auth.service.js"; + +describe("AuthService", () => { + it("sums only rate limit buckets within the current window", () => { + const now = 120_500; + + expect( + getRateLimitBucketStats( + { + 60: "2", + 61: "3", + 120: "5", + }, + now + ) + ).toEqual({ + recentRequests: 8, + resetTime: 500, + }); + }); + + it("returns an empty rate limit state when no buckets are in the window", () => { + expect(getRateLimitBucketStats({ 1: "2" }, 120_500)).toEqual({ + recentRequests: 0, + resetTime: 0, + }); + }); +}); diff --git a/apps/api/src/auth/auth.service.ts b/apps/api/src/auth/auth.service.ts index 9ac0be742..326f4e8cc 100644 --- a/apps/api/src/auth/auth.service.ts +++ b/apps/api/src/auth/auth.service.ts @@ -15,73 +15,121 @@ import { Logger, UnauthorizedException, } from "@nestjs/common"; + +export class RateLimitException extends HttpException { + public readonly resetTime: number; + public constructor(resetTime: number) { + super("Too Many Requests", 429); + this.resetTime = resetTime; + } +} import { InjectRedis } from "#redis"; import { Key } from "@statsify/schemas"; import { Redis } from "ioredis"; import { createHash, randomUUID } from "node:crypto"; -@Injectable() -export class AuthService { - private readonly logger = new Logger("AuthGuard"); +const RATE_LIMIT_WINDOW_SECONDS = 60; - public constructor(@InjectRedis() private readonly redis: Redis) {} +const RATE_LIMIT_SCRIPT = ` +local key = KEYS[1] +local rateLimitKey = KEYS[2] - public async limited(apiKey: string, weight: number, role: AuthRole) { - const hash = this.hash(apiKey); +local requiredRole = tonumber(ARGV[1]) +local weight = tonumber(ARGV[2]) +local now = tonumber(ARGV[3]) +local windowSeconds = tonumber(ARGV[4]) +local currentSecond = math.floor(now / 1000) +local windowStart = currentSecond - windowSeconds + 1 - const [name, ...keyInfo] = await this.redis.hmget( - `key:${hash}`, - "name", - "role", - "limit" - ); +local keyData = redis.call("HMGET", key, "name", "role", "limit") +local name = keyData[1] - if (name === null) throw new UnauthorizedException(); +if not name then + return { "unauthorized" } +end - const [apiKeyRole, apiKeyLimit] = keyInfo.map(Number); +local apiKeyRole = tonumber(keyData[2]) or 0 +local apiKeyLimit = tonumber(keyData[3]) or 0 - if (apiKeyRole < role) throw new ForbiddenException(); +if apiKeyRole < requiredRole then + return { "forbidden" } +end - const pipeline = this.redis.pipeline(); +local buckets = redis.call("HGETALL", rateLimitKey) +local weightedTotal = 0 +local oldestSecond = currentSecond - const time = Date.now(); - const expirey = 60_000; +for i = 1, #buckets, 2 do + local bucketSecond = tonumber(buckets[i]) + local bucketWeight = tonumber(buckets[i + 1]) - const key = `ratelimit:${hash}`; + if bucketSecond < windowStart then + redis.call("HDEL", rateLimitKey, buckets[i]) + else + weightedTotal = weightedTotal + bucketWeight - pipeline.zremrangebyscore(key, 0, time - expirey); - pipeline.zadd(key, time, `${randomUUID()}:${weight}`); - pipeline.zrange(key, 0, -1, "WITHSCORES"); - pipeline.hincrby(`key:${hash}`, "requests", weight); - pipeline.expire(key, expirey / 1000); + if bucketSecond < oldestSecond then + oldestSecond = bucketSecond + end + end +end - const pipelineResult = await pipeline.exec(); +local projectedTotal = weightedTotal + weight +local resetTime = ((oldestSecond + windowSeconds) * 1000) - now - if (!pipelineResult) throw new InternalServerErrorException(); +if projectedTotal > apiKeyLimit then + return { "ok", name, apiKeyLimit, projectedTotal, resetTime } +end + +redis.call("HINCRBY", rateLimitKey, currentSecond, weight) +redis.call("HINCRBY", key, "requests", weight) +redis.call("EXPIRE", rateLimitKey, windowSeconds) + +return { "ok", name, apiKeyLimit, projectedTotal, resetTime } +`; + +type RateLimitResult = + ["unauthorized"] | + ["forbidden"] | + ["ok", string, number, number, number]; + +@Injectable() +export class AuthService { + private readonly logger = new Logger("AuthGuard"); - const requests = pipelineResult[2]; + public constructor(@InjectRedis() private readonly redis: Redis) {} - if (requests[0]) throw new InternalServerErrorException(); + public async limited(apiKey: string, weight: number, role: AuthRole) { + const hash = this.hash(apiKey); + const result = await this.redis.eval( + RATE_LIMIT_SCRIPT, + 2, + `key:${hash}`, + `ratelimit:v2:${hash}`, + role, + weight, + Date.now(), + RATE_LIMIT_WINDOW_SECONDS + ) as RateLimitResult; - const weightedTotal = (requests[1] as string[]) - .filter((_, i) => i % 2 === 0) - .reduce((acc, key) => acc + Number(key.split(":")[1]), 0); + const [status, name, apiKeyLimit, weightedTotal, resetTime] = result; - const resetTime = 60_000 - (time - (requests[1] as [Error | null, number])[1]); + if (status === "unauthorized") throw new UnauthorizedException(); + if (status === "forbidden") throw new ForbiddenException(); - if (weightedTotal > apiKeyLimit) { + if (Number(weightedTotal) > Number(apiKeyLimit)) { this.logger.warn( `${name} has exceeded their request limit of ${apiKeyLimit} and has requested ${weightedTotal} times` ); - throw new HttpException("Too Many Requests", 429); + throw new RateLimitException(Number(resetTime)); } return { canActivate: true, - used: weightedTotal, - limit: apiKeyLimit, - resetTime, + used: Number(weightedTotal), + limit: Number(apiKeyLimit), + resetTime: Number(resetTime), }; } @@ -106,11 +154,11 @@ export class AuthService { public async getKey(apiKey: string): Promise { const hash = this.hash(apiKey); - const key = `ratelimit:${hash}`; + const key = `ratelimit:v2:${hash}`; const pipeline = this.redis.pipeline(); pipeline.hmget(`key:${hash}`, "name", "requests", "limit"); - pipeline.zrange(key, 0, -1, "WITHSCORES"); + pipeline.hgetall(key); const pipelineResult = await pipeline.exec(); @@ -119,15 +167,11 @@ export class AuthService { const [keydata, requests] = pipelineResult; const [name, lifetimeRequests, limit] = keydata[1] as [string, number, number]; - - const recentRequests = (requests[1] as string[]) - .filter((_, i) => i % 2 === 0) - .reduce((acc, key) => acc + Number(key.split(":")[1]), 0); - - const time = Date.now(); - - const resetTime = - 60_000 - (time - Number((requests[1] as [Error | null, number])[1])); + const requestBuckets = requests[1] as Record; + const { recentRequests, resetTime } = getRateLimitBucketStats( + requestBuckets, + Date.now() + ); return { name, @@ -142,3 +186,30 @@ export class AuthService { return createHash("sha256").update(apiKey).digest("hex"); } } + +export function getRateLimitBucketStats(buckets: Record, now: number) { + const currentSecond = Math.floor(now / 1000); + const windowStart = currentSecond - RATE_LIMIT_WINDOW_SECONDS + 1; + + let oldestSecond = currentSecond; + let recentRequests = 0; + + for (const [bucket, weight] of Object.entries(buckets)) { + const bucketSecond = Number(bucket); + + if (bucketSecond < windowStart) continue; + + recentRequests += Number(weight); + + if (bucketSecond < oldestSecond) { + oldestSecond = bucketSecond; + } + } + + return { + recentRequests, + resetTime: recentRequests > 0 ? + (oldestSecond + RATE_LIMIT_WINDOW_SECONDS) * 1000 - now : + 0, + }; +} diff --git a/apps/api/tsconfig.json b/apps/api/tsconfig.json index f1e6b476d..06810fb21 100644 --- a/apps/api/tsconfig.json +++ b/apps/api/tsconfig.json @@ -2,6 +2,7 @@ "extends": "../../tsconfig.base.json", "include": [ "src", - "eslint.config.js" + "eslint.config.js", + "vitest.config.ts" ] -} \ No newline at end of file +} diff --git a/apps/api/vitest.config.ts b/apps/api/vitest.config.ts new file mode 100644 index 000000000..2f7603e59 --- /dev/null +++ b/apps/api/vitest.config.ts @@ -0,0 +1,11 @@ +/** + * Copyright (c) Statsify + * + * This source code is licensed under the GNU GPL v3 license found in the + * LICENSE file in the root directory of this source tree. + * https://github.com/Statsify/statsify/blob/main/LICENSE + */ + +import { config } from "../../vitest.shared.js"; + +export default await config();