506 lines
18 KiB
Python
506 lines
18 KiB
Python
|
"""
|
||
|
lxml-based doctest output comparison.
|
||
|
|
||
|
Note: normally, you should just import the `lxml.usedoctest` and
|
||
|
`lxml.html.usedoctest` modules from within a doctest, instead of this
|
||
|
one::
|
||
|
|
||
|
>>> import lxml.usedoctest # for XML output
|
||
|
|
||
|
>>> import lxml.html.usedoctest # for HTML output
|
||
|
|
||
|
To use this module directly, you must call ``lxmldoctest.install()``,
|
||
|
which will cause doctest to use this in all subsequent calls.
|
||
|
|
||
|
This changes the way output is checked and comparisons are made for
|
||
|
XML or HTML-like content.
|
||
|
|
||
|
XML or HTML content is noticed because the example starts with ``<``
|
||
|
(it's HTML if it starts with ``<html``). You can also use the
|
||
|
``PARSE_HTML`` and ``PARSE_XML`` flags to force parsing.
|
||
|
|
||
|
Some rough wildcard-like things are allowed. Whitespace is generally
|
||
|
ignored (except in attributes). In text (attributes and text in the
|
||
|
body) you can use ``...`` as a wildcard. In an example it also
|
||
|
matches any trailing tags in the element, though it does not match
|
||
|
leading tags. You may create a tag ``<any>`` or include an ``any``
|
||
|
attribute in the tag. An ``any`` tag matches any tag, while the
|
||
|
attribute matches any and all attributes.
|
||
|
|
||
|
When a match fails, the reformatted example and gotten text is
|
||
|
displayed (indented), and a rough diff-like output is given. Anything
|
||
|
marked with ``-`` is in the output but wasn't supposed to be, and
|
||
|
similarly ``+`` means its in the example but wasn't in the output.
|
||
|
|
||
|
You can disable parsing on one line with ``# doctest:+NOPARSE_MARKUP``
|
||
|
"""
|
||
|
|
||
|
from lxml import etree
|
||
|
import sys
|
||
|
import re
|
||
|
import doctest
|
||
|
import cgi
|
||
|
|
||
|
__all__ = ['PARSE_HTML', 'PARSE_XML', 'NOPARSE_MARKUP', 'LXMLOutputChecker',
|
||
|
'LHTMLOutputChecker', 'install', 'temp_install']
|
||
|
|
||
|
try:
|
||
|
_basestring = basestring
|
||
|
except NameError:
|
||
|
_basestring = (str, bytes)
|
||
|
|
||
|
_IS_PYTHON_3 = sys.version_info[0] >= 3
|
||
|
|
||
|
PARSE_HTML = doctest.register_optionflag('PARSE_HTML')
|
||
|
PARSE_XML = doctest.register_optionflag('PARSE_XML')
|
||
|
NOPARSE_MARKUP = doctest.register_optionflag('NOPARSE_MARKUP')
|
||
|
|
||
|
OutputChecker = doctest.OutputChecker
|
||
|
|
||
|
def strip(v):
|
||
|
if v is None:
|
||
|
return None
|
||
|
else:
|
||
|
return v.strip()
|
||
|
|
||
|
def norm_whitespace(v):
|
||
|
return _norm_whitespace_re.sub(' ', v)
|
||
|
|
||
|
_html_parser = etree.HTMLParser(recover=False, remove_blank_text=True)
|
||
|
|
||
|
def html_fromstring(html):
|
||
|
return etree.fromstring(html, _html_parser)
|
||
|
|
||
|
# We use this to distinguish repr()s from elements:
|
||
|
_repr_re = re.compile(r'^<[^>]+ (at|object) ')
|
||
|
_norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
|
||
|
|
||
|
class LXMLOutputChecker(OutputChecker):
|
||
|
|
||
|
empty_tags = (
|
||
|
'param', 'img', 'area', 'br', 'basefont', 'input',
|
||
|
'base', 'meta', 'link', 'col')
|
||
|
|
||
|
def get_default_parser(self):
|
||
|
return etree.XML
|
||
|
|
||
|
def check_output(self, want, got, optionflags):
|
||
|
alt_self = getattr(self, '_temp_override_self', None)
|
||
|
if alt_self is not None:
|
||
|
super_method = self._temp_call_super_check_output
|
||
|
self = alt_self
|
||
|
else:
|
||
|
super_method = OutputChecker.check_output
|
||
|
parser = self.get_parser(want, got, optionflags)
|
||
|
if not parser:
|
||
|
return super_method(
|
||
|
self, want, got, optionflags)
|
||
|
try:
|
||
|
want_doc = parser(want)
|
||
|
except etree.XMLSyntaxError:
|
||
|
return False
|
||
|
try:
|
||
|
got_doc = parser(got)
|
||
|
except etree.XMLSyntaxError:
|
||
|
return False
|
||
|
return self.compare_docs(want_doc, got_doc)
|
||
|
|
||
|
def get_parser(self, want, got, optionflags):
|
||
|
parser = None
|
||
|
if NOPARSE_MARKUP & optionflags:
|
||
|
return None
|
||
|
if PARSE_HTML & optionflags:
|
||
|
parser = html_fromstring
|
||
|
elif PARSE_XML & optionflags:
|
||
|
parser = etree.XML
|
||
|
elif (want.strip().lower().startswith('<html')
|
||
|
and got.strip().startswith('<html')):
|
||
|
parser = html_fromstring
|
||
|
elif (self._looks_like_markup(want)
|
||
|
and self._looks_like_markup(got)):
|
||
|
parser = self.get_default_parser()
|
||
|
return parser
|
||
|
|
||
|
def _looks_like_markup(self, s):
|
||
|
s = s.strip()
|
||
|
return (s.startswith('<')
|
||
|
and not _repr_re.search(s))
|
||
|
|
||
|
def compare_docs(self, want, got):
|
||
|
if not self.tag_compare(want.tag, got.tag):
|
||
|
return False
|
||
|
if not self.text_compare(want.text, got.text, True):
|
||
|
return False
|
||
|
if not self.text_compare(want.tail, got.tail, True):
|
||
|
return False
|
||
|
if 'any' not in want.attrib:
|
||
|
want_keys = sorted(want.attrib.keys())
|
||
|
got_keys = sorted(got.attrib.keys())
|
||
|
if want_keys != got_keys:
|
||
|
return False
|
||
|
for key in want_keys:
|
||
|
if not self.text_compare(want.attrib[key], got.attrib[key], False):
|
||
|
return False
|
||
|
if want.text != '...' or len(want):
|
||
|
want_children = list(want)
|
||
|
got_children = list(got)
|
||
|
while want_children or got_children:
|
||
|
if not want_children or not got_children:
|
||
|
return False
|
||
|
want_first = want_children.pop(0)
|
||
|
got_first = got_children.pop(0)
|
||
|
if not self.compare_docs(want_first, got_first):
|
||
|
return False
|
||
|
if not got_children and want_first.tail == '...':
|
||
|
break
|
||
|
return True
|
||
|
|
||
|
def text_compare(self, want, got, strip):
|
||
|
want = want or ''
|
||
|
got = got or ''
|
||
|
if strip:
|
||
|
want = norm_whitespace(want).strip()
|
||
|
got = norm_whitespace(got).strip()
|
||
|
want = '^%s$' % re.escape(want)
|
||
|
want = want.replace(r'\.\.\.', '.*')
|
||
|
if re.search(want, got):
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def tag_compare(self, want, got):
|
||
|
if want == 'any':
|
||
|
return True
|
||
|
if (not isinstance(want, _basestring)
|
||
|
or not isinstance(got, _basestring)):
|
||
|
return want == got
|
||
|
want = want or ''
|
||
|
got = got or ''
|
||
|
if want.startswith('{...}'):
|
||
|
# Ellipsis on the namespace
|
||
|
return want.split('}')[-1] == got.split('}')[-1]
|
||
|
else:
|
||
|
return want == got
|
||
|
|
||
|
def output_difference(self, example, got, optionflags):
|
||
|
want = example.want
|
||
|
parser = self.get_parser(want, got, optionflags)
|
||
|
errors = []
|
||
|
if parser is not None:
|
||
|
try:
|
||
|
want_doc = parser(want)
|
||
|
except etree.XMLSyntaxError:
|
||
|
e = sys.exc_info()[1]
|
||
|
errors.append('In example: %s' % e)
|
||
|
try:
|
||
|
got_doc = parser(got)
|
||
|
except etree.XMLSyntaxError:
|
||
|
e = sys.exc_info()[1]
|
||
|
errors.append('In actual output: %s' % e)
|
||
|
if parser is None or errors:
|
||
|
value = OutputChecker.output_difference(
|
||
|
self, example, got, optionflags)
|
||
|
if errors:
|
||
|
errors.append(value)
|
||
|
return '\n'.join(errors)
|
||
|
else:
|
||
|
return value
|
||
|
html = parser is html_fromstring
|
||
|
diff_parts = []
|
||
|
diff_parts.append('Expected:')
|
||
|
diff_parts.append(self.format_doc(want_doc, html, 2))
|
||
|
diff_parts.append('Got:')
|
||
|
diff_parts.append(self.format_doc(got_doc, html, 2))
|
||
|
diff_parts.append('Diff:')
|
||
|
diff_parts.append(self.collect_diff(want_doc, got_doc, html, 2))
|
||
|
return '\n'.join(diff_parts)
|
||
|
|
||
|
def html_empty_tag(self, el, html=True):
|
||
|
if not html:
|
||
|
return False
|
||
|
if el.tag not in self.empty_tags:
|
||
|
return False
|
||
|
if el.text or len(el):
|
||
|
# This shouldn't happen (contents in an empty tag)
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def format_doc(self, doc, html, indent, prefix=''):
|
||
|
parts = []
|
||
|
if not len(doc):
|
||
|
# No children...
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(prefix)
|
||
|
parts.append(self.format_tag(doc))
|
||
|
if not self.html_empty_tag(doc, html):
|
||
|
if strip(doc.text):
|
||
|
parts.append(self.format_text(doc.text))
|
||
|
parts.append(self.format_end_tag(doc))
|
||
|
if strip(doc.tail):
|
||
|
parts.append(self.format_text(doc.tail))
|
||
|
parts.append('\n')
|
||
|
return ''.join(parts)
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(prefix)
|
||
|
parts.append(self.format_tag(doc))
|
||
|
if not self.html_empty_tag(doc, html):
|
||
|
parts.append('\n')
|
||
|
if strip(doc.text):
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.format_text(doc.text))
|
||
|
parts.append('\n')
|
||
|
for el in doc:
|
||
|
parts.append(self.format_doc(el, html, indent+2))
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.format_end_tag(doc))
|
||
|
parts.append('\n')
|
||
|
if strip(doc.tail):
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.format_text(doc.tail))
|
||
|
parts.append('\n')
|
||
|
return ''.join(parts)
|
||
|
|
||
|
def format_text(self, text, strip=True):
|
||
|
if text is None:
|
||
|
return ''
|
||
|
if strip:
|
||
|
text = text.strip()
|
||
|
return cgi.escape(text, 1)
|
||
|
|
||
|
def format_tag(self, el):
|
||
|
attrs = []
|
||
|
if isinstance(el, etree.CommentBase):
|
||
|
# FIXME: probably PIs should be handled specially too?
|
||
|
return '<!--'
|
||
|
for name, value in sorted(el.attrib.items()):
|
||
|
attrs.append('%s="%s"' % (name, self.format_text(value, False)))
|
||
|
if not attrs:
|
||
|
return '<%s>' % el.tag
|
||
|
return '<%s %s>' % (el.tag, ' '.join(attrs))
|
||
|
|
||
|
def format_end_tag(self, el):
|
||
|
if isinstance(el, etree.CommentBase):
|
||
|
# FIXME: probably PIs should be handled specially too?
|
||
|
return '-->'
|
||
|
return '</%s>' % el.tag
|
||
|
|
||
|
def collect_diff(self, want, got, html, indent):
|
||
|
parts = []
|
||
|
if not len(want) and not len(got):
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.collect_diff_tag(want, got))
|
||
|
if not self.html_empty_tag(got, html):
|
||
|
parts.append(self.collect_diff_text(want.text, got.text))
|
||
|
parts.append(self.collect_diff_end_tag(want, got))
|
||
|
parts.append(self.collect_diff_text(want.tail, got.tail))
|
||
|
parts.append('\n')
|
||
|
return ''.join(parts)
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.collect_diff_tag(want, got))
|
||
|
parts.append('\n')
|
||
|
if strip(want.text) or strip(got.text):
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.collect_diff_text(want.text, got.text))
|
||
|
parts.append('\n')
|
||
|
want_children = list(want)
|
||
|
got_children = list(got)
|
||
|
while want_children or got_children:
|
||
|
if not want_children:
|
||
|
parts.append(self.format_doc(got_children.pop(0), html, indent+2, '-'))
|
||
|
continue
|
||
|
if not got_children:
|
||
|
parts.append(self.format_doc(want_children.pop(0), html, indent+2, '+'))
|
||
|
continue
|
||
|
parts.append(self.collect_diff(
|
||
|
want_children.pop(0), got_children.pop(0), html, indent+2))
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.collect_diff_end_tag(want, got))
|
||
|
parts.append('\n')
|
||
|
if strip(want.tail) or strip(got.tail):
|
||
|
parts.append(' '*indent)
|
||
|
parts.append(self.collect_diff_text(want.tail, got.tail))
|
||
|
parts.append('\n')
|
||
|
return ''.join(parts)
|
||
|
|
||
|
def collect_diff_tag(self, want, got):
|
||
|
if not self.tag_compare(want.tag, got.tag):
|
||
|
tag = '%s (got: %s)' % (want.tag, got.tag)
|
||
|
else:
|
||
|
tag = got.tag
|
||
|
attrs = []
|
||
|
any = want.tag == 'any' or 'any' in want.attrib
|
||
|
for name, value in sorted(got.attrib.items()):
|
||
|
if name not in want.attrib and not any:
|
||
|
attrs.append('-%s="%s"' % (name, self.format_text(value, False)))
|
||
|
else:
|
||
|
if name in want.attrib:
|
||
|
text = self.collect_diff_text(want.attrib[name], value, False)
|
||
|
else:
|
||
|
text = self.format_text(value, False)
|
||
|
attrs.append('%s="%s"' % (name, text))
|
||
|
if not any:
|
||
|
for name, value in sorted(want.attrib.items()):
|
||
|
if name in got.attrib:
|
||
|
continue
|
||
|
attrs.append('+%s="%s"' % (name, self.format_text(value, False)))
|
||
|
if attrs:
|
||
|
tag = '<%s %s>' % (tag, ' '.join(attrs))
|
||
|
else:
|
||
|
tag = '<%s>' % tag
|
||
|
return tag
|
||
|
|
||
|
def collect_diff_end_tag(self, want, got):
|
||
|
if want.tag != got.tag:
|
||
|
tag = '%s (got: %s)' % (want.tag, got.tag)
|
||
|
else:
|
||
|
tag = got.tag
|
||
|
return '</%s>' % tag
|
||
|
|
||
|
def collect_diff_text(self, want, got, strip=True):
|
||
|
if self.text_compare(want, got, strip):
|
||
|
if not got:
|
||
|
return ''
|
||
|
return self.format_text(got, strip)
|
||
|
text = '%s (got: %s)' % (want, got)
|
||
|
return self.format_text(text, strip)
|
||
|
|
||
|
class LHTMLOutputChecker(LXMLOutputChecker):
|
||
|
def get_default_parser(self):
|
||
|
return html_fromstring
|
||
|
|
||
|
def install(html=False):
|
||
|
"""
|
||
|
Install doctestcompare for all future doctests.
|
||
|
|
||
|
If html is true, then by default the HTML parser will be used;
|
||
|
otherwise the XML parser is used.
|
||
|
"""
|
||
|
if html:
|
||
|
doctest.OutputChecker = LHTMLOutputChecker
|
||
|
else:
|
||
|
doctest.OutputChecker = LXMLOutputChecker
|
||
|
|
||
|
def temp_install(html=False, del_module=None):
|
||
|
"""
|
||
|
Use this *inside* a doctest to enable this checker for this
|
||
|
doctest only.
|
||
|
|
||
|
If html is true, then by default the HTML parser will be used;
|
||
|
otherwise the XML parser is used.
|
||
|
"""
|
||
|
if html:
|
||
|
Checker = LHTMLOutputChecker
|
||
|
else:
|
||
|
Checker = LXMLOutputChecker
|
||
|
frame = _find_doctest_frame()
|
||
|
dt_self = frame.f_locals['self']
|
||
|
checker = Checker()
|
||
|
old_checker = dt_self._checker
|
||
|
dt_self._checker = checker
|
||
|
# The unfortunate thing is that there is a local variable 'check'
|
||
|
# in the function that runs the doctests, that is a bound method
|
||
|
# into the output checker. We have to update that. We can't
|
||
|
# modify the frame, so we have to modify the object in place. The
|
||
|
# only way to do this is to actually change the func_code
|
||
|
# attribute of the method. We change it, and then wait for
|
||
|
# __record_outcome to be run, which signals the end of the __run
|
||
|
# method, at which point we restore the previous check_output
|
||
|
# implementation.
|
||
|
if _IS_PYTHON_3:
|
||
|
check_func = frame.f_locals['check'].__func__
|
||
|
checker_check_func = checker.check_output.__func__
|
||
|
else:
|
||
|
check_func = frame.f_locals['check'].im_func
|
||
|
checker_check_func = checker.check_output.im_func
|
||
|
# Because we can't patch up func_globals, this is the only global
|
||
|
# in check_output that we care about:
|
||
|
doctest.etree = etree
|
||
|
_RestoreChecker(dt_self, old_checker, checker,
|
||
|
check_func, checker_check_func,
|
||
|
del_module)
|
||
|
|
||
|
class _RestoreChecker(object):
|
||
|
def __init__(self, dt_self, old_checker, new_checker, check_func, clone_func,
|
||
|
del_module):
|
||
|
self.dt_self = dt_self
|
||
|
self.checker = old_checker
|
||
|
self.checker._temp_call_super_check_output = self.call_super
|
||
|
self.checker._temp_override_self = new_checker
|
||
|
self.check_func = check_func
|
||
|
self.clone_func = clone_func
|
||
|
self.del_module = del_module
|
||
|
self.install_clone()
|
||
|
self.install_dt_self()
|
||
|
def install_clone(self):
|
||
|
if _IS_PYTHON_3:
|
||
|
self.func_code = self.check_func.__code__
|
||
|
self.func_globals = self.check_func.__globals__
|
||
|
self.check_func.__code__ = self.clone_func.__code__
|
||
|
else:
|
||
|
self.func_code = self.check_func.func_code
|
||
|
self.func_globals = self.check_func.func_globals
|
||
|
self.check_func.func_code = self.clone_func.func_code
|
||
|
def uninstall_clone(self):
|
||
|
if _IS_PYTHON_3:
|
||
|
self.check_func.__code__ = self.func_code
|
||
|
else:
|
||
|
self.check_func.func_code = self.func_code
|
||
|
def install_dt_self(self):
|
||
|
self.prev_func = self.dt_self._DocTestRunner__record_outcome
|
||
|
self.dt_self._DocTestRunner__record_outcome = self
|
||
|
def uninstall_dt_self(self):
|
||
|
self.dt_self._DocTestRunner__record_outcome = self.prev_func
|
||
|
def uninstall_module(self):
|
||
|
if self.del_module:
|
||
|
import sys
|
||
|
del sys.modules[self.del_module]
|
||
|
if '.' in self.del_module:
|
||
|
package, module = self.del_module.rsplit('.', 1)
|
||
|
package_mod = sys.modules[package]
|
||
|
delattr(package_mod, module)
|
||
|
def __call__(self, *args, **kw):
|
||
|
self.uninstall_clone()
|
||
|
self.uninstall_dt_self()
|
||
|
del self.checker._temp_override_self
|
||
|
del self.checker._temp_call_super_check_output
|
||
|
result = self.prev_func(*args, **kw)
|
||
|
self.uninstall_module()
|
||
|
return result
|
||
|
def call_super(self, *args, **kw):
|
||
|
self.uninstall_clone()
|
||
|
try:
|
||
|
return self.check_func(*args, **kw)
|
||
|
finally:
|
||
|
self.install_clone()
|
||
|
|
||
|
def _find_doctest_frame():
|
||
|
import sys
|
||
|
frame = sys._getframe(1)
|
||
|
while frame:
|
||
|
l = frame.f_locals
|
||
|
if 'BOOM' in l:
|
||
|
# Sign of doctest
|
||
|
return frame
|
||
|
frame = frame.f_back
|
||
|
raise LookupError(
|
||
|
"Could not find doctest (only use this function *inside* a doctest)")
|
||
|
|
||
|
__test__ = {
|
||
|
'basic': '''
|
||
|
>>> temp_install()
|
||
|
>>> print """<xml a="1" b="2">stuff</xml>"""
|
||
|
<xml b="2" a="1">...</xml>
|
||
|
>>> print """<xml xmlns="http://example.com"><tag attr="bar" /></xml>"""
|
||
|
<xml xmlns="...">
|
||
|
<tag attr="..." />
|
||
|
</xml>
|
||
|
>>> print """<xml>blahblahblah<foo /></xml>""" # doctest: +NOPARSE_MARKUP, +ELLIPSIS
|
||
|
<xml>...foo /></xml>
|
||
|
'''}
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
import doctest
|
||
|
doctest.testmod()
|
||
|
|
||
|
|