{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## 1. Prepare Python environment"
      ],
      "metadata": {
        "id": "lA8ZuoocCztC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install gradio --quiet\n",
        "!pip install calplot --quiet"
      ],
      "metadata": {
        "id": "tCDgWf53C3Ht"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load required Python packages\n",
        "import os, re, requests, logging, calplot\n",
        "import dask.dataframe as dd\n",
        "import multiprocessing as mp\n",
        "from tqdm import tqdm\n",
        "from itertools import product\n",
        "import pandas as pd\n",
        "import gradio as gr\n",
        "import plotly.express as px"
      ],
      "metadata": {
        "id": "SH0x9ZEXC5yV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to download remote file to the disk\n",
        "def urlDownload(urlLink, showProgress = False):\n",
        "  with requests.get(urlLink, stream=True) as r:\n",
        "    fileSize = int(r.headers.get('Content-Length'))\n",
        "    fileName = r.headers.get('Content-Disposition').split(\"filename=\")[1]\n",
        "    if not os.path.exists(fileName) or os.path.getsize(fileName) != fileSize:\n",
        "      block_size = 1024\n",
        "      if showProgress:\n",
        "        print(f\"Downloading {fileName}\")\n",
        "        progress_bar = tqdm(total=fileSize, unit='iB', unit_scale=True)\n",
        "      with open(fileName, 'wb') as file:\n",
        "        for data in r.iter_content(block_size):\n",
        "          if showProgress:\n",
        "            progress_bar.update(len(data))\n",
        "          file.write(data)\n",
        "      if showProgress:\n",
        "        progress_bar.close()\n",
        "    return fileName"
      ],
      "metadata": {
        "id": "7x14_wM0EFYG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 2. Prepare air quality dataset"
      ],
      "metadata": {
        "id": "CBsYINt_Eo3J"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Download the newest data\n",
        "urlLocation = 'https://aqicn.org/data-platform/covid19/report/39374-7694ec07/'\n",
        "csvFile = urlDownload(urlLocation, showProgress=True)\n",
        "\n",
        "# Define the columns to load, read data, skip the first 4 lines\n",
        "selected_cols = ['Date', 'Country', 'City', 'Specie', 'median']\n",
        "newTable = pd.read_csv(csvFile, skiprows=4, usecols=selected_cols)\n",
        "\n",
        "# Create lists of year and quarter names\n",
        "yNames = [str(i) for i in range(2022, 2024)]\n",
        "qNames = [\"Q\" + str(i) for i in range(1, 5)]\n",
        "\n",
        "# Create a data frame with the url locations and year/quarter combinations\n",
        "DF = pd.DataFrame(list(product(yNames, qNames)),columns=['yNames', 'qNames'])\n",
        "DF.insert(loc=0, column='urlLocation', value=urlLocation)\n",
        "\n",
        "# Combine url location and year/quarter combinations into a single column\n",
        "DF = pd.DataFrame({'urlLocations': DF.agg(''.join, axis=1)})\n",
        "\n",
        "# Download legacy data (in parallel)\n",
        "DDF = dd.from_pandas(DF, npartitions=mp.cpu_count())\n",
        "csvFiles = DDF.apply(lambda x : urlDownload(x[0]), axis=1, meta=pd.Series(dtype=\"str\")).compute(scheduler='threads')\n",
        "\n",
        "# Read legacy data files (sequentially)\n",
        "fileNamesQ = [f for f in os.listdir('.') if re.match(r'^.*Q\\d.csv$', f)]\n",
        "oldTable = pd.concat((pd.read_csv(f, skiprows=4, usecols=selected_cols) for f in fileNamesQ), ignore_index=True)\n",
        "\n",
        "# Append old and new (2024) data tables\n",
        "DF = pd.concat([oldTable, newTable])\n",
        "\n",
        "# Leave data of a specific country, rename main column to Value, sort, deduplicate\n",
        "selectEU = DF['Country']=='RO'\n",
        "DF = DF[selectEU].rename(columns={'median': 'Value'})\n",
        "DF = DF.sort_values(by=['Country', 'City', 'Date']).drop_duplicates()\n",
        "\n",
        "# Create a new data table with the info on selected variables\n",
        "selectedVars = ['temperature', 'humidity']\n",
        "selectedIdx = DF['Specie'].isin(selectedVars)\n",
        "dataTableEU = DF[selectedIdx]\n",
        "\n",
        "# Create pivot table, calculate THI for each row, drop rows with missing THI values\n",
        "dataTableTHI = dataTableEU.pivot_table(index=['Date', 'Country', 'City'], columns='Specie', values='Value').reset_index()\n",
        "dataTableTHI[\"THI\"] = 0.8 * dataTableTHI.temperature + (dataTableTHI.humidity/100)*(dataTableTHI.temperature-14.4) + 46.4\n",
        "dataTableTHI = dataTableTHI.dropna(subset=[\"THI\"])\n",
        "dataTableTHI = dataTableTHI[dataTableTHI['Date']>='2022-01-01']"
      ],
      "metadata": {
        "id": "47ykb910EvY7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## 3. Create Gradio dashboard"
      ],
      "metadata": {
        "id": "oZFYW9K3E3n-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Get unique cities and years\n",
        "myCities = dataTableTHI.City.unique().tolist()\n",
        "myYears = pd.to_datetime(dataTableTHI['Date']).dt.year.unique().tolist()\n",
        "\n",
        "# Set time series plot buttons\n",
        "tsDrill = [dict(count=1, label=\"1m\", step=\"month\", stepmode=\"backward\"),\n",
        "           dict(count=6, label=\"6m\", step=\"month\", stepmode=\"backward\"),\n",
        "           dict(count=1, label=\"YTD\", step=\"year\", stepmode=\"todate\"),\n",
        "           dict(count=1, label=\"1y\", step=\"year\", stepmode=\"backward\"),\n",
        "           dict(step=\"all\")]\n",
        "\n",
        "# Function for Gradio output plots\n",
        "def make_plot(myCity, myYear, myPlot):\n",
        "    if len(myYear)==0:\n",
        "      myYear = myYears\n",
        "    myTable = dataTableTHI[dataTableTHI['City']==myCity][[\"Date\", \"THI\"]]\n",
        "    myTable = myTable[pd.to_datetime(myTable['Date']).dt.year.isin(myYear)]\n",
        "    if myPlot == \"calplot\":\n",
        "      pdTimeSeries = pd.Series(myTable['THI'].values, index=pd.DatetimeIndex(myTable['Date']))\n",
        "      logging.getLogger('matplotlib.font_manager').disabled = True\n",
        "      cp = calplot.calplot(pdTimeSeries, dropzero=True, cmap='coolwarm', yearlabel_kws={'color': 'black', 'fontsize':9})\n",
        "      plot_result = cp[0]\n",
        "    elif myPlot == \"by month\":\n",
        "      myTable['Month'] = pd.to_datetime(myTable['Date']).dt.month\n",
        "      plot_result = px.box(myTable, x=\"Month\", y=\"THI\") # alt.Chart(myTable).mark_boxplot().encode(x='Month', y='THI')\n",
        "    elif myPlot == \"by weekday\":\n",
        "      myTable['DayOfWeek'] = pd.to_datetime(myTable['Date']).dt.dayofweek + 1\n",
        "      plot_result = px.box(myTable, x=\"DayOfWeek\", y=\"THI\") # alt.Chart(myTable).mark_boxplot().encode(x='DayOfWeek', y='THI')\n",
        "    else:\n",
        "      plot_result = px.line(myTable, x='Date', y='THI')\n",
        "      plot_result.update_xaxes(rangeslider_visible=True, rangeselector=dict(buttons=list(tsDrill)))\n",
        "    return plot_result\n",
        "\n",
        "# Design of Gradio dashboard\n",
        "with gr.Blocks() as demo:\n",
        "    citySelect = gr.Dropdown(myCities, label=\"City:\", value=\"Bucharest\")\n",
        "    yearSelect = gr.CheckboxGroup(myYears, label=\"Years:\", value=myYears)\n",
        "    plotSelect = gr.Radio(label=\"Plot type:\", choices=['calplot', 'by month', 'by weekday', 'time series'], value='calplot')\n",
        "    plotVisual = gr.Plot(show_label=False, container=False)\n",
        "    citySelect.change(make_plot, inputs=[citySelect, yearSelect, plotSelect], outputs=[plotVisual])\n",
        "    yearSelect.change(make_plot, inputs=[citySelect, yearSelect, plotSelect], outputs=[plotVisual])\n",
        "    plotSelect.change(make_plot, inputs=[citySelect, yearSelect, plotSelect], outputs=[plotVisual])\n",
        "    demo.load(make_plot, inputs=[citySelect, yearSelect, plotSelect], outputs=[plotVisual])\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    demo.launch(debug=True)"
      ],
      "metadata": {
        "id": "ZqpLilULE-gc"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}