Commit ·
2322e9b
1
Parent(s): 7da2e8e
Upload folder using huggingface_hub
Browse files- goal_keeper.py +1001 -0
- openrl_policy.py +446 -0
- openrl_utils.py +421 -0
- submission.py +81 -0
goal_keeper.py
ADDED
|
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2023 The OpenRL Authors.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
# original code from https://github.com/Sarvar-Anvarov/Google-Research-Football/blob/main/gfootball.py
|
| 18 |
+
# modified by TARTRL team
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
import random
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from functools import wraps
|
| 25 |
+
from enum import Enum
|
| 26 |
+
from typing import *
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Action(Enum):
|
| 31 |
+
Idle = 0
|
| 32 |
+
Left = 1
|
| 33 |
+
TopLeft = 2
|
| 34 |
+
Top = 3
|
| 35 |
+
TopRight = 4
|
| 36 |
+
Right = 5
|
| 37 |
+
BottomRight = 6
|
| 38 |
+
Bottom = 7
|
| 39 |
+
BottomLeft = 8
|
| 40 |
+
LongPass= 9
|
| 41 |
+
HighPass = 10
|
| 42 |
+
ShortPass = 11
|
| 43 |
+
Shot = 12
|
| 44 |
+
Sprint = 13
|
| 45 |
+
ReleaseDirection = 14
|
| 46 |
+
ReleaseSprint = 15
|
| 47 |
+
Slide = 16
|
| 48 |
+
Dribble = 17
|
| 49 |
+
ReleaseDribble = 18
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
ALL_DIRECTION_ACTIONS = [Action.Left, Action.TopLeft, Action.Top, Action.TopRight, Action.Right, Action.BottomRight, Action.Bottom, Action.BottomLeft]
|
| 53 |
+
ALL_DIRECTION_VECS = [(-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]
|
| 54 |
+
|
| 55 |
+
sticky_index_to_action = [
|
| 56 |
+
Action.Left,
|
| 57 |
+
Action.TopLeft,
|
| 58 |
+
Action.Top,
|
| 59 |
+
Action.TopRight,
|
| 60 |
+
Action.Right,
|
| 61 |
+
Action.BottomRight,
|
| 62 |
+
Action.Bottom,
|
| 63 |
+
Action.BottomLeft,
|
| 64 |
+
Action.Sprint,
|
| 65 |
+
Action.Dribble
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
GOAL_BIAS = 0.01
|
| 69 |
+
|
| 70 |
+
class PlayerRole(Enum):
|
| 71 |
+
GoalKeeper = 0
|
| 72 |
+
CenterBack = 1
|
| 73 |
+
LeftBack = 2
|
| 74 |
+
RightBack = 3
|
| 75 |
+
DefenceMidfield = 4
|
| 76 |
+
CentralMidfield = 5
|
| 77 |
+
LeftMidfield = 6
|
| 78 |
+
RIghtMidfield = 7
|
| 79 |
+
AttackMidfield = 8
|
| 80 |
+
CentralFront = 9
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class GameMode(Enum):
|
| 84 |
+
Normal = 0
|
| 85 |
+
KickOff = 1
|
| 86 |
+
GoalKick = 2
|
| 87 |
+
FreeKick = 3
|
| 88 |
+
Corner = 4
|
| 89 |
+
ThrowIn = 5
|
| 90 |
+
Penalty = 6
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def human_readable_agent(agent: Callable[[Dict], Action]):
|
| 94 |
+
"""
|
| 95 |
+
Decorator allowing for more human-friendly implementation of the agent function.
|
| 96 |
+
@human_readable_agent
|
| 97 |
+
def my_agent(obs):
|
| 98 |
+
...
|
| 99 |
+
return football_action_set.action_right
|
| 100 |
+
"""
|
| 101 |
+
@wraps(agent)
|
| 102 |
+
def agent_wrapper(obs) -> List[int]:
|
| 103 |
+
# Extract observations for the first (and only) player we control.
|
| 104 |
+
# obs = obs['players_raw'][0]
|
| 105 |
+
# Turn 'sticky_actions' into a set of active actions (strongly typed).
|
| 106 |
+
obs['sticky_actions'] = { sticky_index_to_action[nr] for nr, action in enumerate(obs['sticky_actions']) if action }
|
| 107 |
+
# Turn 'game_mode' into an enum.
|
| 108 |
+
obs['game_mode'] = GameMode(obs['game_mode'])
|
| 109 |
+
# In case of single agent mode, 'designated' is always equal to 'active'.
|
| 110 |
+
if 'designated' in obs:
|
| 111 |
+
del obs['designated']
|
| 112 |
+
# Conver players' roles to enum.
|
| 113 |
+
obs['left_team_roles'] = [ PlayerRole(role) for role in obs['left_team_roles'] ]
|
| 114 |
+
obs['right_team_roles'] = [ PlayerRole(role) for role in obs['right_team_roles'] ]
|
| 115 |
+
|
| 116 |
+
action = agent(obs)
|
| 117 |
+
return [action.value]
|
| 118 |
+
|
| 119 |
+
return agent_wrapper
|
| 120 |
+
|
| 121 |
+
def find_patterns(obs, player_x, player_y):
|
| 122 |
+
""" find list of appropriate patterns in groups of memory patterns """
|
| 123 |
+
for get_group in groups_of_memory_patterns:
|
| 124 |
+
group = get_group(obs, player_x, player_y)
|
| 125 |
+
if group["environment_fits"](obs, player_x, player_y):
|
| 126 |
+
return group["get_memory_patterns"](obs, player_x, player_y)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_action_of_agent(obs, player_x, player_y):
|
| 130 |
+
""" get action of appropriate pattern in agent's memory """
|
| 131 |
+
memory_patterns = find_patterns(obs, player_x, player_y)
|
| 132 |
+
# find appropriate pattern in list of memory patterns
|
| 133 |
+
for get_pattern in memory_patterns:
|
| 134 |
+
pattern = get_pattern(obs, player_x, player_y)
|
| 135 |
+
if pattern["environment_fits"](obs, player_x, player_y):
|
| 136 |
+
return pattern["get_action"](obs, player_x, player_y)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_distance(x1, y1, right_team):
|
| 140 |
+
""" get two-dimensional Euclidean distance, considering y size of the field """
|
| 141 |
+
return math.sqrt((x1 - right_team[0]) ** 2 + (y1 * 2.38 - right_team[1] * 2.38) ** 2)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def run_to_ball_bottom(obs, player_x, player_y):
|
| 145 |
+
""" run to the ball if it is to the bottom from player's position """
|
| 146 |
+
def environment_fits(obs, player_x, player_y):
|
| 147 |
+
""" environment fits constraints """
|
| 148 |
+
# ball is to the bottom from player's position
|
| 149 |
+
if (obs["ball"][1] > player_y and
|
| 150 |
+
abs(obs["ball"][0] - player_x) < 0.01):
|
| 151 |
+
return True
|
| 152 |
+
return False
|
| 153 |
+
|
| 154 |
+
def get_action(obs, player_x, player_y):
|
| 155 |
+
""" get action of this memory pattern """
|
| 156 |
+
return Action.Bottom
|
| 157 |
+
|
| 158 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def run_to_ball_bottom_left(obs, player_x, player_y):
|
| 162 |
+
""" run to the ball if it is to the bottom left from player's position """
|
| 163 |
+
def environment_fits(obs, player_x, player_y):
|
| 164 |
+
""" environment fits constraints """
|
| 165 |
+
# ball is to the bottom left from player's position
|
| 166 |
+
if (obs["ball"][0] < player_x and
|
| 167 |
+
obs["ball"][1] > player_y):
|
| 168 |
+
return True
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
def get_action(obs, player_x, player_y):
|
| 172 |
+
""" get action of this memory pattern """
|
| 173 |
+
return Action.BottomLeft
|
| 174 |
+
|
| 175 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def run_to_ball_bottom_right(obs, player_x, player_y):
|
| 179 |
+
""" run to the ball if it is to the bottom right from player's position """
|
| 180 |
+
def environment_fits(obs, player_x, player_y):
|
| 181 |
+
""" environment fits constraints """
|
| 182 |
+
# ball is to the bottom right from player's position
|
| 183 |
+
if (obs["ball"][0] > player_x and
|
| 184 |
+
obs["ball"][1] > player_y):
|
| 185 |
+
return True
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
def get_action(obs, player_x, player_y):
|
| 189 |
+
""" get action of this memory pattern """
|
| 190 |
+
return Action.BottomRight
|
| 191 |
+
|
| 192 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def run_to_ball_left(obs, player_x, player_y):
|
| 196 |
+
""" run to the ball if it is to the left from player's position """
|
| 197 |
+
def environment_fits(obs, player_x, player_y):
|
| 198 |
+
""" environment fits constraints """
|
| 199 |
+
# ball is to the left from player's position
|
| 200 |
+
if (obs["ball"][0] < player_x and
|
| 201 |
+
abs(obs["ball"][1] - player_y) < 0.01):
|
| 202 |
+
return True
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
def get_action(obs, player_x, player_y):
|
| 206 |
+
""" get action of this memory pattern """
|
| 207 |
+
return Action.Left
|
| 208 |
+
|
| 209 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def run_to_ball_right(obs, player_x, player_y):
|
| 213 |
+
""" run to the ball if it is to the right from player's position """
|
| 214 |
+
def environment_fits(obs, player_x, player_y):
|
| 215 |
+
""" environment fits constraints """
|
| 216 |
+
# ball is to the right from player's position
|
| 217 |
+
if (obs["ball"][0] > player_x and
|
| 218 |
+
abs(obs["ball"][1] - player_y) < 0.01):
|
| 219 |
+
return True
|
| 220 |
+
return False
|
| 221 |
+
|
| 222 |
+
def get_action(obs, player_x, player_y):
|
| 223 |
+
""" get action of this memory pattern """
|
| 224 |
+
return Action.Right
|
| 225 |
+
|
| 226 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def run_to_ball_top(obs, player_x, player_y):
|
| 230 |
+
""" run to the ball if it is to the top from player's position """
|
| 231 |
+
def environment_fits(obs, player_x, player_y):
|
| 232 |
+
""" environment fits constraints """
|
| 233 |
+
# ball is to the top from player's position
|
| 234 |
+
if (obs["ball"][1] < player_y and
|
| 235 |
+
abs(obs["ball"][0] - player_x) < 0.01):
|
| 236 |
+
return True
|
| 237 |
+
return False
|
| 238 |
+
|
| 239 |
+
def get_action(obs, player_x, player_y):
|
| 240 |
+
""" get action of this memory pattern """
|
| 241 |
+
return Action.Top
|
| 242 |
+
|
| 243 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def run_to_ball_top_left(obs, player_x, player_y):
|
| 247 |
+
""" run to the ball if it is to the top left from player's position """
|
| 248 |
+
def environment_fits(obs, player_x, player_y):
|
| 249 |
+
""" environment fits constraints """
|
| 250 |
+
# ball is to the top left from player's position
|
| 251 |
+
if (obs["ball"][0] < player_x and
|
| 252 |
+
obs["ball"][1] < player_y):
|
| 253 |
+
return True
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
def get_action(obs, player_x, player_y):
|
| 257 |
+
""" get action of this memory pattern """
|
| 258 |
+
return Action.TopLeft
|
| 259 |
+
|
| 260 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def run_to_ball_top_right(obs, player_x, player_y):
|
| 264 |
+
""" run to the ball if it is to the top right from player's position """
|
| 265 |
+
def environment_fits(obs, player_x, player_y):
|
| 266 |
+
""" environment fits constraints """
|
| 267 |
+
# ball is to the top right from player's position
|
| 268 |
+
if (obs["ball"][0] > player_x and
|
| 269 |
+
obs["ball"][1] < player_y):
|
| 270 |
+
return True
|
| 271 |
+
return False
|
| 272 |
+
|
| 273 |
+
def get_action(obs, player_x, player_y):
|
| 274 |
+
""" get action of this memory pattern """
|
| 275 |
+
return Action.TopRight
|
| 276 |
+
|
| 277 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def idle(obs, player_x, player_y):
|
| 281 |
+
""" do nothing, release all sticky actions """
|
| 282 |
+
def environment_fits(obs, player_x, player_y):
|
| 283 |
+
""" environment fits constraints """
|
| 284 |
+
return True
|
| 285 |
+
|
| 286 |
+
def get_action(obs, player_x, player_y):
|
| 287 |
+
""" get action of this memory pattern """
|
| 288 |
+
return Action.Idle
|
| 289 |
+
|
| 290 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def start_sprinting(obs, player_x, player_y):
|
| 294 |
+
""" make sure player is sprinting """
|
| 295 |
+
def environment_fits(obs, player_x, player_y):
|
| 296 |
+
""" environment fits constraints """
|
| 297 |
+
if Action.Sprint not in obs["sticky_actions"]:
|
| 298 |
+
return True
|
| 299 |
+
return False
|
| 300 |
+
|
| 301 |
+
def get_action(obs, player_x, player_y):
|
| 302 |
+
""" get action of this memory pattern """
|
| 303 |
+
if Action.Dribble in obs['sticky_actions']:
|
| 304 |
+
return Action.ReleaseDribble
|
| 305 |
+
return Action.Sprint
|
| 306 |
+
|
| 307 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def corner(obs, player_x, player_y):
|
| 311 |
+
""" perform a shot in corner game mode """
|
| 312 |
+
def environment_fits(obs, player_x, player_y):
|
| 313 |
+
""" environment fits constraints """
|
| 314 |
+
# it is corner game mode
|
| 315 |
+
if obs['game_mode'] == GameMode.Corner:
|
| 316 |
+
return True
|
| 317 |
+
return False
|
| 318 |
+
|
| 319 |
+
def get_action(obs, player_x, player_y):
|
| 320 |
+
""" get action of this memory pattern """
|
| 321 |
+
if player_y > 0:
|
| 322 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
| 323 |
+
return Action.TopRight
|
| 324 |
+
else:
|
| 325 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 326 |
+
return Action.BottomRight
|
| 327 |
+
return Action.HighPass
|
| 328 |
+
|
| 329 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def free_kick(obs, player_x, player_y):
|
| 333 |
+
""" perform a high pass or a shot in free kick game mode """
|
| 334 |
+
def environment_fits(obs, player_x, player_y):
|
| 335 |
+
""" environment fits constraints """
|
| 336 |
+
# it is free kick game mode
|
| 337 |
+
if obs['game_mode'] == GameMode.FreeKick:
|
| 338 |
+
return True
|
| 339 |
+
return False
|
| 340 |
+
|
| 341 |
+
def get_action(obs, player_x, player_y):
|
| 342 |
+
""" get action of this memory pattern """
|
| 343 |
+
# shot if player close to goal
|
| 344 |
+
if player_x > 0.5:
|
| 345 |
+
if player_y > 0:
|
| 346 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
| 347 |
+
return Action.TopRight
|
| 348 |
+
else:
|
| 349 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 350 |
+
return Action.BottomRight
|
| 351 |
+
return Action.Shot
|
| 352 |
+
# high pass if player far from goal
|
| 353 |
+
else:
|
| 354 |
+
if player_y > 0:
|
| 355 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 356 |
+
return Action.BottomRight
|
| 357 |
+
else:
|
| 358 |
+
if Action.TopRight not in obs['sticky_actions']:
|
| 359 |
+
return Action.TopRight
|
| 360 |
+
return Action.ShortPass
|
| 361 |
+
|
| 362 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def goal_kick(obs, player_x, player_y):
|
| 366 |
+
""" perform a short pass in goal kick game mode """
|
| 367 |
+
def environment_fits(obs, player_x, player_y):
|
| 368 |
+
""" environment fits constraints """
|
| 369 |
+
# it is goal kick game mode
|
| 370 |
+
if obs['game_mode'] == GameMode.GoalKick:
|
| 371 |
+
return True
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
def get_action(obs, player_x, player_y):
|
| 375 |
+
""" get action of this memory pattern """
|
| 376 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 377 |
+
return Action.BottomRight
|
| 378 |
+
return Action.ShortPass
|
| 379 |
+
|
| 380 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def kick_off(obs, player_x, player_y):
|
| 384 |
+
""" perform a short pass in kick off game mode """
|
| 385 |
+
def environment_fits(obs, player_x, player_y):
|
| 386 |
+
""" environment fits constraints """
|
| 387 |
+
# it is kick off game mode
|
| 388 |
+
if obs['game_mode'] == GameMode.KickOff:
|
| 389 |
+
return True
|
| 390 |
+
return False
|
| 391 |
+
|
| 392 |
+
def get_action(obs, player_x, player_y):
|
| 393 |
+
""" get action of this memory pattern """
|
| 394 |
+
if player_y > 0:
|
| 395 |
+
if Action.Top not in obs["sticky_actions"]:
|
| 396 |
+
return Action.Top
|
| 397 |
+
else:
|
| 398 |
+
if Action.Bottom not in obs["sticky_actions"]:
|
| 399 |
+
return Action.Bottom
|
| 400 |
+
return Action.ShortPass
|
| 401 |
+
|
| 402 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def penalty(obs, player_x, player_y):
|
| 406 |
+
""" perform a shot in penalty game mode """
|
| 407 |
+
def environment_fits(obs, player_x, player_y):
|
| 408 |
+
""" environment fits constraints """
|
| 409 |
+
# it is penalty game mode
|
| 410 |
+
if obs['game_mode'] == GameMode.Penalty:
|
| 411 |
+
return True
|
| 412 |
+
return False
|
| 413 |
+
|
| 414 |
+
def get_action(obs, player_x, player_y):
|
| 415 |
+
""" get action of this memory pattern """
|
| 416 |
+
if (random.random() < 0.5 and
|
| 417 |
+
Action.TopRight not in obs["sticky_actions"] and
|
| 418 |
+
Action.BottomRight not in obs["sticky_actions"]):
|
| 419 |
+
return Action.TopRight
|
| 420 |
+
else:
|
| 421 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 422 |
+
return Action.BottomRight
|
| 423 |
+
return Action.Shot
|
| 424 |
+
|
| 425 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 426 |
+
|
| 427 |
+
def throw_in(obs, player_x, player_y):
|
| 428 |
+
""" perform a short pass in throw in game mode """
|
| 429 |
+
def environment_fits(obs, player_x, player_y):
|
| 430 |
+
""" environment fits constraints """
|
| 431 |
+
# it is throw in game mode
|
| 432 |
+
if obs['game_mode'] == GameMode.ThrowIn:
|
| 433 |
+
return True
|
| 434 |
+
return False
|
| 435 |
+
|
| 436 |
+
def get_action(obs, player_x, player_y):
|
| 437 |
+
""" get action of this memory pattern """
|
| 438 |
+
if Action.Right not in obs["sticky_actions"]:
|
| 439 |
+
return Action.Right
|
| 440 |
+
return Action.ShortPass
|
| 441 |
+
|
| 442 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def defence_memory_patterns(obs, player_x, player_y):
|
| 446 |
+
""" group of memory patterns for environments in which opponent's team has the ball """
|
| 447 |
+
def environment_fits(obs, player_x, player_y):
|
| 448 |
+
""" environment fits constraints """
|
| 449 |
+
# player don't have the ball
|
| 450 |
+
if obs["ball_owned_team"] != 0:
|
| 451 |
+
return True
|
| 452 |
+
return False
|
| 453 |
+
|
| 454 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 455 |
+
""" get list of memory patterns """
|
| 456 |
+
# shift ball position
|
| 457 |
+
obs["ball"][0] += obs["ball_direction"][0] * 7
|
| 458 |
+
obs["ball"][1] += obs["ball_direction"][1] * 3
|
| 459 |
+
# if opponent has the ball and is far from y axis center
|
| 460 |
+
if abs(obs["ball"][1]) > 0.07 and obs["ball_owned_team"] == 1:
|
| 461 |
+
obs["ball"][0] -= 0.01
|
| 462 |
+
if obs["ball"][1] > 0:
|
| 463 |
+
obs["ball"][1] -= 0.01
|
| 464 |
+
else:
|
| 465 |
+
obs["ball"][1] += 0.01
|
| 466 |
+
|
| 467 |
+
memory_patterns = [
|
| 468 |
+
start_sprinting,
|
| 469 |
+
run_to_ball_right,
|
| 470 |
+
run_to_ball_left,
|
| 471 |
+
run_to_ball_bottom,
|
| 472 |
+
run_to_ball_top,
|
| 473 |
+
run_to_ball_top_right,
|
| 474 |
+
run_to_ball_top_left,
|
| 475 |
+
run_to_ball_bottom_right,
|
| 476 |
+
run_to_ball_bottom_left,
|
| 477 |
+
idle
|
| 478 |
+
]
|
| 479 |
+
return memory_patterns
|
| 480 |
+
|
| 481 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 482 |
+
|
| 483 |
+
def goalkeeper_memory_patterns(obs, player_x, player_y):
|
| 484 |
+
""" group of memory patterns for goalkeeper """
|
| 485 |
+
def environment_fits(obs, player_x, player_y):
|
| 486 |
+
""" environment fits constraints """
|
| 487 |
+
# player is a goalkeeper have the ball
|
| 488 |
+
if (obs["ball_owned_player"] == obs["active"] and
|
| 489 |
+
obs["ball_owned_team"] == 0 and
|
| 490 |
+
obs["ball_owned_player"] == 0):
|
| 491 |
+
return True
|
| 492 |
+
return False
|
| 493 |
+
|
| 494 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 495 |
+
""" get list of memory patterns """
|
| 496 |
+
memory_patterns = [
|
| 497 |
+
long_pass_forward,
|
| 498 |
+
idle
|
| 499 |
+
]
|
| 500 |
+
return memory_patterns
|
| 501 |
+
|
| 502 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def offence_memory_patterns(obs, player_x, player_y):
|
| 506 |
+
""" group of memory patterns for environments in which player's team has the ball """
|
| 507 |
+
def environment_fits(obs, player_x, player_y):
|
| 508 |
+
""" environment fits constraints """
|
| 509 |
+
# player have the ball
|
| 510 |
+
if obs["ball_owned_player"] == obs["active"] and obs["ball_owned_team"] == 0:
|
| 511 |
+
return True
|
| 512 |
+
return False
|
| 513 |
+
|
| 514 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 515 |
+
""" get list of memory patterns """
|
| 516 |
+
memory_patterns = [
|
| 517 |
+
close_to_goalkeeper_shot,
|
| 518 |
+
spot_shot,
|
| 519 |
+
cross,
|
| 520 |
+
long_pass_forward,
|
| 521 |
+
keep_the_ball,
|
| 522 |
+
idle
|
| 523 |
+
]
|
| 524 |
+
return memory_patterns
|
| 525 |
+
|
| 526 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def other_memory_patterns(obs, player_x, player_y):
|
| 530 |
+
""" group of memory patterns for all other environments """
|
| 531 |
+
def environment_fits(obs, player_x, player_y):
|
| 532 |
+
""" environment fits constraints """
|
| 533 |
+
return True
|
| 534 |
+
|
| 535 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 536 |
+
""" get list of memory patterns """
|
| 537 |
+
memory_patterns = [
|
| 538 |
+
idle
|
| 539 |
+
]
|
| 540 |
+
return memory_patterns
|
| 541 |
+
|
| 542 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 543 |
+
|
| 544 |
+
def special_game_modes_memory_patterns(obs, player_x, player_y):
|
| 545 |
+
""" group of memory patterns for special game mode environments """
|
| 546 |
+
def environment_fits(obs, player_x, player_y):
|
| 547 |
+
""" environment fits constraints """
|
| 548 |
+
# if game mode is not normal
|
| 549 |
+
if obs['game_mode'] != GameMode.Normal:
|
| 550 |
+
return True
|
| 551 |
+
return False
|
| 552 |
+
|
| 553 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 554 |
+
""" get list of memory patterns """
|
| 555 |
+
memory_patterns = [
|
| 556 |
+
corner,
|
| 557 |
+
free_kick,
|
| 558 |
+
goal_kick,
|
| 559 |
+
kick_off,
|
| 560 |
+
penalty,
|
| 561 |
+
throw_in,
|
| 562 |
+
idle
|
| 563 |
+
]
|
| 564 |
+
return memory_patterns
|
| 565 |
+
|
| 566 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def special_spot_shot(obs, player_x, player_y):
|
| 570 |
+
""" group of memory patterns for special game mode environments """
|
| 571 |
+
def environment_fits(obs, player_x, player_y):
|
| 572 |
+
""" environment fits constraints """
|
| 573 |
+
# if game mode is not normal
|
| 574 |
+
if player_x > 0.8 and abs(player_y) < 0.21:
|
| 575 |
+
return True
|
| 576 |
+
return False
|
| 577 |
+
|
| 578 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 579 |
+
""" get list of memory patterns """
|
| 580 |
+
memory_patterns = [
|
| 581 |
+
shot,
|
| 582 |
+
idle
|
| 583 |
+
]
|
| 584 |
+
return memory_patterns
|
| 585 |
+
|
| 586 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def own_goal(obs, player_x, player_y):
|
| 590 |
+
""" group of memory patterns for special game mode environments """
|
| 591 |
+
def environment_fits(obs, player_x, player_y):
|
| 592 |
+
""" environment fits constraints """
|
| 593 |
+
# if game mode is not normal
|
| 594 |
+
if player_x < -0.9 and player_y:
|
| 595 |
+
return True
|
| 596 |
+
return False
|
| 597 |
+
|
| 598 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 599 |
+
""" get list of memory patterns """
|
| 600 |
+
memory_patterns = [
|
| 601 |
+
own_goal_2
|
| 602 |
+
]
|
| 603 |
+
return memory_patterns
|
| 604 |
+
|
| 605 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 606 |
+
|
| 607 |
+
def get_best_direction(obs, target_direction):
|
| 608 |
+
active_position = obs["left_team"][obs["active"]]
|
| 609 |
+
relative_goal_position = np.array(target_direction) - active_position
|
| 610 |
+
all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
|
| 611 |
+
best_direction = np.argmax([np.dot(relative_goal_position, v) for v in all_directions_vecs])
|
| 612 |
+
return ALL_DIRECTION_ACTIONS[best_direction]
|
| 613 |
+
|
| 614 |
+
def get_distance2ball(obs):
|
| 615 |
+
return np.linalg.norm(obs["ball"][:2] - obs["left_team"][obs['active']])
|
| 616 |
+
|
| 617 |
+
def get_target2line(obs):
|
| 618 |
+
active_position = obs["left_team"][obs["active"]]
|
| 619 |
+
ball_x, ball_y = obs['ball'][0], obs['ball'][1]
|
| 620 |
+
distance2goal = ((ball_x + 1) ** 2 + ball_y ** 2) ** 0.5 + 1e-5
|
| 621 |
+
cos_theta = (ball_x + 1) / distance2goal
|
| 622 |
+
sin_theta = ball_y / distance2goal
|
| 623 |
+
target_pos = np.array([0.03 * cos_theta - 1, 0.03 * sin_theta])
|
| 624 |
+
return target_pos
|
| 625 |
+
|
| 626 |
+
def already_near_goal(obs, player_x, player_y):
|
| 627 |
+
""" do nothing, release all sticky actions """
|
| 628 |
+
def environment_fits(obs, player_x, player_y):
|
| 629 |
+
""" environment fits constraints """
|
| 630 |
+
active_position = obs["left_team"][obs["active"]]
|
| 631 |
+
relative_goal_position = np.array([-1 + GOAL_BIAS, 0]) - active_position
|
| 632 |
+
distance2goal = np.linalg.norm(relative_goal_position)
|
| 633 |
+
if distance2goal < 0.02:
|
| 634 |
+
return True
|
| 635 |
+
return False
|
| 636 |
+
|
| 637 |
+
def get_action(obs, player_x, player_y):
|
| 638 |
+
""" get action of this memory pattern """
|
| 639 |
+
# print(obs["sticky_actions"])
|
| 640 |
+
if Action.Sprint in obs["sticky_actions"]:
|
| 641 |
+
return Action.ReleaseSprint
|
| 642 |
+
if Action.Dribble in obs["sticky_actions"]:
|
| 643 |
+
return Action.ReleaseDribble
|
| 644 |
+
if len(obs["sticky_actions"]) > 0:
|
| 645 |
+
return Action.ReleaseDirection
|
| 646 |
+
return Action.Idle
|
| 647 |
+
|
| 648 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 649 |
+
|
| 650 |
+
def already_in_line(obs, player_x, player_y):
|
| 651 |
+
""" do nothing, release all sticky actions """
|
| 652 |
+
def environment_fits(obs, player_x, player_y):
|
| 653 |
+
""" environment fits constraints """
|
| 654 |
+
|
| 655 |
+
target_pos = get_target2line(obs)
|
| 656 |
+
distance2goal = np.linalg.norm(target_pos - obs['left_team'][obs['active']])
|
| 657 |
+
if distance2goal < 0.02:
|
| 658 |
+
return True
|
| 659 |
+
return False
|
| 660 |
+
|
| 661 |
+
def get_action(obs, player_x, player_y):
|
| 662 |
+
""" get action of this memory pattern """
|
| 663 |
+
# print(obs["sticky_actions"])
|
| 664 |
+
if Action.Sprint in obs["sticky_actions"]:
|
| 665 |
+
return Action.ReleaseSprint
|
| 666 |
+
if Action.Dribble in obs["sticky_actions"]:
|
| 667 |
+
return Action.ReleaseDribble
|
| 668 |
+
if len(obs["sticky_actions"]) > 0:
|
| 669 |
+
return Action.ReleaseDirection
|
| 670 |
+
return Action.Idle
|
| 671 |
+
|
| 672 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 673 |
+
|
| 674 |
+
def run_to_goal(obs, player_x, player_y):
|
| 675 |
+
def environment_fits(obs, player_x, player_y):
|
| 676 |
+
""" environment fits constraints """
|
| 677 |
+
return True
|
| 678 |
+
|
| 679 |
+
def get_action(obs, player_x, player_y):
|
| 680 |
+
# active_position = obs["left_team"][obs["active"]]
|
| 681 |
+
# relative_goal_position = np.array([-1 + GOAL_BIAS, 0]) - active_position
|
| 682 |
+
# all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
|
| 683 |
+
# best_direction = np.argmax([np.dot(relative_goal_position, v) for v in all_directions_vecs])
|
| 684 |
+
# return ALL_DIRECTION_ACTIONS[best_direction]
|
| 685 |
+
return get_best_direction(obs, [-1 + GOAL_BIAS, 0])
|
| 686 |
+
|
| 687 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 688 |
+
|
| 689 |
+
def run_to_line(obs, player_x, player_y):
|
| 690 |
+
def environment_fits(obs, player_x, player_y):
|
| 691 |
+
""" environment fits constraints """
|
| 692 |
+
return True
|
| 693 |
+
|
| 694 |
+
def get_action(obs, player_x, player_y):
|
| 695 |
+
target_pos = get_target2line(obs)
|
| 696 |
+
return get_best_direction(obs, target_pos)
|
| 697 |
+
|
| 698 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 699 |
+
|
| 700 |
+
def goal_keeper_far_pattern(obs, player_x, player_y):
|
| 701 |
+
def environment_fits(obs, player_x, player_y):
|
| 702 |
+
""" environment fits constraints """
|
| 703 |
+
# player have the ball
|
| 704 |
+
if (obs["active"] == 0):
|
| 705 |
+
active_position = obs["left_team"][0]
|
| 706 |
+
relative_ball_position = obs["ball"][:2] - active_position
|
| 707 |
+
distance2ball = np.linalg.norm(relative_ball_position)
|
| 708 |
+
if distance2ball > 0.5 or (obs['ball_owned_team'] == 0 and obs['ball_owned_player'] != 0):
|
| 709 |
+
return True
|
| 710 |
+
if active_position[0] > -0.7 or abs(active_position[1]) > 0.25:
|
| 711 |
+
for teammate_pos in obs['left_team'][1:]:
|
| 712 |
+
teammate_dis = np.linalg.norm(obs["ball"][:2] - teammate_pos)
|
| 713 |
+
if teammate_dis < distance2ball:
|
| 714 |
+
return True
|
| 715 |
+
return False
|
| 716 |
+
|
| 717 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 718 |
+
""" get list of memory patterns """
|
| 719 |
+
memory_patterns = [
|
| 720 |
+
already_near_goal,
|
| 721 |
+
start_sprinting,
|
| 722 |
+
run_to_goal
|
| 723 |
+
]
|
| 724 |
+
return memory_patterns
|
| 725 |
+
|
| 726 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 727 |
+
|
| 728 |
+
def ball_distance_2_5(obs, player_x, player_y):
|
| 729 |
+
def environment_fits(obs, player_x, player_y):
|
| 730 |
+
""" environment fits constraints """
|
| 731 |
+
# player have the ball
|
| 732 |
+
if (obs["active"] == 0 and obs['ball_owned_team'] != 0):
|
| 733 |
+
distance2ball = get_distance2ball(obs)
|
| 734 |
+
if distance2ball <= 0.5 and distance2ball >= 0.2:
|
| 735 |
+
return True
|
| 736 |
+
return False
|
| 737 |
+
|
| 738 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 739 |
+
""" get list of memory patterns """
|
| 740 |
+
memory_patterns = [
|
| 741 |
+
already_in_line,
|
| 742 |
+
start_sprinting,
|
| 743 |
+
run_to_line
|
| 744 |
+
]
|
| 745 |
+
return memory_patterns
|
| 746 |
+
|
| 747 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 748 |
+
|
| 749 |
+
def ball_distance_close(obs, player_x, player_y):
|
| 750 |
+
def environment_fits(obs, player_x, player_y):
|
| 751 |
+
""" environment fits constraints """
|
| 752 |
+
# player have the ball
|
| 753 |
+
if (obs["active"] == 0 and obs['ball_owned_team'] != 0):
|
| 754 |
+
distance2ball = get_distance2ball(obs)
|
| 755 |
+
if distance2ball < 0.25:
|
| 756 |
+
return True
|
| 757 |
+
return False
|
| 758 |
+
|
| 759 |
+
def get_memory_patterns(obs, player_x, player_y):
|
| 760 |
+
""" get list of memory patterns """
|
| 761 |
+
memory_patterns = [
|
| 762 |
+
shot
|
| 763 |
+
]
|
| 764 |
+
return memory_patterns
|
| 765 |
+
|
| 766 |
+
return {"environment_fits": environment_fits, "get_memory_patterns": get_memory_patterns}
|
| 767 |
+
|
| 768 |
+
# list of groups of memory patterns
|
| 769 |
+
groups_of_memory_patterns = [
|
| 770 |
+
goal_keeper_far_pattern, # 安全
|
| 771 |
+
goalkeeper_memory_patterns, # 守门员持球
|
| 772 |
+
# special_spot_shot, # 射门 进不去
|
| 773 |
+
special_game_modes_memory_patterns, # 特殊game mode
|
| 774 |
+
ball_distance_2_5,
|
| 775 |
+
ball_distance_close,
|
| 776 |
+
# own_goal,
|
| 777 |
+
# offence_memory_patterns, # 我方持球 进不去
|
| 778 |
+
defence_memory_patterns,
|
| 779 |
+
other_memory_patterns # idle
|
| 780 |
+
]
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
def keep_the_ball(obs, player_x, player_y):
|
| 784 |
+
def environment_fits(obs, player_x, player_y):
|
| 785 |
+
return True
|
| 786 |
+
|
| 787 |
+
def get_action(obs, player_x, player_y):
|
| 788 |
+
right_team, left_team = obs['right_team'], obs['left_team']
|
| 789 |
+
dist = [get_distance(player_x, player_y, i) for i in right_team]
|
| 790 |
+
closest = right_team[np.argmin(dist)]
|
| 791 |
+
near = [i for i in right_team if (i[0] < player_x + 0.2) and (i[0] > player_x) and (i[1] > player_y - 0.05)
|
| 792 |
+
and (i[1] < player_y + 0.05)]
|
| 793 |
+
back = [i for i in right_team if (i[0] > player_x)]
|
| 794 |
+
bottom_right = [i for i in left_team if (i[0] > player_x - 0.05) and (i[0] < player_x + 0.2) and (i[1] < player_y + 0.2) and
|
| 795 |
+
(i[1] > player_y)]
|
| 796 |
+
top_right = [i for i in left_team if (i[0] > player_x - 0.05) and (i[0] < player_x + 0.2) and (i[1] > player_y - 0.2) and
|
| 797 |
+
(i[1] < player_y)]
|
| 798 |
+
bottom_left = [i for i in left_team if (i[0] < player_x) and (i[0] > player_x - 0.2) and (i[1] < player_y + 0.2) and
|
| 799 |
+
(i[1] > player_y)]
|
| 800 |
+
top_left = [i for i in left_team if (i[0] < player_x) and (i[0] > player_x - 0.2) and (i[1] > player_y - 0.2) and
|
| 801 |
+
(i[1] < player_y)]
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
if len(near) == 0:
|
| 805 |
+
return Action.Right
|
| 806 |
+
|
| 807 |
+
if player_y > 0:
|
| 808 |
+
if player_y > 0.35:
|
| 809 |
+
return Action.Right
|
| 810 |
+
if len(bottom_right) > 0:
|
| 811 |
+
if Action.BottomRight not in obs['sticky_actions']:
|
| 812 |
+
return Action.BottomRight
|
| 813 |
+
return Action.ShortPass
|
| 814 |
+
return Action.BottomRight
|
| 815 |
+
|
| 816 |
+
if player_y < 0:
|
| 817 |
+
if player_y < -0.35:
|
| 818 |
+
return Action.Right
|
| 819 |
+
if len(top_right) > 0:
|
| 820 |
+
if Action.TopRight not in obs['sticky_actions']:
|
| 821 |
+
return Action.TopRight
|
| 822 |
+
return Action.ShortPass
|
| 823 |
+
return Action.TopRight
|
| 824 |
+
|
| 825 |
+
return {'environment_fits': environment_fits, 'get_action': get_action}
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def spot_shot(obs, player_x, player_y):
|
| 829 |
+
""" shot if close to the goalkeeper """
|
| 830 |
+
def environment_fits(obs, player_x, player_y):
|
| 831 |
+
""" environment fits constraints """
|
| 832 |
+
# shoot if in spotted location
|
| 833 |
+
if player_x > 0.75 and abs(player_y) < 0.21:
|
| 834 |
+
return True
|
| 835 |
+
return False
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def get_action(obs, player_x, player_y):
|
| 839 |
+
""" get action of this memory pattern """
|
| 840 |
+
if player_y >= 0:
|
| 841 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
| 842 |
+
return Action.TopRight
|
| 843 |
+
else:
|
| 844 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 845 |
+
return Action.BottomRight
|
| 846 |
+
return Action.Shot
|
| 847 |
+
|
| 848 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
def cross(obs, player_x, player_y):
|
| 852 |
+
def environment_fits(obs, player_x, player_y):
|
| 853 |
+
if player_x > 0.7 and abs(player_y) > 0.21:
|
| 854 |
+
return True
|
| 855 |
+
return False
|
| 856 |
+
|
| 857 |
+
def get_action(obs, player_x, player_y):
|
| 858 |
+
|
| 859 |
+
if player_x > 0.88:
|
| 860 |
+
if player_y > 0:
|
| 861 |
+
if Action.Top not in obs['sticky_actions']:
|
| 862 |
+
return Action.Top
|
| 863 |
+
else:
|
| 864 |
+
if Action.Bottom not in obs['sticky_actions']:
|
| 865 |
+
return Action.Bottom
|
| 866 |
+
return Action.HighPass
|
| 867 |
+
|
| 868 |
+
if player_x > 0.9:
|
| 869 |
+
if (Action.Right in obs['sticky_actions'] or
|
| 870 |
+
Action.TopRight in obs['sticky_actions'] or
|
| 871 |
+
Action.BottomRight in obs['sticky_actions']):
|
| 872 |
+
return Action.ReleaseDirection
|
| 873 |
+
if Action.Right not in obs['sticky_actions']:
|
| 874 |
+
if player_y > 0:
|
| 875 |
+
if Action.Top not in obs['sticky_actions']:
|
| 876 |
+
return Action.Top
|
| 877 |
+
if player_y < 0:
|
| 878 |
+
if Action.Bottom not in obs['sticky_actions']:
|
| 879 |
+
return Action.Bottom
|
| 880 |
+
return Action.HighPass
|
| 881 |
+
|
| 882 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def close_to_goalkeeper_shot(obs, player_x, player_y):
|
| 886 |
+
""" shot if close to the goalkeeper """
|
| 887 |
+
def environment_fits(obs, player_x, player_y):
|
| 888 |
+
""" environment fits constraints """
|
| 889 |
+
goalkeeper_x = obs["right_team"][0][0] + obs["right_team_direction"][0][0] * 13
|
| 890 |
+
goalkeeper_y = obs["right_team"][0][1] + obs["right_team_direction"][0][1] * 13
|
| 891 |
+
goalkeeper = [goalkeeper_x,goalkeeper_y]
|
| 892 |
+
|
| 893 |
+
if get_distance(player_x, player_y, goalkeeper) < 0.25:
|
| 894 |
+
return True
|
| 895 |
+
return False
|
| 896 |
+
|
| 897 |
+
def get_action(obs, player_x, player_y):
|
| 898 |
+
""" get action of this memory pattern """
|
| 899 |
+
if player_y >= 0:
|
| 900 |
+
if Action.TopRight not in obs["sticky_actions"]:
|
| 901 |
+
return Action.TopRight
|
| 902 |
+
else:
|
| 903 |
+
if Action.BottomRight not in obs["sticky_actions"]:
|
| 904 |
+
return Action.BottomRight
|
| 905 |
+
return Action.Shot
|
| 906 |
+
|
| 907 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
def long_pass_forward(obs, player_x, player_y):
|
| 911 |
+
""" perform a high pass, if far from opponent's goal """
|
| 912 |
+
def environment_fits(obs, player_x, player_y):
|
| 913 |
+
""" environment fits constraints """
|
| 914 |
+
right_team = obs["right_team"][1:]
|
| 915 |
+
# player have the ball and is far from opponent's goal
|
| 916 |
+
if player_x < -0.4:
|
| 917 |
+
return True
|
| 918 |
+
return False
|
| 919 |
+
|
| 920 |
+
def get_action(obs, player_x, player_y):
|
| 921 |
+
""" get action of this memory pattern """
|
| 922 |
+
right_team, left_team = obs['right_team'], obs['left_team']
|
| 923 |
+
dist = [get_distance(player_x, player_y, i) for i in right_team]
|
| 924 |
+
closest = right_team[np.argmin(dist)]
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
if abs(player_y) > 0.22:
|
| 928 |
+
if Action.Right not in obs["sticky_actions"]:
|
| 929 |
+
return Action.Right
|
| 930 |
+
return Action.HighPass
|
| 931 |
+
|
| 932 |
+
if np.min(dist) > 0.4:
|
| 933 |
+
if player_y > 0:
|
| 934 |
+
return Action.Bottom
|
| 935 |
+
else:
|
| 936 |
+
return Action.Top
|
| 937 |
+
|
| 938 |
+
if np.min(dist) < 0.4 and np.min(dist) > 0.2:
|
| 939 |
+
if player_y < 0:
|
| 940 |
+
return Action.TopRight
|
| 941 |
+
else:
|
| 942 |
+
return Action.BottomRight
|
| 943 |
+
|
| 944 |
+
if np.min(dist) < 0.2:
|
| 945 |
+
if Action.Right not in obs['sticky_actions']:
|
| 946 |
+
return Action.Right
|
| 947 |
+
return Action.HighPass
|
| 948 |
+
|
| 949 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 950 |
+
|
| 951 |
+
def shot(obs, player_x, player_y):
|
| 952 |
+
def environment_fits(obs, player_x, player_y):
|
| 953 |
+
return True
|
| 954 |
+
|
| 955 |
+
def get_action(obs, player_x, player_y):
|
| 956 |
+
# if player_y > 0:
|
| 957 |
+
# if Action.TopRight not in obs['sticky_actions']:
|
| 958 |
+
# return Action.TopRight
|
| 959 |
+
# else:
|
| 960 |
+
# if Action.BottomRight not in obs['sticky_actions']:
|
| 961 |
+
# return Action.BottomRight
|
| 962 |
+
return Action.Shot
|
| 963 |
+
|
| 964 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
def own_goal_2(obs, player_x, player_y):
|
| 968 |
+
def environment_fits(obs, player_x, player_y):
|
| 969 |
+
return True
|
| 970 |
+
|
| 971 |
+
def get_action(obs, player_x, player_y):
|
| 972 |
+
return Action.Shot
|
| 973 |
+
|
| 974 |
+
return {"environment_fits": environment_fits, "get_action": get_action}
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
# @human_readable_agent wrapper modifies raw observations
|
| 978 |
+
# provided by the environment:
|
| 979 |
+
# https://github.com/google-research/football/blob/master/gfootball/doc/observation.md#raw-observations
|
| 980 |
+
# into a form easier to work with by humans.
|
| 981 |
+
# Following modifications are applied:
|
| 982 |
+
# - Action, PlayerRole and GameMode enums are introduced.
|
| 983 |
+
# - 'sticky_actions' are turned into a set of active actions (Action enum)
|
| 984 |
+
# see usage example below.
|
| 985 |
+
# - 'game_mode' is turned into GameMode enum.
|
| 986 |
+
# - 'designated' field is removed, as it always equals to 'active'
|
| 987 |
+
# when a single player is controlled on the team.
|
| 988 |
+
# - 'left_team_roles'/'right_team_roles' are turned into PlayerRole enums.
|
| 989 |
+
# - Action enum is to be returned by the agent function.
|
| 990 |
+
@human_readable_agent
|
| 991 |
+
def agent_get_action(obs):
|
| 992 |
+
""" Ole ole ole ole """
|
| 993 |
+
# dictionary for Memory Patterns data
|
| 994 |
+
obs["memory_patterns"] = {}
|
| 995 |
+
# We always control left team (observations and actions
|
| 996 |
+
# are mirrored appropriately by the environment).
|
| 997 |
+
controlled_player_pos = obs["left_team"][obs["active"]]
|
| 998 |
+
# get action of appropriate pattern in agent's memory
|
| 999 |
+
action = get_action_of_agent(obs, controlled_player_pos[0], controlled_player_pos[1])
|
| 1000 |
+
# return action
|
| 1001 |
+
return action
|
openrl_policy.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2023 The OpenRL Authors.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from torch.distributions import Categorical
|
| 22 |
+
|
| 23 |
+
import gym
|
| 24 |
+
|
| 25 |
+
def check(input):
|
| 26 |
+
output = torch.from_numpy(input) if type(input) == np.ndarray else input
|
| 27 |
+
return output
|
| 28 |
+
|
| 29 |
+
class FcEncoder(nn.Module):
|
| 30 |
+
def __init__(self, fc_num, input_size, output_size):
|
| 31 |
+
super(FcEncoder, self).__init__()
|
| 32 |
+
self.first_mlp = nn.Sequential(
|
| 33 |
+
nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)
|
| 34 |
+
)
|
| 35 |
+
self.mlp = nn.Sequential()
|
| 36 |
+
for _ in range(fc_num - 1):
|
| 37 |
+
self.mlp.append(nn.Sequential(
|
| 38 |
+
nn.Linear(output_size, output_size), nn.ReLU(), nn.LayerNorm(output_size)
|
| 39 |
+
))
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
output = self.first_mlp(x)
|
| 43 |
+
return self.mlp(output)
|
| 44 |
+
|
| 45 |
+
def init(module, weight_init, bias_init, gain=1):
|
| 46 |
+
weight_init(module.weight.data, gain=gain)
|
| 47 |
+
if module.bias is not None:
|
| 48 |
+
bias_init(module.bias.data)
|
| 49 |
+
return module
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class FixedCategorical(torch.distributions.Categorical):
|
| 53 |
+
def sample(self):
|
| 54 |
+
return super().sample().unsqueeze(-1)
|
| 55 |
+
|
| 56 |
+
def log_probs(self, actions):
|
| 57 |
+
return (
|
| 58 |
+
super()
|
| 59 |
+
.log_prob(actions.squeeze(-1))
|
| 60 |
+
.view(actions.size(0), -1)
|
| 61 |
+
.sum(-1)
|
| 62 |
+
.unsqueeze(-1)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def mode(self):
|
| 66 |
+
return self.probs.argmax(dim=-1, keepdim=True)
|
| 67 |
+
|
| 68 |
+
class Categorical(nn.Module):
|
| 69 |
+
def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
|
| 70 |
+
super(Categorical, self).__init__()
|
| 71 |
+
init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
|
| 72 |
+
def init_(m):
|
| 73 |
+
return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
|
| 74 |
+
|
| 75 |
+
self.linear = init_(nn.Linear(num_inputs, num_outputs))
|
| 76 |
+
|
| 77 |
+
def forward(self, x, available_actions=None):
|
| 78 |
+
x = self.linear(x)
|
| 79 |
+
if available_actions is not None:
|
| 80 |
+
x[available_actions == 0] = -1e10
|
| 81 |
+
return FixedCategorical(logits=x)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class AddBias(nn.Module):
|
| 85 |
+
def __init__(self, bias):
|
| 86 |
+
super(AddBias, self).__init__()
|
| 87 |
+
self._bias = nn.Parameter(bias.unsqueeze(1))
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
if x.dim() == 2:
|
| 91 |
+
bias = self._bias.t().view(1, -1)
|
| 92 |
+
else:
|
| 93 |
+
bias = self._bias.t().view(1, -1, 1, 1)
|
| 94 |
+
|
| 95 |
+
return x + bias
|
| 96 |
+
|
| 97 |
+
class ACTLayer(nn.Module):
|
| 98 |
+
def __init__(self, action_space, inputs_dim, use_orthogonal, gain):
|
| 99 |
+
super(ACTLayer, self).__init__()
|
| 100 |
+
self.multidiscrete_action = False
|
| 101 |
+
self.continuous_action = False
|
| 102 |
+
self.mixed_action = False
|
| 103 |
+
|
| 104 |
+
action_dim = action_space.n
|
| 105 |
+
self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def forward(self, x, available_actions=None, deterministic=False):
|
| 110 |
+
if self.mixed_action :
|
| 111 |
+
actions = []
|
| 112 |
+
action_log_probs = []
|
| 113 |
+
for action_out in self.action_outs:
|
| 114 |
+
action_logit = action_out(x)
|
| 115 |
+
action = action_logit.mode() if deterministic else action_logit.sample()
|
| 116 |
+
action_log_prob = action_logit.log_probs(action)
|
| 117 |
+
actions.append(action.float())
|
| 118 |
+
action_log_probs.append(action_log_prob)
|
| 119 |
+
|
| 120 |
+
actions = torch.cat(actions, -1)
|
| 121 |
+
action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
|
| 122 |
+
|
| 123 |
+
elif self.multidiscrete_action:
|
| 124 |
+
actions = []
|
| 125 |
+
action_log_probs = []
|
| 126 |
+
for action_out in self.action_outs:
|
| 127 |
+
action_logit = action_out(x)
|
| 128 |
+
action = action_logit.mode() if deterministic else action_logit.sample()
|
| 129 |
+
action_log_prob = action_logit.log_probs(action)
|
| 130 |
+
actions.append(action)
|
| 131 |
+
action_log_probs.append(action_log_prob)
|
| 132 |
+
|
| 133 |
+
actions = torch.cat(actions, -1)
|
| 134 |
+
action_log_probs = torch.cat(action_log_probs, -1)
|
| 135 |
+
|
| 136 |
+
elif self.continuous_action:
|
| 137 |
+
action_logits = self.action_out(x)
|
| 138 |
+
actions = action_logits.mode() if deterministic else action_logits.sample()
|
| 139 |
+
action_log_probs = action_logits.log_probs(actions)
|
| 140 |
+
|
| 141 |
+
else:
|
| 142 |
+
action_logits = self.action_out(x, available_actions)
|
| 143 |
+
actions = action_logits.mode() if deterministic else action_logits.sample()
|
| 144 |
+
action_log_probs = action_logits.log_probs(actions)
|
| 145 |
+
|
| 146 |
+
return actions, action_log_probs
|
| 147 |
+
|
| 148 |
+
def get_probs(self, x, available_actions=None):
|
| 149 |
+
if self.mixed_action or self.multidiscrete_action:
|
| 150 |
+
action_probs = []
|
| 151 |
+
for action_out in self.action_outs:
|
| 152 |
+
action_logit = action_out(x)
|
| 153 |
+
action_prob = action_logit.probs
|
| 154 |
+
action_probs.append(action_prob)
|
| 155 |
+
action_probs = torch.cat(action_probs, -1)
|
| 156 |
+
elif self.continuous_action:
|
| 157 |
+
action_logits = self.action_out(x)
|
| 158 |
+
action_probs = action_logits.probs
|
| 159 |
+
else:
|
| 160 |
+
action_logits = self.action_out(x, available_actions)
|
| 161 |
+
action_probs = action_logits.probs
|
| 162 |
+
|
| 163 |
+
return action_probs
|
| 164 |
+
|
| 165 |
+
def evaluate_actions(self, x, action, available_actions=None, active_masks=None, get_probs=False):
|
| 166 |
+
if self.mixed_action:
|
| 167 |
+
a, b = action.split((2, 1), -1)
|
| 168 |
+
b = b.long()
|
| 169 |
+
action = [a, b]
|
| 170 |
+
action_log_probs = []
|
| 171 |
+
dist_entropy = []
|
| 172 |
+
for action_out, act in zip(self.action_outs, action):
|
| 173 |
+
action_logit = action_out(x)
|
| 174 |
+
action_log_probs.append(action_logit.log_probs(act))
|
| 175 |
+
if active_masks is not None:
|
| 176 |
+
if len(action_logit.entropy().shape) == len(active_masks.shape):
|
| 177 |
+
dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum())
|
| 178 |
+
else:
|
| 179 |
+
dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum())
|
| 180 |
+
else:
|
| 181 |
+
dist_entropy.append(action_logit.entropy().mean())
|
| 182 |
+
|
| 183 |
+
action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
|
| 184 |
+
dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01
|
| 185 |
+
|
| 186 |
+
elif self.multidiscrete_action:
|
| 187 |
+
action = torch.transpose(action, 0, 1)
|
| 188 |
+
action_log_probs = []
|
| 189 |
+
dist_entropy = []
|
| 190 |
+
for action_out, act in zip(self.action_outs, action):
|
| 191 |
+
action_logit = action_out(x)
|
| 192 |
+
action_log_probs.append(action_logit.log_probs(act))
|
| 193 |
+
if active_masks is not None:
|
| 194 |
+
dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())
|
| 195 |
+
else:
|
| 196 |
+
dist_entropy.append(action_logit.entropy().mean())
|
| 197 |
+
|
| 198 |
+
action_log_probs = torch.cat(action_log_probs, -1) # ! could be wrong
|
| 199 |
+
dist_entropy = torch.tensor(dist_entropy).mean()
|
| 200 |
+
|
| 201 |
+
elif self.continuous_action:
|
| 202 |
+
action_logits = self.action_out(x)
|
| 203 |
+
action_log_probs = action_logits.log_probs(action)
|
| 204 |
+
act_entropy = action_logits.entropy()
|
| 205 |
+
# import pdb;pdb.set_trace()
|
| 206 |
+
if active_masks is not None:
|
| 207 |
+
dist_entropy = (act_entropy*active_masks).sum()/active_masks.sum()
|
| 208 |
+
else:
|
| 209 |
+
dist_entropy = act_entropy.mean()
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
action_logits = self.action_out(x, available_actions)
|
| 213 |
+
action_log_probs = action_logits.log_probs(action)
|
| 214 |
+
if active_masks is not None:
|
| 215 |
+
dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()
|
| 216 |
+
else:
|
| 217 |
+
dist_entropy = action_logits.entropy().mean()
|
| 218 |
+
if not get_probs:
|
| 219 |
+
return action_log_probs, dist_entropy
|
| 220 |
+
else:
|
| 221 |
+
return action_log_probs, dist_entropy, action_logits
|
| 222 |
+
|
| 223 |
+
class RNNLayer(nn.Module):
|
| 224 |
+
def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal,rnn_type='gru'):
|
| 225 |
+
super(RNNLayer, self).__init__()
|
| 226 |
+
self._recurrent_N = recurrent_N
|
| 227 |
+
self._use_orthogonal = use_orthogonal
|
| 228 |
+
self.rnn_type = rnn_type
|
| 229 |
+
if rnn_type == 'gru':
|
| 230 |
+
self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
|
| 231 |
+
elif rnn_type == 'lstm':
|
| 232 |
+
self.rnn = nn.LSTM(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
|
| 233 |
+
else:
|
| 234 |
+
raise NotImplementedError(f'RNN type {rnn_type} has not been implemented.')
|
| 235 |
+
|
| 236 |
+
for name, param in self.rnn.named_parameters():
|
| 237 |
+
if 'bias' in name:
|
| 238 |
+
nn.init.constant_(param, 0)
|
| 239 |
+
elif 'weight' in name:
|
| 240 |
+
if self._use_orthogonal:
|
| 241 |
+
nn.init.orthogonal_(param)
|
| 242 |
+
else:
|
| 243 |
+
nn.init.xavier_uniform_(param)
|
| 244 |
+
self.norm = nn.LayerNorm(outputs_dim)
|
| 245 |
+
|
| 246 |
+
def rnn_forward(self, x, h):
|
| 247 |
+
if self.rnn_type == 'lstm':
|
| 248 |
+
h = torch.split(h, h.shape[-1] // 2, dim=-1)
|
| 249 |
+
h = (h[0].contiguous(), h[1].contiguous())
|
| 250 |
+
x_, h_ = self.rnn(x, h)
|
| 251 |
+
if self.rnn_type == 'lstm':
|
| 252 |
+
h_ = torch.cat(h_, -1)
|
| 253 |
+
return x_, h_
|
| 254 |
+
|
| 255 |
+
def forward(self, x, hxs, masks):
|
| 256 |
+
if x.size(0) == hxs.size(0):
|
| 257 |
+
x, hxs = self.rnn_forward(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous())
|
| 258 |
+
#x= self.gru(x.unsqueeze(0))
|
| 259 |
+
x = x.squeeze(0)
|
| 260 |
+
hxs = hxs.transpose(0, 1)
|
| 261 |
+
else:
|
| 262 |
+
# x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
|
| 263 |
+
N = hxs.size(0)
|
| 264 |
+
T = int(x.size(0) / N)
|
| 265 |
+
|
| 266 |
+
# unflatten
|
| 267 |
+
x = x.view(T, N, x.size(1))
|
| 268 |
+
|
| 269 |
+
# Same deal with masks
|
| 270 |
+
masks = masks.view(T, N)
|
| 271 |
+
|
| 272 |
+
# Let's figure out which steps in the sequence have a zero for any agent
|
| 273 |
+
# We will always assume t=0 has a zero in it as that makes the logic cleaner
|
| 274 |
+
has_zeros = ((masks[1:] == 0.0)
|
| 275 |
+
.any(dim=-1)
|
| 276 |
+
.nonzero()
|
| 277 |
+
.squeeze()
|
| 278 |
+
.cpu())
|
| 279 |
+
|
| 280 |
+
# +1 to correct the masks[1:]
|
| 281 |
+
if has_zeros.dim() == 0:
|
| 282 |
+
# Deal with scalar
|
| 283 |
+
has_zeros = [has_zeros.item() + 1]
|
| 284 |
+
else:
|
| 285 |
+
has_zeros = (has_zeros + 1).numpy().tolist()
|
| 286 |
+
|
| 287 |
+
# add t=0 and t=T to the list
|
| 288 |
+
has_zeros = [0] + has_zeros + [T]
|
| 289 |
+
|
| 290 |
+
hxs = hxs.transpose(0, 1)
|
| 291 |
+
|
| 292 |
+
outputs = []
|
| 293 |
+
for i in range(len(has_zeros) - 1):
|
| 294 |
+
# We can now process steps that don't have any zeros in masks together!
|
| 295 |
+
# This is much faster
|
| 296 |
+
start_idx = has_zeros[i]
|
| 297 |
+
end_idx = has_zeros[i + 1]
|
| 298 |
+
temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous()
|
| 299 |
+
rnn_scores, hxs = self.rnn_forward(x[start_idx:end_idx], temp)
|
| 300 |
+
outputs.append(rnn_scores)
|
| 301 |
+
|
| 302 |
+
# assert len(outputs) == T
|
| 303 |
+
# x is a (T, N, -1) tensor
|
| 304 |
+
x = torch.cat(outputs, dim=0)
|
| 305 |
+
|
| 306 |
+
# flatten
|
| 307 |
+
x = x.reshape(T * N, -1)
|
| 308 |
+
hxs = hxs.transpose(0, 1)
|
| 309 |
+
|
| 310 |
+
x = self.norm(x)
|
| 311 |
+
return x, hxs
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class InputEncoder(nn.Module):
|
| 315 |
+
def __init__(self):
|
| 316 |
+
super(InputEncoder, self).__init__()
|
| 317 |
+
fc_layer_num = 2
|
| 318 |
+
fc_output_num = 64
|
| 319 |
+
self.active_input_num = 87
|
| 320 |
+
self.ball_owner_input_num = 57
|
| 321 |
+
self.left_input_num = 88
|
| 322 |
+
self.right_input_num = 88
|
| 323 |
+
self.match_state_input_num = 9
|
| 324 |
+
|
| 325 |
+
self.active_encoder = FcEncoder(fc_layer_num, self.active_input_num, fc_output_num)
|
| 326 |
+
self.ball_owner_encoder = FcEncoder(fc_layer_num, self.ball_owner_input_num, fc_output_num)
|
| 327 |
+
self.left_encoder = FcEncoder(fc_layer_num, self.left_input_num, fc_output_num)
|
| 328 |
+
self.right_encoder = FcEncoder(fc_layer_num, self.right_input_num, fc_output_num)
|
| 329 |
+
self.match_state_encoder = FcEncoder(fc_layer_num, self.match_state_input_num, self.match_state_input_num)
|
| 330 |
+
|
| 331 |
+
def forward(self, x):
|
| 332 |
+
active_vec = x[:, :self.active_input_num]
|
| 333 |
+
ball_owner_vec = x[:, self.active_input_num : self.active_input_num + self.ball_owner_input_num]
|
| 334 |
+
left_vec = x[:, self.active_input_num + self.ball_owner_input_num : self.active_input_num + self.ball_owner_input_num + self.left_input_num]
|
| 335 |
+
right_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num : \
|
| 336 |
+
self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num]
|
| 337 |
+
match_state_vec = x[:, self.active_input_num + self.ball_owner_input_num + self.left_input_num + self.right_input_num:]
|
| 338 |
+
|
| 339 |
+
active_output = self.active_encoder(active_vec)
|
| 340 |
+
ball_owner_output = self.ball_owner_encoder(ball_owner_vec)
|
| 341 |
+
left_output = self.left_encoder(left_vec)
|
| 342 |
+
right_output = self.right_encoder(right_vec)
|
| 343 |
+
match_state_output = self.match_state_encoder(match_state_vec)
|
| 344 |
+
|
| 345 |
+
return torch.cat([
|
| 346 |
+
active_output,
|
| 347 |
+
ball_owner_output,
|
| 348 |
+
left_output,
|
| 349 |
+
right_output,
|
| 350 |
+
match_state_output
|
| 351 |
+
], 1)
|
| 352 |
+
|
| 353 |
+
def get_fc(input_size, output_size):
|
| 354 |
+
return nn.Sequential(nn.Linear(input_size, output_size), nn.ReLU(), nn.LayerNorm(output_size))
|
| 355 |
+
|
| 356 |
+
class ObsEncoder(nn.Module):
|
| 357 |
+
def __init__(self, input_embedding_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type):
|
| 358 |
+
super(ObsEncoder, self).__init__()
|
| 359 |
+
self.input_encoder = InputEncoder() # input先过一遍input encoder
|
| 360 |
+
self.input_embedding = get_fc(input_embedding_size, hidden_size) # 将encoder输出进行embedding
|
| 361 |
+
self.rnn = RNNLayer(hidden_size, hidden_size, _recurrent_N, _use_orthogonal, rnn_type=rnn_type) # embedding输出过一遍rnn
|
| 362 |
+
self.after_rnn_mlp = get_fc(hidden_size, hidden_size) # 过了rnn后再过该mlp
|
| 363 |
+
|
| 364 |
+
def forward(self, obs, rnn_states, masks):
|
| 365 |
+
actor_features = self.input_encoder(obs)
|
| 366 |
+
actor_features = self.input_embedding(actor_features)
|
| 367 |
+
output, rnn_states = self.rnn(actor_features, rnn_states, masks)
|
| 368 |
+
return self.after_rnn_mlp(output), rnn_states
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class PolicyNetwork(nn.Module):
|
| 372 |
+
def __init__(self, device=torch.device("cpu")):
|
| 373 |
+
super(PolicyNetwork, self).__init__()
|
| 374 |
+
self.tpdv = dict(dtype=torch.float32, device=device)
|
| 375 |
+
self.device = device
|
| 376 |
+
self.hidden_size = 256
|
| 377 |
+
self._use_policy_active_masks = True
|
| 378 |
+
recurrent_N = 1
|
| 379 |
+
use_orthogonal = True
|
| 380 |
+
rnn_type = 'lstm'
|
| 381 |
+
gain = 0.01
|
| 382 |
+
action_space = gym.spaces.Discrete(20)
|
| 383 |
+
self.action_dim = 19
|
| 384 |
+
input_embedding_size = 64 * 4 + 9
|
| 385 |
+
self.active_id_size = 1
|
| 386 |
+
self.id_max = 11
|
| 387 |
+
|
| 388 |
+
self.obs_encoder = ObsEncoder(input_embedding_size, self.hidden_size, recurrent_N, use_orthogonal, rnn_type)
|
| 389 |
+
|
| 390 |
+
self.predict_id = get_fc(self.hidden_size + self.action_dim, self.id_max) # 其他信息(指除了active_id外的信息)过了rnn和一层mlp后,经过该层来预测id
|
| 391 |
+
self.id_embedding = get_fc(self.id_max, self.id_max) # active id作为输入,输出和其他信息的feature concat
|
| 392 |
+
|
| 393 |
+
self.before_act_wrapper = FcEncoder(2, self.hidden_size + self.id_max, self.hidden_size)
|
| 394 |
+
self.act = ACTLayer(action_space, self.hidden_size, use_orthogonal, gain)
|
| 395 |
+
|
| 396 |
+
self.to(device)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def forward(self, obs, rnn_states, masks=np.concatenate(np.ones((1, 1, 1), dtype=np.float32)), available_actions=None, deterministic=False):
|
| 400 |
+
obs = check(obs).to(**self.tpdv)
|
| 401 |
+
if available_actions is not None:
|
| 402 |
+
available_actions = check(available_actions).to(**self.tpdv)
|
| 403 |
+
masks = check(masks).to(**self.tpdv)
|
| 404 |
+
rnn_states = check(rnn_states).to(**self.tpdv)
|
| 405 |
+
|
| 406 |
+
active_id = obs[:,:self.active_id_size].squeeze(1).long()
|
| 407 |
+
id_onehot = torch.eye(self.id_max)[active_id].to(self.device)
|
| 408 |
+
obs = obs[:,self.active_id_size:]
|
| 409 |
+
|
| 410 |
+
obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks)
|
| 411 |
+
id_output = self.id_embedding(id_onehot)
|
| 412 |
+
output = torch.cat([id_output, obs_output], 1)
|
| 413 |
+
|
| 414 |
+
output = self.before_act_wrapper(output)
|
| 415 |
+
|
| 416 |
+
actions, action_log_probs = self.act(output, available_actions, deterministic)
|
| 417 |
+
return actions, rnn_states
|
| 418 |
+
|
| 419 |
+
def eval_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):
|
| 420 |
+
obs = check(obs).to(**self.tpdv)
|
| 421 |
+
if available_actions is not None:
|
| 422 |
+
available_actions = check(available_actions).to(**self.tpdv)
|
| 423 |
+
if active_masks is not None:
|
| 424 |
+
active_masks = check(active_masks).to(**self.tpdv)
|
| 425 |
+
masks = check(masks).to(**self.tpdv)
|
| 426 |
+
rnn_states = check(rnn_states).to(**self.tpdv)
|
| 427 |
+
action = check(action).to(**self.tpdv)
|
| 428 |
+
|
| 429 |
+
id_groundtruth = obs[:,:self.active_id_size].squeeze(1).long()
|
| 430 |
+
id_onehot = torch.eye(self.id_max)[id_groundtruth].to(self.device)
|
| 431 |
+
obs = obs[:,self.active_id_size:]
|
| 432 |
+
|
| 433 |
+
obs_output, rnn_states = self.obs_encoder(obs, rnn_states, masks)
|
| 434 |
+
id_output = self.id_embedding(id_onehot)
|
| 435 |
+
|
| 436 |
+
action_onehot = torch.eye(self.action_dim)[action.squeeze(1).long()].to(self.device)
|
| 437 |
+
|
| 438 |
+
id_prediction = self.predict_id(torch.cat([obs_output, action_onehot], 1))
|
| 439 |
+
output = torch.cat([id_output, obs_output], 1)
|
| 440 |
+
|
| 441 |
+
output = self.before_act_wrapper(output)
|
| 442 |
+
action_log_probs, dist_entropy = self.act.evaluate_actions(output, action, available_actions,
|
| 443 |
+
active_masks=active_masks if self._use_policy_active_masks else None)
|
| 444 |
+
values = None
|
| 445 |
+
return action_log_probs, dist_entropy, values, id_prediction, id_groundtruth
|
| 446 |
+
|
openrl_utils.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2023 The OpenRL Authors.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
# Area.
|
| 20 |
+
THIRD_X = 0.3
|
| 21 |
+
BOX_X = 0.7
|
| 22 |
+
MAX_X = 1.0
|
| 23 |
+
BOX_Y = 0.24
|
| 24 |
+
MAX_Y = 0.42
|
| 25 |
+
|
| 26 |
+
# Actions.
|
| 27 |
+
IDLE = 0
|
| 28 |
+
LEFT = 1
|
| 29 |
+
TOP_LEFT = 2
|
| 30 |
+
TOP = 3
|
| 31 |
+
TOP_RIGHT = 4
|
| 32 |
+
RIGHT = 5
|
| 33 |
+
BOTTOM_RIGHT = 6
|
| 34 |
+
BOTTOM = 7
|
| 35 |
+
BOTTOM_LEFT = 8
|
| 36 |
+
LONG_PASS = 9
|
| 37 |
+
HIGH_PASS = 10
|
| 38 |
+
SHORT_PASS = 11
|
| 39 |
+
SHOT = 12
|
| 40 |
+
SPRINT = 13
|
| 41 |
+
RELEASE_DIRECTION = 14
|
| 42 |
+
RELEASE_SPRINT = 15
|
| 43 |
+
SLIDING = 16
|
| 44 |
+
DRIBBLE = 17
|
| 45 |
+
RELEASE_DRIBBLE = 18
|
| 46 |
+
STICKY_LEFT = 0
|
| 47 |
+
STICKY_TOP_LEFT = 1
|
| 48 |
+
STICKY_TOP = 2
|
| 49 |
+
STICKY_TOP_RIGHT = 3
|
| 50 |
+
STICKY_RIGHT = 4
|
| 51 |
+
STICKY_BOTTOM_RIGHT = 5
|
| 52 |
+
STICKY_BOTTOM = 6
|
| 53 |
+
STICKY_BOTTOM_LEFT = 7
|
| 54 |
+
|
| 55 |
+
RIGHT_ACTIONS = [TOP_RIGHT, RIGHT, BOTTOM_RIGHT, TOP, BOTTOM]
|
| 56 |
+
LEFT_ACTIONS = [TOP_LEFT, LEFT, BOTTOM_LEFT, TOP, BOTTOM]
|
| 57 |
+
BOTTOM_ACTIONS = [BOTTOM_LEFT, BOTTOM, BOTTOM_RIGHT, LEFT, RIGHT]
|
| 58 |
+
TOP_ACTIONS = [TOP_LEFT, TOP, TOP_RIGHT, LEFT, RIGHT]
|
| 59 |
+
ALL_DIRECTION_ACTIONS = [LEFT, TOP_LEFT, TOP, TOP_RIGHT, RIGHT, BOTTOM_RIGHT, BOTTOM, BOTTOM_LEFT]
|
| 60 |
+
ALL_DIRECTION_VECS = [(-1, 0), (-1, -1), (0, -1), (1, -1), (1, 0), (1, 1), (0, 1), (-1, 1)]
|
| 61 |
+
|
| 62 |
+
def get_direction_action(available_action, sticky_actions, forbidden_action, target_action, active_direction, need_sprint):
|
| 63 |
+
available_action = np.zeros(19)
|
| 64 |
+
available_action[forbidden_action] = 0
|
| 65 |
+
available_action[target_action] = 1
|
| 66 |
+
|
| 67 |
+
if need_sprint:
|
| 68 |
+
available_action[RELEASE_SPRINT] = 0
|
| 69 |
+
if sticky_actions[8] == 0:
|
| 70 |
+
available_action = np.zeros(19)
|
| 71 |
+
available_action[SPRINT] = 1
|
| 72 |
+
else:
|
| 73 |
+
available_action[SPRINT] = 0
|
| 74 |
+
if sticky_actions[8] == 1:
|
| 75 |
+
available_action = np.zeros(19)
|
| 76 |
+
available_action[RELEASE_SPRINT] = 1
|
| 77 |
+
return available_action
|
| 78 |
+
|
| 79 |
+
def openrl_obs_deal(obs):
|
| 80 |
+
|
| 81 |
+
direction_x_bound = 0.03
|
| 82 |
+
direction_y_bound = 0.02
|
| 83 |
+
ball_direction_x_bound = 0.15
|
| 84 |
+
ball_direction_y_bound = 0.07
|
| 85 |
+
ball_direction_z_bound = 4
|
| 86 |
+
ball_rotation_x_bound = 0.0005
|
| 87 |
+
ball_rotation_y_bound = 0.0004
|
| 88 |
+
ball_rotation_z_bound = 0.015
|
| 89 |
+
active_id = [obs["active"]]
|
| 90 |
+
assert active_id[0] < 11 and active_id[0] >= 0, "active id is wrong, active id = {}".format(active_id[0])
|
| 91 |
+
# left team 88
|
| 92 |
+
left_position = np.concatenate(obs["left_team"])
|
| 93 |
+
left_direction = np.concatenate(obs["left_team_direction"])
|
| 94 |
+
left_tired_factor = obs["left_team_tired_factor"]
|
| 95 |
+
left_yellow_card = obs["left_team_yellow_card"]
|
| 96 |
+
left_red_card = ~obs["left_team_active"]
|
| 97 |
+
left_offside = np.zeros(11)
|
| 98 |
+
if obs["ball_owned_team"] == 0:
|
| 99 |
+
left_offside_line = max(0, obs["ball"][0], np.sort(obs["right_team"][:, 0])[-2])
|
| 100 |
+
left_offside = obs["left_team"][:, 0] > left_offside_line
|
| 101 |
+
left_offside[obs["ball_owned_player"]] = False
|
| 102 |
+
|
| 103 |
+
new_left_direction = left_direction.copy()
|
| 104 |
+
for counting in range(len(new_left_direction)):
|
| 105 |
+
new_left_direction[counting] = new_left_direction[counting] / direction_x_bound if counting % 2 == 0 else new_left_direction[counting] / direction_y_bound
|
| 106 |
+
|
| 107 |
+
left_team = np.concatenate([
|
| 108 |
+
left_position,
|
| 109 |
+
new_left_direction,
|
| 110 |
+
left_tired_factor,
|
| 111 |
+
left_yellow_card,
|
| 112 |
+
left_red_card,
|
| 113 |
+
left_offside,
|
| 114 |
+
]).astype(np.float64)
|
| 115 |
+
|
| 116 |
+
# right team 88
|
| 117 |
+
right_position = np.concatenate(obs["right_team"])
|
| 118 |
+
right_direction = np.concatenate(obs["right_team_direction"])
|
| 119 |
+
right_tired_factor = obs["right_team_tired_factor"]
|
| 120 |
+
right_yellow_card = obs["right_team_yellow_card"]
|
| 121 |
+
right_red_card = ~obs["right_team_active"]
|
| 122 |
+
right_offside = np.zeros(11)
|
| 123 |
+
if obs["ball_owned_team"] == 1:
|
| 124 |
+
right_offside_line = min(0, obs["ball"][0], np.sort(obs["left_team"][:, 0])[1])
|
| 125 |
+
right_offside = obs["right_team"][:, 0] < right_offside_line
|
| 126 |
+
right_offside[obs["ball_owned_player"]] = False
|
| 127 |
+
|
| 128 |
+
new_right_direction = right_direction.copy()
|
| 129 |
+
for counting in range(len(new_right_direction)):
|
| 130 |
+
new_right_direction[counting] = new_right_direction[counting] / direction_x_bound if counting % 2 == 0 else new_right_direction[counting] / direction_y_bound
|
| 131 |
+
|
| 132 |
+
right_team = np.concatenate([
|
| 133 |
+
right_position,
|
| 134 |
+
new_right_direction,
|
| 135 |
+
right_tired_factor,
|
| 136 |
+
right_yellow_card,
|
| 137 |
+
right_red_card,
|
| 138 |
+
right_offside,
|
| 139 |
+
]).astype(np.float64)
|
| 140 |
+
|
| 141 |
+
# active 18
|
| 142 |
+
sticky_actions = obs["sticky_actions"][:10]
|
| 143 |
+
active_position = obs["left_team"][obs["active"]]
|
| 144 |
+
active_direction = obs["left_team_direction"][obs["active"]]
|
| 145 |
+
active_tired_factor = left_tired_factor[obs["active"]]
|
| 146 |
+
active_yellow_card = left_yellow_card[obs["active"]]
|
| 147 |
+
active_red_card = left_red_card[obs["active"]]
|
| 148 |
+
active_offside = left_offside[obs["active"]]
|
| 149 |
+
|
| 150 |
+
new_active_direction = active_direction.copy()
|
| 151 |
+
new_active_direction[0] /= direction_x_bound
|
| 152 |
+
new_active_direction[1] /= direction_y_bound
|
| 153 |
+
|
| 154 |
+
active_player = np.concatenate([
|
| 155 |
+
sticky_actions,
|
| 156 |
+
active_position,
|
| 157 |
+
new_active_direction,
|
| 158 |
+
[active_tired_factor],
|
| 159 |
+
[active_yellow_card],
|
| 160 |
+
[active_red_card],
|
| 161 |
+
[active_offside],
|
| 162 |
+
]).astype(np.float64)
|
| 163 |
+
|
| 164 |
+
# relative 69
|
| 165 |
+
relative_ball_position = obs["ball"][:2] - active_position
|
| 166 |
+
distance2ball = np.linalg.norm(relative_ball_position)
|
| 167 |
+
relative_left_position = obs["left_team"] - active_position
|
| 168 |
+
distance2left = np.linalg.norm(relative_left_position, axis=1)
|
| 169 |
+
relative_left_position = np.concatenate(relative_left_position)
|
| 170 |
+
relative_right_position = obs["right_team"] - active_position
|
| 171 |
+
distance2right = np.linalg.norm(relative_right_position, axis=1)
|
| 172 |
+
relative_right_position = np.concatenate(relative_right_position)
|
| 173 |
+
relative_info = np.concatenate([
|
| 174 |
+
relative_ball_position,
|
| 175 |
+
[distance2ball],
|
| 176 |
+
relative_left_position,
|
| 177 |
+
distance2left,
|
| 178 |
+
relative_right_position,
|
| 179 |
+
distance2right,
|
| 180 |
+
]).astype(np.float64)
|
| 181 |
+
|
| 182 |
+
active_info = np.concatenate([active_player, relative_info]) # 87
|
| 183 |
+
|
| 184 |
+
# ball info 12
|
| 185 |
+
ball_owned_team = np.zeros(3)
|
| 186 |
+
ball_owned_team[obs["ball_owned_team"] + 1] = 1.0
|
| 187 |
+
new_ball_direction = obs["ball_direction"].copy()
|
| 188 |
+
new_ball_rotation = obs['ball_rotation'].copy()
|
| 189 |
+
for counting in range(len(new_ball_direction)):
|
| 190 |
+
if counting % 3 == 0:
|
| 191 |
+
new_ball_direction[counting] /= ball_direction_x_bound
|
| 192 |
+
new_ball_rotation[counting] /= ball_rotation_x_bound
|
| 193 |
+
if counting % 3 == 1:
|
| 194 |
+
new_ball_direction[counting] /= ball_direction_y_bound
|
| 195 |
+
new_ball_rotation[counting] /= ball_rotation_y_bound
|
| 196 |
+
if counting % 3 == 2:
|
| 197 |
+
new_ball_direction[counting] /= ball_direction_z_bound
|
| 198 |
+
new_ball_rotation[counting] /= ball_rotation_z_bound
|
| 199 |
+
ball_info = np.concatenate([
|
| 200 |
+
obs["ball"],
|
| 201 |
+
new_ball_direction,
|
| 202 |
+
ball_owned_team,
|
| 203 |
+
new_ball_rotation
|
| 204 |
+
]).astype(np.float64)
|
| 205 |
+
# ball owned player 23
|
| 206 |
+
ball_owned_player = np.zeros(23)
|
| 207 |
+
if obs["ball_owned_team"] == 1: # 对手
|
| 208 |
+
ball_owned_player[11 + obs['ball_owned_player']] = 1.0
|
| 209 |
+
ball_owned_player_pos = obs['right_team'][obs['ball_owned_player']]
|
| 210 |
+
ball_owned_player_direction = obs["right_team_direction"][obs['ball_owned_player']]
|
| 211 |
+
ball_owner_tired_factor = right_tired_factor[obs['ball_owned_player']]
|
| 212 |
+
ball_owner_yellow_card = right_yellow_card[obs['ball_owned_player']]
|
| 213 |
+
ball_owner_red_card = right_red_card[obs['ball_owned_player']]
|
| 214 |
+
ball_owner_offside = right_offside[obs['ball_owned_player']]
|
| 215 |
+
elif obs["ball_owned_team"] == 0:
|
| 216 |
+
ball_owned_player[obs['ball_owned_player']] = 1.0
|
| 217 |
+
ball_owned_player_pos = obs['left_team'][obs['ball_owned_player']]
|
| 218 |
+
ball_owned_player_direction = obs["left_team_direction"][obs['ball_owned_player']]
|
| 219 |
+
ball_owner_tired_factor = left_tired_factor[obs['ball_owned_player']]
|
| 220 |
+
ball_owner_yellow_card = left_yellow_card[obs['ball_owned_player']]
|
| 221 |
+
ball_owner_red_card = left_red_card[obs['ball_owned_player']]
|
| 222 |
+
ball_owner_offside = left_offside[obs['ball_owned_player']]
|
| 223 |
+
else:
|
| 224 |
+
ball_owned_player[-1] = 1.0
|
| 225 |
+
ball_owned_player_pos = np.zeros(2)
|
| 226 |
+
ball_owned_player_direction = np.zeros(2)
|
| 227 |
+
|
| 228 |
+
relative_ball_owner_position = np.zeros(2)
|
| 229 |
+
distance2ballowner = np.zeros(1)
|
| 230 |
+
ball_owner_info = np.zeros(4)
|
| 231 |
+
if obs["ball_owned_team"] != -1:
|
| 232 |
+
relative_ball_owner_position = ball_owned_player_pos - active_position
|
| 233 |
+
distance2ballowner = [np.linalg.norm(relative_ball_owner_position)]
|
| 234 |
+
ball_owner_info = np.concatenate([
|
| 235 |
+
[ball_owner_tired_factor],
|
| 236 |
+
[ball_owner_yellow_card],
|
| 237 |
+
[ball_owner_red_card],
|
| 238 |
+
[ball_owner_offside]
|
| 239 |
+
])
|
| 240 |
+
|
| 241 |
+
new_ball_owned_player_direction = ball_owned_player_direction.copy()
|
| 242 |
+
new_ball_owned_player_direction[0] /= direction_x_bound
|
| 243 |
+
new_ball_owned_player_direction[1] /= direction_y_bound
|
| 244 |
+
|
| 245 |
+
ball_own_active_info = np.concatenate([
|
| 246 |
+
ball_info, # 12
|
| 247 |
+
ball_owned_player, # 23
|
| 248 |
+
active_position, # 2
|
| 249 |
+
new_active_direction, # 2
|
| 250 |
+
[active_tired_factor], # 1
|
| 251 |
+
[active_yellow_card], # 1
|
| 252 |
+
[active_red_card], # 1
|
| 253 |
+
[active_offside], # 1
|
| 254 |
+
relative_ball_position, # 2
|
| 255 |
+
[distance2ball], # 1
|
| 256 |
+
ball_owned_player_pos, # 2
|
| 257 |
+
new_ball_owned_player_direction, # 2
|
| 258 |
+
relative_ball_owner_position, # 2
|
| 259 |
+
distance2ballowner, # 1
|
| 260 |
+
ball_owner_info # 4
|
| 261 |
+
])
|
| 262 |
+
|
| 263 |
+
# match state
|
| 264 |
+
game_mode = np.zeros(7)
|
| 265 |
+
game_mode[obs["game_mode"]] = 1.0
|
| 266 |
+
goal_diff_ratio = (obs["score"][0] - obs["score"][1]) / 5
|
| 267 |
+
steps_left_ratio = obs["steps_left"] / 3001
|
| 268 |
+
match_state = np.concatenate([
|
| 269 |
+
game_mode,
|
| 270 |
+
[goal_diff_ratio],
|
| 271 |
+
[steps_left_ratio],
|
| 272 |
+
]).astype(np.float64)
|
| 273 |
+
|
| 274 |
+
# available action
|
| 275 |
+
available_action = np.ones(19)
|
| 276 |
+
available_action[IDLE] = 0
|
| 277 |
+
available_action[RELEASE_DIRECTION] = 0
|
| 278 |
+
should_left = False
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if obs["game_mode"] == 0:
|
| 282 |
+
active_x = active_position[0]
|
| 283 |
+
counting_right_enemy_num = 0
|
| 284 |
+
counting_right_teammate_num = 0
|
| 285 |
+
counting_left_teammate_num = 0
|
| 286 |
+
for enemy_pos in obs["right_team"][1:]:
|
| 287 |
+
if active_x < enemy_pos[0]:
|
| 288 |
+
counting_right_enemy_num += 1
|
| 289 |
+
for teammate_pos in obs["left_team"][1:]:
|
| 290 |
+
if active_x < teammate_pos[0]:
|
| 291 |
+
counting_right_teammate_num += 1
|
| 292 |
+
if active_x > teammate_pos[0]:
|
| 293 |
+
counting_left_teammate_num += 1
|
| 294 |
+
|
| 295 |
+
if active_x > obs['ball'][0] + 0.05:
|
| 296 |
+
|
| 297 |
+
if counting_left_teammate_num < 2:
|
| 298 |
+
|
| 299 |
+
if obs['ball_owned_team'] != 0:
|
| 300 |
+
should_left = True
|
| 301 |
+
if should_left:
|
| 302 |
+
available_action = get_direction_action(available_action, sticky_actions, RIGHT_ACTIONS, [LEFT, BOTTOM_LEFT, TOP_LEFT], active_direction, True)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if (abs(relative_ball_position[0]) > 0.75 or abs(relative_ball_position[1]) > 0.5):
|
| 306 |
+
all_directions_vecs = [np.array(v) / np.linalg.norm(np.array(v)) for v in ALL_DIRECTION_VECS]
|
| 307 |
+
best_direction = np.argmax([np.dot(relative_ball_position, v) for v in all_directions_vecs])
|
| 308 |
+
target_direction = ALL_DIRECTION_ACTIONS[best_direction]
|
| 309 |
+
forbidden_actions = ALL_DIRECTION_ACTIONS.copy()
|
| 310 |
+
forbidden_actions.remove(target_direction)
|
| 311 |
+
available_action = get_direction_action(available_action, sticky_actions, forbidden_actions, [target_direction], active_direction, True)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if_i_hold_ball = (obs["ball_owned_team"] == 0 and obs["ball_owned_player"] == obs['active'])
|
| 315 |
+
ball_pos_offset = 0.05
|
| 316 |
+
no_ball_pos_offset = 0.03
|
| 317 |
+
|
| 318 |
+
active_x, active_y = active_position[0], active_position[1]
|
| 319 |
+
if_outside = False
|
| 320 |
+
if active_x <= (-1 + no_ball_pos_offset) or (if_i_hold_ball and active_x <= (-1 + ball_pos_offset)):
|
| 321 |
+
if_outside = True
|
| 322 |
+
action_index = LEFT_ACTIONS
|
| 323 |
+
target_direction = RIGHT
|
| 324 |
+
elif active_x >= (1 - no_ball_pos_offset) or (if_i_hold_ball and active_x >= (1 - ball_pos_offset)):
|
| 325 |
+
if_outside = True
|
| 326 |
+
action_index = RIGHT_ACTIONS
|
| 327 |
+
target_direction = LEFT
|
| 328 |
+
elif active_y >= (0.42 - no_ball_pos_offset) or (if_i_hold_ball and active_y >= (0.42 - ball_pos_offset)):
|
| 329 |
+
if_outside = True
|
| 330 |
+
action_index = BOTTOM_ACTIONS
|
| 331 |
+
target_direction = TOP
|
| 332 |
+
elif active_y <= (-0.42 + no_ball_pos_offset) or (if_i_hold_ball and active_x <= (-0.42 + ball_pos_offset)):
|
| 333 |
+
if_outside = True
|
| 334 |
+
action_index = TOP_ACTIONS
|
| 335 |
+
target_direction = BOTTOM
|
| 336 |
+
if obs["game_mode"] in [1, 2, 3, 4, 5]:
|
| 337 |
+
left2ball = np.linalg.norm(obs["left_team"] - obs["ball"][:2], axis=1)
|
| 338 |
+
right2ball = np.linalg.norm(obs["right_team"] - obs["ball"][:2], axis=1)
|
| 339 |
+
if np.min(left2ball) < np.min(right2ball) and obs["active"] == np.argmin(left2ball):
|
| 340 |
+
if_outside = False
|
| 341 |
+
elif obs["game_mode"] in [6]:
|
| 342 |
+
if obs["ball"][0] > 0 and active_position[0] > BOX_X:
|
| 343 |
+
if_outside = False
|
| 344 |
+
if if_outside:
|
| 345 |
+
available_action = get_direction_action(available_action, sticky_actions, action_index, [target_direction], active_direction, False)
|
| 346 |
+
|
| 347 |
+
if np.sum(sticky_actions[:8]) == 0:
|
| 348 |
+
available_action[RELEASE_DIRECTION] = 0
|
| 349 |
+
if sticky_actions[8] == 0:
|
| 350 |
+
available_action[RELEASE_SPRINT] = 0
|
| 351 |
+
else:
|
| 352 |
+
available_action[SPRINT] = 0
|
| 353 |
+
if sticky_actions[9] == 0:
|
| 354 |
+
available_action[RELEASE_DRIBBLE] = 0
|
| 355 |
+
else:
|
| 356 |
+
available_action[DRIBBLE] = 0
|
| 357 |
+
if active_position[0] < 0.4 or abs(active_position[1]) > 0.3:
|
| 358 |
+
available_action[SHOT] = 0
|
| 359 |
+
|
| 360 |
+
if obs["game_mode"] == 0:
|
| 361 |
+
if obs["ball_owned_team"] == -1:
|
| 362 |
+
available_action[DRIBBLE] = 0
|
| 363 |
+
if distance2ball >= 0.05:
|
| 364 |
+
available_action[SLIDING] = 0
|
| 365 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
|
| 366 |
+
elif obs["ball_owned_team"] == 0:
|
| 367 |
+
available_action[SLIDING] = 0
|
| 368 |
+
if distance2ball >= 0.05:
|
| 369 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, DRIBBLE]] = 0
|
| 370 |
+
elif obs["ball_owned_team"] == 1:
|
| 371 |
+
available_action[DRIBBLE] = 0
|
| 372 |
+
if distance2ball >= 0.05:
|
| 373 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, SLIDING]] = 0
|
| 374 |
+
elif obs["game_mode"] in [1, 2, 3, 4, 5]:
|
| 375 |
+
left2ball = np.linalg.norm(obs["left_team"] - obs["ball"][:2], axis=1)
|
| 376 |
+
right2ball = np.linalg.norm(obs["right_team"] - obs["ball"][:2], axis=1)
|
| 377 |
+
if np.min(left2ball) < np.min(right2ball) and obs["active"] == np.argmin(left2ball):
|
| 378 |
+
available_action[[SPRINT, RELEASE_SPRINT, SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
| 379 |
+
else:
|
| 380 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
|
| 381 |
+
available_action[[SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
| 382 |
+
elif obs["game_mode"] == 6:
|
| 383 |
+
if obs["ball"][0] > 0 and active_position[0] > BOX_X:
|
| 384 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS]] = 0
|
| 385 |
+
available_action[[SPRINT, RELEASE_SPRINT, SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
| 386 |
+
else:
|
| 387 |
+
available_action[[LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT]] = 0
|
| 388 |
+
available_action[[SLIDING, DRIBBLE, RELEASE_DRIBBLE]] = 0
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
obs = np.concatenate([
|
| 392 |
+
active_id, # 1
|
| 393 |
+
active_info, # 87
|
| 394 |
+
ball_own_active_info, # 57
|
| 395 |
+
left_team, # 88
|
| 396 |
+
right_team, # 88
|
| 397 |
+
match_state, # 9
|
| 398 |
+
])
|
| 399 |
+
|
| 400 |
+
share_obs = np.concatenate([
|
| 401 |
+
ball_info, # 12
|
| 402 |
+
ball_owned_player, # 23
|
| 403 |
+
left_team, # 88
|
| 404 |
+
right_team, # 88
|
| 405 |
+
match_state, # 9
|
| 406 |
+
])
|
| 407 |
+
|
| 408 |
+
assert available_action.sum() > 0
|
| 409 |
+
return dict(
|
| 410 |
+
obs=obs,
|
| 411 |
+
share_obs=share_obs,
|
| 412 |
+
available_action=available_action,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _t2n(x):
|
| 417 |
+
return x.detach().cpu().numpy()
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
|
submission.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2023 The OpenRL Authors.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
""""""
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
base_dir = Path(__file__).resolve().parent
|
| 26 |
+
sys.path.append(str(base_dir))
|
| 27 |
+
|
| 28 |
+
from openrl_policy import PolicyNetwork
|
| 29 |
+
from openrl_utils import openrl_obs_deal, _t2n
|
| 30 |
+
from goal_keeper import agent_get_action
|
| 31 |
+
|
| 32 |
+
class OpenRLAgent():
|
| 33 |
+
def __init__(self):
|
| 34 |
+
rnn_shape = [1,1,1,512]
|
| 35 |
+
self.rnn_hidden_state = [np.zeros(rnn_shape, dtype=np.float32) for _ in range (11)]
|
| 36 |
+
self.model = PolicyNetwork()
|
| 37 |
+
self.model.load_state_dict(torch.load( os.path.dirname(os.path.abspath(__file__)) + '/actor.pt', map_location=torch.device("cpu")))
|
| 38 |
+
self.model.eval()
|
| 39 |
+
|
| 40 |
+
def get_action(self,raw_obs,idx):
|
| 41 |
+
if idx == 0:
|
| 42 |
+
re_action = [[0]*19]
|
| 43 |
+
re_action_index = agent_get_action(raw_obs)[0]
|
| 44 |
+
re_action[0][re_action_index] = 1
|
| 45 |
+
return re_action
|
| 46 |
+
|
| 47 |
+
openrl_obs = openrl_obs_deal(raw_obs)
|
| 48 |
+
|
| 49 |
+
obs = openrl_obs['obs']
|
| 50 |
+
obs = np.concatenate(obs.reshape(1, 1, 330))
|
| 51 |
+
rnn_hidden_state = np.concatenate(self.rnn_hidden_state[idx])
|
| 52 |
+
avail_actions = np.zeros(20)
|
| 53 |
+
avail_actions[:19] = openrl_obs['available_action']
|
| 54 |
+
avail_actions = np.concatenate(avail_actions.reshape([1, 1, 20]))
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
actions, rnn_hidden_state = self.model(obs, rnn_hidden_state, available_actions=avail_actions, deterministic=True)
|
| 57 |
+
if actions[0][0] == 17 and raw_obs["sticky_actions"][8] == 1:
|
| 58 |
+
actions[0][0] = 15
|
| 59 |
+
self.rnn_hidden_state[idx] = np.array(np.split(_t2n(rnn_hidden_state), 1))
|
| 60 |
+
|
| 61 |
+
re_action = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
|
| 62 |
+
re_action[0][actions[0]] = 1
|
| 63 |
+
|
| 64 |
+
return re_action
|
| 65 |
+
|
| 66 |
+
agent = OpenRLAgent()
|
| 67 |
+
|
| 68 |
+
def my_controller(obs_list, action_space_list, is_act_continuous=False):
|
| 69 |
+
idx = obs_list['controlled_player_index'] % 11
|
| 70 |
+
del obs_list['controlled_player_index']
|
| 71 |
+
action = agent.get_action(obs_list,idx)
|
| 72 |
+
return action
|
| 73 |
+
|
| 74 |
+
def jidi_controller(obs_list=None):
|
| 75 |
+
if obs_list is None:
|
| 76 |
+
return
|
| 77 |
+
#重命名,防止加载错误
|
| 78 |
+
re = my_controller(obs_list,None)
|
| 79 |
+
assert isinstance(re,list)
|
| 80 |
+
assert isinstance(re[0],list)
|
| 81 |
+
return re
|