# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

from typing import Any, Dict, Iterable, List, Optional, Union, Callable

def _as_categories(val: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
    if val is None: return None
    if isinstance(val, list): return [s.strip() for s in val if s and s.strip()]
    s = val.strip()
    if not s: return None
    if "," in s: return [t.strip() for t in s.split(",") if t.strip()]
    if " " in s and ":" not in s: return None
    return [s]

def crawl_arxiv_index_only(
    client,
    search_keyword: Optional[str],
    categories: Optional[Union[str, List[str]]],
    crawl_depth: int,
    files_per_run: int,
    year_from: Optional[int],
    year_to: Optional[int],
    progress: Optional[Callable[[str], None]] = None,
) -> List[Dict[str, Any]]:
    """Collect metadata only."""
    def log(m):
        try: progress and progress(m)
        except Exception: pass

    staged: List[Dict[str,Any]] = []
    cats = _as_categories(categories)

    depth_to_max = {0: 200, 1: 1000, 2: 2000, 3: 4000, 4: 8000}
    max_total = depth_to_max.get(int(crawl_depth), 1000)
    log(f"arXiv: staging up to {min(max_total, files_per_run)} items…")

    it: Iterable[Dict[str, Any]] = client.search_works(
        query=search_keyword,
        categories=cats,
        authors=None,
        title_terms=None,
        abstract_terms=None,
        year_from=year_from,
        year_to=year_to,
        sort_by="submittedDate",
        sort_order="descending",
        batch_size=200,
        max_results_total=max_total,
    )

    for doc in it:
        staged.append({
            "source": "arXiv",
            "id": f"arxiv:{doc['arxiv_id']}",
            "arxiv_id": doc["arxiv_id"],
            "title": doc.get("title"),
            "publication_year": doc["published"].year,
            "authors": ", ".join(doc.get("authors") or []),
            "pdf_url": doc.get("link_pdf") or f"https://arxiv.org/pdf/{doc['arxiv_id'].split('v')[0]}.pdf",
            "html_url": doc.get("link_abs"),
            "search_keyword": search_keyword or ""
        })
        if len(staged) >= files_per_run:
            break

    log(f"arXiv staged {len(staged)} items.")
    return staged