diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/HttpClientSendWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/HttpClientSendWrapper.java index 986e6341..5513e20b 100644 --- a/agent/src/main/java/dev/aikido/agent/wrappers/HttpClientSendWrapper.java +++ b/agent/src/main/java/dev/aikido/agent/wrappers/HttpClientSendWrapper.java @@ -6,7 +6,6 @@ import net.bytebuddy.matcher.ElementMatcher; import net.bytebuddy.matcher.ElementMatchers; -import java.lang.reflect.Method; import java.net.HttpURLConnection; import java.net.MalformedURLException; import java.net.URL; @@ -55,13 +54,9 @@ public static void before( // Load the class from the JAR Class clazz = classLoader.loadClass("dev.aikido.agent_api.collectors.URLCollector"); - // Run report with "argument" - for (Method method2: clazz.getMethods()) { - if(method2.getName().equals("report")) { - method2.invoke(null, httpRequest.uri().toURL()); - break; - } - } + // report(URL) is overloaded (also has a report(URL, ContextObject) variant), so it + // must be looked up by exact signature - matching by name alone could pick either. + clazz.getMethod("report", URL.class).invoke(null, httpRequest.uri().toURL()); classLoader.close(); // Close the class loader } } diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/HttpURLConnectionWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/HttpURLConnectionWrapper.java index fde7530f..0356a430 100644 --- a/agent/src/main/java/dev/aikido/agent/wrappers/HttpURLConnectionWrapper.java +++ b/agent/src/main/java/dev/aikido/agent/wrappers/HttpURLConnectionWrapper.java @@ -5,7 +5,6 @@ import net.bytebuddy.matcher.ElementMatcher; import net.bytebuddy.matcher.ElementMatchers; -import java.lang.reflect.Method; import java.net.*; import static net.bytebuddy.implementation.bytecode.assign.Assigner.Typing.DYNAMIC; @@ -50,13 +49,9 @@ public static void before( // Load the class from the JAR Class clazz = classLoader.loadClass("dev.aikido.agent_api.collectors.URLCollector"); - // Run report with "argument" - for (Method method2: clazz.getMethods()) { - if(method2.getName().equals("report")) { - method2.invoke(null, url); - break; - } - } + // report(URL) is overloaded (also has a report(URL, ContextObject) variant), so it + // must be looked up by exact signature - matching by name alone could pick either. + clazz.getMethod("report", URL.class).invoke(null, url); classLoader.close(); // Close the class loader } } diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/spring/ReactorAikidoContext.java b/agent/src/main/java/dev/aikido/agent/wrappers/spring/ReactorAikidoContext.java new file mode 100644 index 00000000..d5ec36f1 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/spring/ReactorAikidoContext.java @@ -0,0 +1,91 @@ +package dev.aikido.agent.wrappers.spring; + +import dev.aikido.agent_api.collectors.URLCollector; +import dev.aikido.agent_api.context.ContextObject; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.net.URL; + +/** + * Carries the Aikido ContextObject through Reactor's own Context so it survives scheduler hops + * (e.g. .publishOn()) between the incoming WebFlux request and any WebClient calls made while + * handling it - unlike Context.get()'s ThreadLocal, which only sees the current OS thread. + * + * Everything here is Object-typed and goes through reflection, rooted at the classloader of a + * live Mono instance passed in. reactor-core is compileOnly for this module: a *separate* class + * (like this one, as opposed to an @Advice method whose parameter types ByteBuddy resolves + * specially against the woven target's own classloader) that declares Mono, Context or + * ContextView as a concrete parameter/return type throws NoClassDefFoundError at class + * verification time on the agent's own classloader, which has no visibility into the target + * application's classpath. Must be public: the woven target class (in a completely different + * package) needs to call into it directly. + */ +public final class ReactorAikidoContext { + private static final String KEY = "dev.aikido.agent.wrappers.spring.ReactorAikidoContextKey"; + + private ReactorAikidoContext() {} + + // `mono` is a Mono, returned Object is that same Mono wrapped with .contextWrite(). + public static Object write(Object mono, ContextObject context) { + try { + ClassLoader cl = mono.getClass().getClassLoader(); + Class contextClass = Class.forName("reactor.util.context.Context", false, cl); + Class contextViewClass = Class.forName("reactor.util.context.ContextView", false, cl); + Object newContext = contextClass.getMethod("of", Object.class, Object.class) + .invoke(null, KEY, context); + Method contextWrite = mono.getClass().getMethod("contextWrite", contextViewClass); + return contextWrite.invoke(mono, newContext); + } catch (Throwable t) { + return mono; + } + } + + // `original` is a Mono. Registers `url` once `original` is actually subscribed to, using + // whatever ContextObject write() captured upstream in the same reactive chain (null if + // none). Returns a Mono equivalent to `original` (or `original` itself if anything here + // fails - registration is best-effort, must never break the actual request). + public static Object deferRegisterUrl(Object original, URL url) { + try { + ClassLoader cl = original.getClass().getClassLoader(); + Class functionClass = Class.forName("java.util.function.Function", false, cl); + Method deferContextual = original.getClass().getMethod("deferContextual", functionClass); + InvocationHandler handler = new RegisterUrlHandler(original, url); + Object proxy = Proxy.newProxyInstance(cl, new Class[]{functionClass}, handler); + return deferContextual.invoke(null, proxy); + } catch (Throwable t) { + return original; + } + } + + // Not a lambda: constructed from advice code that ByteBuddy inlines into the *target* + // class's bytecode, so a lambda here would become a private synthetic method that the + // target class can't call back into (IllegalAccessError). A plain named class implementing + // InvocationHandler - whose own methods only ever see java.lang.Object - avoids that. + private static final class RegisterUrlHandler implements InvocationHandler { + private final Object original; + private final URL url; + + RegisterUrlHandler(Object original, URL url) { + this.original = original; + this.url = url; + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) { + if (!"apply".equals(method.getName()) || args == null || args.length == 0) { + return original; + } + Object ctxView = args[0]; + ContextObject context = null; + try { + Method getOrDefault = ctxView.getClass().getMethod("getOrDefault", Object.class, Object.class); + context = (ContextObject) getOrDefault.invoke(ctxView, KEY, null); + } catch (Throwable ignored) { + } + URLCollector.report(url, context); + return original; + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java index 91645852..e05ceb9d 100644 --- a/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java +++ b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebClientWrapper.java @@ -1,15 +1,17 @@ package dev.aikido.agent.wrappers.spring; import dev.aikido.agent.wrappers.Wrapper; -import dev.aikido.agent_api.collectors.URLCollector; import net.bytebuddy.asm.Advice; import net.bytebuddy.description.method.MethodDescription; import net.bytebuddy.description.type.TypeDescription; import net.bytebuddy.matcher.ElementMatcher; import net.bytebuddy.matcher.ElementMatchers; import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import reactor.core.publisher.Mono; import java.net.MalformedURLException; +import java.net.URL; public class SpringWebClientWrapper implements Wrapper { // Referenced by name (not by .class) in the matchers below: ExchangeFunction is only on @@ -33,16 +35,21 @@ public ElementMatcher getTypeMatcher() { return ElementMatchers.hasSuperType(ElementMatchers.named(EXCHANGE_FUNCTION_CLASS_NAME)); } public static class SpringWebClientAdvice { - @Advice.OnMethodEnter(suppress = Throwable.class) - public static void before( - @Advice.Argument(0) ClientRequest request + // Registration happens in onExit, wrapped around the returned Mono via + // deferContextual(), rather than eagerly in onEnter. That way it runs at subscribe + // time, reading back whatever ContextObject SpringWebfluxWrapper wrote into Reactor's + // Context (see ReactorAikidoContext) - reliable regardless of scheduler hops between + // the incoming request and this WebClient call, unlike Context.get()'s ThreadLocal. + @Advice.OnMethodExit(suppress = Throwable.class) + public static void after( + @Advice.Argument(0) ClientRequest request, + @Advice.Return(readOnly = false) Mono returnValue ) throws MalformedURLException { - if (request == null || request.url() == null) { + if (request == null || request.url() == null || returnValue == null) { return; } - // Report the URL before the request is sent, so DNSRecordCollector can match the - // DNS lookup that follows to this outgoing request. - URLCollector.report(request.url().toURL()); + URL url = request.url().toURL(); + returnValue = (Mono) ReactorAikidoContext.deferRegisterUrl(returnValue, url); } } } diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebfluxWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebfluxWrapper.java index c249136b..57fa06ef 100644 --- a/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebfluxWrapper.java +++ b/agent/src/main/java/dev/aikido/agent/wrappers/spring/SpringWebfluxWrapper.java @@ -55,6 +55,12 @@ public ElementMatcher getTypeMatcher() { public record SkipOnWrapper(Mono newReturnValue) { } + // Non-skip path result: carries the ContextObject alongside the response so onExit() can + // write it into Reactor's own Context (see ReactorAikidoContext), letting it survive + // scheduler hops before any WebClient call made while handling this request. + public record EnterResult(ServerHttpResponse res, ContextObject context) { + } + public static class SpringWebfluxAdvice { @Advice.OnMethodEnter(skipOn = SkipOnWrapper.class, suppress = Throwable.class) public static Object onEnter( @@ -94,7 +100,7 @@ public static Object onEnter( return new SkipOnWrapper(res.writeWith(Mono.just(dataBuffer))); } - return res; // Return to analyze status code in OnMethodExit. + return new EnterResult(res, context); // Return to analyze status code in OnMethodExit. } /** onExit() @@ -105,17 +111,20 @@ public static void onExit( @Advice.Enter Object enterResult, @Advice.Return(readOnly = false) Mono returnValue ) { - // enterResult can be two things : Either the SkipOnWrapper or the ServerHttpResponse - // ServerHttpResponse -> Extract status code. + // enterResult can be two things : Either the SkipOnWrapper or the EnterResult + // EnterResult -> Extract status code, write the context into Reactor's Context. // SkipOnWrapper -> we blocked a request (e.g. IP Blocking), and are returning the value below if (enterResult instanceof SkipOnWrapper wrapper && wrapper.newReturnValue() != null) { returnValue = wrapper.newReturnValue(); - } else if (enterResult instanceof ServerHttpResponse res) { + } else if (enterResult instanceof EnterResult er) { // Report status code of response : - Integer statusCode = res.getRawStatusCode(); + Integer statusCode = er.res() != null ? er.res().getRawStatusCode() : null; if (statusCode != null) { WebResponseCollector.report(statusCode); } + if (returnValue != null) { + returnValue = (Mono) ReactorAikidoContext.write(returnValue, er.context()); + } } } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java index dbd05c64..8301dbc5 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/DNSRecordCollector.java @@ -1,6 +1,7 @@ package dev.aikido.agent_api.collectors; import dev.aikido.agent_api.context.Context; +import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.storage.HostnamesStore; import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; @@ -32,7 +33,8 @@ private DNSRecordCollector() {} public static void report(String hostname, InetAddress[] inetAddresses) { // InetAddress.getAllByName() resolves everything in one call, so it's safe to consume. - process(hostname, inetAddresses, PendingHostnamesStore.getAndRemove(hostname), INET_ADDRESS_OPERATION_NAME); + withCapturedContext(hostname, () -> + process(hostname, inetAddresses, PendingHostnamesStore.getAndRemove(hostname), INET_ADDRESS_OPERATION_NAME)); } // For clients that resolve their own DNS (e.g. Reactor Netty, used by Spring's WebClient) or @@ -40,7 +42,31 @@ public static void report(String hostname, InetAddress[] inetAddresses) { // the same hostname (IPv4 then IPv6), so unlike report(), this peeks the pending port instead // of consuming it - consuming on the first attempt would let a later attempt bypass SSRF. public static void reportConnect(String hostname, InetAddress resolvedAddress) { - process(hostname, new InetAddress[]{resolvedAddress}, PendingHostnamesStore.getPorts(hostname), SOCKET_CHANNEL_OPERATION_NAME); + withCapturedContext(hostname, () -> + process(hostname, new InetAddress[]{resolvedAddress}, PendingHostnamesStore.getPorts(hostname), SOCKET_CHANNEL_OPERATION_NAME)); + } + + // Restores the ContextObject captured when this hostname's pending entry was registered + // (PendingHostnamesStore is global, not thread-local) so SSRFDetector's Context.get() sees + // the request that actually triggered the outbound call, even if we're running on a + // different thread than the one that registered it. + private static void withCapturedContext(String hostname, Runnable action) { + ContextObject capturedContext = PendingHostnamesStore.getContext(hostname); + if (capturedContext == null) { + action.run(); + return; + } + ContextObject previous = Context.get(); + Context.set(capturedContext); + try { + action.run(); + } finally { + if (previous != null) { + Context.set(previous); + } else { + Context.reset(); + } + } } private static void process(String hostname, InetAddress[] inetAddresses, Set ports, String operationName) { diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java index e1c1244b..9306a245 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/URLCollector.java @@ -1,5 +1,7 @@ package dev.aikido.agent_api.collectors; +import dev.aikido.agent_api.context.Context; +import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; import dev.aikido.agent_api.storage.PendingHostnamesStore; @@ -13,6 +15,13 @@ public final class URLCollector { private URLCollector() {} public static void report(URL url) { + report(url, Context.get()); + } + + // Used where the caller already resolved the correct context itself instead of relying on + // Context.get() (e.g. Spring WebClient reading it back from Reactor's own Context, which + // survives scheduler hops that break Context.get()'s ThreadLocal). + public static void report(URL url, ContextObject context) { if (url != null) { if (!url.getProtocol().startsWith("http")) { return; // Non-HTTP(S) URL @@ -20,7 +29,7 @@ public static void report(URL url) { logger.trace("Adding a new URL to the cache: %s", url); // Store hostname+port in the pending store so DNSRecordCollector can pick it // up during the DNS lookup that follows, for SSRF detection and outbound hostnames - PendingHostnamesStore.add(url.getHost(), getPortFromURL(url)); + PendingHostnamesStore.add(url.getHost(), getPortFromURL(url), context); } } } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java index 1d66a212..265e72a3 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/collectors/WebRequestCollector.java @@ -8,7 +8,6 @@ import dev.aikido.agent_api.helpers.logging.LogManager; import dev.aikido.agent_api.helpers.logging.Logger; import dev.aikido.agent_api.storage.AttackQueue; -import dev.aikido.agent_api.storage.PendingHostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; import dev.aikido.agent_api.storage.ServiceConfiguration; import dev.aikido.agent_api.storage.attack_wave_detector.AttackWaveDetectorStore; @@ -41,10 +40,6 @@ public static Res report(ContextObject newContext) { // clear context Context.reset(); - // Flush pending hostnames on every context change to prevent the store from - // growing unboundedly when a thread is reused across multiple requests. - PendingHostnamesStore.clear(); - if (config.isIpBypassed(newContext.getRemoteAddress())) { return null; // do not set context when the IP address is bypassed (zen = off) } diff --git a/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java b/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java index 2644f503..9d9da2fd 100644 --- a/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java +++ b/agent_api/src/main/java/dev/aikido/agent_api/storage/PendingHostnamesStore.java @@ -1,9 +1,12 @@ package dev.aikido.agent_api.storage; +import dev.aikido.agent_api.context.Context; +import dev.aikido.agent_api.context.ContextObject; + import java.util.*; /** - * Thread-local bridge between URLCollector and DNSRecordCollector. + * Bridge between URLCollector and DNSRecordCollector. * URLCollector records hostname+port here; DNSRecordCollector.report() (fed by * InetAddress.getAllByName(), which resolves everything in one call) reads and removes the * entry so each (hostname, port) pair is processed exactly once per DNS lookup. @@ -11,49 +14,77 @@ * connect attempt) instead peeks the entry, since a single outbound request can trigger * multiple connect attempts to the same hostname (e.g. IPv4 then IPv6 for a dual-stack host). * - * Entries are normally cleared per incoming request by WebRequestCollector, but a peeked - * entry added outside any incoming-request context (e.g. a WebClient call from a @Scheduled - * task) would never be cleared that way. Capped at MAX_ENTRIES per thread, evicting the least - * recently used entry once exceeded, same bounded-LRU pattern as Hostnames. + * Global rather than thread-local: for async clients (e.g. Spring's WebClient/Reactor Netty), + * the intent registration and the actual connect can run on different OS threads - Reactor + * Netty's own event-loop dispatch, or an app's explicit .publishOn() - so thread-local storage + * silently loses the entry. Trade-off: two concurrent requests to the *same* hostname can share + * an entry (and its captured context) in a narrow race window. This doesn't open an SSRF bypass + * (SSRFDetector/StoredSSRFDetector still run unconditionally either way) - worst case is a wrong + * source attribution for that one request. + * + * Capped at MAX_ENTRIES, evicting the least recently used entry once exceeded, so it can't grow + * unboundedly under load or from entries that are never consumed (e.g. a WebClient call from a + * @Scheduled task, outside any incoming-request context). */ public final class PendingHostnamesStore { private PendingHostnamesStore() {} private static final int MAX_ENTRIES = 1000; - private static final ThreadLocal>> store = - ThreadLocal.withInitial(() -> new LinkedHashMap<>(16, 0.75f, true) { + private record Entry(Set ports, ContextObject context) {} + + private static final Map store = + Collections.synchronizedMap(new LinkedHashMap<>(16, 0.75f, true) { @Override - protected boolean removeEldestEntry(Map.Entry> eldest) { + protected boolean removeEldestEntry(Map.Entry eldest) { return size() > MAX_ENTRIES; } }); public static void add(String hostname, int port) { - Map> map = store.get(); - if (!map.containsKey(hostname)) { - map.put(hostname, new LinkedHashSet<>()); + add(hostname, port, Context.get()); + } + + // Used where the caller already resolved the correct context itself (e.g. via Reactor's own + // Context, which - unlike Context.get()'s ThreadLocal - survives scheduler hops). + public static void add(String hostname, int port, ContextObject context) { + synchronized (store) { + Entry existing = store.get(hostname); + if (existing == null) { + Set ports = new LinkedHashSet<>(); + ports.add(port); + store.put(hostname, new Entry(ports, context)); + } else { + existing.ports().add(port); + } } - map.get(hostname).add(port); } public static Set getAndRemove(String hostname) { - Set ports = store.get().remove(hostname); - if (ports == null) { - return Collections.emptySet(); + synchronized (store) { + Entry entry = store.remove(hostname); + return entry == null ? Collections.emptySet() : entry.ports(); } - return ports; } public static Set getPorts(String hostname) { - Set ports = store.get().get(hostname); - if (ports == null) { - return Collections.emptySet(); + synchronized (store) { + Entry entry = store.get(hostname); + return entry == null ? Collections.emptySet() : Set.copyOf(entry.ports()); + } + } + + // The ContextObject captured when this hostname's pending entry was registered, so SSRF + // taint-checking can use the request that actually triggered the outbound call even if it + // runs on a different thread than the one processing the connect. Null if none was captured. + public static ContextObject getContext(String hostname) { + synchronized (store) { + Entry entry = store.get(hostname); + return entry == null ? null : entry.context(); } - return Collections.unmodifiableSet(ports); } public static void clear() { - store.get().clear(); + store.clear(); } } diff --git a/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java b/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java index b7f88a87..6bd11b8f 100644 --- a/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java +++ b/agent_api/src/test/java/storage/PendingHostnamesStoreTest.java @@ -1,8 +1,10 @@ package storage; +import dev.aikido.agent_api.context.ContextObject; import dev.aikido.agent_api.storage.PendingHostnamesStore; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import utils.EmptySampleContextObject; import java.util.Set; @@ -32,7 +34,7 @@ public void testGetAndRemoveConsumesEntry() { } @Test - public void testUnboundedHostnamesDoNotGrowThreadLocalMapForever() { + public void testUnboundedHostnamesDoNotGrowMapForever() { // Regression test: entries added outside any incoming-request context (e.g. a // WebClient call from a @Scheduled task) never get cleared by WebRequestCollector's // per-request clear(). Adding well over the internal cap of distinct hostnames must @@ -68,6 +70,34 @@ public void testReadingAnEntryProtectsItFromEvictionWhileStillInUse() { assertEquals(Set.of(443), PendingHostnamesStore.getPorts("dual-stack.example.com")); } + @Test + public void testEntryIsVisibleFromADifferentThread() throws InterruptedException { + // The whole point of this store being global instead of thread-local: WebClient's + // "register intent" and "actual connect" steps can run on different OS threads + // (Reactor Netty's own event-loop dispatch, or an app's own .publishOn()). + Thread writer = new Thread(() -> PendingHostnamesStore.add("cross-thread.example.com", 443)); + writer.start(); + writer.join(); + + Set portsSeenFromThisThread = PendingHostnamesStore.getPorts("cross-thread.example.com"); + assertEquals(Set.of(443), portsSeenFromThisThread); + } + + @Test + public void testGetContextReturnsWhatWasCapturedAtAddTime() { + ContextObject context = new EmptySampleContextObject(); + PendingHostnamesStore.add("dev.aikido", 443, context); + + assertSame(context, PendingHostnamesStore.getContext("dev.aikido")); + } + + @Test + public void testGetContextIsNullWhenNoneWasCaptured() { + PendingHostnamesStore.add("dev.aikido", 443, null); + + assertNull(PendingHostnamesStore.getContext("dev.aikido")); + } + @Test public void testClearRemovesEverything() { PendingHostnamesStore.add("dev.aikido", 443); diff --git a/agent_api/src/test/java/wrappers/OkHttpTest.java b/agent_api/src/test/java/wrappers/OkHttpTest.java index 6d54e8bd..8700d946 100644 --- a/agent_api/src/test/java/wrappers/OkHttpTest.java +++ b/agent_api/src/test/java/wrappers/OkHttpTest.java @@ -4,6 +4,8 @@ import dev.aikido.agent_api.storage.Hostnames; import dev.aikido.agent_api.storage.HostnamesStore; import dev.aikido.agent_api.storage.ServiceConfigStore; +import okhttp3.Call; +import okhttp3.Callback; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; @@ -15,9 +17,14 @@ import java.io.IOException; import java.net.ConnectException; import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; public class OkHttpTest { private OkHttpClient client; @@ -97,6 +104,48 @@ public void testSSRFWithoutPortAndWithoutContext() throws Exception { assertEquals(1, getHits("localhost", 80)); } + @Test + public void testSSRFAsyncEnqueueOnDifferentThread() throws Exception { + // newCall() runs on this test thread (intent registered here), but enqueue() executes + // the actual call - including the DNS lookup / connect - on OkHttp's own Dispatcher + // thread pool, a different thread. Investigating whether PendingHostnamesStore (and + // captured Context) survive that hop the same way they need to for WebClient. + setContextAndLifecycle("http://localhost:5000"); + assertEquals(0, getHits("localhost", 5000)); + + Request request = new Request.Builder().url("http://localhost:5000").build(); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + AtomicReference responseCode = new AtomicReference<>(); + + client.newCall(request).enqueue(new Callback() { + @Override + public void onFailure(Call call, IOException e) { + failure.set(e); + latch.countDown(); + } + + @Override + public void onResponse(Call call, Response response) { + responseCode.set(response.code()); + response.close(); + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + assertTrue(completed, "enqueue() callback never fired within 5s"); + + if (failure.get() == null) { + fail("expected the SSRF block to surface as a failure, got response code: " + responseCode.get()); + } + // OkHttp wraps the cancellation as a new IOException embedding the original exception's + // toString() as text, not as a proper getCause() chain, so check the message contains it. + String message = failure.get().getMessage(); + assertTrue(message != null && message.contains("Aikido Zen has blocked a server-side request forgery"), + "expected an SSRF block in the failure message, got: " + failure.get()); + } + private void fetchResponse(String urlString) throws IOException { Request request = new Request.Builder() .url(urlString) diff --git a/agent_api/src/test/java/wrappers/WebClientTest.java b/agent_api/src/test/java/wrappers/WebClientTest.java index 4a4e21d3..0ca0e552 100644 --- a/agent_api/src/test/java/wrappers/WebClientTest.java +++ b/agent_api/src/test/java/wrappers/WebClientTest.java @@ -24,11 +24,16 @@ * SpringWebClientWrapper (URLCollector.report on ExchangeFunction.exchange) and * SocketChannelWrapper (DNSRecordCollector.reportConnect on SocketChannel.connect) run on * different threads for a real WebClient call: the former on the subscribing thread, the - * latter on Reactor Netty's own event-loop thread. PendingHostnamesStore/Context are - * ThreadLocal, so a plain "Context.set() then webClient.block()" test can't observe both - * halves together the way HttpURLConnectionTest can for a same-thread blocking client - that - * only works in production because a real WebFlux request stays on one reactor-http-nio - * thread throughout. So this file tests each wrapper's own contribution separately. + * latter on Reactor Netty's own event-loop thread. This file can't be a single cohesive test + * the way HttpURLConnectionTest is for a same-thread blocking client, and this isn't just a + * ThreadLocal limitation fixed by PendingHostnamesStore going global: the taint-context capture + * (ReactorAikidoContext) only gets written into Reactor's own Context by SpringWebfluxWrapper, + * which only fires for a *real* incoming WebFlux request - a bare "webClient.get()...block()" + * call made directly from a test, with no request behind it, never populates it, so the SSRF + * check silently no-ops and a real network call goes out (confirmed empirically: it hangs until + * timeout instead of getting blocked). A true end-to-end "WebClient in a Spring app, request + * gets blocked" test needs an actual running app - see spring_webflux_postgres.py's `ssrf` + * payload for that. This file instead tests each wrapper's own contribution in isolation. */ public class WebClientTest { private static final WebClient webClient = WebClient.create(); diff --git a/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java b/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java index bf9ee90e..c7555e1e 100644 --- a/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java +++ b/sample-apps/SpringWebfluxSampleApp/src/main/java/dev/aikido/SpringWebfluxSampleApp/RequestController.java @@ -9,6 +9,7 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.reactive.function.client.WebClient; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.netty.http.client.HttpClient; @RestController @@ -47,6 +48,17 @@ public Mono makeRequestFollowingRedirects(@RequestParam String url) { : Mono.just("Error: " + e.getMessage())); } + // Exercises a scheduler hop (a common pattern for mixing blocking JDBC with reactive + // controllers) BEFORE the WebClient call, to test whether ThreadLocal-based taint/port + // correlation survives moving off the original reactor-http-nio thread. Unverified + // hypothesis, see PR #312 worklog item 2. + @GetMapping("/publish-on") + public Mono makeRequestWithSchedulerHop(@RequestParam String url) { + return Mono.just(url) + .publishOn(Schedulers.boundedElastic()) + .flatMap(this::makeRequestInternal); + } + private Mono makeRequestInternal(String url) { return webClient.get() .uri(url)