langspace.probe.cluster_vis package

Submodules

langspace.probe.cluster_vis.methods module

class langspace.probe.cluster_vis.methods.ClusterVisualizationMethod(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]

Bases: Enum

PCA = 'PCA'
TSNE = 't-SNE'
UMAP = 'UMAP'

Module contents

class langspace.probe.cluster_vis.ClusterVisualizationProbe(model: LangVAE, data: Iterable[Sentence], sample_size: int, target_roles: Dict[str, List[str]], methods: List[ClusterVisualizationMethod], cluster_annotation: str, batch_size: int = 20, annotations: Dict[str, List[str]] = None, plot_label_map: Dict[str, str] = None)[source]

Bases: LatentSpaceProbe

A probe for visualizing the latent space of a language VAE via clustering techniques.

This probe supports visualization methods including PCA, T-SNE, and UMAP. It processes a collection of sentences, extracts their latent representations, and generates visual plots highlighting clusters based on provided target roles and annotations. Generated plots are saved to image files.

model

The LM-VAE model whose latent space is to be analyzed.

Type:

LangVAE

data

An iterable of Sentence objects representing the input data.

Type:

Iterable[Sentence]

sample_size

The number of data points to process for visualization.

Type:

int

target_roles

A mapping between annotation categories and target tokens for visualization clustering.

Type:

Dict[str, List[str]]

method

A list of visualization methods to apply (e.g., TSNE, UMAP, PCA).

Type:

List[ClusterVisualizationMethod]

cluster_annot

The annotation name used to filter or identify clusters.

Type:

str

batch_size

The number of data points to encode in each batch.

Type:

int

annotations

Optional dictionary of annotation types to be processed and all

Type:

Dict[str, List[str]], optional

their possible values, for conditional encoding.
plot_label_map

Optional mapping to provide custom labels for plotting.

Type:

Dict[str, str], optional

report()[source]

Generate and save cluster visualization plots based on the encoded latent representations.

For each visualization method specified (TSNE, UMAP, PCA), it creates a corresponding plot:
  • For TSNE, a TSNEVisualizer is created, fit with the latent vectors and labels, and saved as “t_sne.png”.

  • For UMAP, a UMAPVisualizer is created, fit and saved as “umap.png”.

  • For PCA, a PCA visualizer from Yellowbrick is used, with labels converted to integer classes, and saved as “pca.png”.

Returns:

A list containing the visualizer objects corresponding to each applied visualization method.

static role_content_viz(viz_list: Iterable[Sentence], target_roles: Dict[str, List[str]], annotation: str, plot_label_map: Dict[str, str]) List[Tuple[Sentence, str]][source]

Extract sentences and associate them with role-specific labels for content visualization.

The method iterates through each sentence and examines its tokens. If a token’s surface form is found in the list of target tokens (as specified by the given annotation in target_roles), it constructs a label. The label is either the original annotation or a remapped label as defined in plot_label_map. Each sentence with an associated label forms a tuple that is added to the resulting list.

Parameters:
  • viz_list (Iterable[Sentence]) – An iterable of Sentence objects to be processed.

  • target_roles (Dict[str, List[str]]) – Dictionary mapping annotation keys to a list of target tokens.

  • annotation (str) – The key used to access the token’s annotations for role filtering.

  • plot_label_map (Dict[str, str]) – Optional mapping to translate or reformat the original annotation label.

Returns:

A list of tuples where each tuple contains a Sentence and its associated label.

Return type:

List[Tuple[Sentence, str]]

static structure_viz(viz_list, sample_size=1000, TopK=5)[source]

Generate a structured visualization list by removing consecutive duplicate semantic role labels.

This method processes an input list of (sentence, semantic role labels) pairs. For each pair, it removes repeated adjacent role labels; for example, transforming “ARG0 ARG0 ARG0 V ARG1 ARG1” into “ARG0 V ARG1”. It then counts the occurrences of each unique label pattern and selects only the top K most frequent labels. Finally, the input list is balanced to include only up to (sample_size / TopK) instances for each target label.

Parameters:
  • viz_list (List[Tuple[Sentence, str]]) – A list of tuples where each tuple contains a sentence and a string of semantic role labels separated by spaces.

  • sample_size (int, optional) – The maximum number of data points to consider. Default is 1000.

  • TopK (int, optional) – The number of most frequent unique role structures to retain. Default is 5.

Returns:

A filtered and balanced list of (sentence, unique role structure) pairs.