Source code for er_evaluation.plots._plots

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from er_evaluation.data_structures import MembershipVector
from er_evaluation.error_analysis import error_metrics
from er_evaluation.estimators import (estimates_table,
                                      pairwise_precision_estimator,
                                      pairwise_recall_estimator,
                                      summary_estimates_table)
from er_evaluation.estimators._utils import _parse_weights
from er_evaluation.metrics import (metrics_table, pairwise_f,
                                   pairwise_precision, pairwise_recall)
from er_evaluation.summary import (cluster_hill_number,
                                   cluster_sizes_distribution,
                                   summary_statistics)

DEFAULT_METRICS = {
    "Pairwise precision": pairwise_precision,
    "Pairwise recall": pairwise_recall,
}

DEFAULT_COMPARISON_METRICS = {
    "Pairwise precision": pairwise_precision,
    "Pairwise F1": pairwise_f,
}

DEFAULT_ESTIMATORS = {
    "Pairwise precision": pairwise_precision_estimator,
    "Pairwise recall": pairwise_recall_estimator,
}


[docs]def compare_plots(*figs, names=None, marker="color", marker_values=None): r""" Combine multiple figures into one. """ assert names is None or len(names) == len(figs) assert marker_values is None or len(marker_values) == len(figs) combined = go.Figure() for i, fig in enumerate(figs): if marker_values is not None: marker_val = marker_values[i] else: marker_val = i fig.update_traces(marker={marker: marker_val}, showlegend=True) if names is not None: fig.update_traces(name=names[i]) elif fig.data[0].name == "": fig.update_traces(name=i) combined.add_traces(fig.data) combined.update_xaxes(range=fig.layout.xaxis.range) combined.update_yaxes(range=fig.layout.yaxis.range) return combined
[docs]def plot_cluster_sizes_distribution(membership, groupby=None, name=None, normalize=False): r""" Plot the cluster size distribution Args: membership (_type_): Membership vector. groupby (_type_, optional): Series to group by. Defaults to None. name (Series, optional): Name of the plot (useful when combining multiple plots together). Defaults to None. normalize: Wether or not to normalize Returns: Figure: Cluster size distribution plot. """ membership = MembershipVector(membership, dropna=True) if groupby is not None: groupby = MembershipVector(groupby, dropna=True) if groupby.name is None: groupby.name = "groupby" membership = membership[membership.index.isin(groupby.index)] groupby = groupby[groupby.index.isin(membership.index)] cs_dist = membership.groupby(groupby).apply(cluster_sizes_distribution).reset_index(name="cs_size") if normalize: cs_dist.cs_size = cs_dist.cs_size / cs_dist.cs_size.sum() y_label = "proportion" else: y_label = "count" fig = px.bar( x=cs_dist.level_1, y=cs_dist.cs_size, color=cs_dist[groupby.name], labels={"x": "Cluster size", "y": y_label, "color": groupby.name}, barmode="group", ) else: cs_dist = cluster_sizes_distribution(membership) if normalize: cs_dist = cs_dist / cs_dist.sum() y_label = "proportion" else: y_label = "count" fig = px.bar( x=cs_dist.index, y=cs_dist.values, labels={"x": "Cluster size", "y": y_label}, barmode="group", ) if name is not None: fig.update_traces(name=name) return fig
[docs]def plot_entropy_curve(membership, q_range=None, groupby=None, name=None): r""" Plot the Hill number entropy curve Args: membership (_type_): Membership vector. groupby (_type_, optional): Series to group by. Defaults to None. name (Series, optional): Name of the plot (useful when combining multiple plots together). Defaults to None. Returns: Figure: Hill number entropy curve. """ membership = MembershipVector(membership) if q_range is None: q_range = np.linspace(0, 2) if groupby is not None: groupby = MembershipVector(groupby, dropna=True) if groupby.name is None: groupby.name = "groupby" membership = membership[membership.index.isin(groupby.index)] groupby = groupby[groupby.index.isin(membership.index)] hill_numbers = ( membership.groupby(groupby) .apply(lambda x: pd.Series([cluster_hill_number(x, q) for q in q_range], index=q_range)) .reset_index(name="hill_numbers") ) fig = px.line( x=hill_numbers.level_1, y=hill_numbers.hill_numbers, color=hill_numbers[groupby.name], title="Hill Numbers entropy curve", labels={"x": "q", "y": "Hill Number", "color": groupby.name}, ) else: hill_numbers = [cluster_hill_number(membership, q) for q in q_range] fig = px.line( x=q_range, y=hill_numbers, title="Hill Numbers entropy curve", labels={"x": "q", "y": "Hill Number"}, ) if name is not None: fig.update_traces(name=name) return fig
[docs]def plot_summaries(predictions, names=None, type="line", line_shape="spline", markers=True, **kwargs): """ Plot summary statistics Args: predictions (dict): Dictionary of predictions for which to plot the summary statistics. names (Series, optional): Series with cluster element names. Used to compute the homonymy rate and the name variation rate. Defaults to None. type (str, optional): One of "line" for a line plot, or "bar" for a bar plot. Defaults to "line". **kwargs (optional): Additional arguments to pass to plotly express for plot creation. Returns: plotly Figure Examples: >>> from er_evaluation.datasets import load_pv_disambiguations >>> predictions, _ = load_pv_disambiguations() >>> fig = plot_summaries(predictions) >>> fig.show() # doctest: +SKIP """ summaries = pd.DataFrame.from_records([summary_statistics(pred, names=names) for pred in predictions.values()]) summaries["prediction"] = predictions.keys() if type == "line": plt = lambda *args, **kwargs: px.line(*args, line_shape=line_shape, markers=markers, **kwargs) elif type == "bar": plt = px.bar else: raise ValueError(f"Unknown plot type {type}. Should be one of ['line', 'bar'].") fig = make_subplots( rows=2 if names is None else 3, cols=2, subplot_titles=( "Average Cluster Size", "Matching Rate", "Number of distinct cluster sizes", "Entropy", "Name Variation Rate", "Homonymy Rate", ), shared_xaxes=True, **kwargs, ) plots = [["average_cluster_size", "matching_rate"], ["H0", "H1"]] for i, row in enumerate(plots): for j, col in enumerate(row): fig.add_trace(plt(summaries, x="prediction", y=col).data[0], row=i + 1, col=j + 1) if names is not None: extra_plots = ["name_variation_rate", "homonymy_rate"] for i, col in enumerate(extra_plots): fig.add_trace(plt(summaries, x="prediction", y=col).data[0], row=3, col=i + 1) fig.update_annotations(font_size=14) fig.update_layout(title_text="Disambiguation Summary Statistics") return fig
[docs]def add_ests_to_summaries(fig, predictions, sample, weights, names=None): params = summary_estimates_table(sample, weights, predictions, names) plots = [ (1, 1, "Avg Cluster Size Estimate"), (1, 2, "Matching Rate Estimate"), ] if names is not None: plots += [(3, 1, "Name Variation Estimate"), (3, 2, "Homonymy Rate Estimate")] fig.update_traces(legendgroup=0, name="summary") fig.update_traces(selector=0, showlegend=True) for k, (i, j, estimate) in enumerate(plots): dat = params.query(f"estimate == '{estimate}'") fig.add_trace( go.Scatter( name="estimate", x=dat["prediction"], y=dat["value"], line=dict(color="black", dash="dot", width=1, shape="spline"), marker=dict(color="black", size=4), legendgroup=1, showlegend=(k == 0), ), row=i, col=j, ) fig.add_trace( go.Scatter( name="estimate", x=dat["prediction"], y=dat["value"] - 2 * dat["std"], line=dict(color="black", dash="dot", width=0, shape="spline"), mode="lines", legendgroup=1, showlegend=False, ), row=i, col=j, ) fig.add_trace( go.Scatter( name="estimate", x=dat["prediction"], y=dat["value"] + 2 * dat["std"], fill="tonexty", fillcolor="rgba(0, 0, 0, 0.1)", line=dict(color="black", dash="dot", width=0, shape="spline"), mode="lines", legendgroup=1, showlegend=False, ), row=i, col=j, ) return fig
[docs]def plot_metrics(predictions, reference, metrics=DEFAULT_METRICS, type="line", **kwargs): """ Plot performance metrics. Args: predictions (dict): Dictionary of predictions for which to plot the performance metrics. reference (Series): Reference membership vector representing the ground truth. metrics (dict, optional): Dictionary of metrics to display. Defaults to DEFAULT_METRICS. type (str, optional): One of "line" for a line plot or "bar" for a bar plot. Defaults to "line". **kwargs (optional): Additional arguments to pass to plotly express for plot creation. Returns: plotly Figure Examples: >>> from er_evaluation.datasets import load_pv_disambiguations >>> predictions, reference = load_pv_disambiguations() >>> fig = plot_metrics(predictions, reference) >>> fig.show() # doctest: +SKIP """ table = metrics_table(predictions, {"reference": reference}, metrics=metrics) if type == "line": fig = px.line(table, x="prediction", y="value", color="metric", **kwargs) elif type == "bar": fig = px.bar(table, x="metric", y="value", color="prediction", barmode="group", **kwargs) else: raise ValueError(f"Unknown plot type {type}. Should be one of ['line', 'bar'].") fig.update_layout(title_text="Performance metrics") return fig
[docs]def plot_estimates( predictions, sample_weights, estimators=DEFAULT_ESTIMATORS, type="line", line_shape="spline", **kwargs ): """ Plot representative performance estimates. Args: predictions (dict): Dictionary of predictions for which to plot the performance estimates. sample_weights (dict): Dictionary with an element named "sample" containing sampled clusters, and an element named "weights" containing the sampling weights. estimators (dict, optional): Dictionary of estimators to use. Defaults to DEFAULT_ESTIMATORS. type (str, optional): One of "line" for a line plot or "bar" for a bar plot. Defaults to "line". **kwargs (optional): Additional arguments to pass to plotly express for plot creation. Returns: plotly Figure. Examples: >>> from er_evaluation.datasets import load_pv_disambiguations >>> predictions, reference = load_pv_disambiguations() >>> fig = plot_estimates(predictions, {"sample": reference, "weights": "cluster_size"}) >>> fig.show() # doctest: +SKIP """ table = estimates_table(predictions, samples_weights={"sample": sample_weights}, estimators=estimators) table["std_2"] = 2 * table["std"] if type == "line": fig = px.line( table, x="prediction", y="value", color="estimator", symbol="estimator", error_y="std_2", line_shape=line_shape, **kwargs, ) elif type == "bar": fig = px.bar(table, x="estimator", y="value", color="prediction", barmode="group", error_y="std_2", **kwargs) else: raise ValueError(f"Unknown plot type {type}. Should be one of ['line', 'bar'].") fig.update_layout(title_text="Performance estimates") return fig
[docs]def plot_cluster_errors( prediction, reference, x="expected_relative_extra", y="expected_relative_missing", groupby=None, weights=None, opacity=0.5, **kwargs, ): """ Scatter plot of two cluster-wise error metrics. Metrics that can be plotted are: * **expected_extra** (see :meth:`er_evaluation.error_analysis.expected_extra`) * **expected_relative_extra** (see :meth:`er_evaluation.error_analysis.expected_relative_extra`) * **expected_missing** (see :meth:`er_evaluation.error_analysis.expected_missing`) * **expected_relative_missing** (see :meth:`er_evaluation.error_analysis.expected_relative_missing`) * **error_indicator** (see :meth:`er_evaluation.error_analysis.error_indicator`) Args: prediction (Series): Predicted clustering. reference (Series): Reference clustering. x (str, optional): x-axis metric to plot. Defaults to "expected_relative_extra". y (str, optional): y-axis metric to plot. Defaults to "expected_relative_missing". groupby (Series, optional): Optional Series with grouping values (corresponding to color elements). Should be indexed by cluster identifier, with values corresponding to group assignment. weights (Series, optional): Optional Series with cluster weights. Should be indexed by cluster identifier, with values corresponding to cluster weight. Can also be set to the string "cluster_size" for clusters sampled with probability proportional to size. opacity (float, optional): Opacity. Defaults to 0.5. **kwargs (optional): Additional arguments to pass to plotly express for plot creation. Returns: plotly Figure Note: Weights are not accounted for in the marginal histograms. """ errors = error_metrics(prediction, reference) size = None color = None if weights is not None: size = "weight" weights = _parse_weights(reference, weights) weights.name = "weight" errors = errors.merge(weights, left_index=True, right_index=True, how="left") if groupby is not None: color = "color" groupby.name = "color" errors = errors.merge(groupby, left_index=True, right_index=True, how="left") fig = px.scatter( errors, x=x, y=y, opacity=opacity, marginal_x="histogram", marginal_y="histogram", size=size, color=color, **kwargs, ) fig.update_layout(title_text="Cluster-Wise Error Metrics") return fig
[docs]def plot_comparison(predictions, metrics=DEFAULT_COMPARISON_METRICS, color_continuous_scale="Blues", **kwargs): """ Plot metrics computed for all prediction pairs. This is meant to help investigate the similarity and differences between a set of disambiguations. Args: predictions (dict): dictionary of membership vectors to compare. metrics (dict, optional): Dictionary of metrics to compute. Defaults to DEFAULT_COMPARISON_METRICS. **kwargs (optional): Additional arguments to pass to plotly express for plot creation. Examples: >>> from er_evaluation.datasets import load_pv_disambiguations >>> predictions, _ = load_pv_disambiguations() >>> fig = plot_comparison(predictions) >>> fig.show() # doctest: +SKIP """ def metrics_matrix(predictions, metrics): matrix = np.zeros((len(metrics), len(predictions), len(predictions))) for col, metric in enumerate(metrics.values()): for i, pred_i in enumerate(predictions.values()): for j, pred_j in enumerate(predictions.values()): matrix[col, i, j] = metric(pred_i, pred_j) return matrix matrix = metrics_matrix(predictions, metrics) keys = list(predictions.keys()) fig = px.imshow( matrix, x=keys, y=keys, facet_col=0, facet_col_wrap=min(len(metrics), 2), aspect="equal", color_continuous_scale=color_continuous_scale, **kwargs, ) fig.update_layout(title_text="Disambiguation Similarity") for i, name in enumerate(metrics.keys()): fig.layout.annotations[i].update(text=name, font_size=14) fig.update_xaxes(title="Prediction") fig.update_yaxes(title="Reference") return fig