Detecting floating objects using deep learning and Sentinel-2 imagery#

Ocean Modelling Standard Python

license render review

binder binder

rohub doi



Detect floating targets that might contain plastic, algea, sargassum, wood, etc. in coastal areas using deep learning methods and Sentinel-2 data.

Modelling approach#

In this notebook we show the potential of deep learning methods to detect and highlight floating debris of various natures in coastal areas. First, the dataset can be downloaded from Zenodo using this notebook and it contains: a Sentinel-2 image over a coastal area of Rio de Janeiro in Brazil, a validation example from the Plastic Litter Project (PLP), and two use-case Sentinel-2 images of scenes containing various floating targets. We then apply pre-trained weights in order to identify the potential existence of floating targets. We visualise the results in an interactive way and we also show that the model works on validation data and on examples containing different types of floating objects. We finally export the predictions in Geotiff formats to the output folders. Our experiments were implemented using Pytorch. Further details on the labeled dataset and the code can be found here: ESA-PhiLab/floatingobjects.


  • We demonstrate the use of deep neural networks for the detection of floating objects on Sentinel-2 data.

  • Once the user downloads the dataset from Zenodo using this notebook, the predictions can be run and the results will be visualised.

  • Several use cases are included in order to show the variety of the detected objects.

  • The user can visualise the RGB image, the NDVI and FDI indices along with the predictions and classifications.

  • The predictions will be available locally in the user’s folder.



  • Jamila Mifdal (author), European Space Agency Φ-lab, @jmifdal

  • Raquel Carmo (author), European Space Agency Φ-lab, @raquelcarmo

  • Alejandro Coca-Castro (reviewer), The Alan Turing Institute, @acocac

Modelling codebase#

  • Jamila Mifdal (author), European Space Agency Φ-lab, @jmifdal

  • Raquel Carmo (author), European Space Agency Φ-lab, @raquelcarmo

  • Marc Rußwurm (author), EPFL-ECEO, @marccoru

Modelling publications#

Modelling funding#

This project is supported by Φ-lab, European Space Agency.

Install and load libraries#

!pip -q install gdown
WARNING: The directory '/home/jovyan/.cache/pip' or its parent directory is not owned or is not writable by the current user. The cache has been disabled. Check the permissions and owner of that directory. If executing pip with sudo, you should use sudo's -H flag.
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead:

Hide code cell source
import os
import numpy as np

# intake and pooch library
import intake
import pooch

# xarray
import xarray as xr

# machine learning libraries
import torch

# image libraries
from PIL import Image

# visualisation
import rasterio as rio 

import matplotlib,
import matplotlib.pyplot as plt

from skimage.exposure import equalize_hist

import holoviews as hv
import hvplot.xarray

import warnings

hv.extension('bokeh', width=100)

Set folder structure#

# Define the project main folder
data_folder = './notebook'

# Set the folder structure
config = {
    'in_geotiff': os.path.join(data_folder, 'input','tiff'),
    'out_geotiff': os.path.join(data_folder, 'output','raster'),
    'out_usecases': os.path.join(data_folder, 'output','usecases')

# List comprehension for the folder structure code
[os.makedirs(val) for key, val in config.items() if not os.path.exists(val)]
[None, None, None]

Load and prepare input image#

Fetch input data using intake#

The input data of this notebook can be fetched directly from a Zenodo repository using pooch. We provide the repository’s DOI (10.5281/zenodo.5827377) and we set the target entries with corresponding output folders where the downloaded files will be stored.

dataset = pooch.create(
    # Use the figshare DOI
        "cancun_20180914.tif": "md5:1d4cdb7db75d3ade702767bbfbd3ea44",
        "mytilini_20210621.tif": "md5:36edf5dace9c09914cc5fc3165109462",        
        "RioDeJaneiro.tif": "md5:734ab89c2cbde1ad76b460f4b029d9a6", 
        "tangshan_20180130.tif": "md5:e62f1faf4a6a875c88892efa4fe72723",

# Download all the files in a Zenodo record
for fn in dataset.registry_files:
Downloading file 'cancun_20180914.tif' from 'doi:10.5281/zenodo.5827377/cancun_20180914.tif' to '/home/jovyan/notebook/input/tiff'.
Downloading file 'mytilini_20210621.tif' from 'doi:10.5281/zenodo.5827377/mytilini_20210621.tif' to '/home/jovyan/notebook/input/tiff'.
Downloading file 'RioDeJaneiro.tif' from 'doi:10.5281/zenodo.5827377/RioDeJaneiro.tif' to '/home/jovyan/notebook/input/tiff'.
Downloading file 'tangshan_20180130.tif' from 'doi:10.5281/zenodo.5827377/tangshan_20180130.tif' to '/home/jovyan/notebook/input/tiff'.
# write a catalog YAML file for GeoTIFF images
catalog_images = os.path.join(data_folder, 'catalog_images.yaml')

with open(catalog_images, 'w') as f:
    - module: intake_xarray
    driver: rasterio
    description: 'GeoTIFF image in Cancun'
      urlpath: "{{ CATALOG_DIR }}/input/tiff/cancun_20180914.tif"
    driver: rasterio
    description: 'GeoTIFF image in Mytilini'
      urlpath: "{{ CATALOG_DIR }}/input/tiff/mytilini_20210621.tif"
    driver: rasterio
    description: 'GeoTIFF image in Rio de Janeiro'
      urlpath: "{{ CATALOG_DIR }}/input/tiff/RioDeJaneiro.tif"
    driver: rasterio
    description: 'GeoTIFF image in tangshan'
      urlpath: "{{ CATALOG_DIR }}/input/tiff/tangshan_20180130.tif"
cat_images = intake.open_catalog(catalog_images)

Inspect the Sentinel-2 data image#

In the example below, we inspect the .tif image provided in the Zenodo repository, called “RioDeJaneiro.tif”.

Let’s investigate the loaded data-array fetched by intake.

image = cat_images['riodejaneiro'].read()

print('shape =', image.shape)
shape = (13, 1466, 2138)

Compute model predictions#

In this work we focus on the spatial patterns of the floating targets, thus we address this detection task as a binary classification problem of floating objects versus water surface. A deep learning-based segmentation model was chosen to perform the detection and delineation of the floating targets: a U-Net Convolutional Neural Network (CNN). This model has been pre-trained on a large open-source hand-labeled Sentinel-2 dataset, developed by the authors, containing both Level-1C Top-Of-Atmosphere (TOA) and Level-2A Bottom-Of-Atmosphere (BOA) products over coastal water bodies. For more details, please refer to ESA-PhiLab/floatingobjects.

Load the model and its pre-trained weights#

For the U-Net, there are two pre-trained models available (described in this link). For the sake of simplicity we are loading the weights of the ‘unet_seed0’, which refers to a U-Net architecture pre-trained for 50 epochs, with batch size of 160, learning rate of 0.001 and seed 0.

#Device to run the computations on
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#Available models and weights
unet_seed0 = torch.hub.load('ESA-PhiLab/floatingobjects:master', 'unet_seed0', map_location=torch.device(device))
#unet_seed1 = torch.hub.load('ESA-PhiLab/floatingobjects:master', 'unet_seed1', map_location=torch.device(device))

#Select model
model = unet_seed0
Downloading: "" to /home/jovyan/.cache/torch/hub/

Compute predictions#

Input the image to the model to yield predictions.

l1cbands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B10", "B11", "B12"]
l2abands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]

image = image.assign_coords(band_id=('band', l1cbands))
image = image.set_index(band="band_id")

#If L1C image (13 bands), read only the 12 bands compatible with L2A data
if (image.shape[0] == 13):
    image = image.sel(band=l2abands)

image= image[:,400:1400,400:1400] #subset to avoid memory issues in Binder (TODO dask delay or map_blocks might help to manipulate the whole image)
image = image.values.astype(float)
image *= 1e-4
image = torch.from_numpy(image)

#Compute predictions
with torch.no_grad():
    x = image.unsqueeze(0)
    y_logits = torch.sigmoid(model(
    y_score = y_logits.cpu().detach().numpy()[0]

Compute NDVI and FDI indices#

For detecting marine litter, pixelwise spectral features such as the Normalized Difference Vegetation Index (NDVI) and the Floating Debris Index (FDI) are often chosen as problem-specific features when using model-driven classifiers (e.g. Random Forest or Naïve Bayes). In this case, because we’re applying a data-driven approach (U-Net), we only resort to these indices for visual inspection purposes. To compute these indices, we used the following equations:

\[R'_{rs,NIR} = R_{rs,RED_2} + (R_{rs,SWIR_1} - R_{rs,RED_2}) \times \dfrac{(\lambda_{NIR} - \lambda_{RED})}{(\lambda_{SWIR_1} - \lambda_{RED})} \times 10\]


  • \(R_{rs,NIR}\) is the spectral reflectance measured in the near infrared waveband (band B08)

  • \(R_{rs,RED}\) is the spectral reflectance measured in the red waveband (band B04)

  • \(R_{rs,RED_2}\) is the spectral reflectance measured in the red edge waveband (band B06)

  • \(R_{rs,SWIR_1}\) is the spectral reflectance measured in the shortwave infrared (band B11)

  • \(\lambda_{NIR} = 832.9\)

  • \(\lambda_{RED} = 664.8\)

  • \(\lambda_{SWIR_1} = 1612.05\)

def calculate_fdi(scene):
    '''Compute FDI index'''
    NIR = scene[l2abands.index("B8")]
    RED2 = scene[l2abands.index("B6")]
    SWIR1 = scene[l2abands.index("B11")]

    lambda_NIR = 832.9
    lambda_RED = 664.8
    lambda_SWIR1 = 1612.05
    NIR_prime = RED2 + (SWIR1 - RED2) * 10 * (lambda_NIR - lambda_RED) / (lambda_SWIR1 - lambda_RED)
    return NIR - NIR_prime

def calculate_ndvi(scene):
    '''Compute NDVI index'''
    NIR = scene[l2abands.index("B8")].float()
    RED = scene[l2abands.index("B4")].float()
    return (NIR - RED) / (NIR + RED + 1e-12)
#Compute the NDVI and FDI bands corresponding to the image
fdi = calculate_fdi(image).cpu().detach().numpy()
fdi = np.expand_dims(fdi,0)
fdi = np.squeeze(fdi,0)
ndvi = calculate_ndvi(image).cpu().detach().numpy()

#Compute RGB representation
tensor = np.stack([image[l2abands.index('B4')], image[l2abands.index('B3')], image[l2abands.index('B2')]])
rgb = equalize_hist(tensor.swapaxes(0,1).swapaxes(1,2))

#Configure visualisation settings
cmap_magma ='magma')
cmap_viridis ='viridis')
cmap_terrain ='terrain')
norm_fdi = matplotlib.colors.Normalize(vmin=0, vmax=0.1)
norm_ndvi = matplotlib.colors.Normalize(vmin=-.4, vmax=0.4)
norm = matplotlib.colors.Normalize(vmin=0, vmax=0.4)

Here we create the interactive plots.

general_settings = {'x':'x', 'y':'y', 'data_aspect':1, 'flip_yaxis':True, 
                    'xaxis':False, 'yaxis':None, 'tools':['tap', 'box_select']}

#convert to 'xarray.DataArray'
RGB_xr = xr.DataArray(rgb, dims=['y', 'x', 'band'], 
                      coords={'y': np.arange(rgb.shape[0]),
                              'x': np.arange(rgb.shape[1]), 
                              'band': np.arange(rgb.shape[2])})
plot_RGB = RGB_xr.hvplot.rgb(**general_settings, bands='band', title='RGB')
FDI = cmap_magma(norm_fdi(fdi))
FDI_tmp = FDI[:,:,0:3]
#convert to 'xarray.DataArray'
FDI_xr = xr.DataArray(FDI_tmp, dims=['y', 'x', 'band'], 
                      coords={'y': np.arange(FDI_tmp.shape[0]),
                              'x': np.arange(FDI_tmp.shape[1]), 
                              'band': np.arange(FDI_tmp.shape[2])})
plot_FDI = FDI_xr.hvplot.rgb(**general_settings, bands='band', title='FDI')

NDVI = cmap_viridis(norm_ndvi(ndvi))
NDVI_tmp = NDVI[:,:,0:3]
#convert to 'xarray.DataArray'
NDVI_xr = xr.DataArray(NDVI_tmp, dims=['y', 'x', 'band'], 
                       coords={'y': np.arange(NDVI_tmp.shape[0]),
                               'x': np.arange(NDVI_tmp.shape[1]),
                               'band': np.arange(NDVI_tmp.shape[2])})
plot_NDVI = NDVI_xr.hvplot.rgb(**general_settings, bands='band', title='NDVI')

Predictions = cmap_magma(norm(y_score))
#convert to 'xarray.DataArray'
Predictions_xr = xr.DataArray(Predictions, dims=['y', 'x', 'band'], 
                              coords={'y': np.arange(Predictions.shape[0]),
                                      'x': np.arange(Predictions.shape[1]), 
                                      'band': np.arange(Predictions.shape[2])})
plot_Predictions = Predictions_xr.hvplot.rgb(**general_settings, bands='band', title='Predictions')

Classification = np.where(y_score>0.4, 1, 0)
#convert to 'xarray.DataArray'
Classification_xr = xr.DataArray(Classification, dims=['y', 'x'], 
                                 coords={'y': np.arange(Classification.shape[0]),
                                         'x': np.arange(Classification.shape[1])})
plot_Classification = Classification_xr.hvplot(**general_settings, cmap='viridis', colorbar=False, title='Classification')
cplot =  plot_RGB + hv.Empty() + plot_FDI + plot_NDVI + plot_Predictions + plot_Classification