虽然没有谷歌强大的集群和DeepMind变态的算法的团队,但基于深度强化学习(Deep Q Network DQN )的自制小游戏AI效果同样很赞。先上效果图:
/3/模型实现
3.1程序的总体结构
程序主函数在PlaneDQN.py中,与DQN模型相关的函数在BrainDQN_Nature.py中,游戏模型在game文件夹中,训练过程保存的训练值在saved_networks文件夹中。
def playPlane():# Step 1: init BrainDQNactions = 3brain = BrainDQN(actions)# Step 2: init Plane Gameplane = game.GameState()# Step 3: play game# Step 3.1: obtain init stateaction0 = np.array([1,0,0]) # [1,0,0]do nothing,[0,1,0]left,[0,0,1]rightobservation0, reward0, terminal = plane.frame_step(action0)observation0 = cv2.cvtColor(cv2.resize(observation0, (80, 80)), cv2.COLOR_BGR2GRAY)ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY)brain.setInitState(observation0)# Step 3.2: run the gamewhile 1!= 0:action = brain.getAction()nextObservation,reward,terminal = plane.frame_step(action)nextObservation = preprocess(nextObservation)brain.setPerception(nextObservation,action,reward,terminal)
3.3 游戏类GameState和framestep
通过pygame实现游戏界面的搭建,分别建立子弹类、玩家类、敌机类和游戏类,结构代码所示。
class Bullet(pygame.sprite.Sprite): def __init__(self, bullet_img, init_pos): def move(self):# 我方飞机类class Player(pygame.sprite.Sprite): def __init__(self, plane_img, player_rect, init_pos): def shoot(self, bullet_img): def moveLeft(self): def moveRight(self):# 敌方飞机类class Enemy(pygame.sprite.Sprite): def __init__(self, enemy_img, enemy_down_imgs, init_pos): def move(self):class GameState: def __init__(self): def frame_step(self, input_actions): if input_actions[0] == 1 or input_actions[1]== 1 or input_actions[2]== 1: # 检查输入正常 if input_actions[0] == 0 and input_actions[1] == 1 and input_actions[2] == 0: self.player.moveLeft() elif input_actions[0] == 0 and input_actions[1] == 0 and input_actions[2] == 1: self.player.moveRight() else: pass else: raise ValueError('Multiple input actions!') image_data = pygame.surfarray.array3d(pygame.display.get_surface()) pygame.display.update() clock = pygame.time.Clock() clock.tick(30) return image_data, reward, terminal
其中GameState中的framestep()函数,是整个DQN运行一次使环境发生变化的基础函数,该函数运行一次,会根据inputaction进行动作实施,接着会在该时段对界面上的元素进行移动,并判断是否撞击。最后通过get_surface获取界面图像,最后返回环境的image_data,reward和游戏是否停止的terminal。本文游戏效果图为:
class BrainDQN: def __init__(self,actions): def createQNetwork(self): return stateInput,QValue,W_conv1,b_conv1,W_conv2,b_conv2,W_conv3,b_conv3,W_fc1,b_fc1,W_fc2,b_fc2 def copyTargetQNetwork(self): self.session.run(self.copyTargetQNetworkOperation) def createTrainingMethod(self): def trainQNetwork(self): def getAction(self): return action def setInitState(self,observation): self.currentState = np.stack((observation, observation, observation, observation), axis = 2) def weight_variable(self,shape): return tf.Variable(initial) def bias_variable(self,shape): return tf.Variable(initial) def conv2d(self,x, W, stride): return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME") def max_pool_2x2(self,x): return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], paddin
3.5图像处理
图像预处理调用cv2库函数,对图像进行大小和灰度处理。
def preprocess(observation): observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY)#灰度转化 ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY) return np.reshape(observation,(80,80,1))
/4/环境搭建
系统:Ubuntu16.04、win10
源代码:https://github.com/zhangbinchao/PlaneDQN
联系客服