From 34f2214a7e98959ac59f979f3c8c47d70ccccbc6 Mon Sep 17 00:00:00 2001 From: mzivkovicdev Date: Fri, 3 Apr 2026 19:36:22 +0200 Subject: [PATCH] add rate limiting configuration --- docs/examples/crud-spec-full.yaml | 19 +- docs/schema/crud-spec.schema.json | 92 +++++ .../constants/GeneratorConstants.java | 1 + .../generators/RateLimitingGenerator.java | 162 ++++++++ .../generators/SpringCrudGenerator.java | 2 + .../models/CrudConfiguration.java | 291 ++++++++++++- .../ratelimiting/rate-limiter-service.ftl | 120 ++++++ .../ratelimiting/rate-limiting-filter.ftl | 104 +++++ .../redis-rate-limiter-configuration.ftl | 34 ++ .../generators/RateLimitingGeneratorTest.java | 383 ++++++++++++++++++ 10 files changed, 1203 insertions(+), 5 deletions(-) create mode 100644 spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGenerator.java create mode 100644 spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiter-service.ftl create mode 100644 spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiting-filter.ftl create mode 100644 spring-crud-generator-core/src/main/resources/templates/ratelimiting/redis-rate-limiter-configuration.ftl create mode 100644 spring-crud-generator-core/src/test/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGeneratorTest.java diff --git a/docs/examples/crud-spec-full.yaml b/docs/examples/crud-spec-full.yaml index c4052aff..d7f93eeb 100644 --- a/docs/examples/crud-spec-full.yaml +++ b/docs/examples/crud-spec-full.yaml @@ -14,11 +14,28 @@ configuration: image: postgres port: 5432 tag: latest - cache: + cache: enabled: true type: REDIS expiration: 5 # maxSize: 10000 use for CAFFEINE + rateLimiting: + enabled: true + type: IN_MEMORY # IN_MEMORY | REDIS + keyStrategy: IP # IP | API_KEY | HEADER | AUTHENTICATED_USER + # keyHeader: X-Client-Id # only used when keyStrategy: HEADER + global: + capacity: 100 + refillTokens: 100 + refillDuration: 60 # seconds (100 req/min) + # overdraft: # optional burst control + # capacity: 20 + # refillTokens: 20 + # refillDuration: 10 + response: + statusCode: 429 + includeHeaders: true + message: "Rate limit exceeded. Please try again later." openApi: apiSpec: true generateResources: true diff --git a/docs/schema/crud-spec.schema.json b/docs/schema/crud-spec.schema.json index a94fd87e..caae76c3 100644 --- a/docs/schema/crud-spec.schema.json +++ b/docs/schema/crud-spec.schema.json @@ -202,6 +202,9 @@ }, "additionalProperties": { "$ref": "#/$defs/additionalProperties" + }, + "rateLimiting": { + "$ref": "#/$defs/rateLimiting" } } }, @@ -302,6 +305,95 @@ } } }, + "rateLimiting": { + "type": "object", + "additionalProperties": false, + "description": "Rate limiting configuration using Bucket4j token bucket algorithm.", + "properties": { + "enabled": { + "type": "boolean" + }, + "type": { + "type": "string", + "description": "Storage backend for rate limit buckets. Value matching is case-insensitive in the generator.", + "enum": [ + "IN_MEMORY", + "REDIS", + "in_memory", + "redis" + ] + }, + "keyStrategy": { + "type": "string", + "description": "How clients are identified for rate limiting. Value matching is case-insensitive in the generator.", + "enum": [ + "IP", + "API_KEY", + "HEADER", + "AUTHENTICATED_USER", + "ip", + "api_key", + "header", + "authenticated_user" + ] + }, + "keyHeader": { + "type": "string", + "description": "Header name used when keyStrategy is HEADER." + }, + "global": { + "$ref": "#/$defs/rateLimitDefinition" + }, + "response": { + "$ref": "#/$defs/rateLimitResponseConfig" + } + } + }, + "rateLimitDefinition": { + "type": "object", + "additionalProperties": false, + "properties": { + "capacity": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens in the bucket." + }, + "refillTokens": { + "type": "integer", + "minimum": 1, + "description": "Number of tokens added per refill period." + }, + "refillDuration": { + "type": "integer", + "minimum": 1, + "description": "Refill period in seconds." + }, + "overdraft": { + "$ref": "#/$defs/rateLimitDefinition", + "description": "Optional burst control bandwidth (secondary token bucket)." + } + } + }, + "rateLimitResponseConfig": { + "type": "object", + "additionalProperties": false, + "properties": { + "statusCode": { + "type": "integer", + "minimum": 100, + "maximum": 599, + "description": "HTTP status code returned when rate limit is exceeded. Defaults to 429." + }, + "includeHeaders": { + "type": "boolean", + "description": "When true, includes X-Rate-Limit-Remaining and X-Rate-Limit-Retry-After-Seconds headers." + }, + "message": { + "type": "string", + "description": "Error message returned in the response body when rate limit is exceeded." + } + } + }, "graphql": { "type": "object", "additionalProperties": false, diff --git a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/constants/GeneratorConstants.java b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/constants/GeneratorConstants.java index 787c27a9..dfcfe845 100644 --- a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/constants/GeneratorConstants.java +++ b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/constants/GeneratorConstants.java @@ -83,6 +83,7 @@ private GeneratorContextKeys() {} public static final String EXCLUSION_NULL_CONFIG = "exclusion-null-config"; public static final String GITHUB_ACTIONS_WORKFLOW = "github-actions-workflow"; public static final String MONGOCK_MIGRATION_SCRIPT = "mongock-migration-script"; + public static final String RATE_LIMITING_CONFIGURATION = "rate-limiting-configuration"; } } diff --git a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGenerator.java b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGenerator.java new file mode 100644 index 00000000..de67ab0e --- /dev/null +++ b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGenerator.java @@ -0,0 +1,162 @@ +/* + * Copyright 2025-present Marko Zivkovic + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dev.markozivkovic.springcrudgenerator.generators; + +import static dev.markozivkovic.springcrudgenerator.constants.ImportConstants.PACKAGE; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import dev.markozivkovic.springcrudgenerator.constants.GeneratorConstants; +import dev.markozivkovic.springcrudgenerator.constants.TemplateContextConstants; +import dev.markozivkovic.springcrudgenerator.context.GeneratorContext; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitDefinition; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitResponseConfig; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitingConfiguration; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitingConfiguration.KeyStrategyEnum; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitingConfiguration.RateLimitTypeEnum; +import dev.markozivkovic.springcrudgenerator.models.PackageConfiguration; +import dev.markozivkovic.springcrudgenerator.utils.FileWriterUtils; +import dev.markozivkovic.springcrudgenerator.utils.FreeMarkerTemplateProcessorUtils; +import dev.markozivkovic.springcrudgenerator.utils.PackageUtils; +import dev.markozivkovic.springcrudgenerator.utils.SpringBootVersionUtils; + +public class RateLimitingGenerator implements ProjectArtifactGenerator { + + private static final Logger LOGGER = LoggerFactory.getLogger(RateLimitingGenerator.class); + + private static final long DEFAULT_CAPACITY = 100L; + private static final long DEFAULT_REFILL_TOKENS = 100L; + private static final long DEFAULT_REFILL_DURATION = 60L; + private static final int DEFAULT_STATUS_CODE = 429; + private static final String DEFAULT_MESSAGE = "Rate limit exceeded. Please try again later."; + + private final CrudConfiguration crudConfiguration; + private final PackageConfiguration packageConfiguration; + + public RateLimitingGenerator(final CrudConfiguration crudConfiguration, + final PackageConfiguration packageConfiguration) { + this.crudConfiguration = crudConfiguration; + this.packageConfiguration = packageConfiguration; + } + + @Override + public void generate(final String outputDir) { + + if (Objects.isNull(crudConfiguration.getRateLimiting()) + || !Boolean.TRUE.equals(crudConfiguration.getRateLimiting().getEnabled())) { + LOGGER.info("Skipping RateLimitingGenerator, as rate limiting is not enabled."); + return; + } + + if (GeneratorContext.isGenerated(GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)) { + return; + } + + final String packagePath = PackageUtils.getPackagePathFromOutputDir(outputDir); + final String configSubPackage = PackageUtils.computeConfigurationSubPackage(packageConfiguration); + final String configPackage = PackageUtils.computeConfigurationPackage(packagePath, packageConfiguration); + final boolean isSpringBoot3 = SpringBootVersionUtils.isSpringBoot3(crudConfiguration.getSpringBootVersion()); + + final RateLimitingConfiguration rl = crudConfiguration.getRateLimiting(); + final RateLimitTypeEnum type = Objects.nonNull(rl.getType()) ? rl.getType() : RateLimitTypeEnum.IN_MEMORY; + final KeyStrategyEnum keyStrategy = Objects.nonNull(rl.getKeyStrategy()) ? rl.getKeyStrategy() : KeyStrategyEnum.IP; + final String keyHeader = Objects.nonNull(rl.getKeyHeader()) ? rl.getKeyHeader() : "X-Client-Id"; + + final RateLimitDefinition global = Objects.nonNull(rl.getGlobal()) ? rl.getGlobal() : new RateLimitDefinition(); + final long capacity = Objects.nonNull(global.getCapacity()) ? global.getCapacity() : DEFAULT_CAPACITY; + final long refillTokens = Objects.nonNull(global.getRefillTokens()) ? global.getRefillTokens() : DEFAULT_REFILL_TOKENS; + final long refillDuration = Objects.nonNull(global.getRefillDuration()) ? global.getRefillDuration() : DEFAULT_REFILL_DURATION; + final boolean hasOverdraft = Objects.nonNull(global.getOverdraft()); + + final RateLimitResponseConfig responseConfig = rl.getResponse(); + final int statusCode = Objects.nonNull(responseConfig) && Objects.nonNull(responseConfig.getStatusCode()) + ? responseConfig.getStatusCode() : DEFAULT_STATUS_CODE; + final boolean includeHeaders = Objects.isNull(responseConfig) || Objects.isNull(responseConfig.getIncludeHeaders()) + || Boolean.TRUE.equals(responseConfig.getIncludeHeaders()); + final String message = Objects.nonNull(responseConfig) && Objects.nonNull(responseConfig.getMessage()) + ? responseConfig.getMessage() : DEFAULT_MESSAGE; + + final Map context = new HashMap<>(); + context.put("type", type); + context.put("keyStrategy", keyStrategy); + context.put("keyHeader", keyHeader); + context.put("capacity", capacity); + context.put("refillTokens", refillTokens); + context.put("refillDuration", refillDuration); + context.put("hasOverdraft", hasOverdraft); + context.put("statusCode", statusCode); + context.put("includeHeaders", includeHeaders); + context.put("message", message); + context.put(TemplateContextConstants.IS_SPRING_BOOT_3, isSpringBoot3); + + if (hasOverdraft) { + final RateLimitDefinition overdraft = global.getOverdraft(); + context.put("overdraftCapacity", Objects.nonNull(overdraft.getCapacity()) ? overdraft.getCapacity() : 20L); + context.put("overdraftRefillTokens", Objects.nonNull(overdraft.getRefillTokens()) ? overdraft.getRefillTokens() : 20L); + context.put("overdraftRefillDuration", Objects.nonNull(overdraft.getRefillDuration()) ? overdraft.getRefillDuration() : 10L); + } + + this.generateRateLimiterService(outputDir, configPackage, configSubPackage, context); + this.generateRateLimitingFilter(outputDir, configPackage, configSubPackage, context); + + if (RateLimitTypeEnum.REDIS.equals(type)) { + this.generateRedisRateLimiterConfiguration(outputDir, configPackage, configSubPackage, context); + } + + GeneratorContext.markGenerated(GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION); + } + + private void generateRateLimiterService(final String outputDir, final String configPackage, + final String configSubPackage, final Map context) { + + final StringBuilder sb = new StringBuilder(); + sb.append(String.format(PACKAGE, configPackage)) + .append(FreeMarkerTemplateProcessorUtils.processTemplate( + "ratelimiting/rate-limiter-service.ftl", context)); + + FileWriterUtils.writeToFile(outputDir, configSubPackage, "RateLimiterService.java", sb.toString()); + } + + private void generateRateLimitingFilter(final String outputDir, final String configPackage, + final String configSubPackage, final Map context) { + + final StringBuilder sb = new StringBuilder(); + sb.append(String.format(PACKAGE, configPackage)) + .append(FreeMarkerTemplateProcessorUtils.processTemplate( + "ratelimiting/rate-limiting-filter.ftl", context)); + + FileWriterUtils.writeToFile(outputDir, configSubPackage, "RateLimitingFilter.java", sb.toString()); + } + + private void generateRedisRateLimiterConfiguration(final String outputDir, final String configPackage, + final String configSubPackage, final Map context) { + + final StringBuilder sb = new StringBuilder(); + sb.append(String.format(PACKAGE, configPackage)) + .append(FreeMarkerTemplateProcessorUtils.processTemplate( + "ratelimiting/redis-rate-limiter-configuration.ftl", context)); + + FileWriterUtils.writeToFile(outputDir, configSubPackage, "RedisRateLimiterConfiguration.java", sb.toString()); + } + +} diff --git a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/SpringCrudGenerator.java b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/SpringCrudGenerator.java index abffde97..c6c50071 100644 --- a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/SpringCrudGenerator.java +++ b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/generators/SpringCrudGenerator.java @@ -49,6 +49,7 @@ public class SpringCrudGenerator implements CodeGenerator, ProjectArtifactGenera private static final String SWAGGER = "swagger"; private static final String OPENAPI_CODEGEN = "openapi-codegen"; private static final String GRAPHQL = "graphql"; + private static final String RATE_LIMITING = "rate-limiting"; private final Map ARTIFACT_GENERATORS; private final Map GENERATORS; @@ -64,6 +65,7 @@ public SpringCrudGenerator(final CrudConfiguration crudConfiguration, final List this.ARTIFACT_GENERATORS.put(EXCEPTION_HANDLER, new GlobalExceptionHandlerGenerator(crudConfiguration, entities, packageConfiguration)); this.ARTIFACT_GENERATORS.put(SWAGGER, new SwaggerDocumentationGenerator(crudConfiguration, projectMetadata, entities)); this.ARTIFACT_GENERATORS.put(OPENAPI_CODEGEN, new OpenApiCodeGenerator(crudConfiguration, projectMetadata, entities, packageConfiguration)); + this.ARTIFACT_GENERATORS.put(RATE_LIMITING, new RateLimitingGenerator(crudConfiguration, packageConfiguration)); this.GENERATORS = new LinkedHashMap<>(); this.GENERATORS.put(ENUM, new EnumGenerator(packageConfiguration)); diff --git a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/models/CrudConfiguration.java b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/models/CrudConfiguration.java index 023b9dad..387ea117 100644 --- a/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/models/CrudConfiguration.java +++ b/spring-crud-generator-core/src/main/java/dev/markozivkovic/springcrudgenerator/models/CrudConfiguration.java @@ -35,7 +35,8 @@ public class CrudConfiguration { private Boolean dependencyCheck = false; private TestConfiguration tests; private Map additionalProperties = new HashMap<>(); - + private RateLimitingConfiguration rateLimiting; + public CrudConfiguration() { } @@ -44,7 +45,7 @@ public CrudConfiguration(final DatabaseType database, final Integer javaVersion, final Boolean optimisticLocking, final DockerConfiguration docker, final CacheConfiguration cache, final OpenApiDefinition openApi, final GraphQLDefinition graphql, final ErrorResponse errorResponse, final Boolean migrationScripts, final Boolean dependencyCheck, final TestConfiguration tests, - final Map additionalProperties) { + final Map additionalProperties, final RateLimitingConfiguration rateLimiting) { this.database = database; this.javaVersion = javaVersion; this.springBootVersion = springBootVersion; @@ -58,6 +59,7 @@ public CrudConfiguration(final DatabaseType database, final Integer javaVersion, this.dependencyCheck = dependencyCheck; this.tests = tests; this.additionalProperties = additionalProperties; + this.rateLimiting = rateLimiting; } public DatabaseType getDatabase() { @@ -181,6 +183,15 @@ public CrudConfiguration setAdditionalProperties(final Map addit return this; } + public RateLimitingConfiguration getRateLimiting() { + return this.rateLimiting; + } + + public CrudConfiguration setRateLimiting(final RateLimitingConfiguration rateLimiting) { + this.rateLimiting = rateLimiting; + return this; + } + @Override public boolean equals(final Object o) { if (o == this) @@ -201,14 +212,15 @@ public boolean equals(final Object o) { Objects.equals(migrationScripts, crudConfiguration.migrationScripts) && Objects.equals(dependencyCheck, crudConfiguration.dependencyCheck) && Objects.equals(tests, crudConfiguration.tests) && - Objects.equals(additionalProperties, crudConfiguration.additionalProperties); + Objects.equals(additionalProperties, crudConfiguration.additionalProperties) && + Objects.equals(rateLimiting, crudConfiguration.rateLimiting); } @Override public int hashCode() { return Objects.hash( database, javaVersion, springBootVersion, optimisticLocking, docker, cache, openApi, - graphql, errorResponse, migrationScripts, dependencyCheck, tests, additionalProperties + graphql, errorResponse, migrationScripts, dependencyCheck, tests, additionalProperties, rateLimiting ); } @@ -228,6 +240,7 @@ public String toString() { ", dependencyCheck='" + getDependencyCheck() + "'" + ", tests='" + getTests() + "'" + ", additionalProperties='" + getAdditionalProperties() + "'" + + ", rateLimiting='" + getRateLimiting() + "'" + "}"; } @@ -751,4 +764,274 @@ public String toString() { } } + public static class RateLimitingConfiguration { + + private Boolean enabled; + private RateLimitTypeEnum type; + private KeyStrategyEnum keyStrategy; + private String keyHeader; + private RateLimitDefinition global; + private RateLimitResponseConfig response; + + public RateLimitingConfiguration() {} + + public RateLimitingConfiguration(final Boolean enabled, final RateLimitTypeEnum type, + final KeyStrategyEnum keyStrategy, final String keyHeader, + final RateLimitDefinition global, final RateLimitResponseConfig response) { + this.enabled = enabled; + this.type = type; + this.keyStrategy = keyStrategy; + this.keyHeader = keyHeader; + this.global = global; + this.response = response; + } + + public Boolean getEnabled() { + return this.enabled; + } + + public RateLimitingConfiguration setEnabled(final Boolean enabled) { + this.enabled = enabled; + return this; + } + + public RateLimitTypeEnum getType() { + return this.type; + } + + public RateLimitingConfiguration setType(final RateLimitTypeEnum type) { + this.type = type; + return this; + } + + public KeyStrategyEnum getKeyStrategy() { + return this.keyStrategy; + } + + public RateLimitingConfiguration setKeyStrategy(final KeyStrategyEnum keyStrategy) { + this.keyStrategy = keyStrategy; + return this; + } + + public String getKeyHeader() { + return this.keyHeader; + } + + public RateLimitingConfiguration setKeyHeader(final String keyHeader) { + this.keyHeader = keyHeader; + return this; + } + + public RateLimitDefinition getGlobal() { + return this.global; + } + + public RateLimitingConfiguration setGlobal(final RateLimitDefinition global) { + this.global = global; + return this; + } + + public RateLimitResponseConfig getResponse() { + return this.response; + } + + public RateLimitingConfiguration setResponse(final RateLimitResponseConfig response) { + this.response = response; + return this; + } + + @Override + public boolean equals(final Object o) { + if (o == this) + return true; + if (!(o instanceof RateLimitingConfiguration)) { + return false; + } + final RateLimitingConfiguration other = (RateLimitingConfiguration) o; + return Objects.equals(enabled, other.enabled) && + Objects.equals(type, other.type) && + Objects.equals(keyStrategy, other.keyStrategy) && + Objects.equals(keyHeader, other.keyHeader) && + Objects.equals(global, other.global) && + Objects.equals(response, other.response); + } + + @Override + public int hashCode() { + return Objects.hash(enabled, type, keyStrategy, keyHeader, global, response); + } + + @Override + public String toString() { + return "{" + + " enabled='" + getEnabled() + "'" + + ", type='" + getType() + "'" + + ", keyStrategy='" + getKeyStrategy() + "'" + + ", keyHeader='" + getKeyHeader() + "'" + + ", global='" + getGlobal() + "'" + + ", response='" + getResponse() + "'" + + "}"; + } + + public enum RateLimitTypeEnum { + IN_MEMORY, REDIS + } + + public enum KeyStrategyEnum { + IP, API_KEY, HEADER, AUTHENTICATED_USER + } + } + + public static class RateLimitDefinition { + + private Long capacity; + private Long refillTokens; + private Long refillDuration; + private RateLimitDefinition overdraft; + + public RateLimitDefinition() {} + + public RateLimitDefinition(final Long capacity, final Long refillTokens, final Long refillDuration, + final RateLimitDefinition overdraft) { + this.capacity = capacity; + this.refillTokens = refillTokens; + this.refillDuration = refillDuration; + this.overdraft = overdraft; + } + + public Long getCapacity() { + return this.capacity; + } + + public RateLimitDefinition setCapacity(final Long capacity) { + this.capacity = capacity; + return this; + } + + public Long getRefillTokens() { + return this.refillTokens; + } + + public RateLimitDefinition setRefillTokens(final Long refillTokens) { + this.refillTokens = refillTokens; + return this; + } + + public Long getRefillDuration() { + return this.refillDuration; + } + + public RateLimitDefinition setRefillDuration(final Long refillDuration) { + this.refillDuration = refillDuration; + return this; + } + + public RateLimitDefinition getOverdraft() { + return this.overdraft; + } + + public RateLimitDefinition setOverdraft(final RateLimitDefinition overdraft) { + this.overdraft = overdraft; + return this; + } + + @Override + public boolean equals(final Object o) { + if (o == this) + return true; + if (!(o instanceof RateLimitDefinition)) { + return false; + } + final RateLimitDefinition other = (RateLimitDefinition) o; + return Objects.equals(capacity, other.capacity) && + Objects.equals(refillTokens, other.refillTokens) && + Objects.equals(refillDuration, other.refillDuration) && + Objects.equals(overdraft, other.overdraft); + } + + @Override + public int hashCode() { + return Objects.hash(capacity, refillTokens, refillDuration, overdraft); + } + + @Override + public String toString() { + return "{" + + " capacity='" + getCapacity() + "'" + + ", refillTokens='" + getRefillTokens() + "'" + + ", refillDuration='" + getRefillDuration() + "'" + + ", overdraft='" + getOverdraft() + "'" + + "}"; + } + } + + public static class RateLimitResponseConfig { + + private Integer statusCode; + private Boolean includeHeaders; + private String message; + + public RateLimitResponseConfig() {} + + public RateLimitResponseConfig(final Integer statusCode, final Boolean includeHeaders, final String message) { + this.statusCode = statusCode; + this.includeHeaders = includeHeaders; + this.message = message; + } + + public Integer getStatusCode() { + return this.statusCode; + } + + public RateLimitResponseConfig setStatusCode(final Integer statusCode) { + this.statusCode = statusCode; + return this; + } + + public Boolean getIncludeHeaders() { + return this.includeHeaders; + } + + public RateLimitResponseConfig setIncludeHeaders(final Boolean includeHeaders) { + this.includeHeaders = includeHeaders; + return this; + } + + public String getMessage() { + return this.message; + } + + public RateLimitResponseConfig setMessage(final String message) { + this.message = message; + return this; + } + + @Override + public boolean equals(final Object o) { + if (o == this) + return true; + if (!(o instanceof RateLimitResponseConfig)) { + return false; + } + final RateLimitResponseConfig other = (RateLimitResponseConfig) o; + return Objects.equals(statusCode, other.statusCode) && + Objects.equals(includeHeaders, other.includeHeaders) && + Objects.equals(message, other.message); + } + + @Override + public int hashCode() { + return Objects.hash(statusCode, includeHeaders, message); + } + + @Override + public String toString() { + return "{" + + " statusCode='" + getStatusCode() + "'" + + ", includeHeaders='" + getIncludeHeaders() + "'" + + ", message='" + getMessage() + "'" + + "}"; + } + } + } diff --git a/spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiter-service.ftl b/spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiter-service.ftl new file mode 100644 index 00000000..5ffe19e5 --- /dev/null +++ b/spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiter-service.ftl @@ -0,0 +1,120 @@ +<#setting number_format="computer"> +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import io.github.bucket4j.Bandwidth; +import io.github.bucket4j.Bucket; +import io.github.bucket4j.BucketConfiguration; +import io.github.bucket4j.ConsumptionProbe; +<#if type == "REDIS"> +import io.github.bucket4j.distributed.proxy.ProxyManager; + + +@Service +public class RateLimiterService { + +<#if type == "IN_MEMORY"> + private final Map buckets = new ConcurrentHashMap<>(); +<#else> + private final ProxyManager proxyManager; + + + private final long capacity; + private final long refillTokens; + private final Duration refillDuration; +<#if hasOverdraft> + private final long overdraftCapacity; + private final long overdraftRefillTokens; + private final Duration overdraftRefillDuration; + + +<#if type == "IN_MEMORY"> + public RateLimiterService( + @Value("\${rate.limiting.capacity:${capacity}}") final long capacity, + @Value("\${rate.limiting.refill-tokens:${refillTokens}}") final long refillTokens, + @Value("\${rate.limiting.refill-duration-seconds:${refillDuration}}") final long refillDurationSeconds<#if hasOverdraft>, + @Value("\${rate.limiting.overdraft.capacity:${overdraftCapacity}}") final long overdraftCapacity, + @Value("\${rate.limiting.overdraft.refill-tokens:${overdraftRefillTokens}}") final long overdraftRefillTokens, + @Value("\${rate.limiting.overdraft.refill-duration-seconds:${overdraftRefillDuration}}") final long overdraftRefillDurationSeconds) { + this.capacity = capacity; + this.refillTokens = refillTokens; + this.refillDuration = Duration.ofSeconds(refillDurationSeconds); +<#if hasOverdraft> + this.overdraftCapacity = overdraftCapacity; + this.overdraftRefillTokens = overdraftRefillTokens; + this.overdraftRefillDuration = Duration.ofSeconds(overdraftRefillDurationSeconds); + + } +<#else> + public RateLimiterService( + final ProxyManager proxyManager, + @Value("\${rate.limiting.capacity:${capacity}}") final long capacity, + @Value("\${rate.limiting.refill-tokens:${refillTokens}}") final long refillTokens, + @Value("\${rate.limiting.refill-duration-seconds:${refillDuration}}") final long refillDurationSeconds<#if hasOverdraft>, + @Value("\${rate.limiting.overdraft.capacity:${overdraftCapacity}}") final long overdraftCapacity, + @Value("\${rate.limiting.overdraft.refill-tokens:${overdraftRefillTokens}}") final long overdraftRefillTokens, + @Value("\${rate.limiting.overdraft.refill-duration-seconds:${overdraftRefillDuration}}") final long overdraftRefillDurationSeconds) { + this.proxyManager = proxyManager; + this.capacity = capacity; + this.refillTokens = refillTokens; + this.refillDuration = Duration.ofSeconds(refillDurationSeconds); +<#if hasOverdraft> + this.overdraftCapacity = overdraftCapacity; + this.overdraftRefillTokens = overdraftRefillTokens; + this.overdraftRefillDuration = Duration.ofSeconds(overdraftRefillDurationSeconds); + + } + + + public ConsumptionProbe tryConsume(final String key) { +<#if type == "IN_MEMORY"> + final Bucket bucket = buckets.computeIfAbsent(key, k -> createBucket()); + return bucket.tryConsumeAndReturnRemaining(1); +<#else> + final BucketConfiguration configuration = createBucketConfiguration(); + final Bucket bucket = proxyManager.getProxy(key, () -> configuration); + return bucket.tryConsumeAndReturnRemaining(1); + + } + + public long getCapacity() { + return this.capacity; + } + +<#if type == "IN_MEMORY"> + private Bucket createBucket() { + return Bucket.builder() + .addLimit(buildMainBandwidth())<#if hasOverdraft> + .addLimit(buildOverdraftBandwidth()) + .build(); + } +<#else> + private BucketConfiguration createBucketConfiguration() { + return BucketConfiguration.builder() + .addLimit(buildMainBandwidth())<#if hasOverdraft> + .addLimit(buildOverdraftBandwidth()) + .build(); + } + + + private Bandwidth buildMainBandwidth() { + return Bandwidth.builder() + .capacity(this.capacity) + .refillGreedy(this.refillTokens, this.refillDuration) + .build(); + } +<#if hasOverdraft> + + private Bandwidth buildOverdraftBandwidth() { + return Bandwidth.builder() + .capacity(this.overdraftCapacity) + .refillGreedy(this.overdraftRefillTokens, this.overdraftRefillDuration) + .build(); + } + + +} diff --git a/spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiting-filter.ftl b/spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiting-filter.ftl new file mode 100644 index 00000000..88a8fece --- /dev/null +++ b/spring-crud-generator-core/src/main/resources/templates/ratelimiting/rate-limiting-filter.ftl @@ -0,0 +1,104 @@ +<#setting number_format="computer"> +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import org.springframework.web.filter.OncePerRequestFilter; +<#if keyStrategy == "AUTHENTICATED_USER"> +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; + + +import io.github.bucket4j.ConsumptionProbe; + +@Component +@Order(Ordered.HIGHEST_PRECEDENCE) +public class RateLimitingFilter extends OncePerRequestFilter { + + private static final String HEADER_REMAINING = "X-Rate-Limit-Remaining"; + private static final String HEADER_LIMIT = "X-Rate-Limit-Limit"; + private static final String HEADER_RETRY_AFTER = "X-Rate-Limit-Retry-After-Seconds"; + + private final RateLimiterService rateLimiterService; + private final int statusCode; + private final boolean includeHeaders; + private final String message; +<#if keyStrategy == "HEADER"> + private final String keyHeader; + + + public RateLimitingFilter( + final RateLimiterService rateLimiterService, + @Value("\${rate.limiting.response.status-code:${statusCode}}") final int statusCode, + @Value("\${rate.limiting.response.include-headers:${includeHeaders?c}}") final boolean includeHeaders, + @Value("\${rate.limiting.response.message:Rate limit exceeded. Please try again later.}") final String message<#if keyStrategy == "HEADER">, + @Value("\${rate.limiting.key-header:${keyHeader}}") final String keyHeader) { + this.rateLimiterService = rateLimiterService; + this.statusCode = statusCode; + this.includeHeaders = includeHeaders; + this.message = message; +<#if keyStrategy == "HEADER"> + this.keyHeader = keyHeader; + + } + + @Override + protected void doFilterInternal(final HttpServletRequest request, final HttpServletResponse response, + final FilterChain filterChain) throws ServletException, IOException { + + final String key = resolveKey(request); + final ConsumptionProbe probe = rateLimiterService.tryConsume(key); + + if (probe.isConsumed()) { + if (includeHeaders) { + response.setHeader(HEADER_REMAINING, String.valueOf(probe.getRemainingTokens())); + response.setHeader(HEADER_LIMIT, String.valueOf(rateLimiterService.getCapacity())); + } + filterChain.doFilter(request, response); + } else { + if (includeHeaders) { + response.setHeader(HEADER_RETRY_AFTER, + String.valueOf(TimeUnit.NANOSECONDS.toSeconds(probe.getNanosToWaitForRefill()))); + } + response.setStatus(statusCode); + response.setContentType("application/json;charset=UTF-8"); + response.getWriter().write("{\"message\": \"" + message + "\"}"); + } + } + + private String resolveKey(final HttpServletRequest request) { +<#if keyStrategy == "IP"> + final String forwarded = request.getHeader("X-Forwarded-For"); + if (forwarded != null && !forwarded.isBlank()) { + return forwarded.split(",")[0].trim(); + } + return request.getRemoteAddr(); +<#elseif keyStrategy == "API_KEY"> + final String apiKey = request.getHeader("X-API-Key"); + return apiKey != null ? apiKey : request.getRemoteAddr(); +<#elseif keyStrategy == "HEADER"> + final String headerValue = request.getHeader(keyHeader); + return headerValue != null ? headerValue : request.getRemoteAddr(); +<#elseif keyStrategy == "AUTHENTICATED_USER"> + final Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if (authentication != null && authentication.isAuthenticated() + && !"anonymousUser".equals(authentication.getPrincipal())) { + return "user:" + authentication.getName(); + } + final String forwarded = request.getHeader("X-Forwarded-For"); + if (forwarded != null && !forwarded.isBlank()) { + return forwarded.split(",")[0].trim(); + } + return request.getRemoteAddr(); + + } + +} diff --git a/spring-crud-generator-core/src/main/resources/templates/ratelimiting/redis-rate-limiter-configuration.ftl b/spring-crud-generator-core/src/main/resources/templates/ratelimiting/redis-rate-limiter-configuration.ftl new file mode 100644 index 00000000..4c4988ed --- /dev/null +++ b/spring-crud-generator-core/src/main/resources/templates/ratelimiting/redis-rate-limiter-configuration.ftl @@ -0,0 +1,34 @@ +<#setting number_format="computer"> +import java.time.Duration; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory; + +import io.github.bucket4j.distributed.proxy.ProxyManager; +import io.github.bucket4j.redis.lettuce.cas.LettuceBasedProxyManager; +import io.lettuce.core.RedisClient; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.codec.RedisCodec; +import io.lettuce.core.codec.StringCodec; + +@Configuration +public class RedisRateLimiterConfiguration { + + @Bean + public ProxyManager rateLimitingProxyManager( + final LettuceConnectionFactory lettuceConnectionFactory, + @Value("\${rate.limiting.refill-duration-seconds:${refillDuration}}") final long refillDurationSeconds) { + + final RedisClient redisClient = (RedisClient) lettuceConnectionFactory.getNativeClient(); + final StatefulRedisConnection redisConnection = redisClient.connect( + RedisCodec.of(StringCodec.UTF8, ByteArrayCodec.INSTANCE) + ); + return LettuceBasedProxyManager.builderFor(redisConnection) + .withExpirationAfterWrite(Duration.ofSeconds(refillDurationSeconds * 2)) + .build(); + } + +} diff --git a/spring-crud-generator-core/src/test/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGeneratorTest.java b/spring-crud-generator-core/src/test/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGeneratorTest.java new file mode 100644 index 00000000..4102b736 --- /dev/null +++ b/spring-crud-generator-core/src/test/java/dev/markozivkovic/springcrudgenerator/generators/RateLimitingGeneratorTest.java @@ -0,0 +1,383 @@ +package dev.markozivkovic.springcrudgenerator.generators; + +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.when; + +import java.util.Map; +import java.util.Objects; + +import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; + +import dev.markozivkovic.springcrudgenerator.constants.GeneratorConstants; +import dev.markozivkovic.springcrudgenerator.constants.TemplateContextConstants; +import dev.markozivkovic.springcrudgenerator.context.GeneratorContext; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitDefinition; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitResponseConfig; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitingConfiguration; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitingConfiguration.KeyStrategyEnum; +import dev.markozivkovic.springcrudgenerator.models.CrudConfiguration.RateLimitingConfiguration.RateLimitTypeEnum; +import dev.markozivkovic.springcrudgenerator.models.PackageConfiguration; +import dev.markozivkovic.springcrudgenerator.utils.FileWriterUtils; +import dev.markozivkovic.springcrudgenerator.utils.FreeMarkerTemplateProcessorUtils; +import dev.markozivkovic.springcrudgenerator.utils.PackageUtils; +import dev.markozivkovic.springcrudgenerator.utils.SpringBootVersionUtils; + +class RateLimitingGeneratorTest { + + private static class CrudAndRateLimit { + CrudConfiguration crudConfig; + RateLimitingConfiguration rateLimitingConfig; + } + + private CrudAndRateLimit prepareCrudWithRateLimiting() { + final CrudAndRateLimit cr = new CrudAndRateLimit(); + cr.crudConfig = mock(CrudConfiguration.class); + cr.rateLimitingConfig = mock(RateLimitingConfiguration.class); + when(cr.crudConfig.getRateLimiting()).thenReturn(cr.rateLimitingConfig); + return cr; + } + + @Test + void generate_shouldSkipWhenRateLimitingIsNull() { + + final CrudConfiguration crudConfig = mock(CrudConfiguration.class); + when(crudConfig.getRateLimiting()).thenReturn(null); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + generator.generate("out"); + + genCtx.verifyNoInteractions(); + pkg.verifyNoInteractions(); + tpl.verifyNoInteractions(); + writer.verifyNoInteractions(); + } + } + + @Test + void generate_shouldSkipWhenRateLimitingDisabled() { + + final CrudAndRateLimit cr = prepareCrudWithRateLimiting(); + when(cr.rateLimitingConfig.getEnabled()).thenReturn(false); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(cr.crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + generator.generate("out"); + + genCtx.verifyNoInteractions(); + pkg.verifyNoInteractions(); + tpl.verifyNoInteractions(); + writer.verifyNoInteractions(); + } + } + + @Test + void generate_shouldSkipWhenAlreadyGeneratedInContext() { + + final CrudAndRateLimit cr = prepareCrudWithRateLimiting(); + when(cr.rateLimitingConfig.getEnabled()).thenReturn(true); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(cr.crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + genCtx.when(() -> GeneratorContext.isGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)).thenReturn(true); + + generator.generate("out"); + + pkg.verifyNoInteractions(); + tpl.verifyNoInteractions(); + writer.verifyNoInteractions(); + + genCtx.verify(() -> GeneratorContext.isGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)); + genCtx.verifyNoMoreInteractions(); + } + } + + @Test + void generate_shouldGenerateServiceAndFilter_forInMemoryType_withDefaultValues() { + + final CrudAndRateLimit cr = prepareCrudWithRateLimiting(); + when(cr.rateLimitingConfig.getEnabled()).thenReturn(true); + when(cr.rateLimitingConfig.getType()).thenReturn(RateLimitTypeEnum.IN_MEMORY); + when(cr.rateLimitingConfig.getKeyStrategy()).thenReturn(KeyStrategyEnum.IP); + when(cr.rateLimitingConfig.getKeyHeader()).thenReturn(null); + when(cr.rateLimitingConfig.getGlobal()).thenReturn(null); + when(cr.rateLimitingConfig.getResponse()).thenReturn(null); + when(cr.crudConfig.getSpringBootVersion()).thenReturn("3.3.0"); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(cr.crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic sbv = mockStatic(SpringBootVersionUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + genCtx.when(() -> GeneratorContext.isGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)).thenReturn(false); + sbv.when(() -> SpringBootVersionUtils.isSpringBoot3("3.3.0")).thenReturn(true); + pkg.when(() -> PackageUtils.getPackagePathFromOutputDir("out")).thenReturn("com.example"); + pkg.when(() -> PackageUtils.computeConfigurationPackage("com.example", packageConfig)).thenReturn("com.example.configurations"); + pkg.when(() -> PackageUtils.computeConfigurationSubPackage(packageConfig)).thenReturn("configurations"); + + tpl.when(() -> FreeMarkerTemplateProcessorUtils.processTemplate(anyString(), anyMap())).thenReturn("// GENERATED"); + + generator.generate("out"); + + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/rate-limiter-service.ftl"), + argThat(ctx -> { + final Map map = (Map) ctx; + return map.get("type") == RateLimitTypeEnum.IN_MEMORY + && map.get("keyStrategy") == KeyStrategyEnum.IP + && Objects.equals(map.get("capacity"), 100L) + && Objects.equals(map.get("refillTokens"), 100L) + && Objects.equals(map.get("refillDuration"), 60L) + && Objects.equals(map.get("hasOverdraft"), false) + && Objects.equals(map.get("statusCode"), 429) + && Objects.equals(map.get("includeHeaders"), true) + && Objects.equals(map.get(TemplateContextConstants.IS_SPRING_BOOT_3), true); + }) + )); + + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/rate-limiting-filter.ftl"), anyMap())); + + // No Redis config for IN_MEMORY + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/redis-rate-limiter-configuration.ftl"), anyMap()), never()); + + writer.verify(() -> FileWriterUtils.writeToFile( + eq("out"), eq("configurations"), eq("RateLimiterService.java"), anyString())); + writer.verify(() -> FileWriterUtils.writeToFile( + eq("out"), eq("configurations"), eq("RateLimitingFilter.java"), anyString())); + writer.verify(() -> FileWriterUtils.writeToFile( + eq("out"), eq("configurations"), eq("RedisRateLimiterConfiguration.java"), anyString()), never()); + + genCtx.verify(() -> GeneratorContext.markGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)); + } + } + + @Test + void generate_shouldGenerateServiceFilterAndRedisConfig_forRedisType_withExplicitValues() { + + final RateLimitDefinition globalDef = mock(RateLimitDefinition.class); + when(globalDef.getCapacity()).thenReturn(200L); + when(globalDef.getRefillTokens()).thenReturn(200L); + when(globalDef.getRefillDuration()).thenReturn(30L); + when(globalDef.getOverdraft()).thenReturn(null); + + final RateLimitResponseConfig responseConfig = mock(RateLimitResponseConfig.class); + when(responseConfig.getStatusCode()).thenReturn(429); + when(responseConfig.getIncludeHeaders()).thenReturn(true); + when(responseConfig.getMessage()).thenReturn("Too many requests."); + + final CrudAndRateLimit cr = prepareCrudWithRateLimiting(); + when(cr.rateLimitingConfig.getEnabled()).thenReturn(true); + when(cr.rateLimitingConfig.getType()).thenReturn(RateLimitTypeEnum.REDIS); + when(cr.rateLimitingConfig.getKeyStrategy()).thenReturn(KeyStrategyEnum.API_KEY); + when(cr.rateLimitingConfig.getKeyHeader()).thenReturn(null); + when(cr.rateLimitingConfig.getGlobal()).thenReturn(globalDef); + when(cr.rateLimitingConfig.getResponse()).thenReturn(responseConfig); + when(cr.crudConfig.getSpringBootVersion()).thenReturn("4.0.0"); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(cr.crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic sbv = mockStatic(SpringBootVersionUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + genCtx.when(() -> GeneratorContext.isGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)).thenReturn(false); + sbv.when(() -> SpringBootVersionUtils.isSpringBoot3("4.0.0")).thenReturn(false); + pkg.when(() -> PackageUtils.getPackagePathFromOutputDir("out")).thenReturn("com.example"); + pkg.when(() -> PackageUtils.computeConfigurationPackage("com.example", packageConfig)).thenReturn("com.example.configurations"); + pkg.when(() -> PackageUtils.computeConfigurationSubPackage(packageConfig)).thenReturn("configurations"); + + tpl.when(() -> FreeMarkerTemplateProcessorUtils.processTemplate(anyString(), anyMap())).thenReturn("// GENERATED"); + + generator.generate("out"); + + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/rate-limiter-service.ftl"), + argThat(ctx -> { + final Map map = (Map) ctx; + return map.get("type") == RateLimitTypeEnum.REDIS + && map.get("keyStrategy") == KeyStrategyEnum.API_KEY + && Objects.equals(map.get("capacity"), 200L) + && Objects.equals(map.get("refillTokens"), 200L) + && Objects.equals(map.get("refillDuration"), 30L) + && Objects.equals(map.get("hasOverdraft"), false) + && Objects.equals(map.get(TemplateContextConstants.IS_SPRING_BOOT_3), false); + }) + )); + + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/rate-limiting-filter.ftl"), anyMap())); + + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/redis-rate-limiter-configuration.ftl"), anyMap())); + + writer.verify(() -> FileWriterUtils.writeToFile( + eq("out"), eq("configurations"), eq("RateLimiterService.java"), anyString())); + writer.verify(() -> FileWriterUtils.writeToFile( + eq("out"), eq("configurations"), eq("RateLimitingFilter.java"), anyString())); + writer.verify(() -> FileWriterUtils.writeToFile( + eq("out"), eq("configurations"), eq("RedisRateLimiterConfiguration.java"), anyString())); + + genCtx.verify(() -> GeneratorContext.markGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)); + } + } + + @Test + void generate_shouldIncludeOverdraftValues_whenOverdraftIsConfigured() { + + final RateLimitDefinition overdraft = mock(RateLimitDefinition.class); + when(overdraft.getCapacity()).thenReturn(50L); + when(overdraft.getRefillTokens()).thenReturn(50L); + when(overdraft.getRefillDuration()).thenReturn(10L); + + final RateLimitDefinition globalDef = mock(RateLimitDefinition.class); + when(globalDef.getCapacity()).thenReturn(1000L); + when(globalDef.getRefillTokens()).thenReturn(1000L); + when(globalDef.getRefillDuration()).thenReturn(60L); + when(globalDef.getOverdraft()).thenReturn(overdraft); + + final CrudAndRateLimit cr = prepareCrudWithRateLimiting(); + when(cr.rateLimitingConfig.getEnabled()).thenReturn(true); + when(cr.rateLimitingConfig.getType()).thenReturn(RateLimitTypeEnum.IN_MEMORY); + when(cr.rateLimitingConfig.getKeyStrategy()).thenReturn(KeyStrategyEnum.IP); + when(cr.rateLimitingConfig.getKeyHeader()).thenReturn(null); + when(cr.rateLimitingConfig.getGlobal()).thenReturn(globalDef); + when(cr.rateLimitingConfig.getResponse()).thenReturn(null); + when(cr.crudConfig.getSpringBootVersion()).thenReturn("3.3.0"); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(cr.crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic sbv = mockStatic(SpringBootVersionUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + genCtx.when(() -> GeneratorContext.isGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)).thenReturn(false); + sbv.when(() -> SpringBootVersionUtils.isSpringBoot3("3.3.0")).thenReturn(true); + pkg.when(() -> PackageUtils.getPackagePathFromOutputDir("out")).thenReturn("com.example"); + pkg.when(() -> PackageUtils.computeConfigurationPackage("com.example", packageConfig)).thenReturn("com.example.configurations"); + pkg.when(() -> PackageUtils.computeConfigurationSubPackage(packageConfig)).thenReturn("configurations"); + + tpl.when(() -> FreeMarkerTemplateProcessorUtils.processTemplate(anyString(), anyMap())).thenReturn("// GENERATED"); + + generator.generate("out"); + + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/rate-limiter-service.ftl"), + argThat(ctx -> { + final Map map = (Map) ctx; + return Objects.equals(map.get("hasOverdraft"), true) + && Objects.equals(map.get("overdraftCapacity"), 50L) + && Objects.equals(map.get("overdraftRefillTokens"), 50L) + && Objects.equals(map.get("overdraftRefillDuration"), 10L); + }) + )); + + genCtx.verify(() -> GeneratorContext.markGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)); + } + } + + @Test + void generate_shouldUseDefaultValues_whenGlobalConfigHasNullFields() { + + final RateLimitDefinition globalDef = mock(RateLimitDefinition.class); + when(globalDef.getCapacity()).thenReturn(null); + when(globalDef.getRefillTokens()).thenReturn(null); + when(globalDef.getRefillDuration()).thenReturn(null); + when(globalDef.getOverdraft()).thenReturn(null); + + final CrudAndRateLimit cr = prepareCrudWithRateLimiting(); + when(cr.rateLimitingConfig.getEnabled()).thenReturn(true); + when(cr.rateLimitingConfig.getType()).thenReturn(null); + when(cr.rateLimitingConfig.getKeyStrategy()).thenReturn(null); + when(cr.rateLimitingConfig.getKeyHeader()).thenReturn(null); + when(cr.rateLimitingConfig.getGlobal()).thenReturn(globalDef); + when(cr.rateLimitingConfig.getResponse()).thenReturn(null); + when(cr.crudConfig.getSpringBootVersion()).thenReturn("3.3.0"); + + final PackageConfiguration packageConfig = mock(PackageConfiguration.class); + final RateLimitingGenerator generator = new RateLimitingGenerator(cr.crudConfig, packageConfig); + + try (final MockedStatic genCtx = mockStatic(GeneratorContext.class); + final MockedStatic pkg = mockStatic(PackageUtils.class); + final MockedStatic sbv = mockStatic(SpringBootVersionUtils.class); + final MockedStatic tpl = mockStatic(FreeMarkerTemplateProcessorUtils.class); + final MockedStatic writer = mockStatic(FileWriterUtils.class)) { + + genCtx.when(() -> GeneratorContext.isGenerated( + GeneratorConstants.GeneratorContextKeys.RATE_LIMITING_CONFIGURATION)).thenReturn(false); + sbv.when(() -> SpringBootVersionUtils.isSpringBoot3("3.3.0")).thenReturn(true); + pkg.when(() -> PackageUtils.getPackagePathFromOutputDir("out")).thenReturn("com.example"); + pkg.when(() -> PackageUtils.computeConfigurationPackage("com.example", packageConfig)).thenReturn("com.example.configurations"); + pkg.when(() -> PackageUtils.computeConfigurationSubPackage(packageConfig)).thenReturn("configurations"); + + tpl.when(() -> FreeMarkerTemplateProcessorUtils.processTemplate(anyString(), anyMap())).thenReturn("// GENERATED"); + + generator.generate("out"); + + // type defaults to IN_MEMORY, keyStrategy to IP, values to defaults + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/rate-limiter-service.ftl"), + argThat(ctx -> { + final Map map = (Map) ctx; + return map.get("type") == RateLimitTypeEnum.IN_MEMORY + && map.get("keyStrategy") == KeyStrategyEnum.IP + && Objects.equals(map.get("capacity"), 100L) + && Objects.equals(map.get("refillTokens"), 100L) + && Objects.equals(map.get("refillDuration"), 60L) + && Objects.equals(map.get("statusCode"), 429) + && Objects.equals(map.get("includeHeaders"), true) + && Objects.equals(map.get("keyHeader"), "X-Client-Id"); + }) + )); + + // No Redis config for IN_MEMORY (default) + tpl.verify(() -> FreeMarkerTemplateProcessorUtils.processTemplate( + eq("ratelimiting/redis-rate-limiter-configuration.ftl"), anyMap()), never()); + } + } +}