# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

import os
import time
import requests
from typing import Dict, Any, List, Optional, Iterable

BASE_URL = "https://api.adsabs.harvard.edu/v1"

class ADSClient:
    def __init__(self, api_token: str, user_agent: str, pause: float = 0.2, timeout: int = 45):
        token = (api_token or "").strip()
        if not token:
            raise ValueError("ADS API token is required.")
        self.token = token
        self.user_agent = (user_agent or "aiquniq/ROTmap_open/source").strip()
        self.pause = float(pause or 0.2)
        self.timeout = timeout

        self.session = requests.Session()
        self.session.headers.update({
            "Authorization": f"Bearer {self.token}",
            "Accept": "application/json",
            "User-Agent": self.user_agent,
        })

    def _get(self, path: str, params: Dict[str, Any]) -> Dict[str, Any]:
        url = f"{BASE_URL}/{path.lstrip('/')}"
        r = self.session.get(url, params=params, timeout=self.timeout)
        if r.status_code == 401:
            raise RuntimeError("ADS 401 Unauthorized. Check your API token.")
        r.raise_for_status()
        if self.pause:
            time.sleep(self.pause)
        return r.json()

    def _year_range_clause(self, yf: Optional[int], yt: Optional[int]) -> Optional[str]:
        if yf and yt:
            return f"year:[{yf} TO {yt}]"
        if yf and not yt:
            return f"year:[{yf} TO *]"
        if yt and not yf:
            return f"year:[* TO {yt}]"
        return None

    def _build_ads_filters(
        self,
        databases=None,
        refereed_only: bool = False,
        open_access_only: bool = False,
        bibstems=None,
        doctypes=None,
        extra_query: Optional[str] = None
    ) -> list:
        parts = []
        if databases:
            toks = [f"database:{d}" for d in databases if d]
            if toks:
                parts.append("(" + " OR ".join(toks) + ")")
        if refereed_only:
            parts.append("property:refereed")
        if open_access_only:
            parts.append("property:openaccess")
        if bibstems:
            toks = [f"bibstem:{b}" for b in bibstems if b]
            if toks:
                parts.append("(" + " OR ".join(toks) + ")")
        if doctypes:
            toks = [f"doctype:{d}" for d in doctypes if d]
            if toks:
                parts.append("(" + " OR ".join(toks) + ")")
        if extra_query and extra_query.strip():
            parts.append(extra_query.strip())
        return parts

    def search_works(
        self,
        query: Optional[str],
        year_from: Optional[int],
        year_to: Optional[int],
        databases=None,
        refereed_only: bool = False,
        open_access_only: bool = False,
        bibstems=None,
        doctypes=None,
        extra_query: Optional[str] = None
    ) -> Iterable[Dict[str, Any]]:
        q_parts = []
        if query:
            q_parts.append(query)
        yr = self._year_range_clause(year_from, year_to)
        if yr:
            q_parts.append(yr)
        q_parts.extend(self._build_ads_filters(
            databases, refereed_only, open_access_only, bibstems, doctypes, extra_query
        ))

        q = " AND ".join(q_parts) if q_parts else "*:*"
        params = {
            "q": q,
            "fl": "id,bibcode,title,year,doi,links_data,identifier,reference,author",
            "rows": 200,
            "start": 0,
            "sort": "citation_count desc",
        }

        while True:
            data = self._get("search/query", params)
            docs = (data.get("response") or {}).get("docs", []) or []
            for d in docs:
                yield d
            resp = data.get("response") or {}
            start = params["start"]
            rows = params["rows"]
            nf = resp.get("numFound", 0)
            nxt = start + rows
            if nxt >= nf or not docs:
                break
            params["start"] = nxt

    def get_works_batch(self, bibcodes: List[str]) -> List[Dict[str, Any]]:
        if not bibcodes:
            return []
        parts = " ".join(f'"{b}"' for b in bibcodes)
        q = f"bibcode:({parts})"
        params = {
            "q": q,
            "fl": "id,bibcode,title,year,doi,links_data,identifier,reference,author",
            "rows": min(200, len(bibcodes)),
            "start": 0,
        }
        data = self._get("search/query", params)
        return (data.get("response") or {}).get("docs", []) or []