Skip to content

Commit

Permalink
Add telemetry for sb3 integration (#2830)
Browse files Browse the repository at this point in the history
  • Loading branch information
raubitsj authored Oct 25, 2021
1 parent 7aa469d commit 7d69938
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,5 @@
#!/usr/bin/env python
"""Test stable_baselines3 integration
---
id: 0.0.4
plugin:
- wandb
assert:
- :wandb:runs_len: 1
- :wandb:runs[0][config][policy_type]: MlpPolicy
- :wandb:runs[0][config][total_timesteps]: 200
- :wandb:runs[0][config][policy_class]: "<class 'stable_baselines3.common.policies.ActorCriticPolicy'>"
- :wandb:runs[0][config][action_space]: "Discrete(2)"
- :wandb:runs[0][config][batch_size]: 64
- :wandb:runs[0][config][n_epochs]: 10
- :wandb:runs[0][exitcode]: 0
"""
"""Test stable_baselines3 integration"""

import gym
from stable_baselines3 import PPO
Expand Down
37 changes: 37 additions & 0 deletions functional_tests/sb3/01-stable-baselines3.yea
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
id: 0.sb3.01-stable-baselines3
plugin:
- wandb
assert:
- :wandb:runs_len: 1
- :wandb:runs[0][config][policy_type]: MlpPolicy
- :wandb:runs[0][config][total_timesteps]: 200
- :wandb:runs[0][config][policy_class]: "<class 'stable_baselines3.common.policies.ActorCriticPolicy'>"
- :wandb:runs[0][config][action_space]: "Discrete(2)"
- :wandb:runs[0][config][batch_size]: 64
- :wandb:runs[0][config][n_epochs]: 10
- :wandb:runs[0][summary][global_step]: 2048
- :wandb:runs[0][summary][gradients/action_net.bias][_type]: histogram
- :wandb:runs[0][summary][gradients/action_net.weight][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.policy_net.0.bias][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.policy_net.0.weight][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.policy_net.2.bias][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.policy_net.2.weight][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.value_net.0.bias][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.value_net.0.weight][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.value_net.2.bias][_type]: histogram
- :wandb:runs[0][summary][gradients/mlp_extractor.value_net.2.weight][_type]: histogram
- :wandb:runs[0][summary][gradients/value_net.bias][_type]: histogram
- :wandb:runs[0][summary][gradients/value_net.weight][_type]: histogram
- :op:>:
- :wandb:runs[0][summary][rollout/ep_len_mean]
- 0.0
- :op:>:
- :wandb:runs[0][summary][rollout/ep_rew_mean]
- 0.0
- :op:>:
- :wandb:runs[0][summary][time/fps]
- 0.0
- :op:contains:
- :wandb:runs[0][telemetry][3] # feature
- 22 # sb3
- :wandb:runs[0][exitcode]: 0
3 changes: 3 additions & 0 deletions wandb/integration/sb3/sb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def make_env():

from stable_baselines3.common.callbacks import BaseCallback
import wandb
from wandb.sdk.lib import telemetry as wb_telemetry


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(
super(WandbCallback, self).__init__(verbose)
if wandb.run is None:
raise wandb.Error("You must call wandb.init() before WandbCallback()")
with wb_telemetry.context() as tel:
tel.feature.sb3 = True
self.model_save_freq = model_save_freq
self.model_save_path = model_save_path
self.gradient_save_freq = gradient_save_freq
Expand Down
1 change: 1 addition & 0 deletions wandb/proto/wandb_telemetry.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ message Feature {
bool set_config_item = 19; // users set key in run config via run.config.key or run.config["key"]
bool launch = 20; // run is created through wandb launch
bool torch_profiler_trace = 21; // wandb.profiler.torch_trace_handler() called
bool sb3 = 22; // Using stable_baselines3 integration
}

message Env {
Expand Down
19 changes: 13 additions & 6 deletions wandb/proto/wandb_telemetry_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion wandb/proto/wandb_telemetry_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class Feature(google.protobuf.message.Message):
SET_CONFIG_ITEM_FIELD_NUMBER: builtins.int
LAUNCH_FIELD_NUMBER: builtins.int
TORCH_PROFILER_TRACE_FIELD_NUMBER: builtins.int
SB3_FIELD_NUMBER: builtins.int
watch: builtins.bool = ...
finish: builtins.bool = ...
save: builtins.bool = ...
Expand All @@ -159,6 +160,7 @@ class Feature(google.protobuf.message.Message):
set_config_item: builtins.bool = ...
launch: builtins.bool = ...
torch_profiler_trace: builtins.bool = ...
sb3: builtins.bool = ...

def __init__(self,
*,
Expand All @@ -183,8 +185,9 @@ class Feature(google.protobuf.message.Message):
set_config_item : builtins.bool = ...,
launch : builtins.bool = ...,
torch_profiler_trace : builtins.bool = ...,
sb3 : builtins.bool = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal[u"artifact_incremental",b"artifact_incremental",u"finish",b"finish",u"grpc",b"grpc",u"keras",b"keras",u"launch",b"launch",u"metaflow",b"metaflow",u"metric",b"metric",u"offline",b"offline",u"prodigy",b"prodigy",u"resumed",b"resumed",u"sagemaker",b"sagemaker",u"save",b"save",u"set_config_item",b"set_config_item",u"set_init_config",b"set_init_config",u"set_init_id",b"set_init_id",u"set_init_name",b"set_init_name",u"set_init_tags",b"set_init_tags",u"set_run_name",b"set_run_name",u"set_run_tags",b"set_run_tags",u"torch_profiler_trace",b"torch_profiler_trace",u"watch",b"watch"]) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal[u"artifact_incremental",b"artifact_incremental",u"finish",b"finish",u"grpc",b"grpc",u"keras",b"keras",u"launch",b"launch",u"metaflow",b"metaflow",u"metric",b"metric",u"offline",b"offline",u"prodigy",b"prodigy",u"resumed",b"resumed",u"sagemaker",b"sagemaker",u"save",b"save",u"sb3",b"sb3",u"set_config_item",b"set_config_item",u"set_init_config",b"set_init_config",u"set_init_id",b"set_init_id",u"set_init_name",b"set_init_name",u"set_init_tags",b"set_init_tags",u"set_run_name",b"set_run_name",u"set_run_tags",b"set_run_tags",u"torch_profiler_trace",b"torch_profiler_trace",u"watch",b"watch"]) -> None: ...
global___Feature = Feature

class Env(google.protobuf.message.Message):
Expand Down

0 comments on commit 7d69938

Please sign in to comment.