Skip to content

Utility Classes

Debugging Purpose

DummyAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/dummy.py
class DummyAlgorithm(BaseAlgorithm):
    def run(self, modelpool: BaseModelPool):
        """
        This method returns the pretrained model from the model pool.
        If the pretrained model is not available, it returns the first model from the model pool.

        Args:
            modelpool (BaseModelPool): The pool of models to fuse.

        Raises:
            AssertionError: If the model is not found in the model pool.
        """
        if isinstance(modelpool, nn.Module):
            return modelpool
        elif not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        model = modelpool.load_pretrained_or_first_model()

        assert model is not None, "Model is not found in the model pool."
        return model

run(modelpool)

This method returns the pretrained model from the model pool. If the pretrained model is not available, it returns the first model from the model pool.

Parameters:

Raises:

  • AssertionError

    If the model is not found in the model pool.

Source code in fusion_bench/method/dummy.py
def run(self, modelpool: BaseModelPool):
    """
    This method returns the pretrained model from the model pool.
    If the pretrained model is not available, it returns the first model from the model pool.

    Args:
        modelpool (BaseModelPool): The pool of models to fuse.

    Raises:
        AssertionError: If the model is not found in the model pool.
    """
    if isinstance(modelpool, nn.Module):
        return modelpool
    elif not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    model = modelpool.load_pretrained_or_first_model()

    assert model is not None, "Model is not found in the model pool."
    return model

Analysis Purpose

TaskVectorCosSimilarity

Bases: BaseAlgorithm, LightningFabricMixin

This class is similar to the Dummy algorithm, but it also print (or save) the cosine similarity matrix between the task vectors of the models in the model pool.

Source code in fusion_bench/method/analysis/task_vector_cos_similarity.py
class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
    """
    This class is similar to the Dummy algorithm,
    but it also print (or save) the cosine similarity matrix between the task vectors of the models in the model pool.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "plot_heatmap": "plot_heatmap",
        "_output_path": "output_path",
    }

    def __init__(
        self,
        plot_heatmap: bool,
        trainable_only: bool = True,
        max_points_per_model: Optional[int] = None,
        output_path: Optional[str] = None,
        **kwargs,
    ):
        self.plot_heatmap = plot_heatmap
        self.trainable_only = trainable_only
        self.max_points_per_model = max_points_per_model
        self._output_path = output_path
        super().__init__(**kwargs)

    @property
    def output_path(self):
        if self._output_path is None:
            return self.fabric.logger.log_dir
        else:
            return self._output_path

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        pretrained_model = modelpool.load_pretrained_model()

        task_vectors = []
        for name, finetuned_model in tqdm(
            modelpool.named_models(), total=len(modelpool)
        ):
            print(f"computing task vectors for {name}")
            task_vectors.append(
                self.get_task_vector(pretrained_model, finetuned_model).to(
                    torch.float64
                )
            )
        task_vectors = torch.stack(task_vectors, dim=0)

        cos_sim_matrix = torch.zeros(
            len(modelpool), len(modelpool), dtype=torch.float64
        )
        for i in range(len(modelpool)):
            for j in range(i, len(modelpool)):
                assert task_vectors[i].size() == task_vectors[j].size()
                cos_sim_matrix[i, j] = torch.nn.functional.cosine_similarity(
                    task_vectors[i], task_vectors[j], dim=0
                )
                cos_sim_matrix[j, i] = cos_sim_matrix[i, j]

        # convert the matrix to a pandas DataFrame
        cos_sim_df = pd.DataFrame(
            cos_sim_matrix.numpy(),
            index=modelpool.model_names,
            columns=modelpool.model_names,
        )

        print(cos_sim_df)
        if self.output_path is not None:
            os.makedirs(self.output_path, exist_ok=True)
            cos_sim_df.to_csv(
                os.path.join(self.output_path, "task_vector_cos_similarity.csv")
            )

        if self.plot_heatmap:
            self._plot_heatmap(cos_sim_df)

        return pretrained_model

    def _plot_heatmap(self, data: pd.DataFrame):
        """
        This function plots a heatmap of the provided data using seaborn.

        Args:
            data (pd.DataFrame): A pandas DataFrame containing the data to be plotted.
            figsize (tuple): A tuple specifying the size of the figure. Default is (4, 3).

        Returns:
            None
        """
        import matplotlib.pyplot as plt
        import seaborn as sns

        # Create a heatmap using seaborn
        plt.figure()
        sns.heatmap(
            data,
            annot=True,
            fmt=".2f",
            cmap="GnBu",
        )

        # Add title and labels with increased font size
        plt.title("Heatmap of Cos Similarities", fontsize=14)
        # plt.xlabel("Task", fontsize=14)
        # plt.ylabel("Task", fontsize=14)
        plt.xticks(rotation=45)
        plt.yticks(rotation=45)

        # Show plot
        plt.savefig(
            os.path.join(self.output_path, "task_vector_cos_similarity.pdf"),
            bbox_inches="tight",
        )
        plt.close()

    def get_task_vector(
        self, pretrained_model: nn.Module, finetuned_model: nn.Module
    ) -> torch.Tensor:
        task_vector = state_dict_sub(
            self.get_state_dict(finetuned_model),
            self.get_state_dict(pretrained_model),
        )
        task_vector = state_dict_to_vector(task_vector)

        task_vector = task_vector.cpu().float().numpy()
        # downsample if necessary
        if (
            self.max_points_per_model is not None
            and self.max_points_per_model > 0
            and task_vector.shape[0] > self.max_points_per_model
        ):
            log.info(
                f"Downsampling task vectors to {self.max_points_per_model} points."
            )
            indices = np.random.choice(
                task_vector.shape[0], self.max_points_per_model, replace=False
            )
            task_vector = task_vector[indices].copy()

        task_vector = torch.from_numpy(task_vector)
        return task_vector

    def get_state_dict(self, model: nn.Module):
        if self.trainable_only:
            return trainable_state_dict(model)
        else:
            return model.state_dict()

TaskVectorViolinPlot

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Plot violin plots of task vectors as in: L.Shen, A.Tang, E.Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging

Source code in fusion_bench/method/analysis/task_vector_violin_plot.py
class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin):
    R"""
    Plot violin plots of task vectors as in:
    [L.Shen, A.Tang, E.Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging](https://arxiv.org/abs/2410.21804)
    """

    # config_mapping is a mapping from the attributes to the key in the configuration files
    _config_mapping = BaseAlgorithm._config_mapping | {
        "trainable_only": "trainable_only",
        "max_points_per_model": "max_points_per_model",
        "fig_kwargs": "fig_kwargs",
        "_output_path": "output_path",
    }

    def __init__(
        self,
        trainable_only: bool,
        max_points_per_model: Optional[int] = 1000,
        fig_kwawrgs=None,
        output_path: Optional[str] = None,
        **kwargs,
    ):
        R"""
        This class creates violin plots to visualize task vectors, which represent the differences
        between fine-tuned models and their pretrained base model.

        Args:
            trainable_only (bool): If True, only consider trainable parameters when computing
                task vectors. If False, use all parameters.
            fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
                `matplotlib.pyplot.subplots`. Common options include:
                - figsize: Tuple of (width, height) in inches
                - dpi: Dots per inch
                - facecolor: Figure background color
                Defaults to None.
            output_path (str, optional): Path where the violin plot will be saved. If None,
                uses the fabric logger's log directory. Defaults to None.
            kwargs: Additional keyword arguments passed to the parent class(es).

        Example:

            ```python
            plotter = TaskVectorViolinPlot(
                trainable_only=True,
                fig_kwargs={'figsize': (10, 6), 'dpi': 300},
                output_path='./plots'
            )

            plotter.run(modelpool)
            ```
        """
        self.trainable_only = trainable_only
        self.fig_kwargs = fig_kwawrgs
        self.max_points_per_model = max_points_per_model
        self._output_path = output_path
        super().__init__(**kwargs)

    @property
    def output_path(self):
        if self._output_path is None:
            return self.fabric.logger.log_dir
        else:
            return self._output_path

    def run(self, modelpool: BaseModelPool):
        """Create violin plots of task vectors comparing different fine-tuned models against a pretrained model.

        This method implements the visualization technique from the paper "Efficient and Effective
        Weight-Ensembling Mixture of Experts for Multi-Task Model Merging". It:

        1. Loads the pretrained model
        2. Computes task vectors (differences between fine-tuned and pretrained models)
        3. Creates violin plots showing the distribution of values in these task vectors

        Args:
            modelpool (BaseModelPool): Model pool containing the pretrained model and fine-tuned models

        Returns:
            pretrained_model (nn.Model): The plot is saved to the specified output path.
        """
        assert modelpool.has_pretrained
        pretrained_model = modelpool.load_pretrained_model()

        # Compute task vectors for each fine-tuned model
        with torch.no_grad(), timeit_context("Computing task vectors"):
            task_vectors: Dict[str, NDArray] = {}
            for name, finetuned_model in tqdm(
                modelpool.named_models(), total=len(modelpool)
            ):
                print(f"computing task vectors for {name}")
                task_vectors[name] = self.get_task_vector(
                    pretrained_model, finetuned_model
                )

        # === Create violin plot ===
        fig, ax = plt.subplots(
            1, 1, **self.fig_kwargs if self.fig_kwargs is not None else {}
        )
        fig = cast(plt.Figure, fig)
        ax = cast(plt.Axes, ax)

        # Prepare data for plotting
        data = [values for values in task_vectors.values()]
        labels = list(task_vectors.keys())

        # Create violin plot using seaborn
        with timeit_context("ploting"):
            sns.violinplot(data=data, ax=ax)

        # Customize plot
        ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.set_ylabel("Task Vector Values")
        ax.set_title("Distribution of Task Vector Values")

        # Adjust layout to prevent label cutoff and save plot
        plt.tight_layout()
        os.makedirs(self.output_path, exist_ok=True)
        output_file = f"{self.output_path}/task_vector_violin.pdf"
        plt.savefig(output_file, bbox_inches="tight")
        plt.close(fig)

        # === Create violin plot (Abs values) ===
        fig, ax = plt.subplots(
            1, 1, **self.fig_kwargs if self.fig_kwargs is not None else {}
        )
        fig = cast(plt.Figure, fig)
        ax = cast(plt.Axes, ax)

        # Prepare data for plotting
        data = [np.abs(values) for values in task_vectors.values()]
        labels = list(task_vectors.keys())

        # Create violin plot using seaborn
        with timeit_context("ploting abs value plot"):
            sns.violinplot(data=data, ax=ax)

        # Customize plot
        ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.set_ylabel("The Absolute Values")
        ax.set_title("Distribution of Task Vector Absolute Values")

        # Adjust layout to prevent label cutoff and save plot
        plt.tight_layout()
        os.makedirs(self.output_path, exist_ok=True)
        output_file = f"{self.output_path}/task_vector_violin_abs.pdf"
        plt.savefig(output_file, bbox_inches="tight")
        plt.close(fig)

        return pretrained_model

    def get_task_vector(self, pretrained_model, finetuned_model):
        task_vector = state_dict_sub(
            self.get_state_dict(finetuned_model),
            self.get_state_dict(pretrained_model),
        )
        task_vector = state_dict_to_vector(task_vector)

        task_vector = task_vector.cpu().float().numpy()
        # downsample if necessary
        if (
            self.max_points_per_model is not None
            and self.max_points_per_model > 0
            and task_vector.shape[0] > self.max_points_per_model
        ):
            log.info(
                f"Downsampling task vectors to {self.max_points_per_model} points."
            )
            indices = np.random.choice(
                task_vector.shape[0], self.max_points_per_model, replace=False
            )
            task_vector = task_vector[indices].copy()

        return task_vector

    def get_state_dict(self, model: nn.Module):
        if self.trainable_only:
            return trainable_state_dict(model)
        else:
            return model.state_dict()

__init__(trainable_only, max_points_per_model=1000, fig_kwawrgs=None, output_path=None, **kwargs)

This class creates violin plots to visualize task vectors, which represent the differences between fine-tuned models and their pretrained base model.

Parameters:

  • trainable_only (bool) –

    If True, only consider trainable parameters when computing task vectors. If False, use all parameters.

  • fig_kwargs (dict) –

    Dictionary of keyword arguments to pass to matplotlib.pyplot.subplots. Common options include: - figsize: Tuple of (width, height) in inches - dpi: Dots per inch - facecolor: Figure background color Defaults to None.

  • output_path (str, default: None ) –

    Path where the violin plot will be saved. If None, uses the fabric logger's log directory. Defaults to None.

  • kwargs

    Additional keyword arguments passed to the parent class(es).

Example:

```python
plotter = TaskVectorViolinPlot(
    trainable_only=True,
    fig_kwargs={'figsize': (10, 6), 'dpi': 300},
    output_path='./plots'
)

plotter.run(modelpool)
```
Source code in fusion_bench/method/analysis/task_vector_violin_plot.py
def __init__(
    self,
    trainable_only: bool,
    max_points_per_model: Optional[int] = 1000,
    fig_kwawrgs=None,
    output_path: Optional[str] = None,
    **kwargs,
):
    R"""
    This class creates violin plots to visualize task vectors, which represent the differences
    between fine-tuned models and their pretrained base model.

    Args:
        trainable_only (bool): If True, only consider trainable parameters when computing
            task vectors. If False, use all parameters.
        fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
            `matplotlib.pyplot.subplots`. Common options include:
            - figsize: Tuple of (width, height) in inches
            - dpi: Dots per inch
            - facecolor: Figure background color
            Defaults to None.
        output_path (str, optional): Path where the violin plot will be saved. If None,
            uses the fabric logger's log directory. Defaults to None.
        kwargs: Additional keyword arguments passed to the parent class(es).

    Example:

        ```python
        plotter = TaskVectorViolinPlot(
            trainable_only=True,
            fig_kwargs={'figsize': (10, 6), 'dpi': 300},
            output_path='./plots'
        )

        plotter.run(modelpool)
        ```
    """
    self.trainable_only = trainable_only
    self.fig_kwargs = fig_kwawrgs
    self.max_points_per_model = max_points_per_model
    self._output_path = output_path
    super().__init__(**kwargs)

run(modelpool)

Create violin plots of task vectors comparing different fine-tuned models against a pretrained model.

This method implements the visualization technique from the paper "Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging". It:

  1. Loads the pretrained model
  2. Computes task vectors (differences between fine-tuned and pretrained models)
  3. Creates violin plots showing the distribution of values in these task vectors

Parameters:

  • modelpool (BaseModelPool) –

    Model pool containing the pretrained model and fine-tuned models

Returns:

  • pretrained_model ( Model ) –

    The plot is saved to the specified output path.

Source code in fusion_bench/method/analysis/task_vector_violin_plot.py
def run(self, modelpool: BaseModelPool):
    """Create violin plots of task vectors comparing different fine-tuned models against a pretrained model.

    This method implements the visualization technique from the paper "Efficient and Effective
    Weight-Ensembling Mixture of Experts for Multi-Task Model Merging". It:

    1. Loads the pretrained model
    2. Computes task vectors (differences between fine-tuned and pretrained models)
    3. Creates violin plots showing the distribution of values in these task vectors

    Args:
        modelpool (BaseModelPool): Model pool containing the pretrained model and fine-tuned models

    Returns:
        pretrained_model (nn.Model): The plot is saved to the specified output path.
    """
    assert modelpool.has_pretrained
    pretrained_model = modelpool.load_pretrained_model()

    # Compute task vectors for each fine-tuned model
    with torch.no_grad(), timeit_context("Computing task vectors"):
        task_vectors: Dict[str, NDArray] = {}
        for name, finetuned_model in tqdm(
            modelpool.named_models(), total=len(modelpool)
        ):
            print(f"computing task vectors for {name}")
            task_vectors[name] = self.get_task_vector(
                pretrained_model, finetuned_model
            )

    # === Create violin plot ===
    fig, ax = plt.subplots(
        1, 1, **self.fig_kwargs if self.fig_kwargs is not None else {}
    )
    fig = cast(plt.Figure, fig)
    ax = cast(plt.Axes, ax)

    # Prepare data for plotting
    data = [values for values in task_vectors.values()]
    labels = list(task_vectors.keys())

    # Create violin plot using seaborn
    with timeit_context("ploting"):
        sns.violinplot(data=data, ax=ax)

    # Customize plot
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("Task Vector Values")
    ax.set_title("Distribution of Task Vector Values")

    # Adjust layout to prevent label cutoff and save plot
    plt.tight_layout()
    os.makedirs(self.output_path, exist_ok=True)
    output_file = f"{self.output_path}/task_vector_violin.pdf"
    plt.savefig(output_file, bbox_inches="tight")
    plt.close(fig)

    # === Create violin plot (Abs values) ===
    fig, ax = plt.subplots(
        1, 1, **self.fig_kwargs if self.fig_kwargs is not None else {}
    )
    fig = cast(plt.Figure, fig)
    ax = cast(plt.Axes, ax)

    # Prepare data for plotting
    data = [np.abs(values) for values in task_vectors.values()]
    labels = list(task_vectors.keys())

    # Create violin plot using seaborn
    with timeit_context("ploting abs value plot"):
        sns.violinplot(data=data, ax=ax)

    # Customize plot
    ax.set_xticklabels(labels, rotation=45, ha="right")
    ax.set_ylabel("The Absolute Values")
    ax.set_title("Distribution of Task Vector Absolute Values")

    # Adjust layout to prevent label cutoff and save plot
    plt.tight_layout()
    os.makedirs(self.output_path, exist_ok=True)
    output_file = f"{self.output_path}/task_vector_violin_abs.pdf"
    plt.savefig(output_file, bbox_inches="tight")
    plt.close(fig)

    return pretrained_model