create TestHTTPClient that uses werkzueg test client, tests are passing

This commit is contained in:
forest 2021-07-31 15:28:42 -05:00 committed by 3wc
parent 16ff1b5b26
commit 5f9fc1adcf
9 changed files with 130 additions and 39 deletions

3
app.py
View File

@ -1,4 +1,5 @@
from capsulflask import create_app from capsulflask import create_app
from capsulflask.http_client import MyHTTPClient
app = create_app() app = create_app(lambda timeout_seconds: MyHTTPClient(timeout_seconds=timeout_seconds))

View File

@ -21,11 +21,9 @@ from apscheduler.schedulers.background import BackgroundScheduler
from capsulflask.shared import * from capsulflask.shared import *
from capsulflask import hub_model, spoke_model, cli from capsulflask import hub_model, spoke_model, cli
from capsulflask.btcpay import client as btcpay from capsulflask.btcpay import client as btcpay
from capsulflask.http_client import MyHTTPClient
def create_app(http_client_factory):
def create_app():
for var_name in [ for var_name in [
"SPOKE_HOST_TOKEN", "HUB_TOKEN", "STRIPE_SECRET_KEY", "SPOKE_HOST_TOKEN", "HUB_TOKEN", "STRIPE_SECRET_KEY",
"BTCPAY_PRIVATE_KEY", "MAIL_PASSWORD" "BTCPAY_PRIVATE_KEY", "MAIL_PASSWORD"
@ -133,7 +131,10 @@ def create_app():
mylog_warning(app, "No MAIL_SERVER configured. capsul will simply print emails to stdout.") mylog_warning(app, "No MAIL_SERVER configured. capsul will simply print emails to stdout.")
app.config['FLASK_MAIL_INSTANCE'] = StdoutMockFlaskMail() app.config['FLASK_MAIL_INSTANCE'] = StdoutMockFlaskMail()
app.config['HTTP_CLIENT'] = MyHTTPClient(timeout_seconds=int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS'])) # allow a mock http client to be injected by the test code.
app.config['HTTP_CLIENT'] = http_client_factory(int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS']))
app.config['BTCPAY_ENABLED'] = False app.config['BTCPAY_ENABLED'] = False
if app.config['BTCPAY_URL'] != "": if app.config['BTCPAY_URL'] != "":
@ -160,7 +161,7 @@ def create_app():
# debug mode (flask reloader) runs two copies of the app. When running in debug mode, # debug mode (flask reloader) runs two copies of the app. When running in debug mode,
# we only want to start the scheduler one time. # we only want to start the scheduler one time.
if is_running_server and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'): if is_running_server and not app.config['TESTING'] and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'):
scheduler = BackgroundScheduler() scheduler = BackgroundScheduler()
heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task" heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task"
heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"} heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"}

View File

@ -195,13 +195,6 @@ def detail(id):
@account_required @account_required
def create(): def create():
#raise "console.create()!"
# file_object = open('unittest-output.log', 'a')
# file_object.write("console.create()!\n")
# file_object.close()
mylog_error(current_app, "console.create()!")
vm_sizes = get_model().vm_sizes_dict() vm_sizes = get_model().vm_sizes_dict()
operating_systems = get_model().operating_systems_dict() operating_systems = get_model().operating_systems_dict()
public_keys_for_account = get_model().list_ssh_public_keys_for_account(session["account"]) public_keys_for_account = get_model().list_ssh_public_keys_for_account(session["account"])

View File

@ -196,7 +196,6 @@ class CapsulFlaskHub(VirtualizationInterface):
def create(self, email: str, id: str, os: str, size: str, template_image_file_name: str, vcpus: int, memory_mb: int, ssh_authorized_keys: list): def create(self, email: str, id: str, os: str, size: str, template_image_file_name: str, vcpus: int, memory_mb: int, ssh_authorized_keys: list):
validate_capsul_id(id) validate_capsul_id(id)
online_hosts = get_model().get_online_hosts() online_hosts = get_model().get_online_hosts()
mylog_debug(current_app, f"hub_model.create(): ${len(online_hosts)} hosts")
payload = json.dumps(dict( payload = json.dumps(dict(
type="create", type="create",
email=email, email=email,

View File

@ -2,6 +2,7 @@ import re
from flask import current_app, Flask from flask import current_app, Flask
from typing import List from typing import List
from threading import Lock
class OnlineHost: class OnlineHost:
def __init__(self, id: str, url: str): def __init__(self, id: str, url: str):
@ -58,8 +59,10 @@ def my_exec_info_message(exec_info):
mylog_current_test_id_container = { mylog_current_test_id_container = {
'value': '', 'value': '',
'mutex': Lock()
} }
def set_mylog_test_id(test_id): def set_mylog_test_id(test_id):
mylog_current_test_id_container['value'] = ".".join(test_id.split(".")[-2:]) mylog_current_test_id_container['value'] = ".".join(test_id.split(".")[-2:])
@ -67,10 +70,11 @@ def set_mylog_test_id(test_id):
def log_output_for_tests(app: Flask, message: str): def log_output_for_tests(app: Flask, message: str):
if app.config['TESTING'] and mylog_current_test_id_container['value'] != "": if app.config['TESTING'] and mylog_current_test_id_container['value'] != "":
mylog_current_test_id_container['mutex'].acquire()
file_object = open('unittest-log-output.log', 'a') file_object = open('unittest-log-output.log', 'a')
file_object.write(f"{mylog_current_test_id_container['value']}: {message}\n") file_object.write(f"{mylog_current_test_id_container['value']}: {message}\n")
file_object.close() file_object.close()
mylog_current_test_id_container['mutex'].release()
def mylog_debug(app: Flask, message: str): def mylog_debug(app: Flask, message: str):
log_output_for_tests(app, f"DEBUG: {message}") log_output_for_tests(app, f"DEBUG: {message}")

View File

@ -43,9 +43,10 @@ def operation_without_id():
def operation_impl(operation_id: int): def operation_impl(operation_id: int):
if authorized_as_hub(request.headers): if authorized_as_hub(request.headers):
request_body_json = request.json request_body = request.json
request_body = json.loads(request_body_json) if not isinstance(request.json, dict) and not isinstance(request.json, list):
#mylog_info(current_app, f"request.json: {request_body}") request_body = json.loads(request.json)
handlers = { handlers = {
"capacity_avaliable": handle_capacity_avaliable, "capacity_avaliable": handle_capacity_avaliable,
"get": handle_get, "get": handle_get,

View File

@ -8,6 +8,7 @@ from flask import url_for
from capsulflask.db import get_model from capsulflask.db import get_model
from capsulflask.tests_base import BaseTestCase from capsulflask.tests_base import BaseTestCase
from capsulflask.shared import *
from capsulflask.spoke_model import MockSpoke from capsulflask.spoke_model import MockSpoke
@ -25,6 +26,30 @@ class ConsoleTests(BaseTestCase):
"content": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDntq1t8Ddsa2q4p+PM7W4CLYYmxakokRRVLlf7AQlsTJFPsgBe9u0zuoOaKDMkBr0dlnuLm4Eub1Mj+BrdqAokto0YDiAnxUKRuYQKuHySKK8bLkisi2k47jGBDikx/jihgiuFTawo1mYsJJepC7PPwZGsoCImJEgq1L+ug0p3Zrj3QkUx4h25MpCSs2yvfgWjDyN8hEC76O42P+4ETezYrzrd1Kj26hdzHRnrxygvIUOtfau+5ydlaz8xQBEPrEY6/+pKDuwtXg1pBL7GmoUxBXVfHQSgq5s9jIJH+G0CR0ZoHMB25Ln4X/bsCQbLOu21+IGYKSDVM5TIMLtkKUkERQMVWvnpOp1LZKir4dC0m7SW74wpA8+2b1IsURIr9ARYGJpCEv1Q1Wz/X3yTf6Mfey7992MjUc9HcgjgU01/+kYomoXHprzolk+22Gjfgo3a4dRIoTY82GO8kkUKiaWHvDkkVURCY5dpteLA05sk3Z9aRMYsNXPLeOOPfzTlDA0=" "content": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDntq1t8Ddsa2q4p+PM7W4CLYYmxakokRRVLlf7AQlsTJFPsgBe9u0zuoOaKDMkBr0dlnuLm4Eub1Mj+BrdqAokto0YDiAnxUKRuYQKuHySKK8bLkisi2k47jGBDikx/jihgiuFTawo1mYsJJepC7PPwZGsoCImJEgq1L+ug0p3Zrj3QkUx4h25MpCSs2yvfgWjDyN8hEC76O42P+4ETezYrzrd1Kj26hdzHRnrxygvIUOtfau+5ydlaz8xQBEPrEY6/+pKDuwtXg1pBL7GmoUxBXVfHQSgq5s9jIJH+G0CR0ZoHMB25Ln4X/bsCQbLOu21+IGYKSDVM5TIMLtkKUkERQMVWvnpOp1LZKir4dC0m7SW74wpA8+2b1IsURIr9ARYGJpCEv1Q1Wz/X3yTf6Mfey7992MjUc9HcgjgU01/+kYomoXHprzolk+22Gjfgo3a4dRIoTY82GO8kkUKiaWHvDkkVURCY5dpteLA05sk3Z9aRMYsNXPLeOOPfzTlDA0="
} }
def setUp(self):
super().setUp()
get_model().cursor.execute("DELETE FROM host_operation")
get_model().cursor.execute("DELETE FROM operations")
get_model().cursor.execute("DELETE FROM vm_ssh_host_key")
get_model().cursor.execute("DELETE FROM vm_ssh_authorized_key")
get_model().cursor.execute("DELETE FROM ssh_public_keys")
get_model().cursor.execute("DELETE FROM login_tokens")
get_model().cursor.execute("DELETE FROM vms")
get_model().cursor.execute("DELETE FROM payments")
get_model().cursor.connection.commit()
self._login('test@example.com')
get_model().create_ssh_public_key('test@example.com', 'key', 'foo')
# heartbeat all the spokes so that the hub <--> spoke communication can work as normal.
host_ids = get_model().list_hosts_with_networks(None).keys()
for host_id in host_ids:
get_model().host_heartbeat(host_id)
def test_index(self): def test_index(self):
self._login('test@example.com') self._login('test@example.com')
with self.client as client: with self.client as client:
@ -84,10 +109,6 @@ class ConsoleTests(BaseTestCase):
0 0
) )
file_object = open('unittest-output.log', 'a')
file_object.write(f"{self.id()} captured output:\n{self.logs_from_test.getvalue()}\n")
file_object.close()
def test_create_fails_invalid(self): def test_create_fails_invalid(self):
with self.client as client: with self.client as client:
client.get(url_for("console.create")) client.get(url_for("console.create"))
@ -121,9 +142,8 @@ class ConsoleTests(BaseTestCase):
response = client.post(url_for("console.create"), data=data) response = client.post(url_for("console.create"), data=data)
# mylog_info(self.app, get_model().list_all_operations())
self.assertEqual( self.assertEqual(
len(get_model().list_all_operations()), len(get_model().list_all_operations()),
1 1
@ -196,16 +216,4 @@ class ConsoleTests(BaseTestCase):
category='message' category='message'
) )
def setUp(self):
super().setUp()
self._login('test@example.com')
get_model().create_ssh_public_key('test@example.com', 'key', 'foo')
def tearDown(self):
super().tearDown()
get_model().cursor.execute("DELETE FROM ssh_public_keys")
get_model().cursor.execute("DELETE FROM login_tokens")
get_model().cursor.execute("DELETE FROM vms")
get_model().cursor.execute("DELETE FROM payments")
get_model().cursor.connection.commit()

View File

@ -1,14 +1,26 @@
from io import StringIO
import logging import logging
import unittest import unittest
import os import os
import sys
import json
import itertools
import time
import threading
import asyncio
import traceback
from urllib.parse import urlparse
from typing import List
from nanoid import generate from nanoid import generate
from concurrent.futures import ThreadPoolExecutor
from flask_testing import TestCase from flask_testing import TestCase
from flask import current_app from flask import current_app
from capsulflask import create_app from capsulflask import create_app
from capsulflask.db import get_model from capsulflask.db import get_model
from capsulflask.http_client import *
from capsulflask.shared import * from capsulflask.shared import *
class BaseTestCase(TestCase): class BaseTestCase(TestCase):
@ -19,7 +31,9 @@ class BaseTestCase(TestCase):
os.environ['LOG_LEVEL'] = 'DEBUG' os.environ['LOG_LEVEL'] = 'DEBUG'
os.environ['SPOKE_MODEL'] = 'mock' os.environ['SPOKE_MODEL'] = 'mock'
os.environ['HUB_MODEL'] = 'capsul-flask' os.environ['HUB_MODEL'] = 'capsul-flask'
self.app = create_app() self1 = self
get_app = lambda: self1.app
self.app = create_app(lambda timeout_seconds: TestHTTPClient(get_app, timeout_seconds))
return self.app return self.app
def setUp(self): def setUp(self):
@ -33,3 +47,73 @@ class BaseTestCase(TestCase):
with self.client.session_transaction() as session: with self.client.session_transaction() as session:
session['account'] = user_email session['account'] = user_email
session['csrf-token'] = generate() session['csrf-token'] = generate()
class TestHTTPClient:
def __init__(self, get_app, timeout_seconds = 5):
self.timeout_seconds = timeout_seconds
self.get_app = get_app
self.executor = ThreadPoolExecutor()
def do_multi_http_sync(self, online_hosts: List[OnlineHost], url_suffix: str, body: str, authorization_header=None) -> List[HTTPResult]:
future = run_coroutine(self.do_multi_http(online_hosts=online_hosts, url_suffix=url_suffix, body=body, authorization_header=authorization_header))
fromOtherThread = future.result()
toReturn = []
for individualResult in fromOtherThread:
if individualResult.error != None and individualResult.error != "":
mylog_error(self.get_app(), individualResult.error)
toReturn.append(individualResult.http_result)
return toReturn
def do_http_sync(self, url: str, body: str, method="POST", authorization_header=None) -> HTTPResult:
future = run_coroutine(self.do_http(method=method, url=url, body=body, authorization_header=authorization_header))
fromOtherThread = future.result()
if fromOtherThread.error != None and fromOtherThread.error != "":
mylog_error(self.get_app(), fromOtherThread.error)
return fromOtherThread.http_result
async def do_http(self, url: str, body: str, method="POST", authorization_header=None) -> InterThreadResult:
path = urlparse(url).path
headers = {}
if authorization_header != None:
headers['Authorization'] = authorization_header
if body:
headers['Content-Type'] = "application/json"
#mylog_info(self.get_app(), f"path, data=body, headers=headers: {path}, {body}, {headers}")
do_request = None
if method == "POST":
do_request = lambda: self.get_app().test_client().post(path, data=body, headers=headers)
if method == "GET":
do_request = lambda: self.get_app().test_client().get(path, headers=headers)
response = None
try:
response = await get_event_loop().run_in_executor(self.executor, do_request)
except:
traceback.print_exc()
error_message = my_exec_info_message(sys.exc_info())
response_body = json.dumps({"error_message": f"do_http (HTTP {method} {url}) {error_message}"})
return InterThreadResult(
HTTPResult(-1, response_body),
f"""do_http (HTTP {method} {url}) failed with: {error_message}"""
)
return InterThreadResult(HTTPResult(response.status_code, response.get_data()), None)
async def do_multi_http(self, online_hosts: List[OnlineHost], url_suffix: str, body: str, authorization_header=None) -> List[InterThreadResult]:
tasks = []
# append to tasks in the same order as online_hosts
for host in online_hosts:
tasks.append(
self.do_http(url=f"{host.url}{url_suffix}", body=body, authorization_header=authorization_header)
)
# gather is like Promise.all from javascript, it returns a future which resolves to an array of results
# in the same order as the tasks that we passed in -- which were in the same order as online_hosts
results = await asyncio.gather(*tasks)
return results

View File

@ -6,7 +6,7 @@ To run tests:
1. create a Postgres database called `capsulflask_test` 1. create a Postgres database called `capsulflask_test`
- e.g.: `docker exec -it 98e1ddfbbffb createdb -U postgres -O postgres capsulflask_test` - e.g.: `docker exec -it 98e1ddfbbffb createdb -U postgres -O postgres capsulflask_test`
- (`98e1ddfbbffb` is the docker container ID of the postgres container) - (`98e1ddfbbffb` is the docker container ID of the postgres container)
2. run `python3 -m unittest && cat unittest-log-output.log && rm unittest-log-output.log` 2. run `python3 -m unittest; cat unittest-log-output.log; rm unittest-log-output.log`
**NOTE** that right now we can't figure out how to get the tests to properly output the log messages that happened when they failed, (or passed), so for now we have hacked it to write to a file. **NOTE** that right now we can't figure out how to get the tests to properly output the log messages that happened when they failed, (or passed), so for now we have hacked it to write to a file.