Skip to content

Commit 7d5340d

Browse files
committed
Update plot_contributions.py
1 parent 8c4d895 commit 7d5340d

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

scripts/plot_contributions.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import matplotlib.pyplot as plt
66
import numpy as np
77

8+
from utils import style_dict
9+
810
fig_width = 6.0 # inches
911
fig_height = 2.1
1012

@@ -14,6 +16,7 @@
1416
def parse_args() -> argparse.Namespace:
1517
parser = argparse.ArgumentParser()
1618
parser.add_argument('--path', help='path to XYZ configurations', required=True)
19+
parser.add_argument('--dft',help='path to dft XYZ configurations')
1720
return parser.parse_args()
1821

1922

@@ -22,26 +25,36 @@ def main():
2225
contributions = np.array(
2326
[atoms.info['contributions'] for atoms in ase.io.read(args.path, format='extxyz', index=':')]) # [c, e]
2427
contributions = contributions.transpose() # [e, c]
25-
energies = np.array([atoms.info['energy'] for atoms in ase.io.read(args.path, format='extxyz', index=':')]) # [c]
28+
energies = np.array([atoms.info['energy'] for atoms in ase.io.read(args.path, format='extxyz', index=':')]) # [e]
2629
energies = np.expand_dims(energies, axis=0) # [1, c]
30+
displacement = np.array([atoms.info['displacement'] for atoms in ase.io.read(args.path, format='extxyz', index=':')])
31+
dft_energy = np.array(
32+
[atoms.info['energy'] for atoms in ase.io.read(args.dft, format='extxyz', index=':')]) # [e]
2733

28-
array = np.concatenate([energies, contributions[1:]], axis=0)
34+
array = np.concatenate([energies, contributions[:]], axis=0)
2935

3036
# Plot curve
31-
fig, axes = plt.subplots(nrows=1, ncols=array.shape[0], figsize=(fig_width, fig_height), constrained_layout=True)
37+
fig, axes = plt.subplots(nrows=1,
38+
ncols=array.shape[0],
39+
sharey='row',
40+
figsize=(fig_width, fig_height),
41+
constrained_layout=True)
3242

3343
for i, (ax, energies) in enumerate(zip(axes, array)):
34-
e_min = np.min(energies)
35-
ax.plot(energies - e_min, color='black')
44+
e_shift = energies[-1]
45+
ax.plot(displacement, energies - e_shift, **style_dict['botnet'])
3646
if i == 0:
37-
ax.set_title(r'$E_\mathrm{tot}$' + f' - ({e_min:.3f})')
47+
ax.plot(displacement, dft_energy - e_shift, **style_dict['dft'])
48+
ax.set_title(r'$E_\mathrm{tot}$' + f' - ({e_shift:.3f} eV)')
3849
else:
39-
ax.set_title(rf'$E_{i}$' + f' - ({e_min:.3f})')
50+
j = i-1
51+
ax.set_title(rf'$E_{j}$' + f' - ({e_shift:.3f} eV)')
4052

4153
axes[0].set_ylabel(r'$E$ [eV]')
42-
# ax.legend()
54+
axes[0].set_xlabel('Displacement')
55+
axes[0].legend(bbox_to_anchor=(1.04,1), loc="upper left")
4356

44-
plt.show()
57+
fig.savefig('contributions.pdf')
4558

4659

4760
if __name__ == '__main__':

0 commit comments

Comments
 (0)