diff --git a/README.md b/README.md index 88fbb8a..e1cda23 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,13 @@ dtype: int64 ``` +## Generating Presto Embeddings in Google Earth Engine +You can generate Presto embeddings in Google Earth Engine by: +1. Deploying Presto to Vertex AI + Open In Colab + +2. Using the [ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai) function in Google Earth Engine ([script on GEE](https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26)) + ## Reference If you find this code useful, please cite the following paper: ``` diff --git a/deploy/1_Presto_to_VertexAI.ipynb b/deploy/1_Presto_to_VertexAI.ipynb new file mode 100644 index 0000000..8048ec4 --- /dev/null +++ b/deploy/1_Presto_to_VertexAI.ipynb @@ -0,0 +1,634 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_SVA9v_JTq_-" + }, + "source": [ + "# 1. Presto to Vertex AI\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "**Authors**: Ivan Zvonkov, Gabriel Tseng, (additional credits: [Earth_Engine_PyTorch_Vertex_AI](https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb))\n", + "\n", + "**Description**: The notebook Deploys Presto to Vertex AI. This is a prerequisite to generating Presto embeddings on Google Earth Engine using\n", + "[ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai).\n", + "\n", + "Once the model is deployed this [GEE script](https://code.earthengine.google.com/df6348b8d47cd751eb5164dccb7b26a9) can be used to generate Presto embeddings.\n", + "\n", + "**Steps**:\n", + "1. Set up environment\n", + "2. Load default Presto model\n", + "3. Transform Presto model into TorchScript\n", + "4. Package TorchScript model into TorchServe\n", + "5. Deploy and use Vertex AI\n", + "\n", + " 5a. Upload TorchServe model to Vertex AI Model Registry [Free]\n", + "\n", + " 5b. Create a Vertex AI Endpoint [Free]\n", + "\n", + " 5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]\n", + "\n", + " 5d. Generate embeddings in Google Earth Engine [Cost depends on region size]\n", + "\n", + " 5e. Undeploy model from endpoint [Free]\n", + "\n", + "**Cost Breakdown**:\n", + "\n", + "*5a. Upload TorchServe model to Vertex AI Model Registry [Free]*\n", + "- Model files are uploaded to Cloud Storage but are lightweight (3.37 Mb total) and thus easily fall into Google Cloud's 5GB/month Storage [Free Tier](https://cloud.google.com/storage/pricing#cloud-storage-always-free)\n", + "- There is no cost to storing models in Vertex AI Model Registry ([source](https://cloud.google.com/vertex-ai/pricing#modelregistry))\n", + "\n", + "*5b. Create a Vertex AI Endpoint [Free]*\n", + "- There is no cost to creating an endpoint. Costs start when a model is deployed to that endpoint\n", + "\n", + "*5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]*\n", + "- The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1)\n", + "- So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made\n", + "\n", + "*5d. Generate embeddings in Google Earth Engine [Cost depends on region size]*\n", + "- Once a model is deployed and `ee.model.fromVertexAi` is used Vertex AI scales the amount of nodes based on amount of data (size of the region)\n", + "- Our current embedding generation cost estimates are \\$5.37 - \\$10.14 per 1000 km2 \n", + "- We compute a cost estimate for your ROI in our Google Earth Engine script\n", + "\n", + "*5e. Undeploy model from endpoint [Free]*\n", + "- Necessary to stop incurring charges from 5c" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hzb1bwgTUZU0" + }, + "source": [ + "## 1. Set up environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KuoEjld3TTLO" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cUI_5pWJ3V4s" + }, + "outputs": [], + "source": [ + "PROJECT = ''\n", + "!gcloud config set project {PROJECT}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MRGjYjltsm6-" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/nasaharvest/presto.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P1zGbf2KIhLA" + }, + "source": [ + "## 2. Load default Presto model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UKgCxBNnYJIB" + }, + "outputs": [], + "source": [ + "# Navigate inside of the repository to import Presto\n", + "%cd /content/presto\n", + "\n", + "import torch\n", + "from single_file_presto import Presto\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "model = Presto.construct()\n", + "model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n", + "model.eval();\n", + "\n", + "# Navigate back to main directory\n", + "%cd /content" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5v7gJoysFTsf" + }, + "source": [ + "## 3. Transform Presto model into TorchScript\n", + "> TorchScript is a way to create serializable and optimizable models from PyTorch code.\n", + "https://docs.pytorch.org/docs/stable/jit.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hDKqqzi9F7T4" + }, + "outputs": [], + "source": [ + "# Construct input manually\n", + "batch_size = 256\n", + "NUM_TIMESTEPS = 12\n", + "X_tensor = torch.zeros([batch_size, NUM_TIMESTEPS, 17])\n", + "latlons_tensor = torch.zeros([batch_size, 2])\n", + "\n", + "dw_empty = torch.full([batch_size, NUM_TIMESTEPS], 9, device=device).long()\n", + "month_tensor = torch.full([batch_size], 1, device=device)\n", + "\n", + "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n", + "# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n", + "mask = torch.zeros(X_tensor.shape, device=device).float()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mdSXGZuikHUk" + }, + "outputs": [], + "source": [ + "# Verify forward pass with regular model\n", + "with torch.no_grad():\n", + " preds = model.encoder(\n", + " x=X_tensor,\n", + " dynamic_world=dw_empty,\n", + " latlons=latlons_tensor,\n", + " mask=mask,\n", + " month=month_tensor\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nFRvUVkHowKr" + }, + "outputs": [], + "source": [ + "# Make model torchscriptable\n", + "example_kwargs = {\n", + " 'x': X_tensor,\n", + " 'dynamic_world': dw_empty,\n", + " 'latlons': latlons_tensor,\n", + " 'mask': mask,\n", + " 'month': month_tensor\n", + "}\n", + "sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n", + "\n", + "!mkdir -p pytorch_model\n", + "sm.save('pytorch_model/model.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cYuSPOyp1A0K" + }, + "outputs": [], + "source": [ + "jit_model = torch.jit.load('pytorch_model/model.pt')\n", + "jit_model(**example_kwargs).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3b0LsZpqnByv" + }, + "source": [ + "## 4. Package TorchScript model into TorchServe\n", + "> TorchServe is a performant, flexible and easy to use tool for serving PyTorch models in production.\n", + "https://docs.pytorch.org/serve/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "i70o_BZml9vs" + }, + "outputs": [], + "source": [ + "!pip install torchserve torch-model-archiver -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "htq2Ac95FJlk" + }, + "outputs": [], + "source": [ + "%%writefile pytorch_model/custom_handler.py\n", + "import logging\n", + "import torch\n", + "from ts.torch_handler.base_handler import BaseHandler\n", + "import numpy as np\n", + "\n", + "# UPDATE BASED ON YOUR NEEDS\n", + "########################################\n", + "VERSION = \"v1\"\n", + "START_MONTH = 3\n", + "BATCH_SIZE = 256\n", + "########################################\n", + "\n", + "def printh(text):\n", + " # Prepends HANDLER to each print statement to make it easier to find in logs.\n", + " print(f\"HANDLER {VERSION}: {text}\")\n", + "\n", + "# Custom TorchServe handler for the Presto model\n", + "class ClassifierHandler(BaseHandler):\n", + "\n", + " def inference(self, data):\n", + " printh(\"Inference begin\")\n", + "\n", + " # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n", + " data = data[:, :, 0, 0]\n", + " printh(f\"Data shape {data.shape}\")\n", + "\n", + " num_bands = 17\n", + " printh(f\"Num_bands {num_bands}\")\n", + "\n", + " # Subtract first two latlon\n", + " num_timesteps = (data.shape[1] - 2) // num_bands\n", + " printh(f\"Num_timesteps {num_timesteps}\")\n", + "\n", + " with torch.no_grad():\n", + "\n", + " batches = torch.split(data, BATCH_SIZE, dim=0)\n", + "\n", + " # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n", + " month_tensor = torch.full([BATCH_SIZE], START_MONTH, device=self.device)\n", + " printh(f\"Month: {START_MONTH}\")\n", + "\n", + " # dynamic_world: torch.Tensor of shape [BATCH_SIZE, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n", + " dw_empty = torch.full([BATCH_SIZE, num_timesteps], 9, device=self.device).long()\n", + " printh(f\"DW {dw_empty[0]}\")\n", + "\n", + " # mask: An optional torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n", + " mask = torch.zeros((BATCH_SIZE, num_timesteps, num_bands), device=self.device).float()\n", + " printh(f\"Mask sample one timestep: {mask[0, 0]}\")\n", + "\n", + " preds_list = []\n", + " for batch in batches:\n", + " padding = 0\n", + " if batch.shape[0] < BATCH_SIZE:\n", + " padding = BATCH_SIZE - batch.shape[0]\n", + " batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n", + "\n", + " # x: torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands] where bands is described by NORMED_BANDS.\n", + " X_tensor = batch[:, 2:]\n", + " printh(f\"X {X_tensor.shape}\")\n", + "\n", + " X_tensor_reshaped = X_tensor.reshape(BATCH_SIZE, num_timesteps, num_bands)\n", + " printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n", + "\n", + " # latlons: torch.Tensor of shape [BATCH_SIZE, 2] describing the latitude and longitude of each input instance.\n", + " latlons_tensor = batch[:, :2]\n", + "\n", + " printh(\"SHAPES\")\n", + " printh(f\"X {X_tensor_reshaped.shape}\")\n", + " printh(f\"DW {dw_empty.shape}\")\n", + " printh(f\"Latlons {latlons_tensor.shape}\")\n", + " printh(f\"Mask {mask.shape}\")\n", + " printh(f\"Month {month_tensor.shape}\")\n", + "\n", + " pred = self.model(\n", + " x=X_tensor_reshaped,\n", + " dynamic_world=dw_empty,\n", + " latlons=latlons_tensor,\n", + " mask=mask,\n", + " month=month_tensor\n", + " )\n", + " pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n", + " if padding == 0:\n", + " preds_list.append(pred_np[:])\n", + " else:\n", + " preds_list.append(pred_np[:-padding])\n", + "\n", + " [printh(f\"{p.shape}\") for p in preds_list]\n", + " preds = np.concatenate(preds_list)\n", + " printh(f\"Preds shape {preds.shape}\")\n", + " return preds\n", + "\n", + " def handle(self, data, context):\n", + " self.context = context\n", + " printh(f\"Handle begin\")\n", + " input_tensor = self.preprocess(data)\n", + " printh(f\"Input_tensor shape {input_tensor.shape}\")\n", + " pred_out = self.inference(input_tensor)\n", + " printh(f\"Inference complete\")\n", + " return self.postprocess(pred_out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a3Dgq5Ob5b1i" + }, + "outputs": [], + "source": [ + "import importlib\n", + "import pytorch_model.custom_handler\n", + "\n", + "importlib.reload(pytorch_model.custom_handler)\n", + "\n", + "from pytorch_model.custom_handler import ClassifierHandler, VERSION\n", + "\n", + "# Test output\n", + "data = torch.zeros([713, 206, 1, 1])\n", + "handler = ClassifierHandler()\n", + "handler.model = jit_model\n", + "preds = handler.handle(data=data, context=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "90TXNAnfF-TD" + }, + "outputs": [], + "source": [ + "!torch-model-archiver -f \\\n", + " --model-name model \\\n", + " --version 1.0 \\\n", + " --serialized-file 'pytorch_model/model.pt' \\\n", + " --handler 'pytorch_model/custom_handler.py' \\\n", + " --export-path pytorch_model/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PYK9g3r1qVru" + }, + "source": [ + "## 5. Deploy and use Vertex AI" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CH0JO_Jww5Ok" + }, + "source": [ + "### 5a. Upload TorchServe model to Vertex AI Model Registry\n", + "> The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models.\n", + "https://cloud.google.com/vertex-ai/docs/model-registry/introduction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lEUqcAqsTYpn" + }, + "outputs": [], + "source": [ + "REGION = 'us-central1'\n", + "BUCKET_NAME = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OiyyCNa-TBGQ" + }, + "outputs": [], + "source": [ + "# Create bucket to store model artifcats if it doesn't exist\n", + "!gcloud storage buckets create gs://{BUCKET_NAME} --location={REGION}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n8m9UBy3GEvZ" + }, + "outputs": [], + "source": [ + "MODEL_DIR = f'gs://{BUCKET_NAME}/{VERSION}'\n", + "!gsutil cp -r pytorch_model {MODEL_DIR}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AetRF8dcGraC" + }, + "outputs": [], + "source": [ + "# Can take 2 minutes\n", + "MODEL_NAME = f'model_{VERSION}'\n", + "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n", + "\n", + "!gcloud ai models upload \\\n", + " --artifact-uri={MODEL_DIR} \\\n", + " --region={REGION} \\\n", + " --container-image-uri={CONTAINER_IMAGE} \\\n", + " --description={MODEL_NAME} \\\n", + " --display-name={MODEL_NAME} \\\n", + " --model-id={MODEL_NAME}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_BXDtoITxY2T" + }, + "source": [ + "### 5b. Create a Vertex AI Endpoint\n", + "> To deploy a model for online prediction, you need an endpoint.\n", + "https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XhMK9aA73FzI" + }, + "outputs": [], + "source": [ + "ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n", + "\n", + "endpoints = !gcloud ai endpoints list --region={REGION} --format='get(DISPLAY_NAME)'\n", + "\n", + "if ENDPOINT_NAME in endpoints:\n", + " print(f\"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.\")\n", + "else:\n", + " print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating... (~3 minutes)\")\n", + " !gcloud ai endpoints create \\\n", + " --display-name={ENDPOINT_NAME} \\\n", + " --endpoint-id={ENDPOINT_NAME} \\\n", + " --region={REGION}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dWODP6ccx8-3" + }, + "source": [ + "### 5c. Deploy model to endpoint\n", + "> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.\n", + "https://cloud.google.com/vertex-ai/docs/general/deployment\n", + "\n", + "⚠️ The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1). So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sI297v1mjtR8" + }, + "outputs": [], + "source": [ + "# Deploy model to endpoint, this will start an e2-standard-2 machine which costs money\n", + "print(\"Track model deployment progress and prediction logs:\")\n", + "print(f\"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\\n\")\n", + "\n", + "# If using for large region, set min-replica-count higher to save scaling time\n", + "# Can take from 4-27 minutes\n", + "# Relevant quota: \"Custom model serving CPUs per region\"\n", + "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n", + " --region={REGION} \\\n", + " --model={MODEL_NAME} \\\n", + " --display-name={MODEL_NAME} \\\n", + " --machine-type=\"e2-standard-2\" \\\n", + " --min-replica-count='1' \\\n", + " --max-replica-count=\"100\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r2VtGUv9JOI9" + }, + "source": [ + "### 5d. Generate embeddings in Google Earth Engine\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "160PNcRRJMzn" + }, + "outputs": [], + "source": [ + "GEE_SCRIPT_URL = \"https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c\"\n", + "print(f\"Open this script: {GEE_SCRIPT_URL}\")\n", + "print(\"Use the below string for the ENDPOINT variable\")\n", + "print(f\"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8PnhA3gfHSrY" + }, + "source": [ + "### 5e. Undeploy model from endpoint\n", + "\n", + "Once predictions are made, you must undeploy your model to stop incurring further charges.\n", + "\n", + "This can be done using the below code or by using the Google Cloud console directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VvOO_sfQDWPt" + }, + "outputs": [], + "source": [ + "def get_deployed_model():\n", + " deployed_models = !gcloud ai endpoints describe {ENDPOINT_NAME} --region={REGION} --format 'get(deployedModels)'\n", + " if deployed_models[1] == '':\n", + " print(\"No models deployed\")\n", + " else:\n", + " print(deployed_model_id)\n", + " return eval(deployed_models[1])['id']\n", + "\n", + "deployed_model_id = get_deployed_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vswFTu9kFeHy" + }, + "outputs": [], + "source": [ + "!gcloud ai endpoints undeploy-model {ENDPOINT_NAME} --region={REGION} --deployed-model-id={deployed_model_id}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w9Fq6yspF2ye" + }, + "outputs": [], + "source": [ + "get_deployed_model()" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/deploy/2_Generate_Embeddings.js b/deploy/2_Generate_Embeddings.js new file mode 100644 index 0000000..49f78cb --- /dev/null +++ b/deploy/2_Generate_Embeddings.js @@ -0,0 +1,246 @@ +//------------------------------------------------------------------------------------ +// Script for generating Presto embeddings using Vertex AI +// Author: Ivan Zvonkov (izvonkov@umd.edu) +//------------------------------------------------------------------------------------ +// 1. Presto embedding generation parameters (set parameters according to your needs) +//------------------------------------------------------------------------------------ +var roi = ee + .FeatureCollection('FAO/GAUL/2015/level2') + .filter(ee.Filter.eq('ADM2_NAME', 'Haho')); +var PROJ = 'EPSG:25231'; + +var rangeStart = ee.Date('2019-03-01'); +var rangeEnd = ee.Date('2020-03-01'); + +var ENDPOINT = + 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'; +var RUN_VERTEX_AI = false; // Leave this as false to get a cost estimate first +//------------------------------------------------------------------------------------ + +Map.centerObject(roi, 10); +Map.addLayer(roi, {}, 'Region of Interest'); +Map.setOptions('satellite'); + +// 2. Cost Computation +var roiAreaKM2 = roi.geometry().area().divide(1e6); +function estimate(cost) { + return roiAreaKM2.divide(1000).multiply(cost).toInt().getInfo(); +} +print('ROI Area: ' + roiAreaKM2.toInt().getInfo() + ' km2'); +print( + 'Embedding Generation Estimates\nCost: $' + + estimate(5.37) + + '-' + + estimate(10.14) +); +if (!RUN_VERTEX_AI) + print( + 'If you are ready to generate embeddings,\nchange RUN_VERTEX_AI variable to true' + ); + +// 3. Obtain monthly Sentinel-1 composites +var S1_BANDS = ['VV', 'VH']; +var S1_all = ee + .ImageCollection('COPERNICUS/S1_GRD') + .filterBounds(roi) + .filterDate( + ee.Date(rangeStart).advance(-31, 'days'), + ee.Date(rangeEnd).advance(31, 'days') + ); + +var S1 = S1_all.filter( + ee.Filter.eq( + 'orbitProperties_pass', + S1_all.first().get('orbitProperties_pass') + ) +).filter(ee.Filter.eq('instrumentMode', 'IW')); +var S1_VV = S1.filter( + ee.Filter.listContains('transmitterReceiverPolarisation', 'VV') +); +var S1_VH = S1.filter( + ee.Filter.listContains('transmitterReceiverPolarisation', 'VH') +); + +function getCloseImages(middleDate, imageCollection) { + var fromMiddleDate = imageCollection + .map(function (img) { + var dateDist = ee + .Number(img.get('system:time_start')) + .subtract(middleDate.millis()) + .abs(); + return img.set('dateDist', dateDist); + }) + .sort({ property: 'dateDist', ascending: true }); + var fifteenDaysInMs = ee.Number(1296000000); + var maxDiff = ee + .Number(fromMiddleDate.first().get('dateDist')) + .max(fifteenDaysInMs); + return fromMiddleDate.filterMetadata( + 'dateDist', + 'not_greater_than', + maxDiff + ); +} + +function S1_img(date1, date2) { + var startDate = ee.Date(date1); + var daysBetween = ee.Date(date2).difference(startDate, 'days'); + var middleDate = startDate.advance(daysBetween.divide(2), 'days'); + var kept_vv = getCloseImages(middleDate, S1_VV).select('VV'); + var kept_vh = getCloseImages(middleDate, S1_VH).select('VH'); + var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()]); + return S1_composite.select(S1_BANDS).add(25.0).divide(25.0); // S1 ranges from -50 to 1 +} + +// 4. Obtain monthly Sentinel-2 composites +var S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']; +var S2 = ee + .ImageCollection('COPERNICUS/S2_SR_HARMONIZED') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +var csPlus = ee + .ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +var QA_BAND = 'cs_cdf'; // Better than cs here +var S2_cf = S2.linkCollection(csPlus, [QA_BAND]); + +function S2_img(date1, date2) { + return S2_cf.filterDate(date1, date2) + .qualityMosaic(QA_BAND) + .select(S2_BANDS) + .divide(ee.Image(1e4)); +} + +// 5. Obtain monthly ERA5 composites +var ERA5_BANDS = ['temperature_2m', 'total_precipitation_sum']; +var ERA5 = ee + .ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +function ERA5_img(date1, date2) { + return ERA5.filterDate(date1, date2) + .select(ERA5_BANDS) + .mean() + .add([-272.15, 0]) + .divide([35, 0.03]); +} +//var ERA5_temp = ee.Image([0,0]).rename(ERA5_BANDS).clip(roi) + +// 6. Obtain SRTM Data +var SRTM_BANDS = ['elevation', 'slope']; +var elevation = ee.Image('USGS/SRTMGL1_003').clip(roi).select('elevation'); +var slope = ee.Terrain.slope(elevation); +var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50]); +//var SRTM_temp = ee.Image([0,0]).rename(SRTM_BANDS).clip(roi) + +// 7. Combine all data into a monthly CropHarvest-style monthly composite +function cropharvest_img(d1, d2) { + var img = ee.Image.cat([ + S1_img(d1, d2), + S2_img(d1, d2), + ERA5_img(d1, d2), + SRTM_img, + ]); + var ndvi = img.normalizedDifference(['B8', 'B4']).rename('NDVI'); + // toFloat Necessary for tensor conversion + return img.addBands(ndvi).clip(roi).toFloat(); +} + +// 8. Create and visualize Presto input +var latlons = ee.Image.pixelLonLat().clip(roi).select('latitude', 'longitude'); +var imgs = [latlons]; +var numMonths = rangeEnd.difference(rangeStart, 'month').toInt().getInfo(); +var ERA5Palette = [ + '000080', + '0000d9', + '4000ff', + '8000ff', + '0080ff', + '00ffff', + '00ff80', + '80ff00', + 'daff00', + 'ffff00', + 'fff500', + 'ffda00', + 'ffb000', + 'ffa400', + 'ff4f00', + 'ff2500', + 'ff0a00', + 'ff00ff', +]; + +for (var i = 0; i < numMonths; i++) { + var monthStart = rangeStart.advance(i, 'month'); + var monthEnd = monthStart.advance(1, 'month'); + var img = cropharvest_img(monthStart, monthEnd); + imgs.push(img); + + var monthName = monthStart.format('YY/MM').getInfo(); + Map.addLayer( + img, + { + bands: ['VV', 'VH', 'VV'], + min: [0, -0.2, 0.4], + max: [1.0, 0.8, 1.2], + }, + monthName + ' S1', + false + ); + Map.addLayer( + img, + { bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 }, + monthName + ' S2', + false + ); + Map.addLayer( + img, + { bands: ['temperature_2m'], min: 0, max: 1, palette: ERA5Palette }, + monthName + ' ERA5', + false + ); +} +Map.addLayer(imgs[1], { bands: ['slope'], min: 0, max: 0.3 }, 'SRTM', false); + +var composite = ee.ImageCollection.fromImages(imgs).toBands(); + +// 9. Make predictions using Presto on Vertex AI +var vertex_model = ee.Model.fromVertexAi({ + endpoint: ENDPOINT, + inputTileSize: [1, 1], + proj: ee.Projection('EPSG:4326').atScale(10), + fixInputProj: true, + outputTileSize: [1, 1], + outputBands: { p: { type: ee.PixelType.float(), dimensions: 1 } }, + payloadFormat: 'ND_ARRAYS', + maxPayloadBytes: 5242880, // 5.24mb [MAX] +}); + +if (RUN_VERTEX_AI) { + // Create band names for embeddingsArrayImage + var bandNames = []; + for (var i = 0; i < 128; i++) { + bandNames.push('b' + i + ''); + } + + // embeddingsArrayImage is a single band image where each pixel contains an array + var embeddingsArrayImage = vertex_model.predictImage(composite).clip(roi); + var embeddingsMultiBandImage = embeddingsArrayImage.arrayFlatten([ + bandNames, + ]); + + // Only smaller size embeddings can be directly viewed in GEE immediatley larger ones require the batch task + // Map.addLayer(embeddingsMultiBandImage, {min: 0, max: 1},'embeddingsMultiBandImage') + + Export.image.toAsset({ + image: embeddingsMultiBandImage, + description: 'Presto_embeddings', + assetId: 'Togo/Presto_test_embeddings_v2025_04_23', + region: roi, + scale: 10, + maxPixels: 1e12, + crs: 'EPSG:25231', + }); +} diff --git a/deploy/3_Kmeans_Embeddings.js b/deploy/3_Kmeans_Embeddings.js new file mode 100644 index 0000000..bcd1b17 --- /dev/null +++ b/deploy/3_Kmeans_Embeddings.js @@ -0,0 +1,22 @@ +//------------------------------------------------------------------------------------ +// Script for checking the embedding salience through clutering +// Author: Ivan Zvonkov (izvonkov@umd.edu) +//------------------------------------------------------------------------------------ + +// 1. Load embeddings +var embeddings = ee.Image( + 'users/izvonkov/Togo/Presto_test_embeddings_v2025_05_16' +); +var roi = embeddings.geometry({ geodesics: true }); +Map.centerObject(roi, 11); + +// 2. Cluster embeddings and display (7 clusters) +var training = embeddings.sample({ region: roi, scale: 10, numPixels: 10000 }); +var trainedClusterer = ee.Clusterer.wekaKMeans(7).train(training); +var result = embeddings.cluster(trainedClusterer); +Map.addLayer(result.randomVisualizer(), {}, 'clusters'); + +// 3. Display WorldCover +var WorldCover = ee.ImageCollection('ESA/WorldCover/v200').first().clip(roi); +var vis = { bands: ['Map'] }; +Map.addLayer(WorldCover, vis, 'WorldCover');