Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 9400de9

Browse files
Merge pull request #5 from discogs/headers
Add headers keyword to request methods.
2 parents c75afb9 + 0547eac commit 9400de9

File tree

9 files changed

+178
-32
lines changed

9 files changed

+178
-32
lines changed

MANIFEST.in

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ include LICENSE
33
include requirements.txt
44
include setup.py
55
include MANIFEST.in
6-
include test*
6+
include tests/*
77
include tox.ini
8+
include pytest.ini
89
recursive-include cas_client *
910
recursive-exclude cas_client *.pyc

Makefile

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
clean:
2+
find . -name '*.pyc' | xargs rm -Rif
3+
find . -name '*egg-info' | xargs rm -Rif
4+
find . -name __pycache__ | xargs rm -Rif
5+
rm -Rif dist/
6+
rm -Rif build/
7+
8+
package:
9+
python setup.py build sdist

cas_client/_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version_info__ = (0, 1, 3)
1+
__version_info__ = (0, 1, 4)
22
__version__ = '.'.join(str(_) for _ in __version_info__)

cas_client/cas_client.py

+47-17
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
proxy_callback=None,
3636
verify_certificates=False,
3737
session_storage_adapter=None,
38+
headers=None,
3839
):
3940
self._auth_prefix = auth_prefix
4041
self._proxy_callback = proxy_callback
@@ -43,16 +44,17 @@ def __init__(
4344
self._service_url = service_url
4445
self._session_storage_adapter = session_storage_adapter
4546
self._verify_certificates = bool(verify_certificates)
47+
self._headers = headers
4648

4749
### PUBLIC METHODS ###
4850

49-
def acquire_auth_token_ticket(self):
51+
def acquire_auth_token_ticket(self, headers=None):
5052
'''
5153
Acquire an auth token from the CAS server.
5254
'''
5355
logging.debug('[CAS] Acquiring Auth token ticket')
5456
url = self._get_auth_token_tickets_url()
55-
text = self._perform_post(url)
57+
text = self._perform_post(url, headers=headers)
5658
auth_token_ticket = json.loads(text)['ticket']
5759
logging.debug('[CAS] Acquire Auth token ticket: {}'.format(
5860
auth_token_ticket))
@@ -257,6 +259,7 @@ def perform_api_request(
257259
private_key,
258260
method='POST',
259261
service_url=None,
262+
headers=None,
260263
**kwargs
261264
):
262265
'''
@@ -272,34 +275,47 @@ def perform_api_request(
272275
**kwargs
273276
)
274277
if method == 'GET':
275-
response = self._perform_get(url)
278+
response = self._perform_get(url, headers=headers)
276279
elif method == 'POST':
277-
response = self._perform_post(url)
280+
response = self._perform_post(url, headers=headers)
278281
return response
279282

280-
def perform_proxy(self, proxy_ticket):
283+
def perform_proxy(self, proxy_ticket, headers=None):
281284
'''
282285
Fetch a response from the remote CAS `proxy` endpoint.
283286
'''
284287
url = self._get_proxy_url(ticket=proxy_ticket)
285288
logging.debug('[CAS] Proxy URL: {}'.format(url))
286-
return self._perform_cas_call(url, ticket=proxy_ticket)
289+
return self._perform_cas_call(
290+
url,
291+
ticket=proxy_ticket,
292+
headers=headers,
293+
)
287294

288-
def perform_proxy_validate(self, proxied_service_ticket):
295+
def perform_proxy_validate(self, proxied_service_ticket, headers=None):
289296
'''
290297
Fetch a response from the remote CAS `proxyValidate` endpoint.
291298
'''
292299
url = self._get_proxy_validate_url(ticket=proxied_service_ticket)
293300
logging.debug('[CAS] ProxyValidate URL: {}'.format(url))
294-
return self._perform_cas_call(url, ticket=proxied_service_ticket)
301+
return self._perform_cas_call(
302+
url,
303+
ticket=proxied_service_ticket,
304+
headers=headers,
305+
)
295306

296-
def perform_service_validate(self, ticket=None, service_url=None):
307+
def perform_service_validate(
308+
self,
309+
ticket=None,
310+
service_url=None,
311+
headers=None,
312+
):
297313
'''
298314
Fetch a response from the remote CAS `serviceValidate` endpoint.
299315
'''
300316
url = self._get_service_validate_url(ticket, service_url=service_url)
301317
logging.debug('[CAS] ServiceValidate URL: {}'.format(url))
302-
return self._perform_cas_call(url, ticket=ticket)
318+
return self._perform_cas_call(url, ticket=ticket, headers=headers)
303319

304320
def session_exists(self, ticket):
305321
'''
@@ -405,32 +421,42 @@ def _get_service_validate_url(self, ticket, service_url=None):
405421
url = '{url}&pgtUrl={proxy_url}'.format(url, self.proxy_url)
406422
return url
407423

408-
def _perform_cas_call(self, url, ticket):
424+
def _perform_cas_call(self, url, ticket, headers=None):
409425
if ticket is not None:
410426
logging.debug('[CAS] Requesting Ticket Validation')
411-
response_text = self._perform_get(url)
427+
response_text = self._perform_get(url, headers=headers)
412428
response_text = self._clean_up_response_text(response_text)
413429
if response_text:
414430
logging.debug('[CAS] Response:\n{}'.format(response_text))
415431
return CASResponse(response_text)
416432
logging.debug('[CAS] Response: None')
417433
return None
418434

419-
def _perform_get(self, url):
435+
def _perform_get(self, url, headers=None):
436+
headers = headers or self.headers
420437
try:
421-
response = requests.get(url, verify=self.verify_certificates)
438+
response = requests.get(
439+
url,
440+
verify=self.verify_certificates,
441+
headers=headers,
442+
)
422443
return response.text
423444
except requests.HTTPError:
424445
return None
425446

426-
def _perform_post(self, url):
447+
def _perform_post(self, url, headers=None):
448+
headers = headers or self.headers
427449
try:
428-
response = requests.post(url, verify=self.verify_certificates)
450+
response = requests.post(
451+
url,
452+
verify=self.verify_certificates,
453+
headers=headers,
454+
)
429455
return response.text
430456
except requests.HTTPError:
431457
return None
432458

433-
### PUBLIC METHODS ###
459+
### PUBLIC PROPERTIES ###
434460

435461
@property
436462
def auth_prefix(self):
@@ -439,6 +465,10 @@ def auth_prefix(self):
439465
'''
440466
return self._auth_prefix
441467

468+
@property
469+
def headers(self):
470+
return self._headers
471+
442472
@property
443473
def proxy_callback(self):
444474
'''

pytest.ini

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[pytest]
2+
addopts =
3+
--doctest-modules
4+
-vv
5+
doctest_optionflags =
6+
ELLIPSIS
7+
NORMALIZE_WHITESPACE
8+
python_files = test*.py
9+
testpaths = tests

test.py renamed to tests/test.py

+101-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# -*- encoding: utf-8 -*-
2+
import os
23
import unittest
34
from cas_client import CASClient, CASResponse
45
try:
56
from urlparse import parse_qs
67
except ImportError:
78
from urllib.parse import parse_qs
9+
try:
10+
import mock
11+
except ImportError:
12+
from unittest import mock
813

914

1015
class TestCase(unittest.TestCase):
@@ -49,6 +54,11 @@ class TestCase(unittest.TestCase):
4954
</samlp:LogoutRequest>
5055
"""
5156

57+
private_key_filepath = os.path.join(
58+
os.path.abspath(os.path.dirname(__file__)),
59+
'test_private_key.pem',
60+
)
61+
5262
def test_success(self):
5363
response = CASResponse(self.response_text)
5464
self.assertTrue(response.success)
@@ -70,11 +80,13 @@ def test_failure(self):
7080

7181
def test_perform_service_validate(self):
7282
cas_client = CASClient('https://dummy.url')
73-
cas_client._perform_get = lambda url: self.response_text
74-
response = cas_client.perform_service_validate(
75-
ticket='FOO',
76-
service_url='BAR',
77-
)
83+
assert not cas_client.headers
84+
with mock.patch('cas_client.CASClient._perform_get') as m:
85+
m.return_value = self.response_text
86+
response = cas_client.perform_service_validate(
87+
ticket='FOO',
88+
service_url='BAR',
89+
)
7890
self.assertTrue(response.success)
7991
self.assertEqual(response.attributes, {
8092
u'i2a2characteristics': u'0,3592,2000',
@@ -87,6 +99,43 @@ def test_perform_service_validate(self):
8799
self.assertEqual(response.response_type, 'authenticationSuccess')
88100
self.assertEqual(response.user, 'jott')
89101

102+
def test_perform_service_validate_headers_call(self):
103+
class MockResponse(object):
104+
text = self.response_text
105+
106+
cas_client = CASClient('https://dummy.url')
107+
assert not cas_client.headers
108+
with mock.patch('requests.get') as m:
109+
m.return_value = MockResponse()
110+
cas_client.perform_service_validate(
111+
ticket='FOO',
112+
service_url='BAR',
113+
headers={'baz': 'quux'},
114+
)
115+
m.assert_called_with(
116+
'https://dummy.url/cas/serviceValidate?ticket=FOO&service=BAR',
117+
headers={'baz': 'quux'},
118+
verify=False,
119+
)
120+
121+
def test_perform_service_validate_headers_init(self):
122+
class MockResponse(object):
123+
text = self.response_text
124+
125+
cas_client = CASClient('https://dummy.url', headers={'baz': 'quux'})
126+
assert cas_client.headers == {'baz': 'quux'}
127+
with mock.patch('requests.get') as m:
128+
m.return_value = MockResponse()
129+
cas_client.perform_service_validate(
130+
ticket='FOO',
131+
service_url='BAR',
132+
)
133+
m.assert_called_with(
134+
'https://dummy.url/cas/serviceValidate?ticket=FOO&service=BAR',
135+
headers={'baz': 'quux'},
136+
verify=False,
137+
)
138+
90139
def test_get_destroy_other_sessions_url(self):
91140
cas_client = CASClient('https://dummy.url')
92141
service_url = 'https://app.url'
@@ -134,8 +183,7 @@ def test_get_api_url(self):
134183
api_resource = 'do_something_useful'
135184
auth_token_ticket = 'ATT-1234'
136185
authenticator = 'my_company_ldap'
137-
private_key_filepath = 'test_private_key.pem'
138-
with open(private_key_filepath, 'r') as file_pointer:
186+
with open(self.private_key_filepath, 'r') as file_pointer:
139187
private_key = file_pointer.read()
140188
service_url = 'https://example.com'
141189
kwargs = {
@@ -184,8 +232,7 @@ def test_get_auth_token_login_url(self):
184232
authenticator = 'my_company_ldap'
185233
username = 'my_user'
186234
service_url = 'https://example.com'
187-
private_key_filepath = 'test_private_key.pem'
188-
with open(private_key_filepath, 'r') as file_pointer:
235+
with open(self.private_key_filepath, 'r') as file_pointer:
189236
private_key = file_pointer.read()
190237
url = cas_client.get_auth_token_login_url(
191238
auth_token_ticket=auth_token_ticket,
@@ -220,3 +267,48 @@ def test_get_auth_token_login_url(self):
220267
),
221268
'service': 'https://example.com',
222269
}
270+
271+
def test_acquire_auth_token_ticket_no_headers(self):
272+
class MockResponse(object):
273+
text = '{"ticket": "FOO"}'
274+
275+
cas_client = CASClient('https://dummy.url')
276+
assert not cas_client.headers
277+
with mock.patch('requests.post') as m:
278+
m.return_value = MockResponse()
279+
cas_client.acquire_auth_token_ticket()
280+
m.assert_called_with(
281+
'https://dummy.url/cas/api/auth_token_tickets',
282+
headers=None,
283+
verify=False,
284+
)
285+
286+
def test_acquire_auth_token_ticket_headers_call(self):
287+
class MockResponse(object):
288+
text = '{"ticket": "FOO"}'
289+
290+
cas_client = CASClient('https://dummy.url')
291+
assert not cas_client.headers
292+
with mock.patch('requests.post') as m:
293+
m.return_value = MockResponse()
294+
cas_client.acquire_auth_token_ticket(headers={'baz': 'quux'})
295+
m.assert_called_with(
296+
'https://dummy.url/cas/api/auth_token_tickets',
297+
headers={'baz': 'quux'},
298+
verify=False,
299+
)
300+
301+
def test_acquire_auth_token_ticket_headers_init(self):
302+
class MockResponse(object):
303+
text = '{"ticket": "FOO"}'
304+
305+
cas_client = CASClient('https://dummy.url', headers={'baz': 'quux'})
306+
assert cas_client.headers == {'baz': 'quux'}
307+
with mock.patch('requests.post') as m:
308+
m.return_value = MockResponse()
309+
cas_client.acquire_auth_token_ticket()
310+
m.assert_called_with(
311+
'https://dummy.url/cas/api/auth_token_tickets',
312+
headers={'baz': 'quux'},
313+
verify=False,
314+
)
File renamed without changes.
File renamed without changes.

tox.ini

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
[tox]
2-
envlist = py27, py34
2+
envlist=
3+
py27
4+
py34
5+
py35
6+
py36
37

48
[testenv]
5-
deps=pytest
9+
deps=
10+
pytest
11+
py27: mock
612
commands=
7-
py.test -rf -vv test.py
8-
python -m doctest cas_client/cas_client.py
13+
py.test

0 commit comments

Comments
 (0)