Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ogc/edr/edr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def get_configuration(base_url: str, layers: List[pogc.Layer]) -> Dict[str, Any]

# Add the data resources and provider information
resources = configuration.get("resources", {})
configuration["resources"] = resources | EdrConfig._resources_definition(layers)
configuration["resources"] = resources | EdrConfig._resources_definition(base_url, layers)

return configuration

@staticmethod
def _resources_definition(layers: List[pogc.Layer]) -> Dict[str, Any]:
def _resources_definition(base_url: str, layers: List[pogc.Layer]) -> Dict[str, Any]:
"""Define resource related data for the configuration.

The resources dictionary holds the information needed to generate the collections.
Expand All @@ -66,6 +66,8 @@ def _resources_definition(layers: List[pogc.Layer]) -> Dict[str, Any]:

Parameters
----------
base_url : str
The base URL used as an identifier for the given layers.
layers : List[pogc.Layer]
The layers which define the data sources for the EDR server.

Expand Down Expand Up @@ -98,7 +100,7 @@ def _resources_definition(layers: List[pogc.Layer]) -> Dict[str, Any]:
"default": True,
"name": "ogc.edr.edr_provider.EdrProvider",
"data": group_name,
"layers": group_layers,
"base_url": base_url,
"crs": [
"https://www.opengis.net/def/crs/OGC/1.3/CRS84",
"https://www.opengis.net/def/crs/EPSG/0/4326",
Expand Down
111 changes: 83 additions & 28 deletions ogc/edr/edr_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import zipfile
import traitlets as tl
from datetime import datetime
from collections import defaultdict
from typing import List, Dict, Tuple, Any
from shapely.geometry.base import BaseGeometry
from pygeoapi.provider.base import ProviderConnectionError, ProviderInvalidQueryError
Expand All @@ -17,6 +18,37 @@
class EdrProvider(BaseEDRProvider):
"""Custom provider to be used with layer data sources."""

_layers_dict = defaultdict(list)

@classmethod
def set_layers(cls, base_url: str, layers: List[pogc.Layer]):
"""Set the layer resources which will be available to the provider.

Parameters
----------
base_url : str
The base URL that the layers are available on.
layers : List[pogc.Layer]
The layers which the provider will have access to.
"""
cls._layers_dict[base_url] = layers

@classmethod
def get_layers(cls, base_url: str) -> List[pogc.Layer]:
"""Get the layer resources for a specific base URL.

Parameters
----------
base_url : str
The base URL for the layers.

Returns
-------
List[pogc.Layer]
The layers associated with the base URL.
"""
return cls._layers_dict.get(base_url, [])

def __init__(self, provider_def: Dict[str, Any]):
"""Construct the provider using the provider definition.

Expand All @@ -30,7 +62,7 @@ def __init__(self, provider_def: Dict[str, Any]):
ProviderConnectionError
Raised if the specified collection is not found within any layers.
ProviderConnectionError
Raised if the provider does not specify any data sources.
Raised if the provider does not specify any base URL.
"""
super().__init__(provider_def)
collection_id = provider_def.get("data", None)
Expand All @@ -39,9 +71,9 @@ def __init__(self, provider_def: Dict[str, Any]):

self.collection_id = str(collection_id)

self.layers = provider_def.get("layers", [])
if len(self.layers) == 0:
raise ProviderConnectionError("Valid data sources not found.")
self.base_url = provider_def.get("base_url", None)
if self.base_url is None:
raise ProviderConnectionError("Valid URL identifier not found for the data.")

@property
def parameters(self) -> Dict[str, pogc.Layer]:
Expand All @@ -54,7 +86,9 @@ def parameters(self) -> Dict[str, pogc.Layer]:
Dict[str, pogc.Layer]
The parameters as a dictionary of layer identifiers and layer objects.
"""
return {layer.identifier: layer for layer in self.layers if layer.group == self.collection_id}
return {
layer.identifier: layer for layer in self.get_layers(self.base_url) if layer.group == self.collection_id
}

def handle_query(self, requested_coordinates: podpac.Coordinates, **kwargs):
"""Handle the requests to the EDR server at the specified requested coordinates.
Expand Down Expand Up @@ -136,21 +170,30 @@ def handle_query(self, requested_coordinates: podpac.Coordinates, **kwargs):
)

self.check_query_condition(
requested_native_coordinates.size > settings.MAX_GRID_COORDS_REQUEST_SIZE,
bool(requested_native_coordinates.size > settings.MAX_GRID_COORDS_REQUEST_SIZE),
"Grid coordinates x_size * y_size must be less than %d" % settings.MAX_GRID_COORDS_REQUEST_SIZE,
)

dataset = {}
for requested_parameter, layer in parameters_filtered.items():
units_data_array = layer.node.eval(requested_native_coordinates)
# Recombine stacked temporal dimensions if necessary.
# The temporal output should always be stacked, based on stacked input.
if "time_forecastOffsetHr" in units_data_array.dims:
forecast_offsets = units_data_array.forecastOffsetHr.data.copy()
time_data = units_data_array.time.data.copy()
units_data_array = units_data_array.drop_vars({"time", "time_forecastOffsetHr", "forecastOffsetHr"})
units_data_array = units_data_array.rename(time_forecastOffsetHr="time")
units_data_array = units_data_array.assign_coords(time=time_data + forecast_offsets)
dataset[requested_parameter] = units_data_array

self.check_query_condition(len(dataset) == 0, "No matching parameters found.")

# Return a coverage json if specified, else return Base64 encoded native response
if output_format == "json" or output_format == "coveragejson":
crs = self.interpret_crs(requested_native_coordinates.crs if requested_native_coordinates else None)
return self.to_coverage_json(self.layers, dataset, crs)
layers = self.get_layers(self.base_url)
return self.to_coverage_json(layers, dataset, crs)
else:
return self.to_geotiff_response(dataset, self.collection_id)

Expand Down Expand Up @@ -182,11 +225,13 @@ def position(self, **kwargs):
crs = EdrProvider.interpret_crs(crs)

if not isinstance(wkt, BaseGeometry):
raise ProviderInvalidQueryError("Invalid wkt provided.")
msg = "Invalid WKT string provided for the position query."
raise ProviderInvalidQueryError(msg, user_msg=msg)
elif wkt.geom_type == "Point":
lon, lat = EdrProvider.crs_converter([wkt.x], [wkt.y], crs)
else:
raise ProviderInvalidQueryError("Unknown WKT Type (Use Point).")
msg = "Unknown WKT string type for the position query (use Point)."
raise ProviderInvalidQueryError(msg, user_msg=msg)

requested_coordinates = podpac.Coordinates([lat, lon], dims=["lat", "lon"], crs=crs)

Expand Down Expand Up @@ -217,7 +262,8 @@ def cube(self, **kwargs):
crs = EdrProvider.interpret_crs(crs)

if not isinstance(bbox, List) or len(bbox) != 4:
raise ProviderInvalidQueryError("Invalid bounding box provided.")
msg = "Invalid bounding box provided, expected bounding box of (minx, miny, maxx, maxy)."
raise ProviderInvalidQueryError(msg, user_msg=msg)

xmin, ymin, xmax, ymax = bbox
lon, lat = EdrProvider.crs_converter([xmin, xmax], [ymin, ymax], crs)
Expand Down Expand Up @@ -254,11 +300,13 @@ def area(self, **kwargs):
crs = EdrProvider.interpret_crs(crs)

if not isinstance(wkt, BaseGeometry):
raise ProviderInvalidQueryError("Invalid wkt provided.")
msg = "Invalid WKT string provided for the area query."
raise ProviderInvalidQueryError(msg, user_msg=msg)
elif wkt.geom_type == "Polygon":
lon, lat = EdrProvider.crs_converter(wkt.exterior.xy[0], wkt.exterior.xy[1], crs)
else:
raise ProviderInvalidQueryError("Unknown WKT Type (Use Polygon).")
msg = "Unknown WKT string type for the area query (use Polygon)."
raise ProviderInvalidQueryError(msg, user_msg=msg)

requested_coordinates = podpac.Coordinates([lat, lon], dims=["lat", "lon"], crs=crs)

Expand Down Expand Up @@ -288,7 +336,8 @@ def instances(self, **kwargs) -> List[str]:
The instances available in the collection.
"""
instances = set()
for layer in self.layers:
layers = self.get_layers(self.base_url)
for layer in layers:
if layer.group == self.collection_id:
instances.update(layer.time_instances())
return list(instances)
Expand Down Expand Up @@ -382,7 +431,8 @@ def interpret_crs(crs: str | None) -> str:
return settings.crs_84_pyproj_format # Pyproj acceptable format

if crs.lower() not in [key.lower() for key in settings.EDR_CRS.keys()]:
raise ProviderInvalidQueryError("Invalid CRS provided.")
msg = f"Invalid CRS provided, expected one of {', '.join(settings.EDR_CRS.keys())}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like all these more-informative error messages!

It's always worth thinking about, from a security persepective, whether providing additional information with an error message will reveal anything that would aid an attacker. The list of supported CRS's is not a private piece of information, and all the other messages also look good to me, I just figured I'd mention it.

raise ProviderInvalidQueryError(msg, user_msg=msg)

return crs

Expand Down Expand Up @@ -539,20 +589,24 @@ def to_coverage_json(
coordinates = next(iter(dataset.values())).coords
x_arr, y_arr = EdrProvider.crs_converter(coordinates["lon"].values, coordinates["lat"].values, crs)

# Convert numpy array coordinates to a flattened list.
x_arr = list(x_arr.flatten())
y_arr = list(y_arr.flatten())

coverage_json = {
"type": "Coverage",
"domain": {
"type": "Domain",
"domainType": "Grid",
"axes": {
"x": {
"start": x_arr[0],
"stop": x_arr[-1],
"start": x_arr[0] if len(x_arr) > 0 else None,
"stop": x_arr[-1] if len(x_arr) > 0 else None,
Comment on lines +603 to +604
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good additional case handling!

"num": len(x_arr),
},
"y": {
"start": y_arr[0],
"stop": y_arr[-1],
"start": y_arr[0] if len(y_arr) > 0 else None,
"stop": y_arr[-1] if len(y_arr) > 0 else None,
"num": len(y_arr),
},
},
Expand Down Expand Up @@ -635,7 +689,7 @@ def check_query_condition(conditional: bool, message: str):
Raised if the conditional provided is true.
"""
if conditional:
raise ProviderInvalidQueryError(message)
raise ProviderInvalidQueryError(message, user_msg=message)

@staticmethod
def validate_datetime(datetime_string: str) -> bool:
Expand Down Expand Up @@ -697,8 +751,8 @@ def get_native_coordinates(
target_coordinates: podpac.Coordinates,
source_time_instance: np.datetime64 | None,
) -> podpac.Coordinates:
"""Find the intersecting coordinates between source and target coordinates.
Convert time instances to offsets for node evalutation.
"""Find the intersecting latitude and longitude coordinates between the source and target.
Convert time instances to stacked time and forecast offsets for node evalutation.

Parameters
----------
Expand All @@ -714,6 +768,9 @@ def get_native_coordinates(
podpac.Coordinates
The converted coordinates source coordinates intersecting with the target coordinates.
"""
# Find intersections with target keeping source crs
source_intersection_coordinates = target_coordinates.intersect(source_coordinates, dims=["lat", "lon"])
source_intersection_coordinates = source_intersection_coordinates.transform(source_coordinates.crs)
# Handle conversion from times and instance to time and offsets
if (
"forecastOffsetHr" in target_coordinates.udims
Expand All @@ -728,14 +785,12 @@ def get_native_coordinates(

# This modifies the time coordinates to account for the new forecast offset hour
new_coordinates = podpac.Coordinates(
[[source_time_instance], time_deltas],
["time", "forecastOffsetHr"],
[[[source_time_instance] * len(time_deltas), time_deltas]],
[["time", "forecastOffsetHr"]],
crs=source_coordinates.crs,
)
source_coordinates = podpac.coordinates.merge_dims([source_coordinates.drop("time"), new_coordinates])

# Find intersections with target keeping source crs
source_intersection_coordinates = target_coordinates.intersect(source_coordinates)
source_intersection_coordinates = source_intersection_coordinates.transform(source_coordinates.crs)
source_intersection_coordinates = podpac.coordinates.merge_dims(
[source_intersection_coordinates.udrop(["time", "forecastOffsetHr"]), new_coordinates]
)

return source_intersection_coordinates
3 changes: 2 additions & 1 deletion ogc/edr/edr_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ogc import podpac as pogc

from .edr_config import EdrConfig
from .edr_provider import EdrProvider


class EdrRoutes(tl.HasTraits):
Expand Down Expand Up @@ -51,7 +52,7 @@ def create_api(self) -> pygeoapi.api.API:
# This is a bypass which is needed to get by a conditional check in pygeoapi.
pygeoapi.plugin.PLUGINS["formatter"]["GeoTiff"] = ""
pygeoapi.plugin.PLUGINS["formatter"]["CoverageJSON"] = ""

EdrProvider.set_layers(self.base_url, self.layers)
config = EdrConfig.get_configuration(self.base_url, self.layers)
open_api = get_oas(config, fail_on_invalid_collection=False)
return pygeoapi.api.API(config=deepcopy(config), openapi=open_api)
Expand Down
Loading