forked from sashabaranov/go-openai
-
Notifications
You must be signed in to change notification settings - Fork 1
/
client_test.go
147 lines (136 loc) · 3.21 KB
/
client_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package openai //nolint:testpackage // testing private field
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"testing"
)
func TestClient(t *testing.T) {
const mockToken = "mock token"
client := NewClient(mockToken)
if client.config.authToken != mockToken {
t.Errorf("Client does not contain proper token")
}
const mockOrg = "mock org"
client = NewOrgClient(mockToken, mockOrg)
if client.config.authToken != mockToken {
t.Errorf("Client does not contain proper token")
}
if client.config.OrgID != mockOrg {
t.Errorf("Client does not contain proper orgID")
}
}
func TestDecodeResponse(t *testing.T) {
stringInput := ""
testCases := []struct {
name string
value interface{}
body io.Reader
}{
{
name: "nil input",
value: nil,
body: bytes.NewReader([]byte("")),
},
{
name: "string input",
value: &stringInput,
body: bytes.NewReader([]byte("test")),
},
{
name: "map input",
value: &map[string]interface{}{},
body: bytes.NewReader([]byte(`{"test": "test"}`)),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := decodeResponse(tc.body, tc.value)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
}
func TestHandleErrorResp(t *testing.T) {
// var errRes *ErrorResponse
var errRes ErrorResponse
var reqErr RequestError
t.Log(errRes, errRes.Error)
if errRes.Error != nil {
reqErr.Err = errRes.Error
}
t.Log(fmt.Errorf("error, %w", &reqErr))
t.Log(errRes.Error, "nil pointer check Pass")
const mockToken = "mock token"
client := NewClient(mockToken)
testCases := []struct {
name string
httpCode int
body io.Reader
expected string
}{
{
name: "401 Invalid Authentication",
httpCode: http.StatusUnauthorized,
body: bytes.NewReader([]byte(
`{
"error":{
"message":"You didn't provide an API key. ....",
"type":"invalid_request_error",
"param":null,
"code":null
}
}`,
)),
expected: "error, status code: 401, message: You didn't provide an API key. ....",
},
{
name: "401 Azure Access Denied",
httpCode: http.StatusUnauthorized,
body: bytes.NewReader([]byte(
`{
"error":{
"code":"AccessDenied",
"message":"Access denied due to Virtual Network/Firewall rules."
}
}`,
)),
expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.",
},
{
name: "503 Model Overloaded",
httpCode: http.StatusServiceUnavailable,
body: bytes.NewReader([]byte(`
{
"error":{
"message":"That model...",
"type":"server_error",
"param":null,
"code":null
}
}`)),
expected: "error, status code: 503, message: That model...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
testCase := &http.Response{}
testCase.StatusCode = tc.httpCode
testCase.Body = io.NopCloser(tc.body)
err := client.handleErrorResp(testCase)
t.Log(err.Error())
if err.Error() != tc.expected {
t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected)
t.Fail()
}
e := &APIError{}
if !errors.As(err, &e) {
t.Errorf("(%s) Expected error to be of type APIError", tc.name)
t.Fail()
}
})
}
}