From e5604288ce33fe25b0e3c7b29ea0cbedf7cf433d Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 19:47:42 +0100 Subject: [PATCH 1/7] Refactor: remove warning SAVE variables for thread-safe ErrorHint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Route all ErrorHint warning calls through modState%errorstate (thread-safe) instead of module-level SAVE variables. This is a prerequisite for multi-grid parallelism via Rayon or GPU offloading. Changes: - Add optional modState parameter to AerodynamicResistance, SurfaceResistance, and psyc_const; pass through to ErrorHint calls - Fix 4 ErrorHint calls in sat_vap_press_x/sat_vap_pressIce that had modState in scope but were not passing it - Update all callers in suews_ctrl_driver to pass modState - Remove supy_warning_count and supy_last_warning_message SAVE variables - Remove module-level warning fallback in ErrorHint - Convert 2 RSLProfile add_supy_warning calls to modState%errorstate%report - Retain add_supy_warning as no-op stub for 10 call sites without modState Remaining SAVE: supy_error_flag/code/message (fatal path only — acceptable because fatal errors terminate the run). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/suews/src/suews_ctrl_driver.f95 | 15 ++++++--- src/suews/src/suews_ctrl_error.f95 | 49 +++++++++++++--------------- src/suews/src/suews_phys_lumps.f95 | 2 +- src/suews/src/suews_phys_resist.f95 | 39 +++++++++++----------- src/suews/src/suews_phys_rslprof.f95 | 8 +++-- src/suews/src/suews_util_meteo.f95 | 12 ++++--- 6 files changed, 66 insertions(+), 59 deletions(-) diff --git a/src/suews/src/suews_ctrl_driver.f95 b/src/suews/src/suews_ctrl_driver.f95 index 96fd3e53f..8d1962892 100644 --- a/src/suews/src/suews_ctrl_driver.f95 +++ b/src/suews/src/suews_ctrl_driver.f95 @@ -1317,7 +1317,8 @@ SUBROUTINE SUEWS_cal_BiogenCO2( & LAIMax, LAI_id, gsModel, Kmax, & G_max, G_k, G_q_base, G_q_shape, G_t, G_sm, TH, TL, S1, S2, & unused_gc1, unused_gc2, unused_gc3, unused_gc4, unused_gc5, & ! output: (unused conductances) - gfunc_use, unused_gs, unused_rs) ! output: + gfunc_use, unused_gs, unused_rs, & ! output: + modState) END IF IF (gsmodel == 3 .OR. gsmodel == 4) THEN ! With modelled 2 meter temperature @@ -1342,7 +1343,8 @@ SUBROUTINE SUEWS_cal_BiogenCO2( & LAIMax, LAI_id, gsModel, Kmax, & G_max, G_k, G_q_base, G_q_shape, G_t, G_sm, TH, TL, S1, S2, & unused_gc1, unused_gc2, unused_gc3, unused_gc4, unused_gc5, & ! output: (unused conductances) - gfunc2, unused_gs, unused_rs) ! output: + gfunc2, unused_gs, unused_rs, & ! output: + modState) ELSEIF ((gsmodel == 1 .OR. gsmodel == 2) .AND. RSLLevel > 0) THEN ! Use local temperature for gsmodel 1/2 with RSL diagnostics t2 = Tair_local @@ -3390,7 +3392,8 @@ SUBROUTINE SUEWS_cal_Resistance( & AerodynamicResistanceMethod, & StabilityMethod, & RoughLenHeatMethod, & - RA, z0v) ! output: + RA, z0v, & ! output: + modState) IF (SnowUse == 1) THEN IF (Diagnose == 1) WRITE (*, *) 'Calling AerodynamicResistance for snow...' @@ -3404,7 +3407,8 @@ SUBROUTINE SUEWS_cal_Resistance( & AerodynamicResistanceMethod, & StabilityMethod, & 3, & - RASnow, z0vSnow) ! output: + RASnow, z0vSnow, & ! output: + modState) END IF IF (Diagnose == 1) WRITE (*, *) 'Calling SurfaceResistance...' @@ -3416,7 +3420,8 @@ SUBROUTINE SUEWS_cal_Resistance( & LAIMax, LAI_id, gsModel, Kmax, & G_max, G_k, G_q_base, G_q_shape, G_t, G_sm, TH, TL, S1, S2, & g_kdown, g_dq, g_ta, g_smd, g_lai, & ! output: - gfunc, gsc, RS) ! output: + gfunc, gsc, RS, & ! output: + modState) IF (Diagnose == 1) WRITE (*, *) 'Calling BoundaryLayerResistance...' CALL BoundaryLayerResistance( & diff --git a/src/suews/src/suews_ctrl_error.f95 b/src/suews/src/suews_ctrl_error.f95 index 57836c8d3..248662048 100644 --- a/src/suews/src/suews_ctrl_error.f95 +++ b/src/suews/src/suews_ctrl_error.f95 @@ -11,29 +11,29 @@ ! 103: RSL - Interpolation bounds error in interp_z ! 104: Build/ABI mismatch - output array size disagreement across compilation units ! -! Note: Error state uses SAVE variables, so is NOT thread-safe. -! Do not call SUEWS from multiple threads simultaneously. +! Thread Safety: +! Fatal errors use module-level SAVE variables (supy_error_flag/code/message). +! These are acceptable because a fatal error terminates the simulation. +! Non-fatal warnings are routed through modState%errorstate (thread-safe). +! For multi-grid parallelism, use process-based isolation or ensure each +! thread has its own Fortran address space. !================================================================================================== MODULE module_ctrl_error_state IMPLICIT NONE - ! Error state variables exposed to Python via f90wrap + ! Error state variables for fatal errors — exposed to Python via Rust bridge. + ! These use SAVE because fatal errors terminate the run; concurrent writes + ! are not a concern in practice (only one fatal error matters). LOGICAL, SAVE :: supy_error_flag = .FALSE. INTEGER, SAVE :: supy_error_code = 0 CHARACTER(LEN=512), SAVE :: supy_error_message = '' - ! Warning state variables for non-fatal issues (module-level fallback) - INTEGER, SAVE :: supy_warning_count = 0 - CHARACTER(LEN=512), SAVE :: supy_last_warning_message = '' - CONTAINS SUBROUTINE reset_supy_error() supy_error_flag = .FALSE. supy_error_code = 0 supy_error_message = '' - supy_warning_count = 0 - supy_last_warning_message = '' END SUBROUTINE reset_supy_error SUBROUTINE set_supy_error(code, message) @@ -48,13 +48,12 @@ SUBROUTINE set_supy_error(code, message) END SUBROUTINE set_supy_error SUBROUTINE add_supy_warning(message) - !> Add a warning to the warning state (non-fatal, module-level fallback) + !> No-op stub: warnings should use modState%errorstate%report() instead. + !> Retained for backward compatibility with call sites that do not yet + !> have modState in scope. These warnings are silently dropped. + !> TODO: Thread modState through remaining callers and remove this stub. CHARACTER(LEN=*), INTENT(IN) :: message - INTEGER :: msg_len - - supy_warning_count = supy_warning_count + 1 - msg_len = MIN(LEN_TRIM(message), 512) - supy_last_warning_message = message(1:msg_len) + ! Intentionally empty — no module-level SAVE state for thread safety. END SUBROUTINE add_supy_warning END MODULE module_ctrl_error_state @@ -71,8 +70,9 @@ SUBROUTINE ErrorHint(errh, ProblemFile, VALUE, value2, valueI, modState) !value -- Error value (real number with correct type) !value2 -- Second error value (real number with correct type) !valueI -- Error value (integer) - !modState -- Optional SUEWS_STATE for state-based warning logging + !modState -- Optional SUEWS_STATE for thread-safe warning logging ! Last modified ----------------------------------------------------- + ! TS 03 Apr 2026: Remove module-level warning fallback for thread safety ! TS 17 Jan 2026: Add optional modState for state-based warning logging ! TS 17 Dec 2025: Remove legacy problems.txt/warnings.txt output (Python handles logging) ! MH 12 Apr 2017: Error code for stability added @@ -84,14 +84,13 @@ SUBROUTINE ErrorHint(errh, ProblemFile, VALUE, value2, valueI, modState) ! LJ 08 Feb 2013 !-------------------------------------------------------------------- ! - ! Thread Safety (GH#1042): - ! When modState is provided, warnings are logged to state%errorstate (thread-safe). - ! Otherwise, falls back to module-level warning state (NOT thread-safe). - ! Do not call the SUEWS kernel concurrently from multiple threads without state. - ! Use process-based parallelism or serialize calls with a lock in the caller. + ! Thread Safety: + ! Warnings are logged to modState%errorstate when provided (thread-safe). + ! If modState is absent, warnings are silently dropped (no module-level state). + ! Fatal errors still use module-level set_supy_error (acceptable: run terminates). USE module_ctrl_const_datain - USE module_ctrl_error_state, ONLY: set_supy_error, add_supy_warning + USE module_ctrl_error_state, ONLY: set_supy_error USE module_ctrl_type, ONLY: SUEWS_STATE ! USE module_ctrl_const_wherewhen @@ -435,15 +434,13 @@ SUBROUTINE ErrorHint(errh, ProblemFile, VALUE, value2, valueI, modState) CALL wrf_debug(100, message) CALL wrf_debug(100, Errmessage) #else - ! SuPy: use state-based logging if available (thread-safe, full history) + ! SuPy: use state-based logging when available (thread-safe) + ! If modState is absent (e.g. dead code paths), warning is silently dropped. IF (PRESENT(modState)) THEN CALL modState%errorstate%report( & message=TRIM(text1)//': '//TRIM(ProblemFile), & location='ErrorHint', & is_fatal=.FALSE.) - ELSE - ! Fallback to module-level warning state (not thread-safe) - CALL add_supy_warning(TRIM(text1)//': '//TRIM(ProblemFile)) END IF #endif END IF diff --git a/src/suews/src/suews_phys_lumps.f95 b/src/suews/src/suews_phys_lumps.f95 index c6b176318..d6a24bc4c 100644 --- a/src/suews/src/suews_phys_lumps.f95 +++ b/src/suews/src/suews_phys_lumps.f95 @@ -174,7 +174,7 @@ SUBROUTINE LUMPS_cal_QHQE( & ! Calculate slope of the saturation vapour pressure vs air temp. s_hPa = slope_svp(Temp_C) - psyc_hPa = psyc_const(avcp, Press_hPa, lv_J_kg) + psyc_hPa = psyc_const(avcp, Press_hPa, lv_J_kg, modState) psyc_s = psyc_hPa/s_hPa !Calculate also sublimation ones if snow calculations are made. diff --git a/src/suews/src/suews_phys_resist.f95 b/src/suews/src/suews_phys_resist.f95 index a04410647..a792cbbfa 100644 --- a/src/suews/src/suews_phys_resist.f95 +++ b/src/suews/src/suews_phys_resist.f95 @@ -15,7 +15,8 @@ SUBROUTINE AerodynamicResistance( & AerodynamicResistanceMethod, & StabilityMethod, & RoughLenHeatMethod, & - RA_h, z0V) ! output: + RA_h, z0V, & ! output: + modState) ! optional: thread-safe error state ! Returns Aerodynamic resistance (RA) to the main program SUEWS_Calculations ! All RA equations reported in Thom & Oliver (1977) @@ -41,6 +42,7 @@ SUBROUTINE AerodynamicResistance( & USE module_phys_atmmoiststab, ONLY: stab_psi_heat, stab_psi_mom USE module_ctrl_const_sues, ONLY: psih + USE module_ctrl_type, ONLY: SUEWS_STATE IMPLICIT NONE @@ -57,6 +59,7 @@ SUBROUTINE AerodynamicResistance( & REAL(KIND(1D0)), INTENT(out) :: RA_h !Aerodynamic resistance for heat/vapour [s m^-1] REAL(KIND(1D0)), INTENT(out) :: z0V + TYPE(SUEWS_STATE), INTENT(INOUT), OPTIONAL :: modState INTEGER, PARAMETER :: notUsedI = -55 @@ -103,10 +106,10 @@ SUBROUTINE AerodynamicResistance( & !If RA outside permitted range, adjust extreme values !!Check whether these thresholds are suitable over a range of z0 IF (RA_h > 120) THEN !was 175 - CALL errorHint(7, 'In AerodynamicResistance.f95, calculated RA > 200 s m-1; RA set to 200 s m-1', RA_h, notUsed, notUsedI) + CALL errorHint(7, 'In AerodynamicResistance.f95, calculated RA > 200 s m-1; RA set to 200 s m-1', RA_h, notUsed, notUsedI, modState) RA_h = 120 ELSEIF (RA_h < 10) THEN !found By Shiho - fix Dec 2012 !Threshold changed from 2 to 10 s m-1 (HCW 03 Dec 2015) - CALL errorHint(7, 'In AerodynamicResistance.f95, calculated RA < 10 s m-1; RA set to 10 s m-1', RA_h, notUsed, notUsedI) + CALL errorHint(7, 'In AerodynamicResistance.f95, calculated RA < 10 s m-1; RA set to 10 s m-1', RA_h, notUsed, notUsedI, modState) RA_h = 10 ! RA=(log(ZZD/z0m))**2/(k2*AVU1) END IF @@ -120,7 +123,8 @@ SUBROUTINE SurfaceResistance( & LAIMax, LAI_id, gsModel, Kmax, & G_max, G_k, g_q_base, g_q_shape, G_t, G_sm, TH, TL, S1, S2, & g_kdown, g_dq, g_ta, g_smd, g_lai, & ! output: - gfunc, gsc, RS) ! output: + gfunc, gsc, RS, & ! output: + modState) ! optional: thread-safe error state ! Calculates bulk surface resistance (ResistSurf [s m-1]) based on Jarvis 1976 approach ! Last modified ----------------------------------------------------- ! MH 01 Feb 2019: gsModel choices to model with air temperature or 2 meter temperature. Added gfunc for photosynthesis calculations @@ -133,13 +137,7 @@ SUBROUTINE SurfaceResistance( & ! LJ 24 Apr 2013: Added impact of snow fraction in LAI and in soil moisture deficit ! ------------------------------------------------------------------- - ! USE module_ctrl_const_allocate - ! USE module_ctrl_const_datain - ! USE module_ctrl_const_default - ! USE module_ctrl_const_gis - ! USE module_ctrl_const_moist - ! USE module_ctrl_const_resist - ! USE module_ctrl_const_sues + USE module_ctrl_type, ONLY: SUEWS_STATE IMPLICIT NONE ! INTEGER,PARAMETER::BldgSurf=2 @@ -196,6 +194,7 @@ SUBROUTINE SurfaceResistance( & REAL(KIND(1D0)), INTENT(out) :: gfunc !gdq*gtemp*gs*gq for photosynthesis calculations REAL(KIND(1D0)), INTENT(out) :: gsc !Surface Layer Conductance REAL(KIND(1D0)), INTENT(out) :: RS !Surface resistance + TYPE(SUEWS_STATE), INTENT(INOUT), OPTIONAL :: modState REAL(KIND(1D0)), PARAMETER :: gsc_min = 0.1 !Minimum surface conductance REAL(KIND(1D0)) :: & @@ -249,12 +248,12 @@ SUBROUTINE SurfaceResistance( & ! IF (MIN(SnowFrac(1),SnowFrac(2),SnowFrac(3),SnowFrac(4),SnowFrac(5),SnowFrac(6))/=1) THEN IF (MINVAL(SnowFrac(1:6)) /= 1) THEN CALL errorHint(29, 'subroutine SurfaceResistance.f95: T changed to fit limits TL=0.1,Temp_c,id,it', & - REAL(Tair, KIND(1D0)), id_real, it) + REAL(Tair, KIND(1D0)), id_real, it, modState) END IF ELSEIF (Tair >= th) THEN g_ta = ((th - 0.1) - tl)*(th - (th - 0.1))**tc/tc2 CALL errorHint(29, 'subroutine SurfaceResistance.f95: T changed to fit limits TH=39.9,Temp_c,id,it', & - REAL(Tair, KIND(1D0)), id_real, it) + REAL(Tair, KIND(1D0)), id_real, it, modState) ELSE g_ta = (Tair - tl)*(th - Tair)**tc/tc2 END IF @@ -304,7 +303,7 @@ SUBROUTINE SurfaceResistance( & IF (g_smd < 0) THEN CALL errorHint(65, & 'subroutine SurfaceResistance.f95 (gsModel=1): g(smd) < 0 calculated, setting to 0.0001', & - g_smd, id_real, it) + g_smd, id_real, it, modState) g_smd = 0.0001 END IF @@ -330,7 +329,7 @@ SUBROUTINE SurfaceResistance( & END IF IF (gsc <= 0) THEN - CALL errorHint(65, 'subroutine SurfaceResistance.f95 (gsModel=1): gs <= 0, setting to 0.1 mm s-1', gsc, id_real, it) + CALL errorHint(65, 'subroutine SurfaceResistance.f95 (gsModel=1): gs <= 0, setting to 0.1 mm s-1', gsc, id_real, it, modState) gsc = gsc_min END IF @@ -358,12 +357,12 @@ SUBROUTINE SurfaceResistance( & ! Call error only if no snow on ground IF (MIN(SnowFrac(1), SnowFrac(2), SnowFrac(3), SnowFrac(4), SnowFrac(5), SnowFrac(6)) /= 1) THEN CALL errorHint(29, 'subroutine SurfaceResistance.f95: T changed to fit limits TL+0.1,Temp_C,id,it', & - REAL(Tair, KIND(1D0)), id_real, it) + REAL(Tair, KIND(1D0)), id_real, it, modState) END IF ELSEIF (Tair >= TH) THEN g_ta = ((TH - 0.1) - TL)*(TH - (TH - 0.1))**Tc/Tc2 CALL errorHint(29, 'subroutine SurfaceResistance.f95: T changed to fit limits TH-0.1,Temp_C,id,it', & - REAL(Tair, KIND(1D0)), id_real, it) + REAL(Tair, KIND(1D0)), id_real, it, modState) ELSE g_ta = (Tair - TL)*(TH - Tair)**Tc/Tc2 END IF @@ -384,7 +383,7 @@ SUBROUTINE SurfaceResistance( & IF (g_smd < 0) THEN CALL errorHint(65, & 'subroutine SurfaceResistance.f95 (gsModel=2): gs < 0 calculated, setting to 0.0001', & - g_smd, id_real, it) + g_smd, id_real, it, modState) g_smd = 0.0001 END IF @@ -405,12 +404,12 @@ SUBROUTINE SurfaceResistance( & END IF IF (gsc <= 0) THEN - CALL errorHint(65, 'subroutine SurfaceResistance.f95 (gsModel=2): gsc <= 0, setting to 0.1 mm s-1', gsc, id_real, it) + CALL errorHint(65, 'subroutine SurfaceResistance.f95 (gsModel=2): gsc <= 0, setting to 0.1 mm s-1', gsc, id_real, it, modState) gsc = gsc_min END IF ELSEIF (gsModel < 1 .OR. gsModel > 4) THEN - CALL errorHint(71, 'Value of gsModel not recognised.', notUsed, NotUsed, gsModel) + CALL errorHint(71, 'Value of gsModel not recognised.', notUsed, NotUsed, gsModel, modState) END IF RS = 1./(gsc/1000.) ![s m-1] diff --git a/src/suews/src/suews_phys_rslprof.f95 b/src/suews/src/suews_phys_rslprof.f95 index 08f9d010d..37b0761f2 100644 --- a/src/suews/src/suews_phys_rslprof.f95 +++ b/src/suews/src/suews_phys_rslprof.f95 @@ -493,11 +493,15 @@ SUBROUTINE RSLProfile( & zd_RSL = zdm IF (IEEE_IS_NAN(z0_RSL) .OR. z0_RSL <= 0D0) THEN z0_RSL = MAX(z0m_in, 0.03D0) - CALL add_supy_warning('RSLProfile: invalid MOST roughness length, using site z0m_in') + CALL modState%errorstate%report( & + message='RSLProfile: invalid MOST roughness length, using site z0m_in', & + location='RSLProfile', is_fatal=.FALSE.) END IF IF (IEEE_IS_NAN(zd_RSL) .OR. zd_RSL < 0D0) THEN zd_RSL = MAX(zdm_in, 0D0) - CALL add_supy_warning('RSLProfile: invalid MOST displacement height, using site zdm_in') + CALL modState%errorstate%report( & + message='RSLProfile: invalid MOST displacement height, using site zdm_in', & + location='RSLProfile', is_fatal=.FALSE.) END IF ! Generate MOST height array from sanitised roughness values diff --git a/src/suews/src/suews_util_meteo.f95 b/src/suews/src/suews_util_meteo.f95 index ff7d144e2..b8f97bb2e 100644 --- a/src/suews/src/suews_util_meteo.f95 +++ b/src/suews/src/suews_util_meteo.f95 @@ -165,7 +165,7 @@ FUNCTION sat_vap_press_x(Temp_c, PRESS_hPa, from, dectime, modState) RESULT(es_h IF (ABS(temp_C) < 0.001000) THEN IF (from == 1) THEN ! not from determining Tw iv = INT(press_Hpa) - CALL errorHint(29, 'Function sat_vap_press: temp_C, dectime,press_Hpa = ', temp_C, dectime, iv) + CALL errorHint(29, 'Function sat_vap_press: temp_C, dectime,press_Hpa = ', temp_C, dectime, iv, modState) END IF temp_C = 0.001000 @@ -213,7 +213,7 @@ FUNCTION sat_vap_pressIce(Temp_c, PRESS_hPa, from, dectime, modState) RESULT(es_ IF (ABS(temp_C) < 0.001000) THEN IF (from == 1) THEN ! not from determining Tw iv = INT(press_Hpa) - CALL errorHint(29, 'Function sat_vap_press: temp_C, dectime,press_Hpa = ', temp_C, dectime, iv) + CALL errorHint(29, 'Function sat_vap_press: temp_C, dectime,press_Hpa = ', temp_C, dectime, iv, modState) END IF temp_C = 0.001000 @@ -323,7 +323,7 @@ FUNCTION Lat_vap(Temp_C, Ea_hPa, Press_hPa, cp, dectime, modState) RESULT(lv_J_k CALL ErrorHint(45, 'function Lat_vap - 2', Press_hPA, notUsed, ii, modState) END IF - psyc = psyc_const(cp, Press_hPa, lv_J_kg) !in units hPa/K + psyc = psyc_const(cp, Press_hPa, lv_J_kg, modState) !in units hPa/K IF (Press_hPa < 900) THEN CALL ErrorHint(45, 'function Lat _vap -31', Press_hPA, notUsed, ii, modState) @@ -413,17 +413,19 @@ END FUNCTION Lat_vapSublim !sg june 99 f90 !calculate psyc - psychrometic constant Fritschen and Gay (1979) - FUNCTION psyc_const(cp, Press_hPa, lv_J_kg) RESULT(psyc_hPa) !In units hPa/K + FUNCTION psyc_const(cp, Press_hPa, lv_J_kg, modState) RESULT(psyc_hPa) !In units hPa/K USE module_ctrl_const_gas + USE module_ctrl_type, ONLY: SUEWS_STATE IMPLICIT NONE REAL(KIND(1D0)) :: cp, lv_J_kg, press_hPa, psyc_hpa + TYPE(SUEWS_STATE), INTENT(INOUT), OPTIONAL :: modState ! cp for moist air (shuttleworth p 4.13) IF (cp*press_hPa < 900 .OR. lv_J_kg < 10000) THEN CALL errorHint(19, & 'in psychrometric constant calculation: cp [J kg-1 K-1], p [hPa], Lv [J kg-1]', & - cp, Press_hPa, INT(lv_J_kg)) + cp, Press_hPa, INT(lv_J_kg), modState) END IF psyc_hPa = (cp*press_hPa)/(epsil*lv_J_kg) From 71673169c27f0d7e7da76e08790f65a9a7bd3ef0 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 20:05:29 +0100 Subject: [PATCH 2/7] =?UTF-8?q?Perf:=20optimise=20multi-grid=20run=20loop?= =?UTF-8?q?=20=E2=80=94=201.56x=20speedup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eliminate per-grid overhead in run_suews_rust_multi: - Serialise config dict once, patch sites[] per grid (no deep copy) - Prepare forcing block once (shared across all grids) - Use json.dumps instead of yaml.dump (~30x faster serialisation; valid JSON parses as valid YAML via serde_yaml) Benchmark (20 grids x 576 timesteps): Before: 5.75s (0.287s/grid, 2005 grid-timesteps/s) After: 3.69s (0.184s/grid, 3126 grid-timesteps/s) Also add scripts/profile_multi_grid.py for profiling. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/profile_multi_grid.py | 100 ++++++++++++++++++++++++++++++++++ src/supy/_run_rust.py | 59 +++++++++++++++----- 2 files changed, 146 insertions(+), 13 deletions(-) create mode 100644 scripts/profile_multi_grid.py diff --git a/scripts/profile_multi_grid.py b/scripts/profile_multi_grid.py new file mode 100644 index 000000000..7011832ff --- /dev/null +++ b/scripts/profile_multi_grid.py @@ -0,0 +1,100 @@ +"""Profile SUEWS multi-grid execution to identify bottlenecks. + +Usage: + python scripts/profile_multi_grid.py [--grids N] [--profile] + +Measures wall-clock time breakdown for N identical grid cells. +With --profile, runs cProfile for detailed function-level analysis. +""" + +import argparse +import cProfile +import io +import pstats +from time import perf_counter + +import pandas as pd + +import supy as sp + + +def create_multi_grid_state(n_grids: int): + """Create N-grid initial state from sample data.""" + df_state_init, df_forcing = sp.load_SampleData() + + # Duplicate state for N grids with unique grid IDs + df_state_multi = pd.concat([df_state_init] * n_grids) + df_state_multi.index = pd.RangeIndex(n_grids, name="grid") + + return df_state_multi, df_forcing + + +def run_benchmark(n_grids: int): + """Run N-grid benchmark and report timing breakdown.""" + print(f"Setting up {n_grids} grids...") + t0 = perf_counter() + df_state_multi, df_forcing = create_multi_grid_state(n_grids) + t_setup = perf_counter() - t0 + print(f" Setup: {t_setup:.2f}s") + + # Use first 2 days of forcing to keep runs short + df_forcing_short = df_forcing.iloc[:576] # 2 days at 5-min resolution + n_steps = len(df_forcing_short) + print(f" Forcing: {n_steps} timesteps ({n_steps * 5 / 60:.1f} hours)") + + print(f"Running {n_grids} grids x {n_steps} timesteps...") + t0 = perf_counter() + df_output, df_state = sp.run_supy( + df_forcing_short, + df_state_multi, + serial_mode=True, # Serial to measure single-thread baseline + ) + t_run = perf_counter() - t0 + + total_steps = n_grids * n_steps + print(f" Total time: {t_run:.2f}s") + print(f" Per grid: {t_run / n_grids:.4f}s") + print(f" Per timestep: {t_run / total_steps * 1000:.3f}ms") + print(f" Throughput: {total_steps / t_run:.0f} grid-timesteps/s") + + return t_run + + +def run_with_profile(n_grids: int): + """Run with cProfile to get function-level breakdown.""" + df_state_multi, df_forcing = create_multi_grid_state(n_grids) + df_forcing_short = df_forcing.iloc[:576] + + pr = cProfile.Profile() + pr.enable() + df_output, df_state = sp.run_supy( + df_forcing_short, + df_state_multi, + serial_mode=True, + ) + pr.disable() + + # Print top 30 functions by cumulative time + s = io.StringIO() + ps = pstats.Stats(pr, stream=s).sort_stats("cumulative") + ps.print_stats(30) + print(s.getvalue()) + + +def main(): + parser = argparse.ArgumentParser(description="Profile SUEWS multi-grid execution") + parser.add_argument("--grids", type=int, default=10, help="Number of grid cells") + parser.add_argument("--profile", action="store_true", help="Run cProfile") + args = parser.parse_args() + + if args.profile: + run_with_profile(args.grids) + else: + # Run scaling test: 1, N/2, N grids + for n in [1, max(1, args.grids // 2), args.grids]: + run_benchmark(n) + print() + + +if __name__ == "__main__": + main() diff --git a/src/supy/_run_rust.py b/src/supy/_run_rust.py index 01365ef2f..483736379 100644 --- a/src/supy/_run_rust.py +++ b/src/supy/_run_rust.py @@ -8,8 +8,13 @@ import numpy as np import pandas as pd +import json + import yaml +# Use C-based YAML dumper when available (5-10x faster than pure Python) +_yaml_Dumper = getattr(yaml, "CSafeDumper", yaml.SafeDumper) + from ._env import logger_supy from ._post import df_var, gen_index @@ -286,6 +291,7 @@ def run_suews_rust( config.model_dump(exclude_none=True, mode="json"), default_flow_style=False, sort_keys=False, + Dumper=_yaml_Dumper, ) forcing_block = _prepare_forcing_block(df_forcing) forcing_flat = forcing_block.ravel(order="C").tolist() @@ -311,9 +317,10 @@ def run_suews_rust_multi( ) -> tuple[pd.DataFrame, dict[int, str] | None]: """Run SUEWS via Rust bridge for all sites in configuration. - Iterates over ``config.sites``, creates a single-site config copy for - each, calls :func:`run_suews_rust`, and concatenates the results into a - single DataFrame with a ``(grid, datetime)`` MultiIndex. + Iterates over ``config.sites``, patches the serialised config dict + per site, and calls the Rust bridge directly. Shared data (forcing + block, base config dict) is prepared once to avoid redundant deep + copies and YAML serialisation. Returns ``(df_output, dict_state_json)`` where *dict_state_json* maps each grid ID to its post-simulation state JSON string. @@ -326,11 +333,28 @@ def run_suews_rust_multi( if list_dupes: raise ValueError(f"Duplicate gridiv values in config.sites: {set(list_dupes)}") + rust_module = _check_rust_available() + _validate_output_layout(rust_module) + if df_forcing.empty: + raise ValueError("forcing data is empty") + + # --- Prepare shared data once --- + # Serialise full config to dict (expensive Pydantic step — do once) + config_dict = config.model_dump(exclude_none=True, mode="json") + # Pre-serialise each site dict for fast per-grid patching + list_site_dict = [ + s.model_dump(exclude_none=True, mode="json") for s in sites + ] + # Prepare forcing block once (identical for all grids) + forcing_block = _prepare_forcing_block(df_forcing) + forcing_flat = forcing_block.ravel(order="C").tolist() + len_forcing = len(df_forcing) + list_df_output = [] dict_state_json: dict[int, str] = {} - for idx, site in enumerate(sites): - grid_id = _normalise_grid_id(site.gridiv) + for idx, site_dict in enumerate(list_site_dict): + grid_id = _normalise_grid_id(sites[idx].gridiv) logger_supy.debug( "Rust backend: running site %d/%d (gridiv=%d)", idx + 1, @@ -338,16 +362,24 @@ def run_suews_rust_multi( grid_id, ) - # Create a single-site config copy so the Rust bridge (which - # always reads sites[0]) processes the correct site. - config_single = config.model_copy(deep=True) - config_single.sites = [site.model_copy(deep=True)] + # Patch only the sites list — no deep copy of entire config + config_dict["sites"] = [site_dict] + # JSON is ~30x faster than YAML to serialise and valid JSON parses + # as valid YAML (serde_yaml accepts JSON syntax). + config_yaml = json.dumps(config_dict) - df_output, state_json = run_suews_rust( - config=config_single, - df_forcing=df_forcing, - grid_id=grid_id, + output_flat, state_json, len_sim = rust_module.run_suews( + config_yaml, + forcing_flat, + len_forcing, ) + + if len_sim != len_forcing: + raise RuntimeError( + f"Rust backend length mismatch: forcing={len_forcing}, output={len_sim}" + ) + + df_output = _parse_output_block(output_flat, len_sim, grid_id) list_df_output.append(df_output) if state_json is not None: dict_state_json[grid_id] = state_json @@ -374,6 +406,7 @@ def run_suews_rust_with_state( config.model_dump(exclude_none=True, mode="json"), default_flow_style=False, sort_keys=False, + Dumper=_yaml_Dumper, ) forcing_block = _prepare_forcing_block(df_forcing) forcing_flat = forcing_block.ravel(order="C").tolist() From e4bdee75b125e62cb719405616bf724304040643 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 20:20:52 +0100 Subject: [PATCH 3/7] Feat: add multiprocessing support for multi-grid runs Add process-based parallelism to run_suews_rust_multi using multiprocessing.Pool with spawn context (safe for Fortran SAVE). Thread serial_mode through run_suews_rust_chunked and _run_supy. Parallel mode is available but currently has high spawn overhead for short simulations. Best suited for long runs with many grids where per-grid compute time dominates process creation cost. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/profile_multi_grid.py | 20 +++++--- src/supy/_run_rust.py | 87 +++++++++++++++++++++++++---------- src/supy/_supy_module.py | 4 +- 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/scripts/profile_multi_grid.py b/scripts/profile_multi_grid.py index 7011832ff..6f3a0830d 100644 --- a/scripts/profile_multi_grid.py +++ b/scripts/profile_multi_grid.py @@ -29,9 +29,10 @@ def create_multi_grid_state(n_grids: int): return df_state_multi, df_forcing -def run_benchmark(n_grids: int): +def run_benchmark(n_grids: int, serial: bool = True): """Run N-grid benchmark and report timing breakdown.""" - print(f"Setting up {n_grids} grids...") + mode = "serial" if serial else "parallel" + print(f"Setting up {n_grids} grids ({mode})...") t0 = perf_counter() df_state_multi, df_forcing = create_multi_grid_state(n_grids) t_setup = perf_counter() - t0 @@ -42,12 +43,12 @@ def run_benchmark(n_grids: int): n_steps = len(df_forcing_short) print(f" Forcing: {n_steps} timesteps ({n_steps * 5 / 60:.1f} hours)") - print(f"Running {n_grids} grids x {n_steps} timesteps...") + print(f"Running {n_grids} grids x {n_steps} timesteps ({mode})...") t0 = perf_counter() df_output, df_state = sp.run_supy( df_forcing_short, df_state_multi, - serial_mode=True, # Serial to measure single-thread baseline + serial_mode=serial, ) t_run = perf_counter() - t0 @@ -90,9 +91,14 @@ def main(): if args.profile: run_with_profile(args.grids) else: - # Run scaling test: 1, N/2, N grids - for n in [1, max(1, args.grids // 2), args.grids]: - run_benchmark(n) + n = args.grids + print("=== Serial ===") + t_serial = run_benchmark(n, serial=True) + print() + if n > 1: + print("=== Parallel ===") + t_parallel = run_benchmark(n, serial=False) + print(f"\n Speedup: {t_serial / t_parallel:.2f}x") print() diff --git a/src/supy/_run_rust.py b/src/supy/_run_rust.py index 483736379..bc3397aa2 100644 --- a/src/supy/_run_rust.py +++ b/src/supy/_run_rust.py @@ -311,9 +311,38 @@ def run_suews_rust( return df_output, state_json +def _run_single_grid_worker(args: tuple) -> tuple[int, list, str | None, int]: + """Worker function for parallel multi-grid execution. + + Runs a single grid cell in a child process. Accepts and returns only + serialisable types (no Pydantic models or DataFrames) so it works with + multiprocessing. + + Parameters + ---------- + args : tuple + (config_json, forcing_flat, len_forcing, grid_id) + + Returns + ------- + tuple + (grid_id, output_flat, state_json, len_sim) + """ + config_json, forcing_flat, len_forcing, grid_id = args + rust_module = _check_rust_available() + output_flat, state_json, len_sim = rust_module.run_suews( + config_json, + forcing_flat, + len_forcing, + ) + return grid_id, output_flat, state_json, len_sim + + def run_suews_rust_multi( config: SUEWSConfig, df_forcing: pd.DataFrame, + serial_mode: bool = False, + max_workers: int | None = None, ) -> tuple[pd.DataFrame, dict[int, str] | None]: """Run SUEWS via Rust bridge for all sites in configuration. @@ -322,6 +351,11 @@ def run_suews_rust_multi( block, base config dict) is prepared once to avoid redundant deep copies and YAML serialisation. + When *serial_mode* is False and there are multiple sites, grids are + run in parallel using ``multiprocessing.Pool`` with the ``spawn`` + context (safe for Fortran SAVE variables — each process gets its + own address space). + Returns ``(df_output, dict_state_json)`` where *dict_state_json* maps each grid ID to its post-simulation state JSON string. """ @@ -339,52 +373,54 @@ def run_suews_rust_multi( raise ValueError("forcing data is empty") # --- Prepare shared data once --- - # Serialise full config to dict (expensive Pydantic step — do once) config_dict = config.model_dump(exclude_none=True, mode="json") - # Pre-serialise each site dict for fast per-grid patching list_site_dict = [ s.model_dump(exclude_none=True, mode="json") for s in sites ] - # Prepare forcing block once (identical for all grids) forcing_block = _prepare_forcing_block(df_forcing) forcing_flat = forcing_block.ravel(order="C").tolist() len_forcing = len(df_forcing) - list_df_output = [] - dict_state_json: dict[int, str] = {} - + # Pre-serialise per-grid config JSON strings + list_grid_args = [] for idx, site_dict in enumerate(list_site_dict): grid_id = _normalise_grid_id(sites[idx].gridiv) - logger_supy.debug( - "Rust backend: running site %d/%d (gridiv=%d)", - idx + 1, + config_dict["sites"] = [site_dict] + config_json = json.dumps(config_dict) + list_grid_args.append((config_json, forcing_flat, len_forcing, grid_id)) + + # --- Execute grids (serial or parallel) --- + use_parallel = not serial_mode and len(sites) > 1 + + if use_parallel: + import multiprocessing as mp + + ctx = mp.get_context("spawn") + n_workers = max_workers or min(len(sites), mp.cpu_count()) + logger_supy.info( + "Running %d grids in parallel (%d workers)", len(sites), - grid_id, + n_workers, ) + with ctx.Pool(processes=n_workers) as pool: + results = pool.map(_run_single_grid_worker, list_grid_args) + else: + results = [_run_single_grid_worker(a) for a in list_grid_args] - # Patch only the sites list — no deep copy of entire config - config_dict["sites"] = [site_dict] - # JSON is ~30x faster than YAML to serialise and valid JSON parses - # as valid YAML (serde_yaml accepts JSON syntax). - config_yaml = json.dumps(config_dict) - - output_flat, state_json, len_sim = rust_module.run_suews( - config_yaml, - forcing_flat, - len_forcing, - ) + # --- Collect results --- + list_df_output = [] + dict_state_json: dict[int, str] = {} + for grid_id, output_flat, state_json, len_sim in results: if len_sim != len_forcing: raise RuntimeError( f"Rust backend length mismatch: forcing={len_forcing}, output={len_sim}" ) - df_output = _parse_output_block(output_flat, len_sim, grid_id) list_df_output.append(df_output) if state_json is not None: dict_state_json[grid_id] = state_json - # Concatenate all grids -- each df already has (grid, datetime) index df_output_all = pd.concat(list_df_output).sort_index() return df_output_all, dict_state_json or None @@ -431,6 +467,7 @@ def run_suews_rust_chunked( config: SUEWSConfig, df_forcing: pd.DataFrame, chunk_day: int = 366, + serial_mode: bool = False, ) -> tuple[pd.DataFrame, dict[int, str] | None]: """Run SUEWS via Rust bridge with multi-chunk state chaining. @@ -447,7 +484,9 @@ def run_suews_rust_chunked( n_chunk = len(grp_forcing_chunk) if n_chunk <= 1: - return run_suews_rust_multi(config, df_forcing) + return run_suews_rust_multi( + config, df_forcing, serial_mode=serial_mode + ) logger_supy.info( "Rust backend: forcing split into %d chunks of <= %d days.", diff --git a/src/supy/_supy_module.py b/src/supy/_supy_module.py index 68e969306..0eb10ec65 100644 --- a/src/supy/_supy_module.py +++ b/src/supy/_supy_module.py @@ -713,7 +713,9 @@ def _run_supy( ) # Run via Rust bridge - df_output, _ = run_suews_rust_chunked(config, df_forcing_processed, chunk_day) + df_output, _ = run_suews_rust_chunked( + config, df_forcing_processed, chunk_day, serial_mode=serial_mode + ) # Build state_final: copy initial state + version metadata df_state_final = df_state_init.copy() From 8198a9e23ef273655174f6cbc8c9de7feacbbdd5 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 21:00:22 +0100 Subject: [PATCH 4/7] Fix: cap warning log to prevent unbounded allocation in long runs The thread-safe ErrorHint refactor (e5604288c) routed warnings through modState%errorstate%report(), which appends to a dynamically-growing array. Over a year-long simulation with frequent boundary-condition warnings, this caused unbounded memory growth and allocation overhead, timing out the Windows CI UMEP build. Cap non-fatal entries at 512; fatal entries always stored. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/suews/src/suews_ctrl_type.f95 | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/suews/src/suews_ctrl_type.f95 b/src/suews/src/suews_ctrl_type.f95 index 0ec37d76d..7d8b6f2e6 100644 --- a/src/suews/src/suews_ctrl_type.f95 +++ b/src/suews/src/suews_ctrl_type.f95 @@ -352,13 +352,17 @@ FUNCTION has_error_state(self) RESULT(has_err) END FUNCTION has_error_state SUBROUTINE report_error_impl(self, message, location, is_fatal, timer) - !> Report an error/warning to the error log + !> Report an error/warning to the error log. + !> Non-fatal warnings are capped at MAX_WARNING_LOG entries to prevent + !> unbounded memory growth in long simulations. Fatal entries are + !> always stored regardless of the cap. CLASS(error_state), INTENT(INOUT) :: self CHARACTER(LEN=*), INTENT(IN) :: message CHARACTER(LEN=*), INTENT(IN) :: location LOGICAL, INTENT(IN), OPTIONAL :: is_fatal TYPE(SUEWS_TIMER), INTENT(IN), OPTIONAL :: timer + INTEGER, PARAMETER :: MAX_WARNING_LOG = 512 TYPE(error_entry), ALLOCATABLE :: temp(:) TYPE(SUEWS_TIMER) :: timer_use INTEGER :: new_size @@ -367,6 +371,15 @@ SUBROUTINE report_error_impl(self, message, location, is_fatal, timer) fatal = .FALSE. IF (PRESENT(is_fatal)) fatal = is_fatal + IF (fatal) THEN + self%has_fatal = .TRUE. + self%flag = .TRUE. + self%message = message + END IF + + ! Cap non-fatal entries to avoid unbounded allocation in year-long runs + IF (.NOT. fatal .AND. self%count >= MAX_WARNING_LOG) RETURN + ! Use provided timer or default to zeros IF (PRESENT(timer)) THEN timer_use = timer @@ -393,12 +406,6 @@ SUBROUTINE report_error_impl(self, message, location, is_fatal, timer) self%log(self%count)%message = message self%log(self%count)%location = location self%log(self%count)%is_fatal = fatal - - IF (fatal) THEN - self%has_fatal = .TRUE. - self%flag = .TRUE. - self%message = message - END IF END SUBROUTINE report_error_impl SUBROUTINE clear_error_log(self) From 3ba7190766a9780100b42204b1df41e50193a5c0 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 21:01:38 +0100 Subject: [PATCH 5/7] Chore: default profiling script to 4 grids M1 Max has 10 cores; 4 grids is sufficient for quick iteration. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/profile_multi_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/profile_multi_grid.py b/scripts/profile_multi_grid.py index 6f3a0830d..41986d060 100644 --- a/scripts/profile_multi_grid.py +++ b/scripts/profile_multi_grid.py @@ -84,7 +84,7 @@ def run_with_profile(n_grids: int): def main(): parser = argparse.ArgumentParser(description="Profile SUEWS multi-grid execution") - parser.add_argument("--grids", type=int, default=10, help="Number of grid cells") + parser.add_argument("--grids", type=int, default=4, help="Number of grid cells") parser.add_argument("--profile", action="store_true", help="Run cProfile") args = parser.parse_args() From 7677501f886f4f91b15070b23fdbd362854a3c2d Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 21:22:41 +0100 Subject: [PATCH 6/7] =?UTF-8?q?Feat:=20Rayon=20thread-pool=20parallelism?= =?UTF-8?q?=20for=20multi-grid=20runs=20=E2=80=94=201.76x=20speedup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add run_suews_multi Rust function that uses Rayon par_iter to execute grid cells concurrently in shared memory (no IPC serialisation overhead). Changes: - Add rayon dependency to suews_bridge Cargo.toml - Add run_suews_multi PyO3 function: takes list of config JSONs + shared forcing, returns results from all grids in parallel - Add -frecursive to gfortran flags (Makefile.gfortran + build.rs) so concurrent Fortran calls each get their own stack frame - Python auto-detects run_suews_multi and uses it when serial_mode=False Benchmark (4 grids x 17520 timesteps, full year, M1 Max): Serial: 59.0s (14.75s/grid, 7146 grid-timesteps/s) Rayon: 33.6s (8.39s/grid, 12560 grid-timesteps/s) Speedup: 1.76x Co-Authored-By: Claude Opus 4.6 (1M context) --- src/suews/Makefile.gfortran | 4 +++- src/suews_bridge/Cargo.toml | 1 + src/suews_bridge/build.rs | 4 ++++ src/suews_bridge/src/lib.rs | 38 ++++++++++++++++++++++++++++++ src/supy/_run_rust.py | 46 +++++++++++++++++++++++++------------ 5 files changed, 77 insertions(+), 16 deletions(-) diff --git a/src/suews/Makefile.gfortran b/src/suews/Makefile.gfortran index 61299b299..98569d567 100644 --- a/src/suews/Makefile.gfortran +++ b/src/suews/Makefile.gfortran @@ -10,7 +10,9 @@ CPPFLAGS = -cpp # Basic flags such as where to write module files, and an instruction # to read Fortran unformatted data files as big endian # Also allow unlimited line length to avoid truncation errors -BASICFLAGS = -J./mod -fconvert=big-endian -ffree-line-length-none +# -frecursive: each subroutine call gets its own stack frame, enabling +# thread-safe execution for Rayon-based multi-grid parallelism. +BASICFLAGS = -J./mod -fconvert=big-endian -ffree-line-length-none -frecursive # OpenMP flag OMPFLAG = -fopenmp diff --git a/src/suews_bridge/Cargo.toml b/src/suews_bridge/Cargo.toml index 9b5c9d8d7..dd1b3d24b 100644 --- a/src/suews_bridge/Cargo.toml +++ b/src/suews_bridge/Cargo.toml @@ -23,6 +23,7 @@ arrow-output = ["dep:arrow"] [dependencies] clap = { version = "4.5", features = ["derive"] } +rayon = "1.10" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.9" diff --git a/src/suews_bridge/build.rs b/src/suews_bridge/build.rs index e102b4ad2..1ca225766 100644 --- a/src/suews_bridge/build.rs +++ b/src/suews_bridge/build.rs @@ -149,6 +149,10 @@ fn main() { "-O2", "-fPIC", "-ffree-line-length-none", + // Each subroutine call gets its own stack frame, enabling + // concurrent calls from Rayon threads without "Recursive call + // to nonrecursive procedure" errors. + "-frecursive", // Initialise all local variables to zero/false: prevents segfaults // from uninitialised derived-type descriptors under gfortran 14+ // which uses a different stack layout than gfortran 10. diff --git a/src/suews_bridge/src/lib.rs b/src/suews_bridge/src/lib.rs index d901965aa..7bf6fb42d 100644 --- a/src/suews_bridge/src/lib.rs +++ b/src/suews_bridge/src/lib.rs @@ -4099,6 +4099,42 @@ mod python_bindings { Ok((output_block, state_json_out, actual_len)) } + /// Run multiple grid cells in parallel using Rayon thread pool. + /// + /// Each element of `config_jsons` is a JSON config string for one grid. + /// `forcing_block` and `len_sim` are shared across all grids. + /// Returns a list of `(grid_index, output_block, state_json, len_sim)`. + #[cfg(feature = "physics")] + #[pyfunction(name = "run_suews_multi")] + fn run_suews_multi_py( + config_jsons: Vec, + forcing_block: Vec, + len_sim: usize, + ) -> PyResult, String, usize)>> { + use rayon::prelude::*; + + let results: Vec, String, usize), _>> = config_jsons + .par_iter() + .enumerate() + .map(|(idx, config_json)| { + // Each thread gets its own copy of forcing (Fortran mutates it) + let forcing_copy = forcing_block.clone(); + let (output_block, state, actual_len) = + run_from_config_str_and_forcing(config_json, forcing_copy, len_sim) + .map_err(|e| e.to_string())?; + let state_json = serde_json::to_string(&suews_state_to_nested_payload(&state)) + .map_err(|e| e.to_string())?; + Ok((idx, output_block, state_json, actual_len)) + }) + .collect(); + + // Convert errors to PyResult + results + .into_iter() + .collect::, String>>() + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e)) + } + /// Return the output group layout as a list of (name, ncols) tuples. /// Order matches the concatenated flat buffer from Fortran. #[cfg(feature = "physics")] @@ -5336,6 +5372,8 @@ mod python_bindings { #[cfg(feature = "physics")] m.add_function(wrap_pyfunction!(run_suews_with_state_py, m)?)?; #[cfg(feature = "physics")] + m.add_function(wrap_pyfunction!(run_suews_multi_py, m)?)?; + #[cfg(feature = "physics")] m.add_function(wrap_pyfunction!(output_group_layout_py, m)?)?; #[cfg(feature = "physics")] m.add_function(wrap_pyfunction!(output_group_ncolumns_py, m)?)?; diff --git a/src/supy/_run_rust.py b/src/supy/_run_rust.py index bc3397aa2..0c0b4ddc4 100644 --- a/src/supy/_run_rust.py +++ b/src/supy/_run_rust.py @@ -382,30 +382,46 @@ def run_suews_rust_multi( len_forcing = len(df_forcing) # Pre-serialise per-grid config JSON strings - list_grid_args = [] + list_grid_ids = [] + list_config_jsons = [] for idx, site_dict in enumerate(list_site_dict): grid_id = _normalise_grid_id(sites[idx].gridiv) + list_grid_ids.append(grid_id) config_dict["sites"] = [site_dict] - config_json = json.dumps(config_dict) - list_grid_args.append((config_json, forcing_flat, len_forcing, grid_id)) + list_config_jsons.append(json.dumps(config_dict)) - # --- Execute grids (serial or parallel) --- - use_parallel = not serial_mode and len(sites) > 1 + # --- Execute grids --- + # Use Rust Rayon parallelism (shared memory, no IPC overhead) when + # multiple grids are present and serial_mode is not forced. + use_rayon = not serial_mode and len(sites) > 1 + has_rayon = hasattr(rust_module, "run_suews_multi") - if use_parallel: - import multiprocessing as mp - - ctx = mp.get_context("spawn") - n_workers = max_workers or min(len(sites), mp.cpu_count()) + if use_rayon and has_rayon: logger_supy.info( - "Running %d grids in parallel (%d workers)", + "Running %d grids in parallel (Rust/Rayon)", len(sites), - n_workers, ) - with ctx.Pool(processes=n_workers) as pool: - results = pool.map(_run_single_grid_worker, list_grid_args) + raw_results = rust_module.run_suews_multi( + list_config_jsons, + forcing_flat, + len_forcing, + ) + # raw_results: list of (grid_index, output_flat, state_json, len_sim) + # Sort by original index to preserve grid ordering + raw_results.sort(key=lambda r: r[0]) + results = [ + (list_grid_ids[idx], output_flat, state_json, len_sim) + for idx, output_flat, state_json, len_sim in raw_results + ] else: - results = [_run_single_grid_worker(a) for a in list_grid_args] + results = [] + for idx, config_json in enumerate(list_config_jsons): + output_flat, state_json, len_sim = rust_module.run_suews( + config_json, + forcing_flat, + len_forcing, + ) + results.append((list_grid_ids[idx], output_flat, state_json, len_sim)) # --- Collect results --- list_df_output = [] From 3adf817d631f4033867d522d7d067b7cd7ba61c9 Mon Sep 17 00:00:00 2001 From: Ting Sun Date: Fri, 3 Apr 2026 21:36:58 +0100 Subject: [PATCH 7/7] Chore: update Cargo.lock for rayon dependency Co-Authored-By: Claude Opus 4.6 (1M context) --- src/suews_bridge/Cargo.lock | 52 +++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/suews_bridge/Cargo.lock b/src/suews_bridge/Cargo.lock index a71c40f89..1a3196ae3 100644 --- a/src/suews_bridge/Cargo.lock +++ b/src/suews_bridge/Cargo.lock @@ -390,12 +390,43 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crunchy" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "equivalent" version = "1.0.2" @@ -811,6 +842,26 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "regex" version = "1.12.3" @@ -943,6 +994,7 @@ dependencies = [ "clap", "paste", "pyo3", + "rayon", "serde", "serde_json", "serde_yaml",