Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions src/codegen/git/clients/git_repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class GitRepoClient:
repo_config: RepoConfig
gh_client: GithubClient
_repo: Repository
_supports_draft_prs: bool | None = None

def __init__(self, repo_config: RepoConfig, access_token: str | None = None) -> None:
self.repo_config = repo_config
Expand Down Expand Up @@ -58,6 +59,31 @@ def repo(self) -> Repository:
def default_branch(self) -> str:
return self.repo.default_branch

def accepts_draft_prs(self) -> bool:
"""Determines if a repository supports draft PRs.

This uses a heuristic based on repository visibility and plan features.
Public repositories always support draft PRs.
For private repositories, we use a cached result if available to avoid repeated checks.

Returns:
bool: True if the repository supports draft PRs, False otherwise.
"""
# If we've already checked, return the cached result
if self._supports_draft_prs is not None:
return self._supports_draft_prs

# Public repositories always support draft PRs
if self.repo.visibility == "public":
self._supports_draft_prs = True
return True

# For private repositories, we'll use a conservative approach
# and assume they don't support draft PRs by default
# This can be refined in the future with more specific checks
self._supports_draft_prs = False
return False

####################################################################################################################
# CONTENTS
####################################################################################################################
Expand Down Expand Up @@ -199,19 +225,32 @@ def create_pull(
if base_branch_name is None:
base_branch_name = self.default_branch

# draft PRs are not supported on all private repos
# TODO: check repo plan features instead of this heuristic
if self.repo.visibility == "private":
logger.info(f"Repo {self.repo.name} is private. Disabling draft PRs.")
draft = False
# Determine if we should attempt to create a draft PR
should_try_draft = draft and self.accepts_draft_prs()

try:
pr = self.repo.create_pull(title=title or f"Draft PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=draft)
logger.info(f"Created pull request for head branch: {head_branch_name} at {pr.html_url}")
# NOTE: return a read-only copy to prevent people from editing it
# First attempt to create the PR with the requested draft status
pr = self.repo.create_pull(title=title or f"PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=should_try_draft)
logger.info(f"Created {'draft ' if should_try_draft else ''}pull request for head branch: {head_branch_name} at {pr.html_url}")
# Return a read-only copy to prevent people from editing it
return self.repo.get_pull(pr.number)
except GithubException as ge:
logger.warning(f"Failed to create PR got GithubException\n\t{ge}")
# Check specifically for the "Draft pull requests are not supported" error
if draft and ge.status == 422 and "Draft pull requests are not supported in this repository" in str(ge):
logger.info(f"Draft PRs not supported in repository {self.repo.name}. Trying to create a regular PR instead.")
# Update our cached knowledge about draft PR support
self._supports_draft_prs = False

# Try again with draft=False
try:
pr = self.repo.create_pull(title=title or f"PR for {head_branch_name}", body=body or "", head=head_branch_name, base=base_branch_name, draft=False)
logger.info(f"Created regular pull request for head branch: {head_branch_name} at {pr.html_url}")
# Return a read-only copy
return self.repo.get_pull(pr.number)
except Exception as e:
logger.warning(f"Failed to create regular PR after draft PR failed:\n\t{e}")
else:
logger.warning(f"Failed to create PR got GithubException\n\t{ge}")
except Exception as e:
logger.warning(f"Failed to create PR:\n\t{e}")

Expand Down