Skip to content

Commit 374205d

Browse files
committed
Allow extra headers for specific conversations too
This commit adds support for setting extra HTTP headers for all requests issued as part of a specific conversation. These headers will be sent in addition to any extra headers defined for the backend itself, if any, and will take precedence over them. The bedrock implementation does not support extra headers at all at this point.
1 parent 1ad9e02 commit 374205d

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

libaiac/bedrock/chat.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,6 @@ func (conv *Conversation) Messages() []types.Message {
120120
}
121121
return msgs
122122
}
123+
124+
// AddHeader is a noop for the bedrock implementation
125+
func (conv *Conversation) AddHeader(_ string, _ string) {}

libaiac/ollama/chat.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import (
1111
// Conversation is a struct used to converse with an Ollama chat model. It
1212
// maintains all messages sent/received in order to maintain context.
1313
type Conversation struct {
14-
backend *Ollama
15-
model string
16-
messages []types.Message
14+
backend *Ollama
15+
model string
16+
messages []types.Message
17+
extraHeaders map[string]string
1718
}
1819

1920
type chatResponse struct {
@@ -55,7 +56,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
5556
Content: prompt,
5657
})
5758

58-
err = conv.backend.NewRequest("POST", "/chat").
59+
req := conv.backend.NewRequest("POST", "/chat").
5960
JSONBody(map[string]interface{}{
6061
"model": conv.model,
6162
"messages": conv.messages,
@@ -64,8 +65,13 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
6465
},
6566
"stream": false,
6667
}).
67-
Into(&answer).
68-
RunContext(ctx)
68+
Into(&answer)
69+
70+
for key, val := range conv.extraHeaders {
71+
req.Header(key, val)
72+
}
73+
74+
err = req.RunContext(ctx)
6975
if err != nil {
7076
return res, fmt.Errorf("failed sending prompt: %w", err)
7177
}
@@ -92,3 +98,14 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
9298
func (conv *Conversation) Messages() []types.Message {
9399
return conv.messages
94100
}
101+
102+
// AddHeader adds an extra HTTP header that will be added to every HTTP
103+
// request issued as part of this conversation. Any headers added will be in
104+
// addition to any extra headers defined for the backend itself, and will
105+
// take precedence over them.
106+
func (conv *Conversation) AddHeader(key, val string) {
107+
if conv.extraHeaders == nil {
108+
conv.extraHeaders = make(map[string]string)
109+
}
110+
conv.extraHeaders[key] = val
111+
}

libaiac/openai/chat.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ import (
1212
// maintains all messages sent/received in order to maintain context just like
1313
// using ChatGPT.
1414
type Conversation struct {
15-
backend *OpenAI
16-
model string
17-
messages []types.Message
15+
backend *OpenAI
16+
model string
17+
messages []types.Message
18+
extraHeaders map[string]string
1819
}
1920

2021
type chatResponse struct {
@@ -67,15 +68,20 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
6768
apiVersion = fmt.Sprintf("?api-version=%s", conv.backend.apiVersion)
6869
}
6970

70-
err = conv.backend.
71+
req := conv.backend.
7172
NewRequest("POST", fmt.Sprintf("/chat/completions%s", apiVersion)).
7273
JSONBody(map[string]interface{}{
7374
"model": conv.model,
7475
"messages": conv.messages,
7576
"temperature": 0.2,
7677
}).
77-
Into(&answer).
78-
RunContext(ctx)
78+
Into(&answer)
79+
80+
for key, val := range conv.extraHeaders {
81+
req.Header(key, val)
82+
}
83+
84+
err = req.RunContext(ctx)
7985
if err != nil {
8086
return res, fmt.Errorf("failed sending prompt: %w", err)
8187
}
@@ -104,3 +110,14 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
104110
func (conv *Conversation) Messages() []types.Message {
105111
return conv.messages
106112
}
113+
114+
// AddHeader adds an extra HTTP header that will be added to every HTTP
115+
// request issued as part of this conversation. Any headers added will be in
116+
// addition to any extra headers defined for the backend itself, and will
117+
// take precedence over them.
118+
func (conv *Conversation) AddHeader(key, val string) {
119+
if conv.extraHeaders == nil {
120+
conv.extraHeaders = make(map[string]string)
121+
}
122+
conv.extraHeaders[key] = val
123+
}

libaiac/types/interfaces.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ type Conversation interface {
2424
Send(context.Context, string) (Response, error)
2525

2626
// Messages returns all the messages that have been exchanged between the
27-
// user and the assistant up to this point
27+
// user and the assistant up to this point.
2828
Messages() []Message
29+
30+
// AddHeader adds an extra HTTP header that will be added to every HTTP
31+
// request issued as part of this conversation. Any headers added will be in
32+
// addition to any extra headers defined for the backend itself, and will
33+
// take precedence over them. Not all providers may support this
34+
// (specifically, bedrock doesn't).
35+
AddHeader(string, string)
2936
}

0 commit comments

Comments
 (0)