In [None]:
import yfinance as yf
import pandas as pd

# List of 10 example tickers (you can replace these with any tickers you prefer)
tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "NVDA", "JPM", "JNJ", "V"]

# Download daily adjusted close prices for the last month (‚âà30 calendar days)
data = yf.download(tickers, period="6mo", interval="1d", auto_adjust=False)["Adj Close"]

# Transpose so that each row is one company‚Äôs month-long timeseries
df = pd.DataFrame(data.transpose())

# At this point, `df` has:
# ‚Ä¢ Index: the 10 tickers
# ‚Ä¢ Columns: one column per trading day in the last month
# ‚Ä¢ Values: adjusted close price for that ticker on that date
df.to_csv("stocks_data_noindex.csv", index=False)


[*********************100%***********************]  10 of 10 completed


# Stable version

In [None]:
import io
import gradio as gr
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

torch.manual_seed(42)
output = torch.load("stocks_data_forecast.pt")  # (n_timeseries, pred_len, n_quantiles)

def model_forecast(input_data):
    return output

def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):
    """Returns a NumPy array of the plotted figure."""
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
    ax.plot(timeseries, color="blue")
    x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))
    for i in range(quantile_predictions.shape[1]):
        ax.plot(x_pred, quantile_predictions[:, i], color=f"C{i}")
    buf = io.BytesIO()

    # Add title
    ax.set_title(f"Timeseries: {timeseries_name}")
    # Add labels to the legend (quantiles)
    labels = [f"Quantile {i+1}" for i in range(quantile_predictions.shape[1])]
    ax.legend(labels, loc='center left', bbox_to_anchor=(1, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    fig.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert("RGB")
    return np.array(img)  # Return as an H√óW√ó3 array

def display_forecast(file, preset_filename):
    accepted_formats = ['csv', 'xls', 'xlsx', 'parquet']

    def load_table(file_path):
        ext = file_path.split('.')[-1].lower()
        if ext == 'csv':
            return pd.read_csv(file_path)
        elif ext in ['xls', 'xlsx']:
            return pd.read_excel(file_path)
        elif ext == 'parquet':
            return pd.read_parquet(file_path)
        else:
            raise ValueError(f"Unsupported file format '.{ext}'. Acceptable formats: CSV, XLS, XLSX, PARQUET.")
    
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return [], "Please upload a file or select a preset."
            df = load_table(preset_filename)
        
        # Check first column for timeseries names
        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object:
            if not df.iloc[:, 0].str.isnumeric().all():
                timeseries_names = df.iloc[:, 0].tolist()
                df = df.iloc[:, 1:]
            else:
                timeseries_names = [f"Series {i}" for i in range(len(df))]
        else:
            timeseries_names = [f"Series {i}" for i in range(len(df))]

        _input = torch.tensor(df.values)
        _output = model_forecast(_input)

        gallery_images = []
        for i in range(_input.shape[0]):
            img_array = plot_forecast_image(_input[i], _output[i], timeseries_names[i])
            gallery_images.append(img_array)

        return gallery_images, ""
    except Exception as e:
        return [], f"Error: {e}. Please upload files in one of the following formats: CSV, XLS, XLSX, PARQUET."



iface = gr.Interface(
    fn=display_forecast,
    inputs=[
        gr.File(label="Upload your CSV file (optional)"),
        gr.Dropdown(
            label="Or select a preset CSV file",
            choices=["stocks_data_noindex.csv", "stocks_data.csv"],
            value="stocks_data_noindex.csv"
        )
    ],
    outputs=[
        gr.Gallery(label="Forecast Plots (one per row)"), 
        gr.Textbox(label="Error Message")
    ],
    title="CSV‚ÜíDynamic Forecast Gallery",
    description="Upload a CSV with any number of rows; each row‚Äôs forecast becomes one image in a gallery.",
    allow_flagging="never",
)

if __name__ == "__main__":
    iface.launch()



# '''
# 1. Prepared datasets
# 2. Plots of different quiantilies (different colors)
# 3. Filters for plots...
# 4. Different input options
# 5. README.md in there (in UI) (contact us for fine-tuning)
# 6. Requirements for dimensions
# 7. Multivariate data (x_t is vector)
# 8. LOGO of NX-AI and xLSTM and tirex
# '''

In [15]:
pd.read_csv("stocks_data.csv")

Unnamed: 0,Ticker,2024-12-04,2024-12-05,2024-12-06,2024-12-09,2024-12-10,2024-12-11,2024-12-12,2024-12-13,2024-12-16,...,2025-05-21,2025-05-22,2025-05-23,2025-05-27,2025-05-28,2025-05-29,2025-05-30,2025-06-02,2025-06-03,2025-06-04
0,AAPL,242.425201,242.455124,242.2556,246.156204,247.173752,245.89682,247.363297,247.532883,250.435867,...,202.089996,201.360001,195.270004,200.210007,200.419998,199.949997,200.850006,201.699997,203.270004,203.150101
1,AMZN,218.160004,220.550003,227.029999,226.089996,225.039993,230.259995,228.970001,227.460007,232.929993,...,201.119995,203.100006,200.990005,206.020004,204.720001,205.699997,205.009995,206.649994,205.710007,207.472
2,GOOGL,173.970016,172.244003,174.309265,175.168259,184.956985,195.175217,191.739182,189.601639,196.433777,...,168.559998,170.869995,168.470001,172.899994,172.360001,171.860001,171.740005,169.029999,166.179993,167.785004
3,JNJ,148.006256,147.071823,146.865265,147.150513,146.78656,144.238968,143.845535,144.219299,141.494659,...,151.87796,151.312805,151.639999,153.25,152.429993,153.580002,155.210007,155.399994,154.419998,153.419998
4,JPM,240.666992,242.723633,244.58252,241.072388,240.133057,240.795532,238.817978,237.24585,236.889893,...,261.040009,260.670013,260.709991,265.290009,263.48999,264.369995,264.0,264.660004,266.269989,265.065002
5,META,612.740173,607.898376,622.713257,612.530518,618.270813,631.608154,629.721375,619.299072,623.68512,...,635.5,636.570007,627.059998,642.320007,643.580017,645.049988,647.48999,670.900024,666.849976,685.159973
6,MSFT,435.74472,440.924805,441.871155,444.311768,441.63205,447.270416,447.838196,445.556976,449.860413,...,452.570007,454.859985,450.179993,460.690002,457.359985,458.679993,460.359985,461.970001,462.970001,464.190002
7,NVDA,145.116653,145.046661,142.426895,138.797226,135.057587,139.29718,137.327362,134.237656,131.987854,...,131.800003,132.830002,131.289993,135.5,134.809998,139.190002,135.130005,137.380005,141.220001,141.854996
8,TSLA,357.929993,369.48999,389.220001,389.790009,400.98999,424.769989,418.100006,436.230011,463.019989,...,334.619995,341.040009,339.339996,362.890015,356.899994,358.429993,346.459991,342.690002,344.269989,334.6716
9,V,308.866455,308.049194,309.972778,307.27179,311.338196,312.7435,313.182037,313.690308,314.836517,...,358.299988,357.970001,353.540009,359.299988,359.730011,362.399994,365.190002,365.320007,365.859985,368.179993


# Not checked but with labels filter

In [None]:
import io
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gradio as gr

# Set random seed and load your pretrained forecast tensor
torch.manual_seed(42)
_forecast_tensor = torch.load("stocks_data_forecast.pt")  # shape = (n_series, pred_len, n_q)

def model_forecast(input_data):
    return _forecast_tensor

def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):
    """Given one 1D series + quantile‚Äêmatrix, return a NumPy array of the plotted figure."""
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
    ax.plot(timeseries, color="blue")
    x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))
    for i in range(quantile_predictions.shape[1]):
        ax.plot(x_pred, quantile_predictions[:, i], color=f"C{i}")
    ax.set_title(f"Timeseries: {timeseries_name}")

    labels = [f"Quantile {i}" for i in range(quantile_predictions.shape[1])]
    ax.legend(labels, loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert("RGB")
    return np.array(img)

def load_table(file_path):
    """Load CSV / XLS(X) / Parquet by extension, else raise."""
    ext = file_path.split(".")[-1].lower()
    if ext == "csv":
        return pd.read_csv(file_path)
    elif ext in ("xls", "xlsx"):
        return pd.read_excel(file_path)
    elif ext == "parquet":
        return pd.read_parquet(file_path)
    else:
        raise ValueError(
            f"Unsupported file format '.{ext}'. Accepted: CSV, XLS, XLSX, PARQUET."
        )

def extract_names_and_update(file, preset_filename):
    """
    Read the table (uploaded or preset), extract timeseries names, and return:
    1) gr.update for the CheckboxGroup (all names pre‚Äêchecked)
    2) the full list of names to store in state.
    """
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return gr.update(choices=[], value=[]), []
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            names = df.iloc[:, 0].tolist()
        else:
            names = [f"Series {i}" for i in range(len(df))]

        return gr.update(choices=names, value=names), names
    except Exception:
        return gr.update(choices=[], value=[]), []

def filter_names(search_term, all_names):
    """
    Filter the full list of names (all_names) by the search_term (case‚Äêinsensitive substring).
    Return gr.update with filtered choices and keep checked those that remain in both.
    """
    if not all_names:
        return gr.update(choices=[], value=[])
    if not search_term:
        # No search term ‚Üí show all
        return gr.update(choices=all_names, value=all_names)
    lower = search_term.lower()
    filtered = [n for n in all_names if lower in str(n).lower()]
    return gr.update(choices=filtered, value=filtered)

def check_all(names_list):
    """Return an update that checks all names in the checkbox."""
    return gr.update(value=names_list)

def uncheck_all(_):
    """Return an update that unchecks all names."""
    return gr.update(value=[])

def display_filtered_forecast(file, preset_filename, selected_names):
    """
    Load the table, filter by selected_names, run forecast, and return:
    - list of images (NumPy arrays) for the gallery
    - error string (empty if OK)
    """
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return [], "No file selected."
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            all_names = df.iloc[:, 0].tolist()
            data_only = df.iloc[:, 1:].astype(float)
        else:
            all_names = [f"Series {i}" for i in range(len(df))]
            data_only = df.astype(float)

        mask = [name in selected_names for name in all_names]
        if not any(mask):
            return [], "No timeseries chosen to plot."

        filtered_data = data_only.iloc[mask, :].values
        filtered_names = [all_names[i] for i, m in enumerate(mask) if m]

        inp = torch.tensor(filtered_data)           # (n_chosen, length)
        out = model_forecast(inp)                   # (n_chosen, pred_len, n_q)

        gallery_images = []
        for i in range(inp.shape[0]):
            gallery_images.append(
                plot_forecast_image(inp[i], out[i], filtered_names[i])
            )

        return gallery_images, ""
    except Exception as e:
        return [], f"Error: {e}. Please upload a valid CSV, XLS, XLSX, or PARQUET file."

with gr.Blocks() as demo:
    gr.Markdown("## Upload or select a preset ‚Üí search/filter by name ‚Üí click Plot")

    with gr.Row():
        file_input = gr.File(
            label="Upload CSV/XLSX/PARQUET (optional)",
            file_types=[".csv", ".xls", ".xlsx", ".parquet"]
        )
        preset_dropdown = gr.Dropdown(
            label="Or pick a preset:",
            choices=["stocks_data_noindex.csv", "stocks_data.csv"],
            value="stocks_data_noindex.csv"
        )

    # A text box to type a substring (search term)
    search_box = gr.Textbox(
        label="Search/Filter timeseries by name",
        placeholder="Type to filter (e.g. 'AMZN')",
        value=""
    )

    # A CheckboxGroup to show matching names; choices/value will be updated dynamically
    filter_checkbox = gr.CheckboxGroup(
        choices=[], value=[], label="Select which timeseries to show"
    )

    # Buttons to check or uncheck all
    with gr.Row():
        check_all_btn = gr.Button("Check All")
        uncheck_all_btn = gr.Button("Uncheck All")

    plot_button = gr.Button("Plot")

    gallery = gr.Gallery(label="Forecast Plots (filtered)")
    errbox = gr.Textbox(label="Error Message")

    # State to hold the full list of names
    names_state = gr.State([])

    # 1) When file or preset changes, extract full names and update the checkbox + state
    file_input.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )
    preset_dropdown.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )

    # 2) When search text changes, filter names_state and update the checkbox
    search_box.change(
        fn=filter_names,
        inputs=[search_box, names_state],
        outputs=filter_checkbox
    )

    # 3) Check All button: set checkbox value to all names in state
    check_all_btn.click(
        fn=check_all,
        inputs=names_state,
        outputs=filter_checkbox
    )

    # 4) Uncheck All button: set checkbox value to empty list
    uncheck_all_btn.click(
        fn=uncheck_all,
        inputs=names_state,
        outputs=filter_checkbox
    )

    # 5) When "Plot" is clicked, generate the filtered plots
    plot_button.click(
        fn=display_filtered_forecast,
        inputs=[file_input, preset_dropdown, filter_checkbox],
        outputs=[gallery, errbox],
    )

demo.launch()



# Checked, filter

In [None]:
import io
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gradio as gr

# ----------------------------
# Helper functions (logic unchanged)
# ----------------------------

torch.manual_seed(42)
_forecast_tensor = torch.load("stocks_data_forecast.pt")  # shape = (n_series, pred_len, n_q)

def model_forecast(input_data):
    return _forecast_tensor

def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
    ax.plot(timeseries, color="blue")
    x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))
    for i in range(quantile_predictions.shape[1]):
        ax.plot(x_pred, quantile_predictions[:, i], color=f"C{i}")
    ax.set_title(f"Timeseries: {timeseries_name}")
    labels = [f"Quantile {i}" for i in range(quantile_predictions.shape[1])]
    ax.legend(labels, loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert("RGB")
    return np.array(img)

def load_table(file_path):
    ext = file_path.split(".")[-1].lower()
    if ext == "csv":
        return pd.read_csv(file_path)
    elif ext in ("xls", "xlsx"):
        return pd.read_excel(file_path)
    elif ext == "parquet":
        return pd.read_parquet(file_path)
    else:
        raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")

def extract_names_and_update(file, preset_filename):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return gr.update(choices=[], value=[]), []
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            names = df.iloc[:, 0].tolist()
        else:
            names = [f"Series {i}" for i in range(len(df))]
        return gr.update(choices=names, value=names), names
    except Exception:
        return gr.update(choices=[], value=[]), []

def filter_names(search_term, all_names):
    if not all_names:
        return gr.update(choices=[], value=[])
    if not search_term:
        return gr.update(choices=all_names, value=all_names)
    lower = search_term.lower()
    filtered = [n for n in all_names if lower in str(n).lower()]
    return gr.update(choices=filtered, value=filtered)

def check_all(names_list):
    return gr.update(value=names_list)

def uncheck_all(_):
    return gr.update(value=[])

def display_filtered_forecast(file, preset_filename, selected_names):
    """
    Load the table, filter by selected_names, run forecast (correctly sliced),
    and return a gallery + error string.
    """
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return [], "No file selected."
            df = load_table(preset_filename)

        # Extract all_names and numeric data
        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object \
           and not df.iloc[:, 0].str.isnumeric().all():
            all_names = df.iloc[:, 0].tolist()
            data_only = df.iloc[:, 1:].astype(float)
        else:
            all_names = [f"Series {i}" for i in range(len(df))]
            data_only = df.astype(float)

        # Build mask and filtered subset
        mask = [name in selected_names for name in all_names]
        if not any(mask):
            return [], "No timeseries chosen to plot."

        filtered_data = data_only.iloc[mask, :].values
        filtered_names = [all_names[i] for i, m in enumerate(mask) if m]

        # ------------------------
        # HERE is the only change:
        # Instead of calling model_forecast(inp), slice the full tensor by mask:
        # ------------------------
        out = _forecast_tensor[mask]   # shape = (n_chosen, pred_len, n_q)
        inp = torch.tensor(filtered_data)

        # Plot each chosen series against its properly‚Äêaligned forecast
        gallery_images = []
        for i in range(inp.shape[0]):
            gallery_images.append(
                plot_forecast_image(inp[i], out[i], filtered_names[i])
            )

        return gallery_images, ""
    except Exception as e:
        return [], f"Error: {e}. Please upload a valid CSV, XLS, XLSX, or PARQUET file."


# ----------------------------
# Gradio layout: two columns
# ----------------------------

with gr.Blocks() as demo:
    gr.Markdown("# üìà Stock Forecast Viewer üìä")
    gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")

    with gr.Row():
        # Left column: controls
        with gr.Column():
            gr.Markdown("## Data Selection")
            file_input = gr.File(
                label="Upload CSV / XLSX / PARQUET",
                file_types=[".csv", ".xls", ".xlsx", ".parquet"]
            )
            preset_dropdown = gr.Dropdown(
                label="Or choose a preset:",
                choices=["stocks_data_noindex.csv", "stocks_data.csv"],
                value="stocks_data_noindex.csv"
            )

            gr.Markdown("## Search / Filter")
            search_box = gr.Textbox(
                placeholder="Type to filter (e.g. 'AMZN')"
            )
            filter_checkbox = gr.CheckboxGroup(
                choices=[], value=[], label="Select which timeseries to show"
            )

            with gr.Row():
                check_all_btn = gr.Button("‚úÖ Check All")
                uncheck_all_btn = gr.Button("‚ùé Uncheck All")

            plot_button = gr.Button("‚ñ∂Ô∏è Plot Forecasts")
            errbox = gr.Textbox(interactive=False, placeholder="")

        # Right column: gallery
        with gr.Column():
            gr.Markdown("## Forecast Gallery")
            gallery = gr.Gallery()

    names_state = gr.State([])

    # When file or preset changes, update names
    file_input.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )
    preset_dropdown.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )

    # When search term changes, filter names
    search_box.change(
        fn=filter_names,
        inputs=[search_box, names_state],
        outputs=filter_checkbox
    )

    # Check All / Uncheck All
    check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
    uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)

    # Plot button
    plot_button.click(
        fn=display_filtered_forecast,
        inputs=[file_input, preset_dropdown, filter_checkbox],
        outputs=[gallery, errbox]
    )

demo.launch()

# Checked, almost ideal

In [None]:
import io
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gradio as gr

# ----------------------------
# Helper functions (logic unchanged)
# ----------------------------

torch.manual_seed(42)
_forecast_tensor = torch.load("stocks_data_forecast.pt")  # shape = (n_series, pred_len, n_q)

def model_forecast(input_data):
    return _forecast_tensor

def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
    ax.plot(timeseries, color="blue")
    x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))
    for i in range(quantile_predictions.shape[1]):
        ax.plot(x_pred, quantile_predictions[:, i], color=f"C{i}")
    ax.set_title(f"Timeseries: {timeseries_name}")
    labels = [f"Quantile {i}" for i in range(quantile_predictions.shape[1])]
    ax.legend(labels, loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert("RGB")
    return np.array(img)

def load_table(file_path):
    ext = file_path.split(".")[-1].lower()
    if ext == "csv":
        return pd.read_csv(file_path)
    elif ext in ("xls", "xlsx"):
        return pd.read_excel(file_path)
    elif ext == "parquet":
        return pd.read_parquet(file_path)
    else:
        raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")

def extract_names_and_update(file, preset_filename):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return gr.update(choices=[], value=[]), []
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            names = df.iloc[:, 0].tolist()
        else:
            names = [f"Series {i}" for i in range(len(df))]
        return gr.update(choices=names, value=names), names
    except Exception:
        return gr.update(choices=[], value=[]), []

def filter_names(search_term, all_names):
    if not all_names:
        return gr.update(choices=[], value=[])
    if not search_term:
        return gr.update(choices=all_names, value=all_names)
    lower = search_term.lower()
    filtered = [n for n in all_names if lower in str(n).lower()]
    return gr.update(choices=filtered, value=filtered)

def check_all(names_list):
    return gr.update(value=names_list)

def uncheck_all(_):
    return gr.update(value=[])

def display_filtered_forecast(file, preset_filename, selected_names):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return [], "No file selected."
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            all_names = df.iloc[:, 0].tolist()
            data_only = df.iloc[:, 1:].astype(float)
        else:
            all_names = [f"Series {i}" for i in range(len(df))]
            data_only = df.astype(float)

        mask = [name in selected_names for name in all_names]
        if not any(mask):
            return [], "No timeseries chosen to plot."

        filtered_data = data_only.iloc[mask, :].values
        filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
        out = _forecast_tensor[mask]   # slice forecasts to match filtered rows
        inp = torch.tensor(filtered_data)

        gallery_images = []
        for i in range(inp.shape[0]):
            gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))

        return gallery_images, ""
    except Exception as e:
        return [], f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET."

# ----------------------------
# Gradio layout: two columns + instructions
# ----------------------------

with gr.Blocks() as demo:
    gr.Markdown("# üìà Stock Forecast Viewer üìä")
    gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")

    with gr.Row():
        # Left column: controls
        with gr.Column():
            gr.Markdown("## Data Selection")
            file_input = gr.File(
                label="Upload CSV / XLSX / PARQUET",
                file_types=[".csv", ".xls", ".xlsx", ".parquet"]
            )
            preset_dropdown = gr.Dropdown(
                label="Or choose a preset:",
                choices=["stocks_data_noindex.csv", "stocks_data.csv"],
                value="stocks_data_noindex.csv"
            )

            gr.Markdown("## Search / Filter")
            search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
            filter_checkbox = gr.CheckboxGroup(
                choices=[], value=[], label="Select which timeseries to show"
            )

            with gr.Row():
                check_all_btn = gr.Button("‚úÖ Check All")
                uncheck_all_btn = gr.Button("‚ùé Uncheck All")

            plot_button = gr.Button("‚ñ∂Ô∏è Plot Forecasts")
            errbox = gr.Textbox(interactive=False, placeholder="")

        # Right column: gallery + instructions
        with gr.Column():
            gr.Markdown("## Forecast Gallery")
            gallery = gr.Gallery()

            # Instruction text below gallery
            gr.Markdown(
                """
                **How to format your data:**
                - Your file must be a table (CSV, XLS, XLSX, or Parquet).
                - If you haven't prepared the data, the preset file will be used.
                - **One row per timeseries.** Each row is treated as a separate series.
                - If you want to **name** each series, put the name as the first value in **every** row:
                  - Example (CSV):  
                    `AAPL, 120.5, 121.0, 119.8, ...`  
                    `AMZN, 3300.0, 3310.5, 3295.2, ...`  
                  - In that case, the first column is not numeric, so it will be used as the series name.
                - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
                  - Example:  
                    `120.5, 121.0, 119.8, ...`  
                    `3300.0, 3310.5, 3295.2, ...`  
                  - Then every row will be auto-named ‚ÄúSeries 0, Series 1, ‚Ä¶‚Äù in order.
                - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
                - The rest of the columns (after the optional name) must be numeric data points for that series.
                - You can filter by typing in the search box. Then check or uncheck individual names before plotting.
                - Use ‚ÄúCheck All‚Äù / ‚ÄúUncheck All‚Äù to quickly select or deselect every series.
                - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.
                """
            )

    names_state = gr.State([])

    # When file or preset changes, update names
    file_input.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )
    preset_dropdown.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )

    # When search term changes, filter names
    search_box.change(
        fn=filter_names,
        inputs=[search_box, names_state],
        outputs=filter_checkbox
    )

    # Check All / Uncheck All
    check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
    uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)

    # Plot button
    plot_button.click(
        fn=display_filtered_forecast,
        inputs=[file_input, preset_dropdown, filter_checkbox],
        outputs=[gallery, errbox]
    )

demo.launch()

# The default choice isn't processed when the default choice is chosen

In [None]:
import io
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gradio as gr

# ----------------------------
# Helper functions (logic unchanged)
# ----------------------------

torch.manual_seed(42)
_forecast_tensor = torch.load("stocks_data_forecast.pt")  # shape = (n_series, pred_len, n_q)

def model_forecast(input_data):
    return _forecast_tensor

def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=150)
    ax.plot(timeseries, color="blue")
    x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))
    for i in range(quantile_predictions.shape[1]):
        ax.plot(x_pred, quantile_predictions[:, i], color=f"C{i}")
    ax.set_title(f"Timeseries: {timeseries_name}")
    labels = [f"Quantile {i}" for i in range(quantile_predictions.shape[1])]
    ax.legend(labels, loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert("RGB")
    return np.array(img)

def load_table(file_path):
    ext = file_path.split(".")[-1].lower()
    if ext == "csv":
        return pd.read_csv(file_path)
    elif ext in ("xls", "xlsx"):
        return pd.read_excel(file_path)
    elif ext == "parquet":
        return pd.read_parquet(file_path)
    else:
        raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")

def extract_names_and_update(file, preset_filename):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return gr.update(choices=[], value=[]), []
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            names = df.iloc[:, 0].tolist()
        else:
            names = [f"Series {i}" for i in range(len(df))]
        return gr.update(choices=names, value=names), names
    except Exception:
        return gr.update(choices=[], value=[]), []

def filter_names(search_term, all_names):
    if not all_names:
        return gr.update(choices=[], value=[])
    if not search_term:
        return gr.update(choices=all_names, value=all_names)
    lower = search_term.lower()
    filtered = [n for n in all_names if lower in str(n).lower()]
    return gr.update(choices=filtered, value=filtered)

def check_all(names_list):
    return gr.update(value=names_list)

def uncheck_all(_):
    return gr.update(value=[])

def display_filtered_forecast(file, preset_filename, selected_names):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return [], "No file selected."
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            all_names = df.iloc[:, 0].tolist()
            data_only = df.iloc[:, 1:].astype(float)
        else:
            all_names = [f"Series {i}" for i in range(len(df))]
            data_only = df.astype(float)

        mask = [name in selected_names for name in all_names]
        if not any(mask):
            return [], "No timeseries chosen to plot."

        filtered_data = data_only.iloc[mask, :].values
        filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
        out = _forecast_tensor[mask]   # slice forecasts to match filtered rows
        inp = torch.tensor(filtered_data)

        gallery_images = []
        for i in range(inp.shape[0]):
            gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))

        return gallery_images, ""
    except Exception as e:
        return [], f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET."


# ----------------------------
# Gradio layout: two columns + instructions
# ----------------------------

with gr.Blocks() as demo:
    gr.Markdown("# üìà Stock Forecast Viewer üìä")
    gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")

    with gr.Row():
        # Left column: controls
        with gr.Column():
            gr.Markdown("## Data Selection")
            gr.Markdown("*If you haven't prepared the data, the preset file will be used.*")
            file_input = gr.File(
                label="Upload CSV / XLSX / PARQUET",
                file_types=[".csv", ".xls", ".xlsx", ".parquet"]
            )
            preset_dropdown = gr.Dropdown(
                label="Or choose a preset:",
                choices=["stocks_data_noindex.csv", "stocks_data.csv"],
                value="stocks_data_noindex.csv"
            )

            gr.Markdown("## Search / Filter")
            search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
            filter_checkbox = gr.CheckboxGroup(
                choices=[], value=[], label="Select which timeseries to show"
            )

            with gr.Row():
                check_all_btn = gr.Button("‚úÖ Check All")
                uncheck_all_btn = gr.Button("‚ùé Uncheck All")

            plot_button = gr.Button("‚ñ∂Ô∏è Plot Forecasts")
            errbox = gr.Textbox(label="Error Message", interactive=False)

        # Right column: gallery + instructions
        with gr.Column():
            gr.Markdown("## Forecast Gallery")
            gallery = gr.Gallery()

            # Instruction text below gallery
            gr.Markdown(
                """
                **How to format your data:**
                - Your file must be a table (CSV, XLS, XLSX, or Parquet).
                - **One row per timeseries.** Each row is treated as a separate series.
                - If you want to **name** each series, put the name as the first value in **every** row:
                  - Example (CSV):  
                    `AAPL, 120.5, 121.0, 119.8, ...`  
                    `AMZN, 3300.0, 3310.5, 3295.2, ...`  
                  - In that case, the first column is not numeric, so it will be used as the series name.
                - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
                  - Example:  
                    `120.5, 121.0, 119.8, ...`  
                    `3300.0, 3310.5, 3295.2, ...`  
                  - Then every row will be auto-named ‚ÄúSeries 0, Series 1, ‚Ä¶‚Äù in order.
                - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
                - The rest of the columns (after the optional name) must be numeric data points for that series.
                - You can filter by typing in the search box. Then check or uncheck individual names before plotting.
                - Use ‚ÄúCheck All‚Äù / ‚ÄúUncheck All‚Äù to quickly select or deselect every series.
                - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.
                """
            )

    names_state = gr.State([])

    # When file or preset changes, update names
    file_input.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )
    preset_dropdown.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )

    # When search term changes, filter names
    search_box.change(
        fn=filter_names,
        inputs=[search_box, names_state],
        outputs=filter_checkbox
    )

    # Check All / Uncheck All
    check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
    uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)

    # Plot button
    plot_button.click(
        fn=display_filtered_forecast,
        inputs=[file_input, preset_dropdown, filter_checkbox],
        outputs=[gallery, errbox]
    )

demo.launch()

# Default choice - None

In [None]:
import io
import pandas as pd
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gradio as gr

# ----------------------------
# Helper functions (logic unchanged)
# ----------------------------

torch.manual_seed(42)
_forecast_tensor = torch.load("stocks_data_forecast.pt")  # shape = (n_series, pred_len, n_q)

def model_forecast(input_data):
    return _forecast_tensor

def plot_forecast_image(timeseries, quantile_predictions, timeseries_name):
    fig, ax = plt.subplots(figsize=(10, 6), dpi=300)
    
    # Plot the original timeseries with thicker line and marker
    ax.plot(timeseries, color="blue", linewidth=2.5, marker='o', label="Given Data")
    
    x_pred = range(len(timeseries) - 1, len(timeseries) - 1 + len(quantile_predictions))
    # Use distinct colors with higher alpha for smoothness
    for i in range(quantile_predictions.shape[1]):
        ax.plot(x_pred, quantile_predictions[:, i], color=f"C{i}", linewidth=2, alpha=0.8, label=f"Quantile {i+1}")
    
    ax.set_title(f"Timeseries: {timeseries_name}", fontsize=16, fontweight='bold')
    ax.set_xlabel("Time", fontsize=12)
    ax.set_ylabel("Value", fontsize=12)
    
    ax.grid(True, which='both', linestyle='--', linewidth=0.7, alpha=0.6)
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=10, frameon=True, shadow=True)
    
    plt.tight_layout(rect=[0, 0, 0.82, 1])
    
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight", transparent=True)
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf).convert("RGB")
    return np.array(img)

def load_table(file_path):
    ext = file_path.split(".")[-1].lower()
    if ext == "csv":
        return pd.read_csv(file_path)
    elif ext in ("xls", "xlsx"):
        return pd.read_excel(file_path)
    elif ext == "parquet":
        return pd.read_parquet(file_path)
    else:
        raise ValueError("Unsupported format. Use CSV, XLS, XLSX, or PARQUET.")

def extract_names_and_update(file, preset_filename):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return gr.update(choices=[], value=[]), []
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            names = df.iloc[:, 0].tolist()
        else:
            names = [f"Series {i}" for i in range(len(df))]
        return gr.update(choices=names, value=names), names
    except Exception:
        return gr.update(choices=[], value=[]), []

def filter_names(search_term, all_names):
    if not all_names:
        return gr.update(choices=[], value=[])
    if not search_term:
        return gr.update(choices=all_names, value=all_names)
    lower = search_term.lower()
    filtered = [n for n in all_names if lower in str(n).lower()]
    return gr.update(choices=filtered, value=filtered)

def check_all(names_list):
    return gr.update(value=names_list)

def uncheck_all(_):
    return gr.update(value=[])

def display_filtered_forecast(file, preset_filename, selected_names):
    try:
        if file is not None:
            df = load_table(file.name)
        else:
            if not preset_filename:
                return [], "No file selected."
            df = load_table(preset_filename)

        if df.shape[1] > 0 and df.iloc[:, 0].dtype == object and not df.iloc[:, 0].str.isnumeric().all():
            all_names = df.iloc[:, 0].tolist()
            data_only = df.iloc[:, 1:].astype(float)
        else:
            all_names = [f"Series {i}" for i in range(len(df))]
            data_only = df.astype(float)

        mask = [name in selected_names for name in all_names]
        if not any(mask):
            return [], "No timeseries chosen to plot."

        filtered_data = data_only.iloc[mask, :].values
        filtered_names = [all_names[i] for i, m in enumerate(mask) if m]
        out = _forecast_tensor[mask]   # slice forecasts to match filtered rows
        inp = torch.tensor(filtered_data)

        gallery_images = []
        for i in range(inp.shape[0]):
            gallery_images.append(plot_forecast_image(inp[i], out[i], filtered_names[i]))

        return gallery_images, ""
    except Exception as e:
        return [], f"Error: {e}. Use CSV, XLS, XLSX, or PARQUET."


# ----------------------------
# Gradio layout: two columns + instructions
# ----------------------------

with gr.Blocks() as demo:
    gr.Markdown("# üìà Stock Forecast Viewer üìä")
    gr.Markdown("Upload data or choose a preset, filter by name, then click Plot.")

    with gr.Row():
        # Left column: controls
        with gr.Column():
            gr.Markdown("## Data Selection")
            file_input = gr.File(
                label="Upload CSV / XLSX / PARQUET",
                file_types=[".csv", ".xls", ".xlsx", ".parquet"]
            )
            preset_dropdown = gr.Dropdown(
                label="Or choose a preset:",
                choices=["stocks_data_noindex.csv", "stocks_data.csv"],
                value="No file selected"
            )

            gr.Markdown("## Search / Filter")
            search_box = gr.Textbox(placeholder="Type to filter (e.g. 'AMZN')")
            filter_checkbox = gr.CheckboxGroup(
                choices=[], value=[], label="Select which timeseries to show"
            )

            with gr.Row():
                check_all_btn = gr.Button("‚úÖ Check All")
                uncheck_all_btn = gr.Button("‚ùé Uncheck All")

            plot_button = gr.Button("‚ñ∂Ô∏è Plot Forecasts")
            errbox = gr.Textbox(label="Error Message", interactive=False)

        # Right column: gallery + instructions
        with gr.Column():
            gr.Markdown("## Forecast Gallery")
            gallery = gr.Gallery()

            # Instruction text below gallery
            gr.Markdown("## Instructions")
            gr.Markdown(
                """
                **How to format your data:**
                - Your file must be a table (CSV, XLS, XLSX, or Parquet).
                - **One row per timeseries.** Each row is treated as a separate series.
                - If you want to **name** each series, put the name as the first value in **every** row:
                  - Example (CSV):  
                    `AAPL, 120.5, 121.0, 119.8, ...`  
                    `AMZN, 3300.0, 3310.5, 3295.2, ...`  
                  - In that case, the first column is not numeric, so it will be used as the series name.
                - If you do **not** want named series, simply leave out the first column entirely and have all values numeric:
                  - Example:  
                    `120.5, 121.0, 119.8, ...`  
                    `3300.0, 3310.5, 3295.2, ...`  
                  - Then every row will be auto-named ‚ÄúSeries 0, Series 1, ‚Ä¶‚Äù in order.
                - **Consistency rule:** Either all rows have a non-numeric first entry for the name, or none do. Do not mix.
                - The rest of the columns (after the optional name) must be numeric data points for that series.
                - You can filter by typing in the search box. Then check or uncheck individual names before plotting.
                - Use ‚ÄúCheck All‚Äù / ‚ÄúUncheck All‚Äù to quickly select or deselect every series.
                - Finally, click **Plot Forecasts** to view the quantile forecast for each selected series.
                """
            )

    names_state = gr.State([])

    # When file or preset changes, update names
    file_input.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )
    preset_dropdown.change(
        fn=extract_names_and_update,
        inputs=[file_input, preset_dropdown],
        outputs=[filter_checkbox, names_state]
    )

    # When search term changes, filter names
    search_box.change(
        fn=filter_names,
        inputs=[search_box, names_state],
        outputs=filter_checkbox
    )

    # Check All / Uncheck All
    check_all_btn.click(fn=check_all, inputs=names_state, outputs=filter_checkbox)
    uncheck_all_btn.click(fn=uncheck_all, inputs=names_state, outputs=filter_checkbox)

    # Plot button
    plot_button.click(
        fn=display_filtered_forecast,
        inputs=[file_input, preset_dropdown, filter_checkbox],
        outputs=[gallery, errbox]
    )

demo.launch()