Source code for thuner.group.group

"""Module for grouping objects into new objects."""

import copy
import numpy as np
import xarray as xr
import networkx as nx
from networkx.algorithms.components.connected import connected_components
from thuner.utils import get_time_interval


__all__ = ["group"]


[docs] def group( track_input_records, tracks, level_index, obj, dataset_options, object_options, grid_options, ): """Group objects into new objects based on overlap and connected components.""" dataset = track_input_records[object_options.dataset].dataset if tracks.levels[level_index].objects[obj].gridcell_area is None: tracks.levels[level_index].objects[obj].gridcell_area = dataset["gridcell_area"] member_objects = object_options.grouping.member_objects member_levels = object_options.grouping.member_levels grid_dict = {} for member_obj, member_level in zip(member_objects, member_levels): member_grid = tracks.levels[member_level].objects[member_obj].next_grid grid_dict[f"{member_obj}_grid"] = member_grid # Store the domain boundaries associated with the consituent masks grid = xr.Dataset(grid_dict) mask = get_connected_components(tracks, object_options) previous_mask = copy.deepcopy(tracks.levels[level_index].objects[obj].next_mask) tracks.levels[level_index].objects[obj].masks.append(previous_mask) tracks.levels[level_index].objects[obj].next_mask = mask previous_grid = copy.deepcopy(tracks.levels[level_index].objects[obj].next_grid) tracks.levels[level_index].objects[obj].grids.append(previous_grid) tracks.levels[level_index].objects[obj].next_grid = grid tracks.levels[level_index].objects[obj].previous_time_interval = copy.deepcopy( tracks.levels[level_index].objects[obj].next_time_interval ) tracks.levels[level_index].objects[obj].next_time_interval = get_time_interval( grid, previous_grid )
def get_connected_components(tracks, object_options): """Calculate connected components from a dictionary of masks.""" member_objects = object_options.grouping.member_objects member_levels = object_options.grouping.member_levels # Relabel objects in mask so that object numbers are unique current_max = 0 masks = [] for obj, level in zip(member_objects, member_levels): mask = tracks.levels[level].objects[obj].next_mask new_mask = mask.copy() new_mask = new_mask.where(mask == 0, mask + current_max) masks.append(new_mask.values) current_max += np.max(mask.values) # Create graph for objects that overlap at different vertical levels. overlap_graph = nx.Graph() overlap_graph.add_nodes_from(set(range(1, current_max))) # Create edges between the objects that overlap vertically, assuming member objects # listed in increasing altitude. for i in range(len(masks) - 1): # Determine the objects in frame i. objects = set(np.unique(masks[i])).difference({0}) for j in objects: # Determine the objects in frame i + 1 that overlap with object j. overlap = np.logical_and(masks[i] == j, masks[i + 1] > 0) overlap_objects = set(masks[i + 1][overlap].flatten()) # If objects overlap, add edge between object j and first # object from overlap set for k in overlap_objects: overlap_graph.add_edges_from([(j, k)]) # Initialize a new mask ds to represent the grouped object. mask_da_list = [] for obj, level in zip(member_objects, member_levels): mask_da = xr.full_like( tracks.levels[level].objects[obj].next_mask, 0, dtype=int ) mask_da_list.append(mask_da) # Create new objects based on connected components new_objs = list(connected_components(overlap_graph)) # Create a counter, as some of the connected components will be rejected new_obj_counter = 0 for i in range(len(new_objs)): # Require that components span all member objects if not component_span(masks, list(new_objs[i])): continue # Require total areas of member objects are above thresholds after grouping if not check_areas(masks, tracks, object_options, list(new_objs[i])): continue # Create new grouped objects new_obj_counter += 1 for j in range(len(mask_da_list)): mask_da_list[j] = mask_da_list[j].where( ~np.isin(masks[j], list(new_objs[i])), new_obj_counter ) grouped_mask = xr.Dataset({da.name: da for da in mask_da_list}) return grouped_mask def check_areas(masks, tracks, object_options, objs): """ Check if the areas of the member objects after grouping are above the threshold. """ member_objects = object_options.grouping.member_objects member_levels = object_options.grouping.member_levels member_min_areas = object_options.grouping.member_min_areas for j in range(len(masks)): mem_obj = tracks.levels[member_levels[j]].objects[member_objects[j]] mask_j = np.isin(masks[j], objs) if mem_obj.gridcell_area.where(mask_j).sum() < member_min_areas[j]: return False return True def component_span(masks, new_objs): """Check if connected component spans all member objects.""" in_mask = [] for i in range(len(masks)): in_mask.append(any([j in masks[i] for j in new_objs])) return all(in_mask)