WIP: hybrid ODEs #260
Conversation
…into ba/ode_signature
…to ba/ode_signature
There was a problem hiding this comment.
Code Review
This pull request adds the ODEHybridModel to couple LSTMs with ODE step functions, supporting time-varying, static, and global parameters. It includes a tutorial, example signatures, and loss computation updates. Feedback identifies performance issues in the forward pass, such as inefficient array concatenation and NamedTuple construction in loops, and potential GPU bottlenecks. Suggestions were also made to allow an explicit time step for ODE integration and to fix logical errors in the example signature file.
| if sn in m.static_nn_param_names | ||
| C = static_kw[sn] # already (1, B) from static NN | ||
| elseif sn in m.global_param_names | ||
| C₀_val = scale_single_param(sn, ps[sn], m.parameters) | ||
| C = C₀_val .+ zeros(ET, m.n_state, B) | ||
| elseif sn in m.fixed_param_names | ||
| C₀_val = st.fixed[sn] | ||
| C = C₀_val .+ zeros(ET, m.n_state, B) | ||
| else | ||
| C = zeros(ET, m.n_state, B) | ||
| end |
There was a problem hiding this comment.
Using zeros(ET, ...) creates a CPU-based array. If the model is running on a GPU (i.e., pred_3d is a CuArray), this will cause a performance bottleneck due to CPU-GPU transfers or a runtime error during the vcat operation on line 328. It is better to use similar to ensure the initial state is created on the same device as the input data.
if sn in m.static_nn_param_names
C = static_kw[sn] # already (1, B) from static NN
elseif sn in m.global_param_names
C₀_val = scale_single_param(sn, ps[sn], m.parameters)
C = similar(pred_3d, ET, m.n_state, B)
fill!(C, zero(ET))
C = C .+ C₀_val
elseif sn in m.fixed_param_names
C₀_val = st.fixed[sn]
C = similar(pred_3d, ET, m.n_state, B)
fill!(C, zero(ET))
C = C .+ C₀_val
else
C = similar(pred_3d, ET, m.n_state, B)
fill!(C, zero(ET))
end| result_1, nn_kw_1, st_proj = _ode_inner_step( | ||
| m, h, C, forc_3d, 1, ps, st_proj, global_kw, fixed_kw, static_non_state_kw | ||
| ) | ||
| C = C .+ result_1[m.deriv_name] |
There was a problem hiding this comment.
The ODE integration uses a fixed Euler step with an implicit dt parameter or derive it from the input data.
| tgt_trajs = NamedTuple{Tuple(m.targets)}( | ||
| Tuple(vcat(tgt_trajs[tgt], result_t[tgt]) for tgt in m.targets) | ||
| ) | ||
| nn_trajs = NamedTuple{Tuple(m.lstm_param_names)}( | ||
| Tuple(vcat(nn_trajs[n], nn_kw_t[n]) for n in m.lstm_param_names) | ||
| ) |
There was a problem hiding this comment.
Accumulating trajectories by repeatedly calling vcat and reconstructing NamedTuples inside a loop is highly inefficient in Julia. It leads to Vector and then use reduce(vcat, ...) or stack at the end of the loop.
|
|
||
| if forc_3d !== nothing | ||
| forc_t = forc_3d[:, t, :] | ||
| forc_kw = (; zip(m.forcing, [forc_t[i:i, :] for i in 1:length(m.forcing)])...) |
There was a problem hiding this comment.
Constructing a NamedTuple using zip and splatting inside the time loop is computationally expensive due to repeated allocations and potential dynamic dispatch. Since the forcing names are known at construction time, consider pre-slicing the forcing data or using a more efficient way to pass these values to the mechanistic model.
|
|
||
|
|
||
|
|
||
| dCdt(;C, RECO, GPP) = RECO(;C, Rb, Q10, TA) .- GPP(;) |
There was a problem hiding this comment.
This line appears to be logically incorrect. RECO and GPP are passed as arguments (which are likely arrays based on the step functions above), but they are being called as functions here. Additionally, Rb, Q10, and TA are not defined in this scope. If this file is intended to be a functional part of the repository, it needs significant cleanup.
No description provided.