diff --git a/Pipfile b/Pipfile index 6a595d5..ba69f0c 100644 --- a/Pipfile +++ b/Pipfile @@ -9,6 +9,7 @@ blinker = "==1.4" click = "==7.1.2" Flask = "==1.1.2" Flask-Mail = "==0.9.1" +Flask-Testing = "==0.8.1" gunicorn = "==20.0.4" isort = "==4.3.21" itsdangerous = "==1.1.0" diff --git a/app.py b/app.py index 9424b19..6c953f4 100644 --- a/app.py +++ b/app.py @@ -1,2 +1,4 @@ -from capsulflask import app +from capsulflask import create_app + +create_app() diff --git a/capsulflask/__init__.py b/capsulflask/__init__.py index d88ed91..5f25c45 100644 --- a/capsulflask/__init__.py +++ b/capsulflask/__init__.py @@ -8,7 +8,7 @@ import requests import sys import stripe -from dotenv import load_dotenv, find_dotenv +from dotenv import find_dotenv, dotenv_values from flask import Flask from flask_mail import Mail, Message from flask import render_template @@ -22,106 +22,112 @@ from capsulflask import hub_model, spoke_model, cli from capsulflask.btcpay import client as btcpay from capsulflask.http_client import MyHTTPClient + class StdoutMockFlaskMail: def send(self, message: Message): current_app.logger.info(f"Email would have been sent if configured:\n\nto: {','.join(message.recipients)}\nsubject: {message.subject}\nbody:\n\n{message.body}\n\n") -load_dotenv(find_dotenv()) +def create_app(): -app = Flask(__name__) + config = { + **dotenv_values(find_dotenv()), + **os.environ, # override loaded values with environment variables + } -app.config.from_mapping( - - BASE_URL=os.environ.get("BASE_URL", default="http://localhost:5000"), - SECRET_KEY=os.environ.get("SECRET_KEY", default="dev"), - HUB_MODE_ENABLED=os.environ.get("HUB_MODE_ENABLED", default="True").lower() in ['true', '1', 't', 'y', 'yes'], - SPOKE_MODE_ENABLED=os.environ.get("SPOKE_MODE_ENABLED", default="True").lower() in ['true', '1', 't', 'y', 'yes'], - INTERNAL_HTTP_TIMEOUT_SECONDS=os.environ.get("INTERNAL_HTTP_TIMEOUT_SECONDS", default="300"), - HUB_MODEL=os.environ.get("HUB_MODEL", default="capsul-flask"), - SPOKE_MODEL=os.environ.get("SPOKE_MODEL", default="mock"), - LOG_LEVEL=os.environ.get("LOG_LEVEL", default="INFO"), - SPOKE_HOST_ID=os.environ.get("SPOKE_HOST_ID", default="baikal"), - SPOKE_HOST_TOKEN=os.environ.get("SPOKE_HOST_TOKEN", default="changeme"), - HUB_TOKEN=os.environ.get("HUB_TOKEN", default="changeme"), + app = Flask(__name__) - # https://www.postgresql.org/docs/9.1/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS - # https://stackoverflow.com/questions/56332906/where-to-put-ssl-certificates-when-trying-to-connect-to-a-remote-database-using - # TLS example: sslmode=verify-full sslrootcert=letsencrypt-root-ca.crt host=db.example.com port=5432 user=postgres password=dev dbname=postgres - POSTGRES_CONNECTION_PARAMETERS=os.environ.get( - "POSTGRES_CONNECTION_PARAMETERS", - default="host=localhost port=5432 user=postgres password=dev dbname=postgres" - ), + app.config.from_mapping( + TESTING=config.get("TESTING", False), + BASE_URL=config.get("BASE_URL", "http://localhost:5000"), + SECRET_KEY=config.get("SECRET_KEY", "dev"), + HUB_MODE_ENABLED=config.get("HUB_MODE_ENABLED", "True").lower() in ['true', '1', 't', 'y', 'yes'], + SPOKE_MODE_ENABLED=config.get("SPOKE_MODE_ENABLED", "True").lower() in ['true', '1', 't', 'y', 'yes'], + INTERNAL_HTTP_TIMEOUT_SECONDS=config.get("INTERNAL_HTTP_TIMEOUT_SECONDS", "300"), + HUB_MODEL=config.get("HUB_MODEL", "capsul-flask"), + SPOKE_MODEL=config.get("SPOKE_MODEL", "mock"), + LOG_LEVEL=config.get("LOG_LEVEL", "INFO"), + SPOKE_HOST_ID=config.get("SPOKE_HOST_ID", "baikal"), + SPOKE_HOST_TOKEN=config.get("SPOKE_HOST_TOKEN", "changeme"), + HUB_TOKEN=config.get("HUB_TOKEN", "changeme"), - DATABASE_SCHEMA=os.environ.get("DATABASE_SCHEMA", default="public"), + # https://www.postgresql.org/docs/9.1/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS + # https://stackoverflow.com/questions/56332906/where-to-put-ssl-certificates-when-trying-to-connect-to-a-remote-database-using + # TLS example: sslmode=verify-full sslrootcert=letsencrypt-root-ca.crt host=db.example.com port=5432 user=postgres password=dev dbname=postgres + POSTGRES_CONNECTION_PARAMETERS=config.get( + "POSTGRES_CONNECTION_PARAMETERS", + "host=localhost port=5432 user=postgres password=dev dbname=postgres" + ), - MAIL_SERVER=os.environ.get("MAIL_SERVER", default=""), - MAIL_PORT=os.environ.get("MAIL_PORT", default="465"), - MAIL_USE_TLS=os.environ.get("MAIL_USE_TLS", default="False").lower() in ['true', '1', 't', 'y', 'yes'], - MAIL_USE_SSL=os.environ.get("MAIL_USE_SSL", default="True").lower() in ['true', '1', 't', 'y', 'yes'], - MAIL_USERNAME=os.environ.get("MAIL_USERNAME", default=""), - MAIL_PASSWORD=os.environ.get("MAIL_PASSWORD", default=""), - MAIL_DEFAULT_SENDER=os.environ.get("MAIL_DEFAULT_SENDER", default="no-reply@capsul.org"), - ADMIN_EMAIL_ADDRESSES=os.environ.get("ADMIN_EMAIL_ADDRESSES", default="ops@cyberia.club"), - ADMIN_PANEL_ALLOW_EMAIL_ADDRESSES=os.environ.get("ADMIN_PANEL_ALLOW_EMAIL_ADDRESSES", default="forest.n.johnson@gmail.com,capsul@cyberia.club"), + DATABASE_SCHEMA=config.get("DATABASE_SCHEMA", "public"), - PROMETHEUS_URL=os.environ.get("PROMETHEUS_URL", default="https://prometheus.cyberia.club"), + MAIL_SERVER=config.get("MAIL_SERVER", ""), + MAIL_PORT=config.get("MAIL_PORT", "465"), + MAIL_USE_TLS=config.get("MAIL_USE_TLS", "False").lower() in ['true', '1', 't', 'y', 'yes'], + MAIL_USE_SSL=config.get("MAIL_USE_SSL", "True").lower() in ['true', '1', 't', 'y', 'yes'], + MAIL_USERNAME=config.get("MAIL_USERNAME", ""), + MAIL_PASSWORD=config.get("MAIL_PASSWORD", ""), + MAIL_DEFAULT_SENDER=config.get("MAIL_DEFAULT_SENDER", "no-reply@capsul.org"), + ADMIN_EMAIL_ADDRESSES=config.get("ADMIN_EMAIL_ADDRESSES", "ops@cyberia.club"), + ADMIN_PANEL_ALLOW_EMAIL_ADDRESSES=config.get("ADMIN_PANEL_ALLOW_EMAIL_ADDRESSES", "forest.n.johnson@gmail.com,capsul@cyberia.club"), - STRIPE_API_VERSION=os.environ.get("STRIPE_API_VERSION", default="2020-03-02"), - STRIPE_SECRET_KEY=os.environ.get("STRIPE_SECRET_KEY", default=""), - STRIPE_PUBLISHABLE_KEY=os.environ.get("STRIPE_PUBLISHABLE_KEY", default=""), - #STRIPE_WEBHOOK_SECRET=os.environ.get("STRIPE_WEBHOOK_SECRET", default="") + PROMETHEUS_URL=config.get("PROMETHEUS_URL", "https://prometheus.cyberia.club"), - BTCPAY_PRIVATE_KEY=os.environ.get("BTCPAY_PRIVATE_KEY", default="").replace("\\n", "\n"), - BTCPAY_URL=os.environ.get("BTCPAY_URL", default="https://btcpay.cyberia.club") -) + STRIPE_API_VERSION=config.get("STRIPE_API_VERSION", "2020-03-02"), + STRIPE_SECRET_KEY=config.get("STRIPE_SECRET_KEY", ""), + STRIPE_PUBLISHABLE_KEY=config.get("STRIPE_PUBLISHABLE_KEY", ""), + #STRIPE_WEBHOOK_SECRET=config.get("STRIPE_WEBHOOK_SECRET", "") -app.config['HUB_URL'] = os.environ.get("HUB_URL", default=app.config['BASE_URL']) + BTCPAY_PRIVATE_KEY=config.get("BTCPAY_PRIVATE_KEY", "").replace("\\n", "\n"), + BTCPAY_URL=config.get("BTCPAY_URL", "https://btcpay.cyberia.club") + ) -class SetLogLevelToDebugForHeartbeatRelatedMessagesFilter(logging.Filter): - def isHeartbeatRelatedString(self, thing): - # thing_string = "" - is_in_string = False - try: - thing_string = "%s" % thing - is_in_string = 'heartbeat-task' in thing_string or 'hub/heartbeat' in thing_string or 'spoke/heartbeat' in thing_string - except: - pass - # self.warning("isHeartbeatRelatedString(%s): %s", thing_string, is_in_string ) - return is_in_string + app.config['HUB_URL'] = config.get("HUB_URL", app.config['BASE_URL']) + + class SetLogLevelToDebugForHeartbeatRelatedMessagesFilter(logging.Filter): + def isHeartbeatRelatedString(self, thing): + # thing_string = "" + is_in_string = False + try: + thing_string = "%s" % thing + is_in_string = 'heartbeat-task' in thing_string or 'hub/heartbeat' in thing_string or 'spoke/heartbeat' in thing_string + except: + pass + # self.warning("isHeartbeatRelatedString(%s): %s", thing_string, is_in_string ) + return is_in_string + + def filter(self, record): + if app.config['LOG_LEVEL'] == "DEBUG": + return True + + if self.isHeartbeatRelatedString(record.msg): + return False + for arg in record.args: + if self.isHeartbeatRelatedString(arg): + return False - def filter(self, record): - if app.config['LOG_LEVEL'] == "DEBUG": return True - if self.isHeartbeatRelatedString(record.msg): - return False - for arg in record.args: - if self.isHeartbeatRelatedString(arg): - return False - - return True - -logging_dict_config({ - 'version': 1, - 'formatters': {'default': { - 'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s', - }}, - 'filters': { - 'setLogLevelToDebugForHeartbeatRelatedMessages': { - '()': SetLogLevelToDebugForHeartbeatRelatedMessagesFilter, + logging_dict_config({ + 'version': 1, + 'formatters': {'default': { + 'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s', + }}, + 'filters': { + 'setLogLevelToDebugForHeartbeatRelatedMessages': { + '()': SetLogLevelToDebugForHeartbeatRelatedMessagesFilter, + } + }, + 'handlers': {'wsgi': { + 'class': 'logging.StreamHandler', + 'stream': 'ext://flask.logging.wsgi_errors_stream', + 'formatter': 'default', + 'filters': ['setLogLevelToDebugForHeartbeatRelatedMessages'] + }}, + 'root': { + 'level': app.config['LOG_LEVEL'], + 'handlers': ['wsgi'] } - }, - 'handlers': {'wsgi': { - 'class': 'logging.StreamHandler', - 'stream': 'ext://flask.logging.wsgi_errors_stream', - 'formatter': 'default', - 'filters': ['setLogLevelToDebugForHeartbeatRelatedMessages'] - }}, - 'root': { - 'level': app.config['LOG_LEVEL'], - 'handlers': ['wsgi'] - } -}) + }) # app.logger.critical("critical") # app.logger.error("error") @@ -129,118 +135,123 @@ logging_dict_config({ # app.logger.info("info") # app.logger.debug("debug") -stripe.api_key = app.config['STRIPE_SECRET_KEY'] -stripe.api_version = app.config['STRIPE_API_VERSION'] + stripe.api_key = app.config['STRIPE_SECRET_KEY'] + stripe.api_version = app.config['STRIPE_API_VERSION'] -if app.config['MAIL_SERVER'] != "": - app.config['FLASK_MAIL_INSTANCE'] = Mail(app) -else: - app.logger.warning("No MAIL_SERVER configured. capsul will simply print emails to stdout.") - app.config['FLASK_MAIL_INSTANCE'] = StdoutMockFlaskMail() + if app.config['MAIL_SERVER'] != "": + app.config['FLASK_MAIL_INSTANCE'] = Mail(app) + else: + app.logger.warning("No MAIL_SERVER configured. capsul will simply print emails to stdout.") + app.config['FLASK_MAIL_INSTANCE'] = StdoutMockFlaskMail() -app.config['HTTP_CLIENT'] = MyHTTPClient(timeout_seconds=int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS'])) + app.config['HTTP_CLIENT'] = MyHTTPClient(timeout_seconds=int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS'])) -try: - app.config['BTCPAY_CLIENT'] = btcpay.Client(api_uri=app.config['BTCPAY_URL'], pem=app.config['BTCPAY_PRIVATE_KEY']) -except: - app.logger.warning("unable to create btcpay client. Capsul will work fine except cryptocurrency payments will not work. The error was: " + my_exec_info_message(sys.exc_info())) + try: + app.config['BTCPAY_CLIENT'] = btcpay.Client(api_uri=app.config['BTCPAY_URL'], pem=app.config['BTCPAY_PRIVATE_KEY']) + except: + app.logger.warning("unable to create btcpay client. Capsul will work fine except cryptocurrency payments will not work. The error was: " + my_exec_info_message(sys.exc_info())) # only start the scheduler and attempt to migrate the database if we are running the app. # otherwise we are running a CLI command. -command_line = ' '.join(sys.argv) -is_running_server = ('flask run' in command_line) or ('gunicorn' in command_line) + command_line = ' '.join(sys.argv) + is_running_server = ( + ('flask run' in command_line) or + ('gunicorn' in command_line) or + ('test' in command_line) + ) -app.logger.info(f"is_running_server: {is_running_server}") + app.logger.info(f"is_running_server: {is_running_server}") -if app.config['HUB_MODE_ENABLED']: - if app.config['HUB_MODEL'] == "capsul-flask": - app.config['HUB_MODEL'] = hub_model.CapsulFlaskHub() + if app.config['HUB_MODE_ENABLED']: + if app.config['HUB_MODEL'] == "capsul-flask": + app.config['HUB_MODEL'] = hub_model.CapsulFlaskHub() - # debug mode (flask reloader) runs two copies of the app. When running in debug mode, - # we only want to start the scheduler one time. - if is_running_server and (not app.debug or os.environ.get('WERKZEUG_RUN_MAIN') == 'true'): - scheduler = BackgroundScheduler() - heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task" - heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"} - heartbeat_task = lambda: requests.post(heartbeat_task_url, headers=heartbeat_task_headers) - scheduler.add_job(name="heartbeat-task", func=heartbeat_task, trigger="interval", seconds=5) - scheduler.start() + # debug mode (flask reloader) runs two copies of the app. When running in debug mode, + # we only want to start the scheduler one time. + if is_running_server and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'): + scheduler = BackgroundScheduler() + heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task" + heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"} + heartbeat_task = lambda: requests.post(heartbeat_task_url, headers=heartbeat_task_headers) + scheduler.add_job(name="heartbeat-task", func=heartbeat_task, trigger="interval", seconds=5) + scheduler.start() - atexit.register(lambda: scheduler.shutdown()) + atexit.register(lambda: scheduler.shutdown()) - else: - app.config['HUB_MODEL'] = hub_model.MockHub() + else: + app.config['HUB_MODEL'] = hub_model.MockHub() - from capsulflask import db - db.init_app(app, is_running_server) + from capsulflask import db + db.init_app(app, is_running_server) from capsulflask import ( auth, landing, console, payment, metrics, cli, hub_api, publicapi, admin ) - app.register_blueprint(landing.bp) - app.register_blueprint(auth.bp) - app.register_blueprint(console.bp) - app.register_blueprint(payment.bp) - app.register_blueprint(metrics.bp) - app.register_blueprint(cli.bp) - app.register_blueprint(hub_api.bp) - app.register_blueprint(admin.bp) + app.register_blueprint(landing.bp) + app.register_blueprint(auth.bp) + app.register_blueprint(console.bp) + app.register_blueprint(payment.bp) + app.register_blueprint(metrics.bp) + app.register_blueprint(cli.bp) + app.register_blueprint(hub_api.bp) + app.register_blueprint(admin.bp) app.register_blueprint(publicapi.bp) - app.add_url_rule("/", endpoint="index") + app.add_url_rule("/", endpoint="index") + + if app.config['SPOKE_MODE_ENABLED']: + if app.config['SPOKE_MODEL'] == "shell-scripts": + app.config['SPOKE_MODEL'] = spoke_model.ShellScriptSpoke() + else: + app.config['SPOKE_MODEL'] = spoke_model.MockSpoke() + + from capsulflask import spoke_api + + app.register_blueprint(spoke_api.bp) + + @app.after_request + def security_headers(response): + response.headers['X-Frame-Options'] = 'SAMEORIGIN' + if 'Content-Security-Policy' not in response.headers: + response.headers['Content-Security-Policy'] = "default-src 'self'" + response.headers['X-Content-Type-Options'] = 'nosniff' + return response - -if app.config['SPOKE_MODE_ENABLED']: - if app.config['SPOKE_MODEL'] == "shell-scripts": - app.config['SPOKE_MODEL'] = spoke_model.ShellScriptSpoke() - else: - app.config['SPOKE_MODEL'] = spoke_model.MockSpoke() - - from capsulflask import spoke_api - - app.register_blueprint(spoke_api.bp) - -@app.after_request -def security_headers(response): - response.headers['X-Frame-Options'] = 'SAMEORIGIN' - if 'Content-Security-Policy' not in response.headers: - response.headers['Content-Security-Policy'] = "default-src 'self'" - response.headers['X-Content-Type-Options'] = 'nosniff' - return response + @app.context_processor + def override_url_for(): + """ + override the url_for function built into flask + with our own custom implementation that busts the cache correctly when files change + """ + return dict(url_for=url_for_with_cache_bust) -@app.context_processor -def override_url_for(): - """ - override the url_for function built into flask - with our own custom implementation that busts the cache correctly when files change - """ - return dict(url_for=url_for_with_cache_bust) - - -def url_for_with_cache_bust(endpoint, **values): - """ - Add a query parameter based on the hash of the file, this acts as a cache bust - """ - - if endpoint == 'static': - filename = values.get('filename', None) - if filename: - if 'STATIC_FILE_HASH_CACHE' not in current_app.config: - current_app.config['STATIC_FILE_HASH_CACHE'] = dict() - - if filename not in current_app.config['STATIC_FILE_HASH_CACHE']: - filepath = os.path.join(current_app.root_path, endpoint, filename) - #print(filepath) - if os.path.isfile(filepath) and os.access(filepath, os.R_OK): - - with open(filepath, 'rb') as file: - hasher = hashlib.md5() - hasher.update(file.read()) - current_app.config['STATIC_FILE_HASH_CACHE'][filename] = hasher.hexdigest()[-6:] + def url_for_with_cache_bust(endpoint, **values): + """ + Add a query parameter based on the hash of the file, this acts as a cache bust + """ + + if endpoint == 'static': + filename = values.get('filename', None) + if filename: + if 'STATIC_FILE_HASH_CACHE' not in current_app.config: + current_app.config['STATIC_FILE_HASH_CACHE'] = dict() + + if filename not in current_app.config['STATIC_FILE_HASH_CACHE']: + filepath = os.path.join(current_app.root_path, endpoint, filename) + #print(filepath) + if os.path.isfile(filepath) and os.access(filepath, os.R_OK): - values['q'] = current_app.config['STATIC_FILE_HASH_CACHE'][filename] + with open(filepath, 'rb') as file: + hasher = hashlib.md5() + hasher.update(file.read()) + current_app.config['STATIC_FILE_HASH_CACHE'][filename] = hasher.hexdigest()[-6:] + + values['q'] = current_app.config['STATIC_FILE_HASH_CACHE'][filename] - return url_for(endpoint, **values) + return url_for(endpoint, **values) + + return app +>>>>>>> tests diff --git a/capsulflask/auth.py b/capsulflask/auth.py index f8a2a74..75bdd6f 100644 --- a/capsulflask/auth.py +++ b/capsulflask/auth.py @@ -41,7 +41,6 @@ def account_required(view): return wrapped_view - def admin_account_required(view): """View decorator that redirects non-admin users to the login page.""" diff --git a/capsulflask/tests/__init__.py b/capsulflask/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/capsulflask/tests/test_auth.py b/capsulflask/tests/test_auth.py new file mode 100644 index 0000000..7886c8a --- /dev/null +++ b/capsulflask/tests/test_auth.py @@ -0,0 +1,23 @@ +from flask import url_for, session + +from capsulflask.db import get_model +from capsulflask.tests_base import BaseTestCase + + +class LoginTests(BaseTestCase): + render_templates = False + + def test_login_request(self): + with self.client as client: + response = client.get(url_for("auth.login")) + self.assert_200(response) + + # FIXME test generated login link + + def test_login_magiclink(self): + token, ignoreCaseMatches = get_model().login('test@example.com') + + with self.client as client: + response = client.get(url_for("auth.magiclink", token=token)) + self.assertRedirects(response, url_for("console.index")) + self.assertEqual(session['account'], 'test@example.com') diff --git a/capsulflask/tests/test_console.py b/capsulflask/tests/test_console.py new file mode 100644 index 0000000..7d4abee --- /dev/null +++ b/capsulflask/tests/test_console.py @@ -0,0 +1,154 @@ +from flask import url_for + +from capsulflask.db import get_model +from capsulflask.tests_base import BaseTestCase + + +class ConsoleTests(BaseTestCase): + capsul_data = { + "size": "f1-xs", + "os": "debian10", + "ssh_authorized_key_count": 1, + "ssh_key_0": "key" + } + + ssh_key_data = { + "name": "key2", + "method": "POST", + "content": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDntq1t8Ddsa2q4p+PM7W4CLYYmxakokRRVLlf7AQlsTJFPsgBe9u0zuoOaKDMkBr0dlnuLm4Eub1Mj+BrdqAokto0YDiAnxUKRuYQKuHySKK8bLkisi2k47jGBDikx/jihgiuFTawo1mYsJJepC7PPwZGsoCImJEgq1L+ug0p3Zrj3QkUx4h25MpCSs2yvfgWjDyN8hEC76O42P+4ETezYrzrd1Kj26hdzHRnrxygvIUOtfau+5ydlaz8xQBEPrEY6/+pKDuwtXg1pBL7GmoUxBXVfHQSgq5s9jIJH+G0CR0ZoHMB25Ln4X/bsCQbLOu21+IGYKSDVM5TIMLtkKUkERQMVWvnpOp1LZKir4dC0m7SW74wpA8+2b1IsURIr9ARYGJpCEv1Q1Wz/X3yTf6Mfey7992MjUc9HcgjgU01/+kYomoXHprzolk+22Gjfgo3a4dRIoTY82GO8kkUKiaWHvDkkVURCY5dpteLA05sk3Z9aRMYsNXPLeOOPfzTlDA0=" + } + + def test_index(self): + self._login('test@example.com') + with self.client as client: + response = client.get(url_for("console.index")) + self.assert_200(response) + + def test_create_loads(self): + self._login('test@example.com') + with self.client as client: + response = client.get(url_for("console.create")) + self.assert_200(response) + + def test_create_fails_capacity(self): + with self.client as client: + client.get(url_for("console.create")) + csrf_token = self.get_context_variable('csrf_token') + + data = self.capsul_data + data['csrf-token'] = csrf_token + client.post(url_for("console.create"), data=data) + capacity_message = \ + '\n host(s) at capacity. no capsuls can be created at this time. sorry. \n ' + self.assert_message_flashed(capacity_message, category='message') + + self.assertEqual( + len(get_model().list_vms_for_account('test@example.com')), + 0 + ) + + def test_create_fails_invalid(self): + with self.client as client: + client.get(url_for("console.create")) + csrf_token = self.get_context_variable('csrf_token') + + data = self.capsul_data + data['csrf-token'] = csrf_token + data['os'] = '' + client.post(url_for("console.create"), data=data) + + self.assert_message_flashed( + 'OS is required', + category='message' + ) + + self.assertEqual( + len(get_model().list_vms_for_account('test@example.com')), + 0 + ) + + def test_create_succeeds(self): + with self.client as client: + client.get(url_for("console.create")) + csrf_token = self.get_context_variable('csrf_token') + + data = self.capsul_data + data['csrf-token'] = csrf_token + response = client.post(url_for("console.create"), data=data) + + vms = get_model().list_vms_for_account('test@example.com') + + self.assertEqual( + len(vms), + 1 # FIXME: mock create doesn't create, see #83 + ) + + return + + vm_id = vms[0].id + + self.assertRedirects( + response, + url_for("console.index") + f'?{vm_id}' + ) + + def test_keys_loads(self): + self._login('test@example.com') + with self.client as client: + response = client.get(url_for("console.ssh_public_keys")) + self.assert_200(response) + keys = self.get_context_variable('ssh_public_keys') + self.assertEqual(keys[0]['name'], 'key') + + def test_keys_add_fails_invalid(self): + self._login('test@example.com') + with self.client as client: + client.get(url_for("console.ssh_public_keys")) + csrf_token = self.get_context_variable('csrf_token') + + data = self.ssh_key_data + data['csrf-token'] = csrf_token + + data_invalid_content = data + data_invalid_content['content'] = 'foo' + client.post( + url_for("console.ssh_public_keys"), + data=data_invalid_content + ) + + self.assert_message_flashed( + 'Content must match "^(ssh|ecdsa)-[0-9A-Za-z+/_=@:. -]+$"', + category='message' + ) + + data_missing_content = data + data_missing_content['content'] = '' + client.post(url_for("console.ssh_public_keys"), data=data_missing_content) + + self.assert_message_flashed( + 'Content is required', category='message' + ) + + def test_keys_add_fails_duplicate(self): + self._login('test@example.com') + with self.client as client: + client.get(url_for("console.ssh_public_keys")) + csrf_token = self.get_context_variable('csrf_token') + + data = self.ssh_key_data + data['csrf-token'] = csrf_token + data['name'] = 'key' + client.post(url_for("console.ssh_public_keys"), data=data) + + self.assert_message_flashed( + 'A key with that name already exists', + category='message' + ) + + + def setUp(self): + self._login('test@example.com') + get_model().create_ssh_public_key('test@example.com', 'key', 'foo') + + def tearDown(self): + get_model().delete_ssh_public_key('test@example.com', 'key') diff --git a/capsulflask/tests/test_landing.py b/capsulflask/tests/test_landing.py new file mode 100644 index 0000000..8ac6f12 --- /dev/null +++ b/capsulflask/tests/test_landing.py @@ -0,0 +1,14 @@ +from capsulflask.tests_base import BaseTestCase + + +class LandingTests(BaseTestCase): + #: Do not render templates, we're only testing logic here. + render_templates = False + + def test_landing(self): + pages = ['/', 'pricing', 'faq', 'about-ssh', 'changelog', 'support'] + + with self.client as client: + for page in pages: + response = client.get(page) + self.assert_200(response) diff --git a/capsulflask/tests_base.py b/capsulflask/tests_base.py new file mode 100644 index 0000000..8067d5e --- /dev/null +++ b/capsulflask/tests_base.py @@ -0,0 +1,28 @@ +import os +from nanoid import generate + +from flask_testing import TestCase + +from capsulflask import create_app +from capsulflask.db import get_model + +class BaseTestCase(TestCase): + def create_app(self): + # Use default connection paramaters + os.environ['POSTGRES_CONNECTION_PARAMETERS'] = "host=localhost port=5432 user=postgres password=dev dbname=capsulflask_test" + os.environ['TESTING'] = '1' + os.environ['SPOKE_MODEL'] = 'mock' + os.environ['HUB_MODEL'] = 'mock' + return create_app() + + def setUp(self): + pass + + def tearDown(self): + pass + + def _login(self, user_email): + get_model().login(user_email) + with self.client.session_transaction() as session: + session['account'] = user_email + session['csrf-token'] = generate()