diff --git a/README.md b/README.md
index 88fbb8a..e1cda23 100644
--- a/README.md
+++ b/README.md
@@ -169,6 +169,13 @@ dtype: int64
```
+## Generating Presto Embeddings in Google Earth Engine
+You can generate Presto embeddings in Google Earth Engine by:
+1. Deploying Presto to Vertex AI
+
+
+2. Using the [ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai) function in Google Earth Engine ([script on GEE](https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26))
+
## Reference
If you find this code useful, please cite the following paper:
```
diff --git a/deploy/1_Presto_to_VertexAI.ipynb b/deploy/1_Presto_to_VertexAI.ipynb
new file mode 100644
index 0000000..8048ec4
--- /dev/null
+++ b/deploy/1_Presto_to_VertexAI.ipynb
@@ -0,0 +1,634 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_SVA9v_JTq_-"
+ },
+ "source": [
+ "# 1. Presto to Vertex AI\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "**Authors**: Ivan Zvonkov, Gabriel Tseng, (additional credits: [Earth_Engine_PyTorch_Vertex_AI](https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb))\n",
+ "\n",
+ "**Description**: The notebook Deploys Presto to Vertex AI. This is a prerequisite to generating Presto embeddings on Google Earth Engine using\n",
+ "[ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai).\n",
+ "\n",
+ "Once the model is deployed this [GEE script](https://code.earthengine.google.com/df6348b8d47cd751eb5164dccb7b26a9) can be used to generate Presto embeddings.\n",
+ "\n",
+ "**Steps**:\n",
+ "1. Set up environment\n",
+ "2. Load default Presto model\n",
+ "3. Transform Presto model into TorchScript\n",
+ "4. Package TorchScript model into TorchServe\n",
+ "5. Deploy and use Vertex AI\n",
+ "\n",
+ " 5a. Upload TorchServe model to Vertex AI Model Registry [Free]\n",
+ "\n",
+ " 5b. Create a Vertex AI Endpoint [Free]\n",
+ "\n",
+ " 5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]\n",
+ "\n",
+ " 5d. Generate embeddings in Google Earth Engine [Cost depends on region size]\n",
+ "\n",
+ " 5e. Undeploy model from endpoint [Free]\n",
+ "\n",
+ "**Cost Breakdown**:\n",
+ "\n",
+ "*5a. Upload TorchServe model to Vertex AI Model Registry [Free]*\n",
+ "- Model files are uploaded to Cloud Storage but are lightweight (3.37 Mb total) and thus easily fall into Google Cloud's 5GB/month Storage [Free Tier](https://cloud.google.com/storage/pricing#cloud-storage-always-free)\n",
+ "- There is no cost to storing models in Vertex AI Model Registry ([source](https://cloud.google.com/vertex-ai/pricing#modelregistry))\n",
+ "\n",
+ "*5b. Create a Vertex AI Endpoint [Free]*\n",
+ "- There is no cost to creating an endpoint. Costs start when a model is deployed to that endpoint\n",
+ "\n",
+ "*5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]*\n",
+ "- The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1)\n",
+ "- So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made\n",
+ "\n",
+ "*5d. Generate embeddings in Google Earth Engine [Cost depends on region size]*\n",
+ "- Once a model is deployed and `ee.model.fromVertexAi` is used Vertex AI scales the amount of nodes based on amount of data (size of the region)\n",
+ "- Our current embedding generation cost estimates are \\$5.37 - \\$10.14 per 1000 km2 \n",
+ "- We compute a cost estimate for your ROI in our Google Earth Engine script\n",
+ "\n",
+ "*5e. Undeploy model from endpoint [Free]*\n",
+ "- Necessary to stop incurring charges from 5c"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hzb1bwgTUZU0"
+ },
+ "source": [
+ "## 1. Set up environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KuoEjld3TTLO"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import auth\n",
+ "\n",
+ "auth.authenticate_user()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cUI_5pWJ3V4s"
+ },
+ "outputs": [],
+ "source": [
+ "PROJECT = ''\n",
+ "!gcloud config set project {PROJECT}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MRGjYjltsm6-"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/nasaharvest/presto.git"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "P1zGbf2KIhLA"
+ },
+ "source": [
+ "## 2. Load default Presto model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UKgCxBNnYJIB"
+ },
+ "outputs": [],
+ "source": [
+ "# Navigate inside of the repository to import Presto\n",
+ "%cd /content/presto\n",
+ "\n",
+ "import torch\n",
+ "from single_file_presto import Presto\n",
+ "\n",
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
+ "\n",
+ "model = Presto.construct()\n",
+ "model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n",
+ "model.eval();\n",
+ "\n",
+ "# Navigate back to main directory\n",
+ "%cd /content"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5v7gJoysFTsf"
+ },
+ "source": [
+ "## 3. Transform Presto model into TorchScript\n",
+ "> TorchScript is a way to create serializable and optimizable models from PyTorch code.\n",
+ "https://docs.pytorch.org/docs/stable/jit.html"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hDKqqzi9F7T4"
+ },
+ "outputs": [],
+ "source": [
+ "# Construct input manually\n",
+ "batch_size = 256\n",
+ "NUM_TIMESTEPS = 12\n",
+ "X_tensor = torch.zeros([batch_size, NUM_TIMESTEPS, 17])\n",
+ "latlons_tensor = torch.zeros([batch_size, 2])\n",
+ "\n",
+ "dw_empty = torch.full([batch_size, NUM_TIMESTEPS], 9, device=device).long()\n",
+ "month_tensor = torch.full([batch_size], 1, device=device)\n",
+ "\n",
+ "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n",
+ "# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n",
+ "mask = torch.zeros(X_tensor.shape, device=device).float()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mdSXGZuikHUk"
+ },
+ "outputs": [],
+ "source": [
+ "# Verify forward pass with regular model\n",
+ "with torch.no_grad():\n",
+ " preds = model.encoder(\n",
+ " x=X_tensor,\n",
+ " dynamic_world=dw_empty,\n",
+ " latlons=latlons_tensor,\n",
+ " mask=mask,\n",
+ " month=month_tensor\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nFRvUVkHowKr"
+ },
+ "outputs": [],
+ "source": [
+ "# Make model torchscriptable\n",
+ "example_kwargs = {\n",
+ " 'x': X_tensor,\n",
+ " 'dynamic_world': dw_empty,\n",
+ " 'latlons': latlons_tensor,\n",
+ " 'mask': mask,\n",
+ " 'month': month_tensor\n",
+ "}\n",
+ "sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n",
+ "\n",
+ "!mkdir -p pytorch_model\n",
+ "sm.save('pytorch_model/model.pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "cYuSPOyp1A0K"
+ },
+ "outputs": [],
+ "source": [
+ "jit_model = torch.jit.load('pytorch_model/model.pt')\n",
+ "jit_model(**example_kwargs).shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3b0LsZpqnByv"
+ },
+ "source": [
+ "## 4. Package TorchScript model into TorchServe\n",
+ "> TorchServe is a performant, flexible and easy to use tool for serving PyTorch models in production.\n",
+ "https://docs.pytorch.org/serve/"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "i70o_BZml9vs"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install torchserve torch-model-archiver -q"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "htq2Ac95FJlk"
+ },
+ "outputs": [],
+ "source": [
+ "%%writefile pytorch_model/custom_handler.py\n",
+ "import logging\n",
+ "import torch\n",
+ "from ts.torch_handler.base_handler import BaseHandler\n",
+ "import numpy as np\n",
+ "\n",
+ "# UPDATE BASED ON YOUR NEEDS\n",
+ "########################################\n",
+ "VERSION = \"v1\"\n",
+ "START_MONTH = 3\n",
+ "BATCH_SIZE = 256\n",
+ "########################################\n",
+ "\n",
+ "def printh(text):\n",
+ " # Prepends HANDLER to each print statement to make it easier to find in logs.\n",
+ " print(f\"HANDLER {VERSION}: {text}\")\n",
+ "\n",
+ "# Custom TorchServe handler for the Presto model\n",
+ "class ClassifierHandler(BaseHandler):\n",
+ "\n",
+ " def inference(self, data):\n",
+ " printh(\"Inference begin\")\n",
+ "\n",
+ " # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n",
+ " data = data[:, :, 0, 0]\n",
+ " printh(f\"Data shape {data.shape}\")\n",
+ "\n",
+ " num_bands = 17\n",
+ " printh(f\"Num_bands {num_bands}\")\n",
+ "\n",
+ " # Subtract first two latlon\n",
+ " num_timesteps = (data.shape[1] - 2) // num_bands\n",
+ " printh(f\"Num_timesteps {num_timesteps}\")\n",
+ "\n",
+ " with torch.no_grad():\n",
+ "\n",
+ " batches = torch.split(data, BATCH_SIZE, dim=0)\n",
+ "\n",
+ " # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n",
+ " month_tensor = torch.full([BATCH_SIZE], START_MONTH, device=self.device)\n",
+ " printh(f\"Month: {START_MONTH}\")\n",
+ "\n",
+ " # dynamic_world: torch.Tensor of shape [BATCH_SIZE, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n",
+ " dw_empty = torch.full([BATCH_SIZE, num_timesteps], 9, device=self.device).long()\n",
+ " printh(f\"DW {dw_empty[0]}\")\n",
+ "\n",
+ " # mask: An optional torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n",
+ " mask = torch.zeros((BATCH_SIZE, num_timesteps, num_bands), device=self.device).float()\n",
+ " printh(f\"Mask sample one timestep: {mask[0, 0]}\")\n",
+ "\n",
+ " preds_list = []\n",
+ " for batch in batches:\n",
+ " padding = 0\n",
+ " if batch.shape[0] < BATCH_SIZE:\n",
+ " padding = BATCH_SIZE - batch.shape[0]\n",
+ " batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n",
+ "\n",
+ " # x: torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands] where bands is described by NORMED_BANDS.\n",
+ " X_tensor = batch[:, 2:]\n",
+ " printh(f\"X {X_tensor.shape}\")\n",
+ "\n",
+ " X_tensor_reshaped = X_tensor.reshape(BATCH_SIZE, num_timesteps, num_bands)\n",
+ " printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n",
+ "\n",
+ " # latlons: torch.Tensor of shape [BATCH_SIZE, 2] describing the latitude and longitude of each input instance.\n",
+ " latlons_tensor = batch[:, :2]\n",
+ "\n",
+ " printh(\"SHAPES\")\n",
+ " printh(f\"X {X_tensor_reshaped.shape}\")\n",
+ " printh(f\"DW {dw_empty.shape}\")\n",
+ " printh(f\"Latlons {latlons_tensor.shape}\")\n",
+ " printh(f\"Mask {mask.shape}\")\n",
+ " printh(f\"Month {month_tensor.shape}\")\n",
+ "\n",
+ " pred = self.model(\n",
+ " x=X_tensor_reshaped,\n",
+ " dynamic_world=dw_empty,\n",
+ " latlons=latlons_tensor,\n",
+ " mask=mask,\n",
+ " month=month_tensor\n",
+ " )\n",
+ " pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n",
+ " if padding == 0:\n",
+ " preds_list.append(pred_np[:])\n",
+ " else:\n",
+ " preds_list.append(pred_np[:-padding])\n",
+ "\n",
+ " [printh(f\"{p.shape}\") for p in preds_list]\n",
+ " preds = np.concatenate(preds_list)\n",
+ " printh(f\"Preds shape {preds.shape}\")\n",
+ " return preds\n",
+ "\n",
+ " def handle(self, data, context):\n",
+ " self.context = context\n",
+ " printh(f\"Handle begin\")\n",
+ " input_tensor = self.preprocess(data)\n",
+ " printh(f\"Input_tensor shape {input_tensor.shape}\")\n",
+ " pred_out = self.inference(input_tensor)\n",
+ " printh(f\"Inference complete\")\n",
+ " return self.postprocess(pred_out)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "a3Dgq5Ob5b1i"
+ },
+ "outputs": [],
+ "source": [
+ "import importlib\n",
+ "import pytorch_model.custom_handler\n",
+ "\n",
+ "importlib.reload(pytorch_model.custom_handler)\n",
+ "\n",
+ "from pytorch_model.custom_handler import ClassifierHandler, VERSION\n",
+ "\n",
+ "# Test output\n",
+ "data = torch.zeros([713, 206, 1, 1])\n",
+ "handler = ClassifierHandler()\n",
+ "handler.model = jit_model\n",
+ "preds = handler.handle(data=data, context=None)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "90TXNAnfF-TD"
+ },
+ "outputs": [],
+ "source": [
+ "!torch-model-archiver -f \\\n",
+ " --model-name model \\\n",
+ " --version 1.0 \\\n",
+ " --serialized-file 'pytorch_model/model.pt' \\\n",
+ " --handler 'pytorch_model/custom_handler.py' \\\n",
+ " --export-path pytorch_model/"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PYK9g3r1qVru"
+ },
+ "source": [
+ "## 5. Deploy and use Vertex AI"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CH0JO_Jww5Ok"
+ },
+ "source": [
+ "### 5a. Upload TorchServe model to Vertex AI Model Registry\n",
+ "> The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models.\n",
+ "https://cloud.google.com/vertex-ai/docs/model-registry/introduction"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "lEUqcAqsTYpn"
+ },
+ "outputs": [],
+ "source": [
+ "REGION = 'us-central1'\n",
+ "BUCKET_NAME = \"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OiyyCNa-TBGQ"
+ },
+ "outputs": [],
+ "source": [
+ "# Create bucket to store model artifcats if it doesn't exist\n",
+ "!gcloud storage buckets create gs://{BUCKET_NAME} --location={REGION}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "n8m9UBy3GEvZ"
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_DIR = f'gs://{BUCKET_NAME}/{VERSION}'\n",
+ "!gsutil cp -r pytorch_model {MODEL_DIR}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AetRF8dcGraC"
+ },
+ "outputs": [],
+ "source": [
+ "# Can take 2 minutes\n",
+ "MODEL_NAME = f'model_{VERSION}'\n",
+ "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n",
+ "\n",
+ "!gcloud ai models upload \\\n",
+ " --artifact-uri={MODEL_DIR} \\\n",
+ " --region={REGION} \\\n",
+ " --container-image-uri={CONTAINER_IMAGE} \\\n",
+ " --description={MODEL_NAME} \\\n",
+ " --display-name={MODEL_NAME} \\\n",
+ " --model-id={MODEL_NAME}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_BXDtoITxY2T"
+ },
+ "source": [
+ "### 5b. Create a Vertex AI Endpoint\n",
+ "> To deploy a model for online prediction, you need an endpoint.\n",
+ "https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "XhMK9aA73FzI"
+ },
+ "outputs": [],
+ "source": [
+ "ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n",
+ "\n",
+ "endpoints = !gcloud ai endpoints list --region={REGION} --format='get(DISPLAY_NAME)'\n",
+ "\n",
+ "if ENDPOINT_NAME in endpoints:\n",
+ " print(f\"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.\")\n",
+ "else:\n",
+ " print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating... (~3 minutes)\")\n",
+ " !gcloud ai endpoints create \\\n",
+ " --display-name={ENDPOINT_NAME} \\\n",
+ " --endpoint-id={ENDPOINT_NAME} \\\n",
+ " --region={REGION}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dWODP6ccx8-3"
+ },
+ "source": [
+ "### 5c. Deploy model to endpoint\n",
+ "> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.\n",
+ "https://cloud.google.com/vertex-ai/docs/general/deployment\n",
+ "\n",
+ "⚠️ The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1). So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sI297v1mjtR8"
+ },
+ "outputs": [],
+ "source": [
+ "# Deploy model to endpoint, this will start an e2-standard-2 machine which costs money\n",
+ "print(\"Track model deployment progress and prediction logs:\")\n",
+ "print(f\"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\\n\")\n",
+ "\n",
+ "# If using for large region, set min-replica-count higher to save scaling time\n",
+ "# Can take from 4-27 minutes\n",
+ "# Relevant quota: \"Custom model serving CPUs per region\"\n",
+ "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n",
+ " --region={REGION} \\\n",
+ " --model={MODEL_NAME} \\\n",
+ " --display-name={MODEL_NAME} \\\n",
+ " --machine-type=\"e2-standard-2\" \\\n",
+ " --min-replica-count='1' \\\n",
+ " --max-replica-count=\"100\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "r2VtGUv9JOI9"
+ },
+ "source": [
+ "### 5d. Generate embeddings in Google Earth Engine\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "160PNcRRJMzn"
+ },
+ "outputs": [],
+ "source": [
+ "GEE_SCRIPT_URL = \"https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c\"\n",
+ "print(f\"Open this script: {GEE_SCRIPT_URL}\")\n",
+ "print(\"Use the below string for the ENDPOINT variable\")\n",
+ "print(f\"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8PnhA3gfHSrY"
+ },
+ "source": [
+ "### 5e. Undeploy model from endpoint\n",
+ "\n",
+ "Once predictions are made, you must undeploy your model to stop incurring further charges.\n",
+ "\n",
+ "This can be done using the below code or by using the Google Cloud console directly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VvOO_sfQDWPt"
+ },
+ "outputs": [],
+ "source": [
+ "def get_deployed_model():\n",
+ " deployed_models = !gcloud ai endpoints describe {ENDPOINT_NAME} --region={REGION} --format 'get(deployedModels)'\n",
+ " if deployed_models[1] == '':\n",
+ " print(\"No models deployed\")\n",
+ " else:\n",
+ " print(deployed_model_id)\n",
+ " return eval(deployed_models[1])['id']\n",
+ "\n",
+ "deployed_model_id = get_deployed_model()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "vswFTu9kFeHy"
+ },
+ "outputs": [],
+ "source": [
+ "!gcloud ai endpoints undeploy-model {ENDPOINT_NAME} --region={REGION} --deployed-model-id={deployed_model_id}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "w9Fq6yspF2ye"
+ },
+ "outputs": [],
+ "source": [
+ "get_deployed_model()"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/deploy/2_Generate_Embeddings.js b/deploy/2_Generate_Embeddings.js
new file mode 100644
index 0000000..49f78cb
--- /dev/null
+++ b/deploy/2_Generate_Embeddings.js
@@ -0,0 +1,246 @@
+//------------------------------------------------------------------------------------
+// Script for generating Presto embeddings using Vertex AI
+// Author: Ivan Zvonkov (izvonkov@umd.edu)
+//------------------------------------------------------------------------------------
+// 1. Presto embedding generation parameters (set parameters according to your needs)
+//------------------------------------------------------------------------------------
+var roi = ee
+ .FeatureCollection('FAO/GAUL/2015/level2')
+ .filter(ee.Filter.eq('ADM2_NAME', 'Haho'));
+var PROJ = 'EPSG:25231';
+
+var rangeStart = ee.Date('2019-03-01');
+var rangeEnd = ee.Date('2020-03-01');
+
+var ENDPOINT =
+ 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint';
+var RUN_VERTEX_AI = false; // Leave this as false to get a cost estimate first
+//------------------------------------------------------------------------------------
+
+Map.centerObject(roi, 10);
+Map.addLayer(roi, {}, 'Region of Interest');
+Map.setOptions('satellite');
+
+// 2. Cost Computation
+var roiAreaKM2 = roi.geometry().area().divide(1e6);
+function estimate(cost) {
+ return roiAreaKM2.divide(1000).multiply(cost).toInt().getInfo();
+}
+print('ROI Area: ' + roiAreaKM2.toInt().getInfo() + ' km2');
+print(
+ 'Embedding Generation Estimates\nCost: $' +
+ estimate(5.37) +
+ '-' +
+ estimate(10.14)
+);
+if (!RUN_VERTEX_AI)
+ print(
+ 'If you are ready to generate embeddings,\nchange RUN_VERTEX_AI variable to true'
+ );
+
+// 3. Obtain monthly Sentinel-1 composites
+var S1_BANDS = ['VV', 'VH'];
+var S1_all = ee
+ .ImageCollection('COPERNICUS/S1_GRD')
+ .filterBounds(roi)
+ .filterDate(
+ ee.Date(rangeStart).advance(-31, 'days'),
+ ee.Date(rangeEnd).advance(31, 'days')
+ );
+
+var S1 = S1_all.filter(
+ ee.Filter.eq(
+ 'orbitProperties_pass',
+ S1_all.first().get('orbitProperties_pass')
+ )
+).filter(ee.Filter.eq('instrumentMode', 'IW'));
+var S1_VV = S1.filter(
+ ee.Filter.listContains('transmitterReceiverPolarisation', 'VV')
+);
+var S1_VH = S1.filter(
+ ee.Filter.listContains('transmitterReceiverPolarisation', 'VH')
+);
+
+function getCloseImages(middleDate, imageCollection) {
+ var fromMiddleDate = imageCollection
+ .map(function (img) {
+ var dateDist = ee
+ .Number(img.get('system:time_start'))
+ .subtract(middleDate.millis())
+ .abs();
+ return img.set('dateDist', dateDist);
+ })
+ .sort({ property: 'dateDist', ascending: true });
+ var fifteenDaysInMs = ee.Number(1296000000);
+ var maxDiff = ee
+ .Number(fromMiddleDate.first().get('dateDist'))
+ .max(fifteenDaysInMs);
+ return fromMiddleDate.filterMetadata(
+ 'dateDist',
+ 'not_greater_than',
+ maxDiff
+ );
+}
+
+function S1_img(date1, date2) {
+ var startDate = ee.Date(date1);
+ var daysBetween = ee.Date(date2).difference(startDate, 'days');
+ var middleDate = startDate.advance(daysBetween.divide(2), 'days');
+ var kept_vv = getCloseImages(middleDate, S1_VV).select('VV');
+ var kept_vh = getCloseImages(middleDate, S1_VH).select('VH');
+ var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()]);
+ return S1_composite.select(S1_BANDS).add(25.0).divide(25.0); // S1 ranges from -50 to 1
+}
+
+// 4. Obtain monthly Sentinel-2 composites
+var S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12'];
+var S2 = ee
+ .ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
+ .filterBounds(roi)
+ .filterDate(rangeStart, rangeEnd);
+var csPlus = ee
+ .ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED')
+ .filterBounds(roi)
+ .filterDate(rangeStart, rangeEnd);
+var QA_BAND = 'cs_cdf'; // Better than cs here
+var S2_cf = S2.linkCollection(csPlus, [QA_BAND]);
+
+function S2_img(date1, date2) {
+ return S2_cf.filterDate(date1, date2)
+ .qualityMosaic(QA_BAND)
+ .select(S2_BANDS)
+ .divide(ee.Image(1e4));
+}
+
+// 5. Obtain monthly ERA5 composites
+var ERA5_BANDS = ['temperature_2m', 'total_precipitation_sum'];
+var ERA5 = ee
+ .ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR')
+ .filterBounds(roi)
+ .filterDate(rangeStart, rangeEnd);
+function ERA5_img(date1, date2) {
+ return ERA5.filterDate(date1, date2)
+ .select(ERA5_BANDS)
+ .mean()
+ .add([-272.15, 0])
+ .divide([35, 0.03]);
+}
+//var ERA5_temp = ee.Image([0,0]).rename(ERA5_BANDS).clip(roi)
+
+// 6. Obtain SRTM Data
+var SRTM_BANDS = ['elevation', 'slope'];
+var elevation = ee.Image('USGS/SRTMGL1_003').clip(roi).select('elevation');
+var slope = ee.Terrain.slope(elevation);
+var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50]);
+//var SRTM_temp = ee.Image([0,0]).rename(SRTM_BANDS).clip(roi)
+
+// 7. Combine all data into a monthly CropHarvest-style monthly composite
+function cropharvest_img(d1, d2) {
+ var img = ee.Image.cat([
+ S1_img(d1, d2),
+ S2_img(d1, d2),
+ ERA5_img(d1, d2),
+ SRTM_img,
+ ]);
+ var ndvi = img.normalizedDifference(['B8', 'B4']).rename('NDVI');
+ // toFloat Necessary for tensor conversion
+ return img.addBands(ndvi).clip(roi).toFloat();
+}
+
+// 8. Create and visualize Presto input
+var latlons = ee.Image.pixelLonLat().clip(roi).select('latitude', 'longitude');
+var imgs = [latlons];
+var numMonths = rangeEnd.difference(rangeStart, 'month').toInt().getInfo();
+var ERA5Palette = [
+ '000080',
+ '0000d9',
+ '4000ff',
+ '8000ff',
+ '0080ff',
+ '00ffff',
+ '00ff80',
+ '80ff00',
+ 'daff00',
+ 'ffff00',
+ 'fff500',
+ 'ffda00',
+ 'ffb000',
+ 'ffa400',
+ 'ff4f00',
+ 'ff2500',
+ 'ff0a00',
+ 'ff00ff',
+];
+
+for (var i = 0; i < numMonths; i++) {
+ var monthStart = rangeStart.advance(i, 'month');
+ var monthEnd = monthStart.advance(1, 'month');
+ var img = cropharvest_img(monthStart, monthEnd);
+ imgs.push(img);
+
+ var monthName = monthStart.format('YY/MM').getInfo();
+ Map.addLayer(
+ img,
+ {
+ bands: ['VV', 'VH', 'VV'],
+ min: [0, -0.2, 0.4],
+ max: [1.0, 0.8, 1.2],
+ },
+ monthName + ' S1',
+ false
+ );
+ Map.addLayer(
+ img,
+ { bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 },
+ monthName + ' S2',
+ false
+ );
+ Map.addLayer(
+ img,
+ { bands: ['temperature_2m'], min: 0, max: 1, palette: ERA5Palette },
+ monthName + ' ERA5',
+ false
+ );
+}
+Map.addLayer(imgs[1], { bands: ['slope'], min: 0, max: 0.3 }, 'SRTM', false);
+
+var composite = ee.ImageCollection.fromImages(imgs).toBands();
+
+// 9. Make predictions using Presto on Vertex AI
+var vertex_model = ee.Model.fromVertexAi({
+ endpoint: ENDPOINT,
+ inputTileSize: [1, 1],
+ proj: ee.Projection('EPSG:4326').atScale(10),
+ fixInputProj: true,
+ outputTileSize: [1, 1],
+ outputBands: { p: { type: ee.PixelType.float(), dimensions: 1 } },
+ payloadFormat: 'ND_ARRAYS',
+ maxPayloadBytes: 5242880, // 5.24mb [MAX]
+});
+
+if (RUN_VERTEX_AI) {
+ // Create band names for embeddingsArrayImage
+ var bandNames = [];
+ for (var i = 0; i < 128; i++) {
+ bandNames.push('b' + i + '');
+ }
+
+ // embeddingsArrayImage is a single band image where each pixel contains an array
+ var embeddingsArrayImage = vertex_model.predictImage(composite).clip(roi);
+ var embeddingsMultiBandImage = embeddingsArrayImage.arrayFlatten([
+ bandNames,
+ ]);
+
+ // Only smaller size embeddings can be directly viewed in GEE immediatley larger ones require the batch task
+ // Map.addLayer(embeddingsMultiBandImage, {min: 0, max: 1},'embeddingsMultiBandImage')
+
+ Export.image.toAsset({
+ image: embeddingsMultiBandImage,
+ description: 'Presto_embeddings',
+ assetId: 'Togo/Presto_test_embeddings_v2025_04_23',
+ region: roi,
+ scale: 10,
+ maxPixels: 1e12,
+ crs: 'EPSG:25231',
+ });
+}
diff --git a/deploy/3_Kmeans_Embeddings.js b/deploy/3_Kmeans_Embeddings.js
new file mode 100644
index 0000000..bcd1b17
--- /dev/null
+++ b/deploy/3_Kmeans_Embeddings.js
@@ -0,0 +1,22 @@
+//------------------------------------------------------------------------------------
+// Script for checking the embedding salience through clutering
+// Author: Ivan Zvonkov (izvonkov@umd.edu)
+//------------------------------------------------------------------------------------
+
+// 1. Load embeddings
+var embeddings = ee.Image(
+ 'users/izvonkov/Togo/Presto_test_embeddings_v2025_05_16'
+);
+var roi = embeddings.geometry({ geodesics: true });
+Map.centerObject(roi, 11);
+
+// 2. Cluster embeddings and display (7 clusters)
+var training = embeddings.sample({ region: roi, scale: 10, numPixels: 10000 });
+var trainedClusterer = ee.Clusterer.wekaKMeans(7).train(training);
+var result = embeddings.cluster(trainedClusterer);
+Map.addLayer(result.randomVisualizer(), {}, 'clusters');
+
+// 3. Display WorldCover
+var WorldCover = ee.ImageCollection('ESA/WorldCover/v200').first().clip(roi);
+var vis = { bands: ['Map'] };
+Map.addLayer(WorldCover, vis, 'WorldCover');