forked from 3wordchant/capsul-flask
120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
|
|
import logging
|
|
import unittest
|
|
import os
|
|
import sys
|
|
import json
|
|
import itertools
|
|
import time
|
|
import threading
|
|
import asyncio
|
|
import traceback
|
|
|
|
from urllib.parse import urlparse
|
|
from typing import List
|
|
from nanoid import generate
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from flask_testing import TestCase
|
|
from flask import current_app
|
|
|
|
from capsulflask import create_app
|
|
from capsulflask.db import get_model
|
|
from capsulflask.http_client import *
|
|
from capsulflask.shared import *
|
|
|
|
class BaseTestCase(TestCase):
|
|
def create_app(self):
|
|
# Use default connection paramaters
|
|
os.environ['POSTGRES_CONNECTION_PARAMETERS'] = "host=localhost port=5432 user=postgres password=dev dbname=capsulflask_test"
|
|
os.environ['TESTING'] = 'True'
|
|
os.environ['LOG_LEVEL'] = 'DEBUG'
|
|
os.environ['SPOKE_MODEL'] = 'mock'
|
|
os.environ['HUB_MODEL'] = 'capsul-flask'
|
|
self1 = self
|
|
get_app = lambda: self1.app
|
|
self.app = create_app(lambda timeout_seconds: TestHTTPClient(get_app, timeout_seconds))
|
|
return self.app
|
|
|
|
def setUp(self):
|
|
set_mylog_test_id(self.id())
|
|
|
|
def tearDown(self):
|
|
set_mylog_test_id("")
|
|
|
|
def _login(self, user_email):
|
|
get_model().login(user_email)
|
|
with self.client.session_transaction() as session:
|
|
session['account'] = user_email
|
|
session['csrf-token'] = generate()
|
|
|
|
class TestHTTPClient:
|
|
def __init__(self, get_app, timeout_seconds = 5):
|
|
self.timeout_seconds = timeout_seconds
|
|
self.get_app = get_app
|
|
self.executor = ThreadPoolExecutor()
|
|
|
|
|
|
def do_multi_http_sync(self, online_hosts: List[OnlineHost], url_suffix: str, body: str, authorization_header=None) -> List[HTTPResult]:
|
|
future = run_coroutine(self.do_multi_http(online_hosts=online_hosts, url_suffix=url_suffix, body=body, authorization_header=authorization_header))
|
|
fromOtherThread = future.result()
|
|
toReturn = []
|
|
for individualResult in fromOtherThread:
|
|
if individualResult.error != None and individualResult.error != "":
|
|
mylog_error(self.get_app(), individualResult.error)
|
|
toReturn.append(individualResult.http_result)
|
|
|
|
return toReturn
|
|
|
|
def do_http_sync(self, url: str, body: str, method="POST", authorization_header=None) -> HTTPResult:
|
|
future = run_coroutine(self.do_http(method=method, url=url, body=body, authorization_header=authorization_header))
|
|
fromOtherThread = future.result()
|
|
if fromOtherThread.error != None and fromOtherThread.error != "":
|
|
mylog_error(self.get_app(), fromOtherThread.error)
|
|
return fromOtherThread.http_result
|
|
|
|
async def do_http(self, url: str, body: str, method="POST", authorization_header=None) -> InterThreadResult:
|
|
path = urlparse(url).path
|
|
|
|
headers = {}
|
|
if authorization_header != None:
|
|
headers['Authorization'] = authorization_header
|
|
if body:
|
|
headers['Content-Type'] = "application/json"
|
|
|
|
#mylog_info(self.get_app(), f"path, data=body, headers=headers: {path}, {body}, {headers}")
|
|
|
|
do_request = None
|
|
if method == "POST":
|
|
do_request = lambda: self.get_app().test_client().post(path, data=body, headers=headers)
|
|
if method == "GET":
|
|
do_request = lambda: self.get_app().test_client().get(path, headers=headers)
|
|
|
|
response = None
|
|
try:
|
|
response = await get_event_loop().run_in_executor(self.executor, do_request)
|
|
except:
|
|
traceback.print_exc()
|
|
error_message = my_exec_info_message(sys.exc_info())
|
|
response_body = json.dumps({"error_message": f"do_http (HTTP {method} {url}) {error_message}"})
|
|
|
|
return InterThreadResult(
|
|
HTTPResult(-1, response_body),
|
|
f"""do_http (HTTP {method} {url}) failed with: {error_message}"""
|
|
)
|
|
|
|
return InterThreadResult(HTTPResult(response.status_code, response.get_data()), None)
|
|
|
|
async def do_multi_http(self, online_hosts: List[OnlineHost], url_suffix: str, body: str, authorization_header=None) -> List[InterThreadResult]:
|
|
tasks = []
|
|
# append to tasks in the same order as online_hosts
|
|
for host in online_hosts:
|
|
tasks.append(
|
|
self.do_http(url=f"{host.url}{url_suffix}", body=body, authorization_header=authorization_header)
|
|
)
|
|
# gather is like Promise.all from javascript, it returns a future which resolves to an array of results
|
|
# in the same order as the tasks that we passed in -- which were in the same order as online_hosts
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
return results
|