{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[*********************100%***********************] 10 of 10 completed\n" ] } ], "source": [ "import yfinance as yf\n", "import pandas as pd\n", "\n", "# List of 10 example tickers (you can replace these with any tickers you prefer)\n", "tickers = [\"AAPL\", \"MSFT\", \"GOOGL\", \"AMZN\", \"META\", \"TSLA\", \"NVDA\", \"JPM\", \"JNJ\", \"V\"]\n", "\n", "# Download daily adjusted close prices for the last month (≈30 calendar days)\n", "data = yf.download(tickers, period=\"6mo\", interval=\"1d\", auto_adjust=False)[\"Adj Close\"]\n", "\n", "# Transpose so that each row is one company’s month-long timeseries\n", "df = pd.DataFrame(data.transpose())\n", "\n", "# At this point, `df` has:\n", "# • Index: the 10 tickers\n", "# • Columns: one column per trading day in the last month\n", "# • Values: adjusted close price for that ticker on that date\n", "df.to_csv(\"stocks_data_noindex.csv\", index=False)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Stable version" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import io\n", "import gradio as gr\n", "import pandas as pd\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import numpy as np\n", "\n", "torch.manual_seed(42)\n", "output = torch.load(\"stocks_data_forecast.pt\") # (n_timeseries, pred_len, n_quantiles)\n", "\n", "def model_forecast(input_data):\n", " return output\n", "\n", "def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):\n", " \"\"\"Returns a NumPy array of the plotted figure.\"\"\"\n", " fig, ax = plt.subplots(figsize=(10, 6), dpi=150)\n", " ax.plot(timeseries, color=\"blue\")\n", " x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))\n", " for i in range(quantile_predictions.shape[1]):\n", " ax.plot(x_pred, quantile_predictions[:, i], color=f\"C{i}\")\n", " buf = io.BytesIO()\n", "\n", " # Add title\n", " ax.set_title(f\"Timeseries: {timeseries_name}\")\n", " # Add labels to the legend (quantiles)\n", " labels = [f\"Quantile {i+1}\" for i in range(quantile_predictions.shape[1])]\n", " ax.legend(labels, loc='center left', bbox_to_anchor=(1, 0.5))\n", " plt.tight_layout(rect=[0, 0, 0.85, 1])\n", "\n", " fig.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n", " plt.close(fig)\n", " buf.seek(0)\n", " img = Image.open(buf).convert(\"RGB\")\n", " return np.array(img) # Return as an H×W×3 array\n", "\n", "def display_forecast(file, preset_filename):\n", " accepted_formats = ['csv', 'xls', 'xlsx', 'parquet']\n", "\n", " def load_table(file_path):\n", " ext = file_path.split('.')[-1].lower()\n", " if ext == 'csv':\n", " return pd.read_csv(file_path)\n", " elif ext in ['xls', 'xlsx']:\n", " return pd.read_excel(file_path)\n", " elif ext == 'parquet':\n", " return pd.read_parquet(file_path)\n", " else:\n", " raise ValueError(f\"Unsupported file format '.{ext}'. Acceptable formats: CSV, XLS, XLSX, PARQUET.\")\n", " \n", " try:\n", " if file is not None:\n", " df = load_table(file.name)\n", " else:\n", " if not preset_filename:\n", " return [], \"Please upload a file or select a preset.\"\n", " df = load_table(preset_filename)\n", " \n", " # Check first column for timeseries names\n", " if df.shape[1] > 0 and df.iloc[:, 0].dtype == object:\n", " if not df.iloc[:, 0].str.isnumeric().all():\n", " timeseries_names = df.iloc[:, 0].tolist()\n", " df = df.iloc[:, 1:]\n", " else:\n", " timeseries_names = [f\"Series {i}\" for i in range(len(df))]\n", " else:\n", " timeseries_names = [f\"Series {i}\" for i in range(len(df))]\n", "\n", " _input = torch.tensor(df.values)\n", " _output = model_forecast(_input)\n", "\n", " gallery_images = []\n", " for i in range(_input.shape[0]):\n", " img_array = plot_forecast_image(_input[i], _output[i], timeseries_names[i])\n", " gallery_images.append(img_array)\n", "\n", " return gallery_images, \"\"\n", " except Exception as e:\n", " return [], f\"Error: {e}. Please upload files in one of the following formats: CSV, XLS, XLSX, PARQUET.\"\n", "\n", "\n", "\n", "iface = gr.Interface(\n", " fn=display_forecast,\n", " inputs=[\n", " gr.File(label=\"Upload your CSV file (optional)\"),\n", " gr.Dropdown(\n", " label=\"Or select a preset CSV file\",\n", " choices=[\"stocks_data_noindex.csv\", \"stocks_data.csv\"],\n", " value=\"stocks_data_noindex.csv\"\n", " )\n", " ],\n", " outputs=[\n", " gr.Gallery(label=\"Forecast Plots (one per row)\"), \n", " gr.Textbox(label=\"Error Message\")\n", " ],\n", " title=\"CSV→Dynamic Forecast Gallery\",\n", " description=\"Upload a CSV with any number of rows; each row’s forecast becomes one image in a gallery.\",\n", " allow_flagging=\"never\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " iface.launch()\n", "\n", "\n", "\n", "# '''\n", "# 1. Prepared datasets\n", "# 2. Plots of different quiantilies (different colors)\n", "# 3. Filters for plots...\n", "# 4. Different input options\n", "# 5. README.md in there (in UI) (contact us for fine-tuning)\n", "# 6. Requirements for dimensions\n", "# 7. Multivariate data (x_t is vector)\n", "# 8. LOGO of NX-AI and xLSTM and tirex\n", "# '''" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | Ticker | \n", "2024-12-04 | \n", "2024-12-05 | \n", "2024-12-06 | \n", "2024-12-09 | \n", "2024-12-10 | \n", "2024-12-11 | \n", "2024-12-12 | \n", "2024-12-13 | \n", "2024-12-16 | \n", "... | \n", "2025-05-21 | \n", "2025-05-22 | \n", "2025-05-23 | \n", "2025-05-27 | \n", "2025-05-28 | \n", "2025-05-29 | \n", "2025-05-30 | \n", "2025-06-02 | \n", "2025-06-03 | \n", "2025-06-04 | \n", "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "AAPL | \n", "242.425201 | \n", "242.455124 | \n", "242.255600 | \n", "246.156204 | \n", "247.173752 | \n", "245.896820 | \n", "247.363297 | \n", "247.532883 | \n", "250.435867 | \n", "... | \n", "202.089996 | \n", "201.360001 | \n", "195.270004 | \n", "200.210007 | \n", "200.419998 | \n", "199.949997 | \n", "200.850006 | \n", "201.699997 | \n", "203.270004 | \n", "203.150101 | \n", "
| 1 | \n", "AMZN | \n", "218.160004 | \n", "220.550003 | \n", "227.029999 | \n", "226.089996 | \n", "225.039993 | \n", "230.259995 | \n", "228.970001 | \n", "227.460007 | \n", "232.929993 | \n", "... | \n", "201.119995 | \n", "203.100006 | \n", "200.990005 | \n", "206.020004 | \n", "204.720001 | \n", "205.699997 | \n", "205.009995 | \n", "206.649994 | \n", "205.710007 | \n", "207.472000 | \n", "
| 2 | \n", "GOOGL | \n", "173.970016 | \n", "172.244003 | \n", "174.309265 | \n", "175.168259 | \n", "184.956985 | \n", "195.175217 | \n", "191.739182 | \n", "189.601639 | \n", "196.433777 | \n", "... | \n", "168.559998 | \n", "170.869995 | \n", "168.470001 | \n", "172.899994 | \n", "172.360001 | \n", "171.860001 | \n", "171.740005 | \n", "169.029999 | \n", "166.179993 | \n", "167.785004 | \n", "
| 3 | \n", "JNJ | \n", "148.006256 | \n", "147.071823 | \n", "146.865265 | \n", "147.150513 | \n", "146.786560 | \n", "144.238968 | \n", "143.845535 | \n", "144.219299 | \n", "141.494659 | \n", "... | \n", "151.877960 | \n", "151.312805 | \n", "151.639999 | \n", "153.250000 | \n", "152.429993 | \n", "153.580002 | \n", "155.210007 | \n", "155.399994 | \n", "154.419998 | \n", "153.419998 | \n", "
| 4 | \n", "JPM | \n", "240.666992 | \n", "242.723633 | \n", "244.582520 | \n", "241.072388 | \n", "240.133057 | \n", "240.795532 | \n", "238.817978 | \n", "237.245850 | \n", "236.889893 | \n", "... | \n", "261.040009 | \n", "260.670013 | \n", "260.709991 | \n", "265.290009 | \n", "263.489990 | \n", "264.369995 | \n", "264.000000 | \n", "264.660004 | \n", "266.269989 | \n", "265.065002 | \n", "
| 5 | \n", "META | \n", "612.740173 | \n", "607.898376 | \n", "622.713257 | \n", "612.530518 | \n", "618.270813 | \n", "631.608154 | \n", "629.721375 | \n", "619.299072 | \n", "623.685120 | \n", "... | \n", "635.500000 | \n", "636.570007 | \n", "627.059998 | \n", "642.320007 | \n", "643.580017 | \n", "645.049988 | \n", "647.489990 | \n", "670.900024 | \n", "666.849976 | \n", "685.159973 | \n", "
| 6 | \n", "MSFT | \n", "435.744720 | \n", "440.924805 | \n", "441.871155 | \n", "444.311768 | \n", "441.632050 | \n", "447.270416 | \n", "447.838196 | \n", "445.556976 | \n", "449.860413 | \n", "... | \n", "452.570007 | \n", "454.859985 | \n", "450.179993 | \n", "460.690002 | \n", "457.359985 | \n", "458.679993 | \n", "460.359985 | \n", "461.970001 | \n", "462.970001 | \n", "464.190002 | \n", "
| 7 | \n", "NVDA | \n", "145.116653 | \n", "145.046661 | \n", "142.426895 | \n", "138.797226 | \n", "135.057587 | \n", "139.297180 | \n", "137.327362 | \n", "134.237656 | \n", "131.987854 | \n", "... | \n", "131.800003 | \n", "132.830002 | \n", "131.289993 | \n", "135.500000 | \n", "134.809998 | \n", "139.190002 | \n", "135.130005 | \n", "137.380005 | \n", "141.220001 | \n", "141.854996 | \n", "
| 8 | \n", "TSLA | \n", "357.929993 | \n", "369.489990 | \n", "389.220001 | \n", "389.790009 | \n", "400.989990 | \n", "424.769989 | \n", "418.100006 | \n", "436.230011 | \n", "463.019989 | \n", "... | \n", "334.619995 | \n", "341.040009 | \n", "339.339996 | \n", "362.890015 | \n", "356.899994 | \n", "358.429993 | \n", "346.459991 | \n", "342.690002 | \n", "344.269989 | \n", "334.671600 | \n", "
| 9 | \n", "V | \n", "308.866455 | \n", "308.049194 | \n", "309.972778 | \n", "307.271790 | \n", "311.338196 | \n", "312.743500 | \n", "313.182037 | \n", "313.690308 | \n", "314.836517 | \n", "... | \n", "358.299988 | \n", "357.970001 | \n", "353.540009 | \n", "359.299988 | \n", "359.730011 | \n", "362.399994 | \n", "365.190002 | \n", "365.320007 | \n", "365.859985 | \n", "368.179993 | \n", "
10 rows × 125 columns
\n", "