Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions apps/api/src/auth/auth.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
*/

import { AuthRole } from "./auth.role.js";
import { AuthService } from "./auth.service.js";
import { AuthService, RateLimitException } from "./auth.service.js";
import {
CanActivate,
ExecutionContext,
Expand Down Expand Up @@ -45,7 +45,18 @@ export class AuthGuard implements CanActivate {

if (!apiKey) throw new UnauthorizedException();

const keyInfo = await this.authService.limited(apiKey, weight, role);
let keyInfo;
try {
keyInfo = await this.authService.limited(apiKey, weight, role);
} catch (err) {
if (err instanceof RateLimitException) {
context.switchToHttp().getResponse<FastifyReply>().headers({
"retry-after": err.resetTime,
"x-ratelimit-timeout": err.resetTime,
});
}
throw err;
}

context.switchToHttp().getResponse<FastifyReply>().headers({
"x-ratelimit-used": keyInfo.used,
Expand Down
37 changes: 37 additions & 0 deletions apps/api/src/auth/auth.service.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright (c) Statsify
*
* This source code is licensed under the GNU GPL v3 license found in the
* LICENSE file in the root directory of this source tree.
* https://github.com/Statsify/statsify/blob/main/LICENSE
*/

import { describe, expect, it } from "vitest";
import { getRateLimitBucketStats } from "./auth.service.js";

describe("AuthService", () => {
it("sums only rate limit buckets within the current window", () => {
const now = 120_500;

expect(
getRateLimitBucketStats(
{
60: "2",
61: "3",
120: "5",
},
now
)
).toEqual({
recentRequests: 8,
resetTime: 500,
});
});

it("returns an empty rate limit state when no buckets are in the window", () => {
expect(getRateLimitBucketStats({ 1: "2" }, 120_500)).toEqual({
recentRequests: 0,
resetTime: 0,
});
});
});
167 changes: 119 additions & 48 deletions apps/api/src/auth/auth.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,73 +15,121 @@
Logger,
UnauthorizedException,
} from "@nestjs/common";

export class RateLimitException extends HttpException {
public readonly resetTime: number;
public constructor(resetTime: number) {
super("Too Many Requests", 429);
this.resetTime = resetTime;
}
}
import { InjectRedis } from "#redis";
import { Key } from "@statsify/schemas";
import { Redis } from "ioredis";
import { createHash, randomUUID } from "node:crypto";

@Injectable()
export class AuthService {
private readonly logger = new Logger("AuthGuard");
const RATE_LIMIT_WINDOW_SECONDS = 60;

public constructor(@InjectRedis() private readonly redis: Redis) {}
const RATE_LIMIT_SCRIPT = `
local key = KEYS[1]
local rateLimitKey = KEYS[2]

public async limited(apiKey: string, weight: number, role: AuthRole) {
const hash = this.hash(apiKey);
local requiredRole = tonumber(ARGV[1])
local weight = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local windowSeconds = tonumber(ARGV[4])
local currentSecond = math.floor(now / 1000)
local windowStart = currentSecond - windowSeconds + 1

const [name, ...keyInfo] = await this.redis.hmget(
`key:${hash}`,
"name",
"role",
"limit"
);
local keyData = redis.call("HMGET", key, "name", "role", "limit")
local name = keyData[1]

if (name === null) throw new UnauthorizedException();
if not name then
return { "unauthorized" }
end

const [apiKeyRole, apiKeyLimit] = keyInfo.map(Number);
local apiKeyRole = tonumber(keyData[2]) or 0
local apiKeyLimit = tonumber(keyData[3]) or 0

if (apiKeyRole < role) throw new ForbiddenException();
if apiKeyRole < requiredRole then
return { "forbidden" }
end

const pipeline = this.redis.pipeline();
local buckets = redis.call("HGETALL", rateLimitKey)
local weightedTotal = 0
local oldestSecond = currentSecond

const time = Date.now();
const expirey = 60_000;
for i = 1, #buckets, 2 do
local bucketSecond = tonumber(buckets[i])
local bucketWeight = tonumber(buckets[i + 1])

const key = `ratelimit:${hash}`;
if bucketSecond < windowStart then
redis.call("HDEL", rateLimitKey, buckets[i])
else
weightedTotal = weightedTotal + bucketWeight

pipeline.zremrangebyscore(key, 0, time - expirey);
pipeline.zadd(key, time, `${randomUUID()}:${weight}`);
pipeline.zrange(key, 0, -1, "WITHSCORES");
pipeline.hincrby(`key:${hash}`, "requests", weight);
pipeline.expire(key, expirey / 1000);
if bucketSecond < oldestSecond then
oldestSecond = bucketSecond
end
end
end

const pipelineResult = await pipeline.exec();
local projectedTotal = weightedTotal + weight
local resetTime = ((oldestSecond + windowSeconds) * 1000) - now

if (!pipelineResult) throw new InternalServerErrorException();
if projectedTotal > apiKeyLimit then
return { "ok", name, apiKeyLimit, projectedTotal, resetTime }
end

redis.call("HINCRBY", rateLimitKey, currentSecond, weight)
redis.call("HINCRBY", key, "requests", weight)
redis.call("EXPIRE", rateLimitKey, windowSeconds)

return { "ok", name, apiKeyLimit, projectedTotal, resetTime }
`;

type RateLimitResult =
["unauthorized"] |
["forbidden"] |
["ok", string, number, number, number];

@Injectable()
export class AuthService {
private readonly logger = new Logger("AuthGuard");

const requests = pipelineResult[2];
public constructor(@InjectRedis() private readonly redis: Redis) {}

if (requests[0]) throw new InternalServerErrorException();
public async limited(apiKey: string, weight: number, role: AuthRole) {
const hash = this.hash(apiKey);
const result = await this.redis.eval(
RATE_LIMIT_SCRIPT,
2,
`key:${hash}`,
`ratelimit:v2:${hash}`,
role,
weight,
Date.now(),
RATE_LIMIT_WINDOW_SECONDS
) as RateLimitResult;

const weightedTotal = (requests[1] as string[])
.filter((_, i) => i % 2 === 0)
.reduce((acc, key) => acc + Number(key.split(":")[1]), 0);
const [status, name, apiKeyLimit, weightedTotal, resetTime] = result;

const resetTime = 60_000 - (time - (requests[1] as [Error | null, number])[1]);
if (status === "unauthorized") throw new UnauthorizedException();
if (status === "forbidden") throw new ForbiddenException();

if (weightedTotal > apiKeyLimit) {
if (Number(weightedTotal) > Number(apiKeyLimit)) {
this.logger.warn(
`${name} has exceeded their request limit of ${apiKeyLimit} and has requested ${weightedTotal} times`
);

throw new HttpException("Too Many Requests", 429);
throw new RateLimitException(Number(resetTime));
}

return {
canActivate: true,
used: weightedTotal,
limit: apiKeyLimit,
resetTime,
used: Number(weightedTotal),
limit: Number(apiKeyLimit),
resetTime: Number(resetTime),
};
}

Expand All @@ -106,11 +154,11 @@

public async getKey(apiKey: string): Promise<Key> {
const hash = this.hash(apiKey);
const key = `ratelimit:${hash}`;
const key = `ratelimit:v2:${hash}`;

const pipeline = this.redis.pipeline();
pipeline.hmget(`key:${hash}`, "name", "requests", "limit");
pipeline.zrange(key, 0, -1, "WITHSCORES");
pipeline.hgetall(key);

const pipelineResult = await pipeline.exec();

Expand All @@ -119,15 +167,11 @@
const [keydata, requests] = pipelineResult;

const [name, lifetimeRequests, limit] = keydata[1] as [string, number, number];

const recentRequests = (requests[1] as string[])
.filter((_, i) => i % 2 === 0)
.reduce((acc, key) => acc + Number(key.split(":")[1]), 0);

const time = Date.now();

const resetTime =
60_000 - (time - Number((requests[1] as [Error | null, number])[1]));
const requestBuckets = requests[1] as Record<string, string>;
const { recentRequests, resetTime } = getRateLimitBucketStats(
requestBuckets,
Date.now()
);

return {
name,
Expand All @@ -139,6 +183,33 @@
}

private hash(apiKey: string): string {
return createHash("sha256").update(apiKey).digest("hex");

Check failure

Code scanning / CodeQL

Use of password hash with insufficient computational effort High

Password from
an access to apiKey
is hashed insecurely.
Password from
an access to apiKey
is hashed insecurely.
}
}

export function getRateLimitBucketStats(buckets: Record<string, string>, now: number) {
const currentSecond = Math.floor(now / 1000);
const windowStart = currentSecond - RATE_LIMIT_WINDOW_SECONDS + 1;

let oldestSecond = currentSecond;
let recentRequests = 0;

for (const [bucket, weight] of Object.entries(buckets)) {
const bucketSecond = Number(bucket);

if (bucketSecond < windowStart) continue;

recentRequests += Number(weight);

if (bucketSecond < oldestSecond) {
oldestSecond = bucketSecond;
}
}

return {
recentRequests,
resetTime: recentRequests > 0 ?
(oldestSecond + RATE_LIMIT_WINDOW_SECONDS) * 1000 - now :
0,
};
}
5 changes: 3 additions & 2 deletions apps/api/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"extends": "../../tsconfig.base.json",
"include": [
"src",
"eslint.config.js"
"eslint.config.js",
"vitest.config.ts"
]
}
}
11 changes: 11 additions & 0 deletions apps/api/vitest.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/**
* Copyright (c) Statsify
*
* This source code is licensed under the GNU GPL v3 license found in the
* LICENSE file in the root directory of this source tree.
* https://github.com/Statsify/statsify/blob/main/LICENSE
*/

import { config } from "../../vitest.shared.js";

export default await config();
Loading