[Lightweight-Sereis] 1. Pruning
1. Pruning
- Network에서 중요한 Prameter는 살리고, 그렇지 않은 Parameter는 Connection을 끊어버리는 방식.
- 즉 영향력이 낮은 Weight를 제거하여, 연산량은 줄이고 성능은 유지하는 방식.
- Parameter가 줄어들기에, Train 및 Inference에 소요되는 Computational Cost가 줄어듦.
- Regularization(complexity가 줄어들기에)
- 성능을 유지한 채, Weight Matrix를 sparse하게 만드는 것이 목표.
2. Pruning VS Drop-Out
- Connection을 끊는 의미에서, Drop - out과 유사하지만 차이점 존재.
- Drop - Out
- 학습 과정에서, 확률을 기준으로 특정 Connection Drop
- 이는, Realization이기에 Inference 시에는 모든 Parameter가 전부 사용됨.
- Pruning
- 실제 Drop된 Connection은 Train과 Inference에서 모두 Drop
- 그렇기에, 연산량 감소에 효과적.
- Drop - Out
3. Sparsity
- sparse하다는 것은, 희소하게 만드는 것.
- Weight Matrix 내 0의 비율.
- Sparsity를 어떻게 측정할 수 있지?
\[ \text{Sparsity} = 1 - \frac{\text{Non-zero parameters}}{\text{Total parameters}} \]
- 그럼 제거하는 기준은?
- L1 Norm(Magnitude)
- Weight 중 Absolute value가 작은 Weight를 제거하는 방식. \[ L_1 = \sum_{i} |w_i| \]
- L2 Norm
- Weight 중 Squared value가 작은 Weight를 제거하는 방식. \[ L_2 = \sum_{i} w_i^2 \]
- L1 Norm(Magnitude)
4. Unstructured VS Structured
- 제거 되는 범위에 관한 문제.
- Unstructured
- Global Pruning
- 즉, 전체 Weight Matrix를 쭉 Flatten해서, 전역적으로 Pruning을 진행.
Amount = 0.3, Unstructured pruning
전체 Weight Matrix에서 0.3만큼 제거.
- Structured
- Group(Local) Pruning
- 특정 Layer, Filter, 뉴런 등의 단위로 Pruning
Amount = 0.3, Structured pruning
Input, Hidden, Ouput 별로 0.3만큼 제거.
- Unstructured
5. code
5.1 Model Define
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1개 채널 수의 이미지를 입력값으로 이용하여 6개 채널 수의 출력값을 계산하는 방식
# Convolution 연산을 진행하는 커널(필터)의 크기는 5x5 을 이용
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # Convolution 연산 결과 5x5 크기의 16 채널 수의 이미지
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
5.2 Define Pruning Params
parameters_to_prune = ( # Pruning 할 범위 선택
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight')
)
prune.global_unstructured( # Global Pruning
parameters_to_prune,
pruning_method=prune.L1Unstructured, # Pruning 기준은 L1 Norm
amount=0.2, # Sparsity = 0.2
)
5.3 Result
total_params = 0
pruned_params = 0
for name, module in model.named_modules():
if hasattr(module, 'weight') and hasattr(module, 'weight_mask'):
total_params += module.weight.nelement()
pruned_params += (module.weight_mask == 0).sum().item()
print(f"Total Parameters: {total_params}")
print(f"Pruned Parameters: {pruned_params}")
print(f"Non-Zero Parameters: {total_params - pruned_params}")
print(f"Sparsity : {round(1-((total_params - pruned_params)/total_params),1)}")
5.4 Pruning rate by layer
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
)
)
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
)
)
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
)
)
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
)
)
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
)
)
print(
"Global sparsity: {:.2f}%".format(
100. * float(
torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0)
)
/ float(
model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement()
)
)
)
6. Limitations of Sparsity Pruning
- Sparsity Pruning은 Weight를 0으로 만들기에, 연산량 감소에는 효과적.
- 그러나, Parameter를 제거하는 것은 아니기에, Size와 관련있는 Parameter의 개수를 줄일 수는 X
- 실제 Parameter의 개수를 줄일 수 있는 방안들이 제시됨.
- Ex. Channel Pruning
- Convolution 기반 Network의 Channel을 줄여 연산량을 감소시키는 방식.
- Ex. Channel Pruning
7. Torch-Pruning(framework for structural pruning)
- Torch-Pruning은 CVPR’2023에 제시된 DepGraph를 기반으로 Pruning을 제공하는 structural pruning framework
- MetaPruner, Channel 등 다양한 hi-level Pruner를 제공
- 해당 Framework를 사용해 Resnet의 Parameter의 개수를 줄여보자.
7.1 Load DataSets(CIFAR-10)
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
for images, labels in testloader:
print(images.shape)
print(labels.shape)
break
example_inputs = images[0:1].to(device)
7.2 Load Resnet-18
import torch
import torchvision.models as models
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18 = models.resnet18(pretrained=True).to(device = device)
summary(resnet18, (3, 224, 224))
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {count_parameters(resnet18)}")
7.3 Define Torch-Pruning Params
# !pip install torch-pruning
import torch_pruning as tp
# 1. 중요도 기준 정의. p=1 -> L1 Norm, p=2 -> L2 Norm
imp = tp.importance.GroupNormImportance(p=2)
ignored_layers = []
for m in resnet18.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000: # Dense Layer의 경우 Pruning 시 output 출력에 영향을 끼치므로 배제.
ignored_layers.append(m)
pruner = tp.pruner.MetaPruner(
resnet18,
example_inputs,
importance=imp,
pruning_ratio=0.5, # Channel Pruning 비율. 즉 0.5의 경우 256 Channel은 128로 변환
ignored_layers=ignored_layers,
round_to=8, # 채널의 배수를 조정 즉, 4의 경우 4의 배수로, 8의 경우 8의 배수로 / 이를 통해 GPU 연산 최적화.
)
7.4 Pruning
base_macs, base_nparams = tp.utils.count_ops_and_params(resnet18, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(resnet18, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
summary(resnet18, (3, 224, 224))
7.5 Tabel
Model | MACs(G) | Params(M) | Size(mb) |
---|---|---|---|
resnet-18 | 0.0102 | 3.0559 | 107.9600 |
pruning resnet-18 | 0.0029 | 0.8311 | 19.4500 |
댓글남기기