diff --git a/climatenet/analyze_events.py b/climatenet/analyze_events.py index f233ba3..9b52f31 100644 --- a/climatenet/analyze_events.py +++ b/climatenet/analyze_events.py @@ -8,7 +8,7 @@ import cartopy.crs as ccrs import haversine as hs -def analyze_events(event_masks_xarray, class_masks_xarray, results_dir): +def analyze_events(event_masks_xarray, class_masks_xarray, results_dir, polar=False): """Analzse event masks of ARs and TCs Produces PNGs of @@ -19,6 +19,7 @@ def analyze_events(event_masks_xarray, class_masks_xarray, results_dir): class_masks_xarray -- the class masks as xarray, 0==Background, 1==TC, 2 ==AR event_masks_xarray -- the event masks as xarray with IDs as elements results_dir -- the directory where the PNGs get saved to + polar -- bool, default False. If True, use polar stereographic projection """ # create results_dir if it doesn't exist pathlib.Path(results_dir).mkdir(parents=True, exist_ok=True) @@ -26,12 +27,54 @@ def analyze_events(event_masks_xarray, class_masks_xarray, results_dir): class_masks = class_masks_xarray.values event_masks = event_masks_xarray.values + if polar: + # Constants + R = 6378137 # Radius of the Earth in meters (WGS84) + true_scale_lat = -81.06523 # Latitude of true scale in degrees + image_size = 1152 # Image dimensions (1152x1152 pixels) + center_x, center_y = image_size // 2, image_size // 2 # Center of the image + + # Convert true scale latitude to radians + phi_ts = np.radians(true_scale_lat) + + # Calculate scaling factor for the projection + k0 = 2 * R * np.tan(np.pi / 4 + phi_ts / 2) + print('calculating centroids..', flush=True) def pixel_to_degree(pos): """Returns the (lat,long) position of a pixel coordinate""" return(pos[0] * 180.0 / event_masks.shape[1] - 90, pos[1] * 360 / event_masks.shape[2] + 180) + + def pixel_to_stereographic(pos): + pixel_x, pixel_y = pos + # Convert pixel position relative to the center + x_rel = pixel_x - center_x + y_rel = center_y - pixel_y # Flipping y-axis because image coordinates grow downwards + + # Convert to stereographic coordinates (x, y in meters) + x_stereo = (x_rel / (image_size / 2)) * k0 + y_stereo = (y_rel / (image_size / 2)) * k0 + + return x_stereo, y_stereo + + def stereographic_to_latlon(pos): + x_stereo, y_stereo=pos + # Calculate rho (distance from the center) + rho = np.sqrt(x_stereo**2 + y_stereo**2) + + # Calculate the inverse stereographic projection + c= 2 * np.arctan(rho / (2 * R)) + # lat is not negative b/c south polar stereo + lat = np.pi / 2 + 2 * np.arctan(np.cos(c) / np.sin(c)) + lon = np.arctan2(x_stereo, -y_stereo) + + # Convert latitude and longitude to degrees + lat_deg = np.degrees(lat) + lon_deg = np.degrees(lon) + + return lat_deg, lon_deg def average_location(coordinates_pixel): """Returns the average geolocation in pixel space @@ -68,6 +111,37 @@ def average_location(coordinates_pixel): return (event_masks.shape[1] * (average_degree[0] + 90) / 180, event_masks.shape[2] * (average_degree[1] + 180) / 360) + def average_location_polar(coordinates_pixel): + """ + Returns the average geolocation in pixel space for polar stereographic projection. + + Parameters: + coordinates_pixel : list of tuples + List of (x, y) pixel coordinates in polar stereographic projection. + + Returns: + tuple + (average_x, average_y) coordinates of the average location. + """ + + x = 0.0 + y = 0.0 + z = 0.0 + + for x_pixel, y_pixel in coordinates_pixel: + # Convert pixel coordinates to stereographic projection + # No need to convert to radians in stereographic projection + x += x_pixel + y += y_pixel + + total = len(coordinates_pixel) + + # Calculate the average coordinates + average_x = x / total + average_y = y / total + + return average_x, average_y + global centroids # make function visible to pool def centroids(event_mask): """Returns a dict mapping from the IDs in event_mask to their centroids""" @@ -81,8 +155,12 @@ def centroids(event_mask): coordinates_per_id.setdefault(this_id, []).append((row, col)) centroid_per_id = {} - for this_id in coordinates_per_id: - centroid_per_id[this_id] = average_location(coordinates_per_id[this_id]) + if not polar: + for this_id in coordinates_per_id: + centroid_per_id[this_id] = average_location(coordinates_per_id[this_id]) + elif polar: + for this_id in coordinates_per_id: + centroid_per_id[this_id] = average_location_polar(coordinates_per_id[this_id]) return centroid_per_id @@ -188,9 +266,13 @@ def event_type_of_mask(event_mask, class_mask): for i in range(len(this_class_ids)): termination_centroids.append(centroid_per_id_per_time[termination_times[i]][list(this_class_ids)[i]]) genesis_centroids.append(centroid_per_id_per_time[genesis_times[i]][list(this_class_ids)[i]]) - - distances = np.array([hs.haversine(pixel_to_degree(pos1), pixel_to_degree(pos2)) - for pos1, pos2 in zip(termination_centroids, genesis_centroids)]) + if not polar: + distances = np.array([hs.haversine(pixel_to_degree(pos1), pixel_to_degree(pos2)) + for pos1, pos2 in zip(termination_centroids, genesis_centroids)]) + elif polar: + distances = np.array([hs.haversine(stereographic_to_latlon(pixel_to_stereographic(pos1)),stereographic_to_latlon(pixel_to_stereographic(pos2))) + for pos1, pos2 in zip(termination_centroids, genesis_centroids)]) + # travel distance histogram plt.figure(dpi=100) @@ -228,7 +310,10 @@ def map_instance(title): plt.rc('ytick',labelsize=20) mymap = plt.subplot(111,projection=ccrs.PlateCarree()) mymap.set_global() - mymap.background_img(name='BM') + if not polar: + mymap.background_img(name='BM') + elif polar: + mymap.background_img(name='BM_Polar_SPS') mymap.coastlines() mymap.gridlines(crs=ccrs.PlateCarree(),linewidth=2, color='k', alpha=0.5, linestyle='--') mymap.set_xticks([-180,-120,-60,0,60,120,180]) diff --git a/climatenet/bluemarble/BM_Polar_NPS.jpeg b/climatenet/bluemarble/BM_Polar_NPS.jpeg new file mode 100644 index 0000000..d1b517b Binary files /dev/null and b/climatenet/bluemarble/BM_Polar_NPS.jpeg differ diff --git a/climatenet/bluemarble/BM_Polar_SPS.jpeg b/climatenet/bluemarble/BM_Polar_SPS.jpeg new file mode 100644 index 0000000..c8e9a3a Binary files /dev/null and b/climatenet/bluemarble/BM_Polar_SPS.jpeg differ diff --git a/climatenet/utils/losses.py b/climatenet/utils/losses.py index 42e0389..4fbd090 100644 --- a/climatenet/utils/losses.py +++ b/climatenet/utils/losses.py @@ -18,7 +18,7 @@ def jaccard_loss(logits, true, eps=1e-7): jacc_loss: the Jaccard loss. """ num_classes = logits.shape[1] - true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + true_1_hot = torch.eye(num_classes)[true.squeeze(1).long()] true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() probas = F.softmax(logits, dim=1) true_1_hot = true_1_hot.type(logits.type()) @@ -27,4 +27,4 @@ def jaccard_loss(logits, true, eps=1e-7): cardinality = torch.sum(probas + true_1_hot, dims) union = cardinality - intersection jacc_loss = (intersection / (union + eps)).mean() - return (1 - jacc_loss) \ No newline at end of file + return (1 - jacc_loss)