Skip to content

Commit 6e08732

Browse files
Updated checkPromptType function to handle prompt list in completions (#885)
* updated checkPromptType function to handle prompt list in completions * removed generated test file * added corresponding unit testcases * Updated to use less nesting with early returns
1 parent 3672c0d commit 6e08732

File tree

2 files changed

+85
-11
lines changed

2 files changed

+85
-11
lines changed

completion.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool {
161161
func checkPromptType(prompt any) bool {
162162
_, isString := prompt.(string)
163163
_, isStringSlice := prompt.([]string)
164-
return isString || isStringSlice
164+
if isString || isStringSlice {
165+
return true
166+
}
167+
168+
// check if it is prompt is []string hidden under []any
169+
slice, isSlice := prompt.([]any)
170+
if !isSlice {
171+
return false
172+
}
173+
174+
for _, item := range slice {
175+
_, itemIsString := item.(string)
176+
if !itemIsString {
177+
return false
178+
}
179+
}
180+
return true // all items in the slice are string, so it is []string
165181
}
166182

167183
var unsupportedToolsForO1Models = map[ToolType]struct{}{

completion_test.go

+68-10
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) {
5959
checks.NoError(t, err, "CreateCompletion error")
6060
}
6161

62+
// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server
63+
// where the completions requests has a list of prompts with wrong type.
64+
func TestMultiplePromptsCompletionsWrong(t *testing.T) {
65+
client, server, teardown := setupOpenAITestServer()
66+
defer teardown()
67+
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
68+
req := openai.CompletionRequest{
69+
MaxTokens: 5,
70+
Model: "ada",
71+
Prompt: []interface{}{"Lorem ipsum", 9},
72+
}
73+
_, err := client.CreateCompletion(context.Background(), req)
74+
if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) {
75+
t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err)
76+
}
77+
}
78+
79+
// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server
80+
// where the completions requests has a list of prompts.
81+
func TestMultiplePromptsCompletions(t *testing.T) {
82+
client, server, teardown := setupOpenAITestServer()
83+
defer teardown()
84+
server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
85+
req := openai.CompletionRequest{
86+
MaxTokens: 5,
87+
Model: "ada",
88+
Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"},
89+
}
90+
_, err := client.CreateCompletion(context.Background(), req)
91+
checks.NoError(t, err, "CreateCompletion error")
92+
}
93+
6294
// handleCompletionEndpoint Handles the completion endpoint by the test server.
6395
func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
6496
var err error
@@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
87119
if n == 0 {
88120
n = 1
89121
}
122+
// Handle different types of prompts: single string or list of strings
123+
prompts := []string{}
124+
switch v := completionReq.Prompt.(type) {
125+
case string:
126+
prompts = append(prompts, v)
127+
case []interface{}:
128+
for _, item := range v {
129+
if str, ok := item.(string); ok {
130+
prompts = append(prompts, str)
131+
}
132+
}
133+
default:
134+
http.Error(w, "Invalid prompt type", http.StatusBadRequest)
135+
return
136+
}
137+
90138
for i := 0; i < n; i++ {
91-
// generate a random string of length completionReq.Length
92-
completionStr := strings.Repeat("a", completionReq.MaxTokens)
93-
if completionReq.Echo {
94-
completionStr = completionReq.Prompt.(string) + completionStr
139+
for _, prompt := range prompts {
140+
// Generate a random string of length completionReq.MaxTokens
141+
completionStr := strings.Repeat("a", completionReq.MaxTokens)
142+
if completionReq.Echo {
143+
completionStr = prompt + completionStr
144+
}
145+
146+
res.Choices = append(res.Choices, openai.CompletionChoice{
147+
Text: completionStr,
148+
Index: len(res.Choices),
149+
})
95150
}
96-
res.Choices = append(res.Choices, openai.CompletionChoice{
97-
Text: completionStr,
98-
Index: i,
99-
})
100151
}
101-
inputTokens := numTokens(completionReq.Prompt.(string)) * n
102-
completionTokens := completionReq.MaxTokens * n
152+
153+
inputTokens := 0
154+
for _, prompt := range prompts {
155+
inputTokens += numTokens(prompt)
156+
}
157+
inputTokens *= n
158+
completionTokens := completionReq.MaxTokens * len(prompts) * n
103159
res.Usage = openai.Usage{
104160
PromptTokens: inputTokens,
105161
CompletionTokens: completionTokens,
106162
TotalTokens: inputTokens + completionTokens,
107163
}
164+
165+
// Serialize the response and send it back
108166
resBytes, _ = json.Marshal(res)
109167
fmt.Fprintln(w, string(resBytes))
110168
}

0 commit comments

Comments
 (0)