# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

import os
from typing import Dict, Any, List, Set, Optional, Tuple, Callable
from .client import OpenAlexClient
from storage.files import stream_download
from storage import db as dbmod

def pick_fulltext_location(work: Dict[str, Any], require_pdf: bool) -> Tuple[Optional[str], Optional[str]]:
    def get(d, *path):
        cur = d
        for p in path:
            if cur is None:
                return None
            cur = cur.get(p)
        return cur
    best_pdf = get(work, "best_oa_location", "pdf_url") or None
    if not best_pdf and isinstance(work.get("locations"), list):
        for loc in work["locations"]:
            if isinstance(loc, dict) and loc.get("pdf_url"):
                best_pdf = loc.get("pdf_url")
                break
    html = get(work, "best_oa_location", "landing_page_url") or get(work, "open_access", "oa_url") or None
    pdf_url = best_pdf if best_pdf else None
    html_url = None if (require_pdf and not best_pdf) else html
    return pdf_url, html_url

def as_row(work: Dict[str, Any], search_keyword: str, saved_path: Optional[str], pdf_url: Optional[str], html_url: Optional[str]) -> Dict[str, Any]:
    def dig(d, path, default=None):
        cur = d
        for p in path:
            if not isinstance(cur, dict):
                return default
            cur = cur.get(p)
        return cur if cur is not None else default

    concepts_list = []
    for c in (work.get("concepts") or []):
        if isinstance(c, dict):
            name = c.get("display_name")
            if name:
                concepts_list.append(name)

    authors_list = []
    for a in (work.get("authorships") or []):
        if isinstance(a, dict):
            name = dig(a, ["author", "display_name"])
            if name:
                authors_list.append(name)

    primary_location_name = dig(work, ["primary_location", "source", "display_name"], default=None)

    return {
        "id": work.get("id"),
        "doi": work.get("doi"),
        "title": work.get("title"),
        "publication_year": work.get("publication_year"),
        "primary_location": primary_location_name,
        "is_oa": 1 if dig(work, ["open_access", "is_oa"], default=False) else 0,
        "pdf_url": pdf_url,
        "html_url": html_url,
        "file_path": saved_path,
        "abstract_inverted_index": str(work.get("abstract_inverted_index")) if work.get("abstract_inverted_index") else None,
        "concepts": ", ".join(concepts_list) if concepts_list else None,
        "authors": ", ".join(authors_list) if authors_list else None,
        "search_keyword": search_keyword,
    }

def crawl_and_collect(
    client: OpenAlexClient,
    search_keyword: Optional[str],
    topic_query: Optional[str],
    crawl_depth: int,
    files_per_run: int,
    require_pdf: bool = True,
    year_from: Optional[int] = None,
    year_to: Optional[int] = None,
    progress: Optional[Callable[[str], None]] = None,
) -> int:
    def log(msg: str):
        try:
            if progress:
                progress(msg)
        except Exception:
            pass

    dbmod.ensure_dirs()
    conn = dbmod.connect()
    saved = 0
    seen: Set[str] = set()
    queue: List[Tuple[str, int]] = []  # (openalex_id, depth)
    concept_id = client.resolve_concept(topic_query) if topic_query else None
    if concept_id:
        log(f"Resolved topic '{topic_query}' → {concept_id}")
    else:
        if topic_query:
            log(f"Topic '{topic_query}' not resolved; continuing with keyword only.")
    log("Fetching seed works...")
    for w in client.search_works(search_keyword, concept_id, is_oa_only=True, year_from=year_from, year_to=year_to):
        wid = w.get("id")
        if not wid or wid in seen:
            continue
        pdf_url, html_url = pick_fulltext_location(w, require_pdf=require_pdf)
        if not (pdf_url or html_url):
            continue
        base_name = wid.split("/")[-1]
        file_path = None
        if pdf_url:
            file_path = stream_download(pdf_url, "data/files", base_name, client.session, require_pdf=require_pdf)
        if require_pdf and not file_path:
            continue
        row = as_row(w, search_keyword or "", file_path, pdf_url, html_url)
        dbmod.upsert_work(conn, row)
        seen.add(wid)
        saved += 1
        log(f"Saved ({saved}): {(w.get('title') or '')[:100]}")
        refs = [ref for ref in (w.get("referenced_works") or []) if ref and ref not in seen]
        for ref in refs:
            queue.append((ref, 1))
        if refs:
            log(f"Enqueued {len(refs)} references (depth 1).")
        if saved >= files_per_run:
            conn.commit()
            conn.close()
            log(f"Reached files-per-run limit ({files_per_run}).")
            return saved
    while queue and saved < files_per_run:
        batch = queue[:200]
        queue = queue[200:]
        batch_ids = [wid for (wid, d) in batch if d <= crawl_depth]
        if not batch_ids:
            continue
        log(f"BFS: fetching batch of {len(batch_ids)} works at depth ≤ {crawl_depth}...")
        works = client.get_works_batch(batch_ids)
        for w in works:
            wid = w.get("id")
            depth = None
            for (qid, d) in batch:
                if qid == wid:
                    depth = d
                    break
            if not wid or wid in seen or depth is None or depth > crawl_depth:
                continue
            pdf_url, html_url = pick_fulltext_location(w, require_pdf=require_pdf)
            if not pdf_url:
                seen.add(wid)
                continue
            base_name = wid.split("/")[-1]
            file_path = stream_download(pdf_url, "data/files", base_name, client.session, require_pdf=require_pdf)
            if not file_path:
                seen.add(wid)
                continue
            row = as_row(w, search_keyword or "", file_path, pdf_url, html_url)
            dbmod.upsert_work(conn, row)
            seen.add(wid)
            saved += 1
            log(f"Saved ({saved}) [depth {depth}]: {(w.get('title') or '')[:100]}")
            next_depth = depth + 1
            if next_depth <= crawl_depth:
                refs = [ref for ref in (w.get("referenced_works") or []) if ref and ref not in seen]
                for ref in refs:
                    queue.append((ref, next_depth))
                if refs:
                    log(f"Enqueued {len(refs)} references (depth {next_depth}).")
            if saved >= files_per_run:
                log(f"Reached files-per-run limit ({files_per_run}).")
                break
    conn.commit()
    conn.close()
    log("Crawl finished.")
    return saved