4 분 소요

1. Pruning

Example Image

  • Network에서 중요한 Prameter는 살리고, 그렇지 않은 Parameter는 Connection을 끊어버리는 방식.
    • 즉 영향력이 낮은 Weight를 제거하여, 연산량은 줄이고 성능은 유지하는 방식.
  • Parameter가 줄어들기에, Train 및 Inference에 소요되는 Computational Cost가 줄어듦.
  • Regularization(complexity가 줄어들기에)
  • 성능을 유지한 채, Weight Matrix를 sparse하게 만드는 것이 목표.

2. Pruning VS Drop-Out

  • Connection을 끊는 의미에서, Drop - out과 유사하지만 차이점 존재.
    1. Drop - Out
      • 학습 과정에서, 확률을 기준으로 특정 Connection Drop
      • 이는, Realization이기에 Inference 시에는 모든 Parameter가 전부 사용됨.
    2. Pruning
      • 실제 Drop된 Connection은 Train과 Inference에서 모두 Drop
      • 그렇기에, 연산량 감소에 효과적.

3. Sparsity

  • sparse하다는 것은, 희소하게 만드는 것.
    • Weight Matrix 내 0의 비율.
  • Sparsity를 어떻게 측정할 수 있지?

\[ \text{Sparsity} = 1 - \frac{\text{Non-zero parameters}}{\text{Total parameters}} \]

  • 그럼 제거하는 기준은?
    1. L1 Norm(Magnitude)
      • Weight 중 Absolute value가 작은 Weight를 제거하는 방식. \[ L_1 = \sum_{i} |w_i| \]
    2. L2 Norm
      • Weight 중 Squared value가 작은 Weight를 제거하는 방식. \[ L_2 = \sum_{i} w_i^2 \]

4. Unstructured VS Structured

  • 제거 되는 범위에 관한 문제.
    1. Unstructured
      • Global Pruning
      • 즉, 전체 Weight Matrix를 쭉 Flatten해서, 전역적으로 Pruning을 진행.

        Amount = 0.3, Unstructured pruning

        전체 Weight Matrix에서 0.3만큼 제거.

    2. Structured
      • Group(Local) Pruning
      • 특정 Layer, Filter, 뉴런 등의 단위로 Pruning

        Amount = 0.3, Structured pruning

        Input, Hidden, Ouput 별로 0.3만큼 제거.

5. code

5.1 Model Define

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1개 채널 수의 이미지를 입력값으로 이용하여 6개 채널 수의 출력값을 계산하는 방식
        # Convolution 연산을 진행하는 커널(필터)의 크기는 5x5 을 이용
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # Convolution 연산 결과 5x5 크기의 16 채널 수의 이미지
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

5.2 Define Pruning Params

parameters_to_prune = ( # Pruning 할 범위 선택
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight')
)

prune.global_unstructured( # Global Pruning
    parameters_to_prune,
    pruning_method=prune.L1Unstructured, # Pruning 기준은 L1 Norm
    amount=0.2, # Sparsity = 0.2
)

5.3 Result

total_params = 0
pruned_params = 0

for name, module in model.named_modules():
    if hasattr(module, 'weight') and hasattr(module, 'weight_mask'):
        total_params += module.weight.nelement()
        pruned_params += (module.weight_mask == 0).sum().item()

print(f"Total Parameters: {total_params}")
print(f"Pruned Parameters: {pruned_params}")
print(f"Non-Zero Parameters: {total_params - pruned_params}")
print(f"Sparsity : {round(1-((total_params - pruned_params)/total_params),1)}")

5.4 Pruning rate by layer

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

6. Limitations of Sparsity Pruning

Example Image

  • Sparsity Pruning은 Weight를 0으로 만들기에, 연산량 감소에는 효과적.
  • 그러나, Parameter를 제거하는 것은 아니기에, Size와 관련있는 Parameter의 개수를 줄일 수는 X
  • 실제 Parameter의 개수를 줄일 수 있는 방안들이 제시됨.
    • Ex. Channel Pruning
      • Convolution 기반 Network의 Channel을 줄여 연산량을 감소시키는 방식.

7. Torch-Pruning(framework for structural pruning)

  • Torch-Pruning은 CVPR’2023에 제시된 DepGraph를 기반으로 Pruning을 제공하는 structural pruning framework
  • MetaPruner, Channel 등 다양한 hi-level Pruner를 제공
  • 해당 Framework를 사용해 Resnet의 Parameter의 개수를 줄여보자.

7.1 Load DataSets(CIFAR-10)

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)

for images, labels in testloader:
    print(images.shape)
    print(labels.shape)
    break
    
example_inputs = images[0:1].to(device)

7.2 Load Resnet-18

import torch
import torchvision.models as models
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resnet18 = models.resnet18(pretrained=True).to(device = device)

summary(resnet18, (3, 224, 224))


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"Total trainable parameters: {count_parameters(resnet18)}")

7.3 Define Torch-Pruning Params

# !pip install torch-pruning
import torch_pruning as tp

# 1. 중요도 기준 정의. p=1 -> L1 Norm, p=2 -> L2 Norm
imp = tp.importance.GroupNormImportance(p=2) 

ignored_layers = []
for m in resnet18.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000: # Dense Layer의 경우 Pruning 시 output 출력에 영향을 끼치므로 배제.
        ignored_layers.append(m)

pruner = tp.pruner.MetaPruner(
    resnet18,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # Channel Pruning 비율. 즉 0.5의 경우 256 Channel은 128로 변환
    ignored_layers=ignored_layers,
    round_to=8, # 채널의 배수를 조정 즉, 4의 경우 4의 배수로, 8의 경우 8의 배수로 / 이를 통해 GPU 연산 최적화.
)

7.4 Pruning

base_macs, base_nparams = tp.utils.count_ops_and_params(resnet18, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(resnet18, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

summary(resnet18, (3, 224, 224))

7.5 Tabel

Model MACs(G) Params(M) Size(mb)
resnet-18 0.0102 3.0559 107.9600
pruning resnet-18 0.0029 0.8311 19.4500

댓글남기기