import matplotlib.ticker as ticker
import matplotlib.pyplot as pyplot
import matplotlib.dates as mdates
from functools import reduce 
import requests
import json
from datetime import datetime
from threading import Lock
from io import BytesIO
from flask import Blueprint
from flask import current_app
from flask import session
from flask import render_template, make_response
from werkzeug.exceptions import abort

from capsulflask.db import get_model
from capsulflask.auth import account_required

mutex = Lock()
bp = Blueprint("metrics", __name__, url_prefix="/metrics")

durations = dict(
  _5m=[60*5, 15],
  _1h=[60*60, 60],
  _1d=[60*60*24, 60*20],
  _30d=[60*60*24*30, 60*300]
)
sizes = dict(
  s=[0.77, 0.23, 4],
  m=[1, 1, 2],
  l=[6, 4, 1],
)

green = (121/255, 240/255, 50/255)
blue = (70/255, 150/255, 255/255)
red = (255/255, 50/255, 8/255)

gray = (128/255, 128/255, 128/255)

@bp.route("/html/<string:metric>/<string:capsulid>/<string:duration>")
@account_required
def display_metric(metric, capsulid, duration):
  vm = get_model().get_vm_detail(session["account"], capsulid)

  return render_template(
    "display-metric.html", 
    vm=vm, 
    duration=duration, 
    durations=list(map(lambda x: x.strip("_"), durations.keys())),
    metric=metric
  )


@bp.route("/<string:metric>/<string:capsulid>/<string:duration>/<string:size>")
@account_required
def metric_png(metric, capsulid, duration, size):
  result = get_plot_bytes(metric, capsulid, duration, size)

  if result[0] != 200:
    abort(result[0])

  response = make_response(result[1])
  response.headers.set('Content-Type', 'image/png')
  return response


def get_plot_bytes(metric, capsulid, duration, size):

  duration = f"_{duration}"

  if duration not in durations:
    return (404, None)

  if size not in sizes:
    return (404, None)

  vm = get_model().get_vm_detail(session["account"], capsulid)

  if not vm:
    return (404, None)

  now_unix = int(datetime.strftime(datetime.now(), "%s"))
  duration_seconds = durations[duration][0]
  interval_seconds = durations[duration][1] * sizes[size][2]
  if interval_seconds < 30:
    interval_seconds = 30

  # Prometheus queries to pull metrics for VMs
  metric_queries = dict(
    cpu=f"irate(libvirtd_domain_info_cpu_time_seconds_total{{domain='{capsulid}'}}[30s])",
    memory=f"libvirtd_domain_info_maximum_memory_bytes{{domain='{capsulid}'}}-libvirtd_domain_info_memory_unused_bytes{{domain='{capsulid}'}}",
    network_in=f"rate(libvirtd_domain_interface_stats_receive_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])",
    network_out=f"rate(libvirtd_domain_interface_stats_transmit_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])",
    disk=f"rate(libvirtd_domain_block_stats_read_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])%2Brate(libvirtd_domain_block_stats_write_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])",
  )

  # These represent the top of the graph for graphs that are designed to be viewed at a glance.
  # they are also used to colorize the graph at any size.
  scales = dict(
    cpu=vm["vcpus"],
    memory=vm["memory_mb"]*1024*1024,
    network_in=1024*1024*2,
    network_out=1024*200,
    disk=1024*1024*8,
  )

  if metric not in metric_queries:
    return (404, None)

  range_and_interval = f"start={now_unix-duration_seconds}&end={now_unix}&step={interval_seconds}"

  prometheus_range_url = f"{current_app.config['PROMETHEUS_URL']}/api/v1/query_range"

  #print(f"{prometheus_range_url}?query={metric_queries[metric]}&{range_and_interval}")

  prometheus_response = requests.get(f"{prometheus_range_url}?query={metric_queries[metric]}&{range_and_interval}")
  if prometheus_response.status_code >= 300:
    return (502, None)

  series = prometheus_response.json()["data"]["result"]
  
  if len(series) == 0:
    now_timestamp = datetime.timestamp(datetime.now())
    series = [
      dict(
        values=[[now_timestamp - interval_seconds, float(0)],[now_timestamp, float(0)]]
      )
    ]

  time_series_data = list(map(
    lambda x: (datetime.fromtimestamp(x[0]), float(x[1])), 
    series[0]["values"]
  ))

  mutex.acquire()
  try:
    plot_bytes = draw_plot_png_bytes(time_series_data, scale=scales[metric], size_x=sizes[size][0], size_y=sizes[size][1])
  finally:
    mutex.release()

  return (200, plot_bytes)


def draw_plot_png_bytes(data, scale, size_x=3, size_y=1):

  #current_app.logger.info(json.dumps(data, indent=4, default=str))

  pyplot.style.use("seaborn-dark")
  fig, my_plot = pyplot.subplots(figsize=(size_x, size_y)) 

  # x=range(1, 15)
  # y=[1,4,6,8,4,5,3,2,4,1,5,6,8,7]

  divide_by = 1
  unit = ""

  if scale > 1024 and scale < 1024*1024*1024:
    divide_by = 1024*1024
    unit = "MB"
  if scale > 1024*1024*1024:
    divide_by = 1024*1024*1024
    unit = "GB"
  
  scale /= divide_by

  if scale > 10:
    my_plot.get_yaxis().set_major_formatter( ticker.FuncFormatter(lambda x, p: "{}{}".format(int(x), unit)) )
  elif scale > 1:
    my_plot.get_yaxis().set_major_formatter( ticker.FuncFormatter(lambda x, p: "{:.1f}{}".format(x, unit)) )
  else:
    my_plot.get_yaxis().set_major_formatter( ticker.FuncFormatter(lambda x, p: "{:.2f}{}".format(x, unit)) )

  x=list(map(lambda x: x[0], data))
  y=list(map(lambda x: x[1]/divide_by, data))

  minutes = float((x[len(x)-1] - x[0]).total_seconds())/float(60)
  hours = minutes/float(60)
  days = hours/float(24)

  week_locator = mdates.WeekdayLocator()
  minute_locator = mdates.MinuteLocator()
  ten_minute_locator = mdates.MinuteLocator(interval=10)
  hour_locator = mdates.HourLocator(interval=6)
  hour_minute_formatter = mdates.DateFormatter('%H:%M')
  day_formatter = mdates.DateFormatter('%b %d')

  if minutes < 10:
    my_plot.xaxis.set_major_locator(minute_locator)
    my_plot.xaxis.set_major_formatter(hour_minute_formatter)
  elif hours < 2:
    my_plot.xaxis.set_major_locator(ten_minute_locator)
    my_plot.xaxis.set_major_formatter(hour_minute_formatter)
  elif days < 2:
    my_plot.xaxis.set_major_locator(hour_locator)
    my_plot.xaxis.set_major_formatter(hour_minute_formatter)
  else:
    my_plot.xaxis.set_major_locator(week_locator)
    my_plot.xaxis.set_major_formatter(day_formatter)

  max_value = reduce(lambda a, b: a if a > b else b, y, scale)

  if len(data) > 2:
    average=(sum(y)/len(y))/scale
    average=average*1.25+0.1

    bg_color=color_gradient(average)

    average -= 0.1

    fill_color=color_gradient(average)
    highlight_color=lerp_rgb_tuples(fill_color, (1,1,1), 0.5)
  else:
    bg_color = fill_color = highlight_color = gray

  my_plot.fill_between( x, max_value, color=bg_color, alpha=0.13)
  my_plot.fill_between( x, y, color=highlight_color, alpha=0.3)
  my_plot.plot(x, y, 'r-', color=highlight_color)

  if size_y < 4:
    my_plot.set_yticks([0, scale])
    my_plot.set_ylim(0, scale)

  my_plot.xaxis.label.set_color(highlight_color)
  my_plot.tick_params(axis='x', colors=highlight_color)
  my_plot.yaxis.label.set_color(highlight_color)
  my_plot.tick_params(axis='y', colors=highlight_color)

  if size_x < 4:
    my_plot.set_xticklabels([])
  if size_y < 1:
    my_plot.set_yticklabels([])

  image_binary = BytesIO()
  fig.savefig(image_binary, transparent=True, bbox_inches="tight", pad_inches=0.05)

  pyplot.close('all')

  return image_binary.getvalue()


def lerp_rgb_tuples(a, b, lerp):
  if lerp < 0:
    lerp = 0
  if lerp > 1:
    lerp = 1
  return (
    a[0]*(1.0-lerp)+b[0]*lerp,
    a[1]*(1.0-lerp)+b[1]*lerp,
    a[2]*(1.0-lerp)+b[2]*lerp
  )

def color_gradient(value):
  if value < 0:
    value = 0
  if value > 1:
    value = 1
  if value < 0.5:
    return lerp_rgb_tuples(green, blue, value*2)
  else:
    return lerp_rgb_tuples(blue, red, (value-0.5)*2)