深度学习:使用深度学习处理地球观测数据进行语义分割(数据准备)

  • 时间:2025-12-08 22:33 作者: 来源: 阅读:0
  • 扫一扫,手机访问
摘要:使用深度学习处理地球观测数据进行语义分割:本教程是一份指南,用于将栅格数据和标签处理为深度学习语义分割准备的格式。 设置笔记本 # 安装必要的库(使用 -q 参数静默安装) !pip install -q rasterio==1.3.8 # 用于处理栅格地理空间数据 !pip install -q geopandas==0.13.2 # 用于处理矢量地理空间数据 !pip instal

使用深度学习处理地球观测数据进行语义分割:本教程是一份指南,用于将栅格数据和标签处理为深度学习语义分割准备的格式。

设置笔记本


# 安装必要的库(使用 -q 参数静默安装)
!pip install -q rasterio==1.3.8  # 用于处理栅格地理空间数据
!pip install -q geopandas==0.13.2  # 用于处理矢量地理空间数据
!pip install -q radiant_mlhub  # 用于访问遥感数据集,详见:https://mlhub.earth/

# 导入所需库
import os, glob, tarfile, json  # 系统、文件操作相关库
from itertools import product  # 用于创建笛卡尔积
from pathlib import Path  # 面向对象的文件路径操作

import numpy as np  # 数值计算库
from fractions import Fraction  # 分数运算
import matplotlib.pyplot as plt  # 绘图库
import matplotlib as mpl  # matplotlib基础配置
mpl.rcParams['axes.grid'] = False  # 关闭坐标轴网格
mpl.rcParams['figure.figsize'] = (12,12)  # 设置默认图形大小

from sklearn.model_selection import train_test_split  # 数据集划分工具
from sklearn.preprocessing import MinMaxScaler  # 数据归一化
import matplotlib.image as mpimg  # 图像读取
import pandas as pd  # 数据处理
from PIL import Image  # 图像处理

import rasterio  # 栅格数据处理
from rasterio import features, windows  # 栅格特征提取和窗口操作

import geopandas as gpd  # 地理空间数据处理
import cv2  # OpenCV计算机视觉库

from tqdm.notebook import tqdm  # 进度条显示

from radiant_mlhub import Dataset, Collection  # Radiant Earth MLHub数据集访问
from google.colab import drive  # Google Colab云盘挂载

# 配置Radiant Earth MLHub访问凭证
!mlhub configure

# 判断当前运行环境是否为Google Colab
if 'google.colab' in str(get_ipython()):
    # 挂载Google云盘
    drive.mount('/content/gdrive')
    # 设置处理输出目录(指向Google云盘)
    processed_outputs_dir = '/content/gdrive/My Drive/tf-eo-devseed-2-processed-outputs/'
    user_outputs_dir = '/content/gdrive/My Drive/tf-eo-devseed-2-user_outputs_dir'
    # 如果用户输出目录不存在则创建
    if not os.path.exists(user_outputs_dir):
        os.makedirs(user_outputs_dir)
    print('Running on Colab')  # 提示当前运行在Colab环境
else:
    # 本地运行时的目录设置
    processed_outputs_dir = os.path.abspath("./data/tf-eo-devseed-2-processed-outputs")
    user_outputs_dir = os.path.abspath('./tf-eo-devseed-2-user_outputs_dir')
    # 如果目录不存在则创建
    if not os.path.exists(user_outputs_dir):
        os.makedirs(user_outputs_dir)
        os.makedirs(processed_outputs_dir)
    print(f'Not running on Colab, data needs to be downloaded locally at {os.path.abspath(processed_outputs_dir)}')

# 切换到用户输出目录
%cd $user_outputs_dir

主要功能说明:

环境准备:安装地理空间数据处理、机器学习和数据集访问所需的Python库库导入:导入数据处理、可视化、机器学习、地理空间分析等各类工具库环境检测:自动检测是否运行在Google Colab云端环境,并相应设置数据存储路径路径配置: Colab环境:使用Google云盘存储处理结果本地环境:使用本地指定目录存储数据 工作目录切换:最后切换到用户输出目录,便于后续文件操作

启用 GPU


# 指定使用的TensorFlow版本(Colab特有命令)
# 注意:这个命令只在Google Colab环境中有效
%tensorflow_version 2.x

import tensorflow as tf

# 检测GPU设备
device_name = tf.test.gpu_device_name()

# 检查是否找到GPU设备
if device_name != '/device:GPU:0':
    # 如果没有找到GPU设备,抛出系统错误
    raise SystemError('GPU设备未找到')
    
# 如果找到GPU设备,打印设备信息
print('找到GPU设备位于: {}'.format(device_name))

访问数据集

数据链接


我们将使用来自 Radiant Earth MLHub 的作物类型分类数据集:该数据集包含 Sentinel-1 的雷达数据、来自 Planet Labs 的 3 米分辨率光学影像以及来自 Sentinel-2 的 10-20 米分辨率光学影像。数据集包含一个训练集和一个测试集,并附带相应的农田多边形标签。


ds = Dataset.fetch('dlr_fusion_competition_germany')
for c in ds.collections:
    print(c.id)

dlr_fusion_competition_germany_train_source_planet
dlr_fusion_competition_germany_train_source_planet_5day
dlr_fusion_competition_germany_train_source_sentinel_1
dlr_fusion_competition_germany_train_source_sentinel_2
dlr_fusion_competition_germany_test_source_planet
dlr_fusion_competition_germany_test_source_planet_5day
dlr_fusion_competition_germany_test_source_sentinel_1
dlr_fusion_competition_germany_test_source_sentinel_2
dlr_fusion_competition_germany_train_labels
dlr_fusion_competition_germany_test_labels


# 定义要下载的数据集集合列表
# 包含训练和测试数据源(Planet卫星影像)以及对应的标签数据
collections = [
    'dlr_fusion_competition_germany_train_source_planet_5day',  # 德国融合竞赛训练数据源(5天间隔的Planet影像)
    'dlr_fusion_competition_germany_test_source_planet_5day',   # 德国融合竞赛测试数据源
    'dlr_fusion_competition_germany_train_labels',              # 训练标签数据
    'dlr_fusion_competition_germany_test_labels'                # 测试标签数据
]

def download(collection_id):
    """下载指定数据集集合"""
    print(f'正在下载 {collection_id}...')
    
    # 从Radiant MLHub获取数据集集合
    collection = Collection.fetch(collection_id)
    
    # 下载数据集到当前目录
    path = collection.download('.')
    
    # 解压下载的tar.gz文件
    tar = tarfile.open(path, "r:gz")  # 以gzip格式打开tar文件
    tar.extractall()  # 解压所有文件到当前目录
    tar.close()  # 关闭tar文件
    
    # 删除压缩文件,只保留解压后的内容
    os.remove(path)
    
def resolve_path(base, path):
    """解析相对路径为绝对路径"""
    # 将基础路径和相对路径组合,并转换为绝对路径
    return Path(os.path.join(base, path)).resolve()
    
def load_df(collection_id):
    """加载数据集信息到DataFrame"""
    # 读取集合的元数据文件
    collection = json.load(open(f'{collection_id}/collection.json', 'r'))
    
    rows = []  # 存储所有数据行
    item_links = []  # 存储项目链接
    
    # 遍历集合中的所有链接,找到项目链接
    for link in collection['links']:
        if link['rel'] != 'item':  # 只处理项目链接
            continue
        item_links.append(link['href'])  # 收集项目链接
    
    # 遍历每个项目文件
    for item_link in item_links:
        item_path = f'{collection_id}/{item_link}'  # 项目文件完整路径
        current_path = os.path.dirname(item_path)   # 当前项目所在目录
        
        # 读取项目元数据
        item = json.load(open(item_path, 'r'))
        
        # 从项目ID中提取瓦片ID(最后一个下划线后的部分)
        tile_id = item['id'].split('_')[-1]
        
        # 处理项目本地的资产文件(标签文件)
        for asset_key, asset in item['assets'].items():
            rows.append([
                tile_id,               # 瓦片ID
                None,                  # 时间戳(标签没有时间信息)
                None,                  # 卫星平台(标签没有平台信息)
                asset_key,             # 资产类型(如:labels)
                str(resolve_path(current_path, asset['href']))  # 文件绝对路径
            ])
            
        # 处理关联的源数据(卫星影像)
        for link in item['links']:
            if link['rel'] != 'source':  # 只处理源数据链接
                continue
                
            link_path = resolve_path(current_path, link['href'])  # 源数据元数据文件路径
            source_path = os.path.dirname(link_path)  # 源数据所在目录
            
            try:
                # 读取源数据元数据
                source_item = json.load(open(link_path, 'r'))
            except FileNotFoundError:
                continue  # 如果文件不存在则跳过
            
            # 提取时间信息
            datetime = source_item['properties']['datetime']
            # 提取卫星平台信息(从集合名称中获取)
            satellite_platform = source_item['collection'].split('_')[-1]
            
            # 处理源数据的所有资产文件(卫星影像波段)
            for asset_key, asset in source_item['assets'].items():
                rows.append([
                    tile_id,               # 瓦片ID
                    datetime,              # 影像采集时间
                    satellite_platform,    # 卫星平台(如:planet)
                    asset_key,             # 资产类型(如:B1, B2, B3, B4等波段)
                    str(resolve_path(source_path, asset['href']))  # 文件绝对路径
                ])
    
    # 创建DataFrame并返回
    return pd.DataFrame(rows, columns=[
        'tile_id',              # 瓦片标识符
        'datetime',             # 数据采集时间(卫星影像才有)
        'satellite_platform',   # 卫星平台(如:planet)
        'asset',                # 资产类型(标签或波段)
        'file_path'             # 文件绝对路径
    ])

# 下载所有数据集集合
for c in collections:
    download(c)

# 加载训练标签和测试标签信息到DataFrame
train_df = load_df('dlr_fusion_competition_germany_train_labels')
test_df = load_df('dlr_fusion_competition_germany_test_labels')

主要功能说明:

数据集定义:定义了四个数据集集合,包含训练/测试的卫星影像和对应的土地覆盖标签

download()函数

从Radiant MLHub平台下载指定的数据集自动解压tar.gz压缩文件清理压缩包,只保留解压后的内容

resolve_path()函数

工具函数,用于将相对路径转换为绝对路径确保文件路径的正确解析

load_df()函数(核心功能):

解析元数据:读取collection.json和item.json文件提取瓦片ID:从项目ID中识别唯一的空间瓦片处理标签数据:加载土地覆盖标签文件信息处理卫星影像:加载关联的卫星影像波段文件构建结构化数据:将所有文件信息整理成DataFrame,包含: 瓦片ID:空间位置标识时间戳:影像采集时间(仅卫星影像)卫星平台:数据来源平台资产类型:具体的数据类型(标签或波段)文件路径:具体文件的绝对路径

数据处理流程

循环下载所有四个数据集分别加载训练标签和测试标签信息到pandas DataFrame为后续的数据读取和处理提供结构化索引

数据结构示例:
生成的DataFrame将包含类似以下结构的数据:


tile_id | datetime           | satellite_platform | asset | file_path
--------|--------------------|-------------------|-------|-----------
123     | 2023-06-01T10:30:00| planet            | B1    | /path/to/band1.tif
123     | 2023-06-01T10:30:00| planet            | B2    | /path/to/band2.tif
123     | None               | None              | labels| /path/to/labels.tif

这样的数据结构便于后续按瓦片、时间或数据类型进行数据检索和配对。

查看标签

通过将 geojson 加载为 GeoDataFrame 来检查存储在其中的类别标签。类别名称和标识符是从此处提供的文档中提取的:https://radiantearth.blob.core.windows.net/mlhub/esa-food-security-challenge/Crops_GT_Brandenburg_Doc.pdf


# 设置pandas显示选项,确保完整显示列内容(不截断)
pd.set_option('display.max_colwidth', None)

# 定义土地覆盖类别(Land Use Land Cover, LULC)数据
# 这是一个分类任务中的类别映射表,将类别名称映射到类别ID
data = {
    'class_names': [  # 类别名称列表
        'Background',    # 0: 背景
        'Wheat',         # 1: 小麦
        'Rye',           # 2: 黑麦
        'Barley',        # 3: 大麦
        'Oats',          # 4: 燕麦
        'Corn',          # 5: 玉米
        'Oil Seeds',     # 6: 油料作物
        'Root Crops',    # 7: 块根作物
        'Meadows',       # 8: 草甸
        'Forage Crops'   # 9: 饲料作物
    ],
    'class_ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  # 对应的类别ID
}

# 创建包含类别名称和ID的DataFrame
classes = pd.DataFrame(data)
print(classes)

# 将类别映射表保存到CSV文件,方便后续使用
classes.to_csv('lulc_classes.csv')

# 接下来,从实际数据中验证类别信息
# 读取训练标签的GeoJSON文件(包含实际的地块多边形和作物类型信息)
labels_geo = gpd.read_file('dlr_fusion_competition_germany_train_labels/dlr_fusion_competition_germany_train_labels_33N_18E_242N/labels.geojson')

# 获取实际数据中存在的所有作物类别ID
classes = labels_geo.crop_id.unique()  # crop_id列包含作物类别ID
classes.sort()  # 对类别ID进行排序
print("实际标签GeoJSON文件中的类别ID: ", classes)

代码功能详解:

类别定义部分

背景说明:在遥感影像分类任务中,需要将像素分为不同的土地覆盖类别类别映射:创建了10个类别的映射表,从0到9分别代表: 0: 背景(非农田或未分类区域)1-9: 各种农作物类型(小麦、黑麦、大麦等) 数据结构:使用pandas DataFrame存储,便于查看和操作数据保存:将类别映射保存为CSV文件,供后续模型训练和评估使用

实际数据验证部分

文件读取:使用geopandas读取GeoJSON格式的标签文件 GeoJSON是一种地理空间数据格式,包含多边形几何信息和属性在这个数据集中,每个多边形代表一个农田地块,crop_id表示作物类型 类别提取:从 crop_id列提取所有唯一的类别ID排序显示:对类别ID进行排序并打印,用于: 验证数据集中实际存在的类别确认是否有缺失类别或额外类别与预定义的类别映射进行对比验证

实际应用场景

语义分割任务:每个像素都会被分类为这10个类别之一农业监测:通过卫星影像自动识别作物类型,用于产量估算、病虫害监测等数据验证:确保理论类别定义与实际数据的一致性

输出示例


   class_names  class_ids
0   Background          0
1        Wheat          1
2          Rye          2
3       Barley          3
4         Oats          4
5         Corn          5
6    Oil Seeds          6
7   Root Crops          7
8      Meadows          8
9  Forage Crops         9

实际标签GeoJSON文件中的类别ID:  [0 1 2 3 4 5 6 7 8 9]

验证结果说明
如果输出显示实际数据中的类别ID与定义的类别ID一致(0-9),说明:

数据集中包含了所有预定义的作物类型没有额外的未知类别类别映射关系是正确的,可以用于后续的模型训练

栅格处理

在读取每个 Planetscope 图像后,我们将使用均值和标准差对其进行归一化,并将所有值重新缩放到 8 位整数。将值重新缩放到 8 位整数可以使训练数据尽可能小,以便我们可以将更大的批量大小装入 GPU 内存。对每个图像进行归一化有助于模型以更高的数值稳定性进行训练,并将数据调整到反映预训练模型预训练数据的分布。


def raster_read(raster_dir):
    """
    读取和处理栅格数据(卫星影像)
    
    参数:
    raster_dir: str - 栅格数据所在目录路径
    
    返回值:
    raster_dir: str - 输入目录路径
    rgbn: rasterio数据集对象 - 处理后的8位影像数据
    rgbn_src: rasterio数据集对象 - 原始影像数据(用于元数据参考)
    target_crs: CRS对象 - 目标坐标系
    """
    print(f"正在读取栅格数据目录: {raster_dir}")

    # 1. 读取波段元数据和数组
    # 打开RGB+NIR(红、绿、蓝、近红外)四波段影像文件
    rgbn = rasterio.open(os.path.join(raster_dir, 'sr.tif'))  # sr.tif可能代表"surface reflectance"(地表反射率)
    rgbn_src = rgbn  # 保存原始数据集的引用,用于后续的元数据
    target_crs = rgbn_src.crs  # 获取坐标系信息
    print(f"RGB+NIR数据源信息: {rgbn}")

    # 2. 读取和处理影像数组
    # 将原始的16位影像重新缩放到8位(0-255范围)
    scale = True  # 控制是否进行缩放处理的开关
    
    if scale:
        # 使用OpenCV的normalize函数进行归一化处理
        # cv2.NORM_MINMAX: 使用最小-最大归一化方法
        # 将像素值从原始范围线性映射到0-255
        rgbn_norm = cv2.normalize(rgbn.read(), None, 0, 255, cv2.NORM_MINMAX)
        
        # 3. 将处理后的8位影像写入新的TIFF文件
        rgbn_norm_out = rasterio.open(
            os.path.join(raster_dir, 'sr_byte_scaled.tif'),  # 输出文件路径
            'w',  # 写入模式
            driver='Gtiff',  # 使用GeoTIFF驱动
            width=rgbn_src.width,  # 保持原始宽度
            height=rgbn_src.height,  # 保持原始高度
            count=4,  # 4个波段(RGB+NIR)
            crs=rgbn_src.crs,  # 保持原始坐标系
            transform=rgbn_src.transform,  # 保持原始地理变换参数
            dtype='uint8'  # 数据类型为8位无符号整数
        )
        
        # 将归一化后的数据写入新文件
        rgbn_norm_out.write(rgbn_norm)
        rgbn_norm_out.close()  # 关闭文件以确保数据写入磁盘
        
        # 重新打开缩放后的文件作为新的数据源
        rgbn = rasterio.open(os.path.join(raster_dir, 'sr_byte_scaled.tif'))
    else:
        # 如果不进行缩放,直接打开之前处理好的8位文件
        rgbn = rasterio.open(os.path.join(raster_dir, 'sr_byte_scaled.tif'))
    
    print("已完成8位缩放处理。")
    
    # 返回处理结果
    return raster_dir, rgbn, rgbn_src, target_crs

接下来我们将计算相关的光谱指数。在 3 波段影像上微调预训练模型时使用光谱指数,是一种快速利用多个多光谱波段信息同时仍能获得预训练好处的方法。
以下是我们将要计算的指标


WDRVI: Wide Dynamic Range Vegetation Index
WDRVI:宽动态范围植被指数
NDVI: Normalized Difference Vegetation Index
NDVI:归一化植被指数
SI: Shadow Index SI: 影子指数


# 计算光谱指数并将它们堆叠成一个3通道图像
def indexnormstack(red, green, blue, nir):
    """
    计算三个光谱指数并将它们归一化堆叠成一个3通道图像
    
    参数:
    red: numpy数组 - 红波段反射率数据
    green: numpy数组 - 绿波段反射率数据  
    blue: numpy数组 - 蓝波段反射率数据
    nir: numpy数组 - 近红外波段反射率数据
    
    返回值:
    index_stack: numpy数组 - 形状为(H, W, 3)的堆叠光谱指数图像
    """
    
    # 1. 定义光谱指数计算函数
    
    def WDRVIcalc(nir, red):
        """
        计算宽动态范围植被指数 (Wide Dynamic Range Vegetation Index)
        
        公式: WDRVI = (α * NIR - Red) / (α * NIR + Red)
        参数 α=0.15 用于调整NIR和Red的相对权重
        用途: 增强植被信号,比NDVI对高植被覆盖区域更敏感
        """
        a = 0.15  # 权重系数
        wdrvi = (a * nir - red) / (a * nir + red)
        return wdrvi
    
    def NPCRIcalc(red, blue):
        """
        计算归一化色素叶绿素反射率指数 (Normalized Pigment Chlorophyll Reflectance Index)
        
        公式: NPCRI = (Red - Blue) / (Red + Blue)
        用途: 估算叶绿素含量,监测植物营养状况
        注意: 此函数在代码中被注释掉了,当前未使用
        """
        npcri = (red - blue) / (red + blue)
        return npcri
    
    def NDVIcalc(nir, red):
        """
        计算归一化差异植被指数 (Normalized Difference Vegetation Index)
        
        公式: NDVI = (NIR - Red) / (NIR + Red)
        用途: 最常用的植被指数,反映植被生长状况和覆盖度
        分母加1e-5防止除零错误
        """
        ndvi = (nir - red) / (nir + red + 1e-5)  # 加小值防止分母为零
        return ndvi
    
    def SIcalc(red, green, blue):
        """
        计算土壤指数 (Soil Index)
        
        公式: SI = ((1-Red)*(1-Green)*(1-Blue))^(1/3)
        用途: 突出土壤特征,区分植被和裸土
        注意: 使用了Fraction确保精确的分数计算
        """
        expo = Fraction('1/3')  # 使用分数确保精确计算立方根
        si = (((1 - red) * (1 - green) * (1 - blue)) ** expo)
        return si
    
    def norm(arr):
        """
        对数组进行归一化到0-255范围
        
        参数:
        arr: numpy数组 - 输入数组
        
        返回值:
        arr_norm: numpy数组 - 归一化后的数组 (0-255)
        """
        # 重塑数组为2D以适配MinMaxScaler
        arr_reshaped = arr.reshape(-1, 1)
        
        # 创建并拟合归一化器
        scaler = MinMaxScaler(feature_range=(0, 255))
        scaler = scaler.fit(arr_reshaped)
        
        # 应用归一化
        arr_norm = scaler.transform(arr_reshaped)
        
        # 重构回原始形状
        arr_norm = arr_norm.reshape(arr.shape)
        
        # 检查重建的代码(被注释掉了)
        # arr_norm = scaler.inverse_transform(arr_norm)
        
        return arr_norm
    
    # 2. 计算三个光谱指数
    wdrvi = WDRVIcalc(nir, red)      # 宽动态范围植被指数
    # npcri = NPCRIcalc(red, blue)   # 注释掉的叶绿素指数
    ndi = NDVIcalc(nir, red)         # 归一化差异植被指数(实际是NDVI)
    si = SIcalc(red, green, blue)    # 土壤指数
    
    # 3. 打印各指数的数值范围(用于调试)
    print(f"wdrvi范围: [{wdrvi.min():.3f}, {wdrvi.max():.3f}], "
          f"ndi范围: [{ndi.min():.3f}, {ndi.max():.3f}], "
          f"si范围: [{si.min():.3f}, {si.max():.3f}]")
    
    # 4. 对每个指数进行归一化(0-255范围)
    wdrvi_norm = norm(wdrvi)
    ndi_norm = norm(ndi)
    si_norm = norm(si)
    
    # 5. 将三个指数堆叠成3通道图像
    # np.dstack: 沿第三轴(深度轴)堆叠数组
    index_stack = np.dstack((wdrvi_norm, ndi_norm, si_norm))
    
    return index_stack

光谱指数解释

1. WDRVI (宽动态范围植被指数)

公式: WDRVI = (α * NIR - Red) / (α * NIR + Red), α=0.15特点: 在高植被覆盖区域比NDVI更敏感值域: 通常为[-1, 1],正值表示植被,负值表示非植被

2. NDVI (归一化差异植被指数)

公式: NDVI = (NIR - Red) / (NIR + Red)特点: 最常用的植被指数,对植被绿度敏感值域: [-1, 1],0.2-0.5表示中等植被,>0.5表示茂密植被

3. SI (土壤指数)

公式: SI = ((1-Red)*(1-Green)*(1-Blue))^(1/3)特点: 突出土壤特征,降低植被影响用途: 识别裸土、区分土壤类型

数据处理流程


原始波段数据 (4个)
    ↓
计算三个光谱指数 (WDRVI, NDVI, SI)
    ↓
分别归一化到0-255范围
    ↓
沿深度轴堆叠成3通道图像
    ↓
输出: (H, W, 3)的numpy数组

应用场景

植被监测: WDRVI和NDVI对植被变化敏感土壤分析: SI指数突出土壤特征特征增强: 相比原始波段,光谱指数能更好地区分地物类型机器学习: 为分类模型提供更有判别力的特征

注意事项

输入数据要求: 反射率值应在合理范围内(通常0-1或0-10000)归一化: 每个指数单独归一化,保持各自的分布特征异常值: 计算中考虑了除零错误,但输入数据应预先清洗数据形状: 输入应为相同形状的2D数组

示例输出


# 假设有波段数据
red_band = ...    # 红波段
green_band = ...  # 绿波段  
blue_band = ...   # 蓝波段
nir_band = ...    # 近红外波段

# 计算光谱指数堆栈
feature_image = indexnormstack(red_band, green_band, blue_band, nir_band)

# 结果形状: (height, width, 3)
# 通道0: WDRVI (归一化到0-255)
# 通道1: NDVI (归一化到0-255)
# 通道2: SI (归一化到0-255)

我们也可以堆叠感兴趣的特定波段,并用这些数据来训练模型。


def bandstack(red, green, blue, nir):

    stack = np.dstack((red, green, blue))

    return stack

以下是(可选的)光学合成色彩校正。我们通常在训练过程中准备数据增强,这些增强可以改变亮度值并在 tensorflow 数据管道中创建其他合成数据。但提前保存已应用色彩校正的图像输入会很有帮助,这样在训练过程中更容易可视化、比较图像与标签和预测。


# 函数:调整图像亮度
def change_brightness(img, value=30):
    """
    通过HSV颜色空间调整图像亮度
    
    参数:
    img: numpy数组 - 输入的BGR格式图像
    value: int - 亮度调整值,正值增加亮度,负值减少亮度,默认30
    
    返回值:
    img: numpy数组 - 亮度调整后的BGR格式图像
    """
    # 1. 将图像从BGR颜色空间转换到HSV颜色空间
    # HSV (Hue色调, Saturation饱和度, Value明度) 颜色模型更符合人类对颜色的感知
    # Value通道直接对应亮度,便于单独调整
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    
    # 2. 分离HSV的三个通道
    # h: 色调通道 (0-180,OpenCV中Hue范围是0-180而不是0-360)
    # s: 饱和度通道 (0-255)
    # v: 明度/亮度通道 (0-255)
    h, s, v = cv2.split(hsv)
    
    # 3. 调整亮度通道
    # 使用cv2.add函数将亮度值增加指定的value
    # 如果value为正,图像变亮;如果value为负,图像变暗
    v = cv2.add(v, value)
    
    # 4. 处理亮度值溢出问题
    # 确保亮度值在有效的0-255范围内
    v[v > 255] = 255  # 超过255的值截断为255(防止过曝)
    v[v < 0] = 0      # 低于0的值截断为0(防止过暗)
    
    # 5. 重新合并调整后的HSV通道
    final_hsv = cv2.merge((h, s, v))
    
    # 6. 将HSV图像转换回BGR颜色空间
    img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
    
    return img

函数功能详解

核心原理

使用HSV颜色空间而不是RGB/BGR来调整亮度,因为:

解耦性:HSV将颜色信息(Hue)和亮度信息(Value)分离直观性:亮度调整不会改变图像的颜色色调计算效率:只需操作一个通道(Value通道)

HSV颜色空间说明


H (Hue): 色调,表示颜色类型(0°红色,120°绿色,240°蓝色)
S (Saturation): 饱和度,表示颜色纯度(0%灰色,100%纯色)
V (Value): 明度/亮度,表示颜色明亮程度(0%黑色,100%白色)

参数说明

参数类型默认值说明
imgnumpy.ndarray-输入图像,应为BGR格式(OpenCV默认)
valueint30亮度调整值:
• 正值:增加亮度
• 负值:减少亮度
• 范围:-255到255

工作流程


输入BGR图像
    ↓
转换为HSV颜色空间
    ↓
分离H、S、V三个通道
    ↓
对V通道进行亮度调整
    ↓
处理溢出值(确保在0-255范围内)
    ↓
重新合并HSV通道
    ↓
转换回BGR颜色空间
    ↓
返回调整后的图像

应用场景

1. 数据增强


# 在机器学习中创建亮度变化的数据增强
augmented_images = []
for brightness_shift in [-50, -25, 0, 25, 50]:
    bright_img = change_brightness(original_img, brightness_shift)
    augmented_images.append(bright_img)

2. 图像预处理


# 统一调整数据集中图像的亮度
normalized_img = change_brightness(dark_img, 40)  # 提高暗图像亮度

3. 视觉效果调整


# 根据显示需求调整图像亮度
bright_for_display = change_brightness(image, 20)  # 轻微提亮

使用示例


import cv2

# 读取图像
image = cv2.imread('input.jpg')

# 增加亮度
brighter_image = change_brightness(image, 50)  # 增加50单位亮度

# 降低亮度
darker_image = change_brightness(image, -30)  # 减少30单位亮度

# 显示结果
cv2.imshow('Original', image)
cv2.imshow('Brighter', brighter_image)
cv2.imshow('Darker', darker_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

注意事项

输入格式:图像应为BGR格式(OpenCV默认),不是RGB数据类型:图像应为uint8类型(0-255整数)亮度范围: 原始值范围:0-255调整后自动截断到有效范围过度调整可能导致信息丢失 性能考虑:对于大图像,此操作可能较慢,考虑批量处理时优化

高级应用

自适应亮度调整


def adaptive_brightness_adjustment(img):
    """根据图像平均亮度自动调整"""
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    v = hsv[:, :, 2]
    
    # 计算平均亮度
    avg_brightness = np.mean(v)
    
    # 根据平均亮度决定调整方向
    if avg_brightness < 100:  # 图像较暗
        value = 50
    elif avg_brightness > 200:  # 图像过亮
        value = -30
    else:  # 亮度适中
        value = 0
    
    return change_brightness(img, value)

与光谱指数结合


# 在遥感图像处理中,可以先计算光谱指数,再调整亮度显示
index_image = indexnormstack(red, green, blue, nir)  # 之前定义的函数
bright_index_image = change_brightness(index_image, 30)  # 提亮显示

这个函数在计算机视觉和遥感图像处理中非常实用,特别是在需要标准化图像亮度或进行数据增强的场景中。
如果你正在将矢量文件(例如 GeoJSON 或 Shapefile)中的标签栅格化。

下面我们将矢量化标签栅格化。这是图像分割中的一个常见步骤,因为我们需要将代表标签的像素图像传递给 Tensorflow 进行训练,以便计算损失函数。

我们将标签形状文件读入一个 geopandas 数据框,检查无效的几何形状并将其设置为本地坐标参考系统。然后,我们使用其对应的灰度波段图像的元数据来栅格化标记的多边形。

在这个函数中,当使用两个矢量文件进行标记时,会使用 geo_1 。后者优先于前者,因为它在相交时会覆盖。这段代码在历史上使用过,当时不同的数据集对同一位置和日期的标签存在冲突,我们选择了一个标签进行演示。


def label(geos, labels_src):
    """
    将矢量标签数据(GeoJSON)栅格化为与影像匹配的标签图像
    
    参数:
    geos: list - GeoJSON文件路径列表(通常包含训练/验证标签)
    labels_src: rasterio数据集 - 参考影像数据源,用于获取栅格化参数
    
    返回值:
    labels: numpy数组 - 栅格化的标签图像,像素值为类别ID
    """
    
    # 1. 读取第一个GeoJSON文件(主要标签源)
    geo_0 = gpd.read_file(geos[0])
    
    # 检查并移除无效的几何图形(避免栅格化错误)
    geo_0 = geo_0.loc[geo_0.is_valid]
    
    # 将标签数据重新投影到目标坐标系(与影像一致)
    geo_0 = geo_0.to_crs(crs={'init': target_crs})
    
    # 将类别标识列(crop_id)转换为整数类型
    geo_0['landcover_int'] = geo_0.crop_id.astype(int)
    
    # 创建(几何图形, 类别值)对的生成器,用于栅格化
    shapes_0 = ((geom, value) for geom, value in zip(geo_0.geometry, geo_0.landcover_int))
    
    # 2. 处理可能存在的第二个标签源(例如验证集标签)
    if len(geos) > 1:
        geo_1 = gpd.read_file(geos[1])
        geo_1 = geo_1.loc[geo_1.is_valid]  # 移除无效几何
        geo_1 = geo_1.to_crs(crs={'init': target_crs})  # 重新投影
        geo_1['landcover_int'] = geo_1.crop_id.astype(int)  # 类型转换
        shapes_1 = ((geom, value) for geom, value in zip(geo_1.geometry, geo_1.landcover_int))
    else:
        print("只有一个矢量标签源。")  # 只有单个标签文件时输出提示
    
    # 3. 获取参考影像的元数据(用于栅格化)
    labels_src_prf = labels_src.profile  # 获取影像的完整元数据配置
    
    # 4. 栅格化第一个标签源
    # features.rasterize函数将矢量几何图形转换为栅格
    labels = features.rasterize(
        shapes=shapes_0,  # 几何图形和值对的生成器
        out_shape=(labels_src_prf['height'], labels_src_prf['width']),  # 输出形状(与影像相同)
        fill=0,  # 背景填充值(0通常表示背景类别)
        all_touched=True,  # 如果为True,所有接触到的像素都会被填充(适用于多边形)
        transform=labels_src_prf['transform'],  # 地理变换参数(坐标到像素的映射)
        dtype=labels_src_prf['dtype']  # 输出数据类型(与影像一致)
    )
    
    # 5. 如果有第二个标签源,将其叠加到现有标签上
    if len(geos) > 1:
        # 使用out参数直接更新现有的labels数组
        labels = features.rasterize(
            shapes=shapes_1,
            fill=0,  # 背景填充值
            all_touched=True,  # 所有接触的像素都填充
            out=labels,  # 在现有数组上继续绘制(叠加)
            transform=labels_src_prf['transform']
        )
    else:
        print("只有一个矢量标签源。")  # 只有单个标签文件时输出提示
    
    # 6. 打印标签图像中的唯一值(用于验证)
    print("标签图像中的值: ", np.unique(labels))
    
    return labels

下面是一个单一函数,用于将所有处理过的栅格数据写入文件,以便我们稍后用于训练。

在模型训练之前保存处理过的中间结果通常更高效,而不是实时进行所有图像处理。


def save_images(raster_dir, rgb_norm, stack, index_stack, labels, rgb_src):

    stack_computed = True # change to True if using the stack helper function above

    if stack_computed:
      stack_t = stack.transpose(2,0,1)
    else:
      stack_t = stack

    stack_out=rasterio.open(os.path.join(raster_dir,'stack.tif'), 'w', driver='Gtiff',
                              width=rgb_src.width, height=rgb_src.height,
                              count=3,
                              crs=rgb_src.crs,
                              transform=rgb_src.transform,
                              dtype='uint8')

    stack_out.write(stack_t)

    indices_computed = True # change to True if using the index helper function above
    if indices_computed:
      index_stack_t = index_stack.transpose(2,0,1)
    else:
      index_stack_t = index_stack

    index_stack_out=rasterio.open(os.path.join(raster_dir,'index_stack.tif'), 'w', driver='Gtiff',
                              width=rgb_src.width, height=rgb_src.height,
                              count=3,
                              crs=rgb_src.crs,
                              transform=rgb_src.transform,
                              dtype='uint8')

    index_stack_out.write(index_stack_t)
    #index_stack_out.close()

    labels = labels.astype(np.uint8)
    labels_out=rasterio.open(os.path.join(raster_dir,'labels.tif'), 'w', driver='Gtiff',
                              width=rgb_src.width, height=rgb_src.height,
                              count=1,
                              crs=rgb_src.crs,
                              transform=rgb_src.transform,
                              dtype='uint8')

    labels_out.write(labels, 1)
    #labels_out.close()

    print("written")

    return os.path.join(raster_dir,'stack.tif'), os.path.join(raster_dir,'index_stack.tif'), os.path.join(raster_dir,'labels.tif')

现在让我们将光学/索引堆栈和标记图像分割成 224x224 像素的瓦片。


def tile(index_stack, labels, prefix, width, height, raster_dir, output_dir, brighten=False):
    tiles_dir = os.path.join(output_dir,'tiled/')
    img_dir = os.path.join(output_dir,'tiled/stacks_brightened/')
    label_dir = os.path.join(output_dir,'tiled/labels/')
    dirs = [tiles_dir, img_dir, label_dir]
    for d in dirs:
        if not os.path.exists(d):
            os.makedirs(d)

    def get_tiles(ds):
        # get number of rows and columns (pixels) in the entire input image
        nols, nrows = ds.meta['width'], ds.meta['height']
        # get the grid from which tiles will be made 
        offsets = product(range(0, nols, width), range(0, nrows, height))
        # get the window of the entire input image
        big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
        # tile the big window by mini-windows per grid cell
        for col_off, row_off in offsets:
            window = windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
            transform = windows.transform(window, ds.transform)
            yield window, transform

    tile_width, tile_height = width, height

    def crop(inpath, outpath, c):
        # read input image
        image = rasterio.open(inpath)
        # get the metadata 
        meta = image.meta.copy()
        print("meta: ", meta)
        # set the number of channels to 3 or 1, depending on if its the index image or labels image
        meta['count'] = int(c)
        # set the tile output file format to PNG (saves spatial metadata unlike JPG)
        meta['driver']='PNG'
        meta['dtype']='uint8'
        # tile the input image by the mini-windows
        i = 0
        for window, transform in get_tiles(image):
            meta['transform'] = transform
            meta['width'], meta['height'] = window.width, window.height
            outfile = os.path.join(outpath,"tile_%s_%s.png" % (prefix, str(i)))
            with rasterio.open(outfile, 'w', **meta) as outds:
                if brighten:
                  imw = image.read(window=window)
                  imw = imw.transpose(1,2,0)
                  imwb = change_brightness(imw, value=50)
                  imwb = imwb.transpose(2,0,1)
                  outds.write(imwb)
                else:
                  outds.write(image.read(window=window))
            i = i+1

    def process_tiles(index_flag):
        # tile the input images, when index_flag == True, we are tiling the spectral index image, 
        # when False we are tiling the labels image
        if index_flag==True:
            inpath = os.path.join(raster_dir,'stack.tif')
            outpath=img_dir
            crop(inpath, outpath, 3)
        else:
            inpath = os.path.join(raster_dir,'labels.tif')
            outpath=label_dir
            crop(inpath, outpath, 1)

    process_tiles(index_flag=True) # tile stack
    process_tiles(index_flag=False) # tile labels
    return tiles_dir, img_dir, label_dir

train_images_dir = 'dlr_fusion_competition_germany_train_source_planet_5day'
%cd $train_images_dir 

如果你想要将文件写入到你的个人驱动器,请设置 write_out = True,但我们建议你在空闲时间尝试,因为使用 Google Colab + Google Drive 进行存储时,所有复合体需要大约 2 小时或更长时间。


write_out = False
if write_out:
  raster_out_dir = os.path.join(user_outputs_dir,'rasters/')
  if not os.path.exists(raster_out_dir):
    os.makedirs(raster_out_dir)
  for train_image_dir in train_images_dirs: #[0:1]:
    # read the rasters and scale to 8bit
    print("reading and scaling rasters...")
    raster_dir, rgbn, rgbn_src, target_crs = raster_read(os.path.join(train_images_dir,train_image_dir))

    # Calculate indices and combine the indices into one single 3 channel image
    print("calculating spectral indices...")
    index_stack = indexnormstack(rgbn.read(3), rgbn.read(2), rgbn.read(1), rgbn.read(4))

    # Stack channels of interest (RGB) into one single 3 channel image
    print("Stacking channels of interest...")
    stack = bandstack(rgbn.read(3), rgbn.read(2), rgbn.read(1), rgbn.read(4))

    # Color correct the RGB image
    print("Color correcting a RGB image...")
    cc_stack = change_brightness(stack)

    # Rasterize labels
    labels = label([os.path.join(processed_outputs_dir,'dlr_fusion_competition_germany_train_labels/dlr_fusion_competition_germany_train_labels_33N_18E_242N/labels.geojson')], rgbn_src)

    # Save index stack and labels to geotiff
    print("writing scaled rasters and labels to file...")
    stack_file, index_stack_file, labels_file = save_images(raster_dir, rgbn, cc_stack, index_stack, labels, rgbn_src)

    # Tile images into 224x224
    print("tiling the indices and labels...")
    tiles_dir, img_dir, label_dir = tile(stack, labels, str(train_image_dir), 224, 224, raster_dir, raster_out_dir, brighten=False)
else:
  print("Not writing to file; using preprocessed dataset in shared drive.")

将数据读入内存

准备数据

我们将使用 tf-eo-devseed-processed-outputs/ 文件夹中的以下文件夹和文件:

tf-eo-devseed-processed-outputs/
├── stacks/
├── stacks_brightened/
├── indices/
├── labels/
├── background_list_train.txt
├── train_list_clean.txt
└── lulc_classes.csv

获取用于训练和测试的图像和标签瓦片对列表。


def get_train_test_lists(imdir, lbldir):
    """
    从指定目录获取训练/测试图像和标签文件列表
    
    参数:
    imdir: str - 图像文件所在目录路径
    lbldir: str - 标签文件所在目录路径
    
    返回值:
    dset_list: list - 数据集ID列表(不包含扩展名的文件名)
    x_filenames: list - 图像文件完整路径列表
    y_filenames: list - 标签文件完整路径列表
    """
    
    # 1. 获取图像目录中所有PNG格式的文件路径
    imgs = glob.glob(os.path.join(imdir, "*.png"))
    # print(imgs[0:1])  # 调试用:打印第一个图像路径(被注释掉了)
    
    # 2. 提取所有图像的基础文件名(不包含扩展名和路径)
    dset_list = []  # 存储数据集ID(文件名)
    for img in imgs:
        # 分离文件名和扩展名:例如 "image_001.png" → ("image_001", ".png")
        filename_split = os.path.splitext(img)
        filename_zero, fileext = filename_split
        
        # 获取文件名(不包含路径和扩展名)
        basename = os.path.basename(filename_zero)
        dset_list.append(basename)
    
    # 3. 构建图像和标签文件的完整路径列表
    x_filenames = []  # 图像文件路径列表
    y_filenames = []  # 标签文件路径列表
    
    for img_id in dset_list:
        # 图像文件路径:imdir/img_id.png
        x_filenames.append(os.path.join(imdir, "{}.png".format(img_id)))
        # 标签文件路径:lbldir/img_id.png
        y_filenames.append(os.path.join(lbldir, "{}.png".format(img_id)))
    
    # 4. 打印数据集大小信息
    print("图像数量: ", len(dset_list))
    
    return dset_list, x_filenames, y_filenames


# 使用函数获取训练数据集列表
train_list, x_train_filenames, y_train_filenames = get_train_test_lists(img_dir, label_dir)

在用卫星图像进行训练时,我们感兴趣的目标检测可能非常稀疏,这导致很多图像中没有任何感兴趣类别。在计算机视觉中,我们将这些图像称为背景图像。了解背景图像和非背景图像的数量及比例,以便控制类别不平衡,这很有帮助。更理想的是了解背景和非背景图像的地理分布,并通过采样策略进行控制。

这里我们检查背景瓦片的比例。这需要一段时间。所以运行一次后,可以通过加载保存的结果来跳过。


# 设置一个跳过标志,用于控制是否重新计算背景图像列表
skip = False  # 当skip=False时重新计算,skip=True时从文件加载

if not skip:  # 如果skip=False,执行计算过程
    background_list_train = []  # 初始化一个空列表,用于存储纯背景图像的ID
    
    # 遍历训练集列表中的所有图像ID
    for i in train_list: 
        # 读取对应的标签图像(注释掉的调试行:打印文件路径)
        # print(os.path.join(label_dir,"{}.png".format(i))) 
        
        # 使用PIL打开标签图像并转换为numpy数组
        img = np.array(Image.open(os.path.join(label_dir, "{}.png".format(i))))
        
        # 检查图像中是否有大于0的值(即是否有非背景的标签)
        # 由于背景类别通常用0表示,如果图像最大值是0,说明整个图像都是背景
        if img.max() == 0:
            background_list_train.append(i)  # 将纯背景图像的ID添加到列表中

    # 打印纯背景图像的数量
    print("纯背景图像数量: ", len(background_list_train))

    # 将背景图像列表保存到文件中,供以后使用
    with open(os.path.join(processed_outputs_dir, 'background_list_train.txt'), 'w') as f:
        for item in background_list_train:
            f.write("%s
" % item)  # 每个图像ID单独一行

else:  # 如果skip=True,跳过计算过程,直接从文件加载
    # 从保存的文件中读取背景图像列表
    background_list_train = [line.strip() for line in open("background_list_train.txt", 'r')]
    print("纯背景图像数量: ", len(background_list_train))

我们将只保留总量的 10%。过多的背景瓦片会导致一种形式的类别不平衡,因为背景类别可能包含很多不同的现象和误报,并且非常过度代表。


background_removal = len(background_list_train) * 0.9
train_list_clean = [y for y in train_list if y not in background_list_train[0:int(background_removal)]]

x_train_filenames = []
y_train_filenames = []

for i, img_id in zip(tqdm(range(len(train_list_clean))), train_list_clean):
  pass 
  x_train_filenames.append(os.path.join(img_dir, "{}.png".format(img_id)))
  y_train_filenames.append(os.path.join(label_dir, "{}.png".format(img_id)))

print("Number of background tiles: ", background_removal)
print("Remaining number of tiles after 90% background removal: ", len(train_list_clean))

现在我们已经有了用于开发模型的文件集,需要将它们分成三个集合:

模型用于学习的训练集

允许我们评估模型并决定是否更改模型的验证集

以及我们将用于展示最佳模型结果(由验证集确定)的测试集

我们将索引瓦片和标签瓦片分成训练、验证和测试集:分别为 70%、20%和 10%。


x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = train_test_split(x_train_filenames, y_train_filenames, test_size=0.3, random_state=42)
x_val_filenames, x_test_filenames, y_val_filenames, y_test_filenames = train_test_split(x_val_filenames, y_val_filenames, test_size=0.33, random_state=42)

num_train_examples = len(x_train_filenames)
num_val_examples = len(x_val_filenames)
num_test_examples = len(x_test_filenames)

print("Number of training examples: {}".format(num_train_examples))
print("Number of validation examples: {}".format(num_val_examples))
print("Number of test examples: {}".format(num_test_examples))

vals_train = []
vals_val = []
vals_test = []

def get_vals_in_partition(partition_list, x_filenames, y_filenames):
  for x,y,i in zip(x_filenames, y_filenames, tqdm(range(len(y_filenames)))):
      pass 
      try:
        img = np.array(Image.open(y)) 
        vals = np.unique(img)
        partition_list.append(vals)
      except:
        continue

def flatten(partition_list):
    return [item for sublist in partition_list for item in sublist]

get_vals_in_partition(vals_train, x_train_filenames, y_train_filenames)
get_vals_in_partition(vals_val, x_val_filenames, y_val_filenames)
get_vals_in_partition(vals_test, x_test_filenames, y_test_filenames)
print("Values in training partition: ", set(flatten(vals_train)))
print("Values in validation partition: ", set(flatten(vals_val)))
print("Values in test partition: ", set(flatten(vals_test)))

可视化数据

设置要显示的样本数量

display_num = 3

从文件加载纯背景图像列表

background_list_train = [line.strip() for line in open(“background_list_train.txt”, ‘r’)]

1. 筛选出包含前景标签的图像(即非纯背景图像)

foreground_list_x = [] # 存储前景图像的路径
foreground_list_y = [] # 存储前景标签的路径

遍历所有训练图像和标签对

for x, y in zip(x_train_filenames, y_train_filenames):
try:
# 从标签文件路径中提取基本文件名(不含扩展名)
filename_split = os.path.splitext(y)
filename_zero, fileext = filename_split
basename = os.path.basename(filename_zero)


    # 如果该图像不在背景列表中(即包含前景标签)
    if basename not in background_list_train:
        foreground_list_x.append(x)  # 添加图像路径
        foreground_list_y.append(y)  # 添加标签路径
    else:
        continue  # 如果是纯背景图像,跳过
except:
    continue  # 如果出现异常,跳过该样本

计算前景样本的数量

num_foreground_examples = len(foreground_list_y)

2. 随机选择要显示的前景样本

r_choices = np.random.choice(num_foreground_examples, display_num)

3. 创建可视化图表

plt.figure(figsize=(10, 15))

for i in range(0, display_num * 2, 2): # 每次循环处理一个样本对(图像+标签)
# 计算当前要显示的样本索引
img_num = r_choices[i // 2] # 随机选择的索引
img_num = i // 2 # 注意:这里覆盖了随机选择,实际是按顺序显示前display_num个


# 获取图像和标签的路径
x_pathname = foreground_list_x[img_num]  # 原始图像路径
y_pathname = foreground_list_y[img_num]  # 标签图像路径

# 显示原始图像(左列)
plt.subplot(display_num, 2, i + 1)  # 创建子图:行数=display_num,列数=2,当前位置=i+1
plt.imshow(mpimg.imread(x_pathname))  # 使用matplotlib读取和显示图像
plt.title("原始图像")  # 设置子图标题

# 读取标签图像
example_labels = Image.open(y_pathname)

# 获取标签中的唯一值(用于了解有哪些类别)
label_vals = np.unique(np.array(example_labels))

# 显示标签图像(右列)
plt.subplot(display_num, 2, i + 2)  # 创建子图:当前位置=i+2
plt.imshow(example_labels)  # 显示标签图像
plt.title("标签图像")  # 设置子图标题

添加整个图的标题

plt.suptitle(“图像及其掩码的示例”)
plt.show() # 显示图像

  • 全部评论(0)
最新发布的资讯信息
【系统环境|】Linux 安全审计工具 Auditd(2025-12-08 23:24)
【系统环境|】使用Supervisor守护PHP进程:告别手动重启,实现自动化运维(2025-12-08 23:24)
【系统环境|】golang高性能日志库zap的使用(2025-12-08 23:24)
【系统环境|】MySQL主从复制技术详解(2025-12-08 23:24)
【系统环境|】华为MagicBook锐龙版双系统折腾记六:matlab(2025-12-08 23:24)
【系统环境|】ArrayFire:C++高性能张量计算的极速引擎(2025-12-08 23:24)
【系统环境|】一文读懂回声消除(AEC)(2025-12-08 23:23)
【系统环境|】缺人!泰达这些企业招聘!抓紧!(2025-12-08 23:23)
【系统环境|】RS485 Modbus 超级简单轮询程序(2025-12-08 23:23)
【系统环境|】RS485接口≠Modbus协议!工业通信常见认知陷阱(2025-12-08 23:23)
手机二维码手机访问领取大礼包
返回顶部