From 5e96ab5a3ec0496acea03420cb6b09107de37b06 Mon Sep 17 00:00:00 2001 From: Oliverpool Date: Tue, 9 Feb 2016 17:27:23 +0100 Subject: [PATCH] Update Rx --- patacrep/Rx.py | 590 ++++++++++++++++++++++++++++++---------------- patacrep/utils.py | 3 +- 2 files changed, 387 insertions(+), 206 deletions(-) diff --git a/patacrep/Rx.py b/patacrep/Rx.py index 677a3461..a388f039 100644 --- a/patacrep/Rx.py +++ b/patacrep/Rx.py @@ -9,85 +9,242 @@ import re import types from numbers import Number - -core_types = [ ] +### Exception Classes -------------------------------------------------------- class SchemaError(Exception): pass class SchemaMismatch(Exception): - pass -class SchemaTypeMismatch(SchemaMismatch): - def __init__(self, name, desired_type): - SchemaMismatch.__init__(self, '{0} must be {1}'.format(name, desired_type)) + def __init__(self, message, schema, error=None): + Exception.__init__(self, message) + self.type = schema.subname() + self.error = error + +class TypeMismatch(SchemaMismatch): + + def __init__(self, schema, data): + message = 'must be of type {} (was {})'.format( + schema.subname(), + type(data).__name__ + ) -class SchemaValueMismatch(SchemaMismatch): - def __init__(self, name, value): - SchemaMismatch.__init__(self, '{0} must equal {1}'.format(name, value)) + SchemaMismatch.__init__(self, message, schema, 'type') + self.expected_type = schema.subname() + self.value = type(data).__name__ -class SchemaRangeMismatch(SchemaMismatch): - pass -def indent(text, level=1, whitespace=' '): - return '\n'.join(whitespace*level+line for line in text.split('\n')) +class ValueMismatch(SchemaMismatch): + + def __init__(self, schema, data): -class Util(object): - @staticmethod - def make_range_check(opt): - - if not {'min', 'max', 'min-ex', 'max-ex'}.issuperset(opt): - raise ValueError("illegal argument to make_range_check") - if {'min', 'min-ex'}.issubset(opt): - raise ValueError("Cannot define both exclusive and inclusive min") - if {'max', 'max-ex'}.issubset(opt): - raise ValueError("Cannot define both exclusive and inclusive max") - - r = opt.copy() - inf = float('inf') - - def check_range(value): - return( - r.get('min', -inf) <= value and \ - r.get('max', inf) >= value and \ - r.get('min-ex', -inf) < value and \ - r.get('max-ex', inf) > value + message = 'must equal {} (was {})'.format( + repr(schema.value), + repr(data) + ) + + SchemaMismatch.__init__(self, message, schema, 'value') + self.expected_value = schema.value + self.value = data + + + +class RangeMismatch(SchemaMismatch): + + def __init__(self, schema, data): + + message = 'must be in range {} (was {})'.format( + schema.range, + data + ) + + SchemaMismatch.__init__(self, message, schema, 'range') + self.range = schema.range + self.value = data + + +class LengthRangeMismatch(SchemaMismatch): + + def __init__(self, schema, data): + length_range = Range(schema.length) + + if not hasattr(length_range, 'min') and \ + not hasattr(length_range, 'min_ex'): + length_range.min = 0 + + message = 'length must be in range {} (was {})'.format( + length_range, + len(data) + ) + + SchemaMismatch.__init__(self, message, schema, 'range') + self.range = schema.length + self.value = len(data) + + +class MissingFieldMismatch(SchemaMismatch): + + def __init__(self, schema, fields): + + if len(fields) == 1: + message = 'missing required field: {}'.format( + repr(fields[0]) ) + else: + message = 'missing required fields: {}'.format( + ', '.join(fields) + ) + if len(message) >= 80: # if the line is too long + message = 'missing required fields:\n{}'.format( + _indent('\n'.join(fields)) + ) - return check_range + SchemaMismatch.__init__(self, message, schema, 'missing') + self.fields = fields - @staticmethod - def make_range_validator(opt): - check_range = Util.make_range_check(opt) - - r = opt.copy() - nan = float('nan') - - def validate_range(value, name='value'): - if not check_range(value): - if r.get('min', nan) == r.get('max', nan): - msg = '{0} must equal {1}'.format(name, r['min']) - raise SchemaRangeMismatch(msg) - - range_str = '' - if 'min' in r: - range_str = '[{0}, '.format(r['min']) - elif 'min-ex' in r: - range_str = '({0}, '.format(r['min-ex']) - else: - range_str = '(-inf, ' - - if 'max' in r: - range_str += '{0}]'.format(r['max']) - elif 'max-ex' in r: - range_str += '{0})'.format(r['max-ex']) - else: - range_str += 'inf)' - - raise SchemaRangeMismatch(name+' must be in range '+range_str) - - return validate_range +class UnknownFieldMismatch(SchemaMismatch): + + def __init__(self, schema, fields): + + if len(fields) == 1: + message = 'unknown field: {}'.format( + repr(fields[0]) + ) + else: + message = 'unknown fields: {}'.format( + ', '.join(fields) + ) + if len(message) >= 80: # if the line is too long + message = 'unknown fields:\n{}'.format( + _indent('\n'.join(fields)) + ) + + SchemaMismatch.__init__(self, message, schema, 'unexpected') + self.fields = fields + + +class SeqLengthMismatch(SchemaMismatch): + def __init__(self, schema, data): + + expected_length = len(schema.content_schema) + message = 'sequence must have {} element{} (had {})'.format( + expected_length, + 's'*(expected_length != 1), # plural + len(data) + ) + + SchemaMismatch.__init__(self, message, schema, 'size') + self.expected_length = expected_length + self.value = len(data) + + +class TreeMismatch(SchemaMismatch): + + def __init__(self, schema, errors=[], child_errors={}, message=None): + + ## Create error message + + error_messages = [] + + for err in errors: + error_messages.append(str(err)) + + for key, err in child_errors.items(): + + if isinstance(key, int): + index = '[item {}]'.format(key) + else: + index = '{}'.format(repr(key)) + + if isinstance(err, TreeMismatch) and \ + not err.errors and len(err.child_errors) == 1: + + template = '{} > {}' + + else: + template = '{} {}' + + msg = template.format(index, err) + error_messages.append(msg) + + if message is None: + message = 'does not match schema' + + if len(error_messages) == 1: + msg = error_messages[0] + + else: + msg = '{}:\n{}'.format( + message, + _indent('\n'.join(error_messages)) + ) + + SchemaMismatch.__init__(self, msg, schema, 'multiple') + self.errors = errors + self.child_errors = child_errors + +def _createTreeMismatch(schema, errors=[], child_errors={}, message=None): + if len(errors) == 1 and not child_errors: + return errors[0] + else: + return TreeMismatch(schema, errors, child_errors, message) + +### Utilities ---------------------------------------------------------------- + +class Range(object): + + def __init__(self, opt): + if isinstance(opt, Range): + for attr in ('min', 'max', 'min_ex', 'max_ex'): + if hasattr(opt, attr): + setattr(self, attr, getattr(opt, attr)) + else: + if not {'min', 'max', 'min-ex', 'max-ex'}.issuperset(opt): + raise ValueError("illegal argument to make_range_check") + if {'min', 'min-ex'}.issubset(opt): + raise ValueError("Cannot define both exclusive and inclusive min") + if {'max', 'max-ex'}.issubset(opt): + raise ValueError("Cannot define both exclusive and inclusive max") + + for boundary in ('min', 'max', 'min-ex', 'max-ex'): + if boundary in opt: + attr = boundary.replace('-', '_') + setattr(self, attr, opt[boundary]) + + def __call__(self, value): + INF = float('inf') + + get = lambda attr, default: getattr(self, attr, default) + + return( + get('min', -INF) <= value and \ + get('max', INF) >= value and \ + get('min_ex', -INF) < value and \ + get('max_ex', INF) > value + ) + + def __str__(self): + if hasattr(self, 'min'): + s = '[{}, '.format(self.min) + elif hasattr(self, 'min_ex'): + s = '({}, '.format(self.min_ex) + else: + s = '(-Inf, ' + + if hasattr(self, 'max'): + s += '{}]'.format(self.max) + elif hasattr(self, 'max_ex'): + s += '{})'.format(self.max_ex) + else: + s += 'Inf)' + + return s + +def _indent(text, level=1, whitespace=' '): + return '\n'.join(whitespace*level+line for line in text.split('\n')) + + ### Schema Factory Class ----------------------------------------------------- class Factory(object): def __init__(self, register_core_types=True): @@ -109,20 +266,20 @@ class Factory(object): m = re.match('^/([-._a-z0-9]*)/([-._a-z0-9]+)$', type_name) if not m: - raise ValueError("couldn't understand type name '{0}'".format(type_name)) + raise ValueError("couldn't understand type name '{}'".format(type_name)) prefix, suffix = m.groups() if prefix not in self.prefix_registry: raise KeyError( - "unknown prefix '{0}' in type name '{1}'".format(prefix, type_name) + "unknown prefix '{0}' in type name '{}'".format(prefix, type_name) ) return self.prefix_registry[ prefix ] + suffix def add_prefix(self, name, base): if self.prefix_registry.get(name): - raise SchemaError("the prefix '{0}' is already registered".format(name)) + raise SchemaError("the prefix '{}' is already registered".format(name)) self.prefix_registry[name] = base; @@ -130,13 +287,15 @@ class Factory(object): t_uri = t.uri() if t_uri in self.type_registry: - raise ValueError("type already registered for {0}".format(t_uri)) + raise ValueError("type already registered for {}".format(t_uri)) self.type_registry[t_uri] = t def learn_type(self, uri, schema): if self.type_registry.get(uri): - raise SchemaError("tried to learn type for already-registered uri {0}".format(uri)) + raise SchemaError( + "tried to learn type for already-registered uri {}".format(uri) + ) # make sure schema is valid # should this be in a try/except? @@ -153,17 +312,27 @@ class Factory(object): uri = self.expand_uri(schema['type']) - if not self.type_registry.get(uri): raise SchemaError("unknown type {0}".format(uri)) + if not self.type_registry.get(uri): + raise SchemaError("unknown type {}".format(uri)) type_class = self.type_registry[uri] if isinstance(type_class, dict): if not {'type'}.issuperset(schema): - raise SchemaError('composed type does not take check arguments'); + raise SchemaError('composed type does not take check arguments') return self.make_schema(type_class['schema']) else: return type_class(schema, self) +std_factory = None +def make_schema(schema): + global std_factory + if std_factory is None: + std_factory = Factory() + return std_factory.make_schema(schema) + +### Core Type Base Class ------------------------------------------------- + class _CoreType(object): @classmethod def uri(self): @@ -171,7 +340,7 @@ class _CoreType(object): def __init__(self, schema, rx): if not {'type'}.issuperset(schema): - raise SchemaError('unknown parameter for //{0}'.format(self.subname())) + raise SchemaError('unknown parameter for //{}'.format(self.subname())) def check(self, value): try: @@ -180,8 +349,10 @@ class _CoreType(object): return False return True - def validate(self, value, name='value'): - raise SchemaMismatch('Tried to validate abstract base schema class') + def validate(self, value): + raise SchemaMismatch('Tried to validate abstract base schema class', self) + +### Core Schema Types -------------------------------------------------------- class AllType(_CoreType): @staticmethod @@ -191,26 +362,22 @@ class AllType(_CoreType): if not {'type', 'of'}.issuperset(schema): raise SchemaError('unknown parameter for //all') - if not(schema.get('of') and len(schema.get('of'))): + if not schema.get('of'): raise SchemaError('no alternatives given in //all of') self.alts = [rx.make_schema(s) for s in schema['of']] - def validate(self, value, name='value'): - error_messages = [] + def validate(self, value): + errors = [] for schema in self.alts: try: - schema.validate(value, name) + schema.validate(value) except SchemaMismatch as e: - error_messages.append(str(e)) + errors.append(e) + + if errors: + raise _createTreeMismatch(self, errors) - if len(error_messages) > 1: - messages = indent('\n'.join(error_messages)) - message = '{0} failed to meet all schema requirements:\n{1}' - message = message.format(name, messages) - raise SchemaMismatch(message) - elif len(error_messages) == 1: - raise SchemaMismatch(error_messages[0]) class AnyType(_CoreType): @staticmethod @@ -223,25 +390,28 @@ class AnyType(_CoreType): raise SchemaError('unknown parameter for //any') if 'of' in schema: - if not schema['of']: raise SchemaError('no alternatives given in //any of') + if not schema['of']: + raise SchemaError('no alternatives given in //any of') + self.alts = [ rx.make_schema(alt) for alt in schema['of'] ] - def validate(self, value, name='value'): + def validate(self, value): if self.alts is None: return - error_messages = [] + + errors = [] + for schema in self.alts: try: - schema.validate(value, name) + schema.validate(value) break except SchemaMismatch as e: - error_messages.append(str(e)) + errors.append(e) + + if len(errors) == len(self.alts): + message = 'must satisfy at least one of the following' + raise _createTreeMismatch(self, errors, message=message) - if len(error_messages) == len(self.alts): - messages = indent('\n'.join(error_messages)) - message = '{0} failed to meet any schema requirements:\n{1}' - message = message.format(name, messages) - raise SchemaMismatch(message) class ArrType(_CoreType): @staticmethod @@ -259,47 +429,45 @@ class ArrType(_CoreType): self.content_schema = rx.make_schema(schema['contents']) if schema.get('length'): - self.length = Util.make_range_validator(schema['length']) + self.length = Range(schema['length']) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, (list, tuple)): - raise SchemaTypeMismatch(name, 'array') + raise TypeMismatch(self, value) - if self.length: - self.length(len(value), name+' length') + errors = [] + if self.length and not self.length(len(value)): + err = LengthRangeMismatch(self, value) + errors.append(err) - error_messages = [] + child_errors = {} - for i, item in enumerate(value): + for key, item in enumerate(value): try: - self.content_schema.validate(item, 'item '+str(i)) + self.content_schema.validate(item) except SchemaMismatch as e: - error_messages.append(str(e)) + child_errors[key] = e + if errors or child_errors: + raise _createTreeMismatch(self, errors, child_errors) - if len(error_messages) > 1: - messages = indent('\n'.join(error_messages)) - message = '{0} sequence contains invalid elements:\n{1}' - message = message.format(name, messages) - raise SchemaMismatch(message) - elif len(error_messages) == 1: - raise SchemaMismatch(name+': '+error_messages[0]) class BoolType(_CoreType): @staticmethod def subname(): return 'bool' - def validate(self, value, name='value'): + def validate(self, value,): if not isinstance(value, bool): - raise SchemaTypeMismatch(name, 'boolean') + raise TypeMismatch(self, value) + class DefType(_CoreType): @staticmethod def subname(): return 'def' - - def validate(self, value, name='value'): + def validate(self, value): if value is None: - raise SchemaMismatch(name+' must be non-null') + raise TypeMismatch(self, value) + class FailType(_CoreType): @staticmethod @@ -307,8 +475,13 @@ class FailType(_CoreType): def check(self, value): return False - def validate(self, value, name='value'): - raise SchemaMismatch(name+' is of fail type, automatically invalid.') + def validate(self, value): + raise SchemaMismatch( + 'is of fail type, automatically invalid.', + self, + 'fail' + ) + class IntType(_CoreType): @staticmethod @@ -326,17 +499,18 @@ class IntType(_CoreType): self.range = None if 'range' in schema: - self.range = Util.make_range_validator(schema['range']) + self.range = Range(schema['range']) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, Number) or isinstance(value, bool) or value%1: - raise SchemaTypeMismatch(name,'integer') + raise TypeMismatch(self, value) - if self.range: - self.range(value, name) + if self.range and not self.range(value): + raise RangeMismatch(self, value) if self.value is not None and value != self.value: - raise SchemaValueMismatch(name, self.value) + raise ValueMismatch(self, value) + class MapType(_CoreType): @staticmethod @@ -353,25 +527,21 @@ class MapType(_CoreType): self.value_schema = rx.make_schema(schema['values']) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, dict): - raise SchemaTypeMismatch(name, 'map') + raise TypeMismatch(self, value) - error_messages = [] + child_errors = {} for key, val in value.items(): try: - self.value_schema.validate(val, key) + self.value_schema.validate(val) except SchemaMismatch as e: - error_messages.append(str(e)) + child_errors[key] = e + + if child_errors: + raise _createTreeMismatch(self, child_errors=child_errors) - if len(error_messages) > 1: - messages = indent('\n'.join(error_messages)) - message = '{0} map contains invalid entries:\n{1}' - message = message.format(name, messages) - raise SchemaMismatch(message) - elif len(error_messages) == 1: - raise SchemaMismatch(name+': '+error_messages[0]) class NilType(_CoreType): @staticmethod @@ -379,9 +549,10 @@ class NilType(_CoreType): def check(self, value): return value is None - def validate(self, value, name='value'): + def validate(self, value): if value is not None: - raise SchemaTypeMismatch(name, 'null') + raise TypeMismatch(self, value) + class NumType(_CoreType): @staticmethod @@ -400,25 +571,27 @@ class NumType(_CoreType): self.range = None if schema.get('range'): - self.range = Util.make_range_validator(schema['range']) + self.range = Range(schema['range']) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, Number) or isinstance(value, bool): - raise SchemaTypeMismatch(name, 'number') + raise TypeMismatch(self, value) - if self.range: - self.range(value, name) + if self.range and not self.range(value): + raise RangeMismatch(self, value) if self.value is not None and value != self.value: - raise SchemaValueMismatch(name, self.value) + raise ValueMismatch(self, value) + class OneType(_CoreType): @staticmethod def subname(): return 'one' - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, (Number, str)): - raise SchemaTypeMismatch(name, 'number or string') + raise TypeMismatch(self, value) + class RecType(_CoreType): @staticmethod @@ -436,7 +609,9 @@ class RecType(_CoreType): setattr(self, which, {}) for field in schema.get(which, {}).keys(): if field in self.known: - raise SchemaError('%s appears in both required and optional' % field) + raise SchemaError( + '%s appears in both required and optional' % field + ) self.known.add(field) @@ -444,47 +619,53 @@ class RecType(_CoreType): schema[which][field] ) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, dict): - raise SchemaTypeMismatch(name, 'record') + raise TypeMismatch(self, value) - unknown = [k for k in value.keys() if k not in self.known] + errors = [] + child_errors = {} - if unknown and not self.rest_schema: - fields = indent('\n'.join(unknown)) - raise SchemaMismatch(name+' contains unknown fields:\n'+fields) - - error_messages = [] + missing_fields = [] for field in self.required: - try: - if field not in value: - raise SchemaMismatch('missing required field: '+field) - self.required[field].validate(value[field], field) - except SchemaMismatch as e: - error_messages.append(str(e)) + + if field not in value: + missing_fields.append(field) + else: + try: + self.required[field].validate(value[field]) + except SchemaMismatch as e: + child_errors[field] = e + + if missing_fields: + err = MissingFieldMismatch(self, missing_fields) + errors.append(err) for field in self.optional: if field not in value: continue - try: - self.optional[field].validate(value[field], field) - except SchemaMismatch as e: - error_messages.append(str(e)) - if unknown: - rest = {key: value[key] for key in unknown} try: - self.rest_schema.validate(rest, name) + self.optional[field].validate(value[field]) except SchemaMismatch as e: - error_messages.append(str(e)) + child_errors[field] = e - if len(error_messages) > 1: - messages = indent('\n'.join(error_messages)) - message = '{0} record is invalid:\n{1}' - message = message.format(name, messages) - raise SchemaMismatch(message) - elif len(error_messages) == 1: - raise SchemaMismatch(name+': '+error_messages[0]) + unknown = [k for k in value.keys() if k not in self.known] + + if unknown: + if self.rest_schema: + rest = {key: value[key] for key in unknown} + try: + self.rest_schema.validate(rest) + except SchemaMismatch as e: + errors.append(e) + else: + fields = _indent('\n'.join(unknown)) + err = UnknownFieldMismatch(self, unknown) + errors.append(err) + + if errors or child_errors: + raise _createTreeMismatch(self, errors, child_errors) class SeqType(_CoreType): @@ -504,34 +685,33 @@ class SeqType(_CoreType): if (schema.get('tail')): self.tail_schema = rx.make_schema(schema['tail']) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, (list, tuple)): - raise SchemaTypeMismatch(name, 'sequence') + raise TypeMismatch(self, value) - if len(value) < len(self.content_schema): - raise SchemaMismatch(name+' is less than expected length') + errors = [] - if len(value) > len(self.content_schema) and not self.tail_schema: - raise SchemaMismatch(name+' exceeds expected length') + if len(value) != len(self.content_schema): + if len(value) > len(self.content_schema) and self.tail_schema: + try: + self.tail_schema.validate(value[len(self.content_schema):]) + except SchemaMismatch as e: + errors.append(e) + else: + err = SeqLengthMismatch(self, value) + errors.append(err) - error_messages = [] + child_errors = {} - for i, (schema, item) in enumerate(zip(self.content_schema, value)): + for index, (schema, item) in enumerate(zip(self.content_schema, value)): try: - schema.validate(item, 'item '+str(i)) + schema.validate(item) except SchemaMismatch as e: - error_messages.append(str(e)) + child_errors[index] = e - if len(error_messages) > 1: - messages = indent('\n'.join(error_messages)) - message = '{0} sequence is invalid:\n{1}' - message = message.format(name, messages) - raise SchemaMismatch(message) - elif len(error_messages) == 1: - raise SchemaMismatch(name+': '+error_messages[0]) + if errors or child_errors: + raise _createTreeMismatch(self, errors, child_errors) - if len(value) > len(self.content_schema): - self.tail_schema.validate(value[len(self.content_schema):], name) class StrType(_CoreType): @staticmethod @@ -549,18 +729,20 @@ class StrType(_CoreType): self.length = None if 'length' in schema: - self.length = Util.make_range_validator(schema['length']) + self.length = Range(schema['length']) - def validate(self, value, name='value'): + def validate(self, value): if not isinstance(value, str): - raise SchemaTypeMismatch(name, 'string') + raise TypeMismatch(self, value) + if self.value is not None and value != self.value: - raise SchemaValueMismatch(name, '"{0}"'.format(self.value)) - if self.length: - self.length(len(value), name+' length') + raise ValueMismatch(self, self) + + if self.length and not self.length(len(value)): + raise LengthRangeMismatch(self, value) core_types = [ AllType, AnyType, ArrType, BoolType, DefType, FailType, IntType, MapType, NilType, NumType, OneType, RecType, SeqType, StrType -] +] \ No newline at end of file diff --git a/patacrep/utils.py b/patacrep/utils.py index ac13b4d5..3949432e 100644 --- a/patacrep/utils.py +++ b/patacrep/utils.py @@ -83,8 +83,7 @@ def validate_yaml_schema(data, schema): Will raise `SBFileError` if the schema is not respected. """ - rx_checker = Rx.Factory({"register_core_types": True}) - schema = rx_checker.make_schema(schema) + schema = Rx.make_schema(schema) if isinstance(data, DictOfDict): data = dict(data)