#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
# Formal
# ======
#
# Copyright 2013 Rob Britton
# Copyright 2015-2019 Heiko 'riot' Weinen <riot@c-base.org> and others.
#
# This file has been changed and this notice has been added in
# accordance to the Apache License
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Changes notice
==============
This file has been changed by the Hackerfleet Community and this notice has
been added in accordance to the Apache License 2.0
"""
from bson import ObjectId
from pymongo import DESCENDING
from .model_base import ModelBase
import formal.database
from .exceptions import InvalidReloadException
from copy import copy
[docs]class Model(ModelBase):
"""The Mongodb object model class"""
[docs] def reload(self):
""" Reload this object's data from the DB. """
result = self.__class__.find_by_id(self._id)
# result will be None in the case that this object hasn't yet been
# saved to the DB, or if the object has been deleted since it was
# fetched
if result:
self._fields = self.cast(result._fields)
else:
raise InvalidReloadException(
"No object in the database with ID %s" % self._id
)
[docs] def save(self, *args, **kwargs):
""" Saves an object to the database. """
self.validate()
self._fields["_id"] = self.collection().save(self._fields, *args, **kwargs)
[docs] def delete(self):
""" Removes an object from the database. """
try:
self.collection().delete_one({"_id": ObjectId(str(self._fields["_id"]))})
except Exception as e:
print("Uh oh: ", e, type(e))
[docs] def serializablefields(self):
"""Return serializable fields of the object"""
result = copy(self._fields)
result["id"] = self._schema["id"]
if "_id" in result:
result["_id"] = str(result["_id"])
return result
[docs] @classmethod
def bulk_create(cls, objects, *args, **kwargs):
""" Create a number of objects (yay performance). """
docs = [obj._fields for obj in objects]
return cls.collection().insert(docs)
[docs] @classmethod
def find_or_create(cls, query, *args, **kwargs):
""" Retrieve an element from the database. If it doesn't exist, create
it. Calling this method is equivalent to calling find_one and then
creating an object. Note that this method is not atomic. """
result = cls.find_one(query, *args, **kwargs)
if result is None:
default = cls._schema.get("default", {})
default.update(query)
result = cls(default, *args, **kwargs)
return result
[docs] @classmethod
def find(cls, *args, **kwargs):
""" Grabs a set of elements from the DB.
Note: This returns a generator, so you can't to do an efficient count.
To get a count, use the count() function which accepts the same
arguments as find() with the exception of non-query fields like sort,
limit, skip.
"""
options = {}
for option in ["sort", "limit", "skip", "batch_size"]:
if option in kwargs:
options[option] = kwargs[option]
del options[option]
if "batch_size" in options and "skip" not in options and "limit" not in options:
# run things in batches
current_skip = 0
limit = options["batch_size"]
found_something = True
while found_something:
found_something = False
result = cls.collection().find(*args, **kwargs)
result = result.skip(current_skip).limit(limit)
if "sort" in options:
result = result.sort(options["sort"])
for obj in result:
found_something = True
yield cls(obj, from_find=True)
current_skip += limit
else:
result = cls.collection().find(*args, **kwargs)
if "sort" in options:
result = result.sort(options["sort"])
if "skip" in options:
result = result.skip(options["skip"])
if "limit" in options:
result = result.limit(options["limit"])
for obj in result:
yield cls(obj, from_find=True)
[docs] @classmethod
def find_by_id(cls, obj_id, **kwargs):
""" Finds a single object from this collection. """
if isinstance(obj_id, str):
obj_id = ObjectId(obj_id)
args = {"_id": obj_id}
result = cls.collection().find_one(args, **kwargs)
if result is not None:
return cls(result, from_find=True)
return None
[docs] @classmethod
def find_latest(cls, *args, **kwargs):
""" Finds the latest one by _id and returns it. """
kwargs["limit"] = 1
kwargs["sort"] = [("_id", DESCENDING)]
result = cls.collection().find(*args, **kwargs)
if result.count() > 0:
return cls(result[0], from_find=True)
return None
[docs] @classmethod
def find_one(cls, *args, **kwargs):
""" Finds a single object from this collection. """
result = cls.collection().find_one(*args, **kwargs)
if result is not None:
return cls(result)
return None
[docs] @classmethod
def count(cls, object_filter=None):
""" Counts the number of items:
- not the same as pymongo's count, this is the equivalent to:
collection.find(*args, **kwargs).count()
"""
if object_filter is None:
object_filter = {}
# TODO: WTF. Yeah. I love deprecation warnings, too, pymongo.
if hasattr(cls.collection, "count_documents"):
return cls.collection().count_documents(object_filter)
else:
return cls.collection().count(object_filter)
[docs] @classmethod
def collection(cls):
""" Get the pymongo collection object for this model. Useful for
features not supported by formal like aggregate queries and
map-reduce. """
return formal.database.get_collection(
collection=cls.collection_name(), database=cls.database_name()
)