首页 > 其他 > 详细

FizeBuzz

时间:2020-05-19 18:14:49      阅读:49      评论:0      收藏:0      [点我收藏+]
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
import numpy as np
#traing setting
batch_size = 64
learning_rate =1e-4

def fize_buss_encode(i):
    if i%15 ==0: return 3
    elif i %5 ==0: return 2
    elif i %3 ==0: return 1
    else: return 0
def fize_buss_decode(i,prediction):
    return [str(i),"fizz", "buzz","fuzzbuzz"][prediction]
def helper(i):
    print(fize_buss_decode(i,fize_buss_encode(i)))

NUM_DIGHT = 13

def Transform(i,num_dight):
    return np.array([i>>d&1 for d in range(num_dight)][::-1])

trainX = torch.Tensor([Transform(i,NUM_DIGHT) for i in range(101,2**NUM_DIGHT)])
trainY =torch.LongTensor([fize_buss_encode(i)for i in range(101,2**NUM_DIGHT)])
‘‘‘class twolayer(torch.nn.Module):
    def __init__(self,D_in,H,D_out):
        super(twolayer,self).__init__()
        self.linear1 = nn.Linear(D_in,H)
        nn.ReLU()
        self.linear2 = nn.Linear(H,D_out)
    def forward(self,x):
        return self.linear2(self.linear1(x).clamp(min=0))
‘‘‘
H =100 

model = torch.nn.Sequential(torch.nn.Linear(NUM_DIGHT,H),
                            torch.nn.ReLU(),
                            torch.nn.Linear(H,4))

loss_fc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr = 0.05)
BATCH_SIZE = 128

for epoch in range(1000):
    loss =0
    for start in range(0,len(trainX),BATCH_SIZE):
        end = start + BATCH_SIZE
        train_x = trainX[start:end]
        train_y = trainY[start:end]
        loss = loss_fc(model(train_x),train_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

textX = torch.Tensor([Transform(i,NUM_DIGHT) for i in range(101)])
textY = torch.Tensor([fize_buss_encode(i) for i in range(101)])

with torch.no_grad():
    textY_pred = model(textX)

predictions = zip(range(0,101),textY_pred.max(1)[1].data.tolist())
for i, j in predictions:
    if fize_buss_decode(i,j) !=fize_buss_decode(i,fize_buss_encode(i)):
        print( fize_buss_decode(i,j),fize_buss_decode(i,fize_buss_encode(i)))

 

FizeBuzz

原文:https://www.cnblogs.com/yin101/p/12918354.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!