Source code for thuner.data.era5

"""Process ERA5 data."""

import calendar
import signal
from pathlib import Path
import tempfile
import numpy as np
import pandas as pd
import xarray as xr
import cdsapi
from typing import Any, Literal
from pydantic import Field, model_validator
import thuner.log as log
from thuner.utils import get_hour_interval, BaseDatasetOptions
import thuner.data._utils as _utils
from thuner.config import get_outputs_directory

__all__ = ["ERA5Options", "get_era5_filepaths"]


logger = log.setup_logger(__name__)

_summary = {
    "mode": "Mode of the data, e.g. reannalysis.",
    "pressure_levels": "Pressure levels; required if data_format is pressure-levels.",
    "storage": "Storage format of the data, e.g. monthly.",
    "start_buffer": "Minutes before interval start time to include.",
}


[docs] class ERA5Options(BaseDatasetOptions): """Options for ERA5 datasets.""" def model_post_init(self, __context): """Use model_post_init to change default inherited values.""" new_defaults = {"use": "tag", "parent_remote": "/g/data/rt52"} self._change_defaults(**new_defaults) message = f"Setting default era5 {self.data_format} options name and fields." logger.debug(message) if self.data_format == "pressure-levels": self._change_defaults(name="era5_pl", fields=["u", "v", "z", "r", "t"]) elif self.data_format == "single-levels": self._change_defaults(name="era5_sl", fields=["cape", "cin"]) # Define additional fields for era5 _desc = "Latitude range if accessing a directory of subsetted era5 data." latitude_range: list[float] = Field([-90, 90], description=_desc) _desc = "Longitude range if accessing a directory of subsetted era5 data." longitude_range: list[float] = Field([-180, 180], description=_desc) mode: Literal["reanalysis"] = Field("reanalysis", description=_summary["mode"]) _FormatChoices = Literal["pressure-levels", "single-levels"] _desc = "Data format, e.g. pressure-levels." data_format: _FormatChoices = Field("pressure-levels", description=_desc) _desc = "Pressure levels; required if data_format is pressure-levels." pressure_levels: list[str] | list[float] | None = Field(None, description=_desc) storage: str = Field("monthly", description=_summary["storage"])
[docs] def get_filepaths(self): """Override the get_filepaths method with the era5 version.""" return get_era5_filepaths(self)
[docs] def update_input_record(self, time, input_record, track_options, grid_options): """Update the era5 dataset.""" update_era5_input_record(time, input_record, track_options, self, grid_options)
@model_validator(mode="after") def _check_ranges(cls, values): if values.latitude_range[0] < -90 or values.latitude_range[1] > 90: raise ValueError("Latitude range must be between -90 and 90.") if values.longitude_range[0] < -180 or values.longitude_range[1] > 180: raise ValueError("Longitude range must be between -180 and 180.") return values @model_validator(mode="after") def _check_defaults(cls, values): if values.data_format == "pressure-levels": if values.pressure_levels is None: values.pressure_levels = era5_pressure_levels logger.debug(f"Assigning default era5 pressure levels.") values.pressure_levels = [str(level) for level in values.pressure_levels] return values @model_validator(mode="after") def _check_times(cls, values): start_time = np.datetime64("1940-03-01T00:00:00") if np.datetime64(values.start) < start_time: raise ValueError(f"start must be {str(start_time)} or later.") return values @model_validator(mode="after") def _check_filepaths(cls, values): if values.filepaths is None: logger.info("Generating era5 filepaths.") values.filepaths = get_era5_filepaths(values) if values.filepaths is None: raise ValueError("filepaths not provided or badly formed.") return values
era5_pressure_levels = ["1000", "975", "950", "925", "900", "875", "850", "825", "800"] era5_pressure_levels += ["775", "750", "700", "650", "600", "550", "500", "450", "400"] era5_pressure_levels += ["350", "300", "250", "225", "200", "175", "150", "125", "100"] era5_pressure_levels += ["70", "50", "30", "20", "10", "7", "5", "3", "2", "1"]
[docs] def get_era5_filepaths(dataset_options: ERA5Options, start=None, end=None): """ Generate era5 filepaths from dataset options dictionary. Parameters ---------- dataset_options : dict Dictionary containing the input dataset_options. Returns ------- urls : list List of URLs. times : list Times associated with the URLs. """ # First get the base_path base_path = get_base_path(dataset_options, local=True) # If start and end are not provided, use the dataset options if start is None: start = dataset_options.start if end is None: end = dataset_options.end start = pd.Timestamp(start) end = pd.Timestamp(end) short_data_format = {"pressure-levels": "pl", "single-levels": "sfc"} # Get the times corresponding to the filepaths times = get_file_datetimes(dataset_options, start, end) # We will store individual fields in separate files fields = dataset_options.fields filepaths = dict(zip(fields, [[] for i in range(len(fields))])) for field in dataset_options.fields: for time in times: time = pd.Timestamp(time) daterange_str = format_daterange(dataset_options, time) filepath = f"{base_path}/{field}/{time.year}/{field}_era5_oper_" filepath += f"{short_data_format[dataset_options.data_format]}_" filepath += f"{daterange_str}.nc" filepaths[field].append(filepath) for key in filepaths.keys(): filepaths[key] = sorted(filepaths[key]) return filepaths
def format_daterange(options, time): """ Format the date range string used in ERA5 file names on NCI Gadi, https://dx.doi.org/10.25914/5f48874388857. Parameters ---------- options : dict Dictionary containing the data options. time : np.datetime64, pd.Timestamp or str The time to format. Returns ------- date_range_str : str The formatted date range str. """ time = pd.Timestamp(time) last_day = calendar.monthrange(time.year, time.month)[1] if options.storage == "daily": date_range_str = f"{time.year:04}{time.month:02}{time.day:02}" elif options.storage == "monthly": date_range_str = ( f"{time.year:04}{time.month:02}01-{time.year:04}{time.month:02}{last_day}" ) return date_range_str def get_base_path(options: Any, local: bool = True): """Get the base path for the ERA5 data.""" if local: parent = options.parent_local else: parent = options.parent_remote latitude_range = options.latitude_range longitude_range = options.longitude_range if latitude_range == [-90, 90] and longitude_range == [-180, 180]: return f"{parent}/era5/{options.data_format}/{options.mode}" area = get_area(options) area_str = get_area_string(area) if area_str is None: group = f"era5_{options.storage}" else: group = f"era5_{options.storage}_{area_str}" return f"{parent}/{group}/era5/{options.data_format}/{options.mode}" def get_file_datetimes(options, start, end): """Get the datetimes corresponding to the filepaths.""" if options.storage == "daily": # Note we typically store data locally in daily files range_start = np.datetime64(f"{start.year:04}-{start.month:02}-{start.day:02}") range_end = np.datetime64(f"{end.year:04}-{end.month:02}-{end.day:02}") time_step = np.timedelta64(1, "D") elif options.storage == "monthly": # On GADI era5 data is stored in monthly files range_start = np.datetime64(f"{start.year:04}-{start.month:02}") range_end = np.datetime64(f"{end.year:04}-{end.month:02}") time_step = np.timedelta64(1, "M") else: raise ValueError("options.storage must be either 'daily' or 'monthly'.") times = np.arange(range_start, range_end + time_step, time_step) return times def generate_cdsapi_requests(options): """ Retrieve ERA5 data using the CDS API. Parameters ---------- options : dict A dictionary containing the input options. Returns ------- cds_name : str The name argument for the cdsapi retrieval. requests : dict A dictionary containing the cdsapi retrieval options. local_paths : dict A dictionary containing the local file paths. """ # First get the base_path for where to store the files locally base_path = get_base_path(options, local=True) short_data_format = {"pressure-levels": "pl", "single-levels": "sfc"} short_format = short_data_format[options.data_format] requests = dict(zip(options.fields, [[] for i in range(len(options.fields))])) local_paths = dict(zip(options.fields, [[] for i in range(len(options.fields))])) cds_name = f"reanalysis-era5-{options.data_format}" start = pd.Timestamp(options["start"]) # Add an hour to the end time to facilitate temporal interpolation end = pd.Timestamp(options["end"]) + pd.Timedelta(hours=1) # Get the times corresponding to the filepaths times = get_file_datetimes(options, start, end) if options.latitude_range is None: latitude_range = [-90, 90] else: latitude_range = options.latitude_range if options.longitude_range is None: longitude_range = [-180, 180] else: longitude_range = options.longitude_range area = [latitude_range[1], longitude_range[0]] area += [latitude_range[0], longitude_range[1]] # Define a function to get the days for the API request for each time def get_days(time, options): if options.storage == "daily": days = [f"{time.day:02}"] elif options.storage == "monthly": last_day = calendar.monthrange(time.year, time.month)[1] days = [f"{i:02}" for i in range(1, last_day + 1)] else: raise ValueError("options.storage must be either 'daily' or 'monthly'.") return days for field in options.fields: for time in times: time = pd.Timestamp(time) days = get_days(time, options) request = { "product_type": [options.mode], "data_format": "netcdf", "download_format": "unarchived", "variable": [field], "pressure_level": options["pressure_levels"], "year": [f"{time.year:04}"], "month": [f"{time.month:02}"], "day": days, "time": [f"{i:02}" for i in range(0, 24)], "area": area, } daterange_str = format_daterange(options, time) local_path = f"{base_path}/{field}/{time.year}/{field}_era5_oper_" local_path += f"{short_format}_{daterange_str}.nc" requests[field].append(request) local_paths[field].append(local_path) return cds_name, requests, local_paths def get_area(options): """Get the area for the CDS API request.""" if options.longitude_range is None: max_lon = 180 min_lon = -180 logger.warning("No longitude range provided. ERA5 files cover all longitudes.") else: [min_lon, max_lon] = options.longitude_range if options.latitude_range is None: max_lat = 90 min_lat = -90 logger.warning("No latitude range provided. ERA5 files cover all latitudes.") else: [min_lat, max_lat] = options.latitude_range [max_lat, max_lon] = [int(np.ceil(coord)) for coord in [max_lat, max_lon]] [min_lat, min_lon] = [int(np.floor(coord)) for coord in [min_lat, min_lon]] if min_lon == -180 and max_lon == 180 and min_lat == -90 and max_lat == 90: return None else: return [max_lat, min_lon, min_lat, max_lon] def get_area_string(area): """Get the area string for the CDS API request.""" if area is None: return None # Convert a signed latitude or longitude to a string, e.g. 150E def format_lat(lat): return "0" if lat == 0 else f"{int(abs(lat))}{'N' if lat > 0 else 'S'}" def format_lon(lon): return "0" if lon == 0 else f"{int(abs(lon))}{'E' if lon > 0 else 'W'}" area_string = f"{format_lat(area[0])}_{format_lon(area[1])}_{format_lat(area[2])}_{format_lon(area[3])}" return area_string def issue_cdsapi_requests( cds_name, requests, local_paths, enforce_timeout=False, timeout=5 ): """Issue cdsapi requests. Note the wait client functionality doesn't appear to work yet.""" def download_data(cds_name, request, local_path): c = cdsapi.Client(sleep_max=5, retry_max=1) response = c.retrieve(cds_name, request, local_path) return response def handle_request(cds_name, request, local_path): path = Path(local_path) path.parent.mkdir(parents=True, exist_ok=True) if enforce_timeout: def signal_handler(signum, frame): raise TimeoutError("Request timed out.") signal.signal(signal.SIGALRM, signal_handler) signal.alarm(timeout) try: download_data(cds_name, request, local_path) except TimeoutError: filename = Path(local_path).name message = f"Request for {filename} timed out after {timeout} seconds." logger.warning(message) finally: signal.alarm(0) else: download_data(cds_name, request, local_path) for field in requests.keys(): for i in range(len(local_paths[field])): handle_request(cds_name, requests[field][i], local_paths[field][i]) def convert_era5(ds): """Convert ERA5 data.""" if "level" in ds.coords: ds = ds.rename({"level": "pressure"}) if "time_var" in ds.coords: ds = ds.rename({"time_var": "time"}) logger.debug("Renamed time_var to time in era5 dataset.") if "r" in ds.data_vars: ds = ds.rename({"r": "relative_humidity"}) if "t" in ds.data_vars: ds = ds.rename({"t": "temperature"}) if "z" in ds.data_vars: ds = ds.rename({"z": "geopotential"}) return ds def update_era5_input_record( time, input_record, track_options, dataset_options, grid_options ): """Update ERA5 dataset.""" _utils.log_dataset_update(logger, dataset_options.name, time) kwargs = {"start_buffer": dataset_options.start_buffer} kwargs.update({"end_buffer": dataset_options.end_buffer}) start, end = get_hour_interval(time, **kwargs) filepaths = get_era5_filepaths(dataset_options, start, end) all_files_exist = all( Path(filepath).exists() for field in filepaths.values() for filepath in field ) if not all_files_exist and dataset_options.attempt_download: logger.warning("One or more filepaths do not exist; attempting download.") cds_name, requests, local_paths = generate_cdsapi_requests(dataset_options) issue_cdsapi_requests(cds_name, requests, local_paths) lat = np.array(grid_options.latitude) lon = np.array(grid_options.longitude) # Expand the lat and lon ranges to include a buffer to ensure required gridpoints are included lat_range = (lat.min() - 0.25, lat.max() + 0.25) lon_range = (lon.min() - 0.25, lon.max() + 0.25) # Assume user has write privileges in the base_local directory logger.info(f"Subsetting {dataset_options.name} data.") with tempfile.TemporaryDirectory(dir=str(get_outputs_directory())) as tmp: for field in dataset_options.fields: for filepath in filepaths[field]: output_filename = Path(filepath).name logger.debug("Subsetting %s", output_filename) args = [filepath, f"{tmp}/{output_filename}.nc", start, end] args += [lat_range, lon_range] _utils.call_ncks(*args) logger.debug("Merging files.") ds = xr.open_mfdataset(f"{tmp}/*.nc") logger.debug("Converting") ds = convert_era5(ds) input_record.dataset = ds.load()