# -*- coding: utf-8 -*-
from six.moves import builtins as __builtin__
from six import PY2
import re
import nacl.secret
import nacl.utils
import inspect
import datetime
import logging
from decimal import Decimal
from collections import OrderedDict
from six import with_metaclass
import dateutil.parser
from chemist.orm import ORM
from chemist.orm import get_engine
from chemist.orm import format_decimal
from chemist.managers import Manager
from chemist.serializers import json
from chemist.exceptions import FieldTypeValueError
from chemist.exceptions import MultipleEnginesSpecified
from chemist.exceptions import EngineNotSpecified
from chemist.exceptions import InvalidColumnName
from chemist.exceptions import InvalidModelDeclaration
logger = logging.getLogger(__name__)
if PY2:
string_types = (basestring,)
else:
string_types = (str,)
def try_json_deserialize(value, silent=False):
try:
return json.loads(value)
except Exception:
if not silent:
logger.warning("could not JSON deserialize value {}".format(value))
return value
[docs]class Model(with_metaclass(ORM, object)):
"""Super-class of active record models.
**Example:**
::
class BlogPost(Mode):
table = db.Table(
'blog_post',
metadata,
db.Column('id', db.Integer, primary_key=True),
db.Column('title', db.Unicode(200), nullable=False),
db.Column('slug', db.Unicode(200), nullable=False),
db.Column('content', db.UnicodeText, nullable=False),
)
def preprocess(self, data):
# always derive slug from title
data['slug'] = slugify(data['title'])
return data
"""
manager = Manager
@classmethod
def using(cls, engine=None):
if engine is None:
engine = get_engine()
elif isinstance(engine, string_types):
engine = get_engine(uri=engine)
return cls.manager(cls, engine)
@classmethod
def objects(cls):
return cls.using(None)
create = classmethod(lambda cls, **data: cls.using(None).create(**data))
get_or_create = classmethod(
lambda cls, **data: cls.using(None).get_or_create(**data)
)
query_by = classmethod(
lambda cls, order_by=None, **kw: cls.using(None).query_by(
order_by=order_by, **kw
)
)
find_one_by = classmethod(lambda cls, **kw: cls.using(None).find_one_by(**kw))
find_by = classmethod(lambda cls, **kw: cls.using(None).find_by(**kw))
all = classmethod(lambda cls, **kw: cls.using(None).all(**kw))
total_rows = classmethod(lambda cls, **kw: cls.using(None).total_rows(**kw))
get_connection = classmethod(lambda cls, **kw: cls.using(None).get_connection())
many_from_query = classmethod(
lambda cls, query: cls.using(None).many_from_query(query)
)
one_from_query = classmethod(
lambda cls, query: cls.using(None).one_from_query(query)
)
where_many = classmethod(
lambda cls, *args, **kw: cls.using(None).where_many(*args, **kw)
)
where_one = classmethod(
lambda cls, *args, **kw: cls.using(None).where_one(*args, **kw)
)
def __init__(self, engine=None, **data):
"""A Model can be instantiated with keyword-arguments that
have the same keys as the declared fields, it will make a new
model instance that is ready to be persited in the database.
DO NOT overwrite the __init__ method of your custom model.
There are 2 possibilities of customization of your model in
construction time:
* Implement a `preprocess(self, data)` method in your model,
this method takes the dictionary that has the
keyword-arguments given to the constructor and should return a
dictionary with that data "post-processed" This ORM provides
the handy optional method `initialize` that is always called
in the end of the constructor.
* Implement the `initialize(self)` method that will be always
called after successfully creating a new model instance.
"""
Model = self.__class__
module = Model.__module__
name = Model.__name__
columns = self.__columns__
for key, value in data.items():
data[key] = self.decrypt_attribute(key, value)
preprocessed_data = self.preprocess(data)
if not isinstance(preprocessed_data, dict):
raise InvalidModelDeclaration(
"The model `{0}` declares a preprocess method but "
"it does not return a dictionary!".format(name)
)
self.__data__ = preprocessed_data
self.engine = engine
for k, v in data.items():
if k not in self.__columns__:
msg = "{0} is not a valid column name for the model {2}.{1} ({3})"
raise InvalidColumnName(
msg.format(k, name, module, sorted(columns.keys()))
)
if callable(v):
v = v()
setattr(self, k, v)
self.initialize()
def __repr__(self):
return "<{0} {1}={2}>".format(
self.__class__.__name__, self.get_pk_name(), self.get_pk_value()
)
[docs] def preprocess(self, data):
"""Placeholder for your own custom preprocess method, remember
it must return a dictionary.
::
class BlogPost(Mode):
table = db.Table(
'blog_post',
metadata,
db.Column('id', db.Integer, primary_key=True),
db.Column('title', db.Unicode(200), nullable=False),
db.Column('slug', db.Unicode(200), nullable=False),
db.Column('content', db.UnicodeText, nullable=False),
)
def preprocess(self, data):
# always derive slug from title
data['slug'] = slugify(data['title'])
return data
"""
return data
def get_encryption_box_for_attribute(self, attr):
keymap = dict(getattr(self, "encryption", None) or {})
if attr not in keymap:
return
key = keymap[attr]
box = nacl.secret.SecretBox(key)
return box
def encrypt_attribute(self, attr, value):
box = self.get_encryption_box_for_attribute(attr)
if not box:
return value
nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
return box.encrypt(str(value), nonce)
def decrypt_attribute(self, attr, value):
box = self.get_encryption_box_for_attribute(attr)
if not box:
return value
try:
return box.decrypt(value)
except ValueError:
return value
def serialize_value(self, attr, value):
col = self.table.columns[attr]
if col.default and not value:
if col.default.is_callable:
value = col.default.arg(value)
else:
value = col.default.arg
if isinstance(value, Decimal):
return format_decimal(value)
date_types = (datetime.datetime, datetime.date, datetime.time)
if isinstance(value, date_types):
return value.isoformat()
if not value:
return value
data_type = self.__columns__.get(attr, None)
builtins = list(dict(inspect.getmembers(__builtin__)).values())
builtins.extend([Decimal])
if col.primary_key and not value:
return value
if data_type and not isinstance(value, data_type) and data_type in builtins:
try:
return data_type(value)
except TypeError as e:
raise FieldTypeValueError(self, attr, e)
except ValueError as e:
raise FieldTypeValueError(self, attr, e)
return value
def deserialize_value(self, attr, value):
value = self.decrypt_attribute(attr, value)
date_types = (datetime.datetime, datetime.date)
kind = self.__columns__.get(attr, None)
if issubclass(kind, date_types) and not isinstance(value, kind) and value:
return dateutil.parser.parse(value)
return value
def __setattr__(self, attr, value):
if attr in self.__columns__:
self.__data__[attr] = self.deserialize_value(attr, value)
return
return super(Model, self).__setattr__(attr, value)
[docs] def to_dict(self):
"""pre-serializes the model, returning a dictionary with
key-values.
This method can be overwritten by subclasses at will.
**Example:**
::
>>> post = BlogPost.create(title='Some Title', content='loren ipsum')
>>> post.to_dict()
{
'id': 1,
'title': 'Some Title',
'slug': 'some-title',
}
"""
return self.serialize()
[docs] def serialize(self):
"""pre-serializes the model, returning a dictionary with
key-values.
This method is use by the to_dict() and only exists as a
separate method so that subclasses overwriting `to_dict` can
call `serialize()` rather than `super(SubclassName,
self).to_dict()`
"""
keys = list(self.__columns__.keys())
return dict(
[
(k, self.serialize_value(k, self.__data__.get(k)))
for k in self.__columns__.keys()
]
)
[docs] def to_insert_params(self):
"""utility method used internally to generate a dict with all the
serialized values except primary keys.
**Example:**
::
>>> post = BlogPost.create(title='Some Title', content='loren ipsum')
>>> post.to_insert_params()
{
'title': 'Some Title',
'slug': 'some-title',
}
"""
pre_data = Model.serialize(self)
data = OrderedDict()
for k, v in pre_data.items():
data[k] = self.encrypt_attribute(k, v)
primary_key_names = [x.name for x in self.table.primary_key.columns]
keys_to_pluck = (
list(filter(lambda x: x not in self.__columns__, data.keys()))
+ primary_key_names
)
# not saving primary keys, let's let the SQL backend to take
# care of auto increment.
# if we need fine tuning and allow manual primary key
# definition, just go ahead and change this code and it's
# tests :)
for key in keys_to_pluck:
data.pop(key)
return data
[docs] def to_json(self, indent=None, sort_keys=True, **kw):
"""Grabs the dictionary with the current model state returned
by `to_dict` and serializes it to JSON"""
data = self.to_dict()
return json.dumps(data, indent=indent, sort_keys=sort_keys, **kw)
def __getattr__(self, attr):
try:
return object.__getattribute__(self, attr)
except AttributeError:
columns = list(self.__columns__.keys())
if attr in columns:
value = self.__data__.get(attr, None)
return self.serialize_value(attr, value)
[docs] def delete(self):
"""Deletes the current model from the database (removes a row
that has the given model primary key)
"""
self.pre_delete()
conn = self.get_engine().connect()
result = conn.execute(
self.table.delete().where(
getattr(self.table.c, self.get_pk_name()) == self.get_pk_value()
)
)
self.post_delete()
return result
[docs] def pre_delete(self):
"""called right before executing a deletion.
This method can be overwritten by subclasses in order to take any domain-related action
"""
[docs] def post_delete(self):
"""called right after executing a deletion.
This method can be overwritten by subclasses in order to take any domain-related action
"""
@property
def is_persisted(self):
"""boolean property that returns **True** if the primary key is set.
This property **does not perform I/O against the database**
"""
return self.get_pk_name() in self.__data__.keys()
def get_engine(self, input_engine=None):
if not self.engine and not input_engine:
raise EngineNotSpecified(
"You must specify a SQLAlchemy engine object in order to "
"do operations in this model instance: {0}".format(self)
)
elif self.engine and input_engine:
raise MultipleEnginesSpecified(
"This model instance has a SQLAlchemy engine object already. "
"You may not save it to another engine."
)
return self.engine or input_engine
[docs] def save(self, input_engine=None):
"""Persists the model instance in the DB.
It takes care of checking whether it already exists and should be just updated or if a new record should be created.
"""
self.pre_save()
engine = self.get_engine(input_engine)
conn = engine.connect()
transaction = conn.begin()
primary_key_column_name = self.get_pk_name()
mid = self.__data__.get(primary_key_column_name, None)
try:
if mid is None:
values = self.to_insert_params()
res = conn.execute(self.table.insert().values(**values))
primary_keys = {primary_key_column_name: res.inserted_primary_key[0]}
self.set(**dict(primary_keys))
self.set(**dict(res.last_inserted_params()))
else:
res = conn.execute(
self.table.update()
.values(**self.to_insert_params())
.where(self.get_pk_col(primary_key_column_name) == mid)
)
newdata = res.last_updated_params()
for k in list(newdata.keys()):
if k.endswith("_1"):
newdata[k[:-2]] = newdata.pop(k)
self.set(**dict(newdata))
except Exception:
logger.error("failed for %s", engine)
raise
transaction.commit()
# transaction.flush()
conn.close()
self.post_save(transaction)
return self
[docs] def pre_save(self):
"""called right before executing a save.
This method can be overwritten by subclasses in order to take any domain-related action
"""
[docs] def post_save(self, transaction):
"""called right after executing a save.
This method can be overwritten by subclasses in order to take any domain-related action
"""
[docs] def refresh(self):
"""updates the current record with fresh values retrieved by
:py:meth:`find_one_by` and also returns a brand new instance.
.. note:: any unsaved changes in the model will be lost upon
calling this method.
"""
params = {}
params[self.get_pk_name()] = self.get_pk_value()
new = self.find_one_by(**params)
self.set(**new.__data__)
return new
[docs] def set(self, **kw):
"""Sets multiple fields, does not perform a save operation"""
cols = self.__columns__.keys()
pk_regex = re.compile(r"^{}_\d+$".format(self.get_pk_name))
for name, value in kw.items():
if pk_regex.match(name):
continue
if name not in cols:
raise InvalidColumnName("{0}.{1}".format(self, name))
setattr(self, name, value)
self.__data__[name] = value
return self
[docs] def update_and_save(self, **kw):
"""Sets multiple fields then saves them"""
updated = self.set(**kw)
return updated.save()
[docs] def get(self, name, fallback=None):
"""Get a field value from the model"""
return self.__data__.get(name, fallback)
[docs] def initialize(self):
"""Dummy method to be optionally overwritten in the subclasses.
Gets automatically called once a model instance is constructed.
"""
def __eq__(self, other):
"""Just making sure models are comparable to each other"""
matches_pk = all(
[
type(self) == type(other),
self.get_pk_name() == other.get_pk_name(),
self.get_pk_value(),
other.get_pk_value(),
]
)
if matches_pk:
return self.get_pk_value() == other.get_pk_value()
keys = set(list(self.__data__.keys()) + list(other.__data__.keys()))
return all(
[
self.__data__.get(key) == other.__data__.get(key)
for key in keys
if key != self.get_pk_name()
]
)
@classmethod
def get_pk_name(cls):
for name, col in cls.table.c.items():
if col.primary_key:
return name
def get_pk_value(cls):
return getattr(cls, cls.get_pk_name())
@classmethod
def get_pk_col(cls, name):
return getattr(cls.table.c, name)