from __future__ import annotations

import csv
import json
import os
import sys
import urllib.parse
import urllib.request
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Callable


HOST = os.environ.get("HOST", "0.0.0.0")
PORT = int(os.environ.get("PORT", "8787"))
BZN = "AT"
ROOT = Path(__file__).resolve().parent
CACHE_DIR = Path(os.environ.get("EPEX_CACHE_DIR", ROOT / "epex_price_cache"))
AWATTAR_URL = "https://api.awattar.at/v1/marketdata"


def resolve_static_root() -> Path:
    candidates = [
        Path(os.environ.get("STATIC_ROOT", "")),
        ROOT,
        Path(sys.prefix),
        Path(sys.prefix) / "local",
        Path("/app"),
    ]
    for candidate in candidates:
        if candidate and (candidate / "energy_spot_dashboard.html").exists():
            return candidate
    return ROOT


STATIC_ROOT = resolve_static_root()


def cache_path(cache_dir: Path, start: int, end: int, bzn: str = BZN) -> Path:
    return cache_dir / f"epex_{bzn}_{start}_{end}.csv"


def read_cache(path: Path) -> dict:
    with path.open(newline="", encoding="utf-8") as f:
        rows = []
        for row in csv.DictReader(f):
            rows.append(
                {
                    "start_timestamp": int(row["start_timestamp"]),
                    "end_timestamp": int(row["end_timestamp"]),
                    "marketprice": float(row["marketprice"]),
                    "unit": row.get("unit") or "Eur/MWh",
                }
            )
    return {"object": "list", "data": rows}


def write_cache(path: Path, payload: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    data = payload.get("data") or []
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=["start_timestamp", "end_timestamp", "marketprice", "unit"],
        )
        writer.writeheader()
        for row in data:
            writer.writerow(
                {
                    "start_timestamp": int(row["start_timestamp"]),
                    "end_timestamp": int(row["end_timestamp"]),
                    "marketprice": float(row["marketprice"]),
                    "unit": row.get("unit") or "Eur/MWh",
                }
            )


def fetch_awattar_json(url: str) -> dict:
    with urllib.request.urlopen(url, timeout=30) as response:
        return json.loads(response.read().decode("utf-8"))


def get_prices(
    start: int,
    end: int,
    *,
    fetch_json: Callable[[str], dict] = fetch_awattar_json,
    cache_dir: Path = CACHE_DIR,
    bzn: str = BZN,
) -> tuple[dict, dict]:
    path = cache_path(cache_dir, start, end, bzn)
    if path.exists():
        return read_cache(path), {"cacheHit": True, "path": str(path)}

    query = urllib.parse.urlencode({"start": start, "end": end})
    payload = fetch_json(f"{AWATTAR_URL}?{query}")
    write_cache(path, payload)
    return payload, {"cacheHit": False, "path": str(path)}


class DashboardHandler(SimpleHTTPRequestHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, directory=str(STATIC_ROOT), **kwargs)

    def do_GET(self) -> None:
        parsed = urllib.parse.urlparse(self.path)
        if parsed.path == "/api/epex":
            self.handle_epex(parsed)
            return
        super().do_GET()

    def handle_epex(self, parsed: urllib.parse.ParseResult) -> None:
        params = urllib.parse.parse_qs(parsed.query)
        try:
            start = int(params.get("start", [""])[0])
            end = int(params.get("end", [""])[0])
            if start >= end:
                raise ValueError("start must be before end")
            payload, meta = get_prices(start, end)
            body = json.dumps({"object": "list", "data": payload["data"], "cache": meta}).encode("utf-8")
            self.send_response(200)
        except Exception as exc:
            body = json.dumps({"error": str(exc)}).encode("utf-8")
            self.send_response(502)

        self.send_header("Content-Type", "application/json; charset=utf-8")
        self.send_header("Content-Length", str(len(body)))
        self.send_header("Cache-Control", "no-store")
        self.end_headers()
        self.wfile.write(body)


def main() -> None:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    server = ThreadingHTTPServer((HOST, PORT), DashboardHandler)
    print(f"Serving energy dashboard on http://{HOST}:{PORT}/energy_spot_dashboard.html")
    print(f"Static files: {STATIC_ROOT}")
    print(f"EPEX CSV cache: {CACHE_DIR}")
    server.serve_forever()


if __name__ == "__main__":
    main()
