Skip to content

WIP: hybrid ODEs #260

Open
BernhardAhrens wants to merge 6 commits into
mainfrom
ba/ode_signature
Open

WIP: hybrid ODEs #260
BernhardAhrens wants to merge 6 commits into
mainfrom
ba/ode_signature

Conversation

@BernhardAhrens
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds the ODEHybridModel to couple LSTMs with ODE step functions, supporting time-varying, static, and global parameters. It includes a tutorial, example signatures, and loss computation updates. Feedback identifies performance issues in the forward pass, such as inefficient array concatenation and NamedTuple construction in loops, and potential GPU bottlenecks. Suggestions were also made to allow an explicit time step for ODE integration and to fix logical errors in the example signature file.

Comment on lines +285 to +295
if sn in m.static_nn_param_names
C = static_kw[sn] # already (1, B) from static NN
elseif sn in m.global_param_names
C₀_val = scale_single_param(sn, ps[sn], m.parameters)
C = C₀_val .+ zeros(ET, m.n_state, B)
elseif sn in m.fixed_param_names
C₀_val = st.fixed[sn]
C = C₀_val .+ zeros(ET, m.n_state, B)
else
C = zeros(ET, m.n_state, B)
end
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using zeros(ET, ...) creates a CPU-based array. If the model is running on a GPU (i.e., pred_3d is a CuArray), this will cause a performance bottleneck due to CPU-GPU transfers or a runtime error during the vcat operation on line 328. It is better to use similar to ensure the initial state is created on the same device as the input data.

    if sn in m.static_nn_param_names
        C = static_kw[sn]  # already (1, B) from static NN
    elseif sn in m.global_param_names
        C₀_val = scale_single_param(sn, ps[sn], m.parameters)
        C = similar(pred_3d, ET, m.n_state, B)
        fill!(C, zero(ET))
        C = C .+ C₀_val
    elseif sn in m.fixed_param_names
        C₀_val = st.fixed[sn]
        C = similar(pred_3d, ET, m.n_state, B)
        fill!(C, zero(ET))
        C = C .+ C₀_val
    else
        C = similar(pred_3d, ET, m.n_state, B)
        fill!(C, zero(ET))
    end

result_1, nn_kw_1, st_proj = _ode_inner_step(
m, h, C, forc_3d, 1, ps, st_proj, global_kw, fixed_kw, static_non_state_kw
)
C = C .+ result_1[m.deriv_name]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ODE integration uses a fixed Euler step with an implicit $\Delta t = 1$. This assumes that the model's time units perfectly match the data frequency. For physical consistency and flexibility across different datasets, it would be better to allow an explicit dt parameter or derive it from the input data.

Comment thread src/models/ODEHybridModel.jl Outdated
Comment on lines +351 to +356
tgt_trajs = NamedTuple{Tuple(m.targets)}(
Tuple(vcat(tgt_trajs[tgt], result_t[tgt]) for tgt in m.targets)
)
nn_trajs = NamedTuple{Tuple(m.lstm_param_names)}(
Tuple(vcat(nn_trajs[n], nn_kw_t[n]) for n in m.lstm_param_names)
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accumulating trajectories by repeatedly calling vcat and reconstructing NamedTuples inside a loop is highly inefficient in Julia. It leads to $O(T^2)$ allocations and can significantly slow down both the forward pass and automatic differentiation (AD). A more performant approach is to collect the results in a Vector and then use reduce(vcat, ...) or stack at the end of the loop.


if forc_3d !== nothing
forc_t = forc_3d[:, t, :]
forc_kw = (; zip(m.forcing, [forc_t[i:i, :] for i in 1:length(m.forcing)])...)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Constructing a NamedTuple using zip and splatting inside the time loop is computationally expensive due to repeated allocations and potential dynamic dispatch. Since the forcing names are known at construction time, consider pre-slicing the forcing data or using a more efficient way to pass these values to the mechanistic model.




dCdt(;C, RECO, GPP) = RECO(;C, Rb, Q10, TA) .- GPP(;)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line appears to be logically incorrect. RECO and GPP are passed as arguments (which are likely arrays based on the step functions above), but they are being called as functions here. Additionally, Rb, Q10, and TA are not defined in this scope. If this file is intended to be a functional part of the repository, it needs significant cleanup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant