Inspect Model Features

[2]:
!pip install --user -e .
# surface = 'madrid'
surface = "utrecht"
# wiki = 'simplewiki-20211120'
wiki = "nlwiki-20220301"
# wiki = 'eswiki-20220301'
modelfile = f"../data/wiki/{wiki}/experiments/clean-q0.25.24b.vw"
datafile = f"../data/wiki/{wiki}/experiments/clean-q0.25.dat"

import subprocess

args = ["python", "-m", "minimel", "audit", modelfile, datafile, surface]
feats = set()
with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=None) as process:
    for line in process.stdout:
        line = line.decode("utf8").rstrip()
        if line.startswith("\t"):
            feats.update(set(f for f in line[1:].split("\t") if f.startswith("l^")))
len(feats)
Obtaining file:///home/jupyter-benno/minimEL
  Preparing metadata (setup.py) ... done
Installing collected packages: minimel
  Attempting uninstall: minimel
    Found existing installation: minimel 0.1
    Uninstalling minimel-0.1:
      Successfully uninstalled minimel-0.1
  Running setup.py develop for minimel
Successfully installed minimel-0.1
creating quadratic features for pairs: ls
only testing
using no cache
Reading datafile = none
num sources = 0
Num weight bits = 24
learning rate = 0.5
initial_t = 0
power_t = 0.5
Enabled learners: gd, scorer-identity, csoaa_ldf-prob, shared_feature_merger
Input label = CS
Output pred = SCALARS
average  since         example        example        current        current  current
loss     last          counter         weight          label        predict features
0.000000 0.000000            1            1.0          known         221653      208
0.000000 0.000000            2            2.0        unknown              0     3172
0.000000 0.000000            4            4.0        unknown              0     1664
0.000000 0.000000            8            8.0        unknown              0     2418
0.000000 0.000000           16           16.0        unknown              0     1794
0.000000 0.000000           32           32.0        unknown              0     1898
0.000000 0.000000           64           64.0          known            776      910
0.007812 0.015625          128          128.0        unknown              0     8294
0.003906 0.000000          256          256.0        unknown              0    10504
0.003906 0.003906          512          512.0        unknown              0     1248
{803: 588, 776: 229, 707767: 94, 261716: 38, 575655: 21, 221653: 14, 2679365: 11, 18108: 4, 24680: 2}

finished run
number of examples = 1001
weighted example sum = 1001.000000
weighted label sum = 0.000000
average loss = 0.008991
average multiclass log loss = 0.449134
total feature number = 2910232
[2]:
253045
[3]:
import io
import pandas as pd

df = pd.read_csv(
    io.StringIO("\n".join(feats)),
    sep="\*|:|\^|=",
    header=None,
    engine="python",
    usecols=[2, 4, 7],
    names=["wid", "feat", "weight"],
).dropna()

# Normalize weights
df["weight"] = -(df["weight"] - df.groupby("feat")["weight"].transform("mean"))
[4]:
select_ents = set(df["wid"].unique())

import sqlite3

con = sqlite3.connect(f"../data/wiki/{wiki}/index_{wiki}.db")
ent_label = {}
for e in select_ents:
    l = pd.read_sql_query(
        f'select * from mapping where wikidata_id="Q{e}" limit 1', con
    )
    ent_label[e] = l["wikipedia_title"][0]

print(dict(sorted(ent_label.items())))
{776: 'Utrecht_(provincie)', 803: 'Utrecht_(stad)', 18108: 'Utrecht_(Zuid-Afrika)', 24680: 'FC_Utrecht', 221653: 'Universiteit_Utrecht', 261716: 'Aartsbisdom_Utrecht_(rooms-katholiek)', 575655: 'Station_Utrecht_Centraal', 707767: 'Sticht_Utrecht', 847384: 'Utrechts_Conservatorium', 2012748: 'Vechtsebanen', 2193594: 'Hr.Ms._Utrecht_(1901)', 2679365: 'Heerlijkheid_Utrecht', 85308316: 'BVC_Utrecht'}
[5]:
def topfeat(gr):
    gr = gr.drop(columns="wid").set_index("feat").dropna()
    gr = gr.loc[gr["weight"].apply("abs").sort_values().index[::-1]]
    return gr.head(10).reset_index()


tops = df.groupby("wid").apply(topfeat)
tops = tops.swaplevel().unstack().swaplevel(axis=1).sort_index(axis=1).T
tops.index = tops.index.set_levels(
    [ent_label[wid] for wid in tops.index.levels[0]], level=0
)
tops.index.names = (None, None)

import seaborn as sns

cmap = sns.diverging_palette(0, 230, 90, 60, as_cmap=True)
tops.head(10).style.background_gradient(
    cmap=cmap, subset=pd.IndexSlice[pd.IndexSlice[:, "weight"], :]
).format(precision=2)
[5]:
    0 1 2 3 4 5 6 7 8 9
Utrecht_(provincie) feat provincie wegverkeer stuurde baarn vormentaal samenspraak filosofie vakantieoord knutselen cartoonist
weight 1.98 -1.26 -1.07 0.87 0.86 -0.86 0.84 0.83 0.80 -0.80
Utrecht_(stad) feat blijkt utrecht provincie belangrijkste niedersächsisch anc stad museum verplegings ondeelbaar
weight -1.74 1.38 -1.18 -1.10 -1.06 -1.04 0.88 0.87 -0.87 -0.84
Utrecht_(Zuid-Afrika) feat piraten pionier balmat fusieplan polemieken ontario uitreiking university anna zalige
weight -1.19 -1.11 -0.89 -0.83 -0.80 0.79 -0.76 0.76 0.75 -0.75
FC_Utrecht feat universiteit perks brandlaag charter geslagen draait krugersdorp temperature kessels behandelen
weight 1.03 -1.01 -0.84 0.81 -0.74 0.71 -0.70 0.66 -0.66 -0.65
Universiteit_Utrecht feat orde kaunda romanpersonage opgelegd work vilain pestepidemie wet voorafgegaan yntema
weight -1.30 0.92 -0.87 -0.84 -0.82 -0.82 0.80 -0.78 -0.72 0.70
[ ]: