[tinker] Fix single request batching in TinkerEngine#1489
[tinker] Fix single request batching in TinkerEngine#1489pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces scheduling barriers to ensure that single requests, such as optimization steps, do not execute before preceding forward or forward-backward passes for the same model. It refactors the identification of destructive barriers into a reusable helper method and adds several regression tests to verify the scheduling logic. Feedback was provided regarding the performance of the logic used to identify blocked passes in find_single_requests, noting that the current implementation could be optimized to avoid potential performance issues as the number of pending requests grows.
| if destructive_barriers: | ||
| pending_passes = session.exec( | ||
| select(FutureDB.model_id, FutureDB.request_id) | ||
| .where( | ||
| (FutureDB.request_type == types.RequestType.FORWARD_BACKWARD) | ||
| | (FutureDB.request_type == types.RequestType.FORWARD) | ||
| ) | ||
| .where(FutureDB.status == RequestStatus.PENDING) | ||
| .order_by(FutureDB.request_id) | ||
| ).all() | ||
| for model_id, req_id in pending_passes: | ||
| if model_id in destructive_barriers and req_id >= destructive_barriers[model_id]: | ||
| blocked_pass_barriers.setdefault(model_id, req_id) |
There was a problem hiding this comment.
The logic for identifying blocked passes involves a nested loop over all pending passes for every call to find_single_requests. This can be optimized by using a dictionary lookup or a more efficient SQL query to avoid O(N*M) complexity where N is the number of pending passes and M is the number of models.
For each model, we currently make sure to not batch forward/backward requests that come after a destructive update like optim_step or load_weights. In addition, we also need to make sure to not batch destructive updates that come after forward/backward requests.
E.g. for a sequence like optim1 → fwdbwd2 → optim2, we do not want to process optim2 before fwdbwd2 has run.