torchvision
datasets
torchvision.datasets 包含了许多标准数据集的加载器。例如,CIFAR10 和 ImageFolder 是其中两个非常常用的类。
CIFAR10
CIFAR10 数据集是一个广泛使用的数据集,包含10类彩色图像,每类有6000张图像(5000张训练集,1000张测试集)。下面是如何加载 CIFAR10 的示例:
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = datasets.CIFAR10(root=./data, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = datasets.CIFAR10(root=./data, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = (
plane,
car,
bird,
cat,
deer,
dog,
frog,
horse,
shIP,
truck)
ImageFolder
ImageFolder 用于加载按照类别分文件夹存储的图像数据集。
from torchvision import datasets, transforms
data_dir = ./path/to/dataset
transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
image_datasets = datasets.ImageFolder(data_dir, transform=transform)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, shuffle=True, num_workers=2)
models
torchvision.models 提供了一系列预训练模型,如 ResNet、VGG、InceptionV3 等。
ResNet模型:
SetsNet并不是torchvision中的一个组件,而是指一类处理集合数据的神经网络。SetsNet和其他类似的网络(如DeepSets)旨在处理无序的集合输入,这些输入可以是点云、图像集合、特征向量集合等。SetsNet的设计原则是输入集合的顺序不会影响输出,即网络应该对输入的排列不变。
import torchvision.models as models
model = models.resnet50(pretrained=True)
preprocess = transforms.Compose([
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
img_path = ./path/to/image.jpg
img = Image.open(img_path)
img_tensor = preprocess(img)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)
out = model(batch_img_tensor)
VGG模型:
VGG网络是一种经典的卷积神经网络架构,广泛应用于图像分类。下面是如何加载预训练的VGG模型并在一张图像上进行预测的示例:
from torchvision import models, transforms
vgg16 = models.vgg16(pretrained=True)
preprocess = transforms.Compose([
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
img_path = ./path/to/image.jpg
img_pil = Image.open(img_path)
img_tensor = preprocess(img_pil)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)
out = vgg16(batch_img_tensor)
_, pred = torch.max(out, 1)
print("Predicted class:", pred.item())
Inception模型:
InceptionV3是一种更复杂的卷积神经网络架构,设计用于处理高分辨率图像。以下是如何加载预训练的InceptionV3模型并进行预测:
from torchvision import models, transforms
inceptionv3 = models.inception_v3(pretrained=True)
preprocess = transforms.Compose([
transforms.CenterCrop(299),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
img_path = ./path/to/image.jpg
img_pil = Image.open(img_path)
img_tensor = preprocess(img_pil)
batch_img_tensor = torch.unsqueeze(img_tensor, 0)
out = inceptionv3(batch_img_tensor)
_, pred = torch.max(out, 1)
print("Predicted class:", pred.item())
utils
make_grid 网格排列
是一个用于在PyTorch中将多个图像张量组合成一个图像网格的函数。这对于可视化数据集、模型输出或者训练过程中的变化非常有用。make_grid接受一系列图像张量,并返回一个单一的张量,该张量包含了所有输入图像按网格排列的结果
import torchvision.utils as vutils
dataiter = iter(dataloaders)
images, labels = dataiter.next()
img_grid = vutils.make_grid(images)
imshow(img_grid.numpy().transpose((1, 2, 0)))
save_image 保存
图像
save_image函数可以用来保存一个张量为图像文件。下面是一个如何保存图像的例子:
from torchvision.utils import save_image
img_tensor = torch.randn(3, 224, 224)
save_image(img_tensor, saved_image.jpg)
img_pil = Image.new(RGB, (224, 224), color=white)
img_tensor = transforms.ToTensor()(img_pil)
save_image(img_tensor, saved_image_from_pil.jpg)
请确保替换上述代码中的./path/to/image.jpg为实际的图像路径,并确保在运行代码之前有正确的权限访问指定的路径。此外,如果还没有安装torchvision和Pillow,可能需要先安装:
p
IP install torchvision pillow
transforms
是PyTorch中一个重要的模块,用于进行图像预处理和数据增强。它位于torchvision.transforms模块中,主要用于处理PIL图像和Tensor图像。transforms可以帮助你在训练神经网络时对数据进行各种变换,例如随机裁剪、大小调整、正则化等,以增加数据的多样性和模型的鲁棒性。
常见的transforms包括:
数据类型转换:
ToTensor(): 将PIL
图像或NumPy数组转换为PyTorch的Tensor格式。
几何变换:
Resize(size): 调整
图像大小。 CenterCrop(size): 中心裁剪
图像。 RandomCrop(size): 随机裁剪
图像。 RandomHorizontalFl
IP(p=0.5): 随机水平翻转
图像。
色彩变换:
ColorJitter(brightness, contrast, saturation, hue): 随机调整
图像的亮度、对比度、饱和度和色调。
正则化:
Normalize(mean, std):
标准化
图像像素值。
使用transforms
通常需要将它们组合成一个transforms.Compose对象,以便按顺序应用到图像数据上。这样可以灵活地定义数据增强的流程,适应不同的任务需求和数据特征。
当使用transforms进行图像预处理和数据增强时,通常需要按照以下步骤进行操作:
1.导入必要的库:
from torchvision import transforms
2.定义transforms操作:可以根据需求选择合适的transforms进行组合。
transform = transforms.Compose([
transforms.Resize((
256,
256)),
# 调整图像大小为256x256
transforms.RandomCrop(
224),
# 随机裁剪图像为224x224
transforms.RandomHorizontalFl
IP(),
# 随机水平翻转图像
transforms.ToTensor(),
# 将图像转换为Tensor,并归一化至[0, 1]
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
3.加载图像并应用transforms:
img = Image.open(image.jpg)
img_transformed = transform(img)
4.查看处理后的图像:处理后的图像会转换为Tensor,并进行了resize、crop、翻转等操作。
print(img_transformed.size())
# 输出处理后的图像大小
在上面的例子中,transforms.Compose用于将多个transforms组合起来,依次应用到图像上。这种方式能够让你根据任务需求定义灵活的图像处理流程,例如在训练神经网络时进行数据增强,提升模型的泛化能力。