The Wayback Machine - https://web.archive.org/web/20210916112912/https://github.com/ray-project/ray/issues/16718
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Policy weights overwritten in self-play #16718

Open
george-skal opened this issue Jun 28, 2021 · 2 comments
Open

[RLlib] Policy weights overwritten in self-play #16718

george-skal opened this issue Jun 28, 2021 · 2 comments

Comments

@george-skal
Copy link

@george-skal george-skal commented Jun 28, 2021

Hi all!
I am trying a self-play based scheme, where I want to have two agents in waterworld environment have a policy that is being trained (“shared_policy_1”) and other 3 agents that sample a policy from a menagerie (set) of the previous policies of the first two agents ( “shared_policy_2”).
My problem is that I see that the weights in the menagerie are overwritten in every iteration by the current weights. The problem is not happening with ray.init(local_mode=True), but happens without local mode.

The problem seems to be that, when you call get_weights it returns a dictionary of numpy arrays with one entry per parameter in the model. It looks like what is happening is that when local_mode=True the numpy arrays in the weight dictionary are references to unique objects but when local_mode=False the numpy arrays are references to the same numpy array and the values are changing in the learn_on_batch steps. This means that you end up with a list of dictionaries that all have the references to the same numpy object so when it updates every dictionary is also updated.

Please check here https://discuss.ray.io/t/policy-weights-overwritten-in-self-play/2520 for the full discussion, the weights printed and a proposed solution.

The code is:

from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
import argparse
import gym
import os
import random
import ray
import numpy as np
from ray.tune.registry import register_env
from ray.rllib.env.pettingzoo_env import PettingZooEnv
from pettingzoo.sisl import waterworld_v3

M = 5  # Menagerie size
men = []


class MyCallbacks(DefaultCallbacks):

    def on_train_result(self, *, trainer, result: dict, **kwargs):
        print("trainer.train() result: {} -> {} episodes".format(
            trainer, result["episodes_this_iter"]))
        i = result['training_iteration']    # starts from 1
        # the "shared_policy_1" is the only agent being trained
        print("training iteration:", i)
        global men

        if i <= M:
            # menagerie initialisation
            tmp = trainer.get_policy("shared_policy_1").get_weights()
            men.append(tmp)

            filename1 = 'file_init_' + str(i) + '.txt'
            textfile1 = open(filename1, "w")
            for element1 in men:
                textfile1.write("############# menagerie entries ##################" + "\n")
                textfile1.write(str(element1) + "\n")
            textfile1.close()

        else:
            # the first policy added is erased
            men.pop(0)
            # add current training policy in the last position of the menagerie
            w = trainer.get_policy("shared_policy_1").get_weights()
            men.append(w)
            # select one policy randomly
            sel = random.randint(0, M-1)

            trainer.set_weights({"shared_policy_2": men[sel]})

            weights = ray.put(trainer.workers.local_worker().save())
            trainer.workers.foreach_worker(
                lambda w: w.restore(ray.get(weights))
            )

        filename = 'file' + str(i) + '.txt'
        textfile = open(filename, "w")
        for element in men:
            textfile.write("############# menagerie entries ##################" + "\n")
            textfile.write(str(element) + "\n")

        # you can mutate the result dict to add new fields to return
        result["callback_ok"] = True


if __name__ == "__main__":

    ray.init()

    def env_creator(args):
        return PettingZooEnv(waterworld_v3.env(n_pursuers=5, n_evaders=5))

    env = env_creator({})
    register_env("waterworld", env_creator)
    obs_space = env.observation_space
    act_spc = env.action_space

    policies = {"shared_policy_1": (None, obs_space, act_spc, {}),
                "shared_policy_2": (None, obs_space, act_spc, {})
                }

    def policy_mapping_fn(agent_id):
        if agent_id == "pursuer_0" or "pursuer_1":
            return "shared_policy_1"
        else:
            return "shared_policy_2"

    tune.run(
        "PPO",
        name="PPO self play n = 5, M=5 trial 1",
        stop={"episodes_total": 50000},
        checkpoint_freq=10,
        config={
            # Enviroment specific
            "env": "waterworld",
            # General
            "framework": "torch",
            "callbacks": MyCallbacks,
            "num_gpus": 0,
            "num_workers": 0,
            # Method specific
            "multiagent": {
                "policies": policies,
                "policies_to_train": ["shared_policy_1"],
                "policy_mapping_fn": policy_mapping_fn,
            },
        },
    )

Thanks,
George

@richardliaw richardliaw changed the title [RLlib] Policy weights overwritten in self-play [RLlib] Policy weights overwritten in self-play for local_mode Jul 2, 2021
@george-skal
Copy link
Author

@george-skal george-skal commented Jul 9, 2021

Hey @richardliaw , I saw that you changed the title, so just to make sure we are in the same page, the problem is that in the above code the weights are overwritten without local_mode = True and with local_mode = True the code works ok.

@richardliaw richardliaw changed the title [RLlib] Policy weights overwritten in self-play for local_mode [RLlib] Policy weights overwritten in self-play Jul 9, 2021
@richardliaw
Copy link
Contributor

@richardliaw richardliaw commented Jul 9, 2021

@george-skal thanks for this point -- fixed it!

@michaelzhiluo can you take a quick look at this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment