Source code for thuner.visualize.attribute

"""Functions for visualizing object attributes and classifications."""

import gc
from pathlib import Path
from pydantic import Field, model_validator
import tempfile
from typing import Any, Dict
from time import sleep
import xesmf as xe
import multiprocessing as mp
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import thuner.visualize.horizontal as horizontal
from thuner.utils import initialize_process, check_results
from thuner.utils import format_time, new_angle, circular_mean
from thuner.utils import BaseHandler, AttributeHandler
from thuner.attribute.utils import read_attribute_csv
from thuner.analyze.utils import read_options
import thuner.detect.detect as detect
import thuner.visualize.utils as utils
import thuner.visualize.visualize as visualize
from thuner.option.visualize import FigureOptions, GroupedHorizontalAttributeOptions
from thuner.option.visualize import HorizontalAttributeOptions
from thuner.log import setup_logger, logging_listener
from thuner.config import get_outputs_directory


__all__ = ["series", "grouped_horizontal"]

logger = setup_logger(__name__)
proj = ccrs.PlateCarree()


mcs_legend_options = {"ncol": 3, "loc": "lower center"}


def get_altitude_labels(
    track_options,
    object_name="mcs",
    object_level=1,
    member_objects=None,
    member_levels=None,
):
    """Get altitude labels for convective and stratiform objects."""
    object_options = track_options.levels[object_level].object_by_name(object_name)
    if member_objects is None:
        member_objects = object_options.grouping.member_objects
    if member_levels is None:
        all_member_objects = object_options.grouping.member_objects
        all_member_levels = object_options.grouping.member_levels
        member_levels = []
        for i, level in enumerate(all_member_levels):
            if all_member_objects[i] in member_objects:
                member_levels.append(level)
    labels = []
    for i, obj in enumerate(member_objects):
        level = member_levels[i]
        options = track_options.levels[level].object_by_name(obj)
        altitudes = np.array(options.detection.altitudes)
        altitudes = np.round(altitudes / 1e3, 1)
        labels.append(f"{altitudes[0]:g} to {altitudes[1]:g} km")
    return labels


[docs] def series( output_directory: str | Path, start_time, end_time, figure_options, dataset_name, animate=True, parallel_figure=False, by_date=True, num_processes=4, ): """Visualize attributes at specified times.""" # Setup plt backend plt.close("all") original_backend = matplotlib.get_backend() matplotlib.use("Agg") # Setup times and masks start_time = np.datetime64(start_time) end_time = np.datetime64(end_time) options = read_options(output_directory) object_name = figure_options.object_name masks_filepath = output_directory / f"masks/{object_name}.zarr" masks = xr.open_dataset(masks_filepath, engine="zarr") times = masks.time.values times = times[(times >= start_time) & (times <= end_time)] figure_function = figure_options.method.function # Initialize the paths to save xesmf regridder weights dataset_options = options["data"].dataset_by_name(dataset_name) if dataset_options.reuse_regridder: if dataset_options.weights_filepath is None: filepath = output_directory / "records/regridder_weights" filepath = filepath / f"{dataset_options.name}.nc" dataset_options.weights_filepath = filepath # Start with first time args = [times[0], masks, output_directory, figure_options.model_dump()] args += [options, dataset_name] figure_function(*args) if len(times) == 1: # Switch back to original backend plt.close("all"), matplotlib.use(original_backend) return if parallel_figure: kwargs = {"initializer": initialize_process, "processes": num_processes} with logging_listener(), mp.get_context("spawn").Pool(**kwargs) as pool: results = [] for time in times[1:]: sleep(2) # Note need to define a new args for each iteration! Can't simply # change the first element, or we break parallization! Note also it's # bad practice to pass dataframes to mp workers. args = [time, masks, output_directory] args += [figure_options.model_dump()] args += [options, dataset_name] args = tuple(args) results.append(pool.apply_async(figure_function, args)) pool.close() pool.join() check_results(results) else: for time in times[1:]: args[0] = time figure_function(*args) if animate: figure_name = figure_options.name save_directory = output_directory / f"visualize" figure_directory = output_directory / f"visualize/{figure_name}" args = [figure_name, "mcs", output_directory, save_directory] args += [figure_directory, figure_name] visualize.animate_object(*args, by_date=by_date) # Close all figures to clear memory plt.close("all") # Switch back to original backend matplotlib.use(original_backend)
def get_mask_grid_boundary( object_name, time, filepaths_df, masks, dataset_name, options ): """Get the mask and grid for a given time.""" filepath = filepaths_df[dataset_name].loc[time] dataset_options = options["data"].dataset_by_name(dataset_name) object_level = options["track"].object_by_name(object_name).hierarchy_level message = f"Converting {dataset_name}." logger.debug(message) args = [time, filepath, options["track"], options["grid"]] outs = dataset_options.convert_dataset(*args) ds, boundary_coords, simple_boundary_coords = outs del boundary_coords logger.debug(f"Getting grid from dataset at time {time}.") if len(dataset_options.fields) > 1: raise ValueError("Non-unique dataset field.") grid = dataset_options.grid_from_dataset(ds, dataset_options.fields[0], time) del ds logger.debug(f"Rebuilding processed grid for time {time}.") args = [grid, options["track"], object_name, object_level] processed_grid = detect.rebuild_processed_grid(*args) del grid mask = masks.sel(time=time).load() grid_time = processed_grid.time.values mask_time = mask.time.values if grid_time != time or mask_time != time: message = f"Grid or mask time {grid_time} does not match requested time {time}." raise ValueError(message) return mask, processed_grid, simple_boundary_coords def get_object_colors(time, color_angle_df): """Get the object colors for a given time.""" keys = color_angle_df.loc[color_angle_df["time"] == time]["universal_id"].values values = color_angle_df.loc[color_angle_df["time"] == time]["color_angle"].values values = [visualize.mask_colormap(v / (2 * np.pi)) for v in values] return dict(zip(keys, values)) def detected_horizontal( time, masks, output_directory, figure_options_dict, options, dataset_name, ): """Create a horizontal cross section plot.""" logger.info(f"Visualizing attributes at time {time}.") # Rebuild the figure options figure_options = HorizontalAttributeOptions(**figure_options_dict) object_name = figure_options.object_name # Get filepaths dataframe record_filepath = output_directory / f"records/filepaths/{dataset_name}.csv" filepaths_df = read_attribute_csv(record_filepath, columns=[dataset_name]) # Setup colors color_angle_df = get_color_angle_df(object_name, output_directory) grid_options = options["grid"] obj_name = figure_options.object_name args = [obj_name, time, filepaths_df, masks, dataset_name, options] mask, grid, boundary_coords = get_mask_grid_boundary(*args) mask = mask[obj_name + "_mask"] grid = grid[obj_name + "_grid"] object_colors = get_object_colors(time, color_angle_df) time = grid.time.values style = figure_options.style attribute_handlers = figure_options.attribute_handlers args = [grid, mask, grid_options, figure_options, boundary_coords] kwargs = {"object_colors": object_colors} with plt.style.context(visualize.styles[style]), visualize.set_style(style): figure_features = horizontal.detected_mask(*args, **kwargs) fig, subplot_axes, colorbar_axes, legend_axes = figure_features # Create the grouped object figure instance kwargs = {"object_name": object_name, "time": time, "grid": grid, "mask": mask} kwargs.update({"boundary_coordinates": boundary_coords}) kwargs.update({"attribute_handlers": attribute_handlers}) kwargs.update({"figure": fig, "subplot_axes": subplot_axes}) kwargs.update({"colorbar_axes": colorbar_axes, "legend_axes": legend_axes}) core_filepath = output_directory / f"attributes/{obj_name}/core.csv" kwargs["core_filepath"] = str(core_filepath) detected_figure = BaseFigure(**kwargs) # Remove duplicate mask and grid from memory after generating the figure del mask, grid, boundary_coords add_attributes(time, detected_figure) create_legend(detected_figure, grid_options, figure_options) filename = f"{format_time(time)}.png" filepath = output_directory / f"visualize/{figure_options.name}/{filename}" filepath.parent.mkdir(parents=True, exist_ok=True) logger.info(f"Saving {figure_options.name} figure for {time}.") with plt.style.context(visualize.styles[style]), visualize.set_style(style): detected_figure.figure.savefig(filepath, bbox_inches="tight") del detected_figure utils.reduce_color_depth(filepath) plt.clf(), plt.close(), gc.collect()
[docs] def grouped_horizontal( time, masks, output_directory, figure_options_dict, options, dataset_name, ): """Create a horizontal cross section plot.""" logger.info(f"Visualizing attributes at time {time}.") # Rebuild the figure options figure_options = GroupedHorizontalAttributeOptions(**figure_options_dict) # Get filepaths dataframe record_filepath = output_directory / f"records/filepaths/{dataset_name}.csv" filepaths_df = read_attribute_csv(record_filepath, columns=[dataset_name]) obj_name = figure_options.object_name # Setup colors color_angle_df = get_color_angle_df(obj_name, output_directory) grid_options = options["grid"] args = [obj_name, time, filepaths_df, masks, dataset_name, options] mask, grid, boundary_coords = get_mask_grid_boundary(*args) object_colors = get_object_colors(time, color_angle_df) time = grid.time.values style = figure_options.style member_objects = figure_options.member_objects attribute_handlers = figure_options.attribute_handlers args = [grid, mask, grid_options, figure_options, member_objects] args += [boundary_coords] kwargs = {"object_colors": object_colors} with plt.style.context(visualize.styles[style]), visualize.set_style(style): figure_features = horizontal.grouped_mask(*args, **kwargs) fig, subplot_axes, colorbar_axes, legend_axes = figure_features # Set the subplot figure titles to altitudes if specified if figure_options.altitude_titles: # Get altitude labels for the member objects track_options = options["track"] args = [track_options, obj_name, 1, member_objects] altitude_labels = get_altitude_labels(*args) for i, label in enumerate(altitude_labels): subplot_axes[i].set_title(label) # Create the grouped object figure instance kwargs = {"object_name": obj_name, "time": time, "grid": grid, "mask": mask} kwargs.update({"boundary_coordinates": boundary_coords}) kwargs.update({"attribute_handlers": attribute_handlers}) kwargs.update({"member_objects": member_objects}) kwargs.update({"figure": fig, "subplot_axes": subplot_axes}) kwargs.update({"colorbar_axes": colorbar_axes, "legend_axes": legend_axes}) core_filepath = output_directory / f"attributes/{obj_name}/core.csv" kwargs["core_filepath"] = str(core_filepath) base_directory = output_directory / f"attributes/{obj_name}/" filepaths_list = [str(base_directory / f"{obj}/core.csv") for obj in member_objects] kwargs["member_core_filepaths"] = dict(zip(member_objects, filepaths_list)) grouped_figure = GroupedObjectFigure(**kwargs) # Remove duplicate mask and grid from memory after generating the figure del mask, grid, boundary_coords add_attributes(time, grouped_figure) create_legend(grouped_figure, grid_options, figure_options) filename = f"{format_time(time)}.png" filepath = output_directory / f"visualize/{figure_options.name}/{filename}" filepath.parent.mkdir(parents=True, exist_ok=True) logger.info(f"Saving {figure_options.name} figure for {time}.") with plt.style.context(visualize.styles[style]), visualize.set_style(style): grouped_figure.figure.savefig(filepath, bbox_inches="tight") del grouped_figure utils.reduce_color_depth(filepath) plt.clf(), plt.close(), gc.collect()
def add_attribute( ax, object_name, handler, attribute_artists, legend_artists, time, figure ): """Add an attribute to the figure for a given object.""" # Get attribute df attribute_artists[object_name][handler.name] = {} attribute_df = read_attribute_csv(handler.filepath, times=[time]) kwargs = {"times": [time], "columns": handler.quality_variables} quality_df = read_attribute_csv(handler.quality_filepath, **kwargs) if handler.quality_method == "all": quality_df = quality_df.all(axis=1) elif handler.quality_method == "any": quality_df = quality_df.any(axis=1) try: id_type = "universal_id" object_ids = attribute_df.reset_index()[id_type].values except KeyError: id_type = "id" object_ids = attribute_df.reset_index()[id_type].values # Will also need to load in core attributes if hasattr(figure, "member_core_filepaths"): core_filepath = figure.member_core_filepaths[object_name] elif hasattr(figure, "core_filepath"): core_filepath = figure.core_filepath core_df = read_attribute_csv(core_filepath, times=[time]) # Join the core attributes with the attribute df # Prepend column names with handler name if necessary for col in core_df.columns: if col in attribute_df.columns: core_df.rename(columns={col: f"{handler.name}_{col}"}, inplace=True) attribute_df = attribute_df.join(core_df) leg_method = handler.legend_method if leg_method is not None and handler.name not in legend_artists.keys(): # Create the legend artist func = leg_method.function keyword_arguments = leg_method.keyword_arguments legend_artist = func(**keyword_arguments) legend_artists[handler.label] = legend_artist for obj_id in object_ids: # Add the attribute for the given object to the figure object_df = attribute_df.xs(obj_id, level=id_type, drop_level=False) obj_quality_df = quality_df.xs(obj_id, level=id_type, drop_level=False) attributes = handler.attributes func = handler.method.function kwargs = handler.method.keyword_arguments artist = func(ax, attributes, object_df, obj_quality_df, **kwargs) attribute_artists[object_name][handler.name][obj_id] = artist def add_attributes(time, figure): """Add all the requisite attributes to the figure.""" legend_artists = {} attribute_artists = {} for i, obj in enumerate(figure.attribute_handlers.keys()): attribute_artists[obj] = {} for handler in figure.attribute_handlers[obj]: ax = figure.subplot_axes[i] args = [ax, obj, handler, attribute_artists, legend_artists, time] args += [figure] add_attribute(*args) figure.legend_artists = legend_artists figure.attribute_artists = attribute_artists def create_legend(figure, grid_options, figure_options): """Create a legend for the figure.""" legend_options = {"ncol": 3, "loc": "lower center"} scale = visualize.utils.get_extent(grid_options)[1] handles, labels = [], [] handle, handler = horizontal.mask_legend_artist() handles += [handle] labels += ["Object Masks"] handle = horizontal.domain_boundary_legend_artist() handles += [handle] labels += ["Domain Boundary"] handles += list(figure.legend_artists.values()) labels += list(figure.legend_artists.keys()) legend_color = visualize.figure_colors[figure_options.style]["legend"] args = [handles, labels] style = figure_options.style leg_ax = figure.legend_axes[0] with plt.style.context(visualize.styles[style]), visualize.set_style(style): if scale == 1: legend = leg_ax.legend(*args, **mcs_legend_options, handler_map=handler) elif scale == 2: legend_options["loc"] = "lower left" legend_options["bbox_to_anchor"] = (-0.0, -0.425) legend = leg_ax.legend(*args, **mcs_legend_options, handler_map=handler) legend.get_frame().set_alpha(None) legend.get_frame().set_facecolor(legend_color) class BaseFigure(BaseHandler): """Base class for a figure visualizing a field, objects, and object attributes.""" object_name: str = Field(..., description="The name of the object.") time: np.datetime64 = Field(..., description="The time of the figure.") _desc = "A dictionary with a list of attribute handlers for the given object." attribute_handlers: dict[str, list[AttributeHandler]] = Field([], description=_desc) _desc = "The artist used to visualize the domain boundary." boundary_artists: list[Any] = Field([], description=_desc) _desc = "The artists used to visualize the field, e.g. reflectivity." field_artists: list[Any] = Field([], description=_desc) _desc = "The artists used to visualize object masks." mask_artists: list[Any] = Field([], description=_desc) _desc = "The artists used to visualize object attributes." attribute_artists: Dict[str, Any] = Field({}, description=_desc) _desc = "The proxy artists used for creating legends." legend_artists: Dict[str, Any] = Field({}, description=_desc) _desc = "Layout class instance for the figure." layout: utils.BaseLayout = Field(None, description=_desc) _desc = "Options for the figure." options: FigureOptions | None = Field(None, description=_desc) figure: Any = Field(None, description="The Matplotlib figure object.") _desc = "The Matplotlib axes containing subplots." subplot_axes: list[Any] = Field([], description=_desc) _desc = "The Matplotlib axes containing legends." legend_axes: list[Any] = Field([], description=_desc) _desc = "The Matplotlib axes containing colorbars." colorbar_axes: list[Any] = Field([], description=_desc) _desc = "Filepath to csv of core attributes." core_filepath: str | None = Field(None, description=_desc) class GroupedObjectFigure(BaseFigure): """Class for visualizing grouped objects.""" member_objects: list[str] = Field([], description="Member object names.") _desc = "Filepaths to core attributes for each member object." member_core_filepaths: dict[str, str] = Field({}, description=_desc) @model_validator(mode="after") def _check_number_subplots(cls, values): """ Check the number of subplots matches the number of member objects and number of member_core_filepaths keys. """ lengths = [len(values.member_objects), len(values.subplot_axes)] lengths += [len(values.member_core_filepaths)] if len(set(lengths)) != 1: message = "Number of member objects, subplot axes, and member core " message += "filepaths must agree." raise ValueError(message) def velocity_horizontal( ax, attributes, object_df, quality_df, color="tab:red", dt=3600, reverse=False ): """ Add velocity attributes. Assumes the attribtes dataframe has already been subset to the desired time and object, so is effectively a dictionary. """ latitude = object_df["latitude"].values[0] longitude = object_df["longitude"].values[0] u, v = object_df[attributes[0]].values[0], object_df[attributes[1]].values[0] args = [ax, latitude, longitude, u, v, color] kwargs = {"quality": quality_df.values, "dt": dt, "reverse": reverse} return horizontal.cartesian_velocity(*args, **kwargs) def text_horizontal( ax, attributes, object_df, quality_df, formatter=None, labelled_attribute="universal_id", ): """Add object ID attributes.""" latitude = object_df["latitude"].values[0] longitude = object_df["longitude"].values[0] label = object_df.reset_index()[labelled_attribute].values[0] if formatter is not None: label = formatter(label) args = [ax, label, longitude, latitude] if quality_df.values[0]: return horizontal.embossed_text(*args) else: return None def orientation_horizontal(ax, attributes, object_df, quality_df=None): """Add orientation attributes to axes.""" latitude = object_df["orientation_latitude"].values[0] longitude = object_df["orientation_longitude"].values[0] if "major" in attributes: length = object_df["major"].values[0] elif "minor" in attributes: length = object_df["minor"].values[0] else: raise ValueError("No major or minor attribute in object_df.") if "orientation" in attributes: orientation = object_df["orientation"].values[0] else: raise ValueError("No orientation attribute in object_df.") args = [ax, latitude, longitude, length, orientation, quality_df.values] return horizontal.ellipse_axis(*args) def displacement_horizontal( ax, attributes, object_df, quality_df, color="tab:blue", reverse=False ): """Add displacement attributes.""" # Convert displacements from km to metres dx = object_df[attributes[0]].values[0] * 1e3 dy = object_df[attributes[1]].values[0] * 1e3 if reverse: dx, dy = -dx, -dy latitude = object_df["latitude"].values[0] longitude = object_df["longitude"].values[0] args = [ax, latitude, longitude, dx, dy, color, quality_df.values] return horizontal.cartesian_displacement(*args, arrow=True, reverse=reverse) def convert_parents(parents): """Convert a parents string to a list of integers.""" if str(parents) == "nan": return [] parents_list = parents.split(" ") return [int(parent) for parent in parents_list] def get_parent_angles(df, row, color_dict, previous_time): """Get the parent angles for the object in row.""" obj_parents = convert_parents(row["parents"]) parent_angles = [] areas = [] for parent in obj_parents: dict_universal_ids = np.array(color_dict["universal_id"]) times = np.array(color_dict["time"]) cond = (dict_universal_ids == parent) & (times == previous_time) parent_angle = np.array(color_dict["color_angle"])[cond][0] parent_angles.append(parent_angle) parent_universal_id = dict_universal_ids[cond][0] area = df.loc[previous_time, parent_universal_id]["area"] areas.append(area) return parent_angles, areas def new_color_angle(df, row, color_dict, previous_time, angle_list): """Get a new color for the new object in row.""" # Object not yet in color_dict if str(row["parents"]) == "nan": # If object has no parents, get a new color angle as different as possible # from existing color angles angles = color_dict["color_angle"] return new_angle(angles + angle_list) else: # If object has parents, get the average color angle of the parents, # weighting the average by object area args = [df, row, color_dict, previous_time] parent_angles, areas = get_parent_angles(*args) return circular_mean(parent_angles, areas) def update_color_angle(df, row, color_dict, previous_time, universal_id): # If object is already in color_dict, get its color angle dict_universal_ids = np.array(color_dict["universal_id"]) times = np.array(color_dict["time"]) cond = (dict_universal_ids == universal_id) & (times == previous_time) previous_angle = np.array(color_dict["color_angle"])[cond][0] previous_area = df.loc[previous_time, universal_id]["area"] if str(row["parents"]) == "nan": # If the object has no new parents, i.e. no mergers have occured, # retain the same color return previous_angle else: # If the object has new parents, get the average color angle of the # parents and the current object args = [df, row, color_dict, previous_time] parent_angles, areas = get_parent_angles(*args) args = [parent_angles + [previous_angle], areas + [previous_area]] return circular_mean(*args) def get_color_angle_df(object_name, output_parent, filepath=None): """ Get a dictionary containing color angles, i.e. indices, for displaying masks. The color angle is calculated to reflect object splits/merges. """ if filepath is None: filepath = output_parent / f"attributes/{object_name}/core.csv" df = read_attribute_csv(filepath, columns=["parents", "area"]) color_dict = {"time": [], "universal_id": [], "color_angle": []} times = sorted(np.unique(df.reset_index().time)) previous_time = None for i, time in enumerate(times): df_time = df.xs(time, level="time") universal_ids = sorted(np.unique(df_time.reset_index().universal_id)) time_list, universal_id_list, angle_list = [], [], [] if i > 0: previous_time = times[i - 1] for j, universal_id in enumerate(universal_ids): row = df_time.loc[universal_id] if universal_id not in color_dict["universal_id"]: # Object not yet in color_dict angle = new_color_angle(df, row, color_dict, previous_time, angle_list) else: # If object is already in color_dict, get its color angle args = [df, row, color_dict, previous_time, universal_id] angle = update_color_angle(*args) time_list.append(time) universal_id_list.append(universal_id) angle_list.append(angle) color_dict["time"] += time_list color_dict["universal_id"] += universal_id_list color_dict["color_angle"] += angle_list return pd.DataFrame(color_dict)