利用高斯核卷积对MINIST数据集进行去噪

2022-08-01,,,

import torch

import torchvision

from torch.autograd import Variable

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import cv2

from torch import nn

import numpy as np

import torch.nn.functional as F

import advertorch.defenses as defenses 

from numpy import *

seed = 2014

 

torch.manual_seed(seed)

np.random.seed(seed)  # Numpy module.

random.seed(seed)  # Python random module.

torch.manual_seed(seed)

 

train_dataset =   datasets.FashionMNIST('./fashionmnist_data/', train=True, download=True,

                       transform=transforms.Compose([

                           transforms.ToTensor(),

                       ]))

 

train_loader = DataLoader(dataset = train_dataset, batch_size = 500, shuffle = True)

 

test_loader = torch.utils.data.DataLoader(

        datasets.FashionMNIST('./fashionmnist_data/', train=False, transform=transforms.Compose([

        transforms.ToTensor(),

        ])),batch_size=1, shuffle=True)

 

epoch = 12

 

class Linear_cliassifer(torch.nn.Module):

    def __init__(self) :

        super(Linear_cliassifer, self).__init__()

 

        self.Gs = defenses.GaussianSmoothing2D(3, 1, 3)

        self.Line1 = torch.nn.Linear(28 * 28, 10)

 

    def forward(self, x):

 

        x = self.Gs(x)

        x = self.Line1(x.view(-1, 28 * 28))

 

        return x

 

net = Linear_cliassifer()

cost = torch.nn.CrossEntropyLoss()

 

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

 

for k in range(epoch):

    sum_loss = 0.0

    train_correct = 0

    for i, data in enumerate(train_loader, 0):

        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)

 

        loss = cost(outputs, labels)

        loss.backward()

        optimizer.step()

 

        print(loss)

        _, id = torch.max(outputs.data, 1) 

        sum_loss += loss.data

        train_correct += torch.sum(id == labels.data)

        #print('[%d,%d] loss:%.03f' % (k + 1, k, sum_loss / len(train_loader)))

    print('        correct:%.03f%%' % (100 * train_correct / len(train_dataset)))

    torch.save(net.state_dict(), 'model/fasion_BL.pt')

本文地址:https://blog.csdn.net/qq_23144435/article/details/107430946

《利用高斯核卷积对MINIST数据集进行去噪.doc》

下载本文的Word格式文档,以方便收藏与打印。