"""General display functions."""
import random
import copy
from PIL import Image
import imageio
import colorsys
from pathlib import Path
import glob
import numpy as np
import contextlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os
# Set the environment variable to turn off the pyart welcome message
os.environ["PYART_QUIET"] = "True"
import pyart.graph.cm_colorblind as pcm
import thuner.visualize.utils as utils
from thuner.log import setup_logger
logger = setup_logger(__name__)
style = "presentation"
__all__ = ["animate_object", "set_style"]
def discrete_cmap_norm(
levels, cmap_name="Reds", pad_start=0, pad_end=0, extend="neither"
):
"""Create a discrete colormap."""
number_levels = len(levels)
extend_above = 1 if extend in ["both", "max"] else 0
extend_below = 1 if extend in ["both", "min"] else 0
number_colors = pad_start + extend_below + number_levels + extend_above + pad_end
cmap = plt.get_cmap(cmap_name, number_colors)
colors = list(cmap(np.arange(0, number_colors)))
if extend in ["both", "max"]:
end = -pad_end - extend_above if pad_end != 0 else -extend_above
else:
end = -pad_end if pad_end != 0 else None
start = pad_start + extend_below
cmap = mcolors.ListedColormap(colors[start:end], f"{cmap_name}_discrete")
norm = mcolors.BoundaryNorm(levels, ncolors=number_levels, clip=False)
if extend in ["both", "min"]:
cmap.set_under(colors[start - extend_below])
if extend in ["both", "max"]:
cmap.set_over(colors[end])
return cmap, norm
def desaturate_colormap(cmap, factor=0.15):
"""Desaturate a colormap by a given factor."""
colors = cmap(np.linspace(0, 1, cmap.N))
hls_colors = [colorsys.rgb_to_hls(*color[:3]) for color in colors]
desaturated_hls_colors = [(h, l, s * factor) for h, l, s in hls_colors]
desaturated_rgb_colors = [
colorsys.hls_to_rgb(h, l, s) for h, l, s in desaturated_hls_colors
]
desaturated_cmap = mcolors.ListedColormap(
desaturated_rgb_colors, name=f"{cmap.name}_desaturated"
)
return desaturated_cmap
def hls_colormap(N=1, lightness=0.9, saturation=1):
"""Create a hls colormap."""
hls_colors = [(i / N, lightness, saturation) for i in range(N)]
rgb_colors = [colorsys.hls_to_rgb(h, l, s) for h, l, s in hls_colors]
hls_colormap = mcolors.ListedColormap(rgb_colors, name=f"hls_{lightness}_{N}")
return hls_colormap
mask_colors = ["cyan", "magenta", "gold", "cyan"]
mask_colormap = mcolors.LinearSegmentedColormap.from_list("mask", mask_colors, N=64)
runtime_colormap = mcolors.LinearSegmentedColormap.from_list("mask", mask_colors, N=12)
runtime_colors = [runtime_colormap(i) for i in range(12)]
runtime_colors = [mcolors.to_hex(color) for color in runtime_colors]
random.seed(4189)
random.shuffle(runtime_colors)
[docs]
@contextlib.contextmanager
def set_style(new_style):
"""Custom style manager for matplotlib."""
global style
original_style = style
style = new_style
try:
yield
finally:
style = original_style
# Desaturate the HomeyerRainbow colormap
desaturated_homeyer_rainbow = desaturate_colormap(pcm.HomeyerRainbow, factor=0.35)
reflectivity_levels = np.arange(-10, 60 + 5, 5)
reflectivity_norm = mcolors.BoundaryNorm(
reflectivity_levels, ncolors=desaturated_homeyer_rainbow.N, clip=True
)
brightness_rainbow = desaturate_colormap(pcm.HomeyerRainbow, factor=0.35)
brightness_rainbow.set_over((0, 0, 0, 0)) # Set over color to transparent
brightness_rainbow.set_under((0, 0, 0, 0))
brightness_levels = np.arange(180, 250 + 10, 10)
brightness_norm = mcolors.BoundaryNorm(
brightness_levels, ncolors=brightness_rainbow.N, clip=False
)
pcolormesh_style = {
"reflectivity": {
"cmap": desaturated_homeyer_rainbow,
"shading": "nearest",
"norm": reflectivity_norm,
},
"brightness_temperature": {
"cmap": brightness_rainbow,
"shading": "nearest",
"norm": brightness_norm,
},
}
figure_colors = {
"paper": {
"land": tuple(np.array([249.0, 246.0, 216.0]) / (256)),
"sea": tuple(np.array([240.0, 240.0, 256.0]) / (256)),
"coast": "black",
"legend": "w",
"key": "k",
"ellipse_axis": "w",
"ellipse_axis_shadow": "grey",
},
"presentation": {
"land": tuple(np.array([249.0, 246.0, 216.0]) / (256 * 3.5)),
"sea": tuple(np.array([245.0, 245.0, 256.0]) / (256 * 3.5)),
"coast": "white",
"legend": tuple(np.array([249.0, 246.0, 216.0]) / (256 * 3.5)),
"key": "tab:purple",
"ellipse_axis": "w",
"ellipse_axis_shadow": "k",
},
}
figure_colors["gadi"] = figure_colors["presentation"]
base_styles = {
"paper": "default",
"presentation": "dark_background",
"gadi": "dark_background",
}
custom_styles_dir = Path(__file__).parent / "styles"
styles = {
style: [base_styles[style], custom_styles_dir / f"{style}.mplstyle"]
for style in base_styles.keys()
}
def get_filepaths_dates(directory):
filepaths = np.array(sorted(glob.glob(str(directory / "*.png"))))
dates = []
for filepath in filepaths:
date = Path(filepath).stem
date = f"{date[:8]}"
dates.append(date)
dates = np.array(dates)
return filepaths, dates
def animate_all(visualize_options, output_directory):
if visualize_options is None:
return
for obj_options in visualize_options.objects.values():
for fig_options in obj_options.figures:
if fig_options.animate:
animate_object(fig_options.name, obj_options.name, output_directory)
[docs]
def animate_object(
fig_type,
obj,
output_directory,
save_directory=None,
figure_directory=None,
animation_name=None,
by_date=True,
):
"""
Animate object figures.
"""
if save_directory is None:
save_directory = output_directory / "visualize" / fig_type
if figure_directory is None:
figure_directory = output_directory / "visualize" / fig_type / obj
if animation_name is None:
animation_name = obj
logger.info(f"Animating {fig_type} figures for {obj} objects.")
filepaths, dates = get_filepaths_dates(figure_directory)
if by_date:
for date in np.unique(dates):
filepaths_date = filepaths[dates == date]
output_filepath = save_directory / f"{animation_name}_{date}.gif"
logger.info(f"Saving animation to {output_filepath}.")
images = [Image.open(f).convert("RGBA") for f in filepaths_date]
kwargs = {"duration": 200, "loop": 0}
imageio.mimsave(output_filepath, images, **kwargs)
else:
output_filepath = save_directory / f"{animation_name}.gif"
logger.info(f"Saving animation to {output_filepath}.")
images = [Image.open(f).convert("RGBA") for f in filepaths]
kwargs = {"duration": 200, "loop": 0}
imageio.mimsave(output_filepath, images, **kwargs)
def get_grid(time, filename, field, data_options, grid_options):
"""
Get the grid from a file.
"""
grid = utils.load_grid(filename)
return grid[field]