首页 > 其他 > 详细

代码训练,图像分类通用测试代码

时间:2021-06-16 00:06:40      阅读:24      评论:0      收藏:0      [点我收藏+]

图像分类通用测试代码

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    data_transform = {
        "train":transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
        "val":transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
    data_root = os.path.abspath(os.path.join(os.getcwd(),"../.."))
    image_path = os.path.join(data_root, "data_set", "flower_data")
    assert os.path.exists(image_path),"{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),transform=data_transform["train"])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val,key) for key,val in flower_list.items())
    json_str = json.dumps(cla_dict,indent=4)
    with open(‘class_indices.json‘,‘w‘) as json_file:
        json_file.write(json_str)
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size >1 else 0, 8])
    print(‘Using {} dataloadet workers every process‘.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=True,num_workers=nw)
    print("using {} images for training, {} images fot validation.".format(train_num,val_num))
    net = AlexNet(num_classes=5,init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    save_path = ‘./AlexNet.pth‘
    best_acc = 0.0
    for epoch in range(10):
        net.train()
        running_loss = 0.0
        t1 = time.perf_counter()
        for step, data in enumerate(train_loader, start=0):
            images,labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs,labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            rate = (step + 1) / len(train_loader)
            a ="*" * int(rate * 50)
            b ="." * int((1-rate) * 50)
            print("\rtrain loss: {:^3.0f}%[{}->{}]{:.f}".format(int(rate * 100), a, b, loss), end="")
        print()
        print(time.perf_counter()-t1)
        net.eval()
        acc = 0.0
        with torch.no_grad():
            for val_data in validate_loader:
                val_images,val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += (predict_y == val_labels.to(device)).sum().item()
            val_accurate = acc / val_num
            if val_accurate > best_acc:
                best_acc = val_accurate
                torch.save(net.state_dict(), save_path)
            print(‘[epoch %d] train_loss: %.3f test_accuracy;%.3f‘ % (epoch + 1, running_loss / step, val_accurate))
    print(‘Finished Training‘)

代码训练,图像分类通用测试代码

原文:https://www.cnblogs.com/xmy-0904-lfx/p/14887537.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!