@@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) {
59
59
checks .NoError (t , err , "CreateCompletion error" )
60
60
}
61
61
62
+ // TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server
63
+ // where the completions requests has a list of prompts with wrong type.
64
+ func TestMultiplePromptsCompletionsWrong (t * testing.T ) {
65
+ client , server , teardown := setupOpenAITestServer ()
66
+ defer teardown ()
67
+ server .RegisterHandler ("/v1/completions" , handleCompletionEndpoint )
68
+ req := openai.CompletionRequest {
69
+ MaxTokens : 5 ,
70
+ Model : "ada" ,
71
+ Prompt : []interface {}{"Lorem ipsum" , 9 },
72
+ }
73
+ _ , err := client .CreateCompletion (context .Background (), req )
74
+ if ! errors .Is (err , openai .ErrCompletionRequestPromptTypeNotSupported ) {
75
+ t .Fatalf ("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v" , err )
76
+ }
77
+ }
78
+
79
+ // TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server
80
+ // where the completions requests has a list of prompts.
81
+ func TestMultiplePromptsCompletions (t * testing.T ) {
82
+ client , server , teardown := setupOpenAITestServer ()
83
+ defer teardown ()
84
+ server .RegisterHandler ("/v1/completions" , handleCompletionEndpoint )
85
+ req := openai.CompletionRequest {
86
+ MaxTokens : 5 ,
87
+ Model : "ada" ,
88
+ Prompt : []interface {}{"Lorem ipsum" , "Lorem ipsum" },
89
+ }
90
+ _ , err := client .CreateCompletion (context .Background (), req )
91
+ checks .NoError (t , err , "CreateCompletion error" )
92
+ }
93
+
62
94
// handleCompletionEndpoint Handles the completion endpoint by the test server.
63
95
func handleCompletionEndpoint (w http.ResponseWriter , r * http.Request ) {
64
96
var err error
@@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
87
119
if n == 0 {
88
120
n = 1
89
121
}
122
+ // Handle different types of prompts: single string or list of strings
123
+ prompts := []string {}
124
+ switch v := completionReq .Prompt .(type ) {
125
+ case string :
126
+ prompts = append (prompts , v )
127
+ case []interface {}:
128
+ for _ , item := range v {
129
+ if str , ok := item .(string ); ok {
130
+ prompts = append (prompts , str )
131
+ }
132
+ }
133
+ default :
134
+ http .Error (w , "Invalid prompt type" , http .StatusBadRequest )
135
+ return
136
+ }
137
+
90
138
for i := 0 ; i < n ; i ++ {
91
- // generate a random string of length completionReq.Length
92
- completionStr := strings .Repeat ("a" , completionReq .MaxTokens )
93
- if completionReq .Echo {
94
- completionStr = completionReq .Prompt .(string ) + completionStr
139
+ for _ , prompt := range prompts {
140
+ // Generate a random string of length completionReq.MaxTokens
141
+ completionStr := strings .Repeat ("a" , completionReq .MaxTokens )
142
+ if completionReq .Echo {
143
+ completionStr = prompt + completionStr
144
+ }
145
+
146
+ res .Choices = append (res .Choices , openai.CompletionChoice {
147
+ Text : completionStr ,
148
+ Index : len (res .Choices ),
149
+ })
95
150
}
96
- res .Choices = append (res .Choices , openai.CompletionChoice {
97
- Text : completionStr ,
98
- Index : i ,
99
- })
100
151
}
101
- inputTokens := numTokens (completionReq .Prompt .(string )) * n
102
- completionTokens := completionReq .MaxTokens * n
152
+
153
+ inputTokens := 0
154
+ for _ , prompt := range prompts {
155
+ inputTokens += numTokens (prompt )
156
+ }
157
+ inputTokens *= n
158
+ completionTokens := completionReq .MaxTokens * len (prompts ) * n
103
159
res .Usage = openai.Usage {
104
160
PromptTokens : inputTokens ,
105
161
CompletionTokens : completionTokens ,
106
162
TotalTokens : inputTokens + completionTokens ,
107
163
}
164
+
165
+ // Serialize the response and send it back
108
166
resBytes , _ = json .Marshal (res )
109
167
fmt .Fprintln (w , string (resBytes ))
110
168
}
0 commit comments