Source code for ccsd.src.utils.plot

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""plot.py: utility functions for plotting.
"""

import math
import os
import pickle
import warnings
from typing import Any, Dict, FrozenSet, List, Optional, Union

import hypernetx as hnx  # to visalize CC of dim 2
import imageio.v3 as imageio
import kaleido  # import kaleido FIRST to avoid any conflicts
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plotly
import plotly.graph_objs as go
import torch
from easydict import EasyDict
from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, Draw
from toponetx.classes.combinatorial_complex import CombinatorialComplex
from tqdm import tqdm

from ccsd.src.utils.graph_utils import adjs_to_graphs, quantize, quantize_mol
from ccsd.src.utils.mol_utils import construct_mol, gen_mol

warnings.filterwarnings("ignore", category=matplotlib.MatplotlibDeprecationWarning)
plotly.io.kaleido.scope.mathjax = None
RDLogger.DisableLog("rdApp.*")
# Parameters to make graph plots look nicer.
options = {"node_size": 2, "edge_color": "black", "linewidths": 1, "width": 0.5}


[docs] def save_fig( config: EasyDict, save_dir: Optional[str] = None, title: str = "fig", dpi: int = 300, is_sample: bool = True, ) -> None: """Function to adjust the figure and save it. Adapted from Jo, J. & al (2022) Args: config (EasyDict): configuration file save_dir (Optional[str], optional): directory to save the figures. Defaults to None. title (str, optional): name of the file. Defaults to "fig". dpi (int, optional): DPI (Dots per Inch). Defaults to 300. is_sample (bool, optional): whether the figure is generated during the sample phase. Defaults to True. """ plt.tight_layout() plt.subplots_adjust(top=0.85) if save_dir is None: plt.show() else: if is_sample: fig_dir = os.path.join(*[config.folder, "samples", "fig", save_dir]) else: fig_dir = os.path.join(*[save_dir, "fig"]) if not os.path.exists(fig_dir): os.makedirs(fig_dir) plt.savefig( os.path.join(fig_dir, title), bbox_inches="tight", dpi=dpi, transparent=False, ) plt.close() return
[docs] def plot_graphs_list( config: EasyDict, graphs: List[Union[nx.Graph, Dict[str, Any]]], title: str = "title", max_num: int = 16, save_dir: Optional[str] = None, N: int = 0, ) -> None: """Plot a list of graphs. Adapted from Jo, J. & al (2022) Args: config (EasyDict): configuration file graphs (List[Union[nx.Graph, Dict[str, Any]]]): graphs to plot title (str, optional): title of the plot. Defaults to "title". max_num (int, optional): number of graphs to plot (must lower or equal than batch size). Defaults to 16. save_dir (Optional[str], optional): directory to save the figures. Defaults to None. N (int, optional): parameter to skip the first graphs of the list. Defaults to 0. """ batch_size = len(graphs) max_num = min(batch_size, max_num) img_c = int(math.ceil(np.sqrt(max_num))) figure = plt.figure() for i in range(max_num): idx = i + max_num * N if not isinstance(graphs[idx], nx.Graph): G = graphs[idx].g.copy() else: G = graphs[idx].copy() assert isinstance(G, nx.Graph) # check if we have a networkx graph G.remove_nodes_from(list(nx.isolates(G))) e = G.number_of_edges() v = G.number_of_nodes() l = nx.number_of_selfloops(G) ax = plt.subplot(img_c, img_c, i + 1) title_str = f"e={e - l}, n={v}" pos = nx.spring_layout(G) nx.draw(G, pos, with_labels=False, **options) ax.title.set_text(title_str) ax.set_axis_off() figure.suptitle(title) save_fig(config=config, save_dir=save_dir, title=title, is_sample=True)
[docs] def save_graph_list( config: EasyDict, log_folder_name: str, exp_name: str, gen_graph_list: List[nx.Graph], ) -> str: """Save the generated graphs in a pickle file. Adapted from Jo, J. & al (2022) Args: config (EasyDict): configuration file log_folder_name (str): name of the folder where the pickle file will be saved exp_name (str): name of the experiment gen_graph_list (List[nx.Graph]): list of generated graphs Returns: str: path to the pickle file """ path = os.path.join(*[config.folder, "samples", "pkl", log_folder_name]) if not (os.path.isdir(path)): os.makedirs(path) save_dir = os.path.join(*[path, f"{exp_name}.pkl"]) with open(save_dir, "wb") as f: pickle.dump(obj=gen_graph_list, file=f, protocol=pickle.HIGHEST_PROTOCOL) return save_dir
[docs] def plot_cc_list( config: EasyDict, ccs: List[Union[CombinatorialComplex, Dict[str, Any]]], title: str = "title", max_num: int = 16, save_dir: Optional[str] = None, N: int = 0, ) -> None: """Plot a list of combinatorial complexes (represented here as hypergraphs), using hypernetx, for complexes of dimension 2. Args: ccs (List[Union[CombinatorialComplexes, Dict[str, Any]]]): combinatorial complexes to plot title (str, optional): title of the plot. Defaults to "title". max_num (int, optional): number of combinatorial complexes to plot (must lower or equal than batch size). Defaults to 16. save_dir (Optional[str], optional): directory to save the figures. Defaults to None. N (int, optional): parameter to skip the first graphs of the list. Defaults to 0. """ batch_size = len(ccs) max_num = min(batch_size, max_num) img_c = int(math.ceil(np.sqrt(max_num))) figure = plt.figure() for i in range(max_num): idx = i + max_num * N if isinstance(ccs[idx], dict): cc = ccs[idx].get("cc", None) else: cc = ccs[idx] assert isinstance( cc, CombinatorialComplex ), "elements should be combinatorial complexes" # check if we have a combinatorial complex v = len(cc.skeleton(0)) # number of vertices (rank 0) e = len(cc.skeleton(1)) # number of edges (rank 1) f = len(cc.skeleton(2)) # number of faces (rank 2) # Isolated nodes removed from the plot automatically as we use the edges/faces # Same for self loops as they can't be represented in a CC edges = cc.skeleton(1) scenes = {i: (tuple([str(n) for n in edge])) for i, edge in enumerate(edges)} scenes.update( { i + e: (tuple([str(n) for n in face])) for i, face in enumerate(cc.skeleton(2)) } ) H = hnx.Hypergraph(scenes) ax = plt.subplot(img_c, img_c, i + 1) hnx.drawing.draw(H, with_edge_labels=False, with_node_labels=False, ax=ax) title_str = f"n={v}, e={e}, f={f}" ax.title.set_text(title_str) ax.set_axis_off() figure.suptitle(title) save_fig(config=config, save_dir=save_dir, title=title, is_sample=True)
[docs] def save_cc_list( config: EasyDict, log_folder_name: str, exp_name: str, gen_cc_list: List[CombinatorialComplex], ) -> str: """Save the generated combinatorial complexes in a pickle file. Args: config (EasyDict): configuration file log_folder_name (str): name of the folder where the pickle file will be saved exp_name (str): name of the experiment gen_cc_list (List[CombinatorialComplex]): list of generated ccs Returns: str: path to the pickle file """ path = os.path.join(*[config.folder, "samples", "pkl", log_folder_name]) if not (os.path.isdir(path)): os.makedirs(path) save_dir = os.path.join(*[path, f"{exp_name}.pkl"]) with open(save_dir, "wb") as f: pickle.dump(obj=gen_cc_list, file=f, protocol=pickle.HIGHEST_PROTOCOL) return save_dir
[docs] def plot_molecule_list( config: EasyDict, mols: List[Chem.Mol], title: str = "title", max_num: int = 16, save_dir: Optional[str] = None, N: int = 0, ) -> None: """Plot a list of molecules, using rdkit. Args: config (EasyDict): configuration file mols (List[Chem.Mol]): molecules to plot title (str, optional): title of the plot. Defaults to "title". max_num (int, optional): number of molecules to plot (must lower or equal than batch size). Defaults to 16. save_dir (Optional[str], optional): directory to save the figures. Defaults to None. N (int, optional): parameter to skip the first graphs of the list. Defaults to 0. """ batch_size = len(mols) max_num = min(batch_size, max_num) img_c = int(math.ceil(np.sqrt(max_num))) figure = plt.figure() for i in range(max_num): idx = i + max_num * N mol = mols[idx] assert isinstance( mol, Chem.Mol ), "elements should be molecules" # check if we have a molecule ax = plt.subplot(img_c, img_c, i + 1) mol_img = Draw.MolToImage(mol, size=(300, 300)) ax.imshow(mol_img) title_str = f"{Chem.MolToSmiles(mol)}" ax.title.set_text(title_str) ax.set_axis_off() figure.suptitle(title) save_fig(config=config, save_dir=save_dir, title=title, is_sample=True)
[docs] def save_molecule_list( config: EasyDict, log_folder_name: str, exp_name: str, gen_mol_list: List[Chem.Mol] ) -> str: """Save the generated molecules in a pickle file. Args: config (EasyDict): configuration file log_folder_name (str): name of the folder where the pickle file will be saved exp_name (str): name of the experiment gen_mol_list (List[Chem.Mol]): list of generated molecules Returns: str: path to the pickle file """ path = os.path.join(*[config.folder, "samples", "pkl", log_folder_name]) if not (os.path.isdir(path)): os.makedirs(path) save_dir = os.path.join(*[path, f"{exp_name}.pkl"]) with open(save_dir, "wb") as f: pickle.dump(obj=gen_mol_list, file=f, protocol=pickle.HIGHEST_PROTOCOL) return save_dir
[docs] def plot_lc( config: EasyDict, learning_curves: Dict[str, List[float]], f_dir: str = "./", filename: str = "learning_curves", cols: int = 3, ) -> None: """Plot the learning curves. Args: config (EasyDict): configuration file learning_curves (Dict[str, List[float]]): dictionary containing the learning curves f_dir (str, optional): directory to save the figure. Defaults to "./". filename (str, optional): name of the figure. Defaults to "learning_curves". cols (int, optional): number of columns in the figure. Defaults to 3. """ rows = int(math.ceil(len(learning_curves) / cols)) figure = plt.figure(figsize=(20, 10)) for i, (curve_name, curve) in enumerate(learning_curves.items()): curve_name = curve_name.replace("_", " ") # make the title more readable ax = plt.subplot(rows, cols, i + 1) ax.plot(curve) ax.title.set_text(curve_name) figure.suptitle("Learning curves") save_fig(config=config, save_dir=f_dir, title=filename, is_sample=False)
[docs] def plot_3D_molecule( molecule: Chem.Mol, atomic_radii: Optional[Dict[str, float]] = None, cpk_colors: Optional[Dict[str, str]] = None, ) -> plotly.graph_objs.Figure: """Creates a 3D plot of the molecule. Args: molecule (Chem.Mol): The RDKit molecule to plot. atomic_radii (Optional[Dict[str, float]], optional): Dictionary mapping atomic symbols to atomic radii. Defaults to None. cpk_colors (Optional[Dict[str, str]], optional): Dictionary mapping atomic symbols to CPK colors. Defaults to None. Returns: plotly.graph_objs.Figure: The 3D plotly figure of the molecule. """ # Default atomic radii from https://en.wikipedia.org/wiki/Atomic_radii_of_the_elements_(data_page) if atomic_radii is None: atomic_radii = {"C": 0.77, "F": 0.71, "H": 0.38, "N": 0.75, "O": 0.73} # Default CPK colors from https://en.wikipedia.org/wiki/CPK_coloring if cpk_colors is None: cpk_colors = {"C": "black", "F": "green", "H": "white", "N": "blue", "O": "red"} # Generate 3D coordinates if not already present. If all methods fail, use 2D coordinates instead. z_coordinates = None if not molecule.GetNumConformers(): try: AllChem.EmbedMolecule(molecule, AllChem.ETKDG()) atom_symbols = [atom.GetSymbol() for atom in molecule.GetAtoms()] atom_positions = molecule.GetConformer().GetPositions() except: try: AllChem.EmbedMolecule(molecule, AllChem.ETKDGv2()) atom_symbols = [atom.GetSymbol() for atom in molecule.GetAtoms()] atom_positions = molecule.GetConformer().GetPositions() except: try: AllChem.EmbedMolecule(molecule, AllChem.ETKDGv3()) atom_symbols = [atom.GetSymbol() for atom in molecule.GetAtoms()] atom_positions = molecule.GetConformer().GetPositions() except: try: AllChem.EmbedMolecule(molecule, AllChem.ETDG()) atom_symbols = [ atom.GetSymbol() for atom in molecule.GetAtoms() ] atom_positions = molecule.GetConformer().GetPositions() except Exception as e: try: current_mol_name = Chem.MolToSmiles(molecule) except: current_mol_name = "Unknown molecule" print( f"Could not embed molecule {current_mol_name}.\nUsing 2D coordinates instead. \nError: {e}\n" ) atom_symbols = [ atom.GetSymbol() for atom in molecule.GetAtoms() ] AllChem.Compute2DCoords(molecule) all_pos = [ molecule.GetConformer().GetAtomPosition(atom.GetIdx()) for atom in molecule.GetAtoms() ] atom_positions = [(pos.x, pos.y) for pos in all_pos] z_coordinates = len(atom_symbols) * [0] x_coordinates = [pos[0] for pos in atom_positions] y_coordinates = [pos[1] for pos in atom_positions] if z_coordinates is None: # if we generated 3D coordinates z_coordinates = [pos[2] for pos in atom_positions] radii = [atomic_radii.get(symbol, 1.0) for symbol in atom_symbols] # Get atom colors colors = [cpk_colors.get(symbol, "gray") for symbol in atom_symbols] def get_bonds() -> Dict[FrozenSet[int], float]: """Generates a set of bonds from the RDKit molecule Returns: Dict[FrozenSet[int], float]: A dictionary mapping pairs of atom indices to bond lengths. """ bonds = dict() for bond in molecule.GetBonds(): i = bond.GetBeginAtomIdx() j = bond.GetEndAtomIdx() dist = np.linalg.norm( np.array(atom_positions[i]) - np.array(atom_positions[j]) ) bonds[frozenset([i, j])] = dist return bonds def atom_trace() -> go.Scatter3d: """Creates an atom trace for the plot Returns: go.Scatter3d: The atom trace """ # Use radii information to adjust atom sizes markers = dict( color=colors, line=dict(color="lightgray", width=2), size=[r * 10 for r in radii], # Multiply by 10 for better visibility symbol="circle", opacity=0.8, ) trace = go.Scatter3d( x=x_coordinates, y=y_coordinates, z=z_coordinates, mode="markers", marker=markers, text=atom_symbols, name="", ) return trace def bond_trace() -> go.Scatter3d: """Creates a bond trace for the plot Returns: go.Scatter3d: The bond trace """ trace = go.Scatter3d( x=[], y=[], z=[], hoverinfo="none", mode="lines", marker=dict(color="grey", size=7, opacity=1), ) for i, j in bonds.keys(): trace["x"] += (x_coordinates[i], x_coordinates[j], None) trace["y"] += (y_coordinates[i], y_coordinates[j], None) trace["z"] += (z_coordinates[i], z_coordinates[j], None) return trace # Get the bonds bonds = get_bonds() # Create annotations zipped = zip(range(len(atom_symbols)), x_coordinates, y_coordinates, z_coordinates) annotations_id = [ dict( text=num, x=x, y=y, z=z, showarrow=False, yshift=15, font=dict(color="blue") ) for num, x, y, z in zipped ] annotations_length = [] for (i, j), dist in bonds.items(): x_middle, y_middle, z_middle = ( np.array(atom_positions[i]) + np.array(atom_positions[j]) ) / 2 annotation = dict( text=f"{dist:.2f}", x=x_middle, y=y_middle, z=z_middle, showarrow=False, yshift=15, ) annotations_length.append(annotation) # Atom indices & Bond lengths annotations = annotations_id + annotations_length # Create the layout data = [atom_trace(), bond_trace()] axis_params = dict( showgrid=False, showbackground=False, showticklabels=False, zeroline=False, titlefont=dict(color="white"), ) layout = dict( scene=dict( xaxis=axis_params, yaxis=axis_params, zaxis=axis_params, annotations=annotations, ), margin=dict(r=0, l=0, b=0, t=0), showlegend=False, ) # Create the figure fig = go.Figure(data=data, layout=layout) return fig
[docs] def rotate_molecule_animation( figure: plotly.graph_objs.Figure, filedir: str, filename: str, duration: float = 1.0, frames: int = 30, rotations_per_sec: float = 1.0, overwrite: bool = False, engine: str = "kaleido", ) -> None: """Creates an animated GIF of the molecule rotating. Args: figure (plotly.graph_objs.Figure): The 3D plotly figure of the molecule. filedir (str): The directory to save the animated GIF. filename (str): The filename of the output animated GIF. duration (float, optional): Duration of the animation in seconds. Defaults to 1.0. frames (int, optional): Number of frames in the animation. Defaults to 30. rotations_per_sec (float, optional): Number of rotations per second. Defaults to 1.0. overwrite (bool, optional): If True, overwrite the file if it already exists. Defaults to False. engine (str, optional): engine to use for the .write_image plotly method. Defaults to "kaleido". """ # Remove .gif extension if provided if filename.lower().endswith(".gif"): filename = filename[:-4] if not overwrite: # Check if the file already exists if os.path.isfile(os.path.join(filedir, f"{filename}.gif")): raise FileExistsError( f"{filename}.gif already exists. Set overwrite=True to overwrite the file." ) # Define the frame range animation_range = range(int(duration * frames)) # Create the animation frames print("Creating the rotating molecule animation ...") if not (os.path.isdir("tempdir_animation")): os.makedirs("tempdir_animation") # put the images in a temporary directory for i in tqdm(animation_range): layout = figure.layout layout["scene"]["camera"]["eye"] = dict( x=2 * np.sin(2 * np.pi * i * rotations_per_sec / frames), y=2 * np.cos(2 * np.pi * i * rotations_per_sec / frames), z=1, ) fig = go.Figure(data=figure.data, layout=layout) if isinstance(fig, plotly.graph_objs.Figure): # plotly fig.write_image( os.path.join("tempdir_animation", f"{filename}_{i}.png"), engine=engine ) plt.close() elif isinstance(fig, matplotlib.figure.Figure): # matplotlib fig.savefig(os.path.join("tempdir_animation", f"{filename}_{i}.png")) plt.close() else: raise TypeError( "The figure must be either a plotly.graph_objs.Figure or a matplotlib.figure.Figure. " "Otherwise, it has not been implemented yet." ) # Create the GIF print("Loading the images ...") images = [] for i in tqdm(animation_range): images.append( imageio.imread(os.path.join("tempdir_animation", f"{filename}_{i}.png")) ) print("Creating the gif ...") imageio.imwrite( os.path.join(filedir, f"{filename}.gif"), images, duration=(1 / frames), loop=0 ) # Delete the images and the folder print("Deleting the images ...") for i in animation_range: if os.path.exists(os.path.join("tempdir_animation", f"{filename}_{i}.png")): os.remove(os.path.join("tempdir_animation", f"{filename}_{i}.png")) os.rmdir("tempdir_animation") # delete the folder
[docs] def plot_diffusion_trajectory( gen_obj: List[torch.Tensor], is_molecule: bool = False, dataset: str = "QM9", largest_connected_comp: bool = True, ) -> Union[plotly.graph_objs.Figure, matplotlib.figure.Figure]: """Return the figure of one generated object as part of a diffusion trajectory. Args: gen_obj (List[torch.Tensor]): The generated object (node features (x) and adjacency matrix (adj), and rank-2 incidence matrix (rank2) if we generated combinatorial complexes). is_molecule (bool, optional): if True, we plot a molecule, otherwise a graph. Defaults to False. dataset (str, optional): The dataset from which the object was generated. Defaults to "QM9" (only used if is_molecule=True). largest_connected_comp (bool, optional): whether or not we keep only the largest connected component. Defaults to True. Returns: Union[plotly.graph_objs.Figure, matplotlib.figure.Figure]: The figure of the generated object. """ x_, adj_ = gen_obj[0], gen_obj[1] fig = plt.figure(figsize=(10, 10)) if is_molecule: # Preprocess the matrices x, adj = x_.unsqueeze(0), adj_.unsqueeze(0) # batch it samples_int = quantize_mol(adj) samples_int = samples_int - 1 samples_int[samples_int == -1] = 3 # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2 adj = torch.nn.functional.one_hot( torch.tensor(samples_int), num_classes=4 ).permute(0, 3, 1, 2) x = torch.where(x > 0.5, 1, 0) x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1) # Generate the molecule mol = gen_mol(x, adj, dataset, largest_connected_comp)[0] if not (mol): # if it failed, try to build the molecule without adjustment if dataset == "QM9": atomic_num_list = [6, 7, 8, 9, 0] else: atomic_num_list = [6, 7, 8, 9, 15, 16, 17, 35, 53, 0] mol = construct_mol( x[0].detach().cpu().numpy(), adj[0].detach().cpu().numpy(), atomic_num_list, ) else: # else, keep the generated molecule mol = mol[0] # Plot if ( mol is None ): # if the molecule is not built by rdkit, draw a blank image instead mol_img = np.ones((300, 300)) plt.imshow(mol_img, cmap="binary") else: mol_img = Draw.MolToImage(mol, size=(300, 300)) plt.imshow(mol_img) title_str = f"{Chem.MolToSmiles(mol)}" else: samples_int = quantize(adj_.unsqueeze(0)) G = adjs_to_graphs(samples_int, True)[0] G.remove_nodes_from(list(nx.isolates(G))) e = G.number_of_edges() v = G.number_of_nodes() l = nx.number_of_selfloops(G) title_str = f"e={e - l}, n={v}" pos = nx.spring_layout(G) nx.draw(G, pos, with_labels=False, **options) plt.title(title_str) plt.axis("off") return fig
[docs] def diffusion_animation( diff_traj: List[List[torch.Tensor]], is_molecule: bool = False, filedir: str = "./", filename: str = "diffusion_animation", fps: int = 25, overwrite: bool = True, engine: str = "kaleido", duration: float = 4.0, cropped: bool = False, ) -> None: """Creates an animated GIF of the diffusion trajectory. Args: diff_traj (List[List[torch.Tensor]]): The diffusion trajectory (list of generated node features (x) and adjacency matrices (adj), and rank-2 incidence matrices (rank2) if we generated combinatorial complexes). is_molecule (bool, optional): If True, the frames are molecules not graphs. Defaults to False. filedir (str, optional): The directory to save the animated GIF. Defaults to "./". filename (str, optional): The filename of the output animated GIF. Defaults to "diffusion_animation". fps (int, optional): Number of frames per second. Defaults to 25. overwrite (bool, optional): If True, overwrite the file if it already exists. Defaults to True. engine (str, optional): engine to use for the .write_image plotly method if plotly is used. Defaults to "kaleido". duration (float, optional): duration of the animation (in seconds). Defaults to 4.0. cropped (bool, optional): if True, we select the first frames. Otherwise, we skip some frames to build the animation. Defaults to False. """ # Check if we have enough frames nb_frames = int(np.ceil(duration * fps)) assert nb_frames < len( diff_traj ), f"Not enough frames ({len(diff_traj)}/{nb_frames}) to build a {fps}FPS animation of {duration} secondes." # Define the frame range if cropped: animation_range = range(nb_frames) else: step = len(diff_traj) // nb_frames animation_range = range(0, len(diff_traj), step) # Remove .gif extension if provided if filename.lower().endswith(".gif"): filename = filename[:-4] if not overwrite: # Check if the file already exists if os.path.isfile(os.path.join(filedir, f"{filename}.gif")): raise FileExistsError( f"{filename}.gif already exists. Set overwrite=True to overwrite the file." ) # Create the animation frames fig_type = "molecule" if is_molecule else "graph" print(f"Creating the {fig_type} diffusion animation ...") if not (os.path.isdir("tempdir_animation")): os.makedirs("tempdir_animation") # put the images in a temporary directory for i in tqdm(animation_range): fig = plot_diffusion_trajectory(diff_traj[i], is_molecule=is_molecule) if isinstance(fig, plotly.graph_objs.Figure): # plotly fig.write_image( os.path.join("tempdir_animation", f"diffusion_{i}.png"), engine=engine ) plt.close() elif isinstance(fig, matplotlib.figure.Figure): # matplotlib fig.savefig(os.path.join("tempdir_animation", f"diffusion_{i}.png")) plt.close() else: raise TypeError( "The figure must be either a plotly.graph_objs.Figure or a matplotlib.figure.Figure. " "Otherwise, it has not been implemented yet." ) # Create the GIF print("Loading the images ...") images = [] for i in tqdm(animation_range): images.append( imageio.imread(os.path.join("tempdir_animation", f"diffusion_{i}.png")) ) print("Creating the gif ...") filepath = os.path.join(filedir, filename) imageio.imwrite(f"{filepath}.gif", images, duration=(1 / fps), loop=0) # Delete the images print("Deleting the images ...") for i in animation_range: if os.path.exists(os.path.join("tempdir_animation", f"diffusion_{i}.png")): os.remove(os.path.join("tempdir_animation", f"diffusion_{i}.png")) os.rmdir("tempdir_animation") # delete the folder