from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
from sklearn.datasets import load_breast_cancer, load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


ROOT = Path(__file__).resolve().parents[1]
ASSETS = ROOT / "assets"
DATA = ROOT / "data"
ASSETS.mkdir(exist_ok=True)
DATA.mkdir(exist_ok=True)

INK = "#11110f"
MUTED = "#5f5a50"
RULE = "#cfc5b5"
BLUE = "#24607a"
RED = "#a83f2d"
GREEN = "#4f7f56"
GOLD = "#a56e1d"
PAPER = "#fbfaf6"


def style_axes(ax):
    ax.set_facecolor(PAPER)
    ax.grid(axis="y", color=RULE, linewidth=0.8, alpha=0.6)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_color(INK)
    ax.spines["bottom"].set_color(INK)
    ax.tick_params(colors=INK, labelsize=9)
    ax.xaxis.label.set_color(INK)
    ax.yaxis.label.set_color(INK)
    ax.title.set_color(INK)


def savefig(fig, name):
    fig.savefig(ASSETS / name, format="svg", bbox_inches="tight", facecolor=PAPER)
    plt.close(fig)


def breast_cancer_selective_routing():
    dataset = load_breast_cancer()
    x_train, x_test, y_train, y_test = train_test_split(
        dataset.data,
        dataset.target,
        test_size=0.35,
        random_state=42,
        stratify=dataset.target,
    )
    model = make_pipeline(
        StandardScaler(),
        LogisticRegression(max_iter=5000, random_state=42),
    )
    model.fit(x_train, y_train)

    probs = model.predict_proba(x_test)
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    correct = pred == y_test
    order = np.argsort(-conf)

    rows = []
    n = len(y_test)
    for k in range(5, n + 1):
        selected = order[:k]
        coverage = k / n
        risk = 1 - correct[selected].mean()
        review_rate = 1 - coverage
        rows.append(
            {
                "dataset": "Wisconsin Diagnostic Breast Cancer",
                "coverage": coverage,
                "review_rate": review_rate,
                "selective_risk": risk,
                "selected_cases": k,
                "test_cases": n,
            }
        )
    curve = pd.DataFrame(rows)
    curve.to_csv(DATA / "breast_cancer_risk_coverage.csv", index=False)

    operating = []
    for target in [0.005, 0.01, 0.02, 0.05]:
        valid = curve[curve["selective_risk"] <= target]
        if len(valid):
            row = valid.iloc[-1].to_dict()
            row["target_risk"] = target
            operating.append(row)
    operating_df = pd.DataFrame(operating)
    operating_df.to_csv(DATA / "breast_cancer_operating_points.csv", index=False)

    fig, ax = plt.subplots(figsize=(7.2, 4.2))
    ax.plot(curve["coverage"] * 100, curve["selective_risk"] * 100, color=BLUE, linewidth=2.6)
    for target, color in [(1, RED), (2, GOLD), (5, GREEN)]:
        ax.axhline(target, color=color, linewidth=1.0, linestyle=(0, (4, 4)), alpha=0.85)
        valid = curve[curve["selective_risk"] * 100 <= target]
        if len(valid):
            p = valid.iloc[-1]
            ax.scatter([p["coverage"] * 100], [p["selective_risk"] * 100], color=color, s=42, zorder=3)
            ax.text(
                p["coverage"] * 100 - 2.5,
                target + 0.25,
                f"{target}% risk: {p['coverage'] * 100:.1f}% auto",
                color=color,
                fontsize=8.5,
                ha="right",
            )
    ax.set_title("Selective routing on a real medical tabular dataset", loc="left", fontsize=12)
    ax.set_xlabel("Coverage / auto-execution rate (%)")
    ax.set_ylabel("Observed error on auto-executed cases (%)")
    ax.set_xlim(0, 101)
    ax.set_ylim(0, max(8, curve["selective_risk"].max() * 105))
    style_axes(ax)
    ax.text(
        0,
        -0.22,
        f"Dataset: sklearn Wisconsin Diagnostic Breast Cancer, {len(dataset.data)} cases. Model: logistic regression. Test cases: {n}.",
        transform=ax.transAxes,
        fontsize=8,
        color=MUTED,
    )
    savefig(fig, "risk_coverage_breast_cancer.svg")

    return {
        "dataset": "Wisconsin Diagnostic Breast Cancer",
        "samples": len(dataset.data),
        "features": dataset.data.shape[1],
        "classes": len(dataset.target_names),
        "test_cases": n,
        "baseline_accuracy": accuracy_score(y_test, pred),
        "best_1pct_coverage": float(
            operating_df.loc[operating_df["target_risk"].eq(0.01), "coverage"].iloc[0]
            if (len(operating_df) and operating_df["target_risk"].eq(0.01).any())
            else np.nan
        ),
    }


def digits_conformal_sets():
    dataset = load_digits()
    x_train_full, x_test, y_train_full, y_test = train_test_split(
        dataset.data,
        dataset.target,
        test_size=0.25,
        random_state=7,
        stratify=dataset.target,
    )
    x_train, x_cal, y_train, y_cal = train_test_split(
        x_train_full,
        y_train_full,
        test_size=0.30,
        random_state=7,
        stratify=y_train_full,
    )
    model = make_pipeline(
        StandardScaler(),
        LogisticRegression(max_iter=5000, random_state=7),
    )
    model.fit(x_train, y_train)

    cal_probs = model.predict_proba(x_cal)
    cal_scores = 1 - cal_probs[np.arange(len(y_cal)), y_cal]
    test_probs = model.predict_proba(x_test)

    rows = []
    for alpha in np.linspace(0.01, 0.25, 25):
        # Split conformal quantile with finite-sample correction.
        q_level = np.ceil((len(cal_scores) + 1) * (1 - alpha)) / len(cal_scores)
        q_level = min(q_level, 1.0)
        threshold = np.quantile(cal_scores, q_level, method="higher")
        prediction_sets = test_probs >= (1 - threshold)
        set_sizes = prediction_sets.sum(axis=1)
        covered = prediction_sets[np.arange(len(y_test)), y_test]
        singleton = set_sizes == 1
        singleton_pred = test_probs.argmax(axis=1)
        auto_error = (
            1 - (singleton_pred[singleton] == y_test[singleton]).mean()
            if singleton.any()
            else np.nan
        )
        rows.append(
            {
                "dataset": "sklearn Digits",
                "alpha": alpha,
                "target_coverage": 1 - alpha,
                "empirical_coverage": covered.mean(),
                "singleton_auto_rate": singleton.mean(),
                "review_rate": 1 - singleton.mean(),
                "mean_set_size": set_sizes.mean(),
                "auto_error_singletons": auto_error,
                "calibration_cases": len(y_cal),
                "test_cases": len(y_test),
            }
        )
    result = pd.DataFrame(rows)
    result.to_csv(DATA / "digits_conformal_sets.csv", index=False)

    fig, ax1 = plt.subplots(figsize=(7.2, 4.2))
    ax1.plot(result["alpha"] * 100, result["review_rate"] * 100, color=RED, linewidth=2.4, label="review rate")
    ax1.plot(
        result["alpha"] * 100,
        result["auto_error_singletons"] * 100,
        color=BLUE,
        linewidth=2.4,
        label="error on singleton auto-executions",
    )
    ax1.set_title("Conformal prediction sets create a measurable review budget", loc="left", fontsize=12)
    ax1.set_xlabel("Allowed miscoverage alpha (%)")
    ax1.set_ylabel("Rate (%)")
    ax1.set_xlim(1, 25)
    ax1.set_ylim(0, max(55, result["review_rate"].max() * 105))
    style_axes(ax1)
    ax1.legend(frameon=False, loc="upper right", fontsize=8.5)
    ax1.text(
        0,
        -0.22,
        f"Dataset: sklearn Digits, {len(dataset.data)} images. Calibration cases: {len(y_cal)}. Test cases: {len(y_test)}.",
        transform=ax1.transAxes,
        fontsize=8,
        color=MUTED,
    )
    savefig(fig, "conformal_digits_review_budget.svg")

    fig, ax2 = plt.subplots(figsize=(7.2, 4.2))
    ax2.plot(result["target_coverage"] * 100, result["empirical_coverage"] * 100, color=GREEN, linewidth=2.6)
    ax2.plot([75, 100], [75, 100], color=MUTED, linewidth=1.0, linestyle=(0, (4, 4)))
    ax2.set_title("Empirical coverage tracks the conformal target", loc="left", fontsize=12)
    ax2.set_xlabel("Target coverage (%)")
    ax2.set_ylabel("Observed test coverage (%)")
    ax2.set_xlim(74, 100)
    ax2.set_ylim(74, 100)
    style_axes(ax2)
    ax2.text(87, 84, "diagonal = perfect calibration", color=MUTED, fontsize=8.5)
    ax2.text(
        0,
        -0.22,
        "Prediction sets are generated from held-out calibration scores; coverage is measured on a separate test split.",
        transform=ax2.transAxes,
        fontsize=8,
        color=MUTED,
    )
    savefig(fig, "conformal_digits_coverage.svg")

    chosen = result.iloc[(result["alpha"] - 0.10).abs().argsort()[:1]].iloc[0]
    return {
        "dataset": "sklearn Digits",
        "samples": len(dataset.data),
        "features": dataset.data.shape[1],
        "classes": len(dataset.target_names),
        "calibration_cases": len(y_cal),
        "test_cases": len(y_test),
        "alpha_10_review_rate": chosen["review_rate"],
        "alpha_10_empirical_coverage": chosen["empirical_coverage"],
    }


def write_summary(rows):
    pd.DataFrame(rows).to_csv(DATA / "dataset_summary.csv", index=False)


def swebench_verified_distribution():
    ds = load_dataset("princeton-nlp/SWE-bench_Verified", split="test")
    df = pd.DataFrame(ds)
    df["problem_chars"] = df["problem_statement"].str.len()
    df["patch_lines"] = df["patch"].fillna("").str.count("\n") + 1
    df["test_patch_lines"] = df["test_patch"].fillna("").str.count("\n") + 1
    df["fail_to_pass_count"] = df["FAIL_TO_PASS"].fillna("").str.count(",") + (
        df["FAIL_TO_PASS"].fillna("").str.len() > 2
    ).astype(int)

    repo_counts = df["repo"].value_counts().reset_index()
    repo_counts.columns = ["repo", "tasks"]
    difficulty_counts = df["difficulty"].fillna("unknown").value_counts().reset_index()
    difficulty_counts.columns = ["difficulty", "tasks"]

    repo_counts.to_csv(DATA / "swebench_verified_repo_counts.csv", index=False)
    difficulty_counts.to_csv(DATA / "swebench_verified_difficulty_counts.csv", index=False)
    df[
        [
            "repo",
            "instance_id",
            "difficulty",
            "problem_chars",
            "patch_lines",
            "test_patch_lines",
            "fail_to_pass_count",
        ]
    ].to_csv(DATA / "swebench_verified_task_metrics.csv", index=False)

    order = ["<15 min fix", "15 min - 1 hour", "1-4 hours", ">4 hours", "unknown"]
    difficulty_plot = difficulty_counts.set_index("difficulty").reindex(order).dropna().reset_index()
    fig, ax = plt.subplots(figsize=(7.2, 4.2))
    bars = ax.bar(difficulty_plot["difficulty"], difficulty_plot["tasks"], color=[GREEN, BLUE, GOLD, RED, MUTED][: len(difficulty_plot)])
    ax.set_title("SWE-bench Verified task difficulty distribution", loc="left", fontsize=12)
    ax.set_xlabel("Human-estimated task difficulty")
    ax.set_ylabel("Tasks")
    style_axes(ax)
    ax.tick_params(axis="x", rotation=18)
    for bar in bars:
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 3,
            f"{int(bar.get_height())}",
            ha="center",
            va="bottom",
            fontsize=8.5,
            color=INK,
        )
    ax.text(
        0,
        -0.28,
        "Dataset: princeton-nlp/SWE-bench_Verified test split, 500 real GitHub issue-fix tasks.",
        transform=ax.transAxes,
        fontsize=8,
        color=MUTED,
    )
    savefig(fig, "swebench_verified_difficulty.svg")

    top_repos = repo_counts.head(12).sort_values("tasks")
    fig, ax = plt.subplots(figsize=(7.2, 4.8))
    ax.barh(top_repos["repo"], top_repos["tasks"], color=BLUE)
    ax.set_title("SWE-bench Verified is concentrated in real OSS repos", loc="left", fontsize=12)
    ax.set_xlabel("Tasks in benchmark")
    ax.set_ylabel("")
    style_axes(ax)
    for i, value in enumerate(top_repos["tasks"]):
        ax.text(value + 0.8, i, str(int(value)), va="center", fontsize=8.5, color=INK)
    ax.text(
        0,
        -0.18,
        "Top repositories by task count. Use this as benchmark context, not a customer workload guarantee.",
        transform=ax.transAxes,
        fontsize=8,
        color=MUTED,
    )
    savefig(fig, "swebench_verified_repos.svg")

    return {
        "dataset": "SWE-bench Verified",
        "samples": len(df),
        "features": 7,
        "classes": int(df["repo"].nunique()),
        "test_cases": len(df),
        "median_problem_chars": float(df["problem_chars"].median()),
        "median_patch_lines": float(df["patch_lines"].median()),
        "top_repo": str(repo_counts.iloc[0]["repo"]),
    }


if __name__ == "__main__":
    summaries = [
        breast_cancer_selective_routing(),
        digits_conformal_sets(),
        swebench_verified_distribution(),
    ]
    write_summary(summaries)
    print("wrote:")
    for path in sorted([*ASSETS.glob("*.svg"), *DATA.glob("*.csv")]):
        print(path.relative_to(ROOT))
