{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[*********************100%***********************] 10 of 10 completed\n" ] } ], "source": [ "import yfinance as yf\n", "import pandas as pd\n", "\n", "# List of 10 example tickers (you can replace these with any tickers you prefer)\n", "tickers = [\"AAPL\", \"MSFT\", \"GOOGL\", \"AMZN\", \"META\", \"TSLA\", \"NVDA\", \"JPM\", \"JNJ\", \"V\"]\n", "\n", "# Download daily adjusted close prices for the last month (≈30 calendar days)\n", "data = yf.download(tickers, period=\"6mo\", interval=\"1d\", auto_adjust=False)[\"Adj Close\"]\n", "\n", "# Transpose so that each row is one company’s month-long timeseries\n", "df = pd.DataFrame(data.transpose())\n", "\n", "# At this point, `df` has:\n", "# • Index: the 10 tickers\n", "# • Columns: one column per trading day in the last month\n", "# • Values: adjusted close price for that ticker on that date\n", "df.to_csv(\"stocks_data_noindex.csv\", index=False)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Stable version" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import gradio as gr\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "\n", "torch.manual_seed(42)\n", "output = torch.load(\"stocks_data_forecast.pt\") # (n_timeseries, pred_len, n_quantiles)\n", "\n", "def model_forecast(input_data):\n", " return output\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " \"\"\"Returns a NumPy array of the plotted figure.\"\"\"\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n", " ax.plot(timeseries, color=\"blue\")\n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n", " buf = io.BytesIO()\n", "\n", " # Add title\n", " ax.set_title(f\"Timeseries: {timeseries_name}\")\n", " # Add labels to the legend (quantiles)\n", " labels = [f\"Quantile {i+1}\" for i in range(quantile_predictions.shape[1])]\n", " ax.legend(labels, loc='center left', bbox_to_anchor=(1, 0.5))\n", " plt.tight_layout(rect=[0, 0, 0.85, 1])\n", "\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img) # Return as an H×W×3 array\n", "\n", "def display_forecast(file, preset_filename):\n", " accepted_formats = ['csv', 'xls', 'xlsx', 'parquet']\n", "\n", " def load_table(file_path):\n", " ext = file_path.split('.')[-1].lower()\n", " if ext == 'csv':\n", " return pd.read_csv(file_path)\n", " elif ext in ['xls', 'xlsx']:\n", " return pd.read_excel(file_path)\n", " elif ext == 'parquet':\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(f\"Unsupported file format '.{ext}'. Acceptable formats: CSV, XLS, XLSX, PARQUET.\")\n", " \n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"Please upload a file or select a preset.\"\n", " df = load_table(preset_filename)\n", " \n", " # Check first column for timeseries names\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object:\n", " if not df.iloc[:, 0].str.isnumeric().all():\n", " timeseries_names = df.iloc[:, 0].tolist()\n", " df = df.iloc[:, 1:]\n", " else:\n", " timeseries_names = [f\"Series {i}\" for i in range(len(df))]\n", " else:\n", " timeseries_names = [f\"Series {i}\" for i in range(len(df))]\n", "\n", " _input = torch.tensor(df.values)\n", " _output = model_forecast(_input)\n", "\n", " gallery_images = []\n", " for i in range(_input.shape[0]):\n", " img_array = plot_forecast_image(_input[i], _output[i], timeseries_names[i])\n", " gallery_images.append(img_array)\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Please upload files in one of the following formats: CSV, XLS, XLSX, PARQUET.\"\n", "\n", "\n", "\n", "iface = gr.Interface(\n", " fn=display_forecast,\n", " inputs=[\n", " gr.File(label=\"Upload your CSV file (optional)\"),\n", " gr.Dropdown(\n", " label=\"Or select a preset CSV file\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"stocks_data_noindex.csv\"\n", " )\n", " ],\n", " outputs=[\n", " gr.Gallery(label=\"Forecast Plots (one per row)\"), \n", " gr.Textbox(label=\"Error Message\")\n", " ],\n", " title=\"CSV→Dynamic Forecast Gallery\",\n", " description=\"Upload a CSV with any number of rows; each row’s forecast becomes one image in a gallery.\",\n", " allow_flagging=\"never\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " iface.launch()\n", "\n", "\n", "\n", "# '''\n", "# 1. Prepared datasets\n", "# 2. Plots of different quiantilies (different colors)\n", "# 3. Filters for plots...\n", "# 4. Different input options\n", "# 5. README.md in there (in UI) (contact us for fine-tuning)\n", "# 6. Requirements for dimensions\n", "# 7. Multivariate data (x_t is vector)\n", "# 8. LOGO of NX-AI and xLSTM and tirex\n", "# '''" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Ticker2024-12-042024-12-052024-12-062024-12-092024-12-102024-12-112024-12-122024-12-132024-12-16...2025-05-212025-05-222025-05-232025-05-272025-05-282025-05-292025-05-302025-06-022025-06-032025-06-04
0AAPL242.425201242.455124242.255600246.156204247.173752245.896820247.363297247.532883250.435867...202.089996201.360001195.270004200.210007200.419998199.949997200.850006201.699997203.270004203.150101
1AMZN218.160004220.550003227.029999226.089996225.039993230.259995228.970001227.460007232.929993...201.119995203.100006200.990005206.020004204.720001205.699997205.009995206.649994205.710007207.472000
2GOOGL173.970016172.244003174.309265175.168259184.956985195.175217191.739182189.601639196.433777...168.559998170.869995168.470001172.899994172.360001171.860001171.740005169.029999166.179993167.785004
3JNJ148.006256147.071823146.865265147.150513146.786560144.238968143.845535144.219299141.494659...151.877960151.312805151.639999153.250000152.429993153.580002155.210007155.399994154.419998153.419998
4JPM240.666992242.723633244.582520241.072388240.133057240.795532238.817978237.245850236.889893...261.040009260.670013260.709991265.290009263.489990264.369995264.000000264.660004266.269989265.065002
5META612.740173607.898376622.713257612.530518618.270813631.608154629.721375619.299072623.685120...635.500000636.570007627.059998642.320007643.580017645.049988647.489990670.900024666.849976685.159973
6MSFT435.744720440.924805441.871155444.311768441.632050447.270416447.838196445.556976449.860413...452.570007454.859985450.179993460.690002457.359985458.679993460.359985461.970001462.970001464.190002
7NVDA145.116653145.046661142.426895138.797226135.057587139.297180137.327362134.237656131.987854...131.800003132.830002131.289993135.500000134.809998139.190002135.130005137.380005141.220001141.854996
8TSLA357.929993369.489990389.220001389.790009400.989990424.769989418.100006436.230011463.019989...334.619995341.040009339.339996362.890015356.899994358.429993346.459991342.690002344.269989334.671600
9V308.866455308.049194309.972778307.271790311.338196312.743500313.182037313.690308314.836517...358.299988357.970001353.540009359.299988359.730011362.399994365.190002365.320007365.859985368.179993
\n", "

10 rows × 125 columns

\n", "
" ], "text/plain": [ " Ticker 2024-12-04 2024-12-05 2024-12-06 2024-12-09 2024-12-10 \\\n", "0 AAPL 242.425201 242.455124 242.255600 246.156204 247.173752 \n", "1 AMZN 218.160004 220.550003 227.029999 226.089996 225.039993 \n", "2 GOOGL 173.970016 172.244003 174.309265 175.168259 184.956985 \n", "3 JNJ 148.006256 147.071823 146.865265 147.150513 146.786560 \n", "4 JPM 240.666992 242.723633 244.582520 241.072388 240.133057 \n", "5 META 612.740173 607.898376 622.713257 612.530518 618.270813 \n", "6 MSFT 435.744720 440.924805 441.871155 444.311768 441.632050 \n", "7 NVDA 145.116653 145.046661 142.426895 138.797226 135.057587 \n", "8 TSLA 357.929993 369.489990 389.220001 389.790009 400.989990 \n", "9 V 308.866455 308.049194 309.972778 307.271790 311.338196 \n", "\n", " 2024-12-11 2024-12-12 2024-12-13 2024-12-16 ... 2025-05-21 \\\n", "0 245.896820 247.363297 247.532883 250.435867 ... 202.089996 \n", "1 230.259995 228.970001 227.460007 232.929993 ... 201.119995 \n", "2 195.175217 191.739182 189.601639 196.433777 ... 168.559998 \n", "3 144.238968 143.845535 144.219299 141.494659 ... 151.877960 \n", "4 240.795532 238.817978 237.245850 236.889893 ... 261.040009 \n", "5 631.608154 629.721375 619.299072 623.685120 ... 635.500000 \n", "6 447.270416 447.838196 445.556976 449.860413 ... 452.570007 \n", "7 139.297180 137.327362 134.237656 131.987854 ... 131.800003 \n", "8 424.769989 418.100006 436.230011 463.019989 ... 334.619995 \n", "9 312.743500 313.182037 313.690308 314.836517 ... 358.299988 \n", "\n", " 2025-05-22 2025-05-23 2025-05-27 2025-05-28 2025-05-29 2025-05-30 \\\n", "0 201.360001 195.270004 200.210007 200.419998 199.949997 200.850006 \n", "1 203.100006 200.990005 206.020004 204.720001 205.699997 205.009995 \n", "2 170.869995 168.470001 172.899994 172.360001 171.860001 171.740005 \n", "3 151.312805 151.639999 153.250000 152.429993 153.580002 155.210007 \n", "4 260.670013 260.709991 265.290009 263.489990 264.369995 264.000000 \n", "5 636.570007 627.059998 642.320007 643.580017 645.049988 647.489990 \n", "6 454.859985 450.179993 460.690002 457.359985 458.679993 460.359985 \n", "7 132.830002 131.289993 135.500000 134.809998 139.190002 135.130005 \n", "8 341.040009 339.339996 362.890015 356.899994 358.429993 346.459991 \n", "9 357.970001 353.540009 359.299988 359.730011 362.399994 365.190002 \n", "\n", " 2025-06-02 2025-06-03 2025-06-04 \n", "0 201.699997 203.270004 203.150101 \n", "1 206.649994 205.710007 207.472000 \n", "2 169.029999 166.179993 167.785004 \n", "3 155.399994 154.419998 153.419998 \n", "4 264.660004 266.269989 265.065002 \n", "5 670.900024 666.849976 685.159973 \n", "6 461.970001 462.970001 464.190002 \n", "7 137.380005 141.220001 141.854996 \n", "8 342.690002 344.269989 334.671600 \n", "9 365.320007 365.859985 368.179993 \n", "\n", "[10 rows x 125 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.read_csv(\"stocks_data.csv\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Not checked but with labels filter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "import gradio as gr\n", "\n", "# Set random seed and load your pretrained forecast tensor\n", "torch.manual_seed(42)\n", "_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n", "\n", "def model_forecast(input_data):\n", " return _forecast_tensor\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " \"\"\"Given one 1D series + quantile‐matrix, return a NumPy array of the plotted figure.\"\"\"\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n", " ax.plot(timeseries, color=\"blue\")\n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n", " ax.set_title(f\"Timeseries: {timeseries_name}\")\n", "\n", " labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n", " ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n", " plt.tight_layout(rect=[0, 0, 0.85, 1])\n", "\n", " buf = io.BytesIO()\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img)\n", "\n", "def load_table(file_path):\n", " \"\"\"Load CSV / XLS(X) / Parquet by extension, else raise.\"\"\"\n", " ext = file_path.split(\".\")[-1].lower()\n", " if ext == \"csv\":\n", " return pd.read_csv(file_path)\n", " elif ext in (\"xls\", \"xlsx\"):\n", " return pd.read_excel(file_path)\n", " elif ext == \"parquet\":\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(\n", " f\"Unsupported file format '.{ext}'. Accepted: CSV, XLS, XLSX, PARQUET.\"\n", " )\n", "\n", "def extract_names_and_update(file, preset_filename):\n", " \"\"\"\n", " Read the table (uploaded or preset), extract timeseries names, and return:\n", " 1) gr.update for the CheckboxGroup (all names pre‐checked)\n", " 2) the full list of names to store in state.\n", " \"\"\"\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return gr.update(choices=[], value=[]), []\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " names = df.iloc[:, 0].tolist()\n", " else:\n", " names = [f\"Series {i}\" for i in range(len(df))]\n", "\n", " return gr.update(choices=names, value=names), names\n", " except Exception:\n", " return gr.update(choices=[], value=[]), []\n", "\n", "def filter_names(search_term, all_names):\n", " \"\"\"\n", " Filter the full list of names (all_names) by the search_term (case‐insensitive substring).\n", " Return gr.update with filtered choices and keep checked those that remain in both.\n", " \"\"\"\n", " if not all_names:\n", " return gr.update(choices=[], value=[])\n", " if not search_term:\n", " # No search term → show all\n", " return gr.update(choices=all_names, value=all_names)\n", " lower = search_term.lower()\n", " filtered = [n for n in all_names if lower in str(n).lower()]\n", " return gr.update(choices=filtered, value=filtered)\n", "\n", "def check_all(names_list):\n", " \"\"\"Return an update that checks all names in the checkbox.\"\"\"\n", " return gr.update(value=names_list)\n", "\n", "def uncheck_all(_):\n", " \"\"\"Return an update that unchecks all names.\"\"\"\n", " return gr.update(value=[])\n", "\n", "def display_filtered_forecast(file, preset_filename, selected_names):\n", " \"\"\"\n", " Load the table, filter by selected_names, run forecast, and return:\n", " - list of images (NumPy arrays) for the gallery\n", " - error string (empty if OK)\n", " \"\"\"\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"No file selected.\"\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " all_names = df.iloc[:, 0].tolist()\n", " data_only = df.iloc[:, 1:].astype(float)\n", " else:\n", " all_names = [f\"Series {i}\" for i in range(len(df))]\n", " data_only = df.astype(float)\n", "\n", " mask = [name in selected_names for name in all_names]\n", " if not any(mask):\n", " return [], \"No timeseries chosen to plot.\"\n", "\n", " filtered_data = data_only.iloc[mask, :].values\n", " filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n", "\n", " inp = torch.tensor(filtered_data) # (n_chosen, length)\n", " out = model_forecast(inp) # (n_chosen, pred_len, n_q)\n", "\n", " gallery_images = []\n", " for i in range(inp.shape[0]):\n", " gallery_images.append(\n", " plot_forecast_image(inp[i], out[i], filtered_names[i])\n", " )\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Please upload a valid CSV, XLS, XLSX, or PARQUET file.\"\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"## Upload or select a preset → search/filter by name → click Plot\")\n", "\n", " with gr.Row():\n", " file_input = gr.File(\n", " label=\"Upload CSV/XLSX/PARQUET (optional)\",\n", " file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n", " )\n", " preset_dropdown = gr.Dropdown(\n", " label=\"Or pick a preset:\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"stocks_data_noindex.csv\"\n", " )\n", "\n", " # A text box to type a substring (search term)\n", " search_box = gr.Textbox(\n", " label=\"Search/Filter timeseries by name\",\n", " placeholder=\"Type to filter (e.g. 'AMZN')\",\n", " value=\"\"\n", " )\n", "\n", " # A CheckboxGroup to show matching names; choices/value will be updated dynamically\n", " filter_checkbox = gr.CheckboxGroup(\n", " choices=[], value=[], label=\"Select which timeseries to show\"\n", " )\n", "\n", " # Buttons to check or uncheck all\n", " with gr.Row():\n", " check_all_btn = gr.Button(\"Check All\")\n", " uncheck_all_btn = gr.Button(\"Uncheck All\")\n", "\n", " plot_button = gr.Button(\"Plot\")\n", "\n", " gallery = gr.Gallery(label=\"Forecast Plots (filtered)\")\n", " errbox = gr.Textbox(label=\"Error Message\")\n", "\n", " # State to hold the full list of names\n", " names_state = gr.State([])\n", "\n", " # 1) When file or preset changes, extract full names and update the checkbox + state\n", " file_input.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", " preset_dropdown.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", "\n", " # 2) When search text changes, filter names_state and update the checkbox\n", " search_box.change(\n", " fn=filter_names,\n", " inputs=[search_box, names_state],\n", " outputs=filter_checkbox\n", " )\n", "\n", " # 3) Check All button: set checkbox value to all names in state\n", " check_all_btn.click(\n", " fn=check_all,\n", " inputs=names_state,\n", " outputs=filter_checkbox\n", " )\n", "\n", " # 4) Uncheck All button: set checkbox value to empty list\n", " uncheck_all_btn.click(\n", " fn=uncheck_all,\n", " inputs=names_state,\n", " outputs=filter_checkbox\n", " )\n", "\n", " # 5) When \"Plot\" is clicked, generate the filtered plots\n", " plot_button.click(\n", " fn=display_filtered_forecast,\n", " inputs=[file_input, preset_dropdown, filter_checkbox],\n", " outputs=[gallery, errbox],\n", " )\n", "\n", "demo.launch()\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Checked, filter" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "import gradio as gr\n", "\n", "# ----------------------------\n", "# Helper functions (logic unchanged)\n", "# ----------------------------\n", "\n", "torch.manual_seed(42)\n", "_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n", "\n", "def model_forecast(input_data):\n", " return _forecast_tensor\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n", " ax.plot(timeseries, color=\"blue\")\n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n", " ax.set_title(f\"Timeseries: {timeseries_name}\")\n", " labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n", " ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n", " plt.tight_layout(rect=[0, 0, 0.85, 1])\n", " buf = io.BytesIO()\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img)\n", "\n", "def load_table(file_path):\n", " ext = file_path.split(\".\")[-1].lower()\n", " if ext == \"csv\":\n", " return pd.read_csv(file_path)\n", " elif ext in (\"xls\", \"xlsx\"):\n", " return pd.read_excel(file_path)\n", " elif ext == \"parquet\":\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n", "\n", "def extract_names_and_update(file, preset_filename):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return gr.update(choices=[], value=[]), []\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " names = df.iloc[:, 0].tolist()\n", " else:\n", " names = [f\"Series {i}\" for i in range(len(df))]\n", " return gr.update(choices=names, value=names), names\n", " except Exception:\n", " return gr.update(choices=[], value=[]), []\n", "\n", "def filter_names(search_term, all_names):\n", " if not all_names:\n", " return gr.update(choices=[], value=[])\n", " if not search_term:\n", " return gr.update(choices=all_names, value=all_names)\n", " lower = search_term.lower()\n", " filtered = [n for n in all_names if lower in str(n).lower()]\n", " return gr.update(choices=filtered, value=filtered)\n", "\n", "def check_all(names_list):\n", " return gr.update(value=names_list)\n", "\n", "def uncheck_all(_):\n", " return gr.update(value=[])\n", "\n", "def display_filtered_forecast(file, preset_filename, selected_names):\n", " \"\"\"\n", " Load the table, filter by selected_names, run forecast (correctly sliced),\n", " and return a gallery + error string.\n", " \"\"\"\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"No file selected.\"\n", " df = load_table(preset_filename)\n", "\n", " # Extract all_names and numeric data\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object \\\n", " and not df.iloc[:, 0].str.isnumeric().all():\n", " all_names = df.iloc[:, 0].tolist()\n", " data_only = df.iloc[:, 1:].astype(float)\n", " else:\n", " all_names = [f\"Series {i}\" for i in range(len(df))]\n", " data_only = df.astype(float)\n", "\n", " # Build mask and filtered subset\n", " mask = [name in selected_names for name in all_names]\n", " if not any(mask):\n", " return [], \"No timeseries chosen to plot.\"\n", "\n", " filtered_data = data_only.iloc[mask, :].values\n", " filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n", "\n", " # ------------------------\n", " # HERE is the only change:\n", " # Instead of calling model_forecast(inp), slice the full tensor by mask:\n", " # ------------------------\n", " out = _forecast_tensor[mask] # shape = (n_chosen, pred_len, n_q)\n", " inp = torch.tensor(filtered_data)\n", "\n", " # Plot each chosen series against its properly‐aligned forecast\n", " gallery_images = []\n", " for i in range(inp.shape[0]):\n", " gallery_images.append(\n", " plot_forecast_image(inp[i], out[i], filtered_names[i])\n", " )\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Please upload a valid CSV, XLS, XLSX, or PARQUET file.\"\n", "\n", "\n", "# ----------------------------\n", "# Gradio layout: two columns\n", "# ----------------------------\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n", " gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n", "\n", " with gr.Row():\n", " # Left column: controls\n", " with gr.Column():\n", " gr.Markdown(\"## Data Selection\")\n", " file_input = gr.File(\n", " label=\"Upload CSV / XLSX / PARQUET\",\n", " file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n", " )\n", " preset_dropdown = gr.Dropdown(\n", " label=\"Or choose a preset:\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"stocks_data_noindex.csv\"\n", " )\n", "\n", " gr.Markdown(\"## Search / Filter\")\n", " search_box = gr.Textbox(\n", " placeholder=\"Type to filter (e.g. 'AMZN')\"\n", " )\n", " filter_checkbox = gr.CheckboxGroup(\n", " choices=[], value=[], label=\"Select which timeseries to show\"\n", " )\n", "\n", " with gr.Row():\n", " check_all_btn = gr.Button(\"✅ Check All\")\n", " uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n", "\n", " plot_button = gr.Button(\"▶️ Plot Forecasts\")\n", " errbox = gr.Textbox(interactive=False, placeholder=\"\")\n", "\n", " # Right column: gallery\n", " with gr.Column():\n", " gr.Markdown(\"## Forecast Gallery\")\n", " gallery = gr.Gallery()\n", "\n", " names_state = gr.State([])\n", "\n", " # When file or preset changes, update names\n", " file_input.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", " preset_dropdown.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", "\n", " # When search term changes, filter names\n", " search_box.change(\n", " fn=filter_names,\n", " inputs=[search_box, names_state],\n", " outputs=filter_checkbox\n", " )\n", "\n", " # Check All / Uncheck All\n", " check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n", " uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n", "\n", " # Plot button\n", " plot_button.click(\n", " fn=display_filtered_forecast,\n", " inputs=[file_input, preset_dropdown, filter_checkbox],\n", " outputs=[gallery, errbox]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Checked, almost ideal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "import gradio as gr\n", "\n", "# ----------------------------\n", "# Helper functions (logic unchanged)\n", "# ----------------------------\n", "\n", "torch.manual_seed(42)\n", "_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n", "\n", "def model_forecast(input_data):\n", " return _forecast_tensor\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n", " ax.plot(timeseries, color=\"blue\")\n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n", " ax.set_title(f\"Timeseries: {timeseries_name}\")\n", " labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n", " ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n", " plt.tight_layout(rect=[0, 0, 0.85, 1])\n", " buf = io.BytesIO()\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img)\n", "\n", "def load_table(file_path):\n", " ext = file_path.split(\".\")[-1].lower()\n", " if ext == \"csv\":\n", " return pd.read_csv(file_path)\n", " elif ext in (\"xls\", \"xlsx\"):\n", " return pd.read_excel(file_path)\n", " elif ext == \"parquet\":\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n", "\n", "def extract_names_and_update(file, preset_filename):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return gr.update(choices=[], value=[]), []\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " names = df.iloc[:, 0].tolist()\n", " else:\n", " names = [f\"Series {i}\" for i in range(len(df))]\n", " return gr.update(choices=names, value=names), names\n", " except Exception:\n", " return gr.update(choices=[], value=[]), []\n", "\n", "def filter_names(search_term, all_names):\n", " if not all_names:\n", " return gr.update(choices=[], value=[])\n", " if not search_term:\n", " return gr.update(choices=all_names, value=all_names)\n", " lower = search_term.lower()\n", " filtered = [n for n in all_names if lower in str(n).lower()]\n", " return gr.update(choices=filtered, value=filtered)\n", "\n", "def check_all(names_list):\n", " return gr.update(value=names_list)\n", "\n", "def uncheck_all(_):\n", " return gr.update(value=[])\n", "\n", "def display_filtered_forecast(file, preset_filename, selected_names):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"No file selected.\"\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " all_names = df.iloc[:, 0].tolist()\n", " data_only = df.iloc[:, 1:].astype(float)\n", " else:\n", " all_names = [f\"Series {i}\" for i in range(len(df))]\n", " data_only = df.astype(float)\n", "\n", " mask = [name in selected_names for name in all_names]\n", " if not any(mask):\n", " return [], \"No timeseries chosen to plot.\"\n", "\n", " filtered_data = data_only.iloc[mask, :].values\n", " filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n", " out = _forecast_tensor[mask] # slice forecasts to match filtered rows\n", " inp = torch.tensor(filtered_data)\n", "\n", " gallery_images = []\n", " for i in range(inp.shape[0]):\n", " gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Use CSV, XLS, XLSX, or PARQUET.\"\n", "\n", "# ----------------------------\n", "# Gradio layout: two columns + instructions\n", "# ----------------------------\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n", " gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n", "\n", " with gr.Row():\n", " # Left column: controls\n", " with gr.Column():\n", " gr.Markdown(\"## Data Selection\")\n", " file_input = gr.File(\n", " label=\"Upload CSV / XLSX / PARQUET\",\n", " file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n", " )\n", " preset_dropdown = gr.Dropdown(\n", " label=\"Or choose a preset:\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"stocks_data_noindex.csv\"\n", " )\n", "\n", " gr.Markdown(\"## Search / Filter\")\n", " search_box = gr.Textbox(placeholder=\"Type to filter (e.g. 'AMZN')\")\n", " filter_checkbox = gr.CheckboxGroup(\n", " choices=[], value=[], label=\"Select which timeseries to show\"\n", " )\n", "\n", " with gr.Row():\n", " check_all_btn = gr.Button(\"✅ Check All\")\n", " uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n", "\n", " plot_button = gr.Button(\"▶️ Plot Forecasts\")\n", " errbox = gr.Textbox(interactive=False, placeholder=\"\")\n", "\n", " # Right column: gallery + instructions\n", " with gr.Column():\n", " gr.Markdown(\"## Forecast Gallery\")\n", " gallery = gr.Gallery()\n", "\n", " # Instruction text below gallery\n", " gr.Markdown(\n", " \"\"\"\n", " **How to format your data:**\n", " - Your file must be a table (CSV, XLS, XLSX, or Parquet).\n", " - If you haven't prepared the data, the preset file will be used.\n", " - **One row per timeseries.** Each row is treated as a separate series.\n", " - If you want to **name** each series, put the name as the first value in **every** row:\n", " - Example (CSV): \n", " `AAPL, 120.5, 121.0, 119.8, ...` \n", " `AMZN, 3300.0, 3310.5, 3295.2, ...` \n", " - In that case, the first column is not numeric, so it will be used as the series name.\n", " - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:\n", " - Example: \n", " `120.5, 121.0, 119.8, ...` \n", " `3300.0, 3310.5, 3295.2, ...` \n", " - Then every row will be auto-named “Series 0, Series 1, …” in order.\n", " - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.\n", " - The rest of the columns (after the optional name) must be numeric data points for that series.\n", " - You can filter by typing in the search box. Then check or uncheck individual names before plotting.\n", " - Use “Check All” / “Uncheck All” to quickly select or deselect every series.\n", " - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.\n", " \"\"\"\n", " )\n", "\n", " names_state = gr.State([])\n", "\n", " # When file or preset changes, update names\n", " file_input.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", " preset_dropdown.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", "\n", " # When search term changes, filter names\n", " search_box.change(\n", " fn=filter_names,\n", " inputs=[search_box, names_state],\n", " outputs=filter_checkbox\n", " )\n", "\n", " # Check All / Uncheck All\n", " check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n", " uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n", "\n", " # Plot button\n", " plot_button.click(\n", " fn=display_filtered_forecast,\n", " inputs=[file_input, preset_dropdown, filter_checkbox],\n", " outputs=[gallery, errbox]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The default choice isn't processed when the default choice is chosen" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "import gradio as gr\n", "\n", "# ----------------------------\n", "# Helper functions (logic unchanged)\n", "# ----------------------------\n", "\n", "torch.manual_seed(42)\n", "_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n", "\n", "def model_forecast(input_data):\n", " return _forecast_tensor\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n", " ax.plot(timeseries, color=\"blue\")\n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n", " ax.set_title(f\"Timeseries: {timeseries_name}\")\n", " labels = [f\"Quantile {i}\" for i in range(quantile_predictions.shape[1])]\n", " ax.legend(labels, loc=\"center left\", bbox_to_anchor=(1, 0.5))\n", " plt.tight_layout(rect=[0, 0, 0.85, 1])\n", " buf = io.BytesIO()\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img)\n", "\n", "def load_table(file_path):\n", " ext = file_path.split(\".\")[-1].lower()\n", " if ext == \"csv\":\n", " return pd.read_csv(file_path)\n", " elif ext in (\"xls\", \"xlsx\"):\n", " return pd.read_excel(file_path)\n", " elif ext == \"parquet\":\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n", "\n", "def extract_names_and_update(file, preset_filename):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return gr.update(choices=[], value=[]), []\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " names = df.iloc[:, 0].tolist()\n", " else:\n", " names = [f\"Series {i}\" for i in range(len(df))]\n", " return gr.update(choices=names, value=names), names\n", " except Exception:\n", " return gr.update(choices=[], value=[]), []\n", "\n", "def filter_names(search_term, all_names):\n", " if not all_names:\n", " return gr.update(choices=[], value=[])\n", " if not search_term:\n", " return gr.update(choices=all_names, value=all_names)\n", " lower = search_term.lower()\n", " filtered = [n for n in all_names if lower in str(n).lower()]\n", " return gr.update(choices=filtered, value=filtered)\n", "\n", "def check_all(names_list):\n", " return gr.update(value=names_list)\n", "\n", "def uncheck_all(_):\n", " return gr.update(value=[])\n", "\n", "def display_filtered_forecast(file, preset_filename, selected_names):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"No file selected.\"\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " all_names = df.iloc[:, 0].tolist()\n", " data_only = df.iloc[:, 1:].astype(float)\n", " else:\n", " all_names = [f\"Series {i}\" for i in range(len(df))]\n", " data_only = df.astype(float)\n", "\n", " mask = [name in selected_names for name in all_names]\n", " if not any(mask):\n", " return [], \"No timeseries chosen to plot.\"\n", "\n", " filtered_data = data_only.iloc[mask, :].values\n", " filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n", " out = _forecast_tensor[mask] # slice forecasts to match filtered rows\n", " inp = torch.tensor(filtered_data)\n", "\n", " gallery_images = []\n", " for i in range(inp.shape[0]):\n", " gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Use CSV, XLS, XLSX, or PARQUET.\"\n", "\n", "\n", "# ----------------------------\n", "# Gradio layout: two columns + instructions\n", "# ----------------------------\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n", " gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n", "\n", " with gr.Row():\n", " # Left column: controls\n", " with gr.Column():\n", " gr.Markdown(\"## Data Selection\")\n", " gr.Markdown(\"*If you haven't prepared the data, the preset file will be used.*\")\n", " file_input = gr.File(\n", " label=\"Upload CSV / XLSX / PARQUET\",\n", " file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n", " )\n", " preset_dropdown = gr.Dropdown(\n", " label=\"Or choose a preset:\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"stocks_data_noindex.csv\"\n", " )\n", "\n", " gr.Markdown(\"## Search / Filter\")\n", " search_box = gr.Textbox(placeholder=\"Type to filter (e.g. 'AMZN')\")\n", " filter_checkbox = gr.CheckboxGroup(\n", " choices=[], value=[], label=\"Select which timeseries to show\"\n", " )\n", "\n", " with gr.Row():\n", " check_all_btn = gr.Button(\"✅ Check All\")\n", " uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n", "\n", " plot_button = gr.Button(\"▶️ Plot Forecasts\")\n", " errbox = gr.Textbox(label=\"Error Message\", interactive=False)\n", "\n", " # Right column: gallery + instructions\n", " with gr.Column():\n", " gr.Markdown(\"## Forecast Gallery\")\n", " gallery = gr.Gallery()\n", "\n", " # Instruction text below gallery\n", " gr.Markdown(\n", " \"\"\"\n", " **How to format your data:**\n", " - Your file must be a table (CSV, XLS, XLSX, or Parquet).\n", " - **One row per timeseries.** Each row is treated as a separate series.\n", " - If you want to **name** each series, put the name as the first value in **every** row:\n", " - Example (CSV): \n", " `AAPL, 120.5, 121.0, 119.8, ...` \n", " `AMZN, 3300.0, 3310.5, 3295.2, ...` \n", " - In that case, the first column is not numeric, so it will be used as the series name.\n", " - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:\n", " - Example: \n", " `120.5, 121.0, 119.8, ...` \n", " `3300.0, 3310.5, 3295.2, ...` \n", " - Then every row will be auto-named “Series 0, Series 1, …” in order.\n", " - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.\n", " - The rest of the columns (after the optional name) must be numeric data points for that series.\n", " - You can filter by typing in the search box. Then check or uncheck individual names before plotting.\n", " - Use “Check All” / “Uncheck All” to quickly select or deselect every series.\n", " - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.\n", " \"\"\"\n", " )\n", "\n", " names_state = gr.State([])\n", "\n", " # When file or preset changes, update names\n", " file_input.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", " preset_dropdown.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", "\n", " # When search term changes, filter names\n", " search_box.change(\n", " fn=filter_names,\n", " inputs=[search_box, names_state],\n", " outputs=filter_checkbox\n", " )\n", "\n", " # Check All / Uncheck All\n", " check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n", " uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n", "\n", " # Plot button\n", " plot_button.click(\n", " fn=display_filtered_forecast,\n", " inputs=[file_input, preset_dropdown, filter_checkbox],\n", " outputs=[gallery, errbox]\n", " )\n", "\n", "demo.launch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Default choice - None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "import gradio as gr\n", "\n", "# ----------------------------\n", "# Helper functions (logic unchanged)\n", "# ----------------------------\n", "\n", "torch.manual_seed(42)\n", "_forecast_tensor = torch.load(\"stocks_data_forecast.pt\") # shape = (n_series, pred_len, n_q)\n", "\n", "def model_forecast(input_data):\n", " return _forecast_tensor\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=300)\n", " \n", " # Plot the original timeseries with thicker line and marker\n", " ax.plot(timeseries, color=\"blue\", linewidth=2.5, marker='o', label=\"Given Data\")\n", " \n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " # Use distinct colors with higher alpha for smoothness\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\", linewidth=2, alpha=0.8, label=f\"Quantile {i+1}\")\n", " \n", " ax.set_title(f\"Timeseries: {timeseries_name}\", fontsize=16, fontweight='bold')\n", " ax.set_xlabel(\"Time\", fontsize=12)\n", " ax.set_ylabel(\"Value\", fontsize=12)\n", " \n", " ax.grid(True, which='both', linestyle='--', linewidth=0.7, alpha=0.6)\n", " ax.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5), fontsize=10, frameon=True, shadow=True)\n", " \n", " plt.tight_layout(rect=[0, 0, 0.82, 1])\n", " \n", " buf = io.BytesIO()\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\", transparent=True)\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img)\n", "\n", "def load_table(file_path):\n", " ext = file_path.split(\".\")[-1].lower()\n", " if ext == \"csv\":\n", " return pd.read_csv(file_path)\n", " elif ext in (\"xls\", \"xlsx\"):\n", " return pd.read_excel(file_path)\n", " elif ext == \"parquet\":\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(\"Unsupported format. Use CSV, XLS, XLSX, or PARQUET.\")\n", "\n", "def extract_names_and_update(file, preset_filename):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return gr.update(choices=[], value=[]), []\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " names = df.iloc[:, 0].tolist()\n", " else:\n", " names = [f\"Series {i}\" for i in range(len(df))]\n", " return gr.update(choices=names, value=names), names\n", " except Exception:\n", " return gr.update(choices=[], value=[]), []\n", "\n", "def filter_names(search_term, all_names):\n", " if not all_names:\n", " return gr.update(choices=[], value=[])\n", " if not search_term:\n", " return gr.update(choices=all_names, value=all_names)\n", " lower = search_term.lower()\n", " filtered = [n for n in all_names if lower in str(n).lower()]\n", " return gr.update(choices=filtered, value=filtered)\n", "\n", "def check_all(names_list):\n", " return gr.update(value=names_list)\n", "\n", "def uncheck_all(_):\n", " return gr.update(value=[])\n", "\n", "def display_filtered_forecast(file, preset_filename, selected_names):\n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"No file selected.\"\n", " df = load_table(preset_filename)\n", "\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():\n", " all_names = df.iloc[:, 0].tolist()\n", " data_only = df.iloc[:, 1:].astype(float)\n", " else:\n", " all_names = [f\"Series {i}\" for i in range(len(df))]\n", " data_only = df.astype(float)\n", "\n", " mask = [name in selected_names for name in all_names]\n", " if not any(mask):\n", " return [], \"No timeseries chosen to plot.\"\n", "\n", " filtered_data = data_only.iloc[mask, :].values\n", " filtered_names = [all_names[i] for i, m in enumerate(mask) if m]\n", " out = _forecast_tensor[mask] # slice forecasts to match filtered rows\n", " inp = torch.tensor(filtered_data)\n", "\n", " gallery_images = []\n", " for i in range(inp.shape[0]):\n", " gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Use CSV, XLS, XLSX, or PARQUET.\"\n", "\n", "\n", "# ----------------------------\n", "# Gradio layout: two columns + instructions\n", "# ----------------------------\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"# 📈 Stock Forecast Viewer 📊\")\n", " gr.Markdown(\"Upload data or choose a preset, filter by name, then click Plot.\")\n", "\n", " with gr.Row():\n", " # Left column: controls\n", " with gr.Column():\n", " gr.Markdown(\"## Data Selection\")\n", " file_input = gr.File(\n", " label=\"Upload CSV / XLSX / PARQUET\",\n", " file_types=[\".csv\", \".xls\", \".xlsx\", \".parquet\"]\n", " )\n", " preset_dropdown = gr.Dropdown(\n", " label=\"Or choose a preset:\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"No file selected\"\n", " )\n", "\n", " gr.Markdown(\"## Search / Filter\")\n", " search_box = gr.Textbox(placeholder=\"Type to filter (e.g. 'AMZN')\")\n", " filter_checkbox = gr.CheckboxGroup(\n", " choices=[], value=[], label=\"Select which timeseries to show\"\n", " )\n", "\n", " with gr.Row():\n", " check_all_btn = gr.Button(\"✅ Check All\")\n", " uncheck_all_btn = gr.Button(\"❎ Uncheck All\")\n", "\n", " plot_button = gr.Button(\"▶️ Plot Forecasts\")\n", " errbox = gr.Textbox(label=\"Error Message\", interactive=False)\n", "\n", " # Right column: gallery + instructions\n", " with gr.Column():\n", " gr.Markdown(\"## Forecast Gallery\")\n", " gallery = gr.Gallery()\n", "\n", " # Instruction text below gallery\n", " gr.Markdown(\"## Instructions\")\n", " gr.Markdown(\n", " \"\"\"\n", " **How to format your data:**\n", " - Your file must be a table (CSV, XLS, XLSX, or Parquet).\n", " - **One row per timeseries.** Each row is treated as a separate series.\n", " - If you want to **name** each series, put the name as the first value in **every** row:\n", " - Example (CSV): \n", " `AAPL, 120.5, 121.0, 119.8, ...` \n", " `AMZN, 3300.0, 3310.5, 3295.2, ...` \n", " - In that case, the first column is not numeric, so it will be used as the series name.\n", " - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:\n", " - Example: \n", " `120.5, 121.0, 119.8, ...` \n", " `3300.0, 3310.5, 3295.2, ...` \n", " - Then every row will be auto-named “Series 0, Series 1, …” in order.\n", " - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.\n", " - The rest of the columns (after the optional name) must be numeric data points for that series.\n", " - You can filter by typing in the search box. Then check or uncheck individual names before plotting.\n", " - Use “Check All” / “Uncheck All” to quickly select or deselect every series.\n", " - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.\n", " \"\"\"\n", " )\n", "\n", " names_state = gr.State([])\n", "\n", " # When file or preset changes, update names\n", " file_input.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", " preset_dropdown.change(\n", " fn=extract_names_and_update,\n", " inputs=[file_input, preset_dropdown],\n", " outputs=[filter_checkbox, names_state]\n", " )\n", "\n", " # When search term changes, filter names\n", " search_box.change(\n", " fn=filter_names,\n", " inputs=[search_box, names_state],\n", " outputs=filter_checkbox\n", " )\n", "\n", " # Check All / Uncheck All\n", " check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)\n", " uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)\n", "\n", " # Plot button\n", " plot_button.click(\n", " fn=display_filtered_forecast,\n", " inputs=[file_input, preset_dropdown, filter_checkbox],\n", " outputs=[gallery, errbox]\n", " )\n", "\n", "demo.launch()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }