ToM
Requirments
numpy
scipy
pytorch >= 1.7.0
torchvision
pygame
Run
Train
the file to be run: main_both.py
args:
the path to save net_NPC: –save_net_N
the path to save net_a: –save_net_a
time steps: –T
python main_both.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both
Test
python main_ToM.py --save_net_N=net_NPC.pth --save_net_a=net_agent.pth --episodes=45 --trajectories=30 --T=50 --mode=train --task=both