# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

import os
from typing import Dict, Any, List, Set, Optional, Tuple, Callable

#helpers 

def pick_fulltext_location(work: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
    """Best-effort PDF + HTML/landing URL from an OpenAlex work (no network)."""
    def get(d, *path):
        cur = d
        for p in path:
            if cur is None or not isinstance(cur, dict):
                return None
            cur = cur.get(p)
        return cur

    best_pdf = get(work, "best_oa_location", "pdf_url") or None
    if not best_pdf and isinstance(work.get("locations"), list):
        for loc in work["locations"]:
            if isinstance(loc, dict) and loc.get("pdf_url"):
                best_pdf = loc.get("pdf_url")
                break

    html = (
        get(work, "best_oa_location", "landing_page_url")
        or get(work, "open_access", "oa_url")
        or work.get("landing_page_url")
        or None
    )
    return best_pdf if best_pdf else None, html

def make_staged_row(work: Dict[str, Any], search_keyword: str) -> Dict[str, Any]:
    """Shape a staged row matching the ROTmap UI/exports."""
    def dig(d, path, default=None):
        cur = d
        for p in path:
            if not isinstance(cur, dict):
                return default
            cur = cur.get(p)
        return cur if cur is not None else default

    # authors (display names)
    authors_list: List[str] = []
    for a in (work.get("authorships") or []):
        if isinstance(a, dict):
            name = dig(a, ["author", "display_name"])
            if name:
                authors_list.append(name)

    pdf_url, html_url = pick_fulltext_location(work)
    primary_location_name = dig(work, ["primary_location", "source", "display_name"], default=None)

    return {
        "source": "OpenAlex",
        "id": work.get("id"),
        "doi": work.get("doi"),
        "title": work.get("title"),
        "publication_year": work.get("publication_year"),
        "primary_location": primary_location_name,
        "is_oa": 1 if dig(work, ["open_access", "is_oa"], default=False) else 0,
        "pdf_url": pdf_url,
        "html_url": html_url,
        "authors": ", ".join(authors_list) if authors_list else None,
        "search_keyword": search_keyword or "",
    }

# mainn

def crawl_openalex_index_only(
    client,  # Oaclient
    search_keyword: Optional[str],
    topic_query: Optional[str],
    crawl_depth: int,
    files_per_run: int,
    year_from: Optional[int],
    year_to: Optional[int],
    progress: Optional[Callable[[str], None]] = None,
) -> List[Dict[str, Any]]:
    """
    Collects metadata only from OpenAlex; no downloads, no DB writes.
    Uses the same client API as your working save crawler:
      - client.resolve_concept(topic_query) -> concept_id or None
      - client.search_works(search_keyword, concept_id, is_oa_only=True, year_from=..., year_to=...) -> iterator of works
      - client.get_works_batch(openalex_ids: List[str]) -> List[works]
    Performs a light BFS over references up to 'crawl_depth', capped by 'files_per_run'.
    """
    def log(msg: str):
        try:
            if progress:
                progress(msg)
        except Exception:
            pass

    staged: List[Dict[str, Any]] = []
    seen: Set[str] = set()
    queue: List[Tuple[str, int]] = []  # (openalex_id, depth)

    # Resolve conc..
    concept_id = client.resolve_concept(topic_query) if topic_query else None
    if concept_id:
        log(f"Resolved topic '{topic_query}' → {concept_id}")
    else:
        if topic_query:
            log(f"Topic '{topic_query}' not resolved; continuing with keyword only.")

    log("Fetching seed works...")
    # match the save sig  (no unexpect kwargs)
    for w in client.search_works(
        search_keyword,
        concept_id,
        is_oa_only=True,
        year_from=year_from,
        year_to=year_to,
    ):
        wid = w.get("id")
        if not wid or wid in seen:
            continue
        row = make_staged_row(w, search_keyword or "")
        staged.append(row)
        seen.add(wid)

        # enqueue ref for BFS 
        refs = [ref for ref in (w.get("referenced_works") or []) if ref and ref not in seen]
        for ref in refs:
            queue.append((ref, 1))
        if refs:
            log(f"Enqueued {len(refs)} references (depth 1).")

        if len(staged) >= files_per_run:
            log(f"Reached stage limit ({files_per_run}).")
            return staged

    while queue and len(staged) < files_per_run:
        # take a reasonable batch 
        batch = queue[:200]
        queue = queue[200:]
        batch_ids = [wid for (wid, d) in batch if d <= crawl_depth]
        if not batch_ids:
            continue

        log(f"BFS: fetching batch of {len(batch_ids)} works at depth ≤ {crawl_depth}...")
        works = client.get_works_batch(batch_ids)

        # index by id
        depth_map = {qid: d for (qid, d) in batch}

        for w in works:
            wid = w.get("id")
            if not wid or wid in seen:
                continue
            depth = depth_map.get(wid)
            if depth is None or depth > crawl_depth:
                continue

            row = make_staged_row(w, search_keyword or "")
            staged.append(row)
            seen.add(wid)

            # enqueue next layer 
            next_depth = depth + 1
            if next_depth <= crawl_depth:
                refs = [ref for ref in (w.get("referenced_works") or []) if ref and ref not in seen]
                for ref in refs:
                    queue.append((ref, next_depth))
                if refs:
                    log(f"Enqueued {len(refs)} references (depth {next_depth}).")

            if len(staged) >= files_per_run:
                log(f"Reached stage limit ({files_per_run}).")
                break

    log(f"OpenAlex staged {len(staged)} items.")
    return staged