{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# 步驟 1：掛載 Google Drive 與環境準備"
      ],
      "metadata": {
        "id": "T-_LTdc7BTWH"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UYMOEHNRBMxc",
        "outputId": "08a33d7a-b5e7-434f-c56c-c06b7e42b915"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "正在連線至您的 Google Drive...\n",
            "Mounted at /content/drive\n",
            "目前使用的運算裝置: cuda\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision.datasets as datasets\n",
        "import torchvision.transforms as transforms\n",
        "import gradio as gr\n",
        "import cv2\n",
        "from google.colab import drive\n",
        "\n",
        "# 1. 掛載 Google Drive\n",
        "print(\"正在連線至您的 Google Drive...\")\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# 2. 設定權重儲存路徑 (在雲端硬碟建立一個資料夾)\n",
        "SAVE_DIR = \"/content/drive/MyDrive/MNIST_Model\"\n",
        "os.makedirs(SAVE_DIR, exist_ok=True)\n",
        "MODEL_PATH = os.path.join(SAVE_DIR, \"robust_cnn_mnist.pth\")\n",
        "\n",
        "# 3. 定義硬體加速\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"目前使用的運算裝置: {device}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 步驟 2：定義 CNN 模型結構"
      ],
      "metadata": {
        "id": "OkpwXJ9ABO1O"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class RobustCNN(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(RobustCNN, self).__init__()\n",
        "        self.features = nn.Sequential(\n",
        "            nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
        "            nn.ReLU(),\n",
        "            nn.MaxPool2d(2), # 輸出 14x14\n",
        "            nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
        "            nn.ReLU(),\n",
        "            nn.MaxPool2d(2)  # 輸出 7x7\n",
        "        )\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(64 * 7 * 7, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, 10)\n",
        "        )\n",
        "    def forward(self, x):\n",
        "        x = self.features(x)\n",
        "        x = x.view(x.size(0), -1)\n",
        "        return self.classifier(x)\n",
        "\n",
        "# 實例化模型並送入裝置\n",
        "model = RobustCNN().to(device)"
      ],
      "metadata": {
        "id": "jnyyHPEFBOf3"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "步驟 3：智能訓練與權重管理（核心加速區）\n",
        "-\n",
        "\n"
      ],
      "metadata": {
        "id": "lSgVfBdABRwu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "if os.path.exists(MODEL_PATH):\n",
        "    # ==========================================\n",
        "    # 情況 A：雲端硬碟有歷史檔案，直接載入！\n",
        "    # ==========================================\n",
        "    print(\"🎯 偵測到雲端硬碟已有訓練好的模型權重，正在直接載入以加速啟動...\")\n",
        "    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))\n",
        "    model.eval()\n",
        "    print(\"🎉 模型載入成功！免訓練直接進入下一步。\")\n",
        "else:\n",
        "    # ==========================================\n",
        "    # 情況 B：第一次執行，下載 MNIST 數據集並訓練\n",
        "    # ==========================================\n",
        "    print(\"⚠️ 雲端硬碟無歷史權重，啟動首次訓練流程...\")\n",
        "    print(\"正在載入 MNIST 數據集...\")\n",
        "    transform = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.1307,), (0.3081,))\n",
        "    ])\n",
        "\n",
        "    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
        "    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
        "\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.003)\n",
        "\n",
        "    print(\"AI 正在學習數字特徵 (請稍候，共 2 個 Epoch)...\")\n",
        "    model.train()\n",
        "    for epoch in range(2):\n",
        "        for batch_idx, (data, target) in enumerate(train_loader):\n",
        "            data, target = data.to(device), target.to(device)\n",
        "            optimizer.zero_grad()\n",
        "            loss = criterion(model(data), target)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "    # 訓練完畢，立刻存檔到雲端硬碟\n",
        "    torch.save(model.state_dict(), MODEL_PATH)\n",
        "    model.eval()\n",
        "    print(f\"🎉 首次訓練完成！權重已安全儲存至 Google Drive: {MODEL_PATH}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "j0-sLY7XB563",
        "outputId": "a033ac0b-ed70-42cf-9643-f929dc7f3a41"
      },
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "⚠️ 雲端硬碟無歷史權重，啟動首次訓練流程...\n",
            "正在載入 MNIST 數據集...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 9.91M/9.91M [00:01<00:00, 5.61MB/s]\n",
            "100%|██████████| 28.9k/28.9k [00:00<00:00, 131kB/s]\n",
            "100%|██████████| 1.65M/1.65M [00:01<00:00, 1.26MB/s]\n",
            "100%|██████████| 4.54k/4.54k [00:00<00:00, 13.1MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "AI 正在學習數字特徵 (請稍候，共 2 個 Epoch)...\n",
            "🎉 首次訓練完成！權重已安全儲存至 Google Drive: /content/drive/MyDrive/MNIST_Model/robust_cnn_mnist.pth\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 步驟 4：影像處理與最新 Gradio 介面"
      ],
      "metadata": {
        "id": "p-H6NzT-CTAG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import torch\n",
        "import gradio as gr\n",
        "import cv2\n",
        "\n",
        "def recognize_digit(image):\n",
        "    if image is None:\n",
        "        return \"請在畫布上寫字\", {}\n",
        "\n",
        "    # 完美修復 Bug：根據 Gradio 最新規範，抽取出合成後的畫布圖像\n",
        "    if isinstance(image, dict) and \"composite\" in image:\n",
        "        img_array = image[\"composite\"]\n",
        "    elif isinstance(image, np.ndarray):\n",
        "        img_array = image\n",
        "    else:\n",
        "        return \"影像格式不符，請重試\", {}\n",
        "\n",
        "    # 檢查是否為空畫布\n",
        "    if np.max(img_array) == 0:\n",
        "        return \"請在畫布上寫字\", {}\n",
        "\n",
        "    # 處理通道數 (Gradio 的 composite 通常是 RGB 或 RGBA)\n",
        "    if len(img_array.shape) == 3:\n",
        "        if img_array.shape[2] == 4:\n",
        "            # 如果有 Alpha 通道，通常提取它做筆跡最準\n",
        "            gray = img_array[:, :, 3]\n",
        "        else:\n",
        "            gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)\n",
        "    else:\n",
        "        gray = img_array\n",
        "\n",
        "    # 確保為「黑底白字」\n",
        "    _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY)\n",
        "\n",
        "    # 筆跡適度加粗，避免縮小後 28x28 像素太淡\n",
        "    kernel = np.ones((3,3), np.uint8)\n",
        "    thresh = cv2.dilate(thresh, kernel, iterations=1)\n",
        "\n",
        "    # 縮放到 MNIST 標準的 28x28 像素\n",
        "    img_resized = cv2.resize(thresh, (28, 28), interpolation=cv2.INTER_AREA)\n",
        "\n",
        "    # 轉為 PyTorch 張量並進行標準化\n",
        "    img_tensor = torch.tensor(img_resized, dtype=torch.float32) / 255.0\n",
        "    img_tensor = (img_tensor - 0.1307) / 0.3081\n",
        "    img_tensor = img_tensor.unsqueeze(0).unsqueeze(0).to(device) # [1, 1, 28, 28]\n",
        "\n",
        "    # 模型推論\n",
        "    with torch.no_grad():\n",
        "        outputs = model(img_tensor)\n",
        "        probabilities = torch.softmax(outputs, dim=1)[0]\n",
        "\n",
        "    best_digit = torch.argmax(probabilities).item()\n",
        "    results = {str(i): float(probabilities[i]) for i in range(10)}\n",
        "\n",
        "    return f\"🔮 AI 辨識最終結果：【 {best_digit} 】\", results\n",
        "\n",
        "# 渲染網頁 UI 介面\n",
        "with gr.Blocks() as demo:\n",
        "    gr.Markdown(\"# ✍️智光商工電機電子群 手寫數字即時辨識實驗室\")\n",
        "    gr.Markdown(\"### 💡 請用滑鼠在左側【畫布】寫數字，點擊按鈕進行辨識！\")\n",
        "\n",
        "    with gr.Row():\n",
        "        with gr.Column():\n",
        "            input_canvas = gr.Sketchpad(\n",
        "                label=\"手寫區 (畫筆寬度適中即可)\",\n",
        "                type=\"numpy\",\n",
        "                layers=False,       # 關閉多圖層，簡化結構\n",
        "                transforms=()\n",
        "            )\n",
        "            recognize_button = gr.Button(\"辨識數字 (Recognize Digit)\")\n",
        "        with gr.Column():\n",
        "            output_text = gr.Label(label=\"預測結論\")\n",
        "            output_chart = gr.Label(label=\"0 ~ 9 的置信度分佈 (Top 3)\", num_top_classes=3)\n",
        "\n",
        "    # 綁定按鈕點擊事件\n",
        "    recognize_button.click(\n",
        "        fn=recognize_digit,\n",
        "        inputs=input_canvas,\n",
        "        outputs=[output_text, output_chart]\n",
        "    )\n",
        "\n",
        "demo.launch(debug=True)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 646
        },
        "id": "WIETsObxBPde",
        "outputId": "5ca01218-3f72-4f8b-c64c-86c3d8d0b952"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
            "\n",
            "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
            "* Running on public URL: https://49236ecac67ba18e6a.gradio.live\n",
            "\n",
            "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "<div><iframe src=\"https://49236ecac67ba18e6a.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "oIqTmbXhIQ0L"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Ki3sCGpzIT4p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "Ncln2FDdBSRW"
      }
    }
  ]
}