import matplotlib.ticker as ticker
import matplotlib.pyplot as pyplot
import matplotlib.dates as mdates
from functools import reduce 
import requests
#import json
from datetime import datetime
from io import BytesIO
from flask import Blueprint
from flask import current_app
from flask import session
from flask import render_template, make_response
from werkzeug.exceptions import abort

from capsulflask.db import get_model
from capsulflask.auth import account_required

bp = Blueprint("metrics", __name__, url_prefix="/metrics")

durations = dict(
  _5m=[60*5, 15],
  _1h=[60*60, 60],
  _1d=[60*60*24, 60*20],
  _30d=[60*60*24*30, 60*300]
)
sizes = dict(
  s=[0.77, 0.23, 4],
  m=[1, 1, 2],
  l=[6, 4, 1],
)

green = (121/255, 240/255, 50/255)
blue = (70/255, 150/255, 255/255)
red = (255/255, 50/255, 8/255)

@bp.route("/html/<string:metric>/<string:capsulid>/<string:duration>")
@account_required
def display_metric(metric, capsulid, duration):
  vm = get_model().get_vm_detail(session["account"], capsulid)

  return render_template(
    "display-metric.html", 
    vm=vm, 
    duration=duration, 
    durations=list(map(lambda x: x.strip("_"), durations.keys())),
    metric=metric
  )


@bp.route("/<string:metric>/<string:capsulid>/<string:duration>/<string:size>")
@account_required
def metric_png(metric, capsulid, duration, size):
  result = get_plot_bytes(metric, capsulid, duration, size)

  if result[0] != 200:
    abort(result[0])

  response = make_response(result[1])
  response.headers.set('Content-Type', 'image/png')
  return response


def get_plot_bytes(metric, capsulid, duration, size):

  duration = f"_{duration}"

  if duration not in durations:
    return (404, None)

  if size not in sizes:
    return (404, None)

  vm = get_model().get_vm_detail(session["account"], capsulid)

  if not vm:
    return (404, None)

  now_unix = int(datetime.strftime(datetime.now(), "%s"))
  duration_seconds = durations[duration][0]
  interval_seconds = durations[duration][1] * sizes[size][2]
  if interval_seconds < 30:
    interval_seconds = 30

  # Prometheus queries to pull metrics for VMs
  metric_queries = dict(
    cpu=f"irate(libvirtd_domain_info_cpu_time_seconds_total{{domain='{capsulid}'}}[30s])",
    memory=f"libvirtd_domain_info_memory_usage_bytes{{domain='{capsulid}'}}",
    network_in=f"rate(libvirtd_domain_interface_stats_receive_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])",
    network_out=f"rate(libvirtd_domain_interface_stats_transmit_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])",
    disk=f"rate(libvirtd_domain_block_stats_read_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])%2Brate(libvirtd_domain_block_stats_write_bytes_total{{domain='{capsulid}'}}[{interval_seconds}s])",
  )

  # These represent the top of the graph for graphs that are designed to be viewed at a glance.
  # they are also used to colorize the graph at any size.
  scales = dict(
    cpu=vm["vcpus"],
    memory=vm["memory_mb"]*1024*1024,
    network_in=1024*1024*2,
    network_out=1024*200,
    disk=1024*1024*8,
  )

  if metric not in metric_queries:
    return (404, None)

  range_and_interval = f"start={now_unix-duration_seconds}&end={now_unix}&step={interval_seconds}"

  prometheus_range_url = f"{current_app.config['PROMETHEUS_URL']}/api/v1/query_range"

  #print(f"{prometheus_range_url}?query={metric_queries[metric]}&{range_and_interval}")

  prometheus_response = requests.get(f"{prometheus_range_url}?query={metric_queries[metric]}&{range_and_interval}")
  if prometheus_response.status_code >= 300:
    return (502, None)

  if len(prometheus_response.json()["data"]["result"]) == 0:
    return (404, None)

  time_series_data = list(map(
    lambda x: (datetime.fromtimestamp(x[0]), float(x[1])), 
    prometheus_response.json()["data"]["result"][0]["values"]
  ))

  plot_bytes = draw_plot_png_bytes(time_series_data, scale=scales[metric], size_x=sizes[size][0], size_y=sizes[size][1])

  return (200, plot_bytes)


def draw_plot_png_bytes(data, scale, size_x=3, size_y=1):

  pyplot.style.use("seaborn-dark")
  fig, my_plot = pyplot.subplots(figsize=(size_x, size_y)) 

  # x=range(1, 15)
  # y=[1,4,6,8,4,5,3,2,4,1,5,6,8,7]

  divide_by = 1
  unit = ""

  if scale > 1024 and scale < 1024*1024*1024:
    divide_by = 1024*1024
    unit = "MB"
  if scale > 1024*1024*1024:
    divide_by = 1024*1024*1024
    unit = "GB"
  
  scale /= divide_by

  if scale > 10:
    my_plot.get_yaxis().set_major_formatter( ticker.FuncFormatter(lambda x, p: "{}{}".format(int(x), unit)) )
  elif scale > 1:
    my_plot.get_yaxis().set_major_formatter( ticker.FuncFormatter(lambda x, p: "{:.1f}{}".format(x, unit)) )
  else:
    my_plot.get_yaxis().set_major_formatter( ticker.FuncFormatter(lambda x, p: "{:.2f}{}".format(x, unit)) )

  x=list(map(lambda x: x[0], data))
  y=list(map(lambda x: x[1]/divide_by, data))

  minutes = float((x[len(x)-1] - x[0]).total_seconds())/float(60)
  hours = minutes/float(60)
  days = hours/float(24)

  week_locator = mdates.WeekdayLocator()
  minute_locator = mdates.MinuteLocator()
  ten_minute_locator = mdates.MinuteLocator(interval=10)
  hour_locator = mdates.HourLocator(interval=6)
  hour_minute_formatter = mdates.DateFormatter('%H:%M')
  day_formatter = mdates.DateFormatter('%b %d')

  if minutes < 10:
    my_plot.xaxis.set_major_locator(minute_locator)
    my_plot.xaxis.set_major_formatter(hour_minute_formatter)
  elif hours < 2:
    my_plot.xaxis.set_major_locator(ten_minute_locator)
    my_plot.xaxis.set_major_formatter(hour_minute_formatter)
  elif days < 2:
    my_plot.xaxis.set_major_locator(hour_locator)
    my_plot.xaxis.set_major_formatter(hour_minute_formatter)
  else:
    my_plot.xaxis.set_major_locator(week_locator)
    my_plot.xaxis.set_major_formatter(day_formatter)

  max_value = reduce(lambda a, b: a if a > b else b, y, scale)

  average=(sum(y)/len(y))/scale
  average=average*1.25+0.1

  bg_color=color_gradient(average)

  average -= 0.1

  fill_color=color_gradient(average)
  highlight_color=lerp_rgb_tuples(fill_color, (1,1,1), 0.5)

  my_plot.fill_between( x, max_value, color=bg_color, alpha=0.13)
  my_plot.fill_between( x, y, color=highlight_color, alpha=0.3)
  my_plot.plot(x, y, 'r-', color=highlight_color)

  if size_y < 4:
    my_plot.set_yticks([0, scale])
    my_plot.set_ylim(0, scale)

  my_plot.xaxis.label.set_color(highlight_color)
  my_plot.tick_params(axis='x', colors=highlight_color)
  my_plot.yaxis.label.set_color(highlight_color)
  my_plot.tick_params(axis='y', colors=highlight_color)

  if size_x < 4:
    my_plot.set_xticklabels([])
  if size_y < 1:
    my_plot.set_yticklabels([])

  image_binary = BytesIO()
  fig.savefig(image_binary, transparent=True, bbox_inches="tight", pad_inches=0.05)

  pyplot.close('all')

  return image_binary.getvalue()


def lerp_rgb_tuples(a, b, lerp):
  if lerp < 0:
    lerp = 0
  if lerp > 1:
    lerp = 1
  return (
    a[0]*(1.0-lerp)+b[0]*lerp,
    a[1]*(1.0-lerp)+b[1]*lerp,
    a[2]*(1.0-lerp)+b[2]*lerp
  )

def color_gradient(value):
  if value < 0:
    value = 0
  if value > 1:
    value = 1
  if value < 0.5:
    return lerp_rgb_tuples(green, blue, value*2)
  else:
    return lerp_rgb_tuples(blue, red, (value-0.5)*2)