diff --git a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java index 23737eee9a8f..35d4b984e493 100644 --- a/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/PreparedStatement.java @@ -88,9 +88,11 @@ public PrepareResult prepare() */ public DirectStatement execute(List parameters) { + String remoteAddr = (String) originalRequest.context().get("remoteAddress"); return new DirectStatement( sqlToolbox, - originalRequest.freshCopy().withParameters(parameters) + originalRequest.freshCopy().withParameters(parameters), + remoteAddr ); } diff --git a/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java b/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java index 9d31441befb1..a0f4f0d8e300 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlStatementFactory.java @@ -51,7 +51,8 @@ public HttpStatement httpStatement( public DirectStatement directStatement(final SqlQueryPlus sqlRequest) { - return new DirectStatement(lifecycleToolbox, sqlRequest); + String remoteAddr = (String) sqlRequest.context().get("remoteAddress"); + return new DirectStatement(lifecycleToolbox, sqlRequest, remoteAddr); } public PreparedStatement preparedStatement(final SqlQueryPlus sqlRequest) diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidAvaticaJsonHandler.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidAvaticaJsonHandler.java index f6d0738e2321..2e80bfff4e79 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidAvaticaJsonHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidAvaticaJsonHandler.java @@ -66,6 +66,9 @@ public DruidAvaticaJsonHandler( public boolean handle(Request request, Response response, Callback callback) throws Exception { String requestURI = request.getHttpURI().getPath(); + String remoteAddr = Request.getRemoteAddr(request); + DruidMeta.setThreadLocalRemoteAddress(remoteAddr); + try (Timer.Context ctx = this.requestTimer.start()) { if (AVATICA_PATH_NO_TRAILING_SLASH.equals(StringUtils.maybeRemoveTrailingSlash(requestURI))) { response.getHeaders().put("Content-Type", "application/json;charset=utf-8"); @@ -114,6 +117,9 @@ public boolean handle(Request request, Response response, Callback callback) thr return true; } } + finally { + DruidMeta.clearThreadLocalRemoteAddress(); + } return false; } diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java index ff27f020832e..dcdd731ddbe9 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidMeta.java @@ -119,6 +119,18 @@ public static T logFailure(T error) private static final Logger LOG = new Logger(DruidMeta.class); + private static final ThreadLocal THREAD_LOCAL_REMOTE_ADDRESS = new ThreadLocal<>(); + + public static void setThreadLocalRemoteAddress(String remoteAddress) + { + THREAD_LOCAL_REMOTE_ADDRESS.set(remoteAddress); + } + + public static void clearThreadLocalRemoteAddress() + { + THREAD_LOCAL_REMOTE_ADDRESS.remove(); + } + /** * Items passed in via the connection context which are not query * context values. Instead, these are used at connection time to validate @@ -804,6 +816,11 @@ private DruidConnection openDruidConnection( final Map context ) { + String remoteAddress = THREAD_LOCAL_REMOTE_ADDRESS.get(); + if (remoteAddress != null) { + context.put("remoteAddress", remoteAddress); + } + if (connectionCount.incrementAndGet() > config.getMaxConnections()) { // O(connections) but we don't expect this to happen often (it's a last-ditch effort to clear out // abandoned connections) or to have too many connections. diff --git a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java index 2c963468ef80..89915eeb32bb 100644 --- a/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java +++ b/sql/src/test/java/org/apache/druid/sql/avatica/DruidAvaticaHandlerTest.java @@ -501,7 +501,7 @@ public void testExplainSelectCount() throws SQLException ImmutableMap.of( "PLAN", StringUtils.format( - "[{\"query\":{\"queryType\":\"timeseries\",\"dataSource\":{\"type\":\"table\",\"name\":\"foo\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"granularity\":{\"type\":\"all\"},\"aggregations\":[{\"type\":\"count\",\"name\":\"a0\"}],\"context\":{\"forbidden-key\":\"system-default-value\",\"sqlQueryId\":\"%s\",\"sqlStringifyArrays\":false,\"sqlTimeZone\":\"America/Los_Angeles\"}},\"signature\":[{\"name\":\"a0\",\"type\":\"LONG\"}],\"columnMappings\":[{\"queryColumn\":\"a0\",\"outputColumn\":\"cnt\"}]}]", + "[{\"query\":{\"queryType\":\"timeseries\",\"dataSource\":{\"type\":\"table\",\"name\":\"foo\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"granularity\":{\"type\":\"all\"},\"aggregations\":[{\"type\":\"count\",\"name\":\"a0\"}],\"context\":{\"forbidden-key\":\"system-default-value\",\"remoteAddress\":\"127.0.0.1\",\"sqlQueryId\":\"%s\",\"sqlStringifyArrays\":false,\"sqlTimeZone\":\"America/Los_Angeles\"}},\"signature\":[{\"name\":\"a0\",\"type\":\"LONG\"}],\"columnMappings\":[{\"queryColumn\":\"a0\",\"outputColumn\":\"cnt\"}]}]", DUMMY_SQL_QUERY_ID ), "RESOURCES", @@ -1950,4 +1950,87 @@ private static Map row(final Pair... entries) } return m; } + + /** + * Test that remote address is properly captured and logged for JDBC Avatica connections. + * This verifies the fix for issue #19230 which ensures that the client's remote address + * is tracked through the entire SQL execution lifecycle. + */ + @Test + public void testRemoteAddressInLogs() throws SQLException + { + testRequestLogger.clear(); + + try (Statement stmt = client.createStatement()) { + stmt.executeQuery("SELECT COUNT(*) AS cnt FROM druid.foo"); + } + + Assert.assertEquals(1, testRequestLogger.getSqlQueryLogs().size()); + RequestLogLine logLine = testRequestLogger.getSqlQueryLogs().get(0); + + String remoteAddress = logLine.getRemoteAddr(); + Assert.assertNotNull("Remote address should not be null", remoteAddress); + + Assert.assertTrue( + "Remote address should be a valid IP or localhost", + remoteAddress.contains("localhost") || + remoteAddress.contains("127.0.0.1") || + remoteAddress.contains("0:0:0:0:0:0:0:1") || + (remoteAddress.contains(".") && remoteAddress.length() >= 7) + ); + } + + /** + * Test that remote address is captured even when a query fails. + */ + @Test + public void testRemoteAddressInFailedQuery() throws SQLException + { + testRequestLogger.clear(); + + try (Statement stmt = client.createStatement()) { + stmt.executeQuery("SELECT nonexistent FROM druid.foo"); + Assert.fail("Query should have failed"); + } + catch (SQLException e) { + // Expected exception + } + + Assert.assertEquals(1, testRequestLogger.getSqlQueryLogs().size()); + RequestLogLine logLine = testRequestLogger.getSqlQueryLogs().get(0); + + String remoteAddress = logLine.getRemoteAddr(); + Assert.assertNotNull("Remote address should not be null even in failed query", remoteAddress); + Assert.assertFalse("Remote address should not be empty even in failed query", remoteAddress.length() == 0); + } + + /** + * Test that remote address is captured for prepared statements. + */ + @Test + public void testRemoteAddressInPreparedStatement() throws SQLException + { + testRequestLogger.clear(); + + try (PreparedStatement stmt = client.prepareStatement("SELECT COUNT(*) AS cnt FROM druid.foo WHERE dim1 = ?")) { + stmt.setString(1, "abc"); + stmt.executeQuery(); + } + + Assert.assertTrue( + "Should have at least one log entry", + testRequestLogger.getSqlQueryLogs().size() >= 1 + ); + + // Check that at least one log entry (the actual query execution) has a remote address + boolean hasRemoteAddress = false; + for (RequestLogLine logLine : testRequestLogger.getSqlQueryLogs()) { + String remoteAddress = logLine.getRemoteAddr(); + if (remoteAddress != null && remoteAddress.length() > 0) { + hasRemoteAddress = true; + break; + } + } + Assert.assertTrue("At least one log entry should have a remote address", hasRemoteAddress); + } } diff --git a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesBothQueriesAreJoin.iq b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesBothQueriesAreJoin.iq index 65b1fac2fad7..1d167a741772 100644 --- a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesBothQueriesAreJoin.iq +++ b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesBothQueriesAreJoin.iq @@ -87,6 +87,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -124,6 +125,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false diff --git a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@all_enabled.iq b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@all_enabled.iq index f8fef5db5176..6b5b81a57809 100644 --- a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@all_enabled.iq +++ b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@all_enabled.iq @@ -85,6 +85,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -116,6 +117,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false diff --git a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@default.iq b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@default.iq index c215c664567c..4ad0486d8334 100644 --- a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@default.iq +++ b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteJoinQueryTest/testUnionAllTwoQueriesLeftQueryIsJoin@default.iq @@ -79,6 +79,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -107,6 +108,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false diff --git a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueries.iq b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueries.iq index e4c98567392d..adf695e8ae33 100644 --- a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueries.iq +++ b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueries.iq @@ -66,6 +66,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -94,6 +95,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -121,6 +123,7 @@ DruidUnion(all=[true]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false diff --git a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueriesWithLimit.iq b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueriesWithLimit.iq index 49fa808fe731..f77afd72383d 100644 --- a/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueriesWithLimit.iq +++ b/sql/src/test/quidem/org.apache.druid.sql.calcite.DecoupledPlanningCalciteQueryTest/testUnionAllQueriesWithLimit.iq @@ -67,6 +67,7 @@ DruidSort(fetch=[2], druid=[logical]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -95,6 +96,7 @@ DruidSort(fetch=[2], druid=[logical]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false @@ -122,6 +124,7 @@ DruidSort(fetch=[2], druid=[logical]) "maxScatterGatherBytes" : "9223372036854775807", "outputformat" : "MYSQL", "plannerStrategy" : "DECOUPLED", + "remoteAddress" : "127.0.0.1", "sqlCurrentTimestamp" : "2000-01-01T00:00:00Z", "sqlQueryId" : "dummy", "sqlStringifyArrays" : false