1
1
from unittest import mock
2
2
3
+ import pytest
3
4
from django .test import override_settings
4
5
5
6
from authlib .common .urls import url_decode
19
20
class DjangoOAuthTest (TestCase ):
20
21
def test_register_remote_app (self ):
21
22
oauth = OAuth ()
22
- self .assertRaises (AttributeError , lambda : oauth .dev )
23
+ with pytest .raises (AttributeError ):
24
+ oauth .dev # noqa:B018
23
25
24
26
oauth .register (
25
27
"dev" ,
@@ -30,8 +32,8 @@ def test_register_remote_app(self):
30
32
access_token_url = "https://i.b/token" ,
31
33
authorize_url = "https://i.b/authorize" ,
32
34
)
33
- self . assertEqual ( oauth .dev .name , "dev" )
34
- self . assertEqual ( oauth .dev .client_id , "dev" )
35
+ assert oauth .dev .name == "dev"
36
+ assert oauth .dev .client_id == "dev"
35
37
36
38
def test_register_with_overwrite (self ):
37
39
oauth = OAuth ()
@@ -46,15 +48,15 @@ def test_register_with_overwrite(self):
46
48
access_token_params = {"foo" : "foo" },
47
49
authorize_url = "https://i.b/authorize" ,
48
50
)
49
- self . assertEqual ( oauth .dev_overwrite .client_id , "dev-client-id" )
50
- self . assertEqual ( oauth .dev_overwrite .access_token_params ["foo" ], "foo-1" )
51
+ assert oauth .dev_overwrite .client_id == "dev-client-id"
52
+ assert oauth .dev_overwrite .access_token_params ["foo" ] == "foo-1"
51
53
52
54
@override_settings (AUTHLIB_OAUTH_CLIENTS = {"dev" : dev_client })
53
55
def test_register_from_settings (self ):
54
56
oauth = OAuth ()
55
57
oauth .register ("dev" )
56
- self . assertEqual ( oauth .dev .client_id , "dev-key" )
57
- self . assertEqual ( oauth .dev .client_secret , "dev-secret" )
58
+ assert oauth .dev .client_id == "dev-key"
59
+ assert oauth .dev .client_secret == "dev-secret"
58
60
59
61
def test_oauth1_authorize (self ):
60
62
request = self .factory .get ("/login" )
@@ -75,16 +77,16 @@ def test_oauth1_authorize(self):
75
77
send .return_value = mock_send_value ("oauth_token=foo&oauth_verifier=baz" )
76
78
77
79
resp = client .authorize_redirect (request )
78
- self . assertEqual ( resp .status_code , 302 )
80
+ assert resp .status_code == 302
79
81
url = resp .get ("Location" )
80
- self . assertIn ( "oauth_token=foo" , url )
82
+ assert "oauth_token=foo" in url
81
83
82
84
request2 = self .factory .get (url )
83
85
request2 .session = request .session
84
86
with mock .patch ("requests.sessions.Session.send" ) as send :
85
87
send .return_value = mock_send_value ("oauth_token=a&oauth_token_secret=b" )
86
88
token = client .authorize_access_token (request2 )
87
- self . assertEqual ( token ["oauth_token" ], "a" )
89
+ assert token ["oauth_token" ] == "a"
88
90
89
91
def test_oauth2_authorize (self ):
90
92
request = self .factory .get ("/login" )
@@ -100,9 +102,9 @@ def test_oauth2_authorize(self):
100
102
authorize_url = "https://i.b/authorize" ,
101
103
)
102
104
rv = client .authorize_redirect (request , "https://a.b/c" )
103
- self . assertEqual ( rv .status_code , 302 )
105
+ assert rv .status_code == 302
104
106
url = rv .get ("Location" )
105
- self . assertIn ( "state=" , url )
107
+ assert "state=" in url
106
108
state = dict (url_decode (urlparse .urlparse (url ).query ))["state" ]
107
109
108
110
with mock .patch ("requests.sessions.Session.send" ) as send :
@@ -111,7 +113,7 @@ def test_oauth2_authorize(self):
111
113
request2 .session = request .session
112
114
113
115
token = client .authorize_access_token (request2 )
114
- self . assertEqual ( token ["access_token" ], "a" )
116
+ assert token ["access_token" ] == "a"
115
117
116
118
def test_oauth2_authorize_access_denied (self ):
117
119
oauth = OAuth ()
@@ -129,7 +131,8 @@ def test_oauth2_authorize_access_denied(self):
129
131
"/?error=access_denied&error_description=Not+Allowed"
130
132
)
131
133
request .session = self .factory .session
132
- self .assertRaises (OAuthError , client .authorize_access_token , request )
134
+ with pytest .raises (OAuthError ):
135
+ client .authorize_access_token (request )
133
136
134
137
def test_oauth2_authorize_code_challenge (self ):
135
138
request = self .factory .get ("/login" )
@@ -145,24 +148,24 @@ def test_oauth2_authorize_code_challenge(self):
145
148
client_kwargs = {"code_challenge_method" : "S256" },
146
149
)
147
150
rv = client .authorize_redirect (request , "https://a.b/c" )
148
- self . assertEqual ( rv .status_code , 302 )
151
+ assert rv .status_code == 302
149
152
url = rv .get ("Location" )
150
- self . assertIn ( "state=" , url )
151
- self . assertIn ( "code_challenge=" , url )
153
+ assert "state=" in url
154
+ assert "code_challenge=" in url
152
155
153
156
state = dict (url_decode (urlparse .urlparse (url ).query ))["state" ]
154
157
state_data = request .session [f"_state_dev_{ state } " ]["data" ]
155
158
verifier = state_data ["code_verifier" ]
156
159
157
160
def fake_send (sess , req , ** kwargs ):
158
- self . assertIn ( f"code_verifier={ verifier } " , req .body )
161
+ assert f"code_verifier={ verifier } " in req .body
159
162
return mock_send_value (get_bearer_token ())
160
163
161
164
with mock .patch ("requests.sessions.Session.send" , fake_send ):
162
165
request2 = self .factory .get (f"/authorize?state={ state } " )
163
166
request2 .session = request .session
164
167
token = client .authorize_access_token (request2 )
165
- self . assertEqual ( token ["access_token" ], "a" )
168
+ assert token ["access_token" ] == "a"
166
169
167
170
def test_oauth2_authorize_code_verifier (self ):
168
171
request = self .factory .get ("/login" )
@@ -182,10 +185,10 @@ def test_oauth2_authorize_code_verifier(self):
182
185
rv = client .authorize_redirect (
183
186
request , "https://a.b/c" , state = state , code_verifier = code_verifier
184
187
)
185
- self . assertEqual ( rv .status_code , 302 )
188
+ assert rv .status_code == 302
186
189
url = rv .get ("Location" )
187
- self . assertIn ( "state=" , url )
188
- self . assertIn ( "code_challenge=" , url )
190
+ assert "state=" in url
191
+ assert "code_challenge=" in url
189
192
190
193
with mock .patch ("requests.sessions.Session.send" ) as send :
191
194
send .return_value = mock_send_value (get_bearer_token ())
@@ -194,7 +197,7 @@ def test_oauth2_authorize_code_verifier(self):
194
197
request2 .session = request .session
195
198
196
199
token = client .authorize_access_token (request2 )
197
- self . assertEqual ( token ["access_token" ], "a" )
200
+ assert token ["access_token" ] == "a"
198
201
199
202
def test_openid_authorize (self ):
200
203
request = self .factory .get ("/login" )
@@ -213,9 +216,9 @@ def test_openid_authorize(self):
213
216
)
214
217
215
218
resp = client .authorize_redirect (request , "https://b.com/bar" )
216
- self . assertEqual ( resp .status_code , 302 )
219
+ assert resp .status_code == 302
217
220
url = resp .get ("Location" )
218
- self . assertIn ( "nonce=" , url )
221
+ assert "nonce=" in url
219
222
query_data = dict (url_decode (urlparse .urlparse (url ).query ))
220
223
221
224
token = get_bearer_token ()
@@ -237,9 +240,9 @@ def test_openid_authorize(self):
237
240
request2 .session = request .session
238
241
239
242
token = client .authorize_access_token (request2 )
240
- self . assertEqual ( token ["access_token" ], "a" )
241
- self . assertIn ( "userinfo" , token )
242
- self . assertEqual ( token ["userinfo" ]["sub" ], "123" )
243
+ assert token ["access_token" ] == "a"
244
+ assert "userinfo" in token
245
+ assert token ["userinfo" ]["sub" ] == "123"
243
246
244
247
def test_oauth2_access_token_with_post (self ):
245
248
oauth = OAuth ()
@@ -259,7 +262,7 @@ def test_oauth2_access_token_with_post(self):
259
262
request .session = self .factory .session
260
263
request .session ["_state_dev_b" ] = {"data" : {}}
261
264
token = client .authorize_access_token (request )
262
- self . assertEqual ( token ["access_token" ], "a" )
265
+ assert token ["access_token" ] == "a"
263
266
264
267
def test_with_fetch_token_in_oauth (self ):
265
268
def fetch_token (name , request ):
@@ -276,7 +279,7 @@ def fetch_token(name, request):
276
279
)
277
280
278
281
def fake_send (sess , req , ** kwargs ):
279
- self . assertEqual ( sess .token ["access_token" ], "dev" )
282
+ assert sess .token ["access_token" ] == "dev"
280
283
return mock_send_value (get_bearer_token ())
281
284
282
285
with mock .patch ("requests.sessions.Session.send" , fake_send ):
@@ -299,7 +302,7 @@ def fetch_token(request):
299
302
)
300
303
301
304
def fake_send (sess , req , ** kwargs ):
302
- self . assertEqual ( sess .token ["access_token" ], "dev" )
305
+ assert sess .token ["access_token" ] == "dev"
303
306
return mock_send_value (get_bearer_token ())
304
307
305
308
with mock .patch ("requests.sessions.Session.send" , fake_send ):
@@ -319,13 +322,14 @@ def test_request_without_token(self):
319
322
320
323
def fake_send (sess , req , ** kwargs ):
321
324
auth = req .headers .get ("Authorization" )
322
- self . assertIsNone ( auth )
325
+ assert auth is None
323
326
resp = mock .MagicMock ()
324
327
resp .text = "hi"
325
328
resp .status_code = 200
326
329
return resp
327
330
328
331
with mock .patch ("requests.sessions.Session.send" , fake_send ):
329
332
resp = client .get ("/api/user" , withhold_token = True )
330
- self .assertEqual (resp .text , "hi" )
331
- self .assertRaises (OAuthError , client .get , "https://i.b/api/user" )
333
+ assert resp .text == "hi"
334
+ with pytest .raises (OAuthError ):
335
+ client .get ("https://i.b/api/user" )
0 commit comments