Windows10 + pytorch torchvision.transformsの実験

Windows10 + pytorch torchvision.transformsの実験

torchvision.transformsで画像処理

pytorchの仲間であるtorchvisonは画像に関する便利なライブラリ。
その中のtorchvision.transformsは様々な画像処理が行える。
pytorchを使わずとも独立して使える。



1. 画像準備

以前も使用したサイトでテスト画像を取得。



2. 環境構築

専用の環境を作る。

1
2
conda create -n torch python=3.7
conda activate torch

https://pytorch.org/get-started/locally/

公式に従ってpytorchを入れる。

1
conda install pytorch torchvision cpuonly -c pytorch

画像表示用にmatplotlibも入れる。

1
conda install matplotlib

3. 実験

3.1. torchvision.transformsの使い方

1
transform = torchvision.transforms.〇〇()

で画像変換用のインスタンスができる。

1
img = transform(img)

とするとimg(PIL imageを渡す)が変換される。

1
2
3
4
5
6
7
8
9
10
transform = torchvision.transforms.Compose([
torchvision.transforms.〇〇(),
torchvision.transforms.〇〇(),



torchvision.transforms.〇〇(),
])

img = transform(img)

とすると複数の変換処理を順番に実行する。

3.2. 実験結果

メジャーな画像処理4つと、それら全部を組み合わせたものを実験。

  • 1.回転
  • 2.輝度・コントラスト・彩度・色相変化
  • 3.パース変化
  • 4.部分消去
  • 5.全部

の、5パターン×3回の結果。



4. コード

実験に使用したコード。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torchvision

###### transformの準備 ######
# 1.回転
transform_1 = torchvision.transforms.RandomRotation(degrees=30)
# 2.輝度・コントラスト・彩度・色相
transform_2 = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
# 3.パース
transform_3 = torchvision.transforms.RandomPerspective(p=1.0)
# 4.部分消去(データ形式がTensorじゃないと使えないため、一度変換してから戻している)
transform_4 = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomErasing(p=1.0, value="random"),
torchvision.transforms.ToPILImage(),
])
# 5.全部
transform_5 = torchvision.transforms.Compose([
torchvision.transforms.RandomRotation(degrees=30),
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
torchvision.transforms.RandomPerspective(p=1.0),
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomErasing(p=1.0, value="random"),
torchvision.transforms.ToPILImage(),
])

###### 画像の準備 ######
img_path = "input.jpg"
img = Image.open(img_path)

###### 画像の変換 ######
img1 = transform_1(img)
img2 = transform_2(img)
img3 = transform_3(img)
img4 = transform_4(img)
img5 = transform_5(img)

###### 画像表示用のグラフ作成 ######
plt.subplot(1, 6, 1)
plt.imshow(np.array(img))
plt.subplot(1, 6, 2)
plt.imshow(np.array(img1))
plt.subplot(1, 6, 3)
plt.imshow(np.array(img2))
plt.subplot(1, 6, 4)
plt.imshow(np.array(img3))
plt.subplot(1, 6, 5)
plt.imshow(np.array(img4))
plt.subplot(1, 6, 6)
plt.imshow(np.array(img5))

# グラフの線を消す(やらないとなんか汚くなる)
for i in range(6):
plt.subplot(1, 6, i+1)
plt.axis('off')

###### 画像の表示・保存 ######
plt.savefig('output.png')
plt.show()

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×