Source code for formal.model_sqlalchemy

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

# Formal
# ======
#
# Copyright 2018-2019 Heiko 'riot' Weinen <riot@c-base.org> and others.
#
# This file has been changed and this notice has been added in
# accordance to the Apache License
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
SQL Support for Formal
"""

import re
import sqlalchemy as sql
from deepdiff import DeepDiff
from .model_base import DefaultValidatingDraft4Validator
from .database import sql_database
from jsonschema import validate, Draft4Validator, validators
from jsonschema.exceptions import ValidationError
from copy import copy, deepcopy


[docs]class Model(object): """The SQL object model class"""
[docs] def __init__(self, original_fields=None, from_find=False, *args, **kwargs): """ Creates an instance of the object.""" if original_fields is None: original_fields = {} self._from_find = from_find metadata = sql.MetaData() self._table = sql.Table(self._schema["name"], metadata) for item, value in self._schema["properties"].items(): # print(item) column_type = value["type"].upper() primary = value.get("primary", False) if primary: self._primary = item column = sql.String(64) if column_type == "INTEGER": column = sql.Integer elif column_type == "STRING": length = value.get("length", 64) column = sql.String(length) self._table.append_column(sql.Column(item, column, primary_key=primary)) from .database import sql_database metadata.create_all(sql_database) self._engine = sql_database fields = deepcopy(dict(original_fields)) has_id = False if "_id" in fields: has_id = True del fields["_id"] # populate any default fields for objects that haven't come from the DB if not from_find: DefaultValidatingDraft4Validator(self._schema).validate(fields) # for field, details in self._schema["properties"].items(): # if "default" in details and not field in fields: # fields[field] = details["default"] self._fields = self.cast(fields) self.validate() if has_id: self._fields["_id"] = original_fields["_id"]
[docs] def reload(self): """ Reload this object's data from the DB. """ pass
[docs] def save(self, *args, **kwargs): """ Saves an object to the database. """ self.validate() # print(**self._fields) insert = self._table.insert().values(**self._fields) result = self._engine.execute(insert) return result.inserted_primary_key
[docs] def delete(self): """ Removes an object from the database. """ primary = self._fields[self._primary] query = ( "DELETE FROM {table_name}" " WHERE {table_name}.{primary_key} = {primary}".format( **{ "table_name": self._schema["name"], "primary_key": self._primary, "primary": primary, } ) ) delete = sql.text(query) result = self._engine.execute(delete) return result
[docs] def serializablefields(self): """Return serializable fields of the object""" result = copy(self._fields) result["id"] = self._schema["id"] if "_id" in result: result["_id"] = str(result["_id"]) return result
[docs] @classmethod def bulk_create(cls, objects, *args, **kwargs): """ Create a number of objects (yay performance). """ print("WARNING! NOT IMPLEMENTED: BULK_CREATE") pass
# return cls.collection().insert(docs)
[docs] @classmethod def find_or_create(cls, query, *args, **kwargs): """ Retrieve an element from the database. If it doesn't exist, create it. Calling this method is equivalent to calling find_one and then creating an object. Note that this method is not atomic. """ result = cls.find_one(query, *args, **kwargs) if result is None: default = cls._schema.get("default", {}) default.update(query) result = cls(default, *args, **kwargs) return result
[docs] @classmethod def find(cls, *args, **kwargs): """ Grabs a set of elements from the DB. Note: This returns a generator, so you can't do an efficient count. To get a count, use the count() function which accepts the same arguments as find() with the exception of non-query fields like sort, limit, skip. """ options = {} for option in ["sort", "limit", "skip", "batch_size"]: if option in kwargs: options[option] = kwargs[option] del options[option] if "batch_size" in options and "skip" not in options and "limit" not in options: # run things in batches current_skip = 0 limit = options["batch_size"] found_something = True while found_something: found_something = False result = cls.collection().find(*args, **kwargs) result = result.skip(current_skip).limit(limit) if "sort" in options: result = result.sort(options["sort"]) for obj in result: found_something = True yield cls(obj, from_find=True) current_skip += limit else: result = cls._find(*args, **kwargs) if "sort" in options: result = result.sort(options["sort"]) if "skip" in options: result = result.skip(options["skip"]) if "limit" in options: result = result.limit(options["limit"]) for sql_thing in result.fetchall(): obj = cls._transform_object(sql_thing) yield cls(obj, from_find=True)
@classmethod def _transform_object(cls, thing): return dict(zip(cls._schema["properties"], thing))
[docs] @classmethod def find_by_id(cls, obj_id, **kwargs): """ Finds a single object from this collection. """ if isinstance(obj_id, str): obj_id = obj_id args = {"_id": obj_id} result = cls.collection().find_one(args, **kwargs) if result is not None: return cls(result, from_find=True) return None
[docs] @classmethod def find_latest(cls, *args, **kwargs): """ Finds the latest one by _id and returns it. """ kwargs["limit"] = 1 kwargs["sort"] = (cls._primary, "DESC") result = cls._find(*args, **kwargs) # if result.count() > 0: # return cls(result[0], from_find=True) return result.fetchone()
@classmethod def _find(cls, *args, **kwargs): # pprint(args) # pprint(kwargs) limit = kwargs.get("limit", False) sort = kwargs.get("sort", False) name = cls._schema["name"] query = "SELECT %s FROM %s" % (",".join(cls._schema["properties"].keys()), name) if len(args) > 0 and not args[0] == {}: query += " WHERE" # pprint(args) for item in args: for key, value in item.items(): if isinstance(value, str): value = "'%s'" % value query += " %s.%s = %s AND" % (name, key, value) query = query.rstrip(" AND") if sort is not False: query += " ORDER BY %s %s" % (sort[0], sort[1]) if limit is not False: query += " LIMIT %i" % limit find = sql.text(query) result = cls._engine.execute(find) return result
[docs] @classmethod def find_one(cls, *args, **kwargs): """Finds a single object from this collection.""" result = cls._find(*args, **kwargs).fetchone() # pprint(result) if result is not None: return cls(result) return None
[docs] @classmethod def count(cls, *args, **kwargs): """ Counts the number of items: """ name = cls._schema["name"] query = "SELECT COUNT(*) FROM %s" % name proxy = cls._engine.execute(query).scalar() return int(proxy)
[docs] @classmethod def clear(cls): """Clear a collection""" query = "DELETE FROM {table_name}".format(**{"table_name": cls._schema["name"]}) clear = sql.text(query) result = cls._engine.execute(clear) return result
[docs] @classmethod def collection(cls): """ Get the pymongo collection object for this model. Useful for features not supported by formal like aggregate queries and map-reduce. """ return None
# return formal.database.get_collection( # collection=cls.collection_name(), # database=cls.database_name())
[docs] @classmethod def make_migration(cls, new_schema): """Make migrations for a schema""" delta = DeepDiff(cls._schema, new_schema) return delta
[docs] @classmethod def migrate(cls, patchset): """Migrate an object model with a patchset""" def apply_patch(patch): """Apply a patch to all objects of a model""" def migrate_thing(migration_thing, migration_patch): """Apply all diffs in a patch to an object""" for diff in migration_patch: print( "Would now apply ", diff, migration_patch[diff], "to", migration_thing.name, ) for thing in cls.collection().find_all(): migrate_thing(thing, patch) for key in patchset: print("Applying patch", key) apply_patch(patchset[key])
####################################################################################
[docs] def get(self, field, default=None): """ Get a field if it exists, otherwise return the default. """ return self._fields.get(field, default)
[docs] @classmethod def collection_name(cls): """ Get the collection associated with this class. """ name = cls._schema.get( "collectionName", cls._schema.get("collectionName", cls._schema.get("name", cls.__name__)), ) # convert to snake case name = (name[0] + re.sub("([A-Z])", r"_\1", name[1:])).lower() return name
[docs] @classmethod def database_name(cls): """ Get the database associated with this class. Meant to be overridden in subclasses. """ return cls._schema.get("databaseName", None)
[docs] def to_dict(self): """ Convert the object to a dict. """ return self._fields
[docs] def validate(self): """ Validate `schema` against a dict `obj`. """ # self.validate_field("", self._schema, self._fields) try: pass # TODO: Deep-copying for validation is probably not so good ;) fields = dict(self._fields) if "_id" in fields: # Now remove for schema validation (jsonschema knows nothing # off object ids) del fields["_id"] validate(fields, self._schema) except ValidationError as e: raise ValidationError( "Error:\n" + str(e) + "\nFields:\n" + str(self._fields) )
[docs] def cast(self, fields, schema=None): """ Cast the fields from Mongo into our format - necessary to convert floats into ints since Javascript doesn't support ints. """ if schema is None: schema = self._schema value_type = schema.get("type", "object") if ( value_type == "object" and isinstance(fields, dict) and schema.get("properties") ): result = dict() for key, value in fields.items(): result[key] = self.cast(value, schema["properties"].get(key, {})) return result elif value_type == "array" and isinstance(fields, list) and schema.get("items"): return [self.cast(value, schema["items"]) for value in fields] elif value_type == "integer" and isinstance(fields, float): # The only thing that needs to be casted: floats -> ints return int(fields) elif value_type == "object_id": return str(fields) else: return fields
def __str__(self): """Return string representation of this object model""" return str(self.to_dict()) def __repr__(self): """Return a representation for debugging purposes""" return str(self.name if self.name is not None else self.to_dict()) def __getattr__(self, attr): """ Get an attribute from the fields we've selected. Note that if the field doesn't exist, this will return None. """ if attr in self._schema["properties"] and attr in self._fields: return self._fields.get(attr) else: raise AttributeError("Item has no attribute '%s'" % attr) # if attr.startswith('_'): # return super(ModelBase, self).__getattr__(attr) # # if attr in self._schema["properties"] and attr in self._fields: # #print("Direct hit") # return self._fields.get(attr) # current_schema = self._schema["properties"] # current_fields = self._fields # path = attr # new_attribute = path # # #print("Query path:", path) # #print("Initial Fields:", current_fields) # # while '.' in path: # # new_attribute, path = path.split('.', maxsplit=1) # #print("Looking for intermediate path in ", new_attribute, path) # # if new_attribute in current_schema and new_attribute in current_fields: # current_schema = current_schema[new_attribute]['properties'] # current_fields = current_fields[new_attribute] # else: # raise AttributeError("Item has no intermediate attribute '%s'" # % ( new_attribute)) # # # if new_attribute in current_schema and new_attribute in current_fields: # return current_fields.get(new_attribute) # else: # raise AttributeError("Item has no attribute '%s'" % ( attr)) def __setattr__(self, attr, value): """ Set one of the fields, with validation. Exception is on "private" fields - the ones that start with _. """ if attr.startswith("_"): return object.__setattr__(self, attr, value) if attr in self._schema["properties"]: # Check the field against our schema original = self._fields[attr] self._fields[attr] = value try: self.validate() except ValidationError as e: self._fields[attr] = original raise e elif not self._schema.get("additionalProperties", True): # not allowed to add additional properties raise ValidationError("Additional property '%s' not allowed!" % attr) self._fields[attr] = value return value
[docs] def update(self, new_fields, update_id=False): """Update an object's fields""" # try: if True: for key, value in new_fields.items(): if not key == "_id" or update_id: self.__setattr__(key, value)
# except Exception as e: # raise ValidationError( # "Unknown Validation error: '%s' (%s)" % (e, type(e)))