Convolution Neural Network (CNN) Based Classification Workflow Example¶
This example explains how to perform deep learning supervised classification using CNN
Setup & Imports¶
We begin by importing the required modules and setting up the environment. & Download the sample quickbird satellite image for our module
# ! pip install nickyspatial
Summary¶
In this notebook we will perform following steps
- Load a sample raster image.
- Perform segmentation on the raster.
- Define classes and sample collection
- Apply deeplearning supervised classification using Convolution Neural Network (CNN)
- Explore additional funtions: Merge_regions, Enclosed_by, Touched_by
import os
import requests
import matplotlib.pyplot as plt
import pandas as pd
from nickyspatial import (
plot_layer_interactive_plotly,
plot_classification,
plot_sample,
plot_layer,
read_raster,
plot_training_history,
LayerManager,
SlicSegmentation,
MergeRuleSet,
EnclosedByRuleSet,
TouchedByRuleSet,
SupervisedClassifierDL,
Layer,
)
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
data_dir = "data"
os.makedirs(data_dir, exist_ok=True)
raster_path = os.path.join(data_dir, "sample.tif")
if not os.path.exists(raster_path):
url = "https://github.com/kshitijrajsharma/nickyspatial/raw/refs/heads/master/data/sample.tif"
print(f"Downloading sample raster from {url}...")
response = requests.get(url)
response.raise_for_status() # Ensure the download succeeded
with open(raster_path, "wb") as f:
f.write(response.content)
print(f"Downloaded sample raster to {raster_path}")
else:
print(f"Using existing raster at: {raster_path}")
Reading the Raster¶
We now read the raster data and print some basic information about the image.
image_data, transform, crs = read_raster(raster_path)
print(f"Image dimensions: {image_data.shape}")
print(f"Coordinate system: {crs}")
Performing Segmentation¶
Here we perform multi-resolution segmentation. A LayerManager
is used to keep track of all layers created in the process. nickyspatial packages uses a layer object which is an underlying vector segmentation tied up to the raster , similar concept as layer in ecognition
manager = LayerManager()
segmenter = SlicSegmentation(scale=20, compactness=0.50)
# segmenter = SlicSegmentation(scale=20, compactness=0.50)
segmentation_layer = segmenter.execute(
image_data,
transform,
crs,
layer_manager=manager,
layer_name="Base_Segmentation",
)
print("Segmentation layer created:")
print(segmentation_layer)
Visualizing Segmentation¶
We utilize the built-in plotting function to visualize the segmentation. The image will be displayed inline.
# plt.close("all")
# %matplotlib inline
fig1 = plot_layer(layer=segmentation_layer, image_data=image_data, rgb_bands=(3, 2, 1), show_boundaries=True, figsize=(10, 8))
plt.show()
fig1.savefig(os.path.join(output_dir, "1_segmentation.png"))
Sample data collection¶
Using plotly package, interactive map is plotted to collect the segments_id(sample) for classification.
Just Hover the mouse in the map, segment_id will be displayed.
plot_layer_interactive_plotly(segmentation_layer, image_data, rgb_bands=(3, 2, 1), show_boundaries=True, figsize=(900, 600))
# Sample Data for Classification
# This section defines the sample data used for classification.
# Each class is assigned with a list of segment IDs and a specific color for visualization.
samples = {
"Water": [102, 384, 659, 1142, 1662, 1710, 2113, 2182, 2481, 1024],
"Builtup": [467, 1102, 1431, 1984, 1227, 1736, 774, 1065],
"Vegetation": [832, 1778, 2035, 1417, 1263, 242, 2049, 2397],
}
classes_color = {"Water": "#3437c2", "Builtup": "#de1421", "Vegetation": "#0f6b2f"}
Sample Data Visualization¶
In this steps, defined sample segment is visualized.
sample_objects = segmentation_layer.objects.copy()
sample_objects["classification"] = None
for class_name in samples.keys():
sample_objects.loc[sample_objects["segment_id"].isin(samples[class_name]), "classification"] = class_name
# Step 3: Wrap the modified GeoDataFrame back into a Layer
sample_layer = Layer(name="Sample Classification", type="classification")
sample_layer.objects = sample_objects
fig = plot_sample(
sample_layer,
image_data=image_data,
rgb_bands=(3, 2, 1),
transform=transform,
class_field="classification",
class_color=classes_color,
figsize=(10, 8),
)
plt.show()
params = {
"patch_size": (5, 5),
"epochs": 50,
"batch_size": 16,
"early_stopping_patience": 5,
"hidden_layers": [
{"filters": 32, "kernel_size": 3, "max_pooling": True},
{"filters": 64, "kernel_size": 3, "max_pooling": True},
],
"use_batch_norm": False,
"dense_units": 64,
}
# Defining SupervisedClassifierDL and executing it. It returns result_layer, model_history and eval_result
CNN_classification = SupervisedClassifierDL(
name="CN Classification", classifier_type="Convolution Neural Network (CNN)", classifier_params=params
)
CNN_classification_layer, model_history, eval_result, count_dict, invalid_patches_segments_ids = CNN_classification.execute(
source_layer=segmentation_layer, samples=samples, image_data=image_data, layer_manager=manager, layer_name="CNN Classification"
)
print(f"** Invalid Patch Segments: {invalid_patches_segments_ids}")
print("** Evaluation **")
print(f"Accuracy Assessment: {eval_result['accuracy']}")
cm = eval_result["confusion_matrix"]
df_cm = pd.DataFrame(cm, index=samples.keys(), columns=samples.keys())
df_cm.index.name = "Predicted Label"
print(df_cm)
# Plotting (training and validation) loss and accuracy
plot_training_history(model_history)
# plotting classification result
fig4 = plot_classification(CNN_classification_layer, class_field="classification", class_color=classes_color)
Applying merge rule¶
In this step we will further refine the classification. In this example, we will merge region based on class value i.e. merge adjacent segments if they share the same class label.
merger = MergeRuleSet("Merge Segmentation")
class_value = ["Water", "Vegetation"]
merged_layer = merger.execute(
source_layer=CNN_classification_layer,
class_column_name="classification",
class_value=class_value,
layer_manager=manager,
layer_name="Merged RF Classification",
)
fig4 = plot_classification(merged_layer, class_field="classification", class_color=classes_color, figsize=(10, 8))
Applying Enclosed_by rule¶
This rule is also applied based on class label. This function determine whether a object/segment is completely contained/surrounded within/by another object or class and return the updated layer.
This function is very helpful in applying the context-aware rules in classification.
In this example, we define the new class "Urban Vegetation" if "Vegetation" is enclosed by "Builtup".
enclosed_by_rule = EnclosedByRuleSet()
enclosed_by_layer = enclosed_by_rule.execute(
source_layer=merged_layer,
class_column_name="classification",
class_value_a="Vegetation",
class_value_b="Builtup",
new_class_name="Urban Vegetation",
layer_manager=manager,
layer_name="enclosed_by_layer",
)
classes_color["Urban Vegetation"] = "#84f547"
fig4 = plot_classification(enclosed_by_layer, class_field="classification", class_color=classes_color)
Applying Touched_by rule¶
This rule is also applied based on class labels. This determines whether an object/segment is in direct contact with another object or class — that is, they share a boundary.
This function is very useful in implementing context-aware rules in classification, especially when spatial relationships between features matter.
In this example, we define a new class "Builtup near WaterBodies" if "Builtup" is touchedBy "Water".
touched_by_rule = TouchedByRuleSet()
touched_by_layer = touched_by_rule.execute(
source_layer=enclosed_by_layer,
class_column_name="classification",
class_value_a="Builtup",
class_value_b="Water",
new_class_name="Builtup near WaterBodies",
layer_manager=manager,
layer_name="touched_by_layer",
)
classes_color["Builtup near WaterBodies"] = "#1df0e2"
fig4 = plot_classification(touched_by_layer, class_field="classification", class_color=classes_color)
# Applying merge rule to generate the final merged segments
merge_rule = MergeRuleSet("MergeByVegAndType")
merged_layer2 = merger.execute(
source_layer=touched_by_layer,
class_column_name="classification",
class_value=["Builtup near WaterBodies", "Builtup"],
layer_manager=manager,
layer_name="Merged RF Classification 2",
)
fig4 = plot_classification(merged_layer2, class_field="classification", class_color=classes_color)