Fix capsul create tests, post-test cleanup, tidy merge

This commit is contained in:
3wc 2021-07-23 13:40:00 +02:00
parent da4d28f70c
commit 2b33573890
3 changed files with 129 additions and 98 deletions

View File

@ -26,7 +26,6 @@ from capsulflask.http_client import MyHTTPClient
def create_app(): def create_app():
for var_name in [ for var_name in [
"SPOKE_HOST_TOKEN", "HUB_TOKEN", "STRIPE_SECRET_KEY", "SPOKE_HOST_TOKEN", "HUB_TOKEN", "STRIPE_SECRET_KEY",
"BTCPAY_PRIVATE_KEY", "MAIL_PASSWORD" "BTCPAY_PRIVATE_KEY", "MAIL_PASSWORD"
@ -119,11 +118,11 @@ def create_app():
} }
}) })
# app.logger.critical("critical") # app.logger.critical("critical")
# app.logger.error("error") # app.logger.error("error")
# app.logger.warning("warning") # app.logger.warning("warning")
# app.logger.info("info") # app.logger.info("info")
# app.logger.debug("debug") # app.logger.debug("debug")
stripe.api_key = app.config['STRIPE_SECRET_KEY'] stripe.api_key = app.config['STRIPE_SECRET_KEY']
stripe.api_version = app.config['STRIPE_API_VERSION'] stripe.api_version = app.config['STRIPE_API_VERSION']
@ -136,89 +135,88 @@ def create_app():
app.config['HTTP_CLIENT'] = MyHTTPClient(timeout_seconds=int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS'])) app.config['HTTP_CLIENT'] = MyHTTPClient(timeout_seconds=int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS']))
app.config['BTCPAY_ENABLED'] = False app.config['BTCPAY_ENABLED'] = False
if app.config['BTCPAY_URL'] != "": if app.config['BTCPAY_URL'] != "":
try: try:
app.config['BTCPAY_CLIENT'] = btcpay.Client(api_uri=app.config['BTCPAY_URL'], pem=app.config['BTCPAY_PRIVATE_KEY']) app.config['BTCPAY_CLIENT'] = btcpay.Client(api_uri=app.config['BTCPAY_URL'], pem=app.config['BTCPAY_PRIVATE_KEY'])
app.config['BTCPAY_ENABLED'] = True app.config['BTCPAY_ENABLED'] = True
except: except:
app.logger.warning("unable to create btcpay client. Capsul will work fine except cryptocurrency payments will not work. The error was: " + my_exec_info_message(sys.exc_info())) app.logger.warning("unable to create btcpay client. Capsul will work fine except cryptocurrency payments will not work. The error was: " + my_exec_info_message(sys.exc_info()))
# only start the scheduler and attempt to migrate the database if we are running the app. # only start the scheduler and attempt to migrate the database if we are running the app.
# otherwise we are running a CLI command. # otherwise we are running a CLI command.
command_line = ' '.join(sys.argv) command_line = ' '.join(sys.argv)
is_running_server = ( is_running_server = (
('flask run' in command_line) or ('flask run' in command_line) or
('gunicorn' in command_line) or ('gunicorn' in command_line) or
('test' in command_line) ('test' in command_line)
) )
app.logger.info(f"is_running_server: {is_running_server}") app.logger.info(f"is_running_server: {is_running_server}")
if app.config['HUB_MODE_ENABLED']: if app.config['HUB_MODE_ENABLED']:
if app.config['HUB_MODEL'] == "capsul-flask": if app.config['HUB_MODEL'] == "capsul-flask":
app.config['HUB_MODEL'] = hub_model.CapsulFlaskHub() app.config['HUB_MODEL'] = hub_model.CapsulFlaskHub()
# debug mode (flask reloader) runs two copies of the app. When running in debug mode, # debug mode (flask reloader) runs two copies of the app. When running in debug mode,
# we only want to start the scheduler one time. # we only want to start the scheduler one time.
if is_running_server and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'): if is_running_server and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'):
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task" heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task"
heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"} heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"}
heartbeat_task = lambda: requests.post(heartbeat_task_url, headers=heartbeat_task_headers) heartbeat_task = lambda: requests.post(heartbeat_task_url, headers=heartbeat_task_headers)
scheduler.add_job(name="heartbeat-task", func=heartbeat_task, trigger="interval", seconds=5) scheduler.add_job(name="heartbeat-task", func=heartbeat_task, trigger="interval", seconds=5)
scheduler.start() scheduler.start()
atexit.register(lambda: scheduler.shutdown()) atexit.register(lambda: scheduler.shutdown())
else: else:
app.config['HUB_MODEL'] = hub_model.MockHub() app.config['HUB_MODEL'] = hub_model.MockHub()
from capsulflask import db from capsulflask import db
db.init_app(app, is_running_server) db.init_app(app, is_running_server)
from capsulflask import auth, landing, console, payment, metrics, cli, hub_api, admin from capsulflask import auth, landing, console, payment, metrics, cli, hub_api, admin
app.register_blueprint(landing.bp) app.register_blueprint(landing.bp)
app.register_blueprint(auth.bp) app.register_blueprint(auth.bp)
app.register_blueprint(console.bp) app.register_blueprint(console.bp)
app.register_blueprint(payment.bp) app.register_blueprint(payment.bp)
app.register_blueprint(metrics.bp) app.register_blueprint(metrics.bp)
app.register_blueprint(cli.bp) app.register_blueprint(cli.bp)
app.register_blueprint(hub_api.bp) app.register_blueprint(hub_api.bp)
app.register_blueprint(admin.bp) app.register_blueprint(admin.bp)
app.add_url_rule("/", endpoint="index") app.add_url_rule("/", endpoint="index")
if app.config['SPOKE_MODE_ENABLED']: if app.config['SPOKE_MODE_ENABLED']:
if app.config['SPOKE_MODEL'] == "shell-scripts": if app.config['SPOKE_MODEL'] == "shell-scripts":
app.config['SPOKE_MODEL'] = spoke_model.ShellScriptSpoke() app.config['SPOKE_MODEL'] = spoke_model.ShellScriptSpoke()
else: else:
app.config['SPOKE_MODEL'] = spoke_model.MockSpoke() app.config['SPOKE_MODEL'] = spoke_model.MockSpoke()
from capsulflask import spoke_api from capsulflask import spoke_api
app.register_blueprint(spoke_api.bp) app.register_blueprint(spoke_api.bp)
@app.after_request @app.after_request
def security_headers(response): def security_headers(response):
response.headers['X-Frame-Options'] = 'SAMEORIGIN' response.headers['X-Frame-Options'] = 'SAMEORIGIN'
if 'Content-Security-Policy' not in response.headers: if 'Content-Security-Policy' not in response.headers:
response.headers['Content-Security-Policy'] = "default-src 'self'" response.headers['Content-Security-Policy'] = "default-src 'self'"
response.headers['X-Content-Type-Options'] = 'nosniff' response.headers['X-Content-Type-Options'] = 'nosniff'
return response return response
@app.context_processor @app.context_processor
def override_url_for(): def override_url_for():
""" """
override the url_for function built into flask override the url_for function built into flask
with our own custom implementation that busts the cache correctly when files change with our own custom implementation that busts the cache correctly when files change
""" """
return dict(url_for=url_for_with_cache_bust) return dict(url_for=url_for_with_cache_bust)
return app
return app
def url_for_with_cache_bust(endpoint, **values): def url_for_with_cache_bust(endpoint, **values):
""" """
@ -271,4 +269,4 @@ class SetLogLevelToDebugForHeartbeatRelatedMessagesFilter(logging.Filter):
if self.isHeartbeatRelatedString(arg): if self.isHeartbeatRelatedString(arg):
return False return False
return True return True

View File

@ -271,7 +271,6 @@ def create():
) )
return redirect(f"{url_for('console.index')}?created={id}") return redirect(f"{url_for('console.index')}?created={id}")
for error in errors: for error in errors:
flash(error) flash(error)
@ -353,14 +352,10 @@ def ssh_public_keys():
) )
def get_vms(): def get_vms():
if 'user_vms' not in g: return get_model().list_vms_for_account(session["account"])
g.user_vms = get_model().list_vms_for_account(session["account"])
return g.user_vms
def get_payments(): def get_payments():
if 'user_payments' not in g: return get_model().list_payments_for_account(session["account"])
g.user_payments = get_model().list_payments_for_account(session["account"])
return g.user_payments
average_number_of_days_in_a_month = 30.44 average_number_of_days_in_a_month = 30.44

View File

@ -1,7 +1,12 @@
from copy import deepcopy
from unittest.mock import patch
from flask import url_for from flask import url_for
from capsulflask.db import get_model from capsulflask.db import get_model
from capsulflask.tests_base import BaseTestCase from capsulflask.tests_base import BaseTestCase
from capsulflask.hub_model import MockHub
class ConsoleTests(BaseTestCase): class ConsoleTests(BaseTestCase):
@ -30,7 +35,7 @@ class ConsoleTests(BaseTestCase):
response = client.get(url_for("console.create")) response = client.get(url_for("console.create"))
self.assert_200(response) self.assert_200(response)
def test_create_fails_capacity(self): def test_create_fails_credit(self):
with self.client as client: with self.client as client:
client.get(url_for("console.create")) client.get(url_for("console.create"))
csrf_token = self.get_context_variable('csrf_token') csrf_token = self.get_context_variable('csrf_token')
@ -38,6 +43,33 @@ class ConsoleTests(BaseTestCase):
data = self.capsul_data data = self.capsul_data
data['csrf-token'] = csrf_token data['csrf-token'] = csrf_token
client.post(url_for("console.create"), data=data) client.post(url_for("console.create"), data=data)
capacity_message = \
'Your account must have enough credit to run an f1-xs for 1 month before you will be allowed to create it'
self.assert_message_flashed(capacity_message, category='message')
self.assertEqual(
len(get_model().list_vms_for_account('test@example.com')),
0
)
def test_create_fails_capacity(self):
with self.client as client:
client.get(url_for("console.create"))
csrf_token = self.get_context_variable('csrf_token')
data = self.capsul_data
data['csrf-token'] = csrf_token
get_model().create_payment_session('fake', 'test', 'test@example.com', 20)
get_model().consume_payment_session('fake', 'test', 20)
with patch.object(MockHub, 'capacity_avaliable', return_value=False) as mock_method:
response = client.post(url_for("console.create"), data=data)
mock_method.assert_called()
capacity_message = \ capacity_message = \
'\n host(s) at capacity. no capsuls can be created at this time. sorry. \n ' '\n host(s) at capacity. no capsuls can be created at this time. sorry. \n '
self.assert_message_flashed(capacity_message, category='message') self.assert_message_flashed(capacity_message, category='message')
@ -52,7 +84,7 @@ class ConsoleTests(BaseTestCase):
client.get(url_for("console.create")) client.get(url_for("console.create"))
csrf_token = self.get_context_variable('csrf_token') csrf_token = self.get_context_variable('csrf_token')
data = self.capsul_data data = deepcopy(self.capsul_data)
data['csrf-token'] = csrf_token data['csrf-token'] = csrf_token
data['os'] = '' data['os'] = ''
client.post(url_for("console.create"), data=data) client.post(url_for("console.create"), data=data)
@ -72,23 +104,26 @@ class ConsoleTests(BaseTestCase):
client.get(url_for("console.create")) client.get(url_for("console.create"))
csrf_token = self.get_context_variable('csrf_token') csrf_token = self.get_context_variable('csrf_token')
data = self.capsul_data data = deepcopy(self.capsul_data)
data['csrf-token'] = csrf_token data['csrf-token'] = csrf_token
get_model().create_payment_session('fake', 'test', 'test@example.com', 20)
get_model().consume_payment_session('fake', 'test', 20)
response = client.post(url_for("console.create"), data=data) response = client.post(url_for("console.create"), data=data)
# FIXME: mock create doesn't create, see #83 vms = get_model().list_vms_for_account('test@example.com')
# vms = get_model().list_vms_for_account('test@example.com') self.assertEqual(
# self.assertEqual( len(vms),
# len(vms), 1
# 1 )
# )
# vm_id = vms[0]['id']
# vm_id = vms[0].id
# self.assertRedirects(
# self.assertRedirects( response,
# response, url_for("console.index") + f'?created={vm_id}'
# url_for("console.index") + f'?{vm_id}' )
# )
def test_keys_loads(self): def test_keys_loads(self):
self._login('test@example.com') self._login('test@example.com')
@ -143,10 +178,13 @@ class ConsoleTests(BaseTestCase):
category='message' category='message'
) )
def setUp(self): def setUp(self):
self._login('test@example.com') self._login('test@example.com')
get_model().create_ssh_public_key('test@example.com', 'key', 'foo') get_model().create_ssh_public_key('test@example.com', 'key', 'foo')
def tearDown(self): def tearDown(self):
get_model().delete_ssh_public_key('test@example.com', 'key') get_model().cursor.execute("DELETE FROM ssh_public_keys")
get_model().cursor.execute("DELETE FROM login_tokens")
get_model().cursor.execute("DELETE FROM vms")
get_model().cursor.execute("DELETE FROM payments")
get_model().cursor.connection.commit()