Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 92 additions & 7 deletions climatenet/analyze_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cartopy.crs as ccrs
import haversine as hs

def analyze_events(event_masks_xarray, class_masks_xarray, results_dir):
def analyze_events(event_masks_xarray, class_masks_xarray, results_dir, polar=False):
"""Analzse event masks of ARs and TCs

Produces PNGs of
Expand All @@ -19,19 +19,62 @@ def analyze_events(event_masks_xarray, class_masks_xarray, results_dir):
class_masks_xarray -- the class masks as xarray, 0==Background, 1==TC, 2 ==AR
event_masks_xarray -- the event masks as xarray with IDs as elements
results_dir -- the directory where the PNGs get saved to
polar -- bool, default False. If True, use polar stereographic projection
"""
# create results_dir if it doesn't exist
pathlib.Path(results_dir).mkdir(parents=True, exist_ok=True)

class_masks = class_masks_xarray.values
event_masks = event_masks_xarray.values

if polar:
# Constants
R = 6378137 # Radius of the Earth in meters (WGS84)
true_scale_lat = -81.06523 # Latitude of true scale in degrees
image_size = 1152 # Image dimensions (1152x1152 pixels)
center_x, center_y = image_size // 2, image_size // 2 # Center of the image

# Convert true scale latitude to radians
phi_ts = np.radians(true_scale_lat)

# Calculate scaling factor for the projection
k0 = 2 * R * np.tan(np.pi / 4 + phi_ts / 2)

print('calculating centroids..', flush=True)

def pixel_to_degree(pos):
"""Returns the (lat,long) position of a pixel coordinate"""
return(pos[0] * 180.0 / event_masks.shape[1] - 90,
pos[1] * 360 / event_masks.shape[2] + 180)

def pixel_to_stereographic(pos):
pixel_x, pixel_y = pos
# Convert pixel position relative to the center
x_rel = pixel_x - center_x
y_rel = center_y - pixel_y # Flipping y-axis because image coordinates grow downwards

# Convert to stereographic coordinates (x, y in meters)
x_stereo = (x_rel / (image_size / 2)) * k0
y_stereo = (y_rel / (image_size / 2)) * k0

return x_stereo, y_stereo

def stereographic_to_latlon(pos):
x_stereo, y_stereo=pos
# Calculate rho (distance from the center)
rho = np.sqrt(x_stereo**2 + y_stereo**2)

# Calculate the inverse stereographic projection
c= 2 * np.arctan(rho / (2 * R))
# lat is not negative b/c south polar stereo
lat = np.pi / 2 + 2 * np.arctan(np.cos(c) / np.sin(c))
lon = np.arctan2(x_stereo, -y_stereo)

# Convert latitude and longitude to degrees
lat_deg = np.degrees(lat)
lon_deg = np.degrees(lon)

return lat_deg, lon_deg

def average_location(coordinates_pixel):
"""Returns the average geolocation in pixel space
Expand Down Expand Up @@ -68,6 +111,37 @@ def average_location(coordinates_pixel):
return (event_masks.shape[1] * (average_degree[0] + 90) / 180,
event_masks.shape[2] * (average_degree[1] + 180) / 360)

def average_location_polar(coordinates_pixel):
"""
Returns the average geolocation in pixel space for polar stereographic projection.

Parameters:
coordinates_pixel : list of tuples
List of (x, y) pixel coordinates in polar stereographic projection.

Returns:
tuple
(average_x, average_y) coordinates of the average location.
"""

x = 0.0
y = 0.0
z = 0.0

for x_pixel, y_pixel in coordinates_pixel:
# Convert pixel coordinates to stereographic projection
# No need to convert to radians in stereographic projection
x += x_pixel
y += y_pixel

total = len(coordinates_pixel)

# Calculate the average coordinates
average_x = x / total
average_y = y / total

return average_x, average_y

global centroids # make function visible to pool
def centroids(event_mask):
"""Returns a dict mapping from the IDs in event_mask to their centroids"""
Expand All @@ -81,8 +155,12 @@ def centroids(event_mask):
coordinates_per_id.setdefault(this_id, []).append((row, col))

centroid_per_id = {}
for this_id in coordinates_per_id:
centroid_per_id[this_id] = average_location(coordinates_per_id[this_id])
if not polar:
for this_id in coordinates_per_id:
centroid_per_id[this_id] = average_location(coordinates_per_id[this_id])
elif polar:
for this_id in coordinates_per_id:
centroid_per_id[this_id] = average_location_polar(coordinates_per_id[this_id])

return centroid_per_id

Expand Down Expand Up @@ -188,9 +266,13 @@ def event_type_of_mask(event_mask, class_mask):
for i in range(len(this_class_ids)):
termination_centroids.append(centroid_per_id_per_time[termination_times[i]][list(this_class_ids)[i]])
genesis_centroids.append(centroid_per_id_per_time[genesis_times[i]][list(this_class_ids)[i]])

distances = np.array([hs.haversine(pixel_to_degree(pos1), pixel_to_degree(pos2))
for pos1, pos2 in zip(termination_centroids, genesis_centroids)])
if not polar:
distances = np.array([hs.haversine(pixel_to_degree(pos1), pixel_to_degree(pos2))
for pos1, pos2 in zip(termination_centroids, genesis_centroids)])
elif polar:
distances = np.array([hs.haversine(stereographic_to_latlon(pixel_to_stereographic(pos1)),stereographic_to_latlon(pixel_to_stereographic(pos2)))
for pos1, pos2 in zip(termination_centroids, genesis_centroids)])


# travel distance histogram
plt.figure(dpi=100)
Expand Down Expand Up @@ -228,7 +310,10 @@ def map_instance(title):
plt.rc('ytick',labelsize=20)
mymap = plt.subplot(111,projection=ccrs.PlateCarree())
mymap.set_global()
mymap.background_img(name='BM')
if not polar:
mymap.background_img(name='BM')
elif polar:
mymap.background_img(name='BM_Polar_SPS')
mymap.coastlines()
mymap.gridlines(crs=ccrs.PlateCarree(),linewidth=2, color='k', alpha=0.5, linestyle='--')
mymap.set_xticks([-180,-120,-60,0,60,120,180])
Expand Down
Binary file added climatenet/bluemarble/BM_Polar_NPS.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added climatenet/bluemarble/BM_Polar_SPS.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions climatenet/utils/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def jaccard_loss(logits, true, eps=1e-7):
jacc_loss: the Jaccard loss.
"""
num_classes = logits.shape[1]
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
true_1_hot = torch.eye(num_classes)[true.squeeze(1).long()]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
Expand All @@ -27,4 +27,4 @@ def jaccard_loss(logits, true, eps=1e-7):
cardinality = torch.sum(probas + true_1_hot, dims)
union = cardinality - intersection
jacc_loss = (intersection / (union + eps)).mean()
return (1 - jacc_loss)
return (1 - jacc_loss)