A lightweight implementation of REINFORCE (policy gradient) for a continuous action space using TensorFlow / Keras. The policy network outputs an action mean and optionally a learned variance (or std), and the agent updates parameters using the Monte Carlo return.
This repo is intended as a minimal research/educational baseline you can adapt to your own environments.
- ✅ Continuous-action REINFORCE agent
- ✅ Policy network with configurable MLP backbone
- ✅ Optional learned variance (stochastic policy with trainable uncertainty)
- ✅ Uses TensorFlow Probability (
tfd.Normal) for stable log-prob computation - ✅ Simple training script + learning curve plotting
.
├── network_con.py # ConPolicyGrad policy network (mu + optional var/std)
├── con_reinforce.py # Agent (REINFORCE) with trajectory memory + update
├── con_main.py # Training loop (example)
└── utils.py # plot_learning() helper
If your filenames differ, update the import paths in train.py.
- Python 3.9+
- TensorFlow 2.x
- TensorFlow Probability
- NumPy
- Matplotlib (for plots)
Install:
pip install tensorflow tensorflow-probability numpy matplotlibRun training:
python con_main.pyThis will generate plots:
score.png— running-average episode rewardmu.png— running-average mean estimatesigma.png— running-average std/variance behavior (if enabled)
The policy is a simple MLP:
- Two hidden layers (ReLU)
- Output head for mean
mu - Optional output head for variance/std (positive via Softplus)
When learn_var=True, the network returns:
mu, varOtherwise it returns:
muThe agent stores an episode trajectory:
- states
- actions
- rewards
Then computes Monte Carlo returns:
And applies REINFORCE:
The log-probability is computed using:
tfd.Normal(loc=mu, scale=std).log_prob(action)You can customize training via:
learn_var: learn policy variance (True/False)fixed_std: use constant std whenlearn_var=Falsegamma: discount factoralpha: learning rate- layer sizes (
fc1_dims,fc2_dims, etc.)
Example:
agent = Agent(
alpha=3e-3,
gamma=0.99,
learn_var=True,
fixed_std=0.1,
)- Returns computation: Make sure returns are computed from time
tonward (not always from 0). - Shape consistency: Use consistent shapes for states/actions, ideally
(batch, dim)for network inputs. - Variance stability: Enforce a minimum variance/std to avoid numerical issues (e.g.,
min_std=1e-3).
This repo currently uses a toy reward function. To integrate with Gym / custom environment:
- Replace
reward()withenv.step(action) - Store transitions per step
- Call
learn()after each episode
Pseudo-code:
state = env.reset()
done = False
while not done:
action = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
agent.store_transition(state, action, reward)
state = next_state
agent.learn()MIT License (recommended for reuse).
For questions, suggestions, or collaboration inquiries, please open a GitHub issue.
For direct communication, please email: [email protected]