Skip to content

Commit

Permalink
feat: add rate limiting (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
CodyTseng authored Aug 12, 2023
1 parent fd661b3 commit 790fd6a
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 4 deletions.
43 changes: 41 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"@nestjs/event-emitter": "^2.0.1",
"@nestjs/platform-express": "^10.1.3",
"@nestjs/platform-ws": "^10.1.3",
"@nestjs/throttler": "^4.2.1",
"@nestjs/typeorm": "^10.0.0",
"@nestjs/websockets": "^10.1.3",
"@noble/curves": "^1.1.0",
Expand Down Expand Up @@ -96,4 +97,4 @@
"testEnvironment": "node",
"maxWorkers": 1
}
}
}
6 changes: 6 additions & 0 deletions src/app.module.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Module } from '@nestjs/common';
import { ConfigModule, ConfigService } from '@nestjs/config';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { ThrottlerModule } from '@nestjs/throttler';
import { TypeOrmModule } from '@nestjs/typeorm';
import { LoggerModule, PinoLogger } from 'nestjs-pino';
import { loggerModuleFactory } from './common/utils/logger-module-factory';
Expand Down Expand Up @@ -38,6 +39,11 @@ import { NostrModule } from './nostr/nostr.module';
},
inject: [ConfigService, PinoLogger],
}),
ThrottlerModule.forRootAsync({
useFactory: (configService: ConfigService<Config, true>) =>
configService.get('throttler', { infer: true }),
inject: [ConfigService],
}),
NostrModule,
],
})
Expand Down
1 change: 1 addition & 0 deletions src/common/exceptions/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './client.exception';
export * from './restricted.exception';
export * from './validation.exception';
export * from './throttler.exception';
7 changes: 7 additions & 0 deletions src/common/exceptions/throttler.exception.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { ClientException } from './client.exception';

export class ThrottlerException extends ClientException {
constructor() {
super('rate-limited: slow down there chief');
}
}
1 change: 1 addition & 0 deletions src/common/guards/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from './ws-throttler.guard';
39 changes: 39 additions & 0 deletions src/common/guards/ws-throttler.guard.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { createMock } from '@golevelup/ts-jest';
import { ExecutionContext } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { ThrottlerStorageService } from '@nestjs/throttler';
import { WsThrottlerGuard } from './ws-throttler.guard';

describe('WsThrottlerGuard', () => {
const context = createMock<ExecutionContext>({
getClass: jest.fn().mockReturnValue({ name: 'Test' }),
getHandler: jest.fn().mockReturnValue({ name: 'test' }),
switchToWs: jest.fn().mockReturnValue({
getClient: jest.fn().mockReturnValue({ id: 'test' }),
}),
});

let storageService: ThrottlerStorageService;

beforeEach(() => {
storageService = new ThrottlerStorageService();
});

afterEach(() => {
storageService.onApplicationShutdown();
});

it('should be fine', async () => {
const guard = new WsThrottlerGuard(
{ limit: 2, ttl: 2 },
storageService,
new Reflector(),
);

await expect(guard.canActivate(context)).resolves.toBe(true);
await expect(guard.canActivate(context)).resolves.toBe(true);
await expect(guard.canActivate(context)).rejects.toThrowError(
'rate-limited: slow down there chief',
);
});
});
23 changes: 23 additions & 0 deletions src/common/guards/ws-throttler.guard.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { ExecutionContext, Injectable } from '@nestjs/common';
import { ThrottlerGuard } from '@nestjs/throttler';
import { WebSocket } from 'ws';
import { ThrottlerException } from '../exceptions';

@Injectable()
export class WsThrottlerGuard extends ThrottlerGuard {
protected async handleRequest(
context: ExecutionContext,
limit: number,
ttl: number,
): Promise<boolean> {
const client = context.switchToWs().getClient<WebSocket>();
const key = this.generateKey(context, client.id);
const { totalHits } = await this.storageService.increment(key, ttl);

if (totalHits > limit) {
throw new ThrottlerException();
}

return true;
}
}
4 changes: 4 additions & 0 deletions src/config/config.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ describe('config', () => {
LOG_LEVEL: 'info',
EVENT_CREATED_AT_UPPER_LIMIT: '60',
EVENT_ID_MIN_LEADING_ZERO_BITS: '16',
THROTTLER_LIMIT: '100',
THROTTLER_TTL: '1',
}),
).toEqual({
DOMAIN: 'localhost',
Expand All @@ -26,6 +28,8 @@ describe('config', () => {
LOG_LEVEL: 'info',
EVENT_CREATED_AT_UPPER_LIMIT: 60,
EVENT_ID_MIN_LEADING_ZERO_BITS: 16,
THROTTLER_LIMIT: 100,
THROTTLER_TTL: 1,
});
});
});
2 changes: 2 additions & 0 deletions src/config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { limitConfig } from './limit.config';
import { loggerConfig } from './logger.config';
import { meiliSearchConfig } from './meili-search';
import { relayInfoDocConfig } from './relay-info-doc.config';
import { throttlerConfig } from './throttler.config';

export function config() {
const env = validateEnvironment(process.env);
Expand All @@ -15,6 +16,7 @@ export function config() {
limit: limitConfig(env),
relayInfoDoc: relayInfoDocConfig(env),
logger: loggerConfig(env),
throttler: throttlerConfig(env),
};
}
export type Config = ReturnType<typeof config>;
9 changes: 9 additions & 0 deletions src/config/environment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ export const EnvironmentSchema = z.object({
.string()
.transform((minLeadingZeroBits) => parseInt(minLeadingZeroBits))
.optional(),

THROTTLER_LIMIT: z
.string()
.transform((limit) => parseInt(limit))
.optional(),
THROTTLER_TTL: z
.string()
.transform((ttl) => parseInt(ttl))
.optional(),
});
export type Environment = z.infer<typeof EnvironmentSchema>;

Expand Down
1 change: 1 addition & 0 deletions src/config/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export * from './environment';
export * from './limit.config';
export * from './meili-search';
export * from './relay-info-doc.config';
export * from './throttler.config';
8 changes: 8 additions & 0 deletions src/config/throttler.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { Environment } from './environment';

export function throttlerConfig(env: Environment) {
return {
limit: env.THROTTLER_LIMIT ?? 100,
ttl: env.THROTTLER_TTL ?? 1,
};
}
4 changes: 3 additions & 1 deletion src/nostr/nostr.gateway.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { UseFilters } from '@nestjs/common';
import { UseFilters, UseGuards } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import {
ConnectedSocket,
Expand All @@ -15,6 +15,7 @@ import { concatWith, filter, from, map, of } from 'rxjs';
import { WebSocket, WebSocketServer } from 'ws';
import { RestrictedException } from '../common/exceptions';
import { WsExceptionFilter } from '../common/filters';
import { WsThrottlerGuard } from '../common/guards';
import { ZodValidationPipe } from '../common/pipes';
import { Config, LimitConfig } from '../config';
import { MessageType } from './constants';
Expand All @@ -40,6 +41,7 @@ import {

@WebSocketGateway()
@UseFilters(WsExceptionFilter)
@UseGuards(WsThrottlerGuard)
export class NostrGateway
implements OnGatewayInit, OnGatewayConnection, OnGatewayDisconnect
{
Expand Down

0 comments on commit 790fd6a

Please sign in to comment.