Source code for ls_mlkit.util.plot.comparison_plot

"""
Comparison plotting utilities for multi-model analysis
"""

import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np


[docs] def plot_entity_wise_comparison( results_dict: Dict[str, List[Dict[str, Any]]], output_dir: str, figsize: Tuple[float, float] = (12, 8), top_n: int = 20, name_formatter: Optional[Callable[[str], str]] = None, entity_key: str = "entity_name", value_key: str = "value", title: Optional[str] = None, xlabel: str = "Entity Names", ylabel: str = "Value", filename: str = "entity_wise_comparison.png", name_truncate_length: int = 15, alpha: float = 0.7, dpi: int = 300, grid: bool = True, grid_alpha: float = 0.3, legend_bbox_to_anchor: Tuple[float, float] = (0.5, 1.2), legend_loc: str = "upper center", legend_fontsize: Union[int, str] = 7, legend_ncol: int = 1, rotation: float = 45, ha: str = "right", show: bool = False, save: bool = True, return_fig_ax: bool = False, ) -> Optional[Tuple[plt.Figure, plt.Axes]]: # Extract all unique entity names all_entity_names = set() for model_data in results_dict.values(): for entry in model_data: if entity_key in entry: all_entity_names.add(entry[entity_key]) all_entities = sorted(all_entity_names) # Calculate variation for each entity across models entity_variations = {} for entity in all_entities: values = [] for model_data in results_dict.values(): # Find value for this entity in this model for entry in model_data: if entry.get(entity_key) == entity and value_key in entry: values.append(entry[value_key]) break # Only consider entities present in multiple models if len(values) > 1: entity_variations[entity] = np.std(values) # Select top N entities by variation sorted_entities = sorted(entity_variations.items(), key=lambda x: x[1], reverse=True) selected_entities = [entity for entity, _ in sorted_entities[:top_n]] if not selected_entities: print("No entities found with variation across models.") return None # Create figure with dynamic width based on number of entities fig_width = max(figsize[0], len(selected_entities) * 0.4) fig, ax = plt.subplots(figsize=(fig_width, figsize[1])) x_positions = np.arange(len(selected_entities)) bar_width = 0.8 / len(results_dict) # Generate colors for different models colors = plt.cm.Set3(np.linspace(0, 1, len(results_dict))) for i, ((model_name, data), color) in enumerate(zip(results_dict.items(), colors)): model_values = [] for entity in selected_entities: # Find value for this entity in this model value = None for entry in data: if entry.get(entity_key) == entity and value_key in entry: value = entry[value_key] break model_values.append(value if value is not None else 0.0) # Plot bars for this model label = name_formatter(model_name) if name_formatter else model_name x_pos = x_positions + i * bar_width - (len(results_dict) - 1) * bar_width / 2 ax.bar(x_pos, model_values, bar_width, label=label, color=color, alpha=alpha) # Set labels and title ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) if title is None: title = f"Entity-wise {ylabel} Comparison (Top {top_n} by Variation)" ax.set_title(title) ax.set_xticks(x_positions) # Truncate entity names for readability truncated_names = [ name[:name_truncate_length] + "..." if len(name) > name_truncate_length else name for name in selected_entities ] ax.set_xticklabels(truncated_names, rotation=rotation, ha=ha) # Add legend and grid ax.legend(bbox_to_anchor=legend_bbox_to_anchor, loc=legend_loc, fontsize=legend_fontsize, ncol=legend_ncol) if grid: ax.grid(True, alpha=grid_alpha, axis="y") plt.tight_layout() # Save the plot if save: output_path = os.path.join(output_dir, filename) save_kwargs = {"dpi": dpi, "bbox_inches": "tight", "pad_inches": 0.2} # Include legend in saved area legend = ax.get_legend() if legend: save_kwargs["bbox_extra_artists"] = [legend] plt.savefig(output_path, **save_kwargs) print(f"Entity-wise comparison plot saved to: {output_path}") # Show the plot if show: plt.show() # Return figure and axes if requested if return_fig_ax: return fig, ax # Close the figure if not returning it if not return_fig_ax: plt.close(fig)
if __name__ == "__main__": # Example usage and testing import numpy as np # Generate example data np.random.seed(42) example_results = { "Model_A": [{"entity_name": f"item_{i}", "value": np.random.normal(80, 10)} for i in range(50)], "Model_B": [{"entity_name": f"item_{i}", "value": np.random.normal(75, 12)} for i in range(50)], "Model_C": [{"entity_name": f"item_{i}", "value": np.random.normal(85, 8)} for i in range(50)], } # Test the function plot_entity_wise_comparison(example_results, output_dir=".", top_n=10, show=True, filename="test_comparison.png")