diff --git a/DP0.2/20_Introduction_to_Data_Science.ipynb b/DP0.2/20_Introduction_to_Data_Science.ipynb new file mode 100644 index 00000000..995e6d81 --- /dev/null +++ b/DP0.2/20_Introduction_to_Data_Science.ipynb @@ -0,0 +1,1380 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5eb2606b-2724-43be-83bb-fb4df67d0b35", + "metadata": {}, + "source": [ + "# Introduction to Data Science\n", + " \n", + "
\n", + "Contact author(s): Becky Nevin, Brian Nord
\n", + "Last verified to run: 2025-05-09
\n", + "LSST Science Pipelines version: Weekly 2025_09
\n", + "Container size: small
\n", + "Targeted learning level: intermediate
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a4e60f3-a825-4dcc-8cd5-37ed43b7373b", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pycodestyle_magic\n", + "%flake8_on\n", + "import logging\n", + "logging.getLogger(\"flake8\").setLevel(logging.FATAL)" + ] + }, + { + "cell_type": "markdown", + "id": "05e4d2c6-9774-48db-930e-58536bd69de0", + "metadata": {}, + "source": [ + "**Description:** This notebook guides a PI through the process of using python's data science and machine learning libraries to explore data from complex ADQL queries with the TAP service. The goal is to build a predictive model to estimate missing $r-$band Kron Flux values when the other bands are available, and visualize the results and quantify the model performance." + ] + }, + { + "cell_type": "markdown", + "id": "242eb338-4d88-4a17-8e35-20d80774ec1b", + "metadata": {}, + "source": [ + "**Skills:** Use of data science and machine learning tools such as scikit-learn, pandas, and seaborn." + ] + }, + { + "cell_type": "markdown", + "id": "76ffcb7f-a070-47b7-8545-f10813ff78ae", + "metadata": {}, + "source": [ + "**LSST Data Products:** Object, Forcedsource, and CcdVisit tables." + ] + }, + { + "cell_type": "markdown", + "id": "bf4d0080-92f9-4a67-a829-5a772e428ac4", + "metadata": {}, + "source": [ + "**Packages:** lsst.rsp, pandas, scikit-learn, seaborn" + ] + }, + { + "cell_type": "markdown", + "id": "e354c692-b8b4-425f-bcbe-a35221a9993a", + "metadata": {}, + "source": [ + "**Credits:** Developed by Becky Nevin in collaboration with Melissa Graham, Brian Nord, and the Rubin Community Science Team for DP0.2. Based on notebooks developed by Leanne Guy (TAP query) and Alex Drlica-Wagner and Melissa Graham (Butler query).\n", + "Please consider acknowledging them if this notebook is used for the preparation of journal articles, software releases, or other notebooks." + ] + }, + { + "cell_type": "markdown", + "id": "2e3e438c-a028-417f-a21f-1d31c5946376", + "metadata": {}, + "source": [ + "**Get Support:**\n", + "Find DP0-related documentation and resources at dp0.lsst.io. Questions are welcome as new topics in the Support - Data Preview 0 Category of the Rubin Community Forum. Rubin staff will respond to all questions posted there." + ] + }, + { + "cell_type": "markdown", + "id": "ca5378d1-00fe-44ad-be88-5f1a0a85b404", + "metadata": {}, + "source": [ + "# 1. Introduction\n", + "\n", + "This notebook provides an intermediate-level demonstration of how to use the Table Access Protocol (TAP) server and ADQL (Astronomy Data Query Language) to query and retrieve data from the DP0.2 catalogs.\n", + "\n", + "TAP provides standardized access to catalog data for discovery, search, and retrieval.\n", + "Full documentation for TAP is provided by the International Virtual Observatory Alliance (IVOA).\n", + "ADQL is similar to SQL (Structured Query Langage).\n", + "The documentation for ADQL includes more information about syntax and keywords.\n", + "Note that not all ADQL functionality is supported yet in the DP0-era RSP.\n", + "\n", + "**See the recommendations for TAP queries in DP0.2 tutorial 02a \"Introduction to the TAP Service\".**\n", + "\n", + "The [documentation for Data Preview 0.2](https://dp0-2.lsst.io/) includes definitions\n", + "of the data products, descriptions of catalog contents, and ADQL recipes.\n", + "\n", + "## 1.1. Package imports\n", + "\n", + "Import general python packages, the Rubin TAP service utilities, and various scikit-learn utilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f4900a4-3358-472a-b9ba-c42e3f2f0771", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import pandas\n", + "\n", + "from astropy import units as u\n", + "from astropy.coordinates import SkyCoord\n", + "\n", + "from sklearn.model_selection import train_test_split, GridSearchCV\n", + "from sklearn.preprocessing import FunctionTransformer, StandardScaler\n", + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.metrics import mean_squared_error\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "import sklearn.metrics as metrics\n", + "import inspect\n", + "from sklearn.utils import all_estimators\n", + "\n", + "from lsst.rsp import get_tap_service, retrieve_query" + ] + }, + { + "cell_type": "markdown", + "id": "90251edc-e77a-4f2c-aef3-935381faebc9", + "metadata": {}, + "source": [ + "Set up seaborn to use a friendly palette." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94acc9f6-2033-4ace-aefd-d036a35f4221", + "metadata": {}, + "outputs": [], + "source": [ + "sns.set_style(\"whitegrid\")\n", + "palette = sns.color_palette(\"colorblind\")\n", + "sns.set_palette(palette)" + ] + }, + { + "cell_type": "markdown", + "id": "ca1e28f4-805e-4480-a7c6-0473b7e2b088", + "metadata": {}, + "source": [ + "## 1.2. Define functions and parameters\n", + "\n", + "Instantiate the TAP service." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caf56589-100a-4481-8f24-5f5058b6671f", + "metadata": {}, + "outputs": [], + "source": [ + "service = get_tap_service(\"tap\")\n", + "assert service is not None" + ] + }, + { + "cell_type": "markdown", + "id": "9eb0f20e-6c28-404f-8032-acc00f73405a", + "metadata": {}, + "source": [ + "Set the maximum number of rows to display from pandas." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b7b6002-2457-4c20-a03e-6bfa24a0aa27", + "metadata": {}, + "outputs": [], + "source": [ + "pandas.set_option(\"display.max_rows\", 6)" + ] + }, + { + "cell_type": "markdown", + "id": "cd325fe4-6c7c-4803-ad79-f30d7edc23e3", + "metadata": {}, + "source": [ + "# 2. Query for Kron fluxes around extended (galaxy) objects.\n", + "The Kron flux is a measurement of the total flux (or brightness) of an astronomical object, typically a galaxy or extended source, obtained using an elliptical aperture that scales with the object's light profile. It’s designed to include most of the object’s light while minimizing background contamination.\n", + "\n", + "The aperture is defined based on the Kron radius, which is calculated from the first moment of the light distribution. The resulting aperture is adaptive - it changes in size and shape depending on the morphology of the source." + ] + }, + { + "cell_type": "markdown", + "id": "6b4f495d-1215-421d-bdb0-bc32fec92c25", + "metadata": {}, + "source": [ + "Define the coordinates and radius to use for the example queries in Sections 2 and 3." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ddd0344-b354-45a0-9e5a-755149c9bc54", + "metadata": {}, + "outputs": [], + "source": [ + "center_ra = 62\n", + "center_dec = -37\n", + "radius = 0.1\n", + "\n", + "str_center_coords = str(center_ra) + \", \" + str(center_dec)\n", + "str_radius = str(radius)" + ] + }, + { + "cell_type": "markdown", + "id": "dd80babb-ee05-49e9-9f9c-923d5c0cee31", + "metadata": {}, + "source": [ + "Start with the same query as used in the beginner TAP tutorial notebook 02a. Note that the extendedness flag in the $g-$band is used to select for galaxies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "985e3b62-8065-42ec-a40c-1232c4c45f17", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"SELECT coord_ra, coord_dec, g_kronFlux, g_kronFlux_flag, \"\\\n", + " \"r_kronFlux, r_kronFlux_flag, i_kronFlux, i_kronFlux_flag \"\\\n", + " \"FROM dp02_dc2_catalogs.Object \"\\\n", + " \"WHERE CONTAINS(POINT('ICRS', coord_ra, coord_dec), \"\\\n", + " \"CIRCLE('ICRS', \" + str_center_coords + \", \" + str_radius + \")) = 1 \"\\\n", + " \"AND detect_isPrimary = 1 AND g_extendedness = 1\"\n", + "print(query)" + ] + }, + { + "cell_type": "markdown", + "id": "f024085b-0f7f-45c6-8184-41b528c15396", + "metadata": {}, + "source": [ + "Run the query job asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c02adc91-5f5e-418b-87a3-cba8beba7dd2", + "metadata": {}, + "outputs": [], + "source": [ + "job = service.submit_job(query)\n", + "job.run()\n", + "job.wait(phases=[\"COMPLETED\", \"ERROR\"])\n", + "print(\"Job phase is\", job.phase)" + ] + }, + { + "cell_type": "markdown", + "id": "80b28a39-bd12-49d6-9cce-f8ddfc31296c", + "metadata": {}, + "source": [ + "# 3. Explore the data using a `pandas` DataFrame object.\n", + "From the `pandas` docs:\n", + "> A DataFrame is a two-dimensional, size-mutable, heterogeneous tabular data structure with labeled axes (rows and columns)." + ] + }, + { + "cell_type": "markdown", + "id": "07d1cfb1-589b-402b-8b8f-c2c70652b6c6", + "metadata": {}, + "source": [ + "Return the results as a `pandas` DataFrame, and then delete the query to save space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cd2f538-c2d7-44ca-ab4d-825120b8f2e7", + "metadata": {}, + "outputs": [], + "source": [ + "results = job.fetch_result().to_table().to_pandas()\n", + "job.delete()\n", + "del query" + ] + }, + { + "cell_type": "markdown", + "id": "27888222-8fca-4620-a838-9260b0f5f47f", + "metadata": {}, + "source": [ + "Display `results`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee4d121e-6b4d-4371-afae-4f7587b95d51", + "metadata": {}, + "outputs": [], + "source": [ + "results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db2168fe-593a-423d-b2f4-26ac0db60e8c", + "metadata": {}, + "outputs": [], + "source": [ + "type(results)" + ] + }, + { + "cell_type": "markdown", + "id": "9c59c4e8-90bd-4aa8-8ce2-9a14c09988a0", + "metadata": {}, + "source": [ + "There's a lot of options for investigating DataFrame objects. Some options are inspection- and summary-oriented, such as the `.head()`, `.tail()`, and `.describe()` attributes.\n", + "\n", + "Check these out now. `.head()` and `.tail()` show the first and last five rows, respectively, but can be modified to print out a different number of rows. `.describe()` provides statistics of the distribution of values in each column, including the mean and standard deviation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eec25f58-d3f3-4ef4-b3e2-ab105c4718fd", + "metadata": {}, + "outputs": [], + "source": [ + "print(results.head())\n", + "print(results.tail(10))\n", + "print(results.describe())" + ] + }, + { + "cell_type": "markdown", + "id": "1b37fc18-e00b-4ef2-8d18-85d823d60d9e", + "metadata": {}, + "source": [ + "# 4. Visualize the data using `seaborn`\n", + "Use the boxplot tool from `seaborn` to visualize the distribution of the values in each column of the DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fc9b578-2be4-4fb2-8d74-ebca809ea99f", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(8, 6))\n", + "sns.boxplot(data=results)\n", + "plt.title('Box Plot of Data Distributions')\n", + "plt.xlabel('Feature')\n", + "plt.ylabel('Value')\n", + "plt.xticks(rotation=90);" + ] + }, + { + "cell_type": "markdown", + "id": "5a1199ac-f06b-46be-8a98-9b6ca175eb8f", + "metadata": {}, + "source": [ + "> A boxplot of the following columns: `coord_ra`, `coord_dec`, `g_kronFlux`, `g_kronflux_flag`, `r_kronFlux`, `r_kronFlux_flag`, `i_kronFlux`, and `i_kronFlux_flag`. The scaling only allows us to see the outlier circles for all of these columns, not the actual boxplot. This will need to be rescaled." + ] + }, + { + "cell_type": "markdown", + "id": "ef5f20de-0fd6-4ba1-9cab-9d59cd05df99", + "metadata": {}, + "source": [ + "The outliers (points far from the majority of the distribution) are dominant in the visualization. Hide these and also only plot the Kron Flux values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bacf5114-6a64-4100-8eb6-f1d9ddc36f89", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(8, 6))\n", + "sns.boxplot(data=results[['g_kronFlux','r_kronFlux','i_kronFlux']], showfliers=False)\n", + "plt.title('Box Plot of Data Distributions')\n", + "plt.xlabel('Feature')\n", + "plt.ylabel('Value')\n", + "plt.xticks(rotation=90);" + ] + }, + { + "cell_type": "markdown", + "id": "22a861da-38af-47c9-a29a-782fea694615", + "metadata": {}, + "source": [ + "> A boxplot of the following columns: `g_kronFlux`, `r_kronFlux`, and `i_kronFlux`. The boxplots have greater weight to lower values (<1000), with the whiskers extending for all columns to negative values." + ] + }, + { + "cell_type": "markdown", + "id": "75d6336b-1068-46cf-8105-b945fd18020e", + "metadata": {}, + "source": [ + "Boxplots show a box and whiskers.\n", + "- The \"box\" is the interquartile range (IQR), which is the 25th percentile of the distribution of a value to the 75th percentile.\n", + "- The horizontal line inside the box is the median of the distribution.\n", + "- The whisker extends from the IQR to 1.5*IQR away from the edge of the box.\n", + "- Points outside the whisker are considered outliers (hidden here).\n" + ] + }, + { + "cell_type": "markdown", + "id": "448bec6a-15e1-49b5-a52e-bdd743ff207a", + "metadata": {}, + "source": [ + "Use `seaborn`'s violinplot tool to visualize the distribution for these same Kron flux values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39521ac6-0bec-42e7-9062-8fc9ce5edc55", + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(8, 6))\n", + "filtered_results = results[['g_kronFlux', 'r_kronFlux', 'i_kronFlux']].apply(lambda x: x[(x > x.quantile(0.05)) & (x < x.quantile(0.95))])\n", + "sns.violinplot(data=filtered_results,\n", + " cut=0)\n", + "plt.title('Box Plot of Data Distributions')\n", + "plt.xlabel('Feature')\n", + "plt.ylabel('Value')\n", + "plt.xticks(rotation=90);" + ] + }, + { + "cell_type": "markdown", + "id": "9e2aa0ec-7ea6-4d65-94ee-b44ef7a673b6", + "metadata": {}, + "source": [ + "> A violinplot of the following columns: `g_kronFlux`, `r_kronFlux`, and `i_kronFlux`. The violinplots have greater weight to lower values (<1000), and also include boxplots inside." + ] + }, + { + "cell_type": "markdown", + "id": "4f2c50a8-098e-4230-b122-3880c8bb1883", + "metadata": {}, + "source": [ + "A violinplot gives a lot of the same information as a boxplot, in fact, there are little boxplots within the violinplot; the horizontal white line is the median, the thicker grey box is the IQR and the thin line shows the 1.5*IQR span. A violinplot also uses a kernel density extimator to visualize the distribution of each feature. These plots reveal that most of the data are clustered around relatively low values for all of the Kron fluxes." + ] + }, + { + "cell_type": "markdown", + "id": "3ec470a9-8db0-403a-89ba-32e0dd9bef15", + "metadata": {}, + "source": [ + "# 5. Clean the data" + ] + }, + { + "cell_type": "markdown", + "id": "18dba188-ea4c-47e9-8faf-62e3552add1e", + "metadata": {}, + "source": [ + "Use `pandas` to investigate if there are any flags on the `kronFlux` measurement. The `.value_counts()` method will show the number of True and False columns, where True are rows for which the `g_kronFlux` measurement was flagged for a variety of reasons. There are many other columns that investigate specific reasons why this measurement is untrustworthy; the `g_kronFlux_flag` is a way to combine all of the individual flags. When this flag is set to `True`, the row is flagged." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0be4535d-cc89-45ef-98e9-591b9f459fae", + "metadata": {}, + "outputs": [], + "source": [ + "results['g_kronFlux_flag'].value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "24e94c30-66b3-4b9c-b79b-bcf024d5f214", + "metadata": {}, + "source": [ + "Okay what about the `r_kronFlux` measurement?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e66ccb2-3922-471b-8c15-7fb055d02a10", + "metadata": {}, + "outputs": [], + "source": [ + "results['r_kronFlux_flag'].value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "fa7a4ce1-7bd4-47f8-a480-188b2c70579a", + "metadata": {}, + "source": [ + "Perform an intersection to see if the flagged entries overlap between these two photometric bands." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06786c33-2563-4237-9d0f-22d6308c0d7b", + "metadata": {}, + "outputs": [], + "source": [ + "r_values = set(results['r_kronFlux_flag'].unique())\n", + "g_values = set(results['g_kronFlux_flag'].unique())\n", + "\n", + "overlap = r_values & g_values\n", + "\n", + "overlap_true_rows = results[\n", + " (results['r_kronFlux_flag'].isin(overlap)) & \n", + " (results['g_kronFlux_flag'].isin(overlap)) & \n", + " (results['r_kronFlux_flag'] == True) & \n", + " (results['g_kronFlux_flag'] == True)\n", + "]\n", + "\n", + "print(overlap_true_rows)" + ] + }, + { + "cell_type": "markdown", + "id": "ec13b104-ad8d-4bd6-8a93-b6d1d57b921e", + "metadata": {}, + "source": [ + "There are many overlapping rows, meaning that in these cases, both photometric bands are flagged. Since the task at hand is a prediction one between three bands ($g$, $r$, and $i$) the bigger concern is the cases where any of these individual three Kron fluxes are flagged. Exclude rows where this is the case and build an \"unflagged\" DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6294681-9c60-4ec6-805c-d378300acaa3", + "metadata": {}, + "outputs": [], + "source": [ + "unflagged_df = results[\n", + " (results['r_kronFlux_flag'] == False) & \n", + " (results['g_kronFlux_flag'] == False) &\n", + " (results['i_kronFlux_flag'] == False)\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "38a84cda-eedb-4238-9d27-f91cb9f56531", + "metadata": {}, + "source": [ + "Visualize the relationship between $g$ and $r$ in this unflagged DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61a66274-c3e1-4e41-b743-649fc00d69b7", + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(unflagged_df['g_kronFlux'], unflagged_df['r_kronFlux'])\n", + "plt.xlabel(r'$g-$band kronFlux [nJy]')\n", + "plt.ylabel(r'$r-$band kronFlux [nJy]');" + ] + }, + { + "cell_type": "markdown", + "id": "ae55eb3a-8a47-44e5-8242-38467b215cde", + "metadata": {}, + "source": [ + "> Scatter plot of $g-$gand kronFlux [nJy] (x-axis) versus $r-$band kronFlux [nJy] (y-axis). The blue points appear to be roungly linear in this space with more concentration towards lower values and spread at higher values. The few points at high value have values on order 1e6 and there is a concentration around values less than 0.5e5 (in both axes)." + ] + }, + { + "cell_type": "markdown", + "id": "1b76242e-6a0d-4f0c-ac35-fa27d4fef24a", + "metadata": {}, + "source": [ + "Zoom in." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5afedb17-6478-4f2b-bdfc-38e73cd4a65e", + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(unflagged_df['g_kronFlux'], unflagged_df['r_kronFlux'])\n", + "plt.xlabel(r'$g-$band kronFlux [nJy]')\n", + "plt.ylabel(r'$r-$band kronFlux [nJy]')\n", + "plt.xlim([0,30000])\n", + "plt.ylim([0,0.5e5]);" + ] + }, + { + "cell_type": "markdown", + "id": "efffb2cf-1874-42f3-a646-fbe8753774c6", + "metadata": {}, + "source": [ + "> Scatter plot of $g-$gand kronFlux [nJy] (x-axis) versus $r-$band kronFlux [nJy] (y-axis). The blue points appear to be roungly linear in this space with more concentration towards lower values and spread at higher values." + ] + }, + { + "cell_type": "markdown", + "id": "2f503394-3816-4d31-9cf0-9de88e229f87", + "metadata": {}, + "source": [ + "There does seem to be a relationship between $g-$ and $r-$band Kron fluxes, meaning that it should be possible to do some predictive work here." + ] + }, + { + "cell_type": "markdown", + "id": "ca8c22a0-485a-4525-af2f-1e4c8bf74cd0", + "metadata": {}, + "source": [ + "Another thing to check before starting to fit is if there are any nans in these Series. (A Series is a single column of a DataFrame.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01fcf2a8-7f85-4ec0-8ebe-b69b76da7294", + "metadata": {}, + "outputs": [], + "source": [ + "print(unflagged_df['g_kronFlux'].isna().any(), unflagged_df['r_kronFlux'].isna().any())" + ] + }, + { + "cell_type": "markdown", + "id": "62479245-028e-4709-9c46-bc244463176c", + "metadata": {}, + "source": [ + "Both outputs should be `False`." + ] + }, + { + "cell_type": "markdown", + "id": "2309022c-3530-4f3c-9f7a-7d51f3537afe", + "metadata": {}, + "source": [ + "Another important step is to check for negative (or zero) values in these data; they are fluxes so they should not be negative or equal to zero. First, count the number of these values in each of the $g-$ and $r-$band Series." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47a125c4-77fa-4712-be40-241318966774", + "metadata": {}, + "outputs": [], + "source": [ + "neg_g = (unflagged_df['g_kronFlux'] <= 0).sum()\n", + "neg_r = (unflagged_df['r_kronFlux'] <= 0).sum()\n", + "\n", + "print(f\"Negative or zero values in g_kronFlux: {neg_g}\")\n", + "print(f\"Negative or zero values in r_kronFlux: {neg_r}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7a5d982a-939d-4e84-afbb-abe48975d502", + "metadata": {}, + "source": [ + "Uh oh! This is a big problem, let's mask all rows of the dataframe where any individual flux measurement is negative, making a \"clean\" DataFrame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffbe5670-6de4-4a92-99f7-4e480789b596", + "metadata": {}, + "outputs": [], + "source": [ + "mask = (unflagged_df['g_kronFlux'] > 0) & (unflagged_df['r_kronFlux'] > 0) & (unflagged_df['i_kronFlux'] > 0)\n", + "clean_df = unflagged_df[mask]" + ] + }, + { + "cell_type": "markdown", + "id": "20799ed0-9d75-4a85-8a7c-57cebbda85f6", + "metadata": {}, + "source": [ + "It's always good to double check if the above mask worked. Below, there should be zero values in the `clean_df`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "782e7e2e-372e-4fcb-b837-a19a3bc83511", + "metadata": {}, + "outputs": [], + "source": [ + "neg_g = (clean_df['g_kronFlux'] <= 0).sum()\n", + "neg_r = (clean_df['r_kronFlux'] <= 0).sum()\n", + "print(f\"Negative or zero values in g_kronFlux: {neg_g}\")\n", + "print(f\"Negative or zero values in r_kronFlux: {neg_r}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4704605a-4665-4ccc-bd7e-cefaf5e09828", + "metadata": {}, + "source": [ + "# 6. Prepare the training and test sets\n", + "The goal is to predict the $r-$band Kron flux using the $g-$band Kron flux. The first step is to define the training and validation data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc90feca-ede1-44b0-929b-2fec1ddf5ad4", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(\n", + " clean_df['g_kronFlux'].to_frame(),\n", + " clean_df['r_kronFlux'].to_frame(),\n", + " test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "markdown", + "id": "37d77db2-1a91-4696-831f-d426f3ca7539", + "metadata": {}, + "source": [ + "The `.to_frame()` argument is required to input the X data as a 2D shape, as expected by `scikit-learn`." + ] + }, + { + "cell_type": "markdown", + "id": "ecb487df-ce82-4d2d-9821-64620e1e922b", + "metadata": {}, + "source": [ + "It's best practice to use a standard scaler when training machine learning models. Transform the training and test data, scaling to a mean of zero and a standard deviation of one by default.\n", + "\n", + "**Note to future developer of this notebook: I (Becky) was wondering if a log transform is also necessary because the data seems logarithmically scaled. I'm not sure if it's needed, but I've set this up to also include an invertible log transform if you want to test it.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f09a28c9-f868-4cfa-a309-c53b26193e01", + "metadata": {}, + "outputs": [], + "source": [ + "log_scaler = FunctionTransformer(func=np.log, inverse_func=np.exp, validate=True)\n", + "\n", + "X_scaler = make_pipeline(\n", + " #log_scaler,\n", + " StandardScaler()\n", + ")\n", + "\n", + "y_scaler = make_pipeline(\n", + " #log_scaler,\n", + " StandardScaler()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9c8aeb31-3cbd-461f-88de-3e60e9acb4a7", + "metadata": {}, + "source": [ + "Fit to the transform and then also transform the test data according to the scaler." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec1efab7-be4c-4bdc-9ead-78fd6b400345", + "metadata": {}, + "outputs": [], + "source": [ + "X_train_scaled = X_scaler.fit_transform(X_train)\n", + "X_test_scaled = X_scaler.transform(X_test)\n", + "\n", + "y_train_scaled = y_scaler.fit_transform(y_train)\n", + "y_test_scaled = y_scaler.transform(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "f990ce31-fb46-4c64-b56c-b26781c8e571", + "metadata": {}, + "source": [ + "Check that the resultant distribution for training data has a center of roughly zero." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5be7b1b1-2177-46e5-ad41-607e55d12949", + "metadata": {}, + "outputs": [], + "source": [ + "plt.clf()\n", + "plt.hist(X_train_scaled, range=[-2,2]);" + ] + }, + { + "cell_type": "markdown", + "id": "5e0d7b9a-b383-43c2-bb4a-826586326484", + "metadata": {}, + "source": [ + "> A histogram with a x-axis range from -2 to 2 for the distribution of the rescaled `X_train` data." + ] + }, + { + "cell_type": "markdown", + "id": "8dfc5412-e4a5-4f32-a27c-25acb1a8c54c", + "metadata": {}, + "source": [ + "Perform the inverse transform to check if the twice transformed data is the same as the original." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4086f5f-7c2d-4d6c-a2d4-8dff42e2fc84", + "metadata": {}, + "outputs": [], + "source": [ + "X_train_tt = X_scaler.inverse_transform(X_train_scaled)\n", + "X_test_tt = X_scaler.inverse_transform(X_test_scaled)\n", + "\n", + "y_train_tt = y_scaler.inverse_transform(y_train_scaled)\n", + "y_test_tt = y_scaler.inverse_transform(y_test_scaled)" + ] + }, + { + "cell_type": "markdown", + "id": "46ca963a-c142-438b-91a7-c6e30904e45e", + "metadata": {}, + "source": [ + "Check that the original `X_train` is the same as the twice transformed `X_train_tt`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc4cecc6-f256-4f55-8d54-f585630e8b4f", + "metadata": {}, + "outputs": [], + "source": [ + "plt.clf()\n", + "plt.hist(X_train_tt, range=[0,1500], label='twice transformed', alpha=0.5, color='red')\n", + "plt.hist(X_train, range=[0,1500], label='original', alpha=0.5, color='orange')\n", + "plt.legend();" + ] + }, + { + "cell_type": "markdown", + "id": "ad501069-4509-40d2-b05d-03d800e40296", + "metadata": {}, + "source": [ + "> Overlapping histograms for `X_train` and `X_train_tt` with an x-axis range of 0 to 1500." + ] + }, + { + "cell_type": "markdown", + "id": "6011b2e6-5197-4475-8eb4-3a8841b3bd28", + "metadata": {}, + "source": [ + "# 7. Model (using `scikit-learn`)" + ] + }, + { + "cell_type": "markdown", + "id": "48c8d648-288f-4635-a961-248f127fe524", + "metadata": {}, + "source": [ + "## 7.1 Start with a linear regression" + ] + }, + { + "cell_type": "markdown", + "id": "b99ff86e-f364-4edb-8c4f-bdd645a525ae", + "metadata": {}, + "source": [ + "Instantiate the model and then fit it with the training data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e02ca479-6105-442b-9879-2eb215dc4d66", + "metadata": {}, + "outputs": [], + "source": [ + "model_lr = LinearRegression()\n", + "model_lr.fit(X_train_scaled, y_train_scaled)" + ] + }, + { + "cell_type": "markdown", + "id": "b2fe11b4-a9ad-45f9-9c1e-e25bb3a0301c", + "metadata": {}, + "source": [ + "Now predict the test data. Compare the prediction to the true value. Use the mean squared error function to get a diagnostic on how well the fit is behaving for the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aca8064a-c94f-4792-86b3-62ec2130471b", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = model_lr.predict(X_test_scaled)\n", + "mse = mean_squared_error(y_test_scaled, y_pred)\n", + "print(\"MSE:\", mse)" + ] + }, + { + "cell_type": "markdown", + "id": "c9086695-43c6-49e5-a751-bcabc275172d", + "metadata": {}, + "source": [ + "Plot how the predicted values compare to the true values." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee8bd887-928e-4a41-bd77-149b344ab238", + "metadata": {}, + "outputs": [], + "source": [ + "plt.clf()\n", + "plt.scatter(y_test_scaled, y_pred)\n", + "plt.plot([-4,6],[-4,6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([-5,5])\n", + "plt.ylim([-5,5]);" + ] + }, + { + "cell_type": "markdown", + "id": "f33c5476-7662-4e5d-95e5-a2103e6ab2d8", + "metadata": {}, + "source": [ + "> Scatter plot of true versus predicted value for the scaled $r-$band Kron flux. A 1:1 line shows the expected value if the linear regression were a perfect predictor." + ] + }, + { + "cell_type": "markdown", + "id": "8367d3db-3527-4cab-abdb-51391259ec02", + "metadata": {}, + "source": [ + "Transform the predictions back to the true values (remember the current scaling is standardized) and look at the results in that space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6770c381-1f34-44a7-b93b-b7f915013620", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred_tt = y_scaler.inverse_transform(y_pred)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f83cf984-ffa7-4486-98ec-bce890b91bad", + "metadata": {}, + "outputs": [], + "source": [ + "plt.clf()\n", + "plt.scatter(y_test_tt, y_pred_tt)\n", + "plt.plot([-1e3,2e4],[-1e3,2e4], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted');\n", + "plt.xlim([-1e3,2e4])\n", + "plt.ylim([-1e3,2e4]);" + ] + }, + { + "cell_type": "markdown", + "id": "c2b12a25-236e-4a1b-804b-03b97960712c", + "metadata": {}, + "source": [ + "> A scatterplot of the true versus predicted value for the $r-$band Kron flux scaled back to the original values. The x- and y-axes scale between -1000 to 2e4." + ] + }, + { + "cell_type": "markdown", + "id": "da4c11bc-53bb-48b5-a89f-aeb2628fd77d", + "metadata": {}, + "source": [ + "Also get the MSE for the rescaled y value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb653dbd-4e77-4099-80ac-781f20cb38eb", + "metadata": {}, + "outputs": [], + "source": [ + "mse = mean_squared_error(y_test_tt, y_pred_tt)\n", + "print(\"MSE:\", mse)" + ] + }, + { + "cell_type": "markdown", + "id": "1de26eb9-e7cf-4e94-bf14-0fda5d8efe18", + "metadata": {}, + "source": [ + "## 7.2 Improve the model with more features\n", + "Let's see if this will improve with more predictive values, this time including the $i-$band Kron flux. Split the training/test sets, perform the standard scaling, and linear regression. This time, do not scale the y-value $r-$band Kron flux." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d80e56eb-6e3d-4110-bee6-3681ee4a923b", + "metadata": {}, + "outputs": [], + "source": [ + "X_train_i, X_test_i, y_train_i, y_test_i = train_test_split(\n", + " clean_df[['g_kronFlux', 'i_kronFlux']],\n", + " clean_df['r_kronFlux'],\n", + " test_size=0.2,\n", + " random_state=42\n", + ")\n", + "\n", + "scaler = StandardScaler()\n", + "X_train_i_scaled = scaler.fit_transform(X_train_i)\n", + "X_test_i_scaled = scaler.transform(X_test_i)\n", + "\n", + "model_mlr = LinearRegression()\n", + "model_mlr.fit(X_train_i_scaled, y_train_i)\n", + "\n", + "y_pred_i = model_mlr.predict(X_test_i_scaled)\n", + "\n", + "mse = mean_squared_error(y_test_i, y_pred_i)\n", + "print(\"MSE:\", mse)\n", + "\n", + "plt.clf()\n", + "plt.scatter(y_test_i, y_pred_i)\n", + "plt.plot([0, 1e6], [0, 1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0, 2e4])\n", + "plt.ylim([0, 2e4]);" + ] + }, + { + "cell_type": "markdown", + "id": "5f72c9ad-547a-455a-8b0e-39674e94fe43", + "metadata": {}, + "source": [ + "> A scatter plot demonstrating a tighter fit when the $i-$band Kron flux is included as a predictive feature." + ] + }, + { + "cell_type": "markdown", + "id": "30cd4c78-35e5-4459-84bf-9642068da002", + "metadata": {}, + "source": [ + "Test for the reader: try to improve this further by including more features. The MSE has already improved from the above value!" + ] + }, + { + "cell_type": "markdown", + "id": "89e92033-ff41-4662-b5aa-ece41c216a07", + "metadata": {}, + "source": [ + "## 7.3 Random forest regressor\n", + "These are a great type of model for regression or classification. Use it here to perform regression, again, also including the $i-$band Kron fluxes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efb56f9d-6487-444d-b283-6d60c1694948", + "metadata": {}, + "outputs": [], + "source": [ + "model_rf = RandomForestRegressor()\n", + "model_rf.fit(X_train_i_scaled, y_train_i)\n", + "\n", + "y_pred_i_rf = model_rf.predict(X_test_i_scaled)\n", + "\n", + "\n", + "mse = mean_squared_error(y_test_i, y_pred_i_rf)\n", + "print(\"MSE:\", mse)\n", + "\n", + "plt.clf()\n", + "plt.scatter(y_test_i, y_pred_i_rf)\n", + "plt.plot([0, 1e6], [0, 1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0, 2e4])\n", + "plt.ylim([0, 2e4]);" + ] + }, + { + "cell_type": "markdown", + "id": "21b186b9-80b7-455e-8433-78c531050100", + "metadata": {}, + "source": [ + "> Another scatterplot of true versus predicted y-values." + ] + }, + { + "cell_type": "markdown", + "id": "13845e03-08fe-4f80-8883-a671e8afd9d5", + "metadata": {}, + "source": [ + "**Strange that the MSE is actually worse than linear regression.**" + ] + }, + { + "cell_type": "markdown", + "id": "4ed0b4e6-904a-4af8-8bfe-cad402de09a4", + "metadata": {}, + "source": [ + "# 8. Hyperparameter tuning\n", + "With any `scikit-learn` model, it's possible to tune the hyperparameters to achieve better performance. Define the grid to search over, here test the `n_estimators` parameter from random forest estimators, which defines the number of trees in the random forest. The default setting is 100." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6420904-44bb-40cf-8ce0-92b22a802e59", + "metadata": {}, + "outputs": [], + "source": [ + "param_grid = {'n_estimators': [1, 10, 50, 100, 200, 1000]}" + ] + }, + { + "cell_type": "markdown", + "id": "e337134d-22fe-42d3-83d2-9f22bb93e70e", + "metadata": {}, + "source": [ + "Now define the `GridSearchCV` object and fit the model, wrapped in this grid search object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48a166c2-10b9-470f-bd03-3d7e6d51a1c7", + "metadata": {}, + "outputs": [], + "source": [ + "grid = GridSearchCV(model_rf, param_grid, cv=5)\n", + "\n", + "grid.fit(X_train_i_scaled, y_train_i)\n", + "\n", + "print(grid.best_params_)" + ] + }, + { + "cell_type": "markdown", + "id": "8445ad22-4275-4a56-ace7-ce1a678dd900", + "metadata": {}, + "source": [ + "Now retrieve the best model and print out model diagnostics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89167a5d-e835-4f8d-8d57-ac82f0df6239", + "metadata": {}, + "outputs": [], + "source": [ + "best_model = grid.best_estimator_\n", + "print(\"Best Model:\", best_model)\n", + "y_pred_grid = best_model.predict(X_test_i_scaled)\n", + "plt.clf()\n", + "plt.scatter(y_test_i, y_pred_grid)\n", + "plt.plot([0, 1e6], [0, 1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0, 2e4])\n", + "plt.ylim([0, 2e4])\n", + "plt.show()\n", + "\n", + "mse = mean_squared_error(y_test_i, y_pred_grid)\n", + "print(\"MSE:\", mse)" + ] + }, + { + "cell_type": "markdown", + "id": "e6438d53-e03b-4afa-a417-489064caf51e", + "metadata": {}, + "source": [ + "**To do: The multiple linear regression currently has a better MSE than this model, investigate why that is a little more.**" + ] + }, + { + "cell_type": "markdown", + "id": "ec3e2b39-f86e-4dfa-a97a-9e018fc2208c", + "metadata": {}, + "source": [ + "# 9. Fit the best model to rows where the $r-$band Kron flux is missing." + ] + }, + { + "cell_type": "markdown", + "id": "98f817da-d4cd-40d2-be10-ba9c1669e68e", + "metadata": {}, + "source": [ + "First, return to the dataframe and snag rows where the $r-$band flux is flagged but the $g-$ and $i-$band Kron fluxes are intact." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a020b226-289d-4587-9e42-47f64574dbf6", + "metadata": {}, + "outputs": [], + "source": [ + "r_missing_df = results[\n", + " (results['r_kronFlux_flag'] == True) & \n", + " (results['g_kronFlux_flag'] == False) &\n", + " (results['i_kronFlux_flag'] == False)\n", + "]\n", + "r_missing_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd7a891e-5a56-440e-a721-91af73d74269", + "metadata": {}, + "outputs": [], + "source": [ + "X_r_missing = r_missing_df[['g_kronFlux', 'i_kronFlux']]\n", + "X_r_missing_scaled = scaler.fit_transform(X_r_missing)\n", + "y_pred_r_missing = best_model.predict(X_r_missing_scaled)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e438d24-75ba-4888-b2c9-cb4fa6c6db85", + "metadata": {}, + "outputs": [], + "source": [ + "plt.scatter(r_missing_df['r_kronFlux'].values, y_pred_r_missing)\n", + "plt.xlabel(r'Flagged $r-$band values')\n", + "plt.ylabel('Predicted $r-$band values');" + ] + }, + { + "cell_type": "markdown", + "id": "d4c11515-3263-4d8e-bf08-fa60f7cafb9d", + "metadata": {}, + "source": [ + "> A scatterplot of flagged (original) $r-$band Kron flux values against the predicted values from the best fit model. The spread is much greater in the y-axis." + ] + }, + { + "cell_type": "markdown", + "id": "68c59d12-a11f-40b7-b7fb-84276d0c7c78", + "metadata": {}, + "source": [ + "**To do: Also plot against the predictors ($g-$ and $i-$band) Kron fluxes for a better idea of how this model performs.**" + ] + }, + { + "cell_type": "markdown", + "id": "d5c33bd9-b95e-4ec8-a44d-cf3f219339e8", + "metadata": {}, + "source": [ + "# 10. Other available `scikit-learn` choices\n", + "The below two cells explore available options from `scikit-learn` for regression metrics and regression models, respectively. The metric cell is truncated with a `break` statement to only print details of the first metric. The model cell demonstrates printing the class information for the `RandomForestRegressor` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5adca7ed-eb36-4e8b-99e3-3b66833fb3e3", + "metadata": {}, + "outputs": [], + "source": [ + "regression_metrics = [\n", + " name for name, obj in inspect.getmembers(metrics)\n", + " if inspect.isfunction(obj)\n", + " and ('regression' in (obj.__doc__ or '').lower() or 'error' in (obj.__doc__ or '').lower())\n", + " and 'classification' not in (obj.__doc__ or '').lower()\n", + "]\n", + "print(regression_metrics)\n", + "\n", + "\n", + "# Print the filtered metrics and their documentation\n", + "for metric in regression_metrics:\n", + " metric_func = getattr(metrics, metric)\n", + " if metric == \"mean_tweedie_deviance\":\n", + " print(f\"--- {metric} ---\")\n", + " help(metric_func)\n", + " print(\"=\"*80)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de4a41f7-9f4f-4662-9302-2da7d68df434", + "metadata": {}, + "outputs": [], + "source": [ + "# Get all regression models\n", + "regressors = all_estimators(type_filter='regressor')\n", + "\n", + "# Print the names of all available regression models\n", + "for name, estimator in regressors:\n", + " print(name)\n", + "\n", + "for name, estimator in regressors:\n", + " if name == \"RandomForestRegressor\":\n", + " print(help(estimator))" + ] + }, + { + "cell_type": "markdown", + "id": "2e205cc2-26f1-4b51-bb36-3ce2ed8cece4", + "metadata": {}, + "source": [ + "# Exercise for the learner\n", + "\n", + "The uncertainty values are also available for each column of Kron fluxes. Use those to run an uncertainty-aware model fit for the $r-$band Kron fluxes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7877e1b2-e1c8-4d1f-a2b8-60c758cdf61d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "LSST", + "language": "python", + "name": "lsst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/DP0.2/20_Introduction_to_Data_Science_with_Rubin_Data.ipynb b/DP0.2/20_Introduction_to_Data_Science_with_Rubin_Data.ipynb new file mode 100644 index 00000000..6232a288 --- /dev/null +++ b/DP0.2/20_Introduction_to_Data_Science_with_Rubin_Data.ipynb @@ -0,0 +1,2910 @@ +{ + "cells": [ + { + "attachments": { + "90083a24-00a4-4a6f-a1c3-b9b4c6b0de9e.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "id": "6c298cbb-eb18-4cbc-b064-709a7c2b9b4c", + "metadata": {}, + "source": [ + "# XXX.X. How to use `pandas` and `scikit-learn` to do data science / machine learning with the TAP service\n", + "\n", + "
\n", + "\n", + "
\n", + "For the Rubin Science Platform at data.lsst.cloud.
\n", + "Data Release: DPX or DRX
\n", + "Container Size: small
\n", + "LSST Science Pipelines version: Weekly 2024_16
\n", + "Last verified to run: 2024-07-09
\n", + "Repository: github.com/lsst/tutorial-notebooks
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a4e60f3-a825-4dcc-8cd5-37ed43b7373b", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext pycodestyle_magic\n", + "%flake8_on\n", + "import logging\n", + "logging.getLogger(\"flake8\").setLevel(logging.FATAL)" + ] + }, + { + "cell_type": "markdown", + "id": "25e71641-fbfb-4470-8049-03ba631dbef1", + "metadata": {}, + "source": [ + "**Learning objective:** This notebook guides a PI through the process of using python's data science and machine learning libraries to explore data from complex ADQL queries with the TAP service. The goal is to build a predictive model to estimate missing $r-$band Kron Flux values when the other bands are available, and visualize the results and quantify the model performance.\n", + "\n", + "**LSST data products:** Object, Forcedsource, and CcdVisit tables.\n", + "\n", + "**Packages:** lsst.rsp, pandas, scikit-learn, seaborn\n", + "\n", + "**Credit:**\n", + "Based on notebooks developed by Leanne Guy (TAP query) and Alex Drlica-Wagner and Melissa Graham (Butler query).\n", + "Please consider acknowledging them if this notebook is used for the preparation of journal articles, software releases, or other notebooks.\n", + "\n", + "**Get Support:**\n", + "Everyone is encouraged to ask questions or raise issues in the \n", + "Support Category \n", + "of the Rubin Community Forum.\n", + "Rubin staff will respond to all questions posted there." + ] + }, + { + "cell_type": "markdown", + "id": "ca5378d1-00fe-44ad-be88-5f1a0a85b404", + "metadata": {}, + "source": [ + "## 1. Introduction\n", + "\n", + "This notebook provides an intermediate-level demonstration of how to use the Table Access Protocol (TAP) server and ADQL (Astronomy Data Query Language) to query and retrieve data from the DP0.2 catalogs.\n", + "\n", + "TAP provides standardized access to catalog data for discovery, search, and retrieval.\n", + "Full documentation for TAP is provided by the International Virtual Observatory Alliance (IVOA).\n", + "ADQL is similar to SQL (Structured Query Langage).\n", + "The documentation for ADQL includes more information about syntax and keywords.\n", + "Note that not all ADQL functionality is supported yet in the DP0-era RSP.\n", + "\n", + "**See the recommendations for TAP queries in DP0.2 tutorial 02a \"Introduction to the TAP Service\".**\n", + "\n", + "The [documentation for Data Preview 0.2](https://dp0-2.lsst.io/) includes definitions\n", + "of the data products, descriptions of catalog contents, and ADQL recipes.\n", + "\n", + "### 1.1. Package imports\n", + "\n", + "Import general python packages, the Rubin TAP service utilities, and various scikit-learn utilities." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "3f4900a4-3358-472a-b9ba-c42e3f2f0771", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:16:10.447704Z", + "iopub.status.busy": "2024-12-03T18:16:10.447270Z", + "iopub.status.idle": "2024-12-03T18:16:11.741424Z", + "shell.execute_reply": "2024-12-03T18:16:11.740792Z", + "shell.execute_reply.started": "2024-12-03T18:16:10.447676Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import pandas\n", + "\n", + "from astropy import units as u\n", + "from astropy.coordinates import SkyCoord\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import StandardScaler\n", + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.metrics import mean_squared_error\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "\n", + "from lsst.rsp import get_tap_service, retrieve_query" + ] + }, + { + "cell_type": "markdown", + "id": "90251edc-e77a-4f2c-aef3-935381faebc9", + "metadata": {}, + "source": [ + "Set up seaborn to use 538's aesthetics. This is probably not what we want to the rtn-045 default plotting settings though..." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "94acc9f6-2033-4ace-aefd-d036a35f4221", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:18:25.800083Z", + "iopub.status.busy": "2024-12-03T18:18:25.799778Z", + "iopub.status.idle": "2024-12-03T18:18:25.803998Z", + "shell.execute_reply": "2024-12-03T18:18:25.803416Z", + "shell.execute_reply.started": "2024-12-03T18:18:25.800052Z" + } + }, + "outputs": [], + "source": [ + "sns.set_style('whitegrid')\n", + "plt.style.use('fivethirtyeight')\n", + "palette = sns.color_palette(\"muted\") # Choose a desired palette\n", + "sns.set_palette(palette)" + ] + }, + { + "cell_type": "markdown", + "id": "ca1e28f4-805e-4480-a7c6-0473b7e2b088", + "metadata": {}, + "source": [ + "### 1.2. Define functions and parameters\n", + "\n", + "Instantiate the TAP service." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "caf56589-100a-4481-8f24-5f5058b6671f", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:18:31.983738Z", + "iopub.status.busy": "2024-12-03T18:18:31.983012Z", + "iopub.status.idle": "2024-12-03T18:18:32.026282Z", + "shell.execute_reply": "2024-12-03T18:18:32.025768Z", + "shell.execute_reply.started": "2024-12-03T18:18:31.983710Z" + } + }, + "outputs": [], + "source": [ + "service = get_tap_service(\"tap\")\n", + "assert service is not None" + ] + }, + { + "cell_type": "markdown", + "id": "9eb0f20e-6c28-404f-8032-acc00f73405a", + "metadata": {}, + "source": [ + "Set the maximum number of rows to display from pandas." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "2b7b6002-2457-4c20-a03e-6bfa24a0aa27", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:18:32.908017Z", + "iopub.status.busy": "2024-12-03T18:18:32.907315Z", + "iopub.status.idle": "2024-12-03T18:18:32.910925Z", + "shell.execute_reply": "2024-12-03T18:18:32.910377Z", + "shell.execute_reply.started": "2024-12-03T18:18:32.907992Z" + } + }, + "outputs": [], + "source": [ + "pandas.set_option('display.max_rows', 6)" + ] + }, + { + "cell_type": "markdown", + "id": "cd325fe4-6c7c-4803-ad79-f30d7edc23e3", + "metadata": {}, + "source": [ + "## 2. Query for Kron fluxes around extended (galaxy) objects.\n", + "I forget why I chose this specific statistic for the demo.\n", + "\n", + "Kron radius: A radius that is calculated using the light profile of the object, typically as the first moment (i.e., a weighted average of radius with brightness) of the light distribution.\n", + "\n", + "Kron flux: The total flux measured within a certain multiple (often 2.5×) of the Kron radius, typically capturing about 90–95% of the total light for extended sources like galaxies." + ] + }, + { + "cell_type": "markdown", + "id": "6b4f495d-1215-421d-bdb0-bc32fec92c25", + "metadata": {}, + "source": [ + "Define the coordinates and radius to use for the example queries in Sections 2 and 3." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "7ddd0344-b354-45a0-9e5a-755149c9bc54", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:18:35.079591Z", + "iopub.status.busy": "2024-12-03T18:18:35.078853Z", + "iopub.status.idle": "2024-12-03T18:18:35.082610Z", + "shell.execute_reply": "2024-12-03T18:18:35.082033Z", + "shell.execute_reply.started": "2024-12-03T18:18:35.079566Z" + } + }, + "outputs": [], + "source": [ + "center_ra = 62\n", + "center_dec = -37\n", + "radius = 0.1\n", + "\n", + "str_center_coords = str(center_ra) + \", \" + str(center_dec)\n", + "str_radius = str(radius)" + ] + }, + { + "cell_type": "markdown", + "id": "dd80babb-ee05-49e9-9f9c-923d5c0cee31", + "metadata": {}, + "source": [ + "Start with the same query as used in the beginner TAP tutorial notebook 02a." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "985e3b62-8065-42ec-a40c-1232c4c45f17", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:19:13.171378Z", + "iopub.status.busy": "2024-12-03T18:19:13.170670Z", + "iopub.status.idle": "2024-12-03T18:19:13.174956Z", + "shell.execute_reply": "2024-12-03T18:19:13.174323Z", + "shell.execute_reply.started": "2024-12-03T18:19:13.171349Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SELECT coord_ra, coord_dec, g_kronFlux, g_kronFlux_flag, r_kronFlux, r_kronFlux_flag, i_kronFlux, i_kronFlux_flag FROM dp02_dc2_catalogs.Object WHERE CONTAINS(POINT('ICRS', coord_ra, coord_dec), CIRCLE('ICRS', 62, -37, 0.1)) = 1 AND detect_isPrimary = 1 AND g_extendedness = 1\n" + ] + } + ], + "source": [ + "query = \"SELECT coord_ra, coord_dec, g_kronFlux, g_kronFlux_flag, \"\\\n", + " \"r_kronFlux, r_kronFlux_flag, i_kronFlux, i_kronFlux_flag \"\\\n", + " \"FROM dp02_dc2_catalogs.Object \"\\\n", + " \"WHERE CONTAINS(POINT('ICRS', coord_ra, coord_dec), \"\\\n", + " \"CIRCLE('ICRS', \" + str_center_coords + \", \" + str_radius + \")) = 1 \"\\\n", + " \"AND detect_isPrimary = 1 AND g_extendedness = 1\"\n", + "print(query)" + ] + }, + { + "cell_type": "markdown", + "id": "f024085b-0f7f-45c6-8184-41b528c15396", + "metadata": {}, + "source": [ + "Run the query job asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "c02adc91-5f5e-418b-87a3-cba8beba7dd2", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:19:14.234845Z", + "iopub.status.busy": "2024-12-03T18:19:14.234550Z", + "iopub.status.idle": "2024-12-03T18:19:21.521239Z", + "shell.execute_reply": "2024-12-03T18:19:21.520430Z", + "shell.execute_reply.started": "2024-12-03T18:19:14.234825Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job phase is COMPLETED\n" + ] + } + ], + "source": [ + "job = service.submit_job(query)\n", + "job.run()\n", + "job.wait(phases=['COMPLETED', 'ERROR'])\n", + "print('Job phase is', job.phase)" + ] + }, + { + "cell_type": "markdown", + "id": "80b28a39-bd12-49d6-9cce-f8ddfc31296c", + "metadata": {}, + "source": [ + "## 3. Explore the data using a `pandas` DataFrame object.\n", + "DEFINE WHAT A DATAFRAME IS." + ] + }, + { + "cell_type": "markdown", + "id": "07d1cfb1-589b-402b-8b8f-c2c70652b6c6", + "metadata": {}, + "source": [ + "Return the results as a `pandas` dataframe, and then delete the query to save space." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "8cd2f538-c2d7-44ca-ab4d-825120b8f2e7", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:19:21.522743Z", + "iopub.status.busy": "2024-12-03T18:19:21.522437Z", + "iopub.status.idle": "2024-12-03T18:19:21.848448Z", + "shell.execute_reply": "2024-12-03T18:19:21.847864Z", + "shell.execute_reply.started": "2024-12-03T18:19:21.522716Z" + } + }, + "outputs": [], + "source": [ + "results = job.fetch_result().to_table().to_pandas()\n", + "job.delete()\n", + "del query" + ] + }, + { + "cell_type": "markdown", + "id": "27888222-8fca-4620-a838-9260b0f5f47f", + "metadata": {}, + "source": [ + "Display `results`." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "ee4d121e-6b4d-4371-afae-4f7587b95d51", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:18:54.059294Z", + "iopub.status.busy": "2024-12-03T18:18:54.058632Z", + "iopub.status.idle": "2024-12-03T18:18:54.074345Z", + "shell.execute_reply": "2024-12-03T18:18:54.073785Z", + "shell.execute_reply.started": "2024-12-03T18:18:54.059270Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
coord_racoord_decg_kronFluxg_kronFlux_flagr_kronFluxr_kronFlux_flagi_kronFluxi_kronFlux_flag
062.018897-37.09567171.568352True91.185588True624.454022True
162.020999-37.093227174.729861False110.922305False52.040203True
262.000430-37.093196131.680920False137.655812False136.174616True
...........................
1156161.950427-36.94658651.054369True175.646973False123.073904True
1156261.976752-36.904225199.039503False187.972452False115.825734False
1156361.932319-36.941077266.123377False218.853195False481.650950True
\n", + "

11564 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " coord_ra coord_dec g_kronFlux g_kronFlux_flag r_kronFlux \\\n", + "0 62.018897 -37.095671 71.568352 True 91.185588 \n", + "1 62.020999 -37.093227 174.729861 False 110.922305 \n", + "2 62.000430 -37.093196 131.680920 False 137.655812 \n", + "... ... ... ... ... ... \n", + "11561 61.950427 -36.946586 51.054369 True 175.646973 \n", + "11562 61.976752 -36.904225 199.039503 False 187.972452 \n", + "11563 61.932319 -36.941077 266.123377 False 218.853195 \n", + "\n", + " r_kronFlux_flag i_kronFlux i_kronFlux_flag \n", + "0 True 624.454022 True \n", + "1 False 52.040203 True \n", + "2 False 136.174616 True \n", + "... ... ... ... \n", + "11561 False 123.073904 True \n", + "11562 False 115.825734 False \n", + "11563 False 481.650950 True \n", + "\n", + "[11564 rows x 8 columns]" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results" + ] + }, + { + "cell_type": "markdown", + "id": "1d493b9b-0aba-4586-bcd0-3e1ce1f09c16", + "metadata": {}, + "source": [ + "`results` is a `pandas` DataFrame object (see below). There's lots you can do with this type of object..." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "db2168fe-593a-423d-b2f4-26ac0db60e8c", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:19:48.071264Z", + "iopub.status.busy": "2024-12-03T18:19:48.070609Z", + "iopub.status.idle": "2024-12-03T18:19:48.074799Z", + "shell.execute_reply": "2024-12-03T18:19:48.074312Z", + "shell.execute_reply.started": "2024-12-03T18:19:48.071232Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "pandas.core.frame.DataFrame" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(results)" + ] + }, + { + "cell_type": "markdown", + "id": "9c59c4e8-90bd-4aa8-8ce2-9a14c09988a0", + "metadata": {}, + "source": [ + "Some options are inspection- and summary-oriented, such as the `.head()`, `.tail()`, and `.describe()` attributes. Let's check these out now. `.head()` and `.tail()` give you the first and last five rows, respectively, but can be modified to print out a different number of rows. `.describe()` will provide statistics of the distribution of values in each column, including the mean and standard deviation." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "eec25f58-d3f3-4ef4-b3e2-ab105c4718fd", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:23:39.207146Z", + "iopub.status.busy": "2024-12-03T18:23:39.206862Z", + "iopub.status.idle": "2024-12-03T18:23:39.230317Z", + "shell.execute_reply": "2024-12-03T18:23:39.229617Z", + "shell.execute_reply.started": "2024-12-03T18:23:39.207125Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " coord_ra coord_dec g_kronFlux g_kronFlux_flag r_kronFlux \\\n", + "0 62.018897 -37.095671 71.568352 True 91.185588 \n", + "1 62.020999 -37.093227 174.729861 False 110.922305 \n", + "2 62.000430 -37.093196 131.680920 False 137.655812 \n", + "3 62.015568 -37.092868 372.665560 False 171.582869 \n", + "4 62.002969 -37.092762 247.219720 False 153.138653 \n", + "\n", + " r_kronFlux_flag i_kronFlux i_kronFlux_flag \n", + "0 True 624.454022 True \n", + "1 False 52.040203 True \n", + "2 False 136.174616 True \n", + "3 False 211.338418 True \n", + "4 True 184.829166 True \n", + " coord_ra coord_dec g_kronFlux g_kronFlux_flag r_kronFlux \\\n", + "11554 61.913511 -36.960012 158.939682 False NaN \n", + "11555 61.986424 -36.950292 125.535210 False 71.749436 \n", + "11556 61.942562 -36.951137 108.966860 False 135.301872 \n", + "... ... ... ... ... ... \n", + "11561 61.950427 -36.946586 51.054369 True 175.646973 \n", + "11562 61.976752 -36.904225 199.039503 False 187.972452 \n", + "11563 61.932319 -36.941077 266.123377 False 218.853195 \n", + "\n", + " r_kronFlux_flag i_kronFlux i_kronFlux_flag \n", + "11554 True 102.094474 True \n", + "11555 True NaN True \n", + "11556 False 191.964068 True \n", + "... ... ... ... \n", + "11561 False 123.073904 True \n", + "11562 False 115.825734 False \n", + "11563 False 481.650950 True \n", + "\n", + "[10 rows x 8 columns]\n", + " coord_ra coord_dec g_kronFlux r_kronFlux i_kronFlux\n", + "count 11564.000000 11564.000000 11447.000000 1.145300e+04 1.129200e+04\n", + "mean 61.999517 -37.001530 761.468963 1.368137e+03 2.071233e+03\n", + "std 0.062390 0.050868 13194.742739 2.395294e+04 3.258561e+04\n", + "... ... ... ... ... ...\n", + "50% 61.999039 -37.001694 182.963268 2.409297e+02 3.517374e+02\n", + "75% 62.049891 -36.960187 340.536311 4.811457e+02 7.569082e+02\n", + "max 62.124349 -36.900195 782163.585831 1.957761e+06 2.870998e+06\n", + "\n", + "[8 rows x 5 columns]\n" + ] + } + ], + "source": [ + "print(results.head())\n", + "print(results.tail(10))\n", + "print(results.describe())" + ] + }, + { + "cell_type": "markdown", + "id": "1b37fc18-e00b-4ef2-8d18-85d823d60d9e", + "metadata": {}, + "source": [ + "## 4. Visualize using `seaborn`\n", + "Let's look further into visualizing these statistics using `seaborn`'s boxplot tool." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "4fc9b578-2be4-4fb2-8d74-ebca809ea99f", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:55:30.244276Z", + "iopub.status.busy": "2024-12-03T18:55:30.243648Z", + "iopub.status.idle": "2024-12-03T18:55:30.835103Z", + "shell.execute_reply": "2024-12-03T18:55:30.834453Z", + "shell.execute_reply.started": "2024-12-03T18:55:30.244249Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(8, 6))\n", + "sns.boxplot(data=results)\n", + "plt.title('Box Plot of Data Distributions')\n", + "plt.xlabel('Feature')\n", + "plt.ylabel('Value')\n", + "plt.xticks(rotation=90)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "ef5f20de-0fd6-4ba1-9cab-9d59cd05df99", + "metadata": {}, + "source": [ + "It looks like outliers are dominant in the visualization. Hide these and also only plot the kron Flux values." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "bacf5114-6a64-4100-8eb6-f1d9ddc36f89", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T18:59:02.839663Z", + "iopub.status.busy": "2024-12-03T18:59:02.839254Z", + "iopub.status.idle": "2024-12-03T18:59:03.091807Z", + "shell.execute_reply": "2024-12-03T18:59:03.091120Z", + "shell.execute_reply.started": "2024-12-03T18:59:02.839637Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(8, 6))\n", + "sns.boxplot(data=results[['g_kronFlux','r_kronFlux','i_kronFlux']], showfliers=False)\n", + "plt.title('Box Plot of Data Distributions')\n", + "plt.xlabel('Feature')\n", + "plt.ylabel('Value')\n", + "plt.xticks(rotation=90)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "75d6336b-1068-46cf-8105-b945fd18020e", + "metadata": {}, + "source": [ + "Boxplots show a box and whiskers.\n", + "- The \"box\" is the interquartile range (IQR), which is the 25th percentile of the distribution of a value to the 75th percentile.\n", + "- The horizontal line inside the box is the median of the distribution.\n", + "- The whisker extends from the IQR to 1.5*IQR away from the edge of the box.\n", + "- Points outside the whisker are considered outliers (hidden here).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "39521ac6-0bec-42e7-9062-8fc9ce5edc55", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T19:11:16.903575Z", + "iopub.status.busy": "2024-12-03T19:11:16.902993Z", + "iopub.status.idle": "2024-12-03T19:11:17.202209Z", + "shell.execute_reply": "2024-12-03T19:11:17.201547Z", + "shell.execute_reply.started": "2024-12-03T19:11:16.903550Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_638/2438356921.py:3: FutureWarning: \n", + "\n", + "The `bw` parameter is deprecated in favor of `bw_method`/`bw_adjust`.\n", + "Setting `bw_method=0.2`, but please see docs for the new parameters\n", + "and update your code. This will become an error in seaborn v0.15.0.\n", + "\n", + " sns.violinplot(data=filtered_results,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(8, 6))\n", + "filtered_results = results[['g_kronFlux', 'r_kronFlux', 'i_kronFlux']].apply(lambda x: x[(x > x.quantile(0.05)) & (x < x.quantile(0.95))])\n", + "sns.violinplot(data=filtered_results,\n", + " cut=0,\n", + " bw=0.2)\n", + "plt.title('Box Plot of Data Distributions')\n", + "plt.xlabel('Feature')\n", + "plt.ylabel('Value')\n", + "plt.xticks(rotation=90)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4f2c50a8-098e-4230-b122-3880c8bb1883", + "metadata": {}, + "source": [ + "A violinplot gives a lot of the same information as a boxplot, in fact, there are little boxplots within the violinplot; the horizontal white line is the median, the thicker grey box is the IQR and the thin line shows the 1.5*IQR span. A violinplot also uses a kernel density extimator to visualize the distribution of each feature. Here we see that most of the data are clustered around relatively low values for all of the kron fluxes." + ] + }, + { + "cell_type": "markdown", + "id": "3ec470a9-8db0-403a-89ba-32e0dd9bef15", + "metadata": {}, + "source": [ + "## 5. Investigate the fluxes and associated flags" + ] + }, + { + "cell_type": "markdown", + "id": "18dba188-ea4c-47e9-8faf-62e3552add1e", + "metadata": {}, + "source": [ + "Use `pandas` to investigate if there are any flags on the `kronFlux` measurement. The `.value_counts()` method will show the number of True and False columns, where True are rows for which the `g_kronFlux` measurement was flagged for a variety of reasons. There are many other columns that investigate specific reasons why this measurement is untrustworthy; the `g_kronFlux_flag` is a way to combine all of the individual flags." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0be4535d-cc89-45ef-98e9-591b9f459fae", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:41.789910Z", + "iopub.status.busy": "2024-12-03T00:04:41.789726Z", + "iopub.status.idle": "2024-12-03T00:04:41.794781Z", + "shell.execute_reply": "2024-12-03T00:04:41.794289Z", + "shell.execute_reply.started": "2024-12-03T00:04:41.789894Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "r_kronFlux_flag\n", + "False 10723\n", + "True 841\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results['r_kronFlux_flag'].value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "af623422-c6cf-4c21-8971-5599c60ea8d7", + "metadata": { + "execution": { + "iopub.execute_input": "2024-11-15T18:24:36.531474Z", + "iopub.status.busy": "2024-11-15T18:24:36.531192Z", + "iopub.status.idle": "2024-11-15T18:24:36.547424Z", + "shell.execute_reply": "2024-11-15T18:24:36.546685Z", + "shell.execute_reply.started": "2024-11-15T18:24:36.531443Z" + } + }, + "source": [ + "Let's compare the value of the `g_kronFlux` between flagged and unflagged cases." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "dddfd1e7-3faa-465f-ade4-dc136ae39262", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:41.795547Z", + "iopub.status.busy": "2024-12-03T00:04:41.795330Z", + "iopub.status.idle": "2024-12-03T00:04:41.924873Z", + "shell.execute_reply": "2024-12-03T00:04:41.924253Z", + "shell.execute_reply.started": "2024-12-03T00:04:41.795530Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.clf()\n", + "plt.hist(results[results['r_kronFlux_flag']]['r_kronFlux'], label='Flagged', density=True)\n", + "#plt.hist(results[results['r_kronFlux_flag']==False]['r_kronFlux'], label='Not flagged', density=True)\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "24e94c30-66b3-4b9c-b79b-bcf024d5f214", + "metadata": {}, + "source": [ + "Okay what about the `r_kronFlux` measurement?" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0e66ccb2-3922-471b-8c15-7fb055d02a10", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:41.925819Z", + "iopub.status.busy": "2024-12-03T00:04:41.925606Z", + "iopub.status.idle": "2024-12-03T00:04:41.931416Z", + "shell.execute_reply": "2024-12-03T00:04:41.930843Z", + "shell.execute_reply.started": "2024-12-03T00:04:41.925802Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "r_kronFlux_flag\n", + "False 10723\n", + "True 841\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results['r_kronFlux_flag'].value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "fa7a4ce1-7bd4-47f8-a480-188b2c70579a", + "metadata": {}, + "source": [ + "Perform an intersection to see if the flagged entries overlap between these two photometric bands." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "06786c33-2563-4237-9d0f-22d6308c0d7b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:41.932301Z", + "iopub.status.busy": "2024-12-03T00:04:41.932089Z", + "iopub.status.idle": "2024-12-03T00:04:41.970411Z", + "shell.execute_reply": "2024-12-03T00:04:41.969900Z", + "shell.execute_reply.started": "2024-12-03T00:04:41.932283Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " coord_ra coord_dec g_kronFlux g_kronFlux_flag r_kronFlux \\\n", + "1749 61.916451 -37.018987 232.455339 True 380.415747 \n", + "1758 61.910063 -37.017256 115.486047 True 105.318235 \n", + "1774 61.950787 -37.015252 132.207788 True 193.917364 \n", + "... ... ... ... ... ... \n", + "11466 61.956032 -37.074942 59.901381 True 315.077832 \n", + "11471 61.942023 -37.073313 145.759753 True 120.304211 \n", + "11494 61.924542 -37.071842 248.013148 True 273.729756 \n", + "\n", + " r_kronFlux_flag i_kronFlux i_kronFlux_flag \n", + "1749 True 562.754481 False \n", + "1758 True 218.924537 True \n", + "1774 True 522.751057 False \n", + "... ... ... ... \n", + "11466 True NaN True \n", + "11471 True 243.456290 True \n", + "11494 True 257.108970 True \n", + "\n", + "[328 rows x 8 columns]\n" + ] + } + ], + "source": [ + "# get the unique values (which will be True or False)\n", + "r_values = set(results['r_kronFlux_flag'].unique())\n", + "g_values = set(results['g_kronFlux_flag'].unique())\n", + "\n", + "# find the intersection\n", + "overlap = r_values & g_values\n", + "\n", + "overlap_true_rows = results[\n", + " (results['r_kronFlux_flag'].isin(overlap)) & \n", + " (results['g_kronFlux_flag'].isin(overlap)) & \n", + " (results['r_kronFlux_flag'] == True) & \n", + " (results['g_kronFlux_flag'] == True)\n", + "]\n", + "\n", + "print(overlap_true_rows)" + ] + }, + { + "cell_type": "markdown", + "id": "ec13b104-ad8d-4bd6-8a93-b6d1d57b921e", + "metadata": {}, + "source": [ + "There are six overlapping true rows, meaning that in six cases, we cannot use the $r-$band kronFlux to predict the $g-$band kronFlux." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d29499cf-cab9-4a5c-8811-ffd1e5c9d82f", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:41.971139Z", + "iopub.status.busy": "2024-12-03T00:04:41.970965Z", + "iopub.status.idle": "2024-12-03T00:04:41.978973Z", + "shell.execute_reply": "2024-12-03T00:04:41.978503Z", + "shell.execute_reply.started": "2024-12-03T00:04:41.971124Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " coord_ra coord_dec g_kronFlux g_kronFlux_flag r_kronFlux \\\n", + "46 62.036008 -36.904244 60.546756 False 0.103155 \n", + "347 62.072778 -36.964144 15.582472 False 0.000000 \n", + "700 62.067443 -36.954921 94.684525 False NaN \n", + "... ... ... ... ... ... \n", + "11477 61.984400 -37.069348 63.584932 False 85.069624 \n", + "11483 61.911589 -37.067708 159.077697 False 139.900920 \n", + "11493 61.957278 -37.071416 165.461776 False 182.058161 \n", + "\n", + " r_kronFlux_flag i_kronFlux i_kronFlux_flag \n", + "46 True 68.234239 False \n", + "347 True 100.430859 False \n", + "700 True 259.766837 False \n", + "... ... ... ... \n", + "11477 True 100.221708 True \n", + "11483 True 279.320760 True \n", + "11493 True 349.606659 True \n", + "\n", + "[513 rows x 8 columns]\n" + ] + } + ], + "source": [ + "g_false_r_true = results[\n", + " (results['r_kronFlux_flag'] == True) & \n", + " (results['g_kronFlux_flag'] == False)\n", + "]\n", + "\n", + "print(g_false_r_true)" + ] + }, + { + "cell_type": "markdown", + "id": "69b67af5-de80-414c-985d-8f6cdccfc3a3", + "metadata": {}, + "source": [ + "For the unflagged values, let's look at the relationship between these fluxes." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "e6294681-9c60-4ec6-805c-d378300acaa3", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:41.979744Z", + "iopub.status.busy": "2024-12-03T00:04:41.979551Z", + "iopub.status.idle": "2024-12-03T00:04:42.213955Z", + "shell.execute_reply": "2024-12-03T00:04:42.213438Z", + "shell.execute_reply.started": "2024-12-03T00:04:41.979728Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "clean = results[\n", + " (results['r_kronFlux_flag'] == False) & \n", + " (results['g_kronFlux_flag'] == False) &\n", + " (results['i_kronFlux_flag'] == False)\n", + "]\n", + "\n", + "plt.scatter(clean['g_kronFlux'], clean['r_kronFlux'])\n", + "plt.xlabel(r'$g-$band kronFlux [nJy]')\n", + "plt.ylabel(r'$r-$band kronFlux [nJy]');" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5afedb17-6478-4f2b-bdfc-38e73cd4a65e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.215968Z", + "iopub.status.busy": "2024-12-03T00:04:42.215730Z", + "iopub.status.idle": "2024-12-03T00:04:42.382817Z", + "shell.execute_reply": "2024-12-03T00:04:42.382256Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.215949Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(clean['g_kronFlux'], clean['r_kronFlux'])\n", + "plt.xlabel(r'$g-$band kronFlux [nJy]')\n", + "plt.ylabel(r'$r-$band kronFlux [nJy]')\n", + "plt.xlim([0,30000])\n", + "plt.ylim([0,0.5e5]);" + ] + }, + { + "cell_type": "markdown", + "id": "2f503394-3816-4d31-9cf0-9de88e229f87", + "metadata": {}, + "source": [ + "Zooming in on this relationship, it looks roughly linear, so we should be able to do some predictive work here." + ] + }, + { + "cell_type": "markdown", + "id": "4704605a-4665-4ccc-bd7e-cefaf5e09828", + "metadata": {}, + "source": [ + "## 6. Prepare the training and test sets" + ] + }, + { + "cell_type": "markdown", + "id": "0b3c8d4e-0541-49c7-9fa1-f9b3b5d29aac", + "metadata": {}, + "source": [ + "The first step is to define the training and validation data." + ] + }, + { + "cell_type": "markdown", + "id": "1ab2c517-125b-4c70-ad2d-b447f7e6721e", + "metadata": {}, + "source": [ + "The `.to_frame()` argument is required to input the X data as a 2D shape, as expected by scikit-learn." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "fc90feca-ede1-44b0-929b-2fec1ddf5ad4", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.383760Z", + "iopub.status.busy": "2024-12-03T00:04:42.383543Z", + "iopub.status.idle": "2024-12-03T00:04:42.389713Z", + "shell.execute_reply": "2024-12-03T00:04:42.389149Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.383742Z" + } + }, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(\n", + " clean['g_kronFlux'].to_frame(), clean['r_kronFlux'].to_frame(), test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "markdown", + "id": "ecb487df-ce82-4d2d-9821-64620e1e922b", + "metadata": {}, + "source": [ + "Good practice to use a scaler." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "8e675e27-74f0-43e8-91af-8652e2710609", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.390439Z", + "iopub.status.busy": "2024-12-03T00:04:42.390274Z", + "iopub.status.idle": "2024-12-03T00:04:42.404155Z", + "shell.execute_reply": "2024-12-03T00:04:42.403690Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.390426Z" + } + }, + "outputs": [], + "source": [ + "scaler = StandardScaler()\n", + "X_train = scaler.fit_transform(X_train)\n", + "X_test = scaler.transform(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "4771e145-2649-4eda-8891-987c7e9c9009", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.404900Z", + "iopub.status.busy": "2024-12-03T00:04:42.404711Z", + "iopub.status.idle": "2024-12-03T00:04:42.413985Z", + "shell.execute_reply": "2024-12-03T00:04:42.413340Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.404885Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[-0.05664546],\n", + " [ 0.0240961 ],\n", + " [-0.06126054],\n", + " ...,\n", + " [ 0.51047933],\n", + " [-0.01411049],\n", + " [-0.03585177]])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_test" + ] + }, + { + "cell_type": "markdown", + "id": "6011b2e6-5197-4475-8eb4-3a8841b3bd28", + "metadata": {}, + "source": [ + "## 7. Model (using `scikit-learn`)" + ] + }, + { + "cell_type": "markdown", + "id": "48c8d648-288f-4635-a961-248f127fe524", + "metadata": {}, + "source": [ + "## 7.1 Start with a linear regression" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "e02ca479-6105-442b-9879-2eb215dc4d66", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.414746Z", + "iopub.status.busy": "2024-12-03T00:04:42.414553Z", + "iopub.status.idle": "2024-12-03T00:04:42.430795Z", + "shell.execute_reply": "2024-12-03T00:04:42.430269Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.414730Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + ], + "text/plain": [ + "LinearRegression()" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = LinearRegression()\n", + "model.fit(X_train, y_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "aca8064a-c94f-4792-86b3-62ec2130471b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.431529Z", + "iopub.status.busy": "2024-12-03T00:04:42.431336Z", + "iopub.status.idle": "2024-12-03T00:04:42.438736Z", + "shell.execute_reply": "2024-12-03T00:04:42.438086Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.431514Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE: 81831676.75317723\n" + ] + } + ], + "source": [ + "y_pred = model.predict(X_test)\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "print(\"MSE:\", mse)" + ] + }, + { + "cell_type": "markdown", + "id": "c9086695-43c6-49e5-a751-bcabc275172d", + "metadata": {}, + "source": [ + "Plot how the predicted values compare to the true values." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "ee8bd887-928e-4a41-bd77-149b344ab238", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:42.439562Z", + "iopub.status.busy": "2024-12-03T00:04:42.439353Z", + "iopub.status.idle": "2024-12-03T00:04:42.606543Z", + "shell.execute_reply": "2024-12-03T00:04:42.605896Z", + "shell.execute_reply.started": "2024-12-03T00:04:42.439546Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.clf()\n", + "plt.scatter(y_test, y_pred)\n", + "plt.plot([0,1e6],[0,1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0,2e4])\n", + "plt.ylim([0,2e4]);" + ] + }, + { + "cell_type": "markdown", + "id": "1de26eb9-e7cf-4e94-bf14-0fda5d8efe18", + "metadata": {}, + "source": [ + "## 7.2 Improve the model with more features\n", + "Let's see if this will improve with more predictive values, this time including the i band information." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d80e56eb-6e3d-4110-bee6-3681ee4a923b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:43.535589Z", + "iopub.status.busy": "2024-12-03T00:04:43.535315Z", + "iopub.status.idle": "2024-12-03T00:04:43.657195Z", + "shell.execute_reply": "2024-12-03T00:04:43.656652Z", + "shell.execute_reply.started": "2024-12-03T00:04:43.535570Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE: 409962.71379404626\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Train-test split\n", + "X_train, X_test, y_train, y_test = train_test_split(\n", + " clean[['g_kronFlux', 'i_kronFlux']], # Use these two features\n", + " clean['r_kronFlux'], # Target variable\n", + " test_size=0.2,\n", + " random_state=42\n", + ")\n", + "\n", + "# Standardize the features\n", + "scaler = StandardScaler()\n", + "X_train = scaler.fit_transform(X_train)\n", + "X_test = scaler.transform(X_test)\n", + "\n", + "# Train the model\n", + "model = LinearRegression()\n", + "model.fit(X_train, y_train)\n", + "\n", + "# Make predictions\n", + "y_pred = model.predict(X_test)\n", + "\n", + "# Evaluate the model\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "print(\"MSE:\", mse)\n", + "\n", + "# Scatter plot: True vs Predicted\n", + "plt.clf()\n", + "plt.scatter(y_test, y_pred)\n", + "plt.plot([0, 1e6], [0, 1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0, 2e4])\n", + "plt.ylim([0, 2e4])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "30cd4c78-35e5-4459-84bf-9642068da002", + "metadata": {}, + "source": [ + "Test for the reader: try to improve this further by including more features." + ] + }, + { + "cell_type": "markdown", + "id": "89e92033-ff41-4662-b5aa-ece41c216a07", + "metadata": {}, + "source": [ + "## 7.3 Random forest regressor\n", + "These are great" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "efb56f9d-6487-444d-b283-6d60c1694948", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:04:45.724451Z", + "iopub.status.busy": "2024-12-03T00:04:45.723900Z", + "iopub.status.idle": "2024-12-03T00:04:48.635415Z", + "shell.execute_reply": "2024-12-03T00:04:48.634885Z", + "shell.execute_reply.started": "2024-12-03T00:04:45.724427Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE: 33565886.06951628\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model = RandomForestRegressor()\n", + "model.fit(X_train, y_train)\n", + "\n", + "# Make predictions\n", + "y_pred = model.predict(X_test)\n", + "\n", + "# Evaluate the model\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "print(\"MSE:\", mse)\n", + "\n", + "# Scatter plot: True vs Predicted\n", + "plt.clf()\n", + "plt.scatter(y_test, y_pred)\n", + "plt.plot([0, 1e6], [0, 1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0, 2e4])\n", + "plt.ylim([0, 2e4])\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4ed0b4e6-904a-4af8-8bfe-cad402de09a4", + "metadata": {}, + "source": [ + "## 8. Hyperparameter tuning\n", + "With any `scikit-learn` model, it's possible to tune the hyperparameters to achieve better performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6420904-44bb-40cf-8ce0-92b22a802e59", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T17:29:45.803252Z", + "iopub.status.busy": "2024-12-03T17:29:45.802663Z" + } + }, + "outputs": [], + "source": [ + "from sklearn.model_selection import GridSearchCV\n", + "param_grid = {'n_estimators': [1, 10, 50, 100, 200, 1000, 10000]} # default 100 for n_estimators\n", + "\n", + "# Create GridSearchCV object\n", + "grid = GridSearchCV(model, param_grid, cv=5)\n", + "\n", + "grid.fit(X_train, y_train)\n", + "\n", + "# Get the best parameters\n", + "print(grid.best_params_)" + ] + }, + { + "cell_type": "markdown", + "id": "8445ad22-4275-4a56-ace7-ce1a678dd900", + "metadata": {}, + "source": [ + "### 8.2 Now retrieve the best fit model" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "89167a5d-e835-4f8d-8d57-ac82f0df6239", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T00:19:21.191881Z", + "iopub.status.busy": "2024-12-03T00:19:21.191156Z", + "iopub.status.idle": "2024-12-03T00:19:21.944107Z", + "shell.execute_reply": "2024-12-03T00:19:21.943414Z", + "shell.execute_reply.started": "2024-12-03T00:19:21.191857Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best Model: RandomForestRegressor(n_estimators=1000)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE: 19817809.291009396\n" + ] + } + ], + "source": [ + "best_model = grid.best_estimator_\n", + "print(\"Best Model:\", best_model)\n", + "y_pred = best_model.predict(X_test)\n", + "plt.clf()\n", + "plt.scatter(y_test, y_pred)\n", + "plt.plot([0, 1e6], [0, 1e6], color='black', ls='--')\n", + "plt.xlabel('True')\n", + "plt.ylabel('Predicted')\n", + "plt.xlim([0, 2e4])\n", + "plt.ylim([0, 2e4])\n", + "plt.show()\n", + "\n", + "mse = mean_squared_error(y_test, y_pred)\n", + "print(\"MSE:\", mse)" + ] + }, + { + "cell_type": "markdown", + "id": "d5c33bd9-b95e-4ec8-a44d-cf3f219339e8", + "metadata": {}, + "source": [ + "## 9. Other available `scikit-learn` choices\n", + "The below two cells explore available options from `scikit-learn` for regression metrics and regression models, respectively. The metric cell is truncated with a `break` statement to only print details of the first metric. The model cell demonstrates printing the class information for the `RandomForestRegressor` class." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5adca7ed-eb36-4e8b-99e3-3b66833fb3e3", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T17:56:30.644538Z", + "iopub.status.busy": "2024-12-03T17:56:30.643733Z", + "iopub.status.idle": "2024-12-03T17:56:30.650634Z", + "shell.execute_reply": "2024-12-03T17:56:30.649954Z", + "shell.execute_reply.started": "2024-12-03T17:56:30.644507Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['brier_score_loss', 'check_scoring', 'coverage_error', 'd2_absolute_error_score', 'd2_pinball_score', 'd2_tweedie_score', 'explained_variance_score', 'label_ranking_loss', 'log_loss', 'max_error', 'mean_absolute_error', 'mean_absolute_percentage_error', 'mean_gamma_deviance', 'mean_pinball_loss', 'mean_poisson_deviance', 'mean_squared_error', 'mean_squared_log_error', 'mean_tweedie_deviance', 'median_absolute_error', 'pairwise_distances', 'r2_score', 'root_mean_squared_error', 'root_mean_squared_log_error']\n", + "--- mean_tweedie_deviance ---\n", + "Help on function mean_tweedie_deviance in module sklearn.metrics._regression:\n", + "\n", + "mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0)\n", + " Mean Tweedie deviance regression loss.\n", + " \n", + " Read more in the :ref:`User Guide `.\n", + " \n", + " Parameters\n", + " ----------\n", + " y_true : array-like of shape (n_samples,)\n", + " Ground truth (correct) target values.\n", + " \n", + " y_pred : array-like of shape (n_samples,)\n", + " Estimated target values.\n", + " \n", + " sample_weight : array-like of shape (n_samples,), default=None\n", + " Sample weights.\n", + " \n", + " power : float, default=0\n", + " Tweedie power parameter. Either power <= 0 or power >= 1.\n", + " \n", + " The higher `p` the less weight is given to extreme\n", + " deviations between true and predicted targets.\n", + " \n", + " - power < 0: Extreme stable distribution. Requires: y_pred > 0.\n", + " - power = 0 : Normal distribution, output corresponds to\n", + " mean_squared_error. y_true and y_pred can be any real numbers.\n", + " - power = 1 : Poisson distribution. Requires: y_true >= 0 and\n", + " y_pred > 0.\n", + " - 1 < p < 2 : Compound Poisson distribution. Requires: y_true >= 0\n", + " and y_pred > 0.\n", + " - power = 2 : Gamma distribution. Requires: y_true > 0 and y_pred > 0.\n", + " - power = 3 : Inverse Gaussian distribution. Requires: y_true > 0\n", + " and y_pred > 0.\n", + " - otherwise : Positive stable distribution. Requires: y_true > 0\n", + " and y_pred > 0.\n", + " \n", + " Returns\n", + " -------\n", + " loss : float\n", + " A non-negative floating point value (the best value is 0.0).\n", + " \n", + " Examples\n", + " --------\n", + " >>> from sklearn.metrics import mean_tweedie_deviance\n", + " >>> y_true = [2, 0, 1, 4]\n", + " >>> y_pred = [0.5, 0.5, 2., 2.]\n", + " >>> mean_tweedie_deviance(y_true, y_pred, power=1)\n", + " 1.4260...\n", + "\n", + "================================================================================\n" + ] + } + ], + "source": [ + "import sklearn.metrics as metrics\n", + "import inspect\n", + "regression_metrics = [\n", + " name for name, obj in inspect.getmembers(metrics)\n", + " if inspect.isfunction(obj)\n", + " and ('regression' in (obj.__doc__ or '').lower() or 'error' in (obj.__doc__ or '').lower())\n", + " and 'classification' not in (obj.__doc__ or '').lower()\n", + "]\n", + "print(regression_metrics)\n", + "\n", + "\n", + "# Print the filtered metrics and their documentation\n", + "for metric in regression_metrics:\n", + " metric_func = getattr(metrics, metric)\n", + " if metric == \"mean_tweedie_deviance\":\n", + " print(f\"--- {metric} ---\")\n", + " help(metric_func)\n", + " print(\"=\"*80)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "de4a41f7-9f4f-4662-9302-2da7d68df434", + "metadata": { + "execution": { + "iopub.execute_input": "2024-12-03T17:40:53.931835Z", + "iopub.status.busy": "2024-12-03T17:40:53.931115Z", + "iopub.status.idle": "2024-12-03T17:40:54.007554Z", + "shell.execute_reply": "2024-12-03T17:40:54.007006Z", + "shell.execute_reply.started": "2024-12-03T17:40:53.931811Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ARDRegression\n", + "AdaBoostRegressor\n", + "BaggingRegressor\n", + "BayesianRidge\n", + "CCA\n", + "DecisionTreeRegressor\n", + "DummyRegressor\n", + "ElasticNet\n", + "ElasticNetCV\n", + "ExtraTreeRegressor\n", + "ExtraTreesRegressor\n", + "GammaRegressor\n", + "GaussianProcessRegressor\n", + "GradientBoostingRegressor\n", + "HistGradientBoostingRegressor\n", + "HuberRegressor\n", + "IsotonicRegression\n", + "KNeighborsRegressor\n", + "KernelRidge\n", + "Lars\n", + "LarsCV\n", + "Lasso\n", + "LassoCV\n", + "LassoLars\n", + "LassoLarsCV\n", + "LassoLarsIC\n", + "LinearRegression\n", + "LinearSVR\n", + "MLPRegressor\n", + "MultiOutputRegressor\n", + "MultiTaskElasticNet\n", + "MultiTaskElasticNetCV\n", + "MultiTaskLasso\n", + "MultiTaskLassoCV\n", + "NuSVR\n", + "OrthogonalMatchingPursuit\n", + "OrthogonalMatchingPursuitCV\n", + "PLSCanonical\n", + "PLSRegression\n", + "PassiveAggressiveRegressor\n", + "PoissonRegressor\n", + "QuantileRegressor\n", + "RANSACRegressor\n", + "RadiusNeighborsRegressor\n", + "RandomForestRegressor\n", + "Help on class RandomForestRegressor in module sklearn.ensemble._forest:\n", + "\n", + "class RandomForestRegressor(ForestRegressor)\n", + " | RandomForestRegressor(n_estimators=100, *, criterion='squared_error', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=1.0, max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, ccp_alpha=0.0, max_samples=None, monotonic_cst=None)\n", + " | \n", + " | A random forest regressor.\n", + " | \n", + " | A random forest is a meta estimator that fits a number of decision tree\n", + " | regressors on various sub-samples of the dataset and uses averaging to\n", + " | improve the predictive accuracy and control over-fitting.\n", + " | Trees in the forest use the best split strategy, i.e. equivalent to passing\n", + " | `splitter=\"best\"` to the underlying :class:`~sklearn.tree.DecisionTreeRegressor`.\n", + " | The sub-sample size is controlled with the `max_samples` parameter if\n", + " | `bootstrap=True` (default), otherwise the whole dataset is used to build\n", + " | each tree.\n", + " | \n", + " | For a comparison between tree-based ensemble models see the example\n", + " | :ref:`sphx_glr_auto_examples_ensemble_plot_forest_hist_grad_boosting_comparison.py`.\n", + " | \n", + " | Read more in the :ref:`User Guide `.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | n_estimators : int, default=100\n", + " | The number of trees in the forest.\n", + " | \n", + " | .. versionchanged:: 0.22\n", + " | The default value of ``n_estimators`` changed from 10 to 100\n", + " | in 0.22.\n", + " | \n", + " | criterion : {\"squared_error\", \"absolute_error\", \"friedman_mse\", \"poisson\"}, default=\"squared_error\"\n", + " | The function to measure the quality of a split. Supported criteria\n", + " | are \"squared_error\" for the mean squared error, which is equal to\n", + " | variance reduction as feature selection criterion and minimizes the L2\n", + " | loss using the mean of each terminal node, \"friedman_mse\", which uses\n", + " | mean squared error with Friedman's improvement score for potential\n", + " | splits, \"absolute_error\" for the mean absolute error, which minimizes\n", + " | the L1 loss using the median of each terminal node, and \"poisson\" which\n", + " | uses reduction in Poisson deviance to find splits.\n", + " | Training using \"absolute_error\" is significantly slower\n", + " | than when using \"squared_error\".\n", + " | \n", + " | .. versionadded:: 0.18\n", + " | Mean Absolute Error (MAE) criterion.\n", + " | \n", + " | .. versionadded:: 1.0\n", + " | Poisson criterion.\n", + " | \n", + " | max_depth : int, default=None\n", + " | The maximum depth of the tree. If None, then nodes are expanded until\n", + " | all leaves are pure or until all leaves contain less than\n", + " | min_samples_split samples.\n", + " | \n", + " | min_samples_split : int or float, default=2\n", + " | The minimum number of samples required to split an internal node:\n", + " | \n", + " | - If int, then consider `min_samples_split` as the minimum number.\n", + " | - If float, then `min_samples_split` is a fraction and\n", + " | `ceil(min_samples_split * n_samples)` are the minimum\n", + " | number of samples for each split.\n", + " | \n", + " | .. versionchanged:: 0.18\n", + " | Added float values for fractions.\n", + " | \n", + " | min_samples_leaf : int or float, default=1\n", + " | The minimum number of samples required to be at a leaf node.\n", + " | A split point at any depth will only be considered if it leaves at\n", + " | least ``min_samples_leaf`` training samples in each of the left and\n", + " | right branches. This may have the effect of smoothing the model,\n", + " | especially in regression.\n", + " | \n", + " | - If int, then consider `min_samples_leaf` as the minimum number.\n", + " | - If float, then `min_samples_leaf` is a fraction and\n", + " | `ceil(min_samples_leaf * n_samples)` are the minimum\n", + " | number of samples for each node.\n", + " | \n", + " | .. versionchanged:: 0.18\n", + " | Added float values for fractions.\n", + " | \n", + " | min_weight_fraction_leaf : float, default=0.0\n", + " | The minimum weighted fraction of the sum total of weights (of all\n", + " | the input samples) required to be at a leaf node. Samples have\n", + " | equal weight when sample_weight is not provided.\n", + " | \n", + " | max_features : {\"sqrt\", \"log2\", None}, int or float, default=1.0\n", + " | The number of features to consider when looking for the best split:\n", + " | \n", + " | - If int, then consider `max_features` features at each split.\n", + " | - If float, then `max_features` is a fraction and\n", + " | `max(1, int(max_features * n_features_in_))` features are considered at each\n", + " | split.\n", + " | - If \"sqrt\", then `max_features=sqrt(n_features)`.\n", + " | - If \"log2\", then `max_features=log2(n_features)`.\n", + " | - If None or 1.0, then `max_features=n_features`.\n", + " | \n", + " | .. note::\n", + " | The default of 1.0 is equivalent to bagged trees and more\n", + " | randomness can be achieved by setting smaller values, e.g. 0.3.\n", + " | \n", + " | .. versionchanged:: 1.1\n", + " | The default of `max_features` changed from `\"auto\"` to 1.0.\n", + " | \n", + " | Note: the search for a split does not stop until at least one\n", + " | valid partition of the node samples is found, even if it requires to\n", + " | effectively inspect more than ``max_features`` features.\n", + " | \n", + " | max_leaf_nodes : int, default=None\n", + " | Grow trees with ``max_leaf_nodes`` in best-first fashion.\n", + " | Best nodes are defined as relative reduction in impurity.\n", + " | If None then unlimited number of leaf nodes.\n", + " | \n", + " | min_impurity_decrease : float, default=0.0\n", + " | A node will be split if this split induces a decrease of the impurity\n", + " | greater than or equal to this value.\n", + " | \n", + " | The weighted impurity decrease equation is the following::\n", + " | \n", + " | N_t / N * (impurity - N_t_R / N_t * right_impurity\n", + " | - N_t_L / N_t * left_impurity)\n", + " | \n", + " | where ``N`` is the total number of samples, ``N_t`` is the number of\n", + " | samples at the current node, ``N_t_L`` is the number of samples in the\n", + " | left child, and ``N_t_R`` is the number of samples in the right child.\n", + " | \n", + " | ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,\n", + " | if ``sample_weight`` is passed.\n", + " | \n", + " | .. versionadded:: 0.19\n", + " | \n", + " | bootstrap : bool, default=True\n", + " | Whether bootstrap samples are used when building trees. If False, the\n", + " | whole dataset is used to build each tree.\n", + " | \n", + " | oob_score : bool or callable, default=False\n", + " | Whether to use out-of-bag samples to estimate the generalization score.\n", + " | By default, :func:`~sklearn.metrics.r2_score` is used.\n", + " | Provide a callable with signature `metric(y_true, y_pred)` to use a\n", + " | custom metric. Only available if `bootstrap=True`.\n", + " | \n", + " | n_jobs : int, default=None\n", + " | The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,\n", + " | :meth:`decision_path` and :meth:`apply` are all parallelized over the\n", + " | trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`\n", + " | context. ``-1`` means using all processors. See :term:`Glossary\n", + " | ` for more details.\n", + " | \n", + " | random_state : int, RandomState instance or None, default=None\n", + " | Controls both the randomness of the bootstrapping of the samples used\n", + " | when building trees (if ``bootstrap=True``) and the sampling of the\n", + " | features to consider when looking for the best split at each node\n", + " | (if ``max_features < n_features``).\n", + " | See :term:`Glossary ` for details.\n", + " | \n", + " | verbose : int, default=0\n", + " | Controls the verbosity when fitting and predicting.\n", + " | \n", + " | warm_start : bool, default=False\n", + " | When set to ``True``, reuse the solution of the previous call to fit\n", + " | and add more estimators to the ensemble, otherwise, just fit a whole\n", + " | new forest. See :term:`Glossary ` and\n", + " | :ref:`tree_ensemble_warm_start` for details.\n", + " | \n", + " | ccp_alpha : non-negative float, default=0.0\n", + " | Complexity parameter used for Minimal Cost-Complexity Pruning. The\n", + " | subtree with the largest cost complexity that is smaller than\n", + " | ``ccp_alpha`` will be chosen. By default, no pruning is performed. See\n", + " | :ref:`minimal_cost_complexity_pruning` for details.\n", + " | \n", + " | .. versionadded:: 0.22\n", + " | \n", + " | max_samples : int or float, default=None\n", + " | If bootstrap is True, the number of samples to draw from X\n", + " | to train each base estimator.\n", + " | \n", + " | - If None (default), then draw `X.shape[0]` samples.\n", + " | - If int, then draw `max_samples` samples.\n", + " | - If float, then draw `max(round(n_samples * max_samples), 1)` samples. Thus,\n", + " | `max_samples` should be in the interval `(0.0, 1.0]`.\n", + " | \n", + " | .. versionadded:: 0.22\n", + " | \n", + " | monotonic_cst : array-like of int of shape (n_features), default=None\n", + " | Indicates the monotonicity constraint to enforce on each feature.\n", + " | - 1: monotonically increasing\n", + " | - 0: no constraint\n", + " | - -1: monotonically decreasing\n", + " | \n", + " | If monotonic_cst is None, no constraints are applied.\n", + " | \n", + " | Monotonicity constraints are not supported for:\n", + " | - multioutput regressions (i.e. when `n_outputs_ > 1`),\n", + " | - regressions trained on data with missing values.\n", + " | \n", + " | Read more in the :ref:`User Guide `.\n", + " | \n", + " | .. versionadded:: 1.4\n", + " | \n", + " | Attributes\n", + " | ----------\n", + " | estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`\n", + " | The child estimator template used to create the collection of fitted\n", + " | sub-estimators.\n", + " | \n", + " | .. versionadded:: 1.2\n", + " | `base_estimator_` was renamed to `estimator_`.\n", + " | \n", + " | estimators_ : list of DecisionTreeRegressor\n", + " | The collection of fitted sub-estimators.\n", + " | \n", + " | feature_importances_ : ndarray of shape (n_features,)\n", + " | The impurity-based feature importances.\n", + " | The higher, the more important the feature.\n", + " | The importance of a feature is computed as the (normalized)\n", + " | total reduction of the criterion brought by that feature. It is also\n", + " | known as the Gini importance.\n", + " | \n", + " | Warning: impurity-based feature importances can be misleading for\n", + " | high cardinality features (many unique values). See\n", + " | :func:`sklearn.inspection.permutation_importance` as an alternative.\n", + " | \n", + " | n_features_in_ : int\n", + " | Number of features seen during :term:`fit`.\n", + " | \n", + " | .. versionadded:: 0.24\n", + " | \n", + " | feature_names_in_ : ndarray of shape (`n_features_in_`,)\n", + " | Names of features seen during :term:`fit`. Defined only when `X`\n", + " | has feature names that are all strings.\n", + " | \n", + " | .. versionadded:: 1.0\n", + " | \n", + " | n_outputs_ : int\n", + " | The number of outputs when ``fit`` is performed.\n", + " | \n", + " | oob_score_ : float\n", + " | Score of the training dataset obtained using an out-of-bag estimate.\n", + " | This attribute exists only when ``oob_score`` is True.\n", + " | \n", + " | oob_prediction_ : ndarray of shape (n_samples,) or (n_samples, n_outputs)\n", + " | Prediction computed with out-of-bag estimate on the training set.\n", + " | This attribute exists only when ``oob_score`` is True.\n", + " | \n", + " | estimators_samples_ : list of arrays\n", + " | The subset of drawn samples (i.e., the in-bag samples) for each base\n", + " | estimator. Each subset is defined by an array of the indices selected.\n", + " | \n", + " | .. versionadded:: 1.4\n", + " | \n", + " | See Also\n", + " | --------\n", + " | sklearn.tree.DecisionTreeRegressor : A decision tree regressor.\n", + " | sklearn.ensemble.ExtraTreesRegressor : Ensemble of extremely randomized\n", + " | tree regressors.\n", + " | sklearn.ensemble.HistGradientBoostingRegressor : A Histogram-based Gradient\n", + " | Boosting Regression Tree, very fast for big datasets (n_samples >=\n", + " | 10_000).\n", + " | \n", + " | Notes\n", + " | -----\n", + " | The default values for the parameters controlling the size of the trees\n", + " | (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and\n", + " | unpruned trees which can potentially be very large on some data sets. To\n", + " | reduce memory consumption, the complexity and size of the trees should be\n", + " | controlled by setting those parameter values.\n", + " | \n", + " | The features are always randomly permuted at each split. Therefore,\n", + " | the best found split may vary, even with the same training data,\n", + " | ``max_features=n_features`` and ``bootstrap=False``, if the improvement\n", + " | of the criterion is identical for several splits enumerated during the\n", + " | search of the best split. To obtain a deterministic behaviour during\n", + " | fitting, ``random_state`` has to be fixed.\n", + " | \n", + " | The default value ``max_features=1.0`` uses ``n_features``\n", + " | rather than ``n_features / 3``. The latter was originally suggested in\n", + " | [1], whereas the former was more recently justified empirically in [2].\n", + " | \n", + " | References\n", + " | ----------\n", + " | .. [1] L. Breiman, \"Random Forests\", Machine Learning, 45(1), 5-32, 2001.\n", + " | \n", + " | .. [2] P. Geurts, D. Ernst., and L. Wehenkel, \"Extremely randomized\n", + " | trees\", Machine Learning, 63(1), 3-42, 2006.\n", + " | \n", + " | Examples\n", + " | --------\n", + " | >>> from sklearn.ensemble import RandomForestRegressor\n", + " | >>> from sklearn.datasets import make_regression\n", + " | >>> X, y = make_regression(n_features=4, n_informative=2,\n", + " | ... random_state=0, shuffle=False)\n", + " | >>> regr = RandomForestRegressor(max_depth=2, random_state=0)\n", + " | >>> regr.fit(X, y)\n", + " | RandomForestRegressor(...)\n", + " | >>> print(regr.predict([[0, 0, 0, 0]]))\n", + " | [-8.32987858]\n", + " | \n", + " | Method resolution order:\n", + " | RandomForestRegressor\n", + " | ForestRegressor\n", + " | sklearn.base.RegressorMixin\n", + " | BaseForest\n", + " | sklearn.base.MultiOutputMixin\n", + " | sklearn.ensemble._base.BaseEnsemble\n", + " | sklearn.base.MetaEstimatorMixin\n", + " | sklearn.base.BaseEstimator\n", + " | sklearn.utils._estimator_html_repr._HTMLDocumentationLinkMixin\n", + " | sklearn.utils._metadata_requests._MetadataRequester\n", + " | builtins.object\n", + " | \n", + " | Methods defined here:\n", + " | \n", + " | __init__(self, n_estimators=100, *, criterion='squared_error', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=1.0, max_leaf_nodes=None, min_impurity_decrease=0.0, bootstrap=True, oob_score=False, n_jobs=None, random_state=None, verbose=0, warm_start=False, ccp_alpha=0.0, max_samples=None, monotonic_cst=None)\n", + " | Initialize self. See help(type(self)) for accurate signature.\n", + " | \n", + " | set_fit_request(self: sklearn.ensemble._forest.RandomForestRegressor, *, sample_weight: Union[bool, NoneType, str] = '$UNCHANGED$') -> sklearn.ensemble._forest.RandomForestRegressor from sklearn.utils._metadata_requests.RequestMethod.__get__.\n", + " | Request metadata passed to the ``fit`` method.\n", + " | \n", + " | Note that this method is only relevant if\n", + " | ``enable_metadata_routing=True`` (see :func:`sklearn.set_config`).\n", + " | Please see :ref:`User Guide ` on how the routing\n", + " | mechanism works.\n", + " | \n", + " | The options for each parameter are:\n", + " | \n", + " | - ``True``: metadata is requested, and passed to ``fit`` if provided. The request is ignored if metadata is not provided.\n", + " | \n", + " | - ``False``: metadata is not requested and the meta-estimator will not pass it to ``fit``.\n", + " | \n", + " | - ``None``: metadata is not requested, and the meta-estimator will raise an error if the user provides it.\n", + " | \n", + " | - ``str``: metadata should be passed to the meta-estimator with this given alias instead of the original name.\n", + " | \n", + " | The default (``sklearn.utils.metadata_routing.UNCHANGED``) retains the\n", + " | existing request. This allows you to change the request for some\n", + " | parameters and not others.\n", + " | \n", + " | .. versionadded:: 1.3\n", + " | \n", + " | .. note::\n", + " | This method is only relevant if this estimator is used as a\n", + " | sub-estimator of a meta-estimator, e.g. used inside a\n", + " | :class:`~sklearn.pipeline.Pipeline`. Otherwise it has no effect.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED\n", + " | Metadata routing for ``sample_weight`` parameter in ``fit``.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | self : object\n", + " | The updated object.\n", + " | \n", + " | set_score_request(self: sklearn.ensemble._forest.RandomForestRegressor, *, sample_weight: Union[bool, NoneType, str] = '$UNCHANGED$') -> sklearn.ensemble._forest.RandomForestRegressor from sklearn.utils._metadata_requests.RequestMethod.__get__.\n", + " | Request metadata passed to the ``score`` method.\n", + " | \n", + " | Note that this method is only relevant if\n", + " | ``enable_metadata_routing=True`` (see :func:`sklearn.set_config`).\n", + " | Please see :ref:`User Guide ` on how the routing\n", + " | mechanism works.\n", + " | \n", + " | The options for each parameter are:\n", + " | \n", + " | - ``True``: metadata is requested, and passed to ``score`` if provided. The request is ignored if metadata is not provided.\n", + " | \n", + " | - ``False``: metadata is not requested and the meta-estimator will not pass it to ``score``.\n", + " | \n", + " | - ``None``: metadata is not requested, and the meta-estimator will raise an error if the user provides it.\n", + " | \n", + " | - ``str``: metadata should be passed to the meta-estimator with this given alias instead of the original name.\n", + " | \n", + " | The default (``sklearn.utils.metadata_routing.UNCHANGED``) retains the\n", + " | existing request. This allows you to change the request for some\n", + " | parameters and not others.\n", + " | \n", + " | .. versionadded:: 1.3\n", + " | \n", + " | .. note::\n", + " | This method is only relevant if this estimator is used as a\n", + " | sub-estimator of a meta-estimator, e.g. used inside a\n", + " | :class:`~sklearn.pipeline.Pipeline`. Otherwise it has no effect.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | sample_weight : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED\n", + " | Metadata routing for ``sample_weight`` parameter in ``score``.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | self : object\n", + " | The updated object.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Data and other attributes defined here:\n", + " | \n", + " | __abstractmethods__ = frozenset()\n", + " | \n", + " | __annotations__ = {'_parameter_constraints': }\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Methods inherited from ForestRegressor:\n", + " | \n", + " | predict(self, X)\n", + " | Predict regression target for X.\n", + " | \n", + " | The predicted regression target of an input sample is computed as the\n", + " | mean predicted regression targets of the trees in the forest.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | X : {array-like, sparse matrix} of shape (n_samples, n_features)\n", + " | The input samples. Internally, its dtype will be converted to\n", + " | ``dtype=np.float32``. If a sparse matrix is provided, it will be\n", + " | converted into a sparse ``csr_matrix``.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | y : ndarray of shape (n_samples,) or (n_samples, n_outputs)\n", + " | The predicted values.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Methods inherited from sklearn.base.RegressorMixin:\n", + " | \n", + " | score(self, X, y, sample_weight=None)\n", + " | Return the coefficient of determination of the prediction.\n", + " | \n", + " | The coefficient of determination :math:`R^2` is defined as\n", + " | :math:`(1 - \\frac{u}{v})`, where :math:`u` is the residual\n", + " | sum of squares ``((y_true - y_pred)** 2).sum()`` and :math:`v`\n", + " | is the total sum of squares ``((y_true - y_true.mean()) ** 2).sum()``.\n", + " | The best possible score is 1.0 and it can be negative (because the\n", + " | model can be arbitrarily worse). A constant model that always predicts\n", + " | the expected value of `y`, disregarding the input features, would get\n", + " | a :math:`R^2` score of 0.0.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | X : array-like of shape (n_samples, n_features)\n", + " | Test samples. For some estimators this may be a precomputed\n", + " | kernel matrix or a list of generic objects instead with shape\n", + " | ``(n_samples, n_samples_fitted)``, where ``n_samples_fitted``\n", + " | is the number of samples used in the fitting for the estimator.\n", + " | \n", + " | y : array-like of shape (n_samples,) or (n_samples, n_outputs)\n", + " | True values for `X`.\n", + " | \n", + " | sample_weight : array-like of shape (n_samples,), default=None\n", + " | Sample weights.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | score : float\n", + " | :math:`R^2` of ``self.predict(X)`` w.r.t. `y`.\n", + " | \n", + " | Notes\n", + " | -----\n", + " | The :math:`R^2` score used when calling ``score`` on a regressor uses\n", + " | ``multioutput='uniform_average'`` from version 0.23 to keep consistent\n", + " | with default value of :func:`~sklearn.metrics.r2_score`.\n", + " | This influences the ``score`` method of all the multioutput\n", + " | regressors (except for\n", + " | :class:`~sklearn.multioutput.MultiOutputRegressor`).\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Data descriptors inherited from sklearn.base.RegressorMixin:\n", + " | \n", + " | __dict__\n", + " | dictionary for instance variables\n", + " | \n", + " | __weakref__\n", + " | list of weak references to the object\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Methods inherited from BaseForest:\n", + " | \n", + " | apply(self, X)\n", + " | Apply trees in the forest to X, return leaf indices.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | X : {array-like, sparse matrix} of shape (n_samples, n_features)\n", + " | The input samples. Internally, its dtype will be converted to\n", + " | ``dtype=np.float32``. If a sparse matrix is provided, it will be\n", + " | converted into a sparse ``csr_matrix``.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | X_leaves : ndarray of shape (n_samples, n_estimators)\n", + " | For each datapoint x in X and for each tree in the forest,\n", + " | return the index of the leaf x ends up in.\n", + " | \n", + " | decision_path(self, X)\n", + " | Return the decision path in the forest.\n", + " | \n", + " | .. versionadded:: 0.18\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | X : {array-like, sparse matrix} of shape (n_samples, n_features)\n", + " | The input samples. Internally, its dtype will be converted to\n", + " | ``dtype=np.float32``. If a sparse matrix is provided, it will be\n", + " | converted into a sparse ``csr_matrix``.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | indicator : sparse matrix of shape (n_samples, n_nodes)\n", + " | Return a node indicator matrix where non zero elements indicates\n", + " | that the samples goes through the nodes. The matrix is of CSR\n", + " | format.\n", + " | \n", + " | n_nodes_ptr : ndarray of shape (n_estimators + 1,)\n", + " | The columns from indicator[n_nodes_ptr[i]:n_nodes_ptr[i+1]]\n", + " | gives the indicator value for the i-th estimator.\n", + " | \n", + " | fit(self, X, y, sample_weight=None)\n", + " | Build a forest of trees from the training set (X, y).\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | X : {array-like, sparse matrix} of shape (n_samples, n_features)\n", + " | The training input samples. Internally, its dtype will be converted\n", + " | to ``dtype=np.float32``. If a sparse matrix is provided, it will be\n", + " | converted into a sparse ``csc_matrix``.\n", + " | \n", + " | y : array-like of shape (n_samples,) or (n_samples, n_outputs)\n", + " | The target values (class labels in classification, real numbers in\n", + " | regression).\n", + " | \n", + " | sample_weight : array-like of shape (n_samples,), default=None\n", + " | Sample weights. If None, then samples are equally weighted. Splits\n", + " | that would create child nodes with net zero or negative weight are\n", + " | ignored while searching for a split in each node. In the case of\n", + " | classification, splits are also ignored if they would result in any\n", + " | single class carrying a negative weight in either child node.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | self : object\n", + " | Fitted estimator.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Readonly properties inherited from BaseForest:\n", + " | \n", + " | estimators_samples_\n", + " | The subset of drawn samples for each base estimator.\n", + " | \n", + " | Returns a dynamically generated list of indices identifying\n", + " | the samples used for fitting each member of the ensemble, i.e.,\n", + " | the in-bag samples.\n", + " | \n", + " | Note: the list is re-created at each call to the property in order\n", + " | to reduce the object memory footprint by not storing the sampling\n", + " | data. Thus fetching the property may be slower than expected.\n", + " | \n", + " | feature_importances_\n", + " | The impurity-based feature importances.\n", + " | \n", + " | The higher, the more important the feature.\n", + " | The importance of a feature is computed as the (normalized)\n", + " | total reduction of the criterion brought by that feature. It is also\n", + " | known as the Gini importance.\n", + " | \n", + " | Warning: impurity-based feature importances can be misleading for\n", + " | high cardinality features (many unique values). See\n", + " | :func:`sklearn.inspection.permutation_importance` as an alternative.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | feature_importances_ : ndarray of shape (n_features,)\n", + " | The values of this array sum to 1, unless all trees are single node\n", + " | trees consisting of only the root node, in which case it will be an\n", + " | array of zeros.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Methods inherited from sklearn.ensemble._base.BaseEnsemble:\n", + " | \n", + " | __getitem__(self, index)\n", + " | Return the index'th estimator in the ensemble.\n", + " | \n", + " | __iter__(self)\n", + " | Return iterator over estimators in the ensemble.\n", + " | \n", + " | __len__(self)\n", + " | Return the number of estimators in the ensemble.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Methods inherited from sklearn.base.BaseEstimator:\n", + " | \n", + " | __getstate__(self)\n", + " | Helper for pickle.\n", + " | \n", + " | __repr__(self, N_CHAR_MAX=700)\n", + " | Return repr(self).\n", + " | \n", + " | __setstate__(self, state)\n", + " | \n", + " | __sklearn_clone__(self)\n", + " | \n", + " | get_params(self, deep=True)\n", + " | Get parameters for this estimator.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | deep : bool, default=True\n", + " | If True, will return the parameters for this estimator and\n", + " | contained subobjects that are estimators.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | params : dict\n", + " | Parameter names mapped to their values.\n", + " | \n", + " | set_params(self, **params)\n", + " | Set the parameters of this estimator.\n", + " | \n", + " | The method works on simple estimators as well as on nested objects\n", + " | (such as :class:`~sklearn.pipeline.Pipeline`). The latter have\n", + " | parameters of the form ``__`` so that it's\n", + " | possible to update each component of a nested object.\n", + " | \n", + " | Parameters\n", + " | ----------\n", + " | **params : dict\n", + " | Estimator parameters.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | self : estimator instance\n", + " | Estimator instance.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Methods inherited from sklearn.utils._metadata_requests._MetadataRequester:\n", + " | \n", + " | get_metadata_routing(self)\n", + " | Get metadata routing of this object.\n", + " | \n", + " | Please check :ref:`User Guide ` on how the routing\n", + " | mechanism works.\n", + " | \n", + " | Returns\n", + " | -------\n", + " | routing : MetadataRequest\n", + " | A :class:`~sklearn.utils.metadata_routing.MetadataRequest` encapsulating\n", + " | routing information.\n", + " | \n", + " | ----------------------------------------------------------------------\n", + " | Class methods inherited from sklearn.utils._metadata_requests._MetadataRequester:\n", + " | \n", + " | __init_subclass__(**kwargs)\n", + " | Set the ``set_{method}_request`` methods.\n", + " | \n", + " | This uses PEP-487 [1]_ to set the ``set_{method}_request`` methods. It\n", + " | looks for the information available in the set default values which are\n", + " | set using ``__metadata_request__*`` class attributes, or inferred\n", + " | from method signatures.\n", + " | \n", + " | The ``__metadata_request__*`` class attributes are used when a method\n", + " | does not explicitly accept a metadata through its arguments or if the\n", + " | developer would like to specify a request value for those metadata\n", + " | which are different from the default ``None``.\n", + " | \n", + " | References\n", + " | ----------\n", + " | .. [1] https://www.python.org/dev/peps/pep-0487\n", + "\n", + "None\n", + "RegressorChain\n", + "Ridge\n", + "RidgeCV\n", + "SGDRegressor\n", + "SVR\n", + "StackingRegressor\n", + "TheilSenRegressor\n", + "TransformedTargetRegressor\n", + "TweedieRegressor\n", + "VotingRegressor\n" + ] + } + ], + "source": [ + "from sklearn.utils import all_estimators\n", + "\n", + "# Get all regression models\n", + "regressors = all_estimators(type_filter='regressor')\n", + "\n", + "# Print the names of all available regression models\n", + "for name, estimator in regressors:\n", + " print(name)\n", + "\n", + "for name, estimator in regressors:\n", + " if name == \"RandomForestRegressor\":\n", + " print(help(estimator))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be44f0da-e766-4bd0-9312-4b1977397c8d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "LSST", + "language": "python", + "name": "lsst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}