diff --git a/chatkit/server.py b/chatkit/server.py index 01372cb..65645f6 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -435,6 +435,7 @@ async def _process_non_streaming( attachment = await attachment_store.create_attachment( request.params, context ) + await self.store.save_attachment(attachment, context=context) return self._serialize(attachment) case AttachmentsDeleteReq(): attachment_store = self._get_attachment_store() diff --git a/chatkit/types.py b/chatkit/types.py index a0ceeea..cbb5900 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -726,16 +726,27 @@ class ToolChoice(BaseModel): id: str +class AttachmentUploadDescriptor(BaseModel): + """Two-phase upload instructions.""" + + url: AnyUrl + method: Literal["POST", "PUT"] + """The HTTP method to use when uploading the file for two-phase upload.""" + headers: dict[str, str] = Field(default_factory=dict) + """Optional headers to include in the upload request.""" + + class AttachmentBase(BaseModel): """Base metadata shared by all attachments.""" id: str name: str mime_type: str - upload_url: AnyUrl | None = None + upload_descriptor: AttachmentUploadDescriptor | None = None """ - The URL to upload the file, used for two-phase upload. - Should be set to None after upload is complete or when using direct upload where uploading happens when creating the attachment object. + Two-phase upload instructions. + Should be set to None after upload is complete or when using direct upload + where uploading happens when creating the attachment object. """ diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 41be193..ff7177c 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -31,6 +31,7 @@ AttachmentDeleteParams, AttachmentsCreateReq, AttachmentsDeleteReq, + AttachmentUploadDescriptor, ClientToolCallItem, FeedbackKind, FileAttachment, @@ -107,14 +108,22 @@ async def create_attachment( mime_type=input.mime_type, name=input.name, preview_url=AnyUrl(f"https://example.com/{id}/preview"), - upload_url=AnyUrl(f"https://example.com/{id}/upload"), + upload_descriptor=AttachmentUploadDescriptor( + url=AnyUrl(f"https://example.com/{id}/upload"), + method="PUT", + headers={"X-My-Header": "my-value"}, + ), ) else: attachment = FileAttachment( id=id, mime_type=input.mime_type, name=input.name, - upload_url=AnyUrl(f"https://example.com/{id}/upload"), + upload_descriptor=AttachmentUploadDescriptor( + url=AnyUrl(f"https://example.com/{id}/upload"), + method="PUT", + headers={"X-My-Header": "my-value"}, + ), ) self.files[attachment.id] = attachment return attachment @@ -679,6 +688,7 @@ async def responder( assert events[1].type == "thread.item.done" assert events[1].item.type == "assistant_message" + async def test_respond_with_tool_status(): async def responder( thread: ThreadMetadata, input: UserMessageItem | None, context: Any @@ -1019,9 +1029,12 @@ async def test_create_file(): assert attachment.mime_type == file_content_type assert attachment.name == file_name assert attachment.type == "file" - assert attachment.upload_url == AnyUrl( + assert attachment.upload_descriptor is not None + assert attachment.upload_descriptor.url == AnyUrl( f"https://example.com/{attachment.id}/upload" ) + assert attachment.upload_descriptor.method == "PUT" + assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"} assert attachment.id in store.files @@ -1049,9 +1062,12 @@ async def test_create_image_file(): assert attachment.preview_url == AnyUrl( f"https://example.com/{attachment.id}/preview" ) - assert attachment.upload_url == AnyUrl( + assert attachment.upload_descriptor is not None + assert attachment.upload_descriptor.url == AnyUrl( f"https://example.com/{attachment.id}/upload" ) + assert attachment.upload_descriptor.method == "PUT" + assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"} assert attachment.id in store.files