Skip to content

Commit e92a301

Browse files
authored
refactor load method to allow for easier mocking with pyfakefs (#57)
1 parent 28018a3 commit e92a301

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

orgparse/__init__.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,19 @@
106106
"""
107107
# [[[end]]]
108108

109-
import codecs
109+
from io import IOBase
110110
from pathlib import Path
111111
from typing import Iterable, Union, Optional, TextIO
112112

113113

114-
from .node import parse_lines, OrgEnv, OrgNode # todo basenode??
115-
from .utils.py3compat import basestring
114+
from .node import parse_lines, OrgEnv, OrgNode # todo basenode??
116115

117116
__author__ = 'Takafumi Arakaki, Dmitrii Gerasimov'
118117
__license__ = 'BSD License'
119118
__all__ = ["load", "loads", "loadi"]
120119

121120

122-
def load(path: Union[str, Path, TextIO], env: Optional[OrgEnv]=None) -> OrgNode:
121+
def load(path: Union[str, Path, TextIO], env: Optional[OrgEnv] = None) -> OrgNode:
123122
"""
124123
Load org-mode document from a file.
125124
@@ -129,17 +128,24 @@ def load(path: Union[str, Path, TextIO], env: Optional[OrgEnv]=None) -> OrgNode:
129128
:rtype: :class:`orgparse.node.OrgRootNode`
130129
131130
"""
132-
orgfile: TextIO
133-
if isinstance(path, (str, Path)):
134-
# Use 'with' to close the file inside this function.
135-
with codecs.open(str(path), encoding='utf8') as orgfile:
136-
lines = (l.rstrip('\n') for l in orgfile.readlines())
137-
filename = str(path)
138-
else:
139-
orgfile = path
140-
lines = (l.rstrip('\n') for l in orgfile.readlines())
141-
filename = path.name if hasattr(path, 'name') else '<file-like>'
142-
return loadi(lines, filename=filename, env=env)
131+
# Make sure it is a Path object.
132+
if isinstance(path, str):
133+
path = Path(path)
134+
135+
# if it is a Path
136+
if isinstance(path, Path):
137+
# open that Path
138+
with path.open('r', encoding='utf8') as orgfile:
139+
# try again loading
140+
return load(orgfile, env)
141+
142+
# We assume it is a file-like object (e.g. io.StringIO)
143+
all_lines = (line.rstrip('\n') for line in path)
144+
145+
# get the filename
146+
filename = path.name if hasattr(path, 'name') else '<file-like>'
147+
148+
return loadi(all_lines, filename=filename, env=env)
143149

144150

145151
def loads(string: str, filename: str='<string>', env: Optional[OrgEnv]=None) -> OrgNode:

0 commit comments

Comments
 (0)