# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

import time
import requests
from typing import Dict, Any, List, Optional, Iterable

BASE_URL = "https://api.openalex.org"

class OpenAlexClient:
    def __init__(self, user_agent: str, contact_email: str, pause: float = 0.2):
        self.user_agent = user_agent.strip() or "aiquniq/ROTmap_open/source"
        self.contact_email = contact_email.strip()
        self.pause = pause
        self.session = requests.Session()
        self.session.headers.update({"User-Agent": f"{self.user_agent} (mailto:{self.contact_email})"})
    
    def _get(self, path: str, params: Dict[str, Any]) -> Dict[str, Any]:
        if self.contact_email:
            params = {**params, "mailto": self.contact_email}
        url = f"{BASE_URL}{path}"
        resp = self.session.get(url, params=params, timeout=30)
        resp.raise_for_status()
        time.sleep(self.pause)
        return resp.json()
    
    def resolve_concept(self, topic_query: str) -> Optional[str]:
        if not topic_query:
            return None
        j = self._get("/concepts", {"search": topic_query, "per_page": 1})
        results = j.get("results", [])
        if results:
            return results[0].get("id")  
        return None
    
    def search_works(
        self,
        query: Optional[str],
        concept_id: Optional[str],
        is_oa_only: bool = True,
        year_from: Optional[int] = None,
        year_to: Optional[int] = None,
    ) -> Iterable[Dict[str, Any]]:
        filters = []
        if concept_id:
            filters.append(f"concepts.id:{concept_id.split('/')[-1]}")
        if is_oa_only:
            filters.append("open_access.is_oa:true")
        if year_from:
            filters.append(f"from_publication_date:{year_from}-01-01")
        if year_to:
            filters.append(f"to_publication_date:{year_to}-12-31")
        params = {
            "search": query or None,
            "filter": ",".join(filters) if filters else None,
            "per_page": 200,
            "cursor": "*",
            "sort": "cited_by_count:desc"
        }
        while True:
            j = self._get("/works", {k: v for k, v in params.items() if v is not None})
            results = j.get("results", [])
            for r in results:
                yield r
            cursor = j.get("meta", {}).get("next_cursor")
            if not cursor:
                break
            params["cursor"] = cursor
    
    def get_works_batch(self, openalex_ids: List[str]) -> List[Dict[str, Any]]:
        if not openalex_ids:
            return []
        ids_str = "|".join(openalex_ids)
        j = self._get("/works", {"filter": f"ids.openalex:{ids_str}", "per_page": 200})
        return j.get("results", [])
