From 53bfdf596fbe63ca87319a959699992ad6a2ade6 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 21 Jun 2022 23:00:52 +0800 Subject: [PATCH] feature(nyz): add basic distar policy collect(ci skip) --- dizoo/distar/envs/__init__.py | 3 +- dizoo/distar/envs/fake_data.py | 10 +- dizoo/distar/envs/meta.py | 3 +- dizoo/distar/envs/stat.py | 789 +++++++++++++++++++++++++++ dizoo/distar/policy/distar_policy.py | 203 ++++--- dizoo/distar/policy/utils.py | 56 ++ 6 files changed, 993 insertions(+), 71 deletions(-) create mode 100644 dizoo/distar/envs/stat.py diff --git a/dizoo/distar/envs/__init__.py b/dizoo/distar/envs/__init__.py index 1c80372f03..be5986556a 100644 --- a/dizoo/distar/envs/__init__.py +++ b/dizoo/distar/envs/__init__.py @@ -1,3 +1,4 @@ from .meta import * -from .static_data import BEGIN_ACTIONS, ACTION_RACE_MASK, SELECTED_UNITS_MASK +from .static_data import BEGIN_ACTIONS, ACTION_RACE_MASK, SELECTED_UNITS_MASK, ACTIONS +from .stat import Stat from .fake_data import get_fake_rl_trajectory diff --git a/dizoo/distar/envs/fake_data.py b/dizoo/distar/envs/fake_data.py index ec910eaa72..9dfdf2c56a 100644 --- a/dizoo/distar/envs/fake_data.py +++ b/dizoo/distar/envs/fake_data.py @@ -2,7 +2,7 @@ import torch from ding.utils.data import default_collate -from .meta import MAX_DELAY, MAX_ENTITY_NUM, NUM_ACTIONS, ENTITY_TYPE_NUM, NUM_UPGRADES, NUM_CUMULATIVE_STAT_ACTIONS, \ +from .meta import MAX_DELAY, MAX_ENTITY_NUM, NUM_ACTIONS, NUM_UNIT_TYPES, NUM_UPGRADES, NUM_CUMULATIVE_STAT_ACTIONS, \ NUM_BEGINNING_ORDER_ACTIONS, NUM_UNIT_MIX_ABILITIES, NUM_QUEUE_ACTION, NUM_BUFFS, NUM_ADDON, MAX_SELECTED_UNITS_NUM H, W = 152, 160 @@ -28,7 +28,7 @@ def spatial_info(): def entity_info(): data = { - 'unit_type': torch.randint(0, ENTITY_TYPE_NUM, size=(MAX_ENTITY_NUM, ), dtype=torch.float), + 'unit_type': torch.randint(0, NUM_UNIT_TYPES, size=(MAX_ENTITY_NUM, ), dtype=torch.float), 'alliance': torch.randint(0, 5, size=(MAX_ENTITY_NUM, ), dtype=torch.float), 'cargo_space_taken': torch.randint(0, 9, size=(MAX_ENTITY_NUM, ), dtype=torch.float), 'build_progress': torch.rand(MAX_ENTITY_NUM), @@ -74,7 +74,7 @@ def scalar_info(): 'away_race': torch.randint(0, 4, size=(), dtype=torch.float), 'agent_statistics': torch.rand(10), 'time': torch.randint(0, 100, size=(), dtype=torch.float), - 'unit_counts_bow': torch.randint(0, 10, size=(ENTITY_TYPE_NUM, ), dtype=torch.float), + 'unit_counts_bow': torch.randint(0, 10, size=(NUM_UNIT_TYPES, ), dtype=torch.float), 'beginning_build_order': torch.randint(0, 20, size=(20, ), dtype=torch.float), 'cumulative_stat': torch.randint(0, 2, size=(NUM_CUMULATIVE_STAT_ACTIONS, ), dtype=torch.float), 'last_delay': torch.randint(0, MAX_DELAY, size=(), dtype=torch.float), @@ -83,8 +83,8 @@ def scalar_info(): 'upgrades': torch.randint(0, 2, size=(NUM_UPGRADES, ), dtype=torch.float), 'beginning_order': torch.randint(0, NUM_BEGINNING_ORDER_ACTIONS, size=(20, ), dtype=torch.float), 'bo_location': torch.randint(0, 100 * 100, size=(20, ), dtype=torch.float), - 'unit_type_bool': torch.randint(0, 2, size=(ENTITY_TYPE_NUM, ), dtype=torch.float), - 'enemy_unit_type_bool': torch.randint(0, 2, size=(ENTITY_TYPE_NUM, ), dtype=torch.float), + 'unit_type_bool': torch.randint(0, 2, size=(NUM_UNIT_TYPES, ), dtype=torch.float), + 'enemy_unit_type_bool': torch.randint(0, 2, size=(NUM_UNIT_TYPES, ), dtype=torch.float), 'unit_order_type': torch.randint(0, 2, size=(NUM_UNIT_MIX_ABILITIES, ), dtype=torch.float) } return data diff --git a/dizoo/distar/envs/meta.py b/dizoo/distar/envs/meta.py index 7cae00865e..bb11787f79 100644 --- a/dizoo/distar/envs/meta.py +++ b/dizoo/distar/envs/meta.py @@ -1,7 +1,7 @@ MAX_DELAY = 128 MAX_ENTITY_NUM = 512 MAX_SELECTED_UNITS_NUM = 64 -ENTITY_TYPE_NUM = 260 +NUM_UNIT_TYPES = 260 NUM_ACTIONS = 327 NUM_UPGRADES = 90 NUM_CUMULATIVE_STAT_ACTIONS = 167 @@ -10,4 +10,3 @@ NUM_QUEUE_ACTION = 49 NUM_BUFFS = 50 NUM_ADDON = 9 -MAX_SELECTED_UNITS_NUM = 64 diff --git a/dizoo/distar/envs/stat.py b/dizoo/distar/envs/stat.py new file mode 100644 index 0000000000..cebc3a028c --- /dev/null +++ b/dizoo/distar/envs/stat.py @@ -0,0 +1,789 @@ +from collections import defaultdict +import torch +from .static_data import ACTIONS + + +class Stat(object): + + def __init__(self, race_id): + self._unit_num = defaultdict(int) + self._unit_num['max_unit_num'] = 0 + self._race_id = race_id + for k, v in unit_dict[race_id].items(): + self._unit_num[v] = 0 + self._action_success_count = defaultdict(int) + + def update(self, last_action_type, action_result, observation, game_step): + if action_result < 1: + return + if action_result == 1: + self.count_unit_num(last_action_type) + entity_info, entity_num = observation['entity_info'], observation['entity_num'] + try: + if (entity_info['alliance'][:entity_num] == 1).sum() > 10: + self.success_rate_calc(last_action_type, action_result) + except Exception as e: + print('ERROR_ stat.py', e, entity_info['alliance'], entity_num) + + def success_rate_calc(self, last_action_type, action_result): + action_name = ACTIONS[last_action_type]['name'] + error_msg = action_result_dict[action_result] + self._action_success_count['rate/{}/{}'.format(action_name, error_msg)] += 1 + self._action_success_count['rate/{}/{}'.format(action_name, 'count')] += 1 + + def get_stat_data(self): + data = {} + for k, v in self._unit_num.items(): + if k != 'max_unit_num': + data['units/' + k] = v / self._unit_num['max_unit_num'] + for k, v in self._action_success_count.items(): + action_type = k.split('rate/')[1].split('/')[0] + if 'count' in k: + data[k] = v + else: + data[k] = v / (self._action_success_count['rate/{}/{}'.format(action_type, 'count')] + 1e-6) + return data + + def count_unit_num(self, last_action_type): + unit_name = self.get_build_unit_name(last_action_type, self._race_id) + if not unit_name: + return + self._unit_num[unit_name] += 1 + self._unit_num['max_unit_num'] = max(self._unit_num[unit_name], self._unit_num['max_unit_num']) + + @staticmethod + def get_build_unit_name(action_type, race_id): + action_type = ACTIONS[action_type]['func_id'] + unit_name = unit_dict[race_id].get(action_type, False) + return unit_name + + def set_race_id(self, race_id: int): + self._race_id = race_id + + @property + def unit_num(self): + return self._unit_num + + +unit_dict = { + 'zerg': { + 383: 'BroodLord', + 391: 'Lurker', + 395: 'OverlordTransport', + 396: 'Overseer', + 400: 'Ravager', + 498: 'Baneling', + 501: 'Corruptor', + 503: 'Drone', + 507: 'Hydralisk', + 508: 'Infestor', + 514: 'Mutalisk', + 515: 'Overlord', + 516: 'Queen', + 519: 'Roach', + 522: 'SwarmHost', + 524: 'Ultralisk', + 526: 'Viper', + 528: 'Zergling' + }, + 'terran': { + 499: 'Banshee', + 500: 'Battlecruiser', + 502: 'Cyclone', + 504: 'Ghost', + 505: 'Hellbat', + 506: 'Hellion', + 509: 'Liberator', + 510: 'Marauder', + 511: 'Marine', + 512: 'Medivac', + 517: 'Raven', + 518: 'Reaper', + 520: 'SCV', + 521: 'SiegeTank', + 523: 'Thor', + 525: 'VikingFighter', + 527: 'WidowMine' + }, + 'protoss': { + 86: 'Archon', + 393: 'Mothership', + 54: 'Adept', + 56: 'Carrier', + 62: 'Colossus', + 52: 'DarkTemplar', + 166: 'Disruptor', + 51: 'HighTemplar', + 63: 'Immortal', + 513: 'MothershipCore', + 21: 'Mothership', + 61: 'Observer', + 58: 'Oracle', + 55: 'Phoenix', + 64: 'Probe', + 53: 'Sentry', + 50: 'Stalker', + 59: 'Tempest', + 57: 'VoidRay', + 76: 'Adept', + 74: 'DarkTemplar', + 73: 'HighTemplar', + 60: 'WarpPrism', + 75: 'Sentry', + 72: 'Stalker', + 71: 'Zealot', + 49: 'Zealot' + } +} + +cum_dict = [ + { + 'race': ['zerg', 'terran', 'protoss'], + 'name': 'no_op' + }, { + 'race': ['terran'], + 'name': 'Armory' + }, { + 'race': ['protoss'], + 'name': 'Assimilator' + }, { + 'race': ['zerg'], + 'name': 'BanelingNest' + }, { + 'race': ['terran'], + 'name': 'Barracks' + }, { + 'race': ['terran'], + 'name': 'CommandCenter' + }, { + 'race': ['protoss'], + 'name': 'CyberneticsCore' + }, { + 'race': ['protoss'], + 'name': 'DarkShrine' + }, { + 'race': ['terran'], + 'name': 'EngineeringBay' + }, { + 'race': ['zerg'], + 'name': 'EvolutionChamber' + }, { + 'race': ['zerg'], + 'name': 'Extractor' + }, { + 'race': ['terran'], + 'name': 'Factory' + }, { + 'race': ['protoss'], + 'name': 'FleetBeacon' + }, { + 'race': ['protoss'], + 'name': 'Forge' + }, { + 'race': ['terran'], + 'name': 'FusionCore' + }, { + 'race': ['protoss'], + 'name': 'Gateway' + }, { + 'race': ['terran'], + 'name': 'GhostAcademy' + }, { + 'race': ['zerg'], + 'name': 'Hatchery' + }, { + 'race': ['zerg'], + 'name': 'HydraliskDen' + }, { + 'race': ['zerg'], + 'name': 'InfestationPit' + }, { + 'race': ['protoss'], + 'name': 'Interceptors' + }, { + 'race': ['protoss'], + 'name': 'Interceptors' + }, { + 'race': ['zerg'], + 'name': 'LurkerDen' + }, { + 'race': ['protoss'], + 'name': 'Nexus' + }, { + 'race': ['terran'], + 'name': 'Nuke' + }, { + 'race': ['zerg'], + 'name': 'NydusNetwork' + }, { + 'race': ['zerg'], + 'name': 'NydusWorm' + }, { + 'race': ['terran'], + 'name': 'Reactor' + }, { + 'race': ['terran'], + 'name': 'Reactor' + }, { + 'race': ['terran'], + 'name': 'Refinery' + }, { + 'race': ['zerg'], + 'name': 'RoachWarren' + }, { + 'race': ['protoss'], + 'name': 'RoboticsBay' + }, { + 'race': ['protoss'], + 'name': 'RoboticsFacility' + }, { + 'race': ['terran'], + 'name': 'SensorTower' + }, { + 'race': ['zerg'], + 'name': 'SpawningPool' + }, { + 'race': ['zerg'], + 'name': 'Spire' + }, { + 'race': ['protoss'], + 'name': 'Stargate' + }, { + 'race': ['terran'], + 'name': 'Starport' + }, { + 'race': ['protoss'], + 'name': 'StasisTrap' + }, { + 'race': ['terran'], + 'name': 'TechLab' + }, { + 'race': ['terran'], + 'name': 'TechLab' + }, { + 'race': ['protoss'], + 'name': 'TemplarArchive' + }, { + 'race': ['protoss'], + 'name': 'TwilightCouncil' + }, { + 'race': ['zerg'], + 'name': 'UltraliskCavern' + }, { + 'race': ['protoss'], + 'name': 'Archon' + }, { + 'race': ['zerg'], + 'name': 'BroodLord' + }, { + 'race': ['zerg'], + 'name': 'GreaterSpire' + }, { + 'race': ['zerg'], + 'name': 'Hive' + }, { + 'race': ['zerg'], + 'name': 'Lair' + }, { + 'race': ['zerg'], + 'name': 'LurkerDen' + }, { + 'race': ['zerg'], + 'name': 'Lurker' + }, { + 'race': ['protoss'], + 'name': 'Mothership' + }, { + 'race': ['terran'], + 'name': 'OrbitalCommand' + }, { + 'race': ['zerg'], + 'name': 'OverlordTransport' + }, { + 'race': ['terran'], + 'name': 'PlanetaryFortress' + }, { + 'race': ['zerg'], + 'name': 'Ravager' + }, { + 'race': ['zerg'], + 'name': 'Research_AdaptiveTalons' + }, { + 'race': ['protoss'], + 'name': 'Research_AdeptResonatingGlaives' + }, { + 'race': ['terran'], + 'name': 'Research_AdvancedBallistics' + }, { + 'race': ['zerg'], + 'name': 'Research_AnabolicSynthesis' + }, { + 'race': ['terran'], + 'name': 'Research_BansheeCloakingField' + }, { + 'race': ['terran'], + 'name': 'Research_BansheeHyperflightRotors' + }, { + 'race': ['terran'], + 'name': 'Research_BattlecruiserWeaponRefit' + }, { + 'race': ['protoss'], + 'name': 'Research_Blink' + }, { + 'race': ['zerg'], + 'name': 'Research_Burrow' + }, { + 'race': ['zerg'], + 'name': 'Research_CentrifugalHooks' + }, { + 'race': ['protoss'], + 'name': 'Research_Charge' + }, { + 'race': ['zerg'], + 'name': 'Research_ChitinousPlating' + }, { + 'race': ['terran'], + 'name': 'Research_CombatShield' + }, { + 'race': ['terran'], + 'name': 'Research_ConcussiveShells' + }, { + 'race': ['terran'], + 'name': 'Research_CycloneLockOnDamage' + }, { + 'race': ['terran'], + 'name': 'Research_CycloneRapidFireLaunchers' + }, { + 'race': ['terran'], + 'name': 'Research_DrillingClaws' + }, { + 'race': ['terran'], + 'name': 'Research_EnhancedShockwaves' + }, { + 'race': ['protoss'], + 'name': 'Research_ExtendedThermalLance' + }, { + 'race': ['zerg'], + 'name': 'Research_GlialRegeneration' + }, { + 'race': ['protoss'], + 'name': 'Research_GraviticBooster' + }, { + 'race': ['protoss'], + 'name': 'Research_GraviticDrive' + }, { + 'race': ['zerg'], + 'name': 'Research_GroovedSpines' + }, { + 'race': ['terran'], + 'name': 'Research_HighCapacityFuelTanks' + }, { + 'race': ['terran'], + 'name': 'Research_HiSecAutoTracking' + }, { + 'race': ['terran'], + 'name': 'Research_InfernalPreigniter' + }, { + 'race': ['protoss'], + 'name': 'Research_InterceptorGravitonCatapult' + }, { + 'race': ['zerg'], + 'name': 'Research_MuscularAugments' + }, { + 'race': ['terran'], + 'name': 'Research_NeosteelFrame' + }, { + 'race': ['zerg'], + 'name': 'Research_NeuralParasite' + }, { + 'race': ['zerg'], + 'name': 'Research_PathogenGlands' + }, { + 'race': ['terran'], + 'name': 'Research_PersonalCloaking' + }, { + 'race': ['protoss'], + 'name': 'Research_PhoenixAnionPulseCrystals' + }, { + 'race': ['zerg'], + 'name': 'Research_PneumatizedCarapace' + }, { + 'race': ['protoss'], + 'name': 'Research_ProtossAirArmor' + }, { + 'race': ['protoss'], + 'name': 'Research_ProtossAirWeapons' + }, { + 'race': ['protoss'], + 'name': 'Research_ProtossGroundArmor' + }, { + 'race': ['protoss'], + 'name': 'Research_ProtossGroundWeapons' + }, { + 'race': ['protoss'], + 'name': 'Research_ProtossShields' + }, { + 'race': ['protoss'], + 'name': 'Research_PsiStorm' + }, { + 'race': ['terran'], + 'name': 'Research_RavenCorvidReactor' + }, { + 'race': ['terran'], + 'name': 'Research_RavenRecalibratedExplosives' + }, { + 'race': ['protoss'], + 'name': 'Research_ShadowStrike' + }, { + 'race': ['terran'], + 'name': 'Research_SmartServos' + }, { + 'race': ['terran'], + 'name': 'Research_Stimpack' + }, { + 'race': ['terran'], + 'name': 'Research_TerranInfantryArmor' + }, { + 'race': ['terran'], + 'name': 'Research_TerranInfantryWeapons' + }, { + 'race': ['terran'], + 'name': 'Research_TerranShipWeapons' + }, { + 'race': ['terran'], + 'name': 'Research_TerranStructureArmorUpgrade' + }, { + 'race': ['terran'], + 'name': 'Research_TerranVehicleAndShipPlating' + }, { + 'race': ['terran'], + 'name': 'Research_TerranVehicleWeapons' + }, { + 'race': ['zerg'], + 'name': 'Research_TunnelingClaws' + }, { + 'race': ['protoss'], + 'name': 'Research_WarpGate' + }, { + 'race': ['zerg'], + 'name': 'Research_ZergFlyerArmor' + }, { + 'race': ['zerg'], + 'name': 'Research_ZergFlyerAttack' + }, { + 'race': ['zerg'], + 'name': 'Research_ZergGroundArmor' + }, { + 'race': ['zerg'], + 'name': 'Research_ZerglingAdrenalGlands' + }, { + 'race': ['zerg'], + 'name': 'Research_ZerglingMetabolicBoost' + }, { + 'race': ['zerg'], + 'name': 'Research_ZergMeleeWeapons' + }, { + 'race': ['zerg'], + 'name': 'Research_ZergMissileWeapons' + }, { + 'race': ['protoss'], + 'name': 'Adept' + }, { + 'race': ['zerg'], + 'name': 'Baneling' + }, { + 'race': ['terran'], + 'name': 'Banshee' + }, { + 'race': ['terran'], + 'name': 'Battlecruiser' + }, { + 'race': ['protoss'], + 'name': 'Carrier' + }, { + 'race': ['protoss'], + 'name': 'Colossus' + }, { + 'race': ['zerg'], + 'name': 'Corruptor' + }, { + 'race': ['terran'], + 'name': 'Cyclone' + }, { + 'race': ['protoss'], + 'name': 'DarkTemplar' + }, { + 'race': ['protoss'], + 'name': 'Disruptor' + }, { + 'race': ['terran'], + 'name': 'Ghost' + }, { + 'race': ['terran'], + 'name': 'Hellbat' + }, { + 'race': ['terran'], + 'name': 'Hellion' + }, { + 'race': ['protoss'], + 'name': 'HighTemplar' + }, { + 'race': ['zerg'], + 'name': 'Hydralisk' + }, { + 'race': ['protoss'], + 'name': 'Immortal' + }, { + 'race': ['zerg'], + 'name': 'Infestor' + }, { + 'race': ['terran'], + 'name': 'Liberator' + }, { + 'race': ['terran'], + 'name': 'Marauder' + }, { + 'race': ['terran'], + 'name': 'Marine' + }, { + 'race': ['terran'], + 'name': 'Medivac' + }, { + 'race': ['protoss'], + 'name': 'MothershipCore' + }, { + 'race': ['protoss'], + 'name': 'Mothership' + }, { + 'race': ['zerg'], + 'name': 'Mutalisk' + }, { + 'race': ['protoss'], + 'name': 'Observer' + }, { + 'race': ['protoss'], + 'name': 'Oracle' + }, { + 'race': ['protoss'], + 'name': 'Phoenix' + }, { + 'race': ['zerg'], + 'name': 'Queen' + }, { + 'race': ['terran'], + 'name': 'Raven' + }, { + 'race': ['terran'], + 'name': 'Reaper' + }, { + 'race': ['zerg'], + 'name': 'Roach' + }, { + 'race': ['protoss'], + 'name': 'Sentry' + }, { + 'race': ['terran'], + 'name': 'SiegeTank' + }, { + 'race': ['protoss'], + 'name': 'Stalker' + }, { + 'race': ['zerg'], + 'name': 'SwarmHost' + }, { + 'race': ['protoss'], + 'name': 'Tempest' + }, { + 'race': ['terran'], + 'name': 'Thor' + }, { + 'race': ['zerg'], + 'name': 'Ultralisk' + }, { + 'race': ['terran'], + 'name': 'VikingFighter' + }, { + 'race': ['zerg'], + 'name': 'Viper' + }, { + 'race': ['protoss'], + 'name': 'VoidRay' + }, { + 'race': ['protoss'], + 'name': 'Adept' + }, { + 'race': ['protoss'], + 'name': 'DarkTemplar' + }, { + 'race': ['protoss'], + 'name': 'HighTemplar' + }, { + 'race': ['protoss'], + 'name': 'WarpPrism' + }, { + 'race': ['protoss'], + 'name': 'Sentry' + }, { + 'race': ['protoss'], + 'name': 'Stalker' + }, { + 'race': ['protoss'], + 'name': 'Zealot' + }, { + 'race': ['terran'], + 'name': 'WidowMine' + }, { + 'race': ['protoss'], + 'name': 'Zealot' + }, { + 'race': ['zerg'], + 'name': 'Zergling' + } +] + +action_result_dict = [ + '', 'Success', 'ERROR_NotSupported', 'ERROR_Error', 'ERROR_CantQueueThatOrder', 'ERROR_Retry', 'ERROR_Cooldown', + 'ERROR_QueueIsFull', 'ERROR_RallyQueueIsFull', 'ERROR_NotEnoughMinerals', 'ERROR_NotEnoughVespene', + 'ERROR_NotEnoughTerrazine', 'ERROR_NotEnoughCustom', 'ERROR_NotEnoughFood', 'ERROR_FoodUsageImpossible', + 'ERROR_NotEnoughLife', 'ERROR_NotEnoughShields', 'ERROR_NotEnoughEnergy', 'ERROR_LifeSuppressed', + 'ERROR_ShieldsSuppressed', 'ERROR_EnergySuppressed', 'ERROR_NotEnoughCharges', 'ERROR_CantAddMoreCharges', + 'ERROR_TooMuchMinerals', 'ERROR_TooMuchVespene', 'ERROR_TooMuchTerrazine', 'ERROR_TooMuchCustom', + 'ERROR_TooMuchFood', 'ERROR_TooMuchLife', 'ERROR_TooMuchShields', 'ERROR_TooMuchEnergy', + 'ERROR_MustTargetUnitWithLife', 'ERROR_MustTargetUnitWithShields', 'ERROR_MustTargetUnitWithEnergy', + 'ERROR_CantTrade', 'ERROR_CantSpend', 'ERROR_CantTargetThatUnit', 'ERROR_CouldntAllocateUnit', 'ERROR_UnitCantMove', + 'ERROR_TransportIsHoldingPosition', 'ERROR_BuildTechRequirementsNotMet', 'ERROR_CantFindPlacementLocation', + 'ERROR_CantBuildOnThat', 'ERROR_CantBuildTooCloseToDropOff', 'ERROR_CantBuildLocationInvalid', + 'ERROR_CantSeeBuildLocation', 'ERROR_CantBuildTooCloseToCreepSource', 'ERROR_CantBuildTooCloseToResources', + 'ERROR_CantBuildTooFarFromWater', 'ERROR_CantBuildTooFarFromCreepSource', + 'ERROR_CantBuildTooFarFromBuildPowerSource', 'ERROR_CantBuildOnDenseTerrain', + 'ERROR_CantTrainTooFarFromTrainPowerSource', 'ERROR_CantLandLocationInvalid', 'ERROR_CantSeeLandLocation', + 'ERROR_CantLandTooCloseToCreepSource', 'ERROR_CantLandTooCloseToResources', 'ERROR_CantLandTooFarFromWater', + 'ERROR_CantLandTooFarFromCreepSource', 'ERROR_CantLandTooFarFromBuildPowerSource', + 'ERROR_CantLandTooFarFromTrainPowerSource', 'ERROR_CantLandOnDenseTerrain', 'ERROR_AddOnTooFarFromBuilding', + 'ERROR_MustBuildRefineryFirst', 'ERROR_BuildingIsUnderConstruction', 'ERROR_CantFindDropOff', + 'ERROR_CantLoadOtherPlayersUnits', 'ERROR_NotEnoughRoomToLoadUnit', 'ERROR_CantUnloadUnitsThere', + 'ERROR_CantWarpInUnitsThere', 'ERROR_CantLoadImmobileUnits', 'ERROR_CantRechargeImmobileUnits', + 'ERROR_CantRechargeUnderConstructionUnits', 'ERROR_CantLoadThatUnit', 'ERROR_NoCargoToUnload', + 'ERROR_LoadAllNoTargetsFound', 'ERROR_NotWhileOccupied', 'ERROR_CantAttackWithoutAmmo', 'ERROR_CantHoldAnyMoreAmmo', + 'ERROR_TechRequirementsNotMet', 'ERROR_MustLockdownUnitFirst', 'ERROR_MustTargetUnit', 'ERROR_MustTargetInventory', + 'ERROR_MustTargetVisibleUnit', 'ERROR_MustTargetVisibleLocation', 'ERROR_MustTargetWalkableLocation', + 'ERROR_MustTargetPawnableUnit', 'ERROR_YouCantControlThatUnit', 'ERROR_YouCantIssueCommandsToThatUnit', + 'ERROR_MustTargetResources', 'ERROR_RequiresHealTarget', 'ERROR_RequiresRepairTarget', 'ERROR_NoItemsToDrop', + 'ERROR_CantHoldAnyMoreItems', 'ERROR_CantHoldThat', 'ERROR_TargetHasNoInventory', 'ERROR_CantDropThisItem', + 'ERROR_CantMoveThisItem', 'ERROR_CantPawnThisUnit', 'ERROR_MustTargetCaster', 'ERROR_CantTargetCaster', + 'ERROR_MustTargetOuter', 'ERROR_CantTargetOuter', 'ERROR_MustTargetYourOwnUnits', 'ERROR_CantTargetYourOwnUnits', + 'ERROR_MustTargetFriendlyUnits', 'ERROR_CantTargetFriendlyUnits', 'ERROR_MustTargetNeutralUnits', + 'ERROR_CantTargetNeutralUnits', 'ERROR_MustTargetEnemyUnits', 'ERROR_CantTargetEnemyUnits', + 'ERROR_MustTargetAirUnits', 'ERROR_CantTargetAirUnits', 'ERROR_MustTargetGroundUnits', + 'ERROR_CantTargetGroundUnits', 'ERROR_MustTargetStructures', 'ERROR_CantTargetStructures', + 'ERROR_MustTargetLightUnits', 'ERROR_CantTargetLightUnits', 'ERROR_MustTargetArmoredUnits', + 'ERROR_CantTargetArmoredUnits', 'ERROR_MustTargetBiologicalUnits', 'ERROR_CantTargetBiologicalUnits', + 'ERROR_MustTargetHeroicUnits', 'ERROR_CantTargetHeroicUnits', 'ERROR_MustTargetRoboticUnits', + 'ERROR_CantTargetRoboticUnits', 'ERROR_MustTargetMechanicalUnits', 'ERROR_CantTargetMechanicalUnits', + 'ERROR_MustTargetPsionicUnits', 'ERROR_CantTargetPsionicUnits', 'ERROR_MustTargetMassiveUnits', + 'ERROR_CantTargetMassiveUnits', 'ERROR_MustTargetMissile', 'ERROR_CantTargetMissile', 'ERROR_MustTargetWorkerUnits', + 'ERROR_CantTargetWorkerUnits', 'ERROR_MustTargetEnergyCapableUnits', 'ERROR_CantTargetEnergyCapableUnits', + 'ERROR_MustTargetShieldCapableUnits', 'ERROR_CantTargetShieldCapableUnits', 'ERROR_MustTargetFlyers', + 'ERROR_CantTargetFlyers', 'ERROR_MustTargetBuriedUnits', 'ERROR_CantTargetBuriedUnits', + 'ERROR_MustTargetCloakedUnits', 'ERROR_CantTargetCloakedUnits', 'ERROR_MustTargetUnitsInAStasisField', + 'ERROR_CantTargetUnitsInAStasisField', 'ERROR_MustTargetUnderConstructionUnits', + 'ERROR_CantTargetUnderConstructionUnits', 'ERROR_MustTargetDeadUnits', 'ERROR_CantTargetDeadUnits', + 'ERROR_MustTargetRevivableUnits', 'ERROR_CantTargetRevivableUnits', 'ERROR_MustTargetHiddenUnits', + 'ERROR_CantTargetHiddenUnits', 'ERROR_CantRechargeOtherPlayersUnits', 'ERROR_MustTargetHallucinations', + 'ERROR_CantTargetHallucinations', 'ERROR_MustTargetInvulnerableUnits', 'ERROR_CantTargetInvulnerableUnits', + 'ERROR_MustTargetDetectedUnits', 'ERROR_CantTargetDetectedUnits', 'ERROR_CantTargetUnitWithEnergy', + 'ERROR_CantTargetUnitWithShields', 'ERROR_MustTargetUncommandableUnits', 'ERROR_CantTargetUncommandableUnits', + 'ERROR_MustTargetPreventDefeatUnits', 'ERROR_CantTargetPreventDefeatUnits', 'ERROR_MustTargetPreventRevealUnits', + 'ERROR_CantTargetPreventRevealUnits', 'ERROR_MustTargetPassiveUnits', 'ERROR_CantTargetPassiveUnits', + 'ERROR_MustTargetStunnedUnits', 'ERROR_CantTargetStunnedUnits', 'ERROR_MustTargetSummonedUnits', + 'ERROR_CantTargetSummonedUnits', 'ERROR_MustTargetUser1', 'ERROR_CantTargetUser1', + 'ERROR_MustTargetUnstoppableUnits', 'ERROR_CantTargetUnstoppableUnits', 'ERROR_MustTargetResistantUnits', + 'ERROR_CantTargetResistantUnits', 'ERROR_MustTargetDazedUnits', 'ERROR_CantTargetDazedUnits', 'ERROR_CantLockdown', + 'ERROR_CantMindControl', 'ERROR_MustTargetDestructibles', 'ERROR_CantTargetDestructibles', 'ERROR_MustTargetItems', + 'ERROR_CantTargetItems', 'ERROR_NoCalldownAvailable', 'ERROR_WaypointListFull', 'ERROR_MustTargetRace', + 'ERROR_CantTargetRace', 'ERROR_MustTargetSimilarUnits', 'ERROR_CantTargetSimilarUnits', + 'ERROR_CantFindEnoughTargets', 'ERROR_AlreadySpawningLarva', 'ERROR_CantTargetExhaustedResources', + 'ERROR_CantUseMinimap', 'ERROR_CantUseInfoPanel', 'ERROR_OrderQueueIsFull', 'ERROR_CantHarvestThatResource', + 'ERROR_HarvestersNotRequired', 'ERROR_AlreadyTargeted', 'ERROR_CantAttackWeaponsDisabled', + 'ERROR_CouldntReachTarget', 'ERROR_TargetIsOutOfRange', 'ERROR_TargetIsTooClose', 'ERROR_TargetIsOutOfArc', + 'ERROR_CantFindTeleportLocation', 'ERROR_InvalidItemClass', 'ERROR_CantFindCancelOrder' +] +NUM_ACTION_RESULT = 214 + +ACTION_RACE_MASK = { + 'zerg': torch.tensor( + [ + False, False, True, True, True, True, False, False, True, True, True, True, False, False, False, False, + True, False, False, False, True, False, False, False, True, True, False, False, False, False, False, False, + True, True, True, False, False, True, False, False, False, True, True, False, False, False, False, False, + True, False, False, False, False, True, True, True, True, False, False, False, False, False, False, False, + False, True, True, True, True, True, True, True, False, False, False, True, False, False, False, False, + True, False, False, False, False, False, True, True, False, False, True, False, False, True, True, False, + False, False, False, False, False, False, True, True, False, False, False, False, False, True, False, False, + True, False, False, True, False, False, False, False, False, False, False, False, False, True, True, True, + True, False, False, False, False, True, True, False, False, False, False, False, False, False, False, False, + False, False, False, False, False, False, False, False, False, False, True, True, True, False, False, False, + True, False, True, False, True, False, False, True, True, False, False, True, True, False, False, False, + True, True, True, True, False, True, True, False, False, False, False, False, False, False, True, False, + False, False, False, False, False, True, True, True, True, True, True, True, True, True, False, False, True, + False, False, False, False, True, True, False, True, False, False, False, False, False, False, False, True, + False, False, True, False, False, False, False, True, False, True, True, False, False, True, False, False, + False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, + True, False, True, True, True, True, True, True, True, True, True, True, False, True, False, False, False, + False, True, False, False, False, True, False, False, False, False, True, False, True, False, False, False, + False, False, False, True, False, False, True, False, False, True, False, False, True, False, False, False, + False, True, False, False, True, False, True, False, False, False, False, False, False, False, False, False, + False, True, True, True, True, True + ] + ), + 'terran': torch.tensor( + [ + False, False, True, True, False, False, True, True, False, False, True, True, False, False, True, False, + False, True, True, True, False, False, False, True, False, False, True, False, False, True, False, True, + False, False, False, False, False, False, True, False, True, False, False, False, False, True, True, True, + False, False, False, True, False, False, False, False, False, False, True, False, True, True, True, False, + False, False, True, True, True, True, True, False, False, True, True, False, False, False, True, True, + False, False, False, False, False, False, False, False, True, True, False, False, False, False, False, True, + False, False, True, True, False, False, False, False, True, True, True, True, True, False, False, True, + False, True, False, False, False, False, True, True, True, False, False, True, True, False, False, False, + True, True, True, True, False, False, False, False, True, True, True, True, False, False, False, False, + False, False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, + True, False, False, False, False, True, True, False, False, True, True, False, False, False, False, True, + False, False, False, False, True, False, False, True, True, True, False, True, True, True, False, True, + True, False, False, False, False, True, True, True, True, True, True, True, True, False, False, True, False, + True, True, True, False, False, False, False, False, True, True, True, True, True, True, False, False, + False, False, False, True, True, True, False, False, True, False, False, True, False, False, False, False, + False, False, False, False, True, True, False, True, True, True, True, True, True, True, True, False, False, + False, False, False, False, False, False, False, True, True, True, False, False, True, True, False, False, + False, True, False, False, False, True, True, True, False, False, False, False, True, True, True, True, + False, False, False, False, False, False, False, False, False, True, True, False, True, False, True, False, + False, False, True, False, True, False, False, False, False, False, False, False, False, False, True, False, + False, True, True, True, True + ] + ), + 'protoss': torch.tensor( + [ + False, False, True, True, False, False, False, False, False, False, False, False, True, True, False, True, + False, False, False, False, False, True, True, False, False, False, False, True, True, False, True, False, + False, False, False, True, True, False, False, True, False, False, False, True, True, False, False, False, + False, True, True, False, True, False, False, False, False, True, False, True, False, False, False, True, + True, False, False, False, False, True, True, False, True, False, False, False, True, True, False, False, + False, True, True, True, True, True, False, False, False, False, False, True, True, False, False, False, + True, True, False, False, True, True, False, False, False, False, False, False, False, False, True, False, + False, False, True, False, True, True, False, False, False, True, True, False, False, False, False, False, + True, False, False, False, True, False, False, True, False, False, False, False, True, True, True, True, + True, True, True, True, True, True, True, True, True, False, True, True, True, False, False, False, True, + True, False, True, False, False, False, False, False, False, False, False, False, True, True, False, False, + False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, + False, True, True, True, True, True, True, True, True, True, True, True, True, False, True, False, False, + False, False, False, True, False, False, True, False, False, False, False, False, False, False, True, False, + True, True, False, False, False, False, True, False, False, False, False, False, True, False, True, True, + True, True, True, True, False, False, True, False, False, False, False, False, False, False, False, False, + True, False, False, False, False, False, False, False, True, True, True, True, False, False, False, True, + True, False, False, True, True, False, False, False, False, True, False, True, False, False, False, False, + False, True, True, False, True, True, False, True, True, False, False, False, False, False, True, False, + True, False, True, False, False, False, False, True, True, True, True, True, True, True, True, False, True, + False, True, True, False, True + ] + ) +} diff --git a/dizoo/distar/policy/distar_policy.py b/dizoo/distar/policy/distar_policy.py index dd223d714a..7a0dba799a 100644 --- a/dizoo/distar/policy/distar_policy.py +++ b/dizoo/distar/policy/distar_policy.py @@ -2,7 +2,6 @@ from easydict import EasyDict import os.path as osp import torch -import torch.nn.functional as F from torch.optim import Adam from ding.model import model_wrap @@ -11,64 +10,10 @@ from ding.rl_utils import td_lambda_data, td_lambda_error, vtrace_data_with_rho, vtrace_error_with_rho, \ upgo_data, upgo_error from ding.utils import EasyTimer +from ding.utils.data import default_collate, default_decollate from dizoo.distar.model import Model -from .utils import collate_fn_learn - -EPS = 1e-9 - - -def entropy_error(target_policy_probs_dict, target_policy_log_probs_dict, mask, head_weights_dict): - total_entropy_loss = 0. - entropy_dict = {} - for head_type in ['action_type', 'queued', 'delay', 'selected_units', 'target_unit', 'target_location']: - ent = -target_policy_probs_dict[head_type] * target_policy_log_probs_dict[head_type] - if head_type == 'selected_units': - ent = ent.sum(dim=-1) / ( - EPS + torch.log(mask['selected_units_logits_mask'].float().sum(dim=-1) + 1).unsqueeze(-1) - ) # normalize - ent = (ent * mask['selected_units_mask']).sum(-1) - ent = ent.div(mask['selected_units_mask'].sum(-1) + EPS) - elif head_type == 'target_unit': - # normalize by unit - ent = ent.sum(dim=-1) / (EPS + torch.log(mask['target_units_logits_mask'].float().sum(dim=-1) + 1)) - else: - ent = ent.sum(dim=-1) / torch.log(torch.FloatTensor([ent.shape[-1]]).to(ent.device)) - if head_type not in ['action_type', 'delay']: - ent = ent * mask['actions_mask'][head_type] - entropy = ent.mean() - entropy_dict['entropy/' + head_type] = entropy.item() - total_entropy_loss += (-entropy * head_weights_dict[head_type]) - return total_entropy_loss, entropy_dict - - -def kl_error( - target_policy_log_probs_dict, teacher_policy_logits_dict, mask, game_steps, action_type_kl_steps, head_weights_dict -): - total_kl_loss = 0. - kl_loss_dict = {} - - for head_type in ['action_type', 'queued', 'delay', 'selected_units', 'target_unit', 'target_location']: - target_policy_log_probs = target_policy_log_probs_dict[head_type] - teacher_policy_logits = teacher_policy_logits_dict[head_type] - - teacher_policy_log_probs = F.log_softmax(teacher_policy_logits, dim=-1) - teacher_policy_probs = torch.exp(teacher_policy_log_probs) - kl = teacher_policy_probs * (teacher_policy_log_probs - target_policy_log_probs) - - kl = kl.sum(dim=-1) - if head_type == 'selected_units': - kl = (kl * mask['selected_units_mask']).sum(-1) - if head_type not in ['action_type', 'delay']: - kl = kl * mask['actions_mask'][head_type] - if head_type == 'action_type': - flag = game_steps < action_type_kl_steps - action_type_kl = kl * flag - action_type_kl_loss = action_type_kl.mean() - kl_loss_dict['kl/extra_at'] = action_type_kl_loss.item() - kl_loss = kl.mean() - total_kl_loss += (kl_loss * head_weights_dict[head_type]) - kl_loss_dict['kl/' + head_type] = kl_loss.item() - return total_kl_loss, action_type_kl_loss, kl_loss_dict +from dizoo.distar.envs import NUM_UNIT_TYPES, ACTIONS, Stat +from .utils import collate_fn_learn, kl_error, entropy_error class DIStarPolicy(Policy): @@ -78,6 +23,7 @@ class DIStarPolicy(Policy): cuda=True, learning_rate=1e-5, model=dict(), + # learn learn=dict(multi_gpu=False, ), loss_weights=dict( baseline=dict( @@ -156,7 +102,12 @@ class DIStarPolicy(Policy): battle=0.997, ), ), - grad_clip=dict(threshold=1.0, ) + grad_clip=dict(threshold=1.0, ), + # collect + use_value_feature=True, # whether to use value feature, this must be False when play against bot + zero_z_exceed_loop=True, # set Z to 0 if game passes the game loop in Z + zero_z_value=1, + extra_units=True, # selcet extra units if selected units exceed 64 ) def _create_model( @@ -170,6 +121,8 @@ def _create_model( field = enable_field[0] if field == 'learn': return Model(self._cfg.model, use_value_network=True) + elif field == 'collect': # disable value network + return Model(self._cfg.model) else: raise KeyError("invalid policy mode: {}".format(field)) @@ -398,13 +351,137 @@ def _load_state_dict_learn(self, _state_dict: Dict) -> None: self.optimizer.load_state_dict(_state_dict['optimizer']) def _init_collect(self): - pass + self.collect_model = model_wrap(self._model, 'base') + self._reset_collect() + + def _reset_collect(self, env_id=0): + self.stat = Stat('zerg') # TODO + self.target_z_loop = 43200 # TODO + self.exceed_loop_flag = False + self.hidden_state = None + self.last_action_type = torch.tensor(0, dtype=torch.long) + self.last_delay = torch.tensor(0, dtype=torch.long) + self.last_queued = torch.tensor(0, dtype=torch.long) + self.last_selected_unit_tags = None + self.last_targeted_unit_tag = None + self.last_location = None # [x, y] + self.enemy_unit_type_bool = torch.zeros(NUM_UNIT_TYPES, dtype=torch.uint8) + + self.target_building_order = None # TODO + self.target_bo_location = None + self.target_cumulative_stat = None + + self.map_size = None # TODO def _forward_collect(self, data): - pass + game_info = data.pop('game_info') + obs = self._data_preprocess_collect(data, game_info) + obs = default_collate([obs]) + if self._cfg.cuda: + obs = to_device(obs, self._device) - def _process_transition(self): - pass + with torch.no_grad(): + policy_output = self.collect_model.compute_logp_action(**obs) + + if self._cfg.cuda: + policy_output = to_device(policy_output, self._device) + policy_output = default_decollate(policy_output)[0] + policy_output = self._data_postprocess_collect(policy_output, game_info) + return policy_output + + def _data_preprocess_collect(self, data, game_info): + transform_obs = None + if self._cfg.use_value_feature: + obs = transform_obs(data['raw_obs'], opponent_obs=data['opponent_obs']) + else: + raise NotImplementedError + + game_step = game_info['game_loop'] + if self._cfg.zero_z_exceed_loop and game_step > self._target_z_loop: + self._exceed_loop_flag = True + + last_selected_units = torch.zeros(obs['entity_num'], dtype=torch.int8) + last_targeted_unit = torch.zeros(obs['entity_num'], dtype=torch.int8) + tags = game_info['tags'] + if self.last_selected_unit_tags is not None: + for t in self.last_selected_unit_tags: + if t in tags: + last_selected_units[tags.index(t)] = 1 + if self.last_targeted_unit_tag is None: + if self.last_targeted_unit_tag in tags: + last_targeted_unit[tags.index(self.last_targeted_unit_tag)] = 1 + obs['entity_info']['last_selected_units'] = last_selected_units + obs['entity_info']['last_targeted_unit'] = last_targeted_unit + + obs['hidden_state'] = self.hidden_state + + obs['scalar_info']['last_action_type'] = self.last_action_type + obs['scalar_info']['last_delay'] = self.last_delay + obs['scalar_info']['last_queued'] = self.last_queued + obs['scalar_info']['enemy_unit_type_bool'] = ( + self.enemy_unit_type_bool | obs['scalar_info']['enemy_unit_type_bool'] + ).to(torch.uint8) + obs['scalar_info']['beginning_order'] = self.target_building_order * (~self.exceed_loop_flag) + obs['scalar_info']['bo_location'] = self.target_bo_location * (~self.exceed_loop_flag) + if self.exceed_loop_flag: + obs['scalar_info']['cumulative_stat'] = self.target_cumulative_stat * 0 + self._cfg.zero_z_value + else: + obs['scalar_info']['cumulative_stat'] = self.target_cumulative_stat + + # update stat + self.stat.update(self.last_action_type, data['action_result'][0], obs, game_step) + return obs + + def _data_postprocess_collect(self, data, game_info): + self.hidden_state = data['hidden_state'] + + self.last_action_type = data['action_info']['action_type'] + self.last_delay = data['action_info']['delay'] + self.last_queued = data['action_info']['queued'] + action_type = self.last_action_type.item() + action_attr = ACTIONS[action_type] + + # transform into env format action + tags = game_info['tags'] + raw_action = {} + raw_action['func_id'] = action_attr['func_id'] + raw_action['skip_steps'] = self.last_delay.item() + raw_action['queued'] = self.queued.item() + + unit_tags = [] + for i in range(data['selected_units_num'] - 1): # remove end flag + unit_tags.append(tags[data['action_info']['selected_units'][i].item()]) + if self._cfg.extra_units: + extra_units = torch.nonzero(data['extra_units']).squeeze(dim=1).tolist() + for unit_index in extra_units: + unit_tags.append(tags[unit_index]) + raw_action['unit_tags'] = unit_tags + if action_attr['selected_units']: + self.last_selected_unit_tags = unit_tags + else: + self.last_selected_unit_tags = None + + raw_action['target_unit_tag'] = tags[data['action_info']['target_unit'].item()] + if action_attr['target_unit']: + self.last_targeted_unit_tag = raw_action['target_unit_tag'] + else: + self.last_targeted_unit_tag = None + + x = data['action_info']['target_location'].item() % self.map_size.x + y = data['action_info']['target_location'].item() // self.map_size.x + inverse_y = max(self.map_size.y - y, 0) + raw_action['location'] = (x, inverse_y) + self.last_location = data['action_info']['target_location'] + + data['action'] = raw_action + + return data + + def _process_transition(self, obs, policy_output, timestep): + return { + 'obs': obs, + 'action': policy_output['action_info'], + } def _get_train_sample(self): pass diff --git a/dizoo/distar/policy/utils.py b/dizoo/distar/policy/utils.py index f0581bc210..b503323c6a 100644 --- a/dizoo/distar/policy/utils.py +++ b/dizoo/distar/policy/utils.py @@ -1,9 +1,11 @@ import torch +import torch.nn.functional as F from ding.torch_utils import flatten, sequence_mask from ding.utils.data import default_collate from dizoo.distar.envs import MAX_SELECTED_UNITS_NUM MASK_INF = -1e9 +EPS = 1e-9 def padding_entity_info(traj_data, max_entity_num): @@ -88,3 +90,57 @@ def collate_fn_learn(traj_batch): new_data['batch_size'] = batch_size new_data['unroll_len'] = unroll_len return new_data + + +def entropy_error(target_policy_probs_dict, target_policy_log_probs_dict, mask, head_weights_dict): + total_entropy_loss = 0. + entropy_dict = {} + for head_type in ['action_type', 'queued', 'delay', 'selected_units', 'target_unit', 'target_location']: + ent = -target_policy_probs_dict[head_type] * target_policy_log_probs_dict[head_type] + if head_type == 'selected_units': + ent = ent.sum(dim=-1) / ( + EPS + torch.log(mask['selected_units_logits_mask'].float().sum(dim=-1) + 1).unsqueeze(-1) + ) # normalize + ent = (ent * mask['selected_units_mask']).sum(-1) + ent = ent.div(mask['selected_units_mask'].sum(-1) + EPS) + elif head_type == 'target_unit': + # normalize by unit + ent = ent.sum(dim=-1) / (EPS + torch.log(mask['target_units_logits_mask'].float().sum(dim=-1) + 1)) + else: + ent = ent.sum(dim=-1) / torch.log(torch.FloatTensor([ent.shape[-1]]).to(ent.device)) + if head_type not in ['action_type', 'delay']: + ent = ent * mask['actions_mask'][head_type] + entropy = ent.mean() + entropy_dict['entropy/' + head_type] = entropy.item() + total_entropy_loss += (-entropy * head_weights_dict[head_type]) + return total_entropy_loss, entropy_dict + + +def kl_error( + target_policy_log_probs_dict, teacher_policy_logits_dict, mask, game_steps, action_type_kl_steps, head_weights_dict +): + total_kl_loss = 0. + kl_loss_dict = {} + + for head_type in ['action_type', 'queued', 'delay', 'selected_units', 'target_unit', 'target_location']: + target_policy_log_probs = target_policy_log_probs_dict[head_type] + teacher_policy_logits = teacher_policy_logits_dict[head_type] + + teacher_policy_log_probs = F.log_softmax(teacher_policy_logits, dim=-1) + teacher_policy_probs = torch.exp(teacher_policy_log_probs) + kl = teacher_policy_probs * (teacher_policy_log_probs - target_policy_log_probs) + + kl = kl.sum(dim=-1) + if head_type == 'selected_units': + kl = (kl * mask['selected_units_mask']).sum(-1) + if head_type not in ['action_type', 'delay']: + kl = kl * mask['actions_mask'][head_type] + if head_type == 'action_type': + flag = game_steps < action_type_kl_steps + action_type_kl = kl * flag + action_type_kl_loss = action_type_kl.mean() + kl_loss_dict['kl/extra_at'] = action_type_kl_loss.item() + kl_loss = kl.mean() + total_kl_loss += (kl_loss * head_weights_dict[head_type]) + kl_loss_dict['kl/' + head_type] = kl_loss.item() + return total_kl_loss, action_type_kl_loss, kl_loss_dict