Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vaspy/atomco.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def get_poscar_content(self, **kwargs):
tf = self.tf
except AttributeError:
# Initialize tf with 'T's.
default_tf = np.full(self.data.shape, 'T', dtype=np.str)
default_tf = np.full(self.data.shape, 'T', dtype=str)
tf = kwargs.get("tf", default_tf)
data_tf = ''
if coord_type == 'direct':
Expand Down
120 changes: 120 additions & 0 deletions vaspy/tests/electro_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# -*- coding:utf-8 -*-
'''
Unit tests for vaspy.electro module.
'''

import unittest
import os
import copy

import numpy as np

from ..electro import DosX, ElfCar, ChgCar
from . import path


class DosXTest(unittest.TestCase):

def setUp(self):
self.filename = os.path.join(path, "DOS_SUM")

def test_load(self):
dosx = DosX(self.filename)
self.assertIsNotNone(dosx.data)
self.assertGreater(dosx.data.shape[0], 0)

def test_reset_data(self):
dosx = DosX(self.filename)
dosx.reset_data()
self.assertTrue(np.all(dosx.data[:, 1:] == 0.0))

def test_add(self):
dosx1 = DosX(self.filename)
dosx2 = DosX(self.filename)
dos_sum = dosx1 + dosx2
self.assertEqual(dos_sum.filename, "DOS_SUM")

def test_deepcopy(self):
dosx = DosX(self.filename)
dosx_copy = copy.deepcopy(dosx)
self.assertTrue(np.all(dosx.data == dosx_copy.data))
self.assertIsNot(dosx.data, dosx_copy.data)

def test_tofile(self):
dosx = DosX(self.filename)
outfile = os.path.join(path, "_test_dos_output.txt")
try:
dosx.tofile(filename=outfile)
self.assertTrue(os.path.exists(outfile))
finally:
if os.path.exists(outfile):
os.remove(outfile)

def test_get_dband_center(self):
dosx = DosX(self.filename)
dbc = dosx.get_dband_center(d_cols=(5, 10))
self.assertIsNotNone(dbc)
self.assertEqual(dosx.dband_center, dbc)

def test_get_dband_center_int_arg(self):
dosx = DosX(self.filename)
dbc = dosx.get_dband_center(d_cols=5)
self.assertIsNotNone(dbc)

def test_add_mismatched_energy_raises(self):
dosx1 = DosX(self.filename)
dosx2 = DosX(self.filename)
dosx2.data[0, 0] = 999.0
with self.assertRaises(ValueError):
dosx1 + dosx2


class ElfCarTest(unittest.TestCase):

def setUp(self):
self.filename = os.path.join(path, "ELFCAR")

def test_load(self):
elf = ElfCar(self.filename)
self.assertIsNotNone(elf.elf_data)
self.assertEqual(len(elf.elf_data.shape), 3)
self.assertIsNotNone(elf.grid)

def test_expand_data(self):
elf = ElfCar(self.filename)
expanded_data, expanded_grid = elf.expand_data(elf.elf_data, elf.grid, (2, 1, 1))
self.assertEqual(expanded_data.shape[0], elf.elf_data.shape[0] * 2)
self.assertEqual(expanded_grid[0], elf.grid[0] * 2)

def test_contour_bad_distance(self):
elf = ElfCar(self.filename)
with self.assertRaises(ValueError):
elf.plot_contour(distance=1.5)

def test_contour_bad_show_mode(self):
elf = ElfCar(self.filename)
with self.assertRaises(ValueError):
elf.plot_contour(show_mode='bad')

def test_contour_cut_x(self):
elf = ElfCar(self.filename)
elf.plot_contour(axis_cut='x', show_mode='save')

def test_contour_cut_y(self):
elf = ElfCar(self.filename)
elf.plot_contour(axis_cut='y', show_mode='save')

def test_contour_cut_z(self):
elf = ElfCar(self.filename)
elf.plot_contour(axis_cut='z', show_mode='save')


class ChgCarTest(unittest.TestCase):

def setUp(self):
self.filename = os.path.join(path, "ELFCAR")

def test_init(self):
chg = ChgCar(self.filename)
self.assertIsNotNone(chg.elf_data)
self.assertIsNotNone(chg.grid)
28 changes: 28 additions & 0 deletions vaspy/tests/elements_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding:utf-8 -*-
'''
Unit tests for vaspy.elements module.
'''

import unittest

from .. import elements


class ElementsTest(unittest.TestCase):

def test_C12(self):
self.assertAlmostEqual(elements.C12, 1.99264648e-26)

def test_amu(self):
self.assertAlmostEqual(elements.amu, 1.66053904e-27)

def test_chem_elements_has_H(self):
self.assertIn('H', elements.chem_elements)
self.assertEqual(elements.chem_elements['H']['index'], 1)

def test_chem_elements_has_Ni(self):
self.assertIn('Ni', elements.chem_elements)
self.assertEqual(elements.chem_elements['Ni']['index'], 28)

def test_chem_elements_count(self):
self.assertEqual(len(elements.chem_elements), 9)
30 changes: 30 additions & 0 deletions vaspy/tests/errors_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding:utf-8 -*-
'''
Unit tests for vaspy.errors module.
'''

import unittest

from ..errors import CarfileValueError, UnmatchedDataShape


class CarfileValueErrorTest(unittest.TestCase):

def test_raise(self):
with self.assertRaises(CarfileValueError):
raise CarfileValueError("test error")

def test_message(self):
err = CarfileValueError("bad value")
self.assertEqual(str(err), "bad value")


class UnmatchedDataShapeTest(unittest.TestCase):

def test_raise(self):
with self.assertRaises(UnmatchedDataShape):
raise UnmatchedDataShape("shape mismatch")

def test_message(self):
err = UnmatchedDataShape("shape mismatch")
self.assertEqual(str(err), "shape mismatch")
114 changes: 114 additions & 0 deletions vaspy/tests/functions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# -*- coding:utf-8 -*-
'''
Unit tests for vaspy.functions module.
'''

import unittest
import numpy as np

from ..functions import (str2list, line2list, array2str,
combine_atomco_dict, atomdict2str,
get_combinations, get_angle)


class Str2listTest(unittest.TestCase):

def test_str2list(self):
result = str2list(' 1.0 2.0 3.0 ')
self.assertListEqual(result, ['1.0', '2.0', '3.0'])

def test_str2list_empty(self):
result = str2list('')
self.assertEqual(result, [])


class Line2listTest(unittest.TestCase):

def test_line2list_float(self):
result = line2list('1.0 2.0 3.0', dtype=float)
self.assertListEqual(result, [1.0, 2.0, 3.0])

def test_line2list_int(self):
result = line2list('10 20 30', dtype=int)
self.assertListEqual(result, [10, 20, 30])

def test_line2list_str(self):
result = line2list('a b c', dtype=str)
self.assertListEqual(result, ['a', 'b', 'c'])

def test_line2list_custom_field(self):
result = line2list('1.0,2.0,3.0', field=',', dtype=float)
self.assertListEqual(result, [1.0, 2.0, 3.0])

def test_line2list_empty_elements(self):
result = line2list(' 1.0 2.0 ', dtype=float)
self.assertListEqual(result, [1.0, 2.0])

def test_line2list_type_error(self):
with self.assertRaises(TypeError):
line2list('1.0 2.0', dtype=3.14)


class Array2strTest(unittest.TestCase):

def test_array2str(self):
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
result = array2str(arr)
self.assertIn('1.0000000000000000', result)
self.assertIn('2.0000000000000000', result)
self.assertEqual(result.count('\n'), 2)


class CombineAtomcoDictTest(unittest.TestCase):

def test_combine_disjoint(self):
a = {'C': [[1.0, 2.0, 3.0]]}
b = {'O': [[4.0, 5.0, 6.0]]}
result = combine_atomco_dict(a, b)
self.assertEqual(set(result.keys()), {'C', 'O'})

def test_combine_overlap(self):
a = {'C': [[1.0, 2.0, 3.0]]}
b = {'C': [[4.0, 5.0, 6.0]]}
result = combine_atomco_dict(a, b)
self.assertEqual(len(result['C']), 2)

def test_combine_empty(self):
result = combine_atomco_dict({}, {})
self.assertEqual(result, {})


class Atomdict2strTest(unittest.TestCase):

def test_atomdict2str(self):
d = {'C': [[2.01115823704755, 2.33265069974919, 10.54948252493041]],
'Co': [[0.28355818414485, 2.31976779057375, 2.34330019781397],
[2.76900337448991, 0.88479534087197, 2.34330019781397]]}
result = atomdict2str(d, ['C', 'Co'])
self.assertIn('C', result)
self.assertIn('Co', result)
self.assertEqual(result.count('\n'), 3)


class GetCombinationsTest(unittest.TestCase):

def test_get_combinations(self):
result = get_combinations(3, 4, 5)
self.assertIsInstance(result, np.ndarray)


class GetAngleTest(unittest.TestCase):

def test_get_angle_90(self):
v1 = np.array([1.0, 0.0, 0.0])
v2 = np.array([0.0, 1.0, 0.0])
self.assertAlmostEqual(get_angle(v1, v2), 90.0)

def test_get_angle_0(self):
v1 = np.array([1.0, 0.0, 0.0])
self.assertAlmostEqual(get_angle(v1, v1), 0.0)

def test_get_angle_180(self):
v1 = np.array([1.0, 0.0, 0.0])
v2 = np.array([-1.0, 0.0, 0.0])
self.assertAlmostEqual(get_angle(v1, v2), 180.0)
28 changes: 28 additions & 0 deletions vaspy/tests/plotter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding:utf-8 -*-
'''
Unit tests for vaspy.plotter module.
'''

import unittest
import os

from ..plotter import DataPlotter
from . import path


class DataPlotterTest(unittest.TestCase):

def setUp(self):
self.filename = os.path.join(path, "PLOTCON")

def test_load(self):
plotter = DataPlotter(self.filename)
self.assertIsNotNone(plotter.data)
self.assertGreater(plotter.data.shape[0], 0)
self.assertGreater(plotter.data.shape[1], 0)

def test_attributes(self):
plotter = DataPlotter(self.filename)
self.assertEqual(plotter.filename, self.filename)
self.assertEqual(plotter.field, ' ')
self.assertEqual(plotter.dtype, float)
10 changes: 10 additions & 0 deletions vaspy/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,14 @@
from .cif_test import CifFileTest
from .ani_test import AniFileTest
from .xdatcar_test import XdatCarTest
from .functions_test import (Str2listTest, Line2listTest, Array2strTest,
CombineAtomcoDictTest, Atomdict2strTest,
GetCombinationsTest, GetAngleTest)
from .plotter_test import DataPlotterTest
from .electro_test import DosXTest, ElfCarTest, ChgCarTest
from .elements_test import ElementsTest
from .errors_test import CarfileValueErrorTest, UnmatchedDataShapeTest

if __name__ == '__main__':
unittest.main()

Loading