"""Module for detecting objects in a grid."""
import copy
import numbers
from scipy import ndimage
import numpy as np
import xarray as xr
import thuner.detect.preprocess as preprocess
from thuner.log import setup_logger
from thuner.detect.steiner import steiner_scheme
from thuner.utils import get_time_interval
logger = setup_logger(__name__)
__all__ = ["threshold", "steiner"]
[docs]
def threshold(grid, object_options):
"""Detect objects in the given grid using thresholding."""
if object_options.detection.method != "threshold":
raise ValueError("Detection method not set to threshold.")
if object_options.detection.threshold_type == "minima":
binary_grid = grid >= object_options.detection.threshold
elif object_options.detection.threshold_type == "maxima":
binary_grid = grid <= object_options.detection.threshold
return binary_grid
[docs]
def steiner(grid, object_options):
"""Detect objects in the given grid using the Steiner et al. (1995) method."""
if object_options.detection.method != "steiner":
raise ValueError("Detection method not set to steiner.")
if "latitude" in grid.dims:
coordinates = "geographic"
x = grid.longitude.values
y = grid.latitude.values
elif "x" in grid.dims:
coordinates = "cartesian"
x = grid.x.values
y = grid.y.values
else:
raise ValueError("Could not infer grid coordinates.")
if "altitude" in grid.dims:
raise ValueError(
"Steiner et al. (1995) scheme only works with 2D grids. "
"Apply a flattener first."
)
if "altitude" in grid.coords and grid.altitude != 3e3:
logger.warning(
"Steiner et al. (1995) scheme designed to work on 3 km altitude grids. "
f"grid altitude {grid.altitude.values[0]/1e3} km."
)
binary_grid = xr.full_like(grid, 0)
binary_grid.name = "binary_grid"
if x.ndim == 1 and y.ndim == 1:
X, Y = np.meshgrid(x, y)
elif x.ndim == 2 and y.ndim == 2:
X, Y = x, y
else:
raise ValueError("x and y must both be one or two dimensional.")
steiner_class = steiner_scheme(grid.values, X, Y, coordinates=coordinates)
steiner_class = steiner_class.astype(int)
steiner_class[steiner_class != 2] = 0
binary_grid.data = steiner_class
return binary_grid
detecter_dispatcher = {
"threshold": threshold,
"steiner": steiner,
}
def rebuild_processed_grid(grid_data, track_options, obj, level):
"""Rebuild the processed grid for the given object and level."""
grid_dict = {}
object_options = track_options.levels[level].object_by_name(obj)
if "detection" in object_options.__class__.model_fields:
grid_dict[f"{obj}_grid"] = process_grid(grid_data, object_options)
elif "grouping" in object_options.__class__.model_fields:
member_objects = object_options.grouping.member_objects
member_levels = object_options.grouping.member_levels
for member_obj, member_level in zip(member_objects, member_levels):
args = [grid_data, track_options, member_obj, member_level]
grid_dict.update(rebuild_processed_grid(*args))
processed_grid = xr.Dataset(grid_dict)
# if "detection" in object_options.__class__.model_fields:
# processed_grid = processed_grid[f"{obj}_grid"]
return processed_grid
def process_grid(grid, object_options):
"""Apply a flatten method to produce a 2D grid."""
retrieval = object_options.detection.flatten_method
if retrieval is not None:
func = retrieval.function
kwargs = retrieval.keyword_arguments
processed_grid = func(grid, object_options, **kwargs)
else:
processed_grid = grid
return processed_grid
def detect(
track_input_records,
tracks,
level_index,
obj,
dataset_options,
object_options,
grid_options,
):
"""Detect objects in the given grid."""
object_tracks = tracks.levels[level_index].objects[obj]
previous_grid = copy.deepcopy(object_tracks.next_grid)
object_tracks.grids.append(previous_grid)
input_record = track_input_records[object_options.dataset]
grid = input_record.next_grid
object_tracks.previous_time_interval = copy.deepcopy(
object_tracks.next_time_interval
)
object_tracks.next_time_interval = get_time_interval(grid, previous_grid)
dataset = input_record.dataset
if object_tracks.gridcell_area is None:
object_tracks.gridcell_area = dataset["gridcell_area"]
processed_grid = process_grid(grid, object_options)
object_tracks.next_grid = processed_grid
detecter = detecter_dispatcher.get(object_options.detection.method)
if detecter is None:
raise ValueError("Invalid detection method.")
binary_grid = detecter(processed_grid, object_options)
mask = xr.full_like(binary_grid, 0, dtype=np.uint32)
mask.data = ndimage.label(binary_grid)[0]
mask.name = f"{object_options.name}_mask"
if object_options.detection.min_area is not None:
args = [mask, object_options.detection.min_area, dataset["gridcell_area"]]
mask = clear_small_area_objects(*args)
next_mask = copy.deepcopy(object_tracks.next_mask)
object_tracks.masks.append(next_mask)
object_tracks.next_mask = mask
def clear_small_area_objects(mask, min_area, gridcell_area):
"""Takes in labelled image and clears objects less than min_size."""
for obj in range(1, int(mask.max()) + 1):
if isinstance(gridcell_area, xr.DataArray) and len(gridcell_area.shape) == 2:
obj_area = gridcell_area.data[mask == obj].sum()
elif isinstance(gridcell_area, numbers.Real) and gridcell_area > 0:
obj_area = (mask == obj).sum() * gridcell_area
else:
raise ValueError("gridcell_area must be a positive number or a 2D array.")
if obj_area < min_area:
mask.data[mask == obj] = 0
# Relabel the mask after clearing the small objects
mask.data = ndimage.label(mask)[0]
return mask