Performance Plots
[3]:
import pathlib
# wiki, version, langcode = 'simplewiki', '20211120', 'en'
# wiki, version, langcode = 'tawiki', '20220301', 'ta'
# wiki, version, langcode = 'fawiki', '20220301', 'fa'
# wiki, version, langcode = 'trwiki', '20220301', 'tr'
# wiki, version, langcode = 'nlwiki', '20220301', 'nl'
# wiki, version, langcode = 'arwiki', '20220301', 'ar'
# wiki, version, langcode = 'srwiki', '20220301', 'sr'
# wiki, version, langcode = 'eswiki', '20220301', 'es'
# wiki, version, langcode = 'jawiki', '20220301', 'ja'
# wiki, version, langcode = 'dewiki', '20220301', 'de'
usefeats = False
fpreds = sorted(pathlib.Path(f"wiki/{wiki}-{version}/experiments/").glob("*.tsv"))
fgold = pathlib.Path(f"evaluation/Mewsli-9/{langcode}.tsv")
from minimel import evaluate
df = evaluate(fgold, *fpreds)
Evaluating: 100%|██████████| 2/2 [00:00<00:00, 31.01it/s]
[11]:
import pandas as pd
# Load ablations
abl = pd.Series(df.index).str.extract(
"pred-mewsli."
"(?P<kind>\w+)"
"-?(?P<q>q0.25|q0.5|q1)?"
"(?P<f>.feat-clean-q1\.p[.0-9]+?)?"
"(?P<b>.\d+b)?"
"(?P<c>\..*?)?$"
)
abl["c"] = abl["c"].fillna("")
abl["c"] = abl.apply(
lambda s: (s["c"] != "") if s["kind"] not in ["base", "upper"] else s["c"], axis=1
)
abl["b"] = abl["b"].str[1:-1].astype("float").astype("Int32")
abl["f"] = abl["f"].fillna("")
abl["kind"] = abl["kind"].replace("stem", "clean-stem")
f1 = df.set_index(pd.MultiIndex.from_frame(abl))[("micro", "fscore")].rename("f1")
f = (
(f1.index.get_level_values("f") != "")
if usefeats
else (f1.index.get_level_values("f") == "")
)
d = (f1.index.get_level_values("kind") == "base") | (
f1.index.get_level_values("kind") == "upper"
)
f1 = f1[d | f]
f1 = f1.droplevel("f")
f1
[11]:
kind q b c
NaN NaN <NA> False 0.479931
Name: f1, dtype: float64
[7]:
f1.loc["base"].sort_index()
[7]:
q b c
NaN NaN 0.803648
.clean-q1 0.838007
.clean-stem-q1 0.828918
Name: f1, dtype: float64
[3]:
import seaborn as sns
import matplotlib.pyplot as plt
# PLOTTING
data = (
f1[f1.index.drop(["base", "upper"])]
.reset_index()
.rename(
columns={
"c": "Fallback",
"q": "Quantile",
"b": "Bits",
"kind": "Pre-processing",
}
)
.sort_values(["Pre-processing", "Quantile"], ascending=[1, 0])
)
fig = plt.figure(figsize=(12, 8))
g = sns.relplot(
kind="line",
data=data,
x="Bits",
y="f1",
hue="Fallback",
col="Pre-processing",
style="Pre-processing",
size="Quantile",
alpha=0.6,
)
handles, labels = g._legend.get_lines(), [t.get_text() for t in g._legend.get_texts()]
handles.append(handles[0]), labels.append("")
handles.append(handles[0]), labels.append("Baseline")
for (_, _, c), basef1 in f1.loc["base"].sort_index().iteritems():
ls = "--" if ("stem" in c) else "-"
color = "C1" if c else "C0"
label = f"base{c[:-3]}"
for ax in g.axes.flatten():
if (not c) or (("stem" in c) == ("stem" in ax.get_title())):
h = ax.axhline(basef1, ls=ls, linewidth=1, color=color, label=label)
handles.append(h)
labels.append(label)
handles.append(handles[0]), labels.append("Upper Bound")
for (_, _, c), basef1 in f1.loc["upper"].sort_index().iteritems():
ls = "--" if ("stem" in c) else "-"
color = "C1" if c else "C0"
label = f"upper{c[:-3]}"
for ax in g.axes.flatten():
if (not c) or (("stem" in c) == ("stem" in ax.get_title())):
h = ax.axhline(basef1, ls=ls, linewidth=2, color=color, label=label)
handles.append(h)
labels.append(label)
g._legend.remove()
legend = plt.legend(
handles=handles, labels=labels, loc="center left", bbox_to_anchor=(1, 0.5)
)
for h, text in zip(handles, legend.get_texts()):
if not h.get_visible():
text.set_fontweight("bold")
y0, y1 = ax.get_ylim()
ax.set_ylim([y0, y1 + 0.02])
def sizeof_fmt(num, suffix="B"):
for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]:
if abs(num) < 1024.0:
return f"{num:3.0f}{unit}{suffix}"
num /= 1024.0
return f"{num:.1f}Y{suffix}"
for ax in g.axes.flatten():
l = [f"{int(x)}\n{sizeof_fmt(2**x*16)}" for x in ax.get_xticks()]
ax.set_xticklabels(l)
ax.set_xlabel("Bits / Model Size")
from minimel import code_name
code_name["fa"] = "Persian"
code_name["ja"] = "Japanese"
langname = code_name.get(langcode).title()
g.figure.suptitle(
f"F$_1$ score per bits of model, {wiki}-{version} on Mewsli-9 in {langname}", y=1.05
)
plt.rcParams["pgf.texsystem"] = "pdflatex"
plt.savefig(f"paper/fig/{wiki}-{version}.pgf", bbox_inches="tight")
None
/tmp/ipykernel_18305/2362728954.py:60: UserWarning: FixedFormatter should only be used together with FixedLocator
ax.set_xticklabels(l)
<Figure size 864x576 with 0 Axes>