langspace.probe.arithmetic package
Module contents
- class langspace.probe.arithmetic.ArithmeticProbe(model: LangVAE, data: List[Tuple[Sentence, Sentence]], ops: List[ArithmeticOps], annotations: Dict[str, List[str]] = None)[source]
Bases:
LatentSpaceProbeA probe for exploring arithmetic operations in the latent space of a language model variational autoencoder (LM-VAE).
This class applies specified arithmetic operations to latent representations obtained from pairs of sentences. It supports operations such as summation, subtraction, and averaging. In addition, the probe can generate a report in the form of a pandas DataFrame summarizing the original source and target sentences alongside the results of the applied operations.
- model
The LM-VAE model to be probed.
- Type:
LangVAE
- data
A list of sentence pairs as (source, target) tuples.
- Type:
List[Tuple[Sentence, Sentence]]
- ops
A list of arithmetic operations to be applied to the latent vectors.
- Type:
List[ArithmeticOps]
- annotations
Dictionary of annotation types to be processed and all their
- possible values.
- arithmetic(source: Tensor, target: Tensor) List[Tensor][source]
Apply arithmetic operations to the source and target latent representations.
- Parameters:
source (Tensor) – The latent representation of the source sentences.
target (Tensor) – The latent representation of the target sentences.
- Returns:
A list of tensors, each resulting from applying the corresponding arithmetic operation.
- report() DataFrame[source]
Generate a report summarizing the arithmetic probe results.
- The final DataFrame will have the following columns:
source: The original source sentence surfaces.
target: The original target sentence surfaces.
op: The arithmetic operation applied (as a lowercase string).
generate: The generated sentence after applying the latent operation.
- Returns:
A DataFrame containing a detailed report for each arithmetic operation.