Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions src/api/_util/resourcelimit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable, Sequence
from datetime import UTC, datetime

from sqlalchemy import delete, func
from sqlalchemy import delete, func, not_
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.ext.asyncio import AsyncConnection
from sqlmodel import col, select
Expand Down Expand Up @@ -494,17 +494,25 @@ async def get_current_organization_allocations(
*,
exclude_branch_ids: Sequence[Identifier] | None = None,
) -> dict[ResourceType, int]:
result = await session.execute(
select(BranchProvisioning).join(Branch).join(Project).where(Project.organization_id == organization_id)
status_column = col(Branch.status)
branch_id_column = col(BranchProvisioning.branch_id)

stmt = (
select(BranchProvisioning)
.join(Branch)
.join(Project)
.where(
Project.organization_id == organization_id,
not_(status_column.in_([BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING])),
)
)
rows = list(result.scalars().all())
if exclude_branch_ids:
excluded = set(exclude_branch_ids)
rows = [row for row in rows if row.branch_id not in excluded]
stmt = stmt.where(not_(branch_id_column.in_(set(exclude_branch_ids))))

result = await session.execute(stmt)
rows = list(result.scalars().all())
grouped = _group_by_resource_type(rows)
branch_statuses = await _collect_branch_statuses(session, rows)
return _aggregate_group_by_resource_type(grouped, branch_statuses)
return _aggregate_group_by_resource_type(grouped)


async def get_current_project_allocations(
Expand All @@ -513,49 +521,33 @@ async def get_current_project_allocations(
*,
exclude_branch_ids: Sequence[Identifier] | None = None,
) -> dict[ResourceType, int]:
result = await session.execute(select(BranchProvisioning).join(Branch).where(Branch.project_id == project_id))
rows = list(result.scalars().all())
status_column = col(Branch.status)
branch_id_column = col(BranchProvisioning.branch_id)

stmt = (
select(BranchProvisioning)
.join(Branch)
.where(
Branch.project_id == project_id,
not_(status_column.in_([BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING])),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dangerous (and I guess it was also wrong with the old code). Stopped branches don't count towards CPU / RAM allocation, but I think storage (db, object) still count, even when stopped. Only deleted branches are out of the game. WDYT @boddumanohar?

)
)
if exclude_branch_ids:
excluded = set(exclude_branch_ids)
rows = [row for row in rows if row.branch_id not in excluded]
stmt = stmt.where(not_(branch_id_column.in_(set(exclude_branch_ids))))

result = await session.execute(stmt)
rows = list(result.scalars().all())
grouped = _group_by_resource_type(rows)
branch_statuses = await _collect_branch_statuses(session, rows)
return _aggregate_group_by_resource_type(grouped, branch_statuses)
return _aggregate_group_by_resource_type(grouped)


def _aggregate_group_by_resource_type(
grouped: dict[ResourceType, list[BranchProvisioning]], branch_statuses: dict[Identifier, BranchServiceStatus]
) -> dict[ResourceType, int]:
def _aggregate_group_by_resource_type(grouped: dict[ResourceType, list[BranchProvisioning]]) -> dict[ResourceType, int]:
return {
resource_type: sum(
allocation.amount
for allocation in allocations
if (allocation.branch_id is not None)
and (
branch_statuses.get(allocation.branch_id)
not in {BranchServiceStatus.STOPPED, BranchServiceStatus.DELETING}
)
)
resource_type: sum(allocation.amount for allocation in allocations if allocation.branch_id is not None)
for resource_type, allocations in grouped.items()
}


async def _collect_branch_statuses(
_session: SessionDep, rows: list[BranchProvisioning]
) -> dict[Identifier, BranchServiceStatus]:
branch_ids = {row.branch_id for row in rows if row.branch_id is not None}
if not branch_ids:
return {}

from ..organization.project import branch as branch_module

statuses: dict[Identifier, BranchServiceStatus] = {}
for branch_id in branch_ids:
statuses[branch_id] = await branch_module.refresh_branch_status(branch_id)
return statuses


def _group_by_resource_type(allocations: list[BranchProvisioning]) -> dict[ResourceType, list[BranchProvisioning]]:
result: dict[ResourceType, list[BranchProvisioning]] = {}
for allocation in allocations:
Expand Down