From 470d680736bcb7a2b1277fda73e4dde900d6c234 Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Tue, 15 Apr 2025 08:22:02 -0400 Subject: [PATCH 1/6] Add initial deploy code --- deploy/Presto_to_VertexAI.ipynb | 468 ++++++++++++++++++++++++++++++++ deploy/test_embeddings.js | 251 +++++++++++++++++ 2 files changed, 719 insertions(+) create mode 100644 deploy/Presto_to_VertexAI.ipynb create mode 100644 deploy/test_embeddings.js diff --git a/deploy/Presto_to_VertexAI.ipynb b/deploy/Presto_to_VertexAI.ipynb new file mode 100644 index 0000000..e881694 --- /dev/null +++ b/deploy/Presto_to_VertexAI.ipynb @@ -0,0 +1,468 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Presto in EarthEngine\n", + "\n", + "**Authors**: Ivan Zvonkov, Gabriel Tseng\n", + "\n", + "**Description**:\n", + "1. Loads default Presto model.\n", + "2. Deploys default model to Vertex AI.\n", + "\n", + "Inspired by: https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb\n", + "\n", + "**Running this demo may incur charges to your Google Cloud Account!**" + ], + "metadata": { + "id": "_SVA9v_JTq_-" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Set up" + ], + "metadata": { + "id": "hzb1bwgTUZU0" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KuoEjld3TTLO" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "import ee\n", + "import google\n", + "\n", + "# REPLACE WITH YOUR CLOUD PROJECT!\n", + "PROJECT = 'presto-deployment'\n", + "\n", + "# Authenticate the notebook.\n", + "auth.authenticate_user()\n", + "\n", + "# Authenticate to Earth Engine.\n", + "credentials, _ = google.auth.default()\n", + "ee.Initialize(credentials, project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')\n", + "\n", + "# Set the gcloud project for Vertex AI deployment.\n", + "!gcloud config set project {PROJECT}" + ] + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/nasaharvest/presto.git" + ], + "metadata": { + "id": "MRGjYjltsm6-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 1. Load model" + ], + "metadata": { + "id": "P1zGbf2KIhLA" + } + }, + { + "cell_type": "code", + "source": [ + "%cd /content/presto" + ], + "metadata": { + "id": "WIJKe-vAs6__" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from torch.utils.data import TensorDataset\n", + "from torch.utils.data import DataLoader\n", + "import torch\n", + "import torch.optim as optim\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "\n", + "from single_file_presto import Presto\n", + "\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ], + "metadata": { + "id": "UKgCxBNnYJIB" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model = Presto.construct()\n", + "model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n", + "model.eval();" + ], + "metadata": { + "id": "uF72fMTE1fIS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Sanity check\n", + "\n", + "# from presto.eval import CropHarvestEval\n", + "\n", + "# togo_eval = CropHarvestEval(\"Togo\", ignore_dynamic_world=False, num_timesteps=12, seed=0)\n", + "# results = togo_eval.finetuning_results(model, model_modes=[\"Regression\"])\n", + "# results\n", + "\n", + "# batch_size = 8\n", + "# X_np, dw_np, latlons_np, y_np = togo_eval.dataset.as_array(num_samples=batch_size)\n", + "# month_np = np.array([togo_eval.dataset.start_month] * batch_size)" + ], + "metadata": { + "id": "HE5KeE4G17fm" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Construct input manually\n", + "batch_size = 256\n", + "\n", + "X_tensor = torch.zeros([batch_size, 12, 17])\n", + "latlons_tensor = torch.zeros([batch_size, 2])\n", + "\n", + "dw_empty = torch.full([batch_size, 12], 9, device=device).long()\n", + "month_tensor = torch.full([batch_size], 1, device=device)\n", + "\n", + "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n", + "# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n", + "mask = torch.zeros(X_tensor.shape, device=device).float()" + ], + "metadata": { + "id": "hDKqqzi9F7T4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "with torch.no_grad():\n", + " preds = model.encoder(\n", + " x=X_tensor,\n", + " dynamic_world=dw_empty,\n", + " latlons=latlons_tensor,\n", + " mask=mask,\n", + " month=month_tensor\n", + " )\n", + "preds" + ], + "metadata": { + "id": "jQVWkY1nBClT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Deploy to Vertex AI" + ], + "metadata": { + "id": "5v7gJoysFTsf" + } + }, + { + "cell_type": "code", + "source": [ + "%cd .." + ], + "metadata": { + "id": "GjB3wVLDGbav" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!pip install torchserve torch-model-archiver -q\n", + "!mkdir pytorch_model" + ], + "metadata": { + "id": "Xz53_ow1ZWql" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from ts.torch_handler.base_handler import BaseHandler\n", + "\n", + "# Make model torchscriptable\n", + "example_kwargs = {\n", + " 'x': X_tensor,\n", + " 'dynamic_world': dw_empty,\n", + " 'latlons': latlons_tensor,\n", + " 'mask': mask,\n", + " 'month': month_tensor\n", + "}\n", + "sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n", + "\n", + "!mkdir -p pytorch_model\n", + "sm.save('pytorch_model/model.pt')" + ], + "metadata": { + "id": "nFRvUVkHowKr" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "jit_model = torch.jit.load('pytorch_model/model.pt')\n", + "jit_model(**example_kwargs).shape" + ], + "metadata": { + "id": "cYuSPOyp1A0K" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%%writefile pytorch_model/custom_handler.py\n", + "\n", + "import logging\n", + "\n", + "import torch\n", + "from ts.torch_handler.base_handler import BaseHandler\n", + "import numpy as np\n", + "import sys\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "version = \"v27\"\n", + "batch_size = 256\n", + "\n", + "def printh(text):\n", + " print(f\"HANDLER {version}: {text}\")\n", + "\n", + "class ClassifierHandler(BaseHandler):\n", + "\n", + " def inference(self, data):\n", + " printh(\"inference begin\")\n", + "\n", + " # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n", + " data = data[:, :, 0, 0]\n", + " printh(f\"data shape {data.shape}\")\n", + "\n", + " num_bands = 17\n", + " printh(f\"num_bands {num_bands}\")\n", + "\n", + " # Subtract first two latlon\n", + " num_timesteps = (data.shape[1] - 2) // num_bands\n", + " printh(f\"num_timesteps {num_timesteps}\")\n", + "\n", + " with torch.no_grad():\n", + "\n", + " batches = torch.split(data, batch_size, dim=0)\n", + "\n", + " # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n", + " month_tensor = torch.full([batch_size], 3, device=self.device)\n", + " printh(f\"month: 3\")\n", + "\n", + " # dynamic_world: torch.Tensor of shape [batch_size, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n", + " dw_empty = torch.full([batch_size, num_timesteps], 9, device=self.device).long()\n", + " printh(f\"dw {dw_empty[0]}\")\n", + "\n", + " # mask: An optional torch.Tensor of shape [batch_size, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n", + " mask = torch.zeros((batch_size, num_timesteps, num_bands), device=self.device).float()\n", + " printh(f\"mask sample one timestep: {mask[0, 0]}\")\n", + "\n", + " preds = []\n", + " for batch in batches:\n", + "\n", + " padding = 0\n", + " if batch.shape[0] < batch_size:\n", + " padding = batch_size - batch.shape[0]\n", + " batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n", + "\n", + " # x: torch.Tensor of shape [batch_size, num_timesteps, bands] where bands is described by NORMED_BANDS.\n", + " X_tensor = batch[:, 2:]\n", + " printh(f\"X {X_tensor.shape}\")\n", + "\n", + " X_tensor_reshaped = X_tensor.reshape(batch_size, num_timesteps, num_bands)\n", + " printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n", + "\n", + " # latlons: torch.Tensor of shape [batch_size, 2] describing the latitude and longitude of each input instance.\n", + " latlons_tensor = batch[:, :2]\n", + "\n", + " # Shapes\n", + " printh(\"SHAPES\")\n", + " printh(f\"X {X_tensor_reshaped.shape}\")\n", + " printh(f\"dw {dw_empty.shape}\")\n", + " printh(f\"latlons {latlons_tensor.shape}\")\n", + " printh(f\"mask {mask.shape}\")\n", + " printh(f\"month {month_tensor.shape}\")\n", + "\n", + " pred = self.model(\n", + " x=X_tensor_reshaped,\n", + " dynamic_world=dw_empty,\n", + " latlons=latlons_tensor,\n", + " mask=mask,\n", + " month=month_tensor\n", + " )\n", + " pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n", + " if padding == 0:\n", + " preds.append(pred_np[:])\n", + " else:\n", + " preds.append(pred_np[:-padding])\n", + "\n", + " [printh(f\"{p.shape}\") for p in preds]\n", + " preds = np.concatenate(preds)\n", + " printh(f\"preds shape {preds.shape}\")\n", + " return preds\n", + "\n", + " def handle(self, data, context):\n", + " self.context = context\n", + " printh(f\"handle begin\")\n", + " input_tensor = self.preprocess(data)\n", + " printh(f\"input_tensor shape {input_tensor.shape}\")\n", + " pred_out = self.inference(input_tensor)\n", + " return self.postprocess(pred_out)" + ], + "metadata": { + "id": "htq2Ac95FJlk" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import importlib\n", + "import pytorch_model\n", + "from pytorch_model.custom_handler import ClassifierHandler\n", + "importlib.reload(pytorch_model.custom_handler)\n", + "\n", + "from pytorch_model.custom_handler import ClassifierHandler\n", + "\n", + "# Test output\n", + "data = torch.zeros([713, 206, 1, 1])\n", + "handler = ClassifierHandler()\n", + "handler.model = jit_model\n", + "preds = handler.handle(data=data, context=None)" + ], + "metadata": { + "id": "a3Dgq5Ob5b1i" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!torch-model-archiver -f \\\n", + " --model-name model \\\n", + " --version 1.0 \\\n", + " --serialized-file 'pytorch_model/model.pt' \\\n", + " --handler 'pytorch_model/custom_handler.py' \\\n", + " --export-path pytorch_model/" + ], + "metadata": { + "id": "90TXNAnfF-TD" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "version = \"v27\"\n", + "MODEL_DIR = f'gs://presto-models/default_v2025_04_10_{version}'\n", + "!gsutil cp -r pytorch_model {MODEL_DIR}" + ], + "metadata": { + "id": "n8m9UBy3GEvZ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "REGION = 'us-central1'\n", + "MODEL_NAME = f'model_{version}'\n", + "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n", + "ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n", + "\n", + "!gcloud ai models upload \\\n", + " --artifact-uri={MODEL_DIR} \\\n", + " --project={PROJECT} \\\n", + " --region={REGION} \\\n", + " --container-image-uri={CONTAINER_IMAGE} \\\n", + " --description={MODEL_NAME} \\\n", + " --display-name={MODEL_NAME} \\\n", + " --model-id={MODEL_NAME}\n", + "\n", + "# Create endpoint, if endpoint does not exist\n", + "# !gcloud ai endpoints create \\\n", + "# --display-name={ENDPOINT_NAME} \\\n", + "# --endpoint-id={ENDPOINT_NAME} \\\n", + "# --region={REGION} \\\n", + "# --project={PROJECT}\n", + "\n", + "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n", + " --project={PROJECT} \\\n", + " --region={REGION} \\\n", + " --model={MODEL_NAME} \\\n", + " --display-name={MODEL_NAME} \\\n", + " --machine-type=\"e2-standard-4\"\n", + "\n", + "# 21 mintues when issues with server\n", + "# 4 minutes when it's working" + ], + "metadata": { + "id": "AetRF8dcGraC" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/deploy/test_embeddings.js b/deploy/test_embeddings.js new file mode 100644 index 0000000..af8825d --- /dev/null +++ b/deploy/test_embeddings.js @@ -0,0 +1,251 @@ +/////////////////////////////////////////////////////////////////////////////////////////////// +// Author: Ivan Zvonkov (izvonkov@umd.edu) +// Last Edited: Apr 2, 2025 +// Description +// (1) Specify ROI +// (2) Create CropHarvest composite based on presto-v3 +// https://github.com/nasaharvest/presto-v3/tree/ff17611ef1433eff0b020f8e513c640c3959e381/src/data/earthengine +// (3) Use Presto deployed on VertexAI to create embeddings +// (4) Save embeddings as asset +/////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 1. Specifies ROI +/////////////////////////////////////////////////////////////////////////////////////////////// +var lon1 = 1.1575584013473583; +var lon2 = 1.2048954756271435; +var lat1 = 6.840427147744777; +var lat2 = 6.877467188902948; +var roi = ee.Geometry.Polygon([ + [lon1, lat1], + [lon2, lat1], + [lon2, lat2], + [lon1, lat2], + [lon1, lat1], +]); +Map.centerObject(roi, 14); + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 2. Creates CropHarvest Composite +/////////////////////////////////////////////////////////////////////////////////////////////// +var rangeStart = '2019-03-01'; +var rangeEnd = '2020-03-01'; + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 2a. Sentinel-1 Data +/////////////////////////////////////////////////////////////////////////////////////////////// +var S1_BANDS = ['VV', 'VH']; +var S1_all = ee + .ImageCollection('COPERNICUS/S1_GRD') + .filterBounds(roi) + .filterDate( + ee.Date(rangeStart).advance(-31, 'days'), + ee.Date(rangeEnd).advance(31, 'days') + ); + +var S1 = S1_all.filter( + ee.Filter.eq( + 'orbitProperties_pass', + S1_all.first().get('orbitProperties_pass') + ) +).filter(ee.Filter.eq('instrumentMode', 'IW')); +var S1_VV = S1.filter( + ee.Filter.listContains('transmitterReceiverPolarisation', 'VV') +); +var S1_VH = S1.filter( + ee.Filter.listContains('transmitterReceiverPolarisation', 'VH') +); + +function getCloseImages(middleDate, imageCollection) { + var fromMiddleDate = imageCollection + .map(function (img) { + var dateDist = ee + .Number(img.get('system:time_start')) + .subtract(middleDate.millis()) + .abs(); + return img.set('dateDist', dateDist); + }) + .sort({ property: 'dateDist', ascending: true }); + var fifteenDaysInMs = ee.Number(1296000000); + var maxDiff = ee + .Number(fromMiddleDate.first().get('dateDist')) + .max(fifteenDaysInMs); + return fromMiddleDate.filterMetadata( + 'dateDist', + 'not_greater_than', + maxDiff + ); +} + +function S1_img(date1, date2) { + var startDate = ee.Date(date1); + var daysBetween = ee.Date(date2).difference(startDate, 'days'); + var middleDate = startDate.advance(daysBetween.divide(2), 'days'); + var kept_vv = getCloseImages(middleDate, S1_VV).select('VV'); + var kept_vh = getCloseImages(middleDate, S1_VH).select('VH'); + var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()]); + return S1_composite.select(S1_BANDS).add(25.0).divide(25.0); // S1 ranges from -50 to 1 +} + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 2b. Sentinel-2 Data +/////////////////////////////////////////////////////////////////////////////////////////////// +var S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']; +var S2 = ee + .ImageCollection('COPERNICUS/S2_SR_HARMONIZED') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +var csPlus = ee + .ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +var QA_BAND = 'cs_cdf'; // Better than cs here +var S2_cf = S2.linkCollection(csPlus, [QA_BAND]); + +function S2_img(date1, date2) { + return S2_cf.filterDate(date1, date2) + .qualityMosaic(QA_BAND) + .select(S2_BANDS) + .divide(ee.Image(1e4)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 2c. ERA5 Data +/////////////////////////////////////////////////////////////////////////////////////////////// +var ERA5_BANDS = ['temperature_2m', 'total_precipitation_sum']; +var ERA5 = ee + .ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +function ERA5_img(date1, date2) { + return ERA5.filterDate(date1, date2) + .select(ERA5_BANDS) + .mean() + .add([-272.15, 0]) + .divide([35, 0.03]); +} + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 2d. SRTM Data +/////////////////////////////////////////////////////////////////////////////////////////////// +var SRTM_BANDS = ['elevation', 'slope']; +var elevation = ee.Image('USGS/SRTMGL1_003').clip(roi).select('elevation'); +var slope = ee.Terrain.slope(elevation); +var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50]); + +function cropharvest_img(d1, d2) { + var img = ee.Image.cat([ + S1_img(d1, d2), + S2_img(d1, d2), + ERA5_img(d1, d2), + SRTM_img, + ]); + var ndvi = img.normalizedDifference(['B8', 'B4']).rename('NDVI'); + return img.addBands(ndvi).clip(roi).toFloat(); // toFloat Necessary for tensor conversion +} + +var latlons = ee.Image.pixelLonLat().clip(roi).select('latitude', 'longitude'); +var imgs = [ + latlons, + cropharvest_img('2019-03-01', '2019-04-01'), + cropharvest_img('2019-04-01', '2019-05-01'), + cropharvest_img('2019-05-01', '2019-06-01'), + cropharvest_img('2019-06-01', '2019-07-01'), + cropharvest_img('2019-07-01', '2019-08-01'), + cropharvest_img('2019-08-01', '2019-09-01'), + cropharvest_img('2019-09-01', '2019-10-01'), + cropharvest_img('2019-10-01', '2019-11-01'), + cropharvest_img('2019-11-01', '2019-12-01'), + cropharvest_img('2019-12-01', '2020-01-01'), + cropharvest_img('2020-01-01', '2020-02-01'), + cropharvest_img('2020-02-01', '2020-03-01'), +]; + +var S1vis = { + bands: ['VV', 'VH', 'VV'], + min: [0, -0.2, 0.4], + max: [1.0, 0.8, 1.2], +}; +Map.addLayer(imgs[1], S1vis, 'S1 March'); +Map.addLayer(imgs[2], S1vis, 'S1 April'); +Map.addLayer(imgs[3], S1vis, 'S1 May'); + +var S2vis = { bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 }; +Map.addLayer(imgs[1], S2vis, 'S2 March'); +Map.addLayer(imgs[2], S2vis, 'S2 April'); +Map.addLayer(imgs[3], S2vis, 'S2 May'); + +var ERA5Palette = [ + '000080', + '0000d9', + '4000ff', + '8000ff', + '0080ff', + '00ffff', + '00ff80', + '80ff00', + 'daff00', + 'ffff00', + 'fff500', + 'ffda00', + 'ffb000', + 'ffa400', + 'ff4f00', + 'ff2500', + 'ff0a00', + 'ff00ff', +]; +var ERA5vis = { + bands: ['temperature_2m'], + min: 0, + max: 1, + palette: ERA5Palette, +}; +Map.addLayer(imgs[1], ERA5vis, 'ERA5 March'); +Map.addLayer(imgs[2], ERA5vis, 'ERA5 April'); +Map.addLayer(imgs[3], ERA5vis, 'ERA5 May'); + +Map.addLayer(imgs[1], { bands: ['slope'], min: 0, max: 1 }, 'SRTM slope'); + +var composite = ee.ImageCollection.fromImages(imgs).toBands(); +var bands = composite.bandNames(); +print(bands); + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 3. Call Vertex AI Endpoint +/////////////////////////////////////////////////////////////////////////////////////////////// +var endpoint = + 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'; +var vertex_model = ee.Model.fromVertexAi({ + endpoint: endpoint, + inputTileSize: [1, 1], + proj: ee.Projection('EPSG:4326').atScale(10), + fixInputProj: true, + outputTileSize: [1, 1], + outputBands: { + p: { + // 'output' + type: ee.PixelType.float(), + dimensions: 1, + }, + }, + payloadFormat: 'ND_ARRAYS', +}); + +var predictions = vertex_model.predictImage(composite).clip(roi); +print(predictions); + +//Map.addLayer(predictions, {min: 0, max: 1},'predictions') + +/////////////////////////////////////////////////////////////////////////////////////////////// +// 4. Save embeddings +/////////////////////////////////////////////////////////////////////////////////////////////// +Export.image.toAsset({ + image: predictions, + description: 'Presto_embeddings', + assetId: 'Togo/Presto_test_embeddings_v2025_04_10', + region: roi, + scale: 10, + maxPixels: 1e12, + crs: 'EPSG:25231', +}); From 10bf91463029cc24536a6df997d956cf6104fc75 Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Fri, 16 May 2025 15:27:49 -0400 Subject: [PATCH 2/6] Cleaned up deploy code --- README.md | 7 + deploy/1_Presto_to_VertexAI.ipynb | 671 ++++++++++++++++++++++++++++++ deploy/2_Generate_Embeddings.js | 246 +++++++++++ deploy/Presto_to_VertexAI.ipynb | 468 --------------------- deploy/test_embeddings.js | 251 ----------- 5 files changed, 924 insertions(+), 719 deletions(-) create mode 100644 deploy/1_Presto_to_VertexAI.ipynb create mode 100644 deploy/2_Generate_Embeddings.js delete mode 100644 deploy/Presto_to_VertexAI.ipynb delete mode 100644 deploy/test_embeddings.js diff --git a/README.md b/README.md index 88fbb8a..b19de93 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,13 @@ dtype: int64 ``` +## Generating Presto Embeddings in Google Earth Engine +You can generate Presto embeddings in Google Earth Engine by: +1. Deploying Presto to Vertex AI + Open In Colab + +2. Using the ee.model.fromVertexAI function in Google Earth Engine ([script on GEE](https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26)) + ## Reference If you find this code useful, please cite the following paper: ``` diff --git a/deploy/1_Presto_to_VertexAI.ipynb b/deploy/1_Presto_to_VertexAI.ipynb new file mode 100644 index 0000000..f8ac691 --- /dev/null +++ b/deploy/1_Presto_to_VertexAI.ipynb @@ -0,0 +1,671 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "_SVA9v_JTq_-" + }, + "source": [ + "# 1. Presto to Vertex AI\n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "**Authors**: Ivan Zvonkov, Gabriel Tseng, (additional credits: [Earth_Engine_PyTorch_Vertex_AI](https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb))\n", + "\n", + "**Description**: The notebook Deploys Presto to Vertex AI. This is a prerequisite to generating Presto embeddings on Google Earth Engine using\n", + "[ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai).\n", + "\n", + "Once the model is deployed this [GEE script](https://code.earthengine.google.com/df6348b8d47cd751eb5164dccb7b26a9) can be used to generate Presto embeddings.\n", + "\n", + "**Steps**:\n", + "1. Set up environment\n", + "2. Load default Presto model\n", + "3. Transform Presto model into TorchScript\n", + "4. Package TorchScript model into TorchServe\n", + "5. Deploy and use Vertex AI\n", + "\n", + " 5a. Upload TorchServe model to Vertex AI Model Registry [Free]\n", + "\n", + " 5b. Create a Vertex AI Endpoint [Free]\n", + "\n", + " 5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]\n", + "\n", + " 5d. Generate embeddings in Google Earth Engine [Cost depends on region size]\n", + "\n", + " 5e. Undeploy model from endpoint [Free]\n", + "\n", + "**Cost Breakdown**:\n", + "\n", + "*5a. Upload TorchServe model to Vertex AI Model Registry [Free]*\n", + "- Model files are uploaded to Cloud Storage but are lightweight (3.37 Mb total) and thus easily fall into Google Cloud's 5GB/month Storage [Free Tier](https://cloud.google.com/storage/pricing#cloud-storage-always-free)\n", + "- There is no cost to storing models in Vertex AI Model Registry ([source](https://cloud.google.com/vertex-ai/pricing#modelregistry))\n", + "\n", + "*5b. Create a Vertex AI Endpoint [Free]*\n", + "- There is no cost to creating an endpoint. Costs start when a model is deployed to that endpoint\n", + "\n", + "*5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]*\n", + "- The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1)\n", + "- So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made\n", + "\n", + "*5d. Generate embeddings in Google Earth Engine [Cost depends on region size]*\n", + "- Once a model is deployed and `ee.model.fromVertexAi` is used Vertex AI scales the amount of nodes based on amount of data (size of the region)\n", + "- Our current embedding generation cost estimates are \\$5.37 - \\$10.14 per 1000 km2 \n", + "- We compute a cost estimate for your ROI in our Google Earth Engine script\n", + "\n", + "*5e. Undeploy model from endpoint [Free]*\n", + "- Necessary to stop incurring charges from 5c" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hzb1bwgTUZU0" + }, + "source": [ + "## 1. Set up environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KuoEjld3TTLO" + }, + "outputs": [], + "source": [ + "from google.colab import auth\n", + "\n", + "auth.authenticate_user()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cUI_5pWJ3V4s", + "outputId": "25e8a556-7401-45c0-f1ef-9c02b53948d9" + }, + "outputs": [], + "source": [ + "# REPLACE WITH YOUR CLOUD PROJECT!\n", + "PROJECT = 'presto-deployment'\n", + "!gcloud config set project {PROJECT}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "MRGjYjltsm6-", + "outputId": "7ac7b255-bea0-4d90-9baf-6157ba17aa26" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/nasaharvest/presto.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P1zGbf2KIhLA" + }, + "source": [ + "## 2. Load default Presto model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UKgCxBNnYJIB", + "outputId": "f2384124-4109-4031-87df-95140c7d7029" + }, + "outputs": [], + "source": [ + "# Navigate inside of the repository to import Presto\n", + "%cd /content/presto\n", + "\n", + "import torch\n", + "from presto.single_file_presto import Presto\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "model = Presto.construct()\n", + "model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n", + "model.eval();\n", + "\n", + "# Navigate back to main directory\n", + "%cd /content" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5v7gJoysFTsf" + }, + "source": [ + "## 3. Transform Presto model into TorchScript\n", + "> TorchScript is a way to create serializable and optimizable models from PyTorch code.\n", + "https://docs.pytorch.org/docs/stable/jit.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hDKqqzi9F7T4" + }, + "outputs": [], + "source": [ + "# Construct input manually\n", + "batch_size = 256\n", + "\n", + "X_tensor = torch.zeros([batch_size, 12, 17])\n", + "latlons_tensor = torch.zeros([batch_size, 2])\n", + "\n", + "dw_empty = torch.full([batch_size, 12], 9, device=device).long()\n", + "month_tensor = torch.full([batch_size], 1, device=device)\n", + "\n", + "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n", + "# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n", + "mask = torch.zeros(X_tensor.shape, device=device).float()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mdSXGZuikHUk" + }, + "outputs": [], + "source": [ + "# Verify forward pass with regular model\n", + "with torch.no_grad():\n", + " preds = model.encoder(\n", + " x=X_tensor,\n", + " dynamic_world=dw_empty,\n", + " latlons=latlons_tensor,\n", + " mask=mask,\n", + " month=month_tensor\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nFRvUVkHowKr", + "outputId": "8704f15d-6af0-4c46-c4ee-82fe56881230" + }, + "outputs": [], + "source": [ + "# Make model torchscriptable\n", + "example_kwargs = {\n", + " 'x': X_tensor,\n", + " 'dynamic_world': dw_empty,\n", + " 'latlons': latlons_tensor,\n", + " 'mask': mask,\n", + " 'month': month_tensor\n", + "}\n", + "sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n", + "\n", + "!mkdir -p pytorch_model\n", + "sm.save('pytorch_model/model.pt')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cYuSPOyp1A0K", + "outputId": "c97ef2f4-7715-48ba-8b82-3d21dc401fff" + }, + "outputs": [], + "source": [ + "jit_model = torch.jit.load('pytorch_model/model.pt')\n", + "jit_model(**example_kwargs).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3b0LsZpqnByv" + }, + "source": [ + "## 4. Package TorchScript model into TorchServe\n", + "> TorchServe is a performant, flexible and easy to use tool for serving PyTorch models in production.\n", + "https://docs.pytorch.org/serve/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "i70o_BZml9vs", + "outputId": "439d43f5-02ed-4dfd-be86-a49c058e70cd" + }, + "outputs": [], + "source": [ + "!pip install torchserve torch-model-archiver -q" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "htq2Ac95FJlk", + "outputId": "79076ad1-7cde-44f2-8e4a-07f28599d29a" + }, + "outputs": [], + "source": [ + "%%writefile pytorch_model/custom_handler.py\n", + "import logging\n", + "import torch\n", + "from ts.torch_handler.base_handler import BaseHandler\n", + "import numpy as np\n", + "\n", + "# UPDATE BASED ON YOUR NEEDS\n", + "########################################\n", + "VERSION = \"v1\"\n", + "START_MONTH = 3\n", + "BATCH_SIZE = 256\n", + "########################################\n", + "\n", + "def printh(text):\n", + " # Prepends HANDLER to each print statement to make it easier to find in logs.\n", + " print(f\"HANDLER {VERSION}: {text}\")\n", + "\n", + "class ClassifierHandler(BaseHandler):\n", + "\n", + " def inference(self, data):\n", + " printh(\"Inference begin\")\n", + "\n", + " # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n", + " data = data[:, :, 0, 0]\n", + " printh(f\"Data shape {data.shape}\")\n", + "\n", + " num_bands = 17\n", + " printh(f\"Num_bands {num_bands}\")\n", + "\n", + " # Subtract first two latlon\n", + " num_timesteps = (data.shape[1] - 2) // num_bands\n", + " printh(f\"Num_timesteps {num_timesteps}\")\n", + "\n", + " with torch.no_grad():\n", + "\n", + " batches = torch.split(data, BATCH_SIZE, dim=0)\n", + "\n", + " # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n", + " month_tensor = torch.full([BATCH_SIZE], START_MONTH, device=self.device)\n", + " printh(f\"Month: {START_MONTH}\")\n", + "\n", + " # dynamic_world: torch.Tensor of shape [BATCH_SIZE, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n", + " dw_empty = torch.full([BATCH_SIZE, num_timesteps], 9, device=self.device).long()\n", + " printh(f\"DW {dw_empty[0]}\")\n", + "\n", + " # mask: An optional torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n", + " mask = torch.zeros((BATCH_SIZE, num_timesteps, num_bands), device=self.device).float()\n", + " printh(f\"Mask sample one timestep: {mask[0, 0]}\")\n", + "\n", + " preds_list = []\n", + " for batch in batches:\n", + " padding = 0\n", + " if batch.shape[0] < BATCH_SIZE:\n", + " padding = BATCH_SIZE - batch.shape[0]\n", + " batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n", + "\n", + " # x: torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands] where bands is described by NORMED_BANDS.\n", + " X_tensor = batch[:, 2:]\n", + " printh(f\"X {X_tensor.shape}\")\n", + "\n", + " X_tensor_reshaped = X_tensor.reshape(BATCH_SIZE, num_timesteps, num_bands)\n", + " printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n", + "\n", + " # latlons: torch.Tensor of shape [BATCH_SIZE, 2] describing the latitude and longitude of each input instance.\n", + " latlons_tensor = batch[:, :2]\n", + "\n", + " printh(\"SHAPES\")\n", + " printh(f\"X {X_tensor_reshaped.shape}\")\n", + " printh(f\"DW {dw_empty.shape}\")\n", + " printh(f\"Latlons {latlons_tensor.shape}\")\n", + " printh(f\"Mask {mask.shape}\")\n", + " printh(f\"Month {month_tensor.shape}\")\n", + "\n", + " pred = self.model(\n", + " x=X_tensor_reshaped,\n", + " dynamic_world=dw_empty,\n", + " latlons=latlons_tensor,\n", + " mask=mask,\n", + " month=month_tensor\n", + " )\n", + " pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n", + " if padding == 0:\n", + " preds_list.append(pred_np[:])\n", + " else:\n", + " preds_list.append(pred_np[:-padding])\n", + "\n", + " [printh(f\"{p.shape}\") for p in preds_list]\n", + " preds = np.concatenate(preds_list)\n", + " printh(f\"Preds shape {preds.shape}\")\n", + " return preds\n", + "\n", + " def handle(self, data, context):\n", + " self.context = context\n", + " printh(f\"Handle begin\")\n", + " input_tensor = self.preprocess(data)\n", + " printh(f\"Input_tensor shape {input_tensor.shape}\")\n", + " pred_out = self.inference(input_tensor)\n", + " printh(f\"Inference complete\")\n", + " return self.postprocess(pred_out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a3Dgq5Ob5b1i", + "outputId": "e0200fa3-b535-4da6-d476-7614603b8370" + }, + "outputs": [], + "source": [ + "import importlib\n", + "import pytorch_model.custom_handler\n", + "\n", + "importlib.reload(pytorch_model.custom_handler)\n", + "\n", + "from pytorch_model.custom_handler import ClassifierHandler, VERSION\n", + "\n", + "# Test output\n", + "data = torch.zeros([713, 206, 1, 1])\n", + "handler = ClassifierHandler()\n", + "handler.model = jit_model\n", + "preds = handler.handle(data=data, context=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "90TXNAnfF-TD" + }, + "outputs": [], + "source": [ + "!torch-model-archiver -f \\\n", + " --model-name model \\\n", + " --version 1.0 \\\n", + " --serialized-file 'pytorch_model/model.pt' \\\n", + " --handler 'pytorch_model/custom_handler.py' \\\n", + " --export-path pytorch_model/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PYK9g3r1qVru" + }, + "source": [ + "## 5. Deploy and use Vertex AI" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CH0JO_Jww5Ok" + }, + "source": [ + "### 5a. Upload TorchServe model to Vertex AI Model Registry\n", + "> The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models.\n", + "https://cloud.google.com/vertex-ai/docs/model-registry/introduction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "n8m9UBy3GEvZ", + "outputId": "5f67c5c5-caad-487b-97ed-f3d5255132cc" + }, + "outputs": [], + "source": [ + "MODEL_DIR = f'gs://presto-models/{VERSION}'\n", + "!gsutil cp -r pytorch_model {MODEL_DIR}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AetRF8dcGraC", + "outputId": "03589695-1e44-4f79-c8bf-23b5444a0f28" + }, + "outputs": [], + "source": [ + "REGION = 'us-central1'\n", + "MODEL_NAME = f'model_{VERSION}'\n", + "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n", + "\n", + "!gcloud ai models upload \\\n", + " --artifact-uri={MODEL_DIR} \\\n", + " --region={REGION} \\\n", + " --container-image-uri={CONTAINER_IMAGE} \\\n", + " --description={MODEL_NAME} \\\n", + " --display-name={MODEL_NAME} \\\n", + " --model-id={MODEL_NAME}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_BXDtoITxY2T" + }, + "source": [ + "### 5b. Create a Vertex AI Endpoint\n", + "> To deploy a model for online prediction, you need an endpoint.\n", + "https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XhMK9aA73FzI", + "outputId": "f0568a0b-e606-4605-fb68-cc41096096cd" + }, + "outputs": [], + "source": [ + "ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n", + "\n", + "endpoints = !gcloud ai endpoints list --region={REGION} --format='get(DISPLAY_NAME)'\n", + "\n", + "if ENDPOINT_NAME in endpoints:\n", + " print(f\"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.\")\n", + "else:\n", + " print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating.\")\n", + " !gcloud ai endpoints create \\\n", + " --display-name={ENDPOINT_NAME} \\\n", + " --endpoint-id={ENDPOINT_NAME} \\\n", + " --region={REGION}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dWODP6ccx8-3" + }, + "source": [ + "### 5c. Deploy model to endpoint\n", + "> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.\n", + "https://cloud.google.com/vertex-ai/docs/general/deployment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sI297v1mjtR8", + "outputId": "ddb82445-430c-442e-9e3b-d7765491a93b" + }, + "outputs": [], + "source": [ + "# Deploy model to endpoint, this will start an e2-standard-2 machine which costs money\n", + "print(\"Track model deployment progress and prediction logs:\")\n", + "print(f\"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\\n\")\n", + "\n", + "# Can take updward of 25 minutes\n", + "# If using for large region, set min-replica-count higher to save scaling time\n", + "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n", + " --region={REGION} \\\n", + " --model={MODEL_NAME} \\\n", + " --display-name={MODEL_NAME} \\\n", + " --machine-type=\"e2-standard-2\" \\\n", + " --min-replica-count='1' \\\n", + " --max-replica-count=\"300\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r2VtGUv9JOI9" + }, + "source": [ + "### 5d. Generate embeddings in Google Earth Engine\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "160PNcRRJMzn", + "outputId": "52266346-f83f-49e9-e487-ba1c96a2dc0b" + }, + "outputs": [], + "source": [ + "GEE_SCRIPT_URL = \"https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c\"\n", + "print(f\"Open this script: {GEE_SCRIPT_URL} and use the below string for the ENDPOINT variable\")\n", + "print(\"Use the below string for the ENDPOINT variable\")\n", + "print(f\"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8PnhA3gfHSrY" + }, + "source": [ + "### 5e. Undeploy model from endpoint\n", + "\n", + "Once predictions are made, you must undeploy your model to stop incurring further charges.\n", + "\n", + "This can be done using the below code or by using the Google Cloud console directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VvOO_sfQDWPt", + "outputId": "139a2f1a-73af-4b7a-943e-6f05bae8dbfb" + }, + "outputs": [], + "source": [ + "def get_deployed_model():\n", + " deployed_models = !gcloud ai endpoints describe {ENDPOINT_NAME} --region={REGION} --format 'get(deployedModels)'\n", + " if deployed_models[1] == '':\n", + " print(\"No models deployed\")\n", + " else:\n", + " print(deployed_model_id)\n", + " return eval(deployed_models[1])['id']\n", + "\n", + "deployed_model_id = get_deployed_model()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vswFTu9kFeHy", + "outputId": "dcc12b61-5cad-4002-e0d9-33923cd4fbea" + }, + "outputs": [], + "source": [ + "!gcloud ai endpoints undeploy-model {ENDPOINT_NAME} --region={REGION} --deployed-model-id={deployed_model_id}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "w9Fq6yspF2ye", + "outputId": "ac70f2cc-4b0a-4fa7-c595-81ee9a8a276d" + }, + "outputs": [], + "source": [ + "get_deployed_model()" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/deploy/2_Generate_Embeddings.js b/deploy/2_Generate_Embeddings.js new file mode 100644 index 0000000..efe8b7f --- /dev/null +++ b/deploy/2_Generate_Embeddings.js @@ -0,0 +1,246 @@ +//------------------------------------------------------------------------------------ +// Script for generating Presto embeddings using Vertex AI +// Author: Ivan Zvonkov (izvonkov@umd.edu) +//------------------------------------------------------------------------------------ +// 1. Presto embedding generation parameters (set parameters according to your needs) +//------------------------------------------------------------------------------------ +var roi = ee + .FeatureCollection('FAO/GAUL/2015/level2') + .filter(ee.Filter.eq('ADM2_NAME', 'Haho')); +var PROJ = 'EPSG:25231'; + +var rangeStart = ee.Date('2019-03-01'); +var rangeEnd = ee.Date('2020-03-01'); + +var ENDPOINT = + 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'; +var RUN_VERTEX_AI = true; // Leave this as false to get a cost estimate first +//------------------------------------------------------------------------------------ + +Map.centerObject(roi, 10); +Map.addLayer(roi, {}, 'Region of Interest'); +Map.setOptions('satellite'); + +// 2. Cost Computation +var roiAreaKM2 = roi.geometry().area().divide(1e6); +function estimate(cost) { + return roiAreaKM2.divide(1000).multiply(cost).toInt().getInfo(); +} +print('ROI Area: ' + roiAreaKM2.toInt().getInfo() + ' km2'); +print( + 'Embedding Generation Estimates\nCost: $' + + estimate(5.37) + + '-' + + estimate(10.14) +); +if (!RUN_VERTEX_AI) + print( + 'If you are ready to generate embeddings,\nchange RUN_VERTEX_AI variable to true' + ); + +// 3. Obtain monthly Sentinel-1 composites +var S1_BANDS = ['VV', 'VH']; +var S1_all = ee + .ImageCollection('COPERNICUS/S1_GRD') + .filterBounds(roi) + .filterDate( + ee.Date(rangeStart).advance(-31, 'days'), + ee.Date(rangeEnd).advance(31, 'days') + ); + +var S1 = S1_all.filter( + ee.Filter.eq( + 'orbitProperties_pass', + S1_all.first().get('orbitProperties_pass') + ) +).filter(ee.Filter.eq('instrumentMode', 'IW')); +var S1_VV = S1.filter( + ee.Filter.listContains('transmitterReceiverPolarisation', 'VV') +); +var S1_VH = S1.filter( + ee.Filter.listContains('transmitterReceiverPolarisation', 'VH') +); + +function getCloseImages(middleDate, imageCollection) { + var fromMiddleDate = imageCollection + .map(function (img) { + var dateDist = ee + .Number(img.get('system:time_start')) + .subtract(middleDate.millis()) + .abs(); + return img.set('dateDist', dateDist); + }) + .sort({ property: 'dateDist', ascending: true }); + var fifteenDaysInMs = ee.Number(1296000000); + var maxDiff = ee + .Number(fromMiddleDate.first().get('dateDist')) + .max(fifteenDaysInMs); + return fromMiddleDate.filterMetadata( + 'dateDist', + 'not_greater_than', + maxDiff + ); +} + +function S1_img(date1, date2) { + var startDate = ee.Date(date1); + var daysBetween = ee.Date(date2).difference(startDate, 'days'); + var middleDate = startDate.advance(daysBetween.divide(2), 'days'); + var kept_vv = getCloseImages(middleDate, S1_VV).select('VV'); + var kept_vh = getCloseImages(middleDate, S1_VH).select('VH'); + var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()]); + return S1_composite.select(S1_BANDS).add(25.0).divide(25.0); // S1 ranges from -50 to 1 +} + +// 4. Obtain monthly Sentinel-2 composites +var S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']; +var S2 = ee + .ImageCollection('COPERNICUS/S2_SR_HARMONIZED') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +var csPlus = ee + .ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +var QA_BAND = 'cs_cdf'; // Better than cs here +var S2_cf = S2.linkCollection(csPlus, [QA_BAND]); + +function S2_img(date1, date2) { + return S2_cf.filterDate(date1, date2) + .qualityMosaic(QA_BAND) + .select(S2_BANDS) + .divide(ee.Image(1e4)); +} + +// 5. Obtain monthly ERA5 composites +var ERA5_BANDS = ['temperature_2m', 'total_precipitation_sum']; +var ERA5 = ee + .ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR') + .filterBounds(roi) + .filterDate(rangeStart, rangeEnd); +function ERA5_img(date1, date2) { + return ERA5.filterDate(date1, date2) + .select(ERA5_BANDS) + .mean() + .add([-272.15, 0]) + .divide([35, 0.03]); +} +//var ERA5_temp = ee.Image([0,0]).rename(ERA5_BANDS).clip(roi) + +// 6. Obtain SRTM Data +var SRTM_BANDS = ['elevation', 'slope']; +var elevation = ee.Image('USGS/SRTMGL1_003').clip(roi).select('elevation'); +var slope = ee.Terrain.slope(elevation); +var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50]); +//var SRTM_temp = ee.Image([0,0]).rename(SRTM_BANDS).clip(roi) + +// 7. Combine all data into a monthly CropHarvest-style monthly composite +function cropharvest_img(d1, d2) { + var img = ee.Image.cat([ + S1_img(d1, d2), + S2_img(d1, d2), + ERA5_img(d1, d2), + SRTM_img, + ]); + var ndvi = img.normalizedDifference(['B8', 'B4']).rename('NDVI'); + // toFloat Necessary for tensor conversion + return img.addBands(ndvi).clip(roi).toFloat(); +} + +// 8. Create and visualize Presto input +var latlons = ee.Image.pixelLonLat().clip(roi).select('latitude', 'longitude'); +var imgs = [latlons]; +var numMonths = rangeEnd.difference(rangeStart, 'month').toInt().getInfo(); +var ERA5Palette = [ + '000080', + '0000d9', + '4000ff', + '8000ff', + '0080ff', + '00ffff', + '00ff80', + '80ff00', + 'daff00', + 'ffff00', + 'fff500', + 'ffda00', + 'ffb000', + 'ffa400', + 'ff4f00', + 'ff2500', + 'ff0a00', + 'ff00ff', +]; + +for (var i = 0; i < numMonths; i++) { + var monthStart = rangeStart.advance(i, 'month'); + var monthEnd = monthStart.advance(1, 'month'); + var img = cropharvest_img(monthStart, monthEnd); + imgs.push(img); + + var monthName = monthStart.format('YY/MM').getInfo(); + Map.addLayer( + img, + { + bands: ['VV', 'VH', 'VV'], + min: [0, -0.2, 0.4], + max: [1.0, 0.8, 1.2], + }, + monthName + ' S1', + false + ); + Map.addLayer( + img, + { bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 }, + monthName + ' S2', + false + ); + Map.addLayer( + img, + { bands: ['temperature_2m'], min: 0, max: 1, palette: ERA5Palette }, + monthName + ' ERA5', + false + ); +} +Map.addLayer(imgs[1], { bands: ['slope'], min: 0, max: 0.3 }, 'SRTM', false); + +var composite = ee.ImageCollection.fromImages(imgs).toBands(); + +// 9. Make predictions using Presto on Vertex AI +var vertex_model = ee.Model.fromVertexAi({ + endpoint: ENDPOINT, + inputTileSize: [1, 1], + proj: ee.Projection('EPSG:4326').atScale(10), + fixInputProj: true, + outputTileSize: [1, 1], + outputBands: { p: { type: ee.PixelType.float(), dimensions: 1 } }, + payloadFormat: 'ND_ARRAYS', + maxPayloadBytes: 5242880, // 5.24mb [MAX] +}); + +if (RUN_VERTEX_AI) { + // Create band names for embeddingsArrayImage + var bandNames = []; + for (var i = 0; i < 128; i++) { + bandNames.push('b' + i + ''); + } + + // embeddingsArrayImage is a single band image where each pixel contains an array + var embeddingsArrayImage = vertex_model.predictImage(composite).clip(roi); + var embeddingsMultiBandImage = embeddingsArrayImage.arrayFlatten([ + bandNames, + ]); + + // Only smaller size embeddings can be directly viewed in GEE immediatley larger ones require the batch task + // Map.addLayer(embeddingsMultiBandImage, {min: 0, max: 1},'embeddingsMultiBandImage') + + Export.image.toAsset({ + image: embeddingsMultiBandImage, + description: 'Presto_embeddings', + assetId: 'Togo/Presto_test_embeddings_v2025_04_23', + region: roi, + scale: 10, + maxPixels: 1e12, + crs: 'EPSG:25231', + }); +} diff --git a/deploy/Presto_to_VertexAI.ipynb b/deploy/Presto_to_VertexAI.ipynb deleted file mode 100644 index e881694..0000000 --- a/deploy/Presto_to_VertexAI.ipynb +++ /dev/null @@ -1,468 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Presto in EarthEngine\n", - "\n", - "**Authors**: Ivan Zvonkov, Gabriel Tseng\n", - "\n", - "**Description**:\n", - "1. Loads default Presto model.\n", - "2. Deploys default model to Vertex AI.\n", - "\n", - "Inspired by: https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb\n", - "\n", - "**Running this demo may incur charges to your Google Cloud Account!**" - ], - "metadata": { - "id": "_SVA9v_JTq_-" - } - }, - { - "cell_type": "markdown", - "source": [ - "# Set up" - ], - "metadata": { - "id": "hzb1bwgTUZU0" - } - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KuoEjld3TTLO" - }, - "outputs": [], - "source": [ - "from google.colab import auth\n", - "\n", - "import ee\n", - "import google\n", - "\n", - "# REPLACE WITH YOUR CLOUD PROJECT!\n", - "PROJECT = 'presto-deployment'\n", - "\n", - "# Authenticate the notebook.\n", - "auth.authenticate_user()\n", - "\n", - "# Authenticate to Earth Engine.\n", - "credentials, _ = google.auth.default()\n", - "ee.Initialize(credentials, project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')\n", - "\n", - "# Set the gcloud project for Vertex AI deployment.\n", - "!gcloud config set project {PROJECT}" - ] - }, - { - "cell_type": "code", - "source": [ - "!git clone https://github.com/nasaharvest/presto.git" - ], - "metadata": { - "id": "MRGjYjltsm6-" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## 1. Load model" - ], - "metadata": { - "id": "P1zGbf2KIhLA" - } - }, - { - "cell_type": "code", - "source": [ - "%cd /content/presto" - ], - "metadata": { - "id": "WIJKe-vAs6__" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "from torch.utils.data import TensorDataset\n", - "from torch.utils.data import DataLoader\n", - "import torch\n", - "import torch.optim as optim\n", - "import torch.nn as nn\n", - "import numpy as np\n", - "\n", - "from single_file_presto import Presto\n", - "\n", - "\n", - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" - ], - "metadata": { - "id": "UKgCxBNnYJIB" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "model = Presto.construct()\n", - "model.load_state_dict(torch.load(\"data/default_model.pt\", map_location=device))\n", - "model.eval();" - ], - "metadata": { - "id": "uF72fMTE1fIS" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Sanity check\n", - "\n", - "# from presto.eval import CropHarvestEval\n", - "\n", - "# togo_eval = CropHarvestEval(\"Togo\", ignore_dynamic_world=False, num_timesteps=12, seed=0)\n", - "# results = togo_eval.finetuning_results(model, model_modes=[\"Regression\"])\n", - "# results\n", - "\n", - "# batch_size = 8\n", - "# X_np, dw_np, latlons_np, y_np = togo_eval.dataset.as_array(num_samples=batch_size)\n", - "# month_np = np.array([togo_eval.dataset.start_month] * batch_size)" - ], - "metadata": { - "id": "HE5KeE4G17fm" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# Construct input manually\n", - "batch_size = 256\n", - "\n", - "X_tensor = torch.zeros([batch_size, 12, 17])\n", - "latlons_tensor = torch.zeros([batch_size, 2])\n", - "\n", - "dw_empty = torch.full([batch_size, 12], 9, device=device).long()\n", - "month_tensor = torch.full([batch_size], 1, device=device)\n", - "\n", - "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n", - "# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]\n", - "mask = torch.zeros(X_tensor.shape, device=device).float()" - ], - "metadata": { - "id": "hDKqqzi9F7T4" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "with torch.no_grad():\n", - " preds = model.encoder(\n", - " x=X_tensor,\n", - " dynamic_world=dw_empty,\n", - " latlons=latlons_tensor,\n", - " mask=mask,\n", - " month=month_tensor\n", - " )\n", - "preds" - ], - "metadata": { - "id": "jQVWkY1nBClT" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "## Deploy to Vertex AI" - ], - "metadata": { - "id": "5v7gJoysFTsf" - } - }, - { - "cell_type": "code", - "source": [ - "%cd .." - ], - "metadata": { - "id": "GjB3wVLDGbav" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "!pip install torchserve torch-model-archiver -q\n", - "!mkdir pytorch_model" - ], - "metadata": { - "id": "Xz53_ow1ZWql" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "from ts.torch_handler.base_handler import BaseHandler\n", - "\n", - "# Make model torchscriptable\n", - "example_kwargs = {\n", - " 'x': X_tensor,\n", - " 'dynamic_world': dw_empty,\n", - " 'latlons': latlons_tensor,\n", - " 'mask': mask,\n", - " 'month': month_tensor\n", - "}\n", - "sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)\n", - "\n", - "!mkdir -p pytorch_model\n", - "sm.save('pytorch_model/model.pt')" - ], - "metadata": { - "id": "nFRvUVkHowKr" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "jit_model = torch.jit.load('pytorch_model/model.pt')\n", - "jit_model(**example_kwargs).shape" - ], - "metadata": { - "id": "cYuSPOyp1A0K" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "%%writefile pytorch_model/custom_handler.py\n", - "\n", - "import logging\n", - "\n", - "import torch\n", - "from ts.torch_handler.base_handler import BaseHandler\n", - "import numpy as np\n", - "import sys\n", - "\n", - "logger = logging.getLogger(__name__)\n", - "version = \"v27\"\n", - "batch_size = 256\n", - "\n", - "def printh(text):\n", - " print(f\"HANDLER {version}: {text}\")\n", - "\n", - "class ClassifierHandler(BaseHandler):\n", - "\n", - " def inference(self, data):\n", - " printh(\"inference begin\")\n", - "\n", - " # Data shape: [ num_pixels, composite_bands, 1, 1 ]\n", - " data = data[:, :, 0, 0]\n", - " printh(f\"data shape {data.shape}\")\n", - "\n", - " num_bands = 17\n", - " printh(f\"num_bands {num_bands}\")\n", - "\n", - " # Subtract first two latlon\n", - " num_timesteps = (data.shape[1] - 2) // num_bands\n", - " printh(f\"num_timesteps {num_timesteps}\")\n", - "\n", - " with torch.no_grad():\n", - "\n", - " batches = torch.split(data, batch_size, dim=0)\n", - "\n", - " # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.\n", - " month_tensor = torch.full([batch_size], 3, device=self.device)\n", - " printh(f\"month: 3\")\n", - "\n", - " # dynamic_world: torch.Tensor of shape [batch_size, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.\n", - " dw_empty = torch.full([batch_size, num_timesteps], 9, device=self.device).long()\n", - " printh(f\"dw {dw_empty[0]}\")\n", - "\n", - " # mask: An optional torch.Tensor of shape [batch_size, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.\n", - " mask = torch.zeros((batch_size, num_timesteps, num_bands), device=self.device).float()\n", - " printh(f\"mask sample one timestep: {mask[0, 0]}\")\n", - "\n", - " preds = []\n", - " for batch in batches:\n", - "\n", - " padding = 0\n", - " if batch.shape[0] < batch_size:\n", - " padding = batch_size - batch.shape[0]\n", - " batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])\n", - "\n", - " # x: torch.Tensor of shape [batch_size, num_timesteps, bands] where bands is described by NORMED_BANDS.\n", - " X_tensor = batch[:, 2:]\n", - " printh(f\"X {X_tensor.shape}\")\n", - "\n", - " X_tensor_reshaped = X_tensor.reshape(batch_size, num_timesteps, num_bands)\n", - " printh(f\"X sample one timestep: {X_tensor_reshaped[0, 0]}\")\n", - "\n", - " # latlons: torch.Tensor of shape [batch_size, 2] describing the latitude and longitude of each input instance.\n", - " latlons_tensor = batch[:, :2]\n", - "\n", - " # Shapes\n", - " printh(\"SHAPES\")\n", - " printh(f\"X {X_tensor_reshaped.shape}\")\n", - " printh(f\"dw {dw_empty.shape}\")\n", - " printh(f\"latlons {latlons_tensor.shape}\")\n", - " printh(f\"mask {mask.shape}\")\n", - " printh(f\"month {month_tensor.shape}\")\n", - "\n", - " pred = self.model(\n", - " x=X_tensor_reshaped,\n", - " dynamic_world=dw_empty,\n", - " latlons=latlons_tensor,\n", - " mask=mask,\n", - " month=month_tensor\n", - " )\n", - " pred_np = np.expand_dims(pred.numpy(), axis=[1,2])\n", - " if padding == 0:\n", - " preds.append(pred_np[:])\n", - " else:\n", - " preds.append(pred_np[:-padding])\n", - "\n", - " [printh(f\"{p.shape}\") for p in preds]\n", - " preds = np.concatenate(preds)\n", - " printh(f\"preds shape {preds.shape}\")\n", - " return preds\n", - "\n", - " def handle(self, data, context):\n", - " self.context = context\n", - " printh(f\"handle begin\")\n", - " input_tensor = self.preprocess(data)\n", - " printh(f\"input_tensor shape {input_tensor.shape}\")\n", - " pred_out = self.inference(input_tensor)\n", - " return self.postprocess(pred_out)" - ], - "metadata": { - "id": "htq2Ac95FJlk" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "import importlib\n", - "import pytorch_model\n", - "from pytorch_model.custom_handler import ClassifierHandler\n", - "importlib.reload(pytorch_model.custom_handler)\n", - "\n", - "from pytorch_model.custom_handler import ClassifierHandler\n", - "\n", - "# Test output\n", - "data = torch.zeros([713, 206, 1, 1])\n", - "handler = ClassifierHandler()\n", - "handler.model = jit_model\n", - "preds = handler.handle(data=data, context=None)" - ], - "metadata": { - "id": "a3Dgq5Ob5b1i" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "!torch-model-archiver -f \\\n", - " --model-name model \\\n", - " --version 1.0 \\\n", - " --serialized-file 'pytorch_model/model.pt' \\\n", - " --handler 'pytorch_model/custom_handler.py' \\\n", - " --export-path pytorch_model/" - ], - "metadata": { - "id": "90TXNAnfF-TD" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "version = \"v27\"\n", - "MODEL_DIR = f'gs://presto-models/default_v2025_04_10_{version}'\n", - "!gsutil cp -r pytorch_model {MODEL_DIR}" - ], - "metadata": { - "id": "n8m9UBy3GEvZ" - }, - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "REGION = 'us-central1'\n", - "MODEL_NAME = f'model_{version}'\n", - "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n", - "ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'\n", - "\n", - "!gcloud ai models upload \\\n", - " --artifact-uri={MODEL_DIR} \\\n", - " --project={PROJECT} \\\n", - " --region={REGION} \\\n", - " --container-image-uri={CONTAINER_IMAGE} \\\n", - " --description={MODEL_NAME} \\\n", - " --display-name={MODEL_NAME} \\\n", - " --model-id={MODEL_NAME}\n", - "\n", - "# Create endpoint, if endpoint does not exist\n", - "# !gcloud ai endpoints create \\\n", - "# --display-name={ENDPOINT_NAME} \\\n", - "# --endpoint-id={ENDPOINT_NAME} \\\n", - "# --region={REGION} \\\n", - "# --project={PROJECT}\n", - "\n", - "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n", - " --project={PROJECT} \\\n", - " --region={REGION} \\\n", - " --model={MODEL_NAME} \\\n", - " --display-name={MODEL_NAME} \\\n", - " --machine-type=\"e2-standard-4\"\n", - "\n", - "# 21 mintues when issues with server\n", - "# 4 minutes when it's working" - ], - "metadata": { - "id": "AetRF8dcGraC" - }, - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/deploy/test_embeddings.js b/deploy/test_embeddings.js deleted file mode 100644 index af8825d..0000000 --- a/deploy/test_embeddings.js +++ /dev/null @@ -1,251 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////////////////////// -// Author: Ivan Zvonkov (izvonkov@umd.edu) -// Last Edited: Apr 2, 2025 -// Description -// (1) Specify ROI -// (2) Create CropHarvest composite based on presto-v3 -// https://github.com/nasaharvest/presto-v3/tree/ff17611ef1433eff0b020f8e513c640c3959e381/src/data/earthengine -// (3) Use Presto deployed on VertexAI to create embeddings -// (4) Save embeddings as asset -/////////////////////////////////////////////////////////////////////////////////////////////// - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 1. Specifies ROI -/////////////////////////////////////////////////////////////////////////////////////////////// -var lon1 = 1.1575584013473583; -var lon2 = 1.2048954756271435; -var lat1 = 6.840427147744777; -var lat2 = 6.877467188902948; -var roi = ee.Geometry.Polygon([ - [lon1, lat1], - [lon2, lat1], - [lon2, lat2], - [lon1, lat2], - [lon1, lat1], -]); -Map.centerObject(roi, 14); - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 2. Creates CropHarvest Composite -/////////////////////////////////////////////////////////////////////////////////////////////// -var rangeStart = '2019-03-01'; -var rangeEnd = '2020-03-01'; - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 2a. Sentinel-1 Data -/////////////////////////////////////////////////////////////////////////////////////////////// -var S1_BANDS = ['VV', 'VH']; -var S1_all = ee - .ImageCollection('COPERNICUS/S1_GRD') - .filterBounds(roi) - .filterDate( - ee.Date(rangeStart).advance(-31, 'days'), - ee.Date(rangeEnd).advance(31, 'days') - ); - -var S1 = S1_all.filter( - ee.Filter.eq( - 'orbitProperties_pass', - S1_all.first().get('orbitProperties_pass') - ) -).filter(ee.Filter.eq('instrumentMode', 'IW')); -var S1_VV = S1.filter( - ee.Filter.listContains('transmitterReceiverPolarisation', 'VV') -); -var S1_VH = S1.filter( - ee.Filter.listContains('transmitterReceiverPolarisation', 'VH') -); - -function getCloseImages(middleDate, imageCollection) { - var fromMiddleDate = imageCollection - .map(function (img) { - var dateDist = ee - .Number(img.get('system:time_start')) - .subtract(middleDate.millis()) - .abs(); - return img.set('dateDist', dateDist); - }) - .sort({ property: 'dateDist', ascending: true }); - var fifteenDaysInMs = ee.Number(1296000000); - var maxDiff = ee - .Number(fromMiddleDate.first().get('dateDist')) - .max(fifteenDaysInMs); - return fromMiddleDate.filterMetadata( - 'dateDist', - 'not_greater_than', - maxDiff - ); -} - -function S1_img(date1, date2) { - var startDate = ee.Date(date1); - var daysBetween = ee.Date(date2).difference(startDate, 'days'); - var middleDate = startDate.advance(daysBetween.divide(2), 'days'); - var kept_vv = getCloseImages(middleDate, S1_VV).select('VV'); - var kept_vh = getCloseImages(middleDate, S1_VH).select('VH'); - var S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()]); - return S1_composite.select(S1_BANDS).add(25.0).divide(25.0); // S1 ranges from -50 to 1 -} - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 2b. Sentinel-2 Data -/////////////////////////////////////////////////////////////////////////////////////////////// -var S2_BANDS = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']; -var S2 = ee - .ImageCollection('COPERNICUS/S2_SR_HARMONIZED') - .filterBounds(roi) - .filterDate(rangeStart, rangeEnd); -var csPlus = ee - .ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED') - .filterBounds(roi) - .filterDate(rangeStart, rangeEnd); -var QA_BAND = 'cs_cdf'; // Better than cs here -var S2_cf = S2.linkCollection(csPlus, [QA_BAND]); - -function S2_img(date1, date2) { - return S2_cf.filterDate(date1, date2) - .qualityMosaic(QA_BAND) - .select(S2_BANDS) - .divide(ee.Image(1e4)); -} - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 2c. ERA5 Data -/////////////////////////////////////////////////////////////////////////////////////////////// -var ERA5_BANDS = ['temperature_2m', 'total_precipitation_sum']; -var ERA5 = ee - .ImageCollection('ECMWF/ERA5_LAND/MONTHLY_AGGR') - .filterBounds(roi) - .filterDate(rangeStart, rangeEnd); -function ERA5_img(date1, date2) { - return ERA5.filterDate(date1, date2) - .select(ERA5_BANDS) - .mean() - .add([-272.15, 0]) - .divide([35, 0.03]); -} - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 2d. SRTM Data -/////////////////////////////////////////////////////////////////////////////////////////////// -var SRTM_BANDS = ['elevation', 'slope']; -var elevation = ee.Image('USGS/SRTMGL1_003').clip(roi).select('elevation'); -var slope = ee.Terrain.slope(elevation); -var SRTM_img = ee.Image.cat([elevation, slope]).toDouble().divide([2000, 50]); - -function cropharvest_img(d1, d2) { - var img = ee.Image.cat([ - S1_img(d1, d2), - S2_img(d1, d2), - ERA5_img(d1, d2), - SRTM_img, - ]); - var ndvi = img.normalizedDifference(['B8', 'B4']).rename('NDVI'); - return img.addBands(ndvi).clip(roi).toFloat(); // toFloat Necessary for tensor conversion -} - -var latlons = ee.Image.pixelLonLat().clip(roi).select('latitude', 'longitude'); -var imgs = [ - latlons, - cropharvest_img('2019-03-01', '2019-04-01'), - cropharvest_img('2019-04-01', '2019-05-01'), - cropharvest_img('2019-05-01', '2019-06-01'), - cropharvest_img('2019-06-01', '2019-07-01'), - cropharvest_img('2019-07-01', '2019-08-01'), - cropharvest_img('2019-08-01', '2019-09-01'), - cropharvest_img('2019-09-01', '2019-10-01'), - cropharvest_img('2019-10-01', '2019-11-01'), - cropharvest_img('2019-11-01', '2019-12-01'), - cropharvest_img('2019-12-01', '2020-01-01'), - cropharvest_img('2020-01-01', '2020-02-01'), - cropharvest_img('2020-02-01', '2020-03-01'), -]; - -var S1vis = { - bands: ['VV', 'VH', 'VV'], - min: [0, -0.2, 0.4], - max: [1.0, 0.8, 1.2], -}; -Map.addLayer(imgs[1], S1vis, 'S1 March'); -Map.addLayer(imgs[2], S1vis, 'S1 April'); -Map.addLayer(imgs[3], S1vis, 'S1 May'); - -var S2vis = { bands: ['B4', 'B3', 'B2'], min: 0, max: 0.25 }; -Map.addLayer(imgs[1], S2vis, 'S2 March'); -Map.addLayer(imgs[2], S2vis, 'S2 April'); -Map.addLayer(imgs[3], S2vis, 'S2 May'); - -var ERA5Palette = [ - '000080', - '0000d9', - '4000ff', - '8000ff', - '0080ff', - '00ffff', - '00ff80', - '80ff00', - 'daff00', - 'ffff00', - 'fff500', - 'ffda00', - 'ffb000', - 'ffa400', - 'ff4f00', - 'ff2500', - 'ff0a00', - 'ff00ff', -]; -var ERA5vis = { - bands: ['temperature_2m'], - min: 0, - max: 1, - palette: ERA5Palette, -}; -Map.addLayer(imgs[1], ERA5vis, 'ERA5 March'); -Map.addLayer(imgs[2], ERA5vis, 'ERA5 April'); -Map.addLayer(imgs[3], ERA5vis, 'ERA5 May'); - -Map.addLayer(imgs[1], { bands: ['slope'], min: 0, max: 1 }, 'SRTM slope'); - -var composite = ee.ImageCollection.fromImages(imgs).toBands(); -var bands = composite.bandNames(); -print(bands); - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 3. Call Vertex AI Endpoint -/////////////////////////////////////////////////////////////////////////////////////////////// -var endpoint = - 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'; -var vertex_model = ee.Model.fromVertexAi({ - endpoint: endpoint, - inputTileSize: [1, 1], - proj: ee.Projection('EPSG:4326').atScale(10), - fixInputProj: true, - outputTileSize: [1, 1], - outputBands: { - p: { - // 'output' - type: ee.PixelType.float(), - dimensions: 1, - }, - }, - payloadFormat: 'ND_ARRAYS', -}); - -var predictions = vertex_model.predictImage(composite).clip(roi); -print(predictions); - -//Map.addLayer(predictions, {min: 0, max: 1},'predictions') - -/////////////////////////////////////////////////////////////////////////////////////////////// -// 4. Save embeddings -/////////////////////////////////////////////////////////////////////////////////////////////// -Export.image.toAsset({ - image: predictions, - description: 'Presto_embeddings', - assetId: 'Togo/Presto_test_embeddings_v2025_04_10', - region: roi, - scale: 10, - maxPixels: 1e12, - crs: 'EPSG:25231', -}); From 33b9e5a8f54ca44b441aec0348102d50dae45928 Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Fri, 16 May 2025 15:32:28 -0400 Subject: [PATCH 3/6] Add link to ee docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b19de93..e1cda23 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ You can generate Presto embeddings in Google Earth Engine by: 1. Deploying Presto to Vertex AI Open In Colab -2. Using the ee.model.fromVertexAI function in Google Earth Engine ([script on GEE](https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26)) +2. Using the [ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai) function in Google Earth Engine ([script on GEE](https://code.earthengine.google.com/1d196e8466506239c4780585c0e28d26)) ## Reference If you find this code useful, please cite the following paper: From b8d934c1d759a6400d5fb7896af69b4aac50a8b1 Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Fri, 16 May 2025 17:21:44 -0400 Subject: [PATCH 4/6] Embeddings clustering script --- deploy/3_Kmeans_Embeddings.js | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 deploy/3_Kmeans_Embeddings.js diff --git a/deploy/3_Kmeans_Embeddings.js b/deploy/3_Kmeans_Embeddings.js new file mode 100644 index 0000000..2d66034 --- /dev/null +++ b/deploy/3_Kmeans_Embeddings.js @@ -0,0 +1,22 @@ +//------------------------------------------------------------------------------------ +// Script for checking the embedding salience through clutering +// Author: Ivan Zvonkov (izvonkov@umd.edu) +//------------------------------------------------------------------------------------ + +// 1. Load embeddings +var embeddings = ee.Image( + 'users/izvonkov/Togo/Presto_test_embeddings_v2025_05_16' +); +var roi = embeddings.geometry({ geodesics: true }); +Map.centerObject(roi, 11); + +// 2. Cluster embeddings and display +var training = embeddings.sample({ region: roi, scale: 10, numPixels: 10000 }); +var trainedClusterer = ee.Clusterer.wekaKMeans(7).train(training); +var result = embeddings.cluster(trainedClusterer); +Map.addLayer(result.randomVisualizer(), {}, 'clusters'); + +// 3. Display WorldCover +var WorldCover = ee.ImageCollection('ESA/WorldCover/v200').first().clip(roi); +var vis = { bands: ['Map'] }; +Map.addLayer(WorldCover, vis, 'WorldCover'); From 889051ed76cded53803aa60ff66288cecfc3379a Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Fri, 23 May 2025 17:46:20 -0400 Subject: [PATCH 5/6] minor corrections --- deploy/1_Presto_to_VertexAI.ipynb | 138 +++++++++++------------------- 1 file changed, 49 insertions(+), 89 deletions(-) diff --git a/deploy/1_Presto_to_VertexAI.ipynb b/deploy/1_Presto_to_VertexAI.ipynb index f8ac691..a93ee73 100644 --- a/deploy/1_Presto_to_VertexAI.ipynb +++ b/deploy/1_Presto_to_VertexAI.ipynb @@ -84,16 +84,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cUI_5pWJ3V4s", - "outputId": "25e8a556-7401-45c0-f1ef-9c02b53948d9" + "id": "cUI_5pWJ3V4s" }, "outputs": [], "source": [ - "# REPLACE WITH YOUR CLOUD PROJECT!\n", - "PROJECT = 'presto-deployment'\n", + "PROJECT = ''\n", "!gcloud config set project {PROJECT}" ] }, @@ -101,11 +96,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "MRGjYjltsm6-", - "outputId": "7ac7b255-bea0-4d90-9baf-6157ba17aa26" + "id": "MRGjYjltsm6-" }, "outputs": [], "source": [ @@ -125,11 +116,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UKgCxBNnYJIB", - "outputId": "f2384124-4109-4031-87df-95140c7d7029" + "id": "UKgCxBNnYJIB" }, "outputs": [], "source": [ @@ -137,7 +124,7 @@ "%cd /content/presto\n", "\n", "import torch\n", - "from presto.single_file_presto import Presto\n", + "from single_file_presto import Presto\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", @@ -205,11 +192,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "nFRvUVkHowKr", - "outputId": "8704f15d-6af0-4c46-c4ee-82fe56881230" + "id": "nFRvUVkHowKr" }, "outputs": [], "source": [ @@ -231,11 +214,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cYuSPOyp1A0K", - "outputId": "c97ef2f4-7715-48ba-8b82-3d21dc401fff" + "id": "cYuSPOyp1A0K" }, "outputs": [], "source": [ @@ -258,11 +237,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "i70o_BZml9vs", - "outputId": "439d43f5-02ed-4dfd-be86-a49c058e70cd" + "id": "i70o_BZml9vs" }, "outputs": [], "source": [ @@ -273,11 +248,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "htq2Ac95FJlk", - "outputId": "79076ad1-7cde-44f2-8e4a-07f28599d29a" + "id": "htq2Ac95FJlk" }, "outputs": [], "source": [ @@ -386,11 +357,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a3Dgq5Ob5b1i", - "outputId": "e0200fa3-b535-4da6-d476-7614603b8370" + "id": "a3Dgq5Ob5b1i" }, "outputs": [], "source": [ @@ -448,15 +415,35 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "n8m9UBy3GEvZ", - "outputId": "5f67c5c5-caad-487b-97ed-f3d5255132cc" + "id": "lEUqcAqsTYpn" }, "outputs": [], "source": [ - "MODEL_DIR = f'gs://presto-models/{VERSION}'\n", + "REGION = 'us-central1'\n", + "BUCKET_NAME = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OiyyCNa-TBGQ" + }, + "outputs": [], + "source": [ + "# Create bucket to store model artifcats if it doesn't exist\n", + "!gcloud storage buckets create gs://{BUCKET_NAME} --location={REGION}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n8m9UBy3GEvZ" + }, + "outputs": [], + "source": [ + "MODEL_DIR = f'gs://{BUCKET_NAME}/{VERSION}'\n", "!gsutil cp -r pytorch_model {MODEL_DIR}" ] }, @@ -464,15 +451,11 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "AetRF8dcGraC", - "outputId": "03589695-1e44-4f79-c8bf-23b5444a0f28" + "id": "AetRF8dcGraC" }, "outputs": [], "source": [ - "REGION = 'us-central1'\n", + "# Can take 2 minutes\n", "MODEL_NAME = f'model_{VERSION}'\n", "CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'\n", "\n", @@ -500,11 +483,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "XhMK9aA73FzI", - "outputId": "f0568a0b-e606-4605-fb68-cc41096096cd" + "id": "XhMK9aA73FzI" }, "outputs": [], "source": [ @@ -515,7 +494,7 @@ "if ENDPOINT_NAME in endpoints:\n", " print(f\"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.\")\n", "else:\n", - " print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating.\")\n", + " print(f\"Endpoint: '{ENDPOINT_NAME}' does not exist, creating... (~3 minutes)\")\n", " !gcloud ai endpoints create \\\n", " --display-name={ENDPOINT_NAME} \\\n", " --endpoint-id={ENDPOINT_NAME} \\\n", @@ -537,11 +516,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "sI297v1mjtR8", - "outputId": "ddb82445-430c-442e-9e3b-d7765491a93b" + "id": "sI297v1mjtR8" }, "outputs": [], "source": [ @@ -549,15 +524,16 @@ "print(\"Track model deployment progress and prediction logs:\")\n", "print(f\"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\\n\")\n", "\n", - "# Can take updward of 25 minutes\n", "# If using for large region, set min-replica-count higher to save scaling time\n", + "# Can take from 4-27 minutes\n", + "# Relevant quota: \"Custom model serving CPUs per region\"\n", "!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \\\n", " --region={REGION} \\\n", " --model={MODEL_NAME} \\\n", " --display-name={MODEL_NAME} \\\n", " --machine-type=\"e2-standard-2\" \\\n", " --min-replica-count='1' \\\n", - " --max-replica-count=\"300\"" + " --max-replica-count=\"100\"" ] }, { @@ -573,16 +549,12 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "160PNcRRJMzn", - "outputId": "52266346-f83f-49e9-e487-ba1c96a2dc0b" + "id": "160PNcRRJMzn" }, "outputs": [], "source": [ "GEE_SCRIPT_URL = \"https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c\"\n", - "print(f\"Open this script: {GEE_SCRIPT_URL} and use the below string for the ENDPOINT variable\")\n", + "print(f\"Open this script: {GEE_SCRIPT_URL}\")\n", "print(\"Use the below string for the ENDPOINT variable\")\n", "print(f\"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}\")" ] @@ -604,11 +576,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VvOO_sfQDWPt", - "outputId": "139a2f1a-73af-4b7a-943e-6f05bae8dbfb" + "id": "VvOO_sfQDWPt" }, "outputs": [], "source": [ @@ -627,11 +595,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vswFTu9kFeHy", - "outputId": "dcc12b61-5cad-4002-e0d9-33923cd4fbea" + "id": "vswFTu9kFeHy" }, "outputs": [], "source": [ @@ -642,11 +606,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "w9Fq6yspF2ye", - "outputId": "ac70f2cc-4b0a-4fa7-c595-81ee9a8a276d" + "id": "w9Fq6yspF2ye" }, "outputs": [], "source": [ From a3933eb50f0255d3d2e34ea8de586b2b98a58587 Mon Sep 17 00:00:00 2001 From: ivanzvonkov Date: Mon, 26 May 2025 10:50:02 -0400 Subject: [PATCH 6/6] Address PR comments --- deploy/1_Presto_to_VertexAI.ipynb | 13 ++++++++----- deploy/2_Generate_Embeddings.js | 2 +- deploy/3_Kmeans_Embeddings.js | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/deploy/1_Presto_to_VertexAI.ipynb b/deploy/1_Presto_to_VertexAI.ipynb index a93ee73..8048ec4 100644 --- a/deploy/1_Presto_to_VertexAI.ipynb +++ b/deploy/1_Presto_to_VertexAI.ipynb @@ -8,7 +8,7 @@ "source": [ "# 1. Presto to Vertex AI\n", "\n", - "\n", + "\n", " \"Open\n", "\n", "\n", @@ -157,11 +157,11 @@ "source": [ "# Construct input manually\n", "batch_size = 256\n", - "\n", - "X_tensor = torch.zeros([batch_size, 12, 17])\n", + "NUM_TIMESTEPS = 12\n", + "X_tensor = torch.zeros([batch_size, NUM_TIMESTEPS, 17])\n", "latlons_tensor = torch.zeros([batch_size, 2])\n", "\n", - "dw_empty = torch.full([batch_size, 12], 9, device=device).long()\n", + "dw_empty = torch.full([batch_size, NUM_TIMESTEPS], 9, device=device).long()\n", "month_tensor = torch.full([batch_size], 1, device=device)\n", "\n", "# [0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 ]\n", @@ -269,6 +269,7 @@ " # Prepends HANDLER to each print statement to make it easier to find in logs.\n", " print(f\"HANDLER {VERSION}: {text}\")\n", "\n", + "# Custom TorchServe handler for the Presto model\n", "class ClassifierHandler(BaseHandler):\n", "\n", " def inference(self, data):\n", @@ -509,7 +510,9 @@ "source": [ "### 5c. Deploy model to endpoint\n", "> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.\n", - "https://cloud.google.com/vertex-ai/docs/general/deployment" + "https://cloud.google.com/vertex-ai/docs/general/deployment\n", + "\n", + "⚠️ The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\\$0.0771/node hour in us-central-1). So as long as the endpoint is active you will be paying \\$0.0771/hour even if no predictions are made." ] }, { diff --git a/deploy/2_Generate_Embeddings.js b/deploy/2_Generate_Embeddings.js index efe8b7f..49f78cb 100644 --- a/deploy/2_Generate_Embeddings.js +++ b/deploy/2_Generate_Embeddings.js @@ -14,7 +14,7 @@ var rangeEnd = ee.Date('2020-03-01'); var ENDPOINT = 'projects/presto-deployment/locations/us-central1/endpoints/vertex-pytorch-presto-endpoint'; -var RUN_VERTEX_AI = true; // Leave this as false to get a cost estimate first +var RUN_VERTEX_AI = false; // Leave this as false to get a cost estimate first //------------------------------------------------------------------------------------ Map.centerObject(roi, 10); diff --git a/deploy/3_Kmeans_Embeddings.js b/deploy/3_Kmeans_Embeddings.js index 2d66034..bcd1b17 100644 --- a/deploy/3_Kmeans_Embeddings.js +++ b/deploy/3_Kmeans_Embeddings.js @@ -10,7 +10,7 @@ var embeddings = ee.Image( var roi = embeddings.geometry({ geodesics: true }); Map.centerObject(roi, 11); -// 2. Cluster embeddings and display +// 2. Cluster embeddings and display (7 clusters) var training = embeddings.sample({ region: roi, scale: 10, numPixels: 10000 }); var trainedClusterer = ee.Clusterer.wekaKMeans(7).train(training); var result = embeddings.cluster(trainedClusterer);