import json
import os
import traceback
import requests
from typing import List, Dict

# Import constants so we keep a single source of truth for the endpoint/key.
from .constants import (
    OPENAI_API_KEY,
    OPENROUTER_API_KEY,
    DEFAULT_MODEL,
    MAX_RETRIES_LLM,
    REQUEST_TIMEOUT,
)


def chat_completion(
    messages: List[Dict[str, str]],
    model: str = DEFAULT_MODEL,
    temperature: float = 0.2,
    max_tokens: int = 2000,
) -> str:
    """
    Send a list of messages to the LLM (via OpenRouter) and return the assistant's text reply.
    """

    url = "https://openrouter.ai/api/v1/chat/completions"
    key = OPENROUTER_API_KEY or OPENAI_API_KEY
    if not key:
        raise RuntimeError("No OpenRouter API key found. Set OPENROUTER_API_KEY or OPENAI_API_KEY.")

    headers = {
        "Authorization": f"Bearer {key}",
        "Content-Type": "application/json",
    }
    payload = {
        "model": model,
        "messages": messages,
        "temperature": temperature,
        "max_tokens": max_tokens,
    }

    for attempt in range(MAX_RETRIES_LLM):
        try:
            # Prefer streaming assembly to robustly handle chunked responses.
            stream_payload = dict(payload)
            stream_payload["stream"] = True
            pieces = []
            stream_error = None
            # optional logging path for raw stream debugging
            log_path = os.getenv("OPENROUTER_STREAM_LOG")
            with requests.post(url, json=stream_payload, headers=headers, timeout=REQUEST_TIMEOUT, stream=True) as sresp:
                sresp.raise_for_status()
                # Iterate over lines and collect content fragments until [DONE]
                for raw_line in sresp.iter_lines(decode_unicode=True):
                    if not raw_line:
                        continue
                    line = raw_line.strip()
                    # optional: write raw stream lines to log for debugging
                    if log_path:
                        try:
                            with open(log_path, "a", encoding="utf-8") as lf:
                                lf.write(line + "\n")
                        except Exception:
                            pass
                    if line.startswith("data: "):
                        event_data = line[len("data: "):]
                    else:
                        event_data = line
                    if event_data == "[DONE]":
                        break
                    try:
                        ev = json.loads(event_data)
                    except Exception:
                        # Some servers may send non-JSON or keep-alive lines; skip
                        continue
                    # If the event signals an error from the provider, stop streaming
                    if isinstance(ev, dict) and ev.get("error"):
                        if log_path:
                            try:
                                with open(log_path, "a", encoding="utf-8") as lf:
                                    lf.write("==STREAM_ERROR==\n")
                                    lf.write(json.dumps(ev.get("error")) + "\n")
                            except Exception:
                                pass
                        # record the error and break out - we'll return any assembled pieces below
                        try:
                            stream_error = ev.get("error")
                        except Exception:
                            stream_error = {"message": "unknown stream error"}
                        break
                    # First try delta content (streamed tokens)
                    try:
                        ch = ev.get("choices", [])[0]
                        # If this choice indicates an error finish_reason, stop collecting
                        if ch.get("finish_reason") == "error":
                            if log_path:
                                try:
                                    with open(log_path, "a", encoding="utf-8") as lf:
                                        lf.write("==CHOICE_FINISH_ERROR==\n")
                                        lf.write(json.dumps(ch) + "\n")
                                except Exception:
                                    pass
                                try:
                                    stream_error = ch
                                except Exception:
                                    stream_error = {"message": "choice signalled error"}
                                break
                        # delta style streaming
                        delta = ch.get("delta", {})
                        if isinstance(delta, dict):
                            content_piece = delta.get("content")
                            if content_piece:
                                pieces.append(content_piece)
                                continue
                        # or full message in streaming event
                        msg = ch.get("message") or ch.get("text")
                        if msg:
                            # message may be a dict with content
                            if isinstance(msg, dict):
                                cont = msg.get("content")
                                if cont:
                                    pieces.append(cont)
                            elif isinstance(msg, str):
                                pieces.append(msg)
                    except Exception:
                        continue

            if pieces:
                assembled = ("".join(pieces)).strip()
                # annotate assembled content with stream error info so caller can react
                if stream_error:
                    try:
                        err_text = json.dumps(stream_error)
                    except Exception:
                        err_text = str(stream_error)
                    assembled = assembled + f"\n\n<<STREAM_ERROR:{err_text}>>"
                if log_path:
                    try:
                        with open(log_path, "a", encoding="utf-8") as lf:
                            lf.write("==ASSEMBLED==\n")
                            lf.write(assembled + "\n")
                    except Exception:
                        pass
                return assembled

            # If streaming returned nothing useful, fall back to a normal request
            resp = requests.post(url, json=payload, headers=headers, timeout=REQUEST_TIMEOUT)
            try:
                resp.raise_for_status()
            except Exception:
                # log response text for diagnosis when enabled
                log_path = os.getenv("OPENROUTER_STREAM_LOG")
                if log_path:
                    try:
                        with open(log_path, "a", encoding="utf-8") as lf:
                            lf.write("==NONSTREAM_RESPONSE==\n")
                            lf.write(resp.text + "\n")
                    except Exception:
                        pass
                raise
            data = resp.json()
            return data["choices"][0]["message"]["content"].strip()
        except Exception as exc:
            if attempt < MAX_RETRIES_LLM - 1:
                backoff = 2 ** attempt
                import time
                time.sleep(backoff)
            else:
                # for easier debugging, optionally dump stack to log
                log_path = os.getenv("OPENROUTER_STREAM_LOG")
                if log_path:
                    try:
                        with open(log_path, "a", encoding="utf-8") as lf:
                            lf.write("==EXCEPTION==\n")
                            lf.write(traceback.format_exc())
                            lf.write("\n")
                    except Exception:
                        pass
                raise RuntimeError(f"OpenRouter request failed after {MAX_RETRIES_LLM} attempts.") from exc