# -*- coding: utf-8 -*-
"""MI-LSTM 多輸入股價預測:台股費後實測(論文 Li, Shen & Zhu 2018 的落地驗證)。

對應文章:
  https://finlab.finance/blog/mi-lstm-stock-price-prediction-paper

執行:
  cd ~/Documents/finlab && uv run --with torch --with 'numpy<2' --with matplotlib \
    python /tmp/seo-merge/mi-lstm/run_backtest.py
  (本機 venv 為 x86_64,torch 只有 2.2.2 wheel,須搭配 numpy<2)

輸出:
  - /tmp/seo-merge/mi-lstm/metrics.json
  - /tmp/seo-merge/mi-lstm/equity.csv
  - /tmp/seo-merge/mi-lstm/report_strategy.html
  - /tmp/seo-merge/mi-lstm/*.png(五張 16:9 圖)

設計(與原論文的對應與差異,文中 §G 全文揭露):
  - 四類輸入流:mainstream(自身週報酬)、positive(滾動相關係數最高 10 檔平均)、
    negative(最負相關 10 檔平均)、index(加權股價報酬指數週報酬)
  - 週頻(原文為日頻;本文驗證作者 2020 年提出的「拉長到週頻降低手續費」假說)
  - 每年初重新訓練(walk-forward),訓練窗 = 前 156 週,相關性窗 = 前 104 週,
    特徵窗 = 12 週;預測全程樣本外(out-of-sample)
  - 股票池:每年依前 52 週平均週成交值取前 300 檔(4 碼純股票代號,排除 ETF/權證)
  - 對照組:單輸入 LSTM(只用 mainstream),同流程同參數
  - 3 個隨機種子,集成平均當主結果,各 seed 結果留作穩健性
  - 策略成本:finlab sim() 台股預設(手續費 0.1425%、賣出證交稅 0.3%)
  - 基準:0050 含息 etl:adj_close 純指數算術 buy-and-hold(全站 canonical 口徑,不經 sim)

投資警語:本程式僅供量化研究與教學用途,過去績效不代表未來表現,
不構成任何投資建議;實際交易前請自行評估風險、滑價與交易容量。
"""
from __future__ import annotations

import json
import os
import re
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from numpy.lib.stride_tricks import sliding_window_view

from finlab import data
from finlab.backtest import sim

warnings.filterwarnings("ignore")
torch.set_num_threads(max(1, os.cpu_count() - 2))

START = "2018-01-01"
# 資料截止釘在全站 canonical 0050 快照日(聖經 §D,2026-06-09)
END = os.environ.get("DATA_END", "2026-06-09")
OUT = Path("/tmp/seo-merge/mi-lstm")
OUT.mkdir(parents=True, exist_ok=True)

LOOKBACK = 12        # 特徵窗(週)
TRAIN_WEEKS = 156    # 訓練窗(週)
CORR_WEEKS = 104     # 相關性估計窗(週)
K_NEIGHBOR = 10      # 正/負相關各取幾檔
POOL_SIZE = 300      # 每年股票池大小(對齊原論文 300 檔)
TOP_N = 20           # 每週持有檔數(對齊原論文每日選 20 檔)
HIDDEN = 32
EPOCHS = 12
BATCH = 2048
LR = 1e-3
SEEDS = [0, 1, 2]
CLIP = 5.0           # 標準化後特徵/目標截斷


# ---------- 資料 ----------
def cap(df):
    return df[df.index <= END]

adj = cap(data.get("etl:adj_close"))
close = cap(data.get("price:收盤價"))
vol = cap(data.get("price:成交股數"))
twii = cap(data.get("benchmark_return:發行量加權股價報酬指數"))

# 只留 4 碼純數字股票代號且非 00 開頭(排除 ETF / 受益憑證 / 權證)
stock_cols = [c for c in adj.columns if re.fullmatch(r"\d{4}", c) and not c.startswith("00")]
adj_stk = adj[stock_cols]

# 週頻(W-FRI)
wprice = adj_stk.resample("W-FRI").last()
wret = wprice.pct_change()
dollar_vol = (close[stock_cols] * vol[stock_cols]).resample("W-FRI").sum()
idx_w = twii.iloc[:, 0].resample("W-FRI").last().pct_change()

bench_adj = adj["0050"]
del adj, close, vol, adj_stk, twii

weeks = wret.index


# ---------- 模型 ----------
class MILSTM(nn.Module):
    """簡化版 MI-LSTM(Li, Shen & Zhu 2018):
    mainstream 走標準 LSTM gate;positive/negative/index 各有獨立 input gate,
    四條候選 cell state 經 attention 加權融合進 cell state。
    """

    def __init__(self, hidden: int = HIDDEN):
        super().__init__()
        self.hidden = hidden
        self.gates = nn.Linear(hidden + 1, 3 * hidden)   # i, f, o(由 [h, y] 計算)
        self.cand = nn.ModuleList([nn.Linear(hidden + 1, hidden) for _ in range(4)])
        self.aux_gate = nn.ModuleList([nn.Linear(hidden + 1, hidden) for _ in range(3)])
        self.attn_w = nn.Linear(hidden, hidden)
        self.attn_v = nn.Linear(hidden, 1, bias=False)
        self.head = nn.Linear(hidden, 1)

    def forward(self, x):                      # x: (B, T, 4) = [y, p, n, idx]
        B, T, _ = x.shape
        h = x.new_zeros(B, self.hidden)
        c = x.new_zeros(B, self.hidden)
        attn_sum = x.new_zeros(B, 4)
        for t in range(T):
            y = x[:, t, 0:1]
            streams = [x[:, t, s:s + 1] for s in range(4)]
            g = self.gates(torch.cat([h, y], dim=1))
            i_g, f_g, o_g = torch.sigmoid(g).chunk(3, dim=1)
            cands = [torch.tanh(self.cand[s](torch.cat([h, streams[s]], dim=1)))
                     for s in range(4)]
            u = [i_g * cands[0]]                                  # mainstream
            for s in range(1, 4):                                 # p / n / idx 的輔助 gate
                a_g = torch.sigmoid(self.aux_gate[s - 1](torch.cat([h, streams[s]], dim=1)))
                u.append(a_g * cands[s])
            stack = torch.stack(u, dim=1)                         # (B, 4, H)
            score = self.attn_v(torch.tanh(self.attn_w(stack)))   # (B, 4, 1)
            alpha = torch.softmax(score, dim=1)
            merged = (alpha * stack).sum(dim=1)
            c = f_g * c + merged
            h = o_g * torch.tanh(c)
            attn_sum = attn_sum + alpha.squeeze(-1)
        return self.head(h).squeeze(-1), attn_sum / T


class SingleLSTM(nn.Module):
    """對照組:只吃 mainstream 的標準 LSTM。"""

    def __init__(self, hidden: int = HIDDEN):
        super().__init__()
        self.lstm = nn.LSTM(1, hidden, batch_first=True)
        self.head = nn.Linear(hidden, 1)

    def forward(self, x):                      # x: (B, T, 1)
        out, _ = self.lstm(x)
        return self.head(out[:, -1]).squeeze(-1), None


def train_model(model, X, y, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    for layer in model.modules():
        if isinstance(layer, nn.Linear):
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.MSELoss()
    Xt = torch.tensor(X, dtype=torch.float32)
    yt = torch.tensor(y, dtype=torch.float32)
    n = len(Xt)
    for _ in range(EPOCHS):
        perm = torch.randperm(n)
        for b in range(0, n, BATCH):
            idx = perm[b:b + BATCH]
            pred, _ = model(Xt[idx])
            loss = loss_fn(pred, yt[idx])
            opt.zero_grad()
            loss.backward()
            opt.step()
    return model


# ---------- 每年 walk-forward ----------
def build_streams(pool, nb_pos, nb_neg, week_slice):
    """回傳 (W, N, 4) 的四流特徵矩陣(原始週報酬,NaN 補 0)。"""
    y_mat = wret.loc[week_slice, pool].values
    y_fill = np.nan_to_num(y_mat, nan=0.0)
    p_mat = np.nanmean(np.where(np.isnan(y_mat[:, nb_pos]), 0.0, y_mat[:, nb_pos]), axis=2)
    n_mat = np.nanmean(np.where(np.isnan(y_mat[:, nb_neg]), 0.0, y_mat[:, nb_neg]), axis=2)
    i_vec = idx_w.loc[week_slice].fillna(0.0).values[:, None].repeat(len(pool), axis=1)
    return np.stack([y_fill, p_mat, n_mat, i_vec], axis=2), y_mat


def make_samples(streams, raw, stds, with_target=True):
    """streams: (W, N, 4) → 滑窗樣本 X:(S, LOOKBACK, 4), y:(S,), 索引 (week_pos, stock_pos)。"""
    Wn, N, _ = streams.shape
    z = np.clip(streams / stds, -CLIP, CLIP)
    win = sliding_window_view(z, LOOKBACK, axis=0)        # (W-LB+1, N, 4, LB)
    win = win.transpose(0, 1, 3, 2)                       # (W-LB+1, N, LB, 4)
    t_end = Wn - LOOKBACK if with_target else Wn - LOOKBACK + 1
    Xs, ys, idxs = [], [], []
    for w in range(t_end):
        x_w = win[w]                                      # (N, LB, 4)
        if with_target:
            tgt = raw[w + LOOKBACK]                       # 視窗結束週的「下一週」報酬
            ok = ~np.isnan(tgt) & ~np.isnan(raw[w + LOOKBACK - 1])
        else:
            ok = ~np.isnan(raw[w + LOOKBACK - 1])
        if not ok.any():
            continue
        Xs.append(x_w[ok])
        if with_target:
            ys.append(np.clip(tgt[ok] / stds[0, 0, 0], -CLIP, CLIP))
        idxs.append(np.stack([np.full(ok.sum(), w + LOOKBACK - 1), np.where(ok)[0]], axis=1))
    X = np.concatenate(Xs)
    y = np.concatenate(ys) if with_target else None
    return X, y, np.concatenate(idxs)


years = list(range(2018, 2027))
pred_records = {"mi": [], "lstm": []}          # (date, stock, seed, score)
mse_rows = []
attn_rows = []
corr_example = None

for year in years:
    pred_weeks = weeks[(weeks.year == year) & (weeks <= pd.Timestamp(END))]
    if len(pred_weeks) == 0:
        continue
    first_pred = pred_weeks[0]
    cut = weeks.get_loc(first_pred)            # 預測期第一週的位置
    train_idx = weeks[cut - TRAIN_WEEKS:cut]   # 訓練窗(全在預測期之前)

    # --- 股票池:前 52 週平均週成交值前 300 檔,且訓練+特徵窗資料覆蓋 >= 90% ---
    lookwin = weeks[cut - 52:cut]
    avg_dv = dollar_vol.loc[lookwin].mean()
    coverage = wret.loc[train_idx].notna().mean()
    eligible = coverage[coverage >= 0.9].index
    pool = avg_dv[eligible].nlargest(POOL_SIZE).index.tolist()

    # --- 相關性鄰居(只用預測期之前的資料,無前視) ---
    corr_win = wret.loc[weeks[cut - CORR_WEEKS:cut], pool]
    cmat = np.corrcoef(np.nan_to_num(corr_win.values, nan=0.0).T)
    cmat = np.nan_to_num(cmat, nan=0.0)
    np.fill_diagonal(cmat, 0.0)
    nb_pos = np.argsort(-cmat, axis=1)[:, :K_NEIGHBOR]
    nb_neg = np.argsort(cmat, axis=1)[:, :K_NEIGHBOR]

    if year == years[-1] and "2330" in pool:   # 圖 3 素材:台積電的多輸入池
        j = pool.index("2330")
        corr_example = {
            "stock": "2330",
            "asof": str(weeks[cut - 1].date()),
            "positive": [(pool[k], round(float(cmat[j, k]), 3)) for k in nb_pos[j]],
            "negative": [(pool[k], round(float(cmat[j, k]), 3)) for k in nb_neg[j]],
        }

    # --- 訓練樣本 ---
    train_streams, train_raw = build_streams(pool, nb_pos, nb_neg, train_idx)
    stds = np.nanstd(train_streams, axis=(0, 1)).reshape(1, 1, 4)
    stds[stds == 0] = 1.0
    Xtr, ytr, _ = make_samples(train_streams, train_raw, stds, with_target=True)

    # --- 預測樣本(視窗跨年銜接:往前補 LOOKBACK-1 週) ---
    pred_slice = weeks[cut - LOOKBACK + 1: weeks.get_loc(pred_weeks[-1]) + 1]
    pred_streams, pred_raw = build_streams(pool, nb_pos, nb_neg, pred_slice)
    Xpd, _, idx_pd = make_samples(pred_streams, pred_raw, stds, with_target=False)
    pred_dates = pred_slice[idx_pd[:, 0]]
    keep = np.isin(pred_dates, pred_weeks)
    Xpd, idx_pd, pred_dates = Xpd[keep], idx_pd[keep], pred_dates[keep]

    # 測試 MSE 用:預測週的下一週實際報酬(最後一週沒有目標,設 NaN)
    next_ret = np.full(len(idx_pd), np.nan)
    for r, (wpos, spos) in enumerate(idx_pd):
        if wpos + 1 < len(pred_raw):
            next_ret[r] = pred_raw[wpos + 1, spos]
    target_z = np.clip(next_ret / stds[0, 0, 0], -CLIP, CLIP)

    for seed in SEEDS:
        # MI-LSTM(建構前先定 seed,確保整段可重現)
        torch.manual_seed(seed)
        mi = train_model(MILSTM(), Xtr, ytr, seed)
        mi.eval()
        with torch.no_grad():
            score_mi, attn = mi(torch.tensor(Xpd, dtype=torch.float32))
        score_mi = score_mi.numpy()
        attn_rows.append({"year": year, "seed": seed,
                          "alpha": attn.mean(dim=0).numpy().round(4).tolist()})
        ok = ~np.isnan(target_z)
        mse_mi = float(np.mean((score_mi[ok] - target_z[ok]) ** 2))

        # 單輸入 LSTM 對照
        torch.manual_seed(seed)
        ls = train_model(SingleLSTM(), Xtr[:, :, 0:1], ytr, seed)
        ls.eval()
        with torch.no_grad():
            score_ls, _ = ls(torch.tensor(Xpd[:, :, 0:1], dtype=torch.float32))
        score_ls = score_ls.numpy()
        mse_ls = float(np.mean((score_ls[ok] - target_z[ok]) ** 2))
        mse_rows.append({"year": year, "seed": seed,
                         "mse_mi": round(mse_mi, 4), "mse_lstm": round(mse_ls, 4)})

        for r in range(len(idx_pd)):
            pred_records["mi"].append((pred_dates[r], pool[idx_pd[r, 1]], seed, score_mi[r]))
            pred_records["lstm"].append((pred_dates[r], pool[idx_pd[r, 1]], seed, score_ls[r]))
    print(f"[{year}] pool={len(pool)} train={len(Xtr)} pred={len(Xpd)} done")


# ---------- 集成 → 持股 ----------
def to_position(records, top_n=TOP_N):
    df = pd.DataFrame(records, columns=["date", "stock", "seed", "score"])
    score = df.groupby(["date", "stock"])["score"].mean().unstack()   # seed 集成
    top = score.rank(axis=1, ascending=False) <= top_n
    return top.astype(float).div(top.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)


def to_position_single_seed(records, seed, top_n=TOP_N):
    df = pd.DataFrame(records, columns=["date", "stock", "seed", "score"])
    score = df[df.seed == seed].set_index(["date", "stock"])["score"].unstack()
    top = score.rank(axis=1, ascending=False) <= top_n
    return top.astype(float).div(top.sum(axis=1).replace(0, np.nan), axis=0).fillna(0)


pos_mi = to_position(pred_records["mi"])
pos_lstm = to_position(pred_records["lstm"])


def turnover(pos):
    return float((pos.diff().abs().sum(axis=1) / 2).mean())


def clip_creturn(creturn):
    """sim() 的 creturn 會延伸到執行當日;統計前必須雙端截斷到釘日(2026-06 教訓)。"""
    return creturn[(creturn.index >= START) & (creturn.index <= END)]


def summarize(name, report):
    # 統計一律對截斷後 creturn 用純算術計算,與全站 canonical 0050 口徑同式
    cr = clip_creturn(report.creturn).dropna()
    ret = cr.pct_change().dropna()
    mret = cr.resample("ME").last().pct_change().dropna()
    total = float(cr.iloc[-1] / cr.iloc[0] - 1)
    years = (cr.index[-1] - cr.index[0]).days / 365.25
    return {
        "name": name,
        "cagr": round(((1 + total) ** (1 / years) - 1) * 100, 2),
        "daily_sharpe": round(float(ret.mean() / ret.std() * 252 ** 0.5), 2),
        "daily_sortino": round(float(ret.mean() / ret[ret < 0].std() * 252 ** 0.5), 2),
        "monthly_sortino": round(float(mret.mean() / mret[mret < 0].std() * 12 ** 0.5), 2),
        "max_drawdown": round(float((cr / cr.cummax() - 1).min()) * 100, 2),
        "total_return": round(total * 100, 1),
    }


runs = {}
reports = {}
specs = [
    ("MI-LSTM 週頻 20 檔", pos_mi, {}),
    ("單輸入 LSTM 週頻 20 檔", pos_lstm, {}),
    ("MI-LSTM 月頻再平衡", pos_mi.resample("ME").last(), {}),
    ("MI-LSTM 週頻 手續費加倍", pos_mi, {"fee_ratio": 1.425 / 1000 * 2}),
]
for name, pos, kw in specs:
    rep = sim(pos, upload=False, **kw)
    runs[name] = summarize(name, rep)
    runs[name]["turnover_per_rebalance_pct"] = round(turnover(pos) * 100, 1)
    reports[name] = rep
    print(runs[name])

# 各 seed 的 MI 週頻 Sharpe(穩健性)
seed_rows = []
for seed in SEEDS:
    rep = sim(to_position_single_seed(pred_records["mi"], seed), upload=False)
    row = summarize(f"MI-LSTM seed{seed}", rep)
    seed_rows.append(row)
    print(row)


# ---------- 0050 canonical 基準 ----------
bser = bench_adj[(bench_adj.index >= START) & (bench_adj.index <= END)].dropna()
bret = bser.pct_change().dropna()
_years = (bser.index[-1] - bser.index[0]).days / 365.25
_total = float(bser.iloc[-1] / bser.iloc[0] - 1)
_mret = bser.resample("ME").last().pct_change().dropna()
bench = {
    "name": "0050 含息",
    "cagr": round(((1 + _total) ** (1 / _years) - 1) * 100, 2),
    "daily_sharpe": round(float(bret.mean() / bret.std() * (252 ** 0.5)), 2),
    "daily_sortino": round(float(bret.mean() / bret[bret < 0].std() * (252 ** 0.5)), 2),
    "monthly_sortino": round(float(_mret.mean() / _mret[_mret < 0].std() * (12 ** 0.5)), 2),
    "max_drawdown": round(float((bser / bser.cummax() - 1).min()) * 100, 2),
    "total_return": round(_total * 100, 1),
}
print("bench:", bench)
bench_creturn = bser / bser.iloc[0]


# ---------- 分段(穩健性) ----------
def seg_stats(creturn, lo, hi):
    seg = creturn[(creturn.index >= lo) & (creturn.index <= hi)]
    if len(seg) < 30:
        return None
    r = seg.pct_change().dropna()
    yrs = (seg.index[-1] - seg.index[0]).days / 365.25
    tot = float(seg.iloc[-1] / seg.iloc[0] - 1)
    return {
        "cagr": round(((1 + tot) ** (1 / yrs) - 1) * 100, 2),
        "sharpe": round(float(r.mean() / r.std() * (252 ** 0.5)), 2),
        "mdd": round(float((seg / seg.cummax() - 1).min()) * 100, 2),
    }


segments = [("2018-01-01", "2020-12-31"), ("2021-01-01", "2023-12-31"), ("2024-01-01", END)]
seg_table = {}
for lo, hi in segments:
    seg_table[f"{lo[:4]}-{hi[:4]}"] = {
        "MI-LSTM": seg_stats(reports["MI-LSTM 週頻 20 檔"].creturn, lo, hi),
        "LSTM": seg_stats(reports["單輸入 LSTM 週頻 20 檔"].creturn, lo, hi),
        "0050": seg_stats(bench_creturn, lo, hi),
    }

# MSE 彙總
mse_df = pd.DataFrame(mse_rows)
mse_summary = {
    "mean_mse_mi": round(float(mse_df.mse_mi.mean()), 4),
    "mean_mse_lstm": round(float(mse_df.mse_lstm.mean()), 4),
    "mi_wins_year_seed": int((mse_df.mse_mi < mse_df.mse_lstm).sum()),
    "total_year_seed": len(mse_df),
    "by_year": mse_df.groupby("year")[["mse_mi", "mse_lstm"]].mean().round(4).to_dict("index"),
}

# attention 彙總
attn_df = pd.DataFrame([{"year": r["year"], "seed": r["seed"],
                         "mainstream": r["alpha"][0], "positive": r["alpha"][1],
                         "negative": r["alpha"][2], "index": r["alpha"][3]} for r in attn_rows])
attn_mean = attn_df[["mainstream", "positive", "negative", "index"]].mean().round(4).to_dict()
attn_by_year = attn_df.groupby("year")[["mainstream", "positive", "negative", "index"]].mean().round(4)

metrics = {
    "data_end": END,
    "start": START,
    "benchmark": bench,
    "strategies": list(runs.values()),
    "per_seed_mi_weekly": seed_rows,
    "mse": mse_summary,
    "attention_mean": attn_mean,
    "attention_by_year": attn_by_year.to_dict("index"),
    "segments": seg_table,
    "corr_example": corr_example,
    "method": {
        "frequency": "週頻 W-FRI,訊號週五收盤,次一交易日收盤價執行(finlab 預設 trade_at_price=close)",
        "pool": f"每年依前 52 週平均週成交值前 {POOL_SIZE} 檔;4 碼純股票代號,排除 ETF;訓練窗資料覆蓋率 >= 90%",
        "model": f"MI-LSTM hidden={HIDDEN}, lookback={LOOKBACK} 週, 每年重訓(訓練窗 {TRAIN_WEEKS} 週), epochs={EPOCHS}, seeds={SEEDS} 集成",
        "neighbors": f"相關係數窗 {CORR_WEEKS} 週,正/負相關各 {K_NEIGHBOR} 檔,鄰居集每年固定",
        "cost": "策略走 finlab sim() 台股預設成本(手續費 0.1425% 雙邊 + 賣出證交稅 0.3%);0050 基準為含息純指數算術,不含成本",
        "oos": "預測全程樣本外(walk-forward 每年重訓,僅用截至前一週的資料)",
    },
}
(OUT / "metrics.json").write_text(json.dumps(metrics, ensure_ascii=False, indent=2), encoding="utf-8")

eq = pd.DataFrame({
    "0050": bench_creturn,
    "mi_lstm": clip_creturn(reports["MI-LSTM 週頻 20 檔"].creturn),
    "lstm": clip_creturn(reports["單輸入 LSTM 週頻 20 檔"].creturn),
})
eq.to_csv(OUT / "equity.csv")
reports["MI-LSTM 週頻 20 檔"].to_html(str(OUT / "report_strategy.html"),
                                      title="MI-LSTM 多輸入週頻選股(費後)")


# ---------- 圖表 ----------
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch, FancyBboxPatch

plt.rcParams["font.sans-serif"] = ["Heiti TC", "PingFang HK", "Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False

C = {"mi": "#2563EB", "ls": "#EF4444", "bench": "#9CA3AF",
     "green": "#10B981", "amber": "#F59E0B", "purple": "#8B5CF6"}


def style(ax):
    ax.grid(True, alpha=0.22)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)


def fig_ax():
    return plt.subplots(figsize=(12, 6.75), dpi=100)


# 圖 1:縮圖(權益曲線 + 結論數字)
eqp = eq.dropna(how="all").ffill()
fig, ax = fig_ax()
ax.plot(eqp.index, eqp["0050"], label=f"0050 含息(CAGR {bench['cagr']}%)",
        color=C["bench"], linewidth=2.2)
ax.plot(eqp.index, eqp["lstm"],
        label=f"單輸入 LSTM(CAGR {runs['單輸入 LSTM 週頻 20 檔']['cagr']}%)",
        color=C["ls"], linewidth=2.0)
ax.plot(eqp.index, eqp["mi_lstm"],
        label=f"MI-LSTM 多輸入(CAGR {runs['MI-LSTM 週頻 20 檔']['cagr']}%)",
        color=C["mi"], linewidth=2.4)
ax.set_yscale("log")
ax.set_title("MI-LSTM 股價預測台股實測:多輸入 vs 單輸入 LSTM vs 0050(費後,2018-2026)",
             fontsize=15, fontweight="bold")
ax.set_ylabel("淨值(log)")
ax.legend(loc="upper left")
style(ax)
fig.tight_layout()
fig.savefig(OUT / "chart_equity_thumbnail.png")
plt.close(fig)

# 圖 2:MI-LSTM 架構示意(重繪,非論文截圖)
fig, ax = fig_ax()
ax.axis("off")
streams = [("欲預測個股\n過去 12 週報酬\n(mainstream)", C["mi"]),
           ("最正相關 10 檔\n平均報酬\n(positive)", C["green"]),
           ("最負相關 10 檔\n平均報酬\n(negative)", C["ls"]),
           ("加權報酬指數\n(index)", C["amber"])]
for k, (label, color) in enumerate(streams):
    yc = 0.82 - k * 0.22
    ax.add_patch(FancyBboxPatch((0.03, yc - 0.075), 0.2, 0.15,
                                boxstyle="round,pad=0.012", fc=color, alpha=0.16, ec=color))
    ax.text(0.13, yc, label, ha="center", va="center", fontsize=11)
    ax.add_patch(FancyArrowPatch((0.24, yc), (0.34, 0.5 + (yc - 0.5) * 0.4),
                                 arrowstyle="-|>", mutation_scale=16, color=color, lw=1.8))
ax.add_patch(FancyBboxPatch((0.35, 0.28), 0.24, 0.44, boxstyle="round,pad=0.014",
                            fc="#EFF6FF", ec=C["mi"], lw=2))
ax.text(0.47, 0.60, "MI-LSTM cell", ha="center", fontsize=13, fontweight="bold")
ax.text(0.47, 0.50, "mainstream 算出\ninput / forget / output gate", ha="center", fontsize=10)
ax.text(0.47, 0.37, "輔助資訊流各有獨立\ninput gate 過濾雜訊", ha="center", fontsize=10)
ax.add_patch(FancyArrowPatch((0.60, 0.5), (0.66, 0.5), arrowstyle="-|>",
                             mutation_scale=18, color="#111827", lw=2))
ax.add_patch(FancyBboxPatch((0.67, 0.36), 0.15, 0.28, boxstyle="round,pad=0.014",
                            fc="#FDF4FF", ec=C["purple"], lw=2))
ax.text(0.745, 0.55, "Attention", ha="center", fontsize=12, fontweight="bold")
ax.text(0.745, 0.45, "對四條資訊流\n動態加權", ha="center", fontsize=10)
ax.add_patch(FancyArrowPatch((0.83, 0.5), (0.88, 0.5), arrowstyle="-|>",
                             mutation_scale=18, color="#111827", lw=2))
ax.add_patch(FancyBboxPatch((0.88, 0.40), 0.1, 0.2, boxstyle="round,pad=0.014",
                            fc="#ECFDF5", ec=C["green"], lw=2))
ax.text(0.93, 0.5, "預測\n下週報酬", ha="center", va="center", fontsize=11)
ax.set_title("MI-LSTM 多輸入長短期記憶模型架構:四條資訊流 + 雜訊過濾 gate + attention",
             fontsize=15, fontweight="bold")
fig.tight_layout()
fig.savefig(OUT / "chart_mi_lstm_architecture.png")
plt.close(fig)

# 圖 3:台積電的多輸入池(真實相關係數)
if corr_example:
    pos_pairs = corr_example["positive"]
    neg_pairs = corr_example["negative"]
    labels = [p[0] for p in pos_pairs] + [p[0] for p in neg_pairs]
    vals = [p[1] for p in pos_pairs] + [p[1] for p in neg_pairs]
    colors = [C["green"]] * len(pos_pairs) + [C["ls"]] * len(neg_pairs)
    fig, ax = fig_ax()
    x = np.arange(len(labels))
    ax.bar(x, vals, color=colors)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("與 2330 的週報酬相關係數")
    ax.set_title(f"MI-LSTM 輸入池實例:與台積電 2330 最正 / 最負相關各 10 檔"
                 f"(過去 {CORR_WEEKS} 週,截至 {corr_example['asof']})",
                 fontsize=14, fontweight="bold")
    style(ax)
    fig.tight_layout()
    fig.savefig(OUT / "chart_2330_correlation_pool.png")
    plt.close(fig)

# 圖 4:attention 權重(各年平均)
fig, ax = fig_ax()
x = np.arange(len(attn_by_year))
w = 0.2
names = [("mainstream", "mainstream(自身)", C["mi"]),
         ("positive", "positive(正相關股)", C["green"]),
         ("negative", "negative(負相關股)", C["ls"]),
         ("index", "index(大盤)", C["amber"])]
for k, (col, lab, color) in enumerate(names):
    ax.bar(x + (k - 1.5) * w, attn_by_year[col].values, w, label=lab, color=color)
ax.axhline(0.25, color="#111827", linestyle="--", linewidth=1.2, label="均勻權重 0.25")
ax.set_xticks(x)
ax.set_xticklabels(attn_by_year.index.astype(str))
ax.set_ylabel("attention 平均權重")
ax.set_title("台股 MI-LSTM 四類資訊流的 attention 權重(各年測試期平均,3 seeds)",
             fontsize=15, fontweight="bold")
ax.legend(loc="upper right", ncol=3)
style(ax)
fig.tight_layout()
fig.savefig(OUT / "chart_attention_weights.png")
plt.close(fig)

# 圖 5:成本敏感度
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6.75), dpi=100)
variants = ["MI-LSTM 週頻 20 檔", "MI-LSTM 月頻再平衡", "MI-LSTM 週頻 手續費加倍"]
labels = ["週頻", "月頻再平衡", "週頻\n手續費加倍"]
x = np.arange(len(variants))
ax1.bar(x, [runs[v]["cagr"] for v in variants], color=[C["mi"], C["green"], C["amber"]])
ax1.axhline(bench["cagr"], color=C["bench"], linestyle="--", linewidth=1.6,
            label=f"0050 含息 {bench['cagr']}%")
ax1.set_xticks(x)
ax1.set_xticklabels(labels)
ax1.set_title("年化報酬 CAGR(%)")
ax1.legend()
ax2.bar(x, [runs[v]["daily_sharpe"] for v in variants], color=[C["mi"], C["green"], C["amber"]])
ax2.axhline(bench["daily_sharpe"], color=C["bench"], linestyle="--", linewidth=1.6,
            label=f"0050 含息 {bench['daily_sharpe']}")
ax2.set_xticks(x)
ax2.set_xticklabels(labels)
ax2.set_title("日夏普")
ax2.legend()
style(ax1)
style(ax2)
fig.suptitle("MI-LSTM 選股的交易成本敏感度:再平衡頻率與費率的影響(費後,2018-2026)",
             fontsize=15, fontweight="bold")
fig.tight_layout()
fig.savefig(OUT / "chart_cost_sensitivity.png")
plt.close(fig)

print("\n=== DONE ===")
print(json.dumps(metrics["strategies"], ensure_ascii=False, indent=1))
print("bench:", bench)
print("mse:", mse_summary)
print("attention:", attn_mean)
