Как я узнаю, что агенты работают вместе? - PullRequest
0 голосов
/ 10 февраля 2020

Я использую ML-Agents уже несколько месяцев и работаю над самобалансирующейся парой ног. Тем не менее, у меня был вопрос, который зудел меня с самого начала: Как я ЗНАЮ, что агенты работают вместе ? Все, что я сделал, это скопировал и вставил область заранее 9 раз. Это все, что вам нужно сделать, чтобы агенты учились более эффективно? Или я что-то пропустил? Спасибо.

training

Сценарий агента >>> (Мне не нужно было использовать какие-либо другие сценарии, кроме этого. У области и академии нет ничего в них.)

using MLAgents;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents.Sensor;
using Random = UnityEngine.Random;

public class BalanceAgent : Agent {

    private BalancingArea area;
    public GameObject floor;
    public GameObject finishBall;
    public GameObject waist;
    public GameObject wFront;           //Used to check balance of waist.
    public GameObject wBack;           //Used to check balance of waist.
    public GameObject hipR;
    public GameObject hipL;
    public GameObject buttR;
    public GameObject buttL;
    public GameObject thighR;
    public GameObject thighL;
    public GameObject legR;
    public GameObject legL;
    public GameObject footR;
    public GameObject footL;

    public BehaviorParameters behavePar;

    public GameObject sensorFront;
    public GameObject sensorBack;
    public GameObject sensorLeft;
    public GameObject sensorRight;

    public float bodyMoveSensitivity = 0.5f;

    public GameObject[] bodyParts = new GameObject[11];
    HingeJoint[] hingeParts = new HingeJoint[11];
    JointLimits[] jntLimParts = new JointLimits[11];

    Vector3[] posStart = new Vector3[11];
    Vector3[] eulerStart = new Vector3[11];

    public Vector3 waistRot;

    public float waistVec = 0;
    public float buttRVec = 0;
    public float buttLVec = 0;
    public float thighRVec = 0;
    public float thighLVec = 0;
    public float legRVec = 0;
    public float legLVec = 0;
    public float footRVec = 0;
    public float footLVec = 0;
    public float hipRVec = 0;
    public float hipLVec = 0;
    public float waistPushXVec = 0;
    public float waistPushZVec = 0;

    float waistDir = 0;
    float buttRDir = 0;
    float buttLDir = 0;
    float thighRDir = 0;
    float thighLDir = 0;
    float legRDir = 0;
    float legLDir = 0;
    float footRDir = 0;
    float footLDir = 0;
    float hipRDir = 0;
    float hipLDir = 0;
    float waistPushDirX = 0;
    float waistPushDirZ = 0;

    public void Start() {
        bodyParts = new GameObject[] { waist /*0*/, buttR /*1*/, buttL /*2*/, thighR /*3*/, thighL /*4*/, legR /*5*/, legL /*6*/, footR /*7*/, footL /*8*/, hipR /*9*/, hipL /*10*/};

        for (int i = 0; i < bodyParts.Length; i++) {
            posStart[i] = bodyParts[i].transform.position;
            eulerStart[i] = bodyParts[i].transform.eulerAngles;
            if (bodyParts[i].GetComponent<HingeJoint>() != null) {
                hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
                hingeParts[i].limits = jntLimParts[i];
            }
        }
    }

    public override void InitializeAgent() {
        base.InitializeAgent();
        area = GetComponentInParent<BalancingArea>();
    }

    public override void AgentReset() {
        //floor.transform.eulerAngles = new Vector3(Random.Range(-10, 10), 0, Random.Range(-10, 10));             //Floor random rotation
        //finishBall.transform.localPosition = new Vector3(Random.Range(-7, 7), .65f, Random.Range(-7, 7));             //Ball random position

        jntLimParts[1].max = 0;
        jntLimParts[1].min = jntLimParts[1].max - 1;
        hingeParts[1].limits = jntLimParts[1];

        jntLimParts[2].max = 0;
        jntLimParts[2].min = jntLimParts[2].max - 1;
        hingeParts[2].limits = jntLimParts[2];

        jntLimParts[3].max = -15;
        jntLimParts[3].min = jntLimParts[3].max - 1;
        hingeParts[3].limits = jntLimParts[3];

        jntLimParts[4].max = -15;
        jntLimParts[4].min = jntLimParts[4].max - 1;
        hingeParts[4].limits = jntLimParts[4];

        jntLimParts[5].max = 15;
        jntLimParts[5].min = jntLimParts[5].max - 1;
        hingeParts[5].limits = jntLimParts[5];

        jntLimParts[6].max = 15;
        jntLimParts[6].min = jntLimParts[6].max - 1;
        hingeParts[6].limits = jntLimParts[6];

        jntLimParts[7].max = -15;
        jntLimParts[7].min = jntLimParts[7].max - 1;
        hingeParts[7].limits = jntLimParts[7];

        jntLimParts[8].max = -15;
        jntLimParts[8].min = jntLimParts[8].max - 1;
        hingeParts[8].limits = jntLimParts[8];

        jntLimParts[9].max = 0;
        jntLimParts[9].min = jntLimParts[9].max - 1;
        hingeParts[9].limits = jntLimParts[9];

        jntLimParts[10].max = 0;
        jntLimParts[10].min = jntLimParts[10].max - 1;
        hingeParts[10].limits = jntLimParts[10];

        for (int i = 0; i < bodyParts.Length; i++) {
            bodyParts[i].transform.position = posStart[i];
            bodyParts[i].transform.eulerAngles = eulerStart[i];
            bodyParts[i].GetComponent<Rigidbody>().velocity = Vector3.zero;
            bodyParts[i].GetComponent<Rigidbody>().angularVelocity = Vector3.zero;
            if (bodyParts[i].GetComponent<HingeJoint>() != null) {
                hingeParts[i] = bodyParts[i].GetComponent<HingeJoint>();
                hingeParts[i].limits = jntLimParts[i];
            }
        }

        //waist.transform.eulerAngles = new Vector3(0, Random.Range(0, 360), 0);                //Random player direction
        waistRot = waist.transform.eulerAngles;
    }

    public override void AgentAction(float[] vectorAction) {

        waistVec = (int)vectorAction[0];
        switch (waistVec) {
            case 0:
                waistDir = 0;
                break;
            case 1:
                waistDir = bodyMoveSensitivity;
                break;
            case 2:
                waistDir = -bodyMoveSensitivity;
                break;
            case 3:
                waistDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                waistDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                waistDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                waistDir = -bodyMoveSensitivity * 4;
                break;
        }
        bodyParts[0].transform.Rotate(0, waistDir, 0);

        buttRVec = (int)vectorAction[1];
        switch (buttRVec) {
            case 0:
                buttRDir = 0;
                break;
            case 1:
                buttRDir = bodyMoveSensitivity;
                break;
            case 2:
                buttRDir = -bodyMoveSensitivity;
                break;
            case 3:
                buttRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                buttRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                buttRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                buttRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[1].max < 60 && jntLimParts[1].min > -5) {
            jntLimParts[1].max += buttRDir;
            jntLimParts[1].min = jntLimParts[1].max - 1;
            hingeParts[1].limits = jntLimParts[1];
        }
        else {              //If joint is at limit,
            if (jntLimParts[1].min <= -5) {
                jntLimParts[1].max = -4;

            }
            else if (jntLimParts[1].max >= 60) {
                jntLimParts[1].max = 59;
            }
            jntLimParts[1].min = jntLimParts[1].max - 1;
        }

        buttLVec = (int)vectorAction[2];
        switch (buttLVec) {
            case 0:
                buttLDir = 0;
                break;
            case 1:
                buttLDir = bodyMoveSensitivity;
                break;
            case 2:
                buttLDir = -bodyMoveSensitivity;
                break;
            case 3:
                buttLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                buttLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                buttLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                buttLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[2].max < 5 && jntLimParts[2].min > -60) {
            jntLimParts[2].max += buttLDir;
            jntLimParts[2].min = jntLimParts[2].max - 1;
            hingeParts[2].limits = jntLimParts[2];
        }
        else {              //If joint is at limit,
            if (jntLimParts[2].min <= -60) {
                jntLimParts[2].max = -58;

            }
            else if (jntLimParts[2].max >= 5) {
                jntLimParts[2].max = 4;
            }
            jntLimParts[2].min = jntLimParts[2].max - 1;
        }

        thighRVec = (int)vectorAction[3];
        switch (thighRVec) {
            case 0:
                thighRDir = 0;
                break;
            case 1:
                thighRDir = bodyMoveSensitivity;
                break;
            case 2:
                thighRDir = -bodyMoveSensitivity;
                break;
            case 3:
                thighRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                thighRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                thighRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                thighRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[3].max < 80 && jntLimParts[3].min > -80) {
            jntLimParts[3].max += thighRDir;
            jntLimParts[3].min = jntLimParts[3].max - 1;
            hingeParts[3].limits = jntLimParts[3];
        }
        else {              //If joint is at limit,
            if (jntLimParts[3].min <= -80) {
                jntLimParts[3].max = -78;

            }
            else if (jntLimParts[3].max >= 80) {
                jntLimParts[3].max = 79;
            }
            jntLimParts[3].min = jntLimParts[3].max - 1;
        }

        thighLVec = (int)vectorAction[4];
        switch (thighLVec) {
            case 0:
                thighLDir = 0;
                break;
            case 1:
                thighLDir = bodyMoveSensitivity;
                break;
            case 2:
                thighLDir = -bodyMoveSensitivity;
                break;
            case 3:
                thighLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                thighLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                thighLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                thighLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[4].max < 80 && jntLimParts[4].min > -80) {
            jntLimParts[4].max += thighLDir;
            jntLimParts[4].min = jntLimParts[4].max - 1;
            hingeParts[4].limits = jntLimParts[4];
        }
        else {              //If joint is at limit,
            if (jntLimParts[4].min <= -80) {
                jntLimParts[4].max = -78;

            }
            else if (jntLimParts[4].max >= 80) {
                jntLimParts[4].max = 79;
            }
            jntLimParts[4].min = jntLimParts[4].max - 1;
        }

        legRVec = (int)vectorAction[5];
        switch (legRVec) {
            case 0:
                legRDir = 0;
                break;
            case 1:
                legRDir = bodyMoveSensitivity;
                break;
            case 2:
                legRDir = -bodyMoveSensitivity;
                break;
            case 3:
                legRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                legRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                legRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                legRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[5].max < -3 && jntLimParts[5].min > 80) {
            jntLimParts[5].max += legRDir;
            jntLimParts[5].min = jntLimParts[5].max - 1;
            hingeParts[5].limits = jntLimParts[5];
        }
        else {              //If joint is at limit,
            if (jntLimParts[5].min <= -3) {
                jntLimParts[5].max = -1;

            }
            else if (jntLimParts[5].max >= 80) {
                jntLimParts[5].max = 79;
            }
            jntLimParts[5].min = jntLimParts[5].max - 1;
        }

        legLVec = (int)vectorAction[6];
        switch (legLVec) {
            case 0:
                legLDir = 0;
                break;
            case 1:
                legLDir = bodyMoveSensitivity;
                break;
            case 2:
                legLDir = -bodyMoveSensitivity;
                break;
            case 3:
                legLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                legLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                legLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                legLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[6].max < 80 && jntLimParts[6].min > -3) {
            jntLimParts[6].max += legLDir;
            jntLimParts[6].min = jntLimParts[6].max - 1;
            hingeParts[6].limits = jntLimParts[6];
        }
        else {              //If joint is at limit,
            if (jntLimParts[6].min <= -3) {
                jntLimParts[6].max = -1;

            }
            else if (jntLimParts[6].max >= 80) {
                jntLimParts[6].max = 79;
            }
            jntLimParts[6].min = jntLimParts[6].max - 1;
        }

        footRVec = (int)vectorAction[7];
        switch (footRVec) {
            case 0:
                footRDir = 0;
                break;
            case 1:
                footRDir = bodyMoveSensitivity;
                break;
            case 2:
                footRDir = -bodyMoveSensitivity;
                break;
            case 3:
                footRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                footRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                footRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                footRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[7].max < 50 && jntLimParts[7].min > -50) {
            jntLimParts[7].max += footRDir;
            jntLimParts[7].min = jntLimParts[7].max - 1;
            hingeParts[7].limits = jntLimParts[7];
        }
        else {              //If joint is at limit,
            if (jntLimParts[7].min <= -50) {
                jntLimParts[7].max = -48;

            }
            else if (jntLimParts[7].max >= 50) {
                jntLimParts[7].max = 49;
            }
            jntLimParts[7].min = jntLimParts[7].max - 1;
        }

        footLVec = (int)vectorAction[8];
        switch (footLVec) {
            case 0:
                footLDir = 0;
                break;
            case 1:
                footLDir = bodyMoveSensitivity;
                break;
            case 2:
                footLDir = -bodyMoveSensitivity;
                break;
            case 3:
                footLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                footLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                footLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                footLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[8].max < 50 && jntLimParts[8].min > -50) {
            jntLimParts[8].max += footLDir;
            jntLimParts[8].min = jntLimParts[8].max - 1;
            hingeParts[8].limits = jntLimParts[8];
        }
        else {              //If joint is at limit,
            if (jntLimParts[8].min <= -50) {
                jntLimParts[8].max = -48;

            }
            else if (jntLimParts[8].max >= 50) {
                jntLimParts[8].max = 49;
            }
            jntLimParts[8].min = jntLimParts[8].max - 1;
        }


        hipRVec = (int)vectorAction[9];
        switch (hipRVec) {
            case 0:
                hipRDir = 0;
                break;
            case 1:
                hipRDir = bodyMoveSensitivity;
                break;
            case 2:
                hipRDir = -bodyMoveSensitivity;
                break;
            case 3:
                hipRDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                hipRDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                hipRDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                hipRDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[9].max < 45 && jntLimParts[9].min > -15) {
            jntLimParts[9].max += hipRDir;
            jntLimParts[9].min = jntLimParts[9].max - 1;
            hingeParts[9].limits = jntLimParts[9];
        }
        else {              //If joint is at limit,
            if (jntLimParts[9].min <= -15) {
                jntLimParts[9].max = -13;

            }
            else if (jntLimParts[9].max >= 45) {
                jntLimParts[9].max = 44;
            }
            jntLimParts[9].min = jntLimParts[9].max - 1;
        }

        hipLVec = (int)vectorAction[10];
        switch (hipLVec) {
            case 0:
                hipLDir = 0;
                break;
            case 1:
                hipLDir = bodyMoveSensitivity;
                break;
            case 2:
                hipLDir = -bodyMoveSensitivity;
                break;
            case 3:
                hipLDir = bodyMoveSensitivity * 2;
                break;
            case 4:
                hipLDir = -bodyMoveSensitivity * 2;
                break;
            case 5:
                hipLDir = bodyMoveSensitivity * 4;
                break;
            case 6:
                hipLDir = -bodyMoveSensitivity * 4;
                break;
        }
        if (jntLimParts[10].max < 15 && jntLimParts[10].min > -45) {
            jntLimParts[10].max += hipLDir;
            jntLimParts[10].min = jntLimParts[10].max - 1;
            hingeParts[10].limits = jntLimParts[10];
        }
        else {              //If joint is at limit,
            if (jntLimParts[10].min <= -45) {
                jntLimParts[10].max = -43;

            }
            else if (jntLimParts[10].max >= 15) {
                jntLimParts[10].max = 14;
            }
            jntLimParts[10].min = jntLimParts[10].max - 1;
        }

        waistPushXVec = (int)vectorAction[11];
        switch (waistPushXVec) {
            case 0:
                waistPushDirX = 0;
                break;
            case 1:
                waistPushDirX = -1;
                break;
            case 2:
                waistPushDirX = 1;
                break;
        }
        waistPushZVec = (int)vectorAction[12];
        switch (waistPushZVec) {
            case 0:
                waistPushDirZ = 0;
                break;
            case 1:
                waistPushDirZ = -1;
                break;
            case 2:
                waistPushDirZ = 1;
                break;
        }
        waist.GetComponent<Rigidbody>().AddForce(waistPushDirX, 0, waistPushDirZ);              //Try to help move waist

        //waist.transform.eulerAngles = new Vector3(0, waistRot.y, 0);


        sensorFront.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y - 90, 0);                //Forces sensor to look down constantly.
        sensorBack.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y + 90, 0);                //Forces sensor to look down constantly.
        sensorLeft.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y - 180, 0);                //Forces sensor to look down constantly.
        sensorRight.transform.eulerAngles = new Vector3(0, waist.transform.eulerAngles.y, 0);                //Forces sensor to look down constantly.



        //Reward SYSTEM #####################################################################################################################################################################
        AddReward(.1f);             //Survival reward.

        if (Mathf.Abs(finishBall.transform.position.x - waist.transform.position.x) > .25f && Mathf.Abs(finishBall.transform.position.z - waist.transform.position.z) > .25f) {               //Maintain waist position to ball
            AddReward(-.1f * Mathf.Abs(finishBall.transform.position.x - waist.transform.position.x));
        }

        if (waist.GetComponent<Rigidbody>().velocity.magnitude >= 20f) {               //Maintain waist slow velocity.
            AddReward(-.1f);
            Done();
        }

        if (waist.transform.position.y < -2 || waist.transform.position.y > 6) {               //Maintain waist height.
            AddReward(-.1f * Mathf.Abs(finishBall.transform.position.y - waist.transform.position.y));
            Done();
        }

        if (waist.transform.eulerAngles.y > waistRot.y + 25) {                //Maintain waist rotation on Y
            AddReward(-.1f * Mathf.Abs(waist.transform.eulerAngles.y - waistRot.y));
            Done();
        }
        if (waist.transform.eulerAngles.y < waistRot.y - 25) {                //Maintain waist rotation on Y
            AddReward(-.1f * Mathf.Abs(waistRot.y - waist.transform.eulerAngles.y));
            Done();
        }

        if (wFront.transform.position.y < wBack.transform.position.y - 25) {                //Maintain waist rotation forward and backwards.
            AddReward(-.1f * Mathf.Abs(wBack.transform.position.y - wFront.transform.position.y));
            Done();
        }
        if (wFront.transform.position.y > wBack.transform.position.y + 25) {                //Maintain waist rotation forward and backwards.
            AddReward(-.1f * Mathf.Abs(wFront.transform.position.y - wBack.transform.position.y));
            Done();
        }

        if (buttR.transform.position.y < buttL.transform.position.y - 25) {                //Maintain waist rotation left and right.
            AddReward(-.1f * Mathf.Abs(buttL.transform.position.y - buttR.transform.position.y));
            Done();
        }
        if (buttR.transform.position.y > buttL.transform.position.y + 25) {                //Maintain waist rotation left and right.
            AddReward(-.1f * Mathf.Abs(buttR.transform.position.y - buttL.transform.position.y));
            Done();
        }

        /*
        if (waist.transform.position.x > posStart[0].x + 10 || waist.transform.position.x < posStart[0].x - 10 || waist.transform.position.z > posStart[0].z + 10 || waist.transform.position.z < posStart[0].z - 10) {              //Maintain waist position.
            AddReward(-.01f);
            Done();
        }
        */
        //Reward SYSTEM #####################################################################################################################################################################
    }

    public override void CollectObservations() {

        for (int i = 0; i < bodyParts.Length; i++) {
            AddVectorObs(bodyParts[i].transform.position);
            AddVectorObs(bodyParts[i].transform.eulerAngles);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().velocity);
            AddVectorObs(bodyParts[i].GetComponent<Rigidbody>().angularVelocity);
            AddVectorObs(jntLimParts[i].max);
            AddVectorObs(jntLimParts[i].min);
        }

        AddVectorObs(wFront.transform.position.y);
        AddVectorObs(wFront.transform.eulerAngles);
        AddVectorObs(wBack.transform.position.y);
        AddVectorObs(wBack.transform.eulerAngles);
        AddVectorObs(waistRot);             //Waist rotation value after randomization.
        AddVectorObs(finishBall.transform.position);             //Waist rotation value after randomization.
    }
}

1 Ответ

1 голос
/ 10 февраля 2020

Я считаю, что да, все, что вам нужно сделать, это иметь несколько экземпляров префаба. Пока в сцене несколько Area с, они должны иметь возможность координировать свои партии для обучения.

Если вы хотите измерить, как наличие нескольких областей меняет вещи, я бы выделил одну область и позволил бы ей играть некоторое время, а также посмотрел на график совокупного вознаграждения в зависимости от количества эпизодов и увидел, как высоко он достигает, затем сделайте то же самое со многими областями и посмотрите, как с этим выглядит тот же график.

...