# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

import os
import time
import datetime as dt
from typing import Dict, Any, Iterable, List, Optional, Tuple
import requests
import xml.etree.ElementTree as ET

ARXIV_API_HTTPS = "https://export.arxiv.org/api/query"
ARXIV_API_HTTP  = "http://export.arxiv.org/api/query"

NS = {
    "atom": "http://www.w3.org/2005/Atom",
    "arxiv": "http://arxiv.org/schemas/atom",
    "opensearch": "http://a9.com/-/spec/opensearch/1.1/",
}

def _parse_dt(s: str) -> dt.datetime:
    return dt.datetime.fromisoformat((s or "1970-01-01T00:00:00Z").replace("Z", "+00:00"))

def _text(e: Optional[ET.Element]) -> Optional[str]:
    return e.text.strip() if e is not None and e.text else None

def _strip_prefix(s: str, prefix: str) -> str:
    return s[len(prefix):] if s.startswith(prefix) else s

class ArxivClient:
    def __init__(self, user_agent: str, pause: float = 0.3, timeout: int = 45, debug: bool = False):
        self.user_agent = (user_agent or "aiquniq/ROTmap_open/source").strip()
        self.pause = float(pause or 0.3)
        self.timeout = timeout
        self.debug = debug or bool(os.environ.get("ARXIV_DEBUG"))
        self.session = requests.Session()
        self.session.headers.update({
            "User-Agent": self.user_agent,
            "Accept": "application/atom+xml",
        })
        self._base_url = ARXIV_API_HTTPS

    def _request(self, params: Dict[str, Any]) -> str:
        try:
            r = self.session.get(self._base_url, params=params, timeout=self.timeout, allow_redirects=True)
            r.raise_for_status()
            if self.pause:
                time.sleep(self.pause)
            return r.text
        except Exception as e:
            other = ARXIV_API_HTTP if self._base_url == ARXIV_API_HTTPS else ARXIV_API_HTTPS
            if self.debug:
                print(f"[arXiv] _request error on {self._base_url}: {e}; retrying {other}")
            self._base_url = other
            r = self.session.get(self._base_url, params=params, timeout=self.timeout, allow_redirects=True)
            r.raise_for_status()
            if self.pause:
                time.sleep(self.pause)
            return r.text

    def _parse_feed(self, xml_text: str) -> Dict[str, Any]:
        root = ET.fromstring(xml_text)

        entries = []
        for e in root.findall("atom:entry", NS):
            id_url = _text(e.find("atom:id", NS)) or ""
            arxiv_id_full = id_url.rsplit("/", 1)[-1] 
            arxiv_id = _strip_prefix(arxiv_id_full, "abs/")

            pdf_url = None
            abs_url = None
            for link in e.findall("atom:link", NS):
                rel = link.get("rel", "")
                href = link.get("href")
                if rel == "alternate":
                    abs_url = href
                if rel in ("related", "enclosure"):
                    if link.get("type") == "application/pdf" or (href or "").endswith(".pdf"):
                        pdf_url = href

            cats = [c.get("term") for c in e.findall("atom:category", NS) if c.get("term")]
            primary_cat_el = e.find("arxiv:primary_category", NS)
            primary_category = primary_cat_el.get("term") if primary_cat_el is not None else (cats[0] if cats else None)

            doi_el = e.find("arxiv:doi", NS)
            doi = _text(doi_el)

            authors = []
            for a in e.findall("atom:author", NS):
                n = _text(a.find("atom:name", NS))
                if n:
                    authors.append(n)

            published = _parse_dt(_text(e.find("atom:published", NS)))
            updated = _parse_dt(_text(e.find("atom:updated", NS)))

            title = _text(e.find("atom:title", NS)) or ""
            summary = _text(e.find("atom:summary", NS)) or ""

            entries.append({
                "source": "arxiv",
                "id": id_url,
                "arxiv_id": arxiv_id,       
                "title": title,
                "summary": summary,
                "authors": authors,
                "published": published,
                "updated": updated,
                "doi": doi,
                "primary_category": primary_category,
                "categories": cats,
                "link_abs": abs_url or id_url,
                "link_pdf": pdf_url,           
            })

        total_results_el = root.find("opensearch:totalResults", NS)
        try:
            total_results = int(_text(total_results_el)) if total_results_el is not None else None
        except Exception:
            total_results = None

        return {"entries": entries, "total": total_results}

    def _build_search_query(
        self,
        query: Optional[str],
        categories: Optional[List[str]],
        authors: Optional[List[str]],
        title_terms: Optional[List[str]],
        abstract_terms: Optional[List[str]],
    ) -> str:
        parts: List[str] = []

        if query:
            q = query.strip()
            if q:
                is_quoted = q.startswith('"') and q.endswith('"')
                needs_quotes = (not is_quoted) and any(ch.isspace() for ch in q)
                q_clean = q.replace('"', "") 
                parts.append(f'all:"{q_clean}"' if needs_quotes else f"all:{q_clean}")

        if title_terms:
            title_terms = [t for t in (title_terms or []) if t and t.strip()]
            if title_terms:
                parts.append(" AND ".join([f'ti:"{t.strip()}"' if " " in t else f"ti:{t.strip()}" for t in title_terms]))

        if abstract_terms:
            abstract_terms = [t for t in (abstract_terms or []) if t and t.strip()]
            if abstract_terms:
                parts.append(" AND ".join([f'abs:"{t.strip()}"' if " " in t else f"abs:{t.strip()}" for t in abstract_terms]))

        if authors:
            authors = [a for a in (authors or []) if a and a.strip()]
            if authors:
                parts.append(" AND ".join([f'au:"{a.strip()}"' if " " in a else f"au:{a.strip()}" for a in authors]))

        if categories:
            cats = [c.strip() for c in categories if c and c.strip()]
            if cats:
                parts.append("(" + " OR ".join([f"cat:{c}" for c in cats]) + ")")

        return " AND ".join(parts) or "all:quantum"

    # method - public

    def search_works(
        self,
        query: Optional[str],
        categories: Optional[List[str]] = None,
        authors: Optional[List[str]] = None,
        title_terms: Optional[List[str]] = None,
        abstract_terms: Optional[List[str]] = None,
        year_from: Optional[int] = None,
        year_to: Optional[int] = None,
        sort_by: str = "submittedDate",      # relevance | lastUpdatedDate | submittedDate
        sort_order: str = "descending",      # ascending | descending
        batch_size: int = 200,
        max_results_total: int = 1000,
    ) -> Iterable[Dict[str, Any]]:
        """
        arXiv doesn’t filter by year server-side; we filter client-side via 'published'.
        """
        assert batch_size <= 300, "arXiv API limit/ 300"

        search_query = self._build_search_query(query, categories, authors, title_terms, abstract_terms)
        if self.debug:
            print(f"[arXiv] search_query = {search_query}")

        def _needs_phrase_retry(q: Optional[str], built: str) -> bool:
            if not q:
                return False
            q = q.strip()
            if not q:
                return False
            is_quoted = q.startswith('"') and q.endswith('"')
            return (not is_quoted) and any(ch.isspace() for ch in q) and ('all:"' not in built)

        pulled = 0
        start = 0
        retried_phrase = False
        tried_protocol_fallback = False

        while pulled < max_results_total:
            to_get = min(batch_size, max_results_total - pulled)
            params = {
                "search_query": search_query,
                "start": start,
                "max_results": to_get,
                "sortBy": sort_by,
                "sortOrder": sort_order,
            }

            try:
                xml_text = self._request(params)
            except Exception as e:
                if self.debug:
                    print(f"[arXiv] request error: {e}")
                if not tried_protocol_fallback:
                    self._base_url = ARXIV_API_HTTP if self._base_url == ARXIV_API_HTTPS else ARXIV_API_HTTPS
                    tried_protocol_fallback = True
                    if self.debug:
                        print(f"[arXiv] switching base URL to {self._base_url} and retrying…")
                    xml_text = self._request(params)
                else:
                    raise

            feed = self._parse_feed(xml_text)
            entries = feed["entries"]

            if self.debug:
                print(f"[arXiv] page start={start} got={len(entries)} pulled={pulled}/{max_results_total} via {self._base_url}")

            if (
                start == 0
                and not entries
                and not retried_phrase
                and _needs_phrase_retry(query, search_query)
            ):
                q_clean = (query or "").strip().replace('"', "")
                search_query = self._build_search_query(f'"{q_clean}"', categories, authors, title_terms, abstract_terms)
                retried_phrase = True
                if self.debug:
                    print(f"[arXiv] first page empty; retrying with phrase: {search_query}")
                xml_text = self._request({**params, "search_query": search_query})
                feed = self._parse_feed(xml_text)
                entries = feed["entries"]
                if self.debug:
                    print(f"[arXiv] phrase retry got={len(entries)}")

            if not entries:
                break

            for e in entries:
                y = e["published"].year
                if (year_from and y < year_from) or (year_to and y > year_to):
                    continue
                yield e
                pulled += 1
                if pulled >= max_results_total:
                    break

            entries = feed["entries"]
            total = feed.get("total")

            if not entries:
                break

            for e in entries:
                y = e["published"].year
                if (year_from and y < year_from) or (year_to and y > year_to):
                    continue
                yield e
                pulled += 1
                if pulled >= max_results_total:
                    break

            # keep paging
            start += len(entries)

            # stop 
            if total is not None and start >= total:
                break

    def get_works_batch(self, arxiv_ids: List[str]) -> List[Dict[str, Any]]:
        if not arxiv_ids:
            return []
        out: List[Dict[str, Any]] = []
        for i in range(0, len(arxiv_ids), 50):
            chunk = arxiv_ids[i:i+50]
            params = {"id_list": ",".join(chunk)}
            xml_text = self._request(params)
            feed = self._parse_feed(xml_text)
            out.extend(feed["entries"])
        return out

    # downloads

    def download_pdf(self, entry: Dict[str, Any], dest_dir: str) -> Tuple[bool, Optional[str]]:
        pdf_url = entry.get("link_pdf")
        arxiv_id = entry.get("arxiv_id") or "paper"
        if not pdf_url:
            base_id = arxiv_id.split("v")[0]
            pdf_url = f"https://arxiv.org/pdf/{base_id}.pdf"

        os.makedirs(dest_dir, exist_ok=True)
        fname = arxiv_id.replace("/", "_") + ".pdf"
        fpath = os.path.join(dest_dir, fname)

        if os.path.exists(fpath) and os.path.getsize(fpath) > 0:
            return True, fpath

        with requests.get(pdf_url, headers={"User-Agent": self.user_agent}, stream=True, timeout=self.timeout) as r:
            if r.status_code != 200 or "application/pdf" not in (r.headers.get("Content-Type") or ""):
                return False, None
            tmp = fpath + ".part"
            with open(tmp, "wb") as fh:
                for chunk in r.iter_content(chunk_size=1 << 14):
                    if chunk:
                        fh.write(chunk)
            os.replace(tmp, fpath)
        if self.pause:
            time.sleep(self.pause)
        return True, fpath