# invoice_revision_pdf.py

import os
from pathlib import Path
from typing import Any, Dict, Optional

from reportlab.lib.pagesizes import A4
from reportlab.pdfgen import canvas
from werkzeug.utils import secure_filename


def _safe_get(d: Dict[str, Any], path: str, default: Any = None) -> Any:
    """
    Acceso seguro a claves anidadas tipo "invoice.number" o "totals.gst.computed".
    """
    current = d
    for part in path.split('.'):
        if not isinstance(current, dict):
            return default
        current = current.get(part, default)
    return current


def create_invoice_revision_pdf(
    extracted: Dict[str, Any],
    output_dir: str,
    original_filename: str,
    checks: Optional[Dict[str, Any]] = None,
) -> str:
    """
    Crea un PDF de revisión de la extracción de la factura.

    - extracted: dict con los datos ya extraídos (invoice, totals, line_items, etc.).
    - output_dir: carpeta donde se guardarán los PDFs (por ejemplo "revisions").
    - original_filename: nombre del archivo subido (para darle el mismo nombre al PDF).
    - checks: dict opcional con validaciones (si no viene, se intenta usar extracted['checks']).

    Retorna la ruta absoluta del PDF generado.
    """

    # 1) Asegurar directorio de salida
    out_path = Path(output_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    # 2) Construir nombre de PDF basado en el archivo original
    safe_name = secure_filename(original_filename) or "invoice"
    stem = os.path.splitext(safe_name)[0]
    pdf_path = out_path / f"{stem}.pdf"

    # 3) Datos básicos
    invoice_type = _safe_get(extracted, "invoice_type.verbatim", "N/A")
    supplier_name = _safe_get(extracted, "header.supplier.name.verbatim", "N/A") or _safe_get(extracted, "header.seller.name.verbatim", "N/A") or "N/A"
    supplier_abn = _safe_get(extracted, "header.supplier.abn.verbatim", "N/A") or _safe_get(extracted, "header.seller.abn.verbatim", "N/A") or "N/A"
    buyer_name = _safe_get(extracted, "header.buyer.name.verbatim", "N/A")
    invoice_number = _safe_get(extracted, "header.invoice.number.verbatim", "N/A")
    issue_date = _safe_get(extracted, "header.invoice.issue_date.verbatim", "N/A")

    subtotal = _safe_get(extracted, "totals.items_sum.computed") or _safe_get(extracted, "totals.subtotal.verbatim")
    gst_amount = _safe_get(extracted, "totals.gst.computed") or _safe_get(extracted, "totals.gst.verbatim")
    grand_total = _safe_get(extracted, "totals.grand_total.computed") or _safe_get(extracted, "totals.grand_total.verbatim")

    line_items = extracted.get("items", []) or []

    # 4) Checks / validaciones
    if checks is None:
        checks = _safe_get(extracted, "totals.validations")
        #checks = extracted.get("totals.validations", {})

    # 5) Empezar a dibujar PDF
    c = canvas.Canvas(str(pdf_path), pagesize=A4)
    width, height = A4

    x_margin = 40
    y = height - 50
    line_height = 14

    def write(text: str, bold: bool = False):
        nonlocal y
        if y < 50:  # nueva página si se acaba el espacio
            c.showPage()
            y = height - 50
        if bold:
            c.setFont("Helvetica-Bold", 10)
        else:
            c.setFont("Helvetica", 10)
        c.drawString(x_margin, y, text)
        y -= line_height

    # Título
    c.setFont("Helvetica-Bold", 14)
    c.drawString(x_margin, y, "Invoice Extraction – Developer Revision")
    y -= 2 * line_height

    # Datos básicos de la factura
    write(f"Invoice type: {invoice_type}", bold=True)
    write(f"Supplier: {supplier_name}")
    write(f"Supplier ABN: {supplier_abn}")
    write(f"Buyer: {buyer_name}")
    write(f"Invoice number: {invoice_number}")
    write(f"Issue date: {issue_date}")
    y -= line_height

    # Totales
    write("Totals", bold=True)
    write(f"Subtotal: {subtotal}")
    write(f"GST: {gst_amount}")
    write(f"Grand total: {grand_total}")
    y -= line_height

    # Checks / Validaciones
    minimum_ato = extracted.get("minimum_ato", "unknown")
    minimum_xero_aeroflo = extracted.get("minimum_xero_aeroflo", "unknown")
    date_validated = extracted.get("date_validated", "unknown")
    if checks:
        write("Validation checks", bold=True)
        for name, result in checks.items():
            #status = result.get("status") if isinstance(result, dict) else result
            #msg = result.get("message") if isinstance(result, dict) else ""
            label = "OK" if result in (True, "ok", "OK", "pass", "PASS") else "FAIL"
            write(f"- {name}: {result}  {label}")
        write(f"- minimum_ato: {minimum_ato}")
        write(f"- minimum_xero_aeroflo: {minimum_xero_aeroflo}")
        write(f"- date_validated: {date_validated}")
        y -= line_height

        # Line items (resumen en tabla)
    write("Line items (summary)", bold=True)

    max_items = 10  # para que el PDF no se vuelva gigante
    rows_data: list[tuple[str, str, str, str, str, str, str]] = []

    for idx, item in enumerate(line_items[:max_items], start=1):
        sku = _safe_get(item, "sku.verbatim") or _safe_get(item, "sku.computed") or "N/A"
        desc = _safe_get(item, "description.verbatim") or _safe_get(item, "name.verbatim") or "N/A"
        qty = _safe_get(item, "qty.verbatim") or _safe_get(item, "qty.computed") or "N/A"
        unit_price = _safe_get(item, "unit_price.verbatim") or _safe_get(item, "unit_price.computed") or "N/A"
        line_total = _safe_get(item, "line_total.verbatim") or _safe_get(item, "line_total.computed") or "N/A"
        item_gst = _safe_get(item, "gst_line.verbatim") or _safe_get(item, "gst_line.computed") or "N/A"

        checks_dict = item.get("checks") or {}
        qty_check_raw = checks_dict.get("qty_x_unit_eq_total")
        if qty_check_raw is None:
            line_checks = "N/A"
        else:
            line_checks = "OK" if qty_check_raw in (True, "ok", "OK", "pass", "PASS") else "FAIL"

        rows_data.append((
            str(sku),
            str(desc),
            str(qty),
            str(unit_price),
            str(line_total),
            str(item_gst),
            line_checks,
        ))

    if not rows_data:
        write("No line items found.", bold=False)
    else:
        # Definición de tabla
        # Usamos proporciones del ancho util (entre márgenes)
        table_x = x_margin
        table_width = width - 2 * x_margin

        # Porcentajes aproximados para cada columna
        col_widths = [
            0.16 * table_width,  # sku
            0.47 * table_width,  # Description
            0.04 * table_width,  # Qty
            0.10 * table_width,  # Unit
            0.10 * table_width,  # Line total
            0.07 * table_width,  # GST
            0.06 * table_width,  # Check
        ]

        row_height = 14
        num_rows = len(rows_data) + 1  # +1 por el header
        table_height = num_rows * row_height + 4

        # Si no cabe en la página actual, crear nueva página
        if y - table_height < 50:
            c.showPage()
            y = height - 50

        table_top_y = y
        table_bottom_y = y - table_height

        # Dibujar líneas de la tabla
        c.setFont("Helvetica", 8)

        # Bordes horizontales
        current_y = table_top_y
        for _ in range(num_rows + 1):
            c.line(table_x, current_y, table_x + table_width, current_y)
            current_y -= row_height

        # Bordes verticales
        current_x = table_x
        for w_col in col_widths:
            c.line(current_x, table_top_y, current_x, table_bottom_y)
            current_x += w_col
        # Línea final derecha
        c.line(table_x + table_width, table_top_y, table_x + table_width, table_bottom_y)

        # Encabezados
        headers = ["Sku", "Description", "Qty", "Unit", "Line total", "GST", "Check"]
        current_x = table_x
        header_y = table_top_y - row_height + 4

        c.setFont("Helvetica-Bold", 8)
        for text, w_col in zip(headers, col_widths):
            c.drawString(current_x + 2, header_y, text)
            current_x += w_col

        # Filas de datos
        c.setFont("Helvetica", 6)
        row_y = table_top_y - 2 * row_height + 4

        for row in rows_data:
            current_x = table_x
            # row = (idx, desc, qty, unit, line_total, gst, check)
            for i, (cell_text, w_col) in enumerate(zip(row, col_widths)):
                text = str(cell_text)

                # Descripción puede ser larga: truncamos
                if i == 1 and len(text) > 70:
                    text = text[:67] + "..."

                c.drawString(current_x + 2, row_y, text)
                current_x += w_col

            row_y -= row_height

        # Mensaje si hay más items de los que se muestran
        y = table_bottom_y - line_height
        if len(line_items) > max_items:
            write(f"... ({len(line_items) - max_items} more items not shown)", bold=False)


    # Notas finales para el dev
    y -= line_height
    write("Notes for developer:", bold=True)
    write("- Compare this summary with the original invoice layout.")
    write("- Check tricky layouts, GST logic, and edge cases.")
    write("- Use this PDF together with the raw JSON for debugging.")

    c.showPage()
    c.save()

    return str(pdf_path)
