# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

from typing import Any, Dict, List, Optional, Callable, Iterable
from storage.files import stream_download
from storage import db as dbmod
from utils.text import slugify

def _pdf_url_of(doc: Dict[str, Any]) -> Optional[str]:
    if doc.get("link_pdf"): return doc["link_pdf"]
    base_id = (doc.get("arxiv_id") or "").split("v")[0]
    return f"https://arxiv.org/pdf/{base_id}.pdf" if base_id else None

def _authors_of(doc: Dict[str, Any]) -> str:
    a = doc.get("authors") or []
    if isinstance(a, list): return ", ".join(a[:25])
    if isinstance(a, str): return a
    return ""

def crawl_and_collect_arxiv_save(
    client,                          
    search_keyword: Optional[str],
    categories: Optional[List[str] | str],
    crawl_depth: int,
    files_per_run: int,
    year_from: Optional[int],
    year_to: Optional[int],
    progress: Optional[Callable[[str], None]] = None,
) -> int:
    """Downloads PDFs and writes DB """
    def log(m):
        try: progress and progress(m)
        except Exception: pass

    if isinstance(categories, str):
        cats = [t.strip() for t in categories.split(",") if t.strip()]
    else:
        cats = categories

    depth_to_max = {0: 200, 1: 1000, 2: 2000, 3: 4000, 4: 8000}
    max_total = depth_to_max.get(int(crawl_depth), 1000)

    dbmod.ensure_dirs()
    conn = dbmod.connect()
    saved = 0

    it: Iterable[Dict[str, Any]] = client.search_works(
        query=search_keyword,
        categories=cats,
        authors=None,
        title_terms=None,
        abstract_terms=None,
        year_from=year_from,
        year_to=year_to,
        sort_by="submittedDate",
        sort_order="descending",
        batch_size=200,
        max_results_total=max_total,
    )

    log(f"arXiv: scanning up to {min(max_total, files_per_run)} entries…")

    for doc in it:
        if saved >= files_per_run: break

        pdf_url = _pdf_url_of(doc)
        if not pdf_url: continue

        title = (doc.get("title") or doc.get("arxiv_id") or "paper").strip()
        base = slugify(title)[:120]
        path = stream_download(pdf_url, "data/files", base, client.session, require_pdf=True)
        if not path: continue

        row = {
            "id": f"arxiv:{doc.get('arxiv_id')}",
            "doi": doc.get("doi"),
            "title": title or None,
            "publication_year": (doc.get("published").year if doc.get("published") else None),
            "primary_location": "arXiv",
            "is_oa": 1,
            "pdf_url": pdf_url,
            "html_url": doc.get("link_abs"),
            "file_path": path,
            "abstract_inverted_index": None,
            "concepts": None,
            "authors": _authors_of(doc) or None,
            "search_keyword": search_keyword or "",
        }
        dbmod.upsert_work(conn, row)
        saved += 1
        if saved % 25 == 0:
            log(f"arXiv saved ({saved}): {title[:100]}")

    conn.commit(); conn.close()
    log(f"arXiv finished. Saved {saved} PDF(s).")
    return saved