# Copyright 2017 Canonical Ltd.
# Licensed under the LGPLv3, see LICENCE file for details.
import json
import logging
import os
import tempfile
from datetime import datetime, timedelta
from unittest import TestCase
import macaroonbakery.bakery as bakery
import macaroonbakery.checkers as checkers
import macaroonbakery.httpbakery as httpbakery
import macaroonbakery.httpbakery.agent as agent
import requests.cookies
from httmock import HTTMock, response, urlmatch
from six.moves.urllib.parse import parse_qs, urlparse
log = logging.getLogger(__name__)
PRIVATE_KEY = 'CqoSgj06Zcgb4/S6RT4DpTjLAfKoznEY3JsShSjKJEU='
PUBLIC_KEY = 'YAhRSsth3a36mRYqQGQaLiS4QJax0p356nd+B8x7UQE='
class TestAgents(TestCase):
def setUp(self):
fd, filename = tempfile.mkstemp()
with os.fdopen(fd, 'w') as f:
f.write(agent_file)
self.agent_filename = filename
fd, filename = tempfile.mkstemp()
with os.fdopen(fd, 'w') as f:
f.write(bad_key_agent_file)
self.bad_key_agent_filename = filename
fd, filename = tempfile.mkstemp()
with os.fdopen(fd, 'w') as f:
f.write(no_username_agent_file)
self.no_username_agent_filename = filename
def tearDown(self):
os.remove(self.agent_filename)
os.remove(self.bad_key_agent_filename)
os.remove(self.no_username_agent_filename)
def test_load_auth_info(self):
auth_info = agent.load_auth_info(self.agent_filename)
self.assertEqual(str(auth_info.key), PRIVATE_KEY)
self.assertEqual(str(auth_info.key.public_key), PUBLIC_KEY)
self.assertEqual(auth_info.agents, [
agent.Agent(url='https://1.example.com/', username='user-1'),
agent.Agent(url='https://2.example.com/discharger', username='user-2'),
agent.Agent(url='http://0.3.2.1', username='test-user'),
])
def test_invalid_agent_json(self):
with self.assertRaises(agent.AgentFileFormatError):
agent.read_auth_info('}')
def test_invalid_read_auth_info_arg(self):
with self.assertRaises(agent.AgentFileFormatError):
agent.read_auth_info(0)
def test_load_auth_info_with_bad_key(self):
with self.assertRaises(agent.AgentFileFormatError):
agent.load_auth_info(self.bad_key_agent_filename)
def test_load_auth_info_with_no_username(self):
with self.assertRaises(agent.AgentFileFormatError):
agent.load_auth_info(self.no_username_agent_filename)
def test_agent_login(self):
discharge_key = bakery.generate_key()
class _DischargerLocator(bakery.ThirdPartyLocator):
def third_party_info(self, loc):
if loc == 'http://0.3.2.1':
return bakery.ThirdPartyInfo(
public_key=discharge_key.public_key,
version=bakery.LATEST_VERSION,
)
d = _DischargerLocator()
server_key = bakery.generate_key()
server_bakery = bakery.Bakery(key=server_key, locator=d)
@urlmatch(path='.*/here')
def server_get(url, request):
ctx = checkers.AuthContext()
test_ops = [bakery.Op(entity='test-op', action='read')]
auth_checker = server_bakery.checker.auth(
httpbakery.extract_macaroons(request.headers))
try:
auth_checker.allow(ctx, test_ops)
resp = response(status_code=200,
content='done')
except bakery.PermissionDenied:
caveats = [
checkers.Caveat(location='http://0.3.2.1',
condition='is-ok')
]
m = server_bakery.oven.macaroon(
version=bakery.LATEST_VERSION,
expiry=datetime.utcnow() + timedelta(days=1),
caveats=caveats, ops=test_ops)
content, headers = httpbakery.discharge_required_response(
m, '/',
'test',
'message')
resp = response(status_code=401,
content=content,
headers=headers)
return request.hooks['response'][0](resp)
@urlmatch(path='.*/discharge')
def discharge(url, request):
qs = parse_qs(request.body)
if qs.get('token64') is None:
return response(
status_code=401,
content={
'Code': httpbakery.ERR_INTERACTION_REQUIRED,
'Message': 'interaction required',
'Info': {
'InteractionMethods': {
'agent': {'login-url': '/login'},
},
},
},
headers={'Content-Type': 'application/json'})
else:
qs = parse_qs(request.body)
content = {q: qs[q][0] for q in qs}
m = httpbakery.discharge(checkers.AuthContext(), content,
discharge_key, None, alwaysOK3rd)
return {
'status_code': 200,
'content': {
'Macaroon': m.to_dict()
}
}
auth_info = agent.load_auth_info(self.agent_filename)
@urlmatch(path='.*/login')
def login(url, request):
qs = parse_qs(urlparse(request.url).query)
self.assertEqual(request.method, 'GET')
self.assertEqual(
qs, {'username': ['test-user'], 'public-key': [PUBLIC_KEY]})
b = bakery.Bakery(key=discharge_key)
m = b.oven.macaroon(
version=bakery.LATEST_VERSION,
expiry=datetime.utcnow() + timedelta(days=1),
caveats=[bakery.local_third_party_caveat(
PUBLIC_KEY,
version=httpbakery.request_version(request.headers))],
ops=[bakery.Op(entity='agent', action='login')])
return {
'status_code': 200,
'content': {
'macaroon': m.to_dict()
}
}
with HTTMock(server_get), \
HTTMock(discharge), \
HTTMock(login):
client = httpbakery.Client(interaction_methods=[
agent.AgentInteractor(auth_info),
])
resp = requests.get(
'http://0.1.2.3/here',
cookies=client.cookies,
auth=client.auth())
self.assertEqual(resp.content, b'done')
def test_agent_legacy(self):
discharge_key = bakery.generate_key()
class _DischargerLocator(bakery.ThirdPartyLocator):
def third_party_info(self, loc):
if loc == 'http://0.3.2.1':
return bakery.ThirdPartyInfo(
public_key=discharge_key.public_key,
version=bakery.LATEST_VERSION,
)
d = _DischargerLocator()
server_key = bakery.generate_key()
server_bakery = bakery.Bakery(key=server_key, locator=d)
@urlmatch(path='.*/here')
def server_get(url, request):
ctx = checkers.AuthContext()
test_ops = [bakery.Op(entity='test-op', action='read')]
auth_checker = server_bakery.checker.auth(
httpbakery.extract_macaroons(request.headers))
try:
auth_checker.allow(ctx, test_ops)
resp = response(status_code=200,
content='done')
except bakery.PermissionDenied:
caveats = [
checkers.Caveat(location='http://0.3.2.1',
condition='is-ok')
]
m = server_bakery.oven.macaroon(
version=bakery.LATEST_VERSION,
expiry=datetime.utcnow() + timedelta(days=1),
caveats=caveats, ops=test_ops)
content, headers = httpbakery.discharge_required_response(
m, '/',
'test',
'message')
resp = response(
status_code=401,
content=content,
headers=headers,
)
return request.hooks['response'][0](resp)
class InfoStorage:
info = None
@urlmatch(path='.*/discharge')
def discharge(url, request):
qs = parse_qs(request.body)
if qs.get('caveat64') is not None:
content = {q: qs[q][0] for q in qs}
class InteractionRequiredError(Exception):
def __init__(self, error):
self.error = error
class CheckerInError(bakery.ThirdPartyCaveatChecker):
def check_third_party_caveat(self, ctx, info):
InfoStorage.info = info
raise InteractionRequiredError(
httpbakery.Error(
code=httpbakery.ERR_INTERACTION_REQUIRED,
version=httpbakery.request_version(
request.headers),
message='interaction required',
info=httpbakery.ErrorInfo(
wait_url='http://0.3.2.1/wait?'
'dischargeid=1',
visit_url='http://0.3.2.1/visit?'
'dischargeid=1'
),
),
)
try:
httpbakery.discharge(
checkers.AuthContext(), content,
discharge_key, None, CheckerInError())
except InteractionRequiredError as exc:
return response(
status_code=401,
content={
'Code': exc.error.code,
'Message': exc.error.message,
'Info': {
'WaitURL': exc.error.info.wait_url,
'VisitURL': exc.error.info.visit_url,
},
},
headers={'Content-Type': 'application/json'})
key = bakery.generate_key()
@urlmatch(path='.*/visit')
def visit(url, request):
if request.headers.get('Accept') == 'application/json':
return {
'status_code': 200,
'content': {
'agent': '/agent-visit',
}
}
raise Exception('unexpected call to visit without Accept header')
@urlmatch(path='.*/agent-visit')
def agent_visit(url, request):
if request.method != "POST":
raise Exception('unexpected method')
log.info('agent_visit url {}'.format(url))
body = json.loads(request.body.decode('utf-8'))
if body['username'] != 'test-user':
raise Exception('unexpected username in body {!r}'.format(request.body))
public_key = bakery.PublicKey.deserialize(body['public_key'])
ms = httpbakery.extract_macaroons(request.headers)
if len(ms) == 0:
b = bakery.Bakery(key=discharge_key)
m = b.oven.macaroon(
version=bakery.LATEST_VERSION,
expiry=datetime.utcnow() + timedelta(days=1),
caveats=[bakery.local_third_party_caveat(
public_key,
version=httpbakery.request_version(request.headers))],
ops=[bakery.Op(entity='agent', action='login')])
content, headers = httpbakery.discharge_required_response(
m, '/',
'test',
'message')
resp = response(status_code=401,
content=content,
headers=headers)
return request.hooks['response'][0](resp)
return {
'status_code': 200,
'content': {
'agent_login': True
}
}
@urlmatch(path='.*/wait$')
def wait(url, request):
class EmptyChecker(bakery.ThirdPartyCaveatChecker):
def check_third_party_caveat(self, ctx, info):
return []
if InfoStorage.info is None:
self.fail('visit url has not been visited')
m = bakery.discharge(
checkers.AuthContext(),
InfoStorage.info.id,
InfoStorage.info.caveat,
discharge_key,
EmptyChecker(),
_DischargerLocator(),
)
return {
'status_code': 200,
'content': {
'Macaroon': m.to_dict()
}
}
with HTTMock(server_get), \
HTTMock(discharge), \
HTTMock(visit), \
HTTMock(wait), \
HTTMock(agent_visit):
client = httpbakery.Client(interaction_methods=[
agent.AgentInteractor(
agent.AuthInfo(
key=key,
agents=[agent.Agent(username='test-user',
url=u'http://0.3.2.1')],
),
),
])
resp = requests.get(
'http://0.1.2.3/here',
cookies=client.cookies,
auth=client.auth(),
)
self.assertEqual(resp.content, b'done')
agent_file = '''
{
"key": {
"public": "YAhRSsth3a36mRYqQGQaLiS4QJax0p356nd+B8x7UQE=",
"private": "CqoSgj06Zcgb4/S6RT4DpTjLAfKoznEY3JsShSjKJEU="
},
"agents": [{
"url": "https://1.example.com/",
"username": "user-1"
}, {
"url": "https://2.example.com/discharger",
"username": "user-2"
}, {
"url": "http://0.3.2.1",
"username": "test-user"
}]
}
'''
bad_key_agent_file = '''
{
"key": {
"public": "YAhRSsth3a36mRYqQGQaLiS4QJax0p356nd+B8x7UQE=",
"private": "CqoSgj06Zcgb4/S6RT4DpTjLAfKoznEY3JsShSjKJE=="
},
"agents": [{
"url": "https://1.example.com/",
"username": "user-1"
}, {
"url": "https://2.example.com/discharger",
"username": "user-2"
}]
}
'''
no_username_agent_file = '''
{
"key": {
"public": "YAhRSsth3a36mRYqQGQaLiS4QJax0p356nd+B8x7UQE=",
"private": "CqoSgj06Zcgb4/S6RT4DpTjLAfKoznEY3JsShSjKJEU="
},
"agents": [{
"url": "https://1.example.com/"
}, {
"url": "https://2.example.com/discharger",
"username": "user-2"
}]
}
'''
class ThirdPartyCaveatCheckerF(bakery.ThirdPartyCaveatChecker):
def __init__(self, check):
self._check = check
def check_third_party_caveat(self, ctx, info):
cond, arg = checkers.parse_caveat(info.condition)
return self._check(cond, arg)
alwaysOK3rd = ThirdPartyCaveatCheckerF(lambda cond, arg: [])