#!/usr/bin/env python3.12
"""
Lightweight similarity utilities: prefer scikit-learn TF-IDF if available, otherwise use pure Python
token cosine.
"""
from __future__ import annotations

import argparse
import math
import pathlib
import re
from collections import Counter
from typing import Counter as CounterType, List, Tuple

try:
    from sklearn.feature_extraction.text import TfidfVectorizer  # type: ignore
    from sklearn.metrics.pairwise import cosine_similarity  # type: ignore

    _HAVE_SKLEARN = True
except Exception:  # noqa: BLE001
    _HAVE_SKLEARN = False


def tokenize(text: str) -> CounterType[str]:
    words = re.findall(r"[a-zA-Z0-9]+", text.lower())
    return Counter(words)


def cosine_counter(a: CounterType[str], b: CounterType[str]) -> float:
    if not a or not b:
        return 0.0
    common = set(a.keys()) & set(b.keys())
    dot = sum(a[w] * b[w] for w in common)
    mag_a = math.sqrt(sum(v * v for v in a.values()))
    mag_b = math.sqrt(sum(v * v for v in b.values()))
    if mag_a == 0 or mag_b == 0:
        return 0.0
    return dot / (mag_a * mag_b)


def score_similarity(texts: List[str]) -> List[Tuple[int, int, float]]:
    """Return pairwise similarity scores (i, j, score)."""
    if len(texts) < 2:
        return []
    if _HAVE_SKLEARN:
        vectorizer = TfidfVectorizer()
        tfidf = vectorizer.fit_transform(texts)
        sim = cosine_similarity(tfidf)
        results: List[Tuple[int, int, float]] = []
        n = sim.shape[0]
        for i in range(n):
            for j in range(i + 1, n):
                results.append((i, j, float(sim[i, j])))
        return results

    tokens = [tokenize(t) for t in texts]
    results: List[Tuple[int, int, float]] = []
    n = len(tokens)
    for i in range(n):
        for j in range(i + 1, n):
            score = cosine_counter(tokens[i], tokens[j])
            results.append((i, j, score))
    return results


def max_similarity(texts: List[str]) -> float:
    scores = score_similarity(texts)
    return max((score for _, _, score in scores), default=0.0)


def _cli() -> int:
    parser = argparse.ArgumentParser(description="Compute max similarity across files.")
    parser.add_argument("paths", nargs="*", help="Text files to compare")
    args = parser.parse_args()

    if not args.paths:
        print("ERROR:no_files:no files provided for comparison")
        return 2

    texts: List[str] = []
    errors: List[str] = []
    for path_str in args.paths:
        path = pathlib.Path(path_str)
        if path.is_file():
            try:
                texts.append(path.read_text(encoding="utf-8"))
            except Exception as exc:  # noqa: BLE001
                errors.append(f"{path}:{exc}")
        else:
            errors.append(f"{path}:not found")

    if len(texts) < 2:
        print(f"ERROR:insufficient_files:need at least 2 files, got {len(texts)}")
        return 2

    ms = max_similarity(texts)
    if errors:
        print(f"WARN:read_errors:{len(errors)} files skipped")
    print(f"OK:similarity:{ms:.4f} (from {len(texts)} files)")
    return 0


if __name__ == "__main__":
    raise SystemExit(_cli())
