create TestHTTPClient that uses werkzueg test client, tests are passing
This commit is contained in:
@ -21,11 +21,9 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from capsulflask.shared import *
|
||||
from capsulflask import hub_model, spoke_model, cli
|
||||
from capsulflask.btcpay import client as btcpay
|
||||
from capsulflask.http_client import MyHTTPClient
|
||||
|
||||
|
||||
|
||||
def create_app():
|
||||
def create_app(http_client_factory):
|
||||
for var_name in [
|
||||
"SPOKE_HOST_TOKEN", "HUB_TOKEN", "STRIPE_SECRET_KEY",
|
||||
"BTCPAY_PRIVATE_KEY", "MAIL_PASSWORD"
|
||||
@ -133,7 +131,10 @@ def create_app():
|
||||
mylog_warning(app, "No MAIL_SERVER configured. capsul will simply print emails to stdout.")
|
||||
app.config['FLASK_MAIL_INSTANCE'] = StdoutMockFlaskMail()
|
||||
|
||||
app.config['HTTP_CLIENT'] = MyHTTPClient(timeout_seconds=int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS']))
|
||||
# allow a mock http client to be injected by the test code.
|
||||
app.config['HTTP_CLIENT'] = http_client_factory(int(app.config['INTERNAL_HTTP_TIMEOUT_SECONDS']))
|
||||
|
||||
|
||||
|
||||
app.config['BTCPAY_ENABLED'] = False
|
||||
if app.config['BTCPAY_URL'] != "":
|
||||
@ -160,7 +161,7 @@ def create_app():
|
||||
|
||||
# debug mode (flask reloader) runs two copies of the app. When running in debug mode,
|
||||
# we only want to start the scheduler one time.
|
||||
if is_running_server and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'):
|
||||
if is_running_server and not app.config['TESTING'] and (not app.debug or config.get('WERKZEUG_RUN_MAIN') == 'true'):
|
||||
scheduler = BackgroundScheduler()
|
||||
heartbeat_task_url = f"{app.config['HUB_URL']}/hub/heartbeat-task"
|
||||
heartbeat_task_headers = {'Authorization': f"Bearer {app.config['HUB_TOKEN']}"}
|
||||
|
@ -195,13 +195,6 @@ def detail(id):
|
||||
@account_required
|
||||
def create():
|
||||
|
||||
#raise "console.create()!"
|
||||
# file_object = open('unittest-output.log', 'a')
|
||||
# file_object.write("console.create()!\n")
|
||||
# file_object.close()
|
||||
|
||||
mylog_error(current_app, "console.create()!")
|
||||
|
||||
vm_sizes = get_model().vm_sizes_dict()
|
||||
operating_systems = get_model().operating_systems_dict()
|
||||
public_keys_for_account = get_model().list_ssh_public_keys_for_account(session["account"])
|
||||
|
@ -196,7 +196,6 @@ class CapsulFlaskHub(VirtualizationInterface):
|
||||
def create(self, email: str, id: str, os: str, size: str, template_image_file_name: str, vcpus: int, memory_mb: int, ssh_authorized_keys: list):
|
||||
validate_capsul_id(id)
|
||||
online_hosts = get_model().get_online_hosts()
|
||||
mylog_debug(current_app, f"hub_model.create(): ${len(online_hosts)} hosts")
|
||||
payload = json.dumps(dict(
|
||||
type="create",
|
||||
email=email,
|
||||
|
@ -2,6 +2,7 @@ import re
|
||||
|
||||
from flask import current_app, Flask
|
||||
from typing import List
|
||||
from threading import Lock
|
||||
|
||||
class OnlineHost:
|
||||
def __init__(self, id: str, url: str):
|
||||
@ -58,8 +59,10 @@ def my_exec_info_message(exec_info):
|
||||
|
||||
|
||||
|
||||
|
||||
mylog_current_test_id_container = {
|
||||
'value': '',
|
||||
'mutex': Lock()
|
||||
}
|
||||
def set_mylog_test_id(test_id):
|
||||
mylog_current_test_id_container['value'] = ".".join(test_id.split(".")[-2:])
|
||||
@ -67,10 +70,11 @@ def set_mylog_test_id(test_id):
|
||||
|
||||
def log_output_for_tests(app: Flask, message: str):
|
||||
if app.config['TESTING'] and mylog_current_test_id_container['value'] != "":
|
||||
mylog_current_test_id_container['mutex'].acquire()
|
||||
file_object = open('unittest-log-output.log', 'a')
|
||||
file_object.write(f"{mylog_current_test_id_container['value']}: {message}\n")
|
||||
file_object.close()
|
||||
|
||||
mylog_current_test_id_container['mutex'].release()
|
||||
|
||||
def mylog_debug(app: Flask, message: str):
|
||||
log_output_for_tests(app, f"DEBUG: {message}")
|
||||
|
@ -43,9 +43,10 @@ def operation_without_id():
|
||||
|
||||
def operation_impl(operation_id: int):
|
||||
if authorized_as_hub(request.headers):
|
||||
request_body_json = request.json
|
||||
request_body = json.loads(request_body_json)
|
||||
#mylog_info(current_app, f"request.json: {request_body}")
|
||||
request_body = request.json
|
||||
if not isinstance(request.json, dict) and not isinstance(request.json, list):
|
||||
request_body = json.loads(request.json)
|
||||
|
||||
handlers = {
|
||||
"capacity_avaliable": handle_capacity_avaliable,
|
||||
"get": handle_get,
|
||||
|
@ -8,6 +8,7 @@ from flask import url_for
|
||||
|
||||
from capsulflask.db import get_model
|
||||
from capsulflask.tests_base import BaseTestCase
|
||||
from capsulflask.shared import *
|
||||
from capsulflask.spoke_model import MockSpoke
|
||||
|
||||
|
||||
@ -25,6 +26,30 @@ class ConsoleTests(BaseTestCase):
|
||||
"content": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDntq1t8Ddsa2q4p+PM7W4CLYYmxakokRRVLlf7AQlsTJFPsgBe9u0zuoOaKDMkBr0dlnuLm4Eub1Mj+BrdqAokto0YDiAnxUKRuYQKuHySKK8bLkisi2k47jGBDikx/jihgiuFTawo1mYsJJepC7PPwZGsoCImJEgq1L+ug0p3Zrj3QkUx4h25MpCSs2yvfgWjDyN8hEC76O42P+4ETezYrzrd1Kj26hdzHRnrxygvIUOtfau+5ydlaz8xQBEPrEY6/+pKDuwtXg1pBL7GmoUxBXVfHQSgq5s9jIJH+G0CR0ZoHMB25Ln4X/bsCQbLOu21+IGYKSDVM5TIMLtkKUkERQMVWvnpOp1LZKir4dC0m7SW74wpA8+2b1IsURIr9ARYGJpCEv1Q1Wz/X3yTf6Mfey7992MjUc9HcgjgU01/+kYomoXHprzolk+22Gjfgo3a4dRIoTY82GO8kkUKiaWHvDkkVURCY5dpteLA05sk3Z9aRMYsNXPLeOOPfzTlDA0="
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
|
||||
get_model().cursor.execute("DELETE FROM host_operation")
|
||||
get_model().cursor.execute("DELETE FROM operations")
|
||||
get_model().cursor.execute("DELETE FROM vm_ssh_host_key")
|
||||
get_model().cursor.execute("DELETE FROM vm_ssh_authorized_key")
|
||||
get_model().cursor.execute("DELETE FROM ssh_public_keys")
|
||||
get_model().cursor.execute("DELETE FROM login_tokens")
|
||||
get_model().cursor.execute("DELETE FROM vms")
|
||||
get_model().cursor.execute("DELETE FROM payments")
|
||||
get_model().cursor.connection.commit()
|
||||
|
||||
self._login('test@example.com')
|
||||
get_model().create_ssh_public_key('test@example.com', 'key', 'foo')
|
||||
|
||||
# heartbeat all the spokes so that the hub <--> spoke communication can work as normal.
|
||||
host_ids = get_model().list_hosts_with_networks(None).keys()
|
||||
for host_id in host_ids:
|
||||
get_model().host_heartbeat(host_id)
|
||||
|
||||
|
||||
|
||||
def test_index(self):
|
||||
self._login('test@example.com')
|
||||
with self.client as client:
|
||||
@ -84,10 +109,6 @@ class ConsoleTests(BaseTestCase):
|
||||
0
|
||||
)
|
||||
|
||||
file_object = open('unittest-output.log', 'a')
|
||||
file_object.write(f"{self.id()} captured output:\n{self.logs_from_test.getvalue()}\n")
|
||||
file_object.close()
|
||||
|
||||
def test_create_fails_invalid(self):
|
||||
with self.client as client:
|
||||
client.get(url_for("console.create"))
|
||||
@ -121,9 +142,8 @@ class ConsoleTests(BaseTestCase):
|
||||
|
||||
response = client.post(url_for("console.create"), data=data)
|
||||
|
||||
|
||||
# mylog_info(self.app, get_model().list_all_operations())
|
||||
|
||||
|
||||
self.assertEqual(
|
||||
len(get_model().list_all_operations()),
|
||||
1
|
||||
@ -196,16 +216,4 @@ class ConsoleTests(BaseTestCase):
|
||||
category='message'
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._login('test@example.com')
|
||||
get_model().create_ssh_public_key('test@example.com', 'key', 'foo')
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
get_model().cursor.execute("DELETE FROM ssh_public_keys")
|
||||
get_model().cursor.execute("DELETE FROM login_tokens")
|
||||
get_model().cursor.execute("DELETE FROM vms")
|
||||
get_model().cursor.execute("DELETE FROM payments")
|
||||
get_model().cursor.connection.commit()
|
||||
|
||||
|
@ -1,14 +1,26 @@
|
||||
from io import StringIO
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import itertools
|
||||
import time
|
||||
import threading
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from typing import List
|
||||
from nanoid import generate
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from flask_testing import TestCase
|
||||
from flask import current_app
|
||||
|
||||
from capsulflask import create_app
|
||||
from capsulflask.db import get_model
|
||||
from capsulflask.http_client import *
|
||||
from capsulflask.shared import *
|
||||
|
||||
class BaseTestCase(TestCase):
|
||||
@ -19,7 +31,9 @@ class BaseTestCase(TestCase):
|
||||
os.environ['LOG_LEVEL'] = 'DEBUG'
|
||||
os.environ['SPOKE_MODEL'] = 'mock'
|
||||
os.environ['HUB_MODEL'] = 'capsul-flask'
|
||||
self.app = create_app()
|
||||
self1 = self
|
||||
get_app = lambda: self1.app
|
||||
self.app = create_app(lambda timeout_seconds: TestHTTPClient(get_app, timeout_seconds))
|
||||
return self.app
|
||||
|
||||
def setUp(self):
|
||||
@ -33,3 +47,73 @@ class BaseTestCase(TestCase):
|
||||
with self.client.session_transaction() as session:
|
||||
session['account'] = user_email
|
||||
session['csrf-token'] = generate()
|
||||
|
||||
class TestHTTPClient:
|
||||
def __init__(self, get_app, timeout_seconds = 5):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.get_app = get_app
|
||||
self.executor = ThreadPoolExecutor()
|
||||
|
||||
|
||||
def do_multi_http_sync(self, online_hosts: List[OnlineHost], url_suffix: str, body: str, authorization_header=None) -> List[HTTPResult]:
|
||||
future = run_coroutine(self.do_multi_http(online_hosts=online_hosts, url_suffix=url_suffix, body=body, authorization_header=authorization_header))
|
||||
fromOtherThread = future.result()
|
||||
toReturn = []
|
||||
for individualResult in fromOtherThread:
|
||||
if individualResult.error != None and individualResult.error != "":
|
||||
mylog_error(self.get_app(), individualResult.error)
|
||||
toReturn.append(individualResult.http_result)
|
||||
|
||||
return toReturn
|
||||
|
||||
def do_http_sync(self, url: str, body: str, method="POST", authorization_header=None) -> HTTPResult:
|
||||
future = run_coroutine(self.do_http(method=method, url=url, body=body, authorization_header=authorization_header))
|
||||
fromOtherThread = future.result()
|
||||
if fromOtherThread.error != None and fromOtherThread.error != "":
|
||||
mylog_error(self.get_app(), fromOtherThread.error)
|
||||
return fromOtherThread.http_result
|
||||
|
||||
async def do_http(self, url: str, body: str, method="POST", authorization_header=None) -> InterThreadResult:
|
||||
path = urlparse(url).path
|
||||
|
||||
headers = {}
|
||||
if authorization_header != None:
|
||||
headers['Authorization'] = authorization_header
|
||||
if body:
|
||||
headers['Content-Type'] = "application/json"
|
||||
|
||||
#mylog_info(self.get_app(), f"path, data=body, headers=headers: {path}, {body}, {headers}")
|
||||
|
||||
do_request = None
|
||||
if method == "POST":
|
||||
do_request = lambda: self.get_app().test_client().post(path, data=body, headers=headers)
|
||||
if method == "GET":
|
||||
do_request = lambda: self.get_app().test_client().get(path, headers=headers)
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = await get_event_loop().run_in_executor(self.executor, do_request)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error_message = my_exec_info_message(sys.exc_info())
|
||||
response_body = json.dumps({"error_message": f"do_http (HTTP {method} {url}) {error_message}"})
|
||||
|
||||
return InterThreadResult(
|
||||
HTTPResult(-1, response_body),
|
||||
f"""do_http (HTTP {method} {url}) failed with: {error_message}"""
|
||||
)
|
||||
|
||||
return InterThreadResult(HTTPResult(response.status_code, response.get_data()), None)
|
||||
|
||||
async def do_multi_http(self, online_hosts: List[OnlineHost], url_suffix: str, body: str, authorization_header=None) -> List[InterThreadResult]:
|
||||
tasks = []
|
||||
# append to tasks in the same order as online_hosts
|
||||
for host in online_hosts:
|
||||
tasks.append(
|
||||
self.do_http(url=f"{host.url}{url_suffix}", body=body, authorization_header=authorization_header)
|
||||
)
|
||||
# gather is like Promise.all from javascript, it returns a future which resolves to an array of results
|
||||
# in the same order as the tasks that we passed in -- which were in the same order as online_hosts
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return results
|
Reference in New Issue
Block a user