# MIT License
# Copyright (c) 2025 aiquniq
# See LICENSE file in the project root for full license text.

import os
import sqlite3
from typing import Optional, Dict, Any, Iterable

DB_PATH = os.path.join("data", "main.db")

SCHEMA = """
CREATE TABLE IF NOT EXISTS works (
    id TEXT PRIMARY KEY,
    doi TEXT,
    title TEXT,
    publication_year INTEGER,
    primary_location TEXT,
    is_oa INTEGER,
    pdf_url TEXT,
    html_url TEXT,
    file_path TEXT,
    abstract_inverted_index TEXT,
    concepts TEXT,
    authors TEXT,
    search_keyword TEXT,
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""

def connect(db_path: Optional[str] = None):
    path = db_path or DB_PATH
    os.makedirs(os.path.dirname(path), exist_ok=True)
    conn = sqlite3.connect(path)
    conn.execute("PRAGMA journal_mode=WAL;")
    conn.execute("PRAGMA synchronous=NORMAL;")
    conn.execute(SCHEMA)
    return conn

def upsert_work(conn, row: Dict[str, Any]):
    cols_cur = conn.execute("PRAGMA table_info(works)")
    cols = {r[1] for r in cols_cur.fetchall()}
    row = {k: v for k, v in row.items() if k in cols}

    if "id" not in row or not row["id"]:
        raise ValueError("upsert_work: row missing required 'id'")

    keys = ",".join(row.keys())
    placeholders = ",".join([":" + k for k in row.keys()])
    update = ",".join([f"{k}=excluded.{k}" for k in row.keys() if k != "id"])
    sql = (
        f"INSERT INTO works ({keys}) VALUES ({placeholders}) "
        f"ON CONFLICT(id) DO UPDATE SET {update or 'id=id'}"
    )
    conn.execute(sql, row)
    
def count(conn) -> int:
    return conn.execute("SELECT COUNT(*) FROM works").fetchone()[0]

def iter_recent(conn, limit=200) -> Iterable[Dict[str, Any]]:
    cur = conn.execute("SELECT id, title, publication_year, file_path, pdf_url, html_url FROM works ORDER BY created_at DESC LIMIT ?", (limit,))
    cols = [d[0] for d in cur.description]
    for row in cur.fetchall():
        yield dict(zip(cols, row))

def wipe_database():
    if os.path.exists(DB_PATH):
        os.remove(DB_PATH)
    files_dir = os.path.join("data", "files")
    if os.path.isdir(files_dir):
        for root, dirs, files in os.walk(files_dir, topdown=False):
            for name in files:
                try:
                    os.remove(os.path.join(root, name))
                except Exception:
                    pass
            for name in dirs:
                try:
                    os.rmdir(os.path.join(root, name))
                except Exception:
                    pass

def ensure_dirs():
    os.makedirs("data/files", exist_ok=True)
    os.makedirs("exports", exist_ok=True)
