Quick start

Installation

To use LangSpace, first install it using pip:

(.venv) $ pip install langspace

Probing an LM-VAE

Here’s a basic example of how to perform a disentanglement evaluation and an interpolation probe on an LM-VAE model trained with LangVAE (or use our example Colab notebook):

import torch
import nltk
from langvae import LangVAE
from saf_datasets import EntailmentBankDataSet
from langspace.probe import DisentanglementProbe
from langspace.metrics.disentanglement import DisentanglementMetric as Metric
from langspace.probe import InterpolationProbe
from langspace.metrics.interpolation import InterpolationMetric as InterpMetric
from saf.importers import ListImporter

# Load annotated data from saf_datasets.
dataset = EntailmentBankDataSet.from_resource("pos+lemma+ctag+dep+srl#expl_only-noreps")
annotations = {"srl_f": dataset.annotations["srl"]}

# The 'srl' annotation contains a list with the role of a single token in each phrase in the sentence.
# 'srl_f' will contain the first non-empty srl annotation for each token.
for sent in dataset:
    for token in sent.tokens:
        srl = token.annotations["srl"]
        token_annot = [lbl for lbl in srl if (lbl != "O")][0] if (len(set(srl)) > 1) else srl[0]
        token.annotations["srl_f"] = token_annot


# Load explanation LM-VAE for generation.
model = LangVAE.load_from_hf_hub("neuro-symbolic-ai/eb-langcvae-bert-base-cased-gpt2-srl-l128") # Loads model from HuggingFace Hub.
model.eval()

if (torch.cuda.is_available()):
  model.encoder.to("cuda")
  model.decoder.to("cuda")
  model.encoder.init_pretrained_model()
  model.decoder.init_pretrained_model()


# Probing latent disentanglement
gen_factors = {
    "direction": ["ARGM-DIR"],
    "because": ["ARGM-CAU"],
    "purpose": ["ARGM-PRP","ARGM-PNC", "ARGM-GOL"],
    "more": ["ARGM-EXT"],
    "location": ["ARGM-LOC"],
    "argument": ["ARG0", "ARG1", "ARG2", "ARG3", "ARG4"],
    "manner": ["ARGM-MNR"],
    "can": ["ARGM-MOD"],
    "argm-prd": ["ARGM-PRD"],
    "empty": ["O"],
    "negation": ["ARGM-NEG"],
    "verb": ["V"],
    "if-then": ["ARGM-ADV", "ARGM-DIS"],
    "time": ["ARGM-TMP"],
    "C-ARG": ["C-ARG1", "C-ARG0", "C-AGR2"]
}

# Change SRL labels to match dataset annotation vocabulary.
for factor in gen_factors:
    gen_factors[factor] = ["I-" + lbl if (lbl != "O") else lbl for lbl in gen_factors[factor]]


metrics = [Metric.Z_DIFF, Metric.Z_MIN_VAR, Metric.MIG, Metric.INFORMATIVENESS, Metric.COMPLETENESS]
disentang_report = DisentanglementProbe(model, dataset, sample_size=1000, metrics=metrics, gen_factors=gen_factors,
                                        annotations=annotations).report()


# Probing latent interpolation
nltk.download('punkt_tab')

sentences = [
    ("humans require freshwater for survival", "B-ARG0 B-V B-ARG1 B-ARGM-PRP I-ARGM-PRP"),
    ("animals require food to survive", "B-ARG0 B-V B-ARG1 B-ARGM-PRP I-ARGM-PRP"),
    ("the sun is in the northern hemisphere", "B-ARG0 I-ARG0 B-V B-ARGM-LOC I-ARGM-LOC I-ARGM-LOC I-ARGM-LOC"),
    ("food is a source of energy for animals / plants", "B-ARG0 B-V B-ARG2 I-ARG2 I-ARG2 I-ARG2 B-ARGM-PRP I-ARGM-PRP")
]
sentences_ds = ListImporter(annotations=["srl_f"])([[(tok, lbl) for tok, lbl in zip(sent[0].split(), sent[1].split())] for sent in sentences]).sentences

interp_dataset = [(sentences_ds[0], sentences_ds[1]), (sentences_ds[2], sentences_ds[3])]

interp_report = InterpolationProbe(model, interp_dataset, eval=[InterpMetric.SMOOTHNESS], annotations=annotations).report()