# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

from typing import Dict,Any,List,Set,Optional,Tuple,Callable, Iterable
from .client import ADSClient
from storage.files import stream_download
from storage import db as dbmod
from utils.text import slugify

def pick_pdf_and_html(doc:Dict[str,Any]):
    pdf=None; html=None
    links=doc.get('links_data') or []
    if isinstance(links,list):
        for e in links:
            if not isinstance(e,dict): continue
            url=(e.get('url') or e.get('href') or '')
            t=(e.get('type') or e.get('title') or '')
            s=(str(url)+' '+str(t)).lower()
            if 'pdf' in s and str(url).startswith('http'):
                pdf=url; break
    if not pdf:
        ids=doc.get('identifier') or []
        if isinstance(ids,list):
            for ident in ids:
                if isinstance(ident,str) and ident.lower().startswith('arxiv:'):
                    arx=ident.split(':',1)[1].strip()
                    if arx:
                        pdf=f'https://arxiv.org/pdf/{arx}.pdf'; break
    doi=doc.get('doi')
    if isinstance(doi,list) and doi: html=f'https://doi.org/{doi[0]}'
    elif isinstance(doi,str) and doi: html=f'https://doi.org/{doi}'
    return pdf, html

def title_of(doc:Dict[str,Any])->str:
    t=doc.get('title')
    if isinstance(t,list) and t: return t[0]
    if isinstance(t,str): return t
    return ''

def authors_of(doc:Dict[str,Any])->str:
    a=doc.get('author')
    if isinstance(a,list): return ', '.join(a[:25])
    if isinstance(a,str): return a
    return ''

# staging
def crawl_ads_index_only(
    client: ADSClient,
    search_keyword: Optional[str],
    crawl_depth: int,
    files_per_run: int,
    year_from: Optional[int]=None,
    year_to: Optional[int]=None,
    databases=None,
    refereed_only: bool=False,
    open_access_only: bool=False,
    bibstems=None,
    doctypes=None,
    extra_query: Optional[str]=None,
    progress: Optional[Callable[[str],None]]=None
)->List[Dict[str,Any]]:
    """Collects metadata only. Returns a list of staged items."""
    def log(m):
        try: progress and progress(m)
        except Exception: pass

    staged: List[Dict[str,Any]] = []
    seen: Set[str] = set()
    queue: List[Tuple[str,int]] = []

    log('ADS: fetching seed works…')
    for d in client.search_works(search_keyword,year_from,year_to,databases,refereed_only,open_access_only,bibstems,doctypes,extra_query):
        bib=d.get('bibcode')
        if not bib or bib in seen: continue
        pdf,html=pick_pdf_and_html(d)
        if not pdf: continue
        row={
            'source': 'NASA ADS',
            'id': f'ads:{bib}',
            'bibcode': bib,
            'doi': (d.get('doi') or [None])[0] if isinstance(d.get('doi'),list) else d.get('doi'),
            'title': title_of(d) or None,
            'publication_year': d.get('year'),
            'authors': authors_of(d) or None,
            'pdf_url': pdf,
            'html_url': html,
            'search_keyword': search_keyword or ''
        }
        staged.append(row); seen.add(bib)
        if len(staged) >= files_per_run:
            log(f'ADS reached files-per-run limit ({files_per_run}).')
            return staged

        refs=d.get('reference') or []
        new_refs=[r for r in refs if isinstance(r,str) and r not in seen]
        for r in new_refs: queue.append((r,1))
        if new_refs: log(f'ADS enqueued {len(new_refs)} references (depth 1).')

    while queue and len(staged) < files_per_run:
        batch=queue[:200]; queue=queue[200:]
        ids=[b for (b,dpt) in batch if dpt<=crawl_depth]
        if not ids: continue
        log(f'ADS BFS: fetching batch of {len(ids)} at depth ≤ {crawl_depth}…')
        works=client.get_works_batch(ids)
        for d in works:
            bib=d.get('bibcode'); depth=None
            for (qb,dep) in batch:
                if qb==bib: depth=dep; break
            if not bib or bib in seen or depth is None or depth>crawl_depth: continue
            pdf,html=pick_pdf_and_html(d)
            if not pdf: seen.add(bib); continue
            row={
                'source': 'NASA ADS',
                'id': f'ads:{bib}',
                'bibcode': bib,
                'doi': (d.get('doi') or [None])[0] if isinstance(d.get('doi'),list) else d.get('doi'),
                'title': title_of(d) or None,
                'publication_year': d.get('year'),
                'authors': authors_of(d) or None,
                'pdf_url': pdf,
                'html_url': html,
                'search_keyword': search_keyword or '',
                'depth': depth
            }
            staged.append(row); seen.add(bib)
            nd=depth+1
            if nd<=crawl_depth:
                refs=d.get('reference') or []
                new_refs=[r for r in refs if isinstance(r,str) and r not in seen]
                for r in new_refs: queue.append((r,nd))
                if new_refs: log(f'ADS enqueued {len(new_refs)} references (depth {nd}).')
            if len(staged) >= files_per_run:
                log(f'ADS reached files-per-run limit ({files_per_run}).')
                break

    log(f'ADS crawl finished. Staged {len(staged)} items (no files saved).')
    return staged

# save
def crawl_and_collect_ads(
    client: ADSClient,
    search_keyword: Optional[str],
    crawl_depth: int,
    files_per_run: int,
    year_from: Optional[int]=None,
    year_to: Optional[int]=None,
    databases=None,
    refereed_only: bool=False,
    open_access_only: bool=False,
    bibstems=None,
    doctypes=None,
    extra_query: Optional[str]=None,
    progress: Optional[Callable[[str],None]]=None
)->int:
    """Downloads PDFs and writes DB."""
    def log(m):
        try: progress and progress(m)
        except Exception: pass

    dbmod.ensure_dirs(); conn=dbmod.connect(); saved=0; seen:Set[str]=set(); queue:List[Tuple[str,int]]=[]
    log('ADS: fetching seed works...')
    for d in client.search_works(search_keyword,year_from,year_to,databases,refereed_only,open_access_only,bibstems,doctypes,extra_query):
        bib=d.get('bibcode')
        if not bib or bib in seen: continue
        pdf,html=pick_pdf_and_html(d)
        if not pdf: continue
        title=title_of(d) or bib; base=slugify(title)[:120]
        path=stream_download(pdf,'data/files',base,client.session,require_pdf=True)
        if not path: continue
        row={'id':f'ads:{bib}','doi':(d.get('doi') or [None])[0] if isinstance(d.get('doi'),list) else d.get('doi'),
             'title':title_of(d) or None,'publication_year':d.get('year'),'primary_location':'NASA ADS','is_oa':1,
             'pdf_url':pdf,'html_url':html,'file_path':path,'abstract_inverted_index':None,'concepts':None,
             'authors':authors_of(d) or None,'search_keyword':search_keyword or ''}
        dbmod.upsert_work(conn,row); seen.add(bib); saved+=1; log(f'ADS saved ({saved}): {title[:100]}')
        refs=d.get('reference') or []; new_refs=[r for r in refs if isinstance(r,str) and r not in seen]
        for r in new_refs: queue.append((r,1))
        if new_refs: log(f'ADS enqueued {len(new_refs)} references (depth 1).')
        if saved>=files_per_run: conn.commit(); conn.close(); log(f'ADS reached files-per-run limit ({files_per_run}).'); return saved
    while queue and saved<files_per_run:
        batch=queue[:200]; queue=queue[200:]
        ids=[b for (b,d) in batch if d<=crawl_depth]
        if not ids: continue
        log(f'ADS BFS: fetching batch of {len(ids)} at depth ≤ {crawl_depth}...')
        works=client.get_works_batch(ids)
        for d in works:
            bib=d.get('bibcode'); depth=None
            for (qb,dep) in batch:
                if qb==bib: depth=dep; break
            if not bib or bib in seen or depth is None or depth>crawl_depth: continue
            pdf,html=pick_pdf_and_html(d)
            if not pdf: seen.add(bib); continue
            title=title_of(d) or bib; base=slugify(title)[:120]
            path=stream_download(pdf,'data/files',base,client.session,require_pdf=True)
            if not path: seen.add(bib); continue
            row={'id':f'ads:{bib}','doi':(d.get('doi') or [None])[0] if isinstance(d.get('doi'),list) else d.get('doi'),
                 'title':title_of(d) or None,'publication_year':d.get('year'),'primary_location':'NASA ADS','is_oa':1,
                 'pdf_url':pdf,'html_url':html,'file_path':path,'abstract_inverted_index':None,'concepts':None,
                 'authors':authors_of(d) or None,'search_keyword':search_keyword or ''}
            dbmod.upsert_work(conn,row); seen.add(bib); saved+=1; log(f'ADS saved ({saved}) [depth {depth}]: {title[:100]}')
            nd=depth+1
            if nd<=crawl_depth:
                refs=d.get('reference') or []; new_refs=[r for r in refs if isinstance(r,str) and r not in seen]
                for r in new_refs: queue.append((r,nd))
                if new_refs: log(f'ADS enqueued {len(new_refs)} references (depth {nd}).')
            if saved>=files_per_run: log(f'ADS reached files-per-run limit ({files_per_run}).'); break
    conn.commit(); conn.close(); log('ADS crawl finished.'); return saved