Summary of Prediction Results

[65]:
import altair as alt
import pandas as pd
from collections import defaultdict
from src.conf import settings
from src.models.curtailment_classifier import registry

INPUT_DIR = settings.DATA_DIR / "processed/results/"
RAW_DATA = settings.DATA_DIR /
[22]:
alt.data_transformers.disable_max_rows()
[22]:
DataTransformerRegistry.enable('default')
[66]:
charts = defaultdict(list)
for model in registry:
    model = model.__name__
    for prediction_file in INPUT_DIR.glob(f"predictions-{model}-*.parquet"):
        event = prediction_file.name.split("-")[2]
        event = event.split("_")[2].rsplit(".", 1)[0]
        predictions = pd.read_parquet(prediction_file)
        chart = alt.Chart(predictions, title=f"{model} : {event}").mark_bar().encode(
            alt.X("True", bin=alt.Bin(step=.05), title="Probability"),
            alt.Y("count(True)", title="Count"),
            alt.Color("actual", legend=alt.Legend(title="True Value")),
        )
        charts[event].append(chart)
[78]:
def plot_event(event):
    return alt.vconcat(alt.hconcat(*charts[event][:2]), alt.hconcat(*charts[event][2:]))
[86]:
c = plot_event("0.01")
c.save("../docs/figures/0-01.png", webdriver="firefox")
c
/home/ttu/.local/share/virtualenvs/CaReCur-b3qbtQ7S/lib/python3.8/site-packages/selenium/webdriver/remote/webdriver.py:381: UserWarning: find_element_by_* commands are deprecated. Please use find_element() instead
  warnings.warn("find_element_by_* commands are deprecated. Please use find_element() instead")
[86]:
[87]:
c = plot_event("0.03")
c.save("../docs/figures/0-03.png", webdriver="firefox")
c
/home/ttu/.local/share/virtualenvs/CaReCur-b3qbtQ7S/lib/python3.8/site-packages/selenium/webdriver/remote/webdriver.py:381: UserWarning: find_element_by_* commands are deprecated. Please use find_element() instead
  warnings.warn("find_element_by_* commands are deprecated. Please use find_element() instead")
[87]:
[88]:
c = plot_event("0.05")
c.save("../docs/figures/0-05.png", webdriver="firefox")
c
/home/ttu/.local/share/virtualenvs/CaReCur-b3qbtQ7S/lib/python3.8/site-packages/selenium/webdriver/remote/webdriver.py:381: UserWarning: find_element_by_* commands are deprecated. Please use find_element() instead
  warnings.warn("find_element_by_* commands are deprecated. Please use find_element() instead")
[88]:
[89]:
c = plot_event("0.10")
c.save("../docs/figures/0-10.png", webdriver="firefox")
c
/home/ttu/.local/share/virtualenvs/CaReCur-b3qbtQ7S/lib/python3.8/site-packages/selenium/webdriver/remote/webdriver.py:381: UserWarning: find_element_by_* commands are deprecated. Please use find_element() instead
  warnings.warn("find_element_by_* commands are deprecated. Please use find_element() instead")
[89]:
[ ]: