-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathstatement.py
More file actions
50 lines (44 loc) · 1.6 KB
/
statement.py
File metadata and controls
50 lines (44 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pydantic import (
BaseModel,
FilePath,
field_validator,
AnyUrl,
model_validator,
PlainSerializer
)
from pydantic.types import StringConstraints
from typing_extensions import Annotated
from typing import List, Any, Optional, TypeAlias
NonEmptyStr: TypeAlias = Annotated[str, StringConstraints(min_length=1)]
class CommandOutput(BaseModel):
invocation: NonEmptyStr
stdout: NonEmptyStr = '<|no output|>'
stderr: NonEmptyStr = '<|no output|>'
@field_validator('stdout', 'stderr', mode='before')
@classmethod
def _empty_to_default(cls, field_value: Any):
'''If a command had no stdout or no stderr or both,
then we should tell that to the model. Otherwise,
the model will believe that we purposely withheld
the output, or that the RAG tool messed up somewhere.
'''
if not field_value or field_value == "":
return '<|no output|>'
return field_value
class Model(BaseModel):
name: NonEmptyStr
host: Annotated[AnyUrl, PlainSerializer(lambda url: str(url))]
api_key: Optional[NonEmptyStr] = None
@model_validator(mode='before')
def validate_api_key(cls, values: Any):
if isinstance(values, dict):
if "gpt" in values["name"] and not "api_key" in values:
raise ValueError("Model must include 'api_key' if using openai.")
return values
class Statement(BaseModel):
problem_description: NonEmptyStr
goal: NonEmptyStr
relevant_command_output: List[CommandOutput] = []
relevant_files: List[FilePath] = []
use_rag: bool = True
model: Model