Browse Source

Update Rx

pull/190/head
Oliverpool 9 years ago
parent
commit
5e96ab5a3e
  1. 590
      patacrep/Rx.py
  2. 3
      patacrep/utils.py

590
patacrep/Rx.py

@ -9,85 +9,242 @@ import re
import types
from numbers import Number
core_types = [ ]
### Exception Classes --------------------------------------------------------
class SchemaError(Exception):
pass
class SchemaMismatch(Exception):
pass
class SchemaTypeMismatch(SchemaMismatch):
def __init__(self, name, desired_type):
SchemaMismatch.__init__(self, '{0} must be {1}'.format(name, desired_type))
def __init__(self, message, schema, error=None):
Exception.__init__(self, message)
self.type = schema.subname()
self.error = error
class TypeMismatch(SchemaMismatch):
def __init__(self, schema, data):
message = 'must be of type {} (was {})'.format(
schema.subname(),
type(data).__name__
)
class SchemaValueMismatch(SchemaMismatch):
def __init__(self, name, value):
SchemaMismatch.__init__(self, '{0} must equal {1}'.format(name, value))
SchemaMismatch.__init__(self, message, schema, 'type')
self.expected_type = schema.subname()
self.value = type(data).__name__
class SchemaRangeMismatch(SchemaMismatch):
pass
def indent(text, level=1, whitespace=' '):
return '\n'.join(whitespace*level+line for line in text.split('\n'))
class ValueMismatch(SchemaMismatch):
def __init__(self, schema, data):
class Util(object):
@staticmethod
def make_range_check(opt):
if not {'min', 'max', 'min-ex', 'max-ex'}.issuperset(opt):
raise ValueError("illegal argument to make_range_check")
if {'min', 'min-ex'}.issubset(opt):
raise ValueError("Cannot define both exclusive and inclusive min")
if {'max', 'max-ex'}.issubset(opt):
raise ValueError("Cannot define both exclusive and inclusive max")
r = opt.copy()
inf = float('inf')
def check_range(value):
return(
r.get('min', -inf) <= value and \
r.get('max', inf) >= value and \
r.get('min-ex', -inf) < value and \
r.get('max-ex', inf) > value
message = 'must equal {} (was {})'.format(
repr(schema.value),
repr(data)
)
SchemaMismatch.__init__(self, message, schema, 'value')
self.expected_value = schema.value
self.value = data
class RangeMismatch(SchemaMismatch):
def __init__(self, schema, data):
message = 'must be in range {} (was {})'.format(
schema.range,
data
)
SchemaMismatch.__init__(self, message, schema, 'range')
self.range = schema.range
self.value = data
class LengthRangeMismatch(SchemaMismatch):
def __init__(self, schema, data):
length_range = Range(schema.length)
if not hasattr(length_range, 'min') and \
not hasattr(length_range, 'min_ex'):
length_range.min = 0
message = 'length must be in range {} (was {})'.format(
length_range,
len(data)
)
SchemaMismatch.__init__(self, message, schema, 'range')
self.range = schema.length
self.value = len(data)
class MissingFieldMismatch(SchemaMismatch):
def __init__(self, schema, fields):
if len(fields) == 1:
message = 'missing required field: {}'.format(
repr(fields[0])
)
else:
message = 'missing required fields: {}'.format(
', '.join(fields)
)
if len(message) >= 80: # if the line is too long
message = 'missing required fields:\n{}'.format(
_indent('\n'.join(fields))
)
return check_range
SchemaMismatch.__init__(self, message, schema, 'missing')
self.fields = fields
@staticmethod
def make_range_validator(opt):
check_range = Util.make_range_check(opt)
r = opt.copy()
nan = float('nan')
def validate_range(value, name='value'):
if not check_range(value):
if r.get('min', nan) == r.get('max', nan):
msg = '{0} must equal {1}'.format(name, r['min'])
raise SchemaRangeMismatch(msg)
range_str = ''
if 'min' in r:
range_str = '[{0}, '.format(r['min'])
elif 'min-ex' in r:
range_str = '({0}, '.format(r['min-ex'])
else:
range_str = '(-inf, '
if 'max' in r:
range_str += '{0}]'.format(r['max'])
elif 'max-ex' in r:
range_str += '{0})'.format(r['max-ex'])
else:
range_str += 'inf)'
raise SchemaRangeMismatch(name+' must be in range '+range_str)
return validate_range
class UnknownFieldMismatch(SchemaMismatch):
def __init__(self, schema, fields):
if len(fields) == 1:
message = 'unknown field: {}'.format(
repr(fields[0])
)
else:
message = 'unknown fields: {}'.format(
', '.join(fields)
)
if len(message) >= 80: # if the line is too long
message = 'unknown fields:\n{}'.format(
_indent('\n'.join(fields))
)
SchemaMismatch.__init__(self, message, schema, 'unexpected')
self.fields = fields
class SeqLengthMismatch(SchemaMismatch):
def __init__(self, schema, data):
expected_length = len(schema.content_schema)
message = 'sequence must have {} element{} (had {})'.format(
expected_length,
's'*(expected_length != 1), # plural
len(data)
)
SchemaMismatch.__init__(self, message, schema, 'size')
self.expected_length = expected_length
self.value = len(data)
class TreeMismatch(SchemaMismatch):
def __init__(self, schema, errors=[], child_errors={}, message=None):
## Create error message
error_messages = []
for err in errors:
error_messages.append(str(err))
for key, err in child_errors.items():
if isinstance(key, int):
index = '[item {}]'.format(key)
else:
index = '{}'.format(repr(key))
if isinstance(err, TreeMismatch) and \
not err.errors and len(err.child_errors) == 1:
template = '{} > {}'
else:
template = '{} {}'
msg = template.format(index, err)
error_messages.append(msg)
if message is None:
message = 'does not match schema'
if len(error_messages) == 1:
msg = error_messages[0]
else:
msg = '{}:\n{}'.format(
message,
_indent('\n'.join(error_messages))
)
SchemaMismatch.__init__(self, msg, schema, 'multiple')
self.errors = errors
self.child_errors = child_errors
def _createTreeMismatch(schema, errors=[], child_errors={}, message=None):
if len(errors) == 1 and not child_errors:
return errors[0]
else:
return TreeMismatch(schema, errors, child_errors, message)
### Utilities ----------------------------------------------------------------
class Range(object):
def __init__(self, opt):
if isinstance(opt, Range):
for attr in ('min', 'max', 'min_ex', 'max_ex'):
if hasattr(opt, attr):
setattr(self, attr, getattr(opt, attr))
else:
if not {'min', 'max', 'min-ex', 'max-ex'}.issuperset(opt):
raise ValueError("illegal argument to make_range_check")
if {'min', 'min-ex'}.issubset(opt):
raise ValueError("Cannot define both exclusive and inclusive min")
if {'max', 'max-ex'}.issubset(opt):
raise ValueError("Cannot define both exclusive and inclusive max")
for boundary in ('min', 'max', 'min-ex', 'max-ex'):
if boundary in opt:
attr = boundary.replace('-', '_')
setattr(self, attr, opt[boundary])
def __call__(self, value):
INF = float('inf')
get = lambda attr, default: getattr(self, attr, default)
return(
get('min', -INF) <= value and \
get('max', INF) >= value and \
get('min_ex', -INF) < value and \
get('max_ex', INF) > value
)
def __str__(self):
if hasattr(self, 'min'):
s = '[{}, '.format(self.min)
elif hasattr(self, 'min_ex'):
s = '({}, '.format(self.min_ex)
else:
s = '(-Inf, '
if hasattr(self, 'max'):
s += '{}]'.format(self.max)
elif hasattr(self, 'max_ex'):
s += '{})'.format(self.max_ex)
else:
s += 'Inf)'
return s
def _indent(text, level=1, whitespace=' '):
return '\n'.join(whitespace*level+line for line in text.split('\n'))
### Schema Factory Class -----------------------------------------------------
class Factory(object):
def __init__(self, register_core_types=True):
@ -109,20 +266,20 @@ class Factory(object):
m = re.match('^/([-._a-z0-9]*)/([-._a-z0-9]+)$', type_name)
if not m:
raise ValueError("couldn't understand type name '{0}'".format(type_name))
raise ValueError("couldn't understand type name '{}'".format(type_name))
prefix, suffix = m.groups()
if prefix not in self.prefix_registry:
raise KeyError(
"unknown prefix '{0}' in type name '{1}'".format(prefix, type_name)
"unknown prefix '{0}' in type name '{}'".format(prefix, type_name)
)
return self.prefix_registry[ prefix ] + suffix
def add_prefix(self, name, base):
if self.prefix_registry.get(name):
raise SchemaError("the prefix '{0}' is already registered".format(name))
raise SchemaError("the prefix '{}' is already registered".format(name))
self.prefix_registry[name] = base;
@ -130,13 +287,15 @@ class Factory(object):
t_uri = t.uri()
if t_uri in self.type_registry:
raise ValueError("type already registered for {0}".format(t_uri))
raise ValueError("type already registered for {}".format(t_uri))
self.type_registry[t_uri] = t
def learn_type(self, uri, schema):
if self.type_registry.get(uri):
raise SchemaError("tried to learn type for already-registered uri {0}".format(uri))
raise SchemaError(
"tried to learn type for already-registered uri {}".format(uri)
)
# make sure schema is valid
# should this be in a try/except?
@ -153,17 +312,27 @@ class Factory(object):
uri = self.expand_uri(schema['type'])
if not self.type_registry.get(uri): raise SchemaError("unknown type {0}".format(uri))
if not self.type_registry.get(uri):
raise SchemaError("unknown type {}".format(uri))
type_class = self.type_registry[uri]
if isinstance(type_class, dict):
if not {'type'}.issuperset(schema):
raise SchemaError('composed type does not take check arguments');
raise SchemaError('composed type does not take check arguments')
return self.make_schema(type_class['schema'])
else:
return type_class(schema, self)
std_factory = None
def make_schema(schema):
global std_factory
if std_factory is None:
std_factory = Factory()
return std_factory.make_schema(schema)
### Core Type Base Class -------------------------------------------------
class _CoreType(object):
@classmethod
def uri(self):
@ -171,7 +340,7 @@ class _CoreType(object):
def __init__(self, schema, rx):
if not {'type'}.issuperset(schema):
raise SchemaError('unknown parameter for //{0}'.format(self.subname()))
raise SchemaError('unknown parameter for //{}'.format(self.subname()))
def check(self, value):
try:
@ -180,8 +349,10 @@ class _CoreType(object):
return False
return True
def validate(self, value, name='value'):
raise SchemaMismatch('Tried to validate abstract base schema class')
def validate(self, value):
raise SchemaMismatch('Tried to validate abstract base schema class', self)
### Core Schema Types --------------------------------------------------------
class AllType(_CoreType):
@staticmethod
@ -191,26 +362,22 @@ class AllType(_CoreType):
if not {'type', 'of'}.issuperset(schema):
raise SchemaError('unknown parameter for //all')
if not(schema.get('of') and len(schema.get('of'))):
if not schema.get('of'):
raise SchemaError('no alternatives given in //all of')
self.alts = [rx.make_schema(s) for s in schema['of']]
def validate(self, value, name='value'):
error_messages = []
def validate(self, value):
errors = []
for schema in self.alts:
try:
schema.validate(value, name)
schema.validate(value)
except SchemaMismatch as e:
error_messages.append(str(e))
errors.append(e)
if errors:
raise _createTreeMismatch(self, errors)
if len(error_messages) > 1:
messages = indent('\n'.join(error_messages))
message = '{0} failed to meet all schema requirements:\n{1}'
message = message.format(name, messages)
raise SchemaMismatch(message)
elif len(error_messages) == 1:
raise SchemaMismatch(error_messages[0])
class AnyType(_CoreType):
@staticmethod
@ -223,25 +390,28 @@ class AnyType(_CoreType):
raise SchemaError('unknown parameter for //any')
if 'of' in schema:
if not schema['of']: raise SchemaError('no alternatives given in //any of')
if not schema['of']:
raise SchemaError('no alternatives given in //any of')
self.alts = [ rx.make_schema(alt) for alt in schema['of'] ]
def validate(self, value, name='value'):
def validate(self, value):
if self.alts is None:
return
error_messages = []
errors = []
for schema in self.alts:
try:
schema.validate(value, name)
schema.validate(value)
break
except SchemaMismatch as e:
error_messages.append(str(e))
errors.append(e)
if len(errors) == len(self.alts):
message = 'must satisfy at least one of the following'
raise _createTreeMismatch(self, errors, message=message)
if len(error_messages) == len(self.alts):
messages = indent('\n'.join(error_messages))
message = '{0} failed to meet any schema requirements:\n{1}'
message = message.format(name, messages)
raise SchemaMismatch(message)
class ArrType(_CoreType):
@staticmethod
@ -259,47 +429,45 @@ class ArrType(_CoreType):
self.content_schema = rx.make_schema(schema['contents'])
if schema.get('length'):
self.length = Util.make_range_validator(schema['length'])
self.length = Range(schema['length'])
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, (list, tuple)):
raise SchemaTypeMismatch(name, 'array')
raise TypeMismatch(self, value)
if self.length:
self.length(len(value), name+' length')
errors = []
if self.length and not self.length(len(value)):
err = LengthRangeMismatch(self, value)
errors.append(err)
error_messages = []
child_errors = {}
for i, item in enumerate(value):
for key, item in enumerate(value):
try:
self.content_schema.validate(item, 'item '+str(i))
self.content_schema.validate(item)
except SchemaMismatch as e:
error_messages.append(str(e))
child_errors[key] = e
if errors or child_errors:
raise _createTreeMismatch(self, errors, child_errors)
if len(error_messages) > 1:
messages = indent('\n'.join(error_messages))
message = '{0} sequence contains invalid elements:\n{1}'
message = message.format(name, messages)
raise SchemaMismatch(message)
elif len(error_messages) == 1:
raise SchemaMismatch(name+': '+error_messages[0])
class BoolType(_CoreType):
@staticmethod
def subname(): return 'bool'
def validate(self, value, name='value'):
def validate(self, value,):
if not isinstance(value, bool):
raise SchemaTypeMismatch(name, 'boolean')
raise TypeMismatch(self, value)
class DefType(_CoreType):
@staticmethod
def subname(): return 'def'
def validate(self, value, name='value'):
def validate(self, value):
if value is None:
raise SchemaMismatch(name+' must be non-null')
raise TypeMismatch(self, value)
class FailType(_CoreType):
@staticmethod
@ -307,8 +475,13 @@ class FailType(_CoreType):
def check(self, value): return False
def validate(self, value, name='value'):
raise SchemaMismatch(name+' is of fail type, automatically invalid.')
def validate(self, value):
raise SchemaMismatch(
'is of fail type, automatically invalid.',
self,
'fail'
)
class IntType(_CoreType):
@staticmethod
@ -326,17 +499,18 @@ class IntType(_CoreType):
self.range = None
if 'range' in schema:
self.range = Util.make_range_validator(schema['range'])
self.range = Range(schema['range'])
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, Number) or isinstance(value, bool) or value%1:
raise SchemaTypeMismatch(name,'integer')
raise TypeMismatch(self, value)
if self.range:
self.range(value, name)
if self.range and not self.range(value):
raise RangeMismatch(self, value)
if self.value is not None and value != self.value:
raise SchemaValueMismatch(name, self.value)
raise ValueMismatch(self, value)
class MapType(_CoreType):
@staticmethod
@ -353,25 +527,21 @@ class MapType(_CoreType):
self.value_schema = rx.make_schema(schema['values'])
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, dict):
raise SchemaTypeMismatch(name, 'map')
raise TypeMismatch(self, value)
error_messages = []
child_errors = {}
for key, val in value.items():
try:
self.value_schema.validate(val, key)
self.value_schema.validate(val)
except SchemaMismatch as e:
error_messages.append(str(e))
child_errors[key] = e
if child_errors:
raise _createTreeMismatch(self, child_errors=child_errors)
if len(error_messages) > 1:
messages = indent('\n'.join(error_messages))
message = '{0} map contains invalid entries:\n{1}'
message = message.format(name, messages)
raise SchemaMismatch(message)
elif len(error_messages) == 1:
raise SchemaMismatch(name+': '+error_messages[0])
class NilType(_CoreType):
@staticmethod
@ -379,9 +549,10 @@ class NilType(_CoreType):
def check(self, value): return value is None
def validate(self, value, name='value'):
def validate(self, value):
if value is not None:
raise SchemaTypeMismatch(name, 'null')
raise TypeMismatch(self, value)
class NumType(_CoreType):
@staticmethod
@ -400,25 +571,27 @@ class NumType(_CoreType):
self.range = None
if schema.get('range'):
self.range = Util.make_range_validator(schema['range'])
self.range = Range(schema['range'])
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, Number) or isinstance(value, bool):
raise SchemaTypeMismatch(name, 'number')
raise TypeMismatch(self, value)
if self.range:
self.range(value, name)
if self.range and not self.range(value):
raise RangeMismatch(self, value)
if self.value is not None and value != self.value:
raise SchemaValueMismatch(name, self.value)
raise ValueMismatch(self, value)
class OneType(_CoreType):
@staticmethod
def subname(): return 'one'
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, (Number, str)):
raise SchemaTypeMismatch(name, 'number or string')
raise TypeMismatch(self, value)
class RecType(_CoreType):
@staticmethod
@ -436,7 +609,9 @@ class RecType(_CoreType):
setattr(self, which, {})
for field in schema.get(which, {}).keys():
if field in self.known:
raise SchemaError('%s appears in both required and optional' % field)
raise SchemaError(
'%s appears in both required and optional' % field
)
self.known.add(field)
@ -444,47 +619,53 @@ class RecType(_CoreType):
schema[which][field]
)
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, dict):
raise SchemaTypeMismatch(name, 'record')
raise TypeMismatch(self, value)
unknown = [k for k in value.keys() if k not in self.known]
errors = []
child_errors = {}
if unknown and not self.rest_schema:
fields = indent('\n'.join(unknown))
raise SchemaMismatch(name+' contains unknown fields:\n'+fields)
error_messages = []
missing_fields = []
for field in self.required:
try:
if field not in value:
raise SchemaMismatch('missing required field: '+field)
self.required[field].validate(value[field], field)
except SchemaMismatch as e:
error_messages.append(str(e))
if field not in value:
missing_fields.append(field)
else:
try:
self.required[field].validate(value[field])
except SchemaMismatch as e:
child_errors[field] = e
if missing_fields:
err = MissingFieldMismatch(self, missing_fields)
errors.append(err)
for field in self.optional:
if field not in value: continue
try:
self.optional[field].validate(value[field], field)
except SchemaMismatch as e:
error_messages.append(str(e))
if unknown:
rest = {key: value[key] for key in unknown}
try:
self.rest_schema.validate(rest, name)
self.optional[field].validate(value[field])
except SchemaMismatch as e:
error_messages.append(str(e))
child_errors[field] = e
if len(error_messages) > 1:
messages = indent('\n'.join(error_messages))
message = '{0} record is invalid:\n{1}'
message = message.format(name, messages)
raise SchemaMismatch(message)
elif len(error_messages) == 1:
raise SchemaMismatch(name+': '+error_messages[0])
unknown = [k for k in value.keys() if k not in self.known]
if unknown:
if self.rest_schema:
rest = {key: value[key] for key in unknown}
try:
self.rest_schema.validate(rest)
except SchemaMismatch as e:
errors.append(e)
else:
fields = _indent('\n'.join(unknown))
err = UnknownFieldMismatch(self, unknown)
errors.append(err)
if errors or child_errors:
raise _createTreeMismatch(self, errors, child_errors)
class SeqType(_CoreType):
@ -504,34 +685,33 @@ class SeqType(_CoreType):
if (schema.get('tail')):
self.tail_schema = rx.make_schema(schema['tail'])
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, (list, tuple)):
raise SchemaTypeMismatch(name, 'sequence')
raise TypeMismatch(self, value)
if len(value) < len(self.content_schema):
raise SchemaMismatch(name+' is less than expected length')
errors = []
if len(value) > len(self.content_schema) and not self.tail_schema:
raise SchemaMismatch(name+' exceeds expected length')
if len(value) != len(self.content_schema):
if len(value) > len(self.content_schema) and self.tail_schema:
try:
self.tail_schema.validate(value[len(self.content_schema):])
except SchemaMismatch as e:
errors.append(e)
else:
err = SeqLengthMismatch(self, value)
errors.append(err)
error_messages = []
child_errors = {}
for i, (schema, item) in enumerate(zip(self.content_schema, value)):
for index, (schema, item) in enumerate(zip(self.content_schema, value)):
try:
schema.validate(item, 'item '+str(i))
schema.validate(item)
except SchemaMismatch as e:
error_messages.append(str(e))
child_errors[index] = e
if len(error_messages) > 1:
messages = indent('\n'.join(error_messages))
message = '{0} sequence is invalid:\n{1}'
message = message.format(name, messages)
raise SchemaMismatch(message)
elif len(error_messages) == 1:
raise SchemaMismatch(name+': '+error_messages[0])
if errors or child_errors:
raise _createTreeMismatch(self, errors, child_errors)
if len(value) > len(self.content_schema):
self.tail_schema.validate(value[len(self.content_schema):], name)
class StrType(_CoreType):
@staticmethod
@ -549,18 +729,20 @@ class StrType(_CoreType):
self.length = None
if 'length' in schema:
self.length = Util.make_range_validator(schema['length'])
self.length = Range(schema['length'])
def validate(self, value, name='value'):
def validate(self, value):
if not isinstance(value, str):
raise SchemaTypeMismatch(name, 'string')
raise TypeMismatch(self, value)
if self.value is not None and value != self.value:
raise SchemaValueMismatch(name, '"{0}"'.format(self.value))
if self.length:
self.length(len(value), name+' length')
raise ValueMismatch(self, self)
if self.length and not self.length(len(value)):
raise LengthRangeMismatch(self, value)
core_types = [
AllType, AnyType, ArrType, BoolType, DefType,
FailType, IntType, MapType, NilType, NumType,
OneType, RecType, SeqType, StrType
]
]

3
patacrep/utils.py

@ -83,8 +83,7 @@ def validate_yaml_schema(data, schema):
Will raise `SBFileError` if the schema is not respected.
"""
rx_checker = Rx.Factory({"register_core_types": True})
schema = rx_checker.make_schema(schema)
schema = Rx.make_schema(schema)
if isinstance(data, DictOfDict):
data = dict(data)

Loading…
Cancel
Save